diff --git a/.sync_state b/.sync_state index 8e48452..0a7336c 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "2f8a0c29b227897c25451163c8fa90adb2daa6e6", - "last_sync_time": "2025-12-08T00:05:28.592609" + "last_synced_sha": "6752e178932fa060b6a916ff0e2aefd1d0410970", + "last_sync_time": "2025-12-15T01:00:20.438218" } \ No newline at end of file diff --git a/docs/api/samplingclient.md b/docs/api/samplingclient.md index 4970bd1..e16a0f7 100644 --- a/docs/api/samplingclient.md +++ b/docs/api/samplingclient.md @@ -23,7 +23,7 @@ Args: Example: ```python -sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B") +sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen3-8B") prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is")) params = types.SamplingParams(max_tokens=20, temperature=0.7) future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1) diff --git a/docs/api/serviceclient.md b/docs/api/serviceclient.md index 83653e0..e0e683f 100644 --- a/docs/api/serviceclient.md +++ b/docs/api/serviceclient.md @@ -74,7 +74,7 @@ def create_lora_training_client( Create a TrainingClient for LoRA fine-tuning. Args: -- `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen2.5-7B") +- `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen3-8B") - `rank`: LoRA rank controlling the size of adaptation matrices (default 32) - `seed`: Random seed for initialization. None means random seed. - `train_mlp`: Whether to train MLP layers (default True) @@ -88,7 +88,7 @@ Returns: Example: ```python training_client = service_client.create_lora_training_client( - base_model="Qwen/Qwen2.5-7B", + base_model="Qwen/Qwen3-8B", rank=16, train_mlp=True, train_attn=True @@ -203,7 +203,7 @@ Create a SamplingClient for text generation. Args: - `model_path`: Path to saved model weights (e.g., "tinker://run-id/weights/checkpoint-001") -- `base_model`: Name of base model to use (e.g., "Qwen/Qwen2.5-7B") +- `base_model`: Name of base model to use (e.g., "Qwen/Qwen3-8B") - `retry_config`: Optional configuration for retrying failed requests Returns: @@ -216,7 +216,7 @@ Example: ```python # Use a base model sampling_client = service_client.create_sampling_client( - base_model="Qwen/Qwen2.5-7B" + base_model="Qwen/Qwen3-8B" ) # Or use saved weights diff --git a/docs/api/trainingclient.md b/docs/api/trainingclient.md index d251d16..487f904 100644 --- a/docs/api/trainingclient.md +++ b/docs/api/trainingclient.md @@ -21,7 +21,7 @@ Args: Example: ```python -training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B") +training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B") fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy") optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4)) fwdbwd_result = fwdbwd_future.result() # Wait for gradients diff --git a/src/tinker/_client.py b/src/tinker/_client.py index 1246295..b932b90 100644 --- a/src/tinker/_client.py +++ b/src/tinker/_client.py @@ -84,6 +84,8 @@ class AsyncTinker(AsyncAPIClient): raise TinkerError( "The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable" ) + if not api_key.startswith("tml-"): + raise TinkerError("The api_key must start with the 'tml-' prefix") self.api_key = api_key if base_url is None: diff --git a/src/tinker/lib/public_interfaces/sampling_client.py b/src/tinker/lib/public_interfaces/sampling_client.py index fd913fe..31d4cb2 100644 --- a/src/tinker/lib/public_interfaces/sampling_client.py +++ b/src/tinker/lib/public_interfaces/sampling_client.py @@ -36,19 +36,19 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): The SamplingClient lets you generate text tokens from either a base model or from weights you've saved using a TrainingClient. You typically get one by calling `service_client.create_sampling_client()` or `training_client.save_weights_and_get_sampling_client()`. + Key methods: - sample() - generate text completions with customizable parameters - compute_logprobs() - get log probabilities for prompt tokens - Args: - - `holder`: Internal client managing HTTP connections and async operations + Create method parameters: - `model_path`: Path to saved model weights (starts with 'tinker://') - - `base_model`: Name of base model to use for inference + - `base_model`: Name of base model to use for inference (e.g., 'Qwen/Qwen3-8B') - `retry_config`: Configuration for retrying failed requests Example: ```python - sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen2.5-7B") + sampling_client = service_client.create_sampling_client(base_model="Qwen/Qwen3-8B") prompt = types.ModelInput.from_ints(tokenizer.encode("The weather today is")) params = types.SamplingParams(max_tokens=20, temperature=0.7) future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1) @@ -300,7 +300,9 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): def get_telemetry(self) -> Telemetry | None: return self.holder.get_telemetry() - def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None: + def on_queue_state_change( + self, queue_state: QueueState, queue_state_reason: str | None + ) -> None: QUEUE_STATE_LOG_INTERVAL = 60 if queue_state == QueueState.ACTIVE: return diff --git a/src/tinker/lib/public_interfaces/service_client.py b/src/tinker/lib/public_interfaces/service_client.py index 72ee967..21389c0 100644 --- a/src/tinker/lib/public_interfaces/service_client.py +++ b/src/tinker/lib/public_interfaces/service_client.py @@ -143,7 +143,9 @@ class ServiceClient(TelemetryProvider): model_id = create_model_response.model_id from .training_client import TrainingClient - training_client = TrainingClient(self.holder, model_seq_id=model_seq_id, model_id=model_id) + training_client = TrainingClient( + self.holder, model_seq_id=model_seq_id, model_id=model_id + ) logger.info(f"TrainingClient initialized for model {model_id}") return training_client @@ -164,7 +166,7 @@ class ServiceClient(TelemetryProvider): """Create a TrainingClient for LoRA fine-tuning. Args: - - `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen2.5-7B") + - `base_model`: Name of the base model to fine-tune (e.g., "Qwen/Qwen3-8B") - `rank`: LoRA rank controlling the size of adaptation matrices (default 32) - `seed`: Random seed for initialization. None means random seed. - `train_mlp`: Whether to train MLP layers (default True) @@ -178,7 +180,7 @@ class ServiceClient(TelemetryProvider): Example: ```python training_client = service_client.create_lora_training_client( - base_model="Qwen/Qwen2.5-7B", + base_model="Qwen/Qwen3-8B", rank=16, train_mlp=True, train_attn=True @@ -352,7 +354,7 @@ class ServiceClient(TelemetryProvider): Args: - `model_path`: Path to saved model weights (e.g., "tinker://run-id/weights/checkpoint-001") - - `base_model`: Name of base model to use (e.g., "Qwen/Qwen2.5-7B") + - `base_model`: Name of base model to use (e.g., "Qwen/Qwen3-8B") - `retry_config`: Optional configuration for retrying failed requests Returns: @@ -365,7 +367,7 @@ class ServiceClient(TelemetryProvider): ```python # Use a base model sampling_client = service_client.create_sampling_client( - base_model="Qwen/Qwen2.5-7B" + base_model="Qwen/Qwen3-8B" ) # Or use saved weights diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index 78d8296..6307260 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -44,6 +44,8 @@ MAX_CHUNK_LEN = 1024 MAX_CHUNK_BYTES_COUNT = 5000000 MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initializing the TrainingClient with a model_id by either calling create_lora_training_client on the ServiceClient, or initiliazing the TrainingClient with an existing model_id." +# Type alias for custom loss functions. +# Args: (data: List[Datum], model_outputs: List[Any]) -> (loss: Any, metrics: Dict[str, float]) CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]] @@ -63,7 +65,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): Example: ```python - training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen2.5-7B") + training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B") fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy") optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4)) fwdbwd_result = fwdbwd_future.result() # Wait for gradients @@ -122,7 +124,9 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): return self.model_id def _estimate_bytes_count(self, datum: types.Datum) -> int: - return self.holder.estimate_bytes_count_in_model_input(datum.model_input) + sum(len(value.data) * 10 for _, value in datum.loss_fn_inputs.items()) + return self.holder.estimate_bytes_count_in_model_input(datum.model_input) + sum( + len(value.data) * 10 for _, value in datum.loss_fn_inputs.items() + ) def _chunked_requests_generator( self, data: List[types.Datum] @@ -831,7 +835,9 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): def get_telemetry(self) -> Telemetry | None: return self.holder.get_telemetry() - def on_queue_state_change(self, queue_state: QueueState, queue_state_reason: str | None) -> None: + def on_queue_state_change( + self, queue_state: QueueState, queue_state_reason: str | None + ) -> None: QUEUE_STATE_LOG_INTERVAL = 60 if queue_state == QueueState.ACTIVE: return diff --git a/src/tinker/lib/retry_handler.py b/src/tinker/lib/retry_handler.py index 887abe3..174d825 100644 --- a/src/tinker/lib/retry_handler.py +++ b/src/tinker/lib/retry_handler.py @@ -37,18 +37,37 @@ def is_retryable_status_code(status_code: int) -> bool: @dataclass class RetryConfig: + """Configuration for retry behavior on failed API requests. + + Controls connection limits, timeouts, and exponential backoff parameters + for automatic request retries. + """ + max_connections: int = DEFAULT_CONNECTION_LIMITS.max_connections or 100 + """Maximum number of concurrent connections allowed.""" + progress_timeout: float = 120 * 60 # Very long straggler + """Timeout in seconds before failing if no progress is made.""" + retry_delay_base: float = INITIAL_RETRY_DELAY + """Initial delay in seconds before first retry.""" + retry_delay_max: float = MAX_RETRY_DELAY + """Maximum delay in seconds between retries.""" + jitter_factor: float = 0.25 + """Random jitter factor (0-1) applied to retry delays.""" + enable_retry_logic: bool = True + """Whether to enable automatic retries on failure.""" + retryable_exceptions: tuple[Type[Exception], ...] = ( asyncio.TimeoutError, tinker.APIConnectionError, httpx.TimeoutException, RetryableException, ) + """Exception types that should trigger a retry.""" def __post_init__(self): if self.max_connections <= 0: diff --git a/src/tinker/types/create_model_request.py b/src/tinker/types/create_model_request.py index 91f6a07..5340453 100644 --- a/src/tinker/types/create_model_request.py +++ b/src/tinker/types/create_model_request.py @@ -15,11 +15,13 @@ class CreateModelRequest(StrictBase): model_seq_id: int base_model: str + """The name of the base model to fine-tune (e.g., 'Qwen/Qwen3-8B').""" - """Optional metadata about this model/training run, set by the end-user""" user_metadata: Optional[dict[str, Any]] = None + """Optional metadata about this model/training run, set by the end-user.""" lora_config: Optional[LoraConfig] = None + """LoRA configuration""" type: Literal["create_model"] = "create_model" diff --git a/src/tinker/types/forward_backward_output.py b/src/tinker/types/forward_backward_output.py index 6a81d76..502af09 100644 --- a/src/tinker/types/forward_backward_output.py +++ b/src/tinker/types/forward_backward_output.py @@ -8,7 +8,7 @@ __all__ = ["ForwardBackwardOutput"] class ForwardBackwardOutput(BaseModel): loss_fn_output_type: str - """The type of the ForwardBackward output. Can be one of [...] TODO""" + """The class name of the loss function output records (e.g., 'TorchLossReturn', 'ArrayRecord').""" loss_fn_outputs: List[LossFnOutput] """Dictionary mapping field names to tensor data""" diff --git a/src/tinker/types/get_info_response.py b/src/tinker/types/get_info_response.py index bfc947e..acc4f1c 100644 --- a/src/tinker/types/get_info_response.py +++ b/src/tinker/types/get_info_response.py @@ -8,25 +8,38 @@ __all__ = ["GetInfoResponse", "ModelData"] class ModelData(BaseModel): + """Metadata about a model's architecture and configuration.""" + arch: Optional[str] = None + """The model architecture identifier.""" model_name: Optional[str] = None + """The human-readable model name.""" tokenizer_id: Optional[str] = None + """The identifier of the tokenizer used by this model.""" class GetInfoResponse(BaseModel): + """Response containing information about a training client's model.""" + type: Optional[Literal["get_info"]] = None + """Response type identifier.""" model_data: ModelData + """Detailed metadata about the model.""" model_id: ModelID + """Unique identifier for the model.""" is_lora: Optional[bool] = None + """Whether this is a LoRA fine-tuned model.""" lora_rank: Optional[int] = None + """The rank of the LoRA adaptation, if applicable.""" model_name: Optional[str] = None + """The name of the model.""" if PYDANTIC_V2: # allow fields with a `model_` prefix diff --git a/src/tinker/types/get_server_capabilities_response.py b/src/tinker/types/get_server_capabilities_response.py index 6d17d57..c69e9e0 100644 --- a/src/tinker/types/get_server_capabilities_response.py +++ b/src/tinker/types/get_server_capabilities_response.py @@ -6,8 +6,14 @@ __all__ = ["GetServerCapabilitiesResponse", "SupportedModel"] class SupportedModel(BaseModel): + """Information about a model supported by the server.""" + model_name: Optional[str] = None + """The name of the supported model.""" class GetServerCapabilitiesResponse(BaseModel): + """Response containing the server's supported models and capabilities.""" + supported_models: List[SupportedModel] + """List of models available on the server.""" diff --git a/tests/conftest.py b/tests/conftest.py index e7c0344..1182efa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,7 +43,7 @@ def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" +api_key = "tml-My API Key" @pytest.fixture(scope="session") diff --git a/tests/test_client.py b/tests/test_client.py index f55ddd4..db11a44 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -35,7 +35,7 @@ from tinker._base_client import ( from .utils import update_env base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" +api_key = "tml-My API Key" def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: @@ -97,9 +97,9 @@ class TestTinker: copied = self.client.copy() assert id(copied) != id(self.client) - copied = self.client.copy(api_key="another My API Key") - assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + copied = self.client.copy(api_key="tml-another My API Key") + assert copied.api_key == "tml-another My API Key" + assert self.client.api_key == "tml-My API Key" def test_copy_default_options(self) -> None: # options that have a default are overridden correctly @@ -353,6 +353,10 @@ class TestTinker: client2 = Tinker(base_url=base_url, api_key=None, _strict_response_validation=True) _ = client2 + def test_api_key_prefix_validation(self) -> None: + with pytest.raises(TinkerError): + Tinker(base_url=base_url, api_key="not-tml-prefix", _strict_response_validation=True) + def test_default_query_option(self) -> None: client = Tinker( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} @@ -1076,9 +1080,9 @@ class TestAsyncTinker: copied = self.client.copy() assert id(copied) != id(self.client) - copied = self.client.copy(api_key="another My API Key") - assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + copied = self.client.copy(api_key="tml-another My API Key") + assert copied.api_key == "tml-another My API Key" + assert self.client.api_key == "tml-My API Key" def test_copy_default_options(self) -> None: # options that have a default are overridden correctly @@ -1334,6 +1338,10 @@ class TestAsyncTinker: client2 = AsyncTinker(base_url=base_url, api_key=None, _strict_response_validation=True) _ = client2 + def test_api_key_prefix_validation(self) -> None: + with pytest.raises(TinkerError): + AsyncTinker(base_url=base_url, api_key="not-tml-prefix", _strict_response_validation=True) + def test_default_query_option(self) -> None: client = AsyncTinker( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}