华为ai算法笔试 华为笔试 华为秋招 1010
笔试时间:2025年10月10日
往年笔试合集:
第一题:经典LSTM模型结构实现
长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(RNN),旨在解决传统RNN中存在的梯度消失和梯度爆炸问题,使其能够有效地学习长期依赖关系。
一个LSTM单元(Cell)的核心由三个关键的门和一个细胞状态(Cell State)组成:
- 细胞状态(Cell State):这是LSTM的"记忆高速公路",信息沿着这条路径从一个时间步传递到下一个。
- 遗忘门(Forget Gate):决定从上一时间步的细胞状态中丢弃哪些信息。通过Sigmoid激活函数,输出介于0和1之间的向量。0表示完全遗忘,1表示完全保留。
- 输入门(Input Gate):决定将哪些新的信息添加到细胞状态中。包含两个部分:一个Sigmoid层,用于决定哪些信息需要更新一个Tanh层,用于创建一个新的候选细胞状态
- 输出门(Output Gate):决定当前时间步的隐藏状态(Hidden State)将输出哪些信息。
根据给定的LSTM计算结构示意图,实现一个LSTM模型的关键函数。该LSTM模型包含5个LSTM Cell(每个Cell权重参数为wf、wi、wg、wo,对应偏置为bf、bi、bg、bo)。模型会依次处理输入序列的每个时间步(t=1到t=4),每个时间步的计算都会产生一个5维的隐藏状态h。
需要针对不同输入矩阵运行LSTM,并输出每个时间步隐藏状态h的首元素,按时间步顺序组成。
每个输入矩阵形状为4×7,即:
- 时间步长(sequence_length)= 4
- 输入数据维度(x_dim)= 7
输入描述
一行数据,描述输入矩阵:
- 前2个数据为整型:sequence_length和x_dim
- 后续为输入矩阵的浮点数据,按行平铺(flatten)为一维序列,数据用空格分隔
输出描述
一行数据,表示每个时间步隐藏层状态h的首元素,按时间步顺序输出,数据间用空格分隔。
精度要求:
- 四舍五入到小数点后3位
- 尾部多余的0去除,例如:0.200输出为0.2
- 特殊情况:如果值为0、0.000、0.00、0.0,统一输出为0.0
样例输入
4 7 -1.153285 -0.081943 0.464549 3.411137 0.594197 1.21088 -0.234899 -0.272196 0.27
样例输出
0.01 0.003 0.009 -0.019
参考题解
解题思路:
LSTM通过三个门控机制(遗忘门、输入门、输出门)和一个细胞状态来维护长期记忆。每个时间步的计算都会更新细胞状态C和隐藏状态h。
计算步骤:
- 初始化状态:隐藏状态h和细胞状态C(5维零向量)
- 每个时间步的计算流程:计算遗忘门:f_t = sigmoid(W_f · [h_{t-1}, x_t] + b_f)计算输入门和候选状态:i_t = sigmoid(W_i · [h_{t-1}, x_t] + b_i),g_t = tanh(W_g · [h_{t-1}, x_t] + b_g)更新细胞状态:C_t = f_t * C_{t-1} + i_t * g_t计算输出门和隐藏状态:o_t = sigmoid(W_o · [h_{t-1}, x_t] + b_o),h_t = o_t * tanh(C_t)
- 输出处理:只取每个时间步隐藏状态h的第一个元素,进行格式化处理
Python:
import numpy as np def f1(x): return 1 / (1 + np.exp(-x)) def f2(x): return np.tanh(x) np.random.seed(42) d1 = 5 d2 = 7 M1 = np.random.randn(d1, d1 + d2) * 0.01 M2 = np.random.randn(d1, d1 + d2) * 0.01 M3 = np.random.randn(d1, d1 + d2) * 0.01 M4 = np.random.randn(d1, d1 + d2) * 0.01 v1 = np.random.randn(d1) * 0.01 v2 = np.random.randn(d1) * 0.01 v3 = np.random.randn(d1) * 0.01 v4 = np.random.randn(d1) * 0.01 def calc(a, b, c, d): return f1(np.dot(a, np.concatenate((b, c))) + d) def calc2(a, b, c, d): return f2(np.dot(a, np.concatenate((b, c))) + d) data = input().strip().split() n = int(data[0]) d2 = int(data[1]) vals = list(
剩余60%内容,订阅专栏后可继续查看/也可单篇购买
2025打怪升级记录,大厂笔试合集 C++, Java, Python等多种语言做法集合指南