diff --git a/environments/dataset_environment/dataset_env.py b/environments/dataset_environment/dataset_env.py index 23548c73..ecfc9ec2 100644 --- a/environments/dataset_environment/dataset_env.py +++ b/environments/dataset_environment/dataset_env.py @@ -78,6 +78,7 @@ class DatasetEnv(BaseEnv): self.metric_buffer = {} self.reward_function = self._initialize_reward_function() + self.current_item = None def _initialize_reward_function(self): if hasattr(self.config, "reward_functions") and self.config.reward_functions: