深度学习之PyTorch实战(1)——手写数字识别(LSTM)

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable

input_size = 28
sequence_length = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 1
learning_rate = 0.01

train_datasets = dsets.MNIST(root='./data',
                             download=True,
                             train=True,
                             transform=transforms.ToTensor())
test_datasets = dsets.MNIST(root='./data',
                            download=False,
                            train=False,
                            transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_datasets,
                                           batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_datasets,
                                          batch_size=batch_size,
                                          shuffle=False)
class RNN(nn.Module):
    def __init__(self,input_size,hidden_size,num_layers,num_classes):
        super(RNN,self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        #https://pytorch.org/docs/master/nn.html?highlight=lstm#torch.nn.LSTM
        #参考官方文档
        self.lstm = nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)
        self.fc = nn.Linear(hidden_size,num_classes)
    def forward(self, x):
        h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
        c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
        out,_ = self.lstm(x,(h0,c0))
        #选择最后一个时间点的output
        out = self.fc(out[:,-1,:])

        return out

rnn = RNN(input_size,hidden_size,num_layers,num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(),lr=learning_rate)

for epoch in range(num_epochs):
    for i,(images,labels) in enumerate(train_loader):
        images = Variable(images.view(-1,sequence_length,input_size))
        labels = Variable(labels)
        optimizer.zero_grad()
        outputs = rnn(images)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        if (i+1) % 2 == 0:
            print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'
                  % (epoch + 1, num_epochs, i + 1, len(train_loader), loss.item()))

# Test the Model
correct = 0
total = 0
for images, labels in test_loader:
    images = Variable(images.view(-1, sequence_length, input_size))
    outputs = rnn(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))

# Save the Model
torch.save(rnn.state_dict(), 'rnn.pkl')
全部评论

相关推荐

nus22016021404:兄弟,你这个简历撕了丢了吧,就是一坨,去找几个项目,理解项目流程,看几遍就是你的了,看看八股就去干了,多看看牛客里别人发出来的简历,对着写,你这写的啥啊,纯一坨
点赞 评论 收藏
分享
05-11 11:48
河南大学 Java
程序员牛肉:我是26届的双非。目前有两段实习经历,大三上去的美团,现在来字节了,做的是国际电商的营销业务。希望我的经历对你有用。 1.好好做你的CSDN,最好是直接转微信公众号。因为这本质上是一个很好的展示自己技术热情的证据。我当时也是烂大街项目(网盘+鱼皮的一个项目)+零实习去面试美团,但是当时我的CSDN阅读量超百万,微信公众号阅读量40万。面试的时候面试官就告诉我说觉得我对技术挺有激情的。可以看看我主页的美团面试面经。 因此花点时间好好做这个知识分享,最好是单拉出来搞一个板块。各大公司都极其看中知识落地的能力。 可以看看我的简历对于博客的描述。这个帖子里面有:https://www.nowcoder.com/discuss/745348200596324352?sourceSSR=users 2.实习经历有一些东西删除了,目前看来你的产出其实很少。有些内容其实很扯淡,最好不要保留。有一些点你可能觉得很牛逼,但是面试官眼里是减分的。 你还能负责数据库表的设计?这个公司得垃圾成啥样子,才能让一个实习生介入数据库表的设计,不要写这种东西。 一个公司的财务审批系统应该是很稳定的吧?为什么你去了才有RBAC权限设计?那这个公司之前是怎么处理权限分离的?这些东西看着都有点扯淡了。 还有就是使用Redis实现轻量级的消息队列?那为什么这一块不使用专业的MQ呢?为什么要使用redis,这些一定要清楚, 就目前看来,其实你的这个实习技术还不错。不要太焦虑。就是有一些内容有点虚了。可以考虑从PR中再投一点产出
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务