mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
384 lines
14 KiB
Python
384 lines
14 KiB
Python
"""Tests for SamplingClient subprocess mode.
|
|
|
|
These tests use a picklable fake SamplingClient to verify that
|
|
subprocess mode correctly routes sample() and compute_logprobs()
|
|
through the sidecar subprocess.
|
|
|
|
Test organization:
|
|
TestRPCRouting — sample/compute_logprobs delegation through sidecar
|
|
TestErrorHandling — error propagation, sidecar death
|
|
TestPickle — roundtrip with/without sidecar, re-enable mode
|
|
TestConcurrency — multithreaded, async, cancelled futures, mixed ops
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import pickle
|
|
import threading
|
|
import time
|
|
from concurrent.futures import Future as ConcurrentFuture
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from tinker import types
|
|
from tinker._exceptions import SidecarDiedError
|
|
from tinker.lib.sidecar import create_sidecar_handle
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Picklable fake SamplingClient (must be module-level for pickling)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _FakeSamplingClient:
|
|
"""A picklable fake that simulates SamplingClient for testing.
|
|
|
|
This is NOT a real SamplingClient — it provides just enough interface
|
|
to test the sidecar integration. Real SamplingClient requires an
|
|
InternalClientHolder and API connection.
|
|
"""
|
|
|
|
def __init__(self, delay: float = 0.0, fail: bool = False, subprocess_sampling: bool = False):
|
|
self._delay = delay
|
|
self._fail = fail
|
|
self._sampling_client_sidecar_handle = None # set by create_sidecar_handle() in tests
|
|
if subprocess_sampling:
|
|
from tinker.lib.sidecar import _inside_sidecar
|
|
|
|
if not _inside_sidecar:
|
|
self._sampling_client_sidecar_handle = create_sidecar_handle(self)
|
|
|
|
def sample(
|
|
self,
|
|
prompt: types.ModelInput,
|
|
num_samples: int,
|
|
sampling_params: types.SamplingParams,
|
|
include_prompt_logprobs: bool = False,
|
|
topk_prompt_logprobs: int = 0,
|
|
) -> Any:
|
|
# Delegate through sidecar if enabled (mirrors real SamplingClient behavior)
|
|
if self._sampling_client_sidecar_handle is not None:
|
|
from tinker.lib.public_interfaces.sampling_client import _SampleRPC
|
|
|
|
return self._sampling_client_sidecar_handle.submit_rpc(
|
|
_SampleRPC(
|
|
prompt,
|
|
num_samples,
|
|
sampling_params,
|
|
include_prompt_logprobs,
|
|
topk_prompt_logprobs,
|
|
)
|
|
)
|
|
|
|
f: ConcurrentFuture[types.SampleResponse] = ConcurrentFuture()
|
|
if self._fail:
|
|
f.set_exception(RuntimeError("Simulated sample failure"))
|
|
elif self._delay > 0:
|
|
|
|
def _delayed():
|
|
time.sleep(self._delay)
|
|
f.set_result(_make_sample_response())
|
|
|
|
threading.Thread(target=_delayed, daemon=True).start()
|
|
else:
|
|
f.set_result(_make_sample_response())
|
|
return f
|
|
|
|
def compute_logprobs(self, prompt: types.ModelInput) -> Any:
|
|
# Delegate through sidecar if enabled (mirrors real SamplingClient behavior)
|
|
if self._sampling_client_sidecar_handle is not None:
|
|
from tinker.lib.public_interfaces.sampling_client import _ComputeLogprobsRPC
|
|
|
|
return self._sampling_client_sidecar_handle.submit_rpc(_ComputeLogprobsRPC(prompt))
|
|
|
|
f: ConcurrentFuture[list[float | None]] = ConcurrentFuture()
|
|
if self._fail:
|
|
f.set_exception(RuntimeError("Simulated logprobs failure"))
|
|
else:
|
|
f.set_result([0.1, 0.2, None])
|
|
return f
|
|
|
|
def __reduce__(self) -> tuple[type, tuple[float, bool, bool]]:
|
|
return (
|
|
_FakeSamplingClient,
|
|
(self._delay, self._fail, self._sampling_client_sidecar_handle is not None),
|
|
)
|
|
|
|
|
|
def _make_sample_response() -> types.SampleResponse:
|
|
return types.SampleResponse(
|
|
sequences=[
|
|
types.SampledSequence(
|
|
stop_reason="length",
|
|
tokens=[1, 2, 3],
|
|
logprobs=[0.1, 0.2, 0.3],
|
|
)
|
|
],
|
|
type="sample",
|
|
)
|
|
|
|
|
|
def _create_proxy(delay: float = 0.0, fail: bool = False) -> _FakeSamplingClient:
|
|
"""Create a fake client with sidecar handle for testing."""
|
|
client = _FakeSamplingClient(delay=delay, fail=fail)
|
|
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
|
|
return client
|
|
|
|
|
|
_PROMPT = types.ModelInput.from_ints([1, 2, 3])
|
|
_PARAMS = types.SamplingParams(max_tokens=10)
|
|
|
|
|
|
# ===========================================================================
|
|
# Tests
|
|
# ===========================================================================
|
|
|
|
|
|
class TestRPCRouting:
|
|
"""Verify sample() and compute_logprobs() are routed through the sidecar."""
|
|
|
|
def test_sample(self):
|
|
"""sample() → subprocess → SampleResponse."""
|
|
proxy = _create_proxy()
|
|
result = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
|
assert isinstance(result, types.SampleResponse)
|
|
assert result.sequences[0].tokens == [1, 2, 3]
|
|
|
|
def test_constructor_enables_subprocess_mode(self):
|
|
"""subprocess_sampling=True in __init__ creates the sidecar handle."""
|
|
client = _FakeSamplingClient(subprocess_sampling=True)
|
|
assert client._sampling_client_sidecar_handle is not None
|
|
result = client.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
|
assert isinstance(result, types.SampleResponse)
|
|
|
|
def test_compute_logprobs(self):
|
|
"""compute_logprobs() → subprocess → list of logprobs."""
|
|
proxy = _create_proxy()
|
|
result = proxy.compute_logprobs(_PROMPT).result(timeout=10)
|
|
assert result == [0.1, 0.2, None]
|
|
|
|
def test_mixed_sample_and_logprobs(self):
|
|
"""Interleaved sample() and compute_logprobs() all resolve correctly."""
|
|
proxy = _create_proxy(delay=0.01)
|
|
|
|
futures_sample = [proxy.sample(_PROMPT, 1, _PARAMS) for _ in range(10)]
|
|
futures_logprobs = [proxy.compute_logprobs(_PROMPT) for _ in range(10)]
|
|
|
|
for f in futures_sample:
|
|
result = f.result(timeout=10)
|
|
assert isinstance(result, types.SampleResponse)
|
|
assert result.sequences[0].tokens == [1, 2, 3]
|
|
|
|
for f in futures_logprobs:
|
|
assert f.result(timeout=10) == [0.1, 0.2, None]
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Error propagation from subprocess to caller."""
|
|
|
|
def test_sample_error(self):
|
|
"""Exceptions from sample() in the subprocess are propagated."""
|
|
proxy = _create_proxy(fail=True)
|
|
with pytest.raises(RuntimeError, match="Simulated sample failure"):
|
|
proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
|
|
|
def test_compute_logprobs_error(self):
|
|
"""Exceptions from compute_logprobs() in the subprocess are propagated."""
|
|
proxy = _create_proxy(fail=True)
|
|
with pytest.raises(RuntimeError, match="Simulated logprobs failure"):
|
|
proxy.compute_logprobs(_PROMPT).result(timeout=10)
|
|
|
|
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
|
|
def test_sidecar_death_fails_pending_futures(self):
|
|
"""When the subprocess is killed, pending futures get SidecarDiedError."""
|
|
proxy = _create_proxy(delay=0.5)
|
|
future = proxy.sample(_PROMPT, 1, _PARAMS)
|
|
|
|
# Kill the underlying subprocess directly
|
|
sidecar = proxy._sampling_client_sidecar_handle._sidecar
|
|
assert sidecar._process is not None
|
|
sidecar._process.kill()
|
|
sidecar._process.join(timeout=5)
|
|
|
|
with pytest.raises(SidecarDiedError):
|
|
future.result(timeout=5)
|
|
|
|
|
|
class TestPickle:
|
|
"""Pickle roundtrip preserves subprocess mode correctly."""
|
|
|
|
def test_roundtrip_preserves_subprocess_mode(self):
|
|
"""Pickling a sidecar-enabled client re-enables subprocess mode on unpickle."""
|
|
proxy = _create_proxy()
|
|
assert proxy._sampling_client_sidecar_handle is not None
|
|
|
|
restored = pickle.loads(pickle.dumps(proxy))
|
|
assert restored._sampling_client_sidecar_handle is not None
|
|
|
|
result = restored.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
|
assert isinstance(result, types.SampleResponse)
|
|
|
|
def test_roundtrip_without_sidecar(self):
|
|
"""Pickling a client without subprocess mode keeps it disabled."""
|
|
client = _FakeSamplingClient()
|
|
assert client._sampling_client_sidecar_handle is None
|
|
restored = pickle.loads(pickle.dumps(client))
|
|
assert restored._sampling_client_sidecar_handle is None
|
|
|
|
def test_re_enable_subprocess_mode(self):
|
|
"""Replacing the sidecar handle works cleanly."""
|
|
client = _FakeSamplingClient()
|
|
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
|
|
|
|
# First handle works
|
|
assert isinstance(
|
|
client.sample(_PROMPT, 1, _PARAMS).result(timeout=10), types.SampleResponse
|
|
)
|
|
|
|
# Replace with a new handle (old one is GC'd and unregistered)
|
|
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
|
|
|
|
# New handle also works
|
|
assert isinstance(
|
|
client.sample(_PROMPT, 1, _PARAMS).result(timeout=10), types.SampleResponse
|
|
)
|
|
|
|
|
|
class TestConcurrency:
|
|
"""Thread safety and concurrent operations through the sidecar."""
|
|
|
|
def test_multithreaded_samples(self):
|
|
"""sample() from 20 threads all resolve correctly."""
|
|
proxy = _create_proxy(delay=0.01)
|
|
results: list[types.SampleResponse | None] = [None] * 20
|
|
errors: list[Exception] = []
|
|
|
|
def _worker(idx: int) -> None:
|
|
try:
|
|
results[idx] = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=30)
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
threads = [threading.Thread(target=_worker, args=(i,)) for i in range(20)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join(timeout=30)
|
|
|
|
assert not errors, f"Threads raised: {errors}"
|
|
for r in results:
|
|
assert isinstance(r, types.SampleResponse)
|
|
assert r.sequences[0].tokens == [1, 2, 3]
|
|
|
|
def test_multithreaded_mixed_operations(self):
|
|
"""sample() and compute_logprobs() from different threads simultaneously."""
|
|
proxy = _create_proxy(delay=0.01)
|
|
errors: list[Exception] = []
|
|
|
|
def _sample_worker() -> None:
|
|
try:
|
|
for _ in range(10):
|
|
r = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
|
assert isinstance(r, types.SampleResponse)
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
def _logprobs_worker() -> None:
|
|
try:
|
|
for _ in range(10):
|
|
r = proxy.compute_logprobs(_PROMPT).result(timeout=10)
|
|
assert r == [0.1, 0.2, None]
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
threads = [threading.Thread(target=_sample_worker) for _ in range(3)]
|
|
threads += [threading.Thread(target=_logprobs_worker) for _ in range(3)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join(timeout=30)
|
|
|
|
assert not errors, f"Errors: {errors}"
|
|
|
|
def test_async_concurrent_samples(self):
|
|
"""Multiple async sample calls via asyncio.gather all resolve."""
|
|
proxy = _create_proxy(delay=0.01)
|
|
|
|
async def _run() -> list[types.SampleResponse]:
|
|
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
|
|
|
coros = [
|
|
AwaitableConcurrentFuture(proxy.sample(_PROMPT, 1, _PARAMS)) for _ in range(20)
|
|
]
|
|
return await asyncio.gather(*coros)
|
|
|
|
results = asyncio.run(_run())
|
|
assert len(results) == 20
|
|
for r in results:
|
|
assert isinstance(r, types.SampleResponse)
|
|
|
|
def test_cancelled_future_does_not_crash_collector(self):
|
|
"""Cancelling a future doesn't kill the collector thread."""
|
|
proxy = _create_proxy(delay=0.5)
|
|
|
|
future1 = proxy.sample(_PROMPT, 1, _PARAMS)
|
|
future1.cancel()
|
|
|
|
result = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
|
assert isinstance(result, types.SampleResponse)
|
|
|
|
def test_multiple_clients_share_sidecar(self):
|
|
"""Two independent clients sharing the sidecar singleton work concurrently."""
|
|
proxy1 = _create_proxy(delay=0.01)
|
|
proxy2 = _create_proxy(delay=0.01)
|
|
errors: list[Exception] = []
|
|
|
|
def _worker1() -> None:
|
|
try:
|
|
for _ in range(10):
|
|
r = proxy1.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
|
assert isinstance(r, types.SampleResponse)
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
def _worker2() -> None:
|
|
try:
|
|
for _ in range(10):
|
|
r = proxy2.compute_logprobs(_PROMPT).result(timeout=10)
|
|
assert r == [0.1, 0.2, None]
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
t1 = threading.Thread(target=_worker1)
|
|
t2 = threading.Thread(target=_worker2)
|
|
t1.start()
|
|
t2.start()
|
|
t1.join(timeout=30)
|
|
t2.join(timeout=30)
|
|
|
|
assert not errors, f"Errors: {errors}"
|
|
|
|
def test_pickle_roundtrip_then_concurrent_use(self):
|
|
"""Pickle a client, restore it, then use from multiple threads."""
|
|
proxy = _create_proxy(delay=0.01)
|
|
restored = pickle.loads(pickle.dumps(proxy))
|
|
assert restored._sampling_client_sidecar_handle is not None
|
|
|
|
errors: list[Exception] = []
|
|
|
|
def _worker() -> None:
|
|
try:
|
|
for _ in range(10):
|
|
r = restored.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
|
assert isinstance(r, types.SampleResponse)
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
threads = [threading.Thread(target=_worker) for _ in range(5)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join(timeout=30)
|
|
|
|
assert not errors, f"Errors: {errors}"
|