mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
30517b667f
commit
2a1f5e3240
10 changed files with 384 additions and 7 deletions
|
|
@ -1,5 +1,9 @@
|
||||||
line-length = 100
|
line-length = 100
|
||||||
force-exclude = true
|
force-exclude = true
|
||||||
|
exclude = [
|
||||||
|
# Generated protobuf code
|
||||||
|
"src/tinker/proto",
|
||||||
|
]
|
||||||
|
|
||||||
[lint]
|
[lint]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"last_synced_sha": "d117d1692821faa297ea5d2ee7e4dc21b5c8bd0a",
|
"last_synced_sha": "774d92ed933004d66c3e86070e0a0d2e713a0a31",
|
||||||
"last_sync_time": "2026-04-14T00:00:48.831738"
|
"last_sync_time": "2026-04-17T00:45:03.498497"
|
||||||
}
|
}
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "tinker"
|
name = "tinker"
|
||||||
version = "0.18.0"
|
version = "0.18.1"
|
||||||
description = "The official Python SDK for the tinker API"
|
description = "The official Python SDK for the tinker API"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
@ -16,6 +16,7 @@ dependencies = [
|
||||||
"distro>=1.7.0, <2",
|
"distro>=1.7.0, <2",
|
||||||
"sniffio",
|
"sniffio",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
"protobuf>=4.21",
|
||||||
"transformers",
|
"transformers",
|
||||||
"rich>=13.0.0",
|
"rich>=13.0.0",
|
||||||
"click>=8.0.0",
|
"click>=8.0.0",
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ from .sync_only import sync_only
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tinker.lib.internal_client_holder import InternalClientHolder
|
from tinker.lib.internal_client_holder import InternalClientHolder
|
||||||
|
|
||||||
|
from tinker.proto.response_conv import PROTO_SUPPORTED_TYPES, deserialize_proto_response
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
@ -115,6 +117,8 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||||
"X-Tinker-Request-Iteration": str(iteration),
|
"X-Tinker-Request-Iteration": str(iteration),
|
||||||
"X-Tinker-Request-Type": self.request_type,
|
"X-Tinker-Request-Type": self.request_type,
|
||||||
}
|
}
|
||||||
|
if self.model_cls in PROTO_SUPPORTED_TYPES:
|
||||||
|
headers["Accept"] = "application/x-protobuf, application/json"
|
||||||
if iteration == 0:
|
if iteration == 0:
|
||||||
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
|
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
|
||||||
self.request_queue_roundtrip_time
|
self.request_queue_roundtrip_time
|
||||||
|
|
@ -213,6 +217,37 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||||
connection_error_retries += 1
|
connection_error_retries += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Proto response path: server returned protobuf bytes
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
if "application/x-protobuf" in content_type:
|
||||||
|
proto_bytes = response.http_response.content
|
||||||
|
try:
|
||||||
|
self._cached_result = deserialize_proto_response(
|
||||||
|
proto_bytes, self.model_cls
|
||||||
|
)
|
||||||
|
return cast(T, self._cached_result)
|
||||||
|
except Exception as e:
|
||||||
|
if telemetry := self.get_telemetry():
|
||||||
|
current_time = time.time()
|
||||||
|
telemetry.log(
|
||||||
|
"APIFuture.result_async.proto_deserialization_error",
|
||||||
|
event_data={
|
||||||
|
"request_id": self.request_id,
|
||||||
|
"request_type": self.request_type,
|
||||||
|
"exception": str(e),
|
||||||
|
"exception_type": type(e).__name__,
|
||||||
|
"proto_bytes_len": len(proto_bytes),
|
||||||
|
"model_cls": str(self.model_cls),
|
||||||
|
"iteration": iteration,
|
||||||
|
"elapsed_time": current_time - start_time,
|
||||||
|
},
|
||||||
|
severity="ERROR",
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Proto deserialization failed: {e} for {self.request_id=} and expected type {self.model_cls=}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# JSON response path (existing)
|
||||||
result_dict: Any = await response.json()
|
result_dict: Any = await response.json()
|
||||||
|
|
||||||
if "type" in result_dict and result_dict["type"] == "try_again":
|
if "type" in result_dict and result_dict["type"] == "try_again":
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
_jwt_auth_seed: str | None = None,
|
_jwt_auth_seed: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._api_key = api_key
|
# Resolve from env now so shadow_kwargs carries the actual credential
|
||||||
|
# across pickle boundaries (workers may not have the env var set).
|
||||||
|
self._api_key = api_key or os.environ.get("TINKER_API_KEY")
|
||||||
self._constructor_kwargs = dict(kwargs)
|
self._constructor_kwargs = dict(kwargs)
|
||||||
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
|
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
|
||||||
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
|
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
|
||||||
|
|
@ -224,7 +226,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
if not self._client_config.pjwt_auth_enabled:
|
if not self._client_config.pjwt_auth_enabled:
|
||||||
# Without JWT exchange, only API keys are accepted by the server.
|
# Without JWT exchange, only API keys are accepted by the server.
|
||||||
# Replace any cmd-based provider with a plain API key provider.
|
# Replace any cmd-based provider with a plain API key provider.
|
||||||
self._default_auth = ApiKeyAuthProvider(api_key=api_key)
|
self._default_auth = ApiKeyAuthProvider(api_key=self._api_key)
|
||||||
else:
|
else:
|
||||||
# Create a dedicated pool for JWT exchange with the appropriate
|
# Create a dedicated pool for JWT exchange with the appropriate
|
||||||
# credential provider. The lambda captures the pool so it stays alive.
|
# credential provider. The lambda captures the pool so it stays alive.
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,11 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tinker.lib._auth_token_provider import AuthTokenProvider
|
from tinker.lib._auth_token_provider import ApiKeyAuthProvider, AuthTokenProvider
|
||||||
from tinker.lib.internal_client_holder import ClientConnectionPool, InternalClientHolder
|
from tinker.lib.internal_client_holder import ClientConnectionPool, InternalClientHolder
|
||||||
from tinker.types.client_config_response import ClientConfigResponse as _ClientConfigResponse
|
from tinker.types.client_config_response import ClientConfigResponse as _ClientConfigResponse
|
||||||
|
|
||||||
|
|
@ -94,3 +94,65 @@ async def test_fetch_client_config_passes_sdk_version(
|
||||||
|
|
||||||
call_kwargs = holder._cm.__enter__.return_value.service.client_config.call_args
|
call_kwargs = holder._cm.__enter__.return_value.service.client_config.call_args
|
||||||
assert call_kwargs.kwargs["request"].sdk_version == tinker_sdk_version
|
assert call_kwargs.kwargs["request"].sdk_version == tinker_sdk_version
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pickle round-trip: ambient TINKER_API_KEY must travel with the pickle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_holder(api_key: str | None = None) -> InternalClientHolder:
|
||||||
|
"""Build a primary InternalClientHolder with server calls stubbed out."""
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
InternalClientHolder,
|
||||||
|
"_fetch_client_config",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
# pjwt_auth_enabled=False → plain API-key auth path
|
||||||
|
return_value=_ClientConfigResponse(pjwt_auth_enabled=False),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
InternalClientHolder,
|
||||||
|
"_create_session",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value="sess-pickle-test",
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
InternalClientHolder,
|
||||||
|
"_start_heartbeat",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
holder = InternalClientHolder(api_key=api_key)
|
||||||
|
holder._session_heartbeat_task = MagicMock()
|
||||||
|
return holder
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampling_client_pickle_roundtrip_without_env_var(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Pickle a SamplingClient created with ambient TINKER_API_KEY, then
|
||||||
|
unpickle it in an environment without the env var (simulating a worker
|
||||||
|
process). The credential must travel inside the pickle payload."""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from tinker.lib.public_interfaces.sampling_client import SamplingClient
|
||||||
|
|
||||||
|
monkeypatch.setenv("TINKER_API_KEY", "tml-key-from-env")
|
||||||
|
holder = _make_holder(api_key=None)
|
||||||
|
client = SamplingClient(holder, sampling_session_id="samp-1")
|
||||||
|
|
||||||
|
payload = pickle.dumps(client)
|
||||||
|
|
||||||
|
# Simulate the worker: no TINKER_API_KEY available.
|
||||||
|
monkeypatch.delenv("TINKER_API_KEY")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
InternalClientHolder,
|
||||||
|
"_start_heartbeat",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
):
|
||||||
|
restored = pickle.loads(payload)
|
||||||
|
|
||||||
|
assert isinstance(restored.holder._default_auth, ApiKeyAuthProvider)
|
||||||
|
assert restored.holder._default_auth._token == "tml-key-from-env"
|
||||||
|
|
|
||||||
0
src/tinker/proto/__init__.py
Normal file
0
src/tinker/proto/__init__.py
Normal file
110
src/tinker/proto/response_conv.py
Normal file
110
src/tinker/proto/response_conv.py
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
"""Conversion helpers for proto responses to Pydantic models.
|
||||||
|
|
||||||
|
Deserializes proto wire format into SDK Pydantic types (SampleResponse, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tinker.proto import tinker_public_pb2 as public_pb
|
||||||
|
from tinker.types.sample_response import SampleResponse
|
||||||
|
from tinker.types.sampled_sequence import SampledSequence
|
||||||
|
from tinker.types.stop_reason import StopReason
|
||||||
|
|
||||||
|
# Set of model classes that support proto deserialization.
|
||||||
|
# Used by api_future_impl to decide whether to send Accept: application/x-protobuf.
|
||||||
|
PROTO_SUPPORTED_TYPES: set[type] = {SampleResponse}
|
||||||
|
|
||||||
|
MASK_LOGPROB = -99999.0
|
||||||
|
|
||||||
|
# Proto enum -> SDK string mapping
|
||||||
|
_STOP_REASON_TO_STR: dict[int, StopReason] = {
|
||||||
|
public_pb.STOP_REASON_STOP: "stop",
|
||||||
|
public_pb.STOP_REASON_LENGTH: "length",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_sample_response(proto_bytes: bytes) -> SampleResponse:
|
||||||
|
"""Deserialize proto bytes into a Pydantic SampleResponse."""
|
||||||
|
proto = public_pb.SampleResponse()
|
||||||
|
proto.ParseFromString(proto_bytes)
|
||||||
|
|
||||||
|
sequences = []
|
||||||
|
for seq in proto.sequences:
|
||||||
|
stop_reason = _STOP_REASON_TO_STR.get(seq.stop_reason)
|
||||||
|
if stop_reason is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown stop_reason enum value {seq.stop_reason} in proto SampleResponse"
|
||||||
|
)
|
||||||
|
tokens = np.frombuffer(seq.tokens, dtype=np.int32).tolist()
|
||||||
|
logprobs = np.frombuffer(seq.logprobs, dtype=np.float32).tolist() if seq.logprobs else None
|
||||||
|
sequences.append(
|
||||||
|
SampledSequence.model_construct(
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
tokens=tokens,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_logprobs: list[float | None] | None = None
|
||||||
|
if proto.prompt_logprobs:
|
||||||
|
arr = np.frombuffer(proto.prompt_logprobs, dtype=np.float32)
|
||||||
|
prompt_logprobs_list: list[float | None] = arr.tolist()
|
||||||
|
for i in np.flatnonzero(np.isnan(arr)):
|
||||||
|
prompt_logprobs_list[i] = None
|
||||||
|
prompt_logprobs = prompt_logprobs_list
|
||||||
|
|
||||||
|
topk_prompt_logprobs: list[list[tuple[int, float]] | None] | None = None
|
||||||
|
if proto.HasField("topk_prompt_logprobs"):
|
||||||
|
topk_prompt_logprobs = _topk_from_proto(proto.topk_prompt_logprobs)
|
||||||
|
|
||||||
|
return SampleResponse.model_construct(
|
||||||
|
sequences=sequences,
|
||||||
|
prompt_logprobs=prompt_logprobs,
|
||||||
|
topk_prompt_logprobs=topk_prompt_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _topk_from_proto(
|
||||||
|
topk: public_pb.TopkPromptLogprobs,
|
||||||
|
) -> list[list[tuple[int, float]] | None]:
|
||||||
|
"""Convert dense N×K TopkPromptLogprobs to Python list format."""
|
||||||
|
n = topk.prompt_length
|
||||||
|
k = topk.k
|
||||||
|
|
||||||
|
if n == 0 or k == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
token_ids = np.ndarray((n, k), dtype=np.int32, buffer=topk.token_ids)
|
||||||
|
logprobs = np.ndarray((n, k), dtype=np.float32, buffer=topk.logprobs)
|
||||||
|
|
||||||
|
# Single flat zip (faster than 32K per-row zips), then slice per row
|
||||||
|
tid_flat = token_ids.ravel().tolist()
|
||||||
|
lp_flat = logprobs.ravel().tolist()
|
||||||
|
all_tuples = list(zip(tid_flat, lp_flat))
|
||||||
|
|
||||||
|
mask_lp = MASK_LOGPROB
|
||||||
|
result: list[list[tuple[int, float]] | None] = []
|
||||||
|
for i in range(n):
|
||||||
|
start = i * k
|
||||||
|
# First-element sentinel check: if first entry is sentinel, whole row is None
|
||||||
|
if tid_flat[start] == 0 and lp_flat[start] == mask_lp:
|
||||||
|
result.append(None)
|
||||||
|
else:
|
||||||
|
end = start + k
|
||||||
|
while end > start and tid_flat[end - 1] == 0 and lp_flat[end - 1] == mask_lp:
|
||||||
|
end -= 1
|
||||||
|
result.append(all_tuples[start:end])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_proto_response(proto_bytes: bytes, model_cls: type) -> object:
|
||||||
|
"""Deserialize a proto response based on the expected model class.
|
||||||
|
|
||||||
|
Dispatches to the appropriate deserializer based on model_cls.
|
||||||
|
Raises ValueError for unsupported types.
|
||||||
|
"""
|
||||||
|
if model_cls is SampleResponse:
|
||||||
|
return deserialize_sample_response(proto_bytes)
|
||||||
|
raise ValueError(f"Proto deserialization not supported for {model_cls}")
|
||||||
31
src/tinker/proto/tinker_public_pb2.py
Normal file
31
src/tinker/proto/tinker_public_pb2.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# source: tinker_public.proto
|
||||||
|
"""Generated protocol buffer code."""
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13tinker_public.proto\x12\rtinker_public\"\xd4\x01\n\x0eSampleResponse\x12\x31\n\tsequences\x18\x01 \x03(\x0b\x32\x1e.tinker_public.SampledSequence\x12\x1c\n\x0fprompt_logprobs\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x12\x44\n\x14topk_prompt_logprobs\x18\x03 \x01(\x0b\x32!.tinker_public.TopkPromptLogprobsH\x01\x88\x01\x01\x42\x12\n\x10_prompt_logprobsB\x17\n\x15_topk_prompt_logprobs\"u\n\x0fSampledSequence\x12.\n\x0bstop_reason\x18\x01 \x01(\x0e\x32\x19.tinker_public.StopReason\x12\x0e\n\x06tokens\x18\x02 \x01(\x0c\x12\x15\n\x08logprobs\x18\x03 \x01(\x0cH\x00\x88\x01\x01\x42\x0b\n\t_logprobs\"[\n\x12TopkPromptLogprobs\x12\x11\n\ttoken_ids\x18\x01 \x01(\x0c\x12\x10\n\x08logprobs\x18\x02 \x01(\x0c\x12\t\n\x01k\x18\x03 \x01(\r\x12\x15\n\rprompt_length\x18\x04 \x01(\r*:\n\nStopReason\x12\x14\n\x10STOP_REASON_STOP\x10\x00\x12\x16\n\x12STOP_REASON_LENGTH\x10\x01\x62\x06proto3')
|
||||||
|
|
||||||
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||||
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tinker_public_pb2', globals())
|
||||||
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||||
|
|
||||||
|
DESCRIPTOR._options = None
|
||||||
|
_STOPREASON._serialized_start=465
|
||||||
|
_STOPREASON._serialized_end=523
|
||||||
|
_SAMPLERESPONSE._serialized_start=39
|
||||||
|
_SAMPLERESPONSE._serialized_end=251
|
||||||
|
_SAMPLEDSEQUENCE._serialized_start=253
|
||||||
|
_SAMPLEDSEQUENCE._serialized_end=370
|
||||||
|
_TOPKPROMPTLOGPROBS._serialized_start=372
|
||||||
|
_TOPKPROMPTLOGPROBS._serialized_end=463
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
||||||
132
src/tinker/proto/tinker_public_pb2.pyi
Normal file
132
src/tinker/proto/tinker_public_pb2.pyi
Normal file
|
|
@ -0,0 +1,132 @@
|
||||||
|
"""
|
||||||
|
@generated by mypy-protobuf. Do not edit manually!
|
||||||
|
isort:skip_file
|
||||||
|
"""
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import collections.abc
|
||||||
|
import google.protobuf.descriptor
|
||||||
|
import google.protobuf.internal.containers
|
||||||
|
import google.protobuf.internal.enum_type_wrapper
|
||||||
|
import google.protobuf.message
|
||||||
|
import sys
|
||||||
|
import typing
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
import typing as typing_extensions
|
||||||
|
else:
|
||||||
|
import typing_extensions
|
||||||
|
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||||
|
|
||||||
|
class _StopReason:
|
||||||
|
ValueType = typing.NewType("ValueType", builtins.int)
|
||||||
|
V: typing_extensions.TypeAlias = ValueType
|
||||||
|
|
||||||
|
class _StopReasonEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_StopReason.ValueType], builtins.type):
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||||
|
STOP_REASON_STOP: _StopReason.ValueType # 0
|
||||||
|
STOP_REASON_LENGTH: _StopReason.ValueType # 1
|
||||||
|
|
||||||
|
class StopReason(_StopReason, metaclass=_StopReasonEnumTypeWrapper): ...
|
||||||
|
|
||||||
|
STOP_REASON_STOP: StopReason.ValueType # 0
|
||||||
|
STOP_REASON_LENGTH: StopReason.ValueType # 1
|
||||||
|
Global___StopReason: typing_extensions.TypeAlias = StopReason
|
||||||
|
|
||||||
|
@typing.final
|
||||||
|
class SampleResponse(google.protobuf.message.Message):
|
||||||
|
"""Public SampleResponse returned to SDK clients.
|
||||||
|
Field numbers are compatible with tinker_internal.SampleResponse.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
|
|
||||||
|
SEQUENCES_FIELD_NUMBER: builtins.int
|
||||||
|
PROMPT_LOGPROBS_FIELD_NUMBER: builtins.int
|
||||||
|
TOPK_PROMPT_LOGPROBS_FIELD_NUMBER: builtins.int
|
||||||
|
prompt_logprobs: builtins.bytes
|
||||||
|
"""np.array(prompt_logprobs, dtype=np.float32).tobytes()
|
||||||
|
NaN for missing positions (e.g. first prompt token)
|
||||||
|
"""
|
||||||
|
@property
|
||||||
|
def sequences(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[Global___SampledSequence]: ...
|
||||||
|
@property
|
||||||
|
def topk_prompt_logprobs(self) -> Global___TopkPromptLogprobs:
|
||||||
|
"""Optional: absent means no topk was requested."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sequences: collections.abc.Iterable[Global___SampledSequence] | None = ...,
|
||||||
|
prompt_logprobs: builtins.bytes | None = ...,
|
||||||
|
topk_prompt_logprobs: Global___TopkPromptLogprobs | None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def HasField(self, field_name: typing.Literal["_prompt_logprobs", b"_prompt_logprobs", "_topk_prompt_logprobs", b"_topk_prompt_logprobs", "prompt_logprobs", b"prompt_logprobs", "topk_prompt_logprobs", b"topk_prompt_logprobs"]) -> builtins.bool: ...
|
||||||
|
def ClearField(self, field_name: typing.Literal["_prompt_logprobs", b"_prompt_logprobs", "_topk_prompt_logprobs", b"_topk_prompt_logprobs", "prompt_logprobs", b"prompt_logprobs", "sequences", b"sequences", "topk_prompt_logprobs", b"topk_prompt_logprobs"]) -> None: ...
|
||||||
|
@typing.overload
|
||||||
|
def WhichOneof(self, oneof_group: typing.Literal["_prompt_logprobs", b"_prompt_logprobs"]) -> typing.Literal["prompt_logprobs"] | None: ...
|
||||||
|
@typing.overload
|
||||||
|
def WhichOneof(self, oneof_group: typing.Literal["_topk_prompt_logprobs", b"_topk_prompt_logprobs"]) -> typing.Literal["topk_prompt_logprobs"] | None: ...
|
||||||
|
|
||||||
|
Global___SampleResponse: typing_extensions.TypeAlias = SampleResponse
|
||||||
|
|
||||||
|
@typing.final
|
||||||
|
class SampledSequence(google.protobuf.message.Message):
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
|
|
||||||
|
STOP_REASON_FIELD_NUMBER: builtins.int
|
||||||
|
TOKENS_FIELD_NUMBER: builtins.int
|
||||||
|
LOGPROBS_FIELD_NUMBER: builtins.int
|
||||||
|
stop_reason: Global___StopReason.ValueType
|
||||||
|
tokens: builtins.bytes
|
||||||
|
"""np.array(tokens, dtype=np.int32).tobytes()"""
|
||||||
|
logprobs: builtins.bytes
|
||||||
|
"""np.array(logprobs, dtype=np.float32).tobytes()
|
||||||
|
Optional, same length as tokens when present
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
stop_reason: Global___StopReason.ValueType = ...,
|
||||||
|
tokens: builtins.bytes = ...,
|
||||||
|
logprobs: builtins.bytes | None = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def HasField(self, field_name: typing.Literal["_logprobs", b"_logprobs", "logprobs", b"logprobs"]) -> builtins.bool: ...
|
||||||
|
def ClearField(self, field_name: typing.Literal["_logprobs", b"_logprobs", "logprobs", b"logprobs", "stop_reason", b"stop_reason", "tokens", b"tokens"]) -> None: ...
|
||||||
|
def WhichOneof(self, oneof_group: typing.Literal["_logprobs", b"_logprobs"]) -> typing.Literal["logprobs"] | None: ...
|
||||||
|
|
||||||
|
Global___SampledSequence: typing_extensions.TypeAlias = SampledSequence
|
||||||
|
|
||||||
|
@typing.final
|
||||||
|
class TopkPromptLogprobs(google.protobuf.message.Message):
|
||||||
|
"""Dense N×K matrices for top-k prompt logprobs.
|
||||||
|
Both token_ids and logprobs are row-major flattened: position i, rank j
|
||||||
|
is at index i*k + j. Empty cells use token_id=0 and logprob=-99999.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||||
|
|
||||||
|
TOKEN_IDS_FIELD_NUMBER: builtins.int
|
||||||
|
LOGPROBS_FIELD_NUMBER: builtins.int
|
||||||
|
K_FIELD_NUMBER: builtins.int
|
||||||
|
PROMPT_LENGTH_FIELD_NUMBER: builtins.int
|
||||||
|
token_ids: builtins.bytes
|
||||||
|
"""Row-major N×K matrix of int32 token IDs"""
|
||||||
|
logprobs: builtins.bytes
|
||||||
|
"""Row-major N×K matrix of float32 logprobs"""
|
||||||
|
k: builtins.int
|
||||||
|
"""Number of top-k entries per position"""
|
||||||
|
prompt_length: builtins.int
|
||||||
|
"""Number of prompt positions"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
token_ids: builtins.bytes = ...,
|
||||||
|
logprobs: builtins.bytes = ...,
|
||||||
|
k: builtins.int = ...,
|
||||||
|
prompt_length: builtins.int = ...,
|
||||||
|
) -> None: ...
|
||||||
|
def ClearField(self, field_name: typing.Literal["k", b"k", "logprobs", b"logprobs", "prompt_length", b"prompt_length", "token_ids", b"token_ids"]) -> None: ...
|
||||||
|
|
||||||
|
Global___TopkPromptLogprobs: typing_extensions.TypeAlias = TopkPromptLogprobs
|
||||||
Loading…
Add table
Add a link
Reference in a new issue