diff --git a/docs/api/restclient.md b/docs/api/restclient.md index 0a1b89f..249c526 100644 --- a/docs/api/restclient.md +++ b/docs/api/restclient.md @@ -288,6 +288,56 @@ async def get_checkpoint_archive_url_from_tinker_path_async( Async version of get_checkpoint_archive_url_from_tinker_path. +#### `export_checkpoint_to_hub` + +```python +def export_checkpoint_to_hub( + tinker_path: str, + repo_id: str | None, + *, + private: bool = True, + token: str | None = None, + revision: str | None = None, + commit_message: str | None = None, + create_pr: bool = False, + exist_ok: bool = True, + allow_patterns: Sequence[str] | None = None, + ignore_patterns: Sequence[str] | None = None, + add_model_card: bool = True, +) -> str +``` + +Download a checkpoint archive, extract the PEFT adapter files, optionally add a README.md +model card, and upload to the Hugging Face Hub. + +Args: +- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") +- `repo_id`: Hugging Face repo ID (e.g., "username/my-lora-adapter"). If None, + a name is derived from the base model and checkpoint path. +- `private`: Whether to create the repo as private (default True) +- `token`: Hugging Face access token (optional) +- `revision`: Target branch/revision to upload to (optional) +- `commit_message`: Commit message for the upload (optional) +- `create_pr`: Whether to create a PR instead of pushing to the main branch +- `exist_ok`: Whether repo creation should succeed if repo exists +- `allow_patterns`: Optional list of file patterns to include +- `ignore_patterns`: Optional list of file patterns to exclude +- `add_model_card`: Whether to add a README.md if missing (default True) + +Returns: +- The repo_id that was uploaded to + +Example: +```python +rest_client = service_client.create_rest_client() +repo_id = rest_client.export_checkpoint_to_hub( + "tinker://run-id/weights/final", + "username/my-lora-adapter", + private=True, +) +print(f"Uploaded to: {repo_id}") +``` + #### `publish_checkpoint_from_tinker_path` ```python diff --git a/src/tinker/lib/public_interfaces/rest_client.py b/src/tinker/lib/public_interfaces/rest_client.py index cb6ae7a..eed3402 100644 --- a/src/tinker/lib/public_interfaces/rest_client.py +++ b/src/tinker/lib/public_interfaces/rest_client.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from typing import Sequence from concurrent.futures import Future as ConcurrentFuture from typing import TYPE_CHECKING @@ -397,6 +398,194 @@ class RestClient(TelemetryProvider): parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id ) + @sync_only + @capture_exceptions(fatal=True) + def export_checkpoint_to_hub( + self, + tinker_path: str, + repo_id: str | None, + *, + private: bool = True, + token: str | None = None, + revision: str | None = None, + commit_message: str | None = None, + create_pr: bool = False, + exist_ok: bool = True, + allow_patterns: Sequence[str] | None = None, + ignore_patterns: Sequence[str] | None = None, + add_model_card: bool = True, + ) -> str: + """Export a checkpoint archive to the Hugging Face Hub as a PEFT adapter. + + This downloads the checkpoint archive, extracts it locally, optionally adds a + README.md model card, and uploads the folder to the Hub. + + Args: + - `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") + - `repo_id`: Hugging Face repo ID (e.g., "username/my-lora-adapter"). If None, + a name is derived from the base model and checkpoint path. + - `private`: Whether to create the repo as private (default False) + - `token`: Hugging Face access token (optional) + - `revision`: Target branch/revision to upload to (optional) + - `commit_message`: Commit message for the upload (optional) + - `create_pr`: Whether to create a PR instead of pushing to the main branch + - `exist_ok`: Whether repo creation should succeed if repo exists + - `allow_patterns`: Optional list of file patterns to include + - `ignore_patterns`: Optional list of file patterns to exclude + - `add_model_card`: Whether to add a README.md if missing (default True) + + Returns: + - The repo_id that was uploaded to + """ + # Lazy imports to keep base SDK lightweight + try: + from huggingface_hub import HfApi + except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "huggingface_hub is required for export_checkpoint_to_hub. " + "Install it with: pip install huggingface_hub" + ) from exc + + import os + import tarfile + import tempfile + import urllib.request + from pathlib import Path + + # Validate tinker path + types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) + + def _safe_extract(tar: tarfile.TarFile, path: Path) -> None: + for member in tar.getmembers(): + member_path = path / member.name + if not str(member_path.resolve()).startswith(str(path.resolve())): + raise ValueError("Unsafe path in tar archive") + tar.extractall(path=path) + + parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) + + def _sanitize_repo_name(value: str) -> str: + safe_chars = [] + for ch in value: + if ch.isalnum() or ch in {"-", "_", "."}: + safe_chars.append(ch) + else: + safe_chars.append("-") + # Collapse repeated separators + name = "".join(safe_chars) + while "--" in name: + name = name.replace("--", "-") + return name.strip("-_ .") + + url_response = self.get_checkpoint_archive_url_from_tinker_path(tinker_path).result() + + with tempfile.TemporaryDirectory() as temp_dir: + temp_root = Path(temp_dir) + archive_path = temp_root / "checkpoint.tar" + extract_dir = temp_root / "extract" + extract_dir.mkdir(parents=True, exist_ok=True) + + # Download archive + with urllib.request.urlopen(url_response.url, timeout=60) as response: + with open(archive_path, "wb") as f: + while True: + chunk = response.read(8192) + if not chunk: + break + f.write(chunk) + + # Extract archive + with tarfile.open(archive_path, "r") as tar: + _safe_extract(tar, extract_dir) + + # Validate PEFT adapter files exist + adapter_config = extract_dir / "adapter_config.json" + adapter_safetensors = extract_dir / "adapter_model.safetensors" + adapter_bin = extract_dir / "adapter_model.bin" + if not adapter_config.exists() or not (adapter_safetensors.exists() or adapter_bin.exists()): + raise ValueError( + "Checkpoint archive does not contain a PEFT adapter. " + "Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin)." + ) + + base_model = "unknown" + lora_rank = None + train_mlp = None + train_attn = None + train_unembed = None + try: + weights_info = self.get_weights_info_by_tinker_path(tinker_path).result() + base_model = weights_info.base_model + lora_rank = weights_info.lora_rank + train_mlp = weights_info.train_mlp + train_attn = weights_info.train_attn + train_unembed = weights_info.train_unembed + except Exception: + pass + + if repo_id is None: + base_short = base_model.split("/")[-1] if base_model != "unknown" else "adapter" + checkpoint_id = parsed_tinker_path.checkpoint_id.replace("/", "-") + derived = f"tinker-{base_short}-{parsed_tinker_path.training_run_id}-{checkpoint_id}" + repo_id = _sanitize_repo_name(derived) + + # Add a lightweight model card if missing + readme_path = extract_dir / "README.md" + if add_model_card and not readme_path.exists(): + tags: list[str] = ["tinker", "peft", "lora", "transformers"] + if base_model != "unknown": + tags.append(f"base_model:adapter:{base_model}") + model_card = [ + "---", + f"base_model: {base_model}", + "library_name: peft", + "tags:", + ] + for tag in tags: + model_card.append(f"- {tag}") + model_card.append(f"tinker_path: {tinker_path}") + model_card.extend( + [ + "---", + "", + "# LoRA Adapter (Tinker)", + "", + f"This repository contains a LoRA adapter exported from Tinker.", + "", + "## Source", + "", + f"- Tinker checkpoint: {tinker_path}", + "", + "## Details", + "", + f"- Base model: {base_model}", + ] + ) + if lora_rank is not None: + model_card.append(f"- LoRA rank: {lora_rank}") + if train_mlp is not None or train_attn is not None or train_unembed is not None: + model_card.append( + f"- Trained modules: attn={train_attn}, mlp={train_mlp}, unembed={train_unembed}" + ) + model_card.append("") + readme_path.write_text("\n".join(model_card), encoding="utf-8") + + api = HfApi(token=token) + api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok) + + api.upload_folder( + folder_path=os.fspath(extract_dir), + repo_id=repo_id, + path_in_repo="", + revision=revision, + commit_message=commit_message, + create_pr=create_pr, + allow_patterns=list(allow_patterns) if allow_patterns else None, + ignore_patterns=list(ignore_patterns) if ignore_patterns else None, + ) + + return repo_id + def _publish_checkpoint_submit( self, training_run_id: types.ModelID, checkpoint_id: str ) -> AwaitableConcurrentFuture[None]: