import torch
from collections import Counter
from mltools import utils
[文档]
class Tokenizer:
"""
分词器,将文本数据转换为词元索引,支持词元与索引之间的相互转换,并提供保存和加载词表的功能。
"""
def __init__(self, datas: list[str], min_freq: int = 0):
"""
初始化分词器。
Args:
datas (list[str]): 数据集,包含文本数据。
min_freq (int, optional): 最小词频,低于该频率的词元将被过滤。默认值为 0。
"""
tokens = Counter() # 将文本拆分为词元并统计频率
for item in datas:
tokens.update(str(item))
self.unk = 0 # 未知词元索引为0
self.cls = 1 # 分类词元索引为1
self.sep = 2 # 分隔词元索引为2
self.pad = 3 # 填充词元索引为3
tokens = [item[0] for item in tokens.items() if item[1] > min_freq] # 删除低频词元
self.idx_to_token = ["[UNK]", "[CLS]", "[SEP]", "[PAD]"] + tokens # 建立词元列表
# 建立词元字典
tokens_dict = {value: index + 4 for index, value in enumerate(tokens)}
self.token_to_idx = {"[UNK]": 0, "[CLS]": 1, "[SEP]": 2, "[PAD]": 3}
self.token_to_idx.update(tokens_dict)
def __call__(self, tokens: str | list[str] | tuple[str], max_length: int = None) -> torch.Tensor:
"""
调用分词器,将词元转换为索引。
Args:
tokens (str 或 list[str] 或 tuple[str]): 输入的词元。
max_length (int, optional): 最大长度,用于填充或截断。默认值为 None。
Returns:
torch.Tensor: 转换后的词元索引。
"""
return self.encode(tokens, max_length)
def __len__(self) -> int:
"""
返回词表大小。
Returns:
int: 词表的长度。
"""
return len(self.idx_to_token)
[文档]
def decode(self, indices: torch.Tensor) -> str | list[str]:
"""
根据索引返回词元。
Args:
indices (torch.Tensor): 输入的词元索引。
Returns:
str 或 list[str]: 解码后的词元。
Raises:
TypeError: 如果输入的 indices 不是 torch.Tensor 类型。
"""
if isinstance(indices, torch.Tensor):
if indices.dim() == 0:
return []
elif indices.dim() == 1:
return "".join([self.idx_to_token[index] for index in indices.tolist()])
elif indices.dim() == 2:
return ["".join([self.idx_to_token[item] for item in index]) for index in indices.tolist()]
else:
raise TypeError("indices 必须是 torch.Tensor 类型")
[文档]
def encode(self, texts: str | list[str] | tuple[str], max_length: int = None) -> torch.Tensor:
"""
根据词元返回索引。
Args:
texts (str 或 list[str] 或 tuple[str]): 输入的词元。
max_length (int, optional): 最大长度,用于填充或截断。默认值为 None。
Returns:
torch.Tensor: 转换后的词元索引。
Raises:
TypeError: 如果输入的 texts 不是 str、list[str] 或 tuple[str] 类型。
"""
if isinstance(texts, str):
if max_length:
texts = (
list(texts)[:max_length]
if len(texts) > max_length
else list(texts) + ["[PAD]"] * (max_length - len(texts))
)
return torch.tensor([self.token_to_idx.get(token, self.unk) for token in texts])
elif isinstance(texts, (list, tuple)):
if not max_length:
max_length = max([len(text) for text in texts])
return torch.stack([self.encode(text, max_length) for text in texts])
else:
raise TypeError(
f"texts: {texts}\nThe type of texts is {type(texts)}, while texts must be of type str, tuple[str] or list[str]"
)
[文档]
def save(self, path: str, label: str = "tokenizer"):
"""
保存分词器的词表到 JSON 文件。
Args:
path (str): JSON 文件的保存路径。
label (str, optional): 数据在 JSON 文件中的键名。默认值为 'tokenizer'。
"""
utils.DataSaveToJson.save_data(path, label, [self.idx_to_token, self.token_to_idx])
[文档]
def load(self, path: str, label: str = "tokenizer"):
"""
从 JSON 文件中加载分词器的词表。
Args:
path (str): JSON 文件的路径。
label (str, optional): 数据在 JSON 文件中的键名。默认值为 'tokenizer'。
"""
self.idx_to_token, self.token_to_idx = utils.DataSaveToJson.load_data(path, label)