资讯 文档
技术能力
语音技术
文字识别
人脸与人体
图像技术
语言与知识
视频技术

RFT自定义奖励规则

RFT训练方法中预置五种奖励规则,奖励规则中定义了评估模型输出好坏的规则,以下给出代码可供查看或修改使用。

1.字符串比较(相等)

from typing import List
import re
import os
import torch

def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
    """
    Rule-based RLHF reward function.
    
    Args:
        queries (List[str]): List of concatenated prompt and content strings.
        prompts (List[str]): List of prompt strings.
        labels (List[str]): List of desired response strings.
    
    Returns:
        List[float]: List of reward scores corresponding to each query.
    """
    rewards = []
    max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1')
    for idx, (query, prompt, label) in enumerate(zip(queries, prompts, labels)):
        # Extract content by removing the prompt from the query
        content = query[min(int(max_prompt_len), len(prompt)):]
        
        # Define reward rules
        if content == label:
            reward = 1.0  # Exact match
        # elif is_semantically_similar(content, label):
        #     reward = 0.7  # Partial match based on semantic similarity
        elif has_keyword_overlap(content, label):
            reward = 0.5  # Partial match based on keyword overlap
        else:
            reward = 0.0  # No meaningful match
        
        rewards.append(reward)
    
    return torch.tensor(rewards, dtype=torch.float)

def has_keyword_overlap(text1: str, text2: str) -> bool:
    """
    Check if there is significant keyword overlap between two texts.
    
    Args:
        text1 (str): First text.
        text2 (str): Second text.
    
    Returns:
        bool: True if there's keyword overlap, False otherwise.
    """
    # Simple keyword extraction: lowercase and split by non-word characters
    keywords1 = set(re.findall(r'\b\w+\b', text1.lower()))
    keywords2 = set(re.findall(r'\b\w+\b', text2.lower()))
    
    overlap = keywords1.intersection(keywords2)
    overlap_ratio = len(overlap) / max(len(keywords1), len(keywords2))
    
    return overlap_ratio > 0.3  # Threshold can be adjusted

2.字符串比较(包含)

from typing import List
import os
import re
import torch

def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
    """
    RLHF reward function based on set (Jaccard) similarity.
    
    Args:
        queries (List[str]): List of concatenated prompt and content strings.
        prompts (List[str]): List of prompt strings.
        labels (List[str]): List of desired response strings.
    
    Returns:
        List[float]: List of reward scores (0 to 1) corresponding to each query.
    """
    rewards = []
    max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1')
    
    for query, prompt, label in zip(queries, prompts, labels):
        # Extract content by removing the prompt from the query
        content = query[min(int(max_prompt_len), len(prompt)):]
        
        set_content = set(tokenize(content))
        set_label = set(tokenize(label))
        if not set_content and not set_label:
            similarity = 1.0
        else:
            intersection = set_content.intersection(set_label)
            union = set_content.union(set_label)
            similarity = len(intersection) / len(union)
        # Ensure similarity is between 0 and 1
        similarity = max(0.0, min(similarity, 1.0))
        rewards.append(similarity)
    
    return torch.tensor(rewards, dtype=torch.float)

def tokenize(text: str) -> List[str]:
    """
    Tokenize the input text into words.
    
    Args:
        text (str): Input text.
    
    Returns:
        List[str]: List of word tokens.
    """
    return re.findall(r'\b\w+\b', text.lower())


# Example Usage
if __name__ == "__main__":
    queries = [
        "Prompt: How are you?\nI'm doing well, thank you!",
        "Prompt: Tell me a joke.\nWhy did the chicken cross the road?",
        "Prompt: What's the weather today?\nIt's sunny and warm."
    ]
    
    prompts = [
        "Prompt: How are you?",
        "Prompt: Tell me a joke.",
        "Prompt: What's the weather today?"
    ]
    
    labels = [
        "I'm doing well, thank you!",
        "Why did the chicken cross the road?",
        "It's sunny and warm."
    ]
    
    rewards_set_matching = reward_func(queries, prompts, labels)
    print("Set Matching-Based Rewards:", rewards_set_matching)
    # Output: [1.0, 1.0, 1.0]

3.字符串相似度对比

from typing import List
import re
import numpy as np
import os
import torch

def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
    """
    RLHF reward function based on normalized edit distance.
    
    Args:
        queries (List[str]): List of concatenated prompt and content strings.
        prompts (List[str]): List of prompt strings.
        labels (List[str]): List of desired response strings.
    
    Returns:
        List[float]: List of reward scores (0 to 1) corresponding to each query.
    """
    rewards = []
    
    for query, prompt, label in zip(queries, prompts, labels):
        content = extract_content(query, prompt)
        distance = levenshtein_distance(content, label)
        max_len = max(len(content), len(label))
        if max_len == 0:
            similarity = 1.0
        else:
            similarity = 1 - (distance / max_len)
        # Ensure similarity is between 0 and 1
        similarity = max(0.0, min(similarity, 1.0))
        rewards.append(similarity)
    
    return torch.tensor(rewards, dtype=torch.float)

def extract_content(query: str, prompt: str) -> str:
    """
    Extract content from query by removing the prompt.
    
    Args:
        query (str): The concatenated prompt and content.
        prompt (str): The prompt part.
    
    Returns:
        str: The extracted content.
    """
    max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1024')
    # Extract content by removing the prompt from the query
    return query[min(int(max_prompt_len), len(prompt)):]

def levenshtein_distance(s1: str, s2: str) -> int:
    """
    Compute the Levenshtein distance between two strings.
    
    Args:
        s1 (str): First string.
        s2 (str): Second string.
    
    Returns:
        int: The Levenshtein distance.
    """
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)

    # len(s1) >= len(s2)
    previous_row = list(range(len(s2) + 1))
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1  # insertion
            deletions = current_row[j] + 1        # deletion
            substitutions = previous_row[j] + (c1 != c2)  # substitution
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    
    return previous_row[-1]

# Example Usage
if __name__ == "__main__":
    queries = [
        "Prompt: How are you?\nI'm doing well, thank you!",
        "Prompt: Tell me a joke.\nWhy did the chicken cross the road?",
        "Prompt: What's the weather today?\nIt's sunny and warm."
    ]
    
    prompts = [
        "Prompt: How are you?",
        "Prompt: Tell me a joke.",
        "Prompt: What's the weather today?"
    ]
    
    labels = [
        "I'm doing well, thank you!",
        "Why did the chicken cross the road?",
        "It's sunny and warm."
    ]
    
    rewards_edit_distance = reward_func(queries, prompts, labels)
    print("Edit Distance-Based Rewards:", rewards_edit_distance)
    # Output: [1.0, 1.0, 1.0]

4.数学答案匹配

import asyncio
import re
from itertools import islice, zip_longest
import ray
from sympy.parsing.latex import parse_latex
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import torch
import os
try:
    from math_verify import parse, verify
except ImportError:
    print("math_verify is not installed in this environment")
    parse = None
    verify = None


def repeatness(s: str):
    def ranks(l):
        index = {v: i for i, v in enumerate(sorted(set(l)))}
        return [index[v] for v in l]

    def suffixArray(s):
        line = ranks(s)
        n, k, ans, sa = len(s), 1, line, [0] * len(s)
        while k < n - 1:
            line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
            ans, k = line, k << 1
        for i, k in enumerate(ans):
            sa[k] = i
        return ans, sa

    def lcp(arr, suffixArr, inv_suff):
        n, ans, k = len(arr), [0] * len(arr), 0

        for i in range(n):
            if inv_suff[i] == n - 1:
                k = 0
                continue

            j = suffixArr[inv_suff[i] + 1]
            while i + k < n and j + k < n and arr[i + k] == arr[j + k]:
                k += 1

            ans[inv_suff[i]] = k
            if k > 0:
                k -= 1

        return ans

    arr = [ord(i) for i in s]
    n = len(arr)
    if n <= 1:
        return 0
    c, sa = suffixArray(arr)
    cnt = sum(lcp(arr, sa, c))

    return (cnt * 2 / (n * (n + 1))) > 0.2


SUBSTITUTIONS = [
    ("an ", ""),
    ("a ", ""),
    (".$", "$"),
    ("\\$", ""),
    (r"\ ", ""),
    (" ", ""),
    ("mbox", "text"),
    (",\\text{and}", ","),
    ("\\text{and}", ","),
    ("\\text{m}", "\\text{}"),
]


REMOVED_EXPRESSIONS = [
    "square",
    "ways",
    "integers",
    "dollars",
    "mph",
    "inches",
    "ft",
    "hours",
    "km",
    "units",
    "\\ldots",
    "sue",
    "points",
    "feet",
    "minutes",
    "digits",
    "cents",
    "degrees",
    "cm",
    "gm",
    "pounds",
    "meters",
    "meals",
    "edges",
    "students",
    "childrentickets",
    "multiples",
    "\\text{s}",
    "\\text{.}",
    "\\text{\ns}",
    "\\text{}^2",
    "\\text{}^3",
    "\\text{\n}",
    "\\text{}",
    r"\mathrm{th}",
    r"^\circ",
    r"^{\circ}",
    r"\;",
    r",\!",
    "{,}",
    '"',
    "\\dots",
]


def normalize_final_answer(final_answer: str) -> str:
    """
    Normalize a final answer to a quantitative reasoning question.
    This code comes from https://arxiv.org/pdf/2206.14858.pdf, page18.
    """
    # final_answer = final_answer.split("=")[-1]

    for before, after in SUBSTITUTIONS:
        final_answer = final_answer.replace(before, after)
    for expr in REMOVED_EXPRESSIONS:
        final_answer = final_answer.replace(expr, "")

    # Extract answer that is in LaTeX math, is bold,
    # is surrounded by a box, etc.
    final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
    final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)

    # Normalize shorthand TeX:
    # \fracab -> \frac{a}{b}
    # \frac{abc}{bef} -> \frac{abc}{bef}
    # \fracabc -> \frac{a}{b}c
    # \sqrta -> \sqrt{a}
    # \sqrtab -> sqrt{a}b
    final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
    final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
    final_answer = final_answer.replace("$", "")

    # Normalize 100,000 -> 100000
    if final_answer.replace(",", "").isdigit():
        final_answer = final_answer.replace(",", "")

    return final_answer


def latex_eval(latex):
    sym = parse_latex(latex)
    val = sym.evalf()
    return sym, val


def _is_latex_equal(str1, str2):
    try:
        sym1, val1 = latex_eval(str1)
        sym2, val2 = latex_eval(str2)
        if sym1 == sym2 or val1 == val2:
            return True
        else:
            raise ValueError
    except Exception:  # noqa
        try:
            norm1, norm2 = normalize_final_answer(str1), normalize_final_answer(str2)
            sym1, val1 = latex_eval(norm1)
            sym2, val2 = latex_eval(norm2)
            if sym1 == sym2 or val1 == val2:
                return True
        except Exception:  # noqa
            return norm1 == norm2
    return False

def _fix_fracs(string):
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if substr[0] == "{":
                new_str += substr
            else:
                try:
                    assert len(substr) >= 2
                except Exception:  # noqa
                    return string
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    string = new_str
    return string


def _fix_a_slash_b(string):
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        a = int(a)
        b = int(b)
        assert string == "{}/{}".format(a, b)
        new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
        return new_string
    except Exception:  # noqa
        return string


def _remove_right_units(string):
    # "\\text{ " only ever occurs (at least in the val set) when describing units
    if "\\text{ " in string:
        splits = string.split("\\text{ ")
        assert len(splits) == 2
        return splits[0]
    else:
        return string


def _fix_sqrt(string):
    if "\\sqrt" not in string:
        return string
    splits = string.split("\\sqrt")
    new_string = splits[0]
    for split in splits[1:]:
        if split[0] != "{":
            a = split[0]
            new_substr = "\\sqrt{" + a + "}" + split[1:]
        else:
            new_substr = "\\sqrt" + split
        new_string += new_substr
    return new_string


def _strip_string(string):
    # linebreaks
    string = string.replace("\n", "")
    # print(string)

    # remove inverse spaces
    string = string.replace("\\!", "")
    # print(string)

    # replace \\ with \
    string = string.replace("\\\\", "\\")
    # print(string)

    # replace tfrac and dfrac with frac
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")
    # print(string)

    # remove \left and \right
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")
    # print(string)

    # Remove circ (degrees)
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")

    # remove dollar signs
    string = string.replace("\\$", "")
    string = string.replace("$", "")
    string = string.replace(",", "")

    # remove units (on the right)
    string = _remove_right_units(string)

    # remove percentage
    string = string.replace("\\%", "")
    string = string.replace("\%", "")

    # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")
    # if empty, return empty string
    if len(string) == 0:
        return string
    if string[0] == ".":
        string = "0" + string

    # to consider: get rid of e.g. "k = " or "q = " at beginning
    if len(string.split("=")) == 2:
        if len(string.split("=")[0]) <= 2:
            string = string.split("=")[1]

    # fix sqrt3 --> sqrt{3}
    string = _fix_sqrt(string)

    # remove spaces
    string = string.replace(" ", "")

    # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
    string = _fix_fracs(string)

    # manually change 0.5 --> \frac{1}{2}
    if string == "0.5":
        string = "\\frac{1}{2}"

    # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
    string = _fix_a_slash_b(string)

    return string


def is_equiv(str1, str2, verbose=False) -> bool:
    if str1 is None and str2 is None:
        print("WARNING: Both None")
        return True
    if str1 is None or str2 is None:
        return False

    try:
        ss1 = _strip_string(str1)
        ss2 = _strip_string(str2)
        if verbose:
            print(ss1, ss2)
        try:
            return float(ss1) == (float(ss2))
        except Exception:  # noqa
            return ss1 == ss2
    except Exception:  # noqa
        return str1 == str2


def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx : right_brace_idx + 1]

    return retval


def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[: len(left)] == left
        assert s[-1] == "}"
        return s[len(left) : -1]
    except Exception:
        return None


def get_answer_str(s: str) -> str:
    res = remove_boxed(last_boxed_only_string(s))
    if res is not None:
        return res
    return s


def is_equal(str1, str2, math_mode="legacy"):
    first_equal = is_equiv(str1, str2)
    if first_equal:
        return True
    return is_latex_equal(str1, str2, math_mode)


def solution2answer(solution: str, math_mode="eval_peeking") -> str:
    answer = solution
    if math_mode == "eval_peeking":
        answer = get_answer_str(solution)
    else:
        raise ValueError(f"Invalid math_mode: {math_mode}")
    return answer


def get_final_answer(output: str) -> str:
    output = output.replace("is:", "is").replace("answer:", "answer is").strip()
    if output.endswith("."):
        output = output[:-1]
    if ".$" in output:
        output = output.replace(".$", "$")
    pattern_list = [
        r"answer is (-?\d+\.?\d*)$",
        r"answer is (.+?)$",
    ]
    matches = []
    for pat in pattern_list:
        matches = re.findall(pat, output, re.S)
        if matches:
            return get_answer_str(matches[0])

    return get_answer_str(output)

@ray.remote(num_cpus=1)
def extract_final_answers_batch(responses: List[str]) -> List[str]:
        # pattern = re.compile(r"(\\boxed{.*})")
    pattern = re.compile(r"<answer>.*?(\\boxed{.*}).*?</answer>", re.DOTALL)
    results = []
    for response in responses:
        matches = re.findall(pattern, response)
        results.append(matches[-1] if matches else "")
    return results


def is_latex_equal(str1: str, str2: str, math_mode: str = "legacy") -> bool:
    """
    同步比较两个 LaTeX 字符串是否在数学意义上是等价的。
    """
    if math_mode == "legacy":
        # 检查重复性
        if (len(str1) > 128 and repeatness(str1)) or (len(str2) > 128 and repeatness(str2)):
            return False

        try:
            # 直接调用同步函数进行比较
            return _is_latex_equal(str1, str2)
        except Exception:
            return False
    elif math_mode == "math_verify":
        try:
            # 直接调用同步函数进行比较
            return verify(parse(str1), parse(str2))
        except Exception:
            return False
    else:
        raise NotImplementedError(f"Math mode {math_mode} is not implemented")

def reward_func(queries, prompts, labels):
    # queries is prompts + responses
    # labels is answers
    rewards = []
    outputs = []
    max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1024')
    for query, prompt in zip(queries, prompts):
        # Extract content by removing the prompt from the query
        max_len = min(len(prompt), int(max_prompt_len))
        output = query[max_len:].strip()
        outputs.append(output)
        
    # 分布式提取最终答案
    final_answers = ray.get(extract_final_answers_batch.remote(outputs))


    for label, final_answer in zip(labels, final_answers):
        result = is_equal(solution2answer(label), solution2answer(final_answer))
        score = 1.0 if result else 0.0
        rewards.append(score)        

    # print('rewards are', rewards)
    return torch.tensor(rewards, dtype=torch.float)

5.逻辑推理匹配

import re
from typing import Dict, Tuple, Optional
import torch
import os

def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
    """Extracts the final answer from the model's response string.
    
    Args:
        solution_str: Raw response string from the language model
        
    Returns:
        Tuple containing (extracted_answer, processed_string)
    """

    # Extract final answer using XML-style tags
    answer_pattern = r'<answer>(.*?)</answer>'
    matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
    
    if not matches:
        print("[Error] No valid answer tags found")
        return None
        
    final_answer = matches[-1].group(1).strip()
    return final_answer

def parse_solution_text_format(solution_text: str) -> Dict[str, str]:
    """Parses ground truth solution text into status dictionary.
    
    Args:
        solution_text: Formatted solution text from dataset
        
    Returns:
        Dictionary mapping character names to their roles (knight/knave)
    """
    status_dict = {}
    print("\n[Ground Truth Parsing]")
    
    for line in solution_text.split('\n'):
        line = line.strip()
        if not line:
            continue
            
        match = re.search(r'\b([A-Za-z]+)\b.*?\b(knight|knave)\b', line, re.IGNORECASE)
        if match:
            name, role = match.groups()
            status_dict[name] = role.lower()
            print(f"  Found: {name} → {role}")
        else:
            print(f"  [Warning] Unparseable line: '{line}'")
    
    return status_dict

def parse_model_answer(answer_text: str, expected_names: list) -> Optional[Dict[str, str]]:
    """Parses model's answer text into status dictionary.
    
    Args:
        answer_text: Text extracted from model's <answer> tags
        expected_names: List of character names requiring identification
        
    Returns:
        Dictionary mapping character names to predicted roles, or None if incomplete
    """
    status_dict = {}
    print("\n[Model Answer Parsing]")
    print(f"  Expected characters: {expected_names}")

    knight_count = answer_text.lower().count('knight')
    knave_count = answer_text.lower().count('knave')

    print(f"  Number of predicted roles: {knight_count + knave_count}")
    if knight_count + knave_count != len(expected_names):
        print(f"  [Error] Number of characters mismatch: {knight_count + knave_count} != {len(expected_names)}")
        return None

    for name in expected_names:
        pattern = re.compile(
            rf'\b{re.escape(name)}\b\s+is\s+a\s+\b(knight|knave)\b', 
            re.IGNORECASE
        )
        match = pattern.search(answer_text)
        
        if match:
            role = match.group(1).lower()
            status_dict[name] = role
            print(f"  Found: {name} → {role}")
        else:
            print(f"  [Error] Missing identification for {name}")
            return None
    
    return status_dict

def validate_response_structure(processed_str: str) -> bool:
    """Performs comprehensive validation of response structure.
    
    Args:
        processed_str: Processed response string from the model
        
    Returns:
        Boolean indicating whether all formatting requirements are met
    """
    print("\n[Structure Validation]")
    validation_passed = True

    # Check required tags
    tags = {
        # 'think_start': ('<think>', 1),
        'think_end': ('</think>', 1),
        'answer_start': ('<answer>', 1),
        'answer_end': ('</answer>', 1)
    }

    positions = {}
    for tag_name, (tag_str, expected_count) in tags.items():
        count = processed_str.count(tag_str)
        positions[tag_name] = pos = processed_str.find(tag_str)
        
        print(f"  {tag_str}: count={count}, position={pos}")
        
        if count != expected_count:
            print(f"  [Error] {tag_str} appears {count} times (expected {expected_count})")
            validation_passed = False

    # Verify tag order
    if (positions['think_end'] > positions['answer_start'] or
        positions['answer_start'] > positions['answer_end']):
        print("  [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
        validation_passed = False
    else:
        print("  Tag sequence validation passed")

    return validation_passed

def reward_func(queries, prompts, labels):
    """Computes comprehensive score for model response.
    
    Args:
    Returns:
        Total score (sum of format and answer rewards)
    """
    rewards = []
    max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1024')
    for query, prompt, label in zip(queries, prompts, labels):
        # format_reward: int = 1
        max_len = min(len(prompt), int(max_prompt_len))
        output = query[max_len:].strip()
        print("\n" + "="*80)
        print(" Processing New Sample ".center(80, '='))
        
        # Parse ground truth data
        solution_text = label
        gt_status = parse_solution_text_format(solution_text)
        expected_names = list(gt_status.keys())
        print(f"[Ground Truth] Final identities: {gt_status}")

        # Extract model answer
        answer_text = extract_solution(query)
        print(f"\n[Model Response]\n{output}")

        answer_score = 0
        if answer_text:
            pred_status = parse_model_answer(answer_text, expected_names)
            if pred_status:
                if pred_status == gt_status:
                    answer_score = 1
                    print("  Content validation: FULL MATCH")
                else:
                    answer_score = 0
                    print("  Content validation: MISMATCH")
        rewards.append(answer_score)

    return torch.tensor(rewards, dtype=torch.float)

自定义奖励规则

如果开发场景比较复杂,或者预置的规则无法满足需求,您可以参考下述格式,自定义奖励规则。

import torch
import os

def reward_func(queries, prompts, labels):
    """
    Calculate rewards based on queries, prompts, and labels.
    
    Args:
        queries (list of str): Prompts + responses.即模型真实的输入和输出。
        prompts (list of str): Input prompts.模型的输入。
        labels (list of str): Ground truth answers.标注的模型的输出。
    
    Returns:
        torch.Tensor: A tensor of rewards.
    """
    rewards = []
    outputs = []
    max_prompt_len = int(os.environ.get('MAX_PROMPT_LEN', '1024'))
    
    for query, prompt in zip(queries, prompts):        
        # Extract content by removing the prompt from the query
        max_len = min(len(prompt), max_prompt_len)
        output = query[max_len:].strip()
        outputs.append(output)
    
    # Rule-based reward process here
    # Ensure process() is defined and returns a list of rewards
    rewards = process(outputs, labels) 
    
    # Convert rewards to a tensor
    return torch.tensor(rewards, dtype=torch.float)

函数名、输入、输出需要按照上述的规则定义。

reward_func代码定义后,可以通过下述的代码测试是否可用:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import sys
import shutil
import json
import glob
import random
import string
import importlib.util
from typing import List
import time
import torch
import asyncio
import signal

# 定义错误码
ERROR_CODE = {
    'error_code': 4103,
    'error_msg': '自定义奖励规则校验失败,请检查奖励规则是否正确。'
}

def write_error(workspace: str, error_code: dict):
    """
    将错误码写入 error.json 文件。
    """
    error_path = os.path.join(workspace, 'error.json')
    try:
        with open(error_path, 'w', encoding='utf-8') as f:
            json.dump(error_code, f, ensure_ascii=False, indent=4)
        print(f"错误信息已写入 {error_path}")
    except Exception as e:
        print(f"无法写入错误信息到 {error_path}: {e}", file=sys.stderr)

def get_environment_variable(var_name: str) -> str:
    """
    获取环境变量的值,如果未设置则抛出异常。
    """
    value = os.environ.get(var_name)
    if not value:
        ERROR_CODE['error_msg'] = f'环境变量 {var_name} 未设置。'
        print(ERROR_CODE['error_msg'])
        raise EnvironmentError(ERROR_CODE)
    return value

def check_single_python_file(source_dir: str) -> str:
    """
    检查指定目录下是否只有一个 Python 文件。
    返回该文件的路径。
    """
    python_files = glob.glob(os.path.join(source_dir, '*.py'))
    if len(python_files) != 1:
        ERROR_CODE['error_msg'] = f'指定目录下有 {len(python_files)} 个 Python 文件。需要且只能有一个。'
        print(ERROR_CODE['error_msg'])
        raise FileNotFoundError(ERROR_CODE)
    return python_files[0]

def move_file(src: str, dst: str):
    """
    移动文件,从 src 到 dst。
    """
    try:
        shutil.copyfile(src, dst)
        print(f"文件已从 {src} 复制到 {dst}")
    except Exception as e:
        ERROR_CODE['error_msg'] = f'移动文件失败: {e}'
        print(ERROR_CODE['error_msg'])
        raise shutil.Error(ERROR_CODE)

def load_test_data(test_data_path: str, num_samples: int = 20):
    """
    从 JSONL 文件中加载前 num_samples 条测试数据。
    返回 prompts 和 labels 列表,只提取每个列表的第一个元素。
    """
    prompts = []
    labels = []
    try:
        with open(test_data_path, 'r', encoding='utf-8') as f:
            for _ in range(num_samples):
                line = f.readline()
                if not line:
                    break
                data = json.loads(line)
                src = data.get('src', [])
                tgt = data.get('tgt', [])
                if isinstance(src, list) and isinstance(tgt, list) and src and tgt:
                    prompt = src[0].strip()
                    label = tgt[0].strip()
                    prompts.append(prompt)
                    labels.append(label)
        if not prompts or not labels:
            ERROR_CODE['error_msg'] = '测试数据中缺少 "src" 或 "tgt" 字段,或它们不是非空的列表。'
            raise ValueError(ERROR_CODE)
        print(f"已加载 {len(prompts)} 条测试数据。")
        return prompts, labels
    except Exception as e:
        if isinstance(e.args[0], dict):
            raise e
        ERROR_CODE['error_msg'] = f'加载测试数据失败: {e}'
        print(ERROR_CODE['error_msg'])
        raise ValueError(ERROR_CODE)

def generate_queries(prompts: List[str]) -> List[str]:
    """
    为每个 prompt 生成一个 query,通过在 prompt 后拼接随机字符,确保不超过 100 个字符。
    """
    queries = []
    for prompt in prompts:
        max_extra_length = 100 - len(prompt)
        if max_extra_length <= 0:
            query = prompt[:100]
        else:
            random_length = random.randint(1, max_extra_length)
            random_chars = ''.join(random.choices(string.ascii_letters + string.digits, k=random_length))
            query = prompt + random_chars
        queries.append(query)
    print("已生成 queries。")
    return queries

def import_reward_func(auto_py_path: str):
    """
    动态导入 auto.py 并获取 reward_func 函数。
    """
    try:
        spec = importlib.util.spec_from_file_location("auto", auto_py_path)
        auto = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(auto)
        if not hasattr(auto, 'reward_func'):
            ERROR_CODE['error_msg'] = 'auto.py 中未找到 reward_func 函数。'
            raise AttributeError(ERROR_CODE)
        reward_func = auto.reward_func
        if not callable(reward_func):
            ERROR_CODE['error_msg'] = 'reward_func 不是可调用的函数。'
            raise TypeError(ERROR_CODE)
        print("成功导入 reward_func 函数。")
        return reward_func
    except Exception as e:
        if isinstance(e, dict):
            raise e
        ERROR_CODE['error_msg'] = f'导入 reward_func 函数失败: {e}'
        print(ERROR_CODE['error_msg'])
        raise ImportError(ERROR_CODE)

def run_reward_func_sync(reward_func, queries: List[str], prompts: List[str], labels: List[str]):
    return reward_func(queries, prompts, labels)

async def run_reward_func_async(reward_func, queries: List[str], prompts: List[str], labels: List[str], timeout_sec: int):
    loop = asyncio.get_event_loop()
    try:
        rewards = await asyncio.wait_for(loop.run_in_executor(None, run_reward_func_sync, reward_func, queries, prompts, labels), timeout=timeout_sec)
    except asyncio.TimeoutError:
        ERROR_CODE['error_msg'] = f'reward_func 运行超过 {timeout_sec} 秒,已终止。'
        print(ERROR_CODE['error_msg'])
        raise TimeoutError(ERROR_CODE['error_msg'])
    except Exception as e:
        ERROR_CODE['error_msg'] = f'reward_func 运行失败: {e}'
        print(ERROR_CODE['error_msg'])
        raise RuntimeError(ERROR_CODE['error_msg'])

    # 检查返回值
    if not all(isinstance(r, torch.Tensor) and r.dtype == torch.float for r in rewards):
        ERROR_CODE['error_msg'] = 'reward_func 返回的结果格式不正确。'
        print(ERROR_CODE['error_msg'])
        raise ValueError(ERROR_CODE['error_msg'])

    print("reward_func 成功运行。")
    return rewards

def main():
    try:
        # 获取环境变量
        workspace = get_environment_variable('WORKSPACE')
        source_dir = os.path.join(workspace, 'rft_reward_func')
        test_data_path = os.path.join(workspace, 'train_eval_data', 'sft_train.jsonl')
        error_output_path = os.path.join(workspace, 'error.json')
        destination_dir = '/qianfan/rudder_rl/openrlhf/graders'
        destination_path = os.path.join(destination_dir, 'auto.py')
        
        # 步骤1:检查指定路径下是否只有一个 Python 文件
        python_file = check_single_python_file(source_dir)
        
        # 步骤2:移动文件到指定目录
        # 确保目标目录存在
        if not os.path.exists(destination_dir):
            os.makedirs(destination_dir)
            print(f"已创建目标目录 {destination_dir}")
        
        move_file(python_file, destination_path)
        
        # 步骤3:加载测试数据
        prompts, labels = load_test_data(test_data_path, num_samples=20)
        
        # 步骤4:生成 queries
        queries = generate_queries(prompts)
        
        # 步骤5:导入 reward_func 函数
        reward_func = import_reward_func(destination_path)
        
        # 步骤6:运行 reward_func 函数,带超时控制
        #rewards = run_reward_func(reward_func, queries, prompts, labels, timeout_sec=20)
        asyncio.run(run_reward_func_async(reward_func, queries, prompts, labels, timeout_sec=20))

        # 如果一切正常,正常退出
        print("自定义奖励规则校验成功。")
        sys.exit(0)
        
    except Exception as e:
        # 捕捉所有异常,并写入错误码
        try:
            workspace  # 确保 workspace 变量已定义
        except NameError:
            # 如果在获取环境变量时出错,手动获取 WORKSPACE 环境变量
            workspace = os.environ.get('WORKSPACE', '.')
        if isinstance(e.args[0], dict):
            write_error(workspace, e.args[0])
        else:
            write_error(workspace, ERROR_CODE)
        sys.exit(1)

if __name__ == '__main__':
    main()

需要注意的是,当前支持python3.10版本的.py文件。奖励规则自定义后,将脚本放在BOS存储路径中,选择一个指定的.py文件即可。

上一篇
RLHF
下一篇
查看训练列表