Source code for oumi.core.synthesis.tool_router

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Routing layer between LLM tool calls and environment-owned tools."""

from __future__ import annotations

import json
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

from oumi.builders.environments import build_environment
from oumi.core.configs.environment_config import EnvironmentConfig
from oumi.core.configs.params.environment_params import EnvironmentParams
from oumi.core.configs.params.tool_params import (
    ToolArgumentError,
    ToolLookupError,
    ToolParams,
)
from oumi.core.types.tool_call import ToolDefinition, ToolResult
from oumi.environments.base_environment import BaseEnvironment


[docs] @dataclass class ToolRouter: """Routes LLM tool calls to environment-owned tools. Use :meth:`for_sample` to obtain a per-sample clone whose envs have independent state; the synthesizer builds one clone per sample at batch entry so tool mutations don't leak across samples. """ tool_specs: list[ToolDefinition] tools_by_id: dict[str, ToolParams] env_by_id: dict[str, BaseEnvironment] tool_to_env: dict[str, BaseEnvironment] env_params_by_id: dict[str, EnvironmentParams] tool_env_map: dict[str, str] on_env_built: Callable[[BaseEnvironment], None] | None
[docs] @classmethod def from_environment_config( cls, env_config: EnvironmentConfig, on_env_built: Callable[[BaseEnvironment], None] | None = None, ) -> ToolRouter: """Build a router from an env config.""" tool_env_map = env_config.tool_environment_map included_env_ids = set(tool_env_map.values()) | { env_params.id for env_params in env_config.environments if env_params.grounding is not None } env_params_by_id: dict[str, EnvironmentParams] = {} env_by_id: dict[str, BaseEnvironment] = {} for env_params in env_config.environments: if env_params.id not in included_env_ids: continue env_params_by_id[env_params.id] = env_params env = build_environment(env_params) if on_env_built is not None: on_env_built(env) env_by_id[env_params.id] = env tools_by_id = {tool.id: tool for tool in env_config.all_tools} tool_to_env = { tool_id: env_by_id[env_id] for tool_id, env_id in tool_env_map.items() } tool_specs = [tool.to_tool_definition() for tool in env_config.all_tools] return cls( tool_specs=tool_specs, tools_by_id=tools_by_id, env_by_id=env_by_id, tool_to_env=tool_to_env, env_params_by_id=env_params_by_id, tool_env_map=tool_env_map, on_env_built=on_env_built, )
[docs] def for_sample(self) -> ToolRouter: """Return a router safe to use for one sample. Envs whose ``requires_isolation()`` returns ``True`` are rebuilt via ``build_environment`` so their mutable state stays independent across samples; ``on_env_built`` re-runs on those fresh instances. Envs that don't require isolation (e.g. ``DeterministicEnvironment`` and stateless ``SyntheticEnvironment``) are shared with the parent router to avoid the per-sample build + inference-engine attach cost. """ env_by_id_new: dict[str, BaseEnvironment] = {} for env_id, parent_env in self.env_by_id.items(): if not parent_env.requires_isolation(): env_by_id_new[env_id] = parent_env continue fresh = build_environment(self.env_params_by_id[env_id]) if self.on_env_built is not None: self.on_env_built(fresh) env_by_id_new[env_id] = fresh return ToolRouter( tool_specs=self.tool_specs, tools_by_id=self.tools_by_id, env_by_id=env_by_id_new, tool_to_env={ tool_id: env_by_id_new[env_id] for tool_id, env_id in self.tool_env_map.items() }, env_params_by_id=self.env_params_by_id, tool_env_map=self.tool_env_map, on_env_built=self.on_env_built, )
[docs] def parse_and_validate_arguments( self, tool_id: str, raw_arguments: str ) -> dict[str, Any]: """Parse wire JSON args and validate against the tool's parameters schema.""" if tool_id not in self.tools_by_id: raise ToolLookupError( f"Unknown tool '{tool_id}'. Known: {sorted(self.tools_by_id)}" ) tool = self.tools_by_id[tool_id] try: parsed = json.loads(raw_arguments or "{}") except json.JSONDecodeError as e: raise ToolArgumentError( f"Tool '{tool_id}' arguments are not valid JSON: {e}" ) from e if not isinstance(parsed, dict): raise ToolArgumentError( f"Tool '{tool_id}' arguments must be a JSON object, got " f"{type(parsed).__name__}." ) tool.validate_arguments(parsed) return parsed
[docs] def route_batch(self, calls: list[tuple[str, dict[str, Any]]]) -> list[ToolResult]: """Dispatch a batch of (tool_id, args) pairs; preserves call order.""" if not calls: return [] groups: dict[int, list[tuple[int, str, dict[str, Any]]]] = {} for idx, (tool_id, args) in enumerate(calls): if tool_id not in self.tool_to_env: raise ToolLookupError( f"Unknown tool '{tool_id}'. Known: {sorted(self.tool_to_env)}" ) env = self.tool_to_env[tool_id] groups.setdefault(id(env), []).append((idx, tool_id, args)) results: list[ToolResult | None] = [None] * len(calls) for group in groups.values(): env = self.tool_to_env[group[0][1]] outputs = env.step([(tid, args) for _, tid, args in group]) for (idx, _, _), out in zip(group, outputs): results[idx] = out assert all(r is not None for r in results) return results # type: ignore[return-value]