From 3e4e4e35601e70e677cd4c9c959e617bba806450 Mon Sep 17 00:00:00 2001 From: Andrii Grynenko Date: Fri, 21 Nov 2025 01:42:38 +0000 Subject: [PATCH] Sync contents --- .sync_state | 4 +- pyproject.toml | 2 +- src/tinker/cli/commands/checkpoint.py | 7 +-- src/tinker/cli/commands/run.py | 6 +-- src/tinker/lib/internal_client_holder.py | 13 ++++-- .../lib/public_interfaces/sampling_client.py | 46 ++++++++++++------- .../lib/public_interfaces/service_client.py | 20 +++++++- .../lib/public_interfaces/training_client.py | 18 ++++++-- 8 files changed, 81 insertions(+), 35 deletions(-) diff --git a/.sync_state b/.sync_state index ea0f064..19b7c46 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "61084a6cd11cc08c063fce7ff415b18dda9890f0", - "last_sync_time": "2025-11-20T21:08:04.245668" + "last_synced_sha": "a4665d66b6c66527095652259bc8810cf843f425", + "last_sync_time": "2025-11-21T01:42:38.758898" } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 72e311c..3abefe1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tinker" -version = "0.5.0" +version = "0.5.1" description = "The official Python SDK for the tinker API" readme = "README.md" license = "Apache-2.0" diff --git a/src/tinker/cli/commands/checkpoint.py b/src/tinker/cli/commands/checkpoint.py index 8d37c7d..5e99c74 100644 --- a/src/tinker/cli/commands/checkpoint.py +++ b/src/tinker/cli/commands/checkpoint.py @@ -501,14 +501,15 @@ def download( unless --force is specified. Examples: + + # Creates ./run-123_weights_final/ with checkpoint files tinker checkpoint download tinker://run-123/weights/final - # Creates ./run-123_weights_final/ with checkpoint files + # Creates ./models/run-123_weights_final/ with checkpoint files tinker checkpoint download tinker://run-123/weights/final --output ./models/ - # Creates ./models/run-123_weights_final/ with checkpoint files + # Overwrites existing ./run-123_weights_final/ directory tinker checkpoint download tinker://run-123/weights/final --force - # Overwrites existing ./run-123_weights_final/ directory """ # Lazy imports to maintain fast CLI startup import urllib.error diff --git a/src/tinker/cli/commands/run.py b/src/tinker/cli/commands/run.py index 68badbb..ce5f2b0 100644 --- a/src/tinker/cli/commands/run.py +++ b/src/tinker/cli/commands/run.py @@ -68,7 +68,7 @@ class RunListOutput(OutputBase): def get_table_columns(self) -> List[str]: """Return column headers for table output.""" - return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Corrupted"] + return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Status"] def get_table_rows(self) -> List[List[str]]: """Return rows for table output.""" @@ -89,7 +89,7 @@ class RunListOutput(OutputBase): run.model_owner, lora_info, format_timestamp(run.last_request_time), - str(run.corrupted), + "Failed" if run.corrupted else "Active", ] ) @@ -142,7 +142,7 @@ class RunInfoOutput(OutputBase): rows.append(["Last Update", format_timestamp(self.run.last_request_time)]) # Corruption status - rows.append(["Status", "Corrupted" if self.run.corrupted else "Active"]) + rows.append(["Status", "Failed" if self.run.corrupted else "Active"]) # Last checkpoints if self.run.last_checkpoint: diff --git a/src/tinker/lib/internal_client_holder.py b/src/tinker/lib/internal_client_holder.py index 4946caa..f944d6a 100644 --- a/src/tinker/lib/internal_client_holder.py +++ b/src/tinker/lib/internal_client_holder.py @@ -3,12 +3,12 @@ from __future__ import annotations import asyncio -import contextlib import logging import os import threading import time import traceback +import contextlib from collections.abc import Coroutine, Generator from contextlib import AbstractContextManager, contextmanager from typing import Any, Awaitable, Callable, TypeVar @@ -107,11 +107,12 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): self._sample_backoff_until: float | None = None self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400) self._telemetry: Telemetry | None = None + self._session_heartbeat_task: asyncio.Task[None] | None = None session_id, session_heartbeat_task = self.run_coroutine_threadsafe( self._create_session(user_metadata) ).result() + self._session_heartbeat_task = session_heartbeat_task self._session_id: str = session_id - self._session_heartbeat_task: asyncio.Task[None] = session_heartbeat_task self._telemetry = init_telemetry(self, session_id=self._session_id) self._training_client_counter: int = 0 @@ -148,6 +149,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): async def _create_sampling_session( self, model_path: str | None = None, base_model: str | None = None ) -> str: + if model_path and not model_path.startswith("tinker://"): + raise ValueError("model_path must start with 'tinker://'") sampling_session_seq_id = self._sampling_client_counter self._sampling_client_counter += 1 with self.aclient(ClientConnectionPoolType.SESSION) as client: @@ -223,7 +226,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): return AwaitableConcurrentFuture(asyncio.run_coroutine_threadsafe(coro, self.get_loop())) def close(self): - self.run_coroutine_threadsafe(self._async_cleanup()).result() + self.run_coroutine_threadsafe(self._async_cleanup()) if telemetry := self._telemetry: telemetry.stop() @@ -233,8 +236,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): async def _async_cleanup(self): if self._session_heartbeat_task: self._session_heartbeat_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._session_heartbeat_task + with contextlib.suppress(asyncio.CancelledError): + await self._session_heartbeat_task @staticmethod def _is_retryable_status_code(status_code: int) -> bool: diff --git a/src/tinker/lib/public_interfaces/sampling_client.py b/src/tinker/lib/public_interfaces/sampling_client.py index 1fc86e0..98ac78e 100644 --- a/src/tinker/lib/public_interfaces/sampling_client.py +++ b/src/tinker/lib/public_interfaces/sampling_client.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, TypeVar, cast import tinker from tinker import types from tinker.lib.client_connection_pool_type import ClientConnectionPoolType -from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture +from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture from tinker.lib.telemetry import Telemetry, capture_exceptions from tinker.lib.telemetry_provider import TelemetryProvider @@ -58,21 +58,14 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): self, holder: InternalClientHolder, *, - model_path: str | None = None, - base_model: str | None = None, - sampling_session_id: str | None = None, + sampling_session_id: str, retry_config: RetryConfig | None = None, ): - if model_path and not model_path.startswith("tinker://"): - raise ValueError("model_path must start with 'tinker://'") - self.holder = holder - self.model_path = model_path - self.base_model = base_model # Create retry handler with the provided configuration self.retry_handler = _get_retry_handler( - model_path or base_model, retry_config=retry_config, telemetry=holder.get_telemetry() + sampling_session_id, retry_config=retry_config, telemetry=holder.get_telemetry() ) self.feature_gates = set( @@ -81,15 +74,34 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): self._last_queue_state_logged: float = 0 - self._sampling_session_id: str = ( - sampling_session_id - or holder.run_coroutine_threadsafe( - holder._create_sampling_session(model_path=model_path, base_model=base_model) - ).result() - ) + self._sampling_session_id: str = sampling_session_id self._request_id_counter: int = 0 + @staticmethod + async def _create_impl( + holder: InternalClientHolder, + *, + model_path: str | None, + base_model: str | None, + sampling_session_id: str | None, + retry_config: RetryConfig | None, + ) -> SamplingClient: + if sampling_session_id is None: + sampling_session_id = await holder._create_sampling_session(model_path=model_path, base_model=base_model) + return SamplingClient(holder, sampling_session_id=sampling_session_id, retry_config=retry_config) + + @staticmethod + def create( + holder: InternalClientHolder, + *, + model_path: str | None = None, + base_model: str | None = None, + sampling_session_id: str | None = None, + retry_config: RetryConfig | None = None, + ) -> APIFuture[SamplingClient]: + return holder.run_coroutine_threadsafe(SamplingClient._create_impl(holder, model_path=model_path, base_model=base_model, sampling_session_id=sampling_session_id, retry_config=retry_config)) + async def _send_asample_request( self, num_samples: int, @@ -243,7 +255,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): reason = "unknown" self._last_queue_state_logged = time.time() - logger.warning(f"Sampling is paused for {self.model_path}. Reason: {reason}") + logger.warning(f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {reason}") @lru_cache(maxsize=100) diff --git a/src/tinker/lib/public_interfaces/service_client.py b/src/tinker/lib/public_interfaces/service_client.py index 75e6efd..0cf5eaa 100644 --- a/src/tinker/lib/public_interfaces/service_client.py +++ b/src/tinker/lib/public_interfaces/service_client.py @@ -226,7 +226,25 @@ class ServiceClient(TelemetryProvider): if model_path is None and base_model is None: raise ValueError("Either model_path or base_model must be provided") - return SamplingClient( + return SamplingClient.create( + self.holder, + model_path=model_path, + base_model=base_model, + retry_config=retry_config, + ).result() + + @capture_exceptions(fatal=True) + async def create_sampling_client_async( + self, + model_path: str | None = None, + base_model: str | None = None, + retry_config: RetryConfig | None = None, + ) -> SamplingClient: + from .sampling_client import SamplingClient + + if model_path is None and base_model is None: + raise ValueError("Either model_path or base_model must be provided") + return await SamplingClient.create( self.holder, model_path=model_path, base_model=base_model, diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index a35d2bc..a8e7128 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -29,6 +29,7 @@ from ..retry_handler import RetryConfig from ..sync_only import sync_only from .sampling_client import SamplingClient + if TYPE_CHECKING: from transformers.tokenization_utils import PreTrainedTokenizer @@ -522,9 +523,13 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): def create_sampling_client( self, model_path: str, retry_config: RetryConfig | None = None ) -> SamplingClient: - from .sampling_client import SamplingClient + return SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config).result() - return SamplingClient(self.holder, model_path=model_path, retry_config=retry_config) + @capture_exceptions(fatal=True) + async def create_sampling_client_async( + self, model_path: str, retry_config: RetryConfig | None = None + ) -> SamplingClient: + return await SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config) def save_weights_and_get_sampling_client_submit( self, retry_config: RetryConfig | None = None @@ -535,7 +540,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): result = await self._save_weights_for_sampler_impl(request_id, None) assert result.path is None assert result.sampling_session_id is not None - return SamplingClient( + return await SamplingClient.create( self.holder, sampling_session_id=result.sampling_session_id, retry_config=retry_config, @@ -582,6 +587,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer: # call get_info on model_id from transformers.models.auto.tokenization_auto import AutoTokenizer + try: + from tml_tokenizers import get_tinker_tokenizer + except ImportError: + def get_tinker_tokenizer(model_id: str) -> PreTrainedTokenizer | None: + return None async def _get_info_async(): with holder.aclient(ClientConnectionPoolType.TRAIN) as client: @@ -614,5 +624,7 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre "trust_remote_code": True, "revision": "612681931a8c906ddb349f8ad0f582cb552189cd", } + if (tokenizer := get_tinker_tokenizer(tokenizer_id)) is not None: + return tokenizer return AutoTokenizer.from_pretrained(tokenizer_id, fast=True, **kwargs)