diff --git a/.sync_state b/.sync_state index 19b7c46..4421eae 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "a4665d66b6c66527095652259bc8810cf843f425", - "last_sync_time": "2025-11-21T01:42:38.758898" + "last_synced_sha": "f39e53a99d7b8847b819df7c566f21fd64845ac0", + "last_sync_time": "2025-11-25T03:53:31.673304" } \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..92b774d --- /dev/null +++ b/docs/README.md @@ -0,0 +1,15 @@ +# Tinker SDK Documentation + +This directory contains auto-generated API documentation for the Tinker Python SDK. + +### Generate Docs + +Run the documentation generation script: +```bash +uv run scripts/generate_docs.py +``` + +### Notes + +* Only types/classes/methods with doc-string attached will have docs generated +* Please check in the generated artifacts diff --git a/docs/api/_meta.json b/docs/api/_meta.json new file mode 100644 index 0000000..ee602be --- /dev/null +++ b/docs/api/_meta.json @@ -0,0 +1,9 @@ +{ + "serviceclient": "ServiceClient", + "trainingclient": "TrainingClient", + "samplingclient": "SamplingClient", + "restclient": "RestClient", + "apifuture": "APIFuture", + "types": "Parameters", + "exceptions": "Exceptions" +} diff --git a/docs/api/apifuture.md b/docs/api/apifuture.md new file mode 100644 index 0000000..871dfa2 --- /dev/null +++ b/docs/api/apifuture.md @@ -0,0 +1,155 @@ +# `tinker.lib.public_interfaces.api_future` + +API Future classes for handling async operations with retry logic. + +## `APIFuture` Objects + +```python +class APIFuture(ABC, Generic[T]) +``` + +Abstract base class for futures that can be awaited or accessed synchronously. + +APIFuture provides a unified interface for handling async operations that can +be accessed both synchronously (via result()) and asynchronously (via await or result_async()). +This allows for flexible usage patterns in both sync and async contexts. + +The future can be awaited directly in async contexts: +```python +result = await api_future # Equivalent to await api_future.result_async() +``` + +Or accessed synchronously: +```python +result = api_future.result() # Blocks until complete +``` + +Args: + T: The type of the result value + +Example: +```python +# In async context +future = training_client.forward_backward(data, "cross_entropy") +result = await future # Or await future.result_async() + +# In sync context +future = training_client.forward_backward(data, "cross_entropy") +result = future.result() +``` + +#### `result_async` + +```python +@abstractmethod +async def result_async(timeout: float | None = None) -> T +``` + +Get the result asynchronously with optional timeout. + +Args: + timeout: Maximum time to wait in seconds. None means wait indefinitely. + +Returns: + The result value of type T + +Raises: + TimeoutError: If timeout is exceeded + +#### `result` + +```python +@abstractmethod +def result(timeout: float | None = None) -> T +``` + +Get the result synchronously with optional timeout. + +Args: + timeout: Maximum time to wait in seconds. None means wait indefinitely. + +Returns: + The result value of type T + +Raises: + TimeoutError: If timeout is exceeded + +## `AwaitableConcurrentFuture` Objects + +```python +class AwaitableConcurrentFuture(APIFuture[T]) +``` + +Implementation of APIFuture that wraps a concurrent.futures.Future. + +This class bridges Python's concurrent.futures with asyncio, allowing a +standard Future to be used in async contexts. It's commonly returned by +Tinker API methods to provide both sync and async access patterns. + +Args: + future: A concurrent.futures.Future to wrap + +Example: +```python +# Internal usage - typically you receive these from API methods +concurrent_future = some_operation() +api_future = AwaitableConcurrentFuture(concurrent_future) + +# Can be used synchronously +result = api_future.result() + +# Or asynchronously +result = await api_future +``` + +#### `result` + +```python +def result(timeout: float | None = None) -> T +``` + +Get the result synchronously with optional timeout. + +Args: + timeout: Maximum time to wait in seconds. None means wait indefinitely. + +Returns: + The result value of type T + +Raises: + TimeoutError: If timeout is exceeded + Exception: Any exception raised by the underlying operation + +Example: +```python +future = rest_client.get_training_run("run-id") +result = future.result(timeout=30) # Wait up to 30 seconds +``` + +#### `result_async` + +```python +async def result_async(timeout: float | None = None) -> T +``` + +Async version of result. + +#### `future` + +```python +def future() -> ConcurrentFuture[T] +``` + +Get the underlying concurrent.futures.Future. + +Returns: + The wrapped ConcurrentFuture object + +Example: +```python +api_future = rest_client.get_training_run("run-id") +concurrent_future = api_future.future() +# Can now use standard concurrent.futures methods +if concurrent_future.done(): + result = concurrent_future.result() +``` diff --git a/docs/api/exceptions.md b/docs/api/exceptions.md new file mode 100644 index 0000000..cbdecf2 --- /dev/null +++ b/docs/api/exceptions.md @@ -0,0 +1,132 @@ +# `tinker._exceptions` + +## `TinkerError` Objects + +```python +class TinkerError(Exception) +``` + +Base exception for all Tinker-related errors. + +## `APIError` Objects + +```python +class APIError(TinkerError) +``` + +Base class for all API-related errors. + +#### `body` + +The API response body. + +If the API responded with a valid JSON structure then this property will be the +decoded result. + +If it isn't a valid JSON structure then this will be the raw response. + +If there was no response associated with this error then it will be `None`. + +## `APIResponseValidationError` Objects + +```python +class APIResponseValidationError(APIError) +``` + +Raised when API response doesn't match expected schema. + +## `APIStatusError` Objects + +```python +class APIStatusError(APIError) +``` + +Raised when an API response has a status code of 4xx or 5xx. + +## `APIConnectionError` Objects + +```python +class APIConnectionError(APIError) +``` + +Raised when a connection error occurs while making an API request. + +## `APITimeoutError` Objects + +```python +class APITimeoutError(APIConnectionError) +``` + +Raised when an API request times out. + +## `BadRequestError` Objects + +```python +class BadRequestError(APIStatusError) +``` + +HTTP 400: The request was invalid or malformed. + +## `AuthenticationError` Objects + +```python +class AuthenticationError(APIStatusError) +``` + +HTTP 401: Authentication credentials are missing or invalid. + +## `PermissionDeniedError` Objects + +```python +class PermissionDeniedError(APIStatusError) +``` + +HTTP 403: Insufficient permissions to access the resource. + +## `NotFoundError` Objects + +```python +class NotFoundError(APIStatusError) +``` + +HTTP 404: The requested resource was not found. + +## `ConflictError` Objects + +```python +class ConflictError(APIStatusError) +``` + +HTTP 409: The request conflicts with the current state of the resource. + +## `UnprocessableEntityError` Objects + +```python +class UnprocessableEntityError(APIStatusError) +``` + +HTTP 422: The request was well-formed but contains semantic errors. + +## `RateLimitError` Objects + +```python +class RateLimitError(APIStatusError) +``` + +HTTP 429: Too many requests, rate limit exceeded. + +## `InternalServerError` Objects + +```python +class InternalServerError(APIStatusError) +``` + +HTTP 500+: An error occurred on the server. + +## `RequestFailedError` Objects + +```python +class RequestFailedError(TinkerError) +``` + +Raised when an asynchronous request completes in a failed state. diff --git a/docs/api/restclient.md b/docs/api/restclient.md new file mode 100644 index 0000000..d50723c --- /dev/null +++ b/docs/api/restclient.md @@ -0,0 +1,556 @@ +# `tinker.lib.public_interfaces.rest_client` + +RestClient for Tinker API REST operations. + +## `RestClient` Objects + +```python +class RestClient(TelemetryProvider) +``` + +Client for REST API operations like listing checkpoints and metadata. + +The RestClient provides access to various REST endpoints for querying +model information, checkpoints, and other resources. You typically get one +by calling `service_client.create_rest_client()`. + +Key methods: +- list_checkpoints() - list available model checkpoints (both training and sampler) +- list_user_checkpoints() - list all checkpoints across all user's training runs +- get_training_run() - get model information and metadata as ModelEntry +- delete_checkpoint() - delete an existing checkpoint for a training run +- get_checkpoint_archive_url() - get signed URL to download checkpoint archive +- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public +- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private + +Args: + holder: Internal client managing HTTP connections and async operations + +Example: +```python +rest_client = service_client.create_rest_client() +training_run = rest_client.get_training_run("run-id").result() +print(f"Training Run: {training_run.training_run_id}, LoRA: {training_run.is_lora}") +checkpoints = rest_client.list_checkpoints("run-id").result() +print(f"Found {len(checkpoints.checkpoints)} checkpoints") +for checkpoint in checkpoints.checkpoints: + print(f" {checkpoint.checkpoint_type}: {checkpoint.checkpoint_id}") +``` + +#### `get_training_run` + +```python +@sync_only +@capture_exceptions(fatal=True) +def get_training_run( + training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun] +``` + +Get training run info. + +Args: + training_run_id: The training run ID to get information for + +Returns: + A Future containing the training run information + +Example: +```python +future = rest_client.get_training_run("run-id") +response = future.result() +print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}") +``` + +#### `get_training_run_async` + +```python +@capture_exceptions(fatal=True) +async def get_training_run_async( + training_run_id: types.ModelID) -> types.TrainingRun +``` + +Async version of get_training_run. + +#### `get_training_run_by_tinker_path` + +```python +@sync_only +@capture_exceptions(fatal=True) +def get_training_run_by_tinker_path( + tinker_path: str) -> ConcurrentFuture[types.TrainingRun] +``` + +Get training run info. + +Args: + tinker_path: The tinker path to the checkpoint + +Returns: + A Future containing the training run information + +Example: +```python +future = rest_client.get_training_run_by_tinker_path("tinker://run-id/weights/checkpoint-001") +response = future.result() +print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}") +``` + +#### `get_training_run_by_tinker_path_async` + +```python +@capture_exceptions(fatal=True) +async def get_training_run_by_tinker_path_async( + tinker_path: str) -> types.TrainingRun +``` + +Async version of get_training_run_by_tinker_path. + +#### `get_weights_info_by_tinker_path` + +```python +@capture_exceptions(fatal=True) +def get_weights_info_by_tinker_path( + tinker_path: str) -> APIFuture[types.WeightsInfoResponse] +``` + +Get checkpoint information from a tinker path. + +Args: + tinker_path: The tinker path to the checkpoint + +Returns: + An APIFuture containing the checkpoint information. The future is awaitable. + +Example: +```python +future = rest_client.get_weights_info_by_tinker_path("tinker://run-id/weights/checkpoint-001") +response = future.result() # or await future +print(f"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}") +``` + +#### `list_training_runs` + +```python +@sync_only +@capture_exceptions(fatal=True) +def list_training_runs( + limit: int = 20, + offset: int = 0) -> ConcurrentFuture[types.TrainingRunsResponse] +``` + +List training runs with pagination support. + +Args: + limit: Maximum number of training runs to return (default 20) + offset: Offset for pagination (default 0) + +Returns: + A Future containing the TrainingRunsResponse with training runs and cursor info + +Example: +```python +future = rest_client.list_training_runs(limit=50) +response = future.result() +print(f"Found {len(response.training_runs)} training runs") +print(f"Total: {response.cursor.total_count}") +# Get next page +next_page = rest_client.list_training_runs(limit=50, offset=50) +``` + +#### `list_training_runs_async` + +```python +@capture_exceptions(fatal=True) +async def list_training_runs_async(limit: int = 20, + offset: int = 0 + ) -> types.TrainingRunsResponse +``` + +Async version of list_training_runs. + +#### `list_checkpoints` + +```python +@sync_only +@capture_exceptions(fatal=True) +def list_checkpoints( + training_run_id: types.ModelID +) -> ConcurrentFuture[types.CheckpointsListResponse] +``` + +List available checkpoints (both training and sampler). + +Args: + training_run_id: The training run ID to list checkpoints for + +Returns: + A Future containing the CheckpointsListResponse with available checkpoints + +Example: +```python +future = rest_client.list_checkpoints("run-id") +response = future.result() +for checkpoint in response.checkpoints: + if checkpoint.checkpoint_type == "training": + print(f"Training checkpoint: {checkpoint.checkpoint_id}") + elif checkpoint.checkpoint_type == "sampler": + print(f"Sampler checkpoint: {checkpoint.checkpoint_id}") +``` + +#### `list_checkpoints_async` + +```python +@capture_exceptions(fatal=True) +async def list_checkpoints_async( + training_run_id: types.ModelID) -> types.CheckpointsListResponse +``` + +Async version of list_checkpoints. + +#### `get_checkpoint_archive_url` + +```python +@sync_only +@capture_exceptions(fatal=True) +def get_checkpoint_archive_url( + training_run_id: types.ModelID, checkpoint_id: str +) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse] +``` + +Get signed URL to download checkpoint archive. + +Args: + training_run_id: The training run ID to download weights for + checkpoint_id: The checkpoint ID to download + +Returns: + A Future containing the CheckpointArchiveUrlResponse with signed URL and expiration + +Example: +```python +future = rest_client.get_checkpoint_archive_url("run-id", "checkpoint-123") +response = future.result() +print(f"Download URL: {response.url}") +print(f"Expires at: {response.expires_at}") +# Use the URL to download the archive with your preferred HTTP client +``` + +#### `get_checkpoint_archive_url_async` + +```python +@capture_exceptions(fatal=True) +async def get_checkpoint_archive_url_async( + training_run_id: types.ModelID, + checkpoint_id: str) -> types.CheckpointArchiveUrlResponse +``` + +Async version of get_checkpoint_archive_url. + +#### `delete_checkpoint` + +```python +@sync_only +@capture_exceptions(fatal=True) +def delete_checkpoint(training_run_id: types.ModelID, + checkpoint_id: str) -> ConcurrentFuture[None] +``` + +Delete a checkpoint for a training run. + +#### `delete_checkpoint_async` + +```python +@capture_exceptions(fatal=True) +async def delete_checkpoint_async(training_run_id: types.ModelID, + checkpoint_id: str) -> None +``` + +Async version of delete_checkpoint. + +#### `delete_checkpoint_from_tinker_path` + +```python +@sync_only +@capture_exceptions(fatal=True) +def delete_checkpoint_from_tinker_path( + tinker_path: str) -> ConcurrentFuture[None] +``` + +Delete a checkpoint referenced by a tinker path. + +#### `delete_checkpoint_from_tinker_path_async` + +```python +@capture_exceptions(fatal=True) +async def delete_checkpoint_from_tinker_path_async(tinker_path: str) -> None +``` + +Async version of delete_checkpoint_from_tinker_path. + +#### `get_checkpoint_archive_url_from_tinker_path` + +```python +@sync_only +@capture_exceptions(fatal=True) +def get_checkpoint_archive_url_from_tinker_path( + tinker_path: str +) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse] +``` + +Get signed URL to download checkpoint archive. + +Args: + tinker_path: The tinker path to the checkpoint + +Returns: + A Future containing the CheckpointArchiveUrlResponse with signed URL and expiration + +#### `get_checkpoint_archive_url_from_tinker_path_async` + +```python +@capture_exceptions(fatal=True) +async def get_checkpoint_archive_url_from_tinker_path_async( + tinker_path: str) -> types.CheckpointArchiveUrlResponse +``` + +Async version of get_checkpoint_archive_url_from_tinker_path. + +#### `publish_checkpoint_from_tinker_path` + +```python +@sync_only +@capture_exceptions(fatal=True) +def publish_checkpoint_from_tinker_path( + tinker_path: str) -> ConcurrentFuture[None] +``` + +Publish a checkpoint referenced by a tinker path to make it publicly accessible. + +Only the exact owner of the training run can publish checkpoints. +Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path method. + +Args: + tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") + +Returns: + A Future that completes when the checkpoint is published + +Raises: + HTTPException: 400 if checkpoint identifier is invalid + HTTPException: 404 if checkpoint not found or user doesn't own the training run + HTTPException: 409 if checkpoint is already public + HTTPException: 500 if there's an error publishing the checkpoint + +Example: +```python +future = rest_client.publish_checkpoint_from_tinker_path("tinker://run-id/weights/0001") +future.result() # Wait for completion +print("Checkpoint published successfully") +``` + +#### `publish_checkpoint_from_tinker_path_async` + +```python +@capture_exceptions(fatal=True) +async def publish_checkpoint_from_tinker_path_async(tinker_path: str) -> None +``` + +Async version of publish_checkpoint_from_tinker_path. + +#### `unpublish_checkpoint_from_tinker_path` + +```python +@sync_only +@capture_exceptions(fatal=True) +def unpublish_checkpoint_from_tinker_path( + tinker_path: str) -> ConcurrentFuture[None] +``` + +Unpublish a checkpoint referenced by a tinker path to make it private again. + +Only the exact owner of the training run can unpublish checkpoints. +This reverses the effect of publishing a checkpoint. + +Args: + tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") + +Returns: + A Future that completes when the checkpoint is unpublished + +Raises: + HTTPException: 400 if checkpoint identifier is invalid + HTTPException: 404 if checkpoint not found or user doesn't own the training run + HTTPException: 409 if checkpoint is already private + HTTPException: 500 if there's an error unpublishing the checkpoint + +Example: +```python +future = rest_client.unpublish_checkpoint_from_tinker_path("tinker://run-id/weights/0001") +future.result() # Wait for completion +print("Checkpoint unpublished successfully") +``` + +#### `unpublish_checkpoint_from_tinker_path_async` + +```python +@capture_exceptions(fatal=True) +async def unpublish_checkpoint_from_tinker_path_async( + tinker_path: str) -> None +``` + +Async version of unpublish_checkpoint_from_tinker_path. + +#### `list_user_checkpoints` + +```python +@sync_only +@capture_exceptions(fatal=True) +def list_user_checkpoints( + limit: int = 100, + offset: int = 0) -> ConcurrentFuture[types.CheckpointsListResponse] +``` + +List all checkpoints for the current user across all their training runs. + +This method retrieves checkpoints from all training runs owned by the authenticated user, +sorted by time (newest first). It supports pagination for efficiently handling large +numbers of checkpoints. + +Args: + limit: Maximum number of checkpoints to return (default 100) + offset: Offset for pagination (default 0) + +Returns: + A Future containing the CheckpointsListResponse with checkpoints and cursor info + +Example: +```python +future = rest_client.list_user_checkpoints(limit=50) +response = future.result() +print(f"Found {len(response.checkpoints)} checkpoints") +print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}") +for checkpoint in response.checkpoints: + print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}") +# Get next page if there are more checkpoints +if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count: + next_page = rest_client.list_user_checkpoints(limit=50, offset=50) +``` + +#### `list_user_checkpoints_async` + +```python +@capture_exceptions(fatal=True) +async def list_user_checkpoints_async(limit: int = 100, + offset: int = 0 + ) -> types.CheckpointsListResponse +``` + +Async version of list_user_checkpoints. + +#### `get_session` + +```python +@sync_only +@capture_exceptions(fatal=True) +def get_session(session_id: str) -> ConcurrentFuture[types.GetSessionResponse] +``` + +Get session information including all training runs and samplers. + +Args: + session_id: The session ID to get information for + +Returns: + A Future containing the GetSessionResponse with training_run_ids and sampler_ids + +Example: +```python +future = rest_client.get_session("session-id") +response = future.result() +print(f"Training runs: {len(response.training_run_ids)}") +print(f"Samplers: {len(response.sampler_ids)}") +``` + +#### `get_session_async` + +```python +@capture_exceptions(fatal=True) +async def get_session_async(session_id: str) -> types.GetSessionResponse +``` + +Async version of get_session. + +#### `list_sessions` + +```python +@sync_only +@capture_exceptions(fatal=True) +def list_sessions( + limit: int = 20, + offset: int = 0) -> ConcurrentFuture[types.ListSessionsResponse] +``` + +List sessions with pagination support. + +Args: + limit: Maximum number of sessions to return (default 20) + offset: Offset for pagination (default 0) + +Returns: + A Future containing the ListSessionsResponse with list of session IDs + +Example: +```python +future = rest_client.list_sessions(limit=50) +response = future.result() +print(f"Found {len(response.sessions)} sessions") +# Get next page +next_page = rest_client.list_sessions(limit=50, offset=50) +``` + +#### `list_sessions_async` + +```python +@capture_exceptions(fatal=True) +async def list_sessions_async(limit: int = 20, + offset: int = 0) -> types.ListSessionsResponse +``` + +Async version of list_sessions. + +#### `get_sampler` + +```python +@capture_exceptions(fatal=True) +def get_sampler(sampler_id: str) -> APIFuture[types.GetSamplerResponse] +``` + +Get sampler information. + +Args: + sampler_id: The sampler ID (sampling_session_id) to get information for + +Returns: + An APIFuture containing the GetSamplerResponse with sampler details + +Example: +```python +# Sync usage +future = rest_client.get_sampler("session-id:sample:0") +response = future.result() +print(f"Base model: {response.base_model}") +print(f"Model path: {response.model_path}") + +# Async usage +response = await rest_client.get_sampler("session-id:sample:0") +print(f"Base model: {response.base_model}") +``` + +#### `get_sampler_async` + +```python +@capture_exceptions(fatal=True) +async def get_sampler_async(sampler_id: str) -> types.GetSamplerResponse +``` + +Async version of get_sampler. diff --git a/docs/api/samplingclient.md b/docs/api/samplingclient.md new file mode 100644 index 0000000..95f95a3 --- /dev/null +++ b/docs/api/samplingclient.md @@ -0,0 +1,116 @@ +# `tinker.lib.public_interfaces.sampling_client` + +SamplingClient for Tinker API. + +## `SamplingClient` Objects + +```python +class SamplingClient(TelemetryProvider, QueueStateObserver) +``` + +Client for text generation and inference from trained or base models. + +The SamplingClient lets you generate text tokens from either a base model or from weights +you've saved using a TrainingClient. You typically get one by calling +`service_client.create_sampling_client()` or `training_client.save_weights_and_get_sampling_client()`. +Key methods: +- sample() - generate text completions with customizable parameters +- compute_logprobs() - get log probabilities for prompt tokens + +Args: + holder: Internal client managing HTTP connections and async operations + model_path: Path to saved model weights (starts with 'tinker://') + base_model: Name of base model to use for inference + retry_config: Configuration for retrying failed requests + +Example: +```python +sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B") +prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is")) +params = types.SamplingParams(max_tokens=20, temperature=0.7) +future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1) +result = future.result() +``` + +#### `sample` + +```python +@capture_exceptions(fatal=True) +def sample( + prompt: types.ModelInput, + num_samples: int, + sampling_params: types.SamplingParams, + include_prompt_logprobs: bool = False, + topk_prompt_logprobs: int = 0 +) -> ConcurrentFuture[types.SampleResponse] +``` + +Generate text completions from the model. + +Args: + prompt: The input tokens as ModelInput + num_samples: Number of independent samples to generate + sampling_params: Parameters controlling generation (temperature, max_tokens, etc.) + include_prompt_logprobs: Whether to include log probabilities for prompt tokens + topk_prompt_logprobs: Number of top token log probabilities to return per position + +Returns: + A Future containing the SampleResponse with generated text + +Example: +```python +prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is")) +params = types.SamplingParams(max_tokens=20, temperature=0.7) +future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1) +result = future.result() +for sample in result.samples: + print(tokenizer.decode(sample.tokens)) +``` + +#### `sample_async` + +```python +async def sample_async(prompt: types.ModelInput, + num_samples: int, + sampling_params: types.SamplingParams, + include_prompt_logprobs: bool = False, + topk_prompt_logprobs: int = 0) -> types.SampleResponse +``` + +Async version of sample. + +#### `compute_logprobs` + +```python +@capture_exceptions(fatal=True) +def compute_logprobs( + prompt: types.ModelInput) -> ConcurrentFuture[list[float | None]] +``` + +Compute log probabilities for prompt tokens. + +Args: + prompt: The input tokens as ModelInput + +Returns: + A Future containing a list of log probabilities for each token in the prompt. + None values indicate tokens where log probabilities couldn't be computed. + +Example: +```python +prompt = types.ModelInput.from_ints(tokenizer.encode("Hello world")) +future = sampling_client.compute_logprobs(prompt) +logprobs = future.result() +for i, logprob in enumerate(logprobs): + if logprob is not None: + print(f"Token {i}: logprob = {logprob:.4f}") +``` + +#### `compute_logprobs_async` + +```python +async def compute_logprobs_async( + prompt: types.ModelInput) -> list[float | None] +``` + +Async version of compute_logprobs. diff --git a/docs/api/serviceclient.md b/docs/api/serviceclient.md new file mode 100644 index 0000000..cd0a9e4 --- /dev/null +++ b/docs/api/serviceclient.md @@ -0,0 +1,235 @@ +# `tinker.lib.public_interfaces.service_client` + +ServiceClient for Tinker API. + +## `ServiceClient` Objects + +```python +class ServiceClient(TelemetryProvider) +``` + +The ServiceClient is the main entry point for the Tinker API. It provides methods to: +- Query server capabilities and health status +- Generate TrainingClient instances for model training workflows +- Generate SamplingClient instances for text generation and inference +- Generate RestClient instances for REST API operations like listing weights + +Args: + **kwargs: advanced options passed to the underlying HTTP client, + including API keys, headers, and connection settings. + +Example: +```python +client = ServiceClient() +# ^^^ near-instant +training_client = client.create_lora_training_client(base_model="Qwen/Qwen3-8B") +# ^^^ takes a moment as we initialize the model and assign resources +sampling_client = client.create_sampling_client(base_model="Qwen/Qwen3-8B") +# ^^^ near-instant +rest_client = client.create_rest_client() +# ^^^ near-instant +``` + +#### `get_server_capabilities` + +```python +@sync_only +@capture_exceptions(fatal=True) +def get_server_capabilities() -> types.GetServerCapabilitiesResponse +``` + +Query the server's supported features and capabilities. + +Returns: + GetServerCapabilitiesResponse with available models, features, and limits + +Example: +```python +capabilities = service_client.get_server_capabilities() +print(f"Supported models: {capabilities.supported_models}") +print(f"Max batch size: {capabilities.max_batch_size}") +``` + +#### `get_server_capabilities_async` + +```python +@capture_exceptions(fatal=True) +async def get_server_capabilities_async( +) -> types.GetServerCapabilitiesResponse +``` + +Async version of get_server_capabilities. + +#### `create_lora_training_client` + +```python +@sync_only +@capture_exceptions(fatal=True) +def create_lora_training_client( + base_model: str, + rank: int = 32, + seed: int | None = None, + train_mlp: bool = True, + train_attn: bool = True, + train_unembed: bool = True, + user_metadata: dict[str, str] | None = None) -> TrainingClient +``` + +Create a TrainingClient for LoRA fine-tuning. + +Args: + base_model: Name of the base model to fine-tune (e.g., "Qwen/Qwen2.5-7B") + rank: LoRA rank controlling the size of adaptation matrices (default 32) + seed: Random seed for initialization. None means random seed. + train_mlp: Whether to train MLP layers (default True) + train_attn: Whether to train attention layers (default True) + train_unembed: Whether to train unembedding layers (default True) + user_metadata: Optional metadata to attach to the training run + +Returns: + TrainingClient configured for LoRA training + +Example: +```python +training_client = service_client.create_lora_training_client( + base_model="Qwen/Qwen2.5-7B", + rank=16, + train_mlp=True, + train_attn=True +) +# Now use training_client.forward_backward() to train +``` + +#### `create_lora_training_client_async` + +```python +@capture_exceptions(fatal=True) +async def create_lora_training_client_async( + base_model: str, + rank: int = 32, + seed: int | None = None, + train_mlp: bool = True, + train_attn: bool = True, + train_unembed: bool = True, + user_metadata: dict[str, str] | None = None) -> TrainingClient +``` + +Async version of create_lora_training_client. + +#### `create_training_client_from_state` + +```python +@sync_only +@capture_exceptions(fatal=True) +def create_training_client_from_state( + path: str, + user_metadata: dict[str, str] | None = None) -> TrainingClient +``` + +Create a TrainingClient from saved model weights. + +Args: + path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") + user_metadata: Optional metadata to attach to the new training run + +Returns: + TrainingClient loaded with the specified weights + +Example: +```python +# Resume training from a checkpoint +training_client = service_client.create_training_client_from_state( + "tinker://run-id/weights/checkpoint-001" +) +# Continue training from the loaded state +``` + +#### `create_training_client_from_state_async` + +```python +@capture_exceptions(fatal=True) +async def create_training_client_from_state_async( + path: str, + user_metadata: dict[str, str] | None = None) -> TrainingClient +``` + +Async version of create_training_client_from_state. + +#### `create_sampling_client` + +```python +@capture_exceptions(fatal=True) +def create_sampling_client( + model_path: str | None = None, + base_model: str | None = None, + retry_config: RetryConfig | None = None) -> SamplingClient +``` + +Create a SamplingClient for text generation. + +Args: + model_path: Path to saved model weights (e.g., "tinker://run-id/weights/checkpoint-001") + base_model: Name of base model to use (e.g., "Qwen/Qwen2.5-7B") + retry_config: Optional configuration for retrying failed requests + +Returns: + SamplingClient configured for text generation + +Raises: + ValueError: If neither model_path nor base_model is provided + +Example: +```python +# Use a base model +sampling_client = service_client.create_sampling_client( + base_model="Qwen/Qwen2.5-7B" +) + +# Or use saved weights +sampling_client = service_client.create_sampling_client( + model_path="tinker://run-id/weights/checkpoint-001" +) +``` + +#### `create_sampling_client_async` + +```python +@capture_exceptions(fatal=True) +async def create_sampling_client_async( + model_path: str | None = None, + base_model: str | None = None, + retry_config: RetryConfig | None = None) -> SamplingClient +``` + +Async version of create_sampling_client. + +#### `create_rest_client` + +```python +@capture_exceptions(fatal=True) +def create_rest_client() -> RestClient +``` + +Create a RestClient for REST API operations. + +The RestClient provides access to various REST endpoints for querying +model information, checkpoints, sessions, and managing checkpoint visibility. + +Returns: + RestClient for accessing REST API endpoints + +Example: +```python +rest_client = service_client.create_rest_client() + +# List checkpoints for a training run +checkpoints = rest_client.list_checkpoints("run-id").result() + +# Get training run info +training_run = rest_client.get_training_run("run-id").result() + +# Publish a checkpoint +rest_client.publish_checkpoint_from_tinker_path( + "tinker://run-id/weights/checkpoint-001" +).result() +``` diff --git a/docs/api/trainingclient.md b/docs/api/trainingclient.md new file mode 100644 index 0000000..74a0b7d --- /dev/null +++ b/docs/api/trainingclient.md @@ -0,0 +1,480 @@ +# `tinker.lib.public_interfaces.training_client` + +TrainingClient for Tinker API. + +## `TrainingClient` Objects + +```python +class TrainingClient(TelemetryProvider, QueueStateObserver) +``` + +Client for training ML models with forward/backward passes and optimization. + +The TrainingClient corresponds to a fine-tuned model that you can train and sample from. +You typically get one by calling `service_client.create_lora_training_client()`. +Key methods: +- forward_backward() - compute gradients for training +- optim_step() - update model parameters with Adam optimizer +- save_weights_and_get_sampling_client() - export trained model for inference + +Args: + holder: Internal client managing HTTP connections and async operations + model_id: Unique identifier for the model to train. Required for training operations. + +Example: +```python +training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B") +fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy") +optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4)) +fwdbwd_result = fwdbwd_future.result() # Wait for gradients +optim_result = optim_future.result() # Wait for parameter update +sampling_client = training_client.save_weights_and_get_sampling_client("my-model") +``` + +#### `forward` + +```python +@capture_exceptions(fatal=True) +def forward( + data: List[types.Datum], + loss_fn: types.LossFnType, + loss_fn_config: Dict[str, float] | None = None +) -> APIFuture[types.ForwardBackwardOutput] +``` + +Compute forward pass without gradients. + +Args: + data: List of training data samples + loss_fn: Loss function type (e.g., "cross_entropy") + loss_fn_config: Optional configuration for the loss function + +Returns: + APIFuture containing the forward pass outputs and loss + +Example: +```python +data = [types.Datum( + model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")), + loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))} +)] +future = training_client.forward(data, "cross_entropy") +result = await future +print(f"Loss: {result.loss}") +``` + +#### `forward_async` + +```python +async def forward_async( + data: List[types.Datum], + loss_fn: types.LossFnType, + loss_fn_config: Dict[str, float] | None = None +) -> APIFuture[types.ForwardBackwardOutput] +``` + +Async version of forward. + +#### `forward_backward` + +```python +@capture_exceptions(fatal=True) +def forward_backward( + data: List[types.Datum], + loss_fn: types.LossFnType, + loss_fn_config: Dict[str, float] | None = None +) -> APIFuture[types.ForwardBackwardOutput] +``` + +Compute forward pass and backward pass to calculate gradients. + +Args: + data: List of training data samples + loss_fn: Loss function type (e.g., "cross_entropy") + loss_fn_config: Optional configuration for the loss function + +Returns: + APIFuture containing the forward/backward outputs, loss, and gradients + +Example: +```python +data = [types.Datum( + model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")), + loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))} +)] + +# Compute gradients +fwdbwd_future = training_client.forward_backward(data, "cross_entropy") + +# Update parameters +optim_future = training_client.optim_step( + types.AdamParams(learning_rate=1e-4) +) + +fwdbwd_result = await fwdbwd_future +print(f"Loss: {fwdbwd_result.loss}") +``` + +#### `forward_backward_async` + +```python +async def forward_backward_async( + data: List[types.Datum], + loss_fn: types.LossFnType, + loss_fn_config: Dict[str, float] | None = None +) -> APIFuture[types.ForwardBackwardOutput] +``` + +Async version of forward_backward. + +#### `forward_backward_custom` + +```python +@sync_only +@capture_exceptions(fatal=True) +def forward_backward_custom( + data: List[types.Datum], + loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput] +``` + +Compute forward/backward with a custom loss function. + +Allows you to define custom loss functions that operate on log probabilities. +The custom function receives logprobs and computes loss and gradients. + +Args: + data: List of training data samples + loss_fn: Custom loss function that takes (data, logprobs) and returns (loss, metrics) + +Returns: + APIFuture containing the forward/backward outputs with custom loss + +Example: +```python +def custom_loss(data, logprobs_list): + # Custom loss computation + loss = torch.mean(torch.stack([torch.mean(lp) for lp in logprobs_list])) + metrics = {"custom_metric": loss.item()} + return loss, metrics + +future = training_client.forward_backward_custom(data, custom_loss) +result = future.result() +print(f"Custom loss: {result.loss}") +print(f"Metrics: {result.metrics}") +``` + +#### `forward_backward_custom_async` + +```python +@capture_exceptions(fatal=True) +async def forward_backward_custom_async( + data: List[types.Datum], + loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput] +``` + +Async version of forward_backward_custom. + +#### `optim_step` + +```python +@capture_exceptions(fatal=True) +def optim_step( + adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse] +``` + +Update model parameters using Adam optimizer. + +Args: + adam_params: Adam optimizer parameters (learning_rate, betas, eps, weight_decay) + +Returns: + APIFuture containing optimizer step response + +Example: +```python +# First compute gradients +fwdbwd_future = training_client.forward_backward(data, "cross_entropy") + +# Then update parameters +optim_future = training_client.optim_step( + types.AdamParams( + learning_rate=1e-4, + weight_decay=0.01 + ) +) + +# Wait for both to complete +fwdbwd_result = await fwdbwd_future +optim_result = await optim_future +``` + +#### `optim_step_async` + +```python +async def optim_step_async( + adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse] +``` + +Async version of optim_step. + +#### `save_state` + +```python +@capture_exceptions(fatal=True) +def save_state(name: str) -> APIFuture[types.SaveWeightsResponse] +``` + +Save model weights to persistent storage. + +Args: + name: Name for the saved checkpoint + +Returns: + APIFuture containing the save response with checkpoint path + +Example: +```python +# Save after training +save_future = training_client.save_state("checkpoint-001") +result = await save_future +print(f"Saved to: {result.path}") +``` + +#### `save_state_async` + +```python +async def save_state_async(name: str) -> APIFuture[types.SaveWeightsResponse] +``` + +Async version of save_state. + +#### `load_state` + +```python +@capture_exceptions(fatal=True) +def load_state(path: str) -> APIFuture[types.LoadWeightsResponse] +``` + +Load model weights from a saved checkpoint. + +Args: + path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") + +Returns: + APIFuture containing the load response + +Example: +```python +# Load checkpoint to continue training +load_future = training_client.load_state("tinker://run-id/weights/checkpoint-001") +await load_future +# Continue training from loaded state +``` + +#### `load_state_async` + +```python +async def load_state_async(path: str) -> APIFuture[types.LoadWeightsResponse] +``` + +Async version of load_state. + +#### `load_state_with_optimizer` + +```python +@capture_exceptions(fatal=True) +def load_state_with_optimizer( + path: str) -> APIFuture[types.LoadWeightsResponse] +``` + +Load model weights and optimizer state from a checkpoint. + +Args: + path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") + +Returns: + APIFuture containing the load response + +Example: +```python +# Resume training with optimizer state +load_future = training_client.load_state_with_optimizer( + "tinker://run-id/weights/checkpoint-001" +) +await load_future +# Continue training with restored optimizer momentum +``` + +#### `load_state_with_optimizer_async` + +```python +async def load_state_with_optimizer_async( + path: str) -> APIFuture[types.LoadWeightsResponse] +``` + +Async version of load_state_with_optimizer. + +#### `save_weights_for_sampler` + +```python +@capture_exceptions(fatal=True) +def save_weights_for_sampler( + name: str) -> APIFuture[types.SaveWeightsForSamplerResponse] +``` + +Save model weights for use with a SamplingClient. + +Args: + name: Name for the saved sampler weights + +Returns: + APIFuture containing the save response with sampler path + +Example: +```python +# Save weights for inference +save_future = training_client.save_weights_for_sampler("sampler-001") +result = await save_future +print(f"Sampler weights saved to: {result.path}") + +# Use the path to create a sampling client +sampling_client = service_client.create_sampling_client( + model_path=result.path +) +``` + +#### `save_weights_for_sampler_async` + +```python +async def save_weights_for_sampler_async( + name: str) -> APIFuture[types.SaveWeightsForSamplerResponse] +``` + +Async version of save_weights_for_sampler. + +#### `get_info` + +```python +@sync_only +@capture_exceptions(fatal=True) +def get_info() -> types.GetInfoResponse +``` + +Get information about the current model. + +Returns: + GetInfoResponse with model configuration and metadata + +Example: +```python +info = training_client.get_info() +print(f"Model ID: {info.model_data.model_id}") +print(f"Base model: {info.model_data.model_name}") +print(f"LoRA rank: {info.model_data.lora_rank}") +``` + +#### `get_info_async` + +```python +@capture_exceptions(fatal=True) +async def get_info_async() -> types.GetInfoResponse +``` + +Async version of get_info. + +#### `get_tokenizer` + +```python +@cache +@capture_exceptions(fatal=True) +def get_tokenizer() -> PreTrainedTokenizer +``` + +Get the tokenizer for the current model. + +Returns: + PreTrainedTokenizer compatible with the model + +Example: +```python +tokenizer = training_client.get_tokenizer() +tokens = tokenizer.encode("Hello world") +text = tokenizer.decode(tokens) +``` + +#### `create_sampling_client` + +```python +@capture_exceptions(fatal=True) +def create_sampling_client( + model_path: str, + retry_config: RetryConfig | None = None) -> SamplingClient +``` + +Create a SamplingClient from saved weights. + +Args: + model_path: Tinker path to saved weights + retry_config: Optional configuration for retrying failed requests + +Returns: + SamplingClient configured with the specified weights + +Example: +```python +sampling_client = training_client.create_sampling_client( + "tinker://run-id/weights/checkpoint-001" +) +# Use sampling_client for inference +``` + +#### `create_sampling_client_async` + +```python +@capture_exceptions(fatal=True) +async def create_sampling_client_async( + model_path: str, + retry_config: RetryConfig | None = None) -> SamplingClient +``` + +Async version of create_sampling_client. + +#### `save_weights_and_get_sampling_client` + +```python +@capture_exceptions(fatal=True) +def save_weights_and_get_sampling_client( + name: str | None = None, + retry_config: RetryConfig | None = None) -> SamplingClient +``` + +Save current weights and create a SamplingClient for inference. + +Args: + name: Optional name for the saved weights (currently ignored for ephemeral saves) + retry_config: Optional configuration for retrying failed requests + +Returns: + SamplingClient configured with the current model weights + +Example: +```python +# After training, create a sampling client directly +sampling_client = training_client.save_weights_and_get_sampling_client() + +# Now use it for inference +prompt = types.ModelInput.from_ints(tokenizer.encode("Hello")) +params = types.SamplingParams(max_tokens=20) +result = sampling_client.sample(prompt, 1, params).result() +``` + +#### `save_weights_and_get_sampling_client_async` + +```python +@capture_exceptions(fatal=True) +async def save_weights_and_get_sampling_client_async( + name: str | None = None, + retry_config: RetryConfig | None = None) -> SamplingClient +``` + +Async version of save_weights_and_get_sampling_client. diff --git a/docs/api/types.md b/docs/api/types.md new file mode 100644 index 0000000..5aaf180 --- /dev/null +++ b/docs/api/types.md @@ -0,0 +1,911 @@ +# `tinker.types.optim_step_request` + +## `AdamParams` Objects + +```python +class AdamParams(StrictBase) +``` + +#### `learning_rate` + +Learning rate for the optimizer + +#### `beta1` + +Coefficient used for computing running averages of gradient + +#### `beta2` + +Coefficient used for computing running averages of gradient square + +#### `eps` + +Term added to the denominator to improve numerical stability + +# `tinker.types.optim_step_response` + +## `OptimStepResponse` Objects + +```python +class OptimStepResponse(BaseModel) +``` + +#### `metrics` + +Optimization step metrics as key-value pairs + +# `tinker.types.model_input` + +## `ModelInput` Objects + +```python +class ModelInput(StrictBase) +``` + +#### `chunks` + +Sequence of input chunks (formerly TokenSequence) + +#### `from_ints` + +```python +@classmethod +def from_ints(cls, tokens: List[int]) -> "ModelInput" +``` + +Create a ModelInput from a list of ints (tokens). + +#### `to_ints` + +```python +def to_ints() -> List[int] +``` + +Convert the ModelInput to a list of ints (tokens) +Throws exception if there are any non-token chunks + +#### `length` + +```python +@property +def length() -> int +``` + +Return the total context length used by this ModelInput. + +#### `empty` + +```python +@classmethod +def empty(cls) -> "ModelInput" +``` + +Create an empty ModelInput. + +#### `append` + +```python +def append(chunk: ModelInputChunk) -> "ModelInput" +``` + +Add a new chunk, return a new ModelInput. + +#### `append_int` + +```python +def append_int(token: int) -> "ModelInput" +``` + +Add a new token, return a new ModelInput. + +# `tinker.types.weights_info_response` + +## `WeightsInfoResponse` Objects + +```python +class WeightsInfoResponse(BaseModel) +``` + +Minimal information for loading public checkpoints. + +# `tinker.types.checkpoint` + +## `Checkpoint` Objects + +```python +class Checkpoint(BaseModel) +``` + +#### `checkpoint_id` + +The checkpoint ID + +#### `checkpoint_type` + +The type of checkpoint (training or sampler) + +#### `time` + +The time when the checkpoint was created + +#### `tinker_path` + +The tinker path to the checkpoint + +#### `size_bytes` + +The size of the checkpoint in bytes + +#### `public` + +Whether the checkpoint is publicly accessible + +## `ParsedCheckpointTinkerPath` Objects + +```python +class ParsedCheckpointTinkerPath(BaseModel) +``` + +#### `tinker_path` + +The tinker path to the checkpoint + +#### `training_run_id` + +The training run ID + +#### `checkpoint_type` + +The type of checkpoint (training or sampler) + +#### `checkpoint_id` + +The checkpoint ID + +#### `from_tinker_path` + +```python +@classmethod +def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath" +``` + +Parse a tinker path to an instance of ParsedCheckpointTinkerPath + +# `tinker.types.checkpoint_archive_url_response` + +## `CheckpointArchiveUrlResponse` Objects + +```python +class CheckpointArchiveUrlResponse(BaseModel) +``` + +#### `url` + +Signed URL to download the checkpoint archive + +#### `expires` + +Unix timestamp when the signed URL expires, if available + +# `tinker.types.sampled_sequence` + +## `SampledSequence` Objects + +```python +class SampledSequence(BaseModel) +``` + +#### `stop_reason` + +Reason why sampling stopped + +#### `tokens` + +List of generated token IDs + +#### `logprobs` + +Log probabilities for each token (optional) + +# `tinker.types.try_again_response` + +## `TryAgainResponse` Objects + +```python +class TryAgainResponse(BaseModel) +``` + +#### `request_id` + +Request ID that is still pending + +# `tinker.types.load_weights_request` + +## `LoadWeightsRequest` Objects + +```python +class LoadWeightsRequest(StrictBase) +``` + +#### `path` + +A tinker URI for model weights at a specific step + +#### `optimizer` + +Whether to load optimizer state along with model weights + +# `tinker.types.telemetry_send_request` + +## `TelemetrySendRequest` Objects + +```python +class TelemetrySendRequest(StrictBase) +``` + +#### `platform` + +Host platform name + +#### `sdk_version` + +SDK version string + +# `tinker.types.image_asset_pointer_chunk` + +## `ImageAssetPointerChunk` Objects + +```python +class ImageAssetPointerChunk(StrictBase) +``` + +#### `format` + +Image format + +#### `height` + +Image height in pixels + +#### `location` + +Path or URL to the image asset + +#### `tokens` + +Number of tokens this image represents + +#### `width` + +Image width in pixels + +# `tinker.types.checkpoints_list_response` + +## `CheckpointsListResponse` Objects + +```python +class CheckpointsListResponse(BaseModel) +``` + +#### `checkpoints` + +List of available model checkpoints for the model + +#### `cursor` + +Pagination cursor information (None for unpaginated responses) + +# `tinker.types.generic_event` + +## `GenericEvent` Objects + +```python +class GenericEvent(BaseModel) +``` + +#### `event` + +Telemetry event type + +#### `event_name` + +Low-cardinality event name + +#### `severity` + +Log severity level + +#### `event_data` + +Arbitrary structured JSON payload + +# `tinker.types.encoded_text_chunk` + +## `EncodedTextChunk` Objects + +```python +class EncodedTextChunk(StrictBase) +``` + +#### `tokens` + +Array of token IDs + +# `tinker.types.forward_backward_input` + +## `ForwardBackwardInput` Objects + +```python +class ForwardBackwardInput(StrictBase) +``` + +#### `data` + +Array of input data for the forward/backward pass + +#### `loss_fn` + +Fully qualified function path for the loss function + +#### `loss_fn_config` + +Optional configuration parameters for the loss function (e.g., PPO clip thresholds, DPO beta) + +# `tinker.types.session_start_event` + +## `SessionStartEvent` Objects + +```python +class SessionStartEvent(BaseModel) +``` + +#### `event` + +Telemetry event type + +#### `severity` + +Log severity level + +# `tinker.types.training_runs_response` + +## `TrainingRunsResponse` Objects + +```python +class TrainingRunsResponse(BaseModel) +``` + +#### `training_runs` + +List of training runs + +#### `cursor` + +Pagination cursor information + +# `tinker.types.save_weights_response` + +## `SaveWeightsResponse` Objects + +```python +class SaveWeightsResponse(BaseModel) +``` + +#### `path` + +A tinker URI for model weights at a specific step + +# `tinker.types.sample_request` + +## `SampleRequest` Objects + +```python +class SampleRequest(StrictBase) +``` + +#### `num_samples` + +Number of samples to generate + +#### `base_model` + +Optional base model name to sample from. + +Is inferred from model_path, if provided. If sampling against a base model, this +is required. + +#### `model_path` + +Optional tinker:// path to your model weights or LoRA weights. + +If not provided, samples against the base model. + +#### `sampling_session_id` + +Optional sampling session ID to use instead of model_path/base_model. + +If provided along with seq_id, the model configuration will be loaded from the +sampling session. This is useful for multi-turn conversations. + +#### `seq_id` + +Sequence ID within the sampling session. + +Required when sampling_session_id is provided. Used to generate deterministic +request IDs for the sampling request. + +#### `prompt_logprobs` + +If set to `true`, computes and returns logprobs on the prompt tokens. + +Defaults to false. + +#### `topk_prompt_logprobs` + +If set to a positive integer, returns the top-k logprobs for each prompt token. + +# `tinker.types.forward_backward_output` + +## `ForwardBackwardOutput` Objects + +```python +class ForwardBackwardOutput(BaseModel) +``` + +#### `loss_fn_output_type` + +The type of the ForwardBackward output. Can be one of [...] TODO + +#### `loss_fn_outputs` + +Dictionary mapping field names to tensor data + +#### `metrics` + +Training metrics as key-value pairs + +# `tinker.types.sample_response` + +## `SampleResponse` Objects + +```python +class SampleResponse(BaseModel) +``` + +#### `prompt_logprobs` + +If prompt_logprobs was set to true in the request, logprobs are computed for +every token in the prompt. The `prompt_logprobs` response contains a float32 +value for every token in the prompt. + +#### `topk_prompt_logprobs` + +If topk_prompt_logprobs was set to a positive integer k in the request, +the top-k logprobs are computed for every token in the prompt. The +`topk_prompt_logprobs` response contains, for every token in the prompt, +a list of up to k (token_id, logprob) tuples. + +# `tinker.types.create_sampling_session_response` + +## `CreateSamplingSessionResponse` Objects + +```python +class CreateSamplingSessionResponse(BaseModel) +``` + +#### `sampling_session_id` + +The generated sampling session ID + +# `tinker.types.cursor` + +## `Cursor` Objects + +```python +class Cursor(BaseModel) +``` + +#### `offset` + +The offset used for pagination + +#### `limit` + +The maximum number of items requested + +#### `total_count` + +The total number of items available + +# `tinker.types.create_model_request` + +## `CreateModelRequest` Objects + +```python +class CreateModelRequest(StrictBase) +``` + +#### `base_model` + +Optional metadata about this model/training run, set by the end-user + +# `tinker.types.datum` + +## `Datum` Objects + +```python +class Datum(StrictBase) +``` + +#### `loss_fn_inputs` + +Dictionary mapping field names to tensor data + +#### `convert_tensors` + +```python +@model_validator(mode="before") +@classmethod +def convert_tensors(cls, data: Any) -> Any +``` + +Convert torch.Tensor and numpy arrays to TensorData in loss_fn_inputs during construction. + +# `tinker.types.training_run` + +## `TrainingRun` Objects + +```python +class TrainingRun(BaseModel) +``` + +#### `training_run_id` + +The unique identifier for the training run + +#### `base_model` + +The base model name this model is derived from + +#### `model_owner` + +The owner/creator of this model + +#### `is_lora` + +Whether this model uses LoRA (Low-Rank Adaptation) + +#### `corrupted` + +Whether the model is in a corrupted state + +#### `lora_rank` + +The LoRA rank if this is a LoRA model, null otherwise + +#### `last_request_time` + +The timestamp of the last request made to this model + +#### `last_checkpoint` + +The most recent training checkpoint, if available + +#### `last_sampler_checkpoint` + +The most recent sampler checkpoint, if available + +#### `user_metadata` + +Optional metadata about this training run, set by the end-user + +# `tinker.types.session_end_event` + +## `SessionEndEvent` Objects + +```python +class SessionEndEvent(BaseModel) +``` + +#### `duration` + +ISO 8601 duration string + +#### `event` + +Telemetry event type + +#### `severity` + +Log severity level + +# `tinker.types.telemetry_batch` + +## `TelemetryBatch` Objects + +```python +class TelemetryBatch(BaseModel) +``` + +#### `platform` + +Host platform name + +#### `sdk_version` + +SDK version string + +# `tinker.types.unhandled_exception_event` + +## `UnhandledExceptionEvent` Objects + +```python +class UnhandledExceptionEvent(BaseModel) +``` + +#### `event` + +Telemetry event type + +#### `severity` + +Log severity level + +#### `traceback` + +Optional Python traceback string + +# `tinker.types.image_chunk` + +## `ImageChunk` Objects + +```python +class ImageChunk(StrictBase) +``` + +#### `data` + +Image data as bytes + +#### `format` + +Image format + +#### `height` + +Image height in pixels + +#### `tokens` + +Number of tokens this image represents + +#### `width` + +Image width in pixels + +#### `expected_tokens` + +Expected number of tokens this image represents. +This is only advisory: the tinker backend will compute the number of tokens +from the image, and we can fail requests quickly if the tokens does not +match expected_tokens. + +#### `validate_data` + +```python +@field_validator("data", mode="before") +@classmethod +def validate_data(cls, value: Union[bytes, str]) -> bytes +``` + +Deserialize base64 string to bytes if needed. + +#### `serialize_data` + +```python +@field_serializer("data") +def serialize_data(value: bytes) -> str +``` + +Serialize bytes to base64 string for JSON. + +# `tinker.types.save_weights_request` + +## `SaveWeightsRequest` Objects + +```python +class SaveWeightsRequest(StrictBase) +``` + +#### `path` + +A file/directory name for the weights + +# `tinker.types.lora_config` + +## `LoraConfig` Objects + +```python +class LoraConfig(StrictBase) +``` + +#### `rank` + +LoRA rank (dimension of low-rank matrices) + +#### `seed` + +Seed used for initialization of LoRA weights. + +Useful if you need deterministic or reproducible initialization of weights. + +#### `train_unembed` + +Whether to add lora to the unembedding layer + +#### `train_mlp` + +Whether to add loras to the MLP layers (including MoE layers) + +#### `train_attn` + +Whether to add loras to the attention layers + +# `tinker.types.create_sampling_session_request` + +## `CreateSamplingSessionRequest` Objects + +```python +class CreateSamplingSessionRequest(StrictBase) +``` + +#### `session_id` + +The session ID to create the sampling session within + +#### `sampling_session_seq_id` + +Sequence ID for the sampling session within the session + +#### `base_model` + +Optional base model name to sample from. + +Is inferred from model_path, if provided. If sampling against a base model, this +is required. + +#### `model_path` + +Optional tinker:// path to your model weights or LoRA weights. + +If not provided, samples against the base model. + +# `tinker.types.future_retrieve_request` + +## `FutureRetrieveRequest` Objects + +```python +class FutureRetrieveRequest(StrictBase) +``` + +#### `request_id` + +The ID of the request to retrieve + +# `tinker.types.tensor_data` + +## `TensorData` Objects + +```python +class TensorData(StrictBase) +``` + +#### `data` + +Flattened tensor data as array of numbers. + +#### `shape` + +Optional. + +The shape of the tensor (see PyTorch tensor.shape). The shape of a +one-dimensional list of length N is `(N,)`. Can usually be inferred if not +provided, and is generally inferred as a 1D tensor. + +#### `to_numpy` + +```python +def to_numpy() -> npt.NDArray[Any] +``` + +Convert TensorData to numpy array. + +#### `to_torch` + +```python +def to_torch() -> "torch.Tensor" +``` + +Convert TensorData to torch tensor. + +# `tinker.types.save_weights_for_sampler_request` + +## `SaveWeightsForSamplerRequest` Objects + +```python +class SaveWeightsForSamplerRequest(StrictBase) +``` + +#### `path` + +A file/directory name for the weights + +# `tinker.types.sampling_params` + +## `SamplingParams` Objects + +```python +class SamplingParams(BaseModel) +``` + +#### `max_tokens` + +Maximum number of tokens to generate + +#### `seed` + +Random seed for reproducible generation + +#### `stop` + +Stop sequences for generation + +#### `temperature` + +Sampling temperature + +#### `top_k` + +Top-k sampling parameter (-1 for no limit) + +#### `top_p` + +Nucleus sampling probability + +# `tinker.types.save_weights_for_sampler_response` + +## `SaveWeightsForSamplerResponseInternal` Objects + +```python +class SaveWeightsForSamplerResponseInternal(BaseModel) +``` + +#### `path` + +A tinker URI for model weights for sampling at a specific step + +#### `sampling_session_id` + +The generated sampling session ID + +## `SaveWeightsForSamplerResponse` Objects + +```python +class SaveWeightsForSamplerResponse(BaseModel) +``` + +#### `path` + +A tinker URI for model weights for sampling at a specific step + +# `tinker.types.load_weights_response` + +## `LoadWeightsResponse` Objects + +```python +class LoadWeightsResponse(BaseModel) +``` + +#### `path` + +A tinker URI for model weights at a specific step diff --git a/pydoc-markdown.yml b/pydoc-markdown.yml new file mode 100644 index 0000000..663269f --- /dev/null +++ b/pydoc-markdown.yml @@ -0,0 +1,25 @@ +# Pydoc-markdown configuration for Tinker Python SDK +# Generates MDX-compatible documentation for Nextra integration + +loaders: + - type: python + search_path: [] + packages: [] + ignore_when_discovered: [__pycache__, tests, test_*, conftest, mock_*, _test*] + +processors: + - type: filter + documented_only: true + exclude_private: true + exclude_special: true + skip_empty_modules: true + +renderer: + type: markdown + classdef_code_block: true + code_headers: true + code_lang: true + escape_html_in_docstring: false + insert_header_anchors: false + signature_code_block: true + render_toc: false diff --git a/scripts/bootstrap b/scripts/bootstrap deleted file mode 100755 index bd692ad..0000000 --- a/scripts/bootstrap +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env bash - -set -e - -cd "$(dirname "$0")/.." - -if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "$SKIP_BREW" != "1" ]; then - brew bundle check >/dev/null 2>&1 || { - echo "==> Installing Homebrew dependencies…" - brew bundle - } -fi - -echo "==> Installing Python…" -uv python install - -echo "==> Installing Python dependencies…" -uv sync --all-extras - -echo "==> Exporting Python dependencies…" -# note: `--no-hashes` is required because of https://github.com/pypa/pip/issues/4995 -uv export -o requirements-dev.lock --no-hashes diff --git a/scripts/format b/scripts/format deleted file mode 100755 index 1d2f9c6..0000000 --- a/scripts/format +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env bash - -set -e - -cd "$(dirname "$0")/.." - -echo "==> Running ruff" -uv run ruff format -uv run ruff check --fix . -# run formatting again to fix any inconsistencies when imports are stripped -uv run ruff format - -echo "==> Formatting docs" -uv run python scripts/utils/ruffen-docs.py README.md api.md diff --git a/scripts/generate_docs.py b/scripts/generate_docs.py new file mode 100755 index 0000000..c3f6091 --- /dev/null +++ b/scripts/generate_docs.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +# /// script +# dependencies = [ +# "pydoc-markdown>=4.8.0", +# "pyyaml>=6.0", +# ] +# /// + + +import ast +import json +import os +import subprocess +import sys +from pathlib import Path +from typing import Dict, List + +import yaml + + +def cd_to_project_root(): + """Change to the project root (parent of the scripts directory).""" + script_dir = Path(__file__).resolve().parent + project_root = script_dir.parent + os.chdir(project_root) + print(f"Changed to project root: {project_root}") + + +class ModuleAnalyzer: + """Analyze Python modules to extract public API information.""" + + def __init__(self, src_path: Path): + self.src_path = src_path + + def get_module_exports(self, module_path: Path) -> List[str]: + """Extract __all__ exports from a module.""" + try: + content = module_path.read_text() + tree = ast.parse(content) + + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == '__all__': + if isinstance(node.value, ast.List): + return [ + elt.s for elt in node.value.elts + if isinstance(elt, ast.Str) + ] or [ + elt.value for elt in node.value.elts + if isinstance(elt, ast.Constant) and isinstance(elt.value, str) + ] + except Exception as e: + print(f"Warning: Could not parse {module_path}: {e}") + return [] + + def find_all_modules(self) -> Dict[str, Path]: + """Find all Python modules in the package.""" + modules = {} + tinker_path = self.src_path / "tinker" + + for py_file in tinker_path.rglob("*.py"): + # Skip test files and private modules + if any(part.startswith('test') or part.startswith('_test') for part in py_file.parts): + continue + if '__pycache__' in py_file.parts: + continue + + # Calculate module name + relative_path = py_file.relative_to(self.src_path) + module_parts = list(relative_path.parts[:-1]) # Remove .py file + module_parts.append(relative_path.stem) + + # Skip __init__ files in module name + if module_parts[-1] == '__init__': + module_parts = module_parts[:-1] + + module_name = '.'.join(module_parts) + if module_name: # Skip empty module names + modules[module_name] = py_file + + return modules + + +class DocumentationGenerator: + """Generate documentation using pydoc-markdown.""" + + def __init__(self, config_path: Path, output_dir: Path): + self.config_path = config_path + self.output_dir = output_dir + self.output_dir.mkdir(parents=True, exist_ok=True) + self.analyzer = ModuleAnalyzer(Path('src')) + + def run_pydoc_markdown(self, modules: List[str], output_file: Path) -> bool: + """Run pydoc-markdown for specific modules.""" + try: + # Build the command + cmd = ['pydoc-markdown', 'pydoc-markdown.yml', '-I', 'src'] + + # Add modules + for module in modules: + cmd.extend(['-m', module]) + + # Run pydoc-markdown + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False + ) + + if result.returncode == 0: + # Write output to file + output_file.parent.mkdir(parents=True, exist_ok=True) + output_file.write_text(result.stdout) + print(f"Generated: {output_file}") + return True + else: + print(f"Error generating {output_file}: {result.stderr}") + return False + + except Exception as e: + print(f"Exception generating {output_file}: {e}") + return False + + def generate_public_interfaces(self): + """Generate documentation for public interface classes.""" + print("\n=== Generating Public Interfaces Documentation ===") + + # Generate individual pages for each client + client_modules = [ + ('ServiceClient', 'tinker.lib.public_interfaces.service_client'), + ('TrainingClient', 'tinker.lib.public_interfaces.training_client'), + ('SamplingClient', 'tinker.lib.public_interfaces.sampling_client'), + ('RestClient', 'tinker.lib.public_interfaces.rest_client'), + ('APIFuture', 'tinker.lib.public_interfaces.api_future'), + ] + + for class_name, module in client_modules: + output_file = self.output_dir / f'{class_name.lower().replace("_", "-")}.md' + self.run_pydoc_markdown([module], output_file) + + def generate_all_types(self): + """Generate complete types reference.""" + print("\n=== Generating Complete Types Reference ===") + + # Get all type modules + all_modules = self.analyzer.find_all_modules() + type_modules = [m for m in all_modules.keys() if m.startswith('tinker.types')] + + if type_modules: + output_file = self.output_dir / 'types.md' + self.run_pydoc_markdown(type_modules, output_file) + + def generate_exceptions(self): + """Generate exception hierarchy documentation.""" + print("\n=== Generating Exception Documentation ===") + + output_file = self.output_dir / 'exceptions.md' + self.run_pydoc_markdown(['tinker._exceptions'], output_file) + + def generate_nextra_meta(self): + """Generate _meta.json for Nextra navigation.""" + print("\n=== Generating Nextra Navigation Metadata ===") + + meta = { + "serviceclient": "ServiceClient", + "trainingclient": "TrainingClient", + "samplingclient": "SamplingClient", + "restclient": "RestClient", + "apifuture": "APIFuture", + "types": "Parameters", + "exceptions": "Exceptions" + } + + meta_file = self.output_dir / '_meta.json' + meta_file.write_text(json.dumps(meta, indent=2)) + print(f"Generated: {meta_file}") + + def generate_all(self): + """Generate all documentation.""" + print("Starting documentation generation...") + print(f"Output directory: {self.output_dir}") + + # Generate documentation for each category + self.generate_public_interfaces() + self.generate_all_types() + self.generate_exceptions() + + # Generate Nextra metadata + self.generate_nextra_meta() + + print("\n=== Documentation Generation Complete ===") + print(f"Markdown files generated in: {self.output_dir}") + print("\nGenerated files:") + for file in sorted(self.output_dir.rglob("*.md")): + print(f" - {file.relative_to(self.output_dir)}") + + +def main(): + """Main entry point.""" + # Change to project root first + cd_to_project_root() + + # Paths + project_root = Path.cwd() + config_path = project_root / 'pydoc-markdown.yml' + output_dir = project_root / 'docs' / 'api' + + # Check if config exists + if not config_path.exists(): + print(f"Error: Configuration file not found at {config_path}") + print("Please run this script from the project root directory") + sys.exit(1) + + # Create generator and run + generator = DocumentationGenerator(config_path, output_dir) + generator.generate_all() + + # Print usage instructions + print("\n" + "=" * 50) + print("To use these docs in your Nextra project:") + print("1. Copy the docs/api directory to your Nextra project") + print("2. The markdown files are ready to use with Nextra") + print("3. Navigation structure is defined in _meta.json") + print("\nTo regenerate docs after code changes:") + print(" uv run scripts/generate_docs.py") + + +if __name__ == "__main__": + main() diff --git a/scripts/lint b/scripts/lint deleted file mode 100755 index 9cdd830..0000000 --- a/scripts/lint +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash - -set -e - -cd "$(dirname "$0")/.." - -echo "==> Running ruff" -uv run ruff check . - -echo "==> Running pyright" -uv run pyright --threads 64 - -echo "==> Running mypy" -uv run mypy . - -echo "==> Making sure it imports" -uv run python -c 'import tinker' diff --git a/scripts/mock b/scripts/mock deleted file mode 100755 index dfe2f01..0000000 --- a/scripts/mock +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env bash - -set -e - -cd "$(dirname "$0")/.." - -if [[ -n "$1" && "$1" != '--'* ]]; then - URL="$1" - shift -else - URL="$(grep 'openapi_spec_url' .stats.yml | cut -d' ' -f2)" -fi - -# Check if the URL is empty -if [ -z "$URL" ]; then - echo "Error: No OpenAPI spec path/url provided or found in .stats.yml" - exit 1 -fi - -echo "==> Starting mock server with URL ${URL}" - -# Run prism mock on the given spec -if [ "$1" == "--daemon" ]; then - npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" &>.prism.log & - - # Wait for server to come online - echo -n "Waiting for server" - while ! grep -q "✖ fatal\|Prism is listening" ".prism.log"; do - echo -n "." - sleep 0.1 - done - - if grep -q "✖ fatal" ".prism.log"; then - cat .prism.log - exit 1 - fi - - echo -else - npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" -fi diff --git a/bin/publish-pypi b/scripts/publish-pypi similarity index 100% rename from bin/publish-pypi rename to scripts/publish-pypi diff --git a/scripts/test b/scripts/test deleted file mode 100755 index a9ff2c2..0000000 --- a/scripts/test +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -cd "$(dirname "$0")/.." - -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[0;33m' -NC='\033[0m' # No Color - -function prism_is_running() { - curl --silent "http://localhost:4010" >/dev/null 2>&1 -} - -kill_server_on_port() { - pids=$(lsof -t -i tcp:"$1" || echo "") - if [ "$pids" != "" ]; then - kill "$pids" - echo "Stopped $pids." - fi -} - -function is_overriding_api_base_url() { - [ -n "$TEST_API_BASE_URL" ] -} - -if ! is_overriding_api_base_url && ! prism_is_running; then - # When we exit this script, make sure to kill the background mock server process - trap 'kill_server_on_port 4010' EXIT - - # Start the dev server - ./scripts/mock --daemon -fi - -if is_overriding_api_base_url; then - echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}" - echo -elif ! prism_is_running; then - echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server" - echo -e "running against your OpenAPI spec." - echo - echo -e "To run the server, pass in the path or url of your OpenAPI" - echo -e "spec to the prism command:" - echo - echo -e " \$ ${YELLOW}npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock path/to/your.openapi.yml${NC}" - echo - - exit 1 -else - echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}" - echo -fi - -export DEFER_PYDANTIC_BUILD=false - -function run_tests() { - echo "==> Running tests with Pydantic v2" - uv run --all-extras --all-groups pytest "$@" - - echo "==> Running tests with Pydantic v1" - uv pip install 'pydantic<2' - uv run --all-extras --all-groups pytest "$@" -} - -# If UV_PYTHON is already set in the environment, just run the command once -if [[ -n "$UV_PYTHON" ]]; then - run_tests "$@" -else - # If UV_PYTHON is not set, run the command for min and max versions - - echo "==> Running tests for Python 3.9" - UV_PYTHON=3.9 run_tests "$@" - - echo "==> Running tests for Python 3.13" - UV_PYTHON=3.13 run_tests "$@" -fi diff --git a/scripts/utils/ruffen-docs.py b/scripts/utils/ruffen-docs.py deleted file mode 100644 index 0cf2bd2..0000000 --- a/scripts/utils/ruffen-docs.py +++ /dev/null @@ -1,167 +0,0 @@ -# fork of https://github.com/asottile/blacken-docs adapted for ruff -from __future__ import annotations - -import re -import sys -import argparse -import textwrap -import contextlib -import subprocess -from typing import Match, Optional, Sequence, Generator, NamedTuple, cast - -MD_RE = re.compile( - r"(?P^(?P *)```\s*python\n)" r"(?P.*?)" r"(?P^(?P=indent)```\s*$)", - re.DOTALL | re.MULTILINE, -) -MD_PYCON_RE = re.compile( - r"(?P^(?P *)```\s*pycon\n)" r"(?P.*?)" r"(?P^(?P=indent)```.*$)", - re.DOTALL | re.MULTILINE, -) -PYCON_PREFIX = ">>> " -PYCON_CONTINUATION_PREFIX = "..." -PYCON_CONTINUATION_RE = re.compile( - rf"^{re.escape(PYCON_CONTINUATION_PREFIX)}( |$)", -) -DEFAULT_LINE_LENGTH = 100 - - -class CodeBlockError(NamedTuple): - offset: int - exc: Exception - - -def format_str( - src: str, -) -> tuple[str, Sequence[CodeBlockError]]: - errors: list[CodeBlockError] = [] - - @contextlib.contextmanager - def _collect_error(match: Match[str]) -> Generator[None, None, None]: - try: - yield - except Exception as e: - errors.append(CodeBlockError(match.start(), e)) - - def _md_match(match: Match[str]) -> str: - code = textwrap.dedent(match["code"]) - with _collect_error(match): - code = format_code_block(code) - code = textwrap.indent(code, match["indent"]) - return f"{match['before']}{code}{match['after']}" - - def _pycon_match(match: Match[str]) -> str: - code = "" - fragment = cast(Optional[str], None) - - def finish_fragment() -> None: - nonlocal code - nonlocal fragment - - if fragment is not None: - with _collect_error(match): - fragment = format_code_block(fragment) - fragment_lines = fragment.splitlines() - code += f"{PYCON_PREFIX}{fragment_lines[0]}\n" - for line in fragment_lines[1:]: - # Skip blank lines to handle Black adding a blank above - # functions within blocks. A blank line would end the REPL - # continuation prompt. - # - # >>> if True: - # ... def f(): - # ... pass - # ... - if line: - code += f"{PYCON_CONTINUATION_PREFIX} {line}\n" - if fragment_lines[-1].startswith(" "): - code += f"{PYCON_CONTINUATION_PREFIX}\n" - fragment = None - - indentation = None - for line in match["code"].splitlines(): - orig_line, line = line, line.lstrip() - if indentation is None and line: - indentation = len(orig_line) - len(line) - continuation_match = PYCON_CONTINUATION_RE.match(line) - if continuation_match and fragment is not None: - fragment += line[continuation_match.end() :] + "\n" - else: - finish_fragment() - if line.startswith(PYCON_PREFIX): - fragment = line[len(PYCON_PREFIX) :] + "\n" - else: - code += orig_line[indentation:] + "\n" - finish_fragment() - return code - - def _md_pycon_match(match: Match[str]) -> str: - code = _pycon_match(match) - code = textwrap.indent(code, match["indent"]) - return f"{match['before']}{code}{match['after']}" - - src = MD_RE.sub(_md_match, src) - src = MD_PYCON_RE.sub(_md_pycon_match, src) - return src, errors - - -def format_code_block(code: str) -> str: - return subprocess.check_output( - [ - sys.executable, - "-m", - "ruff", - "format", - "--stdin-filename=script.py", - f"--line-length={DEFAULT_LINE_LENGTH}", - ], - encoding="utf-8", - input=code, - ) - - -def format_file( - filename: str, - skip_errors: bool, -) -> int: - with open(filename, encoding="UTF-8") as f: - contents = f.read() - new_contents, errors = format_str(contents) - for error in errors: - lineno = contents[: error.offset].count("\n") + 1 - print(f"{filename}:{lineno}: code block parse error {error.exc}") - if errors and not skip_errors: - return 1 - if contents != new_contents: - print(f"{filename}: Rewriting...") - with open(filename, "w", encoding="UTF-8") as f: - f.write(new_contents) - return 0 - else: - return 0 - - -def main(argv: Sequence[str] | None = None) -> int: - parser = argparse.ArgumentParser() - parser.add_argument( - "-l", - "--line-length", - type=int, - default=DEFAULT_LINE_LENGTH, - ) - parser.add_argument( - "-S", - "--skip-string-normalization", - action="store_true", - ) - parser.add_argument("-E", "--skip-errors", action="store_true") - parser.add_argument("filenames", nargs="*") - args = parser.parse_args(argv) - - retv = 0 - for filename in args.filenames: - retv |= format_file(filename, skip_errors=args.skip_errors) - return retv - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/src/tinker/_exceptions.py b/src/tinker/_exceptions.py index 0b0fdbb..abcc17f 100644 --- a/src/tinker/_exceptions.py +++ b/src/tinker/_exceptions.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import httpx -from typing_extensions import Literal __all__ = [ "BadRequestError", @@ -22,10 +21,14 @@ if TYPE_CHECKING: class TinkerError(Exception): + """Base exception for all Tinker-related errors.""" + pass class APIError(TinkerError): + """Base class for all API-related errors.""" + message: str request: httpx.Request @@ -48,11 +51,19 @@ class APIError(TinkerError): class APIResponseValidationError(APIError): + """Raised when API response doesn't match expected schema.""" + response: httpx.Response status_code: int - def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None: - super().__init__(message or "Data returned by API invalid for expected schema.", response.request, body=body) + def __init__( + self, response: httpx.Response, body: object | None, *, message: str | None = None + ) -> None: + super().__init__( + message or "Data returned by API invalid for expected schema.", + response.request, + body=body, + ) self.response = response self.status_code = response.status_code @@ -70,44 +81,64 @@ class APIStatusError(APIError): class APIConnectionError(APIError): + """Raised when a connection error occurs while making an API request.""" + def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None: super().__init__(message, request, body=None) class APITimeoutError(APIConnectionError): + """Raised when an API request times out.""" + def __init__(self, request: httpx.Request) -> None: super().__init__(message="Request timed out.", request=request) class BadRequestError(APIStatusError): - status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride] + """HTTP 400: The request was invalid or malformed.""" + + status_code: int = 400 class AuthenticationError(APIStatusError): - status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride] + """HTTP 401: Authentication credentials are missing or invalid.""" + + status_code: int = 401 class PermissionDeniedError(APIStatusError): - status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride] + """HTTP 403: Insufficient permissions to access the resource.""" + + status_code: int = 403 class NotFoundError(APIStatusError): - status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride] + """HTTP 404: The requested resource was not found.""" + + status_code: int = 404 class ConflictError(APIStatusError): - status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride] + """HTTP 409: The request conflicts with the current state of the resource.""" + + status_code: int = 409 class UnprocessableEntityError(APIStatusError): - status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride] + """HTTP 422: The request was well-formed but contains semantic errors.""" + + status_code: int = 422 class RateLimitError(APIStatusError): - status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride] + """HTTP 429: Too many requests, rate limit exceeded.""" + + status_code: int = 429 class InternalServerError(APIStatusError): + """HTTP 500+: An error occurred on the server.""" + pass diff --git a/src/tinker/lib/public_interfaces/api_future.py b/src/tinker/lib/public_interfaces/api_future.py index c9ae317..7aae38f 100644 --- a/src/tinker/lib/public_interfaces/api_future.py +++ b/src/tinker/lib/public_interfaces/api_future.py @@ -13,12 +13,65 @@ T = TypeVar("T") class APIFuture(ABC, Generic[T]): + """Abstract base class for futures that can be awaited or accessed synchronously. + + APIFuture provides a unified interface for handling async operations that can + be accessed both synchronously (via result()) and asynchronously (via await or result_async()). + This allows for flexible usage patterns in both sync and async contexts. + + The future can be awaited directly in async contexts: + ```python + result = await api_future # Equivalent to await api_future.result_async() + ``` + + Or accessed synchronously: + ```python + result = api_future.result() # Blocks until complete + ``` + + Args: + T: The type of the result value + + Example: + ```python + # In async context + future = training_client.forward_backward(data, "cross_entropy") + result = await future # Or await future.result_async() + + # In sync context + future = training_client.forward_backward(data, "cross_entropy") + result = future.result() + ``` + """ + @abstractmethod async def result_async(self, timeout: float | None = None) -> T: + """Get the result asynchronously with optional timeout. + + Args: + timeout: Maximum time to wait in seconds. None means wait indefinitely. + + Returns: + The result value of type T + + Raises: + TimeoutError: If timeout is exceeded + """ raise NotImplementedError @abstractmethod def result(self, timeout: float | None = None) -> T: + """Get the result synchronously with optional timeout. + + Args: + timeout: Maximum time to wait in seconds. None means wait indefinitely. + + Returns: + The result value of type T + + Raises: + TimeoutError: If timeout is exceeded + """ raise NotImplementedError def __await__(self): @@ -26,15 +79,71 @@ class APIFuture(ABC, Generic[T]): class AwaitableConcurrentFuture(APIFuture[T]): + """Implementation of APIFuture that wraps a concurrent.futures.Future. + + This class bridges Python's concurrent.futures with asyncio, allowing a + standard Future to be used in async contexts. It's commonly returned by + Tinker API methods to provide both sync and async access patterns. + + Args: + future: A concurrent.futures.Future to wrap + + Example: + ```python + # Internal usage - typically you receive these from API methods + concurrent_future = some_operation() + api_future = AwaitableConcurrentFuture(concurrent_future) + + # Can be used synchronously + result = api_future.result() + + # Or asynchronously + result = await api_future + ``` + """ + def __init__(self, future: ConcurrentFuture[T]): self._future: ConcurrentFuture[T] = future def result(self, timeout: float | None = None) -> T: + """Get the result synchronously with optional timeout. + + Args: + timeout: Maximum time to wait in seconds. None means wait indefinitely. + + Returns: + The result value of type T + + Raises: + TimeoutError: If timeout is exceeded + Exception: Any exception raised by the underlying operation + + Example: + ```python + future = rest_client.get_training_run("run-id") + result = future.result(timeout=30) # Wait up to 30 seconds + ``` + """ return self._future.result(timeout) async def result_async(self, timeout: float | None = None) -> T: + """Async version of result.""" async with asyncio.timeout(timeout): return await asyncio.wrap_future(self._future) def future(self) -> ConcurrentFuture[T]: + """Get the underlying concurrent.futures.Future. + + Returns: + The wrapped ConcurrentFuture object + + Example: + ```python + api_future = rest_client.get_training_run("run-id") + concurrent_future = api_future.future() + # Can now use standard concurrent.futures methods + if concurrent_future.done(): + result = concurrent_future.result() + ``` + """ return self._future diff --git a/src/tinker/lib/public_interfaces/rest_client.py b/src/tinker/lib/public_interfaces/rest_client.py index a918d3d..00f1159 100644 --- a/src/tinker/lib/public_interfaces/rest_client.py +++ b/src/tinker/lib/public_interfaces/rest_client.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING from tinker import NoneType, types from tinker.lib.client_connection_pool_type import ClientConnectionPoolType -from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture +from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture from tinker.lib.telemetry import Telemetry, capture_exceptions from tinker.lib.telemetry_provider import TelemetryProvider @@ -42,13 +42,15 @@ class RestClient(TelemetryProvider): holder: Internal client managing HTTP connections and async operations Example: - >>> rest_client = service_client.create_rest_client() - >>> training_run = rest_client.get_training_run("run-id").result() - >>> print(f"Training Run: {training_run.training_run_id}, LoRA: {training_run.is_lora}") - >>> checkpoints = rest_client.list_checkpoints("run-id").result() - >>> print(f"Found {len(checkpoints.checkpoints)} checkpoints") - >>> for checkpoint in checkpoints.checkpoints: - ... print(f" {checkpoint.checkpoint_type}: {checkpoint.checkpoint_id}") + ```python + rest_client = service_client.create_rest_client() + training_run = rest_client.get_training_run("run-id").result() + print(f"Training Run: {training_run.training_run_id}, LoRA: {training_run.is_lora}") + checkpoints = rest_client.list_checkpoints("run-id").result() + print(f"Found {len(checkpoints.checkpoints)} checkpoints") + for checkpoint in checkpoints.checkpoints: + print(f" {checkpoint.checkpoint_type}: {checkpoint.checkpoint_id}") + ``` """ def __init__(self, holder: InternalClientHolder): @@ -85,26 +87,17 @@ class RestClient(TelemetryProvider): A Future containing the training run information Example: - >>> future = rest_client.get_training_run("run-id") - >>> response = future.result() - >>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}") + ```python + future = rest_client.get_training_run("run-id") + response = future.result() + print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}") + ``` """ return self._get_training_run_submit(training_run_id).future() @capture_exceptions(fatal=True) async def get_training_run_async(self, training_run_id: types.ModelID) -> types.TrainingRun: - """Async version of get_training_run. - - Args: - training_run_id: The training run ID to get information for - - Returns: - Training run information - - Example: - >>> response = await rest_client.get_training_run_async("run-id") - >>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}") - """ + """Async version of get_training_run.""" return await self._get_training_run_submit(training_run_id) @sync_only @@ -121,9 +114,11 @@ class RestClient(TelemetryProvider): A Future containing the training run information Example: - >>> future = rest_client.get_training_run_by_tinker_path("tinker://run-id/weights/checkpoint-001") - >>> response = future.result() - >>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}") + ```python + future = rest_client.get_training_run_by_tinker_path("tinker://run-id/weights/checkpoint-001") + response = future.result() + print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}") + ``` """ parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path( tinker_path @@ -132,22 +127,44 @@ class RestClient(TelemetryProvider): @capture_exceptions(fatal=True) async def get_training_run_by_tinker_path_async(self, tinker_path: str) -> types.TrainingRun: - """Async version of get_training_run. + """Async version of get_training_run_by_tinker_path.""" + parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path( + tinker_path + ) + return await self.get_training_run_async(parsed_checkpoint_tinker_path.training_run_id) + + @capture_exceptions(fatal=True) + def get_weights_info_by_tinker_path( + self, tinker_path: str + ) -> APIFuture[types.WeightsInfoResponse]: + """Get checkpoint information from a tinker path. Args: tinker_path: The tinker path to the checkpoint Returns: - Training run information + An APIFuture containing the checkpoint information. The future is awaitable. Example: - >>> response = await rest_client.get_training_run_by_tinker_path_async("tinker://run-id/weights/checkpoint-001") - >>> print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}") + ```python + future = rest_client.get_weights_info_by_tinker_path("tinker://run-id/weights/checkpoint-001") + response = future.result() # or await future + print(f"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}") + ``` """ - parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path( - tinker_path - ) - return await self.get_training_run_async(parsed_checkpoint_tinker_path.training_run_id) + + async def _get_weights_info_async() -> types.WeightsInfoResponse: + async def _send_request() -> types.WeightsInfoResponse: + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + return await client.post( + "/api/v1/weights_info", + body={"tinker_path": tinker_path}, + cast_to=types.WeightsInfoResponse, + ) + + return await self.holder.execute_with_retries(_send_request) + + return self.holder.run_coroutine_threadsafe(_get_weights_info_async()) def _list_training_runs_submit( self, limit: int = 20, offset: int = 0 @@ -184,12 +201,14 @@ class RestClient(TelemetryProvider): A Future containing the TrainingRunsResponse with training runs and cursor info Example: - >>> future = rest_client.list_training_runs(limit=50) - >>> response = future.result() - >>> print(f"Found {len(response.training_runs)} training runs") - >>> print(f"Total: {response.cursor.total_count}") - >>> # Get next page - >>> next_page = rest_client.list_training_runs(limit=50, offset=50) + ```python + future = rest_client.list_training_runs(limit=50) + response = future.result() + print(f"Found {len(response.training_runs)} training runs") + print(f"Total: {response.cursor.total_count}") + # Get next page + next_page = rest_client.list_training_runs(limit=50, offset=50) + ``` """ return self._list_training_runs_submit(limit, offset).future() @@ -197,22 +216,7 @@ class RestClient(TelemetryProvider): async def list_training_runs_async( self, limit: int = 20, offset: int = 0 ) -> types.TrainingRunsResponse: - """Async version of list_training_runs. - - Args: - limit: Maximum number of training runs to return (default 20) - offset: Offset for pagination (default 0) - - Returns: - TrainingRunsResponse with training runs and cursor info - - Example: - >>> response = await rest_client.list_training_runs_async(limit=50) - >>> print(f"Found {len(response.training_runs)} training runs") - >>> print(f"Total: {response.cursor.total_count}") - >>> # Get next page - >>> next_page = await rest_client.list_training_runs_async(limit=50, offset=50) - """ + """Async version of list_training_runs.""" return await self._list_training_runs_submit(limit, offset) def _list_checkpoints_submit( @@ -243,13 +247,15 @@ class RestClient(TelemetryProvider): A Future containing the CheckpointsListResponse with available checkpoints Example: - >>> future = rest_client.list_checkpoints("run-id") - >>> response = future.result() - >>> for checkpoint in response.checkpoints: - ... if checkpoint.checkpoint_type == "training": - ... print(f"Training checkpoint: {checkpoint.checkpoint_id}") - ... elif checkpoint.checkpoint_type == "sampler": - ... print(f"Sampler checkpoint: {checkpoint.checkpoint_id}") + ```python + future = rest_client.list_checkpoints("run-id") + response = future.result() + for checkpoint in response.checkpoints: + if checkpoint.checkpoint_type == "training": + print(f"Training checkpoint: {checkpoint.checkpoint_id}") + elif checkpoint.checkpoint_type == "sampler": + print(f"Sampler checkpoint: {checkpoint.checkpoint_id}") + ``` """ return self._list_checkpoints_submit(training_run_id).future() @@ -257,22 +263,7 @@ class RestClient(TelemetryProvider): async def list_checkpoints_async( self, training_run_id: types.ModelID ) -> types.CheckpointsListResponse: - """Async version of list_checkpoints. - - Args: - training_run_id: The training run ID to list checkpoints for - - Returns: - CheckpointsListResponse with available checkpoints - - Example: - >>> response = await rest_client.list_checkpoints_async("run-id") - >>> for checkpoint in response.checkpoints: - ... if checkpoint.checkpoint_type == "training": - ... print(f"Training checkpoint: {checkpoint.checkpoint_id}") - ... elif checkpoint.checkpoint_type == "sampler": - ... print(f"Sampler checkpoint: {checkpoint.checkpoint_id}") - """ + """Async version of list_checkpoints.""" return await self._list_checkpoints_submit(training_run_id) def _get_checkpoint_archive_url_submit( @@ -304,11 +295,13 @@ class RestClient(TelemetryProvider): A Future containing the CheckpointArchiveUrlResponse with signed URL and expiration Example: - >>> future = rest_client.get_checkpoint_archive_url("run-id", "checkpoint-123") - >>> response = future.result() - >>> print(f"Download URL: {response.url}") - >>> print(f"Expires at: {response.expires_at}") - >>> # Use the URL to download the archive with your preferred HTTP client + ```python + future = rest_client.get_checkpoint_archive_url("run-id", "checkpoint-123") + response = future.result() + print(f"Download URL: {response.url}") + print(f"Expires at: {response.expires_at}") + # Use the URL to download the archive with your preferred HTTP client + ``` """ return self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id).future() @@ -316,21 +309,7 @@ class RestClient(TelemetryProvider): async def get_checkpoint_archive_url_async( self, training_run_id: types.ModelID, checkpoint_id: str ) -> types.CheckpointArchiveUrlResponse: - """Async version of get_checkpoint_archive_url. - - Args: - training_run_id: The model ID to download weights for - checkpoint_id: The checkpoint ID to download - - Returns: - CheckpointArchiveUrlResponse with signed URL and expiration - - Example: - >>> response = await rest_client.get_checkpoint_archive_url_async("run-id", "checkpoint-123") - >>> print(f"Download URL: {response.url}") - >>> print(f"Expires at: {response.expires_at}") - >>> # Use the URL to download the archive with your preferred HTTP client - """ + """Async version of get_checkpoint_archive_url.""" return await self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id) def _delete_checkpoint_submit( @@ -411,14 +390,7 @@ class RestClient(TelemetryProvider): async def get_checkpoint_archive_url_from_tinker_path_async( self, tinker_path: str ) -> types.CheckpointArchiveUrlResponse: - """Async version of get_checkpoint_archive_url_from_tinker_path. - - Args: - tinker_path: The tinker path to the checkpoint - - Returns: - CheckpointArchiveUrlResponse with signed URL and expiration - """ + """Async version of get_checkpoint_archive_url_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return await self._get_checkpoint_archive_url_submit( parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id @@ -462,9 +434,11 @@ class RestClient(TelemetryProvider): HTTPException: 500 if there's an error publishing the checkpoint Example: - >>> future = rest_client.publish_checkpoint_from_tinker_path("tinker://run-id/weights/0001") - >>> future.result() # Wait for completion - >>> print("Checkpoint published successfully") + ```python + future = rest_client.publish_checkpoint_from_tinker_path("tinker://run-id/weights/0001") + future.result() # Wait for completion + print("Checkpoint published successfully") + ``` """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._publish_checkpoint_submit( @@ -473,24 +447,7 @@ class RestClient(TelemetryProvider): @capture_exceptions(fatal=True) async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None: - """Async version of publish_checkpoint_from_tinker_path. - - Only the exact owner of the training run can publish checkpoints. - Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path_async method. - - Args: - tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") - - Raises: - HTTPException: 400 if checkpoint identifier is invalid - HTTPException: 404 if checkpoint not found or user doesn't own the training run - HTTPException: 409 if checkpoint is already public - HTTPException: 500 if there's an error publishing the checkpoint - - Example: - >>> await rest_client.publish_checkpoint_from_tinker_path_async("tinker://run-id/weights/0001") - >>> print("Checkpoint published successfully") - """ + """Async version of publish_checkpoint_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._publish_checkpoint_submit( parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id @@ -534,9 +491,11 @@ class RestClient(TelemetryProvider): HTTPException: 500 if there's an error unpublishing the checkpoint Example: - >>> future = rest_client.unpublish_checkpoint_from_tinker_path("tinker://run-id/weights/0001") - >>> future.result() # Wait for completion - >>> print("Checkpoint unpublished successfully") + ```python + future = rest_client.unpublish_checkpoint_from_tinker_path("tinker://run-id/weights/0001") + future.result() # Wait for completion + print("Checkpoint unpublished successfully") + ``` """ parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) return self._unpublish_checkpoint_submit( @@ -545,24 +504,7 @@ class RestClient(TelemetryProvider): @capture_exceptions(fatal=True) async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None: - """Async version of unpublish_checkpoint_from_tinker_path. - - Only the exact owner of the training run can unpublish checkpoints. - This reverses the effect of publishing a checkpoint. - - Args: - tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") - - Raises: - HTTPException: 400 if checkpoint identifier is invalid - HTTPException: 404 if checkpoint not found or user doesn't own the training run - HTTPException: 409 if checkpoint is already private - HTTPException: 500 if there's an error unpublishing the checkpoint - - Example: - >>> await rest_client.unpublish_checkpoint_from_tinker_path_async("tinker://run-id/weights/0001") - >>> print("Checkpoint unpublished successfully") - """ + """Async version of unpublish_checkpoint_from_tinker_path.""" parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) await self._unpublish_checkpoint_submit( parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id @@ -607,15 +549,17 @@ class RestClient(TelemetryProvider): A Future containing the CheckpointsListResponse with checkpoints and cursor info Example: - >>> future = rest_client.list_user_checkpoints(limit=50) - >>> response = future.result() - >>> print(f"Found {len(response.checkpoints)} checkpoints") - >>> print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}") - >>> for checkpoint in response.checkpoints: - ... print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}") - >>> # Get next page if there are more checkpoints - >>> if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count: - ... next_page = rest_client.list_user_checkpoints(limit=50, offset=50) + ```python + future = rest_client.list_user_checkpoints(limit=50) + response = future.result() + print(f"Found {len(response.checkpoints)} checkpoints") + print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}") + for checkpoint in response.checkpoints: + print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}") + # Get next page if there are more checkpoints + if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count: + next_page = rest_client.list_user_checkpoints(limit=50, offset=50) + ``` """ return self._list_user_checkpoints_submit(limit, offset).future() @@ -623,29 +567,7 @@ class RestClient(TelemetryProvider): async def list_user_checkpoints_async( self, limit: int = 100, offset: int = 0 ) -> types.CheckpointsListResponse: - """Async version of list_user_checkpoints. - - This method retrieves checkpoints from all training runs owned by the authenticated user, - sorted by time (newest first). It supports pagination for efficiently handling large - numbers of checkpoints. - - Args: - limit: Maximum number of checkpoints to return (default 100) - offset: Offset for pagination (default 0) - - Returns: - CheckpointsListResponse with checkpoints and cursor info - - Example: - >>> response = await rest_client.list_user_checkpoints_async(limit=50) - >>> print(f"Found {len(response.checkpoints)} checkpoints") - >>> print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}") - >>> for checkpoint in response.checkpoints: - ... print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}") - >>> # Get next page if there are more checkpoints - >>> if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count: - ... next_page = await rest_client.list_user_checkpoints_async(limit=50, offset=50) - """ + """Async version of list_user_checkpoints.""" return await self._list_user_checkpoints_submit(limit, offset) def _get_session_submit( @@ -677,28 +599,18 @@ class RestClient(TelemetryProvider): A Future containing the GetSessionResponse with training_run_ids and sampler_ids Example: - >>> future = rest_client.get_session("session-id") - >>> response = future.result() - >>> print(f"Training runs: {len(response.training_run_ids)}") - >>> print(f"Samplers: {len(response.sampler_ids)}") + ```python + future = rest_client.get_session("session-id") + response = future.result() + print(f"Training runs: {len(response.training_run_ids)}") + print(f"Samplers: {len(response.sampler_ids)}") + ``` """ return self._get_session_submit(session_id).future() @capture_exceptions(fatal=True) async def get_session_async(self, session_id: str) -> types.GetSessionResponse: - """Async version of get_session. - - Args: - session_id: The session ID to get information for - - Returns: - GetSessionResponse with training_run_ids and sampler_ids - - Example: - >>> response = await rest_client.get_session_async("session-id") - >>> print(f"Training runs: {len(response.training_run_ids)}") - >>> print(f"Sampling sessions: {len(response.sampler_ids)}") - """ + """Async version of get_session.""" return await self._get_session_submit(session_id) def _list_sessions_submit( @@ -736,11 +648,13 @@ class RestClient(TelemetryProvider): A Future containing the ListSessionsResponse with list of session IDs Example: - >>> future = rest_client.list_sessions(limit=50) - >>> response = future.result() - >>> print(f"Found {len(response.sessions)} sessions") - >>> # Get next page - >>> next_page = rest_client.list_sessions(limit=50, offset=50) + ```python + future = rest_client.list_sessions(limit=50) + response = future.result() + print(f"Found {len(response.sessions)} sessions") + # Get next page + next_page = rest_client.list_sessions(limit=50, offset=50) + ``` """ return self._list_sessions_submit(limit, offset).future() @@ -748,19 +662,46 @@ class RestClient(TelemetryProvider): async def list_sessions_async( self, limit: int = 20, offset: int = 0 ) -> types.ListSessionsResponse: - """Async version of list_sessions. + """Async version of list_sessions.""" + return await self._list_sessions_submit(limit, offset) + + @capture_exceptions(fatal=True) + def get_sampler(self, sampler_id: str) -> APIFuture[types.GetSamplerResponse]: + """Get sampler information. Args: - limit: Maximum number of sessions to return (default 20) - offset: Offset for pagination (default 0) + sampler_id: The sampler ID (sampling_session_id) to get information for Returns: - ListSessionsResponse with list of session IDs + An APIFuture containing the GetSamplerResponse with sampler details Example: - >>> response = await rest_client.list_sessions_async(limit=50) - >>> print(f"Found {len(response.sessions)} sessions") - >>> # Get next page - >>> next_page = await rest_client.list_sessions_async(limit=50, offset=50) + ```python + # Sync usage + future = rest_client.get_sampler("session-id:sample:0") + response = future.result() + print(f"Base model: {response.base_model}") + print(f"Model path: {response.model_path}") + + # Async usage + response = await rest_client.get_sampler("session-id:sample:0") + print(f"Base model: {response.base_model}") + ``` """ - return await self._list_sessions_submit(limit, offset) + + async def _get_sampler_async() -> types.GetSamplerResponse: + async def _send_request() -> types.GetSamplerResponse: + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + return await client.get( + f"/api/v1/samplers/{sampler_id}", + cast_to=types.GetSamplerResponse, + ) + + return await self.holder.execute_with_retries(_send_request) + + return self.holder.run_coroutine_threadsafe(_get_sampler_async()) + + @capture_exceptions(fatal=True) + async def get_sampler_async(self, sampler_id: str) -> types.GetSamplerResponse: + """Async version of get_sampler.""" + return await self.get_sampler(sampler_id) diff --git a/src/tinker/lib/public_interfaces/sampling_client.py b/src/tinker/lib/public_interfaces/sampling_client.py index 98ac78e..80b7e7f 100644 --- a/src/tinker/lib/public_interfaces/sampling_client.py +++ b/src/tinker/lib/public_interfaces/sampling_client.py @@ -47,11 +47,13 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): retry_config: Configuration for retrying failed requests Example: - >>> sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B") - >>> prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is")) - >>> params = types.SamplingParams(max_tokens=20, temperature=0.7) - >>> future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1) - >>> result = future.result() + ```python + sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B") + prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is")) + params = types.SamplingParams(max_tokens=20, temperature=0.7) + future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1) + result = future.result() + ``` """ def __init__( @@ -88,8 +90,12 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): retry_config: RetryConfig | None, ) -> SamplingClient: if sampling_session_id is None: - sampling_session_id = await holder._create_sampling_session(model_path=model_path, base_model=base_model) - return SamplingClient(holder, sampling_session_id=sampling_session_id, retry_config=retry_config) + sampling_session_id = await holder._create_sampling_session( + model_path=model_path, base_model=base_model + ) + return SamplingClient( + holder, sampling_session_id=sampling_session_id, retry_config=retry_config + ) @staticmethod def create( @@ -100,7 +106,15 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): sampling_session_id: str | None = None, retry_config: RetryConfig | None = None, ) -> APIFuture[SamplingClient]: - return holder.run_coroutine_threadsafe(SamplingClient._create_impl(holder, model_path=model_path, base_model=base_model, sampling_session_id=sampling_session_id, retry_config=retry_config)) + return holder.run_coroutine_threadsafe( + SamplingClient._create_impl( + holder, + model_path=model_path, + base_model=base_model, + sampling_session_id=sampling_session_id, + retry_config=retry_config, + ) + ) async def _send_asample_request( self, @@ -182,7 +196,28 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): include_prompt_logprobs: bool = False, topk_prompt_logprobs: int = 0, ) -> ConcurrentFuture[types.SampleResponse]: - """Internal method that does the actual API call without retry logic.""" + """Generate text completions from the model. + + Args: + prompt: The input tokens as ModelInput + num_samples: Number of independent samples to generate + sampling_params: Parameters controlling generation (temperature, max_tokens, etc.) + include_prompt_logprobs: Whether to include log probabilities for prompt tokens + topk_prompt_logprobs: Number of top token log probabilities to return per position + + Returns: + A Future containing the SampleResponse with generated text + + Example: + ```python + prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is")) + params = types.SamplingParams(max_tokens=20, temperature=0.7) + future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1) + result = future.result() + for sample in result.samples: + print(tokenizer.decode(sample.tokens)) + ``` + """ async def _sample_async(): return await self._sample_async_impl( @@ -208,6 +243,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): include_prompt_logprobs: bool = False, topk_prompt_logprobs: int = 0, ) -> types.SampleResponse: + """Async version of sample.""" return await AwaitableConcurrentFuture( self.sample( prompt, @@ -220,6 +256,26 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): @capture_exceptions(fatal=True) def compute_logprobs(self, prompt: types.ModelInput) -> ConcurrentFuture[list[float | None]]: + """Compute log probabilities for prompt tokens. + + Args: + prompt: The input tokens as ModelInput + + Returns: + A Future containing a list of log probabilities for each token in the prompt. + None values indicate tokens where log probabilities couldn't be computed. + + Example: + ```python + prompt = types.ModelInput.from_ints(tokenizer.encode("Hello world")) + future = sampling_client.compute_logprobs(prompt) + logprobs = future.result() + for i, logprob in enumerate(logprobs): + if logprob is not None: + print(f"Token {i}: logprob = {logprob:.4f}") + ``` + """ + async def _compute_logprobs_async() -> list[float | None]: sample_res = await self._sample_async_impl( prompt, @@ -236,6 +292,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): return self.holder.run_coroutine_threadsafe(_compute_logprobs_async_with_retries()).future() async def compute_logprobs_async(self, prompt: types.ModelInput) -> list[float | None]: + """Async version of compute_logprobs.""" return await AwaitableConcurrentFuture(self.compute_logprobs(prompt)) def get_telemetry(self) -> Telemetry | None: @@ -255,7 +312,9 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): reason = "unknown" self._last_queue_state_logged = time.time() - logger.warning(f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {reason}") + logger.warning( + f"Sampling is paused for sampler {self._sampling_session_id}. Reason: {reason}" + ) @lru_cache(maxsize=100) diff --git a/src/tinker/lib/public_interfaces/service_client.py b/src/tinker/lib/public_interfaces/service_client.py index 0cf5eaa..450949b 100644 --- a/src/tinker/lib/public_interfaces/service_client.py +++ b/src/tinker/lib/public_interfaces/service_client.py @@ -40,14 +40,16 @@ class ServiceClient(TelemetryProvider): including API keys, headers, and connection settings. Example: - >>> client = ServiceClient() - # ^^^ near-instant - >>> training_client = client.create_lora_training_client(base_model="Qwen/Qwen3-8B") - # ^^^ takes a moment as we initialize the model and assign resources - >>> sampling_client = client.create_sampling_client(base_model="Qwen/Qwen3-8B") - # ^^^ near-instant - >>> rest_client = client.create_rest_client() - # ^^^ near-instant + ```python + client = ServiceClient() + # ^^^ near-instant + training_client = client.create_lora_training_client(base_model="Qwen/Qwen3-8B") + # ^^^ takes a moment as we initialize the model and assign resources + sampling_client = client.create_sampling_client(base_model="Qwen/Qwen3-8B") + # ^^^ near-instant + rest_client = client.create_rest_client() + # ^^^ near-instant + ``` """ def __init__(self, user_metadata: dict[str, str] | None = None, **kwargs: Any): @@ -75,10 +77,23 @@ class ServiceClient(TelemetryProvider): @sync_only @capture_exceptions(fatal=True) def get_server_capabilities(self) -> types.GetServerCapabilitiesResponse: + """Query the server's supported features and capabilities. + + Returns: + GetServerCapabilitiesResponse with available models, features, and limits + + Example: + ```python + capabilities = service_client.get_server_capabilities() + print(f"Supported models: {capabilities.supported_models}") + print(f"Max batch size: {capabilities.max_batch_size}") + ``` + """ return self._get_server_capabilities_submit().result() @capture_exceptions(fatal=True) async def get_server_capabilities_async(self) -> types.GetServerCapabilitiesResponse: + """Async version of get_server_capabilities.""" return await self._get_server_capabilities_submit() def _create_lora_training_client_submit( @@ -142,6 +157,31 @@ class ServiceClient(TelemetryProvider): train_unembed: bool = True, user_metadata: dict[str, str] | None = None, ) -> TrainingClient: + """Create a TrainingClient for LoRA fine-tuning. + + Args: + base_model: Name of the base model to fine-tune (e.g., "Qwen/Qwen2.5-7B") + rank: LoRA rank controlling the size of adaptation matrices (default 32) + seed: Random seed for initialization. None means random seed. + train_mlp: Whether to train MLP layers (default True) + train_attn: Whether to train attention layers (default True) + train_unembed: Whether to train unembedding layers (default True) + user_metadata: Optional metadata to attach to the training run + + Returns: + TrainingClient configured for LoRA training + + Example: + ```python + training_client = service_client.create_lora_training_client( + base_model="Qwen/Qwen2.5-7B", + rank=16, + train_mlp=True, + train_attn=True + ) + # Now use training_client.forward_backward() to train + ``` + """ return self._create_lora_training_client_submit( base_model, rank, @@ -163,6 +203,7 @@ class ServiceClient(TelemetryProvider): train_unembed: bool = True, user_metadata: dict[str, str] | None = None, ) -> TrainingClient: + """Async version of create_lora_training_client.""" return await self._create_lora_training_client_submit( base_model, rank, @@ -178,16 +219,32 @@ class ServiceClient(TelemetryProvider): def create_training_client_from_state( self, path: str, user_metadata: dict[str, str] | None = None ) -> TrainingClient: - rest_client = self.create_rest_client() - training_run = rest_client.get_training_run_by_tinker_path(path).result() + """Create a TrainingClient from saved model weights. - # Merge user metadata dicts - user_metdata = {**(training_run.user_metadata or {}), **(user_metadata or {})} + Args: + path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") + user_metadata: Optional metadata to attach to the new training run + + Returns: + TrainingClient loaded with the specified weights + + Example: + ```python + # Resume training from a checkpoint + training_client = service_client.create_training_client_from_state( + "tinker://run-id/weights/checkpoint-001" + ) + # Continue training from the loaded state + ``` + """ + rest_client = self.create_rest_client() + # Use weights info endpoint which allows access to models with public checkpoints + weights_info = rest_client.get_weights_info_by_tinker_path(path).result() training_client = self.create_lora_training_client( - base_model=training_run.base_model, - rank=training_run.lora_rank, - user_metadata=user_metdata, + base_model=weights_info.base_model, + rank=weights_info.lora_rank, + user_metadata=user_metadata, ) training_client.load_state(path).result() @@ -197,18 +254,18 @@ class ServiceClient(TelemetryProvider): async def create_training_client_from_state_async( self, path: str, user_metadata: dict[str, str] | None = None ) -> TrainingClient: + """Async version of create_training_client_from_state.""" rest_client = self.create_rest_client() - training_run = await rest_client.get_training_run_by_tinker_path_async(path) + # Use weights info endpoint which allows access to models with public checkpoints + weights_info = await rest_client.get_weights_info_by_tinker_path(path) # Right now all training runs are LoRa runs. - assert training_run.is_lora and training_run.lora_rank is not None - # Merge user metadata dicts - user_metdata = {**(training_run.user_metadata or {}), **(user_metadata or {})} + assert weights_info.is_lora and weights_info.lora_rank is not None training_client = await self.create_lora_training_client_async( - base_model=training_run.base_model, - rank=training_run.lora_rank, - user_metadata=user_metdata, + base_model=weights_info.base_model, + rank=weights_info.lora_rank, + user_metadata=user_metadata, ) load_future = await training_client.load_state_async(path) @@ -222,6 +279,32 @@ class ServiceClient(TelemetryProvider): base_model: str | None = None, retry_config: RetryConfig | None = None, ) -> SamplingClient: + """Create a SamplingClient for text generation. + + Args: + model_path: Path to saved model weights (e.g., "tinker://run-id/weights/checkpoint-001") + base_model: Name of base model to use (e.g., "Qwen/Qwen2.5-7B") + retry_config: Optional configuration for retrying failed requests + + Returns: + SamplingClient configured for text generation + + Raises: + ValueError: If neither model_path nor base_model is provided + + Example: + ```python + # Use a base model + sampling_client = service_client.create_sampling_client( + base_model="Qwen/Qwen2.5-7B" + ) + + # Or use saved weights + sampling_client = service_client.create_sampling_client( + model_path="tinker://run-id/weights/checkpoint-001" + ) + ``` + """ from .sampling_client import SamplingClient if model_path is None and base_model is None: @@ -240,6 +323,7 @@ class ServiceClient(TelemetryProvider): base_model: str | None = None, retry_config: RetryConfig | None = None, ) -> SamplingClient: + """Async version of create_sampling_client.""" from .sampling_client import SamplingClient if model_path is None and base_model is None: @@ -255,12 +339,27 @@ class ServiceClient(TelemetryProvider): def create_rest_client(self) -> RestClient: """Create a RestClient for REST API operations. + The RestClient provides access to various REST endpoints for querying + model information, checkpoints, sessions, and managing checkpoint visibility. + Returns: - RestClient: A client for listing weights and other REST operations + RestClient for accessing REST API endpoints Example: - >>> rest_client = service_client.create_rest_client() - >>> weights = rest_client.list_model_weights("my-model-id").result() + ```python + rest_client = service_client.create_rest_client() + + # List checkpoints for a training run + checkpoints = rest_client.list_checkpoints("run-id").result() + + # Get training run info + training_run = rest_client.get_training_run("run-id").result() + + # Publish a checkpoint + rest_client.publish_checkpoint_from_tinker_path( + "tinker://run-id/weights/checkpoint-001" + ).result() + ``` """ from .rest_client import RestClient diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index a8e7128..e367756 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -62,12 +62,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): model_id: Unique identifier for the model to train. Required for training operations. Example: - >>> training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B") - >>> fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy") - >>> optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4)) - >>> fwdbwd_result = fwdbwd_future.result() # Wait for gradients - >>> optim_result = optim_future.result() # Wait for parameter update - >>> sampling_client = training_client.save_weights_and_get_sampling_client("my-model") + ```python + training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B") + fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy") + optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4)) + fwdbwd_result = fwdbwd_future.result() # Wait for gradients + optim_result = optim_future.result() # Wait for parameter update + sampling_client = training_client.save_weights_and_get_sampling_client("my-model") + ``` """ def __init__(self, holder: InternalClientHolder, model_seq_id: int, model_id: types.ModelID): @@ -175,6 +177,27 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): loss_fn: types.LossFnType, loss_fn_config: Dict[str, float] | None = None, ) -> APIFuture[types.ForwardBackwardOutput]: + """Compute forward pass without gradients. + + Args: + data: List of training data samples + loss_fn: Loss function type (e.g., "cross_entropy") + loss_fn_config: Optional configuration for the loss function + + Returns: + APIFuture containing the forward pass outputs and loss + + Example: + ```python + data = [types.Datum( + model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")), + loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))} + )] + future = training_client.forward(data, "cross_entropy") + result = await future + print(f"Loss: {result.loss}") + ``` + """ requests = self._chunked_requests(data) @capture_exceptions(fatal=True) @@ -205,6 +228,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): loss_fn: types.LossFnType, loss_fn_config: Dict[str, float] | None = None, ) -> APIFuture[types.ForwardBackwardOutput]: + """Async version of forward.""" return self.forward(data, loss_fn, loss_fn_config) async def _send_single_forward_backward_request( @@ -233,6 +257,35 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): loss_fn: types.LossFnType, loss_fn_config: Dict[str, float] | None = None, ) -> APIFuture[types.ForwardBackwardOutput]: + """Compute forward pass and backward pass to calculate gradients. + + Args: + data: List of training data samples + loss_fn: Loss function type (e.g., "cross_entropy") + loss_fn_config: Optional configuration for the loss function + + Returns: + APIFuture containing the forward/backward outputs, loss, and gradients + + Example: + ```python + data = [types.Datum( + model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")), + loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))} + )] + + # Compute gradients + fwdbwd_future = training_client.forward_backward(data, "cross_entropy") + + # Update parameters + optim_future = training_client.optim_step( + types.AdamParams(learning_rate=1e-4) + ) + + fwdbwd_result = await fwdbwd_future + print(f"Loss: {fwdbwd_result.loss}") + ``` + """ requests = self._chunked_requests(data) @capture_exceptions(fatal=True) @@ -269,6 +322,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): loss_fn: types.LossFnType, loss_fn_config: Dict[str, float] | None = None, ) -> APIFuture[types.ForwardBackwardOutput]: + """Async version of forward_backward.""" return self.forward_backward(data, loss_fn, loss_fn_config) @sync_only @@ -276,7 +330,32 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): def forward_backward_custom( self, data: List[types.Datum], loss_fn: CustomLossFnV1 ) -> APIFuture[types.ForwardBackwardOutput]: - """Synchronous version of forward_backward_custom_async.""" + """Compute forward/backward with a custom loss function. + + Allows you to define custom loss functions that operate on log probabilities. + The custom function receives logprobs and computes loss and gradients. + + Args: + data: List of training data samples + loss_fn: Custom loss function that takes (data, logprobs) and returns (loss, metrics) + + Returns: + APIFuture containing the forward/backward outputs with custom loss + + Example: + ```python + def custom_loss(data, logprobs_list): + # Custom loss computation + loss = torch.mean(torch.stack([torch.mean(lp) for lp in logprobs_list])) + metrics = {"custom_metric": loss.item()} + return loss, metrics + + future = training_client.forward_backward_custom(data, custom_loss) + result = future.result() + print(f"Custom loss: {result.loss}") + print(f"Metrics: {result.metrics}") + ``` + """ return self.holder.run_coroutine_threadsafe( self.forward_backward_custom_async(data, loss_fn) ).result() @@ -285,6 +364,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): async def forward_backward_custom_async( self, data: List[types.Datum], loss_fn: CustomLossFnV1 ) -> APIFuture[types.ForwardBackwardOutput]: + """Async version of forward_backward_custom.""" import torch # First do a forward pass and get logprobs @@ -333,6 +413,32 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): @capture_exceptions(fatal=True) def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]: + """Update model parameters using Adam optimizer. + + Args: + adam_params: Adam optimizer parameters (learning_rate, betas, eps, weight_decay) + + Returns: + APIFuture containing optimizer step response + + Example: + ```python + # First compute gradients + fwdbwd_future = training_client.forward_backward(data, "cross_entropy") + + # Then update parameters + optim_future = training_client.optim_step( + types.AdamParams( + learning_rate=1e-4, + weight_decay=0.01 + ) + ) + + # Wait for both to complete + fwdbwd_result = await fwdbwd_future + optim_result = await optim_future + ``` + """ request_id = self._get_request_id() @capture_exceptions(fatal=True) @@ -366,10 +472,27 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): async def optim_step_async( self, adam_params: types.AdamParams ) -> APIFuture[types.OptimStepResponse]: + """Async version of optim_step.""" return self.optim_step(adam_params) @capture_exceptions(fatal=True) def save_state(self, name: str) -> APIFuture[types.SaveWeightsResponse]: + """Save model weights to persistent storage. + + Args: + name: Name for the saved checkpoint + + Returns: + APIFuture containing the save response with checkpoint path + + Example: + ```python + # Save after training + save_future = training_client.save_state("checkpoint-001") + result = await save_future + print(f"Saved to: {result.path}") + ``` + """ request_id = self._get_request_id() @capture_exceptions(fatal=True) @@ -401,43 +524,92 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): return self.holder.run_coroutine_threadsafe(_save_state_async()) async def save_state_async(self, name: str) -> APIFuture[types.SaveWeightsResponse]: + """Async version of save_state.""" return self.save_state(name) @capture_exceptions(fatal=True) - def load_state(self, path: str) -> APIFuture[types.LoadWeightsResponse]: - request_id = self._get_request_id() + async def _load_state_impl( + self, request_id: int, path: str, optimizer: bool + ) -> types.LoadWeightsResponse: + start_time = time.time() - @capture_exceptions(fatal=True) - async def _load_state_async(): - start_time = time.time() - - async def _send_request(): - request = types.LoadWeightsRequest( - model_id=self._guaranteed_model_id(), - path=path, - seq_id=request_id + 1, - ) - with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: - return await client.weights.load( - request=request, - ) - - async with self._take_turn(request_id): - future = await self.holder.execute_with_retries(_send_request) - return await _APIFuture( - types.LoadWeightsResponse, - self.holder, - future, - request_start_time=start_time, - request_type="LoadWeights", - queue_state_observer=self, + async def _send_request(): + request = types.LoadWeightsRequest( + model_id=self._guaranteed_model_id(), + path=path, + seq_id=request_id + 1, + optimizer=optimizer, ) + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + return await client.weights.load( + request=request, + ) - return self.holder.run_coroutine_threadsafe(_load_state_async()) + async with self._take_turn(request_id): + future = await self.holder.execute_with_retries(_send_request) + return await _APIFuture( + types.LoadWeightsResponse, + self.holder, + future, + request_start_time=start_time, + request_type="LoadWeights", + queue_state_observer=self, + ) + + @capture_exceptions(fatal=True) + def load_state(self, path: str) -> APIFuture[types.LoadWeightsResponse]: + """Load model weights from a saved checkpoint. + + Args: + path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") + + Returns: + APIFuture containing the load response + + Example: + ```python + # Load checkpoint to continue training + load_future = training_client.load_state("tinker://run-id/weights/checkpoint-001") + await load_future + # Continue training from loaded state + ``` + """ + request_id = self._get_request_id() + return self.holder.run_coroutine_threadsafe(self._load_state_impl(request_id, path, False)) async def load_state_async(self, path: str) -> APIFuture[types.LoadWeightsResponse]: + """Async version of load_state.""" return self.load_state(path) + @capture_exceptions(fatal=True) + def load_state_with_optimizer(self, path: str) -> APIFuture[types.LoadWeightsResponse]: + """Load model weights and optimizer state from a checkpoint. + + Args: + path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") + + Returns: + APIFuture containing the load response + + Example: + ```python + # Resume training with optimizer state + load_future = training_client.load_state_with_optimizer( + "tinker://run-id/weights/checkpoint-001" + ) + await load_future + # Continue training with restored optimizer momentum + ``` + """ + request_id = self._get_request_id() + return self.holder.run_coroutine_threadsafe(self._load_state_impl(request_id, path, True)) + + async def load_state_with_optimizer_async( + self, path: str + ) -> APIFuture[types.LoadWeightsResponse]: + """Async version of load_state_with_optimizer.""" + return self.load_state_with_optimizer(path) + @capture_exceptions(fatal=True) async def _save_weights_for_sampler_impl( self, request_id: int, name: str | None @@ -478,6 +650,27 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): @capture_exceptions(fatal=True) def save_weights_for_sampler(self, name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]: + """Save model weights for use with a SamplingClient. + + Args: + name: Name for the saved sampler weights + + Returns: + APIFuture containing the save response with sampler path + + Example: + ```python + # Save weights for inference + save_future = training_client.save_weights_for_sampler("sampler-001") + result = await save_future + print(f"Sampler weights saved to: {result.path}") + + # Use the path to create a sampling client + sampling_client = service_client.create_sampling_client( + model_path=result.path + ) + ``` + """ request_id = self._get_request_id() async def _save_weights_for_sampler_async(): @@ -490,6 +683,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): async def save_weights_for_sampler_async( self, name: str ) -> APIFuture[types.SaveWeightsForSamplerResponse]: + """Async version of save_weights_for_sampler.""" return self.save_weights_for_sampler(name) def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]: @@ -508,28 +702,76 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): @sync_only @capture_exceptions(fatal=True) def get_info(self) -> types.GetInfoResponse: + """Get information about the current model. + + Returns: + GetInfoResponse with model configuration and metadata + + Example: + ```python + info = training_client.get_info() + print(f"Model ID: {info.model_data.model_id}") + print(f"Base model: {info.model_data.model_name}") + print(f"LoRA rank: {info.model_data.lora_rank}") + ``` + """ return self._get_info_submit().result() @capture_exceptions(fatal=True) async def get_info_async(self) -> types.GetInfoResponse: + """Async version of get_info.""" return await self._get_info_submit() @cache @capture_exceptions(fatal=True) def get_tokenizer(self) -> PreTrainedTokenizer: + """Get the tokenizer for the current model. + + Returns: + PreTrainedTokenizer compatible with the model + + Example: + ```python + tokenizer = training_client.get_tokenizer() + tokens = tokenizer.encode("Hello world") + text = tokenizer.decode(tokens) + ``` + """ return _get_tokenizer(self._guaranteed_model_id(), self.holder) @capture_exceptions(fatal=True) def create_sampling_client( self, model_path: str, retry_config: RetryConfig | None = None ) -> SamplingClient: - return SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config).result() + """Create a SamplingClient from saved weights. + + Args: + model_path: Tinker path to saved weights + retry_config: Optional configuration for retrying failed requests + + Returns: + SamplingClient configured with the specified weights + + Example: + ```python + sampling_client = training_client.create_sampling_client( + "tinker://run-id/weights/checkpoint-001" + ) + # Use sampling_client for inference + ``` + """ + return SamplingClient.create( + self.holder, model_path=model_path, retry_config=retry_config + ).result() @capture_exceptions(fatal=True) async def create_sampling_client_async( self, model_path: str, retry_config: RetryConfig | None = None ) -> SamplingClient: - return await SamplingClient.create(self.holder, model_path=model_path, retry_config=retry_config) + """Async version of create_sampling_client.""" + return await SamplingClient.create( + self.holder, model_path=model_path, retry_config=retry_config + ) def save_weights_and_get_sampling_client_submit( self, retry_config: RetryConfig | None = None @@ -552,6 +794,26 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): def save_weights_and_get_sampling_client( self, name: str | None = None, retry_config: RetryConfig | None = None ) -> SamplingClient: + """Save current weights and create a SamplingClient for inference. + + Args: + name: Optional name for the saved weights (currently ignored for ephemeral saves) + retry_config: Optional configuration for retrying failed requests + + Returns: + SamplingClient configured with the current model weights + + Example: + ```python + # After training, create a sampling client directly + sampling_client = training_client.save_weights_and_get_sampling_client() + + # Now use it for inference + prompt = types.ModelInput.from_ints(tokenizer.encode("Hello")) + params = types.SamplingParams(max_tokens=20) + result = sampling_client.sample(prompt, 1, params).result() + ``` + """ # Ignore name argument for ephemeral save weights for sampler _ = name return self.save_weights_and_get_sampling_client_submit(retry_config).result() @@ -560,6 +822,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): async def save_weights_and_get_sampling_client_async( self, name: str | None = None, retry_config: RetryConfig | None = None ) -> SamplingClient: + """Async version of save_weights_and_get_sampling_client.""" # Ignore name argument for ephemeral save weights for sampler _ = name return await self.save_weights_and_get_sampling_client_submit(retry_config) @@ -587,9 +850,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> PreTrainedTokenizer: # call get_info on model_id from transformers.models.auto.tokenization_auto import AutoTokenizer + try: from tml_tokenizers import get_tinker_tokenizer except ImportError: + def get_tinker_tokenizer(model_id: str) -> PreTrainedTokenizer | None: return None diff --git a/src/tinker/types/__init__.py b/src/tinker/types/__init__.py index 48115b9..482ade5 100644 --- a/src/tinker/types/__init__.py +++ b/src/tinker/types/__init__.py @@ -36,15 +36,18 @@ from .future_retrieve_response import FutureRetrieveResponse as FutureRetrieveRe from .get_info_request import GetInfoRequest as GetInfoRequest from .get_info_response import GetInfoResponse as GetInfoResponse from .get_info_response import ModelData as ModelData +from .get_sampler_response import GetSamplerResponse as GetSamplerResponse from .get_session_response import GetSessionResponse as GetSessionResponse from .list_sessions_response import ListSessionsResponse as ListSessionsResponse from .get_server_capabilities_response import ( GetServerCapabilitiesResponse as GetServerCapabilitiesResponse, ) from .get_server_capabilities_response import SupportedModel as SupportedModel +from .get_session_response import GetSessionResponse as GetSessionResponse from .health_response import HealthResponse as HealthResponse from .image_asset_pointer_chunk import ImageAssetPointerChunk as ImageAssetPointerChunk from .image_chunk import ImageChunk as ImageChunk +from .list_sessions_response import ListSessionsResponse as ListSessionsResponse from .load_weights_request import LoadWeightsRequest as LoadWeightsRequest from .load_weights_response import LoadWeightsResponse as LoadWeightsResponse from .lora_config import LoraConfig as LoraConfig @@ -93,3 +96,4 @@ from .training_runs_response import TrainingRunsResponse as TrainingRunsResponse from .unhandled_exception_event import UnhandledExceptionEvent as UnhandledExceptionEvent from .unload_model_request import UnloadModelRequest as UnloadModelRequest from .unload_model_response import UnloadModelResponse as UnloadModelResponse +from .weights_info_response import WeightsInfoResponse as WeightsInfoResponse diff --git a/src/tinker/types/get_sampler_response.py b/src/tinker/types/get_sampler_response.py new file mode 100644 index 0000000..a73d9f8 --- /dev/null +++ b/src/tinker/types/get_sampler_response.py @@ -0,0 +1,14 @@ +from .._models import BaseModel + +__all__ = ["GetSamplerResponse"] + + +class GetSamplerResponse(BaseModel): + # The sampler ID (sampling_session_id) + sampler_id: str + + # The base model name + base_model: str + + # Optional model path + model_path: str | None = None diff --git a/src/tinker/types/image_chunk.py b/src/tinker/types/image_chunk.py index 6f5676c..5028393 100644 --- a/src/tinker/types/image_chunk.py +++ b/src/tinker/types/image_chunk.py @@ -25,6 +25,12 @@ class ImageChunk(StrictBase): width: int """Image width in pixels""" + expected_tokens: int | None = None + """Expected number of tokens this image represents. + This is only advisory: the tinker backend will compute the number of tokens + from the image, and we can fail requests quickly if the tokens does not + match expected_tokens.""" + type: Literal["image"] = "image" @field_validator("data", mode="before") diff --git a/src/tinker/types/load_weights_request.py b/src/tinker/types/load_weights_request.py index 23ee524..2d4484c 100644 --- a/src/tinker/types/load_weights_request.py +++ b/src/tinker/types/load_weights_request.py @@ -14,6 +14,9 @@ class LoadWeightsRequest(StrictBase): path: str """A tinker URI for model weights at a specific step""" + optimizer: bool + """Whether to load optimizer state along with model weights""" + seq_id: Optional[int] = None type: Literal["load_weights"] = "load_weights" diff --git a/src/tinker/types/loss_fn_type.py b/src/tinker/types/loss_fn_type.py index a260c86..0955375 100644 --- a/src/tinker/types/loss_fn_type.py +++ b/src/tinker/types/loss_fn_type.py @@ -2,4 +2,4 @@ from typing_extensions import Literal, TypeAlias __all__ = ["LossFnType"] -LossFnType: TypeAlias = Literal["cross_entropy", "importance_sampling", "ppo"] +LossFnType: TypeAlias = Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"] diff --git a/src/tinker/types/weights_info_response.py b/src/tinker/types/weights_info_response.py new file mode 100644 index 0000000..c9b7de8 --- /dev/null +++ b/src/tinker/types/weights_info_response.py @@ -0,0 +1,13 @@ +from .._models import BaseModel + +__all__ = ["WeightsInfoResponse"] + + +class WeightsInfoResponse(BaseModel): + """Minimal information for loading public checkpoints.""" + + base_model: str + + is_lora: bool + + lora_rank: int | None = None diff --git a/tests/test_service_client.py b/tests/test_service_client.py new file mode 100644 index 0000000..4ab0cbc --- /dev/null +++ b/tests/test_service_client.py @@ -0,0 +1,189 @@ +"""Tests for ServiceClient create_training_client_from_state method.""" + +from __future__ import annotations + +import os + +import httpx +import pytest +from respx import MockRouter + +import tinker +from tinker import types + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +@pytest.mark.respx(base_url=base_url) +async def test_create_training_client_from_state_async(respx_mock: MockRouter) -> None: + """Test create_training_client_from_state_async uses public endpoint.""" + tinker_path = "tinker://test-model-123/weights/checkpoint-001" + weights_info_response = types.WeightsInfoResponse( + base_model="meta-llama/Llama-3.2-1B", is_lora=True, lora_rank=32 + ) + + # Mock the get_weights_info endpoint call + respx_mock.post("/api/v1/weights_info").mock( + return_value=httpx.Response(200, json=weights_info_response.model_dump()) + ) + + # Mock the create model call + respx_mock.post("/api/v1/models").mock( + return_value=httpx.Response(200, json={"model_id": "new-model-id"}) + ) + + # Mock the load state call + respx_mock.post("/api/v1/load_weights").mock(return_value=httpx.Response(200, json={})) + + service_client = tinker.ServiceClient(base_url=base_url) + training_client = await service_client.create_training_client_from_state_async(tinker_path) + + assert training_client is not None + assert training_client.model_id == "new-model-id" + + +@pytest.mark.respx(base_url=base_url) +async def test_create_training_client_from_state_async_with_user_metadata( + respx_mock: MockRouter, +) -> None: + """Test create_training_client_from_state_async preserves user metadata.""" + tinker_path = "tinker://test-model-123/weights/checkpoint-001" + user_metadata = {"key1": "value1", "key2": "value2"} + weights_info_response = types.WeightsInfoResponse( + base_model="meta-llama/Llama-3.2-1B", is_lora=True, lora_rank=32 + ) + + # Mock the get_weights_info endpoint call + respx_mock.post("/api/v1/weights_info").mock( + return_value=httpx.Response(200, json=weights_info_response.model_dump()) + ) + + # Mock the create model call + respx_mock.post("/api/v1/models").mock( + return_value=httpx.Response(200, json={"model_id": "new-model-id"}) + ) + + # Mock the load state call + respx_mock.post("/api/v1/load_weights").mock(return_value=httpx.Response(200, json={})) + + service_client = tinker.ServiceClient(base_url=base_url) + training_client = await service_client.create_training_client_from_state_async( + tinker_path, user_metadata=user_metadata + ) + + assert training_client is not None + # Verify user_metadata was passed through (we can't directly check it, but the call succeeded) + + +@pytest.mark.respx(base_url=base_url) +async def test_create_training_client_from_state_async_not_lora(respx_mock: MockRouter) -> None: + """Test create_training_client_from_state_async raises assertion for non-LoRA model.""" + tinker_path = "tinker://test-model-123/weights/checkpoint-001" + + # Mock WeightsInfo response with is_lora=False + weights_info_response = types.WeightsInfoResponse( + base_model="meta-llama/Llama-3.2-1B", is_lora=False, lora_rank=None + ) + + # Mock the get_weights_info endpoint call + respx_mock.post("/api/v1/weights_info").mock( + return_value=httpx.Response(200, json=weights_info_response.model_dump()) + ) + + service_client = tinker.ServiceClient(base_url=base_url) + + # Should raise AssertionError because is_lora=False or lora_rank=None + with pytest.raises(AssertionError): + await service_client.create_training_client_from_state_async(tinker_path) + + +@pytest.mark.respx(base_url=base_url) +async def test_create_training_client_from_state_async_uses_public_endpoint( + respx_mock: MockRouter, +) -> None: + """Test that create_training_client_from_state_async uses get_weights_info_by_tinker_path.""" + tinker_path = "tinker://test-model-123/weights/checkpoint-001" + + # Mock WeightsInfo response + weights_info_response = types.WeightsInfoResponse( + base_model="meta-llama/Llama-3.2-1B", is_lora=True, lora_rank=32 + ) + + # Mock the get_weights_info endpoint call (public endpoint) + info_lite_route = respx_mock.post("/api/v1/weights_info").mock( + return_value=httpx.Response(200, json=weights_info_response.model_dump()) + ) + + # Mock the create model call + respx_mock.post("/api/v1/models").mock( + return_value=httpx.Response(200, json={"model_id": "new-model-id"}) + ) + + # Mock the load state call + respx_mock.post("/api/v1/load_weights").mock(return_value=httpx.Response(200, json={})) + + service_client = tinker.ServiceClient(base_url=base_url) + await service_client.create_training_client_from_state_async(tinker_path) + + # Verify it uses the public endpoint (info_lite), not the full training run endpoint + assert info_lite_route.called + + +@pytest.mark.respx(base_url=base_url) +def test_create_training_client_from_state_sync(respx_mock: MockRouter) -> None: + """Test create_training_client_from_state (sync) uses public endpoint.""" + tinker_path = "tinker://test-model-123/weights/checkpoint-001" + weights_info_response = types.WeightsInfoResponse( + base_model="meta-llama/Llama-3.2-1B", is_lora=True, lora_rank=32 + ) + + # Mock the get_weights_info endpoint call + respx_mock.post("/api/v1/weights_info").mock( + return_value=httpx.Response(200, json=weights_info_response.model_dump()) + ) + + # Mock the create model call + respx_mock.post("/api/v1/models").mock( + return_value=httpx.Response(200, json={"model_id": "new-model-id"}) + ) + + # Mock the load state call + respx_mock.post("/api/v1/load_weights").mock(return_value=httpx.Response(200, json={})) + + service_client = tinker.ServiceClient(base_url=base_url) + training_client = service_client.create_training_client_from_state(tinker_path) + + assert training_client is not None + assert training_client.model_id == "new-model-id" + + +@pytest.mark.respx(base_url=base_url) +def test_create_training_client_from_state_sync_uses_public_endpoint( + respx_mock: MockRouter, +) -> None: + """Test that create_training_client_from_state (sync) uses get_weights_info_by_tinker_path.""" + tinker_path = "tinker://test-model-123/weights/checkpoint-001" + + # Mock WeightsInfo response + weights_info_response = types.WeightsInfoResponse( + base_model="meta-llama/Llama-3.2-1B", is_lora=True, lora_rank=32 + ) + + # Mock the get_weights_info endpoint call (public endpoint) + info_lite_route = respx_mock.post("/api/v1/weights_info").mock( + return_value=httpx.Response(200, json=weights_info_response.model_dump()) + ) + + # Mock the create model call + respx_mock.post("/api/v1/models").mock( + return_value=httpx.Response(200, json={"model_id": "new-model-id"}) + ) + + # Mock the load state call + respx_mock.post("/api/v1/load_weights").mock(return_value=httpx.Response(200, json={})) + + service_client = tinker.ServiceClient(base_url=base_url) + service_client.create_training_client_from_state(tinker_path) + + # Verify it uses the public endpoint (info_lite), not the full training run endpoint + assert info_lite_route.called