mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-19 12:58:07 +00:00
Add ACRE(Abstract Causal REasoning Beyond Covariation) python generators (#199)
* Add acre python generators * acre: improved prompt & formatting of examples, support arbitrary sizes --------- Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
This commit is contained in:
parent
e62b45d61c
commit
c8c3930797
6 changed files with 1152 additions and 4 deletions
290
reasoning_gym/data/acre_objects.json
Normal file
290
reasoning_gym/data/acre_objects.json
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
{
|
||||
"0": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "gray"
|
||||
},
|
||||
"1": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "gray"
|
||||
},
|
||||
"2": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "red"
|
||||
},
|
||||
"3": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "red"
|
||||
},
|
||||
"4": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "blue"
|
||||
},
|
||||
"5": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "blue"
|
||||
},
|
||||
"6": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "green"
|
||||
},
|
||||
"7": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "green"
|
||||
},
|
||||
"8": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "brown"
|
||||
},
|
||||
"9": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "brown"
|
||||
},
|
||||
"10": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "purple"
|
||||
},
|
||||
"11": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "purple"
|
||||
},
|
||||
"12": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "cyan"
|
||||
},
|
||||
"13": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "cyan"
|
||||
},
|
||||
"14": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "yellow"
|
||||
},
|
||||
"15": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cube",
|
||||
"color": "yellow"
|
||||
},
|
||||
"16": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "gray"
|
||||
},
|
||||
"17": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "gray"
|
||||
},
|
||||
"18": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "red"
|
||||
},
|
||||
"19": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "red"
|
||||
},
|
||||
"20": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "blue"
|
||||
},
|
||||
"21": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "blue"
|
||||
},
|
||||
"22": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "green"
|
||||
},
|
||||
"23": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "green"
|
||||
},
|
||||
"24": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "brown"
|
||||
},
|
||||
"25": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "brown"
|
||||
},
|
||||
"26": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "purple"
|
||||
},
|
||||
"27": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "purple"
|
||||
},
|
||||
"28": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "cyan"
|
||||
},
|
||||
"29": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "cyan"
|
||||
},
|
||||
"30": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "yellow"
|
||||
},
|
||||
"31": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "cylinder",
|
||||
"color": "yellow"
|
||||
},
|
||||
"32": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "gray"
|
||||
},
|
||||
"33": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "gray"
|
||||
},
|
||||
"34": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "red"
|
||||
},
|
||||
"35": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "red"
|
||||
},
|
||||
"36": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "blue"
|
||||
},
|
||||
"37": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "blue"
|
||||
},
|
||||
"38": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "green"
|
||||
},
|
||||
"39": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "green"
|
||||
},
|
||||
"40": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "brown"
|
||||
},
|
||||
"41": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "brown"
|
||||
},
|
||||
"42": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "purple"
|
||||
},
|
||||
"43": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "purple"
|
||||
},
|
||||
"44": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "cyan"
|
||||
},
|
||||
"45": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "cyan"
|
||||
},
|
||||
"46": {
|
||||
"material": "rubber",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "yellow"
|
||||
},
|
||||
"47": {
|
||||
"material": "metal",
|
||||
"size": "medium",
|
||||
"shape": "sphere",
|
||||
"color": "yellow"
|
||||
}
|
||||
}
|
||||
|
|
@ -2,9 +2,7 @@
|
|||
Arithmetic tasks for training reasoning capabilities:
|
||||
"""
|
||||
|
||||
from .acre.acre import ACREDataset, ACREDatasetConfig
|
||||
from .list_functions import ListFunctionsDataset, ListFunctionsDatasetConfig
|
||||
|
||||
__all__ = [
|
||||
"ListFunctionsDataset",
|
||||
"ListFunctionsDatasetConfig",
|
||||
]
|
||||
__all__ = ["ListFunctionsDataset", "ListFunctionsDatasetConfig", "ACREDataset", "ACREDatasetConfig"]
|
||||
|
|
|
|||
93
reasoning_gym/induction/acre/acre.py
Normal file
93
reasoning_gym/induction/acre/acre.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""ACRE(Abstract Causal REasoning Beyond Covariation) dataset"""
|
||||
|
||||
# Culled and Adapted from https://github.com/WellyZhang/ACRE
|
||||
|
||||
from dataclasses import dataclass
|
||||
from random import Random
|
||||
from typing import Optional
|
||||
|
||||
from reasoning_gym.factory import ProceduralDataset, register_dataset
|
||||
|
||||
from .blicket import config_control, dist_control, final_parse, serialize
|
||||
from .const import ALL_CONFIG_SIZE, ATTR_CONFIG_SIZE
|
||||
|
||||
|
||||
# Create blicket questions
|
||||
@dataclass
|
||||
class ACREDatasetConfig:
|
||||
"""Configuration for ACRE dataset generation"""
|
||||
|
||||
train: int = 1 # The default is 1 for training, otherwise 0 for validation and testing
|
||||
size: int = 500 # Split ratio = 6 : 2 : 2 -> IID : Comp : Sys
|
||||
seed: Optional[int] = None
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration parameters"""
|
||||
assert self.train in (0, 1), "train must be either 0 or 1"
|
||||
assert self.size > 0, "Dataset size must be positive."
|
||||
|
||||
|
||||
class ACREDataset(ProceduralDataset):
|
||||
|
||||
def __init__(self, config: ACREDatasetConfig):
|
||||
super().__init__(config, config.seed, config.size)
|
||||
self.questions = self._generate_questions()
|
||||
self.prompt_template = """You are a researcher studying causal relationships using Blicket experiments. In these experiments, certain objects (called 'blickets') have the hidden property of activating a detector, causing its light to turn on.
|
||||
|
||||
Each example shows the results of placing different combinations of objects on the detector. Each object is described by color, material and shape. Your task is to determine whether a new combination of objects will cause the detector to activate.
|
||||
|
||||
After observing the previous examples, respond with:
|
||||
- "on" if you can determine the detector light will turn on
|
||||
- "off" if you can determine the detector light will stay off
|
||||
- "undetermined" if there is insufficient evidence to reach a conclusion
|
||||
|
||||
Do not use quotation marks in your answer.
|
||||
|
||||
Previous experimental results:
|
||||
{examples}
|
||||
|
||||
New test case:
|
||||
{input}
|
||||
|
||||
What is the detector light status?"""
|
||||
|
||||
def _generate_questions(self):
|
||||
"""
|
||||
Generates questions of particular size
|
||||
"""
|
||||
|
||||
iid_size = int(0.6 * self.config.size)
|
||||
comp_size = int(0.2 * self.config.size)
|
||||
sys_size = self.config.size - (iid_size + comp_size)
|
||||
rng = Random(self.seed)
|
||||
iid_questions = config_control(iid_size, self.config.train, ALL_CONFIG_SIZE, "IID", rng)
|
||||
comp_questions = config_control(comp_size, self.config.train, ATTR_CONFIG_SIZE, "Comp", rng)
|
||||
sys_questions = dist_control(sys_size, self.config.train, "Sys", rng)
|
||||
|
||||
questions = []
|
||||
questions.extend(iid_questions)
|
||||
questions.extend(comp_questions)
|
||||
questions.extend(sys_questions)
|
||||
rng.shuffle(questions)
|
||||
final_questions = final_parse(serialized_questions=serialize(questions))
|
||||
return final_questions
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Generate a single induction-based list function dataset"""
|
||||
input = self.questions[idx]
|
||||
examples = input["examples"]
|
||||
formatted_examples = ""
|
||||
for object in examples:
|
||||
input_ = ", ".join(" ".join(x) for x in object["input"])
|
||||
output = object["output"]
|
||||
if len(formatted_examples) > 0:
|
||||
formatted_examples += "\n"
|
||||
formatted_examples += f"{input_} → {output}"
|
||||
|
||||
prompt_input = ", ".join(" ".join(x) for x in input["question"]["input"])
|
||||
answer = input["question"]["output"]
|
||||
question = self.prompt_template.format(examples=formatted_examples, input=prompt_input)
|
||||
return {"question": question, "answer": answer, "metadata": {}}
|
||||
|
||||
|
||||
register_dataset("acre", ACREDataset, ACREDatasetConfig)
|
||||
687
reasoning_gym/induction/acre/blicket.py
Normal file
687
reasoning_gym/induction/acre/blicket.py
Normal file
|
|
@ -0,0 +1,687 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from reasoning_gym.data import get_data_file_path
|
||||
|
||||
from .const import (
|
||||
ALL_CONFIG_SIZE,
|
||||
ONOFFOFF_MAX_NON,
|
||||
ONOFFOFF_MAX_POTENTIAL,
|
||||
ONOFFOFF_MIN_NON,
|
||||
ONOFFOFF_MIN_POTENTIAL,
|
||||
ONONOFF_MAX_NON,
|
||||
ONONOFF_MAX_POTENTIAL,
|
||||
ONONOFF_MIN_NON,
|
||||
ONONOFF_MIN_POTENTIAL,
|
||||
)
|
||||
|
||||
|
||||
class BlicketView(object):
|
||||
def __init__(self, light_state="no"):
|
||||
# light state: "no" for no light, "off" for light off, "on" for light on
|
||||
self.objects = []
|
||||
self.light_state = light_state
|
||||
|
||||
def add_objects(self, objects):
|
||||
for obj in objects:
|
||||
self.objects.append(obj)
|
||||
|
||||
def remove_objects(self, objects):
|
||||
for obj in objects:
|
||||
self.objects.remove(obj)
|
||||
|
||||
def __repr__(self):
|
||||
return "BlicketView(objects={}, light_state={})".format(self.objects, self.light_state)
|
||||
|
||||
|
||||
class BlicketQuestion(object):
|
||||
def __init__(
|
||||
self,
|
||||
min_potential_blickets,
|
||||
max_potential_blickets,
|
||||
min_non_blickets,
|
||||
max_non_blickets,
|
||||
config_size,
|
||||
shuffle,
|
||||
rng,
|
||||
):
|
||||
self.min_potential_blickets = min_potential_blickets
|
||||
self.max_potential_blickets = max_potential_blickets
|
||||
self.min_non_blickets = min_non_blickets
|
||||
self.max_non_blickets = max_non_blickets
|
||||
self.config_size = config_size
|
||||
self.shuffle = shuffle
|
||||
self.rng = rng
|
||||
|
||||
potential_blicket_num = self.rng.randint(self.min_potential_blickets, self.max_potential_blickets)
|
||||
non_blicket_num = self.rng.randint(self.min_non_blickets, self.max_non_blickets)
|
||||
|
||||
samples = self.rng.sample(list(range(self.config_size)), k=potential_blicket_num + non_blicket_num)
|
||||
|
||||
self.blickets = []
|
||||
self.set_blickets = []
|
||||
self.potential_blickets = samples[:potential_blicket_num]
|
||||
self.non_blickets = samples[potential_blicket_num:]
|
||||
|
||||
self.direct = []
|
||||
self.indirect = []
|
||||
self.screen_off = []
|
||||
|
||||
def get_habituation_views(self):
|
||||
blicket_obj = self.potential_blickets[0]
|
||||
self.add_blicket(blicket_obj)
|
||||
|
||||
non_blicket_obj = self.non_blickets[0]
|
||||
|
||||
view_with_blicket = BlicketView("on")
|
||||
view_with_blicket.add_objects([blicket_obj])
|
||||
|
||||
view_with_non_blicket = BlicketView("off")
|
||||
view_with_non_blicket.add_objects([non_blicket_obj])
|
||||
|
||||
view_with_both = BlicketView("on")
|
||||
view_with_both.add_objects([blicket_obj, non_blicket_obj])
|
||||
|
||||
habituation_views = [view_with_blicket, view_with_non_blicket, view_with_both]
|
||||
|
||||
# bookkeeping for candidate choices
|
||||
self.direct.append(blicket_obj)
|
||||
self.screen_off.append(non_blicket_obj)
|
||||
|
||||
return habituation_views
|
||||
|
||||
def get_evidence_views(self):
|
||||
raise NotImplementedError("The parent class should not be used.")
|
||||
|
||||
def get_views(self):
|
||||
habituation_views = self.get_habituation_views()
|
||||
evidence_views = self.get_evidence_views()
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(habituation_views)
|
||||
self.rng.shuffle(evidence_views)
|
||||
|
||||
return habituation_views + evidence_views
|
||||
|
||||
def sanity_check(self):
|
||||
# no duplicates
|
||||
assert len(self.direct) == len(set(self.direct))
|
||||
assert len(self.indirect) == len(set(self.indirect))
|
||||
assert len(self.screen_off) == len(set(self.screen_off))
|
||||
# no intersection
|
||||
assert len(set(self.direct).intersection(self.indirect)) == 0
|
||||
assert len(set(self.direct).intersection(self.screen_off)) == 0
|
||||
assert len(set(self.indirect).intersection(self.screen_off)) == 0
|
||||
# completeness
|
||||
assert set(self.direct + self.indirect + self.screen_off) == set(self.blickets + self.non_blickets)
|
||||
|
||||
def check_blickets(self, views):
|
||||
blickets = set()
|
||||
non_blickets = set()
|
||||
potential_blickets = set()
|
||||
|
||||
direct = set()
|
||||
indirect = set()
|
||||
screen_off = set()
|
||||
|
||||
on_views = [view for view in views if view.light_state == "on"]
|
||||
off_views = [view for view in views if view.light_state == "off"]
|
||||
|
||||
assert len(on_views) + len(off_views) == len(views)
|
||||
|
||||
for off_view in off_views:
|
||||
non_blickets.update(off_view.objects)
|
||||
for on_view in on_views:
|
||||
if len(on_view.objects) == 1:
|
||||
blickets.update(on_view.objects)
|
||||
else:
|
||||
diff_set = set(on_view.objects).difference(non_blickets)
|
||||
if len(diff_set) == 1:
|
||||
blickets.update(diff_set)
|
||||
all_objects = set()
|
||||
for view in views:
|
||||
all_objects.update(view.objects)
|
||||
potential_blickets.update(all_objects.difference(non_blickets).difference(blickets))
|
||||
|
||||
assert blickets == set(self.blickets)
|
||||
assert non_blickets == set(self.non_blickets)
|
||||
assert potential_blickets == set(self.potential_blickets)
|
||||
|
||||
for on_view in on_views:
|
||||
if len(on_view.objects) == 1:
|
||||
direct.update(on_view.objects)
|
||||
on_view_objects = set()
|
||||
for on_view in on_views:
|
||||
on_view_objects.update(on_view.objects)
|
||||
off_view_objects = set()
|
||||
for off_view in off_views:
|
||||
off_view_objects.update(off_view.objects)
|
||||
|
||||
direct.update(off_view_objects.difference(on_view_objects))
|
||||
screen_off.update(off_view_objects.intersection(on_view_objects))
|
||||
|
||||
for on_view in on_views:
|
||||
if len(on_view.objects) > 1:
|
||||
diff_set = set(on_view.objects).difference(non_blickets)
|
||||
if len(diff_set) == 1 and not diff_set.issubset(direct):
|
||||
indirect.update(diff_set)
|
||||
|
||||
assert direct == set(self.direct)
|
||||
assert indirect == set(self.indirect)
|
||||
assert screen_off == set(self.screen_off)
|
||||
|
||||
set_blickets = set()
|
||||
for on_view in on_views:
|
||||
on_diff_set = set(on_view.objects).difference(non_blickets)
|
||||
if len(on_diff_set.intersection(blickets)) == 0:
|
||||
potential_set_blicket = list(on_diff_set)
|
||||
potential_set_blicket.sort()
|
||||
set_blickets.add(tuple(potential_set_blicket))
|
||||
remove_set = set()
|
||||
for elem in set_blickets:
|
||||
for another_elem in set_blickets:
|
||||
elem_set = set(elem)
|
||||
another_elem_set = set(another_elem)
|
||||
if elem_set < another_elem_set:
|
||||
remove_set.add(another_elem)
|
||||
set_blickets.difference_update(remove_set)
|
||||
|
||||
assert set_blickets == set(self.set_blickets), "set_blickets:{}, self.set_blickets:{}".format(
|
||||
set_blickets, self.set_blickets
|
||||
)
|
||||
|
||||
def check_labels(self, view, label):
|
||||
diff_set = set(view.objects).difference(self.non_blickets)
|
||||
if len(diff_set) == 0:
|
||||
assert label == 0
|
||||
else:
|
||||
blicket_inter = diff_set.intersection(self.blickets)
|
||||
if len(blicket_inter) > 0:
|
||||
assert label == 2
|
||||
else:
|
||||
if self.has_set_blicket(view):
|
||||
assert label == 2
|
||||
else:
|
||||
assert label == 1
|
||||
|
||||
def has_set_blicket(self, view):
|
||||
for set_blicket in self.set_blickets:
|
||||
if set(set_blicket).issubset(view.objects):
|
||||
return True
|
||||
return False
|
||||
|
||||
def union_sample(self, union, must_have_one=False):
|
||||
if must_have_one:
|
||||
first_set = self.rng.sample(union, k=1)
|
||||
else:
|
||||
first_num = self.rng.randint(1, len(union))
|
||||
first_set = self.rng.sample(union, k=first_num)
|
||||
second_set = list(set(union).difference(first_set))
|
||||
if len(second_set) > 0:
|
||||
additional_num_min = 0
|
||||
else:
|
||||
additional_num_min = 1
|
||||
additional_num = self.rng.randint(additional_num_min, len(first_set))
|
||||
additional_samples = self.rng.sample(first_set, k=additional_num)
|
||||
second_set += additional_samples
|
||||
|
||||
return first_set, second_set
|
||||
|
||||
def add_noise(self, view):
|
||||
# the first non blicket used in habituation and hence the skip
|
||||
noise_num = self.rng.randint(0, len(self.non_blickets) - 1)
|
||||
noise = self.rng.sample(self.non_blickets[1:], k=noise_num)
|
||||
view.add_objects(noise)
|
||||
|
||||
def add_blicket(self, obj):
|
||||
self.potential_blickets.remove(obj)
|
||||
self.blickets.append(obj)
|
||||
|
||||
def add_set_blicket(self, set_blicket):
|
||||
to_remove_set = []
|
||||
for already_set_blicket in self.set_blickets:
|
||||
if set(set_blicket) < set(already_set_blicket):
|
||||
to_remove_set.append(already_set_blicket)
|
||||
if set(already_set_blicket) < set(set_blicket):
|
||||
to_remove_set.append(set_blicket)
|
||||
self.set_blickets.append(set_blicket)
|
||||
for elem in to_remove_set:
|
||||
self.set_blickets.remove(elem)
|
||||
|
||||
def generate_cause_questions(self, train, regime="IID"):
|
||||
|
||||
def fixed_sum_sample(lower, upper, fixed_sum):
|
||||
values = [0] * len(lower)
|
||||
assert sum(upper) >= fixed_sum
|
||||
while True:
|
||||
residual = fixed_sum
|
||||
for i in range(len(values) - 1):
|
||||
values[i] = self.rng.randint(lower[i], min(upper[i], residual))
|
||||
residual = residual - values[i]
|
||||
if residual >= lower[-1] and residual <= upper[-1]:
|
||||
values[-1] = residual
|
||||
break
|
||||
assert sum(values) == fixed_sum
|
||||
return values
|
||||
|
||||
# direct_sample, indirect_sample, screen_off_sample, potential_sample
|
||||
# these magic numbers and changes are to adjust dataset statistics
|
||||
constraint_lower = [0] * 3
|
||||
if regime == "Comp":
|
||||
# for Comp split
|
||||
constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 2, 0)]
|
||||
if regime == "Sys":
|
||||
if train:
|
||||
# for Sys train split
|
||||
constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 3, 0)]
|
||||
else:
|
||||
# for Sys val / test split
|
||||
constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 1, 0)]
|
||||
if regime == "IID":
|
||||
# for IID split
|
||||
constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 1, 0)]
|
||||
potential_sample_num = min(len(self.potential_blickets), 1)
|
||||
direct_sample_num, indirect_sample_num, screen_off_sample_num = fixed_sum_sample(
|
||||
constraint_lower, constraint_upper, 2 - potential_sample_num
|
||||
)
|
||||
direct_samples = self.rng.sample(self.direct, k=direct_sample_num)
|
||||
indirect_samples = self.rng.sample(self.indirect, k=indirect_sample_num)
|
||||
screen_off_samples = self.rng.sample(self.screen_off, k=screen_off_sample_num)
|
||||
potential_samples = self.rng.sample(self.potential_blickets, k=potential_sample_num)
|
||||
|
||||
questions = []
|
||||
# label: 0 for light off, 1 for unknown, 2 for light up
|
||||
for direct_sample in direct_samples:
|
||||
assert direct_sample in self.blickets or direct_sample in self.non_blickets
|
||||
if direct_sample in self.blickets:
|
||||
label = 2
|
||||
else:
|
||||
label = 0
|
||||
cause_view = BlicketView("no")
|
||||
cause_view.add_objects([direct_sample])
|
||||
questions.append((cause_view, label, "direct"))
|
||||
for indirect_sample in indirect_samples:
|
||||
assert indirect_sample in self.blickets
|
||||
label = 2
|
||||
cause_view = BlicketView("no")
|
||||
cause_view.add_objects([indirect_sample])
|
||||
questions.append((cause_view, label, "indirect"))
|
||||
for screen_off_sample in screen_off_samples:
|
||||
assert screen_off_sample in self.non_blickets
|
||||
label = 0
|
||||
cause_view = BlicketView("no")
|
||||
cause_view.add_objects([screen_off_sample])
|
||||
questions.append((cause_view, label, "screen_off"))
|
||||
for potential_sample in potential_samples:
|
||||
assert potential_sample in self.potential_blickets
|
||||
label = 1
|
||||
cause_view = BlicketView("no")
|
||||
cause_view.add_objects([potential_sample])
|
||||
questions.append((cause_view, label, "potential"))
|
||||
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(questions)
|
||||
|
||||
return questions
|
||||
|
||||
def generate_intervention_questions(self, views):
|
||||
|
||||
# on_views = [view for view in views if view.light_state == "on" and len(view.objects) >= 2]
|
||||
off_views = [view for view in views if view.light_state == "off"]
|
||||
|
||||
# on_view = rng.sample(on_views, k=1)[0]
|
||||
# on_view_ref = views.index(on_view)
|
||||
off_view = self.rng.sample(off_views, k=1)[0]
|
||||
off_view_ref = views.index(off_view)
|
||||
|
||||
questions = []
|
||||
|
||||
all_possibilities = (
|
||||
list(set(self.blickets + self.potential_blickets + self.non_blickets).difference(off_view.objects))
|
||||
+ self.set_blickets
|
||||
)
|
||||
# adjust weight during sample for better statistics
|
||||
all_possibilities += list(set(self.potential_blickets).difference(off_view.objects))
|
||||
possibilities = self.rng.sample(all_possibilities, k=2)
|
||||
for possibility in possibilities:
|
||||
if possibility in self.direct:
|
||||
q_type = "direct"
|
||||
elif possibility in self.indirect:
|
||||
q_type = "indirect"
|
||||
elif possibility in self.screen_off:
|
||||
q_type = "screen_off"
|
||||
elif possibility in self.potential_blickets:
|
||||
q_type = "potential"
|
||||
else:
|
||||
assert possibility in self.set_blickets
|
||||
q_type = "indirect"
|
||||
if possibility in self.blickets:
|
||||
label = 2
|
||||
elif possibility in self.potential_blickets:
|
||||
label = 1
|
||||
elif possibility in self.non_blickets:
|
||||
label = 0
|
||||
else:
|
||||
assert possibility in self.set_blickets
|
||||
label = 2
|
||||
intervention_view = BlicketView("no")
|
||||
intervention_view.add_objects(off_view.objects)
|
||||
if type(possibility) == tuple:
|
||||
possibility = list(possibility)
|
||||
else:
|
||||
possibility = [possibility]
|
||||
intervention_view.add_objects(possibility)
|
||||
questions.append((intervention_view, label, q_type, off_view_ref))
|
||||
|
||||
if self.shuffle:
|
||||
self.rng.shuffle(questions)
|
||||
|
||||
return questions
|
||||
|
||||
|
||||
class OnOffOff(BlicketQuestion):
|
||||
def __init__(
|
||||
self,
|
||||
min_potential_blickets,
|
||||
max_potential_blickets,
|
||||
min_non_blickets,
|
||||
max_non_blickets,
|
||||
config_size,
|
||||
shuffle,
|
||||
rng,
|
||||
):
|
||||
super(OnOffOff, self).__init__(
|
||||
min_potential_blickets,
|
||||
max_potential_blickets,
|
||||
min_non_blickets,
|
||||
max_non_blickets,
|
||||
config_size,
|
||||
shuffle,
|
||||
rng,
|
||||
)
|
||||
|
||||
def get_evidence_views(self):
|
||||
# the first non blicket used in habituation and hence the skip
|
||||
first_off, second_off = self.union_sample(self.non_blickets[1:])
|
||||
|
||||
first_off_view = BlicketView("off")
|
||||
first_off_view.add_objects(first_off)
|
||||
|
||||
second_off_view = BlicketView("off")
|
||||
second_off_view.add_objects(second_off)
|
||||
|
||||
on_view = BlicketView("on")
|
||||
|
||||
on_list = self.potential_blickets[:]
|
||||
|
||||
if len(on_list) == 1:
|
||||
self.add_blicket(on_list[0])
|
||||
else:
|
||||
set_blicket = on_list[:]
|
||||
set_blicket.sort()
|
||||
set_blicket = tuple(set_blicket)
|
||||
self.add_set_blicket(set_blicket)
|
||||
|
||||
on_view.add_objects(on_list)
|
||||
|
||||
self.add_noise(on_view)
|
||||
|
||||
evidence_views = [on_view, first_off_view, second_off_view]
|
||||
|
||||
# bookkeeping for candidate choices
|
||||
off_union_set = set(self.non_blickets[1:])
|
||||
on_set = set(on_view.objects)
|
||||
on_diff_set = on_set.difference(off_union_set)
|
||||
direct_set = list(off_union_set.difference(on_set))
|
||||
screen_off_set = list(off_union_set.intersection(on_set))
|
||||
|
||||
self.direct.extend(direct_set)
|
||||
self.screen_off.extend(screen_off_set)
|
||||
|
||||
if len(on_set) == 1:
|
||||
self.direct.extend(list(on_set))
|
||||
else:
|
||||
if len(on_diff_set) == 1:
|
||||
self.indirect.extend(list(on_diff_set))
|
||||
return evidence_views
|
||||
|
||||
|
||||
class OnOnOff(BlicketQuestion):
|
||||
def __init__(
|
||||
self,
|
||||
min_potential_blickets,
|
||||
max_potential_blickets,
|
||||
min_non_blickets,
|
||||
max_non_blickets,
|
||||
config_size,
|
||||
shuffle,
|
||||
rng,
|
||||
):
|
||||
super(OnOnOff, self).__init__(
|
||||
min_potential_blickets,
|
||||
max_potential_blickets,
|
||||
min_non_blickets,
|
||||
max_non_blickets,
|
||||
config_size,
|
||||
shuffle,
|
||||
rng,
|
||||
)
|
||||
|
||||
def get_evidence_views(self):
|
||||
# the first non blicket used in habituation and hence the skip
|
||||
off_list = self.non_blickets[1:]
|
||||
|
||||
off_view = BlicketView("off")
|
||||
off_view.add_objects(off_list)
|
||||
|
||||
first_on, second_on = self.union_sample(self.potential_blickets)
|
||||
|
||||
first_on_view = BlicketView("on")
|
||||
first_on_view.add_objects(first_on)
|
||||
|
||||
second_on_view = BlicketView("on")
|
||||
second_on_view.add_objects(second_on)
|
||||
|
||||
for on in [first_on, second_on]:
|
||||
if len(on) == 1:
|
||||
if on[0] not in self.blickets:
|
||||
self.add_blicket(on[0])
|
||||
for on in [first_on, second_on]:
|
||||
if len(set(on).intersection(self.blickets)) == 0:
|
||||
set_blicket = on[:]
|
||||
set_blicket.sort()
|
||||
set_blicket = tuple(set_blicket)
|
||||
if set_blicket not in self.set_blickets:
|
||||
self.add_set_blicket(set_blicket)
|
||||
|
||||
self.add_noise(first_on_view)
|
||||
self.add_noise(second_on_view)
|
||||
|
||||
evidence_views = [first_on_view, second_on_view, off_view]
|
||||
|
||||
# bookkeeping for candidate choices
|
||||
on_union_set = set(first_on_view.objects).union(second_on_view.objects)
|
||||
off_set = set(off_view.objects)
|
||||
|
||||
direct_set = list(off_set.difference(on_union_set))
|
||||
screen_off_set = list(off_set.intersection(on_union_set))
|
||||
self.direct.extend(direct_set)
|
||||
self.screen_off.extend(screen_off_set)
|
||||
|
||||
for on_view in [first_on_view, second_on_view]:
|
||||
if len(on_view.objects) == 1:
|
||||
obj = on_view.objects[0]
|
||||
if obj not in self.direct:
|
||||
self.direct.append(obj)
|
||||
for on_view in [first_on_view, second_on_view]:
|
||||
diff_set = set(on_view.objects).difference(off_set)
|
||||
if len(on_view.objects) > 1 and len(diff_set) == 1:
|
||||
obj = diff_set.pop()
|
||||
if obj not in self.indirect and obj not in self.direct:
|
||||
self.indirect.append(obj)
|
||||
|
||||
return evidence_views
|
||||
|
||||
|
||||
def serialize(questions):
|
||||
question_list = []
|
||||
for question in questions:
|
||||
view_list = []
|
||||
for i in range(6):
|
||||
json_dict = {}
|
||||
json_dict["light_state"] = question[i].light_state
|
||||
json_dict["objects"] = question[i].objects
|
||||
view_list.append(json_dict)
|
||||
for i in range(6, 8):
|
||||
json_dict = {}
|
||||
json_dict["light_state"] = question[i][0].light_state
|
||||
json_dict["objects"] = question[i][0].objects
|
||||
json_dict["label"] = question[i][1]
|
||||
json_dict["type"] = question[i][2]
|
||||
view_list.append(json_dict)
|
||||
for i in range(8, 10):
|
||||
json_dict = {}
|
||||
json_dict["light_state"] = question[i][0].light_state
|
||||
json_dict["objects"] = question[i][0].objects
|
||||
json_dict["label"] = question[i][1]
|
||||
json_dict["type"] = question[i][2]
|
||||
json_dict["ref"] = question[i][3]
|
||||
view_list.append(json_dict)
|
||||
question_list.append(view_list)
|
||||
return question_list
|
||||
|
||||
|
||||
def config_control(size, train, config_size, regime, rng):
|
||||
questions = []
|
||||
for _ in range(size // 2):
|
||||
blicket_machine = OnOffOff(
|
||||
ONOFFOFF_MIN_POTENTIAL, ONOFFOFF_MAX_POTENTIAL, ONOFFOFF_MIN_NON, ONOFFOFF_MAX_NON, config_size, True, rng
|
||||
)
|
||||
context_views = blicket_machine.get_views()
|
||||
|
||||
blicket_machine.sanity_check()
|
||||
blicket_machine.check_blickets(context_views)
|
||||
|
||||
cause_questions = blicket_machine.generate_cause_questions(train=train, regime=regime)
|
||||
for cause_question in cause_questions:
|
||||
blicket_machine.check_labels(cause_question[0], cause_question[1])
|
||||
intervention_questions = blicket_machine.generate_intervention_questions(context_views)
|
||||
for intervention_question in intervention_questions:
|
||||
blicket_machine.check_labels(intervention_question[0], intervention_question[1])
|
||||
questions.append(context_views + cause_questions + intervention_questions)
|
||||
for _ in range(size // 2):
|
||||
blicket_machine = OnOnOff(
|
||||
ONONOFF_MIN_POTENTIAL, ONONOFF_MAX_POTENTIAL, ONONOFF_MIN_NON, ONONOFF_MAX_NON, config_size, True, rng
|
||||
)
|
||||
context_views = blicket_machine.get_views()
|
||||
|
||||
blicket_machine.sanity_check()
|
||||
blicket_machine.check_blickets(context_views)
|
||||
|
||||
cause_questions = blicket_machine.generate_cause_questions(train=train, regime=regime)
|
||||
for cause_question in cause_questions:
|
||||
blicket_machine.check_labels(cause_question[0], cause_question[1])
|
||||
intervention_questions = blicket_machine.generate_intervention_questions(context_views)
|
||||
for intervention_question in intervention_questions:
|
||||
blicket_machine.check_labels(intervention_question[0], intervention_question[1])
|
||||
questions.append(context_views + cause_questions + intervention_questions)
|
||||
rng.shuffle(questions)
|
||||
return questions
|
||||
|
||||
|
||||
def dist_control(size, train, regime, rng):
|
||||
questions = []
|
||||
for _ in range(size):
|
||||
if train:
|
||||
blicket_machine = OnOffOff(
|
||||
ONOFFOFF_MIN_POTENTIAL,
|
||||
ONOFFOFF_MAX_POTENTIAL,
|
||||
ONOFFOFF_MIN_NON,
|
||||
ONOFFOFF_MAX_NON,
|
||||
ALL_CONFIG_SIZE,
|
||||
True,
|
||||
rng,
|
||||
)
|
||||
else:
|
||||
blicket_machine = OnOnOff(
|
||||
ONONOFF_MIN_POTENTIAL,
|
||||
ONONOFF_MAX_POTENTIAL,
|
||||
ONONOFF_MIN_NON,
|
||||
ONONOFF_MAX_NON,
|
||||
ALL_CONFIG_SIZE,
|
||||
True,
|
||||
rng,
|
||||
)
|
||||
context_views = blicket_machine.get_views()
|
||||
|
||||
blicket_machine.sanity_check()
|
||||
blicket_machine.check_blickets(context_views)
|
||||
|
||||
cause_questions = blicket_machine.generate_cause_questions(train=train, regime=regime)
|
||||
for cause_question in cause_questions:
|
||||
blicket_machine.check_labels(cause_question[0], cause_question[1])
|
||||
intervention_questions = blicket_machine.generate_intervention_questions(context_views)
|
||||
for intervention_question in intervention_questions:
|
||||
blicket_machine.check_labels(intervention_question[0], intervention_question[1])
|
||||
questions.append(context_views + cause_questions + intervention_questions)
|
||||
rng.shuffle(questions)
|
||||
return questions
|
||||
|
||||
|
||||
# Text translation functions
|
||||
LIGHT_DICT_TEXT = ["off", "undetermined", "on"]
|
||||
OBJECT_DICT_FILE_PATH = get_data_file_path("acre_objects.json")
|
||||
with open(OBJECT_DICT_FILE_PATH, "r") as object_dict_file:
|
||||
OBJECT_DICT_TEXT = json.load(object_dict_file)
|
||||
|
||||
|
||||
def get_example_text(objects, light_state):
|
||||
input_examples = []
|
||||
for object in objects:
|
||||
obj_desc = OBJECT_DICT_TEXT[str(object)]
|
||||
color = obj_desc["color"]
|
||||
shape = obj_desc["shape"]
|
||||
material = obj_desc["material"]
|
||||
# [color, material, shape]
|
||||
input_examples.append([color, material, shape])
|
||||
return {"input": input_examples, "output": light_state}
|
||||
|
||||
|
||||
# Label has 0, 1, 2 integer which are indexed into the array ["off", "undetermined", "on"]
|
||||
def get_trial_text(objects, label):
|
||||
# object_text = ""
|
||||
input_examples = []
|
||||
for object in objects:
|
||||
obj_desc = OBJECT_DICT_TEXT[str(object)]
|
||||
color = obj_desc["color"]
|
||||
shape = obj_desc["shape"]
|
||||
material = obj_desc["material"]
|
||||
# [color, material, shape]
|
||||
input_examples.append([color, material, shape])
|
||||
return {"input": input_examples, "output": LIGHT_DICT_TEXT[label]}
|
||||
|
||||
|
||||
# Select translation functions
|
||||
get_example = get_example_text
|
||||
get_trial = get_trial_text
|
||||
get_answer_list = lambda: LIGHT_DICT_TEXT
|
||||
|
||||
|
||||
def final_parse(serialized_questions) -> List[Dict[str, Any]]:
|
||||
output_data = []
|
||||
for sample in serialized_questions: # For each data in input (6 examples, 4 tests)
|
||||
# {"examples": [{"input": [[color, material, shape]], "output": "on"}], "question": {"input": [color, material, shape], "output", "on", "type": "direct"}}
|
||||
# examples, ideally is an array of dict(input, output)
|
||||
examples = []
|
||||
for entry in sample:
|
||||
if entry["light_state"] != "no": # example
|
||||
example = get_example(entry["objects"], entry["light_state"])
|
||||
examples.append(example)
|
||||
else: # test case
|
||||
# question is a dict of input and output
|
||||
question = get_trial(entry["objects"], entry["label"])
|
||||
output_data.append({"examples": examples, "question": question})
|
||||
|
||||
return output_data
|
||||
15
reasoning_gym/induction/acre/const.py
Normal file
15
reasoning_gym/induction/acre/const.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
ONOFFOFF_MIN_POTENTIAL = 2
|
||||
ONOFFOFF_MAX_POTENTIAL = 3
|
||||
ONOFFOFF_MIN_NON = 3
|
||||
ONOFFOFF_MAX_NON = 5
|
||||
|
||||
ONONOFF_MIN_POTENTIAL = 2
|
||||
ONONOFF_MAX_POTENTIAL = 4
|
||||
ONONOFF_MIN_NON = 3
|
||||
ONONOFF_MAX_NON = 4
|
||||
|
||||
ATTR_CONFIG_SIZE = 24
|
||||
ALL_CONFIG_SIZE = 48
|
||||
65
tests/test_acre.py
Normal file
65
tests/test_acre.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import pytest
|
||||
|
||||
from reasoning_gym.induction.acre.acre import ACREDataset, ACREDatasetConfig
|
||||
|
||||
|
||||
def test_acre_config_validation():
|
||||
"""Test that config validation works"""
|
||||
config = ACREDatasetConfig(size=-1)
|
||||
with pytest.raises(AssertionError):
|
||||
config.validate()
|
||||
|
||||
|
||||
def test_acre_dataset_deterministic():
|
||||
"""Test that dataset generates same items with same seed"""
|
||||
config = ACREDatasetConfig(seed=42, size=10)
|
||||
dataset1 = ACREDataset(config)
|
||||
dataset2 = ACREDataset(config)
|
||||
|
||||
for i in range(len(dataset1)):
|
||||
assert dataset1[i] == dataset2[i]
|
||||
|
||||
|
||||
def test_acre_items():
|
||||
"""Test basic properties of generated items"""
|
||||
config = ACREDatasetConfig(size=50, seed=42)
|
||||
dataset = ACREDataset(config)
|
||||
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "question" in item
|
||||
assert "answer" in item
|
||||
assert isinstance(item["question"], str)
|
||||
assert isinstance(item["answer"], str)
|
||||
|
||||
|
||||
def test_acre_iteration():
|
||||
"""Test that iteration respects dataset size"""
|
||||
config = ACREDatasetConfig(size=10, seed=42) # Small size for testing
|
||||
dataset = ACREDataset(config)
|
||||
|
||||
# Test manual iteration
|
||||
items = []
|
||||
for item in dataset:
|
||||
items.append(item)
|
||||
assert len(items) == config.size, "Iterator should yield exactly size items"
|
||||
|
||||
# Test list conversion
|
||||
items = list(dataset)
|
||||
assert len(items) == config.size, "Iterator should yield exactly size items"
|
||||
|
||||
# Test multiple iterations
|
||||
first_items = list(dataset)
|
||||
second_items = list(dataset)
|
||||
assert first_items == second_items, "Multiple iterations should yield same items"
|
||||
|
||||
|
||||
def test_acre_questions_generator():
|
||||
"""Test question generator loading and access"""
|
||||
config = ACREDatasetConfig(size=10, seed=42)
|
||||
dataset = ACREDataset(config)
|
||||
|
||||
# Test properties of questions
|
||||
assert isinstance(dataset.questions, list)
|
||||
assert len(dataset.questions) > 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue