mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-27 17:23:17 +00:00
390 lines
11 KiB
Python
Executable file
390 lines
11 KiB
Python
Executable file
"""#
|
|
|
|
### 谜题描述
|
|
You are given a tree consisting of n nodes. You generate an array from the tree by marking nodes one by one.
|
|
|
|
Initially, when no nodes are marked, a node is equiprobably chosen and marked from the entire tree.
|
|
|
|
After that, until all nodes are marked, a node is equiprobably chosen and marked from the set of unmarked nodes with at least one edge to a marked node.
|
|
|
|
It can be shown that the process marks all nodes in the tree.
|
|
|
|
The final array a is the list of the nodes' labels in order of the time each node was marked.
|
|
|
|
Find the expected number of inversions in the array that is generated by the tree and the aforementioned process.
|
|
|
|
The number of inversions in an array a is the number of pairs of indices (i, j) such that i < j and a_i > a_j. For example, the array [4, 1, 3, 2] contains 4 inversions: (1, 2), (1, 3), (1, 4), (3, 4).
|
|
|
|
Input
|
|
|
|
The first line contains a single integer n (2 ≤ n ≤ 200) — the number of nodes in the tree.
|
|
|
|
The next n - 1 lines each contains two integers x and y (1 ≤ x, y ≤ n; x ≠ y), denoting an edge between node x and y.
|
|
|
|
It's guaranteed that the given edges form a tree.
|
|
|
|
Output
|
|
|
|
Output the expected number of inversions in the generated array modulo 10^9+7.
|
|
|
|
Formally, let M = 10^9+7. It can be shown that the answer can be expressed as an irreducible fraction p/q, where p and q are integers and q not ≡ 0 \pmod{M}. Output the integer equal to p ⋅ q^{-1} mod M. In other words, output such an integer x that 0 ≤ x < M and x ⋅ q ≡ p \pmod{M}.
|
|
|
|
Examples
|
|
|
|
Input
|
|
|
|
|
|
3
|
|
1 2
|
|
1 3
|
|
|
|
|
|
Output
|
|
|
|
|
|
166666669
|
|
|
|
|
|
Input
|
|
|
|
|
|
6
|
|
2 1
|
|
2 3
|
|
6 1
|
|
1 4
|
|
2 5
|
|
|
|
|
|
Output
|
|
|
|
|
|
500000009
|
|
|
|
|
|
Input
|
|
|
|
|
|
5
|
|
1 2
|
|
1 3
|
|
1 4
|
|
2 5
|
|
|
|
|
|
Output
|
|
|
|
|
|
500000007
|
|
|
|
Note
|
|
|
|
This is the tree from the first sample:
|
|
|
|
<image>
|
|
|
|
For the first sample, the arrays are almost fixed. If node 2 is chosen initially, then the only possible array is [2, 1, 3] (1 inversion). If node 3 is chosen initially, then the only possible array is [3, 1, 2] (2 inversions). If node 1 is chosen initially, the arrays [1, 2, 3] (0 inversions) and [1, 3, 2] (1 inversion) are the only possibilities and equiprobable. In total, the expected number of inversions is 1/3⋅ 1 + 1/3 ⋅ 2 + 1/3 ⋅ (1/2 ⋅ 0 + 1/2 ⋅ 1) = 7/6.
|
|
|
|
166666669 ⋅ 6 = 7 \pmod {10^9 + 7}, so the answer is 166666669.
|
|
|
|
This is the tree from the second sample:
|
|
|
|
<image>
|
|
|
|
This is the tree from the third sample:
|
|
|
|
<image>
|
|
|
|
Here is a reference code to solve this task. You can use this to help you genereate cases or validate the solution.
|
|
```python
|
|
import sys
|
|
|
|
f = sys.stdin
|
|
|
|
def line():
|
|
return f.readline().strip().split()
|
|
|
|
class RangeQuery:
|
|
def __init__(self,array,func,func_eq):
|
|
n = len(array)
|
|
self.array = array
|
|
self.func = func
|
|
self.func_eq = func_eq
|
|
self.row = n.bit_length()
|
|
self.tab = [[0 for _ in range(n)] for _ in range(self.row)]
|
|
|
|
self.tab[0] = list(range(n))
|
|
for e in range(1, self.row):
|
|
for j in range(n - 2**e + 1):
|
|
self.tab[e][j] = self._apply_func(self.tab[e-1][j], self.tab[e-1][j + 2**(e-1)])
|
|
|
|
def _apply_func(self,l,r):
|
|
val_l = self.array[l]
|
|
val_r = self.array[r]
|
|
if val_l == val_r:
|
|
return self.func_eq(l,r)
|
|
return l if self.func(val_l,val_r) == val_l else r
|
|
|
|
''' l,r inclusive '''
|
|
def index(self,l,r):
|
|
e = 0 if r-l == 0 else (r-l+1).bit_length() - 1
|
|
return self._apply_func(self.tab[e][l], self.tab[e][r - 2**e + 1])
|
|
|
|
''' l,r inclusive '''
|
|
def query(self,l,r):
|
|
return self.array[self.index(l, r)]
|
|
|
|
|
|
class LCA_RMQ:
|
|
def __init__(self, g, root):
|
|
sz = len(g)
|
|
self.pop_idx = [None]*sz
|
|
|
|
trace = []
|
|
seen = [-1]*sz
|
|
par = [None]*sz
|
|
|
|
q = [(root,0)] # node,level
|
|
seen[root] = 0
|
|
|
|
while q:
|
|
n,lvl = q[-1]
|
|
|
|
ch = g[n]
|
|
if seen[n] == len(ch):
|
|
q.pop()
|
|
self.pop_idx[n] = len(trace)
|
|
else:
|
|
c = ch[seen[n]]
|
|
if seen[c] == -1:
|
|
seen[c] = 0
|
|
par[c] = n
|
|
q.append((c, lvl+1))
|
|
seen[n] += 1
|
|
if c == par[n]:
|
|
continue
|
|
|
|
trace.append((lvl,n)) # level,node
|
|
|
|
self.rmq = RangeQuery(trace, min, min)
|
|
|
|
def query(self, u, v):
|
|
idx_u = self.pop_idx[u]
|
|
idx_v = self.pop_idx[v]
|
|
return self.rmq.query(min(idx_u,idx_v), max(idx_u,idx_v))
|
|
|
|
|
|
def solve():
|
|
w = 0
|
|
for root in range(N):
|
|
lca = LCA_RMQ(G,root)
|
|
for v in range(N):
|
|
for u in range(v+1,N):
|
|
lvl = lca.query(u, v)[0]
|
|
lvl_u = lca.query(u, u)[0]
|
|
lvl_v = lca.query(v, v)[0]
|
|
w += DP[lvl_u-lvl][lvl_v-lvl]
|
|
w %= M
|
|
|
|
w *= INV_M[N]
|
|
w %= M
|
|
|
|
return str(w)
|
|
|
|
|
|
def inv_modulo_all(limit):
|
|
mod_inv = [0,1]
|
|
for i in range(2, limit+1):
|
|
mod_inv.append(M - (M // i) * mod_inv[M % i] % M)
|
|
return mod_inv
|
|
|
|
Z = 201
|
|
M = 10**9 + 7
|
|
INV_M = inv_modulo_all(Z)
|
|
DP = [[0 for _ in range(Z)] for _ in range(Z)]
|
|
for j in range(Z):
|
|
DP[0][j] = 1
|
|
for i in range(1,Z):
|
|
for j in range(1,Z):
|
|
DP[i][j] = (DP[i-1][j] + DP[i][j-1]) % M
|
|
DP[i][j] *= INV_M[2]
|
|
DP[i][j] %= M
|
|
|
|
T = 1
|
|
for test in range(1,T+1):
|
|
N = int(line()[0])
|
|
G = [[] for _ in range(N)]
|
|
for _ in range(N-1):
|
|
u,v = map(int, line())
|
|
G[u-1].append(v-1)
|
|
G[v-1].append(u-1)
|
|
|
|
sys.stdout.write(solve())
|
|
sys.stdout.write(\"\n\")
|
|
|
|
f.close()
|
|
```
|
|
|
|
|
|
请完成上述谜题的训练场环境类实现,包括所有必要的方法。
|
|
"""
|
|
|
|
from bootcamp import Basebootcamp
|
|
import re
|
|
import random
|
|
from bootcamp import Basebootcamp
|
|
|
|
# 预处理部分,参考原代码中的全局变量
|
|
MOD = 10**9 + 7
|
|
Z = 201
|
|
|
|
def inv_modulo_all(limit):
|
|
mod_inv = [0, 1]
|
|
for i in range(2, limit + 1):
|
|
mod_inv.append(MOD - (MOD // i) * mod_inv[MOD % i] % MOD)
|
|
return mod_inv
|
|
|
|
INV_M = inv_modulo_all(Z)
|
|
|
|
DP = [[0] * Z for _ in range(Z)]
|
|
for j in range(Z):
|
|
DP[0][j] = 1
|
|
for i in range(1, Z):
|
|
for j in range(1, Z):
|
|
DP[i][j] = (DP[i-1][j] + DP[i][j-1]) % MOD
|
|
DP[i][j] = (DP[i][j] * INV_M[2]) % MOD
|
|
|
|
class RangeQuery:
|
|
def __init__(self, array, func, func_eq):
|
|
n = len(array)
|
|
self.array = array
|
|
self.func = func
|
|
self.func_eq = func_eq
|
|
self.row = n.bit_length()
|
|
self.tab = [[0] * n for _ in range(self.row)]
|
|
self.tab[0] = list(range(n))
|
|
for e in range(1, self.row):
|
|
for j in range(n - (1 << e) + 1):
|
|
self.tab[e][j] = self._apply_func(self.tab[e-1][j], self.tab[e-1][j + (1 << (e-1))])
|
|
|
|
def _apply_func(self, l, r):
|
|
val_l = self.array[l]
|
|
val_r = self.array[r]
|
|
if val_l == val_r:
|
|
return self.func_eq(l, r)
|
|
return l if self.func(val_l, val_r) == val_l else r
|
|
|
|
def index(self, l, r):
|
|
if l > r:
|
|
l, r = r, l
|
|
e = (r - l + 1).bit_length() - 1
|
|
return self._apply_func(self.tab[e][l], self.tab[e][r - (1 << e) + 1])
|
|
|
|
def query(self, l, r):
|
|
return self.array[self.index(l, r)]
|
|
|
|
class LCA_RMQ:
|
|
def __init__(self, g, root):
|
|
sz = len(g)
|
|
self.pop_idx = [None] * sz
|
|
trace = []
|
|
seen = [-1] * sz
|
|
par = [None] * sz
|
|
q = [(root, 0)]
|
|
seen[root] = 0
|
|
while q:
|
|
n, lvl = q[-1]
|
|
if seen[n] < len(g[n]):
|
|
c = g[n][seen[n]]
|
|
if seen[c] == -1:
|
|
seen[c] = 0
|
|
par[c] = n
|
|
q.append((c, lvl + 1))
|
|
seen[n] += 1
|
|
if c == par[n]:
|
|
continue
|
|
else:
|
|
q.pop()
|
|
self.pop_idx[n] = len(trace)
|
|
trace.append((lvl, n))
|
|
self.rmq = RangeQuery(trace, lambda a, b: a if a < b else b, lambda a, b: min(a, b))
|
|
|
|
def query(self, u, v):
|
|
idx_u = self.pop_idx[u]
|
|
idx_v = self.pop_idx[v]
|
|
return self.rmq.query(min(idx_u, idx_v), max(idx_u, idx_v))
|
|
|
|
class Dtreearraybootcamp(Basebootcamp):
|
|
def __init__(self, n_min=2, n_max=200, **params):
|
|
super().__init__(**params)
|
|
self.n_min = n_min
|
|
self.n_max = n_max
|
|
|
|
def case_generator(self):
|
|
n = random.randint(self.n_min, self.n_max)
|
|
nodes = list(range(1, n + 1))
|
|
random.shuffle(nodes)
|
|
edges = []
|
|
connected = {nodes[0]}
|
|
for i in range(1, n):
|
|
node = nodes[i]
|
|
parent = random.choice(list(connected))
|
|
edges.append((parent, node))
|
|
connected.add(node)
|
|
random.shuffle(edges)
|
|
shuffled_edges = []
|
|
for u, v in edges:
|
|
if random.choice([True, False]):
|
|
shuffled_edges.append([u, v])
|
|
else:
|
|
shuffled_edges.append([v, u])
|
|
return {
|
|
'n': n,
|
|
'edges': shuffled_edges
|
|
}
|
|
|
|
@staticmethod
|
|
def prompt_func(question_case):
|
|
n = question_case['n']
|
|
edges = question_case['edges']
|
|
edges_str = '\n'.join([f"{u} {v}" for u, v in edges])
|
|
return f"""You are given a tree with {n} nodes. The edges are:
|
|
{edges_str}
|
|
|
|
Calculate the expected number of inversions in the array generated by the marking process. Initially, a node is chosen uniformly at random. Each subsequent node is chosen uniformly at random from unmarked nodes adjacent to at least one marked node.
|
|
|
|
Output the expected value modulo 10^9+7. Put your answer within [answer] and [/answer], e.g., [answer]12345[/answer]."""
|
|
|
|
@staticmethod
|
|
def extract_output(output):
|
|
matches = re.findall(r'\[answer\](.*?)\[/answer\]', output, re.DOTALL)
|
|
return matches[-1].strip() if matches else None
|
|
|
|
@classmethod
|
|
def _verify_correction(cls, solution, identity):
|
|
try:
|
|
n = identity['n']
|
|
edges = identity['edges']
|
|
G = [[] for _ in range(n)]
|
|
for u, v in edges:
|
|
u_idx = u - 1
|
|
v_idx = v - 1
|
|
G[u_idx].append(v_idx)
|
|
G[v_idx].append(u_idx)
|
|
w = 0
|
|
for root in range(n):
|
|
lca = LCA_RMQ(G, root)
|
|
for v in range(n):
|
|
for u in range(v + 1, n):
|
|
lvl = lca.query(u, v)[0]
|
|
lvl_u = lca.query(u, u)[0]
|
|
lvl_v = lca.query(v, v)[0]
|
|
du = lvl_u - lvl
|
|
dv = lvl_v - lvl
|
|
w += DP[du][dv]
|
|
w %= MOD
|
|
w = w * INV_M[n] % MOD
|
|
correct = str(w)
|
|
return solution == correct
|
|
except:
|
|
return False
|