gzip compression for atropos api

This commit is contained in:
ropresearch 2025-10-10 01:26:52 -04:00
parent 36243bd3f4
commit baf4b2d8a8
4 changed files with 528 additions and 2 deletions

View file

@ -1,11 +1,15 @@
import gzip
import time
import uuid
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, field_validator
from starlette.datastructures import MutableHeaders
from starlette.types import Receive, Scope, Send
from atroposlib.api.utils import (
find_groups_summing_to_target,
@ -31,6 +35,70 @@ app.add_middleware(
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
class GZipRequestMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
headers = MutableHeaders(scope=scope)
content_encoding = headers.get("content-encoding", "")
if "gzip" not in content_encoding.lower():
await self.app(scope, receive, send)
return
body_chunks = []
more_body = True
while more_body:
message = await receive()
body_chunks.append(message.get("body", b""))
more_body = message.get("more_body", False)
body = b"".join(body_chunks)
if body:
try:
decompressed = gzip.decompress(body)
except OSError:
response = PlainTextResponse(
"Invalid gzip payload",
status_code=status.HTTP_400_BAD_REQUEST,
)
await response(scope, receive, send)
return
else:
decompressed = b""
mutable_headers = MutableHeaders(scope=scope)
mutable_headers["content-length"] = str(len(decompressed))
if "content-encoding" in mutable_headers:
del mutable_headers["content-encoding"]
sent = False
async def new_receive():
nonlocal sent
if sent:
return {"type": "http.request", "body": b"", "more_body": False}
sent = True
return {
"type": "http.request",
"body": decompressed,
"more_body": False,
}
await self.app(scope, new_receive, send)
app.add_middleware(GZipRequestMiddleware)
@app.get("/")
async def root():
return {"message": "AtroposLib API"}