资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

CFC使用教程

如何使用CFC创建奖励规则?

我们在创建CFC任务时候,经常会遇到如何使用CFC创建奖励规则,还有如何填写平台中需要的URL路径等问题,以下将为您阐述具体的实施步骤。

Step1: 创建CFC函数

登陆平台地址,首先选择创建函数并创建空白函数

1.png

2.png

基础信息内存设置与奖励函数的复杂度相关,可以根据具体业务场景在 CFC模块中进行调试并选择合适的值。若需日志储存,可依据CFC帮助文档配置 BOS 日志功能,将奖励函数运行时的输出结果存储至日志,便于后续分析与调优。点击“下一步”,进行触发器设置。

3.png

触发器:选择HTTP触发器,填写自定义URL路径,其中需填写以“/”开头的匹配路径,完整URL在触发器创建完成后可以查看使用(param)指定路径参数。训练时,使用POST的方式请求远程的服务获取rewards函数;其中,身份验证中建议使用不验证的方式,避免出现IAM权限配置导致奖励规则加载失败的问题

4.png

Step2: 函数管理

提交后可以查看已经创建的CFC函数,进入函数后,点击函数代码就可以开始编辑。

5.png

6.png

在CFC上开发奖励函数的示例代码

这里给出一份示例代码:

# -*- coding: utf-8 -*-

from typing import List
import os
import re
import json

def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
    """
    RLHF reward function based on set (Jaccard) similarity.
    
    Args:
        queries (List[str]): List of concatenated prompt and content strings.
        prompts (List[str]): List of prompt strings.
        labels (List[str]): List of desired response strings.
    
    Returns:
        List[float]: List of reward scores (0 to 1) corresponding to each query.
    """
    rewards = []
    max_prompt_len = 160000
    
    for query, prompt, label in zip(queries, prompts, labels):
        # Extract content by removing the prompt from the query
        content = query[min(int(max_prompt_len), len(prompt)):]
        
        set_content = set(tokenize(content))
        set_label = set(tokenize(label))
        if not set_content and not set_label:
            similarity = 1.0
        else:
            intersection = set_content.intersection(set_label)
            union = set_content.union(set_label)
            similarity = len(intersection) / len(union)
        # Ensure similarity is between 0 and 1
        similarity = max(0.0, min(similarity, 1.0))
        rewards.append(similarity)
        # 注意,返回的reward不用再转换成torch里的tensor了,这一步在训练代码中做
    return rewards

def tokenize(text: str) -> List[str]:
    """
    Tokenize the input text into words.
    
    Args:
        text (str): Input text.
    
    Returns:
        List[str]: List of word tokens.
    """
    return re.findall(r'\b\w+\b', text.lower())

def handler(event, context): 
    """
    服务端收到 RFT训练工程请求的event示例:
    {
        'resource': '/get_rewards',
        'path': '/get_rewards',
        'httpMethod': 'POST',
        'headers': {
            'Accept': '*/*',
            'Accept-Encoding': 'gzip, deflate, br',
            'Connection': 'close',
            'Content-Length': '392',
            'Content-Type': 'application/json',
            'User-Agent': 'python-requests/2.31.0',
            'X-Bce-Request-Id': '44dbabc1-8a24-4b10-aa15-8f0933c905a4'
        },
        'queryStringParameters': {},
        'pathParameters': {},
        'requestContext': {
            'stage': 'cfc',
            'requestId': '44dbabc1-8a24-4b10-aa15-8f0933c905a4',
            'resourcePath': '/get_rewards',
            'httpMethod': 'POST',
            'apiId': '7eada4qvyq9vk',
            'sourceIp': '220.181.3.189'
        },
        'body': '{"query": ["Prompt: How are you?\\nI\'m doing well, thank you!", "Prompt: Tell me a joke.\\nWhy did the chicken cross the road?", "Prompt: What\'s the weather today?\\nIt\'s sunny and warm."], "prompts": ["Prompt: How are you?", "Prompt: Tell me a joke.", "Prompt: What\'s the weather today?"], "labels": ["I\'m doing well, thank you!", "Why did the chicken cross the road?", "It\'s sunny and warm."]}',
        'isBase64Encoded': False
    }
    这里的event是一个dict形式,POST请求传入的信息在event["body"]中
    """
    print(event)
    data = event.get("body")
    try:
        data = json.loads(data)
        queries = data.get("query")
        prompts = data.get("prompts")
        labels = data.get("labels")
        rewards = reward_func(queries, prompts, labels)
        results = {
            "rewards": rewards
        }
    except Exception as e:
        results = {
            "error": str(e),
            "input_events": str(event),
        }
    return results

代码编辑相关提示:

7.png

最右侧的滚轮滚动到代码块下面,点击保存后最新的修改才会生效。

8.png

这里具体的测试方式可以根据CFC帮助文档来进行实现。

9.png

Step3: 使用最终的URL路径复制到千帆平台中

10.png

‼️注意

在调用的时候,训练中每个step的样本会请求一次奖励函数服务,数目是rollout_batcn_size * numSamplesPerPrompt个样本传到该服务内,在图中的样例就是64 * 8 = 512个query,prompt,labels。

11.png

附录

在本地机器上验证服务是否能正常走通,代码如下:

import time
import requests


def request_api_wrapper(url, data, score_key="rewards", try_max_times=5):
    """Synchronous request API wrapper"""
    headers = {
        "Content-Type": "application/json",
    }
    for _ in range(try_max_times):
        try:
            response = requests.post(url=url, json=data, headers=headers, timeout=180)
            response.raise_for_status()
            response = response.json()
            assert score_key in response, f"{score_key} not in {response}"
            return response.get(score_key)
        except requests.RequestException as e:
            print(f"Request error, please check: {e}")
        except Exception as e:
            print(f"Unexpected error, please check: {e}")
        time.sleep(1)

    raise Exception(f"Request error for {try_max_times} times, returning None. Please check the API server.")


def remote_rm_fn(api_url, queries, prompts, labels, score_key="rewards"):
    """remote reward model API
    api_url: RM API, We assume that the API supports two modes: merging query + response and not merging
    queries: query+response with the template
    design is made optional.
    score_key: RM score key
    """
    scores = request_api_wrapper(api_url, {"query": queries, "prompts": prompts, "labels": labels}, score_key)
    return scores



if __name__ == "__main__":
    # test utils
    url = "https://XXXX/get_rewards"
    queries = [
        "Prompt: How are you?\nI'm doing well, thank you!",
        "Prompt: Tell me a joke.\nWhy did the chicken cross the road?",
        "Prompt: What's the weather today?\nIt's sunny and warm."
    ]
    
    prompts = [
        "Prompt: How are you?",
        "Prompt: Tell me a joke.",
        "Prompt: What's the weather today?"
    ]
    
    labels = [
        "I'm doing well, thank you!",
        "Why did the chicken cross the road?",
        "It's sunny and warm."
    ]
    score = remote_rm_fn(url, queries, prompts, labels)
    print(score)


修改一下代码中的url即可,获取url方式如下:

12.png

上一篇
RFT自定义奖励规则
下一篇
查看训练列表