mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
fix: Retry RequestFailedError with Unknown category
Training jobs using on-policy methods like GKD would fail permanently when encountering transient server errors during KL penalty computation. The RetryHandler wasn't configured to retry RequestFailedError exceptions, causing the entire training run to abort on the first occurrence. Changes: - Add retry logic for RequestFailedError when category is Unknown - Skip retry for User/Server categories (these require code changes) - Add max_retries parameter (default 5) to prevent infinite loops - Improve logging to show error category for debugging The Unknown category indicates transient server-side issues that often resolve on retry, similar to 5xx HTTP errors. User errors are not retried since they indicate invalid input that won't succeed without changes. Fixes #158
This commit is contained in:
parent
2d8e9d5e00
commit
55b60f5c5c
1 changed files with 24 additions and 0 deletions
|
|
@ -61,6 +61,9 @@ class RetryConfig:
|
|||
enable_retry_logic: bool = True
|
||||
"""Whether to enable automatic retries on failure."""
|
||||
|
||||
max_retries: int = 5
|
||||
"""Maximum number of retry attempts before giving up. Set to 0 for unlimited retries."""
|
||||
|
||||
retryable_exceptions: tuple[Type[Exception], ...] = (
|
||||
asyncio.TimeoutError,
|
||||
tinker.APIConnectionError,
|
||||
|
|
@ -82,6 +85,7 @@ class RetryConfig:
|
|||
self.retry_delay_max,
|
||||
self.jitter_factor,
|
||||
self.enable_retry_logic,
|
||||
self.max_retries,
|
||||
self.retryable_exceptions,
|
||||
)
|
||||
)
|
||||
|
|
@ -241,6 +245,13 @@ class RetryHandler(Generic[T]): # noqa: UP046
|
|||
logger.error(f"Request failed with non-retryable error: {exception_str}")
|
||||
raise
|
||||
|
||||
# Check if we've exceeded max retries (0 means unlimited)
|
||||
if self.config.max_retries > 0 and attempt_count >= self.config.max_retries:
|
||||
logger.error(
|
||||
f"Request failed after {attempt_count} attempts (max_retries={self.config.max_retries}): {exception_str}"
|
||||
)
|
||||
raise
|
||||
|
||||
self._log_retry_reason(e, attempt_count)
|
||||
self._retry_count += 1
|
||||
|
||||
|
|
@ -263,6 +274,14 @@ class RetryHandler(Generic[T]): # noqa: UP046
|
|||
if isinstance(exception, tinker.APIStatusError):
|
||||
return is_retryable_status_code(exception.status_code)
|
||||
|
||||
# Retry RequestFailedError when the error category is Unknown, as these
|
||||
# may be transient server-side issues that could succeed on retry.
|
||||
# User errors (category=User) are not retried as they require client changes.
|
||||
if isinstance(exception, tinker.RequestFailedError):
|
||||
from tinker.types import RequestErrorCategory
|
||||
|
||||
return exception.category == RequestErrorCategory.Unknown
|
||||
|
||||
return False
|
||||
|
||||
def _log_retry_reason(self, exception: Exception, attempt_count: int):
|
||||
|
|
@ -275,6 +294,11 @@ class RetryHandler(Generic[T]): # noqa: UP046
|
|||
logger.debug(
|
||||
f"Request attempt #{attempt_count} failed with status {exception.status_code}"
|
||||
)
|
||||
elif isinstance(exception, tinker.RequestFailedError):
|
||||
logger.debug(
|
||||
f"Request attempt #{attempt_count} failed with category {exception.category}: "
|
||||
f"{exception.message}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Request attempt #{attempt_count} failed with error: {exception}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue