mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
basic changes
This commit is contained in:
parent
14ebf7a492
commit
80d2608c4e
2 changed files with 105 additions and 36 deletions
|
|
@ -195,8 +195,9 @@ def register_trainer(config: TrainingConfig):
|
|||
response = requests.post(
|
||||
"http://localhost:8000/register",
|
||||
json={
|
||||
"wandb_group": config.wandb_group,
|
||||
"wandb_project": config.wandb_project,
|
||||
# wandb fields are required strings - use empty string if None
|
||||
"wandb_group": config.wandb_group or "",
|
||||
"wandb_project": config.wandb_project or "",
|
||||
"batch_size": config.batch_size * config.gradient_accumulation_steps,
|
||||
"max_token_len": config.seq_len,
|
||||
"starting_step": 0,
|
||||
|
|
@ -1103,9 +1104,9 @@ def train_shared_vllm(config: TrainingConfig):
|
|||
model, tokenizer = load_model_and_tokenizer(config, bridge=bridge)
|
||||
optimizer = AdamW(model.parameters(), lr=config.lr)
|
||||
|
||||
# For NCCL mode, set param list from trainer's model
|
||||
# For NCCL mode, build mapping between trainer's and vLLM's param names
|
||||
if config.use_shared_memory:
|
||||
bridge.set_param_list_from_model(model)
|
||||
bridge.build_param_mapping(model)
|
||||
|
||||
print(f"[3/3] Starting training for {config.training_steps} steps")
|
||||
print("NOTE: vLLM sees weight updates immediately after each step!")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue