mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
major refactor
This commit is contained in:
parent
119721ef3d
commit
6833d4d820
13 changed files with 3268 additions and 3423 deletions
102
example_trainer/api.py
Normal file
102
example_trainer/api.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""
|
||||
Atropos API communication utilities.
|
||||
|
||||
Handles communication with the Atropos API server for:
|
||||
- Server health checks
|
||||
- Trainer registration
|
||||
- Batch retrieval
|
||||
"""
|
||||
|
||||
import time as _time
|
||||
|
||||
import requests
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from .config import TrainingConfig
|
||||
|
||||
|
||||
def check_atropos_api(url: str = "http://localhost:8000", timeout: float = 30.0) -> bool:
|
||||
"""
|
||||
Check if the Atropos API server is reachable.
|
||||
|
||||
Args:
|
||||
url: Base URL of the Atropos API server
|
||||
timeout: Maximum time to wait for the server
|
||||
|
||||
Returns:
|
||||
True if server is reachable
|
||||
"""
|
||||
start = _time.time()
|
||||
while _time.time() - start < timeout:
|
||||
try:
|
||||
response = requests.get(f"{url}/info", timeout=2)
|
||||
if response.status_code == 200:
|
||||
print(f"[Trainer] ✓ Atropos API server is reachable at {url}")
|
||||
return True
|
||||
except requests.exceptions.ConnectionError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"[Trainer] Waiting for Atropos API at {url}... ({e})")
|
||||
_time.sleep(1)
|
||||
|
||||
print(f"[Trainer] ⚠ Warning: Atropos API server not reachable at {url}")
|
||||
return False
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30))
|
||||
def register_trainer(config: TrainingConfig):
|
||||
"""
|
||||
Register the trainer with the Atropos API.
|
||||
|
||||
Verifies registration succeeded before returning.
|
||||
"""
|
||||
url = config.atropos_url
|
||||
response = requests.post(
|
||||
f"{url}/register",
|
||||
json={
|
||||
# wandb fields are required strings - use empty string if None
|
||||
"wandb_group": config.wandb_group or "",
|
||||
"wandb_project": config.wandb_project or "",
|
||||
"batch_size": config.batch_size * config.gradient_accumulation_steps,
|
||||
"max_token_len": config.seq_len,
|
||||
"starting_step": 0,
|
||||
"checkpoint_dir": config.save_path,
|
||||
"save_checkpoint_interval": config.training_steps,
|
||||
"num_steps": config.training_steps,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Check for HTTP errors
|
||||
response.raise_for_status()
|
||||
|
||||
# Verify we got a valid response with UUID
|
||||
data = response.json()
|
||||
if "uuid" not in data:
|
||||
raise RuntimeError(f"Registration failed: {data}")
|
||||
|
||||
print(f"[Trainer] ✓ Registered with Atropos API at {url} (uuid: {data['uuid']})")
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30))
|
||||
def get_batch(url: str = "http://localhost:8000"):
|
||||
"""
|
||||
Get a batch of training data from the Atropos API.
|
||||
|
||||
Args:
|
||||
url: Base URL of the Atropos API server
|
||||
|
||||
Returns:
|
||||
Batch data dictionary containing tokens, masks, scores, etc.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If trainer is not registered or other API error
|
||||
"""
|
||||
data = requests.get(f"{url}/batch", timeout=10).json()
|
||||
|
||||
# Check if there was an error (trainer not registered)
|
||||
if data.get("status") == "error":
|
||||
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
|
||||
|
||||
return data
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue