mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
0e726a2c70
commit
ad03d44978
8 changed files with 957 additions and 1439 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"last_synced_sha": "896d014e55e38b513967cd3cb9672200a421ff9f",
|
"last_synced_sha": "e99a92629bc1b7df95b2ffc0b60a60c8eecddee1",
|
||||||
"last_sync_time": "2026-01-16T05:35:31.063947"
|
"last_sync_time": "2026-01-21T23:13:02.930293"
|
||||||
}
|
}
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "tinker"
|
name = "tinker"
|
||||||
version = "0.8.0"
|
version = "0.8.1"
|
||||||
description = "The official Python SDK for the tinker API"
|
description = "The official Python SDK for the tinker API"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
|
||||||
|
|
@ -116,10 +116,12 @@ def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
|
||||||
return model.__fields__ # type: ignore
|
return model.__fields__ # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
|
def model_copy(
|
||||||
|
model: _ModelT, *, deep: bool = False, update: dict[str, Any] | None = None
|
||||||
|
) -> _ModelT:
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
return model.model_copy(deep=deep)
|
return model.model_copy(deep=deep, update=update) # type: ignore
|
||||||
return model.copy(deep=deep) # type: ignore
|
return model.copy(deep=deep, update=update) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
|
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
|
||||||
|
|
|
||||||
|
|
@ -370,6 +370,8 @@ def _load_tokenizer_from_model_info(
|
||||||
"""
|
"""
|
||||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||||
|
|
||||||
|
model_name = model_name.split(":")[0]
|
||||||
|
|
||||||
# Use tokenizer_id if provided, otherwise fall back to heuristic logic
|
# Use tokenizer_id if provided, otherwise fall back to heuristic logic
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if tokenizer_id is None:
|
if tokenizer_id is None:
|
||||||
|
|
|
||||||
|
|
@ -27,15 +27,12 @@ from ..sync_only import sync_only
|
||||||
from .sampling_client import SamplingClient, _load_tokenizer_from_model_info
|
from .sampling_client import SamplingClient, _load_tokenizer_from_model_info
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch # type: ignore[import-not-found]
|
import torch
|
||||||
|
|
||||||
_HAVE_TORCH = True
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_HAVE_TORCH = False
|
torch = None
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
from ..internal_client_holder import InternalClientHolder
|
from ..internal_client_holder import InternalClientHolder
|
||||||
|
|
@ -372,13 +369,13 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
self, data: List[types.Datum], loss_fn: CustomLossFnV1
|
self, data: List[types.Datum], loss_fn: CustomLossFnV1
|
||||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||||
"""Async version of forward_backward_custom."""
|
"""Async version of forward_backward_custom."""
|
||||||
if not _HAVE_TORCH:
|
if torch is None:
|
||||||
raise ImportError("PyTorch is not installed. Cannot run custom forward_backward.")
|
raise ImportError("PyTorch is not installed. Cannot run custom forward_backward.")
|
||||||
|
|
||||||
# First do a forward pass and get logprobs
|
# First do a forward pass and get logprobs
|
||||||
forward_future = await self.forward_async(data, "cross_entropy")
|
forward_future = await self.forward_async(data, "cross_entropy")
|
||||||
forward_result = await forward_future.result_async()
|
forward_result = await forward_future.result_async()
|
||||||
logprobs_list: List["torch.Tensor"] = []
|
logprobs_list = []
|
||||||
for out in forward_result.loss_fn_outputs:
|
for out in forward_result.loss_fn_outputs:
|
||||||
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
|
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
|
||||||
logprobs_list.append(logprob)
|
logprobs_list.append(logprob)
|
||||||
|
|
@ -423,6 +420,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]:
|
def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]:
|
||||||
"""Update model parameters using Adam optimizer.
|
"""Update model parameters using Adam optimizer.
|
||||||
|
|
||||||
|
The Adam optimizer used by tinker is identical
|
||||||
|
to [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||||
|
Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
|
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from .._base_client import make_request_options
|
from .._base_client import make_request_options
|
||||||
from .._compat import model_dump
|
from .._compat import model_copy, model_dump
|
||||||
from .._resource import AsyncAPIResource
|
from .._resource import AsyncAPIResource
|
||||||
from .._types import NOT_GIVEN, Body, Headers, NotGiven, Query
|
from .._types import NOT_GIVEN, Body, Headers, NotGiven, Query
|
||||||
from ..types.create_model_request import CreateModelRequest
|
from ..types.create_model_request import CreateModelRequest
|
||||||
|
|
@ -104,12 +104,19 @@ class AsyncModelsResource(AsyncAPIResource):
|
||||||
if max_retries is not NOT_GIVEN:
|
if max_retries is not NOT_GIVEN:
|
||||||
options["max_retries"] = max_retries
|
options["max_retries"] = max_retries
|
||||||
|
|
||||||
return await self._post(
|
result = await self._post(
|
||||||
"/api/v1/get_info",
|
"/api/v1/get_info",
|
||||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||||
options=options,
|
options=options,
|
||||||
cast_to=GetInfoResponse,
|
cast_to=GetInfoResponse,
|
||||||
)
|
)
|
||||||
|
if result.model_data.tokenizer_id:
|
||||||
|
tokenizer_id = result.model_data.tokenizer_id.split(":")[0]
|
||||||
|
updated_model_data = model_copy(
|
||||||
|
result.model_data, update={"tokenizer_id": tokenizer_id}
|
||||||
|
)
|
||||||
|
result = model_copy(result, update={"model_data": updated_model_data})
|
||||||
|
return result
|
||||||
|
|
||||||
async def unload(
|
async def unload(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class AdamParams(StrictBase):
|
||||||
"""Weight decay for the optimizer. Uses decoupled weight decay."""
|
"""Weight decay for the optimizer. Uses decoupled weight decay."""
|
||||||
|
|
||||||
grad_clip_norm: float = 0.0
|
grad_clip_norm: float = 0.0
|
||||||
"""Gradient clip norm for the optimizer. 0.0 means no clipping."""
|
"""Maximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping."""
|
||||||
|
|
||||||
|
|
||||||
class OptimStepRequest(StrictBase):
|
class OptimStepRequest(StrictBase):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue