add gym taxi env

This commit is contained in:
dmahan93 2025-05-09 19:05:01 -05:00
parent c1ba77ec26
commit 92428fec8f
4 changed files with 377 additions and 7 deletions

View file

@ -60,6 +60,17 @@ class ScoredDataGroup(TypedDict):
overrides: Optional[List[Dict]]
class ScoredDataItem(TypedDict):
tokens: List[int]
masks: List[int]
scores: float
advantages: Optional[List[float]]
ref_logprobs: Optional[List[float]]
messages: Optional[List[Message]]
group_overrides: Optional[Dict]
overrides: Optional[Dict]
class EvalHandlingEnum(Enum):
"""
Enum for handling evals.
@ -228,7 +239,9 @@ class BaseEnv(ABC):
"""
return cls.env_config_cls(), ServerBaseline()
async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]:
async def collect_trajectory(
self, item: Item
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
raise NotImplementedError(
"Handle env single method must be implemented in subclass "
)
@ -248,13 +261,38 @@ class BaseEnv(ABC):
for _ in range(self.config.group_size):
tasks.append(self.collect_trajectory(item))
results = await asyncio.gather(*tasks)
if any(not isinstance(result[0], dict) for result in results):
logging.error("something wasn't a ScoredDataItem")
raise ValueError(
"collect_trajectory must return a ScoredDataItem or None to use the default "
"collect_trajectories method"
)
backlog = []
to_postprocess = []
to_postprocess = ScoredDataGroup()
to_postprocess["tokens"] = []
to_postprocess["masks"] = []
to_postprocess["scores"] = []
to_postprocess["advantages"] = []
to_postprocess["ref_logprobs"] = []
to_postprocess["messages"] = []
to_postprocess["group_overrides"] = {}
to_postprocess["overrides"] = []
print("Processing results")
for result in results:
if result[0] is not None:
to_postprocess.append(result[0])
to_postprocess["tokens"].append(result[0]["tokens"])
to_postprocess["masks"].append(result[0]["masks"])
to_postprocess["scores"].append(result[0]["scores"])
if result[0].get("advantages", None) is not None:
to_postprocess["advantages"].append(result[0]["advantages"])
if result[0].get("ref_logprobs", None) is not None:
to_postprocess["ref_logprobs"].append(result[0]["ref_logprobs"])
if result[0].get("messages", None) is not None:
to_postprocess["messages"].append(result[0]["messages"])
if result[0].get("group_overrides", None) is not None:
to_postprocess["group_overrides"].update(result[0]["group_overrides"])
if result[0].get("overrides", None) is not None:
to_postprocess["overrides"].append(result[0]["overrides"])
backlog.extend(result[1])
random.shuffle(backlog)
return to_postprocess, backlog
async def postprocess_histories(