diff --git a/.sync_state b/.sync_state index 7f5b4ca..bafc67d 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "b31ac99f31c26082a4076c3c79719034b2e0cab2", - "last_sync_time": "2026-03-09T00:49:41.558594" + "last_synced_sha": "2138c60c730c51b7bc19146a38320d631c38e0cc", + "last_sync_time": "2026-03-19T00:10:49.917056" } \ No newline at end of file diff --git a/docs/api/exceptions.md b/docs/api/exceptions.md index 97e9fdf..64a8111 100644 --- a/docs/api/exceptions.md +++ b/docs/api/exceptions.md @@ -121,6 +121,38 @@ class InternalServerError(APIStatusError) HTTP 500+: An error occurred on the server. +## `SidecarError` Objects + +```python +class SidecarError(TinkerError) +``` + +Base exception for subprocess sidecar errors. + +## `SidecarStartupError` Objects + +```python +class SidecarStartupError(SidecarError) +``` + +Raised when the sidecar subprocess fails to start or times out. + +## `SidecarDiedError` Objects + +```python +class SidecarDiedError(SidecarError) +``` + +Raised when the sidecar subprocess exits unexpectedly while requests are pending. + +## `SidecarIPCError` Objects + +```python +class SidecarIPCError(SidecarError) +``` + +Raised when communication with the sidecar subprocess fails. + ## `RequestFailedError` Objects ```python diff --git a/docs/api/restclient.md b/docs/api/restclient.md index 0a1b89f..15e14b9 100644 --- a/docs/api/restclient.md +++ b/docs/api/restclient.md @@ -20,6 +20,7 @@ Key methods: - get_checkpoint_archive_url() - get signed URL to download checkpoint archive - publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public - unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private +- set_checkpoint_ttl_from_tinker_path() - set or remove TTL on a checkpoint Args: - `holder`: Internal client managing HTTP connections and async operations @@ -39,7 +40,9 @@ for checkpoint in checkpoints.checkpoints: ```python def get_training_run( - training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun] + training_run_id: types.ModelID, + access_scope: Literal["owned", "accessible"] = "owned" +) -> ConcurrentFuture[types.TrainingRun] ``` Get training run info. @@ -61,7 +64,9 @@ print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model} ```python async def get_training_run_async( - training_run_id: types.ModelID) -> types.TrainingRun + training_run_id: types.ModelID, + access_scope: Literal["owned", + "accessible"] = "owned") -> types.TrainingRun ``` Async version of get_training_run. @@ -70,7 +75,9 @@ Async version of get_training_run. ```python def get_training_run_by_tinker_path( - tinker_path: str) -> ConcurrentFuture[types.TrainingRun] + tinker_path: str, + access_scope: Literal["owned", "accessible"] = "owned" +) -> ConcurrentFuture[types.TrainingRun] ``` Get training run info. @@ -92,7 +99,9 @@ print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model} ```python async def get_training_run_by_tinker_path_async( - tinker_path: str) -> types.TrainingRun + tinker_path: str, + access_scope: Literal["owned", + "accessible"] = "owned") -> types.TrainingRun ``` Async version of get_training_run_by_tinker_path. @@ -123,8 +132,10 @@ print(f"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}") ```python def list_training_runs( - limit: int = 20, - offset: int = 0) -> ConcurrentFuture[types.TrainingRunsResponse] + limit: int = 20, + offset: int = 0, + access_scope: Literal["owned", "accessible"] = "owned" +) -> ConcurrentFuture[types.TrainingRunsResponse] ``` List training runs with pagination support. @@ -149,9 +160,11 @@ next_page = rest_client.list_training_runs(limit=50, offset=50) #### `list_training_runs_async` ```python -async def list_training_runs_async(limit: int = 20, - offset: int = 0 - ) -> types.TrainingRunsResponse +async def list_training_runs_async( + limit: int = 20, + offset: int = 0, + access_scope: Literal["owned", "accessible"] = "owned" +) -> types.TrainingRunsResponse ``` Async version of list_training_runs. @@ -367,6 +380,46 @@ async def unpublish_checkpoint_from_tinker_path_async( Async version of unpublish_checkpoint_from_tinker_path. +#### `set_checkpoint_ttl_from_tinker_path` + +```python +def set_checkpoint_ttl_from_tinker_path( + tinker_path: str, ttl_seconds: int | None) -> ConcurrentFuture[None] +``` + +Set or remove the TTL on a checkpoint referenced by a tinker path. + +If ttl_seconds is provided, the checkpoint will expire after that many seconds from now. +If ttl_seconds is None, any existing expiration will be removed. + +Args: +- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") +- `ttl_seconds`: Number of seconds until expiration, or None to remove TTL + +Returns: +- A `Future` that completes when the TTL is set + +Raises: + HTTPException: 400 if checkpoint identifier is invalid or ttl_seconds <= 0 + HTTPException: 404 if checkpoint not found or user doesn't own the training run + HTTPException: 500 if there's an error setting the TTL + +Example: +```python +future = rest_client.set_checkpoint_ttl_from_tinker_path("tinker://run-id/weights/0001", 86400) +future.result() # Wait for completion +print("Checkpoint TTL set successfully") +``` + +#### `set_checkpoint_ttl_from_tinker_path_async` + +```python +async def set_checkpoint_ttl_from_tinker_path_async( + tinker_path: str, ttl_seconds: int | None) -> None +``` + +Async version of set_checkpoint_ttl_from_tinker_path. + #### `list_user_checkpoints` ```python @@ -414,7 +467,10 @@ Async version of list_user_checkpoints. #### `get_session` ```python -def get_session(session_id: str) -> ConcurrentFuture[types.GetSessionResponse] +def get_session( + session_id: str, + access_scope: Literal["owned", "accessible"] = "owned" +) -> ConcurrentFuture[types.GetSessionResponse] ``` Get session information including all training runs and samplers. @@ -436,7 +492,10 @@ print(f"Samplers: {len(response.sampler_ids)}") #### `get_session_async` ```python -async def get_session_async(session_id: str) -> types.GetSessionResponse +async def get_session_async( + session_id: str, + access_scope: Literal["owned", "accessible"] = "owned" +) -> types.GetSessionResponse ``` Async version of get_session. @@ -445,8 +504,10 @@ Async version of get_session. ```python def list_sessions( - limit: int = 20, - offset: int = 0) -> ConcurrentFuture[types.ListSessionsResponse] + limit: int = 20, + offset: int = 0, + access_scope: Literal["owned", "accessible"] = "owned" +) -> ConcurrentFuture[types.ListSessionsResponse] ``` List sessions with pagination support. @@ -470,8 +531,11 @@ next_page = rest_client.list_sessions(limit=50, offset=50) #### `list_sessions_async` ```python -async def list_sessions_async(limit: int = 20, - offset: int = 0) -> types.ListSessionsResponse +async def list_sessions_async( + limit: int = 20, + offset: int = 0, + access_scope: Literal["owned", "accessible"] = "owned" +) -> types.ListSessionsResponse ``` Async version of list_sessions. diff --git a/docs/api/samplingclient.md b/docs/api/samplingclient.md index 080fd1a..b2576c5 100644 --- a/docs/api/samplingclient.md +++ b/docs/api/samplingclient.md @@ -30,6 +30,20 @@ future = sampling_client.sample(prompt=prompt, sampling_params=params, num_sampl result = future.result() ``` +Multi-processing support: +This class is picklable, so it can be passed to a separate process/worker to sample. It is also +safe to pass the same instance of SamplingClient to multiple processes/workers. + +If you are using Tinker SDK with more than one process you should always create SamplingClient from +the main process and then pass it to the other processes/workers. +ServiceClient and TrainingClient should always be managed from the main process. + +Subprocess isolation: +Set ``TINKER_SUBPROCESS_SAMPLING=1`` to run sample() and compute_logprobs() in a dedicated +subprocess, preventing GIL contention from CPU-heavy user code (grading, environment +interactions) from stalling networking IO and heartbeats. This is transparent — the same +API works with or without it. + #### `sample` ```python @@ -121,3 +135,33 @@ Get the tokenizer for the current model. Returns: - `PreTrainedTokenizer` compatible with the model + +#### `get_base_model` + +```python +def get_base_model() -> str +``` + +Get the base model name for the current sampling session. + +#### `get_base_model_async` + +```python +async def get_base_model_async() -> str +``` + +Async version of get_base_model. + +#### `__reduce__` + +```python +def __reduce__() -> tuple[Any, tuple[_SamplingClientPickleState]] +``` + +Enable pickling of SamplingClient for subprocess use. + +Serializes into a ``_SamplingClientPickleState`` dataclass. The +``_sampling_client_sidecar_handle`` handle is deliberately omitted — only a +bool flag is stored. The unpickled copy creates its own handle via +the per-process sidecar singleton. Do not add ``__getstate__`` +without preserving this behavior. diff --git a/docs/api/trainingclient.md b/docs/api/trainingclient.md index 564f113..b2c00d0 100644 --- a/docs/api/trainingclient.md +++ b/docs/api/trainingclient.md @@ -3,7 +3,7 @@ TrainingClient for Tinker API. ## `TrainingClient` Objects ```python -class TrainingClient(TelemetryProvider, QueueStateObserver) +class TrainingClient(TelemetryProvider) ``` Client for training ML models with forward/backward passes and optimization. @@ -127,8 +127,11 @@ Async version of forward_backward. ```python def forward_backward_custom( - data: List[types.Datum], - loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput] + data: List[types.Datum], + loss_fn: CustomLossFnV1, + *, + loss_type_input: Literal["logprobs"] = "logprobs" +) -> APIFuture[types.ForwardBackwardOutput] ``` Compute forward/backward with a custom loss function. @@ -139,6 +142,7 @@ The custom function receives logprobs and computes loss and gradients. Args: - `data`: List of training data samples - `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics) +- `loss_type_input`: Input space for `loss_fn`. Currently the only supported value is `"logprobs"`. Returns: - `APIFuture` containing the forward/backward outputs with custom loss @@ -161,8 +165,11 @@ print(f"Metrics: {result.metrics}") ```python async def forward_backward_custom_async( - data: List[types.Datum], - loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput] + data: List[types.Datum], + loss_fn: CustomLossFnV1, + *, + loss_type_input: Literal["logprobs"] = "logprobs" +) -> APIFuture[types.ForwardBackwardOutput] ``` Async version of forward_backward_custom. diff --git a/docs/api/types.md b/docs/api/types.md index 5e2aa5e..b3907d1 100644 --- a/docs/api/types.md +++ b/docs/api/types.md @@ -1,125 +1,12 @@ -## `AdamParams` Objects +## `LoadWeightsResponse` Objects ```python -class AdamParams(StrictBase) +class LoadWeightsResponse(BaseModel) ``` -#### `learning_rate` +#### `path` -Learning rate for the optimizer - -#### `beta1` - -Coefficient used for computing running averages of gradient - -#### `beta2` - -Coefficient used for computing running averages of gradient square - -#### `eps` - -Term added to the denominator to improve numerical stability - -#### `weight_decay` - -Weight decay for the optimizer. Uses decoupled weight decay. - -#### `grad_clip_norm` - -Maximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping. - -## `SupportedModel` Objects - -```python -class SupportedModel(BaseModel) -``` - -Information about a model supported by the server. - -#### `model_name` - -The name of the supported model. - -## `GetServerCapabilitiesResponse` Objects - -```python -class GetServerCapabilitiesResponse(BaseModel) -``` - -Response containing the server's supported models and capabilities. - -#### `supported_models` - -List of models available on the server. - -## `OptimStepResponse` Objects - -```python -class OptimStepResponse(BaseModel) -``` - -#### `metrics` - -Optimization step metrics as key-value pairs - -## `ModelInput` Objects - -```python -class ModelInput(StrictBase) -``` - -#### `chunks` - -Sequence of input chunks (formerly TokenSequence) - -#### `from_ints` - -```python -def from_ints(cls, tokens: List[int]) -> "ModelInput" -``` - -Create a ModelInput from a list of ints (tokens). - -#### `to_ints` - -```python -def to_ints() -> List[int] -``` - -Convert the ModelInput to a list of ints (tokens) -Throws exception if there are any non-token chunks - -#### `length` - -```python -def length() -> int -``` - -Return the total context length used by this ModelInput. - -#### `empty` - -```python -def empty(cls) -> "ModelInput" -``` - -Create an empty ModelInput. - -#### `append` - -```python -def append(chunk: ModelInputChunk) -> "ModelInput" -``` - -Add a new chunk, return a new ModelInput. - -#### `append_int` - -```python -def append_int(token: int) -> "ModelInput" -``` - -Add a new token, return a new ModelInput. +A tinker URI for model weights at a specific step ## `WeightsInfoResponse` Objects @@ -129,6 +16,74 @@ class WeightsInfoResponse(BaseModel) Minimal information for loading public checkpoints. +## `LoadWeightsRequest` Objects + +```python +class LoadWeightsRequest(StrictBase) +``` + +#### `path` + +A tinker URI for model weights at a specific step + +#### `optimizer` + +Whether to load optimizer state along with model weights + +## `CreateModelRequest` Objects + +```python +class CreateModelRequest(StrictBase) +``` + +#### `base_model` + +The name of the base model to fine-tune (e.g., 'Qwen/Qwen3-8B'). + +#### `user_metadata` + +Optional metadata about this model/training run, set by the end-user. + +#### `lora_config` + +LoRA configuration + +## `UnhandledExceptionEvent` Objects + +```python +class UnhandledExceptionEvent(BaseModel) +``` + +#### `event` + +Telemetry event type + +#### `severity` + +Log severity level + +#### `traceback` + +Optional Python traceback string + +## `Datum` Objects + +```python +class Datum(StrictBase) +``` + +#### `loss_fn_inputs` + +Dictionary mapping field names to tensor data + +#### `convert_tensors` + +```python +def convert_tensors(cls, data: Any) -> Any +``` + +Convert torch.Tensor and numpy arrays to TensorData in loss_fn_inputs during construction. + ## `Checkpoint` Objects ```python @@ -193,96 +148,135 @@ def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath" Parse a tinker path to an instance of ParsedCheckpointTinkerPath -## `CheckpointArchiveUrlResponse` Objects +## `SamplingParams` Objects ```python -class CheckpointArchiveUrlResponse(BaseModel) +class SamplingParams(BaseModel) ``` -#### `url` +#### `max_tokens` -Signed URL to download the checkpoint archive +Maximum number of tokens to generate -#### `expires` +#### `seed` -Unix timestamp when the signed URL expires, if available +Random seed for reproducible generation -## `SampledSequence` Objects +#### `stop` + +Stop sequences for generation + +#### `temperature` + +Sampling temperature + +#### `top_k` + +Top-k sampling parameter (-1 for no limit) + +#### `top_p` + +Nucleus sampling probability + +## `SaveWeightsForSamplerRequest` Objects ```python -class SampledSequence(BaseModel) -``` - -#### `stop_reason` - -Reason why sampling stopped - -#### `tokens` - -List of generated token IDs - -#### `logprobs` - -Log probabilities for each token (optional) - -## `TryAgainResponse` Objects - -```python -class TryAgainResponse(BaseModel) -``` - -#### `request_id` - -Request ID that is still pending - -## `LoadWeightsRequest` Objects - -```python -class LoadWeightsRequest(StrictBase) +class SaveWeightsForSamplerRequest(StrictBase) ``` #### `path` -A tinker URI for model weights at a specific step +A file/directory name for the weights -#### `optimizer` +#### `ttl_seconds` -Whether to load optimizer state along with model weights - -## `TelemetrySendRequest` Objects +TTL in seconds for this checkpoint (None = never expires) +## `ModelInput` Objects ```python -class TelemetrySendRequest(StrictBase) +class ModelInput(StrictBase) ``` -#### `platform` +#### `chunks` -Host platform name +Sequence of input chunks (formerly TokenSequence) -#### `sdk_version` - -SDK version string - -## `ImageAssetPointerChunk` Objects +#### `from_ints` ```python -class ImageAssetPointerChunk(StrictBase) +def from_ints(cls, tokens: List[int]) -> "ModelInput" ``` -#### `format` +Create a ModelInput from a list of ints (tokens). -Image format +#### `to_ints` -#### `location` +```python +def to_ints() -> List[int] +``` -Path or URL to the image asset +Convert the ModelInput to a list of ints (tokens) +Throws exception if there are any non-token chunks -#### `expected_tokens` +#### `length` -Expected number of tokens this image represents. -This is only advisory: the tinker backend will compute the number of tokens -from the image, and we can fail requests quickly if the tokens does not -match expected_tokens. +```python +def length() -> int +``` + +Return the total context length used by this ModelInput. + +#### `empty` + +```python +def empty(cls) -> "ModelInput" +``` + +Create an empty ModelInput. + +#### `append` + +```python +def append(chunk: ModelInputChunk) -> "ModelInput" +``` + +Add a new chunk, return a new ModelInput. + +#### `append_int` + +```python +def append_int(token: int) -> "ModelInput" +``` + +Add a new token, return a new ModelInput. + +## `SessionEndEvent` Objects + +```python +class SessionEndEvent(BaseModel) +``` + +#### `duration` + +ISO 8601 duration string + +#### `event` + +Telemetry event type + +#### `severity` + +Log severity level + +## `CreateSamplingSessionResponse` Objects + +```python +class CreateSamplingSessionResponse(BaseModel) +``` + +#### `sampling_session_id` + +The generated sampling session ID ## `CheckpointsListResponse` Objects @@ -298,159 +292,6 @@ List of available model checkpoints for the model Pagination cursor information (None for unpaginated responses) -## `GenericEvent` Objects - -```python -class GenericEvent(BaseModel) -``` - -#### `event` - -Telemetry event type - -#### `event_name` - -Low-cardinality event name - -#### `severity` - -Log severity level - -#### `event_data` - -Arbitrary structured JSON payload - -## `EncodedTextChunk` Objects - -```python -class EncodedTextChunk(StrictBase) -``` - -#### `tokens` - -Array of token IDs - -## `ForwardBackwardInput` Objects - -```python -class ForwardBackwardInput(StrictBase) -``` - -#### `data` - -Array of input data for the forward/backward pass - -#### `loss_fn` - -Fully qualified function path for the loss function - -#### `loss_fn_config` - -Optional configuration parameters for the loss function (e.g., PPO clip thresholds, DPO beta) - -## `SessionStartEvent` Objects - -```python -class SessionStartEvent(BaseModel) -``` - -#### `event` - -Telemetry event type - -#### `severity` - -Log severity level - -## `TrainingRunsResponse` Objects - -```python -class TrainingRunsResponse(BaseModel) -``` - -#### `training_runs` - -List of training runs - -#### `cursor` - -Pagination cursor information - -## `SaveWeightsResponse` Objects - -```python -class SaveWeightsResponse(BaseModel) -``` - -#### `path` - -A tinker URI for model weights at a specific step - -## `SampleRequest` Objects - -```python -class SampleRequest(StrictBase) -``` - -#### `num_samples` - -Number of samples to generate - -#### `base_model` - -Optional base model name to sample from. - -Is inferred from model_path, if provided. If sampling against a base model, this -is required. - -#### `model_path` - -Optional tinker:// path to your model weights or LoRA weights. - -If not provided, samples against the base model. - -#### `sampling_session_id` - -Optional sampling session ID to use instead of model_path/base_model. - -If provided along with seq_id, the model configuration will be loaded from the -sampling session. This is useful for multi-turn conversations. - -#### `seq_id` - -Sequence ID within the sampling session. - -Required when sampling_session_id is provided. Used to generate deterministic -request IDs for the sampling request. - -#### `prompt_logprobs` - -If set to `true`, computes and returns logprobs on the prompt tokens. - -Defaults to false. - -#### `topk_prompt_logprobs` - -If set to a positive integer, returns the top-k logprobs for each prompt token. - -## `ForwardBackwardOutput` Objects - -```python -class ForwardBackwardOutput(BaseModel) -``` - -#### `loss_fn_output_type` - -The class name of the loss function output records (e.g., 'TorchLossReturn', 'ArrayRecord'). - -#### `loss_fn_outputs` - -Dictionary mapping field names to tensor data - -#### `metrics` - -Training metrics as key-value pairs - ## `SampleResponse` Objects ```python @@ -470,15 +311,68 @@ the top-k logprobs are computed for every token in the prompt. The `topk_prompt_logprobs` response contains, for every token in the prompt, a list of up to k (token_id, logprob) tuples. -## `CreateSamplingSessionResponse` Objects +## `FutureRetrieveRequest` Objects ```python -class CreateSamplingSessionResponse(BaseModel) +class FutureRetrieveRequest(StrictBase) ``` -#### `sampling_session_id` +#### `request_id` -The generated sampling session ID +The ID of the request to retrieve + +#### `allow_metadata_only` + +When True, the server may return only response metadata (status and size) +instead of the full payload if the response exceeds the server's inline size limit. + +## `ForwardBackwardOutput` Objects + +```python +class ForwardBackwardOutput(BaseModel) +``` + +#### `loss_fn_output_type` + +The class name of the loss function output records (e.g., 'TorchLossReturn', 'ArrayRecord'). + +#### `loss_fn_outputs` + +Dictionary mapping field names to tensor data + +#### `metrics` + +Training metrics as key-value pairs. + +The following metrics are recorded only during MoE (Mixture of Experts) training. +Note: Don't fixate on the exact values of these metrics at the start of training. +Different models on different data will have different initial values. How these +metrics evolve over training is what matters. + +In the definitions below, *perfect balance* means ``total_tokens / num_experts`` +— the number of tokens each expert would receive if routing were perfectly uniform. + +- ``e_frac_with_tokens:mean``: Fraction of experts that received at least one token, + averaged across layers. A value of 1.0 means every expert got work; 0.5 means half + were idle. Decreasing over time is concerning (routing collapse). + +- ``e_frac_oversubscribed:mean``: Fraction of experts receiving more tokens than + perfect balance, averaged across layers. Increasing over time is concerning. + +- ``e_max_violation:mean``: How much the most overloaded expert exceeds perfect + balance, as a fraction of perfect balance, averaged across layers. Computed as + ``(max_tokens - perfect_balance) / perfect_balance``. A value of 2.0 means the + busiest expert got 3x the fair share. Increasing over time is concerning. + +- ``e_max_violation:max``: Same as ``e_max_violation:mean`` but takes the max + across layers instead of the mean. Shows the worst-case load imbalance in any + single layer. + +- ``e_min_violation:mean``: How much the least loaded expert is below perfect + balance, as a fraction of perfect balance, averaged across layers. Computed as + ``(min_tokens - perfect_balance) / perfect_balance``. A value of -0.5 means the + least-used expert got half the fair share; -1.0 means it got nothing. Typically + negative. Decreasing over time (more negative) is concerning. ## `ModelData` Objects @@ -532,59 +426,151 @@ The rank of the LoRA adaptation, if applicable. The name of the model. -## `Cursor` Objects +## `SaveWeightsResponse` Objects ```python -class Cursor(BaseModel) +class SaveWeightsResponse(BaseModel) ``` -#### `offset` +#### `path` -The offset used for pagination +A tinker URI for model weights at a specific step -#### `limit` - -The maximum number of items requested - -#### `total_count` - -The total number of items available - -## `CreateModelRequest` Objects +## `LoraConfig` Objects ```python -class CreateModelRequest(StrictBase) +class LoraConfig(StrictBase) ``` +#### `rank` + +LoRA rank (dimension of low-rank matrices) + +#### `seed` + +Seed used for initialization of LoRA weights. + +Useful if you need deterministic or reproducible initialization of weights. + +#### `train_unembed` + +Whether to add lora to the unembedding layer + +#### `train_mlp` + +Whether to add loras to the MLP layers (including MoE layers) + +#### `train_attn` + +Whether to add loras to the attention layers + +## `SaveWeightsForSamplerResponseInternal` Objects + +```python +class SaveWeightsForSamplerResponseInternal(BaseModel) +``` + +#### `path` + +A tinker URI for model weights for sampling at a specific step + +#### `sampling_session_id` + +The generated sampling session ID + +## `SaveWeightsForSamplerResponse` Objects + +```python +class SaveWeightsForSamplerResponse(BaseModel) +``` + +#### `path` + +A tinker URI for model weights for sampling at a specific step + +## `CreateSamplingSessionRequest` Objects + +```python +class CreateSamplingSessionRequest(StrictBase) +``` + +#### `session_id` + +The session ID to create the sampling session within + +#### `sampling_session_seq_id` + +Sequence ID for the sampling session within the session + #### `base_model` -The name of the base model to fine-tune (e.g., 'Qwen/Qwen3-8B'). +Optional base model name to sample from. -#### `user_metadata` +Is inferred from model_path, if provided. If sampling against a base model, this +is required. -Optional metadata about this model/training run, set by the end-user. +#### `model_path` -#### `lora_config` +Optional tinker:// path to your model weights or LoRA weights. -LoRA configuration +If not provided, samples against the base model. -## `Datum` Objects +## `OptimStepResponse` Objects ```python -class Datum(StrictBase) +class OptimStepResponse(BaseModel) ``` -#### `loss_fn_inputs` +#### `metrics` -Dictionary mapping field names to tensor data +Optimization step metrics as key-value pairs -#### `convert_tensors` +## `SampleRequest` Objects ```python -def convert_tensors(cls, data: Any) -> Any +class SampleRequest(StrictBase) ``` -Convert torch.Tensor and numpy arrays to TensorData in loss_fn_inputs during construction. +#### `num_samples` + +Number of samples to generate + +#### `base_model` + +Optional base model name to sample from. + +Is inferred from model_path, if provided. If sampling against a base model, this +is required. + +#### `model_path` + +Optional tinker:// path to your model weights or LoRA weights. + +If not provided, samples against the base model. + +#### `sampling_session_id` + +Optional sampling session ID to use instead of model_path/base_model. + +If provided along with seq_id, the model configuration will be loaded from the +sampling session. This is useful for multi-turn conversations. + +#### `seq_id` + +Sequence ID within the sampling session. + +Required when sampling_session_id is provided. Used to generate deterministic +request IDs for the sampling request. + +#### `prompt_logprobs` + +If set to `true`, computes and returns logprobs on the prompt tokens. + +Defaults to false. + +#### `topk_prompt_logprobs` + +If set to a positive integer, returns the top-k logprobs for each prompt token. ## `TrainingRun` Objects @@ -632,15 +618,63 @@ The most recent sampler checkpoint, if available Optional metadata about this training run, set by the end-user -## `SessionEndEvent` Objects +## `TelemetrySendRequest` Objects ```python -class SessionEndEvent(BaseModel) +class TelemetrySendRequest(StrictBase) ``` -#### `duration` +#### `platform` -ISO 8601 duration string +Host platform name + +#### `sdk_version` + +SDK version string + +## `CheckpointArchiveUrlResponse` Objects + +```python +class CheckpointArchiveUrlResponse(BaseModel) +``` + +#### `url` + +Signed URL to download the checkpoint archive + +#### `expires` + +Unix timestamp when the signed URL expires, if available + +## `SupportedModel` Objects + +```python +class SupportedModel(BaseModel) +``` + +Information about a model supported by the server. + +#### `model_name` + +The name of the supported model. + +## `GetServerCapabilitiesResponse` Objects + +```python +class GetServerCapabilitiesResponse(BaseModel) +``` + +Response containing the server's supported models and capabilities. + +#### `supported_models` + +List of models available on the server. + +## `SessionStartEvent` Objects + +```python +class SessionStartEvent(BaseModel) +``` #### `event` @@ -650,6 +684,91 @@ Telemetry event type Log severity level +## `GenericEvent` Objects + +```python +class GenericEvent(BaseModel) +``` + +#### `event` + +Telemetry event type + +#### `event_name` + +Low-cardinality event name + +#### `severity` + +Log severity level + +#### `event_data` + +Arbitrary structured JSON payload + +## `TryAgainResponse` Objects + +```python +class TryAgainResponse(BaseModel) +``` + +#### `request_id` + +Request ID that is still pending + +## `TrainingRunsResponse` Objects + +```python +class TrainingRunsResponse(BaseModel) +``` + +#### `training_runs` + +List of training runs + +#### `cursor` + +Pagination cursor information + +## `ForwardBackwardInput` Objects + +```python +class ForwardBackwardInput(StrictBase) +``` + +#### `data` + +Array of input data for the forward/backward pass + +#### `loss_fn` + +Fully qualified function path for the loss function + +#### `loss_fn_config` + +Optional configuration parameters for the loss function (e.g., PPO clip thresholds, DPO beta) + +## `ImageAssetPointerChunk` Objects + +```python +class ImageAssetPointerChunk(StrictBase) +``` + +#### `format` + +Image format + +#### `location` + +Path or URL to the image asset + +#### `expected_tokens` + +Expected number of tokens this image represents. +This is only advisory: the tinker backend will compute the number of tokens +from the image, and we can fail requests quickly if the tokens does not +match expected_tokens. + ## `TelemetryBatch` Objects ```python @@ -664,23 +783,79 @@ Host platform name SDK version string -## `UnhandledExceptionEvent` Objects +## `TensorData` Objects ```python -class UnhandledExceptionEvent(BaseModel) +class TensorData(StrictBase) ``` -#### `event` +#### `data` -Telemetry event type +Flattened tensor data as array of numbers. -#### `severity` +#### `shape` -Log severity level +Optional. -#### `traceback` +The shape of the tensor (see PyTorch tensor.shape). The shape of a +one-dimensional list of length N is `(N,)`. Can usually be inferred if not +provided, and is generally inferred as a 1D tensor. -Optional Python traceback string +#### `to_numpy` + +```python +def to_numpy() -> npt.NDArray[Any] +``` + +Convert TensorData to numpy array. + +#### `to_torch` + +```python +def to_torch() -> "torch.Tensor" +``` + +Convert TensorData to torch tensor. + +## `EncodedTextChunk` Objects + +```python +class EncodedTextChunk(StrictBase) +``` + +#### `tokens` + +Array of token IDs + +## `AdamParams` Objects + +```python +class AdamParams(StrictBase) +``` + +#### `learning_rate` + +Learning rate for the optimizer + +#### `beta1` + +Coefficient used for computing running averages of gradient + +#### `beta2` + +Coefficient used for computing running averages of gradient square + +#### `eps` + +Term added to the denominator to improve numerical stability + +#### `weight_decay` + +Weight decay for the optimizer. Uses decoupled weight decay. + +#### `grad_clip_norm` + +Maximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping. ## `ImageChunk` Objects @@ -719,6 +894,42 @@ def serialize_data(value: bytes) -> str Serialize bytes to base64 string for JSON. +## `SampledSequence` Objects + +```python +class SampledSequence(BaseModel) +``` + +#### `stop_reason` + +Reason why sampling stopped + +#### `tokens` + +List of generated token IDs + +#### `logprobs` + +Log probabilities for each token (optional) + +## `Cursor` Objects + +```python +class Cursor(BaseModel) +``` + +#### `offset` + +The offset used for pagination + +#### `limit` + +The maximum number of items requested + +#### `total_count` + +The total number of items available + ## `SaveWeightsRequest` Objects ```python @@ -732,180 +943,3 @@ A file/directory name for the weights #### `ttl_seconds` TTL in seconds for this checkpoint (None = never expires) - -## `LoraConfig` Objects - -```python -class LoraConfig(StrictBase) -``` - -#### `rank` - -LoRA rank (dimension of low-rank matrices) - -#### `seed` - -Seed used for initialization of LoRA weights. - -Useful if you need deterministic or reproducible initialization of weights. - -#### `train_unembed` - -Whether to add lora to the unembedding layer - -#### `train_mlp` - -Whether to add loras to the MLP layers (including MoE layers) - -#### `train_attn` - -Whether to add loras to the attention layers - -## `CreateSamplingSessionRequest` Objects - -```python -class CreateSamplingSessionRequest(StrictBase) -``` - -#### `session_id` - -The session ID to create the sampling session within - -#### `sampling_session_seq_id` - -Sequence ID for the sampling session within the session - -#### `base_model` - -Optional base model name to sample from. - -Is inferred from model_path, if provided. If sampling against a base model, this -is required. - -#### `model_path` - -Optional tinker:// path to your model weights or LoRA weights. - -If not provided, samples against the base model. - -## `FutureRetrieveRequest` Objects - -```python -class FutureRetrieveRequest(StrictBase) -``` - -#### `request_id` - -The ID of the request to retrieve - -## `TensorData` Objects - -```python -class TensorData(StrictBase) -``` - -#### `data` - -Flattened tensor data as array of numbers. - -#### `shape` - -Optional. - -The shape of the tensor (see PyTorch tensor.shape). The shape of a -one-dimensional list of length N is `(N,)`. Can usually be inferred if not -provided, and is generally inferred as a 1D tensor. - -#### `to_numpy` - -```python -def to_numpy() -> npt.NDArray[Any] -``` - -Convert TensorData to numpy array. - -#### `to_torch` - -```python -def to_torch() -> "torch.Tensor" -``` - -Convert TensorData to torch tensor. - -## `SaveWeightsForSamplerRequest` Objects - -```python -class SaveWeightsForSamplerRequest(StrictBase) -``` - -#### `path` - -A file/directory name for the weights - -#### `ttl_seconds` - -TTL in seconds for this checkpoint (None = never expires) - -## `SamplingParams` Objects - -```python -class SamplingParams(BaseModel) -``` - -#### `max_tokens` - -Maximum number of tokens to generate - -#### `seed` - -Random seed for reproducible generation - -#### `stop` - -Stop sequences for generation - -#### `temperature` - -Sampling temperature - -#### `top_k` - -Top-k sampling parameter (-1 for no limit) - -#### `top_p` - -Nucleus sampling probability - -## `SaveWeightsForSamplerResponseInternal` Objects - -```python -class SaveWeightsForSamplerResponseInternal(BaseModel) -``` - -#### `path` - -A tinker URI for model weights for sampling at a specific step - -#### `sampling_session_id` - -The generated sampling session ID - -## `SaveWeightsForSamplerResponse` Objects - -```python -class SaveWeightsForSamplerResponse(BaseModel) -``` - -#### `path` - -A tinker URI for model weights for sampling at a specific step - -## `LoadWeightsResponse` Objects - -```python -class LoadWeightsResponse(BaseModel) -``` - -#### `path` - -A tinker URI for model weights at a specific step diff --git a/pyproject.toml b/pyproject.toml index 51709ec..dae08ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tinker" -version = "0.15.0" +version = "0.16.0" description = "The official Python SDK for the tinker API" readme = "README.md" license = "Apache-2.0" diff --git a/scripts/generate_docs.py b/scripts/generate_docs.py index 1ccd348..e5fd48d 100755 --- a/scripts/generate_docs.py +++ b/scripts/generate_docs.py @@ -10,6 +10,7 @@ import ast import json import os +import shutil import subprocess import sys from pathlib import Path @@ -95,8 +96,42 @@ class DocumentationGenerator: def run_pydoc_markdown(self, modules: List[str], output_file: Path) -> bool: """Run pydoc-markdown for specific modules.""" try: - # Build the command - cmd = ["pydoc-markdown", "pydoc-markdown.yml", "-I", "src"] + # Invoke pydoc-markdown via uvx/uv tool run so the required + # doc-generation dependencies are resolved consistently. + uvx_path = shutil.which("uvx") + if uvx_path: + cmd = [ + uvx_path, + "--from", + "pydoc-markdown>=4.8.0", + "--with", + "pyyaml>=6.0", + "--with", + "setuptools", + "pydoc-markdown", + "pydoc-markdown.yml", + "-I", + "src", + ] + else: + uv_path = shutil.which("uv") + if uv_path is None: + raise FileNotFoundError("Could not find `uvx` or `uv` on PATH") + cmd = [ + uv_path, + "tool", + "run", + "--from", + "pydoc-markdown>=4.8.0", + "--with", + "pyyaml>=6.0", + "--with", + "setuptools", + "pydoc-markdown", + "pydoc-markdown.yml", + "-I", + "src", + ] # Add modules for module in modules: diff --git a/src/tinker/__init__.py b/src/tinker/__init__.py index ea62d5d..512050a 100644 --- a/src/tinker/__init__.py +++ b/src/tinker/__init__.py @@ -16,6 +16,10 @@ from ._exceptions import ( PermissionDeniedError, RateLimitError, RequestFailedError, + SidecarDiedError, + SidecarError, + SidecarIPCError, + SidecarStartupError, TinkerError, UnprocessableEntityError, ) @@ -97,6 +101,10 @@ __all__ = [ "UnprocessableEntityError", "RateLimitError", "InternalServerError", + "SidecarError", + "SidecarStartupError", + "SidecarDiedError", + "SidecarIPCError", # Keep types module for advanced use "types", # Version info diff --git a/src/tinker/cli/commands/checkpoint.py b/src/tinker/cli/commands/checkpoint.py index af5b76d..a4c0c48 100644 --- a/src/tinker/cli/commands/checkpoint.py +++ b/src/tinker/cli/commands/checkpoint.py @@ -11,6 +11,8 @@ from typing import TYPE_CHECKING, Any, Dict, List import click if TYPE_CHECKING: + from datetime import datetime + from tinker.lib.public_interfaces.rest_client import RestClient from tinker.types import Checkpoint, TrainingRun @@ -838,53 +840,249 @@ def set_ttl(cli_context: CLIContext, checkpoint_path: str, ttl: int | None, remo client.set_checkpoint_ttl_from_tinker_path(checkpoint_path, ttl_seconds).result() +def _parse_date(value: str) -> "datetime": + """Parse an ISO 8601 date or datetime string to a timezone-aware datetime. + + Accepts: 2024-01-01, 2024-01-01T12:00:00, 2024-01-01T12:00:00Z, + 2024-01-01T12:00:00+00:00. Date-only values are treated as midnight UTC. + """ + from datetime import UTC, datetime + + value = value.strip() + # Python < 3.11 doesn't handle trailing 'Z' in fromisoformat + if value.endswith("Z"): + value = value[:-1] + "+00:00" + try: + dt = datetime.fromisoformat(value) + except ValueError: + raise TinkerCliError( + f"Invalid date: {value}", + "Use ISO 8601 format: 2024-01-01, 2024-01-01T12:00:00Z", + ) + # Assume UTC if no timezone provided + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + return dt + + +_CHECKPOINT_TYPE_MAP = { + "weights": "training", + "sampler_weights": "sampler", + "training": "training", + "sampler": "sampler", +} + + +def _filter_checkpoints( + checkpoints: "List[Checkpoint]", + checkpoint_type: str | None, + before: "datetime | None", + after: "datetime | None", +) -> "List[Checkpoint]": + """Filter checkpoints by type, before date, and/or after date.""" + filtered = checkpoints + if checkpoint_type: + mapped_type = _CHECKPOINT_TYPE_MAP.get(checkpoint_type) + if mapped_type is None: + raise TinkerCliError( + f"Invalid checkpoint type: {checkpoint_type}", + "Valid types: weights, sampler_weights", + ) + filtered = [c for c in filtered if c.checkpoint_type == mapped_type] + if before is not None: + filtered = [c for c in filtered if c.time < before] + if after is not None: + filtered = [c for c in filtered if c.time > after] + return filtered + + +def _confirm_deletion(paths: "List[str]", checkpoints: "List[Checkpoint] | None" = None) -> bool: + """Show deletion summary and prompt for confirmation. Returns True if confirmed.""" + count = len(paths) + if checkpoints is not None: + total_size = sum(c.size_bytes or 0 for c in checkpoints) + click.echo(f"Will delete {count} checkpoint(s):") + for ckpt in checkpoints: + size_str = format_size(ckpt.size_bytes) if ckpt.size_bytes is not None else "N/A" + time_str = format_timestamp(ckpt.time) + click.echo(f" - {ckpt.tinker_path} ({size_str}, created {time_str})") + click.echo(f"\nTotal size: {format_size(total_size)}") + else: + click.echo(f"Will delete {count} checkpoint(s):") + for path in paths: + click.echo(f" - {path}") + click.echo() + click.echo("WARNING: This action is permanent and cannot be undone.") + return click.confirm(f"Are you sure you want to delete {count} checkpoint(s)?") + + +_DELETE_CONCURRENCY = 32 + + +def _delete_one(client: "RestClient", path: str) -> "tuple[str, str] | None": + """Delete a single checkpoint. Returns (path, error) on failure, None on success.""" + try: + client.delete_checkpoint_from_tinker_path(path).result() + return None + except Exception as e: + return (path, str(e)) + + +def _delete_paths( + client: "RestClient", + paths: "List[str]", + format: str, +) -> None: + """Delete a list of tinker paths concurrently and print results.""" + from concurrent.futures import ThreadPoolExecutor, as_completed + + deleted_count = 0 + failed: "List[tuple[str, str]]" = [] + with ( + ThreadPoolExecutor(max_workers=_DELETE_CONCURRENCY) as pool, + click.progressbar( + length=len(paths), + label="Deleting checkpoints", + show_percent=True, + show_pos=True, + hidden=format != "table", + ) as bar, + ): + futures = {pool.submit(_delete_one, client, p): p for p in paths} + for future in as_completed(futures): + result = future.result() + if result is None: + deleted_count += 1 + else: + failed.append(result) + bar.update(1) + + if format == "json": + import json + + click.echo( + json.dumps( + { + "deleted_count": deleted_count, + "failed": [{"tinker_path": p, "error": e} for p, e in failed], + } + ) + ) + else: + click.echo(f"Deleted {deleted_count} checkpoint(s)") + if failed: + click.echo(f"Failed to delete {len(failed)} checkpoint(s):") + for path, error in failed: + click.echo(f" - {path}: {error}") + + @cli.command() -@click.argument("checkpoint_paths", nargs=-1, required=True) +@click.argument("checkpoint_paths", nargs=-1, required=False) +@click.option("--run-id", default=None, help="Delete all checkpoints for a training run") +@click.option( + "--type", "checkpoint_type", default=None, help="Filter by type: weights or sampler_weights" +) +@click.option( + "--before", + default=None, + help="Filter: created before date in UTC (ISO 8601, e.g. 2024-01-01, 2024-01-01T08:00:00Z)", +) +@click.option( + "--after", + default=None, + help="Filter: created after date in UTC (ISO 8601, e.g. 2024-01-01, 2024-01-01T08:00:00Z)", +) @click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt") @click.pass_obj @handle_api_errors -def delete(cli_context: CLIContext, checkpoint_paths: tuple[str, ...], yes: bool) -> None: +def delete( + cli_context: CLIContext, + checkpoint_paths: tuple[str, ...], + run_id: str | None, + checkpoint_type: str | None, + before: str | None, + after: str | None, + yes: bool, +) -> None: """Delete one or more checkpoints permanently. - CHECKPOINT_PATHS must be tinker paths (e.g., tinker://run-id/weights/0001). + Delete by explicit paths: + + tinker checkpoint delete tinker://run-id/weights/0001 tinker://run-id/weights/0002 + + Delete all checkpoints for a training run: + + tinker checkpoint delete --run-id + + Delete with filters: + + tinker checkpoint delete --run-id --type weights --before 2024-06-01 + + Delete checkpoints in a date range: + + tinker checkpoint delete --run-id --after 2024-01-01 --before 2024-02-01 + + Dates are interpreted as UTC. Use full ISO 8601 datetime for precision: + + tinker checkpoint delete --run-id --before 2024-06-01T08:00:00Z + Only the owner of the training run can delete checkpoints. WARNING: This action is permanent and cannot be undone. """ - # Validate all paths upfront - for path in checkpoint_paths: - if not path.startswith("tinker://"): - raise TinkerCliError( - f"Invalid checkpoint path: {path}", - "Checkpoint path must be in the format: tinker://run-id/weights/0001", - ) + if not checkpoint_paths and not run_id: + raise TinkerCliError( + "Must specify checkpoint paths or --run-id", + "Examples:\n" + " tinker checkpoint delete tinker://run-id/weights/0001\n" + " tinker checkpoint delete --run-id \n" + " tinker checkpoint delete --run-id --type weights --before 2024-06-01", + ) - # If not using --yes, show checkpoint list and prompt for confirmation - if not yes: - count = len(checkpoint_paths) - click.echo(f"Will delete {count} checkpoint(s):") + if checkpoint_paths and run_id: + raise TinkerCliError( + "Cannot specify both checkpoint paths and --run-id", + "Use either explicit paths or --run-id with optional filters", + ) + + has_filters = checkpoint_type or before or after + if has_filters and not run_id: + raise TinkerCliError( + "--type, --before, and --after require --run-id", + "Example: tinker checkpoint delete --run-id --type weights --before 2024-06-01", + ) + + client = create_rest_client() + + if run_id: + before_dt = _parse_date(before) if before else None + after_dt = _parse_date(after) if after else None + response = client.list_checkpoints(run_id).result() + checkpoints = _filter_checkpoints( + response.checkpoints, checkpoint_type, before_dt, after_dt + ) + + if not checkpoints: + click.echo(f"No checkpoints found for run {run_id} matching filters") + return + + paths_to_delete = [c.tinker_path for c in checkpoints] + if not yes and not _confirm_deletion(paths_to_delete, checkpoints): + click.echo("Deletion cancelled.") + return + else: for path in checkpoint_paths: - click.echo(f" - {path}") - click.echo() - - # Confirmation prompt - click.echo("WARNING: This action is permanent and cannot be undone.") - if not click.confirm(f"Are you sure you want to delete {count} checkpoint(s)?"): + if not path.startswith("tinker://"): + raise TinkerCliError( + f"Invalid checkpoint path: {path}", + "Checkpoint path must be in the format: tinker://run-id/weights/0001", + ) + paths_to_delete = list(checkpoint_paths) + if not yes and not _confirm_deletion(paths_to_delete): click.echo("Deletion cancelled.") return - # Create client and delete with progress bar - client = create_rest_client() - - with click.progressbar( - checkpoint_paths, - label="Deleting checkpoints", - show_percent=True, - show_pos=True, - hidden=cli_context.format != "table", - ) as bar: - for path in bar: - client.delete_checkpoint_from_tinker_path(path).result() + _delete_paths(client, paths_to_delete, cli_context.format) @cli.command() diff --git a/src/tinker/lib/public_interfaces/sampling_client.py b/src/tinker/lib/public_interfaces/sampling_client.py index 789c74f..0f3355c 100644 --- a/src/tinker/lib/public_interfaces/sampling_client.py +++ b/src/tinker/lib/public_interfaces/sampling_client.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import dataclasses import logging import os import time @@ -15,6 +16,7 @@ import tinker from tinker import types from tinker.lib.client_connection_pool_type import ClientConnectionPoolType from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture +from tinker.lib.sidecar import SidecarHandle, SidecarRPC, create_sidecar_handle from tinker.lib.telemetry import Telemetry, capture_exceptions from tinker.lib.telemetry_provider import TelemetryProvider @@ -33,6 +35,56 @@ logger = logging.getLogger(__name__) U = TypeVar("U") +# --------------------------------------------------------------------------- +# Pickle serialization state +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class _SamplingClientPickleState: + """Serialized state for pickling SamplingClient across processes.""" + + session_id: str + sampling_session_id: str + constructor_kwargs: dict[str, Any] + subprocess_sampling: bool + + +# --------------------------------------------------------------------------- +# Typed RPCs for subprocess-isolated sampling +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class _SampleRPC(SidecarRPC): + """Typed RPC for SamplingClient.sample().""" + + prompt: types.ModelInput + num_samples: int + sampling_params: types.SamplingParams + include_prompt_logprobs: bool + topk_prompt_logprobs: int + + async def execute(self, target: Any) -> Any: + return target.sample( + prompt=self.prompt, + num_samples=self.num_samples, + sampling_params=self.sampling_params, + include_prompt_logprobs=self.include_prompt_logprobs, + topk_prompt_logprobs=self.topk_prompt_logprobs, + ) + + +@dataclasses.dataclass +class _ComputeLogprobsRPC(SidecarRPC): + """Typed RPC for SamplingClient.compute_logprobs().""" + + prompt: types.ModelInput + + async def execute(self, target: Any) -> Any: + return target.compute_logprobs(prompt=self.prompt) + + class SamplingClient(TelemetryProvider, QueueStateObserver): """Client for text generation and inference from trained or base models. @@ -65,6 +117,12 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): If you are using Tinker SDK with more than one process you should always create SamplingClient from the main process and then pass it to the other processes/workers. ServiceClient and TrainingClient should always be managed from the main process. + + Subprocess isolation: + Set ``TINKER_SUBPROCESS_SAMPLING=1`` to run sample() and compute_logprobs() in a dedicated + subprocess, preventing GIL contention from CPU-heavy user code (grading, environment + interactions) from stalling networking IO and heartbeats. This is transparent — the same + API works with or without it. """ def __init__( @@ -74,6 +132,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): sampling_session_id: str, shadow: bool = False, retry_config: RetryConfig | None = None, + subprocess_sampling: bool | None = None, ): self.holder = holder @@ -97,6 +156,20 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): # We use 1B as the base and mod for uuid because the maximum int value is 2^63-1 and 1B*1B is less than 2^63-1. self._request_id_counter = 1_000_000_000 * (int(uuid.uuid4()) % 1_000_000_000 + 1) + # Subprocess isolation: read env var if not explicitly set + if subprocess_sampling is None: + subprocess_sampling = os.environ.get("TINKER_SUBPROCESS_SAMPLING", "").lower() in ( + "1", + "true", + "yes", + ) + self._sampling_client_sidecar_handle: SidecarHandle | None = None + if subprocess_sampling: + from tinker.lib.sidecar import _inside_sidecar + + if not _inside_sidecar: + self._sampling_client_sidecar_handle = create_sidecar_handle(self) + @staticmethod async def _create_impl( holder: InternalClientHolder, @@ -237,6 +310,16 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): print(tokenizer.decode(sample.tokens)) ``` """ + if self._sampling_client_sidecar_handle is not None: + return self._sampling_client_sidecar_handle.submit_rpc( + _SampleRPC( + prompt=prompt, + num_samples=num_samples, + sampling_params=sampling_params, + include_prompt_logprobs=include_prompt_logprobs, + topk_prompt_logprobs=topk_prompt_logprobs, + ) + ) async def _sample_async(): return await self._sample_async_impl( @@ -294,6 +377,10 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): print(f"Token {i}: logprob = {logprob:.4f}") ``` """ + if self._sampling_client_sidecar_handle is not None: + return self._sampling_client_sidecar_handle.submit_rpc( + _ComputeLogprobsRPC(prompt=prompt) + ) async def _compute_logprobs_async() -> list[float | None]: sample_res = await self._sample_async_impl( @@ -349,18 +436,24 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): def get_telemetry(self) -> Telemetry | None: return self.holder.get_telemetry() - def __reduce__(self) -> tuple[Any, tuple[str, str, dict[str, Any]]]: + def __reduce__(self) -> tuple[Any, tuple[_SamplingClientPickleState]]: """Enable pickling of SamplingClient for subprocess use. - Stores the sampling_session_id and holder constructor kwargs. - On unpickle, creates a shadow holder and reconstructs the client. + Serializes into a ``_SamplingClientPickleState`` dataclass. The + ``_sampling_client_sidecar_handle`` handle is deliberately omitted — only a + bool flag is stored. The unpickled copy creates its own handle via + the per-process sidecar singleton. Do not add ``__getstate__`` + without preserving this behavior. """ return ( _unpickle_sampling_client, ( - self.holder.get_session_id(), - self._sampling_session_id, - self.holder._constructor_kwargs, + _SamplingClientPickleState( + session_id=self.holder.get_session_id(), + sampling_session_id=self._sampling_session_id, + constructor_kwargs=self.holder._constructor_kwargs, + subprocess_sampling=self._sampling_client_sidecar_handle is not None, + ), ), ) @@ -386,21 +479,21 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): ) -def _unpickle_sampling_client( - session_id: str, - sampling_session_id: str, - constructor_kwargs: dict[str, Any], -) -> SamplingClient: - """Reconstruct a SamplingClient from pickled data. +def _unpickle_sampling_client(state: _SamplingClientPickleState) -> SamplingClient: + """Reconstruct a SamplingClient from pickled state. Creates a shadow InternalClientHolder and builds a new SamplingClient. - The request_id_counter starts at a random high value to avoid collisions. + Subprocess enablement is handled by the constructor. """ from ..internal_client_holder import InternalClientHolder - holder = InternalClientHolder.get_shadow_holder(session_id, constructor_kwargs) - client = SamplingClient(holder, sampling_session_id=sampling_session_id, shadow=True) - return client + holder = InternalClientHolder.get_shadow_holder(state.session_id, state.constructor_kwargs) + return SamplingClient( + holder, + sampling_session_id=state.sampling_session_id, + shadow=True, + subprocess_sampling=state.subprocess_sampling, + ) @lru_cache(maxsize=100) diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index 67c94cd..b047e2d 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -7,7 +7,7 @@ import logging import threading import time from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Tuple from tinker import types from tinker.lib.client_connection_pool_type import ClientConnectionPoolType @@ -49,6 +49,11 @@ MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initi # Args: (data: List[Datum], model_outputs: List[Any]) -> (loss: Any, metrics: Dict[str, float]) CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]] +_SUPPORTED_CUSTOM_BACKEND_LOSS_FNS = frozenset({"cross_entropy"}) +_CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE: dict[Literal["logprobs"], types.LossFnType] = { + "logprobs": "cross_entropy", +} + class TrainingClient(TelemetryProvider): """Client for training ML models with forward/backward passes and optimization. @@ -331,7 +336,11 @@ class TrainingClient(TelemetryProvider): @sync_only @capture_exceptions(fatal=True) def forward_backward_custom( - self, data: List[types.Datum], loss_fn: CustomLossFnV1 + self, + data: List[types.Datum], + loss_fn: CustomLossFnV1, + *, + loss_type_input: Literal["logprobs"] = "logprobs", ) -> APIFuture[types.ForwardBackwardOutput]: """Compute forward/backward with a custom loss function. @@ -341,6 +350,7 @@ class TrainingClient(TelemetryProvider): Args: - `data`: List of training data samples - `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics) + - `loss_type_input`: Input space for `loss_fn`. Currently the only supported value is `"logprobs"`. Returns: - `APIFuture` containing the forward/backward outputs with custom loss @@ -360,23 +370,49 @@ class TrainingClient(TelemetryProvider): ``` """ return self.holder.run_coroutine_threadsafe( - self.forward_backward_custom_async(data, loss_fn) + self.forward_backward_custom_async( + data, + loss_fn, + loss_type_input=loss_type_input, + ) ).result() @capture_exceptions(fatal=True) async def forward_backward_custom_async( - self, data: List[types.Datum], loss_fn: CustomLossFnV1 + self, + data: List[types.Datum], + loss_fn: CustomLossFnV1, + *, + loss_type_input: Literal["logprobs"] = "logprobs", ) -> APIFuture[types.ForwardBackwardOutput]: """Async version of forward_backward_custom.""" if torch is None: raise ImportError("PyTorch is not installed. Cannot run custom forward_backward.") + if loss_type_input not in _CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE: + supported = ", ".join(sorted(_CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE)) + raise ValueError( + f"Unsupported loss_type_input={loss_type_input!r}. " + f"Supported values are: {supported}" + ) + + surrogate_loss_fn = _CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE[loss_type_input] + + forward_data = self._get_custom_loss_forward_data(data, surrogate_loss_fn) + # First do a forward pass and get logprobs - forward_future = await self.forward_async(data, "cross_entropy") + forward_future = await self.forward_async( + forward_data, + surrogate_loss_fn, + None, + ) forward_result = await forward_future.result_async() logprobs_list = [] for out in forward_result.loss_fn_outputs: - logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True) + logprob = torch.tensor(out["logprobs"].data) + if out["logprobs"].shape is not None: + logprob = logprob.reshape(out["logprobs"].shape) + logprob = logprob.clone().detach().requires_grad_(True) logprobs_list.append(logprob) # Now apply user-provided function @@ -392,7 +428,9 @@ class TrainingClient(TelemetryProvider): for datum, grad in zip(data, grads, strict=True): loss_fn_inputs: Any = { "target_tokens": datum.loss_fn_inputs["target_tokens"], - "weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData) + # Backend CE is L = sum(-logprobs * weights), so to backpropagate a + # client-side custom loss C(logprobs) we must send weights = -dC/dlogprobs. + "weights": -grad, } linear_loss_data.append( types.Datum( @@ -402,7 +440,11 @@ class TrainingClient(TelemetryProvider): ) # Do the backward pass with the gradients - backward_future = await self.forward_backward_async(linear_loss_data, "cross_entropy") + backward_future = await self.forward_backward_async( + linear_loss_data, + surrogate_loss_fn, + None, + ) # We need to slightly modify the future to add the custom metrics, so we use _CombinedAPIFuture # to transform the future. @@ -415,6 +457,49 @@ class TrainingClient(TelemetryProvider): return _CombinedAPIFuture([backward_future], add_custom_metrics, self.holder) + def _get_custom_loss_forward_data( + self, + data: List[types.Datum], + surrogate_loss_fn: types.LossFnType, + ) -> List[types.Datum]: + assert surrogate_loss_fn in _SUPPORTED_CUSTOM_BACKEND_LOSS_FNS, ( + "forward_backward_custom_async should validate surrogate_loss_fn before " + "_get_custom_loss_forward_data is called" + ) + + forward_data = [] + for datum in data: + target_tokens = datum.loss_fn_inputs.get("target_tokens") + if target_tokens is None: + raise ValueError("target_tokens must be provided when using cross_entropy") + + unexpected_keys = sorted(set(datum.loss_fn_inputs) - {"target_tokens", "weights"}) + if unexpected_keys: + raise ValueError( + "forward_backward_custom only supports loss_fn_inputs keys " + "{'target_tokens', 'weights'}; " + f"found unexpected keys: {unexpected_keys}" + ) + + if "weights" in datum.loss_fn_inputs: + forward_data.append(datum) + continue + + forward_loss_fn_inputs = dict(datum.loss_fn_inputs) + forward_loss_fn_inputs["weights"] = types.TensorData( + data=[0.0] * len(target_tokens.data), + dtype="float32", + shape=target_tokens.shape, + ) + forward_data.append( + types.Datum( + model_input=datum.model_input, + loss_fn_inputs=forward_loss_fn_inputs, + ) + ) + + return forward_data + @capture_exceptions(fatal=True) def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]: """Update model parameters using Adam optimizer. diff --git a/src/tinker/types/datum.py b/src/tinker/types/datum.py index 9cc448d..d6d772f 100644 --- a/src/tinker/types/datum.py +++ b/src/tinker/types/datum.py @@ -44,13 +44,30 @@ class Datum(StrictBase): @classmethod def _maybe_convert_array(cls, key: str, value: Any) -> Any: - """Convert torch.Tensor, numpy array, or 1-D list to TensorData if needed.""" + """Convert torch.Tensor, numpy array, or numeric lists to TensorData if needed.""" if _HAVE_TORCH and isinstance(value, torch.Tensor): return TensorData.from_torch(value) elif isinstance(value, np.ndarray): return TensorData.from_numpy(value) elif isinstance(value, list): - # assume it's 1d and infer the dtype from the key + try: + array = np.asarray(value) + except ValueError as exc: + if any(isinstance(item, list) for item in value): + raise ValueError( + f"{key} must be a rectangular numeric array; ragged nested lists are not supported" + ) from exc + raise + if array.dtype.kind in ("f", "i", "u"): + if _key_to_type[key] == "int64": + array = array.astype(np.int64) + else: + array = array.astype(np.float32) + return TensorData.from_numpy(array) + if any(isinstance(item, list) for item in value): + raise ValueError( + f"{key} must be a rectangular numeric array; ragged nested lists are not supported" + ) return TensorData(data=value, dtype=_key_to_type[key], shape=[len(value)]) else: return value diff --git a/src/tinker/types/loss_fn_type.py b/src/tinker/types/loss_fn_type.py index 0955375..1502013 100644 --- a/src/tinker/types/loss_fn_type.py +++ b/src/tinker/types/loss_fn_type.py @@ -2,4 +2,10 @@ from typing_extensions import Literal, TypeAlias __all__ = ["LossFnType"] -LossFnType: TypeAlias = Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"] +LossFnType: TypeAlias = Literal[ + "cross_entropy", + "importance_sampling", + "ppo", + "cispo", + "dro", +] diff --git a/tests/test_checkpoint_delete.py b/tests/test_checkpoint_delete.py new file mode 100644 index 0000000..65f4c9f --- /dev/null +++ b/tests/test_checkpoint_delete.py @@ -0,0 +1,196 @@ +"""Tests for bulk checkpoint deletion: CLI flags and date parsing.""" + +from datetime import UTC, datetime, timedelta + +import pytest +from click.testing import CliRunner + +from tinker.cli.commands.checkpoint import _filter_checkpoints, _parse_date + + +class TestParseDate: + """Tests for the _parse_date ISO 8601 parser.""" + + def test_date_only(self) -> None: + dt = _parse_date("2024-01-15") + assert dt.year == 2024 + assert dt.month == 1 + assert dt.day == 15 + assert dt.tzinfo is not None + + def test_datetime_with_z(self) -> None: + dt = _parse_date("2024-06-01T12:00:00Z") + assert dt.year == 2024 + assert dt.month == 6 + assert dt.hour == 12 + assert dt.tzinfo is not None + + def test_datetime_with_offset(self) -> None: + dt = _parse_date("2024-06-01T12:00:00+00:00") + assert dt.year == 2024 + assert dt.tzinfo is not None + + def test_datetime_naive_gets_utc(self) -> None: + dt = _parse_date("2024-06-01T12:00:00") + assert dt.tzinfo is not None + + def test_whitespace_stripped(self) -> None: + dt = _parse_date(" 2024-01-01 ") + assert dt.year == 2024 + + def test_invalid_raises(self) -> None: + from tinker.cli.exceptions import TinkerCliError + + with pytest.raises(TinkerCliError): + _parse_date("not-a-date") + + def test_invalid_format_raises(self) -> None: + from tinker.cli.exceptions import TinkerCliError + + with pytest.raises(TinkerCliError): + _parse_date("01/15/2024") + + +class TestFilterCheckpoints: + """Tests for the _filter_checkpoints function.""" + + @pytest.fixture() + def sample_checkpoints(self): + from tinker.types.checkpoint import Checkpoint + + now = datetime.now(UTC) + return [ + Checkpoint( + checkpoint_id="weights/0001", + checkpoint_type="training", + time=now - timedelta(days=10), + tinker_path="tinker://run-1/weights/0001", + size_bytes=1000, + ), + Checkpoint( + checkpoint_id="weights/0002", + checkpoint_type="training", + time=now - timedelta(days=3), + tinker_path="tinker://run-1/weights/0002", + size_bytes=2000, + ), + Checkpoint( + checkpoint_id="sampler_weights/0001", + checkpoint_type="sampler", + time=now - timedelta(days=10), + tinker_path="tinker://run-1/sampler_weights/0001", + size_bytes=500, + ), + Checkpoint( + checkpoint_id="weights/0003", + checkpoint_type="training", + time=now - timedelta(hours=1), + tinker_path="tinker://run-1/weights/0003", + size_bytes=3000, + ), + ] + + def test_no_filters(self, sample_checkpoints) -> None: + result = _filter_checkpoints(sample_checkpoints, None, None, None) + assert len(result) == 4 + + def test_filter_by_weights_type(self, sample_checkpoints) -> None: + result = _filter_checkpoints(sample_checkpoints, "weights", None, None) + assert len(result) == 3 + assert all(c.checkpoint_type == "training" for c in result) + + def test_filter_by_sampler_weights_type(self, sample_checkpoints) -> None: + result = _filter_checkpoints(sample_checkpoints, "sampler_weights", None, None) + assert len(result) == 1 + assert result[0].checkpoint_type == "sampler" + + def test_filter_before(self, sample_checkpoints) -> None: + # Before 7 days ago → only the 10-day-old checkpoints + cutoff = datetime.now(UTC) - timedelta(days=7) + result = _filter_checkpoints(sample_checkpoints, None, cutoff, None) + assert len(result) == 2 + assert all("0001" in c.checkpoint_id for c in result) + + def test_filter_after(self, sample_checkpoints) -> None: + # After 7 days ago → the 3-day-old and 1-hour-old checkpoints + cutoff = datetime.now(UTC) - timedelta(days=7) + result = _filter_checkpoints(sample_checkpoints, None, None, cutoff) + assert len(result) == 2 + paths = {c.tinker_path for c in result} + assert "tinker://run-1/weights/0002" in paths + assert "tinker://run-1/weights/0003" in paths + + def test_filter_date_range(self, sample_checkpoints) -> None: + # Between 5 and 2 days ago → only the 3-day-old checkpoint + after_dt = datetime.now(UTC) - timedelta(days=5) + before_dt = datetime.now(UTC) - timedelta(days=2) + result = _filter_checkpoints(sample_checkpoints, None, before_dt, after_dt) + assert len(result) == 1 + assert result[0].tinker_path == "tinker://run-1/weights/0002" + + def test_filter_by_type_and_before(self, sample_checkpoints) -> None: + cutoff = datetime.now(UTC) - timedelta(days=7) + result = _filter_checkpoints(sample_checkpoints, "weights", cutoff, None) + assert len(result) == 1 + assert result[0].tinker_path == "tinker://run-1/weights/0001" + + def test_invalid_type_raises(self, sample_checkpoints) -> None: + from tinker.cli.exceptions import TinkerCliError + + with pytest.raises(TinkerCliError): + _filter_checkpoints(sample_checkpoints, "invalid_type", None, None) + + +class TestDeleteCLIValidation: + """Tests for CLI delete command argument validation.""" + + def _get_error_message(self, result) -> str: + """Get error message from either output or exception.""" + if result.output: + return result.output + if result.exception: + return str(result.exception) + return "" + + def test_no_args_shows_error(self) -> None: + from tinker.cli.commands.checkpoint import cli + + runner = CliRunner() + result = runner.invoke(cli, ["delete"]) + assert result.exit_code != 0 + + def test_paths_and_run_id_conflict(self) -> None: + from tinker.cli.commands.checkpoint import cli + + runner = CliRunner() + result = runner.invoke(cli, ["delete", "tinker://run-1/weights/0001", "--run-id", "run-1"]) + assert result.exit_code != 0 + assert "Cannot specify both" in self._get_error_message(result) + + def test_type_without_run_id_error(self) -> None: + from tinker.cli.commands.checkpoint import cli + + runner = CliRunner() + result = runner.invoke(cli, ["delete", "tinker://run-1/weights/0001", "--type", "weights"]) + assert result.exit_code != 0 + assert "--run-id" in self._get_error_message(result) + + def test_before_without_run_id_error(self) -> None: + from tinker.cli.commands.checkpoint import cli + + runner = CliRunner() + result = runner.invoke( + cli, ["delete", "tinker://run-1/weights/0001", "--before", "2024-01-01"] + ) + assert result.exit_code != 0 + assert "--run-id" in self._get_error_message(result) + + def test_after_without_run_id_error(self) -> None: + from tinker.cli.commands.checkpoint import cli + + runner = CliRunner() + result = runner.invoke( + cli, ["delete", "tinker://run-1/weights/0001", "--after", "2024-01-01"] + ) + assert result.exit_code != 0 + assert "--run-id" in self._get_error_message(result) diff --git a/tests/test_subprocess_sampling_client.py b/tests/test_subprocess_sampling_client.py new file mode 100644 index 0000000..20b03a4 --- /dev/null +++ b/tests/test_subprocess_sampling_client.py @@ -0,0 +1,384 @@ +"""Tests for SamplingClient subprocess mode. + +These tests use a picklable fake SamplingClient to verify that +subprocess mode correctly routes sample() and compute_logprobs() +through the sidecar subprocess. + +Test organization: + TestRPCRouting — sample/compute_logprobs delegation through sidecar + TestErrorHandling — error propagation, sidecar death + TestPickle — roundtrip with/without sidecar, re-enable mode + TestConcurrency — multithreaded, async, cancelled futures, mixed ops +""" + +from __future__ import annotations + +import asyncio +import pickle +import threading +import time +from concurrent.futures import Future as ConcurrentFuture +from typing import Any + +import pytest + +from tinker import types +from tinker._exceptions import SidecarDiedError +from tinker.lib.sidecar import create_sidecar_handle + +# --------------------------------------------------------------------------- +# Picklable fake SamplingClient (must be module-level for pickling) +# --------------------------------------------------------------------------- + + +class _FakeSamplingClient: + """A picklable fake that simulates SamplingClient for testing. + + This is NOT a real SamplingClient — it provides just enough interface + to test the sidecar integration. Real SamplingClient requires an + InternalClientHolder and API connection. + """ + + def __init__(self, delay: float = 0.0, fail: bool = False, subprocess_sampling: bool = False): + self._delay = delay + self._fail = fail + self._sampling_client_sidecar_handle = None # set by create_sidecar_handle() in tests + if subprocess_sampling: + from tinker.lib.sidecar import _inside_sidecar + + if not _inside_sidecar: + self._sampling_client_sidecar_handle = create_sidecar_handle(self) + + def sample( + self, + prompt: types.ModelInput, + num_samples: int, + sampling_params: types.SamplingParams, + include_prompt_logprobs: bool = False, + topk_prompt_logprobs: int = 0, + ) -> Any: + # Delegate through sidecar if enabled (mirrors real SamplingClient behavior) + if self._sampling_client_sidecar_handle is not None: + from tinker.lib.public_interfaces.sampling_client import _SampleRPC + + return self._sampling_client_sidecar_handle.submit_rpc( + _SampleRPC( + prompt, + num_samples, + sampling_params, + include_prompt_logprobs, + topk_prompt_logprobs, + ) + ) + + f: ConcurrentFuture[types.SampleResponse] = ConcurrentFuture() + if self._fail: + f.set_exception(RuntimeError("Simulated sample failure")) + elif self._delay > 0: + + def _delayed(): + time.sleep(self._delay) + f.set_result(_make_sample_response()) + + threading.Thread(target=_delayed, daemon=True).start() + else: + f.set_result(_make_sample_response()) + return f + + def compute_logprobs(self, prompt: types.ModelInput) -> Any: + # Delegate through sidecar if enabled (mirrors real SamplingClient behavior) + if self._sampling_client_sidecar_handle is not None: + from tinker.lib.public_interfaces.sampling_client import _ComputeLogprobsRPC + + return self._sampling_client_sidecar_handle.submit_rpc(_ComputeLogprobsRPC(prompt)) + + f: ConcurrentFuture[list[float | None]] = ConcurrentFuture() + if self._fail: + f.set_exception(RuntimeError("Simulated logprobs failure")) + else: + f.set_result([0.1, 0.2, None]) + return f + + def __reduce__(self) -> tuple[type, tuple[float, bool, bool]]: + return ( + _FakeSamplingClient, + (self._delay, self._fail, self._sampling_client_sidecar_handle is not None), + ) + + +def _make_sample_response() -> types.SampleResponse: + return types.SampleResponse( + sequences=[ + types.SampledSequence( + stop_reason="length", + tokens=[1, 2, 3], + logprobs=[0.1, 0.2, 0.3], + ) + ], + type="sample", + ) + + +def _create_proxy(delay: float = 0.0, fail: bool = False) -> _FakeSamplingClient: + """Create a fake client with sidecar handle for testing.""" + client = _FakeSamplingClient(delay=delay, fail=fail) + client._sampling_client_sidecar_handle = create_sidecar_handle(client) + return client + + +_PROMPT = types.ModelInput.from_ints([1, 2, 3]) +_PARAMS = types.SamplingParams(max_tokens=10) + + +# =========================================================================== +# Tests +# =========================================================================== + + +class TestRPCRouting: + """Verify sample() and compute_logprobs() are routed through the sidecar.""" + + def test_sample(self): + """sample() → subprocess → SampleResponse.""" + proxy = _create_proxy() + result = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10) + assert isinstance(result, types.SampleResponse) + assert result.sequences[0].tokens == [1, 2, 3] + + def test_constructor_enables_subprocess_mode(self): + """subprocess_sampling=True in __init__ creates the sidecar handle.""" + client = _FakeSamplingClient(subprocess_sampling=True) + assert client._sampling_client_sidecar_handle is not None + result = client.sample(_PROMPT, 1, _PARAMS).result(timeout=10) + assert isinstance(result, types.SampleResponse) + + def test_compute_logprobs(self): + """compute_logprobs() → subprocess → list of logprobs.""" + proxy = _create_proxy() + result = proxy.compute_logprobs(_PROMPT).result(timeout=10) + assert result == [0.1, 0.2, None] + + def test_mixed_sample_and_logprobs(self): + """Interleaved sample() and compute_logprobs() all resolve correctly.""" + proxy = _create_proxy(delay=0.01) + + futures_sample = [proxy.sample(_PROMPT, 1, _PARAMS) for _ in range(10)] + futures_logprobs = [proxy.compute_logprobs(_PROMPT) for _ in range(10)] + + for f in futures_sample: + result = f.result(timeout=10) + assert isinstance(result, types.SampleResponse) + assert result.sequences[0].tokens == [1, 2, 3] + + for f in futures_logprobs: + assert f.result(timeout=10) == [0.1, 0.2, None] + + +class TestErrorHandling: + """Error propagation from subprocess to caller.""" + + def test_sample_error(self): + """Exceptions from sample() in the subprocess are propagated.""" + proxy = _create_proxy(fail=True) + with pytest.raises(RuntimeError, match="Simulated sample failure"): + proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10) + + def test_compute_logprobs_error(self): + """Exceptions from compute_logprobs() in the subprocess are propagated.""" + proxy = _create_proxy(fail=True) + with pytest.raises(RuntimeError, match="Simulated logprobs failure"): + proxy.compute_logprobs(_PROMPT).result(timeout=10) + + @pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning") + def test_sidecar_death_fails_pending_futures(self): + """When the subprocess is killed, pending futures get SidecarDiedError.""" + proxy = _create_proxy(delay=0.5) + future = proxy.sample(_PROMPT, 1, _PARAMS) + + # Kill the underlying subprocess directly + sidecar = proxy._sampling_client_sidecar_handle._sidecar + assert sidecar._process is not None + sidecar._process.kill() + sidecar._process.join(timeout=5) + + with pytest.raises(SidecarDiedError): + future.result(timeout=5) + + +class TestPickle: + """Pickle roundtrip preserves subprocess mode correctly.""" + + def test_roundtrip_preserves_subprocess_mode(self): + """Pickling a sidecar-enabled client re-enables subprocess mode on unpickle.""" + proxy = _create_proxy() + assert proxy._sampling_client_sidecar_handle is not None + + restored = pickle.loads(pickle.dumps(proxy)) + assert restored._sampling_client_sidecar_handle is not None + + result = restored.sample(_PROMPT, 1, _PARAMS).result(timeout=10) + assert isinstance(result, types.SampleResponse) + + def test_roundtrip_without_sidecar(self): + """Pickling a client without subprocess mode keeps it disabled.""" + client = _FakeSamplingClient() + assert client._sampling_client_sidecar_handle is None + restored = pickle.loads(pickle.dumps(client)) + assert restored._sampling_client_sidecar_handle is None + + def test_re_enable_subprocess_mode(self): + """Replacing the sidecar handle works cleanly.""" + client = _FakeSamplingClient() + client._sampling_client_sidecar_handle = create_sidecar_handle(client) + + # First handle works + assert isinstance( + client.sample(_PROMPT, 1, _PARAMS).result(timeout=10), types.SampleResponse + ) + + # Replace with a new handle (old one is GC'd and unregistered) + client._sampling_client_sidecar_handle = create_sidecar_handle(client) + + # New handle also works + assert isinstance( + client.sample(_PROMPT, 1, _PARAMS).result(timeout=10), types.SampleResponse + ) + + +class TestConcurrency: + """Thread safety and concurrent operations through the sidecar.""" + + def test_multithreaded_samples(self): + """sample() from 20 threads all resolve correctly.""" + proxy = _create_proxy(delay=0.01) + results: list[types.SampleResponse | None] = [None] * 20 + errors: list[Exception] = [] + + def _worker(idx: int) -> None: + try: + results[idx] = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=30) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=_worker, args=(i,)) for i in range(20)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + + assert not errors, f"Threads raised: {errors}" + for r in results: + assert isinstance(r, types.SampleResponse) + assert r.sequences[0].tokens == [1, 2, 3] + + def test_multithreaded_mixed_operations(self): + """sample() and compute_logprobs() from different threads simultaneously.""" + proxy = _create_proxy(delay=0.01) + errors: list[Exception] = [] + + def _sample_worker() -> None: + try: + for _ in range(10): + r = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10) + assert isinstance(r, types.SampleResponse) + except Exception as e: + errors.append(e) + + def _logprobs_worker() -> None: + try: + for _ in range(10): + r = proxy.compute_logprobs(_PROMPT).result(timeout=10) + assert r == [0.1, 0.2, None] + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=_sample_worker) for _ in range(3)] + threads += [threading.Thread(target=_logprobs_worker) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + + assert not errors, f"Errors: {errors}" + + def test_async_concurrent_samples(self): + """Multiple async sample calls via asyncio.gather all resolve.""" + proxy = _create_proxy(delay=0.01) + + async def _run() -> list[types.SampleResponse]: + from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture + + coros = [ + AwaitableConcurrentFuture(proxy.sample(_PROMPT, 1, _PARAMS)) for _ in range(20) + ] + return await asyncio.gather(*coros) + + results = asyncio.run(_run()) + assert len(results) == 20 + for r in results: + assert isinstance(r, types.SampleResponse) + + def test_cancelled_future_does_not_crash_collector(self): + """Cancelling a future doesn't kill the collector thread.""" + proxy = _create_proxy(delay=0.5) + + future1 = proxy.sample(_PROMPT, 1, _PARAMS) + future1.cancel() + + result = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10) + assert isinstance(result, types.SampleResponse) + + def test_multiple_clients_share_sidecar(self): + """Two independent clients sharing the sidecar singleton work concurrently.""" + proxy1 = _create_proxy(delay=0.01) + proxy2 = _create_proxy(delay=0.01) + errors: list[Exception] = [] + + def _worker1() -> None: + try: + for _ in range(10): + r = proxy1.sample(_PROMPT, 1, _PARAMS).result(timeout=10) + assert isinstance(r, types.SampleResponse) + except Exception as e: + errors.append(e) + + def _worker2() -> None: + try: + for _ in range(10): + r = proxy2.compute_logprobs(_PROMPT).result(timeout=10) + assert r == [0.1, 0.2, None] + except Exception as e: + errors.append(e) + + t1 = threading.Thread(target=_worker1) + t2 = threading.Thread(target=_worker2) + t1.start() + t2.start() + t1.join(timeout=30) + t2.join(timeout=30) + + assert not errors, f"Errors: {errors}" + + def test_pickle_roundtrip_then_concurrent_use(self): + """Pickle a client, restore it, then use from multiple threads.""" + proxy = _create_proxy(delay=0.01) + restored = pickle.loads(pickle.dumps(proxy)) + assert restored._sampling_client_sidecar_handle is not None + + errors: list[Exception] = [] + + def _worker() -> None: + try: + for _ in range(10): + r = restored.sample(_PROMPT, 1, _PARAMS).result(timeout=10) + assert isinstance(r, types.SampleResponse) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=_worker) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + + assert not errors, f"Errors: {errors}" diff --git a/tests/test_training_client_custom_loss.py b/tests/test_training_client_custom_loss.py new file mode 100644 index 0000000..fe80759 --- /dev/null +++ b/tests/test_training_client_custom_loss.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +from concurrent.futures import Future + +import pytest +import torch + +from tinker import types +from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture +from tinker.lib.public_interfaces.training_client import TrainingClient + + +class _DummyHolder: + def run_coroutine_threadsafe(self, coro): + future: Future = Future() + future.set_result(coro) + return future + + def get_telemetry(self): + return None + + +class _FakeTrainingClient(TrainingClient): + def __init__(self): + self.holder = _DummyHolder() + self.forward_calls: list[ + tuple[list[types.Datum], types.LossFnType, dict[str, float] | None] + ] = [] + self.backward_calls: list[ + tuple[list[types.Datum], types.LossFnType, dict[str, float] | None] + ] = [] + + async def forward_async( + self, + data: list[types.Datum], + loss_fn: types.LossFnType, + loss_fn_config: dict[str, float] | None = None, + ): + self.forward_calls.append((data, loss_fn, loss_fn_config)) + result = types.ForwardBackwardOutput( + metrics={}, + loss_fn_output_type="target_token_logprobs", + loss_fn_outputs=[ + { + "logprobs": types.TensorData( + data=[-3.0, -2.0, -1.0, 0.0], + dtype="float32", + shape=[2, 2], + ), + } + ], + ) + future: Future = Future() + future.set_result(result) + return AwaitableConcurrentFuture(future) + + async def forward_backward_async( + self, + data: list[types.Datum], + loss_fn: types.LossFnType, + loss_fn_config: dict[str, float] | None = None, + ): + self.backward_calls.append((data, loss_fn, loss_fn_config)) + result = types.ForwardBackwardOutput( + metrics={"base:sum": 1.0}, + loss_fn_output_type="target_token_logprobs", + loss_fn_outputs=[], + ) + future: Future = Future() + future.set_result(result) + return AwaitableConcurrentFuture(future) + + +@pytest.mark.asyncio +async def test_forward_backward_custom_supports_2d_cross_entropy_targets(): + client = _FakeTrainingClient() + datum = types.Datum( + model_input=types.ModelInput.from_ints([1, 2]), + loss_fn_inputs={ + "target_tokens": [[101, 102], [201, 202]], + }, + ) + + assert datum.loss_fn_inputs["target_tokens"].shape == [2, 2] + + def custom_loss( + data: list[types.Datum], logprobs_list: list[torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, float]]: + del data + logprobs = logprobs_list[0] + assert logprobs.shape == (2, 2) + probs = torch.softmax(logprobs[1], dim=-1) + target_distribution = torch.tensor([0.0, 1.0], dtype=torch.float32) + loss = torch.sum((probs - target_distribution) ** 2) + return loss, {"selected_prob:mean": float(probs[1].detach())} + + result_future = await client.forward_backward_custom_async( + [datum], + custom_loss, + loss_type_input="logprobs", + ) + result = await result_future.result_async() + + assert client.forward_calls[0][1] == "cross_entropy" + forward_datum = client.forward_calls[0][0][0] + assert forward_datum.loss_fn_inputs["weights"].shape == [2, 2] + + assert client.backward_calls[0][1] == "cross_entropy" + backward_datum = client.backward_calls[0][0][0] + assert backward_datum.loss_fn_inputs["target_tokens"].shape == [2, 2] + assert backward_datum.loss_fn_inputs["weights"].shape == [2, 2] + assert "weights" not in datum.loss_fn_inputs + assert result.metrics["selected_prob:mean"] > 0.0 + + +@pytest.mark.asyncio +async def test_forward_backward_custom_preserves_1d_cross_entropy_targets(): + client = _FakeTrainingClient() + datum = types.Datum( + model_input=types.ModelInput.from_ints([1, 2]), + loss_fn_inputs={"target_tokens": [101, 102]}, + ) + + async def forward_async_1d( + data: list[types.Datum], + loss_fn: types.LossFnType, + loss_fn_config: dict[str, float] | None = None, + ): + client.forward_calls.append((data, loss_fn, loss_fn_config)) + result = types.ForwardBackwardOutput( + metrics={}, + loss_fn_output_type="target_token_logprobs", + loss_fn_outputs=[ + { + "logprobs": types.TensorData( + data=[-3.0, -1.0], + dtype="float32", + shape=[2], + ), + } + ], + ) + future: Future = Future() + future.set_result(result) + return AwaitableConcurrentFuture(future) + + setattr(client, "forward_async", forward_async_1d) + + def custom_loss( + data: list[types.Datum], logprobs_list: list[torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, float]]: + del data + logprobs = logprobs_list[0] + assert logprobs.shape == (2,) + loss = -logprobs[-1] + return loss, {"selected_logprob:last": float(logprobs[-1].detach())} + + result_future = await client.forward_backward_custom_async( + [datum], + custom_loss, + loss_type_input="logprobs", + ) + result = await result_future.result_async() + + assert client.forward_calls[0][1] == "cross_entropy" + forward_datum = client.forward_calls[0][0][0] + assert forward_datum.loss_fn_inputs["weights"].shape == [2] + + assert client.backward_calls[0][1] == "cross_entropy" + backward_datum = client.backward_calls[0][0][0] + assert backward_datum.loss_fn_inputs["target_tokens"].shape == [2] + assert backward_datum.loss_fn_inputs["weights"].shape == [2] + torch.testing.assert_close( + torch.tensor(backward_datum.loss_fn_inputs["weights"].data).reshape( + backward_datum.loss_fn_inputs["weights"].shape + ), + torch.tensor([0.0, 1.0], dtype=torch.float32), + ) + assert result.metrics["selected_logprob:last"] < 0.0 + + +@pytest.mark.asyncio +async def test_forward_backward_custom_rejects_unsupported_loss_type_input(): + client = _FakeTrainingClient() + datum = types.Datum( + model_input=types.ModelInput.from_ints([1, 2]), + loss_fn_inputs={"target_tokens": [101, 102]}, + ) + + def custom_loss( + data: list[types.Datum], logprobs_list: list[torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, float]]: + del data, logprobs_list + return torch.tensor(0.0, requires_grad=True), {} + + with pytest.raises(ValueError, match="Unsupported loss_type_input"): + await client.forward_backward_custom_async( + [datum], + custom_loss, + loss_type_input="logits", # type: ignore[arg-type] + ) + + +def test_datum_rejects_ragged_nested_target_tokens(): + with pytest.raises(ValueError, match="ragged nested lists are not supported"): + types.Datum( + model_input=types.ModelInput.from_ints([1, 2]), + loss_fn_inputs={"target_tokens": [[101, 102], [201]]}, + ) + + +@pytest.mark.asyncio +async def test_forward_backward_custom_rejects_unexpected_loss_fn_input_keys(): + client = _FakeTrainingClient() + datum = types.Datum( + model_input=types.ModelInput.from_ints([1, 2]), + loss_fn_inputs={ + "target_tokens": [101, 102], + "advantages": [1.0, 1.0], + }, + ) + + def custom_loss( + data: list[types.Datum], logprobs_list: list[torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, float]]: + del data, logprobs_list + return torch.tensor(0.0, requires_grad=True), {} + + with pytest.raises(ValueError, match="only supports loss_fn_inputs keys"): + await client.forward_backward_custom_async( + [datum], + custom_loss, + loss_type_input="logprobs", + )