pytorch的nn.CrossEntropyLoss()函数使用方法

nn.CrossEntropyLoss()函数计算交叉熵损失

用法:

# output是网络的输出,size=[batch_size, class]
#如网络的batch size为128,数据分为10类,则size=[128, 10]

# target是数据的真实标签,是标量,size=[batch_size]
#如网络的batch size为128,则size=[128]

crossentropyloss=nn.CrossEntropyLoss()
crossentropyloss_output=crossentropyloss(output,target)

注意,使用nn.CrossEntropyLoss()时,不需要现将输出经过softmax层,否则计算的损失会有误,即直接将网络输出用来计算损失即可

nn.CrossEntropyLoss()的计算公式为:
图片说明

其中x是网络的输出向量,class是真实标签

举个例子,一个三分类网络对某个输入样本的输出为[-0.7715, -0.6205,-0.2562],该样本的真实标签为0,则用nn.CrossEntropyLoss()计算的损失为:
图片说明

全部评论

相关推荐

10-21 00:37
已编辑
门头沟学院 C++
小浪_Coding:你问别人,本来就是有求于人,别人肯定没有义务免费回答你丫, 有点流量每天私信可能都十几,几十条的,大家都有工作和自己的事情, 付费也是正常的, 就像你请别人搭把手, 总得给人家买瓶水喝吧
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务