mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-25 17:10:49 +00:00
init-commit
This commit is contained in:
commit
18a552597a
3461 changed files with 1150579 additions and 0 deletions
336
internbootcamp/libs/game24/game24.py
Executable file
336
internbootcamp/libs/game24/game24.py
Executable file
|
|
@ -0,0 +1,336 @@
|
|||
import math
|
||||
import random
|
||||
import fractions
|
||||
import json
|
||||
import tqdm
|
||||
import os
|
||||
import numpy as np
|
||||
from multiprocessing import Pool, Queue, Process
|
||||
|
||||
class Game24Plus:
|
||||
"""General version of game24 (Krypto), given N numbers less than X, find a way to get number M using 4 basic operations."""
|
||||
def __init__(self, num_numbers, range_max, target_max, seed=None):
|
||||
self.rng = np.random.default_rng(seed)
|
||||
self.num_numbers = num_numbers
|
||||
self.target_max = target_max
|
||||
self.range_max = range_max
|
||||
self.operations = ['+', '-', '*', '/']
|
||||
|
||||
def sample_one_number(self, num_min, num_max):
|
||||
return self.rng.integers(num_min, num_max, size=1)[0]
|
||||
|
||||
def get_numbers(self):
|
||||
numbers = self.rng.integers(0, self.range_max, size=self.num_numbers).tolist()
|
||||
numbers.sort()
|
||||
return numbers
|
||||
|
||||
def enumerate_all_numbers(self, num_numbers):
|
||||
if num_numbers == 1:
|
||||
for i in range(1, self.range_max):
|
||||
yield [i]
|
||||
return
|
||||
for i in range(1, self.range_max):
|
||||
for j in self.enumerate_all_numbers(num_numbers - 1):
|
||||
yield [i] + j
|
||||
|
||||
def sample_operation(self, numbers: list):
|
||||
n_pos = self.rng.choice(len(numbers), size=2, replace=False)
|
||||
n1 = numbers[n_pos[0]]
|
||||
n2 = numbers[n_pos[1]]
|
||||
op = self.rng.choice(self.operations).item()
|
||||
if op == '/':
|
||||
# if n2 == 0 or n1 % n2 != 0:
|
||||
if n2 == 0:
|
||||
return None
|
||||
return n1, op, n2, n_pos[0], n_pos[1]
|
||||
|
||||
def calculate(self, n1, op, n2):
|
||||
if op == '+':
|
||||
return n1 + n2
|
||||
elif op == '-':
|
||||
return n1 - n2
|
||||
elif op == '*':
|
||||
return n1 * n2
|
||||
elif op == '/':
|
||||
if n1 % n2 != 0:
|
||||
return fractions.Fraction(n1, n2)
|
||||
return n1 // n2
|
||||
|
||||
def get_target(self, numbers: list):
|
||||
current_numbers = numbers.copy()
|
||||
current_operations = []
|
||||
while len(current_numbers) > 1:
|
||||
r = self.sample_operation(current_numbers)
|
||||
if r is None:
|
||||
continue
|
||||
n1, op, n2, pos1, pos2 = r
|
||||
current_numbers.pop(max(pos1, pos2))
|
||||
current_numbers.pop(min(pos1, pos2))
|
||||
res = self.calculate(n1, op, n2)
|
||||
if isinstance(res, fractions.Fraction):
|
||||
a, b = res.as_integer_ratio()
|
||||
if a % b == 0:
|
||||
res = int(a/b)
|
||||
current_operations.append((str(n1), op, str(n2), str(res)))
|
||||
current_numbers.append(res)
|
||||
return current_numbers[0], current_operations
|
||||
|
||||
def get_target_limit_range(self, numbers):
|
||||
target, operations = self.get_target(numbers)
|
||||
num_try = 0
|
||||
while not isinstance(target, int) or target < 0 or target > self.target_max:
|
||||
target, operations = self.get_target(numbers)
|
||||
num_try += 1
|
||||
if num_try > 1000:
|
||||
return None, None
|
||||
return target, operations
|
||||
|
||||
def solve(self, numbers, target):
|
||||
def precise_calculate(n1, op, n2):
|
||||
if op == '+':
|
||||
return n1 + n2
|
||||
elif op == '-':
|
||||
return n1 - n2
|
||||
elif op == '*':
|
||||
return n1 * n2
|
||||
elif op == '/':
|
||||
if n1 % n2 != 0:
|
||||
return fractions.Fraction(n1, n2)
|
||||
return n1 // n2
|
||||
|
||||
if len(numbers) == 1:
|
||||
if isinstance(numbers[0], int) and numbers[0] == target:
|
||||
return []
|
||||
else:
|
||||
a, b = numbers[0].as_integer_ratio()
|
||||
if a % b == 0 and int(a/b) == target:
|
||||
return []
|
||||
return None
|
||||
for i, n1 in enumerate(numbers):
|
||||
for j, n2 in enumerate(numbers):
|
||||
if i == j:
|
||||
continue
|
||||
for op in self.operations:
|
||||
if op == '/':
|
||||
# if n2 == 0 or n1 % n2 != 0:
|
||||
if n2 == 0:
|
||||
continue
|
||||
res = precise_calculate(n1, op, n2)
|
||||
new_num = [n for k, n in enumerate(numbers) if k != i and k != j] + [res]
|
||||
new_how = [(str(n1), op, str(n2), str(res))]
|
||||
part_solution = self.solve(new_num, target)
|
||||
if part_solution is not None:
|
||||
return new_how + part_solution
|
||||
def generate_ground_truth(self, numbers, target):
|
||||
calc_steps = self.solve(numbers, target)
|
||||
|
||||
result_map = {}
|
||||
|
||||
for step in calc_steps:
|
||||
operand1, operator, operand2, result = step
|
||||
|
||||
# 如果操作数是之前的结果,则替换为对应的表达式
|
||||
if operand1 in result_map:
|
||||
operand1 = f"({result_map[operand1]})"
|
||||
else:
|
||||
operand1 = operand1
|
||||
|
||||
if operand2 in result_map:
|
||||
operand2 = f"({result_map[operand2]})"
|
||||
else:
|
||||
operand2 = operand2
|
||||
|
||||
# 构建当前步骤的表达式
|
||||
if operator == "":
|
||||
expression = operand1 # 如果没有操作符,直接取操作数
|
||||
else:
|
||||
expression = f"{operand1} {operator} {operand2}"
|
||||
|
||||
# 将当前结果存入映射表
|
||||
result_map[result] = expression
|
||||
|
||||
# 最终结果是最后一个元组的结果
|
||||
final_result = calc_steps[-1][-1]
|
||||
final_expression = result_map[final_result]
|
||||
|
||||
return final_expression
|
||||
|
||||
|
||||
def construct_game24_v1(num_numbers=3, range_max=101, num_samples=10000, target_max=1000, seed=1234, output_dir=None):
|
||||
game = Game24Plus(num_numbers, range_max, target_max=target_max, seed=seed)
|
||||
with open(os.path.join(output_dir, f'train_m={num_numbers}.jsonl'), 'w') as fp:
|
||||
puzzle_dict = {}
|
||||
for idx in tqdm.trange(num_samples):
|
||||
num_try = 0
|
||||
while True:
|
||||
numbers = game.get_numbers()
|
||||
puzzle = ' '.join(str(n) for n in numbers)
|
||||
if puzzle in puzzle_dict:
|
||||
num_try += 1
|
||||
if num_try > 1000:
|
||||
puzzle = None
|
||||
break
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if puzzle is None:
|
||||
print('Failed to generate unique puzzle')
|
||||
break
|
||||
puzzle_dict[puzzle] = True
|
||||
num_try = 0
|
||||
data = {}
|
||||
while True:
|
||||
target, operations = game.get_target_limit_range(numbers)
|
||||
target_str = str(target)
|
||||
key = puzzle + ' ' + target_str
|
||||
if key in data:
|
||||
num_try += 1
|
||||
if num_try > 10:
|
||||
break
|
||||
else:
|
||||
data[key] = {
|
||||
'puzzle': puzzle,
|
||||
'target': target_str,
|
||||
'operations': operations
|
||||
}
|
||||
num_try = 0
|
||||
for key, ex in data.items():
|
||||
print(json.dumps(ex), file=fp)
|
||||
|
||||
|
||||
def construct_helper(args):
|
||||
num_numbers, range_max, num_samples_per_target, target_min, target_max, seed = args
|
||||
game = Game24Plus(num_numbers, range_max, target_max=target_max, seed=seed)
|
||||
data = []
|
||||
for target in range(target_min, target_max):
|
||||
num_try = 0
|
||||
target_str = str(target)
|
||||
seen_data = {}
|
||||
while True:
|
||||
numbers = game.get_numbers()
|
||||
puzzle = ' '.join(str(n) for n in numbers)
|
||||
if puzzle in seen_data:
|
||||
continue
|
||||
operations = game.solve(numbers, target)
|
||||
if operations is None:
|
||||
num_try += 1
|
||||
if num_try > 10000:
|
||||
break
|
||||
continue
|
||||
ex = {
|
||||
'puzzle': puzzle,
|
||||
'target': target_str,
|
||||
'operations': operations
|
||||
}
|
||||
data.append(ex)
|
||||
num_try = 0
|
||||
seen_data[puzzle] = 1
|
||||
if len(seen_data) > num_samples_per_target:
|
||||
break
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def construct_game24_v3(num_numbers=3, range_max=101, num_samples=50000, target_max=1000, seed=1234, num_workers=64, output_dir=None):
|
||||
print(f'Args', num_numbers, range_max, num_samples, target_max, seed, num_workers)
|
||||
arg_list = []
|
||||
num_samples_per_target = (num_samples // target_max) + 1
|
||||
chunk_size = 1
|
||||
start_target = 0
|
||||
while start_target < target_max:
|
||||
_target_min = start_target
|
||||
_target_max = min(target_max, start_target + chunk_size)
|
||||
arg_list.append((
|
||||
num_numbers, range_max, num_samples_per_target, _target_min, _target_max, seed
|
||||
))
|
||||
start_target = _target_max
|
||||
# print(arg_list)
|
||||
with open(os.path.join(output_dir, f'train_m={num_numbers}.jsonl'), 'w') as fp:
|
||||
with Pool(num_workers) as p:
|
||||
for data in tqdm.tqdm(p.imap_unordered(construct_helper, arg_list), total=len(arg_list)):
|
||||
for ex in data:
|
||||
print(json.dumps(ex), file=fp)
|
||||
|
||||
|
||||
def construct_helper_for_v4(args_queue, out_queue):
|
||||
while True:
|
||||
args = args_queue.get()
|
||||
if args is None:
|
||||
break
|
||||
numbers = args[0]
|
||||
game = args[1]
|
||||
puzzle = ' '.join(str(n) for n in numbers)
|
||||
num_try = 0
|
||||
data = {}
|
||||
while True:
|
||||
target, operations = game.get_target_limit_range(numbers)
|
||||
if target in data:
|
||||
num_try += 1
|
||||
if num_try > 10:
|
||||
break
|
||||
else:
|
||||
data[target] = {
|
||||
'puzzle': puzzle,
|
||||
'target': str(target),
|
||||
'operations': operations
|
||||
}
|
||||
out_queue.put(data)
|
||||
|
||||
|
||||
def construct_game24_v4(num_numbers=3, range_max=101, num_samples=200000, target_max=1000, seed=1234, num_workers=64, output_dir=None):
|
||||
print(f'Args', num_numbers, range_max, num_samples, target_max, seed, num_workers)
|
||||
seen_numbers = set()
|
||||
procs = []
|
||||
args_queue = Queue()
|
||||
out_queue = Queue()
|
||||
for i in range(num_workers):
|
||||
p = Process(target=construct_helper_for_v4, args=(args_queue, out_queue))
|
||||
p.start()
|
||||
procs.append(p)
|
||||
|
||||
game = Game24Plus(num_numbers, range_max, target_max=target_max, seed=seed)
|
||||
with open(os.path.join(output_dir, f'train_m={num_numbers}.jsonl'), 'w') as fp:
|
||||
for i in range(num_samples):
|
||||
numbers = game.get_numbers()
|
||||
puzzle = ' '.join(str(n) for n in numbers)
|
||||
if puzzle in seen_numbers:
|
||||
continue
|
||||
seen_numbers.add(puzzle)
|
||||
args_queue.put((numbers, game))
|
||||
for i in range(num_workers):
|
||||
args_queue.put(None)
|
||||
|
||||
count = 0
|
||||
tqdm_bar = tqdm.tqdm(total=num_samples)
|
||||
while count < num_samples:
|
||||
data = out_queue.get()
|
||||
for target, ex in data.items():
|
||||
print(json.dumps(ex), file=fp)
|
||||
count += 1
|
||||
tqdm_bar.update(1)
|
||||
|
||||
for p in procs:
|
||||
p.kill()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 初始化 Game24Plus 实例
|
||||
game = Game24Plus(num_numbers=6, range_max=10, target_max=24, seed=46000767)
|
||||
|
||||
# 获取一组随机数字
|
||||
numbers = game.get_numbers()
|
||||
print("Numbers:", numbers)
|
||||
|
||||
# 设置目标值
|
||||
target = 24
|
||||
print("Target:", target)
|
||||
|
||||
print("Trace:", game.solve(numbers, target))
|
||||
|
||||
# 尝试生成一步答案算式
|
||||
expression = game.generate_ground_truth(numbers, target)
|
||||
if expression:
|
||||
print("expression:", expression)
|
||||
else:
|
||||
print("No expression found.")
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue