mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Tests that launch the API server via subprocess used a hardcoded "python" command which fails on systems where only "python3" is available (e.g. macOS). Using sys.executable ensures the same interpreter running pytest is used for subprocesses. Fixes 36 test errors on macOS environments.
354 lines
11 KiB
Python
354 lines
11 KiB
Python
"""Tests covering gzip compression of API responses and requests."""
|
|
|
|
import gzip
|
|
import json
|
|
import os
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
|
|
import pytest
|
|
import requests
|
|
|
|
|
|
def wait_for_api_server(max_wait=10):
|
|
for _ in range(max_wait):
|
|
try:
|
|
response = requests.get("http://localhost:8000/info")
|
|
if response.status_code == 200:
|
|
return True
|
|
except requests.exceptions.ConnectionError:
|
|
pass
|
|
time.sleep(1)
|
|
return False
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def api_server():
|
|
proc = subprocess.Popen(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"atroposlib.cli.run_api",
|
|
"--host",
|
|
"localhost",
|
|
"--port",
|
|
"8000",
|
|
],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
preexec_fn=os.setsid,
|
|
)
|
|
|
|
if not wait_for_api_server():
|
|
proc.terminate()
|
|
raise RuntimeError("API server failed to start")
|
|
|
|
yield
|
|
|
|
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
|
proc.wait()
|
|
|
|
try:
|
|
requests.get("http://localhost:8000/reset_data")
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_api_state():
|
|
"""Reset API state before each test."""
|
|
try:
|
|
requests.get("http://localhost:8000/reset_data")
|
|
except Exception:
|
|
pass
|
|
yield
|
|
try:
|
|
requests.get("http://localhost:8000/reset_data")
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
class TestAPICompression:
|
|
"""Test class for API compression functionality."""
|
|
|
|
def test_small_response_not_compressed(self, api_server):
|
|
"""Small payloads bypass gzip."""
|
|
response = requests.get("http://localhost:8000/info")
|
|
|
|
assert response.status_code == 200, response.text
|
|
|
|
assert response.headers.get("Content-Encoding") != "gzip"
|
|
|
|
data = response.json()
|
|
assert "batch_size" in data
|
|
|
|
def test_large_response_compressed_automatically(self, api_server):
|
|
"""Large batches are gzipped and transparently decoded by clients."""
|
|
requests.post(
|
|
"http://localhost:8000/register",
|
|
json={
|
|
"wandb_group": "test_group",
|
|
"wandb_project": "test_project",
|
|
"batch_size": 16,
|
|
"max_token_len": 2048,
|
|
"checkpoint_dir": "/tmp/test",
|
|
"save_checkpoint_interval": 100,
|
|
"starting_step": 0,
|
|
"num_steps": 1000,
|
|
},
|
|
)
|
|
|
|
large_scored_data = {
|
|
"tokens": [[i for i in range(512)] for _ in range(16)],
|
|
"masks": [[1 for _ in range(512)] for _ in range(16)],
|
|
"scores": [0.5 for _ in range(16)],
|
|
"advantages": [[0.1 for _ in range(512)] for _ in range(16)],
|
|
"ref_logprobs": [[0.2 for _ in range(512)] for _ in range(16)],
|
|
"messages": [[{"role": "user", "content": "test" * 50}] for _ in range(16)],
|
|
}
|
|
|
|
post_response = requests.post(
|
|
"http://localhost:8000/scored_data",
|
|
json=large_scored_data,
|
|
)
|
|
assert post_response.status_code == 200
|
|
|
|
response = requests.get("http://localhost:8000/batch")
|
|
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "batch" in data
|
|
assert data["batch"] is not None
|
|
|
|
batch = data["batch"][0]
|
|
assert len(batch["tokens"]) == 16
|
|
assert len(batch["tokens"][0]) == 512
|
|
assert batch["tokens"][0][0] == 0
|
|
|
|
def test_compression_with_raw_headers(self, api_server):
|
|
"""Explicit Accept-Encoding still yields usable decoded responses."""
|
|
requests.post(
|
|
"http://localhost:8000/register",
|
|
json={
|
|
"wandb_group": "test_group",
|
|
"wandb_project": "test_project",
|
|
"batch_size": 32,
|
|
"max_token_len": 2048,
|
|
"checkpoint_dir": "/tmp/test",
|
|
"save_checkpoint_interval": 100,
|
|
"starting_step": 0,
|
|
"num_steps": 1000,
|
|
},
|
|
)
|
|
|
|
# Post large data
|
|
large_scored_data = {
|
|
"tokens": [[i for i in range(512)] for _ in range(16)],
|
|
"masks": [[1 for _ in range(512)] for _ in range(16)],
|
|
"scores": [0.5 for _ in range(16)],
|
|
}
|
|
requests.post("http://localhost:8000/scored_data", json=large_scored_data)
|
|
|
|
session = requests.Session()
|
|
response = session.get(
|
|
"http://localhost:8000/batch",
|
|
headers={"Accept-Encoding": "gzip"},
|
|
stream=True,
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "batch" in data
|
|
|
|
if data["batch"] is not None:
|
|
batch = data["batch"][0]
|
|
assert "tokens" in batch
|
|
assert len(batch["tokens"]) > 0
|
|
|
|
def test_compression_ratio_estimation(self, api_server):
|
|
"""Produce a rough before/after size estimate for visibility."""
|
|
requests.post(
|
|
"http://localhost:8000/register",
|
|
json={
|
|
"wandb_group": "test_group",
|
|
"wandb_project": "test_project",
|
|
"batch_size": 32,
|
|
"max_token_len": 2048,
|
|
"checkpoint_dir": "/tmp/test",
|
|
"save_checkpoint_interval": 100,
|
|
"starting_step": 0,
|
|
"num_steps": 1000,
|
|
},
|
|
)
|
|
|
|
large_scored_data = {
|
|
"tokens": [[i for i in range(1024)] for _ in range(32)],
|
|
"masks": [[1 for _ in range(1024)] for _ in range(32)],
|
|
"scores": [0.5 for _ in range(32)],
|
|
"advantages": [[0.1 for _ in range(1024)] for _ in range(32)],
|
|
}
|
|
|
|
requests.post("http://localhost:8000/scored_data", json=large_scored_data)
|
|
|
|
response = requests.get("http://localhost:8000/batch")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
uncompressed_json = json.dumps(data)
|
|
uncompressed_size = len(uncompressed_json.encode("utf-8"))
|
|
|
|
assert data["batch"] is not None
|
|
batch = data["batch"][0]
|
|
assert len(batch["tokens"]) == 32
|
|
assert len(batch["tokens"][0]) == 1024
|
|
|
|
print(f"\nEstimated uncompressed size: {uncompressed_size:,} bytes")
|
|
print("With gzip compression, actual transfer would be ~15-20% of this size")
|
|
|
|
def test_environment_client_compatibility(self, api_server):
|
|
"""Simulate the common trainer + env flow."""
|
|
requests.post(
|
|
"http://localhost:8000/register",
|
|
json={
|
|
"wandb_group": "test_group",
|
|
"wandb_project": "test_project",
|
|
"batch_size": 32,
|
|
"max_token_len": 2048,
|
|
"checkpoint_dir": "/tmp/test",
|
|
"save_checkpoint_interval": 100,
|
|
"starting_step": 0,
|
|
"num_steps": 1000,
|
|
},
|
|
)
|
|
|
|
requests.get("http://localhost:8000/batch")
|
|
|
|
env_response = requests.post(
|
|
"http://localhost:8000/register-env",
|
|
json={
|
|
"max_token_length": 2048,
|
|
"desired_name": "test_env",
|
|
"weight": 1.0,
|
|
"group_size": 4,
|
|
},
|
|
)
|
|
assert env_response.status_code == 200
|
|
env_data = env_response.json()
|
|
assert env_data["status"] == "success"
|
|
env_id = env_data["env_id"]
|
|
|
|
scored_data = {
|
|
"tokens": [[i for i in range(256)] for _ in range(4)],
|
|
"masks": [[1 for _ in range(256)] for _ in range(4)],
|
|
"scores": [0.8, 0.6, 0.4, 0.2],
|
|
"env_id": env_id,
|
|
}
|
|
|
|
post_response = requests.post(
|
|
"http://localhost:8000/scored_data",
|
|
json=scored_data,
|
|
)
|
|
assert post_response.status_code == 200
|
|
|
|
status_response = requests.get(
|
|
"http://localhost:8000/status-env", json={"env_id": env_id}
|
|
)
|
|
assert status_response.status_code == 200
|
|
status_data = status_response.json()
|
|
assert "queue_size" in status_data
|
|
|
|
def test_server_accepts_gzipped_scored_data(self, api_server):
|
|
"""Server inflates gzipped POST bodies."""
|
|
requests.post(
|
|
"http://localhost:8000/register",
|
|
json={
|
|
"wandb_group": "test_group",
|
|
"wandb_project": "test_project",
|
|
"batch_size": 8,
|
|
"max_token_len": 2048,
|
|
"checkpoint_dir": "/tmp/test",
|
|
"save_checkpoint_interval": 100,
|
|
"starting_step": 0,
|
|
"num_steps": 1000,
|
|
},
|
|
)
|
|
|
|
scored_data = {
|
|
"tokens": [[i for i in range(256)] for _ in range(8)],
|
|
"masks": [[1 for _ in range(256)] for _ in range(8)],
|
|
"scores": [0.1 for _ in range(8)],
|
|
}
|
|
|
|
payload = json.dumps(scored_data).encode("utf-8")
|
|
compressed = gzip.compress(payload)
|
|
|
|
response = requests.post(
|
|
"http://localhost:8000/scored_data",
|
|
data=compressed,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Content-Encoding": "gzip",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
assert response_data["status"] == "received"
|
|
|
|
batch_response = requests.get("http://localhost:8000/batch")
|
|
assert batch_response.status_code == 200
|
|
batch_data = batch_response.json()
|
|
assert batch_data["batch"] is not None
|
|
|
|
def test_scored_data_list_compression(self, api_server):
|
|
"""Multi-item submissions still round-trip correctly."""
|
|
requests.post(
|
|
"http://localhost:8000/register",
|
|
json={
|
|
"wandb_group": "test_group",
|
|
"wandb_project": "test_project",
|
|
"batch_size": 32,
|
|
"max_token_len": 2048,
|
|
"checkpoint_dir": "/tmp/test",
|
|
"save_checkpoint_interval": 100,
|
|
"starting_step": 0,
|
|
"num_steps": 1000,
|
|
},
|
|
)
|
|
|
|
scored_data_list = [
|
|
{
|
|
"tokens": [[i for i in range(512)] for _ in range(8)],
|
|
"masks": [[1 for _ in range(512)] for _ in range(8)],
|
|
"scores": [0.5 for _ in range(8)],
|
|
}
|
|
for _ in range(4)
|
|
]
|
|
|
|
response = requests.post(
|
|
"http://localhost:8000/scored_data_list",
|
|
json=scored_data_list,
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "received"
|
|
assert data["groups_processed"] == 4
|
|
|
|
batch_response = requests.get("http://localhost:8000/batch")
|
|
assert batch_response.status_code == 200
|
|
batch_data = batch_response.json()
|
|
assert batch_data["batch"] is not None
|
|
|
|
batch = batch_data["batch"]
|
|
total_sequences = sum(len(item["tokens"]) for item in batch)
|
|
assert total_sequences == 32
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "-s"])
|