- 在训练神经网络的时候通常都会写一个训练代码块,通过这个代码块的执行开始训练网络
- 学习神经网络模型时写的训练代码块, 整个流程就是在写一个脚本文件:
- 定义数据加载器
- 定义优化器
- 定义损失函数
- 模型实例化
- 循环读取数据,开始迭代训练
- 形成一个抽象的类Trainer,这样能够在一定程度上提高代码的复用性、可读性以及可扩展性
- 动态的将需要的功能模块“注册”到Trainer类中,而不需要去修改Trainer最原始的定义,实现可扩展的功能,对Trainer类进行升级,使他能够具备插件化处理的功能。
- 定义了一个插件队列的字典, 它保存不同时机调用的插件序列, 那么问题来了,一般会在什么时候调用这些插件呢?
- 在每次获取到数据之后,训练之前 对数据进行不同处理?
- 在完成一次backward操作之后 显示当前loss 或者accuracy
- 在完成每次batch or epoch 后保存模型或者修改学习率?
- 定义了四种类别的插件:
- iteration:一般是在完成一个batch 训练之后进行的事件调用序列(一般不改动网络或者优化器,如:计算准确率)调用序列;
- batch 在进行batch 训练之前需要进行的事件调用序列
- epoch 完成一个epoch 训练之后进行的事件调用序列
- update 完成一个batch训练之后进行的事件(涉及到对网络或者优化器的改动,如:学习率的调整)
- 注意,iteration 跟update 两种插件调用的时候传入的参数不一样,iteration 会传入batch output,loss 等训练过程中的数据, 而update传入的的model,方便对网络的修改
-
pytorch lightning:
-
自动early stopping,自动batch_size, leaning rate搜索
-
当模型比较复杂时,为了提高代码的可读性,建议先使用torch.nn.Module构建网络,再使用pytorch_lightning.LightningModule对其进行包装
-
使用Pytorch进行开发时,到这里模型定义就结束了,其余的训练、验证的具体实现会放在train.py等训练代码中;但Pytorch Lightning模型实现的是整个系统,所以训练的细节也会在这个类中实现
-
当模型比较复杂时,为了提高代码的可读性,建议先使用torch.nn.Module构建网络,再使用pytorch_lightning.LightningModule对其进行包装
-