Update server.py

This commit is contained in:
Nina 2025-11-07 19:00:32 +01:00 committed by GitHub
parent c96b8a1255
commit a5a8b07848
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -165,6 +165,64 @@ class ScoredData(BaseModel):
return v
def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]:
"""Convert a `ScoredData` pydantic model into a plain dictionary."""
return {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"generation_params": scored_data.generation_params,
"inference_logprobs": scored_data.inference_logprobs,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
"env_id": scored_data.env_id,
}
def _process_scored_data(scored_data: ScoredData) -> Dict[str, Any]:
"""Normalize buffering/queueing logic for scored data submissions."""
if not hasattr(app.state, "queue"):
app.state.queue = []
if not hasattr(app.state, "buffer"):
app.state.buffer = {}
data_dict = _scored_data_to_dict(scored_data)
env_id = data_dict.get("env_id")
envs = getattr(app.state, "envs", [])
if env_id is not None and env_id < len(envs):
expected_group_size = envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if actual_group_size != expected_group_size:
buffer = app.state.buffer.setdefault(env_id, [])
buffer.append(data_dict)
indices = find_groups_summing_to_target(buffer, expected_group_size)
if indices:
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(buffer.pop(idx))
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
return {
"status": "buffered",
"buffer_size": sum(len(group["tokens"]) for group in app.state.buffer.get(env_id, [])),
}
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received"}
class Status(BaseModel):
"""