目录
1. train 训练数据
训练的代码只是在之前图像分类的基础上做了一些更改,具体的可以看下面的文章
pytorch 搭建 LeNet 网络对 CIFAR-10 图片分类https://blog.csdn.net/qq_44886601/article/details/127498256
首先,导入之前定义的UNet 网络
然后,加载训练集和测试集
这里训练的时候可以将数据打乱,测试的时候没有必要,batch_size 因为电脑硬件的问题设置成2,再大的话这里内存就会不够了
然后定义优化器和损失函数,这里用的是BCE加上sigmoid的损失函数
训练的时候,要将模式改为train模式,然后训练的步骤很常规
梯度清零->前向传播->计算损失函数->反向传播->更新参数
这里测试的时候有些区别
因为这里UNet 网络的输出是一幅图像,而之前将label改为了二值图像(归一化后是0 1)。所以这里计算准确率的时候,将预测的图像也变为二值图像,计算准确率用的是对应图像像素点的灰度值是否相等的方法
最后保留最好准确率的那个参数就行了
2. Loss 值
这是跑了20 个epoch的输出
3. 完整代码
from model import UNet # 导入Unet 网络
from dataset import Data_Loader # 数据处理
from torch import optim
import torch.nn as nn
import torch
# 网络训练模块
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # GPU or CPU
print(device)
net = UNet(in_channels=1, num_classes=1) # 加载网络
net.to(device) # 将网络加载到device上
# 加载训练集
train_path = "./data/train/image"
trainset = Data_Loader(train_path)
train_loader = torch.utils.data.DataLoader(dataset=trainset,batch_size=2,shuffle=True)
# len(trainset) 样本总数:21
# 加载测试集
test_path = "./data/test/image"
testset = Data_Loader(test_path)
test_loader = torch.utils.data.DataLoader(dataset=testset,batch_size=2)
optimizer = optim.RMSprop(net.parameters(),lr = 0.000001,weight_decay=1e-8,momentum=0.9) # 定义优化器
criterion = nn.BCEWithLogitsLoss() # 定义损失函数
save_path = './UNet.pth' # 网络参数的保存路径
best_acc = 0.0 # 保存最好的准确率
for epoch in range(20):
net.train() # 训练模式
running_loss = 0.0
for image,label in train_loader: # 读取数据和label
optimizer.zero_grad() # 梯度清零
pred = net(image.to(device)) # 前向传播
loss = criterion(pred, label.to(device)) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 梯度下降
running_loss += loss.item() # 计算损失和
net.eval() # 测试模式
acc = 0.0 # 正确率
total = 0
with torch.no_grad():
for test_image, test_label in test_loader:
outputs = net(test_image.to(device)) # 前向传播
outputs[outputs >= 0] = 1 # 将预测图片转为二值图片
outputs[outputs < 0] = 0
acc += (outputs == test_label.to(device)).sum().item() / (480*480) # 计算预测图片与真实图片像素点一致的精度:acc = 相同的 / 总个数
total += test_label.size(0)
accurate = acc / total # 计算整个test上面的正确率
print('[epoch %d] train_loss: %.3f test_accuracy: %.3f %%' %
(epoch + 1, running_loss, accurate*100))
if accurate > best_acc: # 保留最好的精度
best_acc = accurate
torch.save(net.state_dict(), save_path) # 保存网络参数