mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-26 17:13:11 +00:00
initial export to hub
This commit is contained in:
parent
ca40e08bb4
commit
d6aa21005b
2 changed files with 239 additions and 0 deletions
|
|
@ -288,6 +288,56 @@ async def get_checkpoint_archive_url_from_tinker_path_async(
|
|||
|
||||
Async version of get_checkpoint_archive_url_from_tinker_path.
|
||||
|
||||
#### `export_checkpoint_to_hub`
|
||||
|
||||
```python
|
||||
def export_checkpoint_to_hub(
|
||||
tinker_path: str,
|
||||
repo_id: str | None,
|
||||
*,
|
||||
private: bool = True,
|
||||
token: str | None = None,
|
||||
revision: str | None = None,
|
||||
commit_message: str | None = None,
|
||||
create_pr: bool = False,
|
||||
exist_ok: bool = True,
|
||||
allow_patterns: Sequence[str] | None = None,
|
||||
ignore_patterns: Sequence[str] | None = None,
|
||||
add_model_card: bool = True,
|
||||
) -> str
|
||||
```
|
||||
|
||||
Download a checkpoint archive, extract the PEFT adapter files, optionally add a README.md
|
||||
model card, and upload to the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||
- `repo_id`: Hugging Face repo ID (e.g., "username/my-lora-adapter"). If None,
|
||||
a name is derived from the base model and checkpoint path.
|
||||
- `private`: Whether to create the repo as private (default True)
|
||||
- `token`: Hugging Face access token (optional)
|
||||
- `revision`: Target branch/revision to upload to (optional)
|
||||
- `commit_message`: Commit message for the upload (optional)
|
||||
- `create_pr`: Whether to create a PR instead of pushing to the main branch
|
||||
- `exist_ok`: Whether repo creation should succeed if repo exists
|
||||
- `allow_patterns`: Optional list of file patterns to include
|
||||
- `ignore_patterns`: Optional list of file patterns to exclude
|
||||
- `add_model_card`: Whether to add a README.md if missing (default True)
|
||||
|
||||
Returns:
|
||||
- The repo_id that was uploaded to
|
||||
|
||||
Example:
|
||||
```python
|
||||
rest_client = service_client.create_rest_client()
|
||||
repo_id = rest_client.export_checkpoint_to_hub(
|
||||
"tinker://run-id/weights/final",
|
||||
"username/my-lora-adapter",
|
||||
private=True,
|
||||
)
|
||||
print(f"Uploaded to: {repo_id}")
|
||||
```
|
||||
|
||||
#### `publish_checkpoint_from_tinker_path`
|
||||
|
||||
```python
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
from concurrent.futures import Future as ConcurrentFuture
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
|
@ -397,6 +398,194 @@ class RestClient(TelemetryProvider):
|
|||
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||
)
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def export_checkpoint_to_hub(
|
||||
self,
|
||||
tinker_path: str,
|
||||
repo_id: str | None,
|
||||
*,
|
||||
private: bool = True,
|
||||
token: str | None = None,
|
||||
revision: str | None = None,
|
||||
commit_message: str | None = None,
|
||||
create_pr: bool = False,
|
||||
exist_ok: bool = True,
|
||||
allow_patterns: Sequence[str] | None = None,
|
||||
ignore_patterns: Sequence[str] | None = None,
|
||||
add_model_card: bool = True,
|
||||
) -> str:
|
||||
"""Export a checkpoint archive to the Hugging Face Hub as a PEFT adapter.
|
||||
|
||||
This downloads the checkpoint archive, extracts it locally, optionally adds a
|
||||
README.md model card, and uploads the folder to the Hub.
|
||||
|
||||
Args:
|
||||
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||
- `repo_id`: Hugging Face repo ID (e.g., "username/my-lora-adapter"). If None,
|
||||
a name is derived from the base model and checkpoint path.
|
||||
- `private`: Whether to create the repo as private (default False)
|
||||
- `token`: Hugging Face access token (optional)
|
||||
- `revision`: Target branch/revision to upload to (optional)
|
||||
- `commit_message`: Commit message for the upload (optional)
|
||||
- `create_pr`: Whether to create a PR instead of pushing to the main branch
|
||||
- `exist_ok`: Whether repo creation should succeed if repo exists
|
||||
- `allow_patterns`: Optional list of file patterns to include
|
||||
- `ignore_patterns`: Optional list of file patterns to exclude
|
||||
- `add_model_card`: Whether to add a README.md if missing (default True)
|
||||
|
||||
Returns:
|
||||
- The repo_id that was uploaded to
|
||||
"""
|
||||
# Lazy imports to keep base SDK lightweight
|
||||
try:
|
||||
from huggingface_hub import HfApi
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
raise ImportError(
|
||||
"huggingface_hub is required for export_checkpoint_to_hub. "
|
||||
"Install it with: pip install huggingface_hub"
|
||||
) from exc
|
||||
|
||||
import os
|
||||
import tarfile
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
# Validate tinker path
|
||||
types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||
|
||||
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
|
||||
for member in tar.getmembers():
|
||||
member_path = path / member.name
|
||||
if not str(member_path.resolve()).startswith(str(path.resolve())):
|
||||
raise ValueError("Unsafe path in tar archive")
|
||||
tar.extractall(path=path)
|
||||
|
||||
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||
|
||||
def _sanitize_repo_name(value: str) -> str:
|
||||
safe_chars = []
|
||||
for ch in value:
|
||||
if ch.isalnum() or ch in {"-", "_", "."}:
|
||||
safe_chars.append(ch)
|
||||
else:
|
||||
safe_chars.append("-")
|
||||
# Collapse repeated separators
|
||||
name = "".join(safe_chars)
|
||||
while "--" in name:
|
||||
name = name.replace("--", "-")
|
||||
return name.strip("-_ .")
|
||||
|
||||
url_response = self.get_checkpoint_archive_url_from_tinker_path(tinker_path).result()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_root = Path(temp_dir)
|
||||
archive_path = temp_root / "checkpoint.tar"
|
||||
extract_dir = temp_root / "extract"
|
||||
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download archive
|
||||
with urllib.request.urlopen(url_response.url, timeout=60) as response:
|
||||
with open(archive_path, "wb") as f:
|
||||
while True:
|
||||
chunk = response.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
|
||||
# Extract archive
|
||||
with tarfile.open(archive_path, "r") as tar:
|
||||
_safe_extract(tar, extract_dir)
|
||||
|
||||
# Validate PEFT adapter files exist
|
||||
adapter_config = extract_dir / "adapter_config.json"
|
||||
adapter_safetensors = extract_dir / "adapter_model.safetensors"
|
||||
adapter_bin = extract_dir / "adapter_model.bin"
|
||||
if not adapter_config.exists() or not (adapter_safetensors.exists() or adapter_bin.exists()):
|
||||
raise ValueError(
|
||||
"Checkpoint archive does not contain a PEFT adapter. "
|
||||
"Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin)."
|
||||
)
|
||||
|
||||
base_model = "unknown"
|
||||
lora_rank = None
|
||||
train_mlp = None
|
||||
train_attn = None
|
||||
train_unembed = None
|
||||
try:
|
||||
weights_info = self.get_weights_info_by_tinker_path(tinker_path).result()
|
||||
base_model = weights_info.base_model
|
||||
lora_rank = weights_info.lora_rank
|
||||
train_mlp = weights_info.train_mlp
|
||||
train_attn = weights_info.train_attn
|
||||
train_unembed = weights_info.train_unembed
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if repo_id is None:
|
||||
base_short = base_model.split("/")[-1] if base_model != "unknown" else "adapter"
|
||||
checkpoint_id = parsed_tinker_path.checkpoint_id.replace("/", "-")
|
||||
derived = f"tinker-{base_short}-{parsed_tinker_path.training_run_id}-{checkpoint_id}"
|
||||
repo_id = _sanitize_repo_name(derived)
|
||||
|
||||
# Add a lightweight model card if missing
|
||||
readme_path = extract_dir / "README.md"
|
||||
if add_model_card and not readme_path.exists():
|
||||
tags: list[str] = ["tinker", "peft", "lora", "transformers"]
|
||||
if base_model != "unknown":
|
||||
tags.append(f"base_model:adapter:{base_model}")
|
||||
model_card = [
|
||||
"---",
|
||||
f"base_model: {base_model}",
|
||||
"library_name: peft",
|
||||
"tags:",
|
||||
]
|
||||
for tag in tags:
|
||||
model_card.append(f"- {tag}")
|
||||
model_card.append(f"tinker_path: {tinker_path}")
|
||||
model_card.extend(
|
||||
[
|
||||
"---",
|
||||
"",
|
||||
"# LoRA Adapter (Tinker)",
|
||||
"",
|
||||
f"This repository contains a LoRA adapter exported from Tinker.",
|
||||
"",
|
||||
"## Source",
|
||||
"",
|
||||
f"- Tinker checkpoint: {tinker_path}",
|
||||
"",
|
||||
"## Details",
|
||||
"",
|
||||
f"- Base model: {base_model}",
|
||||
]
|
||||
)
|
||||
if lora_rank is not None:
|
||||
model_card.append(f"- LoRA rank: {lora_rank}")
|
||||
if train_mlp is not None or train_attn is not None or train_unembed is not None:
|
||||
model_card.append(
|
||||
f"- Trained modules: attn={train_attn}, mlp={train_mlp}, unembed={train_unembed}"
|
||||
)
|
||||
model_card.append("")
|
||||
readme_path.write_text("\n".join(model_card), encoding="utf-8")
|
||||
|
||||
api = HfApi(token=token)
|
||||
api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok)
|
||||
|
||||
api.upload_folder(
|
||||
folder_path=os.fspath(extract_dir),
|
||||
repo_id=repo_id,
|
||||
path_in_repo="",
|
||||
revision=revision,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
allow_patterns=list(allow_patterns) if allow_patterns else None,
|
||||
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
|
||||
)
|
||||
|
||||
return repo_id
|
||||
|
||||
def _publish_checkpoint_submit(
|
||||
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||
) -> AwaitableConcurrentFuture[None]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue