# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deterministic environment with fixed lookup responses."""
from __future__ import annotations
import dataclasses
import json
import random
from dataclasses import dataclass, field
from typing import Any
from oumi.core.configs.params.base_params import BaseParams
from oumi.core.configs.params.environment_params import EnvironmentParams
from oumi.core.configs.params.grounding_params import GroundingFact
from oumi.core.configs.params.tool_params import ToolLookupError, ToolParams
from oumi.core.registry import register_environment
from oumi.core.types.tool_call import ToolResult
from oumi.environments.base_environment import BaseEnvironment
from oumi.utils.logging import logger
[docs]
@dataclass
class ToolLookupEntry(BaseParams):
"""One (input, output) pair in a deterministic env's lookup table."""
input: dict[str, Any] = field(default_factory=dict)
output: dict[str, Any] = field(default_factory=dict)
[docs]
def input_key(self) -> str:
"""Canonical JSON form of ``input`` for matching and dedup."""
return json.dumps(self.input, sort_keys=True)
[docs]
def matches(self, arguments: dict[str, Any]) -> bool:
"""Check if the input matches the given arguments."""
return self.input_key() == json.dumps(arguments, sort_keys=True)
[docs]
@dataclass
class DeterministicEnvironmentKwargs(BaseParams):
"""Type-specific kwargs for DeterministicEnvironment."""
lookup_table: dict[str, list[ToolLookupEntry]] = field(default_factory=dict)
"""Per-tool list of (input, output) entries, keyed by tool id."""
[docs]
def __post_init__(self) -> None:
"""Coerce raw entry dicts into ``ToolLookupEntry`` instances."""
self.lookup_table = {
tool_id: [
entry
if isinstance(entry, ToolLookupEntry)
else ToolLookupEntry(**entry)
for entry in entries
]
for tool_id, entries in self.lookup_table.items()
}
[docs]
@register_environment("deterministic")
class DeterministicEnvironment(BaseEnvironment):
"""Environment that resolves tools from a per-tool lookup table.
The env's ``env_kwargs.lookup_table`` is the source of truth for tool
behavior. Tools listed in ``params.tools`` declare contracts only;
their data lives on the env.
"""
tool_params_cls = ToolParams
def __init__(
self,
params: EnvironmentParams,
kwargs: DeterministicEnvironmentKwargs,
) -> None:
"""Initialize a DeterministicEnvironment."""
self._params = params
self._kwargs = kwargs
self._tool_ids = {tool.id for tool in params.tools}
self._validate_lookup_table()
[docs]
def step(self, calls: list[tuple[str, dict[str, Any]]]) -> list[ToolResult]:
"""Resolve a batch of deterministic tool calls to their outputs.
Raises:
ValueError: If any ``tool_id`` is not declared in this env's tools list.
ToolLookupError: If no entry in the env's lookup table matches
the provided arguments for any call.
"""
return [self._resolve_one(tool_id, args) for tool_id, args in calls]
def _resolve_one(self, tool_id: str, arguments: dict[str, Any]) -> ToolResult:
if tool_id not in self._tool_ids:
raise ValueError(
f"Tool '{tool_id}' not found in environment '{self._params.id}'. "
f"Available tools: {sorted(self._tool_ids)}"
)
entries = self._kwargs.lookup_table.get(tool_id, [])
for entry in entries:
if entry.matches(arguments):
return ToolResult(output=entry.output)
available = [entry.input for entry in entries]
raise ToolLookupError(
f"No deterministic output matches arguments "
f"{json.dumps(arguments, sort_keys=True)} for tool '{tool_id}'. "
f"Configured inputs: {json.dumps(available, sort_keys=True)}"
)
[docs]
def sample_grounding(
self,
n: int,
*,
rng: random.Random,
tool_ids: set[str] | None = None,
) -> list[GroundingFact]:
"""Sample grounding facts from per-tool projected pools.
Walks every tool that has a per-tool entry in
``params.grounding.tools``. Each entry in that tool's lookup table
is projected via ``{**input, **output}`` filtered through the
configured ``fields`` whitelist. Tools without a grounding entry
contribute nothing.
"""
grounding = self._params.grounding
if grounding is None or not grounding.tools:
return []
pool: list[GroundingFact] = []
for tool in self._params.tools:
tool_grounding = grounding.tools.get(tool.id)
if tool_grounding is None:
continue
if tool_ids is not None and tool.id not in tool_ids:
continue
whitelist = set(tool_grounding.fields)
for entry in self._kwargs.lookup_table.get(tool.id, []):
row = {**entry.input, **entry.output}
projected = {
key: value for key, value in row.items() if key in whitelist
}
pool.append(GroundingFact(data=projected))
return rng.sample(pool, min(n, len(pool)))
[docs]
@classmethod
def from_params(cls, params: EnvironmentParams) -> DeterministicEnvironment:
"""Build a DeterministicEnvironment from its params object."""
raw_kwargs = params.env_kwargs or {}
known = {f.name for f in dataclasses.fields(DeterministicEnvironmentKwargs)}
unknown = set(raw_kwargs) - known
if unknown:
raise ValueError(
f"DeterministicEnvironment got unknown env_kwargs: "
f"{sorted(unknown)}. Known: {sorted(known)}"
)
kwargs = DeterministicEnvironmentKwargs(**raw_kwargs)
kwargs.finalize_and_validate()
return cls(params, kwargs)
def _validate_lookup_table(self) -> None:
"""Validate the env's lookup_table against its tool list.
- Stale ``lookup_table`` keys (no matching tool): log a warning;
entries are dormant.
- Tools without entries: hard error.
- Duplicate inputs within a tool's entries: hard error.
"""
for tool_id in self._kwargs.lookup_table:
if tool_id not in self._tool_ids:
logger.warning(
"Environment '%s': lookup_table.'%s' references unknown "
"tool. Entries will be ignored.",
self._params.id,
tool_id,
)
for tool in self._params.tools:
entries = self._kwargs.lookup_table.get(tool.id, [])
if not entries:
raise ValueError(
f"Tool '{tool.id}' has no entries in lookup_table for "
f"environment '{self._params.id}'."
)
seen: set[str] = set()
for entry in entries:
key = entry.input_key()
if key in seen:
raise ValueError(
f"Tool '{tool.id}' has duplicate input entry: {entry.input}"
)
seen.add(key)