mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
0e726a2c70
commit
ad03d44978
8 changed files with 957 additions and 1439 deletions
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "896d014e55e38b513967cd3cb9672200a421ff9f",
|
||||
"last_sync_time": "2026-01-16T05:35:31.063947"
|
||||
"last_synced_sha": "e99a92629bc1b7df95b2ffc0b60a60c8eecddee1",
|
||||
"last_sync_time": "2026-01-21T23:13:02.930293"
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "tinker"
|
||||
version = "0.8.0"
|
||||
version = "0.8.1"
|
||||
description = "The official Python SDK for the tinker API"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
|
|
|
|||
|
|
@ -116,10 +116,12 @@ def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
|
|||
return model.__fields__ # type: ignore
|
||||
|
||||
|
||||
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
|
||||
def model_copy(
|
||||
model: _ModelT, *, deep: bool = False, update: dict[str, Any] | None = None
|
||||
) -> _ModelT:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_copy(deep=deep)
|
||||
return model.copy(deep=deep) # type: ignore
|
||||
return model.model_copy(deep=deep, update=update) # type: ignore
|
||||
return model.copy(deep=deep, update=update) # type: ignore
|
||||
|
||||
|
||||
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
|
||||
|
|
|
|||
|
|
@ -370,6 +370,8 @@ def _load_tokenizer_from_model_info(
|
|||
"""
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
|
||||
model_name = model_name.split(":")[0]
|
||||
|
||||
# Use tokenizer_id if provided, otherwise fall back to heuristic logic
|
||||
kwargs = {}
|
||||
if tokenizer_id is None:
|
||||
|
|
|
|||
|
|
@ -27,15 +27,12 @@ from ..sync_only import sync_only
|
|||
from .sampling_client import SamplingClient, _load_tokenizer_from_model_info
|
||||
|
||||
try:
|
||||
import torch # type: ignore[import-not-found]
|
||||
|
||||
_HAVE_TORCH = True
|
||||
import torch
|
||||
except ImportError:
|
||||
_HAVE_TORCH = False
|
||||
torch = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..internal_client_holder import InternalClientHolder
|
||||
|
|
@ -372,13 +369,13 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
self, data: List[types.Datum], loss_fn: CustomLossFnV1
|
||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||
"""Async version of forward_backward_custom."""
|
||||
if not _HAVE_TORCH:
|
||||
if torch is None:
|
||||
raise ImportError("PyTorch is not installed. Cannot run custom forward_backward.")
|
||||
|
||||
# First do a forward pass and get logprobs
|
||||
forward_future = await self.forward_async(data, "cross_entropy")
|
||||
forward_result = await forward_future.result_async()
|
||||
logprobs_list: List["torch.Tensor"] = []
|
||||
logprobs_list = []
|
||||
for out in forward_result.loss_fn_outputs:
|
||||
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
|
||||
logprobs_list.append(logprob)
|
||||
|
|
@ -423,6 +420,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]:
|
||||
"""Update model parameters using Adam optimizer.
|
||||
|
||||
The Adam optimizer used by tinker is identical
|
||||
to [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||
Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).
|
||||
|
||||
|
||||
Args:
|
||||
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import httpx
|
||||
|
||||
from .._base_client import make_request_options
|
||||
from .._compat import model_dump
|
||||
from .._compat import model_copy, model_dump
|
||||
from .._resource import AsyncAPIResource
|
||||
from .._types import NOT_GIVEN, Body, Headers, NotGiven, Query
|
||||
from ..types.create_model_request import CreateModelRequest
|
||||
|
|
@ -104,12 +104,19 @@ class AsyncModelsResource(AsyncAPIResource):
|
|||
if max_retries is not NOT_GIVEN:
|
||||
options["max_retries"] = max_retries
|
||||
|
||||
return await self._post(
|
||||
result = await self._post(
|
||||
"/api/v1/get_info",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=GetInfoResponse,
|
||||
)
|
||||
if result.model_data.tokenizer_id:
|
||||
tokenizer_id = result.model_data.tokenizer_id.split(":")[0]
|
||||
updated_model_data = model_copy(
|
||||
result.model_data, update={"tokenizer_id": tokenizer_id}
|
||||
)
|
||||
result = model_copy(result, update={"model_data": updated_model_data})
|
||||
return result
|
||||
|
||||
async def unload(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class AdamParams(StrictBase):
|
|||
"""Weight decay for the optimizer. Uses decoupled weight decay."""
|
||||
|
||||
grad_clip_norm: float = 0.0
|
||||
"""Gradient clip norm for the optimizer. 0.0 means no clipping."""
|
||||
"""Maximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping."""
|
||||
|
||||
|
||||
class OptimStepRequest(StrictBase):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue