本文着重讲如何快速上手跑出结果。本项目基于resnet34识别四类示意图,由cat vs dog项目改写而来。

dataset  数据集预处理等工作。

rename  数据集图片重命名用,后面会讲

test_model是从checkpoints里取出来训练好的模型改个名,文件夹里是我们的模型

test  测试程序,train  训练程序。


运行rename.py会在image/processed生成重命名好的图片。格式为sktech1.0.jpg、sktech1.1.jpg、sktech1.2.jpg等。将这些图片二八分开分别放入datasets/test和datasets/train

四类图片都要这样处理。

需要注意的是,最后无论是test文件夹还是train文件夹,图片的id要求是分类其他类型的图片,不是你给的示意图。怎么办?

要求是图片的二/三分类,怎么修改代码?

答:以二分类为例。修改以下代码:

datasets.py:第60行

(学生快速上手向)python图片分类识别器-LMLPHP

 从四类改两类。

rename.py:重命名图片跟着上面步骤做。

test_modification.py:

29行的model.fc = nn.Linear(512, 4)   把4改成2.

48行(下图)改2类

(学生快速上手向)python图片分类识别器-LMLPHP

72行同理:

(学生快速上手向)python图片分类识别器-LMLPHP

train.py:

 30行model.fc = nn.Linear(512,4)   把4改成2

110行confusion_matrix = meter.ConfusionMeter(4)  把4改2

120行accuracy = 100.* (cm_value[0][0] + cm_value[1][1] + cm_value[2][2] + cm_value[3][3]) / (cm_value.sum())    把cm_value[2][2] + cm_value[3][3])删掉,只留两类。

应该就这些,改不好来评论区问。

③你这项目没做可视化啊?

答:确实。


本文结束


以下代码无关本文,仅充数用

# coding=utf-8

""" test
使用测试集测试模型结果
"""

from config import _setting_
import os
import torch as t
from dataset import NatureSketchClassification
from torch.utils.data import DataLoader
from torchnet import meter
from torch.autograd import Variable
from torchvision import models
from torch import nn
import time
import csv


""""""
def test(**kwargs):
	# set data
	test_data = NatureSketchClassification(_setting_.test_data_root, test=True)
	test_dataloader = DataLoader(test_data, batch_size=_setting_.batch_size, shuffle=False, num_workers=_setting_.num_workers)
	results = []

	# set model
	model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
	model.fc = nn.Linear(512, 4)
	model.load_state_dict(t.load('./test_model.pth', map_location='cpu'))
	model.eval()

	for id, (data, path) in enumerate(test_dataloader):
		# input = Variable(data,volatile=True)

		with t.no_grad():
			input = Variable(data)

		score = model(input)
		print('score=',score)#检验score
		path = path.numpy().tolist()
		_,predicted = t.max(score.data,1)
		#Modification
		predicted = predicted.data.cpu().numpy().tolist()
		res = ""
		print('predicted=',predicted)#检验predicted
		#Modification
		for (i, j) in zip(path, predicted):
			if j == 0:
				res = "sketch1"
			elif j == 1:
				res = "sketch2"
			elif j == 2:
				res = "sketch3"
			elif j == 3:
				res = "sketch4"
			print('res=',res)#检验res(result)
			results.append([i,"".join(res)])

	
	res = []
	truth = ""
	compare = ""
	imgs = [os.path.join(_setting_.test_data_root,img) for img in os.listdir(_setting_.test_data_root)] #获取root路径下所有图片的地址
	imgs_num = len(imgs) # 图片数量
	NumofCorrect = 0
	imgs = sorted(imgs,key=lambda x: int(x.split('.')[-2].split('/')[-1])) # 按序号排序
	for image in imgs:
		id = int(image.split('.')[-2].split('/')[-1]) # 获取id
		#Modification
		
		if 'sketch1' in image.split('/')[-1]:
			truth = 'sketch1'
		elif 'sketch2' in image.split('/')[-1]:
			truth = 'sketch2'
		elif 'sketch3' in image.split('/')[-1]:
			truth = 'sketch3'
		else:
			truth = 'sketch4'
		print('truth=',truth)
		#truth = 'nature' if 'nature' in image.split('/')[-1] else 'sketch' # 获取图片的真实分类
		compare = 'true' if truth == results[id - 1][1] else 'false'
		if compare == 'true':
			NumofCorrect = NumofCorrect + 1
		res.append([results[id - 1][0], results[id - 1][1], "".join(truth), compare])

	Accuracy = NumofCorrect / imgs_num * 100
	round(Accuracy, 2)
	write_csv(res, _setting_.result_file, Accuracy)

	for id, label, truth, compare in res:
		if compare == 'false':
			print("number: "+ str(id) + ", res: " + label + ", truth: " + truth + ", IsCorrect: " + compare)
	print("Accuracy: " + str(Accuracy))
	return results


""""""
def write_csv(results, file_name, acc):
	Accuracy = []
	Accuracy.append([" ", "Accuracy", "".join(str(acc))])
	with open(file_name, "w") as f:
		writer = csv.writer(f)
		writer.writerow(['id', 'label', 'truth', 'IsCorrect'])
		writer.writerows(results)
		writer.writerows(Accuracy)

if __name__ == '__main__':
	test()
10-30 13:40