首页 > 试题广场 >

实现Masked Multi-Head Self-Atten

[编程题]实现Masked Multi-Head Self-Atten
  • 热度指数:153 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
给定批量序列表示 X(形状:[batch, seq, d_model])与权重矩阵 W_Q、W_K、W_V、W_O(均为 d_model×d_model),实现 Masked Multi-Head Self-Attention。  
将最后一维按头数 num_heads 均分,每头维度 d_k = d_model / num_heads。  
计算步骤:  
  1) Q = X @ W_Q,K = X @ W_K,V = X @ W_V。  
  2) 将 Q、K、V reshape 为 [batch, num_heads, seq, d_k]。  
  3) 计算注意力分数 scores = (Q @ K^T) / sqrt(d_k),其中 K^T 表示每头在最后两维做转置得到 [batch, num_heads, seq, seq]。  
  4) 使用下三角因果掩码(只能看见当前及更早位置):掩掉上三角元素(置为一个很小的负数)。  
  5) 在最后一维做 softmax 得到权重,注意数值稳定性(减去每行最大值再做 exp)。  
  6) attention = softmax @ V(形状 [batch, num_heads, seq, d_k])。  
  7) 拼回 [batch, seq, d_model] 后,再右乘 W_O。  
输出保留两位小数,结果需转换为 Python List。

输入描述:
以分号分隔的 6 个参数:num_heads; X; W_Q; W_K; W_V; W_O  
其中 X、W_Q、W_K、W_V、W_O 用 Python 风格的嵌套列表表示。


输出描述:
最终输出张量(形状 [batch, seq, d_model]),四舍五入到小数点后两位,类型为 List。
示例1

输入

2; [[[1, 1], [1, 1], [1, 1]]]; [[1, 0], [0, 1]]; [[1, 0], [0, 1]]; [[1, 0], [0, 1]]; [[1, 0], [0, 1]]

输出

[[[1.00, 1.00], [1.00, 1.00], [1.00, 1.00]]]

说明

权重为单位矩阵,Q=K=V=X。因果掩码使第 i 个位置只看见前 i+1 个位置;由于各位置完全相同,softmax 权重在可见范围内均匀分布,输出与输入一致;乘 W_O(单位)后不变。

备注:
本题由牛友@Charles 整理上传
import numpy as np

readin = input().split(';')

num_heads = int(readin[0])
X = np.array(eval(readin[1]))
W_Q = np.array(eval(readin[2]))
W_K = np.array(eval(readin[3]))
W_V = np.array(eval(readin[4]))
W_O = np.array(eval(readin[5]))

batch_size, seq, d_model = X.shape
d_k = d_model // num_heads

result = []

mask = np.zeros([seq, seq])
for ii in range(seq):
    for jj in range(seq):
        if ii < jj:
            mask[ii, jj] = -1e9

for ii in range(batch_size):
    Q1 = X[ii] @ W_Q
    K1 = X[ii] @ W_K
    V1 = X[ii] @ W_V

    attention = np.zeros([seq, d_model])
    for jj in range(num_heads):
        Q = Q1[:, (jj*d_k):((jj+1)*d_k)]
        K = K1[:, (jj*d_k):((jj+1)*d_k)]
        V = V1[:, (jj*d_k):((jj+1)*d_k)]

        scores = (Q @ K.T + mask) / np.sqrt(d_k)
        max_in_line = np.max(scores, axis=1)
        scores = scores - max_in_line[:, np.newaxis]
        softmax = np.exp(scores)
        softmax = softmax / np.sum(softmax, axis=1)[:, np.newaxis]
        attention[:, (jj*d_k):((jj+1)*d_k)] = softmax @ V

    out = attention @ W_O

    result.append(out)

output = '['
for ii in range(batch_size):
    output += '['
    for jj in range(seq):
        output += '['
        for kk in range(d_model):
            if kk < d_model - 1:
                output += '%.2f, ' % result[ii][jj, kk]
            else:
                output += '%.2f' % result[ii][jj, kk]
        if jj < seq - 1:
            output += '], '
        else:
            output += ']'
    if ii < batch_size - 1:
        output += '], '
    else:
        output += ']'
output += ']'
print(output)

发表于 2025-10-09 18:27:51 回复(0)