mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
5ad4282c96
commit
e7a0d0ca2d
15 changed files with 91 additions and 31 deletions
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "2f8a0c29b227897c25451163c8fa90adb2daa6e6",
|
||||
"last_sync_time": "2025-12-08T00:05:28.592609"
|
||||
"last_synced_sha": "6752e178932fa060b6a916ff0e2aefd1d0410970",
|
||||
"last_sync_time": "2025-12-15T01:00:20.438218"
|
||||
}
|
||||
|
|
@ -23,7 +23,7 @@ Args:
|
|||
|
||||
Example:
|
||||
```python
|
||||
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B")
|
||||
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen3-8B")
|
||||
prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is"))
|
||||
params = types.SamplingParams(max_tokens=20, temperature=0.7)
|
||||
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ def create_lora_training_client(
|
|||
Create a TrainingClient for LoRA fine-tuning.
|
||||
|
||||
Args:
|
||||
- `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen2.5-7B")
|
||||
- `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen3-8B")
|
||||
- `rank`: LoRA rank controlling the size of adaptation matrices (default 32)
|
||||
- `seed`: Random seed for initialization. None means random seed.
|
||||
- `train_mlp`: Whether to train MLP layers (default True)
|
||||
|
|
@ -88,7 +88,7 @@ Returns:
|
|||
Example:
|
||||
```python
|
||||
training_client = service_client.create_lora_training_client(
|
||||
base_model="Qwen/Qwen2.5-7B",
|
||||
base_model="Qwen/Qwen3-8B",
|
||||
rank=16,
|
||||
train_mlp=True,
|
||||
train_attn=True
|
||||
|
|
@ -203,7 +203,7 @@ Create a SamplingClient for text generation.
|
|||
|
||||
Args:
|
||||
- `model_path`: Path to saved model weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
||||
- `base_model`: Name of base model to use (e.g., "Qwen/Qwen2.5-7B")
|
||||
- `base_model`: Name of base model to use (e.g., "Qwen/Qwen3-8B")
|
||||
- `retry_config`: Optional configuration for retrying failed requests
|
||||
|
||||
Returns:
|
||||
|
|
@ -216,7 +216,7 @@ Example:
|
|||
```python
|
||||
# Use a base model
|
||||
sampling_client = service_client.create_sampling_client(
|
||||
base_model="Qwen/Qwen2.5-7B"
|
||||
base_model="Qwen/Qwen3-8B"
|
||||
)
|
||||
|
||||
# Or use saved weights
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ Args:
|
|||
|
||||
Example:
|
||||
```python
|
||||
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B")
|
||||
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
|
||||
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
|
||||
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
|
||||
fwdbwd_result = fwdbwd_future.result() # Wait for gradients
|
||||
|
|
|
|||
|
|
@ -84,6 +84,8 @@ class AsyncTinker(AsyncAPIClient):
|
|||
raise TinkerError(
|
||||
"The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable"
|
||||
)
|
||||
if not api_key.startswith("tml-"):
|
||||
raise TinkerError("The api_key must start with the 'tml-' prefix")
|
||||
self.api_key = api_key
|
||||
|
||||
if base_url is None:
|
||||
|
|
|
|||
|
|
@ -36,19 +36,19 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
The SamplingClient lets you generate text tokens from either a base model or from weights
|
||||
you've saved using a TrainingClient. You typically get one by calling
|
||||
`service_client.create_sampling_client()` or `training_client.save_weights_and_get_sampling_client()`.
|
||||
|
||||
Key methods:
|
||||
- sample() - generate text completions with customizable parameters
|
||||
- compute_logprobs() - get log probabilities for prompt tokens
|
||||
|
||||
Args:
|
||||
- `holder`: Internal client managing HTTP connections and async operations
|
||||
Create method parameters:
|
||||
- `model_path`: Path to saved model weights (starts with 'tinker://')
|
||||
- `base_model`: Name of base model to use for inference
|
||||
- `base_model`: Name of base model to use for inference (e.g., 'Qwen/Qwen3-8B')
|
||||
- `retry_config`: Configuration for retrying failed requests
|
||||
|
||||
Example:
|
||||
```python
|
||||
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B")
|
||||
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen3-8B")
|
||||
prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is"))
|
||||
params = types.SamplingParams(max_tokens=20, temperature=0.7)
|
||||
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
|
||||
|
|
@ -300,7 +300,9 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None:
|
||||
def on_queue_state_change(
|
||||
self, queue_state: QueueState, queue_state_reason: str | None
|
||||
) -> None:
|
||||
QUEUE_STATE_LOG_INTERVAL = 60
|
||||
if queue_state == QueueState.ACTIVE:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -143,7 +143,9 @@ class ServiceClient(TelemetryProvider):
|
|||
model_id = create_model_response.model_id
|
||||
from .training_client import TrainingClient
|
||||
|
||||
training_client = TrainingClient(self.holder, model_seq_id=model_seq_id, model_id=model_id)
|
||||
training_client = TrainingClient(
|
||||
self.holder, model_seq_id=model_seq_id, model_id=model_id
|
||||
)
|
||||
logger.info(f"TrainingClient initialized for model {model_id}")
|
||||
return training_client
|
||||
|
||||
|
|
@ -164,7 +166,7 @@ class ServiceClient(TelemetryProvider):
|
|||
"""Create a TrainingClient for LoRA fine-tuning.
|
||||
|
||||
Args:
|
||||
- `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen2.5-7B")
|
||||
- `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen3-8B")
|
||||
- `rank`: LoRA rank controlling the size of adaptation matrices (default 32)
|
||||
- `seed`: Random seed for initialization. None means random seed.
|
||||
- `train_mlp`: Whether to train MLP layers (default True)
|
||||
|
|
@ -178,7 +180,7 @@ class ServiceClient(TelemetryProvider):
|
|||
Example:
|
||||
```python
|
||||
training_client = service_client.create_lora_training_client(
|
||||
base_model="Qwen/Qwen2.5-7B",
|
||||
base_model="Qwen/Qwen3-8B",
|
||||
rank=16,
|
||||
train_mlp=True,
|
||||
train_attn=True
|
||||
|
|
@ -352,7 +354,7 @@ class ServiceClient(TelemetryProvider):
|
|||
|
||||
Args:
|
||||
- `model_path`: Path to saved model weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
||||
- `base_model`: Name of base model to use (e.g., "Qwen/Qwen2.5-7B")
|
||||
- `base_model`: Name of base model to use (e.g., "Qwen/Qwen3-8B")
|
||||
- `retry_config`: Optional configuration for retrying failed requests
|
||||
|
||||
Returns:
|
||||
|
|
@ -365,7 +367,7 @@ class ServiceClient(TelemetryProvider):
|
|||
```python
|
||||
# Use a base model
|
||||
sampling_client = service_client.create_sampling_client(
|
||||
base_model="Qwen/Qwen2.5-7B"
|
||||
base_model="Qwen/Qwen3-8B"
|
||||
)
|
||||
|
||||
# Or use saved weights
|
||||
|
|
|
|||
|
|
@ -44,6 +44,8 @@ MAX_CHUNK_LEN = 1024
|
|||
MAX_CHUNK_BYTES_COUNT = 5000000
|
||||
MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initializing the TrainingClient with a model_id by either calling create_lora_training_client on the ServiceClient, or initiliazing the TrainingClient with an existing model_id."
|
||||
|
||||
# Type alias for custom loss functions.
|
||||
# Args: (data: List[Datum], model_outputs: List[Any]) -> (loss: Any, metrics: Dict[str, float])
|
||||
CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]]
|
||||
|
||||
|
||||
|
|
@ -63,7 +65,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
|
||||
Example:
|
||||
```python
|
||||
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B")
|
||||
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
|
||||
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
|
||||
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
|
||||
fwdbwd_result = fwdbwd_future.result() # Wait for gradients
|
||||
|
|
@ -122,7 +124,9 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
return self.model_id
|
||||
|
||||
def _estimate_bytes_count(self, datum: types.Datum) -> int:
|
||||
return self.holder.estimate_bytes_count_in_model_input(datum.model_input) + sum(len(value.data) * 10 for _, value in datum.loss_fn_inputs.items())
|
||||
return self.holder.estimate_bytes_count_in_model_input(datum.model_input) + sum(
|
||||
len(value.data) * 10 for _, value in datum.loss_fn_inputs.items()
|
||||
)
|
||||
|
||||
def _chunked_requests_generator(
|
||||
self, data: List[types.Datum]
|
||||
|
|
@ -831,7 +835,9 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None:
|
||||
def on_queue_state_change(
|
||||
self, queue_state: QueueState, queue_state_reason: str | None
|
||||
) -> None:
|
||||
QUEUE_STATE_LOG_INTERVAL = 60
|
||||
if queue_state == QueueState.ACTIVE:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -37,18 +37,37 @@ def is_retryable_status_code(status_code: int) -> bool:
|
|||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""Configuration for retry behavior on failed API requests.
|
||||
|
||||
Controls connection limits, timeouts, and exponential backoff parameters
|
||||
for automatic request retries.
|
||||
"""
|
||||
|
||||
max_connections: int = DEFAULT_CONNECTION_LIMITS.max_connections or 100
|
||||
"""Maximum number of concurrent connections allowed."""
|
||||
|
||||
progress_timeout: float = 120 * 60 # Very long straggler
|
||||
"""Timeout in seconds before failing if no progress is made."""
|
||||
|
||||
retry_delay_base: float = INITIAL_RETRY_DELAY
|
||||
"""Initial delay in seconds before first retry."""
|
||||
|
||||
retry_delay_max: float = MAX_RETRY_DELAY
|
||||
"""Maximum delay in seconds between retries."""
|
||||
|
||||
jitter_factor: float = 0.25
|
||||
"""Random jitter factor (0-1) applied to retry delays."""
|
||||
|
||||
enable_retry_logic: bool = True
|
||||
"""Whether to enable automatic retries on failure."""
|
||||
|
||||
retryable_exceptions: tuple[Type[Exception], ...] = (
|
||||
asyncio.TimeoutError,
|
||||
tinker.APIConnectionError,
|
||||
httpx.TimeoutException,
|
||||
RetryableException,
|
||||
)
|
||||
"""Exception types that should trigger a retry."""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.max_connections <= 0:
|
||||
|
|
|
|||
|
|
@ -15,11 +15,13 @@ class CreateModelRequest(StrictBase):
|
|||
model_seq_id: int
|
||||
|
||||
base_model: str
|
||||
"""The name of the base model to fine-tune (e.g., 'Qwen/Qwen3-8B')."""
|
||||
|
||||
"""Optional metadata about this model/training run, set by the end-user"""
|
||||
user_metadata: Optional[dict[str, Any]] = None
|
||||
"""Optional metadata about this model/training run, set by the end-user."""
|
||||
|
||||
lora_config: Optional[LoraConfig] = None
|
||||
"""LoRA configuration"""
|
||||
|
||||
type: Literal["create_model"] = "create_model"
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ __all__ = ["ForwardBackwardOutput"]
|
|||
|
||||
class ForwardBackwardOutput(BaseModel):
|
||||
loss_fn_output_type: str
|
||||
"""The type of the ForwardBackward output. Can be one of [...] TODO"""
|
||||
"""The class name of the loss function output records (e.g., 'TorchLossReturn', 'ArrayRecord')."""
|
||||
|
||||
loss_fn_outputs: List[LossFnOutput]
|
||||
"""Dictionary mapping field names to tensor data"""
|
||||
|
|
|
|||
|
|
@ -8,25 +8,38 @@ __all__ = ["GetInfoResponse", "ModelData"]
|
|||
|
||||
|
||||
class ModelData(BaseModel):
|
||||
"""Metadata about a model's architecture and configuration."""
|
||||
|
||||
arch: Optional[str] = None
|
||||
"""The model architecture identifier."""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
"""The human-readable model name."""
|
||||
|
||||
tokenizer_id: Optional[str] = None
|
||||
"""The identifier of the tokenizer used by this model."""
|
||||
|
||||
|
||||
class GetInfoResponse(BaseModel):
|
||||
"""Response containing information about a training client's model."""
|
||||
|
||||
type: Optional[Literal["get_info"]] = None
|
||||
"""Response type identifier."""
|
||||
|
||||
model_data: ModelData
|
||||
"""Detailed metadata about the model."""
|
||||
|
||||
model_id: ModelID
|
||||
"""Unique identifier for the model."""
|
||||
|
||||
is_lora: Optional[bool] = None
|
||||
"""Whether this is a LoRA fine-tuned model."""
|
||||
|
||||
lora_rank: Optional[int] = None
|
||||
"""The rank of the LoRA adaptation, if applicable."""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
"""The name of the model."""
|
||||
|
||||
if PYDANTIC_V2:
|
||||
# allow fields with a `model_` prefix
|
||||
|
|
|
|||
|
|
@ -6,8 +6,14 @@ __all__ = ["GetServerCapabilitiesResponse", "SupportedModel"]
|
|||
|
||||
|
||||
class SupportedModel(BaseModel):
|
||||
"""Information about a model supported by the server."""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
"""The name of the supported model."""
|
||||
|
||||
|
||||
class GetServerCapabilitiesResponse(BaseModel):
|
||||
"""Response containing the server's supported models and capabilities."""
|
||||
|
||||
supported_models: List[SupportedModel]
|
||||
"""List of models available on the server."""
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def pytest_collection_modifyitems(items: list[pytest.Function]) -> None:
|
|||
|
||||
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
|
||||
|
||||
api_key = "My API Key"
|
||||
api_key = "tml-My API Key"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ from tinker._base_client import (
|
|||
from .utils import update_env
|
||||
|
||||
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
|
||||
api_key = "My API Key"
|
||||
api_key = "tml-My API Key"
|
||||
|
||||
|
||||
def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
|
||||
|
|
@ -97,9 +97,9 @@ class TestTinker:
|
|||
copied = self.client.copy()
|
||||
assert id(copied) != id(self.client)
|
||||
|
||||
copied = self.client.copy(api_key="another My API Key")
|
||||
assert copied.api_key == "another My API Key"
|
||||
assert self.client.api_key == "My API Key"
|
||||
copied = self.client.copy(api_key="tml-another My API Key")
|
||||
assert copied.api_key == "tml-another My API Key"
|
||||
assert self.client.api_key == "tml-My API Key"
|
||||
|
||||
def test_copy_default_options(self) -> None:
|
||||
# options that have a default are overridden correctly
|
||||
|
|
@ -353,6 +353,10 @@ class TestTinker:
|
|||
client2 = Tinker(base_url=base_url, api_key=None, _strict_response_validation=True)
|
||||
_ = client2
|
||||
|
||||
def test_api_key_prefix_validation(self) -> None:
|
||||
with pytest.raises(TinkerError):
|
||||
Tinker(base_url=base_url, api_key="not-tml-prefix", _strict_response_validation=True)
|
||||
|
||||
def test_default_query_option(self) -> None:
|
||||
client = Tinker(
|
||||
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
|
||||
|
|
@ -1076,9 +1080,9 @@ class TestAsyncTinker:
|
|||
copied = self.client.copy()
|
||||
assert id(copied) != id(self.client)
|
||||
|
||||
copied = self.client.copy(api_key="another My API Key")
|
||||
assert copied.api_key == "another My API Key"
|
||||
assert self.client.api_key == "My API Key"
|
||||
copied = self.client.copy(api_key="tml-another My API Key")
|
||||
assert copied.api_key == "tml-another My API Key"
|
||||
assert self.client.api_key == "tml-My API Key"
|
||||
|
||||
def test_copy_default_options(self) -> None:
|
||||
# options that have a default are overridden correctly
|
||||
|
|
@ -1334,6 +1338,10 @@ class TestAsyncTinker:
|
|||
client2 = AsyncTinker(base_url=base_url, api_key=None, _strict_response_validation=True)
|
||||
_ = client2
|
||||
|
||||
def test_api_key_prefix_validation(self) -> None:
|
||||
with pytest.raises(TinkerError):
|
||||
AsyncTinker(base_url=base_url, api_key="not-tml-prefix", _strict_response_validation=True)
|
||||
|
||||
def test_default_query_option(self) -> None:
|
||||
client = AsyncTinker(
|
||||
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue