题目如下:

示例1
输入
6 4
1 2
2 3
1 3
4 6
0 0 0 0 0 0
输出
14
示例2
输入
6 4
1 2
2 3
1 3
4 6
0 0 0 0 2 0
输出
1
题目链接
题解 or 思路:
首先如果我们理解题意了,这个题是顶级诈骗。
 因为是无向图,我们需要记录图中 环的大小 & 环中的 炸弹数 所以我们可以使用 带权并查集 来维护。
 设:
 环的 大小 为: 
    
     
      
       
        c
       
       
        n
       
       
        
         t
        
        
         
          环
         
         
          i
         
        
       
      
      
       cnt_{环i}
      
     
    cnt环i
 环中的 炸弹数 为: 
    
     
      
       
        c
       
       
        n
       
       
        
         t
        
        
         
          炸
         
         
          i
         
        
       
      
      
       cnt_{炸i}
      
     
    cnt炸i
 最终炸弹数的和为: 
    
     
      
       
        s
       
       
        u
       
       
        m
       
      
      
       sum
      
     
    sum
我们可以分两种情况:
- 最终所有节点的炸弹数为 
      
       
        
         
          0
         
        
        
         0
        
       
      0
 那么答案就是 ∑ c n t 环 i ∗ c n t 环 i \sum cnt_{环i} * cnt_{环i} ∑cnt环i∗cnt环i
 因为:在同一个连通块中 任意一点能到达任意一点
- 最终所有节点的炸弹数不为 
      
       
        
         
          0
         
        
        
         0
        
       
      0
 如果存在一个环, c n t 炸 i cnt_{炸i} cnt炸i == sum
 那么答案就是 c n t 环 i ∗ c n t 环 i cnt_{环i} * cnt_{环i} cnt环i∗cnt环i
 如果找不到:
 那么就是没有合法的方案,答案为 0 0 0
AC 代码:
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <numeric>
#include <cstring>
#include <cmath>
#include <map>
#include <unordered_map>
#include <bitset>
#include <set>
#include <random>
#include <ctime>
#include <queue>
#include <stack>
#include <climits>
#define buff                     \
    ios::sync_with_stdio(false); \
    cin.tie(0);
// #define int long long
#define ll long long
#define PII pair<int, int>
#define px first
#define py second
typedef std::mt19937 Random_mt19937;
Random_mt19937 rnd(time(0));
using namespace std;
const int mod = 1e9 + 7;
const int inf = 2147483647;
const int N = 200009;
int n, m, f[N], cnt[N], s[N];
int find(int x)
{
    if (x == f[x])
        return f[x];
    int fx = find(f[x]);
    s[x] += s[f[x]];
    cnt[x] += cnt[f[x]];
    f[x] = fx;
    return f[x];
}
void join(int x, int y)
{
    int xx = find(x);
    int yy = find(y);
    if (xx != yy)
    {
        f[yy] = xx;
        cnt[xx] += cnt[yy];
        s[xx] += s[yy];
    }
}
void solve()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        f[i] = i, cnt[i] = 1;
    int sum = 0;
    vector<PII> v;
    for (int i = 1; i <= m; i++)
    {
        int a, b;
        cin >> a >> b;
        v.push_back({a, b});
    }
    for (int i = 1; i <= n; i++)
        cin >> s[i], sum += s[i];
    for (int i = 0; i < m; i++)
        join(v[i].first, v[i].second);
    ll ans = 0;
    if (sum == 0)
    {
        for (int i = 1; i <= n; i++)
            if (f[i] == i)
                ans += (ll)cnt[i] * cnt[i];
    }
    else
    {
        for (int i = 1; i <= n; i++)
            if (f[i] == i)
                if (s[i] == sum)
                    ans = (ll)cnt[i] * cnt[i];
    }
    cout << ans << '\n';
}
int main()
{
    buff;
    solve();
}















