This commit is contained in:
Kashif Rasul 2026-01-28 09:35:08 +01:00
parent 3e75e63d40
commit 0d81ac458a
2 changed files with 144 additions and 0 deletions

View file

@ -62,6 +62,7 @@ class LazyGroup(click.Group):
- `tinker run info <run-id>` - Show details of a specific run
- `tinker checkpoint list` - List all checkpoints
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
- `tinker checkpoint push-hf <checkpoint-path>` - Upload a checkpoint to Hugging Face Hub
### 4. Output System with Inheritance
@ -241,6 +242,9 @@ tinker checkpoint list run-abc123
# Show checkpoint details
tinker checkpoint info ckpt-xyz789
# Upload checkpoint to Hugging Face Hub
tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter
# JSON output
tinker --format json run list
tinker --format json checkpoint list

View file

@ -239,6 +239,50 @@ class CheckpointDownloadOutput(OutputBase):
return rows
class CheckpointHubUploadOutput(OutputBase):
"""Output for 'tinker checkpoint push-hf' command."""
def __init__(
self,
checkpoint_path: str,
repo_id: str,
revision: str | None = None,
public: bool | None = None,
):
self.checkpoint_path = checkpoint_path
self.repo_id = repo_id
self.revision = revision
self.public = public
def to_dict(self) -> Dict[str, Any]:
result = {
"checkpoint_path": self.checkpoint_path,
"repo_id": self.repo_id,
}
if self.revision is not None:
result["revision"] = self.revision
if self.public is not None:
result["public"] = self.public
return result
def get_title(self) -> str | None:
return f"Checkpoint Hub Upload: {self.checkpoint_path}"
def get_table_columns(self) -> List[str]:
return ["Property", "Value"]
def get_table_rows(self) -> List[List[str]]:
rows = [
["Checkpoint Path", self.checkpoint_path],
["Repo ID", self.repo_id],
]
if self.revision is not None:
rows.append(["Revision", self.revision])
if self.public is not None:
rows.append(["Public", format_bool(self.public)])
return rows
def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Checkpoint":
"""Get checkpoint details from a tinker path.
@ -680,3 +724,99 @@ def download(
f"Failed to save checkpoint: {e}",
f"Please check that you have write permissions to {output_dir}",
)
@cli.command(name="push-hf")
@click.argument("checkpoint_path")
@click.option(
"--repo",
"-r",
"repo_id",
type=str,
default=None,
help="Hugging Face repo ID (e.g., username/my-lora-adapter). If omitted, derive from run.",
)
@click.option(
"--public",
is_flag=True,
help="Create or upload to a public repo (default: private).",
)
@click.option(
"--revision",
type=str,
default=None,
help="Target branch/revision to upload to (optional).",
)
@click.option(
"--commit-message",
type=str,
default=None,
help="Commit message for the upload (optional).",
)
@click.option(
"--create-pr",
is_flag=True,
help="Create a pull request instead of pushing to the main branch.",
)
@click.option(
"--allow-pattern",
"allow_patterns",
multiple=True,
help="Only upload files matching this pattern (can be repeated).",
)
@click.option(
"--ignore-pattern",
"ignore_patterns",
multiple=True,
help="Skip files matching this pattern (can be repeated).",
)
@click.option(
"--no-model-card",
is_flag=True,
help="Do not create a README.md model card if one is missing.",
)
@click.pass_obj
@handle_api_errors
def push_hf(
cli_context: CLIContext,
checkpoint_path: str,
repo_id: str | None,
public: bool,
revision: str | None,
commit_message: str | None,
create_pr: bool,
allow_patterns: tuple[str, ...],
ignore_patterns: tuple[str, ...],
no_model_card: bool,
) -> None:
"""Upload a checkpoint to the Hugging Face Hub as a PEFT adapter.
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001).
"""
# Validate it's a tinker path
if not checkpoint_path.startswith("tinker://"):
raise TinkerCliError(
f"Invalid checkpoint path: {checkpoint_path}",
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
)
client = create_rest_client()
repo_id_out = client.export_checkpoint_to_hub(
checkpoint_path,
repo_id,
private=not public,
revision=revision,
commit_message=commit_message,
create_pr=create_pr,
allow_patterns=list(allow_patterns) if allow_patterns else None,
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
add_model_card=not no_model_card,
)
output_obj = CheckpointHubUploadOutput(
checkpoint_path=checkpoint_path,
repo_id=repo_id_out,
revision=revision,
public=public,
)
output_obj.print(format=cli_context.format)