a*b%mod==k等价于k*inv(b)%mod==a

然后树分治,用hashmap记录即可,unorder_map/map貌似会TLE,我手写了一个

注意这个小范围的逆元可以直接线性处理

复杂度$nlogn*hashmap$

#include<bits/stdc++.h>
#define ll long long
#define rep(ii,a,b) for(int ii=a;ii<=b;++ii)
#define IO ios::sync_with_stdio(false);cin.tie(0);cout.tie(0)
#define fi first
#define se second
#define mp make_pair
#define pii pair<ll,ll>
using namespace std;
const int maxn=1e6+10,maxm=2e6+10;
const ll INF=0x88888888,mod=1e6+3;
int casn,n,m,k;
ll val[maxn],inv[maxn];
const int maxsz=4e6+9;//@素数表@:1e7+19,2e7+3,3e7+23
//1e6+3,2e6+3,3e6+7,4e6+9,1e5+3,2e5+3,3e5+7,4e5+9
//@要保证取值的操作次数小于maxsz,maxsz最好为素数@
//@count操作不增加新节点@
class hash_map{public:
  struct node{ll u;int v,next;}e[maxsz<<1];
  int head[maxsz],nume,numk,id[maxsz];
  bool count(ll u){
    int hs=u%maxsz;
    for(int i=head[hs];i;i=e[i].next)
      if(e[i].u==u) return 1;
    return 0;
  }
  int& operator[](ll u){
    int hs=u%maxsz;
    for(int i=head[hs];i;i=e[i].next)
      if(e[i].u==u) return e[i].v;
    if(!head[hs])id[++numk]=hs;
    return e[++nume]=(node){u,0,head[hs]},head[hs]=nume,e[nume].v;
  }
  void clear(){
    rep(i,0,numk)head[id[i]]=0;
    numk=nume=0;
  }
};

namespace graph{
  struct node{int to,next;}e[maxm];
  int head[maxn],nume,all,vis[maxn],root,maxt;
  int sz[maxn];
  void add(int a,int b){
    e[++nume]={b,head[a]};
    head[a]=nume;
  }
  void init(int n){
    rep(i,0,n) vis[i]=head[i]=0;
    root=nume=1;
  }
  void getroot(int now,int fa){
    sz[now]=1;
    for(int i=head[now];i;i=e[i].next){
      int to=e[i].to;
      if(to==fa||vis[to]) continue;
      getroot(to,now);
      sz[now]+=sz[to];
    }
    int tmp=max(sz[now]-1,all-sz[now]);
    if(maxt>tmp) maxt=tmp,root=now;
  }//@基础部分@
  hash_map p;
  int dfn;
  pii stree[maxn];
  pii ans;
  int flag;
  void dfs(int now,int fa,ll dis){
    stree[++dfn]=mp(dis,now);
    for(int i=head[now];i;i=e[i].next){
      int to=e[i].to;
      if(to==fa||vis[to]) continue;
      dfs(to,now,dis*val[to]%mod);
    }
  }
  void cal(int now,int root){
    dfn=0;
    dfs(now,now,val[now]);
    rep(i,1,dfn){
      ll x=k*inv[stree[i].fi]%mod;
      if(p.count(x)){
        pii tmp=mp(p[x],stree[i].se);
        if(tmp.fi>tmp.se) swap(tmp.fi,tmp.se);
        if(tmp.fi<ans.fi||tmp.fi==ans.fi&&tmp.se<ans.se) ans=tmp;
        flag=1;
      }
    }
    rep(i,1,dfn){
      (stree[i].fi*=val[root])%=mod;
      int &x=p[stree[i].fi];
      if(!x) x=stree[i].se;
      else if(x>stree[i].se)x=stree[i].se;
    }
  }
  void getans(int now){
    vis[now]=1;p.clear();
    p[val[now]]=now;
    for(int i=head[now];i;i=e[i].next){
      int to=e[i].to;
      if(vis[to]) continue;
      cal(to,now);
    }
    for(int i=head[now];i;i=e[i].next){
      int to=e[i].to;
      if(vis[to]) continue;
      all=sz[to],maxt=n+1;
      getroot(to,now);getans(root);
    }
  }
  void solve(int n){
    flag=0;maxt=all=n;
    ans=mp(1e9,1e9);
    getroot(1,1);
    getans(root);
  }
}using namespace graph;
namespace fastio{//@支持读取整数,字符串,输出整数@
bool isdigit(char c){return c>=48&&c<=57;}
const int maxsz=1e7;
class fast_iostream{public:
  char ch=get_char();
  bool endf=1,flag;
  char get_char(){
    static char buffer[maxsz],*a=buffer,*b=buffer;
    return b==a&&(b=(a=buffer)+fread(buffer,1,maxsz, stdin),b==a)?EOF:*a++;
  }
  template<typename type>bool get_int(type& tmp){
    flag=tmp=0;
    while(!isdigit(ch)&&ch!=EOF){flag=ch=='-';ch=get_char();};
    if(ch==EOF)return endf=0;
    do{tmp=ch-48+tmp*10;}while(isdigit(ch=get_char()));
    if(flag)tmp=-tmp;
    return 1;
  }
  int get_str(char* str){
    char* tmp=str;
    while(ch=='\r'||ch=='\n'||ch==' ')ch=get_char();
    if(ch==EOF)return(endf=0),*tmp=0;
    do{*(tmp++)=ch;ch=get_char();}while(ch!='\r'&&ch!='\n'&&ch!=' '&&ch!=EOF);
    *(tmp++)=0;
    return(int)(tmp-str-1);
  }
  fast_iostream& operator>>(char* tmp){get_str(tmp);return *this;}
  template<typename type>fast_iostream& operator>>(type& tmp){get_int(tmp);return *this;}
  operator bool() const {return endf;}
};
template<typename type>void put(type tmp){
  if (tmp==0){putchar(48);return;}
  static int top,stk[21];
  if (tmp<0){tmp=-tmp;putchar('-');}
  while(tmp)stk[++top]=tmp%10,tmp/=10;
  while(top)putchar(stk[top--]+48);
}
}fastio::fast_iostream io;
#define cin io
int main() {
  inv[1]=1;
  rep(i,2,mod-1) {
    inv[i]=(-mod/i)*inv[mod%i]%mod+mod;
    if(inv[i]<0) inv[i]+=mod;
  }
  while(cin>>n>>k){
    graph::init(n);
    rep(i,1,n) cin>>val[i];
    int a,b,c;
    rep(i,2,n){
      cin>>a>>b;
      graph::add(a,b);
      graph::add(b,a);
    }
    graph::solve(n);
    if(!flag) puts("No solution");
    else {
      if(ans.fi>ans.se) swap(ans.fi,ans.se);
      fastio::put(ans.fi);putchar(' ');
      fastio::put(ans.se);putchar('\n');
    }
  }
}

 跑的还挺快的

01-22 20:00