import torch
from torch import nn
import logging
from pathlib import Path
from datetime import datetime
from mltools import utils, draw
[文档]
class Epoch:
"""
机器学习 Epoch,用于管理训练轮数,支持保存和加载总训练轮数。
"""
def __init__(self, parent: object):
"""
初始化
Args:
parent (object): 父对象,用于访问日志记录器。
"""
self._totol_epoch = 0
self.parent = parent
def __call__(self, num_epochs: int) -> int:
"""
返回迭代轮数。
Args:
num_epochs (int): 期望的训练轮数。
Returns:
int: 本次需要训练的轮数。
"""
num_epoch = num_epochs - self.totol_epoch if num_epochs > self.totol_epoch else 0 # 计算迭代次数
self._totol_epoch = max(self.totol_epoch, num_epochs) # 计算总迭代次数
# 根据迭代次数产生日志
self.parent.logger.debug(f"total training epochs {self.totol_epoch}")
if num_epoch:
self.parent.logger.debug(f"trained {num_epoch} epochs")
else:
self.parent.logger.warning(
f"num_epochs is {num_epochs}, less than totol training epoch {self.totol_epoch}, the model won't be trained."
)
return num_epoch
@property
def totol_epoch(self) -> int:
"""
返回总迭代次数。
Returns:
int: 总训练轮数。
"""
return self._totol_epoch
[文档]
def save(self, path: str, label: str = "epoch"):
"""
保存总训练轮数到 JSON 文件。
Args:
path (str): JSON 文件的保存路径。
label (str, optional): 数据在 JSON 文件中的键名。默认值为 'epoch'。
"""
utils.DataSaveToJson.save_data(path, label, self.totol_epoch)
[文档]
def load(self, path: str, label: str = "epoch"):
"""
从 JSON 文件中加载总训练轮数。
Args:
path (str): JSON 文件的路径。
label (str, optional): 数据在 JSON 文件中的键名。默认值为 'epoch'。
"""
self._totol_epoch = utils.DataSaveToJson.load_data(path, label)
[文档]
class AutoSaveLoader:
"""
自动保存加载器,将多个数据的保存和加载功能整合在一起,
支持添加自定义的保存和加载函数。
"""
def __init__(self):
"""
初始化函数,创建保存和加载函数列表。
"""
self.save_func = [] # 保存函数
self.load_func = [] # 加载函数
[文档]
def add_save_func(self, func: callable):
"""
添加保存函数。
Args:
func (callable): 保存函数。
"""
self.save_func.append(func)
[文档]
def save(self, dir_path: str):
"""
保存数据。
Args:
dir_path (str): 数据保存的目录路径。
"""
for func in self.save_func:
func(dir_path)
[文档]
def add_load_func(self, func: callable):
"""
添加加载函数。
Args:
func (callable): 加载函数。
"""
self.load_func.append(func)
[文档]
def load(self, dir_path: str):
"""
加载数据。
Args:
dir_path (str): 数据加载的目录路径。
"""
for func in self.load_func:
func(dir_path)
[文档]
class MachineLearning:
"""
机器学习工具类,提供批量创建训练辅助对象、管理模型和数据的保存与加载等功能。
"""
def __init__(self, file_name: str):
"""
初始化函数。
Args:
file_name (str): 文件名。
"""
# 创建目录
Path("../data").mkdir(parents=True, exist_ok=True)
utils.add_ignore_file("../data")
Path("../results").mkdir(parents=True, exist_ok=True)
utils.add_ignore_file("../results")
# 定义时间字符串和文件名
time_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
self.dir_path = f"../results/{time_str}-{file_name}"
self.file_name = file_name
# 创建目录
Path(self.dir_path).mkdir()
# 设置日志
self.logger = logging.getLogger("mylog")
self.logger.setLevel(logging.DEBUG)
# 定义日志格式
formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s")
# 创建文件处理器
file_handler = logging.FileHandler(f"{self.dir_path}/{self.file_name}.log")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
# 创建自动保存加载器
self.data_manager = AutoSaveLoader()
[文档]
def batch_create(self, create_epoch: bool = True, create_timer: bool = True, create_recorder: bool = True) -> tuple:
"""
批量创建 Epoch、Timer 和 Recorder 对象。
Args:
create_epoch (bool, optional): 是否创建 Epoch 对象。默认值为 True。
create_timer (bool, optional): 是否创建计时器对象。默认值为 True。
create_recorder (bool, optional): 是否创建记录器对象。默认值为 True。
Returns:
tuple: 包含创建的 Epoch、Timer 和 Recorder 对象的元组,不包含 None 值。
"""
epoch = self.create_epoch() if create_epoch else None
timer = self.create_timer() if create_timer else None
recorder = self.create_recorder(3) if create_recorder else None
return (item for item in (epoch, timer, recorder) if item is not None)
[文档]
def save(self, dir_name: str = None):
"""
保存数据。
Args:
dir_name (str, optional): 数据保存的目录名。默认值为 None。
"""
dir_path = f"../results/{dir_name}" if dir_name else self.dir_path
self.data_manager.save(dir_path)
[文档]
def load(self, dir_name: str = None):
"""
加载数据。
Args:
dir_name (str, optional): 数据加载的目录名。默认值为 None。
"""
dir_path = f"../results/{dir_name}" if dir_name else self.dir_path
self.data_manager.load(dir_path)
[文档]
def create_epoch(self, label: str = "num_epochs") -> Epoch:
"""
创建 Epoch 参数。
Args:
label (str, optional): Epoch 的标签,建议和被赋值变量名相同。默认值为 'num_epochs'。
Returns:
Epoch: 创建的 Epoch 对象。
"""
epoch = Epoch(self)
def save(dir_path):
epoch.save(f"{dir_path}/{self.file_name}.json", label)
self.logger.debug(f"save Epoch({label}) to {dir_path}/{self.file_name}.json")
self.data_manager.add_save_func(save)
def load(dir_path):
epoch.load(f"{dir_path}/{self.file_name}.json", label)
self.logger.debug(f"load Epoch({label}) from {dir_path}/{self.file_name}.json")
self.data_manager.add_load_func(load)
self.logger.debug(f"create Epoch({label})")
return epoch
[文档]
def create_timer(self, label: str = "timer") -> utils.Timer:
"""
创建计时器。
Args:
label (str, optional): 计时器的标签,建议和被赋值变量名相同。默认值为 'timer'。
Returns:
Timer: 创建的计时器对象。
"""
timer = utils.Timer()
def save(dir_path):
timer.save(f"{dir_path}/{self.file_name}.json", label)
self.logger.debug(f"save Timer({label}) to {dir_path}/{self.file_name}.json")
self.data_manager.add_save_func(save)
def load(dir_path):
timer.load(f"{dir_path}/{self.file_name}.json", label)
self.logger.debug(f"load Timer({label}) from {dir_path}/{self.file_name}.json")
self.data_manager.add_load_func(load)
self.logger.debug(f"create Timer({label})")
return timer
[文档]
def create_recorder(self, recorder_num: int, label: str = "recorder") -> utils.Recorder:
"""
创建记录器。
Args:
recorder_num (int): 记录器的数量。
label (str, optional): 记录器的标签,建议和被赋值变量名相同。默认值为 'recorder'。
Returns:
Recorder: 创建的记录器对象。
"""
recorder = utils.Recorder(recorder_num)
def save(dir_path):
recorder.save(f"{dir_path}/{self.file_name}.json", label)
self.logger.debug(f"save Recorder({label}) to {dir_path}/{self.file_name}.json")
self.data_manager.add_save_func(save)
def load(dir_path):
recorder.load(f"{dir_path}/{self.file_name}.json", label)
self.logger.debug(f"load Recorder({label}) from {dir_path}/{self.file_name}.json")
self.data_manager.add_load_func(load)
self.logger.debug(f"create Recorder({label})")
return recorder
[文档]
def create_animator(
self,
xlabel: str = None,
ylabel: str = None,
xlim: tuple = None,
ylim: tuple = None,
legend: list = None,
fmts: list = None,
label: str = "animator",
) -> draw.Animator:
"""
创建动画器。
Args:
xlabel (str, optional): x 轴标签。默认值为 None。
ylabel (str, optional): y 轴标签。默认值为 None。
xlim (tuple, optional): x 轴范围。默认值为 None。
ylim (tuple, optional): y 轴范围。默认值为 None。
legend (list, optional): 图例。默认值为 None。
fmts (list, optional): 格式。默认值为 None。
label (str, optional): 动画器的标签,建议和被赋值变量名相同。默认值为 'animator'。
Returns:
Animator: 创建的动画器对象。
"""
animator = draw.Animator(xlabel, ylabel, xlim, ylim, legend, fmts)
def save(dir_path):
animator.save(f"{dir_path}/{self.file_name}.png")
self.logger.debug(f"save Animator({label}) to {dir_path}/{self.file_name}.png")
self.data_manager.add_save_func(save)
self.logger.debug(f"create Animator({label})")
return animator
[文档]
def add_model(self, model: nn.Module, label: str = "model"):
"""
添加模型。
Args:
model: 模型。
label (str, optional): 模型的标签,建议和模型变量名相同。默认值为 'model'。
Raises:
RuntimeError: 如果模型不是 nn.Module 类型。
"""
if not isinstance(model, nn.Module):
raise RuntimeError(f"model({label}) must be a nn.Module")
def save(dir_path):
torch.save(model.state_dict(), f"{dir_path}/{self.file_name}.pt")
self.logger.debug(f"save model({label}) to {dir_path}/{self.file_name}.pt")
self.data_manager.add_save_func(save)
def load(dir_path):
model.load_state_dict(torch.load(f"{dir_path}/{self.file_name}.pt"))
self.logger.debug(f"load model({label}) from {dir_path}/{self.file_name}.pt")
self.data_manager.add_load_func(load)
self.logger.debug(f"add model({label})")
self.logger.debug(f"model({label}) is {model}")
[文档]
def print_training_time_massage(self, timer: utils.Timer, num_epochs: int, current_epoch: int):
"""
打印模型训练时间相关信息,包括已训练时长、平均训练时长和预估剩余训练时长。
Args:
timer (Timer): 计时器对象,用于获取训练时间数据。
num_epochs (int): 总训练轮数。
current_epoch (int): 当前训练到的轮数。
"""
# 计算已训练的总时长,并转换为 HH:MM:SS 格式
trained_duration = utils.Timer.str(timer.sum())
# 计算每轮的平均训练时长,并转换为 HH:MM:SS 格式
average_duration = utils.Timer.str(timer.avg())
# 计算预估的剩余训练时长,并转换为 HH:MM:SS 格式
estimated_duration = utils.Timer.str((num_epochs - current_epoch) * timer.avg())
# 打印训练时间相关信息
self.logger.info(
f"Trained duration: {trained_duration}, Average training duration: {average_duration}, Estimated training duration:{estimated_duration}"
)
[文档]
def model_params(self, model: nn.Module, label: str = "model"):
"""
打印模型参数数量。
Args:
model: 模型对象。
Raises:
RuntimeError: 如果模型不是 nn.Module 类型。
"""
if not isinstance(model, nn.Module):
raise RuntimeError(f"model({label}) must be a nn.Module")
# 统计模型参数数量
num_params = sum([param.numel() for param in model.parameters()])
# 打印模型参数数量
self.logger.info(f"Number of model({label}) parameters: {num_params / (1000 * 1000):.2f}M")