Source code for oumi.analyze.config

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configuration for the typed analyzer framework."""

from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any

from oumi.core.configs.params.test_params import TestParams


class AnalyzerType(str, Enum):
    """Built-in analyzer types."""

    LENGTH = "length"
    QUALITY = "quality"
    FORMAT = "format"
    DIVERSITY = "diversity"
    EMBEDDING = "embedding"
    LLM_JUDGE = "llm_judge"


[docs] @dataclass class AnalyzerConfig: """Configuration for a single analyzer. Attributes: id: Analyzer type identifier (e.g., "length", "quality"). instance_id: Optional unique instance ID for multiple analyzers of same type. params: Analyzer-specific parameters. """ id: str instance_id: str | None = None params: dict[str, Any] = field(default_factory=dict)
[docs] def __post_init__(self): """Auto-populate instance_id if not provided.""" if self.instance_id is None: self.instance_id = self.id
@dataclass class OutputFieldSchema: """Schema definition for a single output field. Attributes: name: Field name (key in the returned dict). type: Field type ("int", "float", "bool", "str", "list"). description: Description of the field. """ name: str type: str = "float" description: str = "" @dataclass class CustomMetricConfig: """Configuration for a custom user-defined metric. Custom metrics allow users to define Python functions that compute additional metrics. These are executed during the analysis phase and their results are cached. .. warning:: **Security Warning**: The ``function`` field contains arbitrary Python code that is executed dynamically. Only load configurations from trusted sources. Never load YAML configs from untrusted users or external sources without review, as they could execute malicious code. Example YAML:: custom_metrics: - id: word_to_char_ratio scope: conversation description: "Ratio of words to characters" output_schema: - name: ratio type: float description: "Words divided by characters (0.15-0.20 is typical)" function: | def compute(conversation): chars = sum(len(m.content) for m in conversation.messages) words = sum(len(m.content.split()) for m in conversation.messages) return {"ratio": words / chars if chars > 0 else 0.0} Attributes: id: Unique identifier for the metric. scope: Scope of the metric ("message", "conversation", or "dataset"). function: Python code defining a compute() function. description: Description of what the metric computes. output_schema: List of output field definitions. """ id: str scope: str = "conversation" # "message", "conversation", or "dataset" function: str = "" description: str | None = None output_schema: list[OutputFieldSchema] = field(default_factory=list) depends_on: list[str] = field(default_factory=list) def __post_init__(self): """Validate the configuration.""" if self.scope not in ("message", "conversation", "dataset"): raise ValueError( f"Invalid scope '{self.scope}'. " "Must be 'message', 'conversation', or 'dataset'." ) def get_metric_paths(self) -> list[str]: """Get full metric paths for all output fields. Returns: List of metric paths like ["metric_id.field_name", ...]. """ if self.output_schema: return [f"{self.id}.{f.name}" for f in self.output_schema] return [f"{self.id}.<field>"] def get_field_info(self) -> dict[str, dict[str, str]]: """Get field information for display. Returns: Dict mapping field names to {"type": ..., "description": ...}. """ return { f.name: {"type": f.type, "description": f.description} for f in self.output_schema }
[docs] @dataclass class TypedAnalyzeConfig: """Configuration for the typed analyzer pipeline. This is the main configuration class for the new typed analyzer architecture. It supports both programmatic construction and loading from YAML files. Example YAML:: dataset_path: /path/to/data.jsonl sample_count: 1000 output_path: ./analysis_output analyzers: - id: length params: count_tokens: true - id: quality custom_metrics: - id: turn_pattern scope: conversation function: | def compute(conversation): ... tests: - id: max_words type: threshold metric: LengthAnalyzer.total_words operator: ">" value: 10000 max_percentage: 5.0 Attributes: dataset_name: Name of the dataset (HuggingFace identifier). dataset_path: Path to local dataset file. split: Dataset split to use. sample_count: Number of samples to analyze. output_path: Directory for output artifacts. analyzers: List of analyzer configurations. custom_metrics: List of custom metric configurations. tests: List of test configurations. tokenizer_name: Tokenizer for token counting. generate_report: Whether to generate HTML report. report_title: Custom title for the report. """ # Eval name (optional, for web viewer) eval_name: str | None = None # Parent eval ID (for linking derived analyses) parent_eval_id: str | None = None # Dataset source dataset_name: str | None = None dataset_path: str | None = None split: str = "train" subset: str | None = None sample_count: int | None = None # Output output_path: str = "." # Analyzers analyzers: list[AnalyzerConfig] = field(default_factory=list) # Custom metrics custom_metrics: list[CustomMetricConfig] = field(default_factory=list) # Tests tests: list[TestParams] = field(default_factory=list) # Tokenizer tokenizer_name: str | None = None tokenizer_kwargs: dict[str, Any] = field(default_factory=dict) # Report generate_report: bool = False report_title: str | None = None
[docs] @classmethod def from_yaml( cls, path: str | Path, allow_custom_code: bool = False ) -> "TypedAnalyzeConfig": """Load configuration from a YAML file. .. warning:: **Security Warning**: If the YAML file contains ``custom_metrics`` with ``function`` fields, arbitrary Python code will be loaded. Only load configurations from trusted sources. Set ``allow_custom_code=True`` to explicitly acknowledge this risk. Args: path: Path to YAML configuration file. allow_custom_code: If True, allow loading custom_metrics with function code. If False (default) and the config contains custom metrics with code, raises ValueError. Returns: TypedAnalyzeConfig instance. Raises: ValueError: If config contains custom code but allow_custom_code=False. """ import yaml with open(path) as f: data = yaml.safe_load(f) return cls.from_dict(data, allow_custom_code=allow_custom_code)
[docs] @classmethod def from_dict( cls, data: dict[str, Any], allow_custom_code: bool = False ) -> "TypedAnalyzeConfig": """Create configuration from a dictionary. Args: data: Configuration dictionary. allow_custom_code: If True, allow custom_metrics with function code. If False (default) and the config contains custom metrics with code, raises ValueError. Returns: TypedAnalyzeConfig instance. Raises: ValueError: If config contains custom code but allow_custom_code=False. """ # Parse analyzers analyzers = [] for analyzer_data in data.get("analyzers", []): if isinstance(analyzer_data, dict): analyzers.append(AnalyzerConfig(**analyzer_data)) elif isinstance(analyzer_data, str): analyzers.append(AnalyzerConfig(id=analyzer_data)) # Validate unique instance_ids instance_ids = [a.instance_id for a in analyzers] duplicates = [id for id in set(instance_ids) if instance_ids.count(id) > 1] if duplicates: raise ValueError( f"Duplicate analyzer instance_id values: {duplicates}. " "Each analyzer must have a unique instance_id to avoid collisions." ) # Parse custom metrics custom_metrics = [] for metric_data in data.get("custom_metrics", []): # Parse output_schema if present output_schema = [] for field_data in metric_data.pop("output_schema", []): if isinstance(field_data, dict): output_schema.append(OutputFieldSchema(**field_data)) custom_metrics.append( CustomMetricConfig(**metric_data, output_schema=output_schema) ) # Security check: reject custom code unless explicitly allowed if not allow_custom_code: metrics_with_code = [m.id for m in custom_metrics if m.function.strip()] if metrics_with_code: raise ValueError( f"Configuration contains custom metrics with executable code: " f"{metrics_with_code}. This is a security risk if loading from " f"untrusted sources. Set allow_custom_code=True to explicitly " f"allow code execution, or remove the 'function' fields." ) # Parse tests tests = [] for test_data in data.get("tests", []): test_params = TestParams(**test_data) test_params.finalize_and_validate() tests.append(test_params) return cls( eval_name=data.get("eval_name"), parent_eval_id=data.get("parent_eval_id"), dataset_name=data.get("dataset_name"), dataset_path=data.get("dataset_path"), split=data.get("split", "train"), subset=data.get("subset"), sample_count=data.get("sample_count"), output_path=data.get("output_path", "."), analyzers=analyzers, custom_metrics=custom_metrics, tests=tests, tokenizer_name=data.get("tokenizer_name"), tokenizer_kwargs=data.get("tokenizer_kwargs", {}), generate_report=data.get("generate_report", False), report_title=data.get("report_title"), )
[docs] def to_dict(self) -> dict[str, Any]: """Convert configuration to a dictionary. Returns: Configuration as dictionary. """ return { "dataset_name": self.dataset_name, "dataset_path": self.dataset_path, "split": self.split, "subset": self.subset, "sample_count": self.sample_count, "output_path": self.output_path, "analyzers": [ {"id": a.id, "instance_id": a.instance_id, "params": a.params} for a in self.analyzers ], "custom_metrics": [ { "id": m.id, "scope": m.scope, "function": m.function, "description": m.description, } for m in self.custom_metrics ], "tests": [ { "id": t.id, "type": t.type, "metric": t.metric, "severity": t.severity, "title": t.title, "description": t.description, "operator": t.operator, "value": t.value, "condition": t.condition, "max_percentage": t.max_percentage, "min_percentage": t.min_percentage, } for t in self.tests ], "tokenizer_name": self.tokenizer_name, "tokenizer_kwargs": self.tokenizer_kwargs, "generate_report": self.generate_report, "report_title": self.report_title, }
[docs] def get_test_configs(self) -> list[TestParams]: """Get test configurations for the test engine. Returns: List of TestParams instances. """ return self.tests