Sync contents

This commit is contained in:
Andrii Grynenko 2026-01-21 23:13:02 +00:00
parent 0e726a2c70
commit ad03d44978
8 changed files with 957 additions and 1439 deletions

View file

@ -1,4 +1,4 @@
{ {
"last_synced_sha": "896d014e55e38b513967cd3cb9672200a421ff9f", "last_synced_sha": "e99a92629bc1b7df95b2ffc0b60a60c8eecddee1",
"last_sync_time": "2026-01-16T05:35:31.063947" "last_sync_time": "2026-01-21T23:13:02.930293"
} }

View file

@ -1,6 +1,6 @@
[project] [project]
name = "tinker" name = "tinker"
version = "0.8.0" version = "0.8.1"
description = "The official Python SDK for the tinker API" description = "The official Python SDK for the tinker API"
readme = "README.md" readme = "README.md"
license = "Apache-2.0" license = "Apache-2.0"

View file

@ -116,10 +116,12 @@ def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
return model.__fields__ # type: ignore return model.__fields__ # type: ignore
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT: def model_copy(
model: _ModelT, *, deep: bool = False, update: dict[str, Any] | None = None
) -> _ModelT:
if PYDANTIC_V2: if PYDANTIC_V2:
return model.model_copy(deep=deep) return model.model_copy(deep=deep, update=update) # type: ignore
return model.copy(deep=deep) # type: ignore return model.copy(deep=deep, update=update) # type: ignore
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:

View file

@ -370,6 +370,8 @@ def _load_tokenizer_from_model_info(
""" """
from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.models.auto.tokenization_auto import AutoTokenizer
model_name = model_name.split(":")[0]
# Use tokenizer_id if provided, otherwise fall back to heuristic logic # Use tokenizer_id if provided, otherwise fall back to heuristic logic
kwargs = {} kwargs = {}
if tokenizer_id is None: if tokenizer_id is None:

View file

@ -27,15 +27,12 @@ from ..sync_only import sync_only
from .sampling_client import SamplingClient, _load_tokenizer_from_model_info from .sampling_client import SamplingClient, _load_tokenizer_from_model_info
try: try:
import torch # type: ignore[import-not-found] import torch
_HAVE_TORCH = True
except ImportError: except ImportError:
_HAVE_TORCH = False torch = None
if TYPE_CHECKING: if TYPE_CHECKING:
import torch
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from ..internal_client_holder import InternalClientHolder from ..internal_client_holder import InternalClientHolder
@ -372,13 +369,13 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
self, data: List[types.Datum], loss_fn: CustomLossFnV1 self, data: List[types.Datum], loss_fn: CustomLossFnV1
) -> APIFuture[types.ForwardBackwardOutput]: ) -> APIFuture[types.ForwardBackwardOutput]:
"""Async version of forward_backward_custom.""" """Async version of forward_backward_custom."""
if not _HAVE_TORCH: if torch is None:
raise ImportError("PyTorch is not installed. Cannot run custom forward_backward.") raise ImportError("PyTorch is not installed. Cannot run custom forward_backward.")
# First do a forward pass and get logprobs # First do a forward pass and get logprobs
forward_future = await self.forward_async(data, "cross_entropy") forward_future = await self.forward_async(data, "cross_entropy")
forward_result = await forward_future.result_async() forward_result = await forward_future.result_async()
logprobs_list: List["torch.Tensor"] = [] logprobs_list = []
for out in forward_result.loss_fn_outputs: for out in forward_result.loss_fn_outputs:
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True) logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
logprobs_list.append(logprob) logprobs_list.append(logprob)
@ -423,6 +420,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]: def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]:
"""Update model parameters using Adam optimizer. """Update model parameters using Adam optimizer.
The Adam optimizer used by tinker is identical
to [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).
Args: Args:
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay) - `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import httpx import httpx
from .._base_client import make_request_options from .._base_client import make_request_options
from .._compat import model_dump from .._compat import model_copy, model_dump
from .._resource import AsyncAPIResource from .._resource import AsyncAPIResource
from .._types import NOT_GIVEN, Body, Headers, NotGiven, Query from .._types import NOT_GIVEN, Body, Headers, NotGiven, Query
from ..types.create_model_request import CreateModelRequest from ..types.create_model_request import CreateModelRequest
@ -104,12 +104,19 @@ class AsyncModelsResource(AsyncAPIResource):
if max_retries is not NOT_GIVEN: if max_retries is not NOT_GIVEN:
options["max_retries"] = max_retries options["max_retries"] = max_retries
return await self._post( result = await self._post(
"/api/v1/get_info", "/api/v1/get_info",
body=model_dump(request, exclude_unset=True, mode="json"), body=model_dump(request, exclude_unset=True, mode="json"),
options=options, options=options,
cast_to=GetInfoResponse, cast_to=GetInfoResponse,
) )
if result.model_data.tokenizer_id:
tokenizer_id = result.model_data.tokenizer_id.split(":")[0]
updated_model_data = model_copy(
result.model_data, update={"tokenizer_id": tokenizer_id}
)
result = model_copy(result, update={"model_data": updated_model_data})
return result
async def unload( async def unload(
self, self,

View file

@ -26,7 +26,7 @@ class AdamParams(StrictBase):
"""Weight decay for the optimizer. Uses decoupled weight decay.""" """Weight decay for the optimizer. Uses decoupled weight decay."""
grad_clip_norm: float = 0.0 grad_clip_norm: float = 0.0
"""Gradient clip norm for the optimizer. 0.0 means no clipping.""" """Maximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping."""
class OptimStepRequest(StrictBase): class OptimStepRequest(StrictBase):

2351
uv.lock generated

File diff suppressed because it is too large Load diff