mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
a187cb5b2f
commit
77cc8a0810
7 changed files with 226 additions and 226 deletions
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "1c67712168cd74fc2bda614f239dcad5f7b2e166",
|
||||
"last_sync_time": "2026-02-13T22:05:28.522838"
|
||||
"last_synced_sha": "7a5a376dc40cd21c3c744270af661b021e9fa9c6",
|
||||
"last_sync_time": "2026-02-27T05:49:39.458251"
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "tinker"
|
||||
version = "0.13.1"
|
||||
version = "0.14.0"
|
||||
description = "The official Python SDK for the tinker API"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
|
|
|
|||
|
|
@ -85,186 +85,200 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
|||
start_time = time.time()
|
||||
iteration = -1
|
||||
connection_error_retries = 0
|
||||
allow_metadata_only = True
|
||||
|
||||
while True:
|
||||
iteration += 1
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
while True:
|
||||
iteration += 1
|
||||
|
||||
if timeout is not None and time.time() - start_time > timeout:
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.timeout",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"timeout": timeout,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="ERROR",
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
|
||||
)
|
||||
|
||||
# Headers for telemetry
|
||||
headers = {
|
||||
"X-Tinker-Request-Iteration": str(iteration),
|
||||
"X-Tinker-Request-Type": self.request_type,
|
||||
}
|
||||
if iteration == 0:
|
||||
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
|
||||
self.request_queue_roundtrip_time
|
||||
)
|
||||
|
||||
# Function hasn't been called yet, execute it now
|
||||
try:
|
||||
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
|
||||
response = await client.futures.with_raw_response.retrieve(
|
||||
request=FutureRetrieveRequest(request_id=self.request_id),
|
||||
timeout=45,
|
||||
extra_headers=headers,
|
||||
max_retries=0,
|
||||
)
|
||||
except tinker.APIStatusError as e:
|
||||
connection_error_retries = 0
|
||||
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
|
||||
user_error = is_user_error(e)
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.api_status_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"status_code": e.status_code,
|
||||
"exception": str(e),
|
||||
"should_retry": should_retry,
|
||||
"is_user_error": user_error,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="WARNING" if should_retry or user_error else "ERROR",
|
||||
if timeout is not None and time.time() - start_time > timeout:
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.timeout",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"timeout": timeout,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="ERROR",
|
||||
)
|
||||
raise TimeoutError(
|
||||
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
|
||||
)
|
||||
|
||||
# Retry 408s until we time out
|
||||
if e.status_code == 408:
|
||||
if self._queue_state_observer is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
response = e.response.json()
|
||||
if queue_state_str := response.get("queue_state", None):
|
||||
queue_state_reason = response.get("queue_state_reason", None)
|
||||
if queue_state_str == "active":
|
||||
queue_state = QueueState.ACTIVE
|
||||
elif queue_state_str == "paused_rate_limit":
|
||||
queue_state = QueueState.PAUSED_RATE_LIMIT
|
||||
elif queue_state_str == "paused_capacity":
|
||||
queue_state = QueueState.PAUSED_CAPACITY
|
||||
else:
|
||||
queue_state = QueueState.UNKNOWN
|
||||
self._queue_state_observer.on_queue_state_change(
|
||||
queue_state, queue_state_reason
|
||||
)
|
||||
continue
|
||||
if e.status_code == 410:
|
||||
raise RetryableException(
|
||||
message=f"Promise expired/broken for request {self.untyped_future.request_id}"
|
||||
headers = {
|
||||
"X-Tinker-Request-Iteration": str(iteration),
|
||||
"X-Tinker-Request-Type": self.request_type,
|
||||
}
|
||||
if iteration == 0:
|
||||
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
|
||||
self.request_queue_roundtrip_time
|
||||
)
|
||||
|
||||
try:
|
||||
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
|
||||
response = await client.futures.with_raw_response.retrieve(
|
||||
request=FutureRetrieveRequest(
|
||||
request_id=self.request_id,
|
||||
allow_metadata_only=allow_metadata_only,
|
||||
),
|
||||
timeout=45,
|
||||
extra_headers=headers,
|
||||
max_retries=0,
|
||||
)
|
||||
except tinker.APIStatusError as e:
|
||||
connection_error_retries = 0
|
||||
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
|
||||
user_error = is_user_error(e)
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.api_status_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"status_code": e.status_code,
|
||||
"exception": str(e),
|
||||
"should_retry": should_retry,
|
||||
"is_user_error": user_error,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="WARNING" if should_retry or user_error else "ERROR",
|
||||
)
|
||||
|
||||
# Retry 408s until we time out
|
||||
if e.status_code == 408:
|
||||
if self._queue_state_observer is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
response = e.response.json()
|
||||
if queue_state_str := response.get("queue_state", None):
|
||||
queue_state_reason = response.get("queue_state_reason", None)
|
||||
if queue_state_str == "active":
|
||||
queue_state = QueueState.ACTIVE
|
||||
elif queue_state_str == "paused_rate_limit":
|
||||
queue_state = QueueState.PAUSED_RATE_LIMIT
|
||||
elif queue_state_str == "paused_capacity":
|
||||
queue_state = QueueState.PAUSED_CAPACITY
|
||||
else:
|
||||
queue_state = QueueState.UNKNOWN
|
||||
self._queue_state_observer.on_queue_state_change(
|
||||
queue_state, queue_state_reason
|
||||
)
|
||||
continue
|
||||
if e.status_code == 410:
|
||||
raise RetryableException(
|
||||
message=f"Promise expired/broken for request {self.untyped_future.request_id}"
|
||||
) from e
|
||||
if e.status_code in range(500, 600):
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}"
|
||||
) from e
|
||||
if e.status_code in range(500, 600):
|
||||
except tinker.APIConnectionError as e:
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.connection_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"exception": str(e),
|
||||
"connection_error_retries": connection_error_retries,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="WARNING",
|
||||
)
|
||||
|
||||
# Retry all connection errors with exponential backoff
|
||||
await asyncio.sleep(min(2**connection_error_retries, 30))
|
||||
connection_error_retries += 1
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}"
|
||||
) from e
|
||||
except tinker.APIConnectionError as e:
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.connection_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"exception": str(e),
|
||||
"connection_error_retries": connection_error_retries,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="WARNING",
|
||||
|
||||
result_dict: Any = await response.json()
|
||||
|
||||
if "type" in result_dict and result_dict["type"] == "try_again":
|
||||
logger.warning(f"Retrying request {self.request_id=} because of try_again")
|
||||
continue
|
||||
|
||||
if result_dict.get("status") == "complete_metadata":
|
||||
# metadata only response should be returned only once
|
||||
assert allow_metadata_only
|
||||
allow_metadata_only = False
|
||||
|
||||
response_payload_size = result_dict.get("response_payload_size", 0)
|
||||
assert response_payload_size is not None
|
||||
await stack.enter_async_context(
|
||||
self.holder._inflight_response_bytes_semaphore.acquire(
|
||||
response_payload_size
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if "error" in result_dict:
|
||||
error_category = RequestErrorCategory.Unknown
|
||||
with contextlib.suppress(Exception):
|
||||
error_category = RequestErrorCategory(result_dict.get("category"))
|
||||
|
||||
user_error = error_category is RequestErrorCategory.User
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.application_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"error": result_dict["error"],
|
||||
"error_category": error_category.name,
|
||||
"is_user_error": user_error,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="WARNING" if user_error else "ERROR",
|
||||
)
|
||||
|
||||
error_message = result_dict["error"]
|
||||
raise RequestFailedError(
|
||||
f"Request failed: {error_message} for {self.request_id=} and expected type {self.model_cls=}",
|
||||
request_id=self.request_id,
|
||||
category=error_category,
|
||||
)
|
||||
|
||||
# Retry all connection errors with exponential backoff
|
||||
await asyncio.sleep(min(2**connection_error_retries, 30))
|
||||
connection_error_retries += 1
|
||||
continue
|
||||
try:
|
||||
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
|
||||
self._cached_result = self.model_cls.model_validate(result_dict)
|
||||
else:
|
||||
self._cached_result = result_dict
|
||||
return cast(T, self._cached_result)
|
||||
except Exception as e:
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.validation_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"exception": str(e),
|
||||
"exception_type": type(e).__name__,
|
||||
"exception_stack": "".join(
|
||||
traceback.format_exception(type(e), e, e.__traceback__)
|
||||
)
|
||||
if e.__traceback__
|
||||
else None,
|
||||
"model_cls": str(self.model_cls),
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="ERROR",
|
||||
)
|
||||
|
||||
# Function hasn't been called yet, execute it now
|
||||
result_dict: Any = await response.json()
|
||||
|
||||
if "type" in result_dict and result_dict["type"] == "try_again":
|
||||
logger.warning(f"Retrying request {self.request_id=} because of try_again")
|
||||
continue
|
||||
|
||||
if "error" in result_dict:
|
||||
error_category = RequestErrorCategory.Unknown
|
||||
with contextlib.suppress(Exception):
|
||||
error_category = RequestErrorCategory(result_dict.get("category"))
|
||||
|
||||
user_error = error_category is RequestErrorCategory.User
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.application_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"error": result_dict["error"],
|
||||
"error_category": error_category.name,
|
||||
"is_user_error": user_error,
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="WARNING" if user_error else "ERROR",
|
||||
)
|
||||
|
||||
error_message = result_dict["error"]
|
||||
raise RequestFailedError(
|
||||
f"Request failed: {error_message} for {self.request_id=} and expected type {self.model_cls=}",
|
||||
request_id=self.request_id,
|
||||
category=error_category,
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if model_cls is a BaseModel subclass before calling model_validate
|
||||
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
|
||||
self._cached_result = self.model_cls.model_validate(result_dict)
|
||||
else:
|
||||
# For non-BaseModel types, just return the result directly
|
||||
self._cached_result = result_dict
|
||||
return cast(T, self._cached_result)
|
||||
except Exception as e:
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.validation_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"exception": str(e),
|
||||
"exception_type": type(e).__name__,
|
||||
"exception_stack": "".join(
|
||||
traceback.format_exception(type(e), e, e.__traceback__)
|
||||
)
|
||||
if e.__traceback__
|
||||
else None,
|
||||
"model_cls": str(self.model_cls),
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="ERROR",
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}"
|
||||
) from e
|
||||
raise ValueError(
|
||||
f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}"
|
||||
) from e
|
||||
|
||||
@property
|
||||
def request_id(self) -> str:
|
||||
|
|
|
|||
|
|
@ -162,6 +162,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
||||
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
|
||||
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
||||
self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
||||
self._training_client_lock: threading.Lock = threading.Lock()
|
||||
|
||||
if session_id is not None:
|
||||
|
|
@ -336,7 +337,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
|
||||
def close(self):
|
||||
self.run_coroutine_threadsafe(self._async_cleanup())
|
||||
if telemetry := self._telemetry:
|
||||
if telemetry := getattr(self, "_telemetry", None):
|
||||
telemetry.stop()
|
||||
|
||||
def __del__(self):
|
||||
|
|
|
|||
|
|
@ -14,4 +14,35 @@ class ForwardBackwardOutput(BaseModel):
|
|||
"""Dictionary mapping field names to tensor data"""
|
||||
|
||||
metrics: Dict[str, float]
|
||||
"""Training metrics as key-value pairs"""
|
||||
"""Training metrics as key-value pairs.
|
||||
|
||||
The following metrics are recorded only during MoE (Mixture of Experts) training.
|
||||
Note: Don't fixate on the exact values of these metrics at the start of training.
|
||||
Different models on different data will have different initial values. How these
|
||||
metrics evolve over training is what matters.
|
||||
|
||||
In the definitions below, *perfect balance* means ``total_tokens / num_experts``
|
||||
— the number of tokens each expert would receive if routing were perfectly uniform.
|
||||
|
||||
- ``e_frac_with_tokens:mean``: Fraction of experts that received at least one token,
|
||||
averaged across layers. A value of 1.0 means every expert got work; 0.5 means half
|
||||
were idle. Decreasing over time is concerning (routing collapse).
|
||||
|
||||
- ``e_frac_oversubscribed:mean``: Fraction of experts receiving more tokens than
|
||||
perfect balance, averaged across layers. Increasing over time is concerning.
|
||||
|
||||
- ``e_max_violation:mean``: How much the most overloaded expert exceeds perfect
|
||||
balance, as a fraction of perfect balance, averaged across layers. Computed as
|
||||
``(max_tokens - perfect_balance) / perfect_balance``. A value of 2.0 means the
|
||||
busiest expert got 3x the fair share. Increasing over time is concerning.
|
||||
|
||||
- ``e_max_violation:max``: Same as ``e_max_violation:mean`` but takes the max
|
||||
across layers instead of the mean. Shows the worst-case load imbalance in any
|
||||
single layer.
|
||||
|
||||
- ``e_min_violation:mean``: How much the least loaded expert is below perfect
|
||||
balance, as a fraction of perfect balance, averaged across layers. Computed as
|
||||
``(min_tokens - perfect_balance) / perfect_balance``. A value of -0.5 means the
|
||||
least-used expert got half the fair share; -1.0 means it got nothing. Typically
|
||||
negative. Decreasing over time (more negative) is concerning.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -7,3 +7,7 @@ __all__ = ["FutureRetrieveRequest"]
|
|||
class FutureRetrieveRequest(StrictBase):
|
||||
request_id: RequestID
|
||||
"""The ID of the request to retrieve"""
|
||||
|
||||
allow_metadata_only: bool = False
|
||||
"""When True, the server may return only response metadata (status and size)
|
||||
instead of the full payload if the response exceeds the server's inline size limit."""
|
||||
|
|
|
|||
|
|
@ -1,50 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from httpx import URL
|
||||
|
||||
from tinker._base_client import BaseClient
|
||||
from tinker.lib.public_interfaces import _convert_forward_backward_input
|
||||
from tinker.types import ForwardBackwardInput
|
||||
|
||||
|
||||
class TestBaseClient(BaseClient):
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
# Test the fixed base_url parameter handling
|
||||
super().__init__(version="1.0.0", base_url=base_url, _strict_response_validation=True)
|
||||
|
||||
def make_status_error(self, err_msg: str, body: object, response: object) -> None:
|
||||
return None
|
||||
|
||||
def _enforce_trailing_slash(self, url: URL) -> URL:
|
||||
return url
|
||||
|
||||
|
||||
# Test that the base_url parameter is handled correctly
|
||||
client = TestBaseClient("http://example.com")
|
||||
print(f"Base URL correctly set to: {client._base_url}")
|
||||
|
||||
# Create a proper dictionary to initialize the model
|
||||
model_data = {
|
||||
"data": [
|
||||
{
|
||||
"model_input": {"chunks": [{"type": "encoded_text", "tokens": [1, 2, 3]}]},
|
||||
"loss_fn_inputs": {},
|
||||
}
|
||||
],
|
||||
"loss_fn": "cross_entropy", # Use a valid value from the enum
|
||||
}
|
||||
|
||||
# Create a mock model just to test the base_url parameter
|
||||
try:
|
||||
# Properly initialize the ForwardBackwardInput model
|
||||
input_obj = ForwardBackwardInput.model_validate(model_data)
|
||||
|
||||
# Test the convert function
|
||||
result = _convert_forward_backward_input(input_obj)
|
||||
print(f"Conversion successful: {result}")
|
||||
except Exception as e:
|
||||
# Since we're just testing the base_url parameter fix, we can ignore model validation errors
|
||||
print(f"Note: Could not validate ForwardBackwardInput model: {e}")
|
||||
print("But that's okay since we're just testing the base_url parameter fix")
|
||||
|
||||
print("All tests passed!")
|
||||
Loading…
Add table
Add a link
Reference in a new issue