Sync contents

This commit is contained in:
Andrii Grynenko 2026-04-17 00:45:03 +00:00
parent 30517b667f
commit 2a1f5e3240
10 changed files with 384 additions and 7 deletions

View file

@ -1,5 +1,9 @@
line-length = 100 line-length = 100
force-exclude = true force-exclude = true
exclude = [
# Generated protobuf code
"src/tinker/proto",
]
[lint] [lint]

View file

@ -1,4 +1,4 @@
{ {
"last_synced_sha": "d117d1692821faa297ea5d2ee7e4dc21b5c8bd0a", "last_synced_sha": "774d92ed933004d66c3e86070e0a0d2e713a0a31",
"last_sync_time": "2026-04-14T00:00:48.831738" "last_sync_time": "2026-04-17T00:45:03.498497"
} }

View file

@ -1,6 +1,6 @@
[project] [project]
name = "tinker" name = "tinker"
version = "0.18.0" version = "0.18.1"
description = "The official Python SDK for the tinker API" description = "The official Python SDK for the tinker API"
readme = "README.md" readme = "README.md"
license = "Apache-2.0" license = "Apache-2.0"
@ -16,6 +16,7 @@ dependencies = [
"distro>=1.7.0, <2", "distro>=1.7.0, <2",
"sniffio", "sniffio",
"numpy", "numpy",
"protobuf>=4.21",
"transformers", "transformers",
"rich>=13.0.0", "rich>=13.0.0",
"click>=8.0.0", "click>=8.0.0",

View file

@ -26,6 +26,8 @@ from .sync_only import sync_only
if TYPE_CHECKING: if TYPE_CHECKING:
from tinker.lib.internal_client_holder import InternalClientHolder from tinker.lib.internal_client_holder import InternalClientHolder
from tinker.proto.response_conv import PROTO_SUPPORTED_TYPES, deserialize_proto_response
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T") T = TypeVar("T")
@ -115,6 +117,8 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
"X-Tinker-Request-Iteration": str(iteration), "X-Tinker-Request-Iteration": str(iteration),
"X-Tinker-Request-Type": self.request_type, "X-Tinker-Request-Type": self.request_type,
} }
if self.model_cls in PROTO_SUPPORTED_TYPES:
headers["Accept"] = "application/x-protobuf, application/json"
if iteration == 0: if iteration == 0:
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str( headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
self.request_queue_roundtrip_time self.request_queue_roundtrip_time
@ -213,6 +217,37 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
connection_error_retries += 1 connection_error_retries += 1
continue continue
# Proto response path: server returned protobuf bytes
content_type = response.headers.get("content-type", "")
if "application/x-protobuf" in content_type:
proto_bytes = response.http_response.content
try:
self._cached_result = deserialize_proto_response(
proto_bytes, self.model_cls
)
return cast(T, self._cached_result)
except Exception as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.proto_deserialization_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"exception_type": type(e).__name__,
"proto_bytes_len": len(proto_bytes),
"model_cls": str(self.model_cls),
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="ERROR",
)
raise ValueError(
f"Proto deserialization failed: {e} for {self.request_id=} and expected type {self.model_cls=}"
) from e
# JSON response path (existing)
result_dict: Any = await response.json() result_dict: Any = await response.json()
if "type" in result_dict and result_dict["type"] == "try_again": if "type" in result_dict and result_dict["type"] == "try_again":

View file

@ -191,7 +191,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
_jwt_auth_seed: str | None = None, _jwt_auth_seed: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
self._api_key = api_key # Resolve from env now so shadow_kwargs carries the actual credential
# across pickle boundaries (workers may not have the env var set).
self._api_key = api_key or os.environ.get("TINKER_API_KEY")
self._constructor_kwargs = dict(kwargs) self._constructor_kwargs = dict(kwargs)
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop() self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {} self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
@ -224,7 +226,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
if not self._client_config.pjwt_auth_enabled: if not self._client_config.pjwt_auth_enabled:
# Without JWT exchange, only API keys are accepted by the server. # Without JWT exchange, only API keys are accepted by the server.
# Replace any cmd-based provider with a plain API key provider. # Replace any cmd-based provider with a plain API key provider.
self._default_auth = ApiKeyAuthProvider(api_key=api_key) self._default_auth = ApiKeyAuthProvider(api_key=self._api_key)
else: else:
# Create a dedicated pool for JWT exchange with the appropriate # Create a dedicated pool for JWT exchange with the appropriate
# credential provider. The lambda captures the pool so it stays alive. # credential provider. The lambda captures the pool so it stays alive.

View file

@ -4,11 +4,11 @@ from __future__ import annotations
import asyncio import asyncio
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from tinker.lib._auth_token_provider import AuthTokenProvider from tinker.lib._auth_token_provider import ApiKeyAuthProvider, AuthTokenProvider
from tinker.lib.internal_client_holder import ClientConnectionPool, InternalClientHolder from tinker.lib.internal_client_holder import ClientConnectionPool, InternalClientHolder
from tinker.types.client_config_response import ClientConfigResponse as _ClientConfigResponse from tinker.types.client_config_response import ClientConfigResponse as _ClientConfigResponse
@ -94,3 +94,65 @@ async def test_fetch_client_config_passes_sdk_version(
call_kwargs = holder._cm.__enter__.return_value.service.client_config.call_args call_kwargs = holder._cm.__enter__.return_value.service.client_config.call_args
assert call_kwargs.kwargs["request"].sdk_version == tinker_sdk_version assert call_kwargs.kwargs["request"].sdk_version == tinker_sdk_version
# ---------------------------------------------------------------------------
# Pickle round-trip: ambient TINKER_API_KEY must travel with the pickle
# ---------------------------------------------------------------------------
def _make_holder(api_key: str | None = None) -> InternalClientHolder:
"""Build a primary InternalClientHolder with server calls stubbed out."""
with (
patch.object(
InternalClientHolder,
"_fetch_client_config",
new_callable=AsyncMock,
# pjwt_auth_enabled=False → plain API-key auth path
return_value=_ClientConfigResponse(pjwt_auth_enabled=False),
),
patch.object(
InternalClientHolder,
"_create_session",
new_callable=AsyncMock,
return_value="sess-pickle-test",
),
patch.object(
InternalClientHolder,
"_start_heartbeat",
new_callable=AsyncMock,
),
):
holder = InternalClientHolder(api_key=api_key)
holder._session_heartbeat_task = MagicMock()
return holder
def test_sampling_client_pickle_roundtrip_without_env_var(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Pickle a SamplingClient created with ambient TINKER_API_KEY, then
unpickle it in an environment without the env var (simulating a worker
process). The credential must travel inside the pickle payload."""
import pickle
from tinker.lib.public_interfaces.sampling_client import SamplingClient
monkeypatch.setenv("TINKER_API_KEY", "tml-key-from-env")
holder = _make_holder(api_key=None)
client = SamplingClient(holder, sampling_session_id="samp-1")
payload = pickle.dumps(client)
# Simulate the worker: no TINKER_API_KEY available.
monkeypatch.delenv("TINKER_API_KEY")
with patch.object(
InternalClientHolder,
"_start_heartbeat",
new_callable=AsyncMock,
):
restored = pickle.loads(payload)
assert isinstance(restored.holder._default_auth, ApiKeyAuthProvider)
assert restored.holder._default_auth._token == "tml-key-from-env"

View file

View file

@ -0,0 +1,110 @@
"""Conversion helpers for proto responses to Pydantic models.
Deserializes proto wire format into SDK Pydantic types (SampleResponse, etc.).
"""
from __future__ import annotations
import numpy as np
from tinker.proto import tinker_public_pb2 as public_pb
from tinker.types.sample_response import SampleResponse
from tinker.types.sampled_sequence import SampledSequence
from tinker.types.stop_reason import StopReason
# Set of model classes that support proto deserialization.
# Used by api_future_impl to decide whether to send Accept: application/x-protobuf.
PROTO_SUPPORTED_TYPES: set[type] = {SampleResponse}
MASK_LOGPROB = -99999.0
# Proto enum -> SDK string mapping
_STOP_REASON_TO_STR: dict[int, StopReason] = {
public_pb.STOP_REASON_STOP: "stop",
public_pb.STOP_REASON_LENGTH: "length",
}
def deserialize_sample_response(proto_bytes: bytes) -> SampleResponse:
"""Deserialize proto bytes into a Pydantic SampleResponse."""
proto = public_pb.SampleResponse()
proto.ParseFromString(proto_bytes)
sequences = []
for seq in proto.sequences:
stop_reason = _STOP_REASON_TO_STR.get(seq.stop_reason)
if stop_reason is None:
raise ValueError(
f"Unknown stop_reason enum value {seq.stop_reason} in proto SampleResponse"
)
tokens = np.frombuffer(seq.tokens, dtype=np.int32).tolist()
logprobs = np.frombuffer(seq.logprobs, dtype=np.float32).tolist() if seq.logprobs else None
sequences.append(
SampledSequence.model_construct(
stop_reason=stop_reason,
tokens=tokens,
logprobs=logprobs,
)
)
prompt_logprobs: list[float | None] | None = None
if proto.prompt_logprobs:
arr = np.frombuffer(proto.prompt_logprobs, dtype=np.float32)
prompt_logprobs_list: list[float | None] = arr.tolist()
for i in np.flatnonzero(np.isnan(arr)):
prompt_logprobs_list[i] = None
prompt_logprobs = prompt_logprobs_list
topk_prompt_logprobs: list[list[tuple[int, float]] | None] | None = None
if proto.HasField("topk_prompt_logprobs"):
topk_prompt_logprobs = _topk_from_proto(proto.topk_prompt_logprobs)
return SampleResponse.model_construct(
sequences=sequences,
prompt_logprobs=prompt_logprobs,
topk_prompt_logprobs=topk_prompt_logprobs,
)
def _topk_from_proto(
topk: public_pb.TopkPromptLogprobs,
) -> list[list[tuple[int, float]] | None]:
"""Convert dense N×K TopkPromptLogprobs to Python list format."""
n = topk.prompt_length
k = topk.k
if n == 0 or k == 0:
return []
token_ids = np.ndarray((n, k), dtype=np.int32, buffer=topk.token_ids)
logprobs = np.ndarray((n, k), dtype=np.float32, buffer=topk.logprobs)
# Single flat zip (faster than 32K per-row zips), then slice per row
tid_flat = token_ids.ravel().tolist()
lp_flat = logprobs.ravel().tolist()
all_tuples = list(zip(tid_flat, lp_flat))
mask_lp = MASK_LOGPROB
result: list[list[tuple[int, float]] | None] = []
for i in range(n):
start = i * k
# First-element sentinel check: if first entry is sentinel, whole row is None
if tid_flat[start] == 0 and lp_flat[start] == mask_lp:
result.append(None)
else:
end = start + k
while end > start and tid_flat[end - 1] == 0 and lp_flat[end - 1] == mask_lp:
end -= 1
result.append(all_tuples[start:end])
return result
def deserialize_proto_response(proto_bytes: bytes, model_cls: type) -> object:
"""Deserialize a proto response based on the expected model class.
Dispatches to the appropriate deserializer based on model_cls.
Raises ValueError for unsupported types.
"""
if model_cls is SampleResponse:
return deserialize_sample_response(proto_bytes)
raise ValueError(f"Proto deserialization not supported for {model_cls}")

View file

@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: tinker_public.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13tinker_public.proto\x12\rtinker_public\"\xd4\x01\n\x0eSampleResponse\x12\x31\n\tsequences\x18\x01 \x03(\x0b\x32\x1e.tinker_public.SampledSequence\x12\x1c\n\x0fprompt_logprobs\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x12\x44\n\x14topk_prompt_logprobs\x18\x03 \x01(\x0b\x32!.tinker_public.TopkPromptLogprobsH\x01\x88\x01\x01\x42\x12\n\x10_prompt_logprobsB\x17\n\x15_topk_prompt_logprobs\"u\n\x0fSampledSequence\x12.\n\x0bstop_reason\x18\x01 \x01(\x0e\x32\x19.tinker_public.StopReason\x12\x0e\n\x06tokens\x18\x02 \x01(\x0c\x12\x15\n\x08logprobs\x18\x03 \x01(\x0cH\x00\x88\x01\x01\x42\x0b\n\t_logprobs\"[\n\x12TopkPromptLogprobs\x12\x11\n\ttoken_ids\x18\x01 \x01(\x0c\x12\x10\n\x08logprobs\x18\x02 \x01(\x0c\x12\t\n\x01k\x18\x03 \x01(\r\x12\x15\n\rprompt_length\x18\x04 \x01(\r*:\n\nStopReason\x12\x14\n\x10STOP_REASON_STOP\x10\x00\x12\x16\n\x12STOP_REASON_LENGTH\x10\x01\x62\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tinker_public_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_STOPREASON._serialized_start=465
_STOPREASON._serialized_end=523
_SAMPLERESPONSE._serialized_start=39
_SAMPLERESPONSE._serialized_end=251
_SAMPLEDSEQUENCE._serialized_start=253
_SAMPLEDSEQUENCE._serialized_end=370
_TOPKPROMPTLOGPROBS._serialized_start=372
_TOPKPROMPTLOGPROBS._serialized_end=463
# @@protoc_insertion_point(module_scope)

View file

@ -0,0 +1,132 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
class _StopReason:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _StopReasonEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_StopReason.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
STOP_REASON_STOP: _StopReason.ValueType # 0
STOP_REASON_LENGTH: _StopReason.ValueType # 1
class StopReason(_StopReason, metaclass=_StopReasonEnumTypeWrapper): ...
STOP_REASON_STOP: StopReason.ValueType # 0
STOP_REASON_LENGTH: StopReason.ValueType # 1
Global___StopReason: typing_extensions.TypeAlias = StopReason
@typing.final
class SampleResponse(google.protobuf.message.Message):
"""Public SampleResponse returned to SDK clients.
Field numbers are compatible with tinker_internal.SampleResponse.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SEQUENCES_FIELD_NUMBER: builtins.int
PROMPT_LOGPROBS_FIELD_NUMBER: builtins.int
TOPK_PROMPT_LOGPROBS_FIELD_NUMBER: builtins.int
prompt_logprobs: builtins.bytes
"""np.array(prompt_logprobs, dtype=np.float32).tobytes()
NaN for missing positions (e.g. first prompt token)
"""
@property
def sequences(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[Global___SampledSequence]: ...
@property
def topk_prompt_logprobs(self) -> Global___TopkPromptLogprobs:
"""Optional: absent means no topk was requested."""
def __init__(
self,
*,
sequences: collections.abc.Iterable[Global___SampledSequence] | None = ...,
prompt_logprobs: builtins.bytes | None = ...,
topk_prompt_logprobs: Global___TopkPromptLogprobs | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_prompt_logprobs", b"_prompt_logprobs", "_topk_prompt_logprobs", b"_topk_prompt_logprobs", "prompt_logprobs", b"prompt_logprobs", "topk_prompt_logprobs", b"topk_prompt_logprobs"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_prompt_logprobs", b"_prompt_logprobs", "_topk_prompt_logprobs", b"_topk_prompt_logprobs", "prompt_logprobs", b"prompt_logprobs", "sequences", b"sequences", "topk_prompt_logprobs", b"topk_prompt_logprobs"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_prompt_logprobs", b"_prompt_logprobs"]) -> typing.Literal["prompt_logprobs"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_topk_prompt_logprobs", b"_topk_prompt_logprobs"]) -> typing.Literal["topk_prompt_logprobs"] | None: ...
Global___SampleResponse: typing_extensions.TypeAlias = SampleResponse
@typing.final
class SampledSequence(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
STOP_REASON_FIELD_NUMBER: builtins.int
TOKENS_FIELD_NUMBER: builtins.int
LOGPROBS_FIELD_NUMBER: builtins.int
stop_reason: Global___StopReason.ValueType
tokens: builtins.bytes
"""np.array(tokens, dtype=np.int32).tobytes()"""
logprobs: builtins.bytes
"""np.array(logprobs, dtype=np.float32).tobytes()
Optional, same length as tokens when present
"""
def __init__(
self,
*,
stop_reason: Global___StopReason.ValueType = ...,
tokens: builtins.bytes = ...,
logprobs: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_logprobs", b"_logprobs", "logprobs", b"logprobs"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_logprobs", b"_logprobs", "logprobs", b"logprobs", "stop_reason", b"stop_reason", "tokens", b"tokens"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_logprobs", b"_logprobs"]) -> typing.Literal["logprobs"] | None: ...
Global___SampledSequence: typing_extensions.TypeAlias = SampledSequence
@typing.final
class TopkPromptLogprobs(google.protobuf.message.Message):
"""Dense N×K matrices for top-k prompt logprobs.
Both token_ids and logprobs are row-major flattened: position i, rank j
is at index i*k + j. Empty cells use token_id=0 and logprob=-99999.0.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOKEN_IDS_FIELD_NUMBER: builtins.int
LOGPROBS_FIELD_NUMBER: builtins.int
K_FIELD_NUMBER: builtins.int
PROMPT_LENGTH_FIELD_NUMBER: builtins.int
token_ids: builtins.bytes
"""Row-major N×K matrix of int32 token IDs"""
logprobs: builtins.bytes
"""Row-major N×K matrix of float32 logprobs"""
k: builtins.int
"""Number of top-k entries per position"""
prompt_length: builtins.int
"""Number of prompt positions"""
def __init__(
self,
*,
token_ids: builtins.bytes = ...,
logprobs: builtins.bytes = ...,
k: builtins.int = ...,
prompt_length: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["k", b"k", "logprobs", b"logprobs", "prompt_length", b"prompt_length", "token_ids", b"token_ids"]) -> None: ...
Global___TopkPromptLogprobs: typing_extensions.TypeAlias = TopkPromptLogprobs