模型能生成文本只是开始,能稳定快速地服务才是工程

项目目标

训练好的模型只输出 logits。产品需要的是文本、速度、成本、稳定性和可控性。

这一段路线处理推理系统:如何从概率分布采样,如何减少重复计算,如何应对长上下文,如何理解 FlashAttention、PagedAttention、量化、batching、硬件带宽这些工程词。

项目 9:做一个 sampling dashboard

生成文本时,模型每一步都会输出一个 logits 向量。采样策略决定如何从这个分布里选下一个 token。

需要支持:

参数作用
temperature控制分布尖锐程度
top-k只在概率最高的 k 个 token 中采样
top-p只在累计概率达到 p 的候选中采样
repetition penalty惩罚重复 token
greedy / argmax每次取最大概率

要画的图:

  • temperature 从低到高时的 entropy 曲线。
  • top-k/top-p 对候选集合大小的影响。
  • 同一个 prompt 下不同参数的输出样例。

破坏实验:temperature=0 或过低时,观察重复;temperature 过高时,观察语义漂移。

最小实现参考(一段就能演示全部采样策略):

import torch
import torch.nn.functional as F

def sample_next(logits, temperature=1.0, top_k=None, top_p=None, repetition_penalty=1.0, prev_tokens=None):
    """logits: (vocab_size,) 一步的 logits 向量"""
    # 1) repetition penalty:对历史出现过的 token 降权
    if repetition_penalty != 1.0 and prev_tokens is not None:
        for t in set(prev_tokens):
            logits[t] = logits[t] / repetition_penalty if logits[t] > 0 else logits[t] * repetition_penalty

    # 2) temperature
    logits = logits / max(temperature, 1e-5)

    # 3) top-k:只保留 logits 最高的 k 个,其它设 -inf
    if top_k is not None and top_k > 0:
        thresh = torch.topk(logits, top_k).values[-1]
        logits = torch.where(logits < thresh, torch.full_like(logits, float('-inf')), logits)

    # 4) top-p(nucleus sampling):累计概率达 p 之后截断
    if top_p is not None and 0 < top_p < 1:
        sorted_logits, sorted_idx = torch.sort(logits, descending=True)
        cum_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
        to_remove = cum_probs > top_p
        # 至少保留最高那个
        to_remove[..., 1:] = to_remove[..., :-1].clone()
        to_remove[..., 0] = False
        sorted_logits[to_remove] = float('-inf')
        # 还原到原顺序
        logits = torch.empty_like(logits).scatter_(0, sorted_idx, sorted_logits)

    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1).item()

# entropy 监控(让 dashboard 上的曲线有数据来源)
def entropy(logits, temperature=1.0):
    p = torch.softmax(logits / temperature, dim=-1)
    return -(p * torch.log(p + 1e-10)).sum().item()

temperature 从 0.2 扫到 2.0,记录 entropy 和被采样 token 多样性,就是 sampling dashboard 的核心数据。

项目 10:实现 speculative decoding

自回归生成慢,是因为第 N 个 token 必须等第 N-1 个 token 出来。speculative decoding 用小模型先猜多个 token,再让大模型批量验证。

最小实现:

  1. draft model 一次生成 k 个候选 token。
  2. target model 对这些 token 做验证。
  3. 接受匹配的前缀。
  4. 第一个不匹配的位置重新采样。

观察指标:

  • 接受率。
  • tokens/sec。
  • draft model 越小是否越快。
  • draft model 太弱时接受率是否下降。

这个项目的重点不是追求极限性能,而是理解“用额外计算换并行验证”的思想。

项目 11:实现 KV cache

没有 KV cache 时,每生成一个新 token,都要重新计算前面所有 token 的 K/V。KV cache 把历史 K/V 存下来,新 token 只计算自己,然后和缓存拼接。

要做两个版本:

no_cache_generate(prompt)
cache_generate(prompt)

比较:

  • prefill 时间。
  • decode 单 token 时间。
  • 序列越长时延迟如何增长。
  • cache 显存占用如何随层数、head 数、上下文长度增长。

KV cache 是训练和推理的关键差异之一。训练时整段序列并行;推理时逐 token 生成,cache 决定能不能把历史计算复用起来。

最小 KV cache 实现,看 cache 怎么从无到有改变 attention 调用:

class CachedSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.n_head = n_head
        self.head_dim = n_embd // n_head
        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.out = nn.Linear(n_embd, n_embd, bias=False)

    def forward(self, x, past_kv=None):
        """
        x: (B, T_new, C)
        past_kv: tuple (past_k, past_v),每个 shape (B, H, T_past, D)
        return: y, new_kv
        """
        B, T, C = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = [t.view(B, T, self.n_head, self.head_dim).transpose(1, 2) for t in qkv]
        # q: (B, H, T_new, D)

        if past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=2)   # 拼接历史
            v = torch.cat([past_v, v], dim=2)

        # attention over [past + new]
        att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        att = F.softmax(att, dim=-1)
        y = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.out(y), (k, v)

测试有无 cache 的延迟差异:

import time
prompt_len, gen_len = 256, 128

# 无 cache:每步都重新跑全 prompt + 已生成
t0 = time.time()
ids = prompt.clone()
for _ in range(gen_len):
    logits = model(ids)[:, -1, :]   # 整段重算
    next_id = sample_next(logits[0])
    ids = torch.cat([ids, torch.tensor([[next_id]])], dim=1)
no_cache_t = time.time() - t0

# 有 cache:第一次跑全 prompt,之后每步只跑一个 token
t0 = time.time()
ids = prompt.clone()
past = None
logits, past = model_with_cache(ids, past_kv=past)
for _ in range(gen_len):
    next_id = sample_next(logits[0, -1, :])
    logits, past = model_with_cache(
        torch.tensor([[next_id]]), past_kv=past
    )
cache_t = time.time() - t0

print(f"无 cache: {no_cache_t:.2f}s   有 cache: {cache_t:.2f}s   加速 {no_cache_t/cache_t:.1f}x")

序列越长加速越明显——256 prompt + 128 生成时 cache 通常 5~10x 加速,1K + 1K 时能到 20~30x。

项目 12:实现 MQA、GQA,并理解 MLA

标准 multi-head attention 每个 query head 都有自己的 K/V head。这样表达力强,但 KV cache 很大。

MQA 让所有 query head 共享一组 K/V。GQA 折中,让多个 query head 共享一组 K/V。MLA 进一步把 KV 压到低秩 latent 表示。

最小实验:

方案KV head 数观察
MHA等于 query headcache 最大,质量基准
GQA少于 query headcache 下降,质量通常较稳
MQA1cache 最小,可能损质量

要记录 tokens/sec、显存、验证 loss 或小任务准确率。

项目 13-14:长上下文实验

长上下文不是把 block_size 调大就结束。它会带来三类问题:

  • 计算和显存成本上升。
  • 模型可能丢失远处信息。
  • attention 可能被无关 token 稀释。

要做的实验:

  1. sliding-window attention:只看最近窗口。
  2. attention sink:保留开头少量 token。
  3. RoPE scaling / YaRN-style interpolation:扩展位置范围。
  4. 长文 needle-in-a-haystack:在长上下文中找特定事实。

观察图:

  • context length vs latency。
  • context length vs memory。
  • context length vs retrieval accuracy。
  • perplexity vs context length。

项目 15:对比 naive attention、SDPA、FlashAttention

数学上都是 attention,但工程性能不同。

naive attention 会显式构造 (T, T) attention matrix。序列长时,内存读写成为瓶颈。FlashAttention 的核心思路是优化内存访问,减少高带宽显存读写,而不是改变 attention 的数学结果。

实验对比 naive vs SDPA(PyTorch 内置,会自动选 FlashAttention 后端):

import torch
import torch.nn.functional as F
import time

B, H, T, D = 4, 12, 2048, 64       # batch / head / seq / head_dim
q = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
k = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
v = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)

# 1) naive attention(显式构造 T×T attention 矩阵)
def naive_attn(q, k, v):
    scores = (q @ k.transpose(-2, -1)) / (D ** 0.5)
    return F.softmax(scores, dim=-1) @ v

# 2) SDPA(CUDA 后端会自动用 FlashAttention v2 / mem-efficient)
def sdpa_attn(q, k, v):
    return F.scaled_dot_product_attention(q, k, v)

# warmup
for _ in range(3):
    naive_attn(q, k, v); sdpa_attn(q, k, v)
torch.cuda.synchronize()

# benchmark
for name, fn in [("naive", naive_attn), ("SDPA", sdpa_attn)]:
    torch.cuda.reset_peak_memory_stats()
    t0 = time.time()
    for _ in range(50):
        out = fn(q, k, v)
    torch.cuda.synchronize()
    dt = (time.time() - t0) / 50 * 1000
    mem = torch.cuda.max_memory_allocated() / 1024 / 1024
    print(f"{name:8s}  {dt:.2f} ms  峰值显存 {mem:.0f} MB")

T=2048 在 A100 上典型差距:

naive     12.50 ms  峰值显存 800 MB
SDPA       2.10 ms  峰值显存  60 MB

核心结论:两者数学结果相同,差距全在内存访问。naive 实现要落 T×T 的 attention 矩阵到 HBM,T=2048 就是 16 MB × batch × head 数,每步都要读写一遍;FlashAttention 用 tile + recompute,把中间结果留在 SRAM,HBM 访问次数降一个数量级。这就是"同一公式,工程实现差几倍"的来源

项目 16:建立硬件预算表

推理优化最后都要落到硬件预算:

维度问题
参数量权重能否放进显存
precisionFP16/BF16/INT8/INT4 对质量和速度的影响
memory bandwidth单 token 解码是否被带宽限制
KV cache并发和上下文长度的主要显存压力
batch size吞吐与首 token 延迟的取舍
interconnect多 GPU tensor parallel 是否被通信拖慢

这张表会直接服务于部署选型:llama.cpp、vLLM、TensorRT-LLM、SGLang、ExLlamaV2 并不是“谁更高级”,而是在不同硬件、模型、并发、量化格式下适合不同任务。

本项目交付物

  • 一个 sampling dashboard。
  • speculative decoding 最小实现。
  • 有/无 KV cache 的 benchmark。
  • MHA/GQA/MQA cache 对比表。
  • 长上下文压力测试脚本。
  • naive attention vs SDPA/FlashAttention 的性能图。
  • 一张自己的硬件预算表。

和本站已有内容的连接

  • inference-opt/01-推理为什么慢
  • inference-opt/02-模型量化INT8与INT4
  • inference-opt/06-推理引擎横评
  • vllm/vLLM高性能推理部署
  • vllm/vLLM-speculative-decoding
  • llm-app/长上下文处理
  • llm-app/成本优化

延伸阅读

版权声明: 如无特别声明,本文版权归 sshipanoo 所有,转载请注明本文链接。

(采用 CC BY-NC-SA 4.0 许可协议进行授权)

本文标题:项目 9-16:Decoding、KV Cache、长上下文与推理系统

本文链接:https://www.sshipanoo.com/blog/ai/llm-roadmap/04-DecodingKVCache与推理系统/

本文最后一次更新为 天前,文章中的某些内容可能已过时!