Feat/unsloth example (#482)

* cleaned up examples

* updated failing hooks

* updated readme

* corrected linting checks
This commit is contained in:
joesharratt1229 2025-06-28 17:04:38 +01:00 committed by GitHub
parent d9cd20c174
commit 1c98584f28
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 122 additions and 2857 deletions

View file

@ -1,2 +0,0 @@
checkpoint/
wandb/

View file

@ -1,736 +0,0 @@
# tested with OpenRLHF 707b970e992154952a91607ca5491cc49b8665c3
import argparse
import itertools
import math
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional
import torch
import torch.nn as nn
from openrlhf.datasets import PromptDataset, SFTDataset
from openrlhf.models import Actor, get_llm_for_sequence_regression
from openrlhf.models.utils import compute_approx_kl, masked_mean
from openrlhf.trainer import PPOTrainer
from openrlhf.trainer.ppo_utils.experience_maker import Experience, NaiveExperienceMaker, Samples
from openrlhf.utils import blending_datasets, get_strategy, get_tokenizer
from openrlhf.utils.logging_utils import init_logger
from torch.utils.data import Dataset
from transformers.trainer import get_scheduler
import reasoning_gym
from reasoning_gym.dataset import ProceduralDataset
from reasoning_gym.utils import extract_answer
logger = init_logger(__name__)
DEBUG = False
def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str:
if apply_chat_template:
chat = data[input_key]
if isinstance(chat, str):
chat = [{"role": "user", "content": chat}]
prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
else:
prompt = data[input_key]
if input_template:
prompt = input_template.format(prompt)
return prompt
class ReasoningGymDataset(Dataset):
def __init__(
self,
dataset,
tokenizer,
developer_prompt: Optional[str] = None,
developer_role: str = "system",
) -> None:
super().__init__()
self.dataset = dataset
self.tokenizer = tokenizer
self.developer_prompt = developer_prompt
self.developer_role = developer_role
def __len__(self):
length = len(self.dataset)
return length
def __getitem__(self, idx: int) -> tuple[str, dict]:
x = self.dataset[idx]
q = x["question"]
chat = []
if self.developer_prompt is not None:
chat.append({"role": self.developer_role, "content": self.developer_prompt})
chat.append({"role": "user", "content": q})
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
return prompt, x
@dataclass
class SamplesWithMetadata(Samples):
metadata: Optional[dict]
class AlgorithmicRewardExperienceMaker(NaiveExperienceMaker):
def __init__(
self,
dataset: ProceduralDataset,
actor: Actor,
critic: nn.Module,
reward_model: nn.Module,
initial_model: Actor,
tokenizer,
prompt_max_len: int,
kl_controller,
strategy=None,
remote_rm_url: str = None,
reward_fn=None,
) -> None:
super().__init__(
actor=actor,
critic=critic,
reward_model=reward_model,
initial_model=initial_model,
tokenizer=tokenizer,
prompt_max_len=prompt_max_len,
kl_controller=kl_controller,
strategy=strategy,
remote_rm_url=remote_rm_url,
reward_fn=reward_fn,
)
self.dataset = dataset
@torch.no_grad()
def generate_samples(self, all_prompts: list[tuple[str, Any]], **generate_kwargs) -> list[Samples]:
"""
Generate samples and return in batches.
"""
assert not getattr(self, "packing_samples", False)
args = self.strategy.args
self.actor.eval()
# prepare inputs to sample multiple response
repeated_prompts = []
repeated_metadata = []
for prompt, metadata in all_prompts:
for _ in range(args.n_samples_per_prompt):
repeated_prompts.append(prompt)
repeated_metadata.append(metadata)
samples_list = []
for i in range(0, len(repeated_prompts), args.micro_rollout_batch_size):
prompts = repeated_prompts[i : i + args.micro_rollout_batch_size]
metadata = repeated_metadata[i : i + args.micro_rollout_batch_size]
inputs = self.tokenize_fn(prompts, self.prompt_max_len, device="cuda")
sequences, attention_mask, action_mask = self.actor.generate(**inputs, **generate_kwargs)
samples = SamplesWithMetadata(
sequences=sequences,
attention_mask=attention_mask,
action_mask=action_mask,
num_actions=action_mask.size(1),
packed_seq_lens=None,
response_length=action_mask.float().sum(dim=-1),
total_length=attention_mask.float().sum(dim=-1),
metadata=metadata,
)
samples_list.append(samples)
return samples_list
@torch.no_grad()
def make_experience(self, samples: Samples) -> Experience:
"""
Turn samples into experience by calculating logprobs, values, rewards, and kl divergence.
"""
self.actor.eval()
self.initial_model.eval()
if self.reward_model is not None:
self.reward_model.eval()
if self.critic is not None:
self.critic.eval()
# extract values from samples
sequences = samples.sequences
attention_mask = samples.attention_mask
action_mask = samples.action_mask
num_actions = samples.num_actions
if isinstance(samples, SamplesWithMetadata):
metadata = samples.metadata
else:
metadata = None
# log probs
action_log_probs = self.actor(sequences, num_actions, attention_mask)
# init log probs
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
# values
if self.critic is not None:
value = self.critic(sequences, num_actions, attention_mask)
else:
value = None
# determine outcome reward
completions = sequences[:, -action_mask.size(1) :].cpu()
completions = self.tokenizer.batch_decode(completions, skip_special_tokens=True)
returns = [
self.dataset.score_answer(extract_answer(c, tag_name="answer"), entry=m)
for c, m in zip(completions, metadata)
]
r = torch.tensor(returns, dtype=torch.float32, device=sequences.device)
kl = compute_approx_kl(
action_log_probs,
base_action_log_probs,
action_mask=action_mask,
use_kl_estimator_k3=self.strategy.args.use_kl_estimator_k3,
)
info = {
"kl": masked_mean(kl, action_mask, dim=-1),
"reward": r,
"response_length": samples.response_length,
"total_length": samples.total_length,
"num_actions": num_actions,
}
logger.info(f"info={info}")
# reset model state
self.actor.train()
if self.critic is not None:
self.critic.train()
return Experience(
sequences,
action_log_probs,
value,
None,
None,
attention_mask,
action_mask,
info,
kl,
)
def train(args):
# configure strategy
strategy = get_strategy(args)
strategy.setup_distributed()
# configure model
# load huggingface model
actor = Actor(
args.pretrain,
use_flash_attention_2=args.flash_attn,
bf16=args.bf16,
load_in_4bit=args.load_in_4bit,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=args.target_modules,
lora_dropout=args.lora_dropout,
ds_config=strategy.get_ds_train_config(is_actor=True),
)
if args.actor_init_on_gpu:
actor = actor.to(torch.cuda.current_device())
if args.critic_pretrain:
critic = get_llm_for_sequence_regression(
args.critic_pretrain,
"critic",
normalize_reward=args.normalize_reward,
use_flash_attention_2=args.flash_attn,
bf16=args.bf16,
load_in_4bit=args.load_in_4bit,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=args.target_modules,
lora_dropout=args.lora_dropout,
ds_config=strategy.get_ds_train_config(is_actor=False),
value_head_prefix=args.value_head_prefix,
init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain,
)
else:
critic = None
reward_model = None
strategy.print(actor)
strategy.print(critic)
# configure tokenizer
tokenizer = get_tokenizer(
args.pretrain,
actor.model,
"left",
strategy,
use_fast=not args.disable_fast_tokenizer,
)
# load weights for reference actor
initial_model = Actor(
args.pretrain,
use_flash_attention_2=args.flash_attn,
bf16=args.bf16,
load_in_4bit=args.load_in_4bit,
ds_config=strategy.get_ds_eval_config(offload=False),
)
if args.enable_ema:
ema_model = Actor(
args.pretrain,
use_flash_attention_2=args.flash_attn,
bf16=args.bf16,
load_in_4bit=args.load_in_4bit,
ds_config=strategy.get_ds_eval_config(offload=True),
)
else:
ema_model = None
# gradient_checkpointing
if args.gradient_checkpointing:
actor.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
)
if critic is not None:
critic.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
)
# configure optimizer
actor_optim = strategy.create_optimizer(
actor, lr=args.actor_learning_rate, betas=args.adam_betas, weight_decay=args.l2
)
if args.critic_pretrain:
critic_optim = strategy.create_optimizer(
critic,
lr=args.critic_learning_rate,
betas=args.adam_betas,
weight_decay=args.l2,
)
else:
critic_optim = None
# prepare datasets
print("prompt_data", args.prompt_data)
# DeepSeek Zero system prompt
system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
<answer> answer here </answer>
"""
prompts_data = reasoning_gym.create_dataset(args.prompt_data, size=args.max_samples)
prompts_dataset = ReasoningGymDataset(prompts_data, tokenizer, developer_prompt=system_prompt)
if args.pretrain_data:
pretrain_data = blending_datasets(
args.pretrain_data,
args.pretrain_data_probs,
strategy,
args.seed,
return_eval=False,
train_split=args.pretrain_split,
)
pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len
pretrain_dataset = SFTDataset(
pretrain_data.select(
range(
min(
len(pretrain_data),
args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt,
)
)
),
tokenizer,
pretrain_max_len,
strategy,
pretrain_mode=True,
)
def collate_prompt_and_metadata(xs: list[tuple]) -> list[tuple]:
"""dummy collate function to pass on the metadata dict unchanged"""
return xs
# prepare dataloader
prompts_dataloader = strategy.setup_dataloader(
prompts_dataset,
args.rollout_batch_size // strategy.world_size,
True,
True,
collate_fn=collate_prompt_and_metadata,
)
if args.pretrain_data:
pretrain_dataloader = itertools.cycle(
iter(
strategy.setup_dataloader(
pretrain_dataset,
args.micro_train_batch_size,
True,
True,
pretrain_dataset.collate_fn,
)
)
)
else:
pretrain_dataloader = None
# configure scheduler
num_update_steps_per_episodes = (
len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs
)
max_steps = math.ceil(args.num_episodes * num_update_steps_per_episodes)
actor_scheduler = get_scheduler(
"cosine_with_min_lr",
actor_optim,
num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio),
num_training_steps=max_steps,
scheduler_specific_kwargs={"min_lr": args.actor_learning_rate * 0.1},
)
if args.critic_pretrain:
critic_scheduler = get_scheduler(
"cosine_with_min_lr",
critic_optim,
num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio),
num_training_steps=max_steps,
scheduler_specific_kwargs={"min_lr": args.critic_learning_rate * 0.1},
)
else:
critic_scheduler = None
# prepare models/optimizers...
(
(actor, actor_optim, actor_scheduler),
(critic, critic_optim, critic_scheduler),
reward_model,
initial_model,
) = strategy.prepare(
(actor, actor_optim, actor_scheduler),
(critic, critic_optim, critic_scheduler),
reward_model,
initial_model,
is_rlhf=True,
)
if ema_model:
ema_model._offload = True
ema_model = strategy.prepare(ema_model, is_rlhf=True)
# load checkpoint
consumed_samples = 0
if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")):
_, states = strategy.load_ckpt(actor.model, os.path.join(args.ckpt_path, "_actor"))
if args.critic_pretrain:
strategy.load_ckpt(critic, os.path.join(args.ckpt_path, "_critic"))
consumed_samples = states["consumed_samples"]
strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}")
os.makedirs(args.save_path, exist_ok=True)
# configure Trainer
trainer = PPOTrainer(
strategy,
actor,
critic,
reward_model,
initial_model,
ema_model,
actor_optim,
critic_optim,
actor_scheduler,
critic_scheduler,
max_epochs=args.max_epochs,
micro_train_batch_size=args.micro_train_batch_size,
micro_rollout_batch_size=args.micro_rollout_batch_size,
gradient_checkpointing=args.gradient_checkpointing,
tokenizer=tokenizer,
prompt_max_len=args.prompt_max_len,
value_clip=args.value_clip,
eps_clip=args.eps_clip,
gamma=args.gamma,
lambd=args.lambd,
init_kl_coef=args.init_kl_coef,
kl_target=args.kl_target,
ema_beta=0.992,
ptx_coef=args.ptx_coef,
max_norm=args.max_norm,
# fro GPT generation
do_sample=True,
max_new_tokens=args.generate_max_len,
max_length=args.max_len,
temperature=args.temperature,
top_p=args.top_p,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
# remote reward model
remote_rm_url=args.remote_rm_url,
save_hf_ckpt=args.save_hf_ckpt,
disable_ds_ckpt=args.disable_ds_ckpt,
)
# patch experience maker ..
xp = trainer.experience_maker
trainer.experience_maker = AlgorithmicRewardExperienceMaker(
dataset=prompts_dataset.dataset,
actor=xp.actor,
critic=xp.critic,
reward_model=xp.reward_model,
initial_model=xp.initial_model,
tokenizer=xp.tokenizer,
prompt_max_len=xp.prompt_max_len,
kl_controller=xp.kl_ctl,
strategy=xp.strategy,
remote_rm_url=xp.remote_rm_url,
reward_fn=xp.reward_fn,
)
xp = None
trainer.fit(
args,
prompts_dataloader,
pretrain_dataloader,
consumed_samples,
num_update_steps_per_episodes,
)
# save model checkpoint after fitting on only rank0
strategy.save_model(
ema_model if args.enable_ema else actor,
tokenizer,
args.save_path,
)
if args.critic_pretrain and args.save_value_network:
strategy.save_model(
critic,
tokenizer,
args.save_path + "_critic",
)
import os
import debugpy
if __name__ == "__main__":
# Only attach debugger on rank 0 process
if DEBUG and int(os.getenv("LOCAL_RANK", "0")) == 0:
debugpy.listen(("localhost", 5678))
print("Waiting for debugger attach on localhost:5678")
debugpy.wait_for_client()
parser = argparse.ArgumentParser()
# Checkpoint
parser.add_argument("--save_path", type=str, default="./ckpt")
parser.add_argument("--save_steps", type=int, default=-1)
parser.add_argument("--save_hf_ckpt", action="store_true", default=False)
parser.add_argument("--disable_ds_ckpt", action="store_true", default=False)
parser.add_argument("--logging_steps", type=int, default=1)
parser.add_argument("--eval_steps", type=int, default=-1)
parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo")
parser.add_argument("--max_ckpt_num", type=int, default=3)
parser.add_argument("--max_ckpt_mem", type=int, default=1e8)
parser.add_argument("--load_checkpoint", action="store_true", default=False)
# PPO
parser.add_argument("--num_episodes", type=int, default=1)
parser.add_argument("--rollout_batch_size", type=int, default=512)
parser.add_argument("--micro_rollout_batch_size", type=int, default=8)
parser.add_argument("--max_epochs", type=int, default=1)
parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt")
parser.add_argument(
"--generate_max_len",
type=int,
default=1024,
help="Max tokens to generate in PPO",
)
parser.add_argument("--max_len", type=int, default=None, help="deprecated max_len")
parser.add_argument("--max_samples", type=int, default=1000000)
parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping")
parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss")
parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef")
parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range")
parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range")
parser.add_argument("--lambd", type=float, default=1.0, help="PPO GAE lambd")
parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma")
parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU")
parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size")
parser.add_argument(
"--normalize_reward",
action="store_true",
default=False,
help="Enable Reward Normazation",
)
parser.add_argument("--top_p", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument(
"--freezing_actor_steps",
type=int,
default=-1,
help="Used for critic initialization",
)
parser.add_argument(
"--n_samples_per_prompt",
type=int,
default=1,
help="number of responses for each prompt in generation",
)
parser.add_argument(
"--save_value_network",
action="store_true",
default=False,
help="Save critic model",
)
parser.add_argument("--actor_learning_rate", type=float, default=1e-6)
parser.add_argument("--critic_learning_rate", type=float, default=9e-6)
parser.add_argument("--lr_warmup_ratio", type=float, default=0.03)
parser.add_argument("--kl_target", type=float, default=None)
parser.add_argument("--init_kl_coef", type=float, default=0.01, help="KL penalty in PPO")
parser.add_argument(
"--use_kl_estimator_k3",
action="store_true",
default=False,
help=(
"Use the k3 estimator in http://joschu.net/blog/kl-approx.html"
"to ensure the KL divergence calculated is non-negative"
),
)
parser.add_argument(
"--adam_betas",
type=float,
nargs=2,
default=(0.9, 0.95),
help="Betas for Adam optimizer",
)
parser.add_argument(
"--reward_clip_range",
type=float,
nargs=2,
default=(-10, 10),
help="Reward clip range",
)
# DeepSpeed
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed")
parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage")
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
parser.add_argument("--enable_ema", action="store_true", help="Enable EMA checkpoint for the model.")
parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size")
parser.add_argument(
"--adam_offload",
action="store_true",
default=False,
help="Offload Adam Optimizer",
)
parser.add_argument("--actor_init_on_gpu", action="store_true", default=False)
parser.add_argument(
"--flash_attn",
action="store_true",
default=False,
help="Enable FlashAttention2",
)
parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss")
parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type")
parser.add_argument("--overlap_comm", action="store_true", default=False)
parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False)
parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
# Reinforce
parser.add_argument(
"--advantage_estimator",
type=str,
choices=["gae", "reinforce", "rloo"],
default="gae",
help="Choose advantage estimation method: gae, reinforce, rloo",
)
# LoRA
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument("--lora_rank", type=int, default=0)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear")
parser.add_argument("--lora_dropout", type=float, default=0)
# Models
parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path")
parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path")
parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API")
parser.add_argument("--critic_pretrain", type=str, default=None, help="HF model name or path")
parser.add_argument("--value_head_prefix", type=str, default="score")
# Custom dataset
parser.add_argument("--prompt_data", type=str, default="chain_sum", help="HF dataset name or path")
parser.add_argument(
"--prompt_data_probs",
type=str,
default="1.0",
help="sampling probs for datasets",
)
parser.add_argument("--prompt_split", type=str, default="train")
parser.add_argument("--pretrain_data", type=str, default=None, help="HF dataset name or path")
parser.add_argument(
"--pretrain_data_probs",
type=str,
default="1.0",
help="sampling probs for datasets",
)
parser.add_argument("--pretrain_split", type=str, default="train")
parser.add_argument("--input_key", type=str, default="question", help="JSON dataset key")
parser.add_argument("--input_template", type=str, default=None)
parser.add_argument(
"--apply_chat_template",
action="store_true",
default=False,
help="Use HF tokenizer chat template",
)
# wandb parameters
parser.add_argument("--use_wandb", type=str, default=None)
parser.add_argument("--wandb_org", type=str, default=None)
parser.add_argument("--wandb_group", type=str, default=None)
parser.add_argument("--wandb_project", type=str, default="openrlhf_train_ppo")
parser.add_argument(
"--wandb_run_name",
type=str,
default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"),
)
# TensorBoard parameters
parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path")
args = parser.parse_args()
if args.advantage_estimator not in ["gae"]:
args.critic_pretrain = None
elif args.critic_pretrain is None:
args.critic_pretrain = args.pretrain ## temp
if args.advantage_estimator == "rloo":
assert args.n_samples_per_prompt > 1, "RLOO requires n_samples_per_prompt > 1"
if args.input_template and "{}" not in args.input_template:
print("[Warning] {} not in args.input_template, set to None")
args.input_template = None
if args.input_template and "\\n" in args.input_template:
print(
"[Warning] input_template contains \\n chracters instead of newline. "
"You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell."
)
train(args)

View file

@ -1,34 +0,0 @@
#!/bin/bash
args=(
custom_reward.py
--pretrain meta-llama/Llama-3.2-3B-Instruct # meta-llama/Llama-3.2-1B-Instruct
--save_path ./checkpoint/Llama-3.2-3B-rl # ./checkpoint/Llama-3.2-1b-lr
--save_steps -1
--logging_steps 1
--eval_steps -1
--micro_train_batch_size 2
--train_batch_size 128
--micro_rollout_batch_size 16
--rollout_batch_size 1024
--max_epochs 1
--prompt_max_len 1024
--generate_max_len 1024
--zero_stage 2
--bf16
--actor_learning_rate 5e-7
--init_kl_coef 0.01
--prompt_data chain_sum # leg_counting
--input_key question
--apply_chat_template
--normalize_reward
--adam_offload
--flash_attn
--gradient_checkpointing
--max_samples 100000
--critic_learning_rate 9e-6
)
# Add wandb argument only if wandb_token is set
if [[ -n "${wandb_token}" ]]; then
args+=(--use_wandb "${wandb_token}")
fi
deepspeed ${args[@]}

View file

@ -1,34 +0,0 @@
# Reasoning Gym Open-Instruct Example
This example demonstrates how to use [open-instruct](https://github.com/allenai/open-instruct) with **reasoning-gym** for training language models through reinforcement learning.
## Environment Setup
Before running the training script, you may want to set the following environment variables:
```bash
export CUDA_VISIBLE_DEVICES=0
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
```
- `CUDA_VISIBLE_DEVICES=0` specifies which GPUs to use (in this case, GPUs 0 )
- `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` enables dynamic GPU memory allocation
## Running the Training Script
To start training, run:
```bash
bash grpo_config.sh
```
This will train a DeepSeek-R1-Distill-Qwen-1.5B model using GRPO (Group Relative Policy Optimization) on reasoning tasks.
## Key Features
- Uses vLLM for efficient inference
- Supports multi-GPU training
- Implements reward functions for both answer correctness and response formatting
- Includes evaluation during training
- Supports model checkpointing and pushing to Hugging Face Hub
- Integrates with Weights & Biases for experiment tracking

View file

@ -1,34 +0,0 @@
exp_name="0302_qwen2.5_math_grpo_fast1_${RANDOM}"
python src/grpo_trainer.py \
--model_name_or_path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--beta 0.0 \
--num_unique_prompts_rollout 16 \
--num_samples_per_prompt_rollout 8 \
--output_dir ./open_instruct_checkpoints/$exp_name \
--save_freq 40 \
--kl_estimator kl3 \
--learning_rate 5e-7 \
--max_token_length 256 \
--max_prompt_token_length 128 \
--response_length 256 \
--pack_length 512 \
--stop_strings '"</answer>"' \
--chat_template_name r1_simple_chat_postpend_think \
--temperature 1.0 \
--total_episodes 100000 \
--deepspeed_stage 3 \
--per_device_train_batch_size 1 \
--num_mini_batches 1\
--num_learners_per_node 1 \
--num_epochs 1 \
--vllm_tensor_parallel_size 1 \
--vllm_num_engines 1 \
--vllm_enforce_eager true \
--lr_scheduler_type linear \
--seed 1 \
--num_evals 100 \
--dataset_name chain_sum \
--gradient_checkpointing \
--with_tracking \
--single_gpu_mode false \
--gather_whole_model false

View file

@ -1,457 +0,0 @@
import json
import os
import shutil
import threading
import time
from argparse import Namespace
from dataclasses import asdict, dataclass
from queue import Queue
from typing import Callable, List
import numpy as np
import pandas as pd
import ray
import torch
from open_instruct.ground_truth_utils import soft_format_reward_func
from open_instruct.grpo_fast import (
Args,
ModelGroup,
PolicyTrainerRayProcess,
Timer,
calculate_runtime_args,
collate_fn,
vllm_generate_thread,
)
from open_instruct.model_utils import ModelConfig, push_folder_to_hub
from open_instruct.rl_utils2 import PackedSequences, reset_position_ids
from open_instruct.utils import ArgumentParserPlus, get_wandb_tags, is_beaker_job, maybe_get_beaker_config
from open_instruct.vllm_utils2 import create_vllm_engines
from ray.util.placement_group import placement_group
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer, PreTrainedTokenizer
from utils import ReasoningGymDataset, list_preserving_collate, pack_sequences
from vllm import SamplingParams
from reasoning_gym.utils import SYSTEM_PROMPTS, extract_answer
@dataclass
class GRPOArgs(Args):
dataset_name: str = "chain_sum"
seed: int = 42
size: int = 10000
eval_seed: int = 42
eval_size: int = 100
def data_preparation_thread(
reward_fn: Callable,
inference_results_Q: Queue,
packed_sequences_Q: Queue,
queries_prompt_Q: Queue,
args: Args,
tokenizer: PreTrainedTokenizer,
num_training_steps: int,
reasoning_gym_dataset: ReasoningGymDataset,
):
for training_step in range(1, num_training_steps + 1):
# Get next batch of prompts and responses
observations = queries_prompt_Q.get()
queries, items = observations
if args.num_samples_per_prompt_rollout > 1:
queries = [item for item in queries for _ in range(args.num_samples_per_prompt_rollout)]
with Timer("🚀 [Data Preparation Thread] Getting response ids"):
responses, finish_reasons = inference_results_Q.get()
for i in range(len(finish_reasons)):
if finish_reasons[i] == "stop" and responses[i][-1] != tokenizer.eos_token_id:
responses[i].append(tokenizer.eos_token_id)
with Timer("📦 [Data Preparation Thread] Packing sequences"):
packed_sequences = pack_sequences(
queries=queries,
responses=responses,
pack_length=args.pack_length,
pad_token_id=tokenizer.pad_token_id,
)
num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses)
with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
decoded_responses = tokenizer.batch_decode(responses, skip_special_tokens=True)
stop_rate = sum(int(finish_reason == "stop") for finish_reason in finish_reasons) / len(finish_reasons)
with Timer("💰 [Data Preparation Thread] Calculating rewards"):
scores, reward_metrics = reward_fn(decoded_responses, items, reasoning_gym_dataset)
with Timer("🎆 [Data Preparation Thread] Calculating advantages"):
# Calculate advantages
scores = np.array(scores)
print(f"[Data Preparation Thread] {len(scores)=}")
scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout)
global_mean_grouped_rewards = scores_per_prompt.mean(axis=-1)
global_mean_grouped_rewards = np.repeat(
global_mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0
)
global_std_grouped_rewards = scores_per_prompt.std(axis=-1)
global_std_grouped_rewards = np.repeat(
global_std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0
)
global_advantages = (scores - global_mean_grouped_rewards) / (global_std_grouped_rewards + 1e-8)
# Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value
# and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed)
lookup_advantages = np.zeros(len(global_advantages) + 1, dtype=np.float32)
lookup_advantages[1:] = global_advantages
packed_advantages = [
torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32)
for packed_mask in packed_sequences.response_masks
]
packed_sequences.advantages = packed_advantages
with Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker"):
B = (
len(packed_sequences.query_responses) // args.world_size
) # essentially doing `drop_last=True`, which is fine.
collated_data = []
for i in range(args.world_size):
per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)]
per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)]
per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)]
per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)]
per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)]
# Shuffle the batch and collate the data
b_inds = np.random.permutation(len(per_device_packed_query_responses))
collated_query_responses = []
collated_attention_masks = []
collated_position_ids = []
collated_response_masks = []
collated_advantages = []
for j in range(0, len(per_device_packed_query_responses), args.per_device_train_batch_size):
micro_range = b_inds[j : j + args.per_device_train_batch_size]
collated_query_responses.append(
collate_fn(
[per_device_packed_query_responses[idx] for idx in micro_range], tokenizer.pad_token_id
)
)
collated_attention_masks.append(
collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0)
)
collated_position_ids.append(
collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0)
)
collated_response_masks.append(
collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0)
)
collated_advantages.append(
collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0)
)
collated_data.append(
{
"collated_query_responses": collated_query_responses,
"collated_attention_masks": collated_attention_masks,
"collated_position_ids": collated_position_ids,
"collated_advantages": collated_advantages,
"collated_response_masks": collated_response_masks,
}
)
sequence_lengths = np.array([len(response) for response in responses])
metrics = {
"scores": np.array(scores).mean(),
"val/sequence_lengths": sequence_lengths.mean(),
"val/sequence_lengths_min": sequence_lengths.min(),
"val/sequence_lengths_max": sequence_lengths.max(),
"val/stop_rate": stop_rate,
**reward_metrics,
}
if args.save_traces:
traces = {
"scores": scores.tolist(),
"finish_reasons": finish_reasons,
"responses": responses,
"queries": queries,
"ground_truths": items["answer"],
"datasets": reasoning_gym_dataset.dataset_name,
"training_step": training_step,
**reward_metrics,
}
os.makedirs(args.output_dir, exist_ok=True)
with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f:
json.dump(traces, f)
f.write("\n")
packed_sequences_Q.put(
{
"packed_sequences": packed_sequences, # for debugging purposes
"collated_data": collated_data,
"metrics": metrics,
"responses_count": len(responses),
"num_new_tokens": num_new_tokens,
"B": B,
}
)
def reward_fn(responses, items, reasoning_gym_dataset):
metrics = {}
formatted_responses = [extract_answer(response) for response in responses]
correctness_rewards = [
reasoning_gym_dataset.data.score_answer(formatted_response, entry)
for formatted_response, entry in zip(formatted_responses, items)
]
format_rewards = soft_format_reward_func(responses)
rewards = correctness_rewards + format_rewards
metrics["correctness_rewards"] = correctness_rewards
metrics["format_rewards"] = format_rewards
metrics["rewards"] = rewards
return rewards, metrics
def main(args: Args, model_config: ModelConfig, reward_fn: Callable):
calculate_runtime_args(args)
# Setup experiment tracking and seeds
all_configs = {}
beaker_config = None
if is_beaker_job():
beaker_config = maybe_get_beaker_config()
all_configs.update(vars(beaker_config))
all_configs.update(**asdict(args), **asdict(model_config))
if args.with_tracking:
import wandb
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=all_configs,
name=args.run_name,
save_code=True,
tags=[args.exp_name] + get_wandb_tags(),
)
# Setup tokenizer and get datasets
tokenizer_revision = (
model_config.model_revision if model_config.tokenizer_revision is None else model_config.tokenizer_revision
)
tokenizer_name = (
model_config.tokenizer_name if model_config.tokenizer_name is not None else model_config.model_name_or_path
)
if tokenizer_revision != model_config.model_revision:
# Warn user if tokenizer and model use different revisions; this is an unusual
# use case.
warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different
from the model revision `{model_config.model_revision}`."""
print(warning)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision)
reasoning_gym_dataset = ReasoningGymDataset(
args.dataset_name, args.seed, args.size, tokenizer, "system", SYSTEM_PROMPTS["DeepSeekZero"]
)
eval_dataset = ReasoningGymDataset(
args.dataset_name, args.eval_seed, args.eval_size, tokenizer, "system", SYSTEM_PROMPTS["DeepSeekZero"]
)
bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node]
pg = placement_group(bundles, strategy="PACK")
ray.get(pg.ready())
inits = []
policy_group = ModelGroup(
pg,
PolicyTrainerRayProcess,
args.num_learners_per_node,
args.single_gpu_mode,
)
wandb_url = wandb.run.get_url() if args.with_tracking else None
inits.extend(
model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer)
for model in policy_group.models
)
max_len = args.max_prompt_token_length + args.response_length
vllm_engines = create_vllm_engines(
args.vllm_num_engines,
args.vllm_tensor_parallel_size,
args.vllm_enforce_eager,
model_config.model_name_or_path,
model_config.model_revision,
args.seed,
args.vllm_enable_prefix_caching,
max_len,
args.vllm_gpu_memory_utilization,
args.single_gpu_mode,
pg=pg if args.single_gpu_mode else None,
)
ray.get(inits)
print("======== ✅ all models and vLLM engines initialized =========")
ray.get([m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models])
print("======== ✅ model update group setup successfully =========")
# Setup training
generation_config = SamplingParams(
temperature=args.temperature,
top_p=1.0,
max_tokens=args.response_length,
include_stop_str_in_output=True,
n=args.num_samples_per_prompt_rollout,
stop=args.stop_strings,
)
eval_generation_config = SamplingParams(
temperature=0.0,
top_p=1.0,
max_tokens=args.response_length,
include_stop_str_in_output=True,
n=1, # since we are doing greedy sampling, don't need to generate more
stop=args.stop_strings,
)
iter_dataloader = DataLoader(
reasoning_gym_dataset, batch_size=args.num_unique_prompts_rollout, collate_fn=list_preserving_collate
)
inference_results_Q = Queue(maxsize=1)
param_prompt_Q = Queue(maxsize=1)
evaluation_inference_results_Q = Queue(maxsize=1)
packed_sequences_Q = Queue(maxsize=1)
queries_prompt_Q = Queue(maxsize=1)
num_eval_samples = 50
eval_prompt_token_ids = None
eval_ground_truths = None
if eval_dataset is not None:
model_eval_input, eval_items = next(iter(eval_dataset))
eval_prompt_token_ids = model_eval_input
resume_training_step = 1
thread = threading.Thread(
target=vllm_generate_thread,
args=(
vllm_engines,
generation_config,
eval_generation_config,
inference_results_Q,
param_prompt_Q,
args.num_training_steps,
eval_prompt_token_ids,
evaluation_inference_results_Q,
args.eval_freq,
resume_training_step,
),
)
thread.start()
print("======== ✅ vllm generate thread starts =========")
packing_thread = threading.Thread(
target=data_preparation_thread,
args=(
reward_fn,
inference_results_Q,
packed_sequences_Q,
queries_prompt_Q,
args,
tokenizer,
args.num_training_steps,
reasoning_gym_dataset,
),
)
packing_thread.start()
print("======== ✅ data preparation thread starts =========")
model_inputs, items = next(iter(iter_dataloader))
queries_prompt_Q.put((model_inputs, items))
param_prompt_Q.put((None, model_inputs))
episode = 0
num_total_tokens = 0
start_time = time.time()
for training_step in range(resume_training_step, args.num_training_steps + 1):
episode += args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
if training_step != 1:
model_inputs, items = next(iter(iter_dataloader))
with Timer("🔄 Loading weights using shared memory"):
ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models])
queries_prompt_Q.put((model_inputs, items))
param_prompt_Q.put((None, model_inputs))
with Timer("[Main Thread] 📦 Getting packed sequences from thread"):
packed_data = packed_sequences_Q.get()
data_thread_metrics = packed_data["metrics"]
B = packed_data["B"]
collated_data = packed_data["collated_data"]
num_total_tokens += packed_data["num_new_tokens"]
if B == 0:
print("[Main Thread] 🤡 After packing, there is not enough data to train")
continue
# ------------------------------------------------------------------------------------------------
# Train the model
with Timer("[Main Thread] 🗡️ Training"):
metrics_list: List[dict[str, float]] = ray.get(
[
policy_group.models[i].train.remote(
**collated_data[i],
pad_token_id=tokenizer.pad_token_id,
num_mini_batches=args.num_mini_batches,
)
for i in range(args.world_size)
]
)
average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]}
metrics = {
"episode": episode,
"training_step": training_step,
"val/num_total_tokens": num_total_tokens,
"epoch": episode / len(reasoning_gym_dataset),
"tokens_per_second": num_total_tokens / (time.time() - start_time),
**data_thread_metrics,
**average_metrics,
}
if args.save_freq > 0 and training_step % args.save_freq == 0:
with Timer("[Main Thread] 🗡️ Saving model"):
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
print(f"Saving model at step {training_step} to {step_dir}")
ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)])
print(f"Saving final model at step {training_step} to {args.output_dir}")
with Timer("[Main Thread] 🗡️ Saving model"):
ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)])
thread.join()
print("======== ✅ vllm generate thread ends =========")
packing_thread.join()
print("======== ✅ data preparation thread ends =========")
ray.shutdown()
if (
args.try_auto_save_to_beaker
and is_beaker_job()
and len(beaker_config.beaker_dataset_id_urls) > 0
and args.output_dir.rstrip("/") != "/output"
):
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
print("finished training")
accelerator = Namespace()
accelerator.is_main_process = True # hack
if args.push_to_hub:
print("Pushing model to hub")
push_folder_to_hub(
accelerator,
args.output_dir,
args.hf_repo_id,
args.hf_repo_revision,
)
if __name__ == "__main__":
parser = ArgumentParserPlus((GRPOArgs, ModelConfig))
args, model_config = parser.parse_args_into_dataclasses()
assert isinstance(args, GRPOArgs)
assert isinstance(model_config, ModelConfig)
main(args, model_config, reward_fn)

View file

@ -1,104 +0,0 @@
from typing import List
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
import reasoning_gym
from reasoning_gym.utils import SYSTEM_PROMPTS
def list_preserving_collate(batch):
"""
Custom collate function that preserves lists instead of converting to tensors.
"""
token_ids = [item[0] for item in batch]
items = [item[1] for item in batch]
return token_ids, items
class ReasoningGymDataset(Dataset):
def __init__(self, dataset_name, seed, size, tokenizer, developer_role, developer_prompt):
self.data = reasoning_gym.create_dataset(dataset_name, seed=seed, size=size)
self.tokenizer = tokenizer
self.developer_role = developer_role
self.developer_prompt = developer_prompt
def __len__(self):
return self.data.size
def __getitem__(self, index):
chat_message = [{"role": self.developer_role, "content": self.developer_prompt}]
item = self.data[index]
chat_message.append({"role": "user", "content": item["question"]})
prompt_text = self.tokenizer.apply_chat_template(chat_message, tokenize=True, add_generation_prompt=True)
prompt_text = [token for token in prompt_text if token != self.tokenizer.pad_token_id]
return prompt_text, item
def pack_sequences(
queries: List[List[int]],
responses: List[List[int]],
pack_length: int,
pad_token_id: int,
) -> "PackedSequences":
# assert padding token does not exist in queries and responses
query_responses = []
attention_masks = []
response_masks = []
num_actions = []
packed_seq_lens = []
cur_data = []
cur_response_mask = []
cur_num_actions = []
cur_packed_seq_lens = []
cur_attention_mask = []
offset = 0
for i in range(len(queries)):
query = queries[i]
response = responses[i]
# remove padding (but using vllm so this should not be needed, but just in case)
query = [t for t in query if t != pad_token_id]
response = [t for t in response if t != pad_token_id]
query_response = query + response
if len(query_response) + len(cur_data) > pack_length:
query_responses.append(cur_data)
response_masks.append(cur_response_mask)
attention_masks.append(cur_attention_mask)
num_actions.append(cur_num_actions)
packed_seq_lens.append(cur_packed_seq_lens)
cur_data = []
cur_response_mask = []
cur_attention_mask = []
cur_num_actions = []
cur_packed_seq_lens = []
offset = i
cur_data.extend(query_response)
cur_num_actions.append(len(response))
cur_packed_seq_lens.append(len(query_response))
cur_response_mask.extend([0 for _ in range(len(query))] + [i + 1 for _ in range(len(response))])
cur_attention_mask.extend([i + 1 - offset for _ in range(len(query_response))])
if len(cur_data) > 0:
query_responses.append(cur_data)
response_masks.append(cur_response_mask)
attention_masks.append(cur_attention_mask)
num_actions.append(cur_num_actions)
packed_seq_lens.append(cur_packed_seq_lens)
attention_masks_list = [torch.tensor(t) for t in attention_masks]
from open_instruct.rl_utils2 import PackedSequences, reset_position_ids
return PackedSequences(
query_responses=[torch.tensor(t) for t in query_responses],
attention_masks=attention_masks_list,
position_ids=[reset_position_ids(t.unsqueeze(0)).squeeze(0) for t in attention_masks_list],
response_masks=[torch.tensor(t) for t in response_masks],
original_responses=responses,
num_actions=[torch.tensor(t) for t in num_actions],
packed_seq_lens=[torch.tensor(t) for t in packed_seq_lens],
)

View file

@ -1,15 +1,10 @@
# Reasoning Gym configs
dataset_size: 20000
dataset_size: 10000
developer_prompt: DeepSeekZero
developer_role: system
datasets:
simple_equations:
chain_sum:
weight: 1
complex_arithmetic:
weight: 1
config:
min_real: -20
max_real: 20
# Model configs from trl
@ -34,9 +29,9 @@ lr_scheduler_type: constant_with_warmup
lr_scheduler_kwargs:
num_warmup_steps: 10
max_prompt_length: 512
max_completion_length: 2048
max_completion_length: 512
max_steps: 100
num_generations: 8
num_generations: 16
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 8

View file

@ -27,7 +27,7 @@ class DatasetConfigItem:
@dataclass
class DatasetConfig:
dataset_size: int = field(default=1000)
dataset_size: int = field(default=10000)
developer_prompt: str = field(default="DeepSeekZero")
developer_role: str = field(default="system")
datasets: dict[str, DatasetConfigItem] = field(default=None)

View file

@ -0,0 +1,30 @@
# Chain Sum LORA Training with unsloth
This example demonstrates how to fine-tune an LLM with RL on a reasoning gym environment using the **unsloth** framework. Unsloth is a efficient open-source library for fine-tuning & RL. Unsloths default training path uses quantised low rank adaption (QLORA) which results in a signficantly lower memory footprint ($\approx 3x$) and means you can significantly increase batch sizes and context length without risking OOM errors.
Requirements:
python >= 3.10
## Installation
1. **Install reasoning-gym**:
```bash
pip install reasoning-gym
```
2. **Install unsloth dependencies**:
```bash
pip install -r requirements.txt
```
3. **Run training script**
To start training with unsloth with RG environments using default arguments run the following:
```bash
python train_grpo_lora.py
```
To customise/override any default arguments you can simply:
```bash
python train_grpo_lora.py --dataset-name chain_sum --max-seq-length 512 --model-id Qwen/Qwen2.5-7B-Instruct
**Note** the free open-source version of unsloth is currently built to train models in single GPU environments only.

View file

@ -147,22 +147,24 @@ def main(args):
training_args = GRPOConfig(
output_dir="outputs",
use_vllm=True,
learning_rate=5e-6,
learning_rate=1e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
weight_decay=0.0,
warmup_ratio=0.0,
lr_scheduler_type="constant",
optim="adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=args.train_batch_size,
gradient_accumulation_steps=1,
gradient_accumulation_steps=4,
num_generations=args.num_generations,
num_train_epochs=args.train_epochs,
max_prompt_length=512,
max_completion_length=512,
save_steps=100,
max_grad_norm=0.1,
max_grad_norm=1.0,
)
train(model, tokenizer, dataset, training_args)
@ -187,17 +189,17 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
parser.add_argument("--dataset-name", type=str)
parser.add_argument("--dataset-name", type=str, default="chain_sum")
parser.add_argument("--max-seq-length", type=int, default=1024)
parser.add_argument("--lora-rank", type=int, default=64)
parser.add_argument("--quantize", action="store_true")
parser.add_argument("--num-generations", type=int, default=8)
parser.add_argument("--num-generations", type=int, default=16)
parser.add_argument("--train-epochs", type=int, default=1)
parser.add_argument("--train-batch-size", type=int, default=8)
parser.add_argument("--train-batch-size", type=int, default=16)
parser.add_argument("--dataset-seed", type=int, default=42)
parser.add_argument("--dataset-size", type=int, default=1000)
parser.add_argument("--dataset-size", type=int, default=10000)
parser.add_argument("--eval-seed", type=int, default=42)
parser.add_argument("--eval-size", type=int, default=100)

View file

@ -1,3 +0,0 @@
outputs/
wandb/
verl_output.log

View file

@ -1,32 +1,74 @@
### env setup
# Chain Sum Training with veRL
```
conda create --name verl python=3.11 -y
conda activate verl
This example demonstrates how to train a language model using veRL (Volcano Engine Reinforcement Learning) with the reasoning-gym environment for chain sum problems.
pip install flash-attn --no-build-isolation
pip install ray wandb
# pip3 install vllm==0.7.0
pip3 install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
Requirements:
python >= 3.10
## Installation
1. **Install veRL**: Follow the installation instructions at [veRL repository](https://github.com/volcengine/verl)
2. **Install reasoning-gym**:
```bash
pip install reasoning-gym
```
## Training
To start training the model on chain sum problems:
```bash
python grpo_train.py --config-path config --config-name grpo_trainer
```
Regarding vllm>0.7 see: [docs](https://verl.readthedocs.io/en/latest/README_vllm0.7.html)
### Configuration
You can modify the training by editing the configuration file or overriding arguments in the shell scripts directly
### clone and install veRL
tested with verl HEAD c34206925e2a50fd452e474db857b4d488f8602d
```
git clone https://github.com/volcengine/verl.git
cd verl
pip install -e .
```bash
# Change dataset
Here it is easiest to modify the `config/grpo_trainer.yaml` file with a custom training composite. Here is an example experiment which uses a composite of algorithmic training tasks
```yaml
reasoning_gym:
dataset_size: 20000
developer_prompt: DeepSeekZero
datasets:
ab:
weight: 1
base_conversion:
weight: 1
binary_alternation:
weight: 1
config:
p_solvable: 0.9
binary_matrix:
weight: 1
config:
min_n: 2
max_n: 6
caesar_cipher:
weight: 1
config:
max_words: 10
cryptarithm:
weight: 1
isomorphic_strings:
weight: 1
config:
max_string_length: 8
```
**Note**: In `config/grpo_trainer.yaml` we specify default arguments to be read from `file:///workspace/verl/verl/trainer/config`. Modify this accordingly if you have a different local folder setup.
Optionally log in to huggingface hub and wandb with your keys:
# Change configuration Set project_name and experiment_name if logging your runs to W&B. T
This config assumes a single GPU node, but you can configure this too. The following command would be for 2 GPUs, with 1 used for vLLM rollouts:
```
huggingface-cli login
wandb login
```
python3 -u train_grpo.py --config-paths configs/inter_generalisation --config-name algorithmic_qwen_3b \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
trainer.n_gpus_per_node=2 \
trainer.project_name=rg-grpo \
trainer.experiment_name=algorithmic_qwen2.5_3b
Or similarly you could define this in a config file directly

View file

@ -1,12 +1,15 @@
hydra:
searchpath:
- file:///workspace/verl/verl/trainer/config
defaults:
- ppo_trainer
- _self_
reasoning_gym:
dataset_size: 20000
dataset_size: 10000
developer_prompt: DeepSeekZero
datasets:
ab:
chain_sum:
weight: 1
data:
@ -16,7 +19,7 @@ data:
prompt_key: prompt
max_prompt_length: 512
max_response_length: 512
train_batch_size: 16
train_batch_size: 64
val_batch_size: 1
actor_rollout_ref:
@ -32,7 +35,7 @@ actor_rollout_ref:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 16
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: 4
ppo_micro_batch_size_per_gpu: 8
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
@ -70,7 +73,7 @@ actor_rollout_ref:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: 4
log_prob_micro_batch_size_per_gpu: 8
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
@ -90,7 +93,7 @@ actor_rollout_ref:
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 1
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
@ -184,14 +187,14 @@ algorithm:
trainer:
balance_batch: True
total_epochs: 30
total_epochs: 1
total_training_steps: null
project_name: verl_examples
experiment_name: chain_sum
logger: [ 'console', 'wandb' ]
val_generations_to_log_to_wandb: 0
nnodes: 1
n_gpus_per_node: 1
n_gpus_per_node: 2
save_freq: -1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if

View file

@ -1,72 +0,0 @@
# Chain Sum Training with veRL
This example demonstrates how to train a language model using veRL (Volcano Engine Reinforcement Learning) with the reasoning-gym environment for chain sum problems.
Requirements:
python >= 3.10
## Installation
1. **Install veRL**: Follow the installation instructions at [veRL repository](https://github.com/volcengine/verl)
2. **Install reasoning-gym**:
```bash
pip install reasoning-gym
```
## Training
To start training the model on chain sum problems:
```bash
python grpo_train.py --config-path config --config-name grpo_trainer
```
### Configuration
You can modify the training by editing the configuration file or overriding arguments in the shell scripts directly
```bash
# Change dataset
Here it is easiest to modify the `config/grpo_trainer.yaml` file with a custom training composite. Here is an example experiment which uses a composite of algorithmic training tasks
```yaml
reasoning_gym:
dataset_size: 20000
developer_prompt: DeepSeekZero
datasets:
ab:
weight: 1
base_conversion:
weight: 1
binary_alternation:
weight: 1
config:
p_solvable: 0.9
binary_matrix:
weight: 1
config:
min_n: 2
max_n: 6
caesar_cipher:
weight: 1
config:
max_words: 10
cryptarithm:
weight: 1
isomorphic_strings:
weight: 1
config:
max_string_length: 8
```
# Change configuration Set project_name and experiment_name if logging your runs to W&B. T
This config assumes a single GPU node, but you can configure this too. The following command would be for 2 GPUs, with 1 used for vLLM rollouts:
python3 -u train_grpo.py --config-paths configs/inter_generalisation --config-name algorithmic_qwen_3b \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
trainer.n_gpus_per_node=2 \
trainer.project_name=rg-grpo \
trainer.experiment_name=algorithmic_qwen2.5_3b
Or similarly you could define this in a config file directly

View file

@ -1,131 +0,0 @@
# Word Ladder Puzzle Dataset Generator
## Overview
This project generates a dataset of word ladder puzzles and (optionally) submits chain-of-thought reasoning requests using Anthropic's Message Batches API. Each puzzle is stored as a JSON object with the following keys: `question`, `answer`, `metadata`, and `reasoning` (initially `null`).
The project consists of several key components:
- **`main.py`**:
Orchestrates the overall flow. It performs the following tasks:
1. Generates a dataset of word ladder puzzles by calling functions from `utils/create_word_ladders.py`.
2. (Optionally) Triggers the reasoning request process to augment puzzles with chain-of-thought reasoning via `utils/generate_reasoning.py`.
3. (Planned) Additional steps such as checking results or uploading the final dataset.
The configuration for the dataset parameters (e.g., word length, chain length, and dataset size) is centralized here, making it easy to adjust the settings as needed.
- **`utils/create_word_ladders.py`**:
Contains functions to create and validate a word ladder dataset. It leverages underlying modules (e.g., `reasoning_gym`) to generate individual puzzles and ensures uniqueness across the dataset.
- **`utils/generate_reasoning.py`**:
Reads the generated dataset (in JSONL format), then filters out puzzles that already have reasoning. For puzzles missing chain-of-thought data, it splits them into batches (with a default batch size that you can adjust) and submits each batch to Anthropic's Message Batches API. Each API request includes the puzzle along with a custom system prompt (read from `system_prompt.txt`), and the resulting metadata is stored for later retrieval and analysis.
- **`usage_stats.py`**:
Analyzes API response files to compute detailed usage statistics. This script:
- Extracts token usage metrics such as `input_tokens`, `cache_creation_input_tokens`, `cache_read_input_tokens`, and `output_tokens`.
- Calculates costs based on pricing data and shows the savings achieved through prompt caching.
- Forecasts costs for various quantities of jobs (e.g., 2,000, 4,000, 10,000, 20,000, and 50,000 puzzles) using the observed average token usage.
This is especially useful for monitoring your API spend and ensuring that your usage stays within budget.
## Warning
**Caution:**
Running large batches of requests via the Anthropic API (especially in `generate_reasoning.py`) can incur significant costs in Anthropic credits. **Please review and understand your API quota and budgeting before running the API call.** If you are just testing or working with a demo dataset, ensure you adjust the batch size or dataset size appropriately to avoid unexpected charges.
## Prerequisites
- **Python Version:** Python 3.7+
- **Dependencies:**
- `tqdm`
- `anthropic`
- `reasoning_gym`
- **Environment Variables:**
For generating reasoning batches, set your Anthropic API key:
```bash
export ANTROPIC_API_KEY=your_api_key_here
```
## Directory Structure
```
examples/word_ladder/
├── main.py
├── utils/
│ ├── create_word_ladders.py
│ ├── generate_reasoning.py
│ └── system_prompt.txt
├── usage_stats.py
```
## Configuration
The dataset generation parameters are centralized in `main.py` under the `config` dictionary. You can adjust settings like:
- **Word Length:**
- `min_word_length`
- `max_word_length`
- **Chain Length:**
- `min_chain_length` (e.g., set to -1 for the shortest possible chain)
- `max_chain_length`
- **Dataset Size:**
- `size` — the number of puzzles to generate (e.g., `1000` for a demo)
## How to Run
1. **Generate the Dataset**
Run the main script:
```bash
python3 main.py
```
This does the following:
- Generates a unique JSONL file containing the word ladder puzzles in the `output` folder.
- Calls functions from `utils/create_word_ladders.py` to create the puzzles.
- Optionally (if enabled), submits the puzzles for chain-of-thought reasoning via the API.
2. **Submit Reasoning Batches (Optional)**
To generate chain-of-thought reasoning for puzzles:
- Verify that `ANTHROPIC_API_KEY` is set.
- Confirm that `system_prompt.txt` is present in the `/examples/word_ladder` folder and contains the desired system prompt.
- In `main.py`, uncomment the reasoning submission section to enable the API call, or run directly:
```bash
python3 utils/generate_reasoning.py
```
**Warning:** Be aware that submitting large batches can quickly incur high costs in Anthropic credits.
3. **Compute Usage Statistics**
After running batches through the API, you can analyze the cost and token usage statistics with:
```bash
python3 usage_stats.py path/to/msgbatch_results.jsonl
```
This script provides detailed costing information, token usage per query, savings from caching, and forecasting for future job batches.
## Output
- All generated datasets and batch metadata files are stored in the `/examples/word_ladder/output` folder.
- After submitting reasoning batches via Anthropic's API, you can monitor progress and download the batch results from the Anthropic web dashboard.
- Use `usage_stats.py` to compute detailed statistics and forecast future costs based on your current usage and token pricing.
## Troubleshooting
- **File Paths:**
Verify that `system_prompt.txt` is in the `/examples/word_ladder` folder as expected. The modules use paths relative to their location.
- **Environment Variables:**
Make sure your `ANTHROPIC_API_KEY` is set correctly when submitting API requests.
- **Output Directory Permissions:**
Ensure the `output` directory exists and is writable by your user.
- **Cost Monitoring:**
Check your Anthropic API usage and account balance before running large batches to avoid unexpected costs.
## License
This project is licensed under the MIT License.

View file

@ -1,84 +0,0 @@
#!/usr/bin/env python3
"""
main.py Orchestrates the overall flow:
1. Generate word ladder sets
2. Submit chain-of-thought reasoning requests in batches via the LLM
3. Upload the final dataset to HuggingFace Hub (if needed)
"""
import sys
import uuid
from pathlib import Path
from typing import Any
from examples.word_ladder.utils import create_word_ladders, generate_reasoning
def create_dataset(jsonl_path: Path, config: dict[str, Any]) -> bool:
"""
Creates the word ladder dataset, handling potential exhaustion gracefully.
Returns:
bool: True if dataset was created (even if truncated), False if creation failed
"""
try:
print("Step 1: Algorithmically creating word ladder chains...")
create_word_ladders.create_word_ladder_dataset(str(jsonl_path), config=config)
return True
except IndexError as e:
# Dataset was exhausted but some examples were generated
print("\nNote: Dataset generation stopped early due to exhaustion of unique puzzles.")
print(f"Reason: {str(e)}")
if jsonl_path.exists():
print("Continuing with the partial dataset that was successfully generated.")
return True
return False
except Exception as e:
# Unexpected error during dataset creation
print(f"\nError: Failed to create dataset: {str(e)}")
return False
def main():
# Centralized configuration for the dataset
config = {
"dataset_name": "word_ladder",
"dataset_config": {
"min_word_length": 3,
"max_word_length": 3,
"min_chain_length": -1, # set to -1 for the shortest possible path
"max_chain_length": 7,
"size": 2000, # Generate a small-ish dataset for demonstration
},
}
# Generate a friendly unique identifier and compose the file path
unique_id = uuid.uuid4().hex[:8]
output_dir = Path(__file__).resolve().parent / "output"
output_dir.mkdir(exist_ok=True) # Create output directory if it doesn't exist
jsonl_path = output_dir / f"word_ladders_{unique_id}.jsonl"
# Step 1: Create the dataset
if not create_dataset(jsonl_path, config):
print("Exiting due to dataset creation failure.")
sys.exit(1)
# Step 2: Generate reasoning
try:
print("\nStep 2: Submitting reasoning batches for the dataset...")
generate_reasoning.submit_reasoning_batches(input_path=str(jsonl_path))
except Exception as e:
print(f"\nError: Failed to submit reasoning batches: {str(e)}")
sys.exit(1)
# Step 3: Check Anthropic batch results
# Step 4: Upload to HuggingFace 🤗
print("\nComplete!")
if __name__ == "__main__":
main()

View file

@ -1 +0,0 @@
{"batch_id": "msgbatch_01J7Dev6petQwq41gNpeGhUK", "api_response": {"id": "msgbatch_01J7Dev6petQwq41gNpeGhUK", "type": "message_batch", "processing_status": "in_progress", "request_counts": {"canceled": 0, "errored": 0, "expired": 0, "processing": 1000, "succeeded": 0}, "ended_at": null, "created_at": "2025-02-03T06:07:04.440375Z", "expires_at": "2025-02-04T06:07:04.440375Z", "cancel_initiated_at": null, "results_url": null}, "custom_ids": ["PAINT_LACKS_0", "CONKS_GRAYS_1", "WINGS_WORMY_2", "PINTO_RASPS_3", "GRACE_WAINS_4", "RACES_VAPES_5", "LAIRS_FUELS_6", "WINKS_TENTH_7", "VERBS_DATES_8", "SCENT_HEARS_9", "GRACE_CHIVE_10", "COOPS_SPARE_11", "ROOMY_PLUME_12", "PORKS_HOSTA_13", "DEARS_TABBY_14", "SNAIL_ADOBE_15", "SMALL_BLINK_16", "SOAPS_SPOTS_17", "HOMES_LOVEY_18", "SPURN_STEEP_19", "PLODS_SWIFT_20", "TACOS_SETTS_21", "BANKS_KICKY_22", "POUTS_CAKES_23", "RESTS_SWUNG_24", "PORKY_FISKS_25", "PUMAS_NAPES_26", "WANDS_PRICE_27", "DROLL_SORES_28", "JOWLS_PACTS_29", "ATONE_GOOFS_30", "BRIGS_SHUNT_31", "RAPID_DOXES_32", "CHESS_DOLED_33", "BUDGE_BOOBS_34", "ROPED_SAGAS_35", "TATTS_SPARE_36", "RALES_NATES_37", "RUFFS_GONGS_38", "VEGES_HUNTS_39", "SCULL_PAIRS_40", "DICKY_MOWED_41", "PANKO_ADDER_42", "TILLS_GAINS_43", "OASIS_DOXED_44", "BONED_MEALS_45", "LAMBS_TOADY_641", "PIKER_HEAVY_642", "BANKS_TRUGS_643", "VEEPS_DIVAS_644", "CHATS_CLEFS_645", "SNITS_PLAIN_646", "CHARY_SPITE_647", "EARLY_DUMMY_977", "MEDAL_COYER_978", "FEATS_MOOKS_979", "MALES_HEXES_980", "DOGGO_MUGGY_981", "GIRDS_SOAPS_982", "SLASH_TRYST_983", "LOOPS_FLOES_984", "SLATE_SMORE_985", "SINEW_SULKS_986", "SCORN_SHOUT_987", "MAMAS_VOIDS_988", "TWIGS_BLOKE_989", "KOOKY_SPINE_990", "FLUES_SMITE_991", "DRAWS_MAIMS_992", "SMITH_PLEAS_993", "BERGS_OPINE_994", "SNAKE_BLABS_995", "WADED_DRIBS_996", "YAPPY_PIPES_997", "STAGY_COONS_998", "PEAKY_MURAL_999"], "input_file": "word_ladders_029573f4.jsonl"}

View file

@ -1,5 +0,0 @@
{"custom_id":"PAINT_LACKS_0","result":{"type":"succeeded","message":{"id":"msg_01Kif1LqT19GgPdSEQiuFjwu","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"<think>\nOkay, let me work on transforming PAINT into LACKS by changing one letter at a time. First, let me analyze both words to understand what needs to change.\n\nPAINT → LACKS\n\nLooking at the letters:\nP needs to become L\nA stays A\nI needs to become C\nN needs to become K\nT needs to become S\n\nSo four letters need to change (all except A). Let me start exploring possible paths by changing one letter at a time.\n\nStarting with PAINT. What changes can I make? Let's see:\n- Change P to something: FAINT, SAINT, TAINT (all valid)\n- Change I to something: PANT (valid)\n- Change N to something: PAINT to PAIT (not valid)\n- Change T to something: PAIN (valid, but loses a letter... no, need to keep 5 letters)\nWait, what about PAINS (changing T to S)? Yes, PAINS is a valid word! \n\nSo first step: PAINT → PAINS (changing T to S)\n\nNow from PAINS, what can I do? Need to eventually get to LACKS. Let's see possible changes:\n- Change P to something: GAINS, MAINS, RAINS, WAINS\n- Change I to something: PANS (loses a letter, not allowed)\n- Change N to something: PAIRS (valid, but moves away from target LACKS)\n\nWAINS looks interesting (changing P to W). Is WAINS a valid word? Yes, it refers to wagons or carts historically. \n\nStep 2: PAINS → WAINS\n\nNow from WAINS, what's next? Looking at LACKS, we still need to get that C and K in there. Let's see:\n- Change A to something: WINS (loses letter)\n- Change I to something: WANS (loses letter)\n- Change N to something: WAITS (valid but moves away from target)\n- Change S to something: WAINT (not a word)\nWait, what about WARNS (changing I to R)? Yes! WARNS is definitely a valid word.\n\nStep 3: WAINS → WARNS\n\nNow from WARNS to LACKS. Still need L, C, and K. From WARNS:\n- Change W to something: BARNS, EARNS, YARNS\nBARNS looks promising! Because B is closer to L in the alphabet than W. \n\nStep 4: WARNS → BARNS\n\nFrom BARNS to LACKS. Let's see:\n- Change N to something: BARS (loses letter)\n- Change S to something: BARN (loses S we need)\n- Change R to K perhaps? BANKS? No, need to keep R for now.\nWait! What about BARKS (changing N to K)? Yes! That works!\n\nStep 5: BARNS → BARKS\n\nNow BARKS. Getting closer! We need to change R to C and B to L to get to LACKS. Let's try:\n- Change B to something: LARKS, MARKS, PARKS\n- Change R to C: BACKS! Yes, that works!\n\nStep 6: BARKS → BACKS\n\nFinally, from BACKS to LACKS, just need to change B to L:\n\nStep 7: BACKS → LACKS\n\nLet me verify the entire sequence:\nPAINT → PAINS (T→S)\nPAINS → WAINS (P→W)\nWAINS → WARNS (I→R)\nWARNS → BARNS (W→B)\nBARNS → BARKS (N→K)\nBARKS → BACKS (R→C)\nBACKS → LACKS (B→L)\n\nChecking each word:\nPAINT: valid\nPAINS: valid (plural of pain)\nWAINS: valid (old wagons)\nWARNS: valid (gives warning)\nBARNS: valid (farm buildings)\nBARKS: valid (tree covering or dog sound)\nBACKS: valid (plural of back)\nLACKS: valid (is without)\n\nEach step changes exactly one letter and creates a valid English word. Seven steps total, which makes sense since we needed to change four letters to get from PAINT to LACKS, and we had to make some strategic intermediate changes to get there. This seems like a valid solution!\n\nLet me double-check if there might be a shorter path... Could we have gone:\nPAINT → PANT (remove I? No, must keep 5 letters)\nPAINT → PARTS (change N to R and T to S? No, that's two changes at once)\nPAINT → PANIC (T to C, but moves away from target)\n\nNo, I think the path we found is optimal because we needed those intermediate steps to gradually transform the word while keeping valid English words at each step. The sequence PAINT→PAINS→WAINS→WARNS→BARNS→BARKS→BACKS→LACKS works perfectly!\n</think>"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":63,"cache_creation_input_tokens":0,"cache_read_input_tokens":6434,"output_tokens":1149}}}}
{"custom_id":"CONKS_GRAYS_1","result":{"type":"succeeded","message":{"id":"msg_01JrxHJyxFxsTVsn8huufv8T","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"<think>\nAlright, let me try to transform CONKS into GRAYS by changing one letter at a time. First, let me write down both words to analyze what needs to change:\n\nCONKS\nGRAYS\n\nLooking at the letters that need to change:\nC needs to become G\nO needs to become R\nN needs to become A\nK needs to become Y\nS stays S\n\nHmm, four letters need to change, but S stays the same at the end. Let me start by looking at possible changes from CONKS.\n\nFirst letter options from CONKS:\n- Change C to B: BONKS (valid)\n- Change C to D: DONKS (valid)\n- Change C to G: GONKS (is this valid? Not sure)\n- Change C to H: HONKS (valid)\n\nSecond letter options:\n- Change O to A: CANKS (is this valid? Not sure)\n- Change O to I: CINKS (probably not valid)\n- Change O to O: stays same\n- Change O to U: CUNKS (not sure if valid)\n\nWait, maybe instead of changing the first letter right away to G, let's try working with the O. What if we change O to another letter that might help us get closer? \n\nLet me try: CONKS → COOKS (changing N to O)\nYes! COOKS is definitely a valid word. That feels like progress.\n\nFrom COOKS, what can we do?\n- Change first C to G? GOOKS (probably not appropriate)\n- Change second O to R? No, need to keep both O's for now\n- Change K to something? COOLS? No, moving away from target\n- Change S? No, need to keep S\n\nWhat about changing one of the O's to something? COOKS → COOPS (changing K to P)\nYes! COOPS is valid (like chicken coops). This feels promising.\n\nNow from COOPS:\n- Change C to G? GOOPS (not sure if valid)\n- Change first O to R? CROPS! Yes, that works!\n\nCROPS is a valid word. Now we're getting somewhere. Let's see what we can do from CROPS:\n- Change O to A? CRAPS! That's valid and gets us closer to GRAYS\n\nFrom CRAPS:\n- Change P to Y? CRAYS! Perfect!\n\nAnd finally, from CRAYS:\n- Change C to G: GRAYS! We've reached the target!\n\nLet me verify this path:\n1. CONKS → COOKS (N to O)\n2. COOKS → COOPS (K to P)\n3. COOPS → CROPS (O to R)\n4. CROPS → CRAPS (O to A)\n5. CRAPS → CRAYS (P to Y)\n6. CRAYS → GRAYS (C to G)\n\nLet me check each word again:\nCONKS - valid (to hit on head)\nCOOKS - valid\nCOOPS - valid (enclosures)\nCROPS - valid\nCRAPS - valid\nCRAYS - valid (types of fish)\nGRAYS - valid\n\nThis seems to work perfectly! Each step changes exactly one letter, and all words are valid English words. Let me double-check if there might be a shorter path.\n\nCould we go CONKS → CONES? No, that doesn't help get to GRAYS.\nWhat about CONKS → CRANKS? No, that's two letter changes.\nOr CONKS → CORKS? Again, doesn't seem to lead anywhere useful.\n\nThe path we found:\nCONKS → COOKS → COOPS → CROPS → CRAPS → CRAYS → GRAYS\nseems optimal because:\n1. Each step is a valid word\n2. Each change moves us closer to our target\n3. We maintain the final S throughout\n4. We gradually transform the word while keeping it valid at each step\n\nYes, I think this is the solution. Six steps, changing one letter at a time, all valid words, and each change gets us closer to our target GRAYS.\n</think>"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":64,"cache_creation_input_tokens":0,"cache_read_input_tokens":6434,"output_tokens":956}}}}
{"custom_id":"WINGS_WORMY_2","result":{"type":"succeeded","message":{"id":"msg_01KXJZXTZgQoTnLvMqHphEcS","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"<think>\nOkay, let me tackle this word ladder from WINGS to WORMY. First, I'll write both words down and analyze what letters need to change:\n\nWINGS\nWORMY\n\nLooking at the letters:\nW stays the same\nI needs to become O\nN needs to become R\nG needs to become M\nS needs to become Y\n\nSo four letters need to change, keeping only the W. Let me start exploring possible paths by changing one letter at a time.\n\nStarting with WINGS, what can I change? Let's look at each position:\n- First letter (W): stays the same, so no need to change\n- Second letter (I): could become A, E, O, U\n- Third letter (N): could become other consonants\n- Fourth letter (G): could become other consonants\n- Fifth letter (S): could become other letters\n\nLet me try changing I first since it needs to become O eventually:\nWINGS → WONGS? Is that a word? Not sure. \n\nMaybe change N first:\nWINGS → WIGS? No, that removes a letter. Wait, I need to keep it 5 letters.\nWINGS → WINGS (same)\nWINGS → WINKS (changing G to K)! Yes, WINKS is definitely a valid word.\n\nOkay, so WINGS → WINKS works. Now from WINKS, what's next? I still need to:\n- Change I to O\n- Change K to R\n- Change S to Y\n\nLet me try changing I to O in WINKS:\nWINKS → WONKS\nIs WONKS a word? I think it is - it's the plural of wonk, someone who studies a subject extensively. Yes, that works!\n\nSo now I have:\nWINGS → WINKS → WONKS\n\nFrom WONKS, I need to:\n- Change K to R\n- Change S to Y\n\nLet me try changing K to R:\nWONKS → WORKS\nPerfect! WORKS is definitely a valid word.\n\nSo now:\nWINGS → WINKS → WONKS → WORKS\n\nFrom WORKS, I still need to:\n- Change S to Y\n\nBut wait, looking at the target WORMY, I notice I need M before Y. So from WORKS, maybe I should change S to M first?\n\nWORKS → WORMS\nYes! WORMS is a valid word.\n\nNow I have:\nWINGS → WINKS → WONKS → WORKS → WORMS\n\nFinally, from WORMS to WORMY, just need to change S to Y:\nWORMS → WORMY\n\nLet me verify the complete path:\n1. WINGS → WINKS (change G to K)\n2. WINKS → WONKS (change I to O)\n3. WONKS → WORKS (change K to R)\n4. WORKS → WORMS (change K to M)\n5. WORMS → WORMY (change S to Y)\n\nLet me check if each word is valid:\n- WINGS: valid (plural of wing)\n- WINKS: valid (plural of wink)\n- WONKS: valid (plural of wonk)\n- WORKS: valid (plural of work)\n- WORMS: valid (plural of worm)\n- WORMY: valid (full of worms)\n\nCould there be a shorter path? Let me think...\nCould I go WINGS → WONGS (I to O)? But is WONGS a valid word? Not sure about that.\nOr WINGS → WRINGS (add R)? No, that would be adding a letter, not changing one.\nOr WINGS → WORMS directly? No, that would require changing multiple letters at once.\n\nThe path I found requires five steps, which makes sense because we needed to change four letters (I→O, N→R, G→M, S→Y) and couldn't combine any changes since we can only change one letter at a time.\n\nSo after exploring alternatives, I believe the path:\nWINGS → WINKS → WONKS → WORKS → WORMS → WORMY\nis optimal, with each step changing exactly one letter and forming a valid English word.\n</think>"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":58,"cache_creation_input_tokens":0,"cache_read_input_tokens":6434,"output_tokens":962}}}}
{"custom_id":"PINTO_RASPS_3","result":{"type":"succeeded","message":{"id":"msg_01GJQuw2NwcJQpx2civyj8Fu","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"<think>\nAlright, I need to transform PINTO into RASPS by changing one letter at a time. Let me analyze both words first to understand what needs to change.\n\nPINTO\nRASPS\n\nLooking at the letters:\nP -> R\nI -> A\nN -> S\nT -> P\nO -> S\n\nHmm, all five letters need to change. This might be tricky. Let me start by exploring possible changes from PINTO.\n\nFirst letter (P) possibilities:\n- MINTO (is that a word? Maybe a proper noun, but that's not allowed)\n- RINTO (not a word)\n- TINTO (might be Spanish, not valid English)\n\nSecond letter (I) possibilities:\n- PANTO (yes, that's valid!)\n- PENTO (not sure if that's valid)\n- PUNTO (Spanish, not English)\n\nThird letter (N) possibilities:\n- PINTO -> PINTS (changing O to S)\nThat's interesting! PINTS is definitely valid. Maybe that's a good first step since we need to end with 'S' anyway.\n\nLet me try PINTO -> PINTS first.\nNow from PINTS, what can we do? We need to eventually get to RASPS, so maybe changing that I to A would help?\nPINTS -> PANTS (changing I to A)\nYes! PANTS is valid. Now we have:\nPINTO -> PINTS -> PANTS\n\nFrom PANTS, looking at the letters compared to RASPS:\nP -> R\nA (matches!)\nN -> S\nT -> P\nS (matches!)\n\nMaybe we can change that T to R? No, that wouldn't help because we still need P in that position eventually. What about changing N to R? No, we need S there. \n\nLet me try changing T to something... PANTS -> PARTS (changing N to R)\nYes! PARTS is valid. So now:\nPINTO -> PINTS -> PANTS -> PARTS\n\nFrom PARTS to RASPS, we need to:\n- Change P to R at the start\n- Change T to P\n- Keep the S\n\nWait, but if I change P to R now, how would I get the P back later? Maybe I need an intermediate step.\n\nWhat if... PARTS -> WARTS (changing P to W)?\nYes! That's valid. Then from WARTS, maybe... WARPS (changing T to P)?\nWARTS -> WARPS works! \n\nSo now I have:\nPINTO -> PINTS -> PANTS -> PARTS -> WARTS -> WARPS\n\nGetting closer! From WARPS to RASPS, I need to:\n- Change W to R\n- Change P to S in the middle\n\nBut wait, if I change W to R directly, would that work? Let me try another approach...\nWARPS -> WASPS (changing R to S)\nYes! That works. Then finally:\nWASPS -> RASPS (changing W to R)\n\nLet me verify the complete path:\n1. PINTO -> PINTS (O to S)\n2. PINTS -> PANTS (I to A)\n3. PANTS -> PARTS (N to R)\n4. PARTS -> WARTS (P to W)\n5. WARTS -> WARPS (T to P)\n6. WARPS -> WASPS (R to S)\n7. WASPS -> RASPS (W to R)\n\nChecking each word:\nPINTO - valid (type of horse/car)\nPINTS - valid (plural of pint)\nPANTS - valid (clothing)\nPARTS - valid (components)\nWARTS - valid (skin growths)\nWARPS - valid (bends/distorts)\nWASPS - valid (insects)\nRASPS - valid (coarse files/harsh sounds)\n\nEach step changes exactly one letter and creates a valid English word. Seven steps in total. Let me double-check if there might be a shorter path...\n\nCould I have gone from PINTS directly to PARTS? No, that would be changing two letters (I to A and N to R).\nCould I have skipped WARTS? If I went PARTS -> WARPS directly, that would change two letters (P to W and T to P).\nCould I have gone from WARPS to RASPS directly? No, that would change two letters (W to R and P to S).\n\nNo, each step seems necessary. Therefore, this seven-step solution appears to be valid and efficient:\nPINTO -> PINTS -> PANTS -> PARTS -> WARTS -> WARPS -> WASPS -> RASPS\n</think>"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":67,"cache_creation_input_tokens":6434,"cache_read_input_tokens":0,"output_tokens":1081}}}}
{"custom_id":"GRACE_WAINS_4","result":{"type":"succeeded","message":{"id":"msg_01He4e1HXh1DXdxPz2j7K3GZ","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[{"type":"text","text":"<think>\nLet me tackle this word ladder puzzle. I need to transform GRACE into WAINS by changing just one letter at a time, making sure each step creates a valid English word.\n\nFirst, let me compare the starting and ending words:\nGRACE\nWAINS\n\nLooking at what needs to change:\nG needs to become W\nR needs to become A\nA needs to become I\nC needs to become N\nE needs to become S\n\nWow, all five letters need to change! This might be tricky. Let me start by looking at possible changes from GRACE.\n\nFrom GRACE, I could change:\n- First letter (G): BRACE, TRACE, SPACE... but I need to get to W eventually\n- Second letter (R): GRACE to GLACE? That's valid\n- Third letter (A): GRACE to GRECE? No, that's not a word\n- Fourth letter (C): GRACE to GRADE, GRAVE, GRAPE... wait, GRAPE is interesting\n\nLet me try GRACE → GRAPE (changing C to P)\nThat works! Now from GRAPE, I need to somehow get closer to WAINS.\n\nLooking at GRAPE vs WAINS:\nG still needs to become W\nR still needs to become A\nA needs to become I\nP needs to become N\nE needs to become S\n\nMaybe I can change the A to I next? GRAPE → GRIPE\nYes! GRIPE is a valid word. That's progress because now the middle letter matches our target word.\n\nFrom GRIPE:\nG still needs to become W\nR still needs to become A\nP needs to become N\nE needs to become S\n\nLooking at GRIPE... what can I change? The final E needs to become S eventually. But wait, looking at the P... if I change P to S, would that work? GRIPE → GRISE? No, that's not common. What about GRIPS? Yes! GRIPE → GRIPS works by changing E to S.\n\nNow I have GRIPS. Getting closer! Looking at what's left:\nG needs to become W\nR needs to become A\nI stays I\nP needs to become N\nS stays S\n\nFrom GRIPS, maybe I can change P to N? GRIPS → GRINS\nYes! That works. Now I have GRINS.\n\nLooking at GRINS vs WAINS:\nG needs to become W\nR needs to become A\nI stays I\nN stays N\nS stays S\n\nFrom GRINS, I need to change R to A. GRINS → GAINS\nPerfect! Now I just need to change G to W to reach the target.\n\nGAINS → WAINS\n\nLet me verify the complete path:\n1. GRACE → GRAPE (C→P)\n2. GRAPE → GRIPE (A→I)\n3. GRIPE → GRIPS (E→S)\n4. GRIPS → GRINS (P→N)\n5. GRINS → GAINS (R→A)\n6. GAINS → WAINS (G→W)\n\nLet me check each word:\nGRACE - valid\nGRAPE - valid (fruit)\nGRIPE - valid (complaint)\nGRIPS - valid (holds)\nGRINS - valid (smiles)\nGAINS - valid (increases)\nWAINS - valid (wagons)\n\nAnd each step only changes one letter. This works! Though let me just double-check if there might be a shorter path...\n\nCould I have gone from GRACE directly to WAINS in fewer steps? No, because five letters need to change, so minimum five steps needed. Could I have gone:\nGRACE → GRANS? (Not a word)\nGRACE → GRINS? (No, that would be two letters)\nGRACE → WRACE? (Not a word)\n\nNo, I think the path I found is optimal because:\n1. Each step is a valid word\n2. Each step changes only one letter\n3. We needed to change all five letters\n4. Each change gets us closer to the target\n\nTherefore, the solution path GRACE → GRAPE → GRIPE → GRIPS → GRINS → GAINS → WAINS is correct and efficient.\n</think>"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":61,"cache_creation_input_tokens":0,"cache_read_input_tokens":6434,"output_tokens":967}}}}

View file

@ -1,10 +0,0 @@
{"question": "Transform the word ladder 'PAINT' to 'LACKS' by changing one letter at a time.", "answer": "PAINT,PAINS,WAINS,WARNS,BARNS,BARKS,BACKS,LACKS", "reasoning": null, "metadata": {"start_word": "PAINT", "end_word": "LACKS", "word_length": 5, "chain_length": 8}}
{"question": "Transform the word ladder 'CONKS' to 'GRAYS' by changing one letter at a time.", "answer": "CONKS,COOKS,COOPS,CROPS,CRAPS,CRAYS,GRAYS", "reasoning": null, "metadata": {"start_word": "CONKS", "end_word": "GRAYS", "word_length": 5, "chain_length": 7}}
{"question": "Transform the word ladder 'WINGS' to 'WORMY' by changing one letter at a time.", "answer": "WINGS,WINKS,WONKS,WORKS,WORMS,WORMY", "reasoning": null, "metadata": {"start_word": "WINGS", "end_word": "WORMY", "word_length": 5, "chain_length": 6}}
{"question": "Transform the word ladder 'PINTO' to 'RASPS' by changing one letter at a time.", "answer": "PINTO,PINTS,PANTS,PARTS,WARTS,WARPS,WASPS,RASPS", "reasoning": null, "metadata": {"start_word": "PINTO", "end_word": "RASPS", "word_length": 5, "chain_length": 8}}
{"question": "Transform the word ladder 'GRACE' to 'WAINS' by changing one letter at a time.", "answer": "GRACE,GRAPE,GRIPE,GRIPS,GRINS,GAINS,WAINS", "reasoning": null, "metadata": {"start_word": "GRACE", "end_word": "WAINS", "word_length": 5, "chain_length": 7}}
{"question": "Transform the word ladder 'RACES' to 'VAPES' by changing one letter at a time.", "answer": "RACES,RAPES,VAPES", "reasoning": null, "metadata": {"start_word": "RACES", "end_word": "VAPES", "word_length": 5, "chain_length": 3}}
{"question": "Transform the word ladder 'LAIRS' to 'FUELS' by changing one letter at a time.", "answer": "LAIRS,FAIRS,FAILS,FARLS,FURLS,FUELS", "reasoning": null, "metadata": {"start_word": "LAIRS", "end_word": "FUELS", "word_length": 5, "chain_length": 6}}
{"question": "Transform the word ladder 'WINKS' to 'TENTH' by changing one letter at a time.", "answer": "WINKS,PINKS,PINTS,TINTS,TENTS,TENTH", "reasoning": null, "metadata": {"start_word": "WINKS", "end_word": "TENTH", "word_length": 5, "chain_length": 6}}
{"question": "Transform the word ladder 'VERBS' to 'DATES' by changing one letter at a time.", "answer": "VERBS,HERBS,HERDS,HARDS,HARES,HATES,DATES", "reasoning": null, "metadata": {"start_word": "VERBS", "end_word": "DATES", "word_length": 5, "chain_length": 7}}
{"question": "Transform the word ladder 'SCENT' to 'HEARS' by changing one letter at a time.", "answer": "SCENT,SCANT,SCART,SCARS,SEARS,HEARS", "reasoning": null, "metadata": {"start_word": "SCENT", "end_word": "HEARS", "word_length": 5, "chain_length": 6}}

View file

@ -1,5 +0,0 @@
# Required dependencies for the Word Ladder Puzzle Dataset Generator
# Python 3.7+
anthropic>=0.45.2 # Client library for interacting with Anthropic's API
tqdm>=4.67.1 # For progress bars

View file

@ -1,288 +0,0 @@
Word Ladder puzzles involve transforming a start word into an end word. You are allowed to change only one letter a time and you must keep the number of letters constant.
Each time you change one letter the word in the chain must be forming a new word that's valid in English. Plurals are allowed, but not proper nouns.
Your task will be to simulate a detailed step-by-step chain of thought reasoning of the transformation process from a start word to the end word of a word ladder.
You will be given the word ladder question, as well as the correct solution path. So you don't need to solve the problem, you have the solution already, which should be your 'cheat sheet' to refer to as you go along.
Your task is to provide a perfectly simulated chain of thought reasoning exactly in the style of the example below,
including considering multiple possibilities, backtracking where needed, validating words, and showing the final path.
Remember you're NOT trying to 'find' the shortest path necessarily. Make sure you're not lazy and show the user that you're working hard to find the intermediate words in the chain, and that you haven't gotten simply lucky in finding the steps.
For example, you may go down a path that's no longer getting you closer to the solution, so you backtrack and try a different path, which ends up working.
Use casual, free flowing verbose, human-like thinking language. You must convince the user that you're really 'REASONING'. You can see the real reasoning in the examples below, which may look convoluted, but it's true reasoning, that is what we need.
It's ok to pretend catch yourself making a mistake, and then correct yourself, adds realism to the chain of thought.
VERY IMPORTANT ***NEVER MENTION YOU HAVE BEEN PROVIDED WITH THE ANSWER, YOU NEED TO WRITE ONLY AS IF YOU'RE SOLVING THE QUESTION STEP BY STEP ***
Example chain of thought reasoning:
<think>
Okay, so I need to figure out how to transform the word "CARK" into "ZIPS" by changing one letter at a time, making sure each intermediate step is a valid English word. Let me start by breaking down the problem.
First, both words are four letters long, so each step must maintain four letters. Each change can only alter one letter, either adding, removing, or replacing, but since the length must stay the same, it has to be a replacement each time. So, I need to find a path from CARK to ZIPS where each intermediate word is valid.
Let me write down the starting and target words:
CARK
...
ZIPS
I need to see which letters need to change. Let's compare the letters:
C -> Z
A -> I
R -> P
K -> S
So all four letters need to change. But since each step can only change one letter, I need to find a sequence where each single-letter change results in a valid word. Maybe I can approach this by changing one letter at a time, checking possible words each step.
Starting with CARK. Let's look at possible changes for each letter:
First letter C: changing to other letters to form a word. Possible options could be BARK (C to B), DARK (C to D), LARK (C to L), MARK (C to M), PARK (C to P), etc. Wait, but CARK is a valid word? Let me confirm. CARK means to worry or burden, so yes, it's valid.
So, possible first steps from CARK:
Change C to B: BARK
Change C to D: DARK
Change C to L: LARK
Change C to M: MARK
Change C to P: PARK
Change C to W: WARK (is that a word? Wark might be a Scottish term for work, but I'm not sure if it's accepted in standard English. Maybe not.)
Other possibilities: changing the A, R, or K. Let's see:
Second letter A: changing to other vowels or consonants. For example, CORK (changing A to O), but CORK is valid. Or CIRK (changing A to I; cirk is a variant of circus, but I think "cirk" isn't standard. Maybe CIRC? No, that's not right. Wait, maybe CIRK is a valid word in some contexts, like a circular trench in fortification? Hmm, not sure. Better to stick with more common words.
Third letter R: changing to another letter. For example, CAKE (changing R to K), but that would require changing both the third letter and the fourth letter? Wait, no: CARK to CAKE would be changing R to E and K to E? Wait, no. Wait, CARK is C-A-R-K. To get CAKE, you'd need to change R to K and K to E, which is two changes. So not possible in one step. Maybe changing R to another letter. For example, CASH (changing R to S and K to H? No, that's two letters again. Wait, no. Let's see: CARK. Changing the third letter R to another letter, like CASK (changing R to S). But that would be changing the third letter from R to S, resulting in CASK. Yes, that's a valid word. So CARK -> CASK.
Fourth letter K: changing to another letter. For example, CARD (changing K to D), CARP (changing K to P), CARS (changing K to S), CART (changing K to T). All of these are valid.
So possible first steps from CARK:
- BARK, DARK, LARK, MARK, PARK, CASK, CARD, CARP, CARS, CART.
Now, which of these would lead towards ZIPS? Let's think about the target word. ZIPS has a Z, which is the first letter. So maybe we need to get to a word starting with Z eventually. But changing from C to Z in one step would be a big jump. Alternatively, maybe we can change other letters first to reach intermediate steps that can approach Z.
Alternatively, maybe work on changing the last letters first. Let's see. For example, changing K to S (CARS) might be useful since the target ends with S. Let's try that.
CARK -> CARS (changing K to S). Now, from CARS, we need to change the third letter R to P and the first letter C to Z, and the second A to I. Let's see.
From CARS, possible changes:
Change C to Z: ZARS (is that a word? ZAR is a currency, plural ZARS? Maybe, but not sure. Alternatively, ZAPS (changing R to P and S to something? Wait, no. Let's check each letter.
Alternatively, from CARS, change A to I: CIRS (is that a word? CIR is a three-letter word, but CIRS? Maybe not. Or CURS (changing A to U), which is valid. But how does that help?
Alternatively, from CARS, change R to P: CAPS. Yes, that's valid. So CARS -> CAPS (changing R to P). Then from CAPS, we can change C to Z: ZAPS. Yes, ZAPS is a word. Then ZAPS -> ZIPS by changing A to I. So that would be a path:
CARK -> CARS -> CAPS -> ZAPS -> ZIPS.
Wait, let's check each step:
1. CARK to CARS: change K to S. Valid.
2. CARS to CAPS: change R to P. Valid.
3. CAPS to ZAPS: change C to Z. Valid.
4. ZAPS to ZIPS: change A to I. Valid.
That's four steps. But wait, starting from CARK (step 0), then steps 1-4 would be four changes. But the problem says "transform CARK into ZIPS by changing one letter at a time", so the number of steps would be the number of letters changed. Since all four letters are different, at minimum it would take four steps. But perhaps the path is longer if there are no shorter paths.
Wait, let me check if this works:
CARK → CARS → CAPS → ZAPS → ZIPS.
Yes, that's four steps. Each step changes one letter, and all intermediate words are valid. Let me verify each word:
- CARK: valid.
- CARS: valid.
- CAPS: valid.
- ZAPS: valid.
- ZIPS: valid.
Yes, that seems to work. But let me check if there's a shorter path or if this is correct.
Alternatively, is there a three-step path? Probably not since four letters need to change. Each step can only change one letter, so at least four steps are needed. So the path above is four steps, which is the minimum. Let me confirm each step again.
Another possible path: changing the first letter earlier. For example:
CARK → PARK (change C to P). Then PARK → PARKS (but that's five letters; can't do that. Alternatively, changing K to S: PARS. Then PARS -> PAWS (changing R to W?), but not helpful. Or PARK → PARS (change K to S), then PARS → PAIS (change R to I?), but PAIS is valid. Then PAIS → ZIPS (change P to Z and A to I? No, that's two letters. Hmm, not helpful.
Alternatively, CARK → DARK → DANK → DINK → ZINK → ZIPS? Let's see. CARK to DARK (C→D). DARK to DANK (R→N). DANK to DINK (A→I). DINK to ZINK (D→Z). ZINK is a variant spelling of zinc, but ZINK is a valid word. Then ZINK to ZIPS (K→S and N→P? Wait, no. ZINK is Z-I-N-K. To get to ZIPS, we need Z-I-P-S. So need to change N to P and K to S. That's two steps. So that path would be longer.
Another path: CARK → CARD → ZARD? ZARD isn't a word. CARD → WARD → WARD → ... not helpful.
Alternatively, CARK → CARK to CIRK (A→I). CIRK is a variant of circus, but maybe not standard. Then CIRK → ZIRK (C→Z). ZIRK is a surname or a brand, not sure if it's a valid word. Then ZIRK → ZIPS (R→P and K→S). But again, two letters at once. Not helpful.
So the first path I thought of seems viable: CARK → CARS → CAPS → ZAPS → ZIPS.
Another check: From CARK, can we go to CASK? CARK → CASK (R→S). Then CASK → ZASK (C→Z). ZASK? Not a word. Alternatively, CASK → CASK → CASH (K→H). Then CASH → ZASH (C→Z). Not a word. Not helpful.
Alternatively, CARK → CARP (K→P). CARP is valid. Then CARP → CARP → ZARP (C→Z). ZARP? Not a word. Or CARP → HARp (C→H), but HARP is valid. Then HARP → HIPS (A→I and R→P). But that's two changes. Not allowed.
So the initial path seems best. Let me confirm each step again:
1. CARK → CARS: change K to S. Valid.
2. CARS → CAPS: change R to P. Valid.
3. CAPS → ZAPS: change C to Z. Valid.
4. ZAPS → ZIPS: change A to I. Valid.
Yes, that's four steps, each changing one letter, all words are valid. So that should be the correct sequence.
</think>
alternative chain of thought:
<think> Okay, so I need to figure out how to transform the word "BALL" into "POMP" by changing one letter at a time, with each intermediate step being a valid English word. Let me start by writing down both words to see the letters that need to change.
BALL → POMP
First, let's look at each position:
B (1st letter) needs to become P.
A (2nd letter) needs to become O.
L (3rd letter) needs to become M.
L (4th letter) needs to become P.
So each letter position needs to change. The challenge is to do this one letter at a time, and each step must result in a valid word. Let me think about possible steps.
Starting with BALL. Let's see what words can be formed by changing one letter.
For the first letter (B), changing it to other letters might give words like CALL, FALL, HALL, MALL, PALL, WALL. Those are all valid. Alternatively, changing the second letter (A) could lead to BELL, BILL, BOLL, BULL. Changing the third letter (L) might result in BALM, BALE, BALK, etc. Changing the fourth letter (L) could give BALD, BALE, BALM, etc.
Wait, but the target word is POMP. So maybe we need to adjust letters step by step. Let's try to approach this step by step.
Option 1: Change the first letter from B to P. That would give PALL. Is PALL a valid word? Yes. So step 1: BALL → PALL.
Now, PALL. Next, maybe change the second letter A to O. That would make POLL. POLL is a valid word. Step 2: PALL → POLL.
Now, POLL. Next, change the third letter L to M. That would make POMP? Wait, no. Because POLL has two Ls. If we change the third L to M, it would be POML, which isn't a word. Alternatively, change the fourth L to P? That would be POLP, which isn't a word. Hmm. Maybe this path isn't working.
Alternative approach. Let's try changing a different letter first. Let's see. Starting with BALL.
Another possibility: Change the second letter A to O first. Then we get BOLL. BOLL is a valid word (a type of seed capsule). Step 1: BALL → BOLL.
From BOLL, maybe change the first letter B to P. That gives POLL. Step 2: BOLL → POLL.
Now, POLL again. Then, similar problem as before. Need to get to M in the third position. Maybe change the third L to M. But that gives POML, which isn't valid. Alternatively, change the fourth L to P. POLP isn't a word. Hmm.
Alternative path: From BOLL, maybe change the third letter L to M. That would make BOLM. Is BOLM a word? I don't think so. Or change the fourth L to P: BOLP. Not a word.
Maybe another starting step. Let's try changing the third letter in BALL from L to M. So BALL → BALM. Step 1: BALL → BALM.
BALM is valid. Now, change the fourth letter M to P? Wait, BALM is B-A-L-M. Wait, no. Wait, original word is BALL (B-A-L-L). Changing the third L to M would make BALM. Then changing the fourth L (from original) to something else. Wait, after BALM, the fourth letter is M? Wait, no. Wait, BALL is B-A-L-L. If I change the third L to M, the word becomes BALM (B-A-L-M). Wait, but then the fourth letter is M, so BALM. Then the fourth letter is M. To get to POMP, which ends with P, maybe we can change that M to P. But BALM → BALP? BALP isn't a valid word. Hmm.
Alternatively, maybe change the fourth letter first. BALL → BALD (changing the last L to D). Then BALD. But I'm not sure if that helps. Let's see. BALD. Then maybe change A to O: BOLD. Then BOLD. Then change B to P: POLD. Not a word. Hmm.
Alternatively, BALL → BALE (changing the last L to E). BALE is valid. Then BALE → PALE (changing B to P). Then PALE → POLE (changing A to O). Then POLE → POLL (changing E to L). Then POLL → POMP (changing L to M and L to P? Wait, but that's two changes at once. Not allowed. So this path might not work.
Wait, perhaps another approach. Let's consider the target word POMP. The letters are P-O-M-P. So after changing some letters, maybe we can reach a word that has some of these letters, then proceed.
Let me try starting with BALL. Change the first letter B to P: PALL. Then change the second letter A to O: POLL. Then change the third letter L to M: POMP? Wait, no. POLL → changing third L to M would make POML. Not valid. But if we can change the third and fourth letters? Wait, but we can only change one letter each step.
Alternatively, from POLL, change the fourth L to P. That would make POLP, which isn't a word. Not helpful.
Hmm. Maybe a different path. Let's think of possible intermediate words that can connect these.
Another idea: BALL → BOLT (changing A to O and L to T? No, that's two changes. Not allowed. Wait, no. To get BOLT from BALL, you need to change the second letter A to O (making BOLL) and then the fourth L to T. But that's two steps. Alternatively, maybe BALL → BOLT by changing A to O and third L to T? No, that's two changes again.
Alternatively, BALL → BULK (changing A to U and third L to K. Again two changes. Not allowed.
Wait, maybe changing the third letter first. BALL → BALE (changing third L to E). Then BALE → BOLT (changing A to O and E to T. Two changes. Not allowed. Hmm.
Alternatively, BALL → BELL (changing A to E). Then BELL → BELL is same, then maybe BELL → BOLT (changing E to O and L to T. Again two steps.
This is tricky. Let me try another angle. Let's look for words that are one letter away from POMP. For example, possible previous steps before POMP could be:
POOP (changing M to O), but POOP is valid. Then from POOP to POMP by changing O to M. But that's a possible step. But how to get to POOP from previous steps.
Alternatively, POMP could come from changing a letter in a word like POOP, POMY (not a word), POPS (changing S to M? No), etc. Maybe not helpful.
Alternatively, the step before POMP could be something like POMM (if that's a word; not sure). Probably not.
Alternatively, consider that the third letter needs to be M. So maybe a word like POM_ something. For example, POMP. So maybe from a word like POME (a type of fruit), changing E to P. But POME to POMP would require changing E to P. But POME is a valid word. So step: POME → POMP. Then, how to get to POME.
But how to get from BALL to POME. Let's see. Let's try:
BALL → BALE (change third L to E)
BALE → PALE (change B to P)
PALE → PALE (same)
PALE → POLE (change A to O)
POLE → POLL (change E to L)
POLL → POLL (same)
POLL → POMP (change L to M and L to P. Wait, that's two letters. Not allowed.
Alternatively, maybe a different path. Let's think of words with M in the third position. For example, BALM (which we considered earlier). BALM → BAMM (not a word). Hmm.
Wait, BALM is B-A-L-M. If we change the first letter B to P, we get PALM. PALM is a valid word. So BALL → BALM (changing third L to M), then BALM → PALM (changing B to P). Then PALM. Now, from PALM, maybe change the fourth M to P. That gives PALP. PALP is a valid word (to palpate). So PALM → PALP. Then PALP → POMP. How? Change A to O and L to M. But that's two letters. Alternatively, change the second letter A to O: POLP (not a word). Hmm.
Alternatively, from PALM, change the second letter A to O: POLM. Not a word. Not helpful.
Hmm. Maybe another path. Let's think of changing the second letter first. BALL → BOLL (A to O). Then BOLL → BOLT (changing L to T). Then BOLT → BOLT (same). Then BOLT → MOAT (changing B to M and L to A. Two changes. Not allowed.
Alternatively, BOLL → POLL (B to P). Then POLL → POLE (L to E). Then POLE → POME (E to M and L to E? No, that's two changes. Wait, POLE to POME would require changing L to M and E to E. Wait, no. POLE is P-O-L-E. To get to POME, you need to change L to M and E to E. Wait, no. Changing L to M would make POME? P-O-M-E. Yes, POME is a word. So POLE → POME (changing L to M). Then POME → POMP (changing E to P). So that's possible.
So let's outline this path:
BALL → BOLL (change A to O)
BOLL → POLL (change B to P)
POLL → POLE (change L to E)
POLE → POME (change L to M)
POME → POMP (change E to P)
Wait, but step 3: POLL to POLE. POLL is P-O-L-L. Changing the fourth L to E gives POLE? Wait, no. If you change the fourth L (the second L) to E, then it would be POLE. Yes. So POLL → POLE (changing the fourth letter L to E). Then POLE → POME (changing the third letter L to M). Wait, POLE is P-O-L-E. Changing the third letter L to M gives P-O-M-E, which is POME. Then POME → POMP by changing the fourth letter E to P. That works.
So the steps would be:
BALL → BOLL
BOLL → POLL
POLL → POLE
POLE → POME
POME → POMP
But wait, step 3: POLL to POLE. Is that valid? Yes, because changing the fourth L to E gives POLE. Then step 4: POLE to POME by changing the third L (but in POLE, the third letter is L? Wait, POLE is P-O-L-E. So third letter is L. Changing that to M gives P-O-M-E, which is POME. Then POME to POMP by changing the fourth E to P. That works. So that's five steps. But the problem says "change one letter at a time" without specifying the minimum number of steps, so as long as each step is valid, it's okay.
But let's check if all these words are valid:
BOLL: Yes, a type of seed capsule.
POLL: Yes.
POLE: Yes.
POME: Yes, a fruit like an apple.
POMP: Target.
So that's a valid sequence. But wait, the original problem is to go from BALL to POMP. Let me count the steps. BALL (1) → BOLL (2) → POLL (3) → POLE (4) → POME (5) → POMP (6). Wait, that's five steps, but the count is six words including the start and end. Wait, the problem says "transform ... by changing one letter at a time. Each step must create a valid English word". So the sequence should be BALL → word1 → word2 → ... → POMP. So in this case, five changes (since each change is a step). But let me check each step again.
BALL → BOLL (change A to O)
BOLL → POLL (change B to P)
POLL → POLE (change fourth L to E)
POLE → POME (change third L to M)
POME → POMP (change fourth E to P)
Yes, that's five steps. Each step changes one letter, and each intermediate word is valid. So this works.
Alternatively, is there a shorter path? Let's see. Let's try another path.
Starting with BALL → PALL (B to P). Then PALL → POLL (A to O). Then POLL → POMP. But how? POLL to POMP requires changing L to M and L to P. That's two changes. Not allowed. So that path would require two steps from POLL: first change third L to M (making POML, invalid), then fourth L to P (POLP, invalid). So that doesn't work.
Alternatively, another path: BALL → BALE (third L to E). BALE → PALE (B to P). PALE → POLE (A to O). POLE → POME (L to M). POME → POMP (E to P). Same as the previous path but starting differently. So that's also five steps.
Alternatively, BALL → BALM (third L to M). BALM → PALM (B to P). PALM → PALP (M to P). PALP → POMP (A to O and L to M). But changing two letters at once isn't allowed. So that doesn't work.
So the path through BOLL, POLL, POLE, POME seems valid. Let me confirm each word:
BOLL: valid.
POLL: valid.
POLE: valid.
POME: valid.
POMP: target.
Yes, this works. Therefore, the sequence is BALL → BOLL → POLL → POLE → POME → POMP.
</think>
Finally, I've seen in your previous answers to similar questions that you sometimes get complacent and deviate from the hard work required. You must maintain the style and level effort throughout the entirety of your answer until the end.
Now, the user will provide you with the word ladder question and the correct solution path. Provide the user with a very verbose, first person chain of thought reasoning that is simulating how a human would methodically think to solve the word ladder step by step.

View file

@ -1,254 +0,0 @@
import datetime
import json
import os
import time
from pathlib import Path
from types import SimpleNamespace
import pytest
from examples.word_ladder import generate_reasoning
# We alias the functions and globals for easier usage in our tests.
submit_reasoning_batches = generate_reasoning.submit_reasoning_batches
_submit_single_batch = generate_reasoning._submit_single_batch
DEFAULT_INPUT_JSONL = generate_reasoning.DEFAULT_INPUT_JSONL
COMMON_UUID = generate_reasoning.COMMON_UUID
BATCH_SIZE = generate_reasoning.BATCH_SIZE
client = generate_reasoning.client
# Define a mock batch response class mimicking Anthropic's API response.
class MockBatchResponse:
def __init__(self, batch_id="msgbatch_mock", processing_status="in_progress", fail=False):
self.id = batch_id
self.type = "message_batch"
self.processing_status = processing_status
# Make request_counts a SimpleNamespace object with the required attributes
self.request_counts = SimpleNamespace(processing=0, succeeded=0, errored=0, canceled=0, expired=0)
self.ended_at = None
# Use datetime objects so that isoformat() is available
self.created_at = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
self.expires_at = self.created_at + datetime.timedelta(seconds=86400)
self.cancel_initiated_at = None
self.results_url = None
# Helper: Create a temporary system prompt file.
@pytest.fixture
def system_prompt_file(tmp_path, monkeypatch):
prompt_text = "This is a system prompt."
sys_file = tmp_path / "system_prompt.txt"
sys_file.write_text(prompt_text, encoding="utf-8")
# Monkeypatch the system prompt path
monkeypatch.setattr(generate_reasoning, "DEFAULT_SYSTEM_PROMPT", str(sys_file))
return sys_file
# Helper: Create necessary directories using a temporary location.
@pytest.fixture
def setup_directories(tmp_path, monkeypatch):
# Create output directory in temporary path
output_dir = tmp_path / "output"
output_dir.mkdir(exist_ok=True)
# Monkeypatch the DEFAULT_OUTPUT_DIR to be a Path (temporary directory).
monkeypatch.setattr(generate_reasoning, "DEFAULT_OUTPUT_DIR", output_dir)
# Ensure we're working in the temporary directory
monkeypatch.chdir(tmp_path)
return output_dir
# Helper: Create a temporary input JSONL file with given entries.
@pytest.fixture
def input_jsonl_file(tmp_path, setup_directories, monkeypatch):
# Create input file in temporary directory
file_path = setup_directories / "word_ladder_examples.jsonl"
entries = [
{
"question": "Transform 'A' to 'B'",
"answer": "A,X,B",
"reasoning": None,
"metadata": {"start_word": "A", "end_word": "B", "word_length": 1, "chain_length": 3},
},
{
"question": "Transform 'C' to 'D'",
"answer": "C,Y,D",
"reasoning": "Some reasoning",
"metadata": {"start_word": "C", "end_word": "D", "word_length": 1, "chain_length": 3},
},
{
"question": "Transform 'E' to 'F'",
"answer": "E,Z,F",
"reasoning": None,
"metadata": {"start_word": "E", "end_word": "F", "word_length": 1, "chain_length": 3},
},
]
with file_path.open("w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry) + "\n")
# Monkeypatch DEFAULT_INPUT_JSONL to point to our temporary test file.
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(file_path))
return file_path
# Test that submit_reasoning_batches builds a batch skipping entries with existing reasoning.
def test_submit_batches_success(system_prompt_file, input_jsonl_file, setup_directories, monkeypatch):
def fake_create(requests):
for req in requests:
# Handle the case where req is a dictionary
if isinstance(req, dict):
params = req.get("params", {})
custom_id = req.get("custom_id")
# Check if params itself is a dictionary
if isinstance(params, dict):
model = params.get("model")
temperature = params.get("temperature")
else:
model = params.model
temperature = params.temperature
else:
# Else, req is an object with attributes.
params = req.params
custom_id = req.custom_id
if isinstance(params, dict):
model = params.get("model")
temperature = params.get("temperature")
else:
model = params.model
temperature = params.temperature
assert model == "claude-3-5-sonnet-20241022", "Incorrect model version"
assert temperature == 0.5, "Incorrect temperature"
assert "C_D_" not in custom_id
return MockBatchResponse(batch_id="msgbatch_test_success")
monkeypatch.setattr(client.messages.batches, "create", fake_create)
batch_metadata_prefix = "test_metadata"
submit_reasoning_batches(input_path=str(input_jsonl_file), batch_metadata_prefix=batch_metadata_prefix)
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
meta_file_path = setup_directories / metadata_filename
assert meta_file_path.exists(), "Metadata file was not created as expected."
with meta_file_path.open("r", encoding="utf-8") as f:
lines = f.readlines()
# Expecting only those entries that did not already have a reasoning.
# (From our test input, 2 out of 3 entries qualify.)
assert len(lines) > 0
for line in lines:
metadata = json.loads(line)
api_response = metadata["api_response"]
assert api_response["id"] == "msgbatch_test_success"
assert api_response["processing_status"] == "in_progress"
custom_ids = metadata["custom_ids"]
assert len(custom_ids) == 2
# Test that _submit_single_batch retries once and eventually succeeds.
def test_retry_logic(system_prompt_file, setup_directories, monkeypatch):
call_count = {"count": 0}
def fake_create_retry(requests):
if call_count["count"] == 0:
call_count["count"] += 1
raise Exception("Temporary failure")
return MockBatchResponse(batch_id="msgbatch_retry_success")
monkeypatch.setattr(client.messages.batches, "create", fake_create_retry)
dummy_request = type("DummyRequest", (), {"custom_id": "dummy_1"})()
batch_requests = [dummy_request]
custom_ids = ["dummy_1"]
_submit_single_batch(batch_requests, custom_ids, 0, "test_retry", "dummy_input.jsonl")
metadata_filename = f"{COMMON_UUID}_test_retry.jsonl"
meta_file_path = setup_directories / metadata_filename
assert meta_file_path.exists(), "Retry metadata file was not created."
with meta_file_path.open("r", encoding="utf-8") as f:
metadata = json.loads(f.read())
assert metadata["api_response"]["id"] == "msgbatch_retry_success"
assert call_count["count"] == 1
# Test that when all attempts to submit a batch fail, the error is logged to the failed file.
def test_failed_batch(system_prompt_file, setup_directories, monkeypatch):
def fake_create_fail(requests):
raise Exception("Permanent failure")
monkeypatch.setattr(client.messages.batches, "create", fake_create_fail)
dummy_request = type("DummyRequest", (), {"custom_id": "dummy_fail"})()
batch_requests = [dummy_request]
custom_ids = ["dummy_fail"]
_submit_single_batch(batch_requests, custom_ids, 0, "test_failed", "dummy_input.jsonl")
error_filename = f"{COMMON_UUID}_failed_batches.jsonl"
error_file_path = setup_directories / error_filename
assert error_file_path.exists(), "Failed batch log file was not created."
with error_file_path.open("r", encoding="utf-8") as f:
error_entry = json.loads(f.readline())
assert error_entry["batch_number"] == 0
assert "Permanent failure" in error_entry["error"]
assert error_entry["batch_requests"] == ["dummy_fail"]
# Test batching behavior when multiple batches are needed.
def test_multiple_batches(system_prompt_file, setup_directories, monkeypatch):
test_batch_size = 2
monkeypatch.setattr(generate_reasoning, "BATCH_SIZE", test_batch_size)
# Create input file
input_file = setup_directories / "word_ladder_examples.jsonl"
entries = [
{
"question": f"Transform word ladder {idx}",
"answer": f"start,mid,end_{idx}",
"reasoning": None,
"metadata": {"start_word": f"start_{idx}", "end_word": f"end_{idx}"},
}
for idx in range(5)
]
with input_file.open("w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry) + "\n")
# Monkeypatch DEFAULT_INPUT_JSONL.
monkeypatch.setattr(generate_reasoning, "DEFAULT_INPUT_JSONL", str(input_file))
batch_ids = []
def fake_create(requests):
new_id = f"msgbatch_batch_{len(batch_ids)}"
batch_ids.append(new_id)
return MockBatchResponse(batch_id=new_id)
monkeypatch.setattr(client.messages.batches, "create", fake_create)
batch_metadata_prefix = "test_multi"
submit_reasoning_batches(input_path=str(input_file), batch_metadata_prefix=batch_metadata_prefix)
metadata_filename = f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
meta_file_path = setup_directories / metadata_filename
assert meta_file_path.exists(), "Multiple batch metadata file was not created."
with meta_file_path.open("r", encoding="utf-8") as f:
metadata_lines = f.readlines()
# With 5 qualifying entries and a batch size of 2 we expect 3 batches.
assert len(metadata_lines) == 3
seen_custom_ids = []
for line in metadata_lines:
metadata = json.loads(line)
seen_custom_ids.extend(metadata["custom_ids"])
assert metadata["api_response"]["id"] in batch_ids
assert len(seen_custom_ids) == 5

View file

@ -1,133 +0,0 @@
"""
create_word_ladders.py Generates the word ladder dataset as a JSONL file.
Each line is a JSON object with keys: question, answer, metadata, and reasoning (set to None).
"""
import json
import uuid
from pathlib import Path
from tqdm import tqdm
import reasoning_gym
def check_duplicates(jsonl_path: str) -> tuple[bool, dict]:
"""
Check for duplicate word pairs in a word ladder JSONL file.
Returns:
tuple[bool, dict]: (has_duplicates, valid_entries) where:
- has_duplicates: True if any duplicates were found
- valid_entries: Dict mapping line_number -> data for non-duplicate entries
Note: A pair is considered duplicate if either (word1, word2) or (word2, word1)
already exists, since word ladder paths are bidirectional.
"""
pairs_seen = {} # (start, end) -> (line_number, data)
valid_entries = {}
duplicates_found = False
with open(jsonl_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f):
data = json.loads(line)
metadata = data["metadata"]
pair = (metadata["start_word"], metadata["end_word"])
reverse_pair = (metadata["end_word"], metadata["start_word"])
# Check both orientations of the pair
if pair in pairs_seen or reverse_pair in pairs_seen:
duplicates_found = True
# Skip this entry - it's a duplicate
continue
else:
# Store both the line number and data for valid entries
pairs_seen[pair] = (line_num, data)
valid_entries[line_num] = data
return duplicates_found, valid_entries
def create_word_ladder_dataset(jsonl_path: str = None, config: dict = None) -> None:
"""
Creates a word ladder dataset and writes each sample as a JSON line.
Ensures no duplicate word pairs by regenerating as needed.
"""
if config is None:
raise ValueError("Configuration (config) must be provided.")
# Create output directory if it doesn't exist.
# Updated path to point to the parent folder's output.
output_dir = Path(__file__).resolve().parent.parent / "output"
output_dir.mkdir(exist_ok=True)
# Determine the file path based on uuid when not provided
if jsonl_path is None:
unique_id = uuid.uuid4().hex[:8]
jsonl_path = output_dir / f"word_ladders_{unique_id}.jsonl"
else:
jsonl_path = Path(jsonl_path)
target_size = config["dataset_config"]["size"]
current_size = 0
max_attempts = 3 # Limit total regeneration attempts
attempt = 0
# Initial generation
dataset = reasoning_gym.create_dataset(config["dataset_name"], **config["dataset_config"])
with open(jsonl_path, "w", encoding="utf-8") as f:
for item in tqdm(dataset, desc="Generating initial ladder examples"):
row = {
"question": item["question"],
"answer": item["answer"],
"reasoning": None,
"metadata": item.get("metadata", {}),
}
f.write(json.dumps(row) + "\n")
while attempt < max_attempts:
# Check entire file for duplicates
has_duplicates, valid_entries = check_duplicates(jsonl_path)
current_size = len(valid_entries)
if not has_duplicates and current_size == target_size:
print(f"\nSuccessfully created dataset with {current_size} unique examples.")
return
# If we have duplicates or not enough entries, regenerate the missing amount
needed = target_size - current_size
if needed > 0:
print(f"\nAttempt {attempt + 1}: Regenerating {needed} examples to replace duplicates/missing entries...")
# Generate additional examples
config["dataset_config"]["size"] = needed
additional_dataset = reasoning_gym.create_dataset(config["dataset_name"], **config["dataset_config"])
# Write all entries to a temporary file
temp_path = jsonl_path.with_suffix(".tmp")
with open(temp_path, "w", encoding="utf-8") as f:
# Write existing valid entries
for data in valid_entries.values():
f.write(json.dumps(data) + "\n")
# Write new entries
for item in additional_dataset:
row = {
"question": item["question"],
"answer": item["answer"],
"reasoning": None,
"metadata": item.get("metadata", {}),
}
f.write(json.dumps(row) + "\n")
# Replace original file with temporary file
temp_path.replace(jsonl_path)
# Note: We'll check for duplicates again at the start of the next loop
attempt += 1
if current_size < target_size:
print(f"\nWarning: Could only generate {current_size} unique examples after {max_attempts} attempts.")
else:
print(f"\nSuccessfully created dataset with {current_size} unique examples.")

View file

@ -1,214 +0,0 @@
"""
generate_reasoning.py Reads the JSONL file containing ladder examples,
creates batch requests of chain-of-thought prompts split into batches of 2,500,
calls Anthropic's Message Batches API for each batch, and writes separate batch metadata
files for later retrieval of the responses.
*** WARNING ***: Running large batches of requests via the Anthropic API (especially in generate_reasoning.py)
can incur significant costs in Anthropic credits. Please review and understand your API quota and budgeting
before running the API call. If you are testing or working with a demo dataset, adjust the batch size or dataset
size appropriately to avoid unexpected charges.
Using Anthropic's Message Batches API with caching enabled for system prompt.
In our informal testing, Sonnet was deemed best performance value.
You can swap out to another API, but this will need a rewrite to remove anthropic-specific code.
"""
import json
import os
import time
import uuid
from pathlib import Path
import anthropic
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
from tqdm import tqdm
# Updated default output directory to use the parent directory.
DEFAULT_OUTPUT_DIR = Path(__file__).resolve().parent.parent / "output"
# Add default constants at the top with other constants
DEFAULT_INPUT_JSONL = "output/word_ladder_examples.jsonl"
DEFAULT_SYSTEM_PROMPT = Path(__file__).resolve().parent.parent / "system_prompt.txt"
BATCH_SIZE = 2500
COMMON_UUID = uuid.uuid4().hex[:8]
# Set up the Anthropic client (ensure the API key is set in the environment)
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
def submit_reasoning_batches(
input_path: str = DEFAULT_INPUT_JSONL,
batch_metadata_prefix: str = "batch_metadata",
system_prompt_path: str = DEFAULT_SYSTEM_PROMPT,
) -> None:
"""
Reads the input JSONL file of word ladder examples, builds batch requests for any example that
does not have reasoning, splits them into groups of BATCH_SIZE, and submits each batch using
Anthropic's Message Batches API.
Args:
input_path: Path to input JSONL file
batch_metadata_prefix: Prefix for batch metadata files
system_prompt_path: Path to system prompt file
"""
# Create output directory if it doesn't exist
output_dir = DEFAULT_OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
# Read the system prompt from file (used as a preamble for every request)
with open(system_prompt_path, "r", encoding="utf-8") as sys_file:
system_message = [
{
"type": "text",
"text": sys_file.read(),
"cache_control": {"type": "ephemeral"}, # Enable anthropic prompt caching
}
]
batch_requests = []
custom_ids = [] # List of custom_ids for the current batch
batch_num = 0
# Get the total number of lines in advance for tqdm progress bar.
total_lines = sum(1 for _ in open(input_path))
with open(input_path, "r", encoding="utf-8") as infile:
for idx, line in tqdm(enumerate(infile), desc="Preparing batch requests", total=total_lines):
data = json.loads(line)
# Skip example if 'reasoning' already exists.
if not data.get("reasoning"):
# Build a custom id. Here we use the row position and the start/end words:
metadata = data.get("metadata", {})
start = metadata.get("start_word", "unknown")
end = metadata.get("end_word", "unknown")
custom_id = f"{start}_{end}_{idx}"
custom_ids.append(custom_id)
# Build the prompt text exactly as before.
prompt = f"{data['question']}. The correct solution is {data['answer']}. "
# Build the request payload using Request and MessageCreateParamsNonStreaming.
request_payload = Request(
custom_id=custom_id,
params=MessageCreateParamsNonStreaming(
model="claude-3-5-sonnet-20241022", # Or choose the appropriate model version
max_tokens=8192,
temperature=0.5,
system=system_message,
messages=[{"role": "user", "content": prompt}],
),
)
# Instead of wrapping in SimpleNamespace, simply ensure custom_id is set.
if isinstance(request_payload, dict):
request_payload["custom_id"] = custom_id
else:
request_payload.custom_id = custom_id
batch_requests.append(request_payload)
# If we have reached our batch size limit, submit the current batch.
if len(batch_requests) >= BATCH_SIZE:
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
batch_num += 1
# Reset for the next batch
batch_requests = []
custom_ids = []
# Submit any remaining requests that didn't complete a full batch.
if batch_requests:
_submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path)
def _submit_single_batch(batch_requests, custom_ids, batch_num, batch_metadata_prefix, input_path):
"""
Helper function to submit a single batch request, track the full API response,
and write out the corresponding metadata including the list of custom_ids.
"""
# Use the default output directory
output_dir = DEFAULT_OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
def serialize_datetime(dt):
"""
Convert a datetime object to ISO formatted string.
If dt is None, returns None.
"""
if dt is None:
return None
iso_str = dt.isoformat() # e.g. "2024-08-20T18:37:24.100435+00:00"
if dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None:
iso_str = iso_str.replace("+00:00", "Z")
return iso_str
def extract_custom_id(req):
# Safely extract the custom_id attribute whether req is an object or a dict.
return req.custom_id if hasattr(req, "custom_id") else req.get("custom_id")
max_attempts = 2
attempt = 0
last_exception = None
message_batch = None
while attempt < max_attempts:
try:
print(f"Submitting batch {batch_num} with {len(batch_requests)} requests... (attempt {attempt+1})")
message_batch = client.messages.batches.create(requests=batch_requests)
time.sleep(1)
print(f"Batch {batch_num} submitted with ID: {message_batch.id}")
break # Success: exit the loop.
except Exception as e:
last_exception = e
attempt += 1
print(f"Error submitting batch {batch_num} on attempt {attempt}: {e}")
if attempt < max_attempts:
print("Retrying...")
time.sleep(1)
if message_batch is None:
error_filename = output_dir / f"{COMMON_UUID}_failed_batches.jsonl"
error_msg = (
f"{str(last_exception)} after {max_attempts} attempts"
if last_exception
else f"Failed after {max_attempts} attempts"
)
failed_info = {
"batch_number": batch_num,
"error": error_msg,
"batch_requests": [extract_custom_id(req) for req in batch_requests],
"input_file": input_path,
}
with open(error_filename, "a", encoding="utf-8") as error_file:
error_file.write(json.dumps(failed_info) + "\n")
print(f"Batch {batch_num} permanently failed. Logged to {error_filename}.")
return
# Build a dictionary of the expected response fields.
api_response = {
"id": message_batch.id,
"type": message_batch.type,
"processing_status": message_batch.processing_status,
"request_counts": vars(message_batch.request_counts),
"ended_at": serialize_datetime(message_batch.ended_at),
"created_at": serialize_datetime(message_batch.created_at),
"expires_at": serialize_datetime(message_batch.expires_at),
"cancel_initiated_at": serialize_datetime(message_batch.cancel_initiated_at),
"results_url": message_batch.results_url,
}
batch_metadata = {
"batch_id": message_batch.id,
"api_response": api_response,
"custom_ids": custom_ids,
"input_file": os.path.basename(input_path),
}
metadata_filename = output_dir / f"{COMMON_UUID}_{batch_metadata_prefix}.jsonl"
with open(metadata_filename, "a", encoding="utf-8") as meta_file:
meta_file.write(json.dumps(batch_metadata) + "\n")
print(f"Batch metadata for batch {batch_num} appended to {metadata_filename}.")
if __name__ == "__main__":
# When running this module directly, submit the reasoning batches.
submit_reasoning_batches()

View file

@ -1,206 +0,0 @@
#!/usr/bin/env python3
"""
This script reads a JSONL file that contains messages with usage statistics.
For each JSON record, it expects to find the token usage information under:
record["result"]["message"]["usage"]
It then calculates and prints statistics for each usage token field:
- input_tokens
- cache_creation_input_tokens
- cache_read_input_tokens
- output_tokens
+pricing calculations
+calculates the savings from caching (vs if we hadn't done any caching)
+forecasts costs for 10,000, 20,000 and 50,000 jobs based on tokens per query
Usage:
python usage_stats.py path/to/msgbatch_01X9LgZNVkLFhzrrBd9LNgWb_results.jsonl
"""
import argparse
import json
from statistics import mean
def main():
parser = argparse.ArgumentParser(description="Compute usage token statistics from a JSONL file.")
parser.add_argument("file", help="Path to the JSONL file containing usage token data.")
args = parser.parse_args()
# Usage token fields that we want to track
usage_fields = [
"input_tokens",
"cache_creation_input_tokens",
"cache_read_input_tokens",
"output_tokens",
]
# Pricing for Sonnet, 2 Feb 2025
base_input_rate = 1.50
pricing = {
"input_tokens": base_input_rate,
"cache_creation_input_tokens": base_input_rate * 1.25, # More expensive for initial computation
"cache_read_input_tokens": base_input_rate * 0.1, # Cheaper for cache-read tokens
"output_tokens": 7.50,
}
# A dictionary to store lists of values for each usage field
usage_data = {key: [] for key in usage_fields}
total_lines = 0
error_count = 0
with open(args.file, "r", encoding="utf-8") as f:
for line in f:
total_lines += 1
try:
record = json.loads(line)
except json.JSONDecodeError:
print(f"[Warning] Failed to parse JSON on line {total_lines}.")
error_count += 1
continue
# Navigate to the usage stats
try:
usage = record["result"]["message"]["usage"]
except KeyError:
print(f"[Warning] Missing usage field in line {total_lines}.")
error_count += 1
continue
# Extract token values from the usage data
for key in usage_fields:
# Defaulting to 0 if the token field is missing or non-numeric
try:
token_value = int(usage.get(key, 0))
except (ValueError, TypeError):
token_value = 0
usage_data[key].append(token_value)
print(f"\nProcessed {total_lines} lines with {error_count} error(s).\n")
print("Usage Tokens Statistics:")
print("-" * 40)
grand_total_cost = 0.0
# Calculate and print stats for each token type
for key in usage_fields:
values = usage_data[key]
if values:
total = sum(values)
count = len(values)
min_val = min(values)
max_val = max(values)
avg = mean(values)
# Calculate pricing cost scaling by tokens per million
cost = total / 1_000_000 * pricing[key]
grand_total_cost += cost
print(f"{key}:")
print(f" Total = {total}")
print(f" Count = {count}")
print(f" Min = {min_val}")
print(f" Max = {max_val}")
print(f" Mean = {avg:.2f}")
print(f" Cost = ${cost:.2f}\n")
else:
print(f"{key}: No data found.\n")
print("-" * 40)
print(f"Grand Total Estimated Cost: ${grand_total_cost:.2f}")
# -----------------------------------------------
# Calculate caching savings (for input-related tokens)
# Without caching, all tokens would have been charged at the standard input rate.
#
# Baseline cost (if no caching were used):
# = (input_tokens + cache_creation_input_tokens + cache_read_input_tokens)
# / 1_000_000 * base_input_rate
#
# Actual cost (with caching):
# = input_tokens * base_input_rate +
# cache_creation_input_tokens * (base_input_rate * 1.25) +
# cache_read_input_tokens * (base_input_rate * 0.1)
#
# Savings from caching is then the difference.
sum_input = sum(usage_data["input_tokens"])
sum_cache_creation = sum(usage_data["cache_creation_input_tokens"])
sum_cache_read = sum(usage_data["cache_read_input_tokens"])
baseline_input_cost = (sum_input + sum_cache_creation + sum_cache_read) / 1_000_000 * pricing["input_tokens"]
actual_input_cost = (
(sum_input) / 1_000_000 * pricing["input_tokens"]
+ (sum_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"]
+ (sum_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
)
caching_savings = baseline_input_cost - actual_input_cost
print(f"Caching Savings (input-related tokens): ${caching_savings:.2f}")
# -----------------------------------------------
# Forecast future cost estimates based on the average tokens per job.
#
# We'll compute the average tokens per job (i.e. tokens per query) for:
# - input_tokens
# - cache_creation_input_tokens
# - cache_read_input_tokens
# - output_tokens
#
# Then we forecast, for example, for 10,000, 20,000, and 50,000 jobs:
# - Apply the relevant pricing to compute the cost per token type.
# - Also compute the baseline cost for input-related tokens and the savings
# from caching.
if usage_data["input_tokens"]:
job_count = len(usage_data["input_tokens"])
avg_input_tokens = sum(usage_data["input_tokens"]) / job_count
avg_cache_creation_tokens = sum(usage_data["cache_creation_input_tokens"]) / job_count
avg_cache_read_tokens = sum(usage_data["cache_read_input_tokens"]) / job_count
avg_output_tokens = sum(usage_data["output_tokens"]) / job_count
print("\nAverage Tokens per Job:")
print(f" input_tokens = {avg_input_tokens:.2f}")
print(f" cache_creation_input_tokens = {avg_cache_creation_tokens:.2f}")
print(f" cache_read_input_tokens = {avg_cache_read_tokens:.2f}")
print(f" output_tokens = {avg_output_tokens:.2f}")
forecast_jobs = [2000, 4000, 10000, 20000, 50000]
print("\nForecasting Future Job Costs:")
for jobs in forecast_jobs:
# Forecast token usage for the job count by multiplying the per-job averages.
forecast_input = avg_input_tokens * jobs
forecast_cache_creation = avg_cache_creation_tokens * jobs
forecast_cache_read = avg_cache_read_tokens * jobs
forecast_output = avg_output_tokens * jobs
# Forecast actual cost (with caching applied for input tokens):
actual_input_cost_forecast = (
(forecast_input) / 1_000_000 * pricing["input_tokens"]
+ (forecast_cache_creation) / 1_000_000 * pricing["cache_creation_input_tokens"]
+ (forecast_cache_read) / 1_000_000 * pricing["cache_read_input_tokens"]
)
# Without caching, all input-related tokens would be at base_input_rate:
baseline_input_cost_forecast = (
(forecast_input + forecast_cache_creation + forecast_cache_read) / 1_000_000 * pricing["input_tokens"]
)
caching_savings_forecast = baseline_input_cost_forecast - actual_input_cost_forecast
forecast_output_cost = forecast_output / 1_000_000 * pricing["output_tokens"]
total_forecast_cost = actual_input_cost_forecast + forecast_output_cost
print(f"\nFor {jobs:,} jobs:")
print(" Forecasted Token Usage:")
print(f" input_tokens = {forecast_input:,.0f}")
print(f" cache_creation_input_tokens = {forecast_cache_creation:,.0f}")
print(f" cache_read_input_tokens = {forecast_cache_read:,.0f}")
print(f" output_tokens = {forecast_output:,.0f}")
print(" Estimated Costs:")
print(f" Input Cost (with caching) = ${actual_input_cost_forecast:,.2f}")
print(f" Output Cost = ${forecast_output_cost:,.2f}")
print(f" Grand Total Cost = ${total_forecast_cost:,.2f}")
print(f" Caching Savings (input) = ${caching_savings_forecast:,.2f}")
else:
print("No valid jobs to forecast future costs.")
if __name__ == "__main__":
main()