mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Bug fix
This commit is contained in:
parent
9dbef4e552
commit
53984580c8
1 changed files with 15 additions and 13 deletions
|
|
@ -178,7 +178,6 @@ class BaseEnvConfig(BaseModel):
|
|||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
||||
name: Optional[str] = None
|
||||
env_config_cls: BaseEnvConfig = BaseEnvConfig
|
||||
server_cls: APIServer = APIServer
|
||||
|
|
@ -224,7 +223,6 @@ class BaseEnv(ABC):
|
|||
self.checkpoint_dir = ""
|
||||
self.checkpoint_interval = -1
|
||||
if self.config.data_path_to_save_groups is not None:
|
||||
|
||||
Path(self.config.data_path_to_save_groups).parent.mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
|
@ -286,7 +284,9 @@ class BaseEnv(ABC):
|
|||
"Handle env single method must be implemented in subclass "
|
||||
)
|
||||
|
||||
async def collect_trajectories(self, item: Item) -> Tuple[
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[
|
||||
Union[
|
||||
Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]
|
||||
],
|
||||
|
|
@ -830,7 +830,7 @@ class BaseEnv(ABC):
|
|||
group.setdefault("group_overrides", None)
|
||||
|
||||
for mask in group["masks"]:
|
||||
self.completion_lengths.append(len(mask))
|
||||
self.completion_lengths.append(sum(m != -100 for m in mask))
|
||||
|
||||
if self.max_token_len <= 0:
|
||||
warnings.warn(
|
||||
|
|
@ -1461,9 +1461,10 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
|
||||
# Get the default configurations from the specific environment class via config_init
|
||||
default_env_config_from_init, default_server_configs_from_init = (
|
||||
cls.config_init()
|
||||
)
|
||||
(
|
||||
default_env_config_from_init,
|
||||
default_server_configs_from_init,
|
||||
) = cls.config_init()
|
||||
|
||||
# Define namespace prefixes
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
|
@ -1549,9 +1550,9 @@ class BaseEnv(ABC):
|
|||
env_config_dict_base["ensure_scores_are_not_same"] = False
|
||||
env_config_dict_base["include_messages"] = True
|
||||
if env_config_dict_base.get("data_path_to_save_groups") is None:
|
||||
env_config_dict_base["data_path_to_save_groups"] = (
|
||||
f"data/{cls.name or 'groups'}.jsonl"
|
||||
)
|
||||
env_config_dict_base[
|
||||
"data_path_to_save_groups"
|
||||
] = f"data/{cls.name or 'groups'}.jsonl"
|
||||
env_config_dict_base["use_wandb"] = True
|
||||
|
||||
env_config_dict = merge_dicts(
|
||||
|
|
@ -1729,9 +1730,10 @@ class BaseEnv(ABC):
|
|||
type: The CliEvaluateConfig class for evaluate commands.
|
||||
"""
|
||||
# Get the default configurations from the specific environment class via config_init
|
||||
default_env_config_from_init, default_server_configs_from_init = (
|
||||
cls.config_init()
|
||||
)
|
||||
(
|
||||
default_env_config_from_init,
|
||||
default_server_configs_from_init,
|
||||
) = cls.config_init()
|
||||
|
||||
# Define namespace prefixes
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue