Update server.py

This commit is contained in:
Nina 2025-11-07 19:01:09 +01:00 committed by GitHub
parent a5a8b07848
commit 97107ca868
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -404,61 +404,7 @@ async def get_latest_example():
@app.post("/scored_data")
async def scored_data(scored_data: ScoredData):
data_dict = {
"tokens": scored_data.tokens,
"masks": scored_data.masks,
"scores": scored_data.scores,
"advantages": scored_data.advantages,
"ref_logprobs": scored_data.ref_logprobs,
"messages": scored_data.messages,
"generation_params": scored_data.generation_params,
"inference_logprobs": scored_data.inference_logprobs,
"overrides": scored_data.overrides,
"group_overrides": scored_data.group_overrides,
"images": scored_data.images,
"env_id": scored_data.env_id,
}
# Check if this is a mixed-size group
env_id = scored_data.env_id
if env_id is not None and env_id < len(app.state.envs):
expected_group_size = app.state.envs[env_id].get("group_size", 1)
actual_group_size = len(scored_data.tokens)
if actual_group_size != expected_group_size:
# Mixed size group - add to buffer
if env_id not in app.state.buffer:
app.state.buffer[env_id] = []
app.state.buffer[env_id].append(data_dict)
# Try to find groups that sum to expected_group_size
indices = find_groups_summing_to_target(
app.state.buffer[env_id], expected_group_size
)
if indices:
# Add these groups to queue in order
groups_to_add = []
for idx in sorted(indices, reverse=True):
groups_to_add.append(app.state.buffer[env_id].pop(idx))
# Add in FIFO order
for group in reversed(groups_to_add):
app.state.queue.append(group)
app.state.latest = group
return {
"status": "buffered",
"buffer_size": sum(
len(g["tokens"]) for g in app.state.buffer.get(env_id, [])
),
}
# Normal path - correct size or no env info
app.state.queue.append(data_dict)
app.state.latest = data_dict
return {"status": "received"}