Source code for oumi.datasets.grpo.rewards.countdown_rewards
# Copyright 2025 - Jiayi Pan## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Derived from https://github.com/Jiayi-Pan/TinyZero/blob/main/verl/utils/reward_score/countdown.py.This file was slightly modified to be an Oumi reward registry function."""importrefromtypingimportAny,Optionalfromoumi.core.registryimportRegistryType,registerdef_extract_solution(solution_str:str)->Optional[str]:"""Extracts the equation from the solution string. Args: solution_str: The response from the LLM. Returns: The equation from the solution string, or None if not found. """solution_str=solution_str.split("\n")[-1]answer_pattern=r"<answer>(.*?)</answer>"match=re.finditer(answer_pattern,solution_str)matches=list(match)ifmatches:final_answer=matches[-1].group(1).strip()else:final_answer=Nonereturnfinal_answerdef_validate_equation(equation_str:str,available_numbers:list[int])->bool:"""Validates that equation only uses available numbers and each number once. Args: equation_str: The equation to validate. available_numbers: The list of available numbers. Returns: True if the equation uses each available number exactly once, else False. """try:# Extract all numbers from the equationnumbers_in_eq=[int(n)forninre.findall(r"\d+",equation_str)]# Check if all numbers in equation are availableavailable_numbers=sorted(available_numbers)numbers_in_eq=sorted(numbers_in_eq)# Each number should be used exactly oncereturnnumbers_in_eq==available_numbersexceptException:returnFalsedef_evaluate_equation(equation_str:str)->Optional[float]:"""Safely evaluates the arithmetic equation using eval() with precautions."""try:# Regex that only allows numbers, operators, parentheses and whitespaceallowed_pattern=r"^[\d+\-*/().\s]+$"ifnotre.match(allowed_pattern,equation_str):raiseValueError("Invalid characters in equation.")# Evaluate the equation with restricted globals and localsresult=eval(equation_str,{"__builtins__":None},{})returnresultexceptException:returnNone
[docs]@register("countdown",RegistryType.REWARD_FUNCTION)defcountdown_reward(data_source:str,solution_str:str,ground_truth:dict[str,Any],extra_info:dict[str,Any],format_score=0.0,score=1.0,)->float:"""Custom reward function for the Countdown task. Currently, this function only works with the VERL_GRPO trainer. Args: data_source: The data source. solution_str: The response from the LLM. ground_truth: Dictionary containing target number and available numbers extra_info: Extra information about the sample. format_score: The score for correct format but wrong answer. score: The score for the correct answer. Returns: `score` if the equation is valid and correct, `format_score` if the answer was parsed properly but the equation is incorrect, `0` if the answer was not parsed properly. """target=ground_truth["target"]numbers=ground_truth["numbers"]equation=_extract_solution(solution_str=solution_str)ifequationisNone:return0# Validate equation uses correct numbersifnot_validate_equation(equation,numbers):returnformat_score# Evaluate equationtry:result=_evaluate_equation(equation)ifresultisNone:returnformat_scoreifabs(result-target)<1e-5:# Account for floating point precisionreturnscoreelse:returnformat_scoreexceptException:returnformat_score