mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
cb07a58315
commit
e2afb305fc
5 changed files with 118 additions and 24 deletions
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "d2a7909ec8644844287a7f2cd210624364627ac4",
|
||||
"last_sync_time": "2026-02-05T22:04:43.526878"
|
||||
"last_synced_sha": "b458e037671b2467e93f3976d75056753e552ca0",
|
||||
"last_sync_time": "2026-02-06T06:27:03.455957"
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "tinker"
|
||||
version = "0.11.0"
|
||||
version = "0.12.0"
|
||||
description = "The official Python SDK for the tinker API"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
|
|
|
|||
|
|
@ -4,11 +4,13 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from collections.abc import Coroutine, Generator
|
||||
from contextlib import AbstractContextManager, asynccontextmanager, contextmanager
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
|
@ -97,6 +99,29 @@ class InternalClientHolderThreadSingleton:
|
|||
_internal_client_holder_thread_singleton = InternalClientHolderThreadSingleton()
|
||||
|
||||
|
||||
class _ShadowHolderSingleton:
|
||||
"""Singleton to cache shadow InternalClientHolders by constructor args."""
|
||||
|
||||
def __init__(self):
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
# Key is (session_id, json-serialized kwargs)
|
||||
self._cache: dict[tuple[str, str], weakref.ref[InternalClientHolder]] = {}
|
||||
|
||||
def get_or_create(self, session_id: str, kwargs: dict[str, Any]) -> InternalClientHolder:
|
||||
key = (session_id, json.dumps(kwargs, sort_keys=True))
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
holder = self._cache[key]()
|
||||
if holder is not None:
|
||||
return holder
|
||||
holder = InternalClientHolder(session_id=session_id, **kwargs)
|
||||
self._cache[key] = weakref.ref(holder)
|
||||
return holder
|
||||
|
||||
|
||||
_shadow_holder_singleton = _ShadowHolderSingleton()
|
||||
|
||||
|
||||
class BytesSemaphore:
|
||||
def __init__(self, max_bytes: int):
|
||||
self._bytes: int = max_bytes
|
||||
|
|
@ -123,29 +148,44 @@ class BytesSemaphore:
|
|||
|
||||
|
||||
class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||
def __init__(self, user_metadata: dict[str, str] | None = None, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
user_metadata: dict[str, str] | None = None,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._constructor_kwargs = kwargs
|
||||
# So we can use async eventloop for parallel sampling requests
|
||||
# in sync code.
|
||||
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
|
||||
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
|
||||
self._sample_backoff_until: float | None = None
|
||||
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
||||
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
|
||||
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
||||
self._telemetry: Telemetry | None = None
|
||||
self._session_heartbeat_task: asyncio.Task[None] | None = None
|
||||
session_id, session_heartbeat_task = self.run_coroutine_threadsafe(
|
||||
self._create_session(user_metadata)
|
||||
).result()
|
||||
self._session_heartbeat_task = session_heartbeat_task
|
||||
self._session_id: str = session_id
|
||||
self._telemetry = init_telemetry(self, session_id=self._session_id)
|
||||
|
||||
self._training_client_counter: int = 0
|
||||
self._training_client_lock: threading.Lock = threading.Lock()
|
||||
|
||||
self._sampling_client_counter: int = 0
|
||||
if session_id is not None:
|
||||
# Shadow mode: reuse existing session, can't create new clients
|
||||
self._session_id: str = session_id
|
||||
self._training_client_counter: int | None = None
|
||||
self._sampling_client_counter: int | None = None
|
||||
else:
|
||||
# Normal mode: create new session
|
||||
self._session_id = self.run_coroutine_threadsafe(
|
||||
self._create_session(user_metadata)
|
||||
).result()
|
||||
self._training_client_counter = 0
|
||||
self._sampling_client_counter = 0
|
||||
|
||||
self._session_heartbeat_task: asyncio.Task[None] = self.run_coroutine_threadsafe(
|
||||
self._start_heartbeat()
|
||||
).result()
|
||||
self._telemetry: Telemetry | None = init_telemetry(self, session_id=self._session_id)
|
||||
|
||||
@classmethod
|
||||
def get_shadow_holder(cls, session_id: str, kwargs: dict[str, Any]) -> InternalClientHolder:
|
||||
"""Get or create a shadow holder from the singleton cache."""
|
||||
return _shadow_holder_singleton.get_or_create(session_id, kwargs)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _sample_dispatch_count_rate_limit(self):
|
||||
|
|
@ -214,6 +254,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
) -> str:
|
||||
if model_path and not model_path.startswith("tinker://"):
|
||||
raise ValueError("model_path must start with 'tinker://'")
|
||||
# _create_sampling_session can only be called via a ServiceClient.
|
||||
# ServiceClient will never have a shadow holder, so we can safely assert.
|
||||
assert self._sampling_client_counter is not None
|
||||
sampling_session_seq_id = self._sampling_client_counter
|
||||
self._sampling_client_counter += 1
|
||||
with self.aclient(ClientConnectionPoolType.SESSION) as client:
|
||||
|
|
@ -226,9 +269,11 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
result = await client.service.create_sampling_session(request=request)
|
||||
return result.sampling_session_id
|
||||
|
||||
async def _create_session(
|
||||
self, user_metadata: dict[str, str] | None = None
|
||||
) -> tuple[str, asyncio.Task[None]]:
|
||||
async def _start_heartbeat(self) -> asyncio.Task[None]:
|
||||
"""Start the session heartbeat task."""
|
||||
return asyncio.create_task(self._session_heartbeat(self._session_id))
|
||||
|
||||
async def _create_session(self, user_metadata: dict[str, str] | None = None) -> str:
|
||||
if (tags_str := os.environ.get("TINKER_TAGS")) is not None:
|
||||
tags: set[str] = set(tags_str.split(","))
|
||||
else:
|
||||
|
|
@ -244,9 +289,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
logger.warning(result.warning_message)
|
||||
if result.error_message:
|
||||
logger.error(result.error_message)
|
||||
session_id = result.session_id
|
||||
session_heartbeat_task = asyncio.create_task(self._session_heartbeat(session_id))
|
||||
return session_id, session_heartbeat_task
|
||||
return result.session_id
|
||||
|
||||
def _get_client_connection_pool(
|
||||
self, client_pool_type: ClientConnectionPoolType
|
||||
|
|
@ -266,6 +309,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
return self._session_id
|
||||
|
||||
def get_training_client_id(self) -> int:
|
||||
# get_training_client_id can only be called via a ServiceClient.
|
||||
# ServiceClient will never have a shadow holder, so we can safely assert.
|
||||
assert self._training_client_counter is not None
|
||||
with self._training_client_lock:
|
||||
training_client_id = self._training_client_counter
|
||||
self._training_client_counter += 1
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures import Future as ConcurrentFuture
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
|
||||
import tinker
|
||||
from tinker import types
|
||||
|
|
@ -56,6 +57,14 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
|
||||
result = future.result()
|
||||
```
|
||||
|
||||
Multi-processing support:
|
||||
This class is picklable, so it can be passed to a separate process/worker to sample. It is also
|
||||
safe to pass the same instance of SamplingClient to multiple processes/workers.
|
||||
|
||||
If you are using Tinker SDK with more than one process you should always create SamplingClient from
|
||||
the main process and then pass it to the other processes/workers.
|
||||
ServiceClient and TrainingClient should always be managed from the main process.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -63,6 +72,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
holder: InternalClientHolder,
|
||||
*,
|
||||
sampling_session_id: str,
|
||||
shadow: bool = False,
|
||||
retry_config: RetryConfig | None = None,
|
||||
):
|
||||
self.holder = holder
|
||||
|
|
@ -81,6 +91,10 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
self._sampling_session_id: str = sampling_session_id
|
||||
|
||||
self._request_id_counter: int = 0
|
||||
if shadow:
|
||||
# Start request_id_counter at a random high value to avoid collisions
|
||||
# with the original client or other unpickled copies
|
||||
self._request_id_counter = 1_000_000_000 * random.randint(1, 1_000_000)
|
||||
|
||||
@staticmethod
|
||||
async def _create_impl(
|
||||
|
|
@ -323,6 +337,21 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
def __reduce__(self) -> tuple[Any, tuple[str, str, dict[str, Any]]]:
|
||||
"""Enable pickling of SamplingClient for subprocess use.
|
||||
|
||||
Stores the sampling_session_id and holder constructor kwargs.
|
||||
On unpickle, creates a shadow holder and reconstructs the client.
|
||||
"""
|
||||
return (
|
||||
_unpickle_sampling_client,
|
||||
(
|
||||
self.holder.get_session_id(),
|
||||
self._sampling_session_id,
|
||||
self.holder._constructor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
def on_queue_state_change(
|
||||
self, queue_state: QueueState, queue_state_reason: str | None
|
||||
) -> None:
|
||||
|
|
@ -345,6 +374,23 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
)
|
||||
|
||||
|
||||
def _unpickle_sampling_client(
|
||||
session_id: str,
|
||||
sampling_session_id: str,
|
||||
constructor_kwargs: dict[str, Any],
|
||||
) -> SamplingClient:
|
||||
"""Reconstruct a SamplingClient from pickled data.
|
||||
|
||||
Creates a shadow InternalClientHolder and builds a new SamplingClient.
|
||||
The request_id_counter starts at a random high value to avoid collisions.
|
||||
"""
|
||||
from ..internal_client_holder import InternalClientHolder
|
||||
|
||||
holder = InternalClientHolder.get_shadow_holder(session_id, constructor_kwargs)
|
||||
client = SamplingClient(holder, sampling_session_id=sampling_session_id, shadow=True)
|
||||
return client
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def _get_retry_handler(
|
||||
name: str, retry_config: RetryConfig | None = None, telemetry: Telemetry | None = None
|
||||
|
|
|
|||
|
|
@ -644,6 +644,8 @@ class TrainingClient(TelemetryProvider):
|
|||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
else:
|
||||
# Training client can never be created from a shadow holder, so we can safely assert
|
||||
assert self.holder._sampling_client_counter is not None
|
||||
sampling_session_seq_id = self.holder._sampling_client_counter
|
||||
self.holder._sampling_client_counter += 1
|
||||
request = types.SaveWeightsForSamplerRequest(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue