题解 | #小红的语言模型推理耗时预测#
小红的语言模型推理耗时预测
https://www.nowcoder.com/practice/6cfc226fdfd34dd2a8e4ee2dcc2bb419
题目链接
题目描述
小红计划构建一个线性回归模型来预估语言模型的推理耗时。模型包含三个特征:协议连接数 、包转发率
和内存占用百分比
。
训练过程如下:
- 特征归一化:对每一列特征进行 Min-Max 归一化。若
,则归一化值为
。
- 权重训练:初始化权重
。进行
轮迭代,使用批量梯度下降。
- 权重还原:将归一化空间下的权重还原到原始量纲。
输出还原后的 ,结果使用银行家舍入法保留 2 位小数。
解题思路
本题可以通过数学推导对梯度下降过程进行算法优化:
-
梯度公式优化: 传统的 BGD 每一轮需要
计算梯度。展开梯度公式:
我们可以预先计算特征的各项累加和(如
),使每轮迭代的复杂度从
降至
。
-
复杂度提升: 优化后的总时间复杂度为
,相比原先的
在大规模数据下更具优势。
-
银行家舍入: 严格执行“四舍六入五取偶”。在 Java 中使用
BigDecimal,在 Python 中使用decimal模块,在 C++ 中使用nearbyint配合FE_TONEAREST。
代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <iomanip>
#include <cmath>
#include <cfenv>
using namespace std;
double banker_round_2(double val) {
fesetround(FE_TONEAREST);
return nearbyint(val * 100.0) / 100.0;
}
int main() {
int m, n;
double alpha;
cin >> m >> n >> alpha;
vector<vector<double>> x(m, vector<double>(3));
vector<double> y(m);
vector<double> min_v(3, 1e18), max_v(3, -1e18);
for (int i = 0; i < m; ++i) {
for (int j = 0; j < 3; ++j) {
cin >> x[i][j];
min_v[j] = min(min_v[j], x[i][j]);
max_v[j] = max(max_v[j], x[i][j]);
}
cin >> y[i];
}
// 预计算归一化特征及其各项乘积之和
double s[3] = {0}, s2[3] = {0}, s_cross[3] = {0}, sy = 0, syx[3] = {0};
// s_cross[0]: x1*x2, s_cross[1]: x1*x3, s_cross[2]: x2*x3
for (int i = 0; i < m; ++i) {
double xn[3];
for (int j = 0; j < 3; ++j) {
double range = max_v[j] - min_v[j];
xn[j] = (range == 0) ? 0 : (x[i][j] - min_v[j]) / range;
s[j] += xn[j];
s2[j] += xn[j] * xn[j];
syx[j] += y[i] * xn[j];
}
sy += y[i];
s_cross[0] += xn[0] * xn[1];
s_cross[1] += xn[0] * xn[2];
s_cross[2] += xn[1] * xn[2];
}
double w[4] = {0};
for (int it = 0; it < n; ++it) {
double g[4];
g[0] = (w[0] * m + w[1] * s[0] + w[2] * s[1] + w[3] * s[2] - sy) / m;
g[1] = (w[0] * s[0] + w[1] * s2[0] + w[2] * s_cross[0] + w[3] * s_cross[1] - syx[0]) / m;
g[2] = (w[0] * s[1] + w[1] * s_cross[0] + w[2] * s2[1] + w[3] * s_cross[2] - syx[1]) / m;
g[3] = (w[0] * s[2] + w[1] * s_cross[1] + w[2] * s_cross[2] + w[3] * s2[2] - syx[2]) / m;
for (int k = 0; k < 4; ++k) w[k] -= alpha * g[k];
}
double wf[4], sum_t = 0;
for (int k = 1; k <= 3; ++k) {
double r = max_v[k - 1] - min_v[k - 1];
if (r == 0) wf[k] = 0;
else {
wf[k] = w[k] / r;
sum_t += (w[k] * min_v[k - 1]) / r;
}
}
wf[0] = w[0] - sum_t;
for (int i = 0; i < 4; ++i) {
cout << fixed << setprecision(2) << banker_round_2(wf[i]) << (i == 3 ? "" : " ");
}
cout << endl;
return 0;
}
import java.util.*;
import java.math.BigDecimal;
import java.math.RoundingMode;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in).useLocale(Locale.US);
int m = sc.nextInt();
int n = sc.nextInt();
double alpha = sc.nextDouble();
double[][] x = new double[m][3];
double[] y = new double[m];
double[] minV = new double[3];
double[] maxV = new double[3];
Arrays.fill(minV, Double.MAX_VALUE);
Arrays.fill(maxV, -Double.MAX_VALUE);
for (int i = 0; i < m; i++) {
for (int j = 0; j < 3; j++) {
x[i][j] = sc.nextDouble();
minV[j] = Math.min(minV[j], x[i][j]);
maxV[j] = Math.max(maxV[j], x[i][j]);
}
y[i] = sc.nextDouble();
}
double[] s = new double[3], s2 = new double[3], scross = new double[3];
double sy = 0;
double[] syx = new double[3];
for (int i = 0; i < m; i++) {
double[] xn = new double[3];
for (int j = 0; j < 3; j++) {
double r = maxV[j] - minV[j];
xn[j] = (r == 0) ? 0 : (x[i][j] - minV[j]) / r;
s[j] += xn[j];
s2[j] += xn[j] * xn[j];
syx[j] += y[i] * xn[j];
}
sy += y[i];
scross[0] += xn[0] * xn[1];
scross[1] += xn[0] * xn[2];
scross[2] += xn[1] * xn[2];
}
double[] w = new double[4];
for (int it = 0; it < n; it++) {
double g0 = (w[0] * m + w[1] * s[0] + w[2] * s[1] + w[3] * s[2] - sy) / m;
double g1 = (w[0] * s[0] + w[1] * s2[0] + w[2] * scross[0] + w[3] * scross[1] - syx[0]) / m;
double g2 = (w[0] * s[1] + w[1] * scross[0] + w[2] * s2[1] + w[3] * scross[2] - syx[1]) / m;
double g3 = (w[0] * s[2] + w[1] * scross[1] + w[2] * scross[2] + w[3] * s2[2] - syx[2]) / m;
w[0] -= alpha * g0; w[1] -= alpha * g1; w[2] -= alpha * g2; w[3] -= alpha * g3;
}
double[] wf = new double[4];
double st = 0;
for (int k = 1; k <= 3; k++) {
double r = maxV[k - 1] - minV[k - 1];
if (r == 0) wf[k] = 0;
else {
wf[k] = w[k] / r;
st += (w[k] * minV[k - 1]) / r;
}
}
wf[0] = w[0] - st;
for (int i = 0; i < 4; i++) {
String res = BigDecimal.valueOf(wf[i]).setScale(2, RoundingMode.HALF_EVEN).toString();
if (!res.contains(".")) res += ".00";
else if (res.split("\\.")[1].length() == 1) res += "0";
System.out.print(res + (i == 3 ? "" : " "));
}
System.out.println();
}
}
from decimal import Decimal, ROUND_HALF_EVEN
def solve():
m = int(input())
n = int(input())
alpha = float(input())
data = [list(map(float, input().split())) for _ in range(m)]
min_v = [min(row[j] for row in data) for j in range(3)]
max_v = [max(row[j] for row in data) for j in range(3)]
s, s2, scross = [0.0]*3, [0.0]*3, [0.0]*3
sy, syx = 0.0, [0.0]*3
for row in data:
xn = []
for j in range(3):
rng = max_v[j] - min_v[j]
val = (row[j] - min_v[j]) / rng if rng != 0 else 0.0
xn.append(val)
s[j] += val
s2[j] += val * val
syx[j] += row[3] * val
sy += row[3]
scross[0] += xn[0] * xn[1]
scross[1] += xn[0] * xn[2]
scross[2] += xn[1] * xn[2]
w = [0.0] * 4
for _ in range(n):
g0 = (w[0]*m + w[1]*s[0] + w[2]*s[1] + w[3]*s[2] - sy) / m
g1 = (w[0]*s[0] + w[1]*s2[0] + w[2]*scross[0] + w[3]*scross[1] - syx[0]) / m
g2 = (w[0]*s[1] + w[1]*scross[0] + w[2]*s2[1] + w[3]*scross[2] - syx[1]) / m
g3 = (w[0]*s[2] + w[1]*scross[1] + w[2]*scross[2] + w[3]*s2[2] - syx[2]) / m
w[0] -= alpha * g0
w[1] -= alpha * g1
w[2] -= alpha * g2
w[3] -= alpha * g3
wf = [0.0] * 4
sum_t = 0.0
for k in range(1, 4):
rng = max_v[k-1] - min_v[k-1]
if rng == 0: wf[k] = 0.0
else:
wf[k] = w[k] / rng
sum_t += (w[k] * min_v[k-1]) / rng
wf[0] = w[0] - sum_t
def banker_round(val):
d = Decimal(str(val)).quantize(Decimal("0.00"), rounding=ROUND_HALF_EVEN)
return f"{d:.2f}"
print(" ".join(banker_round(v) for v in wf))
solve()
算法及复杂度
- 算法:线性回归 + BGD 优化(预计算特征项和)。
- 时间复杂度:
。预处理扫描一次数据
,后续迭代由于使用预计算的和,单次迭代
。
- 空间复杂度:
。
查看10道真题和解析