mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
gzip compression for atropos api
This commit is contained in:
parent
36243bd3f4
commit
baf4b2d8a8
4 changed files with 528 additions and 2 deletions
|
|
@ -1,11 +1,15 @@
|
|||
import gzip
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from pydantic import BaseModel, field_validator
|
||||
from starlette.datastructures import MutableHeaders
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from atroposlib.api.utils import (
|
||||
find_groups_summing_to_target,
|
||||
|
|
@ -31,6 +35,70 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
|
||||
class GZipRequestMiddleware:
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = MutableHeaders(scope=scope)
|
||||
content_encoding = headers.get("content-encoding", "")
|
||||
if "gzip" not in content_encoding.lower():
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
body_chunks = []
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await receive()
|
||||
body_chunks.append(message.get("body", b""))
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
body = b"".join(body_chunks)
|
||||
if body:
|
||||
try:
|
||||
decompressed = gzip.decompress(body)
|
||||
except OSError:
|
||||
response = PlainTextResponse(
|
||||
"Invalid gzip payload",
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
else:
|
||||
decompressed = b""
|
||||
|
||||
mutable_headers = MutableHeaders(scope=scope)
|
||||
mutable_headers["content-length"] = str(len(decompressed))
|
||||
if "content-encoding" in mutable_headers:
|
||||
del mutable_headers["content-encoding"]
|
||||
|
||||
sent = False
|
||||
|
||||
async def new_receive():
|
||||
nonlocal sent
|
||||
if sent:
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
sent = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": decompressed,
|
||||
"more_body": False,
|
||||
}
|
||||
|
||||
await self.app(scope, new_receive, send)
|
||||
|
||||
|
||||
app.add_middleware(GZipRequestMiddleware)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "AtroposLib API"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue