Fix up dataset and data dumps

This commit is contained in:
teknium1 2025-05-26 01:50:22 -07:00
parent ae0340bb9f
commit e2ea82b29b

View file

@ -10,6 +10,7 @@ import logging # Add logging import
import os # Add os import
import random # Ensured import random is present
import re
import uuid # Import uuid module
from difflib import SequenceMatcher
from typing import Dict, List, Optional, Tuple, Union
@ -100,11 +101,11 @@ class SWERLEnvConfig(BaseEnvConfig):
# HF Dataset Configs
dataset_name: str = Field(
default="princeton-nlp/SWE-bench_Lite_oracle",
description="Name of the Hugging Face dataset to load.",
description="Name of the Hugging Face dataset to load for training (and evaluation if dataset_name_eval is not set).", # noqa: E501
)
dataset_config_name: Optional[str] = Field(
default=None,
description="Configuration name for the Hugging Face dataset (e.g., a subset).",
description="Configuration name for the Hugging Face dataset for training (e.g., a subset) (and evaluation if dataset_config_name_eval is not set).", # noqa: E501
)
dataset_split_train: str = Field(
default="train", description="Dataset split to use for training."
@ -122,6 +123,15 @@ class SWERLEnvConfig(BaseEnvConfig):
dataset_oracle_patch_column: str = Field(
default="patch", description="Column name for the oracle patch."
)
# New fields for evaluation dataset
dataset_name_eval: Optional[str] = Field(
default=None,
description="Optional: Name of the Hugging Face dataset to load for evaluation. If None, uses dataset_name.",
)
dataset_config_name_eval: Optional[str] = Field(
default=None,
description="Optional: Configuration name for the Hugging Face dataset for evaluation. If None, uses dataset_config_name.", # noqa: E501
)
max_train_samples: Optional[int] = Field(
default=None,
description="Maximum number of training samples to load. None for all.",
@ -185,13 +195,14 @@ class SWERLEnv(BaseEnv):
self.think_tags_well_formed_buffer = []
# For saving rollouts to JSONL
self.run_uuid = str(uuid.uuid4()) # Generate a UUID for this run
self.rollouts_to_save_buffer: List[
Dict[str, Union[str, List[Dict[str, Union[List[Dict[str, str]], float]]]]]
] = []
self.processed_item_count = 0
# Creates .../atropos/environments/datadumps/ relative to the project structure
# Creates .../atropos/environments/swe_rl/data_dumps/ relative to the project structure
self.datadumps_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "datadumps"
os.path.dirname(os.path.abspath(__file__)), "swe_rl", "data_dumps"
)
self.save_file_batch_num = 0
@ -216,15 +227,18 @@ class SWERLEnv(BaseEnv):
ensure_scores_are_not_same=True,
eval_n_samples=1,
# HF Dataset Configs
dataset_name="princeton-nlp/SWE-bench_Lite_oracle",
dataset_name="NousResearch/SWE-smith-oracle",
dataset_config_name=None,
dataset_split_train="test",
dataset_split_train="train",
dataset_split_eval="test",
dataset_issue_column="problem_statement",
dataset_code_context_column="text",
dataset_oracle_patch_column="patch",
max_train_samples=1000,
max_test_samples=200,
# Initialize new eval dataset fields
dataset_name_eval=None, # Or a default eval dataset name
dataset_config_name_eval=None, # Or a default eval config name
)
server_configs = [
APIServerConfig(
@ -238,11 +252,14 @@ class SWERLEnv(BaseEnv):
return env_config, server_configs
async def setup(self):
self.logger.info(
f"Setting up SWE-RL environment with dataset: {self.config.dataset_name}"
)
self.logger.info("Setting up SWE-RL environment...")
self.train_dataset = []
self.test_dataset = []
train_data_raw = None
eval_data_raw = None
eval_sampled_from_train = (
False # Flag to track if eval set was sampled from train
)
try:
# Load training data
@ -253,96 +270,174 @@ class SWERLEnv(BaseEnv):
path=self.config.dataset_name,
name=self.config.dataset_config_name,
split=self.config.dataset_split_train,
streaming=False,
streaming=False, # Ensure we get a Dataset object for potential sampling
)
if self.config.max_train_samples is not None:
self.logger.info(
f"Taking max {self.config.max_train_samples} training samples."
f"Applying max_train_samples: {self.config.max_train_samples} to training data."
)
# For datasets.Dataset, select directly. For IterableDataset, .take()
if hasattr(train_data_raw, "select"): # It's a Dataset
indices = range(
min(len(train_data_raw), self.config.max_train_samples)
)
train_data_raw = train_data_raw.select(indices)
else: # It's an IterableDataset
else: # It's an IterableDataset (less likely with streaming=False)
train_data_raw = train_data_raw.take(self.config.max_train_samples)
# Load evaluation data
self.logger.info(
f"Loading evaluation data from {self.config.dataset_name}, split: {self.config.dataset_split_eval}"
f"Training data size after max_train_samples: {len(train_data_raw) if hasattr(train_data_raw, '__len__') else 'unknown (iterable)'}" # noqa: E501
)
eval_data_raw = load_dataset(
path=self.config.dataset_name,
name=self.config.dataset_config_name,
split=self.config.dataset_split_eval,
streaming=False,
)
if self.config.max_test_samples is not None:
# --- Evaluation Data Handling ---
if self.config.dataset_name_eval is not None:
# Load dedicated evaluation dataset
eval_dataset_name = self.config.dataset_name_eval
eval_dataset_config_name = self.config.dataset_config_name_eval
self.logger.info(
f"Taking max {self.config.max_test_samples} test samples."
f"Loading dedicated evaluation data from {eval_dataset_name}, split: {self.config.dataset_split_eval}" # noqa: E501
)
if hasattr(eval_data_raw, "select"): # It's a Dataset
indices = range(
min(len(eval_data_raw), self.config.max_test_samples)
)
eval_data_raw = eval_data_raw.select(indices)
else: # It's an IterableDataset
eval_data_raw = eval_data_raw.take(self.config.max_test_samples)
eval_data_raw = load_dataset(
path=eval_dataset_name,
name=eval_dataset_config_name,
split=self.config.dataset_split_eval,
streaming=False,
)
else:
# No dedicated eval dataset, try to sample from training data
self.logger.info(
"No specific evaluation dataset provided. Attempting to sample up to 500 instances from training data for evaluation." # noqa: E501
)
num_eval_samples_to_take = 500
self.logger.info("Mapping dataset columns...")
for item_idx, raw_item in enumerate(train_data_raw):
try:
self.train_dataset.append(
{
"item_id": f"train_{self.config.dataset_name}_{item_idx}",
"issue": raw_item[self.config.dataset_issue_column],
"code_context": raw_item[
self.config.dataset_code_context_column
],
"oracle_patch": raw_item[
self.config.dataset_oracle_patch_column
],
}
)
except KeyError as e:
self.logger.error(
f"Column mapping error for training item {item_idx}: {e}. Skipping item. Raw: {str(raw_item)[:500]}" # noqa: E501
)
continue
if (
train_data_raw
and hasattr(train_data_raw, "select")
and len(train_data_raw) >= num_eval_samples_to_take
):
all_train_indices = list(range(len(train_data_raw)))
random.shuffle(all_train_indices)
for item_idx, raw_item in enumerate(eval_data_raw):
try:
self.test_dataset.append(
{
"item_id": f"test_{self.config.dataset_name}_{item_idx}",
"issue": raw_item[self.config.dataset_issue_column],
"code_context": raw_item[
self.config.dataset_code_context_column
],
"oracle_patch": raw_item[
self.config.dataset_oracle_patch_column
],
}
eval_indices = all_train_indices[:num_eval_samples_to_take]
remaining_train_indices = all_train_indices[
num_eval_samples_to_take:
]
eval_data_raw = train_data_raw.select(eval_indices)
original_train_len = len(train_data_raw)
train_data_raw = train_data_raw.select(
remaining_train_indices
) # Update train_data_raw
self.logger.info(
f"Successfully sampled {len(eval_data_raw)} instances from training data (original size {original_train_len}) for evaluation. " # noqa: E501
f"Remaining training data size: {len(train_data_raw)}." # noqa: E501
)
except KeyError as e:
self.logger.error(
f"Column mapping error for eval item {item_idx}: {e}. Skipping item. Raw: {str(raw_item)[:500]}"
eval_sampled_from_train = True
elif train_data_raw and hasattr(
train_data_raw, "select"
): # Train data exists but not enough samples
self.logger.warning(
f"Training data has {len(train_data_raw)} samples, which is less than the desired {num_eval_samples_to_take} for evaluation sampling. " # noqa: E501
f"No evaluation set will be derived from training data by sampling." # noqa: E501
)
continue
else: # train_data_raw is None, not a Dataset for sampling, or other issue
self.logger.warning(
"Could not sample from training data for evaluation (e.g., training data not loaded, not a suitable type for sampling, or empty). " # noqa: E501
"Evaluation set will be empty." # noqa: E501
)
# Apply max_test_samples to the resulting eval_data_raw, regardless of its origin
if eval_data_raw and self.config.max_test_samples is not None:
if len(eval_data_raw) > self.config.max_test_samples:
self.logger.info(
f"Applying max_test_samples: Capping evaluation set from {len(eval_data_raw)} to {self.config.max_test_samples}." # noqa: E501
)
if hasattr(eval_data_raw, "select"):
indices = range(
min(len(eval_data_raw), self.config.max_test_samples)
)
eval_data_raw = eval_data_raw.select(indices)
elif hasattr(
eval_data_raw, "take"
): # Fallback for IterableDataset (less likely here)
eval_data_raw = eval_data_raw.take(self.config.max_test_samples)
else:
self.logger.info(
f"Evaluation set has {len(eval_data_raw)} samples. max_test_samples ({self.config.max_test_samples}) is >= this, so no change to eval set size based on this cap." # noqa: E501
)
elif not eval_data_raw and self.config.max_test_samples is not None:
self.logger.info(
"max_test_samples is set, but there is no evaluation data to apply it to at this stage." # noqa: E501
)
# --- Populate self.train_dataset and self.test_dataset ---
self.logger.info("Mapping dataset columns for final train/test sets...")
if (
train_data_raw
): # train_data_raw could be empty if all were taken for eval
for item_idx, raw_item in enumerate(train_data_raw):
try:
self.train_dataset.append(
{
"item_id": f"train_{self.config.dataset_name.replace('/', '_')}_{item_idx}",
"issue": raw_item[self.config.dataset_issue_column],
"code_context": raw_item[
self.config.dataset_code_context_column
],
"oracle_patch": raw_item[
self.config.dataset_oracle_patch_column
],
}
)
except KeyError as e:
self.logger.error(
f"Column mapping error for training item {item_idx}: {e}. Skipping item. Raw: {str(raw_item)[:500]}" # noqa: E501
)
continue
eval_ds_name_for_item_id = "unknown_eval_source"
if self.config.dataset_name_eval:
eval_ds_name_for_item_id = self.config.dataset_name_eval.replace(
"/", "_"
)
elif eval_sampled_from_train:
eval_ds_name_for_item_id = (
f"{self.config.dataset_name.replace('/', '_')}_sampled_as_eval"
)
if eval_data_raw:
for item_idx, raw_item in enumerate(eval_data_raw):
try:
self.test_dataset.append(
{
"item_id": f"test_{eval_ds_name_for_item_id}_{item_idx}",
"issue": raw_item[self.config.dataset_issue_column],
"code_context": raw_item[
self.config.dataset_code_context_column
],
"oracle_patch": raw_item[
self.config.dataset_oracle_patch_column
],
}
)
except KeyError as e:
self.logger.error(
f"Column mapping error for eval item {item_idx} from '{eval_ds_name_for_item_id}': {e}. Skipping item. Raw: {str(raw_item)[:500]}" # noqa: E501
)
continue
random.shuffle(self.train_dataset)
random.shuffle(self.test_dataset)
self.logger.info(
f"Loaded dataset: {len(self.train_dataset)} training examples, {len(self.test_dataset)} test examples."
f"Loaded dataset: {len(self.train_dataset)} training examples, {len(self.test_dataset)} test examples." # noqa: E501
)
except Exception as e:
self.logger.error(
f"Error loading or processing dataset '{self.config.dataset_name}': {e}"
f"Error loading or processing dataset: {e}" # Simplified main error message
)
self.train_dataset = []
self.train_dataset = [] # Ensure they are reset on error
self.test_dataset = []
# Consider re-raising if dataset is critical: raise e
@ -351,9 +446,26 @@ class SWERLEnv(BaseEnv):
self.logger.warning(
"Training dataset is empty after setup. Check dataset configuration and availability."
)
if not self.test_dataset:
eval_source_description_for_log = "not available"
if self.config.dataset_name_eval:
eval_source_description_for_log = (
f"dedicated set ({self.config.dataset_name_eval})"
)
elif eval_sampled_from_train:
eval_source_description_for_log = (
f"sampled from training set ({self.config.dataset_name})"
)
elif (
self.config.dataset_name_eval is None
): # No dedicated, and sampling was not attempted or failed
eval_source_description_for_log = (
"sampling from train was not possible or yielded no data"
)
self.logger.warning(
"Test dataset is empty after setup. Check dataset configuration and availability."
f"Test dataset (source: {eval_source_description_for_log}) is empty after setup. Check dataset configuration and availability." # noqa: E501
)
async def get_next_item(self) -> Optional[Dict[str, str]]:
@ -780,7 +892,8 @@ class SWERLEnv(BaseEnv):
return # Don't proceed if directory creation fails
file_path = os.path.join(
self.datadumps_dir, f"rollouts_batch_{self.save_file_batch_num:04d}.jsonl"
self.datadumps_dir,
f"swe_rl_environment_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl",
)
try: