首页 > 试题广场 >

异或

[编程题]异或
  • 热度指数:17002 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 32M,其他语言64M
  • 算法知识视频讲解
给定整数m以及n各数字A1,A2,..An,将数列A中所有元素两两异或,共能得到n(n-1)/2个结果,请求出这些结果中大于m的有多少个。

输入描述:
第一行包含两个整数n,m. 
第二行给出n个整数A1,A2,...,An。
数据范围
对于30%的数据,1 <= n, m <= 1000
对于100%的数据,1 <= n, m, Ai <= 10^5


输出描述:
输出仅包括一行,即所求的答案
示例1

输入

3 10  
6 5 10

输出

2
// 似乎不用各种比较m和a啊
// eg:
// 为了方便举例,假设所有数都是5位二进制,m = 0x00100
// 只需要枚举ai:
// 找到跟ai异或后是0x1****、0x01***、0x0011*、0x00101的个数就行了
// 而异或是可逆的,也就是找 ai^m满足上面条件的个数。
#include <stdio.h>
#include <string.h>
#include <stdint.h>

int n,m;
int a[100010];
int c[(1<<18)+1];

int maxDepth = 16;

// 存0x1****、0x01***等等分别的个数
void buildTrie() {
    memset(c, 0, sizeof(c));
    for(int i=0;i<n;i++) {
        int now = 1;
        for(int j=maxDepth;j>=0;j--) {
            now = (now<<1) + ((a[i]&(1<<j))>0 ? 1 : 0);
            c[now]++;
            // printf("p %d add c[%d] to %d\n", a[i], now, c[now]);
        }
    }
}

int countTrie(int p, int depth) {
    int now = 1;
    for(int j=maxDepth;j>=depth;j--) {
        now = (now<<1) + ((p&(1<<j))>0 ? 1 : 0);
    }
    // printf("p=%d, depth=%d, c[%d]=%d\n", p, depth, now, c[now]);
    return c[now];
}

int calTrie(int p) {
    int count = 0;
    int nowbit = 0;
    for(int j=maxDepth;j>=0;j--) {
        int bit = (1<<j);
        int k = m&bit;
        nowbit |= k;
            // 该位是0,则置上1肯定比m大
        if(k == 0) {
            count += countTrie((nowbit|bit)^p, j);
        }
    }
    return count;
}

int main() {
    scanf("%d%d", &n, &m);
    for(int i=0;i<n;i++) {
        scanf("%d", &a[i]);
    }
    buildTrie();
    __int64_t count = 0;
    for(int i=0;i<n;i++) {
        count += calTrie(a[i]);
    }
    printf("%ld", count/2);
    return 0;
}

发表于 2018-06-27 12:33:32 回复(0)
依据楼上的思路,将循环改为递归,更短小更易理解
```
#include<iostream>
#include<vector>

using namespace std;

class query_tree {
public:
	query_tree *next[2]{NULL,NULL};
    int count;
    query_tree() :count(1) { }
};

query_tree root;
void build_tree(int m) {
	query_tree *cur=&root;
	for(int j=16;j>=0;j--) {
    	bool flag=m>>j & 1;
        if(!cur->next[flag]) {
            cur->next[flag]=new query_tree;
        }
        else
           	cur->next[flag]->count++;
        cur=cur->next[flag];
    }
}

long long query_num(int n,int m,query_tree *root,int index) {
    if(index<0)
        return 0;
    int n_i=n>>index & 1;
    int m_i=m>>index & 1;
    if(n_i==1 && m_i==1) {
        return root->next[0]?query_num(n,m,root->next[0],index-1):0;
    }
    else if(n_i==1 && m_i==0){
        long long val1=root->next[0]?root->next[0]->count:0;
        long long val2=root->next[1]?query_num(n,m,root->next[1],index-1):0;
        return val1+val2;
    }
    else if(n_i==0 && m_i==1){
        return root->next[1]?query_num(n,m,root->next[1],index-1):0;
    }
    else {
        long long val1=root->next[1]?root->next[1]->count:0;
        long long val2=root->next[0]?query_num(n,m,root->next[0],index-1):0;
        return val1+val2;
    }
}
int main() {
    int n,m;
    cin>>n>>m;
    vector<int> vi(n);
    long long count=0;
    for(int i=0;i<n;i++) {
        cin>>vi[i];
        build_tree(vi[i]);
    }
    for(int i=0;i<n;i++)
        count += query_num(vi[i],m,&root,16);
    cout<<count/2;
    return 0;
}
```

编辑于 2017-09-09 21:37:47 回复(0)
from itertools import combinations
n,m = map(int,input().split())
print(len(set(filter(lambda x:x > m,map(lambda x:x[0]^x[1],combinations(list(map(int,input().split())),2))))))
不知道为什么会报错,在小数据集上测试结果还可以,我看这题还没有Python解法,抛砖引玉啦
发表于 2018-07-08 20:28:27 回复(1)
/*1、把所有数字转换成等长的字符串,前补0 
 *2、把字符串插入到tire树中 
 *3、每插入一次,同时查询已经在tire树中的串和当前要插入的串的亦或结果大于m的有多少个(最后一期算再除2也可以,这里直接避免了重复),累加cnt * 查询时,如果m的第k位为0,那么与当前串如果为0,显然tire中k位为1的左右子串都大于m,因此结果加上tire中k位为1的cnt,然后要找tire中第k位为0的,看下一位会不会产生比m大的;如果当前串第K位为1与之相反。 如果m的第K为1,如果当前串的第K位是0,与上面对应,显然tire中k位为0的左右子串都小于m,不累计cnt,直接找tire中第k位为1的时候,k+1位会不会比m大。 
 * 4、输出结果 */
#include<iostream> #include<cstring> #include<algorithm> #include<string> usingnamespacestd; structNode {     Node* next[2];     intcnt;     Node()     {         cnt=0;         memset(next,0,sizeof(next));     } }; voidinsert(Node* root,constchar* s) {     while(*s)     {         if(!root->next[*s-'0'])             root->next[*s-'0']=newNode();         root = root->next[*s-'0'];         root->cnt++;         ++s;     } } longlongsearch(Node* root,constchar* now,constchar* s) {     longlongcnt = 0;     while(*s)     {         if(*now=='0'&&*s=='1')         {             if(root->next[1]==NULL)                 break;             root=root->next[1];         }         elseif(*now=='0'&&*s=='0')         {             if(root->next[1])                 cnt+=root->next[1]->cnt;             if(root->next[0]==NULL)                 break;             root=root->next[0];         }         elseif(*now=='1'&&*s=='1')         {             if(root->next[0]==NULL)                 break;             root = root->next[0];         }         elseif(*now=='1'&&*s=='0')         {             if(root->next[0])                 cnt+=root->next[0]->cnt;             if(root->next[1]==NULL)                 break;             root=root->next[1];         }         ++s;         ++now;     }     returncnt; } string int2str(intvalue) {     string s;     while(value)     {         s+=(value%2+'0');         value/=2;     }     while(s.size()<18)         s+='0';     reverse(s.begin(),s.end());     returns; } intmain() {     intn,m;     cin>>n>>m;     intk;     Node* root = newNode();     longlongcnt = 0;     string sz_m = int2str(m);     while(n--)     {         cin>>k;         string s=int2str(k);         insert(root, s.c_str() );         cnt+=search(root, s.c_str() ,sz_m.c_str());     }     cout<<cnt<<endl;     return0; }
编辑于 2018-05-22 10:26:55 回复(0)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll ans;
int m;
struct Node
{
    int sz;
    struct Node *ls,*rs;
    Node():sz(0),ls(NULL),rs(NULL){}
    ~Node(){delete ls;delete rs;}
};
struct Trie
{
    int B;
    Node *root;
    Trie(int B):B(B),root(new Node){}
    ~Trie(){delete root;}
    void insert(int k)
    {
        Node *now = root;
        root->sz+=1;
        for(int i=B-1;i>=0;--i)
        {
            int b=((k>>i)&1);
            if(b==0)
            {
                if(now->ls == NULL) now->ls = new Node;
                now=now->ls;
            }
            else
            {
                if(now->rs == NULL) now->rs = new Node;
                now=now->rs;
            }
            now->sz+=1;
        }
    }
    void query(int k)
    {
        Node *now = root;
        int x = k^m;
        for(int i=B-1;i>=0;--i)
        {
            int b = ((x>>i)&1);
            if(((m>>i)&1)^1)
            {
                if(b==1&&now->ls!=NULL) ans+=now->ls->sz;
                else if(b==0&&now->rs!=NULL) ans+=now->rs->sz;
            }
            if(b==0)
            {
                if(now->ls == NULL) break;
                now = now->ls;
            }
            else
            {
                if(now->rs == NULL) break;
                now = now->rs;
            }
        }
    }
};
int main()
{
    Trie t(20);
    int n;
    scanf("%d%d",&n,&m);
    for(int i=0;i<n;++i)
    {
        int k ;
        scanf("%d",&k);
        t.query(k);
        t.insert(k);
    }
    cout << ans << '\n';
    return 0;
}
发表于 2017-12-19 23:39:07 回复(0)
#include <algorithm>
#include <stdio.h>
#include <string.h>
#include <math.h>
using namespace std;
typedef long long ll;
const int maxn=100007;
int a[maxn];
struct Node{
    int num;
    Node *next[2];
    void init(){
        num=0;
        memset(next,(int)NULL,sizeof(next));
    }
}newnode[maxn*100];
Node *root;
int p;
Node* getnewnode(){
    newnode[p].init();
    return &newnode[p++];
}
void init(){
    p=0,root=getnewnode();
}
void insert(Node *cur,char *s) {
    if(*s=='\0') return;
    int index=*s-'0';
    if(cur->next[index]==NULL) cur->next[index]=getnewnode();
    cur->next[index]->num++;
    insert(cur->next[index],s+1);
}
ll query(Node *cur,int curi,char *s,char *sm){
    if(!cur) return 0;
    int si=*s-'0',smi=*sm-'0';
    if((curi^si)<smi) return 0;
    else if((curi^si)==smi) 
        return query(cur->next[0],0,s+1,sm+1)+query(cur->next[1],1,s+1,sm+1);
    else if((curi^si)>smi) return 1ll*cur->num;
    return 0;
}
void getstr(int num,char *s,int n) {
    int i=n;
    for(i=0;i<n;i++) s[i]='0';
    s[i--]='\0';
    while(num){
        if(num%2) s[i]='1';
        i--,num>>= 1;
    }
}
char str[27],strm[27];
int main() {
    int i,j,n,m,mx=0,mx2=0;
    scanf("%d%d",&n,&m);
    for(i=0;i<n;i++) scanf("%d",a+i);
    sort(a,a+n);
    mx=max(m,a[n-1]);
    while(mx) mx2++,mx>>=1;
    getstr(m,strm,mx2);
    init();
    int tmp,stri;
    ll ans=0;
    for(i=0;i<n;i++){
        getstr(a[i],str,mx2);
        ans+=query(root->next[0],0,str,strm)+query(root->next[1],1,str,strm);
        insert(root,str);
    }
    printf("%lld\n",ans);
}

发表于 2017-10-30 15:45:37 回复(0)
对于100%的数据,1 <= n, m, Ai <= 10^5
,n*(n-1)/2这个数据怎么也超不了,int表示的范围,我用int建的数组,告诉我超了范围了。



发表于 2018-08-10 15:33:53 回复(1)
直接计算肯定是超时的,所以这问题不能使用暴力破解,考虑到从高位到地位,依次进行位运算,如果两个数异或结果在某高位为1,而m的对应位为0,则肯定任何这两位异或结果为1的都会比m大。
由此,考虑使用字典树(TrieTree)从高位到第位建立字典,再使用每个元素依次去字典中查对应高位异或为1, 而m为0的数的个数,相加在除以2既是最终的结果;直接贴出代码如下,非原创,欢迎讨论;
补充:queryTrieTree在搜索的过程中,是从高位往低位搜索,那么,如果有一个数与字典中的数异或结果的第k位大于m的第k位,那么该数与对应分支中所有的数异或结果都会大于m, 否则,就要搜索在第k位异或相等的情况下,更低位的异或结果。queryTrieTree中四个分支的作用分别如下:
1. aDigit=1, mDigit=1时,字典中第k位为0,异或结果为1,需要继续搜索更低位,第k位为1,异或结果为0,小于mDigit,不用理会;
2. aDigit=0, mDigit=1时,字典中第k位为1,异或结果为1,需要继续搜索更低位,第k位为0,异或结果为0,小于mDigit,不用理会;
3. aDigit=1, mDigit=0时,字典中第k位为0,异或结果为1,与对应分支所有数异或,结果都会大于m,第k位为1,异或结果为0,递归获得结果;
4. aDigit=0, mDigit=0时,字典中第k位为1,异或结果为1,与对应分支所有数异或,结果都会大于m,第k位为0,异或结果为0,递归获得结果;
import java.util.Scanner;

public class Main {
    private static class TrieTree {
        TrieTree[] next = new TrieTree[2];
        int count = 1;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        while (sc.hasNext()){
            int n = sc.nextInt();
            int m = sc.nextInt();
            int[] a = new int[n];
            for (int i = 0; i < n; i++) {
                a[i] = sc.nextInt();
            }
            System.out.println(solve(a, m));
        }
    }

    private static long solve(int[] a, int m) {
        TrieTree trieTree = buildTrieTree(a);
        long result = 0;
        for (int i = 0; i < a.length; i++) {
            result += queryTrieTree(trieTree, a[i], m, 31);
        }
        return result / 2;
    }

    private static long queryTrieTree(TrieTree trieTree, int a, int m, int index) {
        if(trieTree == null)
            return 0;

        TrieTree current = trieTree;
        for (int i = index; i >= 0; i--) {
            int aDigit = (a >> i) & 1;
            int mDigit = (m >> i) & 1;
            if(aDigit == 1 && mDigit == 1) {
                if(current.next[0] == null)
                    return 0;
                current = current.next[0];
            } else if (aDigit == 0 && mDigit == 1) {
                if(current.next[1] == null)
                    return 0;
                current = current.next[1];
            } else if (aDigit == 1 && mDigit == 0) {
                long p = queryTrieTree(current.next[1], a, m, i - 1);
                long q = current.next[0] == null ? 0 : current.next[0].count;
                return p + q;
            } else if (aDigit == 0 && mDigit == 0) {
                long p = queryTrieTree(current.next[0], a, m, i - 1);
                long q = current.next[1] == null ? 0 : current.next[1].count;
                return p + q;
            }
        }
        return 0;
    }

    private static TrieTree buildTrieTree(int[] a) {
        TrieTree trieTree = new TrieTree();
        for (int i = 0; i < a.length; i++) {
            TrieTree current = trieTree;
            for (int j = 31; j >= 0; j--) {
                int digit = (a[i] >> j) & 1;
                if(current.next[digit] == null) {
                    current.next[digit] = new TrieTree();
                } else {
                    current.next[digit].count ++;
                }
                current = current.next[digit];
            }
        }
        return trieTree;
    }
}

编辑于 2017-08-02 17:09:35 回复(25)
/*
C++
思路来源:潇潇古月
思路:
    直接计算肯定是超时的,所以这问题不能使用暴力破解,考虑到从高位到地位,依次进行位运算,
    如果两个数异或结果在某高位为1,而m的对应位为0,则肯定任何这两位异或结果为1的都会比m大。
    由此,考虑使用字典树(TrieTree)从高位到第位建立字典,再使用每个元素依次去字典中查对应
    高位异或为1, 而m为0的数的个数,相加在除以2既是最终的结果;直接贴出代码如下,非原创,欢迎讨论;
    补充:queryTrieTree在搜索的过程中,是从高位往低位搜索,那么,如果有一个数与字典中的数异或结果
    的第k位大于m的第k位,那么该数与对应分支中所有的数异或结果都会大于m, 否则,就要搜索在第k位异或
    相等的情况下,更低位的异或结果。queryTrieTree中四个分支的作用分别如下:
    1. aDigit=1, mDigit=1时,字典中第k位为0,异或结果为1,需要继续搜索更低位,第k位为1,异或结果为0,小于mDigit,不用理会;
    2. aDigit=0, mDigit=1时,字典中第k位为1,异或结果为1,需要继续搜索更低位,第k位为0,异或结果为0,小于mDigit,不用理会;
    3. aDigit=1, mDigit=0时,字典中第k位为0,异或结果为1,与对应分支所有数异或,结果都会大于m,第k位为1,异或结果为0,递归获得结果;
    4. aDigit=0, mDigit=0时,字典中第k位为1,异或结果为1,与对应分支所有数异或,结果都会大于m,第k位为0,异或结果为0,递归获得结果;

改进:
    1.字典树17位即可保证大于100000,移位范围为1~16位,则字典树构建时从16~0即可。
       字典树第一层不占位,实际上是15~-1层有数据,这也是数据中next的用法。
    2.queryTrieTree函数需要考虑到index为-1时的返回值。
    
时间复杂度:O(n);
空间复杂度O(k),k为常数(trie树的高度),因此可以认为O(1)。
 */
#include <iostream>
#include <vector>
using namespace std;

struct TrieTree
{
    int count;//每个节点存的次数
    struct TrieTree* next[2]{NULL,NULL};//每个节点存储两个节点指针
    TrieTree():count(1){}
};

TrieTree* buildTrieTree(const vector<int>& array)
{
    TrieTree* trieTree = new TrieTree();
    for(int i=0;i<(int)array.size();++i)
    {
        TrieTree* cur = trieTree;
        for(int j=16;j>=0;--j)
        {
            int digit = (array[i] >> j) & 1;
            if(NULL == cur->next[digit])
                cur->next[digit] = new TrieTree();
            else
                ++(cur->next[digit]->count);
            cur = cur->next[digit];
        }
    }
    return trieTree;
}

//查询字典树
long long queryTrieTree(TrieTree*& trieTree, const int a, const int m, const int index)
{
    if(NULL == trieTree)
        return 0;

    TrieTree* cur = trieTree;
    
    for(int i=index;i>=0;--i)
    {
        int aDigit = (a >> i) & 1;
        int mDigit = (m >> i) & 1;

        if(1==aDigit && 1==mDigit)
        {
            if(NULL == cur->next[0])
                return 0;
            cur = cur->next[0];
        }
        else if(0 == aDigit && 1==mDigit)
        {
            if(NULL == cur->next[1])
                return 0;
            cur = cur->next[1];
        }
        else if(1 == aDigit && 0 == mDigit)
        {
            long long val0 =  (NULL == cur->next[0]) ? 0 : cur->next[0]->count;
            long long val1 =  queryTrieTree(cur->next[1],a,m,i-1);
            return val0+val1;
        }
        else if(0 == aDigit && 0 == mDigit)
        {
            long long val0 =  queryTrieTree(cur->next[0],a,m,i-1);
            long long val1 =  (NULL == cur->next[1]) ? 0 : cur->next[1]->count;
            return val0+val1;
        }
    }    
    return 0;//此时index==-1,这种情况肯定返回0(其他情况在循环体中都考虑到了)
}

//结果可能超过了int范围,因此用long long
long long solve(const vector<int>& array, const int& m)
{
    TrieTree* trieTree = buildTrieTree(array);
    long long result = 0;
    for(int i=0;i<(int)array.size();++i)
    {
        result += queryTrieTree(trieTree,array[i],m,16);
    }
    return result /2;
}

int main()
{
    int n,m;
    while(cin>>n>>m)
    {
        vector<int> array(n);
        for(int i=0;i<n;++i)
            cin>>array[i];
        cout<< solve(array,m) <<endl;
    }
    return 0;
}

发表于 2017-08-06 15:58:24 回复(8)

这题C++直接暴力都有80%

发表于 2018-10-05 17:09:50 回复(0)
import java.util.*;
/*fine,超时了暂时没招了*/
public class Main
{
    public static void main(String[] args)
    {
        Scanner sc = new Scanner(System.in);
        int m,n;
        n=sc.nextInt();m=sc.nextInt();
        
        int[] A=new int[n];
        for(int i=0;i<n;i++)
            A[i]=sc.nextInt();
        f(n,m,A);
    }
    public static void f(int n,int m,int[] A)
        
    {
        int ct=0;
        for(int i=0;i<n;i++)
            for(int j=i+1;j<n;j++)
            {
                int t=A[i]^A[j];
                if(t>m)ct++;
            }
        System.out.println(ct);
    }
}


发表于 2021-08-22 16:32:36 回复(0)
#include <bits/stdc++.h>
using namespace std;

class Trie {
public:
    Trie() : cnt(0) {}
    
    void insert(int n) {
        Trie* p = this;
        for (int i = 31; i >= -1; --i) {
            ++p->cnt;
            if (i == -1) break;
            int t = !!((1<<i) & n);
            if (!p->child[t]) p->child[t] = new Trie();
            p = p->child[t];
        }
    }
    
    long long search(int x, int M) {
        Trie* p = this;
        long long ret = 0;
        for (int i = 31; i >= 0; --i) {
            if (!p) break;
            int t0 = !!((1<<i) & x), t1 = !!((1<<i) & M);
            if (t1 == 1) {
                p = p->child[t0^1];
            }
            else {
                ret += p->child[t0^1] ? p->child[t0^1]->cnt : 0;
                p = p->child[t0];
            }
        }
        return ret;
    }
    
    Trie* child[2];
    long long cnt;
};


int main() {
    int N, M, x;
    cin >> N >> M;
    Trie* trie = new Trie();
    long long res = 0;
    while(N--) {
        scanf("%d", &x);
        res += trie->search(x, M);
        trie->insert(x);
    }
    cout << res << "\n";
    return 0;
}

发表于 2020-12-02 20:56:35 回复(0)
#include<bits/stdc++.h>
using namespace std;

const int maxn =1e5+7;
int tree[2*maxn][2];
int cnt[2*maxn][2];
int tot=0;
void ins(int x){
    int idx=0;
    for(int i=17;i>=0;i--){
        int bit=(x>>i)&1;
        if(tree[idx][bit]==0) tree[idx][bit]=++tot;
        cnt[idx][bit]++;
        idx=tree[idx][bit];
    }
}


long long ans=0;
void fd(int idx,int x,int m,int n_bit){
    if(n_bit==-1||idx>tot) return ;
    if(!tree[idx][0]&&!tree[idx][1]) return ;
    // this bit of m is 1
    int b1=(x>>n_bit)&1;
    int b2=(m>>n_bit)&1;
    if(b2==1){
        //this bit of m is 1 ,of x is 0 then y must be 1
        if(b1==0&&tree[idx][1]) fd(tree[idx][1],x,m,n_bit-1);
        //this bit of m is 1 ,of x is 1 then y must be 0
        if(b1==1&&tree[idx][0]) fd(tree[idx][0],x,m,n_bit-1);
    }
    if(b2==0){
        if(b1==0){
            //this bit of m is 0 ,of x is 0,add number of bit 1 and recursion for bit 0
            ans+=cnt[idx][1];
            if(tree[idx][0]) fd(tree[idx][0],x,m,n_bit-1);
        }
        if(b1==1){
            //this bit of m is 0 ,of x is 1,add number of bit 0 and recursion for bit 1
            ans+=cnt[idx][0];
            if(tree[idx][1]) fd(tree[idx][1],x,m,n_bit-1);
        }
    }
}


int main(){

    int N,M;
    cin>>N>>M;
    vector<int> vec(N);
    for(int i=0;i<N;i++){
        cin>>vec[i];
        ins(vec[i]);
    }
    for(int i=0;i<N;i++){
        fd(0,vec[i],M,17);
    }
    cout<<ans/2<<endl;
    //3 10 5 6 10
//    for(int i=0;i<36;i++){
//        cout<<tree[i][0] <<" "<<tree[i][1]<<endl;
//    }
//    cout<<"-------------------\n";
//    for(int i=0;i<36;i++){
//        cout<<tree[i][0] <<" "<<tree[i][1]<<endl;
//    }

}
tree数组开的大小一般是,N*每一层节点的最大种类数,比如英文小写是26个,数字字符是10个,等等。
发表于 2020-09-20 00:04:56 回复(0)
直接暴力运算通过率是80%,求解哪里出了问题。
#include<iostream>
using namespace std;
int main(void)
{
    int n,m;
    while(cin>>n>>m)
    {
        int A[100001]={0};
        int i,j,s=0;
        for(i=0;i<n;i++)
            cin>>A[i];
        for(i=0;i<n;i++)
            for(j=i+1;j<n;j++)
            {
                int t;
                t=A[i]^A[j];
                if(t>m) s++;
            }
        cout<<s<<endl;
    }
    return 0;
}

编辑于 2020-09-14 17:35:01 回复(0)

超时
n,m=map(int,input().split())
A=list(map(int,input().split()))
sum=0
k=n-1
i=0
while k:
for j in range(i+1,n):
if A[i]^A[j] > m:
sum+=1
i+=1
k-=1
print(sum)

发表于 2020-08-30 16:59:11 回复(0)
#include <iostream>
using namespace std;

#define MAX_N 100000

struct TrieTreeNode
{
	int size = 0;
	TrieTreeNode* child[2];
};

int minSum;
int a[MAX_N];
TrieTreeNode* root;

void InsertToTree(int num)
{
	TrieTreeNode* currentNode = root;

	for (int i = 17; i >= 0; i--)
	{
		if (nullptr == currentNode->child[(num >> i) & 1])
		{
			TrieTreeNode* newNode = new TrieTreeNode();
			currentNode->child[(num >> i) & 1] = newNode;
			currentNode = newNode;
		}
		else
		{
			currentNode = currentNode->child[(num >> i) & 1];
		}
		currentNode->size++;
	}
}

int Calculate(int num)
{
	TrieTreeNode* currentNode = root;

	int numBit, minBit;
	int result = 0;
	for (int i = 17; i >= 0; i--)
	{
		numBit = (num >> i) & 1;
		minBit = (minSum >> i) & 1;
		// 如果m的当前位是0
		if (minBit == 0)
		{
			// 如果当前节点与numBit不同分支上有节点,则将其size加入result
			if (nullptr != currentNode->child[!numBit])
			{
				result += (currentNode->child[!numBit]->size);
			}	
			if (nullptr != currentNode->child[numBit])
			{
				currentNode = currentNode->child[numBit];
				continue;
			}
			break;
		}
		else if (nullptr != currentNode->child[!numBit])
		{
			currentNode = currentNode->child[!numBit];
			continue;
		}
		break;
	}
	return result;
}

int main()
{   
	root = new TrieTreeNode();
	minSum = 0;
	// memset(a, 0, MAX_N * sizeof(int));
    
    int n;
	cin >> n >> minSum;
	for (int i = 0; i < n; i++)
	{
		cin >> a[i];
		InsertToTree(a[i]);
	}
	long long total = 0;
	for (int i = 0; i < n; i++)
	{
		total += Calculate(a[i]);
	}
	cout << total / 2;
}

发表于 2020-08-29 20:31:03 回复(0)
只能过80%,96609 650样例过不了,但是99202 47499的样例都可以过了,为什么呀。。。
def get_bin(x):
    res = []
    while x:
        res.append(str(x&1))
        x >>= 1
    res += (18-len(res))*['0']
    return "".join(res[::-1])
# print(get_bin(100000))
x = input().split()
n, m = int(x[0]), int(x[1])
s = [int(i) for i in input().split()]
global ans
ans = 0
m_bin = get_bin(m)
# print(m_bin.find('1'))
# print(m_bin)
k = m_bin.find('1')
bins = {}
for i in range(k, 19):
    bins[i] = {}
for num in s:
    num_bin = get_bin(num)
    for j in range(k, 19):
        bins[j][num_bin[:j]] = bins[j].get(num_bin[:j], 0) + 1
for key in bins[k]:
    ans += bins[k][key] * (n-bins[k][key])
# print(bins)
ans //= 2
def search(kk, key1, key2):
    if kk >= 18:
        return
    global ans
    key1_1 = key1 + '1'
    key1_0 = key1 + '0'
    key2_1 = key2 + '1'
    key2_0 = key2 + '0'
    if m_bin[kk] == '1':
        if key1_1 in bins[kk+1] and key2_0 in bins[kk+1]:
            search(kk+1, key1_1, key2_0)
        if key1_0 in bins[kk + 1] and key2_1 in bins[kk + 1]:
            search(kk+1, key1_0, key2_1)
    else:
        # print(bins[kk+1])
        # print(key1, key2)
        ans += bins[kk+1].get(key1_0, 0) * bins[kk+1].get(key2_1, 0) + bins[kk+1].get(key1_1, 0) * bins[kk+1].get(key2_0, 0)
        if key1_1 in bins[kk+1] and key2_1 in bins[kk+1]:
            search(kk + 1, key1_1, key2_1)
        if key1_0 in bins[kk+1] and key2_0 in bins[kk+1]:
            search(kk + 1, key1_0, key2_0)
for key in bins[k]:
    search(k+1, key+"1", key+"0")
print(ans)


发表于 2020-08-15 11:57:09 回复(0)
为此题写了一篇博客:
发表于 2020-08-02 14:40:02 回复(0)

看到没有Go语言版本的,那我就写个Go语言版本的吧,思路参考第一页的大佬的思路,字典树是个好东西

package main
import (
    "fmt"
)
type TrieTree struct {//01字典树
    next [2]*TrieTree
    count int
}
func createTrieTree() *TrieTree {
    return &TrieTree{
        next:  [2]*TrieTree{},
        count: 1,
    }
}
func buildTrieTree(trieTree *TrieTree,A []int) *TrieTree {
    for i := 0; i<len(A); i++ {
        current := trieTree
        for j := 31; j >=0 ; j-- {
            digit := (A[i]>>j) & 1
            if current.next[digit] == nil {
                current.next[digit] = createTrieTree()
            }else {
                current.next[digit].count++
            }
            current = current.next[digit]
        }
    }
    return trieTree
}
func queryTrieTree(trieTree *TrieTree, a int, m int, digitNum int) int {
    if trieTree == nil {
        return 0
    }
    current := trieTree;
    for i := digitNum; i >= 0; i-- {
        aDigit, mDigit := (a >> i) & 1, (m >> i) & 1;
        if aDigit == 1 && mDigit == 1 {
            if current.next[0] == nil {
                return 0
            }
            current = current.next[0]
        }else if aDigit == 0 && mDigit == 1 {
            if current.next[1] == nil {
                return 0
            }
            current = current.next[1]
        }else if aDigit == 0 && mDigit == 0{
            p := queryTrieTree(current.next[0], a, m, i - 1)
            var q int
            if current.next[1] == nil {
                q = 0
            }else{
                q = current.next[1].count
            }
            return p + q
        }else if aDigit == 1 && mDigit == 0 {
            p := queryTrieTree(current.next[1], a, m, i -1)
            var q int
            if current.next[0] == nil{
                q = 0
            }else{
                q = current.next[0].count
            }
            return p + q
        }
    }
    return 0
}
func solve(n int, m int, A []int) int {
    trieTree := createTrieTree()
    trieTree = buildTrieTree(trieTree, A)
    num := 0
    for i := 0; i < n; i++ {
        num += queryTrieTree(trieTree, A[i], m, 31)
    }
    return num/2
}
func main() {
    var n int
    fmt.Scanf("%d", &n)
    var m int
    fmt.Scanf("%d", &m)
    A := make([]int, n)
    for i:=0; i<n; i++{
        fmt.Scanf("%d", &A[i])
    }
    fmt.Print(solve(n, m, A))
}
发表于 2020-07-30 18:33:29 回复(0)

   
#include<bits/stdc++.h>
using namespace std;
struct TrieNode
{
    long path;
    long end;
    vector<TrieNode*>map;
    TrieNode()
    {
        path=0;
        end=0;
        map.resize(2,NULL);
    }
};
class Trie
{
public:
    TrieNode*root;
    vector<long>arr;
    long n;
public:
    Trie(){root=new TrieNode();n=0;}
    void insert(long num);
    long solve(long m);
};
void Trie::insert(long num)
{
     n++;
     arr.push_back(num);
     TrieNode* p=root;
     p->path++;
     long pow=1;
     long index;
     for(long i=16;i>=0;i--)
     {
         pow=1<<i;
         if((num&pow)>0){
            index=1;
         } else{
            index=0;
         }
         if(p->map[index]==NULL)
         {
             p->map[index]=new TrieNode();
         }
         p=p->map[index];
         p->path++;
     }
     p->end++;
};
long Trie::solve(long m)
{
    long res=0;
    long pow=1;
    long t1,t2;
    TrieNode* p=root;
    for(long i=0;i<n;i++)
    {
        p=root;
        for(long j=16;j>=0;j--)
        {
            pow=1<<j;
            t1=pow&m;
            t2=pow&arr[i];
            if(t1==0&&t2>0)
            {
                if(p->map[0]!=NULL){
                    res=res+p->map[0]->path;
                }
                if(p->map[1]!=NULL){
                    p=p->map[1];
                }else{
                    break;
                }
            }
            else if(t1==0&&t2==0)
            {
                if(p->map[1]!=NULL){
                    res=res+p->map[1]->path;
                }
                if(p->map[0]!=NULL){
                    p=p->map[0];
                }else{
                    break;
                }
            }
            else if(t1>0&&t2>0)
            {
                if(p->map[0]!=NULL){
                    p=p->map[0];
                }else{
                    break;
                }
            }
            else if(t1>0&&t2==0)
            {
                if(p->map[1]!=NULL){
                    p=p->map[1];
                }else{
                    break;
                }
            }
        }
    }
    return res/2;
}
int main()
{
    long res=0;
    Trie myTrie;
    long n,m;
    cin>>n>>m;
    long num;
    for(long i=0;i<n;i++)
    {
        cin>>num;
        myTrie.insert(num);
    }
    res=myTrie.solve(m);
    cout<<res<<endl;
    return 0;
}


发表于 2020-07-09 16:53:21 回复(0)

问题信息

难度:
43条回答 17669浏览

热门推荐

通过挑战的用户

查看代码