diff --git a/environments/community/sql_query_env/sql_query_env.py b/environments/community/sql_query_env/sql_query_env.py index f42b636d..b5d1d93b 100644 --- a/environments/community/sql_query_env/sql_query_env.py +++ b/environments/community/sql_query_env/sql_query_env.py @@ -343,6 +343,7 @@ class SQLQueryEnv(BaseEnv): scores["tokens"] = list() scores["masks"] = list() scores["scores"] = list() + scores["inference_logprobs"] = list() # Get table info from first item gold_sql = rollout_group_data[0]["gold_sql"]