mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
3e368dc0df
commit
3e4e4e3560
8 changed files with 81 additions and 35 deletions
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "61084a6cd11cc08c063fce7ff415b18dda9890f0",
|
||||
"last_sync_time": "2025-11-20T21:08:04.245668"
|
||||
"last_synced_sha": "a4665d66b6c66527095652259bc8810cf843f425",
|
||||
"last_sync_time": "2025-11-21T01:42:38.758898"
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "tinker"
|
||||
version = "0.5.0"
|
||||
version = "0.5.1"
|
||||
description = "The official Python SDK for the tinker API"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
|
|
|
|||
|
|
@ -501,14 +501,15 @@ def download(
|
|||
unless --force is specified.
|
||||
|
||||
Examples:
|
||||
tinker checkpoint download tinker://run-123/weights/final
|
||||
|
||||
# Creates ./run-123_weights_final/ with checkpoint files
|
||||
tinker checkpoint download tinker://run-123/weights/final
|
||||
|
||||
tinker checkpoint download tinker://run-123/weights/final --output ./models/
|
||||
# Creates ./models/run-123_weights_final/ with checkpoint files
|
||||
tinker checkpoint download tinker://run-123/weights/final --output ./models/
|
||||
|
||||
tinker checkpoint download tinker://run-123/weights/final --force
|
||||
# Overwrites existing ./run-123_weights_final/ directory
|
||||
tinker checkpoint download tinker://run-123/weights/final --force
|
||||
"""
|
||||
# Lazy imports to maintain fast CLI startup
|
||||
import urllib.error
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class RunListOutput(OutputBase):
|
|||
|
||||
def get_table_columns(self) -> List[str]:
|
||||
"""Return column headers for table output."""
|
||||
return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Corrupted"]
|
||||
return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Status"]
|
||||
|
||||
def get_table_rows(self) -> List[List[str]]:
|
||||
"""Return rows for table output."""
|
||||
|
|
@ -89,7 +89,7 @@ class RunListOutput(OutputBase):
|
|||
run.model_owner,
|
||||
lora_info,
|
||||
format_timestamp(run.last_request_time),
|
||||
str(run.corrupted),
|
||||
"Failed" if run.corrupted else "Active",
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -142,7 +142,7 @@ class RunInfoOutput(OutputBase):
|
|||
rows.append(["Last Update", format_timestamp(self.run.last_request_time)])
|
||||
|
||||
# Corruption status
|
||||
rows.append(["Status", "Corrupted" if self.run.corrupted else "Active"])
|
||||
rows.append(["Status", "Failed" if self.run.corrupted else "Active"])
|
||||
|
||||
# Last checkpoints
|
||||
if self.run.last_checkpoint:
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import contextlib
|
||||
from collections.abc import Coroutine, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
|
@ -107,11 +107,12 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
self._sample_backoff_until: float | None = None
|
||||
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
||||
self._telemetry: Telemetry | None = None
|
||||
self._session_heartbeat_task: asyncio.Task[None] | None = None
|
||||
session_id, session_heartbeat_task = self.run_coroutine_threadsafe(
|
||||
self._create_session(user_metadata)
|
||||
).result()
|
||||
self._session_heartbeat_task = session_heartbeat_task
|
||||
self._session_id: str = session_id
|
||||
self._session_heartbeat_task: asyncio.Task[None] = session_heartbeat_task
|
||||
self._telemetry = init_telemetry(self, session_id=self._session_id)
|
||||
|
||||
self._training_client_counter: int = 0
|
||||
|
|
@ -148,6 +149,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
async def _create_sampling_session(
|
||||
self, model_path: str | None = None, base_model: str | None = None
|
||||
) -> str:
|
||||
if model_path and not model_path.startswith("tinker://"):
|
||||
raise ValueError("model_path must start with 'tinker://'")
|
||||
sampling_session_seq_id = self._sampling_client_counter
|
||||
self._sampling_client_counter += 1
|
||||
with self.aclient(ClientConnectionPoolType.SESSION) as client:
|
||||
|
|
@ -223,7 +226,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
return AwaitableConcurrentFuture(asyncio.run_coroutine_threadsafe(coro, self.get_loop()))
|
||||
|
||||
def close(self):
|
||||
self.run_coroutine_threadsafe(self._async_cleanup()).result()
|
||||
self.run_coroutine_threadsafe(self._async_cleanup())
|
||||
if telemetry := self._telemetry:
|
||||
telemetry.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, TypeVar, cast
|
|||
import tinker
|
||||
from tinker import types
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
|
||||
from tinker.lib.telemetry import Telemetry, capture_exceptions
|
||||
from tinker.lib.telemetry_provider import TelemetryProvider
|
||||
|
||||
|
|
@ -58,21 +58,14 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
self,
|
||||
holder: InternalClientHolder,
|
||||
*,
|
||||
model_path: str | None = None,
|
||||
base_model: str | None = None,
|
||||
sampling_session_id: str | None = None,
|
||||
sampling_session_id: str,
|
||||
retry_config: RetryConfig | None = None,
|
||||
):
|
||||
if model_path and not model_path.startswith("tinker://"):
|
||||
raise ValueError("model_path must start with 'tinker://'")
|
||||
|
||||
self.holder = holder
|
||||
self.model_path = model_path
|
||||
self.base_model = base_model
|
||||
|
||||
# Create retry handler with the provided configuration
|
||||
self.retry_handler = _get_retry_handler(
|
||||
model_path or base_model, retry_config=retry_config, telemetry=holder.get_telemetry()
|
||||
sampling_session_id, retry_config=retry_config, telemetry=holder.get_telemetry()
|
||||
)
|
||||
|
||||
self.feature_gates = set(
|
||||
|
|
@ -81,15 +74,34 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
|
||||
self._last_queue_state_logged: float = 0
|
||||
|
||||
self._sampling_session_id: str = (
|
||||
sampling_session_id
|
||||
or holder.run_coroutine_threadsafe(
|
||||
holder._create_sampling_session(model_path=model_path, base_model=base_model)
|
||||
).result()
|
||||
)
|
||||
self._sampling_session_id: str = sampling_session_id
|
||||
|
||||
self._request_id_counter: int = 0
|
||||
|
||||
@staticmethod
|
||||
async def _create_impl(
|
||||
holder: InternalClientHolder,
|
||||
*,
|
||||
model_path: str | None,
|
||||
base_model: str | None,
|
||||
sampling_session_id: str | None,
|
||||
retry_config: RetryConfig | None,
|
||||
) -> SamplingClient:
|
||||
if sampling_session_id is None:
|
||||
sampling_session_id = await holder._create_sampling_session(model_path=model_path, base_model=base_model)
|
||||
return SamplingClient(holder, sampling_session_id=sampling_session_id, retry_config=retry_config)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
holder: InternalClientHolder,
|
||||
*,
|
||||
model_path: str | None = None,
|
||||
base_model: str | None = None,
|
||||
sampling_session_id: str | None = None,
|
||||
retry_config: RetryConfig | None = None,
|
||||
) -> APIFuture[SamplingClient]:
|
||||
return holder.run_coroutine_threadsafe(SamplingClient._create_impl(holder, model_path=model_path, base_model=base_model, sampling_session_id=sampling_session_id, retry_config=retry_config))
|
||||
|
||||
async def _send_asample_request(
|
||||
self,
|
||||
num_samples: int,
|
||||
|
|
@ -243,7 +255,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
reason = "unknown"
|
||||
self._last_queue_state_logged = time.time()
|
||||
|
||||
logger.warning(f"Sampling is paused for {self.model_path}. Reason: {reason}")
|
||||
logger.warning(f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {reason}")
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
|
|
|
|||
|
|
@ -226,7 +226,25 @@ class ServiceClient(TelemetryProvider):
|
|||
|
||||
if model_path is None and base_model is None:
|
||||
raise ValueError("Either model_path or base_model must be provided")
|
||||
return SamplingClient(
|
||||
return SamplingClient.create(
|
||||
self.holder,
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
retry_config=retry_config,
|
||||
).result()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_sampling_client_async(
|
||||
self,
|
||||
model_path: str | None = None,
|
||||
base_model: str | None = None,
|
||||
retry_config: RetryConfig | None = None,
|
||||
) -> SamplingClient:
|
||||
from .sampling_client import SamplingClient
|
||||
|
||||
if model_path is None and base_model is None:
|
||||
raise ValueError("Either model_path or base_model must be provided")
|
||||
return await SamplingClient.create(
|
||||
self.holder,
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from ..retry_handler import RetryConfig
|
|||
from ..sync_only import sync_only
|
||||
from .sampling_client import SamplingClient
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
|
@ -522,9 +523,13 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
def create_sampling_client(
|
||||
self, model_path: str, retry_config: RetryConfig | None = None
|
||||
) -> SamplingClient:
|
||||
from .sampling_client import SamplingClient
|
||||
return SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config).result()
|
||||
|
||||
return SamplingClient(self.holder, model_path=model_path, retry_config=retry_config)
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_sampling_client_async(
|
||||
self, model_path: str, retry_config: RetryConfig | None = None
|
||||
) -> SamplingClient:
|
||||
return await SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config)
|
||||
|
||||
def save_weights_and_get_sampling_client_submit(
|
||||
self, retry_config: RetryConfig | None = None
|
||||
|
|
@ -535,7 +540,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
result = await self._save_weights_for_sampler_impl(request_id, None)
|
||||
assert result.path is None
|
||||
assert result.sampling_session_id is not None
|
||||
return SamplingClient(
|
||||
return await SamplingClient.create(
|
||||
self.holder,
|
||||
sampling_session_id=result.sampling_session_id,
|
||||
retry_config=retry_config,
|
||||
|
|
@ -582,6 +587,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer:
|
||||
# call get_info on model_id
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
try:
|
||||
from tml_tokenizers import get_tinker_tokenizer
|
||||
except ImportError:
|
||||
def get_tinker_tokenizer(model_id: str) -> PreTrainedTokenizer | None:
|
||||
return None
|
||||
|
||||
async def _get_info_async():
|
||||
with holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
|
|
@ -614,5 +624,7 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre
|
|||
"trust_remote_code": True,
|
||||
"revision": "612681931a8c906ddb349f8ad0f582cb552189cd",
|
||||
}
|
||||
if (tokenizer := get_tinker_tokenizer(tokenizer_id)) is not None:
|
||||
return tokenizer
|
||||
|
||||
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue