Source code for oumi.inference.gcp_inference_engine
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importjsonimportosfromtypingimportAny,Optionalimportpydanticfromtyping_extensionsimportoverridefromoumi.core.configsimportGenerationParams,ModelParams,RemoteParamsfromoumi.core.configs.params.guided_decoding_paramsimportGuidedDecodingParamsfromoumi.core.types.conversationimportConversationfromoumi.inference.remote_inference_engineimportRemoteInferenceEngine
[docs]classGoogleVertexInferenceEngine(RemoteInferenceEngine):"""Engine for running inference against Google Vertex AI."""_API_URL_TEMPLATE=("https://{region}-aiplatform.googleapis.com/v1beta1/projects/""{project_id}/locations/{region}/endpoints/openapi/chat/completions")"""The API URL template for the GCP project. Used when no `api_url` is provided."""_DEFAULT_PROJECT_ID_ENV_KEY:str="PROJECT_ID""""The default project ID environment key for the GCP project."""_DEFAULT_REGION_ENV_KEY:str="REGION""""The default region environment key for the GCP project."""_project_id:Optional[str]=None"""The project ID for the GCP project."""_region:Optional[str]=None"""The region for the GCP project."""def__init__(self,model_params:ModelParams,*,generation_params:Optional[GenerationParams]=None,remote_params:Optional[RemoteParams]=None,project_id_env_key:Optional[str]=None,region_env_key:Optional[str]=None,project_id:Optional[str]=None,region:Optional[str]=None,):"""Initializes the inference Engine. Args: model_params: The model parameters to use for inference. generation_params: The generation parameters to use for inference. remote_params: The remote parameters to use for inference. project_id_env_key: The environment variable key name for the project ID. region_env_key: The environment variable key name for the region. project_id: The project ID to use for inference. region: The region to use for inference. """super().__init__(model_params=model_params,generation_params=generation_params,remote_params=remote_params,)ifproject_idandproject_id_env_key:raiseValueError("You cannot set both `project_id` and `project_id_env_key`.")ifregionandregion_env_key:raiseValueError("You cannot set both `region` and `region_env_key`.")self._project_id_env_key=(project_id_env_keyorself._DEFAULT_PROJECT_ID_ENV_KEY)self._region_env_key=region_env_keyorself._DEFAULT_REGION_ENV_KEYself._project_id=project_idself._region=region@overridedef_set_required_fields_for_inference(self,remote_params:RemoteParams)->None:"""Set required fields for inference."""if(notremote_params.api_urlandnotself._remote_params.api_urlandnotself.base_url):ifself._project_idandself._region:project_id=self._project_idregion=self._regionelifos.getenv(self._project_id_env_key)andos.getenv(self._region_env_key):project_id=os.getenv(self._project_id_env_key)region=os.getenv(self._region_env_key)else:raiseValueError("This inference engine requires that either `api_url` is set in ""`RemoteParams` or that both `project_id` and `region` are set. ""You can set the `project_id` and `region` when ""constructing a GoogleVertexInferenceEngine, "f"or as environment variables: `{self._project_id_env_key}` and "f"`{self._region_env_key}`.")remote_params.api_url=self._API_URL_TEMPLATE.format(project_id=project_id,region=region,)super()._set_required_fields_for_inference(remote_params)@overridedef_get_api_key(self,remote_params:RemoteParams)->str:"""Gets the authentication token for GCP."""try:fromgoogle.authimportdefault# pyright: ignore[reportMissingImports]fromgoogle.auth.transport.requestsimport(# pyright: ignore[reportMissingImports]Request,)fromgoogle.oauth2import(# pyright: ignore[reportMissingImports]service_account,)exceptModuleNotFoundError:raiseRuntimeError("Google-auth is not installed. ""Please install oumi with GCP extra:`pip install oumi[gcp]`, ""or install google-auth with `pip install google-auth`.")ifremote_params.api_key:credentials=service_account.Credentials.from_service_account_file(filename=remote_params.api_key,scopes=["https://www.googleapis.com/auth/cloud-platform"],)else:credentials,_=default(scopes=["https://www.googleapis.com/auth/cloud-platform"])credentials.refresh(Request())# type: ignorereturncredentials.token# type: ignore@overridedef_get_request_headers(self,remote_params:Optional[RemoteParams])->dict[str,str]:"""Gets the request headers for GCP."""ifnotremote_params:raiseValueError("Remote params are required for GCP inference.")headers={"Authorization":f"Bearer {self._get_api_key(remote_params)}","Content-Type":"application/json",}returnheaders@overridedef_default_remote_params(self)->RemoteParams:"""Returns the default remote parameters."""returnRemoteParams(num_workers=10,politeness_policy=60.0)@overridedef_convert_conversation_to_api_input(self,conversation:Conversation,generation_params:GenerationParams,model_params:ModelParams,)->dict[str,Any]:"""Converts a conversation to an OpenAI input. Documentation: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-vertex-using-openai-library Args: conversation: The conversation to convert. generation_params: Parameters for generation during inference. model_params: Model parameters to use during inference. Returns: Dict[str, Any]: A dictionary representing the Vertex input. """api_input={"model":model_params.model_name,"messages":self._get_list_of_message_json_dicts(conversation.messages,group_adjacent_same_role_turns=True),"max_completion_tokens":generation_params.max_new_tokens,"temperature":generation_params.temperature,"top_p":generation_params.top_p,"n":1,# Number of completions to generate for each prompt."seed":generation_params.seed,"logit_bias":generation_params.logit_bias,}ifgeneration_params.stop_strings:api_input["stop"]=generation_params.stop_stringsifgeneration_params.guided_decoding:api_input["response_format"]=_convert_guided_decoding_config_to_api_input(generation_params.guided_decoding)returnapi_input
[docs]@overridedefget_supported_params(self)->set[str]:"""Returns a set of supported generation parameters for this engine."""return{"guided_decoding","logit_bias","max_new_tokens","seed","stop_strings","temperature","top_p",}
## Helper functions#def_convert_guided_decoding_config_to_api_input(guided_config:GuidedDecodingParams,)->dict:"""Converts a guided decoding configuration to an API input."""ifguided_config.jsonisNone:raiseValueError("Only JSON schema guided decoding is supported, got '%s'",guided_config,)json_schema=guided_config.jsonifisinstance(json_schema,type)andissubclass(json_schema,pydantic.BaseModel):schema_name=json_schema.__name__schema_value=json_schema.model_json_schema()elifisinstance(json_schema,dict):# Use a generic name if no schema is provided.schema_name="Response"schema_value=json_schemaelifisinstance(json_schema,str):# Use a generic name if no schema is provided.schema_name="Response"# Try to parse as JSON stringschema_value=json.loads(json_schema)else:raiseValueError(f"Got unsupported JSON schema type: {type(json_schema)}""Please provide a Pydantic model or a JSON schema as a ""string or dict.")return{"type":"json_schema","json_schema":{"name":schema_name,"schema":_replace_refs_in_schema(schema_value),},}def_replace_refs_in_schema(schema:dict)->dict:"""Replace $ref references in a JSON schema with their actual definitions. Args: schema: The JSON schema dictionary Returns: dict: Schema with all references replaced by their definitions and $defs removed """def_get_ref_value(ref:str)->dict:# Remove the '#/' prefix and split into partsparts=ref.replace("#/","").split("/")# Navigate through the schema to get the referenced valuecurrent=schemaforpartinparts:current=current[part]returncurrent.copy()# Return a copy to avoid modifying the originaldef_replace_refs(obj:dict)->dict:ifnotisinstance(obj,dict):returnobjresult={}forkey,valueinobj.items():ifkey=="$ref":# If we find a $ref, replace it with the actual valuereturn_replace_refs(_get_ref_value(value))elifisinstance(value,dict):result[key]=_replace_refs(value)elifisinstance(value,list):result[key]=[_replace_refs(item)ifisinstance(item,dict)elseitemforiteminvalue]else:result[key]=valuereturnresult# Replace all references firstresolved=_replace_refs(schema.copy())# Remove the $defs key if it existsif"$defs"inresolved:delresolved["$defs"]returnresolved