torch.optim是里面是和优化算法相关的类。比如使用SGD算法用

optimizer = optim.SGD(net.parameters(),lr=1e-3)

这个地方要注意的是传入的第一个参数是网络的parameters。

这个类里面有param_group,是一个字典,里面包括:

params: 网路可学习权重

lr: 学习率

weight_decay:权重衰减

等等私有成员。

这个里面的weight_decay有两个地方需要注意一下,一个是这里面的权重衰减是默认是L2正则化,另外一点是,这个正则化是对于weight和bias都进行正则化,按照《深度学习》里面讲的是,最好仅对于weight进行正则化,对对于bias进行正则化极有可能会造成欠拟合。

05-07 15:44