mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
Merge commit '71e7a5ca27' into add-support-for-custom-api-servers
This commit is contained in:
commit
96be544228
45 changed files with 1605 additions and 494 deletions
|
|
@ -152,4 +152,4 @@ server_configs:
|
|||
|
||||
If you encounter issues with reward functions, make sure they are properly registered in the registry.
|
||||
|
||||
For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset.
|
||||
For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset.
|
||||
|
|
|
|||
|
|
@ -49,4 +49,4 @@ dataset:
|
|||
server_configs:
|
||||
- model_name: "gpt-4.1-nano"
|
||||
api_key: ${OPENAI_API_KEY}
|
||||
timeout: 600
|
||||
timeout: 600
|
||||
|
|
|
|||
|
|
@ -70,4 +70,4 @@ dataset:
|
|||
|
||||
eval_dataset_name: "gsm8k"
|
||||
eval_dataset_config: "main"
|
||||
eval_split: "test"
|
||||
eval_split: "test"
|
||||
|
|
|
|||
|
|
@ -27,4 +27,4 @@ dataset:
|
|||
|
||||
max_tokens: 4096
|
||||
length_warmup_steps: 0
|
||||
min_tokens: 200
|
||||
min_tokens: 200
|
||||
|
|
|
|||
|
|
@ -32,7 +32,9 @@ class DatasetEnvConfig(BaseEnvConfig):
|
|||
None, description="Field in dataset containing canonical correct answer"
|
||||
)
|
||||
system_prompt: Optional[str] = Field(None, description="System prompt to use")
|
||||
prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '<think>')")
|
||||
prefill: Optional[str] = Field(
|
||||
None, description="Text to prefill the completion with (e.g. '<think>')"
|
||||
)
|
||||
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
|
||||
max_generations_per_prompt: int = Field(
|
||||
1, description="Number of generations per prompt for collection"
|
||||
|
|
@ -137,21 +139,21 @@ class DatasetEnv(BaseEnv):
|
|||
# Extract user prompt and answer from item
|
||||
user_content = dict(item[0][0])["content"]
|
||||
answer = item[1] if len(item) > 1 else None
|
||||
|
||||
|
||||
# Create messages list
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Add prefill as assistant message if configured
|
||||
if self.config.prefill:
|
||||
messages.append({"role": "assistant", "content": self.config.prefill})
|
||||
|
||||
|
||||
# Convert messages to a prompt string using the tokenizer
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
|
||||
# Calculate max tokens for generation (with optional warmup)
|
||||
max_tokens = self.config.max_tokens
|
||||
if self.config.length_warmup_steps > 0:
|
||||
|
|
@ -160,7 +162,7 @@ class DatasetEnv(BaseEnv):
|
|||
self.config.min_tokens
|
||||
+ warmup_progress * (self.config.max_tokens - self.config.min_tokens)
|
||||
)
|
||||
|
||||
|
||||
# Generate completion using completions API
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
|
|
@ -169,34 +171,38 @@ class DatasetEnv(BaseEnv):
|
|||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
|
||||
to_score = []
|
||||
to_backlog = []
|
||||
|
||||
|
||||
# Process completions
|
||||
for completion in completions.choices:
|
||||
# Get the completion text
|
||||
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
|
||||
|
||||
completion_text = (
|
||||
completion.text
|
||||
if hasattr(completion, "text")
|
||||
else completion.message.content
|
||||
)
|
||||
|
||||
# Build full message sequence for scoring
|
||||
full_messages = []
|
||||
if self.config.system_prompt:
|
||||
full_messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
full_messages.append(
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
)
|
||||
|
||||
full_messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Combine prefill with completion if prefill was used
|
||||
response_content = completion_text
|
||||
if self.config.prefill:
|
||||
response_content = self.config.prefill + completion_text
|
||||
|
||||
|
||||
full_messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
|
||||
# Add to scoring list with answer and ground truth
|
||||
to_score.append(
|
||||
(full_messages, answer, item[2] if len(item) > 2 else None)
|
||||
)
|
||||
|
||||
to_score.append((full_messages, answer, item[2] if len(item) > 2 else None))
|
||||
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]:
|
||||
|
|
@ -204,27 +210,27 @@ class DatasetEnv(BaseEnv):
|
|||
|
||||
async def collect_trajectories(self, item: Item) -> Tuple[List, List]:
|
||||
self.current_item = item
|
||||
|
||||
|
||||
# Extract user prompt from item
|
||||
user_content = dict(item[0][0])["content"]
|
||||
|
||||
|
||||
# Create messages list
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Add prefill as assistant message if configured
|
||||
if self.config.prefill:
|
||||
messages.append({"role": "assistant", "content": self.config.prefill})
|
||||
|
||||
|
||||
# Convert messages to a prompt string using the tokenizer
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
|
||||
# Calculate max tokens for generation (with optional warmup)
|
||||
max_tokens = self.config.max_tokens
|
||||
|
||||
|
||||
# Generate completions
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
|
|
@ -233,30 +239,36 @@ class DatasetEnv(BaseEnv):
|
|||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
|
||||
print(f"Completions: {completions}")
|
||||
# Process completions
|
||||
trajectories = []
|
||||
for completion in completions.choices:
|
||||
# Get the completion text
|
||||
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
|
||||
|
||||
completion_text = (
|
||||
completion.text
|
||||
if hasattr(completion, "text")
|
||||
else completion.message.content
|
||||
)
|
||||
|
||||
# Build complete message sequence
|
||||
full_messages = []
|
||||
if self.config.system_prompt:
|
||||
full_messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
full_messages.append(
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
)
|
||||
|
||||
full_messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Combine prefill with completion if prefill was used
|
||||
response_content = completion_text
|
||||
if self.config.prefill:
|
||||
response_content = self.config.prefill + completion_text
|
||||
|
||||
|
||||
full_messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
|
||||
trajectories.append(full_messages)
|
||||
|
||||
|
||||
return trajectories, []
|
||||
|
||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||
|
|
@ -402,6 +414,7 @@ class DatasetEnv(BaseEnv):
|
|||
|
||||
await super().wandb_log(metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Launch the DatasetEnv via the BaseEnv CLI (serve or process)
|
||||
DatasetEnv.cli()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ from dotenv import load_dotenv
|
|||
|
||||
from atroposlib.envs.base import APIServerConfig
|
||||
from atroposlib.envs.reward_fns import registry
|
||||
from atroposlib.utils.config_handler import ConfigHandler
|
||||
|
||||
# from atroposlib.utils.config_handler import ConfigHandler
|
||||
from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -23,7 +24,8 @@ def parse_arguments():
|
|||
"--config",
|
||||
type=str,
|
||||
default="dataset_local",
|
||||
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/, or full path to a YAML file.",
|
||||
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/,"
|
||||
" or full path to a YAML file.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -35,7 +37,7 @@ async def main():
|
|||
args = parse_arguments()
|
||||
|
||||
# Initialize config handler
|
||||
config_handler = ConfigHandler()
|
||||
# config_handler = ConfigHandler()
|
||||
|
||||
# Determine config path
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -14,16 +14,16 @@ Requirements:
|
|||
- Run from project root so example_trainer is on PYTHONPATH
|
||||
- example_trainer/ is a valid Python package (with __init__.py)
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import time
|
||||
import atexit
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
# Ensure project root is on PYTHONPATH
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
|
@ -32,56 +32,73 @@ try:
|
|||
from example_trainer.grpo import TrainingConfig, train
|
||||
except ImportError as e:
|
||||
print(f"Error importing example_trainer.grpo: {e}")
|
||||
print("Ensure you're running from project root and that example_trainer/ is a package.")
|
||||
print(
|
||||
"Ensure you're running from project root and that example_trainer/ is a package."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
API_HOST = '127.0.0.1'
|
||||
API_HOST = "127.0.0.1"
|
||||
API_PORT = 8000
|
||||
|
||||
VLLM_HOST = '127.0.0.1'
|
||||
VLLM_HOST = "127.0.0.1"
|
||||
VLLM_PORT = 9001
|
||||
|
||||
MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct'
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
TOKENIZER_NAME = MODEL_NAME
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
'model_name': MODEL_NAME,
|
||||
'training_steps': 20,
|
||||
'batch_size': 2,
|
||||
'gradient_accumulation_steps': 2,
|
||||
'seq_len': 512,
|
||||
'vllm_port': VLLM_PORT,
|
||||
'vllm_restart_interval': 10,
|
||||
'use_wandb': False,
|
||||
'wandb_project': '',
|
||||
'wandb_group': '',
|
||||
'save_path': './trained_model_checkpoints_local_test',
|
||||
"model_name": MODEL_NAME,
|
||||
"training_steps": 20,
|
||||
"batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"seq_len": 512,
|
||||
"vllm_port": VLLM_PORT,
|
||||
"vllm_restart_interval": 10,
|
||||
"use_wandb": False,
|
||||
"wandb_project": "",
|
||||
"wandb_group": "",
|
||||
"save_path": "./trained_model_checkpoints_local_test",
|
||||
}
|
||||
|
||||
# Flags for launching DatasetEnv serve
|
||||
DATASET_FLAGS = [
|
||||
'--group_size', '4',
|
||||
'--max_num_workers', '2',
|
||||
'--rollout_server_url', f"http://{API_HOST}:{API_PORT}",
|
||||
'--tokenizer_name', TOKENIZER_NAME,
|
||||
'--use_wandb',
|
||||
'--wandb_name', 'dataset_env_local_test',
|
||||
'--max_token_length', str(TRAINER_CONFIG['seq_len']),
|
||||
'--ensure_scores_are_not_same',
|
||||
'--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays',
|
||||
'--split', 'train[:100]',
|
||||
'--prompt_field', 'prompt',
|
||||
'--answer_field', 'answer',
|
||||
'--reward_functions', 'length',
|
||||
'--max_tokens', '128',
|
||||
'--temperature', '0.7',
|
||||
'--model_name', MODEL_NAME,
|
||||
'--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}",
|
||||
'--slurm',
|
||||
'--testing',
|
||||
"--group_size",
|
||||
"4",
|
||||
"--max_num_workers",
|
||||
"2",
|
||||
"--rollout_server_url",
|
||||
f"http://{API_HOST}:{API_PORT}",
|
||||
"--tokenizer_name",
|
||||
TOKENIZER_NAME,
|
||||
"--use_wandb",
|
||||
"--wandb_name",
|
||||
"dataset_env_local_test",
|
||||
"--max_token_length",
|
||||
str(TRAINER_CONFIG["seq_len"]),
|
||||
"--ensure_scores_are_not_same",
|
||||
"--dataset_name",
|
||||
"HuggingFaceH4/testing_self_instruct_process_essays",
|
||||
"--split",
|
||||
"train[:100]",
|
||||
"--prompt_field",
|
||||
"prompt",
|
||||
"--answer_field",
|
||||
"answer",
|
||||
"--reward_functions",
|
||||
"length",
|
||||
"--max_tokens",
|
||||
"128",
|
||||
"--temperature",
|
||||
"0.7",
|
||||
"--model_name",
|
||||
MODEL_NAME,
|
||||
"--base_url",
|
||||
f"http://{VLLM_HOST}:{VLLM_PORT}",
|
||||
"--slurm",
|
||||
"--testing",
|
||||
]
|
||||
|
||||
# Track background processes for cleanup
|
||||
|
|
@ -106,6 +123,7 @@ def cleanup_processes():
|
|||
print(f"PID {p.pid} already exited.")
|
||||
print("Cleanup complete.")
|
||||
|
||||
|
||||
atexit.register(cleanup_processes)
|
||||
|
||||
|
||||
|
|
@ -113,6 +131,7 @@ def handle_signal(sig, frame):
|
|||
print(f"\nSignal {sig} received; exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
signal.signal(signal.SIGTERM, handle_signal)
|
||||
|
||||
|
|
@ -121,10 +140,12 @@ def main():
|
|||
# 1) Start the API server
|
||||
print("--- Starting Trajectory Handler API Server ---")
|
||||
api_cmd = [
|
||||
'uvicorn',
|
||||
'atroposlib.api.server:app',
|
||||
'--host', API_HOST,
|
||||
'--port', str(API_PORT),
|
||||
"uvicorn",
|
||||
"atroposlib.api.server:app",
|
||||
"--host",
|
||||
API_HOST,
|
||||
"--port",
|
||||
str(API_PORT),
|
||||
]
|
||||
print(f"$ {' '.join(api_cmd)}")
|
||||
api_proc = subprocess.Popen(api_cmd)
|
||||
|
|
@ -133,7 +154,12 @@ def main():
|
|||
|
||||
# 2) Start the dataset environment
|
||||
print("\n--- Starting Dataset Environment ---")
|
||||
env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS
|
||||
env_cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"environments.dataset_environment.dataset_env",
|
||||
"serve",
|
||||
] + DATASET_FLAGS
|
||||
print(f"$ {' '.join(env_cmd)}")
|
||||
env_proc = subprocess.Popen(env_cmd)
|
||||
processes.append(env_proc)
|
||||
|
|
@ -150,5 +176,5 @@ def main():
|
|||
print("--- Training complete ---")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue