mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-03 17:53:17 +00:00
Fix up dataset and data dumps
This commit is contained in:
parent
ae0340bb9f
commit
e2ea82b29b
1 changed files with 185 additions and 72 deletions
|
|
@ -10,6 +10,7 @@ import logging # Add logging import
|
|||
import os # Add os import
|
||||
import random # Ensured import random is present
|
||||
import re
|
||||
import uuid # Import uuid module
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
|
@ -100,11 +101,11 @@ class SWERLEnvConfig(BaseEnvConfig):
|
|||
# HF Dataset Configs
|
||||
dataset_name: str = Field(
|
||||
default="princeton-nlp/SWE-bench_Lite_oracle",
|
||||
description="Name of the Hugging Face dataset to load.",
|
||||
description="Name of the Hugging Face dataset to load for training (and evaluation if dataset_name_eval is not set).", # noqa: E501
|
||||
)
|
||||
dataset_config_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Configuration name for the Hugging Face dataset (e.g., a subset).",
|
||||
description="Configuration name for the Hugging Face dataset for training (e.g., a subset) (and evaluation if dataset_config_name_eval is not set).", # noqa: E501
|
||||
)
|
||||
dataset_split_train: str = Field(
|
||||
default="train", description="Dataset split to use for training."
|
||||
|
|
@ -122,6 +123,15 @@ class SWERLEnvConfig(BaseEnvConfig):
|
|||
dataset_oracle_patch_column: str = Field(
|
||||
default="patch", description="Column name for the oracle patch."
|
||||
)
|
||||
# New fields for evaluation dataset
|
||||
dataset_name_eval: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional: Name of the Hugging Face dataset to load for evaluation. If None, uses dataset_name.",
|
||||
)
|
||||
dataset_config_name_eval: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional: Configuration name for the Hugging Face dataset for evaluation. If None, uses dataset_config_name.", # noqa: E501
|
||||
)
|
||||
max_train_samples: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Maximum number of training samples to load. None for all.",
|
||||
|
|
@ -185,13 +195,14 @@ class SWERLEnv(BaseEnv):
|
|||
self.think_tags_well_formed_buffer = []
|
||||
|
||||
# For saving rollouts to JSONL
|
||||
self.run_uuid = str(uuid.uuid4()) # Generate a UUID for this run
|
||||
self.rollouts_to_save_buffer: List[
|
||||
Dict[str, Union[str, List[Dict[str, Union[List[Dict[str, str]], float]]]]]
|
||||
] = []
|
||||
self.processed_item_count = 0
|
||||
# Creates .../atropos/environments/datadumps/ relative to the project structure
|
||||
# Creates .../atropos/environments/swe_rl/data_dumps/ relative to the project structure
|
||||
self.datadumps_dir = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "datadumps"
|
||||
os.path.dirname(os.path.abspath(__file__)), "swe_rl", "data_dumps"
|
||||
)
|
||||
self.save_file_batch_num = 0
|
||||
|
||||
|
|
@ -216,15 +227,18 @@ class SWERLEnv(BaseEnv):
|
|||
ensure_scores_are_not_same=True,
|
||||
eval_n_samples=1,
|
||||
# HF Dataset Configs
|
||||
dataset_name="princeton-nlp/SWE-bench_Lite_oracle",
|
||||
dataset_name="NousResearch/SWE-smith-oracle",
|
||||
dataset_config_name=None,
|
||||
dataset_split_train="test",
|
||||
dataset_split_train="train",
|
||||
dataset_split_eval="test",
|
||||
dataset_issue_column="problem_statement",
|
||||
dataset_code_context_column="text",
|
||||
dataset_oracle_patch_column="patch",
|
||||
max_train_samples=1000,
|
||||
max_test_samples=200,
|
||||
# Initialize new eval dataset fields
|
||||
dataset_name_eval=None, # Or a default eval dataset name
|
||||
dataset_config_name_eval=None, # Or a default eval config name
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
|
|
@ -238,11 +252,14 @@ class SWERLEnv(BaseEnv):
|
|||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
self.logger.info(
|
||||
f"Setting up SWE-RL environment with dataset: {self.config.dataset_name}"
|
||||
)
|
||||
self.logger.info("Setting up SWE-RL environment...")
|
||||
self.train_dataset = []
|
||||
self.test_dataset = []
|
||||
train_data_raw = None
|
||||
eval_data_raw = None
|
||||
eval_sampled_from_train = (
|
||||
False # Flag to track if eval set was sampled from train
|
||||
)
|
||||
|
||||
try:
|
||||
# Load training data
|
||||
|
|
@ -253,96 +270,174 @@ class SWERLEnv(BaseEnv):
|
|||
path=self.config.dataset_name,
|
||||
name=self.config.dataset_config_name,
|
||||
split=self.config.dataset_split_train,
|
||||
streaming=False,
|
||||
streaming=False, # Ensure we get a Dataset object for potential sampling
|
||||
)
|
||||
if self.config.max_train_samples is not None:
|
||||
self.logger.info(
|
||||
f"Taking max {self.config.max_train_samples} training samples."
|
||||
f"Applying max_train_samples: {self.config.max_train_samples} to training data."
|
||||
)
|
||||
# For datasets.Dataset, select directly. For IterableDataset, .take()
|
||||
if hasattr(train_data_raw, "select"): # It's a Dataset
|
||||
indices = range(
|
||||
min(len(train_data_raw), self.config.max_train_samples)
|
||||
)
|
||||
train_data_raw = train_data_raw.select(indices)
|
||||
else: # It's an IterableDataset
|
||||
else: # It's an IterableDataset (less likely with streaming=False)
|
||||
train_data_raw = train_data_raw.take(self.config.max_train_samples)
|
||||
|
||||
# Load evaluation data
|
||||
self.logger.info(
|
||||
f"Loading evaluation data from {self.config.dataset_name}, split: {self.config.dataset_split_eval}"
|
||||
f"Training data size after max_train_samples: {len(train_data_raw) if hasattr(train_data_raw, '__len__') else 'unknown (iterable)'}" # noqa: E501
|
||||
)
|
||||
eval_data_raw = load_dataset(
|
||||
path=self.config.dataset_name,
|
||||
name=self.config.dataset_config_name,
|
||||
split=self.config.dataset_split_eval,
|
||||
streaming=False,
|
||||
)
|
||||
if self.config.max_test_samples is not None:
|
||||
|
||||
# --- Evaluation Data Handling ---
|
||||
if self.config.dataset_name_eval is not None:
|
||||
# Load dedicated evaluation dataset
|
||||
eval_dataset_name = self.config.dataset_name_eval
|
||||
eval_dataset_config_name = self.config.dataset_config_name_eval
|
||||
self.logger.info(
|
||||
f"Taking max {self.config.max_test_samples} test samples."
|
||||
f"Loading dedicated evaluation data from {eval_dataset_name}, split: {self.config.dataset_split_eval}" # noqa: E501
|
||||
)
|
||||
if hasattr(eval_data_raw, "select"): # It's a Dataset
|
||||
indices = range(
|
||||
min(len(eval_data_raw), self.config.max_test_samples)
|
||||
)
|
||||
eval_data_raw = eval_data_raw.select(indices)
|
||||
else: # It's an IterableDataset
|
||||
eval_data_raw = eval_data_raw.take(self.config.max_test_samples)
|
||||
eval_data_raw = load_dataset(
|
||||
path=eval_dataset_name,
|
||||
name=eval_dataset_config_name,
|
||||
split=self.config.dataset_split_eval,
|
||||
streaming=False,
|
||||
)
|
||||
else:
|
||||
# No dedicated eval dataset, try to sample from training data
|
||||
self.logger.info(
|
||||
"No specific evaluation dataset provided. Attempting to sample up to 500 instances from training data for evaluation." # noqa: E501
|
||||
)
|
||||
num_eval_samples_to_take = 500
|
||||
|
||||
self.logger.info("Mapping dataset columns...")
|
||||
for item_idx, raw_item in enumerate(train_data_raw):
|
||||
try:
|
||||
self.train_dataset.append(
|
||||
{
|
||||
"item_id": f"train_{self.config.dataset_name}_{item_idx}",
|
||||
"issue": raw_item[self.config.dataset_issue_column],
|
||||
"code_context": raw_item[
|
||||
self.config.dataset_code_context_column
|
||||
],
|
||||
"oracle_patch": raw_item[
|
||||
self.config.dataset_oracle_patch_column
|
||||
],
|
||||
}
|
||||
)
|
||||
except KeyError as e:
|
||||
self.logger.error(
|
||||
f"Column mapping error for training item {item_idx}: {e}. Skipping item. Raw: {str(raw_item)[:500]}" # noqa: E501
|
||||
)
|
||||
continue
|
||||
if (
|
||||
train_data_raw
|
||||
and hasattr(train_data_raw, "select")
|
||||
and len(train_data_raw) >= num_eval_samples_to_take
|
||||
):
|
||||
all_train_indices = list(range(len(train_data_raw)))
|
||||
random.shuffle(all_train_indices)
|
||||
|
||||
for item_idx, raw_item in enumerate(eval_data_raw):
|
||||
try:
|
||||
self.test_dataset.append(
|
||||
{
|
||||
"item_id": f"test_{self.config.dataset_name}_{item_idx}",
|
||||
"issue": raw_item[self.config.dataset_issue_column],
|
||||
"code_context": raw_item[
|
||||
self.config.dataset_code_context_column
|
||||
],
|
||||
"oracle_patch": raw_item[
|
||||
self.config.dataset_oracle_patch_column
|
||||
],
|
||||
}
|
||||
eval_indices = all_train_indices[:num_eval_samples_to_take]
|
||||
remaining_train_indices = all_train_indices[
|
||||
num_eval_samples_to_take:
|
||||
]
|
||||
|
||||
eval_data_raw = train_data_raw.select(eval_indices)
|
||||
original_train_len = len(train_data_raw)
|
||||
train_data_raw = train_data_raw.select(
|
||||
remaining_train_indices
|
||||
) # Update train_data_raw
|
||||
|
||||
self.logger.info(
|
||||
f"Successfully sampled {len(eval_data_raw)} instances from training data (original size {original_train_len}) for evaluation. " # noqa: E501
|
||||
f"Remaining training data size: {len(train_data_raw)}." # noqa: E501
|
||||
)
|
||||
except KeyError as e:
|
||||
self.logger.error(
|
||||
f"Column mapping error for eval item {item_idx}: {e}. Skipping item. Raw: {str(raw_item)[:500]}"
|
||||
eval_sampled_from_train = True
|
||||
elif train_data_raw and hasattr(
|
||||
train_data_raw, "select"
|
||||
): # Train data exists but not enough samples
|
||||
self.logger.warning(
|
||||
f"Training data has {len(train_data_raw)} samples, which is less than the desired {num_eval_samples_to_take} for evaluation sampling. " # noqa: E501
|
||||
f"No evaluation set will be derived from training data by sampling." # noqa: E501
|
||||
)
|
||||
continue
|
||||
else: # train_data_raw is None, not a Dataset for sampling, or other issue
|
||||
self.logger.warning(
|
||||
"Could not sample from training data for evaluation (e.g., training data not loaded, not a suitable type for sampling, or empty). " # noqa: E501
|
||||
"Evaluation set will be empty." # noqa: E501
|
||||
)
|
||||
|
||||
# Apply max_test_samples to the resulting eval_data_raw, regardless of its origin
|
||||
if eval_data_raw and self.config.max_test_samples is not None:
|
||||
if len(eval_data_raw) > self.config.max_test_samples:
|
||||
self.logger.info(
|
||||
f"Applying max_test_samples: Capping evaluation set from {len(eval_data_raw)} to {self.config.max_test_samples}." # noqa: E501
|
||||
)
|
||||
if hasattr(eval_data_raw, "select"):
|
||||
indices = range(
|
||||
min(len(eval_data_raw), self.config.max_test_samples)
|
||||
)
|
||||
eval_data_raw = eval_data_raw.select(indices)
|
||||
elif hasattr(
|
||||
eval_data_raw, "take"
|
||||
): # Fallback for IterableDataset (less likely here)
|
||||
eval_data_raw = eval_data_raw.take(self.config.max_test_samples)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Evaluation set has {len(eval_data_raw)} samples. max_test_samples ({self.config.max_test_samples}) is >= this, so no change to eval set size based on this cap." # noqa: E501
|
||||
)
|
||||
elif not eval_data_raw and self.config.max_test_samples is not None:
|
||||
self.logger.info(
|
||||
"max_test_samples is set, but there is no evaluation data to apply it to at this stage." # noqa: E501
|
||||
)
|
||||
|
||||
# --- Populate self.train_dataset and self.test_dataset ---
|
||||
self.logger.info("Mapping dataset columns for final train/test sets...")
|
||||
if (
|
||||
train_data_raw
|
||||
): # train_data_raw could be empty if all were taken for eval
|
||||
for item_idx, raw_item in enumerate(train_data_raw):
|
||||
try:
|
||||
self.train_dataset.append(
|
||||
{
|
||||
"item_id": f"train_{self.config.dataset_name.replace('/', '_')}_{item_idx}",
|
||||
"issue": raw_item[self.config.dataset_issue_column],
|
||||
"code_context": raw_item[
|
||||
self.config.dataset_code_context_column
|
||||
],
|
||||
"oracle_patch": raw_item[
|
||||
self.config.dataset_oracle_patch_column
|
||||
],
|
||||
}
|
||||
)
|
||||
except KeyError as e:
|
||||
self.logger.error(
|
||||
f"Column mapping error for training item {item_idx}: {e}. Skipping item. Raw: {str(raw_item)[:500]}" # noqa: E501
|
||||
)
|
||||
continue
|
||||
|
||||
eval_ds_name_for_item_id = "unknown_eval_source"
|
||||
if self.config.dataset_name_eval:
|
||||
eval_ds_name_for_item_id = self.config.dataset_name_eval.replace(
|
||||
"/", "_"
|
||||
)
|
||||
elif eval_sampled_from_train:
|
||||
eval_ds_name_for_item_id = (
|
||||
f"{self.config.dataset_name.replace('/', '_')}_sampled_as_eval"
|
||||
)
|
||||
|
||||
if eval_data_raw:
|
||||
for item_idx, raw_item in enumerate(eval_data_raw):
|
||||
try:
|
||||
self.test_dataset.append(
|
||||
{
|
||||
"item_id": f"test_{eval_ds_name_for_item_id}_{item_idx}",
|
||||
"issue": raw_item[self.config.dataset_issue_column],
|
||||
"code_context": raw_item[
|
||||
self.config.dataset_code_context_column
|
||||
],
|
||||
"oracle_patch": raw_item[
|
||||
self.config.dataset_oracle_patch_column
|
||||
],
|
||||
}
|
||||
)
|
||||
except KeyError as e:
|
||||
self.logger.error(
|
||||
f"Column mapping error for eval item {item_idx} from '{eval_ds_name_for_item_id}': {e}. Skipping item. Raw: {str(raw_item)[:500]}" # noqa: E501
|
||||
)
|
||||
continue
|
||||
|
||||
random.shuffle(self.train_dataset)
|
||||
random.shuffle(self.test_dataset)
|
||||
|
||||
self.logger.info(
|
||||
f"Loaded dataset: {len(self.train_dataset)} training examples, {len(self.test_dataset)} test examples."
|
||||
f"Loaded dataset: {len(self.train_dataset)} training examples, {len(self.test_dataset)} test examples." # noqa: E501
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Error loading or processing dataset '{self.config.dataset_name}': {e}"
|
||||
f"Error loading or processing dataset: {e}" # Simplified main error message
|
||||
)
|
||||
self.train_dataset = []
|
||||
self.train_dataset = [] # Ensure they are reset on error
|
||||
self.test_dataset = []
|
||||
# Consider re-raising if dataset is critical: raise e
|
||||
|
||||
|
|
@ -351,9 +446,26 @@ class SWERLEnv(BaseEnv):
|
|||
self.logger.warning(
|
||||
"Training dataset is empty after setup. Check dataset configuration and availability."
|
||||
)
|
||||
|
||||
if not self.test_dataset:
|
||||
eval_source_description_for_log = "not available"
|
||||
if self.config.dataset_name_eval:
|
||||
eval_source_description_for_log = (
|
||||
f"dedicated set ({self.config.dataset_name_eval})"
|
||||
)
|
||||
elif eval_sampled_from_train:
|
||||
eval_source_description_for_log = (
|
||||
f"sampled from training set ({self.config.dataset_name})"
|
||||
)
|
||||
elif (
|
||||
self.config.dataset_name_eval is None
|
||||
): # No dedicated, and sampling was not attempted or failed
|
||||
eval_source_description_for_log = (
|
||||
"sampling from train was not possible or yielded no data"
|
||||
)
|
||||
|
||||
self.logger.warning(
|
||||
"Test dataset is empty after setup. Check dataset configuration and availability."
|
||||
f"Test dataset (source: {eval_source_description_for_log}) is empty after setup. Check dataset configuration and availability." # noqa: E501
|
||||
)
|
||||
|
||||
async def get_next_item(self) -> Optional[Dict[str, str]]:
|
||||
|
|
@ -780,7 +892,8 @@ class SWERLEnv(BaseEnv):
|
|||
return # Don't proceed if directory creation fails
|
||||
|
||||
file_path = os.path.join(
|
||||
self.datadumps_dir, f"rollouts_batch_{self.save_file_batch_num:04d}.jsonl"
|
||||
self.datadumps_dir,
|
||||
f"swe_rl_environment_rollouts_{self.run_uuid}_{self.save_file_batch_num:04d}.jsonl",
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue