mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-05-03 17:53:19 +00:00
CLI to export checkpoints to HF (#16)
Adds support for uploading checkpoints to HF.
This commit is contained in:
parent
ca40e08bb4
commit
55114e8c45
2 changed files with 474 additions and 91 deletions
|
|
@ -62,6 +62,7 @@ class LazyGroup(click.Group):
|
||||||
- `tinker run info <run-id>` - Show details of a specific run
|
- `tinker run info <run-id>` - Show details of a specific run
|
||||||
- `tinker checkpoint list` - List all checkpoints
|
- `tinker checkpoint list` - List all checkpoints
|
||||||
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
|
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
|
||||||
|
- `tinker checkpoint push-hf <checkpoint-path>` - Upload a checkpoint to Hugging Face Hub
|
||||||
|
|
||||||
### 4. Output System with Inheritance
|
### 4. Output System with Inheritance
|
||||||
|
|
||||||
|
|
@ -241,6 +242,9 @@ tinker checkpoint list run-abc123
|
||||||
# Show checkpoint details
|
# Show checkpoint details
|
||||||
tinker checkpoint info ckpt-xyz789
|
tinker checkpoint info ckpt-xyz789
|
||||||
|
|
||||||
|
# Upload checkpoint to Hugging Face Hub
|
||||||
|
tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter
|
||||||
|
|
||||||
# JSON output
|
# JSON output
|
||||||
tinker --format json run list
|
tinker --format json run list
|
||||||
tinker --format json checkpoint list
|
tinker --format json checkpoint list
|
||||||
|
|
|
||||||
|
|
@ -239,6 +239,50 @@ class CheckpointDownloadOutput(OutputBase):
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointHubUploadOutput(OutputBase):
|
||||||
|
"""Output for 'tinker checkpoint push-hf' command."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
checkpoint_path: str,
|
||||||
|
repo_id: str,
|
||||||
|
revision: str | None = None,
|
||||||
|
public: bool | None = None,
|
||||||
|
):
|
||||||
|
self.checkpoint_path = checkpoint_path
|
||||||
|
self.repo_id = repo_id
|
||||||
|
self.revision = revision
|
||||||
|
self.public = public
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
result = {
|
||||||
|
"checkpoint_path": self.checkpoint_path,
|
||||||
|
"repo_id": self.repo_id,
|
||||||
|
}
|
||||||
|
if self.revision is not None:
|
||||||
|
result["revision"] = self.revision
|
||||||
|
if self.public is not None:
|
||||||
|
result["public"] = self.public
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_title(self) -> str | None:
|
||||||
|
return f"Checkpoint Hub Upload: {self.checkpoint_path}"
|
||||||
|
|
||||||
|
def get_table_columns(self) -> List[str]:
|
||||||
|
return ["Property", "Value"]
|
||||||
|
|
||||||
|
def get_table_rows(self) -> List[List[str]]:
|
||||||
|
rows = [
|
||||||
|
["Checkpoint Path", self.checkpoint_path],
|
||||||
|
["Repo ID", self.repo_id],
|
||||||
|
]
|
||||||
|
if self.revision is not None:
|
||||||
|
rows.append(["Revision", self.revision])
|
||||||
|
if self.public is not None:
|
||||||
|
rows.append(["Public", format_bool(self.public)])
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Checkpoint":
|
def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Checkpoint":
|
||||||
"""Get checkpoint details from a tinker path.
|
"""Get checkpoint details from a tinker path.
|
||||||
|
|
||||||
|
|
@ -280,6 +324,307 @@ def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Che
|
||||||
raise TinkerCliError(f"Failed to retrieve checkpoint: {e}")
|
raise TinkerCliError(f"Failed to retrieve checkpoint: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _export_checkpoint_to_hub(
|
||||||
|
client: "RestClient",
|
||||||
|
tinker_path: str,
|
||||||
|
repo_id: str | None,
|
||||||
|
*,
|
||||||
|
private: bool,
|
||||||
|
revision: str | None,
|
||||||
|
commit_message: str | None,
|
||||||
|
create_pr: bool,
|
||||||
|
exist_ok: bool,
|
||||||
|
allow_patterns: list[str] | None,
|
||||||
|
ignore_patterns: list[str] | None,
|
||||||
|
add_model_card: bool,
|
||||||
|
) -> str:
|
||||||
|
# Lazy imports to keep CLI startup fast
|
||||||
|
try:
|
||||||
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
|
except ImportError as exc:
|
||||||
|
raise TinkerCliError(
|
||||||
|
"huggingface_hub is required for this command.",
|
||||||
|
"Install it with: pip install huggingface_hub, then run: hf auth login",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tinker import ParsedCheckpointTinkerPath
|
||||||
|
|
||||||
|
# Validate tinker path
|
||||||
|
parsed_tinker_path = ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
try:
|
||||||
|
api.whoami()
|
||||||
|
except Exception as exc:
|
||||||
|
raise TinkerCliError("Not logged in", "Run: hf auth login") from exc
|
||||||
|
|
||||||
|
def _sanitize_repo_name(value: str) -> str:
|
||||||
|
safe_chars = []
|
||||||
|
for ch in value:
|
||||||
|
if ch.isalnum() or ch in {"-", "_", "."}:
|
||||||
|
safe_chars.append(ch)
|
||||||
|
else:
|
||||||
|
safe_chars.append("-")
|
||||||
|
name = "".join(safe_chars)
|
||||||
|
while "--" in name:
|
||||||
|
name = name.replace("--", "-")
|
||||||
|
return name.strip("-_ .")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
temp_root = Path(temp_dir)
|
||||||
|
archive_path = temp_root / "checkpoint.tar"
|
||||||
|
extract_dir = temp_root / "extract"
|
||||||
|
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
url_response = client.get_checkpoint_archive_url_from_tinker_path(tinker_path).result()
|
||||||
|
_download_checkpoint_archive(
|
||||||
|
url_response.url,
|
||||||
|
archive_path=archive_path,
|
||||||
|
show_progress=False,
|
||||||
|
format="json",
|
||||||
|
)
|
||||||
|
_safe_extract_tar(archive_path, extract_dir, show_progress=False, format="json")
|
||||||
|
|
||||||
|
adapter_config = extract_dir / "adapter_config.json"
|
||||||
|
adapter_safetensors = extract_dir / "adapter_model.safetensors"
|
||||||
|
adapter_bin = extract_dir / "adapter_model.bin"
|
||||||
|
checkpoint_complete = extract_dir / "checkpoint_complete"
|
||||||
|
if not adapter_config.exists() or not (
|
||||||
|
adapter_safetensors.exists() or adapter_bin.exists()
|
||||||
|
):
|
||||||
|
raise TinkerCliError(
|
||||||
|
"Checkpoint archive does not contain a PEFT adapter.",
|
||||||
|
"Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin).",
|
||||||
|
)
|
||||||
|
if not checkpoint_complete.exists():
|
||||||
|
raise TinkerCliError(
|
||||||
|
"Checkpoint archive is missing 'checkpoint_complete'.",
|
||||||
|
"The adapter files may be incomplete.",
|
||||||
|
)
|
||||||
|
|
||||||
|
base_model = "unknown"
|
||||||
|
lora_rank = None
|
||||||
|
train_mlp = None
|
||||||
|
train_attn = None
|
||||||
|
train_unembed = None
|
||||||
|
try:
|
||||||
|
weights_info = client.get_weights_info_by_tinker_path(tinker_path).result()
|
||||||
|
base_model = weights_info.base_model
|
||||||
|
lora_rank = weights_info.lora_rank
|
||||||
|
train_mlp = weights_info.train_mlp
|
||||||
|
train_attn = weights_info.train_attn
|
||||||
|
train_unembed = weights_info.train_unembed
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_data = json.loads(adapter_config.read_text(encoding="utf-8"))
|
||||||
|
if not isinstance(config_data.get("base_model_name_or_path"), str):
|
||||||
|
config_data["base_model_name_or_path"] = base_model
|
||||||
|
adapter_config.write_text(
|
||||||
|
json.dumps(config_data, indent=2, sort_keys=True) + "\n", encoding="utf-8"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if repo_id is None:
|
||||||
|
base_short = base_model.split("/")[-1] if base_model != "unknown" else "adapter"
|
||||||
|
derived = f"tinker-{base_short}-{parsed_tinker_path.training_run_id}"
|
||||||
|
repo_id = _sanitize_repo_name(derived)
|
||||||
|
if revision is None:
|
||||||
|
revision = _sanitize_repo_name(parsed_tinker_path.checkpoint_id.replace("/", "-"))
|
||||||
|
|
||||||
|
readme_path = extract_dir / "README.md"
|
||||||
|
if add_model_card and not readme_path.exists():
|
||||||
|
tags: List[str] = ["tinker", "peft", "lora"]
|
||||||
|
if base_model != "unknown":
|
||||||
|
tags.append(f"base_model:adapter:{base_model}")
|
||||||
|
model_card = [
|
||||||
|
"---",
|
||||||
|
f"base_model: {base_model}",
|
||||||
|
"library_name: peft",
|
||||||
|
"tags:",
|
||||||
|
]
|
||||||
|
for tag in tags:
|
||||||
|
model_card.append(f"- {tag}")
|
||||||
|
model_card.append(f"tinker_path: {tinker_path}")
|
||||||
|
model_card.extend(
|
||||||
|
[
|
||||||
|
"---",
|
||||||
|
"",
|
||||||
|
"# Tinker LoRA Adapter",
|
||||||
|
"",
|
||||||
|
"This repository contains a LoRA adapter exported from Tinker.",
|
||||||
|
"",
|
||||||
|
"## Usage",
|
||||||
|
"",
|
||||||
|
"```python",
|
||||||
|
"from transformers import AutoModelForCausalLM",
|
||||||
|
"",
|
||||||
|
f'adapter_id = "{repo_id}"',
|
||||||
|
f'base_model = "{base_model}"',
|
||||||
|
"",
|
||||||
|
'model = AutoModelForCausalLM.from_pretrained(adapter_id, device_map="auto")',
|
||||||
|
"```",
|
||||||
|
"",
|
||||||
|
"## Source",
|
||||||
|
"",
|
||||||
|
"```",
|
||||||
|
f"{tinker_path}",
|
||||||
|
"```",
|
||||||
|
"",
|
||||||
|
"## Details",
|
||||||
|
"",
|
||||||
|
f"- Base model: {base_model}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if lora_rank is not None:
|
||||||
|
model_card.append(f"- LoRA rank: {lora_rank}")
|
||||||
|
if train_mlp is not None or train_attn is not None or train_unembed is not None:
|
||||||
|
model_card.append(
|
||||||
|
f"- Trained modules: attn={train_attn}, mlp={train_mlp}, unembed={train_unembed}"
|
||||||
|
)
|
||||||
|
model_card.append("")
|
||||||
|
readme_path.write_text("\n".join(model_card), encoding="utf-8")
|
||||||
|
|
||||||
|
api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok)
|
||||||
|
|
||||||
|
def _readme_tinker_path() -> str | None:
|
||||||
|
try:
|
||||||
|
readme_file = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename="README.md",
|
||||||
|
revision=revision,
|
||||||
|
token=None,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
text = Path(readme_file).read_text(encoding="utf-8", errors="ignore")
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
match = re.search(r"tinker://[^\s`]+", text)
|
||||||
|
return match.group(0) if match else None
|
||||||
|
|
||||||
|
existing_tinker_path = _readme_tinker_path()
|
||||||
|
if existing_tinker_path and existing_tinker_path != tinker_path:
|
||||||
|
raise TinkerCliError(
|
||||||
|
"Repo ID appears to contain a different Tinker checkpoint.",
|
||||||
|
f"Found {existing_tinker_path}, expected {tinker_path}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if allow_patterns is None:
|
||||||
|
ignore_patterns = list(ignore_patterns) if ignore_patterns else []
|
||||||
|
if "checkpoint_complete" not in ignore_patterns:
|
||||||
|
ignore_patterns.append("checkpoint_complete")
|
||||||
|
|
||||||
|
api.upload_folder(
|
||||||
|
folder_path=os.fspath(extract_dir),
|
||||||
|
repo_id=repo_id,
|
||||||
|
path_in_repo="",
|
||||||
|
revision=revision,
|
||||||
|
commit_message=commit_message,
|
||||||
|
create_pr=create_pr,
|
||||||
|
allow_patterns=list(allow_patterns) if allow_patterns else None,
|
||||||
|
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return repo_id
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_extract_tar(
|
||||||
|
archive_path,
|
||||||
|
extract_dir,
|
||||||
|
*,
|
||||||
|
show_progress: bool,
|
||||||
|
format: str,
|
||||||
|
) -> None:
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
base = extract_dir.resolve()
|
||||||
|
with tarfile.open(archive_path, "r") as tar:
|
||||||
|
members = tar.getmembers()
|
||||||
|
for member in members:
|
||||||
|
if member.issym() or member.islnk():
|
||||||
|
raise TinkerCliError(
|
||||||
|
"Unsafe symlink or hardlink in tar archive",
|
||||||
|
"Archive may be corrupted or malicious.",
|
||||||
|
)
|
||||||
|
member_path = (extract_dir / member.name).resolve()
|
||||||
|
if not str(member_path).startswith(str(base)):
|
||||||
|
raise TinkerCliError(
|
||||||
|
"Unsafe path in tar archive",
|
||||||
|
"Archive may be corrupted or malicious.",
|
||||||
|
)
|
||||||
|
if show_progress and format != "json":
|
||||||
|
with click.progressbar(
|
||||||
|
members,
|
||||||
|
label="Extracting archive ",
|
||||||
|
show_percent=True,
|
||||||
|
show_pos=True,
|
||||||
|
) as bar:
|
||||||
|
for member in bar:
|
||||||
|
tar.extract(member, path=extract_dir)
|
||||||
|
else:
|
||||||
|
tar.extractall(path=extract_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def _download_checkpoint_archive(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
archive_path,
|
||||||
|
show_progress: bool,
|
||||||
|
format: str,
|
||||||
|
) -> int:
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(url, timeout=60) as response:
|
||||||
|
total_size = int(response.headers.get("Content-Length", 0))
|
||||||
|
|
||||||
|
if show_progress and format != "json":
|
||||||
|
with click.progressbar(
|
||||||
|
length=total_size,
|
||||||
|
label="Downloading archive",
|
||||||
|
show_percent=True,
|
||||||
|
show_pos=True,
|
||||||
|
show_eta=True,
|
||||||
|
) as bar:
|
||||||
|
with open(archive_path, "wb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = response.read(8192)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
bar.update(len(chunk))
|
||||||
|
else:
|
||||||
|
with open(archive_path, "wb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = response.read(8192)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
except urllib.error.URLError as e:
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Failed to download checkpoint: {e}",
|
||||||
|
"Please check your network connection and try again.",
|
||||||
|
) from e
|
||||||
|
except IOError as e:
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Failed to save checkpoint: {e}",
|
||||||
|
f"Please check that you have write permissions to {archive_path.parent}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return total_size
|
||||||
|
|
||||||
|
|
||||||
# Click command group for checkpoint commands
|
# Click command group for checkpoint commands
|
||||||
@click.group()
|
@click.group()
|
||||||
def cli():
|
def cli():
|
||||||
|
|
@ -545,10 +890,7 @@ def download(
|
||||||
"""
|
"""
|
||||||
# Lazy imports to maintain fast CLI startup
|
# Lazy imports to maintain fast CLI startup
|
||||||
import shutil
|
import shutil
|
||||||
import tarfile
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import urllib.error
|
|
||||||
import urllib.request
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Validate it's a tinker path
|
# Validate it's a tinker path
|
||||||
|
|
@ -584,99 +926,136 @@ def download(
|
||||||
"Use --force to overwrite or choose a different output directory.",
|
"Use --force to overwrite or choose a different output directory.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create client and get download URL
|
|
||||||
client = create_rest_client()
|
|
||||||
url_response = client.get_checkpoint_archive_url_from_tinker_path(checkpoint_path).result()
|
|
||||||
|
|
||||||
# Use a temporary directory for the archive
|
# Use a temporary directory for the archive
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
archive_path = Path(temp_dir) / f"{checkpoint_id}.tar"
|
archive_path = Path(temp_dir) / f"{checkpoint_id}.tar"
|
||||||
extract_dir = target_path
|
extract_dir = target_path
|
||||||
|
|
||||||
# Download the archive with progress bar
|
# Create client and get download URL
|
||||||
|
client = create_rest_client()
|
||||||
|
url_response = client.get_checkpoint_archive_url_from_tinker_path(checkpoint_path).result()
|
||||||
|
|
||||||
|
total_size = _download_checkpoint_archive(
|
||||||
|
url_response.url,
|
||||||
|
archive_path=archive_path,
|
||||||
|
show_progress=True,
|
||||||
|
format=format,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the checkpoint
|
||||||
try:
|
try:
|
||||||
# Open the URL connection
|
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||||
with urllib.request.urlopen(url_response.url, timeout=30) as response:
|
_safe_extract_tar(archive_path, extract_dir, show_progress=True, format=format)
|
||||||
# Get total file size from headers
|
destination = str(extract_dir)
|
||||||
total_size = int(response.headers.get("Content-Length", 0))
|
if archive_path.exists():
|
||||||
|
archive_path.unlink()
|
||||||
# Download with progress bar
|
except Exception as e:
|
||||||
if format != "json":
|
|
||||||
with click.progressbar(
|
|
||||||
length=total_size,
|
|
||||||
label="Downloading archive",
|
|
||||||
show_percent=True,
|
|
||||||
show_pos=True,
|
|
||||||
show_eta=True,
|
|
||||||
) as bar:
|
|
||||||
with open(archive_path, "wb") as f:
|
|
||||||
while True:
|
|
||||||
chunk = response.read(8192)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
f.write(chunk)
|
|
||||||
bar.update(len(chunk))
|
|
||||||
else:
|
|
||||||
# Silent download for JSON output
|
|
||||||
with open(archive_path, "wb") as f:
|
|
||||||
while True:
|
|
||||||
chunk = response.read(8192)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
# Extract the checkpoint
|
|
||||||
try:
|
|
||||||
# Create extraction directory
|
|
||||||
extract_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Extract the tar archive
|
|
||||||
with tarfile.open(archive_path, "r") as tar:
|
|
||||||
# Get list of members for progress tracking
|
|
||||||
members = tar.getmembers()
|
|
||||||
|
|
||||||
if format != "json":
|
|
||||||
with click.progressbar(
|
|
||||||
members,
|
|
||||||
label="Extracting archive ",
|
|
||||||
show_percent=True,
|
|
||||||
show_pos=True,
|
|
||||||
) as bar:
|
|
||||||
for member in bar:
|
|
||||||
tar.extract(member, path=extract_dir)
|
|
||||||
else:
|
|
||||||
# Extract all at once for few files
|
|
||||||
tar.extractall(path=extract_dir)
|
|
||||||
|
|
||||||
destination = str(extract_dir)
|
|
||||||
|
|
||||||
# Delete archive after successful extraction
|
|
||||||
if archive_path.exists():
|
|
||||||
archive_path.unlink()
|
|
||||||
|
|
||||||
except tarfile.TarError as e:
|
|
||||||
raise TinkerCliError(
|
|
||||||
f"Failed to extract archive: {e}",
|
|
||||||
"The downloaded file may be corrupted. Try downloading again.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create output object
|
|
||||||
output_obj = CheckpointDownloadOutput(
|
|
||||||
checkpoint_path=checkpoint_path,
|
|
||||||
file_size_bytes=total_size if total_size > 0 else None,
|
|
||||||
destination=destination,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print in requested format
|
|
||||||
output_obj.print(format=format)
|
|
||||||
|
|
||||||
except urllib.error.URLError as e:
|
|
||||||
raise TinkerCliError(
|
raise TinkerCliError(
|
||||||
f"Failed to download checkpoint: {e}",
|
f"Failed to extract archive: {e}",
|
||||||
"Please check your network connection and try again.",
|
"The downloaded file may be corrupted. Try downloading again.",
|
||||||
)
|
|
||||||
except IOError as e:
|
|
||||||
raise TinkerCliError(
|
|
||||||
f"Failed to save checkpoint: {e}",
|
|
||||||
f"Please check that you have write permissions to {output_dir}",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
output_obj = CheckpointDownloadOutput(
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
file_size_bytes=total_size if total_size > 0 else None,
|
||||||
|
destination=destination,
|
||||||
|
)
|
||||||
|
output_obj.print(format=format)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(name="push-hf")
|
||||||
|
@click.argument("checkpoint_path")
|
||||||
|
@click.option(
|
||||||
|
"--repo",
|
||||||
|
"-r",
|
||||||
|
"repo_id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Hugging Face repo ID (e.g., username/my-lora-adapter). If omitted, derive from run.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--public",
|
||||||
|
is_flag=True,
|
||||||
|
help="Create or upload to a public repo (default: private).",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--revision",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Target branch/revision to upload to (optional).",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--commit-message",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Commit message for the upload (optional).",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--create-pr",
|
||||||
|
is_flag=True,
|
||||||
|
help="Create a pull request instead of pushing to the main branch.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--allow-pattern",
|
||||||
|
"allow_patterns",
|
||||||
|
multiple=True,
|
||||||
|
help="Only upload files matching this pattern (can be repeated).",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--ignore-pattern",
|
||||||
|
"ignore_patterns",
|
||||||
|
multiple=True,
|
||||||
|
help="Skip files matching this pattern (can be repeated).",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--no-model-card",
|
||||||
|
is_flag=True,
|
||||||
|
help="Do not create a README.md model card if one is missing.",
|
||||||
|
)
|
||||||
|
@click.pass_obj
|
||||||
|
@handle_api_errors
|
||||||
|
def push_hf(
|
||||||
|
cli_context: CLIContext,
|
||||||
|
checkpoint_path: str,
|
||||||
|
repo_id: str | None,
|
||||||
|
public: bool,
|
||||||
|
revision: str | None,
|
||||||
|
commit_message: str | None,
|
||||||
|
create_pr: bool,
|
||||||
|
allow_patterns: tuple[str, ...],
|
||||||
|
ignore_patterns: tuple[str, ...],
|
||||||
|
no_model_card: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Upload a checkpoint to the Hugging Face Hub as a PEFT adapter.
|
||||||
|
|
||||||
|
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/sampler_weights/0001).
|
||||||
|
"""
|
||||||
|
# Validate it's a tinker path
|
||||||
|
if not checkpoint_path.startswith("tinker://"):
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Invalid checkpoint path: {checkpoint_path}",
|
||||||
|
"Checkpoint path must be in the format: tinker://run-id/sampler_weights/0001",
|
||||||
|
)
|
||||||
|
|
||||||
|
client = create_rest_client()
|
||||||
|
repo_id_out = _export_checkpoint_to_hub(
|
||||||
|
client,
|
||||||
|
checkpoint_path,
|
||||||
|
repo_id,
|
||||||
|
private=not public,
|
||||||
|
revision=revision,
|
||||||
|
commit_message=commit_message,
|
||||||
|
create_pr=create_pr,
|
||||||
|
exist_ok=True,
|
||||||
|
allow_patterns=list(allow_patterns) if allow_patterns else None,
|
||||||
|
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
|
||||||
|
add_model_card=not no_model_card,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_obj = CheckpointHubUploadOutput(
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
repo_id=repo_id_out,
|
||||||
|
revision=revision,
|
||||||
|
public=public,
|
||||||
|
)
|
||||||
|
output_obj.print(format=cli_context.format)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue