文章目录
实验环境
torch1.8.0+torchvision0.9.0
import torch
import torchvision
print(torch.__version__)
print(torchvision.__version__)
1.8.0
0.9.0+cpu
1.PyTorch数据加载
import torchvision.transforms as tfm
from PIL import Image
img = Image.open('volleyball.png')
img_1 = tfm.RandomCrop(200, padding=50)(img) #随机裁剪图片
img_1.show()
img_1.save('crop.png')
img_2 = tfm.RandomHorizontalFlip()(img) #随机水平翻转图片
img_2.show()
img_2.save('flip.png')
1.1 数据预处理
torchvision.transforms
transfrom_train = tfm.Compose([
tfm.RandomCrop(32, padding=4),
tfm.RandomHorizontalFlip(),
tfm.ToTensor(), #将图片转换为Tensor张量
tfm.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)) #标准化
])
1.2 数据加载
torch.utils.data
loader = torch.utils.data.DataLoader(
datasets, batch_size=32, shuffle=True, sampler=None,
num_workers=2, collate_fn=None, pin_memory=True, drop_last=False
)
- datasets:传入的数据集,可以是自定义的dataset对象或者torchvision中的预定义数据集对象。
- batch_size:每个batch中包含的样本数量。
- shuffle:是否打乱数据集。
- sampler:样本抽样器,如果指定了sampler,则忽略shuffle参数。
- num_workers:用于数据加载的子进程数量。
- collate_fn:对样本进行批处理前的预处理函数,可用于对样本进行排序、padding等操作。
- pin_memory:是否将数据加载到GPU的显存中。
- drop_last:如果数据集样本数量不能被batch_size整除,则是否舍弃剩余的不足一个batch的样本。
2.PyTorch模型搭建
2.1 经典模型
torchvision.models
from torchvision import models
net1 = models.resnet50()
net2 = models.resnet50(pretrained=True)
2.2 模型加载与保存
model.load_state_dict(torch.load('pretrained_weights.pth'))
torch.save(model.state_dict(), 'model_weights.pth')
3.PyTorch优化器
3.1 torch.optim
optimizer = optim.SGD([ #SGD随机梯度下降算法
{'params':model.base.parameters()},
{'params':model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
# 训练过程
model = init_model_function() #模型构建
optimizer = optim.SomeOptimizer( #设置优化器
model.parameters(), lr, mm
)
for data, label in train_dataloader:
optimizer.zero_grad() #前向计算前,清空原有梯度
output = model(data) #前向计算
loss = loss_function(output, label) #损失函数
loss.backward() #反向传播
optimizer.step() #更新参数
3.2 学习率调整
scheduler = optim.lr_scheduler.SomeScheduler(optimizer, *args)
for epoch in range(epochs):
train()
test()
scheduler.step()
常见函数
激活单元类型
损失函数层类型
优化器类型
变换操作类型
数据集名称
torchvision.models中所有实现的分类模型