Source code for oumi.analyze.analyzers.length

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Length analyzer implementation and result model."""

from typing import Protocol, runtime_checkable

import tiktoken
from pydantic import BaseModel, Field

from oumi.analyze.base import ConversationAnalyzer
from oumi.core.registry import register_sample_analyzer
from oumi.core.types.conversation import Conversation, Role

__all__ = ["LengthMetrics", "LengthAnalyzer", "Tokenizer", "default_tokenizer"]


[docs] @runtime_checkable class Tokenizer(Protocol): """Protocol for tokenizers used by LengthAnalyzer."""
[docs] def encode(self, text: str) -> list[int]: """Encode text to token IDs.""" ...
[docs] def default_tokenizer(encoding: str = "cl100k_base") -> tiktoken.Encoding: """Get the default tiktoken tokenizer. Args: encoding: Tiktoken encoding name. Defaults to "cl100k_base" (GPT-4). Returns: Tiktoken encoder instance. """ return tiktoken.get_encoding(encoding)
[docs] class LengthMetrics(BaseModel): """Result model for length analysis of conversations. Example: >>> result = LengthMetrics( ... total_tokens=25, ... avg_tokens_per_message=12.5, ... message_token_counts=[10, 15], ... num_messages=2, ... ) >>> print(result.total_tokens) 25 """ total_tokens: int = Field(description="Total number of tokens across all messages") rendered_tokens: int | None = Field( default=None, description="Token count of the full conversation rendered with chat template. " "None if tokenizer doesn't support apply_chat_template.", ) avg_tokens_per_message: float = Field(description="Average tokens per message") message_token_counts: list[int] = Field( description="Token count for each message in order" ) num_messages: int = Field(description="Number of messages in the conversation") user_total_tokens: int = Field( default=0, description="Total tokens in user messages" ) assistant_total_tokens: int = Field( default=0, description="Total tokens in assistant messages" ) system_total_tokens: int = Field( default=0, description="Total tokens in system messages" ) tool_total_tokens: int = Field( default=0, description="Total tokens in tool messages" )
[docs] @register_sample_analyzer("length") class LengthAnalyzer(ConversationAnalyzer[LengthMetrics]): """Analyzer for computing token length metrics of conversations. Computes token counts for conversations using a provided tokenizer. Provides both conversation-level totals and per-message breakdowns. Example: >>> from oumi.analyze.analyzers.length import LengthAnalyzer, default_tokenizer >>> from oumi.core.types.conversation import Conversation, Message, Role >>> >>> analyzer = LengthAnalyzer(tokenizer=default_tokenizer()) >>> conversation = Conversation(messages=[ ... Message(role=Role.USER, content="Hello, how are you?"), ... Message(role=Role.ASSISTANT, content="I'm doing well, thanks!"), ... ]) >>> result = analyzer.analyze(conversation) >>> print(f"Total tokens: {result.total_tokens}") Total tokens: 12 Args: tokenizer: Tokenizer instance for token counting. Must have an `encode(text) -> list` method. Use `default_tokenizer()` for tiktoken, or pass a HuggingFace tokenizer for model-specific counts. """ def __init__(self, tokenizer: Tokenizer | None = None): """Initialize the analyzer.""" self.tokenizer = tokenizer def _count_tokens(self, text: str) -> int: if self.tokenizer is None: raise RuntimeError( "No tokenizer configured. Either pass a tokenizer to __init__ " "or use default_tokenizer()." ) tokens = self.tokenizer.encode(text) return len(tokens) def _count_rendered_tokens(self, conversation: Conversation) -> int | None: """Count tokens in the chat-template-rendered conversation. This gives the actual token count the model sees during training/inference, including special tokens added by the chat template. Args: conversation: The conversation to render and tokenize. Returns: Token count of rendered conversation, or None if tokenizer doesn't support chat templates. """ if self.tokenizer is None: return None # Check if tokenizer has a chat template before proceeding if getattr(self.tokenizer, "chat_template", None) is None: return None if not conversation.messages: return 0 try: # Use base class method to render conversation with chat template # Type ignore: we've verified tokenizer has chat_template attribute above rendered_text = self.get_conversation_text(conversation, self.tokenizer) # type: ignore[arg-type] return self._count_tokens(rendered_text) except (ValueError, AttributeError): # Unexpected error during rendering return None
[docs] def analyze(self, conversation: Conversation) -> LengthMetrics: """Analyze token length metrics for a conversation. Args: conversation: The conversation to analyze. Returns: LengthMetrics containing token counts. """ message_token_counts: list[int] = [] role_token_counts: dict[Role, int] = {role: 0 for role in Role} for message in conversation.messages: text = self.get_text_content(message) token_count = self._count_tokens(text) message_token_counts.append(token_count) if message.role in role_token_counts: role_token_counts[message.role] += token_count total_tokens = sum(message_token_counts) num_messages = len(conversation.messages) avg_tokens = total_tokens / num_messages if num_messages > 0 else 0.0 rendered_tokens = self._count_rendered_tokens(conversation) return LengthMetrics( total_tokens=total_tokens, rendered_tokens=rendered_tokens, avg_tokens_per_message=avg_tokens, message_token_counts=message_token_counts, num_messages=num_messages, user_total_tokens=role_token_counts[Role.USER], assistant_total_tokens=role_token_counts[Role.ASSISTANT], system_total_tokens=role_token_counts[Role.SYSTEM], tool_total_tokens=role_token_counts[Role.TOOL], )
[docs] def analyze_text(self, text: str) -> LengthMetrics: """Analyze token length metrics for a single text string. Convenience method for analyzing text without creating a Conversation. Args: text: The text to analyze. Returns: LengthMetrics for the text (treated as a single message). """ token_count = self._count_tokens(text) return LengthMetrics( total_tokens=token_count, avg_tokens_per_message=float(token_count), message_token_counts=[token_count], num_messages=1, )