题解 | 结构化剪枝后的分类预测
结构化剪枝后的分类预测
https://www.nowcoder.com/practice/e00ba135dfa24b6ea8c2aa3bb3cdd67f
import sys
import math
input = sys.stdin.readline
n ,d, c = map(int ,input().split())
X = []
W = []
for _ in range(n):
x = list(map(float,input().split()))
X.append(x)
for _ in range(d):
w = list(map(float,input().split()))
W.append(w)
ratio = float(input())
if ratio > 0 and math.floor(ratio*d) ==0:
k = 1
else:
k = math.floor(ratio*d)
for _ in range(k):
L1 =[]
for i in range(len(W)):# W 的长度是动态变化的
s = 0
for a in range(c):
s += abs(W[i][a])
L1.append(s)
min_idx = L1.index(min(L1))
del W[min_idx]
for row in X:
del row[min_idx]
h = [[0]*c for _ in range(n)]
m = d-k
for i in range(n):
for j in range(c):
for k in range(m):
h[i][j] += X[i][k]*W[k][j]
M = []
for row in h:
max_idx = row.index(max(row))
M.append(max_idx)
print(*M)
查看13道真题和解析