主要内容:并查集
并查集
并查集的题目感觉大部分都是模板题,上板子!!
class UnionFind:
def __init__(self, n):
self.pa = list(range(n))
self.size = [1]*n
self.cnt = n
def find(self, x):
if self.pa[x] != x:
self.pa[x] = self.find(self.pa[x])
return self.pa[x]
def merge(self, x, y):
fx = self.find(x)
fy = self.find(y)
if fx == fy:
return False
self.pa[fx] = fy
self.size[fy] += self.size[fx]
self.cnt -= 1
return True
def is_same(self, x, y):
return self.find(x) == self.find(y)
merge函数用于判断x,y是否联通,如果联通return False。
lanqiao19719吊坠
# 并查集模板
class UnionFind:
def __init__(self, n):
self.pa = list(range(n))
self.size = [1]*n
self.cnt = n
def find(self, x):
if self.pa[x] != x:
self.pa[x] = self.find(self.pa[x])
return self.pa[x]
def merge(self, x, y):
fx = self.find(x)
fy = self.find(y)
if fx == fy:
return False
self.pa[fx] = fy
self.size[fy] += self.size[fx]
self.cnt -= 1
return True
def is_same(self, x, y):
return self.find(x) == self.find(y)
n, m = map(int, input().split()) # 输入处理
strings = [] # 记录字符串
for _ in range(n):
strings.append(input())
# 边权为这两个字符串的最长公共子串的长度,可以按环形旋转改变起始位置,但不能翻转
def f(s):
m = len(s)
# 两个字符串拼接起来
s_concat = s + s
# 字典,键为子串长度,值为子串
suffix_dict = {}
for i in range(m):
rotated = s_concat[i:i+m]
for k in range(1, m):
# 旋转之后的子串
suffix = rotated[-k:]
if k not in suffix_dict:
suffix_dict[k] = set()
suffix_dict[k].add(suffix)
return suffix_dict
suffix_dicts = []
for s in strings:
suffix_dicts.append(f(s))
# 建图,包括边权,连接点
edges = []
for i in range(n):
for j in range(i+1, n):
max_ij_k = 0
for k in range(m, -1, -1):
suffix_set_i = suffix_dicts[i].get(k, set())
suffix_set_j = suffix_dicts[j].get(k, set())
# 如果两个集合相交不为空,记录max_ij_k为k,因为是逆序的,所以直接记录并break
if suffix_set_i & suffix_set_j:
max_ij_k = k
break
weight = max_ij_k
edges.append((weight, i, j))
# 边权从小到大排序
edges.sort(reverse=True, key=lambda x : x[0])
uf = UnionFind(n)
# ans记录值,cnt记录次数
ans = 0
cnt = 0
for weight, i, j in edges:
if uf.merge(i, j):
ans += weight
cnt += 1
# 临界
if cnt == n-1:
break
print(ans)
3493. 属性图
class UnionFind:
def __init__(self, n):
self.pa = list(range(n))
self.size = [1]* n
self.cnt = n
def find(self, x):
if self.pa[x] != x:
self.pa[x] = self.find(self.pa[x])
return self.pa[x]
def merge(self, x, y):
fx = self.find(x)
fy = self.find(y)
if fx == fy:
return False
self.pa[fx] = fy
self.size[fy] += self.size[fx]
self.cnt -= 1
return True
def is_same(self, x, y):
return self.find(x) == self.find(y)
class Solution:
def numberOfComponents(self, properties: List[List[int]], k: int) -> int:
sets = list(map(set, properties))
uf = UnionFind(len(properties))
for i, a in enumerate(sets):
for j, b in enumerate(sets[:i]):
if len(a&b) >= k:
uf.merge(i, j)
return uf.cnt
思路:并查集,先利用集合的特性去重,根据properties的长度实例化并查集,双重循环得到集合a和集合b,根据题目要求当 intersect(properties[i], properties[j]) >= k
(其中 i
和 j
的范围为 [0, n - 1]
且 i != j
),节点 i
和节点 j
之间有一条边,即当满足条件时,将i,j连起来,并在merge函数中self.cnt -= 1。最终返回uf.cnt就行。
1971. 寻找图中是否存在路径
class UnionFind:
def __init__(self, n):
self.pa = list(range(n))
self.size = [1]*n
self.cnt = n
def find(self, x):
if self.pa[x] != x:
self.pa[x] = self.find(self.pa[x])
return self.pa[x]
def merge(self, x, y):
fx = self.find(x)
fy = self.find(y)
if fx == fy:
return False
self.pa[fx] = fy
self.size[fy] += self.size[fx]
self.cnt -= 1
return True
def is_same(self, x, y):
return self.find(x) == self.find(y)
class Solution:
def validPath(self, n: int, edges: List[List[int]], source: int, destination: int) -> bool:
uf = UnionFind(n)
for i, j in edges:
uf.merge(i, j)
return uf.is_same(source, destination)
实例化一个并查集,遍历edges中的i,j并连起来,遍历结束就使用is_same()函数进行判断是否连在一起。
200. 岛屿数量
class UnionFind:
def __init__(self, n):
self.pa = list(range(n))
self.size = [1]*n
self.cnt = n
def find(self, x):
if self.pa[x] != x:
self.pa[x] = self.find(self.pa[x])
return self.pa[x]
def merge(self, x, y):
fx = self.find(x)
fy = self.find(y)
if fx == fy:
return False
self.pa[fx] = fy
self.size[fy] += self.size[fx]
self.cnt -= 1
return True
def is_same(self, x, y):
return self.find(x) == self.find(y)
class Solution:
def numIslands(self, grid: List[List[str]]) -> int:
n = len(grid)
m = len(grid[0])
uf = UnionFind(m*n)
ocean = 0
for i in range(n):
for j in range(m):
if grid[i][j] == "0":
ocean += 1
else:
# 向下查看
if i < n-1 and grid[i+1][j] == "1":
uf.merge(i*m+j, (i+1)*m+j)
# 向右查看
if j < m-1 and grid[i][j+1] == "1":
uf.merge(i*m+j, i*m+j+1)
return uf.cnt - ocean
思路:获得grid的高n宽m,通过n*m实例化并查集,将grid中的每个元素当作一个点看,然后使用ocean记录海水的熟练,当grid[i][j]==“1”时,向下向右查看,如果下面是1将当前位置i*m+j和(i+1)*m+j连起来,当右边是1将当前位置和i*m+j+1连起来。最终返回uf.cnt-ocean即为所求答案。
1631. 最小体力消耗路径
class UnionFind:
def __init__(self, n):
self.pa = list(range(n))
self.size = [1]* n
self.cnt = n
def find(self, x):
if self.pa[x] != x:
self.pa[x] = self.find(self.pa[x])
return self.pa[x]
def merge(self, x, y):
fx = self.find(x)
fy = self.find(y)
if fx == fy:
return False
self.pa[fx] = fy
self.size[fy] += self.size[fx]
self.cnt -= 1
return True
def is_same(self, x, y):
return self.find(x) == self.find(y)
class Solution:
def minimumEffortPath(self, heights: List[List[int]]) -> int:
n = len(heights)
m = len(heights[0])
uf = UnionFind(n*m)
edges = []
dirs = (0, 1, 0)
for i in range(n):
for j in range(m):
for a, b in pairwise(dirs):
x = i + a
y = j + b
if 0 <= x < n and 0 <= y < m:
edges.append((abs(heights[i][j] - heights[x][y]), i*m+j, x*m+y))
edges.sort() # 求最小
for h, i, j in edges:
uf.merge(i, j)
if uf.is_same(0, m*n-1):
return h
return 0
思路:和岛屿数量思路类似,通过n*m实例化并查集,edges记录i,j和x,y之间的高度之差绝对值。根据这个值进行排序edges,然后开始遍历edges,每次遍历将i,j连起来,并判断起点0和m*n-1是否连起来了,连起来了就直接return h因为edges是在此之前排过序的。
924. 尽量减少恶意软件的传播
class UnionFind:
def __init__(self, n):
self.pa = list(range(n))
self.size = [1]* n
self.cnt = n
def find(self, x):
if self.pa[x] != x:
self.pa[x] = self.find(self.pa[x])
return self.pa[x]
def merge(self, x, y):
fx = self.find(x)
fy = self.find(y)
if fx == fy:
return False
self.pa[fx] = fy
self.size[fy] += self.size[fx]
self.cnt -= 1
return True
def is_same(self, x, y):
return self.find(x) == self.find(y)
class Solution:
def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
n = len(graph)
m = len(graph[0])
uf = UnionFind(n)
for i in range(n):
for j in range(i + 1, n):
graph[i][j] and uf.merge(i, j)
cnt = Counter(uf.find(x) for x in initial)
ans, mx = n, 0
for x in initial:
root = uf.find(x)
if cnt[root] > 1:
continue
sz = uf.size[root]
if sz > mx or (sz == mx and x < ans):
ans = x
mx = sz
return min(initial) if ans == n else ans