atropos/atroposlib/tests/test_api_messages_handling.py

1331 lines
48 KiB
Python

"""
Tests for API server message handling, particularly for SFT (Supervised Fine-Tuning) scenarios.
"""
import os
import signal
import subprocess
import time
import pytest
import requests
def wait_for_api_server(max_wait=10):
"""Wait for API server to be ready."""
for _ in range(max_wait):
try:
response = requests.get("http://localhost:8000/info")
if response.status_code == 200:
return True
except requests.exceptions.ConnectionError:
pass
time.sleep(1)
return False
@pytest.fixture(scope="module")
def api_server():
"""Launch API server for testing."""
# Start the API server as a subprocess
proc = subprocess.Popen(
[
"python",
"-m",
"atroposlib.cli.run_api",
"--host",
"localhost",
"--port",
"8000",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid, # Create new process group
)
# Wait for server to be ready
if not wait_for_api_server():
proc.terminate()
raise RuntimeError("API server failed to start")
yield
# Kill the process group to ensure all child processes are terminated
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
proc.wait()
# Clean up after tests
try:
requests.get("http://localhost:8000/reset_data")
except Exception:
pass
@pytest.fixture(autouse=True)
def reset_api_state():
"""Reset API state before each test."""
requests.get("http://localhost:8000/reset_data")
yield
requests.get("http://localhost:8000/reset_data")
class TestAPIMessagesHandling:
"""Test class for API messages handling."""
def test_register_trainer(self, api_server):
"""Test trainer registration."""
response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test_group",
"wandb_project": "test_project",
"batch_size": 32,
"max_token_len": 2048,
"checkpoint_dir": "/tmp/test_checkpoint",
"save_checkpoint_interval": 100,
"starting_step": 0,
"num_steps": 1000,
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
data = response.json()
assert "uuid" in data
assert isinstance(data["uuid"], int)
def test_scored_data_with_messages(self, api_server):
"""Test posting scored data with messages field."""
# First register
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": 2,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Test with messages in OpenAI format
messages = [
{
"role": "system",
"content": "You are a helpful assistant.",
"reward": None,
},
{
"role": "user",
"content": "What is the capital of France?",
"reward": None,
},
{
"role": "assistant",
"content": "The capital of France is Paris.",
"reward": None,
},
]
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[1, 2, 3, 4, 5]],
"masks": [[1, 1, 1, 1, 1]],
"scores": [1.0],
"messages": [messages],
"advantages": [[0.5, 0.5, 0.5, 0.5, 0.5]],
"ref_logprobs": [[-0.1, -0.2, -0.3, -0.4, -0.5]],
"generation_params": {"temperature": 0.7},
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
assert response.json()["status"] == "received"
# batches
latest = requests.get("http://localhost:8000/latest_example").json()
assert latest.get("generation_params", {}).get("temperature") == 0.7
def test_scored_data_list_with_messages(self, api_server):
"""Test posting a list of scored data with messages."""
# First register
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "test",
"wandb_project": "test",
"batch_size": 4,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Create multiple scored data items with messages
scored_data_list = []
for i in range(3):
messages = [
{"role": "user", "content": f"Question {i}", "reward": None},
{"role": "assistant", "content": f"Answer {i}", "reward": None},
]
scored_data_list.append(
{
"tokens": [[i + 1, i + 2, i + 3]],
"masks": [[1, 1, 1]],
"scores": [float(i)],
"messages": [messages],
}
)
response = requests.post(
"http://localhost:8000/scored_data_list", json=scored_data_list
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
data = response.json()
assert data["status"] == "received"
assert data["groups_processed"] == 3
def test_sft_style_messages(self, api_server):
"""Test SFT-style message handling with group overrides."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "sft_test",
"wandb_project": "sft_test",
"batch_size": 1,
"max_token_len": 1024,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# SFT-style data with ShareGPT format messages
sharegpt_messages = [
{"role": "user", "content": "Explain quantum computing", "reward": None},
{
"role": "assistant",
"content": "Quantum computing is a type of computing...",
"reward": None,
},
]
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[100, 101, 102, 103, 104, 105]],
"masks": [[-100, -100, 102, 103, 104, 105]], # Masked prefix
"scores": [1.0],
"messages": [sharegpt_messages],
"advantages": [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]],
"group_overrides": {"sft": True},
"overrides": [{"sft": True}],
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
assert response.json()["status"] == "received"
def test_multimodal_messages_with_images(self, api_server):
"""Test messages with image data."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "multimodal_test",
"wandb_project": "multimodal_test",
"batch_size": 1,
"max_token_len": 2048,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Multimodal message with image reference
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,..."},
},
],
"reward": None,
},
{
"role": "assistant",
"content": "I can see a cat in the image.",
"reward": None,
},
]
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[200, 201, 202, 203]],
"masks": [[1, 1, 1, 1]],
"scores": [0.9],
"messages": [messages],
"images": ["base64_encoded_image_data"],
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
assert response.json()["status"] == "received"
def test_batch_retrieval_with_messages(self, api_server):
"""Test retrieving batches that contain messages."""
# Register with batch size 2
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "batch_test",
"wandb_project": "batch_test",
"batch_size": 2,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
for i in range(2):
messages = [
{"role": "user", "content": f"Test message {i}", "reward": None},
{"role": "assistant", "content": f"Response {i}", "reward": None},
]
payload = {
"tokens": [[i * 10 + j for j in range(5)]],
"masks": [[1] * 5],
"scores": [float(i)],
"messages": [messages],
}
if i == 0:
payload["overrides"] = [{"temperature": 0.5}]
else:
payload["generation_params"] = {"temperature": 0.8}
response = requests.post(
"http://localhost:8000/scored_data",
json=payload,
)
assert response.status_code == 200
# Retrieve the batch
batch_response = requests.get("http://localhost:8000/batch")
assert batch_response.status_code == 200
batch_data = batch_response.json()
assert batch_data["batch"] is not None
assert len(batch_data["batch"]) == 2
# Verify messages are preserved in the batch
for i, item in enumerate(batch_data["batch"]):
assert "messages" in item
assert item["messages"] is not None
assert len(item["messages"]) == 1
assert len(item["messages"][0]) == 2
assert item["messages"][0][0]["role"] == "user"
assert item["messages"][0][1]["role"] == "assistant"
# temp passthroughs
if i == 0:
assert item.get("overrides") is not None
assert item["overrides"][0].get("temperature") == 0.5
else:
assert item.get("generation_params") is not None
assert item["generation_params"].get("temperature") == 0.8
def test_latest_example_with_messages(self, api_server):
"""Test that latest example endpoint includes messages."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "latest_test",
"wandb_project": "latest_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Post data with messages
messages = [
{
"role": "system",
"content": "You are a coding assistant.",
"reward": None,
},
{"role": "user", "content": "Write a Python hello world.", "reward": None},
{"role": "assistant", "content": "print('Hello, World!')", "reward": None},
]
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[42, 43, 44]],
"masks": [[1, 1, 1]],
"scores": [0.95],
"messages": [messages],
"inference_logprobs": [[-1.0, -0.7, -0.2]],
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
# Get latest example
latest_response = requests.get("http://localhost:8000/latest_example")
assert latest_response.status_code == 200
latest_data = latest_response.json()
assert "messages" in latest_data
assert latest_data["messages"] == [messages]
assert len(latest_data["messages"][0]) == 3
assert latest_data.get("inference_logprobs") == [[-1.0, -0.7, -0.2]]
def test_empty_messages_handling(self, api_server):
"""Test handling of empty or None messages."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "empty_test",
"wandb_project": "empty_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Test with None messages (optional field)
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[1, 2, 3]],
"masks": [[1, 1, 1]],
"scores": [1.0],
# messages field omitted
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
assert response.json()["status"] == "received"
# Test with empty messages list
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[4, 5, 6]],
"masks": [[1, 1, 1]],
"scores": [1.0],
"messages": [],
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
assert response.json()["status"] == "received"
def test_complex_message_structures(self, api_server):
"""Test handling of complex message structures with tool calls."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "complex_test",
"wandb_project": "complex_test",
"batch_size": 1,
"max_token_len": 2048,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Complex messages with tool role
messages = [
{
"role": "system",
"content": "You have access to calculation tools.",
"reward": None,
},
{"role": "user", "content": "What is 15 * 23?", "reward": None},
{
"role": "assistant",
"content": "I'll calculate that for you.",
"reward": None,
},
{"role": "tool", "content": "Result: 345", "reward": None},
{"role": "assistant", "content": "15 * 23 = 345", "reward": None},
]
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[300, 301, 302, 303, 304]],
"masks": [[1, 1, 1, 1, 1]],
"scores": [0.85],
"messages": [messages],
"advantages": [[0.1, 0.2, 0.3, 0.4, 0.5]],
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
assert response.json()["status"] == "received"
def test_message_reward_field(self, api_server):
"""Test messages with reward field as defined in Message TypedDict."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "reward_test",
"wandb_project": "reward_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Messages with reward field
messages = [
{"role": "user", "content": "Solve: 2+2", "reward": None},
{"role": "assistant", "content": "2+2 = 4", "reward": 1.0},
]
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[400, 401, 402]],
"masks": [[1, 1, 1]],
"scores": [1.0],
"messages": [messages],
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
assert response.json()["status"] == "received"
class TestSFTIntegration:
"""Test SFT-specific integration scenarios."""
def test_sft_completion_format(self, api_server):
"""Test SFT with completion format."""
# Register
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "sft_completion",
"wandb_project": "sft_completion",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Simple completion text - messages field is omitted for completion format
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[500, 501, 502, 503]],
"masks": [[500, 501, 502, 503]],
"scores": [1],
"advantages": [[1, 1, 1, 1]],
# messages field omitted - completion format doesn't use Message objects
"group_overrides": {"sft": True},
"overrides": [{"sft": True}],
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
def test_sft_prefixed_completion(self, api_server):
"""Test SFT with prefixed completion format."""
# Register
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "sft_prefixed",
"wandb_project": "sft_prefixed",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Prefixed completion with masked prefix
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[600, 601, 602, 603, 604, 605]],
"masks": [[-100, -100, -100, 603, 604, 605]], # First 3 tokens masked
"scores": [1],
"advantages": [[1, 1, 1, 1, 1, 1]],
# messages field omitted - prefixed completion format doesn't use Message objects
"group_overrides": {"sft": True},
"overrides": [{"sft": True}],
},
)
if response.status_code != 200:
print(f"Error response: {response.text}")
assert response.status_code == 200
def test_sft_batch_processing(self, api_server):
"""Test batch processing for SFT data."""
# Register with larger batch size
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "sft_batch",
"wandb_project": "sft_batch",
"batch_size": 4,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Send multiple SFT items
for i in range(4):
messages = [
{"role": "user", "content": f"Question {i}", "reward": None},
{"role": "assistant", "content": f"Answer {i}", "reward": None},
]
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[700 + i, 701 + i, 702 + i]],
"masks": [[-100, 701 + i, 702 + i]],
"scores": [1],
"advantages": [[1, 1, 1, 1]],
"messages": [messages],
"group_overrides": {"sft": True},
"overrides": [{"sft": True}],
},
)
assert response.status_code == 200
# Verify queue size
status_response = requests.get("http://localhost:8000/status")
assert status_response.status_code == 200
assert status_response.json()["queue_size"] == 4
# Get batch
batch_response = requests.get("http://localhost:8000/batch")
assert batch_response.status_code == 200
batch = batch_response.json()["batch"]
assert len(batch) == 4
# Verify all items have SFT overrides
for item in batch:
assert item.get("group_overrides", {}).get("sft") is True
assert item.get("overrides", [{}])[0].get("sft") is True
class TestMessageRewardHandling:
"""Test different scenarios with reward field in messages."""
def test_messages_without_reward_field(self, api_server):
"""Test messages without the reward field - should be accepted."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "no_reward_test",
"wandb_project": "no_reward_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Messages without reward field should be accepted
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[800, 801, 802]],
"masks": [[1, 1, 1]],
"scores": [0.5],
"messages": [
[
{
"role": "user",
"content": "Write a poem about the ocean",
# No reward field - this should be OK
},
{
"role": "assistant",
"content": "Waves crash upon the shore,\nEndless blue forevermore.",
# No reward field
},
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
# This should now succeed
print(f"Status code without reward field: {response.status_code}")
if response.status_code != 200:
error_data = response.json()
print(f"Error response: {error_data}")
assert (
response.status_code == 200
), "Messages without reward field should be accepted"
def test_mixed_reward_presence(self, api_server):
"""Test messages with inconsistent reward field presence - should be accepted."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "mixed_reward_test",
"wandb_project": "mixed_reward_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Mixed messages - some with reward, some without - should be OK
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[810, 811, 812]],
"masks": [[1, 1, 1]],
"scores": [0.7],
"messages": [
[
{"role": "user", "content": "Hello"}, # No reward field
{
"role": "assistant",
"content": "Hi!",
"reward": 0.9,
}, # Has reward
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Mixed reward presence status: {response.status_code}")
if response.status_code != 200:
print(f"Error: {response.json()}")
assert response.status_code == 200, "Mixed reward presence should be accepted"
def test_reward_none_vs_missing(self, api_server):
"""Test explicit None reward vs missing field."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "reward_none_test",
"wandb_project": "reward_none_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Test 1: All messages have reward=None (should work)
response1 = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[820, 821, 822]],
"masks": [[1, 1, 1]],
"scores": [0.8],
"messages": [
[
{"role": "user", "content": "Test with None", "reward": None},
{"role": "assistant", "content": "Response", "reward": None},
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
assert (
response1.status_code == 200
), f"None reward should work: {response1.json()}"
# Test 2: Messages without reward field should also work
messages_missing_reward = []
msg1 = {"role": "user", "content": "Test without reward"}
msg2 = {"role": "assistant", "content": "Response without reward"}
messages_missing_reward.append(msg1)
messages_missing_reward.append(msg2)
response2 = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[830, 831, 832]],
"masks": [[1, 1, 1]],
"scores": [0.6],
"messages": [messages_missing_reward],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Missing reward field status: {response2.status_code}")
if response2.status_code != 200:
print(f"Error details: {response2.json()}")
assert (
response2.status_code == 200
), "Messages without reward field should be accepted"
def test_extra_fields_in_messages(self, api_server):
"""Test messages with extra fields not defined in Message TypedDict."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "extra_fields_test",
"wandb_project": "extra_fields_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Messages with extra fields that aren't in the TypedDict
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[840, 841, 842]],
"masks": [[1, 1, 1]],
"scores": [0.7],
"messages": [
[
{
"role": "user",
"content": "Hello AI",
"reward": None,
"definitely_not_in_typeddict_kwarg": "surprise!",
"another_extra_field": 42,
"yet_another_field": {"nested": "data"},
},
{
"role": "assistant",
"content": "Hello human!",
"reward": 0.5,
"definitely_not_in_typeddict_kwarg": "another surprise!",
"random_metadata": ["list", "of", "things"],
},
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Extra fields status: {response.status_code}")
if response.status_code == 200:
print("API accepts extra fields in messages!")
else:
error_data = response.json()
print(f"Error with extra fields: {error_data}")
# Check if it's complaining about the extra fields
assert (
"definitely_not_in_typeddict_kwarg" in str(error_data)
or response.status_code == 422
)
def test_extra_fields_without_reward(self, api_server):
"""Test messages with extra fields but missing the reward field - should be accepted."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "extra_no_reward_test",
"wandb_project": "extra_no_reward_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Messages with extra fields but NO reward field - should be OK
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[850, 851, 852]],
"masks": [[1, 1, 1]],
"scores": [0.8],
"messages": [
[
{
"role": "user",
"content": "Test message",
# NO reward field!
"definitely_not_in_typeddict_kwarg": "I'm here but reward isn't!",
"extra_metadata": {"key": "value"},
"priority": 10,
},
{
"role": "assistant",
"content": "Response message",
# NO reward field!
"definitely_not_in_typeddict_kwarg": "Still no reward field",
"completion_tokens": 42,
"model": "gpt-4",
},
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Extra fields without reward status: {response.status_code}")
if response.status_code != 200:
error_data = response.json()
print(f"Error: {error_data}")
assert (
response.status_code == 200
), "Messages with extra fields but no reward should be accepted"
class TestWeirdMessageFormats:
"""Test edge cases with unusual data structures that users might try."""
def test_messages_as_tuples(self, api_server):
"""Test sending messages as tuples instead of lists."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "tuple_test",
"wandb_project": "tuple_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Try messages as tuple of tuples (JSON will convert to lists)
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[900, 901, 902]],
"masks": [[1, 1, 1]],
"scores": [0.7],
# This will be converted to list by JSON serialization
"messages": [
(
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
)
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Tuple messages status: {response.status_code}")
if response.status_code != 200:
print(f"Error: {response.json()}")
# Should work since JSON converts tuples to lists
assert response.status_code == 200
def test_messages_with_nested_weird_types(self, api_server):
"""Test messages with nested unusual types."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "nested_weird_test",
"wandb_project": "nested_weird_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Messages with weird nested content
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[910, 911, 912]],
"masks": [[1, 1, 1]],
"scores": [0.8],
"messages": [
[
{
"role": "user",
"content": {
"text": "Complex content",
"metadata": {
"nested": {"deeply": ["list", "of", "things"]}
},
"numbers": (1, 2, 3), # Tuple becomes list in JSON
},
},
{
"role": "assistant",
"content": ["This", "is", "a", "list", "content"],
"reward": 0.5,
},
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Nested weird types status: {response.status_code}")
if response.status_code != 200:
print(f"Error: {response.json()}")
assert response.status_code == 200
def test_messages_with_numeric_strings_as_content(self, api_server):
"""Test messages with numeric strings and edge case content."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "numeric_content_test",
"wandb_project": "numeric_content_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Messages with numeric and edge case content
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[920, 921, 922]],
"masks": [[1, 1, 1]],
"scores": [0.6],
"messages": [
[
{"role": "user", "content": "12345"}, # Numeric string
{"role": "assistant", "content": ""}, # Empty string
{"role": "user", "content": " "}, # Whitespace only
{"role": "assistant", "content": "0"}, # Zero as string
]
],
"advantages": [[0.5, 0.3, 0.2, 0.1]],
},
)
print(f"Numeric content status: {response.status_code}")
assert response.status_code == 200
def test_messages_with_boolean_and_null_values(self, api_server):
"""Test messages with boolean and null values in various fields."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "bool_null_test",
"wandb_project": "bool_null_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Try various edge cases
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[930, 931, 932]],
"masks": [[1, 1, 1]],
"scores": [0.5],
"messages": [
[
{
"role": "user",
"content": "Normal message",
"reward": None, # Explicit None
"extra_bool": True,
"extra_null": None,
},
{
"role": "assistant",
"content": "Response",
"reward": 0.0, # Zero reward
"metadata": {"flag": False, "value": None},
},
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Bool/null values status: {response.status_code}")
assert response.status_code == 200
def test_messages_with_very_large_content(self, api_server):
"""Test messages with very large content strings."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "large_content_test",
"wandb_project": "large_content_test",
"batch_size": 1,
"max_token_len": 10000,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Large content message
large_content = "A" * 10000 # 10k character string
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[940, 941, 942]],
"masks": [[1, 1, 1]],
"scores": [0.7],
"messages": [
[
{"role": "user", "content": large_content},
{"role": "assistant", "content": "Short response"},
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Large content status: {response.status_code}")
assert response.status_code == 200
def test_messages_with_unicode_and_special_chars(self, api_server):
"""Test messages with unicode, emojis, and special characters."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "unicode_test",
"wandb_project": "unicode_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Unicode and special character messages
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[950, 951, 952]],
"masks": [[1, 1, 1]],
"scores": [0.8],
"messages": [
[
{
"role": "user",
"content": "Hello 👋 世界 🌍 مرحبا 🎉 Здравствуй",
},
{
"role": "assistant",
"content": "Special chars: \n\t\r \" ' \\ / 🚀",
"reward": 0.9,
},
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Unicode/special chars status: {response.status_code}")
assert response.status_code == 200
def test_messages_with_custom_roles(self, api_server):
"""Test messages with custom role values like 'dog'."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "custom_role_test",
"wandb_project": "custom_role_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Try custom/weird roles
response = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[960, 961, 962, 963, 964]],
"masks": [[1, 1, 1, 1, 1]],
"scores": [0.5],
"messages": [
[
{"role": "dog", "content": "Woof woof!"}, # Custom role
{"role": "cat", "content": "Meow"}, # Another custom role
{"role": "narrator", "content": "The animals were talking"},
{"role": "USER", "content": "Wrong case but should work"},
{"role": "robot", "content": "Beep boop"},
]
],
"advantages": [[0.5, 0.4, 0.3, 0.2, 0.1]],
},
)
print(f"Custom roles (dog, cat, etc) status: {response.status_code}")
if response.status_code != 200:
print(f"Error: {response.json()}")
# Should accept custom roles
assert response.status_code == 200
def test_messages_missing_required_fields(self, api_server):
"""Test messages missing role or content fields."""
# Register first
register_response = requests.post(
"http://localhost:8000/register",
json={
"wandb_group": "missing_fields_test",
"wandb_project": "missing_fields_test",
"batch_size": 1,
"max_token_len": 512,
"checkpoint_dir": "/tmp",
"save_checkpoint_interval": 10,
"starting_step": 0,
"num_steps": 100,
},
)
assert register_response.status_code == 200
# Try missing role
response1 = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[970, 971, 972]],
"masks": [[1, 1, 1]],
"scores": [0.5],
"messages": [
[
{"content": "Missing role field"}, # No role
{"role": "assistant", "content": "This one is OK"},
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Missing role status: {response1.status_code}")
if response1.status_code != 200:
print(f"Error: {response1.json()}")
assert response1.status_code == 422 # Should fail validation
# Try missing content
response2 = requests.post(
"http://localhost:8000/scored_data",
json={
"tokens": [[980, 981, 982]],
"masks": [[1, 1, 1]],
"scores": [0.5],
"messages": [
[
{"role": "user"}, # No content
{"role": "assistant", "content": "This one is OK"},
]
],
"advantages": [[0.5, 0.3, 0.2]],
},
)
print(f"Missing content status: {response2.status_code}")
if response2.status_code != 200:
print(f"Error: {response2.json()}")
assert response2.status_code == 422 # Should fail validation