在机器学习项目中,特别是涉及到图像识别和分类的领域,经常需要对大量数据进行预处理。这些数据预处理可能包括图像转换、格式化标签、数据集划分等。本文将介绍一个基于Python的脚本,该脚本能够自动化这些常见任务,并且还支持多进程处理以加速这些操作。

脚本核心功能

  该脚本具有以下核心功能:
  (1)读取XML标注文件,提取类标签,确保标注的一致性。
  (2)将XML标注文件转换为YOLO格式。
  (3)对图像数据进行处理,如移动和重命名。
  (4)划分数据集为训练集和验证集。
  (5)支持多进程处理,提高处理效率。

主要函数解析

  以下是一些核心函数的解析,它们构成了脚本的主要功能:
  (1)revise_data(path)
  该函数遍历指定路径下的所有XML标注文件,提取其中的类标签,并确保类的一致性。这对于数据清洗和一致性核对是非常有用的。
  (2)convert_annotation(image_id, father_dir, classes)
  通过这个函数,可以将指定图像的XML标注文件转换为YOLO格式的标签文件。YOLO格式是深度学习中常用的,适用于目标检测任务。
  (3)process_image(sub_dir, data_dir, temp_dir)
  此函数用于处理图像数据,主要包括复制图像到临时文件夹,对图像进行必要的预处理,之后清除原始图像文件夹,并将处理后的图像移动到原始位置。
  (4)split_train_val_datasets_multiprocessing(all_datasets_image_path, Number_of_intervals, num_processes)
  数据集划分非常重要,它决定了训练集和验证集的数据。这个函数可以自动化这一过程,并且支持多进程,以便在处理大规模数据集时加快速度。
  (5)transfer_labels_multiprocessing(image_path, classes, num_processes)
  当需要将XML格式的标注数据转换为YOLO格式时,这个函数可以把该过程多进程化,显著提升效率。

使用案例

  如果你有一个包含图片和XML标注文件的大型数据集,并想要转换标注格式,清洗数据,并分割数据集,你可以这样使用脚本:
  (1)设置dataDir, temp_dir, saveDir, sub_list等变量,确保它们指向正确的文件夹路径。
  (2)调用revise_data()来清洗和核对XML标注。
  (3)然后使用transfer_labels_multiprocessing()函数批量转换标注格式。
  (4)利用split_train_val_datasets_multiprocessing()函数划分数据集。
  (5)最后通过process_image_multiprocessing()处理图像数据。

代码

import cv2
import os
import glob
import shutil
import os.path
import os.path
import numpy as np
from tqdm import tqdm
from pathlib import Path
from multiprocessing import Pool
import xml.etree.ElementTree as ET

global names


def revise_data(path):
    class_label_list = []
    for xml_file in tqdm(glob.glob(path + '/*.xml')):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall('object'):
            objectname = member.find('name').text
            class_label_list.append(objectname)

        tree.write(xml_file, encoding="UTF-8")
    list_class = list(set(class_label_list))
    print(len(list_class))
    print(list_class)

    return list_class


def get_file_list(file_path):
    dir_list = os.listdir(file_path)
    if not dir_list:
        return
    else:

        dir_list = sorted(dir_list, key=lambda x: os.path.getmtime(os.path.join(file_path, x)))
        return dir_list


def bnd_box_to_yolo_line(size, box):
    x_min = max(0, min(box[0], size[0]))
    x_max = max(0, min(box[1], size[0]))
    y_min = max(0, min(box[2], size[1]))
    y_max = max(0, min(box[3], size[1]))

    x_center = float((x_min + x_max)) / 2 / size[0]
    y_center = float((y_min + y_max)) / 2 / size[1]

    w = float((x_max - x_min)) / size[0]
    h = float((y_max - y_min)) / size[1]

    return x_center, y_center, w, h


def convert_annotation(image_id, father_dir, classes):
    in_file = open(father_dir + '/annotations/%s.xml' % (image_id))
    out_file = open(father_dir + '/labels/%s.txt' % (image_id), 'w')
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('
04-13 07:55