# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test configuration parameters for dataset analysis.
This module provides dataclasses for configuring user-defined tests that run
on dataset analysis results. Inspired by promptfoo's declarative assertion system.
Example:
>>> from oumi.core.configs.params.test_params import TestParams
>>> test = TestParams(
... id="no_pii",
... type="percentage",
... metric="quality__has_pii",
... condition="== True",
... max_percentage=1.0,
... severity="high",
... title="PII detected in dataset",
... )
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from oumi.core.configs.params.base_params import BaseParams
[docs]
class TestType(str, Enum):
"""Types of tests that can be run on analysis results.
Currently implemented:
- THRESHOLD: Numeric comparisons with optional percentage tolerance
Not yet implemented (planned for future):
- REGEX: Pattern matching on text fields
- CONTAINS: Text containment checks (supports match_mode: any/all/exact)
- OUTLIERS: Anomaly detection using standard deviation
- COMPOSITE: Combine multiple tests with AND/OR logic
"""
THRESHOLD = "threshold"
# Not yet implemented - planned for future
REGEX = "regex"
CONTAINS = "contains"
OUTLIERS = "outliers"
COMPOSITE = "composite"
[docs]
class TestSeverity(str, Enum):
"""Severity levels for test failures."""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
[docs]
class TestScope(str, Enum):
"""Scope at which a test operates."""
MESSAGE = "message"
CONVERSATION = "conversation"
[docs]
class CompositeOperator(str, Enum):
"""Operators for combining tests in composite tests."""
ANY = "any"
ALL = "all"
# Declarative validation configuration
# Note: Only "threshold" is currently implemented. Others are planned for future.
TEST_VALIDATIONS = {
"threshold": {
"required": ["metric", "operator", "value"],
"valid_values": {"operator": ["<", ">", "<=", ">=", "==", "!="]},
},
# Not yet implemented - planned for future
"regex": {
"required": ["text_field", "pattern"],
},
"contains": {
"required": ["text_field"],
"custom": lambda self: (
None
if (self.value is not None or self.values)
else "requires 'value' or 'values'"
),
},
"outliers": {
"required": ["metric"],
"custom": lambda self: (
None if self.std_threshold > 0 else "'std_threshold' must be positive"
),
},
"composite": {
"required": ["tests"],
"custom": lambda self: (
None
if self.tests
and (
self.composite_operator in ["any", "all"]
or _try_parse_int(self.composite_operator)
)
else "requires at least one sub-test"
if not self.tests
else f"Invalid composite_operator '{self.composite_operator}'"
),
},
}
def _try_parse_int(value: str) -> bool:
"""Try to parse a string as an integer."""
try:
int(value)
return True
except (ValueError, TypeError):
return False
[docs]
@dataclass
class TestParams(BaseParams):
"""Configuration for a single test on analysis results.
This is a flexible dataclass that supports all test types. Fields are
optional based on the test type being configured. Validation is performed
in __finalize_and_validate__ based on the test type.
Attributes:
id: Unique identifier for this test.
type: The type of test (threshold, percentage, regex, etc.).
severity: How severe a failure of this test is (high, medium, low).
title: Human-readable title for the test (shown in reports).
description: Detailed description of what this test checks.
scope: Whether to run on message or conversation DataFrame.
negate: If True, invert the test logic (pass becomes fail).
# Metric-based test fields (threshold, percentage, outliers)
metric: Column name to check (e.g., "length__token_count").
operator: Comparison operator for threshold tests (<, >, <=, >=, ==, !=).
value: Value to compare against for threshold tests.
condition: Condition string for percentage tests (e.g., "== True", "> 0.5").
max_percentage: Maximum percentage of samples that can match/fail.
min_percentage: Minimum percentage of samples that must match.
std_threshold: Standard deviations for outlier detection.
# Text-based test fields (regex, contains)
field: Column name containing text to search (e.g., "text_content").
pattern: Regex pattern for regex tests.
values: List of substrings for contains-any/contains-all tests.
case_sensitive: Whether text matching is case-sensitive.
# Distribution test fields
check: Type of distribution check (max_fraction, entropy, etc.).
threshold: Threshold value for distribution checks.
# Query test fields
expression: Pandas query expression string.
# Composite test fields
tests: List of sub-test configurations for composite tests.
composite_operator: How to combine sub-tests (any, all, or min count).
# Python test fields
function: Python function code as a string.
"""
id: str = ""
type: str = ""
severity: str = "medium"
title: str | None = None
description: str | None = None
scope: str = "message"
negate: bool = False
metric: str | None = None
operator: str | None = None
value: float | int | str | None = None
condition: str | None = None
max_percentage: float | None = None
min_percentage: float | None = None
std_threshold: float = 3.0
text_field: str | None = None
pattern: str | None = None
values: list[str] | None = None
case_sensitive: bool = False
check: str | None = None
threshold: float | None = None
expression: str | None = None
tests: list[dict[str, Any]] = field(default_factory=list)
composite_operator: str = "any"
function: str | None = None
[docs]
def __finalize_and_validate__(self) -> None:
"""Validate test configuration based on test type."""
if not self.id:
raise ValueError("Test 'id' is required.")
if not self.type:
raise ValueError(f"Test 'type' is required for test '{self.id}'.")
self._validate_enum_field("type", TestType, "test type")
self._validate_enum_field("severity", TestSeverity, "severity")
self._validate_enum_field("scope", TestScope, "scope")
self._validate_by_type()
def _validate_enum_field(
self, field_name: str, enum_class: Any, label: str
) -> None:
"""Validate that a field matches an enum value.
Args:
field_name: Name of the field to validate.
enum_class: Enum class to validate against.
label: Human-readable label for error messages.
"""
value = getattr(self, field_name)
valid_values = [e.value for e in enum_class]
if value not in valid_values:
raise ValueError(
f"Invalid {label} '{value}' for test '{self.id}'. "
f"Valid values: {valid_values}"
)
def _validate_by_type(self) -> None:
"""Validate fields based on test type using declarative rules."""
validation_rules = TEST_VALIDATIONS.get(self.type)
if not validation_rules:
return
# Check required fields
for field_name in validation_rules.get("required", []):
value = getattr(self, field_name)
if value is None or (isinstance(value, str) and not value):
raise ValueError(
f"Test '{self.id}': '{field_name}' is required for "
f"{self.type} tests."
)
# Check either_required (at least one must be set)
for field_group in validation_rules.get("either_required", []):
if not any(getattr(self, f) is not None for f in field_group):
fields_str = "' or '".join(field_group)
raise ValueError(
f"Test '{self.id}': Either '{fields_str}' "
f"is required for {self.type} tests."
)
# Check valid_values (field must be in allowed list)
for field_name, valid_values in validation_rules.get(
"valid_values", {}
).items():
value = getattr(self, field_name)
if value and value not in valid_values:
raise ValueError(
f"Test '{self.id}': Invalid {field_name} '{value}'. "
f"Valid values: {valid_values}"
)
# Check valid_enums (field must be in enum)
for field_name, enum_name in validation_rules.get("valid_enums", {}).items():
value = getattr(self, field_name)
if value:
enum_class = globals()[enum_name]
valid_values = [e.value for e in enum_class]
if value not in valid_values:
raise ValueError(
f"Test '{self.id}': Invalid {field_name} '{value}'. "
f"Valid values: {valid_values}"
)
# Run custom validation if provided
custom_validator = validation_rules.get("custom")
if custom_validator:
result = custom_validator(self)
if isinstance(result, str):
raise ValueError(f"Test '{self.id}': {result}")
[docs]
def get_title(self) -> str:
"""Get the display title for this test."""
if self.title:
return self.title
return self.id.replace("_", " ").title()
[docs]
def get_description(self) -> str:
"""Get the description for this test."""
if self.description:
return self.description
return f"Test of type '{self.type}'."