Sync contents

This commit is contained in:
Andrii Grynenko 2025-11-21 01:42:38 +00:00
parent 3e368dc0df
commit 3e4e4e3560
8 changed files with 81 additions and 35 deletions

View file

@ -1,4 +1,4 @@
{
"last_synced_sha": "61084a6cd11cc08c063fce7ff415b18dda9890f0",
"last_sync_time": "2025-11-20T21:08:04.245668"
"last_synced_sha": "a4665d66b6c66527095652259bc8810cf843f425",
"last_sync_time": "2025-11-21T01:42:38.758898"
}

View file

@ -1,6 +1,6 @@
[project]
name = "tinker"
version = "0.5.0"
version = "0.5.1"
description = "The official Python SDK for the tinker API"
readme = "README.md"
license = "Apache-2.0"

View file

@ -501,14 +501,15 @@ def download(
unless --force is specified.
Examples:
tinker checkpoint download tinker://run-123/weights/final
# Creates ./run-123_weights_final/ with checkpoint files
tinker checkpoint download tinker://run-123/weights/final
tinker checkpoint download tinker://run-123/weights/final --output ./models/
# Creates ./models/run-123_weights_final/ with checkpoint files
tinker checkpoint download tinker://run-123/weights/final --output ./models/
tinker checkpoint download tinker://run-123/weights/final --force
# Overwrites existing ./run-123_weights_final/ directory
tinker checkpoint download tinker://run-123/weights/final --force
"""
# Lazy imports to maintain fast CLI startup
import urllib.error

View file

@ -68,7 +68,7 @@ class RunListOutput(OutputBase):
def get_table_columns(self) -> List[str]:
"""Return column headers for table output."""
return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Corrupted"]
return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Status"]
def get_table_rows(self) -> List[List[str]]:
"""Return rows for table output."""
@ -89,7 +89,7 @@ class RunListOutput(OutputBase):
run.model_owner,
lora_info,
format_timestamp(run.last_request_time),
str(run.corrupted),
"Failed" if run.corrupted else "Active",
]
)
@ -142,7 +142,7 @@ class RunInfoOutput(OutputBase):
rows.append(["Last Update", format_timestamp(self.run.last_request_time)])
# Corruption status
rows.append(["Status", "Corrupted" if self.run.corrupted else "Active"])
rows.append(["Status", "Failed" if self.run.corrupted else "Active"])
# Last checkpoints
if self.run.last_checkpoint:

View file

@ -3,12 +3,12 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
import os
import threading
import time
import traceback
import contextlib
from collections.abc import Coroutine, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Awaitable, Callable, TypeVar
@ -107,11 +107,12 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
self._sample_backoff_until: float | None = None
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
self._telemetry: Telemetry | None = None
self._session_heartbeat_task: asyncio.Task[None] | None = None
session_id, session_heartbeat_task = self.run_coroutine_threadsafe(
self._create_session(user_metadata)
).result()
self._session_heartbeat_task = session_heartbeat_task
self._session_id: str = session_id
self._session_heartbeat_task: asyncio.Task[None] = session_heartbeat_task
self._telemetry = init_telemetry(self, session_id=self._session_id)
self._training_client_counter: int = 0
@ -148,6 +149,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
async def _create_sampling_session(
self, model_path: str | None = None, base_model: str | None = None
) -> str:
if model_path and not model_path.startswith("tinker://"):
raise ValueError("model_path must start with 'tinker://'")
sampling_session_seq_id = self._sampling_client_counter
self._sampling_client_counter += 1
with self.aclient(ClientConnectionPoolType.SESSION) as client:
@ -223,7 +226,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
return AwaitableConcurrentFuture(asyncio.run_coroutine_threadsafe(coro, self.get_loop()))
def close(self):
self.run_coroutine_threadsafe(self._async_cleanup()).result()
self.run_coroutine_threadsafe(self._async_cleanup())
if telemetry := self._telemetry:
telemetry.stop()

View file

@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, TypeVar, cast
import tinker
from tinker import types
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
from tinker.lib.telemetry import Telemetry, capture_exceptions
from tinker.lib.telemetry_provider import TelemetryProvider
@ -58,21 +58,14 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
self,
holder: InternalClientHolder,
*,
model_path: str | None = None,
base_model: str | None = None,
sampling_session_id: str | None = None,
sampling_session_id: str,
retry_config: RetryConfig | None = None,
):
if model_path and not model_path.startswith("tinker://"):
raise ValueError("model_path must start with 'tinker://'")
self.holder = holder
self.model_path = model_path
self.base_model = base_model
# Create retry handler with the provided configuration
self.retry_handler = _get_retry_handler(
model_path or base_model, retry_config=retry_config, telemetry=holder.get_telemetry()
sampling_session_id, retry_config=retry_config, telemetry=holder.get_telemetry()
)
self.feature_gates = set(
@ -81,15 +74,34 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
self._last_queue_state_logged: float = 0
self._sampling_session_id: str = (
sampling_session_id
or holder.run_coroutine_threadsafe(
holder._create_sampling_session(model_path=model_path, base_model=base_model)
).result()
)
self._sampling_session_id: str = sampling_session_id
self._request_id_counter: int = 0
@staticmethod
async def _create_impl(
holder: InternalClientHolder,
*,
model_path: str | None,
base_model: str | None,
sampling_session_id: str | None,
retry_config: RetryConfig | None,
) -> SamplingClient:
if sampling_session_id is None:
sampling_session_id = await holder._create_sampling_session(model_path=model_path, base_model=base_model)
return SamplingClient(holder, sampling_session_id=sampling_session_id, retry_config=retry_config)
@staticmethod
def create(
holder: InternalClientHolder,
*,
model_path: str | None = None,
base_model: str | None = None,
sampling_session_id: str | None = None,
retry_config: RetryConfig | None = None,
) -> APIFuture[SamplingClient]:
return holder.run_coroutine_threadsafe(SamplingClient._create_impl(holder, model_path=model_path, base_model=base_model, sampling_session_id=sampling_session_id, retry_config=retry_config))
async def _send_asample_request(
self,
num_samples: int,
@ -243,7 +255,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
reason = "unknown"
self._last_queue_state_logged = time.time()
logger.warning(f"Sampling is paused for {self.model_path}. Reason: {reason}")
logger.warning(f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {reason}")
@lru_cache(maxsize=100)

View file

@ -226,7 +226,25 @@ class ServiceClient(TelemetryProvider):
if model_path is None and base_model is None:
raise ValueError("Either model_path or base_model must be provided")
return SamplingClient(
return SamplingClient.create(
self.holder,
model_path=model_path,
base_model=base_model,
retry_config=retry_config,
).result()
@capture_exceptions(fatal=True)
async def create_sampling_client_async(
self,
model_path: str | None = None,
base_model: str | None = None,
retry_config: RetryConfig | None = None,
) -> SamplingClient:
from .sampling_client import SamplingClient
if model_path is None and base_model is None:
raise ValueError("Either model_path or base_model must be provided")
return await SamplingClient.create(
self.holder,
model_path=model_path,
base_model=base_model,

View file

@ -29,6 +29,7 @@ from ..retry_handler import RetryConfig
from ..sync_only import sync_only
from .sampling_client import SamplingClient
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
@ -522,9 +523,13 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
def create_sampling_client(
self, model_path: str, retry_config: RetryConfig | None = None
) -> SamplingClient:
from .sampling_client import SamplingClient
return SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config).result()
return SamplingClient(self.holder, model_path=model_path, retry_config=retry_config)
@capture_exceptions(fatal=True)
async def create_sampling_client_async(
self, model_path: str, retry_config: RetryConfig | None = None
) -> SamplingClient:
return await SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config)
def save_weights_and_get_sampling_client_submit(
self, retry_config: RetryConfig | None = None
@ -535,7 +540,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
result = await self._save_weights_for_sampler_impl(request_id, None)
assert result.path is None
assert result.sampling_session_id is not None
return SamplingClient(
return await SamplingClient.create(
self.holder,
sampling_session_id=result.sampling_session_id,
retry_config=retry_config,
@ -582,6 +587,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer:
# call get_info on model_id
from transformers.models.auto.tokenization_auto import AutoTokenizer
try:
from tml_tokenizers import get_tinker_tokenizer
except ImportError:
def get_tinker_tokenizer(model_id: str) -> PreTrainedTokenizer | None:
return None
async def _get_info_async():
with holder.aclient(ClientConnectionPoolType.TRAIN) as client:
@ -614,5 +624,7 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre
"trust_remote_code": True,
"revision": "612681931a8c906ddb349f8ad0f582cb552189cd",
}
if (tokenizer := get_tinker_tokenizer(tokenizer_id)) is not None:
return tokenizer
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True, **kwargs)