mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-25 17:10:51 +00:00
Feat/unsloth example (#482)
* cleaned up examples * updated failing hooks * updated readme * corrected linting checks
This commit is contained in:
parent
d9cd20c174
commit
1c98584f28
29 changed files with 122 additions and 2857 deletions
30
examples/unsloth/README.md
Normal file
30
examples/unsloth/README.md
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# Chain Sum LORA Training with unsloth
|
||||
|
||||
This example demonstrates how to fine-tune an LLM with RL on a reasoning gym environment using the **unsloth** framework. Unsloth is a efficient open-source library for fine-tuning & RL. Unsloths default training path uses quantised low rank adaption (QLORA) which results in a signficantly lower memory footprint ($\approx 3x$) and means you can significantly increase batch sizes and context length without risking OOM errors.
|
||||
|
||||
Requirements:
|
||||
|
||||
python >= 3.10
|
||||
|
||||
## Installation
|
||||
|
||||
1. **Install reasoning-gym**:
|
||||
```bash
|
||||
pip install reasoning-gym
|
||||
```
|
||||
2. **Install unsloth dependencies**:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
3. **Run training script**
|
||||
To start training with unsloth with RG environments using default arguments run the following:
|
||||
|
||||
```bash
|
||||
python train_grpo_lora.py
|
||||
```
|
||||
|
||||
To customise/override any default arguments you can simply:
|
||||
```bash
|
||||
python train_grpo_lora.py --dataset-name chain_sum --max-seq-length 512 --model-id Qwen/Qwen2.5-7B-Instruct
|
||||
|
||||
**Note** the free open-source version of unsloth is currently built to train models in single GPU environments only.
|
||||
|
|
@ -147,22 +147,24 @@ def main(args):
|
|||
training_args = GRPOConfig(
|
||||
output_dir="outputs",
|
||||
use_vllm=True,
|
||||
learning_rate=5e-6,
|
||||
learning_rate=1e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type="cosine",
|
||||
weight_decay=0.0,
|
||||
warmup_ratio=0.0,
|
||||
lr_scheduler_type="constant",
|
||||
optim="adamw_8bit",
|
||||
logging_steps=1,
|
||||
bf16=is_bfloat16_supported(),
|
||||
fp16=not is_bfloat16_supported(),
|
||||
per_device_train_batch_size=args.train_batch_size,
|
||||
gradient_accumulation_steps=1,
|
||||
gradient_accumulation_steps=4,
|
||||
num_generations=args.num_generations,
|
||||
num_train_epochs=args.train_epochs,
|
||||
max_prompt_length=512,
|
||||
max_completion_length=512,
|
||||
save_steps=100,
|
||||
max_grad_norm=0.1,
|
||||
max_grad_norm=1.0,
|
||||
)
|
||||
|
||||
train(model, tokenizer, dataset, training_args)
|
||||
|
|
@ -187,17 +189,17 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model-id", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
|
||||
parser.add_argument("--dataset-name", type=str)
|
||||
parser.add_argument("--dataset-name", type=str, default="chain_sum")
|
||||
|
||||
parser.add_argument("--max-seq-length", type=int, default=1024)
|
||||
parser.add_argument("--lora-rank", type=int, default=64)
|
||||
parser.add_argument("--quantize", action="store_true")
|
||||
parser.add_argument("--num-generations", type=int, default=8)
|
||||
parser.add_argument("--num-generations", type=int, default=16)
|
||||
parser.add_argument("--train-epochs", type=int, default=1)
|
||||
parser.add_argument("--train-batch-size", type=int, default=8)
|
||||
parser.add_argument("--train-batch-size", type=int, default=16)
|
||||
|
||||
parser.add_argument("--dataset-seed", type=int, default=42)
|
||||
parser.add_argument("--dataset-size", type=int, default=1000)
|
||||
parser.add_argument("--dataset-size", type=int, default=10000)
|
||||
|
||||
parser.add_argument("--eval-seed", type=int, default=42)
|
||||
parser.add_argument("--eval-size", type=int, default=100)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue