[DMAAP-DR] Another fix for sql injection
[dmaap/datarouter.git] / datarouter-prov / src / main / java / org / onap / dmaap / datarouter / provisioning / StatisticsServlet.java
index 964ef03..6049eea 100755 (executable)
@@ -64,6 +64,7 @@ public class StatisticsServlet extends BaseServlet {
     private static final String FMT1 = "yyyy-MM-dd'T'HH:mm:ss'Z'";\r
     private static final String FMT2 = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'";\r
     public static final String FEEDID = "FEEDID";\r
+    public static final String START = "start";\r
 \r
     //sql Strings\r
     private static final String SQL_SELECT_NAME = "SELECT (SELECT NAME FROM FEEDS AS f WHERE f.FEEDID in(";\r
@@ -301,11 +302,12 @@ public class StatisticsServlet extends BaseServlet {
 \r
 \r
     /**\r
-     * queryGeneretor - Generating sql query.\r
+     * getResultSet - Set the result from the query.\r
      *\r
      * @param map as key value pare of all user input fields\r
      */\r
-    private PreparedStatement queryGeneretor(Map<String, String> map) throws ParseException, SQLException {\r
+    private void getResultSet(Map<String, String> map, String outputType, ServletOutputStream out,\r
+        HttpServletResponse resp) throws ParseException, SQLException, IOException {\r
 \r
         String sql;\r
         String feedids = null;\r
@@ -315,8 +317,8 @@ public class StatisticsServlet extends BaseServlet {
         long startInMillis = 0;\r
         long endInMillis = 0;\r
         String subid = " ";\r
-        String compareType = null;\r
-        PreparedStatement ps = null;\r
+        String compareType;\r
+        ResultSet rs;\r
 \r
         if (map.get(FEEDIDS) != null) {\r
             feedids = map.get(FEEDIDS);\r
@@ -334,8 +336,7 @@ public class StatisticsServlet extends BaseServlet {
         if (endTime == null && startTime == null) {\r
             sql =  SQL_SELECT_NAME + "?" + SQL_FEED_ID + SQL_SELECT_COUNT + "?" + SQL_TYPE_PUB + SQL_SELECT_SUM\r
                 + "?" + SQL_PUBLISH_LENGTH + SQL_SUBSCRIBER_URL + SQL_SUB_ID + SQL_DELIVERY_TIME + SQL_AVERAGE_DELAY\r
-                + SQL_JOIN_RECORDS + "?" + ") " + "?" + SQL_STATUS_204\r
-                + SQL_GROUP_SUB_ID;\r
+                + SQL_JOIN_RECORDS + "?" + ") " + SQL_STATUS_204 + SQL_GROUP_SUB_ID;\r
             compareType = "default";\r
         } else if (startTime != null && endTime == null) {\r
             long inputTimeInMilli = 60000 * Long.parseLong(startTime);\r
@@ -344,9 +345,9 @@ public class StatisticsServlet extends BaseServlet {
             compareTime = currentTimeInMilli - inputTimeInMilli;\r
             sql = SQL_SELECT_NAME + "?" + SQL_FEED_ID + SQL_SELECT_COUNT + "?" + SQL_TYPE_PUB + SQL_SELECT_SUM\r
                 + "?" + SQL_PUBLISH_LENGTH + SQL_SUBSCRIBER_URL + SQL_SUB_ID + SQL_DELIVERY_TIME + SQL_AVERAGE_DELAY\r
-                + SQL_JOIN_RECORDS + "?" + ") " + "?" + SQL_STATUS_204\r
+                + SQL_JOIN_RECORDS + "?" + ") " + SQL_STATUS_204\r
                 + " and e.event_time>=" + "?" + SQL_GROUP_SUB_ID;\r
-            compareType = "start";\r
+            compareType = START;\r
         } else {\r
             SimpleDateFormat inFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss");\r
             Date startDate = inFormat.parse(startTime);\r
@@ -355,32 +356,28 @@ public class StatisticsServlet extends BaseServlet {
             endInMillis = endDate.getTime();\r
             sql = SQL_SELECT_NAME + "?" + SQL_FEED_ID + SQL_SELECT_COUNT + "?" + SQL_TYPE_PUB + SQL_SELECT_SUM\r
                 + "?" + SQL_PUBLISH_LENGTH + SQL_SUBSCRIBER_URL + SQL_SUB_ID + SQL_DELIVERY_TIME + SQL_AVERAGE_DELAY\r
-                + SQL_JOIN_RECORDS + "?" + ") " + "?" + SQL_STATUS_204\r
+                + SQL_JOIN_RECORDS + "?" + ") " +  SQL_STATUS_204\r
                 + " and e.event_time between " + "?" + " and " + "?" + SQL_GROUP_SUB_ID;\r
             compareType = "startAndEnd";\r
         }\r
-        try (Connection conn = ProvDbUtils.getInstance().getConnection()) {\r
+        try (Connection conn = ProvDbUtils.getInstance().getConnection();\r
+            PreparedStatement ps = conn.prepareStatement(sql)) {\r
             eventlogger.debug("SQL Query for Statistics resultset. " + sql);\r
             intlogger.debug(sql);\r
-            ps = conn.prepareStatement(sql);\r
             ps.setString(1, feedids);\r
             ps.setString(2, feedids);\r
             ps.setString(3, feedids);\r
             ps.setString(4, feedids);\r
-            ps.setString(5, subid);\r
-            if (compareType.equals("start")) {\r
-                ps.setLong(6, compareTime);\r
+            if (compareType.equals(START)) {\r
+                ps.setLong(5, compareTime);\r
             }\r
             if (compareType.equals("startAndEnd")) {\r
-                ps.setLong(6, startInMillis);\r
-                ps.setLong(7, endInMillis);\r
-            }\r
-        } finally {\r
-            if (ps != null) {\r
-                ps.close();\r
+                ps.setLong(5, startInMillis);\r
+                ps.setLong(6, endInMillis);\r
             }\r
+            rs = ps.executeQuery();\r
+            parseResult(outputType, out, resp, rs);\r
         }\r
-        return ps;\r
     }\r
 \r
 \r
@@ -491,7 +488,7 @@ public class StatisticsServlet extends BaseServlet {
             }\r
         }\r
 \r
-        long stime = getTimeFromParam(req.getParameter("start"));\r
+        long stime = getTimeFromParam(req.getParameter(START));\r
         if (stime < 0) {\r
             map.put("err", "bad start");\r
             return map;\r
@@ -540,12 +537,12 @@ public class StatisticsServlet extends BaseServlet {
         HttpServletResponse resp) {\r
         eventlogger.info("Generating sql query to get Statistics resultset. ");\r
         try {\r
-            PreparedStatement ps = this.queryGeneretor(map);\r
             long start = System.currentTimeMillis();\r
-            executeQuery(outputType, out, resp, ps);\r
+            this.getResultSet(map, outputType, out, resp);\r
             intlogger.debug("Time: " + (System.currentTimeMillis() - start) + " ms");\r
         } catch (SQLException e) {\r
-            eventlogger.error("SQLException:" + e);\r
+            eventlogger.error("SQLException:" + e.getMessage());\r
+            e.printStackTrace();\r
         } catch (IOException e) {\r
             eventlogger.error("IOException - Generating JSON/CSV:" + e);\r
         } catch (JSONException e) {\r
@@ -555,22 +552,18 @@ public class StatisticsServlet extends BaseServlet {
         }\r
     }\r
 \r
-    private void executeQuery(String outputType, ServletOutputStream out, HttpServletResponse resp,\r
-        PreparedStatement ps) throws IOException {\r
-        try (ResultSet rs = ps.executeQuery()) {\r
-            if (CSV_OUTPUT_TYPE.equals(outputType)) {\r
-                resp.setContentType("application/octet-stream");\r
-                DateTimeFormatter formatter = DateTimeFormatter.ofPattern("dd-MM-yyyy HH:mm:ss");\r
-                resp.setHeader("Content-Disposition",\r
-                    "attachment; filename=\"result:" + LocalDateTime.now().format(formatter) + ".csv\"");\r
-                eventlogger.info("Generating CSV file from Statistics resultset");\r
-                rsToCSV(rs, out);\r
-            } else {\r
-                eventlogger.info("Generating JSON for Statistics resultset");\r
-                this.rsToJson(rs, out);\r
-            }\r
-        } catch (SQLException e) {\r
-            eventlogger.error("SQLException:" + e);\r
+    private void parseResult(String outputType, ServletOutputStream out, HttpServletResponse resp,\r
+        ResultSet rs) throws IOException, SQLException {\r
+        if (CSV_OUTPUT_TYPE.equals(outputType)) {\r
+            resp.setContentType("application/octet-stream");\r
+            DateTimeFormatter formatter = DateTimeFormatter.ofPattern("dd-MM-yyyy HH:mm:ss");\r
+            resp.setHeader("Content-Disposition",\r
+                "attachment; filename=\"result:" + LocalDateTime.now().format(formatter) + ".csv\"");\r
+            eventlogger.info("Generating CSV file from Statistics resultset");\r
+            rsToCSV(rs, out);\r
+        } else {\r
+            eventlogger.info("Generating JSON for Statistics resultset");\r
+            this.rsToJson(rs, out);\r
         }\r
     }\r
 }\r