目的
使用卷积网络、全连接网络,对MINIST数据进行分类
形状说明
卷积网络:
通常,卷积神经网络都是一个4D的形状输入,(batch,channel,行列值,行列值),如代码中的(64,1,28,28)
通常,卷积神经网络都是一个4D的形状输出,(batch,out_channel,行列值,行列值),如代码中的(64,32,7,7)
输入到全连接网络:
输入值应当为图片信息,输入形状需要根据图像大小展平reshape(-1, 32×7×7),即为(batch=64,32×7×7)
输出形状根据nn.Linear(256, 10)的10个分类可知,形状为(batch=64,10)
形状说明-计算loss
y_pred = model(x)
loss = loss_fn(y_pred, y_label)
训练中,一次batch=64的数据中:
- 网络的输出y_pred.shape为(64,10)
- minist的label是一个数值,并不是独热编码,因此,学习目标y_label.shape为(64)