diff --git a/.sync_state b/.sync_state index 4ed173d..124df6f 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "1c67712168cd74fc2bda614f239dcad5f7b2e166", - "last_sync_time": "2026-02-13T22:05:28.522838" + "last_synced_sha": "7a5a376dc40cd21c3c744270af661b021e9fa9c6", + "last_sync_time": "2026-02-27T05:49:39.458251" } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index dd16ad4..b44e15e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tinker" -version = "0.13.1" +version = "0.14.0" description = "The official Python SDK for the tinker API" readme = "README.md" license = "Apache-2.0" diff --git a/src/tinker/lib/api_future_impl.py b/src/tinker/lib/api_future_impl.py index 79548aa..c209088 100644 --- a/src/tinker/lib/api_future_impl.py +++ b/src/tinker/lib/api_future_impl.py @@ -85,186 +85,200 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass] start_time = time.time() iteration = -1 connection_error_retries = 0 + allow_metadata_only = True - while True: - iteration += 1 + async with contextlib.AsyncExitStack() as stack: + while True: + iteration += 1 - if timeout is not None and time.time() - start_time > timeout: - if telemetry := self.get_telemetry(): - current_time = time.time() - telemetry.log( - "APIFuture.result_async.timeout", - event_data={ - "request_id": self.request_id, - "request_type": self.request_type, - "timeout": timeout, - "iteration": iteration, - "elapsed_time": current_time - start_time, - }, - severity="ERROR", - ) - raise TimeoutError( - f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}" - ) - - # Headers for telemetry - headers = { - "X-Tinker-Request-Iteration": str(iteration), - "X-Tinker-Request-Type": self.request_type, - } - if iteration == 0: - headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str( - self.request_queue_roundtrip_time - ) - - # Function hasn't been called yet, execute it now - try: - with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client: - response = await client.futures.with_raw_response.retrieve( - request=FutureRetrieveRequest(request_id=self.request_id), - timeout=45, - extra_headers=headers, - max_retries=0, - ) - except tinker.APIStatusError as e: - connection_error_retries = 0 - should_retry = e.status_code == 408 or e.status_code in range(500, 600) - user_error = is_user_error(e) - if telemetry := self.get_telemetry(): - current_time = time.time() - telemetry.log( - "APIFuture.result_async.api_status_error", - event_data={ - "request_id": self.request_id, - "request_type": self.request_type, - "status_code": e.status_code, - "exception": str(e), - "should_retry": should_retry, - "is_user_error": user_error, - "iteration": iteration, - "elapsed_time": current_time - start_time, - }, - severity="WARNING" if should_retry or user_error else "ERROR", + if timeout is not None and time.time() - start_time > timeout: + if telemetry := self.get_telemetry(): + current_time = time.time() + telemetry.log( + "APIFuture.result_async.timeout", + event_data={ + "request_id": self.request_id, + "request_type": self.request_type, + "timeout": timeout, + "iteration": iteration, + "elapsed_time": current_time - start_time, + }, + severity="ERROR", + ) + raise TimeoutError( + f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}" ) - # Retry 408s until we time out - if e.status_code == 408: - if self._queue_state_observer is not None: - with contextlib.suppress(Exception): - response = e.response.json() - if queue_state_str := response.get("queue_state", None): - queue_state_reason = response.get("queue_state_reason", None) - if queue_state_str == "active": - queue_state = QueueState.ACTIVE - elif queue_state_str == "paused_rate_limit": - queue_state = QueueState.PAUSED_RATE_LIMIT - elif queue_state_str == "paused_capacity": - queue_state = QueueState.PAUSED_CAPACITY - else: - queue_state = QueueState.UNKNOWN - self._queue_state_observer.on_queue_state_change( - queue_state, queue_state_reason - ) - continue - if e.status_code == 410: - raise RetryableException( - message=f"Promise expired/broken for request {self.untyped_future.request_id}" + headers = { + "X-Tinker-Request-Iteration": str(iteration), + "X-Tinker-Request-Type": self.request_type, + } + if iteration == 0: + headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str( + self.request_queue_roundtrip_time + ) + + try: + with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client: + response = await client.futures.with_raw_response.retrieve( + request=FutureRetrieveRequest( + request_id=self.request_id, + allow_metadata_only=allow_metadata_only, + ), + timeout=45, + extra_headers=headers, + max_retries=0, + ) + except tinker.APIStatusError as e: + connection_error_retries = 0 + should_retry = e.status_code == 408 or e.status_code in range(500, 600) + user_error = is_user_error(e) + if telemetry := self.get_telemetry(): + current_time = time.time() + telemetry.log( + "APIFuture.result_async.api_status_error", + event_data={ + "request_id": self.request_id, + "request_type": self.request_type, + "status_code": e.status_code, + "exception": str(e), + "should_retry": should_retry, + "is_user_error": user_error, + "iteration": iteration, + "elapsed_time": current_time - start_time, + }, + severity="WARNING" if should_retry or user_error else "ERROR", + ) + + # Retry 408s until we time out + if e.status_code == 408: + if self._queue_state_observer is not None: + with contextlib.suppress(Exception): + response = e.response.json() + if queue_state_str := response.get("queue_state", None): + queue_state_reason = response.get("queue_state_reason", None) + if queue_state_str == "active": + queue_state = QueueState.ACTIVE + elif queue_state_str == "paused_rate_limit": + queue_state = QueueState.PAUSED_RATE_LIMIT + elif queue_state_str == "paused_capacity": + queue_state = QueueState.PAUSED_CAPACITY + else: + queue_state = QueueState.UNKNOWN + self._queue_state_observer.on_queue_state_change( + queue_state, queue_state_reason + ) + continue + if e.status_code == 410: + raise RetryableException( + message=f"Promise expired/broken for request {self.untyped_future.request_id}" + ) from e + if e.status_code in range(500, 600): + continue + raise ValueError( + f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}" ) from e - if e.status_code in range(500, 600): + except tinker.APIConnectionError as e: + if telemetry := self.get_telemetry(): + current_time = time.time() + telemetry.log( + "APIFuture.result_async.connection_error", + event_data={ + "request_id": self.request_id, + "request_type": self.request_type, + "exception": str(e), + "connection_error_retries": connection_error_retries, + "iteration": iteration, + "elapsed_time": current_time - start_time, + }, + severity="WARNING", + ) + + # Retry all connection errors with exponential backoff + await asyncio.sleep(min(2**connection_error_retries, 30)) + connection_error_retries += 1 continue - raise ValueError( - f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}" - ) from e - except tinker.APIConnectionError as e: - if telemetry := self.get_telemetry(): - current_time = time.time() - telemetry.log( - "APIFuture.result_async.connection_error", - event_data={ - "request_id": self.request_id, - "request_type": self.request_type, - "exception": str(e), - "connection_error_retries": connection_error_retries, - "iteration": iteration, - "elapsed_time": current_time - start_time, - }, - severity="WARNING", + + result_dict: Any = await response.json() + + if "type" in result_dict and result_dict["type"] == "try_again": + logger.warning(f"Retrying request {self.request_id=} because of try_again") + continue + + if result_dict.get("status") == "complete_metadata": + # metadata only response should be returned only once + assert allow_metadata_only + allow_metadata_only = False + + response_payload_size = result_dict.get("response_payload_size", 0) + assert response_payload_size is not None + await stack.enter_async_context( + self.holder._inflight_response_bytes_semaphore.acquire( + response_payload_size + ) + ) + continue + + if "error" in result_dict: + error_category = RequestErrorCategory.Unknown + with contextlib.suppress(Exception): + error_category = RequestErrorCategory(result_dict.get("category")) + + user_error = error_category is RequestErrorCategory.User + if telemetry := self.get_telemetry(): + current_time = time.time() + telemetry.log( + "APIFuture.result_async.application_error", + event_data={ + "request_id": self.request_id, + "request_type": self.request_type, + "error": result_dict["error"], + "error_category": error_category.name, + "is_user_error": user_error, + "iteration": iteration, + "elapsed_time": current_time - start_time, + }, + severity="WARNING" if user_error else "ERROR", + ) + + error_message = result_dict["error"] + raise RequestFailedError( + f"Request failed: {error_message} for {self.request_id=} and expected type {self.model_cls=}", + request_id=self.request_id, + category=error_category, ) - # Retry all connection errors with exponential backoff - await asyncio.sleep(min(2**connection_error_retries, 30)) - connection_error_retries += 1 - continue + try: + if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel): + self._cached_result = self.model_cls.model_validate(result_dict) + else: + self._cached_result = result_dict + return cast(T, self._cached_result) + except Exception as e: + if telemetry := self.get_telemetry(): + current_time = time.time() + telemetry.log( + "APIFuture.result_async.validation_error", + event_data={ + "request_id": self.request_id, + "request_type": self.request_type, + "exception": str(e), + "exception_type": type(e).__name__, + "exception_stack": "".join( + traceback.format_exception(type(e), e, e.__traceback__) + ) + if e.__traceback__ + else None, + "model_cls": str(self.model_cls), + "iteration": iteration, + "elapsed_time": current_time - start_time, + }, + severity="ERROR", + ) - # Function hasn't been called yet, execute it now - result_dict: Any = await response.json() - - if "type" in result_dict and result_dict["type"] == "try_again": - logger.warning(f"Retrying request {self.request_id=} because of try_again") - continue - - if "error" in result_dict: - error_category = RequestErrorCategory.Unknown - with contextlib.suppress(Exception): - error_category = RequestErrorCategory(result_dict.get("category")) - - user_error = error_category is RequestErrorCategory.User - if telemetry := self.get_telemetry(): - current_time = time.time() - telemetry.log( - "APIFuture.result_async.application_error", - event_data={ - "request_id": self.request_id, - "request_type": self.request_type, - "error": result_dict["error"], - "error_category": error_category.name, - "is_user_error": user_error, - "iteration": iteration, - "elapsed_time": current_time - start_time, - }, - severity="WARNING" if user_error else "ERROR", - ) - - error_message = result_dict["error"] - raise RequestFailedError( - f"Request failed: {error_message} for {self.request_id=} and expected type {self.model_cls=}", - request_id=self.request_id, - category=error_category, - ) - - try: - # Check if model_cls is a BaseModel subclass before calling model_validate - if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel): - self._cached_result = self.model_cls.model_validate(result_dict) - else: - # For non-BaseModel types, just return the result directly - self._cached_result = result_dict - return cast(T, self._cached_result) - except Exception as e: - if telemetry := self.get_telemetry(): - current_time = time.time() - telemetry.log( - "APIFuture.result_async.validation_error", - event_data={ - "request_id": self.request_id, - "request_type": self.request_type, - "exception": str(e), - "exception_type": type(e).__name__, - "exception_stack": "".join( - traceback.format_exception(type(e), e, e.__traceback__) - ) - if e.__traceback__ - else None, - "model_cls": str(self.model_cls), - "iteration": iteration, - "elapsed_time": current_time - start_time, - }, - severity="ERROR", - ) - - raise ValueError( - f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}" - ) from e + raise ValueError( + f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}" + ) from e @property def request_id(self) -> str: diff --git a/src/tinker/lib/internal_client_holder.py b/src/tinker/lib/internal_client_holder.py index b4769a9..7cdd901 100644 --- a/src/tinker/lib/internal_client_holder.py +++ b/src/tinker/lib/internal_client_holder.py @@ -162,6 +162,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400) self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10) self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024) + self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024) self._training_client_lock: threading.Lock = threading.Lock() if session_id is not None: @@ -336,7 +337,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): def close(self): self.run_coroutine_threadsafe(self._async_cleanup()) - if telemetry := self._telemetry: + if telemetry := getattr(self, "_telemetry", None): telemetry.stop() def __del__(self): diff --git a/src/tinker/types/forward_backward_output.py b/src/tinker/types/forward_backward_output.py index 502af09..2b9b361 100644 --- a/src/tinker/types/forward_backward_output.py +++ b/src/tinker/types/forward_backward_output.py @@ -14,4 +14,35 @@ class ForwardBackwardOutput(BaseModel): """Dictionary mapping field names to tensor data""" metrics: Dict[str, float] - """Training metrics as key-value pairs""" + """Training metrics as key-value pairs. + + The following metrics are recorded only during MoE (Mixture of Experts) training. + Note: Don't fixate on the exact values of these metrics at the start of training. + Different models on different data will have different initial values. How these + metrics evolve over training is what matters. + + In the definitions below, *perfect balance* means ``total_tokens / num_experts`` + — the number of tokens each expert would receive if routing were perfectly uniform. + + - ``e_frac_with_tokens:mean``: Fraction of experts that received at least one token, + averaged across layers. A value of 1.0 means every expert got work; 0.5 means half + were idle. Decreasing over time is concerning (routing collapse). + + - ``e_frac_oversubscribed:mean``: Fraction of experts receiving more tokens than + perfect balance, averaged across layers. Increasing over time is concerning. + + - ``e_max_violation:mean``: How much the most overloaded expert exceeds perfect + balance, as a fraction of perfect balance, averaged across layers. Computed as + ``(max_tokens - perfect_balance) / perfect_balance``. A value of 2.0 means the + busiest expert got 3x the fair share. Increasing over time is concerning. + + - ``e_max_violation:max``: Same as ``e_max_violation:mean`` but takes the max + across layers instead of the mean. Shows the worst-case load imbalance in any + single layer. + + - ``e_min_violation:mean``: How much the least loaded expert is below perfect + balance, as a fraction of perfect balance, averaged across layers. Computed as + ``(min_tokens - perfect_balance) / perfect_balance``. A value of -0.5 means the + least-used expert got half the fair share; -1.0 means it got nothing. Typically + negative. Decreasing over time (more negative) is concerning. + """ diff --git a/src/tinker/types/future_retrieve_request.py b/src/tinker/types/future_retrieve_request.py index ab589f9..2802316 100644 --- a/src/tinker/types/future_retrieve_request.py +++ b/src/tinker/types/future_retrieve_request.py @@ -7,3 +7,7 @@ __all__ = ["FutureRetrieveRequest"] class FutureRetrieveRequest(StrictBase): request_id: RequestID """The ID of the request to retrieve""" + + allow_metadata_only: bool = False + """When True, the server may return only response metadata (status and size) + instead of the full payload if the response exceeds the server's inline size limit.""" diff --git a/test_fixes.py b/test_fixes.py deleted file mode 100644 index 7f8d200..0000000 --- a/test_fixes.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python - -from httpx import URL - -from tinker._base_client import BaseClient -from tinker.lib.public_interfaces import _convert_forward_backward_input -from tinker.types import ForwardBackwardInput - - -class TestBaseClient(BaseClient): - def __init__(self, base_url: str = "http://localhost:8000"): - # Test the fixed base_url parameter handling - super().__init__(version="1.0.0", base_url=base_url, _strict_response_validation=True) - - def make_status_error(self, err_msg: str, body: object, response: object) -> None: - return None - - def _enforce_trailing_slash(self, url: URL) -> URL: - return url - - -# Test that the base_url parameter is handled correctly -client = TestBaseClient("http://example.com") -print(f"Base URL correctly set to: {client._base_url}") - -# Create a proper dictionary to initialize the model -model_data = { - "data": [ - { - "model_input": {"chunks": [{"type": "encoded_text", "tokens": [1, 2, 3]}]}, - "loss_fn_inputs": {}, - } - ], - "loss_fn": "cross_entropy", # Use a valid value from the enum -} - -# Create a mock model just to test the base_url parameter -try: - # Properly initialize the ForwardBackwardInput model - input_obj = ForwardBackwardInput.model_validate(model_data) - - # Test the convert function - result = _convert_forward_backward_input(input_obj) - print(f"Conversion successful: {result}") -except Exception as e: - # Since we're just testing the base_url parameter fix, we can ignore model validation errors - print(f"Note: Could not validate ForwardBackwardInput model: {e}") - print("But that's okay since we're just testing the base_url parameter fix") - -print("All tests passed!")