目的

使用卷积网络、全连接网络,对MINIST数据进行分类

形状说明

卷积网络:
通常,卷积神经网络都是一个4D的形状输入,(batch,channel,行列值,行列值),如代码中的(64,1,28,28)
通常,卷积神经网络都是一个4D的形状输出,(batch,out_channel,行列值,行列值),如代码中的(64,32,7,7)

输入到全连接网络:
输入值应当为图片信息,输入形状需要根据图像大小展平reshape(-1, 32×7×7),即为(batch=64,32×7×7)
输出形状根据nn.Linear(256, 10)的10个分类可知,形状为(batch=64,10)

形状说明-计算loss

y_pred = model(x)
loss = loss_fn(y_pred, y_label)

训练中,一次batch=64的数据中:

  • 网络的输出y_pred.shape为(64,10)
  • minist的label是一个数值,并不是独热编码,因此,学习目标y_label.shape为(64)
12-08 06:27