atropos/atroposlib/tests/test_api_compression.py
0xbyt4 16ac6f0936 fix: use sys.executable instead of hardcoded "python" in tests
Tests that launch the API server via subprocess used a hardcoded
"python" command which fails on systems where only "python3" is
available (e.g. macOS). Using sys.executable ensures the same
interpreter running pytest is used for subprocesses.

Fixes 36 test errors on macOS environments.
2026-02-26 21:48:56 +03:00

354 lines
11 KiB
Python

"""Tests covering gzip compression of API responses and requests."""
import gzip
import json
import os
import signal
import subprocess
import sys
import time
import pytest
import requests
def wait_for_api_server(max_wait=10):
for _ in range(max_wait):
try:
response = requests.get("http://localhost:8000/info")
if response.status_code == 200:
return True
except requests.exceptions.ConnectionError:
pass
time.sleep(1)
return False
@pytest.fixture(scope="module")
def api_server():
proc = subprocess.Popen(
[
sys.executable,
"-m",
"atroposlib.cli.run_api",
"--host",
"localhost",
"--port",
"8000",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid,
)
if not wait_for_api_server():
proc.terminate()
raise RuntimeError("API server failed to start")
yield
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
proc.wait()
try:
requests.get("http://localhost:8000/reset_data")
except Exception:
pass
@pytest.fixture(autouse=True)
def reset_api_state():
"""Reset API state before each test."""
try:
requests.get("http://localhost:8000/reset_data")
except Exception:
pass
yield
try:
requests.get("http://localhost:8000/reset_data")
except Exception:
pass
class TestAPICompression:
"""Test class for API compression functionality."""
def test_small_response_not_compressed(self, api_server):
"""Small payloads bypass gzip."""
response = requests.get("http://localhost:8000/info")
assert response.status_code == 200, response.text
assert response.headers.get("Content-Encoding") != "gzip"
data = response.json()
assert "batch_size" in data
def test_large_response_compressed_automatically(self, api_server):
"""Large batches are gzipped and transparently decoded by clients."""
requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test_group",
"wandb_project": "test_project",
"batch_size": 16,
"max_token_len": 2048,
"checkpoint_dir": "/tmp/test",
"save_checkpoint_interval": 100,
"starting_step": 0,
"num_steps": 1000,
},
)
large_scored_data = {
"tokens": [[i for i in range(512)] for _ in range(16)],
"masks": [[1 for _ in range(512)] for _ in range(16)],
"scores": [0.5 for _ in range(16)],
"advantages": [[0.1 for _ in range(512)] for _ in range(16)],
"ref_logprobs": [[0.2 for _ in range(512)] for _ in range(16)],
"messages": [[{"role": "user", "content": "test" * 50}] for _ in range(16)],
}
post_response = requests.post(
"http://localhost:8000/scored_data",
json=large_scored_data,
)
assert post_response.status_code == 200
response = requests.get("http://localhost:8000/batch")
assert response.status_code == 200
data = response.json()
assert "batch" in data
assert data["batch"] is not None
batch = data["batch"][0]
assert len(batch["tokens"]) == 16
assert len(batch["tokens"][0]) == 512
assert batch["tokens"][0][0] == 0
def test_compression_with_raw_headers(self, api_server):
"""Explicit Accept-Encoding still yields usable decoded responses."""
requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test_group",
"wandb_project": "test_project",
"batch_size": 32,
"max_token_len": 2048,
"checkpoint_dir": "/tmp/test",
"save_checkpoint_interval": 100,
"starting_step": 0,
"num_steps": 1000,
},
)
# Post large data
large_scored_data = {
"tokens": [[i for i in range(512)] for _ in range(16)],
"masks": [[1 for _ in range(512)] for _ in range(16)],
"scores": [0.5 for _ in range(16)],
}
requests.post("http://localhost:8000/scored_data", json=large_scored_data)
session = requests.Session()
response = session.get(
"http://localhost:8000/batch",
headers={"Accept-Encoding": "gzip"},
stream=True,
)
assert response.status_code == 200
data = response.json()
assert "batch" in data
if data["batch"] is not None:
batch = data["batch"][0]
assert "tokens" in batch
assert len(batch["tokens"]) > 0
def test_compression_ratio_estimation(self, api_server):
"""Produce a rough before/after size estimate for visibility."""
requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test_group",
"wandb_project": "test_project",
"batch_size": 32,
"max_token_len": 2048,
"checkpoint_dir": "/tmp/test",
"save_checkpoint_interval": 100,
"starting_step": 0,
"num_steps": 1000,
},
)
large_scored_data = {
"tokens": [[i for i in range(1024)] for _ in range(32)],
"masks": [[1 for _ in range(1024)] for _ in range(32)],
"scores": [0.5 for _ in range(32)],
"advantages": [[0.1 for _ in range(1024)] for _ in range(32)],
}
requests.post("http://localhost:8000/scored_data", json=large_scored_data)
response = requests.get("http://localhost:8000/batch")
assert response.status_code == 200
data = response.json()
uncompressed_json = json.dumps(data)
uncompressed_size = len(uncompressed_json.encode("utf-8"))
assert data["batch"] is not None
batch = data["batch"][0]
assert len(batch["tokens"]) == 32
assert len(batch["tokens"][0]) == 1024
print(f"\nEstimated uncompressed size: {uncompressed_size:,} bytes")
print("With gzip compression, actual transfer would be ~15-20% of this size")
def test_environment_client_compatibility(self, api_server):
"""Simulate the common trainer + env flow."""
requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test_group",
"wandb_project": "test_project",
"batch_size": 32,
"max_token_len": 2048,
"checkpoint_dir": "/tmp/test",
"save_checkpoint_interval": 100,
"starting_step": 0,
"num_steps": 1000,
},
)
requests.get("http://localhost:8000/batch")
env_response = requests.post(
"http://localhost:8000/register-env",
json={
"max_token_length": 2048,
"desired_name": "test_env",
"weight": 1.0,
"group_size": 4,
},
)
assert env_response.status_code == 200
env_data = env_response.json()
assert env_data["status"] == "success"
env_id = env_data["env_id"]
scored_data = {
"tokens": [[i for i in range(256)] for _ in range(4)],
"masks": [[1 for _ in range(256)] for _ in range(4)],
"scores": [0.8, 0.6, 0.4, 0.2],
"env_id": env_id,
}
post_response = requests.post(
"http://localhost:8000/scored_data",
json=scored_data,
)
assert post_response.status_code == 200
status_response = requests.get(
"http://localhost:8000/status-env", json={"env_id": env_id}
)
assert status_response.status_code == 200
status_data = status_response.json()
assert "queue_size" in status_data
def test_server_accepts_gzipped_scored_data(self, api_server):
"""Server inflates gzipped POST bodies."""
requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test_group",
"wandb_project": "test_project",
"batch_size": 8,
"max_token_len": 2048,
"checkpoint_dir": "/tmp/test",
"save_checkpoint_interval": 100,
"starting_step": 0,
"num_steps": 1000,
},
)
scored_data = {
"tokens": [[i for i in range(256)] for _ in range(8)],
"masks": [[1 for _ in range(256)] for _ in range(8)],
"scores": [0.1 for _ in range(8)],
}
payload = json.dumps(scored_data).encode("utf-8")
compressed = gzip.compress(payload)
response = requests.post(
"http://localhost:8000/scored_data",
data=compressed,
headers={
"Content-Type": "application/json",
"Content-Encoding": "gzip",
},
)
assert response.status_code == 200
response_data = response.json()
assert response_data["status"] == "received"
batch_response = requests.get("http://localhost:8000/batch")
assert batch_response.status_code == 200
batch_data = batch_response.json()
assert batch_data["batch"] is not None
def test_scored_data_list_compression(self, api_server):
"""Multi-item submissions still round-trip correctly."""
requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test_group",
"wandb_project": "test_project",
"batch_size": 32,
"max_token_len": 2048,
"checkpoint_dir": "/tmp/test",
"save_checkpoint_interval": 100,
"starting_step": 0,
"num_steps": 1000,
},
)
scored_data_list = [
{
"tokens": [[i for i in range(512)] for _ in range(8)],
"masks": [[1 for _ in range(512)] for _ in range(8)],
"scores": [0.5 for _ in range(8)],
}
for _ in range(4)
]
response = requests.post(
"http://localhost:8000/scored_data_list",
json=scored_data_list,
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "received"
assert data["groups_processed"] == 4
batch_response = requests.get("http://localhost:8000/batch")
assert batch_response.status_code == 200
batch_data = batch_response.json()
assert batch_data["batch"] is not None
batch = batch_data["batch"]
total_sequences = sum(len(item["tokens"]) for item in batch)
assert total_sequences == 32
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])