mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-23 16:54:58 +00:00
897 lines
37 KiB
Python
897 lines
37 KiB
Python
"""RestClient for Tinker API REST operations."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Sequence
|
|
from concurrent.futures import Future as ConcurrentFuture
|
|
from typing import TYPE_CHECKING
|
|
|
|
from tinker import types
|
|
from tinker._types import NoneType
|
|
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
|
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
|
|
from tinker.lib.telemetry import Telemetry, capture_exceptions
|
|
from tinker.lib.telemetry_provider import TelemetryProvider
|
|
|
|
from ..sync_only import sync_only
|
|
|
|
if TYPE_CHECKING:
|
|
from ..internal_client_holder import InternalClientHolder
|
|
|
|
# pyright: reportPrivateImportUsage=false
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RestClient(TelemetryProvider):
|
|
"""Client for REST API operations like listing checkpoints and metadata.
|
|
|
|
The RestClient provides access to various REST endpoints for querying
|
|
model information, checkpoints, and other resources. You typically get one
|
|
by calling `service_client.create_rest_client()`.
|
|
|
|
Key methods:
|
|
- list_checkpoints() - list available model checkpoints (both training and sampler)
|
|
- list_user_checkpoints() - list all checkpoints across all user's training runs
|
|
- get_training_run() - get model information and metadata as ModelEntry
|
|
- delete_checkpoint() - delete an existing checkpoint for a training run
|
|
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
|
|
- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public
|
|
- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private
|
|
|
|
Args:
|
|
- `holder`: Internal client managing HTTP connections and async operations
|
|
|
|
Example:
|
|
```python
|
|
rest_client = service_client.create_rest_client()
|
|
training_run = rest_client.get_training_run("run-id").result()
|
|
print(f"Training Run: {training_run.training_run_id}, LoRA: {training_run.is_lora}")
|
|
checkpoints = rest_client.list_checkpoints("run-id").result()
|
|
print(f"Found {len(checkpoints.checkpoints)} checkpoints")
|
|
for checkpoint in checkpoints.checkpoints:
|
|
print(f" {checkpoint.checkpoint_type}: {checkpoint.checkpoint_id}")
|
|
```
|
|
"""
|
|
|
|
def __init__(self, holder: InternalClientHolder):
|
|
self.holder = holder
|
|
|
|
def _get_training_run_submit(
|
|
self, training_run_id: types.ModelID
|
|
) -> AwaitableConcurrentFuture[types.TrainingRun]:
|
|
"""Internal method to submit get model request."""
|
|
|
|
async def _get_training_run_async() -> types.TrainingRun:
|
|
async def _send_request() -> types.TrainingRun:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
return await client.get(
|
|
f"/api/v1/training_runs/{training_run_id}",
|
|
cast_to=types.TrainingRun,
|
|
)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_get_training_run_async())
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def get_training_run(
|
|
self, training_run_id: types.ModelID
|
|
) -> ConcurrentFuture[types.TrainingRun]:
|
|
"""Get training run info.
|
|
|
|
Args:
|
|
- `training_run_id`: The training run ID to get information for
|
|
|
|
Returns:
|
|
- A `Future` containing the training run information
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.get_training_run("run-id")
|
|
response = future.result()
|
|
print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
|
|
```
|
|
"""
|
|
return self._get_training_run_submit(training_run_id).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def get_training_run_async(self, training_run_id: types.ModelID) -> types.TrainingRun:
|
|
"""Async version of get_training_run."""
|
|
return await self._get_training_run_submit(training_run_id)
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def get_training_run_by_tinker_path(
|
|
self, tinker_path: str
|
|
) -> ConcurrentFuture[types.TrainingRun]:
|
|
"""Get training run info.
|
|
|
|
Args:
|
|
- `tinker_path`: The tinker path to the checkpoint
|
|
|
|
Returns:
|
|
- A `Future` containing the training run information
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.get_training_run_by_tinker_path("tinker://run-id/weights/checkpoint-001")
|
|
response = future.result()
|
|
print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}")
|
|
```
|
|
"""
|
|
parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(
|
|
tinker_path
|
|
)
|
|
return self.get_training_run(parsed_checkpoint_tinker_path.training_run_id)
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def get_training_run_by_tinker_path_async(self, tinker_path: str) -> types.TrainingRun:
|
|
"""Async version of get_training_run_by_tinker_path."""
|
|
parsed_checkpoint_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(
|
|
tinker_path
|
|
)
|
|
return await self.get_training_run_async(parsed_checkpoint_tinker_path.training_run_id)
|
|
|
|
@capture_exceptions(fatal=True)
|
|
def get_weights_info_by_tinker_path(
|
|
self, tinker_path: str
|
|
) -> APIFuture[types.WeightsInfoResponse]:
|
|
"""Get checkpoint information from a tinker path.
|
|
|
|
Args:
|
|
- `tinker_path`: The tinker path to the checkpoint
|
|
|
|
Returns:
|
|
- An `APIFuture` containing the checkpoint information. The future is awaitable.
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.get_weights_info_by_tinker_path("tinker://run-id/weights/checkpoint-001")
|
|
response = future.result() # or await future
|
|
print(f"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}")
|
|
```
|
|
"""
|
|
|
|
async def _get_weights_info_async() -> types.WeightsInfoResponse:
|
|
async def _send_request() -> types.WeightsInfoResponse:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
return await client.post(
|
|
"/api/v1/weights_info",
|
|
body={"tinker_path": tinker_path},
|
|
cast_to=types.WeightsInfoResponse,
|
|
)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_get_weights_info_async())
|
|
|
|
def _list_training_runs_submit(
|
|
self, limit: int = 20, offset: int = 0
|
|
) -> AwaitableConcurrentFuture[types.TrainingRunsResponse]:
|
|
"""Internal method to submit list training runs request."""
|
|
|
|
async def _list_training_runs_async() -> types.TrainingRunsResponse:
|
|
async def _send_request() -> types.TrainingRunsResponse:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
params: dict[str, object] = {"limit": limit, "offset": offset}
|
|
|
|
return await client.get(
|
|
"/api/v1/training_runs",
|
|
options={"params": params},
|
|
cast_to=types.TrainingRunsResponse,
|
|
)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_list_training_runs_async())
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def list_training_runs(
|
|
self, limit: int = 20, offset: int = 0
|
|
) -> ConcurrentFuture[types.TrainingRunsResponse]:
|
|
"""List training runs with pagination support.
|
|
|
|
Args:
|
|
- `limit`: Maximum number of training runs to return (default 20)
|
|
- `offset`: Offset for pagination (default 0)
|
|
|
|
Returns:
|
|
- A `Future` containing the `TrainingRunsResponse` with training runs and cursor info
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.list_training_runs(limit=50)
|
|
response = future.result()
|
|
print(f"Found {len(response.training_runs)} training runs")
|
|
print(f"Total: {response.cursor.total_count}")
|
|
# Get next page
|
|
next_page = rest_client.list_training_runs(limit=50, offset=50)
|
|
```
|
|
"""
|
|
return self._list_training_runs_submit(limit, offset).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def list_training_runs_async(
|
|
self, limit: int = 20, offset: int = 0
|
|
) -> types.TrainingRunsResponse:
|
|
"""Async version of list_training_runs."""
|
|
return await self._list_training_runs_submit(limit, offset)
|
|
|
|
def _list_checkpoints_submit(
|
|
self, training_run_id: types.ModelID
|
|
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
|
|
"""Internal method to submit list model checkpoints request."""
|
|
|
|
async def _list_checkpoints_async():
|
|
async def _send_request():
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
return await client.weights.list(training_run_id)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_list_checkpoints_async())
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def list_checkpoints(
|
|
self, training_run_id: types.ModelID
|
|
) -> ConcurrentFuture[types.CheckpointsListResponse]:
|
|
"""List available checkpoints (both training and sampler).
|
|
|
|
Args:
|
|
- `training_run_id`: The training run ID to list checkpoints for
|
|
|
|
Returns:
|
|
- A `Future` containing the `CheckpointsListResponse` with available checkpoints
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.list_checkpoints("run-id")
|
|
response = future.result()
|
|
for checkpoint in response.checkpoints:
|
|
if checkpoint.checkpoint_type == "training":
|
|
print(f"Training checkpoint: {checkpoint.checkpoint_id}")
|
|
elif checkpoint.checkpoint_type == "sampler":
|
|
print(f"Sampler checkpoint: {checkpoint.checkpoint_id}")
|
|
```
|
|
"""
|
|
return self._list_checkpoints_submit(training_run_id).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def list_checkpoints_async(
|
|
self, training_run_id: types.ModelID
|
|
) -> types.CheckpointsListResponse:
|
|
"""Async version of list_checkpoints."""
|
|
return await self._list_checkpoints_submit(training_run_id)
|
|
|
|
def _get_checkpoint_archive_url_submit(
|
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
|
) -> AwaitableConcurrentFuture[types.CheckpointArchiveUrlResponse]:
|
|
"""Internal method to submit get checkpoint archive URL request."""
|
|
|
|
async def _get_checkpoint_archive_url_async() -> types.CheckpointArchiveUrlResponse:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
return await client.weights.get_checkpoint_archive_url(
|
|
model_id=training_run_id,
|
|
checkpoint_id=checkpoint_id,
|
|
)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_get_checkpoint_archive_url_async())
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def get_checkpoint_archive_url(
|
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
|
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]:
|
|
"""Get signed URL to download checkpoint archive.
|
|
|
|
Args:
|
|
- `training_run_id`: The training run ID to download weights for
|
|
- `checkpoint_id`: The checkpoint ID to download
|
|
|
|
Returns:
|
|
- A `Future` containing the `CheckpointArchiveUrlResponse` with signed URL and expiration
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.get_checkpoint_archive_url("run-id", "checkpoint-123")
|
|
response = future.result()
|
|
print(f"Download URL: {response.url}")
|
|
print(f"Expires at: {response.expires_at}")
|
|
# Use the URL to download the archive with your preferred HTTP client
|
|
```
|
|
"""
|
|
return self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def get_checkpoint_archive_url_async(
|
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
|
) -> types.CheckpointArchiveUrlResponse:
|
|
"""Async version of get_checkpoint_archive_url."""
|
|
return await self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id)
|
|
|
|
def _delete_checkpoint_submit(
|
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
|
) -> AwaitableConcurrentFuture[None]:
|
|
"""Internal method to submit delete checkpoint request."""
|
|
|
|
async def _delete_checkpoint_async() -> None:
|
|
async def _send_request() -> None:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
await client.delete(
|
|
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}",
|
|
cast_to=NoneType,
|
|
)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_delete_checkpoint_async())
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def delete_checkpoint(
|
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
|
) -> ConcurrentFuture[None]:
|
|
"""Delete a checkpoint for a training run."""
|
|
|
|
return self._delete_checkpoint_submit(training_run_id, checkpoint_id).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def delete_checkpoint_async(
|
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
|
) -> None:
|
|
"""Async version of delete_checkpoint."""
|
|
|
|
await self._delete_checkpoint_submit(training_run_id, checkpoint_id)
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def delete_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
|
|
"""Delete a checkpoint referenced by a tinker path."""
|
|
|
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
return self._delete_checkpoint_submit(
|
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
|
).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
|
"""Async version of delete_checkpoint_from_tinker_path."""
|
|
|
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
await self._delete_checkpoint_submit(
|
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
|
)
|
|
|
|
def get_telemetry(self) -> Telemetry | None:
|
|
return self.holder.get_telemetry()
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def get_checkpoint_archive_url_from_tinker_path(
|
|
self, tinker_path: str
|
|
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]:
|
|
"""Get signed URL to download checkpoint archive.
|
|
|
|
Args:
|
|
- `tinker_path`: The tinker path to the checkpoint
|
|
|
|
Returns:
|
|
- A `Future` containing the `CheckpointArchiveUrlResponse` with signed URL and expiration
|
|
"""
|
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
return self._get_checkpoint_archive_url_submit(
|
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
|
).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def get_checkpoint_archive_url_from_tinker_path_async(
|
|
self, tinker_path: str
|
|
) -> types.CheckpointArchiveUrlResponse:
|
|
"""Async version of get_checkpoint_archive_url_from_tinker_path."""
|
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
return await self._get_checkpoint_archive_url_submit(
|
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
|
)
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def export_checkpoint_to_hub(
|
|
self,
|
|
tinker_path: str,
|
|
repo_id: str | None,
|
|
*,
|
|
private: bool = True,
|
|
token: str | None = None,
|
|
revision: str | None = None,
|
|
commit_message: str | None = None,
|
|
create_pr: bool = False,
|
|
exist_ok: bool = True,
|
|
allow_patterns: Sequence[str] | None = None,
|
|
ignore_patterns: Sequence[str] | None = None,
|
|
add_model_card: bool = True,
|
|
) -> str:
|
|
"""Export a checkpoint archive to the Hugging Face Hub as a PEFT adapter.
|
|
|
|
This downloads the checkpoint archive, extracts it locally, optionally adds a
|
|
README.md model card, and uploads the folder to the Hub.
|
|
|
|
Args:
|
|
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
|
- `repo_id`: Hugging Face repo ID (e.g., "username/my-lora-adapter"). If None,
|
|
a name is derived from the base model and checkpoint path.
|
|
- `private`: Whether to create the repo as private (default False)
|
|
- `token`: Hugging Face access token (optional)
|
|
- `revision`: Target branch/revision to upload to (optional)
|
|
- `commit_message`: Commit message for the upload (optional)
|
|
- `create_pr`: Whether to create a PR instead of pushing to the main branch
|
|
- `exist_ok`: Whether repo creation should succeed if repo exists
|
|
- `allow_patterns`: Optional list of file patterns to include
|
|
- `ignore_patterns`: Optional list of file patterns to exclude
|
|
- `add_model_card`: Whether to add a README.md if missing (default True)
|
|
|
|
Returns:
|
|
- The repo_id that was uploaded to
|
|
"""
|
|
# Lazy imports to keep base SDK lightweight
|
|
try:
|
|
from huggingface_hub import HfApi
|
|
except ImportError as exc: # pragma: no cover - optional dependency
|
|
raise ImportError(
|
|
"huggingface_hub is required for export_checkpoint_to_hub. "
|
|
"Install it with: pip install huggingface_hub"
|
|
) from exc
|
|
|
|
import os
|
|
import tarfile
|
|
import tempfile
|
|
import urllib.request
|
|
from pathlib import Path
|
|
|
|
# Validate tinker path
|
|
types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
|
|
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
|
|
for member in tar.getmembers():
|
|
member_path = path / member.name
|
|
if not str(member_path.resolve()).startswith(str(path.resolve())):
|
|
raise ValueError("Unsafe path in tar archive")
|
|
tar.extractall(path=path)
|
|
|
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
|
|
def _sanitize_repo_name(value: str) -> str:
|
|
safe_chars = []
|
|
for ch in value:
|
|
if ch.isalnum() or ch in {"-", "_", "."}:
|
|
safe_chars.append(ch)
|
|
else:
|
|
safe_chars.append("-")
|
|
# Collapse repeated separators
|
|
name = "".join(safe_chars)
|
|
while "--" in name:
|
|
name = name.replace("--", "-")
|
|
return name.strip("-_ .")
|
|
|
|
url_response = self.get_checkpoint_archive_url_from_tinker_path(tinker_path).result()
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
temp_root = Path(temp_dir)
|
|
archive_path = temp_root / "checkpoint.tar"
|
|
extract_dir = temp_root / "extract"
|
|
extract_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Download archive
|
|
with urllib.request.urlopen(url_response.url, timeout=60) as response:
|
|
with open(archive_path, "wb") as f:
|
|
while True:
|
|
chunk = response.read(8192)
|
|
if not chunk:
|
|
break
|
|
f.write(chunk)
|
|
|
|
# Extract archive
|
|
with tarfile.open(archive_path, "r") as tar:
|
|
_safe_extract(tar, extract_dir)
|
|
|
|
# Validate PEFT adapter files exist
|
|
adapter_config = extract_dir / "adapter_config.json"
|
|
adapter_safetensors = extract_dir / "adapter_model.safetensors"
|
|
adapter_bin = extract_dir / "adapter_model.bin"
|
|
if not adapter_config.exists() or not (adapter_safetensors.exists() or adapter_bin.exists()):
|
|
raise ValueError(
|
|
"Checkpoint archive does not contain a PEFT adapter. "
|
|
"Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin)."
|
|
)
|
|
|
|
base_model = "unknown"
|
|
lora_rank = None
|
|
train_mlp = None
|
|
train_attn = None
|
|
train_unembed = None
|
|
try:
|
|
weights_info = self.get_weights_info_by_tinker_path(tinker_path).result()
|
|
base_model = weights_info.base_model
|
|
lora_rank = weights_info.lora_rank
|
|
train_mlp = weights_info.train_mlp
|
|
train_attn = weights_info.train_attn
|
|
train_unembed = weights_info.train_unembed
|
|
except Exception:
|
|
pass
|
|
|
|
if repo_id is None:
|
|
base_short = base_model.split("/")[-1] if base_model != "unknown" else "adapter"
|
|
checkpoint_id = parsed_tinker_path.checkpoint_id.replace("/", "-")
|
|
derived = f"tinker-{base_short}-{parsed_tinker_path.training_run_id}-{checkpoint_id}"
|
|
repo_id = _sanitize_repo_name(derived)
|
|
|
|
# Add a lightweight model card if missing
|
|
readme_path = extract_dir / "README.md"
|
|
if add_model_card and not readme_path.exists():
|
|
tags: list[str] = ["tinker", "peft", "lora", "transformers"]
|
|
if base_model != "unknown":
|
|
tags.append(f"base_model:adapter:{base_model}")
|
|
model_card = [
|
|
"---",
|
|
f"base_model: {base_model}",
|
|
"library_name: peft",
|
|
"tags:",
|
|
]
|
|
for tag in tags:
|
|
model_card.append(f"- {tag}")
|
|
model_card.append(f"tinker_path: {tinker_path}")
|
|
model_card.extend(
|
|
[
|
|
"---",
|
|
"",
|
|
"# LoRA Adapter (Tinker)",
|
|
"",
|
|
f"This repository contains a LoRA adapter exported from Tinker.",
|
|
"",
|
|
"## Source",
|
|
"",
|
|
f"- Tinker checkpoint: {tinker_path}",
|
|
"",
|
|
"## Details",
|
|
"",
|
|
f"- Base model: {base_model}",
|
|
]
|
|
)
|
|
if lora_rank is not None:
|
|
model_card.append(f"- LoRA rank: {lora_rank}")
|
|
if train_mlp is not None or train_attn is not None or train_unembed is not None:
|
|
model_card.append(
|
|
f"- Trained modules: attn={train_attn}, mlp={train_mlp}, unembed={train_unembed}"
|
|
)
|
|
model_card.append("")
|
|
readme_path.write_text("\n".join(model_card), encoding="utf-8")
|
|
|
|
api = HfApi(token=token)
|
|
api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok)
|
|
|
|
api.upload_folder(
|
|
folder_path=os.fspath(extract_dir),
|
|
repo_id=repo_id,
|
|
path_in_repo="",
|
|
revision=revision,
|
|
commit_message=commit_message,
|
|
create_pr=create_pr,
|
|
allow_patterns=list(allow_patterns) if allow_patterns else None,
|
|
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
|
|
)
|
|
|
|
return repo_id
|
|
|
|
def _publish_checkpoint_submit(
|
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
|
) -> AwaitableConcurrentFuture[None]:
|
|
"""Internal method to submit publish checkpoint request."""
|
|
|
|
async def _publish_checkpoint_async() -> None:
|
|
async def _send_request() -> None:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
await client.post(
|
|
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish",
|
|
cast_to=NoneType,
|
|
)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_publish_checkpoint_async())
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
|
|
"""Publish a checkpoint referenced by a tinker path to make it publicly accessible.
|
|
|
|
Only the exact owner of the training run can publish checkpoints.
|
|
Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path method.
|
|
|
|
Args:
|
|
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
|
|
|
Returns:
|
|
- A `Future` that completes when the checkpoint is published
|
|
|
|
Raises:
|
|
HTTPException: 400 if checkpoint identifier is invalid
|
|
HTTPException: 404 if checkpoint not found or user doesn't own the training run
|
|
HTTPException: 409 if checkpoint is already public
|
|
HTTPException: 500 if there's an error publishing the checkpoint
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.publish_checkpoint_from_tinker_path("tinker://run-id/weights/0001")
|
|
future.result() # Wait for completion
|
|
print("Checkpoint published successfully")
|
|
```
|
|
"""
|
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
return self._publish_checkpoint_submit(
|
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
|
).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
|
"""Async version of publish_checkpoint_from_tinker_path."""
|
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
await self._publish_checkpoint_submit(
|
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
|
)
|
|
|
|
def _unpublish_checkpoint_submit(
|
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
|
) -> AwaitableConcurrentFuture[None]:
|
|
"""Internal method to submit unpublish checkpoint request."""
|
|
|
|
async def _unpublish_checkpoint_async() -> None:
|
|
async def _send_request() -> None:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
await client.delete(
|
|
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish",
|
|
cast_to=NoneType,
|
|
)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_unpublish_checkpoint_async())
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
|
|
"""Unpublish a checkpoint referenced by a tinker path to make it private again.
|
|
|
|
Only the exact owner of the training run can unpublish checkpoints.
|
|
This reverses the effect of publishing a checkpoint.
|
|
|
|
Args:
|
|
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
|
|
|
Returns:
|
|
- A `Future` that completes when the checkpoint is unpublished
|
|
|
|
Raises:
|
|
HTTPException: 400 if checkpoint identifier is invalid
|
|
HTTPException: 404 if checkpoint not found or user doesn't own the training run
|
|
HTTPException: 409 if checkpoint is already private
|
|
HTTPException: 500 if there's an error unpublishing the checkpoint
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.unpublish_checkpoint_from_tinker_path("tinker://run-id/weights/0001")
|
|
future.result() # Wait for completion
|
|
print("Checkpoint unpublished successfully")
|
|
```
|
|
"""
|
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
return self._unpublish_checkpoint_submit(
|
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
|
).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
|
"""Async version of unpublish_checkpoint_from_tinker_path."""
|
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
|
await self._unpublish_checkpoint_submit(
|
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
|
)
|
|
|
|
def _list_user_checkpoints_submit(
|
|
self, limit: int = 100, offset: int = 0
|
|
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
|
|
"""Internal method to submit list user checkpoints request."""
|
|
|
|
async def _list_user_checkpoints_async() -> types.CheckpointsListResponse:
|
|
async def _send_request() -> types.CheckpointsListResponse:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
params: dict[str, object] = {"limit": limit, "offset": offset}
|
|
|
|
return await client.get(
|
|
"/api/v1/checkpoints",
|
|
options={"params": params},
|
|
cast_to=types.CheckpointsListResponse,
|
|
)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_list_user_checkpoints_async())
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def list_user_checkpoints(
|
|
self, limit: int = 100, offset: int = 0
|
|
) -> ConcurrentFuture[types.CheckpointsListResponse]:
|
|
"""List all checkpoints for the current user across all their training runs.
|
|
|
|
This method retrieves checkpoints from all training runs owned by the authenticated user,
|
|
sorted by time (newest first). It supports pagination for efficiently handling large
|
|
numbers of checkpoints.
|
|
|
|
Args:
|
|
- `limit`: Maximum number of checkpoints to return (default 100)
|
|
- `offset`: Offset for pagination (default 0)
|
|
|
|
Returns:
|
|
- A `Future` containing the `CheckpointsListResponse` with checkpoints and cursor info
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.list_user_checkpoints(limit=50)
|
|
response = future.result()
|
|
print(f"Found {len(response.checkpoints)} checkpoints")
|
|
print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}")
|
|
for checkpoint in response.checkpoints:
|
|
print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}")
|
|
# Get next page if there are more checkpoints
|
|
if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count:
|
|
next_page = rest_client.list_user_checkpoints(limit=50, offset=50)
|
|
```
|
|
"""
|
|
return self._list_user_checkpoints_submit(limit, offset).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def list_user_checkpoints_async(
|
|
self, limit: int = 100, offset: int = 0
|
|
) -> types.CheckpointsListResponse:
|
|
"""Async version of list_user_checkpoints."""
|
|
return await self._list_user_checkpoints_submit(limit, offset)
|
|
|
|
def _get_session_submit(
|
|
self, session_id: str
|
|
) -> AwaitableConcurrentFuture[types.GetSessionResponse]:
|
|
"""Internal method to submit get session request."""
|
|
|
|
async def _get_session_async() -> types.GetSessionResponse:
|
|
async def _send_request() -> types.GetSessionResponse:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
return await client.get(
|
|
f"/api/v1/sessions/{session_id}",
|
|
cast_to=types.GetSessionResponse,
|
|
)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_get_session_async())
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def get_session(self, session_id: str) -> ConcurrentFuture[types.GetSessionResponse]:
|
|
"""Get session information including all training runs and samplers.
|
|
|
|
Args:
|
|
- `session_id`: The session ID to get information for
|
|
|
|
Returns:
|
|
- A `Future` containing the `GetSessionResponse` with training_run_ids and sampler_ids
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.get_session("session-id")
|
|
response = future.result()
|
|
print(f"Training runs: {len(response.training_run_ids)}")
|
|
print(f"Samplers: {len(response.sampler_ids)}")
|
|
```
|
|
"""
|
|
return self._get_session_submit(session_id).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def get_session_async(self, session_id: str) -> types.GetSessionResponse:
|
|
"""Async version of get_session."""
|
|
return await self._get_session_submit(session_id)
|
|
|
|
def _list_sessions_submit(
|
|
self, limit: int = 20, offset: int = 0
|
|
) -> AwaitableConcurrentFuture[types.ListSessionsResponse]:
|
|
"""Internal method to submit list sessions request."""
|
|
|
|
async def _list_sessions_async() -> types.ListSessionsResponse:
|
|
async def _send_request() -> types.ListSessionsResponse:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
params: dict[str, object] = {"limit": limit, "offset": offset}
|
|
|
|
return await client.get(
|
|
"/api/v1/sessions",
|
|
options={"params": params},
|
|
cast_to=types.ListSessionsResponse,
|
|
)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_list_sessions_async())
|
|
|
|
@sync_only
|
|
@capture_exceptions(fatal=True)
|
|
def list_sessions(
|
|
self, limit: int = 20, offset: int = 0
|
|
) -> ConcurrentFuture[types.ListSessionsResponse]:
|
|
"""List sessions with pagination support.
|
|
|
|
Args:
|
|
- `limit`: Maximum number of sessions to return (default 20)
|
|
- `offset`: Offset for pagination (default 0)
|
|
|
|
Returns:
|
|
- A `Future` containing the `ListSessionsResponse` with list of session IDs
|
|
|
|
Example:
|
|
```python
|
|
future = rest_client.list_sessions(limit=50)
|
|
response = future.result()
|
|
print(f"Found {len(response.sessions)} sessions")
|
|
# Get next page
|
|
next_page = rest_client.list_sessions(limit=50, offset=50)
|
|
```
|
|
"""
|
|
return self._list_sessions_submit(limit, offset).future()
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def list_sessions_async(
|
|
self, limit: int = 20, offset: int = 0
|
|
) -> types.ListSessionsResponse:
|
|
"""Async version of list_sessions."""
|
|
return await self._list_sessions_submit(limit, offset)
|
|
|
|
@capture_exceptions(fatal=True)
|
|
def get_sampler(self, sampler_id: str) -> APIFuture[types.GetSamplerResponse]:
|
|
"""Get sampler information.
|
|
|
|
Args:
|
|
- `sampler_id`: The sampler ID (sampling_session_id) to get information for
|
|
|
|
Returns:
|
|
- An `APIFuture` containing the `GetSamplerResponse` with sampler details
|
|
|
|
Example:
|
|
```python
|
|
# Sync usage
|
|
future = rest_client.get_sampler("session-id:sample:0")
|
|
response = future.result()
|
|
print(f"Base model: {response.base_model}")
|
|
print(f"Model path: {response.model_path}")
|
|
|
|
# Async usage
|
|
response = await rest_client.get_sampler("session-id:sample:0")
|
|
print(f"Base model: {response.base_model}")
|
|
```
|
|
"""
|
|
|
|
async def _get_sampler_async() -> types.GetSamplerResponse:
|
|
async def _send_request() -> types.GetSamplerResponse:
|
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
return await client.get(
|
|
f"/api/v1/samplers/{sampler_id}",
|
|
cast_to=types.GetSamplerResponse,
|
|
)
|
|
|
|
return await self.holder.execute_with_retries(_send_request)
|
|
|
|
return self.holder.run_coroutine_threadsafe(_get_sampler_async())
|
|
|
|
@capture_exceptions(fatal=True)
|
|
async def get_sampler_async(self, sampler_id: str) -> types.GetSamplerResponse:
|
|
"""Async version of get_sampler."""
|
|
return await self.get_sampler(sampler_id)
|