Source code for oumi.core.configs.analyze_config

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional

from omegaconf import MISSING

from oumi.core.configs.base_config import BaseConfig
from oumi.core.configs.params.base_params import BaseParams


[docs] class DatasetSource(Enum): """Source of the dataset for analysis.""" CONFIG = "config" """Load dataset from config parameters (dataset_name, dataset_path, etc.)""" DIRECT = "direct" """Pass dataset directly to DatasetAnalyzer.__init__()"""
[docs] @dataclass class SampleAnalyzerParams(BaseParams): """Params for a single sample analyzer plugin.""" id: str = MISSING """Unique identifier for the analyzer.""" params: dict[str, Any] = field(default_factory=dict) """Analyzer-specific parameters passed to the analyzer constructor."""
[docs] @dataclass class AnalyzeConfig(BaseConfig): """Configuration for dataset analysis and aggregation.""" # Required field - must come first dataset_source: DatasetSource = MISSING """Source of the dataset for analysis. Use CONFIG to load from config parameters or DIRECT to pass dataset directly to DatasetAnalyzer.__init__(). This field is required and must be explicitly set. """ # Simple fields for common use cases dataset_name: Optional[str] = None """Dataset name.""" dataset_path: Optional[str] = None """Path to a custom dataset file (JSON or JSONL format). If provided, this takes precedence over dataset_name for loading custom datasets. """ dataset_format: Optional[str] = None """Format of the custom dataset. Either 'oumi' (conversation format) or 'alpaca'. Only used when dataset_path is provided. """ split: str = "train" """The split of the dataset to load. This is typically one of "train", "test", or "validation". Defaults to "train". """ subset: Optional[str] = None """The subset of the dataset to load. If None, uses the base dataset.""" sample_count: Optional[int] = None """The number of examples to sample from the dataset. If None, uses the full dataset. If specified, must be non-negative. """ output_path: str = "." """Directory path where output files will be saved. Defaults to current directory ('.'). """ analyzers: list[SampleAnalyzerParams] = field(default_factory=list) """List of analyzer configurations (plugin-style).""" tokenizer_config: Optional[dict[str, Any]] = None """Tokenizer configuration for building a tokenizer. If None, no tokenizer will be used. Expected format: { "model_name": "gpt2", # Required: model name for tokenizer "tokenizer_kwargs": {}, # Optional: additional tokenizer parameters "trust_remote_code": False # Optional: whether to trust remote code } """ # Add processor parameters for vision-language datasets processor_name: Optional[str] = None """Processor name for vision-language datasets.""" processor_kwargs: dict[str, Any] = field(default_factory=dict) """Processor-specific parameters.""" trust_remote_code: bool = False """Whether to trust remote code for processor loading.""" is_multimodal: Optional[bool] = None """If True, treat the dataset as multimodal (vision-language) when using a custom dataset_path. If False, treat as text-only. """
[docs] def __post_init__(self): """Validates the configuration parameters.""" if self.dataset_source == DatasetSource.CONFIG: # Only require dataset info when loading from config if not self.dataset_name and not self.dataset_path: raise ValueError( "Either 'dataset_name' or 'dataset_path' must be provided when " "dataset_source=DatasetSource.CONFIG" ) else: # When using direct dataset, dataset_name is optional but recommended if not self.dataset_name: self.dataset_name = "Custom Dataset" # Validate dataset_format requirements if self.dataset_path is not None: if self.dataset_format is None: raise ValueError( "'dataset_format' must be specified when using 'dataset_path'. " "Use 'oumi' for conversation format or 'alpaca' for instruction " "format." ) elif self.dataset_format not in ["oumi", "alpaca"]: raise ValueError("'dataset_format' must be either 'oumi' or 'alpaca'") # Require explicit is_multimodal setting for custom datasets if self.is_multimodal is None: raise ValueError( "'is_multimodal' must be specified when using 'dataset_path'. " "Set to 'True' for vision-language datasets or 'False' for " "text-only datasets." ) # Additional validation for multimodal if self.is_multimodal is True: # Currently VLJsonlinesDataset expects oumi conversation format if self.dataset_format != "oumi": raise ValueError( "Multimodal datasets require dataset_format='oumi'" ) if not self.processor_name: raise ValueError( "'processor_name' must be specified when 'is_multimodal' " "is True" ) # Validate sample_count if self.sample_count is not None and self.sample_count <= 0: raise ValueError("`sample_count` must be greater than 0.") # Validate analyzer configurations analyzer_ids = set() for analyzer in self.analyzers: # Validate analyzer ID if not analyzer.id: raise ValueError("Analyzer 'id' must be provided") if analyzer.id in analyzer_ids: raise ValueError(f"Duplicate analyzer ID found: '{analyzer.id}'") analyzer_ids.add(analyzer.id)