题解 | 对称INT8量化方案

对称INT8量化方案

https://www.nowcoder.com/practice/0b25a6d7787a4d9cab40c7cc26d5eee2

超级大模拟,务必小心写

import math

def read():
    M,N=map(int,input().split())
    arr = [list(map(float,input().split())) for _ in range(M)]
    return arr,M,N

def clip(v,low,up):
    if v<low:
        return low
    if v>up:
        return up
    return v

def transpose(m):
    return list(zip(*m))

a1,M,K=read()
a2,K2,N=read()

if K!=K2:
    raise ValueError("1?")

def quant(arr):
    res=[]
    sas=[]
    for row in arr:
        sa=max(list(map(abs,row)))/127
        nrow=[
            clip(round(e/(sa+1e-18)),-127,127)
            for e in row
        ]
        res.append(nrow)
        sas.append(sa)
    return res,sas

qa1,sa1=quant(a1)
qa2,sa2=quant(transpose(a2))
qa2=transpose(qa2)

if len(sa1)!=M or len(sa2)!=N:
    raise ValueError("2?")

resm=[]
for i,r1 in enumerate(qa1):
    nrow=[]
    for j in range(len(qa2[0])):
        res=0
        for k,e1 in enumerate(r1):
            e2=qa2[k][j]
            res+=e2*e1
        nrow.append(res)
    resm.append(nrow)

if len(resm)!=M or len(resm[0])!=N:
    raise ValueError("3?")

for i,row in enumerate(resm):
    for j,v in enumerate(row):
        resm[i][j]=v*sa1[i]*sa2[j]

for row in resm:
    strs=[format(e,'.2f') for e in row]
    print(' '.join(strs))




全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务