# Copyright 2024 Bytedance Ltd. and/or its affiliates## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# Copied from the Verl script:# volcengine/verl/scripts/model_merger.py# with minor modifications e.g., `import verl` is wrapped into try-except block# and some verl-related imports are moved down to the function that uses them."""This script is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends.To merge FSDP checkpoints:```shpython scripts/model_merger.py merge \ --backend fsdp \ --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ --target_dir /path/to/merged_hf_model```To merge Megatron checkpoints:```shpython scripts/model_merger.py merge \ --backend megatron \ --tie-word-embedding \ --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ --target_dir /path/to/merged_hf_model```For more details, please refer to documentation:https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model"""importargparseimportosimportrefromabcimportABC,abstractmethodfromconcurrent.futuresimportThreadPoolExecutorfromdataclassesimportdataclass,fieldfrompathlibimportPathfromtypingimportOptional,Unionimportnumpyasnpimporttorchfromaccelerateimportinit_empty_weightsfromsafetensors.torchimportload_filefromtorch.distributed._tensorimportPlacement,Shardfromtransformersimport(AutoConfig,AutoModelForCausalLM,AutoModelForTokenClassification,AutoModelForVision2Seq,GenerationConfig,PretrainedConfig,)try:# for torch 2.5+fromtorch.distributed.tensorimportDTensorexceptImportError:fromtorch.distributed._tensorimportDTensorfromtqdmimporttqdmtry:importverl# pyright: ignore[reportMissingImports]exceptModuleNotFoundError:verl=None
[docs]@dataclassclassModelMergerConfig:operation:str# 'merge' or 'test'backend:strlocal_dir:strhf_model_config_path:strtarget_dir:Optional[str]="tmp"hf_upload_path:Optional[str]=Noneprivate:bool=Falsetest_hf_dir:Optional[str]=Nonetie_word_embedding:bool=Falseis_value_model:bool=Falsehf_model_path:Optional[str]=Nonehf_upload:bool=field(init=False)def__post_init__(self):self.hf_upload=self.operation=="merge"andbool(self.hf_upload_path)ifself.operation=="test":self.target_dir=Noneself.hf_upload_path=Noneself.private=False
[docs]classBaseModelMerger(ABC):def__init__(self,config:ModelMergerConfig):self.config=configself.hf_model_config_path=config.hf_model_config_pathifconfig.hf_model_path:print("Warning: --hf_model_path is deprecated and will be removed in a future version. Currently verl will save huggingface model configuration files into checkpoint directories. Therefore, there is no need to provide --hf_model_path. ")self.hf_model_config_path=config.hf_model_pathself.model_config=AutoConfig.from_pretrained(self.hf_model_config_path)
[docs]defpatch_model_generation_config(self,model):"""The generation_config created from model config may be different to the pretrained model, this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 This function patch the generation_config created from model config to the pretrained model. """ifmodel.can_generate():try:model.generation_config=GenerationConfig.from_pretrained(self.hf_model_config_path)exceptOSError:print(f"Warning: Generation config file not found in {self.hf_model_config_path}, using a generation config created from the model config.")returnmodel
[docs]defsave_hf_model_and_tokenizer(self,state_dict:dict[str,torch.Tensor]):ifverlisNone:raiseRuntimeError("verl is not installed. ""Please install it with 'pip install `oumi[gpu]`'.")fromverl.utilsimporthf_processor,hf_tokenizerauto_model_class=self.get_transformers_auto_model_class()withinit_empty_weights():model=auto_model_class.from_config(self.model_config,torch_dtype=torch.bfloat16)model.to_empty(device="cpu")model=self.patch_model_generation_config(model)print(f"Saving model to {self.config.target_dir}")model.save_pretrained(self.config.target_dir,state_dict=state_dict)delstate_dictdelmodelprocessor=hf_processor(self.hf_model_config_path)tokenizer=hf_tokenizer(self.hf_model_config_path)ifprocessorisnotNone:print(f"Saving processor to {self.config.target_dir}")processor.save_pretrained(self.config.target_dir)iftokenizerisnotNone:print(f"Saving tokenizer to {self.config.target_dir}")tokenizer.save_pretrained(self.config.target_dir)
[docs]@abstractmethoddefmerge_and_save(self):raiseNotImplementedError("Subclasses should implement this method")
[docs]classFSDPModelMerger(BaseModelMerger):def_get_world_size(self)->int:"""Extracts the FSDP world_size from checkpoint filenames (e.g., 'model_world_size_8_rank_0.pt')."""forfilenameinos.listdir(self.config.local_dir):match=re.match(r"model_world_size_(\d+)_rank_0\.pt",filename)ifmatch:returnint(match.group(1))raiseFileNotFoundError(rf"Could not determine world size. No file matching 'model_world_size_(\d+)_rank_0.pt' found in {self.config.local_dir}")def_load_rank_zero_state_dict(self,world_size:int)->dict:returntorch.load(Path(self.config.local_dir)/f"model_world_size_{world_size}_rank_0.pt",map_location="cpu",weights_only=False,)def_extract_device_mesh_info(self,state_dict:dict,world_size:int)->tuple[np.ndarray,tuple[str,...]]:"""Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. If no DTensor is found, infers a simple FSDP mesh based on world_size. """pivot_key=sorted(list(state_dict.keys()))[0]weight=state_dict[pivot_key]ifisinstance(weight,DTensor):# get sharding infodevice_mesh=weight.device_meshmesh=device_mesh.meshmesh_dim_names=device_mesh.mesh_dim_nameselse:# for non-DTensormesh=np.array([world_size],dtype=np.int64)mesh_dim_names=("fsdp",)returnmesh,mesh_dim_namesdef_calculate_shard_configuration(self,mesh:np.ndarray,mesh_dim_names:tuple[str,...])->tuple[int,tuple[int,...]]:"""Calculates the total number of shards and the shape of the device mesh."""assertmesh_dim_namesin(("fsdp",),("ddp","fsdp"),),f"Unsupported mesh_dim_names {mesh_dim_names}"if"tp"inmesh_dim_names:# TODO: "tp" is not supported yet due to the above asserttotal_shards=mesh.shape[-1]*mesh.shape[-2]mesh_shape=(mesh.shape[-2],mesh.shape[-1])else:total_shards=mesh.shape[-1]mesh_shape=(mesh.shape[-1],)returntotal_shards,mesh_shapedef_merge_by_placement(self,tensors:list[torch.Tensor],placement:Placement)->torch.Tensor:"""Merges a list of tensors based on their DTensor placement"""ifplacement.is_replicate():returntensors[0]elifplacement.is_partial():raiseNotImplementedError("Partial placement is not supported yet")elifplacement.is_shard():returntorch.cat(tensors,dim=placement.dim).contiguous()raiseNotImplementedError(f"Unsupported placement: {placement}")def_load_and_merge_state_dicts(self,world_size:int,total_shards:int,mesh_shape:tuple[int,...],mesh_dim_names:tuple[str,...],)->dict[str,torch.Tensor]:model_state_dict_lst=[None]*total_shardsdefprocess_one_shard(rank:int,model_state_dict_lst:list):model_path=(Path(self.config.local_dir)/f"model_world_size_{world_size}_rank_{rank}.pt")state_dict=torch.load(model_path,map_location="cpu",weights_only=False)model_state_dict_lst[rank]=state_dictreturnstate_dictwithThreadPoolExecutor(max_workers=min(32,os.cpu_count()))asexecutor:futures=[executor.submit(process_one_shard,rank,model_state_dict_lst)forrankinrange(total_shards)]forfutureintqdm(futures,desc=f"Loading {total_shards} FSDP shards",total=total_shards):future.result()# Merge state dicts from all shardsstate_dict={}param_placements:dict[str,list]={}forkeyinset(model_state_dict_lst[0].keys()):state_dict[key]=[]formodel_state_shardinmodel_state_dict_lst:# add tensor shard in order of rank to state_dict[key]tensor=model_state_shard.pop(key)ifisinstance(tensor,DTensor):state_dict[key].append(tensor._local_tensor.bfloat16())placements=tuple(tensor.placements)# replicated placement at dp dimension can be discardedifmesh_dim_names[0]in("dp","ddp"):placements=placements[1:]ifkeynotinparam_placements:param_placements[key]=placementselse:assertparam_placements[key]==placementselse:state_dict[key].append(tensor.bfloat16())delmodel_state_dict_lst# Merge tensorsforkeyinsorted(state_dict):ifnotisinstance(state_dict[key],list):print(f"No need to merge key {key}")continueifkeyinparam_placements:# merge shardsplacements:tuple[Shard]=param_placements[key]iflen(mesh_shape)==1:# 1-D list, FSDP without TPassertlen(placements)==1shards=state_dict[key]state_dict[key]=self._merge_by_placement(shards,placements[0])else:# 2-D list, FSDP + TPraiseNotImplementedError("FSDP + TP is not supported yet")else:state_dict[key]=torch.cat(state_dict[key],dim=0)returnstate_dict
[docs]defmerge_and_save(self):world_size=self._get_world_size()rank_zero_state_dict=self._load_rank_zero_state_dict(world_size)mesh,mesh_dim_names=self._extract_device_mesh_info(rank_zero_state_dict,world_size)print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")total_shards,mesh_shape=self._calculate_shard_configuration(mesh,mesh_dim_names)print(f"Processing model shards with {total_shards}{mesh_shape} in total")merged_state_dict=self._load_and_merge_state_dicts(world_size,total_shards,mesh_shape,mesh_dim_names)ifself.config.operation=="test":ifnotself.config.test_hf_dir:raiseValueError("test_hf_dir must be provided for test operation")self._test_state_dict(merged_state_dict)elifself.config.operation=="merge":self.save_hf_model_and_tokenizer(merged_state_dict)ifself.config.hf_upload:self.upload_to_huggingface()else:raiseValueError(f"Unknown operation: {self.config.operation}")
def_test_state_dict(self,state_dict:dict[str,torch.Tensor]):auto_model_class=self.get_transformers_auto_model_class()hf_model=auto_model_class.from_pretrained(self.config.test_hf_dir,torch_dtype=torch.bfloat16)hf_state_dict=hf_model.state_dict()delhf_modelhf_model_keys=set(hf_state_dict.keys())collected_keys=set(state_dict.keys())missing_keys=hf_model_keys-collected_keysassertlen(missing_keys)==0,(f"Missing keys in collected state dict: {list(sorted(missing_keys))}")extra_keys=collected_keys-hf_model_keysassertlen(extra_keys)==0,(f"Extra keys in collected state dict: {list(sorted(extra_keys))}")forkeyinhf_model_keys:hf_shape=hf_state_dict[key].shapecollected_shape=state_dict[key].shapeasserthf_shape==collected_shape,(f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}")hf_dtype=hf_state_dict[key].dtypecollected_dtype=state_dict[key].dtypeasserthf_dtype==collected_dtype,(f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}")torch.testing.assert_close(hf_state_dict[key],state_dict[key],atol=1e-6,rtol=1e-6)print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.")
[docs]classMegatronModelMerger(BaseModelMerger):def__init__(self,config:ModelMergerConfig):fromverl.utils.megatron_utilsimport(get_hf_config_and_tokenizer_checkpoint_path,)config.hf_model_config_path=get_hf_config_and_tokenizer_checkpoint_path(config.local_dir)super().__init__(config)def_get_tp_pp_rank_from_sharded_dir(self,sharded_dir:str)->tuple[int,int]:match=re.match(r"mp_rank_(\d\d)_(\d\d\d)",sharded_dir)assertmatch,f"Invalid sharded dir {sharded_dir}"tp_rank=int(match.group(1))pp_rank=int(match.group(2))returntp_rank,pp_rankdef_check_megatron_checkpoint_path(self,model_path:str)->tuple[list[str],int,int]:"""Validates the Megatron checkpoint structure (presence of 'model.pt' in sharded directories). Determines TP and PP sizes from directory names. """tp_size=0pp_size=0sharded_dirs=sorted(os.listdir(model_path))forsharded_dirinsharded_dirs:assert"model.pt"inos.listdir(Path(model_path)/sharded_dir),(f"model.pt not found in {sharded_dir}")tp_rank,pp_rank=self._get_tp_pp_rank_from_sharded_dir(sharded_dir)tp_size=max(tp_size,tp_rank+1)pp_size=max(pp_size,pp_rank+1)returnsharded_dirs,tp_size,pp_sizedef_merge_across_tp(self,key:str,tp_data:list[torch.Tensor],config:PretrainedConfig,tp_size:int,is_value_model:bool=False,)->Union[torch.Tensor,list[torch.Tensor]]:if"linear_fc1.weight"inkey:# if the tensor is gate and projgate_lst=[]up_lst=[]forinfer_paramintp_data:gate,up=infer_param.chunk(2)gate_lst.append(gate)up_lst.append(up)gate=torch.cat(gate_lst,dim=0)up=torch.cat(up_lst,dim=0)return[gate,up]elif"self_attention.linear_qkv."inkeyand"layer_norm"notinkey:# if the tensor is qkv, for each param on tp, split into q, k, v# concat q, k, v separately.q_lst=[]k_lst=[]v_lst=[]assertconfig.num_attention_heads%config.num_key_value_heads==0num_q_per_kv=config.num_attention_heads//config.num_key_value_headsasserttp_data[0].shape[0]%(num_q_per_kv+2)==0kv_size_per_tp=tp_data[0].shape[0]//(num_q_per_kv+2)split_size=[kv_size_per_tp*num_q_per_kv,kv_size_per_tp,kv_size_per_tp]forinfer_paramintp_data:num_query_groups_per_partition=config.num_key_value_heads//tp_sizeforchunkininfer_param.chunk(num_query_groups_per_partition):split_size=[kv_size_per_tp*num_q_per_kv//num_query_groups_per_partition,kv_size_per_tp//num_query_groups_per_partition,kv_size_per_tp//num_query_groups_per_partition,]q,k,v=chunk.split(split_size)q_lst.append(q)k_lst.append(k)v_lst.append(v)q=torch.cat(q_lst,dim=0)k=torch.cat(k_lst,dim=0)v=torch.cat(v_lst,dim=0)return[q,k,v]elif("layer_norm"inkeyor"layernorm"inkeyor"output_layer"inkeyandis_value_model):returntp_data[0]else:dim=0if"linear_fc2.weight"inkeyor"self_attention.linear_proj"inkey:dim=1returntorch.cat(tp_data,dim=dim)def_load_state_dicts(self,model_ckpt_path:str,sharded_dirs:list[str],tp_size:int,pp_size:int)->list[list[dict]]:model_state_dict_lst=[[Nonefor_inrange(tp_size)]for_inrange(pp_size)]def_process_one_megatron_shard(sharded_dir:str):model_file_path=Path(model_ckpt_path)/sharded_dir/"model.pt"state_dict=torch.load(model_file_path,map_location="cpu",weights_only=False)tp_rank,pp_rank=self._get_tp_pp_rank_from_sharded_dir(sharded_dir)model_state_dict_lst[pp_rank][tp_rank]=state_dictwithThreadPoolExecutor(max_workers=min(32,os.cpu_count()))asexecutor:futures=[executor.submit(_process_one_megatron_shard,sharded_dir)forsharded_dirinsharded_dirs]forfutureintqdm(futures,desc=f"Loading {len(sharded_dirs)} Megatron shards",total=len(sharded_dirs),):future.result()returnmodel_state_dict_lstdef_merge_state_dicts(self,model_state_dict_lst:list[list[dict]],tp_size:int,pp_size:int)->dict[str,torch.Tensor]:state_dict={}vpp_size=len(model_state_dict_lst[0][0])layers_cum=0forvpp_rankinrange(vpp_size):forpp_rankinrange(pp_size):layers_handled=0keys=model_state_dict_lst[pp_rank][0][vpp_rank].keys()forkeyinkeys:if"extra_state"inkey:continueifself.config.tie_word_embeddingand("output_layer"inkey):print("skip lm_head and reward_head loading because of tie_word_embeddings")continuenew_key=keyif"decoder.layers."inkey:local_layer_no=int(key.split(".")[2])layers_handled=max(local_layer_no,layers_handled)global_layer_no=local_layer_no+layers_cumnew_key_list=key.split(".")new_key_list[2]=str(global_layer_no)new_key=".".join(new_key_list)tp_data=[model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key]fortp_rankinrange(tp_size)]merged=self._merge_across_tp(new_key,tp_data,self.model_config,tp_size,self.config.is_value_model,)ifnotisinstance(merged,list):state_dict[new_key]=mergedeliflen(merged)==3:# split qkvforn,dinzip(["q","k","v"],merged):state_dict[new_key.replace("linear_qkv",f"linear_{n}")]=deliflen(merged)==2:# split gate upstate_dict[new_key.replace("linear_fc1","gate_proj")]=merged[0]state_dict[new_key.replace("linear_fc1","up_proj")]=merged[1]layers_cum+=layers_handled+1# zero basedreturnstate_dict
[docs]defmerge_and_save(self):fromverl.utils.megatron_utilsimportget_model_checkpoint_pathmodel_ckpt_path=get_model_checkpoint_path(self.config.local_dir)sharded_dirs,tp_size,pp_size=self._check_megatron_checkpoint_path(model_ckpt_path)print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}")model_state_dict_lst=self._load_state_dicts(model_ckpt_path,sharded_dirs,tp_size,pp_size)merged_state_dict=self._merge_state_dicts(model_state_dict_lst,tp_size,pp_size)delmodel_state_dict_lstifself.config.operation=="test":ifnotself.config.test_hf_dir:raiseValueError("test_hf_dir must be provided for test operation")self._test_state_dict(merged_state_dict)elifself.config.operation=="merge":self.save_hf_model_and_tokenizer(merged_state_dict)ifself.config.hf_upload:self.upload_to_huggingface()else:raiseValueError(f"Unknown operation: {self.config.operation}")
def_test_state_dict(self,state_dict:dict[str,torch.Tensor]):"""Compares the merged Megatron state_dict against a reference safetensors model. Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. """ref_state_dict=load_file(Path(self.config.test_hf_dir)/"model.safetensors")params_mapping=[# (megatron core gpt model name, vllm model name)("self_attention.linear_qkv.layer_norm_weight","input_layernorm.weight"),("self_attention.linear_qkv.layer_norm_bias","input_layernorm.bias"),("embedding.word_embeddings","model.embed_tokens"),("self_attention.linear_qkv","self_attn.qkv_proj"),("self_attention.linear_proj","self_attn.o_proj"),("pre_mlp_layernorm","post_attention_layernorm"),("mlp.linear_fc1.layer_norm_weight","post_attention_layernorm.weight"),("mlp.linear_fc1.layer_norm_bias","post_attention_layernorm.bias"),("mlp.linear_fc1","mlp.gate_up_proj"),("mlp.linear_fc2","mlp.down_proj"),("decoder.final_layernorm","model.norm"),("output_layer","lm_head"),("self_attention.linear_q","self_attn.q_proj"),("self_attention.linear_k","self_attn.k_proj"),("self_attention.linear_v","self_attn.v_proj"),]fororiginal_name,loaded_weightinstate_dict.items():name=self._replace_name(original_name,params_mapping)ifnotnameorname.endswith(".bias")andnamenotinref_state_dict:continueif"rotary_emb.inv_freq"inname:continueifself.config.tie_word_embeddingand"lm_head.weight"inname:continueifnamenotinref_state_dict:raiseRuntimeError(f"key: {name} not exist in state_dict")param=ref_state_dict[name]assertloaded_weight.dtype==param.dtypetorch.testing.assert_close(loaded_weight,param,atol=1e-2,rtol=5e-2)def_replace_name(self,megatron_name:str,name_mapping:list[tuple[str,str]])->str:form_name,v_nameinname_mapping:ifm_namenotinmegatron_name:continueif"layers"inmegatron_name:# deal with decoder layersmegatron_name=megatron_name.replace("decoder","model")megatron_name_list=megatron_name.split(".")if("layer_norm_weight"inmegatron_name_listor"layer_norm_bias"inmegatron_name_list):param_name_list=megatron_name_list[:3]param_name_list.append(v_name)param_name=".".join(param_name_list)else:param_name_list=megatron_name_list[:3]weight_or_bias=megatron_name_list[-1]param_name_list.append(v_name)param_name_list.append(weight_or_bias)param_name=".".join(param_name_list)returnparam_nameelse:param_name=megatron_name.replace(m_name,v_name)returnparam_namereturnNone# Return None if no mapping found
[docs]defmain():parser=argparse.ArgumentParser(description="verl model merger")subparsers=parser.add_subparsers(dest="operation",required=True,help="Specify 'merge' or 'test' operation.")base_op_parser=argparse.ArgumentParser(add_help=False)base_op_parser.add_argument("--backend",type=str,required=True,choices=["fsdp","megatron"],help="The backend of the model",)base_op_parser.add_argument("--local_dir",type=str,required=True,help="Path to the saved model checkpoints",)base_op_parser.add_argument("--hf_model_path",type=str,default=None,help="(Deprecated) Path to the original Hugging Face model for config.",)base_op_parser.add_argument("--tie-word-embedding",action="store_true",help="Whether to tie word embedding weights (currently only Megatron supported)",)base_op_parser.add_argument("--is-value-model",action="store_true",help="Whether the model is a value model (currently only Megatron supported)",)merge_parser=subparsers.add_parser("merge",parents=[base_op_parser],help="Merge model checkpoints and save.")merge_parser.add_argument("--target_dir",default="tmp",type=str,help="Directory to save the merged huggingface model",)merge_parser.add_argument("--hf_upload_path",default=None,type=str,help="Hugging Face repository ID to upload the model",)merge_parser.add_argument("--private",action="store_true",help="Whether to upload the model to a private Hugging Face repository",)test_parser=subparsers.add_parser("test",parents=[base_op_parser],help="Test merged model against a reference Hugging Face model",)test_parser.add_argument("--test_hf_dir",type=str,required=True,help="Path to the reference Hugging Face model directory for testing",)args=parser.parse_args()common_config_args={"operation":args.operation,"backend":args.backend,"tie_word_embedding":args.tie_word_embedding,"is_value_model":args.is_value_model,"local_dir":args.local_dir,"hf_model_path":args.hf_model_path,"hf_model_config_path":args.local_dir,}ifargs.operation=="merge":config=ModelMergerConfig(**common_config_args,target_dir=args.target_dir,hf_upload_path=args.hf_upload_path,private=args.private,test_hf_dir=None,)os.makedirs(config.target_dir,exist_ok=True)elifargs.operation=="test":config=ModelMergerConfig(**common_config_args,test_hf_dir=args.test_hf_dir,# the following args are not used by test operationtarget_dir=None,hf_upload_path=None,private=False,)else:raiseNotImplementedError(f"Unknown operation: {args.operation}")ifconfig.backend=="fsdp":merger=FSDPModelMerger(config)elifconfig.backend=="megatron":merger=MegatronModelMerger(config)else:raiseNotImplementedError(f"Unknown backend: {config.backend}")merger.merge_and_save()