initial export to hub

This commit is contained in:
Kashif Rasul 2026-01-27 19:55:25 +01:00
parent ca40e08bb4
commit d6aa21005b
2 changed files with 239 additions and 0 deletions

View file

@ -288,6 +288,56 @@ async def get_checkpoint_archive_url_from_tinker_path_async(
Async version of get_checkpoint_archive_url_from_tinker_path. Async version of get_checkpoint_archive_url_from_tinker_path.
#### `export_checkpoint_to_hub`
```python
def export_checkpoint_to_hub(
tinker_path: str,
repo_id: str | None,
*,
private: bool = True,
token: str | None = None,
revision: str | None = None,
commit_message: str | None = None,
create_pr: bool = False,
exist_ok: bool = True,
allow_patterns: Sequence[str] | None = None,
ignore_patterns: Sequence[str] | None = None,
add_model_card: bool = True,
) -> str
```
Download a checkpoint archive, extract the PEFT adapter files, optionally add a README.md
model card, and upload to the Hugging Face Hub.
Args:
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
- `repo_id`: Hugging Face repo ID (e.g., "username/my-lora-adapter"). If None,
a name is derived from the base model and checkpoint path.
- `private`: Whether to create the repo as private (default True)
- `token`: Hugging Face access token (optional)
- `revision`: Target branch/revision to upload to (optional)
- `commit_message`: Commit message for the upload (optional)
- `create_pr`: Whether to create a PR instead of pushing to the main branch
- `exist_ok`: Whether repo creation should succeed if repo exists
- `allow_patterns`: Optional list of file patterns to include
- `ignore_patterns`: Optional list of file patterns to exclude
- `add_model_card`: Whether to add a README.md if missing (default True)
Returns:
- The repo_id that was uploaded to
Example:
```python
rest_client = service_client.create_rest_client()
repo_id = rest_client.export_checkpoint_to_hub(
"tinker://run-id/weights/final",
"username/my-lora-adapter",
private=True,
)
print(f"Uploaded to: {repo_id}")
```
#### `publish_checkpoint_from_tinker_path` #### `publish_checkpoint_from_tinker_path`
```python ```python

View file

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Sequence
from concurrent.futures import Future as ConcurrentFuture from concurrent.futures import Future as ConcurrentFuture
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -397,6 +398,194 @@ class RestClient(TelemetryProvider):
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
) )
@sync_only
@capture_exceptions(fatal=True)
def export_checkpoint_to_hub(
self,
tinker_path: str,
repo_id: str | None,
*,
private: bool = True,
token: str | None = None,
revision: str | None = None,
commit_message: str | None = None,
create_pr: bool = False,
exist_ok: bool = True,
allow_patterns: Sequence[str] | None = None,
ignore_patterns: Sequence[str] | None = None,
add_model_card: bool = True,
) -> str:
"""Export a checkpoint archive to the Hugging Face Hub as a PEFT adapter.
This downloads the checkpoint archive, extracts it locally, optionally adds a
README.md model card, and uploads the folder to the Hub.
Args:
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
- `repo_id`: Hugging Face repo ID (e.g., "username/my-lora-adapter"). If None,
a name is derived from the base model and checkpoint path.
- `private`: Whether to create the repo as private (default False)
- `token`: Hugging Face access token (optional)
- `revision`: Target branch/revision to upload to (optional)
- `commit_message`: Commit message for the upload (optional)
- `create_pr`: Whether to create a PR instead of pushing to the main branch
- `exist_ok`: Whether repo creation should succeed if repo exists
- `allow_patterns`: Optional list of file patterns to include
- `ignore_patterns`: Optional list of file patterns to exclude
- `add_model_card`: Whether to add a README.md if missing (default True)
Returns:
- The repo_id that was uploaded to
"""
# Lazy imports to keep base SDK lightweight
try:
from huggingface_hub import HfApi
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError(
"huggingface_hub is required for export_checkpoint_to_hub. "
"Install it with: pip install huggingface_hub"
) from exc
import os
import tarfile
import tempfile
import urllib.request
from pathlib import Path
# Validate tinker path
types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
for member in tar.getmembers():
member_path = path / member.name
if not str(member_path.resolve()).startswith(str(path.resolve())):
raise ValueError("Unsafe path in tar archive")
tar.extractall(path=path)
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
def _sanitize_repo_name(value: str) -> str:
safe_chars = []
for ch in value:
if ch.isalnum() or ch in {"-", "_", "."}:
safe_chars.append(ch)
else:
safe_chars.append("-")
# Collapse repeated separators
name = "".join(safe_chars)
while "--" in name:
name = name.replace("--", "-")
return name.strip("-_ .")
url_response = self.get_checkpoint_archive_url_from_tinker_path(tinker_path).result()
with tempfile.TemporaryDirectory() as temp_dir:
temp_root = Path(temp_dir)
archive_path = temp_root / "checkpoint.tar"
extract_dir = temp_root / "extract"
extract_dir.mkdir(parents=True, exist_ok=True)
# Download archive
with urllib.request.urlopen(url_response.url, timeout=60) as response:
with open(archive_path, "wb") as f:
while True:
chunk = response.read(8192)
if not chunk:
break
f.write(chunk)
# Extract archive
with tarfile.open(archive_path, "r") as tar:
_safe_extract(tar, extract_dir)
# Validate PEFT adapter files exist
adapter_config = extract_dir / "adapter_config.json"
adapter_safetensors = extract_dir / "adapter_model.safetensors"
adapter_bin = extract_dir / "adapter_model.bin"
if not adapter_config.exists() or not (adapter_safetensors.exists() or adapter_bin.exists()):
raise ValueError(
"Checkpoint archive does not contain a PEFT adapter. "
"Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin)."
)
base_model = "unknown"
lora_rank = None
train_mlp = None
train_attn = None
train_unembed = None
try:
weights_info = self.get_weights_info_by_tinker_path(tinker_path).result()
base_model = weights_info.base_model
lora_rank = weights_info.lora_rank
train_mlp = weights_info.train_mlp
train_attn = weights_info.train_attn
train_unembed = weights_info.train_unembed
except Exception:
pass
if repo_id is None:
base_short = base_model.split("/")[-1] if base_model != "unknown" else "adapter"
checkpoint_id = parsed_tinker_path.checkpoint_id.replace("/", "-")
derived = f"tinker-{base_short}-{parsed_tinker_path.training_run_id}-{checkpoint_id}"
repo_id = _sanitize_repo_name(derived)
# Add a lightweight model card if missing
readme_path = extract_dir / "README.md"
if add_model_card and not readme_path.exists():
tags: list[str] = ["tinker", "peft", "lora", "transformers"]
if base_model != "unknown":
tags.append(f"base_model:adapter:{base_model}")
model_card = [
"---",
f"base_model: {base_model}",
"library_name: peft",
"tags:",
]
for tag in tags:
model_card.append(f"- {tag}")
model_card.append(f"tinker_path: {tinker_path}")
model_card.extend(
[
"---",
"",
"# LoRA Adapter (Tinker)",
"",
f"This repository contains a LoRA adapter exported from Tinker.",
"",
"## Source",
"",
f"- Tinker checkpoint: {tinker_path}",
"",
"## Details",
"",
f"- Base model: {base_model}",
]
)
if lora_rank is not None:
model_card.append(f"- LoRA rank: {lora_rank}")
if train_mlp is not None or train_attn is not None or train_unembed is not None:
model_card.append(
f"- Trained modules: attn={train_attn}, mlp={train_mlp}, unembed={train_unembed}"
)
model_card.append("")
readme_path.write_text("\n".join(model_card), encoding="utf-8")
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok)
api.upload_folder(
folder_path=os.fspath(extract_dir),
repo_id=repo_id,
path_in_repo="",
revision=revision,
commit_message=commit_message,
create_pr=create_pr,
allow_patterns=list(allow_patterns) if allow_patterns else None,
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
)
return repo_id
def _publish_checkpoint_submit( def _publish_checkpoint_submit(
self, training_run_id: types.ModelID, checkpoint_id: str self, training_run_id: types.ModelID, checkpoint_id: str
) -> AwaitableConcurrentFuture[None]: ) -> AwaitableConcurrentFuture[None]: