mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge commit '71e7a5ca27' into add-support-for-custom-api-servers
This commit is contained in:
commit
96be544228
45 changed files with 1605 additions and 494 deletions
23
atroposlib/tests/conftest.py
Normal file
23
atroposlib/tests/conftest.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import pytest
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--runproviders", action="store_true", default=False, help="run provider tests"
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "providers: mark test as requires providers api keys to run"
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if config.getoption("--runproviders"):
|
||||
# --runproviders given in cli: do not skip slow tests
|
||||
return
|
||||
skip_providers = pytest.mark.skip(reason="need --runproviders option to run")
|
||||
for item in items:
|
||||
if "providers" in item.keywords:
|
||||
item.add_marker(skip_providers)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Adjust the import below if your functions are in a different module.
|
||||
from atroposlib.utils.advantages import (
|
||||
|
|
@ -23,9 +23,9 @@ def test_allclose_to_first_vector():
|
|||
"""Test that return_vector=True returns a tensor of booleans."""
|
||||
values = [1.0, 1.000000001, 1.000000002]
|
||||
result = allclose_to_first(values, return_vector=True)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert isinstance(result, np.ndarray)
|
||||
# All comparisons should be True.
|
||||
assert torch.all(result)
|
||||
assert np.all(result)
|
||||
|
||||
|
||||
def test_allclose_to_first_not_close():
|
||||
|
|
@ -74,15 +74,15 @@ def test_compute_stats_jagged():
|
|||
|
||||
def test_compute_discounted_returns():
|
||||
"""Test compute_discounted_returns with a tensor input."""
|
||||
rewards = torch.tensor([1.0, 1.0, 1.0])
|
||||
rewards = np.array([1.0, 1.0, 1.0])
|
||||
gamma = 0.9
|
||||
returns = compute_discounted_returns(rewards, gamma)
|
||||
# For a 3-element vector:
|
||||
# t=2: 1.0
|
||||
# t=1: 1.0 + 0.9*1.0 = 1.9
|
||||
# t=0: 1.0 + 0.9*1.9 = 2.71
|
||||
expected = torch.tensor([2.71, 1.9, 1.0])
|
||||
assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8)
|
||||
expected = np.array([2.71, 1.9, 1.0])
|
||||
assert np.allclose(returns, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
def test_compute_discounted_returns_list_input():
|
||||
|
|
@ -90,8 +90,8 @@ def test_compute_discounted_returns_list_input():
|
|||
rewards = [1, 1, 1]
|
||||
gamma = 0.0 # With gamma=0, the returns should equal the rewards.
|
||||
returns = compute_discounted_returns(rewards, gamma)
|
||||
expected = torch.tensor([1.0, 1.0, 1.0])
|
||||
assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8)
|
||||
expected = np.array([1.0, 1.0, 1.0])
|
||||
assert np.allclose(returns, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
def test_compute_grpo_process_supervision_advantages_cumsum():
|
||||
|
|
|
|||
110
atroposlib/tests/test_openai_api_workarounds.py
Normal file
110
atroposlib/tests/test_openai_api_workarounds.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
import asyncio
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
import pytest
|
||||
|
||||
from atroposlib.envs.server_handling.openai_server import APIServerConfig, OpenAIServer
|
||||
|
||||
|
||||
@pytest.mark.providers
|
||||
def test_openai_api_n_kwarg_ignore_discovery():
|
||||
dotenv.load_dotenv()
|
||||
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not openrouter_api_key:
|
||||
pytest.skip("OPENROUTER_API_KEY not set")
|
||||
config = APIServerConfig(
|
||||
api_key=openrouter_api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="openai/gpt-4.1-nano",
|
||||
timeout=1200,
|
||||
num_max_requests_at_once=512,
|
||||
num_requests_for_eval=64,
|
||||
rolling_buffer_length=1024,
|
||||
)
|
||||
assert not config.n_kwarg_is_ignored, "n kwarg is not ignored by default"
|
||||
n = 4
|
||||
server = OpenAIServer(
|
||||
config=config,
|
||||
)
|
||||
response = asyncio.run(
|
||||
server.chat_completion(
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
],
|
||||
n=n,
|
||||
)
|
||||
)
|
||||
assert server.config.n_kwarg_is_ignored, "n kwarg is should be set after discovery"
|
||||
print(len(response.choices), n)
|
||||
assert (
|
||||
len(response.choices) == n
|
||||
), f"Expected {n} responses, got {len(response.choices)}"
|
||||
|
||||
|
||||
@pytest.mark.providers
|
||||
def test_openai_api_n_kwarg_ignore_use():
|
||||
dotenv.load_dotenv()
|
||||
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not openrouter_api_key:
|
||||
pytest.skip("OPENROUTER_API_KEY not set")
|
||||
config = APIServerConfig(
|
||||
api_key=openrouter_api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="openai/gpt-4.1-nano",
|
||||
timeout=1200,
|
||||
num_max_requests_at_once=512,
|
||||
num_requests_for_eval=64,
|
||||
rolling_buffer_length=1024,
|
||||
n_kwarg_is_ignored=True,
|
||||
)
|
||||
server = OpenAIServer(
|
||||
config=config,
|
||||
)
|
||||
n = 4
|
||||
response = asyncio.run(
|
||||
server.chat_completion(
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
],
|
||||
n=n,
|
||||
)
|
||||
)
|
||||
assert server.config.n_kwarg_is_ignored, "n kwarg is should be set after discovery"
|
||||
assert (
|
||||
len(response.choices) == n
|
||||
), f"Expected {n} responses, got {len(response.choices)}"
|
||||
|
||||
|
||||
@pytest.mark.providers
|
||||
def test_openai_api_n_kwarg_supported():
|
||||
dotenv.load_dotenv()
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not openai_api_key:
|
||||
pytest.skip("OPENAI_API_KEY not set")
|
||||
config = APIServerConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
timeout=1200,
|
||||
num_max_requests_at_once=512,
|
||||
num_requests_for_eval=64,
|
||||
rolling_buffer_length=1024,
|
||||
n_kwarg_is_ignored=False,
|
||||
)
|
||||
server = OpenAIServer(
|
||||
config=config,
|
||||
)
|
||||
n = 4
|
||||
response = asyncio.run(
|
||||
server.chat_completion(
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
],
|
||||
n=n,
|
||||
)
|
||||
)
|
||||
assert (
|
||||
not server.config.n_kwarg_is_ignored
|
||||
), "n kwarg should be used with supported models"
|
||||
assert (
|
||||
len(response.choices) == n
|
||||
), f"Expected {n} responses, got {len(response.choices)}"
|
||||
Loading…
Add table
Add a link
Reference in a new issue