# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importjsonimportrefromabcimportABC,abstractmethodfromtypingimportAny,Optional,Unionfromtqdm.autoimporttqdmfromtyping_extensionsimportSelffromoumi.core.configsimportInferenceConfig,JudgeConfigfromoumi.core.inferenceimportBaseInferenceEnginefromoumi.core.types.conversationimportConversation,Message,Role,TemplatedMessagefromoumi.utils.loggingimportlogger
[docs]@classmethoddeffrom_xml_output(cls,raw_judgement:Optional[str])->Optional[Self]:"""Parses the judgement from XML-like tags in the raw output. Args: raw_judgement: The raw judgement string to parse. Returns: Optional[Self]: An instance of the class with parsed attributes, or None if parsing fails. """ifnotraw_judgement:returnNoneattributes={}# Regex pattern to match XML-like tags and their content# Captures the tag name in group 1 and the content between tags in group 2# For example, "<label>True</label>" would match as ("label", "True")pattern=r"<(\w+)>(.*?)</\1>"matches=re.findall(pattern,raw_judgement,re.DOTALL)forattr_name,attr_valueinmatches:attributes[attr_name]=attr_value.strip()returncls(**attributes,raw_judgement=raw_judgement)
[docs]@classmethoddeffrom_json_output(cls,raw_judgement:Optional[str])->Optional[Self]:"""Parses the judgement from JSON."""ifnotraw_judgement:returnNonetry:judgement_data=json.loads(raw_judgement)returncls(**judgement_data,raw_judgement=raw_judgement)exceptjson.JSONDecodeError:returnNone
@propertydeflabel(self):"""Convert the judgement to a boolean or Likert scale label. This method should be overridden by subclasses to provide the actual conversion logic. """returnself.raw_judgement@propertydeffields(self):"""Return the fields of the judgement."""fields=self.model_dump()fields.pop("raw_judgement",None)fields.pop("template",None)fields.pop("role",None)returnfields
[docs]classBaseJudge(ABC):def__init__(self,config:JudgeConfig,inference_engine:Optional[BaseInferenceEngine]=None,):"""Initialize the Judge."""self._config=configself._attributes=config.attributesiflen(config.attributes)<1:raiseValueError("At least one attribute must be specified in the judge configuration.")ifinference_engineisnotNone:logger.debug("Using provided inference engine.")self.inference_engine=inference_engineelse:logger.debug("Initializing inference engine.")self.inference_engine=self._create_inference_engine(config)
[docs]defjudge(self,raw_inputs:Union[list[Conversation],list[dict],list[Message]],)->list[dict[str,BaseJudgeOutput]]:"""Judge the given conversations."""# Convert the raw user inputs into a list of JudgeInput classes# A JudgeInput is the unit of what needs to be judged, and could be a# prompt, request/answer pair or a full conversationjudge_inputs=[]forraw_inputinraw_inputs:ifisinstance(raw_input,dict):judge_input=self._transform_dict_input(raw_input)elifisinstance(raw_input,TemplatedMessage):judge_input=raw_inputelifisinstance(raw_input,Conversation):judge_input=self._transform_conversation_input(raw_input)else:raiseValueError(f"Unsupported conversation type: {type(raw_input)}")judge_inputs.append(judge_input)results={}forattribute_nameinself._attributes.keys():# Generate the full judging prompt for each attribute# This includes the judge system prompt, and any few shot examples# That are included in the judge config.judgement_prompts=[self.build_judgement_prompt(judge_input,attribute_name=attribute_name)forjudge_inputintqdm(judge_inputs)]# Run inference for the attribute's prompt# We batch the inference for a single attribute together to maximally# benefit from kv prefix caching (system prompt, few shot examples)raw_judgements=self._infer(judgement_prompts)# Parse the raw judge output (a string) into a JudgeOutput objectjudgements=[]forconversationinraw_judgements:judgement=conversation.messages[-1].contentparsed_judgement=self._transform_model_output(judgement)judgements.append({"raw_judgement":judgement,"fields":parsed_judgement.fields,"label":parsed_judgement.label,})results[attribute_name]=judgements# Results are now in the form# {attribute: judgements for attribute in attributes}# Transform to# [{attribute: judgement} for judgement in judgements]outputs=[]foridxinrange(len(raw_inputs)):output_dict={}forattribute_nameinself._attributes.keys():output_dict[attribute_name]=results[attribute_name][idx]outputs.append(output_dict)returnoutputs
[docs]defbuild_judgement_prompt(self,judge_input:Message,attribute_name:Optional[str])->Conversation:"""Generate judge prompts for a dataset."""ifattribute_nameisNone:iflen(self._attributes)>0:raiseValueError("attribute_name must be specified when there are multiple"" attributes to judge.")else:# If there's only one attribute, use itattribute_name=next(iter(self._attributes))ifattribute_namenotinself._attributes:raiseKeyError(f"Attribute '{attribute_name}' not found in config.attributes")attribute=self._attributes[attribute_name]messages=attribute.messages.copy()messages.append(Message(content=judge_input.content,role=Role.USER))returnConversation(messages=messages,metadata={"judge_attribute_name":attribute_name,},)
def_infer(self,conversations:list[Conversation])->list[Conversation]:"""Judge a single attribute."""metadatas=[convo.metadataforconvoinconversations]# Wrap the generation params in an inference config for inference.# Only the generations params are used by the inference engine.inference_config=InferenceConfig(model=self._config.model,generation=self._config.generation,remote_params=self._config.remote_params,)responses=self.inference_engine.infer(input=conversations,inference_config=inference_config)assertlen(responses)==len(metadatas)forresponse,metadatainzip(responses,metadatas):response.metadata.update(metadata)returnresponsesdef_create_inference_engine(self,config:JudgeConfig)->BaseInferenceEngine:"""Create the inference engine."""fromoumi.builders.inference_enginesimportbuild_inference_enginereturnbuild_inference_engine(engine_type=config.engine,model_params=config.model,remote_params=config.remote_params,)@abstractmethoddef_transform_conversation_input(self,conversation:Conversation)->Message:raiseNotImplementedError@abstractmethoddef_transform_dict_input(self,raw_input:dict[str,Any])->Message:raiseNotImplementedError@abstractmethoddef_transform_model_output(self,model_output)->BaseJudgeOutput:raiseNotImplementedError