Sync contents

This commit is contained in:
Andrii Grynenko 2026-02-06 06:27:03 +00:00
parent cb07a58315
commit e2afb305fc
5 changed files with 118 additions and 24 deletions

View file

@ -1,4 +1,4 @@
{
"last_synced_sha": "d2a7909ec8644844287a7f2cd210624364627ac4",
"last_sync_time": "2026-02-05T22:04:43.526878"
"last_synced_sha": "b458e037671b2467e93f3976d75056753e552ca0",
"last_sync_time": "2026-02-06T06:27:03.455957"
}

View file

@ -1,6 +1,6 @@
[project]
name = "tinker"
version = "0.11.0"
version = "0.12.0"
description = "The official Python SDK for the tinker API"
readme = "README.md"
license = "Apache-2.0"

View file

@ -4,11 +4,13 @@ from __future__ import annotations
import asyncio
import contextlib
import json
import logging
import os
import threading
import time
import traceback
import weakref
from collections.abc import Coroutine, Generator
from contextlib import AbstractContextManager, asynccontextmanager, contextmanager
from typing import Any, Awaitable, Callable, TypeVar
@ -97,6 +99,29 @@ class InternalClientHolderThreadSingleton:
_internal_client_holder_thread_singleton = InternalClientHolderThreadSingleton()
class _ShadowHolderSingleton:
"""Singleton to cache shadow InternalClientHolders by constructor args."""
def __init__(self):
self._lock: threading.Lock = threading.Lock()
# Key is (session_id, json-serialized kwargs)
self._cache: dict[tuple[str, str], weakref.ref[InternalClientHolder]] = {}
def get_or_create(self, session_id: str, kwargs: dict[str, Any]) -> InternalClientHolder:
key = (session_id, json.dumps(kwargs, sort_keys=True))
with self._lock:
if key in self._cache:
holder = self._cache[key]()
if holder is not None:
return holder
holder = InternalClientHolder(session_id=session_id, **kwargs)
self._cache[key] = weakref.ref(holder)
return holder
_shadow_holder_singleton = _ShadowHolderSingleton()
class BytesSemaphore:
def __init__(self, max_bytes: int):
self._bytes: int = max_bytes
@ -123,29 +148,44 @@ class BytesSemaphore:
class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
def __init__(self, user_metadata: dict[str, str] | None = None, **kwargs: Any) -> None:
def __init__(
self,
user_metadata: dict[str, str] | None = None,
*,
session_id: str | None = None,
**kwargs: Any,
) -> None:
self._constructor_kwargs = kwargs
# So we can use async eventloop for parallel sampling requests
# in sync code.
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
self._sample_backoff_until: float | None = None
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
self._telemetry: Telemetry | None = None
self._session_heartbeat_task: asyncio.Task[None] | None = None
session_id, session_heartbeat_task = self.run_coroutine_threadsafe(
self._create_session(user_metadata)
).result()
self._session_heartbeat_task = session_heartbeat_task
self._session_id: str = session_id
self._telemetry = init_telemetry(self, session_id=self._session_id)
self._training_client_counter: int = 0
self._training_client_lock: threading.Lock = threading.Lock()
self._sampling_client_counter: int = 0
if session_id is not None:
# Shadow mode: reuse existing session, can't create new clients
self._session_id: str = session_id
self._training_client_counter: int | None = None
self._sampling_client_counter: int | None = None
else:
# Normal mode: create new session
self._session_id = self.run_coroutine_threadsafe(
self._create_session(user_metadata)
).result()
self._training_client_counter = 0
self._sampling_client_counter = 0
self._session_heartbeat_task: asyncio.Task[None] = self.run_coroutine_threadsafe(
self._start_heartbeat()
).result()
self._telemetry: Telemetry | None = init_telemetry(self, session_id=self._session_id)
@classmethod
def get_shadow_holder(cls, session_id: str, kwargs: dict[str, Any]) -> InternalClientHolder:
"""Get or create a shadow holder from the singleton cache."""
return _shadow_holder_singleton.get_or_create(session_id, kwargs)
@asynccontextmanager
async def _sample_dispatch_count_rate_limit(self):
@ -214,6 +254,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
) -> str:
if model_path and not model_path.startswith("tinker://"):
raise ValueError("model_path must start with 'tinker://'")
# _create_sampling_session can only be called via a ServiceClient.
# ServiceClient will never have a shadow holder, so we can safely assert.
assert self._sampling_client_counter is not None
sampling_session_seq_id = self._sampling_client_counter
self._sampling_client_counter += 1
with self.aclient(ClientConnectionPoolType.SESSION) as client:
@ -226,9 +269,11 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
result = await client.service.create_sampling_session(request=request)
return result.sampling_session_id
async def _create_session(
self, user_metadata: dict[str, str] | None = None
) -> tuple[str, asyncio.Task[None]]:
async def _start_heartbeat(self) -> asyncio.Task[None]:
"""Start the session heartbeat task."""
return asyncio.create_task(self._session_heartbeat(self._session_id))
async def _create_session(self, user_metadata: dict[str, str] | None = None) -> str:
if (tags_str := os.environ.get("TINKER_TAGS")) is not None:
tags: set[str] = set(tags_str.split(","))
else:
@ -244,9 +289,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
logger.warning(result.warning_message)
if result.error_message:
logger.error(result.error_message)
session_id = result.session_id
session_heartbeat_task = asyncio.create_task(self._session_heartbeat(session_id))
return session_id, session_heartbeat_task
return result.session_id
def _get_client_connection_pool(
self, client_pool_type: ClientConnectionPoolType
@ -266,6 +309,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
return self._session_id
def get_training_client_id(self) -> int:
# get_training_client_id can only be called via a ServiceClient.
# ServiceClient will never have a shadow holder, so we can safely assert.
assert self._training_client_counter is not None
with self._training_client_lock:
training_client_id = self._training_client_counter
self._training_client_counter += 1

View file

@ -5,10 +5,11 @@ from __future__ import annotations
import asyncio
import logging
import os
import random
import time
from concurrent.futures import Future as ConcurrentFuture
from functools import lru_cache
from typing import TYPE_CHECKING, TypeVar, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast
import tinker
from tinker import types
@ -56,6 +57,14 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result = future.result()
```
Multi-processing support:
This class is picklable, so it can be passed to a separate process/worker to sample. It is also
safe to pass the same instance of SamplingClient to multiple processes/workers.
If you are using Tinker SDK with more than one process you should always create SamplingClient from
the main process and then pass it to the other processes/workers.
ServiceClient and TrainingClient should always be managed from the main process.
"""
def __init__(
@ -63,6 +72,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
holder: InternalClientHolder,
*,
sampling_session_id: str,
shadow: bool = False,
retry_config: RetryConfig | None = None,
):
self.holder = holder
@ -81,6 +91,10 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
self._sampling_session_id: str = sampling_session_id
self._request_id_counter: int = 0
if shadow:
# Start request_id_counter at a random high value to avoid collisions
# with the original client or other unpickled copies
self._request_id_counter = 1_000_000_000 * random.randint(1, 1_000_000)
@staticmethod
async def _create_impl(
@ -323,6 +337,21 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
def __reduce__(self) -> tuple[Any, tuple[str, str, dict[str, Any]]]:
"""Enable pickling of SamplingClient for subprocess use.
Stores the sampling_session_id and holder constructor kwargs.
On unpickle, creates a shadow holder and reconstructs the client.
"""
return (
_unpickle_sampling_client,
(
self.holder.get_session_id(),
self._sampling_session_id,
self.holder._constructor_kwargs,
),
)
def on_queue_state_change(
self, queue_state: QueueState, queue_state_reason: str | None
) -> None:
@ -345,6 +374,23 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
)
def _unpickle_sampling_client(
session_id: str,
sampling_session_id: str,
constructor_kwargs: dict[str, Any],
) -> SamplingClient:
"""Reconstruct a SamplingClient from pickled data.
Creates a shadow InternalClientHolder and builds a new SamplingClient.
The request_id_counter starts at a random high value to avoid collisions.
"""
from ..internal_client_holder import InternalClientHolder
holder = InternalClientHolder.get_shadow_holder(session_id, constructor_kwargs)
client = SamplingClient(holder, sampling_session_id=sampling_session_id, shadow=True)
return client
@lru_cache(maxsize=100)
def _get_retry_handler(
name: str, retry_config: RetryConfig | None = None, telemetry: Telemetry | None = None

View file

@ -644,6 +644,8 @@ class TrainingClient(TelemetryProvider):
ttl_seconds=ttl_seconds,
)
else:
# Training client can never be created from a shadow holder, so we can safely assert
assert self.holder._sampling_client_counter is not None
sampling_session_seq_id = self.holder._sampling_client_counter
self.holder._sampling_client_counter += 1
request = types.SaveWeightsForSamplerRequest(