1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.hadoop.hbase.rest.filter;
19
20 import java.io.IOException;
21 import java.util.Collections;
22 import java.util.HashMap;
23 import java.util.HashSet;
24 import java.util.Map;
25 import java.util.Set;
26 import java.util.regex.Matcher;
27 import java.util.regex.Pattern;
28
29 import javax.servlet.Filter;
30 import javax.servlet.FilterChain;
31 import javax.servlet.FilterConfig;
32 import javax.servlet.ServletException;
33 import javax.servlet.ServletRequest;
34 import javax.servlet.ServletResponse;
35 import javax.servlet.http.HttpServletRequest;
36 import javax.servlet.http.HttpServletResponse;
37
38 import org.apache.commons.logging.Log;
39 import org.apache.commons.logging.LogFactory;
40 import org.apache.hadoop.conf.Configuration;
41 import org.apache.hadoop.hbase.classification.InterfaceAudience;
42 import org.apache.hadoop.hbase.classification.InterfaceStability;
43
44
45
46
47
48
49
50
51 @InterfaceAudience.Public
52 @InterfaceStability.Evolving
53 public class RestCsrfPreventionFilter implements Filter {
54 private static final Log LOG = LogFactory.getLog(RestCsrfPreventionFilter.class);
55
56 public static final String HEADER_USER_AGENT = "User-Agent";
57 public static final String BROWSER_USER_AGENT_PARAM =
58 "browser-useragents-regex";
59 public static final String CUSTOM_HEADER_PARAM = "custom-header";
60 public static final String CUSTOM_METHODS_TO_IGNORE_PARAM =
61 "methods-to-ignore";
62 static final String BROWSER_USER_AGENTS_DEFAULT = "^Mozilla.*,^Opera.*";
63 public static final String HEADER_DEFAULT = "X-XSRF-HEADER";
64 static final String METHODS_TO_IGNORE_DEFAULT = "GET,OPTIONS,HEAD,TRACE";
65 private String headerName = HEADER_DEFAULT;
66 private Set<String> methodsToIgnore = null;
67 private Set<Pattern> browserUserAgents;
68
69 @Override
70 public void init(FilterConfig filterConfig) {
71 String customHeader = filterConfig.getInitParameter(CUSTOM_HEADER_PARAM);
72 if (customHeader != null) {
73 headerName = customHeader;
74 }
75 String customMethodsToIgnore =
76 filterConfig.getInitParameter(CUSTOM_METHODS_TO_IGNORE_PARAM);
77 if (customMethodsToIgnore != null) {
78 parseMethodsToIgnore(customMethodsToIgnore);
79 } else {
80 parseMethodsToIgnore(METHODS_TO_IGNORE_DEFAULT);
81 }
82
83 String agents = filterConfig.getInitParameter(BROWSER_USER_AGENT_PARAM);
84 if (agents == null) {
85 agents = BROWSER_USER_AGENTS_DEFAULT;
86 }
87 parseBrowserUserAgents(agents);
88 LOG.info(String.format("Adding cross-site request forgery (CSRF) protection, "
89 + "headerName = %s, methodsToIgnore = %s, browserUserAgents = %s",
90 headerName, methodsToIgnore, browserUserAgents));
91 }
92
93 void parseBrowserUserAgents(String userAgents) {
94 String[] agentsArray = userAgents.split(",");
95 browserUserAgents = new HashSet<>();
96 for (String patternString : agentsArray) {
97 browserUserAgents.add(Pattern.compile(patternString));
98 }
99 }
100
101 void parseMethodsToIgnore(String mti) {
102 String[] methods = mti.split(",");
103 methodsToIgnore = new HashSet<>();
104 Collections.addAll(methodsToIgnore, methods);
105 }
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122 protected boolean isBrowser(String userAgent) {
123 if (userAgent == null) {
124 return false;
125 }
126 for (Pattern pattern : browserUserAgents) {
127 Matcher matcher = pattern.matcher(userAgent);
128 if (matcher.matches()) {
129 return true;
130 }
131 }
132 return false;
133 }
134
135
136
137
138
139
140
141
142
143
144 public interface HttpInteraction {
145
146
147
148
149
150
151 String getHeader(String header);
152
153
154
155
156
157
158 String getMethod();
159
160
161
162
163
164
165
166
167 void proceed() throws IOException, ServletException;
168
169
170
171
172
173
174
175
176
177 void sendError(int code, String message) throws IOException;
178 }
179
180
181
182
183
184
185
186
187
188 public void handleHttpInteraction(HttpInteraction httpInteraction)
189 throws IOException, ServletException {
190 if (!isBrowser(httpInteraction.getHeader(HEADER_USER_AGENT)) ||
191 methodsToIgnore.contains(httpInteraction.getMethod()) ||
192 httpInteraction.getHeader(headerName) != null) {
193 httpInteraction.proceed();
194 } else {
195 httpInteraction.sendError(HttpServletResponse.SC_BAD_REQUEST,
196 "Missing Required Header for CSRF Vulnerability Protection");
197 }
198 }
199
200 @Override
201 public void doFilter(ServletRequest request, ServletResponse response,
202 final FilterChain chain) throws IOException, ServletException {
203 final HttpServletRequest httpRequest = (HttpServletRequest)request;
204 final HttpServletResponse httpResponse = (HttpServletResponse)response;
205 handleHttpInteraction(new ServletFilterHttpInteraction(httpRequest,
206 httpResponse, chain));
207 }
208
209 @Override
210 public void destroy() {
211 }
212
213
214
215
216
217
218
219
220
221
222
223
224 public static Map<String, String> getFilterParams(Configuration conf, String confPrefix) {
225 Map<String, String> filterConfigMap = new HashMap<>();
226 for (Map.Entry<String, String> entry : conf) {
227 String name = entry.getKey();
228 if (name.startsWith(confPrefix)) {
229 String value = conf.get(name);
230 name = name.substring(confPrefix.length());
231 filterConfigMap.put(name, value);
232 }
233 }
234 return filterConfigMap;
235 }
236
237
238
239
240 private static final class ServletFilterHttpInteraction implements HttpInteraction {
241 private final FilterChain chain;
242 private final HttpServletRequest httpRequest;
243 private final HttpServletResponse httpResponse;
244
245
246
247
248
249
250
251
252 public ServletFilterHttpInteraction(HttpServletRequest httpRequest,
253 HttpServletResponse httpResponse, FilterChain chain) {
254 this.httpRequest = httpRequest;
255 this.httpResponse = httpResponse;
256 this.chain = chain;
257 }
258
259 @Override
260 public String getHeader(String header) {
261 return httpRequest.getHeader(header);
262 }
263
264 @Override
265 public String getMethod() {
266 return httpRequest.getMethod();
267 }
268
269 @Override
270 public void proceed() throws IOException, ServletException {
271 chain.doFilter(httpRequest, httpResponse);
272 }
273
274 @Override
275 public void sendError(int code, String message) throws IOException {
276 httpResponse.sendError(code, message);
277 }
278 }
279 }