mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Update server.py
This commit is contained in:
parent
c96b8a1255
commit
a5a8b07848
1 changed files with 58 additions and 0 deletions
|
|
@ -165,6 +165,64 @@ class ScoredData(BaseModel):
|
|||
|
||||
return v
|
||||
|
||||
def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]:
|
||||
"""Convert a `ScoredData` pydantic model into a plain dictionary."""
|
||||
|
||||
return {
|
||||
"tokens": scored_data.tokens,
|
||||
"masks": scored_data.masks,
|
||||
"scores": scored_data.scores,
|
||||
"advantages": scored_data.advantages,
|
||||
"ref_logprobs": scored_data.ref_logprobs,
|
||||
"messages": scored_data.messages,
|
||||
"generation_params": scored_data.generation_params,
|
||||
"inference_logprobs": scored_data.inference_logprobs,
|
||||
"overrides": scored_data.overrides,
|
||||
"group_overrides": scored_data.group_overrides,
|
||||
"images": scored_data.images,
|
||||
"env_id": scored_data.env_id,
|
||||
}
|
||||
|
||||
|
||||
def _process_scored_data(scored_data: ScoredData) -> Dict[str, Any]:
|
||||
"""Normalize buffering/queueing logic for scored data submissions."""
|
||||
|
||||
if not hasattr(app.state, "queue"):
|
||||
app.state.queue = []
|
||||
if not hasattr(app.state, "buffer"):
|
||||
app.state.buffer = {}
|
||||
|
||||
data_dict = _scored_data_to_dict(scored_data)
|
||||
env_id = data_dict.get("env_id")
|
||||
envs = getattr(app.state, "envs", [])
|
||||
|
||||
if env_id is not None and env_id < len(envs):
|
||||
expected_group_size = envs[env_id].get("group_size", 1)
|
||||
actual_group_size = len(scored_data.tokens)
|
||||
|
||||
if actual_group_size != expected_group_size:
|
||||
buffer = app.state.buffer.setdefault(env_id, [])
|
||||
buffer.append(data_dict)
|
||||
|
||||
indices = find_groups_summing_to_target(buffer, expected_group_size)
|
||||
|
||||
if indices:
|
||||
groups_to_add = []
|
||||
for idx in sorted(indices, reverse=True):
|
||||
groups_to_add.append(buffer.pop(idx))
|
||||
|
||||
for group in reversed(groups_to_add):
|
||||
app.state.queue.append(group)
|
||||
app.state.latest = group
|
||||
|
||||
return {
|
||||
"status": "buffered",
|
||||
"buffer_size": sum(len(group["tokens"]) for group in app.state.buffer.get(env_id, [])),
|
||||
}
|
||||
|
||||
app.state.queue.append(data_dict)
|
||||
app.state.latest = data_dict
|
||||
return {"status": "received"}
|
||||
|
||||
class Status(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue