跨模态检索任务简介
更新时间:2022-07-28
跨模态检索任务简介
- 跨模态检索结合了视觉和语言的信息,输入是一张图片及一段文本,输出二者匹配的分值。
- 常用的场景包括:图文匹配,短视频检索等。
ERNIE预训练模型选择
文心提供的ERNIE预训练模型的参数文件和配置文件在 wenxin_appzoo/wenxin_appzoo/models_hub目录下,使用对应的sh脚本,即可拉取对应的模型、字典、必要环境等文件。目前支持跨模态检索任务的预训练模型为:
模型名称 | 下载脚本 | 备注 |
---|---|---|
ERNIE-Vil-2.0-twotower | sh download_ernie_vil_2.0_twotower_ch.sh | 参数、字典与ernie的配置文件存放于ernie_vil_2.0_twotower_ch目录 |
快速开始
#进入目录
cd ./wenxin_appzoo/tasks/text_vil_retrieval
代码结构
text_vil_retrieval
├── data # 示例数据目录
│ ├── predict_data
│ │ └── predict.txt
│ ├── test_data
│ │ └── test.txt
│ └── train_data
│ └── train.txt
├── examples # 配置文件目录
│ ├── cross_modal_retrieval_ernie_vil_2.0_ch_infer.json # 预测json配置文件
│ └── cross_modal_retrieval_ernie_vil_2.0_ch.json # 训练json配置文件
├── inference #推理相关类
│ ├── custom_inference.py
│ └── __init__.py
├── model # 模型文件目录
│ ├── ernie_vil_twotower_retrieval_v2.py
│ └── __init__.py
├── reader # 数据加载代码目录
│ └── ernie_vil_retrieval_dataset_reader.py
├── run_infer.py # 预测入口文件
├── run_trainer.py # 训练入口文件
└── trainer
├── custom_trainer.py
└── __init__.py
数据准备
- 文心中的训练集、验证集、测试集分别存放在./data目录下的train_data, dev_data, test_data文件夹下。
- ERNIE-VIL能够获取图片和文本的匹配分值,我们准备的预测数据格式为两列,由\t分开,第一列是文本,第二列是由base64编码的图片。
- 注:数据集为gb18030编码,需要设置环境变量export LANG=zh_CN.UTF-8
- 将图片文件转为base64加密的字符串的示例代码:
import cv2
import numpy as np
import base64
def image_file_to_str(filename):
img = cv2.imread(filename)
img_encode = cv2.imencode(".jpg", img)[1]
data_encode = np.array(img_encode)
str_encode = data_encode.tostring()
return base64.b64encode(str_encode)
开始训练
进入指定任务的目录,跨模型检索任务的目录为wenxin_appzoo/wenxin_appzoo/tasks/text_vil_retrieval, 运行
cd ./wenxin_appzoo/tasks/text_vil_retrieval
python run_trainer.py --param_path ./examples/cross_modal_retrieval_ernie_vil_2.0_ch.json
- 通过上述脚本调用json文件开启训练。
- 训练运行的日志会自动保存在./log/test.log文件中。
- 训练模型保存于./output/ernie_vil_2.0_twotower_ch文件夹下。
开始预测
使用run_infer.py入口脚本,传入预测配置文件(./examples/cross_modal_retrieval_ernie_vil_2.0_ch_infer.json)进行预测:
python run_infer.py --param_path ./examples/cross_modal_retrieval_ernie_vil_2.0_ch_infer.json
预测运行的日志会自动保存在./output/predict_reuslt.txt文件中,预测结果的格式为:匹配分值'\t'文本embedding'\t'图片embedding