warp loss

warp loss是推荐系统中针对rank-pairwise的损失函数


paddle tagspace实现,包含了warp loss损失函数的实现

def network(vocab_text_size, vocab_tag_size, emb_dim=10, hid_dim=1000, win_size=5, margin=0.1, neg_size=5):
    """ network definition """
    text = io.data(name="text", shape=[1], lod_level=1, dtype='int64')
    pos_tag = io.data(name="pos_tag", shape=[1], lod_level=1, dtype='int64')
    neg_tag = io.data(name="neg_tag", shape=[1], lod_level=1, dtype='int64')
    text_emb = nn.embedding(
            input=text, size=[vocab_text_size, emb_dim], param_attr="text_emb")
    pos_tag_emb = nn.embedding(
            input=pos_tag, size=[vocab_tag_size, emb_dim], param_attr="tag_emb")
    neg_tag_emb = nn.embedding(
            input=neg_tag, size=[vocab_tag_size, emb_dim], param_attr="tag_emb")

    conv_1d = fluid.nets.sequence_conv_pool(
            input=text_emb,
            num_filters=hid_dim,
            filter_size=win_size,
            act="tanh",
            pool_type="max",
            param_attr="cnn")
    text_hid = fluid.layers.fc(input=conv_1d, size=emb_dim, param_attr="text_hid")
    cos_pos = nn.cos_sim(pos_tag_emb, text_hid)
    mul_text_hid = fluid.layers.sequence_expand_as(x=text_hid, y=neg_tag_emb)
    mul_cos_neg = nn.cos_sim(neg_tag_emb, mul_text_hid)
    cos_neg_all = fluid.layers.sequence_reshape(input=mul_cos_neg, new_dim=neg_size)
    #choose max negtive cosine
    cos_neg = nn.reduce_max(cos_neg_all, dim=1, keep_dim=True)
    #calculate hinge loss
    loss_part1 = nn.elementwise_sub(
            tensor.fill_constant_batch_size_like(
                input=cos_pos,
                shape=[-1, 1],
                value=margin,
                dtype='float32'),
            cos_pos)
    loss_part2 = nn.elementwise_add(loss_part1, cos_neg)
    loss_part3 = nn.elementwise_max(
        tensor.fill_constant_batch_size_like(
            input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'),
        loss_part2)
    avg_cost = nn.mean(loss_part3)
    less = tensor.cast(cf.less_than(cos_neg, cos_pos), dtype='float32')
    correct = nn.reduce_sum(less)
    return avg_cost, correct, cos_pos
全部评论

相关推荐

牛客60022193...:大厂都招前端,他们觉得AI能替代前端,可能他们公司吊打btaj吧
点赞 评论 收藏
分享
2025-12-05 10:31
门头沟学院 Java
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务