简单记录做cs231n作业学到的pytorch的小技巧

简单记录做cs231n作业学到的pytorch的小技巧

transform的输入必须是PIL image,并且一次只支持一张图像的增强
def preprocess(img, size=224):
    transform = T.Compose([
        T.Resize(size),
        T.ToTensor(),
        T.Normalize(mean=SQUEEZENET_MEAN.tolist(),
                    std=SQUEEZENET_STD.tolist()),
        T.Lambda(lambda x: x[None]), # ???
    ])
    return transform(img)
torch gather用法
  • 要求s, y都是二维tensor或者三维tensor,shape可以不同
# Example of using gather to select one entry from each row in PyTorch
def gather_example():
    N, C = 4, 5
    s = torch.randn(N, C)
    y = torch.LongTensor([1, 2, 1, 3])
    print(s)
    print(y)
    print(s.gather(1, y.view(-1, 1)).squeeze())
#     print(s.gather(1, y.view(-1, 1)).squeeze()) 
gather_example()
torch max 的常见用法
x = torch.randn(1,2)
print(x.max(1)) # max(values, index)二元组
print()
print(x.max(1)[0], x.max(1)[1])

# 输出
"""
torch.return_types.max(
values=tensor([-0.0182]),
indices=tensor([0]))

tensor([-0.0182]) tensor([0])"""

# 想要得到index 直接用.argmax(dim=1)
a = torch.arange(0, 6).view(2,3)
print(a)
a.argmax(1)
PIL.Image.fromarray 用法

[图片上传失败...(image-b39634-1575038993735)]

PIL.Image.open 用法

[图片上传失败...(image-8edb96-1575038993735)]

torch 需要求导的变量的原地操作
x.data.copy_(jitter(img.data, ox, oy))

需要使用x.data.copy_()方法实现

全部评论

相关推荐

点赞 收藏 评论
分享
牛客网
牛客企业服务