mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-05-02 17:45:52 +00:00
Sync contents
This commit is contained in:
parent
a187cb5b2f
commit
77cc8a0810
7 changed files with 226 additions and 226 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"last_synced_sha": "1c67712168cd74fc2bda614f239dcad5f7b2e166",
|
"last_synced_sha": "7a5a376dc40cd21c3c744270af661b021e9fa9c6",
|
||||||
"last_sync_time": "2026-02-13T22:05:28.522838"
|
"last_sync_time": "2026-02-27T05:49:39.458251"
|
||||||
}
|
}
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "tinker"
|
name = "tinker"
|
||||||
version = "0.13.1"
|
version = "0.14.0"
|
||||||
description = "The official Python SDK for the tinker API"
|
description = "The official Python SDK for the tinker API"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,9 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
iteration = -1
|
iteration = -1
|
||||||
connection_error_retries = 0
|
connection_error_retries = 0
|
||||||
|
allow_metadata_only = True
|
||||||
|
|
||||||
|
async with contextlib.AsyncExitStack() as stack:
|
||||||
while True:
|
while True:
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
||||||
|
|
@ -107,7 +109,6 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||||
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
|
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Headers for telemetry
|
|
||||||
headers = {
|
headers = {
|
||||||
"X-Tinker-Request-Iteration": str(iteration),
|
"X-Tinker-Request-Iteration": str(iteration),
|
||||||
"X-Tinker-Request-Type": self.request_type,
|
"X-Tinker-Request-Type": self.request_type,
|
||||||
|
|
@ -117,11 +118,13 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||||
self.request_queue_roundtrip_time
|
self.request_queue_roundtrip_time
|
||||||
)
|
)
|
||||||
|
|
||||||
# Function hasn't been called yet, execute it now
|
|
||||||
try:
|
try:
|
||||||
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
|
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
|
||||||
response = await client.futures.with_raw_response.retrieve(
|
response = await client.futures.with_raw_response.retrieve(
|
||||||
request=FutureRetrieveRequest(request_id=self.request_id),
|
request=FutureRetrieveRequest(
|
||||||
|
request_id=self.request_id,
|
||||||
|
allow_metadata_only=allow_metadata_only,
|
||||||
|
),
|
||||||
timeout=45,
|
timeout=45,
|
||||||
extra_headers=headers,
|
extra_headers=headers,
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
|
|
@ -196,13 +199,26 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||||
connection_error_retries += 1
|
connection_error_retries += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Function hasn't been called yet, execute it now
|
|
||||||
result_dict: Any = await response.json()
|
result_dict: Any = await response.json()
|
||||||
|
|
||||||
if "type" in result_dict and result_dict["type"] == "try_again":
|
if "type" in result_dict and result_dict["type"] == "try_again":
|
||||||
logger.warning(f"Retrying request {self.request_id=} because of try_again")
|
logger.warning(f"Retrying request {self.request_id=} because of try_again")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if result_dict.get("status") == "complete_metadata":
|
||||||
|
# metadata only response should be returned only once
|
||||||
|
assert allow_metadata_only
|
||||||
|
allow_metadata_only = False
|
||||||
|
|
||||||
|
response_payload_size = result_dict.get("response_payload_size", 0)
|
||||||
|
assert response_payload_size is not None
|
||||||
|
await stack.enter_async_context(
|
||||||
|
self.holder._inflight_response_bytes_semaphore.acquire(
|
||||||
|
response_payload_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
if "error" in result_dict:
|
if "error" in result_dict:
|
||||||
error_category = RequestErrorCategory.Unknown
|
error_category = RequestErrorCategory.Unknown
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
|
|
@ -233,11 +249,9 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if model_cls is a BaseModel subclass before calling model_validate
|
|
||||||
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
|
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
|
||||||
self._cached_result = self.model_cls.model_validate(result_dict)
|
self._cached_result = self.model_cls.model_validate(result_dict)
|
||||||
else:
|
else:
|
||||||
# For non-BaseModel types, just return the result directly
|
|
||||||
self._cached_result = result_dict
|
self._cached_result = result_dict
|
||||||
return cast(T, self._cached_result)
|
return cast(T, self._cached_result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -162,6 +162,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
||||||
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
|
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
|
||||||
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
||||||
|
self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
||||||
self._training_client_lock: threading.Lock = threading.Lock()
|
self._training_client_lock: threading.Lock = threading.Lock()
|
||||||
|
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
|
|
@ -336,7 +337,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.run_coroutine_threadsafe(self._async_cleanup())
|
self.run_coroutine_threadsafe(self._async_cleanup())
|
||||||
if telemetry := self._telemetry:
|
if telemetry := getattr(self, "_telemetry", None):
|
||||||
telemetry.stop()
|
telemetry.stop()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|
|
||||||
|
|
@ -14,4 +14,35 @@ class ForwardBackwardOutput(BaseModel):
|
||||||
"""Dictionary mapping field names to tensor data"""
|
"""Dictionary mapping field names to tensor data"""
|
||||||
|
|
||||||
metrics: Dict[str, float]
|
metrics: Dict[str, float]
|
||||||
"""Training metrics as key-value pairs"""
|
"""Training metrics as key-value pairs.
|
||||||
|
|
||||||
|
The following metrics are recorded only during MoE (Mixture of Experts) training.
|
||||||
|
Note: Don't fixate on the exact values of these metrics at the start of training.
|
||||||
|
Different models on different data will have different initial values. How these
|
||||||
|
metrics evolve over training is what matters.
|
||||||
|
|
||||||
|
In the definitions below, *perfect balance* means ``total_tokens / num_experts``
|
||||||
|
— the number of tokens each expert would receive if routing were perfectly uniform.
|
||||||
|
|
||||||
|
- ``e_frac_with_tokens:mean``: Fraction of experts that received at least one token,
|
||||||
|
averaged across layers. A value of 1.0 means every expert got work; 0.5 means half
|
||||||
|
were idle. Decreasing over time is concerning (routing collapse).
|
||||||
|
|
||||||
|
- ``e_frac_oversubscribed:mean``: Fraction of experts receiving more tokens than
|
||||||
|
perfect balance, averaged across layers. Increasing over time is concerning.
|
||||||
|
|
||||||
|
- ``e_max_violation:mean``: How much the most overloaded expert exceeds perfect
|
||||||
|
balance, as a fraction of perfect balance, averaged across layers. Computed as
|
||||||
|
``(max_tokens - perfect_balance) / perfect_balance``. A value of 2.0 means the
|
||||||
|
busiest expert got 3x the fair share. Increasing over time is concerning.
|
||||||
|
|
||||||
|
- ``e_max_violation:max``: Same as ``e_max_violation:mean`` but takes the max
|
||||||
|
across layers instead of the mean. Shows the worst-case load imbalance in any
|
||||||
|
single layer.
|
||||||
|
|
||||||
|
- ``e_min_violation:mean``: How much the least loaded expert is below perfect
|
||||||
|
balance, as a fraction of perfect balance, averaged across layers. Computed as
|
||||||
|
``(min_tokens - perfect_balance) / perfect_balance``. A value of -0.5 means the
|
||||||
|
least-used expert got half the fair share; -1.0 means it got nothing. Typically
|
||||||
|
negative. Decreasing over time (more negative) is concerning.
|
||||||
|
"""
|
||||||
|
|
|
||||||
|
|
@ -7,3 +7,7 @@ __all__ = ["FutureRetrieveRequest"]
|
||||||
class FutureRetrieveRequest(StrictBase):
|
class FutureRetrieveRequest(StrictBase):
|
||||||
request_id: RequestID
|
request_id: RequestID
|
||||||
"""The ID of the request to retrieve"""
|
"""The ID of the request to retrieve"""
|
||||||
|
|
||||||
|
allow_metadata_only: bool = False
|
||||||
|
"""When True, the server may return only response metadata (status and size)
|
||||||
|
instead of the full payload if the response exceeds the server's inline size limit."""
|
||||||
|
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
from httpx import URL
|
|
||||||
|
|
||||||
from tinker._base_client import BaseClient
|
|
||||||
from tinker.lib.public_interfaces import _convert_forward_backward_input
|
|
||||||
from tinker.types import ForwardBackwardInput
|
|
||||||
|
|
||||||
|
|
||||||
class TestBaseClient(BaseClient):
|
|
||||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
|
||||||
# Test the fixed base_url parameter handling
|
|
||||||
super().__init__(version="1.0.0", base_url=base_url, _strict_response_validation=True)
|
|
||||||
|
|
||||||
def make_status_error(self, err_msg: str, body: object, response: object) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _enforce_trailing_slash(self, url: URL) -> URL:
|
|
||||||
return url
|
|
||||||
|
|
||||||
|
|
||||||
# Test that the base_url parameter is handled correctly
|
|
||||||
client = TestBaseClient("http://example.com")
|
|
||||||
print(f"Base URL correctly set to: {client._base_url}")
|
|
||||||
|
|
||||||
# Create a proper dictionary to initialize the model
|
|
||||||
model_data = {
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"model_input": {"chunks": [{"type": "encoded_text", "tokens": [1, 2, 3]}]},
|
|
||||||
"loss_fn_inputs": {},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"loss_fn": "cross_entropy", # Use a valid value from the enum
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create a mock model just to test the base_url parameter
|
|
||||||
try:
|
|
||||||
# Properly initialize the ForwardBackwardInput model
|
|
||||||
input_obj = ForwardBackwardInput.model_validate(model_data)
|
|
||||||
|
|
||||||
# Test the convert function
|
|
||||||
result = _convert_forward_backward_input(input_obj)
|
|
||||||
print(f"Conversion successful: {result}")
|
|
||||||
except Exception as e:
|
|
||||||
# Since we're just testing the base_url parameter fix, we can ignore model validation errors
|
|
||||||
print(f"Note: Could not validate ForwardBackwardInput model: {e}")
|
|
||||||
print("But that's okay since we're just testing the base_url parameter fix")
|
|
||||||
|
|
||||||
print("All tests passed!")
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue