Sync contents

This commit is contained in:
Daniel Xu 2025-10-15 16:37:21 +00:00
parent 9ba155a34d
commit 422bc7b9f1
122 changed files with 516 additions and 477 deletions

View file

@ -1,4 +1,4 @@
{
"last_synced_sha": "b099cce3c688f59468230f58027c38c2a81bcc13",
"last_sync_time": "2025-10-01T18:17:16.846167"
"last_synced_sha": "db70edcca9cdd2d06faa9e4af37538547f308b14",
"last_sync_time": "2025-10-15T16:37:21.808358"
}

View file

@ -1,4 +0,0 @@
File generated from our OpenAPI spec by Stainless.
This directory can be used to store example files demonstrating usage of this SDK.
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.

View file

@ -1,6 +1,6 @@
[project]
name = "tinker"
version = "0.1.5"
version = "0.2.0"
description = "The official Python SDK for the tinker API"
readme = "README.md"
license = "Apache-2.0"

View file

@ -1,27 +0,0 @@
#!/usr/bin/env bash
set -exuo pipefail
FILENAME=$(basename dist/*.whl)
RESPONSE=$(curl -X POST "$URL?filename=$FILENAME" \
-H "Authorization: Bearer $AUTH" \
-H "Content-Type: application/json")
SIGNED_URL=$(echo "$RESPONSE" | jq -r '.url')
if [[ "$SIGNED_URL" == "null" ]]; then
echo -e "\033[31mFailed to get signed URL.\033[0m"
exit 1
fi
UPLOAD_RESPONSE=$(curl -v -X PUT \
-H "Content-Type: binary/octet-stream" \
--data-binary "@dist/$FILENAME" "$SIGNED_URL" 2>&1)
if echo "$UPLOAD_RESPONSE" | grep -q "HTTP/[0-9.]* 200"; then
echo -e "\033[32mUploaded build to Stainless storage.\033[0m"
echo -e "\033[32mInstallation: pip install 'https://pkg.stainless.com/s/tinker-python/$SHA/$FILENAME'\033[0m"
else
echo -e "\033[31mFailed to upload artifact.\033[0m"
exit 1
fi

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
import typing as _t
from . import types

View file

@ -1924,7 +1924,6 @@ def get_platform() -> Platform:
@lru_cache(maxsize=None)
def platform_headers(version: str, *, platform: Platform | None) -> Dict[str, str]:
return {
"X-Stainless-Lang": "python",
"X-Stainless-Package-Version": version,
"X-Stainless-OS": str(platform or get_platform()),
"X-Stainless-Arch": str(get_architecture()),

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import os

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
import httpx
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import time

View file

@ -1,4 +1,2 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "tinker"
__version__ = "0.0.1-alpha.1"

View file

@ -1,4 +0,0 @@
File generated from our OpenAPI spec by Stainless.
This directory can be used to store custom files to expand the SDK.
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.

View file

@ -14,7 +14,7 @@ import tinker
from tinker import types
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import APIFuture
from tinker.lib.telemetry import Telemetry
from tinker.lib.telemetry import Telemetry, is_user_error
from .._models import BaseModel
from .retryable_exception import RetryableException
@ -121,6 +121,7 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
except tinker.APIStatusError as e:
connection_error_retries = 0
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
user_error = is_user_error(e)
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
@ -131,10 +132,11 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
"status_code": e.status_code,
"exception": str(e),
"should_retry": should_retry,
"is_user_error": user_error,
"iteration": iteration,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if should_retry else "ERROR",
severity="WARNING" if should_retry or user_error else "ERROR",
)
# Retry 408s until we time out

View file

@ -19,7 +19,7 @@ from tinker._exceptions import APIConnectionError, APIStatusError
from tinker.lib.async_tinker_provider import AsyncTinkerProvider
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker.lib.telemetry import Telemetry, init_telemetry
from tinker.lib.telemetry import Telemetry, init_telemetry, is_user_error
from tinker.lib.telemetry_provider import TelemetryProvider
logger = logging.getLogger(__name__)
@ -190,6 +190,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
return await func(*args, **kwargs)
except Exception as e:
is_retryable = self._is_retryable_exception(e)
user_error = is_user_error(e)
current_time = time.time()
elapsed_time = current_time - start_time
if telemetry := self.get_telemetry():
@ -208,12 +209,13 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
else None,
"status_code": getattr(e, "status_code", None),
"is_retryable": is_retryable,
"is_user_error": user_error,
"attempt_count": attempt_count,
"start_time": start_time,
"current_time": current_time,
"elapsed_time": elapsed_time,
},
severity="WARNING" if is_retryable else "ERROR",
severity="WARNING" if is_retryable or user_error else "ERROR",
)
if is_retryable and elapsed_time < MAX_WAIT_TIME:
# Apply exponential backoff

View file

@ -33,7 +33,7 @@ class RestClient(TelemetryProvider):
- list_checkpoints() - list available model checkpoints (both training and sampler)
- get_training_run() - get model information and metadata as ModelEntry
- delete_checkpoint() - delete an existing checkpoint for a training run
- download_sampler_weights_archive() - download sampler weights checkpoint as tar.gz archive
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
Args:
holder: Internal client managing HTTP connections and async operations
@ -55,6 +55,7 @@ class RestClient(TelemetryProvider):
self, training_run_id: types.ModelID
) -> AwaitableConcurrentFuture[types.TrainingRun]:
"""Internal method to submit get model request."""
async def _get_training_run_async() -> types.TrainingRun:
async def _send_request() -> types.TrainingRun:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
@ -69,7 +70,9 @@ class RestClient(TelemetryProvider):
@sync_only
@capture_exceptions(fatal=True)
def get_training_run(self, training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun]:
def get_training_run(
self, training_run_id: types.ModelID
) -> ConcurrentFuture[types.TrainingRun]:
"""Get training run info.
Args:
@ -103,7 +106,9 @@ class RestClient(TelemetryProvider):
@sync_only
@capture_exceptions(fatal=True)
def get_training_run_by_tinker_path(self, tinker_path: str) -> ConcurrentFuture[types.TrainingRun]:
def get_training_run_by_tinker_path(
self, tinker_path: str
) -> ConcurrentFuture[types.TrainingRun]:
"""Get training run info.
Args:
@ -145,6 +150,7 @@ class RestClient(TelemetryProvider):
self, limit: int = 20, offset: int = 0
) -> AwaitableConcurrentFuture[types.TrainingRunsResponse]:
"""Internal method to submit list training runs request."""
async def _list_training_runs_async() -> types.TrainingRunsResponse:
async def _send_request() -> types.TrainingRunsResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
@ -210,6 +216,7 @@ class RestClient(TelemetryProvider):
self, training_run_id: types.ModelID
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
"""Internal method to submit list model checkpoints request."""
async def _list_checkpoints_async():
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
@ -221,7 +228,9 @@ class RestClient(TelemetryProvider):
@sync_only
@capture_exceptions(fatal=True)
def list_checkpoints(self, training_run_id: types.ModelID) -> ConcurrentFuture[types.CheckpointsListResponse]:
def list_checkpoints(
self, training_run_id: types.ModelID
) -> ConcurrentFuture[types.CheckpointsListResponse]:
"""List available checkpoints (both training and sampler).
Args:
@ -242,7 +251,9 @@ class RestClient(TelemetryProvider):
return self._list_checkpoints_submit(training_run_id).future()
@capture_exceptions(fatal=True)
async def list_checkpoints_async(self, training_run_id: types.ModelID) -> types.CheckpointsListResponse:
async def list_checkpoints_async(
self, training_run_id: types.ModelID
) -> types.CheckpointsListResponse:
"""Async version of list_checkpoints.
Args:
@ -261,64 +272,66 @@ class RestClient(TelemetryProvider):
"""
return await self._list_checkpoints_submit(training_run_id)
def _download_checkpoint_archive_submit(
def _get_checkpoint_archive_url_submit(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> AwaitableConcurrentFuture[bytes]:
"""Internal method to submit download checkpoint archive request."""
async def _download_checkpoint_archive_async():
async def _send_request():
) -> AwaitableConcurrentFuture[types.CheckpointArchiveUrlResponse]:
"""Internal method to submit get checkpoint archive URL request."""
async def _get_checkpoint_archive_url_async() -> types.CheckpointArchiveUrlResponse:
async def _send_request() -> types.CheckpointArchiveUrlResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.get(
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/archive",
cast_to=bytes,
options={"headers": {"accept": "application/gzip"}},
return await client.weights.get_checkpoint_archive_url(
model_id=training_run_id,
checkpoint_id=checkpoint_id,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_download_checkpoint_archive_async())
return self.holder.run_coroutine_threadsafe(_get_checkpoint_archive_url_async())
@sync_only
@capture_exceptions(fatal=True)
def download_checkpoint_archive(
def get_checkpoint_archive_url(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> ConcurrentFuture[bytes]:
"""Download checkpoint as a tar.gz archive.
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]:
"""Get signed URL to download checkpoint archive.
Args:
training_run_id: The training run ID to download weights for
checkpoint_id: The checkpoint ID to download
Returns:
A Future containing the archive data as bytes
A Future containing the CheckpointArchiveUrlResponse with signed URL and expiration
Example:
>>> future = rest_client.download_checkpoint_archive("run-id", "checkpoint-123")
>>> archive_data = future.result()
>>> with open(f"model-checkpoint.tar.gz", "wb") as f:
... f.write(archive_data)
>>> future = rest_client.get_checkpoint_archive_url("run-id", "checkpoint-123")
>>> response = future.result()
>>> print(f"Download URL: {response.url}")
>>> print(f"Expires at: {response.expires_at}")
>>> # Use the URL to download the archive with your preferred HTTP client
"""
return self._download_checkpoint_archive_submit(training_run_id, checkpoint_id).future()
return self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id).future()
@capture_exceptions(fatal=True)
async def download_checkpoint_archive_async(
async def get_checkpoint_archive_url_async(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> bytes:
"""Async version of download_checkpoint_archive.
) -> types.CheckpointArchiveUrlResponse:
"""Async version of get_checkpoint_archive_url.
Args:
training_run_id: The model ID to download weights for
checkpoint_id: The checkpoint ID to download
Returns:
Archive data as bytes
CheckpointArchiveUrlResponse with signed URL and expiration
Example:
>>> archive_data = await rest_client.download_checkpoint_archive_async("run-id", "checkpoint-123")
>>> with open(f"model-checkpoint.tar.gz", "wb") as f:
... f.write(archive_data)
>>> response = await rest_client.get_checkpoint_archive_url_async("run-id", "checkpoint-123")
>>> print(f"Download URL: {response.url}")
>>> print(f"Expires at: {response.expires_at}")
>>> # Use the URL to download the archive with your preferred HTTP client
"""
return await self._download_checkpoint_archive_submit(training_run_id, checkpoint_id)
return await self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id)
def _delete_checkpoint_submit(
self, training_run_id: types.ModelID, checkpoint_id: str
@ -339,13 +352,17 @@ class RestClient(TelemetryProvider):
@sync_only
@capture_exceptions(fatal=True)
def delete_checkpoint(self, training_run_id: types.ModelID, checkpoint_id: str) -> ConcurrentFuture[None]:
def delete_checkpoint(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> ConcurrentFuture[None]:
"""Delete a checkpoint for a training run."""
return self._delete_checkpoint_submit(training_run_id, checkpoint_id).future()
@capture_exceptions(fatal=True)
async def delete_checkpoint_async(self, training_run_id: types.ModelID, checkpoint_id: str) -> None:
async def delete_checkpoint_async(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> None:
"""Async version of delete_checkpoint."""
await self._delete_checkpoint_submit(training_run_id, checkpoint_id)
@ -356,42 +373,53 @@ class RestClient(TelemetryProvider):
"""Delete a checkpoint referenced by a tinker path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._delete_checkpoint_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id).future()
return self._delete_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
).future()
@capture_exceptions(fatal=True)
async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of delete_checkpoint_from_tinker_path."""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._delete_checkpoint_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id)
await self._delete_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
)
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
@sync_only
@capture_exceptions(fatal=True)
def download_checkpoint_archive_from_tinker_path(
def get_checkpoint_archive_url_from_tinker_path(
self, tinker_path: str
) -> ConcurrentFuture[bytes]:
"""Download checkpoint as a tar.gz archive.
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]:
"""Get signed URL to download checkpoint archive.
Args:
tinker_path: The tinker path to the checkpoint
Returns:
A Future containing the archive data as bytes
A Future containing the CheckpointArchiveUrlResponse with signed URL and expiration
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._download_checkpoint_archive_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id).future()
return self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
).future()
@capture_exceptions(fatal=True)
async def download_checkpoint_archive_from_tinker_path_async(
async def get_checkpoint_archive_url_from_tinker_path_async(
self, tinker_path: str
) -> bytes:
"""Async version of download_checkpoint_archive_from_tinker_path.
) -> types.CheckpointArchiveUrlResponse:
"""Async version of get_checkpoint_archive_url_from_tinker_path.
Args:
tinker_path: The tinker path to the checkpoint
Returns:
CheckpointArchiveUrlResponse with signed URL and expiration
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return await self._download_checkpoint_archive_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id)
return await self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
)

View file

@ -78,13 +78,18 @@ class ServiceClient(TelemetryProvider):
return await self._get_server_capabilities_submit()
def _create_model_submit(
self, base_model: str, lora_config: types.LoraConfig
self,
base_model: str,
lora_config: types.LoraConfig,
user_metadata: dict[str, str] | None = None,
) -> AwaitableConcurrentFuture[types.ModelID]:
async def _create_model_async():
start_time = time.time()
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
future = await client.models.create(
base_model=base_model, lora_config=_to_lora_config_params(lora_config)
base_model=base_model,
lora_config=_to_lora_config_params(lora_config),
user_metadata=user_metadata,
)
create_model_response = await _APIFuture(
types.CreateModelResponse,
@ -107,6 +112,7 @@ class ServiceClient(TelemetryProvider):
train_mlp: bool = True,
train_attn: bool = True,
train_unembed: bool = True,
user_metadata: dict[str, str] | None = None,
) -> TrainingClient:
assert any([train_mlp, train_attn, train_unembed]), (
"At least one of train_mlp, train_attn, or train_unembed must be True"
@ -120,6 +126,7 @@ class ServiceClient(TelemetryProvider):
train_attn=train_attn,
train_unembed=train_unembed,
),
user_metadata=user_metadata,
).result()
logger.info(f"Creating TrainingClient for {model_id=}")
return self.create_training_client(model_id)
@ -133,6 +140,7 @@ class ServiceClient(TelemetryProvider):
train_mlp: bool = True,
train_attn: bool = True,
train_unembed: bool = True,
user_metadata: dict[str, str] | None = None,
) -> TrainingClient:
assert any([train_mlp, train_attn, train_unembed]), (
"At least one of train_mlp, train_attn, or train_unembed must be True"
@ -146,6 +154,7 @@ class ServiceClient(TelemetryProvider):
train_attn=train_attn,
train_unembed=train_unembed,
),
user_metadata=user_metadata,
)
logger.info(f"Creating TrainingClient for {model_id=}")
return self.create_training_client(model_id)
@ -165,6 +174,7 @@ class ServiceClient(TelemetryProvider):
training_client = self.create_lora_training_client(
base_model=training_run.base_model,
rank=training_run.lora_rank,
user_metadata=training_run.user_metadata,
)
training_client.load_state(path).result()
@ -181,6 +191,7 @@ class ServiceClient(TelemetryProvider):
training_client = await self.create_lora_training_client_async(
base_model=training_run.base_model,
rank=training_run.lora_rank,
user_metadata=training_run.user_metadata,
)
load_future = await training_client.load_state_async(path)
@ -230,8 +241,6 @@ def _get_default_headers() -> dict[str, str]:
if (api_key := os.environ.get("TINKER_API_KEY", "")) and "X-API-Key" not in headers:
headers["X-API-Key"] = api_key
headers["X-Username"] = os.environ.get("USER", "")
if (
client_id := os.environ.get("CLOUDFLARE_ACCESS_CLIENT_ID")
) and "CF-Access-Client-Id" not in headers:

View file

@ -438,39 +438,6 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
return self.save_weights_for_sampler(name)
@capture_exceptions(fatal=True)
def unload_model(
self,
) -> APIFuture[types.UnloadModelResponse]:
request_id = self._get_request_id()
@capture_exceptions(fatal=True)
async def _unload_model_async():
start_time = time.time()
async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.models.unload(
model_id=self._guaranteed_model_id(),
idempotency_key=self._make_idempotency_key(request_id),
)
async with self._take_turn(request_id):
future = await self.holder.execute_with_retries(_send_request)
return await _APIFuture(
types.UnloadModelResponse,
self.holder,
future,
request_start_time=start_time,
request_type="UnloadModel",
queue_state_observer=self,
)
return self.holder.run_coroutine_threadsafe(_unload_model_async())
async def unload_model_async(self) -> APIFuture[types.UnloadModelResponse]:
return self.unload_model()
def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]:
request_id = self._get_request_id()

View file

@ -17,7 +17,7 @@ from typing import Any, Awaitable, Callable, Generic, Type, TypeVar
import httpx
import tinker
from tinker.lib.telemetry import Telemetry
from tinker.lib.telemetry import Telemetry, is_user_error
from .._constants import (
DEFAULT_CONNECTION_LIMITS,
@ -191,7 +191,7 @@ class RetryHandler(Generic[T]): # noqa: UP046
exception_str = f"{type(e).__name__}: {str(e) or 'No error message'}"
self._errors_since_last_retry[exception_str] += 1
should_retry = self._should_retry(e)
user_error = is_user_error(e)
if telemetry := self.get_telemetry():
current_time = time.time()
telemetry.log(
@ -209,12 +209,13 @@ class RetryHandler(Generic[T]): # noqa: UP046
else None,
"status_code": getattr(e, "status_code", None),
"should_retry": should_retry,
"is_user_error": user_error,
"attempt_count": attempt_count,
"start_time": start_time,
"current_time": current_time,
"elapsed_time": current_time - start_time,
},
severity="WARNING" if should_retry else "ERROR",
severity="WARNING" if should_retry or user_error else "ERROR",
)
if not should_retry:

View file

@ -159,7 +159,7 @@ class Telemetry:
return self._log(self._generic_event(event_name, event_data, severity))
async def log_exception(self, exception: BaseException, severity: Severity = "ERROR") -> bool:
logged = self._log(self._exception_event(exception, severity))
logged = self._log(self._exception_or_user_error_event(exception, severity))
# trigger flush but don't block on it
self._trigger_flush()
return logged
@ -167,7 +167,9 @@ class Telemetry:
async def log_fatal_exception(
self, exception: BaseException, severity: Severity = "ERROR"
) -> bool:
logged = self._log(self._exception_event(exception, severity), self._session_end_event())
logged = self._log(
self._exception_or_user_error_event(exception, severity), self._session_end_event()
)
self._trigger_flush()
# wait for the flush to complete
_ = await self._wait_until_drained()
@ -177,7 +179,7 @@ class Telemetry:
@sync_only
def log_exception_sync(self, exception: BaseException, severity: Severity = "ERROR") -> bool:
logged = self._log(self._exception_event(exception, severity))
logged = self._log(self._exception_or_user_error_event(exception, severity))
# trigger flush but don't block on it
self._trigger_flush()
return logged
@ -186,7 +188,9 @@ class Telemetry:
def log_fatal_exception_sync(
self, exception: BaseException, severity: Severity = "ERROR"
) -> bool:
logged = self._log(self._exception_event(exception, severity), self._session_end_event())
logged = self._log(
self._exception_or_user_error_event(exception, severity), self._session_end_event()
)
self._trigger_flush()
# wait for the flush to complete
if _current_loop() is None:
@ -242,6 +246,15 @@ class Telemetry:
duration=str(end_time - self._session_start),
)
def _exception_or_user_error_event(
self, exception: BaseException, severity: Severity
) -> TelemetryEvent:
return (
self._user_error_event(exception)
if is_user_error(exception)
else self._exception_event(exception, severity)
)
def _exception_event(
self, exception: BaseException, severity: Severity
) -> UnhandledExceptionEvent:
@ -260,6 +273,18 @@ class Telemetry:
else None,
)
def _user_error_event(self, exception: BaseException) -> GenericEvent:
data: dict[str, object] = {"error_type": exception.__class__.__name__}
if message := str(exception):
data["message"] = message
if user_error := _get_user_error(exception):
status_code = getattr(user_error, "status_code", None)
if isinstance(status_code, int):
data["status_code"] = status_code
if body := getattr(user_error, "body", None):
data["body"] = body
return self._generic_event("user_error", data, "WARNING")
def _next_session_index(self) -> int:
with self._session_index_lock:
idx = self._session_index
@ -372,6 +397,37 @@ def capture_exceptions(
return _decorate if func is None else _decorate(func)
def is_user_error(exception: BaseException) -> bool:
return _get_user_error(exception) is not None
def _get_user_error(
exception: BaseException, visited: set[int] | None = None
) -> BaseException | None:
visited = set() if visited is None else visited
if id(exception) in visited:
return None
visited.add(id(exception))
if (
(status_code := getattr(exception, "status_code", None))
and isinstance(status_code, int)
and 400 <= status_code < 500
and status_code != 408
):
return exception
if (cause := getattr(exception, "__cause__", None)) is not None and (
user_error := _get_user_error(cause, visited)
) is not None:
return user_error
if (context := getattr(exception, "__context__", None)) is not None:
return _get_user_error(context, visited)
return None
def _to_send_params(batch: TelemetryBatch) -> TelemetrySendParams:
return cast(TelemetrySendParams, cast(object, batch.model_dump()))

View file

@ -8,10 +8,17 @@ from concurrent.futures import Future as ConcurrentFuture
from typing import Any, TypeVar, cast
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import httpx
import pytest
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker._exceptions import (
APIStatusError,
BadRequestError,
ConflictError,
UnprocessableEntityError,
)
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker.lib.telemetry import (
MAX_BATCH_SIZE,
MAX_QUEUE_SIZE,
@ -189,6 +196,30 @@ class TestTelemetryClass:
assert exception_event.error_type == "RuntimeError"
assert exception_event.error_message == "Test exception"
def test_log_exception_sync_user_error(self):
request = httpx.Request("GET", "https://example.com")
response = httpx.Response(400, request=request)
error = BadRequestError(
"Invalid request payload",
response=response,
body={"detail": "bad request"},
)
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
result = self.telemetry.log_exception_sync(error, "ERROR")
assert result is True
assert len(self.telemetry._queue) == 2
mock_trigger.assert_called_once()
generic_event = self.telemetry._queue[-1]
assert isinstance(generic_event, GenericEvent)
assert generic_event.event_name == "user_error"
assert generic_event.severity == "WARNING"
assert generic_event.event_data["error_type"] == "BadRequestError"
assert generic_event.event_data["status_code"] == 400
assert generic_event.event_data["message"] == "Invalid request payload"
assert generic_event.event_data["body"] == {"detail": "bad request"}
@pytest.mark.asyncio
async def test_log_exception_async(self):
try:
@ -225,6 +256,84 @@ class TestTelemetryClass:
end_event = self.telemetry._queue[-1]
assert isinstance(end_event, SessionEndEvent)
def test_log_fatal_exception_sync_user_error(self):
request = httpx.Request("GET", "https://example.com")
response = httpx.Response(422, request=request)
error = UnprocessableEntityError(
"Payload is invalid",
response=response,
body={"errors": ["invalid field"]},
)
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
with patch.object(
self.telemetry, "_wait_until_drained_sync", return_value=True
) as mock_wait:
result = self.telemetry.log_fatal_exception_sync(error, "ERROR")
assert result is True
assert len(self.telemetry._queue) == 3
mock_trigger.assert_called_once()
mock_wait.assert_called_once()
generic_event = self.telemetry._queue[-2]
assert isinstance(generic_event, GenericEvent)
assert generic_event.event_name == "user_error"
assert generic_event.severity == "WARNING"
assert generic_event.event_data["error_type"] == "UnprocessableEntityError"
assert generic_event.event_data["status_code"] == 422
assert generic_event.event_data["message"] == "Payload is invalid"
assert generic_event.event_data["body"] == {"errors": ["invalid field"]}
end_event = self.telemetry._queue[-1]
assert isinstance(end_event, SessionEndEvent)
def test_log_exception_sync_timeout_not_user_error(self):
request = httpx.Request("GET", "https://example.com")
response = httpx.Response(408, request=request)
error = APIStatusError(
"Request timed out",
response=response,
body={"error": "Request timed out"},
)
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
result = self.telemetry.log_exception_sync(error, "ERROR")
assert result is True
assert len(self.telemetry._queue) == 2
mock_trigger.assert_called_once()
exception_event = self.telemetry._queue[-1]
assert isinstance(exception_event, UnhandledExceptionEvent)
assert exception_event.error_type == "APIStatusError"
def test_log_exception_sync_value_error_with_api_status_cause(self):
request = httpx.Request("GET", "https://example.com")
response = httpx.Response(409, request=request)
conflict_error = ConflictError(
"Resource already exists",
response=response,
body={"detail": "Resource already exists"},
)
try:
raise conflict_error
except ConflictError as exc:
try:
raise ValueError("Wrapped user error") from exc
except ValueError as outer:
wrapped_error = outer
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
result = self.telemetry.log_exception_sync(wrapped_error, "ERROR")
assert result is True
assert len(self.telemetry._queue) == 2
mock_trigger.assert_called_once()
generic_event = self.telemetry._queue[-1]
assert isinstance(generic_event, GenericEvent)
assert generic_event.event_name == "user_error"
assert generic_event.event_data["status_code"] == 409
assert generic_event.event_data["body"] == {"detail": "Resource already exists"}
@pytest.mark.asyncio
async def test_log_fatal_exception_async(self):
try:

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from .models import (
ModelsResource,
AsyncModelsResource,

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Any, cast

View file

@ -1,26 +1,23 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import httpx
from typing_extensions import Literal
import httpx
from ..types import ModelID, model_create_params, model_unload_params, model_get_info_params
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import maybe_transform, async_maybe_transform
from .._base_client import make_request_options
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._resource import AsyncAPIResource, SyncAPIResource
from .._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
to_raw_response_wrapper,
to_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..types.model_id import ModelID
from .._types import NOT_GIVEN, Body, Headers, NotGiven, Query
from .._utils import async_maybe_transform, maybe_transform
from ..types import ModelID, model_create_params, model_get_info_params, model_unload_params
from ..types.get_info_response import GetInfoResponse
from ..types.lora_config_param import LoraConfigParam
from ..types.model_id import ModelID
from ..types.shared.untyped_api_future import UntypedAPIFuture
__all__ = ["ModelsResource", "AsyncModelsResource"]
@ -50,6 +47,7 @@ class ModelsResource(SyncAPIResource):
self,
*,
base_model: str,
user_metadata: dict[str, str] | None = None,
lora_config: LoraConfigParam | NotGiven = NOT_GIVEN,
type: Literal["create_model"] = "create_model",
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
@ -81,6 +79,7 @@ class ModelsResource(SyncAPIResource):
body=maybe_transform(
{
"base_model": base_model,
"user_metadata": user_metadata,
"lora_config": lora_config,
"type": type,
},
@ -213,6 +212,7 @@ class AsyncModelsResource(AsyncAPIResource):
self,
*,
base_model: str,
user_metadata: dict[str, str] | None = None,
lora_config: LoraConfigParam | NotGiven = NOT_GIVEN,
type: Literal["create_model"] = "create_model",
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
@ -244,6 +244,7 @@ class AsyncModelsResource(AsyncAPIResource):
body=await async_maybe_transform(
{
"base_model": base_model,
"user_metadata": user_metadata,
"lora_config": lora_config,
"type": type,
},

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
import httpx

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Iterable

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal

View file

@ -1,14 +1,14 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal
import datetime
import httpx
from ..types import (
ModelID,
CheckpointsListResponse,
CheckpointArchiveUrlResponse,
weight_load_params,
weight_save_params,
weight_save_for_sampler_params,
@ -23,6 +23,7 @@ from .._response import (
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from .._exceptions import APIStatusError
from .._base_client import make_request_options
from ..types.shared.untyped_api_future import UntypedAPIFuture
@ -270,6 +271,84 @@ class WeightsResource(SyncAPIResource):
return None
def get_checkpoint_archive_url(
self,
*,
model_id: ModelID,
checkpoint_id: str,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> CheckpointArchiveUrlResponse:
"""
Get signed URL to download checkpoint archive.
Args:
model_id: The training run ID to download weights for
checkpoint_id: The checkpoint ID to download
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not model_id:
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
if not checkpoint_id:
raise ValueError(
f"Expected a non-empty value for `checkpoint_id` but received {checkpoint_id!r}"
)
from .._response import APIResponse
# Merge the accept header
merged_headers: Headers = {"accept": "application/gzip"}
if extra_headers is not None:
merged_headers = {**merged_headers, **extra_headers}
options = make_request_options(
extra_headers=merged_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
)
options["follow_redirects"] = False
try:
response = self._get(
f"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_id}/archive",
cast_to=APIResponse,
options=options,
)
except APIStatusError as e:
# On success, this API responds with a 302
if e.status_code != 302:
raise e
location = e.response.headers.get("Location")
if location is None:
raise e
expires = datetime.datetime.now(datetime.UTC) + datetime.timedelta(minutes=15)
try:
if expires_header := e.response.headers.get("Expires"):
expires = datetime.datetime.strptime(expires_header, "%a, %d %b %Y %H:%M:%S GMT")
except ValueError:
pass
return CheckpointArchiveUrlResponse(
url=location,
expires=expires,
)
raise Exception("Unexpected error while getting checkpoint archive URL")
class AsyncWeightsResource(AsyncAPIResource):
@cached_property
def with_raw_response(self) -> AsyncWeightsResourceWithRawResponse:
@ -511,6 +590,83 @@ class AsyncWeightsResource(AsyncAPIResource):
return None
async def get_checkpoint_archive_url(
self,
*,
model_id: ModelID,
checkpoint_id: str,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> CheckpointArchiveUrlResponse:
"""
Get signed URL to download checkpoint archive.
Args:
model_id: The training run ID to download weights for
checkpoint_id: The checkpoint ID to download
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds
"""
if not model_id:
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
if not checkpoint_id:
raise ValueError(
f"Expected a non-empty value for `checkpoint_id` but received {checkpoint_id!r}"
)
from urllib.parse import urlparse, parse_qs
from .._response import AsyncAPIResponse
# Merge the accept header
merged_headers: Headers = {"accept": "application/gzip"}
if extra_headers is not None:
merged_headers = {**merged_headers, **extra_headers}
options = make_request_options(
extra_headers=merged_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
)
options["follow_redirects"] = False
try:
response = await self._get(
f"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_id}/archive",
cast_to=AsyncAPIResponse,
options=options,
)
except APIStatusError as e:
# On success, this API responds with a 302
if e.status_code != 302:
raise e
location = e.response.headers.get("Location")
if location is None:
raise e
expires = datetime.datetime.now(datetime.UTC) + datetime.timedelta(minutes=15)
try:
if expires_header := e.response.headers.get("Expires"):
expires = datetime.datetime.strptime(expires_header, "%a, %d %b %Y %H:%M:%S GMT")
except ValueError:
pass
return CheckpointArchiveUrlResponse(
url=location,
expires=expires,
)
raise Exception("Unexpected error while getting checkpoint archive URL")
class WeightsResourceWithRawResponse:
def __init__(self, weights: WeightsResource) -> None:
@ -531,6 +687,9 @@ class WeightsResourceWithRawResponse:
self.delete_checkpoint = to_raw_response_wrapper(
weights.delete_checkpoint,
)
self.get_checkpoint_archive_url = to_raw_response_wrapper(
weights.get_checkpoint_archive_url,
)
class AsyncWeightsResourceWithRawResponse:
@ -552,6 +711,9 @@ class AsyncWeightsResourceWithRawResponse:
self.delete_checkpoint = async_to_raw_response_wrapper(
weights.delete_checkpoint,
)
self.get_checkpoint_archive_url = async_to_raw_response_wrapper(
weights.get_checkpoint_archive_url,
)
class WeightsResourceWithStreamingResponse:
@ -573,6 +735,9 @@ class WeightsResourceWithStreamingResponse:
self.delete_checkpoint = to_streamed_response_wrapper(
weights.delete_checkpoint,
)
self.get_checkpoint_archive_url = to_streamed_response_wrapper(
weights.get_checkpoint_archive_url,
)
class AsyncWeightsResourceWithStreamingResponse:
@ -594,3 +759,6 @@ class AsyncWeightsResourceWithStreamingResponse:
self.delete_checkpoint = async_to_streamed_response_wrapper(
weights.delete_checkpoint,
)
self.get_checkpoint_archive_url = async_to_streamed_response_wrapper(
weights.get_checkpoint_archive_url,
)

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
# There's an underscore in front of *Param classes (TypedDict) because they shouldn't be used.
@ -16,7 +14,6 @@ from .stop_reason import StopReason as StopReason
from .tensor_data import TensorData as TensorData
from .loss_fn_type import LossFnType as LossFnType
from .tensor_dtype import TensorDtype as TensorDtype
from .error_response import ErrorResponse as ErrorResponse
from .loss_fn_inputs import LossFnInputs as LossFnInputs
from .loss_fn_output import LossFnOutput as LossFnOutput
from .sample_request import SampleRequest as SampleRequest
@ -35,12 +32,18 @@ from .model_input_param import ModelInputParam as _ModelInputParam
from .tensor_data_param import TensorDataParam as _TensorDataParam
from .encoded_text_chunk import EncodedTextChunk as EncodedTextChunk
from .optim_step_request import OptimStepRequest as OptimStepRequest
from .checkpoint import Checkpoint as Checkpoint, CheckpointType as CheckpointType, ParsedCheckpointTinkerPath as ParsedCheckpointTinkerPath
from .checkpoint import (
Checkpoint as Checkpoint,
CheckpointType as CheckpointType,
ParsedCheckpointTinkerPath as ParsedCheckpointTinkerPath,
)
from .weight_load_params import WeightLoadParams as _WeightLoadParams
from .weight_save_params import WeightSaveParams as _WeightSaveParams
from .checkpoints_list_response import CheckpointsListResponse as CheckpointsListResponse
from .checkpoint_archive_url_response import (
CheckpointArchiveUrlResponse as CheckpointArchiveUrlResponse,
)
from .cursor import Cursor as Cursor
from .training_run_ids_response import TrainingRunIdsResponse as TrainingRunIdsResponse
from .training_runs_response import TrainingRunsResponse as TrainingRunsResponse
from .forward_backward_input_param import ForwardBackwardInputParam as _ForwardBackwardInputParam
from .forward_backward_input import ForwardBackwardInput as ForwardBackwardInput
@ -49,7 +52,6 @@ from .model_create_params import ModelCreateParams as _ModelCreateParams
from .model_unload_params import ModelUnloadParams as _ModelUnloadParams
from .session_end_event import SessionEndEvent as SessionEndEvent
from .telemetry_response import TelemetryResponse as TelemetryResponse
from .try_again_response import TryAgainResponse as TryAgainResponse
from .optim_step_response import OptimStepResponse as OptimStepResponse
from .session_start_event import SessionStartEvent as SessionStartEvent
from .create_model_request import CreateModelRequest as CreateModelRequest
@ -72,19 +74,32 @@ from .sampling_params_param import SamplingParamsParam as _SamplingParamsParam
from .sampling_sample_params import SamplingSampleParams as _SamplingSampleParams
from .sampling_asample_params import SamplingAsampleParams as _SamplingAsampleParams
from .future_retrieve_response import FutureRetrieveResponse as FutureRetrieveResponse
from .compute_logprobs_response import ComputeLogprobsResponse as ComputeLogprobsResponse
from .image_asset_pointer_chunk import ImageAssetPointerChunk as ImageAssetPointerChunk
from .training_optim_step_params import TrainingOptimStepParams as _TrainingOptimStepParams
from .weight_save_for_sampler_params import WeightSaveForSamplerParams as _WeightSaveForSamplerParams
from .image_asset_pointer_chunk_param import ImageAssetPointerChunkParam as _ImageAssetPointerChunkParam
from .weight_save_for_sampler_params import (
WeightSaveForSamplerParams as _WeightSaveForSamplerParams,
)
from .image_asset_pointer_chunk_param import (
ImageAssetPointerChunkParam as _ImageAssetPointerChunkParam,
)
from .session_end_event_param import SessionEndEventParam as _SessionEndEventParam
from .session_start_event_param import SessionStartEventParam as _SessionStartEventParam
from .unhandled_exception_event import UnhandledExceptionEvent as UnhandledExceptionEvent
from .unhandled_exception_event_param import UnhandledExceptionEventParam as _UnhandledExceptionEventParam
from .get_server_capabilities_response import GetServerCapabilitiesResponse as GetServerCapabilitiesResponse
from .save_weights_for_sampler_request import SaveWeightsForSamplerRequest as SaveWeightsForSamplerRequest
from .unhandled_exception_event_param import (
UnhandledExceptionEventParam as _UnhandledExceptionEventParam,
)
from .get_server_capabilities_response import (
GetServerCapabilitiesResponse as GetServerCapabilitiesResponse,
)
from .save_weights_for_sampler_request import (
SaveWeightsForSamplerRequest as SaveWeightsForSamplerRequest,
)
from .get_server_capabilities_response import SupportedModel as SupportedModel
from .training_forward_backward_params import TrainingForwardBackwardParams as _TrainingForwardBackwardParams
from .save_weights_for_sampler_response import SaveWeightsForSamplerResponse as SaveWeightsForSamplerResponse
from .training_forward_backward_params import (
TrainingForwardBackwardParams as _TrainingForwardBackwardParams,
)
from .save_weights_for_sampler_response import (
SaveWeightsForSamplerResponse as SaveWeightsForSamplerResponse,
)
from .optim_step_request import AdamParams as AdamParams
from .training_run import TrainingRun as TrainingRun

View file

@ -0,0 +1,12 @@
import datetime
from .._models import BaseModel
__all__ = ["CheckpointArchiveUrlResponse"]
class CheckpointArchiveUrlResponse(BaseModel):
url: str
"""Signed URL to download the checkpoint archive"""
expires: datetime.datetime
"""Unix timestamp when the signed URL expires, if available"""

View file

@ -1,14 +0,0 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal, Sequence
from .._models import BaseModel
__all__ = ["ComputeLogprobsResponse"]
class ComputeLogprobsResponse(BaseModel):
logprobs: Sequence[Optional[float]]
type: Literal["compute_logprobs"] = "compute_logprobs"

View file

@ -1,6 +1,5 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Any, Optional
from typing import Optional
from typing_extensions import Literal
from .._models import StrictBase
@ -12,6 +11,9 @@ __all__ = ["CreateModelRequest"]
class CreateModelRequest(StrictBase):
base_model: str
"""Optional metadata about this model/training run, set by the end-user"""
user_metadata: Optional[dict[str, Any]] = None
lora_config: Optional[LoraConfig] = None
type: Literal["create_model"] = "create_model"

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal
from .._compat import PYDANTIC_V2, ConfigDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Any, TYPE_CHECKING
from pydantic import model_validator

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Required, TypedDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Sequence
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Iterable

View file

@ -1,18 +0,0 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Dict, Optional
from .._models import BaseModel
__all__ = ["ErrorResponse"]
class ErrorResponse(BaseModel):
error: str
"""Error code"""
message: str
"""Human-readable error message"""
details: Optional[Dict[str, object]] = None
"""Additional error details"""

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal, TypeAlias
__all__ = ["EventType"]

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import List
from .datum import Datum

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Iterable

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Dict, List
from .._models import BaseModel

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Required, TypedDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Union
from typing_extensions import TypeAlias

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from datetime import datetime
from typing import Dict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from datetime import datetime

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional, Literal
from .._compat import PYDANTIC_V2, ConfigDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import List, Optional
from .._models import BaseModel

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal
from .._models import BaseModel

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal
from .._models import StrictBase

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal, Required, TypedDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from .._models import StrictBase

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Optional

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Dict
from typing_extensions import TypeAlias

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Dict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Dict
from typing_extensions import TypeAlias

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal, TypeAlias
__all__ = ["LossFnType"]

View file

@ -1,7 +1,7 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Any, Optional
from typing_extensions import Literal, Required, TypedDict
from .lora_config_param import LoraConfigParam
@ -14,4 +14,7 @@ class ModelCreateParams(TypedDict, total=False):
lora_config: LoraConfigParam
"""Optional metadata about this model/training run, set by the end-user"""
user_metadata: Optional[dict[str, Any]]
type: Literal["create_model"] = "create_model"

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal, Required, TypedDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import TypeAlias
__all__ = ["ModelID"]

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import List
from .._models import StrictBase

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Union
from typing_extensions import Annotated, TypeAlias

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Union

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Iterable

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal, Required, TypedDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Dict, Optional
from .._models import BaseModel

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import TypeAlias
__all__ = ["RequestID"]

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import List, Optional, Sequence
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import List, Optional
from .._models import BaseModel

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal, Required, TypedDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional, Sequence, Union
from .._models import BaseModel

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import List, Union, Iterable, Optional

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Literal, Required, TypedDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from typing_extensions import Literal

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from datetime import datetime
from .._models import BaseModel

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Union

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from datetime import datetime
from .._models import BaseModel

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Union

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal, TypeAlias
__all__ = ["Severity"]

View file

@ -1,3 +1 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from .untyped_api_future import UntypedAPIFuture as UntypedAPIFuture

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Optional
from ..._compat import PYDANTIC_V2, ConfigDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal, TypeAlias
__all__ = ["StopReason"]

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import List
from .._models import BaseModel

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import Union
from typing_extensions import TypeAlias

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Union

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal
from .._models import BaseModel

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Iterable

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing import List, Union, Optional, TYPE_CHECKING, Any
from .._models import StrictBase

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing import Union, Iterable, Optional

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from typing_extensions import Literal, TypeAlias
__all__ = ["TensorDtype"]

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Required, TypedDict

View file

@ -1,5 +1,3 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
from __future__ import annotations
from typing_extensions import Required, TypedDict

Some files were not shown because too many files have changed in this diff Show more