Source code for oumi.inference.bedrock_inference_engine

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
from typing import Any, Optional

from tqdm.asyncio import tqdm
from typing_extensions import override

from oumi.core.configs import GenerationParams, ModelParams, RemoteParams
from oumi.core.types.conversation import Conversation, Message, Role
from oumi.inference.adaptive_semaphore import PoliteAdaptiveSemaphore
from oumi.inference.remote_inference_engine import RemoteInferenceEngine
from oumi.utils.logging import logger

try:
    import boto3  # pyright: ignore[reportMissingImports]
    from botocore.exceptions import ClientError  # pyright: ignore[reportMissingImports]
except ModuleNotFoundError:
    boto3 = None  # type: ignore
    ClientError = None  # type: ignore

_CONTENT_KEY: str = "content"
_ROLE_KEY: str = "role"
_AWS_REGION_ENV_VAR: str = "AWS_REGION"


[docs] class BedrockInferenceEngine(RemoteInferenceEngine): """Inference engine for running inference against the Bedrock API. This class extends RemoteInferenceEngine to provide specific functionality for interacting with Bedrock's language models via their API. It handles the conversion of Oumi's Conversation objects to Bedrock's expected input format, as well as parsing the API responses back into Conversation objects. Note: This engine requires the boto3 package to be installed. If not installed, it will raise a RuntimeError. """ def __init__( self, model_params: ModelParams, *, generation_params: Optional[GenerationParams] = None, remote_params: Optional[RemoteParams] = None, ): """Initializes the BedrockInferenceEngine. Args: model_params: Parameters for the model. generation_params: Parameters for generation. remote_params: Parameters for remote inference. Raises: RuntimeError: If the boto3 package is not installed. """ if not boto3: raise RuntimeError( "boto3 is not installed. Please install it with 'pip install boto3'." ) super().__init__( model_params=model_params, generation_params=generation_params, remote_params=remote_params, ) @property @override def base_url(self) -> Optional[str]: """Return the default base URL for the Bedrock API.""" return None @property @override def api_key_env_varname(self) -> Optional[str]: """Return the default environment variable name for the Bedrock API key.""" return None def _bedrock_client(self, remote_params: RemoteParams) -> Any: region = os.getenv(_AWS_REGION_ENV_VAR) if not region: raise ValueError(f"Environment variable {_AWS_REGION_ENV_VAR} not set.") return boto3.client("bedrock-runtime", region_name=region) # type: ignore @override def _convert_conversation_to_api_input( self, conversation: Conversation, generation_params: GenerationParams, model_params: ModelParams, ) -> dict[str, Any]: """Converts a conversation to a Bedrock API input. Args: conversation: The conversation to convert. generation_params: Parameters for text generation. model_params: Model parameters to use during inference. Returns: Dict[str, Any]: A dictionary containing the formatted input for the Bedrock API, including the model, messages, and generation parameters. """ system_messages = [ message for message in conversation.messages if message.role == Role.SYSTEM ] if len(system_messages) > 0: system_message = system_messages[0].content if len(system_messages) > 1: logger.warning( "Multiple system messages found in conversation. " "Only using the first one." ) else: system_message = None messages = [ message for message in conversation.messages if message.role != Role.SYSTEM ] body = { "inferenceConfig": { "maxTokens": generation_params.max_new_tokens, "temperature": generation_params.temperature, "topP": generation_params.top_p, }, "messages": self._to_bedrock_messages(messages), } if system_message: body["system"] = [{"text": system_message}] if model_params.model_kwargs: body["additionalModelRequestFields"] = model_params.model_kwargs if generation_params.stop_strings: body["inferenceConfig"]["stopSequences"] = generation_params.stop_strings return body def _to_bedrock_messages(self, messages: list[Message]) -> list[dict]: result = [] for m in messages: # Map Oumi roles to Bedrock roles if m.role == Role.USER: bedrock_role = "user" elif m.role == Role.ASSISTANT: bedrock_role = "assistant" else: logger.warning(f"Skipping message with unsupported role: {m.role}") continue content_blocks: list[dict] = [] if m.contains_single_text_content_item_only(): content_blocks = [{"text": m.text_content_items[0].content}] elif m.contains_images(): img = m.image_content_items[0] if getattr(img, "binary", None): content_blocks = [{"image": {"source": {"bytes": img.binary}}}] else: uri = img.content or "" if uri.startswith("s3://"): content_blocks = [ {"image": {"source": {"s3Location": {"uri": uri}}}} ] else: content_blocks = [{"image": {"source": {"url": uri}}}] else: content_blocks = [{"text": m.compute_flattened_text_content()}] result.append({"role": bedrock_role, "content": content_blocks}) return result @override def _default_remote_params(self) -> RemoteParams: """Returns the default remote parameters.""" return RemoteParams() @override def _set_required_fields_for_inference(self, remote_params: RemoteParams): """Override to skip API key validation since Bedrock uses AWS credentials.""" pass async def _infer( self, input: list[Conversation], inference_config: Optional[Any] = None, ) -> list[Conversation]: """Async inference implementation that doesn't use HTTP sessions.""" semaphore = PoliteAdaptiveSemaphore( capacity=self._remote_params.num_workers, politeness_policy=self._remote_params.politeness_policy, ) # Create tasks for all conversations tasks = [ self._query_api( conversation, semaphore, None, # No HTTP session needed for boto3 inference_config=inference_config, ) for conversation in input ] disable_tqdm = len(tasks) < 2 results = await tqdm.gather(*tasks, disable=disable_tqdm) return results def _call_bedrock_converse( self, remote_params: RemoteParams, model_params: ModelParams, body: dict[str, Any], ) -> dict[str, Any]: """Synchronously invokes Bedrock Converse via boto3.""" client = self._bedrock_client(remote_params) kwargs: dict[str, Any] = { "modelId": model_params.model_name, "messages": body["messages"], } if "system" in body: kwargs["system"] = body["system"] if "inferenceConfig" in body: kwargs["inferenceConfig"] = body["inferenceConfig"] if "additionalModelRequestFields" in body: kwargs["additionalModelRequestFields"] = body[ "additionalModelRequestFields" ] return client.converse(**kwargs) @override async def _query_api( self, conversation: Conversation, semaphore: PoliteAdaptiveSemaphore, session: Any, inference_config: Optional[Any] = None, ) -> Conversation: """Queries Bedrock Converse using boto3 instead of HTTP.""" if inference_config is None: remote_params = self._remote_params generation_params = self._generation_params model_params = self._model_params output_path = None else: remote_params = inference_config.remote_params or self._remote_params generation_params = inference_config.generation or self._generation_params model_params = inference_config.model or self._model_params output_path = inference_config.output_path semaphore_or_controller = ( self._adaptive_concurrency_controller if self._remote_params.use_adaptive_concurrency else semaphore ) async with semaphore_or_controller: api_input = self._convert_conversation_to_api_input( conversation, generation_params, model_params ) failure_reason: Optional[str] = None for attempt in range(remote_params.max_retries + 1): try: if attempt > 0: delay = min( remote_params.retry_backoff_base * (2 ** (attempt - 1)), remote_params.retry_backoff_max, ) await asyncio.sleep(delay) response = await asyncio.to_thread( self._call_bedrock_converse, remote_params, model_params, api_input, ) result = self._convert_api_output_to_conversation( response, conversation ) if output_path: self._save_conversation_to_scratch(result, output_path) await self._try_record_success() return result except ClientError as e: # type: ignore # Capture AWS error message for logging/propagation failure_reason = e.response.get("Error", {}).get("Message") or str( e ) await self._try_record_error() if attempt >= remote_params.max_retries: raise RuntimeError(failure_reason) from e continue except RuntimeError: raise except Exception as e: failure_reason = f"Unexpected error: {str(e)}" await self._try_record_error() if attempt >= remote_params.max_retries: raise RuntimeError(failure_reason) from e continue raise RuntimeError( f"Failed to query Bedrock after {remote_params.max_retries} retries. " + (f"Reason: {failure_reason}" if failure_reason else "") ) @override def _convert_api_output_to_conversation( self, response: dict[str, Any], original: Conversation ) -> Conversation: text = "" msg = response.get("output", {}).get("message", {}) for block in msg.get("content", []): if "text" in block: text += block["text"] new_message = Message(content=text, role=Role.ASSISTANT) return Conversation( messages=[*original.messages, new_message], metadata=original.metadata, conversation_id=original.conversation_id, )
[docs] @override def get_supported_params(self) -> set[str]: """Returns a set of supported generation parameters for this engine.""" return { "max_new_tokens", "stop_strings", "temperature", "top_p", }
# Override batch inference methods to indicate they're not supported
[docs] @override def infer_batch( self, conversations: list[Conversation], inference_config: Optional[Any] = None, ) -> str: """Bedrock does not support batch inference via OpenAI-style batch API.""" raise NotImplementedError("Batch inference is not supported for Bedrock API.")
[docs] @override def get_batch_status(self, batch_id: str) -> Any: """Bedrock does not support batch inference via OpenAI-style batch API.""" raise NotImplementedError("Batch inference is not supported for Bedrock API.")
[docs] @override def list_batches( self, after: Optional[str] = None, limit: Optional[int] = None, ) -> Any: """Bedrock does not support batch inference via OpenAI-style batch API.""" raise NotImplementedError("Batch inference is not supported for Bedrock API.")
[docs] @override def get_batch_results( self, batch_id: str, conversations: list[Conversation], ) -> list[Conversation]: """Bedrock does not support batch inference via OpenAI-style batch API.""" raise NotImplementedError("Batch inference is not supported for Bedrock API.")