mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Fix API to accept messages without reward field + comprehensive tests
- Made reward field truly optional in messages (no auto-addition) - Accept custom roles (dog, cat, etc.) beyond standard ones - Added 24 new tests for edge cases (tuples, unicode, large content) - Reorganized test structure: moved from testing/ to atroposlib/tests/ - Fixed legacy API tests and removed tests requiring missing data files All 43 tests pass\! Fixes message handling for SFT use cases. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
24dd0a71b4
commit
e13526d308
11 changed files with 1434 additions and 46 deletions
89
atroposlib/tests/test_api_legacy.py
Normal file
89
atroposlib/tests/test_api_legacy.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import pytest
|
||||
import requests
|
||||
|
||||
from atroposlib.tests.api_test_utils import launch_api_for_testing
|
||||
|
||||
|
||||
def register_data(group="test", proj="test", batch_size=32) -> requests.Response:
|
||||
x = requests.post(
|
||||
"http://localhost:8000/register",
|
||||
json={
|
||||
"wandb_group": group,
|
||||
"wandb_project": proj,
|
||||
"batch_size": batch_size,
|
||||
"max_token_len": 512,
|
||||
"checkpoint_dir": "/tmp/test",
|
||||
"save_checkpoint_interval": 100,
|
||||
"starting_step": 0,
|
||||
"num_steps": 1000,
|
||||
},
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def post_scored_data(
|
||||
tokens=((0,),), masks=((0,),), scores=(0,), ref_logprobs=((0,),)
|
||||
) -> requests.Response:
|
||||
data = {
|
||||
"tokens": tokens,
|
||||
"masks": masks,
|
||||
"scores": scores,
|
||||
}
|
||||
if ref_logprobs is not None:
|
||||
data["ref_logprobs"] = ref_logprobs
|
||||
x = requests.post("http://localhost:8000/scored_data", json=data)
|
||||
return x
|
||||
|
||||
|
||||
def reset() -> requests.Response:
|
||||
x = requests.get("http://localhost:8000/reset_data")
|
||||
return x
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api():
|
||||
proc = launch_api_for_testing()
|
||||
yield
|
||||
proc.terminate()
|
||||
proc.wait() # Wait for clean shutdown
|
||||
|
||||
|
||||
def test_register(api):
|
||||
x = register_data()
|
||||
assert x.status_code == 200, x.text
|
||||
data = x.json()
|
||||
assert "uuid" in data
|
||||
|
||||
|
||||
def test_reset(api):
|
||||
x = register_data()
|
||||
assert x.status_code == 200, x.text
|
||||
data = x.json()
|
||||
assert "uuid" in data
|
||||
x = post_scored_data()
|
||||
assert x.status_code == 200, x.text
|
||||
x = reset()
|
||||
print("0-0-0-0-0-0-0-0", flush=True)
|
||||
print(x.text, flush=True)
|
||||
print("0-0-0-0-0-0-0-0", flush=True)
|
||||
assert x.status_code == 200, x.text
|
||||
x = requests.get("http://localhost:8000/info")
|
||||
assert x.status_code == 200
|
||||
assert x.json()["batch_size"] == -1
|
||||
x = requests.get("http://localhost:8000/status")
|
||||
assert x.status_code == 200, x.text
|
||||
data = x.json()
|
||||
assert data["current_step"] == 0
|
||||
assert data["queue_size"] == 0
|
||||
x = requests.get("http://localhost:8000/wandb_info")
|
||||
assert x.status_code == 200, x.text
|
||||
data = x.json()
|
||||
assert data["group"] is None
|
||||
assert data["project"] is None
|
||||
|
||||
|
||||
def test_batch_size(api):
|
||||
x = register_data()
|
||||
assert x.status_code == 200, x.text
|
||||
# get the batch size
|
||||
x = requests.get("http://localhost:8000/info")
|
||||
Loading…
Add table
Add a link
Reference in a new issue