一、背景
笔者工作中遇到一个需求,需要开发一个注解,放在controller层的类或者方法上,用以校验请求参数中(不管是url还是body体内,都要检查,有token参数,且符合校验规则就放行)是否传了一个token的参数,并且token符合一定的生成规则,符合就不予拦截,放行请求,否则拦截请求。
用法如下图所示

可以看到 @TokenCheck 注解既可以放在类上,也可以放在方法上 ,放在类上则对该类中的所有的方法进行拦截校验。
注意:是加了注解才会校验是否拦截,不加没有影响。
整个代码都是使用的最新springboot版本开发的,所以servlet相关的类都是使用jakarta


如果你的springboot版本比较老 ,请使用javax
先引入以下依赖(javax不飘红不用引入)
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>javax.servlet-api</artifactId>
<version>4.0.1</version>
<scope>provided</scope>
</dependency>


我用到的第三方依赖
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.8.24</version>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-webmvc</artifactId>
<version>6.0.11</version>
</dependency>
二、TokenCheck注解
package com.example.demo.interceptorToken;
import java.lang.annotation.*;
/**
* 是否有token
*/
@Documented
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface TokenCheck {
}
三、请求包装器RequestWrapper
主要是对request请求包装下,因为拦截器会拦截request,会读取其中的参数流,而流只能读一次,后续再用到流的读取会报错,所以用一个包装器类处理下,把流以字节形式读出来,重写了getInputStream(),后续可以重复使用。
package com.example.demo.interceptorToken;
import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import org.apache.commons.lang3.StringUtils;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
/**
* @author hulei
* @date 2024/1/11 19:48
* @Description 由于 request中getReader()和getInputStream()只能调用一次 导致在Controller @ResponseBody的时候获取不到 null 或 Stream closed
* 在项目中,可能会出现需要针对接口参数进行校验等问题
* 构建可重复读取inputStream的request
*/
public class RequestWrapper extends HttpServletRequestWrapper {
// 将流保存下来
private final byte[] requestBody;
public RequestWrapper(HttpServletRequest request) throws IOException {
super(request);
requestBody = readBytes(request.getReader());
}
@Override
public ServletInputStream getInputStream() {
final ByteArrayInputStream basic = new ByteArrayInputStream((requestBody != null && requestBody.length >0) ? requestBody : new byte[]{});
return new ServletInputStream() {
@Override
public int read() {
return basic.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
@Override
public BufferedReader getReader() {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
/**
* 通过BufferedReader和字符编码集转换成byte数组
*/
private byte[] readBytes(BufferedReader br) throws IOException {
String str;
StringBuilder retStr = new StringBuilder();
while ((str = br.readLine()) != null) {
retStr.append(str);
}
if (StringUtils.isNotBlank(retStr.toString())) {
return retStr.toString().getBytes(StandardCharsets.UTF_8);
}
return null;
}
}
四、过滤器RequestFilter
自定义请求过滤器,把请求用自定义的包装器RequestWrapper包装下,往调用下文传递,也是为了让request请求的流能多次读取
package com.example.demo.interceptorToken; import jakarta.servlet.*; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import lombok.NonNull; import org.springframework.core.annotation.Order; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.multipart.MultipartHttpServletRequest; import org.springframework.web.multipart.support.StandardServletMultipartResolver; import java.io.IOException; /** * @author hulei * @date 2024/1/11 19:48 * 自定义请求过滤器 */ //排序优先级,最先执行的过滤器 @Order(0) public class RequestFilter extends OncePerRequestFilter { //spring6.0版本后删除了CommonsMultipartResolver,使用StandardServletMultipartResolver //如果是spring6.0版本,此行代码不报错请使用如下 // private CommonsMultipartResolver multipartResolver = new CommonsMultipartResolver(); private final StandardServletMultipartResolver multipartResolver = new StandardServletMultipartResolver(); /** * */ @Override protected void doFilterInternal(@NonNull HttpServletRequest request, @NonNull HttpServletResponse response, @NonNull FilterChain filterChain) throws ServletException, IOException { //请求参数有form_data的话,防止request.getHeaders()报已使用,单独处理 if (request.getContentType().contains("multipart/form-data")) { MultipartHttpServletRequest multiReq = multipartResolver.resolveMultipart(request); filterChain.doFilter(multiReq, response); }else{ ServletRequest requestWrapper; requestWrapper = new RequestWrapper(request); filterChain.doFilter(requestWrapper, response); } } }
五、请求过滤器配置类TokenFilterConfig
这个很好理解,把自定义配置类注入spring容器
package com.example.demo.interceptorToken;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterConfig;
import jakarta.servlet.ServletContext;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.Enumeration;
/**
* @author hulei
* @date 2024/1/11 19:48
* 将过滤器注入spring容器中
*/
@Configuration
public class TokenFilterConfig implements FilterConfig {
@Bean
Filter bodyFilter() {
return new RequestFilter();
}
@Bean
public FilterRegistrationBean<RequestFilter> filters() {
FilterRegistrationBean<RequestFilter> filterRegistrationBean = new FilterRegistrationBean<>();
filterRegistrationBean.setFilter((RequestFilter) bodyFilter());
filterRegistrationBean.addUrlPatterns("/*");
filterRegistrationBean.setName("requestFilter");
//多个filter的时候order的数值越小 则优先级越高
//filterRegistrationBean.setOrder(0);
return filterRegistrationBean;
}
@Override
public String getFilterName() {
return null;
}
@Override
public ServletContext getServletContext() {
return null;
}
@Override
public String getInitParameter(String s) {
return null;
}
@Override
public Enumeration<String> getInitParameterNames() {
return null;
}
}
六、核心类RequestInterceptor拦截器
注意如果你的springboot版本也是低于3.0,请继承HandlerInterceptorAdapter类,实现其中方法,基本不用改动类中的内容,只需要 把implements HandlerInterceptor 改为extends HandlerInterceptorAdapter即可。
package com.example.demo.interceptorToken;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.NonNull;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.StreamUtils;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.multipart.MultipartHttpServletRequest;
import org.springframework.web.multipart.support.StandardServletMultipartResolver;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.Method;
import java.net.URLDecoder;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* @author hulei
* @date 2024/1/11 19:48
* 自定义请求拦截器(spring boot 3.0以下的版本,需要继承HandlerInterceptorAdapter类,de方法)
*/
public class RequestInterceptor implements HandlerInterceptor {
/**
* 需要从请求里验证的关键字参数名
*/
private static final String TOKEN_STR = "token";
/**
* 进入拦截的方法前触发
* 这里主要从打了注解请求中查找有没有token关键字,并且token的值是否符合一定的生成规则,是就放行,不是就拦截
*/
@Override
public boolean preHandle(@NonNull HttpServletRequest request,@NonNull HttpServletResponse response,@NonNull Object handler) throws Exception {
if(handler instanceof HandlerMethod handlerMethod){
//获取token注解
TokenCheck tokenCheck = getTokenCheck(handlerMethod);
//请求参数有form_data的话,防止request.getHeaders()或request.getInputStream()报已使用错误,单独处理
if( request.getContentType() != null && request.getContentType().contains("multipart/form-data")){
//判断当前注解是否存在
if(tokenCheck != null){
final StandardServletMultipartResolver multipartResolver = new StandardServletMultipartResolver();
MultipartHttpServletRequest multipartHttpServletRequest = multipartResolver.resolveMultipart(request);
//获取全部参数,不管是params里的还是form_data里的
//Map<String,String[]> bodyParam = multipartHttpServletRequest.getParameterMap();
//直接获取token参数
String token = multipartHttpServletRequest.getParameter(TOKEN_STR);
if(!StringUtils.isEmpty(token)){
boolean tokenRuleValidation = tokenRuleValidation(token);
if(!tokenRuleValidation){
returnJson(response, "token校验失败");
return false;
}
return true;
}
returnJson(response, "token校验失败");
return false;
}
}else{
//判断当前注解是否存在
if (tokenCheck != null) {
// 获取请求方式
//String requestMethod = request.getMethod();
// 获取请求参数
Map<String,String> paramMap;
//token关键字,分别是来自url的token或者来自body中的token
String tokenFromUrl,tokenFromBody = "";
request = new RequestWrapper(request);
String bodyParamsStr = this.getPostParam(request);
tokenFromBody = getTokenFromBody(bodyParamsStr,tokenFromBody);
paramMap = getUrlQueryMap(request);
tokenFromUrl = paramMap.get(TOKEN_STR);
if(tokenRuleValidation(tokenFromUrl)|| tokenRuleValidation(tokenFromBody)){
return true;
}else {
returnJson(response, "token校验失败");
return false;
}
}
}
return true;
}
return true;
}
private static TokenCheck getTokenCheck(HandlerMethod handler) {
Method method = handler.getMethod();
//获取方法所属的类,并获取类上的@TokenCheck注解
Class<?> clazz = method.getDeclaringClass();
TokenCheck tokenCheck = null;
if(clazz.isAnnotationPresent(TokenCheck.class)){
tokenCheck = clazz.getAnnotation(TokenCheck.class);
}
//类上没有注解,则从方法上再获取@TokenCheck
tokenCheck = tokenCheck == null ? method.getAnnotation(TokenCheck.class) : tokenCheck;
return tokenCheck;
}
/**
* 从请求体获取token参数
*/
private String getTokenFromBody(String bodyParamsStr,String tokenFromBody){
//判断是否是json数组
boolean isJsonArray = JSONUtil.isTypeJSONArray(bodyParamsStr);
if(!isJsonArray){
tokenFromBody = JSONUtil.parseObj(bodyParamsStr).getStr(TOKEN_STR);
}else{
JSONArray jsonArray = JSONUtil.parseArray(bodyParamsStr);
Set<String> tokenSet = new HashSet<>();
for (int i = 0; i < jsonArray.size(); i++) {
JSONObject jsonObject = jsonArray.getJSONObject(i);
if(StringUtils.isNotEmpty(jsonObject.getStr(TOKEN_STR))){
tokenSet.add(jsonObject.getStr(TOKEN_STR));
}
}
if(!tokenSet.isEmpty()){
tokenFromBody = tokenSet.stream().filter(this::tokenRuleValidation).findFirst().orElse("");
}
}
return tokenFromBody;
}
/**
* token 规则校验
* @param token token关键字
*/
private boolean tokenRuleValidation(String token){
return "AAABBB".equals(token);
}
/**
* 如果是get请求,则把url中的请求参数获取到,转换为map
*/
public static Map<String, String> getUrlQueryMap(HttpServletRequest request) throws UnsupportedEncodingException {
//获取当前请求的编码方式,用于参数value解码
String encoding = request.getCharacterEncoding();
String urlQueryString = request.getQueryString();
Map<String, String> queryMap = new HashMap<>();
String[] arrSplit;
if (urlQueryString == null) {
return queryMap;
} else {
//每个键值为一组
arrSplit = urlQueryString.split("&");
for (String strSplit : arrSplit) {
String[] arrSplitEqual = strSplit.split("=");
//解析出键值
if (arrSplitEqual.length > 1) {
queryMap.put(arrSplitEqual[0],URLDecoder.decode(arrSplitEqual[1], encoding));
} else {
if (!"".equals(arrSplitEqual[0])) {
queryMap.put(arrSplitEqual[0], "");
}
}
}
}
return queryMap;
}
/**
* 离开拦截的方法后触发
*/
@Override
public void postHandle(@NonNull HttpServletRequest request,@NonNull HttpServletResponse response,@NonNull Object handler, ModelAndView modelAndView) {
}
/**
* 返回
*/
private void returnJson(HttpServletResponse response, String json) throws IOException {
response.setCharacterEncoding("UTF-8");
response.setContentType("text/html; charset=utf-8");
try (PrintWriter writer = response.getWriter()) {
writer.print(json);
}
}
private String getPostParam(HttpServletRequest request) throws Exception{
RequestWrapper readerWrapper = new RequestWrapper(request);
return StringUtils.isEmpty(getBodyParams(readerWrapper.getInputStream(), request.getCharacterEncoding())) ?
"{}":getBodyParams(readerWrapper.getInputStream(), request.getCharacterEncoding());
}
/**
* 获取POST、PUT、DELETE请求中Body参数
*
*/
private String getBodyParams(ServletInputStream inputStream, String charset) throws Exception {
String body = StreamUtils.copyToString(inputStream, Charset.forName(charset));
if (StringUtils.isEmpty(body)) {
return "";
}
return body;
}
}
七、拦截器注册InterceptorRegister
一个配置类,把自定义的拦截器注入spring
package com.example.demo.interceptorToken;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* @author hulei
* @date 2024/1/11 19:48
* 将拦截注入spring容器
*/
@Configuration
public class InterceptorRegister implements WebMvcConfigurer {
@Bean
public RequestInterceptor tokenInterceptor() {
return new RequestInterceptor();
}
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(tokenInterceptor());
}
}
八、总结
本例主要是自定义注解,完成请求参数的拦截校验,实际中可根据需求进行修改,如记录日志,拦截校验其他参数,修改RequestInterceptor中的拦截前方法和拦截后方法的逻辑即可
gitee地址: Token-Check-Demo: 自定义注解拦截request请求
注: 创作不易,转载请标明原作地址



















