group temps, sample temps, and logprob api params

This commit is contained in:
ropresearch 2025-09-25 16:41:58 -04:00
parent efc6b55f0a
commit c3fc68879c
4 changed files with 65 additions and 12 deletions

View file

@ -140,12 +140,16 @@ class TestAPIMessagesHandling:
"messages": [messages],
"advantages": [[0.5, 0.5, 0.5, 0.5, 0.5]],
"ref_logprobs": [[-0.1, -0.2, -0.3, -0.4, -0.5]],
"generation_params": {"temperature": 0.7},
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
assert response.json()["status"] == "received"
# batches
latest = requests.get("http://localhost:8000/latest_example").json()
assert latest.get("generation_params", {}).get("temperature") == 0.7
def test_scored_data_list_with_messages(self, api_server):
"""Test posting a list of scored data with messages."""
@ -307,20 +311,25 @@ class TestAPIMessagesHandling:
)
assert register_response.status_code == 200
# Post two items to make a full batch
for i in range(2):
messages = [
{"role": "user", "content": f"Test message {i}", "reward": None},
{"role": "assistant", "content": f"Response {i}", "reward": None},
]
payload = {
"tokens": [[i * 10 + j for j in range(5)]],
"masks": [[1] * 5],
"scores": [float(i)],
"messages": [messages],
}
if i == 0:
payload["overrides"] = [{"temperature": 0.5}]
else:
payload["generation_params"] = {"temperature": 0.8}
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[i * 10 + j for j in range(5)]],
"masks": [[1] * 5],
"scores": [float(i)],
"messages": [messages],
},
json=payload,
)
assert response.status_code == 200
@ -339,6 +348,13 @@ class TestAPIMessagesHandling:
assert len(item["messages"][0]) == 2
assert item["messages"][0][0]["role"] == "user"
assert item["messages"][0][1]["role"] == "assistant"
# temp passthroughs
if i == 0:
assert item.get("overrides") is not None
assert item["overrides"][0].get("temperature") == 0.5
else:
assert item.get("generation_params") is not None
assert item["generation_params"].get("temperature") == 0.8
def test_latest_example_with_messages(self, api_server):
"""Test that latest example endpoint includes messages."""
@ -376,6 +392,7 @@ class TestAPIMessagesHandling:
"masks": [[1, 1, 1]],
"scores": [0.95],
"messages": [messages],
"inference_logprobs": [[-1.0, -0.7, -0.2]],
},
)
if response.status_code != 200:
@ -390,6 +407,7 @@ class TestAPIMessagesHandling:
assert "messages" in latest_data
assert latest_data["messages"] == [messages]
assert len(latest_data["messages"][0]) == 3
assert latest_data.get("inference_logprobs") == [[-1.0, -0.7, -0.2]]
def test_empty_messages_handling(self, api_server):
"""Test handling of empty or None messages."""