redis 限流


redis 限流

redis 令牌桶限流

@dataclass(frozen=True)
class RateLimitResult:
    allowed: bool
    remaining_tokens: float
    retry_after_seconds: float
    now_ms: int


class RedisTokenBucketRateLimiter:
    _TOKEN_BUCKET_LUA = r"""
        local key = KEYS[1]                         -- 令牌桶状态存储的 Redis Key(Hash)
                                                   
        local capacity = tonumber(ARGV[1])          -- 桶容量:最多可累积的令牌数
        local rate = tonumber(ARGV[2])              -- 补充速率:每秒补充多少令牌(tokens/s)
        local requested = tonumber(ARGV[3])         -- 本次请求要消耗的令牌数
                                                   
        local ttl_ms = tonumber(ARGV[4])            -- 令牌桶 Key 的过期时间(毫秒),用于回收冷 Key
                                                   
        local t = redis.call('TIME')                -- 读取 Redis 服务器时间:{秒, 微秒}
                                                   
        local now = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000)  -- 服务器时间换算为毫秒时间戳
                                                   
        local data = redis.call('HMGET', key, 'tokens', 'ts') -- 读取当前令牌数(tokens)与上次更新时间(ts)
        local tokens = tonumber(data[1])            -- 当前令牌数(可能为 nil)
        local ts = tonumber(data[2])                -- 上次更新时间戳(毫秒,可能为 nil)
                                                   
        if tokens == nil then tokens = capacity end -- 初始化:若不存在,则认为桶是满的
        if ts == nil then ts = now end              -- 初始化:若不存在,则将上次更新时间设为当前
        if now < ts then ts = now end               -- 防御:若时钟回拨导致 now < ts,则强制对齐
                                                   
        local delta_ms = now - ts                   -- 距离上次更新过去了多少毫秒
        local refill = (delta_ms / 1000.0) * rate   -- 这段时间应补充的令牌数量(可为小数)
        tokens = math.min(capacity, tokens + refill) -- 将令牌补充后并截断到容量上限
                                                   
        local allowed = 0                           -- 是否允许:0/1(Lua 没有布尔返回给 Python 的统一类型)
        local retry_after_ms = 0                    -- 如果不允许,建议等待多少毫秒再试
                                                   
        if tokens >= requested then                 -- 若令牌充足
        allowed = 1                                 -- 标记允许
        tokens = tokens - requested                 -- 扣除本次消耗
        else                                        -- 否则令牌不足
        allowed = 0                                 -- 标记拒绝
        if rate > 0 then                            -- 若补充速率 > 0,可计算需要等待的时间
            local missing = requested - tokens      -- 还差多少令牌
            retry_after_ms = math.ceil((missing / rate) * 1000) -- 差额按速率换算为等待毫秒数(向上取整)
        else                                        -- 若 rate=0,则永远补不回令牌
            retry_after_ms = -1                     -- 用 -1 表示无法通过等待获得令牌
        end                                         -- 结束 rate>0 分支
        end                                         -- 结束 tokens>=requested 分支
                                                   
        redis.call('HMSET', key, 'tokens', tokens, 'ts', now) -- 写回最新 tokens 与 ts(原子更新)
        if ttl_ms ~= nil and ttl_ms > 0 then        -- 若配置了过期时间且 >0
        redis.call('PEXPIRE', key, ttl_ms)          -- 给令牌桶 Key 设置过期(毫秒)
        end                                         -- 结束 ttl 分支
                                                   
        return {allowed, tokens, retry_after_ms, now} -- 返回:是否允许、剩余令牌、建议等待(ms)、当前时间(ms)
        """

    def __init__(
        self,
        redis: Optional[AsyncRedis] = None,
        *,
        capacity: Union[int, float],
        refill_rate: Union[int, float],
        requested: Union[int, float] = 1,
        ttl_ms: Optional[int] = None,
        key_prefix: str = "rate_limit:token_bucket:",
    ) -> None:
        capacity_f = float(capacity)
        rate_f = float(refill_rate)
        req_f = float(requested)

        if capacity_f <= 0:
            raise ValueError("capacity 必须 > 0")
        if rate_f < 0:
            raise ValueError("refill_rate 必须 >= 0")
        if req_f <= 0:
            raise ValueError("requested 必须 > 0")
        if req_f > capacity_f:
            raise ValueError("requested 不能大于 capacity")

        self._redis = redis or async_redis_client
        self._key_prefix = key_prefix
        self._capacity = capacity_f
        self._refill_rate = rate_f
        self._requested = req_f
        self._ttl_ms = int(ttl_ms) if ttl_ms is not None else None

    def _full_key(self, key: str) -> str:
        return f"{self._key_prefix}{key}"

    @staticmethod
    def _default_ttl_ms(
        capacity: Union[int, float], refill_rate: Union[int, float]
    ) -> int:
        # 让桶在“完全补满所需时间”的 2 倍后过期,避免大量冷 key 常驻
        # 至少 5 秒,防止极小桶/极大 rate 造成频繁抖动
        if refill_rate <= 0:
            return 60_000
        seconds = max(5.0, float(capacity) / float(refill_rate) * 2.0)
        return int(seconds * 1000)

    async def allow(
        self,
        key: str,
        *,
        requested: Optional[Union[int, float]] = None,
        ttl_ms: Optional[int] = None,
    ) -> RateLimitResult:
        """
        尝试拿令牌(令牌桶)。
        - capacity/refill_rate/requested 默认来自 __init__
        - requested/ttl_ms 允许在单次调用中覆盖
        """
        capacity_f = self._capacity
        rate_f = self._refill_rate
        req_f = float(requested) if requested is not None else self._requested

        if req_f <= 0:
            raise ValueError("requested 必须 > 0")
        if req_f > capacity_f:
            raise ValueError("requested 不能大于 capacity")

        ttl_ms_i = (
            int(ttl_ms)
            if ttl_ms is not None
            else (
                self._ttl_ms
                if self._ttl_ms is not None
                else self._default_ttl_ms(capacity_f, rate_f)
            )
        )

        full_key = self._full_key(key)
        res = await self._redis.eval(
            self._TOKEN_BUCKET_LUA,
            1,
            full_key,
            capacity_f,
            rate_f,
            req_f,
            ttl_ms_i,
        )

        allowed = bool(int(res[0]))
        remaining = float(res[1])
        retry_after_ms = int(res[2])
        now_ms = int(res[3])

        retry_after_seconds = 0.0
        if not allowed:
            retry_after_seconds = (
                0.0 if retry_after_ms <= 0 else retry_after_ms / 1000.0
            )

        return RateLimitResult(
            allowed=allowed,
            remaining_tokens=remaining,
            retry_after_seconds=retry_after_seconds,
            now_ms=now_ms,
        )

    async def wait(
        self,
        key: str,
        *,
        requested: Optional[Union[int, float]] = None,
        ttl_ms: Optional[int] = None,
        max_wait_seconds: Optional[float] = None,
        poll_min_sleep: float = 0.01,
    ) -> RateLimitResult:
        """
        一直等到拿到令牌为止(适合 Activity 里对外部依赖做限流)。
        max_wait_seconds: 超过则抛 TimeoutError
        """
        start = time.perf_counter()

        while True:
            r = await self.allow(
                key,
                requested=requested,
                ttl_ms=ttl_ms,
            )
            if r.allowed:
                return r

            sleep_s = max(poll_min_sleep, r.retry_after_seconds)
            if max_wait_seconds is not None:
                elapsed = time.perf_counter() - start
                if elapsed + sleep_s > max_wait_seconds:
                    raise TimeoutError(
                        f"rate limit wait timeout: key={key}, waited={elapsed:.3f}s, next_sleep={sleep_s:.3f}s"
                    )
            await asyncio.sleep(sleep_s)

redis 滑动窗口限流

@dataclass(frozen=True)
class SlidingWindowResult:
    allowed: bool
    current_count: int
    retry_after_seconds: float
    now_ms: int


class RedisSlidingWindowRateLimiter:
    """
    Redis 滑动窗口限流(ZSET)。
    语义:任意连续 window_seconds 秒内最多 limit 次通过。
    """

    _SLIDING_WINDOW_LUA = r"""
    local key = KEYS[1]                              -- 滑动窗口限流使用的 ZSET Key
                                                    
    local window_ms = tonumber(ARGV[1])              -- 窗口长度(毫秒)
    local limit = tonumber(ARGV[2])                  -- 窗口内允许的最大通过次数
    local ttl_ms = tonumber(ARGV[3])                 -- Key 过期时间(毫秒),用于回收冷 Key
                                                    
    -- Redis server time (ms)                        -- 注释:以下用 Redis TIME 作为统一时钟,避免多机时间不一致
    local t = redis.call('TIME')                     -- 读取 Redis 服务器时间:{秒, 微秒}
    local now = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000) -- 转为毫秒时间戳
                                                    
    -- window cleanup                                -- 注释:删除窗口之外的旧请求记录
    local start = now - window_ms                    -- 窗口起点(毫秒)
    redis.call('ZREMRANGEBYSCORE', key, '-inf', start) -- 删除 score <= start 的成员(窗口外)
                                                    
    local cnt = tonumber(redis.call('ZCARD', key))    -- 统计窗口内当前已有多少次请求
                                                    
    if cnt < limit then                              -- 若窗口内次数仍未达到上限(严格 <,避免多放 1 次)
      -- unique member to avoid overwrite under same ms -- 注释:member 唯一化,避免同毫秒覆盖导致少计数
      local seq = redis.call('INCR', key .. ':seq')   -- 递增序列号(辅助唯一 member)
      local member = tostring(now) .. ':' .. tostring(seq) -- member = now:seq
      redis.call('ZADD', key, now, member)            -- 写入本次请求记录:score=now(member 的时间)
      if ttl_ms ~= nil and ttl_ms > 0 then            -- 若配置了过期时间且 >0
        redis.call('PEXPIRE', key, ttl_ms)            -- 给 ZSET Key 设置过期时间
        redis.call('PEXPIRE', key .. ':seq', ttl_ms)  -- 给序列号 Key 也设置过期时间
      end                                             -- 结束 ttl 分支
      return {1, cnt + 1, 0, now}                     -- 返回:允许(1)、窗口内计数(含本次)、建议等待(ms=0)、now(ms)
    else                                              -- 否则窗口内次数已满,需要拒绝
      -- compute retry_after_ms: when the oldest entry exits window -- 注释:计算何时最早记录滑出窗口
      local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES') -- 取窗口内最早的一条(score 最小)
      local oldest_ts = nil                           -- 最早记录的时间戳(毫秒)
      if oldest ~= nil and #oldest >= 2 then          -- 若存在成员且带 score
        oldest_ts = tonumber(oldest[2])               -- oldest[2] 为 score(ZSET 的时间戳)
      end                                             -- 结束 oldest 解析
      local retry_after = 0                           -- 建议等待的毫秒数
      if oldest_ts ~= nil then                        -- 若能拿到最早时间戳
        retry_after = math.max(0, math.ceil((oldest_ts + window_ms) - now)) -- 等到 oldest_ts+window_ms 才会滑出
      else                                            -- 极端情况:ZSET 为空但 cnt>=limit(理论不应发生)
        retry_after = math.ceil(window_ms)            -- 保守等待一个窗口长度
      end                                             -- 结束 oldest_ts 分支
      if ttl_ms ~= nil and ttl_ms > 0 then            -- 若配置了过期时间且 >0
        redis.call('PEXPIRE', key, ttl_ms)            -- 维持 key 的过期时间(避免热 key 被误删)
        redis.call('PEXPIRE', key .. ':seq', ttl_ms)  -- 维持 seq key 的过期时间
      end                                             -- 结束 ttl 分支
      return {0, cnt, retry_after, now}               -- 返回:拒绝(0)、当前窗口计数、建议等待(ms)、now(ms)
    end                                               -- 结束 if cnt < limit 分支
    """

    def __init__(
        self,
        redis: Optional[AsyncRedis] = None,
        *,
        limit: int,
        window_seconds: float,
        ttl_ms: Optional[int] = None,
        key_prefix: str = "rate_limit:sliding_window:",
    ) -> None:
        if limit <= 0:
            raise ValueError("limit 必须 > 0")
        if window_seconds <= 0:
            raise ValueError("window_seconds 必须 > 0")

        self._redis = redis or async_redis_client
        self._key_prefix = key_prefix
        self._limit = int(limit)
        self._window_seconds = float(window_seconds)
        self._ttl_ms = int(ttl_ms) if ttl_ms is not None else None

    def _full_key(self, key: str) -> str:
        return f"{self._key_prefix}{key}"

    @staticmethod
    def _default_ttl_ms(window_seconds: float) -> int:
        # key 存活时间略大于窗口,方便清理且避免冷 key 常驻
        return int(max(5.0, window_seconds * 2.0) * 1000)

    async def allow(
        self,
        key: str,
        *,
        ttl_ms: Optional[int] = None,
    ) -> SlidingWindowResult:
        window_ms = int(self._window_seconds * 1000)
        ttl_ms_i = (
            int(ttl_ms)
            if ttl_ms is not None
            else (
                self._ttl_ms
                if self._ttl_ms is not None
                else self._default_ttl_ms(self._window_seconds)
            )
        )

        res = await self._redis.eval(
            self._SLIDING_WINDOW_LUA,
            1,
            self._full_key(key),
            window_ms,
            self._limit,
            ttl_ms_i,
        )

        allowed = bool(int(res[0]))
        current_count = int(res[1])
        retry_after_ms = int(res[2])
        now_ms = int(res[3])

        return SlidingWindowResult(
            allowed=allowed,
            current_count=current_count,
            retry_after_seconds=retry_after_ms / 1000.0 if retry_after_ms > 0 else 0.0,
            now_ms=now_ms,
        )

    async def wait(
        self,
        key: str,
        *,
        ttl_ms: Optional[int] = None,
        max_wait_seconds: Optional[float] = None,
        poll_min_sleep: float = 0.01,
    ) -> SlidingWindowResult:
        start = asyncio.get_running_loop().time()
        while True:
            r = await self.allow(
                key,
                ttl_ms=ttl_ms,
            )
            if r.allowed:
                return r

            sleep_s = max(poll_min_sleep, r.retry_after_seconds)
            if max_wait_seconds is not None:
                elapsed = asyncio.get_running_loop().time() - start
                if elapsed + sleep_s > max_wait_seconds:
                    raise TimeoutError(
                        f"sliding window wait timeout: key={key}, waited={elapsed:.3f}s, next_sleep={sleep_s:.3f}s"
                    )
            await asyncio.sleep(sleep_s)

声明:Hello World|版权所有,违者必究|如未注明,均为原创|本网站采用BY-NC-SA协议进行授权

转载:转载请注明原文链接 - redis 限流


我的朋友,理论是灰色的,而生活之树是常青的!