mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
3e368dc0df
commit
3e4e4e3560
8 changed files with 81 additions and 35 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"last_synced_sha": "61084a6cd11cc08c063fce7ff415b18dda9890f0",
|
"last_synced_sha": "a4665d66b6c66527095652259bc8810cf843f425",
|
||||||
"last_sync_time": "2025-11-20T21:08:04.245668"
|
"last_sync_time": "2025-11-21T01:42:38.758898"
|
||||||
}
|
}
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "tinker"
|
name = "tinker"
|
||||||
version = "0.5.0"
|
version = "0.5.1"
|
||||||
description = "The official Python SDK for the tinker API"
|
description = "The official Python SDK for the tinker API"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
|
||||||
|
|
@ -501,14 +501,15 @@ def download(
|
||||||
unless --force is specified.
|
unless --force is specified.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
|
# Creates ./run-123_weights_final/ with checkpoint files
|
||||||
tinker checkpoint download tinker://run-123/weights/final
|
tinker checkpoint download tinker://run-123/weights/final
|
||||||
# Creates ./run-123_weights_final/ with checkpoint files
|
|
||||||
|
|
||||||
|
# Creates ./models/run-123_weights_final/ with checkpoint files
|
||||||
tinker checkpoint download tinker://run-123/weights/final --output ./models/
|
tinker checkpoint download tinker://run-123/weights/final --output ./models/
|
||||||
# Creates ./models/run-123_weights_final/ with checkpoint files
|
|
||||||
|
|
||||||
|
# Overwrites existing ./run-123_weights_final/ directory
|
||||||
tinker checkpoint download tinker://run-123/weights/final --force
|
tinker checkpoint download tinker://run-123/weights/final --force
|
||||||
# Overwrites existing ./run-123_weights_final/ directory
|
|
||||||
"""
|
"""
|
||||||
# Lazy imports to maintain fast CLI startup
|
# Lazy imports to maintain fast CLI startup
|
||||||
import urllib.error
|
import urllib.error
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ class RunListOutput(OutputBase):
|
||||||
|
|
||||||
def get_table_columns(self) -> List[str]:
|
def get_table_columns(self) -> List[str]:
|
||||||
"""Return column headers for table output."""
|
"""Return column headers for table output."""
|
||||||
return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Corrupted"]
|
return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Status"]
|
||||||
|
|
||||||
def get_table_rows(self) -> List[List[str]]:
|
def get_table_rows(self) -> List[List[str]]:
|
||||||
"""Return rows for table output."""
|
"""Return rows for table output."""
|
||||||
|
|
@ -89,7 +89,7 @@ class RunListOutput(OutputBase):
|
||||||
run.model_owner,
|
run.model_owner,
|
||||||
lora_info,
|
lora_info,
|
||||||
format_timestamp(run.last_request_time),
|
format_timestamp(run.last_request_time),
|
||||||
str(run.corrupted),
|
"Failed" if run.corrupted else "Active",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -142,7 +142,7 @@ class RunInfoOutput(OutputBase):
|
||||||
rows.append(["Last Update", format_timestamp(self.run.last_request_time)])
|
rows.append(["Last Update", format_timestamp(self.run.last_request_time)])
|
||||||
|
|
||||||
# Corruption status
|
# Corruption status
|
||||||
rows.append(["Status", "Corrupted" if self.run.corrupted else "Active"])
|
rows.append(["Status", "Failed" if self.run.corrupted else "Active"])
|
||||||
|
|
||||||
# Last checkpoints
|
# Last checkpoints
|
||||||
if self.run.last_checkpoint:
|
if self.run.last_checkpoint:
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,12 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import contextlib
|
||||||
from collections.abc import Coroutine, Generator
|
from collections.abc import Coroutine, Generator
|
||||||
from contextlib import AbstractContextManager, contextmanager
|
from contextlib import AbstractContextManager, contextmanager
|
||||||
from typing import Any, Awaitable, Callable, TypeVar
|
from typing import Any, Awaitable, Callable, TypeVar
|
||||||
|
|
@ -107,11 +107,12 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
self._sample_backoff_until: float | None = None
|
self._sample_backoff_until: float | None = None
|
||||||
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
||||||
self._telemetry: Telemetry | None = None
|
self._telemetry: Telemetry | None = None
|
||||||
|
self._session_heartbeat_task: asyncio.Task[None] | None = None
|
||||||
session_id, session_heartbeat_task = self.run_coroutine_threadsafe(
|
session_id, session_heartbeat_task = self.run_coroutine_threadsafe(
|
||||||
self._create_session(user_metadata)
|
self._create_session(user_metadata)
|
||||||
).result()
|
).result()
|
||||||
|
self._session_heartbeat_task = session_heartbeat_task
|
||||||
self._session_id: str = session_id
|
self._session_id: str = session_id
|
||||||
self._session_heartbeat_task: asyncio.Task[None] = session_heartbeat_task
|
|
||||||
self._telemetry = init_telemetry(self, session_id=self._session_id)
|
self._telemetry = init_telemetry(self, session_id=self._session_id)
|
||||||
|
|
||||||
self._training_client_counter: int = 0
|
self._training_client_counter: int = 0
|
||||||
|
|
@ -148,6 +149,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
async def _create_sampling_session(
|
async def _create_sampling_session(
|
||||||
self, model_path: str | None = None, base_model: str | None = None
|
self, model_path: str | None = None, base_model: str | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if model_path and not model_path.startswith("tinker://"):
|
||||||
|
raise ValueError("model_path must start with 'tinker://'")
|
||||||
sampling_session_seq_id = self._sampling_client_counter
|
sampling_session_seq_id = self._sampling_client_counter
|
||||||
self._sampling_client_counter += 1
|
self._sampling_client_counter += 1
|
||||||
with self.aclient(ClientConnectionPoolType.SESSION) as client:
|
with self.aclient(ClientConnectionPoolType.SESSION) as client:
|
||||||
|
|
@ -223,7 +226,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
return AwaitableConcurrentFuture(asyncio.run_coroutine_threadsafe(coro, self.get_loop()))
|
return AwaitableConcurrentFuture(asyncio.run_coroutine_threadsafe(coro, self.get_loop()))
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.run_coroutine_threadsafe(self._async_cleanup()).result()
|
self.run_coroutine_threadsafe(self._async_cleanup())
|
||||||
if telemetry := self._telemetry:
|
if telemetry := self._telemetry:
|
||||||
telemetry.stop()
|
telemetry.stop()
|
||||||
|
|
||||||
|
|
@ -233,8 +236,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
async def _async_cleanup(self):
|
async def _async_cleanup(self):
|
||||||
if self._session_heartbeat_task:
|
if self._session_heartbeat_task:
|
||||||
self._session_heartbeat_task.cancel()
|
self._session_heartbeat_task.cancel()
|
||||||
with contextlib.suppress(asyncio.CancelledError):
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
await self._session_heartbeat_task
|
await self._session_heartbeat_task
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_retryable_status_code(status_code: int) -> bool:
|
def _is_retryable_status_code(status_code: int) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, TypeVar, cast
|
||||||
import tinker
|
import tinker
|
||||||
from tinker import types
|
from tinker import types
|
||||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
|
||||||
from tinker.lib.telemetry import Telemetry, capture_exceptions
|
from tinker.lib.telemetry import Telemetry, capture_exceptions
|
||||||
from tinker.lib.telemetry_provider import TelemetryProvider
|
from tinker.lib.telemetry_provider import TelemetryProvider
|
||||||
|
|
||||||
|
|
@ -58,21 +58,14 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
||||||
self,
|
self,
|
||||||
holder: InternalClientHolder,
|
holder: InternalClientHolder,
|
||||||
*,
|
*,
|
||||||
model_path: str | None = None,
|
sampling_session_id: str,
|
||||||
base_model: str | None = None,
|
|
||||||
sampling_session_id: str | None = None,
|
|
||||||
retry_config: RetryConfig | None = None,
|
retry_config: RetryConfig | None = None,
|
||||||
):
|
):
|
||||||
if model_path and not model_path.startswith("tinker://"):
|
|
||||||
raise ValueError("model_path must start with 'tinker://'")
|
|
||||||
|
|
||||||
self.holder = holder
|
self.holder = holder
|
||||||
self.model_path = model_path
|
|
||||||
self.base_model = base_model
|
|
||||||
|
|
||||||
# Create retry handler with the provided configuration
|
# Create retry handler with the provided configuration
|
||||||
self.retry_handler = _get_retry_handler(
|
self.retry_handler = _get_retry_handler(
|
||||||
model_path or base_model, retry_config=retry_config, telemetry=holder.get_telemetry()
|
sampling_session_id, retry_config=retry_config, telemetry=holder.get_telemetry()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feature_gates = set(
|
self.feature_gates = set(
|
||||||
|
|
@ -81,15 +74,34 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
||||||
|
|
||||||
self._last_queue_state_logged: float = 0
|
self._last_queue_state_logged: float = 0
|
||||||
|
|
||||||
self._sampling_session_id: str = (
|
self._sampling_session_id: str = sampling_session_id
|
||||||
sampling_session_id
|
|
||||||
or holder.run_coroutine_threadsafe(
|
|
||||||
holder._create_sampling_session(model_path=model_path, base_model=base_model)
|
|
||||||
).result()
|
|
||||||
)
|
|
||||||
|
|
||||||
self._request_id_counter: int = 0
|
self._request_id_counter: int = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _create_impl(
|
||||||
|
holder: InternalClientHolder,
|
||||||
|
*,
|
||||||
|
model_path: str | None,
|
||||||
|
base_model: str | None,
|
||||||
|
sampling_session_id: str | None,
|
||||||
|
retry_config: RetryConfig | None,
|
||||||
|
) -> SamplingClient:
|
||||||
|
if sampling_session_id is None:
|
||||||
|
sampling_session_id = await holder._create_sampling_session(model_path=model_path, base_model=base_model)
|
||||||
|
return SamplingClient(holder, sampling_session_id=sampling_session_id, retry_config=retry_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(
|
||||||
|
holder: InternalClientHolder,
|
||||||
|
*,
|
||||||
|
model_path: str | None = None,
|
||||||
|
base_model: str | None = None,
|
||||||
|
sampling_session_id: str | None = None,
|
||||||
|
retry_config: RetryConfig | None = None,
|
||||||
|
) -> APIFuture[SamplingClient]:
|
||||||
|
return holder.run_coroutine_threadsafe(SamplingClient._create_impl(holder, model_path=model_path, base_model=base_model, sampling_session_id=sampling_session_id, retry_config=retry_config))
|
||||||
|
|
||||||
async def _send_asample_request(
|
async def _send_asample_request(
|
||||||
self,
|
self,
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
|
|
@ -243,7 +255,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
||||||
reason = "unknown"
|
reason = "unknown"
|
||||||
self._last_queue_state_logged = time.time()
|
self._last_queue_state_logged = time.time()
|
||||||
|
|
||||||
logger.warning(f"Sampling is paused for {self.model_path}. Reason: {reason}")
|
logger.warning(f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {reason}")
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=100)
|
@lru_cache(maxsize=100)
|
||||||
|
|
|
||||||
|
|
@ -226,7 +226,25 @@ class ServiceClient(TelemetryProvider):
|
||||||
|
|
||||||
if model_path is None and base_model is None:
|
if model_path is None and base_model is None:
|
||||||
raise ValueError("Either model_path or base_model must be provided")
|
raise ValueError("Either model_path or base_model must be provided")
|
||||||
return SamplingClient(
|
return SamplingClient.create(
|
||||||
|
self.holder,
|
||||||
|
model_path=model_path,
|
||||||
|
base_model=base_model,
|
||||||
|
retry_config=retry_config,
|
||||||
|
).result()
|
||||||
|
|
||||||
|
@capture_exceptions(fatal=True)
|
||||||
|
async def create_sampling_client_async(
|
||||||
|
self,
|
||||||
|
model_path: str | None = None,
|
||||||
|
base_model: str | None = None,
|
||||||
|
retry_config: RetryConfig | None = None,
|
||||||
|
) -> SamplingClient:
|
||||||
|
from .sampling_client import SamplingClient
|
||||||
|
|
||||||
|
if model_path is None and base_model is None:
|
||||||
|
raise ValueError("Either model_path or base_model must be provided")
|
||||||
|
return await SamplingClient.create(
|
||||||
self.holder,
|
self.holder,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from ..retry_handler import RetryConfig
|
||||||
from ..sync_only import sync_only
|
from ..sync_only import sync_only
|
||||||
from .sampling_client import SamplingClient
|
from .sampling_client import SamplingClient
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
@ -522,9 +523,13 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
def create_sampling_client(
|
def create_sampling_client(
|
||||||
self, model_path: str, retry_config: RetryConfig | None = None
|
self, model_path: str, retry_config: RetryConfig | None = None
|
||||||
) -> SamplingClient:
|
) -> SamplingClient:
|
||||||
from .sampling_client import SamplingClient
|
return SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config).result()
|
||||||
|
|
||||||
return SamplingClient(self.holder, model_path=model_path, retry_config=retry_config)
|
@capture_exceptions(fatal=True)
|
||||||
|
async def create_sampling_client_async(
|
||||||
|
self, model_path: str, retry_config: RetryConfig | None = None
|
||||||
|
) -> SamplingClient:
|
||||||
|
return await SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config)
|
||||||
|
|
||||||
def save_weights_and_get_sampling_client_submit(
|
def save_weights_and_get_sampling_client_submit(
|
||||||
self, retry_config: RetryConfig | None = None
|
self, retry_config: RetryConfig | None = None
|
||||||
|
|
@ -535,7 +540,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
result = await self._save_weights_for_sampler_impl(request_id, None)
|
result = await self._save_weights_for_sampler_impl(request_id, None)
|
||||||
assert result.path is None
|
assert result.path is None
|
||||||
assert result.sampling_session_id is not None
|
assert result.sampling_session_id is not None
|
||||||
return SamplingClient(
|
return await SamplingClient.create(
|
||||||
self.holder,
|
self.holder,
|
||||||
sampling_session_id=result.sampling_session_id,
|
sampling_session_id=result.sampling_session_id,
|
||||||
retry_config=retry_config,
|
retry_config=retry_config,
|
||||||
|
|
@ -582,6 +587,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer:
|
def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer:
|
||||||
# call get_info on model_id
|
# call get_info on model_id
|
||||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||||
|
try:
|
||||||
|
from tml_tokenizers import get_tinker_tokenizer
|
||||||
|
except ImportError:
|
||||||
|
def get_tinker_tokenizer(model_id: str) -> PreTrainedTokenizer | None:
|
||||||
|
return None
|
||||||
|
|
||||||
async def _get_info_async():
|
async def _get_info_async():
|
||||||
with holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
with holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
|
|
@ -614,5 +624,7 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"revision": "612681931a8c906ddb349f8ad0f582cb552189cd",
|
"revision": "612681931a8c906ddb349f8ad0f582cb552189cd",
|
||||||
}
|
}
|
||||||
|
if (tokenizer := get_tinker_tokenizer(tokenizer_id)) is not None:
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True, **kwargs)
|
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True, **kwargs)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue