This commit is contained in:
Alexey Gorbatovski 2025-07-15 14:37:55 +03:00
parent 9dbef4e552
commit 53984580c8

View file

@ -178,7 +178,6 @@ class BaseEnvConfig(BaseModel):
class BaseEnv(ABC):
name: Optional[str] = None
env_config_cls: BaseEnvConfig = BaseEnvConfig
server_cls: APIServer = APIServer
@ -224,7 +223,6 @@ class BaseEnv(ABC):
self.checkpoint_dir = ""
self.checkpoint_interval = -1
if self.config.data_path_to_save_groups is not None:
Path(self.config.data_path_to_save_groups).parent.mkdir(
parents=True, exist_ok=True
)
@ -286,7 +284,9 @@ class BaseEnv(ABC):
"Handle env single method must be implemented in subclass "
)
async def collect_trajectories(self, item: Item) -> Tuple[
async def collect_trajectories(
self, item: Item
) -> Tuple[
Union[
Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]
],
@ -830,7 +830,7 @@ class BaseEnv(ABC):
group.setdefault("group_overrides", None)
for mask in group["masks"]:
self.completion_lengths.append(len(mask))
self.completion_lengths.append(sum(m != -100 for m in mask))
if self.max_token_len <= 0:
warnings.warn(
@ -1461,9 +1461,10 @@ class BaseEnv(ABC):
"""
# Get the default configurations from the specific environment class via config_init
default_env_config_from_init, default_server_configs_from_init = (
cls.config_init()
)
(
default_env_config_from_init,
default_server_configs_from_init,
) = cls.config_init()
# Define namespace prefixes
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
@ -1549,9 +1550,9 @@ class BaseEnv(ABC):
env_config_dict_base["ensure_scores_are_not_same"] = False
env_config_dict_base["include_messages"] = True
if env_config_dict_base.get("data_path_to_save_groups") is None:
env_config_dict_base["data_path_to_save_groups"] = (
f"data/{cls.name or 'groups'}.jsonl"
)
env_config_dict_base[
"data_path_to_save_groups"
] = f"data/{cls.name or 'groups'}.jsonl"
env_config_dict_base["use_wandb"] = True
env_config_dict = merge_dicts(
@ -1729,9 +1730,10 @@ class BaseEnv(ABC):
type: The CliEvaluateConfig class for evaluate commands.
"""
# Get the default configurations from the specific environment class via config_init
default_env_config_from_init, default_server_configs_from_init = (
cls.config_init()
)
(
default_env_config_from_init,
default_server_configs_from_init,
) = cls.config_init()
# Define namespace prefixes
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"