diff --git a/.sync_state b/.sync_state index e8a20d3..8e48452 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "ffa781c9432bd1afb73c278d253d0a7cd1b333d8", - "last_sync_time": "2025-12-03T01:00:03.831979" + "last_synced_sha": "2f8a0c29b227897c25451163c8fa90adb2daa6e6", + "last_sync_time": "2025-12-08T00:05:28.592609" } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 06bc56c..b1e1169 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tinker" -version = "0.6.3" +version = "0.7.0" description = "The official Python SDK for the tinker API" readme = "README.md" license = "Apache-2.0" diff --git a/src/tinker/_exceptions.py b/src/tinker/_exceptions.py index abcc17f..909872c 100644 --- a/src/tinker/_exceptions.py +++ b/src/tinker/_exceptions.py @@ -57,7 +57,7 @@ class APIResponseValidationError(APIError): status_code: int def __init__( - self, response: httpx.Response, body: object | None, *, message: str | None = None + self, response: httpx.Response, body: object | None, message: str | None = None ) -> None: super().__init__( message or "Data returned by API invalid for expected schema.", @@ -67,6 +67,14 @@ class APIResponseValidationError(APIError): self.response = response self.status_code = response.status_code + def __reduce__(self): + # Return a tuple of (callable, args) to recreate the exception + return ( + self.__class__, + (self.response, self.body, self.message), # positional args + None, + ) + class APIStatusError(APIError): """Raised when an API response has a status code of 4xx or 5xx.""" @@ -74,24 +82,41 @@ class APIStatusError(APIError): response: httpx.Response status_code: int - def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None: + def __init__(self, message: str, response: httpx.Response, body: object | None) -> None: super().__init__(message, response.request, body=body) self.response = response self.status_code = response.status_code + def __reduce__(self): + # Return a tuple of (callable, args) to recreate the exception + return ( + self.__class__, + (self.message, self.response, self.body), # positional args + None, + ) + + class APIConnectionError(APIError): """Raised when a connection error occurs while making an API request.""" - def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None: + def __init__(self, request: httpx.Request, message: str = "Connection error.") -> None: super().__init__(message, request, body=None) + def __reduce__(self): + # Return a tuple of (callable, args) to recreate the exception + return ( + self.__class__, + (self.request, self.message), # positional args + None, + ) + class APITimeoutError(APIConnectionError): """Raised when an API request times out.""" def __init__(self, request: httpx.Request) -> None: - super().__init__(message="Request timed out.", request=request) + super().__init__(request=request, message="Request timed out.", ) class BadRequestError(APIStatusError): @@ -148,10 +173,18 @@ class RequestFailedError(TinkerError): def __init__( self, message: str, - *, request_id: str, category: "RequestErrorCategory", ) -> None: super().__init__(message) + self.message: str = message self.request_id: str = request_id self.category: RequestErrorCategory = category + + def __reduce__(self): + # Return a tuple of (callable, args) to recreate the exception + return ( + self.__class__, + (self.message, self.request_id, self.category), # positional args + None, + ) diff --git a/src/tinker/lib/api_future_impl.py b/src/tinker/lib/api_future_impl.py index 03cf1d7..7be5627 100644 --- a/src/tinker/lib/api_future_impl.py +++ b/src/tinker/lib/api_future_impl.py @@ -44,7 +44,7 @@ class QueueState(Enum): class QueueStateObserver(ABC): @abstractmethod - def on_queue_state_change(self, queue_state: QueueState) -> None: + def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None: raise NotImplementedError @@ -151,6 +151,7 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass] with contextlib.suppress(Exception): response = e.response.json() if queue_state_str := response.get("queue_state", None): + queue_state_reason = response.get("queue_state_reason", None) if queue_state_str == "active": queue_state = QueueState.ACTIVE elif queue_state_str == "paused_rate_limit": @@ -159,7 +160,7 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass] queue_state = QueueState.PAUSED_CAPACITY else: queue_state = QueueState.UNKNOWN - self._queue_state_observer.on_queue_state_change(queue_state) + self._queue_state_observer.on_queue_state_change(queue_state, queue_state_reason) continue if e.status_code == 410: raise RetryableException( diff --git a/src/tinker/lib/internal_client_holder.py b/src/tinker/lib/internal_client_holder.py index f944d6a..11c7b21 100644 --- a/src/tinker/lib/internal_client_holder.py +++ b/src/tinker/lib/internal_client_holder.py @@ -10,7 +10,7 @@ import time import traceback import contextlib from collections.abc import Coroutine, Generator -from contextlib import AbstractContextManager, contextmanager +from contextlib import AbstractContextManager, asynccontextmanager, contextmanager from typing import Any, Awaitable, Callable, TypeVar import httpx @@ -97,6 +97,31 @@ class InternalClientHolderThreadSingleton: _internal_client_holder_thread_singleton = InternalClientHolderThreadSingleton() +class BytesSemaphore: + def __init__(self, max_bytes: int): + self._bytes: int = max_bytes + self._condition: asyncio.Condition = asyncio.Condition() + self._release_task: asyncio.Task[None] | None = None + + async def _release(self): + async with self._condition: + self._condition.notify_all() + + @asynccontextmanager + async def acquire(self, bytes: int): + async with self._condition: + while self._bytes < 0: + await self._condition.wait() + self._bytes -= bytes + + try: + yield + finally: + self._bytes += bytes + # Make sure the release task is never cancelled. + self._release_task = asyncio.create_task(self._release()) + + class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): def __init__(self, user_metadata: dict[str, str] | None = None, **kwargs: Any) -> None: self._constructor_kwargs = kwargs @@ -106,6 +131,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {} self._sample_backoff_until: float | None = None self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400) + self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10) + self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024) self._telemetry: Telemetry | None = None self._session_heartbeat_task: asyncio.Task[None] | None = None session_id, session_heartbeat_task = self.run_coroutine_threadsafe( @@ -120,6 +147,37 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): self._sampling_client_counter: int = 0 + @asynccontextmanager + async def _sample_dispatch_count_rate_limit(self): + async with self._sample_dispatch_semaphore: + yield + + @asynccontextmanager + async def _sample_dispatch_count_throttled_rate_limit(self): + async with self._sample_dispatch_throttled_semaphore: + yield + + def _sample_backoff_requested_recently(self) -> bool: + return self._sample_backoff_until is not None and time.monotonic() - self._sample_backoff_until < 10 + + @asynccontextmanager + async def _sample_dispatch_bytes_rate_limit(self, bytes: int): + if self._sample_backoff_requested_recently(): + # Rate limit more aggressively if we received backoff response recently + bytes *= 20 + async with self._sample_dispatch_bytes_semaphore.acquire(bytes): + yield + + @asynccontextmanager + async def sample_dispatch_rate_limit(self, estimated_bytes_count: int): + async with contextlib.AsyncExitStack() as stack: + await stack.enter_async_context(self._sample_dispatch_count_rate_limit()) + if self._sample_backoff_requested_recently(): + await stack.enter_async_context(self._sample_dispatch_count_throttled_rate_limit()) + await stack.enter_async_context(self._sample_dispatch_bytes_rate_limit(estimated_bytes_count)) + + yield + async def _session_heartbeat(self, session_id: str): SESSION_HEARTBEAT_PERIOD_SEC = 10 SESSION_MISSED_HEARTBEAT_WARNING_THRESHOLD_SEC = 60 * 2 @@ -305,6 +363,16 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): raise e + def estimate_bytes_count_in_chunk(self, chunk: types.ModelInputChunk) -> int: + if isinstance(chunk, types.ImageChunk): + return len(chunk.data) + if isinstance(chunk, types.ImageAssetPointerChunk): + return len(chunk.location) + return chunk.length * 10 + + def estimate_bytes_count_in_model_input(self, model_input: types.ModelInput) -> int: + return sum(self.estimate_bytes_count_in_chunk(chunk) for chunk in model_input.chunks) + def _current_loop() -> asyncio.AbstractEventLoop | None: try: diff --git a/src/tinker/lib/public_interfaces/sampling_client.py b/src/tinker/lib/public_interfaces/sampling_client.py index dd0b1e2..fd913fe 100644 --- a/src/tinker/lib/public_interfaces/sampling_client.py +++ b/src/tinker/lib/public_interfaces/sampling_client.py @@ -155,11 +155,12 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): include_prompt_logprobs: bool, topk_prompt_logprobs: int = 0, ) -> types.SampleResponse: - async with self.holder._sample_dispatch_semaphore: + estimated_bytes_count = self.holder.estimate_bytes_count_in_model_input(prompt) + async with self.holder.sample_dispatch_rate_limit(estimated_bytes_count): while True: if ( self.holder._sample_backoff_until is not None - and time.time() < self.holder._sample_backoff_until + and time.monotonic() < self.holder._sample_backoff_until ): await asyncio.sleep(1) continue @@ -175,7 +176,8 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): if untyped_future is not None: break # Handle backoff - self.holder._sample_backoff_until = time.time() + 1 + backoff_duration = 1 if estimated_bytes_count <= 128 * 1024 else 5 + self.holder._sample_backoff_until = time.monotonic() + backoff_duration continue return await _APIFuture( @@ -298,22 +300,23 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): def get_telemetry(self) -> Telemetry | None: return self.holder.get_telemetry() - def on_queue_state_change(self, queue_state: QueueState) -> None: + def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None: QUEUE_STATE_LOG_INTERVAL = 60 if queue_state == QueueState.ACTIVE: return if time.time() - self._last_queue_state_logged < QUEUE_STATE_LOG_INTERVAL: return - if queue_state == QueueState.PAUSED_RATE_LIMIT: - reason = "concurrent LoRA rate limit hit" - elif queue_state == QueueState.PAUSED_CAPACITY: - reason = "out of capacity" - else: - reason = "unknown" + if not queue_state_reason: + if queue_state == QueueState.PAUSED_RATE_LIMIT: + queue_state_reason = "concurrent sampler weights limit hit" + elif queue_state == QueueState.PAUSED_CAPACITY: + queue_state_reason = "Tinker backend is running short on capacity, please wait" + else: + queue_state_reason = "unknown" self._last_queue_state_logged = time.time() logger.warning( - f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {reason}" + f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {queue_state_reason}" ) diff --git a/src/tinker/lib/public_interfaces/service_client.py b/src/tinker/lib/public_interfaces/service_client.py index c1893bd..72ee967 100644 --- a/src/tinker/lib/public_interfaces/service_client.py +++ b/src/tinker/lib/public_interfaces/service_client.py @@ -141,10 +141,11 @@ class ServiceClient(TelemetryProvider): request_type="CreateModel", ).result_async() model_id = create_model_response.model_id - logger.info(f"Creating TrainingClient for {model_id=}") from .training_client import TrainingClient - return TrainingClient(self.holder, model_seq_id=model_seq_id, model_id=model_id) + training_client = TrainingClient(self.holder, model_seq_id=model_seq_id, model_id=model_id) + logger.info(f"TrainingClient initialized for model {model_id}") + return training_client return self.holder.run_coroutine_threadsafe(_create_lora_training_client_async()) diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index a15b6bf..78d8296 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -40,8 +40,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) # FwdBwdChunkSize -MAX_CHUNK_LEN = 128 -MAX_CHUNK_NUMBER_COUNT = 500000 +MAX_CHUNK_LEN = 1024 +MAX_CHUNK_BYTES_COUNT = 5000000 MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initializing the TrainingClient with a model_id by either calling create_lora_training_client on the ServiceClient, or initiliazing the TrainingClient with an existing model_id." CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]] @@ -121,36 +121,27 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): assert self.model_id is not None, MODEL_ID_NOT_SET_ERROR return self.model_id - def _estimate_number_count_in_chunk(self, chunk: types.ModelInputChunk) -> int: - if isinstance(chunk, types.ImageChunk): - return len(chunk.data) - if isinstance(chunk, types.ImageAssetPointerChunk): - return len(chunk.location) - return chunk.length - - def _estimate_number_count(self, datum: types.Datum) -> int: - return sum( - self._estimate_number_count_in_chunk(chunk) for chunk in datum.model_input.chunks - ) + sum(len(value.data) for _, value in datum.loss_fn_inputs.items()) + def _estimate_bytes_count(self, datum: types.Datum) -> int: + return self.holder.estimate_bytes_count_in_model_input(datum.model_input) + sum(len(value.data) * 10 for _, value in datum.loss_fn_inputs.items()) def _chunked_requests_generator( self, data: List[types.Datum] ) -> Generator[List[types.Datum], None, None]: current_chunk: List[types.Datum] = [] - current_chunk_number_count = 0 + current_chunk_bytes_count = 0 for datum in data: - estimated_number_count = self._estimate_number_count(datum) + estimated_bytes_count = self._estimate_bytes_count(datum) if ( len(current_chunk) > 0 - and current_chunk_number_count + estimated_number_count > MAX_CHUNK_NUMBER_COUNT + and current_chunk_bytes_count + estimated_bytes_count > MAX_CHUNK_BYTES_COUNT ) or (len(current_chunk) == MAX_CHUNK_LEN): yield current_chunk current_chunk = [] - current_chunk_number_count = 0 + current_chunk_bytes_count = 0 current_chunk.append(datum) - current_chunk_number_count += estimated_number_count + current_chunk_bytes_count += estimated_bytes_count if len(current_chunk) > 0: yield current_chunk @@ -840,7 +831,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): def get_telemetry(self) -> Telemetry | None: return self.holder.get_telemetry() - def on_queue_state_change(self, queue_state: QueueState) -> None: + def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None: QUEUE_STATE_LOG_INTERVAL = 60 if queue_state == QueueState.ACTIVE: return @@ -848,13 +839,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): return self._last_queue_state_logged = time.time() - if queue_state == QueueState.PAUSED_RATE_LIMIT: - reason = "concurrent models rate limit hit" - elif queue_state == QueueState.PAUSED_CAPACITY: - reason = "out of capacity" - else: - reason = "unknown" - logger.warning(f"Training is paused for {self.model_id}. Reason: {reason}") + if not queue_state_reason: + if queue_state == QueueState.PAUSED_RATE_LIMIT: + queue_state_reason = "concurrent training clients rate limit hit" + elif queue_state == QueueState.PAUSED_CAPACITY: + queue_state_reason = "Tinker backend is running short on capacity, please wait" + else: + queue_state_reason = "unknown" + logger.warning(f"Training is paused for {self.model_id}. Reason: {queue_state_reason}") def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer: diff --git a/src/tinker/lib/retry_handler.py b/src/tinker/lib/retry_handler.py index ef0197d..887abe3 100644 --- a/src/tinker/lib/retry_handler.py +++ b/src/tinker/lib/retry_handler.py @@ -142,8 +142,8 @@ class RetryHandler(Generic[T]): # noqa: UP046 # Create a dummy request for the exception (required by APIConnectionError) dummy_request = httpx.Request("GET", "http://localhost") raise tinker.APIConnectionError( - message=f"No progress made in {self.config.progress_timeout}s. Requests appear to be stuck.", request=dummy_request, + message=f"No progress made in {self.config.progress_timeout}s. Requests appear to be stuck.", ) raise finally: diff --git a/src/tinker/types/optim_step_request.py b/src/tinker/types/optim_step_request.py index 1c207b1..cf537e3 100644 --- a/src/tinker/types/optim_step_request.py +++ b/src/tinker/types/optim_step_request.py @@ -21,6 +21,12 @@ class AdamParams(StrictBase): eps: float = 1e-12 """Term added to the denominator to improve numerical stability""" + weight_decay: float = 0.0 + """Weight decay for the optimizer. Uses decoupled weight decay.""" + + grad_clip_norm: float = 0.0 + """Gradient clip norm for the optimizer. 0.0 means no clipping.""" + class OptimStepRequest(StrictBase): adam_params: AdamParams