mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
group temps, sample temps, and logprob api params
This commit is contained in:
parent
efc6b55f0a
commit
c3fc68879c
4 changed files with 65 additions and 12 deletions
|
|
@ -140,12 +140,16 @@ class TestAPIMessagesHandling:
|
|||
"messages": [messages],
|
||||
"advantages": [[0.5, 0.5, 0.5, 0.5, 0.5]],
|
||||
"ref_logprobs": [[-0.1, -0.2, -0.3, -0.4, -0.5]],
|
||||
"generation_params": {"temperature": 0.7},
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print(f"Error response: {response.text}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "received"
|
||||
# batches
|
||||
latest = requests.get("http://localhost:8000/latest_example").json()
|
||||
assert latest.get("generation_params", {}).get("temperature") == 0.7
|
||||
|
||||
def test_scored_data_list_with_messages(self, api_server):
|
||||
"""Test posting a list of scored data with messages."""
|
||||
|
|
@ -307,20 +311,25 @@ class TestAPIMessagesHandling:
|
|||
)
|
||||
assert register_response.status_code == 200
|
||||
|
||||
# Post two items to make a full batch
|
||||
for i in range(2):
|
||||
messages = [
|
||||
{"role": "user", "content": f"Test message {i}", "reward": None},
|
||||
{"role": "assistant", "content": f"Response {i}", "reward": None},
|
||||
]
|
||||
payload = {
|
||||
"tokens": [[i * 10 + j for j in range(5)]],
|
||||
"masks": [[1] * 5],
|
||||
"scores": [float(i)],
|
||||
"messages": [messages],
|
||||
}
|
||||
if i == 0:
|
||||
payload["overrides"] = [{"temperature": 0.5}]
|
||||
else:
|
||||
payload["generation_params"] = {"temperature": 0.8}
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:8000/scored_data",
|
||||
json={
|
||||
"tokens": [[i * 10 + j for j in range(5)]],
|
||||
"masks": [[1] * 5],
|
||||
"scores": [float(i)],
|
||||
"messages": [messages],
|
||||
},
|
||||
json=payload,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
|
@ -339,6 +348,13 @@ class TestAPIMessagesHandling:
|
|||
assert len(item["messages"][0]) == 2
|
||||
assert item["messages"][0][0]["role"] == "user"
|
||||
assert item["messages"][0][1]["role"] == "assistant"
|
||||
# temp passthroughs
|
||||
if i == 0:
|
||||
assert item.get("overrides") is not None
|
||||
assert item["overrides"][0].get("temperature") == 0.5
|
||||
else:
|
||||
assert item.get("generation_params") is not None
|
||||
assert item["generation_params"].get("temperature") == 0.8
|
||||
|
||||
def test_latest_example_with_messages(self, api_server):
|
||||
"""Test that latest example endpoint includes messages."""
|
||||
|
|
@ -376,6 +392,7 @@ class TestAPIMessagesHandling:
|
|||
"masks": [[1, 1, 1]],
|
||||
"scores": [0.95],
|
||||
"messages": [messages],
|
||||
"inference_logprobs": [[-1.0, -0.7, -0.2]],
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
|
|
@ -390,6 +407,7 @@ class TestAPIMessagesHandling:
|
|||
assert "messages" in latest_data
|
||||
assert latest_data["messages"] == [messages]
|
||||
assert len(latest_data["messages"][0]) == 3
|
||||
assert latest_data.get("inference_logprobs") == [[-1.0, -0.7, -0.2]]
|
||||
|
||||
def test_empty_messages_handling(self, api_server):
|
||||
"""Test handling of empty or None messages."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue