首页 > 试题广场 >

实现自注意力机制

[编程题]实现自注意力机制
  • 热度指数:1376 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
实现自注意力机制,这是转换器模型的基本组成部分,广泛应用于自然语言处理和计算机视觉任务。自注意力机制允许模型在生成上下文表示时动态关注输入序列的不同部分。

输入描述:
输入为四个numpy数组,分别表示输入序列、查询权重矩阵、键权重矩阵和值权重矩阵。



输出描述:
返回一个numpy数组,表示自注意力的输出。

示例1

输入

[[1, 0], [0, 1]]
[[1, 0], [0, 1]]
[[1, 0], [0, 1]]
[[1, 2], [3, 4]]

输出

[[1.6604769 2.6604769]
 [2.3395231 3.3395231]]

备注:
1.对应的输入、输出已给出,您只用实现核心功能函数即可。
2.支持numpy、scipy、pandas、scikit-learn库。
import numpy as np
import sys

def compute_qkv(X, W_q, W_k, W_v):
    """
    计算查询矩阵Q、键矩阵K和值矩阵V
    :param X: 输入序列矩阵 (N, D)
    :param W_q: 查询权重矩阵 (D, D_k)
    :param W_k: 键权重矩阵 (D, D_k)
    :param W_v: 值权重矩阵 (D, D_v)
    :return: Q, K, V 矩阵
    """
    # 矩阵乘法计算 Q, K, V
    Q = np.dot(X, W_q)
    K = np.dot(X, W_k)
    V = np.dot(X, W_v)
    return Q, K, V

def self_attention(Q, K, V):
    """
    实现自注意力机制的核心计算逻辑
    :param Q: 查询矩阵
    :param K: 键矩阵
    :param V: 值矩阵
    :return: 注意力机制的输出
    """
    # 获取特征维度 d_k (K的最后一维)
    d_k = K.shape[-1]
    
    # 1. 计算注意力得分: Q * K^T
    # scores 形状为 (N, N)
    scores = np.dot(Q, K.T)
    
    # 2. 缩放: 除以 sqrt(d_k)
    scaled_scores = scores / np.sqrt(d_k)
    
    # 3. Softmax 归一化: 在每一行(最后一个轴)上应用
    # 为了数值稳定性,可以减去每行的最大值,但 numpy.exp 直接计算在此题范围内也可行
    exp_scores = np.exp(scaled_scores)
    # axis=-1 表示按行求和,keepdims=True 保证维度对齐以便广播
    attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
    
    # 4. 加权求和: 权重矩阵乘 V
    output = np.dot(attention_weights, V)
    
    return output

def main():
    # 使用标准输入读取数据
    # 根据题目描述,输入为四个 numpy 数组的字符串表示
    try:
        input_data = sys.stdin.read().splitlines()
        if not input_data:
            return
        
        # 解析输入序列 X, 权重矩阵 W_q, W_k, W_v
        # 使用 eval 将字符串形式的列表转为 python 列表再转为 numpy 数组
        X = np.array(eval(input_data[0]))
        W_q = np.array(eval(input_data[1]))
        W_k = np.array(eval(input_data[2]))
        W_v = np.array(eval(input_data[3]))
        
        # 第一步:计算 Q, K, V
        Q, K, V = compute_qkv(X, W_q, W_k, W_v)
        
        # 第二步:计算注意力输出
        result = self_attention(Q, K, V)
        
        # 第三步:输出结果
        print(result)
        
    except EOFError:
        pass

if __name__ == "__main__":
    main()


发表于 2026-03-16 21:05:09 回复(0)