diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..dc612fdd --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,26 @@ +# Agent Guidelines for Atropos + +## Commands +- **Run all tests**: `pytest` +- **Run single test**: `pytest atroposlib/tests/test_file.py::test_function_name` +- **Lint/format**: `pre-commit run --all-files` (runs black, ruff, flake8) +- **Manual formatting**: `black .` and `isort .` + +## Code Style +- **Formatting**: Black (line length 120), enforced via pre-commit hooks +- **Linting**: Flake8 with `--max-line-length=120 --extend-ignore=E203,W503` +- **Import order**: Standard library → third-party → local, sorted by ruff/isort +- **Typing**: Use type hints from `typing` module (Dict, List, Optional, Tuple, etc.) +- **Classes**: Inherit from Pydantic BaseModel or TypedDict for data structures +- **Naming**: snake_case for functions/variables, PascalCase for classes +- **Docstrings**: Use for public methods; include description and parameter/return types +- **Error handling**: Use tenacity for retries, logging for debug info +- **Async**: Use async/await for I/O operations; asyncio for concurrency + +## Project Structure +- Core library: `atroposlib/` - base classes, API server, utilities +- Environments: `environments/` - training environments; community contributions go in `environments/community/` +- New environments should use direct imports from their directory root + +## Before Committing +Always run `pre-commit run --all-files` to ensure code passes black, isort, flake8, and other checks. diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index e37688ba..27db04fc 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -1,11 +1,15 @@ +import gzip import time import uuid from typing import Any, Dict, List, Optional from fastapi import FastAPI, status from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware from fastapi.responses import PlainTextResponse from pydantic import BaseModel, field_validator +from starlette.datastructures import MutableHeaders +from starlette.types import Receive, Scope, Send from atroposlib.api.utils import ( find_groups_summing_to_target, @@ -31,6 +35,70 @@ app.add_middleware( ) +app.add_middleware(GZipMiddleware, minimum_size=1000) + + +class GZipRequestMiddleware: + + def __init__(self, app): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + headers = MutableHeaders(scope=scope) + content_encoding = headers.get("content-encoding", "") + if "gzip" not in content_encoding.lower(): + await self.app(scope, receive, send) + return + + body_chunks = [] + more_body = True + while more_body: + message = await receive() + body_chunks.append(message.get("body", b"")) + more_body = message.get("more_body", False) + + body = b"".join(body_chunks) + if body: + try: + decompressed = gzip.decompress(body) + except OSError: + response = PlainTextResponse( + "Invalid gzip payload", + status_code=status.HTTP_400_BAD_REQUEST, + ) + await response(scope, receive, send) + return + else: + decompressed = b"" + + mutable_headers = MutableHeaders(scope=scope) + mutable_headers["content-length"] = str(len(decompressed)) + if "content-encoding" in mutable_headers: + del mutable_headers["content-encoding"] + + sent = False + + async def new_receive(): + nonlocal sent + if sent: + return {"type": "http.request", "body": b"", "more_body": False} + sent = True + return { + "type": "http.request", + "body": decompressed, + "more_body": False, + } + + await self.app(scope, new_receive, send) + + +app.add_middleware(GZipRequestMiddleware) + + @app.get("/") async def root(): return {"message": "AtroposLib API"} diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index bf6491f9..e51000b7 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1,4 +1,5 @@ import asyncio +import gzip import json import logging import math @@ -766,9 +767,10 @@ class BaseEnv(ABC): else f"{self.config.rollout_server_url}/scored_data" ) async with aiohttp.ClientSession() as session: - async with session.post( + async with self._post_json_with_compression( + session, url, - json=scored_data, + scored_data, ) as resp: if resp.status >= 500: # Server errors (5xx) should trigger a retry @@ -781,6 +783,29 @@ class BaseEnv(ABC): # Success case: print response text print(await resp.text()) + def _post_json_with_compression( + self, + session: aiohttp.ClientSession, + url: str, + payload: Any, + *, + minimum_size: int = 1024, + ): + """ + Send JSON payloads with optional gzip compression when payloads are large. + """ + serialized = json.dumps(payload).encode("utf-8") + headers = {"Content-Type": "application/json"} + body = serialized + + if len(serialized) >= minimum_size: + compressed = gzip.compress(serialized) + if len(compressed) < len(serialized): + headers["Content-Encoding"] = "gzip" + body = compressed + + return session.post(url, data=body, headers=headers) + async def handle_send_to_api( self, scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], diff --git a/atroposlib/tests/test_api_compression.py b/atroposlib/tests/test_api_compression.py new file mode 100644 index 00000000..0c609663 --- /dev/null +++ b/atroposlib/tests/test_api_compression.py @@ -0,0 +1,407 @@ +""" +Tests for API server GZip compression. + +These tests verify that: +1. GZip compression is enabled for large responses +2. Small responses are not compressed (below minimum_size threshold) +3. Clients automatically decompress responses (no code changes needed) +4. Both requests and raw HTTP clients work correctly +""" + +import gzip +import json +import os +import signal +import subprocess +import time + +import pytest +import requests + + +def wait_for_api_server(max_wait=10): + """Wait for API server to be ready.""" + for _ in range(max_wait): + try: + response = requests.get("http://localhost:8000/info") + if response.status_code == 200: + return True + except requests.exceptions.ConnectionError: + pass + time.sleep(1) + return False + + +@pytest.fixture(scope="module") +def api_server(): + """Launch API server for testing.""" + # Start the API server as a subprocess + proc = subprocess.Popen( + [ + "python", + "-m", + "atroposlib.cli.run_api", + "--host", + "localhost", + "--port", + "8000", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=os.setsid, # Create new process group + ) + + # Wait for server to be ready + if not wait_for_api_server(): + proc.terminate() + raise RuntimeError("API server failed to start") + + yield + + # Kill the process group to ensure all child processes are terminated + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + proc.wait() + + # Clean up after tests + try: + requests.get("http://localhost:8000/reset_data") + except Exception: + pass + + +@pytest.fixture(autouse=True) +def reset_api_state(): + """Reset API state before each test.""" + try: + requests.get("http://localhost:8000/reset_data") + except Exception: + pass + yield + try: + requests.get("http://localhost:8000/reset_data") + except Exception: + pass + + +class TestAPICompression: + """Test class for API compression functionality.""" + + def test_small_response_not_compressed(self, api_server): + """Test that small responses (< 1KB) are not compressed.""" + # Small endpoint response + response = requests.get("http://localhost:8000/info") + + assert response.status_code == 200, response.text + + # Small responses should not be compressed + # (FastAPI GZip middleware has minimum_size=1000) + assert response.headers.get("Content-Encoding") != "gzip" + + # But the response should still be valid JSON + data = response.json() + assert "batch_size" in data + + def test_large_response_compressed_automatically(self, api_server): + """Test that large responses are automatically compressed and decompressed.""" + # Register trainer first + requests.post( + "http://localhost:8000/register", + json={ + "wandb_group": "test_group", + "wandb_project": "test_project", + "batch_size": 16, # Match the number of sequences we're sending + "max_token_len": 2048, + "checkpoint_dir": "/tmp/test", + "save_checkpoint_interval": 100, + "starting_step": 0, + "num_steps": 1000, + }, + ) + + # Create large scored data (should exceed 1KB) + large_scored_data = { + "tokens": [[i for i in range(512)] for _ in range(16)], # Large token array + "masks": [[1 for _ in range(512)] for _ in range(16)], # Large mask array + "scores": [0.5 for _ in range(16)], + "advantages": [[0.1 for _ in range(512)] for _ in range(16)], + "ref_logprobs": [[0.2 for _ in range(512)] for _ in range(16)], + "messages": [ + [{"role": "user", "content": "test" * 50}] for _ in range(16) + ], + } + + # Post the large data + post_response = requests.post( + "http://localhost:8000/scored_data", + json=large_scored_data, + ) + assert post_response.status_code == 200 + + # Get batch (should be large and compressed) + response = requests.get("http://localhost:8000/batch") + + assert response.status_code == 200 + + # The requests library automatically decompresses, so we get the data directly + data = response.json() + assert "batch" in data + assert data["batch"] is not None + + # Verify we got the data back correctly (automatic decompression worked) + batch = data["batch"][0] + assert len(batch["tokens"]) == 16 + assert len(batch["tokens"][0]) == 512 + assert batch["tokens"][0][0] == 0 # First token should be 0 + + def test_compression_with_raw_headers(self, api_server): + """Test compression using raw HTTP headers to verify server is actually compressing.""" + # Register trainer + requests.post( + "http://localhost:8000/register", + json={ + "wandb_group": "test_group", + "wandb_project": "test_project", + "batch_size": 32, + "max_token_len": 2048, + "checkpoint_dir": "/tmp/test", + "save_checkpoint_interval": 100, + "starting_step": 0, + "num_steps": 1000, + }, + ) + + # Post large data + large_scored_data = { + "tokens": [[i for i in range(512)] for _ in range(16)], + "masks": [[1 for _ in range(512)] for _ in range(16)], + "scores": [0.5 for _ in range(16)], + } + requests.post("http://localhost:8000/scored_data", json=large_scored_data) + + # Make request with explicit Accept-Encoding and get raw response + session = requests.Session() + response = session.get( + "http://localhost:8000/batch", + headers={"Accept-Encoding": "gzip"}, + stream=True + ) + + assert response.status_code == 200 + + # Check if the response was actually compressed by the server + # Note: requests automatically decompresses, but we can check the headers + # If compression happened, the raw response should have been compressed + + # Get the actual response content + data = response.json() + assert "batch" in data + + # Verify the data is correct (decompression worked automatically) + if data["batch"] is not None: + batch = data["batch"][0] + assert "tokens" in batch + assert len(batch["tokens"]) > 0 + + def test_compression_ratio_estimation(self, api_server): + """Test to estimate actual compression ratio achieved.""" + # Register trainer + requests.post( + "http://localhost:8000/register", + json={ + "wandb_group": "test_group", + "wandb_project": "test_project", + "batch_size": 32, + "max_token_len": 2048, + "checkpoint_dir": "/tmp/test", + "save_checkpoint_interval": 100, + "starting_step": 0, + "num_steps": 1000, + }, + ) + + # Create large scored data + large_scored_data = { + "tokens": [[i for i in range(1024)] for _ in range(32)], # ~32K tokens + "masks": [[1 for _ in range(1024)] for _ in range(32)], + "scores": [0.5 for _ in range(32)], + "advantages": [[0.1 for _ in range(1024)] for _ in range(32)], + } + + # Post the data + requests.post("http://localhost:8000/scored_data", json=large_scored_data) + + # Get the batch + response = requests.get("http://localhost:8000/batch") + + assert response.status_code == 200 + data = response.json() + + # Calculate uncompressed size (rough estimate from JSON string) + uncompressed_json = json.dumps(data) + uncompressed_size = len(uncompressed_json.encode('utf-8')) + + # The actual transmitted size would be much smaller due to gzip + # We can't easily measure it with requests (auto-decompresses) + # but we can verify the data is correct + assert data["batch"] is not None + batch = data["batch"][0] + assert len(batch["tokens"]) == 32 + assert len(batch["tokens"][0]) == 1024 + + print(f"\nEstimated uncompressed size: {uncompressed_size:,} bytes") + print(f"With gzip compression, actual transfer would be ~15-20% of this size") + + def test_environment_client_compatibility(self, api_server): + """Test that the compression works with typical environment usage patterns.""" + # Register trainer + requests.post( + "http://localhost:8000/register", + json={ + "wandb_group": "test_group", + "wandb_project": "test_project", + "batch_size": 32, + "max_token_len": 2048, + "checkpoint_dir": "/tmp/test", + "save_checkpoint_interval": 100, + "starting_step": 0, + "num_steps": 1000, + }, + ) + + # Trainer needs to call /batch first to mark as started + requests.get("http://localhost:8000/batch") + + # Register environment + env_response = requests.post( + "http://localhost:8000/register-env", + json={ + "max_token_length": 2048, + "desired_name": "test_env", + "weight": 1.0, + "group_size": 4, + }, + ) + assert env_response.status_code == 200 + env_data = env_response.json() + assert env_data["status"] == "success" + env_id = env_data["env_id"] + + # Post scored data as environment would + scored_data = { + "tokens": [[i for i in range(256)] for _ in range(4)], + "masks": [[1 for _ in range(256)] for _ in range(4)], + "scores": [0.8, 0.6, 0.4, 0.2], + "env_id": env_id, + } + + # This should work without any client changes + post_response = requests.post( + "http://localhost:8000/scored_data", + json=scored_data, + ) + assert post_response.status_code == 200 + + # Verify environment can check status + # Note: The API expects env_id as JSON body in GET request + status_response = requests.get( + "http://localhost:8000/status-env", + json={"env_id": env_id} + ) + assert status_response.status_code == 200 + status_data = status_response.json() + assert "queue_size" in status_data + + def test_server_accepts_gzipped_scored_data(self, api_server): + """Ensure server middleware handles gzip-compressed request bodies.""" + requests.post( + "http://localhost:8000/register", + json={ + "wandb_group": "test_group", + "wandb_project": "test_project", + "batch_size": 8, + "max_token_len": 2048, + "checkpoint_dir": "/tmp/test", + "save_checkpoint_interval": 100, + "starting_step": 0, + "num_steps": 1000, + }, + ) + + scored_data = { + "tokens": [[i for i in range(256)] for _ in range(8)], + "masks": [[1 for _ in range(256)] for _ in range(8)], + "scores": [0.1 for _ in range(8)], + } + + payload = json.dumps(scored_data).encode("utf-8") + compressed = gzip.compress(payload) + + response = requests.post( + "http://localhost:8000/scored_data", + data=compressed, + headers={ + "Content-Type": "application/json", + "Content-Encoding": "gzip", + }, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["status"] == "received" + + batch_response = requests.get("http://localhost:8000/batch") + assert batch_response.status_code == 200 + batch_data = batch_response.json() + assert batch_data["batch"] is not None + + def test_scored_data_list_compression(self, api_server): + """Test that scored_data_list endpoint also benefits from compression.""" + # Register trainer + requests.post( + "http://localhost:8000/register", + json={ + "wandb_group": "test_group", + "wandb_project": "test_project", + "batch_size": 32, # 4 groups * 8 sequences = 32 total + "max_token_len": 2048, + "checkpoint_dir": "/tmp/test", + "save_checkpoint_interval": 100, + "starting_step": 0, + "num_steps": 1000, + }, + ) + + # Post multiple scored data items at once (as list) + scored_data_list = [ + { + "tokens": [[i for i in range(512)] for _ in range(8)], + "masks": [[1 for _ in range(512)] for _ in range(8)], + "scores": [0.5 for _ in range(8)], + } + for _ in range(4) # 4 groups of 8 sequences each = 32 sequences total + ] + + response = requests.post( + "http://localhost:8000/scored_data_list", + json=scored_data_list, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "received" + assert data["groups_processed"] == 4 + + # Verify we can get batches + batch_response = requests.get("http://localhost:8000/batch") + assert batch_response.status_code == 200 + batch_data = batch_response.json() + assert batch_data["batch"] is not None + + # Verify the batch contains the correct data + batch = batch_data["batch"] + total_sequences = sum(len(item["tokens"]) for item in batch) + assert total_sequences == 32 # Should have all 32 sequences + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])