Source code for oumi.analyze.testing.engine

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test engine for validating typed analysis results.

This module provides a test engine that operates on typed Pydantic results
instead of DataFrames. Tests are pure validation - no computation allowed.
"""

import logging
import operator
from collections.abc import Callable
from typing import Any

from pydantic import BaseModel

from oumi.analyze.testing.results import TestResult, TestSeverity, TestSummary
from oumi.core.configs.params.test_params import TestParams, TestType

logger = logging.getLogger(__name__)


OPERATORS: dict[str, Callable[[Any, Any], bool]] = {
    "<": operator.lt,
    ">": operator.gt,
    "<=": operator.le,
    ">=": operator.ge,
    "==": operator.eq,
    "!=": operator.ne,
}


[docs] class TestEngine: """Engine for running tests on typed analysis results. Tests operate on typed Pydantic results, not DataFrames. This ensures tests are pure validation with no computation - all metrics must be pre-computed by analyzers. Example: >>> from oumi.analyze.testing import TestEngine, TestParams, TestType >>> >>> tests = [ ... TestParams( ... id="max_words", ... type=TestType.THRESHOLD, ... metric="LengthAnalyzer.total_words", ... operator=">", ... value=10000, ... max_percentage=5.0, ... severity=TestSeverity.MEDIUM, ... ), ... ] >>> engine = TestEngine(tests) >>> summary = engine.run(results) >>> print(f"Pass rate: {summary.pass_rate}%") Args: tests: List of test configurations. """ def __init__(self, tests: list[TestParams]): """Initialize the test engine with test configurations.""" self.tests = tests def _create_error_result(self, test: TestParams, error: str) -> TestResult: """Create a TestResult for an error condition. Args: test: Test configuration. error: Error message. Returns: TestResult with passed=False and error set. """ return TestResult( test_id=test.id, passed=False, severity=TestSeverity(test.severity), title=test.title or test.id, description=test.description or "", metric=test.metric or "", error=error, ) def _calculate_percentage(self, count: int, total: int) -> float: """Calculate percentage, handling division by zero. Args: count: Numerator. total: Denominator. Returns: Percentage (0.0 to 100.0). """ return 100.0 * count / total if total > 0 else 0.0 def _build_test_result( self, test: TestParams, passed: bool, total_count: int, affected_indices: list[int], affected_pct: float, details: dict[str, Any], actual_value: float | None = None, ) -> TestResult: """Build a TestResult from common fields. Args: test: Test configuration. passed: Whether the test passed. total_count: Total number of values tested. affected_indices: Indices of affected samples. affected_pct: Percentage of affected samples. details: Test-specific details. actual_value: Actual metric value for single-value tests. Returns: TestResult instance. """ return TestResult( test_id=test.id, passed=passed, severity=TestSeverity(test.severity), title=test.title or test.id, description=test.description or "", metric=test.metric or "", affected_count=len(affected_indices), total_count=total_count, affected_percentage=round(affected_pct, 2), threshold=test.max_percentage or test.min_percentage, actual_value=actual_value, sample_indices=affected_indices[:50], details=details, ) def _get_actual_value(self, values: list[Any]) -> float | None: """Extract actual value for single-value metrics. Args: values: List of metric values. Returns: Float value if this is a single numeric value, None otherwise. """ if len(values) == 1: val = values[0] if isinstance(val, int | float): return float(val) if isinstance(val, bool): return 1.0 if val else 0.0 return None
[docs] def run( self, results: dict[str, list[BaseModel] | BaseModel], ) -> TestSummary: """Run all tests on the analysis results. Args: results: Dictionary mapping analyzer names to results. Returns: TestSummary containing all test results. """ test_results: list[TestResult] = [] logger.info(f"Running {len(self.tests)} tests...") for test in self.tests: try: result = self._run_single_test(test, results) test_results.append(result) status = "PASSED" if result.passed else "FAILED" logger.debug( f" Test '{test.id}': {status} " f"({result.affected_count}/{result.total_count} affected)" ) except Exception as e: error_result = self._create_error_result( test, f"Test execution failed: {e}" ) test_results.append(error_result) logger.warning(f" Test '{test.id}': ERROR - {e}") summary = TestSummary.from_results(test_results) logger.info( f"Test results: {summary.passed_tests}/{summary.total_tests} passed " f"({summary.pass_rate}%)" ) if summary.high_severity_failures > 0: logger.warning(f" {summary.high_severity_failures} high severity failures") return summary
def _run_single_test( self, test: TestParams, results: dict[str, list[BaseModel] | BaseModel], ) -> TestResult: """Run a single test. Args: test: Test configuration. results: Analysis results. Returns: TestResult for this test. """ if not test.metric: return self._create_error_result(test, "Test requires 'metric' field") values = self._extract_metric_values(test.metric, results) if not values: return self._create_error_result( test, f"Metric '{test.metric}' not found in results" ) if test.type == TestType.THRESHOLD: return self._run_threshold_test(test, values) else: return self._create_error_result(test, f"Unknown test type: {test.type}") def _extract_metric_values( self, metric: str, results: dict[str, list[BaseModel] | BaseModel], ) -> list[Any]: """Extract metric values from results. Metric format: "AnalyzerName.field_name" or "AnalyzerName.nested.field" Args: metric: Metric path string. results: Analysis results. Returns: List of values for the metric. """ parts = metric.split(".") if len(parts) < 2: return [] analyzer_name = parts[0] field_path = parts[1:] if analyzer_name not in results: return [] analyzer_results = results[analyzer_name] if isinstance(analyzer_results, BaseModel): value = self._get_nested_value(analyzer_results, field_path) return [value] if value is not None else [] values = [] for result in analyzer_results: value = self._get_nested_value(result, field_path) if value is not None: values.append(value) return values def _get_nested_value( self, obj: Any, field_path: list[str], ) -> Any | None: """Get a nested field value from a Pydantic model or dict. Args: obj: Pydantic model instance or dict. field_path: List of field names to traverse. Returns: Field value or None if not found. """ current: Any = obj for i, field in enumerate(field_path): if isinstance(current, BaseModel): # Pydantic model - check if field exists in model_fields if field in type(current).model_fields: current = getattr(current, field) else: # Check for CustomMetricResult with values dict values = getattr(current, "values", None) if isinstance(values, dict): return self._traverse_dict(values, field_path[i:]) return None elif isinstance(current, dict): if field in current: current = current[field] else: return None else: raise TypeError( f"Cannot traverse type {type(current).__name__}. " f"Expected BaseModel or dict, got {current!r}" ) return current def _traverse_dict(self, d: dict, path: list[str]) -> Any | None: """Traverse a dict using a field path.""" current: Any = d for field in path: if isinstance(current, dict) and field in current: current = current[field] else: return None return current def _run_threshold_test( self, test: TestParams, values: list[Any], ) -> TestResult: """Run a threshold test. Args: test: Test configuration. values: Metric values to test. Returns: TestResult. """ if test.operator is None or test.value is None: return self._create_error_result( test, "Threshold test requires 'operator' and 'value'" ) op_func = OPERATORS.get(test.operator) if op_func is None: return self._create_error_result(test, f"Unknown operator: {test.operator}") matching_indices = [] non_matching_indices = [] matching_reasons: dict[int, str] = {} non_matching_reasons: dict[int, str] = {} for i, value in enumerate(values): try: if op_func(value, test.value): matching_indices.append(i) matching_reasons[i] = f"{value} {test.operator} {test.value}" else: non_matching_indices.append(i) non_matching_reasons[i] = ( f"{value} does not satisfy {test.operator} {test.value}" ) except (TypeError, ValueError): non_matching_indices.append(i) non_matching_reasons[i] = f"Cannot evaluate: {value}" total_count = len(values) matching_count = len(matching_indices) non_matching_count = len(non_matching_indices) matching_pct = self._calculate_percentage(matching_count, total_count) non_matching_pct = self._calculate_percentage(non_matching_count, total_count) # Determine pass/fail based on percentage thresholds passed = True affected_indices = [] affected_pct = 0.0 failure_reasons: dict[int, str] = {} if test.max_percentage is not None and matching_pct > test.max_percentage: passed = False affected_indices = matching_indices affected_pct = matching_pct failure_reasons = matching_reasons if test.min_percentage is not None and matching_pct < test.min_percentage: passed = False # If max also failed, combine; otherwise use non-matching if not affected_indices: affected_indices = non_matching_indices affected_pct = non_matching_pct failure_reasons = non_matching_reasons # Default case: no percentage thresholds, all must match if test.max_percentage is None and test.min_percentage is None: passed = non_matching_count == 0 affected_indices = non_matching_indices affected_pct = non_matching_pct failure_reasons = non_matching_reasons return self._build_test_result( test=test, passed=passed, total_count=total_count, affected_indices=affected_indices, affected_pct=affected_pct, actual_value=self._get_actual_value(values), details={ "operator": test.operator, "value": test.value, "max_percentage": test.max_percentage, "min_percentage": test.min_percentage, "matching_count": matching_count, "matching_percentage": round(matching_pct, 2), "failure_reasons": { k: v for k, v in list(failure_reasons.items())[:50] }, }, )