mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
#!/usr/bin/env python
|
|
# encoding: utf-8
|
|
from collections import defaultdict
|
|
from glob import glob
|
|
|
|
import fire
|
|
import torch
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
def main(fsdp_checkpoint_path, huggingface_model_path, output_path):
|
|
state_dict = defaultdict(list)
|
|
|
|
world_size = 4
|
|
for rank in range(world_size):
|
|
filepath = f"{fsdp_checkpoint_path}/model_world_size_{world_size}_rank_{rank}.pt"
|
|
print("loading", filepath)
|
|
this_state_dict = torch.load(filepath)
|
|
for key, value in this_state_dict.items():
|
|
state_dict[key].append(value.to_local())
|
|
|
|
for key in state_dict:
|
|
state_dict[key] = torch.cat(state_dict[key], dim=0)
|
|
|
|
config = AutoConfig.from_pretrained(huggingface_model_path)
|
|
model = AutoModelForCausalLM.from_config(config)
|
|
model.load_state_dict(state_dict)
|
|
|
|
model.save_pretrained(output_path, max_shard_size="10GB")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(huggingface_model_path)
|
|
tokenizer.save_pretrained(output_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(main)
|