mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
9ba155a34d
commit
422bc7b9f1
122 changed files with 516 additions and 477 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"last_synced_sha": "b099cce3c688f59468230f58027c38c2a81bcc13",
|
"last_synced_sha": "db70edcca9cdd2d06faa9e4af37538547f308b14",
|
||||||
"last_sync_time": "2025-10-01T18:17:16.846167"
|
"last_sync_time": "2025-10-15T16:37:21.808358"
|
||||||
}
|
}
|
||||||
|
|
@ -1,4 +0,0 @@
|
||||||
File generated from our OpenAPI spec by Stainless.
|
|
||||||
|
|
||||||
This directory can be used to store example files demonstrating usage of this SDK.
|
|
||||||
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "tinker"
|
name = "tinker"
|
||||||
version = "0.1.5"
|
version = "0.2.0"
|
||||||
description = "The official Python SDK for the tinker API"
|
description = "The official Python SDK for the tinker API"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
set -exuo pipefail
|
|
||||||
|
|
||||||
FILENAME=$(basename dist/*.whl)
|
|
||||||
|
|
||||||
RESPONSE=$(curl -X POST "$URL?filename=$FILENAME" \
|
|
||||||
-H "Authorization: Bearer $AUTH" \
|
|
||||||
-H "Content-Type: application/json")
|
|
||||||
|
|
||||||
SIGNED_URL=$(echo "$RESPONSE" | jq -r '.url')
|
|
||||||
|
|
||||||
if [[ "$SIGNED_URL" == "null" ]]; then
|
|
||||||
echo -e "\033[31mFailed to get signed URL.\033[0m"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
UPLOAD_RESPONSE=$(curl -v -X PUT \
|
|
||||||
-H "Content-Type: binary/octet-stream" \
|
|
||||||
--data-binary "@dist/$FILENAME" "$SIGNED_URL" 2>&1)
|
|
||||||
|
|
||||||
if echo "$UPLOAD_RESPONSE" | grep -q "HTTP/[0-9.]* 200"; then
|
|
||||||
echo -e "\033[32mUploaded build to Stainless storage.\033[0m"
|
|
||||||
echo -e "\033[32mInstallation: pip install 'https://pkg.stainless.com/s/tinker-python/$SHA/$FILENAME'\033[0m"
|
|
||||||
else
|
|
||||||
echo -e "\033[31mFailed to upload artifact.\033[0m"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
import typing as _t
|
import typing as _t
|
||||||
|
|
||||||
from . import types
|
from . import types
|
||||||
|
|
|
||||||
|
|
@ -1924,7 +1924,6 @@ def get_platform() -> Platform:
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def platform_headers(version: str, *, platform: Platform | None) -> Dict[str, str]:
|
def platform_headers(version: str, *, platform: Platform | None) -> Dict[str, str]:
|
||||||
return {
|
return {
|
||||||
"X-Stainless-Lang": "python",
|
|
||||||
"X-Stainless-Package-Version": version,
|
"X-Stainless-Package-Version": version,
|
||||||
"X-Stainless-OS": str(platform or get_platform()),
|
"X-Stainless-OS": str(platform or get_platform()),
|
||||||
"X-Stainless-Arch": str(get_architecture()),
|
"X-Stainless-Arch": str(get_architecture()),
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
|
RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response"
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,2 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
__title__ = "tinker"
|
__title__ = "tinker"
|
||||||
__version__ = "0.0.1-alpha.1"
|
__version__ = "0.0.1-alpha.1"
|
||||||
|
|
|
||||||
|
|
@ -1,4 +0,0 @@
|
||||||
File generated from our OpenAPI spec by Stainless.
|
|
||||||
|
|
||||||
This directory can be used to store custom files to expand the SDK.
|
|
||||||
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.
|
|
||||||
|
|
@ -14,7 +14,7 @@ import tinker
|
||||||
from tinker import types
|
from tinker import types
|
||||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||||
from tinker.lib.public_interfaces.api_future import APIFuture
|
from tinker.lib.public_interfaces.api_future import APIFuture
|
||||||
from tinker.lib.telemetry import Telemetry
|
from tinker.lib.telemetry import Telemetry, is_user_error
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
from .retryable_exception import RetryableException
|
from .retryable_exception import RetryableException
|
||||||
|
|
@ -121,6 +121,7 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||||
except tinker.APIStatusError as e:
|
except tinker.APIStatusError as e:
|
||||||
connection_error_retries = 0
|
connection_error_retries = 0
|
||||||
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
|
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
|
||||||
|
user_error = is_user_error(e)
|
||||||
if telemetry := self.get_telemetry():
|
if telemetry := self.get_telemetry():
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
telemetry.log(
|
telemetry.log(
|
||||||
|
|
@ -131,10 +132,11 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||||
"status_code": e.status_code,
|
"status_code": e.status_code,
|
||||||
"exception": str(e),
|
"exception": str(e),
|
||||||
"should_retry": should_retry,
|
"should_retry": should_retry,
|
||||||
|
"is_user_error": user_error,
|
||||||
"iteration": iteration,
|
"iteration": iteration,
|
||||||
"elapsed_time": current_time - start_time,
|
"elapsed_time": current_time - start_time,
|
||||||
},
|
},
|
||||||
severity="WARNING" if should_retry else "ERROR",
|
severity="WARNING" if should_retry or user_error else "ERROR",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Retry 408s until we time out
|
# Retry 408s until we time out
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from tinker._exceptions import APIConnectionError, APIStatusError
|
||||||
from tinker.lib.async_tinker_provider import AsyncTinkerProvider
|
from tinker.lib.async_tinker_provider import AsyncTinkerProvider
|
||||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||||
from tinker.lib.telemetry import Telemetry, init_telemetry
|
from tinker.lib.telemetry import Telemetry, init_telemetry, is_user_error
|
||||||
from tinker.lib.telemetry_provider import TelemetryProvider
|
from tinker.lib.telemetry_provider import TelemetryProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -190,6 +190,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
is_retryable = self._is_retryable_exception(e)
|
is_retryable = self._is_retryable_exception(e)
|
||||||
|
user_error = is_user_error(e)
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
elapsed_time = current_time - start_time
|
elapsed_time = current_time - start_time
|
||||||
if telemetry := self.get_telemetry():
|
if telemetry := self.get_telemetry():
|
||||||
|
|
@ -208,12 +209,13 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
else None,
|
else None,
|
||||||
"status_code": getattr(e, "status_code", None),
|
"status_code": getattr(e, "status_code", None),
|
||||||
"is_retryable": is_retryable,
|
"is_retryable": is_retryable,
|
||||||
|
"is_user_error": user_error,
|
||||||
"attempt_count": attempt_count,
|
"attempt_count": attempt_count,
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"current_time": current_time,
|
"current_time": current_time,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
},
|
},
|
||||||
severity="WARNING" if is_retryable else "ERROR",
|
severity="WARNING" if is_retryable or user_error else "ERROR",
|
||||||
)
|
)
|
||||||
if is_retryable and elapsed_time < MAX_WAIT_TIME:
|
if is_retryable and elapsed_time < MAX_WAIT_TIME:
|
||||||
# Apply exponential backoff
|
# Apply exponential backoff
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ class RestClient(TelemetryProvider):
|
||||||
- list_checkpoints() - list available model checkpoints (both training and sampler)
|
- list_checkpoints() - list available model checkpoints (both training and sampler)
|
||||||
- get_training_run() - get model information and metadata as ModelEntry
|
- get_training_run() - get model information and metadata as ModelEntry
|
||||||
- delete_checkpoint() - delete an existing checkpoint for a training run
|
- delete_checkpoint() - delete an existing checkpoint for a training run
|
||||||
- download_sampler_weights_archive() - download sampler weights checkpoint as tar.gz archive
|
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
holder: Internal client managing HTTP connections and async operations
|
holder: Internal client managing HTTP connections and async operations
|
||||||
|
|
@ -55,6 +55,7 @@ class RestClient(TelemetryProvider):
|
||||||
self, training_run_id: types.ModelID
|
self, training_run_id: types.ModelID
|
||||||
) -> AwaitableConcurrentFuture[types.TrainingRun]:
|
) -> AwaitableConcurrentFuture[types.TrainingRun]:
|
||||||
"""Internal method to submit get model request."""
|
"""Internal method to submit get model request."""
|
||||||
|
|
||||||
async def _get_training_run_async() -> types.TrainingRun:
|
async def _get_training_run_async() -> types.TrainingRun:
|
||||||
async def _send_request() -> types.TrainingRun:
|
async def _send_request() -> types.TrainingRun:
|
||||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
|
|
@ -69,7 +70,9 @@ class RestClient(TelemetryProvider):
|
||||||
|
|
||||||
@sync_only
|
@sync_only
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
def get_training_run(self, training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun]:
|
def get_training_run(
|
||||||
|
self, training_run_id: types.ModelID
|
||||||
|
) -> ConcurrentFuture[types.TrainingRun]:
|
||||||
"""Get training run info.
|
"""Get training run info.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -103,7 +106,9 @@ class RestClient(TelemetryProvider):
|
||||||
|
|
||||||
@sync_only
|
@sync_only
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
def get_training_run_by_tinker_path(self, tinker_path: str) -> ConcurrentFuture[types.TrainingRun]:
|
def get_training_run_by_tinker_path(
|
||||||
|
self, tinker_path: str
|
||||||
|
) -> ConcurrentFuture[types.TrainingRun]:
|
||||||
"""Get training run info.
|
"""Get training run info.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -145,6 +150,7 @@ class RestClient(TelemetryProvider):
|
||||||
self, limit: int = 20, offset: int = 0
|
self, limit: int = 20, offset: int = 0
|
||||||
) -> AwaitableConcurrentFuture[types.TrainingRunsResponse]:
|
) -> AwaitableConcurrentFuture[types.TrainingRunsResponse]:
|
||||||
"""Internal method to submit list training runs request."""
|
"""Internal method to submit list training runs request."""
|
||||||
|
|
||||||
async def _list_training_runs_async() -> types.TrainingRunsResponse:
|
async def _list_training_runs_async() -> types.TrainingRunsResponse:
|
||||||
async def _send_request() -> types.TrainingRunsResponse:
|
async def _send_request() -> types.TrainingRunsResponse:
|
||||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
|
|
@ -210,6 +216,7 @@ class RestClient(TelemetryProvider):
|
||||||
self, training_run_id: types.ModelID
|
self, training_run_id: types.ModelID
|
||||||
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
|
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
|
||||||
"""Internal method to submit list model checkpoints request."""
|
"""Internal method to submit list model checkpoints request."""
|
||||||
|
|
||||||
async def _list_checkpoints_async():
|
async def _list_checkpoints_async():
|
||||||
async def _send_request():
|
async def _send_request():
|
||||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
|
|
@ -221,7 +228,9 @@ class RestClient(TelemetryProvider):
|
||||||
|
|
||||||
@sync_only
|
@sync_only
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
def list_checkpoints(self, training_run_id: types.ModelID) -> ConcurrentFuture[types.CheckpointsListResponse]:
|
def list_checkpoints(
|
||||||
|
self, training_run_id: types.ModelID
|
||||||
|
) -> ConcurrentFuture[types.CheckpointsListResponse]:
|
||||||
"""List available checkpoints (both training and sampler).
|
"""List available checkpoints (both training and sampler).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -242,7 +251,9 @@ class RestClient(TelemetryProvider):
|
||||||
return self._list_checkpoints_submit(training_run_id).future()
|
return self._list_checkpoints_submit(training_run_id).future()
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
async def list_checkpoints_async(self, training_run_id: types.ModelID) -> types.CheckpointsListResponse:
|
async def list_checkpoints_async(
|
||||||
|
self, training_run_id: types.ModelID
|
||||||
|
) -> types.CheckpointsListResponse:
|
||||||
"""Async version of list_checkpoints.
|
"""Async version of list_checkpoints.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -261,64 +272,66 @@ class RestClient(TelemetryProvider):
|
||||||
"""
|
"""
|
||||||
return await self._list_checkpoints_submit(training_run_id)
|
return await self._list_checkpoints_submit(training_run_id)
|
||||||
|
|
||||||
def _download_checkpoint_archive_submit(
|
def _get_checkpoint_archive_url_submit(
|
||||||
self, training_run_id: types.ModelID, checkpoint_id: str
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||||
) -> AwaitableConcurrentFuture[bytes]:
|
) -> AwaitableConcurrentFuture[types.CheckpointArchiveUrlResponse]:
|
||||||
"""Internal method to submit download checkpoint archive request."""
|
"""Internal method to submit get checkpoint archive URL request."""
|
||||||
async def _download_checkpoint_archive_async():
|
|
||||||
async def _send_request():
|
async def _get_checkpoint_archive_url_async() -> types.CheckpointArchiveUrlResponse:
|
||||||
|
async def _send_request() -> types.CheckpointArchiveUrlResponse:
|
||||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
return await client.get(
|
return await client.weights.get_checkpoint_archive_url(
|
||||||
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/archive",
|
model_id=training_run_id,
|
||||||
cast_to=bytes,
|
checkpoint_id=checkpoint_id,
|
||||||
options={"headers": {"accept": "application/gzip"}},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.holder.execute_with_retries(_send_request)
|
return await self.holder.execute_with_retries(_send_request)
|
||||||
|
|
||||||
return self.holder.run_coroutine_threadsafe(_download_checkpoint_archive_async())
|
return self.holder.run_coroutine_threadsafe(_get_checkpoint_archive_url_async())
|
||||||
|
|
||||||
@sync_only
|
@sync_only
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
def download_checkpoint_archive(
|
def get_checkpoint_archive_url(
|
||||||
self, training_run_id: types.ModelID, checkpoint_id: str
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||||
) -> ConcurrentFuture[bytes]:
|
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]:
|
||||||
"""Download checkpoint as a tar.gz archive.
|
"""Get signed URL to download checkpoint archive.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
training_run_id: The training run ID to download weights for
|
training_run_id: The training run ID to download weights for
|
||||||
checkpoint_id: The checkpoint ID to download
|
checkpoint_id: The checkpoint ID to download
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Future containing the archive data as bytes
|
A Future containing the CheckpointArchiveUrlResponse with signed URL and expiration
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> future = rest_client.download_checkpoint_archive("run-id", "checkpoint-123")
|
>>> future = rest_client.get_checkpoint_archive_url("run-id", "checkpoint-123")
|
||||||
>>> archive_data = future.result()
|
>>> response = future.result()
|
||||||
>>> with open(f"model-checkpoint.tar.gz", "wb") as f:
|
>>> print(f"Download URL: {response.url}")
|
||||||
... f.write(archive_data)
|
>>> print(f"Expires at: {response.expires_at}")
|
||||||
|
>>> # Use the URL to download the archive with your preferred HTTP client
|
||||||
"""
|
"""
|
||||||
return self._download_checkpoint_archive_submit(training_run_id, checkpoint_id).future()
|
return self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id).future()
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
async def download_checkpoint_archive_async(
|
async def get_checkpoint_archive_url_async(
|
||||||
self, training_run_id: types.ModelID, checkpoint_id: str
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||||
) -> bytes:
|
) -> types.CheckpointArchiveUrlResponse:
|
||||||
"""Async version of download_checkpoint_archive.
|
"""Async version of get_checkpoint_archive_url.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
training_run_id: The model ID to download weights for
|
training_run_id: The model ID to download weights for
|
||||||
checkpoint_id: The checkpoint ID to download
|
checkpoint_id: The checkpoint ID to download
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Archive data as bytes
|
CheckpointArchiveUrlResponse with signed URL and expiration
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> archive_data = await rest_client.download_checkpoint_archive_async("run-id", "checkpoint-123")
|
>>> response = await rest_client.get_checkpoint_archive_url_async("run-id", "checkpoint-123")
|
||||||
>>> with open(f"model-checkpoint.tar.gz", "wb") as f:
|
>>> print(f"Download URL: {response.url}")
|
||||||
... f.write(archive_data)
|
>>> print(f"Expires at: {response.expires_at}")
|
||||||
|
>>> # Use the URL to download the archive with your preferred HTTP client
|
||||||
"""
|
"""
|
||||||
return await self._download_checkpoint_archive_submit(training_run_id, checkpoint_id)
|
return await self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id)
|
||||||
|
|
||||||
def _delete_checkpoint_submit(
|
def _delete_checkpoint_submit(
|
||||||
self, training_run_id: types.ModelID, checkpoint_id: str
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||||
|
|
@ -339,13 +352,17 @@ class RestClient(TelemetryProvider):
|
||||||
|
|
||||||
@sync_only
|
@sync_only
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
def delete_checkpoint(self, training_run_id: types.ModelID, checkpoint_id: str) -> ConcurrentFuture[None]:
|
def delete_checkpoint(
|
||||||
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||||
|
) -> ConcurrentFuture[None]:
|
||||||
"""Delete a checkpoint for a training run."""
|
"""Delete a checkpoint for a training run."""
|
||||||
|
|
||||||
return self._delete_checkpoint_submit(training_run_id, checkpoint_id).future()
|
return self._delete_checkpoint_submit(training_run_id, checkpoint_id).future()
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
async def delete_checkpoint_async(self, training_run_id: types.ModelID, checkpoint_id: str) -> None:
|
async def delete_checkpoint_async(
|
||||||
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||||
|
) -> None:
|
||||||
"""Async version of delete_checkpoint."""
|
"""Async version of delete_checkpoint."""
|
||||||
|
|
||||||
await self._delete_checkpoint_submit(training_run_id, checkpoint_id)
|
await self._delete_checkpoint_submit(training_run_id, checkpoint_id)
|
||||||
|
|
@ -356,42 +373,53 @@ class RestClient(TelemetryProvider):
|
||||||
"""Delete a checkpoint referenced by a tinker path."""
|
"""Delete a checkpoint referenced by a tinker path."""
|
||||||
|
|
||||||
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
return self._delete_checkpoint_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id).future()
|
return self._delete_checkpoint_submit(
|
||||||
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||||
|
).future()
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
async def delete_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
||||||
"""Async version of delete_checkpoint_from_tinker_path."""
|
"""Async version of delete_checkpoint_from_tinker_path."""
|
||||||
|
|
||||||
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
await self._delete_checkpoint_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id)
|
await self._delete_checkpoint_submit(
|
||||||
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||||
|
)
|
||||||
|
|
||||||
def get_telemetry(self) -> Telemetry | None:
|
def get_telemetry(self) -> Telemetry | None:
|
||||||
return self.holder.get_telemetry()
|
return self.holder.get_telemetry()
|
||||||
|
|
||||||
@sync_only
|
@sync_only
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
def download_checkpoint_archive_from_tinker_path(
|
def get_checkpoint_archive_url_from_tinker_path(
|
||||||
self, tinker_path: str
|
self, tinker_path: str
|
||||||
) -> ConcurrentFuture[bytes]:
|
) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]:
|
||||||
"""Download checkpoint as a tar.gz archive.
|
"""Get signed URL to download checkpoint archive.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tinker_path: The tinker path to the checkpoint
|
tinker_path: The tinker path to the checkpoint
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Future containing the archive data as bytes
|
A Future containing the CheckpointArchiveUrlResponse with signed URL and expiration
|
||||||
"""
|
"""
|
||||||
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
return self._download_checkpoint_archive_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id).future()
|
return self._get_checkpoint_archive_url_submit(
|
||||||
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||||
|
).future()
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
async def download_checkpoint_archive_from_tinker_path_async(
|
async def get_checkpoint_archive_url_from_tinker_path_async(
|
||||||
self, tinker_path: str
|
self, tinker_path: str
|
||||||
) -> bytes:
|
) -> types.CheckpointArchiveUrlResponse:
|
||||||
"""Async version of download_checkpoint_archive_from_tinker_path.
|
"""Async version of get_checkpoint_archive_url_from_tinker_path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tinker_path: The tinker path to the checkpoint
|
tinker_path: The tinker path to the checkpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckpointArchiveUrlResponse with signed URL and expiration
|
||||||
"""
|
"""
|
||||||
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
return await self._download_checkpoint_archive_submit(parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id)
|
return await self._get_checkpoint_archive_url_submit(
|
||||||
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -78,13 +78,18 @@ class ServiceClient(TelemetryProvider):
|
||||||
return await self._get_server_capabilities_submit()
|
return await self._get_server_capabilities_submit()
|
||||||
|
|
||||||
def _create_model_submit(
|
def _create_model_submit(
|
||||||
self, base_model: str, lora_config: types.LoraConfig
|
self,
|
||||||
|
base_model: str,
|
||||||
|
lora_config: types.LoraConfig,
|
||||||
|
user_metadata: dict[str, str] | None = None,
|
||||||
) -> AwaitableConcurrentFuture[types.ModelID]:
|
) -> AwaitableConcurrentFuture[types.ModelID]:
|
||||||
async def _create_model_async():
|
async def _create_model_async():
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
future = await client.models.create(
|
future = await client.models.create(
|
||||||
base_model=base_model, lora_config=_to_lora_config_params(lora_config)
|
base_model=base_model,
|
||||||
|
lora_config=_to_lora_config_params(lora_config),
|
||||||
|
user_metadata=user_metadata,
|
||||||
)
|
)
|
||||||
create_model_response = await _APIFuture(
|
create_model_response = await _APIFuture(
|
||||||
types.CreateModelResponse,
|
types.CreateModelResponse,
|
||||||
|
|
@ -107,6 +112,7 @@ class ServiceClient(TelemetryProvider):
|
||||||
train_mlp: bool = True,
|
train_mlp: bool = True,
|
||||||
train_attn: bool = True,
|
train_attn: bool = True,
|
||||||
train_unembed: bool = True,
|
train_unembed: bool = True,
|
||||||
|
user_metadata: dict[str, str] | None = None,
|
||||||
) -> TrainingClient:
|
) -> TrainingClient:
|
||||||
assert any([train_mlp, train_attn, train_unembed]), (
|
assert any([train_mlp, train_attn, train_unembed]), (
|
||||||
"At least one of train_mlp, train_attn, or train_unembed must be True"
|
"At least one of train_mlp, train_attn, or train_unembed must be True"
|
||||||
|
|
@ -120,6 +126,7 @@ class ServiceClient(TelemetryProvider):
|
||||||
train_attn=train_attn,
|
train_attn=train_attn,
|
||||||
train_unembed=train_unembed,
|
train_unembed=train_unembed,
|
||||||
),
|
),
|
||||||
|
user_metadata=user_metadata,
|
||||||
).result()
|
).result()
|
||||||
logger.info(f"Creating TrainingClient for {model_id=}")
|
logger.info(f"Creating TrainingClient for {model_id=}")
|
||||||
return self.create_training_client(model_id)
|
return self.create_training_client(model_id)
|
||||||
|
|
@ -133,6 +140,7 @@ class ServiceClient(TelemetryProvider):
|
||||||
train_mlp: bool = True,
|
train_mlp: bool = True,
|
||||||
train_attn: bool = True,
|
train_attn: bool = True,
|
||||||
train_unembed: bool = True,
|
train_unembed: bool = True,
|
||||||
|
user_metadata: dict[str, str] | None = None,
|
||||||
) -> TrainingClient:
|
) -> TrainingClient:
|
||||||
assert any([train_mlp, train_attn, train_unembed]), (
|
assert any([train_mlp, train_attn, train_unembed]), (
|
||||||
"At least one of train_mlp, train_attn, or train_unembed must be True"
|
"At least one of train_mlp, train_attn, or train_unembed must be True"
|
||||||
|
|
@ -146,6 +154,7 @@ class ServiceClient(TelemetryProvider):
|
||||||
train_attn=train_attn,
|
train_attn=train_attn,
|
||||||
train_unembed=train_unembed,
|
train_unembed=train_unembed,
|
||||||
),
|
),
|
||||||
|
user_metadata=user_metadata,
|
||||||
)
|
)
|
||||||
logger.info(f"Creating TrainingClient for {model_id=}")
|
logger.info(f"Creating TrainingClient for {model_id=}")
|
||||||
return self.create_training_client(model_id)
|
return self.create_training_client(model_id)
|
||||||
|
|
@ -165,6 +174,7 @@ class ServiceClient(TelemetryProvider):
|
||||||
training_client = self.create_lora_training_client(
|
training_client = self.create_lora_training_client(
|
||||||
base_model=training_run.base_model,
|
base_model=training_run.base_model,
|
||||||
rank=training_run.lora_rank,
|
rank=training_run.lora_rank,
|
||||||
|
user_metadata=training_run.user_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
training_client.load_state(path).result()
|
training_client.load_state(path).result()
|
||||||
|
|
@ -181,6 +191,7 @@ class ServiceClient(TelemetryProvider):
|
||||||
training_client = await self.create_lora_training_client_async(
|
training_client = await self.create_lora_training_client_async(
|
||||||
base_model=training_run.base_model,
|
base_model=training_run.base_model,
|
||||||
rank=training_run.lora_rank,
|
rank=training_run.lora_rank,
|
||||||
|
user_metadata=training_run.user_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
load_future = await training_client.load_state_async(path)
|
load_future = await training_client.load_state_async(path)
|
||||||
|
|
@ -230,8 +241,6 @@ def _get_default_headers() -> dict[str, str]:
|
||||||
if (api_key := os.environ.get("TINKER_API_KEY", "")) and "X-API-Key" not in headers:
|
if (api_key := os.environ.get("TINKER_API_KEY", "")) and "X-API-Key" not in headers:
|
||||||
headers["X-API-Key"] = api_key
|
headers["X-API-Key"] = api_key
|
||||||
|
|
||||||
headers["X-Username"] = os.environ.get("USER", "")
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
client_id := os.environ.get("CLOUDFLARE_ACCESS_CLIENT_ID")
|
client_id := os.environ.get("CLOUDFLARE_ACCESS_CLIENT_ID")
|
||||||
) and "CF-Access-Client-Id" not in headers:
|
) and "CF-Access-Client-Id" not in headers:
|
||||||
|
|
|
||||||
|
|
@ -438,39 +438,6 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
||||||
return self.save_weights_for_sampler(name)
|
return self.save_weights_for_sampler(name)
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
|
||||||
def unload_model(
|
|
||||||
self,
|
|
||||||
) -> APIFuture[types.UnloadModelResponse]:
|
|
||||||
request_id = self._get_request_id()
|
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
|
||||||
async def _unload_model_async():
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
async def _send_request():
|
|
||||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
|
||||||
return await client.models.unload(
|
|
||||||
model_id=self._guaranteed_model_id(),
|
|
||||||
idempotency_key=self._make_idempotency_key(request_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
async with self._take_turn(request_id):
|
|
||||||
future = await self.holder.execute_with_retries(_send_request)
|
|
||||||
return await _APIFuture(
|
|
||||||
types.UnloadModelResponse,
|
|
||||||
self.holder,
|
|
||||||
future,
|
|
||||||
request_start_time=start_time,
|
|
||||||
request_type="UnloadModel",
|
|
||||||
queue_state_observer=self,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.holder.run_coroutine_threadsafe(_unload_model_async())
|
|
||||||
|
|
||||||
async def unload_model_async(self) -> APIFuture[types.UnloadModelResponse]:
|
|
||||||
return self.unload_model()
|
|
||||||
|
|
||||||
def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]:
|
def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]:
|
||||||
request_id = self._get_request_id()
|
request_id = self._get_request_id()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from typing import Any, Awaitable, Callable, Generic, Type, TypeVar
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import tinker
|
import tinker
|
||||||
from tinker.lib.telemetry import Telemetry
|
from tinker.lib.telemetry import Telemetry, is_user_error
|
||||||
|
|
||||||
from .._constants import (
|
from .._constants import (
|
||||||
DEFAULT_CONNECTION_LIMITS,
|
DEFAULT_CONNECTION_LIMITS,
|
||||||
|
|
@ -191,7 +191,7 @@ class RetryHandler(Generic[T]): # noqa: UP046
|
||||||
exception_str = f"{type(e).__name__}: {str(e) or 'No error message'}"
|
exception_str = f"{type(e).__name__}: {str(e) or 'No error message'}"
|
||||||
self._errors_since_last_retry[exception_str] += 1
|
self._errors_since_last_retry[exception_str] += 1
|
||||||
should_retry = self._should_retry(e)
|
should_retry = self._should_retry(e)
|
||||||
|
user_error = is_user_error(e)
|
||||||
if telemetry := self.get_telemetry():
|
if telemetry := self.get_telemetry():
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
telemetry.log(
|
telemetry.log(
|
||||||
|
|
@ -209,12 +209,13 @@ class RetryHandler(Generic[T]): # noqa: UP046
|
||||||
else None,
|
else None,
|
||||||
"status_code": getattr(e, "status_code", None),
|
"status_code": getattr(e, "status_code", None),
|
||||||
"should_retry": should_retry,
|
"should_retry": should_retry,
|
||||||
|
"is_user_error": user_error,
|
||||||
"attempt_count": attempt_count,
|
"attempt_count": attempt_count,
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"current_time": current_time,
|
"current_time": current_time,
|
||||||
"elapsed_time": current_time - start_time,
|
"elapsed_time": current_time - start_time,
|
||||||
},
|
},
|
||||||
severity="WARNING" if should_retry else "ERROR",
|
severity="WARNING" if should_retry or user_error else "ERROR",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not should_retry:
|
if not should_retry:
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ class Telemetry:
|
||||||
return self._log(self._generic_event(event_name, event_data, severity))
|
return self._log(self._generic_event(event_name, event_data, severity))
|
||||||
|
|
||||||
async def log_exception(self, exception: BaseException, severity: Severity = "ERROR") -> bool:
|
async def log_exception(self, exception: BaseException, severity: Severity = "ERROR") -> bool:
|
||||||
logged = self._log(self._exception_event(exception, severity))
|
logged = self._log(self._exception_or_user_error_event(exception, severity))
|
||||||
# trigger flush but don't block on it
|
# trigger flush but don't block on it
|
||||||
self._trigger_flush()
|
self._trigger_flush()
|
||||||
return logged
|
return logged
|
||||||
|
|
@ -167,7 +167,9 @@ class Telemetry:
|
||||||
async def log_fatal_exception(
|
async def log_fatal_exception(
|
||||||
self, exception: BaseException, severity: Severity = "ERROR"
|
self, exception: BaseException, severity: Severity = "ERROR"
|
||||||
) -> bool:
|
) -> bool:
|
||||||
logged = self._log(self._exception_event(exception, severity), self._session_end_event())
|
logged = self._log(
|
||||||
|
self._exception_or_user_error_event(exception, severity), self._session_end_event()
|
||||||
|
)
|
||||||
self._trigger_flush()
|
self._trigger_flush()
|
||||||
# wait for the flush to complete
|
# wait for the flush to complete
|
||||||
_ = await self._wait_until_drained()
|
_ = await self._wait_until_drained()
|
||||||
|
|
@ -177,7 +179,7 @@ class Telemetry:
|
||||||
|
|
||||||
@sync_only
|
@sync_only
|
||||||
def log_exception_sync(self, exception: BaseException, severity: Severity = "ERROR") -> bool:
|
def log_exception_sync(self, exception: BaseException, severity: Severity = "ERROR") -> bool:
|
||||||
logged = self._log(self._exception_event(exception, severity))
|
logged = self._log(self._exception_or_user_error_event(exception, severity))
|
||||||
# trigger flush but don't block on it
|
# trigger flush but don't block on it
|
||||||
self._trigger_flush()
|
self._trigger_flush()
|
||||||
return logged
|
return logged
|
||||||
|
|
@ -186,7 +188,9 @@ class Telemetry:
|
||||||
def log_fatal_exception_sync(
|
def log_fatal_exception_sync(
|
||||||
self, exception: BaseException, severity: Severity = "ERROR"
|
self, exception: BaseException, severity: Severity = "ERROR"
|
||||||
) -> bool:
|
) -> bool:
|
||||||
logged = self._log(self._exception_event(exception, severity), self._session_end_event())
|
logged = self._log(
|
||||||
|
self._exception_or_user_error_event(exception, severity), self._session_end_event()
|
||||||
|
)
|
||||||
self._trigger_flush()
|
self._trigger_flush()
|
||||||
# wait for the flush to complete
|
# wait for the flush to complete
|
||||||
if _current_loop() is None:
|
if _current_loop() is None:
|
||||||
|
|
@ -242,6 +246,15 @@ class Telemetry:
|
||||||
duration=str(end_time - self._session_start),
|
duration=str(end_time - self._session_start),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _exception_or_user_error_event(
|
||||||
|
self, exception: BaseException, severity: Severity
|
||||||
|
) -> TelemetryEvent:
|
||||||
|
return (
|
||||||
|
self._user_error_event(exception)
|
||||||
|
if is_user_error(exception)
|
||||||
|
else self._exception_event(exception, severity)
|
||||||
|
)
|
||||||
|
|
||||||
def _exception_event(
|
def _exception_event(
|
||||||
self, exception: BaseException, severity: Severity
|
self, exception: BaseException, severity: Severity
|
||||||
) -> UnhandledExceptionEvent:
|
) -> UnhandledExceptionEvent:
|
||||||
|
|
@ -260,6 +273,18 @@ class Telemetry:
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _user_error_event(self, exception: BaseException) -> GenericEvent:
|
||||||
|
data: dict[str, object] = {"error_type": exception.__class__.__name__}
|
||||||
|
if message := str(exception):
|
||||||
|
data["message"] = message
|
||||||
|
if user_error := _get_user_error(exception):
|
||||||
|
status_code = getattr(user_error, "status_code", None)
|
||||||
|
if isinstance(status_code, int):
|
||||||
|
data["status_code"] = status_code
|
||||||
|
if body := getattr(user_error, "body", None):
|
||||||
|
data["body"] = body
|
||||||
|
return self._generic_event("user_error", data, "WARNING")
|
||||||
|
|
||||||
def _next_session_index(self) -> int:
|
def _next_session_index(self) -> int:
|
||||||
with self._session_index_lock:
|
with self._session_index_lock:
|
||||||
idx = self._session_index
|
idx = self._session_index
|
||||||
|
|
@ -372,6 +397,37 @@ def capture_exceptions(
|
||||||
return _decorate if func is None else _decorate(func)
|
return _decorate if func is None else _decorate(func)
|
||||||
|
|
||||||
|
|
||||||
|
def is_user_error(exception: BaseException) -> bool:
|
||||||
|
return _get_user_error(exception) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_error(
|
||||||
|
exception: BaseException, visited: set[int] | None = None
|
||||||
|
) -> BaseException | None:
|
||||||
|
visited = set() if visited is None else visited
|
||||||
|
if id(exception) in visited:
|
||||||
|
return None
|
||||||
|
visited.add(id(exception))
|
||||||
|
|
||||||
|
if (
|
||||||
|
(status_code := getattr(exception, "status_code", None))
|
||||||
|
and isinstance(status_code, int)
|
||||||
|
and 400 <= status_code < 500
|
||||||
|
and status_code != 408
|
||||||
|
):
|
||||||
|
return exception
|
||||||
|
|
||||||
|
if (cause := getattr(exception, "__cause__", None)) is not None and (
|
||||||
|
user_error := _get_user_error(cause, visited)
|
||||||
|
) is not None:
|
||||||
|
return user_error
|
||||||
|
|
||||||
|
if (context := getattr(exception, "__context__", None)) is not None:
|
||||||
|
return _get_user_error(context, visited)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _to_send_params(batch: TelemetryBatch) -> TelemetrySendParams:
|
def _to_send_params(batch: TelemetryBatch) -> TelemetrySendParams:
|
||||||
return cast(TelemetrySendParams, cast(object, batch.model_dump()))
|
return cast(TelemetrySendParams, cast(object, batch.model_dump()))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,17 @@ from concurrent.futures import Future as ConcurrentFuture
|
||||||
from typing import Any, TypeVar, cast
|
from typing import Any, TypeVar, cast
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
from tinker._exceptions import (
|
||||||
|
APIStatusError,
|
||||||
|
BadRequestError,
|
||||||
|
ConflictError,
|
||||||
|
UnprocessableEntityError,
|
||||||
|
)
|
||||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||||
|
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||||
from tinker.lib.telemetry import (
|
from tinker.lib.telemetry import (
|
||||||
MAX_BATCH_SIZE,
|
MAX_BATCH_SIZE,
|
||||||
MAX_QUEUE_SIZE,
|
MAX_QUEUE_SIZE,
|
||||||
|
|
@ -189,6 +196,30 @@ class TestTelemetryClass:
|
||||||
assert exception_event.error_type == "RuntimeError"
|
assert exception_event.error_type == "RuntimeError"
|
||||||
assert exception_event.error_message == "Test exception"
|
assert exception_event.error_message == "Test exception"
|
||||||
|
|
||||||
|
def test_log_exception_sync_user_error(self):
|
||||||
|
request = httpx.Request("GET", "https://example.com")
|
||||||
|
response = httpx.Response(400, request=request)
|
||||||
|
error = BadRequestError(
|
||||||
|
"Invalid request payload",
|
||||||
|
response=response,
|
||||||
|
body={"detail": "bad request"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
|
||||||
|
result = self.telemetry.log_exception_sync(error, "ERROR")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert len(self.telemetry._queue) == 2
|
||||||
|
mock_trigger.assert_called_once()
|
||||||
|
generic_event = self.telemetry._queue[-1]
|
||||||
|
assert isinstance(generic_event, GenericEvent)
|
||||||
|
assert generic_event.event_name == "user_error"
|
||||||
|
assert generic_event.severity == "WARNING"
|
||||||
|
assert generic_event.event_data["error_type"] == "BadRequestError"
|
||||||
|
assert generic_event.event_data["status_code"] == 400
|
||||||
|
assert generic_event.event_data["message"] == "Invalid request payload"
|
||||||
|
assert generic_event.event_data["body"] == {"detail": "bad request"}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_log_exception_async(self):
|
async def test_log_exception_async(self):
|
||||||
try:
|
try:
|
||||||
|
|
@ -225,6 +256,84 @@ class TestTelemetryClass:
|
||||||
end_event = self.telemetry._queue[-1]
|
end_event = self.telemetry._queue[-1]
|
||||||
assert isinstance(end_event, SessionEndEvent)
|
assert isinstance(end_event, SessionEndEvent)
|
||||||
|
|
||||||
|
def test_log_fatal_exception_sync_user_error(self):
|
||||||
|
request = httpx.Request("GET", "https://example.com")
|
||||||
|
response = httpx.Response(422, request=request)
|
||||||
|
error = UnprocessableEntityError(
|
||||||
|
"Payload is invalid",
|
||||||
|
response=response,
|
||||||
|
body={"errors": ["invalid field"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
|
||||||
|
with patch.object(
|
||||||
|
self.telemetry, "_wait_until_drained_sync", return_value=True
|
||||||
|
) as mock_wait:
|
||||||
|
result = self.telemetry.log_fatal_exception_sync(error, "ERROR")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert len(self.telemetry._queue) == 3
|
||||||
|
mock_trigger.assert_called_once()
|
||||||
|
mock_wait.assert_called_once()
|
||||||
|
generic_event = self.telemetry._queue[-2]
|
||||||
|
assert isinstance(generic_event, GenericEvent)
|
||||||
|
assert generic_event.event_name == "user_error"
|
||||||
|
assert generic_event.severity == "WARNING"
|
||||||
|
assert generic_event.event_data["error_type"] == "UnprocessableEntityError"
|
||||||
|
assert generic_event.event_data["status_code"] == 422
|
||||||
|
assert generic_event.event_data["message"] == "Payload is invalid"
|
||||||
|
assert generic_event.event_data["body"] == {"errors": ["invalid field"]}
|
||||||
|
end_event = self.telemetry._queue[-1]
|
||||||
|
assert isinstance(end_event, SessionEndEvent)
|
||||||
|
|
||||||
|
def test_log_exception_sync_timeout_not_user_error(self):
|
||||||
|
request = httpx.Request("GET", "https://example.com")
|
||||||
|
response = httpx.Response(408, request=request)
|
||||||
|
error = APIStatusError(
|
||||||
|
"Request timed out",
|
||||||
|
response=response,
|
||||||
|
body={"error": "Request timed out"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
|
||||||
|
result = self.telemetry.log_exception_sync(error, "ERROR")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert len(self.telemetry._queue) == 2
|
||||||
|
mock_trigger.assert_called_once()
|
||||||
|
exception_event = self.telemetry._queue[-1]
|
||||||
|
assert isinstance(exception_event, UnhandledExceptionEvent)
|
||||||
|
assert exception_event.error_type == "APIStatusError"
|
||||||
|
|
||||||
|
def test_log_exception_sync_value_error_with_api_status_cause(self):
|
||||||
|
request = httpx.Request("GET", "https://example.com")
|
||||||
|
response = httpx.Response(409, request=request)
|
||||||
|
conflict_error = ConflictError(
|
||||||
|
"Resource already exists",
|
||||||
|
response=response,
|
||||||
|
body={"detail": "Resource already exists"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
raise conflict_error
|
||||||
|
except ConflictError as exc:
|
||||||
|
try:
|
||||||
|
raise ValueError("Wrapped user error") from exc
|
||||||
|
except ValueError as outer:
|
||||||
|
wrapped_error = outer
|
||||||
|
|
||||||
|
with patch.object(self.telemetry, "_trigger_flush") as mock_trigger:
|
||||||
|
result = self.telemetry.log_exception_sync(wrapped_error, "ERROR")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert len(self.telemetry._queue) == 2
|
||||||
|
mock_trigger.assert_called_once()
|
||||||
|
generic_event = self.telemetry._queue[-1]
|
||||||
|
assert isinstance(generic_event, GenericEvent)
|
||||||
|
assert generic_event.event_name == "user_error"
|
||||||
|
assert generic_event.event_data["status_code"] == 409
|
||||||
|
assert generic_event.event_data["body"] == {"detail": "Resource already exists"}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_log_fatal_exception_async(self):
|
async def test_log_fatal_exception_async(self):
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
ModelsResource,
|
ModelsResource,
|
||||||
AsyncModelsResource,
|
AsyncModelsResource,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,23 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import httpx
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
import httpx
|
from .._base_client import make_request_options
|
||||||
|
|
||||||
from ..types import ModelID, model_create_params, model_unload_params, model_get_info_params
|
|
||||||
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
|
|
||||||
from .._utils import maybe_transform, async_maybe_transform
|
|
||||||
from .._compat import cached_property
|
from .._compat import cached_property
|
||||||
from .._resource import SyncAPIResource, AsyncAPIResource
|
from .._resource import AsyncAPIResource, SyncAPIResource
|
||||||
from .._response import (
|
from .._response import (
|
||||||
to_raw_response_wrapper,
|
|
||||||
to_streamed_response_wrapper,
|
|
||||||
async_to_raw_response_wrapper,
|
async_to_raw_response_wrapper,
|
||||||
async_to_streamed_response_wrapper,
|
async_to_streamed_response_wrapper,
|
||||||
|
to_raw_response_wrapper,
|
||||||
|
to_streamed_response_wrapper,
|
||||||
)
|
)
|
||||||
from .._base_client import make_request_options
|
from .._types import NOT_GIVEN, Body, Headers, NotGiven, Query
|
||||||
from ..types.model_id import ModelID
|
from .._utils import async_maybe_transform, maybe_transform
|
||||||
|
from ..types import ModelID, model_create_params, model_get_info_params, model_unload_params
|
||||||
from ..types.get_info_response import GetInfoResponse
|
from ..types.get_info_response import GetInfoResponse
|
||||||
from ..types.lora_config_param import LoraConfigParam
|
from ..types.lora_config_param import LoraConfigParam
|
||||||
|
from ..types.model_id import ModelID
|
||||||
from ..types.shared.untyped_api_future import UntypedAPIFuture
|
from ..types.shared.untyped_api_future import UntypedAPIFuture
|
||||||
|
|
||||||
__all__ = ["ModelsResource", "AsyncModelsResource"]
|
__all__ = ["ModelsResource", "AsyncModelsResource"]
|
||||||
|
|
@ -50,6 +47,7 @@ class ModelsResource(SyncAPIResource):
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
base_model: str,
|
base_model: str,
|
||||||
|
user_metadata: dict[str, str] | None = None,
|
||||||
lora_config: LoraConfigParam | NotGiven = NOT_GIVEN,
|
lora_config: LoraConfigParam | NotGiven = NOT_GIVEN,
|
||||||
type: Literal["create_model"] = "create_model",
|
type: Literal["create_model"] = "create_model",
|
||||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||||
|
|
@ -81,6 +79,7 @@ class ModelsResource(SyncAPIResource):
|
||||||
body=maybe_transform(
|
body=maybe_transform(
|
||||||
{
|
{
|
||||||
"base_model": base_model,
|
"base_model": base_model,
|
||||||
|
"user_metadata": user_metadata,
|
||||||
"lora_config": lora_config,
|
"lora_config": lora_config,
|
||||||
"type": type,
|
"type": type,
|
||||||
},
|
},
|
||||||
|
|
@ -213,6 +212,7 @@ class AsyncModelsResource(AsyncAPIResource):
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
base_model: str,
|
base_model: str,
|
||||||
|
user_metadata: dict[str, str] | None = None,
|
||||||
lora_config: LoraConfigParam | NotGiven = NOT_GIVEN,
|
lora_config: LoraConfigParam | NotGiven = NOT_GIVEN,
|
||||||
type: Literal["create_model"] = "create_model",
|
type: Literal["create_model"] = "create_model",
|
||||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||||
|
|
@ -244,6 +244,7 @@ class AsyncModelsResource(AsyncAPIResource):
|
||||||
body=await async_maybe_transform(
|
body=await async_maybe_transform(
|
||||||
{
|
{
|
||||||
"base_model": base_model,
|
"base_model": base_model,
|
||||||
|
"user_metadata": user_metadata,
|
||||||
"lora_config": lora_config,
|
"lora_config": lora_config,
|
||||||
"type": type,
|
"type": type,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
import datetime
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from ..types import (
|
from ..types import (
|
||||||
ModelID,
|
ModelID,
|
||||||
CheckpointsListResponse,
|
CheckpointsListResponse,
|
||||||
|
CheckpointArchiveUrlResponse,
|
||||||
weight_load_params,
|
weight_load_params,
|
||||||
weight_save_params,
|
weight_save_params,
|
||||||
weight_save_for_sampler_params,
|
weight_save_for_sampler_params,
|
||||||
|
|
@ -23,6 +23,7 @@ from .._response import (
|
||||||
async_to_raw_response_wrapper,
|
async_to_raw_response_wrapper,
|
||||||
async_to_streamed_response_wrapper,
|
async_to_streamed_response_wrapper,
|
||||||
)
|
)
|
||||||
|
from .._exceptions import APIStatusError
|
||||||
from .._base_client import make_request_options
|
from .._base_client import make_request_options
|
||||||
from ..types.shared.untyped_api_future import UntypedAPIFuture
|
from ..types.shared.untyped_api_future import UntypedAPIFuture
|
||||||
|
|
||||||
|
|
@ -270,6 +271,84 @@ class WeightsResource(SyncAPIResource):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_checkpoint_archive_url(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_id: ModelID,
|
||||||
|
checkpoint_id: str,
|
||||||
|
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||||
|
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
extra_query: Query | None = None,
|
||||||
|
extra_body: Body | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> CheckpointArchiveUrlResponse:
|
||||||
|
"""
|
||||||
|
Get signed URL to download checkpoint archive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The training run ID to download weights for
|
||||||
|
checkpoint_id: The checkpoint ID to download
|
||||||
|
|
||||||
|
extra_headers: Send extra headers
|
||||||
|
|
||||||
|
extra_query: Add additional query parameters to the request
|
||||||
|
|
||||||
|
extra_body: Add additional JSON properties to the request
|
||||||
|
|
||||||
|
timeout: Override the client-level default timeout for this request, in seconds
|
||||||
|
"""
|
||||||
|
if not model_id:
|
||||||
|
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
|
||||||
|
if not checkpoint_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected a non-empty value for `checkpoint_id` but received {checkpoint_id!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from .._response import APIResponse
|
||||||
|
|
||||||
|
# Merge the accept header
|
||||||
|
merged_headers: Headers = {"accept": "application/gzip"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
merged_headers = {**merged_headers, **extra_headers}
|
||||||
|
|
||||||
|
options = make_request_options(
|
||||||
|
extra_headers=merged_headers,
|
||||||
|
extra_query=extra_query,
|
||||||
|
extra_body=extra_body,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
options["follow_redirects"] = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self._get(
|
||||||
|
f"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_id}/archive",
|
||||||
|
cast_to=APIResponse,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
except APIStatusError as e:
|
||||||
|
# On success, this API responds with a 302
|
||||||
|
if e.status_code != 302:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
location = e.response.headers.get("Location")
|
||||||
|
if location is None:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
expires = datetime.datetime.now(datetime.UTC) + datetime.timedelta(minutes=15)
|
||||||
|
try:
|
||||||
|
if expires_header := e.response.headers.get("Expires"):
|
||||||
|
expires = datetime.datetime.strptime(expires_header, "%a, %d %b %Y %H:%M:%S GMT")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return CheckpointArchiveUrlResponse(
|
||||||
|
url=location,
|
||||||
|
expires=expires,
|
||||||
|
)
|
||||||
|
raise Exception("Unexpected error while getting checkpoint archive URL")
|
||||||
|
|
||||||
|
|
||||||
class AsyncWeightsResource(AsyncAPIResource):
|
class AsyncWeightsResource(AsyncAPIResource):
|
||||||
@cached_property
|
@cached_property
|
||||||
def with_raw_response(self) -> AsyncWeightsResourceWithRawResponse:
|
def with_raw_response(self) -> AsyncWeightsResourceWithRawResponse:
|
||||||
|
|
@ -511,6 +590,83 @@ class AsyncWeightsResource(AsyncAPIResource):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_checkpoint_archive_url(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_id: ModelID,
|
||||||
|
checkpoint_id: str,
|
||||||
|
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||||
|
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
extra_query: Query | None = None,
|
||||||
|
extra_body: Body | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> CheckpointArchiveUrlResponse:
|
||||||
|
"""
|
||||||
|
Get signed URL to download checkpoint archive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The training run ID to download weights for
|
||||||
|
checkpoint_id: The checkpoint ID to download
|
||||||
|
|
||||||
|
extra_headers: Send extra headers
|
||||||
|
|
||||||
|
extra_query: Add additional query parameters to the request
|
||||||
|
|
||||||
|
extra_body: Add additional JSON properties to the request
|
||||||
|
|
||||||
|
timeout: Override the client-level default timeout for this request, in seconds
|
||||||
|
"""
|
||||||
|
if not model_id:
|
||||||
|
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
|
||||||
|
if not checkpoint_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected a non-empty value for `checkpoint_id` but received {checkpoint_id!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
from .._response import AsyncAPIResponse
|
||||||
|
|
||||||
|
# Merge the accept header
|
||||||
|
merged_headers: Headers = {"accept": "application/gzip"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
merged_headers = {**merged_headers, **extra_headers}
|
||||||
|
|
||||||
|
options = make_request_options(
|
||||||
|
extra_headers=merged_headers,
|
||||||
|
extra_query=extra_query,
|
||||||
|
extra_body=extra_body,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
options["follow_redirects"] = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._get(
|
||||||
|
f"/api/v1/training_runs/{model_id}/checkpoints/{checkpoint_id}/archive",
|
||||||
|
cast_to=AsyncAPIResponse,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
except APIStatusError as e:
|
||||||
|
# On success, this API responds with a 302
|
||||||
|
if e.status_code != 302:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
location = e.response.headers.get("Location")
|
||||||
|
if location is None:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
expires = datetime.datetime.now(datetime.UTC) + datetime.timedelta(minutes=15)
|
||||||
|
try:
|
||||||
|
if expires_header := e.response.headers.get("Expires"):
|
||||||
|
expires = datetime.datetime.strptime(expires_header, "%a, %d %b %Y %H:%M:%S GMT")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return CheckpointArchiveUrlResponse(
|
||||||
|
url=location,
|
||||||
|
expires=expires,
|
||||||
|
)
|
||||||
|
raise Exception("Unexpected error while getting checkpoint archive URL")
|
||||||
|
|
||||||
class WeightsResourceWithRawResponse:
|
class WeightsResourceWithRawResponse:
|
||||||
def __init__(self, weights: WeightsResource) -> None:
|
def __init__(self, weights: WeightsResource) -> None:
|
||||||
|
|
@ -531,6 +687,9 @@ class WeightsResourceWithRawResponse:
|
||||||
self.delete_checkpoint = to_raw_response_wrapper(
|
self.delete_checkpoint = to_raw_response_wrapper(
|
||||||
weights.delete_checkpoint,
|
weights.delete_checkpoint,
|
||||||
)
|
)
|
||||||
|
self.get_checkpoint_archive_url = to_raw_response_wrapper(
|
||||||
|
weights.get_checkpoint_archive_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncWeightsResourceWithRawResponse:
|
class AsyncWeightsResourceWithRawResponse:
|
||||||
|
|
@ -552,6 +711,9 @@ class AsyncWeightsResourceWithRawResponse:
|
||||||
self.delete_checkpoint = async_to_raw_response_wrapper(
|
self.delete_checkpoint = async_to_raw_response_wrapper(
|
||||||
weights.delete_checkpoint,
|
weights.delete_checkpoint,
|
||||||
)
|
)
|
||||||
|
self.get_checkpoint_archive_url = async_to_raw_response_wrapper(
|
||||||
|
weights.get_checkpoint_archive_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WeightsResourceWithStreamingResponse:
|
class WeightsResourceWithStreamingResponse:
|
||||||
|
|
@ -573,6 +735,9 @@ class WeightsResourceWithStreamingResponse:
|
||||||
self.delete_checkpoint = to_streamed_response_wrapper(
|
self.delete_checkpoint = to_streamed_response_wrapper(
|
||||||
weights.delete_checkpoint,
|
weights.delete_checkpoint,
|
||||||
)
|
)
|
||||||
|
self.get_checkpoint_archive_url = to_streamed_response_wrapper(
|
||||||
|
weights.get_checkpoint_archive_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncWeightsResourceWithStreamingResponse:
|
class AsyncWeightsResourceWithStreamingResponse:
|
||||||
|
|
@ -594,3 +759,6 @@ class AsyncWeightsResourceWithStreamingResponse:
|
||||||
self.delete_checkpoint = async_to_streamed_response_wrapper(
|
self.delete_checkpoint = async_to_streamed_response_wrapper(
|
||||||
weights.delete_checkpoint,
|
weights.delete_checkpoint,
|
||||||
)
|
)
|
||||||
|
self.get_checkpoint_archive_url = async_to_streamed_response_wrapper(
|
||||||
|
weights.get_checkpoint_archive_url,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# There's an underscore in front of *Param classes (TypedDict) because they shouldn't be used.
|
# There's an underscore in front of *Param classes (TypedDict) because they shouldn't be used.
|
||||||
|
|
@ -16,7 +14,6 @@ from .stop_reason import StopReason as StopReason
|
||||||
from .tensor_data import TensorData as TensorData
|
from .tensor_data import TensorData as TensorData
|
||||||
from .loss_fn_type import LossFnType as LossFnType
|
from .loss_fn_type import LossFnType as LossFnType
|
||||||
from .tensor_dtype import TensorDtype as TensorDtype
|
from .tensor_dtype import TensorDtype as TensorDtype
|
||||||
from .error_response import ErrorResponse as ErrorResponse
|
|
||||||
from .loss_fn_inputs import LossFnInputs as LossFnInputs
|
from .loss_fn_inputs import LossFnInputs as LossFnInputs
|
||||||
from .loss_fn_output import LossFnOutput as LossFnOutput
|
from .loss_fn_output import LossFnOutput as LossFnOutput
|
||||||
from .sample_request import SampleRequest as SampleRequest
|
from .sample_request import SampleRequest as SampleRequest
|
||||||
|
|
@ -35,12 +32,18 @@ from .model_input_param import ModelInputParam as _ModelInputParam
|
||||||
from .tensor_data_param import TensorDataParam as _TensorDataParam
|
from .tensor_data_param import TensorDataParam as _TensorDataParam
|
||||||
from .encoded_text_chunk import EncodedTextChunk as EncodedTextChunk
|
from .encoded_text_chunk import EncodedTextChunk as EncodedTextChunk
|
||||||
from .optim_step_request import OptimStepRequest as OptimStepRequest
|
from .optim_step_request import OptimStepRequest as OptimStepRequest
|
||||||
from .checkpoint import Checkpoint as Checkpoint, CheckpointType as CheckpointType, ParsedCheckpointTinkerPath as ParsedCheckpointTinkerPath
|
from .checkpoint import (
|
||||||
|
Checkpoint as Checkpoint,
|
||||||
|
CheckpointType as CheckpointType,
|
||||||
|
ParsedCheckpointTinkerPath as ParsedCheckpointTinkerPath,
|
||||||
|
)
|
||||||
from .weight_load_params import WeightLoadParams as _WeightLoadParams
|
from .weight_load_params import WeightLoadParams as _WeightLoadParams
|
||||||
from .weight_save_params import WeightSaveParams as _WeightSaveParams
|
from .weight_save_params import WeightSaveParams as _WeightSaveParams
|
||||||
from .checkpoints_list_response import CheckpointsListResponse as CheckpointsListResponse
|
from .checkpoints_list_response import CheckpointsListResponse as CheckpointsListResponse
|
||||||
|
from .checkpoint_archive_url_response import (
|
||||||
|
CheckpointArchiveUrlResponse as CheckpointArchiveUrlResponse,
|
||||||
|
)
|
||||||
from .cursor import Cursor as Cursor
|
from .cursor import Cursor as Cursor
|
||||||
from .training_run_ids_response import TrainingRunIdsResponse as TrainingRunIdsResponse
|
|
||||||
from .training_runs_response import TrainingRunsResponse as TrainingRunsResponse
|
from .training_runs_response import TrainingRunsResponse as TrainingRunsResponse
|
||||||
from .forward_backward_input_param import ForwardBackwardInputParam as _ForwardBackwardInputParam
|
from .forward_backward_input_param import ForwardBackwardInputParam as _ForwardBackwardInputParam
|
||||||
from .forward_backward_input import ForwardBackwardInput as ForwardBackwardInput
|
from .forward_backward_input import ForwardBackwardInput as ForwardBackwardInput
|
||||||
|
|
@ -49,7 +52,6 @@ from .model_create_params import ModelCreateParams as _ModelCreateParams
|
||||||
from .model_unload_params import ModelUnloadParams as _ModelUnloadParams
|
from .model_unload_params import ModelUnloadParams as _ModelUnloadParams
|
||||||
from .session_end_event import SessionEndEvent as SessionEndEvent
|
from .session_end_event import SessionEndEvent as SessionEndEvent
|
||||||
from .telemetry_response import TelemetryResponse as TelemetryResponse
|
from .telemetry_response import TelemetryResponse as TelemetryResponse
|
||||||
from .try_again_response import TryAgainResponse as TryAgainResponse
|
|
||||||
from .optim_step_response import OptimStepResponse as OptimStepResponse
|
from .optim_step_response import OptimStepResponse as OptimStepResponse
|
||||||
from .session_start_event import SessionStartEvent as SessionStartEvent
|
from .session_start_event import SessionStartEvent as SessionStartEvent
|
||||||
from .create_model_request import CreateModelRequest as CreateModelRequest
|
from .create_model_request import CreateModelRequest as CreateModelRequest
|
||||||
|
|
@ -72,19 +74,32 @@ from .sampling_params_param import SamplingParamsParam as _SamplingParamsParam
|
||||||
from .sampling_sample_params import SamplingSampleParams as _SamplingSampleParams
|
from .sampling_sample_params import SamplingSampleParams as _SamplingSampleParams
|
||||||
from .sampling_asample_params import SamplingAsampleParams as _SamplingAsampleParams
|
from .sampling_asample_params import SamplingAsampleParams as _SamplingAsampleParams
|
||||||
from .future_retrieve_response import FutureRetrieveResponse as FutureRetrieveResponse
|
from .future_retrieve_response import FutureRetrieveResponse as FutureRetrieveResponse
|
||||||
from .compute_logprobs_response import ComputeLogprobsResponse as ComputeLogprobsResponse
|
|
||||||
from .image_asset_pointer_chunk import ImageAssetPointerChunk as ImageAssetPointerChunk
|
from .image_asset_pointer_chunk import ImageAssetPointerChunk as ImageAssetPointerChunk
|
||||||
from .training_optim_step_params import TrainingOptimStepParams as _TrainingOptimStepParams
|
from .training_optim_step_params import TrainingOptimStepParams as _TrainingOptimStepParams
|
||||||
from .weight_save_for_sampler_params import WeightSaveForSamplerParams as _WeightSaveForSamplerParams
|
from .weight_save_for_sampler_params import (
|
||||||
from .image_asset_pointer_chunk_param import ImageAssetPointerChunkParam as _ImageAssetPointerChunkParam
|
WeightSaveForSamplerParams as _WeightSaveForSamplerParams,
|
||||||
|
)
|
||||||
|
from .image_asset_pointer_chunk_param import (
|
||||||
|
ImageAssetPointerChunkParam as _ImageAssetPointerChunkParam,
|
||||||
|
)
|
||||||
from .session_end_event_param import SessionEndEventParam as _SessionEndEventParam
|
from .session_end_event_param import SessionEndEventParam as _SessionEndEventParam
|
||||||
from .session_start_event_param import SessionStartEventParam as _SessionStartEventParam
|
from .session_start_event_param import SessionStartEventParam as _SessionStartEventParam
|
||||||
from .unhandled_exception_event import UnhandledExceptionEvent as UnhandledExceptionEvent
|
from .unhandled_exception_event import UnhandledExceptionEvent as UnhandledExceptionEvent
|
||||||
from .unhandled_exception_event_param import UnhandledExceptionEventParam as _UnhandledExceptionEventParam
|
from .unhandled_exception_event_param import (
|
||||||
from .get_server_capabilities_response import GetServerCapabilitiesResponse as GetServerCapabilitiesResponse
|
UnhandledExceptionEventParam as _UnhandledExceptionEventParam,
|
||||||
from .save_weights_for_sampler_request import SaveWeightsForSamplerRequest as SaveWeightsForSamplerRequest
|
)
|
||||||
|
from .get_server_capabilities_response import (
|
||||||
|
GetServerCapabilitiesResponse as GetServerCapabilitiesResponse,
|
||||||
|
)
|
||||||
|
from .save_weights_for_sampler_request import (
|
||||||
|
SaveWeightsForSamplerRequest as SaveWeightsForSamplerRequest,
|
||||||
|
)
|
||||||
from .get_server_capabilities_response import SupportedModel as SupportedModel
|
from .get_server_capabilities_response import SupportedModel as SupportedModel
|
||||||
from .training_forward_backward_params import TrainingForwardBackwardParams as _TrainingForwardBackwardParams
|
from .training_forward_backward_params import (
|
||||||
from .save_weights_for_sampler_response import SaveWeightsForSamplerResponse as SaveWeightsForSamplerResponse
|
TrainingForwardBackwardParams as _TrainingForwardBackwardParams,
|
||||||
|
)
|
||||||
|
from .save_weights_for_sampler_response import (
|
||||||
|
SaveWeightsForSamplerResponse as SaveWeightsForSamplerResponse,
|
||||||
|
)
|
||||||
from .optim_step_request import AdamParams as AdamParams
|
from .optim_step_request import AdamParams as AdamParams
|
||||||
from .training_run import TrainingRun as TrainingRun
|
from .training_run import TrainingRun as TrainingRun
|
||||||
|
|
|
||||||
12
src/tinker/types/checkpoint_archive_url_response.py
Normal file
12
src/tinker/types/checkpoint_archive_url_response.py
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
import datetime
|
||||||
|
from .._models import BaseModel
|
||||||
|
|
||||||
|
__all__ = ["CheckpointArchiveUrlResponse"]
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointArchiveUrlResponse(BaseModel):
|
||||||
|
url: str
|
||||||
|
"""Signed URL to download the checkpoint archive"""
|
||||||
|
|
||||||
|
expires: datetime.datetime
|
||||||
|
"""Unix timestamp when the signed URL expires, if available"""
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
from typing_extensions import Literal, Sequence
|
|
||||||
|
|
||||||
from .._models import BaseModel
|
|
||||||
|
|
||||||
__all__ = ["ComputeLogprobsResponse"]
|
|
||||||
|
|
||||||
|
|
||||||
class ComputeLogprobsResponse(BaseModel):
|
|
||||||
logprobs: Sequence[Optional[float]]
|
|
||||||
|
|
||||||
type: Literal["compute_logprobs"] = "compute_logprobs"
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
from typing import Any, Optional
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from .._models import StrictBase
|
from .._models import StrictBase
|
||||||
|
|
@ -12,6 +11,9 @@ __all__ = ["CreateModelRequest"]
|
||||||
class CreateModelRequest(StrictBase):
|
class CreateModelRequest(StrictBase):
|
||||||
base_model: str
|
base_model: str
|
||||||
|
|
||||||
|
"""Optional metadata about this model/training run, set by the end-user"""
|
||||||
|
user_metadata: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
lora_config: Optional[LoraConfig] = None
|
lora_config: Optional[LoraConfig] = None
|
||||||
|
|
||||||
type: Literal["create_model"] = "create_model"
|
type: Literal["create_model"] = "create_model"
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from .._compat import PYDANTIC_V2, ConfigDict
|
from .._compat import PYDANTIC_V2, ConfigDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
from .._models import BaseModel
|
|
||||||
|
|
||||||
__all__ = ["ErrorResponse"]
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
|
||||||
error: str
|
|
||||||
"""Error code"""
|
|
||||||
|
|
||||||
message: str
|
|
||||||
"""Human-readable error message"""
|
|
||||||
|
|
||||||
details: Optional[Dict[str, object]] = None
|
|
||||||
"""Additional error details"""
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import Literal, TypeAlias
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
|
||||||
__all__ = ["EventType"]
|
__all__ = ["EventType"]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from .datum import Datum
|
from .datum import Datum
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional, Literal
|
from typing import Optional, Literal
|
||||||
|
|
||||||
from .._compat import PYDANTIC_V2, ConfigDict
|
from .._compat import PYDANTIC_V2, ConfigDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from .._models import StrictBase
|
from .._models import StrictBase
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Literal, Required, TypedDict
|
from typing_extensions import Literal, Required, TypedDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .._models import StrictBase
|
from .._models import StrictBase
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import Literal, TypeAlias
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
|
||||||
__all__ = ["LossFnType"]
|
__all__ = ["LossFnType"]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from typing_extensions import Literal, Required, TypedDict
|
from typing_extensions import Literal, Required, TypedDict
|
||||||
|
|
||||||
from .lora_config_param import LoraConfigParam
|
from .lora_config_param import LoraConfigParam
|
||||||
|
|
@ -14,4 +14,7 @@ class ModelCreateParams(TypedDict, total=False):
|
||||||
|
|
||||||
lora_config: LoraConfigParam
|
lora_config: LoraConfigParam
|
||||||
|
|
||||||
|
"""Optional metadata about this model/training run, set by the end-user"""
|
||||||
|
user_metadata: Optional[dict[str, Any]]
|
||||||
|
|
||||||
type: Literal["create_model"] = "create_model"
|
type: Literal["create_model"] = "create_model"
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Literal, Required, TypedDict
|
from typing_extensions import Literal, Required, TypedDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
__all__ = ["ModelID"]
|
__all__ = ["ModelID"]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from .._models import StrictBase
|
from .._models import StrictBase
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from typing_extensions import Annotated, TypeAlias
|
from typing_extensions import Annotated, TypeAlias
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Literal, Required, TypedDict
|
from typing_extensions import Literal, Required, TypedDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
__all__ = ["RequestID"]
|
__all__ = ["RequestID"]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import List, Optional, Sequence
|
from typing import List, Optional, Sequence
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Literal, Required, TypedDict
|
from typing_extensions import Literal, Required, TypedDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional, Sequence, Union
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List, Union, Iterable, Optional
|
from typing import List, Union, Iterable, Optional
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Literal, Required, TypedDict
|
from typing_extensions import Literal, Required, TypedDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import Literal, TypeAlias
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
|
||||||
__all__ = ["Severity"]
|
__all__ = ["Severity"]
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from .untyped_api_future import UntypedAPIFuture as UntypedAPIFuture
|
from .untyped_api_future import UntypedAPIFuture as UntypedAPIFuture
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..._compat import PYDANTIC_V2, ConfigDict
|
from ..._compat import PYDANTIC_V2, ConfigDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import Literal, TypeAlias
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
|
||||||
__all__ = ["StopReason"]
|
__all__ = ["StopReason"]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing import List, Union, Optional, TYPE_CHECKING, Any
|
from typing import List, Union, Optional, TYPE_CHECKING, Any
|
||||||
|
|
||||||
from .._models import StrictBase
|
from .._models import StrictBase
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Union, Iterable, Optional
|
from typing import Union, Iterable, Optional
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from typing_extensions import Literal, TypeAlias
|
from typing_extensions import Literal, TypeAlias
|
||||||
|
|
||||||
__all__ = ["TensorDtype"]
|
__all__ = ["TensorDtype"]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue