mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
Update server.py
This commit is contained in:
parent
a5a8b07848
commit
97107ca868
1 changed files with 0 additions and 54 deletions
|
|
@ -404,61 +404,7 @@ async def get_latest_example():
|
|||
|
||||
|
||||
@app.post("/scored_data")
|
||||
async def scored_data(scored_data: ScoredData):
|
||||
data_dict = {
|
||||
"tokens": scored_data.tokens,
|
||||
"masks": scored_data.masks,
|
||||
"scores": scored_data.scores,
|
||||
"advantages": scored_data.advantages,
|
||||
"ref_logprobs": scored_data.ref_logprobs,
|
||||
"messages": scored_data.messages,
|
||||
"generation_params": scored_data.generation_params,
|
||||
"inference_logprobs": scored_data.inference_logprobs,
|
||||
"overrides": scored_data.overrides,
|
||||
"group_overrides": scored_data.group_overrides,
|
||||
"images": scored_data.images,
|
||||
"env_id": scored_data.env_id,
|
||||
}
|
||||
|
||||
# Check if this is a mixed-size group
|
||||
env_id = scored_data.env_id
|
||||
if env_id is not None and env_id < len(app.state.envs):
|
||||
expected_group_size = app.state.envs[env_id].get("group_size", 1)
|
||||
actual_group_size = len(scored_data.tokens)
|
||||
|
||||
if actual_group_size != expected_group_size:
|
||||
# Mixed size group - add to buffer
|
||||
if env_id not in app.state.buffer:
|
||||
app.state.buffer[env_id] = []
|
||||
|
||||
app.state.buffer[env_id].append(data_dict)
|
||||
|
||||
# Try to find groups that sum to expected_group_size
|
||||
indices = find_groups_summing_to_target(
|
||||
app.state.buffer[env_id], expected_group_size
|
||||
)
|
||||
|
||||
if indices:
|
||||
# Add these groups to queue in order
|
||||
groups_to_add = []
|
||||
for idx in sorted(indices, reverse=True):
|
||||
groups_to_add.append(app.state.buffer[env_id].pop(idx))
|
||||
|
||||
# Add in FIFO order
|
||||
for group in reversed(groups_to_add):
|
||||
app.state.queue.append(group)
|
||||
app.state.latest = group
|
||||
|
||||
return {
|
||||
"status": "buffered",
|
||||
"buffer_size": sum(
|
||||
len(g["tokens"]) for g in app.state.buffer.get(env_id, [])
|
||||
),
|
||||
}
|
||||
|
||||
# Normal path - correct size or no env info
|
||||
app.state.queue.append(data_dict)
|
||||
app.state.latest = data_dict
|
||||
return {"status": "received"}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue