Tdeer

重构TDEER

代码部分架构


flowchart LR
    subgraph 数据处理
        A[数据预处理] --> B[数据清洗]
        B --> C[数据标注]
        C --> D[数据存储]
    end
    subgraph 模型训练
        E[模型训练] --> F[模型评估]
        F --> G[模型调优]
        G --> H[模型部署]
    end
    subgraph 模型部署
        I[模型部署] --> J[模型监控]
graph TD
    A[Start] --> B{Is it sunny?}
    B -->|Yes| C[Go to the park]
    B -->|No| D[Stay home]

解析run代码

导入必要的库和模块:

os 用于操作系统功能,如文件路径操作。 argparse 用于解析命令行参数。 tokenizers.BertWordPieceTokenizer 用于处理文本数据,使其适用于 BERT 模型。 从其他 Python 文件导入的模块,如 DataGenerator, load_data, load_rel, build_model, Evaluator, Infer, compute_metrics,这些文件负责数据加载、模型构建、评估等功能。 设置命令行参数解析:

这部分代码定义了多个命令行参数,用于指定模型的训练和测试配置,如模型名、数据路径、学习率、批次大小、训练轮数等。 初始化模型和工具:

加载关系数据。 初始化 BERT 分词器,并设定最大序列长度。 构建模型,包括实体模型、关系模型、翻译模型和训练模型。 训练模型:

如果指定了 –do_train,程序将进入训练模式。 加载训练数据和验证数据,并进行相应的配置。 初始化数据生成器和评估器。 使用 train_model.fit() 方法训练模型,并在每个 epoch 结束时调用评估器。 测试模型:

如果指定了 –do_test,程序将进入测试模式。 加载测试数据。 加载训练好的模型权重。 计算测试数据上的精度、召回率和 F1 分数,并打印结果。 这段代码是一个完整的机器学习流程,包括参数解析、模型训练、模型评估和测试。它允许用户通过命令行接口灵活地配置训练和测试过程,适用于实验和研究不同配置对模型性能的影响。

dataloader

这段代码是一个Python脚本,主要用于加载和处理关系抽取的数据,并创建一个数据生成器用于神经网络训练。以下是代码各部分的详细解析:

导入模块

  • json: 用于处理JSON数据。
  • typing: 提供类型标注支持。
  • collections.defaultdict: 提供默认字典功能。
  • numpykeras.preprocessing.sequence.pad_sequences: 用于数学运算和数据预处理。
  • log: 可能是自定义的日志记录模块。

函数定义

  1. find_entity:
    • 功能:在给定的源列表中找到目标列表的起始索引。
    • 参数:source (整数列表), target (整数列表)。
    • 返回:目标列表在源列表中的起始索引,如果未找到,则返回-1。
  2. to_tuple:
    • 功能:将输入句子中的三元组列表转换成元组形式。
    • 参数:sent (包含三元组列表的字典)。
    • 修改是就地进行的,没有返回值。
  3. filter_data:
    • 功能:从指定文件路径加载数据,过滤掉不在关系ID映射中的三元组。
    • 参数:fpath (文件路径), rel2id (关系到ID的映射字典)。
    • 返回:过滤后的数据列表。
  4. load_rel:
    • 功能:从指定路径加载关系映射数据。
    • 参数:rel_path (文件路径)。
    • 返回:ID到关系的映射,关系到ID的映射,所有关系列表,关系的数量。
  5. load_data:
    • 功能:从指定路径加载数据,并进行进一步处理。
    • 参数:fpath (文件路径), rel2id (关系到ID的映射字典), is_train (是否为训练数据)。
    • 返回:处理后的数据列表。

类定义:DataGenerator

  • 构造函数:
    • 初始化数据生成器实例。
    • 参数:datas (数据列表), tokenizer (分词器对象), rel2id, all_rels, max_len (最大长度), batch_size (批次大小), max_sample_triples (最大采样三元组数量), neg_samples (负样本数量)。
  • __len__:
    • 返回生成器的步骤数。
  • __iter__:
    • 数据生成器,用于在训练中按批次生成数据。
    • 参数:random (是否随机化顺序)。
  • forfit:
    • 无限循环地调用数据迭代器,常用于模型训练。

数据处理和模型训练准备

这个脚本主要处理输入数据,准备用于关系抽取任务的数据集。它包括了从文本中识别实体和关系,将文本转换为模型可以处理的格式,并生成训练所需的正负样本。整个处理过程强调了在实体识别和关系抽取中的一些常见步骤,如编码、寻找实体位置、构建实体和关系的表示等。

model

这段代码定义了一个使用TensorFlow和Keras库构建的复杂关系抽取模型,利用BERT预训练模型进行特征提取,并设置了不同部分的网络结构用于实体和关系识别。以下是代码的主要组成部分和功能详解:

导入库

  • tensorflowkeras: 用于建立和训练深度学习模型。
  • langml.plm.bert: 加载预训练的BERT模型。
  • langml.layers: 用于特殊的层操作,如自注意力机制。
  • langml.tensor_typing: 用于类型标注。
  • utils.compute_metrics: 用于计算模型的精确度、召回率和F1分数。

build_model 函数

定义了一个模型构建函数,它包含以下几个步骤和特性:

  1. 载入BERT模型:从指定目录加载BERT的配置和权重。
  2. 输入层定义
    • 实体头尾的输入。
    • 关系的输入。
    • 样本的主题和客体头尾的索引。
  3. BERT模型输出处理
    • 使用BERT模型输出特征。
    • 定义实体头尾的预测。
    • 定义关系预测。
    • 使用自注意力机制处理对象特征。
  4. 损失函数定义
    • 实体头尾的二元交叉熵损失。
    • 关系的均方误差损失。
    • 客体头部的二元交叉熵损失。
    • 综合所有损失项形成最终的损失函数。
  5. 模型编译
    • 使用Adam优化器。
    • 设置学习率。

Evaluator

  • 这是一个回调类,用于在每个训练周期结束时评估模型,并在有更好的F1分数时保存模型权重。
  • on_train_begin: 在训练开始时设置最优F1分数为负无穷大。
  • on_batch_begin: 动态调整学习率,首先增加到最大值,然后减少到最小值。
  • on_epoch_end: 在每个周期结束时计算模型的精确度、召回率和F1分数,并在有提升时更新最佳模型。

模型的主要功能

这个脚本主要用于建立一个用于实体识别和关系抽取的深度学习模型。模型利用BERT的强大语言理解能力,通过自注意力机制进一步提取和处理文本特征,优化关系和实体的识别。通过精细的损失函数和动态学习率调整,模型能够更好地适应训练数据,提高识别的准确性。

utils

这段代码定义了一个用于关系抽取任务的推理系统,包括模型预测和性能评估的相关操作。以下是代码的主要组成部分和功能详解:

导入库

  • json: 处理JSON数据。
  • time: 获取时间数据,用于性能监控。
  • numpy: 数值计算。
  • tqdm: 进度条显示。
  • langml.tensor_typing.Models: 类型标注用。

功能函数

  1. rematch:
    • 功能:根据文本与token的对应关系,生成从token到文本的字符索引映射。
    • 参数:offsets,即单词或token在原始文本中的起止位置列表。
    • 返回:每个token对应的原文字符索引列表。

类定义:Infer

  • 构造函数:
    • 初始化推理系统,包括加载模型和分词器等。
  • decode_entity:
    • 从文本中解码实体,根据头尾索引返回实体字符串。
  • __call__:
    • 类的主要逻辑,用于生成文本的实体和关系三元组。包括实体识别和关系预测,以及客体头部预测。

辅助处理函数

  • partial_match:
    • 用于处理预测集和黄金标准集,确保它们的比较忽略某些不重要的字符,如空格。
  • remove_space:
    • 去除三元组中的所有空格,用于精确匹配场景。

性能评估函数:compute_metrics

  • 这个函数负责计算推理模型的性能指标,包括精确度、召回率和F1分数。
  • 通过遍历开发集数据,使用Infer类生成预测三元组,并与黄金标准进行比较。
  • 根据是否完全匹配,可能调用remove_spacepartial_match对数据进行预处理。
  • 将预测结果、缺失和错误的三元组写入输出文件,用于进一步分析。

总结

这段代码实现了一个完整的关系抽取推理和评估流程,包括文本的实体和关系抽取、三元组的生成、以及模型效果的评估。通过类和函数的细致设计,代码既考虑了模型的实际应用场景,也便于通过修改参数来适应不同的需求和数据集。