[DMAAP-48] Initial code import
[dmaap/datarouter.git] / datarouter-prov / src / main / java / com / att / research / datarouter / provisioning / utils / ThrottleFilter.java
1 /*******************************************************************************\r
2  * ============LICENSE_START==================================================\r
3  * * org.onap.dmaap\r
4  * * ===========================================================================\r
5  * * Copyright © 2017 AT&T Intellectual Property. All rights reserved.\r
6  * * ===========================================================================\r
7  * * Licensed under the Apache License, Version 2.0 (the "License");\r
8  * * you may not use this file except in compliance with the License.\r
9  * * You may obtain a copy of the License at\r
10  * * \r
11  *  *      http://www.apache.org/licenses/LICENSE-2.0\r
12  * * \r
13  *  * Unless required by applicable law or agreed to in writing, software\r
14  * * distributed under the License is distributed on an "AS IS" BASIS,\r
15  * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r
16  * * See the License for the specific language governing permissions and\r
17  * * limitations under the License.\r
18  * * ============LICENSE_END====================================================\r
19  * *\r
20  * * ECOMP is a trademark and service mark of AT&T Intellectual Property.\r
21  * *\r
22  ******************************************************************************/\r
23 \r
24 \r
25 package com.att.research.datarouter.provisioning.utils;\r
26 \r
27 import java.io.IOException;\r
28 import java.io.InputStream;\r
29 import java.util.ArrayList;\r
30 import java.util.HashMap;\r
31 import java.util.List;\r
32 import java.util.Map;\r
33 import java.util.Timer;\r
34 import java.util.TimerTask;\r
35 import java.util.Vector;\r
36 \r
37 import javax.servlet.Filter;\r
38 import javax.servlet.FilterChain;\r
39 import javax.servlet.FilterConfig;\r
40 import javax.servlet.ServletException;\r
41 import javax.servlet.ServletRequest;\r
42 import javax.servlet.ServletResponse;\r
43 import javax.servlet.http.HttpServletRequest;\r
44 import javax.servlet.http.HttpServletResponse;\r
45 \r
46 import com.att.research.datarouter.provisioning.beans.Parameters;\r
47 \r
48 import org.apache.log4j.Logger;\r
49 import org.eclipse.jetty.continuation.Continuation;\r
50 import org.eclipse.jetty.continuation.ContinuationSupport;\r
51 import org.eclipse.jetty.server.AbstractHttpConnection;\r
52 import org.eclipse.jetty.server.Request;\r
53 \r
54 /**\r
55  * This filter checks /publish requests to the provisioning server to allow ill-behaved publishers to be throttled.\r
56  * It is configured via the provisioning parameter THROTTLE_FILTER.\r
57  * The THROTTLE_FILTER provisioning parameter can have these values:\r
58  * <table>\r
59  * <tr><td>(no value)</td><td>filter disabled</td></tr>\r
60  * <tr><td>off</td><td>filter disabled</td></tr>\r
61  * <tr><td>N[,M[,action]]</td><td>set N, M, and action (used in the algorithm below).\r
62  *     Action is <i>drop</i> or <i>throttle</i>.\r
63  *     If M is missing, it defaults to 5 minutes.\r
64  *     If the action is missing, it defaults to <i>drop</i>.\r
65  * </td></tr>\r
66  * </table>\r
67  * <p>\r
68  * The <i>action</i> is triggered iff:\r
69  * <ol>\r
70  * <li>the filter is enabled, and</li>\r
71  * <li>N /publish requests come to the provisioning server in M minutes\r
72  *   <ol>\r
73  *   <li>from the same IP address</li>\r
74  *   <li>for the same feed</li>\r
75  *   <li>lacking the <i>Expect: 100-continue</i> header</li>\r
76  *   </ol>\r
77  * </li>\r
78  * </ol>\r
79  * The action that can be performed (if triggered) are:\r
80  * <ol>\r
81  * <li><i>drop</i> - the connection is dropped immediately.</li>\r
82  * <li><i>throttle</i> - [not supported] the connection is put into a low priority queue with all other throttled connections.\r
83  *   These are then processed at a slower rate.  Note: this option does not work correctly, and is disabled.\r
84  *   The only action that is supported is <i>drop</i>.\r
85  * </li>\r
86  * </ol>\r
87  *\r
88  * @author Robert Eby\r
89  * @version $Id: ThrottleFilter.java,v 1.2 2014/03/12 19:45:41 eby Exp $\r
90  */\r
91 public class ThrottleFilter extends TimerTask implements Filter {\r
92         public  static final int    DEFAULT_N       = 10;\r
93         public  static final int    DEFAULT_M       = 5;\r
94         public  static final String THROTTLE_MARKER = "com.att.research.datarouter.provisioning.THROTTLE_MARKER";\r
95         private static final String JETTY_REQUEST   = "org.eclipse.jetty.server.Request";\r
96         private static final long   ONE_MINUTE      = 60000L;\r
97         private static final int    ACTION_DROP     = 0;\r
98         private static final int    ACTION_THROTTLE = 1;\r
99 \r
100         // Configuration\r
101         private static boolean enabled = false;         // enabled or not\r
102         private static int n_requests = 0;                      // number of requests in M minutes\r
103         private static int m_minutes = 0;                       // sampling period\r
104         private static int action = ACTION_DROP;        // action to take (throttle or drop)\r
105 \r
106         private static Logger logger = Logger.getLogger("com.att.research.datarouter.provisioning.internal");\r
107         private static Map<String, Counter> map = new HashMap<String, Counter>();\r
108         private static final Timer rolex = new Timer();\r
109 \r
110         @Override\r
111         public void init(FilterConfig arg0) throws ServletException {\r
112                 configure();\r
113                 rolex.scheduleAtFixedRate(this, 5*60000L, 5*60000L);    // Run once every 5 minutes to clean map\r
114         }\r
115 \r
116         /**\r
117          * Configure the throttle.  This should be called from BaseServlet.provisioningParametersChanged(), to make sure it stays up to date.\r
118          */\r
119         public static void configure() {\r
120                 Parameters p = Parameters.getParameter(Parameters.THROTTLE_FILTER);\r
121                 if (p != null) {\r
122                         try {\r
123                                 Class.forName(JETTY_REQUEST);\r
124                                 String v = p.getValue();\r
125                                 if (v != null && !v.equals("off")) {\r
126                                         String[] pp = v.split(",");\r
127                                         if (pp != null) {\r
128                                                 n_requests = (pp.length > 0) ? getInt(pp[0], DEFAULT_N) : DEFAULT_N;\r
129                                                 m_minutes  = (pp.length > 1) ? getInt(pp[1], DEFAULT_M) : DEFAULT_M;\r
130                                                 action     = (pp.length > 2 && pp[2] != null && pp[2].equalsIgnoreCase("throttle")) ? ACTION_THROTTLE : ACTION_DROP;\r
131                                                 enabled    = true;\r
132                                                 // ACTION_THROTTLE is not currently working, so is not supported\r
133                                                 if (action == ACTION_THROTTLE) {\r
134                                                         action = ACTION_DROP;\r
135                                                         logger.info("Throttling is not currently supported; action changed to DROP");\r
136                                                 }\r
137                                                 logger.info("ThrottleFilter is ENABLED for /publish requests; N="+n_requests+", M="+m_minutes+", Action="+action);\r
138                                                 return;\r
139                                         }\r
140                                 }\r
141                         } catch (ClassNotFoundException e) {\r
142                                 logger.warn("Class "+JETTY_REQUEST+" is not available; this filter requires Jetty.");\r
143                         }\r
144                 }\r
145                 logger.info("ThrottleFilter is DISABLED for /publish requests.");\r
146                 enabled = false;\r
147                 map.clear();\r
148         }\r
149         private static int getInt(String s, int deflt) {\r
150                 try {\r
151                         return Integer.parseInt(s);\r
152                 } catch (NumberFormatException x) {\r
153                         return deflt;\r
154                 }\r
155         }\r
156         @Override\r
157         public void destroy() {\r
158                 rolex.cancel();\r
159                 map.clear();\r
160         }\r
161 \r
162         @Override\r
163         public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)\r
164                 throws IOException, ServletException\r
165         {\r
166                 if (enabled && action == ACTION_THROTTLE) {\r
167                         throttleFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);\r
168                 } else if (enabled) {\r
169                         dropFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);\r
170                 } else {\r
171                         chain.doFilter(request, response);\r
172                 }\r
173         }\r
174         public void dropFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)\r
175                 throws IOException, ServletException\r
176         {\r
177                 int rate = getRequestRate((HttpServletRequest) request);\r
178                 if (rate >= n_requests) {\r
179                         // drop request - only works under Jetty\r
180                         String m = String.format("Dropping connection: %s %d bad connections in %d minutes", getConnectionId((HttpServletRequest) request), rate, m_minutes);\r
181                         logger.info(m);\r
182                         Request base_request = (request instanceof Request)\r
183                                 ? (Request) request\r
184                                 : AbstractHttpConnection.getCurrentConnection().getRequest();\r
185                         base_request.getConnection().getEndPoint().close();\r
186                 } else {\r
187                         chain.doFilter(request, response);\r
188                 }\r
189         }\r
190         public void throttleFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)\r
191                 throws IOException, ServletException\r
192         {\r
193                 // throttle request\r
194                 String id = getConnectionId((HttpServletRequest) request);\r
195                 int rate = getRequestRate((HttpServletRequest) request);\r
196                 Object results = request.getAttribute(THROTTLE_MARKER);\r
197                 if (rate >= n_requests && results == null) {\r
198                         String m = String.format("Throttling connection: %s %d bad connections in %d minutes", getConnectionId((HttpServletRequest) request), rate, m_minutes);\r
199                         logger.info(m);\r
200                         Continuation continuation = ContinuationSupport.getContinuation(request);\r
201                         continuation.suspend();\r
202                         register(id, continuation);\r
203                         continuation.undispatch();\r
204                 } else {\r
205                         chain.doFilter(request, response);\r
206                         @SuppressWarnings("resource")\r
207                         InputStream is = request.getInputStream();\r
208                         byte[] b = new byte[4096];\r
209                         int n = is.read(b);\r
210                         while (n > 0) {\r
211                                 n = is.read(b);\r
212                         }\r
213                         resume(id);\r
214                 }\r
215         }\r
216         private Map<String, List<Continuation>> suspended_requests = new HashMap<String, List<Continuation>>();\r
217         private void register(String id, Continuation continuation) {\r
218                 synchronized (suspended_requests) {\r
219                         List<Continuation> list = suspended_requests.get(id);\r
220                         if (list == null) {\r
221                                 list = new ArrayList<Continuation>();\r
222                                 suspended_requests.put(id,  list);\r
223                         }\r
224                         list.add(continuation);\r
225                 }\r
226         }\r
227         private void resume(String id) {\r
228                 synchronized (suspended_requests) {\r
229                         List<Continuation> list = suspended_requests.get(id);\r
230                         if (list != null) {\r
231                                 // when the waited for event happens\r
232                                 Continuation continuation = list.remove(0);\r
233                                 continuation.setAttribute(ThrottleFilter.THROTTLE_MARKER, new Object());\r
234                                 continuation.resume();\r
235                         }\r
236                 }\r
237         }\r
238 \r
239         /**\r
240          * Return a count of number of requests in the last M minutes, iff this is a "bad" request.\r
241          * If the request has been resumed (if it contains the THROTTLE_MARKER) it is considered good.\r
242          * @param request the request\r
243          * @return number of requests in the last M minutes, 0 means it is a "good" request\r
244          */\r
245         private int getRequestRate(HttpServletRequest request) {\r
246                 String expecthdr = request.getHeader("Expect");\r
247                 if (expecthdr != null && expecthdr.equalsIgnoreCase("100-continue"))\r
248                         return 0;\r
249 \r
250                 String key = getConnectionId(request);\r
251                 synchronized (map) {\r
252                         Counter cnt = map.get(key);\r
253                         if (cnt == null) {\r
254                                 cnt = new Counter();\r
255                                 map.put(key, cnt);\r
256                         }\r
257                         int n = cnt.getRequestRate();\r
258                         return n;\r
259                 }\r
260         }\r
261 \r
262         public class Counter {\r
263                 private List<Long> times = new Vector<Long>();  // a record of request times\r
264                 public int prune() {\r
265                         try {\r
266                                 long n = System.currentTimeMillis() - (m_minutes * ONE_MINUTE);\r
267                                 long t = times.get(0);\r
268                                 while (t < n) {\r
269                                         times.remove(0);\r
270                                         t = times.get(0);\r
271                                 }\r
272                         } catch (IndexOutOfBoundsException e) {\r
273                                 // ignore\r
274                         }\r
275                         return times.size();\r
276                 }\r
277                 public int getRequestRate() {\r
278                         times.add(System.currentTimeMillis());\r
279                         return prune();\r
280                 }\r
281         }\r
282 \r
283         /**\r
284          *  Identify a connection by endpoint IP address, and feed ID.\r
285          */\r
286         private String getConnectionId(HttpServletRequest req) {\r
287                 return req.getRemoteAddr() + "/" + getFeedId(req);\r
288         }\r
289         private int getFeedId(HttpServletRequest req) {\r
290                 String path = req.getPathInfo();\r
291                 if (path == null || path.length() < 2)\r
292                         return -1;\r
293                 path = path.substring(1);\r
294                 int ix = path.indexOf('/');\r
295                 if (ix < 0 || ix == path.length()-1)\r
296                         return -2;\r
297                 try {\r
298                         int feedid = Integer.parseInt(path.substring(0, ix));\r
299                         return feedid;\r
300                 } catch (NumberFormatException e) {\r
301                         return -1;\r
302                 }\r
303         }\r
304 \r
305         @Override\r
306         public void run() {\r
307                 // Once every 5 minutes, go through the map, and remove empty entrys\r
308                 for (Object s : map.keySet().toArray()) {\r
309                         synchronized (map) {\r
310                                 Counter c = map.get(s);\r
311                                 if (c.prune() <= 0)\r
312                                         map.remove(s);\r
313                         }\r
314                 }\r
315         }\r
316 }\r