Source code for oumi.inference.gemini_inference_engine
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing_extensions import override
from oumi.core.configs import GenerationParams
from oumi.core.types.conversation import Conversation
from oumi.inference.gcp_inference_engine import (
_convert_guided_decoding_config_to_api_input,
)
from oumi.inference.remote_inference_engine import RemoteInferenceEngine
[docs]
class GoogleGeminiInferenceEngine(RemoteInferenceEngine):
"""Engine for running inference against Gemini API."""
base_url = (
"https://generativelanguage.googleapis.com/v1beta/openai/chat/completions"
)
"""The base URL for the Gemini API."""
api_key_env_varname = "GEMINI_API_KEY"
"""The environment variable name for the Gemini API key."""
@override
def _convert_conversation_to_api_input(
self, conversation: Conversation, generation_params: GenerationParams
) -> dict[str, Any]:
"""Converts a conversation to an Gemini API input.
Documentation: https://ai.google.dev/docs
Args:
conversation: The conversation to convert.
generation_params: Parameters for generation during inference.
Returns:
Dict[str, Any]: A dictionary representing the Gemini input.
"""
api_input = {
"model": self._model,
"messages": self._get_list_of_message_json_dicts(
conversation.messages, group_adjacent_same_role_turns=True
),
"max_completion_tokens": generation_params.max_new_tokens,
"temperature": generation_params.temperature,
"top_p": generation_params.top_p,
"n": 1, # Number of completions to generate for each prompt.
}
if generation_params.stop_strings:
api_input["stop"] = generation_params.stop_strings
if generation_params.guided_decoding:
api_input["response_format"] = _convert_guided_decoding_config_to_api_input(
generation_params.guided_decoding
)
return api_input
[docs]
@override
def get_supported_params(self) -> set[str]:
"""Returns a set of supported generation parameters for this engine."""
return {
"guided_decoding",
"max_new_tokens",
"stop_strings",
"temperature",
"top_p",
}