CFC使用教程
更新时间:2025-06-19
如何使用CFC创建奖励规则?
我们在创建CFC任务时候,经常会遇到如何使用CFC创建奖励规则,还有如何填写平台中需要的URL路径等问题,以下将为您阐述具体的实施步骤。
Step1: 创建CFC函数
登陆平台地址,首先选择创建函数并创建空白函数。
基础信息内存设置与奖励函数的复杂度相关,可以根据具体业务场景在 CFC模块中进行调试并选择合适的值。若需日志储存,可依据CFC帮助文档配置 BOS 日志功能,将奖励函数运行时的输出结果存储至日志,便于后续分析与调优。点击“下一步”,进行触发器设置。
触发器:选择HTTP触发器,填写自定义URL路径,其中需填写以“/”开头的匹配路径,完整URL在触发器创建完成后可以查看使用(param)指定路径参数。训练时,使用POST的方式请求远程的服务获取rewards函数;其中,身份验证中建议使用不验证的方式,避免出现IAM权限配置导致奖励规则加载失败的问题。
Step2: 函数管理
提交后可以查看已经创建的CFC函数,进入函数后,点击函数代码就可以开始编辑。
在CFC上开发奖励函数的示例代码
这里给出一份示例代码:
# -*- coding: utf-8 -*-
from typing import List
import os
import re
import json
def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
"""
RLHF reward function based on set (Jaccard) similarity.
Args:
queries (List[str]): List of concatenated prompt and content strings.
prompts (List[str]): List of prompt strings.
labels (List[str]): List of desired response strings.
Returns:
List[float]: List of reward scores (0 to 1) corresponding to each query.
"""
rewards = []
max_prompt_len = 160000
for query, prompt, label in zip(queries, prompts, labels):
# Extract content by removing the prompt from the query
content = query[min(int(max_prompt_len), len(prompt)):]
set_content = set(tokenize(content))
set_label = set(tokenize(label))
if not set_content and not set_label:
similarity = 1.0
else:
intersection = set_content.intersection(set_label)
union = set_content.union(set_label)
similarity = len(intersection) / len(union)
# Ensure similarity is between 0 and 1
similarity = max(0.0, min(similarity, 1.0))
rewards.append(similarity)
# 注意,返回的reward不用再转换成torch里的tensor了,这一步在训练代码中做
return rewards
def tokenize(text: str) -> List[str]:
"""
Tokenize the input text into words.
Args:
text (str): Input text.
Returns:
List[str]: List of word tokens.
"""
return re.findall(r'\b\w+\b', text.lower())
def handler(event, context):
"""
服务端收到 RFT训练工程请求的event示例:
{
'resource': '/get_rewards',
'path': '/get_rewards',
'httpMethod': 'POST',
'headers': {
'Accept': '*/*',
'Accept-Encoding': 'gzip, deflate, br',
'Connection': 'close',
'Content-Length': '392',
'Content-Type': 'application/json',
'User-Agent': 'python-requests/2.31.0',
'X-Bce-Request-Id': '44dbabc1-8a24-4b10-aa15-8f0933c905a4'
},
'queryStringParameters': {},
'pathParameters': {},
'requestContext': {
'stage': 'cfc',
'requestId': '44dbabc1-8a24-4b10-aa15-8f0933c905a4',
'resourcePath': '/get_rewards',
'httpMethod': 'POST',
'apiId': '7eada4qvyq9vk',
'sourceIp': '220.181.3.189'
},
'body': '{"query": ["Prompt: How are you?\\nI\'m doing well, thank you!", "Prompt: Tell me a joke.\\nWhy did the chicken cross the road?", "Prompt: What\'s the weather today?\\nIt\'s sunny and warm."], "prompts": ["Prompt: How are you?", "Prompt: Tell me a joke.", "Prompt: What\'s the weather today?"], "labels": ["I\'m doing well, thank you!", "Why did the chicken cross the road?", "It\'s sunny and warm."]}',
'isBase64Encoded': False
}
这里的event是一个dict形式,POST请求传入的信息在event["body"]中
"""
print(event)
data = event.get("body")
try:
data = json.loads(data)
queries = data.get("query")
prompts = data.get("prompts")
labels = data.get("labels")
rewards = reward_func(queries, prompts, labels)
results = {
"rewards": rewards
}
except Exception as e:
results = {
"error": str(e),
"input_events": str(event),
}
return results
代码编辑相关提示:
最右侧的滚轮滚动到代码块下面,点击保存后最新的修改才会生效。
这里具体的测试方式可以根据CFC帮助文档来进行实现。
Step3: 使用最终的URL路径复制到千帆平台中
‼️注意
在调用的时候,训练中每个step的样本会请求一次奖励函数服务,数目是rollout_batcn_size * numSamplesPerPrompt个样本传到该服务内,在图中的样例就是64 * 8 = 512个query,prompt,labels。
附录
在本地机器上验证服务是否能正常走通,代码如下:
import time
import requests
def request_api_wrapper(url, data, score_key="rewards", try_max_times=5):
"""Synchronous request API wrapper"""
headers = {
"Content-Type": "application/json",
}
for _ in range(try_max_times):
try:
response = requests.post(url=url, json=data, headers=headers, timeout=180)
response.raise_for_status()
response = response.json()
assert score_key in response, f"{score_key} not in {response}"
return response.get(score_key)
except requests.RequestException as e:
print(f"Request error, please check: {e}")
except Exception as e:
print(f"Unexpected error, please check: {e}")
time.sleep(1)
raise Exception(f"Request error for {try_max_times} times, returning None. Please check the API server.")
def remote_rm_fn(api_url, queries, prompts, labels, score_key="rewards"):
"""remote reward model API
api_url: RM API, We assume that the API supports two modes: merging query + response and not merging
queries: query+response with the template
design is made optional.
score_key: RM score key
"""
scores = request_api_wrapper(api_url, {"query": queries, "prompts": prompts, "labels": labels}, score_key)
return scores
if __name__ == "__main__":
# test utils
url = "https://XXXX/get_rewards"
queries = [
"Prompt: How are you?\nI'm doing well, thank you!",
"Prompt: Tell me a joke.\nWhy did the chicken cross the road?",
"Prompt: What's the weather today?\nIt's sunny and warm."
]
prompts = [
"Prompt: How are you?",
"Prompt: Tell me a joke.",
"Prompt: What's the weather today?"
]
labels = [
"I'm doing well, thank you!",
"Why did the chicken cross the road?",
"It's sunny and warm."
]
score = remote_rm_fn(url, queries, prompts, labels)
print(score)
修改一下代码中的url即可,获取url方式如下: