diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 6180af09..278a5a05 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -178,7 +178,6 @@ class BaseEnvConfig(BaseModel): class BaseEnv(ABC): - name: Optional[str] = None env_config_cls: BaseEnvConfig = BaseEnvConfig server_cls: APIServer = APIServer @@ -224,7 +223,6 @@ class BaseEnv(ABC): self.checkpoint_dir = "" self.checkpoint_interval = -1 if self.config.data_path_to_save_groups is not None: - Path(self.config.data_path_to_save_groups).parent.mkdir( parents=True, exist_ok=True ) @@ -286,7 +284,9 @@ class BaseEnv(ABC): "Handle env single method must be implemented in subclass " ) - async def collect_trajectories(self, item: Item) -> Tuple[ + async def collect_trajectories( + self, item: Item + ) -> Tuple[ Union[ Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None] ], @@ -830,7 +830,7 @@ class BaseEnv(ABC): group.setdefault("group_overrides", None) for mask in group["masks"]: - self.completion_lengths.append(len(mask)) + self.completion_lengths.append(sum(m != -100 for m in mask)) if self.max_token_len <= 0: warnings.warn( @@ -1461,9 +1461,10 @@ class BaseEnv(ABC): """ # Get the default configurations from the specific environment class via config_init - default_env_config_from_init, default_server_configs_from_init = ( - cls.config_init() - ) + ( + default_env_config_from_init, + default_server_configs_from_init, + ) = cls.config_init() # Define namespace prefixes env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}" @@ -1549,9 +1550,9 @@ class BaseEnv(ABC): env_config_dict_base["ensure_scores_are_not_same"] = False env_config_dict_base["include_messages"] = True if env_config_dict_base.get("data_path_to_save_groups") is None: - env_config_dict_base["data_path_to_save_groups"] = ( - f"data/{cls.name or 'groups'}.jsonl" - ) + env_config_dict_base[ + "data_path_to_save_groups" + ] = f"data/{cls.name or 'groups'}.jsonl" env_config_dict_base["use_wandb"] = True env_config_dict = merge_dicts( @@ -1729,9 +1730,10 @@ class BaseEnv(ABC): type: The CliEvaluateConfig class for evaluate commands. """ # Get the default configurations from the specific environment class via config_init - default_env_config_from_init, default_server_configs_from_init = ( - cls.config_init() - ) + ( + default_env_config_from_init, + default_server_configs_from_init, + ) = cls.config_init() # Define namespace prefixes env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"