mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
investigating weird training issue
This commit is contained in:
parent
690e670e64
commit
3df0e45659
1 changed files with 28 additions and 8 deletions
|
|
@ -100,14 +100,34 @@ def get_batch(url: str = "http://localhost:8000"):
|
|||
Raises:
|
||||
RuntimeError: If trainer is not registered or other API error
|
||||
"""
|
||||
data = requests.get(
|
||||
f"{url}/batch",
|
||||
headers={
|
||||
"X-Atropos-Client": "trainer",
|
||||
"X-Atropos-Pid": str(os.getpid()),
|
||||
},
|
||||
timeout=10,
|
||||
).json()
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{url}/batch",
|
||||
headers={
|
||||
"X-Atropos-Client": "trainer",
|
||||
"X-Atropos-Pid": str(os.getpid()),
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
print(
|
||||
f" [Trainer/API] GET /batch status={response.status_code}",
|
||||
flush=True,
|
||||
)
|
||||
data = response.json()
|
||||
batch = data.get("batch")
|
||||
if batch is None:
|
||||
print(" [Trainer/API] parsed batch=None", flush=True)
|
||||
else:
|
||||
num_groups = len(batch)
|
||||
num_sequences = sum(len(item["tokens"]) for item in batch)
|
||||
print(
|
||||
" [Trainer/API] parsed batch payload: "
|
||||
f"groups={num_groups} sequences={num_sequences}",
|
||||
flush=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
print(f" [Trainer/API] GET /batch failed: {exc!r}", flush=True)
|
||||
raise
|
||||
|
||||
# Check if there was an error (trainer not registered)
|
||||
if data.get("status") == "error":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue