# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Synthetic environment backed by LLM-simulated tool execution."""
from __future__ import annotations
import copy
import dataclasses
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import jsonschema
from oumi.core.configs.inference_config import InferenceConfig
from oumi.core.configs.params.base_params import BaseParams
from oumi.core.configs.params.environment_params import EnvironmentParams
from oumi.core.configs.params.guided_decoding_params import GuidedDecodingParams
from oumi.core.configs.params.tool_params import ToolError, ToolParams
from oumi.core.registry import register_environment
from oumi.core.types.conversation import Conversation, Message, Role
from oumi.core.types.tool_call import ToolResult
from oumi.environments.base_environment import BaseEnvironment
from oumi.utils.str_utils import extract_json
if TYPE_CHECKING:
from oumi.core.inference.base_inference_engine import BaseInferenceEngine
[docs]
@dataclass
class SyntheticStateParams(BaseParams):
"""Optional state configuration for a synthetic environment."""
state_schema: dict[str, Any] | None = None
initial_state: dict[str, Any] | None = None
[docs]
def __post_init__(self):
"""Validate state config consistency."""
if self.state_schema is not None and self.initial_state is not None:
jsonschema.validate(self.initial_state, self.state_schema)
[docs]
@dataclass
class SyntheticEnvironmentKwargs(BaseParams):
"""Type-specific kwargs for SyntheticEnvironment."""
system_prompt: str = ""
state_params: SyntheticStateParams | None = None
cache_by_input: bool = True
[docs]
def __post_init__(self) -> None:
"""Coerce state_params dict into SyntheticStateParams if needed."""
if isinstance(self.state_params, dict):
self.state_params = SyntheticStateParams(**self.state_params)
[docs]
def __finalize_and_validate__(self) -> None:
"""Finalize and validate the kwargs."""
if not self.system_prompt:
raise ValueError(
"SyntheticEnvironmentKwargs.system_prompt cannot be empty."
)
if self.state_params is not None and self.cache_by_input:
raise ValueError(
"SyntheticEnvironmentKwargs.cache_by_input must be False when "
"state_params is provided."
)
[docs]
@register_environment("synthetic")
class SyntheticEnvironment(BaseEnvironment):
"""LLM-simulated environment with optional mutable state."""
def __init__(
self,
params: EnvironmentParams,
kwargs: SyntheticEnvironmentKwargs,
) -> None:
"""Initialize a SyntheticEnvironment with the given params and kwargs."""
self._params = params
self._kwargs = kwargs
self._cache: dict[str, ToolResult] = {}
self._state: dict[str, Any] | None = (
copy.deepcopy(kwargs.state_params.initial_state)
if kwargs.state_params is not None
and kwargs.state_params.initial_state is not None
else None
)
self._engine: BaseInferenceEngine | None = None
self._base_inference_config: InferenceConfig | None = None
[docs]
def attach_inference(
self,
engine: BaseInferenceEngine,
base_config: InferenceConfig,
) -> None:
"""Inject the orchestrator's inference engine + base config."""
self._engine = engine
self._base_inference_config = base_config
[docs]
@classmethod
def from_params(cls, params: EnvironmentParams) -> SyntheticEnvironment:
"""Build a SyntheticEnvironment from its params object."""
kwargs = SyntheticEnvironmentKwargs(**(params.env_kwargs or {}))
kwargs.finalize_and_validate()
return cls(params, kwargs)
@property
def current_state(self) -> dict[str, Any] | None:
"""Return the current in-memory state snapshot."""
if self._state is None:
return None
return copy.deepcopy(self._state)
@staticmethod
def _cache_key(tool_id: str, arguments: dict[str, Any]) -> str:
"""Build a stable cache key from tool id and arguments."""
return f"{tool_id}::{json.dumps(arguments, sort_keys=True)}"
def _resolve_cached(
self, tool_id: str, arguments: dict[str, Any]
) -> ToolResult | None:
"""Look up a cached result for the given tool call."""
if not self._kwargs.cache_by_input:
return None
result = self._cache.get(self._cache_key(tool_id, arguments))
if result is None:
return None
return ToolResult(
output=copy.deepcopy(result.output),
updated_state=copy.deepcopy(result.updated_state),
)
def _cache_result(
self, tool_id: str, arguments: dict[str, Any], result: ToolResult
) -> None:
"""Store a generated result in the cache."""
if not self._kwargs.cache_by_input:
return
self._cache[self._cache_key(tool_id, arguments)] = ToolResult(
output=copy.deepcopy(result.output),
updated_state=copy.deepcopy(result.updated_state),
)
def _lookup_tool(self, tool_id: str) -> ToolParams:
for tool in self._params.tools:
if tool.id == tool_id:
return tool
raise ValueError(
f"Tool '{tool_id}' not found in environment '{self._params.id}'. "
f"Available tools: {[tool.id for tool in self._params.tools]}"
)
[docs]
def step(self, calls: list[tuple[str, dict[str, Any]]]) -> list[ToolResult]:
"""Execute synthetic tool calls. Cache-misses batched per tool_id.
Raises:
RuntimeError: If ``attach_inference`` was not called.
ValueError: If any tool id is unknown.
ToolError: On simulator parse failure or output_schema mismatch.
"""
if not calls:
return []
for tool_id, _ in calls:
self._lookup_tool(tool_id)
if self._engine is None or self._base_inference_config is None:
raise RuntimeError(
"SyntheticEnvironment.step called before attach_inference(). "
"Wire the synthesizer's engine via attach_inference(engine, "
"base_config) before invoking step()."
)
results: list[ToolResult | None] = [None] * len(calls)
misses: list[tuple[int, str, dict[str, Any]]] = []
for i, (tool_id, args) in enumerate(calls):
cached = self._resolve_cached(tool_id, args)
if cached is not None:
results[i] = cached
else:
misses.append((i, tool_id, args))
groups: dict[str, list[tuple[int, dict[str, Any]]]] = {}
for i, tool_id, args in misses:
groups.setdefault(tool_id, []).append((i, args))
for tool_id, group in groups.items():
tool = self._lookup_tool(tool_id)
convs = [self._build_call_conv(tool, args) for _, args in group]
inferred = self._engine.infer(convs, self._simulator_inference_config(tool))
if len(inferred) != len(group):
raise RuntimeError(
f"Simulator returned {len(inferred)} responses for "
f"{len(group)} calls to '{tool_id}'."
)
for (idx, args), conv in zip(group, inferred):
raw = self._extract_text(conv)
result = self._parse_and_validate(raw, tool)
self._cache_result(tool_id, args, result)
results[idx] = result
assert all(r is not None for r in results), (
"every call must produce a ToolResult"
)
return results # type: ignore[return-value]
def _build_simulator_system_prompt(self, tool: ToolParams) -> str:
"""Compose the simulator system prompt: env persona + tool schema."""
return (
f"{self._kwargs.system_prompt}\n\n"
f"You are simulating the `{tool.id}` tool. Respond ONLY with a "
f"JSON object matching the tool's output schema. Do NOT include "
f"explanations, markdown, or surrounding prose.\n\n"
f"Tool schema:\n{json.dumps(tool.to_llm_schema(), indent=2)}"
)
def _build_call_conv(
self, tool: ToolParams, arguments: dict[str, Any]
) -> Conversation:
"""Build the simulator conversation for one tool call."""
user_payload = json.dumps(
{"tool": tool.id, "arguments": arguments}, sort_keys=True
)
return Conversation(
messages=[
Message(
role=Role.SYSTEM,
content=self._build_simulator_system_prompt(tool),
),
Message(role=Role.USER, content=user_payload),
]
)
def _simulator_inference_config(self, tool: ToolParams) -> InferenceConfig:
"""Overlay guided decoding for the tool's output_schema onto base_config.
Tools without ``output_schema`` get the permissive ``{"type": "object"}``
constraint. Mirrors ``ConversationSynthesizer._planner_inference_config``.
"""
assert self._base_inference_config is not None
schema = tool.output_schema or {"type": "object"}
sim_gen = dataclasses.replace(
self._base_inference_config.generation,
guided_decoding=GuidedDecodingParams(json=schema),
)
return dataclasses.replace(self._base_inference_config, generation=sim_gen)
@staticmethod
def _extract_text(conv: Conversation) -> str:
"""Pull the simulator's text response from an inferred conversation.
Returns ``""`` (which forces the ``ToolError`` path in
``_parse_and_validate``) when the last message is not an assistant
turn — guards against a passthrough/partial-failure path where the
engine returns ``convs`` unchanged and ``messages[-1]`` is still the
user payload (itself valid JSON of the form
``{"tool": ..., "arguments": ...}``).
"""
if not conv.messages:
return ""
last = conv.messages[-1]
if last.role != Role.ASSISTANT:
return ""
content = last.content
return content.strip() if isinstance(content, str) else ""
@staticmethod
def _parse_and_validate(raw: str, tool: ToolParams) -> ToolResult:
"""Parse simulator output and validate against ``tool.output_schema``."""
if not raw:
raise ToolError(f"Simulator returned empty response for '{tool.id}'.")
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
extracted = extract_json(raw, expected_type=dict)
if extracted is None:
raise ToolError(
f"Simulator output for '{tool.id}' is not valid JSON: {raw[:200]!r}"
) from None
parsed = extracted
if not isinstance(parsed, dict):
raise ToolError(
f"Simulator output for '{tool.id}' must be a JSON object, "
f"got {type(parsed).__name__}."
)
if tool.output_schema is not None:
try:
jsonschema.validate(parsed, tool.output_schema)
except jsonschema.ValidationError as e:
raise ToolError(
f"Simulator output for '{tool.id}' failed schema validation: {e}"
) from e
return ToolResult(output=parsed)