diff --git a/atroposlib/api/README.md b/atroposlib/api/README.md index 3e719ee3..6430ab0a 100644 --- a/atroposlib/api/README.md +++ b/atroposlib/api/README.md @@ -129,9 +129,12 @@ The API documentation (Swagger UI) will be available at `http:// tokens: List[List[int]] masks: List[List[int]] scores: List[float] + advantages: Optional[List[List[float]]] = None ref_logprobs: Optional[List[List[float]]] = None + messages: Optional[List[List[Message]]] = None overrides: Optional[List[dict]] = None # Per-item logging overrides group_overrides: Optional[dict] = None # Group logging overrides + images: Optional[Any] = None # Image data (if applicable) ``` * **Response:** `{"status": "received"}` * `POST /scored_data_list` @@ -145,7 +148,7 @@ The API documentation (Swagger UI) will be available at `http:// * Not enough data: `{"batch": null}` * `GET /latest_example` * **Description:** Debug endpoint to retrieve the most recently added `ScoredData` item. - * **Response:** The last `ScoredData` dictionary pushed, or empty lists if none yet. + * **Response:** The last `ScoredData` dictionary pushed, or empty lists for tokens, masks, scores, advantages, ref_logprobs, messages, and images if none yet. ### Debugging diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index a6dcb5fc..0250399b 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -8,6 +8,7 @@ from fastapi.responses import PlainTextResponse from pydantic import BaseModel from atroposlib.api.utils import grab_exact_from_heterogeneous_queue +from atroposlib.type_definitions import Message app = FastAPI(title="AtroposLib API") @@ -50,7 +51,9 @@ class ScoredData(BaseModel): tokens: List[List[int]] masks: List[List[int]] scores: List[float] + advantages: Optional[List[List[float]]] = None ref_logprobs: Optional[List[List[float]]] = None + messages: Optional[List[List[Message]]] = None overrides: Optional[List[dict]] = None group_overrides: Optional[dict] = None images: Optional[Any] = None @@ -212,7 +215,9 @@ async def get_latest_example(): "tokens": [], "masks": [], "scores": [], + "advantages": [], "ref_logprobs": [], + "messages": [], "images": [], } @@ -224,7 +229,9 @@ async def scored_data(scored_data: ScoredData): "tokens": scored_data.tokens, "masks": scored_data.masks, "scores": scored_data.scores, + "advantages": scored_data.advantages, "ref_logprobs": scored_data.ref_logprobs, + "messages": scored_data.messages, "overrides": scored_data.overrides, "group_overrides": scored_data.group_overrides, "images": scored_data.images, @@ -245,8 +252,10 @@ async def scored_data_list(scored_data_list: List[ScoredData]): "tokens": scored_data.tokens, "masks": scored_data.masks, "scores": scored_data.scores, + "advantages": scored_data.advantages, "ref_logprobs": scored_data.ref_logprobs, "images": scored_data.images, + "messages": scored_data.messages, "overrides": scored_data.overrides, "group_overrides": scored_data.group_overrides, } diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 03f0a962..595a8d83 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -60,6 +60,7 @@ class ScoredDataGroup(TypedDict): messages: Optional[List[List[Message]]] group_overrides: Optional[Dict] overrides: Optional[List[Dict]] + images: Optional[Any] class ScoredDataItem(TypedDict): @@ -71,6 +72,7 @@ class ScoredDataItem(TypedDict): messages: Optional[List[Message]] group_overrides: Optional[Dict] overrides: Optional[Dict] + images: Optional[Any] class EvalHandlingEnum(Enum): @@ -282,6 +284,7 @@ class BaseEnv(ABC): to_postprocess["messages"] = [] to_postprocess["group_overrides"] = {} to_postprocess["overrides"] = [] + to_postprocess["images"] = [] print("Processing results") for result in results: to_postprocess["tokens"].append(result[0]["tokens"]) @@ -297,6 +300,8 @@ class BaseEnv(ABC): to_postprocess["group_overrides"].update(result[0]["group_overrides"]) if result[0].get("overrides", None) is not None: to_postprocess["overrides"].append(result[0]["overrides"]) + if result[0].get("images", None) is not None: + to_postprocess["images"].append(result[0]["images"]) backlog.extend(result[1]) return to_postprocess, backlog