[DMAAP-DR] Fix sql injection bug 72/127572/1
authorefiacor <fiachra.corcoran@est.tech>
Wed, 9 Mar 2022 11:48:35 +0000 (11:48 +0000)
committerefiacor <fiachra.corcoran@est.tech>
Wed, 9 Mar 2022 11:48:41 +0000 (11:48 +0000)
Signed-off-by: efiacor <fiachra.corcoran@est.tech>
Change-Id: Icccc65b3b90c553dea74d95bf247b08ae6f78506
Issue-ID: DMAAP-1623

datarouter-prov/src/main/java/org/onap/dmaap/datarouter/provisioning/StatisticsServlet.java
datarouter-prov/src/test/java/org/onap/dmaap/datarouter/provisioning/StatisticsServletTest.java

index 4bc3977..964ef03 100755 (executable)
@@ -163,33 +163,33 @@ public class StatisticsServlet extends BaseServlet {
             map.put(OUTPUT_TYPE, JSON_OUTPUT_TYPE);\r
         }\r
         if (req.getParameter(START_TIME) != null) {\r
-            String start_time = req.getParameter(START_TIME);\r
-            try{\r
-                Long.parseLong(start_time);\r
-                map.put(START_TIME, start_time);\r
+            String startTime = req.getParameter(START_TIME);\r
+            try {\r
+                Long.parseLong(startTime);\r
+                map.put(START_TIME, startTime);\r
             }\r
-            catch (NumberFormatException e){\r
+            catch (NumberFormatException e) {\r
                 eventlogger.error("Invalid start time StatisticsServlet.doGet: " +  e.getMessage(), e);\r
             }\r
         }\r
         if (req.getParameter(END_TIME) != null) {\r
-            String end_time = req.getParameter(END_TIME);\r
-            try{\r
-                Long.parseLong(end_time);\r
-                map.put(END_TIME, end_time);\r
+            String endTime = req.getParameter(END_TIME);\r
+            try {\r
+                Long.parseLong(endTime);\r
+                map.put(END_TIME, endTime);\r
             }\r
-            catch (NumberFormatException e){\r
+            catch (NumberFormatException e) {\r
                 eventlogger.error("Invalid end time StatisticsServlet.doGet: " +  e.getMessage(), e);\r
             }\r
         }\r
         if (req.getParameter("time") != null) {\r
             String time = req.getParameter("time");\r
-            try{\r
+            try {\r
                 Long.parseLong(time);\r
                 map.put(START_TIME, time);\r
                 map.put(END_TIME, null);\r
             }\r
-            catch (NumberFormatException e){\r
+            catch (NumberFormatException e) {\r
                 eventlogger.error("Invalid end time StatisticsServlet.doGet: " +  e.getMessage(), e);\r
             }\r
         }\r
@@ -201,11 +201,6 @@ public class StatisticsServlet extends BaseServlet {
 \r
     }\r
 \r
-    private boolean validateDateInput(String date){\r
-\r
-        return true;\r
-    }\r
-\r
     /**\r
      * rsToJson - Converting RS to JSON object.\r
      *\r
@@ -310,13 +305,18 @@ public class StatisticsServlet extends BaseServlet {
      *\r
      * @param map as key value pare of all user input fields\r
      */\r
-    private String queryGeneretor(Map<String, String> map) throws ParseException {\r
+    private PreparedStatement queryGeneretor(Map<String, String> map) throws ParseException, SQLException {\r
 \r
         String sql;\r
         String feedids = null;\r
         String startTime = null;\r
         String endTime = null;\r
+        long compareTime = 0;\r
+        long startInMillis = 0;\r
+        long endInMillis = 0;\r
         String subid = " ";\r
+        String compareType = null;\r
+        PreparedStatement ps = null;\r
 \r
         if (map.get(FEEDIDS) != null) {\r
             feedids = map.get(FEEDIDS);\r
@@ -331,49 +331,56 @@ public class StatisticsServlet extends BaseServlet {
             subid = map.get(SUBID);\r
         }\r
 \r
-        eventlogger.info("Generating sql query to get Statistics resultset. ");\r
-\r
         if (endTime == null && startTime == null) {\r
-\r
-            sql =  SQL_SELECT_NAME + feedids + SQL_FEED_ID + SQL_SELECT_COUNT + feedids + SQL_TYPE_PUB\r
-                + SQL_SELECT_SUM\r
-                + feedids + SQL_PUBLISH_LENGTH\r
-                + SQL_SUBSCRIBER_URL + SQL_SUB_ID + SQL_DELIVERY_TIME + SQL_AVERAGE_DELAY + SQL_JOIN_RECORDS\r
-                + feedids + ") " + subid\r
-                + SQL_STATUS_204 + SQL_GROUP_SUB_ID;\r
-\r
-            return sql;\r
+            sql =  SQL_SELECT_NAME + "?" + SQL_FEED_ID + SQL_SELECT_COUNT + "?" + SQL_TYPE_PUB + SQL_SELECT_SUM\r
+                + "?" + SQL_PUBLISH_LENGTH + SQL_SUBSCRIBER_URL + SQL_SUB_ID + SQL_DELIVERY_TIME + SQL_AVERAGE_DELAY\r
+                + SQL_JOIN_RECORDS + "?" + ") " + "?" + SQL_STATUS_204\r
+                + SQL_GROUP_SUB_ID;\r
+            compareType = "default";\r
         } else if (startTime != null && endTime == null) {\r
-\r
             long inputTimeInMilli = 60000 * Long.parseLong(startTime);\r
             Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT"));\r
             long currentTimeInMilli = cal.getTimeInMillis();\r
-            long compareTime = currentTimeInMilli - inputTimeInMilli;\r
-\r
-            sql = SQL_SELECT_NAME + feedids + SQL_FEED_ID + SQL_SELECT_COUNT + feedids + SQL_TYPE_PUB\r
-                + SQL_SELECT_SUM\r
-                + feedids + SQL_PUBLISH_LENGTH\r
-                + SQL_SUBSCRIBER_URL + SQL_SUB_ID + SQL_DELIVERY_TIME + SQL_AVERAGE_DELAY + SQL_JOIN_RECORDS\r
-                + feedids + ") " + subid\r
-                + SQL_STATUS_204 + " and e.event_time>=" + compareTime + SQL_GROUP_SUB_ID;\r
-            return sql;\r
-\r
+            compareTime = currentTimeInMilli - inputTimeInMilli;\r
+            sql = SQL_SELECT_NAME + "?" + SQL_FEED_ID + SQL_SELECT_COUNT + "?" + SQL_TYPE_PUB + SQL_SELECT_SUM\r
+                + "?" + SQL_PUBLISH_LENGTH + SQL_SUBSCRIBER_URL + SQL_SUB_ID + SQL_DELIVERY_TIME + SQL_AVERAGE_DELAY\r
+                + SQL_JOIN_RECORDS + "?" + ") " + "?" + SQL_STATUS_204\r
+                + " and e.event_time>=" + "?" + SQL_GROUP_SUB_ID;\r
+            compareType = "start";\r
         } else {\r
             SimpleDateFormat inFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss");\r
             Date startDate = inFormat.parse(startTime);\r
             Date endDate = inFormat.parse(endTime);\r
-\r
-            long startInMillis = startDate.getTime();\r
-            long endInMillis = endDate.getTime();\r
-\r
-            sql = SQL_SELECT_NAME + feedids + SQL_FEED_ID + SQL_SELECT_COUNT + feedids + SQL_TYPE_PUB\r
-                + SQL_SELECT_SUM\r
-                + feedids + SQL_PUBLISH_LENGTH + SQL_SUBSCRIBER_URL\r
-                + SQL_SUB_ID + SQL_DELIVERY_TIME + SQL_AVERAGE_DELAY + SQL_JOIN_RECORDS + feedids + ")" + subid + SQL_STATUS_204\r
-                +" and e.event_time between " + startInMillis + " and " + endInMillis + SQL_GROUP_SUB_ID;\r
-\r
-            return sql;\r
+            startInMillis = startDate.getTime();\r
+            endInMillis = endDate.getTime();\r
+            sql = SQL_SELECT_NAME + "?" + SQL_FEED_ID + SQL_SELECT_COUNT + "?" + SQL_TYPE_PUB + SQL_SELECT_SUM\r
+                + "?" + SQL_PUBLISH_LENGTH + SQL_SUBSCRIBER_URL + SQL_SUB_ID + SQL_DELIVERY_TIME + SQL_AVERAGE_DELAY\r
+                + SQL_JOIN_RECORDS + "?" + ") " + "?" + SQL_STATUS_204\r
+                + " and e.event_time between " + "?" + " and " + "?" + SQL_GROUP_SUB_ID;\r
+            compareType = "startAndEnd";\r
+        }\r
+        try (Connection conn = ProvDbUtils.getInstance().getConnection()) {\r
+            eventlogger.debug("SQL Query for Statistics resultset. " + sql);\r
+            intlogger.debug(sql);\r
+            ps = conn.prepareStatement(sql);\r
+            ps.setString(1, feedids);\r
+            ps.setString(2, feedids);\r
+            ps.setString(3, feedids);\r
+            ps.setString(4, feedids);\r
+            ps.setString(5, subid);\r
+            if (compareType.equals("start")) {\r
+                ps.setLong(6, compareTime);\r
+            }\r
+            if (compareType.equals("startAndEnd")) {\r
+                ps.setLong(6, startInMillis);\r
+                ps.setLong(7, endInMillis);\r
+            }\r
+        } finally {\r
+            if (ps != null) {\r
+                ps.close();\r
+            }\r
         }\r
+        return ps;\r
     }\r
 \r
 \r
@@ -459,7 +466,7 @@ public class StatisticsServlet extends BaseServlet {
                 return map;\r
             }\r
             map.put("statusSQL", sql);\r
-            map.put("resultSQL", sql.replaceAll("STATUS", "RESULT"));\r
+            map.put("resultSQL", sql.replace("STATUS", "RESULT"));\r
         }\r
 \r
         str = req.getParameter("expiryReason");\r
@@ -531,29 +538,14 @@ public class StatisticsServlet extends BaseServlet {
 \r
     private void getRecordsForSQL(Map<String, String> map, String outputType, ServletOutputStream out,\r
         HttpServletResponse resp) {\r
+        eventlogger.info("Generating sql query to get Statistics resultset. ");\r
         try {\r
-            String filterQuery = this.queryGeneretor(map);\r
-            eventlogger.debug("SQL Query for Statistics resultset. " + filterQuery);\r
-            intlogger.debug(filterQuery);\r
+            PreparedStatement ps = this.queryGeneretor(map);\r
             long start = System.currentTimeMillis();\r
-            try (Connection conn = ProvDbUtils.getInstance().getConnection();\r
-                PreparedStatement ps = conn.prepareStatement(filterQuery);\r
-                ResultSet rs = ps.executeQuery()) {\r
-                if (CSV_OUTPUT_TYPE.equals(outputType)) {\r
-                    resp.setContentType("application/octet-stream");\r
-                    DateTimeFormatter formatter = DateTimeFormatter.ofPattern("dd-MM-yyyy HH:mm:ss");\r
-                    resp.setHeader("Content-Disposition",\r
-                        "attachment; filename=\"result:" + LocalDateTime.now().format(formatter) + ".csv\"");\r
-                    eventlogger.info("Generating CSV file from Statistics resultset");\r
-                    rsToCSV(rs, out);\r
-                } else {\r
-                    eventlogger.info("Generating JSON for Statistics resultset");\r
-                    this.rsToJson(rs, out);\r
-                }\r
-            } catch (SQLException e) {\r
-                eventlogger.error("SQLException:" + e);\r
-            }\r
+            executeQuery(outputType, out, resp, ps);\r
             intlogger.debug("Time: " + (System.currentTimeMillis() - start) + " ms");\r
+        } catch (SQLException e) {\r
+            eventlogger.error("SQLException:" + e);\r
         } catch (IOException e) {\r
             eventlogger.error("IOException - Generating JSON/CSV:" + e);\r
         } catch (JSONException e) {\r
@@ -562,5 +554,24 @@ public class StatisticsServlet extends BaseServlet {
             eventlogger.error("ParseException - executing SQL query:" + e);\r
         }\r
     }\r
+\r
+    private void executeQuery(String outputType, ServletOutputStream out, HttpServletResponse resp,\r
+        PreparedStatement ps) throws IOException {\r
+        try (ResultSet rs = ps.executeQuery()) {\r
+            if (CSV_OUTPUT_TYPE.equals(outputType)) {\r
+                resp.setContentType("application/octet-stream");\r
+                DateTimeFormatter formatter = DateTimeFormatter.ofPattern("dd-MM-yyyy HH:mm:ss");\r
+                resp.setHeader("Content-Disposition",\r
+                    "attachment; filename=\"result:" + LocalDateTime.now().format(formatter) + ".csv\"");\r
+                eventlogger.info("Generating CSV file from Statistics resultset");\r
+                rsToCSV(rs, out);\r
+            } else {\r
+                eventlogger.info("Generating JSON for Statistics resultset");\r
+                this.rsToJson(rs, out);\r
+            }\r
+        } catch (SQLException e) {\r
+            eventlogger.error("SQLException:" + e);\r
+        }\r
+    }\r
 }\r
 \r
index 1fe8d9b..b6686b0 100755 (executable)
@@ -119,7 +119,7 @@ public class StatisticsServletTest {
     ServletOutputStream outStream = mock(ServletOutputStream.class);
     when(response.getOutputStream()).thenReturn(outStream);
     statisticsServlet.doGet(request, response);
-    verify(response).setStatus(eq(HttpServletResponse.SC_OK));
+    verify(response).setStatus(HttpServletResponse.SC_OK);
   }
 
   @Test
@@ -130,7 +130,7 @@ public class StatisticsServletTest {
     ServletOutputStream outStream = mock(ServletOutputStream.class);
     when(response.getOutputStream()).thenReturn(outStream);
     statisticsServlet.doGet(request, response);
-    verify(response).setStatus(eq(HttpServletResponse.SC_OK));
+    verify(response).setStatus(HttpServletResponse.SC_OK);
   }
 
   private void buildRequestParameters() {