• 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏吧

如何入门Pytorch之四:搭建神经网络训练MNIST

开发技术 开发技术 3周前 (09-13) 20次浏览

       上一节我们学习了Pytorch优化网络的基本方法,本节我们将以MNIST数据集为例,通过搭建一个完整的神经网络,来加深对Pytorch的理解。

一、数据集

       MNIST是一个非常经典的数据集,下载链接:http://yann.lecun.com/exdb/mnist/

      下载下来的文件如下:

如何入门Pytorch之四:搭建神经网络训练MNIST

 

该手写数字数据库具有60,000个示例的训练集和10,000个示例的测试集。它是NIST提供的更大集合的子集。数字已经过尺寸标准化,并以固定尺寸的图像为中心。

手写数字识别是一个比较简单的任务,它是一个10分类问题,(0-9),之所以选这个数据集,是因为识别难度低,计算量小,数据容易获得。

二、模型搭建

    1、网络节点的确定

    对于不同的目的,网络的选择也是不一样的。一般来说,网络容量和数据集大小是对应的。一个小型数据集也只需要一个小型的网络。

这里有一个经验值:

      1)model_size=sqrt(in_size*out_size)

      2)model_size=log(in_size)

      3)  model_size=sqrt(in_size*out_size)

      model_size:网络的节点量

      in_size:输入的节点量

      out_size输出的节点量

     2、导入pytorch包

import torch
import torchvision
import trochvision import datasets
import trochvision import transforms
from torch.autograd import Variable

    3、获取训练集和测试集

#root用于指定数据集下载后的存放路径
#transform用于指定导入数据集需要对数据进行变换操作
#train指定在数据集下载后需要载入哪部分数据,true为训练集,false为测试集
data_train=datasets.MNIST(root="./data/",transform=transform,train=True,download=True) data_test=datasets.MNIST(root='./data/',transform=transform,train=False)

    4、数据预览和装载

#数据装载,可以理解为对图片的处理
#处理完成后,将图片送给模型训练,装载就是打包的过程
#dataset 用于指定载入的数据集名称
#batch_size设置了每个包的图片数据数据个数
#shuffle 装载过程将数据随机打乱并打包
data_loader_train=torch.utils.data.DataLoader(dataset=data_train,batch_size=64,shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,batch_size=64,shuffle=True)

 

  

 

   

 


喜欢 (0)