思路十分简单,答案只有 3 种可能,但是有一些细节需要额外注意一下. 

code: 

#include <bits/stdc++.h>
#define N 300002
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
int val[N],hd[N],to[N<<1],nex[N<<1],d1[N],d2[N],n,edges,maxx,mx,m2,cnt,uu;
void add(int u,int v)
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
void dfs(int u,int ff)
{
    if(val[u]==mx) d1[u]=0, uu=u;
    for(int i=hd[u];i;i=nex[i])
    {
        int v=to[i];
        if(v==ff) continue;
        dfs(v,u);
        if(d1[v]+1>d1[u])
        {
            d2[u]=d1[u],d1[u]=d1[v]+1;
        }
        else if(d1[v]+1>d2[u]) d2[u]=d1[v]+1;
    }
    maxx=max(d1[u]+d2[u], maxx);
}
int main()
{
    int i,j;
    // setIO("input");
    mx=-1000300000;
    m2=mx;
    scanf("%d",&n);
    for(i=1;i<=n;++i)
    {
        scanf("%d",&val[i]),mx=max(mx,val[i]);
    }
    for(i=1;i<=n;++i) if(val[i]<mx) m2=max(m2, val[i]);
    for(i=1;i<=n;++i) if(val[i]==m2) ++cnt;
    for(i=1;i<n;++i)
    {
        int u,v;
        scanf("%d%d",&u,&v),add(u,v),add(v,u);
    }
    memset(d1,-0x3f,sizeof(d1));
    memset(d2,-0x3f,sizeof(d2));
    dfs(1,0);
    if(maxx==0)
    {
        if(m2!=mx-1)
            printf("%d\n",mx);
        else
        {
            for(int i=hd[uu];i;i=nex[i])
            {
                int v=to[i];
                if(val[v]==m2) --cnt;
            }
            if(cnt) printf("%d\n",mx+1);
            else printf("%d\n",mx);
        }
    }
    else if(maxx<=2) printf("%d\n",mx+1);
    else printf("%d\n",mx+2);
    return 0;
}

  

01-22 06:56