tinker/tests/test_subprocess_sampling_client.py
2026-03-19 00:10:49 +00:00

384 lines
14 KiB
Python

"""Tests for SamplingClient subprocess mode.
These tests use a picklable fake SamplingClient to verify that
subprocess mode correctly routes sample() and compute_logprobs()
through the sidecar subprocess.
Test organization:
TestRPCRouting — sample/compute_logprobs delegation through sidecar
TestErrorHandling — error propagation, sidecar death
TestPickle — roundtrip with/without sidecar, re-enable mode
TestConcurrency — multithreaded, async, cancelled futures, mixed ops
"""
from __future__ import annotations
import asyncio
import pickle
import threading
import time
from concurrent.futures import Future as ConcurrentFuture
from typing import Any
import pytest
from tinker import types
from tinker._exceptions import SidecarDiedError
from tinker.lib.sidecar import create_sidecar_handle
# ---------------------------------------------------------------------------
# Picklable fake SamplingClient (must be module-level for pickling)
# ---------------------------------------------------------------------------
class _FakeSamplingClient:
"""A picklable fake that simulates SamplingClient for testing.
This is NOT a real SamplingClient — it provides just enough interface
to test the sidecar integration. Real SamplingClient requires an
InternalClientHolder and API connection.
"""
def __init__(self, delay: float = 0.0, fail: bool = False, subprocess_sampling: bool = False):
self._delay = delay
self._fail = fail
self._sampling_client_sidecar_handle = None # set by create_sidecar_handle() in tests
if subprocess_sampling:
from tinker.lib.sidecar import _inside_sidecar
if not _inside_sidecar:
self._sampling_client_sidecar_handle = create_sidecar_handle(self)
def sample(
self,
prompt: types.ModelInput,
num_samples: int,
sampling_params: types.SamplingParams,
include_prompt_logprobs: bool = False,
topk_prompt_logprobs: int = 0,
) -> Any:
# Delegate through sidecar if enabled (mirrors real SamplingClient behavior)
if self._sampling_client_sidecar_handle is not None:
from tinker.lib.public_interfaces.sampling_client import _SampleRPC
return self._sampling_client_sidecar_handle.submit_rpc(
_SampleRPC(
prompt,
num_samples,
sampling_params,
include_prompt_logprobs,
topk_prompt_logprobs,
)
)
f: ConcurrentFuture[types.SampleResponse] = ConcurrentFuture()
if self._fail:
f.set_exception(RuntimeError("Simulated sample failure"))
elif self._delay > 0:
def _delayed():
time.sleep(self._delay)
f.set_result(_make_sample_response())
threading.Thread(target=_delayed, daemon=True).start()
else:
f.set_result(_make_sample_response())
return f
def compute_logprobs(self, prompt: types.ModelInput) -> Any:
# Delegate through sidecar if enabled (mirrors real SamplingClient behavior)
if self._sampling_client_sidecar_handle is not None:
from tinker.lib.public_interfaces.sampling_client import _ComputeLogprobsRPC
return self._sampling_client_sidecar_handle.submit_rpc(_ComputeLogprobsRPC(prompt))
f: ConcurrentFuture[list[float | None]] = ConcurrentFuture()
if self._fail:
f.set_exception(RuntimeError("Simulated logprobs failure"))
else:
f.set_result([0.1, 0.2, None])
return f
def __reduce__(self) -> tuple[type, tuple[float, bool, bool]]:
return (
_FakeSamplingClient,
(self._delay, self._fail, self._sampling_client_sidecar_handle is not None),
)
def _make_sample_response() -> types.SampleResponse:
return types.SampleResponse(
sequences=[
types.SampledSequence(
stop_reason="length",
tokens=[1, 2, 3],
logprobs=[0.1, 0.2, 0.3],
)
],
type="sample",
)
def _create_proxy(delay: float = 0.0, fail: bool = False) -> _FakeSamplingClient:
"""Create a fake client with sidecar handle for testing."""
client = _FakeSamplingClient(delay=delay, fail=fail)
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
return client
_PROMPT = types.ModelInput.from_ints([1, 2, 3])
_PARAMS = types.SamplingParams(max_tokens=10)
# ===========================================================================
# Tests
# ===========================================================================
class TestRPCRouting:
"""Verify sample() and compute_logprobs() are routed through the sidecar."""
def test_sample(self):
"""sample() → subprocess → SampleResponse."""
proxy = _create_proxy()
result = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(result, types.SampleResponse)
assert result.sequences[0].tokens == [1, 2, 3]
def test_constructor_enables_subprocess_mode(self):
"""subprocess_sampling=True in __init__ creates the sidecar handle."""
client = _FakeSamplingClient(subprocess_sampling=True)
assert client._sampling_client_sidecar_handle is not None
result = client.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(result, types.SampleResponse)
def test_compute_logprobs(self):
"""compute_logprobs() → subprocess → list of logprobs."""
proxy = _create_proxy()
result = proxy.compute_logprobs(_PROMPT).result(timeout=10)
assert result == [0.1, 0.2, None]
def test_mixed_sample_and_logprobs(self):
"""Interleaved sample() and compute_logprobs() all resolve correctly."""
proxy = _create_proxy(delay=0.01)
futures_sample = [proxy.sample(_PROMPT, 1, _PARAMS) for _ in range(10)]
futures_logprobs = [proxy.compute_logprobs(_PROMPT) for _ in range(10)]
for f in futures_sample:
result = f.result(timeout=10)
assert isinstance(result, types.SampleResponse)
assert result.sequences[0].tokens == [1, 2, 3]
for f in futures_logprobs:
assert f.result(timeout=10) == [0.1, 0.2, None]
class TestErrorHandling:
"""Error propagation from subprocess to caller."""
def test_sample_error(self):
"""Exceptions from sample() in the subprocess are propagated."""
proxy = _create_proxy(fail=True)
with pytest.raises(RuntimeError, match="Simulated sample failure"):
proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
def test_compute_logprobs_error(self):
"""Exceptions from compute_logprobs() in the subprocess are propagated."""
proxy = _create_proxy(fail=True)
with pytest.raises(RuntimeError, match="Simulated logprobs failure"):
proxy.compute_logprobs(_PROMPT).result(timeout=10)
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
def test_sidecar_death_fails_pending_futures(self):
"""When the subprocess is killed, pending futures get SidecarDiedError."""
proxy = _create_proxy(delay=0.5)
future = proxy.sample(_PROMPT, 1, _PARAMS)
# Kill the underlying subprocess directly
sidecar = proxy._sampling_client_sidecar_handle._sidecar
assert sidecar._process is not None
sidecar._process.kill()
sidecar._process.join(timeout=5)
with pytest.raises(SidecarDiedError):
future.result(timeout=5)
class TestPickle:
"""Pickle roundtrip preserves subprocess mode correctly."""
def test_roundtrip_preserves_subprocess_mode(self):
"""Pickling a sidecar-enabled client re-enables subprocess mode on unpickle."""
proxy = _create_proxy()
assert proxy._sampling_client_sidecar_handle is not None
restored = pickle.loads(pickle.dumps(proxy))
assert restored._sampling_client_sidecar_handle is not None
result = restored.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(result, types.SampleResponse)
def test_roundtrip_without_sidecar(self):
"""Pickling a client without subprocess mode keeps it disabled."""
client = _FakeSamplingClient()
assert client._sampling_client_sidecar_handle is None
restored = pickle.loads(pickle.dumps(client))
assert restored._sampling_client_sidecar_handle is None
def test_re_enable_subprocess_mode(self):
"""Replacing the sidecar handle works cleanly."""
client = _FakeSamplingClient()
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
# First handle works
assert isinstance(
client.sample(_PROMPT, 1, _PARAMS).result(timeout=10), types.SampleResponse
)
# Replace with a new handle (old one is GC'd and unregistered)
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
# New handle also works
assert isinstance(
client.sample(_PROMPT, 1, _PARAMS).result(timeout=10), types.SampleResponse
)
class TestConcurrency:
"""Thread safety and concurrent operations through the sidecar."""
def test_multithreaded_samples(self):
"""sample() from 20 threads all resolve correctly."""
proxy = _create_proxy(delay=0.01)
results: list[types.SampleResponse | None] = [None] * 20
errors: list[Exception] = []
def _worker(idx: int) -> None:
try:
results[idx] = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=30)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=_worker, args=(i,)) for i in range(20)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=30)
assert not errors, f"Threads raised: {errors}"
for r in results:
assert isinstance(r, types.SampleResponse)
assert r.sequences[0].tokens == [1, 2, 3]
def test_multithreaded_mixed_operations(self):
"""sample() and compute_logprobs() from different threads simultaneously."""
proxy = _create_proxy(delay=0.01)
errors: list[Exception] = []
def _sample_worker() -> None:
try:
for _ in range(10):
r = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(r, types.SampleResponse)
except Exception as e:
errors.append(e)
def _logprobs_worker() -> None:
try:
for _ in range(10):
r = proxy.compute_logprobs(_PROMPT).result(timeout=10)
assert r == [0.1, 0.2, None]
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=_sample_worker) for _ in range(3)]
threads += [threading.Thread(target=_logprobs_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=30)
assert not errors, f"Errors: {errors}"
def test_async_concurrent_samples(self):
"""Multiple async sample calls via asyncio.gather all resolve."""
proxy = _create_proxy(delay=0.01)
async def _run() -> list[types.SampleResponse]:
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
coros = [
AwaitableConcurrentFuture(proxy.sample(_PROMPT, 1, _PARAMS)) for _ in range(20)
]
return await asyncio.gather(*coros)
results = asyncio.run(_run())
assert len(results) == 20
for r in results:
assert isinstance(r, types.SampleResponse)
def test_cancelled_future_does_not_crash_collector(self):
"""Cancelling a future doesn't kill the collector thread."""
proxy = _create_proxy(delay=0.5)
future1 = proxy.sample(_PROMPT, 1, _PARAMS)
future1.cancel()
result = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(result, types.SampleResponse)
def test_multiple_clients_share_sidecar(self):
"""Two independent clients sharing the sidecar singleton work concurrently."""
proxy1 = _create_proxy(delay=0.01)
proxy2 = _create_proxy(delay=0.01)
errors: list[Exception] = []
def _worker1() -> None:
try:
for _ in range(10):
r = proxy1.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(r, types.SampleResponse)
except Exception as e:
errors.append(e)
def _worker2() -> None:
try:
for _ in range(10):
r = proxy2.compute_logprobs(_PROMPT).result(timeout=10)
assert r == [0.1, 0.2, None]
except Exception as e:
errors.append(e)
t1 = threading.Thread(target=_worker1)
t2 = threading.Thread(target=_worker2)
t1.start()
t2.start()
t1.join(timeout=30)
t2.join(timeout=30)
assert not errors, f"Errors: {errors}"
def test_pickle_roundtrip_then_concurrent_use(self):
"""Pickle a client, restore it, then use from multiple threads."""
proxy = _create_proxy(delay=0.01)
restored = pickle.loads(pickle.dumps(proxy))
assert restored._sampling_client_sidecar_handle is not None
errors: list[Exception] = []
def _worker() -> None:
try:
for _ in range(10):
r = restored.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(r, types.SampleResponse)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=_worker) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=30)
assert not errors, f"Errors: {errors}"