InternBootcamp/internbootcamp/bootcamp/cryoukosmemorynote/cryoukosmemorynote.py
2025-05-23 15:27:15 +08:00

823 lines
23 KiB
Python
Executable file
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""#
### 谜题描述
Ryouko is an extremely forgetful girl, she could even forget something that has just happened. So in order to remember, she takes a notebook with her, called Ryouko's Memory Note. She writes what she sees and what she hears on the notebook, and the notebook became her memory.
Though Ryouko is forgetful, she is also born with superb analyzing abilities. However, analyzing depends greatly on gathered information, in other words, memory. So she has to shuffle through her notebook whenever she needs to analyze, which is tough work.
Ryouko's notebook consists of n pages, numbered from 1 to n. To make life (and this problem) easier, we consider that to turn from page x to page y, |x - y| pages should be turned. During analyzing, Ryouko needs m pieces of information, the i-th piece of information is on page ai. Information must be read from the notebook in order, so the total number of pages that Ryouko needs to turn is <image>.
Ryouko wants to decrease the number of pages that need to be turned. In order to achieve this, she can merge two pages of her notebook. If Ryouko merges page x to page y, she would copy all the information on page x to y (1 ≤ x, y ≤ n), and consequently, all elements in sequence a that was x would become y. Note that x can be equal to y, in which case no changes take place.
Please tell Ryouko the minimum number of pages that she needs to turn. Note she can apply the described operation at most once before the reading. Note that the answer can exceed 32-bit integers.
Input
The first line of input contains two integers n and m (1 ≤ n, m ≤ 105).
The next line contains m integers separated by spaces: a1, a2, ..., am (1 ≤ ai ≤ n).
Output
Print a single integer — the minimum number of pages Ryouko needs to turn.
Examples
Input
4 6
1 2 3 4 3 2
Output
3
Input
10 5
9 4 3 8 8
Output
6
Note
In the first sample, the optimal solution is to merge page 4 to 3, after merging sequence a becomes {1, 2, 3, 3, 3, 2}, so the number of pages Ryouko needs to turn is |1 - 2| + |2 - 3| + |3 - 3| + |3 - 3| + |3 - 2| = 3.
In the second sample, optimal solution is achieved by merging page 9 to 4.
Here is a reference code to solve this task. You can use this to help you genereate cases or validate the solution.
```python
# Enter your code here. Read input from STDIN. Print output to STDOUT# ===============================================================================================
# importing some useful libraries.
from __future__ import division, print_function
from fractions import Fraction
import sys
import os
from io import BytesIO, IOBase
from itertools import *
import bisect
from heapq import *
from math import ceil, floor
from copy import *
from collections import deque, defaultdict
from collections import Counter as counter # Counter(list) return a dict with {key: count}
from itertools import combinations # if a = [1,2,3] then print(list(comb(a,2))) -----> [(1, 2), (1, 3), (2, 3)]
from itertools import permutations as permutate
from bisect import bisect_left as bl
from operator import *
# If the element is already present in the list,
# the left most position where element has to be inserted is returned.
from bisect import bisect_right as br
from bisect import bisect
# If the element is already present in the list,
# the right most position where element has to be inserted is returned
# ==============================================================================================
# fast I/O region
BUFSIZE = 8192
from sys import stderr
class FastIO(IOBase):
newlines = 0
def __init__(self, file):
self._fd = file.fileno()
self.buffer = BytesIO()
self.writable = \"x\" in file.mode or \"r\" not in file.mode
self.write = self.buffer.write if self.writable else None
def read(self):
while True:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
if not b:
break
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines = 0
return self.buffer.read()
def readline(self):
while self.newlines == 0:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
self.newlines = b.count(b\"\n\") + (not b)
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines -= 1
return self.buffer.readline()
def flush(self):
if self.writable:
os.write(self._fd, self.buffer.getvalue())
self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
def __init__(self, file):
self.buffer = FastIO(file)
self.flush = self.buffer.flush
self.writable = self.buffer.writable
self.write = lambda s: self.buffer.write(s.encode(\"ascii\"))
self.read = lambda: self.buffer.read().decode(\"ascii\")
self.readline = lambda: self.buffer.readline().decode(\"ascii\")
def print(*args, **kwargs):
\"\"\"Prints the values to a stream, or to sys.stdout by default.\"\"\"
sep, file = kwargs.pop(\"sep\", \" \"), kwargs.pop(\"file\", sys.stdout)
at_start = True
for x in args:
if not at_start:
file.write(sep)
file.write(str(x))
at_start = False
file.write(kwargs.pop(\"end\", \"\n\"))
if kwargs.pop(\"flush\", False):
file.flush()
if sys.version_info[0] < 3:
sys.stdin, sys.stdout = FastIO(sys.stdin), FastIO(sys.stdout)
else:
sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout)
# inp = lambda: sys.stdin.readline().rstrip(\"\r\n\")
# ===============================================================================================
### START ITERATE RECURSION ###
from types import GeneratorType
def iterative(f, stack=[]):
def wrapped_func(*args, **kwargs):
if stack: return f(*args, **kwargs)
to = f(*args, **kwargs)
while True:
if type(to) is GeneratorType:
stack.append(to)
to = next(to)
continue
stack.pop()
if not stack: break
to = stack[-1].send(to)
return to
return wrapped_func
#### END ITERATE RECURSION ####
###########################
# Sorted list
class SortedList:
def __init__(self, iterable=[], _load=200):
\"\"\"Initialize sorted list instance.\"\"\"
values = sorted(iterable)
self._len = _len = len(values)
self._load = _load
self._lists = _lists = [values[i:i + _load] for i in range(0, _len, _load)]
self._list_lens = [len(_list) for _list in _lists]
self._mins = [_list[0] for _list in _lists]
self._fen_tree = []
self._rebuild = True
def _fen_build(self):
\"\"\"Build a fenwick tree instance.\"\"\"
self._fen_tree[:] = self._list_lens
_fen_tree = self._fen_tree
for i in range(len(_fen_tree)):
if i | i + 1 < len(_fen_tree):
_fen_tree[i | i + 1] += _fen_tree[i]
self._rebuild = False
def _fen_update(self, index, value):
\"\"\"Update `fen_tree[index] += value`.\"\"\"
if not self._rebuild:
_fen_tree = self._fen_tree
while index < len(_fen_tree):
_fen_tree[index] += value
index |= index + 1
def _fen_query(self, end):
\"\"\"Return `sum(_fen_tree[:end])`.\"\"\"
if self._rebuild:
self._fen_build()
_fen_tree = self._fen_tree
x = 0
while end:
x += _fen_tree[end - 1]
end &= end - 1
return x
def _fen_findkth(self, k):
\"\"\"Return a pair of (the largest `idx` such that `sum(_fen_tree[:idx]) <= k`, `k - sum(_fen_tree[:idx])`).\"\"\"
_list_lens = self._list_lens
if k < _list_lens[0]:
return 0, k
if k >= self._len - _list_lens[-1]:
return len(_list_lens) - 1, k + _list_lens[-1] - self._len
if self._rebuild:
self._fen_build()
_fen_tree = self._fen_tree
idx = -1
for d in reversed(range(len(_fen_tree).bit_length())):
right_idx = idx + (1 << d)
if right_idx < len(_fen_tree) and k >= _fen_tree[right_idx]:
idx = right_idx
k -= _fen_tree[idx]
return idx + 1, k
def _delete(self, pos, idx):
\"\"\"Delete value at the given `(pos, idx)`.\"\"\"
_lists = self._lists
_mins = self._mins
_list_lens = self._list_lens
self._len -= 1
self._fen_update(pos, -1)
del _lists[pos][idx]
_list_lens[pos] -= 1
if _list_lens[pos]:
_mins[pos] = _lists[pos][0]
else:
del _lists[pos]
del _list_lens[pos]
del _mins[pos]
self._rebuild = True
def _loc_left(self, value):
\"\"\"Return an index pair that corresponds to the first position of `value` in the sorted list.\"\"\"
if not self._len:
return 0, 0
_lists = self._lists
_mins = self._mins
lo, pos = -1, len(_lists) - 1
while lo + 1 < pos:
mi = (lo + pos) >> 1
if value <= _mins[mi]:
pos = mi
else:
lo = mi
if pos and value <= _lists[pos - 1][-1]:
pos -= 1
_list = _lists[pos]
lo, idx = -1, len(_list)
while lo + 1 < idx:
mi = (lo + idx) >> 1
if value <= _list[mi]:
idx = mi
else:
lo = mi
return pos, idx
def _loc_right(self, value):
\"\"\"Return an index pair that corresponds to the last position of `value` in the sorted list.\"\"\"
if not self._len:
return 0, 0
_lists = self._lists
_mins = self._mins
pos, hi = 0, len(_lists)
while pos + 1 < hi:
mi = (pos + hi) >> 1
if value < _mins[mi]:
hi = mi
else:
pos = mi
_list = _lists[pos]
lo, idx = -1, len(_list)
while lo + 1 < idx:
mi = (lo + idx) >> 1
if value < _list[mi]:
idx = mi
else:
lo = mi
return pos, idx
def add(self, value):
\"\"\"Add `value` to sorted list.\"\"\"
_load = self._load
_lists = self._lists
_mins = self._mins
_list_lens = self._list_lens
self._len += 1
if _lists:
pos, idx = self._loc_right(value)
self._fen_update(pos, 1)
_list = _lists[pos]
_list.insert(idx, value)
_list_lens[pos] += 1
_mins[pos] = _list[0]
if _load + _load < len(_list):
_lists.insert(pos + 1, _list[_load:])
_list_lens.insert(pos + 1, len(_list) - _load)
_mins.insert(pos + 1, _list[_load])
_list_lens[pos] = _load
del _list[_load:]
self._rebuild = True
else:
_lists.append([value])
_mins.append(value)
_list_lens.append(1)
self._rebuild = True
def discard(self, value):
\"\"\"Remove `value` from sorted list if it is a member.\"\"\"
_lists = self._lists
if _lists:
pos, idx = self._loc_right(value)
if idx and _lists[pos][idx - 1] == value:
self._delete(pos, idx - 1)
def remove(self, value):
\"\"\"Remove `value` from sorted list; `value` must be a member.\"\"\"
_len = self._len
self.discard(value)
if _len == self._len:
raise ValueError('{0!r} not in list'.format(value))
def pop(self, index=-1):
\"\"\"Remove and return value at `index` in sorted list.\"\"\"
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
value = self._lists[pos][idx]
self._delete(pos, idx)
return value
def bisect_left(self, value):
\"\"\"Return the first index to insert `value` in the sorted list.\"\"\"
pos, idx = self._loc_left(value)
return self._fen_query(pos) + idx
def bisect_right(self, value):
\"\"\"Return the last index to insert `value` in the sorted list.\"\"\"
pos, idx = self._loc_right(value)
return self._fen_query(pos) + idx
def count(self, value):
\"\"\"Return number of occurrences of `value` in the sorted list.\"\"\"
return self.bisect_right(value) - self.bisect_left(value)
def __len__(self):
\"\"\"Return the size of the sorted list.\"\"\"
return self._len
def __getitem__(self, index):
\"\"\"Lookup value at `index` in sorted list.\"\"\"
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
return self._lists[pos][idx]
def __delitem__(self, index):
\"\"\"Remove value at `index` from sorted list.\"\"\"
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
self._delete(pos, idx)
def __contains__(self, value):
\"\"\"Return true if `value` is an element of the sorted list.\"\"\"
_lists = self._lists
if _lists:
pos, idx = self._loc_left(value)
return idx < len(_lists[pos]) and _lists[pos][idx] == value
return False
def __iter__(self):
\"\"\"Return an iterator over the sorted list.\"\"\"
return (value for _list in self._lists for value in _list)
def __reversed__(self):
\"\"\"Return a reverse iterator over the sorted list.\"\"\"
return (value for _list in reversed(self._lists) for value in reversed(_list))
def __repr__(self):
\"\"\"Return string representation of sorted list.\"\"\"
return 'SortedList({0})'.format(list(self))
# ===============================================================================================
# some shortcuts
mod = 1000000007
def testcase(t):
for p in range(t):
solve()
def pow(x, y, p):
res = 1 # Initialize result
x = x % p # Update x if it is more , than or equal to p
if (x == 0):
return 0
while (y > 0):
if ((y & 1) == 1): # If y is odd, multiply, x with result
res = (res * x) % p
y = y >> 1 # y = y/2
x = (x * x) % p
return res
from functools import reduce
def factors(n):
return set(reduce(list.__add__,
([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0)))
def gcd(a, b):
if a == b: return a
while b > 0: a, b = b, a % b
return a
# discrete binary search
# minimise:
# def search():
# l = 0
# r = 10 ** 15
#
# for i in range(200):
# if isvalid(l):
# return l
# if l == r:
# return l
# m = (l + r) // 2
# if isvalid(m) and not isvalid(m - 1):
# return m
# if isvalid(m):
# r = m + 1
# else:
# l = m
# return m
# maximise:
# def search():
# l = 0
# r = 10 ** 15
#
# for i in range(200):
# # print(l,r)
# if isvalid(r):
# return r
# if l == r:
# return l
# m = (l + r) // 2
# if isvalid(m) and not isvalid(m + 1):
# return m
# if isvalid(m):
# l = m
# else:
# r = m - 1
# return m
##############Find sum of product of subsets of size k in a array
# ar=[0,1,2,3]
# k=3
# n=len(ar)-1
# dp=[0]*(n+1)
# dp[0]=1
# for pos in range(1,n+1):
# dp[pos]=0
# l=max(1,k+pos-n-1)
# for j in range(min(pos,k),l-1,-1):
# dp[j]=dp[j]+ar[pos]*dp[j-1]
# print(dp[k])
def prefix_sum(ar): # [1,2,3,4]->[1,3,6,10]
return list(accumulate(ar))
def suffix_sum(ar): # [1,2,3,4]->[10,9,7,4]
return list(accumulate(ar[::-1]))[::-1]
def N():
return int(inp())
dx = [0, 0, 1, -1]
dy = [1, -1, 0, 0]
def YES():
print(\"YES\")
def NO():
print(\"NO\")
def Yes():
print(\"Yes\")
def No():
print(\"No\")
# =========================================================================================
from collections import defaultdict
def numberOfSetBits(i):
i = i - ((i >> 1) & 0x55555555)
i = (i & 0x33333333) + ((i >> 2) & 0x33333333)
return (((i + (i >> 4) & 0xF0F0F0F) * 0x1010101) & 0xffffffff) >> 24
class MergeFind:
def __init__(self, n):
self.parent = list(range(n))
self.size = [1] * n
self.num_sets = n
# self.lista = [[_] for _ in range(n)]
def find(self, a):
to_update = []
while a != self.parent[a]:
to_update.append(a)
a = self.parent[a]
for b in to_update:
self.parent[b] = a
return self.parent[a]
def merge(self, a, b):
a = self.find(a)
b = self.find(b)
if a == b:
return
if self.size[a] < self.size[b]:
a, b = b, a
self.num_sets -= 1
self.parent[b] = a
self.size[a] += self.size[b]
# self.lista[a] += self.lista[b]
# self.lista[b] = []
def set_size(self, a):
return self.size[self.find(a)]
def __len__(self):
return self.num_sets
def lcm(a, b):
return abs((a // gcd(a, b)) * b)
# #
# to find factorial and ncr
# tot = 100005
# mod = 10**9 + 7
# fac = [1, 1]
# finv = [1, 1]
# inv = [0, 1]
#
# for i in range(2, tot + 1):
# fac.append((fac[-1] * i) % mod)
# inv.append(mod - (inv[mod % i] * (mod // i) % mod))
# finv.append(finv[-1] * inv[-1] % mod)
def comb(n, r):
if n < r:
return 0
else:
return fac[n] * (finv[r] * finv[n - r] % mod) % mod
def inp(): return sys.stdin.readline().rstrip(\"\r\n\") # for fast input
def out(var): sys.stdout.write(str(var)) # for fast output, always take string
def lis(): return list(map(int, inp().split()))
def stringlis(): return list(map(str, inp().split()))
def sep(): return map(int, inp().split())
def strsep(): return map(str, inp().split())
def fsep(): return map(float, inp().split())
def nextline(): out(\"\n\") # as stdout.write always print sring.
def arr1d(n, v):
return [v] * n
def arr2d(n, m, v):
return [[v] * m for _ in range(n)]
def arr3d(n, m, p, v):
return [[[v] * p for _ in range(m)] for i in range(n)]
def ceil(a, b):
return (a + b - 1) // b
# co-ordinate compression
# ma={s:idx for idx,s in enumerate(sorted(set(l+r)))}
# mxn=100005
# lrg=[0]*mxn
# for i in range(2,mxn-3):
# if (lrg[i]==0):
# for j in range(i,mxn-3,i):
# lrg[j]=i
def solve():
n,m=sep()
ar=lis()
adj=[[] for _ in range(n+1)]
for i in range(m):
if(i-1>=0):
if(ar[i]!=ar[i-1]):
adj[ar[i]].append(ar[i-1])
if (i + 1 < m):
if (ar[i] != ar[i + 1]):
adj[ar[i]].append(ar[i + 1])
totscore=0
for i in range(1,m):
totscore+=abs(ar[i]-ar[i-1])
redscore=0
for i in range(1,n+1):
adj[i].sort()
curscore=0
l=(len(adj[i]))
if l==0:continue
med=adj[i][(l)//2]
besscore=0
for j in adj[i]:
curscore+=abs(i-j)
besscore+=abs(med-j)
redscore=max(redscore,curscore-besscore)
print(min(totscore-redscore,totscore))
solve()
# testcase(int(inp()))
```
请完成上述谜题的训练场环境类实现,包括所有必要的方法。
"""
from bootcamp import Basebootcamp
import random
import re
from typing import Dict, Any
class Cryoukosmemorynotebootcamp(Basebootcamp):
def __init__(self, max_n=20, max_m=20):
"""
初始化训练场参数默认最大页面数为20最大序列长度为20。
"""
self.max_n = max_n
self.max_m = max_m
def case_generator(self) -> Dict[str, Any]:
"""
生成谜题实例包含n, m和页面序列a。
"""
n = random.randint(1, self.max_n)
m = random.randint(1, self.max_m)
a = [random.randint(1, n) for _ in range(m)]
return {
'n': n,
'm': m,
'a': a
}
@staticmethod
def prompt_func(question_case) -> str:
"""
将谜题实例转换为详细的自然语言问题描述。
"""
n = question_case['n']
m = question_case['m']
a = ' '.join(map(str, question_case['a']))
problem_text = f"""你是Ryouko需要解决一个关于记忆笔记本页面优化的问题。你的笔记本共有{n}编号从1到{n}。你需要按顺序查阅以下页面序列(共{m}次):{a}
每次翻页的代价是当前页与目标页的绝对差值。你最多可以执行一次合并操作选择一个页面x合并到页面y使得所有在序列中的x都会被替换为y。请计算经过最优合并后最小的总翻页代价。
例如,当输入为:
4 6
1 2 3 4 3 2
合并页面4到3后总代价为3因此最终答案为3。
你的任务是解决以下具体案例:
- 笔记本总页数n = {n}
- 序列长度m = {m}
- 访问序列a = {a}
请将最终答案的整数值放置在[answer]和[/answer]标签之间。例如:[answer]42[/answer]"""
return problem_text
@staticmethod
def extract_output(output: str):
"""
从模型输出中提取最后一个[answer]标签包裹的整数。
"""
matches = re.findall(r'\[answer\](.*?)\[/answer\]', output, re.DOTALL)
if not matches:
return None
last_match = matches[-1].strip()
try:
return int(last_match)
except ValueError:
return None
@classmethod
def _verify_correction(cls, solution: int, identity: Dict[str, Any]) -> bool:
"""
验证答案是否符合计算出的最小翻页数。
"""
try:
# 计算正确答案
n = identity['n']
m = identity['m']
a = identity['a']
correct = cls.compute_min_turns(n, m, a)
return solution == correct
except:
return False
@staticmethod
def compute_min_turns(n: int, m: int, a: list) -> int:
"""
计算给定案例的最小翻页数。
"""
if m <= 1:
return 0
original_cost = sum(abs(a[i] - a[i-1]) for i in range(1, m))
adj = [[] for _ in range(n+1)] # 邻接关系存储
# 构建邻接关系
for i in range(m):
current = a[i]
if i > 0 and a[i-1] != current:
adj[current].append(a[i-1])
if i < m-1 and a[i+1] != current:
adj[current].append(a[i+1])
max_reduction = 0
for page in range(1, n+1):
neighbors = adj[page]
if not neighbors:
continue
# 计算原始总代价
original_sum = sum(abs(page - x) for x in neighbors)
# 计算最优合并后的代价
sorted_neighbors = sorted(neighbors)
median_index = len(sorted_neighbors) // 2
median = sorted_neighbors[median_index]
optimized_sum = sum(abs(median - x) for x in sorted_neighbors)
# 更新最大减少量
current_reduction = original_sum - optimized_sum
if current_reduction > max_reduction:
max_reduction = current_reduction
return original_cost - max_reduction