Sync contents

This commit is contained in:
Andrii Grynenko 2026-04-17 00:45:03 +00:00
parent 30517b667f
commit 2a1f5e3240
10 changed files with 384 additions and 7 deletions

View file

@ -1,5 +1,9 @@
line-length = 100
force-exclude = true
exclude = [
# Generated protobuf code
"src/tinker/proto",
]
[lint]

View file

@ -1,4 +1,4 @@
{
"last_synced_sha": "d117d1692821faa297ea5d2ee7e4dc21b5c8bd0a",
"last_sync_time": "2026-04-14T00:00:48.831738"
"last_synced_sha": "774d92ed933004d66c3e86070e0a0d2e713a0a31",
"last_sync_time": "2026-04-17T00:45:03.498497"
}

View file

@ -1,6 +1,6 @@
[project]
name = "tinker"
version = "0.18.0"
version = "0.18.1"
description = "The official Python SDK for the tinker API"
readme = "README.md"
license = "Apache-2.0"
@ -16,6 +16,7 @@ dependencies = [
"distro>=1.7.0, <2",
"sniffio",
"numpy",
"protobuf>=4.21",
"transformers",
"rich>=13.0.0",
"click>=8.0.0",

View file

@ -26,6 +26,8 @@ from .sync_only import sync_only
if TYPE_CHECKING:
from tinker.lib.internal_client_holder import InternalClientHolder
from tinker.proto.response_conv import PROTO_SUPPORTED_TYPES, deserialize_proto_response
logger = logging.getLogger(__name__)
T = TypeVar("T")
@ -115,6 +117,8 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
"X-Tinker-Request-Iteration": str(iteration),
"X-Tinker-Request-Type": self.request_type,
}
if self.model_cls in PROTO_SUPPORTED_TYPES:
headers["Accept"] = "application/x-protobuf, application/json"
if iteration == 0:
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
self.request_queue_roundtrip_time
@ -213,6 +217,37 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
connection_error_retries += 1
continue
# Proto response path: server returned protobuf bytes
content_type = response.headers.get("content-type", "")
if "application/x-protobuf" in content_type:
proto_bytes = response.http_response.content
try:
self._cached_result = deserialize_proto_response(
proto_bytes, self.model_cls
)
return cast(T, self._cached_result)
except Exception as e:
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
"APIFuture.result_async.proto_deserialization_error",
event_data={
"request_id": self.request_id,
"request_type": self.request_type,
"exception": str(e),
"exception_type": type(e).__name__,
"proto_bytes_len": len(proto_bytes),
"model_cls": str(self.model_cls),
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="ERROR",
)
raise ValueError(
f"Proto deserialization failed: {e} for {self.request_id=} and expected type {self.model_cls=}"
) from e
# JSON response path (existing)
result_dict: Any = await response.json()
if "type" in result_dict and result_dict["type"] == "try_again":

View file

@ -191,7 +191,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
_jwt_auth_seed: str | None = None,
**kwargs: Any,
) -> None:
self._api_key = api_key
# Resolve from env now so shadow_kwargs carries the actual credential
# across pickle boundaries (workers may not have the env var set).
self._api_key = api_key or os.environ.get("TINKER_API_KEY")
self._constructor_kwargs = dict(kwargs)
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
@ -224,7 +226,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
if not self._client_config.pjwt_auth_enabled:
# Without JWT exchange, only API keys are accepted by the server.
# Replace any cmd-based provider with a plain API key provider.
self._default_auth = ApiKeyAuthProvider(api_key=api_key)
self._default_auth = ApiKeyAuthProvider(api_key=self._api_key)
else:
# Create a dedicated pool for JWT exchange with the appropriate
# credential provider. The lambda captures the pool so it stays alive.

View file

@ -4,11 +4,11 @@ from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from tinker.lib._auth_token_provider import AuthTokenProvider
from tinker.lib._auth_token_provider import ApiKeyAuthProvider, AuthTokenProvider
from tinker.lib.internal_client_holder import ClientConnectionPool, InternalClientHolder
from tinker.types.client_config_response import ClientConfigResponse as _ClientConfigResponse
@ -94,3 +94,65 @@ async def test_fetch_client_config_passes_sdk_version(
call_kwargs = holder._cm.__enter__.return_value.service.client_config.call_args
assert call_kwargs.kwargs["request"].sdk_version == tinker_sdk_version
# ---------------------------------------------------------------------------
# Pickle round-trip: ambient TINKER_API_KEY must travel with the pickle
# ---------------------------------------------------------------------------
def _make_holder(api_key: str | None = None) -> InternalClientHolder:
"""Build a primary InternalClientHolder with server calls stubbed out."""
with (
patch.object(
InternalClientHolder,
"_fetch_client_config",
new_callable=AsyncMock,
# pjwt_auth_enabled=False → plain API-key auth path
return_value=_ClientConfigResponse(pjwt_auth_enabled=False),
),
patch.object(
InternalClientHolder,
"_create_session",
new_callable=AsyncMock,
return_value="sess-pickle-test",
),
patch.object(
InternalClientHolder,
"_start_heartbeat",
new_callable=AsyncMock,
),
):
holder = InternalClientHolder(api_key=api_key)
holder._session_heartbeat_task = MagicMock()
return holder
def test_sampling_client_pickle_roundtrip_without_env_var(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Pickle a SamplingClient created with ambient TINKER_API_KEY, then
unpickle it in an environment without the env var (simulating a worker
process). The credential must travel inside the pickle payload."""
import pickle
from tinker.lib.public_interfaces.sampling_client import SamplingClient
monkeypatch.setenv("TINKER_API_KEY", "tml-key-from-env")
holder = _make_holder(api_key=None)
client = SamplingClient(holder, sampling_session_id="samp-1")
payload = pickle.dumps(client)
# Simulate the worker: no TINKER_API_KEY available.
monkeypatch.delenv("TINKER_API_KEY")
with patch.object(
InternalClientHolder,
"_start_heartbeat",
new_callable=AsyncMock,
):
restored = pickle.loads(payload)
assert isinstance(restored.holder._default_auth, ApiKeyAuthProvider)
assert restored.holder._default_auth._token == "tml-key-from-env"

View file

View file

@ -0,0 +1,110 @@
"""Conversion helpers for proto responses to Pydantic models.
Deserializes proto wire format into SDK Pydantic types (SampleResponse, etc.).
"""
from __future__ import annotations
import numpy as np
from tinker.proto import tinker_public_pb2 as public_pb
from tinker.types.sample_response import SampleResponse
from tinker.types.sampled_sequence import SampledSequence
from tinker.types.stop_reason import StopReason
# Set of model classes that support proto deserialization.
# Used by api_future_impl to decide whether to send Accept: application/x-protobuf.
PROTO_SUPPORTED_TYPES: set[type] = {SampleResponse}
MASK_LOGPROB = -99999.0
# Proto enum -> SDK string mapping
_STOP_REASON_TO_STR: dict[int, StopReason] = {
public_pb.STOP_REASON_STOP: "stop",
public_pb.STOP_REASON_LENGTH: "length",
}
def deserialize_sample_response(proto_bytes: bytes) -> SampleResponse:
"""Deserialize proto bytes into a Pydantic SampleResponse."""
proto = public_pb.SampleResponse()
proto.ParseFromString(proto_bytes)
sequences = []
for seq in proto.sequences:
stop_reason = _STOP_REASON_TO_STR.get(seq.stop_reason)
if stop_reason is None:
raise ValueError(
f"Unknown stop_reason enum value {seq.stop_reason} in proto SampleResponse"
)
tokens = np.frombuffer(seq.tokens, dtype=np.int32).tolist()
logprobs = np.frombuffer(seq.logprobs, dtype=np.float32).tolist() if seq.logprobs else None
sequences.append(
SampledSequence.model_construct(
stop_reason=stop_reason,
tokens=tokens,
logprobs=logprobs,
)
)
prompt_logprobs: list[float | None] | None = None
if proto.prompt_logprobs:
arr = np.frombuffer(proto.prompt_logprobs, dtype=np.float32)
prompt_logprobs_list: list[float | None] = arr.tolist()
for i in np.flatnonzero(np.isnan(arr)):
prompt_logprobs_list[i] = None
prompt_logprobs = prompt_logprobs_list
topk_prompt_logprobs: list[list[tuple[int, float]] | None] | None = None
if proto.HasField("topk_prompt_logprobs"):
topk_prompt_logprobs = _topk_from_proto(proto.topk_prompt_logprobs)
return SampleResponse.model_construct(
sequences=sequences,
prompt_logprobs=prompt_logprobs,
topk_prompt_logprobs=topk_prompt_logprobs,
)
def _topk_from_proto(
topk: public_pb.TopkPromptLogprobs,
) -> list[list[tuple[int, float]] | None]:
"""Convert dense N×K TopkPromptLogprobs to Python list format."""
n = topk.prompt_length
k = topk.k
if n == 0 or k == 0:
return []
token_ids = np.ndarray((n, k), dtype=np.int32, buffer=topk.token_ids)
logprobs = np.ndarray((n, k), dtype=np.float32, buffer=topk.logprobs)
# Single flat zip (faster than 32K per-row zips), then slice per row
tid_flat = token_ids.ravel().tolist()
lp_flat = logprobs.ravel().tolist()
all_tuples = list(zip(tid_flat, lp_flat))
mask_lp = MASK_LOGPROB
result: list[list[tuple[int, float]] | None] = []
for i in range(n):
start = i * k
# First-element sentinel check: if first entry is sentinel, whole row is None
if tid_flat[start] == 0 and lp_flat[start] == mask_lp:
result.append(None)
else:
end = start + k
while end > start and tid_flat[end - 1] == 0 and lp_flat[end - 1] == mask_lp:
end -= 1
result.append(all_tuples[start:end])
return result
def deserialize_proto_response(proto_bytes: bytes, model_cls: type) -> object:
"""Deserialize a proto response based on the expected model class.
Dispatches to the appropriate deserializer based on model_cls.
Raises ValueError for unsupported types.
"""
if model_cls is SampleResponse:
return deserialize_sample_response(proto_bytes)
raise ValueError(f"Proto deserialization not supported for {model_cls}")

View file

@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: tinker_public.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13tinker_public.proto\x12\rtinker_public\"\xd4\x01\n\x0eSampleResponse\x12\x31\n\tsequences\x18\x01 \x03(\x0b\x32\x1e.tinker_public.SampledSequence\x12\x1c\n\x0fprompt_logprobs\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x12\x44\n\x14topk_prompt_logprobs\x18\x03 \x01(\x0b\x32!.tinker_public.TopkPromptLogprobsH\x01\x88\x01\x01\x42\x12\n\x10_prompt_logprobsB\x17\n\x15_topk_prompt_logprobs\"u\n\x0fSampledSequence\x12.\n\x0bstop_reason\x18\x01 \x01(\x0e\x32\x19.tinker_public.StopReason\x12\x0e\n\x06tokens\x18\x02 \x01(\x0c\x12\x15\n\x08logprobs\x18\x03 \x01(\x0cH\x00\x88\x01\x01\x42\x0b\n\t_logprobs\"[\n\x12TopkPromptLogprobs\x12\x11\n\ttoken_ids\x18\x01 \x01(\x0c\x12\x10\n\x08logprobs\x18\x02 \x01(\x0c\x12\t\n\x01k\x18\x03 \x01(\r\x12\x15\n\rprompt_length\x18\x04 \x01(\r*:\n\nStopReason\x12\x14\n\x10STOP_REASON_STOP\x10\x00\x12\x16\n\x12STOP_REASON_LENGTH\x10\x01\x62\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tinker_public_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_STOPREASON._serialized_start=465
_STOPREASON._serialized_end=523
_SAMPLERESPONSE._serialized_start=39
_SAMPLERESPONSE._serialized_end=251
_SAMPLEDSEQUENCE._serialized_start=253
_SAMPLEDSEQUENCE._serialized_end=370
_TOPKPROMPTLOGPROBS._serialized_start=372
_TOPKPROMPTLOGPROBS._serialized_end=463
# @@protoc_insertion_point(module_scope)

View file

@ -0,0 +1,132 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
class _StopReason:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _StopReasonEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_StopReason.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
STOP_REASON_STOP: _StopReason.ValueType # 0
STOP_REASON_LENGTH: _StopReason.ValueType # 1
class StopReason(_StopReason, metaclass=_StopReasonEnumTypeWrapper): ...
STOP_REASON_STOP: StopReason.ValueType # 0
STOP_REASON_LENGTH: StopReason.ValueType # 1
Global___StopReason: typing_extensions.TypeAlias = StopReason
@typing.final
class SampleResponse(google.protobuf.message.Message):
"""Public SampleResponse returned to SDK clients.
Field numbers are compatible with tinker_internal.SampleResponse.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SEQUENCES_FIELD_NUMBER: builtins.int
PROMPT_LOGPROBS_FIELD_NUMBER: builtins.int
TOPK_PROMPT_LOGPROBS_FIELD_NUMBER: builtins.int
prompt_logprobs: builtins.bytes
"""np.array(prompt_logprobs, dtype=np.float32).tobytes()
NaN for missing positions (e.g. first prompt token)
"""
@property
def sequences(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[Global___SampledSequence]: ...
@property
def topk_prompt_logprobs(self) -> Global___TopkPromptLogprobs:
"""Optional: absent means no topk was requested."""
def __init__(
self,
*,
sequences: collections.abc.Iterable[Global___SampledSequence] | None = ...,
prompt_logprobs: builtins.bytes | None = ...,
topk_prompt_logprobs: Global___TopkPromptLogprobs | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_prompt_logprobs", b"_prompt_logprobs", "_topk_prompt_logprobs", b"_topk_prompt_logprobs", "prompt_logprobs", b"prompt_logprobs", "topk_prompt_logprobs", b"topk_prompt_logprobs"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_prompt_logprobs", b"_prompt_logprobs", "_topk_prompt_logprobs", b"_topk_prompt_logprobs", "prompt_logprobs", b"prompt_logprobs", "sequences", b"sequences", "topk_prompt_logprobs", b"topk_prompt_logprobs"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_prompt_logprobs", b"_prompt_logprobs"]) -> typing.Literal["prompt_logprobs"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_topk_prompt_logprobs", b"_topk_prompt_logprobs"]) -> typing.Literal["topk_prompt_logprobs"] | None: ...
Global___SampleResponse: typing_extensions.TypeAlias = SampleResponse
@typing.final
class SampledSequence(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
STOP_REASON_FIELD_NUMBER: builtins.int
TOKENS_FIELD_NUMBER: builtins.int
LOGPROBS_FIELD_NUMBER: builtins.int
stop_reason: Global___StopReason.ValueType
tokens: builtins.bytes
"""np.array(tokens, dtype=np.int32).tobytes()"""
logprobs: builtins.bytes
"""np.array(logprobs, dtype=np.float32).tobytes()
Optional, same length as tokens when present
"""
def __init__(
self,
*,
stop_reason: Global___StopReason.ValueType = ...,
tokens: builtins.bytes = ...,
logprobs: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_logprobs", b"_logprobs", "logprobs", b"logprobs"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_logprobs", b"_logprobs", "logprobs", b"logprobs", "stop_reason", b"stop_reason", "tokens", b"tokens"]) -> None: ...
def WhichOneof(self, oneof_group: typing.Literal["_logprobs", b"_logprobs"]) -> typing.Literal["logprobs"] | None: ...
Global___SampledSequence: typing_extensions.TypeAlias = SampledSequence
@typing.final
class TopkPromptLogprobs(google.protobuf.message.Message):
"""Dense N×K matrices for top-k prompt logprobs.
Both token_ids and logprobs are row-major flattened: position i, rank j
is at index i*k + j. Empty cells use token_id=0 and logprob=-99999.0.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOKEN_IDS_FIELD_NUMBER: builtins.int
LOGPROBS_FIELD_NUMBER: builtins.int
K_FIELD_NUMBER: builtins.int
PROMPT_LENGTH_FIELD_NUMBER: builtins.int
token_ids: builtins.bytes
"""Row-major N×K matrix of int32 token IDs"""
logprobs: builtins.bytes
"""Row-major N×K matrix of float32 logprobs"""
k: builtins.int
"""Number of top-k entries per position"""
prompt_length: builtins.int
"""Number of prompt positions"""
def __init__(
self,
*,
token_ids: builtins.bytes = ...,
logprobs: builtins.bytes = ...,
k: builtins.int = ...,
prompt_length: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["k", b"k", "logprobs", b"logprobs", "prompt_length", b"prompt_length", "token_ids", b"token_ids"]) -> None: ...
Global___TopkPromptLogprobs: typing_extensions.TypeAlias = TopkPromptLogprobs