• 微信公众号:美女很有趣。 工作之余,放松一下,关注即送10G+美女照片!

第3次作业:卷积神经网络

开发技术 开发技术 5小时前 2次浏览

卷积神经网络

1.加载数据(MNIST)

PyTorch里包含了 MNIST, CIFAR10 等常用数据集,调用 torchvision.datasets 即可把这些数据由远程下载到本地。

Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,再使用DataLoader这个类来更加快捷的对数据进行操作。其中getitem方法支持从0到len(self)的索引

 

第3次作业:卷积神经网络

显示数据中部分图像:

plt.subplot(x,y,i)让图像呈矩阵显示,x表示行,y表示列,i表示第几个图像。

plt.axis(‘off’)表示不显示坐标轴

第3次作业:卷积神经网络

2.创建网络

nn.Sequential是一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数,将网络的层组合在一起。

nn.Linear(a,b)用于设置网络中的全连接层,第一个参数表示输入的二维张量的大小,第二个参数表示输出的二维张量的大小,同时也代表全连接层的神经元个数

nn.ReLu()使用ReLu激活函数,加快收敛,防止梯度消失。

下图是全连接神经网络的实现。

第3次作业:卷积神经网络

forward方法的具体流程,以一个Module为例:
1. 调用module的call方法
2. module的call里面调用module的forward方法
3. forward里面如果碰到Module的子类,回到第1步,如果碰到的是Function的子类,继续往下
4. 调用Function的call方法
5. Function的call方法调用了Function的forward方法。
6. Function的forward返回值
7. module的forward返回值
8. 在module的call进行forward_hook操作,然后返回值。

下图是卷积神经网络的实现。

 第3次作业:卷积神经网络

定义训练函数和测试函数,训练使用BP算法

第3次作业:卷积神经网络

 第3次作业:卷积神经网络

 3.在小型全连接网络上训练

第3次作业:卷积神经网络

4.在卷积神经网络上训练

 第3次作业:卷积神经网络

 对比两个结果会发现,在参数相同的情况下,CNN能够更准确的对图片中的数字进行识别,这是因为CNN利用卷积和池化能够更好的挖掘图片中的有效信息,CNN可以更好地利用图片的局部信息,但是全连接网络却做不到。为了验证这个结论,接下来将图片的像素顺序打乱,训练和测试函数也基本相同,只是对data加入了打乱顺序操作,然后会发现CNN的准确率大幅度下降,而全连接神经网络的准确率没有变化。

 


程序员灯塔
转载请注明原文链接:第3次作业:卷积神经网络
喜欢 (0)