atropos/atroposlib/tests/test_api_legacy.py
Dakota e13526d308 Fix API to accept messages without reward field + comprehensive tests
- Made reward field truly optional in messages (no auto-addition)
- Accept custom roles (dog, cat, etc.) beyond standard ones
- Added 24 new tests for edge cases (tuples, unicode, large content)
- Reorganized test structure: moved from testing/ to atroposlib/tests/
- Fixed legacy API tests and removed tests requiring missing data files

All 43 tests pass\! Fixes message handling for SFT use cases.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-06-09 14:03:08 -05:00

89 lines
2.3 KiB
Python

import pytest
import requests
from atroposlib.tests.api_test_utils import launch_api_for_testing
def register_data(group="test", proj="test", batch_size=32) -> requests.Response:
x = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": group,
"wandb_project": proj,
"batch_size": batch_size,
"max_token_len": 512,
"checkpoint_dir": "/tmp/test",
"save_checkpoint_interval": 100,
"starting_step": 0,
"num_steps": 1000,
},
)
return x
def post_scored_data(
tokens=((0,),), masks=((0,),), scores=(0,), ref_logprobs=((0,),)
) -> requests.Response:
data = {
"tokens": tokens,
"masks": masks,
"scores": scores,
}
if ref_logprobs is not None:
data["ref_logprobs"] = ref_logprobs
x = requests.post("http://localhost:8000/scored_data", json=data)
return x
def reset() -> requests.Response:
x = requests.get("http://localhost:8000/reset_data")
return x
@pytest.fixture(scope="session")
def api():
proc = launch_api_for_testing()
yield
proc.terminate()
proc.wait() # Wait for clean shutdown
def test_register(api):
x = register_data()
assert x.status_code == 200, x.text
data = x.json()
assert "uuid" in data
def test_reset(api):
x = register_data()
assert x.status_code == 200, x.text
data = x.json()
assert "uuid" in data
x = post_scored_data()
assert x.status_code == 200, x.text
x = reset()
print("0-0-0-0-0-0-0-0", flush=True)
print(x.text, flush=True)
print("0-0-0-0-0-0-0-0", flush=True)
assert x.status_code == 200, x.text
x = requests.get("http://localhost:8000/info")
assert x.status_code == 200
assert x.json()["batch_size"] == -1
x = requests.get("http://localhost:8000/status")
assert x.status_code == 200, x.text
data = x.json()
assert data["current_step"] == 0
assert data["queue_size"] == 0
x = requests.get("http://localhost:8000/wandb_info")
assert x.status_code == 200, x.text
data = x.json()
assert data["group"] is None
assert data["project"] is None
def test_batch_size(api):
x = register_data()
assert x.status_code == 200, x.text
# get the batch size
x = requests.get("http://localhost:8000/info")