mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
3.5 KiB
3.5 KiB
SamplingClient for Tinker API.
SamplingClient Objects
class SamplingClient(TelemetryProvider, QueueStateObserver)
Client for text generation and inference from trained or base models.
The SamplingClient lets you generate text tokens from either a base model or from weights
you've saved using a TrainingClient. You typically get one by calling
service_client.create_sampling_client() or training_client.save_weights_and_get_sampling_client().
Key methods:
- sample() - generate text completions with customizable parameters
- compute_logprobs() - get log probabilities for prompt tokens
Args:
holder: Internal client managing HTTP connections and async operationsmodel_path: Path to saved model weights (starts with 'tinker://')base_model: Name of base model to use for inferenceretry_config: Configuration for retrying failed requests
Example:
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen3-8B")
prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is"))
params = types.SamplingParams(max_tokens=20, temperature=0.7)
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result = future.result()
sample
def sample(
prompt: types.ModelInput,
num_samples: int,
sampling_params: types.SamplingParams,
include_prompt_logprobs: bool = False,
topk_prompt_logprobs: int = 0
) -> ConcurrentFuture[types.SampleResponse]
Generate text completions from the model.
Args:
prompt: The input tokens as ModelInputnum_samples: Number of independent samples to generatesampling_params: Parameters controlling generation (temperature, max_tokens, etc.)include_prompt_logprobs: Whether to include log probabilities for prompt tokenstopk_prompt_logprobs: Number of top token log probabilities to return per position
Returns:
- A
Futurecontaining theSampleResponsewith generated text
Example:
prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is"))
params = types.SamplingParams(max_tokens=20, temperature=0.7)
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result = future.result()
for sample in result.samples:
print(tokenizer.decode(sample.tokens))
sample_async
async def sample_async(prompt: types.ModelInput,
num_samples: int,
sampling_params: types.SamplingParams,
include_prompt_logprobs: bool = False,
topk_prompt_logprobs: int = 0) -> types.SampleResponse
Async version of sample.
compute_logprobs
def compute_logprobs(
prompt: types.ModelInput) -> ConcurrentFuture[list[float | None]]
Compute log probabilities for prompt tokens.
Args:
prompt: The input tokens as ModelInput
Returns:
- A
Futurecontaining a list of log probabilities for each token in the prompt. None values indicate tokens where log probabilities couldn't be computed.
Example:
prompt = types.ModelInput.from_ints(tokenizer.encode("Hello world"))
future = sampling_client.compute_logprobs(prompt)
logprobs = future.result()
for i, logprob in enumerate(logprobs):
if logprob is not None:
print(f"Token {i}: logprob = {logprob:.4f}")
compute_logprobs_async
async def compute_logprobs_async(
prompt: types.ModelInput) -> list[float | None]
Async version of compute_logprobs.