题解 | #小红的语言模型推理耗时预测#

小红的语言模型推理耗时预测

https://www.nowcoder.com/practice/6cfc226fdfd34dd2a8e4ee2dcc2bb419

题目链接

小红的语言模型推理耗时预测

题目描述

小红计划构建一个线性回归模型来预估语言模型的推理耗时。模型包含三个特征:协议连接数 、包转发率 和内存占用百分比 。 训练过程如下:

  1. 特征归一化:对每一列特征进行 Min-Max 归一化。若 ,则归一化值为
  2. 权重训练:初始化权重 。进行 轮迭代,使用批量梯度下降。
  3. 权重还原:将归一化空间下的权重还原到原始量纲。

输出还原后的 ,结果使用银行家舍入法保留 2 位小数。

解题思路

本题可以通过数学推导对梯度下降过程进行算法优化

  1. 梯度公式优化: 传统的 BGD 每一轮需要 计算梯度。展开梯度公式: 我们可以预先计算特征的各项累加和(如 ),使每轮迭代的复杂度从 降至

  2. 复杂度提升: 优化后的总时间复杂度为 ,相比原先的 在大规模数据下更具优势。

  3. 银行家舍入: 严格执行“四舍六入五取偶”。在 Java 中使用 BigDecimal,在 Python 中使用 decimal 模块,在 C++ 中使用 nearbyint 配合 FE_TONEAREST

代码

#include <iostream>
#include <vector>
#include <algorithm>
#include <iomanip>
#include <cmath>
#include <cfenv>

using namespace std;

double banker_round_2(double val) {
    fesetround(FE_TONEAREST);
    return nearbyint(val * 100.0) / 100.0;
}

int main() {
    int m, n;
    double alpha;
    cin >> m >> n >> alpha;

    vector<vector<double>> x(m, vector<double>(3));
    vector<double> y(m);
    vector<double> min_v(3, 1e18), max_v(3, -1e18);

    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < 3; ++j) {
            cin >> x[i][j];
            min_v[j] = min(min_v[j], x[i][j]);
            max_v[j] = max(max_v[j], x[i][j]);
        }
        cin >> y[i];
    }

    // 预计算归一化特征及其各项乘积之和
    double s[3] = {0}, s2[3] = {0}, s_cross[3] = {0}, sy = 0, syx[3] = {0};
    // s_cross[0]: x1*x2, s_cross[1]: x1*x3, s_cross[2]: x2*x3
    for (int i = 0; i < m; ++i) {
        double xn[3];
        for (int j = 0; j < 3; ++j) {
            double range = max_v[j] - min_v[j];
            xn[j] = (range == 0) ? 0 : (x[i][j] - min_v[j]) / range;
            s[j] += xn[j];
            s2[j] += xn[j] * xn[j];
            syx[j] += y[i] * xn[j];
        }
        sy += y[i];
        s_cross[0] += xn[0] * xn[1];
        s_cross[1] += xn[0] * xn[2];
        s_cross[2] += xn[1] * xn[2];
    }

    double w[4] = {0};
    for (int it = 0; it < n; ++it) {
        double g[4];
        g[0] = (w[0] * m + w[1] * s[0] + w[2] * s[1] + w[3] * s[2] - sy) / m;
        g[1] = (w[0] * s[0] + w[1] * s2[0] + w[2] * s_cross[0] + w[3] * s_cross[1] - syx[0]) / m;
        g[2] = (w[0] * s[1] + w[1] * s_cross[0] + w[2] * s2[1] + w[3] * s_cross[2] - syx[1]) / m;
        g[3] = (w[0] * s[2] + w[1] * s_cross[1] + w[2] * s_cross[2] + w[3] * s2[2] - syx[2]) / m;
        
        for (int k = 0; k < 4; ++k) w[k] -= alpha * g[k];
    }

    double wf[4], sum_t = 0;
    for (int k = 1; k <= 3; ++k) {
        double r = max_v[k - 1] - min_v[k - 1];
        if (r == 0) wf[k] = 0;
        else {
            wf[k] = w[k] / r;
            sum_t += (w[k] * min_v[k - 1]) / r;
        }
    }
    wf[0] = w[0] - sum_t;

    for (int i = 0; i < 4; ++i) {
        cout << fixed << setprecision(2) << banker_round_2(wf[i]) << (i == 3 ? "" : " ");
    }
    cout << endl;
    return 0;
}
import java.util.*;
import java.math.BigDecimal;
import java.math.RoundingMode;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in).useLocale(Locale.US);
        int m = sc.nextInt();
        int n = sc.nextInt();
        double alpha = sc.nextDouble();

        double[][] x = new double[m][3];
        double[] y = new double[m];
        double[] minV = new double[3];
        double[] maxV = new double[3];
        Arrays.fill(minV, Double.MAX_VALUE);
        Arrays.fill(maxV, -Double.MAX_VALUE);

        for (int i = 0; i < m; i++) {
            for (int j = 0; j < 3; j++) {
                x[i][j] = sc.nextDouble();
                minV[j] = Math.min(minV[j], x[i][j]);
                maxV[j] = Math.max(maxV[j], x[i][j]);
            }
            y[i] = sc.nextDouble();
        }

        double[] s = new double[3], s2 = new double[3], scross = new double[3];
        double sy = 0;
        double[] syx = new double[3];

        for (int i = 0; i < m; i++) {
            double[] xn = new double[3];
            for (int j = 0; j < 3; j++) {
                double r = maxV[j] - minV[j];
                xn[j] = (r == 0) ? 0 : (x[i][j] - minV[j]) / r;
                s[j] += xn[j];
                s2[j] += xn[j] * xn[j];
                syx[j] += y[i] * xn[j];
            }
            sy += y[i];
            scross[0] += xn[0] * xn[1];
            scross[1] += xn[0] * xn[2];
            scross[2] += xn[1] * xn[2];
        }

        double[] w = new double[4];
        for (int it = 0; it < n; it++) {
            double g0 = (w[0] * m + w[1] * s[0] + w[2] * s[1] + w[3] * s[2] - sy) / m;
            double g1 = (w[0] * s[0] + w[1] * s2[0] + w[2] * scross[0] + w[3] * scross[1] - syx[0]) / m;
            double g2 = (w[0] * s[1] + w[1] * scross[0] + w[2] * s2[1] + w[3] * scross[2] - syx[1]) / m;
            double g3 = (w[0] * s[2] + w[1] * scross[1] + w[2] * scross[2] + w[3] * s2[2] - syx[2]) / m;
            w[0] -= alpha * g0; w[1] -= alpha * g1; w[2] -= alpha * g2; w[3] -= alpha * g3;
        }

        double[] wf = new double[4];
        double st = 0;
        for (int k = 1; k <= 3; k++) {
            double r = maxV[k - 1] - minV[k - 1];
            if (r == 0) wf[k] = 0;
            else {
                wf[k] = w[k] / r;
                st += (w[k] * minV[k - 1]) / r;
            }
        }
        wf[0] = w[0] - st;

        for (int i = 0; i < 4; i++) {
            String res = BigDecimal.valueOf(wf[i]).setScale(2, RoundingMode.HALF_EVEN).toString();
            if (!res.contains(".")) res += ".00";
            else if (res.split("\\.")[1].length() == 1) res += "0";
            System.out.print(res + (i == 3 ? "" : " "));
        }
        System.out.println();
    }
}
from decimal import Decimal, ROUND_HALF_EVEN

def solve():
    m = int(input())
    n = int(input())
    alpha = float(input())
    data = [list(map(float, input().split())) for _ in range(m)]
    
    min_v = [min(row[j] for row in data) for j in range(3)]
    max_v = [max(row[j] for row in data) for j in range(3)]
    
    s, s2, scross = [0.0]*3, [0.0]*3, [0.0]*3
    sy, syx = 0.0, [0.0]*3
    
    for row in data:
        xn = []
        for j in range(3):
            rng = max_v[j] - min_v[j]
            val = (row[j] - min_v[j]) / rng if rng != 0 else 0.0
            xn.append(val)
            s[j] += val
            s2[j] += val * val
            syx[j] += row[3] * val
        sy += row[3]
        scross[0] += xn[0] * xn[1]
        scross[1] += xn[0] * xn[2]
        scross[2] += xn[1] * xn[2]
        
    w = [0.0] * 4
    for _ in range(n):
        g0 = (w[0]*m + w[1]*s[0] + w[2]*s[1] + w[3]*s[2] - sy) / m
        g1 = (w[0]*s[0] + w[1]*s2[0] + w[2]*scross[0] + w[3]*scross[1] - syx[0]) / m
        g2 = (w[0]*s[1] + w[1]*scross[0] + w[2]*s2[1] + w[3]*scross[2] - syx[1]) / m
        g3 = (w[0]*s[2] + w[1]*scross[1] + w[2]*scross[2] + w[3]*s2[2] - syx[2]) / m
        w[0] -= alpha * g0
        w[1] -= alpha * g1
        w[2] -= alpha * g2
        w[3] -= alpha * g3
            
    wf = [0.0] * 4
    sum_t = 0.0
    for k in range(1, 4):
        rng = max_v[k-1] - min_v[k-1]
        if rng == 0: wf[k] = 0.0
        else:
            wf[k] = w[k] / rng
            sum_t += (w[k] * min_v[k-1]) / rng
    wf[0] = w[0] - sum_t
    
    def banker_round(val):
        d = Decimal(str(val)).quantize(Decimal("0.00"), rounding=ROUND_HALF_EVEN)
        return f"{d:.2f}"

    print(" ".join(banker_round(v) for v in wf))

solve()

算法及复杂度

  • 算法:线性回归 + BGD 优化(预计算特征项和)。
  • 时间复杂度:。预处理扫描一次数据 ,后续迭代由于使用预计算的和,单次迭代
  • 空间复杂度:
全部评论

相关推荐

02-26 13:56
已编辑
重庆财经学院 Java
King987:你有实习经历,但是写的也太简单了,这肯定是不行的,你主要要包装实习经历这一块,看我的作品,你自己包装一下吧,或者发我,我给你出一期作品
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务