mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
61 lines
1.3 KiB
Python
Executable file
61 lines
1.3 KiB
Python
Executable file
# Copyright (c) InternLM. All rights reserved.
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import (
|
|
Dict,
|
|
Generic,
|
|
List,
|
|
Optional,
|
|
Type,
|
|
TypedDict,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
T = TypeVar("T")
|
|
MessageItem = TypedDict("MessageItem", {"role": str, "content": str})
|
|
Reward = Union[float, List[float], None]
|
|
MetaData = TypedDict("MetaData", {"data_source": str})
|
|
|
|
|
|
@dataclass
|
|
class JudgeStatus(Generic[T]):
|
|
ok: bool = True
|
|
reason: Optional[str] = None
|
|
handle: Optional[T] = None
|
|
|
|
|
|
class BaseJudger(ABC):
|
|
def __init__(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_data_received(
|
|
self,
|
|
prompt_messages: List[MessageItem],
|
|
completion_messages: List[MessageItem],
|
|
metadata: dict,
|
|
) -> JudgeStatus:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def on_reward_required(
|
|
self,
|
|
status: JudgeStatus,
|
|
timeout: Optional[float] = None,
|
|
) -> Reward:
|
|
raise NotImplementedError()
|
|
|
|
|
|
registered_judgers: Dict[str, Type[BaseJudger]] = {}
|
|
|
|
|
|
def register_judger(name: str):
|
|
global registered_judgers
|
|
|
|
def wrapper(cls):
|
|
assert name not in registered_judgers, f"{name} already in {registered_judgers}"
|
|
registered_judgers[name] = cls
|
|
return cls
|
|
|
|
return wrapper
|