Sync contents

This commit is contained in:
Andrii Grynenko 2026-02-27 05:49:39 +00:00
parent a187cb5b2f
commit 77cc8a0810
7 changed files with 226 additions and 226 deletions

View file

@ -1,4 +1,4 @@
{ {
"last_synced_sha": "1c67712168cd74fc2bda614f239dcad5f7b2e166", "last_synced_sha": "7a5a376dc40cd21c3c744270af661b021e9fa9c6",
"last_sync_time": "2026-02-13T22:05:28.522838" "last_sync_time": "2026-02-27T05:49:39.458251"
} }

View file

@ -1,6 +1,6 @@
[project] [project]
name = "tinker" name = "tinker"
version = "0.13.1" version = "0.14.0"
description = "The official Python SDK for the tinker API" description = "The official Python SDK for the tinker API"
readme = "README.md" readme = "README.md"
license = "Apache-2.0" license = "Apache-2.0"

View file

@ -85,186 +85,200 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
start_time = time.time() start_time = time.time()
iteration = -1 iteration = -1
connection_error_retries = 0 connection_error_retries = 0
allow_metadata_only = True
while True: async with contextlib.AsyncExitStack() as stack:
iteration += 1 while True:
iteration += 1
if timeout is not None and time.time() - start_time > timeout: if timeout is not None and time.time() - start_time > timeout:
if telemetry := self.get_telemetry(): if telemetry := self.get_telemetry():
current_time = time.time() current_time = time.time()
telemetry.log( telemetry.log(
"APIFuture.result_async.timeout", "APIFuture.result_async.timeout",
event_data={ event_data={
"request_id": self.request_id, "request_id": self.request_id,
"request_type": self.request_type, "request_type": self.request_type,
"timeout": timeout, "timeout": timeout,
"iteration": iteration, "iteration": iteration,
"elapsed_time": current_time - start_time, "elapsed_time": current_time - start_time,
}, },
severity="ERROR", severity="ERROR",
) )
raise TimeoutError( raise TimeoutError(
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}" f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
)
# Headers for telemetry
headers = {
"X-Tinker-Request-Iteration": str(iteration),
"X-Tinker-Request-Type": self.request_type,
}
if iteration == 0:
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
self.request_queue_roundtrip_time
)
# Function hasn't been called yet, execute it now
try:
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
response = await client.futures.with_raw_response.retrieve(
request=FutureRetrieveRequest(request_id=self.request_id),
timeout=45,
extra_headers=headers,
max_retries=0,
)
except tinker.APIStatusError as e:
connection_error_retries = 0
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
user_error = is_user_error(e)
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.api_status_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"status_code": e.status_code,
"exception": str(e),
"should_retry": should_retry,
"is_user_error": user_error,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if should_retry or user_error else "ERROR",
) )
# Retry 408s until we time out headers = {
if e.status_code == 408: "X-Tinker-Request-Iteration": str(iteration),
if self._queue_state_observer is not None: "X-Tinker-Request-Type": self.request_type,
with contextlib.suppress(Exception): }
response = e.response.json() if iteration == 0:
if queue_state_str := response.get("queue_state", None): headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
queue_state_reason = response.get("queue_state_reason", None) self.request_queue_roundtrip_time
if queue_state_str == "active": )
queue_state = QueueState.ACTIVE
elif queue_state_str == "paused_rate_limit": try:
queue_state = QueueState.PAUSED_RATE_LIMIT with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
elif queue_state_str == "paused_capacity": response = await client.futures.with_raw_response.retrieve(
queue_state = QueueState.PAUSED_CAPACITY request=FutureRetrieveRequest(
else: request_id=self.request_id,
queue_state = QueueState.UNKNOWN allow_metadata_only=allow_metadata_only,
self._queue_state_observer.on_queue_state_change( ),
queue_state, queue_state_reason timeout=45,
) extra_headers=headers,
continue max_retries=0,
if e.status_code == 410: )
raise RetryableException( except tinker.APIStatusError as e:
message=f"Promise expired/broken for request {self.untyped_future.request_id}" connection_error_retries = 0
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
user_error = is_user_error(e)
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.api_status_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"status_code": e.status_code,
"exception": str(e),
"should_retry": should_retry,
"is_user_error": user_error,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if should_retry or user_error else "ERROR",
)
# Retry 408s until we time out
if e.status_code == 408:
if self._queue_state_observer is not None:
with contextlib.suppress(Exception):
response = e.response.json()
if queue_state_str := response.get("queue_state", None):
queue_state_reason = response.get("queue_state_reason", None)
if queue_state_str == "active":
queue_state = QueueState.ACTIVE
elif queue_state_str == "paused_rate_limit":
queue_state = QueueState.PAUSED_RATE_LIMIT
elif queue_state_str == "paused_capacity":
queue_state = QueueState.PAUSED_CAPACITY
else:
queue_state = QueueState.UNKNOWN
self._queue_state_observer.on_queue_state_change(
queue_state, queue_state_reason
)
continue
if e.status_code == 410:
raise RetryableException(
message=f"Promise expired/broken for request {self.untyped_future.request_id}"
) from e
if e.status_code in range(500, 600):
continue
raise ValueError(
f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}"
) from e ) from e
if e.status_code in range(500, 600): except tinker.APIConnectionError as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.connection_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"connection_error_retries": connection_error_retries,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING",
)
# Retry all connection errors with exponential backoff
await asyncio.sleep(min(2**connection_error_retries, 30))
connection_error_retries += 1
continue continue
raise ValueError(
f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}" result_dict: Any = await response.json()
) from e
except tinker.APIConnectionError as e: if "type" in result_dict and result_dict["type"] == "try_again":
if telemetry := self.get_telemetry(): logger.warning(f"Retrying request {self.request_id=} because of try_again")
current_time = time.time() continue
telemetry.log(
"APIFuture.result_async.connection_error", if result_dict.get("status") == "complete_metadata":
event_data={ # metadata only response should be returned only once
"request_id": self.request_id, assert allow_metadata_only
"request_type": self.request_type, allow_metadata_only = False
"exception": str(e),
"connection_error_retries": connection_error_retries, response_payload_size = result_dict.get("response_payload_size", 0)
"iteration": iteration, assert response_payload_size is not None
"elapsed_time": current_time - start_time, await stack.enter_async_context(
}, self.holder._inflight_response_bytes_semaphore.acquire(
severity="WARNING", response_payload_size
)
)
continue
if "error" in result_dict:
error_category = RequestErrorCategory.Unknown
with contextlib.suppress(Exception):
error_category = RequestErrorCategory(result_dict.get("category"))
user_error = error_category is RequestErrorCategory.User
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.application_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"error": result_dict["error"],
"error_category": error_category.name,
"is_user_error": user_error,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if user_error else "ERROR",
)
error_message = result_dict["error"]
raise RequestFailedError(
f"Request failed: {error_message} for {self.request_id=} and expected type {self.model_cls=}",
request_id=self.request_id,
category=error_category,
) )
# Retry all connection errors with exponential backoff try:
await asyncio.sleep(min(2**connection_error_retries, 30)) if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
connection_error_retries += 1 self._cached_result = self.model_cls.model_validate(result_dict)
continue else:
self._cached_result = result_dict
return cast(T, self._cached_result)
except Exception as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.validation_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"exception_type": type(e).__name__,
"exception_stack": "".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
if e.__traceback__
else None,
"model_cls": str(self.model_cls),
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="ERROR",
)
# Function hasn't been called yet, execute it now raise ValueError(
result_dict: Any = await response.json() f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}"
) from e
if "type" in result_dict and result_dict["type"] == "try_again":
logger.warning(f"Retrying request {self.request_id=} because of try_again")
continue
if "error" in result_dict:
error_category = RequestErrorCategory.Unknown
with contextlib.suppress(Exception):
error_category = RequestErrorCategory(result_dict.get("category"))
user_error = error_category is RequestErrorCategory.User
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.application_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"error": result_dict["error"],
"error_category": error_category.name,
"is_user_error": user_error,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if user_error else "ERROR",
)
error_message = result_dict["error"]
raise RequestFailedError(
f"Request failed: {error_message} for {self.request_id=} and expected type {self.model_cls=}",
request_id=self.request_id,
category=error_category,
)
try:
# Check if model_cls is a BaseModel subclass before calling model_validate
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
self._cached_result = self.model_cls.model_validate(result_dict)
else:
# For non-BaseModel types, just return the result directly
self._cached_result = result_dict
return cast(T, self._cached_result)
except Exception as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.validation_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"exception_type": type(e).__name__,
"exception_stack": "".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
if e.__traceback__
else None,
"model_cls": str(self.model_cls),
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="ERROR",
)
raise ValueError(
f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}"
) from e
@property @property
def request_id(self) -> str: def request_id(self) -> str:

View file

@ -162,6 +162,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400) self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10) self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024) self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
self._training_client_lock: threading.Lock = threading.Lock() self._training_client_lock: threading.Lock = threading.Lock()
if session_id is not None: if session_id is not None:
@ -336,7 +337,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
def close(self): def close(self):
self.run_coroutine_threadsafe(self._async_cleanup()) self.run_coroutine_threadsafe(self._async_cleanup())
if telemetry := self._telemetry: if telemetry := getattr(self, "_telemetry", None):
telemetry.stop() telemetry.stop()
def __del__(self): def __del__(self):

View file

@ -14,4 +14,35 @@ class ForwardBackwardOutput(BaseModel):
"""Dictionary mapping field names to tensor data""" """Dictionary mapping field names to tensor data"""
metrics: Dict[str, float] metrics: Dict[str, float]
"""Training metrics as key-value pairs""" """Training metrics as key-value pairs.
The following metrics are recorded only during MoE (Mixture of Experts) training.
Note: Don't fixate on the exact values of these metrics at the start of training.
Different models on different data will have different initial values. How these
metrics evolve over training is what matters.
In the definitions below, *perfect balance* means ``total_tokens / num_experts``
the number of tokens each expert would receive if routing were perfectly uniform.
- ``e_frac_with_tokens:mean``: Fraction of experts that received at least one token,
averaged across layers. A value of 1.0 means every expert got work; 0.5 means half
were idle. Decreasing over time is concerning (routing collapse).
- ``e_frac_oversubscribed:mean``: Fraction of experts receiving more tokens than
perfect balance, averaged across layers. Increasing over time is concerning.
- ``e_max_violation:mean``: How much the most overloaded expert exceeds perfect
balance, as a fraction of perfect balance, averaged across layers. Computed as
``(max_tokens - perfect_balance) / perfect_balance``. A value of 2.0 means the
busiest expert got 3x the fair share. Increasing over time is concerning.
- ``e_max_violation:max``: Same as ``e_max_violation:mean`` but takes the max
across layers instead of the mean. Shows the worst-case load imbalance in any
single layer.
- ``e_min_violation:mean``: How much the least loaded expert is below perfect
balance, as a fraction of perfect balance, averaged across layers. Computed as
``(min_tokens - perfect_balance) / perfect_balance``. A value of -0.5 means the
least-used expert got half the fair share; -1.0 means it got nothing. Typically
negative. Decreasing over time (more negative) is concerning.
"""

View file

@ -7,3 +7,7 @@ __all__ = ["FutureRetrieveRequest"]
class FutureRetrieveRequest(StrictBase): class FutureRetrieveRequest(StrictBase):
request_id: RequestID request_id: RequestID
"""The ID of the request to retrieve""" """The ID of the request to retrieve"""
allow_metadata_only: bool = False
"""When True, the server may return only response metadata (status and size)
instead of the full payload if the response exceeds the server's inline size limit."""

View file

@ -1,50 +0,0 @@
#!/usr/bin/env python
from httpx import URL
from tinker._base_client import BaseClient
from tinker.lib.public_interfaces import _convert_forward_backward_input
from tinker.types import ForwardBackwardInput
class TestBaseClient(BaseClient):
def __init__(self, base_url: str = "http://localhost:8000"):
# Test the fixed base_url parameter handling
super().__init__(version="1.0.0", base_url=base_url, _strict_response_validation=True)
def make_status_error(self, err_msg: str, body: object, response: object) -> None:
return None
def _enforce_trailing_slash(self, url: URL) -> URL:
return url
# Test that the base_url parameter is handled correctly
client = TestBaseClient("http://example.com")
print(f"Base URL correctly set to: {client._base_url}")
# Create a proper dictionary to initialize the model
model_data = {
"data": [
{
"model_input": {"chunks": [{"type": "encoded_text", "tokens": [1, 2, 3]}]},
"loss_fn_inputs": {},
}
],
"loss_fn": "cross_entropy", # Use a valid value from the enum
}
# Create a mock model just to test the base_url parameter
try:
# Properly initialize the ForwardBackwardInput model
input_obj = ForwardBackwardInput.model_validate(model_data)
# Test the convert function
result = _convert_forward_backward_input(input_obj)
print(f"Conversion successful: {result}")
except Exception as e:
# Since we're just testing the base_url parameter fix, we can ignore model validation errors
print(f"Note: Could not validate ForwardBackwardInput model: {e}")
print("But that's okay since we're just testing the base_url parameter fix")
print("All tests passed!")