资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

跨模态检索任务简介

跨模态检索任务简介

  • 跨模态检索结合了视觉和语言的信息,输入是一张图片及一段文本,输出二者匹配的分值。
  • 常用的场景包括:图文匹配,短视频检索等。

ERNIE预训练模型选择

文心提供的ERNIE预训练模型的参数文件和配置文件在 wenxin_appzoo/wenxin_appzoo/models_hub目录下,使用对应的sh脚本,即可拉取对应的模型、字典、必要环境等文件。目前支持跨模态检索任务的预训练模型为:

模型名称 下载脚本 备注
ERNIE-Vil-2.0-twotower sh download_ernie_vil_2.0_twotower_ch.sh 参数、字典与ernie的配置文件存放于ernie_vil_2.0_twotower_ch目录

快速开始

#进入目录
cd ./wenxin_appzoo/tasks/text_vil_retrieval

代码结构

    text_vil_retrieval
    ├── data	# 示例数据目录
    │   ├── predict_data
    │   │   └── predict.txt
    │   ├── test_data
    │   │   └── test.txt
    │   └── train_data
    │       └── train.txt
    ├── examples	# 配置文件目录
    │   ├── cross_modal_retrieval_ernie_vil_2.0_ch_infer.json	# 预测json配置文件
    │   └── cross_modal_retrieval_ernie_vil_2.0_ch.json	# 训练json配置文件
    ├── inference	#推理相关类
    │   ├── custom_inference.py
    │   └── __init__.py
    ├── model	# 模型文件目录
    │   ├── ernie_vil_twotower_retrieval_v2.py
    │   └── __init__.py
    ├── reader	# 数据加载代码目录
    │   └── ernie_vil_retrieval_dataset_reader.py
    ├── run_infer.py	# 预测入口文件
    ├── run_trainer.py	# 训练入口文件
    └── trainer
        ├── custom_trainer.py
        └── __init__.py

数据准备

  • 文心中的训练集、验证集、测试集分别存放在./data目录下的train_data, dev_data, test_data文件夹下。
  • ERNIE-VIL能够获取图片和文本的匹配分值,我们准备的预测数据格式为两列,由\t分开,第一列是文本,第二列是由base64编码的图片。
  • 注:数据集为gb18030编码,需要设置环境变量export LANG=zh_CN.UTF-8
  • 将图片文件转为base64加密的字符串的示例代码:
import cv2
import numpy as np
import base64

def image_file_to_str(filename):
    img = cv2.imread(filename)
    img_encode = cv2.imencode(".jpg", img)[1]
    data_encode = np.array(img_encode)
    str_encode = data_encode.tostring()
    return base64.b64encode(str_encode)

开始训练

进入指定任务的目录,跨模型检索任务的目录为wenxin_appzoo/wenxin_appzoo/tasks/text_vil_retrieval, 运行

cd ./wenxin_appzoo/tasks/text_vil_retrieval
python run_trainer.py --param_path ./examples/cross_modal_retrieval_ernie_vil_2.0_ch.json
  • 通过上述脚本调用json文件开启训练。
  • 训练运行的日志会自动保存在./log/test.log文件中。
  • 训练模型保存于./output/ernie_vil_2.0_twotower_ch文件夹下。

开始预测

使用run_infer.py入口脚本,传入预测配置文件(./examples/cross_modal_retrieval_ernie_vil_2.0_ch_infer.json)进行预测:

python run_infer.py --param_path ./examples/cross_modal_retrieval_ernie_vil_2.0_ch_infer.json

预测运行的日志会自动保存在./output/predict_reuslt.txt文件中,预测结果的格式为:匹配分值'\t'文本embedding'\t'图片embedding

上一篇
(升级)模型蒸馏任务
下一篇
数据蒸馏任务