run pre-commit on all files

This commit is contained in:
dmahan93 2025-05-09 09:54:20 -05:00
parent b959c30ebf
commit 40b12dae60
17 changed files with 169 additions and 118 deletions

View file

@ -1 +1 @@
OPENAI_API_KEY= OPENAI_API_KEY=

View file

@ -66,4 +66,4 @@
- [ ] My changes generate no new warnings - [ ] My changes generate no new warnings
- [ ] New and existing unit tests pass locally with my changes - [ ] New and existing unit tests pass locally with my changes
- [ ] Docstrings added for all new public classes / functions - [ ] Docstrings added for all new public classes / functions
- [ ] If .env vars required, did you add it to the .env.example in repo root? - [ ] If .env vars required, did you add it to the .env.example in repo root?

View file

@ -41,4 +41,4 @@ Project maintainers are obligated to respect the privacy and security of the rep
This Code of Conduct is adapted from general open source community standards and GitHub's community guidelines. This Code of Conduct is adapted from general open source community standards and GitHub's community guidelines.
Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone. Remember: Respect each other, collaborate constructively, and focus on making Atropos better for everyone.

View file

@ -2,7 +2,6 @@ import argparse
import time import time
import requests import requests
import wandb import wandb

View file

@ -387,8 +387,10 @@ class BaseEnv(ABC):
# Now register the env... # Now register the env...
while True: while True:
data = await self._register_env() data = await self._register_env()
if data['status'] != "success": if data["status"] != "success":
logging.warning(f"Waiting to register the env due to status {data['status']}") logging.warning(
f"Waiting to register the env due to status {data['status']}"
)
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
self.env_id = data["env_id"] self.env_id = data["env_id"]

View file

@ -181,4 +181,4 @@ class ConfigHandler:
# Add slurm flag to config if running in a Slurm environment # Add slurm flag to config if running in a Slurm environment
config["use_slurm"] = "SLURM_JOB_ID" in os.environ config["use_slurm"] = "SLURM_JOB_ID" in os.environ
return config return config

View file

@ -152,4 +152,4 @@ server_configs:
If you encounter issues with reward functions, make sure they are properly registered in the registry. If you encounter issues with reward functions, make sure they are properly registered in the registry.
For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset. For dataset-related issues, verify that the dataset exists on HuggingFace and that the specified fields exist in the dataset.

View file

@ -49,4 +49,4 @@ dataset:
server_configs: server_configs:
- model_name: "gpt-4.1-nano" - model_name: "gpt-4.1-nano"
api_key: ${OPENAI_API_KEY} api_key: ${OPENAI_API_KEY}
timeout: 600 timeout: 600

View file

@ -70,4 +70,4 @@ dataset:
eval_dataset_name: "gsm8k" eval_dataset_name: "gsm8k"
eval_dataset_config: "main" eval_dataset_config: "main"
eval_split: "test" eval_split: "test"

View file

@ -27,4 +27,4 @@ dataset:
max_tokens: 4096 max_tokens: 4096
length_warmup_steps: 0 length_warmup_steps: 0
min_tokens: 200 min_tokens: 200

View file

@ -32,7 +32,9 @@ class DatasetEnvConfig(BaseEnvConfig):
None, description="Field in dataset containing canonical correct answer" None, description="Field in dataset containing canonical correct answer"
) )
system_prompt: Optional[str] = Field(None, description="System prompt to use") system_prompt: Optional[str] = Field(None, description="System prompt to use")
prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '<think>')") prefill: Optional[str] = Field(
None, description="Text to prefill the completion with (e.g. '<think>')"
)
shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset") shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset")
max_generations_per_prompt: int = Field( max_generations_per_prompt: int = Field(
1, description="Number of generations per prompt for collection" 1, description="Number of generations per prompt for collection"
@ -137,21 +139,21 @@ class DatasetEnv(BaseEnv):
# Extract user prompt and answer from item # Extract user prompt and answer from item
user_content = dict(item[0][0])["content"] user_content = dict(item[0][0])["content"]
answer = item[1] if len(item) > 1 else None answer = item[1] if len(item) > 1 else None
# Create messages list # Create messages list
messages = [] messages = []
if self.config.system_prompt: if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt}) messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": user_content}) messages.append({"role": "user", "content": user_content})
# Add prefill as assistant message if configured # Add prefill as assistant message if configured
if self.config.prefill: if self.config.prefill:
messages.append({"role": "assistant", "content": self.config.prefill}) messages.append({"role": "assistant", "content": self.config.prefill})
# Convert messages to a prompt string using the tokenizer # Convert messages to a prompt string using the tokenizer
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
# Calculate max tokens for generation (with optional warmup) # Calculate max tokens for generation (with optional warmup)
max_tokens = self.config.max_tokens max_tokens = self.config.max_tokens
if self.config.length_warmup_steps > 0: if self.config.length_warmup_steps > 0:
@ -160,7 +162,7 @@ class DatasetEnv(BaseEnv):
self.config.min_tokens self.config.min_tokens
+ warmup_progress * (self.config.max_tokens - self.config.min_tokens) + warmup_progress * (self.config.max_tokens - self.config.min_tokens)
) )
# Generate completion using completions API # Generate completion using completions API
completions = await self.server.completion( completions = await self.server.completion(
prompt=prompt, prompt=prompt,
@ -169,34 +171,38 @@ class DatasetEnv(BaseEnv):
temperature=self.config.temperature, temperature=self.config.temperature,
top_p=self.config.top_p, top_p=self.config.top_p,
) )
to_score = [] to_score = []
to_backlog = [] to_backlog = []
# Process completions # Process completions
for completion in completions.choices: for completion in completions.choices:
# Get the completion text # Get the completion text
completion_text = completion.text if hasattr(completion, "text") else completion.message.content completion_text = (
completion.text
if hasattr(completion, "text")
else completion.message.content
)
# Build full message sequence for scoring # Build full message sequence for scoring
full_messages = [] full_messages = []
if self.config.system_prompt: if self.config.system_prompt:
full_messages.append({"role": "system", "content": self.config.system_prompt}) full_messages.append(
{"role": "system", "content": self.config.system_prompt}
)
full_messages.append({"role": "user", "content": user_content}) full_messages.append({"role": "user", "content": user_content})
# Combine prefill with completion if prefill was used # Combine prefill with completion if prefill was used
response_content = completion_text response_content = completion_text
if self.config.prefill: if self.config.prefill:
response_content = self.config.prefill + completion_text response_content = self.config.prefill + completion_text
full_messages.append({"role": "assistant", "content": response_content}) full_messages.append({"role": "assistant", "content": response_content})
# Add to scoring list with answer and ground truth # Add to scoring list with answer and ground truth
to_score.append( to_score.append((full_messages, answer, item[2] if len(item) > 2 else None))
(full_messages, answer, item[2] if len(item) > 2 else None)
)
return to_score, to_backlog return to_score, to_backlog
async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]: async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]:
@ -204,27 +210,27 @@ class DatasetEnv(BaseEnv):
async def collect_trajectories(self, item: Item) -> Tuple[List, List]: async def collect_trajectories(self, item: Item) -> Tuple[List, List]:
self.current_item = item self.current_item = item
# Extract user prompt from item # Extract user prompt from item
user_content = dict(item[0][0])["content"] user_content = dict(item[0][0])["content"]
# Create messages list # Create messages list
messages = [] messages = []
if self.config.system_prompt: if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt}) messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": user_content}) messages.append({"role": "user", "content": user_content})
# Add prefill as assistant message if configured # Add prefill as assistant message if configured
if self.config.prefill: if self.config.prefill:
messages.append({"role": "assistant", "content": self.config.prefill}) messages.append({"role": "assistant", "content": self.config.prefill})
# Convert messages to a prompt string using the tokenizer # Convert messages to a prompt string using the tokenizer
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
# Calculate max tokens for generation (with optional warmup) # Calculate max tokens for generation (with optional warmup)
max_tokens = self.config.max_tokens max_tokens = self.config.max_tokens
# Generate completions # Generate completions
completions = await self.server.completion( completions = await self.server.completion(
prompt=prompt, prompt=prompt,
@ -233,30 +239,36 @@ class DatasetEnv(BaseEnv):
temperature=self.config.temperature, temperature=self.config.temperature,
top_p=self.config.top_p, top_p=self.config.top_p,
) )
print(f"Completions: {completions}") print(f"Completions: {completions}")
# Process completions # Process completions
trajectories = [] trajectories = []
for completion in completions.choices: for completion in completions.choices:
# Get the completion text # Get the completion text
completion_text = completion.text if hasattr(completion, "text") else completion.message.content completion_text = (
completion.text
if hasattr(completion, "text")
else completion.message.content
)
# Build complete message sequence # Build complete message sequence
full_messages = [] full_messages = []
if self.config.system_prompt: if self.config.system_prompt:
full_messages.append({"role": "system", "content": self.config.system_prompt}) full_messages.append(
{"role": "system", "content": self.config.system_prompt}
)
full_messages.append({"role": "user", "content": user_content}) full_messages.append({"role": "user", "content": user_content})
# Combine prefill with completion if prefill was used # Combine prefill with completion if prefill was used
response_content = completion_text response_content = completion_text
if self.config.prefill: if self.config.prefill:
response_content = self.config.prefill + completion_text response_content = self.config.prefill + completion_text
full_messages.append({"role": "assistant", "content": response_content}) full_messages.append({"role": "assistant", "content": response_content})
trajectories.append(full_messages) trajectories.append(full_messages)
return trajectories, [] return trajectories, []
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
@ -402,6 +414,7 @@ class DatasetEnv(BaseEnv):
await super().wandb_log(metrics) await super().wandb_log(metrics)
if __name__ == "__main__": if __name__ == "__main__":
# Launch the DatasetEnv via the BaseEnv CLI (serve or process) # Launch the DatasetEnv via the BaseEnv CLI (serve or process)
DatasetEnv.cli() DatasetEnv.cli()

View file

@ -8,7 +8,8 @@ from dotenv import load_dotenv
from atroposlib.envs.base import OpenaiConfig from atroposlib.envs.base import OpenaiConfig
from atroposlib.envs.reward_fns import registry from atroposlib.envs.reward_fns import registry
from atroposlib.utils.config_handler import ConfigHandler
# from atroposlib.utils.config_handler import ConfigHandler
from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig from environments.dataset_environment.dataset_env import DatasetEnv, DatasetEnvConfig
load_dotenv() load_dotenv()
@ -23,7 +24,8 @@ def parse_arguments():
"--config", "--config",
type=str, type=str,
default="dataset_local", default="dataset_local",
help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/, or full path to a YAML file.", help="Configuration file name (without .yaml extension) relative to environments/dataset_environment/configs/,"
" or full path to a YAML file.",
) )
return parser.parse_args() return parser.parse_args()
@ -35,7 +37,7 @@ async def main():
args = parse_arguments() args = parse_arguments()
# Initialize config handler # Initialize config handler
config_handler = ConfigHandler() # config_handler = ConfigHandler()
# Determine config path # Determine config path
if ( if (

View file

@ -14,16 +14,16 @@ Requirements:
- Run from project root so example_trainer is on PYTHONPATH - Run from project root so example_trainer is on PYTHONPATH
- example_trainer/ is a valid Python package (with __init__.py) - example_trainer/ is a valid Python package (with __init__.py)
""" """
import os
import sys
import subprocess
import time
import atexit import atexit
import os
import signal import signal
import subprocess
import sys
import time
import traceback import traceback
# Ensure project root is on PYTHONPATH # Ensure project root is on PYTHONPATH
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if project_root not in sys.path: if project_root not in sys.path:
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
@ -32,56 +32,73 @@ try:
from example_trainer.grpo import TrainingConfig, train from example_trainer.grpo import TrainingConfig, train
except ImportError as e: except ImportError as e:
print(f"Error importing example_trainer.grpo: {e}") print(f"Error importing example_trainer.grpo: {e}")
print("Ensure you're running from project root and that example_trainer/ is a package.") print(
"Ensure you're running from project root and that example_trainer/ is a package."
)
sys.exit(1) sys.exit(1)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Configuration # Configuration
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
API_HOST = '127.0.0.1' API_HOST = "127.0.0.1"
API_PORT = 8000 API_PORT = 8000
VLLM_HOST = '127.0.0.1' VLLM_HOST = "127.0.0.1"
VLLM_PORT = 9001 VLLM_PORT = 9001
MODEL_NAME = 'Qwen/Qwen2.5-1.5B-Instruct' MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
TOKENIZER_NAME = MODEL_NAME TOKENIZER_NAME = MODEL_NAME
TRAINER_CONFIG = { TRAINER_CONFIG = {
'model_name': MODEL_NAME, "model_name": MODEL_NAME,
'training_steps': 20, "training_steps": 20,
'batch_size': 2, "batch_size": 2,
'gradient_accumulation_steps': 2, "gradient_accumulation_steps": 2,
'seq_len': 512, "seq_len": 512,
'vllm_port': VLLM_PORT, "vllm_port": VLLM_PORT,
'vllm_restart_interval': 10, "vllm_restart_interval": 10,
'use_wandb': False, "use_wandb": False,
'wandb_project': '', "wandb_project": "",
'wandb_group': '', "wandb_group": "",
'save_path': './trained_model_checkpoints_local_test', "save_path": "./trained_model_checkpoints_local_test",
} }
# Flags for launching DatasetEnv serve # Flags for launching DatasetEnv serve
DATASET_FLAGS = [ DATASET_FLAGS = [
'--group_size', '4', "--group_size",
'--max_num_workers', '2', "4",
'--rollout_server_url', f"http://{API_HOST}:{API_PORT}", "--max_num_workers",
'--tokenizer_name', TOKENIZER_NAME, "2",
'--use_wandb', "--rollout_server_url",
'--wandb_name', 'dataset_env_local_test', f"http://{API_HOST}:{API_PORT}",
'--max_token_length', str(TRAINER_CONFIG['seq_len']), "--tokenizer_name",
'--ensure_scores_are_not_same', TOKENIZER_NAME,
'--dataset_name', 'HuggingFaceH4/testing_self_instruct_process_essays', "--use_wandb",
'--split', 'train[:100]', "--wandb_name",
'--prompt_field', 'prompt', "dataset_env_local_test",
'--answer_field', 'answer', "--max_token_length",
'--reward_functions', 'length', str(TRAINER_CONFIG["seq_len"]),
'--max_tokens', '128', "--ensure_scores_are_not_same",
'--temperature', '0.7', "--dataset_name",
'--model_name', MODEL_NAME, "HuggingFaceH4/testing_self_instruct_process_essays",
'--base_url', f"http://{VLLM_HOST}:{VLLM_PORT}", "--split",
'--slurm', "train[:100]",
'--testing', "--prompt_field",
"prompt",
"--answer_field",
"answer",
"--reward_functions",
"length",
"--max_tokens",
"128",
"--temperature",
"0.7",
"--model_name",
MODEL_NAME,
"--base_url",
f"http://{VLLM_HOST}:{VLLM_PORT}",
"--slurm",
"--testing",
] ]
# Track background processes for cleanup # Track background processes for cleanup
@ -106,6 +123,7 @@ def cleanup_processes():
print(f"PID {p.pid} already exited.") print(f"PID {p.pid} already exited.")
print("Cleanup complete.") print("Cleanup complete.")
atexit.register(cleanup_processes) atexit.register(cleanup_processes)
@ -113,6 +131,7 @@ def handle_signal(sig, frame):
print(f"\nSignal {sig} received; exiting.") print(f"\nSignal {sig} received; exiting.")
sys.exit(0) sys.exit(0)
signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal) signal.signal(signal.SIGTERM, handle_signal)
@ -121,10 +140,12 @@ def main():
# 1) Start the API server # 1) Start the API server
print("--- Starting Trajectory Handler API Server ---") print("--- Starting Trajectory Handler API Server ---")
api_cmd = [ api_cmd = [
'uvicorn', "uvicorn",
'atroposlib.api.server:app', "atroposlib.api.server:app",
'--host', API_HOST, "--host",
'--port', str(API_PORT), API_HOST,
"--port",
str(API_PORT),
] ]
print(f"$ {' '.join(api_cmd)}") print(f"$ {' '.join(api_cmd)}")
api_proc = subprocess.Popen(api_cmd) api_proc = subprocess.Popen(api_cmd)
@ -133,7 +154,12 @@ def main():
# 2) Start the dataset environment # 2) Start the dataset environment
print("\n--- Starting Dataset Environment ---") print("\n--- Starting Dataset Environment ---")
env_cmd = ['python', '-m', 'environments.dataset_environment.dataset_env', 'serve'] + DATASET_FLAGS env_cmd = [
"python",
"-m",
"environments.dataset_environment.dataset_env",
"serve",
] + DATASET_FLAGS
print(f"$ {' '.join(env_cmd)}") print(f"$ {' '.join(env_cmd)}")
env_proc = subprocess.Popen(env_cmd) env_proc = subprocess.Popen(env_cmd)
processes.append(env_proc) processes.append(env_proc)
@ -150,5 +176,5 @@ def main():
print("--- Training complete ---") print("--- Training complete ---")
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -1,10 +1,10 @@
import random import random
import re import re
from typing import List, Optional, Tuple, Union, Dict from typing import Dict, List, Optional, Tuple, Union
import wandb
from datasets import load_dataset from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import ( from atroposlib.envs.base import (
BaseEnv, BaseEnv,
@ -508,23 +508,32 @@ class FundamentalPredictionEnv(BaseEnv):
# Calculate and log training direction accuracy # Calculate and log training direction accuracy
try: try:
direction_accuracy = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer) direction_accuracy = sum(self.percent_correct_buffer) / len(
self.percent_correct_buffer
)
wandb_metrics["train/direction_accuracy"] = direction_accuracy wandb_metrics["train/direction_accuracy"] = direction_accuracy
except ZeroDivisionError: except ZeroDivisionError:
pass # Skip if buffer is empty pass # Skip if buffer is empty
# Calculate and log training magnitude accuracy # Calculate and log training magnitude accuracy
try: try:
magnitude_accuracy = sum(self.magnitude_accuracy_buffer) / len(self.magnitude_accuracy_buffer) magnitude_accuracy = sum(self.magnitude_accuracy_buffer) / len(
self.magnitude_accuracy_buffer
)
wandb_metrics["train/magnitude_accuracy"] = magnitude_accuracy wandb_metrics["train/magnitude_accuracy"] = magnitude_accuracy
except ZeroDivisionError: except ZeroDivisionError:
pass # Skip if buffer is empty pass # Skip if buffer is empty
# Calculate combined training score (direction + magnitude) # Calculate combined training score (direction + magnitude)
try: try:
combined_score = direction_accuracy + magnitude_accuracy if 'direction_accuracy' in wandb_metrics else 0 combined_score = (
direction_accuracy + magnitude_accuracy
if "direction_accuracy" in wandb_metrics
else 0
)
wandb_metrics["train/combined_score"] = combined_score wandb_metrics["train/combined_score"] = combined_score
except: except Exception as e:
print(f"Error calculating combined score: {e}")
pass pass
# Clear the buffers after logging # Clear the buffers after logging
@ -549,7 +558,7 @@ class FundamentalPredictionEnv(BaseEnv):
# Get number of examples to keep # Get number of examples to keep
num_keep = getattr(self.config, "num_rollouts_per_group_for_logging", -1) num_keep = getattr(self.config, "num_rollouts_per_group_for_logging", -1)
if num_keep == -1: if num_keep == -1:
num_keep = self.config.group_size num_keep = self.config.group_size
@ -577,23 +586,25 @@ class FundamentalPredictionEnv(BaseEnv):
async def create_rollout_table(self, wandb_metrics): async def create_rollout_table(self, wandb_metrics):
if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0: if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0:
table = wandb.Table(columns=[ table = wandb.Table(
"text", columns=[
"score", "text",
"expected_direction", "score",
"expected_magnitude", "expected_direction",
"fundamental_metric" "expected_magnitude",
]) "fundamental_metric",
]
)
for group in self.rollouts_for_wandb: for group in self.rollouts_for_wandb:
for item in group: for item in group:
table.add_data(item[0], item[1], item[2], item[3], item[4]) table.add_data(item[0], item[1], item[2], item[3], item[4])
wandb_metrics["train/rollouts"] = table wandb_metrics["train/rollouts"] = table
# Clear rollouts after logging # Clear rollouts after logging
self.rollouts_for_wandb = [] self.rollouts_for_wandb = []
return wandb_metrics return wandb_metrics

View file

@ -5,6 +5,7 @@ from concurrent.futures import ProcessPoolExecutor
from difflib import SequenceMatcher from difflib import SequenceMatcher
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import wandb
from datasets import load_dataset from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify from math_verify import LatexExtractionConfig, parse, verify
@ -12,7 +13,6 @@ from math_verify.errors import TimeoutException
from pydantic import Field from pydantic import Field
from tqdm.asyncio import tqdm_asyncio from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import ( from atroposlib.envs.base import (
BaseEnv, BaseEnv,
BaseEnvConfig, BaseEnvConfig,

View file

@ -8,7 +8,6 @@ from datasets import load_dataset
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
from atroposlib.type_definitions import GameHistory, Item from atroposlib.type_definitions import GameHistory, Item
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer

View file

@ -13,13 +13,12 @@ import numpy as np
import requests import requests
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import wandb # Added for logging
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential from tenacity import retry, stop_after_attempt, wait_exponential
from torch.optim import AdamW from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb # Added for logging
# Global variable to keep track of the vLLM process # Global variable to keep track of the vLLM process
vllm_process = None vllm_process = None