InternBootcamp/examples/xpuyu_usage/bootcamp_rl/judgers/base_judger.py
2025-05-23 15:27:15 +08:00

61 lines
1.3 KiB
Python
Executable file

# Copyright (c) InternLM. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
Dict,
Generic,
List,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)
T = TypeVar("T")
MessageItem = TypedDict("MessageItem", {"role": str, "content": str})
Reward = Union[float, List[float], None]
MetaData = TypedDict("MetaData", {"data_source": str})
@dataclass
class JudgeStatus(Generic[T]):
ok: bool = True
reason: Optional[str] = None
handle: Optional[T] = None
class BaseJudger(ABC):
def __init__(self):
pass
@abstractmethod
def on_data_received(
self,
prompt_messages: List[MessageItem],
completion_messages: List[MessageItem],
metadata: dict,
) -> JudgeStatus:
raise NotImplementedError()
@abstractmethod
def on_reward_required(
self,
status: JudgeStatus,
timeout: Optional[float] = None,
) -> Reward:
raise NotImplementedError()
registered_judgers: Dict[str, Type[BaseJudger]] = {}
def register_judger(name: str):
global registered_judgers
def wrapper(cls):
assert name not in registered_judgers, f"{name} already in {registered_judgers}"
registered_judgers[name] = cls
return cls
return wrapper