Spring Boot + Bucket4j 实现API请求限流

保护你的 API 免受滥用至关重要。速率限制是 API 安全的关键。它可以防止拒绝服务攻击、管理资源并确保客户端之间的公平使用。Spring Boot 3 和 Bucket4j 结合提供了一个强大且灵活的方式来为你的应用程序添加速率限制。

在本文中,我们将探讨如何在 Spring Boot 3 应用程序中使用 Bucket4j 开发速率限制功能。我们将介绍不同的方法,并提供实用的示例,供你根据需求进行调整。

先决条件

在开始之前,请确保你具备以下条件:

  • • Java 17 或更高版本。
  • • 对 Java、Spring Boot 和 API 开发有基本了解。

实现

第一步是将所需的依赖项添加到你的 pom.xml 或 build.gradle 中。

<dependency>
    <groupId>com.bucket4j</groupId>
    <artifactId>bucket4j-core</artifactId>
    <version>8.3.0</version>
</dependency>
<dependency>
    <groupId>com.bucket4j</groupId>
    <artifactId>bucket4j-caffeine</artifactId>
    <version>8.3.0</version>
</dependency>
<dependency>
    <groupId>com.github.ben-manes.caffeine</groupId>
    <artifactId>caffeine</artifactId>
    <version>3.1.8</version>
</dependency>

我们不会直接跳到最终代码,而是逐步构建速率限制功能。让我们从创建一个基本的 REST 控制器开始。

@RestController
@RequestMapping("/api")
public class RateLimitedController {

    @GetMapping("/greeting")
    public String getGreeting() {
        return "Hello, World!";
    }
}

接下来,我们需要配置速率限制。

@Configuration
public class RateLimitConfig {

    @Bean
    public Bucket createNewBucket() {
        long overdraft = 50;
        Refill refill = Refill.intervally(40, Duration.ofMinutes(1));
        Bandwidth limit = Bandwidth.classic(overdraft, refill);
        return Bucket.builder()
                .addLimit(limit)
                .build();
    }
}

现在,我们需要设置一个速率限制拦截器。

@Component
@RequiredArgsConstructor
public class RateLimitInterceptor implements HandlerInterceptor {

    private final Bucket bucket;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);
        if (probe.isConsumed()) {
            response.addHeader("X-Rate-Limit-Remaining", String.valueOf(probe.getRemainingTokens()));
            return true;
        }

        long waitForRefill = probe.getNanosToWaitForRefill() / 1_000_000_000;
        response.addHeader("X-Rate-Limit-Retry-After-Seconds", String.valueOf(waitForRefill));
        response.sendError(HttpStatus.TOO_MANY_REQUESTS.value(),
                "You have exhausted your API Request Quota");
        return false;
    }
}

目前,我们还没有注册我们的拦截器,让我们来解决这个问题。

@Configuration
public class WebMvcConfig implements WebMvcConfigurer {

    private final RateLimitInterceptor interceptor;

    public WebMvcConfig(RateLimitInterceptor interceptor) {
        this.interceptor = interceptor;
    }

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(interceptor)
                .addPathPatterns("/api/**");
    }
}

我们已经实现了一个基本的速率限制器。这个基本版本并不适合实际生产环境。

IP 基础的速率限制

IP 基础的速率限制更接近实际生产场景。IP 限制提供了更细粒度的控制。

@Component
public class IpBasedRateLimitInterceptor implements HandlerInterceptor {

    private final Cache<String, Bucket> cache;

    public IpBasedRateLimitInterceptor() {
        this.cache = Caffeine.newBuilder()
                .expireAfterWrite(1, TimeUnit.SECONDS)
                .build();
    }

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        String ip = getClientIP(request);
        Bucket bucket = cache.get(ip, this::newBucket);

        ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);
        if (probe.isConsumed()) {
            response.addHeader("X-Rate-Limit-Remaining", String.valueOf(probe.getRemainingTokens()));
            return true;
        }

        long waitForRefill = probe.getNanosToWaitForRefill() / 1_000_000_000;
        response.addHeader("X-Rate-Limit-Retry-After-Seconds", String.valueOf(waitForRefill));
        response.sendError(HttpStatus.TOO_MANY_REQUESTS.value(),
                "Rate limit exceeded. Try again in " + waitForRefill + " seconds");
        return false;
    }

    private String getClientIP(HttpServletRequest request) {
        String xfHeader = request.getHeader("X-Forwarded-For");
        if (xfHeader == null) {
            return request.getRemoteAddr();
        }
        return xfHeader.split(",")[0];
    }

    private Bucket newBucket(String ip) {
        return Bucket.builder()
                .addLimit(Bandwidth.classic(10, Refill.intervally(10, Duration.ofMinutes(1))))
                .build();
    }
}

当然,我们需要单元测试来验证我们的实现是否有效。

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
class RateLimitedControllerTest {

    @LocalServerPort
    private int port;

    @Autowired
    private TestRestTemplate restTemplate;

    @Test
    void whenExceedingRateLimit_thenReceive429() {
        String url = "http://localhost:" + port + "/api/greeting";
        
        // 发送 10 个请求(超过我们的限制 9 次)
        for (int i = 0; i < 10; i++) {
            ResponseEntity<String> response = restTemplate.getForEntity(url, String.class);
            
            if (i < 10) {
                assertEquals(HttpStatus.OK, response.getStatusCode());
            } else {
                assertEquals(HttpStatus.TOO_MANY_REQUESTS, response.getStatusCode());
            }
        }
    }
}

基于系统负载的动态速率限制

最后但同样重要的是,让我们再构建一个速率限制器。这个速率限制器将根据应用程序的负载来限制请求。

@Slf4j
@Component
public class SystemMetricsCollector {
    private final OperatingSystemMXBean osBean;

    public SystemMetricsCollector() {
        this.osBean = ManagementFactory.getOperatingSystemMXBean();
    }

    public SystemMetrics collectMetrics() {
        double cpuLoad = getProcessCpuLoad();
        long freeMemory = Runtime.getRuntime().freeMemory();
        long totalMemory = Runtime.getRuntime().totalMemory();
        double memoryUsage = 1.0 - (double) freeMemory / totalMemory;

        return new SystemMetrics(cpuLoad, memoryUsage);
    }

    private double getProcessCpuLoad() {
        if (osBean instanceof com.sun.management.OperatingSystemMXBean) {
            return ((com.sun.management.OperatingSystemMXBean) osBean)
                    .getProcessCpuLoad();
        }
        return osBean.getSystemLoadAverage();
    }
}

以及:

@Data
@AllArgsConstructor
public class SystemMetrics {
    private double cpuLoad;
    private double memoryUsage;
}

然后,我们需要创建速率限制的计算组件。

@Component
@Slf4j
public class DynamicRateLimitCalculator {
    private static final int BASE_LIMIT = 100;
    private static final double CPU_THRESHOLD_HIGH = 0.8;
    private static final double CPU_THRESHOLD_MEDIUM = 0.5;
    private static final double MEMORY_THRESHOLD_HIGH = 0.8;
    private static final double MEMORY_THRESHOLD_MEDIUM = 0.5;

    public RateLimitConfig calculateLimit(SystemMetrics metrics) {
        int limit = BASE_LIMIT;
        
        // 根据 CPU 负载调整限制
        limit = adjustLimitBasedOnCpu(limit, metrics.getCpuLoad());
        
        // 根据内存使用率调整限制
        limit = adjustLimitBasedOnMemory(limit, metrics.getMemoryUsage());
        
        Duration refillDuration = calculateRefillDuration(metrics);
        
        log.debug("Calculated rate limit: {}/{}s", limit, 
                  refillDuration.getSeconds());
        
        return new RateLimitConfig(limit, refillDuration);
    }

    private int adjustLimitBasedOnCpu(int currentLimit, double cpuLoad) {
        if (cpuLoad > CPU_THRESHOLD_HIGH) {
            return (int) (currentLimit * 0.3); // 严重减少
        } else if (cpuLoad > CPU_THRESHOLD_MEDIUM) {
            return (int) (currentLimit * 0.6); // 适度减少
        }
        return currentLimit;
    }

    private int adjustLimitBasedOnMemory(int currentLimit, 
                                       double memoryUsage) {
        if (memoryUsage > MEMORY_THRESHOLD_HIGH) {
            return (int) (currentLimit * 0.4);
        } else if (memoryUsage > MEMORY_THRESHOLD_MEDIUM) {
            return (int) (currentLimit * 0.7);
        }
        return currentLimit;
    }

    private Duration calculateRefillDuration(SystemMetrics metrics) {
        double maxLoad = Math.max(metrics.getCpuLoad(), 
                                metrics.getMemoryUsage());
        if (maxLoad > 0.8) {
            return Duration.ofMinutes(2);
        } else if (maxLoad > 0.5) {
            return Duration.ofMinutes(1);
        }
        return Duration.ofSeconds(30);
    }
}

@Data
@AllArgsConstructor
public class RateLimitConfig {
    private int limit;
    private Duration refillDuration;
}

让我们创建一个灵活的速率限制器,它将作为处理程序拦截器。

@Slf4j
@Component
public class DynamicRateLimitInterceptor implements HandlerInterceptor, RateLimitConfigProvider {
    private final Cache<String, Bucket> bucketCache;
    private final SystemMetricsCollector metricsCollector;
    private final DynamicRateLimitCalculator calculator;
    private final AtomicReference<RateLimitConfig> currentConfig;
    private final ScheduledExecutorService scheduler;
    private final RateLimitMetrics metrics;

    public DynamicRateLimitInterceptor(SystemMetricsCollector metricsCollector,
                              DynamicRateLimitCalculator calculator, MeterRegistry meterRegistry) {
        this.metricsCollector = metricsCollector;
        this.calculator = calculator;
        this.currentConfig = new AtomicReference<>(
                new RateLimitConfig(100, Duration.ofMinutes(1))
        );
        this.bucketCache = Caffeine.newBuilder()
                .expireAfterWrite(1, TimeUnit.HOURS)
                .build();
        this.scheduler = Executors.newSingleThreadScheduledExecutor();
        this.metrics = new RateLimitMetrics(meterRegistry, this);
        startMetricsUpdateTask();
    }

    private void startMetricsUpdateTask() {
        scheduler.scheduleAtFixedRate(
                this::updateRateLimitConfig,
                0,
                10,
                TimeUnit.SECONDS
        );
    }

    private void updateRateLimitConfig() {
        try {
            SystemMetrics metrics = metricsCollector.collectMetrics();
            RateLimitConfig newConfig = calculator.calculateLimit(metrics);

            RateLimitConfig oldConfig = currentConfig.get();
            if (hasSignificantChange(oldConfig, newConfig)) {
                currentConfig.set(newConfig);
                log.info("Rate limit updated: {}/{}s",
                        newConfig.getLimit(),
                        newConfig.getRefillDuration().getSeconds());

                // Clear cache to force bucket recreation with new limits
                bucketCache.invalidateAll();
            }
        } catch (Exception e) {
            log.error("Error updating rate limit config", e);
        }
    }

    private boolean hasSignificantChange(RateLimitConfig oldConfig,
                                         RateLimitConfig newConfig) {
        double limitChange = Math.abs(1.0 -
                (double) newConfig.getLimit() / oldConfig.getLimit());
        return limitChange > 0.2; // 20% change threshold
    }

    public RateLimitConfig getRateLimitConfig() {
        return this.currentConfig.get();
    }

    @Override
    public boolean preHandle(HttpServletRequest request,
                             HttpServletResponse response,
                             Object handler) throws Exception {
        String path = request.getRequestURI();
        String method = request.getMethod();

        Timer.Sample timerSample = metrics.startTimer();
        boolean rateLimited = false;
        try {
            metrics.recordRequest();

            String clientId = getClientIdentifier(request);
            Bucket bucket = bucketCache.get(clientId, this::createBucket);

            ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);

            if (probe.isConsumed()) {
                addRateLimitHeaders(response, probe);
                return true;
            }

            metrics.incrementRateLimitExceeded();
            handleRateLimitExceeded(response, probe);
            return false;
        } finally {
            metrics.stopTimer(timerSample, path, method, rateLimited);
        }
    }

    private Bucket createBucket(String clientId) {
        RateLimitConfig config = currentConfig.get();
        return Bucket.builder()
                .addLimit(Bandwidth.classic(
                        config.getLimit(),
                        Refill.intervally(config.getLimit(),
                                config.getRefillDuration())
                ))
                .build();
    }

    private String getClientIdentifier(HttpServletRequest request) {
        // Could combine multiple factors: IP, user ID, API key, etc.
        return request.getRemoteAddr();
    }

    private void addRateLimitHeaders(HttpServletResponse response,
                                     ConsumptionProbe probe) {
        RateLimitConfig config = currentConfig.get();
        response.addHeader("X-Rate-Limit-Limit",
                String.valueOf(config.getLimit()));
        response.addHeader("X-Rate-Limit-Remaining",
                String.valueOf(probe.getRemainingTokens()));
        response.addHeader("X-Rate-Limit-Reset",
                String.valueOf(probe.getNanosToWaitForRefill() /
                        1_000_000_000));
    }

    private void handleRateLimitExceeded(HttpServletResponse response,
                                         ConsumptionProbe probe)
            throws IOException {
        response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);

        String errorMessage = String.format(
                "Rate limit exceeded. Try again in %d seconds",
                probe.getNanosToWaitForRefill() / 1_000_000_000
        );

        response.getWriter().write(
                String.format(
                        "{\"error\": \"%s\", \"retryAfter\": %d}",
                        errorMessage,
                        probe.getNanosToWaitForRefill() / 1_000_000_000
                )
        );
    }

    @PreDestroy
    public void shutdown() {
        scheduler.shutdown();
    }
}

配置 Spring Boot 应用程序以使用速率限制器

现在我们需要配置 Spring Boot 应用程序以使用我们实现的速率限制器。

@Configuration
public class RateLimitConfig implements WebMvcConfigurer {
    
    @Autowired
    private DynamicRateLimiter rateLimiter;
    
    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(rateLimiter)
               .addPathPatterns("/api/**");
    }
}

通过上述配置,我们将 DynamicRateLimiter 注册为一个拦截器,并将其应用于所有以 /api 开头的请求路径。

创建自定义指标以监控应用程序

为了跟踪性能、负载、内存消耗等应用程序的各个方面,创建自定义指标是一个好主意。以下是实现自定义指标的代码:

public class RateLimitMetrics {
    private final MeterRegistry meterRegistry;
    private final Counter rateLimitExceeded;
    private final Counter requestsTotal;
    private final Gauge currentLimit;

    public RateLimitMetrics(MeterRegistry registry,
                            RateLimitConfigProvider configProvider) {
        this.meterRegistry = registry;

        this.rateLimitExceeded = Counter.builder("rate_limit.exceeded")
                .description("Number of rate limit exceeded events")
                .tag("type", "exceeded")
                .register(registry);

        this.requestsTotal = Counter.builder("rate_limit.requests")
                .description("Total number of requests processed")
                .tag("type", "total")
                .register(registry);

        this.currentLimit = Gauge.builder("rate_limit.current",
                        configProvider,
                        this::getCurrentLimit)
                .description("Current rate limit value")
                .tag("type", "limit")
                .register(registry);
    }

    public Timer.Sample startTimer() {
        return Timer.start();
    }

    public void stopTimer(Timer.Sample sample, String path, String method, boolean rateLimited) {
        Timer timer = Timer.builder("rate_limit.request.duration")
                .description("Request duration through rate limiter")
                .tags(
                        "path", path,
                        "method", method,
                        "rate_limited", String.valueOf(rateLimited),
                        "component", "rate_limiter"
                )
                .register(meterRegistry);
        sample.stop(timer);
    }

    public void incrementRateLimitExceeded() {
        rateLimitExceeded.increment();
    }

    public void recordRequest() {
        requestsTotal.increment();
    }

    private double getCurrentLimit(RateLimitConfigProvider provider) {
        return provider.getRateLimitConfig().getLimit();
    }

    public Map<String, Number> getCurrentMetrics() {
        return Map.of(
                "rateLimitExceeded", rateLimitExceeded.count(),
                "totalRequests", requestsTotal.count(),
                "currentLimit", currentLimit.value()
        );
    }
}

通过上述代码,我们创建了以下指标:

  1. rate_limit.exceeded:记录速率限制被触发的次数。
  2. rate_limit.requests:记录处理的请求总数。
  3. rate_limit.current:显示当前的速率限制值。

最佳实践和注意事项

  1. 缓存实现:在生产环境中,使用分布式缓存(如 Redis)来实现集群环境中的速率限制。
  2. 响应头:始终在响应头中包含速率限制信息,以帮助客户端管理其请求速率。常见的头信息包括:
    • X-Rate-Limit-Remaining:剩余的请求次数。
    • X-Rate-Limit-Retry-After-Seconds:需要等待的秒数。
  3. 错误处理:当用户超出速率限制时,提供清晰的错误信息。
  4. 监控:设置指标以跟踪速率限制事件,并根据使用模式调整限制。

结论

本文展示了如何在 Spring Boot 3 应用程序中使用 Bucket4j 实现速率限制。我们介绍了三种方法:基于时间、基于 IP 地址和基于系统负载的速率限制。实际场景可能与本文中的示例有所不同。

速率限制是 API 安全策略的一部分。通过将其与其他安全措施结合使用,你可以构建强大的 API 保护机制。

 

请登录后发表评论

    没有回复内容