Sync contents

This commit is contained in:
Andrii Grynenko 2026-02-27 05:49:39 +00:00
parent a187cb5b2f
commit 77cc8a0810
7 changed files with 226 additions and 226 deletions

View file

@ -1,4 +1,4 @@
{ {
"last_synced_sha": "1c67712168cd74fc2bda614f239dcad5f7b2e166", "last_synced_sha": "7a5a376dc40cd21c3c744270af661b021e9fa9c6",
"last_sync_time": "2026-02-13T22:05:28.522838" "last_sync_time": "2026-02-27T05:49:39.458251"
} }

View file

@ -1,6 +1,6 @@
[project] [project]
name = "tinker" name = "tinker"
version = "0.13.1" version = "0.14.0"
description = "The official Python SDK for the tinker API" description = "The official Python SDK for the tinker API"
readme = "README.md" readme = "README.md"
license = "Apache-2.0" license = "Apache-2.0"

View file

@ -85,7 +85,9 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
start_time = time.time() start_time = time.time()
iteration = -1 iteration = -1
connection_error_retries = 0 connection_error_retries = 0
allow_metadata_only = True
async with contextlib.AsyncExitStack() as stack:
while True: while True:
iteration += 1 iteration += 1
@ -107,7 +109,6 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}" f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
) )
# Headers for telemetry
headers = { headers = {
"X-Tinker-Request-Iteration": str(iteration), "X-Tinker-Request-Iteration": str(iteration),
"X-Tinker-Request-Type": self.request_type, "X-Tinker-Request-Type": self.request_type,
@ -117,11 +118,13 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
self.request_queue_roundtrip_time self.request_queue_roundtrip_time
) )
# Function hasn't been called yet, execute it now
try: try:
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client: with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
response = await client.futures.with_raw_response.retrieve( response = await client.futures.with_raw_response.retrieve(
request=FutureRetrieveRequest(request_id=self.request_id), request=FutureRetrieveRequest(
request_id=self.request_id,
allow_metadata_only=allow_metadata_only,
),
timeout=45, timeout=45,
extra_headers=headers, extra_headers=headers,
max_retries=0, max_retries=0,
@ -196,13 +199,26 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
connection_error_retries += 1 connection_error_retries += 1
continue continue
# Function hasn't been called yet, execute it now
result_dict: Any = await response.json() result_dict: Any = await response.json()
if "type" in result_dict and result_dict["type"] == "try_again": if "type" in result_dict and result_dict["type"] == "try_again":
logger.warning(f"Retrying request {self.request_id=} because of try_again") logger.warning(f"Retrying request {self.request_id=} because of try_again")
continue continue
if result_dict.get("status") == "complete_metadata":
# metadata only response should be returned only once
assert allow_metadata_only
allow_metadata_only = False
response_payload_size = result_dict.get("response_payload_size", 0)
assert response_payload_size is not None
await stack.enter_async_context(
self.holder._inflight_response_bytes_semaphore.acquire(
response_payload_size
)
)
continue
if "error" in result_dict: if "error" in result_dict:
error_category = RequestErrorCategory.Unknown error_category = RequestErrorCategory.Unknown
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
@ -233,11 +249,9 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
) )
try: try:
# Check if model_cls is a BaseModel subclass before calling model_validate
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel): if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
self._cached_result = self.model_cls.model_validate(result_dict) self._cached_result = self.model_cls.model_validate(result_dict)
else: else:
# For non-BaseModel types, just return the result directly
self._cached_result = result_dict self._cached_result = result_dict
return cast(T, self._cached_result) return cast(T, self._cached_result)
except Exception as e: except Exception as e:

View file

@ -162,6 +162,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400) self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10) self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024) self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
self._training_client_lock: threading.Lock = threading.Lock() self._training_client_lock: threading.Lock = threading.Lock()
if session_id is not None: if session_id is not None:
@ -336,7 +337,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
def close(self): def close(self):
self.run_coroutine_threadsafe(self._async_cleanup()) self.run_coroutine_threadsafe(self._async_cleanup())
if telemetry := self._telemetry: if telemetry := getattr(self, "_telemetry", None):
telemetry.stop() telemetry.stop()
def __del__(self): def __del__(self):

View file

@ -14,4 +14,35 @@ class ForwardBackwardOutput(BaseModel):
"""Dictionary mapping field names to tensor data""" """Dictionary mapping field names to tensor data"""
metrics: Dict[str, float] metrics: Dict[str, float]
"""Training metrics as key-value pairs""" """Training metrics as key-value pairs.
The following metrics are recorded only during MoE (Mixture of Experts) training.
Note: Don't fixate on the exact values of these metrics at the start of training.
Different models on different data will have different initial values. How these
metrics evolve over training is what matters.
In the definitions below, *perfect balance* means ``total_tokens / num_experts``
the number of tokens each expert would receive if routing were perfectly uniform.
- ``e_frac_with_tokens:mean``: Fraction of experts that received at least one token,
averaged across layers. A value of 1.0 means every expert got work; 0.5 means half
were idle. Decreasing over time is concerning (routing collapse).
- ``e_frac_oversubscribed:mean``: Fraction of experts receiving more tokens than
perfect balance, averaged across layers. Increasing over time is concerning.
- ``e_max_violation:mean``: How much the most overloaded expert exceeds perfect
balance, as a fraction of perfect balance, averaged across layers. Computed as
``(max_tokens - perfect_balance) / perfect_balance``. A value of 2.0 means the
busiest expert got 3x the fair share. Increasing over time is concerning.
- ``e_max_violation:max``: Same as ``e_max_violation:mean`` but takes the max
across layers instead of the mean. Shows the worst-case load imbalance in any
single layer.
- ``e_min_violation:mean``: How much the least loaded expert is below perfect
balance, as a fraction of perfect balance, averaged across layers. Computed as
``(min_tokens - perfect_balance) / perfect_balance``. A value of -0.5 means the
least-used expert got half the fair share; -1.0 means it got nothing. Typically
negative. Decreasing over time (more negative) is concerning.
"""

View file

@ -7,3 +7,7 @@ __all__ = ["FutureRetrieveRequest"]
class FutureRetrieveRequest(StrictBase): class FutureRetrieveRequest(StrictBase):
request_id: RequestID request_id: RequestID
"""The ID of the request to retrieve""" """The ID of the request to retrieve"""
allow_metadata_only: bool = False
"""When True, the server may return only response metadata (status and size)
instead of the full payload if the response exceeds the server's inline size limit."""

View file

@ -1,50 +0,0 @@
#!/usr/bin/env python
from httpx import URL
from tinker._base_client import BaseClient
from tinker.lib.public_interfaces import _convert_forward_backward_input
from tinker.types import ForwardBackwardInput
class TestBaseClient(BaseClient):
def __init__(self, base_url: str = "http://localhost:8000"):
# Test the fixed base_url parameter handling
super().__init__(version="1.0.0", base_url=base_url, _strict_response_validation=True)
def make_status_error(self, err_msg: str, body: object, response: object) -> None:
return None
def _enforce_trailing_slash(self, url: URL) -> URL:
return url
# Test that the base_url parameter is handled correctly
client = TestBaseClient("http://example.com")
print(f"Base URL correctly set to: {client._base_url}")
# Create a proper dictionary to initialize the model
model_data = {
"data": [
{
"model_input": {"chunks": [{"type": "encoded_text", "tokens": [1, 2, 3]}]},
"loss_fn_inputs": {},
}
],
"loss_fn": "cross_entropy", # Use a valid value from the enum
}
# Create a mock model just to test the base_url parameter
try:
# Properly initialize the ForwardBackwardInput model
input_obj = ForwardBackwardInput.model_validate(model_data)
# Test the convert function
result = _convert_forward_backward_input(input_obj)
print(f"Conversion successful: {result}")
except Exception as e:
# Since we're just testing the base_url parameter fix, we can ignore model validation errors
print(f"Note: Could not validate ForwardBackwardInput model: {e}")
print("But that's okay since we're just testing the base_url parameter fix")
print("All tests passed!")