题解 | Log Softmax函数的实现(含推导)

Log Softmax函数的实现

https://www.nowcoder.com/practice/a8a0934f25f04c7e97d64d3e1b77219a

第一个答案的推导不知道在写什么东西,这里给出比较简单的推导思路:

首先,标准的Softmax函数是以下形式:

由于实际实现过程中会出现溢出问题,来源就是函数内存在指数运算,一旦a的取值太大(如exp(1000)等等...),就会出现无法表示的问题,所以通过以下方式改进:

为了更好地防止溢出,,一般会使用输入信号中的最大值,即C' = -max(a)

那么log_max函数的改进形式就可以从上述式子中推出:

对应代码实现如下:

import numpy as np

def log_softmax(scores: list) -> np.ndarray:
    c = np.max(scores)
    exp_val = np.exp(scores - c)
    sum_val = np.sum(exp_val)
    result = scores - c - np.log(sum_val)
    return result

if __name__ == "__main__":
    scores = eval(input())
    print(log_softmax(scores))

全部评论

相关推荐

评论
1
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务