From 2a1f5e32402d805f580e8888c10f0dd9e1e9a599 Mon Sep 17 00:00:00 2001 From: Andrii Grynenko Date: Fri, 17 Apr 2026 00:45:03 +0000 Subject: [PATCH] Sync contents --- .ruff.toml | 4 + .sync_state | 4 +- pyproject.toml | 3 +- src/tinker/lib/api_future_impl.py | 35 +++++ src/tinker/lib/internal_client_holder.py | 6 +- src/tinker/lib/internal_client_holder_test.py | 66 ++++++++- src/tinker/proto/__init__.py | 0 src/tinker/proto/response_conv.py | 110 +++++++++++++++ src/tinker/proto/tinker_public_pb2.py | 31 ++++ src/tinker/proto/tinker_public_pb2.pyi | 132 ++++++++++++++++++ 10 files changed, 384 insertions(+), 7 deletions(-) create mode 100644 src/tinker/proto/__init__.py create mode 100644 src/tinker/proto/response_conv.py create mode 100644 src/tinker/proto/tinker_public_pb2.py create mode 100644 src/tinker/proto/tinker_public_pb2.pyi diff --git a/.ruff.toml b/.ruff.toml index fbc2dcd..3df701a 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,5 +1,9 @@ line-length = 100 force-exclude = true +exclude = [ + # Generated protobuf code + "src/tinker/proto", +] [lint] diff --git a/.sync_state b/.sync_state index 70ffc85..df43c8c 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "d117d1692821faa297ea5d2ee7e4dc21b5c8bd0a", - "last_sync_time": "2026-04-14T00:00:48.831738" + "last_synced_sha": "774d92ed933004d66c3e86070e0a0d2e713a0a31", + "last_sync_time": "2026-04-17T00:45:03.498497" } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index bc3516e..5bff632 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tinker" -version = "0.18.0" +version = "0.18.1" description = "The official Python SDK for the tinker API" readme = "README.md" license = "Apache-2.0" @@ -16,6 +16,7 @@ dependencies = [ "distro>=1.7.0, <2", "sniffio", "numpy", + "protobuf>=4.21", "transformers", "rich>=13.0.0", "click>=8.0.0", diff --git a/src/tinker/lib/api_future_impl.py b/src/tinker/lib/api_future_impl.py index 4f6ce50..de872e5 100644 --- a/src/tinker/lib/api_future_impl.py +++ b/src/tinker/lib/api_future_impl.py @@ -26,6 +26,8 @@ from .sync_only import sync_only if TYPE_CHECKING: from tinker.lib.internal_client_holder import InternalClientHolder +from tinker.proto.response_conv import PROTO_SUPPORTED_TYPES, deserialize_proto_response + logger = logging.getLogger(__name__) T = TypeVar("T") @@ -115,6 +117,8 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass] "X-Tinker-Request-Iteration": str(iteration), "X-Tinker-Request-Type": self.request_type, } + if self.model_cls in PROTO_SUPPORTED_TYPES: + headers["Accept"] = "application/x-protobuf, application/json" if iteration == 0: headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str( self.request_queue_roundtrip_time @@ -213,6 +217,37 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass] connection_error_retries += 1 continue + # Proto response path: server returned protobuf bytes + content_type = response.headers.get("content-type", "") + if "application/x-protobuf" in content_type: + proto_bytes = response.http_response.content + try: + self._cached_result = deserialize_proto_response( + proto_bytes, self.model_cls + ) + return cast(T, self._cached_result) + except Exception as e: + if telemetry := self.get_telemetry(): + current_time = time.time() + telemetry.log( + "APIFuture.result_async.proto_deserialization_error", + event_data={ + "request_id": self.request_id, + "request_type": self.request_type, + "exception": str(e), + "exception_type": type(e).__name__, + "proto_bytes_len": len(proto_bytes), + "model_cls": str(self.model_cls), + "iteration": iteration, + "elapsed_time": current_time - start_time, + }, + severity="ERROR", + ) + raise ValueError( + f"Proto deserialization failed: {e} for {self.request_id=} and expected type {self.model_cls=}" + ) from e + + # JSON response path (existing) result_dict: Any = await response.json() if "type" in result_dict and result_dict["type"] == "try_again": diff --git a/src/tinker/lib/internal_client_holder.py b/src/tinker/lib/internal_client_holder.py index 94ef5cd..75ebde4 100644 --- a/src/tinker/lib/internal_client_holder.py +++ b/src/tinker/lib/internal_client_holder.py @@ -191,7 +191,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): _jwt_auth_seed: str | None = None, **kwargs: Any, ) -> None: - self._api_key = api_key + # Resolve from env now so shadow_kwargs carries the actual credential + # across pickle boundaries (workers may not have the env var set). + self._api_key = api_key or os.environ.get("TINKER_API_KEY") self._constructor_kwargs = dict(kwargs) self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop() self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {} @@ -224,7 +226,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): if not self._client_config.pjwt_auth_enabled: # Without JWT exchange, only API keys are accepted by the server. # Replace any cmd-based provider with a plain API key provider. - self._default_auth = ApiKeyAuthProvider(api_key=api_key) + self._default_auth = ApiKeyAuthProvider(api_key=self._api_key) else: # Create a dedicated pool for JWT exchange with the appropriate # credential provider. The lambda captures the pool so it stays alive. diff --git a/src/tinker/lib/internal_client_holder_test.py b/src/tinker/lib/internal_client_holder_test.py index b6bdfc9..9411008 100644 --- a/src/tinker/lib/internal_client_holder_test.py +++ b/src/tinker/lib/internal_client_holder_test.py @@ -4,11 +4,11 @@ from __future__ import annotations import asyncio from typing import Any -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from tinker.lib._auth_token_provider import AuthTokenProvider +from tinker.lib._auth_token_provider import ApiKeyAuthProvider, AuthTokenProvider from tinker.lib.internal_client_holder import ClientConnectionPool, InternalClientHolder from tinker.types.client_config_response import ClientConfigResponse as _ClientConfigResponse @@ -94,3 +94,65 @@ async def test_fetch_client_config_passes_sdk_version( call_kwargs = holder._cm.__enter__.return_value.service.client_config.call_args assert call_kwargs.kwargs["request"].sdk_version == tinker_sdk_version + + +# --------------------------------------------------------------------------- +# Pickle round-trip: ambient TINKER_API_KEY must travel with the pickle +# --------------------------------------------------------------------------- + + +def _make_holder(api_key: str | None = None) -> InternalClientHolder: + """Build a primary InternalClientHolder with server calls stubbed out.""" + with ( + patch.object( + InternalClientHolder, + "_fetch_client_config", + new_callable=AsyncMock, + # pjwt_auth_enabled=False → plain API-key auth path + return_value=_ClientConfigResponse(pjwt_auth_enabled=False), + ), + patch.object( + InternalClientHolder, + "_create_session", + new_callable=AsyncMock, + return_value="sess-pickle-test", + ), + patch.object( + InternalClientHolder, + "_start_heartbeat", + new_callable=AsyncMock, + ), + ): + holder = InternalClientHolder(api_key=api_key) + holder._session_heartbeat_task = MagicMock() + return holder + + +def test_sampling_client_pickle_roundtrip_without_env_var( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Pickle a SamplingClient created with ambient TINKER_API_KEY, then + unpickle it in an environment without the env var (simulating a worker + process). The credential must travel inside the pickle payload.""" + import pickle + + from tinker.lib.public_interfaces.sampling_client import SamplingClient + + monkeypatch.setenv("TINKER_API_KEY", "tml-key-from-env") + holder = _make_holder(api_key=None) + client = SamplingClient(holder, sampling_session_id="samp-1") + + payload = pickle.dumps(client) + + # Simulate the worker: no TINKER_API_KEY available. + monkeypatch.delenv("TINKER_API_KEY") + + with patch.object( + InternalClientHolder, + "_start_heartbeat", + new_callable=AsyncMock, + ): + restored = pickle.loads(payload) + + assert isinstance(restored.holder._default_auth, ApiKeyAuthProvider) + assert restored.holder._default_auth._token == "tml-key-from-env" diff --git a/src/tinker/proto/__init__.py b/src/tinker/proto/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tinker/proto/response_conv.py b/src/tinker/proto/response_conv.py new file mode 100644 index 0000000..0f97a98 --- /dev/null +++ b/src/tinker/proto/response_conv.py @@ -0,0 +1,110 @@ +"""Conversion helpers for proto responses to Pydantic models. + +Deserializes proto wire format into SDK Pydantic types (SampleResponse, etc.). +""" + +from __future__ import annotations + +import numpy as np + +from tinker.proto import tinker_public_pb2 as public_pb +from tinker.types.sample_response import SampleResponse +from tinker.types.sampled_sequence import SampledSequence +from tinker.types.stop_reason import StopReason + +# Set of model classes that support proto deserialization. +# Used by api_future_impl to decide whether to send Accept: application/x-protobuf. +PROTO_SUPPORTED_TYPES: set[type] = {SampleResponse} + +MASK_LOGPROB = -99999.0 + +# Proto enum -> SDK string mapping +_STOP_REASON_TO_STR: dict[int, StopReason] = { + public_pb.STOP_REASON_STOP: "stop", + public_pb.STOP_REASON_LENGTH: "length", +} + + +def deserialize_sample_response(proto_bytes: bytes) -> SampleResponse: + """Deserialize proto bytes into a Pydantic SampleResponse.""" + proto = public_pb.SampleResponse() + proto.ParseFromString(proto_bytes) + + sequences = [] + for seq in proto.sequences: + stop_reason = _STOP_REASON_TO_STR.get(seq.stop_reason) + if stop_reason is None: + raise ValueError( + f"Unknown stop_reason enum value {seq.stop_reason} in proto SampleResponse" + ) + tokens = np.frombuffer(seq.tokens, dtype=np.int32).tolist() + logprobs = np.frombuffer(seq.logprobs, dtype=np.float32).tolist() if seq.logprobs else None + sequences.append( + SampledSequence.model_construct( + stop_reason=stop_reason, + tokens=tokens, + logprobs=logprobs, + ) + ) + + prompt_logprobs: list[float | None] | None = None + if proto.prompt_logprobs: + arr = np.frombuffer(proto.prompt_logprobs, dtype=np.float32) + prompt_logprobs_list: list[float | None] = arr.tolist() + for i in np.flatnonzero(np.isnan(arr)): + prompt_logprobs_list[i] = None + prompt_logprobs = prompt_logprobs_list + + topk_prompt_logprobs: list[list[tuple[int, float]] | None] | None = None + if proto.HasField("topk_prompt_logprobs"): + topk_prompt_logprobs = _topk_from_proto(proto.topk_prompt_logprobs) + + return SampleResponse.model_construct( + sequences=sequences, + prompt_logprobs=prompt_logprobs, + topk_prompt_logprobs=topk_prompt_logprobs, + ) + + +def _topk_from_proto( + topk: public_pb.TopkPromptLogprobs, +) -> list[list[tuple[int, float]] | None]: + """Convert dense N×K TopkPromptLogprobs to Python list format.""" + n = topk.prompt_length + k = topk.k + + if n == 0 or k == 0: + return [] + + token_ids = np.ndarray((n, k), dtype=np.int32, buffer=topk.token_ids) + logprobs = np.ndarray((n, k), dtype=np.float32, buffer=topk.logprobs) + + # Single flat zip (faster than 32K per-row zips), then slice per row + tid_flat = token_ids.ravel().tolist() + lp_flat = logprobs.ravel().tolist() + all_tuples = list(zip(tid_flat, lp_flat)) + + mask_lp = MASK_LOGPROB + result: list[list[tuple[int, float]] | None] = [] + for i in range(n): + start = i * k + # First-element sentinel check: if first entry is sentinel, whole row is None + if tid_flat[start] == 0 and lp_flat[start] == mask_lp: + result.append(None) + else: + end = start + k + while end > start and tid_flat[end - 1] == 0 and lp_flat[end - 1] == mask_lp: + end -= 1 + result.append(all_tuples[start:end]) + return result + + +def deserialize_proto_response(proto_bytes: bytes, model_cls: type) -> object: + """Deserialize a proto response based on the expected model class. + + Dispatches to the appropriate deserializer based on model_cls. + Raises ValueError for unsupported types. + """ + if model_cls is SampleResponse: + return deserialize_sample_response(proto_bytes) + raise ValueError(f"Proto deserialization not supported for {model_cls}") diff --git a/src/tinker/proto/tinker_public_pb2.py b/src/tinker/proto/tinker_public_pb2.py new file mode 100644 index 0000000..dc9a017 --- /dev/null +++ b/src/tinker/proto/tinker_public_pb2.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tinker_public.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13tinker_public.proto\x12\rtinker_public\"\xd4\x01\n\x0eSampleResponse\x12\x31\n\tsequences\x18\x01 \x03(\x0b\x32\x1e.tinker_public.SampledSequence\x12\x1c\n\x0fprompt_logprobs\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x12\x44\n\x14topk_prompt_logprobs\x18\x03 \x01(\x0b\x32!.tinker_public.TopkPromptLogprobsH\x01\x88\x01\x01\x42\x12\n\x10_prompt_logprobsB\x17\n\x15_topk_prompt_logprobs\"u\n\x0fSampledSequence\x12.\n\x0bstop_reason\x18\x01 \x01(\x0e\x32\x19.tinker_public.StopReason\x12\x0e\n\x06tokens\x18\x02 \x01(\x0c\x12\x15\n\x08logprobs\x18\x03 \x01(\x0cH\x00\x88\x01\x01\x42\x0b\n\t_logprobs\"[\n\x12TopkPromptLogprobs\x12\x11\n\ttoken_ids\x18\x01 \x01(\x0c\x12\x10\n\x08logprobs\x18\x02 \x01(\x0c\x12\t\n\x01k\x18\x03 \x01(\r\x12\x15\n\rprompt_length\x18\x04 \x01(\r*:\n\nStopReason\x12\x14\n\x10STOP_REASON_STOP\x10\x00\x12\x16\n\x12STOP_REASON_LENGTH\x10\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tinker_public_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _STOPREASON._serialized_start=465 + _STOPREASON._serialized_end=523 + _SAMPLERESPONSE._serialized_start=39 + _SAMPLERESPONSE._serialized_end=251 + _SAMPLEDSEQUENCE._serialized_start=253 + _SAMPLEDSEQUENCE._serialized_end=370 + _TOPKPROMPTLOGPROBS._serialized_start=372 + _TOPKPROMPTLOGPROBS._serialized_end=463 +# @@protoc_insertion_point(module_scope) diff --git a/src/tinker/proto/tinker_public_pb2.pyi b/src/tinker/proto/tinker_public_pb2.pyi new file mode 100644 index 0000000..5bf050f --- /dev/null +++ b/src/tinker/proto/tinker_public_pb2.pyi @@ -0,0 +1,132 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" + +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class _StopReason: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _StopReasonEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_StopReason.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + STOP_REASON_STOP: _StopReason.ValueType # 0 + STOP_REASON_LENGTH: _StopReason.ValueType # 1 + +class StopReason(_StopReason, metaclass=_StopReasonEnumTypeWrapper): ... + +STOP_REASON_STOP: StopReason.ValueType # 0 +STOP_REASON_LENGTH: StopReason.ValueType # 1 +Global___StopReason: typing_extensions.TypeAlias = StopReason + +@typing.final +class SampleResponse(google.protobuf.message.Message): + """Public SampleResponse returned to SDK clients. + Field numbers are compatible with tinker_internal.SampleResponse. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SEQUENCES_FIELD_NUMBER: builtins.int + PROMPT_LOGPROBS_FIELD_NUMBER: builtins.int + TOPK_PROMPT_LOGPROBS_FIELD_NUMBER: builtins.int + prompt_logprobs: builtins.bytes + """np.array(prompt_logprobs, dtype=np.float32).tobytes() + NaN for missing positions (e.g. first prompt token) + """ + @property + def sequences(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[Global___SampledSequence]: ... + @property + def topk_prompt_logprobs(self) -> Global___TopkPromptLogprobs: + """Optional: absent means no topk was requested.""" + + def __init__( + self, + *, + sequences: collections.abc.Iterable[Global___SampledSequence] | None = ..., + prompt_logprobs: builtins.bytes | None = ..., + topk_prompt_logprobs: Global___TopkPromptLogprobs | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_prompt_logprobs", b"_prompt_logprobs", "_topk_prompt_logprobs", b"_topk_prompt_logprobs", "prompt_logprobs", b"prompt_logprobs", "topk_prompt_logprobs", b"topk_prompt_logprobs"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_prompt_logprobs", b"_prompt_logprobs", "_topk_prompt_logprobs", b"_topk_prompt_logprobs", "prompt_logprobs", b"prompt_logprobs", "sequences", b"sequences", "topk_prompt_logprobs", b"topk_prompt_logprobs"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing.Literal["_prompt_logprobs", b"_prompt_logprobs"]) -> typing.Literal["prompt_logprobs"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing.Literal["_topk_prompt_logprobs", b"_topk_prompt_logprobs"]) -> typing.Literal["topk_prompt_logprobs"] | None: ... + +Global___SampleResponse: typing_extensions.TypeAlias = SampleResponse + +@typing.final +class SampledSequence(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STOP_REASON_FIELD_NUMBER: builtins.int + TOKENS_FIELD_NUMBER: builtins.int + LOGPROBS_FIELD_NUMBER: builtins.int + stop_reason: Global___StopReason.ValueType + tokens: builtins.bytes + """np.array(tokens, dtype=np.int32).tobytes()""" + logprobs: builtins.bytes + """np.array(logprobs, dtype=np.float32).tobytes() + Optional, same length as tokens when present + """ + def __init__( + self, + *, + stop_reason: Global___StopReason.ValueType = ..., + tokens: builtins.bytes = ..., + logprobs: builtins.bytes | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_logprobs", b"_logprobs", "logprobs", b"logprobs"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_logprobs", b"_logprobs", "logprobs", b"logprobs", "stop_reason", b"stop_reason", "tokens", b"tokens"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_logprobs", b"_logprobs"]) -> typing.Literal["logprobs"] | None: ... + +Global___SampledSequence: typing_extensions.TypeAlias = SampledSequence + +@typing.final +class TopkPromptLogprobs(google.protobuf.message.Message): + """Dense N×K matrices for top-k prompt logprobs. + Both token_ids and logprobs are row-major flattened: position i, rank j + is at index i*k + j. Empty cells use token_id=0 and logprob=-99999.0. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TOKEN_IDS_FIELD_NUMBER: builtins.int + LOGPROBS_FIELD_NUMBER: builtins.int + K_FIELD_NUMBER: builtins.int + PROMPT_LENGTH_FIELD_NUMBER: builtins.int + token_ids: builtins.bytes + """Row-major N×K matrix of int32 token IDs""" + logprobs: builtins.bytes + """Row-major N×K matrix of float32 logprobs""" + k: builtins.int + """Number of top-k entries per position""" + prompt_length: builtins.int + """Number of prompt positions""" + def __init__( + self, + *, + token_ids: builtins.bytes = ..., + logprobs: builtins.bytes = ..., + k: builtins.int = ..., + prompt_length: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["k", b"k", "logprobs", b"logprobs", "prompt_length", b"prompt_length", "token_ids", b"token_ids"]) -> None: ... + +Global___TopkPromptLogprobs: typing_extensions.TypeAlias = TopkPromptLogprobs