mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
Merge branch 'main' into blackjack2-env
This commit is contained in:
commit
36f6822d71
9 changed files with 455 additions and 61 deletions
|
|
@ -61,6 +61,17 @@ class ScoredDataGroup(TypedDict):
|
|||
overrides: Optional[List[Dict]]
|
||||
|
||||
|
||||
class ScoredDataItem(TypedDict):
|
||||
tokens: List[int]
|
||||
masks: List[int]
|
||||
scores: float
|
||||
advantages: Optional[List[float]]
|
||||
ref_logprobs: Optional[List[float]]
|
||||
messages: Optional[List[Message]]
|
||||
group_overrides: Optional[Dict]
|
||||
overrides: Optional[Dict]
|
||||
|
||||
|
||||
class EvalHandlingEnum(Enum):
|
||||
"""
|
||||
Enum for handling evals.
|
||||
|
|
@ -229,7 +240,9 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
return cls.env_config_cls(), ServerBaseline()
|
||||
|
||||
async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]:
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
raise NotImplementedError(
|
||||
"Handle env single method must be implemented in subclass "
|
||||
)
|
||||
|
|
@ -249,13 +262,38 @@ class BaseEnv(ABC):
|
|||
for _ in range(self.config.group_size):
|
||||
tasks.append(self.collect_trajectory(item))
|
||||
results = await asyncio.gather(*tasks)
|
||||
if any(not isinstance(result[0], dict) for result in results):
|
||||
logging.error("something wasn't a ScoredDataItem")
|
||||
raise ValueError(
|
||||
"collect_trajectory must return a ScoredDataItem or None to use the default "
|
||||
"collect_trajectories method"
|
||||
)
|
||||
backlog = []
|
||||
to_postprocess = []
|
||||
to_postprocess = ScoredDataGroup()
|
||||
to_postprocess["tokens"] = []
|
||||
to_postprocess["masks"] = []
|
||||
to_postprocess["scores"] = []
|
||||
to_postprocess["advantages"] = []
|
||||
to_postprocess["ref_logprobs"] = []
|
||||
to_postprocess["messages"] = []
|
||||
to_postprocess["group_overrides"] = {}
|
||||
to_postprocess["overrides"] = []
|
||||
print("Processing results")
|
||||
for result in results:
|
||||
if result[0] is not None:
|
||||
to_postprocess.append(result[0])
|
||||
to_postprocess["tokens"].append(result[0]["tokens"])
|
||||
to_postprocess["masks"].append(result[0]["masks"])
|
||||
to_postprocess["scores"].append(result[0]["scores"])
|
||||
if result[0].get("advantages", None) is not None:
|
||||
to_postprocess["advantages"].append(result[0]["advantages"])
|
||||
if result[0].get("ref_logprobs", None) is not None:
|
||||
to_postprocess["ref_logprobs"].append(result[0]["ref_logprobs"])
|
||||
if result[0].get("messages", None) is not None:
|
||||
to_postprocess["messages"].append(result[0]["messages"])
|
||||
if result[0].get("group_overrides", None) is not None:
|
||||
to_postprocess["group_overrides"].update(result[0]["group_overrides"])
|
||||
if result[0].get("overrides", None) is not None:
|
||||
to_postprocess["overrides"].append(result[0]["overrides"])
|
||||
backlog.extend(result[1])
|
||||
random.shuffle(backlog)
|
||||
return to_postprocess, backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue