Sync contents

This commit is contained in:
Andrii Grynenko 2025-12-08 00:05:28 +00:00
parent 0622760140
commit 5ad4282c96
10 changed files with 155 additions and 51 deletions

View file

@ -1,4 +1,4 @@
{
"last_synced_sha": "ffa781c9432bd1afb73c278d253d0a7cd1b333d8",
"last_sync_time": "2025-12-03T01:00:03.831979"
"last_synced_sha": "2f8a0c29b227897c25451163c8fa90adb2daa6e6",
"last_sync_time": "2025-12-08T00:05:28.592609"
}

View file

@ -1,6 +1,6 @@
[project]
name = "tinker"
version = "0.6.3"
version = "0.7.0"
description = "The official Python SDK for the tinker API"
readme = "README.md"
license = "Apache-2.0"

View file

@ -57,7 +57,7 @@ class APIResponseValidationError(APIError):
status_code: int
def __init__(
self, response: httpx.Response, body: object | None, *, message: str | None = None
self, response: httpx.Response, body: object | None, message: str | None = None
) -> None:
super().__init__(
message or "Data returned by API invalid for expected schema.",
@ -67,6 +67,14 @@ class APIResponseValidationError(APIError):
self.response = response
self.status_code = response.status_code
def __reduce__(self):
# Return a tuple of (callable, args) to recreate the exception
return (
self.__class__,
(self.response, self.body, self.message), # positional args
None,
)
class APIStatusError(APIError):
"""Raised when an API response has a status code of 4xx or 5xx."""
@ -74,24 +82,41 @@ class APIStatusError(APIError):
response: httpx.Response
status_code: int
def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
def __init__(self, message: str, response: httpx.Response, body: object | None) -> None:
super().__init__(message, response.request, body=body)
self.response = response
self.status_code = response.status_code
def __reduce__(self):
# Return a tuple of (callable, args) to recreate the exception
return (
self.__class__,
(self.message, self.response, self.body), # positional args
None,
)
class APIConnectionError(APIError):
"""Raised when a connection error occurs while making an API request."""
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
def __init__(self, request: httpx.Request, message: str = "Connection error.") -> None:
super().__init__(message, request, body=None)
def __reduce__(self):
# Return a tuple of (callable, args) to recreate the exception
return (
self.__class__,
(self.request, self.message), # positional args
None,
)
class APITimeoutError(APIConnectionError):
"""Raised when an API request times out."""
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out.", request=request)
super().__init__(request=request, message="Request timed out.", )
class BadRequestError(APIStatusError):
@ -148,10 +173,18 @@ class RequestFailedError(TinkerError):
def __init__(
self,
message: str,
*,
request_id: str,
category: "RequestErrorCategory",
) -> None:
super().__init__(message)
self.message: str = message
self.request_id: str = request_id
self.category: RequestErrorCategory = category
def __reduce__(self):
# Return a tuple of (callable, args) to recreate the exception
return (
self.__class__,
(self.message, self.request_id, self.category), # positional args
None,
)

View file

@ -44,7 +44,7 @@ class QueueState(Enum):
class QueueStateObserver(ABC):
@abstractmethod
def on_queue_state_change(self, queue_state: QueueState) -> None:
def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None:
raise NotImplementedError
@ -151,6 +151,7 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
with contextlib.suppress(Exception):
response = e.response.json()
if queue_state_str := response.get("queue_state", None):
queue_state_reason = response.get("queue_state_reason", None)
if queue_state_str == "active":
queue_state = QueueState.ACTIVE
elif queue_state_str == "paused_rate_limit":
@ -159,7 +160,7 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
queue_state = QueueState.PAUSED_CAPACITY
else:
queue_state = QueueState.UNKNOWN
self._queue_state_observer.on_queue_state_change(queue_state)
self._queue_state_observer.on_queue_state_change(queue_state, queue_state_reason)
continue
if e.status_code == 410:
raise RetryableException(

View file

@ -10,7 +10,7 @@ import time
import traceback
import contextlib
from collections.abc import Coroutine, Generator
from contextlib import AbstractContextManager, contextmanager
from contextlib import AbstractContextManager, asynccontextmanager, contextmanager
from typing import Any, Awaitable, Callable, TypeVar
import httpx
@ -97,6 +97,31 @@ class InternalClientHolderThreadSingleton:
_internal_client_holder_thread_singleton = InternalClientHolderThreadSingleton()
class BytesSemaphore:
def __init__(self, max_bytes: int):
self._bytes: int = max_bytes
self._condition: asyncio.Condition = asyncio.Condition()
self._release_task: asyncio.Task[None] | None = None
async def _release(self):
async with self._condition:
self._condition.notify_all()
@asynccontextmanager
async def acquire(self, bytes: int):
async with self._condition:
while self._bytes < 0:
await self._condition.wait()
self._bytes -= bytes
try:
yield
finally:
self._bytes += bytes
# Make sure the release task is never cancelled.
self._release_task = asyncio.create_task(self._release())
class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
def __init__(self, user_metadata: dict[str, str] | None = None, **kwargs: Any) -> None:
self._constructor_kwargs = kwargs
@ -106,6 +131,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
self._sample_backoff_until: float | None = None
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
self._telemetry: Telemetry | None = None
self._session_heartbeat_task: asyncio.Task[None] | None = None
session_id, session_heartbeat_task = self.run_coroutine_threadsafe(
@ -120,6 +147,37 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
self._sampling_client_counter: int = 0
@asynccontextmanager
async def _sample_dispatch_count_rate_limit(self):
async with self._sample_dispatch_semaphore:
yield
@asynccontextmanager
async def _sample_dispatch_count_throttled_rate_limit(self):
async with self._sample_dispatch_throttled_semaphore:
yield
def _sample_backoff_requested_recently(self) -> bool:
return self._sample_backoff_until is not None and time.monotonic() - self._sample_backoff_until < 10
@asynccontextmanager
async def _sample_dispatch_bytes_rate_limit(self, bytes: int):
if self._sample_backoff_requested_recently():
# Rate limit more aggressively if we received backoff response recently
bytes *= 20
async with self._sample_dispatch_bytes_semaphore.acquire(bytes):
yield
@asynccontextmanager
async def sample_dispatch_rate_limit(self, estimated_bytes_count: int):
async with contextlib.AsyncExitStack() as stack:
await stack.enter_async_context(self._sample_dispatch_count_rate_limit())
if self._sample_backoff_requested_recently():
await stack.enter_async_context(self._sample_dispatch_count_throttled_rate_limit())
await stack.enter_async_context(self._sample_dispatch_bytes_rate_limit(estimated_bytes_count))
yield
async def _session_heartbeat(self, session_id: str):
SESSION_HEARTBEAT_PERIOD_SEC = 10
SESSION_MISSED_HEARTBEAT_WARNING_THRESHOLD_SEC = 60 * 2
@ -305,6 +363,16 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
raise e
def estimate_bytes_count_in_chunk(self, chunk: types.ModelInputChunk) -> int:
if isinstance(chunk, types.ImageChunk):
return len(chunk.data)
if isinstance(chunk, types.ImageAssetPointerChunk):
return len(chunk.location)
return chunk.length * 10
def estimate_bytes_count_in_model_input(self, model_input: types.ModelInput) -> int:
return sum(self.estimate_bytes_count_in_chunk(chunk) for chunk in model_input.chunks)
def _current_loop() -> asyncio.AbstractEventLoop | None:
try:

View file

@ -155,11 +155,12 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
include_prompt_logprobs: bool,
topk_prompt_logprobs: int = 0,
) -> types.SampleResponse:
async with self.holder._sample_dispatch_semaphore:
estimated_bytes_count = self.holder.estimate_bytes_count_in_model_input(prompt)
async with self.holder.sample_dispatch_rate_limit(estimated_bytes_count):
while True:
if (
self.holder._sample_backoff_until is not None
and time.time() < self.holder._sample_backoff_until
and time.monotonic() < self.holder._sample_backoff_until
):
await asyncio.sleep(1)
continue
@ -175,7 +176,8 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
if untyped_future is not None:
break
# Handle backoff
self.holder._sample_backoff_until = time.time() + 1
backoff_duration = 1 if estimated_bytes_count <= 128 * 1024 else 5
self.holder._sample_backoff_until = time.monotonic() + backoff_duration
continue
return await _APIFuture(
@ -298,22 +300,23 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
def on_queue_state_change(self, queue_state: QueueState) -> None:
def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None:
QUEUE_STATE_LOG_INTERVAL = 60
if queue_state == QueueState.ACTIVE:
return
if time.time() - self._last_queue_state_logged < QUEUE_STATE_LOG_INTERVAL:
return
if not queue_state_reason:
if queue_state == QueueState.PAUSED_RATE_LIMIT:
reason = "concurrent LoRA rate limit hit"
queue_state_reason = "concurrent sampler weights limit hit"
elif queue_state == QueueState.PAUSED_CAPACITY:
reason = "out of capacity"
queue_state_reason = "Tinker backend is running short on capacity, please wait"
else:
reason = "unknown"
queue_state_reason = "unknown"
self._last_queue_state_logged = time.time()
logger.warning(
f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {reason}"
f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {queue_state_reason}"
)

View file

@ -141,10 +141,11 @@ class ServiceClient(TelemetryProvider):
request_type="CreateModel",
).result_async()
model_id = create_model_response.model_id
logger.info(f"Creating TrainingClient for {model_id=}")
from .training_client import TrainingClient
return TrainingClient(self.holder, model_seq_id=model_seq_id, model_id=model_id)
training_client = TrainingClient(self.holder, model_seq_id=model_seq_id, model_id=model_id)
logger.info(f"TrainingClient initialized for model {model_id}")
return training_client
return self.holder.run_coroutine_threadsafe(_create_lora_training_client_async())

View file

@ -40,8 +40,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# FwdBwdChunkSize
MAX_CHUNK_LEN = 128
MAX_CHUNK_NUMBER_COUNT = 500000
MAX_CHUNK_LEN = 1024
MAX_CHUNK_BYTES_COUNT = 5000000
MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initializing the TrainingClient with a model_id by either calling create_lora_training_client on the ServiceClient, or initiliazing the TrainingClient with an existing model_id."
CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]]
@ -121,36 +121,27 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
assert self.model_id is not None, MODEL_ID_NOT_SET_ERROR
return self.model_id
def _estimate_number_count_in_chunk(self, chunk: types.ModelInputChunk) -> int:
if isinstance(chunk, types.ImageChunk):
return len(chunk.data)
if isinstance(chunk, types.ImageAssetPointerChunk):
return len(chunk.location)
return chunk.length
def _estimate_number_count(self, datum: types.Datum) -> int:
return sum(
self._estimate_number_count_in_chunk(chunk) for chunk in datum.model_input.chunks
) + sum(len(value.data) for _, value in datum.loss_fn_inputs.items())
def _estimate_bytes_count(self, datum: types.Datum) -> int:
return self.holder.estimate_bytes_count_in_model_input(datum.model_input) + sum(len(value.data) * 10 for _, value in datum.loss_fn_inputs.items())
def _chunked_requests_generator(
self, data: List[types.Datum]
) -> Generator[List[types.Datum], None, None]:
current_chunk: List[types.Datum] = []
current_chunk_number_count = 0
current_chunk_bytes_count = 0
for datum in data:
estimated_number_count = self._estimate_number_count(datum)
estimated_bytes_count = self._estimate_bytes_count(datum)
if (
len(current_chunk) > 0
and current_chunk_number_count + estimated_number_count > MAX_CHUNK_NUMBER_COUNT
and current_chunk_bytes_count + estimated_bytes_count > MAX_CHUNK_BYTES_COUNT
) or (len(current_chunk) == MAX_CHUNK_LEN):
yield current_chunk
current_chunk = []
current_chunk_number_count = 0
current_chunk_bytes_count = 0
current_chunk.append(datum)
current_chunk_number_count += estimated_number_count
current_chunk_bytes_count += estimated_bytes_count
if len(current_chunk) > 0:
yield current_chunk
@ -840,7 +831,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
def on_queue_state_change(self, queue_state: QueueState) -> None:
def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None:
QUEUE_STATE_LOG_INTERVAL = 60
if queue_state == QueueState.ACTIVE:
return
@ -848,13 +839,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
return
self._last_queue_state_logged = time.time()
if not queue_state_reason:
if queue_state == QueueState.PAUSED_RATE_LIMIT:
reason = "concurrent models rate limit hit"
queue_state_reason = "concurrent training clients rate limit hit"
elif queue_state == QueueState.PAUSED_CAPACITY:
reason = "out of capacity"
queue_state_reason = "Tinker backend is running short on capacity, please wait"
else:
reason = "unknown"
logger.warning(f"Training is paused for {self.model_id}. Reason: {reason}")
queue_state_reason = "unknown"
logger.warning(f"Training is paused for {self.model_id}. Reason: {queue_state_reason}")
def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer:

View file

@ -142,8 +142,8 @@ class RetryHandler(Generic[T]): # noqa: UP046
# Create a dummy request for the exception (required by APIConnectionError)
dummy_request = httpx.Request("GET", "http://localhost")
raise tinker.APIConnectionError(
message=f"No progress made in {self.config.progress_timeout}s. Requests appear to be stuck.",
request=dummy_request,
message=f"No progress made in {self.config.progress_timeout}s. Requests appear to be stuck.",
)
raise
finally:

View file

@ -21,6 +21,12 @@ class AdamParams(StrictBase):
eps: float = 1e-12
"""Term added to the denominator to improve numerical stability"""
weight_decay: float = 0.0
"""Weight decay for the optimizer. Uses decoupled weight decay."""
grad_clip_norm: float = 0.0
"""Gradient clip norm for the optimizer. 0.0 means no clipping."""
class OptimStepRequest(StrictBase):
adam_params: AdamParams