reasoning-gym/reasoning_gym/induction/acre/blicket.py
Adefioye c8c3930797
Add ACRE(Abstract Causal REasoning Beyond Covariation) python generators (#199)
* Add acre python generators
* acre: improved prompt & formatting of examples, support arbitrary sizes

---------

Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
2025-03-10 00:09:54 +01:00

687 lines
25 KiB
Python

# -*- coding: utf-8 -*-
import json
from typing import Any, Dict, List
from reasoning_gym.data import get_data_file_path
from .const import (
ALL_CONFIG_SIZE,
ONOFFOFF_MAX_NON,
ONOFFOFF_MAX_POTENTIAL,
ONOFFOFF_MIN_NON,
ONOFFOFF_MIN_POTENTIAL,
ONONOFF_MAX_NON,
ONONOFF_MAX_POTENTIAL,
ONONOFF_MIN_NON,
ONONOFF_MIN_POTENTIAL,
)
class BlicketView(object):
def __init__(self, light_state="no"):
# light state: "no" for no light, "off" for light off, "on" for light on
self.objects = []
self.light_state = light_state
def add_objects(self, objects):
for obj in objects:
self.objects.append(obj)
def remove_objects(self, objects):
for obj in objects:
self.objects.remove(obj)
def __repr__(self):
return "BlicketView(objects={}, light_state={})".format(self.objects, self.light_state)
class BlicketQuestion(object):
def __init__(
self,
min_potential_blickets,
max_potential_blickets,
min_non_blickets,
max_non_blickets,
config_size,
shuffle,
rng,
):
self.min_potential_blickets = min_potential_blickets
self.max_potential_blickets = max_potential_blickets
self.min_non_blickets = min_non_blickets
self.max_non_blickets = max_non_blickets
self.config_size = config_size
self.shuffle = shuffle
self.rng = rng
potential_blicket_num = self.rng.randint(self.min_potential_blickets, self.max_potential_blickets)
non_blicket_num = self.rng.randint(self.min_non_blickets, self.max_non_blickets)
samples = self.rng.sample(list(range(self.config_size)), k=potential_blicket_num + non_blicket_num)
self.blickets = []
self.set_blickets = []
self.potential_blickets = samples[:potential_blicket_num]
self.non_blickets = samples[potential_blicket_num:]
self.direct = []
self.indirect = []
self.screen_off = []
def get_habituation_views(self):
blicket_obj = self.potential_blickets[0]
self.add_blicket(blicket_obj)
non_blicket_obj = self.non_blickets[0]
view_with_blicket = BlicketView("on")
view_with_blicket.add_objects([blicket_obj])
view_with_non_blicket = BlicketView("off")
view_with_non_blicket.add_objects([non_blicket_obj])
view_with_both = BlicketView("on")
view_with_both.add_objects([blicket_obj, non_blicket_obj])
habituation_views = [view_with_blicket, view_with_non_blicket, view_with_both]
# bookkeeping for candidate choices
self.direct.append(blicket_obj)
self.screen_off.append(non_blicket_obj)
return habituation_views
def get_evidence_views(self):
raise NotImplementedError("The parent class should not be used.")
def get_views(self):
habituation_views = self.get_habituation_views()
evidence_views = self.get_evidence_views()
if self.shuffle:
self.rng.shuffle(habituation_views)
self.rng.shuffle(evidence_views)
return habituation_views + evidence_views
def sanity_check(self):
# no duplicates
assert len(self.direct) == len(set(self.direct))
assert len(self.indirect) == len(set(self.indirect))
assert len(self.screen_off) == len(set(self.screen_off))
# no intersection
assert len(set(self.direct).intersection(self.indirect)) == 0
assert len(set(self.direct).intersection(self.screen_off)) == 0
assert len(set(self.indirect).intersection(self.screen_off)) == 0
# completeness
assert set(self.direct + self.indirect + self.screen_off) == set(self.blickets + self.non_blickets)
def check_blickets(self, views):
blickets = set()
non_blickets = set()
potential_blickets = set()
direct = set()
indirect = set()
screen_off = set()
on_views = [view for view in views if view.light_state == "on"]
off_views = [view for view in views if view.light_state == "off"]
assert len(on_views) + len(off_views) == len(views)
for off_view in off_views:
non_blickets.update(off_view.objects)
for on_view in on_views:
if len(on_view.objects) == 1:
blickets.update(on_view.objects)
else:
diff_set = set(on_view.objects).difference(non_blickets)
if len(diff_set) == 1:
blickets.update(diff_set)
all_objects = set()
for view in views:
all_objects.update(view.objects)
potential_blickets.update(all_objects.difference(non_blickets).difference(blickets))
assert blickets == set(self.blickets)
assert non_blickets == set(self.non_blickets)
assert potential_blickets == set(self.potential_blickets)
for on_view in on_views:
if len(on_view.objects) == 1:
direct.update(on_view.objects)
on_view_objects = set()
for on_view in on_views:
on_view_objects.update(on_view.objects)
off_view_objects = set()
for off_view in off_views:
off_view_objects.update(off_view.objects)
direct.update(off_view_objects.difference(on_view_objects))
screen_off.update(off_view_objects.intersection(on_view_objects))
for on_view in on_views:
if len(on_view.objects) > 1:
diff_set = set(on_view.objects).difference(non_blickets)
if len(diff_set) == 1 and not diff_set.issubset(direct):
indirect.update(diff_set)
assert direct == set(self.direct)
assert indirect == set(self.indirect)
assert screen_off == set(self.screen_off)
set_blickets = set()
for on_view in on_views:
on_diff_set = set(on_view.objects).difference(non_blickets)
if len(on_diff_set.intersection(blickets)) == 0:
potential_set_blicket = list(on_diff_set)
potential_set_blicket.sort()
set_blickets.add(tuple(potential_set_blicket))
remove_set = set()
for elem in set_blickets:
for another_elem in set_blickets:
elem_set = set(elem)
another_elem_set = set(another_elem)
if elem_set < another_elem_set:
remove_set.add(another_elem)
set_blickets.difference_update(remove_set)
assert set_blickets == set(self.set_blickets), "set_blickets:{}, self.set_blickets:{}".format(
set_blickets, self.set_blickets
)
def check_labels(self, view, label):
diff_set = set(view.objects).difference(self.non_blickets)
if len(diff_set) == 0:
assert label == 0
else:
blicket_inter = diff_set.intersection(self.blickets)
if len(blicket_inter) > 0:
assert label == 2
else:
if self.has_set_blicket(view):
assert label == 2
else:
assert label == 1
def has_set_blicket(self, view):
for set_blicket in self.set_blickets:
if set(set_blicket).issubset(view.objects):
return True
return False
def union_sample(self, union, must_have_one=False):
if must_have_one:
first_set = self.rng.sample(union, k=1)
else:
first_num = self.rng.randint(1, len(union))
first_set = self.rng.sample(union, k=first_num)
second_set = list(set(union).difference(first_set))
if len(second_set) > 0:
additional_num_min = 0
else:
additional_num_min = 1
additional_num = self.rng.randint(additional_num_min, len(first_set))
additional_samples = self.rng.sample(first_set, k=additional_num)
second_set += additional_samples
return first_set, second_set
def add_noise(self, view):
# the first non blicket used in habituation and hence the skip
noise_num = self.rng.randint(0, len(self.non_blickets) - 1)
noise = self.rng.sample(self.non_blickets[1:], k=noise_num)
view.add_objects(noise)
def add_blicket(self, obj):
self.potential_blickets.remove(obj)
self.blickets.append(obj)
def add_set_blicket(self, set_blicket):
to_remove_set = []
for already_set_blicket in self.set_blickets:
if set(set_blicket) < set(already_set_blicket):
to_remove_set.append(already_set_blicket)
if set(already_set_blicket) < set(set_blicket):
to_remove_set.append(set_blicket)
self.set_blickets.append(set_blicket)
for elem in to_remove_set:
self.set_blickets.remove(elem)
def generate_cause_questions(self, train, regime="IID"):
def fixed_sum_sample(lower, upper, fixed_sum):
values = [0] * len(lower)
assert sum(upper) >= fixed_sum
while True:
residual = fixed_sum
for i in range(len(values) - 1):
values[i] = self.rng.randint(lower[i], min(upper[i], residual))
residual = residual - values[i]
if residual >= lower[-1] and residual <= upper[-1]:
values[-1] = residual
break
assert sum(values) == fixed_sum
return values
# direct_sample, indirect_sample, screen_off_sample, potential_sample
# these magic numbers and changes are to adjust dataset statistics
constraint_lower = [0] * 3
if regime == "Comp":
# for Comp split
constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 2, 0)]
if regime == "Sys":
if train:
# for Sys train split
constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 3, 0)]
else:
# for Sys val / test split
constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 1, 0)]
if regime == "IID":
# for IID split
constraint_upper = [len(self.direct), len(self.indirect), max(len(self.screen_off) - 1, 0)]
potential_sample_num = min(len(self.potential_blickets), 1)
direct_sample_num, indirect_sample_num, screen_off_sample_num = fixed_sum_sample(
constraint_lower, constraint_upper, 2 - potential_sample_num
)
direct_samples = self.rng.sample(self.direct, k=direct_sample_num)
indirect_samples = self.rng.sample(self.indirect, k=indirect_sample_num)
screen_off_samples = self.rng.sample(self.screen_off, k=screen_off_sample_num)
potential_samples = self.rng.sample(self.potential_blickets, k=potential_sample_num)
questions = []
# label: 0 for light off, 1 for unknown, 2 for light up
for direct_sample in direct_samples:
assert direct_sample in self.blickets or direct_sample in self.non_blickets
if direct_sample in self.blickets:
label = 2
else:
label = 0
cause_view = BlicketView("no")
cause_view.add_objects([direct_sample])
questions.append((cause_view, label, "direct"))
for indirect_sample in indirect_samples:
assert indirect_sample in self.blickets
label = 2
cause_view = BlicketView("no")
cause_view.add_objects([indirect_sample])
questions.append((cause_view, label, "indirect"))
for screen_off_sample in screen_off_samples:
assert screen_off_sample in self.non_blickets
label = 0
cause_view = BlicketView("no")
cause_view.add_objects([screen_off_sample])
questions.append((cause_view, label, "screen_off"))
for potential_sample in potential_samples:
assert potential_sample in self.potential_blickets
label = 1
cause_view = BlicketView("no")
cause_view.add_objects([potential_sample])
questions.append((cause_view, label, "potential"))
if self.shuffle:
self.rng.shuffle(questions)
return questions
def generate_intervention_questions(self, views):
# on_views = [view for view in views if view.light_state == "on" and len(view.objects) >= 2]
off_views = [view for view in views if view.light_state == "off"]
# on_view = rng.sample(on_views, k=1)[0]
# on_view_ref = views.index(on_view)
off_view = self.rng.sample(off_views, k=1)[0]
off_view_ref = views.index(off_view)
questions = []
all_possibilities = (
list(set(self.blickets + self.potential_blickets + self.non_blickets).difference(off_view.objects))
+ self.set_blickets
)
# adjust weight during sample for better statistics
all_possibilities += list(set(self.potential_blickets).difference(off_view.objects))
possibilities = self.rng.sample(all_possibilities, k=2)
for possibility in possibilities:
if possibility in self.direct:
q_type = "direct"
elif possibility in self.indirect:
q_type = "indirect"
elif possibility in self.screen_off:
q_type = "screen_off"
elif possibility in self.potential_blickets:
q_type = "potential"
else:
assert possibility in self.set_blickets
q_type = "indirect"
if possibility in self.blickets:
label = 2
elif possibility in self.potential_blickets:
label = 1
elif possibility in self.non_blickets:
label = 0
else:
assert possibility in self.set_blickets
label = 2
intervention_view = BlicketView("no")
intervention_view.add_objects(off_view.objects)
if type(possibility) == tuple:
possibility = list(possibility)
else:
possibility = [possibility]
intervention_view.add_objects(possibility)
questions.append((intervention_view, label, q_type, off_view_ref))
if self.shuffle:
self.rng.shuffle(questions)
return questions
class OnOffOff(BlicketQuestion):
def __init__(
self,
min_potential_blickets,
max_potential_blickets,
min_non_blickets,
max_non_blickets,
config_size,
shuffle,
rng,
):
super(OnOffOff, self).__init__(
min_potential_blickets,
max_potential_blickets,
min_non_blickets,
max_non_blickets,
config_size,
shuffle,
rng,
)
def get_evidence_views(self):
# the first non blicket used in habituation and hence the skip
first_off, second_off = self.union_sample(self.non_blickets[1:])
first_off_view = BlicketView("off")
first_off_view.add_objects(first_off)
second_off_view = BlicketView("off")
second_off_view.add_objects(second_off)
on_view = BlicketView("on")
on_list = self.potential_blickets[:]
if len(on_list) == 1:
self.add_blicket(on_list[0])
else:
set_blicket = on_list[:]
set_blicket.sort()
set_blicket = tuple(set_blicket)
self.add_set_blicket(set_blicket)
on_view.add_objects(on_list)
self.add_noise(on_view)
evidence_views = [on_view, first_off_view, second_off_view]
# bookkeeping for candidate choices
off_union_set = set(self.non_blickets[1:])
on_set = set(on_view.objects)
on_diff_set = on_set.difference(off_union_set)
direct_set = list(off_union_set.difference(on_set))
screen_off_set = list(off_union_set.intersection(on_set))
self.direct.extend(direct_set)
self.screen_off.extend(screen_off_set)
if len(on_set) == 1:
self.direct.extend(list(on_set))
else:
if len(on_diff_set) == 1:
self.indirect.extend(list(on_diff_set))
return evidence_views
class OnOnOff(BlicketQuestion):
def __init__(
self,
min_potential_blickets,
max_potential_blickets,
min_non_blickets,
max_non_blickets,
config_size,
shuffle,
rng,
):
super(OnOnOff, self).__init__(
min_potential_blickets,
max_potential_blickets,
min_non_blickets,
max_non_blickets,
config_size,
shuffle,
rng,
)
def get_evidence_views(self):
# the first non blicket used in habituation and hence the skip
off_list = self.non_blickets[1:]
off_view = BlicketView("off")
off_view.add_objects(off_list)
first_on, second_on = self.union_sample(self.potential_blickets)
first_on_view = BlicketView("on")
first_on_view.add_objects(first_on)
second_on_view = BlicketView("on")
second_on_view.add_objects(second_on)
for on in [first_on, second_on]:
if len(on) == 1:
if on[0] not in self.blickets:
self.add_blicket(on[0])
for on in [first_on, second_on]:
if len(set(on).intersection(self.blickets)) == 0:
set_blicket = on[:]
set_blicket.sort()
set_blicket = tuple(set_blicket)
if set_blicket not in self.set_blickets:
self.add_set_blicket(set_blicket)
self.add_noise(first_on_view)
self.add_noise(second_on_view)
evidence_views = [first_on_view, second_on_view, off_view]
# bookkeeping for candidate choices
on_union_set = set(first_on_view.objects).union(second_on_view.objects)
off_set = set(off_view.objects)
direct_set = list(off_set.difference(on_union_set))
screen_off_set = list(off_set.intersection(on_union_set))
self.direct.extend(direct_set)
self.screen_off.extend(screen_off_set)
for on_view in [first_on_view, second_on_view]:
if len(on_view.objects) == 1:
obj = on_view.objects[0]
if obj not in self.direct:
self.direct.append(obj)
for on_view in [first_on_view, second_on_view]:
diff_set = set(on_view.objects).difference(off_set)
if len(on_view.objects) > 1 and len(diff_set) == 1:
obj = diff_set.pop()
if obj not in self.indirect and obj not in self.direct:
self.indirect.append(obj)
return evidence_views
def serialize(questions):
question_list = []
for question in questions:
view_list = []
for i in range(6):
json_dict = {}
json_dict["light_state"] = question[i].light_state
json_dict["objects"] = question[i].objects
view_list.append(json_dict)
for i in range(6, 8):
json_dict = {}
json_dict["light_state"] = question[i][0].light_state
json_dict["objects"] = question[i][0].objects
json_dict["label"] = question[i][1]
json_dict["type"] = question[i][2]
view_list.append(json_dict)
for i in range(8, 10):
json_dict = {}
json_dict["light_state"] = question[i][0].light_state
json_dict["objects"] = question[i][0].objects
json_dict["label"] = question[i][1]
json_dict["type"] = question[i][2]
json_dict["ref"] = question[i][3]
view_list.append(json_dict)
question_list.append(view_list)
return question_list
def config_control(size, train, config_size, regime, rng):
questions = []
for _ in range(size // 2):
blicket_machine = OnOffOff(
ONOFFOFF_MIN_POTENTIAL, ONOFFOFF_MAX_POTENTIAL, ONOFFOFF_MIN_NON, ONOFFOFF_MAX_NON, config_size, True, rng
)
context_views = blicket_machine.get_views()
blicket_machine.sanity_check()
blicket_machine.check_blickets(context_views)
cause_questions = blicket_machine.generate_cause_questions(train=train, regime=regime)
for cause_question in cause_questions:
blicket_machine.check_labels(cause_question[0], cause_question[1])
intervention_questions = blicket_machine.generate_intervention_questions(context_views)
for intervention_question in intervention_questions:
blicket_machine.check_labels(intervention_question[0], intervention_question[1])
questions.append(context_views + cause_questions + intervention_questions)
for _ in range(size // 2):
blicket_machine = OnOnOff(
ONONOFF_MIN_POTENTIAL, ONONOFF_MAX_POTENTIAL, ONONOFF_MIN_NON, ONONOFF_MAX_NON, config_size, True, rng
)
context_views = blicket_machine.get_views()
blicket_machine.sanity_check()
blicket_machine.check_blickets(context_views)
cause_questions = blicket_machine.generate_cause_questions(train=train, regime=regime)
for cause_question in cause_questions:
blicket_machine.check_labels(cause_question[0], cause_question[1])
intervention_questions = blicket_machine.generate_intervention_questions(context_views)
for intervention_question in intervention_questions:
blicket_machine.check_labels(intervention_question[0], intervention_question[1])
questions.append(context_views + cause_questions + intervention_questions)
rng.shuffle(questions)
return questions
def dist_control(size, train, regime, rng):
questions = []
for _ in range(size):
if train:
blicket_machine = OnOffOff(
ONOFFOFF_MIN_POTENTIAL,
ONOFFOFF_MAX_POTENTIAL,
ONOFFOFF_MIN_NON,
ONOFFOFF_MAX_NON,
ALL_CONFIG_SIZE,
True,
rng,
)
else:
blicket_machine = OnOnOff(
ONONOFF_MIN_POTENTIAL,
ONONOFF_MAX_POTENTIAL,
ONONOFF_MIN_NON,
ONONOFF_MAX_NON,
ALL_CONFIG_SIZE,
True,
rng,
)
context_views = blicket_machine.get_views()
blicket_machine.sanity_check()
blicket_machine.check_blickets(context_views)
cause_questions = blicket_machine.generate_cause_questions(train=train, regime=regime)
for cause_question in cause_questions:
blicket_machine.check_labels(cause_question[0], cause_question[1])
intervention_questions = blicket_machine.generate_intervention_questions(context_views)
for intervention_question in intervention_questions:
blicket_machine.check_labels(intervention_question[0], intervention_question[1])
questions.append(context_views + cause_questions + intervention_questions)
rng.shuffle(questions)
return questions
# Text translation functions
LIGHT_DICT_TEXT = ["off", "undetermined", "on"]
OBJECT_DICT_FILE_PATH = get_data_file_path("acre_objects.json")
with open(OBJECT_DICT_FILE_PATH, "r") as object_dict_file:
OBJECT_DICT_TEXT = json.load(object_dict_file)
def get_example_text(objects, light_state):
input_examples = []
for object in objects:
obj_desc = OBJECT_DICT_TEXT[str(object)]
color = obj_desc["color"]
shape = obj_desc["shape"]
material = obj_desc["material"]
# [color, material, shape]
input_examples.append([color, material, shape])
return {"input": input_examples, "output": light_state}
# Label has 0, 1, 2 integer which are indexed into the array ["off", "undetermined", "on"]
def get_trial_text(objects, label):
# object_text = ""
input_examples = []
for object in objects:
obj_desc = OBJECT_DICT_TEXT[str(object)]
color = obj_desc["color"]
shape = obj_desc["shape"]
material = obj_desc["material"]
# [color, material, shape]
input_examples.append([color, material, shape])
return {"input": input_examples, "output": LIGHT_DICT_TEXT[label]}
# Select translation functions
get_example = get_example_text
get_trial = get_trial_text
get_answer_list = lambda: LIGHT_DICT_TEXT
def final_parse(serialized_questions) -> List[Dict[str, Any]]:
output_data = []
for sample in serialized_questions: # For each data in input (6 examples, 4 tests)
# {"examples": [{"input": [[color, material, shape]], "output": "on"}], "question": {"input": [color, material, shape], "output", "on", "type": "direct"}}
# examples, ideally is an array of dict(input, output)
examples = []
for entry in sample:
if entry["light_state"] != "no": # example
example = get_example(entry["objects"], entry["light_state"])
examples.append(example)
else: # test case
# question is a dict of input and output
question = get_trial(entry["objects"], entry["label"])
output_data.append({"examples": examples, "question": question})
return output_data