mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
0622760140
commit
5ad4282c96
10 changed files with 155 additions and 51 deletions
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "ffa781c9432bd1afb73c278d253d0a7cd1b333d8",
|
||||
"last_sync_time": "2025-12-03T01:00:03.831979"
|
||||
"last_synced_sha": "2f8a0c29b227897c25451163c8fa90adb2daa6e6",
|
||||
"last_sync_time": "2025-12-08T00:05:28.592609"
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "tinker"
|
||||
version = "0.6.3"
|
||||
version = "0.7.0"
|
||||
description = "The official Python SDK for the tinker API"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class APIResponseValidationError(APIError):
|
|||
status_code: int
|
||||
|
||||
def __init__(
|
||||
self, response: httpx.Response, body: object | None, *, message: str | None = None
|
||||
self, response: httpx.Response, body: object | None, message: str | None = None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
message or "Data returned by API invalid for expected schema.",
|
||||
|
|
@ -67,6 +67,14 @@ class APIResponseValidationError(APIError):
|
|||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
|
||||
def __reduce__(self):
|
||||
# Return a tuple of (callable, args) to recreate the exception
|
||||
return (
|
||||
self.__class__,
|
||||
(self.response, self.body, self.message), # positional args
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class APIStatusError(APIError):
|
||||
"""Raised when an API response has a status code of 4xx or 5xx."""
|
||||
|
|
@ -74,24 +82,41 @@ class APIStatusError(APIError):
|
|||
response: httpx.Response
|
||||
status_code: int
|
||||
|
||||
def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
|
||||
def __init__(self, message: str, response: httpx.Response, body: object | None) -> None:
|
||||
super().__init__(message, response.request, body=body)
|
||||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
|
||||
|
||||
def __reduce__(self):
|
||||
# Return a tuple of (callable, args) to recreate the exception
|
||||
return (
|
||||
self.__class__,
|
||||
(self.message, self.response, self.body), # positional args
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class APIConnectionError(APIError):
|
||||
"""Raised when a connection error occurs while making an API request."""
|
||||
|
||||
def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
|
||||
def __init__(self, request: httpx.Request, message: str = "Connection error.") -> None:
|
||||
super().__init__(message, request, body=None)
|
||||
|
||||
def __reduce__(self):
|
||||
# Return a tuple of (callable, args) to recreate the exception
|
||||
return (
|
||||
self.__class__,
|
||||
(self.request, self.message), # positional args
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class APITimeoutError(APIConnectionError):
|
||||
"""Raised when an API request times out."""
|
||||
|
||||
def __init__(self, request: httpx.Request) -> None:
|
||||
super().__init__(message="Request timed out.", request=request)
|
||||
super().__init__(request=request, message="Request timed out.", )
|
||||
|
||||
|
||||
class BadRequestError(APIStatusError):
|
||||
|
|
@ -148,10 +173,18 @@ class RequestFailedError(TinkerError):
|
|||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
request_id: str,
|
||||
category: "RequestErrorCategory",
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message: str = message
|
||||
self.request_id: str = request_id
|
||||
self.category: RequestErrorCategory = category
|
||||
|
||||
def __reduce__(self):
|
||||
# Return a tuple of (callable, args) to recreate the exception
|
||||
return (
|
||||
self.__class__,
|
||||
(self.message, self.request_id, self.category), # positional args
|
||||
None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class QueueState(Enum):
|
|||
|
||||
class QueueStateObserver(ABC):
|
||||
@abstractmethod
|
||||
def on_queue_state_change(self, queue_state: QueueState) -> None:
|
||||
def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
|
@ -151,6 +151,7 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
|||
with contextlib.suppress(Exception):
|
||||
response = e.response.json()
|
||||
if queue_state_str := response.get("queue_state", None):
|
||||
queue_state_reason = response.get("queue_state_reason", None)
|
||||
if queue_state_str == "active":
|
||||
queue_state = QueueState.ACTIVE
|
||||
elif queue_state_str == "paused_rate_limit":
|
||||
|
|
@ -159,7 +160,7 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
|||
queue_state = QueueState.PAUSED_CAPACITY
|
||||
else:
|
||||
queue_state = QueueState.UNKNOWN
|
||||
self._queue_state_observer.on_queue_state_change(queue_state)
|
||||
self._queue_state_observer.on_queue_state_change(queue_state, queue_state_reason)
|
||||
continue
|
||||
if e.status_code == 410:
|
||||
raise RetryableException(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import time
|
|||
import traceback
|
||||
import contextlib
|
||||
from collections.abc import Coroutine, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from contextlib import AbstractContextManager, asynccontextmanager, contextmanager
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
import httpx
|
||||
|
|
@ -97,6 +97,31 @@ class InternalClientHolderThreadSingleton:
|
|||
_internal_client_holder_thread_singleton = InternalClientHolderThreadSingleton()
|
||||
|
||||
|
||||
class BytesSemaphore:
|
||||
def __init__(self, max_bytes: int):
|
||||
self._bytes: int = max_bytes
|
||||
self._condition: asyncio.Condition = asyncio.Condition()
|
||||
self._release_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def _release(self):
|
||||
async with self._condition:
|
||||
self._condition.notify_all()
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self, bytes: int):
|
||||
async with self._condition:
|
||||
while self._bytes < 0:
|
||||
await self._condition.wait()
|
||||
self._bytes -= bytes
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._bytes += bytes
|
||||
# Make sure the release task is never cancelled.
|
||||
self._release_task = asyncio.create_task(self._release())
|
||||
|
||||
|
||||
class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||
def __init__(self, user_metadata: dict[str, str] | None = None, **kwargs: Any) -> None:
|
||||
self._constructor_kwargs = kwargs
|
||||
|
|
@ -106,6 +131,8 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
|
||||
self._sample_backoff_until: float | None = None
|
||||
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
||||
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
|
||||
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
||||
self._telemetry: Telemetry | None = None
|
||||
self._session_heartbeat_task: asyncio.Task[None] | None = None
|
||||
session_id, session_heartbeat_task = self.run_coroutine_threadsafe(
|
||||
|
|
@ -120,6 +147,37 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
|
||||
self._sampling_client_counter: int = 0
|
||||
|
||||
@asynccontextmanager
|
||||
async def _sample_dispatch_count_rate_limit(self):
|
||||
async with self._sample_dispatch_semaphore:
|
||||
yield
|
||||
|
||||
@asynccontextmanager
|
||||
async def _sample_dispatch_count_throttled_rate_limit(self):
|
||||
async with self._sample_dispatch_throttled_semaphore:
|
||||
yield
|
||||
|
||||
def _sample_backoff_requested_recently(self) -> bool:
|
||||
return self._sample_backoff_until is not None and time.monotonic() - self._sample_backoff_until < 10
|
||||
|
||||
@asynccontextmanager
|
||||
async def _sample_dispatch_bytes_rate_limit(self, bytes: int):
|
||||
if self._sample_backoff_requested_recently():
|
||||
# Rate limit more aggressively if we received backoff response recently
|
||||
bytes *= 20
|
||||
async with self._sample_dispatch_bytes_semaphore.acquire(bytes):
|
||||
yield
|
||||
|
||||
@asynccontextmanager
|
||||
async def sample_dispatch_rate_limit(self, estimated_bytes_count: int):
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
await stack.enter_async_context(self._sample_dispatch_count_rate_limit())
|
||||
if self._sample_backoff_requested_recently():
|
||||
await stack.enter_async_context(self._sample_dispatch_count_throttled_rate_limit())
|
||||
await stack.enter_async_context(self._sample_dispatch_bytes_rate_limit(estimated_bytes_count))
|
||||
|
||||
yield
|
||||
|
||||
async def _session_heartbeat(self, session_id: str):
|
||||
SESSION_HEARTBEAT_PERIOD_SEC = 10
|
||||
SESSION_MISSED_HEARTBEAT_WARNING_THRESHOLD_SEC = 60 * 2
|
||||
|
|
@ -305,6 +363,16 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
|
||||
raise e
|
||||
|
||||
def estimate_bytes_count_in_chunk(self, chunk: types.ModelInputChunk) -> int:
|
||||
if isinstance(chunk, types.ImageChunk):
|
||||
return len(chunk.data)
|
||||
if isinstance(chunk, types.ImageAssetPointerChunk):
|
||||
return len(chunk.location)
|
||||
return chunk.length * 10
|
||||
|
||||
def estimate_bytes_count_in_model_input(self, model_input: types.ModelInput) -> int:
|
||||
return sum(self.estimate_bytes_count_in_chunk(chunk) for chunk in model_input.chunks)
|
||||
|
||||
|
||||
def _current_loop() -> asyncio.AbstractEventLoop | None:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -155,11 +155,12 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
include_prompt_logprobs: bool,
|
||||
topk_prompt_logprobs: int = 0,
|
||||
) -> types.SampleResponse:
|
||||
async with self.holder._sample_dispatch_semaphore:
|
||||
estimated_bytes_count = self.holder.estimate_bytes_count_in_model_input(prompt)
|
||||
async with self.holder.sample_dispatch_rate_limit(estimated_bytes_count):
|
||||
while True:
|
||||
if (
|
||||
self.holder._sample_backoff_until is not None
|
||||
and time.time() < self.holder._sample_backoff_until
|
||||
and time.monotonic() < self.holder._sample_backoff_until
|
||||
):
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
|
@ -175,7 +176,8 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
if untyped_future is not None:
|
||||
break
|
||||
# Handle backoff
|
||||
self.holder._sample_backoff_until = time.time() + 1
|
||||
backoff_duration = 1 if estimated_bytes_count <= 128 * 1024 else 5
|
||||
self.holder._sample_backoff_until = time.monotonic() + backoff_duration
|
||||
continue
|
||||
|
||||
return await _APIFuture(
|
||||
|
|
@ -298,22 +300,23 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
def on_queue_state_change(self, queue_state: QueueState) -> None:
|
||||
def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None:
|
||||
QUEUE_STATE_LOG_INTERVAL = 60
|
||||
if queue_state == QueueState.ACTIVE:
|
||||
return
|
||||
if time.time() - self._last_queue_state_logged < QUEUE_STATE_LOG_INTERVAL:
|
||||
return
|
||||
if queue_state == QueueState.PAUSED_RATE_LIMIT:
|
||||
reason = "concurrent LoRA rate limit hit"
|
||||
elif queue_state == QueueState.PAUSED_CAPACITY:
|
||||
reason = "out of capacity"
|
||||
else:
|
||||
reason = "unknown"
|
||||
if not queue_state_reason:
|
||||
if queue_state == QueueState.PAUSED_RATE_LIMIT:
|
||||
queue_state_reason = "concurrent sampler weights limit hit"
|
||||
elif queue_state == QueueState.PAUSED_CAPACITY:
|
||||
queue_state_reason = "Tinker backend is running short on capacity, please wait"
|
||||
else:
|
||||
queue_state_reason = "unknown"
|
||||
self._last_queue_state_logged = time.time()
|
||||
|
||||
logger.warning(
|
||||
f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {reason}"
|
||||
f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {queue_state_reason}"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -141,10 +141,11 @@ class ServiceClient(TelemetryProvider):
|
|||
request_type="CreateModel",
|
||||
).result_async()
|
||||
model_id = create_model_response.model_id
|
||||
logger.info(f"Creating TrainingClient for {model_id=}")
|
||||
from .training_client import TrainingClient
|
||||
|
||||
return TrainingClient(self.holder, model_seq_id=model_seq_id, model_id=model_id)
|
||||
training_client = TrainingClient(self.holder, model_seq_id=model_seq_id, model_id=model_id)
|
||||
logger.info(f"TrainingClient initialized for model {model_id}")
|
||||
return training_client
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_create_lora_training_client_async())
|
||||
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FwdBwdChunkSize
|
||||
MAX_CHUNK_LEN = 128
|
||||
MAX_CHUNK_NUMBER_COUNT = 500000
|
||||
MAX_CHUNK_LEN = 1024
|
||||
MAX_CHUNK_BYTES_COUNT = 5000000
|
||||
MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initializing the TrainingClient with a model_id by either calling create_lora_training_client on the ServiceClient, or initiliazing the TrainingClient with an existing model_id."
|
||||
|
||||
CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]]
|
||||
|
|
@ -121,36 +121,27 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
assert self.model_id is not None, MODEL_ID_NOT_SET_ERROR
|
||||
return self.model_id
|
||||
|
||||
def _estimate_number_count_in_chunk(self, chunk: types.ModelInputChunk) -> int:
|
||||
if isinstance(chunk, types.ImageChunk):
|
||||
return len(chunk.data)
|
||||
if isinstance(chunk, types.ImageAssetPointerChunk):
|
||||
return len(chunk.location)
|
||||
return chunk.length
|
||||
|
||||
def _estimate_number_count(self, datum: types.Datum) -> int:
|
||||
return sum(
|
||||
self._estimate_number_count_in_chunk(chunk) for chunk in datum.model_input.chunks
|
||||
) + sum(len(value.data) for _, value in datum.loss_fn_inputs.items())
|
||||
def _estimate_bytes_count(self, datum: types.Datum) -> int:
|
||||
return self.holder.estimate_bytes_count_in_model_input(datum.model_input) + sum(len(value.data) * 10 for _, value in datum.loss_fn_inputs.items())
|
||||
|
||||
def _chunked_requests_generator(
|
||||
self, data: List[types.Datum]
|
||||
) -> Generator[List[types.Datum], None, None]:
|
||||
current_chunk: List[types.Datum] = []
|
||||
current_chunk_number_count = 0
|
||||
current_chunk_bytes_count = 0
|
||||
|
||||
for datum in data:
|
||||
estimated_number_count = self._estimate_number_count(datum)
|
||||
estimated_bytes_count = self._estimate_bytes_count(datum)
|
||||
if (
|
||||
len(current_chunk) > 0
|
||||
and current_chunk_number_count + estimated_number_count > MAX_CHUNK_NUMBER_COUNT
|
||||
and current_chunk_bytes_count + estimated_bytes_count > MAX_CHUNK_BYTES_COUNT
|
||||
) or (len(current_chunk) == MAX_CHUNK_LEN):
|
||||
yield current_chunk
|
||||
current_chunk = []
|
||||
current_chunk_number_count = 0
|
||||
current_chunk_bytes_count = 0
|
||||
|
||||
current_chunk.append(datum)
|
||||
current_chunk_number_count += estimated_number_count
|
||||
current_chunk_bytes_count += estimated_bytes_count
|
||||
|
||||
if len(current_chunk) > 0:
|
||||
yield current_chunk
|
||||
|
|
@ -840,7 +831,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
def on_queue_state_change(self, queue_state: QueueState) -> None:
|
||||
def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None:
|
||||
QUEUE_STATE_LOG_INTERVAL = 60
|
||||
if queue_state == QueueState.ACTIVE:
|
||||
return
|
||||
|
|
@ -848,13 +839,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
return
|
||||
self._last_queue_state_logged = time.time()
|
||||
|
||||
if queue_state == QueueState.PAUSED_RATE_LIMIT:
|
||||
reason = "concurrent models rate limit hit"
|
||||
elif queue_state == QueueState.PAUSED_CAPACITY:
|
||||
reason = "out of capacity"
|
||||
else:
|
||||
reason = "unknown"
|
||||
logger.warning(f"Training is paused for {self.model_id}. Reason: {reason}")
|
||||
if not queue_state_reason:
|
||||
if queue_state == QueueState.PAUSED_RATE_LIMIT:
|
||||
queue_state_reason = "concurrent training clients rate limit hit"
|
||||
elif queue_state == QueueState.PAUSED_CAPACITY:
|
||||
queue_state_reason = "Tinker backend is running short on capacity, please wait"
|
||||
else:
|
||||
queue_state_reason = "unknown"
|
||||
logger.warning(f"Training is paused for {self.model_id}. Reason: {queue_state_reason}")
|
||||
|
||||
|
||||
def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer:
|
||||
|
|
|
|||
|
|
@ -142,8 +142,8 @@ class RetryHandler(Generic[T]): # noqa: UP046
|
|||
# Create a dummy request for the exception (required by APIConnectionError)
|
||||
dummy_request = httpx.Request("GET", "http://localhost")
|
||||
raise tinker.APIConnectionError(
|
||||
message=f"No progress made in {self.config.progress_timeout}s. Requests appear to be stuck.",
|
||||
request=dummy_request,
|
||||
message=f"No progress made in {self.config.progress_timeout}s. Requests appear to be stuck.",
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -21,6 +21,12 @@ class AdamParams(StrictBase):
|
|||
eps: float = 1e-12
|
||||
"""Term added to the denominator to improve numerical stability"""
|
||||
|
||||
weight_decay: float = 0.0
|
||||
"""Weight decay for the optimizer. Uses decoupled weight decay."""
|
||||
|
||||
grad_clip_norm: float = 0.0
|
||||
"""Gradient clip norm for the optimizer. 0.0 means no clipping."""
|
||||
|
||||
|
||||
class OptimStepRequest(StrictBase):
|
||||
adam_params: AdamParams
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue