mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
336 lines
No EOL
12 KiB
Python
Executable file
336 lines
No EOL
12 KiB
Python
Executable file
import math
|
|
import random
|
|
import fractions
|
|
import json
|
|
import tqdm
|
|
import os
|
|
import numpy as np
|
|
from multiprocessing import Pool, Queue, Process
|
|
|
|
class Game24Plus:
|
|
"""General version of game24 (Krypto), given N numbers less than X, find a way to get number M using 4 basic operations."""
|
|
def __init__(self, num_numbers, range_max, target_max, seed=None):
|
|
self.rng = np.random.default_rng(seed)
|
|
self.num_numbers = num_numbers
|
|
self.target_max = target_max
|
|
self.range_max = range_max
|
|
self.operations = ['+', '-', '*', '/']
|
|
|
|
def sample_one_number(self, num_min, num_max):
|
|
return self.rng.integers(num_min, num_max, size=1)[0]
|
|
|
|
def get_numbers(self):
|
|
numbers = self.rng.integers(0, self.range_max, size=self.num_numbers).tolist()
|
|
numbers.sort()
|
|
return numbers
|
|
|
|
def enumerate_all_numbers(self, num_numbers):
|
|
if num_numbers == 1:
|
|
for i in range(1, self.range_max):
|
|
yield [i]
|
|
return
|
|
for i in range(1, self.range_max):
|
|
for j in self.enumerate_all_numbers(num_numbers - 1):
|
|
yield [i] + j
|
|
|
|
def sample_operation(self, numbers: list):
|
|
n_pos = self.rng.choice(len(numbers), size=2, replace=False)
|
|
n1 = numbers[n_pos[0]]
|
|
n2 = numbers[n_pos[1]]
|
|
op = self.rng.choice(self.operations).item()
|
|
if op == '/':
|
|
# if n2 == 0 or n1 % n2 != 0:
|
|
if n2 == 0:
|
|
return None
|
|
return n1, op, n2, n_pos[0], n_pos[1]
|
|
|
|
def calculate(self, n1, op, n2):
|
|
if op == '+':
|
|
return n1 + n2
|
|
elif op == '-':
|
|
return n1 - n2
|
|
elif op == '*':
|
|
return n1 * n2
|
|
elif op == '/':
|
|
if n1 % n2 != 0:
|
|
return fractions.Fraction(n1, n2)
|
|
return n1 // n2
|
|
|
|
def get_target(self, numbers: list):
|
|
current_numbers = numbers.copy()
|
|
current_operations = []
|
|
while len(current_numbers) > 1:
|
|
r = self.sample_operation(current_numbers)
|
|
if r is None:
|
|
continue
|
|
n1, op, n2, pos1, pos2 = r
|
|
current_numbers.pop(max(pos1, pos2))
|
|
current_numbers.pop(min(pos1, pos2))
|
|
res = self.calculate(n1, op, n2)
|
|
if isinstance(res, fractions.Fraction):
|
|
a, b = res.as_integer_ratio()
|
|
if a % b == 0:
|
|
res = int(a/b)
|
|
current_operations.append((str(n1), op, str(n2), str(res)))
|
|
current_numbers.append(res)
|
|
return current_numbers[0], current_operations
|
|
|
|
def get_target_limit_range(self, numbers):
|
|
target, operations = self.get_target(numbers)
|
|
num_try = 0
|
|
while not isinstance(target, int) or target < 0 or target > self.target_max:
|
|
target, operations = self.get_target(numbers)
|
|
num_try += 1
|
|
if num_try > 1000:
|
|
return None, None
|
|
return target, operations
|
|
|
|
def solve(self, numbers, target):
|
|
def precise_calculate(n1, op, n2):
|
|
if op == '+':
|
|
return n1 + n2
|
|
elif op == '-':
|
|
return n1 - n2
|
|
elif op == '*':
|
|
return n1 * n2
|
|
elif op == '/':
|
|
if n1 % n2 != 0:
|
|
return fractions.Fraction(n1, n2)
|
|
return n1 // n2
|
|
|
|
if len(numbers) == 1:
|
|
if isinstance(numbers[0], int) and numbers[0] == target:
|
|
return []
|
|
else:
|
|
a, b = numbers[0].as_integer_ratio()
|
|
if a % b == 0 and int(a/b) == target:
|
|
return []
|
|
return None
|
|
for i, n1 in enumerate(numbers):
|
|
for j, n2 in enumerate(numbers):
|
|
if i == j:
|
|
continue
|
|
for op in self.operations:
|
|
if op == '/':
|
|
# if n2 == 0 or n1 % n2 != 0:
|
|
if n2 == 0:
|
|
continue
|
|
res = precise_calculate(n1, op, n2)
|
|
new_num = [n for k, n in enumerate(numbers) if k != i and k != j] + [res]
|
|
new_how = [(str(n1), op, str(n2), str(res))]
|
|
part_solution = self.solve(new_num, target)
|
|
if part_solution is not None:
|
|
return new_how + part_solution
|
|
def generate_ground_truth(self, numbers, target):
|
|
calc_steps = self.solve(numbers, target)
|
|
|
|
result_map = {}
|
|
|
|
for step in calc_steps:
|
|
operand1, operator, operand2, result = step
|
|
|
|
# 如果操作数是之前的结果,则替换为对应的表达式
|
|
if operand1 in result_map:
|
|
operand1 = f"({result_map[operand1]})"
|
|
else:
|
|
operand1 = operand1
|
|
|
|
if operand2 in result_map:
|
|
operand2 = f"({result_map[operand2]})"
|
|
else:
|
|
operand2 = operand2
|
|
|
|
# 构建当前步骤的表达式
|
|
if operator == "":
|
|
expression = operand1 # 如果没有操作符,直接取操作数
|
|
else:
|
|
expression = f"{operand1} {operator} {operand2}"
|
|
|
|
# 将当前结果存入映射表
|
|
result_map[result] = expression
|
|
|
|
# 最终结果是最后一个元组的结果
|
|
final_result = calc_steps[-1][-1]
|
|
final_expression = result_map[final_result]
|
|
|
|
return final_expression
|
|
|
|
|
|
def construct_game24_v1(num_numbers=3, range_max=101, num_samples=10000, target_max=1000, seed=1234, output_dir=None):
|
|
game = Game24Plus(num_numbers, range_max, target_max=target_max, seed=seed)
|
|
with open(os.path.join(output_dir, f'train_m={num_numbers}.jsonl'), 'w') as fp:
|
|
puzzle_dict = {}
|
|
for idx in tqdm.trange(num_samples):
|
|
num_try = 0
|
|
while True:
|
|
numbers = game.get_numbers()
|
|
puzzle = ' '.join(str(n) for n in numbers)
|
|
if puzzle in puzzle_dict:
|
|
num_try += 1
|
|
if num_try > 1000:
|
|
puzzle = None
|
|
break
|
|
continue
|
|
else:
|
|
break
|
|
if puzzle is None:
|
|
print('Failed to generate unique puzzle')
|
|
break
|
|
puzzle_dict[puzzle] = True
|
|
num_try = 0
|
|
data = {}
|
|
while True:
|
|
target, operations = game.get_target_limit_range(numbers)
|
|
target_str = str(target)
|
|
key = puzzle + ' ' + target_str
|
|
if key in data:
|
|
num_try += 1
|
|
if num_try > 10:
|
|
break
|
|
else:
|
|
data[key] = {
|
|
'puzzle': puzzle,
|
|
'target': target_str,
|
|
'operations': operations
|
|
}
|
|
num_try = 0
|
|
for key, ex in data.items():
|
|
print(json.dumps(ex), file=fp)
|
|
|
|
|
|
def construct_helper(args):
|
|
num_numbers, range_max, num_samples_per_target, target_min, target_max, seed = args
|
|
game = Game24Plus(num_numbers, range_max, target_max=target_max, seed=seed)
|
|
data = []
|
|
for target in range(target_min, target_max):
|
|
num_try = 0
|
|
target_str = str(target)
|
|
seen_data = {}
|
|
while True:
|
|
numbers = game.get_numbers()
|
|
puzzle = ' '.join(str(n) for n in numbers)
|
|
if puzzle in seen_data:
|
|
continue
|
|
operations = game.solve(numbers, target)
|
|
if operations is None:
|
|
num_try += 1
|
|
if num_try > 10000:
|
|
break
|
|
continue
|
|
ex = {
|
|
'puzzle': puzzle,
|
|
'target': target_str,
|
|
'operations': operations
|
|
}
|
|
data.append(ex)
|
|
num_try = 0
|
|
seen_data[puzzle] = 1
|
|
if len(seen_data) > num_samples_per_target:
|
|
break
|
|
|
|
return data
|
|
|
|
|
|
def construct_game24_v3(num_numbers=3, range_max=101, num_samples=50000, target_max=1000, seed=1234, num_workers=64, output_dir=None):
|
|
print(f'Args', num_numbers, range_max, num_samples, target_max, seed, num_workers)
|
|
arg_list = []
|
|
num_samples_per_target = (num_samples // target_max) + 1
|
|
chunk_size = 1
|
|
start_target = 0
|
|
while start_target < target_max:
|
|
_target_min = start_target
|
|
_target_max = min(target_max, start_target + chunk_size)
|
|
arg_list.append((
|
|
num_numbers, range_max, num_samples_per_target, _target_min, _target_max, seed
|
|
))
|
|
start_target = _target_max
|
|
# print(arg_list)
|
|
with open(os.path.join(output_dir, f'train_m={num_numbers}.jsonl'), 'w') as fp:
|
|
with Pool(num_workers) as p:
|
|
for data in tqdm.tqdm(p.imap_unordered(construct_helper, arg_list), total=len(arg_list)):
|
|
for ex in data:
|
|
print(json.dumps(ex), file=fp)
|
|
|
|
|
|
def construct_helper_for_v4(args_queue, out_queue):
|
|
while True:
|
|
args = args_queue.get()
|
|
if args is None:
|
|
break
|
|
numbers = args[0]
|
|
game = args[1]
|
|
puzzle = ' '.join(str(n) for n in numbers)
|
|
num_try = 0
|
|
data = {}
|
|
while True:
|
|
target, operations = game.get_target_limit_range(numbers)
|
|
if target in data:
|
|
num_try += 1
|
|
if num_try > 10:
|
|
break
|
|
else:
|
|
data[target] = {
|
|
'puzzle': puzzle,
|
|
'target': str(target),
|
|
'operations': operations
|
|
}
|
|
out_queue.put(data)
|
|
|
|
|
|
def construct_game24_v4(num_numbers=3, range_max=101, num_samples=200000, target_max=1000, seed=1234, num_workers=64, output_dir=None):
|
|
print(f'Args', num_numbers, range_max, num_samples, target_max, seed, num_workers)
|
|
seen_numbers = set()
|
|
procs = []
|
|
args_queue = Queue()
|
|
out_queue = Queue()
|
|
for i in range(num_workers):
|
|
p = Process(target=construct_helper_for_v4, args=(args_queue, out_queue))
|
|
p.start()
|
|
procs.append(p)
|
|
|
|
game = Game24Plus(num_numbers, range_max, target_max=target_max, seed=seed)
|
|
with open(os.path.join(output_dir, f'train_m={num_numbers}.jsonl'), 'w') as fp:
|
|
for i in range(num_samples):
|
|
numbers = game.get_numbers()
|
|
puzzle = ' '.join(str(n) for n in numbers)
|
|
if puzzle in seen_numbers:
|
|
continue
|
|
seen_numbers.add(puzzle)
|
|
args_queue.put((numbers, game))
|
|
for i in range(num_workers):
|
|
args_queue.put(None)
|
|
|
|
count = 0
|
|
tqdm_bar = tqdm.tqdm(total=num_samples)
|
|
while count < num_samples:
|
|
data = out_queue.get()
|
|
for target, ex in data.items():
|
|
print(json.dumps(ex), file=fp)
|
|
count += 1
|
|
tqdm_bar.update(1)
|
|
|
|
for p in procs:
|
|
p.kill()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 初始化 Game24Plus 实例
|
|
game = Game24Plus(num_numbers=6, range_max=10, target_max=24, seed=46000767)
|
|
|
|
# 获取一组随机数字
|
|
numbers = game.get_numbers()
|
|
print("Numbers:", numbers)
|
|
|
|
# 设置目标值
|
|
target = 24
|
|
print("Target:", target)
|
|
|
|
print("Trace:", game.solve(numbers, target))
|
|
|
|
# 尝试生成一步答案算式
|
|
expression = game.generate_ground_truth(numbers, target)
|
|
if expression:
|
|
print("expression:", expression)
|
|
else:
|
|
print("No expression found.")
|
|
|