mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
112 lines
3.2 KiB
Python
112 lines
3.2 KiB
Python
"""
|
|
Atropos API communication utilities.
|
|
|
|
Handles communication with the Atropos API server for:
|
|
- Server health checks
|
|
- Trainer registration
|
|
- Batch retrieval
|
|
"""
|
|
|
|
import time as _time
|
|
|
|
import requests
|
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
|
|
from .config import TrainingConfig
|
|
|
|
|
|
def check_atropos_api(
|
|
url: str = "http://localhost:8000", timeout: float = 30.0
|
|
) -> bool:
|
|
"""
|
|
Check if the Atropos API server is reachable.
|
|
|
|
Args:
|
|
url: Base URL of the Atropos API server
|
|
timeout: Maximum time to wait for the server
|
|
|
|
Returns:
|
|
True if server is reachable
|
|
"""
|
|
start = _time.time()
|
|
while _time.time() - start < timeout:
|
|
try:
|
|
response = requests.get(f"{url}/info", timeout=2)
|
|
if response.status_code == 200:
|
|
print(f"[Trainer] ✓ Atropos API server is reachable at {url}")
|
|
return True
|
|
except requests.exceptions.ConnectionError:
|
|
pass
|
|
except Exception as e:
|
|
print(f"[Trainer] Waiting for Atropos API at {url}... ({e})")
|
|
_time.sleep(1)
|
|
|
|
print(f"[Trainer] ⚠ Warning: Atropos API server not reachable at {url}")
|
|
return False
|
|
|
|
|
|
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30))
|
|
def register_trainer(config: TrainingConfig):
|
|
"""
|
|
Register the trainer with the Atropos API.
|
|
|
|
Verifies registration succeeded before returning.
|
|
"""
|
|
url = config.atropos_url
|
|
save_checkpoint_interval = (
|
|
config.training_steps
|
|
if config.checkpoint_interval <= 0
|
|
else config.checkpoint_interval
|
|
)
|
|
response = requests.post(
|
|
f"{url}/register",
|
|
json={
|
|
# wandb fields are required strings - use empty string if None
|
|
"wandb_group": config.wandb_group or "",
|
|
"wandb_project": config.wandb_project or "",
|
|
"batch_size": config.batch_size * config.gradient_accumulation_steps,
|
|
"max_token_len": config.seq_len,
|
|
"starting_step": 0,
|
|
"checkpoint_dir": config.save_path,
|
|
"save_checkpoint_interval": save_checkpoint_interval,
|
|
"num_steps": config.training_steps,
|
|
},
|
|
timeout=10,
|
|
)
|
|
|
|
# Check for HTTP errors
|
|
response.raise_for_status()
|
|
|
|
# Verify we got a valid response with UUID
|
|
data = response.json()
|
|
if "uuid" not in data:
|
|
raise RuntimeError(f"Registration failed: {data}")
|
|
|
|
print(f"[Trainer] ✓ Registered with Atropos API at {url} (uuid: {data['uuid']})")
|
|
|
|
|
|
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30))
|
|
def get_batch(url: str = "http://localhost:8000"):
|
|
"""
|
|
Get a batch of training data from the Atropos API.
|
|
|
|
Args:
|
|
url: Base URL of the Atropos API server
|
|
|
|
Returns:
|
|
Batch data dictionary containing tokens, masks, scores, etc.
|
|
|
|
Raises:
|
|
RuntimeError: If trainer is not registered or other API error
|
|
"""
|
|
response = requests.get(
|
|
f"{url}/batch",
|
|
timeout=10,
|
|
)
|
|
data = response.json()
|
|
|
|
# Check if there was an error (trainer not registered)
|
|
if data.get("status") == "error":
|
|
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
|
|
|
|
return data
|