# 训练时(全量计算) # 为每个训练样本前向传播计算完整的注意力 # 序列长度 L, 头数 H, 维度 D attention_scores = Q @ K.T # [B, H, L, L] # 需要存储完整注意力矩阵用于反向传播
# 推理时(增量解码)
# 假设已生成前 t-1 个 token
# 只计算新 token 的 query,复用之前的 KV
new_Q = W_q @ x_t # [B, H, 1, D]
# KV Cache 存储了前 t-1 个 token 的 K, V
K_cache = [K_1, K_2, ..., K_{t-1}] # [B, H, t-1, D]
V_cache = [V_1, V_2, ..., V_{t-1}] # [B, H, t-1, D]
# 只计算新 token 与前 t-1 个 token 的注意力
attention = softmax(new_Q @ K_cache.T / sqrt(D))