Source code for oumi.utils.torch_naming_heuristics
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utility functions which use detect-by-name heuristics.# TODO(OPE-303): These should be replaced with something more robust."""importimportlibfromtypingimportAnyimporttorchimporttorch.nnasnnimporttransformersfromoumi.utils.loggingimportloggerfromoumi.utils.torch_utilsimport_get_parameter_names_PARAMS_KEY="params"_WEIGHT_DECAY_KEY="weight_decay"
[docs]defdisable_dropout(hf_config:transformers.PretrainedConfig)->None:"""Detects dropout probabilities in config and sets them to 0.0. This essentially removes the dropout layer, which can aid the compiled model's speed. Dropout is normally not used for LLM training, and also hinders the effectiveness of model compilation. We assume any attribute with "drop" in the name and a float value is a dropout param. For example, this includes `attn_pdrop` and `summary_first_dropout` for GPT2. Args: hf_config: The HuggingFace model config. """drop_attrs=[]fork,vinvars(hf_config).items():if"drop"inkandisinstance(v,float):setattr(hf_config,k,0.0)drop_attrs.append(k)logger.info(f"Found these dropout attributes and set their values to 0.0: {drop_attrs}")
[docs]defgroup_trainable_params(model:torch.nn.Module,weight_decay:float)->list[dict[str,Any]]:"""Groups trainable params by weight decay for optimization. As a rule of thumb, we generally want to weight decay all 2d matrices, i.e. weight tensors for matmuls/embeddings, and not biases/layernorms. Args: model: The model whose parameters will be optimized. weight_decay: The weight decay to apply to the appropriate parameters. Returns: List[Dict[str, Any]]: A list containing two dictionaries: the first with parameters that should be weight decayed and the second with parameters that shouldn't. """# Exclude layernorm and bias tensors.decay_parameters=_get_parameter_names(model,forbidden_layer_types=[torch.nn.LayerNorm])decay_parameters=[namefornameindecay_parametersif"bias"notinname]# Only include trainable params.trainable_params=[(n,p)forn,pinmodel.named_parameters()ifp.requires_grad]# Group by weight decay.return[{_PARAMS_KEY:[pforn,pintrainable_paramsifnindecay_parameters],_WEIGHT_DECAY_KEY:weight_decay,},{_PARAMS_KEY:[pforn,pintrainable_paramsifnnotindecay_parameters],_WEIGHT_DECAY_KEY:0.0,},]
[docs]defguess_transformer_layer_cls(model:nn.Module)->type[nn.Module]:"""Guess the transformer layer class based on the model architecture."""formoduleinmodel.modules():forlayer_patternin["layer","block","transformerlayer"]:layer_name=str(type(module)).lower()iflayer_patterninlayer_nameand"layernorm"notinlayer_name:returntype(module)raiseValueError("Unable to guess transformer layer class. Please specify it explicitly.")
[docs]defresolve_transformer_layer_cls_string_as_module_set(class_names:str,)->set[type[nn.Module]]:"""Get a module class from its string name."""result:set[type[nn.Module]]=set()forclass_namein_parse_transformer_layer_cls_string(class_names):parts=class_name.rsplit(".",maxsplit=1)iflen(parts)==1:# Assume `transformers` by default.module_name="transformers"else:module_name,class_name=partsmodule=importlib.import_module(module_name)transformer_cls=getattr(module,class_name)result.add(transformer_cls)returnresult
[docs]defsimplify_transformer_layer_cls_string(class_names:str)->str:"""Replaces fully-qualified class names with pure class names. For example, converts 'foo.Block,foo.util.Decoder' to 'Block,Decoder'. The `accelerate` library expects the simplified format, while OUMI trainer requires fully-qualified class names. """result=[]forclass_namein_parse_transformer_layer_cls_string(class_names):parts=class_name.rsplit(".")result.append(parts[-1])return",".join(result)