在前面的章节中,我们学习了什么是深度学习,如何构建不同的深度学习的神经网络,如何优化网络,优化模型,如何对网络中的参数实现可视化,深度学习的基础知识我们都已经学习,现在需要的就是加强我们对知识的掌握。   这一节,我们开启一个任务图像分类任务的实战练习,同样是我们熟悉的CIFAR-10数据集,但不同的是我们将完整走一遍使用深度学习进行项目实战的流程。   1. 数据获取与分析   CIFAR-10数据集包含10个不同类别的图像,每个类别有6000张32x32像素的彩色图像。我们的目标是训练一个模型,能够准确地对这些图像进行分类。我们首先需要下载数据集,它是torchvision的内置数据集,可以使用代码直接下载:   import torchimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport numpy as np# 加载CIFAR-10数据集transform = transforms.Compose(    [transforms.ToTensor(),     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                        download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,                                          shuffle=True, num_workers=2)# 类别名称classes = ('plane', 'car', 'bird', 'cat',           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')   然后,我们其实查看一下数据集,看看数据集到底长什么样子,有没有什么需要处理的,或者了解数据集什么样之后便与我们后续调优:   # 显示一些训练图像def imshow(img):    img = img / 2 + 0.5     # 反归一化    npimg = img.numpy()    plt.imshow(np.transpose(npimg, (1, 2, 0)))    plt.show()# 随机获取一些训练图像dataiter = iter(trainloader)images, labels = next(dataiter)# 显示图像及其标签imshow(torchvision.utils.make_grid(images))print(' '.join('%5s' % classes[labels[j]] for j in range(4)))      从上图可以看出,整体上图像质量不高,属于低分辨率的小图像,对硬件的要求也不会太高。   2. 数据预处理   一般进行图像任务的深度学习之前,我们会考虑将图像进行标准化,同时为了提升模型的鲁棒性与泛化性,防止过拟合,我们也会考虑给数据进行一些翻转、裁剪的增强:   # 数据预处理transform_train = transforms.Compose([    transforms.RandomCrop(32, padding=4),    transforms.RandomHorizontalFlip(),    transforms.ToTensor(),    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])transform_test = transforms.Compose([    transforms.ToTensor(),    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载训练集和测试集trainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                        download=True, transform=transform_train)trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,                                          shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,                                       download=True, transform=transform_test)testloader = torch.utils.data.DataLoader(testset, batch_size=100,                                         shuffle=False, num_workers=2)   在上述代码中,我们定义了用于训练集和测试集的数据预处理操作。训练集使用了数据增强技术,包括随机裁剪和水平翻转,以增加数据的多样性。测试集是我们用于验证模型性能的,因此要保证完整性和标准,不需要进行数据增强,但是要和训练集保持一致的标准化。   3. 模型、损失函数、优化器   在这一步中,我们不再自己定义模型。我们使用前人实践经验出来的好模型,比如预训练的ResNet-18模型,这些模型与预训练参数都在torchvision库中,可以直接加载。   import torch.nn as nnimport torchvision.mode                       
点赞 1
评论 2
全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务