# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batch-aware test engine for validating analysis results incrementally.
Processes results one batch at a time with constant memory, accumulating
only lightweight counters and affected conversation IDs. Call
``process_batch()`` for each batch, then ``finalize()`` to compute the
final TestSummary.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from pydantic import BaseModel
from oumi.analyze.testing.engine import (
MAX_FAILURE_REASONS,
MAX_SAMPLE_INDICES,
OPERATORS,
)
from oumi.analyze.testing.results import TestResult, TestSeverity, TestSummary
from oumi.core.configs.params.test_params import TestParams, TestType
logger = logging.getLogger(__name__)
@dataclass
class _TestAccumulator:
"""Per-test state accumulated across batches."""
test: TestParams
matching_count: int = 0
non_matching_count: int = 0
total_count: int = 0
matching_conversation_ids: list[str | None] = field(default_factory=list)
non_matching_conversation_ids: list[str | None] = field(default_factory=list)
matching_reasons: dict[int, str] = field(default_factory=dict)
non_matching_reasons: dict[int, str] = field(default_factory=dict)
error: str | None = None
[docs]
class BatchTestEngine:
"""Engine for running tests on analysis results incrementally.
Unlike ``TestEngine`` which requires the full dataset in memory,
``BatchTestEngine`` processes one batch at a time and accumulates
only counters and affected conversation IDs.
Example:
>>> engine = BatchTestEngine(tests)
>>> for batch_results, batch_ids in batches:
... engine.process_batch(batch_results, batch_ids)
>>> summary = engine.finalize()
Args:
tests: List of test configurations.
"""
def __init__(self, tests: list[TestParams]):
"""Initialize the batch test engine with test configurations."""
self.tests = tests
self._accumulators: dict[str, _TestAccumulator] = {}
for test in tests:
self._accumulators[test.id] = _TestAccumulator(test=test)
[docs]
def process_batch(
self,
results: dict[str, list[BaseModel] | BaseModel],
conversation_ids: list[str | None],
) -> None:
"""Process one batch of analysis results.
Args:
results: Analyzer results for this batch (same format as
``TestEngine.run()``).
conversation_ids: Conversation IDs for each item in this batch,
aligned by index with the per-conversation result lists.
"""
for test in self.tests:
acc = self._accumulators[test.id]
if acc.error:
continue
try:
self._process_test_batch(acc, test, results, conversation_ids)
except Exception as e:
acc.error = f"Test execution failed: {e}"
logger.warning(f" Test '{test.id}': ERROR - {e}")
def _create_error_result(self, test: TestParams, error: str) -> TestResult:
"""Create a TestResult for an error condition."""
return TestResult(
test_id=test.id,
passed=False,
severity=TestSeverity(test.severity),
title=test.title or test.id,
description=test.description or "",
metric=test.metric or "",
error=error,
)
[docs]
def finalize(self) -> TestSummary:
"""Compute final test results from accumulated batch data.
Returns:
TestSummary with pass/fail for each test.
"""
test_results: list[TestResult] = []
for test in self.tests:
acc = self._accumulators[test.id]
if acc.error:
test_results.append(self._create_error_result(test, acc.error))
continue
if acc.total_count == 0 and test.metric:
test_results.append(
self._create_error_result(
test, f"Metric '{test.metric}' not found in results"
)
)
continue
test_results.append(self._build_final_result(acc))
summary = TestSummary.from_results(test_results)
logger.info(
f"Test results: {summary.passed_tests}/{summary.total_tests} passed "
f"({summary.pass_rate}%)"
)
if summary.high_severity_failures > 0:
logger.warning(f" {summary.high_severity_failures} high severity failures")
return summary
[docs]
def get_affected_conversation_ids(self) -> dict[str, list[str | None]]:
"""Return affected conversation IDs per test.
Call after ``finalize()`` to get the full mapping for persistence
(e.g. ``test_affected_rows.json``).
"""
result: dict[str, list[str | None]] = {}
for test in self.tests:
acc = self._accumulators[test.id]
if acc.error:
result[test.id] = []
continue
affected_ids = self._get_affected_ids(acc)
result[test.id] = affected_ids
return result
def _process_test_batch(
self,
acc: _TestAccumulator,
test: TestParams,
results: dict[str, list[BaseModel] | BaseModel],
conversation_ids: list[str | None],
) -> None:
"""Process a single test against one batch of results."""
if not test.metric:
acc.error = "Test requires 'metric' field"
return
values = self._extract_metric_values(test.metric, results)
if not values:
return
if test.type != TestType.THRESHOLD:
acc.error = f"Unknown test type: {test.type}"
return
if test.operator is None or test.value is None:
acc.error = "Threshold test requires 'operator' and 'value'"
return
op_func = OPERATORS.get(test.operator)
if op_func is None:
acc.error = f"Unknown operator: {test.operator}"
return
for orig_idx, value in values:
conv_id = (
conversation_ids[orig_idx] if orig_idx < len(conversation_ids) else None
)
try:
if op_func(value, test.value):
match_pos = acc.matching_count
acc.matching_count += 1
acc.matching_conversation_ids.append(conv_id)
if len(acc.matching_reasons) < MAX_FAILURE_REASONS:
acc.matching_reasons[match_pos] = (
f"Flagged: {test.metric} {test.operator} {test.value}"
f" (value={value})"
)
else:
non_match_pos = acc.non_matching_count
acc.non_matching_count += 1
acc.non_matching_conversation_ids.append(conv_id)
if len(acc.non_matching_reasons) < MAX_FAILURE_REASONS:
acc.non_matching_reasons[non_match_pos] = (
f"Not flagged: {test.metric} {test.operator} {test.value}"
f" (value={value})"
)
except (TypeError, ValueError):
non_match_pos = acc.non_matching_count
acc.non_matching_count += 1
acc.non_matching_conversation_ids.append(conv_id)
if len(acc.non_matching_reasons) < MAX_FAILURE_REASONS:
acc.non_matching_reasons[non_match_pos] = (
f"Cannot evaluate: {value}"
)
acc.total_count += len(values)
def _determine_outcome(
self, acc: _TestAccumulator
) -> tuple[bool, list[str | None], float, dict[int, str]]:
"""Determine pass/fail and select the affected set.
Returns:
(passed, affected_ids, affected_pct, failure_reasons)
"""
test = acc.test
total_count = acc.total_count
matching_count = acc.matching_count
if total_count > 0:
matching_pct = 100.0 * matching_count / total_count
non_matching_pct = 100.0 * acc.non_matching_count / total_count
else:
matching_pct = 0.0
non_matching_pct = 0.0
passed = True
affected_ids: list[str | None] = []
affected_pct = 0.0
failure_reasons: dict[int, str] = {}
if test.max_percentage is not None and matching_pct > test.max_percentage:
passed = False
affected_ids = acc.matching_conversation_ids
affected_pct = matching_pct
failure_reasons = acc.matching_reasons
if test.min_percentage is not None and matching_pct < test.min_percentage:
passed = False
if not affected_ids:
affected_ids = acc.non_matching_conversation_ids
affected_pct = non_matching_pct
failure_reasons = acc.non_matching_reasons
if test.max_percentage is None and test.min_percentage is None:
passed = matching_count == 0
affected_ids = acc.matching_conversation_ids
affected_pct = matching_pct
failure_reasons = acc.matching_reasons
return passed, affected_ids, affected_pct, failure_reasons
def _build_final_result(self, acc: _TestAccumulator) -> TestResult:
"""Build the final TestResult from accumulated data."""
test = acc.test
passed, affected_ids, affected_pct, failure_reasons = self._determine_outcome(
acc
)
return TestResult(
test_id=test.id,
passed=passed,
severity=TestSeverity(test.severity),
title=test.title or test.id,
description=test.description or "",
metric=test.metric or "",
affected_count=len(affected_ids),
total_count=acc.total_count,
affected_percentage=round(affected_pct, 2),
threshold=test.max_percentage or test.min_percentage,
actual_value=None,
sample_indices=[], # Not meaningful for batch mode
all_affected_indices=[], # Not meaningful for batch mode
details={
"operator": test.operator,
"value": test.value,
"max_percentage": test.max_percentage,
"min_percentage": test.min_percentage,
"matching_count": acc.matching_count,
"matching_percentage": round(
100.0 * acc.matching_count / acc.total_count
if acc.total_count > 0
else 0.0,
2,
),
"failure_reasons": {
str(k): v
for k, v in list(failure_reasons.items())[:MAX_FAILURE_REASONS]
},
"sample_conversation_ids": affected_ids[:MAX_SAMPLE_INDICES],
},
)
def _get_affected_ids(self, acc: _TestAccumulator) -> list[str | None]:
"""Return all affected conversation IDs based on test outcome."""
_, affected_ids, _, _ = self._determine_outcome(acc)
return affected_ids
def _extract_metric_values(
self,
metric: str,
results: dict[str, list[BaseModel] | BaseModel],
) -> list[tuple[int, Any]]:
"""Extract values for a metric path like 'analyzer_name.field_name'.
Returns a list of (original_index, value) tuples so callers can
correctly align values with conversation IDs even when some items
have ``None`` metric values.
"""
parts = metric.split(".")
if len(parts) < 2:
return []
analyzer_name = parts[0]
field_path = parts[1:]
if analyzer_name not in results:
return []
analyzer_results = results[analyzer_name]
if isinstance(analyzer_results, BaseModel):
value = self._get_nested_value(analyzer_results, field_path)
return [(0, value)] if value is not None else []
values = []
for idx, result in enumerate(analyzer_results):
value = self._get_nested_value(result, field_path)
if value is not None:
values.append((idx, value))
return values
def _get_nested_value(self, obj: Any, field_path: list[str]) -> Any:
"""Get a nested field value from a Pydantic model or dict."""
current: Any = obj
for i, field_name in enumerate(field_path):
if isinstance(current, BaseModel):
if field_name in type(current).model_fields:
current = getattr(current, field_name)
else:
values = getattr(current, "values", None)
if isinstance(values, dict):
return self._traverse_dict(values, field_path[i:])
return None
elif isinstance(current, dict):
if field_name in current:
current = current[field_name]
else:
return None
else:
raise TypeError(
f"Cannot traverse type {type(current).__name__}. "
f"Expected BaseModel or dict, got {current!r}"
)
return current
def _traverse_dict(self, d: dict, path: list[str]) -> Any | None:
"""Traverse a dict using a field path."""
current: Any = d
for field_name in path:
if isinstance(current, dict) and field_name in current:
current = current[field_name]
else:
return None
return current