资讯 社区 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

模型训练Trainer使用说明

功能介绍

千帆ModelBuilder Python SDK支持调用Trainer相关API,支持对数据集进行自定义训练。本文使用千帆ModelBuilder SFT语言大模型为例介绍。

注意事项

  • 调用本文API,需使用安全认证AK/SK鉴权,调用流程及鉴权介绍详见SDK安装及使用流程
  • 本文涉及以下函数列表

    • 加载数据集Dataset.load()
    • 创建Trainer LLMFinetune()
    • 配置自定义训练参数TrainConfig()
    • 查询训练参数ModelInfoMapping()
    • 查询训练参数默认值DefaultTrainConfigMapping()
    • 启动训练任务run()
    • 重启训练任务resume()
    • 日志打印enable_log()

调用流程简介

(1)打印日志

注意:如果无需打印日志,可跳过此步骤。

如果需打印过程日志,通过调用enable_log()实现。

(2)准备并加载数据集

注意:数据集要求,必须是有标注的非排序对话数据集

加载千帆ModelBuilder上的数据集,通过调用Dataset.load()实现,详见参数说明。

(3)创建Trainer

调用LLMFInetune()创建Trainer对象,此步骤也会初步校验TrainConfig参数,如果有不符合的字段会打印warning日志。

注意:如果需要自定义训练参数,可通过调用TrainConfig()实现,详见参数说明。

(4)启动训练任务

通过调用run()实现。

(5)重启训练任务

注意:如果无需重启任务,可跳过此步骤。

如果突发断电或者任务停止,可以使用resume()重启任务。

调用示例

未自定义训练参数

import os 
import qianfan

# 使用安全认证AK/SK鉴权,通过环境变量方式初始化;替换下列示例中参数,安全认证Access Key替换your_iam_ak,Secret Key替换your_iam_sk
os.environ["QIANFAN_ACCESS_KEY"] = "your_iam_ak"
os.environ["QIANFAN_SECRET_KEY"] = "your_iam_sk"


# 如果希望打印过程日志,通过调用enable_log(logging.INFO)启用打印日志功能
#from qianfan.utils import enable_log
#import logging
#enable_log(logging.INFO)  # 设置打印日志的最低级别

from qianfan.dataset import Dataset
from qianfan.trainer import LLMFinetune

# 加载千帆ModelBuilder数据集,is_download_to_local=False表示不下载数据集到本地,而是直接使用
ds: Dataset = Dataset.load(qianfan_dataset_id="your_dataset_id", is_download_to_local=False)

# 新建trainer LLMFinetune,需最少传入train_type和dataset
# 注意fine-tune任务需要指定的数据集类型要求为有标注的非排序对话数据集。
trainer = LLMFinetune(
    train_type="ERNIE-xx",
    dataset=ds, 
)

trainer.run()

# 如果突发断电或者任务停止,可以使用resume函数重启任务
# trainer.resume()

自定义训练参数

import os 
import qianfan

# 使用安全认证AK/SK鉴权,通过环境变量方式初始化;替换下列示例中参数,安全认证Access Key替换your_iam_ak,Secret Key替换your_iam_sk
os.environ["QIANFAN_ACCESS_KEY"] = "your_iam_ak"
os.environ["QIANFAN_SECRET_KEY"] = "your_iam_sk"

# 如果希望打印过程日志,通过调用enable_log(logging.INFO)启用打印日志功能
#from qianfan.utils import enable_log
#import logging
#enable_log(logging.INFO)  # 设置打印日志的最低级别

from qianfan.dataset import Dataset
from qianfan.trainer import LLMFinetune
from qianfan.trainer.configs import TrainConfig

# 加载千帆ModelBuilder的数据集。
# qianfan_dataset_id是数据集id,类型要求为有标注的非排序对话数据集is_download_to_local=False表示不下载数据集到本地,直接使用
ds: Dataset = Dataset.load(qianfan_dataset_id="your_dataset_id", is_download_to_local=False)

# 发起训练任务。以基础模型ERNIE-xx为例,需要指定的数据集类型要求为有标注的非排序对话数据集。
trainer = LLMFinetune(
    train_type="ERNIE-xx",
    dataset=ds,
    peft_type="LoRA",
    # 自定义训练参数
    train_config=TrainConfig(
        epochs=1, # 迭代轮次(Epoch),控制训练过程中的迭代轮数。
        # batch_size=32, # 批处理大小(BatchSize)表示在每次训练迭代中使用的样本数。较大的批处理大小可以加速训练.部分模型可能无需填写该字段
        learning_rate=0.00004, # 学习率(LearningRate)是在梯度下降的过程中更新权重时的超参数,过高会导致模型难以收敛,过低则会导致模型收敛速度过慢,
    )
)

trainer.run()

# 如果突发断电或者任务停止,可以使用resume函数重启任务
# trainer.resume()

函数列表

模型训练Trainer需使用的部分函数如下。

  • 加载数据集
  • 创建Trainer
  • 配置自定义训练参数
  • 查询训练参数
  • 查询训练参数默认值

加载数据集

加载千帆ModelBuilder的数据集。

示例

Dataset.load(qianfan_dataset_id="your_dataset_id", is_download_to_local=False)

请求参数

名称 类型 必填 描述
qianfan_dataset_id string 要导入的数据集版本ID,说明:
(1)可以通过以下任一方式获取该字段值:
· 方式一,通过调用创建数据集接口,返回的datasetId字段获取。
· 方式二,在千帆ModelBuilder控制台-数据集管理列表页面,点击详情,在版本信息页查看,如下图所示: img
(2)数据集类型,要求为有标注的非排序对话数据集
is_download_to_local bool 是否下载数据到本地。
True:下载数据集到本地
False:不下载数据集到本地,而是直接使用

创建Trainer

调用LLMFInetune()创建Trainer对象,此步骤也会初步校验TrainConfig参数,如果有不符合的字段会打印warning日志。

示例

LLMFinetune(
    train_type="ERNIE-xx" 
)

请求参数

参数名 数据类型 必填 描述
train_type String 模型版本,示例:ERNIE-Lite-8K-0922,可以通过以下方法获取具体值:
千帆ModelBuilder控制台-模型调优-SFT页面-点击创建训练任务,选择基础模型,查看模型版本,如下图所示:image.png
dataset Optional[Any] 一个数据集实例。说明:
数据集dataset和此参数,至少填写一个
train_config Union[TrainConfig, string] 用于微调训练参数的TrainConfig。说明:
如果不填写此参数,将使用不同模型的默认参数。
deploy_config DeployConfig 用于模型服务部署参数的DeployConfig。说明:
如果需要部署服务,此参数必填。
event_handler EventHandler 用于接收训练过程中事件处理的EventHandler实例。
base_model String 基础模型,示例:ChatGLM2
eval_dataset Optional[Any] 可选的评价数据集
evaluators List[Evaluator] 用于评估的评估器列表
dataset_bos_path String 训练用的 bos 路径,说明:
数据集dataset和此参数,至少填写一个

配置自定义训练参数

示例

TrainConfig(
        epochs=1, 
        batch_size=32, 
        learning_rate=0.00004,
    )

请求参数

模型不同,训练配置使用的参数不同。可以通过以下任一方式查询请求参数:

  • 方式一:通过提供的请求参数列表
  • 方式二:通过调用查询训练参数ModelInfoMapping(),获取参数列表
  • 方式一:请求参数列表

说明:下列表格中的模型支持情况,请参考模型支持情况

名称 类型 必填 描述
epoch int 迭代轮次,说明:该字段取值详情参考模型支持情况
learningRate float 学习率,说明:说明:该字段取值详情参考模型支持情况
batchSize int 批处理大小,说明:该字段取值更多详情参考模型支持情况
maxSeqLen int 序列长度,说明:该字段取值详情参考模型支持情况
loggingSteps int 保存日志间隔,说明:
(1)当为以下情况,该字段必填
· model为ERNIE-Speed-8K,且trainMode为SFT
· model为ERNIE-Lite-8K-0922,且trainMode为SFT
· model为ERNIE-Lite-8K-0308,且trainMode为SFT
· model为ERNIE-Tiny-8K,且trainMode为SFT
(2)取值范围[1, 100],默认值为1
warmupRatio float 预热比例,说明:该字段取值详情参考模型支持情况
weightDecay float 正则化系数,说明:该字段取值详情参考模型支持情况
loraRank int LoRA 策略中的秩,说明:该字段取值详情参考模型支持情况
loraAlpha int 说明:说明:该字段取值更多详情参考模型支持情况
loraAllLinear string LoRA 所有线性层,说明:该字段取值详情参考模型支持情况
loraTargetModules string[] 说明:该字段取值详情参考模型支持情况
loraDropout float 说明:该字段取值更多详情参考模型支持情况
schedulerName string 说明:该字段取值详情参考模型支持情况
Packing bool 可选值:true 或 false,默认值false,说明:该字段取值详情参考模型支持情况
extras Dict[str, Any] {} 其他参数字典,保留值
  • 方式二:通过接口查询参数

详见查询训练参数介绍

查询训练参数

请求示例

from qianfan.trainer.configs import ModelInfoMapping

print(ModelInfoMapping['ERNIE-xx'])

返回示例

short_name='xxx'
base_model_type='ERNIE-Lite-8K-0922'
support_peft_types=[<PeftType.ALL: 'ALL'>, <PeftType.LoRA: 'LoRA'>]
common_params_limit=TrainLimit(
    batch_size_limit=(1, 4),
    max_seq_len_options=[4096, 8192], epoch_limit=(1, 50),
    learning_rate_limit=(2e-07, 0.0002),
    log_steps_limit=None,
    warmup_ratio_limit=None,
    weight_decay_limit=None,
    lora_rank_options=None,
    lora_alpha_options=None,
    lora_dropout_limit=None,
    scheduler_name_options=None
)
specific_peft_types_params_limit={
    'ALL': TrainLimit(
        batch_size_limit=None,
        max_seq_len_options=None,
        epoch_limit=None,
        learning_rate_limit=(1e-05, 4e-05),
        log_steps_limit=None,
        warmup_ratio_limit=None,
        weight_decay_limit=None,
        lora_rank_options=None,
        lora_alpha_options=None,
        lora_dropout_limit=None,
        scheduler_name_options=None
    ),
    'LoRA': TrainLimit(
        batch_size_limit=None,
        max_seq_len_options=None,
        epoch_limit=None,
        learning_rate_limit=(3e-05, 0.001),
        log_steps_limit=None,
        warmup_ratio_limit=None,
        weight_decay_limit=None,
        lora_rank_options=None,
        lora_alpha_options=None,
        lora_dropout_limit=None,
        scheduler_name_options=None
    )
}

请求参数

名称 类型 描述
train_type string 模型版本,可以通过以下方法获取具体值:
千帆ModelBuilder控制台-模型调优-SFT页面-点击创建训练任务,选择基础模型,查看模型版本,如下图所示:image.png

返回参数

说明:下列表格中的模型支持情况,请参考模型支持情况

名称 类型 必填 描述
epoch int 迭代轮次,说明:该字段取值详情参考模型支持情况
learningRate float 学习率,说明:说明:该字段取值详情参考模型支持情况
batchSize int 批处理大小,说明:该字段取值更多详情参考模型支持情况
maxSeqLen int 序列长度,说明:该字段取值详情参考模型支持情况
loggingSteps int 保存日志间隔,说明:
(1)当为以下情况,该字段必填
· model为ERNIE-Speed-8K,且trainMode为SFT
· model为ERNIE-Lite-8K-0922,且trainMode为SFT
· model为ERNIE-Lite-8K-0308,且trainMode为SFT
· model为ERNIE-Tiny-8K,且trainMode为SFT
(2)取值范围[1, 100],默认值为1
warmupRatio float 预热比例,说明:该字段取值详情参考模型支持情况
weightDecay float 正则化系数,说明:该字段取值详情参考模型支持情况
loraRank int LoRA 策略中的秩,说明:该字段取值详情参考模型支持情况
loraAlpha int 说明:说明:该字段取值更多详情参考模型支持情况
loraAllLinear string LoRA 所有线性层,说明:该字段取值详情参考模型支持情况
loraTargetModules string[] 说明:该字段取值详情参考模型支持情况
loraDropout float 说明:该字段取值更多详情参考模型支持情况
schedulerName string 说明:该字段取值详情参考模型支持情况
Packing bool 可选值:true 或 false,默认值false,说明:该字段取值详情参考模型支持情况
extras Dict[str, Any] {} 其他参数字典,保留值

查询训练参数默认值

请求示例

from qianfan.trainer.configs import DefaultTrainConfigMapping

print(DefaultTrainConfigMapping['ERNIE-xx'])

返回示例

epoch=1
batch_size=None
learning_rate=3e-05
max_seq_len=4096
peft_type='LoRA'
trainset_rate=20
logging_steps=None
warmup_ratio=None
weight_decay=None
lora_rank=None
lora_all_linear=None
scheduler_name=None
lora_alpha=None
lora_dropout=None
extras={}

请求参数

名称 类型 描述
train_type string 模型版本,示例:BLOOMZ_7B,可以通过以下方法获取具体值:
千帆ModelBuilder控制台-模型调优-SFT页面-点击创建训练任务,选择基础模型,查看模型版本,如下图所示:image.png

返回参数

说明:下列表格中的模型支持情况,请参考模型支持情况

名称 类型 必填 描述
epoch int 迭代轮次,说明:该字段取值详情参考模型支持情况
learningRate float 学习率,说明:说明:该字段取值详情参考模型支持情况
batchSize int 批处理大小,说明:该字段取值更多详情参考模型支持情况
maxSeqLen int 序列长度,说明:该字段取值详情参考模型支持情况
loggingSteps int 保存日志间隔,说明:
(1)当为以下情况,该字段必填
· model为ERNIE-Speed-8K,且trainMode为SFT
· model为ERNIE-Lite-8K-0922,且trainMode为SFT
· model为ERNIE-Lite-8K-0308,且trainMode为SFT
· model为ERNIE-Tiny-8K,且trainMode为SFT
(2)取值范围[1, 100],默认值为1
warmupRatio float 预热比例,说明:该字段取值详情参考模型支持情况
weightDecay float 正则化系数,说明:该字段取值详情参考模型支持情况
loraRank int LoRA 策略中的秩,说明:该字段取值详情参考模型支持情况
loraAlpha int 说明:说明:该字段取值更多详情参考模型支持情况
loraAllLinear string LoRA 所有线性层,说明:该字段取值详情参考模型支持情况
loraTargetModules string[] 说明:该字段取值详情参考模型支持情况
loraDropout float 说明:该字段取值更多详情参考模型支持情况
schedulerName string 说明:该字段取值详情参考模型支持情况
Packing bool 可选值:true 或 false,默认值false,说明:该字段取值详情参考模型支持情况
extras Dict[str, Any] {} 其他参数字典,保留值
上一篇
命令行CLI工具
下一篇
Prompt对象