mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-23 16:54:58 +00:00
initial export to hub
This commit is contained in:
parent
ca40e08bb4
commit
d6aa21005b
2 changed files with 239 additions and 0 deletions
|
|
@ -288,6 +288,56 @@ async def get_checkpoint_archive_url_from_tinker_path_async(
|
||||||
|
|
||||||
Async version of get_checkpoint_archive_url_from_tinker_path.
|
Async version of get_checkpoint_archive_url_from_tinker_path.
|
||||||
|
|
||||||
|
#### `export_checkpoint_to_hub`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def export_checkpoint_to_hub(
|
||||||
|
tinker_path: str,
|
||||||
|
repo_id: str | None,
|
||||||
|
*,
|
||||||
|
private: bool = True,
|
||||||
|
token: str | None = None,
|
||||||
|
revision: str | None = None,
|
||||||
|
commit_message: str | None = None,
|
||||||
|
create_pr: bool = False,
|
||||||
|
exist_ok: bool = True,
|
||||||
|
allow_patterns: Sequence[str] | None = None,
|
||||||
|
ignore_patterns: Sequence[str] | None = None,
|
||||||
|
add_model_card: bool = True,
|
||||||
|
) -> str
|
||||||
|
```
|
||||||
|
|
||||||
|
Download a checkpoint archive, extract the PEFT adapter files, optionally add a README.md
|
||||||
|
model card, and upload to the Hugging Face Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||||
|
- `repo_id`: Hugging Face repo ID (e.g., "username/my-lora-adapter"). If None,
|
||||||
|
a name is derived from the base model and checkpoint path.
|
||||||
|
- `private`: Whether to create the repo as private (default True)
|
||||||
|
- `token`: Hugging Face access token (optional)
|
||||||
|
- `revision`: Target branch/revision to upload to (optional)
|
||||||
|
- `commit_message`: Commit message for the upload (optional)
|
||||||
|
- `create_pr`: Whether to create a PR instead of pushing to the main branch
|
||||||
|
- `exist_ok`: Whether repo creation should succeed if repo exists
|
||||||
|
- `allow_patterns`: Optional list of file patterns to include
|
||||||
|
- `ignore_patterns`: Optional list of file patterns to exclude
|
||||||
|
- `add_model_card`: Whether to add a README.md if missing (default True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The repo_id that was uploaded to
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
rest_client = service_client.create_rest_client()
|
||||||
|
repo_id = rest_client.export_checkpoint_to_hub(
|
||||||
|
"tinker://run-id/weights/final",
|
||||||
|
"username/my-lora-adapter",
|
||||||
|
private=True,
|
||||||
|
)
|
||||||
|
print(f"Uploaded to: {repo_id}")
|
||||||
|
```
|
||||||
|
|
||||||
#### `publish_checkpoint_from_tinker_path`
|
#### `publish_checkpoint_from_tinker_path`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Sequence
|
||||||
from concurrent.futures import Future as ConcurrentFuture
|
from concurrent.futures import Future as ConcurrentFuture
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
@ -397,6 +398,194 @@ class RestClient(TelemetryProvider):
|
||||||
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@sync_only
|
||||||
|
@capture_exceptions(fatal=True)
|
||||||
|
def export_checkpoint_to_hub(
|
||||||
|
self,
|
||||||
|
tinker_path: str,
|
||||||
|
repo_id: str | None,
|
||||||
|
*,
|
||||||
|
private: bool = True,
|
||||||
|
token: str | None = None,
|
||||||
|
revision: str | None = None,
|
||||||
|
commit_message: str | None = None,
|
||||||
|
create_pr: bool = False,
|
||||||
|
exist_ok: bool = True,
|
||||||
|
allow_patterns: Sequence[str] | None = None,
|
||||||
|
ignore_patterns: Sequence[str] | None = None,
|
||||||
|
add_model_card: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""Export a checkpoint archive to the Hugging Face Hub as a PEFT adapter.
|
||||||
|
|
||||||
|
This downloads the checkpoint archive, extracts it locally, optionally adds a
|
||||||
|
README.md model card, and uploads the folder to the Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||||
|
- `repo_id`: Hugging Face repo ID (e.g., "username/my-lora-adapter"). If None,
|
||||||
|
a name is derived from the base model and checkpoint path.
|
||||||
|
- `private`: Whether to create the repo as private (default False)
|
||||||
|
- `token`: Hugging Face access token (optional)
|
||||||
|
- `revision`: Target branch/revision to upload to (optional)
|
||||||
|
- `commit_message`: Commit message for the upload (optional)
|
||||||
|
- `create_pr`: Whether to create a PR instead of pushing to the main branch
|
||||||
|
- `exist_ok`: Whether repo creation should succeed if repo exists
|
||||||
|
- `allow_patterns`: Optional list of file patterns to include
|
||||||
|
- `ignore_patterns`: Optional list of file patterns to exclude
|
||||||
|
- `add_model_card`: Whether to add a README.md if missing (default True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The repo_id that was uploaded to
|
||||||
|
"""
|
||||||
|
# Lazy imports to keep base SDK lightweight
|
||||||
|
try:
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
raise ImportError(
|
||||||
|
"huggingface_hub is required for export_checkpoint_to_hub. "
|
||||||
|
"Install it with: pip install huggingface_hub"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tarfile
|
||||||
|
import tempfile
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Validate tinker path
|
||||||
|
types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
|
|
||||||
|
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
|
||||||
|
for member in tar.getmembers():
|
||||||
|
member_path = path / member.name
|
||||||
|
if not str(member_path.resolve()).startswith(str(path.resolve())):
|
||||||
|
raise ValueError("Unsafe path in tar archive")
|
||||||
|
tar.extractall(path=path)
|
||||||
|
|
||||||
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
|
|
||||||
|
def _sanitize_repo_name(value: str) -> str:
|
||||||
|
safe_chars = []
|
||||||
|
for ch in value:
|
||||||
|
if ch.isalnum() or ch in {"-", "_", "."}:
|
||||||
|
safe_chars.append(ch)
|
||||||
|
else:
|
||||||
|
safe_chars.append("-")
|
||||||
|
# Collapse repeated separators
|
||||||
|
name = "".join(safe_chars)
|
||||||
|
while "--" in name:
|
||||||
|
name = name.replace("--", "-")
|
||||||
|
return name.strip("-_ .")
|
||||||
|
|
||||||
|
url_response = self.get_checkpoint_archive_url_from_tinker_path(tinker_path).result()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
temp_root = Path(temp_dir)
|
||||||
|
archive_path = temp_root / "checkpoint.tar"
|
||||||
|
extract_dir = temp_root / "extract"
|
||||||
|
extract_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Download archive
|
||||||
|
with urllib.request.urlopen(url_response.url, timeout=60) as response:
|
||||||
|
with open(archive_path, "wb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = response.read(8192)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
# Extract archive
|
||||||
|
with tarfile.open(archive_path, "r") as tar:
|
||||||
|
_safe_extract(tar, extract_dir)
|
||||||
|
|
||||||
|
# Validate PEFT adapter files exist
|
||||||
|
adapter_config = extract_dir / "adapter_config.json"
|
||||||
|
adapter_safetensors = extract_dir / "adapter_model.safetensors"
|
||||||
|
adapter_bin = extract_dir / "adapter_model.bin"
|
||||||
|
if not adapter_config.exists() or not (adapter_safetensors.exists() or adapter_bin.exists()):
|
||||||
|
raise ValueError(
|
||||||
|
"Checkpoint archive does not contain a PEFT adapter. "
|
||||||
|
"Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin)."
|
||||||
|
)
|
||||||
|
|
||||||
|
base_model = "unknown"
|
||||||
|
lora_rank = None
|
||||||
|
train_mlp = None
|
||||||
|
train_attn = None
|
||||||
|
train_unembed = None
|
||||||
|
try:
|
||||||
|
weights_info = self.get_weights_info_by_tinker_path(tinker_path).result()
|
||||||
|
base_model = weights_info.base_model
|
||||||
|
lora_rank = weights_info.lora_rank
|
||||||
|
train_mlp = weights_info.train_mlp
|
||||||
|
train_attn = weights_info.train_attn
|
||||||
|
train_unembed = weights_info.train_unembed
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if repo_id is None:
|
||||||
|
base_short = base_model.split("/")[-1] if base_model != "unknown" else "adapter"
|
||||||
|
checkpoint_id = parsed_tinker_path.checkpoint_id.replace("/", "-")
|
||||||
|
derived = f"tinker-{base_short}-{parsed_tinker_path.training_run_id}-{checkpoint_id}"
|
||||||
|
repo_id = _sanitize_repo_name(derived)
|
||||||
|
|
||||||
|
# Add a lightweight model card if missing
|
||||||
|
readme_path = extract_dir / "README.md"
|
||||||
|
if add_model_card and not readme_path.exists():
|
||||||
|
tags: list[str] = ["tinker", "peft", "lora", "transformers"]
|
||||||
|
if base_model != "unknown":
|
||||||
|
tags.append(f"base_model:adapter:{base_model}")
|
||||||
|
model_card = [
|
||||||
|
"---",
|
||||||
|
f"base_model: {base_model}",
|
||||||
|
"library_name: peft",
|
||||||
|
"tags:",
|
||||||
|
]
|
||||||
|
for tag in tags:
|
||||||
|
model_card.append(f"- {tag}")
|
||||||
|
model_card.append(f"tinker_path: {tinker_path}")
|
||||||
|
model_card.extend(
|
||||||
|
[
|
||||||
|
"---",
|
||||||
|
"",
|
||||||
|
"# LoRA Adapter (Tinker)",
|
||||||
|
"",
|
||||||
|
f"This repository contains a LoRA adapter exported from Tinker.",
|
||||||
|
"",
|
||||||
|
"## Source",
|
||||||
|
"",
|
||||||
|
f"- Tinker checkpoint: {tinker_path}",
|
||||||
|
"",
|
||||||
|
"## Details",
|
||||||
|
"",
|
||||||
|
f"- Base model: {base_model}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if lora_rank is not None:
|
||||||
|
model_card.append(f"- LoRA rank: {lora_rank}")
|
||||||
|
if train_mlp is not None or train_attn is not None or train_unembed is not None:
|
||||||
|
model_card.append(
|
||||||
|
f"- Trained modules: attn={train_attn}, mlp={train_mlp}, unembed={train_unembed}"
|
||||||
|
)
|
||||||
|
model_card.append("")
|
||||||
|
readme_path.write_text("\n".join(model_card), encoding="utf-8")
|
||||||
|
|
||||||
|
api = HfApi(token=token)
|
||||||
|
api.create_repo(repo_id=repo_id, private=private, exist_ok=exist_ok)
|
||||||
|
|
||||||
|
api.upload_folder(
|
||||||
|
folder_path=os.fspath(extract_dir),
|
||||||
|
repo_id=repo_id,
|
||||||
|
path_in_repo="",
|
||||||
|
revision=revision,
|
||||||
|
commit_message=commit_message,
|
||||||
|
create_pr=create_pr,
|
||||||
|
allow_patterns=list(allow_patterns) if allow_patterns else None,
|
||||||
|
ignore_patterns=list(ignore_patterns) if ignore_patterns else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return repo_id
|
||||||
|
|
||||||
def _publish_checkpoint_submit(
|
def _publish_checkpoint_submit(
|
||||||
self, training_run_id: types.ModelID, checkpoint_id: str
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||||
) -> AwaitableConcurrentFuture[None]:
|
) -> AwaitableConcurrentFuture[None]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue