Sync contents

This commit is contained in:
Daniel Xu 2025-12-15 01:00:20 +00:00
parent 5ad4282c96
commit e7a0d0ca2d
15 changed files with 91 additions and 31 deletions

View file

@ -1,4 +1,4 @@
{
"last_synced_sha": "2f8a0c29b227897c25451163c8fa90adb2daa6e6",
"last_sync_time": "2025-12-08T00:05:28.592609"
"last_synced_sha": "6752e178932fa060b6a916ff0e2aefd1d0410970",
"last_sync_time": "2025-12-15T01:00:20.438218"
}

View file

@ -23,7 +23,7 @@ Args:
Example:
```python
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B")
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen3-8B")
prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is"))
params = types.SamplingParams(max_tokens=20, temperature=0.7)
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)

View file

@ -74,7 +74,7 @@ def create_lora_training_client(
Create a TrainingClient for LoRA fine-tuning.
Args:
- `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen2.5-7B")
- `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen3-8B")
- `rank`: LoRA rank controlling the size of adaptation matrices (default 32)
- `seed`: Random seed for initialization. None means random seed.
- `train_mlp`: Whether to train MLP layers (default True)
@ -88,7 +88,7 @@ Returns:
Example:
```python
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen2.5-7B",
base_model="Qwen/Qwen3-8B",
rank=16,
train_mlp=True,
train_attn=True
@ -203,7 +203,7 @@ Create a SamplingClient for text generation.
Args:
- `model_path`: Path to saved model weights (e.g., "tinker://run-id/weights/checkpoint-001")
- `base_model`: Name of base model to use (e.g., "Qwen/Qwen2.5-7B")
- `base_model`: Name of base model to use (e.g., "Qwen/Qwen3-8B")
- `retry_config`: Optional configuration for retrying failed requests
Returns:
@ -216,7 +216,7 @@ Example:
```python
# Use a base model
sampling_client = service_client.create_sampling_client(
base_model="Qwen/Qwen2.5-7B"
base_model="Qwen/Qwen3-8B"
)
# Or use saved weights

View file

@ -21,7 +21,7 @@ Args:
Example:
```python
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B")
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
fwdbwd_result = fwdbwd_future.result() # Wait for gradients

View file

@ -84,6 +84,8 @@ class AsyncTinker(AsyncAPIClient):
raise TinkerError(
"The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable"
)
if not api_key.startswith("tml-"):
raise TinkerError("The api_key must start with the 'tml-' prefix")
self.api_key = api_key
if base_url is None:

View file

@ -36,19 +36,19 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
The SamplingClient lets you generate text tokens from either a base model or from weights
you've saved using a TrainingClient. You typically get one by calling
`service_client.create_sampling_client()` or `training_client.save_weights_and_get_sampling_client()`.
Key methods:
- sample() - generate text completions with customizable parameters
- compute_logprobs() - get log probabilities for prompt tokens
Args:
- `holder`: Internal client managing HTTP connections and async operations
Create method parameters:
- `model_path`: Path to saved model weights (starts with 'tinker://')
- `base_model`: Name of base model to use for inference
- `base_model`: Name of base model to use for inference (e.g., 'Qwen/Qwen3-8B')
- `retry_config`: Configuration for retrying failed requests
Example:
```python
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B")
sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen3-8B")
prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is"))
params = types.SamplingParams(max_tokens=20, temperature=0.7)
future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
@ -300,7 +300,9 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None:
def on_queue_state_change(
self, queue_state: QueueState, queue_state_reason: str | None
) -> None:
QUEUE_STATE_LOG_INTERVAL = 60
if queue_state == QueueState.ACTIVE:
return

View file

@ -143,7 +143,9 @@ class ServiceClient(TelemetryProvider):
model_id = create_model_response.model_id
from .training_client import TrainingClient
training_client = TrainingClient(self.holder, model_seq_id=model_seq_id, model_id=model_id)
training_client = TrainingClient(
self.holder, model_seq_id=model_seq_id, model_id=model_id
)
logger.info(f"TrainingClient initialized for model {model_id}")
return training_client
@ -164,7 +166,7 @@ class ServiceClient(TelemetryProvider):
"""Create a TrainingClient for LoRA fine-tuning.
Args:
- `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen2.5-7B")
- `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen3-8B")
- `rank`: LoRA rank controlling the size of adaptation matrices (default 32)
- `seed`: Random seed for initialization. None means random seed.
- `train_mlp`: Whether to train MLP layers (default True)
@ -178,7 +180,7 @@ class ServiceClient(TelemetryProvider):
Example:
```python
training_client = service_client.create_lora_training_client(
base_model="Qwen/Qwen2.5-7B",
base_model="Qwen/Qwen3-8B",
rank=16,
train_mlp=True,
train_attn=True
@ -352,7 +354,7 @@ class ServiceClient(TelemetryProvider):
Args:
- `model_path`: Path to saved model weights (e.g., "tinker://run-id/weights/checkpoint-001")
- `base_model`: Name of base model to use (e.g., "Qwen/Qwen2.5-7B")
- `base_model`: Name of base model to use (e.g., "Qwen/Qwen3-8B")
- `retry_config`: Optional configuration for retrying failed requests
Returns:
@ -365,7 +367,7 @@ class ServiceClient(TelemetryProvider):
```python
# Use a base model
sampling_client = service_client.create_sampling_client(
base_model="Qwen/Qwen2.5-7B"
base_model="Qwen/Qwen3-8B"
)
# Or use saved weights

View file

@ -44,6 +44,8 @@ MAX_CHUNK_LEN = 1024
MAX_CHUNK_BYTES_COUNT = 5000000
MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initializing the TrainingClient with a model_id by either calling create_lora_training_client on the ServiceClient, or initiliazing the TrainingClient with an existing model_id."
# Type alias for custom loss functions.
# Args: (data: List[Datum], model_outputs: List[Any]) -> (loss: Any, metrics: Dict[str, float])
CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]]
@ -63,7 +65,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
Example:
```python
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B")
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
fwdbwd_result = fwdbwd_future.result() # Wait for gradients
@ -122,7 +124,9 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
return self.model_id
def _estimate_bytes_count(self, datum: types.Datum) -> int:
return self.holder.estimate_bytes_count_in_model_input(datum.model_input) + sum(len(value.data) * 10 for _, value in datum.loss_fn_inputs.items())
return self.holder.estimate_bytes_count_in_model_input(datum.model_input) + sum(
len(value.data) * 10 for _, value in datum.loss_fn_inputs.items()
)
def _chunked_requests_generator(
self, data: List[types.Datum]
@ -831,7 +835,9 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None:
def on_queue_state_change(
self, queue_state: QueueState, queue_state_reason: str | None
) -> None:
QUEUE_STATE_LOG_INTERVAL = 60
if queue_state == QueueState.ACTIVE:
return

View file

@ -37,18 +37,37 @@ def is_retryable_status_code(status_code: int) -> bool:
@dataclass
class RetryConfig:
"""Configuration for retry behavior on failed API requests.
Controls connection limits, timeouts, and exponential backoff parameters
for automatic request retries.
"""
max_connections: int = DEFAULT_CONNECTION_LIMITS.max_connections or 100
"""Maximum number of concurrent connections allowed."""
progress_timeout: float = 120 * 60 # Very long straggler
"""Timeout in seconds before failing if no progress is made."""
retry_delay_base: float = INITIAL_RETRY_DELAY
"""Initial delay in seconds before first retry."""
retry_delay_max: float = MAX_RETRY_DELAY
"""Maximum delay in seconds between retries."""
jitter_factor: float = 0.25
"""Random jitter factor (0-1) applied to retry delays."""
enable_retry_logic: bool = True
"""Whether to enable automatic retries on failure."""
retryable_exceptions: tuple[Type[Exception], ...] = (
asyncio.TimeoutError,
tinker.APIConnectionError,
httpx.TimeoutException,
RetryableException,
)
"""Exception types that should trigger a retry."""
def __post_init__(self):
if self.max_connections <= 0:

View file

@ -15,11 +15,13 @@ class CreateModelRequest(StrictBase):
model_seq_id: int
base_model: str
"""The name of the base model to fine-tune (e.g., 'Qwen/Qwen3-8B')."""
"""Optional metadata about this model/training run, set by the end-user"""
user_metadata: Optional[dict[str, Any]] = None
"""Optional metadata about this model/training run, set by the end-user."""
lora_config: Optional[LoraConfig] = None
"""LoRA configuration"""
type: Literal["create_model"] = "create_model"

View file

@ -8,7 +8,7 @@ __all__ = ["ForwardBackwardOutput"]
class ForwardBackwardOutput(BaseModel):
loss_fn_output_type: str
"""The type of the ForwardBackward output. Can be one of [...] TODO"""
"""The class name of the loss function output records (e.g., 'TorchLossReturn', 'ArrayRecord')."""
loss_fn_outputs: List[LossFnOutput]
"""Dictionary mapping field names to tensor data"""

View file

@ -8,25 +8,38 @@ __all__ = ["GetInfoResponse", "ModelData"]
class ModelData(BaseModel):
"""Metadata about a model's architecture and configuration."""
arch: Optional[str] = None
"""The model architecture identifier."""
model_name: Optional[str] = None
"""The human-readable model name."""
tokenizer_id: Optional[str] = None
"""The identifier of the tokenizer used by this model."""
class GetInfoResponse(BaseModel):
"""Response containing information about a training client's model."""
type: Optional[Literal["get_info"]] = None
"""Response type identifier."""
model_data: ModelData
"""Detailed metadata about the model."""
model_id: ModelID
"""Unique identifier for the model."""
is_lora: Optional[bool] = None
"""Whether this is a LoRA fine-tuned model."""
lora_rank: Optional[int] = None
"""The rank of the LoRA adaptation, if applicable."""
model_name: Optional[str] = None
"""The name of the model."""
if PYDANTIC_V2:
# allow fields with a `model_` prefix

View file

@ -6,8 +6,14 @@ __all__ = ["GetServerCapabilitiesResponse", "SupportedModel"]
class SupportedModel(BaseModel):
"""Information about a model supported by the server."""
model_name: Optional[str] = None
"""The name of the supported model."""
class GetServerCapabilitiesResponse(BaseModel):
"""Response containing the server's supported models and capabilities."""
supported_models: List[SupportedModel]
"""List of models available on the server."""

View file

@ -43,7 +43,7 @@ def pytest_collection_modifyitems(items: list[pytest.Function]) -> None:
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = "My API Key"
api_key = "tml-My API Key"
@pytest.fixture(scope="session")

View file

@ -35,7 +35,7 @@ from tinker._base_client import (
from .utils import update_env
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = "My API Key"
api_key = "tml-My API Key"
def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
@ -97,9 +97,9 @@ class TestTinker:
copied = self.client.copy()
assert id(copied) != id(self.client)
copied = self.client.copy(api_key="another My API Key")
assert copied.api_key == "another My API Key"
assert self.client.api_key == "My API Key"
copied = self.client.copy(api_key="tml-another My API Key")
assert copied.api_key == "tml-another My API Key"
assert self.client.api_key == "tml-My API Key"
def test_copy_default_options(self) -> None:
# options that have a default are overridden correctly
@ -353,6 +353,10 @@ class TestTinker:
client2 = Tinker(base_url=base_url, api_key=None, _strict_response_validation=True)
_ = client2
def test_api_key_prefix_validation(self) -> None:
with pytest.raises(TinkerError):
Tinker(base_url=base_url, api_key="not-tml-prefix", _strict_response_validation=True)
def test_default_query_option(self) -> None:
client = Tinker(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
@ -1076,9 +1080,9 @@ class TestAsyncTinker:
copied = self.client.copy()
assert id(copied) != id(self.client)
copied = self.client.copy(api_key="another My API Key")
assert copied.api_key == "another My API Key"
assert self.client.api_key == "My API Key"
copied = self.client.copy(api_key="tml-another My API Key")
assert copied.api_key == "tml-another My API Key"
assert self.client.api_key == "tml-My API Key"
def test_copy_default_options(self) -> None:
# options that have a default are overridden correctly
@ -1334,6 +1338,10 @@ class TestAsyncTinker:
client2 = AsyncTinker(base_url=base_url, api_key=None, _strict_response_validation=True)
_ = client2
def test_api_key_prefix_validation(self) -> None:
with pytest.raises(TinkerError):
AsyncTinker(base_url=base_url, api_key="not-tml-prefix", _strict_response_validation=True)
def test_default_query_option(self) -> None:
client = AsyncTinker(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}