diff --git a/reasoning_gym/data/acre_objects.json b/reasoning_gym/data/acre_objects.json new file mode 100644 index 00000000..ffb2a8c1 --- /dev/null +++ b/reasoning_gym/data/acre_objects.json @@ -0,0 +1,290 @@ +{ + "0": { + "material": "rubber", + "size": "medium", + "shape": "cube", + "color": "gray" + }, + "1": { + "material": "metal", + "size": "medium", + "shape": "cube", + "color": "gray" + }, + "2": { + "material": "rubber", + "size": "medium", + "shape": "cube", + "color": "red" + }, + "3": { + "material": "metal", + "size": "medium", + "shape": "cube", + "color": "red" + }, + "4": { + "material": "rubber", + "size": "medium", + "shape": "cube", + "color": "blue" + }, + "5": { + "material": "metal", + "size": "medium", + "shape": "cube", + "color": "blue" + }, + "6": { + "material": "rubber", + "size": "medium", + "shape": "cube", + "color": "green" + }, + "7": { + "material": "metal", + "size": "medium", + "shape": "cube", + "color": "green" + }, + "8": { + "material": "rubber", + "size": "medium", + "shape": "cube", + "color": "brown" + }, + "9": { + "material": "metal", + "size": "medium", + "shape": "cube", + "color": "brown" + }, + "10": { + "material": "rubber", + "size": "medium", + "shape": "cube", + "color": "purple" + }, + "11": { + "material": "metal", + "size": "medium", + "shape": "cube", + "color": "purple" + }, + "12": { + "material": "rubber", + "size": "medium", + "shape": "cube", + "color": "cyan" + }, + "13": { + "material": "metal", + "size": "medium", + "shape": "cube", + "color": "cyan" + }, + "14": { + "material": "rubber", + "size": "medium", + "shape": "cube", + "color": "yellow" + }, + "15": { + "material": "metal", + "size": "medium", + "shape": "cube", + "color": "yellow" + }, + "16": { + "material": "rubber", + "size": "medium", + "shape": "cylinder", + "color": "gray" + }, + "17": { + "material": "metal", + "size": "medium", + "shape": "cylinder", + "color": "gray" + }, + "18": { + "material": "rubber", + "size": "medium", + "shape": "cylinder", + "color": "red" + }, + "19": { + "material": "metal", + "size": "medium", + "shape": "cylinder", + "color": "red" + }, + "20": { + "material": "rubber", + "size": "medium", + "shape": "cylinder", + "color": "blue" + }, + "21": { + "material": "metal", + "size": "medium", + "shape": "cylinder", + "color": "blue" + }, + "22": { + "material": "rubber", + "size": "medium", + "shape": "cylinder", + "color": "green" + }, + "23": { + "material": "metal", + "size": "medium", + "shape": "cylinder", + "color": "green" + }, + "24": { + "material": "rubber", + "size": "medium", + "shape": "cylinder", + "color": "brown" + }, + "25": { + "material": "metal", + "size": "medium", + "shape": "cylinder", + "color": "brown" + }, + "26": { + "material": "rubber", + "size": "medium", + "shape": "cylinder", + "color": "purple" + }, + "27": { + "material": "metal", + "size": "medium", + "shape": "cylinder", + "color": "purple" + }, + "28": { + "material": "rubber", + "size": "medium", + "shape": "cylinder", + "color": "cyan" + }, + "29": { + "material": "metal", + "size": "medium", + "shape": "cylinder", + "color": "cyan" + }, + "30": { + "material": "rubber", + "size": "medium", + "shape": "cylinder", + "color": "yellow" + }, + "31": { + "material": "metal", + "size": "medium", + "shape": "cylinder", + "color": "yellow" + }, + "32": { + "material": "rubber", + "size": "medium", + "shape": "sphere", + "color": "gray" + }, + "33": { + "material": "metal", + "size": "medium", + "shape": "sphere", + "color": "gray" + }, + "34": { + "material": "rubber", + "size": "medium", + "shape": "sphere", + "color": "red" + }, + "35": { + "material": "metal", + "size": "medium", + "shape": "sphere", + "color": "red" + }, + "36": { + "material": "rubber", + "size": "medium", + "shape": "sphere", + "color": "blue" + }, + "37": { + "material": "metal", + "size": "medium", + "shape": "sphere", + "color": "blue" + }, + "38": { + "material": "rubber", + "size": "medium", + "shape": "sphere", + "color": "green" + }, + "39": { + "material": "metal", + "size": "medium", + "shape": "sphere", + "color": "green" + }, + "40": { + "material": "rubber", + "size": "medium", + "shape": "sphere", + "color": "brown" + }, + "41": { + "material": "metal", + "size": "medium", + "shape": "sphere", + "color": "brown" + }, + "42": { + "material": "rubber", + "size": "medium", + "shape": "sphere", + "color": "purple" + }, + "43": { + "material": "metal", + "size": "medium", + "shape": "sphere", + "color": "purple" + }, + "44": { + "material": "rubber", + "size": "medium", + "shape": "sphere", + "color": "cyan" + }, + "45": { + "material": "metal", + "size": "medium", + "shape": "sphere", + "color": "cyan" + }, + "46": { + "material": "rubber", + "size": "medium", + "shape": "sphere", + "color": "yellow" + }, + "47": { + "material": "metal", + "size": "medium", + "shape": "sphere", + "color": "yellow" + } +} diff --git a/reasoning_gym/induction/__init__.py b/reasoning_gym/induction/__init__.py index f6dc1504..814d846c 100644 --- a/reasoning_gym/induction/__init__.py +++ b/reasoning_gym/induction/__init__.py @@ -2,9 +2,7 @@ Arithmetic tasks for training reasoning capabilities: """ +from .acre.acre import ACREDataset, ACREDatasetConfig from .list_functions import ListFunctionsDataset, ListFunctionsDatasetConfig -__all__ = [ - "ListFunctionsDataset", - "ListFunctionsDatasetConfig", -] +__all__ = ["ListFunctionsDataset", "ListFunctionsDatasetConfig", "ACREDataset", "ACREDatasetConfig"] diff --git a/reasoning_gym/induction/acre/acre.py b/reasoning_gym/induction/acre/acre.py new file mode 100644 index 00000000..344c9dc8 --- /dev/null +++ b/reasoning_gym/induction/acre/acre.py @@ -0,0 +1,93 @@ +"""ACRE(Abstract Causal REasoning Beyond Covariation) dataset""" + +# Culled and Adapted from https://github.com/WellyZhang/ACRE + +from dataclasses import dataclass +from random import Random +from typing import Optional + +from reasoning_gym.factory import ProceduralDataset, register_dataset + +from .blicket import config_control, dist_control, final_parse, serialize +from .const import ALL_CONFIG_SIZE, ATTR_CONFIG_SIZE + + +# Create blicket questions +@dataclass +class ACREDatasetConfig: + """Configuration for ACRE dataset generation""" + + train: int = 1 # The default is 1 for training, otherwise 0 for validation and testing + size: int = 500 # Split ratio = 6 : 2 : 2 -> IID : Comp : Sys + seed: Optional[int] = None + + def validate(self) -> None: + """Validate configuration parameters""" + assert self.train in (0, 1), "train must be either 0 or 1" + assert self.size > 0, "Dataset size must be positive." + + +class ACREDataset(ProceduralDataset): + + def __init__(self, config: ACREDatasetConfig): + super().__init__(config, config.seed, config.size) + self.questions = self._generate_questions() + self.prompt_template = """You are a researcher studying causal relationships using Blicket experiments. In these experiments, certain objects (called 'blickets') have the hidden property of activating a detector, causing its light to turn on. + +Each example shows the results of placing different combinations of objects on the detector. Each object is described by color, material and shape. Your task is to determine whether a new combination of objects will cause the detector to activate. + +After observing the previous examples, respond with: +- "on" if you can determine the detector light will turn on +- "off" if you can determine the detector light will stay off +- "undetermined" if there is insufficient evidence to reach a conclusion + +Do not use quotation marks in your answer. + +Previous experimental results: +{examples} + +New test case: +{input} + +What is the detector light status?""" + + def _generate_questions(self): + """ + Generates questions of particular size + """ + + iid_size = int(0.6 * self.config.size) + comp_size = int(0.2 * self.config.size) + sys_size = self.config.size - (iid_size + comp_size) + rng = Random(self.seed) + iid_questions = config_control(iid_size, self.config.train, ALL_CONFIG_SIZE, "IID", rng) + comp_questions = config_control(comp_size, self.config.train, ATTR_CONFIG_SIZE, "Comp", rng) + sys_questions = dist_control(sys_size, self.config.train, "Sys", rng) + + questions = [] + questions.extend(iid_questions) + questions.extend(comp_questions) + questions.extend(sys_questions) + rng.shuffle(questions) + final_questions = final_parse(serialized_questions=serialize(questions)) + return final_questions + + def __getitem__(self, idx: int) -> dict: + """Generate a single induction-based list function dataset""" + input = self.questions[idx] + examples = input["examples"] + formatted_examples = "" + for object in examples: + input_ = ", ".join(" ".join(x) for x in object["input"]) + output = object["output"] + if len(formatted_examples) > 0: + formatted_examples += "\n" + formatted_examples += f"{input_} → {output}" + + prompt_input = ", ".join(" ".join(x) for x in input["question"]["input"]) + answer = input["question"]["output"] + question = self.prompt_template.format(examples=formatted_examples, input=prompt_input) + return {"question": question, "answer": answer, "metadata": {}} + + +register_dataset("acre", ACREDataset, ACREDatasetConfig) diff --git a/reasoning_gym/induction/acre/blicket.py b/reasoning_gym/induction/acre/blicket.py new file mode 100644 index 00000000..7d3658a1 --- /dev/null +++ b/reasoning_gym/induction/acre/blicket.py @@ -0,0 +1,687 @@ +# -*- coding: utf-8 -*- + + +import json +from typing import Any, Dict, List + +from reasoning_gym.data import get_data_file_path + +from .const import ( + ALL_CONFIG_SIZE, + ONOFFOFF_MAX_NON, + ONOFFOFF_MAX_POTENTIAL, + ONOFFOFF_MIN_NON, + ONOFFOFF_MIN_POTENTIAL, + ONONOFF_MAX_NON, + ONONOFF_MAX_POTENTIAL, + ONONOFF_MIN_NON, + ONONOFF_MIN_POTENTIAL, +) + + +class BlicketView(object): + def __init__(self, light_state="no"): + # light state: "no" for no light, "off" for light off, "on" for light on + self.objects = [] + self.light_state = light_state + + def add_objects(self, objects): + for obj in objects: + self.objects.append(obj) + + def remove_objects(self, objects): + for obj in objects: + self.objects.remove(obj) + + def __repr__(self): + return "BlicketView(objects={}, light_state={})".format(self.objects, self.light_state) + + +class BlicketQuestion(object): + def __init__( + self, + min_potential_blickets, + max_potential_blickets, + min_non_blickets, + max_non_blickets, + config_size, + shuffle, + rng, + ): + self.min_potential_blickets = min_potential_blickets + self.max_potential_blickets = max_potential_blickets + self.min_non_blickets = min_non_blickets + self.max_non_blickets = max_non_blickets + self.config_size = config_size + self.shuffle = shuffle + self.rng = rng + + potential_blicket_num = self.rng.randint(self.min_potential_blickets, self.max_potential_blickets) + non_blicket_num = self.rng.randint(self.min_non_blickets, self.max_non_blickets) + + samples = self.rng.sample(list(range(self.config_size)), k=potential_blicket_num + non_blicket_num) + + self.blickets = [] + self.set_blickets = [] + self.potential_blickets = samples[:potential_blicket_num] + self.non_blickets = samples[potential_blicket_num:] + + self.direct = [] + self.indirect = [] + self.screen_off = [] + + def get_habituation_views(self): + blicket_obj = self.potential_blickets[0] + self.add_blicket(blicket_obj) + + non_blicket_obj = self.non_blickets[0] + + view_with_blicket = BlicketView("on") + view_with_blicket.add_objects([blicket_obj]) + + view_with_non_blicket = BlicketView("off") + view_with_non_blicket.add_objects([non_blicket_obj]) + + view_with_both = BlicketView("on") + view_with_both.add_objects([blicket_obj, non_blicket_obj]) + + habituation_views = [view_with_blicket, view_with_non_blicket, view_with_both] + + # bookkeeping for candidate choices + self.direct.append(blicket_obj) + self.screen_off.append(non_blicket_obj) + + return habituation_views + + def get_evidence_views(self): + raise NotImplementedError("The parent class should not be used.") + + def get_views(self): + habituation_views = self.get_habituation_views() + evidence_views = self.get_evidence_views() + if self.shuffle: + self.rng.shuffle(habituation_views) + self.rng.shuffle(evidence_views) + + return habituation_views + evidence_views + + def sanity_check(self): + # no duplicates + assert len(self.direct) == len(set(self.direct)) + assert len(self.indirect) == len(set(self.indirect)) + assert len(self.screen_off) == len(set(self.screen_off)) + # no intersection + assert len(set(self.direct).intersection(self.indirect)) == 0 + assert len(set(self.direct).intersection(self.screen_off)) == 0 + assert len(set(self.indirect).intersection(self.screen_off)) == 0 + # completeness + assert set(self.direct + self.indirect + self.screen_off) == set(self.blickets + self.non_blickets) + + def check_blickets(self, views): + blickets = set() + non_blickets = set() + potential_blickets = set() + + direct = set() + indirect = set() + screen_off = set() + + on_views = [view for view in views if view.light_state == "on"] + off_views = [view for view in views if view.light_state == "off"] + + assert len(on_views) + len(off_views) == len(views) + + for off_view in off_views: + non_blickets.update(off_view.objects) + for on_view in on_views: + if len(on_view.objects) == 1: + blickets.update(on_view.objects) + else: + diff_set = set(on_view.objects).difference(non_blickets) + if len(diff_set) == 1: + blickets.update(diff_set) + all_objects = set() + for view in views: + all_objects.update(view.objects) + potential_blickets.update(all_objects.difference(non_blickets).difference(blickets)) + + assert blickets == set(self.blickets) + assert non_blickets == set(self.non_blickets) + assert potential_blickets == set(self.potential_blickets) + + for on_view in on_views: + if len(on_view.objects) == 1: + direct.update(on_view.objects) + on_view_objects = set() + for on_view in on_views: + on_view_objects.update(on_view.objects) + off_view_objects = set() + for off_view in off_views: + off_view_objects.update(off_view.objects) + + direct.update(off_view_objects.difference(on_view_objects)) + screen_off.update(off_view_objects.intersection(on_view_objects)) + + for on_view in on_views: + if len(on_view.objects) > 1: + diff_set = set(on_view.objects).difference(non_blickets) + if len(diff_set) == 1 and not diff_set.issubset(direct): + indirect.update(diff_set) + + assert direct == set(self.direct) + assert indirect == set(self.indirect) + assert screen_off == set(self.screen_off) + + set_blickets = set() + for on_view in on_views: + on_diff_set = set(on_view.objects).difference(non_blickets) + if len(on_diff_set.intersection(blickets)) == 0: + potential_set_blicket = list(on_diff_set) + potential_set_blicket.sort() + set_blickets.add(tuple(potential_set_blicket)) + remove_set = set() + for elem in set_blickets: + for another_elem in set_blickets: + elem_set = set(elem) + another_elem_set = set(another_elem) + if elem_set < another_elem_set: + remove_set.add(another_elem) + set_blickets.difference_update(remove_set) + + assert set_blickets == set(self.set_blickets), "set_blickets:{}, self.set_blickets:{}".format( + set_blickets, self.set_blickets + ) + + def check_labels(self, view, label): + diff_set = set(view.objects).difference(self.non_blickets) + if len(diff_set) == 0: + assert label == 0 + else: + blicket_inter = diff_set.intersection(self.blickets) + if len(blicket_inter) > 0: + assert label == 2 + else: + if self.has_set_blicket(view): + assert label == 2 + else: + assert label == 1 + + def has_set_blicket(self, view): + for set_blicket in self.set_blickets: + if set(set_blicket).issubset(view.objects): + return True + return False + + def union_sample(self, union, must_have_one=False): + if must_have_one: + first_set = self.rng.sample(union, k=1) + else: + first_num = self.rng.randint(1, len(union)) + first_set = self.rng.sample(union, k=first_num) + second_set = list(set(union).difference(first_set)) + if len(second_set) > 0: + additional_num_min = 0 + else: + additional_num_min = 1 + additional_num = self.rng.randint(additional_num_min, len(first_set)) + additional_samples = self.rng.sample(first_set, k=additional_num) + second_set += additional_samples + + return first_set, second_set + + def add_noise(self, view): + # the first non blicket used in habituation and hence the skip + noise_num = self.rng.randint(0, len(self.non_blickets) - 1) + noise = self.rng.sample(self.non_blickets[1:], k=noise_num) + view.add_objects(noise) + + def add_blicket(self, obj): + self.potential_blickets.remove(obj) + self.blickets.append(obj) + + def add_set_blicket(self, set_blicket): + to_remove_set = [] + for already_set_blicket in self.set_blickets: + if set(set_blicket) < set(already_set_blicket): + to_remove_set.append(already_set_blicket) + if set(already_set_blicket) < set(set_blicket): + to_remove_set.append(set_blicket) + self.set_blickets.append(set_blicket) + for elem in to_remove_set: + self.set_blickets.remove(elem) + + def generate_cause_questions(self, train, regime="IID"): + + def fixed_sum_sample(lower, upper, fixed_sum): + values = [0] * len(lower) + assert sum(upper) >= fixed_sum + while True: + residual = fixed_sum + for i in range(len(values) - 1): + values[i] = self.rng.randint(lower[i], min(upper[i], residual)) + residual = residual - values[i] + if residual >= lower[-1] and residual <= upper[-1]: + values[-1] = residual + break + assert sum(values) == fixed_sum + return values + + # direct_sample, indirect_sample, screen_off_sample, potential_sample + # these magic numbers and changes are to adjust dataset statistics + constraint_lower = [0] * 3 + if regime == "Comp": + # for Comp split + constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 2, 0)] + if regime == "Sys": + if train: + # for Sys train split + constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 3, 0)] + else: + # for Sys val / test split + constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 1, 0)] + if regime == "IID": + # for IID split + constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 1, 0)] + potential_sample_num = min(len(self.potential_blickets), 1) + direct_sample_num, indirect_sample_num, screen_off_sample_num = fixed_sum_sample( + constraint_lower, constraint_upper, 2 - potential_sample_num + ) + direct_samples = self.rng.sample(self.direct, k=direct_sample_num) + indirect_samples = self.rng.sample(self.indirect, k=indirect_sample_num) + screen_off_samples = self.rng.sample(self.screen_off, k=screen_off_sample_num) + potential_samples = self.rng.sample(self.potential_blickets, k=potential_sample_num) + + questions = [] + # label: 0 for light off, 1 for unknown, 2 for light up + for direct_sample in direct_samples: + assert direct_sample in self.blickets or direct_sample in self.non_blickets + if direct_sample in self.blickets: + label = 2 + else: + label = 0 + cause_view = BlicketView("no") + cause_view.add_objects([direct_sample]) + questions.append((cause_view, label, "direct")) + for indirect_sample in indirect_samples: + assert indirect_sample in self.blickets + label = 2 + cause_view = BlicketView("no") + cause_view.add_objects([indirect_sample]) + questions.append((cause_view, label, "indirect")) + for screen_off_sample in screen_off_samples: + assert screen_off_sample in self.non_blickets + label = 0 + cause_view = BlicketView("no") + cause_view.add_objects([screen_off_sample]) + questions.append((cause_view, label, "screen_off")) + for potential_sample in potential_samples: + assert potential_sample in self.potential_blickets + label = 1 + cause_view = BlicketView("no") + cause_view.add_objects([potential_sample]) + questions.append((cause_view, label, "potential")) + + if self.shuffle: + self.rng.shuffle(questions) + + return questions + + def generate_intervention_questions(self, views): + + # on_views = [view for view in views if view.light_state == "on" and len(view.objects) >= 2] + off_views = [view for view in views if view.light_state == "off"] + + # on_view = rng.sample(on_views, k=1)[0] + # on_view_ref = views.index(on_view) + off_view = self.rng.sample(off_views, k=1)[0] + off_view_ref = views.index(off_view) + + questions = [] + + all_possibilities = ( + list(set(self.blickets + self.potential_blickets + self.non_blickets).difference(off_view.objects)) + + self.set_blickets + ) + # adjust weight during sample for better statistics + all_possibilities += list(set(self.potential_blickets).difference(off_view.objects)) + possibilities = self.rng.sample(all_possibilities, k=2) + for possibility in possibilities: + if possibility in self.direct: + q_type = "direct" + elif possibility in self.indirect: + q_type = "indirect" + elif possibility in self.screen_off: + q_type = "screen_off" + elif possibility in self.potential_blickets: + q_type = "potential" + else: + assert possibility in self.set_blickets + q_type = "indirect" + if possibility in self.blickets: + label = 2 + elif possibility in self.potential_blickets: + label = 1 + elif possibility in self.non_blickets: + label = 0 + else: + assert possibility in self.set_blickets + label = 2 + intervention_view = BlicketView("no") + intervention_view.add_objects(off_view.objects) + if type(possibility) == tuple: + possibility = list(possibility) + else: + possibility = [possibility] + intervention_view.add_objects(possibility) + questions.append((intervention_view, label, q_type, off_view_ref)) + + if self.shuffle: + self.rng.shuffle(questions) + + return questions + + +class OnOffOff(BlicketQuestion): + def __init__( + self, + min_potential_blickets, + max_potential_blickets, + min_non_blickets, + max_non_blickets, + config_size, + shuffle, + rng, + ): + super(OnOffOff, self).__init__( + min_potential_blickets, + max_potential_blickets, + min_non_blickets, + max_non_blickets, + config_size, + shuffle, + rng, + ) + + def get_evidence_views(self): + # the first non blicket used in habituation and hence the skip + first_off, second_off = self.union_sample(self.non_blickets[1:]) + + first_off_view = BlicketView("off") + first_off_view.add_objects(first_off) + + second_off_view = BlicketView("off") + second_off_view.add_objects(second_off) + + on_view = BlicketView("on") + + on_list = self.potential_blickets[:] + + if len(on_list) == 1: + self.add_blicket(on_list[0]) + else: + set_blicket = on_list[:] + set_blicket.sort() + set_blicket = tuple(set_blicket) + self.add_set_blicket(set_blicket) + + on_view.add_objects(on_list) + + self.add_noise(on_view) + + evidence_views = [on_view, first_off_view, second_off_view] + + # bookkeeping for candidate choices + off_union_set = set(self.non_blickets[1:]) + on_set = set(on_view.objects) + on_diff_set = on_set.difference(off_union_set) + direct_set = list(off_union_set.difference(on_set)) + screen_off_set = list(off_union_set.intersection(on_set)) + + self.direct.extend(direct_set) + self.screen_off.extend(screen_off_set) + + if len(on_set) == 1: + self.direct.extend(list(on_set)) + else: + if len(on_diff_set) == 1: + self.indirect.extend(list(on_diff_set)) + return evidence_views + + +class OnOnOff(BlicketQuestion): + def __init__( + self, + min_potential_blickets, + max_potential_blickets, + min_non_blickets, + max_non_blickets, + config_size, + shuffle, + rng, + ): + super(OnOnOff, self).__init__( + min_potential_blickets, + max_potential_blickets, + min_non_blickets, + max_non_blickets, + config_size, + shuffle, + rng, + ) + + def get_evidence_views(self): + # the first non blicket used in habituation and hence the skip + off_list = self.non_blickets[1:] + + off_view = BlicketView("off") + off_view.add_objects(off_list) + + first_on, second_on = self.union_sample(self.potential_blickets) + + first_on_view = BlicketView("on") + first_on_view.add_objects(first_on) + + second_on_view = BlicketView("on") + second_on_view.add_objects(second_on) + + for on in [first_on, second_on]: + if len(on) == 1: + if on[0] not in self.blickets: + self.add_blicket(on[0]) + for on in [first_on, second_on]: + if len(set(on).intersection(self.blickets)) == 0: + set_blicket = on[:] + set_blicket.sort() + set_blicket = tuple(set_blicket) + if set_blicket not in self.set_blickets: + self.add_set_blicket(set_blicket) + + self.add_noise(first_on_view) + self.add_noise(second_on_view) + + evidence_views = [first_on_view, second_on_view, off_view] + + # bookkeeping for candidate choices + on_union_set = set(first_on_view.objects).union(second_on_view.objects) + off_set = set(off_view.objects) + + direct_set = list(off_set.difference(on_union_set)) + screen_off_set = list(off_set.intersection(on_union_set)) + self.direct.extend(direct_set) + self.screen_off.extend(screen_off_set) + + for on_view in [first_on_view, second_on_view]: + if len(on_view.objects) == 1: + obj = on_view.objects[0] + if obj not in self.direct: + self.direct.append(obj) + for on_view in [first_on_view, second_on_view]: + diff_set = set(on_view.objects).difference(off_set) + if len(on_view.objects) > 1 and len(diff_set) == 1: + obj = diff_set.pop() + if obj not in self.indirect and obj not in self.direct: + self.indirect.append(obj) + + return evidence_views + + +def serialize(questions): + question_list = [] + for question in questions: + view_list = [] + for i in range(6): + json_dict = {} + json_dict["light_state"] = question[i].light_state + json_dict["objects"] = question[i].objects + view_list.append(json_dict) + for i in range(6, 8): + json_dict = {} + json_dict["light_state"] = question[i][0].light_state + json_dict["objects"] = question[i][0].objects + json_dict["label"] = question[i][1] + json_dict["type"] = question[i][2] + view_list.append(json_dict) + for i in range(8, 10): + json_dict = {} + json_dict["light_state"] = question[i][0].light_state + json_dict["objects"] = question[i][0].objects + json_dict["label"] = question[i][1] + json_dict["type"] = question[i][2] + json_dict["ref"] = question[i][3] + view_list.append(json_dict) + question_list.append(view_list) + return question_list + + +def config_control(size, train, config_size, regime, rng): + questions = [] + for _ in range(size // 2): + blicket_machine = OnOffOff( + ONOFFOFF_MIN_POTENTIAL, ONOFFOFF_MAX_POTENTIAL, ONOFFOFF_MIN_NON, ONOFFOFF_MAX_NON, config_size, True, rng + ) + context_views = blicket_machine.get_views() + + blicket_machine.sanity_check() + blicket_machine.check_blickets(context_views) + + cause_questions = blicket_machine.generate_cause_questions(train=train, regime=regime) + for cause_question in cause_questions: + blicket_machine.check_labels(cause_question[0], cause_question[1]) + intervention_questions = blicket_machine.generate_intervention_questions(context_views) + for intervention_question in intervention_questions: + blicket_machine.check_labels(intervention_question[0], intervention_question[1]) + questions.append(context_views + cause_questions + intervention_questions) + for _ in range(size // 2): + blicket_machine = OnOnOff( + ONONOFF_MIN_POTENTIAL, ONONOFF_MAX_POTENTIAL, ONONOFF_MIN_NON, ONONOFF_MAX_NON, config_size, True, rng + ) + context_views = blicket_machine.get_views() + + blicket_machine.sanity_check() + blicket_machine.check_blickets(context_views) + + cause_questions = blicket_machine.generate_cause_questions(train=train, regime=regime) + for cause_question in cause_questions: + blicket_machine.check_labels(cause_question[0], cause_question[1]) + intervention_questions = blicket_machine.generate_intervention_questions(context_views) + for intervention_question in intervention_questions: + blicket_machine.check_labels(intervention_question[0], intervention_question[1]) + questions.append(context_views + cause_questions + intervention_questions) + rng.shuffle(questions) + return questions + + +def dist_control(size, train, regime, rng): + questions = [] + for _ in range(size): + if train: + blicket_machine = OnOffOff( + ONOFFOFF_MIN_POTENTIAL, + ONOFFOFF_MAX_POTENTIAL, + ONOFFOFF_MIN_NON, + ONOFFOFF_MAX_NON, + ALL_CONFIG_SIZE, + True, + rng, + ) + else: + blicket_machine = OnOnOff( + ONONOFF_MIN_POTENTIAL, + ONONOFF_MAX_POTENTIAL, + ONONOFF_MIN_NON, + ONONOFF_MAX_NON, + ALL_CONFIG_SIZE, + True, + rng, + ) + context_views = blicket_machine.get_views() + + blicket_machine.sanity_check() + blicket_machine.check_blickets(context_views) + + cause_questions = blicket_machine.generate_cause_questions(train=train, regime=regime) + for cause_question in cause_questions: + blicket_machine.check_labels(cause_question[0], cause_question[1]) + intervention_questions = blicket_machine.generate_intervention_questions(context_views) + for intervention_question in intervention_questions: + blicket_machine.check_labels(intervention_question[0], intervention_question[1]) + questions.append(context_views + cause_questions + intervention_questions) + rng.shuffle(questions) + return questions + + +# Text translation functions +LIGHT_DICT_TEXT = ["off", "undetermined", "on"] +OBJECT_DICT_FILE_PATH = get_data_file_path("acre_objects.json") +with open(OBJECT_DICT_FILE_PATH, "r") as object_dict_file: + OBJECT_DICT_TEXT = json.load(object_dict_file) + + +def get_example_text(objects, light_state): + input_examples = [] + for object in objects: + obj_desc = OBJECT_DICT_TEXT[str(object)] + color = obj_desc["color"] + shape = obj_desc["shape"] + material = obj_desc["material"] + # [color, material, shape] + input_examples.append([color, material, shape]) + return {"input": input_examples, "output": light_state} + + +# Label has 0, 1, 2 integer which are indexed into the array ["off", "undetermined", "on"] +def get_trial_text(objects, label): + # object_text = "" + input_examples = [] + for object in objects: + obj_desc = OBJECT_DICT_TEXT[str(object)] + color = obj_desc["color"] + shape = obj_desc["shape"] + material = obj_desc["material"] + # [color, material, shape] + input_examples.append([color, material, shape]) + return {"input": input_examples, "output": LIGHT_DICT_TEXT[label]} + + +# Select translation functions +get_example = get_example_text +get_trial = get_trial_text +get_answer_list = lambda: LIGHT_DICT_TEXT + + +def final_parse(serialized_questions) -> List[Dict[str, Any]]: + output_data = [] + for sample in serialized_questions: # For each data in input (6 examples, 4 tests) + # {"examples": [{"input": [[color, material, shape]], "output": "on"}], "question": {"input": [color, material, shape], "output", "on", "type": "direct"}} + # examples, ideally is an array of dict(input, output) + examples = [] + for entry in sample: + if entry["light_state"] != "no": # example + example = get_example(entry["objects"], entry["light_state"]) + examples.append(example) + else: # test case + # question is a dict of input and output + question = get_trial(entry["objects"], entry["label"]) + output_data.append({"examples": examples, "question": question}) + + return output_data diff --git a/reasoning_gym/induction/acre/const.py b/reasoning_gym/induction/acre/const.py new file mode 100644 index 00000000..337ee099 --- /dev/null +++ b/reasoning_gym/induction/acre/const.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + + +ONOFFOFF_MIN_POTENTIAL = 2 +ONOFFOFF_MAX_POTENTIAL = 3 +ONOFFOFF_MIN_NON = 3 +ONOFFOFF_MAX_NON = 5 + +ONONOFF_MIN_POTENTIAL = 2 +ONONOFF_MAX_POTENTIAL = 4 +ONONOFF_MIN_NON = 3 +ONONOFF_MAX_NON = 4 + +ATTR_CONFIG_SIZE = 24 +ALL_CONFIG_SIZE = 48 diff --git a/tests/test_acre.py b/tests/test_acre.py new file mode 100644 index 00000000..a0c4ab74 --- /dev/null +++ b/tests/test_acre.py @@ -0,0 +1,65 @@ +import pytest + +from reasoning_gym.induction.acre.acre import ACREDataset, ACREDatasetConfig + + +def test_acre_config_validation(): + """Test that config validation works""" + config = ACREDatasetConfig(size=-1) + with pytest.raises(AssertionError): + config.validate() + + +def test_acre_dataset_deterministic(): + """Test that dataset generates same items with same seed""" + config = ACREDatasetConfig(seed=42, size=10) + dataset1 = ACREDataset(config) + dataset2 = ACREDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_acre_items(): + """Test basic properties of generated items""" + config = ACREDatasetConfig(size=50, seed=42) + dataset = ACREDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert isinstance(item["question"], str) + assert isinstance(item["answer"], str) + + +def test_acre_iteration(): + """Test that iteration respects dataset size""" + config = ACREDatasetConfig(size=10, seed=42) # Small size for testing + dataset = ACREDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test list conversion + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test multiple iterations + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items" + + +def test_acre_questions_generator(): + """Test question generator loading and access""" + config = ACREDatasetConfig(size=10, seed=42) + dataset = ACREDataset(config) + + # Test properties of questions + assert isinstance(dataset.questions, list) + assert len(dataset.questions) > 0