mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
making torch an optional dependency (#15)
This commit is contained in:
parent
8534f4b4b8
commit
0e726a2c70
4 changed files with 1437 additions and 932 deletions
|
|
@ -16,11 +16,11 @@ dependencies = [
|
|||
"distro>=1.7.0, <2",
|
||||
"sniffio",
|
||||
"numpy",
|
||||
"torch",
|
||||
"transformers",
|
||||
"rich>=13.0.0",
|
||||
"click>=8.0.0",
|
||||
]
|
||||
|
||||
requires-python = ">= 3.11"
|
||||
classifiers = [
|
||||
"Typing :: Typed",
|
||||
|
|
@ -49,6 +49,7 @@ Documentation = "https://tinker-docs.thinkingmachines.ai/"
|
|||
|
||||
[project.optional-dependencies]
|
||||
aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.8"]
|
||||
torch = ["torch"]
|
||||
|
||||
[tool.uv]
|
||||
managed = true
|
||||
|
|
|
|||
|
|
@ -26,7 +26,16 @@ from ..retry_handler import RetryConfig
|
|||
from ..sync_only import sync_only
|
||||
from .sampling_client import SamplingClient, _load_tokenizer_from_model_info
|
||||
|
||||
try:
|
||||
import torch # type: ignore[import-not-found]
|
||||
|
||||
_HAVE_TORCH = True
|
||||
except ImportError:
|
||||
_HAVE_TORCH = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from ..internal_client_holder import InternalClientHolder
|
||||
|
|
@ -363,12 +372,13 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
self, data: List[types.Datum], loss_fn: CustomLossFnV1
|
||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||
"""Async version of forward_backward_custom."""
|
||||
import torch
|
||||
if not _HAVE_TORCH:
|
||||
raise ImportError("PyTorch is not installed. Cannot run custom forward_backward.")
|
||||
|
||||
# First do a forward pass and get logprobs
|
||||
forward_future = await self.forward_async(data, "cross_entropy")
|
||||
forward_result = await forward_future.result_async()
|
||||
logprobs_list: List[torch.Tensor] = []
|
||||
logprobs_list: List["torch.Tensor"] = []
|
||||
for out in forward_result.loss_fn_outputs:
|
||||
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
|
||||
logprobs_list.append(logprob)
|
||||
|
|
|
|||
|
|
@ -86,7 +86,6 @@ def _convert_tensor_dtype_to_torch(dtype: TensorDtype) -> "torch.dtype":
|
|||
"""Convert TensorDtype to torch dtype."""
|
||||
if not _HAVE_TORCH:
|
||||
raise ImportError("PyTorch is not installed. Cannot convert to torch dtype.")
|
||||
import torch
|
||||
|
||||
if dtype == "float32":
|
||||
return torch.float32
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue