mltools.data 源代码

from torch.utils import data
import re
import httpx
import numpy as np
from tqdm import tqdm
from pathlib import Path


[文档] class MyDataset(data.Dataset): """ 自定义数据集类,继承自 torch.utils.data.Dataset,用于管理机器学习任务中的数据。 """ def __init__(self, datas: list): """ 初始化数据集 Args: datas (list): 数据集内容,可以是任何格式的数据 """ data.Dataset.__init__(self) self.data = datas def __len__(self) -> int: """ 返回数据集的样本数量 Returns: int: 数据集中的样本总数 """ return len(self.data) def __getitem__(self, idx: int) -> any: """ 根据索引获取单个数据样本 Args: idx (int): 数据索引 Returns: 对应索引的数据样本 """ return self.data[idx]
[文档] def split_data(datas: list, ratio: list) -> list: """ 划分数据集 Args: datas (list): 数据集内容,可以是任何格式的数据 ratio (list): 划分比例,例如 [0.8, 0.2] 表示划分成 80% 训练集和 20% 测试集 Returns: list: 划分后的数据集,每个元素都是一个数据集 """ ratio = [r / sum(ratio) for r in ratio] nums = [int(len(datas) * r) for r in ratio] nums[-1] = len(datas) - sum(nums[:-1]) return data.random_split(datas, nums)
[文档] def iter_data( datas: list, *, batch_size: int, shuffle: bool = True, num_workers: int = 0, pin_memory: bool = False, drop_last: bool = False, ) -> data.DataLoader: """ 迭代数据集 Args: datas (list): 数据集内容,可以是任何格式的数据 batch_size (int): 每个批次的样本数量 shuffle (bool, optional): 是否在每个 epoch 开始时打乱数据. 默认值为 True. num_workers (int, optional): 用于数据加载的子进程数量. 默认值为 0. pin_memory (bool, optional): 是否将数据加载到 CUDA 固定内存中. 默认值为 False. drop_last (bool, optional): 是否丢弃最后一个批次, 如果数据集大小不能被批次大小整除. 默认值为 False. Returns: data.DataLoader: 数据加载器,用于迭代数据集 """ return ( data.DataLoader( _data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, ) for _data in datas )
[文档] def download_file(url: str, *, save_path: str) -> str: """ 下载文件 Args: url (str): 文件的 URL 地址 save_path (str): 保存文件的路径 Returns: str: 下载的文件名 """ file_name = re.search(r"(?<=/)[^/]+$", url).group() # 从url中提取文件名 if not Path(f"{save_path}/{file_name}").exists(): # 如果文件不存在则下载 Path(save_path).mkdir(parents=True, exist_ok=True) # 创建保存路径 with httpx.Client() as client: with client.stream("GET", url) as response: response.raise_for_status() # 检查响应状态码 total_size = int(response.headers.get("Content-Length", 0)) # 获取文件大小 with ( open(f"{save_path}/{file_name}", "wb") as f, tqdm(desc=file_name, total=total_size, unit="B", unit_scale=True, unit_divisor=1024) as pbar, ): for chuck in response.iter_bytes(): f.write(chuck) pbar.update(len(chuck)) return file_name
[文档] class BaseBbox: """ 基础边界框类,用于表示物体的边界框。 """ def __init__(self, bbox: list, *, bbox_type: str = "xmin_ymin_xmax_ymax"): """ 初始化边界框 Args: bbox (list): 边界框参数,格式根据 bbox_type 不同而不同 bbox_type (str, optional): 边界框格式,可选值为 "xmin_ymin_xmax_ymax"、"xmin_ymin_w_h"、"center_w_h". 默认值为 "xmin_ymin_xmax_ymax". Raises: ValueError: 如果 bbox 参数长度不是 5 个元素 ValueError: 如果类别不是整数 ValueError: 如果 bbox 参数不归一化 ValueError: 如果 bbox_type 不是 'xmin_ymin_xmax_ymax'、'xmin_ymin_w_h' 或 'center_w_h' """ if len(bbox) != 5: raise ValueError("bbox 参数必须是 5 个元素") if isinstance(bbox[0], int): self.class_id = bbox[0] else: raise ValueError("类别必须是整数") if not all(isinstance(item, float) for item in bbox[1:]): raise ValueError("bbox 参数必须归一化") if bbox_type == "xmin_ymin_xmax_ymax": self.x_min, self.y_min, self.x_max, self.y_max = bbox[1:] elif bbox_type == "xmin_ymin_w_h": self.x_min, self.y_min, self.x_max, self.y_max = (bbox[1], bbox[2], bbox[1] + bbox[3], bbox[2] + bbox[4]) elif bbox_type == "center_w_h": self.x_min, self.y_min, self.x_max, self.y_max = ( bbox[1] - bbox[3] / 2, bbox[2] - bbox[4] / 2, bbox[1] + bbox[3] / 2, bbox[2] + bbox[4] / 2, ) else: raise ValueError("bbox_type 必须是 'xmin_ymin_xmax_ymax'、'xmin_ymin_w_h' 或 'center_w_h'") def __str__(self) -> str: """ 返回边界框的字符串表示 Returns: str: 边界框的字符串表示,格式为 "class_id x_min y_min x_max y_max" """ return f"{self.class_id} {self.x_min} {self.y_min} {self.x_max} {self.y_max}" def __repr__(self) -> str: """ 返回边界框的字符串表示 Returns: str: 边界框的字符串表示,格式为 "BaseBbox(class_id={self.class_id}, bbox=[{self.x_min}, {self.y_min}, {self.x_max}, {self.y_max}])" """ return f"BaseBbox(class_id={self.class_id}, bbox=[{self.x_min}, {self.y_min}, {self.x_max}, {self.y_max}])"
[文档] def xmin_ymin_xmax_ymax(self) -> list: """ 返回边界框的坐标表示 Returns: list: 边界框的坐标表示,格式为 [class_id, x_min, y_min, x_max, y_max] """ return [self.class_id, self.x_min, self.y_min, self.x_max, self.y_max]
[文档] def xmin_ymin_w_h(self) -> list: """ 返回边界框的坐标表示 Returns: list: 边界框的坐标表示,格式为 [class_id, x_min, y_min, x_max - x_min, y_max - y_min] """ return [self.class_id, self.x_min, self.y_min, self.x_max - self.x_min, self.y_max - self.y_min]
[文档] def center_w_h(self) -> list: """ 返回边界框的坐标表示 Returns: list: 边界框的坐标表示,格式为 [class_id, x_min + (x_max - x_min) / 2, y_min + (y_max - y_min) / 2, x_max - x_min, y_max - y_min] """ return [ self.class_id, self.x_min + (self.x_max - self.x_min) / 2, self.y_min + (self.y_max - self.y_min) / 2, self.x_max - self.x_min, self.y_max - self.y_min, ]
[文档] @staticmethod def normalize(bbox: list, *, width: int, height: int) -> list: """ 归一化边界框坐标 Args: bbox (list): 边界框参数,格式为 [class_id, x_min, y_min, x_max, y_max] width (int): 图片宽度 height (int): 图片高度 Returns: list: 归一化后的边界框参数,格式为 [class_id, x_min / width, y_min / height, x_max / width, y_max / height] """ return [bbox[0], bbox[1] / width, bbox[2] / height, bbox[3] / width, bbox[4] / height]
[文档] @staticmethod def unnormalize(bbox: list, *, width: int, height: int) -> list: """ 反归一化边界框坐标 Args: bbox (list): 归一化后的边界框参数,格式为 [class_id, x_min / width, y_min / height, x_max / width, y_max / height] width (int): 图片宽度 height (int): 图片高度 Returns: list: 反归一化后的边界框参数,格式为 [class_id, int(x_min * width), int(y_min * height), int(x_max * width), int(y_max * height)] """ return [bbox[0], int(bbox[1] * width), int(bbox[2] * height), int(bbox[3] * width), int(bbox[4] * height)]
[文档] class Bbox: """ 边界框容器类 """ def __init__(self, bboxes: list, *, bbox_type: str = "xmin_ymin_xmax_ymax"): """ 初始化 Bbox 实例 Args: bboxes (list): 边界框列表,每个元素为 BaseBbox 实例 bbox_type (str, optional): 边界框类型,可选值为 "xmin_ymin_xmax_ymax"、"xmin_ymin_w_h" 或 "center_w_h",默认为 "xmin_ymin_xmax_ymax" Raises: ValueError: 如果 bboxes 参数不是列表 ValueError: 如果 bboxes 列表元素不是列表 """ if not isinstance(bboxes, list): raise ValueError("bboxes 参数必须是列表") if not all(isinstance(bbox, list) for bbox in bboxes): raise ValueError("bboxes 列表元素必须是列表") self.bboxes = [BaseBbox(bbox, bbox_type=bbox_type) for bbox in bboxes] def __getitem__(self, index: int) -> BaseBbox: """ 获取指定索引的边界框 Args: index (int): 边界框索引 Returns: BaseBbox: 指定索引的边界框实例 """ return self.bboxes[index] def __len__(self) -> int: """ 返回边界框列表的长度 Returns: int: 边界框列表的长度 """ return len(self.bboxes) def __str__(self) -> str: """ 返回边界框列表的字符串表示 Returns: str: 边界框列表的字符串表示 """ return "\n".join(str(bbox) for bbox in self.bboxes) def __repr__(self) -> str: """ 返回边界框列表的字符串表示 Returns: str: 边界框列表的字符串表示 """ return "Bbox([\n" + ",\n".join(str(bbox.__repr__()) for bbox in self.bboxes) + ",\n])"
[文档] def xmin_ymin_xmax_ymax(self) -> list: """ 返回边界框的坐标表示 Returns: list: 边界框的坐标表示,格式为 [class_id, x_min, y_min, x_max, y_max] """ return [bbox.xmin_ymin_xmax_ymax() for bbox in self.bboxes]
[文档] def xmin_ymin_w_h(self) -> list: """ 返回边界框的坐标表示 Returns: list: 边界框的坐标表示,格式为 [class_id, x_min, y_min, x_max - x_min, y_max - y_min] """ return [bbox.xmin_ymin_w_h() for bbox in self.bboxes]
[文档] def center_w_h(self) -> list: """ 返回边界框的坐标表示 Returns: list: 边界框的坐标表示,格式为 [class_id, x_min + (x_max - x_min) / 2, y_min + (y_max - y_min) / 2, x_max - x_min, y_max - y_min] """ return [bbox.center_w_h() for bbox in self.bboxes]
[文档] @staticmethod def normalize(bboxes: list, *, width: int, height: int) -> list: """ 归一化边界框坐标 Args: bboxes (list): 边界框列表,每个元素为 BaseBbox 实例 width (int): 图片宽度 height (int): 图片高度 Returns: list: 归一化后的边界框列表,每个元素为 BaseBbox 实例 """ return [BaseBbox.normalize(bbox, width=width, height=height) for bbox in bboxes]
[文档] @staticmethod def unnormalize(bboxes: list, *, width: int, height: int) -> list: """ 反归一化边界框坐标 Args: bboxes (list): 归一化后的边界框列表,每个元素为 BaseBbox 实例 width (int): 图片宽度 height (int): 图片高度 Returns: list: 反归一化后的边界框列表,每个元素为 BaseBbox 实例 """ return [BaseBbox.unnormalize(bbox, width=width, height=height) for bbox in bboxes]
[文档] def bbox( bboxes: list, *, bbox_type: str = "xmin_ymin_xmax_ymax", normalize: bool = True, width: int = None, height: int = None, ) -> Bbox: """ 创建 Bbox 实例 Args: bboxes (list): 边界框列表,每个元素为 BaseBbox 实例 bbox_type (str, optional): 边界框类型,可选值为 "xmin_ymin_xmax_ymax"、"xmin_ymin_w_h" 或 "center_w_h",默认为 "xmin_ymin_xmax_ymax" normalize (bool, optional): 是否归一化边界框坐标,默认为 True width (int, optional): 图片宽度,默认为 None height (int, optional): 图片高度,默认为 None Returns: Bbox: Bbox 实例 Raises: ValueError: 如果 normalize 为 False 时,width 和 height 未提供 """ if normalize: return Bbox(bboxes, bbox_type=bbox_type) else: if width is None or height is None: raise ValueError("normalize 为 False 时,width 和 height 必须提供") return Bbox(Bbox.normalize(bboxes, width=width, height=height), bbox_type=bbox_type)
[文档] def read_label_file(label_file_path: str, bbox_type: str = "xmin_ymin_xmax_ymax") -> Bbox: """ 读取标签文件并返回边界框实例 Args: label_file_path (str): 标签文件路径 bbox_type (str, optional): 边界框类型,可选值为 "xmin_ymin_xmax_ymax"、"xmin_ymin_w_h" 或 "center_w_h",默认为 "xmin_ymin_xmax_ymax" Returns: Bbox: 边界框实例 """ lines = [] with open(label_file_path, "r") as file: for line in file.readlines(): line = line.strip().split() line[0] = int(line[0]) line[1:] = [float(x) for x in line[1:]] lines.append(line) return bbox(lines, bbox_type=bbox_type)
[文档] def mask_to_bbox(mask: np.ndarray, mask_type: str = "gray") -> Bbox: """ 将二值掩码转换为边界框 Args: np_mask (np.ndarray): 二值掩码数组 mask_type (str, optional): 掩码类型,可选值为 "gray",默认为 "gray" Returns: Bbox: 边界框实例 Raises: ValueError: 如果 np_mask 不是 2 维数组 ValueError: 如果 mask_type 不是 'gray' """ if mask.ndim != 2: raise ValueError("np_mask 必须是 2 维数组") if mask_type == "gray": _mask = mask != 0 (y_indices,) = np.nonzero(np.any(_mask == 1, axis=1)) (x_indices,) = np.nonzero(np.any(_mask == 1, axis=0)) y_min, y_max = y_indices.min().item(), y_indices.max().item() x_min, x_max = x_indices.min().item(), x_indices.max().item() else: raise ValueError("mask_type 必须是 'gray'") return bbox([[0, x_min, y_min, x_max, y_max]], normalize=False, width=_mask.shape[1], height=_mask.shape[0])
[文档] def rename_file(file_path: str, new_name: str): """ 重命名文件 Args: file_path (str): 文件路径 new_name (str): 新文件名 Raises: FileExistsError: 如果新文件名已存在 """ _file_path = Path(file_path) file_new_name = _file_path.name.replace(_file_path.stem, new_name) file_new_path = _file_path.parent / file_new_name if file_new_path.exists(): raise FileExistsError(f"文件 {file_new_path} 已存在") _file_path.rename(file_new_path)
[文档] def batch_rename(image_dir_path: str, label_dir_path: str, *, prefix: str, offset: int = 0): """ 批量重命名图片和标签文件 Args: image_dir_path (str): 图片目录路径 label_dir_path (str): 标签目录路径 prefix (str): 文件名前缀 offset (int, optional): 文件名偏移量,默认为 0 """ print(f"以 {prefix} 为前缀重命名文件") _image_dir_path, _label_dir_path = Path(image_dir_path), Path(label_dir_path) for index, image_path in enumerate(_image_dir_path.iterdir()): label_path = _label_dir_path / (image_path.stem + ".txt") try: rename_file(str(image_path), f"{prefix}_{index + offset:010d}") rename_file(str(label_path), f"{prefix}_{index + offset:010d}") except FileExistsError as e: print(e) batch_rename(image_dir_path, label_dir_path, prefix="temp") batch_rename(image_dir_path, label_dir_path, prefix=prefix) break print("重命名完成")