mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Feat/unsloth example (#482)
* cleaned up examples * updated failing hooks * updated readme * corrected linting checks
This commit is contained in:
parent
d9cd20c174
commit
1c98584f28
29 changed files with 122 additions and 2857 deletions
2
examples/OpenRLHF/.gitignore
vendored
2
examples/OpenRLHF/.gitignore
vendored
|
|
@ -1,2 +0,0 @@
|
|||
checkpoint/
|
||||
wandb/
|
||||
|
|
@ -1,736 +0,0 @@
|
|||
# tested with OpenRLHF 707b970e992154952a91607ca5491cc49b8665c3
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from openrlhf.datasets import PromptDataset, SFTDataset
|
||||
from openrlhf.models import Actor, get_llm_for_sequence_regression
|
||||
from openrlhf.models.utils import compute_approx_kl, masked_mean
|
||||
from openrlhf.trainer import PPOTrainer
|
||||
from openrlhf.trainer.ppo_utils.experience_maker import Experience, NaiveExperienceMaker, Samples
|
||||
from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer
|
||||
from openrlhf.utils.logging_utils import init_logger
|
||||
from torch.utils.data import Dataset
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
import reasoning_gym
|
||||
from reasoning_gym.dataset import ProceduralDataset
|
||||
from reasoning_gym.utils import extract_answer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
DEBUG = False
|
||||
|
||||
|
||||
def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str:
|
||||
if apply_chat_template:
|
||||
chat = data[input_key]
|
||||
if isinstance(chat, str):
|
||||
chat = [{"role": "user", "content": chat}]
|
||||
prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
else:
|
||||
prompt = data[input_key]
|
||||
if input_template:
|
||||
prompt = input_template.format(prompt)
|
||||
return prompt
|
||||
|
||||
|
||||
class ReasoningGymDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
tokenizer,
|
||||
developer_prompt: Optional[str] = None,
|
||||
developer_role: str = "system",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dataset = dataset
|
||||
self.tokenizer = tokenizer
|
||||
self.developer_prompt = developer_prompt
|
||||
self.developer_role = developer_role
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.dataset)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[str, dict]:
|
||||
x = self.dataset[idx]
|
||||
|
||||
q = x["question"]
|
||||
chat = []
|
||||
if self.developer_prompt is not None:
|
||||
chat.append({"role": self.developer_role, "content": self.developer_prompt})
|
||||
chat.append({"role": "user", "content": q})
|
||||
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
return prompt, x
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplesWithMetadata(Samples):
|
||||
metadata: Optional[dict]
|
||||
|
||||
|
||||
class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: ProceduralDataset,
|
||||
actor: Actor,
|
||||
critic: nn.Module,
|
||||
reward_model: nn.Module,
|
||||
initial_model: Actor,
|
||||
tokenizer,
|
||||
prompt_max_len: int,
|
||||
kl_controller,
|
||||
strategy=None,
|
||||
remote_rm_url: str = None,
|
||||
reward_fn=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
actor=actor,
|
||||
critic=critic,
|
||||
reward_model=reward_model,
|
||||
initial_model=initial_model,
|
||||
tokenizer=tokenizer,
|
||||
prompt_max_len=prompt_max_len,
|
||||
kl_controller=kl_controller,
|
||||
strategy=strategy,
|
||||
remote_rm_url=remote_rm_url,
|
||||
reward_fn=reward_fn,
|
||||
)
|
||||
self.dataset = dataset
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_samples(self, all_prompts: list[tuple[str, Any]], **generate_kwargs) -> list[Samples]:
|
||||
"""
|
||||
Generate samples and return in batches.
|
||||
"""
|
||||
assert not getattr(self, "packing_samples", False)
|
||||
args = self.strategy.args
|
||||
self.actor.eval()
|
||||
|
||||
# prepare inputs to sample multiple response
|
||||
repeated_prompts = []
|
||||
repeated_metadata = []
|
||||
for prompt, metadata in all_prompts:
|
||||
for _ in range(args.n_samples_per_prompt):
|
||||
repeated_prompts.append(prompt)
|
||||
repeated_metadata.append(metadata)
|
||||
|
||||
samples_list = []
|
||||
for i in range(0, len(repeated_prompts), args.micro_rollout_batch_size):
|
||||
prompts = repeated_prompts[i : i + args.micro_rollout_batch_size]
|
||||
metadata = repeated_metadata[i : i + args.micro_rollout_batch_size]
|
||||
inputs = self.tokenize_fn(prompts, self.prompt_max_len, device="cuda")
|
||||
sequences, attention_mask, action_mask = self.actor.generate(**inputs, **generate_kwargs)
|
||||
samples = SamplesWithMetadata(
|
||||
sequences=sequences,
|
||||
attention_mask=attention_mask,
|
||||
action_mask=action_mask,
|
||||
num_actions=action_mask.size(1),
|
||||
packed_seq_lens=None,
|
||||
response_length=action_mask.float().sum(dim=-1),
|
||||
total_length=attention_mask.float().sum(dim=-1),
|
||||
metadata=metadata,
|
||||
)
|
||||
samples_list.append(samples)
|
||||
return samples_list
|
||||
|
||||
@torch.no_grad()
|
||||
def make_experience(self, samples: Samples) -> Experience:
|
||||
"""
|
||||
Turn samples into experience by calculating logprobs, values, rewards, and kl divergence.
|
||||
"""
|
||||
self.actor.eval()
|
||||
self.initial_model.eval()
|
||||
if self.reward_model is not None:
|
||||
self.reward_model.eval()
|
||||
if self.critic is not None:
|
||||
self.critic.eval()
|
||||
|
||||
# extract values from samples
|
||||
sequences = samples.sequences
|
||||
attention_mask = samples.attention_mask
|
||||
action_mask = samples.action_mask
|
||||
num_actions = samples.num_actions
|
||||
if isinstance(samples, SamplesWithMetadata):
|
||||
metadata = samples.metadata
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
# log probs
|
||||
action_log_probs = self.actor(sequences, num_actions, attention_mask)
|
||||
|
||||
# init log probs
|
||||
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
|
||||
|
||||
# values
|
||||
if self.critic is not None:
|
||||
value = self.critic(sequences, num_actions, attention_mask)
|
||||
else:
|
||||
value = None
|
||||
|
||||
# determine outcome reward
|
||||
completions = sequences[:, -action_mask.size(1) :].cpu()
|
||||
completions = self.tokenizer.batch_decode(completions, skip_special_tokens=True)
|
||||
returns = [
|
||||
self.dataset.score_answer(extract_answer(c, tag_name="answer"), entry=m)
|
||||
for c, m in zip(completions, metadata)
|
||||
]
|
||||
r = torch.tensor(returns, dtype=torch.float32, device=sequences.device)
|
||||
|
||||
kl = compute_approx_kl(
|
||||
action_log_probs,
|
||||
base_action_log_probs,
|
||||
action_mask=action_mask,
|
||||
use_kl_estimator_k3=self.strategy.args.use_kl_estimator_k3,
|
||||
)
|
||||
|
||||
info = {
|
||||
"kl": masked_mean(kl, action_mask, dim=-1),
|
||||
"reward": r,
|
||||
"response_length": samples.response_length,
|
||||
"total_length": samples.total_length,
|
||||
"num_actions": num_actions,
|
||||
}
|
||||
|
||||
logger.info(f"info={info}")
|
||||
|
||||
# reset model state
|
||||
self.actor.train()
|
||||
if self.critic is not None:
|
||||
self.critic.train()
|
||||
|
||||
return Experience(
|
||||
sequences,
|
||||
action_log_probs,
|
||||
value,
|
||||
None,
|
||||
None,
|
||||
attention_mask,
|
||||
action_mask,
|
||||
info,
|
||||
kl,
|
||||
)
|
||||
|
||||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
strategy = get_strategy(args)
|
||||
strategy.setup_distributed()
|
||||
|
||||
# configure model
|
||||
# load huggingface model
|
||||
actor = Actor(
|
||||
args.pretrain,
|
||||
use_flash_attention_2=args.flash_attn,
|
||||
bf16=args.bf16,
|
||||
load_in_4bit=args.load_in_4bit,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
target_modules=args.target_modules,
|
||||
lora_dropout=args.lora_dropout,
|
||||
ds_config=strategy.get_ds_train_config(is_actor=True),
|
||||
)
|
||||
|
||||
if args.actor_init_on_gpu:
|
||||
actor = actor.to(torch.cuda.current_device())
|
||||
|
||||
if args.critic_pretrain:
|
||||
critic = get_llm_for_sequence_regression(
|
||||
args.critic_pretrain,
|
||||
"critic",
|
||||
normalize_reward=args.normalize_reward,
|
||||
use_flash_attention_2=args.flash_attn,
|
||||
bf16=args.bf16,
|
||||
load_in_4bit=args.load_in_4bit,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
target_modules=args.target_modules,
|
||||
lora_dropout=args.lora_dropout,
|
||||
ds_config=strategy.get_ds_train_config(is_actor=False),
|
||||
value_head_prefix=args.value_head_prefix,
|
||||
init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain,
|
||||
)
|
||||
else:
|
||||
critic = None
|
||||
|
||||
reward_model = None
|
||||
|
||||
strategy.print(actor)
|
||||
strategy.print(critic)
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
args.pretrain,
|
||||
actor.model,
|
||||
"left",
|
||||
strategy,
|
||||
use_fast=not args.disable_fast_tokenizer,
|
||||
)
|
||||
|
||||
# load weights for reference actor
|
||||
initial_model = Actor(
|
||||
args.pretrain,
|
||||
use_flash_attention_2=args.flash_attn,
|
||||
bf16=args.bf16,
|
||||
load_in_4bit=args.load_in_4bit,
|
||||
ds_config=strategy.get_ds_eval_config(offload=False),
|
||||
)
|
||||
|
||||
if args.enable_ema:
|
||||
ema_model = Actor(
|
||||
args.pretrain,
|
||||
use_flash_attention_2=args.flash_attn,
|
||||
bf16=args.bf16,
|
||||
load_in_4bit=args.load_in_4bit,
|
||||
ds_config=strategy.get_ds_eval_config(offload=True),
|
||||
)
|
||||
else:
|
||||
ema_model = None
|
||||
|
||||
# gradient_checkpointing
|
||||
if args.gradient_checkpointing:
|
||||
actor.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
|
||||
)
|
||||
if critic is not None:
|
||||
critic.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
|
||||
)
|
||||
|
||||
# configure optimizer
|
||||
actor_optim = strategy.create_optimizer(
|
||||
actor, lr=args.actor_learning_rate, betas=args.adam_betas, weight_decay=args.l2
|
||||
)
|
||||
if args.critic_pretrain:
|
||||
critic_optim = strategy.create_optimizer(
|
||||
critic,
|
||||
lr=args.critic_learning_rate,
|
||||
betas=args.adam_betas,
|
||||
weight_decay=args.l2,
|
||||
)
|
||||
else:
|
||||
critic_optim = None
|
||||
|
||||
# prepare datasets
|
||||
print("prompt_data", args.prompt_data)
|
||||
|
||||
# DeepSeek Zero system prompt
|
||||
system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
|
||||
<answer> answer here </answer>
|
||||
"""
|
||||
|
||||
prompts_data = reasoning_gym.create_dataset(args.prompt_data, size=args.max_samples)
|
||||
prompts_dataset = ReasoningGymDataset(prompts_data, tokenizer, developer_prompt=system_prompt)
|
||||
|
||||
if args.pretrain_data:
|
||||
pretrain_data = blending_datasets(
|
||||
args.pretrain_data,
|
||||
args.pretrain_data_probs,
|
||||
strategy,
|
||||
args.seed,
|
||||
return_eval=False,
|
||||
train_split=args.pretrain_split,
|
||||
)
|
||||
pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len
|
||||
pretrain_dataset = SFTDataset(
|
||||
pretrain_data.select(
|
||||
range(
|
||||
min(
|
||||
len(pretrain_data),
|
||||
args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt,
|
||||
)
|
||||
)
|
||||
),
|
||||
tokenizer,
|
||||
pretrain_max_len,
|
||||
strategy,
|
||||
pretrain_mode=True,
|
||||
)
|
||||
|
||||
def collate_prompt_and_metadata(xs: list[tuple]) -> list[tuple]:
|
||||
"""dummy collate function to pass on the metadata dict unchanged"""
|
||||
return xs
|
||||
|
||||
# prepare dataloader
|
||||
prompts_dataloader = strategy.setup_dataloader(
|
||||
prompts_dataset,
|
||||
args.rollout_batch_size // strategy.world_size,
|
||||
True,
|
||||
True,
|
||||
collate_fn=collate_prompt_and_metadata,
|
||||
)
|
||||
if args.pretrain_data:
|
||||
pretrain_dataloader = itertools.cycle(
|
||||
iter(
|
||||
strategy.setup_dataloader(
|
||||
pretrain_dataset,
|
||||
args.micro_train_batch_size,
|
||||
True,
|
||||
True,
|
||||
pretrain_dataset.collate_fn,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
pretrain_dataloader = None
|
||||
|
||||
# configure scheduler
|
||||
num_update_steps_per_episodes = (
|
||||
len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs
|
||||
)
|
||||
max_steps = math.ceil(args.num_episodes * num_update_steps_per_episodes)
|
||||
|
||||
actor_scheduler = get_scheduler(
|
||||
"cosine_with_min_lr",
|
||||
actor_optim,
|
||||
num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio),
|
||||
num_training_steps=max_steps,
|
||||
scheduler_specific_kwargs={"min_lr": args.actor_learning_rate * 0.1},
|
||||
)
|
||||
|
||||
if args.critic_pretrain:
|
||||
critic_scheduler = get_scheduler(
|
||||
"cosine_with_min_lr",
|
||||
critic_optim,
|
||||
num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio),
|
||||
num_training_steps=max_steps,
|
||||
scheduler_specific_kwargs={"min_lr": args.critic_learning_rate * 0.1},
|
||||
)
|
||||
else:
|
||||
critic_scheduler = None
|
||||
|
||||
# prepare models/optimizers...
|
||||
(
|
||||
(actor, actor_optim, actor_scheduler),
|
||||
(critic, critic_optim, critic_scheduler),
|
||||
reward_model,
|
||||
initial_model,
|
||||
) = strategy.prepare(
|
||||
(actor, actor_optim, actor_scheduler),
|
||||
(critic, critic_optim, critic_scheduler),
|
||||
reward_model,
|
||||
initial_model,
|
||||
is_rlhf=True,
|
||||
)
|
||||
|
||||
if ema_model:
|
||||
ema_model._offload = True
|
||||
ema_model = strategy.prepare(ema_model, is_rlhf=True)
|
||||
|
||||
# load checkpoint
|
||||
consumed_samples = 0
|
||||
if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")):
|
||||
_, states = strategy.load_ckpt(actor.model, os.path.join(args.ckpt_path, "_actor"))
|
||||
if args.critic_pretrain:
|
||||
strategy.load_ckpt(critic, os.path.join(args.ckpt_path, "_critic"))
|
||||
consumed_samples = states["consumed_samples"]
|
||||
strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}")
|
||||
|
||||
os.makedirs(args.save_path, exist_ok=True)
|
||||
|
||||
# configure Trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
ema_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
actor_scheduler,
|
||||
critic_scheduler,
|
||||
max_epochs=args.max_epochs,
|
||||
micro_train_batch_size=args.micro_train_batch_size,
|
||||
micro_rollout_batch_size=args.micro_rollout_batch_size,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
tokenizer=tokenizer,
|
||||
prompt_max_len=args.prompt_max_len,
|
||||
value_clip=args.value_clip,
|
||||
eps_clip=args.eps_clip,
|
||||
gamma=args.gamma,
|
||||
lambd=args.lambd,
|
||||
init_kl_coef=args.init_kl_coef,
|
||||
kl_target=args.kl_target,
|
||||
ema_beta=0.992,
|
||||
ptx_coef=args.ptx_coef,
|
||||
max_norm=args.max_norm,
|
||||
# fro GPT generation
|
||||
do_sample=True,
|
||||
max_new_tokens=args.generate_max_len,
|
||||
max_length=args.max_len,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
# remote reward model
|
||||
remote_rm_url=args.remote_rm_url,
|
||||
save_hf_ckpt=args.save_hf_ckpt,
|
||||
disable_ds_ckpt=args.disable_ds_ckpt,
|
||||
)
|
||||
|
||||
# patch experience maker ..
|
||||
xp = trainer.experience_maker
|
||||
trainer.experience_maker = AlgorithmicRewardExperienceMaker(
|
||||
dataset=prompts_dataset.dataset,
|
||||
actor=xp.actor,
|
||||
critic=xp.critic,
|
||||
reward_model=xp.reward_model,
|
||||
initial_model=xp.initial_model,
|
||||
tokenizer=xp.tokenizer,
|
||||
prompt_max_len=xp.prompt_max_len,
|
||||
kl_controller=xp.kl_ctl,
|
||||
strategy=xp.strategy,
|
||||
remote_rm_url=xp.remote_rm_url,
|
||||
reward_fn=xp.reward_fn,
|
||||
)
|
||||
xp = None
|
||||
|
||||
trainer.fit(
|
||||
args,
|
||||
prompts_dataloader,
|
||||
pretrain_dataloader,
|
||||
consumed_samples,
|
||||
num_update_steps_per_episodes,
|
||||
)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(
|
||||
ema_model if args.enable_ema else actor,
|
||||
tokenizer,
|
||||
args.save_path,
|
||||
)
|
||||
|
||||
if args.critic_pretrain and args.save_value_network:
|
||||
strategy.save_model(
|
||||
critic,
|
||||
tokenizer,
|
||||
args.save_path + "_critic",
|
||||
)
|
||||
|
||||
|
||||
import os
|
||||
|
||||
import debugpy
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Only attach debugger on rank 0 process
|
||||
if DEBUG and int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
debugpy.listen(("localhost", 5678))
|
||||
print("Waiting for debugger attach on localhost:5678")
|
||||
debugpy.wait_for_client()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# Checkpoint
|
||||
parser.add_argument("--save_path", type=str, default="./ckpt")
|
||||
parser.add_argument("--save_steps", type=int, default=-1)
|
||||
parser.add_argument("--save_hf_ckpt", action="store_true", default=False)
|
||||
parser.add_argument("--disable_ds_ckpt", action="store_true", default=False)
|
||||
parser.add_argument("--logging_steps", type=int, default=1)
|
||||
parser.add_argument("--eval_steps", type=int, default=-1)
|
||||
parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo")
|
||||
parser.add_argument("--max_ckpt_num", type=int, default=3)
|
||||
parser.add_argument("--max_ckpt_mem", type=int, default=1e8)
|
||||
parser.add_argument("--load_checkpoint", action="store_true", default=False)
|
||||
|
||||
# PPO
|
||||
parser.add_argument("--num_episodes", type=int, default=1)
|
||||
parser.add_argument("--rollout_batch_size", type=int, default=512)
|
||||
parser.add_argument("--micro_rollout_batch_size", type=int, default=8)
|
||||
parser.add_argument("--max_epochs", type=int, default=1)
|
||||
parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt")
|
||||
parser.add_argument(
|
||||
"--generate_max_len",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Max tokens to generate in PPO",
|
||||
)
|
||||
parser.add_argument("--max_len", type=int, default=None, help="deprecated max_len")
|
||||
parser.add_argument("--max_samples", type=int, default=1000000)
|
||||
parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping")
|
||||
parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss")
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef")
|
||||
parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range")
|
||||
parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range")
|
||||
parser.add_argument("--lambd", type=float, default=1.0, help="PPO GAE lambd")
|
||||
parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma")
|
||||
parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU")
|
||||
parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size")
|
||||
parser.add_argument(
|
||||
"--normalize_reward",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Reward Normazation",
|
||||
)
|
||||
parser.add_argument("--top_p", type=float, default=1.0)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument(
|
||||
"--freezing_actor_steps",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Used for critic initialization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_samples_per_prompt",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of responses for each prompt in generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_value_network",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Save critic model",
|
||||
)
|
||||
parser.add_argument("--actor_learning_rate", type=float, default=1e-6)
|
||||
parser.add_argument("--critic_learning_rate", type=float, default=9e-6)
|
||||
parser.add_argument("--lr_warmup_ratio", type=float, default=0.03)
|
||||
parser.add_argument("--kl_target", type=float, default=None)
|
||||
parser.add_argument("--init_kl_coef", type=float, default=0.01, help="KL penalty in PPO")
|
||||
parser.add_argument(
|
||||
"--use_kl_estimator_k3",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Use the k3 estimator in http://joschu.net/blog/kl-approx.html"
|
||||
"to ensure the KL divergence calculated is non-negative"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_betas",
|
||||
type=float,
|
||||
nargs=2,
|
||||
default=(0.9, 0.95),
|
||||
help="Betas for Adam optimizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward_clip_range",
|
||||
type=float,
|
||||
nargs=2,
|
||||
default=(-10, 10),
|
||||
help="Reward clip range",
|
||||
)
|
||||
|
||||
# DeepSpeed
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed")
|
||||
parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage")
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
|
||||
parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
|
||||
parser.add_argument("--enable_ema", action="store_true", help="Enable EMA checkpoint for the model.")
|
||||
parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size")
|
||||
parser.add_argument(
|
||||
"--adam_offload",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Offload Adam Optimizer",
|
||||
)
|
||||
parser.add_argument("--actor_init_on_gpu", action="store_true", default=False)
|
||||
parser.add_argument(
|
||||
"--flash_attn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable FlashAttention2",
|
||||
)
|
||||
parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss")
|
||||
parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type")
|
||||
parser.add_argument("--overlap_comm", action="store_true", default=False)
|
||||
parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False)
|
||||
parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
|
||||
|
||||
# Reinforce
|
||||
parser.add_argument(
|
||||
"--advantage_estimator",
|
||||
type=str,
|
||||
choices=["gae", "reinforce", "rloo"],
|
||||
default="gae",
|
||||
help="Choose advantage estimation method: gae, reinforce, rloo",
|
||||
)
|
||||
|
||||
# LoRA
|
||||
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
||||
parser.add_argument("--lora_rank", type=int, default=0)
|
||||
parser.add_argument("--lora_alpha", type=int, default=16)
|
||||
parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear")
|
||||
parser.add_argument("--lora_dropout", type=float, default=0)
|
||||
|
||||
# Models
|
||||
parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path")
|
||||
parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path")
|
||||
parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API")
|
||||
parser.add_argument("--critic_pretrain", type=str, default=None, help="HF model name or path")
|
||||
parser.add_argument("--value_head_prefix", type=str, default="score")
|
||||
|
||||
# Custom dataset
|
||||
parser.add_argument("--prompt_data", type=str, default="chain_sum", help="HF dataset name or path")
|
||||
parser.add_argument(
|
||||
"--prompt_data_probs",
|
||||
type=str,
|
||||
default="1.0",
|
||||
help="sampling probs for datasets",
|
||||
)
|
||||
parser.add_argument("--prompt_split", type=str, default="train")
|
||||
parser.add_argument("--pretrain_data", type=str, default=None, help="HF dataset name or path")
|
||||
parser.add_argument(
|
||||
"--pretrain_data_probs",
|
||||
type=str,
|
||||
default="1.0",
|
||||
help="sampling probs for datasets",
|
||||
)
|
||||
parser.add_argument("--pretrain_split", type=str, default="train")
|
||||
parser.add_argument("--input_key", type=str, default="question", help="JSON dataset key")
|
||||
parser.add_argument("--input_template", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--apply_chat_template",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use HF tokenizer chat template",
|
||||
)
|
||||
|
||||
# wandb parameters
|
||||
parser.add_argument("--use_wandb", type=str, default=None)
|
||||
parser.add_argument("--wandb_org", type=str, default=None)
|
||||
parser.add_argument("--wandb_group", type=str, default=None)
|
||||
parser.add_argument("--wandb_project", type=str, default="openrlhf_train_ppo")
|
||||
parser.add_argument(
|
||||
"--wandb_run_name",
|
||||
type=str,
|
||||
default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"),
|
||||
)
|
||||
|
||||
# TensorBoard parameters
|
||||
parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.advantage_estimator not in ["gae"]:
|
||||
args.critic_pretrain = None
|
||||
elif args.critic_pretrain is None:
|
||||
args.critic_pretrain = args.pretrain ## temp
|
||||
|
||||
if args.advantage_estimator == "rloo":
|
||||
assert args.n_samples_per_prompt > 1, "RLOO requires n_samples_per_prompt > 1"
|
||||
|
||||
if args.input_template and "{}" not in args.input_template:
|
||||
print("[Warning] {} not in args.input_template, set to None")
|
||||
args.input_template = None
|
||||
|
||||
if args.input_template and "\\n" in args.input_template:
|
||||
print(
|
||||
"[Warning] input_template contains \\n chracters instead of newline. "
|
||||
"You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell."
|
||||
)
|
||||
|
||||
train(args)
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
#!/bin/bash
|
||||
args=(
|
||||
custom_reward.py
|
||||
--pretrain meta-llama/Llama-3.2-3B-Instruct # meta-llama/Llama-3.2-1B-Instruct
|
||||
--save_path ./checkpoint/Llama-3.2-3B-rl # ./checkpoint/Llama-3.2-1b-lr
|
||||
--save_steps -1
|
||||
--logging_steps 1
|
||||
--eval_steps -1
|
||||
--micro_train_batch_size 2
|
||||
--train_batch_size 128
|
||||
--micro_rollout_batch_size 16
|
||||
--rollout_batch_size 1024
|
||||
--max_epochs 1
|
||||
--prompt_max_len 1024
|
||||
--generate_max_len 1024
|
||||
--zero_stage 2
|
||||
--bf16
|
||||
--actor_learning_rate 5e-7
|
||||
--init_kl_coef 0.01
|
||||
--prompt_data chain_sum # leg_counting
|
||||
--input_key question
|
||||
--apply_chat_template
|
||||
--normalize_reward
|
||||
--adam_offload
|
||||
--flash_attn
|
||||
--gradient_checkpointing
|
||||
--max_samples 100000
|
||||
--critic_learning_rate 9e-6
|
||||
)
|
||||
# Add wandb argument only if wandb_token is set
|
||||
if [[ -n "${wandb_token}" ]]; then
|
||||
args+=(--use_wandb "${wandb_token}")
|
||||
fi
|
||||
deepspeed ${args[@]}
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
# Reasoning Gym Open-Instruct Example
|
||||
|
||||
This example demonstrates how to use [open-instruct](https://github.com/allenai/open-instruct) with **reasoning-gym** for training language models through reinforcement learning.
|
||||
|
||||
## Environment Setup
|
||||
|
||||
Before running the training script, you may want to set the following environment variables:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
```
|
||||
|
||||
- `CUDA_VISIBLE_DEVICES=0` specifies which GPUs to use (in this case, GPUs 0 )
|
||||
- `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` enables dynamic GPU memory allocation
|
||||
|
||||
## Running the Training Script
|
||||
|
||||
To start training, run:
|
||||
|
||||
```bash
|
||||
bash grpo_config.sh
|
||||
```
|
||||
|
||||
This will train a DeepSeek-R1-Distill-Qwen-1.5B model using GRPO (Group Relative Policy Optimization) on reasoning tasks.
|
||||
|
||||
## Key Features
|
||||
|
||||
- Uses vLLM for efficient inference
|
||||
- Supports multi-GPU training
|
||||
- Implements reward functions for both answer correctness and response formatting
|
||||
- Includes evaluation during training
|
||||
- Supports model checkpointing and pushing to Hugging Face Hub
|
||||
- Integrates with Weights & Biases for experiment tracking
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
exp_name="0302_qwen2.5_math_grpo_fast1_${RANDOM}"
|
||||
python src/grpo_trainer.py \
|
||||
--model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
|
||||
--beta 0.0 \
|
||||
--num_unique_prompts_rollout 16 \
|
||||
--num_samples_per_prompt_rollout 8 \
|
||||
--output_dir ./open_instruct_checkpoints/$exp_name \
|
||||
--save_freq 40 \
|
||||
--kl_estimator kl3 \
|
||||
--learning_rate 5e-7 \
|
||||
--max_token_length 256 \
|
||||
--max_prompt_token_length 128 \
|
||||
--response_length 256 \
|
||||
--pack_length 512 \
|
||||
--stop_strings '"</answer>"' \
|
||||
--chat_template_name r1_simple_chat_postpend_think \
|
||||
--temperature 1.0 \
|
||||
--total_episodes 100000 \
|
||||
--deepspeed_stage 3 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--num_mini_batches 1\
|
||||
--num_learners_per_node 1 \
|
||||
--num_epochs 1 \
|
||||
--vllm_tensor_parallel_size 1 \
|
||||
--vllm_num_engines 1 \
|
||||
--vllm_enforce_eager true \
|
||||
--lr_scheduler_type linear \
|
||||
--seed 1 \
|
||||
--num_evals 100 \
|
||||
--dataset_name chain_sum \
|
||||
--gradient_checkpointing \
|
||||
--with_tracking \
|
||||
--single_gpu_mode false \
|
||||
--gather_whole_model false
|
||||
|
|
@ -1,457 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict, dataclass
|
||||
from queue import Queue
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import ray
|
||||
import torch
|
||||
from open_instruct.ground_truth_utils import soft_format_reward_func
|
||||
from open_instruct.grpo_fast import (
|
||||
Args,
|
||||
ModelGroup,
|
||||
PolicyTrainerRayProcess,
|
||||
Timer,
|
||||
calculate_runtime_args,
|
||||
collate_fn,
|
||||
vllm_generate_thread,
|
||||
)
|
||||
from open_instruct.model_utils import ModelConfig, push_folder_to_hub
|
||||
from open_instruct.rl_utils2 import PackedSequences, reset_position_ids
|
||||
from open_instruct.utils import ArgumentParserPlus, get_wandb_tags, is_beaker_job, maybe_get_beaker_config
|
||||
from open_instruct.vllm_utils2 import create_vllm_engines
|
||||
from ray.util.placement_group import placement_group
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
from utils import ReasoningGymDataset, list_preserving_collate, pack_sequences
|
||||
from vllm import SamplingParams
|
||||
|
||||
from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOArgs(Args):
|
||||
dataset_name: str = "chain_sum"
|
||||
seed: int = 42
|
||||
size: int = 10000
|
||||
eval_seed: int = 42
|
||||
eval_size: int = 100
|
||||
|
||||
|
||||
def data_preparation_thread(
|
||||
reward_fn: Callable,
|
||||
inference_results_Q: Queue,
|
||||
packed_sequences_Q: Queue,
|
||||
queries_prompt_Q: Queue,
|
||||
args: Args,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_training_steps: int,
|
||||
reasoning_gym_dataset: ReasoningGymDataset,
|
||||
):
|
||||
for training_step in range(1, num_training_steps + 1):
|
||||
# Get next batch of prompts and responses
|
||||
observations = queries_prompt_Q.get()
|
||||
queries, items = observations
|
||||
|
||||
if args.num_samples_per_prompt_rollout > 1:
|
||||
queries = [item for item in queries for _ in range(args.num_samples_per_prompt_rollout)]
|
||||
with Timer("🚀 [Data Preparation Thread] Getting response ids"):
|
||||
responses, finish_reasons = inference_results_Q.get()
|
||||
for i in range(len(finish_reasons)):
|
||||
if finish_reasons[i] == "stop" and responses[i][-1] != tokenizer.eos_token_id:
|
||||
responses[i].append(tokenizer.eos_token_id)
|
||||
|
||||
with Timer("📦 [Data Preparation Thread] Packing sequences"):
|
||||
packed_sequences = pack_sequences(
|
||||
queries=queries,
|
||||
responses=responses,
|
||||
pack_length=args.pack_length,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses)
|
||||
|
||||
with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
|
||||
decoded_responses = tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||
stop_rate = sum(int(finish_reason == "stop") for finish_reason in finish_reasons) / len(finish_reasons)
|
||||
|
||||
with Timer("💰 [Data Preparation Thread] Calculating rewards"):
|
||||
scores, reward_metrics = reward_fn(decoded_responses, items, reasoning_gym_dataset)
|
||||
|
||||
with Timer("🎆 [Data Preparation Thread] Calculating advantages"):
|
||||
# Calculate advantages
|
||||
scores = np.array(scores)
|
||||
print(f"[Data Preparation Thread] {len(scores)=}")
|
||||
scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout)
|
||||
global_mean_grouped_rewards = scores_per_prompt.mean(axis=-1)
|
||||
global_mean_grouped_rewards = np.repeat(
|
||||
global_mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0
|
||||
)
|
||||
global_std_grouped_rewards = scores_per_prompt.std(axis=-1)
|
||||
global_std_grouped_rewards = np.repeat(
|
||||
global_std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0
|
||||
)
|
||||
global_advantages = (scores - global_mean_grouped_rewards) / (global_std_grouped_rewards + 1e-8)
|
||||
|
||||
# Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value
|
||||
# and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed)
|
||||
lookup_advantages = np.zeros(len(global_advantages) + 1, dtype=np.float32)
|
||||
lookup_advantages[1:] = global_advantages
|
||||
packed_advantages = [
|
||||
torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32)
|
||||
for packed_mask in packed_sequences.response_masks
|
||||
]
|
||||
packed_sequences.advantages = packed_advantages
|
||||
|
||||
with Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker"):
|
||||
B = (
|
||||
len(packed_sequences.query_responses) // args.world_size
|
||||
) # essentially doing `drop_last=True`, which is fine.
|
||||
collated_data = []
|
||||
for i in range(args.world_size):
|
||||
per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)]
|
||||
per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)]
|
||||
per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)]
|
||||
per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)]
|
||||
per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)]
|
||||
|
||||
# Shuffle the batch and collate the data
|
||||
b_inds = np.random.permutation(len(per_device_packed_query_responses))
|
||||
collated_query_responses = []
|
||||
collated_attention_masks = []
|
||||
collated_position_ids = []
|
||||
collated_response_masks = []
|
||||
collated_advantages = []
|
||||
for j in range(0, len(per_device_packed_query_responses), args.per_device_train_batch_size):
|
||||
micro_range = b_inds[j : j + args.per_device_train_batch_size]
|
||||
collated_query_responses.append(
|
||||
collate_fn(
|
||||
[per_device_packed_query_responses[idx] for idx in micro_range], tokenizer.pad_token_id
|
||||
)
|
||||
)
|
||||
collated_attention_masks.append(
|
||||
collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0)
|
||||
)
|
||||
collated_position_ids.append(
|
||||
collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0)
|
||||
)
|
||||
collated_response_masks.append(
|
||||
collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0)
|
||||
)
|
||||
collated_advantages.append(
|
||||
collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0)
|
||||
)
|
||||
collated_data.append(
|
||||
{
|
||||
"collated_query_responses": collated_query_responses,
|
||||
"collated_attention_masks": collated_attention_masks,
|
||||
"collated_position_ids": collated_position_ids,
|
||||
"collated_advantages": collated_advantages,
|
||||
"collated_response_masks": collated_response_masks,
|
||||
}
|
||||
)
|
||||
sequence_lengths = np.array([len(response) for response in responses])
|
||||
metrics = {
|
||||
"scores": np.array(scores).mean(),
|
||||
"val/sequence_lengths": sequence_lengths.mean(),
|
||||
"val/sequence_lengths_min": sequence_lengths.min(),
|
||||
"val/sequence_lengths_max": sequence_lengths.max(),
|
||||
"val/stop_rate": stop_rate,
|
||||
**reward_metrics,
|
||||
}
|
||||
|
||||
if args.save_traces:
|
||||
traces = {
|
||||
"scores": scores.tolist(),
|
||||
"finish_reasons": finish_reasons,
|
||||
"responses": responses,
|
||||
"queries": queries,
|
||||
"ground_truths": items["answer"],
|
||||
"datasets": reasoning_gym_dataset.dataset_name,
|
||||
"training_step": training_step,
|
||||
**reward_metrics,
|
||||
}
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f:
|
||||
json.dump(traces, f)
|
||||
f.write("\n")
|
||||
packed_sequences_Q.put(
|
||||
{
|
||||
"packed_sequences": packed_sequences, # for debugging purposes
|
||||
"collated_data": collated_data,
|
||||
"metrics": metrics,
|
||||
"responses_count": len(responses),
|
||||
"num_new_tokens": num_new_tokens,
|
||||
"B": B,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def reward_fn(responses, items, reasoning_gym_dataset):
|
||||
metrics = {}
|
||||
formatted_responses = [extract_answer(response) for response in responses]
|
||||
correctness_rewards = [
|
||||
reasoning_gym_dataset.data.score_answer(formatted_response, entry)
|
||||
for formatted_response, entry in zip(formatted_responses, items)
|
||||
]
|
||||
format_rewards = soft_format_reward_func(responses)
|
||||
rewards = correctness_rewards + format_rewards
|
||||
metrics["correctness_rewards"] = correctness_rewards
|
||||
metrics["format_rewards"] = format_rewards
|
||||
metrics["rewards"] = rewards
|
||||
return rewards, metrics
|
||||
|
||||
|
||||
def main(args: Args, model_config: ModelConfig, reward_fn: Callable):
|
||||
calculate_runtime_args(args)
|
||||
|
||||
# Setup experiment tracking and seeds
|
||||
all_configs = {}
|
||||
beaker_config = None
|
||||
if is_beaker_job():
|
||||
beaker_config = maybe_get_beaker_config()
|
||||
all_configs.update(vars(beaker_config))
|
||||
all_configs.update(**asdict(args), **asdict(model_config))
|
||||
if args.with_tracking:
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project=args.wandb_project_name,
|
||||
entity=args.wandb_entity,
|
||||
sync_tensorboard=True,
|
||||
config=all_configs,
|
||||
name=args.run_name,
|
||||
save_code=True,
|
||||
tags=[args.exp_name] + get_wandb_tags(),
|
||||
)
|
||||
|
||||
# Setup tokenizer and get datasets
|
||||
tokenizer_revision = (
|
||||
model_config.model_revision if model_config.tokenizer_revision is None else model_config.tokenizer_revision
|
||||
)
|
||||
tokenizer_name = (
|
||||
model_config.tokenizer_name if model_config.tokenizer_name is not None else model_config.model_name_or_path
|
||||
)
|
||||
if tokenizer_revision != model_config.model_revision:
|
||||
# Warn user if tokenizer and model use different revisions; this is an unusual
|
||||
# use case.
|
||||
warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different
|
||||
from the model revision `{model_config.model_revision}`."""
|
||||
print(warning)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision)
|
||||
reasoning_gym_dataset = ReasoningGymDataset(
|
||||
args.dataset_name, args.seed, args.size, tokenizer, "system", SYSTEM_PROMPTS["DeepSeekZero"]
|
||||
)
|
||||
eval_dataset = ReasoningGymDataset(
|
||||
args.dataset_name, args.eval_seed, args.eval_size, tokenizer, "system", SYSTEM_PROMPTS["DeepSeekZero"]
|
||||
)
|
||||
|
||||
bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node]
|
||||
pg = placement_group(bundles, strategy="PACK")
|
||||
ray.get(pg.ready())
|
||||
inits = []
|
||||
policy_group = ModelGroup(
|
||||
pg,
|
||||
PolicyTrainerRayProcess,
|
||||
args.num_learners_per_node,
|
||||
args.single_gpu_mode,
|
||||
)
|
||||
|
||||
wandb_url = wandb.run.get_url() if args.with_tracking else None
|
||||
inits.extend(
|
||||
model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer)
|
||||
for model in policy_group.models
|
||||
)
|
||||
max_len = args.max_prompt_token_length + args.response_length
|
||||
vllm_engines = create_vllm_engines(
|
||||
args.vllm_num_engines,
|
||||
args.vllm_tensor_parallel_size,
|
||||
args.vllm_enforce_eager,
|
||||
model_config.model_name_or_path,
|
||||
model_config.model_revision,
|
||||
args.seed,
|
||||
args.vllm_enable_prefix_caching,
|
||||
max_len,
|
||||
args.vllm_gpu_memory_utilization,
|
||||
args.single_gpu_mode,
|
||||
pg=pg if args.single_gpu_mode else None,
|
||||
)
|
||||
ray.get(inits)
|
||||
print("======== ✅ all models and vLLM engines initialized =========")
|
||||
|
||||
ray.get([m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models])
|
||||
print("======== ✅ model update group setup successfully =========")
|
||||
|
||||
# Setup training
|
||||
generation_config = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
top_p=1.0,
|
||||
max_tokens=args.response_length,
|
||||
include_stop_str_in_output=True,
|
||||
n=args.num_samples_per_prompt_rollout,
|
||||
stop=args.stop_strings,
|
||||
)
|
||||
eval_generation_config = SamplingParams(
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_tokens=args.response_length,
|
||||
include_stop_str_in_output=True,
|
||||
n=1, # since we are doing greedy sampling, don't need to generate more
|
||||
stop=args.stop_strings,
|
||||
)
|
||||
|
||||
iter_dataloader = DataLoader(
|
||||
reasoning_gym_dataset, batch_size=args.num_unique_prompts_rollout, collate_fn=list_preserving_collate
|
||||
)
|
||||
inference_results_Q = Queue(maxsize=1)
|
||||
param_prompt_Q = Queue(maxsize=1)
|
||||
evaluation_inference_results_Q = Queue(maxsize=1)
|
||||
packed_sequences_Q = Queue(maxsize=1)
|
||||
queries_prompt_Q = Queue(maxsize=1)
|
||||
num_eval_samples = 50
|
||||
|
||||
eval_prompt_token_ids = None
|
||||
eval_ground_truths = None
|
||||
if eval_dataset is not None:
|
||||
model_eval_input, eval_items = next(iter(eval_dataset))
|
||||
eval_prompt_token_ids = model_eval_input
|
||||
resume_training_step = 1
|
||||
thread = threading.Thread(
|
||||
target=vllm_generate_thread,
|
||||
args=(
|
||||
vllm_engines,
|
||||
generation_config,
|
||||
eval_generation_config,
|
||||
inference_results_Q,
|
||||
param_prompt_Q,
|
||||
args.num_training_steps,
|
||||
eval_prompt_token_ids,
|
||||
evaluation_inference_results_Q,
|
||||
args.eval_freq,
|
||||
resume_training_step,
|
||||
),
|
||||
)
|
||||
|
||||
thread.start()
|
||||
print("======== ✅ vllm generate thread starts =========")
|
||||
|
||||
packing_thread = threading.Thread(
|
||||
target=data_preparation_thread,
|
||||
args=(
|
||||
reward_fn,
|
||||
inference_results_Q,
|
||||
packed_sequences_Q,
|
||||
queries_prompt_Q,
|
||||
args,
|
||||
tokenizer,
|
||||
args.num_training_steps,
|
||||
reasoning_gym_dataset,
|
||||
),
|
||||
)
|
||||
|
||||
packing_thread.start()
|
||||
print("======== ✅ data preparation thread starts =========")
|
||||
model_inputs, items = next(iter(iter_dataloader))
|
||||
queries_prompt_Q.put((model_inputs, items))
|
||||
param_prompt_Q.put((None, model_inputs))
|
||||
|
||||
episode = 0
|
||||
num_total_tokens = 0
|
||||
start_time = time.time()
|
||||
for training_step in range(resume_training_step, args.num_training_steps + 1):
|
||||
episode += args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
|
||||
|
||||
if training_step != 1:
|
||||
model_inputs, items = next(iter(iter_dataloader))
|
||||
with Timer("🔄 Loading weights using shared memory"):
|
||||
ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models])
|
||||
|
||||
queries_prompt_Q.put((model_inputs, items))
|
||||
param_prompt_Q.put((None, model_inputs))
|
||||
|
||||
with Timer("[Main Thread] 📦 Getting packed sequences from thread"):
|
||||
packed_data = packed_sequences_Q.get()
|
||||
data_thread_metrics = packed_data["metrics"]
|
||||
B = packed_data["B"]
|
||||
collated_data = packed_data["collated_data"]
|
||||
num_total_tokens += packed_data["num_new_tokens"]
|
||||
|
||||
if B == 0:
|
||||
print("[Main Thread] 🤡 After packing, there is not enough data to train")
|
||||
continue
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Train the model
|
||||
with Timer("[Main Thread] 🗡️ Training"):
|
||||
metrics_list: List[dict[str, float]] = ray.get(
|
||||
[
|
||||
policy_group.models[i].train.remote(
|
||||
**collated_data[i],
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
num_mini_batches=args.num_mini_batches,
|
||||
)
|
||||
for i in range(args.world_size)
|
||||
]
|
||||
)
|
||||
average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]}
|
||||
metrics = {
|
||||
"episode": episode,
|
||||
"training_step": training_step,
|
||||
"val/num_total_tokens": num_total_tokens,
|
||||
"epoch": episode / len(reasoning_gym_dataset),
|
||||
"tokens_per_second": num_total_tokens / (time.time() - start_time),
|
||||
**data_thread_metrics,
|
||||
**average_metrics,
|
||||
}
|
||||
|
||||
if args.save_freq > 0 and training_step % args.save_freq == 0:
|
||||
with Timer("[Main Thread] 🗡️ Saving model"):
|
||||
checkpoint_dir = f"{args.output_dir}_checkpoints"
|
||||
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
|
||||
print(f"Saving model at step {training_step} to {step_dir}")
|
||||
ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)])
|
||||
|
||||
print(f"Saving final model at step {training_step} to {args.output_dir}")
|
||||
with Timer("[Main Thread] 🗡️ Saving model"):
|
||||
ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)])
|
||||
|
||||
thread.join()
|
||||
print("======== ✅ vllm generate thread ends =========")
|
||||
packing_thread.join()
|
||||
print("======== ✅ data preparation thread ends =========")
|
||||
ray.shutdown()
|
||||
|
||||
if (
|
||||
args.try_auto_save_to_beaker
|
||||
and is_beaker_job()
|
||||
and len(beaker_config.beaker_dataset_id_urls) > 0
|
||||
and args.output_dir.rstrip("/") != "/output"
|
||||
):
|
||||
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
|
||||
print("finished training")
|
||||
|
||||
accelerator = Namespace()
|
||||
accelerator.is_main_process = True # hack
|
||||
if args.push_to_hub:
|
||||
print("Pushing model to hub")
|
||||
push_folder_to_hub(
|
||||
accelerator,
|
||||
args.output_dir,
|
||||
args.hf_repo_id,
|
||||
args.hf_repo_revision,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParserPlus((GRPOArgs, ModelConfig))
|
||||
args, model_config = parser.parse_args_into_dataclasses()
|
||||
assert isinstance(args, GRPOArgs)
|
||||
assert isinstance(model_config, ModelConfig)
|
||||
main(args, model_config, reward_fn)
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
import reasoning_gym
|
||||
from reasoning_gym.utils import SYSTEM_PROMPTS
|
||||
|
||||
|
||||
def list_preserving_collate(batch):
|
||||
"""
|
||||
Custom collate function that preserves lists instead of converting to tensors.
|
||||
"""
|
||||
token_ids = [item[0] for item in batch]
|
||||
items = [item[1] for item in batch]
|
||||
return token_ids, items
|
||||
|
||||
|
||||
class ReasoningGymDataset(Dataset):
|
||||
def __init__(self, dataset_name, seed, size, tokenizer, developer_role, developer_prompt):
|
||||
self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size)
|
||||
self.tokenizer = tokenizer
|
||||
self.developer_role = developer_role
|
||||
self.developer_prompt = developer_prompt
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size
|
||||
|
||||
def __getitem__(self, index):
|
||||
chat_message = [{"role": self.developer_role, "content": self.developer_prompt}]
|
||||
item = self.data[index]
|
||||
chat_message.append({"role": "user", "content": item["question"]})
|
||||
prompt_text = self.tokenizer.apply_chat_template(chat_message, tokenize=True, add_generation_prompt=True)
|
||||
prompt_text = [token for token in prompt_text if token != self.tokenizer.pad_token_id]
|
||||
return prompt_text, item
|
||||
|
||||
|
||||
def pack_sequences(
|
||||
queries: List[List[int]],
|
||||
responses: List[List[int]],
|
||||
pack_length: int,
|
||||
pad_token_id: int,
|
||||
) -> "PackedSequences":
|
||||
# assert padding token does not exist in queries and responses
|
||||
query_responses = []
|
||||
attention_masks = []
|
||||
response_masks = []
|
||||
num_actions = []
|
||||
packed_seq_lens = []
|
||||
cur_data = []
|
||||
cur_response_mask = []
|
||||
cur_num_actions = []
|
||||
cur_packed_seq_lens = []
|
||||
cur_attention_mask = []
|
||||
offset = 0
|
||||
|
||||
for i in range(len(queries)):
|
||||
query = queries[i]
|
||||
response = responses[i]
|
||||
# remove padding (but using vllm so this should not be needed, but just in case)
|
||||
query = [t for t in query if t != pad_token_id]
|
||||
response = [t for t in response if t != pad_token_id]
|
||||
query_response = query + response
|
||||
|
||||
if len(query_response) + len(cur_data) > pack_length:
|
||||
query_responses.append(cur_data)
|
||||
response_masks.append(cur_response_mask)
|
||||
attention_masks.append(cur_attention_mask)
|
||||
num_actions.append(cur_num_actions)
|
||||
packed_seq_lens.append(cur_packed_seq_lens)
|
||||
cur_data = []
|
||||
cur_response_mask = []
|
||||
cur_attention_mask = []
|
||||
cur_num_actions = []
|
||||
cur_packed_seq_lens = []
|
||||
offset = i
|
||||
|
||||
cur_data.extend(query_response)
|
||||
cur_num_actions.append(len(response))
|
||||
cur_packed_seq_lens.append(len(query_response))
|
||||
cur_response_mask.extend([0 for _ in range(len(query))] + [i + 1 for _ in range(len(response))])
|
||||
cur_attention_mask.extend([i + 1 - offset for _ in range(len(query_response))])
|
||||
|
||||
if len(cur_data) > 0:
|
||||
query_responses.append(cur_data)
|
||||
response_masks.append(cur_response_mask)
|
||||
attention_masks.append(cur_attention_mask)
|
||||
num_actions.append(cur_num_actions)
|
||||
packed_seq_lens.append(cur_packed_seq_lens)
|
||||
|
||||
attention_masks_list = [torch.tensor(t) for t in attention_masks]
|
||||
|
||||
from open_instruct.rl_utils2 import PackedSequences, reset_position_ids
|
||||
|
||||
return PackedSequences(
|
||||
query_responses=[torch.tensor(t) for t in query_responses],
|
||||
attention_masks=attention_masks_list,
|
||||
position_ids=[reset_position_ids(t.unsqueeze(0)).squeeze(0) for t in attention_masks_list],
|
||||
response_masks=[torch.tensor(t) for t in response_masks],
|
||||
original_responses=responses,
|
||||
num_actions=[torch.tensor(t) for t in num_actions],
|
||||
packed_seq_lens=[torch.tensor(t) for t in packed_seq_lens],
|
||||
)
|
||||
|
|
@ -1,15 +1,10 @@
|
|||
# Reasoning Gym configs
|
||||
dataset_size: 20000
|
||||
dataset_size: 10000
|
||||
developer_prompt: DeepSeekZero
|
||||
developer_role: system
|
||||
datasets:
|
||||
simple_equations:
|
||||
chain_sum:
|
||||
weight: 1
|
||||
complex_arithmetic:
|
||||
weight: 1
|
||||
config:
|
||||
min_real: -20
|
||||
max_real: 20
|
||||
|
||||
|
||||
# Model configs from trl
|
||||
|
|
@ -34,9 +29,9 @@ lr_scheduler_type: constant_with_warmup
|
|||
lr_scheduler_kwargs:
|
||||
num_warmup_steps: 10
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 2048
|
||||
max_completion_length: 512
|
||||
max_steps: 100
|
||||
num_generations: 8
|
||||
num_generations: 16
|
||||
num_train_epochs: 1
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 8
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class DatasetConfigItem:
|
|||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
dataset_size: int = field(default=1000)
|
||||
dataset_size: int = field(default=10000)
|
||||
developer_prompt: str = field(default="DeepSeekZero")
|
||||
developer_role: str = field(default="system")
|
||||
datasets: dict[str, DatasetConfigItem] = field(default=None)
|
||||
|
|
|
|||
30
examples/unsloth/README.md
Normal file
30
examples/unsloth/README.md
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# Chain Sum LORA Training with unsloth
|
||||
|
||||
This example demonstrates how to fine-tune an LLM with RL on a reasoning gym environment using the **unsloth** framework. Unsloth is a efficient open-source library for fine-tuning & RL. Unsloths default training path uses quantised low rank adaption (QLORA) which results in a signficantly lower memory footprint ($\approx 3x$) and means you can significantly increase batch sizes and context length without risking OOM errors.
|
||||
|
||||
Requirements:
|
||||
|
||||
python >= 3.10
|
||||
|
||||
## Installation
|
||||
|
||||
1. **Install reasoning-gym**:
|
||||
```bash
|
||||
pip install reasoning-gym
|
||||
```
|
||||
2. **Install unsloth dependencies**:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
3. **Run training script**
|
||||
To start training with unsloth with RG environments using default arguments run the following:
|
||||
|
||||
```bash
|
||||
python train_grpo_lora.py
|
||||
```
|
||||
|
||||
To customise/override any default arguments you can simply:
|
||||
```bash
|
||||
python train_grpo_lora.py --dataset-name chain_sum --max-seq-length 512 --model-id Qwen/Qwen2.5-7B-Instruct
|
||||
|
||||
**Note** the free open-source version of unsloth is currently built to train models in single GPU environments only.
|
||||
|
|
@ -147,22 +147,24 @@ def main(args):
|
|||
training_args = GRPOConfig(
|
||||
output_dir="outputs",
|
||||
use_vllm=True,
|
||||
learning_rate=5e-6,
|
||||
learning_rate=1e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type="cosine",
|
||||
weight_decay=0.0,
|
||||
warmup_ratio=0.0,
|
||||
lr_scheduler_type="constant",
|
||||
optim="adamw_8bit",
|
||||
logging_steps=1,
|
||||
bf16=is_bfloat16_supported(),
|
||||
fp16=not is_bfloat16_supported(),
|
||||
per_device_train_batch_size=args.train_batch_size,
|
||||
gradient_accumulation_steps=1,
|
||||
gradient_accumulation_steps=4,
|
||||
num_generations=args.num_generations,
|
||||
num_train_epochs=args.train_epochs,
|
||||
max_prompt_length=512,
|
||||
max_completion_length=512,
|
||||
save_steps=100,
|
||||
max_grad_norm=0.1,
|
||||
max_grad_norm=1.0,
|
||||
)
|
||||
|
||||
train(model, tokenizer, dataset, training_args)
|
||||
|
|
@ -187,17 +189,17 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model-id", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
|
||||
parser.add_argument("--dataset-name", type=str)
|
||||
parser.add_argument("--dataset-name", type=str, default="chain_sum")
|
||||
|
||||
parser.add_argument("--max-seq-length", type=int, default=1024)
|
||||
parser.add_argument("--lora-rank", type=int, default=64)
|
||||
parser.add_argument("--quantize", action="store_true")
|
||||
parser.add_argument("--num-generations", type=int, default=8)
|
||||
parser.add_argument("--num-generations", type=int, default=16)
|
||||
parser.add_argument("--train-epochs", type=int, default=1)
|
||||
parser.add_argument("--train-batch-size", type=int, default=8)
|
||||
parser.add_argument("--train-batch-size", type=int, default=16)
|
||||
|
||||
parser.add_argument("--dataset-seed", type=int, default=42)
|
||||
parser.add_argument("--dataset-size", type=int, default=1000)
|
||||
parser.add_argument("--dataset-size", type=int, default=10000)
|
||||
|
||||
parser.add_argument("--eval-seed", type=int, default=42)
|
||||
parser.add_argument("--eval-size", type=int, default=100)
|
||||
|
|
|
|||
3
examples/veRL/.gitignore
vendored
3
examples/veRL/.gitignore
vendored
|
|
@ -1,3 +0,0 @@
|
|||
outputs/
|
||||
wandb/
|
||||
verl_output.log
|
||||
|
|
@ -1,32 +1,74 @@
|
|||
### env setup
|
||||
# Chain Sum Training with veRL
|
||||
|
||||
```
|
||||
conda create --name verl python=3.11 -y
|
||||
conda activate verl
|
||||
This example demonstrates how to train a language model using veRL (Volcano Engine Reinforcement Learning) with the reasoning-gym environment for chain sum problems.
|
||||
|
||||
pip install flash-attn --no-build-isolation
|
||||
pip install ray wandb
|
||||
# pip3 install vllm==0.7.0
|
||||
pip3 install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
|
||||
Requirements:
|
||||
|
||||
python >= 3.10
|
||||
|
||||
## Installation
|
||||
|
||||
1. **Install veRL**: Follow the installation instructions at [veRL repository](https://github.com/volcengine/verl)
|
||||
|
||||
2. **Install reasoning-gym**:
|
||||
```bash
|
||||
pip install reasoning-gym
|
||||
```
|
||||
|
||||
Regarding vllm>0.7 see: [docs](https://verl.readthedocs.io/en/latest/README_vllm0.7.html)
|
||||
## Training
|
||||
|
||||
To start training the model on chain sum problems:
|
||||
|
||||
### clone and install veRL
|
||||
|
||||
tested with verl HEAD c34206925e2a50fd452e474db857b4d488f8602d
|
||||
|
||||
```
|
||||
git clone https://github.com/volcengine/verl.git
|
||||
cd verl
|
||||
pip install -e .
|
||||
```bash
|
||||
python grpo_train.py --config-path config --config-name grpo_trainer
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Optionally log in to huggingface hub and wandb with your keys:
|
||||
You can modify the training by editing the configuration file or overriding arguments in the shell scripts directly
|
||||
|
||||
```bash
|
||||
# Change dataset
|
||||
Here it is easiest to modify the `config/grpo_trainer.yaml` file with a custom training composite. Here is an example experiment which uses a composite of algorithmic training tasks
|
||||
```yaml
|
||||
reasoning_gym:
|
||||
dataset_size: 20000
|
||||
developer_prompt: DeepSeekZero
|
||||
datasets:
|
||||
ab:
|
||||
weight: 1
|
||||
base_conversion:
|
||||
weight: 1
|
||||
binary_alternation:
|
||||
weight: 1
|
||||
config:
|
||||
p_solvable: 0.9
|
||||
binary_matrix:
|
||||
weight: 1
|
||||
config:
|
||||
min_n: 2
|
||||
max_n: 6
|
||||
caesar_cipher:
|
||||
weight: 1
|
||||
config:
|
||||
max_words: 10
|
||||
cryptarithm:
|
||||
weight: 1
|
||||
isomorphic_strings:
|
||||
weight: 1
|
||||
config:
|
||||
max_string_length: 8
|
||||
```
|
||||
huggingface-cli login
|
||||
wandb login
|
||||
```
|
||||
|
||||
**Note**: In `config/grpo_trainer.yaml` we specify default arguments to be read from `file:///workspace/verl/verl/trainer/config`. Modify this accordingly if you have a different local folder setup.
|
||||
|
||||
# Change configuration Set project_name and experiment_name if logging your runs to W&B. T
|
||||
This config assumes a single GPU node, but you can configure this too. The following command would be for 2 GPUs, with 1 used for vLLM rollouts:
|
||||
|
||||
python3 -u train_grpo.py --config-paths configs/inter_generalisation --config-name algorithmic_qwen_3b \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||
trainer.n_gpus_per_node=2 \
|
||||
trainer.project_name=rg-grpo \
|
||||
trainer.experiment_name=algorithmic_qwen2.5_3b
|
||||
|
||||
Or similarly you could define this in a config file directly
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
hydra:
|
||||
searchpath:
|
||||
- file:///workspace/verl/verl/trainer/config
|
||||
defaults:
|
||||
- ppo_trainer
|
||||
- _self_
|
||||
|
||||
reasoning_gym:
|
||||
dataset_size: 20000
|
||||
dataset_size: 10000
|
||||
developer_prompt: DeepSeekZero
|
||||
datasets:
|
||||
ab:
|
||||
chain_sum:
|
||||
weight: 1
|
||||
|
||||
data:
|
||||
|
|
@ -16,7 +19,7 @@ data:
|
|||
prompt_key: prompt
|
||||
max_prompt_length: 512
|
||||
max_response_length: 512
|
||||
train_batch_size: 16
|
||||
train_batch_size: 64
|
||||
val_batch_size: 1
|
||||
|
||||
actor_rollout_ref:
|
||||
|
|
@ -32,7 +35,7 @@ actor_rollout_ref:
|
|||
strategy: fsdp # This is for backward-compatibility
|
||||
ppo_mini_batch_size: 16
|
||||
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
|
||||
ppo_micro_batch_size_per_gpu: 4
|
||||
ppo_micro_batch_size_per_gpu: 8
|
||||
use_dynamic_bsz: False
|
||||
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
|
||||
grad_clip: 1.0
|
||||
|
|
@ -70,7 +73,7 @@ actor_rollout_ref:
|
|||
# transformer_layer_cls_to_wrap: None
|
||||
min_num_params: 0
|
||||
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
|
||||
log_prob_micro_batch_size_per_gpu: 4
|
||||
log_prob_micro_batch_size_per_gpu: 8
|
||||
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
|
||||
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
|
||||
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
|
||||
|
|
@ -90,7 +93,7 @@ actor_rollout_ref:
|
|||
enforce_eager: True
|
||||
free_cache_engine: True
|
||||
load_format: dummy_dtensor
|
||||
tensor_model_parallel_size: 1
|
||||
tensor_model_parallel_size: 2
|
||||
max_num_batched_tokens: 8192
|
||||
max_num_seqs: 1024
|
||||
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
|
||||
|
|
@ -184,14 +187,14 @@ algorithm:
|
|||
|
||||
trainer:
|
||||
balance_batch: True
|
||||
total_epochs: 30
|
||||
total_epochs: 1
|
||||
total_training_steps: null
|
||||
project_name: verl_examples
|
||||
experiment_name: chain_sum
|
||||
logger: [ 'console', 'wandb' ]
|
||||
val_generations_to_log_to_wandb: 0
|
||||
nnodes: 1
|
||||
n_gpus_per_node: 1
|
||||
n_gpus_per_node: 2
|
||||
save_freq: -1
|
||||
# auto: find the last ckpt to resume. If can't find, start from scratch
|
||||
resume_mode: auto # or auto or resume_path if
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
# Chain Sum Training with veRL
|
||||
|
||||
This example demonstrates how to train a language model using veRL (Volcano Engine Reinforcement Learning) with the reasoning-gym environment for chain sum problems.
|
||||
|
||||
Requirements:
|
||||
|
||||
python >= 3.10
|
||||
|
||||
## Installation
|
||||
|
||||
1. **Install veRL**: Follow the installation instructions at [veRL repository](https://github.com/volcengine/verl)
|
||||
|
||||
2. **Install reasoning-gym**:
|
||||
```bash
|
||||
pip install reasoning-gym
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
To start training the model on chain sum problems:
|
||||
|
||||
```bash
|
||||
python grpo_train.py --config-path config --config-name grpo_trainer
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
You can modify the training by editing the configuration file or overriding arguments in the shell scripts directly
|
||||
|
||||
```bash
|
||||
# Change dataset
|
||||
Here it is easiest to modify the `config/grpo_trainer.yaml` file with a custom training composite. Here is an example experiment which uses a composite of algorithmic training tasks
|
||||
```yaml
|
||||
reasoning_gym:
|
||||
dataset_size: 20000
|
||||
developer_prompt: DeepSeekZero
|
||||
datasets:
|
||||
ab:
|
||||
weight: 1
|
||||
base_conversion:
|
||||
weight: 1
|
||||
binary_alternation:
|
||||
weight: 1
|
||||
config:
|
||||
p_solvable: 0.9
|
||||
binary_matrix:
|
||||
weight: 1
|
||||
config:
|
||||
min_n: 2
|
||||
max_n: 6
|
||||
caesar_cipher:
|
||||
weight: 1
|
||||
config:
|
||||
max_words: 10
|
||||
cryptarithm:
|
||||
weight: 1
|
||||
isomorphic_strings:
|
||||
weight: 1
|
||||
config:
|
||||
max_string_length: 8
|
||||
```
|
||||
|
||||
# Change configuration Set project_name and experiment_name if logging your runs to W&B. T
|
||||
This config assumes a single GPU node, but you can configure this too. The following command would be for 2 GPUs, with 1 used for vLLM rollouts:
|
||||
|
||||
python3 -u train_grpo.py --config-paths configs/inter_generalisation --config-name algorithmic_qwen_3b \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||
trainer.n_gpus_per_node=2 \
|
||||
trainer.project_name=rg-grpo \
|
||||
trainer.experiment_name=algorithmic_qwen2.5_3b
|
||||
|
||||
Or similarly you could define this in a config file directly
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
# Word Ladder Puzzle Dataset Generator
|
||||
|
||||
## Overview
|
||||
|
||||
This project generates a dataset of word ladder puzzles and (optionally) submits chain-of-thought reasoning requests using Anthropic's Message Batches API. Each puzzle is stored as a JSON object with the following keys: `question`, `answer`, `metadata`, and `reasoning` (initially `null`).
|
||||
|
||||
The project consists of several key components:
|
||||
|
||||
- **`main.py`**:
|
||||
Orchestrates the overall flow. It performs the following tasks:
|
||||
1. Generates a dataset of word ladder puzzles by calling functions from `utils/create_word_ladders.py`.
|
||||
2. (Optionally) Triggers the reasoning request process to augment puzzles with chain-of-thought reasoning via `utils/generate_reasoning.py`.
|
||||
3. (Planned) Additional steps such as checking results or uploading the final dataset.
|
||||
|
||||
The configuration for the dataset parameters (e.g., word length, chain length, and dataset size) is centralized here, making it easy to adjust the settings as needed.
|
||||
|
||||
- **`utils/create_word_ladders.py`**:
|
||||
Contains functions to create and validate a word ladder dataset. It leverages underlying modules (e.g., `reasoning_gym`) to generate individual puzzles and ensures uniqueness across the dataset.
|
||||
|
||||
- **`utils/generate_reasoning.py`**:
|
||||
Reads the generated dataset (in JSONL format), then filters out puzzles that already have reasoning. For puzzles missing chain-of-thought data, it splits them into batches (with a default batch size that you can adjust) and submits each batch to Anthropic's Message Batches API. Each API request includes the puzzle along with a custom system prompt (read from `system_prompt.txt`), and the resulting metadata is stored for later retrieval and analysis.
|
||||
|
||||
- **`usage_stats.py`**:
|
||||
Analyzes API response files to compute detailed usage statistics. This script:
|
||||
- Extracts token usage metrics such as `input_tokens`, `cache_creation_input_tokens`, `cache_read_input_tokens`, and `output_tokens`.
|
||||
- Calculates costs based on pricing data and shows the savings achieved through prompt caching.
|
||||
- Forecasts costs for various quantities of jobs (e.g., 2,000, 4,000, 10,000, 20,000, and 50,000 puzzles) using the observed average token usage.
|
||||
This is especially useful for monitoring your API spend and ensuring that your usage stays within budget.
|
||||
|
||||
## Warning
|
||||
|
||||
**Caution:**
|
||||
Running large batches of requests via the Anthropic API (especially in `generate_reasoning.py`) can incur significant costs in Anthropic credits. **Please review and understand your API quota and budgeting before running the API call.** If you are just testing or working with a demo dataset, ensure you adjust the batch size or dataset size appropriately to avoid unexpected charges.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Python Version:** Python 3.7+
|
||||
- **Dependencies:**
|
||||
- `tqdm`
|
||||
- `anthropic`
|
||||
- `reasoning_gym`
|
||||
- **Environment Variables:**
|
||||
For generating reasoning batches, set your Anthropic API key:
|
||||
```bash
|
||||
export ANTROPIC_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
examples/word_ladder/
|
||||
├── main.py
|
||||
├── utils/
|
||||
│ ├── create_word_ladders.py
|
||||
│ ├── generate_reasoning.py
|
||||
│ └── system_prompt.txt
|
||||
├── usage_stats.py
|
||||
```
|
||||
|
||||
|
||||
## Configuration
|
||||
|
||||
The dataset generation parameters are centralized in `main.py` under the `config` dictionary. You can adjust settings like:
|
||||
|
||||
- **Word Length:**
|
||||
- `min_word_length`
|
||||
- `max_word_length`
|
||||
|
||||
- **Chain Length:**
|
||||
- `min_chain_length` (e.g., set to -1 for the shortest possible chain)
|
||||
- `max_chain_length`
|
||||
|
||||
- **Dataset Size:**
|
||||
- `size` — the number of puzzles to generate (e.g., `1000` for a demo)
|
||||
|
||||
## How to Run
|
||||
|
||||
1. **Generate the Dataset**
|
||||
|
||||
Run the main script:
|
||||
```bash
|
||||
python3 main.py
|
||||
```
|
||||
This does the following:
|
||||
- Generates a unique JSONL file containing the word ladder puzzles in the `output` folder.
|
||||
- Calls functions from `utils/create_word_ladders.py` to create the puzzles.
|
||||
- Optionally (if enabled), submits the puzzles for chain-of-thought reasoning via the API.
|
||||
|
||||
2. **Submit Reasoning Batches (Optional)**
|
||||
|
||||
To generate chain-of-thought reasoning for puzzles:
|
||||
- Verify that `ANTHROPIC_API_KEY` is set.
|
||||
- Confirm that `system_prompt.txt` is present in the `/examples/word_ladder` folder and contains the desired system prompt.
|
||||
- In `main.py`, uncomment the reasoning submission section to enable the API call, or run directly:
|
||||
```bash
|
||||
python3 utils/generate_reasoning.py
|
||||
```
|
||||
|
||||
**Warning:** Be aware that submitting large batches can quickly incur high costs in Anthropic credits.
|
||||
|
||||
3. **Compute Usage Statistics**
|
||||
|
||||
After running batches through the API, you can analyze the cost and token usage statistics with:
|
||||
```bash
|
||||
python3 usage_stats.py path/to/msgbatch_results.jsonl
|
||||
```
|
||||
This script provides detailed costing information, token usage per query, savings from caching, and forecasting for future job batches.
|
||||
|
||||
## Output
|
||||
|
||||
- All generated datasets and batch metadata files are stored in the `/examples/word_ladder/output` folder.
|
||||
- After submitting reasoning batches via Anthropic's API, you can monitor progress and download the batch results from the Anthropic web dashboard.
|
||||
- Use `usage_stats.py` to compute detailed statistics and forecast future costs based on your current usage and token pricing.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **File Paths:**
|
||||
Verify that `system_prompt.txt` is in the `/examples/word_ladder` folder as expected. The modules use paths relative to their location.
|
||||
|
||||
- **Environment Variables:**
|
||||
Make sure your `ANTHROPIC_API_KEY` is set correctly when submitting API requests.
|
||||
|
||||
- **Output Directory Permissions:**
|
||||
Ensure the `output` directory exists and is writable by your user.
|
||||
|
||||
- **Cost Monitoring:**
|
||||
Check your Anthropic API usage and account balance before running large batches to avoid unexpected costs.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License.
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
main.py – Orchestrates the overall flow:
|
||||
1. Generate word ladder sets
|
||||
2. Submit chain-of-thought reasoning requests in batches via the LLM
|
||||
3. Upload the final dataset to HuggingFace Hub (if needed)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from examples.word_ladder.utils import create_word_ladders, generate_reasoning
|
||||
|
||||
|
||||
def create_dataset(jsonl_path: Path, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Creates the word ladder dataset, handling potential exhaustion gracefully.
|
||||
|
||||
Returns:
|
||||
bool: True if dataset was created (even if truncated), False if creation failed
|
||||
"""
|
||||
try:
|
||||
print("Step 1: Algorithmically creating word ladder chains...")
|
||||
create_word_ladders.create_word_ladder_dataset(str(jsonl_path), config=config)
|
||||
return True
|
||||
|
||||
except IndexError as e:
|
||||
# Dataset was exhausted but some examples were generated
|
||||
print("\nNote: Dataset generation stopped early due to exhaustion of unique puzzles.")
|
||||
print(f"Reason: {str(e)}")
|
||||
if jsonl_path.exists():
|
||||
print("Continuing with the partial dataset that was successfully generated.")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
# Unexpected error during dataset creation
|
||||
print(f"\nError: Failed to create dataset: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
# Centralized configuration for the dataset
|
||||
config = {
|
||||
"dataset_name": "word_ladder",
|
||||
"dataset_config": {
|
||||
"min_word_length": 3,
|
||||
"max_word_length": 3,
|
||||
"min_chain_length": -1, # set to -1 for the shortest possible path
|
||||
"max_chain_length": 7,
|
||||
"size": 2000, # Generate a small-ish dataset for demonstration
|
||||
},
|
||||
}
|
||||
|
||||
# Generate a friendly unique identifier and compose the file path
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
output_dir = Path(__file__).resolve().parent / "output"
|
||||
output_dir.mkdir(exist_ok=True) # Create output directory if it doesn't exist
|
||||
jsonl_path = output_dir / f"word_ladders_{unique_id}.jsonl"
|
||||
|
||||
# Step 1: Create the dataset
|
||||
if not create_dataset(jsonl_path, config):
|
||||
print("Exiting due to dataset creation failure.")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 2: Generate reasoning
|
||||
|
||||
try:
|
||||
print("\nStep 2: Submitting reasoning batches for the dataset...")
|
||||
generate_reasoning.submit_reasoning_batches(input_path=str(jsonl_path))
|
||||
except Exception as e:
|
||||
print(f"\nError: Failed to submit reasoning batches: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
# Step 3: Check Anthropic batch results
|
||||
# Step 4: Upload to HuggingFace 🤗
|
||||
|
||||
print("\nComplete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1 +0,0 @@
|
|||
{"batch_id": "msgbatch_01J7Dev6petQwq41gNpeGhUK", "api_response": {"id": "msgbatch_01J7Dev6petQwq41gNpeGhUK", "type": "message_batch", "processing_status": "in_progress", "request_counts": {"canceled": 0, "errored": 0, "expired": 0, "processing": 1000, "succeeded": 0}, "ended_at": null, "created_at": "2025-02-03T06:07:04.440375Z", "expires_at": "2025-02-04T06:07:04.440375Z", "cancel_initiated_at": null, "results_url": null}, "custom_ids": ["PAINT_LACKS_0", "CONKS_GRAYS_1", "WINGS_WORMY_2", "PINTO_RASPS_3", "GRACE_WAINS_4", "RACES_VAPES_5", "LAIRS_FUELS_6", "WINKS_TENTH_7", "VERBS_DATES_8", "SCENT_HEARS_9", "GRACE_CHIVE_10", "COOPS_SPARE_11", "ROOMY_PLUME_12", "PORKS_HOSTA_13", "DEARS_TABBY_14", "SNAIL_ADOBE_15", "SMALL_BLINK_16", "SOAPS_SPOTS_17", "HOMES_LOVEY_18", "SPURN_STEEP_19", "PLODS_SWIFT_20", "TACOS_SETTS_21", "BANKS_KICKY_22", "POUTS_CAKES_23", "RESTS_SWUNG_24", "PORKY_FISKS_25", "PUMAS_NAPES_26", "WANDS_PRICE_27", "DROLL_SORES_28", "JOWLS_PACTS_29", "ATONE_GOOFS_30", "BRIGS_SHUNT_31", "RAPID_DOXES_32", "CHESS_DOLED_33", "BUDGE_BOOBS_34", "ROPED_SAGAS_35", "TATTS_SPARE_36", "RALES_NATES_37", "RUFFS_GONGS_38", "VEGES_HUNTS_39", "SCULL_PAIRS_40", "DICKY_MOWED_41", "PANKO_ADDER_42", "TILLS_GAINS_43", "OASIS_DOXED_44", "BONED_MEALS_45", "LAMBS_TOADY_641", "PIKER_HEAVY_642", "BANKS_TRUGS_643", "VEEPS_DIVAS_644", "CHATS_CLEFS_645", "SNITS_PLAIN_646", "CHARY_SPITE_647", "EARLY_DUMMY_977", "MEDAL_COYER_978", "FEATS_MOOKS_979", "MALES_HEXES_980", "DOGGO_MUGGY_981", "GIRDS_SOAPS_982", "SLASH_TRYST_983", "LOOPS_FLOES_984", "SLATE_SMORE_985", "SINEW_SULKS_986", "SCORN_SHOUT_987", "MAMAS_VOIDS_988", "TWIGS_BLOKE_989", "KOOKY_SPINE_990", "FLUES_SMITE_991", "DRAWS_MAIMS_992", "SMITH_PLEAS_993", "BERGS_OPINE_994", "SNAKE_BLABS_995", "WADED_DRIBS_996", "YAPPY_PIPES_997", "STAGY_COONS_998", "PEAKY_MURAL_999"], "input_file": "word_ladders_029573f4.jsonl"}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
{"custom_id":"PAINT_LACKS_0","result":{"type":"succeeded","message":{"id":"msg_01Kif1LqT19GgPdSEQiuFjwu","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"<think>\nOkay, let me work on transforming PAINT into LACKS by changing one letter at a time. First, let me analyze both words to understand what needs to change.\n\nPAINT → LACKS\n\nLooking at the letters:\nP needs to become L\nA stays A\nI needs to become C\nN needs to become K\nT needs to become S\n\nSo four letters need to change (all except A). Let me start exploring possible paths by changing one letter at a time.\n\nStarting with PAINT. What changes can I make? Let's see:\n- Change P to something: FAINT, SAINT, TAINT (all valid)\n- Change I to something: PANT (valid)\n- Change N to something: PAINT to PAIT (not valid)\n- Change T to something: PAIN (valid, but loses a letter... no, need to keep 5 letters)\nWait, what about PAINS (changing T to S)? Yes, PAINS is a valid word! \n\nSo first step: PAINT → PAINS (changing T to S)\n\nNow from PAINS, what can I do? Need to eventually get to LACKS. Let's see possible changes:\n- Change P to something: GAINS, MAINS, RAINS, WAINS\n- Change I to something: PANS (loses a letter, not allowed)\n- Change N to something: PAIRS (valid, but moves away from target LACKS)\n\nWAINS looks interesting (changing P to W). Is WAINS a valid word? Yes, it refers to wagons or carts historically. \n\nStep 2: PAINS → WAINS\n\nNow from WAINS, what's next? Looking at LACKS, we still need to get that C and K in there. Let's see:\n- Change A to something: WINS (loses letter)\n- Change I to something: WANS (loses letter)\n- Change N to something: WAITS (valid but moves away from target)\n- Change S to something: WAINT (not a word)\nWait, what about WARNS (changing I to R)? Yes! WARNS is definitely a valid word.\n\nStep 3: WAINS → WARNS\n\nNow from WARNS to LACKS. Still need L, C, and K. From WARNS:\n- Change W to something: BARNS, EARNS, YARNS\nBARNS looks promising! Because B is closer to L in the alphabet than W. \n\nStep 4: WARNS → BARNS\n\nFrom BARNS to LACKS. Let's see:\n- Change N to something: BARS (loses letter)\n- Change S to something: BARN (loses S we need)\n- Change R to K perhaps? BANKS? No, need to keep R for now.\nWait! What about BARKS (changing N to K)? Yes! That works!\n\nStep 5: BARNS → BARKS\n\nNow BARKS. Getting closer! We need to change R to C and B to L to get to LACKS. Let's try:\n- Change B to something: LARKS, MARKS, PARKS\n- Change R to C: BACKS! Yes, that works!\n\nStep 6: BARKS → BACKS\n\nFinally, from BACKS to LACKS, just need to change B to L:\n\nStep 7: BACKS → LACKS\n\nLet me verify the entire sequence:\nPAINT → PAINS (T→S)\nPAINS → WAINS (P→W)\nWAINS → WARNS (I→R)\nWARNS → BARNS (W→B)\nBARNS → BARKS (N→K)\nBARKS → BACKS (R→C)\nBACKS → LACKS (B→L)\n\nChecking each word:\nPAINT: valid\nPAINS: valid (plural of pain)\nWAINS: valid (old wagons)\nWARNS: valid (gives warning)\nBARNS: valid (farm buildings)\nBARKS: valid (tree covering or dog sound)\nBACKS: valid (plural of back)\nLACKS: valid (is without)\n\nEach step changes exactly one letter and creates a valid English word. Seven steps total, which makes sense since we needed to change four letters to get from PAINT to LACKS, and we had to make some strategic intermediate changes to get there. This seems like a valid solution!\n\nLet me double-check if there might be a shorter path... Could we have gone:\nPAINT → PANT (remove I? No, must keep 5 letters)\nPAINT → PARTS (change N to R and T to S? No, that's two changes at once)\nPAINT → PANIC (T to C, but moves away from target)\n\nNo, I think the path we found is optimal because we needed those intermediate steps to gradually transform the word while keeping valid English words at each step. The sequence PAINT→PAINS→WAINS→WARNS→BARNS→BARKS→BACKS→LACKS works perfectly!\n</think>"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":63,"cache_creation_input_tokens":0,"cache_read_input_tokens":6434,"output_tokens":1149}}}}
|
||||
{"custom_id":"CONKS_GRAYS_1","result":{"type":"succeeded","message":{"id":"msg_01JrxHJyxFxsTVsn8huufv8T","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"<think>\nAlright, let me try to transform CONKS into GRAYS by changing one letter at a time. First, let me write down both words to analyze what needs to change:\n\nCONKS\nGRAYS\n\nLooking at the letters that need to change:\nC needs to become G\nO needs to become R\nN needs to become A\nK needs to become Y\nS stays S\n\nHmm, four letters need to change, but S stays the same at the end. Let me start by looking at possible changes from CONKS.\n\nFirst letter options from CONKS:\n- Change C to B: BONKS (valid)\n- Change C to D: DONKS (valid)\n- Change C to G: GONKS (is this valid? Not sure)\n- Change C to H: HONKS (valid)\n\nSecond letter options:\n- Change O to A: CANKS (is this valid? Not sure)\n- Change O to I: CINKS (probably not valid)\n- Change O to O: stays same\n- Change O to U: CUNKS (not sure if valid)\n\nWait, maybe instead of changing the first letter right away to G, let's try working with the O. What if we change O to another letter that might help us get closer? \n\nLet me try: CONKS → COOKS (changing N to O)\nYes! COOKS is definitely a valid word. That feels like progress.\n\nFrom COOKS, what can we do?\n- Change first C to G? GOOKS (probably not appropriate)\n- Change second O to R? No, need to keep both O's for now\n- Change K to something? COOLS? No, moving away from target\n- Change S? No, need to keep S\n\nWhat about changing one of the O's to something? COOKS → COOPS (changing K to P)\nYes! COOPS is valid (like chicken coops). This feels promising.\n\nNow from COOPS:\n- Change C to G? GOOPS (not sure if valid)\n- Change first O to R? CROPS! Yes, that works!\n\nCROPS is a valid word. Now we're getting somewhere. Let's see what we can do from CROPS:\n- Change O to A? CRAPS! That's valid and gets us closer to GRAYS\n\nFrom CRAPS:\n- Change P to Y? CRAYS! Perfect!\n\nAnd finally, from CRAYS:\n- Change C to G: GRAYS! We've reached the target!\n\nLet me verify this path:\n1. CONKS → COOKS (N to O)\n2. COOKS → COOPS (K to P)\n3. COOPS → CROPS (O to R)\n4. CROPS → CRAPS (O to A)\n5. CRAPS → CRAYS (P to Y)\n6. CRAYS → GRAYS (C to G)\n\nLet me check each word again:\nCONKS - valid (to hit on head)\nCOOKS - valid\nCOOPS - valid (enclosures)\nCROPS - valid\nCRAPS - valid\nCRAYS - valid (types of fish)\nGRAYS - valid\n\nThis seems to work perfectly! Each step changes exactly one letter, and all words are valid English words. Let me double-check if there might be a shorter path.\n\nCould we go CONKS → CONES? No, that doesn't help get to GRAYS.\nWhat about CONKS → CRANKS? No, that's two letter changes.\nOr CONKS → CORKS? Again, doesn't seem to lead anywhere useful.\n\nThe path we found:\nCONKS → COOKS → COOPS → CROPS → CRAPS → CRAYS → GRAYS\nseems optimal because:\n1. Each step is a valid word\n2. Each change moves us closer to our target\n3. We maintain the final S throughout\n4. We gradually transform the word while keeping it valid at each step\n\nYes, I think this is the solution. Six steps, changing one letter at a time, all valid words, and each change gets us closer to our target GRAYS.\n</think>"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":64,"cache_creation_input_tokens":0,"cache_read_input_tokens":6434,"output_tokens":956}}}}
|
||||
{"custom_id":"WINGS_WORMY_2","result":{"type":"succeeded","message":{"id":"msg_01KXJZXTZgQoTnLvMqHphEcS","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"<think>\nOkay, let me tackle this word ladder from WINGS to WORMY. First, I'll write both words down and analyze what letters need to change:\n\nWINGS\nWORMY\n\nLooking at the letters:\nW stays the same\nI needs to become O\nN needs to become R\nG needs to become M\nS needs to become Y\n\nSo four letters need to change, keeping only the W. Let me start exploring possible paths by changing one letter at a time.\n\nStarting with WINGS, what can I change? Let's look at each position:\n- First letter (W): stays the same, so no need to change\n- Second letter (I): could become A, E, O, U\n- Third letter (N): could become other consonants\n- Fourth letter (G): could become other consonants\n- Fifth letter (S): could become other letters\n\nLet me try changing I first since it needs to become O eventually:\nWINGS → WONGS? Is that a word? Not sure. \n\nMaybe change N first:\nWINGS → WIGS? No, that removes a letter. Wait, I need to keep it 5 letters.\nWINGS → WINGS (same)\nWINGS → WINKS (changing G to K)! Yes, WINKS is definitely a valid word.\n\nOkay, so WINGS → WINKS works. Now from WINKS, what's next? I still need to:\n- Change I to O\n- Change K to R\n- Change S to Y\n\nLet me try changing I to O in WINKS:\nWINKS → WONKS\nIs WONKS a word? I think it is - it's the plural of wonk, someone who studies a subject extensively. Yes, that works!\n\nSo now I have:\nWINGS → WINKS → WONKS\n\nFrom WONKS, I need to:\n- Change K to R\n- Change S to Y\n\nLet me try changing K to R:\nWONKS → WORKS\nPerfect! WORKS is definitely a valid word.\n\nSo now:\nWINGS → WINKS → WONKS → WORKS\n\nFrom WORKS, I still need to:\n- Change S to Y\n\nBut wait, looking at the target WORMY, I notice I need M before Y. So from WORKS, maybe I should change S to M first?\n\nWORKS → WORMS\nYes! WORMS is a valid word.\n\nNow I have:\nWINGS → WINKS → WONKS → WORKS → WORMS\n\nFinally, from WORMS to WORMY, just need to change S to Y:\nWORMS → WORMY\n\nLet me verify the complete path:\n1. WINGS → WINKS (change G to K)\n2. WINKS → WONKS (change I to O)\n3. WONKS → WORKS (change K to R)\n4. WORKS → WORMS (change K to M)\n5. WORMS → WORMY (change S to Y)\n\nLet me check if each word is valid:\n- WINGS: valid (plural of wing)\n- WINKS: valid (plural of wink)\n- WONKS: valid (plural of wonk)\n- WORKS: valid (plural of work)\n- WORMS: valid (plural of worm)\n- WORMY: valid (full of worms)\n\nCould there be a shorter path? Let me think...\nCould I go WINGS → WONGS (I to O)? But is WONGS a valid word? Not sure about that.\nOr WINGS → WRINGS (add R)? No, that would be adding a letter, not changing one.\nOr WINGS → WORMS directly? No, that would require changing multiple letters at once.\n\nThe path I found requires five steps, which makes sense because we needed to change four letters (I→O, N→R, G→M, S→Y) and couldn't combine any changes since we can only change one letter at a time.\n\nSo after exploring alternatives, I believe the path:\nWINGS → WINKS → WONKS → WORKS → WORMS → WORMY\nis optimal, with each step changing exactly one letter and forming a valid English word.\n</think>"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":58,"cache_creation_input_tokens":0,"cache_read_input_tokens":6434,"output_tokens":962}}}}
|
||||
{"custom_id":"PINTO_RASPS_3","result":{"type":"succeeded","message":{"id":"msg_01GJQuw2NwcJQpx2civyj8Fu","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"<think>\nAlright, I need to transform PINTO into RASPS by changing one letter at a time. Let me analyze both words first to understand what needs to change.\n\nPINTO\nRASPS\n\nLooking at the letters:\nP -> R\nI -> A\nN -> S\nT -> P\nO -> S\n\nHmm, all five letters need to change. This might be tricky. Let me start by exploring possible changes from PINTO.\n\nFirst letter (P) possibilities:\n- MINTO (is that a word? Maybe a proper noun, but that's not allowed)\n- RINTO (not a word)\n- TINTO (might be Spanish, not valid English)\n\nSecond letter (I) possibilities:\n- PANTO (yes, that's valid!)\n- PENTO (not sure if that's valid)\n- PUNTO (Spanish, not English)\n\nThird letter (N) possibilities:\n- PINTO -> PINTS (changing O to S)\nThat's interesting! PINTS is definitely valid. Maybe that's a good first step since we need to end with 'S' anyway.\n\nLet me try PINTO -> PINTS first.\nNow from PINTS, what can we do? We need to eventually get to RASPS, so maybe changing that I to A would help?\nPINTS -> PANTS (changing I to A)\nYes! PANTS is valid. Now we have:\nPINTO -> PINTS -> PANTS\n\nFrom PANTS, looking at the letters compared to RASPS:\nP -> R\nA (matches!)\nN -> S\nT -> P\nS (matches!)\n\nMaybe we can change that T to R? No, that wouldn't help because we still need P in that position eventually. What about changing N to R? No, we need S there. \n\nLet me try changing T to something... PANTS -> PARTS (changing N to R)\nYes! PARTS is valid. So now:\nPINTO -> PINTS -> PANTS -> PARTS\n\nFrom PARTS to RASPS, we need to:\n- Change P to R at the start\n- Change T to P\n- Keep the S\n\nWait, but if I change P to R now, how would I get the P back later? Maybe I need an intermediate step.\n\nWhat if... PARTS -> WARTS (changing P to W)?\nYes! That's valid. Then from WARTS, maybe... WARPS (changing T to P)?\nWARTS -> WARPS works! \n\nSo now I have:\nPINTO -> PINTS -> PANTS -> PARTS -> WARTS -> WARPS\n\nGetting closer! From WARPS to RASPS, I need to:\n- Change W to R\n- Change P to S in the middle\n\nBut wait, if I change W to R directly, would that work? Let me try another approach...\nWARPS -> WASPS (changing R to S)\nYes! That works. Then finally:\nWASPS -> RASPS (changing W to R)\n\nLet me verify the complete path:\n1. PINTO -> PINTS (O to S)\n2. PINTS -> PANTS (I to A)\n3. PANTS -> PARTS (N to R)\n4. PARTS -> WARTS (P to W)\n5. WARTS -> WARPS (T to P)\n6. WARPS -> WASPS (R to S)\n7. WASPS -> RASPS (W to R)\n\nChecking each word:\nPINTO - valid (type of horse/car)\nPINTS - valid (plural of pint)\nPANTS - valid (clothing)\nPARTS - valid (components)\nWARTS - valid (skin growths)\nWARPS - valid (bends/distorts)\nWASPS - valid (insects)\nRASPS - valid (coarse files/harsh sounds)\n\nEach step changes exactly one letter and creates a valid English word. Seven steps in total. Let me double-check if there might be a shorter path...\n\nCould I have gone from PINTS directly to PARTS? No, that would be changing two letters (I to A and N to R).\nCould I have skipped WARTS? If I went PARTS -> WARPS directly, that would change two letters (P to W and T to P).\nCould I have gone from WARPS to RASPS directly? No, that would change two letters (W to R and P to S).\n\nNo, each step seems necessary. Therefore, this seven-step solution appears to be valid and efficient:\nPINTO -> PINTS -> PANTS -> PARTS -> WARTS -> WARPS -> WASPS -> RASPS\n</think>"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":67,"cache_creation_input_tokens":6434,"cache_read_input_tokens":0,"output_tokens":1081}}}}
|
||||
{"custom_id":"GRACE_WAINS_4","result":{"type":"succeeded","message":{"id":"msg_01He4e1HXh1DXdxPz2j7K3GZ","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"<think>\nLet me tackle this word ladder puzzle. I need to transform GRACE into WAINS by changing just one letter at a time, making sure each step creates a valid English word.\n\nFirst, let me compare the starting and ending words:\nGRACE\nWAINS\n\nLooking at what needs to change:\nG needs to become W\nR needs to become A\nA needs to become I\nC needs to become N\nE needs to become S\n\nWow, all five letters need to change! This might be tricky. Let me start by looking at possible changes from GRACE.\n\nFrom GRACE, I could change:\n- First letter (G): BRACE, TRACE, SPACE... but I need to get to W eventually\n- Second letter (R): GRACE to GLACE? That's valid\n- Third letter (A): GRACE to GRECE? No, that's not a word\n- Fourth letter (C): GRACE to GRADE, GRAVE, GRAPE... wait, GRAPE is interesting\n\nLet me try GRACE → GRAPE (changing C to P)\nThat works! Now from GRAPE, I need to somehow get closer to WAINS.\n\nLooking at GRAPE vs WAINS:\nG still needs to become W\nR still needs to become A\nA needs to become I\nP needs to become N\nE needs to become S\n\nMaybe I can change the A to I next? GRAPE → GRIPE\nYes! GRIPE is a valid word. That's progress because now the middle letter matches our target word.\n\nFrom GRIPE:\nG still needs to become W\nR still needs to become A\nP needs to become N\nE needs to become S\n\nLooking at GRIPE... what can I change? The final E needs to become S eventually. But wait, looking at the P... if I change P to S, would that work? GRIPE → GRISE? No, that's not common. What about GRIPS? Yes! GRIPE → GRIPS works by changing E to S.\n\nNow I have GRIPS. Getting closer! Looking at what's left:\nG needs to become W\nR needs to become A\nI stays I\nP needs to become N\nS stays S\n\nFrom GRIPS, maybe I can change P to N? GRIPS → GRINS\nYes! That works. Now I have GRINS.\n\nLooking at GRINS vs WAINS:\nG needs to become W\nR needs to become A\nI stays I\nN stays N\nS stays S\n\nFrom GRINS, I need to change R to A. GRINS → GAINS\nPerfect! Now I just need to change G to W to reach the target.\n\nGAINS → WAINS\n\nLet me verify the complete path:\n1. GRACE → GRAPE (C→P)\n2. GRAPE → GRIPE (A→I)\n3. GRIPE → GRIPS (E→S)\n4. GRIPS → GRINS (P→N)\n5. GRINS → GAINS (R→A)\n6. GAINS → WAINS (G→W)\n\nLet me check each word:\nGRACE - valid\nGRAPE - valid (fruit)\nGRIPE - valid (complaint)\nGRIPS - valid (holds)\nGRINS - valid (smiles)\nGAINS - valid (increases)\nWAINS - valid (wagons)\n\nAnd each step only changes one letter. This works! Though let me just double-check if there might be a shorter path...\n\nCould I have gone from GRACE directly to WAINS in fewer steps? No, because five letters need to change, so minimum five steps needed. Could I have gone:\nGRACE → GRANS? (Not a word)\nGRACE → GRINS? (No, that would be two letters)\nGRACE → WRACE? (Not a word)\n\nNo, I think the path I found is optimal because:\n1. Each step is a valid word\n2. Each step changes only one letter\n3. We needed to change all five letters\n4. Each change gets us closer to the target\n\nTherefore, the solution path GRACE → GRAPE → GRIPE → GRIPS → GRINS → GAINS → WAINS is correct and efficient.\n</think>"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":61,"cache_creation_input_tokens":0,"cache_read_input_tokens":6434,"output_tokens":967}}}}
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
{"question": "Transform the word ladder 'PAINT' to 'LACKS' by changing one letter at a time.", "answer": "PAINT,PAINS,WAINS,WARNS,BARNS,BARKS,BACKS,LACKS", "reasoning": null, "metadata": {"start_word": "PAINT", "end_word": "LACKS", "word_length": 5, "chain_length": 8}}
|
||||
{"question": "Transform the word ladder 'CONKS' to 'GRAYS' by changing one letter at a time.", "answer": "CONKS,COOKS,COOPS,CROPS,CRAPS,CRAYS,GRAYS", "reasoning": null, "metadata": {"start_word": "CONKS", "end_word": "GRAYS", "word_length": 5, "chain_length": 7}}
|
||||
{"question": "Transform the word ladder 'WINGS' to 'WORMY' by changing one letter at a time.", "answer": "WINGS,WINKS,WONKS,WORKS,WORMS,WORMY", "reasoning": null, "metadata": {"start_word": "WINGS", "end_word": "WORMY", "word_length": 5, "chain_length": 6}}
|
||||
{"question": "Transform the word ladder 'PINTO' to 'RASPS' by changing one letter at a time.", "answer": "PINTO,PINTS,PANTS,PARTS,WARTS,WARPS,WASPS,RASPS", "reasoning": null, "metadata": {"start_word": "PINTO", "end_word": "RASPS", "word_length": 5, "chain_length": 8}}
|
||||
{"question": "Transform the word ladder 'GRACE' to 'WAINS' by changing one letter at a time.", "answer": "GRACE,GRAPE,GRIPE,GRIPS,GRINS,GAINS,WAINS", "reasoning": null, "metadata": {"start_word": "GRACE", "end_word": "WAINS", "word_length": 5, "chain_length": 7}}
|
||||
{"question": "Transform the word ladder 'RACES' to 'VAPES' by changing one letter at a time.", "answer": "RACES,RAPES,VAPES", "reasoning": null, "metadata": {"start_word": "RACES", "end_word": "VAPES", "word_length": 5, "chain_length": 3}}
|
||||
{"question": "Transform the word ladder 'LAIRS' to 'FUELS' by changing one letter at a time.", "answer": "LAIRS,FAIRS,FAILS,FARLS,FURLS,FUELS", "reasoning": null, "metadata": {"start_word": "LAIRS", "end_word": "FUELS", "word_length": 5, "chain_length": 6}}
|
||||
{"question": "Transform the word ladder 'WINKS' to 'TENTH' by changing one letter at a time.", "answer": "WINKS,PINKS,PINTS,TINTS,TENTS,TENTH", "reasoning": null, "metadata": {"start_word": "WINKS", "end_word": "TENTH", "word_length": 5, "chain_length": 6}}
|
||||
{"question": "Transform the word ladder 'VERBS' to 'DATES' by changing one letter at a time.", "answer": "VERBS,HERBS,HERDS,HARDS,HARES,HATES,DATES", "reasoning": null, "metadata": {"start_word": "VERBS", "end_word": "DATES", "word_length": 5, "chain_length": 7}}
|
||||
{"question": "Transform the word ladder 'SCENT' to 'HEARS' by changing one letter at a time.", "answer": "SCENT,SCANT,SCART,SCARS,SEARS,HEARS", "reasoning": null, "metadata": {"start_word": "SCENT", "end_word": "HEARS", "word_length": 5, "chain_length": 6}}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Required dependencies for the Word Ladder Puzzle Dataset Generator
|
||||
# Python 3.7+
|
||||
|
||||
anthropic>=0.45.2 # Client library for interacting with Anthropic's API
|
||||
tqdm>=4.67.1 # For progress bars
|
||||
|
|
@ -1,288 +0,0 @@
|
|||
Word Ladder puzzles involve transforming a start word into an end word. You are allowed to change only one letter a time and you must keep the number of letters constant.
|
||||
Each time you change one letter the word in the chain must be forming a new word that's valid in English. Plurals are allowed, but not proper nouns.
|
||||
Your task will be to simulate a detailed step-by-step chain of thought reasoning of the transformation process from a start word to the end word of a word ladder.
|
||||
You will be given the word ladder question, as well as the correct solution path. So you don't need to solve the problem, you have the solution already, which should be your 'cheat sheet' to refer to as you go along.
|
||||
Your task is to provide a perfectly simulated chain of thought reasoning exactly in the style of the example below,
|
||||
including considering multiple possibilities, backtracking where needed, validating words, and showing the final path.
|
||||
Remember you're NOT trying to 'find' the shortest path necessarily. Make sure you're not lazy and show the user that you're working hard to find the intermediate words in the chain, and that you haven't gotten simply lucky in finding the steps.
|
||||
For example, you may go down a path that's no longer getting you closer to the solution, so you backtrack and try a different path, which ends up working.
|
||||
Use casual, free flowing verbose, human-like thinking language. You must convince the user that you're really 'REASONING'. You can see the real reasoning in the examples below, which may look convoluted, but it's true reasoning, that is what we need.
|
||||
It's ok to pretend catch yourself making a mistake, and then correct yourself, adds realism to the chain of thought.
|
||||
VERY IMPORTANT ***NEVER MENTION YOU HAVE BEEN PROVIDED WITH THE ANSWER, YOU NEED TO WRITE ONLY AS IF YOU'RE SOLVING THE QUESTION STEP BY STEP ***
|
||||
|
||||
Example chain of thought reasoning:
|
||||
<think>
|
||||
Okay, so I need to figure out how to transform the word "CARK" into "ZIPS" by changing one letter at a time, making sure each intermediate step is a valid English word. Let me start by breaking down the problem.
|
||||
|
||||
First, both words are four letters long, so each step must maintain four letters. Each change can only alter one letter, either adding, removing, or replacing, but since the length must stay the same, it has to be a replacement each time. So, I need to find a path from CARK to ZIPS where each intermediate word is valid.
|
||||
|
||||
Let me write down the starting and target words:
|
||||
|
||||
CARK
|
||||
|
||||
...
|
||||
|
||||
ZIPS
|
||||
|
||||
I need to see which letters need to change. Let's compare the letters:
|
||||
|
||||
C -> Z
|
||||
|
||||
A -> I
|
||||
|
||||
R -> P
|
||||
|
||||
K -> S
|
||||
|
||||
So all four letters need to change. But since each step can only change one letter, I need to find a sequence where each single-letter change results in a valid word. Maybe I can approach this by changing one letter at a time, checking possible words each step.
|
||||
|
||||
Starting with CARK. Let's look at possible changes for each letter:
|
||||
|
||||
First letter C: changing to other letters to form a word. Possible options could be BARK (C to B), DARK (C to D), LARK (C to L), MARK (C to M), PARK (C to P), etc. Wait, but CARK is a valid word? Let me confirm. CARK means to worry or burden, so yes, it's valid.
|
||||
|
||||
So, possible first steps from CARK:
|
||||
|
||||
Change C to B: BARK
|
||||
|
||||
Change C to D: DARK
|
||||
|
||||
Change C to L: LARK
|
||||
|
||||
Change C to M: MARK
|
||||
|
||||
Change C to P: PARK
|
||||
|
||||
Change C to W: WARK (is that a word? Wark might be a Scottish term for work, but I'm not sure if it's accepted in standard English. Maybe not.)
|
||||
|
||||
Other possibilities: changing the A, R, or K. Let's see:
|
||||
|
||||
Second letter A: changing to other vowels or consonants. For example, CORK (changing A to O), but CORK is valid. Or CIRK (changing A to I; cirk is a variant of circus, but I think "cirk" isn't standard. Maybe CIRC? No, that's not right. Wait, maybe CIRK is a valid word in some contexts, like a circular trench in fortification? Hmm, not sure. Better to stick with more common words.
|
||||
|
||||
Third letter R: changing to another letter. For example, CAKE (changing R to K), but that would require changing both the third letter and the fourth letter? Wait, no: CARK to CAKE would be changing R to E and K to E? Wait, no. Wait, CARK is C-A-R-K. To get CAKE, you'd need to change R to K and K to E, which is two changes. So not possible in one step. Maybe changing R to another letter. For example, CASH (changing R to S and K to H? No, that's two letters again. Wait, no. Let's see: CARK. Changing the third letter R to another letter, like CASK (changing R to S). But that would be changing the third letter from R to S, resulting in CASK. Yes, that's a valid word. So CARK -> CASK.
|
||||
|
||||
Fourth letter K: changing to another letter. For example, CARD (changing K to D), CARP (changing K to P), CARS (changing K to S), CART (changing K to T). All of these are valid.
|
||||
|
||||
So possible first steps from CARK:
|
||||
|
||||
- BARK, DARK, LARK, MARK, PARK, CASK, CARD, CARP, CARS, CART.
|
||||
|
||||
Now, which of these would lead towards ZIPS? Let's think about the target word. ZIPS has a Z, which is the first letter. So maybe we need to get to a word starting with Z eventually. But changing from C to Z in one step would be a big jump. Alternatively, maybe we can change other letters first to reach intermediate steps that can approach Z.
|
||||
|
||||
Alternatively, maybe work on changing the last letters first. Let's see. For example, changing K to S (CARS) might be useful since the target ends with S. Let's try that.
|
||||
|
||||
CARK -> CARS (changing K to S). Now, from CARS, we need to change the third letter R to P and the first letter C to Z, and the second A to I. Let's see.
|
||||
|
||||
From CARS, possible changes:
|
||||
|
||||
Change C to Z: ZARS (is that a word? ZAR is a currency, plural ZARS? Maybe, but not sure. Alternatively, ZAPS (changing R to P and S to something? Wait, no. Let's check each letter.
|
||||
|
||||
Alternatively, from CARS, change A to I: CIRS (is that a word? CIR is a three-letter word, but CIRS? Maybe not. Or CURS (changing A to U), which is valid. But how does that help?
|
||||
|
||||
Alternatively, from CARS, change R to P: CAPS. Yes, that's valid. So CARS -> CAPS (changing R to P). Then from CAPS, we can change C to Z: ZAPS. Yes, ZAPS is a word. Then ZAPS -> ZIPS by changing A to I. So that would be a path:
|
||||
|
||||
CARK -> CARS -> CAPS -> ZAPS -> ZIPS.
|
||||
|
||||
Wait, let's check each step:
|
||||
|
||||
1. CARK to CARS: change K to S. Valid.
|
||||
|
||||
2. CARS to CAPS: change R to P. Valid.
|
||||
|
||||
3. CAPS to ZAPS: change C to Z. Valid.
|
||||
|
||||
4. ZAPS to ZIPS: change A to I. Valid.
|
||||
|
||||
That's four steps. But wait, starting from CARK (step 0), then steps 1-4 would be four changes. But the problem says "transform CARK into ZIPS by changing one letter at a time", so the number of steps would be the number of letters changed. Since all four letters are different, at minimum it would take four steps. But perhaps the path is longer if there are no shorter paths.
|
||||
|
||||
Wait, let me check if this works:
|
||||
|
||||
CARK → CARS → CAPS → ZAPS → ZIPS.
|
||||
|
||||
Yes, that's four steps. Each step changes one letter, and all intermediate words are valid. Let me verify each word:
|
||||
|
||||
- CARK: valid.
|
||||
|
||||
- CARS: valid.
|
||||
|
||||
- CAPS: valid.
|
||||
|
||||
- ZAPS: valid.
|
||||
|
||||
- ZIPS: valid.
|
||||
|
||||
Yes, that seems to work. But let me check if there's a shorter path or if this is correct.
|
||||
|
||||
Alternatively, is there a three-step path? Probably not since four letters need to change. Each step can only change one letter, so at least four steps are needed. So the path above is four steps, which is the minimum. Let me confirm each step again.
|
||||
|
||||
Another possible path: changing the first letter earlier. For example:
|
||||
|
||||
CARK → PARK (change C to P). Then PARK → PARKS (but that's five letters; can't do that. Alternatively, changing K to S: PARS. Then PARS -> PAWS (changing R to W?), but not helpful. Or PARK → PARS (change K to S), then PARS → PAIS (change R to I?), but PAIS is valid. Then PAIS → ZIPS (change P to Z and A to I? No, that's two letters. Hmm, not helpful.
|
||||
|
||||
Alternatively, CARK → DARK → DANK → DINK → ZINK → ZIPS? Let's see. CARK to DARK (C→D). DARK to DANK (R→N). DANK to DINK (A→I). DINK to ZINK (D→Z). ZINK is a variant spelling of zinc, but ZINK is a valid word. Then ZINK to ZIPS (K→S and N→P? Wait, no. ZINK is Z-I-N-K. To get to ZIPS, we need Z-I-P-S. So need to change N to P and K to S. That's two steps. So that path would be longer.
|
||||
|
||||
Another path: CARK → CARD → ZARD? ZARD isn't a word. CARD → WARD → WARD → ... not helpful.
|
||||
|
||||
Alternatively, CARK → CARK to CIRK (A→I). CIRK is a variant of circus, but maybe not standard. Then CIRK → ZIRK (C→Z). ZIRK is a surname or a brand, not sure if it's a valid word. Then ZIRK → ZIPS (R→P and K→S). But again, two letters at once. Not helpful.
|
||||
|
||||
So the first path I thought of seems viable: CARK → CARS → CAPS → ZAPS → ZIPS.
|
||||
|
||||
Another check: From CARK, can we go to CASK? CARK → CASK (R→S). Then CASK → ZASK (C→Z). ZASK? Not a word. Alternatively, CASK → CASK → CASH (K→H). Then CASH → ZASH (C→Z). Not a word. Not helpful.
|
||||
|
||||
Alternatively, CARK → CARP (K→P). CARP is valid. Then CARP → CARP → ZARP (C→Z). ZARP? Not a word. Or CARP → HARp (C→H), but HARP is valid. Then HARP → HIPS (A→I and R→P). But that's two changes. Not allowed.
|
||||
|
||||
So the initial path seems best. Let me confirm each step again:
|
||||
|
||||
1. CARK → CARS: change K to S. Valid.
|
||||
|
||||
2. CARS → CAPS: change R to P. Valid.
|
||||
|
||||
3. CAPS → ZAPS: change C to Z. Valid.
|
||||
|
||||
4. ZAPS → ZIPS: change A to I. Valid.
|
||||
|
||||
Yes, that's four steps, each changing one letter, all words are valid. So that should be the correct sequence.
|
||||
</think>
|
||||
|
||||
alternative chain of thought:
|
||||
<think> Okay, so I need to figure out how to transform the word "BALL" into "POMP" by changing one letter at a time, with each intermediate step being a valid English word. Let me start by writing down both words to see the letters that need to change.
|
||||
BALL → POMP
|
||||
|
||||
First, let's look at each position:
|
||||
|
||||
B (1st letter) needs to become P.
|
||||
A (2nd letter) needs to become O.
|
||||
L (3rd letter) needs to become M.
|
||||
L (4th letter) needs to become P.
|
||||
|
||||
So each letter position needs to change. The challenge is to do this one letter at a time, and each step must result in a valid word. Let me think about possible steps.
|
||||
|
||||
Starting with BALL. Let's see what words can be formed by changing one letter.
|
||||
|
||||
For the first letter (B), changing it to other letters might give words like CALL, FALL, HALL, MALL, PALL, WALL. Those are all valid. Alternatively, changing the second letter (A) could lead to BELL, BILL, BOLL, BULL. Changing the third letter (L) might result in BALM, BALE, BALK, etc. Changing the fourth letter (L) could give BALD, BALE, BALM, etc.
|
||||
|
||||
Wait, but the target word is POMP. So maybe we need to adjust letters step by step. Let's try to approach this step by step.
|
||||
|
||||
Option 1: Change the first letter from B to P. That would give PALL. Is PALL a valid word? Yes. So step 1: BALL → PALL.
|
||||
|
||||
Now, PALL. Next, maybe change the second letter A to O. That would make POLL. POLL is a valid word. Step 2: PALL → POLL.
|
||||
|
||||
Now, POLL. Next, change the third letter L to M. That would make POMP? Wait, no. Because POLL has two Ls. If we change the third L to M, it would be POML, which isn't a word. Alternatively, change the fourth L to P? That would be POLP, which isn't a word. Hmm. Maybe this path isn't working.
|
||||
|
||||
Alternative approach. Let's try changing a different letter first. Let's see. Starting with BALL.
|
||||
|
||||
Another possibility: Change the second letter A to O first. Then we get BOLL. BOLL is a valid word (a type of seed capsule). Step 1: BALL → BOLL.
|
||||
|
||||
From BOLL, maybe change the first letter B to P. That gives POLL. Step 2: BOLL → POLL.
|
||||
|
||||
Now, POLL again. Then, similar problem as before. Need to get to M in the third position. Maybe change the third L to M. But that gives POML, which isn't valid. Alternatively, change the fourth L to P. POLP isn't a word. Hmm.
|
||||
|
||||
Alternative path: From BOLL, maybe change the third letter L to M. That would make BOLM. Is BOLM a word? I don't think so. Or change the fourth L to P: BOLP. Not a word.
|
||||
|
||||
Maybe another starting step. Let's try changing the third letter in BALL from L to M. So BALL → BALM. Step 1: BALL → BALM.
|
||||
|
||||
BALM is valid. Now, change the fourth letter M to P? Wait, BALM is B-A-L-M. Wait, no. Wait, original word is BALL (B-A-L-L). Changing the third L to M would make BALM. Then changing the fourth L (from original) to something else. Wait, after BALM, the fourth letter is M? Wait, no. Wait, BALL is B-A-L-L. If I change the third L to M, the word becomes BALM (B-A-L-M). Wait, but then the fourth letter is M, so BALM. Then the fourth letter is M. To get to POMP, which ends with P, maybe we can change that M to P. But BALM → BALP? BALP isn't a valid word. Hmm.
|
||||
|
||||
Alternatively, maybe change the fourth letter first. BALL → BALD (changing the last L to D). Then BALD. But I'm not sure if that helps. Let's see. BALD. Then maybe change A to O: BOLD. Then BOLD. Then change B to P: POLD. Not a word. Hmm.
|
||||
|
||||
Alternatively, BALL → BALE (changing the last L to E). BALE is valid. Then BALE → PALE (changing B to P). Then PALE → POLE (changing A to O). Then POLE → POLL (changing E to L). Then POLL → POMP (changing L to M and L to P? Wait, but that's two changes at once. Not allowed. So this path might not work.
|
||||
|
||||
Wait, perhaps another approach. Let's consider the target word POMP. The letters are P-O-M-P. So after changing some letters, maybe we can reach a word that has some of these letters, then proceed.
|
||||
|
||||
Let me try starting with BALL. Change the first letter B to P: PALL. Then change the second letter A to O: POLL. Then change the third letter L to M: POMP? Wait, no. POLL → changing third L to M would make POML. Not valid. But if we can change the third and fourth letters? Wait, but we can only change one letter each step.
|
||||
|
||||
Alternatively, from POLL, change the fourth L to P. That would make POLP, which isn't a word. Not helpful.
|
||||
|
||||
Hmm. Maybe a different path. Let's think of possible intermediate words that can connect these.
|
||||
|
||||
Another idea: BALL → BOLT (changing A to O and L to T? No, that's two changes. Not allowed. Wait, no. To get BOLT from BALL, you need to change the second letter A to O (making BOLL) and then the fourth L to T. But that's two steps. Alternatively, maybe BALL → BOLT by changing A to O and third L to T? No, that's two changes again.
|
||||
|
||||
Alternatively, BALL → BULK (changing A to U and third L to K. Again two changes. Not allowed.
|
||||
|
||||
Wait, maybe changing the third letter first. BALL → BALE (changing third L to E). Then BALE → BOLT (changing A to O and E to T. Two changes. Not allowed. Hmm.
|
||||
|
||||
Alternatively, BALL → BELL (changing A to E). Then BELL → BELL is same, then maybe BELL → BOLT (changing E to O and L to T. Again two steps.
|
||||
|
||||
This is tricky. Let me try another angle. Let's look for words that are one letter away from POMP. For example, possible previous steps before POMP could be:
|
||||
|
||||
POOP (changing M to O), but POOP is valid. Then from POOP to POMP by changing O to M. But that's a possible step. But how to get to POOP from previous steps.
|
||||
Alternatively, POMP could come from changing a letter in a word like POOP, POMY (not a word), POPS (changing S to M? No), etc. Maybe not helpful.
|
||||
|
||||
Alternatively, the step before POMP could be something like POMM (if that's a word; not sure). Probably not.
|
||||
|
||||
Alternatively, consider that the third letter needs to be M. So maybe a word like POM_ something. For example, POMP. So maybe from a word like POME (a type of fruit), changing E to P. But POME to POMP would require changing E to P. But POME is a valid word. So step: POME → POMP. Then, how to get to POME.
|
||||
|
||||
But how to get from BALL to POME. Let's see. Let's try:
|
||||
|
||||
BALL → BALE (change third L to E)
|
||||
BALE → PALE (change B to P)
|
||||
PALE → PALE (same)
|
||||
PALE → POLE (change A to O)
|
||||
POLE → POLL (change E to L)
|
||||
POLL → POLL (same)
|
||||
POLL → POMP (change L to M and L to P. Wait, that's two letters. Not allowed.
|
||||
|
||||
Alternatively, maybe a different path. Let's think of words with M in the third position. For example, BALM (which we considered earlier). BALM → BAMM (not a word). Hmm.
|
||||
|
||||
Wait, BALM is B-A-L-M. If we change the first letter B to P, we get PALM. PALM is a valid word. So BALL → BALM (changing third L to M), then BALM → PALM (changing B to P). Then PALM. Now, from PALM, maybe change the fourth M to P. That gives PALP. PALP is a valid word (to palpate). So PALM → PALP. Then PALP → POMP. How? Change A to O and L to M. But that's two letters. Alternatively, change the second letter A to O: POLP (not a word). Hmm.
|
||||
|
||||
Alternatively, from PALM, change the second letter A to O: POLM. Not a word. Not helpful.
|
||||
|
||||
Hmm. Maybe another path. Let's think of changing the second letter first. BALL → BOLL (A to O). Then BOLL → BOLT (changing L to T). Then BOLT → BOLT (same). Then BOLT → MOAT (changing B to M and L to A. Two changes. Not allowed.
|
||||
|
||||
Alternatively, BOLL → POLL (B to P). Then POLL → POLE (L to E). Then POLE → POME (E to M and L to E? No, that's two changes. Wait, POLE to POME would require changing L to M and E to E. Wait, no. POLE is P-O-L-E. To get to POME, you need to change L to M and E to E. Wait, no. Changing L to M would make POME? P-O-M-E. Yes, POME is a word. So POLE → POME (changing L to M). Then POME → POMP (changing E to P). So that's possible.
|
||||
|
||||
So let's outline this path:
|
||||
|
||||
BALL → BOLL (change A to O)
|
||||
BOLL → POLL (change B to P)
|
||||
POLL → POLE (change L to E)
|
||||
POLE → POME (change L to M)
|
||||
POME → POMP (change E to P)
|
||||
Wait, but step 3: POLL to POLE. POLL is P-O-L-L. Changing the fourth L to E gives POLE? Wait, no. If you change the fourth L (the second L) to E, then it would be POLE. Yes. So POLL → POLE (changing the fourth letter L to E). Then POLE → POME (changing the third letter L to M). Wait, POLE is P-O-L-E. Changing the third letter L to M gives P-O-M-E, which is POME. Then POME → POMP by changing the fourth letter E to P. That works.
|
||||
|
||||
So the steps would be:
|
||||
|
||||
BALL → BOLL
|
||||
BOLL → POLL
|
||||
POLL → POLE
|
||||
POLE → POME
|
||||
POME → POMP
|
||||
But wait, step 3: POLL to POLE. Is that valid? Yes, because changing the fourth L to E gives POLE. Then step 4: POLE to POME by changing the third L (but in POLE, the third letter is L? Wait, POLE is P-O-L-E. So third letter is L. Changing that to M gives P-O-M-E, which is POME. Then POME to POMP by changing the fourth E to P. That works. So that's five steps. But the problem says "change one letter at a time" without specifying the minimum number of steps, so as long as each step is valid, it's okay.
|
||||
|
||||
But let's check if all these words are valid:
|
||||
|
||||
BOLL: Yes, a type of seed capsule.
|
||||
POLL: Yes.
|
||||
POLE: Yes.
|
||||
POME: Yes, a fruit like an apple.
|
||||
POMP: Target.
|
||||
So that's a valid sequence. But wait, the original problem is to go from BALL to POMP. Let me count the steps. BALL (1) → BOLL (2) → POLL (3) → POLE (4) → POME (5) → POMP (6). Wait, that's five steps, but the count is six words including the start and end. Wait, the problem says "transform ... by changing one letter at a time. Each step must create a valid English word". So the sequence should be BALL → word1 → word2 → ... → POMP. So in this case, five changes (since each change is a step). But let me check each step again.
|
||||
|
||||
BALL → BOLL (change A to O)
|
||||
BOLL → POLL (change B to P)
|
||||
POLL → POLE (change fourth L to E)
|
||||
POLE → POME (change third L to M)
|
||||
POME → POMP (change fourth E to P)
|
||||
Yes, that's five steps. Each step changes one letter, and each intermediate word is valid. So this works.
|
||||
|
||||
Alternatively, is there a shorter path? Let's see. Let's try another path.
|
||||
|
||||
Starting with BALL → PALL (B to P). Then PALL → POLL (A to O). Then POLL → POMP. But how? POLL to POMP requires changing L to M and L to P. That's two changes. Not allowed. So that path would require two steps from POLL: first change third L to M (making POML, invalid), then fourth L to P (POLP, invalid). So that doesn't work.
|
||||
|
||||
Alternatively, another path: BALL → BALE (third L to E). BALE → PALE (B to P). PALE → POLE (A to O). POLE → POME (L to M). POME → POMP (E to P). Same as the previous path but starting differently. So that's also five steps.
|
||||
|
||||
Alternatively, BALL → BALM (third L to M). BALM → PALM (B to P). PALM → PALP (M to P). PALP → POMP (A to O and L to M). But changing two letters at once isn't allowed. So that doesn't work.
|
||||
|
||||
So the path through BOLL, POLL, POLE, POME seems valid. Let me confirm each word:
|
||||
|
||||
BOLL: valid.
|
||||
POLL: valid.
|
||||
POLE: valid.
|
||||
POME: valid.
|
||||
POMP: target.
|
||||
Yes, this works. Therefore, the sequence is BALL → BOLL → POLL → POLE → POME → POMP.
|
||||
</think>
|
||||
|
||||
|
||||
Finally, I've seen in your previous answers to similar questions that you sometimes get complacent and deviate from the hard work required. You must maintain the style and level effort throughout the entirety of your answer until the end.
|
||||
Now, the user will provide you with the word ladder question and the correct solution path. Provide the user with a very verbose, first person chain of thought reasoning that is simulating how a human would methodically think to solve the word ladder step by step.
|
||||
|
|
@ -1,254 +0,0 @@
|
|||
import datetime
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from examples.word_ladder import generate_reasoning
|
||||
|
||||
# We alias the functions and globals for easier usage in our tests.
|
||||
submit_reasoning_batches = generate_reasoning.submit_reasoning_batches
|
||||
_submit_single_batch = generate_reasoning._submit_single_batch
|
||||
DEFAULT_INPUT_JSONL = generate_reasoning.DEFAULT_INPUT_JSONL
|
||||
COMMON_UUID = generate_reasoning.COMMON_UUID
|
||||
BATCH_SIZE = generate_reasoning.BATCH_SIZE
|
||||
client = generate_reasoning.client
|
||||
|
||||
|
||||
# Define a mock batch response class mimicking Anthropic's API response.
|
||||
class MockBatchResponse:
|
||||
def __init__(self, batch_id="msgbatch_mock", processing_status="in_progress", fail=False):
|
||||
self.id = batch_id
|
||||
self.type = "message_batch"
|
||||
self.processing_status = processing_status
|
||||
# Make request_counts a SimpleNamespace object with the required attributes
|
||||
self.request_counts = SimpleNamespace(processing=0, succeeded=0, errored=0, canceled=0, expired=0)
|
||||
self.ended_at = None
|
||||
# Use datetime objects so that isoformat() is available
|
||||
self.created_at = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
|
||||
self.expires_at = self.created_at + datetime.timedelta(seconds=86400)
|
||||
self.cancel_initiated_at = None
|
||||
self.results_url = None
|
||||
|
||||
|
||||
# Helper: Create a temporary system prompt file.
|
||||
@pytest.fixture
|
||||
def system_prompt_file(tmp_path, monkeypatch):
|
||||
prompt_text = "This is a system prompt."
|
||||
sys_file = tmp_path / "system_prompt.txt"
|
||||
sys_file.write_text(prompt_text, encoding="utf-8")
|
||||
|
||||
# Monkeypatch the system prompt path
|
||||
monkeypatch.setattr(generate_reasoning, "DEFAULT_SYSTEM_PROMPT", str(sys_file))
|
||||
return sys_file
|
||||
|
||||
|
||||
# Helper: Create necessary directories using a temporary location.
|
||||
@pytest.fixture
|
||||
def setup_directories(tmp_path, monkeypatch):
|
||||
# Create output directory in temporary path
|
||||
output_dir = tmp_path / "output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Monkeypatch the DEFAULT_OUTPUT_DIR to be a Path (temporary directory).
|
||||
monkeypatch.setattr(generate_reasoning, "DEFAULT_OUTPUT_DIR", output_dir)
|
||||
|
||||
# Ensure we're working in the temporary directory
|
||||
monkeypatch.chdir(tmp_path)
|
||||
return output_dir
|
||||
|
||||
|
||||
# Helper: Create a temporary input JSONL file with given entries.
|
||||
@pytest.fixture
|
||||
def input_jsonl_file(tmp_path, setup_directories, monkeypatch):
|
||||
# Create input file in temporary directory
|
||||
file_path = setup_directories / "word_ladder_examples.jsonl"
|
||||
entries = [
|
||||
{
|
||||
"question": "Transform 'A' to 'B'",
|
||||
"answer": "A,X,B",
|
||||
"reasoning": None,
|
||||
"metadata": {"start_word": "A", "end_word": "B", "word_length": 1, "chain_length": 3},
|
||||
},
|
||||
{
|
||||
"question": "Transform 'C' to 'D'",
|
||||
"answer": "C,Y,D",
|
||||
"reasoning": "Some reasoning",
|
||||
"metadata": {"start_word": "C", "end_word": "D", "word_length": 1, "chain_length": 3},
|
||||
},
|
||||
{
|
||||
"question": "Transform 'E' to 'F'",
|
||||
"answer": "E,Z,F",
|
||||
"reasoning": None,
|
||||
"metadata": {"start_word": "E", "end_word": "F", "word_length": 1, "chain_length": 3},
|
||||
},
|
||||
]
|
||||
with file_path.open("w", encoding="utf-8") as f:
|
||||
for entry in entries:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
# Monkeypatch DEFAULT_INPUT_JSONL to point to our temporary test file.
|
||||
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(file_path))
|
||||
return file_path
|
||||
|
||||
|
||||
# Test that submit_reasoning_batches builds a batch skipping entries with existing reasoning.
|
||||
def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_directories, monkeypatch):
|
||||
def fake_create(requests):
|
||||
for req in requests:
|
||||
# Handle the case where req is a dictionary
|
||||
if isinstance(req, dict):
|
||||
params = req.get("params", {})
|
||||
custom_id = req.get("custom_id")
|
||||
# Check if params itself is a dictionary
|
||||
if isinstance(params, dict):
|
||||
model = params.get("model")
|
||||
temperature = params.get("temperature")
|
||||
else:
|
||||
model = params.model
|
||||
temperature = params.temperature
|
||||
else:
|
||||
# Else, req is an object with attributes.
|
||||
params = req.params
|
||||
custom_id = req.custom_id
|
||||
if isinstance(params, dict):
|
||||
model = params.get("model")
|
||||
temperature = params.get("temperature")
|
||||
else:
|
||||
model = params.model
|
||||
temperature = params.temperature
|
||||
assert model == "claude-3-5-sonnet-20241022", "Incorrect model version"
|
||||
assert temperature == 0.5, "Incorrect temperature"
|
||||
assert "C_D_" not in custom_id
|
||||
return MockBatchResponse(batch_id="msgbatch_test_success")
|
||||
|
||||
monkeypatch.setattr(client.messages.batches, "create", fake_create)
|
||||
|
||||
batch_metadata_prefix = "test_metadata"
|
||||
submit_reasoning_batches(input_path=str(input_jsonl_file), batch_metadata_prefix=batch_metadata_prefix)
|
||||
|
||||
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
||||
meta_file_path = setup_directories / metadata_filename
|
||||
assert meta_file_path.exists(), "Metadata file was not created as expected."
|
||||
|
||||
with meta_file_path.open("r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
# Expecting only those entries that did not already have a reasoning.
|
||||
# (From our test input, 2 out of 3 entries qualify.)
|
||||
assert len(lines) > 0
|
||||
for line in lines:
|
||||
metadata = json.loads(line)
|
||||
api_response = metadata["api_response"]
|
||||
assert api_response["id"] == "msgbatch_test_success"
|
||||
assert api_response["processing_status"] == "in_progress"
|
||||
custom_ids = metadata["custom_ids"]
|
||||
assert len(custom_ids) == 2
|
||||
|
||||
|
||||
# Test that _submit_single_batch retries once and eventually succeeds.
|
||||
def test_retry_logic(system_prompt_file, setup_directories, monkeypatch):
|
||||
call_count = {"count": 0}
|
||||
|
||||
def fake_create_retry(requests):
|
||||
if call_count["count"] == 0:
|
||||
call_count["count"] += 1
|
||||
raise Exception("Temporary failure")
|
||||
return MockBatchResponse(batch_id="msgbatch_retry_success")
|
||||
|
||||
monkeypatch.setattr(client.messages.batches, "create", fake_create_retry)
|
||||
|
||||
dummy_request = type("DummyRequest", (), {"custom_id": "dummy_1"})()
|
||||
batch_requests = [dummy_request]
|
||||
custom_ids = ["dummy_1"]
|
||||
|
||||
_submit_single_batch(batch_requests, custom_ids, 0, "test_retry", "dummy_input.jsonl")
|
||||
|
||||
metadata_filename = f"{COMMON_UUID}_test_retry.jsonl"
|
||||
meta_file_path = setup_directories / metadata_filename
|
||||
assert meta_file_path.exists(), "Retry metadata file was not created."
|
||||
|
||||
with meta_file_path.open("r", encoding="utf-8") as f:
|
||||
metadata = json.loads(f.read())
|
||||
assert metadata["api_response"]["id"] == "msgbatch_retry_success"
|
||||
|
||||
assert call_count["count"] == 1
|
||||
|
||||
|
||||
# Test that when all attempts to submit a batch fail, the error is logged to the failed file.
|
||||
def test_failed_batch(system_prompt_file, setup_directories, monkeypatch):
|
||||
def fake_create_fail(requests):
|
||||
raise Exception("Permanent failure")
|
||||
|
||||
monkeypatch.setattr(client.messages.batches, "create", fake_create_fail)
|
||||
|
||||
dummy_request = type("DummyRequest", (), {"custom_id": "dummy_fail"})()
|
||||
batch_requests = [dummy_request]
|
||||
custom_ids = ["dummy_fail"]
|
||||
|
||||
_submit_single_batch(batch_requests, custom_ids, 0, "test_failed", "dummy_input.jsonl")
|
||||
|
||||
error_filename = f"{COMMON_UUID}_failed_batches.jsonl"
|
||||
error_file_path = setup_directories / error_filename
|
||||
assert error_file_path.exists(), "Failed batch log file was not created."
|
||||
|
||||
with error_file_path.open("r", encoding="utf-8") as f:
|
||||
error_entry = json.loads(f.readline())
|
||||
assert error_entry["batch_number"] == 0
|
||||
assert "Permanent failure" in error_entry["error"]
|
||||
assert error_entry["batch_requests"] == ["dummy_fail"]
|
||||
|
||||
|
||||
# Test batching behavior when multiple batches are needed.
|
||||
def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
|
||||
test_batch_size = 2
|
||||
monkeypatch.setattr(generate_reasoning, "BATCH_SIZE", test_batch_size)
|
||||
|
||||
# Create input file
|
||||
input_file = setup_directories / "word_ladder_examples.jsonl"
|
||||
entries = [
|
||||
{
|
||||
"question": f"Transform word ladder {idx}",
|
||||
"answer": f"start,mid,end_{idx}",
|
||||
"reasoning": None,
|
||||
"metadata": {"start_word": f"start_{idx}", "end_word": f"end_{idx}"},
|
||||
}
|
||||
for idx in range(5)
|
||||
]
|
||||
|
||||
with input_file.open("w", encoding="utf-8") as f:
|
||||
for entry in entries:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
# Monkeypatch DEFAULT_INPUT_JSONL.
|
||||
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(input_file))
|
||||
|
||||
batch_ids = []
|
||||
|
||||
def fake_create(requests):
|
||||
new_id = f"msgbatch_batch_{len(batch_ids)}"
|
||||
batch_ids.append(new_id)
|
||||
return MockBatchResponse(batch_id=new_id)
|
||||
|
||||
monkeypatch.setattr(client.messages.batches, "create", fake_create)
|
||||
|
||||
batch_metadata_prefix = "test_multi"
|
||||
submit_reasoning_batches(input_path=str(input_file), batch_metadata_prefix=batch_metadata_prefix)
|
||||
|
||||
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
||||
meta_file_path = setup_directories / metadata_filename
|
||||
assert meta_file_path.exists(), "Multiple batch metadata file was not created."
|
||||
|
||||
with meta_file_path.open("r", encoding="utf-8") as f:
|
||||
metadata_lines = f.readlines()
|
||||
# With 5 qualifying entries and a batch size of 2 we expect 3 batches.
|
||||
assert len(metadata_lines) == 3
|
||||
|
||||
seen_custom_ids = []
|
||||
for line in metadata_lines:
|
||||
metadata = json.loads(line)
|
||||
seen_custom_ids.extend(metadata["custom_ids"])
|
||||
assert metadata["api_response"]["id"] in batch_ids
|
||||
assert len(seen_custom_ids) == 5
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
"""
|
||||
create_word_ladders.py – Generates the word ladder dataset as a JSONL file.
|
||||
Each line is a JSON object with keys: question, answer, metadata, and reasoning (set to None).
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import reasoning_gym
|
||||
|
||||
|
||||
def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
|
||||
"""
|
||||
Check for duplicate word pairs in a word ladder JSONL file.
|
||||
|
||||
Returns:
|
||||
tuple[bool, dict]: (has_duplicates, valid_entries) where:
|
||||
- has_duplicates: True if any duplicates were found
|
||||
- valid_entries: Dict mapping line_number -> data for non-duplicate entries
|
||||
|
||||
Note: A pair is considered duplicate if either (word1, word2) or (word2, word1)
|
||||
already exists, since word ladder paths are bidirectional.
|
||||
"""
|
||||
pairs_seen = {} # (start, end) -> (line_number, data)
|
||||
valid_entries = {}
|
||||
duplicates_found = False
|
||||
|
||||
with open(jsonl_path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f):
|
||||
data = json.loads(line)
|
||||
metadata = data["metadata"]
|
||||
pair = (metadata["start_word"], metadata["end_word"])
|
||||
reverse_pair = (metadata["end_word"], metadata["start_word"])
|
||||
|
||||
# Check both orientations of the pair
|
||||
if pair in pairs_seen or reverse_pair in pairs_seen:
|
||||
duplicates_found = True
|
||||
# Skip this entry - it's a duplicate
|
||||
continue
|
||||
else:
|
||||
# Store both the line number and data for valid entries
|
||||
pairs_seen[pair] = (line_num, data)
|
||||
valid_entries[line_num] = data
|
||||
|
||||
return duplicates_found, valid_entries
|
||||
|
||||
|
||||
def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> None:
|
||||
"""
|
||||
Creates a word ladder dataset and writes each sample as a JSON line.
|
||||
Ensures no duplicate word pairs by regenerating as needed.
|
||||
"""
|
||||
if config is None:
|
||||
raise ValueError("Configuration (config) must be provided.")
|
||||
|
||||
# Create output directory if it doesn't exist.
|
||||
# Updated path to point to the parent folder's output.
|
||||
output_dir = Path(__file__).resolve().parent.parent / "output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Determine the file path based on uuid when not provided
|
||||
if jsonl_path is None:
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
jsonl_path = output_dir / f"word_ladders_{unique_id}.jsonl"
|
||||
else:
|
||||
jsonl_path = Path(jsonl_path)
|
||||
|
||||
target_size = config["dataset_config"]["size"]
|
||||
current_size = 0
|
||||
max_attempts = 3 # Limit total regeneration attempts
|
||||
attempt = 0
|
||||
|
||||
# Initial generation
|
||||
dataset = reasoning_gym.create_dataset(config["dataset_name"], **config["dataset_config"])
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
for item in tqdm(dataset, desc="Generating initial ladder examples"):
|
||||
row = {
|
||||
"question": item["question"],
|
||||
"answer": item["answer"],
|
||||
"reasoning": None,
|
||||
"metadata": item.get("metadata", {}),
|
||||
}
|
||||
f.write(json.dumps(row) + "\n")
|
||||
|
||||
while attempt < max_attempts:
|
||||
# Check entire file for duplicates
|
||||
has_duplicates, valid_entries = check_duplicates(jsonl_path)
|
||||
current_size = len(valid_entries)
|
||||
|
||||
if not has_duplicates and current_size == target_size:
|
||||
print(f"\nSuccessfully created dataset with {current_size} unique examples.")
|
||||
return
|
||||
|
||||
# If we have duplicates or not enough entries, regenerate the missing amount
|
||||
needed = target_size - current_size
|
||||
if needed > 0:
|
||||
print(f"\nAttempt {attempt + 1}: Regenerating {needed} examples to replace duplicates/missing entries...")
|
||||
|
||||
# Generate additional examples
|
||||
config["dataset_config"]["size"] = needed
|
||||
additional_dataset = reasoning_gym.create_dataset(config["dataset_name"], **config["dataset_config"])
|
||||
|
||||
# Write all entries to a temporary file
|
||||
temp_path = jsonl_path.with_suffix(".tmp")
|
||||
with open(temp_path, "w", encoding="utf-8") as f:
|
||||
# Write existing valid entries
|
||||
for data in valid_entries.values():
|
||||
f.write(json.dumps(data) + "\n")
|
||||
|
||||
# Write new entries
|
||||
for item in additional_dataset:
|
||||
row = {
|
||||
"question": item["question"],
|
||||
"answer": item["answer"],
|
||||
"reasoning": None,
|
||||
"metadata": item.get("metadata", {}),
|
||||
}
|
||||
f.write(json.dumps(row) + "\n")
|
||||
|
||||
# Replace original file with temporary file
|
||||
temp_path.replace(jsonl_path)
|
||||
|
||||
# Note: We'll check for duplicates again at the start of the next loop
|
||||
|
||||
attempt += 1
|
||||
|
||||
if current_size < target_size:
|
||||
print(f"\nWarning: Could only generate {current_size} unique examples after {max_attempts} attempts.")
|
||||
else:
|
||||
print(f"\nSuccessfully created dataset with {current_size} unique examples.")
|
||||
|
|
@ -1,214 +0,0 @@
|
|||
"""
|
||||
generate_reasoning.py – Reads the JSONL file containing ladder examples,
|
||||
creates batch requests of chain-of-thought prompts split into batches of 2,500,
|
||||
calls Anthropic's Message Batches API for each batch, and writes separate batch metadata
|
||||
files for later retrieval of the responses.
|
||||
|
||||
*** WARNING ***: Running large batches of requests via the Anthropic API (especially in generate_reasoning.py)
|
||||
can incur significant costs in Anthropic credits. Please review and understand your API quota and budgeting
|
||||
before running the API call. If you are testing or working with a demo dataset, adjust the batch size or dataset
|
||||
size appropriately to avoid unexpected charges.
|
||||
|
||||
Using Anthropic's Message Batches API with caching enabled for system prompt.
|
||||
In our informal testing, Sonnet was deemed best performance value.
|
||||
You can swap out to another API, but this will need a rewrite to remove anthropic-specific code.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import anthropic
|
||||
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
|
||||
from anthropic.types.messages.batch_create_params import Request
|
||||
from tqdm import tqdm
|
||||
|
||||
# Updated default output directory to use the parent directory.
|
||||
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "output"
|
||||
|
||||
# Add default constants at the top with other constants
|
||||
DEFAULT_INPUT_JSONL = "output/word_ladder_examples.jsonl"
|
||||
DEFAULT_SYSTEM_PROMPT = Path(__file__).resolve().parent.parent / "system_prompt.txt"
|
||||
BATCH_SIZE = 2500
|
||||
COMMON_UUID = uuid.uuid4().hex[:8]
|
||||
|
||||
# Set up the Anthropic client (ensure the API key is set in the environment)
|
||||
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
|
||||
|
||||
|
||||
def submit_reasoning_batches(
|
||||
input_path: str = DEFAULT_INPUT_JSONL,
|
||||
batch_metadata_prefix: str = "batch_metadata",
|
||||
system_prompt_path: str = DEFAULT_SYSTEM_PROMPT,
|
||||
) -> None:
|
||||
"""
|
||||
Reads the input JSONL file of word ladder examples, builds batch requests for any example that
|
||||
does not have reasoning, splits them into groups of BATCH_SIZE, and submits each batch using
|
||||
Anthropic's Message Batches API.
|
||||
|
||||
Args:
|
||||
input_path: Path to input JSONL file
|
||||
batch_metadata_prefix: Prefix for batch metadata files
|
||||
system_prompt_path: Path to system prompt file
|
||||
"""
|
||||
# Create output directory if it doesn't exist
|
||||
output_dir = DEFAULT_OUTPUT_DIR
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Read the system prompt from file (used as a preamble for every request)
|
||||
with open(system_prompt_path, "r", encoding="utf-8") as sys_file:
|
||||
system_message = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": sys_file.read(),
|
||||
"cache_control": {"type": "ephemeral"}, # Enable anthropic prompt caching
|
||||
}
|
||||
]
|
||||
batch_requests = []
|
||||
custom_ids = [] # List of custom_ids for the current batch
|
||||
batch_num = 0
|
||||
|
||||
# Get the total number of lines in advance for tqdm progress bar.
|
||||
total_lines = sum(1 for _ in open(input_path))
|
||||
|
||||
with open(input_path, "r", encoding="utf-8") as infile:
|
||||
for idx, line in tqdm(enumerate(infile), desc="Preparing batch requests", total=total_lines):
|
||||
data = json.loads(line)
|
||||
|
||||
# Skip example if 'reasoning' already exists.
|
||||
if not data.get("reasoning"):
|
||||
# Build a custom id. Here we use the row position and the start/end words:
|
||||
metadata = data.get("metadata", {})
|
||||
start = metadata.get("start_word", "unknown")
|
||||
end = metadata.get("end_word", "unknown")
|
||||
custom_id = f"{start}_{end}_{idx}"
|
||||
custom_ids.append(custom_id)
|
||||
|
||||
# Build the prompt text exactly as before.
|
||||
prompt = f"{data['question']}. The correct solution is {data['answer']}. "
|
||||
|
||||
# Build the request payload using Request and MessageCreateParamsNonStreaming.
|
||||
request_payload = Request(
|
||||
custom_id=custom_id,
|
||||
params=MessageCreateParamsNonStreaming(
|
||||
model="claude-3-5-sonnet-20241022", # Or choose the appropriate model version
|
||||
max_tokens=8192,
|
||||
temperature=0.5,
|
||||
system=system_message,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
),
|
||||
)
|
||||
# Instead of wrapping in SimpleNamespace, simply ensure custom_id is set.
|
||||
if isinstance(request_payload, dict):
|
||||
request_payload["custom_id"] = custom_id
|
||||
else:
|
||||
request_payload.custom_id = custom_id
|
||||
batch_requests.append(request_payload)
|
||||
|
||||
# If we have reached our batch size limit, submit the current batch.
|
||||
if len(batch_requests) >= BATCH_SIZE:
|
||||
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
|
||||
batch_num += 1
|
||||
# Reset for the next batch
|
||||
batch_requests = []
|
||||
custom_ids = []
|
||||
|
||||
# Submit any remaining requests that didn't complete a full batch.
|
||||
if batch_requests:
|
||||
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
|
||||
|
||||
|
||||
def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path):
|
||||
"""
|
||||
Helper function to submit a single batch request, track the full API response,
|
||||
and write out the corresponding metadata including the list of custom_ids.
|
||||
"""
|
||||
# Use the default output directory
|
||||
output_dir = DEFAULT_OUTPUT_DIR
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def serialize_datetime(dt):
|
||||
"""
|
||||
Convert a datetime object to ISO formatted string.
|
||||
If dt is None, returns None.
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
iso_str = dt.isoformat() # e.g. "2024-08-20T18:37:24.100435+00:00"
|
||||
|
||||
if dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None:
|
||||
iso_str = iso_str.replace("+00:00", "Z")
|
||||
return iso_str
|
||||
|
||||
def extract_custom_id(req):
|
||||
# Safely extract the custom_id attribute whether req is an object or a dict.
|
||||
return req.custom_id if hasattr(req, "custom_id") else req.get("custom_id")
|
||||
|
||||
max_attempts = 2
|
||||
attempt = 0
|
||||
last_exception = None
|
||||
message_batch = None
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
print(f"Submitting batch {batch_num} with {len(batch_requests)} requests... (attempt {attempt+1})")
|
||||
message_batch = client.messages.batches.create(requests=batch_requests)
|
||||
time.sleep(1)
|
||||
print(f"Batch {batch_num} submitted with ID: {message_batch.id}")
|
||||
break # Success: exit the loop.
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
attempt += 1
|
||||
print(f"Error submitting batch {batch_num} on attempt {attempt}: {e}")
|
||||
if attempt < max_attempts:
|
||||
print("Retrying...")
|
||||
time.sleep(1)
|
||||
|
||||
if message_batch is None:
|
||||
error_filename = output_dir / f"{COMMON_UUID}_failed_batches.jsonl"
|
||||
error_msg = (
|
||||
f"{str(last_exception)} after {max_attempts} attempts"
|
||||
if last_exception
|
||||
else f"Failed after {max_attempts} attempts"
|
||||
)
|
||||
failed_info = {
|
||||
"batch_number": batch_num,
|
||||
"error": error_msg,
|
||||
"batch_requests": [extract_custom_id(req) for req in batch_requests],
|
||||
"input_file": input_path,
|
||||
}
|
||||
with open(error_filename, "a", encoding="utf-8") as error_file:
|
||||
error_file.write(json.dumps(failed_info) + "\n")
|
||||
print(f"Batch {batch_num} permanently failed. Logged to {error_filename}.")
|
||||
return
|
||||
|
||||
# Build a dictionary of the expected response fields.
|
||||
api_response = {
|
||||
"id": message_batch.id,
|
||||
"type": message_batch.type,
|
||||
"processing_status": message_batch.processing_status,
|
||||
"request_counts": vars(message_batch.request_counts),
|
||||
"ended_at": serialize_datetime(message_batch.ended_at),
|
||||
"created_at": serialize_datetime(message_batch.created_at),
|
||||
"expires_at": serialize_datetime(message_batch.expires_at),
|
||||
"cancel_initiated_at": serialize_datetime(message_batch.cancel_initiated_at),
|
||||
"results_url": message_batch.results_url,
|
||||
}
|
||||
|
||||
batch_metadata = {
|
||||
"batch_id": message_batch.id,
|
||||
"api_response": api_response,
|
||||
"custom_ids": custom_ids,
|
||||
"input_file": os.path.basename(input_path),
|
||||
}
|
||||
metadata_filename = output_dir / f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
|
||||
with open(metadata_filename, "a", encoding="utf-8") as meta_file:
|
||||
meta_file.write(json.dumps(batch_metadata) + "\n")
|
||||
|
||||
print(f"Batch metadata for batch {batch_num} appended to {metadata_filename}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# When running this module directly, submit the reasoning batches.
|
||||
submit_reasoning_batches()
|
||||
|
|
@ -1,206 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
This script reads a JSONL file that contains messages with usage statistics.
|
||||
For each JSON record, it expects to find the token usage information under:
|
||||
record["result"]["message"]["usage"]
|
||||
|
||||
It then calculates and prints statistics for each usage token field:
|
||||
- input_tokens
|
||||
- cache_creation_input_tokens
|
||||
- cache_read_input_tokens
|
||||
- output_tokens
|
||||
+pricing calculations
|
||||
+calculates the savings from caching (vs if we hadn't done any caching)
|
||||
+forecasts costs for 10,000, 20,000 and 50,000 jobs based on tokens per query
|
||||
|
||||
Usage:
|
||||
python usage_stats.py path/to/msgbatch_01X9LgZNVkLFhzrrBd9LNgWb_results.jsonl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from statistics import mean
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Compute usage token statistics from a JSONL file.")
|
||||
parser.add_argument("file", help="Path to the JSONL file containing usage token data.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Usage token fields that we want to track
|
||||
usage_fields = [
|
||||
"input_tokens",
|
||||
"cache_creation_input_tokens",
|
||||
"cache_read_input_tokens",
|
||||
"output_tokens",
|
||||
]
|
||||
|
||||
# Pricing for Sonnet, 2 Feb 2025
|
||||
base_input_rate = 1.50
|
||||
pricing = {
|
||||
"input_tokens": base_input_rate,
|
||||
"cache_creation_input_tokens": base_input_rate * 1.25, # More expensive for initial computation
|
||||
"cache_read_input_tokens": base_input_rate * 0.1, # Cheaper for cache-read tokens
|
||||
"output_tokens": 7.50,
|
||||
}
|
||||
|
||||
# A dictionary to store lists of values for each usage field
|
||||
usage_data = {key: [] for key in usage_fields}
|
||||
total_lines = 0
|
||||
error_count = 0
|
||||
|
||||
with open(args.file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
total_lines += 1
|
||||
try:
|
||||
record = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
print(f"[Warning] Failed to parse JSON on line {total_lines}.")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Navigate to the usage stats
|
||||
try:
|
||||
usage = record["result"]["message"]["usage"]
|
||||
except KeyError:
|
||||
print(f"[Warning] Missing usage field in line {total_lines}.")
|
||||
error_count += 1
|
||||
continue
|
||||
|
||||
# Extract token values from the usage data
|
||||
for key in usage_fields:
|
||||
# Defaulting to 0 if the token field is missing or non-numeric
|
||||
try:
|
||||
token_value = int(usage.get(key, 0))
|
||||
except (ValueError, TypeError):
|
||||
token_value = 0
|
||||
usage_data[key].append(token_value)
|
||||
|
||||
print(f"\nProcessed {total_lines} lines with {error_count} error(s).\n")
|
||||
print("Usage Tokens Statistics:")
|
||||
print("-" * 40)
|
||||
|
||||
grand_total_cost = 0.0
|
||||
# Calculate and print stats for each token type
|
||||
for key in usage_fields:
|
||||
values = usage_data[key]
|
||||
if values:
|
||||
total = sum(values)
|
||||
count = len(values)
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
avg = mean(values)
|
||||
# Calculate pricing cost scaling by tokens per million
|
||||
cost = total / 1_000_000 * pricing[key]
|
||||
grand_total_cost += cost
|
||||
|
||||
print(f"{key}:")
|
||||
print(f" Total = {total}")
|
||||
print(f" Count = {count}")
|
||||
print(f" Min = {min_val}")
|
||||
print(f" Max = {max_val}")
|
||||
print(f" Mean = {avg:.2f}")
|
||||
print(f" Cost = ${cost:.2f}\n")
|
||||
else:
|
||||
print(f"{key}: No data found.\n")
|
||||
|
||||
print("-" * 40)
|
||||
print(f"Grand Total Estimated Cost: ${grand_total_cost:.2f}")
|
||||
|
||||
# -----------------------------------------------
|
||||
# Calculate caching savings (for input-related tokens)
|
||||
# Without caching, all tokens would have been charged at the standard input rate.
|
||||
#
|
||||
# Baseline cost (if no caching were used):
|
||||
# = (input_tokens + cache_creation_input_tokens + cache_read_input_tokens)
|
||||
# / 1_000_000 * base_input_rate
|
||||
#
|
||||
# Actual cost (with caching):
|
||||
# = input_tokens * base_input_rate +
|
||||
# cache_creation_input_tokens * (base_input_rate * 1.25) +
|
||||
# cache_read_input_tokens * (base_input_rate * 0.1)
|
||||
#
|
||||
# Savings from caching is then the difference.
|
||||
sum_input = sum(usage_data["input_tokens"])
|
||||
sum_cache_creation = sum(usage_data["cache_creation_input_tokens"])
|
||||
sum_cache_read = sum(usage_data["cache_read_input_tokens"])
|
||||
|
||||
baseline_input_cost = (sum_input + sum_cache_creation + sum_cache_read) / 1_000_000 * pricing["input_tokens"]
|
||||
actual_input_cost = (
|
||||
(sum_input) / 1_000_000 * pricing["input_tokens"]
|
||||
+ (sum_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"]
|
||||
+ (sum_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
|
||||
)
|
||||
caching_savings = baseline_input_cost - actual_input_cost
|
||||
|
||||
print(f"Caching Savings (input-related tokens): ${caching_savings:.2f}")
|
||||
|
||||
# -----------------------------------------------
|
||||
# Forecast future cost estimates based on the average tokens per job.
|
||||
#
|
||||
# We'll compute the average tokens per job (i.e. tokens per query) for:
|
||||
# - input_tokens
|
||||
# - cache_creation_input_tokens
|
||||
# - cache_read_input_tokens
|
||||
# - output_tokens
|
||||
#
|
||||
# Then we forecast, for example, for 10,000, 20,000, and 50,000 jobs:
|
||||
# - Apply the relevant pricing to compute the cost per token type.
|
||||
# - Also compute the baseline cost for input-related tokens and the savings
|
||||
# from caching.
|
||||
if usage_data["input_tokens"]:
|
||||
job_count = len(usage_data["input_tokens"])
|
||||
avg_input_tokens = sum(usage_data["input_tokens"]) / job_count
|
||||
avg_cache_creation_tokens = sum(usage_data["cache_creation_input_tokens"]) / job_count
|
||||
avg_cache_read_tokens = sum(usage_data["cache_read_input_tokens"]) / job_count
|
||||
avg_output_tokens = sum(usage_data["output_tokens"]) / job_count
|
||||
|
||||
print("\nAverage Tokens per Job:")
|
||||
print(f" input_tokens = {avg_input_tokens:.2f}")
|
||||
print(f" cache_creation_input_tokens = {avg_cache_creation_tokens:.2f}")
|
||||
print(f" cache_read_input_tokens = {avg_cache_read_tokens:.2f}")
|
||||
print(f" output_tokens = {avg_output_tokens:.2f}")
|
||||
|
||||
forecast_jobs = [2000, 4000, 10000, 20000, 50000]
|
||||
print("\nForecasting Future Job Costs:")
|
||||
for jobs in forecast_jobs:
|
||||
# Forecast token usage for the job count by multiplying the per-job averages.
|
||||
forecast_input = avg_input_tokens * jobs
|
||||
forecast_cache_creation = avg_cache_creation_tokens * jobs
|
||||
forecast_cache_read = avg_cache_read_tokens * jobs
|
||||
forecast_output = avg_output_tokens * jobs
|
||||
|
||||
# Forecast actual cost (with caching applied for input tokens):
|
||||
actual_input_cost_forecast = (
|
||||
(forecast_input) / 1_000_000 * pricing["input_tokens"]
|
||||
+ (forecast_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"]
|
||||
+ (forecast_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
|
||||
)
|
||||
|
||||
# Without caching, all input-related tokens would be at base_input_rate:
|
||||
baseline_input_cost_forecast = (
|
||||
(forecast_input + forecast_cache_creation + forecast_cache_read) / 1_000_000 * pricing["input_tokens"]
|
||||
)
|
||||
|
||||
caching_savings_forecast = baseline_input_cost_forecast - actual_input_cost_forecast
|
||||
|
||||
forecast_output_cost = forecast_output / 1_000_000 * pricing["output_tokens"]
|
||||
total_forecast_cost = actual_input_cost_forecast + forecast_output_cost
|
||||
|
||||
print(f"\nFor {jobs:,} jobs:")
|
||||
print(" Forecasted Token Usage:")
|
||||
print(f" input_tokens = {forecast_input:,.0f}")
|
||||
print(f" cache_creation_input_tokens = {forecast_cache_creation:,.0f}")
|
||||
print(f" cache_read_input_tokens = {forecast_cache_read:,.0f}")
|
||||
print(f" output_tokens = {forecast_output:,.0f}")
|
||||
print(" Estimated Costs:")
|
||||
print(f" Input Cost (with caching) = ${actual_input_cost_forecast:,.2f}")
|
||||
print(f" Output Cost = ${forecast_output_cost:,.2f}")
|
||||
print(f" Grand Total Cost = ${total_forecast_cost:,.2f}")
|
||||
print(f" Caching Savings (input) = ${caching_savings_forecast:,.2f}")
|
||||
else:
|
||||
print("No valid jobs to forecast future costs.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue