比赛 2024暑期C班集训2 评测结果 AAAAAAAAAA
题目名称 大力枚举 最终得分 100
用户昵称 ┭┮﹏┭┮ 运行时间 0.131 s
代码语言 C++ 内存使用 2.21 MiB
提交时间 2024-07-02 09:04:20
显示代码纯文本
#include <bits/stdc++.h> 
using namespace std;
#define ll long long
const int N = 1e5+10;
const ll mod = 1e9+7;


ll read(){
   ll x = 0,f = 1;char c = getchar();
   for(;c < '0' || c > '9';c = getchar())if(c == '-')f = -1;
   for(;c >= '0' && c <= '9';c = getchar())x = (x<<1) + (x<<3) + c-'0';
   return x * f;
}

int n;
ll a[N],c[N],cnt[N];
ll s1[N],s2[N],s3[N],s4[N];
ll ksm(ll x,ll y){
    ll ans = 1;
    while(y){
        if(y & 1)ans = ans * x % mod;
        x = x * x % mod,y >>= 1;
    }return ans;
}
ll C(ll n,ll m){
    ll ans = 1;
    for(ll i = n;i >= n-m+1;i--)ans = ans * i % mod;
    for(ll i = 2;i <= m;i++)ans = ans * ksm(i,mod-2) % mod;
    return ans;
}
int main(){
    freopen("enumerate.in","r",stdin);
    freopen("enumerate.out","w",stdout);
    n = read();
    for(int i = 1;i <= n;i++)c[i] = a[i] = read();
    sort(c+1,c+1+n);
    int len = unique(c+1,c+1+n) - (c+1);
    for(int i = 1;i <= n;i++)a[i] = lower_bound(c+1,c+1+len,a[i]) - c,cnt[a[i]]++;
    //
    for(int i = 1;i <= len;i++){
        s1[i] = s1[i-1] % mod;
        if(cnt[i])s1[i] = (s1[i] + cnt[i] * c[i] % mod) % mod;
    }
    for(int i = 1;i <= len;i++){
        s2[i] = s2[i-1] % mod;
        if(cnt[i] > 0)s2[i] = (s2[i] + cnt[i] * c[i] % mod * s1[i-1] % mod) % mod;
        if(cnt[i] > 1)s2[i] = (s2[i] + C(cnt[i],2) * ksm(c[i],2) % mod) % mod;
    }
    for(int i = 1;i <= len;i++){
        s3[i] = s3[i-1] % mod;
        if(cnt[i] > 0)s3[i] = (s3[i] + cnt[i] * c[i] % mod * s2[i-1] % mod) % mod;
        if(cnt[i] > 1)s3[i] = (s3[i] + C(cnt[i],2) * s1[i-1] % mod * ksm(c[i],2) % mod) % mod;
        if(cnt[i] > 2)s3[i] = (s3[i] + C(cnt[i],3) * ksm(c[i],3) % mod) % mod;
    }
    for(int i = 1;i <= len;i++){
        s4[i] = s4[i-1] % mod;
        if(cnt[i] > 0)s4[i] = (s4[i] + cnt[i] * c[i] % mod * s3[i-1] % mod) % mod;
        if(cnt[i] > 1)s4[i] = (s4[i] + C(cnt[i],2) * s2[i-1] % mod * ksm(c[i],2) % mod) % mod;
        if(cnt[i] > 2)s4[i] = (s4[i] + C(cnt[i],3) * s1[i-1] % mod * ksm(c[i],3) % mod) % mod;
        if(cnt[i] > 3)s4[i] = (s4[i] + C(cnt[i],4) * ksm(c[i],4) % mod) % mod;
    }
    printf("%lld\n",s4[len]);
    
    
    return 0;
}