Source code for oumi.inference.gcp_inference_engine

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import Any, Optional

import pydantic
from typing_extensions import override

from oumi.core.configs import GenerationParams, RemoteParams
from oumi.core.configs.params.guided_decoding_params import GuidedDecodingParams
from oumi.core.types.conversation import Conversation
from oumi.inference.remote_inference_engine import RemoteInferenceEngine


[docs] class GoogleVertexInferenceEngine(RemoteInferenceEngine): """Engine for running inference against Google Vertex AI.""" @override def _get_api_key(self, remote_params: RemoteParams) -> str: """Gets the authentication token for GCP.""" try: from google.auth import default # pyright: ignore[reportMissingImports] from google.auth.transport.requests import ( # pyright: ignore[reportMissingImports] Request, ) from google.oauth2 import ( # pyright: ignore[reportMissingImports] service_account, ) except ModuleNotFoundError: raise RuntimeError( "Google-auth is not installed. " "Please install oumi with GCP extra:`pip install oumi[gcp]`, " "or install google-auth with `pip install google-auth`." ) if remote_params.api_key: credentials = service_account.Credentials.from_service_account_file( filename=remote_params.api_key, scopes=["https://www.googleapis.com/auth/cloud-platform"], ) else: credentials, _ = default( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) credentials.refresh(Request()) # type: ignore return credentials.token # type: ignore @override def _get_request_headers( self, remote_params: Optional[RemoteParams] ) -> dict[str, str]: """Gets the request headers for GCP.""" if not remote_params: raise ValueError("Remote params are required for GCP inference.") headers = { "Authorization": f"Bearer {self._get_api_key(remote_params)}", "Content-Type": "application/json", } return headers @override def _convert_conversation_to_api_input( self, conversation: Conversation, generation_params: GenerationParams ) -> dict[str, Any]: """Converts a conversation to an OpenAI input. Documentation: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-vertex-using-openai-library Args: conversation: The conversation to convert. generation_params: Parameters for generation during inference. Returns: Dict[str, Any]: A dictionary representing the Vertex input. """ api_input = { "model": self._model, "messages": self._get_list_of_message_json_dicts( conversation.messages, group_adjacent_same_role_turns=True ), "max_completion_tokens": generation_params.max_new_tokens, "temperature": generation_params.temperature, "top_p": generation_params.top_p, "n": 1, # Number of completions to generate for each prompt. "seed": generation_params.seed, "logit_bias": generation_params.logit_bias, } if generation_params.stop_strings: api_input["stop"] = generation_params.stop_strings if generation_params.guided_decoding: api_input["response_format"] = _convert_guided_decoding_config_to_api_input( generation_params.guided_decoding ) return api_input
[docs] @override def get_supported_params(self) -> set[str]: """Returns a set of supported generation parameters for this engine.""" return { "guided_decoding", "logit_bias", "max_new_tokens", "seed", "stop_strings", "temperature", "top_p", }
# # Helper functions # def _convert_guided_decoding_config_to_api_input( guided_config: GuidedDecodingParams, ) -> dict: """Converts a guided decoding configuration to an API input.""" if guided_config.json is None: raise ValueError( "Only JSON schema guided decoding is supported, got '%s'", guided_config, ) json_schema = guided_config.json if isinstance(json_schema, type) and issubclass(json_schema, pydantic.BaseModel): schema_name = json_schema.__name__ schema_value = json_schema.model_json_schema() elif isinstance(json_schema, dict): # Use a generic name if no schema is provided. schema_name = "Response" schema_value = json_schema elif isinstance(json_schema, str): # Use a generic name if no schema is provided. schema_name = "Response" # Try to parse as JSON string schema_value = json.loads(json_schema) else: raise ValueError( f"Got unsupported JSON schema type: {type(json_schema)}" "Please provide a Pydantic model or a JSON schema as a " "string or dict." ) return { "type": "json_schema", "json_schema": { "name": schema_name, "schema": _replace_refs_in_schema(schema_value), }, } def _replace_refs_in_schema(schema: dict) -> dict: """Replace $ref references in a JSON schema with their actual definitions. Args: schema: The JSON schema dictionary Returns: dict: Schema with all references replaced by their definitions and $defs removed """ def _get_ref_value(ref: str) -> dict: # Remove the '#/' prefix and split into parts parts = ref.replace("#/", "").split("/") # Navigate through the schema to get the referenced value current = schema for part in parts: current = current[part] return current.copy() # Return a copy to avoid modifying the original def _replace_refs(obj: dict) -> dict: if not isinstance(obj, dict): return obj result = {} for key, value in obj.items(): if key == "$ref": # If we find a $ref, replace it with the actual value return _replace_refs(_get_ref_value(value)) elif isinstance(value, dict): result[key] = _replace_refs(value) elif isinstance(value, list): result[key] = [ _replace_refs(item) if isinstance(item, dict) else item for item in value ] else: result[key] = value return result # Replace all references first resolved = _replace_refs(schema.copy()) # Remove the $defs key if it exists if "$defs" in resolved: del resolved["$defs"] return resolved