实验环境

torch1.8.0+torchvision0.9.0

import torch
import torchvision
print(torch.__version__)
print(torchvision.__version__)
1.8.0
0.9.0+cpu

1.PyTorch数据加载

import torchvision.transforms as tfm
from PIL import Image
img = Image.open('volleyball.png')
img_1 = tfm.RandomCrop(200, padding=50)(img)  #随机裁剪图片
img_1.show()
img_1.save('crop.png')
img_2 = tfm.RandomHorizontalFlip()(img)       #随机水平翻转图片
img_2.show()
img_2.save('flip.png')

1.1 数据预处理

torchvision.transforms

transfrom_train = tfm.Compose([
    tfm.RandomCrop(32, padding=4),
    tfm.RandomHorizontalFlip(),   
    tfm.ToTensor(),     #将图片转换为Tensor张量                       
    tfm.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))  #标准化
])

1.2 数据加载

torch.utils.data

loader = torch.utils.data.DataLoader(
    datasets, batch_size=32, shuffle=True, sampler=None,
    num_workers=2, collate_fn=None, pin_memory=True, drop_last=False
)
  • datasets:传入的数据集,可以是自定义的dataset对象或者torchvision中的预定义数据集对象。
  • batch_size:每个batch中包含的样本数量。
  • shuffle:是否打乱数据集。
  • sampler:样本抽样器,如果指定了sampler,则忽略shuffle参数。
  • num_workers:用于数据加载的子进程数量。
  • collate_fn:对样本进行批处理前的预处理函数,可用于对样本进行排序、padding等操作。
  • pin_memory:是否将数据加载到GPU的显存中。
  • drop_last:如果数据集样本数量不能被batch_size整除,则是否舍弃剩余的不足一个batch的样本。

2.PyTorch模型搭建

2.1 经典模型

torchvision.models

from torchvision import models
net1 = models.resnet50()
net2 = models.resnet50(pretrained=True)

2.2 模型加载与保存

model.load_state_dict(torch.load('pretrained_weights.pth'))
torch.save(model.state_dict(), 'model_weights.pth')

3.PyTorch优化器

3.1 torch.optim

optimizer = optim.SGD([       #SGD随机梯度下降算法
    {'params':model.base.parameters()},
    {'params':model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
# 训练过程 
model = init_model_function()               #模型构建
optimizer = optim.SomeOptimizer(            #设置优化器
    model.parameters(), lr, mm
)

for data, label in train_dataloader:
    optimizer.zero_grad()                #前向计算前,清空原有梯度
    output = model(data)                 #前向计算
    loss = loss_function(output, label)  #损失函数
    loss.backward()                      #反向传播 
    optimizer.step()                     #更新参数

3.2 学习率调整

scheduler = optim.lr_scheduler.SomeScheduler(optimizer, *args)
for epoch in range(epochs):
    train()
    test()
    scheduler.step()

常见函数

激活单元类型

损失函数层类型

优化器类型

变换操作类型

数据集名称

torchvision.models中所有实现的分类模型

附:系列文章

10-08 10:38