Sync contents

This commit is contained in:
Andrii Grynenko 2026-02-27 05:49:39 +00:00
parent a187cb5b2f
commit 77cc8a0810
7 changed files with 226 additions and 226 deletions

View file

@ -1,4 +1,4 @@
{
"last_synced_sha": "1c67712168cd74fc2bda614f239dcad5f7b2e166",
"last_sync_time": "2026-02-13T22:05:28.522838"
"last_synced_sha": "7a5a376dc40cd21c3c744270af661b021e9fa9c6",
"last_sync_time": "2026-02-27T05:49:39.458251"
}

View file

@ -1,6 +1,6 @@
[project]
name = "tinker"
version = "0.13.1"
version = "0.14.0"
description = "The official Python SDK for the tinker API"
readme = "README.md"
license = "Apache-2.0"

View file

@ -85,186 +85,200 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
start_time = time.time()
iteration = -1
connection_error_retries = 0
allow_metadata_only = True
while True:
iteration += 1
async with contextlib.AsyncExitStack() as stack:
while True:
iteration += 1
if timeout is not None and time.time() - start_time > timeout:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.timeout",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"timeout": timeout,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="ERROR",
)
raise TimeoutError(
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
)
# Headers for telemetry
headers = {
"X-Tinker-Request-Iteration": str(iteration),
"X-Tinker-Request-Type": self.request_type,
}
if iteration == 0:
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
self.request_queue_roundtrip_time
)
# Function hasn't been called yet, execute it now
try:
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
response = await client.futures.with_raw_response.retrieve(
request=FutureRetrieveRequest(request_id=self.request_id),
timeout=45,
extra_headers=headers,
max_retries=0,
)
except tinker.APIStatusError as e:
connection_error_retries = 0
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
user_error = is_user_error(e)
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.api_status_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"status_code": e.status_code,
"exception": str(e),
"should_retry": should_retry,
"is_user_error": user_error,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if should_retry or user_error else "ERROR",
if timeout is not None and time.time() - start_time > timeout:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.timeout",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"timeout": timeout,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="ERROR",
)
raise TimeoutError(
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
)
# Retry 408s until we time out
if e.status_code == 408:
if self._queue_state_observer is not None:
with contextlib.suppress(Exception):
response = e.response.json()
if queue_state_str := response.get("queue_state", None):
queue_state_reason = response.get("queue_state_reason", None)
if queue_state_str == "active":
queue_state = QueueState.ACTIVE
elif queue_state_str == "paused_rate_limit":
queue_state = QueueState.PAUSED_RATE_LIMIT
elif queue_state_str == "paused_capacity":
queue_state = QueueState.PAUSED_CAPACITY
else:
queue_state = QueueState.UNKNOWN
self._queue_state_observer.on_queue_state_change(
queue_state, queue_state_reason
)
continue
if e.status_code == 410:
raise RetryableException(
message=f"Promise expired/broken for request {self.untyped_future.request_id}"
headers = {
"X-Tinker-Request-Iteration": str(iteration),
"X-Tinker-Request-Type": self.request_type,
}
if iteration == 0:
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
self.request_queue_roundtrip_time
)
try:
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
response = await client.futures.with_raw_response.retrieve(
request=FutureRetrieveRequest(
request_id=self.request_id,
allow_metadata_only=allow_metadata_only,
),
timeout=45,
extra_headers=headers,
max_retries=0,
)
except tinker.APIStatusError as e:
connection_error_retries = 0
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
user_error = is_user_error(e)
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.api_status_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"status_code": e.status_code,
"exception": str(e),
"should_retry": should_retry,
"is_user_error": user_error,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if should_retry or user_error else "ERROR",
)
# Retry 408s until we time out
if e.status_code == 408:
if self._queue_state_observer is not None:
with contextlib.suppress(Exception):
response = e.response.json()
if queue_state_str := response.get("queue_state", None):
queue_state_reason = response.get("queue_state_reason", None)
if queue_state_str == "active":
queue_state = QueueState.ACTIVE
elif queue_state_str == "paused_rate_limit":
queue_state = QueueState.PAUSED_RATE_LIMIT
elif queue_state_str == "paused_capacity":
queue_state = QueueState.PAUSED_CAPACITY
else:
queue_state = QueueState.UNKNOWN
self._queue_state_observer.on_queue_state_change(
queue_state, queue_state_reason
)
continue
if e.status_code == 410:
raise RetryableException(
message=f"Promise expired/broken for request {self.untyped_future.request_id}"
) from e
if e.status_code in range(500, 600):
continue
raise ValueError(
f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}"
) from e
if e.status_code in range(500, 600):
except tinker.APIConnectionError as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.connection_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"connection_error_retries": connection_error_retries,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING",
)
# Retry all connection errors with exponential backoff
await asyncio.sleep(min(2**connection_error_retries, 30))
connection_error_retries += 1
continue
raise ValueError(
f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}"
) from e
except tinker.APIConnectionError as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.connection_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"connection_error_retries": connection_error_retries,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING",
result_dict: Any = await response.json()
if "type" in result_dict and result_dict["type"] == "try_again":
logger.warning(f"Retrying request {self.request_id=} because of try_again")
continue
if result_dict.get("status") == "complete_metadata":
# metadata only response should be returned only once
assert allow_metadata_only
allow_metadata_only = False
response_payload_size = result_dict.get("response_payload_size", 0)
assert response_payload_size is not None
await stack.enter_async_context(
self.holder._inflight_response_bytes_semaphore.acquire(
response_payload_size
)
)
continue
if "error" in result_dict:
error_category = RequestErrorCategory.Unknown
with contextlib.suppress(Exception):
error_category = RequestErrorCategory(result_dict.get("category"))
user_error = error_category is RequestErrorCategory.User
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.application_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"error": result_dict["error"],
"error_category": error_category.name,
"is_user_error": user_error,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if user_error else "ERROR",
)
error_message = result_dict["error"]
raise RequestFailedError(
f"Request failed: {error_message} for {self.request_id=} and expected type {self.model_cls=}",
request_id=self.request_id,
category=error_category,
)
# Retry all connection errors with exponential backoff
await asyncio.sleep(min(2**connection_error_retries, 30))
connection_error_retries += 1
continue
try:
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
self._cached_result = self.model_cls.model_validate(result_dict)
else:
self._cached_result = result_dict
return cast(T, self._cached_result)
except Exception as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.validation_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"exception_type": type(e).__name__,
"exception_stack": "".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
if e.__traceback__
else None,
"model_cls": str(self.model_cls),
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="ERROR",
)
# Function hasn't been called yet, execute it now
result_dict: Any = await response.json()
if "type" in result_dict and result_dict["type"] == "try_again":
logger.warning(f"Retrying request {self.request_id=} because of try_again")
continue
if "error" in result_dict:
error_category = RequestErrorCategory.Unknown
with contextlib.suppress(Exception):
error_category = RequestErrorCategory(result_dict.get("category"))
user_error = error_category is RequestErrorCategory.User
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.application_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"error": result_dict["error"],
"error_category": error_category.name,
"is_user_error": user_error,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if user_error else "ERROR",
)
error_message = result_dict["error"]
raise RequestFailedError(
f"Request failed: {error_message} for {self.request_id=} and expected type {self.model_cls=}",
request_id=self.request_id,
category=error_category,
)
try:
# Check if model_cls is a BaseModel subclass before calling model_validate
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
self._cached_result = self.model_cls.model_validate(result_dict)
else:
# For non-BaseModel types, just return the result directly
self._cached_result = result_dict
return cast(T, self._cached_result)
except Exception as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.validation_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"exception_type": type(e).__name__,
"exception_stack": "".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
if e.__traceback__
else None,
"model_cls": str(self.model_cls),
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="ERROR",
)
raise ValueError(
f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}"
) from e
raise ValueError(
f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}"
) from e
@property
def request_id(self) -> str:

View file

@ -162,6 +162,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
self._training_client_lock: threading.Lock = threading.Lock()
if session_id is not None:
@ -336,7 +337,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
def close(self):
self.run_coroutine_threadsafe(self._async_cleanup())
if telemetry := self._telemetry:
if telemetry := getattr(self, "_telemetry", None):
telemetry.stop()
def __del__(self):

View file

@ -14,4 +14,35 @@ class ForwardBackwardOutput(BaseModel):
"""Dictionary mapping field names to tensor data"""
metrics: Dict[str, float]
"""Training metrics as key-value pairs"""
"""Training metrics as key-value pairs.
The following metrics are recorded only during MoE (Mixture of Experts) training.
Note: Don't fixate on the exact values of these metrics at the start of training.
Different models on different data will have different initial values. How these
metrics evolve over training is what matters.
In the definitions below, *perfect balance* means ``total_tokens / num_experts``
the number of tokens each expert would receive if routing were perfectly uniform.
- ``e_frac_with_tokens:mean``: Fraction of experts that received at least one token,
averaged across layers. A value of 1.0 means every expert got work; 0.5 means half
were idle. Decreasing over time is concerning (routing collapse).
- ``e_frac_oversubscribed:mean``: Fraction of experts receiving more tokens than
perfect balance, averaged across layers. Increasing over time is concerning.
- ``e_max_violation:mean``: How much the most overloaded expert exceeds perfect
balance, as a fraction of perfect balance, averaged across layers. Computed as
``(max_tokens - perfect_balance) / perfect_balance``. A value of 2.0 means the
busiest expert got 3x the fair share. Increasing over time is concerning.
- ``e_max_violation:max``: Same as ``e_max_violation:mean`` but takes the max
across layers instead of the mean. Shows the worst-case load imbalance in any
single layer.
- ``e_min_violation:mean``: How much the least loaded expert is below perfect
balance, as a fraction of perfect balance, averaged across layers. Computed as
``(min_tokens - perfect_balance) / perfect_balance``. A value of -0.5 means the
least-used expert got half the fair share; -1.0 means it got nothing. Typically
negative. Decreasing over time (more negative) is concerning.
"""

View file

@ -7,3 +7,7 @@ __all__ = ["FutureRetrieveRequest"]
class FutureRetrieveRequest(StrictBase):
request_id: RequestID
"""The ID of the request to retrieve"""
allow_metadata_only: bool = False
"""When True, the server may return only response metadata (status and size)
instead of the full payload if the response exceeds the server's inline size limit."""

View file

@ -1,50 +0,0 @@
#!/usr/bin/env python
from httpx import URL
from tinker._base_client import BaseClient
from tinker.lib.public_interfaces import _convert_forward_backward_input
from tinker.types import ForwardBackwardInput
class TestBaseClient(BaseClient):
def __init__(self, base_url: str = "http://localhost:8000"):
# Test the fixed base_url parameter handling
super().__init__(version="1.0.0", base_url=base_url, _strict_response_validation=True)
def make_status_error(self, err_msg: str, body: object, response: object) -> None:
return None
def _enforce_trailing_slash(self, url: URL) -> URL:
return url
# Test that the base_url parameter is handled correctly
client = TestBaseClient("http://example.com")
print(f"Base URL correctly set to: {client._base_url}")
# Create a proper dictionary to initialize the model
model_data = {
"data": [
{
"model_input": {"chunks": [{"type": "encoded_text", "tokens": [1, 2, 3]}]},
"loss_fn_inputs": {},
}
],
"loss_fn": "cross_entropy", # Use a valid value from the enum
}
# Create a mock model just to test the base_url parameter
try:
# Properly initialize the ForwardBackwardInput model
input_obj = ForwardBackwardInput.model_validate(model_data)
# Test the convert function
result = _convert_forward_backward_input(input_obj)
print(f"Conversion successful: {result}")
except Exception as e:
# Since we're just testing the base_url parameter fix, we can ignore model validation errors
print(f"Note: Could not validate ForwardBackwardInput model: {e}")
print("But that's okay since we're just testing the base_url parameter fix")
print("All tests passed!")