diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 02f734ba..4c6e0ad1 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -200,6 +200,8 @@ def _process_scored_data(scored_data: ScoredData) -> Dict[str, Any]: app.state.queue = [] if not hasattr(app.state, "buffer"): app.state.buffer = {} + if not hasattr(app.state, "total_rollouts_processed"): + app.state.total_rollouts_processed = 0 data_dict = _scored_data_to_dict(scored_data) env_id = data_dict.get("env_id") @@ -477,7 +479,7 @@ async def scored_data_list(scored_data_list: List[ScoredData]): async def get_status(): try: return { - "current_step": app.state.status_dict["step"], + "current_step": app.state.status_dict.get("step", 0), "queue_size": len(app.state.queue), } except AttributeError: @@ -541,7 +543,8 @@ async def get_status_env(env: EnvIdentifier): # Calculate total minimum allocations total_min_allocation = 0.0 - for env_config in app.state.envs: + envs = getattr(app.state, "envs", []) + for env_config in envs: if ( env_config.get("connected", False) and env_config.get("min_batch_allocation") is not None