mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
add cli
This commit is contained in:
parent
3e75e63d40
commit
0d81ac458a
2 changed files with 144 additions and 0 deletions
|
|
@ -62,6 +62,7 @@ class LazyGroup(click.Group):
|
|||
- `tinker run info <run-id>` - Show details of a specific run
|
||||
- `tinker checkpoint list` - List all checkpoints
|
||||
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
|
||||
- `tinker checkpoint push-hf <checkpoint-path>` - Upload a checkpoint to Hugging Face Hub
|
||||
|
||||
### 4. Output System with Inheritance
|
||||
|
||||
|
|
@ -241,6 +242,9 @@ tinker checkpoint list run-abc123
|
|||
# Show checkpoint details
|
||||
tinker checkpoint info ckpt-xyz789
|
||||
|
||||
# Upload checkpoint to Hugging Face Hub
|
||||
tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter
|
||||
|
||||
# JSON output
|
||||
tinker --format json run list
|
||||
tinker --format json checkpoint list
|
||||
|
|
|
|||
|
|
@ -239,6 +239,50 @@ class CheckpointDownloadOutput(OutputBase):
|
|||
return rows
|
||||
|
||||
|
||||
class CheckpointHubUploadOutput(OutputBase):
|
||||
"""Output for 'tinker checkpoint push-hf' command."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path: str,
|
||||
repo_id: str,
|
||||
revision: str | None = None,
|
||||
public: bool | None = None,
|
||||
):
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision
|
||||
self.public = public
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"checkpoint_path": self.checkpoint_path,
|
||||
"repo_id": self.repo_id,
|
||||
}
|
||||
if self.revision is not None:
|
||||
result["revision"] = self.revision
|
||||
if self.public is not None:
|
||||
result["public"] = self.public
|
||||
return result
|
||||
|
||||
def get_title(self) -> str | None:
|
||||
return f"Checkpoint Hub Upload: {self.checkpoint_path}"
|
||||
|
||||
def get_table_columns(self) -> List[str]:
|
||||
return ["Property", "Value"]
|
||||
|
||||
def get_table_rows(self) -> List[List[str]]:
|
||||
rows = [
|
||||
["Checkpoint Path", self.checkpoint_path],
|
||||
["Repo ID", self.repo_id],
|
||||
]
|
||||
if self.revision is not None:
|
||||
rows.append(["Revision", self.revision])
|
||||
if self.public is not None:
|
||||
rows.append(["Public", format_bool(self.public)])
|
||||
return rows
|
||||
|
||||
|
||||
def get_checkpoint_from_path(client: "RestClient", checkpoint_path: str) -> "Checkpoint":
|
||||
"""Get checkpoint details from a tinker path.
|
||||
|
||||
|
|
@ -680,3 +724,99 @@ def download(
|
|||
f"Failed to save checkpoint: {e}",
|
||||
f"Please check that you have write permissions to {output_dir}",
|
||||
)
|
||||
|
||||
|
||||
@cli.command(name="push-hf")
|
||||
@click.argument("checkpoint_path")
|
||||
@click.option(
|
||||
"--repo",
|
||||
"-r",
|
||||
"repo_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hugging Face repo ID (e.g., username/my-lora-adapter). If omitted, derive from run.",
|
||||
)
|
||||
@click.option(
|
||||
"--public",
|
||||
is_flag=True,
|
||||
help="Create or upload to a public repo (default: private).",
|
||||
)
|
||||
@click.option(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Target branch/revision to upload to (optional).",
|
||||
)
|
||||
@click.option(
|
||||
"--commit-message",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Commit message for the upload (optional).",
|
||||
)
|
||||
@click.option(
|
||||
"--create-pr",
|
||||
is_flag=True,
|
||||
help="Create a pull request instead of pushing to the main branch.",
|
||||
)
|
||||
@click.option(
|
||||
"--allow-pattern",
|
||||
"allow_patterns",
|
||||
multiple=True,
|
||||
help="Only upload files matching this pattern (can be repeated).",
|
||||
)
|
||||
@click.option(
|
||||
"--ignore-pattern",
|
||||
"ignore_patterns",
|
||||
multiple=True,
|
||||
help="Skip files matching this pattern (can be repeated).",
|
||||
)
|
||||
@click.option(
|
||||
"--no-model-card",
|
||||
is_flag=True,
|
||||
help="Do not create a README.md model card if one is missing.",
|
||||
)
|
||||
@click.pass_obj
|
||||
@handle_api_errors
|
||||
def push_hf(
|
||||
cli_context: CLIContext,
|
||||
checkpoint_path: str,
|
||||
repo_id: str | None,
|
||||
public: bool,
|
||||
revision: str | None,
|
||||
commit_message: str | None,
|
||||
create_pr: bool,
|
||||
allow_patterns: tuple[str, ...],
|
||||
ignore_patterns: tuple[str, ...],
|
||||
no_model_card: bool,
|
||||
) -> None:
|
||||
"""Upload a checkpoint to the Hugging Face Hub as a PEFT adapter.
|
||||
|
||||
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001).
|
||||
"""
|
||||
# Validate it's a tinker path
|
||||
if not checkpoint_path.startswith("tinker://"):
|
||||
raise TinkerCliError(
|
||||
f"Invalid checkpoint path: {checkpoint_path}",
|
||||
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
|
||||
)
|
||||
|
||||
client = create_rest_client()
|
||||
repo_id_out = client.export_checkpoint_to_hub(
|
||||
checkpoint_path,
|
||||
repo_id,
|
||||
private=not public,
|
||||
revision=revision,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
allow_patterns=list(allow_patterns) if allow_patterns else None,
|
||||
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
|
||||
add_model_card=not no_model_card,
|
||||
)
|
||||
|
||||
output_obj = CheckpointHubUploadOutput(
|
||||
checkpoint_path=checkpoint_path,
|
||||
repo_id=repo_id_out,
|
||||
revision=revision,
|
||||
public=public,
|
||||
)
|
||||
output_obj.print(format=cli_context.format)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue