atropos/example_trainer/api.py
2026-03-13 11:06:02 -04:00

116 lines
3.4 KiB
Python

"""
Atropos API communication utilities.
Handles communication with the Atropos API server for:
- Server health checks
- Trainer registration
- Batch retrieval
"""
import os
import time as _time
import requests
from tenacity import retry, stop_after_attempt, wait_exponential
from .config import TrainingConfig
def check_atropos_api(
url: str = "http://localhost:8000", timeout: float = 30.0
) -> bool:
"""
Check if the Atropos API server is reachable.
Args:
url: Base URL of the Atropos API server
timeout: Maximum time to wait for the server
Returns:
True if server is reachable
"""
start = _time.time()
while _time.time() - start < timeout:
try:
response = requests.get(f"{url}/info", timeout=2)
if response.status_code == 200:
print(f"[Trainer] ✓ Atropos API server is reachable at {url}")
return True
except requests.exceptions.ConnectionError:
pass
except Exception as e:
print(f"[Trainer] Waiting for Atropos API at {url}... ({e})")
_time.sleep(1)
print(f"[Trainer] ⚠ Warning: Atropos API server not reachable at {url}")
return False
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30))
def register_trainer(config: TrainingConfig):
"""
Register the trainer with the Atropos API.
Verifies registration succeeded before returning.
"""
url = config.atropos_url
save_checkpoint_interval = (
config.training_steps
if config.checkpoint_interval <= 0
else config.checkpoint_interval
)
response = requests.post(
f"{url}/register",
json={
# wandb fields are required strings - use empty string if None
"wandb_group": config.wandb_group or "",
"wandb_project": config.wandb_project or "",
"batch_size": config.batch_size * config.gradient_accumulation_steps,
"max_token_len": config.seq_len,
"starting_step": 0,
"checkpoint_dir": config.save_path,
"save_checkpoint_interval": save_checkpoint_interval,
"num_steps": config.training_steps,
},
timeout=10,
)
# Check for HTTP errors
response.raise_for_status()
# Verify we got a valid response with UUID
data = response.json()
if "uuid" not in data:
raise RuntimeError(f"Registration failed: {data}")
print(f"[Trainer] ✓ Registered with Atropos API at {url} (uuid: {data['uuid']})")
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30))
def get_batch(url: str = "http://localhost:8000"):
"""
Get a batch of training data from the Atropos API.
Args:
url: Base URL of the Atropos API server
Returns:
Batch data dictionary containing tokens, masks, scores, etc.
Raises:
RuntimeError: If trainer is not registered or other API error
"""
data = requests.get(
f"{url}/batch",
headers={
"X-Atropos-Client": "trainer",
"X-Atropos-Pid": str(os.getpid()),
},
timeout=10,
).json()
# Check if there was an error (trainer not registered)
if data.get("status") == "error":
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
return data