diff --git a/GALLERY.md b/GALLERY.md index 6180ca5c..48fc0caf 100644 --- a/GALLERY.md +++ b/GALLERY.md @@ -20,6 +20,7 @@ This gallery shows examples from all available datasets using their default conf - [gcd](#gcd) - [intermediate_integration](#intermediate_integration) - [largest_island](#largest_island) +- [course_schedule](#course_schedule) - [lcm](#lcm) - [leg_counting](#leg_counting) - [letter_counting](#letter_counting) @@ -971,6 +972,64 @@ Metadata: {'grid': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, ```` +### course_schedule + +Generates course schedule exercises, and checks if the given course schedule is valid + +Default configuration: +```python +num_courses = 5 +max_num_prerequisites = 2 +p_solvable = 0.5 +min_cycle_length = 3 +max_cycle_length = 5 +``` + +Example tasks: +```` +Example 1: +Question: There are a total of 5 courses you have to take, labeled from 0 to 4. + +You are given the following list of prerequisites, where prerequisites[i] = (a_i, b_i) indicates that you must first take course b_i first if you want to take course a_i: +[(2, 1), (4, 2), (4, 3), (2, 3)] + +Return True if you can finish all courses considering the prerequisites, or False otherwise. + +Answer: True + +Metadata: {'courses': [3, 1, 2, 4, 0], 'prerequisites': [(2, 1), (4, 2), (4, 3), (2, 3)], 'solution': True, 'solvable': True} + +-------------------------------------------------- + +Example 2: +Question: There are a total of 5 courses you have to take, labeled from 0 to 4. + +You are given the following list of prerequisites, where prerequisites[i] = (a_i, b_i) indicates that you must first take course b_i first if you want to take course a_i: +[(3, 0), (2, 4), (2, 3), (4, 1), (3, 1), (0, 1), (0, 2), (1, 3)] + +Return True if you can finish all courses considering the prerequisites, or False otherwise. + +Answer: False + +Metadata: {'courses': [1, 4, 3, 2, 0], 'prerequisites': [(3, 0), (2, 4), (2, 3), (4, 1), (3, 1), (0, 1), (0, 2), (1, 3)], 'solution': False, 'solvable': False} + +-------------------------------------------------- + +Example 3: +Question: There are a total of 5 courses you have to take, labeled from 0 to 4. + +You are given the following list of prerequisites, where prerequisites[i] = (a_i, b_i) indicates that you must first take course b_i first if you want to take course a_i: +[] + +Return True if you can finish all courses considering the prerequisites, or False otherwise. + +Answer: True + +Metadata: {'courses': [2, 1, 4, 0, 3], 'prerequisites': [], 'solution': True, 'solvable': True} + +-------------------------------------------------- +```` + ### lcm Generates Least Common Multiple (LCM) tasks diff --git a/README.md b/README.md index 0fedd335..2107b030 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets - `FamilyRelationshipsDataset`: Generate family relationship reasoning tasks with family trees - `QuantumLockDataset`: Generates puzzles which involve stateful arithmetic and a correct sequence of operations - `LargestIslandDataset`: Generate a grid with islands and find the largest one +- `CourseScheduleDataset`: Generate a course schedule with prerequisites and find whether you can complete all courses ### Game Tasks diff --git a/reasoning_gym/graphs/__init__.py b/reasoning_gym/graphs/__init__.py index ee722e38..4d1ccd8f 100644 --- a/reasoning_gym/graphs/__init__.py +++ b/reasoning_gym/graphs/__init__.py @@ -1,3 +1,4 @@ +from .course_schedule import CourseScheduleConfig, CourseScheduleDataset from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset from .largest_island import LargestIslandDataset from .quantum_lock import QuantumLockConfig, QuantumLockDataset @@ -8,4 +9,6 @@ __all__ = [ "QuantumLockConfig", "QuantumLockDataset", "LargestIslandDataset", + "CourseScheduleDataset", + "CourseScheduleConfig", ] diff --git a/reasoning_gym/graphs/course_schedule.py b/reasoning_gym/graphs/course_schedule.py new file mode 100644 index 00000000..22f597bf --- /dev/null +++ b/reasoning_gym/graphs/course_schedule.py @@ -0,0 +1,145 @@ +""" +Determine if you can complete all courses given their prerequisite relationships. + +A popular topological sort Leetcode problem: +https://leetcode.com/problems/course-schedule/description/ +""" + +from collections import defaultdict +from dataclasses import dataclass +from random import Random +from typing import List, Optional + +from ..factory import ProceduralDataset, register_dataset + +MAX_NUM_COURSES = 1_000 + +QUESTION_TEMPLATE = """There are a total of {num_courses} courses you have to take, labeled from 0 to {last_index}. + +You are given the following list of prerequisites, where prerequisites[i] = (a_i, b_i) indicates that you must first take course b_i first if you want to take course a_i: +{prerequisites} + +Return True if you can finish all courses considering the prerequisites, or False otherwise. +""" + + +@dataclass +class CourseScheduleConfig: + """Configuration for Course Schedule dataset generation""" + + num_courses: int = 5 # Total number of courses (ranging from 0 to num_courses - 1) + max_num_prerequisites: int = 2 # Maximum number of prerequisites (per course) + p_solvable: float = 0.5 # Probability that the course schedule is solvable + min_cycle_length: int = 3 # Minimum length of a cycle in the prerequisites (if unsolvable) + max_cycle_length: int = 5 # Maximum length of a cycle in the prerequisites (if unsolvable) + + size: int = 500 # Virtual dataset size + seed: Optional[int] = None + + def validate(self): + """Validate configuration parameters""" + assert 1 <= self.num_courses <= MAX_NUM_COURSES, f"num_courses must be between 1 and {MAX_NUM_COURSES}" + assert ( + 1 <= self.max_num_prerequisites <= self.num_courses + ), "max_num_prerequisites must be between 1 and num_courses" + assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1" + assert ( + 3 <= self.min_cycle_length <= self.max_cycle_length + ), "min_cycle_length must be between 3 and max_cycle_length" + + +class CourseScheduleDataset(ProceduralDataset): + """Generates Course Schedule exercises with configurable difficulty""" + + def __init__(self, config: CourseScheduleConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __len__(self) -> int: + return self.config.size + + def __iter__(self): + self._current_idx = 0 + return self + + def __next__(self): + if self._current_idx >= self.config.size: + raise StopIteration + item = self[self._current_idx] + self._current_idx += 1 + return item + + def _can_finish(self, num_courses: int, prerequisites: List[List[int]]) -> bool: + adj = defaultdict(list) + for course, prereq in prerequisites: + adj[course].append(prereq) + + visited, cycle = set(), set() + + def topological_sort(idx): + if idx in cycle: + return False + if idx in visited: + return True + + cycle.add(idx) + for nei in adj[idx]: + if not topological_sort(nei): + return False + cycle.remove(idx) + visited.add(idx) + return True + + for i in range(num_courses): + if not topological_sort(i): + return False + + return True + + def _create_prerequisites(self, rng: Random, courses: List[int], solvable: bool) -> List[List[int]]: + """Create a list of prerequisites for each course""" + prerequisites = [] + # Generate a valid course schedule + for idx in range(len(courses) - 1, 0, -1): + current_course = courses[idx] + available_prereqs = courses[:idx] # Only earlier courses can be prerequisites + num_prerequisites = rng.randint(0, min(len(available_prereqs), self.config.max_num_prerequisites)) + if num_prerequisites > 0: + chosen_prereqs = rng.sample(available_prereqs, num_prerequisites) + prerequisites.extend([[current_course, p] for p in chosen_prereqs]) + + if not solvable: + # If solution should be unsolvable, create a cycle + cycle_length = rng.randint(self.config.min_cycle_length, min(self.config.max_cycle_length, len(courses))) + cycle_courses = rng.sample(courses, cycle_length) + for i in range(cycle_length): + prerequisites.append([cycle_courses[i], cycle_courses[(i + 1) % cycle_length]]) + + # remove potential duplicates + prerequisites = list(set(tuple(prereq) for prereq in prerequisites)) + rng.shuffle(prerequisites) + return prerequisites + + def __getitem__(self, idx: int) -> dict: + """Generate a single Course Schedule question""" + rng = Random(self.seed + idx) + + courses = list(range(self.config.num_courses)) + rng.shuffle(courses) + + solvable = rng.random() < self.config.p_solvable + + prerequisites = self._create_prerequisites(rng, courses, solvable) + answer = self._can_finish(self.config.num_courses, prerequisites) + + return { + "question": QUESTION_TEMPLATE.format( + num_courses=self.config.num_courses, + last_index=self.config.num_courses - 1, + prerequisites=str(prerequisites), + ), + "answer": str(answer), + "metadata": {"courses": courses, "prerequisites": prerequisites, "solution": answer, "solvable": solvable}, + } + + +register_dataset("course_schedule", CourseScheduleDataset, CourseScheduleConfig) diff --git a/tests/test_course_schedule.py b/tests/test_course_schedule.py new file mode 100644 index 00000000..d62c41cf --- /dev/null +++ b/tests/test_course_schedule.py @@ -0,0 +1,127 @@ +"""Tests for Course Schedule puzzle generation""" + +import pytest + +from reasoning_gym.graphs.course_schedule import CourseScheduleConfig, CourseScheduleDataset + + +def test_course_schedule_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = CourseScheduleConfig(num_courses=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CourseScheduleConfig(num_courses=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CourseScheduleConfig(max_num_prerequisites=-1) # Negative not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CourseScheduleConfig(max_num_prerequisites=0) # Zero not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CourseScheduleConfig(num_courses=3, max_num_prerequisites=5) # max_num_prerequisites > num_courses + config.validate() + + with pytest.raises(AssertionError): + config = CourseScheduleConfig(p_solvable=-0.1) # < 0 not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CourseScheduleConfig(p_solvable=1.1) # > 1 not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CourseScheduleConfig(p_solvable=1.1) # > 1 not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CourseScheduleConfig(min_cycle_length=2) # < 3 not allowed + config.validate() + + with pytest.raises(AssertionError): + config = CourseScheduleConfig(min_cycle_length=3, max_cycle_length=2) # min_cycle_length > max_cycle_length + config.validate() + + +def test_course_schedule_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = CourseScheduleConfig(seed=42, size=10) + dataset1 = CourseScheduleDataset(config) + dataset2 = CourseScheduleDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_course_schedule_dataset_items(): + """Test basic properties of generated items""" + config = CourseScheduleConfig(num_courses=15, size=10, seed=42) + dataset = CourseScheduleDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + # Check item structure + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata + assert "courses" in item["metadata"] + assert "prerequisites" in item["metadata"] + assert "solution" in item["metadata"] + assert "solvable" in item["metadata"] + + courses = item["metadata"]["courses"] + prerequisites = item["metadata"]["prerequisites"] + solvable = item["metadata"]["solvable"] # Solution dictated by p_solvable + solution = item["metadata"]["solution"] # Solution obtained from topological sort + + # Verify metadata + assert len(courses) == config.num_courses + assert max(courses) == config.num_courses - 1 + assert len(prerequisites) <= config.max_num_prerequisites * config.num_courses + assert all(len(prereq) == 2 for prereq in prerequisites) + for course, prereq in prerequisites: + assert course < config.num_courses + assert prereq < config.num_courses + assert course != prereq + assert solution == solvable + + +def test_course_schedule_dataset_iteration(): + """Test that iteration respects dataset size""" + config = CourseScheduleConfig(size=5, seed=42) + dataset = CourseScheduleDataset(config) + + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same items + assert items == list(dataset) + + +def test_course_schedule_answer(): + """Test the _can_finish method""" + config = CourseScheduleConfig(seed=42) + dataset = CourseScheduleDataset(config) + + prerequisites = [[0, 1]] + assert dataset._can_finish(num_courses=2, prerequisites=prerequisites) == True + + # Direct cycle + prerequisites = [[0, 1], [1, 0]] + assert dataset._can_finish(num_courses=2, prerequisites=prerequisites) == False + + # Empty prerequisites + prerequisites = [] + assert dataset._can_finish(num_courses=2, prerequisites=prerequisites) == True + + # Indirect cycle of length 3 + prerequisites = [[0, 1], [1, 2], [2, 0]] + assert dataset._can_finish(num_courses=3, prerequisites=prerequisites) == False