mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
run pre-commit on all files
This commit is contained in:
parent
b959c30ebf
commit
40b12dae60
17 changed files with 169 additions and 118 deletions
|
|
@ -1 +1 @@
|
|||
OPENAI_API_KEY=
|
||||
OPENAI_API_KEY=
|
||||
|
|
|
|||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
|
|
@ -66,4 +66,4 @@
|
|||
- [ ] My changes generate no new warnings
|
||||
- [ ] New and existing unit tests pass locally with my changes
|
||||
- [ ] Docstrings added for all new public classes / functions
|
||||
- [ ] If .env vars required, did you add it to the .env.example in repo root?
|
||||
- [ ] If .env vars required, did you add it to the .env.example in repo root?
|
||||
|
|
|
|||
|
|
@ -41,4 +41,4 @@ Project maintainers are obligated to respect the privacy and security of the rep
|
|||
|
||||
This Code of Conduct is adapted from general open source community standards and GitHub's community guidelines.
|
||||
|
||||
Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone.
|
||||
Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import argparse
|
|||
import time
|
||||
|
||||
import requests
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -387,8 +387,10 @@ class BaseEnv(ABC):
|
|||
# Now register the env...
|
||||
while True:
|
||||
data = await self._register_env()
|
||||
if data['status'] != "success":
|
||||
logging.warning(f"Waiting to register the env due to status {data['status']}")
|
||||
if data["status"] != "success":
|
||||
logging.warning(
|
||||
f"Waiting to register the env due to status {data['status']}"
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
self.env_id = data["env_id"]
|
||||
|
|
|
|||
|
|
@ -181,4 +181,4 @@ class ConfigHandler:
|
|||
# Add slurm flag to config if running in a Slurm environment
|
||||
config["use_slurm"] = "SLURM_JOB_ID" in os.environ
|
||||
|
||||
return config
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -152,4 +152,4 @@ server_configs:
|
|||
|
||||
If you encounter issues with reward functions, make sure they are properly registered in the registry.
|
||||
|
||||
For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset.
|
||||
For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset.
|
||||
|
|
|
|||
|
|
@ -49,4 +49,4 @@ dataset:
|
|||
server_configs:
|
||||
- model_name: "gpt-4.1-nano"
|
||||
api_key: ${OPENAI_API_KEY}
|
||||
timeout: 600
|
||||
timeout: 600
|
||||
|
|
|
|||
|
|
@ -70,4 +70,4 @@ dataset:
|
|||
|
||||
eval_dataset_name: "gsm8k"
|
||||
eval_dataset_config: "main"
|
||||
eval_split: "test"
|
||||
eval_split: "test"
|
||||
|
|
|
|||
|
|
@ -27,4 +27,4 @@ dataset:
|
|||
|
||||
max_tokens: 4096
|
||||
length_warmup_steps: 0
|
||||
min_tokens: 200
|
||||
min_tokens: 200
|
||||
|
|
|
|||
|
|
@ -32,7 +32,9 @@ class DatasetEnvConfig(BaseEnvConfig):
|
|||
None, description="Field in dataset containing canonical correct answer"
|
||||
)
|
||||
system_prompt: Optional[str] = Field(None, description="System prompt to use")
|
||||
prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '<think>')")
|
||||
prefill: Optional[str] = Field(
|
||||
None, description="Text to prefill the completion with (e.g. '<think>')"
|
||||
)
|
||||
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
|
||||
max_generations_per_prompt: int = Field(
|
||||
1, description="Number of generations per prompt for collection"
|
||||
|
|
@ -137,21 +139,21 @@ class DatasetEnv(BaseEnv):
|
|||
# Extract user prompt and answer from item
|
||||
user_content = dict(item[0][0])["content"]
|
||||
answer = item[1] if len(item) > 1 else None
|
||||
|
||||
|
||||
# Create messages list
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Add prefill as assistant message if configured
|
||||
if self.config.prefill:
|
||||
messages.append({"role": "assistant", "content": self.config.prefill})
|
||||
|
||||
|
||||
# Convert messages to a prompt string using the tokenizer
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
|
||||
# Calculate max tokens for generation (with optional warmup)
|
||||
max_tokens = self.config.max_tokens
|
||||
if self.config.length_warmup_steps > 0:
|
||||
|
|
@ -160,7 +162,7 @@ class DatasetEnv(BaseEnv):
|
|||
self.config.min_tokens
|
||||
+ warmup_progress * (self.config.max_tokens - self.config.min_tokens)
|
||||
)
|
||||
|
||||
|
||||
# Generate completion using completions API
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
|
|
@ -169,34 +171,38 @@ class DatasetEnv(BaseEnv):
|
|||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
|
||||
to_score = []
|
||||
to_backlog = []
|
||||
|
||||
|
||||
# Process completions
|
||||
for completion in completions.choices:
|
||||
# Get the completion text
|
||||
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
|
||||
|
||||
completion_text = (
|
||||
completion.text
|
||||
if hasattr(completion, "text")
|
||||
else completion.message.content
|
||||
)
|
||||
|
||||
# Build full message sequence for scoring
|
||||
full_messages = []
|
||||
if self.config.system_prompt:
|
||||
full_messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
full_messages.append(
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
)
|
||||
|
||||
full_messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Combine prefill with completion if prefill was used
|
||||
response_content = completion_text
|
||||
if self.config.prefill:
|
||||
response_content = self.config.prefill + completion_text
|
||||
|
||||
|
||||
full_messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
|
||||
# Add to scoring list with answer and ground truth
|
||||
to_score.append(
|
||||
(full_messages, answer, item[2] if len(item) > 2 else None)
|
||||
)
|
||||
|
||||
to_score.append((full_messages, answer, item[2] if len(item) > 2 else None))
|
||||
|
||||
return to_score, to_backlog
|
||||
|
||||
async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]:
|
||||
|
|
@ -204,27 +210,27 @@ class DatasetEnv(BaseEnv):
|
|||
|
||||
async def collect_trajectories(self, item: Item) -> Tuple[List, List]:
|
||||
self.current_item = item
|
||||
|
||||
|
||||
# Extract user prompt from item
|
||||
user_content = dict(item[0][0])["content"]
|
||||
|
||||
|
||||
# Create messages list
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Add prefill as assistant message if configured
|
||||
if self.config.prefill:
|
||||
messages.append({"role": "assistant", "content": self.config.prefill})
|
||||
|
||||
|
||||
# Convert messages to a prompt string using the tokenizer
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
|
||||
# Calculate max tokens for generation (with optional warmup)
|
||||
max_tokens = self.config.max_tokens
|
||||
|
||||
|
||||
# Generate completions
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt,
|
||||
|
|
@ -233,30 +239,36 @@ class DatasetEnv(BaseEnv):
|
|||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
)
|
||||
|
||||
|
||||
print(f"Completions: {completions}")
|
||||
# Process completions
|
||||
trajectories = []
|
||||
for completion in completions.choices:
|
||||
# Get the completion text
|
||||
completion_text = completion.text if hasattr(completion, "text") else completion.message.content
|
||||
|
||||
completion_text = (
|
||||
completion.text
|
||||
if hasattr(completion, "text")
|
||||
else completion.message.content
|
||||
)
|
||||
|
||||
# Build complete message sequence
|
||||
full_messages = []
|
||||
if self.config.system_prompt:
|
||||
full_messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
|
||||
full_messages.append(
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
)
|
||||
|
||||
full_messages.append({"role": "user", "content": user_content})
|
||||
|
||||
|
||||
# Combine prefill with completion if prefill was used
|
||||
response_content = completion_text
|
||||
if self.config.prefill:
|
||||
response_content = self.config.prefill + completion_text
|
||||
|
||||
|
||||
full_messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
|
||||
trajectories.append(full_messages)
|
||||
|
||||
|
||||
return trajectories, []
|
||||
|
||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||
|
|
@ -402,6 +414,7 @@ class DatasetEnv(BaseEnv):
|
|||
|
||||
await super().wandb_log(metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Launch the DatasetEnv via the BaseEnv CLI (serve or process)
|
||||
DatasetEnv.cli()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ from dotenv import load_dotenv
|
|||
|
||||
from atroposlib.envs.base import OpenaiConfig
|
||||
from atroposlib.envs.reward_fns import registry
|
||||
from atroposlib.utils.config_handler import ConfigHandler
|
||||
|
||||
# from atroposlib.utils.config_handler import ConfigHandler
|
||||
from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -23,7 +24,8 @@ def parse_arguments():
|
|||
"--config",
|
||||
type=str,
|
||||
default="dataset_local",
|
||||
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/, or full path to a YAML file.",
|
||||
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/,"
|
||||
" or full path to a YAML file.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
@ -35,7 +37,7 @@ async def main():
|
|||
args = parse_arguments()
|
||||
|
||||
# Initialize config handler
|
||||
config_handler = ConfigHandler()
|
||||
# config_handler = ConfigHandler()
|
||||
|
||||
# Determine config path
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -14,16 +14,16 @@ Requirements:
|
|||
- Run from project root so example_trainer is on PYTHONPATH
|
||||
- example_trainer/ is a valid Python package (with __init__.py)
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import time
|
||||
import atexit
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
# Ensure project root is on PYTHONPATH
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
|
@ -32,56 +32,73 @@ try:
|
|||
from example_trainer.grpo import TrainingConfig, train
|
||||
except ImportError as e:
|
||||
print(f"Error importing example_trainer.grpo: {e}")
|
||||
print("Ensure you're running from project root and that example_trainer/ is a package.")
|
||||
print(
|
||||
"Ensure you're running from project root and that example_trainer/ is a package."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
API_HOST = '127.0.0.1'
|
||||
API_HOST = "127.0.0.1"
|
||||
API_PORT = 8000
|
||||
|
||||
VLLM_HOST = '127.0.0.1'
|
||||
VLLM_HOST = "127.0.0.1"
|
||||
VLLM_PORT = 9001
|
||||
|
||||
MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct'
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
TOKENIZER_NAME = MODEL_NAME
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
'model_name': MODEL_NAME,
|
||||
'training_steps': 20,
|
||||
'batch_size': 2,
|
||||
'gradient_accumulation_steps': 2,
|
||||
'seq_len': 512,
|
||||
'vllm_port': VLLM_PORT,
|
||||
'vllm_restart_interval': 10,
|
||||
'use_wandb': False,
|
||||
'wandb_project': '',
|
||||
'wandb_group': '',
|
||||
'save_path': './trained_model_checkpoints_local_test',
|
||||
"model_name": MODEL_NAME,
|
||||
"training_steps": 20,
|
||||
"batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"seq_len": 512,
|
||||
"vllm_port": VLLM_PORT,
|
||||
"vllm_restart_interval": 10,
|
||||
"use_wandb": False,
|
||||
"wandb_project": "",
|
||||
"wandb_group": "",
|
||||
"save_path": "./trained_model_checkpoints_local_test",
|
||||
}
|
||||
|
||||
# Flags for launching DatasetEnv serve
|
||||
DATASET_FLAGS = [
|
||||
'--group_size', '4',
|
||||
'--max_num_workers', '2',
|
||||
'--rollout_server_url', f"http://{API_HOST}:{API_PORT}",
|
||||
'--tokenizer_name', TOKENIZER_NAME,
|
||||
'--use_wandb',
|
||||
'--wandb_name', 'dataset_env_local_test',
|
||||
'--max_token_length', str(TRAINER_CONFIG['seq_len']),
|
||||
'--ensure_scores_are_not_same',
|
||||
'--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays',
|
||||
'--split', 'train[:100]',
|
||||
'--prompt_field', 'prompt',
|
||||
'--answer_field', 'answer',
|
||||
'--reward_functions', 'length',
|
||||
'--max_tokens', '128',
|
||||
'--temperature', '0.7',
|
||||
'--model_name', MODEL_NAME,
|
||||
'--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}",
|
||||
'--slurm',
|
||||
'--testing',
|
||||
"--group_size",
|
||||
"4",
|
||||
"--max_num_workers",
|
||||
"2",
|
||||
"--rollout_server_url",
|
||||
f"http://{API_HOST}:{API_PORT}",
|
||||
"--tokenizer_name",
|
||||
TOKENIZER_NAME,
|
||||
"--use_wandb",
|
||||
"--wandb_name",
|
||||
"dataset_env_local_test",
|
||||
"--max_token_length",
|
||||
str(TRAINER_CONFIG["seq_len"]),
|
||||
"--ensure_scores_are_not_same",
|
||||
"--dataset_name",
|
||||
"HuggingFaceH4/testing_self_instruct_process_essays",
|
||||
"--split",
|
||||
"train[:100]",
|
||||
"--prompt_field",
|
||||
"prompt",
|
||||
"--answer_field",
|
||||
"answer",
|
||||
"--reward_functions",
|
||||
"length",
|
||||
"--max_tokens",
|
||||
"128",
|
||||
"--temperature",
|
||||
"0.7",
|
||||
"--model_name",
|
||||
MODEL_NAME,
|
||||
"--base_url",
|
||||
f"http://{VLLM_HOST}:{VLLM_PORT}",
|
||||
"--slurm",
|
||||
"--testing",
|
||||
]
|
||||
|
||||
# Track background processes for cleanup
|
||||
|
|
@ -106,6 +123,7 @@ def cleanup_processes():
|
|||
print(f"PID {p.pid} already exited.")
|
||||
print("Cleanup complete.")
|
||||
|
||||
|
||||
atexit.register(cleanup_processes)
|
||||
|
||||
|
||||
|
|
@ -113,6 +131,7 @@ def handle_signal(sig, frame):
|
|||
print(f"\nSignal {sig} received; exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
signal.signal(signal.SIGTERM, handle_signal)
|
||||
|
||||
|
|
@ -121,10 +140,12 @@ def main():
|
|||
# 1) Start the API server
|
||||
print("--- Starting Trajectory Handler API Server ---")
|
||||
api_cmd = [
|
||||
'uvicorn',
|
||||
'atroposlib.api.server:app',
|
||||
'--host', API_HOST,
|
||||
'--port', str(API_PORT),
|
||||
"uvicorn",
|
||||
"atroposlib.api.server:app",
|
||||
"--host",
|
||||
API_HOST,
|
||||
"--port",
|
||||
str(API_PORT),
|
||||
]
|
||||
print(f"$ {' '.join(api_cmd)}")
|
||||
api_proc = subprocess.Popen(api_cmd)
|
||||
|
|
@ -133,7 +154,12 @@ def main():
|
|||
|
||||
# 2) Start the dataset environment
|
||||
print("\n--- Starting Dataset Environment ---")
|
||||
env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS
|
||||
env_cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"environments.dataset_environment.dataset_env",
|
||||
"serve",
|
||||
] + DATASET_FLAGS
|
||||
print(f"$ {' '.join(env_cmd)}")
|
||||
env_proc = subprocess.Popen(env_cmd)
|
||||
processes.append(env_proc)
|
||||
|
|
@ -150,5 +176,5 @@ def main():
|
|||
print("--- Training complete ---")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import random
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Union, Dict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
import wandb
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
|
|
@ -508,23 +508,32 @@ class FundamentalPredictionEnv(BaseEnv):
|
|||
|
||||
# Calculate and log training direction accuracy
|
||||
try:
|
||||
direction_accuracy = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
|
||||
direction_accuracy = sum(self.percent_correct_buffer) / len(
|
||||
self.percent_correct_buffer
|
||||
)
|
||||
wandb_metrics["train/direction_accuracy"] = direction_accuracy
|
||||
except ZeroDivisionError:
|
||||
pass # Skip if buffer is empty
|
||||
|
||||
# Calculate and log training magnitude accuracy
|
||||
try:
|
||||
magnitude_accuracy = sum(self.magnitude_accuracy_buffer) / len(self.magnitude_accuracy_buffer)
|
||||
magnitude_accuracy = sum(self.magnitude_accuracy_buffer) / len(
|
||||
self.magnitude_accuracy_buffer
|
||||
)
|
||||
wandb_metrics["train/magnitude_accuracy"] = magnitude_accuracy
|
||||
except ZeroDivisionError:
|
||||
pass # Skip if buffer is empty
|
||||
|
||||
# Calculate combined training score (direction + magnitude)
|
||||
try:
|
||||
combined_score = direction_accuracy + magnitude_accuracy if 'direction_accuracy' in wandb_metrics else 0
|
||||
combined_score = (
|
||||
direction_accuracy + magnitude_accuracy
|
||||
if "direction_accuracy" in wandb_metrics
|
||||
else 0
|
||||
)
|
||||
wandb_metrics["train/combined_score"] = combined_score
|
||||
except:
|
||||
except Exception as e:
|
||||
print(f"Error calculating combined score: {e}")
|
||||
pass
|
||||
|
||||
# Clear the buffers after logging
|
||||
|
|
@ -549,7 +558,7 @@ class FundamentalPredictionEnv(BaseEnv):
|
|||
|
||||
# Get number of examples to keep
|
||||
num_keep = getattr(self.config, "num_rollouts_per_group_for_logging", -1)
|
||||
|
||||
|
||||
if num_keep == -1:
|
||||
num_keep = self.config.group_size
|
||||
|
||||
|
|
@ -577,23 +586,25 @@ class FundamentalPredictionEnv(BaseEnv):
|
|||
|
||||
async def create_rollout_table(self, wandb_metrics):
|
||||
if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0:
|
||||
table = wandb.Table(columns=[
|
||||
"text",
|
||||
"score",
|
||||
"expected_direction",
|
||||
"expected_magnitude",
|
||||
"fundamental_metric"
|
||||
])
|
||||
|
||||
table = wandb.Table(
|
||||
columns=[
|
||||
"text",
|
||||
"score",
|
||||
"expected_direction",
|
||||
"expected_magnitude",
|
||||
"fundamental_metric",
|
||||
]
|
||||
)
|
||||
|
||||
for group in self.rollouts_for_wandb:
|
||||
for item in group:
|
||||
table.add_data(item[0], item[1], item[2], item[3], item[4])
|
||||
|
||||
|
||||
wandb_metrics["train/rollouts"] = table
|
||||
|
||||
|
||||
# Clear rollouts after logging
|
||||
self.rollouts_for_wandb = []
|
||||
|
||||
|
||||
return wandb_metrics
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from concurrent.futures import ProcessPoolExecutor
|
|||
from difflib import SequenceMatcher
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
|
@ -12,7 +13,6 @@ from math_verify.errors import TimeoutException
|
|||
from pydantic import Field
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from datasets import load_dataset
|
|||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,13 +13,12 @@ import numpy as np
|
|||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import wandb # Added for logging
|
||||
from pydantic import BaseModel, Field
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import wandb # Added for logging
|
||||
|
||||
# Global variable to keep track of the vLLM process
|
||||
vllm_process = None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue