# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importdataclassesimportjsonfromtypingimportAnyimportnumpyasnpimporttorchfromoumi.utils.loggingimportloggerJSON_FILE_INDENT=2
[docs]defdefault(self,obj):"""Extending python's JSON Encoder to serialize torch dtype."""ifobjisNone:return""# JSON does NOT natively support any torch types.elifisinstance(obj,torch.dtype):returnstr(obj)# JSON does NOT natively support numpy types.elifisinstance(obj,np.integer):returnint(obj)elifisinstance(obj,np.floating):returnfloat(obj)elifisinstance(obj,np.ndarray):returnobj.tolist()else:try:returnsuper().default(obj)exceptException:logger.warning(f"Non-serializable value `{obj}` of type `{type(obj)}`.")returnstr(obj)
[docs]defconvert_all_keys_to_serializable_types(dictionary:dict)->None:"""Converts all keys in a hierarchical dictionary to serializable types."""serializable_key_types={str,int,float,bool,None}non_serializable_keys=[keyforkeyindictionaryiftype(key)notinserializable_key_types]forkeyinnon_serializable_keys:dictionary[str(key)]=dictionary.pop(key)# Recursively convert all keys for nested dictionaries.forvalueindictionary.values():ifisinstance(value,dict):convert_all_keys_to_serializable_types(value)
[docs]defflatten_config(config:Any,prefix:str="",separator:str=".")->dict[str,Any]:"""Flattens a nested config object into a flat dictionary with dot notation keys. Args: config: The config object to flatten (dataclass, dict, or other) prefix: The prefix to prepend to keys separator: The separator to use between nested keys Examples: >>> config = TrainingConfig( >>> model=ModelParams( >>> model_name="gpt2", >>> ), >>> training=TrainingParams( >>> batch_size=16, >>> ), >>> ) >>> flatten_config(config) { "model.model_name": "gpt2", "training.batch_size": 16, } Returns: A flattened dictionary with string keys """ifdataclasses.is_dataclass(config)andnotisinstance(config,type):config_dict=dataclasses.asdict(config)elifisinstance(config,dict):config_dict=configelse:# For non-dict/dataclass objects, convert to string representationreturn{prefixor"value":str(config)}flattened={}forkey,valueinconfig_dict.items():new_key=f"{prefix}{separator}{key}"ifprefixelsekeyifisinstance(value,dict):# Recursively flatten nested dictionariesnested_flat=flatten_config(value,new_key,separator)flattened.update(nested_flat)elifdataclasses.is_dataclass(value)andnotisinstance(value,type):# Recursively flatten nested dataclassesnested_flat=flatten_config(value,new_key,separator)flattened.update(nested_flat)elifisinstance(value,(list,tuple)):# Handle lists/tuples by converting to string or flattening if they# contain dictsifvalueand(isinstance(value[0],dict)or(dataclasses.is_dataclass(value[0])andnotisinstance(value[0],type))):fori,iteminenumerate(value):item_key=f"{new_key}{separator}{i}"nested_flat=flatten_config(item,item_key,separator)flattened.update(nested_flat)else:flattened[new_key]=str(value)else:ifisinstance(value,(str,int,float,bool))orvalueisNone:flattened[new_key]=valueelse:flattened[new_key]=str(value)returnflattened
[docs]defjson_serializer(obj:Any)->str:"""Serializes a Python obj to a JSON formatted string."""ifdataclasses.is_dataclass(obj)andnotisinstance(obj,type):dict_to_serialize=dataclasses.asdict(obj)elifisinstance(obj,dict):dict_to_serialize=objelse:raiseValueError(f"Cannot serialize object of type {type(obj)} to JSON.")# Ensure all (nested) dictionary keys are serializable.ifisinstance(dict_to_serialize,dict):convert_all_keys_to_serializable_types(dict_to_serialize)# Attempt to serialize the dictionary to JSON.try:returnjson.dumps(dict_to_serialize,cls=TorchJsonEncoder,indent=JSON_FILE_INDENT)exceptExceptionase:error_str="Non-serializable dict:\n"forkey,valueindict_to_serialize.items():error_str+=f" - {key}: {value} (type: {type(value)})\n"logger.error(error_str)raiseException(f"Failed to serialize dict to JSON: {e}")