mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
33505fe981
commit
11f495a381
19 changed files with 708 additions and 452 deletions
|
|
@ -15,7 +15,9 @@ from tenacity import retry, stop_after_attempt, wait_exponential
|
|||
from .config import TrainingConfig
|
||||
|
||||
|
||||
def check_atropos_api(url: str = "http://localhost:8000", timeout: float = 30.0) -> bool:
|
||||
def check_atropos_api(
|
||||
url: str = "http://localhost:8000", timeout: float = 30.0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the Atropos API server is reachable.
|
||||
|
||||
|
|
@ -82,13 +84,13 @@ def register_trainer(config: TrainingConfig):
|
|||
def get_batch(url: str = "http://localhost:8000"):
|
||||
"""
|
||||
Get a batch of training data from the Atropos API.
|
||||
|
||||
|
||||
Args:
|
||||
url: Base URL of the Atropos API server
|
||||
|
||||
|
||||
Returns:
|
||||
Batch data dictionary containing tokens, masks, scores, etc.
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If trainer is not registered or other API error
|
||||
"""
|
||||
|
|
@ -99,4 +101,3 @@ def get_batch(url: str = "http://localhost:8000"):
|
|||
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
|
||||
|
||||
return data
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue