mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
init-commit
This commit is contained in:
commit
18a552597a
3461 changed files with 1150579 additions and 0 deletions
61
examples/xpuyu_usage/bootcamp_rl/judgers/base_judger.py
Executable file
61
examples/xpuyu_usage/bootcamp_rl/judgers/base_judger.py
Executable file
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) InternLM. All rights reserved.
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
MessageItem = TypedDict("MessageItem", {"role": str, "content": str})
|
||||
Reward = Union[float, List[float], None]
|
||||
MetaData = TypedDict("MetaData", {"data_source": str})
|
||||
|
||||
|
||||
@dataclass
|
||||
class JudgeStatus(Generic[T]):
|
||||
ok: bool = True
|
||||
reason: Optional[str] = None
|
||||
handle: Optional[T] = None
|
||||
|
||||
|
||||
class BaseJudger(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_data_received(
|
||||
self,
|
||||
prompt_messages: List[MessageItem],
|
||||
completion_messages: List[MessageItem],
|
||||
metadata: dict,
|
||||
) -> JudgeStatus:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def on_reward_required(
|
||||
self,
|
||||
status: JudgeStatus,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Reward:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
registered_judgers: Dict[str, Type[BaseJudger]] = {}
|
||||
|
||||
|
||||
def register_judger(name: str):
|
||||
global registered_judgers
|
||||
|
||||
def wrapper(cls):
|
||||
assert name not in registered_judgers, f"{name} already in {registered_judgers}"
|
||||
registered_judgers[name] = cls
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
Loading…
Add table
Add a link
Reference in a new issue