Source code for oumi.core.configs.params.rule_judge_params
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any
from oumi.core.configs.params.base_params import BaseParams
from oumi.core.configs.params.judge_params import (
JudgeOutputType,
JudgeResponseFormat,
)
[docs]
@dataclass
class RuleJudgeParams(BaseParams):
r"""Parameters for rule-based judge evaluation.
This class defines the configuration for a rule-based judge that uses
deterministic rules.
Examples:
Regex pattern matching:
>>> rule_params = RuleJudgeParams( # doctest: +SKIP
... rule_type="regex_match",
... input_fields=["text"],
... rule_config={"pattern": r"\\d{3}-\\d{4}", "match_mode": "contains"},
... response_format=JudgeResponseFormat.XML,
... judgment_type=JudgeOutputType.BOOL
... )
"""
rule_type: str
"""Type of rule to apply (e.g., 'exact_match', 'regex_match', 'contains', etc.)"""
input_fields: list[str]
"""List of input field names that the rule will operate on.
These fields must be present in the input data passed to the judge.
Example: ["expected_answer", "actual_answer"] for comparison rules.
"""
rule_config: dict[str, Any] = field(default_factory=dict)
"""Configuration specific to the rule type.
Different rule types require different configuration parameters.
Examples:
- regex_match: {"pattern": r"\\d{3}-\\d{4}", "match_mode": "fullmatch"}
"""
response_format: JudgeResponseFormat = field(default=JudgeResponseFormat.XML)
"""The format in which the judge output should be formatted."""
judgment_type: JudgeOutputType = field(default=JudgeOutputType.BOOL)
"""The type of output that the judgment produces."""
judgment_scores: dict[str, float] | None = field(default=None)
"""For ENUM judgment_type, the mapping from category names to numeric scores.
Example:
{"excellent": 1.0, "good": 0.7, "poor": 0.3}
"""
[docs]
def __post_init__(self):
"""Validate the parameters after initialization."""
self._validate_params()
def _validate_params(self):
"""Validate the parameters for consistency and completeness.
Raises:
ValueError: If parameters are invalid
"""
if not self.rule_type or not self.rule_type.strip():
raise ValueError("rule_type cannot be empty")
if not self.input_fields:
raise ValueError("input_fields cannot be empty")
if not all(
isinstance(field, str) and field.strip() for field in self.input_fields
):
raise ValueError("All input_fields must be non-empty strings")
if self.judgment_type == JudgeOutputType.ENUM and not self.judgment_scores:
raise ValueError("judgment_scores must be provided for ENUM judgment_type")
if self.judgment_scores:
if not all(
isinstance(score, int | float)
for score in self.judgment_scores.values()
):
raise ValueError("All judgment_scores values must be numeric")
if len(self.judgment_scores) == 0:
raise ValueError("judgment_scores cannot be empty when provided")