diff --git a/.sync_state b/.sync_state index a6ae7d5..9ea7d8c 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "d2a7909ec8644844287a7f2cd210624364627ac4", - "last_sync_time": "2026-02-05T22:04:43.526878" + "last_synced_sha": "b458e037671b2467e93f3976d75056753e552ca0", + "last_sync_time": "2026-02-06T06:27:03.455957" } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index af04ae9..6bb8029 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tinker" -version = "0.11.0" +version = "0.12.0" description = "The official Python SDK for the tinker API" readme = "README.md" license = "Apache-2.0" diff --git a/src/tinker/lib/internal_client_holder.py b/src/tinker/lib/internal_client_holder.py index bc8e9a6..b4769a9 100644 --- a/src/tinker/lib/internal_client_holder.py +++ b/src/tinker/lib/internal_client_holder.py @@ -4,11 +4,13 @@ from __future__ import annotations import asyncio import contextlib +import json import logging import os import threading import time import traceback +import weakref from collections.abc import Coroutine, Generator from contextlib import AbstractContextManager, asynccontextmanager, contextmanager from typing import Any, Awaitable, Callable, TypeVar @@ -97,6 +99,29 @@ class InternalClientHolderThreadSingleton: _internal_client_holder_thread_singleton = InternalClientHolderThreadSingleton() +class _ShadowHolderSingleton: + """Singleton to cache shadow InternalClientHolders by constructor args.""" + + def __init__(self): + self._lock: threading.Lock = threading.Lock() + # Key is (session_id, json-serialized kwargs) + self._cache: dict[tuple[str, str], weakref.ref[InternalClientHolder]] = {} + + def get_or_create(self, session_id: str, kwargs: dict[str, Any]) -> InternalClientHolder: + key = (session_id, json.dumps(kwargs, sort_keys=True)) + with self._lock: + if key in self._cache: + holder = self._cache[key]() + if holder is not None: + return holder + holder = InternalClientHolder(session_id=session_id, **kwargs) + self._cache[key] = weakref.ref(holder) + return holder + + +_shadow_holder_singleton = _ShadowHolderSingleton() + + class BytesSemaphore: def __init__(self, max_bytes: int): self._bytes: int = max_bytes @@ -123,29 +148,44 @@ class BytesSemaphore: class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): - def __init__(self, user_metadata: dict[str, str] | None = None, **kwargs: Any) -> None: + def __init__( + self, + user_metadata: dict[str, str] | None = None, + *, + session_id: str | None = None, + **kwargs: Any, + ) -> None: self._constructor_kwargs = kwargs - # So we can use async eventloop for parallel sampling requests - # in sync code. self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop() self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {} self._sample_backoff_until: float | None = None self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400) self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10) self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024) - self._telemetry: Telemetry | None = None - self._session_heartbeat_task: asyncio.Task[None] | None = None - session_id, session_heartbeat_task = self.run_coroutine_threadsafe( - self._create_session(user_metadata) - ).result() - self._session_heartbeat_task = session_heartbeat_task - self._session_id: str = session_id - self._telemetry = init_telemetry(self, session_id=self._session_id) - - self._training_client_counter: int = 0 self._training_client_lock: threading.Lock = threading.Lock() - self._sampling_client_counter: int = 0 + if session_id is not None: + # Shadow mode: reuse existing session, can't create new clients + self._session_id: str = session_id + self._training_client_counter: int | None = None + self._sampling_client_counter: int | None = None + else: + # Normal mode: create new session + self._session_id = self.run_coroutine_threadsafe( + self._create_session(user_metadata) + ).result() + self._training_client_counter = 0 + self._sampling_client_counter = 0 + + self._session_heartbeat_task: asyncio.Task[None] = self.run_coroutine_threadsafe( + self._start_heartbeat() + ).result() + self._telemetry: Telemetry | None = init_telemetry(self, session_id=self._session_id) + + @classmethod + def get_shadow_holder(cls, session_id: str, kwargs: dict[str, Any]) -> InternalClientHolder: + """Get or create a shadow holder from the singleton cache.""" + return _shadow_holder_singleton.get_or_create(session_id, kwargs) @asynccontextmanager async def _sample_dispatch_count_rate_limit(self): @@ -214,6 +254,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): ) -> str: if model_path and not model_path.startswith("tinker://"): raise ValueError("model_path must start with 'tinker://'") + # _create_sampling_session can only be called via a ServiceClient. + # ServiceClient will never have a shadow holder, so we can safely assert. + assert self._sampling_client_counter is not None sampling_session_seq_id = self._sampling_client_counter self._sampling_client_counter += 1 with self.aclient(ClientConnectionPoolType.SESSION) as client: @@ -226,9 +269,11 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): result = await client.service.create_sampling_session(request=request) return result.sampling_session_id - async def _create_session( - self, user_metadata: dict[str, str] | None = None - ) -> tuple[str, asyncio.Task[None]]: + async def _start_heartbeat(self) -> asyncio.Task[None]: + """Start the session heartbeat task.""" + return asyncio.create_task(self._session_heartbeat(self._session_id)) + + async def _create_session(self, user_metadata: dict[str, str] | None = None) -> str: if (tags_str := os.environ.get("TINKER_TAGS")) is not None: tags: set[str] = set(tags_str.split(",")) else: @@ -244,9 +289,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): logger.warning(result.warning_message) if result.error_message: logger.error(result.error_message) - session_id = result.session_id - session_heartbeat_task = asyncio.create_task(self._session_heartbeat(session_id)) - return session_id, session_heartbeat_task + return result.session_id def _get_client_connection_pool( self, client_pool_type: ClientConnectionPoolType @@ -266,6 +309,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): return self._session_id def get_training_client_id(self) -> int: + # get_training_client_id can only be called via a ServiceClient. + # ServiceClient will never have a shadow holder, so we can safely assert. + assert self._training_client_counter is not None with self._training_client_lock: training_client_id = self._training_client_counter self._training_client_counter += 1 diff --git a/src/tinker/lib/public_interfaces/sampling_client.py b/src/tinker/lib/public_interfaces/sampling_client.py index cddca92..7e82492 100644 --- a/src/tinker/lib/public_interfaces/sampling_client.py +++ b/src/tinker/lib/public_interfaces/sampling_client.py @@ -5,10 +5,11 @@ from __future__ import annotations import asyncio import logging import os +import random import time from concurrent.futures import Future as ConcurrentFuture from functools import lru_cache -from typing import TYPE_CHECKING, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast import tinker from tinker import types @@ -56,6 +57,14 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1) result = future.result() ``` + + Multi-processing support: + This class is picklable, so it can be passed to a separate process/worker to sample. It is also + safe to pass the same instance of SamplingClient to multiple processes/workers. + + If you are using Tinker SDK with more than one process you should always create SamplingClient from + the main process and then pass it to the other processes/workers. + ServiceClient and TrainingClient should always be managed from the main process. """ def __init__( @@ -63,6 +72,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): holder: InternalClientHolder, *, sampling_session_id: str, + shadow: bool = False, retry_config: RetryConfig | None = None, ): self.holder = holder @@ -81,6 +91,10 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): self._sampling_session_id: str = sampling_session_id self._request_id_counter: int = 0 + if shadow: + # Start request_id_counter at a random high value to avoid collisions + # with the original client or other unpickled copies + self._request_id_counter = 1_000_000_000 * random.randint(1, 1_000_000) @staticmethod async def _create_impl( @@ -323,6 +337,21 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): def get_telemetry(self) -> Telemetry | None: return self.holder.get_telemetry() + def __reduce__(self) -> tuple[Any, tuple[str, str, dict[str, Any]]]: + """Enable pickling of SamplingClient for subprocess use. + + Stores the sampling_session_id and holder constructor kwargs. + On unpickle, creates a shadow holder and reconstructs the client. + """ + return ( + _unpickle_sampling_client, + ( + self.holder.get_session_id(), + self._sampling_session_id, + self.holder._constructor_kwargs, + ), + ) + def on_queue_state_change( self, queue_state: QueueState, queue_state_reason: str | None ) -> None: @@ -345,6 +374,23 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): ) +def _unpickle_sampling_client( + session_id: str, + sampling_session_id: str, + constructor_kwargs: dict[str, Any], +) -> SamplingClient: + """Reconstruct a SamplingClient from pickled data. + + Creates a shadow InternalClientHolder and builds a new SamplingClient. + The request_id_counter starts at a random high value to avoid collisions. + """ + from ..internal_client_holder import InternalClientHolder + + holder = InternalClientHolder.get_shadow_holder(session_id, constructor_kwargs) + client = SamplingClient(holder, sampling_session_id=sampling_session_id, shadow=True) + return client + + @lru_cache(maxsize=100) def _get_retry_handler( name: str, retry_config: RetryConfig | None = None, telemetry: Telemetry | None = None diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index aa40c78..67c94cd 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -644,6 +644,8 @@ class TrainingClient(TelemetryProvider): ttl_seconds=ttl_seconds, ) else: + # Training client can never be created from a shadow holder, so we can safely assert + assert self.holder._sampling_client_counter is not None sampling_session_seq_id = self.holder._sampling_client_counter self.holder._sampling_client_counter += 1 request = types.SaveWeightsForSamplerRequest(