# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromdataclassesimportdataclass,fieldfromenumimportEnumfrompathlibimportPathfromtypingimportGeneric,Optional,TypeVarimportpydanticfromoumi.core.configsimportBaseConfigfromoumi.core.configs.inference_engine_typeimportInferenceEngineTypefromoumi.core.configs.params.generation_paramsimportGenerationParamsfromoumi.core.configs.params.model_paramsimportModelParamsfromoumi.core.configs.params.remote_paramsimportRemoteParamsfromoumi.core.types.conversationimportConversation,Message,Role,TemplatedMessage
[docs]classJudgeAttributeValueType(str,Enum):"""Enumeration of possible value types for judge attributes."""BOOL="bool""""Boolean value type."""CATEGORICAL="categorical""""Categorical value type."""LIKERT_5="likert-5""""Likert scale with 5 points value type."""
T=TypeVar("T",bound=TemplatedMessage)
[docs]classJudgeAttribute(pydantic.BaseModel,Generic[T]):"""Attributes for the judge. Example: >>> attribute = JudgeAttribute( # doctest: +SKIP ... name="helpful", ... system_prompt="You are an impartial judge.", ... examples=[ ... TemplatedMessage( ... role=Role.USER, ... request="What is the capital of France?", ... response="The capital of France is Paris.", ... ), ... TemplatedMessage( ... role=Role.ASSISTANT, ... response="True", ... ), ... ], ... value_type=JudgeAttributeValueType.BOOL, ... limit_examples=5, ... ) >>> print(attribute.name) # doctest: +SKIP helpful """name:str"""The name of the attribute being judged."""system_prompt:str"""The system prompt for the judge."""examples:list[T]=field(default_factory=list)"""A list of few-shot example inputs and judgements."""value_type:JudgeAttributeValueType=JudgeAttributeValueType.BOOL"""The type of value for the attribute."""limit_examples:Optional[int]=5"""The maximum number of examples to use. This is an optional parameter that limits the number of examples to be used for judging the attribute. If not specified, the default is 5. """@propertydefconversation(self)->Conversation:"""Returns the judgement conversation in oumi format. This will include the judge system prompt, and any few-shot examples. """returnConversation(messages=self.messages)@propertydefmessages(self)->list[Message]:"""Returns the messages in oumi format. This will include the judge system prompt, and any few-shot examples. """messages=[Message(content=self.system_prompt,role=Role.SYSTEM)]returnmessages+[e.messageforeinself.examples]
[docs]@classmethoddefload(cls:type,filename:str)->"JudgeAttribute[T]":"""Loads the judge attribute config from a file."""path=Path(filename)ifnotpath.exists():raiseFileNotFoundError(path)returncls.model_validate_json(path.read_text())
[docs]@dataclassclassJudgeConfig(BaseConfig):"""Configuration for the Judge. This class holds the configuration for the Judge, including the attributes to judge, the model parameters, and the text generation parameters. Examples: >>> attributes = { ... "helpful": JudgeAttribute( # doctest: +SKIP ... name="helpful", ... system_prompt="Is this answer helpful?", ... examples=[ ... TemplatedMessage( ... role=Role.USER, ... request="What is the capital of France?", ... response="The capital of France is Paris.", ... ), ... TemplatedMessage( ... role=Role.ASSISTANT, ... response="True", ... ), ... ], ... ), ... "honest": JudgeAttribute( ... name="honest", ... system_prompt="Is this answer honest?", ... examples=[] ... ) ... } >>> model_params = ModelParams(model_name="example-model") >>> generation_params = GenerationParams(max_new_tokens=100) # doctest: +SKIP >>> judge_config = JudgeConfig( # doctest: +SKIP ... attributes=attributes, ... model=model_params, ... generation=generation_params ... ) """attributes:dict[str,JudgeAttribute]=field(default_factory=dict)"""The attributes to judge."""model:ModelParams=field(default_factory=ModelParams)"""Parameters for the model used in inference."""generation:GenerationParams=field(default_factory=GenerationParams)"""Parameters for text generation during inference."""engine:InferenceEngineType=field(default=InferenceEngineType.NATIVE)"""The inference engine to use for generation."""remote_params:Optional[RemoteParams]=None"""Parameters for running inference against a remote API."""