mirror of
https://github.com/InternLM/InternBootcamp.git
synced 2026-04-19 12:58:04 +00:00
223 lines
6.7 KiB
Python
Executable file
223 lines
6.7 KiB
Python
Executable file
"""#
|
||
|
||
### 谜题描述
|
||
You have been offered a job in a company developing a large social network. Your first task is connected with searching profiles that most probably belong to the same user.
|
||
|
||
The social network contains n registered profiles, numbered from 1 to n. Some pairs there are friends (the \"friendship\" relationship is mutual, that is, if i is friends with j, then j is also friends with i). Let's say that profiles i and j (i ≠ j) are doubles, if for any profile k (k ≠ i, k ≠ j) one of the two statements is true: either k is friends with i and j, or k isn't friends with either of them. Also, i and j can be friends or not be friends.
|
||
|
||
Your task is to count the number of different unordered pairs (i, j), such that the profiles i and j are doubles. Note that the pairs are unordered, that is, pairs (a, b) and (b, a) are considered identical.
|
||
|
||
Input
|
||
|
||
The first line contains two space-separated integers n and m (1 ≤ n ≤ 106, 0 ≤ m ≤ 106), — the number of profiles and the number of pairs of friends, correspondingly.
|
||
|
||
Next m lines contains descriptions of pairs of friends in the format \"v u\", where v and u (1 ≤ v, u ≤ n, v ≠ u) are numbers of profiles that are friends with each other. It is guaranteed that each unordered pair of friends occurs no more than once and no profile is friends with itself.
|
||
|
||
Output
|
||
|
||
Print the single integer — the number of unordered pairs of profiles that are doubles.
|
||
|
||
Please do not use the %lld specificator to read or write 64-bit integers in С++. It is preferred to use the %I64d specificator.
|
||
|
||
Examples
|
||
|
||
Input
|
||
|
||
3 3
|
||
1 2
|
||
2 3
|
||
1 3
|
||
|
||
|
||
Output
|
||
|
||
3
|
||
|
||
|
||
Input
|
||
|
||
3 0
|
||
|
||
|
||
Output
|
||
|
||
3
|
||
|
||
|
||
Input
|
||
|
||
4 1
|
||
1 3
|
||
|
||
|
||
Output
|
||
|
||
2
|
||
|
||
Note
|
||
|
||
In the first and second sample any two profiles are doubles.
|
||
|
||
In the third sample the doubles are pairs of profiles (1, 3) and (2, 4).
|
||
|
||
Here is a reference code to solve this task. You can use this to help you genereate cases or validate the solution.
|
||
```python
|
||
#include <bits/stdc++.h>
|
||
using namespace std;
|
||
long long p[1000005], h[1000005];
|
||
unordered_map<long long, long long> m1, m2;
|
||
int main() {
|
||
long long ans = 0, n, m, B = 37, x, y;
|
||
m1.reserve(1 << 12), m2.reserve(1 << 12);
|
||
m1.max_load_factor(0.25), m2.max_load_factor(0.25);
|
||
scanf(\"%lld%lld\", &n, &m), p[0] = 1;
|
||
for (long long i = 1; i < 1000005; i++) p[i] = p[i - 1] * B;
|
||
while (m--) scanf(\"%lld%lld\", &x, &y), h[x] += p[y], h[y] += p[x];
|
||
for (long long i = 1; i <= n; i++) m1[h[i]]++, h[i] += p[i], m2[h[i]]++;
|
||
for (auto i : m1) ans += (i.second * (i.second - 1)) / 2;
|
||
for (auto i : m2) ans += (i.second * (i.second - 1)) / 2;
|
||
cout << ans;
|
||
}
|
||
```
|
||
|
||
|
||
请完成上述谜题的训练场环境类实现,包括所有必要的方法。
|
||
"""
|
||
|
||
from bootcamp import Basebootcamp
|
||
import random
|
||
from itertools import combinations
|
||
from collections import defaultdict
|
||
import re
|
||
from bootcamp import Basebootcamp
|
||
|
||
class Cdoubleprofilesbootcamp(Basebootcamp):
|
||
def __init__(self, max_n=10, max_m=20):
|
||
"""
|
||
初始化参数,修正n的取值范围包含1
|
||
增加参数校验确保合法案例生成
|
||
"""
|
||
self.max_n = max(1, max_n) # 确保max_n至少为1
|
||
self.max_m = max(0, max_m)
|
||
|
||
def case_generator(self):
|
||
"""
|
||
生成合法谜题实例,完全覆盖n=1的边界情况
|
||
优化边的生成逻辑,确保完全的随机性
|
||
"""
|
||
n = random.randint(1, self.max_n) # 修复n的取值范围
|
||
max_possible_edges = n * (n - 1) // 2
|
||
m = random.randint(0, min(self.max_m, max_possible_edges))
|
||
|
||
# 生成所有可能的无向边并随机选取
|
||
all_edges = list(combinations(range(1, n+1), 2))
|
||
selected_edges = random.sample(all_edges, m) if m > 0 else []
|
||
selected_edges = sorted([tuple(sorted(e)) for e in selected_edges]) # 标准化边格式
|
||
|
||
correct_answer = self.compute_answer(n, m, selected_edges)
|
||
return {
|
||
"n": n,
|
||
"m": m,
|
||
"edges": selected_edges,
|
||
"correct_answer": correct_answer
|
||
}
|
||
|
||
@staticmethod
|
||
def compute_answer(n, m, edges):
|
||
"""
|
||
改进哈希计算逻辑,增加对大数的容错处理
|
||
使用更安全的模运算防止数值溢出
|
||
"""
|
||
MOD = 10**18 + 3 # 大素数防止哈希碰撞
|
||
B = 37
|
||
|
||
p = [1] * (n + 2) # 扩展数组长度防止越界
|
||
for i in range(1, n+2):
|
||
p[i] = (p[i-1] * B) % MOD
|
||
|
||
h = [0] * (n + 2)
|
||
for v, u in edges:
|
||
h[v] = (h[v] + p[u]) % MOD
|
||
h[u] = (h[u] + p[v]) % MOD
|
||
|
||
m1 = defaultdict(int)
|
||
m2 = defaultdict(int)
|
||
for i in range(1, n+1):
|
||
m1[h[i]] += 1
|
||
m2[(h[i] + p[i]) % MOD] += 1 # 增加模运算
|
||
|
||
ans = 0
|
||
for cnt in m1.values():
|
||
ans += cnt * (cnt - 1) // 2
|
||
for cnt in m2.values():
|
||
ans += cnt * (cnt - 1) // 2
|
||
return ans
|
||
|
||
@staticmethod
|
||
def prompt_func(question_case) -> str:
|
||
"""
|
||
增强问题描述规范,明确边界条件
|
||
增加输入输出样例的规范说明
|
||
"""
|
||
n = question_case['n']
|
||
m = question_case['m']
|
||
edges = question_case['edges']
|
||
input_lines = [f"{n} {m}"]
|
||
input_lines += [f"{v} {u}" for v, u in sorted(edges)]
|
||
input_str = '\n'.join(input_lines)
|
||
|
||
prompt = f"""You are working at a social network company and need to find profile doubles. Two profiles i and j are doubles if for any other profile k ≠ i,j:
|
||
- k is friends with both i and j OR
|
||
- k is friends with neither i nor j
|
||
|
||
Rules:
|
||
1. Profiles are numbered 1..n
|
||
2. Friendship is mutual
|
||
3. Pairs (i,j) and (j,i) count as the same pair
|
||
4. n can be 1 (then answer is 0)
|
||
|
||
Input Format:
|
||
- First line: n m
|
||
- Next m lines: pairs of friends
|
||
|
||
Output Format:
|
||
- Single integer: number of valid doubles pairs
|
||
|
||
Example Input 1:
|
||
3 3
|
||
1 2
|
||
2 3
|
||
1 3
|
||
|
||
Example Output 1:
|
||
3
|
||
|
||
Current Input:
|
||
{input_str}
|
||
|
||
Present the final integer answer between [answer] and [/answer] tags."""
|
||
return prompt
|
||
|
||
@staticmethod
|
||
def extract_output(output):
|
||
"""
|
||
改进正则表达式模式,严格匹配整数格式
|
||
支持科学计数法等特殊格式的转换
|
||
"""
|
||
pattern = r'\[answer\s*\]\s*(-?\d+)\s*\[/answer\s*\]'
|
||
matches = re.findall(pattern, output, re.IGNORECASE)
|
||
if matches:
|
||
try:
|
||
return int(matches[-1].strip())
|
||
except ValueError:
|
||
pass
|
||
return None
|
||
|
||
@classmethod
|
||
def _verify_correction(cls, solution, identity):
|
||
"""
|
||
增加类型检查,确保比较有效性
|
||
"""
|
||
if not isinstance(solution, int):
|
||
return False
|
||
return solution == identity.get('correct_answer', -1)
|