diff --git a/src/tinker/cli/CLAUDE.md b/src/tinker/cli/CLAUDE.md index f321106..29655c6 100644 --- a/src/tinker/cli/CLAUDE.md +++ b/src/tinker/cli/CLAUDE.md @@ -62,6 +62,7 @@ class LazyGroup(click.Group): - `tinker run info ` - Show details of a specific run - `tinker checkpoint list` - List all checkpoints - `tinker checkpoint info ` - Show checkpoint details +- `tinker checkpoint push-hf ` - Upload a checkpoint to Hugging Face Hub ### 4. Output System with Inheritance @@ -241,6 +242,9 @@ tinker checkpoint list run-abc123 # Show checkpoint details tinker checkpoint info ckpt-xyz789 +# Upload checkpoint to Hugging Face Hub +tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter + # JSON output tinker --format json run list tinker --format json checkpoint list diff --git a/src/tinker/cli/commands/checkpoint.py b/src/tinker/cli/commands/checkpoint.py index 6ed948a..410f07d 100644 --- a/src/tinker/cli/commands/checkpoint.py +++ b/src/tinker/cli/commands/checkpoint.py @@ -239,6 +239,50 @@ class CheckpointDownloadOutput(OutputBase): return rows +class CheckpointHubUploadOutput(OutputBase): + """Output for 'tinker checkpoint push-hf' command.""" + + def __init__( + self, + checkpoint_path: str, + repo_id: str, + revision: str | None = None, + public: bool | None = None, + ): + self.checkpoint_path = checkpoint_path + self.repo_id = repo_id + self.revision = revision + self.public = public + + def to_dict(self) -> Dict[str, Any]: + result = { + "checkpoint_path": self.checkpoint_path, + "repo_id": self.repo_id, + } + if self.revision is not None: + result["revision"] = self.revision + if self.public is not None: + result["public"] = self.public + return result + + def get_title(self) -> str | None: + return f"Checkpoint Hub Upload: {self.checkpoint_path}" + + def get_table_columns(self) -> List[str]: + return ["Property", "Value"] + + def get_table_rows(self) -> List[List[str]]: + rows = [ + ["Checkpoint Path", self.checkpoint_path], + ["Repo ID", self.repo_id], + ] + if self.revision is not None: + rows.append(["Revision", self.revision]) + if self.public is not None: + rows.append(["Public", format_bool(self.public)]) + return rows + + def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Checkpoint": """Get checkpoint details from a tinker path. @@ -680,3 +724,99 @@ def download( f"Failed to save checkpoint: {e}", f"Please check that you have write permissions to {output_dir}", ) + + +@cli.command(name="push-hf") +@click.argument("checkpoint_path") +@click.option( + "--repo", + "-r", + "repo_id", + type=str, + default=None, + help="Hugging Face repo ID (e.g., username/my-lora-adapter). If omitted, derive from run.", +) +@click.option( + "--public", + is_flag=True, + help="Create or upload to a public repo (default: private).", +) +@click.option( + "--revision", + type=str, + default=None, + help="Target branch/revision to upload to (optional).", +) +@click.option( + "--commit-message", + type=str, + default=None, + help="Commit message for the upload (optional).", +) +@click.option( + "--create-pr", + is_flag=True, + help="Create a pull request instead of pushing to the main branch.", +) +@click.option( + "--allow-pattern", + "allow_patterns", + multiple=True, + help="Only upload files matching this pattern (can be repeated).", +) +@click.option( + "--ignore-pattern", + "ignore_patterns", + multiple=True, + help="Skip files matching this pattern (can be repeated).", +) +@click.option( + "--no-model-card", + is_flag=True, + help="Do not create a README.md model card if one is missing.", +) +@click.pass_obj +@handle_api_errors +def push_hf( + cli_context: CLIContext, + checkpoint_path: str, + repo_id: str | None, + public: bool, + revision: str | None, + commit_message: str | None, + create_pr: bool, + allow_patterns: tuple[str, ...], + ignore_patterns: tuple[str, ...], + no_model_card: bool, +) -> None: + """Upload a checkpoint to the Hugging Face Hub as a PEFT adapter. + + CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001). + """ + # Validate it's a tinker path + if not checkpoint_path.startswith("tinker://"): + raise TinkerCliError( + f"Invalid checkpoint path: {checkpoint_path}", + "Checkpoint path must be in the format: tinker://run-id/weights/0001", + ) + + client = create_rest_client() + repo_id_out = client.export_checkpoint_to_hub( + checkpoint_path, + repo_id, + private=not public, + revision=revision, + commit_message=commit_message, + create_pr=create_pr, + allow_patterns=list(allow_patterns) if allow_patterns else None, + ignore_patterns=list(ignore_patterns) if ignore_patterns else None, + add_model_card=not no_model_card, + ) + + output_obj = CheckpointHubUploadOutput( + checkpoint_path=checkpoint_path, + repo_id=repo_id_out, + revision=revision, + public=public, + ) + output_obj.print(format=cli_context.format)