package com.nuvole.rateLimit;
|
|
import com.alibaba.fastjson.JSONObject;
|
import com.google.common.collect.Maps;
|
import com.google.common.util.concurrent.RateLimiter;
|
import com.nuvole.common.domain.emnu.CommonResultEmnu;
|
import com.nuvole.common.domain.result.CommonResult;
|
import jakarta.servlet.http.HttpServletResponse;
|
import lombok.extern.slf4j.Slf4j;
|
import org.aspectj.lang.ProceedingJoinPoint;
|
import org.aspectj.lang.annotation.Around;
|
import org.aspectj.lang.annotation.Aspect;
|
import org.aspectj.lang.reflect.MethodSignature;
|
import org.springframework.stereotype.Component;
|
import org.springframework.web.context.request.RequestContextHolder;
|
import org.springframework.web.context.request.ServletRequestAttributes;
|
|
import java.io.IOException;
|
import java.io.PrintWriter;
|
import java.lang.reflect.Method;
|
import java.util.Map;
|
|
|
@Slf4j
|
@Aspect
|
@Component
|
public class ApiRateLimitAop {
|
/**
|
* 不同的接口,不同的流量控制
|
* map的key为 Limiter.key
|
*/
|
private final Map<String, RateLimiter> limitMap = Maps.newConcurrentMap();
|
|
@Around("@annotation(com.nuvole.rateLimit.ApiRateLimit)")
|
public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
|
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
|
Method method = signature.getMethod();
|
//拿limit的注解
|
ApiRateLimit limit = method.getAnnotation(ApiRateLimit.class);
|
if (limit != null) {
|
//key作用:不同的接口,不同的流量控制
|
String key = limit.key();
|
RateLimiter rateLimiter = null;
|
//验证缓存是否有命中key
|
if (!limitMap.containsKey(key)) {
|
// 创建令牌桶
|
rateLimiter = RateLimiter.create(limit.permitsPerSecond());
|
limitMap.put(key, rateLimiter);
|
log.info("新建了令牌桶={},容量={}", key, limit.permitsPerSecond());
|
}
|
rateLimiter = limitMap.get(key);
|
// 拿令牌
|
boolean acquire = rateLimiter.tryAcquire(limit.timeout(), limit.timeunit());
|
// 拿不到命令,直接返回异常提示
|
if (!acquire) {
|
log.debug("令牌桶={},获取令牌失败", key);
|
this.responseFail(limit.msg());
|
return null;
|
}
|
}
|
return joinPoint.proceed();
|
}
|
|
/**
|
* @param msg 提示信息
|
*/
|
private void responseFail(String msg) {
|
HttpServletResponse response = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getResponse();
|
CommonResult resultData = new CommonResult(CommonResultEmnu.REQUEST_LIMIT_ERR, msg);
|
response.setContentType("application/json;charset=UTF-8");
|
PrintWriter out = null;
|
try {
|
out = response.getWriter();
|
} catch (IOException e) {
|
e.printStackTrace();
|
}
|
out.print(JSONObject.toJSONString(resultData));
|
out.flush();
|
out.close();
|
}
|
}
|