Source code for oumi.core.callbacks.bitnet_callback
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Simple Bitnet model saving callback."""frompathlibimportPathfromtypingimportOptional,Unionimporttransformersfromtransformers.trainer_utilsimportPREFIX_CHECKPOINT_DIRfromoumi.core.callbacks.base_trainer_callbackimportBaseTrainerCallbackfromoumi.core.configsimportTrainingParams# Import `onebitllms` utils methodstry:importonebitllms# type: ignorefromonebitllmsimportquantize_to_1bit# type: ignoreexceptImportError:onebitllms=None
[docs]classBitNetCallback(BaseTrainerCallback):"""BitNet model saving callback. Simple callback that saves the model into BitNet quantized format during training. """
[docs]defon_save(self,args:Union[transformers.TrainingArguments,TrainingParams],state:Optional[transformers.TrainerState]=None,control:Optional[transformers.TrainerControl]=None,**kwargs,):"""Saving callback. Gets triggered at each saving step to quantize trained models in 1bit precision. """ifonebitllmsisNone:raiseValueError("""You need `onebitllms` to be installed in order to save correctly BitNet models - `pip install onebitllms`""")output_dir=Path(args.output_dir)# type: ignorequantized_subdir=Path(f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}-quantized"# type: ignore)output_subdir=Path(f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")# type: ignorecheckpoint_folder=output_dir/output_subdirquantized_checkpoint_folder=output_dir/quantized_subdirquantize_to_1bit(str(checkpoint_folder),str(quantized_checkpoint_folder))