Sync contents

This commit is contained in:
Andrii Grynenko 2025-11-21 01:42:38 +00:00
parent 3e368dc0df
commit 3e4e4e3560
8 changed files with 81 additions and 35 deletions

View file

@ -1,4 +1,4 @@
{ {
"last_synced_sha": "61084a6cd11cc08c063fce7ff415b18dda9890f0", "last_synced_sha": "a4665d66b6c66527095652259bc8810cf843f425",
"last_sync_time": "2025-11-20T21:08:04.245668" "last_sync_time": "2025-11-21T01:42:38.758898"
} }

View file

@ -1,6 +1,6 @@
[project] [project]
name = "tinker" name = "tinker"
version = "0.5.0" version = "0.5.1"
description = "The official Python SDK for the tinker API" description = "The official Python SDK for the tinker API"
readme = "README.md" readme = "README.md"
license = "Apache-2.0" license = "Apache-2.0"

View file

@ -501,14 +501,15 @@ def download(
unless --force is specified. unless --force is specified.
Examples: Examples:
# Creates ./run-123_weights_final/ with checkpoint files
tinker checkpoint download tinker://run-123/weights/final tinker checkpoint download tinker://run-123/weights/final
# Creates ./run-123_weights_final/ with checkpoint files
# Creates ./models/run-123_weights_final/ with checkpoint files
tinker checkpoint download tinker://run-123/weights/final --output ./models/ tinker checkpoint download tinker://run-123/weights/final --output ./models/
# Creates ./models/run-123_weights_final/ with checkpoint files
# Overwrites existing ./run-123_weights_final/ directory
tinker checkpoint download tinker://run-123/weights/final --force tinker checkpoint download tinker://run-123/weights/final --force
# Overwrites existing ./run-123_weights_final/ directory
""" """
# Lazy imports to maintain fast CLI startup # Lazy imports to maintain fast CLI startup
import urllib.error import urllib.error

View file

@ -68,7 +68,7 @@ class RunListOutput(OutputBase):
def get_table_columns(self) -> List[str]: def get_table_columns(self) -> List[str]:
"""Return column headers for table output.""" """Return column headers for table output."""
return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Corrupted"] return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Status"]
def get_table_rows(self) -> List[List[str]]: def get_table_rows(self) -> List[List[str]]:
"""Return rows for table output.""" """Return rows for table output."""
@ -89,7 +89,7 @@ class RunListOutput(OutputBase):
run.model_owner, run.model_owner,
lora_info, lora_info,
format_timestamp(run.last_request_time), format_timestamp(run.last_request_time),
str(run.corrupted), "Failed" if run.corrupted else "Active",
] ]
) )
@ -142,7 +142,7 @@ class RunInfoOutput(OutputBase):
rows.append(["Last Update", format_timestamp(self.run.last_request_time)]) rows.append(["Last Update", format_timestamp(self.run.last_request_time)])
# Corruption status # Corruption status
rows.append(["Status", "Corrupted" if self.run.corrupted else "Active"]) rows.append(["Status", "Failed" if self.run.corrupted else "Active"])
# Last checkpoints # Last checkpoints
if self.run.last_checkpoint: if self.run.last_checkpoint:

View file

@ -3,12 +3,12 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import logging import logging
import os import os
import threading import threading
import time import time
import traceback import traceback
import contextlib
from collections.abc import Coroutine, Generator from collections.abc import Coroutine, Generator
from contextlib import AbstractContextManager, contextmanager from contextlib import AbstractContextManager, contextmanager
from typing import Any, Awaitable, Callable, TypeVar from typing import Any, Awaitable, Callable, TypeVar
@ -107,11 +107,12 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
self._sample_backoff_until: float | None = None self._sample_backoff_until: float | None = None
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400) self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
self._telemetry: Telemetry | None = None self._telemetry: Telemetry | None = None
self._session_heartbeat_task: asyncio.Task[None] | None = None
session_id, session_heartbeat_task = self.run_coroutine_threadsafe( session_id, session_heartbeat_task = self.run_coroutine_threadsafe(
self._create_session(user_metadata) self._create_session(user_metadata)
).result() ).result()
self._session_heartbeat_task = session_heartbeat_task
self._session_id: str = session_id self._session_id: str = session_id
self._session_heartbeat_task: asyncio.Task[None] = session_heartbeat_task
self._telemetry = init_telemetry(self, session_id=self._session_id) self._telemetry = init_telemetry(self, session_id=self._session_id)
self._training_client_counter: int = 0 self._training_client_counter: int = 0
@ -148,6 +149,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
async def _create_sampling_session( async def _create_sampling_session(
self, model_path: str | None = None, base_model: str | None = None self, model_path: str | None = None, base_model: str | None = None
) -> str: ) -> str:
if model_path and not model_path.startswith("tinker://"):
raise ValueError("model_path must start with 'tinker://'")
sampling_session_seq_id = self._sampling_client_counter sampling_session_seq_id = self._sampling_client_counter
self._sampling_client_counter += 1 self._sampling_client_counter += 1
with self.aclient(ClientConnectionPoolType.SESSION) as client: with self.aclient(ClientConnectionPoolType.SESSION) as client:
@ -223,7 +226,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
return AwaitableConcurrentFuture(asyncio.run_coroutine_threadsafe(coro, self.get_loop())) return AwaitableConcurrentFuture(asyncio.run_coroutine_threadsafe(coro, self.get_loop()))
def close(self): def close(self):
self.run_coroutine_threadsafe(self._async_cleanup()).result() self.run_coroutine_threadsafe(self._async_cleanup())
if telemetry := self._telemetry: if telemetry := self._telemetry:
telemetry.stop() telemetry.stop()
@ -233,8 +236,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
async def _async_cleanup(self): async def _async_cleanup(self):
if self._session_heartbeat_task: if self._session_heartbeat_task:
self._session_heartbeat_task.cancel() self._session_heartbeat_task.cancel()
with contextlib.suppress(asyncio.CancelledError): with contextlib.suppress(asyncio.CancelledError):
await self._session_heartbeat_task await self._session_heartbeat_task
@staticmethod @staticmethod
def _is_retryable_status_code(status_code: int) -> bool: def _is_retryable_status_code(status_code: int) -> bool:

View file

@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, TypeVar, cast
import tinker import tinker
from tinker import types from tinker import types
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
from tinker.lib.telemetry import Telemetry, capture_exceptions from tinker.lib.telemetry import Telemetry, capture_exceptions
from tinker.lib.telemetry_provider import TelemetryProvider from tinker.lib.telemetry_provider import TelemetryProvider
@ -58,21 +58,14 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
self, self,
holder: InternalClientHolder, holder: InternalClientHolder,
*, *,
model_path: str | None = None, sampling_session_id: str,
base_model: str | None = None,
sampling_session_id: str | None = None,
retry_config: RetryConfig | None = None, retry_config: RetryConfig | None = None,
): ):
if model_path and not model_path.startswith("tinker://"):
raise ValueError("model_path must start with 'tinker://'")
self.holder = holder self.holder = holder
self.model_path = model_path
self.base_model = base_model
# Create retry handler with the provided configuration # Create retry handler with the provided configuration
self.retry_handler = _get_retry_handler( self.retry_handler = _get_retry_handler(
model_path or base_model, retry_config=retry_config, telemetry=holder.get_telemetry() sampling_session_id, retry_config=retry_config, telemetry=holder.get_telemetry()
) )
self.feature_gates = set( self.feature_gates = set(
@ -81,15 +74,34 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
self._last_queue_state_logged: float = 0 self._last_queue_state_logged: float = 0
self._sampling_session_id: str = ( self._sampling_session_id: str = sampling_session_id
sampling_session_id
or holder.run_coroutine_threadsafe(
holder._create_sampling_session(model_path=model_path, base_model=base_model)
).result()
)
self._request_id_counter: int = 0 self._request_id_counter: int = 0
@staticmethod
async def _create_impl(
holder: InternalClientHolder,
*,
model_path: str | None,
base_model: str | None,
sampling_session_id: str | None,
retry_config: RetryConfig | None,
) -> SamplingClient:
if sampling_session_id is None:
sampling_session_id = await holder._create_sampling_session(model_path=model_path, base_model=base_model)
return SamplingClient(holder, sampling_session_id=sampling_session_id, retry_config=retry_config)
@staticmethod
def create(
holder: InternalClientHolder,
*,
model_path: str | None = None,
base_model: str | None = None,
sampling_session_id: str | None = None,
retry_config: RetryConfig | None = None,
) -> APIFuture[SamplingClient]:
return holder.run_coroutine_threadsafe(SamplingClient._create_impl(holder, model_path=model_path, base_model=base_model, sampling_session_id=sampling_session_id, retry_config=retry_config))
async def _send_asample_request( async def _send_asample_request(
self, self,
num_samples: int, num_samples: int,
@ -243,7 +255,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
reason = "unknown" reason = "unknown"
self._last_queue_state_logged = time.time() self._last_queue_state_logged = time.time()
logger.warning(f"Sampling is paused for {self.model_path}. Reason: {reason}") logger.warning(f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {reason}")
@lru_cache(maxsize=100) @lru_cache(maxsize=100)

View file

@ -226,7 +226,25 @@ class ServiceClient(TelemetryProvider):
if model_path is None and base_model is None: if model_path is None and base_model is None:
raise ValueError("Either model_path or base_model must be provided") raise ValueError("Either model_path or base_model must be provided")
return SamplingClient( return SamplingClient.create(
self.holder,
model_path=model_path,
base_model=base_model,
retry_config=retry_config,
).result()
@capture_exceptions(fatal=True)
async def create_sampling_client_async(
self,
model_path: str | None = None,
base_model: str | None = None,
retry_config: RetryConfig | None = None,
) -> SamplingClient:
from .sampling_client import SamplingClient
if model_path is None and base_model is None:
raise ValueError("Either model_path or base_model must be provided")
return await SamplingClient.create(
self.holder, self.holder,
model_path=model_path, model_path=model_path,
base_model=base_model, base_model=base_model,

View file

@ -29,6 +29,7 @@ from ..retry_handler import RetryConfig
from ..sync_only import sync_only from ..sync_only import sync_only
from .sampling_client import SamplingClient from .sampling_client import SamplingClient
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
@ -522,9 +523,13 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
def create_sampling_client( def create_sampling_client(
self, model_path: str, retry_config: RetryConfig | None = None self, model_path: str, retry_config: RetryConfig | None = None
) -> SamplingClient: ) -> SamplingClient:
from .sampling_client import SamplingClient return SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config).result()
return SamplingClient(self.holder, model_path=model_path, retry_config=retry_config) @capture_exceptions(fatal=True)
async def create_sampling_client_async(
self, model_path: str, retry_config: RetryConfig | None = None
) -> SamplingClient:
return await SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config)
def save_weights_and_get_sampling_client_submit( def save_weights_and_get_sampling_client_submit(
self, retry_config: RetryConfig | None = None self, retry_config: RetryConfig | None = None
@ -535,7 +540,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
result = await self._save_weights_for_sampler_impl(request_id, None) result = await self._save_weights_for_sampler_impl(request_id, None)
assert result.path is None assert result.path is None
assert result.sampling_session_id is not None assert result.sampling_session_id is not None
return SamplingClient( return await SamplingClient.create(
self.holder, self.holder,
sampling_session_id=result.sampling_session_id, sampling_session_id=result.sampling_session_id,
retry_config=retry_config, retry_config=retry_config,
@ -582,6 +587,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer: def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer:
# call get_info on model_id # call get_info on model_id
from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.models.auto.tokenization_auto import AutoTokenizer
try:
from tml_tokenizers import get_tinker_tokenizer
except ImportError:
def get_tinker_tokenizer(model_id: str) -> PreTrainedTokenizer | None:
return None
async def _get_info_async(): async def _get_info_async():
with holder.aclient(ClientConnectionPoolType.TRAIN) as client: with holder.aclient(ClientConnectionPoolType.TRAIN) as client:
@ -614,5 +624,7 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre
"trust_remote_code": True, "trust_remote_code": True,
"revision": "612681931a8c906ddb349f8ad0f582cb552189cd", "revision": "612681931a8c906ddb349f8ad0f582cb552189cd",
} }
if (tokenizer := get_tinker_tokenizer(tokenizer_id)) is not None:
return tokenizer
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True, **kwargs) return AutoTokenizer.from_pretrained(tokenizer_id, fast=True, **kwargs)