mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-30 17:40:38 +00:00
Sync contents
This commit is contained in:
parent
3e4e4e3560
commit
951d660110
32 changed files with 3895 additions and 635 deletions
9
docs/api/_meta.json
Normal file
9
docs/api/_meta.json
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"serviceclient": "ServiceClient",
|
||||
"trainingclient": "TrainingClient",
|
||||
"samplingclient": "SamplingClient",
|
||||
"restclient": "RestClient",
|
||||
"apifuture": "APIFuture",
|
||||
"types": "Parameters",
|
||||
"exceptions": "Exceptions"
|
||||
}
|
||||
155
docs/api/apifuture.md
Normal file
155
docs/api/apifuture.md
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# `tinker.lib.public_interfaces.api_future`
|
||||
|
||||
API Future classes for handling async operations with retry logic.
|
||||
|
||||
## `APIFuture` Objects
|
||||
|
||||
```python
|
||||
class APIFuture(ABC, Generic[T])
|
||||
```
|
||||
|
||||
Abstract base class for futures that can be awaited or accessed synchronously.
|
||||
|
||||
APIFuture provides a unified interface for handling async operations that can
|
||||
be accessed both synchronously (via result()) and asynchronously (via await or result_async()).
|
||||
This allows for flexible usage patterns in both sync and async contexts.
|
||||
|
||||
The future can be awaited directly in async contexts:
|
||||
```python
|
||||
result = await api_future # Equivalent to await api_future.result_async()
|
||||
```
|
||||
|
||||
Or accessed synchronously:
|
||||
```python
|
||||
result = api_future.result() # Blocks until complete
|
||||
```
|
||||
|
||||
Args:
|
||||
T: The type of the result value
|
||||
|
||||
Example:
|
||||
```python
|
||||
# In async context
|
||||
future = training_client.forward_backward(data, "cross_entropy")
|
||||
result = await future # Or await future.result_async()
|
||||
|
||||
# In sync context
|
||||
future = training_client.forward_backward(data, "cross_entropy")
|
||||
result = future.result()
|
||||
```
|
||||
|
||||
#### `result_async`
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
async def result_async(timeout: float | None = None) -> T
|
||||
```
|
||||
|
||||
Get the result asynchronously with optional timeout.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds. None means wait indefinitely.
|
||||
|
||||
Returns:
|
||||
The result value of type T
|
||||
|
||||
Raises:
|
||||
TimeoutError: If timeout is exceeded
|
||||
|
||||
#### `result`
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def result(timeout: float | None = None) -> T
|
||||
```
|
||||
|
||||
Get the result synchronously with optional timeout.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds. None means wait indefinitely.
|
||||
|
||||
Returns:
|
||||
The result value of type T
|
||||
|
||||
Raises:
|
||||
TimeoutError: If timeout is exceeded
|
||||
|
||||
## `AwaitableConcurrentFuture` Objects
|
||||
|
||||
```python
|
||||
class AwaitableConcurrentFuture(APIFuture[T])
|
||||
```
|
||||
|
||||
Implementation of APIFuture that wraps a concurrent.futures.Future.
|
||||
|
||||
This class bridges Python's concurrent.futures with asyncio, allowing a
|
||||
standard Future to be used in async contexts. It's commonly returned by
|
||||
Tinker API methods to provide both sync and async access patterns.
|
||||
|
||||
Args:
|
||||
future: A concurrent.futures.Future to wrap
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Internal usage - typically you receive these from API methods
|
||||
concurrent_future = some_operation()
|
||||
api_future = AwaitableConcurrentFuture(concurrent_future)
|
||||
|
||||
# Can be used synchronously
|
||||
result = api_future.result()
|
||||
|
||||
# Or asynchronously
|
||||
result = await api_future
|
||||
```
|
||||
|
||||
#### `result`
|
||||
|
||||
```python
|
||||
def result(timeout: float | None = None) -> T
|
||||
```
|
||||
|
||||
Get the result synchronously with optional timeout.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds. None means wait indefinitely.
|
||||
|
||||
Returns:
|
||||
The result value of type T
|
||||
|
||||
Raises:
|
||||
TimeoutError: If timeout is exceeded
|
||||
Exception: Any exception raised by the underlying operation
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.get_training_run("run-id")
|
||||
result = future.result(timeout=30) # Wait up to 30 seconds
|
||||
```
|
||||
|
||||
#### `result_async`
|
||||
|
||||
```python
|
||||
async def result_async(timeout: float | None = None) -> T
|
||||
```
|
||||
|
||||
Async version of result.
|
||||
|
||||
#### `future`
|
||||
|
||||
```python
|
||||
def future() -> ConcurrentFuture[T]
|
||||
```
|
||||
|
||||
Get the underlying concurrent.futures.Future.
|
||||
|
||||
Returns:
|
||||
The wrapped ConcurrentFuture object
|
||||
|
||||
Example:
|
||||
```python
|
||||
api_future = rest_client.get_training_run("run-id")
|
||||
concurrent_future = api_future.future()
|
||||
# Can now use standard concurrent.futures methods
|
||||
if concurrent_future.done():
|
||||
result = concurrent_future.result()
|
||||
```
|
||||
132
docs/api/exceptions.md
Normal file
132
docs/api/exceptions.md
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
# `tinker._exceptions`
|
||||
|
||||
## `TinkerError` Objects
|
||||
|
||||
```python
|
||||
class TinkerError(Exception)
|
||||
```
|
||||
|
||||
Base exception for all Tinker-related errors.
|
||||
|
||||
## `APIError` Objects
|
||||
|
||||
```python
|
||||
class APIError(TinkerError)
|
||||
```
|
||||
|
||||
Base class for all API-related errors.
|
||||
|
||||
#### `body`
|
||||
|
||||
The API response body.
|
||||
|
||||
If the API responded with a valid JSON structure then this property will be the
|
||||
decoded result.
|
||||
|
||||
If it isn't a valid JSON structure then this will be the raw response.
|
||||
|
||||
If there was no response associated with this error then it will be `None`.
|
||||
|
||||
## `APIResponseValidationError` Objects
|
||||
|
||||
```python
|
||||
class APIResponseValidationError(APIError)
|
||||
```
|
||||
|
||||
Raised when API response doesn't match expected schema.
|
||||
|
||||
## `APIStatusError` Objects
|
||||
|
||||
```python
|
||||
class APIStatusError(APIError)
|
||||
```
|
||||
|
||||
Raised when an API response has a status code of 4xx or 5xx.
|
||||
|
||||
## `APIConnectionError` Objects
|
||||
|
||||
```python
|
||||
class APIConnectionError(APIError)
|
||||
```
|
||||
|
||||
Raised when a connection error occurs while making an API request.
|
||||
|
||||
## `APITimeoutError` Objects
|
||||
|
||||
```python
|
||||
class APITimeoutError(APIConnectionError)
|
||||
```
|
||||
|
||||
Raised when an API request times out.
|
||||
|
||||
## `BadRequestError` Objects
|
||||
|
||||
```python
|
||||
class BadRequestError(APIStatusError)
|
||||
```
|
||||
|
||||
HTTP 400: The request was invalid or malformed.
|
||||
|
||||
## `AuthenticationError` Objects
|
||||
|
||||
```python
|
||||
class AuthenticationError(APIStatusError)
|
||||
```
|
||||
|
||||
HTTP 401: Authentication credentials are missing or invalid.
|
||||
|
||||
## `PermissionDeniedError` Objects
|
||||
|
||||
```python
|
||||
class PermissionDeniedError(APIStatusError)
|
||||
```
|
||||
|
||||
HTTP 403: Insufficient permissions to access the resource.
|
||||
|
||||
## `NotFoundError` Objects
|
||||
|
||||
```python
|
||||
class NotFoundError(APIStatusError)
|
||||
```
|
||||
|
||||
HTTP 404: The requested resource was not found.
|
||||
|
||||
## `ConflictError` Objects
|
||||
|
||||
```python
|
||||
class ConflictError(APIStatusError)
|
||||
```
|
||||
|
||||
HTTP 409: The request conflicts with the current state of the resource.
|
||||
|
||||
## `UnprocessableEntityError` Objects
|
||||
|
||||
```python
|
||||
class UnprocessableEntityError(APIStatusError)
|
||||
```
|
||||
|
||||
HTTP 422: The request was well-formed but contains semantic errors.
|
||||
|
||||
## `RateLimitError` Objects
|
||||
|
||||
```python
|
||||
class RateLimitError(APIStatusError)
|
||||
```
|
||||
|
||||
HTTP 429: Too many requests, rate limit exceeded.
|
||||
|
||||
## `InternalServerError` Objects
|
||||
|
||||
```python
|
||||
class InternalServerError(APIStatusError)
|
||||
```
|
||||
|
||||
HTTP 500+: An error occurred on the server.
|
||||
|
||||
## `RequestFailedError` Objects
|
||||
|
||||
```python
|
||||
class RequestFailedError(TinkerError)
|
||||
```
|
||||
|
||||
Raised when an asynchronous request completes in a failed state.
|
||||
556
docs/api/restclient.md
Normal file
556
docs/api/restclient.md
Normal file
|
|
@ -0,0 +1,556 @@
|
|||
# `tinker.lib.public_interfaces.rest_client`
|
||||
|
||||
RestClient for Tinker API REST operations.
|
||||
|
||||
## `RestClient` Objects
|
||||
|
||||
```python
|
||||
class RestClient(TelemetryProvider)
|
||||
```
|
||||
|
||||
Client for REST API operations like listing checkpoints and metadata.
|
||||
|
||||
The RestClient provides access to various REST endpoints for querying
|
||||
model information, checkpoints, and other resources. You typically get one
|
||||
by calling `service_client.create_rest_client()`.
|
||||
|
||||
Key methods:
|
||||
- list_checkpoints() - list available model checkpoints (both training and sampler)
|
||||
- list_user_checkpoints() - list all checkpoints across all user's training runs
|
||||
- get_training_run() - get model information and metadata as ModelEntry
|
||||
- delete_checkpoint() - delete an existing checkpoint for a training run
|
||||
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
|
||||
- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public
|
||||
- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private
|
||||
|
||||
Args:
|
||||
holder: Internal client managing HTTP connections and async operations
|
||||
|
||||
Example:
|
||||
```python
|
||||
rest_client = service_client.create_rest_client()
|
||||
training_run = rest_client.get_training_run("run-id").result()
|
||||
print(f"Training Run: {training_run.training_run_id}, LoRA: {training_run.is_lora}")
|
||||
checkpoints = rest_client.list_checkpoints("run-id").result()
|
||||
print(f"Found {len(checkpoints.checkpoints)} checkpoints")
|
||||
for checkpoint in checkpoints.checkpoints:
|
||||
print(f" {checkpoint.checkpoint_type}: {checkpoint.checkpoint_id}")
|
||||
```
|
||||
|
||||
#### `get_training_run`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_training_run(
|
||||
training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun]
|
||||
```
|
||||
|
||||
Get training run info.
|
||||
|
||||
Args:
|
||||
training_run_id: The training run ID to get information for
|
||||
|
||||
Returns:
|
||||
A Future containing the training run information
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.get_training_run("run-id")
|
||||
response = future.result()
|
||||
print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
|
||||
```
|
||||
|
||||
#### `get_training_run_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_training_run_async(
|
||||
training_run_id: types.ModelID) -> types.TrainingRun
|
||||
```
|
||||
|
||||
Async version of get_training_run.
|
||||
|
||||
#### `get_training_run_by_tinker_path`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_training_run_by_tinker_path(
|
||||
tinker_path: str) -> ConcurrentFuture[types.TrainingRun]
|
||||
```
|
||||
|
||||
Get training run info.
|
||||
|
||||
Args:
|
||||
tinker_path: The tinker path to the checkpoint
|
||||
|
||||
Returns:
|
||||
A Future containing the training run information
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.get_training_run_by_tinker_path("tinker://run-id/weights/checkpoint-001")
|
||||
response = future.result()
|
||||
print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
|
||||
```
|
||||
|
||||
#### `get_training_run_by_tinker_path_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_training_run_by_tinker_path_async(
|
||||
tinker_path: str) -> types.TrainingRun
|
||||
```
|
||||
|
||||
Async version of get_training_run_by_tinker_path.
|
||||
|
||||
#### `get_weights_info_by_tinker_path`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_weights_info_by_tinker_path(
|
||||
tinker_path: str) -> APIFuture[types.WeightsInfoResponse]
|
||||
```
|
||||
|
||||
Get checkpoint information from a tinker path.
|
||||
|
||||
Args:
|
||||
tinker_path: The tinker path to the checkpoint
|
||||
|
||||
Returns:
|
||||
An APIFuture containing the checkpoint information. The future is awaitable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.get_weights_info_by_tinker_path("tinker://run-id/weights/checkpoint-001")
|
||||
response = future.result() # or await future
|
||||
print(f"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}")
|
||||
```
|
||||
|
||||
#### `list_training_runs`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def list_training_runs(
|
||||
limit: int = 20,
|
||||
offset: int = 0) -> ConcurrentFuture[types.TrainingRunsResponse]
|
||||
```
|
||||
|
||||
List training runs with pagination support.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of training runs to return (default 20)
|
||||
offset: Offset for pagination (default 0)
|
||||
|
||||
Returns:
|
||||
A Future containing the TrainingRunsResponse with training runs and cursor info
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.list_training_runs(limit=50)
|
||||
response = future.result()
|
||||
print(f"Found {len(response.training_runs)} training runs")
|
||||
print(f"Total: {response.cursor.total_count}")
|
||||
# Get next page
|
||||
next_page = rest_client.list_training_runs(limit=50, offset=50)
|
||||
```
|
||||
|
||||
#### `list_training_runs_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def list_training_runs_async(limit: int = 20,
|
||||
offset: int = 0
|
||||
) -> types.TrainingRunsResponse
|
||||
```
|
||||
|
||||
Async version of list_training_runs.
|
||||
|
||||
#### `list_checkpoints`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def list_checkpoints(
|
||||
training_run_id: types.ModelID
|
||||
) -> ConcurrentFuture[types.CheckpointsListResponse]
|
||||
```
|
||||
|
||||
List available checkpoints (both training and sampler).
|
||||
|
||||
Args:
|
||||
training_run_id: The training run ID to list checkpoints for
|
||||
|
||||
Returns:
|
||||
A Future containing the CheckpointsListResponse with available checkpoints
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.list_checkpoints("run-id")
|
||||
response = future.result()
|
||||
for checkpoint in response.checkpoints:
|
||||
if checkpoint.checkpoint_type == "training":
|
||||
print(f"Training checkpoint: {checkpoint.checkpoint_id}")
|
||||
elif checkpoint.checkpoint_type == "sampler":
|
||||
print(f"Sampler checkpoint: {checkpoint.checkpoint_id}")
|
||||
```
|
||||
|
||||
#### `list_checkpoints_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def list_checkpoints_async(
|
||||
training_run_id: types.ModelID) -> types.CheckpointsListResponse
|
||||
```
|
||||
|
||||
Async version of list_checkpoints.
|
||||
|
||||
#### `get_checkpoint_archive_url`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_checkpoint_archive_url(
|
||||
training_run_id: types.ModelID, checkpoint_id: str
|
||||
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]
|
||||
```
|
||||
|
||||
Get signed URL to download checkpoint archive.
|
||||
|
||||
Args:
|
||||
training_run_id: The training run ID to download weights for
|
||||
checkpoint_id: The checkpoint ID to download
|
||||
|
||||
Returns:
|
||||
A Future containing the CheckpointArchiveUrlResponse with signed URL and expiration
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.get_checkpoint_archive_url("run-id", "checkpoint-123")
|
||||
response = future.result()
|
||||
print(f"Download URL: {response.url}")
|
||||
print(f"Expires at: {response.expires_at}")
|
||||
# Use the URL to download the archive with your preferred HTTP client
|
||||
```
|
||||
|
||||
#### `get_checkpoint_archive_url_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_checkpoint_archive_url_async(
|
||||
training_run_id: types.ModelID,
|
||||
checkpoint_id: str) -> types.CheckpointArchiveUrlResponse
|
||||
```
|
||||
|
||||
Async version of get_checkpoint_archive_url.
|
||||
|
||||
#### `delete_checkpoint`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def delete_checkpoint(training_run_id: types.ModelID,
|
||||
checkpoint_id: str) -> ConcurrentFuture[None]
|
||||
```
|
||||
|
||||
Delete a checkpoint for a training run.
|
||||
|
||||
#### `delete_checkpoint_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def delete_checkpoint_async(training_run_id: types.ModelID,
|
||||
checkpoint_id: str) -> None
|
||||
```
|
||||
|
||||
Async version of delete_checkpoint.
|
||||
|
||||
#### `delete_checkpoint_from_tinker_path`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def delete_checkpoint_from_tinker_path(
|
||||
tinker_path: str) -> ConcurrentFuture[None]
|
||||
```
|
||||
|
||||
Delete a checkpoint referenced by a tinker path.
|
||||
|
||||
#### `delete_checkpoint_from_tinker_path_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def delete_checkpoint_from_tinker_path_async(tinker_path: str) -> None
|
||||
```
|
||||
|
||||
Async version of delete_checkpoint_from_tinker_path.
|
||||
|
||||
#### `get_checkpoint_archive_url_from_tinker_path`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_checkpoint_archive_url_from_tinker_path(
|
||||
tinker_path: str
|
||||
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]
|
||||
```
|
||||
|
||||
Get signed URL to download checkpoint archive.
|
||||
|
||||
Args:
|
||||
tinker_path: The tinker path to the checkpoint
|
||||
|
||||
Returns:
|
||||
A Future containing the CheckpointArchiveUrlResponse with signed URL and expiration
|
||||
|
||||
#### `get_checkpoint_archive_url_from_tinker_path_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_checkpoint_archive_url_from_tinker_path_async(
|
||||
tinker_path: str) -> types.CheckpointArchiveUrlResponse
|
||||
```
|
||||
|
||||
Async version of get_checkpoint_archive_url_from_tinker_path.
|
||||
|
||||
#### `publish_checkpoint_from_tinker_path`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def publish_checkpoint_from_tinker_path(
|
||||
tinker_path: str) -> ConcurrentFuture[None]
|
||||
```
|
||||
|
||||
Publish a checkpoint referenced by a tinker path to make it publicly accessible.
|
||||
|
||||
Only the exact owner of the training run can publish checkpoints.
|
||||
Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path method.
|
||||
|
||||
Args:
|
||||
tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||
|
||||
Returns:
|
||||
A Future that completes when the checkpoint is published
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if checkpoint identifier is invalid
|
||||
HTTPException: 404 if checkpoint not found or user doesn't own the training run
|
||||
HTTPException: 409 if checkpoint is already public
|
||||
HTTPException: 500 if there's an error publishing the checkpoint
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.publish_checkpoint_from_tinker_path("tinker://run-id/weights/0001")
|
||||
future.result() # Wait for completion
|
||||
print("Checkpoint published successfully")
|
||||
```
|
||||
|
||||
#### `publish_checkpoint_from_tinker_path_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def publish_checkpoint_from_tinker_path_async(tinker_path: str) -> None
|
||||
```
|
||||
|
||||
Async version of publish_checkpoint_from_tinker_path.
|
||||
|
||||
#### `unpublish_checkpoint_from_tinker_path`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def unpublish_checkpoint_from_tinker_path(
|
||||
tinker_path: str) -> ConcurrentFuture[None]
|
||||
```
|
||||
|
||||
Unpublish a checkpoint referenced by a tinker path to make it private again.
|
||||
|
||||
Only the exact owner of the training run can unpublish checkpoints.
|
||||
This reverses the effect of publishing a checkpoint.
|
||||
|
||||
Args:
|
||||
tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||
|
||||
Returns:
|
||||
A Future that completes when the checkpoint is unpublished
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if checkpoint identifier is invalid
|
||||
HTTPException: 404 if checkpoint not found or user doesn't own the training run
|
||||
HTTPException: 409 if checkpoint is already private
|
||||
HTTPException: 500 if there's an error unpublishing the checkpoint
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.unpublish_checkpoint_from_tinker_path("tinker://run-id/weights/0001")
|
||||
future.result() # Wait for completion
|
||||
print("Checkpoint unpublished successfully")
|
||||
```
|
||||
|
||||
#### `unpublish_checkpoint_from_tinker_path_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def unpublish_checkpoint_from_tinker_path_async(
|
||||
tinker_path: str) -> None
|
||||
```
|
||||
|
||||
Async version of unpublish_checkpoint_from_tinker_path.
|
||||
|
||||
#### `list_user_checkpoints`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def list_user_checkpoints(
|
||||
limit: int = 100,
|
||||
offset: int = 0) -> ConcurrentFuture[types.CheckpointsListResponse]
|
||||
```
|
||||
|
||||
List all checkpoints for the current user across all their training runs.
|
||||
|
||||
This method retrieves checkpoints from all training runs owned by the authenticated user,
|
||||
sorted by time (newest first). It supports pagination for efficiently handling large
|
||||
numbers of checkpoints.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of checkpoints to return (default 100)
|
||||
offset: Offset for pagination (default 0)
|
||||
|
||||
Returns:
|
||||
A Future containing the CheckpointsListResponse with checkpoints and cursor info
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.list_user_checkpoints(limit=50)
|
||||
response = future.result()
|
||||
print(f"Found {len(response.checkpoints)} checkpoints")
|
||||
print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}")
|
||||
for checkpoint in response.checkpoints:
|
||||
print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}")
|
||||
# Get next page if there are more checkpoints
|
||||
if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count:
|
||||
next_page = rest_client.list_user_checkpoints(limit=50, offset=50)
|
||||
```
|
||||
|
||||
#### `list_user_checkpoints_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def list_user_checkpoints_async(limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> types.CheckpointsListResponse
|
||||
```
|
||||
|
||||
Async version of list_user_checkpoints.
|
||||
|
||||
#### `get_session`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_session(session_id: str) -> ConcurrentFuture[types.GetSessionResponse]
|
||||
```
|
||||
|
||||
Get session information including all training runs and samplers.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to get information for
|
||||
|
||||
Returns:
|
||||
A Future containing the GetSessionResponse with training_run_ids and sampler_ids
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.get_session("session-id")
|
||||
response = future.result()
|
||||
print(f"Training runs: {len(response.training_run_ids)}")
|
||||
print(f"Samplers: {len(response.sampler_ids)}")
|
||||
```
|
||||
|
||||
#### `get_session_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_session_async(session_id: str) -> types.GetSessionResponse
|
||||
```
|
||||
|
||||
Async version of get_session.
|
||||
|
||||
#### `list_sessions`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def list_sessions(
|
||||
limit: int = 20,
|
||||
offset: int = 0) -> ConcurrentFuture[types.ListSessionsResponse]
|
||||
```
|
||||
|
||||
List sessions with pagination support.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of sessions to return (default 20)
|
||||
offset: Offset for pagination (default 0)
|
||||
|
||||
Returns:
|
||||
A Future containing the ListSessionsResponse with list of session IDs
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.list_sessions(limit=50)
|
||||
response = future.result()
|
||||
print(f"Found {len(response.sessions)} sessions")
|
||||
# Get next page
|
||||
next_page = rest_client.list_sessions(limit=50, offset=50)
|
||||
```
|
||||
|
||||
#### `list_sessions_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def list_sessions_async(limit: int = 20,
|
||||
offset: int = 0) -> types.ListSessionsResponse
|
||||
```
|
||||
|
||||
Async version of list_sessions.
|
||||
|
||||
#### `get_sampler`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_sampler(sampler_id: str) -> APIFuture[types.GetSamplerResponse]
|
||||
```
|
||||
|
||||
Get sampler information.
|
||||
|
||||
Args:
|
||||
sampler_id: The sampler ID (sampling_session_id) to get information for
|
||||
|
||||
Returns:
|
||||
An APIFuture containing the GetSamplerResponse with sampler details
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Sync usage
|
||||
future = rest_client.get_sampler("session-id:sample:0")
|
||||
response = future.result()
|
||||
print(f"Base model: {response.base_model}")
|
||||
print(f"Model path: {response.model_path}")
|
||||
|
||||
# Async usage
|
||||
response = await rest_client.get_sampler("session-id:sample:0")
|
||||
print(f"Base model: {response.base_model}")
|
||||
```
|
||||
|
||||
#### `get_sampler_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_sampler_async(sampler_id: str) -> types.GetSamplerResponse
|
||||
```
|
||||
|
||||
Async version of get_sampler.
|
||||
116
docs/api/samplingclient.md
Normal file
116
docs/api/samplingclient.md
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
# `tinker.lib.public_interfaces.sampling_client`
|
||||
|
||||
SamplingClient for Tinker API.
|
||||
|
||||
## `SamplingClient` Objects
|
||||
|
||||
```python
|
||||
class SamplingClient(TelemetryProvider, QueueStateObserver)
|
||||
```
|
||||
|
||||
Client for text generation and inference from trained or base models.
|
||||
|
||||
The SamplingClient lets you generate text tokens from either a base model or from weights
|
||||
you've saved using a TrainingClient. You typically get one by calling
|
||||
`service_client.create_sampling_client()` or `training_client.save_weights_and_get_sampling_client()`.
|
||||
Key methods:
|
||||
- sample() - generate text completions with customizable parameters
|
||||
- compute_logprobs() - get log probabilities for prompt tokens
|
||||
|
||||
Args:
|
||||
holder: Internal client managing HTTP connections and async operations
|
||||
model_path: Path to saved model weights (starts with 'tinker://')
|
||||
base_model: Name of base model to use for inference
|
||||
retry_config: Configuration for retrying failed requests
|
||||
|
||||
Example:
|
||||
```python
|
||||
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B")
|
||||
prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is"))
|
||||
params = types.SamplingParams(max_tokens=20, temperature=0.7)
|
||||
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
|
||||
result = future.result()
|
||||
```
|
||||
|
||||
#### `sample`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def sample(
|
||||
prompt: types.ModelInput,
|
||||
num_samples: int,
|
||||
sampling_params: types.SamplingParams,
|
||||
include_prompt_logprobs: bool = False,
|
||||
topk_prompt_logprobs: int = 0
|
||||
) -> ConcurrentFuture[types.SampleResponse]
|
||||
```
|
||||
|
||||
Generate text completions from the model.
|
||||
|
||||
Args:
|
||||
prompt: The input tokens as ModelInput
|
||||
num_samples: Number of independent samples to generate
|
||||
sampling_params: Parameters controlling generation (temperature, max_tokens, etc.)
|
||||
include_prompt_logprobs: Whether to include log probabilities for prompt tokens
|
||||
topk_prompt_logprobs: Number of top token log probabilities to return per position
|
||||
|
||||
Returns:
|
||||
A Future containing the SampleResponse with generated text
|
||||
|
||||
Example:
|
||||
```python
|
||||
prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is"))
|
||||
params = types.SamplingParams(max_tokens=20, temperature=0.7)
|
||||
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
|
||||
result = future.result()
|
||||
for sample in result.samples:
|
||||
print(tokenizer.decode(sample.tokens))
|
||||
```
|
||||
|
||||
#### `sample_async`
|
||||
|
||||
```python
|
||||
async def sample_async(prompt: types.ModelInput,
|
||||
num_samples: int,
|
||||
sampling_params: types.SamplingParams,
|
||||
include_prompt_logprobs: bool = False,
|
||||
topk_prompt_logprobs: int = 0) -> types.SampleResponse
|
||||
```
|
||||
|
||||
Async version of sample.
|
||||
|
||||
#### `compute_logprobs`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def compute_logprobs(
|
||||
prompt: types.ModelInput) -> ConcurrentFuture[list[float | None]]
|
||||
```
|
||||
|
||||
Compute log probabilities for prompt tokens.
|
||||
|
||||
Args:
|
||||
prompt: The input tokens as ModelInput
|
||||
|
||||
Returns:
|
||||
A Future containing a list of log probabilities for each token in the prompt.
|
||||
None values indicate tokens where log probabilities couldn't be computed.
|
||||
|
||||
Example:
|
||||
```python
|
||||
prompt = types.ModelInput.from_ints(tokenizer.encode("Hello world"))
|
||||
future = sampling_client.compute_logprobs(prompt)
|
||||
logprobs = future.result()
|
||||
for i, logprob in enumerate(logprobs):
|
||||
if logprob is not None:
|
||||
print(f"Token {i}: logprob = {logprob:.4f}")
|
||||
```
|
||||
|
||||
#### `compute_logprobs_async`
|
||||
|
||||
```python
|
||||
async def compute_logprobs_async(
|
||||
prompt: types.ModelInput) -> list[float | None]
|
||||
```
|
||||
|
||||
Async version of compute_logprobs.
|
||||
235
docs/api/serviceclient.md
Normal file
235
docs/api/serviceclient.md
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
# `tinker.lib.public_interfaces.service_client`
|
||||
|
||||
ServiceClient for Tinker API.
|
||||
|
||||
## `ServiceClient` Objects
|
||||
|
||||
```python
|
||||
class ServiceClient(TelemetryProvider)
|
||||
```
|
||||
|
||||
The ServiceClient is the main entry point for the Tinker API. It provides methods to:
|
||||
- Query server capabilities and health status
|
||||
- Generate TrainingClient instances for model training workflows
|
||||
- Generate SamplingClient instances for text generation and inference
|
||||
- Generate RestClient instances for REST API operations like listing weights
|
||||
|
||||
Args:
|
||||
**kwargs: advanced options passed to the underlying HTTP client,
|
||||
including API keys, headers, and connection settings.
|
||||
|
||||
Example:
|
||||
```python
|
||||
client = ServiceClient()
|
||||
# ^^^ near-instant
|
||||
training_client = client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
|
||||
# ^^^ takes a moment as we initialize the model and assign resources
|
||||
sampling_client = client.create_sampling_client(base_model="Qwen/Qwen3-8B")
|
||||
# ^^^ near-instant
|
||||
rest_client = client.create_rest_client()
|
||||
# ^^^ near-instant
|
||||
```
|
||||
|
||||
#### `get_server_capabilities`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_server_capabilities() -> types.GetServerCapabilitiesResponse
|
||||
```
|
||||
|
||||
Query the server's supported features and capabilities.
|
||||
|
||||
Returns:
|
||||
GetServerCapabilitiesResponse with available models, features, and limits
|
||||
|
||||
Example:
|
||||
```python
|
||||
capabilities = service_client.get_server_capabilities()
|
||||
print(f"Supported models: {capabilities.supported_models}")
|
||||
print(f"Max batch size: {capabilities.max_batch_size}")
|
||||
```
|
||||
|
||||
#### `get_server_capabilities_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_server_capabilities_async(
|
||||
) -> types.GetServerCapabilitiesResponse
|
||||
```
|
||||
|
||||
Async version of get_server_capabilities.
|
||||
|
||||
#### `create_lora_training_client`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_lora_training_client(
|
||||
base_model: str,
|
||||
rank: int = 32,
|
||||
seed: int | None = None,
|
||||
train_mlp: bool = True,
|
||||
train_attn: bool = True,
|
||||
train_unembed: bool = True,
|
||||
user_metadata: dict[str, str] | None = None) -> TrainingClient
|
||||
```
|
||||
|
||||
Create a TrainingClient for LoRA fine-tuning.
|
||||
|
||||
Args:
|
||||
base_model: Name of the base model to fine-tune (e.g., "Qwen/Qwen2.5-7B")
|
||||
rank: LoRA rank controlling the size of adaptation matrices (default 32)
|
||||
seed: Random seed for initialization. None means random seed.
|
||||
train_mlp: Whether to train MLP layers (default True)
|
||||
train_attn: Whether to train attention layers (default True)
|
||||
train_unembed: Whether to train unembedding layers (default True)
|
||||
user_metadata: Optional metadata to attach to the training run
|
||||
|
||||
Returns:
|
||||
TrainingClient configured for LoRA training
|
||||
|
||||
Example:
|
||||
```python
|
||||
training_client = service_client.create_lora_training_client(
|
||||
base_model="Qwen/Qwen2.5-7B",
|
||||
rank=16,
|
||||
train_mlp=True,
|
||||
train_attn=True
|
||||
)
|
||||
# Now use training_client.forward_backward() to train
|
||||
```
|
||||
|
||||
#### `create_lora_training_client_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_lora_training_client_async(
|
||||
base_model: str,
|
||||
rank: int = 32,
|
||||
seed: int | None = None,
|
||||
train_mlp: bool = True,
|
||||
train_attn: bool = True,
|
||||
train_unembed: bool = True,
|
||||
user_metadata: dict[str, str] | None = None) -> TrainingClient
|
||||
```
|
||||
|
||||
Async version of create_lora_training_client.
|
||||
|
||||
#### `create_training_client_from_state`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_training_client_from_state(
|
||||
path: str,
|
||||
user_metadata: dict[str, str] | None = None) -> TrainingClient
|
||||
```
|
||||
|
||||
Create a TrainingClient from saved model weights.
|
||||
|
||||
Args:
|
||||
path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
||||
user_metadata: Optional metadata to attach to the new training run
|
||||
|
||||
Returns:
|
||||
TrainingClient loaded with the specified weights
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Resume training from a checkpoint
|
||||
training_client = service_client.create_training_client_from_state(
|
||||
"tinker://run-id/weights/checkpoint-001"
|
||||
)
|
||||
# Continue training from the loaded state
|
||||
```
|
||||
|
||||
#### `create_training_client_from_state_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_training_client_from_state_async(
|
||||
path: str,
|
||||
user_metadata: dict[str, str] | None = None) -> TrainingClient
|
||||
```
|
||||
|
||||
Async version of create_training_client_from_state.
|
||||
|
||||
#### `create_sampling_client`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_sampling_client(
|
||||
model_path: str | None = None,
|
||||
base_model: str | None = None,
|
||||
retry_config: RetryConfig | None = None) -> SamplingClient
|
||||
```
|
||||
|
||||
Create a SamplingClient for text generation.
|
||||
|
||||
Args:
|
||||
model_path: Path to saved model weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
||||
base_model: Name of base model to use (e.g., "Qwen/Qwen2.5-7B")
|
||||
retry_config: Optional configuration for retrying failed requests
|
||||
|
||||
Returns:
|
||||
SamplingClient configured for text generation
|
||||
|
||||
Raises:
|
||||
ValueError: If neither model_path nor base_model is provided
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Use a base model
|
||||
sampling_client = service_client.create_sampling_client(
|
||||
base_model="Qwen/Qwen2.5-7B"
|
||||
)
|
||||
|
||||
# Or use saved weights
|
||||
sampling_client = service_client.create_sampling_client(
|
||||
model_path="tinker://run-id/weights/checkpoint-001"
|
||||
)
|
||||
```
|
||||
|
||||
#### `create_sampling_client_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_sampling_client_async(
|
||||
model_path: str | None = None,
|
||||
base_model: str | None = None,
|
||||
retry_config: RetryConfig | None = None) -> SamplingClient
|
||||
```
|
||||
|
||||
Async version of create_sampling_client.
|
||||
|
||||
#### `create_rest_client`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_rest_client() -> RestClient
|
||||
```
|
||||
|
||||
Create a RestClient for REST API operations.
|
||||
|
||||
The RestClient provides access to various REST endpoints for querying
|
||||
model information, checkpoints, sessions, and managing checkpoint visibility.
|
||||
|
||||
Returns:
|
||||
RestClient for accessing REST API endpoints
|
||||
|
||||
Example:
|
||||
```python
|
||||
rest_client = service_client.create_rest_client()
|
||||
|
||||
# List checkpoints for a training run
|
||||
checkpoints = rest_client.list_checkpoints("run-id").result()
|
||||
|
||||
# Get training run info
|
||||
training_run = rest_client.get_training_run("run-id").result()
|
||||
|
||||
# Publish a checkpoint
|
||||
rest_client.publish_checkpoint_from_tinker_path(
|
||||
"tinker://run-id/weights/checkpoint-001"
|
||||
).result()
|
||||
```
|
||||
480
docs/api/trainingclient.md
Normal file
480
docs/api/trainingclient.md
Normal file
|
|
@ -0,0 +1,480 @@
|
|||
# `tinker.lib.public_interfaces.training_client`
|
||||
|
||||
TrainingClient for Tinker API.
|
||||
|
||||
## `TrainingClient` Objects
|
||||
|
||||
```python
|
||||
class TrainingClient(TelemetryProvider, QueueStateObserver)
|
||||
```
|
||||
|
||||
Client for training ML models with forward/backward passes and optimization.
|
||||
|
||||
The TrainingClient corresponds to a fine-tuned model that you can train and sample from.
|
||||
You typically get one by calling `service_client.create_lora_training_client()`.
|
||||
Key methods:
|
||||
- forward_backward() - compute gradients for training
|
||||
- optim_step() - update model parameters with Adam optimizer
|
||||
- save_weights_and_get_sampling_client() - export trained model for inference
|
||||
|
||||
Args:
|
||||
holder: Internal client managing HTTP connections and async operations
|
||||
model_id: Unique identifier for the model to train. Required for training operations.
|
||||
|
||||
Example:
|
||||
```python
|
||||
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B")
|
||||
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
|
||||
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
|
||||
fwdbwd_result = fwdbwd_future.result() # Wait for gradients
|
||||
optim_result = optim_future.result() # Wait for parameter update
|
||||
sampling_client = training_client.save_weights_and_get_sampling_client("my-model")
|
||||
```
|
||||
|
||||
#### `forward`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def forward(
|
||||
data: List[types.Datum],
|
||||
loss_fn: types.LossFnType,
|
||||
loss_fn_config: Dict[str, float] | None = None
|
||||
) -> APIFuture[types.ForwardBackwardOutput]
|
||||
```
|
||||
|
||||
Compute forward pass without gradients.
|
||||
|
||||
Args:
|
||||
data: List of training data samples
|
||||
loss_fn: Loss function type (e.g., "cross_entropy")
|
||||
loss_fn_config: Optional configuration for the loss function
|
||||
|
||||
Returns:
|
||||
APIFuture containing the forward pass outputs and loss
|
||||
|
||||
Example:
|
||||
```python
|
||||
data = [types.Datum(
|
||||
model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")),
|
||||
loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))}
|
||||
)]
|
||||
future = training_client.forward(data, "cross_entropy")
|
||||
result = await future
|
||||
print(f"Loss: {result.loss}")
|
||||
```
|
||||
|
||||
#### `forward_async`
|
||||
|
||||
```python
|
||||
async def forward_async(
|
||||
data: List[types.Datum],
|
||||
loss_fn: types.LossFnType,
|
||||
loss_fn_config: Dict[str, float] | None = None
|
||||
) -> APIFuture[types.ForwardBackwardOutput]
|
||||
```
|
||||
|
||||
Async version of forward.
|
||||
|
||||
#### `forward_backward`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def forward_backward(
|
||||
data: List[types.Datum],
|
||||
loss_fn: types.LossFnType,
|
||||
loss_fn_config: Dict[str, float] | None = None
|
||||
) -> APIFuture[types.ForwardBackwardOutput]
|
||||
```
|
||||
|
||||
Compute forward pass and backward pass to calculate gradients.
|
||||
|
||||
Args:
|
||||
data: List of training data samples
|
||||
loss_fn: Loss function type (e.g., "cross_entropy")
|
||||
loss_fn_config: Optional configuration for the loss function
|
||||
|
||||
Returns:
|
||||
APIFuture containing the forward/backward outputs, loss, and gradients
|
||||
|
||||
Example:
|
||||
```python
|
||||
data = [types.Datum(
|
||||
model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")),
|
||||
loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))}
|
||||
)]
|
||||
|
||||
# Compute gradients
|
||||
fwdbwd_future = training_client.forward_backward(data, "cross_entropy")
|
||||
|
||||
# Update parameters
|
||||
optim_future = training_client.optim_step(
|
||||
types.AdamParams(learning_rate=1e-4)
|
||||
)
|
||||
|
||||
fwdbwd_result = await fwdbwd_future
|
||||
print(f"Loss: {fwdbwd_result.loss}")
|
||||
```
|
||||
|
||||
#### `forward_backward_async`
|
||||
|
||||
```python
|
||||
async def forward_backward_async(
|
||||
data: List[types.Datum],
|
||||
loss_fn: types.LossFnType,
|
||||
loss_fn_config: Dict[str, float] | None = None
|
||||
) -> APIFuture[types.ForwardBackwardOutput]
|
||||
```
|
||||
|
||||
Async version of forward_backward.
|
||||
|
||||
#### `forward_backward_custom`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def forward_backward_custom(
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
|
||||
```
|
||||
|
||||
Compute forward/backward with a custom loss function.
|
||||
|
||||
Allows you to define custom loss functions that operate on log probabilities.
|
||||
The custom function receives logprobs and computes loss and gradients.
|
||||
|
||||
Args:
|
||||
data: List of training data samples
|
||||
loss_fn: Custom loss function that takes (data, logprobs) and returns (loss, metrics)
|
||||
|
||||
Returns:
|
||||
APIFuture containing the forward/backward outputs with custom loss
|
||||
|
||||
Example:
|
||||
```python
|
||||
def custom_loss(data, logprobs_list):
|
||||
# Custom loss computation
|
||||
loss = torch.mean(torch.stack([torch.mean(lp) for lp in logprobs_list]))
|
||||
metrics = {"custom_metric": loss.item()}
|
||||
return loss, metrics
|
||||
|
||||
future = training_client.forward_backward_custom(data, custom_loss)
|
||||
result = future.result()
|
||||
print(f"Custom loss: {result.loss}")
|
||||
print(f"Metrics: {result.metrics}")
|
||||
```
|
||||
|
||||
#### `forward_backward_custom_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def forward_backward_custom_async(
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
|
||||
```
|
||||
|
||||
Async version of forward_backward_custom.
|
||||
|
||||
#### `optim_step`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def optim_step(
|
||||
adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]
|
||||
```
|
||||
|
||||
Update model parameters using Adam optimizer.
|
||||
|
||||
Args:
|
||||
adam_params: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
|
||||
|
||||
Returns:
|
||||
APIFuture containing optimizer step response
|
||||
|
||||
Example:
|
||||
```python
|
||||
# First compute gradients
|
||||
fwdbwd_future = training_client.forward_backward(data, "cross_entropy")
|
||||
|
||||
# Then update parameters
|
||||
optim_future = training_client.optim_step(
|
||||
types.AdamParams(
|
||||
learning_rate=1e-4,
|
||||
weight_decay=0.01
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for both to complete
|
||||
fwdbwd_result = await fwdbwd_future
|
||||
optim_result = await optim_future
|
||||
```
|
||||
|
||||
#### `optim_step_async`
|
||||
|
||||
```python
|
||||
async def optim_step_async(
|
||||
adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]
|
||||
```
|
||||
|
||||
Async version of optim_step.
|
||||
|
||||
#### `save_state`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_state(name: str) -> APIFuture[types.SaveWeightsResponse]
|
||||
```
|
||||
|
||||
Save model weights to persistent storage.
|
||||
|
||||
Args:
|
||||
name: Name for the saved checkpoint
|
||||
|
||||
Returns:
|
||||
APIFuture containing the save response with checkpoint path
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Save after training
|
||||
save_future = training_client.save_state("checkpoint-001")
|
||||
result = await save_future
|
||||
print(f"Saved to: {result.path}")
|
||||
```
|
||||
|
||||
#### `save_state_async`
|
||||
|
||||
```python
|
||||
async def save_state_async(name: str) -> APIFuture[types.SaveWeightsResponse]
|
||||
```
|
||||
|
||||
Async version of save_state.
|
||||
|
||||
#### `load_state`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def load_state(path: str) -> APIFuture[types.LoadWeightsResponse]
|
||||
```
|
||||
|
||||
Load model weights from a saved checkpoint.
|
||||
|
||||
Args:
|
||||
path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
||||
|
||||
Returns:
|
||||
APIFuture containing the load response
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Load checkpoint to continue training
|
||||
load_future = training_client.load_state("tinker://run-id/weights/checkpoint-001")
|
||||
await load_future
|
||||
# Continue training from loaded state
|
||||
```
|
||||
|
||||
#### `load_state_async`
|
||||
|
||||
```python
|
||||
async def load_state_async(path: str) -> APIFuture[types.LoadWeightsResponse]
|
||||
```
|
||||
|
||||
Async version of load_state.
|
||||
|
||||
#### `load_state_with_optimizer`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def load_state_with_optimizer(
|
||||
path: str) -> APIFuture[types.LoadWeightsResponse]
|
||||
```
|
||||
|
||||
Load model weights and optimizer state from a checkpoint.
|
||||
|
||||
Args:
|
||||
path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
||||
|
||||
Returns:
|
||||
APIFuture containing the load response
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Resume training with optimizer state
|
||||
load_future = training_client.load_state_with_optimizer(
|
||||
"tinker://run-id/weights/checkpoint-001"
|
||||
)
|
||||
await load_future
|
||||
# Continue training with restored optimizer momentum
|
||||
```
|
||||
|
||||
#### `load_state_with_optimizer_async`
|
||||
|
||||
```python
|
||||
async def load_state_with_optimizer_async(
|
||||
path: str) -> APIFuture[types.LoadWeightsResponse]
|
||||
```
|
||||
|
||||
Async version of load_state_with_optimizer.
|
||||
|
||||
#### `save_weights_for_sampler`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_weights_for_sampler(
|
||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
```
|
||||
|
||||
Save model weights for use with a SamplingClient.
|
||||
|
||||
Args:
|
||||
name: Name for the saved sampler weights
|
||||
|
||||
Returns:
|
||||
APIFuture containing the save response with sampler path
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Save weights for inference
|
||||
save_future = training_client.save_weights_for_sampler("sampler-001")
|
||||
result = await save_future
|
||||
print(f"Sampler weights saved to: {result.path}")
|
||||
|
||||
# Use the path to create a sampling client
|
||||
sampling_client = service_client.create_sampling_client(
|
||||
model_path=result.path
|
||||
)
|
||||
```
|
||||
|
||||
#### `save_weights_for_sampler_async`
|
||||
|
||||
```python
|
||||
async def save_weights_for_sampler_async(
|
||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
```
|
||||
|
||||
Async version of save_weights_for_sampler.
|
||||
|
||||
#### `get_info`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_info() -> types.GetInfoResponse
|
||||
```
|
||||
|
||||
Get information about the current model.
|
||||
|
||||
Returns:
|
||||
GetInfoResponse with model configuration and metadata
|
||||
|
||||
Example:
|
||||
```python
|
||||
info = training_client.get_info()
|
||||
print(f"Model ID: {info.model_data.model_id}")
|
||||
print(f"Base model: {info.model_data.model_name}")
|
||||
print(f"LoRA rank: {info.model_data.lora_rank}")
|
||||
```
|
||||
|
||||
#### `get_info_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_info_async() -> types.GetInfoResponse
|
||||
```
|
||||
|
||||
Async version of get_info.
|
||||
|
||||
#### `get_tokenizer`
|
||||
|
||||
```python
|
||||
@cache
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_tokenizer() -> PreTrainedTokenizer
|
||||
```
|
||||
|
||||
Get the tokenizer for the current model.
|
||||
|
||||
Returns:
|
||||
PreTrainedTokenizer compatible with the model
|
||||
|
||||
Example:
|
||||
```python
|
||||
tokenizer = training_client.get_tokenizer()
|
||||
tokens = tokenizer.encode("Hello world")
|
||||
text = tokenizer.decode(tokens)
|
||||
```
|
||||
|
||||
#### `create_sampling_client`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_sampling_client(
|
||||
model_path: str,
|
||||
retry_config: RetryConfig | None = None) -> SamplingClient
|
||||
```
|
||||
|
||||
Create a SamplingClient from saved weights.
|
||||
|
||||
Args:
|
||||
model_path: Tinker path to saved weights
|
||||
retry_config: Optional configuration for retrying failed requests
|
||||
|
||||
Returns:
|
||||
SamplingClient configured with the specified weights
|
||||
|
||||
Example:
|
||||
```python
|
||||
sampling_client = training_client.create_sampling_client(
|
||||
"tinker://run-id/weights/checkpoint-001"
|
||||
)
|
||||
# Use sampling_client for inference
|
||||
```
|
||||
|
||||
#### `create_sampling_client_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_sampling_client_async(
|
||||
model_path: str,
|
||||
retry_config: RetryConfig | None = None) -> SamplingClient
|
||||
```
|
||||
|
||||
Async version of create_sampling_client.
|
||||
|
||||
#### `save_weights_and_get_sampling_client`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_weights_and_get_sampling_client(
|
||||
name: str | None = None,
|
||||
retry_config: RetryConfig | None = None) -> SamplingClient
|
||||
```
|
||||
|
||||
Save current weights and create a SamplingClient for inference.
|
||||
|
||||
Args:
|
||||
name: Optional name for the saved weights (currently ignored for ephemeral saves)
|
||||
retry_config: Optional configuration for retrying failed requests
|
||||
|
||||
Returns:
|
||||
SamplingClient configured with the current model weights
|
||||
|
||||
Example:
|
||||
```python
|
||||
# After training, create a sampling client directly
|
||||
sampling_client = training_client.save_weights_and_get_sampling_client()
|
||||
|
||||
# Now use it for inference
|
||||
prompt = types.ModelInput.from_ints(tokenizer.encode("Hello"))
|
||||
params = types.SamplingParams(max_tokens=20)
|
||||
result = sampling_client.sample(prompt, 1, params).result()
|
||||
```
|
||||
|
||||
#### `save_weights_and_get_sampling_client_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def save_weights_and_get_sampling_client_async(
|
||||
name: str | None = None,
|
||||
retry_config: RetryConfig | None = None) -> SamplingClient
|
||||
```
|
||||
|
||||
Async version of save_weights_and_get_sampling_client.
|
||||
911
docs/api/types.md
Normal file
911
docs/api/types.md
Normal file
|
|
@ -0,0 +1,911 @@
|
|||
# `tinker.types.optim_step_request`
|
||||
|
||||
## `AdamParams` Objects
|
||||
|
||||
```python
|
||||
class AdamParams(StrictBase)
|
||||
```
|
||||
|
||||
#### `learning_rate`
|
||||
|
||||
Learning rate for the optimizer
|
||||
|
||||
#### `beta1`
|
||||
|
||||
Coefficient used for computing running averages of gradient
|
||||
|
||||
#### `beta2`
|
||||
|
||||
Coefficient used for computing running averages of gradient square
|
||||
|
||||
#### `eps`
|
||||
|
||||
Term added to the denominator to improve numerical stability
|
||||
|
||||
# `tinker.types.optim_step_response`
|
||||
|
||||
## `OptimStepResponse` Objects
|
||||
|
||||
```python
|
||||
class OptimStepResponse(BaseModel)
|
||||
```
|
||||
|
||||
#### `metrics`
|
||||
|
||||
Optimization step metrics as key-value pairs
|
||||
|
||||
# `tinker.types.model_input`
|
||||
|
||||
## `ModelInput` Objects
|
||||
|
||||
```python
|
||||
class ModelInput(StrictBase)
|
||||
```
|
||||
|
||||
#### `chunks`
|
||||
|
||||
Sequence of input chunks (formerly TokenSequence)
|
||||
|
||||
#### `from_ints`
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
def from_ints(cls, tokens: List[int]) -> "ModelInput"
|
||||
```
|
||||
|
||||
Create a ModelInput from a list of ints (tokens).
|
||||
|
||||
#### `to_ints`
|
||||
|
||||
```python
|
||||
def to_ints() -> List[int]
|
||||
```
|
||||
|
||||
Convert the ModelInput to a list of ints (tokens)
|
||||
Throws exception if there are any non-token chunks
|
||||
|
||||
#### `length`
|
||||
|
||||
```python
|
||||
@property
|
||||
def length() -> int
|
||||
```
|
||||
|
||||
Return the total context length used by this ModelInput.
|
||||
|
||||
#### `empty`
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
def empty(cls) -> "ModelInput"
|
||||
```
|
||||
|
||||
Create an empty ModelInput.
|
||||
|
||||
#### `append`
|
||||
|
||||
```python
|
||||
def append(chunk: ModelInputChunk) -> "ModelInput"
|
||||
```
|
||||
|
||||
Add a new chunk, return a new ModelInput.
|
||||
|
||||
#### `append_int`
|
||||
|
||||
```python
|
||||
def append_int(token: int) -> "ModelInput"
|
||||
```
|
||||
|
||||
Add a new token, return a new ModelInput.
|
||||
|
||||
# `tinker.types.weights_info_response`
|
||||
|
||||
## `WeightsInfoResponse` Objects
|
||||
|
||||
```python
|
||||
class WeightsInfoResponse(BaseModel)
|
||||
```
|
||||
|
||||
Minimal information for loading public checkpoints.
|
||||
|
||||
# `tinker.types.checkpoint`
|
||||
|
||||
## `Checkpoint` Objects
|
||||
|
||||
```python
|
||||
class Checkpoint(BaseModel)
|
||||
```
|
||||
|
||||
#### `checkpoint_id`
|
||||
|
||||
The checkpoint ID
|
||||
|
||||
#### `checkpoint_type`
|
||||
|
||||
The type of checkpoint (training or sampler)
|
||||
|
||||
#### `time`
|
||||
|
||||
The time when the checkpoint was created
|
||||
|
||||
#### `tinker_path`
|
||||
|
||||
The tinker path to the checkpoint
|
||||
|
||||
#### `size_bytes`
|
||||
|
||||
The size of the checkpoint in bytes
|
||||
|
||||
#### `public`
|
||||
|
||||
Whether the checkpoint is publicly accessible
|
||||
|
||||
## `ParsedCheckpointTinkerPath` Objects
|
||||
|
||||
```python
|
||||
class ParsedCheckpointTinkerPath(BaseModel)
|
||||
```
|
||||
|
||||
#### `tinker_path`
|
||||
|
||||
The tinker path to the checkpoint
|
||||
|
||||
#### `training_run_id`
|
||||
|
||||
The training run ID
|
||||
|
||||
#### `checkpoint_type`
|
||||
|
||||
The type of checkpoint (training or sampler)
|
||||
|
||||
#### `checkpoint_id`
|
||||
|
||||
The checkpoint ID
|
||||
|
||||
#### `from_tinker_path`
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath"
|
||||
```
|
||||
|
||||
Parse a tinker path to an instance of ParsedCheckpointTinkerPath
|
||||
|
||||
# `tinker.types.checkpoint_archive_url_response`
|
||||
|
||||
## `CheckpointArchiveUrlResponse` Objects
|
||||
|
||||
```python
|
||||
class CheckpointArchiveUrlResponse(BaseModel)
|
||||
```
|
||||
|
||||
#### `url`
|
||||
|
||||
Signed URL to download the checkpoint archive
|
||||
|
||||
#### `expires`
|
||||
|
||||
Unix timestamp when the signed URL expires, if available
|
||||
|
||||
# `tinker.types.sampled_sequence`
|
||||
|
||||
## `SampledSequence` Objects
|
||||
|
||||
```python
|
||||
class SampledSequence(BaseModel)
|
||||
```
|
||||
|
||||
#### `stop_reason`
|
||||
|
||||
Reason why sampling stopped
|
||||
|
||||
#### `tokens`
|
||||
|
||||
List of generated token IDs
|
||||
|
||||
#### `logprobs`
|
||||
|
||||
Log probabilities for each token (optional)
|
||||
|
||||
# `tinker.types.try_again_response`
|
||||
|
||||
## `TryAgainResponse` Objects
|
||||
|
||||
```python
|
||||
class TryAgainResponse(BaseModel)
|
||||
```
|
||||
|
||||
#### `request_id`
|
||||
|
||||
Request ID that is still pending
|
||||
|
||||
# `tinker.types.load_weights_request`
|
||||
|
||||
## `LoadWeightsRequest` Objects
|
||||
|
||||
```python
|
||||
class LoadWeightsRequest(StrictBase)
|
||||
```
|
||||
|
||||
#### `path`
|
||||
|
||||
A tinker URI for model weights at a specific step
|
||||
|
||||
#### `optimizer`
|
||||
|
||||
Whether to load optimizer state along with model weights
|
||||
|
||||
# `tinker.types.telemetry_send_request`
|
||||
|
||||
## `TelemetrySendRequest` Objects
|
||||
|
||||
```python
|
||||
class TelemetrySendRequest(StrictBase)
|
||||
```
|
||||
|
||||
#### `platform`
|
||||
|
||||
Host platform name
|
||||
|
||||
#### `sdk_version`
|
||||
|
||||
SDK version string
|
||||
|
||||
# `tinker.types.image_asset_pointer_chunk`
|
||||
|
||||
## `ImageAssetPointerChunk` Objects
|
||||
|
||||
```python
|
||||
class ImageAssetPointerChunk(StrictBase)
|
||||
```
|
||||
|
||||
#### `format`
|
||||
|
||||
Image format
|
||||
|
||||
#### `height`
|
||||
|
||||
Image height in pixels
|
||||
|
||||
#### `location`
|
||||
|
||||
Path or URL to the image asset
|
||||
|
||||
#### `tokens`
|
||||
|
||||
Number of tokens this image represents
|
||||
|
||||
#### `width`
|
||||
|
||||
Image width in pixels
|
||||
|
||||
# `tinker.types.checkpoints_list_response`
|
||||
|
||||
## `CheckpointsListResponse` Objects
|
||||
|
||||
```python
|
||||
class CheckpointsListResponse(BaseModel)
|
||||
```
|
||||
|
||||
#### `checkpoints`
|
||||
|
||||
List of available model checkpoints for the model
|
||||
|
||||
#### `cursor`
|
||||
|
||||
Pagination cursor information (None for unpaginated responses)
|
||||
|
||||
# `tinker.types.generic_event`
|
||||
|
||||
## `GenericEvent` Objects
|
||||
|
||||
```python
|
||||
class GenericEvent(BaseModel)
|
||||
```
|
||||
|
||||
#### `event`
|
||||
|
||||
Telemetry event type
|
||||
|
||||
#### `event_name`
|
||||
|
||||
Low-cardinality event name
|
||||
|
||||
#### `severity`
|
||||
|
||||
Log severity level
|
||||
|
||||
#### `event_data`
|
||||
|
||||
Arbitrary structured JSON payload
|
||||
|
||||
# `tinker.types.encoded_text_chunk`
|
||||
|
||||
## `EncodedTextChunk` Objects
|
||||
|
||||
```python
|
||||
class EncodedTextChunk(StrictBase)
|
||||
```
|
||||
|
||||
#### `tokens`
|
||||
|
||||
Array of token IDs
|
||||
|
||||
# `tinker.types.forward_backward_input`
|
||||
|
||||
## `ForwardBackwardInput` Objects
|
||||
|
||||
```python
|
||||
class ForwardBackwardInput(StrictBase)
|
||||
```
|
||||
|
||||
#### `data`
|
||||
|
||||
Array of input data for the forward/backward pass
|
||||
|
||||
#### `loss_fn`
|
||||
|
||||
Fully qualified function path for the loss function
|
||||
|
||||
#### `loss_fn_config`
|
||||
|
||||
Optional configuration parameters for the loss function (e.g., PPO clip thresholds, DPO beta)
|
||||
|
||||
# `tinker.types.session_start_event`
|
||||
|
||||
## `SessionStartEvent` Objects
|
||||
|
||||
```python
|
||||
class SessionStartEvent(BaseModel)
|
||||
```
|
||||
|
||||
#### `event`
|
||||
|
||||
Telemetry event type
|
||||
|
||||
#### `severity`
|
||||
|
||||
Log severity level
|
||||
|
||||
# `tinker.types.training_runs_response`
|
||||
|
||||
## `TrainingRunsResponse` Objects
|
||||
|
||||
```python
|
||||
class TrainingRunsResponse(BaseModel)
|
||||
```
|
||||
|
||||
#### `training_runs`
|
||||
|
||||
List of training runs
|
||||
|
||||
#### `cursor`
|
||||
|
||||
Pagination cursor information
|
||||
|
||||
# `tinker.types.save_weights_response`
|
||||
|
||||
## `SaveWeightsResponse` Objects
|
||||
|
||||
```python
|
||||
class SaveWeightsResponse(BaseModel)
|
||||
```
|
||||
|
||||
#### `path`
|
||||
|
||||
A tinker URI for model weights at a specific step
|
||||
|
||||
# `tinker.types.sample_request`
|
||||
|
||||
## `SampleRequest` Objects
|
||||
|
||||
```python
|
||||
class SampleRequest(StrictBase)
|
||||
```
|
||||
|
||||
#### `num_samples`
|
||||
|
||||
Number of samples to generate
|
||||
|
||||
#### `base_model`
|
||||
|
||||
Optional base model name to sample from.
|
||||
|
||||
Is inferred from model_path, if provided. If sampling against a base model, this
|
||||
is required.
|
||||
|
||||
#### `model_path`
|
||||
|
||||
Optional tinker:// path to your model weights or LoRA weights.
|
||||
|
||||
If not provided, samples against the base model.
|
||||
|
||||
#### `sampling_session_id`
|
||||
|
||||
Optional sampling session ID to use instead of model_path/base_model.
|
||||
|
||||
If provided along with seq_id, the model configuration will be loaded from the
|
||||
sampling session. This is useful for multi-turn conversations.
|
||||
|
||||
#### `seq_id`
|
||||
|
||||
Sequence ID within the sampling session.
|
||||
|
||||
Required when sampling_session_id is provided. Used to generate deterministic
|
||||
request IDs for the sampling request.
|
||||
|
||||
#### `prompt_logprobs`
|
||||
|
||||
If set to `true`, computes and returns logprobs on the prompt tokens.
|
||||
|
||||
Defaults to false.
|
||||
|
||||
#### `topk_prompt_logprobs`
|
||||
|
||||
If set to a positive integer, returns the top-k logprobs for each prompt token.
|
||||
|
||||
# `tinker.types.forward_backward_output`
|
||||
|
||||
## `ForwardBackwardOutput` Objects
|
||||
|
||||
```python
|
||||
class ForwardBackwardOutput(BaseModel)
|
||||
```
|
||||
|
||||
#### `loss_fn_output_type`
|
||||
|
||||
The type of the ForwardBackward output. Can be one of [...] TODO
|
||||
|
||||
#### `loss_fn_outputs`
|
||||
|
||||
Dictionary mapping field names to tensor data
|
||||
|
||||
#### `metrics`
|
||||
|
||||
Training metrics as key-value pairs
|
||||
|
||||
# `tinker.types.sample_response`
|
||||
|
||||
## `SampleResponse` Objects
|
||||
|
||||
```python
|
||||
class SampleResponse(BaseModel)
|
||||
```
|
||||
|
||||
#### `prompt_logprobs`
|
||||
|
||||
If prompt_logprobs was set to true in the request, logprobs are computed for
|
||||
every token in the prompt. The `prompt_logprobs` response contains a float32
|
||||
value for every token in the prompt.
|
||||
|
||||
#### `topk_prompt_logprobs`
|
||||
|
||||
If topk_prompt_logprobs was set to a positive integer k in the request,
|
||||
the top-k logprobs are computed for every token in the prompt. The
|
||||
`topk_prompt_logprobs` response contains, for every token in the prompt,
|
||||
a list of up to k (token_id, logprob) tuples.
|
||||
|
||||
# `tinker.types.create_sampling_session_response`
|
||||
|
||||
## `CreateSamplingSessionResponse` Objects
|
||||
|
||||
```python
|
||||
class CreateSamplingSessionResponse(BaseModel)
|
||||
```
|
||||
|
||||
#### `sampling_session_id`
|
||||
|
||||
The generated sampling session ID
|
||||
|
||||
# `tinker.types.cursor`
|
||||
|
||||
## `Cursor` Objects
|
||||
|
||||
```python
|
||||
class Cursor(BaseModel)
|
||||
```
|
||||
|
||||
#### `offset`
|
||||
|
||||
The offset used for pagination
|
||||
|
||||
#### `limit`
|
||||
|
||||
The maximum number of items requested
|
||||
|
||||
#### `total_count`
|
||||
|
||||
The total number of items available
|
||||
|
||||
# `tinker.types.create_model_request`
|
||||
|
||||
## `CreateModelRequest` Objects
|
||||
|
||||
```python
|
||||
class CreateModelRequest(StrictBase)
|
||||
```
|
||||
|
||||
#### `base_model`
|
||||
|
||||
Optional metadata about this model/training run, set by the end-user
|
||||
|
||||
# `tinker.types.datum`
|
||||
|
||||
## `Datum` Objects
|
||||
|
||||
```python
|
||||
class Datum(StrictBase)
|
||||
```
|
||||
|
||||
#### `loss_fn_inputs`
|
||||
|
||||
Dictionary mapping field names to tensor data
|
||||
|
||||
#### `convert_tensors`
|
||||
|
||||
```python
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def convert_tensors(cls, data: Any) -> Any
|
||||
```
|
||||
|
||||
Convert torch.Tensor and numpy arrays to TensorData in loss_fn_inputs during construction.
|
||||
|
||||
# `tinker.types.training_run`
|
||||
|
||||
## `TrainingRun` Objects
|
||||
|
||||
```python
|
||||
class TrainingRun(BaseModel)
|
||||
```
|
||||
|
||||
#### `training_run_id`
|
||||
|
||||
The unique identifier for the training run
|
||||
|
||||
#### `base_model`
|
||||
|
||||
The base model name this model is derived from
|
||||
|
||||
#### `model_owner`
|
||||
|
||||
The owner/creator of this model
|
||||
|
||||
#### `is_lora`
|
||||
|
||||
Whether this model uses LoRA (Low-Rank Adaptation)
|
||||
|
||||
#### `corrupted`
|
||||
|
||||
Whether the model is in a corrupted state
|
||||
|
||||
#### `lora_rank`
|
||||
|
||||
The LoRA rank if this is a LoRA model, null otherwise
|
||||
|
||||
#### `last_request_time`
|
||||
|
||||
The timestamp of the last request made to this model
|
||||
|
||||
#### `last_checkpoint`
|
||||
|
||||
The most recent training checkpoint, if available
|
||||
|
||||
#### `last_sampler_checkpoint`
|
||||
|
||||
The most recent sampler checkpoint, if available
|
||||
|
||||
#### `user_metadata`
|
||||
|
||||
Optional metadata about this training run, set by the end-user
|
||||
|
||||
# `tinker.types.session_end_event`
|
||||
|
||||
## `SessionEndEvent` Objects
|
||||
|
||||
```python
|
||||
class SessionEndEvent(BaseModel)
|
||||
```
|
||||
|
||||
#### `duration`
|
||||
|
||||
ISO 8601 duration string
|
||||
|
||||
#### `event`
|
||||
|
||||
Telemetry event type
|
||||
|
||||
#### `severity`
|
||||
|
||||
Log severity level
|
||||
|
||||
# `tinker.types.telemetry_batch`
|
||||
|
||||
## `TelemetryBatch` Objects
|
||||
|
||||
```python
|
||||
class TelemetryBatch(BaseModel)
|
||||
```
|
||||
|
||||
#### `platform`
|
||||
|
||||
Host platform name
|
||||
|
||||
#### `sdk_version`
|
||||
|
||||
SDK version string
|
||||
|
||||
# `tinker.types.unhandled_exception_event`
|
||||
|
||||
## `UnhandledExceptionEvent` Objects
|
||||
|
||||
```python
|
||||
class UnhandledExceptionEvent(BaseModel)
|
||||
```
|
||||
|
||||
#### `event`
|
||||
|
||||
Telemetry event type
|
||||
|
||||
#### `severity`
|
||||
|
||||
Log severity level
|
||||
|
||||
#### `traceback`
|
||||
|
||||
Optional Python traceback string
|
||||
|
||||
# `tinker.types.image_chunk`
|
||||
|
||||
## `ImageChunk` Objects
|
||||
|
||||
```python
|
||||
class ImageChunk(StrictBase)
|
||||
```
|
||||
|
||||
#### `data`
|
||||
|
||||
Image data as bytes
|
||||
|
||||
#### `format`
|
||||
|
||||
Image format
|
||||
|
||||
#### `height`
|
||||
|
||||
Image height in pixels
|
||||
|
||||
#### `tokens`
|
||||
|
||||
Number of tokens this image represents
|
||||
|
||||
#### `width`
|
||||
|
||||
Image width in pixels
|
||||
|
||||
#### `expected_tokens`
|
||||
|
||||
Expected number of tokens this image represents.
|
||||
This is only advisory: the tinker backend will compute the number of tokens
|
||||
from the image, and we can fail requests quickly if the tokens does not
|
||||
match expected_tokens.
|
||||
|
||||
#### `validate_data`
|
||||
|
||||
```python
|
||||
@field_validator("data", mode="before")
|
||||
@classmethod
|
||||
def validate_data(cls, value: Union[bytes, str]) -> bytes
|
||||
```
|
||||
|
||||
Deserialize base64 string to bytes if needed.
|
||||
|
||||
#### `serialize_data`
|
||||
|
||||
```python
|
||||
@field_serializer("data")
|
||||
def serialize_data(value: bytes) -> str
|
||||
```
|
||||
|
||||
Serialize bytes to base64 string for JSON.
|
||||
|
||||
# `tinker.types.save_weights_request`
|
||||
|
||||
## `SaveWeightsRequest` Objects
|
||||
|
||||
```python
|
||||
class SaveWeightsRequest(StrictBase)
|
||||
```
|
||||
|
||||
#### `path`
|
||||
|
||||
A file/directory name for the weights
|
||||
|
||||
# `tinker.types.lora_config`
|
||||
|
||||
## `LoraConfig` Objects
|
||||
|
||||
```python
|
||||
class LoraConfig(StrictBase)
|
||||
```
|
||||
|
||||
#### `rank`
|
||||
|
||||
LoRA rank (dimension of low-rank matrices)
|
||||
|
||||
#### `seed`
|
||||
|
||||
Seed used for initialization of LoRA weights.
|
||||
|
||||
Useful if you need deterministic or reproducible initialization of weights.
|
||||
|
||||
#### `train_unembed`
|
||||
|
||||
Whether to add lora to the unembedding layer
|
||||
|
||||
#### `train_mlp`
|
||||
|
||||
Whether to add loras to the MLP layers (including MoE layers)
|
||||
|
||||
#### `train_attn`
|
||||
|
||||
Whether to add loras to the attention layers
|
||||
|
||||
# `tinker.types.create_sampling_session_request`
|
||||
|
||||
## `CreateSamplingSessionRequest` Objects
|
||||
|
||||
```python
|
||||
class CreateSamplingSessionRequest(StrictBase)
|
||||
```
|
||||
|
||||
#### `session_id`
|
||||
|
||||
The session ID to create the sampling session within
|
||||
|
||||
#### `sampling_session_seq_id`
|
||||
|
||||
Sequence ID for the sampling session within the session
|
||||
|
||||
#### `base_model`
|
||||
|
||||
Optional base model name to sample from.
|
||||
|
||||
Is inferred from model_path, if provided. If sampling against a base model, this
|
||||
is required.
|
||||
|
||||
#### `model_path`
|
||||
|
||||
Optional tinker:// path to your model weights or LoRA weights.
|
||||
|
||||
If not provided, samples against the base model.
|
||||
|
||||
# `tinker.types.future_retrieve_request`
|
||||
|
||||
## `FutureRetrieveRequest` Objects
|
||||
|
||||
```python
|
||||
class FutureRetrieveRequest(StrictBase)
|
||||
```
|
||||
|
||||
#### `request_id`
|
||||
|
||||
The ID of the request to retrieve
|
||||
|
||||
# `tinker.types.tensor_data`
|
||||
|
||||
## `TensorData` Objects
|
||||
|
||||
```python
|
||||
class TensorData(StrictBase)
|
||||
```
|
||||
|
||||
#### `data`
|
||||
|
||||
Flattened tensor data as array of numbers.
|
||||
|
||||
#### `shape`
|
||||
|
||||
Optional.
|
||||
|
||||
The shape of the tensor (see PyTorch tensor.shape). The shape of a
|
||||
one-dimensional list of length N is `(N,)`. Can usually be inferred if not
|
||||
provided, and is generally inferred as a 1D tensor.
|
||||
|
||||
#### `to_numpy`
|
||||
|
||||
```python
|
||||
def to_numpy() -> npt.NDArray[Any]
|
||||
```
|
||||
|
||||
Convert TensorData to numpy array.
|
||||
|
||||
#### `to_torch`
|
||||
|
||||
```python
|
||||
def to_torch() -> "torch.Tensor"
|
||||
```
|
||||
|
||||
Convert TensorData to torch tensor.
|
||||
|
||||
# `tinker.types.save_weights_for_sampler_request`
|
||||
|
||||
## `SaveWeightsForSamplerRequest` Objects
|
||||
|
||||
```python
|
||||
class SaveWeightsForSamplerRequest(StrictBase)
|
||||
```
|
||||
|
||||
#### `path`
|
||||
|
||||
A file/directory name for the weights
|
||||
|
||||
# `tinker.types.sampling_params`
|
||||
|
||||
## `SamplingParams` Objects
|
||||
|
||||
```python
|
||||
class SamplingParams(BaseModel)
|
||||
```
|
||||
|
||||
#### `max_tokens`
|
||||
|
||||
Maximum number of tokens to generate
|
||||
|
||||
#### `seed`
|
||||
|
||||
Random seed for reproducible generation
|
||||
|
||||
#### `stop`
|
||||
|
||||
Stop sequences for generation
|
||||
|
||||
#### `temperature`
|
||||
|
||||
Sampling temperature
|
||||
|
||||
#### `top_k`
|
||||
|
||||
Top-k sampling parameter (-1 for no limit)
|
||||
|
||||
#### `top_p`
|
||||
|
||||
Nucleus sampling probability
|
||||
|
||||
# `tinker.types.save_weights_for_sampler_response`
|
||||
|
||||
## `SaveWeightsForSamplerResponseInternal` Objects
|
||||
|
||||
```python
|
||||
class SaveWeightsForSamplerResponseInternal(BaseModel)
|
||||
```
|
||||
|
||||
#### `path`
|
||||
|
||||
A tinker URI for model weights for sampling at a specific step
|
||||
|
||||
#### `sampling_session_id`
|
||||
|
||||
The generated sampling session ID
|
||||
|
||||
## `SaveWeightsForSamplerResponse` Objects
|
||||
|
||||
```python
|
||||
class SaveWeightsForSamplerResponse(BaseModel)
|
||||
```
|
||||
|
||||
#### `path`
|
||||
|
||||
A tinker URI for model weights for sampling at a specific step
|
||||
|
||||
# `tinker.types.load_weights_response`
|
||||
|
||||
## `LoadWeightsResponse` Objects
|
||||
|
||||
```python
|
||||
class LoadWeightsResponse(BaseModel)
|
||||
```
|
||||
|
||||
#### `path`
|
||||
|
||||
A tinker URI for model weights at a specific step
|
||||
Loading…
Add table
Add a link
Reference in a new issue