mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
run pre-commit on all files
This commit is contained in:
parent
b959c30ebf
commit
40b12dae60
17 changed files with 169 additions and 118 deletions
|
|
@ -1 +1 @@
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
|
|
|
||||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
|
|
@ -66,4 +66,4 @@
|
||||||
- [ ] My changes generate no new warnings
|
- [ ] My changes generate no new warnings
|
||||||
- [ ] New and existing unit tests pass locally with my changes
|
- [ ] New and existing unit tests pass locally with my changes
|
||||||
- [ ] Docstrings added for all new public classes / functions
|
- [ ] Docstrings added for all new public classes / functions
|
||||||
- [ ] If .env vars required, did you add it to the .env.example in repo root?
|
- [ ] If .env vars required, did you add it to the .env.example in repo root?
|
||||||
|
|
|
||||||
|
|
@ -41,4 +41,4 @@ Project maintainers are obligated to respect the privacy and security of the rep
|
||||||
|
|
||||||
This Code of Conduct is adapted from general open source community standards and GitHub's community guidelines.
|
This Code of Conduct is adapted from general open source community standards and GitHub's community guidelines.
|
||||||
|
|
||||||
Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone.
|
Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone.
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import argparse
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -387,8 +387,10 @@ class BaseEnv(ABC):
|
||||||
# Now register the env...
|
# Now register the env...
|
||||||
while True:
|
while True:
|
||||||
data = await self._register_env()
|
data = await self._register_env()
|
||||||
if data['status'] != "success":
|
if data["status"] != "success":
|
||||||
logging.warning(f"Waiting to register the env due to status {data['status']}")
|
logging.warning(
|
||||||
|
f"Waiting to register the env due to status {data['status']}"
|
||||||
|
)
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
continue
|
continue
|
||||||
self.env_id = data["env_id"]
|
self.env_id = data["env_id"]
|
||||||
|
|
|
||||||
|
|
@ -181,4 +181,4 @@ class ConfigHandler:
|
||||||
# Add slurm flag to config if running in a Slurm environment
|
# Add slurm flag to config if running in a Slurm environment
|
||||||
config["use_slurm"] = "SLURM_JOB_ID" in os.environ
|
config["use_slurm"] = "SLURM_JOB_ID" in os.environ
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
|
||||||
|
|
@ -152,4 +152,4 @@ server_configs:
|
||||||
|
|
||||||
If you encounter issues with reward functions, make sure they are properly registered in the registry.
|
If you encounter issues with reward functions, make sure they are properly registered in the registry.
|
||||||
|
|
||||||
For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset.
|
For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset.
|
||||||
|
|
|
||||||
|
|
@ -49,4 +49,4 @@ dataset:
|
||||||
server_configs:
|
server_configs:
|
||||||
- model_name: "gpt-4.1-nano"
|
- model_name: "gpt-4.1-nano"
|
||||||
api_key: ${OPENAI_API_KEY}
|
api_key: ${OPENAI_API_KEY}
|
||||||
timeout: 600
|
timeout: 600
|
||||||
|
|
|
||||||
|
|
@ -70,4 +70,4 @@ dataset:
|
||||||
|
|
||||||
eval_dataset_name: "gsm8k"
|
eval_dataset_name: "gsm8k"
|
||||||
eval_dataset_config: "main"
|
eval_dataset_config: "main"
|
||||||
eval_split: "test"
|
eval_split: "test"
|
||||||
|
|
|
||||||
|
|
@ -27,4 +27,4 @@ dataset:
|
||||||
|
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
length_warmup_steps: 0
|
length_warmup_steps: 0
|
||||||
min_tokens: 200
|
min_tokens: 200
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,9 @@ class DatasetEnvConfig(BaseEnvConfig):
|
||||||
None, description="Field in dataset containing canonical correct answer"
|
None, description="Field in dataset containing canonical correct answer"
|
||||||
)
|
)
|
||||||
system_prompt: Optional[str] = Field(None, description="System prompt to use")
|
system_prompt: Optional[str] = Field(None, description="System prompt to use")
|
||||||
prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '<think>')")
|
prefill: Optional[str] = Field(
|
||||||
|
None, description="Text to prefill the completion with (e.g. '<think>')"
|
||||||
|
)
|
||||||
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
|
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
|
||||||
max_generations_per_prompt: int = Field(
|
max_generations_per_prompt: int = Field(
|
||||||
1, description="Number of generations per prompt for collection"
|
1, description="Number of generations per prompt for collection"
|
||||||
|
|
@ -137,21 +139,21 @@ class DatasetEnv(BaseEnv):
|
||||||
# Extract user prompt and answer from item
|
# Extract user prompt and answer from item
|
||||||
user_content = dict(item[0][0])["content"]
|
user_content = dict(item[0][0])["content"]
|
||||||
answer = item[1] if len(item) > 1 else None
|
answer = item[1] if len(item) > 1 else None
|
||||||
|
|
||||||
# Create messages list
|
# Create messages list
|
||||||
messages = []
|
messages = []
|
||||||
if self.config.system_prompt:
|
if self.config.system_prompt:
|
||||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||||
|
|
||||||
messages.append({"role": "user", "content": user_content})
|
messages.append({"role": "user", "content": user_content})
|
||||||
|
|
||||||
# Add prefill as assistant message if configured
|
# Add prefill as assistant message if configured
|
||||||
if self.config.prefill:
|
if self.config.prefill:
|
||||||
messages.append({"role": "assistant", "content": self.config.prefill})
|
messages.append({"role": "assistant", "content": self.config.prefill})
|
||||||
|
|
||||||
# Convert messages to a prompt string using the tokenizer
|
# Convert messages to a prompt string using the tokenizer
|
||||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
# Calculate max tokens for generation (with optional warmup)
|
# Calculate max tokens for generation (with optional warmup)
|
||||||
max_tokens = self.config.max_tokens
|
max_tokens = self.config.max_tokens
|
||||||
if self.config.length_warmup_steps > 0:
|
if self.config.length_warmup_steps > 0:
|
||||||
|
|
@ -160,7 +162,7 @@ class DatasetEnv(BaseEnv):
|
||||||
self.config.min_tokens
|
self.config.min_tokens
|
||||||
+ warmup_progress * (self.config.max_tokens - self.config.min_tokens)
|
+ warmup_progress * (self.config.max_tokens - self.config.min_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate completion using completions API
|
# Generate completion using completions API
|
||||||
completions = await self.server.completion(
|
completions = await self.server.completion(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|
@ -169,34 +171,38 @@ class DatasetEnv(BaseEnv):
|
||||||
temperature=self.config.temperature,
|
temperature=self.config.temperature,
|
||||||
top_p=self.config.top_p,
|
top_p=self.config.top_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
to_score = []
|
to_score = []
|
||||||
to_backlog = []
|
to_backlog = []
|
||||||
|
|
||||||
# Process completions
|
# Process completions
|
||||||
for completion in completions.choices:
|
for completion in completions.choices:
|
||||||
# Get the completion text
|
# Get the completion text
|
||||||
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
|
completion_text = (
|
||||||
|
completion.text
|
||||||
|
if hasattr(completion, "text")
|
||||||
|
else completion.message.content
|
||||||
|
)
|
||||||
|
|
||||||
# Build full message sequence for scoring
|
# Build full message sequence for scoring
|
||||||
full_messages = []
|
full_messages = []
|
||||||
if self.config.system_prompt:
|
if self.config.system_prompt:
|
||||||
full_messages.append({"role": "system", "content": self.config.system_prompt})
|
full_messages.append(
|
||||||
|
{"role": "system", "content": self.config.system_prompt}
|
||||||
|
)
|
||||||
|
|
||||||
full_messages.append({"role": "user", "content": user_content})
|
full_messages.append({"role": "user", "content": user_content})
|
||||||
|
|
||||||
# Combine prefill with completion if prefill was used
|
# Combine prefill with completion if prefill was used
|
||||||
response_content = completion_text
|
response_content = completion_text
|
||||||
if self.config.prefill:
|
if self.config.prefill:
|
||||||
response_content = self.config.prefill + completion_text
|
response_content = self.config.prefill + completion_text
|
||||||
|
|
||||||
full_messages.append({"role": "assistant", "content": response_content})
|
full_messages.append({"role": "assistant", "content": response_content})
|
||||||
|
|
||||||
# Add to scoring list with answer and ground truth
|
# Add to scoring list with answer and ground truth
|
||||||
to_score.append(
|
to_score.append((full_messages, answer, item[2] if len(item) > 2 else None))
|
||||||
(full_messages, answer, item[2] if len(item) > 2 else None)
|
|
||||||
)
|
|
||||||
|
|
||||||
return to_score, to_backlog
|
return to_score, to_backlog
|
||||||
|
|
||||||
async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]:
|
async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]:
|
||||||
|
|
@ -204,27 +210,27 @@ class DatasetEnv(BaseEnv):
|
||||||
|
|
||||||
async def collect_trajectories(self, item: Item) -> Tuple[List, List]:
|
async def collect_trajectories(self, item: Item) -> Tuple[List, List]:
|
||||||
self.current_item = item
|
self.current_item = item
|
||||||
|
|
||||||
# Extract user prompt from item
|
# Extract user prompt from item
|
||||||
user_content = dict(item[0][0])["content"]
|
user_content = dict(item[0][0])["content"]
|
||||||
|
|
||||||
# Create messages list
|
# Create messages list
|
||||||
messages = []
|
messages = []
|
||||||
if self.config.system_prompt:
|
if self.config.system_prompt:
|
||||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||||
|
|
||||||
messages.append({"role": "user", "content": user_content})
|
messages.append({"role": "user", "content": user_content})
|
||||||
|
|
||||||
# Add prefill as assistant message if configured
|
# Add prefill as assistant message if configured
|
||||||
if self.config.prefill:
|
if self.config.prefill:
|
||||||
messages.append({"role": "assistant", "content": self.config.prefill})
|
messages.append({"role": "assistant", "content": self.config.prefill})
|
||||||
|
|
||||||
# Convert messages to a prompt string using the tokenizer
|
# Convert messages to a prompt string using the tokenizer
|
||||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
# Calculate max tokens for generation (with optional warmup)
|
# Calculate max tokens for generation (with optional warmup)
|
||||||
max_tokens = self.config.max_tokens
|
max_tokens = self.config.max_tokens
|
||||||
|
|
||||||
# Generate completions
|
# Generate completions
|
||||||
completions = await self.server.completion(
|
completions = await self.server.completion(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|
@ -233,30 +239,36 @@ class DatasetEnv(BaseEnv):
|
||||||
temperature=self.config.temperature,
|
temperature=self.config.temperature,
|
||||||
top_p=self.config.top_p,
|
top_p=self.config.top_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Completions: {completions}")
|
print(f"Completions: {completions}")
|
||||||
# Process completions
|
# Process completions
|
||||||
trajectories = []
|
trajectories = []
|
||||||
for completion in completions.choices:
|
for completion in completions.choices:
|
||||||
# Get the completion text
|
# Get the completion text
|
||||||
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
|
completion_text = (
|
||||||
|
completion.text
|
||||||
|
if hasattr(completion, "text")
|
||||||
|
else completion.message.content
|
||||||
|
)
|
||||||
|
|
||||||
# Build complete message sequence
|
# Build complete message sequence
|
||||||
full_messages = []
|
full_messages = []
|
||||||
if self.config.system_prompt:
|
if self.config.system_prompt:
|
||||||
full_messages.append({"role": "system", "content": self.config.system_prompt})
|
full_messages.append(
|
||||||
|
{"role": "system", "content": self.config.system_prompt}
|
||||||
|
)
|
||||||
|
|
||||||
full_messages.append({"role": "user", "content": user_content})
|
full_messages.append({"role": "user", "content": user_content})
|
||||||
|
|
||||||
# Combine prefill with completion if prefill was used
|
# Combine prefill with completion if prefill was used
|
||||||
response_content = completion_text
|
response_content = completion_text
|
||||||
if self.config.prefill:
|
if self.config.prefill:
|
||||||
response_content = self.config.prefill + completion_text
|
response_content = self.config.prefill + completion_text
|
||||||
|
|
||||||
full_messages.append({"role": "assistant", "content": response_content})
|
full_messages.append({"role": "assistant", "content": response_content})
|
||||||
|
|
||||||
trajectories.append(full_messages)
|
trajectories.append(full_messages)
|
||||||
|
|
||||||
return trajectories, []
|
return trajectories, []
|
||||||
|
|
||||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||||
|
|
@ -402,6 +414,7 @@ class DatasetEnv(BaseEnv):
|
||||||
|
|
||||||
await super().wandb_log(metrics)
|
await super().wandb_log(metrics)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Launch the DatasetEnv via the BaseEnv CLI (serve or process)
|
# Launch the DatasetEnv via the BaseEnv CLI (serve or process)
|
||||||
DatasetEnv.cli()
|
DatasetEnv.cli()
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,8 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
from atroposlib.envs.base import OpenaiConfig
|
from atroposlib.envs.base import OpenaiConfig
|
||||||
from atroposlib.envs.reward_fns import registry
|
from atroposlib.envs.reward_fns import registry
|
||||||
from atroposlib.utils.config_handler import ConfigHandler
|
|
||||||
|
# from atroposlib.utils.config_handler import ConfigHandler
|
||||||
from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig
|
from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
@ -23,7 +24,8 @@ def parse_arguments():
|
||||||
"--config",
|
"--config",
|
||||||
type=str,
|
type=str,
|
||||||
default="dataset_local",
|
default="dataset_local",
|
||||||
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/, or full path to a YAML file.",
|
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/,"
|
||||||
|
" or full path to a YAML file.",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -35,7 +37,7 @@ async def main():
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
# Initialize config handler
|
# Initialize config handler
|
||||||
config_handler = ConfigHandler()
|
# config_handler = ConfigHandler()
|
||||||
|
|
||||||
# Determine config path
|
# Determine config path
|
||||||
if (
|
if (
|
||||||
|
|
|
||||||
|
|
@ -14,16 +14,16 @@ Requirements:
|
||||||
- Run from project root so example_trainer is on PYTHONPATH
|
- Run from project root so example_trainer is on PYTHONPATH
|
||||||
- example_trainer/ is a valid Python package (with __init__.py)
|
- example_trainer/ is a valid Python package (with __init__.py)
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import atexit
|
import atexit
|
||||||
|
import os
|
||||||
import signal
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
# Ensure project root is on PYTHONPATH
|
# Ensure project root is on PYTHONPATH
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
if project_root not in sys.path:
|
if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
|
@ -32,56 +32,73 @@ try:
|
||||||
from example_trainer.grpo import TrainingConfig, train
|
from example_trainer.grpo import TrainingConfig, train
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"Error importing example_trainer.grpo: {e}")
|
print(f"Error importing example_trainer.grpo: {e}")
|
||||||
print("Ensure you're running from project root and that example_trainer/ is a package.")
|
print(
|
||||||
|
"Ensure you're running from project root and that example_trainer/ is a package."
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Configuration
|
# Configuration
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
API_HOST = '127.0.0.1'
|
API_HOST = "127.0.0.1"
|
||||||
API_PORT = 8000
|
API_PORT = 8000
|
||||||
|
|
||||||
VLLM_HOST = '127.0.0.1'
|
VLLM_HOST = "127.0.0.1"
|
||||||
VLLM_PORT = 9001
|
VLLM_PORT = 9001
|
||||||
|
|
||||||
MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct'
|
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
TOKENIZER_NAME = MODEL_NAME
|
TOKENIZER_NAME = MODEL_NAME
|
||||||
|
|
||||||
TRAINER_CONFIG = {
|
TRAINER_CONFIG = {
|
||||||
'model_name': MODEL_NAME,
|
"model_name": MODEL_NAME,
|
||||||
'training_steps': 20,
|
"training_steps": 20,
|
||||||
'batch_size': 2,
|
"batch_size": 2,
|
||||||
'gradient_accumulation_steps': 2,
|
"gradient_accumulation_steps": 2,
|
||||||
'seq_len': 512,
|
"seq_len": 512,
|
||||||
'vllm_port': VLLM_PORT,
|
"vllm_port": VLLM_PORT,
|
||||||
'vllm_restart_interval': 10,
|
"vllm_restart_interval": 10,
|
||||||
'use_wandb': False,
|
"use_wandb": False,
|
||||||
'wandb_project': '',
|
"wandb_project": "",
|
||||||
'wandb_group': '',
|
"wandb_group": "",
|
||||||
'save_path': './trained_model_checkpoints_local_test',
|
"save_path": "./trained_model_checkpoints_local_test",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Flags for launching DatasetEnv serve
|
# Flags for launching DatasetEnv serve
|
||||||
DATASET_FLAGS = [
|
DATASET_FLAGS = [
|
||||||
'--group_size', '4',
|
"--group_size",
|
||||||
'--max_num_workers', '2',
|
"4",
|
||||||
'--rollout_server_url', f"http://{API_HOST}:{API_PORT}",
|
"--max_num_workers",
|
||||||
'--tokenizer_name', TOKENIZER_NAME,
|
"2",
|
||||||
'--use_wandb',
|
"--rollout_server_url",
|
||||||
'--wandb_name', 'dataset_env_local_test',
|
f"http://{API_HOST}:{API_PORT}",
|
||||||
'--max_token_length', str(TRAINER_CONFIG['seq_len']),
|
"--tokenizer_name",
|
||||||
'--ensure_scores_are_not_same',
|
TOKENIZER_NAME,
|
||||||
'--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays',
|
"--use_wandb",
|
||||||
'--split', 'train[:100]',
|
"--wandb_name",
|
||||||
'--prompt_field', 'prompt',
|
"dataset_env_local_test",
|
||||||
'--answer_field', 'answer',
|
"--max_token_length",
|
||||||
'--reward_functions', 'length',
|
str(TRAINER_CONFIG["seq_len"]),
|
||||||
'--max_tokens', '128',
|
"--ensure_scores_are_not_same",
|
||||||
'--temperature', '0.7',
|
"--dataset_name",
|
||||||
'--model_name', MODEL_NAME,
|
"HuggingFaceH4/testing_self_instruct_process_essays",
|
||||||
'--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}",
|
"--split",
|
||||||
'--slurm',
|
"train[:100]",
|
||||||
'--testing',
|
"--prompt_field",
|
||||||
|
"prompt",
|
||||||
|
"--answer_field",
|
||||||
|
"answer",
|
||||||
|
"--reward_functions",
|
||||||
|
"length",
|
||||||
|
"--max_tokens",
|
||||||
|
"128",
|
||||||
|
"--temperature",
|
||||||
|
"0.7",
|
||||||
|
"--model_name",
|
||||||
|
MODEL_NAME,
|
||||||
|
"--base_url",
|
||||||
|
f"http://{VLLM_HOST}:{VLLM_PORT}",
|
||||||
|
"--slurm",
|
||||||
|
"--testing",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Track background processes for cleanup
|
# Track background processes for cleanup
|
||||||
|
|
@ -106,6 +123,7 @@ def cleanup_processes():
|
||||||
print(f"PID {p.pid} already exited.")
|
print(f"PID {p.pid} already exited.")
|
||||||
print("Cleanup complete.")
|
print("Cleanup complete.")
|
||||||
|
|
||||||
|
|
||||||
atexit.register(cleanup_processes)
|
atexit.register(cleanup_processes)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -113,6 +131,7 @@ def handle_signal(sig, frame):
|
||||||
print(f"\nSignal {sig} received; exiting.")
|
print(f"\nSignal {sig} received; exiting.")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, handle_signal)
|
signal.signal(signal.SIGINT, handle_signal)
|
||||||
signal.signal(signal.SIGTERM, handle_signal)
|
signal.signal(signal.SIGTERM, handle_signal)
|
||||||
|
|
||||||
|
|
@ -121,10 +140,12 @@ def main():
|
||||||
# 1) Start the API server
|
# 1) Start the API server
|
||||||
print("--- Starting Trajectory Handler API Server ---")
|
print("--- Starting Trajectory Handler API Server ---")
|
||||||
api_cmd = [
|
api_cmd = [
|
||||||
'uvicorn',
|
"uvicorn",
|
||||||
'atroposlib.api.server:app',
|
"atroposlib.api.server:app",
|
||||||
'--host', API_HOST,
|
"--host",
|
||||||
'--port', str(API_PORT),
|
API_HOST,
|
||||||
|
"--port",
|
||||||
|
str(API_PORT),
|
||||||
]
|
]
|
||||||
print(f"$ {' '.join(api_cmd)}")
|
print(f"$ {' '.join(api_cmd)}")
|
||||||
api_proc = subprocess.Popen(api_cmd)
|
api_proc = subprocess.Popen(api_cmd)
|
||||||
|
|
@ -133,7 +154,12 @@ def main():
|
||||||
|
|
||||||
# 2) Start the dataset environment
|
# 2) Start the dataset environment
|
||||||
print("\n--- Starting Dataset Environment ---")
|
print("\n--- Starting Dataset Environment ---")
|
||||||
env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS
|
env_cmd = [
|
||||||
|
"python",
|
||||||
|
"-m",
|
||||||
|
"environments.dataset_environment.dataset_env",
|
||||||
|
"serve",
|
||||||
|
] + DATASET_FLAGS
|
||||||
print(f"$ {' '.join(env_cmd)}")
|
print(f"$ {' '.join(env_cmd)}")
|
||||||
env_proc = subprocess.Popen(env_cmd)
|
env_proc = subprocess.Popen(env_cmd)
|
||||||
processes.append(env_proc)
|
processes.append(env_proc)
|
||||||
|
|
@ -150,5 +176,5 @@ def main():
|
||||||
print("--- Training complete ---")
|
print("--- Training complete ---")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple, Union, Dict
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import wandb
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from tqdm.asyncio import tqdm_asyncio
|
from tqdm.asyncio import tqdm_asyncio
|
||||||
import wandb
|
|
||||||
|
|
||||||
from atroposlib.envs.base import (
|
from atroposlib.envs.base import (
|
||||||
BaseEnv,
|
BaseEnv,
|
||||||
|
|
@ -508,23 +508,32 @@ class FundamentalPredictionEnv(BaseEnv):
|
||||||
|
|
||||||
# Calculate and log training direction accuracy
|
# Calculate and log training direction accuracy
|
||||||
try:
|
try:
|
||||||
direction_accuracy = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
|
direction_accuracy = sum(self.percent_correct_buffer) / len(
|
||||||
|
self.percent_correct_buffer
|
||||||
|
)
|
||||||
wandb_metrics["train/direction_accuracy"] = direction_accuracy
|
wandb_metrics["train/direction_accuracy"] = direction_accuracy
|
||||||
except ZeroDivisionError:
|
except ZeroDivisionError:
|
||||||
pass # Skip if buffer is empty
|
pass # Skip if buffer is empty
|
||||||
|
|
||||||
# Calculate and log training magnitude accuracy
|
# Calculate and log training magnitude accuracy
|
||||||
try:
|
try:
|
||||||
magnitude_accuracy = sum(self.magnitude_accuracy_buffer) / len(self.magnitude_accuracy_buffer)
|
magnitude_accuracy = sum(self.magnitude_accuracy_buffer) / len(
|
||||||
|
self.magnitude_accuracy_buffer
|
||||||
|
)
|
||||||
wandb_metrics["train/magnitude_accuracy"] = magnitude_accuracy
|
wandb_metrics["train/magnitude_accuracy"] = magnitude_accuracy
|
||||||
except ZeroDivisionError:
|
except ZeroDivisionError:
|
||||||
pass # Skip if buffer is empty
|
pass # Skip if buffer is empty
|
||||||
|
|
||||||
# Calculate combined training score (direction + magnitude)
|
# Calculate combined training score (direction + magnitude)
|
||||||
try:
|
try:
|
||||||
combined_score = direction_accuracy + magnitude_accuracy if 'direction_accuracy' in wandb_metrics else 0
|
combined_score = (
|
||||||
|
direction_accuracy + magnitude_accuracy
|
||||||
|
if "direction_accuracy" in wandb_metrics
|
||||||
|
else 0
|
||||||
|
)
|
||||||
wandb_metrics["train/combined_score"] = combined_score
|
wandb_metrics["train/combined_score"] = combined_score
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(f"Error calculating combined score: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Clear the buffers after logging
|
# Clear the buffers after logging
|
||||||
|
|
@ -549,7 +558,7 @@ class FundamentalPredictionEnv(BaseEnv):
|
||||||
|
|
||||||
# Get number of examples to keep
|
# Get number of examples to keep
|
||||||
num_keep = getattr(self.config, "num_rollouts_per_group_for_logging", -1)
|
num_keep = getattr(self.config, "num_rollouts_per_group_for_logging", -1)
|
||||||
|
|
||||||
if num_keep == -1:
|
if num_keep == -1:
|
||||||
num_keep = self.config.group_size
|
num_keep = self.config.group_size
|
||||||
|
|
||||||
|
|
@ -577,23 +586,25 @@ class FundamentalPredictionEnv(BaseEnv):
|
||||||
|
|
||||||
async def create_rollout_table(self, wandb_metrics):
|
async def create_rollout_table(self, wandb_metrics):
|
||||||
if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0:
|
if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0:
|
||||||
table = wandb.Table(columns=[
|
table = wandb.Table(
|
||||||
"text",
|
columns=[
|
||||||
"score",
|
"text",
|
||||||
"expected_direction",
|
"score",
|
||||||
"expected_magnitude",
|
"expected_direction",
|
||||||
"fundamental_metric"
|
"expected_magnitude",
|
||||||
])
|
"fundamental_metric",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
for group in self.rollouts_for_wandb:
|
for group in self.rollouts_for_wandb:
|
||||||
for item in group:
|
for item in group:
|
||||||
table.add_data(item[0], item[1], item[2], item[3], item[4])
|
table.add_data(item[0], item[1], item[2], item[3], item[4])
|
||||||
|
|
||||||
wandb_metrics["train/rollouts"] = table
|
wandb_metrics["train/rollouts"] = table
|
||||||
|
|
||||||
# Clear rollouts after logging
|
# Clear rollouts after logging
|
||||||
self.rollouts_for_wandb = []
|
self.rollouts_for_wandb = []
|
||||||
|
|
||||||
return wandb_metrics
|
return wandb_metrics
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from concurrent.futures import ProcessPoolExecutor
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import wandb
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
from latex2sympy2_extended import NormalizationConfig
|
||||||
from math_verify import LatexExtractionConfig, parse, verify
|
from math_verify import LatexExtractionConfig, parse, verify
|
||||||
|
|
@ -12,7 +13,6 @@ from math_verify.errors import TimeoutException
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from tqdm.asyncio import tqdm_asyncio
|
from tqdm.asyncio import tqdm_asyncio
|
||||||
|
|
||||||
import wandb
|
|
||||||
from atroposlib.envs.base import (
|
from atroposlib.envs.base import (
|
||||||
BaseEnv,
|
BaseEnv,
|
||||||
BaseEnvConfig,
|
BaseEnvConfig,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ from datasets import load_dataset
|
||||||
|
|
||||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||||
from atroposlib.type_definitions import GameHistory, Item
|
from atroposlib.type_definitions import GameHistory, Item
|
||||||
|
|
||||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,13 +13,12 @@ import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import wandb # Added for logging
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
import wandb # Added for logging
|
|
||||||
|
|
||||||
# Global variable to keep track of the vLLM process
|
# Global variable to keep track of the vLLM process
|
||||||
vllm_process = None
|
vllm_process = None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue