mirror of
https://github.com/open-thought/reasoning-gym.git
synced 2026-04-30 17:40:45 +00:00
adds zebrapuzzles
This commit is contained in:
parent
5d84b6bec5
commit
0c9094e9f4
20 changed files with 2447 additions and 2 deletions
|
|
@ -0,0 +1,54 @@
|
|||
from z3 import *
|
||||
import solver as solver
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
import sat_utils
|
||||
|
||||
def check_singe_clues(puzzle_idx, answer_header, answer_value, running_clues, idx2clue, used_clue, all_cells):
|
||||
solved_cell = defaultdict(int)
|
||||
valid_solutions = defaultdict(int)
|
||||
print('check single clues')
|
||||
for idx in idx2clue:
|
||||
if idx not in used_clue:
|
||||
new_clues = running_clues+idx2clue[idx].as_cnf()
|
||||
numbered_cnf, num2var = sat_utils.translate(new_clues)
|
||||
single_solver = solver.my_solver(puzzle_idx, answer_header, answer_value, numbered_cnf, num2var)
|
||||
cell_info = single_solver.check_cell_difficulty()
|
||||
num_cell=0
|
||||
for cell in cell_info:
|
||||
if cell not in all_cells:
|
||||
num_cell+=1
|
||||
solved_cell[idx] = num_cell
|
||||
del (single_solver)
|
||||
idx = sorted(solved_cell.items(),key=lambda x: x[1], reverse=True)[0][0]
|
||||
if solved_cell[idx]!=0:
|
||||
return [idx]
|
||||
else:
|
||||
return [1000]
|
||||
|
||||
def check(puzzle_idx, answer_header, answer_value, running_clues, idx2clue, used_clue, all_cells):
|
||||
idx = check_singe_clues(puzzle_idx, answer_header, answer_value, running_clues, idx2clue, used_clue, all_cells)
|
||||
if idx[0]<1000:
|
||||
return idx
|
||||
else:
|
||||
print('check multiple clues')
|
||||
solved_cell = defaultdict(int)
|
||||
all_left_clues_idx = [i for i in idx2clue if i not in used_clue]
|
||||
print(all_left_clues_idx)
|
||||
for n in range(2, 10):
|
||||
combinations = itertools.combinations(all_left_clues_idx, n)
|
||||
for comb in combinations:
|
||||
new_clues = [clues for clues in running_clues]
|
||||
for comb_idx in comb:
|
||||
new_clues += idx2clue[comb_idx].as_cnf()
|
||||
numbered_cnf, num2var = sat_utils.translate(new_clues)
|
||||
single_solver = solver.my_solver(puzzle_idx, answer_header, answer_value, numbered_cnf, num2var)
|
||||
cell_info = single_solver.check_cell_difficulty()
|
||||
num_cell = 0
|
||||
for cell in cell_info:
|
||||
if cell not in all_cells:
|
||||
num_cell += 1
|
||||
solved_cell[comb] = num_cell
|
||||
del (single_solver)
|
||||
if num_cell!=0:
|
||||
return comb
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
|
||||
|
||||
def print_clue(clue_idxs, idx2clue, new_cell, first=False):
|
||||
if len(clue_idxs)==1:
|
||||
# print(type(idx2clue[clue_idxs[0]]))
|
||||
# if type(idx2clue[clue_idxs[0]]) == clues.found_at:
|
||||
return single_find_at(idx2clue[clue_idxs[0]], new_cell, first)
|
||||
else:
|
||||
if first:
|
||||
result= 'First combining clues: '
|
||||
else:
|
||||
result= 'Then combining clues: '
|
||||
for current_idx in clue_idxs:
|
||||
result += "<{}>".format(idx2clue[current_idx])
|
||||
result += ' Unique Values Rules and the fixed table structure. We know that '
|
||||
for cell in new_cell:
|
||||
result += cell
|
||||
result += '. '
|
||||
return result
|
||||
|
||||
def single_find_at(clue, new_cell, first=False):
|
||||
if first:
|
||||
result="First applying clue: "
|
||||
else:
|
||||
result="Then applying clue: "
|
||||
if len(new_cell)>1:
|
||||
result += "<{}> and Unique Values We know that ".format(clue)
|
||||
else:
|
||||
result += "<{}> We know that ".format(clue)
|
||||
for cell in new_cell:
|
||||
result += cell
|
||||
result += '. '
|
||||
return result
|
||||
45
reasoning_gym/logic/contrib/logic_puzzle/graph/main.py
Normal file
45
reasoning_gym/logic/contrib/logic_puzzle/graph/main.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import pickle
|
||||
from tqdm import tqdm
|
||||
from z3 import *
|
||||
import json
|
||||
import argparse
|
||||
import solver
|
||||
import sat_utils
|
||||
|
||||
def solve_logic_grid_puzzle(inputfile, ground_truth):
|
||||
answers = [json.loads(answer) for answer in list(open(ground_truth, 'r'))][0]
|
||||
cell_difficulty = {}
|
||||
with open(inputfile, 'rb') as f:
|
||||
puzzles = pickle.load(f)
|
||||
for i in tqdm(range(len(puzzles[:]))):
|
||||
d=puzzles[i]
|
||||
assert d['idx']==answers[i]['idx']
|
||||
answer_header = answers[i]['solution']['table_header']
|
||||
answer_value = answers[i]['solution']['table_rows']
|
||||
# read cnf form
|
||||
symbolic_cnf = d['puzzle'].as_cnf()
|
||||
numbered_cnf, num2var = sat_utils.translate(symbolic_cnf)
|
||||
|
||||
single_solver = solver.my_solver(d['idx'], answer_header, answer_value, numbered_cnf, num2var)
|
||||
solution, unique = single_solver.check_solution()
|
||||
if unique:
|
||||
for row_num in range(len(solution)):
|
||||
for column_num in range(len(solution[row_num])):
|
||||
assert row_num == solution[row_num][column_num]
|
||||
difficulty = single_solver.check_problem_difficulty()
|
||||
cell_difficulty[d['idx']]=difficulty
|
||||
|
||||
with open('./logic_grid_puzzles.test.difficulty.pkl', 'wb') as outputfile:
|
||||
pickle.dump(cell_difficulty, outputfile)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input_data', type=str, default="../logic_grid_puzzles.test.pkl")
|
||||
parser.add_argument('--ground_truth', type=str, default="../logic_grid_puzzles.test.json")
|
||||
args = parser.parse_args()
|
||||
|
||||
solve_logic_grid_puzzle(args.input_data, args.ground_truth)
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def parse_table(table_string):
|
||||
flat_rows = [r for r in table_string.split('\n') if r.startswith('$')]
|
||||
table = {}
|
||||
for flat_row in flat_rows:
|
||||
flat_cells = [c for c in flat_row.split('|') if c]
|
||||
house, attributes = flat_cells[0], flat_cells[1:]
|
||||
h_prefix, h_number = house.split(':')
|
||||
assert h_prefix == '$ House'
|
||||
house_key = f'House {h_number.strip()}'
|
||||
table[house_key] = {}
|
||||
|
||||
for att in attributes:
|
||||
if not ':' in att:
|
||||
continue
|
||||
name, value = att.split(':')
|
||||
table[house_key][name.strip()] = value.strip()
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def copy_table(table):
|
||||
new_table = {}
|
||||
for house, attributes in table.items():
|
||||
new_table[house] = {}
|
||||
for att_name, att_value in attributes.items():
|
||||
new_table[house][att_name] = None
|
||||
return new_table
|
||||
|
||||
|
||||
def parse_step(step_text, table_to_fill):
|
||||
parsed_clues = [s.split('>')[0] for s in step_text.split('<') if '>' in s]
|
||||
ans_texts = [s.strip() for s in step_text.split('We know that ')[-1].split('.') if s.strip().startswith('The ')]
|
||||
for ans_text in ans_texts:
|
||||
try:
|
||||
att_name = ans_text.split(' in house ')[0].split('The ')[-1].strip()
|
||||
house_num = ans_text.split(' in house ')[1].split(' is ')[0].strip()
|
||||
att_value = ans_text.split(' is ')[-1].strip()
|
||||
table_to_fill[f'House {house_num}'][att_name] = att_value
|
||||
except:
|
||||
print(parsed_clues)
|
||||
print(ans_text)
|
||||
print(table_to_fill)
|
||||
print()
|
||||
continue
|
||||
return parsed_clues, table_to_fill
|
||||
|
||||
|
||||
def pre_process(src_file: os.PathLike, dest_dir: os.PathLike, overwrite: bool = False):
|
||||
outputs = json.load(open(src_file, 'r'))['data']
|
||||
|
||||
src_file = Path(src_file)
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
save_file = os.path.join(dest_dir, f"parsed_{src_file.name}")
|
||||
|
||||
if not overwrite and os.path.exists(save_file):
|
||||
return
|
||||
|
||||
items = []
|
||||
for output in tqdm(outputs, desc=f"{src_file.stem}"):
|
||||
groundtruth_table = parse_table(output['truth_outputs'][0])
|
||||
if type(output['output_text']) == list:
|
||||
output_text = output['output_text'][0].split('\n')
|
||||
else:
|
||||
output_text = output['output_text'].split('\n')
|
||||
step_texts = [o for o in output_text if o.startswith('Step ')]
|
||||
steps = []
|
||||
partial_table = copy_table(groundtruth_table)
|
||||
for s_text in step_texts:
|
||||
clues, filled_table = parse_step(s_text, partial_table)
|
||||
steps.append({
|
||||
'clues': clues,
|
||||
'partial_table': filled_table,
|
||||
})
|
||||
partial_table = deepcopy(filled_table)
|
||||
|
||||
predict_table_texts = [o for o in output_text if o.startswith('$ House')]
|
||||
if not predict_table_texts:
|
||||
predicted_table = partial_table
|
||||
else:
|
||||
predicted_table = parse_table('\n'.join(predict_table_texts))
|
||||
|
||||
items.append({
|
||||
'groundtruth_table': groundtruth_table,
|
||||
'predicted_table': predicted_table,
|
||||
'steps': steps
|
||||
})
|
||||
|
||||
with open(save_file, 'w') as f:
|
||||
json.dump(items, f, indent=4)
|
||||
|
||||
|
||||
def error_analysis(home_path: os.PathLike, file_names: Sequence[str] = None):
|
||||
if not file_names:
|
||||
parsed_files = [p for p in Path(home_path).glob("parsed_table*judged.json")]
|
||||
else:
|
||||
parsed_files = [os.path.join(home_path, f) for f in file_names]
|
||||
|
||||
datasets = []
|
||||
for parsed_file in parsed_files:
|
||||
datasets.extend([json.loads(s) for s in open(parsed_file, 'r').readlines()])
|
||||
|
||||
error_types = ['correct', 'type1', 'type2', 'type3']
|
||||
error_stats = {}
|
||||
for data in tqdm(datasets):
|
||||
gt_table = data['groundtruth_table']
|
||||
previous_type = None
|
||||
for i, step in enumerate(data["steps"]):
|
||||
partial_table = step['partial_table']
|
||||
# whether value correct:
|
||||
value_to_check = []
|
||||
for house_name, attributes in partial_table.items():
|
||||
for att, value in attributes.items():
|
||||
if not i:
|
||||
previous_empty = True
|
||||
else:
|
||||
if att not in data["steps"][i-1]['partial_table'][house_name]:
|
||||
previous_empty = False
|
||||
else:
|
||||
previous_empty = data["steps"][i-1]['partial_table'][house_name][att] is None
|
||||
if value is not None and previous_empty:
|
||||
value_to_check.append((house_name, att))
|
||||
|
||||
value_correct = all(
|
||||
[partial_table[house_name][att] == gt_table[house_name][att] for house_name, att in value_to_check]
|
||||
)
|
||||
operation_correct = step['label']
|
||||
if not i:
|
||||
ancestor_correct = True
|
||||
else:
|
||||
ancestor_correct = previous_type == 'correct'
|
||||
|
||||
if value_correct:
|
||||
if ancestor_correct:
|
||||
node_type = 'correct'
|
||||
else:
|
||||
node_type = 'type3'
|
||||
else:
|
||||
if operation_correct:
|
||||
node_type = 'type2'
|
||||
else:
|
||||
node_type = 'type1'
|
||||
|
||||
previous_type = node_type
|
||||
|
||||
if i not in error_stats.keys():
|
||||
error_stats[i] = {e: 0 for e in error_types}
|
||||
error_stats[i][node_type] += 1
|
||||
|
||||
print(json.dumps(error_stats))
|
||||
percent_stats = {}
|
||||
number_nodes = {}
|
||||
for key, values in error_stats.items():
|
||||
percent_stats[key] = {}
|
||||
total_nodes = sum([values[t] for t in error_types])
|
||||
number_nodes[key] = total_nodes
|
||||
for t in error_types:
|
||||
percent_stats[key][t] = values[t] / total_nodes
|
||||
print(json.dumps(percent_stats))
|
||||
print(number_nodes)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--cot_dir', type=str, default=None)
|
||||
parser.add_argument('--output_dir', type=str, default='tmp')
|
||||
parser.add_argument('--overwrite', action='store_true', default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
cot_dir = Path(args.cot_dir)
|
||||
assert cot_dir.exists(), "cot_dir does not exist"
|
||||
assert cot_dir.is_dir(), "cot_dir is not a directory"
|
||||
|
||||
for cot_file in cot_dir.glob("table*.json"):
|
||||
pre_process(cot_file, args.output_dir, args.overwrite)
|
||||
|
||||
error_analysis(args.output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
102
reasoning_gym/logic/contrib/logic_puzzle/graph/reasoning_path.py
Normal file
102
reasoning_gym/logic/contrib/logic_puzzle/graph/reasoning_path.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import pickle
|
||||
from tqdm import tqdm
|
||||
from z3 import *
|
||||
import json
|
||||
import argparse
|
||||
import solver as solver
|
||||
import check_clues
|
||||
import graph_literals
|
||||
from collections import defaultdict
|
||||
import sat_utils
|
||||
|
||||
|
||||
def get_idx2clue(clues):
|
||||
clue_num2_clue = defaultdict(int)
|
||||
clue_type2_clue_num = defaultdict(list)
|
||||
clue_num = 0
|
||||
for clue in list(clues):
|
||||
clue_num2_clue[clue_num] = clue
|
||||
clue_type2_clue_num[type(clue)].append(clue_num)
|
||||
clue_num += 1
|
||||
return clue_num2_clue, clue_type2_clue_num
|
||||
|
||||
def logic_grid_puzzle(inputfile, ground_truth, size, lower_part, higher_part):
|
||||
reasoning_result = []
|
||||
answers = json.load(open(ground_truth, 'r'))
|
||||
puzzles = pickle.load(open(inputfile, 'rb'))
|
||||
cell_difficulty = {}
|
||||
mode = inputfile[inputfile.find('puzzles.') + 8:inputfile.find('.pkl')]
|
||||
print('Number of puzzles', len(answers))
|
||||
assert len(answers) == len(puzzles)
|
||||
new_data = []
|
||||
for i in tqdm(range(len(puzzles[:]))):
|
||||
d = puzzles[i]
|
||||
per_size_idx = int(puzzles[i]['idx'].split('-')[-1])
|
||||
if d['idx'].startswith("lgp-"+mode+"-"+size) and per_size_idx>=lower_part and per_size_idx<higher_part:
|
||||
print('Puzzle id:', d['idx'])
|
||||
print('Puzzle', d)
|
||||
print("Solving puzzle"+"==============="*7+"Solving puzzle")
|
||||
|
||||
assert d['idx']==answers[i]['idx']
|
||||
answer_header = answers[i]['solution']['table_header']
|
||||
answer_value = answers[i]['solution']['table_rows']
|
||||
|
||||
puzzle_clues = d['puzzle'].clues
|
||||
idx2clue, clue_type2_idx = get_idx2clue(puzzle_clues)
|
||||
|
||||
running_clues = []
|
||||
all_cells = []
|
||||
used_clue = []
|
||||
self_constraints = d['puzzle'].constraints
|
||||
running_clues.extend(self_constraints)
|
||||
|
||||
# for i in range(len(idx2clue)):
|
||||
first=True
|
||||
step_num=1
|
||||
reasoning = ""
|
||||
while len(used_clue)<len(idx2clue):
|
||||
reasoning += 'Step {}: '.format(step_num)
|
||||
step_num+=1
|
||||
clue_idxs = check_clues.check(d['idx'],
|
||||
answer_header, answer_value,
|
||||
running_clues, idx2clue, used_clue, all_cells)
|
||||
for current_clue_idx in clue_idxs:
|
||||
running_clues.extend(idx2clue[current_clue_idx].as_cnf())
|
||||
used_clue.append(current_clue_idx)
|
||||
|
||||
numbered_cnf, num2var = sat_utils.translate(running_clues)
|
||||
single_solver = solver.my_solver(d['idx'], answer_header, answer_value, numbered_cnf, num2var)
|
||||
cell_info = single_solver.check_cell_difficulty()
|
||||
new_cell=[]
|
||||
for cell in cell_info:
|
||||
if cell not in all_cells:
|
||||
new_cell.append(cell)
|
||||
all_cells.append(cell)
|
||||
reasoning += graph_literals.print_clue(clue_idxs, idx2clue, new_cell, first)
|
||||
first=False
|
||||
reasoning+='\n'
|
||||
assert len(all_cells) == len(answer_value)*(len(answer_value[0])-1)
|
||||
reasoning+='The puzzle is solved.'
|
||||
reasoning_result.append(reasoning)
|
||||
print(reasoning)
|
||||
single_data = answers[i]
|
||||
single_data["reasoning"]=reasoning
|
||||
new_data.append(single_data)
|
||||
|
||||
with open("./data/logic_grid_puzzles.reasoning."+mode+size+"-"+str(lower_part)+"_"+str(higher_part)+".json", "w") as outputfile:
|
||||
json.dump(new_data, outputfile)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input_data', type=str, default="./data/logic_grid_puzzles.test_id_xl.pkl")
|
||||
parser.add_argument('--ground_truth', type=str, default="./data/ogic_grid_puzzles.test_id_xl.json")
|
||||
parser.add_argument('--size', type=str, default="2x")
|
||||
parser.add_argument('--lower_part', type=int, default=0) #min data index
|
||||
parser.add_argument('--higher_part', type=int, default=100) #max data index
|
||||
args = parser.parse_args()
|
||||
|
||||
logic_grid_puzzle(args.input_data, args.ground_truth, args.size, args.lower_part, args.higher_part)
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
137
reasoning_gym/logic/contrib/logic_puzzle/graph/solver.py
Normal file
137
reasoning_gym/logic/contrib/logic_puzzle/graph/solver.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
from z3 import *
|
||||
import collections
|
||||
import solver_utils
|
||||
import numpy as np
|
||||
|
||||
class my_solver:
|
||||
def __init__(self, id, feature_name, feature_value, numbered_cnf, num2var):
|
||||
self.id = id
|
||||
self.num2var = num2var
|
||||
self.clue_num = 0
|
||||
|
||||
self.feature_name = feature_name
|
||||
self.feature_value = feature_value
|
||||
self.numbered_cnf = numbered_cnf
|
||||
|
||||
self.init_solver(feature_name, feature_value, numbered_cnf)
|
||||
|
||||
def init_solver(self, feature_name, feature_value, numbered_cnf):
|
||||
set_param(proof=True)
|
||||
self.s = Solver()
|
||||
|
||||
self.s.set(unsat_core=True)
|
||||
house_total = len(feature_value)
|
||||
self.house_total = house_total
|
||||
features, mapping = self.define_feature(feature_name, feature_value)
|
||||
|
||||
self.houses = [
|
||||
[solver_utils.instanciate_int_constrained('%s%d' % (prop, n), self.s, house_total) for prop in
|
||||
features.keys()] for n in
|
||||
range(house_total)]
|
||||
self.mapping = mapping
|
||||
self.feature = features
|
||||
# because we read feature from GT, we can make this assumption to calculate difficulty.
|
||||
self.ground_truth = self.feature
|
||||
for single_cnf in numbered_cnf:
|
||||
if len(single_cnf) == 1:
|
||||
self.add_single_clue(single_cnf)
|
||||
else:
|
||||
self.add_mul_val_clue(single_cnf)
|
||||
self.clue_num += 1
|
||||
|
||||
def define_feature(self, feature_name, feature_value):
|
||||
features = collections.OrderedDict()
|
||||
for i in range(len(feature_name)):
|
||||
if feature_name[i] != 'House':
|
||||
features[feature_name[i]]=[]
|
||||
for house_j in feature_value:
|
||||
features[feature_name[i]].append(house_j[i])
|
||||
|
||||
feature_mapping = {list(features.keys())[i]: i for i in range(len(features.keys()))}
|
||||
return features, feature_mapping
|
||||
|
||||
def decode_var(self, var_num):
|
||||
attribute, house_num = (self.num2var[var_num].split(' '))
|
||||
attribute_name, attribute_value = (attribute.split('.'))
|
||||
if '_' in attribute_value:
|
||||
attribute_value = attribute_value.replace('_', ' ')
|
||||
house_num = int(house_num) - 1
|
||||
return house_num, attribute_name, attribute_value
|
||||
|
||||
|
||||
def normal_vs_not_constraints(self, house_num, attribute_name, attribute_value):
|
||||
if attribute_name.startswith('~'):
|
||||
attribute_name = attribute_name[1:]
|
||||
return Not(self.houses[house_num][self.mapping[attribute_name]] == self.feature[attribute_name].index(attribute_value))
|
||||
else:
|
||||
return self.houses[house_num][self.mapping[attribute_name]] == self.feature[attribute_name].index(attribute_value)
|
||||
|
||||
def add_single_clue(self, single_cnf):
|
||||
house_num, attribute_name, attribute_value = self.decode_var(single_cnf[0])
|
||||
self.s.assert_and_track(self.normal_vs_not_constraints(house_num, attribute_name, attribute_value), 'a'+str(self.clue_num))
|
||||
|
||||
def add_mul_val_clue(self, single_cnf):
|
||||
cons = []
|
||||
for i in range(len(single_cnf)):
|
||||
house_num, attribute_name, attribute_value = self.decode_var(single_cnf[i])
|
||||
cons.append(self.normal_vs_not_constraints(house_num, attribute_name, attribute_value))
|
||||
self.s.assert_and_track(Or(*cons), 'a' + str(self.clue_num))
|
||||
|
||||
|
||||
def check_solution(self):
|
||||
if self.s.check() == unsat:
|
||||
c = self.s.unsat_core()
|
||||
print("Size of the unsat core:", len(c))
|
||||
print("Unsat core:", ", ".join([str(i) for i in c]))
|
||||
|
||||
proof_str = str(self.s.proof())
|
||||
print(self.s.proof())
|
||||
print("Proof length:", len(proof_str))
|
||||
|
||||
else:
|
||||
m = self.s.model()
|
||||
solution = [[m[case].as_long() for case in line] for line in self.houses]
|
||||
unique=True
|
||||
if count_solutions(self.s)!=1:
|
||||
print(self.id)
|
||||
unique=False
|
||||
# print("Number of solutions:", count_solutions(self.s))
|
||||
return solution, unique
|
||||
|
||||
def print_solution(self, solution):
|
||||
print("Solution:")
|
||||
print(self.feature)
|
||||
|
||||
def check_cell_difficulty(self):
|
||||
results = []
|
||||
clue_num=0
|
||||
for i in range(self.house_total):
|
||||
for feature in self.ground_truth:
|
||||
self.s.assert_and_track(Not(self.houses[i][self.mapping[feature]] == i), "check"+str(clue_num))
|
||||
if self.s.check() == unsat:
|
||||
results.append("The {} in house {} is {}".format(feature, i+1, self.ground_truth[feature][i]))
|
||||
self.s.reset()
|
||||
self.init_solver(self.feature_name, self.feature_value, self.numbered_cnf)
|
||||
return results
|
||||
|
||||
def check_statement_difficulty(self):
|
||||
clue_num=0
|
||||
proof_length = np.zeros((self.house_total, len(self.ground_truth)))
|
||||
for i in range(self.house_total):
|
||||
for feature in self.ground_truth:
|
||||
self.s.assert_and_track(Not(self.houses[i][self.mapping[feature]] == i), "check"+str(clue_num))
|
||||
clue_num+=1
|
||||
if self.s.check() == unsat:
|
||||
c = self.s.unsat_core()
|
||||
proof_length[i, self.mapping[feature]] = self.s.statistics().propagations
|
||||
print(proof_length)
|
||||
return proof_length
|
||||
|
||||
def check_problem_difficulty(self):
|
||||
proof_length = np.zeros((1, 1))
|
||||
proof_length[0, 0] = self.s.statistics().propagations
|
||||
return proof_length
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
from z3 import *
|
||||
|
||||
|
||||
def column(matrix, i):
|
||||
return [matrix[j][i] for j in range(len(matrix))]
|
||||
|
||||
def instanciate_int_constrained(name, s, card):
|
||||
x = Int(name)
|
||||
# Each int represent an index in p[name]
|
||||
s.add(x >= 0, x <= card - 1)
|
||||
return x
|
||||
|
||||
def count_solutions(s, max=1e9):
|
||||
count = 0
|
||||
while s.check() == sat:
|
||||
count += 1
|
||||
if count >= max:
|
||||
return count
|
||||
m = s.model()
|
||||
|
||||
# Create a new constraint the blocks the current model
|
||||
block = []
|
||||
for d in m:
|
||||
# d is a declaration
|
||||
if d.arity() > 0:
|
||||
raise Z3Exception("uninterpreted functions are not supported")
|
||||
# create a constant from declaration
|
||||
c = d()
|
||||
if is_array(c) or c.sort().kind() == Z3_UNINTERPRETED_SORT:
|
||||
raise Z3Exception("arrays and uninterpreted sorts are not supported")
|
||||
block.append(c != m[d])
|
||||
|
||||
s.add(Or(block))
|
||||
return count
|
||||
Loading…
Add table
Add a link
Reference in a new issue