diff --git a/src/tinker/cli/CLAUDE.md b/src/tinker/cli/CLAUDE.md index f321106..29655c6 100644 --- a/src/tinker/cli/CLAUDE.md +++ b/src/tinker/cli/CLAUDE.md @@ -62,6 +62,7 @@ class LazyGroup(click.Group): - `tinker run info ` - Show details of a specific run - `tinker checkpoint list` - List all checkpoints - `tinker checkpoint info ` - Show checkpoint details +- `tinker checkpoint push-hf ` - Upload a checkpoint to Hugging Face Hub ### 4. Output System with Inheritance @@ -241,6 +242,9 @@ tinker checkpoint list run-abc123 # Show checkpoint details tinker checkpoint info ckpt-xyz789 +# Upload checkpoint to Hugging Face Hub +tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter + # JSON output tinker --format json run list tinker --format json checkpoint list diff --git a/src/tinker/cli/commands/checkpoint.py b/src/tinker/cli/commands/checkpoint.py index 6ed948a..ad7bc8e 100644 --- a/src/tinker/cli/commands/checkpoint.py +++ b/src/tinker/cli/commands/checkpoint.py @@ -239,6 +239,50 @@ class CheckpointDownloadOutput(OutputBase): return rows +class CheckpointHubUploadOutput(OutputBase): + """Output for 'tinker checkpoint push-hf' command.""" + + def __init__( + self, + checkpoint_path: str, + repo_id: str, + revision: str | None = None, + public: bool | None = None, + ): + self.checkpoint_path = checkpoint_path + self.repo_id = repo_id + self.revision = revision + self.public = public + + def to_dict(self) -> Dict[str, Any]: + result = { + "checkpoint_path": self.checkpoint_path, + "repo_id": self.repo_id, + } + if self.revision is not None: + result["revision"] = self.revision + if self.public is not None: + result["public"] = self.public + return result + + def get_title(self) -> str | None: + return f"Checkpoint Hub Upload: {self.checkpoint_path}" + + def get_table_columns(self) -> List[str]: + return ["Property", "Value"] + + def get_table_rows(self) -> List[List[str]]: + rows = [ + ["Checkpoint Path", self.checkpoint_path], + ["Repo ID", self.repo_id], + ] + if self.revision is not None: + rows.append(["Revision", self.revision]) + if self.public is not None: + rows.append(["Public", format_bool(self.public)]) + return rows + + def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Checkpoint": """Get checkpoint details from a tinker path. @@ -280,6 +324,307 @@ def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Che raise TinkerCliError(f"Failed to retrieve checkpoint: {e}") +def _export_checkpoint_to_hub( + client: "RestClient", + tinker_path: str, + repo_id: str | None, + *, + private: bool, + revision: str | None, + commit_message: str | None, + create_pr: bool, + exist_ok: bool, + allow_patterns: list[str] | None, + ignore_patterns: list[str] | None, + add_model_card: bool, +) -> str: + # Lazy imports to keep CLI startup fast + try: + from huggingface_hub import HfApi, hf_hub_download + except ImportError as exc: + raise TinkerCliError( + "huggingface_hub is required for this command.", + "Install it with: pip install huggingface_hub, then run: hf auth login", + ) from exc + + import json + import os + import re + import tempfile + from pathlib import Path + + from tinker import ParsedCheckpointTinkerPath + + # Validate tinker path + parsed_tinker_path = ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) + + api = HfApi() + try: + api.whoami() + except Exception as exc: + raise TinkerCliError("Not logged in", "Run: hf auth login") from exc + + def _sanitize_repo_name(value: str) -> str: + safe_chars = [] + for ch in value: + if ch.isalnum() or ch in {"-", "_", "."}: + safe_chars.append(ch) + else: + safe_chars.append("-") + name = "".join(safe_chars) + while "--" in name: + name = name.replace("--", "-") + return name.strip("-_ .") + + with tempfile.TemporaryDirectory() as temp_dir: + temp_root = Path(temp_dir) + archive_path = temp_root / "checkpoint.tar" + extract_dir = temp_root / "extract" + extract_dir.mkdir(parents=True, exist_ok=True) + + url_response = client.get_checkpoint_archive_url_from_tinker_path(tinker_path).result() + _download_checkpoint_archive( + url_response.url, + archive_path=archive_path, + show_progress=False, + format="json", + ) + _safe_extract_tar(archive_path, extract_dir, show_progress=False, format="json") + + adapter_config = extract_dir / "adapter_config.json" + adapter_safetensors = extract_dir / "adapter_model.safetensors" + adapter_bin = extract_dir / "adapter_model.bin" + checkpoint_complete = extract_dir / "checkpoint_complete" + if not adapter_config.exists() or not ( + adapter_safetensors.exists() or adapter_bin.exists() + ): + raise TinkerCliError( + "Checkpoint archive does not contain a PEFT adapter.", + "Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin).", + ) + if not checkpoint_complete.exists(): + raise TinkerCliError( + "Checkpoint archive is missing 'checkpoint_complete'.", + "The adapter files may be incomplete.", + ) + + base_model = "unknown" + lora_rank = None + train_mlp = None + train_attn = None + train_unembed = None + try: + weights_info = client.get_weights_info_by_tinker_path(tinker_path).result() + base_model = weights_info.base_model + lora_rank = weights_info.lora_rank + train_mlp = weights_info.train_mlp + train_attn = weights_info.train_attn + train_unembed = weights_info.train_unembed + except Exception: + pass + + try: + config_data = json.loads(adapter_config.read_text(encoding="utf-8")) + if not isinstance(config_data.get("base_model_name_or_path"), str): + config_data["base_model_name_or_path"] = base_model + adapter_config.write_text( + json.dumps(config_data, indent=2, sort_keys=True) + "\n", encoding="utf-8" + ) + except Exception: + pass + + if repo_id is None: + base_short = base_model.split("/")[-1] if base_model != "unknown" else "adapter" + derived = f"tinker-{base_short}-{parsed_tinker_path.training_run_id}" + repo_id = _sanitize_repo_name(derived) + if revision is None: + revision = _sanitize_repo_name(parsed_tinker_path.checkpoint_id.replace("/", "-")) + + readme_path = extract_dir / "README.md" + if add_model_card and not readme_path.exists(): + tags: List[str] = ["tinker", "peft", "lora"] + if base_model != "unknown": + tags.append(f"base_model:adapter:{base_model}") + model_card = [ + "---", + f"base_model: {base_model}", + "library_name: peft", + "tags:", + ] + for tag in tags: + model_card.append(f"- {tag}") + model_card.append(f"tinker_path: {tinker_path}") + model_card.extend( + [ + "---", + "", + "# Tinker LoRA Adapter", + "", + "This repository contains a LoRA adapter exported from Tinker.", + "", + "## Usage", + "", + "```python", + "from transformers import AutoModelForCausalLM", + "", + f'adapter_id = "{repo_id}"', + f'base_model = "{base_model}"', + "", + 'model = AutoModelForCausalLM.from_pretrained(adapter_id, device_map="auto")', + "```", + "", + "## Source", + "", + "```", + f"{tinker_path}", + "```", + "", + "## Details", + "", + f"- Base model: {base_model}", + ] + ) + if lora_rank is not None: + model_card.append(f"- LoRA rank: {lora_rank}") + if train_mlp is not None or train_attn is not None or train_unembed is not None: + model_card.append( + f"- Trained modules: attn={train_attn}, mlp={train_mlp}, unembed={train_unembed}" + ) + model_card.append("") + readme_path.write_text("\n".join(model_card), encoding="utf-8") + + api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok) + + def _readme_tinker_path() -> str | None: + try: + readme_file = hf_hub_download( + repo_id=repo_id, + filename="README.md", + revision=revision, + token=None, + ) + except Exception: + return None + try: + text = Path(readme_file).read_text(encoding="utf-8", errors="ignore") + except Exception: + return None + match = re.search(r"tinker://[^\s`]+", text) + return match.group(0) if match else None + + existing_tinker_path = _readme_tinker_path() + if existing_tinker_path and existing_tinker_path != tinker_path: + raise TinkerCliError( + "Repo ID appears to contain a different Tinker checkpoint.", + f"Found {existing_tinker_path}, expected {tinker_path}.", + ) + + if allow_patterns is None: + ignore_patterns = list(ignore_patterns) if ignore_patterns else [] + if "checkpoint_complete" not in ignore_patterns: + ignore_patterns.append("checkpoint_complete") + + api.upload_folder( + folder_path=os.fspath(extract_dir), + repo_id=repo_id, + path_in_repo="", + revision=revision, + commit_message=commit_message, + create_pr=create_pr, + allow_patterns=list(allow_patterns) if allow_patterns else None, + ignore_patterns=list(ignore_patterns) if ignore_patterns else None, + ) + + return repo_id + + +def _safe_extract_tar( + archive_path, + extract_dir, + *, + show_progress: bool, + format: str, +) -> None: + import tarfile + + base = extract_dir.resolve() + with tarfile.open(archive_path, "r") as tar: + members = tar.getmembers() + for member in members: + if member.issym() or member.islnk(): + raise TinkerCliError( + "Unsafe symlink or hardlink in tar archive", + "Archive may be corrupted or malicious.", + ) + member_path = (extract_dir / member.name).resolve() + if not str(member_path).startswith(str(base)): + raise TinkerCliError( + "Unsafe path in tar archive", + "Archive may be corrupted or malicious.", + ) + if show_progress and format != "json": + with click.progressbar( + members, + label="Extracting archive ", + show_percent=True, + show_pos=True, + ) as bar: + for member in bar: + tar.extract(member, path=extract_dir) + else: + tar.extractall(path=extract_dir) + + +def _download_checkpoint_archive( + url: str, + *, + archive_path, + show_progress: bool, + format: str, +) -> int: + import urllib.error + import urllib.request + + try: + with urllib.request.urlopen(url, timeout=60) as response: + total_size = int(response.headers.get("Content-Length", 0)) + + if show_progress and format != "json": + with click.progressbar( + length=total_size, + label="Downloading archive", + show_percent=True, + show_pos=True, + show_eta=True, + ) as bar: + with open(archive_path, "wb") as f: + while True: + chunk = response.read(8192) + if not chunk: + break + f.write(chunk) + bar.update(len(chunk)) + else: + with open(archive_path, "wb") as f: + while True: + chunk = response.read(8192) + if not chunk: + break + f.write(chunk) + except urllib.error.URLError as e: + raise TinkerCliError( + f"Failed to download checkpoint: {e}", + "Please check your network connection and try again.", + ) from e + except IOError as e: + raise TinkerCliError( + f"Failed to save checkpoint: {e}", + f"Please check that you have write permissions to {archive_path.parent}", + ) from e + + return total_size + + # Click command group for checkpoint commands @click.group() def cli(): @@ -545,10 +890,7 @@ def download( """ # Lazy imports to maintain fast CLI startup import shutil - import tarfile import tempfile - import urllib.error - import urllib.request from pathlib import Path # Validate it's a tinker path @@ -584,99 +926,136 @@ def download( "Use --force to overwrite or choose a different output directory.", ) - # Create client and get download URL - client = create_rest_client() - url_response = client.get_checkpoint_archive_url_from_tinker_path(checkpoint_path).result() - # Use a temporary directory for the archive with tempfile.TemporaryDirectory() as temp_dir: archive_path = Path(temp_dir) / f"{checkpoint_id}.tar" extract_dir = target_path - # Download the archive with progress bar + # Create client and get download URL + client = create_rest_client() + url_response = client.get_checkpoint_archive_url_from_tinker_path(checkpoint_path).result() + + total_size = _download_checkpoint_archive( + url_response.url, + archive_path=archive_path, + show_progress=True, + format=format, + ) + + # Extract the checkpoint try: - # Open the URL connection - with urllib.request.urlopen(url_response.url, timeout=30) as response: - # Get total file size from headers - total_size = int(response.headers.get("Content-Length", 0)) - - # Download with progress bar - if format != "json": - with click.progressbar( - length=total_size, - label="Downloading archive", - show_percent=True, - show_pos=True, - show_eta=True, - ) as bar: - with open(archive_path, "wb") as f: - while True: - chunk = response.read(8192) - if not chunk: - break - f.write(chunk) - bar.update(len(chunk)) - else: - # Silent download for JSON output - with open(archive_path, "wb") as f: - while True: - chunk = response.read(8192) - if not chunk: - break - f.write(chunk) - - # Extract the checkpoint - try: - # Create extraction directory - extract_dir.mkdir(parents=True, exist_ok=True) - - # Extract the tar archive - with tarfile.open(archive_path, "r") as tar: - # Get list of members for progress tracking - members = tar.getmembers() - - if format != "json": - with click.progressbar( - members, - label="Extracting archive ", - show_percent=True, - show_pos=True, - ) as bar: - for member in bar: - tar.extract(member, path=extract_dir) - else: - # Extract all at once for few files - tar.extractall(path=extract_dir) - - destination = str(extract_dir) - - # Delete archive after successful extraction - if archive_path.exists(): - archive_path.unlink() - - except tarfile.TarError as e: - raise TinkerCliError( - f"Failed to extract archive: {e}", - "The downloaded file may be corrupted. Try downloading again.", - ) - - # Create output object - output_obj = CheckpointDownloadOutput( - checkpoint_path=checkpoint_path, - file_size_bytes=total_size if total_size > 0 else None, - destination=destination, - ) - - # Print in requested format - output_obj.print(format=format) - - except urllib.error.URLError as e: + extract_dir.mkdir(parents=True, exist_ok=True) + _safe_extract_tar(archive_path, extract_dir, show_progress=True, format=format) + destination = str(extract_dir) + if archive_path.exists(): + archive_path.unlink() + except Exception as e: raise TinkerCliError( - f"Failed to download checkpoint: {e}", - "Please check your network connection and try again.", - ) - except IOError as e: - raise TinkerCliError( - f"Failed to save checkpoint: {e}", - f"Please check that you have write permissions to {output_dir}", + f"Failed to extract archive: {e}", + "The downloaded file may be corrupted. Try downloading again.", ) + + output_obj = CheckpointDownloadOutput( + checkpoint_path=checkpoint_path, + file_size_bytes=total_size if total_size > 0 else None, + destination=destination, + ) + output_obj.print(format=format) + + +@cli.command(name="push-hf") +@click.argument("checkpoint_path") +@click.option( + "--repo", + "-r", + "repo_id", + type=str, + default=None, + help="Hugging Face repo ID (e.g., username/my-lora-adapter). If omitted, derive from run.", +) +@click.option( + "--public", + is_flag=True, + help="Create or upload to a public repo (default: private).", +) +@click.option( + "--revision", + type=str, + default=None, + help="Target branch/revision to upload to (optional).", +) +@click.option( + "--commit-message", + type=str, + default=None, + help="Commit message for the upload (optional).", +) +@click.option( + "--create-pr", + is_flag=True, + help="Create a pull request instead of pushing to the main branch.", +) +@click.option( + "--allow-pattern", + "allow_patterns", + multiple=True, + help="Only upload files matching this pattern (can be repeated).", +) +@click.option( + "--ignore-pattern", + "ignore_patterns", + multiple=True, + help="Skip files matching this pattern (can be repeated).", +) +@click.option( + "--no-model-card", + is_flag=True, + help="Do not create a README.md model card if one is missing.", +) +@click.pass_obj +@handle_api_errors +def push_hf( + cli_context: CLIContext, + checkpoint_path: str, + repo_id: str | None, + public: bool, + revision: str | None, + commit_message: str | None, + create_pr: bool, + allow_patterns: tuple[str, ...], + ignore_patterns: tuple[str, ...], + no_model_card: bool, +) -> None: + """Upload a checkpoint to the Hugging Face Hub as a PEFT adapter. + + CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/sampler_weights/0001). + """ + # Validate it's a tinker path + if not checkpoint_path.startswith("tinker://"): + raise TinkerCliError( + f"Invalid checkpoint path: {checkpoint_path}", + "Checkpoint path must be in the format: tinker://run-id/sampler_weights/0001", + ) + + client = create_rest_client() + repo_id_out = _export_checkpoint_to_hub( + client, + checkpoint_path, + repo_id, + private=not public, + revision=revision, + commit_message=commit_message, + create_pr=create_pr, + exist_ok=True, + allow_patterns=list(allow_patterns) if allow_patterns else None, + ignore_patterns=list(ignore_patterns) if ignore_patterns else None, + add_model_card=not no_model_card, + ) + + output_obj = CheckpointHubUploadOutput( + checkpoint_path=checkpoint_path, + repo_id=repo_id_out, + revision=revision, + public=public, + ) + output_obj.print(format=cli_context.format)