Source code for oumi.core.configs.params.tool_params

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tool definitions and execution results shared by all environment types."""

from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any

import jsonschema

from oumi.core.configs.params.base_params import BaseParams
from oumi.core.types.tool_call import (
    FunctionDefinition,
    JSONSchema,
    ToolDefinition,
)


[docs] class ToolError(Exception): """Base class for tool errors surfaced back to the LLM. Subclasses of this exception are caught by the tool-call loop and re-emitted as structured ``tool`` messages so the model can self-correct on the next iteration. """
[docs] class ToolArgumentError(ToolError): """Raised when tool-call arguments fail schema validation."""
[docs] class ToolLookupError(ToolError): """Raised when a tool cannot resolve an output for the given arguments. Currently used by ``DeterministicEnvironment`` when no configured ``LookupEntry`` matches the provided arguments. """
[docs] @dataclass class ToolParams(BaseParams): """Tool schema owned by an environment. ``parameters`` and ``output_schema`` are stored as plain JSON-Schema dicts so OmegaConf can carry them through YAML round-trips. They are converted to a Pydantic ``JSONSchema`` only at the wire-format boundary in :meth:`to_tool_definition`. """ id: str name: str description: str parameters: dict[str, Any] = field(default_factory=lambda: {"type": "object"}) output_schema: dict[str, Any] | None = None read_only: bool = True
[docs] @classmethod def create(cls, raw: Any) -> ToolParams: """Create a tool from raw config data.""" if isinstance(raw, ToolParams): return raw if not isinstance(raw, Mapping): raise TypeError( f"Tool definitions must be tool objects or mappings, got {type(raw)}" ) return cls( id=raw["id"], name=raw["name"], description=raw["description"], parameters=dict(raw.get("parameters", {"type": "object"})), output_schema=( dict(raw["output_schema"]) if raw.get("output_schema") is not None else None ), read_only=raw.get("read_only", True), )
[docs] def __post_init__(self): """Validate common tool fields. Accepts ``JSONSchema`` instances on ``parameters`` / ``output_schema`` for callers that build a tool with Pydantic types directly; converts them to dicts so the canonical in-memory shape stays JSON-Schema-shaped. """ if not self.id: raise ValueError(f"{type(self).__name__}.id cannot be empty.") if not self.name: raise ValueError(f"{type(self).__name__}.name cannot be empty.") if not self.description: raise ValueError(f"{type(self).__name__}.description cannot be empty.") if isinstance(self.parameters, JSONSchema): self.parameters = self.parameters.model_dump(mode="json", exclude_none=True) if isinstance(self.output_schema, JSONSchema): self.output_schema = self.output_schema.model_dump( mode="json", exclude_none=True )
[docs] def to_llm_schema(self) -> dict[str, Any]: """Export a provider-agnostic schema for LLM tool registration.""" schema: dict[str, Any] = { "name": self.id, "display_name": self.name, "description": self.description, "parameters": self.parameters, } if self.output_schema is not None: schema["output_schema"] = self.output_schema return schema
[docs] def to_tool_definition(self) -> ToolDefinition: """Project to OpenAI-wire-format ``ToolDefinition``. Drops chain-internal fields (``output_schema``, ``read_only``, ``name`` display label) that have no slot in the OpenAI contract. Coerces ``parameters`` to ``JSONSchema`` at the boundary. """ return ToolDefinition( function=FunctionDefinition( name=self.id, description=self.description, parameters=JSONSchema.model_validate(self.parameters), ), )
[docs] def validate_arguments(self, arguments: dict[str, Any]) -> None: """Validate call-time arguments against this tool's ``parameters`` schema. Raises: ToolArgumentError: If ``arguments`` do not conform. """ try: jsonschema.validate(arguments, self.parameters) except jsonschema.ValidationError as e: raise ToolArgumentError(str(e)) from e