Source code for oumi.core.analyze.dataset_analyzer

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from dataclasses import asdict, dataclass
from typing import Any, Optional, Union

import pandas as pd
from tqdm import tqdm

from oumi.core.configs import AnalyzeConfig, DatasetSource
from oumi.core.datasets import BaseMapDataset
from oumi.core.registry import REGISTRY
from oumi.utils.analysis_utils import (
    build_tokenizer_from_config,
    load_dataset_from_config,
)
from oumi.utils.logging import logger


@dataclass
class MessageAnalysisResult:
    """Result of analyzing a single message in a conversation.

    Attributes:
        message_index: Index of the message within the conversation
        role: Role of the message sender (e.g., 'user', 'assistant')
        message_id: Unique identifier for the message
        text_content: The text content of the message
        analyzer_metrics: Dictionary containing analyzer metrics for this message
    """

    message_index: int
    role: str
    message_id: str
    text_content: str
    analyzer_metrics: dict[str, Any]

    def to_dict(self) -> dict[str, Any]:
        """Convert the analysis result to a dictionary with flattened analyzer metrics.

        Returns:
            Dictionary representation of the analysis result
        """
        return asdict(self)


@dataclass
class ConversationAnalysisResult:
    """Result of analyzing a conversation as a whole.

    Attributes:
        analyzer_metrics: Dictionary containing analyzer metrics for the conversation
    """

    analyzer_metrics: dict[str, Any]

    def to_dict(self) -> dict[str, Any]:
        """Convert the analysis result to a dictionary.

        Returns:
            Dictionary representation of the analysis result
        """
        return asdict(self)


@dataclass
class DatasetAnalysisResult:
    """Complete result of dataset analysis.

    Attributes:
        dataset_name: Name of the analyzed dataset
        total_conversations: Total number of conversations in the dataset
        conversations_analyzed: Number of conversations actually analyzed
    """

    dataset_name: str
    total_conversations: int
    conversations_analyzed: int

    def to_dict(self) -> dict[str, Any]:
        """Convert the analysis result to a dictionary.

        Returns:
            Dictionary representation of the analysis result
        """
        return asdict(self)


[docs] class DatasetAnalyzer: """Orchestrates the analysis of datasets using multiple sample analyzers.""" def __init__(self, config: AnalyzeConfig, dataset: Optional[BaseMapDataset] = None): """Initialize the dataset analyzer with configuration. Args: config: AnalyzeConfig object containing all analysis parameters dataset: Optional pre-loaded dataset. If provided, this dataset will be used instead of loading from the config. """ self.config = config self.dataset_name = config.dataset_name self.split = config.split # Build tokenizer from config if provided self.tokenizer = build_tokenizer_from_config(config.tokenizer_config) # Use provided dataset or load from config based on dataset_source if config.dataset_source == DatasetSource.DIRECT: # Direct mode: must provide dataset if dataset is None: raise ValueError( "Config specifies dataset_source=DatasetSource.DIRECT but no " "dataset was provided. Either pass a dataset to " "DatasetAnalyzer.__init__() or " "set dataset_source=DatasetSource.CONFIG.value." ) self.dataset = dataset # Use the provided dataset name if config doesn't have one if not self.dataset_name: self.dataset_name = getattr(dataset, "dataset_name", "Custom Dataset") logger.info( f"Using provided dataset '{self.dataset_name}' with " f"{len(dataset)} conversations" ) elif config.dataset_source == DatasetSource.CONFIG: # Config mode: load dataset from config parameters if dataset is not None: raise ValueError( f"Dataset provided but config.dataset_source is " f"'{config.dataset_source.value}'. When using " f"DatasetSource.CONFIG, do not pass a dataset to the " f"constructor. Set dataset_source=DatasetSource.DIRECT " f"if you want to use the provided dataset." ) # Load dataset with the tokenizer self.dataset = load_dataset_from_config(config, self.tokenizer) logger.info(f"Loaded dataset from config: {self.dataset_name}") else: raise ValueError(f"Invalid dataset_source: {config.dataset_source}") self.sample_analyzers = self._initialize_sample_analyzers() # Initialize analysis results as None self._analysis_results: Optional[DatasetAnalysisResult] = None self._merged_df: Optional[pd.DataFrame] = None self._message_df: Optional[pd.DataFrame] = None self._conversation_df: Optional[pd.DataFrame] = None self._analysis_summary: Optional[dict[str, Any]] = None # Decimal precision for rounding metrics self._decimal_precision = 2 def _initialize_sample_analyzers(self) -> dict[str, Any]: """Initialize sample analyzer plugins from configuration. Returns: Dictionary mapping analyzer IDs to analyzer instances """ sample_analyzers = {} for analyzer_params in self.config.analyzers: try: # Get the analyzer class from the registry analyzer_class = REGISTRY.get_sample_analyzer(analyzer_params.id) if analyzer_class is None: raise ValueError( f"Sample analyzer '{analyzer_params.id}' not found in registry" ) # Prepare parameters for analyzer constructor analyzer_kwargs = dict(analyzer_params.params) if self.tokenizer is not None: analyzer_kwargs["tokenizer"] = self.tokenizer # Create analyzer instance with keyword arguments sample_analyzer = analyzer_class(**analyzer_kwargs) sample_analyzers[analyzer_params.id] = sample_analyzer logger.info(f"Initialized sample analyzer: {analyzer_params.id}") except Exception as e: logger.error( f"Failed to initialize sample analyzer {analyzer_params.id}: {e}" ) logger.error(f"Analyzer configuration: {analyzer_params}") return sample_analyzers
[docs] def analyze_dataset(self) -> None: """Analyze the dataset and store results internally. This method performs both message-level and conversation-level analysis using the configured sample analyzers. Each analyzer processes entire conversations and returns metrics for both individual messages and conversations as a whole. Results are stored internally and can be accessed via the query() method. Raises: ValueError: If no analyzers are configured for analysis. """ if not self.sample_analyzers: raise ValueError( "No analyzers configured for analysis. Please add at least one " "analyzer to the configuration before calling analyze_dataset()." ) logger.info(f"Starting analysis of dataset: {self.dataset_name}") logger.info( f"Using {len(self.sample_analyzers)} sample analyzers: " f"{list(self.sample_analyzers.keys())}" ) total_conversations = len(self.dataset) conversations_to_analyze = min( total_conversations, self.config.sample_count or total_conversations ) logger.info(f"Analyzing {conversations_to_analyze} conversations") self._compute_conversation_metrics() # Generate and store the analysis summary after metrics are computed self._analysis_summary = self._generate_analysis_summary()
@property def analysis_results(self) -> Optional[DatasetAnalysisResult]: """Get the analysis results if available. Returns: DatasetAnalysisResult if analysis has been run, None otherwise """ return self._analysis_results def _compute_conversation_metrics(self) -> None: """Compute metrics for all conversations in the dataset. This method processes each conversation and creates DataFrames with prefixed columns for each analyzer's metrics. """ total_conversations = len(self.dataset) # Apply conversation limit if specified max_conversations = self.config.sample_count if max_conversations is not None: # AnalyzeConfig ensures sample_count is greater than 0 conversations_to_analyze = min(total_conversations, max_conversations) logger.info( f"Limiting analysis to first {max_conversations} " f"conversations (dataset has {total_conversations} total)" ) else: conversations_to_analyze = total_conversations logger.info( "Analyzing %d conversations for both message-level and " "conversation-level metrics", conversations_to_analyze, ) # Collect DataFrames for messages and conversations message_dfs = [] conversation_dfs = [] # Use tqdm for progress monitoring for conv_idx in tqdm( range(conversations_to_analyze), desc=f"Analyzing conversations in {self.dataset_name}", unit="conv", ): conversation = self.dataset.conversation(conv_idx) conversation_id = conversation.conversation_id or f"conv_{conv_idx}" # Process each analyzer for this conversation conversation_has_data = False for analyzer_id, analyzer in self.sample_analyzers.items(): try: message_results, conversation_result = analyzer.analyze_sample( conversation, self.tokenizer ) # Convert to DataFrames with prefixed columns message_df = self._convert_messages_to_df( message_results, analyzer_id, conversation_id, conv_idx ) conversation_df = self._convert_conversation_to_df( conversation_result, analyzer_id, conversation_id, conv_idx, ) # Always add conversation_df (even if empty) to ensure conversation # is represented conversation_dfs.append(conversation_df) # Only add message_df if it has data if not message_df.empty: message_dfs.append(message_df) conversation_has_data = True except Exception as e: logger.warning( f"Analyzer {analyzer_id} failed for conversation " f"{conv_idx}: {e}" ) # If no analyzers succeeded, add a placeholder row for this conversation if not conversation_has_data: # Create a placeholder row with only basic columns (no analyzer columns) placeholder_row = { "conversation_id": conversation_id, "conversation_index": conv_idx, "message_index": 0, # Add required message columns "role": "system", # Default role "message_id": f"placeholder_{conv_idx}_0", "text_content": "", # Empty content } placeholder_df = pd.DataFrame([placeholder_row]) message_dfs.append(placeholder_df) # Add to message_dfs instead # Create final DataFrames if message_dfs: self._message_df = pd.concat(message_dfs, ignore_index=True) else: self._message_df = pd.DataFrame() if conversation_dfs: self._conversation_df = pd.concat(conversation_dfs, ignore_index=True) else: self._conversation_df = pd.DataFrame() # Create merged DataFrame with both message and conversation metrics if not self._message_df.empty and not self._conversation_df.empty: self._merged_df = self._message_df.merge( self._conversation_df, on=["conversation_id", "conversation_index"], how="left", ) elif not self._message_df.empty: self._merged_df = self._message_df.copy() elif not self._conversation_df.empty: self._merged_df = self._conversation_df.copy() else: self._merged_df = pd.DataFrame() # Store metadata self._analysis_results = DatasetAnalysisResult( dataset_name=self.dataset_name or "", total_conversations=total_conversations, conversations_analyzed=conversations_to_analyze, ) def _convert_messages_to_df( self, messages: list[MessageAnalysisResult], analyzer_id: str, conversation_id: str, conversation_index: int, ) -> pd.DataFrame: """Convert message results to DataFrame with prefixed columns.""" if not messages: return pd.DataFrame() rows = [] for message in messages: row = { "conversation_id": conversation_id, "conversation_index": conversation_index, "message_index": message.message_index, "role": message.role, "message_id": message.message_id, "text_content": message.text_content, } # Add analyzer metrics with message_ prefix for key, value in message.analyzer_metrics.items(): row[f"message_{analyzer_id}_{key}"] = value rows.append(row) return pd.DataFrame(rows) def _convert_conversation_to_df( self, conversation: ConversationAnalysisResult, analyzer_id: str, conversation_id: str, conversation_index: int, ) -> pd.DataFrame: """Convert conversation result to DataFrame with prefixed columns.""" row = { "conversation_id": conversation_id, "conversation_index": conversation_index, } # Add analyzer metrics with conversation_ prefix for key, value in conversation.analyzer_metrics.items(): row[f"conversation_{analyzer_id}_{key}"] = value return pd.DataFrame([row])
[docs] def query(self, query_expression: str) -> pd.DataFrame: """Query the analysis results using pandas query syntax. Args: query_expression: Pandas query expression (e.g., "char_count > 10") Returns: DataFrame containing rows that match the query expression Raises: RuntimeError: If analysis has not been run yet. """ # Check if analysis has been run if self._merged_df is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to query the analysis results." ) # Apply the query filter try: filtered_df = self._merged_df.query(query_expression) logger.info(f"Query '{query_expression}' returned {len(filtered_df)} rows") except Exception as e: logger.error(f"Query failed: {e}") raise ValueError(f"Invalid query expression: {query_expression}") from e return filtered_df
@property def analysis_df(self) -> Union[pd.DataFrame, None]: """Get the merged analysis DataFrame with both message and conversation metrics. Returns: DataFrame with columns prefixed by message_ and conversation_ for each analyzer Raises: RuntimeError: If analysis has not been run yet. """ if self._merged_df is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to access the analysis DataFrame." ) return self._merged_df @property def message_df(self) -> Union[pd.DataFrame, None]: """Get the message-level analysis DataFrame. Returns: DataFrame with message-level metrics prefixed by message_ Raises: RuntimeError: If analysis has not been run yet. """ if self._message_df is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to access the message DataFrame." ) return self._message_df @property def conversation_df(self) -> Union[pd.DataFrame, None]: """Get the conversation-level analysis DataFrame. Returns: DataFrame with conversation-level metrics prefixed by conversation_ Raises: RuntimeError: If analysis has not been run yet. """ if self._conversation_df is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to access the conversation DataFrame." ) return self._conversation_df
[docs] def query_conversations( self, query_expression: str, ) -> pd.DataFrame: """Query conversation-level analysis results using pandas query expression. Args: query_expression: Pandas query expression to filter conversation analysis results Returns: DataFrame with filtered conversation analysis results Raises: RuntimeError: If analysis has not been run yet. Examples: # Filter for short conversations long_conversations = analyzer.query_conversations( "length_token_count > 1000" ) """ # Check if analysis has been run if self._conversation_df is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to query conversation results." ) # Apply the query filter try: filtered_df = self._conversation_df.query(query_expression) logger.info(f"Query '{query_expression}' returned {len(filtered_df)} rows") except Exception as e: logger.error(f"Query failed: {e}") raise ValueError(f"Invalid query expression '{query_expression}': {e}") return filtered_df
[docs] def filter( self, query_expression: str, ) -> BaseMapDataset: """Filter the original dataset based on analysis results. This method uses analysis results to filter the original dataset, returning a new dataset object containing only the conversations that match the query. Args: query_expression: Pandas query expression to filter analysis results Returns: A new dataset object containing only the filtered conversations Raises: RuntimeError: If analysis has not been run yet. Examples: # Filter for conversations with short messages short_dataset = analyzer.filter("length_word_count < 10") # Filter for conversations with assistant messages assistant_dataset = analyzer.filter("role == 'assistant'") # Filter for conversations with long user messages long_user_dataset = analyzer.filter( "role == 'user' and length_word_count > 100" ) """ # Get filtered analysis results filtered_df = self.query(query_expression) # Get unique conversation indices from filtered results conversation_indices = filtered_df.conversation_index.unique().tolist() # Create a new dataset with only the filtered conversations filtered_dataset = self._create_filtered_dataset(conversation_indices) logger.info( f"Filtered dataset: {len(conversation_indices)} conversations " f"out of {len(self.dataset)} total" ) return filtered_dataset
def _create_filtered_dataset( self, conversation_indices: list[int] ) -> BaseMapDataset: """Create a new dataset containing only the specified conversations. Args: conversation_indices: List of conversation indices to include Returns: A new dataset object with the same format as the original """ # Deep copy the original dataset to preserve all attributes and methods filtered_dataset = copy.deepcopy(self.dataset) # Filter the DataFrame to only include the specified conversations original_df = self.dataset.data filtered_dataset._data = original_df.iloc[conversation_indices].copy() # Update the dataset name to indicate it's filtered filtered_dataset.dataset_name = f"{self.dataset.dataset_name}_filtered" return filtered_dataset def _generate_analysis_summary(self) -> dict[str, Any]: """Generate a comprehensive summary of dataset analysis results. This method aggregates metrics from all analyzers to provide insights useful for assessing datasets. It computes statistics like averages, standard deviations, min/max values, and efficiency metrics. Returns: Dictionary containing comprehensive dataset analysis summary with: - Dataset overview statistics - Message-level aggregated metrics - Conversation-level aggregated metrics """ # Check if we have data to analyze if self._merged_df is None or self._merged_df.empty: return {"error": "No analysis data available"} summary = { "dataset_overview": self._get_dataset_overview(), "message_level_summary": self._get_message_level_summary(), "conversation_level_summary": self._get_conversation_level_summary(), } return summary @property def analysis_summary(self) -> dict[str, Any]: """Get the comprehensive analysis summary. Returns: Dictionary containing comprehensive dataset analysis summary Raises: RuntimeError: If analysis has not been run yet. """ if self._analysis_summary is None: raise RuntimeError( "Analysis has not been run yet. Please call analyze_dataset() first " "to generate the analysis summary." ) return self._analysis_summary def _get_dataset_overview(self) -> dict[str, Any]: """Get basic dataset overview statistics.""" if self._analysis_results is None: return {} return { "dataset_name": self._analysis_results.dataset_name, "total_conversations": self._analysis_results.total_conversations, "conversations_analyzed": self._analysis_results.conversations_analyzed, "dataset_coverage_percentage": round( 100.0 * self._analysis_results.conversations_analyzed / self._analysis_results.total_conversations if self._analysis_results.total_conversations > 0 else 0, self._decimal_precision, ), "total_messages": len(self._message_df) if self._message_df is not None else 0, "analyzers_used": list(self.sample_analyzers.keys()), } def _get_message_level_summary(self) -> dict[str, Any]: """Get aggregated message-level metrics across all analyzers.""" if self._message_df is None or self._message_df.empty: return {} # Get all message-level analyzer columns message_columns = [ col for col in self._message_df.columns if col.startswith("message_") ] summary = {} for col in message_columns: if col in [ "message_index", "role", "message_id", "text_content", "conversation_id", "conversation_index", ]: continue # Extract analyzer name and metric from column # Format: message_{analyzer}_{metric} parts = col.split("_", 2) if len(parts) >= 3: analyzer_name = parts[1] metric_name = "_".join(parts[2:]) if analyzer_name not in summary: summary[analyzer_name] = {} # Compute statistics for numeric columns if pd.api.types.is_numeric_dtype(self._message_df[col]): values = self._message_df[col].dropna() if len(values) > 0: summary[analyzer_name][metric_name] = { "count": len(values), "mean": round( float(values.mean()), self._decimal_precision ), "std": round(float(values.std()), self._decimal_precision), "min": float(values.min()), "max": float(values.max()), "median": round( float(values.median()), self._decimal_precision ), } return summary def _get_conversation_level_summary(self) -> dict[str, Any]: """Get aggregated conversation-level metrics across all analyzers.""" if self._conversation_df is None or self._conversation_df.empty: return {} # Get all conversation-level analyzer columns conversation_columns = [ col for col in self._conversation_df.columns if col.startswith("conversation_") ] summary = {} for col in conversation_columns: if col in ["conversation_id", "conversation_index"]: continue # Extract analyzer name and metric from column # Format: conversation_{analyzer}_{metric} parts = col.split("_", 2) if len(parts) >= 3: analyzer_name = parts[1] metric_name = "_".join(parts[2:]) if analyzer_name not in summary: summary[analyzer_name] = {} # Compute statistics for numeric columns if pd.api.types.is_numeric_dtype(self._conversation_df[col]): values = self._conversation_df[col].dropna() if len(values) > 0: summary[analyzer_name][metric_name] = { "count": len(values), "mean": round( float(values.mean()), self._decimal_precision ), "std": round(float(values.std()), self._decimal_precision), "min": float(values.min()), "max": float(values.max()), "median": round( float(values.median()), self._decimal_precision ), } # Add conversation turn statistics if available if self._message_df is not None and not self._message_df.empty: turns_per_conversation = self._message_df.groupby("conversation_id").size() # Handle pandas Series operations with proper type conversion mean_val = turns_per_conversation.mean() std_val = turns_per_conversation.std() min_val = turns_per_conversation.min() max_val = turns_per_conversation.max() median_val = turns_per_conversation.median() summary["conversation_turns"] = { "count": len(turns_per_conversation), "mean": round(float(mean_val), self._decimal_precision) # type: ignore if mean_val is not None else 0.0, "std": round(float(std_val), self._decimal_precision) # type: ignore if std_val is not None else 0.0, "min": int(min_val) if min_val is not None else 0, # type: ignore "max": int(max_val) if max_val is not None else 0, # type: ignore "median": round(float(median_val), self._decimal_precision) # type: ignore if median_val is not None else 0.0, } return summary