欺诈文本分类检测(十七):支持分类原因训练

1. 引言

前文数据校正与增强进行了数据增强,本文将使用增强后的数据对模型进行进一步训练,以便得到能同时预测出分类标签、欺诈者、分类原因多个信息的模型。

为此,我们需要对整个训练过程进行调整,包括:

  1. 交叉训练逻辑封装
  2. 数据序列化的改造
  3. 评测方法改造

2. 交叉训练封装

首先,我们将前文 交叉训练验证的代码封装为一个脚本trainer_cross.py,方便复用。内容如下:

import glob
import gc
import numpy as np
from datasets import Dataset, concatenate_datasets
from sklearn.model_selection import KFold
from trainer import *

def find_last_checkpoint(output_dir):
    checkpoint_dirs = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
    last_checkpoint_dir = max(checkpoint_dirs, key=os.path.getctime)
    return last_checkpoint_dir

def load_model_v2(model_path, checkpoint_path='', device='cuda'):
    # 加载模型
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device)
    # 加载lora权重
    if checkpoint_path: 
        model = PeftModel.from_pretrained(model, model_id=checkpoint_path).to(device)
    # 将基础模型的参数设置为不可训练
    for param in model.base_model.parameters():
        param.requires_grad = False
    
    # 将 LoRA 插入模块的参数设置为可训练
    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True
    return model

def build_trainer_v2(model, tokenizer, train_args, train_dataset, eval_dataset):
    # 开启梯度检查点时,要执行该方法
    if train_args.gradient_checkpointing:
        model.enable_input_require_grads()
    return Trainer(
        model=model,
        args=train_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],  # 早停回调
    )

def train_kfold(model_path, output_base_path, datasets, build_args_func, fold_num=5, device='cuda', last_checkpoint_path=''):
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    kf = KFold(n_splits=fold_num, shuffle=True)
    results = []
    
    for fold, (train_index, val_index) in enumerate(kf.split(np.arange(len(datasets)))):
        print(f"fold={fold} start, train_index={train_index}, val_index={val_index}")
        train_dataset = datasets.select(train_index)
        eval_dataset = datasets.select(val_index)
        print(f"train data: {len(train_dataset)}, eval: {len(eval_dataset)}")
    
        output_path = f'{output_base_path}_{fold}'
        train_args, lora_config = build_args_func(output_path)
        if last_checkpoint_path:
            model = load_model_v2(model_path, last_checkpoint_path, device)
            print(f"fold={fold}, load model from checkpoint: {last_checkpoint_path}")
        else:
            model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to(device)
            model = get_peft_model(model, lora_config)
    
        model.print_trainable_parameters()
        trainer = build_trainer_v2(model, tokenizer, train_args, train_dataset, eval_dataset)
        train_result = trainer.train()
        print(f"fold={fold}, result = {train_result}")
        results.append(train_result)
        
        last_checkpoint_path = find_last_checkpoint(output_path)
        
    return results

其中,各个方法的作用释义如下:

  • find_last_checkpoint:用于从一个目录下查找最新的checkpoint。
  • load_model_v2:加载模型和微调的checkpoint,并将lora权重设置为可训练,非lora权重设置为不可训练。
  • build_trainer_v2:构造训练器
  • train_kfold:封装K折交叉训练验证的主循环逻辑,循环的每个批次为不同的数据集

train_kfold是此脚本最终对外公开的方法,它开放了如下参数以便灵活调整训练过程:

  • model_path:基座模型路径;
  • output_base_path:输出模型的基础路径,K折交叉训练会以此路径为基础,来构造每一折数据的输出路径;
  • datasets:经过预处理后的数据集;
  • build_args_func:构造训练参数的方法,根据output_path来构造训练参数和Lora参数;
  • fold_num: 数据集要分割的折数;
  • device: 训练的GPU设备;
  • last_checkpoint_path: 最近一次训练的checkpoint路径,当接着上一次的训练结果继续训练时传此参数。

3. 数据加载改造

当输出数据改变后,模型的预期输出不再仅仅是一个分类标签,还需要包括欺诈者和分类原因。因此,我们加载数据和数据序列化的方式需要作相应调整。

改造数据预处理函数,扩展with_reason参数,参数值定义:

  • true:表示预期结果除了is_fraud字段外,还包含fraud_speaker和reason字段。
  • false:表示预期结果不包含fraud_speaker和reason字段。

代码如下(有变化的仅仅是if with_reason的判断分支)。

def preprocess(item, tokenizer, with_reason=False, max_length=2048):
    system_message = "You are a helpful assistant."
    user_message = item['instruction'] + '\n' + item['input']
    if with_reason: 
        output = {"is_fraud":item["label"], "fraud_speaker":item["fraud_speaker"], "reason":item["reason"]}
    else:
        output = {"is_fraud":item["label"]}
        
    assistant_message = json.dumps(output, ensure_ascii=False)
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n", add_special_tokens=False)  
    response = tokenizer(assistant_message, add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  
    # -100是一个特殊的标记,用于指示指令部分的token不应参与损失计算
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  
    
    # 对输入长度做一个限制保护,超出截断
    return {
        "input_ids": input_ids[:max_length],
        "attention_mask": attention_mask[:max_length],
        "labels": labels[:max_length]
    }

相应对外的load_dataset方法也扩展with_reason参数,目的兼容之前的单独分类标签训练,支持带原因和不带原因两种加载数据的模式。

def load_one_dataset(data_path, tokenizer, with_reason:bool):
    df = load_jsonl(data_path)
    ds = Dataset.from_pandas(df)
    return ds.map(
        lambda x: preprocess(x, tokenizer, with_reason=with_reason),
        remove_columns=ds.column_names)

def load_dataset(train_path, eval_path, tokenizer, with_reason=False):
    train_dataset = load_one_dataset(train_path, tokenizer, with_reason)
    eval_dataset = load_one_dataset(eval_path, tokenizer, with_reason)
    return train_dataset, eval_dataset

4. 开始训练

4.1 初始化

初始化改为引入新封装的脚本trainer_cross.py

%run trainer_cross.py

数据路径、模型路径、设备定义基本和之前保持一致。

traindata_path = '/data2/anti_fraud/dataset/train0902.jsonl'
evaldata_path = '/data2/anti_fraud/dataset/eval0902.jsonl'
model_path = '/data2/anti_fraud/models/modelscope/hub/Qwen/Qwen2-1___5B-Instruct'
output_base_path = '/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0913'
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = 'cuda'

加载数据集,使用concatenate_datasets方法将训练集和验证集合并为一个数据集。

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
train_dataset, eval_dataset = load_dataset(traindata_path, evaldata_path, tokenizer, with_reason=True, lazy=False)
datasets = concatenate_datasets([train_dataset, eval_dataset])

在这里插入图片描述

4.2 训练

定义参数构造的方法,用于构造训练参数和Lora参数,具体参数值保持与之前相同。

def build_arguments(output_path):
    train_args = build_train_arguments(output_path)
    train_args.eval_strategy='epoch'
    train_args.save_strategy='epoch'
    train_args.num_train_epochs = 2
    train_args.per_device_train_batch_size = 8
    
    lora_config = build_loraconfig()
    lora_config.lora_dropout = 0.2  
    lora_config.r = 16
    lora_config.lora_alpha = 32
    return train_args, lora_config

调用train_kfold方法开始训练:

results = train_kfold(model_path, output_base_path, datasets, build_args_func=build_arguments, fold_num=5, last_checkpoint_path=last_checkpoint_path)

总共进行了5折数据10轮训练,每折数据进行了两轮训练,相应的训练损失和验证损失数据如下:

EpochTraining LossValidation Loss
10.7801000.825167
20.6967000.813522
30.7854000.738886
40.6662000.731676
50.6794000.619393
60.5589000.610776
70.5821000.503672
80.4297000.490893
90.4833000.394778
100.3080000.372799
4.3 评测

根据验证损失数据,基于前文支持分类原因评测改造的脚本,采用微调效果最好的最后一轮checkpoint进行评测。

%run evaluate_v2.py
testdata_path = '/data2/anti_fraud/dataset/test0902.jsonl'
checkpoint_path = '/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0913_4/checkpoint-5454'
evaluate_v2(model_path, checkpoint_path, testdata_path, device, debug=True)

三个字段的评测指标分别如下:

字段指标
is_fraudprecision: 0.9422, recall: 0.9434, accuracy: 0.9419
fraud_speakeraccuracy: 0.9175
reasonprecision: 0.3596, recall: 0.3708, f1-score: 0.3571

经过训练后,三个字段的指标都有不同程度的提高,分别为:

  • is_fraud: precision从0.6232提升到0.9422,表明模型在欺诈文本分类任务上的精确率有明显提高,这能减少欺诈文本误报的次数;
  • fraud_speaker: accuracy从0.6327提升到0.9175,表明模型能有效的识别哪些说话者可能涉及欺诈;
  • reason: 召回率从0.2324提升到0.3708,f1-score从0.2638提升到0.3571

可以看到,is_fraud和fraud_speaker两个字段的准确率提升是比较明显的,而reason字段的召回率也有一定程度的提升,但分数没有那么高。

猜想原因可能在于训练不充分,因为从上面训练的损失数据中能看到一个现象:整个10轮训练下来,不论是训练损失还是验证损失都还在持续下降,这说明训练还未完成。

5. 再次训练

调整输出目录,并定义最近一次训练结构的checkpoint路径:

output_base_path = '/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0924'
last_checkpoint_path = '/data2/anti_fraud/models/Qwen2-1___5B-Instruct_ft_0913_4/checkpoint-5454'

将折子数调整为10, 其它都和上面相同,基于指定的checkpoint继续训练:

results = train_kfold(model_path, output_base_path, datasets, build_args_func=build_arguments, fold_num=10, last_checkpoint_path=last_checkpoint_path)

总共进行了10折数据20轮训练,每折数据进行了两轮训练,相应的训练损失和验证损失数据如下:

EpochTraining LossValidation Loss
10.4393000.365142
20.2041000.351331
30.3275000.268683
40.2024000.246474
50.2539000.192165
60.1332000.165728
70.16770.1145
80.15550.1024
90.14310.08527
100.13290.07053
110.12310.06223
120.11390.0571
130.10660.05086
140.10080.04274
150.094840.04154
160.0709000.038615
170.0697000.033552
180.0688000.029461
190.0593000.026729
200.0682000.023861

从训练结果来看,损失数据一直在不断的下降,这里用最后第20轮的checkpoint进行评测(代码省略),三个字段的评测指标分别如下:

字段指标
is_fraud0.9372, recall: 0.9414, accuracy: 0.9383
fraud_speakeraccuracy: 0.9152
reasonprecision: 0.3664, recall: 0.3705, f1-score: 0.3601

结果是有些失望的,虽然损失在不断下降,但各项评测指标基本都没有什么改善。

原因猜测可能有以下几个:

  1. 数据量不足,2-3万条训练数据相对于文本生成任务来说还是太少,可能还不足以明显改善生成文本的相似度。
  2. Lora参数矩阵(r=16)较小,不足以储存足够的信息特征。

补充:后续尝试过将Lora矩阵的秩r调到64,并训练了10折20轮,但测评结果依然与上面的结果相似,没有明显改善。

由于分类原因的相似度指标并不是我们必需的,这里受精力和数据量的限制,暂时没继续往下训练,但欺诈者字段的准确率已经达到预期,带欺诈者和分类原因的response格式已经能够正常生成,如下所示:

{"is_fraud": true, "fraud_speaker": "小灿", "reason": "小灿要求吴某某登录一个网站,并根据其指示进行国际黄金的操作,这种行为很可能是典型的投资诈骗手段。"}
{"is_fraud": true, "fraud_speaker": "李小龙", "reason": "李小龙要求支付一笔费用以帮助其亲戚获得释放,但没有提供任何具体的细节或证明其关系。使用 relative/convicted等模糊词汇来联系律师,并要求立即支付费用,这种行为具有明显的诈骗特征。"}
{"is_fraud": true, "fraud_speaker": "朱立", "reason": "朱立提供的投资方案中,明显存在通过吸引投资者大量投入资金来获取高额回报的行为。这种高回报承诺通常是不现实且具有欺骗性的,极有可能构成经济诈骗。具体表现为:- 98元、598元和998元的投资计划承诺回报远远超过正常市场回报率,这违反了普遍的投资原则和常识。\- 投资98元的方案回报5万到10万元,这样的回报率过高,容易让人怀疑其真实性。\- 同样,其他投资方案也承诺极高的回报,如598元投资回报50万元,998元投资回报200万元,这些回报率远超正常市场水平,极易诱导投资者上当。\n"}

小结:本文通过对数据加载的改造和交叉训练过程的封装,完成了一次针对带分类原因的欺诈文本分类任务的训练,并通过评测方法的改造实现了对不同类型字段的结果评测。从损失数据和评测结果来看,要改善生成文本的精确率和召回率,可能还需要更多更丰富的数据,后续腾出时间再研究。

参考阅读

  • Lora单卡二次调优
  • 交叉训练验证
  • 数据校正与增强
  • 支持分类原因评测改造

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/884321.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

3-1.Android Fragment 之创建 Fragment

Fragment Fragment 可以视为 Activity 的一个片段&#xff0c;它具有自己的生命周期和接收事件的能力&#xff0c;它有以下特点 Fragment 依赖于 Activity&#xff0c;不能独立存在&#xff0c;Fragment 的生命周期受 Activity 的生命周期影响 Fragment 将 Activity 的 UI 和…

信安 实验1 用Wireshark分析典型TCP/IP体系中的协议

我发现了有些人喜欢静静看博客不聊天呐&#xff0c; 但是ta会点赞。 这样的人呢帅气低调有内涵&#xff0c; 美丽大方很优雅。 说的就是你&#xff0c; 不用再怀疑哦 实验1 用Wireshark分析典型TCP/IP体系中的协议 实验目的 通过Wireshark软件分析典型网络协议数据包&a…

C++深入学习string类成员函数(3):访问与修饰

引言 在 C 中&#xff0c;std::string 提供了丰富的成员函数来访问和修改字符串中的字符。通过这些函数&#xff0c;程序员可以灵活地处理字符串中的各个元素&#xff0c;无论是读取特定位置的字符&#xff0c;还是修改字符串的内容。此外&#xff0c;std::string 类还确保了访…

FileZilla Server 黑白单移除

我使用FileZilla Server 搭建了一个FTP服务在内网使用&#xff0c;主要用于做数据备份的。 有一台服务器一直可以正常连接&#xff0c;突然有一天不能连接了。一开始我以为是FTP服务器出问题了&#xff0c;就一直没管。后来我测试了一下其他IP都可以正常连接FTP服务器&#xff…

STM32嵌入式编程学习到提高:【4】UART串口打印

------------------------------------------------------------------------------------------------------------------------- 工程文件&#xff1a;放在百度云盘里&#xff0c;需要的自行下载&#xff01;&#xff01;&#xff01; 链接: https://pan.baidu.com/s/14gRne…

Gartner 报告解读(二)| Open Telemetry可观测性解读与使用建议

上期跟大家解读了Gartner 成熟度曲线报告&#xff0c;主要分享了影响中国IT使用的4大因素--自主可控计划、AI发展趋势影响、降本增效、IT基础设施现代化程度。新来的朋友点这里&#xff0c;一键了解具体内容。 Gartner 成熟度曲线报告解读&#xff08;一&#xff09;| 2024中国…

Apifox 9月更新|「动态值」全新升级、跨团队引用接口和测试场景、测试报告交互优化

Apifox 新版本上线啦&#xff01;&#xff01;&#xff01; 看看本次版本更新主要涵盖的重点内容&#xff0c;有没有你所关注的功能特性&#xff1a; 「动态值」全新升级 更强大、更灵活的数据模拟能力 支持智能代码补全动态值 测试报告交互优化 支持跨团队引用接口和测试场…

Unity图形用户界面!*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。(万字解析)

Unity 3D GUI 简介 游戏开发过程中&#xff0c;开发人员往往会通过制作大量的图形用户界面&#xff08; Graphical User Interface&#xff0c;GUI &#xff09;来增强游戏与玩家的交互性。 Unity 3D 中的图形系统分为 OnGUI、NGUI、UGUI等&#xff0c;这些类型的图形系统内容…

Django 数据库配置以及字段设置详解

配置PostGre 要在 Django 中配置连接 PostgreSQL 数据库&#xff0c;并创建一个包含“使用人”和“车牌号”等字段的 Car 表 1. 配置 PostgreSQL 数据库连接 首先&#xff0c;在 Django 项目的 settings.py 中配置 PostgreSQL 连接。 修改 settings.py 文件&#xff1a; …

数据定义语言CREATE的应用

新书速览|SQL Server 2022从入门到精通&#xff1a;视频教学超值版_sql server 2022 出版社-CSDN博客 《SQL Server 2022从入门到精通&#xff08;视频教学超值版&#xff09;&#xff08;数据库技术丛书&#xff09;》(王英英)【摘要 书评 试读】- 京东图书 (jd.com) SQL Se…

【Python】1.初始Python--打开Python的大门

&#x1f4da;博客主页&#xff1a;爱敲代码的小杨. ✨专栏&#xff1a;《Java SE语法》 | 《数据结构与算法》 | 《C生万物》 |《MySQL探索之旅》 |《Web世界探险家》 ❤️感谢大家点赞&#x1f44d;&#x1f3fb;收藏⭐评论✍&#x1f3fb;&#xff0c;您的三连就是我持续更…

Java Web —— 第十天(SpringBoot原理)

SpringBoot框架之所以使用起来更简单更快捷&#xff0c;是因为SpringBoot框架底层提供了两个非常重要的 功能&#xff1a;一个是起步依赖&#xff0c;一个是自动配置。 通过SpringBoot所提供的起步依赖&#xff0c;就可以大大的简化pom文件当中依赖的配置&#xff0c;从而解决…

游戏开发2025年最新版——八股文面试题(unity,虚幻,cocos都适用)

1.静态合批与动态合批的原理是什么&#xff1f;有什么限制条件&#xff1f;为什么&#xff1f;对CPU和GPU产生的影响分别是什么&#xff1f; 原理&#xff1a;Unity运行时可以将一些物体进行合并&#xff0c;从而用一个描绘调用来渲染他们&#xff0c;就是一个drawcall批次。 限…

微信占用空间太大,文件清理工具来了

今天分享几个安卓手机文件清理工具。 SD女佣 安卓经典系统清理利器&#xff0c;一键释放存储空间&#xff0c;能清理手机中的垃圾文件、临时文件和无用的应用程序数据&#xff0c;提升设备性能并节省存储空间&#xff0c;内置强大的文件浏览器&#xff0c;支持应用管理和系统…

LeetCode讲解篇之5. 最长回文子串

文章目录 题目描述题解思路题解代码 题目描述 题目链接 题解思路 从中心点先寻找和中心点相等的左右端点&#xff0c;在基于左右端点进行往外扩散&#xff0c;直至左右端点不相等或者越界&#xff0c;然后左右端点这个范围内就是我们找寻的回文串&#xff0c;我们遍历中心点…

AI 大模型浪潮下,大龄程序员怎样转型求变,攀登技术高峰?

前言 在信息技术迅猛发展的今天&#xff0c;程序员作为技术的创造者和实践者&#xff0c;正面临前所未有的挑战。技术的迭代速度日益加快&#xff0c;传统项目的生命周期不断缩短。同时&#xff0c;人工智能&#xff08;AI&#xff09;尤其是大模型技术的兴起&#xff0c;使得…

如何调整云桌面安装的虚拟机分辨率?

如何调整云桌面安装的虚拟机分辨率&#xff1f; 1. 编辑GRUB配置文件2. 修改分辨率3. 更新GRUB4. 重启虚拟机 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 在云桌面环境中&#xff0c;虚拟机分辨率过低且无法调整时&#xff0c;可以通过以…

影刀RPA实战:java结合影刀同步采购订单数据

1.实战目标 本次实战我们用java语言结合影刀&#xff0c;实现从自用ERP系统同步订单到旺店通中&#xff0c;在工作中&#xff0c;有时候我们的运营数据不是直接在旺店通ERP中操作&#xff0c;比如我们有自己的ERP&#xff0c;完成一些特定的内部工作后&#xff0c;再把数据同步…

[3]Opengl ES着色器

术语&#xff1a; VertexShader&#xff1a;顶点着色器&#xff0c;用来描述图形图像位置的顶点坐标&#xff1b; FragmentShader&#xff1a;片元着色器&#xff0c;用来给顶点指定的区域进行着色&#xff1b; Vertex&#xff1a;顶点 Texture&#xff1a;纹理…

云 安 全 (Cloud Security)

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 本人主要分享计算机核心技…