CLI to export checkpoints to HF (#16)

Adds support for uploading checkpoints to HF.
This commit is contained in:
Kashif Rasul 2026-01-31 01:59:32 +01:00 committed by GitHub
parent ca40e08bb4
commit 55114e8c45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 474 additions and 91 deletions

View file

@ -62,6 +62,7 @@ class LazyGroup(click.Group):
- `tinker run info <run-id>` - Show details of a specific run
- `tinker checkpoint list` - List all checkpoints
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
- `tinker checkpoint push-hf <checkpoint-path>` - Upload a checkpoint to Hugging Face Hub
### 4. Output System with Inheritance
@ -241,6 +242,9 @@ tinker checkpoint list run-abc123
# Show checkpoint details
tinker checkpoint info ckpt-xyz789
# Upload checkpoint to Hugging Face Hub
tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter
# JSON output
tinker --format json run list
tinker --format json checkpoint list

View file

@ -239,6 +239,50 @@ class CheckpointDownloadOutput(OutputBase):
return rows
class CheckpointHubUploadOutput(OutputBase):
"""Output for 'tinker checkpoint push-hf' command."""
def __init__(
self,
checkpoint_path: str,
repo_id: str,
revision: str | None = None,
public: bool | None = None,
):
self.checkpoint_path = checkpoint_path
self.repo_id = repo_id
self.revision = revision
self.public = public
def to_dict(self) -> Dict[str, Any]:
result = {
"checkpoint_path": self.checkpoint_path,
"repo_id": self.repo_id,
}
if self.revision is not None:
result["revision"] = self.revision
if self.public is not None:
result["public"] = self.public
return result
def get_title(self) -> str | None:
return f"Checkpoint Hub Upload: {self.checkpoint_path}"
def get_table_columns(self) -> List[str]:
return ["Property", "Value"]
def get_table_rows(self) -> List[List[str]]:
rows = [
["Checkpoint Path", self.checkpoint_path],
["Repo ID", self.repo_id],
]
if self.revision is not None:
rows.append(["Revision", self.revision])
if self.public is not None:
rows.append(["Public", format_bool(self.public)])
return rows
def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Checkpoint":
"""Get checkpoint details from a tinker path.
@ -280,6 +324,307 @@ def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Che
raise TinkerCliError(f"Failed to retrieve checkpoint: {e}")
def _export_checkpoint_to_hub(
client: "RestClient",
tinker_path: str,
repo_id: str | None,
*,
private: bool,
revision: str | None,
commit_message: str | None,
create_pr: bool,
exist_ok: bool,
allow_patterns: list[str] | None,
ignore_patterns: list[str] | None,
add_model_card: bool,
) -> str:
# Lazy imports to keep CLI startup fast
try:
from huggingface_hub import HfApi, hf_hub_download
except ImportError as exc:
raise TinkerCliError(
"huggingface_hub is required for this command.",
"Install it with: pip install huggingface_hub, then run: hf auth login",
) from exc
import json
import os
import re
import tempfile
from pathlib import Path
from tinker import ParsedCheckpointTinkerPath
# Validate tinker path
parsed_tinker_path = ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
api = HfApi()
try:
api.whoami()
except Exception as exc:
raise TinkerCliError("Not logged in", "Run: hf auth login") from exc
def _sanitize_repo_name(value: str) -> str:
safe_chars = []
for ch in value:
if ch.isalnum() or ch in {"-", "_", "."}:
safe_chars.append(ch)
else:
safe_chars.append("-")
name = "".join(safe_chars)
while "--" in name:
name = name.replace("--", "-")
return name.strip("-_ .")
with tempfile.TemporaryDirectory() as temp_dir:
temp_root = Path(temp_dir)
archive_path = temp_root / "checkpoint.tar"
extract_dir = temp_root / "extract"
extract_dir.mkdir(parents=True, exist_ok=True)
url_response = client.get_checkpoint_archive_url_from_tinker_path(tinker_path).result()
_download_checkpoint_archive(
url_response.url,
archive_path=archive_path,
show_progress=False,
format="json",
)
_safe_extract_tar(archive_path, extract_dir, show_progress=False, format="json")
adapter_config = extract_dir / "adapter_config.json"
adapter_safetensors = extract_dir / "adapter_model.safetensors"
adapter_bin = extract_dir / "adapter_model.bin"
checkpoint_complete = extract_dir / "checkpoint_complete"
if not adapter_config.exists() or not (
adapter_safetensors.exists() or adapter_bin.exists()
):
raise TinkerCliError(
"Checkpoint archive does not contain a PEFT adapter.",
"Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin).",
)
if not checkpoint_complete.exists():
raise TinkerCliError(
"Checkpoint archive is missing 'checkpoint_complete'.",
"The adapter files may be incomplete.",
)
base_model = "unknown"
lora_rank = None
train_mlp = None
train_attn = None
train_unembed = None
try:
weights_info = client.get_weights_info_by_tinker_path(tinker_path).result()
base_model = weights_info.base_model
lora_rank = weights_info.lora_rank
train_mlp = weights_info.train_mlp
train_attn = weights_info.train_attn
train_unembed = weights_info.train_unembed
except Exception:
pass
try:
config_data = json.loads(adapter_config.read_text(encoding="utf-8"))
if not isinstance(config_data.get("base_model_name_or_path"), str):
config_data["base_model_name_or_path"] = base_model
adapter_config.write_text(
json.dumps(config_data, indent=2, sort_keys=True) + "\n", encoding="utf-8"
)
except Exception:
pass
if repo_id is None:
base_short = base_model.split("/")[-1] if base_model != "unknown" else "adapter"
derived = f"tinker-{base_short}-{parsed_tinker_path.training_run_id}"
repo_id = _sanitize_repo_name(derived)
if revision is None:
revision = _sanitize_repo_name(parsed_tinker_path.checkpoint_id.replace("/", "-"))
readme_path = extract_dir / "README.md"
if add_model_card and not readme_path.exists():
tags: List[str] = ["tinker", "peft", "lora"]
if base_model != "unknown":
tags.append(f"base_model:adapter:{base_model}")
model_card = [
"---",
f"base_model: {base_model}",
"library_name: peft",
"tags:",
]
for tag in tags:
model_card.append(f"- {tag}")
model_card.append(f"tinker_path: {tinker_path}")
model_card.extend(
[
"---",
"",
"# Tinker LoRA Adapter",
"",
"This repository contains a LoRA adapter exported from Tinker.",
"",
"## Usage",
"",
"```python",
"from transformers import AutoModelForCausalLM",
"",
f'adapter_id = "{repo_id}"',
f'base_model = "{base_model}"',
"",
'model = AutoModelForCausalLM.from_pretrained(adapter_id, device_map="auto")',
"```",
"",
"## Source",
"",
"```",
f"{tinker_path}",
"```",
"",
"## Details",
"",
f"- Base model: {base_model}",
]
)
if lora_rank is not None:
model_card.append(f"- LoRA rank: {lora_rank}")
if train_mlp is not None or train_attn is not None or train_unembed is not None:
model_card.append(
f"- Trained modules: attn={train_attn}, mlp={train_mlp}, unembed={train_unembed}"
)
model_card.append("")
readme_path.write_text("\n".join(model_card), encoding="utf-8")
api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok)
def _readme_tinker_path() -> str | None:
try:
readme_file = hf_hub_download(
repo_id=repo_id,
filename="README.md",
revision=revision,
token=None,
)
except Exception:
return None
try:
text = Path(readme_file).read_text(encoding="utf-8", errors="ignore")
except Exception:
return None
match = re.search(r"tinker://[^\s`]+", text)
return match.group(0) if match else None
existing_tinker_path = _readme_tinker_path()
if existing_tinker_path and existing_tinker_path != tinker_path:
raise TinkerCliError(
"Repo ID appears to contain a different Tinker checkpoint.",
f"Found {existing_tinker_path}, expected {tinker_path}.",
)
if allow_patterns is None:
ignore_patterns = list(ignore_patterns) if ignore_patterns else []
if "checkpoint_complete" not in ignore_patterns:
ignore_patterns.append("checkpoint_complete")
api.upload_folder(
folder_path=os.fspath(extract_dir),
repo_id=repo_id,
path_in_repo="",
revision=revision,
commit_message=commit_message,
create_pr=create_pr,
allow_patterns=list(allow_patterns) if allow_patterns else None,
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
)
return repo_id
def _safe_extract_tar(
archive_path,
extract_dir,
*,
show_progress: bool,
format: str,
) -> None:
import tarfile
base = extract_dir.resolve()
with tarfile.open(archive_path, "r") as tar:
members = tar.getmembers()
for member in members:
if member.issym() or member.islnk():
raise TinkerCliError(
"Unsafe symlink or hardlink in tar archive",
"Archive may be corrupted or malicious.",
)
member_path = (extract_dir / member.name).resolve()
if not str(member_path).startswith(str(base)):
raise TinkerCliError(
"Unsafe path in tar archive",
"Archive may be corrupted or malicious.",
)
if show_progress and format != "json":
with click.progressbar(
members,
label="Extracting archive ",
show_percent=True,
show_pos=True,
) as bar:
for member in bar:
tar.extract(member, path=extract_dir)
else:
tar.extractall(path=extract_dir)
def _download_checkpoint_archive(
url: str,
*,
archive_path,
show_progress: bool,
format: str,
) -> int:
import urllib.error
import urllib.request
try:
with urllib.request.urlopen(url, timeout=60) as response:
total_size = int(response.headers.get("Content-Length", 0))
if show_progress and format != "json":
with click.progressbar(
length=total_size,
label="Downloading archive",
show_percent=True,
show_pos=True,
show_eta=True,
) as bar:
with open(archive_path, "wb") as f:
while True:
chunk = response.read(8192)
if not chunk:
break
f.write(chunk)
bar.update(len(chunk))
else:
with open(archive_path, "wb") as f:
while True:
chunk = response.read(8192)
if not chunk:
break
f.write(chunk)
except urllib.error.URLError as e:
raise TinkerCliError(
f"Failed to download checkpoint: {e}",
"Please check your network connection and try again.",
) from e
except IOError as e:
raise TinkerCliError(
f"Failed to save checkpoint: {e}",
f"Please check that you have write permissions to {archive_path.parent}",
) from e
return total_size
# Click command group for checkpoint commands
@click.group()
def cli():
@ -545,10 +890,7 @@ def download(
"""
# Lazy imports to maintain fast CLI startup
import shutil
import tarfile
import tempfile
import urllib.error
import urllib.request
from pathlib import Path
# Validate it's a tinker path
@ -584,99 +926,136 @@ def download(
"Use --force to overwrite or choose a different output directory.",
)
# Create client and get download URL
client = create_rest_client()
url_response = client.get_checkpoint_archive_url_from_tinker_path(checkpoint_path).result()
# Use a temporary directory for the archive
with tempfile.TemporaryDirectory() as temp_dir:
archive_path = Path(temp_dir) / f"{checkpoint_id}.tar"
extract_dir = target_path
# Download the archive with progress bar
# Create client and get download URL
client = create_rest_client()
url_response = client.get_checkpoint_archive_url_from_tinker_path(checkpoint_path).result()
total_size = _download_checkpoint_archive(
url_response.url,
archive_path=archive_path,
show_progress=True,
format=format,
)
# Extract the checkpoint
try:
# Open the URL connection
with urllib.request.urlopen(url_response.url, timeout=30) as response:
# Get total file size from headers
total_size = int(response.headers.get("Content-Length", 0))
# Download with progress bar
if format != "json":
with click.progressbar(
length=total_size,
label="Downloading archive",
show_percent=True,
show_pos=True,
show_eta=True,
) as bar:
with open(archive_path, "wb") as f:
while True:
chunk = response.read(8192)
if not chunk:
break
f.write(chunk)
bar.update(len(chunk))
else:
# Silent download for JSON output
with open(archive_path, "wb") as f:
while True:
chunk = response.read(8192)
if not chunk:
break
f.write(chunk)
# Extract the checkpoint
try:
# Create extraction directory
extract_dir.mkdir(parents=True, exist_ok=True)
# Extract the tar archive
with tarfile.open(archive_path, "r") as tar:
# Get list of members for progress tracking
members = tar.getmembers()
if format != "json":
with click.progressbar(
members,
label="Extracting archive ",
show_percent=True,
show_pos=True,
) as bar:
for member in bar:
tar.extract(member, path=extract_dir)
else:
# Extract all at once for few files
tar.extractall(path=extract_dir)
destination = str(extract_dir)
# Delete archive after successful extraction
if archive_path.exists():
archive_path.unlink()
except tarfile.TarError as e:
raise TinkerCliError(
f"Failed to extract archive: {e}",
"The downloaded file may be corrupted. Try downloading again.",
)
# Create output object
output_obj = CheckpointDownloadOutput(
checkpoint_path=checkpoint_path,
file_size_bytes=total_size if total_size > 0 else None,
destination=destination,
)
# Print in requested format
output_obj.print(format=format)
except urllib.error.URLError as e:
extract_dir.mkdir(parents=True, exist_ok=True)
_safe_extract_tar(archive_path, extract_dir, show_progress=True, format=format)
destination = str(extract_dir)
if archive_path.exists():
archive_path.unlink()
except Exception as e:
raise TinkerCliError(
f"Failed to download checkpoint: {e}",
"Please check your network connection and try again.",
)
except IOError as e:
raise TinkerCliError(
f"Failed to save checkpoint: {e}",
f"Please check that you have write permissions to {output_dir}",
f"Failed to extract archive: {e}",
"The downloaded file may be corrupted. Try downloading again.",
)
output_obj = CheckpointDownloadOutput(
checkpoint_path=checkpoint_path,
file_size_bytes=total_size if total_size > 0 else None,
destination=destination,
)
output_obj.print(format=format)
@cli.command(name="push-hf")
@click.argument("checkpoint_path")
@click.option(
"--repo",
"-r",
"repo_id",
type=str,
default=None,
help="Hugging Face repo ID (e.g., username/my-lora-adapter). If omitted, derive from run.",
)
@click.option(
"--public",
is_flag=True,
help="Create or upload to a public repo (default: private).",
)
@click.option(
"--revision",
type=str,
default=None,
help="Target branch/revision to upload to (optional).",
)
@click.option(
"--commit-message",
type=str,
default=None,
help="Commit message for the upload (optional).",
)
@click.option(
"--create-pr",
is_flag=True,
help="Create a pull request instead of pushing to the main branch.",
)
@click.option(
"--allow-pattern",
"allow_patterns",
multiple=True,
help="Only upload files matching this pattern (can be repeated).",
)
@click.option(
"--ignore-pattern",
"ignore_patterns",
multiple=True,
help="Skip files matching this pattern (can be repeated).",
)
@click.option(
"--no-model-card",
is_flag=True,
help="Do not create a README.md model card if one is missing.",
)
@click.pass_obj
@handle_api_errors
def push_hf(
cli_context: CLIContext,
checkpoint_path: str,
repo_id: str | None,
public: bool,
revision: str | None,
commit_message: str | None,
create_pr: bool,
allow_patterns: tuple[str, ...],
ignore_patterns: tuple[str, ...],
no_model_card: bool,
) -> None:
"""Upload a checkpoint to the Hugging Face Hub as a PEFT adapter.
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/sampler_weights/0001).
"""
# Validate it's a tinker path
if not checkpoint_path.startswith("tinker://"):
raise TinkerCliError(
f"Invalid checkpoint path: {checkpoint_path}",
"Checkpoint path must be in the format: tinker://run-id/sampler_weights/0001",
)
client = create_rest_client()
repo_id_out = _export_checkpoint_to_hub(
client,
checkpoint_path,
repo_id,
private=not public,
revision=revision,
commit_message=commit_message,
create_pr=create_pr,
exist_ok=True,
allow_patterns=list(allow_patterns) if allow_patterns else None,
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
add_model_card=not no_model_card,
)
output_obj = CheckpointHubUploadOutput(
checkpoint_path=checkpoint_path,
repo_id=repo_id_out,
revision=revision,
public=public,
)
output_obj.print(format=cli_context.format)