InternBootcamp/internbootcamp/bootcamp/cnetworksafety/cnetworksafety.py
2025-05-23 15:27:15 +08:00

254 lines
8.9 KiB
Python
Executable file

"""#
### 谜题描述
The Metropolis computer network consists of n servers, each has an encryption key in the range from 0 to 2^k - 1 assigned to it. Let c_i be the encryption key assigned to the i-th server. Additionally, m pairs of servers are directly connected via a data communication channel. Because of the encryption algorithms specifics, a data communication channel can only be considered safe if the two servers it connects have distinct encryption keys. The initial assignment of encryption keys is guaranteed to keep all data communication channels safe.
You have been informed that a new virus is actively spreading across the internet, and it is capable to change the encryption key of any server it infects. More specifically, the virus body contains some unknown number x in the same aforementioned range, and when server i is infected, its encryption key changes from c_i to c_i ⊕ x, where ⊕ denotes the [bitwise XOR operation](https://en.wikipedia.org/wiki/Bitwise_operation#XOR).
Sadly, you know neither the number x nor which servers of Metropolis are going to be infected by the dangerous virus, so you have decided to count the number of such situations in which all data communication channels remain safe. Formally speaking, you need to find the number of pairs (A, x), where A is some (possibly empty) subset of the set of servers and x is some number in the range from 0 to 2^k - 1, such that when all servers from the chosen subset A and none of the others are infected by a virus containing the number x, all data communication channels remain safe. Since this number can be quite big, you are asked to find its remainder modulo 10^9 + 7.
Input
The first line of input contains three integers n, m and k (1 ≤ n ≤ 500 000, 0 ≤ m ≤ min((n(n - 1))/(2), 500 000), 0 ≤ k ≤ 60) — the number of servers, the number of pairs of servers directly connected by a data communication channel, and the parameter k, which defines the range of possible values for encryption keys.
The next line contains n integers c_i (0 ≤ c_i ≤ 2^k - 1), the i-th of which is the encryption key used by the i-th server.
The next m lines contain two integers u_i and v_i each (1 ≤ u_i, v_i ≤ n, u_i ≠ v_i) denoting that those servers are connected by a data communication channel. It is guaranteed that each pair of servers appears in this list at most once.
Output
The only output line should contain a single integer — the number of safe infections of some subset of servers by a virus with some parameter, modulo 10^9 + 7.
Examples
Input
4 4 2
0 1 0 1
1 2
2 3
3 4
4 1
Output
50
Input
4 5 3
7 1 7 2
1 2
2 3
3 4
4 1
2 4
Output
96
Note
Consider the first example.
Possible values for the number x contained by the virus are 0, 1, 2 and 3.
For values 0, 2 and 3 the virus can infect any subset of the set of servers, which gives us 16 pairs for each values. A virus containing the number 1 can infect either all of the servers, or none. This gives us 16 + 2 + 16 + 16 = 50 pairs in total.
Here is a reference code to solve this task. You can use this to help you genereate cases or validate the solution.
```python
#include <bits/stdc++.h>
using namespace std;
map<long long, vector<pair<long long, long long>>> m1;
vector<long long> used;
vector<vector<long long>> g;
const long long inf = 1e9 + 7;
void dfs(long long v, long long timer) {
used[v] = timer;
for (auto u : g[v]) {
if (used[u] != timer) dfs(u, timer);
}
}
long long N = 500005;
signed main() {
ios_base::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
long long n, m, k;
cin >> n >> m >> k;
vector<long long> pwr(N);
pwr[0] = 1;
for (long long i = 1; i < N; ++i) {
pwr[i] = (2 * pwr[i - 1]) % inf;
}
vector<long long> a(n);
for (long long i = 0; i < n; ++i) {
cin >> a[i];
}
for (long long i = 0; i < m; ++i) {
long long u, v;
cin >> u >> v;
--u, --v;
m1[(a[u] ^ a[v])].push_back({u, v});
}
g.resize(n);
used.resize(n);
long long answer = 0, timer = 0;
for (auto x : m1) {
timer++;
set<long long> V;
const auto& cur = x.second;
for (auto u : cur) {
V.insert(u.first);
V.insert(u.second);
g[u.first].push_back(u.second);
g[u.second].push_back(u.first);
}
long long q = 0;
for (auto u : V) {
if (used[u] != timer) {
dfs(u, timer);
q++;
}
}
answer = (answer + pwr[n - V.size() + q]) % inf;
for (auto u : V) {
g[u].clear();
}
}
answer += (((pwr[k] - m1.size())) * pwr[n]) % inf;
answer = (answer % inf + inf) % inf;
cout << answer;
}
```
请完成上述谜题的训练场环境类实现,包括所有必要的方法。
"""
from bootcamp import Basebootcamp
import random
from collections import defaultdict
import re
MOD = 10**9 + 7
class Cnetworksafetybootcamp(Basebootcamp):
def __init__(self, max_n=6, max_k=3, **params):
self.max_n = params.get('max_n', max_n)
self.max_k = params.get('max_k', max_k)
self.max_m_attempts = params.get('max_m_attempts', 10)
def case_generator(self):
while True:
n = random.randint(1, self.max_n)
k = random.randint(0, self.max_k)
max_key = (1 << k) - 1 if k > 0 else 0
c = [random.randint(0, max_key) for _ in range(n)]
# Generate candidate edges with distinct keys
candidate_edges = set()
for u in range(n):
for v in range(u + 1, n):
if c[u] != c[v]:
candidate_edges.add((u + 1, v + 1)) # 1-based index
# Convert to sorted list for sampling
candidate_edges = sorted(candidate_edges)
m_max = min(len(candidate_edges), self.max_m_attempts)
m = random.randint(0, m_max) if candidate_edges else 0
# Handle edge cases for small n
selected_edges = []
if candidate_edges and m > 0:
selected_edges = random.sample(candidate_edges, m)
# Validate constraints
if m == 0 or len(selected_edges) == m:
break
return {
'n': n,
'm': m,
'k': k,
'c': c,
'edges': selected_edges
}
@staticmethod
def prompt_func(question_case):
case = question_case
edges_str = '\n'.join(f"{u} {v}" for u, v in case['edges'])
return (
"Cnetworksafety Network Security Problem\n\n"
"**Background**:\n"
"A virus can flip encryption keys using XOR with an unknown number x. "
"Find the number of safe (infected subset, x) pairs that keep all communication channels secure.\n\n"
"**Input Format**:\n"
f"- Line 1: {case['n']} {case['m']} {case['k']}\n"
f"- Line 2: {' '.join(map(str, case['c']))}\n"
f"- Next {case['m']} lines (connections):\n{edges_str}\n\n"
"**Requirements**:\n"
"Output the answer modulo 1e9+7\n"
"Put your final answer within [answer]...[/answer] tags."
)
@staticmethod
def extract_output(output):
answers = re.findall(r'\[answer\](.*?)\[/answer\]', output, re.DOTALL)
return answers[-1].strip() if answers else None
@classmethod
def _verify_correction(cls, solution, identity):
try:
user_ans = int(solution.strip()) % MOD
except:
return False
# Reference algorithm implementation
n, m, k = identity['n'], identity['m'], identity['k']
c = identity['c']
edges = [(u-1, v-1) for u, v in identity['edges']] # Convert to 0-based
d_map = defaultdict(list)
for u, v in edges:
d = c[u] ^ c[v]
d_map[d].append((u, v))
result = 0
# Process each xor value
for d, edges in d_map.items():
adj = [[] for _ in range(n)]
nodes = set()
for u, v in edges:
adj[u].append(v)
adj[v].append(u)
nodes.add(u)
nodes.add(v)
visited = [False] * n
components = 0
for node in nodes:
if not visited[node]:
components += 1
stack = [node]
visited[node] = True
while stack:
u = stack.pop()
for v in adj[u]:
if not visited[v]:
visited[v] = True
stack.append(v)
free_nodes = n - len(nodes)
result = (result + pow(2, free_nodes + components, MOD)) % MOD
# Handle x values not in d_map
other_x = (pow(2, k, MOD) - len(d_map)) % MOD
result = (result + other_x * pow(2, n, MOD)) % MOD
return user_ans == result % MOD