mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-23 16:54:58 +00:00
Sync contents
This commit is contained in:
parent
a187cb5b2f
commit
77cc8a0810
7 changed files with 226 additions and 226 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"last_synced_sha": "1c67712168cd74fc2bda614f239dcad5f7b2e166",
|
"last_synced_sha": "7a5a376dc40cd21c3c744270af661b021e9fa9c6",
|
||||||
"last_sync_time": "2026-02-13T22:05:28.522838"
|
"last_sync_time": "2026-02-27T05:49:39.458251"
|
||||||
}
|
}
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "tinker"
|
name = "tinker"
|
||||||
version = "0.13.1"
|
version = "0.14.0"
|
||||||
description = "The official Python SDK for the tinker API"
|
description = "The official Python SDK for the tinker API"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
|
||||||
|
|
@ -85,186 +85,200 @@ class _APIFuture(APIFuture[T]): # pyright: ignore[reportUnusedClass]
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
iteration = -1
|
iteration = -1
|
||||||
connection_error_retries = 0
|
connection_error_retries = 0
|
||||||
|
allow_metadata_only = True
|
||||||
|
|
||||||
while True:
|
async with contextlib.AsyncExitStack() as stack:
|
||||||
iteration += 1
|
while True:
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
if timeout is not None and time.time() - start_time > timeout:
|
if timeout is not None and time.time() - start_time > timeout:
|
||||||
if telemetry := self.get_telemetry():
|
if telemetry := self.get_telemetry():
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
telemetry.log(
|
telemetry.log(
|
||||||
"APIFuture.result_async.timeout",
|
"APIFuture.result_async.timeout",
|
||||||
event_data={
|
event_data={
|
||||||
"request_id": self.request_id,
|
"request_id": self.request_id,
|
||||||
"request_type": self.request_type,
|
"request_type": self.request_type,
|
||||||
"timeout": timeout,
|
"timeout": timeout,
|
||||||
"iteration": iteration,
|
"iteration": iteration,
|
||||||
"elapsed_time": current_time - start_time,
|
"elapsed_time": current_time - start_time,
|
||||||
},
|
},
|
||||||
severity="ERROR",
|
severity="ERROR",
|
||||||
)
|
)
|
||||||
raise TimeoutError(
|
raise TimeoutError(
|
||||||
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
|
f"Timeout of {timeout} seconds reached while waiting for result of {self.request_id=}"
|
||||||
)
|
|
||||||
|
|
||||||
# Headers for telemetry
|
|
||||||
headers = {
|
|
||||||
"X-Tinker-Request-Iteration": str(iteration),
|
|
||||||
"X-Tinker-Request-Type": self.request_type,
|
|
||||||
}
|
|
||||||
if iteration == 0:
|
|
||||||
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
|
|
||||||
self.request_queue_roundtrip_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# Function hasn't been called yet, execute it now
|
|
||||||
try:
|
|
||||||
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
|
|
||||||
response = await client.futures.with_raw_response.retrieve(
|
|
||||||
request=FutureRetrieveRequest(request_id=self.request_id),
|
|
||||||
timeout=45,
|
|
||||||
extra_headers=headers,
|
|
||||||
max_retries=0,
|
|
||||||
)
|
|
||||||
except tinker.APIStatusError as e:
|
|
||||||
connection_error_retries = 0
|
|
||||||
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
|
|
||||||
user_error = is_user_error(e)
|
|
||||||
if telemetry := self.get_telemetry():
|
|
||||||
current_time = time.time()
|
|
||||||
telemetry.log(
|
|
||||||
"APIFuture.result_async.api_status_error",
|
|
||||||
event_data={
|
|
||||||
"request_id": self.request_id,
|
|
||||||
"request_type": self.request_type,
|
|
||||||
"status_code": e.status_code,
|
|
||||||
"exception": str(e),
|
|
||||||
"should_retry": should_retry,
|
|
||||||
"is_user_error": user_error,
|
|
||||||
"iteration": iteration,
|
|
||||||
"elapsed_time": current_time - start_time,
|
|
||||||
},
|
|
||||||
severity="WARNING" if should_retry or user_error else "ERROR",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Retry 408s until we time out
|
headers = {
|
||||||
if e.status_code == 408:
|
"X-Tinker-Request-Iteration": str(iteration),
|
||||||
if self._queue_state_observer is not None:
|
"X-Tinker-Request-Type": self.request_type,
|
||||||
with contextlib.suppress(Exception):
|
}
|
||||||
response = e.response.json()
|
if iteration == 0:
|
||||||
if queue_state_str := response.get("queue_state", None):
|
headers["X-Tinker-Create-Promise-Roundtrip-Time"] = str(
|
||||||
queue_state_reason = response.get("queue_state_reason", None)
|
self.request_queue_roundtrip_time
|
||||||
if queue_state_str == "active":
|
)
|
||||||
queue_state = QueueState.ACTIVE
|
|
||||||
elif queue_state_str == "paused_rate_limit":
|
try:
|
||||||
queue_state = QueueState.PAUSED_RATE_LIMIT
|
with self.holder.aclient(ClientConnectionPoolType.RETRIEVE_PROMISE) as client:
|
||||||
elif queue_state_str == "paused_capacity":
|
response = await client.futures.with_raw_response.retrieve(
|
||||||
queue_state = QueueState.PAUSED_CAPACITY
|
request=FutureRetrieveRequest(
|
||||||
else:
|
request_id=self.request_id,
|
||||||
queue_state = QueueState.UNKNOWN
|
allow_metadata_only=allow_metadata_only,
|
||||||
self._queue_state_observer.on_queue_state_change(
|
),
|
||||||
queue_state, queue_state_reason
|
timeout=45,
|
||||||
)
|
extra_headers=headers,
|
||||||
continue
|
max_retries=0,
|
||||||
if e.status_code == 410:
|
)
|
||||||
raise RetryableException(
|
except tinker.APIStatusError as e:
|
||||||
message=f"Promise expired/broken for request {self.untyped_future.request_id}"
|
connection_error_retries = 0
|
||||||
|
should_retry = e.status_code == 408 or e.status_code in range(500, 600)
|
||||||
|
user_error = is_user_error(e)
|
||||||
|
if telemetry := self.get_telemetry():
|
||||||
|
current_time = time.time()
|
||||||
|
telemetry.log(
|
||||||
|
"APIFuture.result_async.api_status_error",
|
||||||
|
event_data={
|
||||||
|
"request_id": self.request_id,
|
||||||
|
"request_type": self.request_type,
|
||||||
|
"status_code": e.status_code,
|
||||||
|
"exception": str(e),
|
||||||
|
"should_retry": should_retry,
|
||||||
|
"is_user_error": user_error,
|
||||||
|
"iteration": iteration,
|
||||||
|
"elapsed_time": current_time - start_time,
|
||||||
|
},
|
||||||
|
severity="WARNING" if should_retry or user_error else "ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry 408s until we time out
|
||||||
|
if e.status_code == 408:
|
||||||
|
if self._queue_state_observer is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
response = e.response.json()
|
||||||
|
if queue_state_str := response.get("queue_state", None):
|
||||||
|
queue_state_reason = response.get("queue_state_reason", None)
|
||||||
|
if queue_state_str == "active":
|
||||||
|
queue_state = QueueState.ACTIVE
|
||||||
|
elif queue_state_str == "paused_rate_limit":
|
||||||
|
queue_state = QueueState.PAUSED_RATE_LIMIT
|
||||||
|
elif queue_state_str == "paused_capacity":
|
||||||
|
queue_state = QueueState.PAUSED_CAPACITY
|
||||||
|
else:
|
||||||
|
queue_state = QueueState.UNKNOWN
|
||||||
|
self._queue_state_observer.on_queue_state_change(
|
||||||
|
queue_state, queue_state_reason
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if e.status_code == 410:
|
||||||
|
raise RetryableException(
|
||||||
|
message=f"Promise expired/broken for request {self.untyped_future.request_id}"
|
||||||
|
) from e
|
||||||
|
if e.status_code in range(500, 600):
|
||||||
|
continue
|
||||||
|
raise ValueError(
|
||||||
|
f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}"
|
||||||
) from e
|
) from e
|
||||||
if e.status_code in range(500, 600):
|
except tinker.APIConnectionError as e:
|
||||||
|
if telemetry := self.get_telemetry():
|
||||||
|
current_time = time.time()
|
||||||
|
telemetry.log(
|
||||||
|
"APIFuture.result_async.connection_error",
|
||||||
|
event_data={
|
||||||
|
"request_id": self.request_id,
|
||||||
|
"request_type": self.request_type,
|
||||||
|
"exception": str(e),
|
||||||
|
"connection_error_retries": connection_error_retries,
|
||||||
|
"iteration": iteration,
|
||||||
|
"elapsed_time": current_time - start_time,
|
||||||
|
},
|
||||||
|
severity="WARNING",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry all connection errors with exponential backoff
|
||||||
|
await asyncio.sleep(min(2**connection_error_retries, 30))
|
||||||
|
connection_error_retries += 1
|
||||||
continue
|
continue
|
||||||
raise ValueError(
|
|
||||||
f"Error retrieving result: {e} with status code {e.status_code=} for {self.request_id=} and expected type {self.model_cls=}"
|
result_dict: Any = await response.json()
|
||||||
) from e
|
|
||||||
except tinker.APIConnectionError as e:
|
if "type" in result_dict and result_dict["type"] == "try_again":
|
||||||
if telemetry := self.get_telemetry():
|
logger.warning(f"Retrying request {self.request_id=} because of try_again")
|
||||||
current_time = time.time()
|
continue
|
||||||
telemetry.log(
|
|
||||||
"APIFuture.result_async.connection_error",
|
if result_dict.get("status") == "complete_metadata":
|
||||||
event_data={
|
# metadata only response should be returned only once
|
||||||
"request_id": self.request_id,
|
assert allow_metadata_only
|
||||||
"request_type": self.request_type,
|
allow_metadata_only = False
|
||||||
"exception": str(e),
|
|
||||||
"connection_error_retries": connection_error_retries,
|
response_payload_size = result_dict.get("response_payload_size", 0)
|
||||||
"iteration": iteration,
|
assert response_payload_size is not None
|
||||||
"elapsed_time": current_time - start_time,
|
await stack.enter_async_context(
|
||||||
},
|
self.holder._inflight_response_bytes_semaphore.acquire(
|
||||||
severity="WARNING",
|
response_payload_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "error" in result_dict:
|
||||||
|
error_category = RequestErrorCategory.Unknown
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
error_category = RequestErrorCategory(result_dict.get("category"))
|
||||||
|
|
||||||
|
user_error = error_category is RequestErrorCategory.User
|
||||||
|
if telemetry := self.get_telemetry():
|
||||||
|
current_time = time.time()
|
||||||
|
telemetry.log(
|
||||||
|
"APIFuture.result_async.application_error",
|
||||||
|
event_data={
|
||||||
|
"request_id": self.request_id,
|
||||||
|
"request_type": self.request_type,
|
||||||
|
"error": result_dict["error"],
|
||||||
|
"error_category": error_category.name,
|
||||||
|
"is_user_error": user_error,
|
||||||
|
"iteration": iteration,
|
||||||
|
"elapsed_time": current_time - start_time,
|
||||||
|
},
|
||||||
|
severity="WARNING" if user_error else "ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
error_message = result_dict["error"]
|
||||||
|
raise RequestFailedError(
|
||||||
|
f"Request failed: {error_message} for {self.request_id=} and expected type {self.model_cls=}",
|
||||||
|
request_id=self.request_id,
|
||||||
|
category=error_category,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Retry all connection errors with exponential backoff
|
try:
|
||||||
await asyncio.sleep(min(2**connection_error_retries, 30))
|
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
|
||||||
connection_error_retries += 1
|
self._cached_result = self.model_cls.model_validate(result_dict)
|
||||||
continue
|
else:
|
||||||
|
self._cached_result = result_dict
|
||||||
|
return cast(T, self._cached_result)
|
||||||
|
except Exception as e:
|
||||||
|
if telemetry := self.get_telemetry():
|
||||||
|
current_time = time.time()
|
||||||
|
telemetry.log(
|
||||||
|
"APIFuture.result_async.validation_error",
|
||||||
|
event_data={
|
||||||
|
"request_id": self.request_id,
|
||||||
|
"request_type": self.request_type,
|
||||||
|
"exception": str(e),
|
||||||
|
"exception_type": type(e).__name__,
|
||||||
|
"exception_stack": "".join(
|
||||||
|
traceback.format_exception(type(e), e, e.__traceback__)
|
||||||
|
)
|
||||||
|
if e.__traceback__
|
||||||
|
else None,
|
||||||
|
"model_cls": str(self.model_cls),
|
||||||
|
"iteration": iteration,
|
||||||
|
"elapsed_time": current_time - start_time,
|
||||||
|
},
|
||||||
|
severity="ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
# Function hasn't been called yet, execute it now
|
raise ValueError(
|
||||||
result_dict: Any = await response.json()
|
f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}"
|
||||||
|
) from e
|
||||||
if "type" in result_dict and result_dict["type"] == "try_again":
|
|
||||||
logger.warning(f"Retrying request {self.request_id=} because of try_again")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "error" in result_dict:
|
|
||||||
error_category = RequestErrorCategory.Unknown
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
error_category = RequestErrorCategory(result_dict.get("category"))
|
|
||||||
|
|
||||||
user_error = error_category is RequestErrorCategory.User
|
|
||||||
if telemetry := self.get_telemetry():
|
|
||||||
current_time = time.time()
|
|
||||||
telemetry.log(
|
|
||||||
"APIFuture.result_async.application_error",
|
|
||||||
event_data={
|
|
||||||
"request_id": self.request_id,
|
|
||||||
"request_type": self.request_type,
|
|
||||||
"error": result_dict["error"],
|
|
||||||
"error_category": error_category.name,
|
|
||||||
"is_user_error": user_error,
|
|
||||||
"iteration": iteration,
|
|
||||||
"elapsed_time": current_time - start_time,
|
|
||||||
},
|
|
||||||
severity="WARNING" if user_error else "ERROR",
|
|
||||||
)
|
|
||||||
|
|
||||||
error_message = result_dict["error"]
|
|
||||||
raise RequestFailedError(
|
|
||||||
f"Request failed: {error_message} for {self.request_id=} and expected type {self.model_cls=}",
|
|
||||||
request_id=self.request_id,
|
|
||||||
category=error_category,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check if model_cls is a BaseModel subclass before calling model_validate
|
|
||||||
if inspect.isclass(self.model_cls) and issubclass(self.model_cls, BaseModel):
|
|
||||||
self._cached_result = self.model_cls.model_validate(result_dict)
|
|
||||||
else:
|
|
||||||
# For non-BaseModel types, just return the result directly
|
|
||||||
self._cached_result = result_dict
|
|
||||||
return cast(T, self._cached_result)
|
|
||||||
except Exception as e:
|
|
||||||
if telemetry := self.get_telemetry():
|
|
||||||
current_time = time.time()
|
|
||||||
telemetry.log(
|
|
||||||
"APIFuture.result_async.validation_error",
|
|
||||||
event_data={
|
|
||||||
"request_id": self.request_id,
|
|
||||||
"request_type": self.request_type,
|
|
||||||
"exception": str(e),
|
|
||||||
"exception_type": type(e).__name__,
|
|
||||||
"exception_stack": "".join(
|
|
||||||
traceback.format_exception(type(e), e, e.__traceback__)
|
|
||||||
)
|
|
||||||
if e.__traceback__
|
|
||||||
else None,
|
|
||||||
"model_cls": str(self.model_cls),
|
|
||||||
"iteration": iteration,
|
|
||||||
"elapsed_time": current_time - start_time,
|
|
||||||
},
|
|
||||||
severity="ERROR",
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Error retrieving result: {e} for {self.request_id=} and expected type {self.model_cls=}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def request_id(self) -> str:
|
def request_id(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -162,6 +162,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
||||||
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
|
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
|
||||||
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
||||||
|
self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
||||||
self._training_client_lock: threading.Lock = threading.Lock()
|
self._training_client_lock: threading.Lock = threading.Lock()
|
||||||
|
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
|
|
@ -336,7 +337,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.run_coroutine_threadsafe(self._async_cleanup())
|
self.run_coroutine_threadsafe(self._async_cleanup())
|
||||||
if telemetry := self._telemetry:
|
if telemetry := getattr(self, "_telemetry", None):
|
||||||
telemetry.stop()
|
telemetry.stop()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|
|
||||||
|
|
@ -14,4 +14,35 @@ class ForwardBackwardOutput(BaseModel):
|
||||||
"""Dictionary mapping field names to tensor data"""
|
"""Dictionary mapping field names to tensor data"""
|
||||||
|
|
||||||
metrics: Dict[str, float]
|
metrics: Dict[str, float]
|
||||||
"""Training metrics as key-value pairs"""
|
"""Training metrics as key-value pairs.
|
||||||
|
|
||||||
|
The following metrics are recorded only during MoE (Mixture of Experts) training.
|
||||||
|
Note: Don't fixate on the exact values of these metrics at the start of training.
|
||||||
|
Different models on different data will have different initial values. How these
|
||||||
|
metrics evolve over training is what matters.
|
||||||
|
|
||||||
|
In the definitions below, *perfect balance* means ``total_tokens / num_experts``
|
||||||
|
— the number of tokens each expert would receive if routing were perfectly uniform.
|
||||||
|
|
||||||
|
- ``e_frac_with_tokens:mean``: Fraction of experts that received at least one token,
|
||||||
|
averaged across layers. A value of 1.0 means every expert got work; 0.5 means half
|
||||||
|
were idle. Decreasing over time is concerning (routing collapse).
|
||||||
|
|
||||||
|
- ``e_frac_oversubscribed:mean``: Fraction of experts receiving more tokens than
|
||||||
|
perfect balance, averaged across layers. Increasing over time is concerning.
|
||||||
|
|
||||||
|
- ``e_max_violation:mean``: How much the most overloaded expert exceeds perfect
|
||||||
|
balance, as a fraction of perfect balance, averaged across layers. Computed as
|
||||||
|
``(max_tokens - perfect_balance) / perfect_balance``. A value of 2.0 means the
|
||||||
|
busiest expert got 3x the fair share. Increasing over time is concerning.
|
||||||
|
|
||||||
|
- ``e_max_violation:max``: Same as ``e_max_violation:mean`` but takes the max
|
||||||
|
across layers instead of the mean. Shows the worst-case load imbalance in any
|
||||||
|
single layer.
|
||||||
|
|
||||||
|
- ``e_min_violation:mean``: How much the least loaded expert is below perfect
|
||||||
|
balance, as a fraction of perfect balance, averaged across layers. Computed as
|
||||||
|
``(min_tokens - perfect_balance) / perfect_balance``. A value of -0.5 means the
|
||||||
|
least-used expert got half the fair share; -1.0 means it got nothing. Typically
|
||||||
|
negative. Decreasing over time (more negative) is concerning.
|
||||||
|
"""
|
||||||
|
|
|
||||||
|
|
@ -7,3 +7,7 @@ __all__ = ["FutureRetrieveRequest"]
|
||||||
class FutureRetrieveRequest(StrictBase):
|
class FutureRetrieveRequest(StrictBase):
|
||||||
request_id: RequestID
|
request_id: RequestID
|
||||||
"""The ID of the request to retrieve"""
|
"""The ID of the request to retrieve"""
|
||||||
|
|
||||||
|
allow_metadata_only: bool = False
|
||||||
|
"""When True, the server may return only response metadata (status and size)
|
||||||
|
instead of the full payload if the response exceeds the server's inline size limit."""
|
||||||
|
|
|
||||||
|
|
@ -1,50 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
from httpx import URL
|
|
||||||
|
|
||||||
from tinker._base_client import BaseClient
|
|
||||||
from tinker.lib.public_interfaces import _convert_forward_backward_input
|
|
||||||
from tinker.types import ForwardBackwardInput
|
|
||||||
|
|
||||||
|
|
||||||
class TestBaseClient(BaseClient):
|
|
||||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
|
||||||
# Test the fixed base_url parameter handling
|
|
||||||
super().__init__(version="1.0.0", base_url=base_url, _strict_response_validation=True)
|
|
||||||
|
|
||||||
def make_status_error(self, err_msg: str, body: object, response: object) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _enforce_trailing_slash(self, url: URL) -> URL:
|
|
||||||
return url
|
|
||||||
|
|
||||||
|
|
||||||
# Test that the base_url parameter is handled correctly
|
|
||||||
client = TestBaseClient("http://example.com")
|
|
||||||
print(f"Base URL correctly set to: {client._base_url}")
|
|
||||||
|
|
||||||
# Create a proper dictionary to initialize the model
|
|
||||||
model_data = {
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"model_input": {"chunks": [{"type": "encoded_text", "tokens": [1, 2, 3]}]},
|
|
||||||
"loss_fn_inputs": {},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"loss_fn": "cross_entropy", # Use a valid value from the enum
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create a mock model just to test the base_url parameter
|
|
||||||
try:
|
|
||||||
# Properly initialize the ForwardBackwardInput model
|
|
||||||
input_obj = ForwardBackwardInput.model_validate(model_data)
|
|
||||||
|
|
||||||
# Test the convert function
|
|
||||||
result = _convert_forward_backward_input(input_obj)
|
|
||||||
print(f"Conversion successful: {result}")
|
|
||||||
except Exception as e:
|
|
||||||
# Since we're just testing the base_url parameter fix, we can ignore model validation errors
|
|
||||||
print(f"Note: Could not validate ForwardBackwardInput model: {e}")
|
|
||||||
print("But that's okay since we're just testing the base_url parameter fix")
|
|
||||||
|
|
||||||
print("All tests passed!")
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue