# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importtorchfromtransformers.optimizationimportAdafactorfromoumi.core.configsimportTrainingParamsfromoumi.utils.torch_naming_heuristicsimportgroup_trainable_params
[docs]defbuild_optimizer(model:torch.nn.Module,config:TrainingParams)->torch.optim.Optimizer:"""Builds and returns a PyTorch optimizer based on the provided configuration. See pytorch documentation for more information on available optimizers: https://pytorch.org/docs/stable/optim.html Args: model: The model whose parameters will be optimized. config: The configuration object containing optimizer parameters. Returns: Optimizer: The constructed PyTorch optimizer. """optimizer_name=config.optimizer.lower()# Get parameters that require optimization, grouped by weight decay.trainable_param_groups=group_trainable_params(model,config.weight_decay)fused_available=torch.cuda.is_available()ifoptimizer_name=="adam":returntorch.optim.Adam(trainable_param_groups,lr=config.learning_rate,betas=(config.adam_beta1,config.adam_beta2),eps=config.adam_epsilon,fused=fused_available,)elifoptimizer_namein("adamw","adamw_torch","adamw_torch_fused"):returntorch.optim.AdamW(trainable_param_groups,lr=config.learning_rate,betas=(config.adam_beta1,config.adam_beta2),eps=config.adam_epsilon,fused=fused_available,)elifoptimizer_namein("adamw_8bit","paged_adamw_8bit","paged_adamw","paged_adamw_32bit",):try:importbitsandbytes# pyright: ignore[reportMissingImports]exceptImportError:raiseImportError("bitsandbytes is not installed. ""Please install it with `pip install bitsandbytes` ""to use 8-bit or paged optimizers.")ifoptimizer_namein("adamw_8bit","paged_adamw_8bit"):returnbitsandbytes.optim.AdamW(trainable_param_groups,lr=config.learning_rate,betas=(config.adam_beta1,config.adam_beta2),eps=config.adam_epsilon,weight_decay=config.weight_decay,optim_bits=8,is_paged=optimizer_name=="paged_adamw_8bit",)else:# paged_adamw or paged_adamw_32bitreturnbitsandbytes.optim.PagedAdamW(trainable_param_groups,lr=config.learning_rate,betas=(config.adam_beta1,config.adam_beta2),eps=config.adam_epsilon,weight_decay=config.weight_decay,)elifoptimizer_name=="sgd":returntorch.optim.SGD(trainable_param_groups,lr=config.learning_rate,momentum=config.sgd_momentum,fused=fused_available,)elifoptimizer_name=="adafactor":returnAdafactor(trainable_param_groups,lr=config.learning_rate,relative_step=False,scale_parameter=False,)else:raiseValueError(f"Unsupported optimizer: {optimizer_name}")