oumi.datasets.grpo.rewards#

GRPO reward functions module.

oumi.datasets.grpo.rewards.compute_letter_count_reward(completion: str, target_count: int) float[source]#

Computes the rewards for counting the letters in a string.

The last group of consecutive digits in the completion is assumed to be the letter count. We’re also assuming it’s counting the correct letter. The reward is the negative of the absolute difference between the count and the target count, plus 0.1 if the answer was properly formatted.

For example, for the string “There are 2 ‘r’s in strawberry”, and the target count is 3, the reward is -1.

Parameters:
  • completion – The completion string from the LLM.

  • target_count – The target count of letters.

Returns:

The reward value, calculated as the negative of the absolute difference between the count and the target count. The count is assumed to be the last group of consecutive digits in the completion string.

oumi.datasets.grpo.rewards.compute_sharp_target_token_length_reward(num_tokens: int, *, target_tokens: int)[source]#

Returns maximum reward for inputs that are target_tokens long.

The reward reduces sharply if the actual number of tokens deviates from target_tokens.

The reward is computed as: -|num_tokens - target_tokens|, which penalizes token counts not equal to target_tokens.

oumi.datasets.grpo.rewards.compute_soft_target_token_length_reward(num_tokens: int, *, target_tokens: int)[source]#

Returns maximum reward for inputs that are target_tokens long.

The reward is in the [0,1] range and reduces smoothly from the maximum value of 1.0 if the actual number of tokens deviates from target_tokens.

The reward is proportional to: x*exp(-x) where x := num_tokens/target_tokens.

oumi.datasets.grpo.rewards.countdown_reward(data_source: str, solution_str: str, ground_truth: dict[str, Any], extra_info: dict[str, Any], format_score=0.1, score=1.0) float[source]#

Custom reward function for the Countdown task.

Currently, this function only works with the VERL_PPO trainer.

Parameters:
  • data_source – The data source.

  • solution_str – The response from the LLM.

  • ground_truth – Dictionary containing target number and available numbers

  • extra_info – Extra information about the sample.

  • format_score – The score for correct format but wrong answer.

  • score – The score for the correct answer.

Returns:

score if the equation is valid and correct, format_score if the answer was parsed properly but the equation is incorrect, 0 if the answer was not parsed properly.