This commit is contained in:
ropresearch 2025-10-10 11:50:39 -04:00
parent baf4b2d8a8
commit e5b8fb8654
3 changed files with 37 additions and 90 deletions

View file

@ -1,12 +1,4 @@
"""
Tests for API server GZip compression.
These tests verify that:
1. GZip compression is enabled for large responses
2. Small responses are not compressed (below minimum_size threshold)
3. Clients automatically decompress responses (no code changes needed)
4. Both requests and raw HTTP clients work correctly
"""
"""Tests covering gzip compression of API responses and requests."""
import gzip
import json
@ -20,7 +12,6 @@ import requests
def wait_for_api_server(max_wait=10):
"""Wait for API server to be ready."""
for _ in range(max_wait):
try:
response = requests.get("http://localhost:8000/info")
@ -34,8 +25,6 @@ def wait_for_api_server(max_wait=10):
@pytest.fixture(scope="module")
def api_server():
"""Launch API server for testing."""
# Start the API server as a subprocess
proc = subprocess.Popen(
[
"python",
@ -48,21 +37,18 @@ def api_server():
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid, # Create new process group
preexec_fn=os.setsid,
)
# Wait for server to be ready
if not wait_for_api_server():
proc.terminate()
raise RuntimeError("API server failed to start")
yield
# Kill the process group to ensure all child processes are terminated
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
proc.wait()
# Clean up after tests
try:
requests.get("http://localhost:8000/reset_data")
except Exception:
@ -87,29 +73,24 @@ class TestAPICompression:
"""Test class for API compression functionality."""
def test_small_response_not_compressed(self, api_server):
"""Test that small responses (< 1KB) are not compressed."""
# Small endpoint response
"""Small payloads bypass gzip."""
response = requests.get("http://localhost:8000/info")
assert response.status_code == 200, response.text
# Small responses should not be compressed
# (FastAPI GZip middleware has minimum_size=1000)
assert response.headers.get("Content-Encoding") != "gzip"
# But the response should still be valid JSON
data = response.json()
assert "batch_size" in data
def test_large_response_compressed_automatically(self, api_server):
"""Test that large responses are automatically compressed and decompressed."""
# Register trainer first
"""Large batches are gzipped and transparently decoded by clients."""
requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test_group",
"wandb_project": "test_project",
"batch_size": 16, # Match the number of sequences we're sending
"batch_size": 16,
"max_token_len": 2048,
"checkpoint_dir": "/tmp/test",
"save_checkpoint_interval": 100,
@ -118,10 +99,9 @@ class TestAPICompression:
},
)
# Create large scored data (should exceed 1KB)
large_scored_data = {
"tokens": [[i for i in range(512)] for _ in range(16)], # Large token array
"masks": [[1 for _ in range(512)] for _ in range(16)], # Large mask array
"tokens": [[i for i in range(512)] for _ in range(16)],
"masks": [[1 for _ in range(512)] for _ in range(16)],
"scores": [0.5 for _ in range(16)],
"advantages": [[0.1 for _ in range(512)] for _ in range(16)],
"ref_logprobs": [[0.2 for _ in range(512)] for _ in range(16)],
@ -130,32 +110,27 @@ class TestAPICompression:
],
}
# Post the large data
post_response = requests.post(
"http://localhost:8000/scored_data",
json=large_scored_data,
)
assert post_response.status_code == 200
# Get batch (should be large and compressed)
response = requests.get("http://localhost:8000/batch")
assert response.status_code == 200
# The requests library automatically decompresses, so we get the data directly
data = response.json()
assert "batch" in data
assert data["batch"] is not None
# Verify we got the data back correctly (automatic decompression worked)
batch = data["batch"][0]
assert len(batch["tokens"]) == 16
assert len(batch["tokens"][0]) == 512
assert batch["tokens"][0][0] == 0 # First token should be 0
assert batch["tokens"][0][0] == 0
def test_compression_with_raw_headers(self, api_server):
"""Test compression using raw HTTP headers to verify server is actually compressing."""
# Register trainer
"""Explicit Accept-Encoding still yields usable decoded responses."""
requests.post(
"http://localhost:8000/register",
json={
@ -178,33 +153,25 @@ class TestAPICompression:
}
requests.post("http://localhost:8000/scored_data", json=large_scored_data)
# Make request with explicit Accept-Encoding and get raw response
session = requests.Session()
response = session.get(
"http://localhost:8000/batch",
headers={"Accept-Encoding": "gzip"},
stream=True
)
assert response.status_code == 200
# Check if the response was actually compressed by the server
# Note: requests automatically decompresses, but we can check the headers
# If compression happened, the raw response should have been compressed
# Get the actual response content
data = response.json()
assert "batch" in data
# Verify the data is correct (decompression worked automatically)
if data["batch"] is not None:
batch = data["batch"][0]
assert "tokens" in batch
assert len(batch["tokens"]) > 0
def test_compression_ratio_estimation(self, api_server):
"""Test to estimate actual compression ratio achieved."""
# Register trainer
"""Produce a rough before/after size estimate for visibility."""
requests.post(
"http://localhost:8000/register",
json={
@ -219,41 +186,33 @@ class TestAPICompression:
},
)
# Create large scored data
large_scored_data = {
"tokens": [[i for i in range(1024)] for _ in range(32)], # ~32K tokens
"tokens": [[i for i in range(1024)] for _ in range(32)],
"masks": [[1 for _ in range(1024)] for _ in range(32)],
"scores": [0.5 for _ in range(32)],
"advantages": [[0.1 for _ in range(1024)] for _ in range(32)],
}
# Post the data
requests.post("http://localhost:8000/scored_data", json=large_scored_data)
# Get the batch
response = requests.get("http://localhost:8000/batch")
assert response.status_code == 200
data = response.json()
# Calculate uncompressed size (rough estimate from JSON string)
uncompressed_json = json.dumps(data)
uncompressed_size = len(uncompressed_json.encode('utf-8'))
# The actual transmitted size would be much smaller due to gzip
# We can't easily measure it with requests (auto-decompresses)
# but we can verify the data is correct
assert data["batch"] is not None
batch = data["batch"][0]
assert len(batch["tokens"]) == 32
assert len(batch["tokens"][0]) == 1024
print(f"\nEstimated uncompressed size: {uncompressed_size:,} bytes")
print(f"With gzip compression, actual transfer would be ~15-20% of this size")
print("With gzip compression, actual transfer would be ~15-20% of this size")
def test_environment_client_compatibility(self, api_server):
"""Test that the compression works with typical environment usage patterns."""
# Register trainer
"""Simulate the common trainer + env flow."""
requests.post(
"http://localhost:8000/register",
json={
@ -268,10 +227,8 @@ class TestAPICompression:
},
)
# Trainer needs to call /batch first to mark as started
requests.get("http://localhost:8000/batch")
# Register environment
env_response = requests.post(
"http://localhost:8000/register-env",
json={
@ -286,23 +243,19 @@ class TestAPICompression:
assert env_data["status"] == "success"
env_id = env_data["env_id"]
# Post scored data as environment would
scored_data = {
"tokens": [[i for i in range(256)] for _ in range(4)],
"masks": [[1 for _ in range(256)] for _ in range(4)],
"scores": [0.8, 0.6, 0.4, 0.2],
"env_id": env_id,
}
# This should work without any client changes
post_response = requests.post(
"http://localhost:8000/scored_data",
json=scored_data,
)
assert post_response.status_code == 200
# Verify environment can check status
# Note: The API expects env_id as JSON body in GET request
status_response = requests.get(
"http://localhost:8000/status-env",
json={"env_id": env_id}
@ -312,7 +265,7 @@ class TestAPICompression:
assert "queue_size" in status_data
def test_server_accepts_gzipped_scored_data(self, api_server):
"""Ensure server middleware handles gzip-compressed request bodies."""
"""Server inflates gzipped POST bodies."""
requests.post(
"http://localhost:8000/register",
json={
@ -355,14 +308,13 @@ class TestAPICompression:
assert batch_data["batch"] is not None
def test_scored_data_list_compression(self, api_server):
"""Test that scored_data_list endpoint also benefits from compression."""
# Register trainer
"""Multi-item submissions still round-trip correctly."""
requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test_group",
"wandb_project": "test_project",
"batch_size": 32, # 4 groups * 8 sequences = 32 total
"batch_size": 32,
"max_token_len": 2048,
"checkpoint_dir": "/tmp/test",
"save_checkpoint_interval": 100,
@ -371,36 +323,33 @@ class TestAPICompression:
},
)
# Post multiple scored data items at once (as list)
scored_data_list = [
{
"tokens": [[i for i in range(512)] for _ in range(8)],
"masks": [[1 for _ in range(512)] for _ in range(8)],
"scores": [0.5 for _ in range(8)],
}
for _ in range(4) # 4 groups of 8 sequences each = 32 sequences total
for _ in range(4)
]
response = requests.post(
"http://localhost:8000/scored_data_list",
json=scored_data_list,
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "received"
assert data["groups_processed"] == 4
# Verify we can get batches
batch_response = requests.get("http://localhost:8000/batch")
assert batch_response.status_code == 200
batch_data = batch_response.json()
assert batch_data["batch"] is not None
# Verify the batch contains the correct data
batch = batch_data["batch"]
total_sequences = sum(len(item["tokens"]) for item in batch)
assert total_sequences == 32 # Should have all 32 sequences
assert total_sequences == 32
if __name__ == "__main__":