mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
Enhance ScoredData model and API documentation
- Added optional fields: advantages, messages, and images to the ScoredData model. - Updated API responses to include these new fields when no data is available. - Revised README.md to reflect changes in the API structure and response format.
This commit is contained in:
parent
46a43a89bf
commit
4a21ed0891
3 changed files with 18 additions and 1 deletions
|
|
@ -129,9 +129,12 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
|
|||
tokens: List[List[int]]
|
||||
masks: List[List[int]]
|
||||
scores: List[float]
|
||||
advantages: Optional[List[List[float]]] = None
|
||||
ref_logprobs: Optional[List[List[float]]] = None
|
||||
messages: Optional[List[List[Message]]] = None
|
||||
overrides: Optional[List[dict]] = None # Per-item logging overrides
|
||||
group_overrides: Optional[dict] = None # Group logging overrides
|
||||
images: Optional[Any] = None # Image data (if applicable)
|
||||
```
|
||||
* **Response:** `{"status": "received"}`
|
||||
* `POST /scored_data_list`
|
||||
|
|
@ -145,7 +148,7 @@ The API documentation (Swagger UI) will be available at `http://<your-server-ip>
|
|||
* Not enough data: `{"batch": null}`
|
||||
* `GET /latest_example`
|
||||
* **Description:** Debug endpoint to retrieve the most recently added `ScoredData` item.
|
||||
* **Response:** The last `ScoredData` dictionary pushed, or empty lists if none yet.
|
||||
* **Response:** The last `ScoredData` dictionary pushed, or empty lists for tokens, masks, scores, advantages, ref_logprobs, messages, and images if none yet.
|
||||
|
||||
### Debugging
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from fastapi.responses import PlainTextResponse
|
|||
from pydantic import BaseModel
|
||||
|
||||
from atroposlib.api.utils import grab_exact_from_heterogeneous_queue
|
||||
from atroposlib.type_definitions import Message
|
||||
|
||||
app = FastAPI(title="AtroposLib API")
|
||||
|
||||
|
|
@ -50,7 +51,9 @@ class ScoredData(BaseModel):
|
|||
tokens: List[List[int]]
|
||||
masks: List[List[int]]
|
||||
scores: List[float]
|
||||
advantages: Optional[List[List[float]]] = None
|
||||
ref_logprobs: Optional[List[List[float]]] = None
|
||||
messages: Optional[List[List[Message]]] = None
|
||||
overrides: Optional[List[dict]] = None
|
||||
group_overrides: Optional[dict] = None
|
||||
images: Optional[Any] = None
|
||||
|
|
@ -212,7 +215,9 @@ async def get_latest_example():
|
|||
"tokens": [],
|
||||
"masks": [],
|
||||
"scores": [],
|
||||
"advantages": [],
|
||||
"ref_logprobs": [],
|
||||
"messages": [],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
|
|
@ -224,7 +229,9 @@ async def scored_data(scored_data: ScoredData):
|
|||
"tokens": scored_data.tokens,
|
||||
"masks": scored_data.masks,
|
||||
"scores": scored_data.scores,
|
||||
"advantages": scored_data.advantages,
|
||||
"ref_logprobs": scored_data.ref_logprobs,
|
||||
"messages": scored_data.messages,
|
||||
"overrides": scored_data.overrides,
|
||||
"group_overrides": scored_data.group_overrides,
|
||||
"images": scored_data.images,
|
||||
|
|
@ -245,8 +252,10 @@ async def scored_data_list(scored_data_list: List[ScoredData]):
|
|||
"tokens": scored_data.tokens,
|
||||
"masks": scored_data.masks,
|
||||
"scores": scored_data.scores,
|
||||
"advantages": scored_data.advantages,
|
||||
"ref_logprobs": scored_data.ref_logprobs,
|
||||
"images": scored_data.images,
|
||||
"messages": scored_data.messages,
|
||||
"overrides": scored_data.overrides,
|
||||
"group_overrides": scored_data.group_overrides,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ class ScoredDataGroup(TypedDict):
|
|||
messages: Optional[List[List[Message]]]
|
||||
group_overrides: Optional[Dict]
|
||||
overrides: Optional[List[Dict]]
|
||||
images: Optional[Any]
|
||||
|
||||
|
||||
class ScoredDataItem(TypedDict):
|
||||
|
|
@ -71,6 +72,7 @@ class ScoredDataItem(TypedDict):
|
|||
messages: Optional[List[Message]]
|
||||
group_overrides: Optional[Dict]
|
||||
overrides: Optional[Dict]
|
||||
images: Optional[Any]
|
||||
|
||||
|
||||
class EvalHandlingEnum(Enum):
|
||||
|
|
@ -282,6 +284,7 @@ class BaseEnv(ABC):
|
|||
to_postprocess["messages"] = []
|
||||
to_postprocess["group_overrides"] = {}
|
||||
to_postprocess["overrides"] = []
|
||||
to_postprocess["images"] = []
|
||||
print("Processing results")
|
||||
for result in results:
|
||||
to_postprocess["tokens"].append(result[0]["tokens"])
|
||||
|
|
@ -297,6 +300,8 @@ class BaseEnv(ABC):
|
|||
to_postprocess["group_overrides"].update(result[0]["group_overrides"])
|
||||
if result[0].get("overrides", None) is not None:
|
||||
to_postprocess["overrides"].append(result[0]["overrides"])
|
||||
if result[0].get("images", None) is not None:
|
||||
to_postprocess["images"].append(result[0]["images"])
|
||||
backlog.extend(result[1])
|
||||
return to_postprocess, backlog
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue