fix: Retry RequestFailedError with Unknown category

Training jobs using on-policy methods like GKD would fail permanently
when encountering transient server errors during KL penalty computation.
The RetryHandler wasn't configured to retry RequestFailedError exceptions,
causing the entire training run to abort on the first occurrence.

Changes:
- Add retry logic for RequestFailedError when category is Unknown
- Skip retry for User/Server categories (these require code changes)
- Add max_retries parameter (default 5) to prevent infinite loops
- Improve logging to show error category for debugging

The Unknown category indicates transient server-side issues that often
resolve on retry, similar to 5xx HTTP errors. User errors are not retried
since they indicate invalid input that won't succeed without changes.

Fixes #158
This commit is contained in:
Blake Ledden 2025-12-20 22:55:50 -08:00
parent 2d8e9d5e00
commit 55b60f5c5c

View file

@ -61,6 +61,9 @@ class RetryConfig:
enable_retry_logic: bool = True
"""Whether to enable automatic retries on failure."""
max_retries: int = 5
"""Maximum number of retry attempts before giving up. Set to 0 for unlimited retries."""
retryable_exceptions: tuple[Type[Exception], ...] = (
asyncio.TimeoutError,
tinker.APIConnectionError,
@ -82,6 +85,7 @@ class RetryConfig:
self.retry_delay_max,
self.jitter_factor,
self.enable_retry_logic,
self.max_retries,
self.retryable_exceptions,
)
)
@ -241,6 +245,13 @@ class RetryHandler(Generic[T]): # noqa: UP046
logger.error(f"Request failed with non-retryable error: {exception_str}")
raise
# Check if we've exceeded max retries (0 means unlimited)
if self.config.max_retries > 0 and attempt_count >= self.config.max_retries:
logger.error(
f"Request failed after {attempt_count} attempts (max_retries={self.config.max_retries}): {exception_str}"
)
raise
self._log_retry_reason(e, attempt_count)
self._retry_count += 1
@ -263,6 +274,14 @@ class RetryHandler(Generic[T]): # noqa: UP046
if isinstance(exception, tinker.APIStatusError):
return is_retryable_status_code(exception.status_code)
# Retry RequestFailedError when the error category is Unknown, as these
# may be transient server-side issues that could succeed on retry.
# User errors (category=User) are not retried as they require client changes.
if isinstance(exception, tinker.RequestFailedError):
from tinker.types import RequestErrorCategory
return exception.category == RequestErrorCategory.Unknown
return False
def _log_retry_reason(self, exception: Exception, attempt_count: int):
@ -275,6 +294,11 @@ class RetryHandler(Generic[T]): # noqa: UP046
logger.debug(
f"Request attempt #{attempt_count} failed with status {exception.status_code}"
)
elif isinstance(exception, tinker.RequestFailedError):
logger.debug(
f"Request attempt #{attempt_count} failed with category {exception.category}: "
f"{exception.message}"
)
else:
logger.debug(f"Request attempt #{attempt_count} failed with error: {exception}")