Source code for oumi.datasets.grpo.rewards.countdown_rewards

# Copyright 2025 - Jiayi Pan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Derived from https://github.com/Jiayi-Pan/TinyZero/blob/main/verl/utils/reward_score/countdown.py.

This file was slightly modified to be an Oumi reward registry function.
"""

import re
from typing import Any, Optional

from oumi.core.registry import RegistryType, register


def _extract_solution(solution_str: str) -> Optional[str]:
    """Extracts the equation from the solution string.

    Args:
        solution_str: The response from the LLM.

    Returns:
        The equation from the solution string, or None if not found.
    """
    solution_str = solution_str.split("\n")[-1]

    answer_pattern = r"<answer>(.*?)</answer>"
    match = re.finditer(answer_pattern, solution_str)
    matches = list(match)
    if matches:
        final_answer = matches[-1].group(1).strip()
    else:
        final_answer = None
    return final_answer


def _validate_equation(equation_str: str, available_numbers: list[int]) -> bool:
    """Validates that equation only uses available numbers and each number once.

    Args:
        equation_str: The equation to validate.
        available_numbers: The list of available numbers.

    Returns:
        True if the equation uses each available number exactly once, else False.
    """
    try:
        # Extract all numbers from the equation
        numbers_in_eq = [int(n) for n in re.findall(r"\d+", equation_str)]

        # Check if all numbers in equation are available
        available_numbers = sorted(available_numbers)
        numbers_in_eq = sorted(numbers_in_eq)

        # Each number should be used exactly once
        return numbers_in_eq == available_numbers
    except Exception:
        return False


def _evaluate_equation(equation_str: str) -> Optional[float]:
    """Safely evaluates the arithmetic equation using eval() with precautions."""
    try:
        # Regex that only allows numbers, operators, parentheses and whitespace
        allowed_pattern = r"^[\d+\-*/().\s]+$"
        if not re.match(allowed_pattern, equation_str):
            raise ValueError("Invalid characters in equation.")

        # Evaluate the equation with restricted globals and locals
        result = eval(equation_str, {"__builtins__": None}, {})
        return result
    except Exception:
        return None


[docs] @register("countdown", RegistryType.REWARD_FUNCTION) def countdown_reward( data_source: str, solution_str: str, ground_truth: dict[str, Any], extra_info: dict[str, Any], format_score=0.0, score=1.0, ) -> float: """Custom reward function for the Countdown task. Currently, this function only works with the VERL_GRPO trainer. Args: data_source: The data source. solution_str: The response from the LLM. ground_truth: Dictionary containing target number and available numbers extra_info: Extra information about the sample. format_score: The score for correct format but wrong answer. score: The score for the correct answer. Returns: `score` if the equation is valid and correct, `format_score` if the answer was parsed properly but the equation is incorrect, `0` if the answer was not parsed properly. """ target = ground_truth["target"] numbers = ground_truth["numbers"] equation = _extract_solution(solution_str=solution_str) if equation is None: return 0 # Validate equation uses correct numbers if not _validate_equation(equation, numbers): return format_score # Evaluate equation try: result = _evaluate_equation(equation) if result is None: return format_score if abs(result - target) < 1e-5: # Account for floating point precision return score else: return format_score except Exception: return format_score