# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Optional, cast
from tqdm.auto import tqdm
from typing_extensions import override
from oumi.core.configs import GenerationParams, InferenceConfig, ModelParams
from oumi.core.inference import BaseInferenceEngine
from oumi.core.types.conversation import Conversation, Message, Role
from oumi.utils.logging import logger
try:
from llama_cpp import Llama # pyright: ignore[reportMissingImports]
except ModuleNotFoundError:
Llama = None
[docs]
class LlamaCppInferenceEngine(BaseInferenceEngine):
"""Engine for running llama.cpp inference locally.
This class provides an interface for running inference using the llama.cpp library
on local hardware. It allows for efficient execution of large language models
with quantization, kv-caching, prefix filling, ...
Note:
This engine requires the llama-cpp-python package to be installed.
If not installed, it will raise a RuntimeError.
Example:
>>> from oumi.core.configs import ModelParams
>>> from oumi.inference import LlamaCppInferenceEngine
>>> model_params = ModelParams(
... model_name="path/to/model.gguf",
... model_kwargs={
... "n_gpu_layers": -1,
... "n_threads": 8,
... "flash_attn": True
... }
... )
>>> engine = LlamaCppInferenceEngine(model_params) # doctest: +SKIP
>>> # Use the engine for inference
"""
def __init__(
self,
model_params: ModelParams,
*,
generation_params: Optional[GenerationParams] = None,
):
"""Initializes the LlamaCppInferenceEngine.
This method sets up the engine for running inference using llama.cpp.
It loads the specified model and configures the inference parameters.
Documentation: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion
Args:
model_params (ModelParams): Parameters for the model, including the model
name, maximum length, and any additional keyword arguments for model
initialization.
generation_params (GenerationParams): Parameters for generation.
Raises:
RuntimeError: If the llama-cpp-python package is not installed.
ValueError: If the specified model file is not found.
Note:
This method automatically sets some default values for model initialization:
- verbose: False (reduces log output for bulk inference)
- n_gpu_layers: -1 (uses GPU acceleration for all layers if available)
- n_threads: 4
- filename: "*q8_0.gguf" (applies Q8 quantization by default)
- flash_attn: True
These defaults can be overridden by specifying them in
`model_params.model_kwargs`.
"""
super().__init__(model_params=model_params, generation_params=generation_params)
if not Llama:
raise RuntimeError(
"llama-cpp-python is not installed. "
"Please install it with 'pip install llama-cpp-python'."
)
# `model_max_length` is required by llama-cpp, but optional in our config
# Use a default value if not set.
if model_params.model_max_length is None:
model_max_length = 4096
logger.warning(
"model_max_length is not set. "
f"Using default value of {model_max_length}."
)
else:
model_max_length = model_params.model_max_length
# Set some reasonable defaults. These will be overriden by the user if set in
# the config.
kwargs = {
# llama-cpp logs a lot of useful information,
# but it's too verbose by default for bulk inference.
"verbose": False,
# Put all layers on GPU / MPS if available. Otherwise, will use CPU.
"n_gpu_layers": -1,
# Increase the default number of threads.
# Too many can cause deadlocks
"n_threads": 4,
# Use Q8 quantization by default.
"filename": "*8_0.gguf",
"flash_attn": True,
}
model_kwargs = model_params.model_kwargs.copy()
kwargs.update(model_kwargs)
# Load model
if Path(model_params.model_name).exists():
logger.info(f"Loading model from disk: {model_params.model_name}.")
kwargs.pop("filename", None) # only needed if downloading from hub
self._llm = Llama(
model_path=model_params.model_name, n_ctx=model_max_length, **kwargs
)
else:
logger.info(
f"Loading model from Huggingface Hub: {model_params.model_name}."
)
self._llm = Llama.from_pretrained(
repo_id=model_params.model_name, n_ctx=model_max_length, **kwargs
)
def _convert_conversation_to_llama_input(
self, conversation: Conversation
) -> list[dict[str, str]]:
"""Converts a conversation to a list of llama.cpp input messages."""
# FIXME Handle multimodal e.g., raise an error.
return [
{
"content": message.compute_flattened_text_content(),
"role": "user" if message.role == Role.USER else "assistant",
}
for message in conversation.messages
]
def _infer(
self,
input: list[Conversation],
inference_config: Optional[InferenceConfig] = None,
) -> list[Conversation]:
"""Runs model inference on the provided input using llama.cpp.
Args:
input: A list of conversations to run inference on.
Each conversation should contain at least one message.
inference_config: Parameters for inference.
Returns:
List[Conversation]: A list of conversations with the model's responses
appended. Each conversation in the output list corresponds to an input
conversation, with an additional message from the assistant (model) added.
"""
generation_params = (
inference_config.generation
if inference_config and inference_config.generation
else self._generation_params
)
output_conversations = []
# skip using a progress for single turns
disable_tgdm = len(input) < 2
for conversation in tqdm(input, disable=disable_tgdm):
if not conversation.messages:
logger.warning("Conversation must have at least one message.")
# add the conversation to keep input and output the same length.
output_conversations.append(conversation)
continue
llama_input = self._convert_conversation_to_llama_input(conversation)
response = self._llm.create_chat_completion(
messages=llama_input, # type: ignore
max_tokens=generation_params.max_new_tokens,
temperature=generation_params.temperature,
top_p=generation_params.top_p,
frequency_penalty=generation_params.frequency_penalty,
presence_penalty=generation_params.presence_penalty,
stop=generation_params.stop_strings,
logit_bias=generation_params.logit_bias,
min_p=generation_params.min_p,
)
response = cast(dict, response)
new_message = Message(
content=response["choices"][0]["message"]["content"],
role=Role.ASSISTANT,
)
messages = [
*conversation.messages,
new_message,
]
new_conversation = Conversation(
messages=messages,
metadata=conversation.metadata,
conversation_id=conversation.conversation_id,
)
output_conversations.append(new_conversation)
if inference_config and inference_config.output_path:
self._save_conversation(
new_conversation,
inference_config.output_path,
)
return output_conversations
[docs]
@override
def get_supported_params(self) -> set[str]:
"""Returns a set of supported generation parameters for this engine."""
return {
"frequency_penalty",
"logit_bias",
"max_new_tokens",
"min_p",
"presence_penalty",
"stop_strings",
"temperature",
"top_p",
}
[docs]
@override
def infer_online(
self,
input: list[Conversation],
inference_config: Optional[InferenceConfig] = None,
) -> list[Conversation]:
"""Runs model inference online.
Args:
input: A list of conversations to run inference on.
inference_config: Parameters for inference.
Returns:
List[Conversation]: Inference output.
"""
return self._infer(input, inference_config)
[docs]
@override
def infer_from_file(
self,
input_filepath: str,
inference_config: Optional[InferenceConfig] = None,
) -> list[Conversation]:
"""Runs model inference on inputs in the provided file.
This is a convenience method to prevent boilerplate from asserting the existence
of input_filepath in the generation_params.
Args:
input_filepath: Path to the input file containing prompts for generation.
inference_config: Parameters for inference.
Returns:
List[Conversation]: Inference output.
"""
input = self._read_conversations(input_filepath)
return self._infer(input, inference_config)