diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 705dd391..85290fce 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -416,20 +416,15 @@ async def scored_data_list(scored_data_list: List[ScoredData]): buffered_count = 0 last_buffer_size: Optional[int] = None for scored_data in scored_data_list: - data_dict = { - "tokens": scored_data.tokens, - "masks": scored_data.masks, - "scores": scored_data.scores, - "advantages": scored_data.advantages, - "ref_logprobs": scored_data.ref_logprobs, - "images": scored_data.images, - "messages": scored_data.messages, - "generation_params": scored_data.generation_params, - "inference_logprobs": scored_data.inference_logprobs, - "overrides": scored_data.overrides, - "group_overrides": scored_data.group_overrides, - "env_id": scored_data.env_id, - } + result = _process_scored_data(scored_data) + if result.get("status") == "buffered": + buffered_count += 1 + last_buffer_size = result.get("buffer_size", last_buffer_size) + + response: Dict[str, Any] = { + "status": "received", + "groups_processed": len(scored_data_list), + } # Check if this is a mixed-size group env_id = scored_data.env_id