mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
- Made reward field truly optional in messages (no auto-addition) - Accept custom roles (dog, cat, etc.) beyond standard ones - Added 24 new tests for edge cases (tuples, unicode, large content) - Reorganized test structure: moved from testing/ to atroposlib/tests/ - Fixed legacy API tests and removed tests requiring missing data files All 43 tests pass\! Fixes message handling for SFT use cases. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
89 lines
2.3 KiB
Python
89 lines
2.3 KiB
Python
import pytest
|
|
import requests
|
|
|
|
from atroposlib.tests.api_test_utils import launch_api_for_testing
|
|
|
|
|
|
def register_data(group="test", proj="test", batch_size=32) -> requests.Response:
|
|
x = requests.post(
|
|
"http://localhost:8000/register",
|
|
json={
|
|
"wandb_group": group,
|
|
"wandb_project": proj,
|
|
"batch_size": batch_size,
|
|
"max_token_len": 512,
|
|
"checkpoint_dir": "/tmp/test",
|
|
"save_checkpoint_interval": 100,
|
|
"starting_step": 0,
|
|
"num_steps": 1000,
|
|
},
|
|
)
|
|
return x
|
|
|
|
|
|
def post_scored_data(
|
|
tokens=((0,),), masks=((0,),), scores=(0,), ref_logprobs=((0,),)
|
|
) -> requests.Response:
|
|
data = {
|
|
"tokens": tokens,
|
|
"masks": masks,
|
|
"scores": scores,
|
|
}
|
|
if ref_logprobs is not None:
|
|
data["ref_logprobs"] = ref_logprobs
|
|
x = requests.post("http://localhost:8000/scored_data", json=data)
|
|
return x
|
|
|
|
|
|
def reset() -> requests.Response:
|
|
x = requests.get("http://localhost:8000/reset_data")
|
|
return x
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def api():
|
|
proc = launch_api_for_testing()
|
|
yield
|
|
proc.terminate()
|
|
proc.wait() # Wait for clean shutdown
|
|
|
|
|
|
def test_register(api):
|
|
x = register_data()
|
|
assert x.status_code == 200, x.text
|
|
data = x.json()
|
|
assert "uuid" in data
|
|
|
|
|
|
def test_reset(api):
|
|
x = register_data()
|
|
assert x.status_code == 200, x.text
|
|
data = x.json()
|
|
assert "uuid" in data
|
|
x = post_scored_data()
|
|
assert x.status_code == 200, x.text
|
|
x = reset()
|
|
print("0-0-0-0-0-0-0-0", flush=True)
|
|
print(x.text, flush=True)
|
|
print("0-0-0-0-0-0-0-0", flush=True)
|
|
assert x.status_code == 200, x.text
|
|
x = requests.get("http://localhost:8000/info")
|
|
assert x.status_code == 200
|
|
assert x.json()["batch_size"] == -1
|
|
x = requests.get("http://localhost:8000/status")
|
|
assert x.status_code == 200, x.text
|
|
data = x.json()
|
|
assert data["current_step"] == 0
|
|
assert data["queue_size"] == 0
|
|
x = requests.get("http://localhost:8000/wandb_info")
|
|
assert x.status_code == 200, x.text
|
|
data = x.json()
|
|
assert data["group"] is None
|
|
assert data["project"] is None
|
|
|
|
|
|
def test_batch_size(api):
|
|
x = register_data()
|
|
assert x.status_code == 200, x.text
|
|
# get the batch size
|
|
x = requests.get("http://localhost:8000/info")
|