making torch an optional dependency (#15)

This commit is contained in:
Alec Gunny 2026-01-20 09:44:24 -08:00 committed by GitHub
parent 8534f4b4b8
commit 0e726a2c70
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 1437 additions and 932 deletions

View file

@ -16,11 +16,11 @@ dependencies = [
"distro>=1.7.0, <2",
"sniffio",
"numpy",
"torch",
"transformers",
"rich>=13.0.0",
"click>=8.0.0",
]
requires-python = ">= 3.11"
classifiers = [
"Typing :: Typed",
@ -49,6 +49,7 @@ Documentation = "https://tinker-docs.thinkingmachines.ai/"
[project.optional-dependencies]
aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.8"]
torch = ["torch"]
[tool.uv]
managed = true

View file

@ -26,7 +26,16 @@ from ..retry_handler import RetryConfig
from ..sync_only import sync_only
from .sampling_client import SamplingClient, _load_tokenizer_from_model_info
try:
import torch # type: ignore[import-not-found]
_HAVE_TORCH = True
except ImportError:
_HAVE_TORCH = False
if TYPE_CHECKING:
import torch
from transformers.tokenization_utils import PreTrainedTokenizer
from ..internal_client_holder import InternalClientHolder
@ -363,12 +372,13 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
self, data: List[types.Datum], loss_fn: CustomLossFnV1
) -> APIFuture[types.ForwardBackwardOutput]:
"""Async version of forward_backward_custom."""
import torch
if not _HAVE_TORCH:
raise ImportError("PyTorch is not installed. Cannot run custom forward_backward.")
# First do a forward pass and get logprobs
forward_future = await self.forward_async(data, "cross_entropy")
forward_result = await forward_future.result_async()
logprobs_list: List[torch.Tensor] = []
logprobs_list: List["torch.Tensor"] = []
for out in forward_result.loss_fn_outputs:
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
logprobs_list.append(logprob)

View file

@ -86,7 +86,6 @@ def _convert_tensor_dtype_to_torch(dtype: TensorDtype) -> "torch.dtype":
"""Convert TensorDtype to torch dtype."""
if not _HAVE_TORCH:
raise ImportError("PyTorch is not installed. Cannot convert to torch dtype.")
import torch
if dtype == "float32":
return torch.float32

2351
uv.lock generated

File diff suppressed because it is too large Load diff