tinker/src/tinker/lib/public_interfaces/rest_client.py
2026-01-27 19:55:25 +01:00

897 lines
37 KiB
Python

"""RestClient for Tinker API REST operations."""
from __future__ import annotations
import logging
from typing import Sequence
from concurrent.futures import Future as ConcurrentFuture
from typing import TYPE_CHECKING
from tinker import types
from tinker._types import NoneType
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
from tinker.lib.telemetry import Telemetry, capture_exceptions
from tinker.lib.telemetry_provider import TelemetryProvider
from ..sync_only import sync_only
if TYPE_CHECKING:
from ..internal_client_holder import InternalClientHolder
# pyright: reportPrivateImportUsage=false
logger = logging.getLogger(__name__)
class RestClient(TelemetryProvider):
"""Client for REST API operations like listing checkpoints and metadata.
The RestClient provides access to various REST endpoints for querying
model information, checkpoints, and other resources. You typically get one
by calling `service_client.create_rest_client()`.
Key methods:
- list_checkpoints() - list available model checkpoints (both training and sampler)
- list_user_checkpoints() - list all checkpoints across all user's training runs
- get_training_run() - get model information and metadata as ModelEntry
- delete_checkpoint() - delete an existing checkpoint for a training run
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public
- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private
Args:
- `holder`: Internal client managing HTTP connections and async operations
Example:
```python
rest_client = service_client.create_rest_client()
training_run = rest_client.get_training_run("run-id").result()
print(f"Training Run: {training_run.training_run_id}, LoRA: {training_run.is_lora}")
checkpoints = rest_client.list_checkpoints("run-id").result()
print(f"Found {len(checkpoints.checkpoints)} checkpoints")
for checkpoint in checkpoints.checkpoints:
print(f" {checkpoint.checkpoint_type}: {checkpoint.checkpoint_id}")
```
"""
def __init__(self, holder: InternalClientHolder):
self.holder = holder
def _get_training_run_submit(
self, training_run_id: types.ModelID
) -> AwaitableConcurrentFuture[types.TrainingRun]:
"""Internal method to submit get model request."""
async def _get_training_run_async() -> types.TrainingRun:
async def _send_request() -> types.TrainingRun:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.get(
f"/api/v1/training_runs/{training_run_id}",
cast_to=types.TrainingRun,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_get_training_run_async())
@sync_only
@capture_exceptions(fatal=True)
def get_training_run(
self, training_run_id: types.ModelID
) -> ConcurrentFuture[types.TrainingRun]:
"""Get training run info.
Args:
- `training_run_id`: The training run ID to get information for
Returns:
- A `Future` containing the training run information
Example:
```python
future = rest_client.get_training_run("run-id")
response = future.result()
print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
```
"""
return self._get_training_run_submit(training_run_id).future()
@capture_exceptions(fatal=True)
async def get_training_run_async(self, training_run_id: types.ModelID) -> types.TrainingRun:
"""Async version of get_training_run."""
return await self._get_training_run_submit(training_run_id)
@sync_only
@capture_exceptions(fatal=True)
def get_training_run_by_tinker_path(
self, tinker_path: str
) -> ConcurrentFuture[types.TrainingRun]:
"""Get training run info.
Args:
- `tinker_path`: The tinker path to the checkpoint
Returns:
- A `Future` containing the training run information
Example:
```python
future = rest_client.get_training_run_by_tinker_path("tinker://run-id/weights/checkpoint-001")
response = future.result()
print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
```
"""
parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(
tinker_path
)
return self.get_training_run(parsed_checkpoint_tinker_path.training_run_id)
@capture_exceptions(fatal=True)
async def get_training_run_by_tinker_path_async(self, tinker_path: str) -> types.TrainingRun:
"""Async version of get_training_run_by_tinker_path."""
parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(
tinker_path
)
return await self.get_training_run_async(parsed_checkpoint_tinker_path.training_run_id)
@capture_exceptions(fatal=True)
def get_weights_info_by_tinker_path(
self, tinker_path: str
) -> APIFuture[types.WeightsInfoResponse]:
"""Get checkpoint information from a tinker path.
Args:
- `tinker_path`: The tinker path to the checkpoint
Returns:
- An `APIFuture` containing the checkpoint information. The future is awaitable.
Example:
```python
future = rest_client.get_weights_info_by_tinker_path("tinker://run-id/weights/checkpoint-001")
response = future.result() # or await future
print(f"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}")
```
"""
async def _get_weights_info_async() -> types.WeightsInfoResponse:
async def _send_request() -> types.WeightsInfoResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.post(
"/api/v1/weights_info",
body={"tinker_path": tinker_path},
cast_to=types.WeightsInfoResponse,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_get_weights_info_async())
def _list_training_runs_submit(
self, limit: int = 20, offset: int = 0
) -> AwaitableConcurrentFuture[types.TrainingRunsResponse]:
"""Internal method to submit list training runs request."""
async def _list_training_runs_async() -> types.TrainingRunsResponse:
async def _send_request() -> types.TrainingRunsResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
params: dict[str, object] = {"limit": limit, "offset": offset}
return await client.get(
"/api/v1/training_runs",
options={"params": params},
cast_to=types.TrainingRunsResponse,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_list_training_runs_async())
@sync_only
@capture_exceptions(fatal=True)
def list_training_runs(
self, limit: int = 20, offset: int = 0
) -> ConcurrentFuture[types.TrainingRunsResponse]:
"""List training runs with pagination support.
Args:
- `limit`: Maximum number of training runs to return (default 20)
- `offset`: Offset for pagination (default 0)
Returns:
- A `Future` containing the `TrainingRunsResponse` with training runs and cursor info
Example:
```python
future = rest_client.list_training_runs(limit=50)
response = future.result()
print(f"Found {len(response.training_runs)} training runs")
print(f"Total: {response.cursor.total_count}")
# Get next page
next_page = rest_client.list_training_runs(limit=50, offset=50)
```
"""
return self._list_training_runs_submit(limit, offset).future()
@capture_exceptions(fatal=True)
async def list_training_runs_async(
self, limit: int = 20, offset: int = 0
) -> types.TrainingRunsResponse:
"""Async version of list_training_runs."""
return await self._list_training_runs_submit(limit, offset)
def _list_checkpoints_submit(
self, training_run_id: types.ModelID
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
"""Internal method to submit list model checkpoints request."""
async def _list_checkpoints_async():
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.list(training_run_id)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_list_checkpoints_async())
@sync_only
@capture_exceptions(fatal=True)
def list_checkpoints(
self, training_run_id: types.ModelID
) -> ConcurrentFuture[types.CheckpointsListResponse]:
"""List available checkpoints (both training and sampler).
Args:
- `training_run_id`: The training run ID to list checkpoints for
Returns:
- A `Future` containing the `CheckpointsListResponse` with available checkpoints
Example:
```python
future = rest_client.list_checkpoints("run-id")
response = future.result()
for checkpoint in response.checkpoints:
if checkpoint.checkpoint_type == "training":
print(f"Training checkpoint: {checkpoint.checkpoint_id}")
elif checkpoint.checkpoint_type == "sampler":
print(f"Sampler checkpoint: {checkpoint.checkpoint_id}")
```
"""
return self._list_checkpoints_submit(training_run_id).future()
@capture_exceptions(fatal=True)
async def list_checkpoints_async(
self, training_run_id: types.ModelID
) -> types.CheckpointsListResponse:
"""Async version of list_checkpoints."""
return await self._list_checkpoints_submit(training_run_id)
def _get_checkpoint_archive_url_submit(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> AwaitableConcurrentFuture[types.CheckpointArchiveUrlResponse]:
"""Internal method to submit get checkpoint archive URL request."""
async def _get_checkpoint_archive_url_async() -> types.CheckpointArchiveUrlResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.get_checkpoint_archive_url(
model_id=training_run_id,
checkpoint_id=checkpoint_id,
)
return self.holder.run_coroutine_threadsafe(_get_checkpoint_archive_url_async())
@sync_only
@capture_exceptions(fatal=True)
def get_checkpoint_archive_url(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]:
"""Get signed URL to download checkpoint archive.
Args:
- `training_run_id`: The training run ID to download weights for
- `checkpoint_id`: The checkpoint ID to download
Returns:
- A `Future` containing the `CheckpointArchiveUrlResponse` with signed URL and expiration
Example:
```python
future = rest_client.get_checkpoint_archive_url("run-id", "checkpoint-123")
response = future.result()
print(f"Download URL: {response.url}")
print(f"Expires at: {response.expires_at}")
# Use the URL to download the archive with your preferred HTTP client
```
"""
return self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id).future()
@capture_exceptions(fatal=True)
async def get_checkpoint_archive_url_async(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> types.CheckpointArchiveUrlResponse:
"""Async version of get_checkpoint_archive_url."""
return await self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id)
def _delete_checkpoint_submit(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> AwaitableConcurrentFuture[None]:
"""Internal method to submit delete checkpoint request."""
async def _delete_checkpoint_async() -> None:
async def _send_request() -> None:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
await client.delete(
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}",
cast_to=NoneType,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_delete_checkpoint_async())
@sync_only
@capture_exceptions(fatal=True)
def delete_checkpoint(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> ConcurrentFuture[None]:
"""Delete a checkpoint for a training run."""
return self._delete_checkpoint_submit(training_run_id, checkpoint_id).future()
@capture_exceptions(fatal=True)
async def delete_checkpoint_async(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> None:
"""Async version of delete_checkpoint."""
await self._delete_checkpoint_submit(training_run_id, checkpoint_id)
@sync_only
@capture_exceptions(fatal=True)
def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
"""Delete a checkpoint referenced by a tinker path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._delete_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
).future()
@capture_exceptions(fatal=True)
async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of delete_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._delete_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
)
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
@sync_only
@capture_exceptions(fatal=True)
def get_checkpoint_archive_url_from_tinker_path(
self, tinker_path: str
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]:
"""Get signed URL to download checkpoint archive.
Args:
- `tinker_path`: The tinker path to the checkpoint
Returns:
- A `Future` containing the `CheckpointArchiveUrlResponse` with signed URL and expiration
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
).future()
@capture_exceptions(fatal=True)
async def get_checkpoint_archive_url_from_tinker_path_async(
self, tinker_path: str
) -> types.CheckpointArchiveUrlResponse:
"""Async version of get_checkpoint_archive_url_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return await self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
)
@sync_only
@capture_exceptions(fatal=True)
def export_checkpoint_to_hub(
self,
tinker_path: str,
repo_id: str | None,
*,
private: bool = True,
token: str | None = None,
revision: str | None = None,
commit_message: str | None = None,
create_pr: bool = False,
exist_ok: bool = True,
allow_patterns: Sequence[str] | None = None,
ignore_patterns: Sequence[str] | None = None,
add_model_card: bool = True,
) -> str:
"""Export a checkpoint archive to the Hugging Face Hub as a PEFT adapter.
This downloads the checkpoint archive, extracts it locally, optionally adds a
README.md model card, and uploads the folder to the Hub.
Args:
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
- `repo_id`: Hugging Face repo ID (e.g., "username/my-lora-adapter"). If None,
a name is derived from the base model and checkpoint path.
- `private`: Whether to create the repo as private (default False)
- `token`: Hugging Face access token (optional)
- `revision`: Target branch/revision to upload to (optional)
- `commit_message`: Commit message for the upload (optional)
- `create_pr`: Whether to create a PR instead of pushing to the main branch
- `exist_ok`: Whether repo creation should succeed if repo exists
- `allow_patterns`: Optional list of file patterns to include
- `ignore_patterns`: Optional list of file patterns to exclude
- `add_model_card`: Whether to add a README.md if missing (default True)
Returns:
- The repo_id that was uploaded to
"""
# Lazy imports to keep base SDK lightweight
try:
from huggingface_hub import HfApi
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError(
"huggingface_hub is required for export_checkpoint_to_hub. "
"Install it with: pip install huggingface_hub"
) from exc
import os
import tarfile
import tempfile
import urllib.request
from pathlib import Path
# Validate tinker path
types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
for member in tar.getmembers():
member_path = path / member.name
if not str(member_path.resolve()).startswith(str(path.resolve())):
raise ValueError("Unsafe path in tar archive")
tar.extractall(path=path)
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
def _sanitize_repo_name(value: str) -> str:
safe_chars = []
for ch in value:
if ch.isalnum() or ch in {"-", "_", "."}:
safe_chars.append(ch)
else:
safe_chars.append("-")
# Collapse repeated separators
name = "".join(safe_chars)
while "--" in name:
name = name.replace("--", "-")
return name.strip("-_ .")
url_response = self.get_checkpoint_archive_url_from_tinker_path(tinker_path).result()
with tempfile.TemporaryDirectory() as temp_dir:
temp_root = Path(temp_dir)
archive_path = temp_root / "checkpoint.tar"
extract_dir = temp_root / "extract"
extract_dir.mkdir(parents=True, exist_ok=True)
# Download archive
with urllib.request.urlopen(url_response.url, timeout=60) as response:
with open(archive_path, "wb") as f:
while True:
chunk = response.read(8192)
if not chunk:
break
f.write(chunk)
# Extract archive
with tarfile.open(archive_path, "r") as tar:
_safe_extract(tar, extract_dir)
# Validate PEFT adapter files exist
adapter_config = extract_dir / "adapter_config.json"
adapter_safetensors = extract_dir / "adapter_model.safetensors"
adapter_bin = extract_dir / "adapter_model.bin"
if not adapter_config.exists() or not (adapter_safetensors.exists() or adapter_bin.exists()):
raise ValueError(
"Checkpoint archive does not contain a PEFT adapter. "
"Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin)."
)
base_model = "unknown"
lora_rank = None
train_mlp = None
train_attn = None
train_unembed = None
try:
weights_info = self.get_weights_info_by_tinker_path(tinker_path).result()
base_model = weights_info.base_model
lora_rank = weights_info.lora_rank
train_mlp = weights_info.train_mlp
train_attn = weights_info.train_attn
train_unembed = weights_info.train_unembed
except Exception:
pass
if repo_id is None:
base_short = base_model.split("/")[-1] if base_model != "unknown" else "adapter"
checkpoint_id = parsed_tinker_path.checkpoint_id.replace("/", "-")
derived = f"tinker-{base_short}-{parsed_tinker_path.training_run_id}-{checkpoint_id}"
repo_id = _sanitize_repo_name(derived)
# Add a lightweight model card if missing
readme_path = extract_dir / "README.md"
if add_model_card and not readme_path.exists():
tags: list[str] = ["tinker", "peft", "lora", "transformers"]
if base_model != "unknown":
tags.append(f"base_model:adapter:{base_model}")
model_card = [
"---",
f"base_model: {base_model}",
"library_name: peft",
"tags:",
]
for tag in tags:
model_card.append(f"- {tag}")
model_card.append(f"tinker_path: {tinker_path}")
model_card.extend(
[
"---",
"",
"# LoRA Adapter (Tinker)",
"",
f"This repository contains a LoRA adapter exported from Tinker.",
"",
"## Source",
"",
f"- Tinker checkpoint: {tinker_path}",
"",
"## Details",
"",
f"- Base model: {base_model}",
]
)
if lora_rank is not None:
model_card.append(f"- LoRA rank: {lora_rank}")
if train_mlp is not None or train_attn is not None or train_unembed is not None:
model_card.append(
f"- Trained modules: attn={train_attn}, mlp={train_mlp}, unembed={train_unembed}"
)
model_card.append("")
readme_path.write_text("\n".join(model_card), encoding="utf-8")
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok)
api.upload_folder(
folder_path=os.fspath(extract_dir),
repo_id=repo_id,
path_in_repo="",
revision=revision,
commit_message=commit_message,
create_pr=create_pr,
allow_patterns=list(allow_patterns) if allow_patterns else None,
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
)
return repo_id
def _publish_checkpoint_submit(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> AwaitableConcurrentFuture[None]:
"""Internal method to submit publish checkpoint request."""
async def _publish_checkpoint_async() -> None:
async def _send_request() -> None:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
await client.post(
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish",
cast_to=NoneType,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_publish_checkpoint_async())
@sync_only
@capture_exceptions(fatal=True)
def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
"""Publish a checkpoint referenced by a tinker path to make it publicly accessible.
Only the exact owner of the training run can publish checkpoints.
Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path method.
Args:
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
Returns:
- A `Future` that completes when the checkpoint is published
Raises:
HTTPException: 400 if checkpoint identifier is invalid
HTTPException: 404 if checkpoint not found or user doesn't own the training run
HTTPException: 409 if checkpoint is already public
HTTPException: 500 if there's an error publishing the checkpoint
Example:
```python
future = rest_client.publish_checkpoint_from_tinker_path("tinker://run-id/weights/0001")
future.result() # Wait for completion
print("Checkpoint published successfully")
```
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._publish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
).future()
@capture_exceptions(fatal=True)
async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of publish_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._publish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
)
def _unpublish_checkpoint_submit(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> AwaitableConcurrentFuture[None]:
"""Internal method to submit unpublish checkpoint request."""
async def _unpublish_checkpoint_async() -> None:
async def _send_request() -> None:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
await client.delete(
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish",
cast_to=NoneType,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_unpublish_checkpoint_async())
@sync_only
@capture_exceptions(fatal=True)
def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
"""Unpublish a checkpoint referenced by a tinker path to make it private again.
Only the exact owner of the training run can unpublish checkpoints.
This reverses the effect of publishing a checkpoint.
Args:
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
Returns:
- A `Future` that completes when the checkpoint is unpublished
Raises:
HTTPException: 400 if checkpoint identifier is invalid
HTTPException: 404 if checkpoint not found or user doesn't own the training run
HTTPException: 409 if checkpoint is already private
HTTPException: 500 if there's an error unpublishing the checkpoint
Example:
```python
future = rest_client.unpublish_checkpoint_from_tinker_path("tinker://run-id/weights/0001")
future.result() # Wait for completion
print("Checkpoint unpublished successfully")
```
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._unpublish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
).future()
@capture_exceptions(fatal=True)
async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of unpublish_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._unpublish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
)
def _list_user_checkpoints_submit(
self, limit: int = 100, offset: int = 0
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
"""Internal method to submit list user checkpoints request."""
async def _list_user_checkpoints_async() -> types.CheckpointsListResponse:
async def _send_request() -> types.CheckpointsListResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
params: dict[str, object] = {"limit": limit, "offset": offset}
return await client.get(
"/api/v1/checkpoints",
options={"params": params},
cast_to=types.CheckpointsListResponse,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_list_user_checkpoints_async())
@sync_only
@capture_exceptions(fatal=True)
def list_user_checkpoints(
self, limit: int = 100, offset: int = 0
) -> ConcurrentFuture[types.CheckpointsListResponse]:
"""List all checkpoints for the current user across all their training runs.
This method retrieves checkpoints from all training runs owned by the authenticated user,
sorted by time (newest first). It supports pagination for efficiently handling large
numbers of checkpoints.
Args:
- `limit`: Maximum number of checkpoints to return (default 100)
- `offset`: Offset for pagination (default 0)
Returns:
- A `Future` containing the `CheckpointsListResponse` with checkpoints and cursor info
Example:
```python
future = rest_client.list_user_checkpoints(limit=50)
response = future.result()
print(f"Found {len(response.checkpoints)} checkpoints")
print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}")
for checkpoint in response.checkpoints:
print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}")
# Get next page if there are more checkpoints
if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count:
next_page = rest_client.list_user_checkpoints(limit=50, offset=50)
```
"""
return self._list_user_checkpoints_submit(limit, offset).future()
@capture_exceptions(fatal=True)
async def list_user_checkpoints_async(
self, limit: int = 100, offset: int = 0
) -> types.CheckpointsListResponse:
"""Async version of list_user_checkpoints."""
return await self._list_user_checkpoints_submit(limit, offset)
def _get_session_submit(
self, session_id: str
) -> AwaitableConcurrentFuture[types.GetSessionResponse]:
"""Internal method to submit get session request."""
async def _get_session_async() -> types.GetSessionResponse:
async def _send_request() -> types.GetSessionResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.get(
f"/api/v1/sessions/{session_id}",
cast_to=types.GetSessionResponse,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_get_session_async())
@sync_only
@capture_exceptions(fatal=True)
def get_session(self, session_id: str) -> ConcurrentFuture[types.GetSessionResponse]:
"""Get session information including all training runs and samplers.
Args:
- `session_id`: The session ID to get information for
Returns:
- A `Future` containing the `GetSessionResponse` with training_run_ids and sampler_ids
Example:
```python
future = rest_client.get_session("session-id")
response = future.result()
print(f"Training runs: {len(response.training_run_ids)}")
print(f"Samplers: {len(response.sampler_ids)}")
```
"""
return self._get_session_submit(session_id).future()
@capture_exceptions(fatal=True)
async def get_session_async(self, session_id: str) -> types.GetSessionResponse:
"""Async version of get_session."""
return await self._get_session_submit(session_id)
def _list_sessions_submit(
self, limit: int = 20, offset: int = 0
) -> AwaitableConcurrentFuture[types.ListSessionsResponse]:
"""Internal method to submit list sessions request."""
async def _list_sessions_async() -> types.ListSessionsResponse:
async def _send_request() -> types.ListSessionsResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
params: dict[str, object] = {"limit": limit, "offset": offset}
return await client.get(
"/api/v1/sessions",
options={"params": params},
cast_to=types.ListSessionsResponse,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_list_sessions_async())
@sync_only
@capture_exceptions(fatal=True)
def list_sessions(
self, limit: int = 20, offset: int = 0
) -> ConcurrentFuture[types.ListSessionsResponse]:
"""List sessions with pagination support.
Args:
- `limit`: Maximum number of sessions to return (default 20)
- `offset`: Offset for pagination (default 0)
Returns:
- A `Future` containing the `ListSessionsResponse` with list of session IDs
Example:
```python
future = rest_client.list_sessions(limit=50)
response = future.result()
print(f"Found {len(response.sessions)} sessions")
# Get next page
next_page = rest_client.list_sessions(limit=50, offset=50)
```
"""
return self._list_sessions_submit(limit, offset).future()
@capture_exceptions(fatal=True)
async def list_sessions_async(
self, limit: int = 20, offset: int = 0
) -> types.ListSessionsResponse:
"""Async version of list_sessions."""
return await self._list_sessions_submit(limit, offset)
@capture_exceptions(fatal=True)
def get_sampler(self, sampler_id: str) -> APIFuture[types.GetSamplerResponse]:
"""Get sampler information.
Args:
- `sampler_id`: The sampler ID (sampling_session_id) to get information for
Returns:
- An `APIFuture` containing the `GetSamplerResponse` with sampler details
Example:
```python
# Sync usage
future = rest_client.get_sampler("session-id:sample:0")
response = future.result()
print(f"Base model: {response.base_model}")
print(f"Model path: {response.model_path}")
# Async usage
response = await rest_client.get_sampler("session-id:sample:0")
print(f"Base model: {response.base_model}")
```
"""
async def _get_sampler_async() -> types.GetSamplerResponse:
async def _send_request() -> types.GetSamplerResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.get(
f"/api/v1/samplers/{sampler_id}",
cast_to=types.GetSamplerResponse,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_get_sampler_async())
@capture_exceptions(fatal=True)
async def get_sampler_async(self, sampler_id: str) -> types.GetSamplerResponse:
"""Async version of get_sampler."""
return await self.get_sampler(sampler_id)