题目描述

一个原力网络可以看成是一个可能存在重边但没有自环的无向图。每条边有一种属性和一个权值。属性可能是R、G、B三种当中的一种,代表这条边上原力的类型。权值是一个正整数,代表这条边上的原力强度。原力技术的核心在于将R、G、B三种不同的原力融合在一起产生单一的、便于利用的原力。为了评估一个能源网络,JYY需要找到所有满足要求的三元环(首尾相接的三条边),其中R、G、B三种边各一条。一个三元环产生的能量是其中三条边的权值之积。
现在对于给出的原力网络,JYY想知道这个网络的总能量是多少。网络的总能量是所有满足要求三元环的能量之和。

输入

第一行包含两个正整数N、M。表示原力网络的总顶点个数和总边数。
接下来M行,每行包含三个正整数ui,vi,wi和一个字符ci。
表示编号ui和vi的顶点之间存在属性为ci权值为wi的一条边。
N≤50,000,M≤100,000,1≤?Wi≤10^6

输出

输出一行一个整数,表示这个原力网络的总能量模10^9+7的值

样例输入

4 6
1 2 2 R
2 4 3 G
4 3 5 R
3 1 7 G
1 4 11 B
2 3 13 B

样例输出

828


题解

根号分治+STL-map

看到这种根本没法写出什么玄学数据结构之类的,大概率就是根号分治了。

对于本题,由于边数只有 $m$ ,因此度数大于等于 $\sqrt m$ 的点只有 $O(\sqrt m)$ 个,我们称这样的点为大点,度数小于 $\sqrt m$ 的称为小点。

那么对于一个三元环:

如果三个点都是大点:这种情况下我们暴力枚举三个大点,求出是否有满足条件的三元环并加入到答案中即可。时间复杂度为 $O((\sqrt m)^3)=O(m\sqrt m)$ ;

如果三个点中有小点:这种情况下我们枚举每个小点和它的两条出边,判断这三个点是否有满足条件的三元环。此时,枚举第一条出边相当于枚举图中所有边,第二条出边是度数复杂度,而度数小于 $\sqrt m$ ,因此复杂度也是 $O(m\sqrt m)$ 的。注意这个过程需要保证不重不漏,因此只考虑枚举点为这三个点中编号最小的小点的答案。

那么如何判断是否有满足条件的三元环呢?我偷懒了使用STL-map判断两点之间有没有某颜色的边,复杂度上会多一个log。

时间复杂度 $O(m\sqrt m\log m)$ ,实际上跑得挺快的 然而在bz上还是倒数第一...

#include <map>
#include <cmath>
#include <cstdio>
#define N 50010
#define mod 1000000007
using namespace std;
typedef long long ll;
struct data
{
int x , y , z;
data() {}
data(int a , int b , int c) {x = a , y = b , z = c;}
bool operator<(const data &a)const {return x == a.x ? y == a.y ? z < a.z : y < a.y : x < a.x;}
};
map<data , ll> mp;
int head[N] , to[N << 2] , val[N << 2] , opt[N << 2] , next[N << 2] , cnt , d[N] , id[350] , tot;
char str[5];
inline void add(int x , int y , int v , int c)
{
to[++cnt] = y , val[cnt] = v , opt[cnt] = c , next[cnt] = head[x] , head[x] = cnt;
}
int main()
{
int n , m , si , i , j , k , x , y , z , t;
ll ans = 0;
scanf("%d%d" , &n , &m) , si = (int)sqrt(m);
for(i = 1 ; i <= m ; i ++ )
{
scanf("%d%d%d%s" , &x , &y , &z , str);
t = (str[0] == 'R' ? 1 : str[0] == 'G' ? 2 : 3);
add(x , y , z , t) , add(y , x , z , t) , d[x] ++ , d[y] ++ ;
(mp[data(x , y , t)] += z) %= mod , (mp[data(y , x , t)] += z) %= mod;
}
for(i = 1 ; i <= n ; i ++ )
if(d[i] >= si)
id[++tot] = i;
for(i = 1 ; i <= tot ; i ++ )
for(j = 1 ; j <= tot ; j ++ )
for(k = 1 ; k <= tot ; k ++ )
ans = (ans + mp[data(id[i] , id[j] , 1)] * mp[data(id[i] , id[k] , 2)] % mod * mp[data(id[j] , id[k] , 3)]) % mod;
for(i = 1 ; i <= n ; i ++ )
if(d[i] < si)
for(j = head[i] ; j ; j = next[j])
if(d[to[j]] >= si || to[j] > i)
for(k = next[j] ; k ; k = next[k])
if(opt[k] != opt[j] && (d[to[k]] >= si || to[k] > i))
ans = (ans + mp[data(to[j] , to[k] , 6 - opt[j] - opt[k])] * val[j] % mod * val[k]) % mod;
printf("%lld\n" , ans);
return 0;
}
05-11 22:38