- Leetcode 3553. Minimum Weighted Subgraph With the Required Paths II
- 1. 解题思路
- 2. 代码实现
- 题目链接:3553. Minimum Weighted Subgraph With the Required Paths II
1. 解题思路
这一题很惭愧,并没有自力搞定,是看了大佬们的解答才有了思路,然后还没有做出来,最后最核心的内容是问了deepseek搞定的,发现是一个经典算法,也算是涨姿势了……
言归正传,这一题要求包含给定的三个点的最小子树的距离,事实上我们只需要分别求出任意其中两点的距离,然后三者相加除以2即是我们的答案了。
因此,问题就变成了如何求树当中任意两点的距离。而这个,我们事实上也只需要指定树的一个根节点,那么任意两个点的距离事实上就是这两个点分别关于这个根节点的距离之和减去他们的最小公共父节点到根节点的距离的两倍即可求得。
于是乎,我们的问题就只剩下如何求一个树上任意两个点的最小公共父节点了。我最开始是通过暴力解法来的,但这样的话单次query的时间复杂度就是 O ( N ) O(N) O(N)了,最终导致了整体代码的超时,问了一下deepseek之后发现求树的最小公共父节点事实上是一个非常经典的算法,虽然我还没完全看懂,不过copy了一下deepseek给到的算法实现之后,上述题目倒是直接搞定了……
只能说,又双叒叕是被deepseek完虐的一天,唉……
2. 代码实现
给出python代码实现如下:
class LCA:
def __init__(self, root, tree):
self.max_level = 20 # 根据树的高度调整
self.parent = defaultdict(lambda: [-1]*self.max_level)
self.depth = {}
self.preprocess(root, tree)
def preprocess(self, root, tree):
stack = [(root, -1, 0)] # (node, parent, depth)
while stack:
node, par, d = stack.pop()
self.depth[node] = d
self.parent[node][0] = par
for k in range(1, self.max_level):
if self.parent[node][k-1] != -1:
self.parent[node][k] = self.parent[self.parent[node][k-1]][k-1]
for child in tree[node]:
if child != par:
stack.append((child, node, d+1))
def query(self, u, v):
if self.depth[u] < self.depth[v]:
u, v = v, u
# 对齐深度
for k in range(self.max_level-1, -1, -1):
if self.depth[u] - (1 << k) >= self.depth[v]:
u = self.parent[u][k]
if u == v:
return u
# 同步跳转
for k in range(self.max_level-1, -1, -1):
if self.parent[u][k] != -1 and self.parent[u][k] != self.parent[v][k]:
u = self.parent[u][k]
v = self.parent[v][k]
return self.parent[u][0]
class Solution:
def minimumWeight(self, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
n = len(edges)
graph = defaultdict(list)
for u, v, w in edges:
graph[u].append((v, w))
graph[v].append((u, w))
seen = {0}
parents = defaultdict(int)
dist = defaultdict(int)
tree = defaultdict(list)
parents[0] = -1
q = [(0, 0)]
while q:
d, u = q.pop(0)
for v, w in graph[u]:
if v in seen:
continue
seen.add(v)
tree[u].append(v)
parents[v] = u
dist[v] = d + w
q.append((d+w, v))
lca = LCA(0, tree)
def get_dist(u, v):
p = lca.query(u, v)
return dist[u] + dist[v] - 2*dist[p]
def query(u, v, w):
return (get_dist(u, v) + get_dist(u, w) + get_dist(v, w)) // 2
return [query(u, v, w) for u, v, w in queries]
提交代码评测得到:耗时5339ms,占用内存174.4MB。