From 5e51e3104380f8a47453b6ed80c7077684ede631 Mon Sep 17 00:00:00 2001 From: Daniel Xu Date: Tue, 2 Dec 2025 04:15:53 +0000 Subject: [PATCH] Sync contents --- .sync_state | 4 +- docs/api/serviceclient.md | 40 ++++++++++++ docs/api/types.md | 27 ++------ pyproject.toml | 2 +- .../lib/public_interfaces/service_client.py | 62 +++++++++++++++++++ 5 files changed, 110 insertions(+), 25 deletions(-) diff --git a/.sync_state b/.sync_state index f388eee..a07e116 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "52e233ae8c999937881c32b6b15606de6b391789", - "last_sync_time": "2025-12-02T02:40:27.052745" + "last_synced_sha": "4431cf66bde9b717e16d5c9af17c5083e183ac09", + "last_sync_time": "2025-12-02T04:15:53.354450" } \ No newline at end of file diff --git a/docs/api/serviceclient.md b/docs/api/serviceclient.md index 0895c2a..ac1365c 100644 --- a/docs/api/serviceclient.md +++ b/docs/api/serviceclient.md @@ -147,6 +147,46 @@ async def create_training_client_from_state_async( Async version of create_training_client_from_state. +#### `create_training_client_from_state_with_optimizer` + +```python +def create_training_client_from_state_with_optimizer( + path: str, + user_metadata: dict[str, str] | None = None) -> TrainingClient +``` + +Create a TrainingClient from saved model weights and optimizer state. + +This is similar to create_training_client_from_state but also restores +optimizer state (e.g., Adam momentum), which is useful for resuming +training exactly where it left off. + +Args: +- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") +- `user_metadata`: Optional metadata to attach to the new training run + +Returns: +- `TrainingClient` loaded with the specified weights and optimizer state + +Example: +```python +# Resume training from a checkpoint with optimizer state +training_client = service_client.create_training_client_from_state_with_optimizer( + "tinker://run-id/weights/checkpoint-001" +) +# Continue training with restored optimizer momentum +``` + +#### `create_training_client_from_state_with_optimizer_async` + +```python +async def create_training_client_from_state_with_optimizer_async( + path: str, + user_metadata: dict[str, str] | None = None) -> TrainingClient +``` + +Async version of create_training_client_from_state_with_optimizer. + #### `create_sampling_client` ```python diff --git a/docs/api/types.md b/docs/api/types.md index 05731a3..3d92d75 100644 --- a/docs/api/types.md +++ b/docs/api/types.md @@ -237,21 +237,16 @@ class ImageAssetPointerChunk(StrictBase) Image format -#### `height` - -Image height in pixels - #### `location` Path or URL to the image asset -#### `tokens` +#### `expected_tokens` -Number of tokens this image represents - -#### `width` - -Image width in pixels +Expected number of tokens this image represents. +This is only advisory: the tinker backend will compute the number of tokens +from the image, and we can fail requests quickly if the tokens does not +match expected_tokens. ## `CheckpointsListResponse` Objects @@ -605,18 +600,6 @@ Image data as bytes Image format -#### `height` - -Image height in pixels - -#### `tokens` - -Number of tokens this image represents - -#### `width` - -Image width in pixels - #### `expected_tokens` Expected number of tokens this image represents. diff --git a/pyproject.toml b/pyproject.toml index 5bc89bf..f19ba20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tinker" -version = "0.6.0" +version = "0.6.1" description = "The official Python SDK for the tinker API" readme = "README.md" license = "Apache-2.0" diff --git a/src/tinker/lib/public_interfaces/service_client.py b/src/tinker/lib/public_interfaces/service_client.py index e14774b..cb29abc 100644 --- a/src/tinker/lib/public_interfaces/service_client.py +++ b/src/tinker/lib/public_interfaces/service_client.py @@ -275,6 +275,68 @@ class ServiceClient(TelemetryProvider): await load_future.result_async() return training_client + @sync_only + @capture_exceptions(fatal=True) + def create_training_client_from_state_with_optimizer( + self, path: str, user_metadata: dict[str, str] | None = None + ) -> TrainingClient: + """Create a TrainingClient from saved model weights and optimizer state. + + This is similar to create_training_client_from_state but also restores + optimizer state (e.g., Adam momentum), which is useful for resuming + training exactly where it left off. + + Args: + - `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") + - `user_metadata`: Optional metadata to attach to the new training run + + Returns: + - `TrainingClient` loaded with the specified weights and optimizer state + + Example: + ```python + # Resume training from a checkpoint with optimizer state + training_client = service_client.create_training_client_from_state_with_optimizer( + "tinker://run-id/weights/checkpoint-001" + ) + # Continue training with restored optimizer momentum + ``` + """ + rest_client = self.create_rest_client() + # Use weights info endpoint which allows access to models with public checkpoints + weights_info = rest_client.get_weights_info_by_tinker_path(path).result() + + training_client = self.create_lora_training_client( + base_model=weights_info.base_model, + rank=weights_info.lora_rank, + user_metadata=user_metadata, + ) + + training_client.load_state_with_optimizer(path).result() + return training_client + + @capture_exceptions(fatal=True) + async def create_training_client_from_state_with_optimizer_async( + self, path: str, user_metadata: dict[str, str] | None = None + ) -> TrainingClient: + """Async version of create_training_client_from_state_with_optimizer.""" + rest_client = self.create_rest_client() + # Use weights info endpoint which allows access to models with public checkpoints + weights_info = await rest_client.get_weights_info_by_tinker_path(path) + + # Right now all training runs are LoRa runs. + assert weights_info.is_lora and weights_info.lora_rank is not None + + training_client = await self.create_lora_training_client_async( + base_model=weights_info.base_model, + rank=weights_info.lora_rank, + user_metadata=user_metadata, + ) + + load_future = await training_client.load_state_with_optimizer_async(path) + await load_future.result_async() + return training_client + @capture_exceptions(fatal=True) def create_sampling_client( self,