RFT自定义奖励规则
更新时间:2025-06-04
RFT训练方法中预置五种奖励规则,奖励规则中定义了评估模型输出好坏的规则,以下给出代码可供查看或修改使用。
1.字符串比较(相等)
from typing import List
import re
import os
import torch
def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
"""
Rule-based RLHF reward function.
Args:
queries (List[str]): List of concatenated prompt and content strings.
prompts (List[str]): List of prompt strings.
labels (List[str]): List of desired response strings.
Returns:
List[float]: List of reward scores corresponding to each query.
"""
rewards = []
max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1')
for idx, (query, prompt, label) in enumerate(zip(queries, prompts, labels)):
# Extract content by removing the prompt from the query
content = query[min(int(max_prompt_len), len(prompt)):]
# Define reward rules
if content == label:
reward = 1.0 # Exact match
# elif is_semantically_similar(content, label):
# reward = 0.7 # Partial match based on semantic similarity
elif has_keyword_overlap(content, label):
reward = 0.5 # Partial match based on keyword overlap
else:
reward = 0.0 # No meaningful match
rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float)
def has_keyword_overlap(text1: str, text2: str) -> bool:
"""
Check if there is significant keyword overlap between two texts.
Args:
text1 (str): First text.
text2 (str): Second text.
Returns:
bool: True if there's keyword overlap, False otherwise.
"""
# Simple keyword extraction: lowercase and split by non-word characters
keywords1 = set(re.findall(r'\b\w+\b', text1.lower()))
keywords2 = set(re.findall(r'\b\w+\b', text2.lower()))
overlap = keywords1.intersection(keywords2)
overlap_ratio = len(overlap) / max(len(keywords1), len(keywords2))
return overlap_ratio > 0.3 # Threshold can be adjusted
2.字符串比较(包含)
from typing import List
import os
import re
import torch
def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
"""
RLHF reward function based on set (Jaccard) similarity.
Args:
queries (List[str]): List of concatenated prompt and content strings.
prompts (List[str]): List of prompt strings.
labels (List[str]): List of desired response strings.
Returns:
List[float]: List of reward scores (0 to 1) corresponding to each query.
"""
rewards = []
max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1')
for query, prompt, label in zip(queries, prompts, labels):
# Extract content by removing the prompt from the query
content = query[min(int(max_prompt_len), len(prompt)):]
set_content = set(tokenize(content))
set_label = set(tokenize(label))
if not set_content and not set_label:
similarity = 1.0
else:
intersection = set_content.intersection(set_label)
union = set_content.union(set_label)
similarity = len(intersection) / len(union)
# Ensure similarity is between 0 and 1
similarity = max(0.0, min(similarity, 1.0))
rewards.append(similarity)
return torch.tensor(rewards, dtype=torch.float)
def tokenize(text: str) -> List[str]:
"""
Tokenize the input text into words.
Args:
text (str): Input text.
Returns:
List[str]: List of word tokens.
"""
return re.findall(r'\b\w+\b', text.lower())
# Example Usage
if __name__ == "__main__":
queries = [
"Prompt: How are you?\nI'm doing well, thank you!",
"Prompt: Tell me a joke.\nWhy did the chicken cross the road?",
"Prompt: What's the weather today?\nIt's sunny and warm."
]
prompts = [
"Prompt: How are you?",
"Prompt: Tell me a joke.",
"Prompt: What's the weather today?"
]
labels = [
"I'm doing well, thank you!",
"Why did the chicken cross the road?",
"It's sunny and warm."
]
rewards_set_matching = reward_func(queries, prompts, labels)
print("Set Matching-Based Rewards:", rewards_set_matching)
# Output: [1.0, 1.0, 1.0]
3.字符串相似度对比
from typing import List
import re
import numpy as np
import os
import torch
def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
"""
RLHF reward function based on normalized edit distance.
Args:
queries (List[str]): List of concatenated prompt and content strings.
prompts (List[str]): List of prompt strings.
labels (List[str]): List of desired response strings.
Returns:
List[float]: List of reward scores (0 to 1) corresponding to each query.
"""
rewards = []
for query, prompt, label in zip(queries, prompts, labels):
content = extract_content(query, prompt)
distance = levenshtein_distance(content, label)
max_len = max(len(content), len(label))
if max_len == 0:
similarity = 1.0
else:
similarity = 1 - (distance / max_len)
# Ensure similarity is between 0 and 1
similarity = max(0.0, min(similarity, 1.0))
rewards.append(similarity)
return torch.tensor(rewards, dtype=torch.float)
def extract_content(query: str, prompt: str) -> str:
"""
Extract content from query by removing the prompt.
Args:
query (str): The concatenated prompt and content.
prompt (str): The prompt part.
Returns:
str: The extracted content.
"""
max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1024')
# Extract content by removing the prompt from the query
return query[min(int(max_prompt_len), len(prompt)):]
def levenshtein_distance(s1: str, s2: str) -> int:
"""
Compute the Levenshtein distance between two strings.
Args:
s1 (str): First string.
s2 (str): Second string.
Returns:
int: The Levenshtein distance.
"""
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)
# len(s1) >= len(s2)
previous_row = list(range(len(s2) + 1))
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1 # insertion
deletions = current_row[j] + 1 # deletion
substitutions = previous_row[j] + (c1 != c2) # substitution
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
# Example Usage
if __name__ == "__main__":
queries = [
"Prompt: How are you?\nI'm doing well, thank you!",
"Prompt: Tell me a joke.\nWhy did the chicken cross the road?",
"Prompt: What's the weather today?\nIt's sunny and warm."
]
prompts = [
"Prompt: How are you?",
"Prompt: Tell me a joke.",
"Prompt: What's the weather today?"
]
labels = [
"I'm doing well, thank you!",
"Why did the chicken cross the road?",
"It's sunny and warm."
]
rewards_edit_distance = reward_func(queries, prompts, labels)
print("Edit Distance-Based Rewards:", rewards_edit_distance)
# Output: [1.0, 1.0, 1.0]
4.数学答案匹配
import asyncio
import re
from itertools import islice, zip_longest
import ray
from sympy.parsing.latex import parse_latex
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import torch
import os
try:
from math_verify import parse, verify
except ImportError:
print("math_verify is not installed in this environment")
parse = None
verify = None
def repeatness(s: str):
def ranks(l):
index = {v: i for i, v in enumerate(sorted(set(l)))}
return [index[v] for v in l]
def suffixArray(s):
line = ranks(s)
n, k, ans, sa = len(s), 1, line, [0] * len(s)
while k < n - 1:
line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
ans, k = line, k << 1
for i, k in enumerate(ans):
sa[k] = i
return ans, sa
def lcp(arr, suffixArr, inv_suff):
n, ans, k = len(arr), [0] * len(arr), 0
for i in range(n):
if inv_suff[i] == n - 1:
k = 0
continue
j = suffixArr[inv_suff[i] + 1]
while i + k < n and j + k < n and arr[i + k] == arr[j + k]:
k += 1
ans[inv_suff[i]] = k
if k > 0:
k -= 1
return ans
arr = [ord(i) for i in s]
n = len(arr)
if n <= 1:
return 0
c, sa = suffixArray(arr)
cnt = sum(lcp(arr, sa, c))
return (cnt * 2 / (n * (n + 1))) > 0.2
SUBSTITUTIONS = [
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"ft",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
]
def normalize_final_answer(final_answer: str) -> str:
"""
Normalize a final answer to a quantitative reasoning question.
This code comes from https://arxiv.org/pdf/2206.14858.pdf, page18.
"""
# final_answer = final_answer.split("=")[-1]
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract answer that is in LaTeX math, is bold,
# is surrounded by a box, etc.
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
# Normalize 100,000 -> 100000
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
return final_answer
def latex_eval(latex):
sym = parse_latex(latex)
val = sym.evalf()
return sym, val
def _is_latex_equal(str1, str2):
try:
sym1, val1 = latex_eval(str1)
sym2, val2 = latex_eval(str2)
if sym1 == sym2 or val1 == val2:
return True
else:
raise ValueError
except Exception: # noqa
try:
norm1, norm2 = normalize_final_answer(str1), normalize_final_answer(str2)
sym1, val1 = latex_eval(norm1)
sym2, val2 = latex_eval(norm2)
if sym1 == sym2 or val1 == val2:
return True
except Exception: # noqa
return norm1 == norm2
return False
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except Exception: # noqa
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except Exception: # noqa
return string
def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def _fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def _strip_string(string):
# linebreaks
string = string.replace("\n", "")
# print(string)
# remove inverse spaces
string = string.replace("\\!", "")
# print(string)
# replace \\ with \
string = string.replace("\\\\", "\\")
# print(string)
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# print(string)
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# print(string)
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
string = string.replace("$", "")
string = string.replace(",", "")
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def is_equiv(str1, str2, verbose=False) -> bool:
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = _strip_string(str1)
ss2 = _strip_string(str2)
if verbose:
print(ss1, ss2)
try:
return float(ss1) == (float(ss2))
except Exception: # noqa
return ss1 == ss2
except Exception: # noqa
return str1 == str2
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx : right_brace_idx + 1]
return retval
def remove_boxed(s):
left = "\\boxed{"
try:
assert s[: len(left)] == left
assert s[-1] == "}"
return s[len(left) : -1]
except Exception:
return None
def get_answer_str(s: str) -> str:
res = remove_boxed(last_boxed_only_string(s))
if res is not None:
return res
return s
def is_equal(str1, str2, math_mode="legacy"):
first_equal = is_equiv(str1, str2)
if first_equal:
return True
return is_latex_equal(str1, str2, math_mode)
def solution2answer(solution: str, math_mode="eval_peeking") -> str:
answer = solution
if math_mode == "eval_peeking":
answer = get_answer_str(solution)
else:
raise ValueError(f"Invalid math_mode: {math_mode}")
return answer
def get_final_answer(output: str) -> str:
output = output.replace("is:", "is").replace("answer:", "answer is").strip()
if output.endswith("."):
output = output[:-1]
if ".$" in output:
output = output.replace(".$", "$")
pattern_list = [
r"answer is (-?\d+\.?\d*)$",
r"answer is (.+?)$",
]
matches = []
for pat in pattern_list:
matches = re.findall(pat, output, re.S)
if matches:
return get_answer_str(matches[0])
return get_answer_str(output)
@ray.remote(num_cpus=1)
def extract_final_answers_batch(responses: List[str]) -> List[str]:
# pattern = re.compile(r"(\\boxed{.*})")
pattern = re.compile(r"<answer>.*?(\\boxed{.*}).*?</answer>", re.DOTALL)
results = []
for response in responses:
matches = re.findall(pattern, response)
results.append(matches[-1] if matches else "")
return results
def is_latex_equal(str1: str, str2: str, math_mode: str = "legacy") -> bool:
"""
同步比较两个 LaTeX 字符串是否在数学意义上是等价的。
"""
if math_mode == "legacy":
# 检查重复性
if (len(str1) > 128 and repeatness(str1)) or (len(str2) > 128 and repeatness(str2)):
return False
try:
# 直接调用同步函数进行比较
return _is_latex_equal(str1, str2)
except Exception:
return False
elif math_mode == "math_verify":
try:
# 直接调用同步函数进行比较
return verify(parse(str1), parse(str2))
except Exception:
return False
else:
raise NotImplementedError(f"Math mode {math_mode} is not implemented")
def reward_func(queries, prompts, labels):
# queries is prompts + responses
# labels is answers
rewards = []
outputs = []
max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1024')
for query, prompt in zip(queries, prompts):
# Extract content by removing the prompt from the query
max_len = min(len(prompt), int(max_prompt_len))
output = query[max_len:].strip()
outputs.append(output)
# 分布式提取最终答案
final_answers = ray.get(extract_final_answers_batch.remote(outputs))
for label, final_answer in zip(labels, final_answers):
result = is_equal(solution2answer(label), solution2answer(final_answer))
score = 1.0 if result else 0.0
rewards.append(score)
# print('rewards are', rewards)
return torch.tensor(rewards, dtype=torch.float)
5.逻辑推理匹配
import re
from typing import Dict, Tuple, Optional
import torch
import os
def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
"""Extracts the final answer from the model's response string.
Args:
solution_str: Raw response string from the language model
Returns:
Tuple containing (extracted_answer, processed_string)
"""
# Extract final answer using XML-style tags
answer_pattern = r'<answer>(.*?)</answer>'
matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
if not matches:
print("[Error] No valid answer tags found")
return None
final_answer = matches[-1].group(1).strip()
return final_answer
def parse_solution_text_format(solution_text: str) -> Dict[str, str]:
"""Parses ground truth solution text into status dictionary.
Args:
solution_text: Formatted solution text from dataset
Returns:
Dictionary mapping character names to their roles (knight/knave)
"""
status_dict = {}
print("\n[Ground Truth Parsing]")
for line in solution_text.split('\n'):
line = line.strip()
if not line:
continue
match = re.search(r'\b([A-Za-z]+)\b.*?\b(knight|knave)\b', line, re.IGNORECASE)
if match:
name, role = match.groups()
status_dict[name] = role.lower()
print(f" Found: {name} → {role}")
else:
print(f" [Warning] Unparseable line: '{line}'")
return status_dict
def parse_model_answer(answer_text: str, expected_names: list) -> Optional[Dict[str, str]]:
"""Parses model's answer text into status dictionary.
Args:
answer_text: Text extracted from model's <answer> tags
expected_names: List of character names requiring identification
Returns:
Dictionary mapping character names to predicted roles, or None if incomplete
"""
status_dict = {}
print("\n[Model Answer Parsing]")
print(f" Expected characters: {expected_names}")
knight_count = answer_text.lower().count('knight')
knave_count = answer_text.lower().count('knave')
print(f" Number of predicted roles: {knight_count + knave_count}")
if knight_count + knave_count != len(expected_names):
print(f" [Error] Number of characters mismatch: {knight_count + knave_count} != {len(expected_names)}")
return None
for name in expected_names:
pattern = re.compile(
rf'\b{re.escape(name)}\b\s+is\s+a\s+\b(knight|knave)\b',
re.IGNORECASE
)
match = pattern.search(answer_text)
if match:
role = match.group(1).lower()
status_dict[name] = role
print(f" Found: {name} → {role}")
else:
print(f" [Error] Missing identification for {name}")
return None
return status_dict
def validate_response_structure(processed_str: str) -> bool:
"""Performs comprehensive validation of response structure.
Args:
processed_str: Processed response string from the model
Returns:
Boolean indicating whether all formatting requirements are met
"""
print("\n[Structure Validation]")
validation_passed = True
# Check required tags
tags = {
# 'think_start': ('<think>', 1),
'think_end': ('</think>', 1),
'answer_start': ('<answer>', 1),
'answer_end': ('</answer>', 1)
}
positions = {}
for tag_name, (tag_str, expected_count) in tags.items():
count = processed_str.count(tag_str)
positions[tag_name] = pos = processed_str.find(tag_str)
print(f" {tag_str}: count={count}, position={pos}")
if count != expected_count:
print(f" [Error] {tag_str} appears {count} times (expected {expected_count})")
validation_passed = False
# Verify tag order
if (positions['think_end'] > positions['answer_start'] or
positions['answer_start'] > positions['answer_end']):
print(" [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
validation_passed = False
else:
print(" Tag sequence validation passed")
return validation_passed
def reward_func(queries, prompts, labels):
"""Computes comprehensive score for model response.
Args:
Returns:
Total score (sum of format and answer rewards)
"""
rewards = []
max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1024')
for query, prompt, label in zip(queries, prompts, labels):
# format_reward: int = 1
max_len = min(len(prompt), int(max_prompt_len))
output = query[max_len:].strip()
print("\n" + "="*80)
print(" Processing New Sample ".center(80, '='))
# Parse ground truth data
solution_text = label
gt_status = parse_solution_text_format(solution_text)
expected_names = list(gt_status.keys())
print(f"[Ground Truth] Final identities: {gt_status}")
# Extract model answer
answer_text = extract_solution(query)
print(f"\n[Model Response]\n{output}")
answer_score = 0
if answer_text:
pred_status = parse_model_answer(answer_text, expected_names)
if pred_status:
if pred_status == gt_status:
answer_score = 1
print(" Content validation: FULL MATCH")
else:
answer_score = 0
print(" Content validation: MISMATCH")
rewards.append(answer_score)
return torch.tensor(rewards, dtype=torch.float)
自定义奖励规则
如果开发场景比较复杂,或者预置的规则无法满足需求,您可以参考下述格式,自定义奖励规则。
import torch
import os
def reward_func(queries, prompts, labels):
"""
Calculate rewards based on queries, prompts, and labels.
Args:
queries (list of str): Prompts + responses.即模型真实的输入和输出。
prompts (list of str): Input prompts.模型的输入。
labels (list of str): Ground truth answers.标注的模型的输出。
Returns:
torch.Tensor: A tensor of rewards.
"""
rewards = []
outputs = []
max_prompt_len = int(os.environ.get('MAX_PROMPT_LEN', '1024'))
for query, prompt in zip(queries, prompts):
# Extract content by removing the prompt from the query
max_len = min(len(prompt), max_prompt_len)
output = query[max_len:].strip()
outputs.append(output)
# Rule-based reward process here
# Ensure process() is defined and returns a list of rewards
rewards = process(outputs, labels)
# Convert rewards to a tensor
return torch.tensor(rewards, dtype=torch.float)
函数名、输入、输出需要按照上述的规则定义。
reward_func代码定义后,可以通过下述的代码测试是否可用:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import shutil
import json
import glob
import random
import string
import importlib.util
from typing import List
import time
import torch
import asyncio
import signal
# 定义错误码
ERROR_CODE = {
'error_code': 4103,
'error_msg': '自定义奖励规则校验失败,请检查奖励规则是否正确。'
}
def write_error(workspace: str, error_code: dict):
"""
将错误码写入 error.json 文件。
"""
error_path = os.path.join(workspace, 'error.json')
try:
with open(error_path, 'w', encoding='utf-8') as f:
json.dump(error_code, f, ensure_ascii=False, indent=4)
print(f"错误信息已写入 {error_path}")
except Exception as e:
print(f"无法写入错误信息到 {error_path}: {e}", file=sys.stderr)
def get_environment_variable(var_name: str) -> str:
"""
获取环境变量的值,如果未设置则抛出异常。
"""
value = os.environ.get(var_name)
if not value:
ERROR_CODE['error_msg'] = f'环境变量 {var_name} 未设置。'
print(ERROR_CODE['error_msg'])
raise EnvironmentError(ERROR_CODE)
return value
def check_single_python_file(source_dir: str) -> str:
"""
检查指定目录下是否只有一个 Python 文件。
返回该文件的路径。
"""
python_files = glob.glob(os.path.join(source_dir, '*.py'))
if len(python_files) != 1:
ERROR_CODE['error_msg'] = f'指定目录下有 {len(python_files)} 个 Python 文件。需要且只能有一个。'
print(ERROR_CODE['error_msg'])
raise FileNotFoundError(ERROR_CODE)
return python_files[0]
def move_file(src: str, dst: str):
"""
移动文件,从 src 到 dst。
"""
try:
shutil.copyfile(src, dst)
print(f"文件已从 {src} 复制到 {dst}")
except Exception as e:
ERROR_CODE['error_msg'] = f'移动文件失败: {e}'
print(ERROR_CODE['error_msg'])
raise shutil.Error(ERROR_CODE)
def load_test_data(test_data_path: str, num_samples: int = 20):
"""
从 JSONL 文件中加载前 num_samples 条测试数据。
返回 prompts 和 labels 列表,只提取每个列表的第一个元素。
"""
prompts = []
labels = []
try:
with open(test_data_path, 'r', encoding='utf-8') as f:
for _ in range(num_samples):
line = f.readline()
if not line:
break
data = json.loads(line)
src = data.get('src', [])
tgt = data.get('tgt', [])
if isinstance(src, list) and isinstance(tgt, list) and src and tgt:
prompt = src[0].strip()
label = tgt[0].strip()
prompts.append(prompt)
labels.append(label)
if not prompts or not labels:
ERROR_CODE['error_msg'] = '测试数据中缺少 "src" 或 "tgt" 字段,或它们不是非空的列表。'
raise ValueError(ERROR_CODE)
print(f"已加载 {len(prompts)} 条测试数据。")
return prompts, labels
except Exception as e:
if isinstance(e.args[0], dict):
raise e
ERROR_CODE['error_msg'] = f'加载测试数据失败: {e}'
print(ERROR_CODE['error_msg'])
raise ValueError(ERROR_CODE)
def generate_queries(prompts: List[str]) -> List[str]:
"""
为每个 prompt 生成一个 query,通过在 prompt 后拼接随机字符,确保不超过 100 个字符。
"""
queries = []
for prompt in prompts:
max_extra_length = 100 - len(prompt)
if max_extra_length <= 0:
query = prompt[:100]
else:
random_length = random.randint(1, max_extra_length)
random_chars = ''.join(random.choices(string.ascii_letters + string.digits, k=random_length))
query = prompt + random_chars
queries.append(query)
print("已生成 queries。")
return queries
def import_reward_func(auto_py_path: str):
"""
动态导入 auto.py 并获取 reward_func 函数。
"""
try:
spec = importlib.util.spec_from_file_location("auto", auto_py_path)
auto = importlib.util.module_from_spec(spec)
spec.loader.exec_module(auto)
if not hasattr(auto, 'reward_func'):
ERROR_CODE['error_msg'] = 'auto.py 中未找到 reward_func 函数。'
raise AttributeError(ERROR_CODE)
reward_func = auto.reward_func
if not callable(reward_func):
ERROR_CODE['error_msg'] = 'reward_func 不是可调用的函数。'
raise TypeError(ERROR_CODE)
print("成功导入 reward_func 函数。")
return reward_func
except Exception as e:
if isinstance(e, dict):
raise e
ERROR_CODE['error_msg'] = f'导入 reward_func 函数失败: {e}'
print(ERROR_CODE['error_msg'])
raise ImportError(ERROR_CODE)
def run_reward_func_sync(reward_func, queries: List[str], prompts: List[str], labels: List[str]):
return reward_func(queries, prompts, labels)
async def run_reward_func_async(reward_func, queries: List[str], prompts: List[str], labels: List[str], timeout_sec: int):
loop = asyncio.get_event_loop()
try:
rewards = await asyncio.wait_for(loop.run_in_executor(None, run_reward_func_sync, reward_func, queries, prompts, labels), timeout=timeout_sec)
except asyncio.TimeoutError:
ERROR_CODE['error_msg'] = f'reward_func 运行超过 {timeout_sec} 秒,已终止。'
print(ERROR_CODE['error_msg'])
raise TimeoutError(ERROR_CODE['error_msg'])
except Exception as e:
ERROR_CODE['error_msg'] = f'reward_func 运行失败: {e}'
print(ERROR_CODE['error_msg'])
raise RuntimeError(ERROR_CODE['error_msg'])
# 检查返回值
if not all(isinstance(r, torch.Tensor) and r.dtype == torch.float for r in rewards):
ERROR_CODE['error_msg'] = 'reward_func 返回的结果格式不正确。'
print(ERROR_CODE['error_msg'])
raise ValueError(ERROR_CODE['error_msg'])
print("reward_func 成功运行。")
return rewards
def main():
try:
# 获取环境变量
workspace = get_environment_variable('WORKSPACE')
source_dir = os.path.join(workspace, 'rft_reward_func')
test_data_path = os.path.join(workspace, 'train_eval_data', 'sft_train.jsonl')
error_output_path = os.path.join(workspace, 'error.json')
destination_dir = '/qianfan/rudder_rl/openrlhf/graders'
destination_path = os.path.join(destination_dir, 'auto.py')
# 步骤1:检查指定路径下是否只有一个 Python 文件
python_file = check_single_python_file(source_dir)
# 步骤2:移动文件到指定目录
# 确保目标目录存在
if not os.path.exists(destination_dir):
os.makedirs(destination_dir)
print(f"已创建目标目录 {destination_dir}")
move_file(python_file, destination_path)
# 步骤3:加载测试数据
prompts, labels = load_test_data(test_data_path, num_samples=20)
# 步骤4:生成 queries
queries = generate_queries(prompts)
# 步骤5:导入 reward_func 函数
reward_func = import_reward_func(destination_path)
# 步骤6:运行 reward_func 函数,带超时控制
#rewards = run_reward_func(reward_func, queries, prompts, labels, timeout_sec=20)
asyncio.run(run_reward_func_async(reward_func, queries, prompts, labels, timeout_sec=20))
# 如果一切正常,正常退出
print("自定义奖励规则校验成功。")
sys.exit(0)
except Exception as e:
# 捕捉所有异常,并写入错误码
try:
workspace # 确保 workspace 变量已定义
except NameError:
# 如果在获取环境变量时出错,手动获取 WORKSPACE 环境变量
workspace = os.environ.get('WORKSPACE', '.')
if isinstance(e.args[0], dict):
write_error(workspace, e.args[0])
else:
write_error(workspace, ERROR_CODE)
sys.exit(1)
if __name__ == '__main__':
main()
需要注意的是,当前支持python3.10版本的.py文件。奖励规则自定义后,将脚本放在BOS存储路径中,选择一个指定的.py文件即可。