题解 | 对称INT8量化方案
对称INT8量化方案
https://www.nowcoder.com/practice/0b25a6d7787a4d9cab40c7cc26d5eee2
超级大模拟,务必小心写
import math
def read():
M,N=map(int,input().split())
arr = [list(map(float,input().split())) for _ in range(M)]
return arr,M,N
def clip(v,low,up):
if v<low:
return low
if v>up:
return up
return v
def transpose(m):
return list(zip(*m))
a1,M,K=read()
a2,K2,N=read()
if K!=K2:
raise ValueError("1?")
def quant(arr):
res=[]
sas=[]
for row in arr:
sa=max(list(map(abs,row)))/127
nrow=[
clip(round(e/(sa+1e-18)),-127,127)
for e in row
]
res.append(nrow)
sas.append(sa)
return res,sas
qa1,sa1=quant(a1)
qa2,sa2=quant(transpose(a2))
qa2=transpose(qa2)
if len(sa1)!=M or len(sa2)!=N:
raise ValueError("2?")
resm=[]
for i,r1 in enumerate(qa1):
nrow=[]
for j in range(len(qa2[0])):
res=0
for k,e1 in enumerate(r1):
e2=qa2[k][j]
res+=e2*e1
nrow.append(res)
resm.append(nrow)
if len(resm)!=M or len(resm[0])!=N:
raise ValueError("3?")
for i,row in enumerate(resm):
for j,v in enumerate(row):
resm[i][j]=v*sa1[i]*sa2[j]
for row in resm:
strs=[format(e,'.2f') for e in row]
print(' '.join(strs))
查看13道真题和解析