mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
CLI to export checkpoints to HF (#16)
Adds support for uploading checkpoints to HF.
This commit is contained in:
parent
ca40e08bb4
commit
55114e8c45
2 changed files with 474 additions and 91 deletions
|
|
@ -62,6 +62,7 @@ class LazyGroup(click.Group):
|
|||
- `tinker run info <run-id>` - Show details of a specific run
|
||||
- `tinker checkpoint list` - List all checkpoints
|
||||
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
|
||||
- `tinker checkpoint push-hf <checkpoint-path>` - Upload a checkpoint to Hugging Face Hub
|
||||
|
||||
### 4. Output System with Inheritance
|
||||
|
||||
|
|
@ -241,6 +242,9 @@ tinker checkpoint list run-abc123
|
|||
# Show checkpoint details
|
||||
tinker checkpoint info ckpt-xyz789
|
||||
|
||||
# Upload checkpoint to Hugging Face Hub
|
||||
tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter
|
||||
|
||||
# JSON output
|
||||
tinker --format json run list
|
||||
tinker --format json checkpoint list
|
||||
|
|
|
|||
|
|
@ -239,6 +239,50 @@ class CheckpointDownloadOutput(OutputBase):
|
|||
return rows
|
||||
|
||||
|
||||
class CheckpointHubUploadOutput(OutputBase):
|
||||
"""Output for 'tinker checkpoint push-hf' command."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: str,
|
||||
repo_id: str,
|
||||
revision: str | None = None,
|
||||
public: bool | None = None,
|
||||
):
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision
|
||||
self.public = public
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"checkpoint_path": self.checkpoint_path,
|
||||
"repo_id": self.repo_id,
|
||||
}
|
||||
if self.revision is not None:
|
||||
result["revision"] = self.revision
|
||||
if self.public is not None:
|
||||
result["public"] = self.public
|
||||
return result
|
||||
|
||||
def get_title(self) -> str | None:
|
||||
return f"Checkpoint Hub Upload: {self.checkpoint_path}"
|
||||
|
||||
def get_table_columns(self) -> List[str]:
|
||||
return ["Property", "Value"]
|
||||
|
||||
def get_table_rows(self) -> List[List[str]]:
|
||||
rows = [
|
||||
["Checkpoint Path", self.checkpoint_path],
|
||||
["Repo ID", self.repo_id],
|
||||
]
|
||||
if self.revision is not None:
|
||||
rows.append(["Revision", self.revision])
|
||||
if self.public is not None:
|
||||
rows.append(["Public", format_bool(self.public)])
|
||||
return rows
|
||||
|
||||
|
||||
def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Checkpoint":
|
||||
"""Get checkpoint details from a tinker path.
|
||||
|
||||
|
|
@ -280,6 +324,307 @@ def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Che
|
|||
raise TinkerCliError(f"Failed to retrieve checkpoint: {e}")
|
||||
|
||||
|
||||
def _export_checkpoint_to_hub(
|
||||
client: "RestClient",
|
||||
tinker_path: str,
|
||||
repo_id: str | None,
|
||||
*,
|
||||
private: bool,
|
||||
revision: str | None,
|
||||
commit_message: str | None,
|
||||
create_pr: bool,
|
||||
exist_ok: bool,
|
||||
allow_patterns: list[str] | None,
|
||||
ignore_patterns: list[str] | None,
|
||||
add_model_card: bool,
|
||||
) -> str:
|
||||
# Lazy imports to keep CLI startup fast
|
||||
try:
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
except ImportError as exc:
|
||||
raise TinkerCliError(
|
||||
"huggingface_hub is required for this command.",
|
||||
"Install it with: pip install huggingface_hub, then run: hf auth login",
|
||||
) from exc
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from tinker import ParsedCheckpointTinkerPath
|
||||
|
||||
# Validate tinker path
|
||||
parsed_tinker_path = ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||
|
||||
api = HfApi()
|
||||
try:
|
||||
api.whoami()
|
||||
except Exception as exc:
|
||||
raise TinkerCliError("Not logged in", "Run: hf auth login") from exc
|
||||
|
||||
def _sanitize_repo_name(value: str) -> str:
|
||||
safe_chars = []
|
||||
for ch in value:
|
||||
if ch.isalnum() or ch in {"-", "_", "."}:
|
||||
safe_chars.append(ch)
|
||||
else:
|
||||
safe_chars.append("-")
|
||||
name = "".join(safe_chars)
|
||||
while "--" in name:
|
||||
name = name.replace("--", "-")
|
||||
return name.strip("-_ .")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_root = Path(temp_dir)
|
||||
archive_path = temp_root / "checkpoint.tar"
|
||||
extract_dir = temp_root / "extract"
|
||||
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
url_response = client.get_checkpoint_archive_url_from_tinker_path(tinker_path).result()
|
||||
_download_checkpoint_archive(
|
||||
url_response.url,
|
||||
archive_path=archive_path,
|
||||
show_progress=False,
|
||||
format="json",
|
||||
)
|
||||
_safe_extract_tar(archive_path, extract_dir, show_progress=False, format="json")
|
||||
|
||||
adapter_config = extract_dir / "adapter_config.json"
|
||||
adapter_safetensors = extract_dir / "adapter_model.safetensors"
|
||||
adapter_bin = extract_dir / "adapter_model.bin"
|
||||
checkpoint_complete = extract_dir / "checkpoint_complete"
|
||||
if not adapter_config.exists() or not (
|
||||
adapter_safetensors.exists() or adapter_bin.exists()
|
||||
):
|
||||
raise TinkerCliError(
|
||||
"Checkpoint archive does not contain a PEFT adapter.",
|
||||
"Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin).",
|
||||
)
|
||||
if not checkpoint_complete.exists():
|
||||
raise TinkerCliError(
|
||||
"Checkpoint archive is missing 'checkpoint_complete'.",
|
||||
"The adapter files may be incomplete.",
|
||||
)
|
||||
|
||||
base_model = "unknown"
|
||||
lora_rank = None
|
||||
train_mlp = None
|
||||
train_attn = None
|
||||
train_unembed = None
|
||||
try:
|
||||
weights_info = client.get_weights_info_by_tinker_path(tinker_path).result()
|
||||
base_model = weights_info.base_model
|
||||
lora_rank = weights_info.lora_rank
|
||||
train_mlp = weights_info.train_mlp
|
||||
train_attn = weights_info.train_attn
|
||||
train_unembed = weights_info.train_unembed
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
config_data = json.loads(adapter_config.read_text(encoding="utf-8"))
|
||||
if not isinstance(config_data.get("base_model_name_or_path"), str):
|
||||
config_data["base_model_name_or_path"] = base_model
|
||||
adapter_config.write_text(
|
||||
json.dumps(config_data, indent=2, sort_keys=True) + "\n", encoding="utf-8"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if repo_id is None:
|
||||
base_short = base_model.split("/")[-1] if base_model != "unknown" else "adapter"
|
||||
derived = f"tinker-{base_short}-{parsed_tinker_path.training_run_id}"
|
||||
repo_id = _sanitize_repo_name(derived)
|
||||
if revision is None:
|
||||
revision = _sanitize_repo_name(parsed_tinker_path.checkpoint_id.replace("/", "-"))
|
||||
|
||||
readme_path = extract_dir / "README.md"
|
||||
if add_model_card and not readme_path.exists():
|
||||
tags: List[str] = ["tinker", "peft", "lora"]
|
||||
if base_model != "unknown":
|
||||
tags.append(f"base_model:adapter:{base_model}")
|
||||
model_card = [
|
||||
"---",
|
||||
f"base_model: {base_model}",
|
||||
"library_name: peft",
|
||||
"tags:",
|
||||
]
|
||||
for tag in tags:
|
||||
model_card.append(f"- {tag}")
|
||||
model_card.append(f"tinker_path: {tinker_path}")
|
||||
model_card.extend(
|
||||
[
|
||||
"---",
|
||||
"",
|
||||
"# Tinker LoRA Adapter",
|
||||
"",
|
||||
"This repository contains a LoRA adapter exported from Tinker.",
|
||||
"",
|
||||
"## Usage",
|
||||
"",
|
||||
"```python",
|
||||
"from transformers import AutoModelForCausalLM",
|
||||
"",
|
||||
f'adapter_id = "{repo_id}"',
|
||||
f'base_model = "{base_model}"',
|
||||
"",
|
||||
'model = AutoModelForCausalLM.from_pretrained(adapter_id, device_map="auto")',
|
||||
"```",
|
||||
"",
|
||||
"## Source",
|
||||
"",
|
||||
"```",
|
||||
f"{tinker_path}",
|
||||
"```",
|
||||
"",
|
||||
"## Details",
|
||||
"",
|
||||
f"- Base model: {base_model}",
|
||||
]
|
||||
)
|
||||
if lora_rank is not None:
|
||||
model_card.append(f"- LoRA rank: {lora_rank}")
|
||||
if train_mlp is not None or train_attn is not None or train_unembed is not None:
|
||||
model_card.append(
|
||||
f"- Trained modules: attn={train_attn}, mlp={train_mlp}, unembed={train_unembed}"
|
||||
)
|
||||
model_card.append("")
|
||||
readme_path.write_text("\n".join(model_card), encoding="utf-8")
|
||||
|
||||
api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok)
|
||||
|
||||
def _readme_tinker_path() -> str | None:
|
||||
try:
|
||||
readme_file = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="README.md",
|
||||
revision=revision,
|
||||
token=None,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
try:
|
||||
text = Path(readme_file).read_text(encoding="utf-8", errors="ignore")
|
||||
except Exception:
|
||||
return None
|
||||
match = re.search(r"tinker://[^\s`]+", text)
|
||||
return match.group(0) if match else None
|
||||
|
||||
existing_tinker_path = _readme_tinker_path()
|
||||
if existing_tinker_path and existing_tinker_path != tinker_path:
|
||||
raise TinkerCliError(
|
||||
"Repo ID appears to contain a different Tinker checkpoint.",
|
||||
f"Found {existing_tinker_path}, expected {tinker_path}.",
|
||||
)
|
||||
|
||||
if allow_patterns is None:
|
||||
ignore_patterns = list(ignore_patterns) if ignore_patterns else []
|
||||
if "checkpoint_complete" not in ignore_patterns:
|
||||
ignore_patterns.append("checkpoint_complete")
|
||||
|
||||
api.upload_folder(
|
||||
folder_path=os.fspath(extract_dir),
|
||||
repo_id=repo_id,
|
||||
path_in_repo="",
|
||||
revision=revision,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
allow_patterns=list(allow_patterns) if allow_patterns else None,
|
||||
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
|
||||
)
|
||||
|
||||
return repo_id
|
||||
|
||||
|
||||
def _safe_extract_tar(
|
||||
archive_path,
|
||||
extract_dir,
|
||||
*,
|
||||
show_progress: bool,
|
||||
format: str,
|
||||
) -> None:
|
||||
import tarfile
|
||||
|
||||
base = extract_dir.resolve()
|
||||
with tarfile.open(archive_path, "r") as tar:
|
||||
members = tar.getmembers()
|
||||
for member in members:
|
||||
if member.issym() or member.islnk():
|
||||
raise TinkerCliError(
|
||||
"Unsafe symlink or hardlink in tar archive",
|
||||
"Archive may be corrupted or malicious.",
|
||||
)
|
||||
member_path = (extract_dir / member.name).resolve()
|
||||
if not str(member_path).startswith(str(base)):
|
||||
raise TinkerCliError(
|
||||
"Unsafe path in tar archive",
|
||||
"Archive may be corrupted or malicious.",
|
||||
)
|
||||
if show_progress and format != "json":
|
||||
with click.progressbar(
|
||||
members,
|
||||
label="Extracting archive ",
|
||||
show_percent=True,
|
||||
show_pos=True,
|
||||
) as bar:
|
||||
for member in bar:
|
||||
tar.extract(member, path=extract_dir)
|
||||
else:
|
||||
tar.extractall(path=extract_dir)
|
||||
|
||||
|
||||
def _download_checkpoint_archive(
|
||||
url: str,
|
||||
*,
|
||||
archive_path,
|
||||
show_progress: bool,
|
||||
format: str,
|
||||
) -> int:
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=60) as response:
|
||||
total_size = int(response.headers.get("Content-Length", 0))
|
||||
|
||||
if show_progress and format != "json":
|
||||
with click.progressbar(
|
||||
length=total_size,
|
||||
label="Downloading archive",
|
||||
show_percent=True,
|
||||
show_pos=True,
|
||||
show_eta=True,
|
||||
) as bar:
|
||||
with open(archive_path, "wb") as f:
|
||||
while True:
|
||||
chunk = response.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
bar.update(len(chunk))
|
||||
else:
|
||||
with open(archive_path, "wb") as f:
|
||||
while True:
|
||||
chunk = response.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
except urllib.error.URLError as e:
|
||||
raise TinkerCliError(
|
||||
f"Failed to download checkpoint: {e}",
|
||||
"Please check your network connection and try again.",
|
||||
) from e
|
||||
except IOError as e:
|
||||
raise TinkerCliError(
|
||||
f"Failed to save checkpoint: {e}",
|
||||
f"Please check that you have write permissions to {archive_path.parent}",
|
||||
) from e
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
# Click command group for checkpoint commands
|
||||
@click.group()
|
||||
def cli():
|
||||
|
|
@ -545,10 +890,7 @@ def download(
|
|||
"""
|
||||
# Lazy imports to maintain fast CLI startup
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
# Validate it's a tinker path
|
||||
|
|
@ -584,99 +926,136 @@ def download(
|
|||
"Use --force to overwrite or choose a different output directory.",
|
||||
)
|
||||
|
||||
# Create client and get download URL
|
||||
client = create_rest_client()
|
||||
url_response = client.get_checkpoint_archive_url_from_tinker_path(checkpoint_path).result()
|
||||
|
||||
# Use a temporary directory for the archive
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
archive_path = Path(temp_dir) / f"{checkpoint_id}.tar"
|
||||
extract_dir = target_path
|
||||
|
||||
# Download the archive with progress bar
|
||||
# Create client and get download URL
|
||||
client = create_rest_client()
|
||||
url_response = client.get_checkpoint_archive_url_from_tinker_path(checkpoint_path).result()
|
||||
|
||||
total_size = _download_checkpoint_archive(
|
||||
url_response.url,
|
||||
archive_path=archive_path,
|
||||
show_progress=True,
|
||||
format=format,
|
||||
)
|
||||
|
||||
# Extract the checkpoint
|
||||
try:
|
||||
# Open the URL connection
|
||||
with urllib.request.urlopen(url_response.url, timeout=30) as response:
|
||||
# Get total file size from headers
|
||||
total_size = int(response.headers.get("Content-Length", 0))
|
||||
|
||||
# Download with progress bar
|
||||
if format != "json":
|
||||
with click.progressbar(
|
||||
length=total_size,
|
||||
label="Downloading archive",
|
||||
show_percent=True,
|
||||
show_pos=True,
|
||||
show_eta=True,
|
||||
) as bar:
|
||||
with open(archive_path, "wb") as f:
|
||||
while True:
|
||||
chunk = response.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
bar.update(len(chunk))
|
||||
else:
|
||||
# Silent download for JSON output
|
||||
with open(archive_path, "wb") as f:
|
||||
while True:
|
||||
chunk = response.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
|
||||
# Extract the checkpoint
|
||||
try:
|
||||
# Create extraction directory
|
||||
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extract the tar archive
|
||||
with tarfile.open(archive_path, "r") as tar:
|
||||
# Get list of members for progress tracking
|
||||
members = tar.getmembers()
|
||||
|
||||
if format != "json":
|
||||
with click.progressbar(
|
||||
members,
|
||||
label="Extracting archive ",
|
||||
show_percent=True,
|
||||
show_pos=True,
|
||||
) as bar:
|
||||
for member in bar:
|
||||
tar.extract(member, path=extract_dir)
|
||||
else:
|
||||
# Extract all at once for few files
|
||||
tar.extractall(path=extract_dir)
|
||||
|
||||
destination = str(extract_dir)
|
||||
|
||||
# Delete archive after successful extraction
|
||||
if archive_path.exists():
|
||||
archive_path.unlink()
|
||||
|
||||
except tarfile.TarError as e:
|
||||
raise TinkerCliError(
|
||||
f"Failed to extract archive: {e}",
|
||||
"The downloaded file may be corrupted. Try downloading again.",
|
||||
)
|
||||
|
||||
# Create output object
|
||||
output_obj = CheckpointDownloadOutput(
|
||||
checkpoint_path=checkpoint_path,
|
||||
file_size_bytes=total_size if total_size > 0 else None,
|
||||
destination=destination,
|
||||
)
|
||||
|
||||
# Print in requested format
|
||||
output_obj.print(format=format)
|
||||
|
||||
except urllib.error.URLError as e:
|
||||
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||
_safe_extract_tar(archive_path, extract_dir, show_progress=True, format=format)
|
||||
destination = str(extract_dir)
|
||||
if archive_path.exists():
|
||||
archive_path.unlink()
|
||||
except Exception as e:
|
||||
raise TinkerCliError(
|
||||
f"Failed to download checkpoint: {e}",
|
||||
"Please check your network connection and try again.",
|
||||
)
|
||||
except IOError as e:
|
||||
raise TinkerCliError(
|
||||
f"Failed to save checkpoint: {e}",
|
||||
f"Please check that you have write permissions to {output_dir}",
|
||||
f"Failed to extract archive: {e}",
|
||||
"The downloaded file may be corrupted. Try downloading again.",
|
||||
)
|
||||
|
||||
output_obj = CheckpointDownloadOutput(
|
||||
checkpoint_path=checkpoint_path,
|
||||
file_size_bytes=total_size if total_size > 0 else None,
|
||||
destination=destination,
|
||||
)
|
||||
output_obj.print(format=format)
|
||||
|
||||
|
||||
@cli.command(name="push-hf")
|
||||
@click.argument("checkpoint_path")
|
||||
@click.option(
|
||||
"--repo",
|
||||
"-r",
|
||||
"repo_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hugging Face repo ID (e.g., username/my-lora-adapter). If omitted, derive from run.",
|
||||
)
|
||||
@click.option(
|
||||
"--public",
|
||||
is_flag=True,
|
||||
help="Create or upload to a public repo (default: private).",
|
||||
)
|
||||
@click.option(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Target branch/revision to upload to (optional).",
|
||||
)
|
||||
@click.option(
|
||||
"--commit-message",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Commit message for the upload (optional).",
|
||||
)
|
||||
@click.option(
|
||||
"--create-pr",
|
||||
is_flag=True,
|
||||
help="Create a pull request instead of pushing to the main branch.",
|
||||
)
|
||||
@click.option(
|
||||
"--allow-pattern",
|
||||
"allow_patterns",
|
||||
multiple=True,
|
||||
help="Only upload files matching this pattern (can be repeated).",
|
||||
)
|
||||
@click.option(
|
||||
"--ignore-pattern",
|
||||
"ignore_patterns",
|
||||
multiple=True,
|
||||
help="Skip files matching this pattern (can be repeated).",
|
||||
)
|
||||
@click.option(
|
||||
"--no-model-card",
|
||||
is_flag=True,
|
||||
help="Do not create a README.md model card if one is missing.",
|
||||
)
|
||||
@click.pass_obj
|
||||
@handle_api_errors
|
||||
def push_hf(
|
||||
cli_context: CLIContext,
|
||||
checkpoint_path: str,
|
||||
repo_id: str | None,
|
||||
public: bool,
|
||||
revision: str | None,
|
||||
commit_message: str | None,
|
||||
create_pr: bool,
|
||||
allow_patterns: tuple[str, ...],
|
||||
ignore_patterns: tuple[str, ...],
|
||||
no_model_card: bool,
|
||||
) -> None:
|
||||
"""Upload a checkpoint to the Hugging Face Hub as a PEFT adapter.
|
||||
|
||||
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/sampler_weights/0001).
|
||||
"""
|
||||
# Validate it's a tinker path
|
||||
if not checkpoint_path.startswith("tinker://"):
|
||||
raise TinkerCliError(
|
||||
f"Invalid checkpoint path: {checkpoint_path}",
|
||||
"Checkpoint path must be in the format: tinker://run-id/sampler_weights/0001",
|
||||
)
|
||||
|
||||
client = create_rest_client()
|
||||
repo_id_out = _export_checkpoint_to_hub(
|
||||
client,
|
||||
checkpoint_path,
|
||||
repo_id,
|
||||
private=not public,
|
||||
revision=revision,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
exist_ok=True,
|
||||
allow_patterns=list(allow_patterns) if allow_patterns else None,
|
||||
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
|
||||
add_model_card=not no_model_card,
|
||||
)
|
||||
|
||||
output_obj = CheckpointHubUploadOutput(
|
||||
checkpoint_path=checkpoint_path,
|
||||
repo_id=repo_id_out,
|
||||
revision=revision,
|
||||
public=public,
|
||||
)
|
||||
output_obj.print(format=cli_context.format)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue