Source code for oumi.judges.base_judge

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re
from abc import ABC, abstractmethod
from typing import Any, Optional, Union

from tqdm.auto import tqdm
from typing_extensions import Self

from oumi.core.configs import InferenceConfig, JudgeConfig
from oumi.core.inference import BaseInferenceEngine
from oumi.core.types.conversation import Conversation, Message, Role, TemplatedMessage
from oumi.utils.logging import logger


[docs] class BaseJudgeOutput(ABC, TemplatedMessage): raw_judgement: Optional[str] = None
[docs] @classmethod def from_xml_output(cls, raw_judgement: Optional[str]) -> Optional[Self]: """Parses the judgement from XML-like tags in the raw output. Args: raw_judgement: The raw judgement string to parse. Returns: Optional[Self]: An instance of the class with parsed attributes, or None if parsing fails. """ if not raw_judgement: return None attributes = {} # Regex pattern to match XML-like tags and their content # Captures the tag name in group 1 and the content between tags in group 2 # For example, "<label>True</label>" would match as ("label", "True") pattern = r"<(\w+)>(.*?)</\1>" matches = re.findall(pattern, raw_judgement, re.DOTALL) for attr_name, attr_value in matches: attributes[attr_name] = attr_value.strip() return cls(**attributes, raw_judgement=raw_judgement)
[docs] @classmethod def from_json_output(cls, raw_judgement: Optional[str]) -> Optional[Self]: """Parses the judgement from JSON.""" if not raw_judgement: return None try: judgement_data = json.loads(raw_judgement) return cls(**judgement_data, raw_judgement=raw_judgement) except json.JSONDecodeError: return None
@property def label(self): """Convert the judgement to a boolean or Likert scale label. This method should be overridden by subclasses to provide the actual conversion logic. """ return self.raw_judgement @property def fields(self): """Return the fields of the judgement.""" fields = self.model_dump() fields.pop("raw_judgement", None) fields.pop("template", None) fields.pop("role", None) return fields
[docs] class BaseJudge(ABC): def __init__( self, config: JudgeConfig, inference_engine: Optional[BaseInferenceEngine] = None, ): """Initialize the Judge.""" self._config = config self._attributes = config.attributes if len(config.attributes) < 1: raise ValueError( "At least one attribute must be specified in the judge configuration." ) if inference_engine is not None: logger.debug("Using provided inference engine.") self.inference_engine = inference_engine else: logger.debug("Initializing inference engine.") self.inference_engine = self._create_inference_engine(config)
[docs] def judge( self, raw_inputs: Union[list[Conversation], list[dict], list[Message]], ) -> list[dict[str, BaseJudgeOutput]]: """Judge the given conversations.""" # Convert the raw user inputs into a list of JudgeInput classes # A JudgeInput is the unit of what needs to be judged, and could be a # prompt, request/answer pair or a full conversation judge_inputs = [] for raw_input in raw_inputs: if isinstance(raw_input, dict): judge_input = self._transform_dict_input(raw_input) elif isinstance(raw_input, TemplatedMessage): judge_input = raw_input elif isinstance(raw_input, Conversation): judge_input = self._transform_conversation_input(raw_input) else: raise ValueError(f"Unsupported conversation type: {type(raw_input)}") judge_inputs.append(judge_input) results = {} for attribute_name in self._attributes.keys(): # Generate the full judging prompt for each attribute # This includes the judge system prompt, and any few shot examples # That are included in the judge config. judgement_prompts = [ self.build_judgement_prompt(judge_input, attribute_name=attribute_name) for judge_input in tqdm(judge_inputs) ] # Run inference for the attribute's prompt # We batch the inference for a single attribute together to maximally # benefit from kv prefix caching (system prompt, few shot examples) raw_judgements = self._infer(judgement_prompts) # Parse the raw judge output (a string) into a JudgeOutput object judgements = [] for conversation in raw_judgements: judgement = conversation.messages[-1].content parsed_judgement = self._transform_model_output(judgement) judgements.append( { "raw_judgement": judgement, "fields": parsed_judgement.fields, "label": parsed_judgement.label, } ) results[attribute_name] = judgements # Results are now in the form # {attribute: judgements for attribute in attributes} # Transform to # [{attribute: judgement} for judgement in judgements] outputs = [] for idx in range(len(raw_inputs)): output_dict = {} for attribute_name in self._attributes.keys(): output_dict[attribute_name] = results[attribute_name][idx] outputs.append(output_dict) return outputs
[docs] def build_judgement_prompt( self, judge_input: Message, attribute_name: Optional[str] ) -> Conversation: """Generate judge prompts for a dataset.""" if attribute_name is None: if len(self._attributes) > 0: raise ValueError( "attribute_name must be specified when there are multiple" " attributes to judge." ) else: # If there's only one attribute, use it attribute_name = next(iter(self._attributes)) if attribute_name not in self._attributes: raise KeyError( f"Attribute '{attribute_name}' not found in config.attributes" ) attribute = self._attributes[attribute_name] messages = attribute.messages.copy() messages.append(Message(content=judge_input.content, role=Role.USER)) return Conversation( messages=messages, metadata={ "judge_attribute_name": attribute_name, }, )
def _infer(self, conversations: list[Conversation]) -> list[Conversation]: """Judge a single attribute.""" metadatas = [convo.metadata for convo in conversations] # Wrap the generation params in an inference config for inference. # Only the generations params are used by the inference engine. inference_config = InferenceConfig( model=self._config.model, generation=self._config.generation, remote_params=self._config.remote_params, ) responses = self.inference_engine.infer( input=conversations, inference_config=inference_config ) assert len(responses) == len(metadatas) for response, metadata in zip(responses, metadatas): response.metadata.update(metadata) return responses def _create_inference_engine(self, config: JudgeConfig) -> BaseInferenceEngine: """Create the inference engine.""" from oumi.builders.inference_engines import build_inference_engine return build_inference_engine( engine_type=config.engine, model_params=config.model, remote_params=config.remote_params, ) @abstractmethod def _transform_conversation_input(self, conversation: Conversation) -> Message: raise NotImplementedError @abstractmethod def _transform_dict_input(self, raw_input: dict[str, Any]) -> Message: raise NotImplementedError @abstractmethod def _transform_model_output(self, model_output) -> BaseJudgeOutput: raise NotImplementedError