保护你的 API 免受滥用至关重要。速率限制是 API 安全的关键。它可以防止拒绝服务攻击、管理资源并确保客户端之间的公平使用。Spring Boot 3 和 Bucket4j 结合提供了一个强大且灵活的方式来为你的应用程序添加速率限制。
在本文中,我们将探讨如何在 Spring Boot 3 应用程序中使用 Bucket4j 开发速率限制功能。我们将介绍不同的方法,并提供实用的示例,供你根据需求进行调整。
先决条件
在开始之前,请确保你具备以下条件:
- • Java 17 或更高版本。
- • 对 Java、Spring Boot 和 API 开发有基本了解。
实现
第一步是将所需的依赖项添加到你的 pom.xml
或 build.gradle
中。
<dependency>
<groupId>com.bucket4j</groupId>
<artifactId>bucket4j-core</artifactId>
<version>8.3.0</version>
</dependency>
<dependency>
<groupId>com.bucket4j</groupId>
<artifactId>bucket4j-caffeine</artifactId>
<version>8.3.0</version>
</dependency>
<dependency>
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>
<version>3.1.8</version>
</dependency>
我们不会直接跳到最终代码,而是逐步构建速率限制功能。让我们从创建一个基本的 REST 控制器开始。
@RestController
@RequestMapping("/api")
public class RateLimitedController {
@GetMapping("/greeting")
public String getGreeting() {
return "Hello, World!";
}
}
接下来,我们需要配置速率限制。
@Configuration
public class RateLimitConfig {
@Bean
public Bucket createNewBucket() {
long overdraft = 50;
Refill refill = Refill.intervally(40, Duration.ofMinutes(1));
Bandwidth limit = Bandwidth.classic(overdraft, refill);
return Bucket.builder()
.addLimit(limit)
.build();
}
}
现在,我们需要设置一个速率限制拦截器。
@Component
@RequiredArgsConstructor
public class RateLimitInterceptor implements HandlerInterceptor {
private final Bucket bucket;
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);
if (probe.isConsumed()) {
response.addHeader("X-Rate-Limit-Remaining", String.valueOf(probe.getRemainingTokens()));
return true;
}
long waitForRefill = probe.getNanosToWaitForRefill() / 1_000_000_000;
response.addHeader("X-Rate-Limit-Retry-After-Seconds", String.valueOf(waitForRefill));
response.sendError(HttpStatus.TOO_MANY_REQUESTS.value(),
"You have exhausted your API Request Quota");
return false;
}
}
目前,我们还没有注册我们的拦截器,让我们来解决这个问题。
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {
private final RateLimitInterceptor interceptor;
public WebMvcConfig(RateLimitInterceptor interceptor) {
this.interceptor = interceptor;
}
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(interceptor)
.addPathPatterns("/api/**");
}
}
我们已经实现了一个基本的速率限制器。这个基本版本并不适合实际生产环境。
IP 基础的速率限制
IP 基础的速率限制更接近实际生产场景。IP 限制提供了更细粒度的控制。
@Component
public class IpBasedRateLimitInterceptor implements HandlerInterceptor {
private final Cache<String, Bucket> cache;
public IpBasedRateLimitInterceptor() {
this.cache = Caffeine.newBuilder()
.expireAfterWrite(1, TimeUnit.SECONDS)
.build();
}
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
String ip = getClientIP(request);
Bucket bucket = cache.get(ip, this::newBucket);
ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);
if (probe.isConsumed()) {
response.addHeader("X-Rate-Limit-Remaining", String.valueOf(probe.getRemainingTokens()));
return true;
}
long waitForRefill = probe.getNanosToWaitForRefill() / 1_000_000_000;
response.addHeader("X-Rate-Limit-Retry-After-Seconds", String.valueOf(waitForRefill));
response.sendError(HttpStatus.TOO_MANY_REQUESTS.value(),
"Rate limit exceeded. Try again in " + waitForRefill + " seconds");
return false;
}
private String getClientIP(HttpServletRequest request) {
String xfHeader = request.getHeader("X-Forwarded-For");
if (xfHeader == null) {
return request.getRemoteAddr();
}
return xfHeader.split(",")[0];
}
private Bucket newBucket(String ip) {
return Bucket.builder()
.addLimit(Bandwidth.classic(10, Refill.intervally(10, Duration.ofMinutes(1))))
.build();
}
}
当然,我们需要单元测试来验证我们的实现是否有效。
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
class RateLimitedControllerTest {
@LocalServerPort
private int port;
@Autowired
private TestRestTemplate restTemplate;
@Test
void whenExceedingRateLimit_thenReceive429() {
String url = "http://localhost:" + port + "/api/greeting";
// 发送 10 个请求(超过我们的限制 9 次)
for (int i = 0; i < 10; i++) {
ResponseEntity<String> response = restTemplate.getForEntity(url, String.class);
if (i < 10) {
assertEquals(HttpStatus.OK, response.getStatusCode());
} else {
assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode());
}
}
}
}
基于系统负载的动态速率限制
最后但同样重要的是,让我们再构建一个速率限制器。这个速率限制器将根据应用程序的负载来限制请求。
@Slf4j
@Component
public class SystemMetricsCollector {
private final OperatingSystemMXBean osBean;
public SystemMetricsCollector() {
this.osBean = ManagementFactory.getOperatingSystemMXBean();
}
public SystemMetrics collectMetrics() {
double cpuLoad = getProcessCpuLoad();
long freeMemory = Runtime.getRuntime().freeMemory();
long totalMemory = Runtime.getRuntime().totalMemory();
double memoryUsage = 1.0 - (double) freeMemory / totalMemory;
return new SystemMetrics(cpuLoad, memoryUsage);
}
private double getProcessCpuLoad() {
if (osBean instanceof com.sun.management.OperatingSystemMXBean) {
return ((com.sun.management.OperatingSystemMXBean) osBean)
.getProcessCpuLoad();
}
return osBean.getSystemLoadAverage();
}
}
以及:
@Data
@AllArgsConstructor
public class SystemMetrics {
private double cpuLoad;
private double memoryUsage;
}
然后,我们需要创建速率限制的计算组件。
@Component
@Slf4j
public class DynamicRateLimitCalculator {
private static final int BASE_LIMIT = 100;
private static final double CPU_THRESHOLD_HIGH = 0.8;
private static final double CPU_THRESHOLD_MEDIUM = 0.5;
private static final double MEMORY_THRESHOLD_HIGH = 0.8;
private static final double MEMORY_THRESHOLD_MEDIUM = 0.5;
public RateLimitConfig calculateLimit(SystemMetrics metrics) {
int limit = BASE_LIMIT;
// 根据 CPU 负载调整限制
limit = adjustLimitBasedOnCpu(limit, metrics.getCpuLoad());
// 根据内存使用率调整限制
limit = adjustLimitBasedOnMemory(limit, metrics.getMemoryUsage());
Duration refillDuration = calculateRefillDuration(metrics);
log.debug("Calculated rate limit: {}/{}s", limit,
refillDuration.getSeconds());
return new RateLimitConfig(limit, refillDuration);
}
private int adjustLimitBasedOnCpu(int currentLimit, double cpuLoad) {
if (cpuLoad > CPU_THRESHOLD_HIGH) {
return (int) (currentLimit * 0.3); // 严重减少
} else if (cpuLoad > CPU_THRESHOLD_MEDIUM) {
return (int) (currentLimit * 0.6); // 适度减少
}
return currentLimit;
}
private int adjustLimitBasedOnMemory(int currentLimit,
double memoryUsage) {
if (memoryUsage > MEMORY_THRESHOLD_HIGH) {
return (int) (currentLimit * 0.4);
} else if (memoryUsage > MEMORY_THRESHOLD_MEDIUM) {
return (int) (currentLimit * 0.7);
}
return currentLimit;
}
private Duration calculateRefillDuration(SystemMetrics metrics) {
double maxLoad = Math.max(metrics.getCpuLoad(),
metrics.getMemoryUsage());
if (maxLoad > 0.8) {
return Duration.ofMinutes(2);
} else if (maxLoad > 0.5) {
return Duration.ofMinutes(1);
}
return Duration.ofSeconds(30);
}
}
@Data
@AllArgsConstructor
public class RateLimitConfig {
private int limit;
private Duration refillDuration;
}
让我们创建一个灵活的速率限制器,它将作为处理程序拦截器。
@Slf4j
@Component
public class DynamicRateLimitInterceptor implements HandlerInterceptor, RateLimitConfigProvider {
private final Cache<String, Bucket> bucketCache;
private final SystemMetricsCollector metricsCollector;
private final DynamicRateLimitCalculator calculator;
private final AtomicReference<RateLimitConfig> currentConfig;
private final ScheduledExecutorService scheduler;
private final RateLimitMetrics metrics;
public DynamicRateLimitInterceptor(SystemMetricsCollector metricsCollector,
DynamicRateLimitCalculator calculator, MeterRegistry meterRegistry) {
this.metricsCollector = metricsCollector;
this.calculator = calculator;
this.currentConfig = new AtomicReference<>(
new RateLimitConfig(100, Duration.ofMinutes(1))
);
this.bucketCache = Caffeine.newBuilder()
.expireAfterWrite(1, TimeUnit.HOURS)
.build();
this.scheduler = Executors.newSingleThreadScheduledExecutor();
this.metrics = new RateLimitMetrics(meterRegistry, this);
startMetricsUpdateTask();
}
private void startMetricsUpdateTask() {
scheduler.scheduleAtFixedRate(
this::updateRateLimitConfig,
0,
10,
TimeUnit.SECONDS
);
}
private void updateRateLimitConfig() {
try {
SystemMetrics metrics = metricsCollector.collectMetrics();
RateLimitConfig newConfig = calculator.calculateLimit(metrics);
RateLimitConfig oldConfig = currentConfig.get();
if (hasSignificantChange(oldConfig, newConfig)) {
currentConfig.set(newConfig);
log.info("Rate limit updated: {}/{}s",
newConfig.getLimit(),
newConfig.getRefillDuration().getSeconds());
// Clear cache to force bucket recreation with new limits
bucketCache.invalidateAll();
}
} catch (Exception e) {
log.error("Error updating rate limit config", e);
}
}
private boolean hasSignificantChange(RateLimitConfig oldConfig,
RateLimitConfig newConfig) {
double limitChange = Math.abs(1.0 -
(double) newConfig.getLimit() / oldConfig.getLimit());
return limitChange > 0.2; // 20% change threshold
}
public RateLimitConfig getRateLimitConfig() {
return this.currentConfig.get();
}
@Override
public boolean preHandle(HttpServletRequest request,
HttpServletResponse response,
Object handler) throws Exception {
String path = request.getRequestURI();
String method = request.getMethod();
Timer.Sample timerSample = metrics.startTimer();
boolean rateLimited = false;
try {
metrics.recordRequest();
String clientId = getClientIdentifier(request);
Bucket bucket = bucketCache.get(clientId, this::createBucket);
ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);
if (probe.isConsumed()) {
addRateLimitHeaders(response, probe);
return true;
}
metrics.incrementRateLimitExceeded();
handleRateLimitExceeded(response, probe);
return false;
} finally {
metrics.stopTimer(timerSample, path, method, rateLimited);
}
}
private Bucket createBucket(String clientId) {
RateLimitConfig config = currentConfig.get();
return Bucket.builder()
.addLimit(Bandwidth.classic(
config.getLimit(),
Refill.intervally(config.getLimit(),
config.getRefillDuration())
))
.build();
}
private String getClientIdentifier(HttpServletRequest request) {
// Could combine multiple factors: IP, user ID, API key, etc.
return request.getRemoteAddr();
}
private void addRateLimitHeaders(HttpServletResponse response,
ConsumptionProbe probe) {
RateLimitConfig config = currentConfig.get();
response.addHeader("X-Rate-Limit-Limit",
String.valueOf(config.getLimit()));
response.addHeader("X-Rate-Limit-Remaining",
String.valueOf(probe.getRemainingTokens()));
response.addHeader("X-Rate-Limit-Reset",
String.valueOf(probe.getNanosToWaitForRefill() /
1_000_000_000));
}
private void handleRateLimitExceeded(HttpServletResponse response,
ConsumptionProbe probe)
throws IOException {
response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
String errorMessage = String.format(
"Rate limit exceeded. Try again in %d seconds",
probe.getNanosToWaitForRefill() / 1_000_000_000
);
response.getWriter().write(
String.format(
"{\"error\": \"%s\", \"retryAfter\": %d}",
errorMessage,
probe.getNanosToWaitForRefill() / 1_000_000_000
)
);
}
@PreDestroy
public void shutdown() {
scheduler.shutdown();
}
}
配置 Spring Boot 应用程序以使用速率限制器
现在我们需要配置 Spring Boot 应用程序以使用我们实现的速率限制器。
@Configuration
public class RateLimitConfig implements WebMvcConfigurer {
@Autowired
private DynamicRateLimiter rateLimiter;
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(rateLimiter)
.addPathPatterns("/api/**");
}
}
通过上述配置,我们将 DynamicRateLimiter
注册为一个拦截器,并将其应用于所有以 /api
开头的请求路径。
创建自定义指标以监控应用程序
为了跟踪性能、负载、内存消耗等应用程序的各个方面,创建自定义指标是一个好主意。以下是实现自定义指标的代码:
public class RateLimitMetrics {
private final MeterRegistry meterRegistry;
private final Counter rateLimitExceeded;
private final Counter requestsTotal;
private final Gauge currentLimit;
public RateLimitMetrics(MeterRegistry registry,
RateLimitConfigProvider configProvider) {
this.meterRegistry = registry;
this.rateLimitExceeded = Counter.builder("rate_limit.exceeded")
.description("Number of rate limit exceeded events")
.tag("type", "exceeded")
.register(registry);
this.requestsTotal = Counter.builder("rate_limit.requests")
.description("Total number of requests processed")
.tag("type", "total")
.register(registry);
this.currentLimit = Gauge.builder("rate_limit.current",
configProvider,
this::getCurrentLimit)
.description("Current rate limit value")
.tag("type", "limit")
.register(registry);
}
public Timer.Sample startTimer() {
return Timer.start();
}
public void stopTimer(Timer.Sample sample, String path, String method, boolean rateLimited) {
Timer timer = Timer.builder("rate_limit.request.duration")
.description("Request duration through rate limiter")
.tags(
"path", path,
"method", method,
"rate_limited", String.valueOf(rateLimited),
"component", "rate_limiter"
)
.register(meterRegistry);
sample.stop(timer);
}
public void incrementRateLimitExceeded() {
rateLimitExceeded.increment();
}
public void recordRequest() {
requestsTotal.increment();
}
private double getCurrentLimit(RateLimitConfigProvider provider) {
return provider.getRateLimitConfig().getLimit();
}
public Map<String, Number> getCurrentMetrics() {
return Map.of(
"rateLimitExceeded", rateLimitExceeded.count(),
"totalRequests", requestsTotal.count(),
"currentLimit", currentLimit.value()
);
}
}
通过上述代码,我们创建了以下指标:
- rate_limit.exceeded:记录速率限制被触发的次数。
- rate_limit.requests:记录处理的请求总数。
- rate_limit.current:显示当前的速率限制值。
最佳实践和注意事项
- 缓存实现:在生产环境中,使用分布式缓存(如 Redis)来实现集群环境中的速率限制。
- 响应头:始终在响应头中包含速率限制信息,以帮助客户端管理其请求速率。常见的头信息包括:
- X-Rate-Limit-Remaining:剩余的请求次数。
- X-Rate-Limit-Retry-After-Seconds:需要等待的秒数。
- 错误处理:当用户超出速率限制时,提供清晰的错误信息。
- 监控:设置指标以跟踪速率限制事件,并根据使用模式调整限制。
结论
本文展示了如何在 Spring Boot 3 应用程序中使用 Bucket4j 实现速率限制。我们介绍了三种方法:基于时间、基于 IP 地址和基于系统负载的速率限制。实际场景可能与本文中的示例有所不同。
速率限制是 API 安全策略的一部分。通过将其与其他安全措施结合使用,你可以构建强大的 API 保护机制。
没有回复内容