mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
30517b667f
commit
2a1f5e3240
10 changed files with 384 additions and 7 deletions
|
|
@ -1,5 +1,9 @@
|
|||
line-length = 100
|
||||
force-exclude = true
|
||||
exclude = [
|
||||
# Generated protobuf code
|
||||
"src/tinker/proto",
|
||||
]
|
||||
|
||||
[lint]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "d117d1692821faa297ea5d2ee7e4dc21b5c8bd0a",
|
||||
"last_sync_time": "2026-04-14T00:00:48.831738"
|
||||
"last_synced_sha": "774d92ed933004d66c3e86070e0a0d2e713a0a31",
|
||||
"last_sync_time": "2026-04-17T00:45:03.498497"
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "tinker"
|
||||
version = "0.18.0"
|
||||
version = "0.18.1"
|
||||
description = "The official Python SDK for the tinker API"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
|
|
@ -16,6 +16,7 @@ dependencies = [
|
|||
"distro>=1.7.0, <2",
|
||||
"sniffio",
|
||||
"numpy",
|
||||
"protobuf>=4.21",
|
||||
"transformers",
|
||||
"rich>=13.0.0",
|
||||
"click>=8.0.0",
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ from .sync_only import sync_only
|
|||
if TYPE_CHECKING:
|
||||
from tinker.lib.internal_client_holder import InternalClientHolder
|
||||
|
||||
from tinker.proto.response_conv import PROTO_SUPPORTED_TYPES, deserialize_proto_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
|
@ -115,6 +117,8 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
|||
"X-Tinker-Request-Iteration": str(iteration),
|
||||
"X-Tinker-Request-Type": self.request_type,
|
||||
}
|
||||
if self.model_cls in PROTO_SUPPORTED_TYPES:
|
||||
headers["Accept"] = "application/x-protobuf, application/json"
|
||||
if iteration == 0:
|
||||
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
|
||||
self.request_queue_roundtrip_time
|
||||
|
|
@ -213,6 +217,37 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
|||
connection_error_retries += 1
|
||||
continue
|
||||
|
||||
# Proto response path: server returned protobuf bytes
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/x-protobuf" in content_type:
|
||||
proto_bytes = response.http_response.content
|
||||
try:
|
||||
self._cached_result = deserialize_proto_response(
|
||||
proto_bytes, self.model_cls
|
||||
)
|
||||
return cast(T, self._cached_result)
|
||||
except Exception as e:
|
||||
if telemetry := self.get_telemetry():
|
||||
current_time = time.time()
|
||||
telemetry.log(
|
||||
"APIFuture.result_async.proto_deserialization_error",
|
||||
event_data={
|
||||
"request_id": self.request_id,
|
||||
"request_type": self.request_type,
|
||||
"exception": str(e),
|
||||
"exception_type": type(e).__name__,
|
||||
"proto_bytes_len": len(proto_bytes),
|
||||
"model_cls": str(self.model_cls),
|
||||
"iteration": iteration,
|
||||
"elapsed_time": current_time - start_time,
|
||||
},
|
||||
severity="ERROR",
|
||||
)
|
||||
raise ValueError(
|
||||
f"Proto deserialization failed: {e} for {self.request_id=} and expected type {self.model_cls=}"
|
||||
) from e
|
||||
|
||||
# JSON response path (existing)
|
||||
result_dict: Any = await response.json()
|
||||
|
||||
if "type" in result_dict and result_dict["type"] == "try_again":
|
||||
|
|
|
|||
|
|
@ -191,7 +191,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
_jwt_auth_seed: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._api_key = api_key
|
||||
# Resolve from env now so shadow_kwargs carries the actual credential
|
||||
# across pickle boundaries (workers may not have the env var set).
|
||||
self._api_key = api_key or os.environ.get("TINKER_API_KEY")
|
||||
self._constructor_kwargs = dict(kwargs)
|
||||
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
|
||||
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
|
||||
|
|
@ -224,7 +226,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
if not self._client_config.pjwt_auth_enabled:
|
||||
# Without JWT exchange, only API keys are accepted by the server.
|
||||
# Replace any cmd-based provider with a plain API key provider.
|
||||
self._default_auth = ApiKeyAuthProvider(api_key=api_key)
|
||||
self._default_auth = ApiKeyAuthProvider(api_key=self._api_key)
|
||||
else:
|
||||
# Create a dedicated pool for JWT exchange with the appropriate
|
||||
# credential provider. The lambda captures the pool so it stays alive.
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tinker.lib._auth_token_provider import AuthTokenProvider
|
||||
from tinker.lib._auth_token_provider import ApiKeyAuthProvider, AuthTokenProvider
|
||||
from tinker.lib.internal_client_holder import ClientConnectionPool, InternalClientHolder
|
||||
from tinker.types.client_config_response import ClientConfigResponse as _ClientConfigResponse
|
||||
|
||||
|
|
@ -94,3 +94,65 @@ async def test_fetch_client_config_passes_sdk_version(
|
|||
|
||||
call_kwargs = holder._cm.__enter__.return_value.service.client_config.call_args
|
||||
assert call_kwargs.kwargs["request"].sdk_version == tinker_sdk_version
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pickle round-trip: ambient TINKER_API_KEY must travel with the pickle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_holder(api_key: str | None = None) -> InternalClientHolder:
|
||||
"""Build a primary InternalClientHolder with server calls stubbed out."""
|
||||
with (
|
||||
patch.object(
|
||||
InternalClientHolder,
|
||||
"_fetch_client_config",
|
||||
new_callable=AsyncMock,
|
||||
# pjwt_auth_enabled=False → plain API-key auth path
|
||||
return_value=_ClientConfigResponse(pjwt_auth_enabled=False),
|
||||
),
|
||||
patch.object(
|
||||
InternalClientHolder,
|
||||
"_create_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value="sess-pickle-test",
|
||||
),
|
||||
patch.object(
|
||||
InternalClientHolder,
|
||||
"_start_heartbeat",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
):
|
||||
holder = InternalClientHolder(api_key=api_key)
|
||||
holder._session_heartbeat_task = MagicMock()
|
||||
return holder
|
||||
|
||||
|
||||
def test_sampling_client_pickle_roundtrip_without_env_var(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Pickle a SamplingClient created with ambient TINKER_API_KEY, then
|
||||
unpickle it in an environment without the env var (simulating a worker
|
||||
process). The credential must travel inside the pickle payload."""
|
||||
import pickle
|
||||
|
||||
from tinker.lib.public_interfaces.sampling_client import SamplingClient
|
||||
|
||||
monkeypatch.setenv("TINKER_API_KEY", "tml-key-from-env")
|
||||
holder = _make_holder(api_key=None)
|
||||
client = SamplingClient(holder, sampling_session_id="samp-1")
|
||||
|
||||
payload = pickle.dumps(client)
|
||||
|
||||
# Simulate the worker: no TINKER_API_KEY available.
|
||||
monkeypatch.delenv("TINKER_API_KEY")
|
||||
|
||||
with patch.object(
|
||||
InternalClientHolder,
|
||||
"_start_heartbeat",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
restored = pickle.loads(payload)
|
||||
|
||||
assert isinstance(restored.holder._default_auth, ApiKeyAuthProvider)
|
||||
assert restored.holder._default_auth._token == "tml-key-from-env"
|
||||
|
|
|
|||
0
src/tinker/proto/__init__.py
Normal file
0
src/tinker/proto/__init__.py
Normal file
110
src/tinker/proto/response_conv.py
Normal file
110
src/tinker/proto/response_conv.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Conversion helpers for proto responses to Pydantic models.
|
||||
|
||||
Deserializes proto wire format into SDK Pydantic types (SampleResponse, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tinker.proto import tinker_public_pb2 as public_pb
|
||||
from tinker.types.sample_response import SampleResponse
|
||||
from tinker.types.sampled_sequence import SampledSequence
|
||||
from tinker.types.stop_reason import StopReason
|
||||
|
||||
# Set of model classes that support proto deserialization.
|
||||
# Used by api_future_impl to decide whether to send Accept: application/x-protobuf.
|
||||
PROTO_SUPPORTED_TYPES: set[type] = {SampleResponse}
|
||||
|
||||
MASK_LOGPROB = -99999.0
|
||||
|
||||
# Proto enum -> SDK string mapping
|
||||
_STOP_REASON_TO_STR: dict[int, StopReason] = {
|
||||
public_pb.STOP_REASON_STOP: "stop",
|
||||
public_pb.STOP_REASON_LENGTH: "length",
|
||||
}
|
||||
|
||||
|
||||
def deserialize_sample_response(proto_bytes: bytes) -> SampleResponse:
|
||||
"""Deserialize proto bytes into a Pydantic SampleResponse."""
|
||||
proto = public_pb.SampleResponse()
|
||||
proto.ParseFromString(proto_bytes)
|
||||
|
||||
sequences = []
|
||||
for seq in proto.sequences:
|
||||
stop_reason = _STOP_REASON_TO_STR.get(seq.stop_reason)
|
||||
if stop_reason is None:
|
||||
raise ValueError(
|
||||
f"Unknown stop_reason enum value {seq.stop_reason} in proto SampleResponse"
|
||||
)
|
||||
tokens = np.frombuffer(seq.tokens, dtype=np.int32).tolist()
|
||||
logprobs = np.frombuffer(seq.logprobs, dtype=np.float32).tolist() if seq.logprobs else None
|
||||
sequences.append(
|
||||
SampledSequence.model_construct(
|
||||
stop_reason=stop_reason,
|
||||
tokens=tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
prompt_logprobs: list[float | None] | None = None
|
||||
if proto.prompt_logprobs:
|
||||
arr = np.frombuffer(proto.prompt_logprobs, dtype=np.float32)
|
||||
prompt_logprobs_list: list[float | None] = arr.tolist()
|
||||
for i in np.flatnonzero(np.isnan(arr)):
|
||||
prompt_logprobs_list[i] = None
|
||||
prompt_logprobs = prompt_logprobs_list
|
||||
|
||||
topk_prompt_logprobs: list[list[tuple[int, float]] | None] | None = None
|
||||
if proto.HasField("topk_prompt_logprobs"):
|
||||
topk_prompt_logprobs = _topk_from_proto(proto.topk_prompt_logprobs)
|
||||
|
||||
return SampleResponse.model_construct(
|
||||
sequences=sequences,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
topk_prompt_logprobs=topk_prompt_logprobs,
|
||||
)
|
||||
|
||||
|
||||
def _topk_from_proto(
|
||||
topk: public_pb.TopkPromptLogprobs,
|
||||
) -> list[list[tuple[int, float]] | None]:
|
||||
"""Convert dense N×K TopkPromptLogprobs to Python list format."""
|
||||
n = topk.prompt_length
|
||||
k = topk.k
|
||||
|
||||
if n == 0 or k == 0:
|
||||
return []
|
||||
|
||||
token_ids = np.ndarray((n, k), dtype=np.int32, buffer=topk.token_ids)
|
||||
logprobs = np.ndarray((n, k), dtype=np.float32, buffer=topk.logprobs)
|
||||
|
||||
# Single flat zip (faster than 32K per-row zips), then slice per row
|
||||
tid_flat = token_ids.ravel().tolist()
|
||||
lp_flat = logprobs.ravel().tolist()
|
||||
all_tuples = list(zip(tid_flat, lp_flat))
|
||||
|
||||
mask_lp = MASK_LOGPROB
|
||||
result: list[list[tuple[int, float]] | None] = []
|
||||
for i in range(n):
|
||||
start = i * k
|
||||
# First-element sentinel check: if first entry is sentinel, whole row is None
|
||||
if tid_flat[start] == 0 and lp_flat[start] == mask_lp:
|
||||
result.append(None)
|
||||
else:
|
||||
end = start + k
|
||||
while end > start and tid_flat[end - 1] == 0 and lp_flat[end - 1] == mask_lp:
|
||||
end -= 1
|
||||
result.append(all_tuples[start:end])
|
||||
return result
|
||||
|
||||
|
||||
def deserialize_proto_response(proto_bytes: bytes, model_cls: type) -> object:
|
||||
"""Deserialize a proto response based on the expected model class.
|
||||
|
||||
Dispatches to the appropriate deserializer based on model_cls.
|
||||
Raises ValueError for unsupported types.
|
||||
"""
|
||||
if model_cls is SampleResponse:
|
||||
return deserialize_sample_response(proto_bytes)
|
||||
raise ValueError(f"Proto deserialization not supported for {model_cls}")
|
||||
31
src/tinker/proto/tinker_public_pb2.py
Normal file
31
src/tinker/proto/tinker_public_pb2.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: tinker_public.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13tinker_public.proto\x12\rtinker_public\"\xd4\x01\n\x0eSampleResponse\x12\x31\n\tsequences\x18\x01 \x03(\x0b\x32\x1e.tinker_public.SampledSequence\x12\x1c\n\x0fprompt_logprobs\x18\x02 \x01(\x0cH\x00\x88\x01\x01\x12\x44\n\x14topk_prompt_logprobs\x18\x03 \x01(\x0b\x32!.tinker_public.TopkPromptLogprobsH\x01\x88\x01\x01\x42\x12\n\x10_prompt_logprobsB\x17\n\x15_topk_prompt_logprobs\"u\n\x0fSampledSequence\x12.\n\x0bstop_reason\x18\x01 \x01(\x0e\x32\x19.tinker_public.StopReason\x12\x0e\n\x06tokens\x18\x02 \x01(\x0c\x12\x15\n\x08logprobs\x18\x03 \x01(\x0cH\x00\x88\x01\x01\x42\x0b\n\t_logprobs\"[\n\x12TopkPromptLogprobs\x12\x11\n\ttoken_ids\x18\x01 \x01(\x0c\x12\x10\n\x08logprobs\x18\x02 \x01(\x0c\x12\t\n\x01k\x18\x03 \x01(\r\x12\x15\n\rprompt_length\x18\x04 \x01(\r*:\n\nStopReason\x12\x14\n\x10STOP_REASON_STOP\x10\x00\x12\x16\n\x12STOP_REASON_LENGTH\x10\x01\x62\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tinker_public_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_STOPREASON._serialized_start=465
|
||||
_STOPREASON._serialized_end=523
|
||||
_SAMPLERESPONSE._serialized_start=39
|
||||
_SAMPLERESPONSE._serialized_end=251
|
||||
_SAMPLEDSEQUENCE._serialized_start=253
|
||||
_SAMPLEDSEQUENCE._serialized_end=370
|
||||
_TOPKPROMPTLOGPROBS._serialized_start=372
|
||||
_TOPKPROMPTLOGPROBS._serialized_end=463
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
132
src/tinker/proto/tinker_public_pb2.pyi
Normal file
132
src/tinker/proto/tinker_public_pb2.pyi
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
class _StopReason:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _StopReasonEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_StopReason.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
STOP_REASON_STOP: _StopReason.ValueType # 0
|
||||
STOP_REASON_LENGTH: _StopReason.ValueType # 1
|
||||
|
||||
class StopReason(_StopReason, metaclass=_StopReasonEnumTypeWrapper): ...
|
||||
|
||||
STOP_REASON_STOP: StopReason.ValueType # 0
|
||||
STOP_REASON_LENGTH: StopReason.ValueType # 1
|
||||
Global___StopReason: typing_extensions.TypeAlias = StopReason
|
||||
|
||||
@typing.final
|
||||
class SampleResponse(google.protobuf.message.Message):
|
||||
"""Public SampleResponse returned to SDK clients.
|
||||
Field numbers are compatible with tinker_internal.SampleResponse.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
SEQUENCES_FIELD_NUMBER: builtins.int
|
||||
PROMPT_LOGPROBS_FIELD_NUMBER: builtins.int
|
||||
TOPK_PROMPT_LOGPROBS_FIELD_NUMBER: builtins.int
|
||||
prompt_logprobs: builtins.bytes
|
||||
"""np.array(prompt_logprobs, dtype=np.float32).tobytes()
|
||||
NaN for missing positions (e.g. first prompt token)
|
||||
"""
|
||||
@property
|
||||
def sequences(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[Global___SampledSequence]: ...
|
||||
@property
|
||||
def topk_prompt_logprobs(self) -> Global___TopkPromptLogprobs:
|
||||
"""Optional: absent means no topk was requested."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sequences: collections.abc.Iterable[Global___SampledSequence] | None = ...,
|
||||
prompt_logprobs: builtins.bytes | None = ...,
|
||||
topk_prompt_logprobs: Global___TopkPromptLogprobs | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_prompt_logprobs", b"_prompt_logprobs", "_topk_prompt_logprobs", b"_topk_prompt_logprobs", "prompt_logprobs", b"prompt_logprobs", "topk_prompt_logprobs", b"topk_prompt_logprobs"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_prompt_logprobs", b"_prompt_logprobs", "_topk_prompt_logprobs", b"_topk_prompt_logprobs", "prompt_logprobs", b"prompt_logprobs", "sequences", b"sequences", "topk_prompt_logprobs", b"topk_prompt_logprobs"]) -> None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_prompt_logprobs", b"_prompt_logprobs"]) -> typing.Literal["prompt_logprobs"] | None: ...
|
||||
@typing.overload
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_topk_prompt_logprobs", b"_topk_prompt_logprobs"]) -> typing.Literal["topk_prompt_logprobs"] | None: ...
|
||||
|
||||
Global___SampleResponse: typing_extensions.TypeAlias = SampleResponse
|
||||
|
||||
@typing.final
|
||||
class SampledSequence(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
STOP_REASON_FIELD_NUMBER: builtins.int
|
||||
TOKENS_FIELD_NUMBER: builtins.int
|
||||
LOGPROBS_FIELD_NUMBER: builtins.int
|
||||
stop_reason: Global___StopReason.ValueType
|
||||
tokens: builtins.bytes
|
||||
"""np.array(tokens, dtype=np.int32).tobytes()"""
|
||||
logprobs: builtins.bytes
|
||||
"""np.array(logprobs, dtype=np.float32).tobytes()
|
||||
Optional, same length as tokens when present
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
stop_reason: Global___StopReason.ValueType = ...,
|
||||
tokens: builtins.bytes = ...,
|
||||
logprobs: builtins.bytes | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["_logprobs", b"_logprobs", "logprobs", b"logprobs"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["_logprobs", b"_logprobs", "logprobs", b"logprobs", "stop_reason", b"stop_reason", "tokens", b"tokens"]) -> None: ...
|
||||
def WhichOneof(self, oneof_group: typing.Literal["_logprobs", b"_logprobs"]) -> typing.Literal["logprobs"] | None: ...
|
||||
|
||||
Global___SampledSequence: typing_extensions.TypeAlias = SampledSequence
|
||||
|
||||
@typing.final
|
||||
class TopkPromptLogprobs(google.protobuf.message.Message):
|
||||
"""Dense N×K matrices for top-k prompt logprobs.
|
||||
Both token_ids and logprobs are row-major flattened: position i, rank j
|
||||
is at index i*k + j. Empty cells use token_id=0 and logprob=-99999.0.
|
||||
"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOKEN_IDS_FIELD_NUMBER: builtins.int
|
||||
LOGPROBS_FIELD_NUMBER: builtins.int
|
||||
K_FIELD_NUMBER: builtins.int
|
||||
PROMPT_LENGTH_FIELD_NUMBER: builtins.int
|
||||
token_ids: builtins.bytes
|
||||
"""Row-major N×K matrix of int32 token IDs"""
|
||||
logprobs: builtins.bytes
|
||||
"""Row-major N×K matrix of float32 logprobs"""
|
||||
k: builtins.int
|
||||
"""Number of top-k entries per position"""
|
||||
prompt_length: builtins.int
|
||||
"""Number of prompt positions"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
token_ids: builtins.bytes = ...,
|
||||
logprobs: builtins.bytes = ...,
|
||||
k: builtins.int = ...,
|
||||
prompt_length: builtins.int = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["k", b"k", "logprobs", b"logprobs", "prompt_length", b"prompt_length", "token_ids", b"token_ids"]) -> None: ...
|
||||
|
||||
Global___TopkPromptLogprobs: typing_extensions.TypeAlias = TopkPromptLogprobs
|
||||
Loading…
Add table
Add a link
Reference in a new issue