Source code for oumi.evaluation.registry.count_letters_task
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importrefromtypingimportAny,Optionalfromoumi.core.configs.params.evaluation_paramsimportEvaluationTaskParamsfromoumi.core.inference.base_inference_engineimportBaseInferenceEnginefromoumi.core.registryimportregister_evaluation_functionfromoumi.datasets.grpo.letter_countimportLetterCountGrpoDatasetfromoumi.utils.loggingimportloggerdef_extract_prediction(response:str)->Optional[int]:r"""Returns the numeric answer extracted from `\boxed{...}`, or None otherwise."""regex_result=re.findall(r"\\boxed\{([-+]?\d+)\}",response)ifnotregex_resultorlen(regex_result)!=1:returnNonenumber_str=regex_result[0]# Except clause shouldn't trigger because the regex should only find ints.try:returnint(number_str)exceptValueError:returnNone
[docs]@register_evaluation_function("count_letters")defcount_letters(task_params:EvaluationTaskParams,inference_engine:BaseInferenceEngine,)->dict[str,Any]:"""Custom evaluation function registered as `count_letters`."""dataset=LetterCountGrpoDataset(dataset="oumi-ai/oumi-letter-count-clean",split="test")# TODO: OPE-1155: Add support for using Oumi dataset code to create the dataset.# dataset = build_dataset("oumi-ai/oumi-letter-count", tokenizer=None, sample_count=10) # noqa: E501num_samples=task_params.num_samplesifnum_samplesisNone:num_samples=len(dataset)input_conversations=[dataset.conversation(i)foriinrange(num_samples)]conversations=inference_engine.infer(input_conversations)logger.info(f"Finished inference on {len(conversations)} conversations!")iflen(conversations)>0:logger.info(f"Sample conversation: {conversations[0]}")count=0# The number of examples with correct answers extracted.total=0# All examples.valid_count=0# The number of examples with valid answers extracted.fori,conversationinenumerate(conversations):total+=1# Grab the model's responseresponse=conversation.last_message()# Ignore cases where model didn't respond or it's a multimodal response.# For now, we focus on text-only responses.ifnotresponseornotisinstance(response.content,str):continue# Count the example as correct if the extracted prediction is correct.prediction=_extract_prediction(response.content)ifpredictionisNone:continuevalid_count+=1ifprediction==conversation.metadata["letter_count_integer"]:count+=1return{# Accuracy across all examples."accuracy":count/totaliftotal>0else0,# Accuracy when only counting examples with properly extracted answers."properly_extracted_accuracy":count/valid_countifvalid_count>0else0,"num_samples":num_samples,# These three values sum up to num_samples."num_correct_answers":count,"num_incorrect_answers":valid_count-count,"num_invalid_answers":total-valid_count,}