kd树踩坑指南

被网易问到了,原问题是
二维平面上,有很多很多点。
现在给一个quary点,要求从二维平面上找到距离quary点最新的点。何解?
后端有后端的做法,机器学习有机器学习的做法。这个题看过小蓝书的都会做,就是最近邻分类器的简单魔改。当时忘得差不多了,现在把kd树仿真了一下,坑还是比较多的,列一下。
小蓝书P40-P45就是这部分内容。
KD树是实现快速找最近的m个点算法实现。
KD树是对坐标轴做了矩形划分,不断用垂直坐标轴去划分矩形区域(***超平面叫超矩形区域),然后利用矩形区域减少搜索量。
由于垂直坐标轴划分的区域,KD树理应用于欧式距离?(个人猜测)
比较好的划分方法是每次选方差最大的方向对应的维度对空间划分,这里用的小蓝书上的方法。第一层就是第一维,第二层就是第二维,第三层就是第三维,更容易实现。
详情见小蓝书。
构造:
输入:k维空间数据集T,其中每个x都是列向量。
输出kd树。
1、构造根结点,以x第一维坐标的中位数为切分点,根结点对应划分两个区域,左子树代表<,右子树代表>,原数据根据大小划分到两个子树。
落在超平面上的点保存为根节点。
2、重复对与深度为j的结点,以(j模k)+1维的中位数切分数据构造左右子区域,等于数据落在结点上,小于落左边大于落右边。直到子区域没有实例。
坑点1:中位数是不是严格中位数?
应该不是,假如有偶数个样本,中位数要么是靠左的要么是靠右的,不可以是中间的平均,这样就没有落在结点上的数据实例了。
坑点2:落在超平面上的点保存在节点上
落在结点上可以有很多组数据,一个都不可以舍掉。虽然舍掉了绝大多数例子都可以过,但是严格的仿真过不了。
重复的数据是可能存在的,比如根结点,划分点是中位数5,那么第一维等于5的落在根上,要知道第一维等于5的数据有多个是小概率事件但不代表没有。
越往后落在结点上的数据越趋近于一个,甚至是我们只在结点上保存一个数据点发现绝大多数例子都是可以过的,原因就是这种异常概率还是太小了。除非数据范围很小,数据很多。
搜索:
输入:kd树,目标点
输出:最近邻。
1、找到包含目标点的叶结点:从根出发,递归向下访问,小于找左子结点,否则右子结点,直到是叶子结点。
2、叶结点为最近点。
3、递归回退,对每个结点操作:
1)如果该子结点更近,更新最近点。
2)找该子结点的另一对应子结点对应区域是否有更近的点,有则移动到另一子结点,搜索整个区域最近点。
判断的方法是以目标点与最近点的距离为半径画圆,判断会不会与目标区域相交。
不相交继续回退。
4、回退到根结点,输出最近点。
复杂度O(log(N))。实际随着k上涨效率迅速下降几乎接近线性扫描。
(个人仿真的结果,不应该叫迅速下降,而是极速下降。。。)

坑点1:算法中很多很多递归。
要想理解递归,必须先理解递归。个人觉得递归不太好理解,用栈模拟的。
坑点2:子结点回退父结点。
最大的坑点是子结点可以是空的,但是空的子结点如何当初始化默认的最近结点?这里书上没说清楚,导致我认为是第一个最近的非空结点。
奇妙的是就算写的不对,绝大多数示例是可以过的,我就直接找了牛客上的一个题(就5个test_case),就算写的很多不对的地方还是过了。
后来自己造了例子才发现了问题。
改bug,发现回退的时候由于是空结点,不能用前向的parent引用回退,那就很尴尬,导致改完了整个parent没有用了。
坑点3:无限递归。
由于相交时要求移动到另一个子结点。但是假如另一个子结点搜索完最近距离没有更新,那么这个另外的子结点又会移动到原来的子结点。导致无限互相跳死循环。
有可能是非递归写法导致的,但我不知道官方标准的写法。这里用了一个set()保存有没有遍历过,与书上的标准写法相违背。
坑点4:小于进入左结点,否则右结点。
小问题,否则这里包含了等于。

#建立树
#kdtree按垂直坐标轴的线划分原始空间,寻找方差最大的维度,按中位数切分成两个子方向,递归进行
#但是李航老师的书上给的是这样划分的,从1维度开始,切分,深度为2用二维度,3用3维度。超过k取模+1
#同样是中位数切分
#没有优化,主要用来学习实现而不是使用。

MAX_ = 0x3f3f3f3f3f3f3f3f
#原始数据采用每一列是一个样本
#因为用的python的list存数据,所以切片是行切片,使用numpy则可以列切片避免重复的转置了
########################################
################结点定义################
'''
val:切分点
layer:当前层
left:左结点
right:右结点
data:当前结点保存数据,是一个集合,每一行是一个样本
parent:父结点。看看就好,没啥用,因为理解错误这里修了很久的bug。
'''
class Node:
    def __init__(self,val,layer,left=None,right = None):
        self.val = val
        self.layer = layer
        self.left = left
        self.right = right
        self.data = None
        self.parent = None
    def print(self):
        print('layer:',self.layer,'val:',self.val)
########################################
################转置函数################
#手动深拷贝
def T(data):
    if not data:
        return None
    row,col = len(data),len(data[0])
    return [[data[i][j]for i in range(row)] for j in range(col)]
########################################
###############中位数获取###############
def get_mid(data,k):
    cur_data = data[k][:]
    mid = (1+len(cur_data))//2-1
    return _get_mid(cur_data,mid,0,len(cur_data)-1)#转化为第k大的数
def _get_mid(data,k,st,en):
    index = partition(data,k,st,en)
    if index==k:
        return data[k]
    else:
        if(index<k):
            return _get_mid(data,k,index+1,en)
        else: 
            return _get_mid(data,k,st,index-1)
def partition(data,k,st,en):
    if st==en:
        return st
    base = data[st]
    base_index = st
    st+=1
    while st<=en:
        while st<=en and data[st]<=base:
            st+=1
        while st<=en and data[en]>=base:
            en-=1
        if st<=en:
            data[st],data[en] = data[en],data[st]   
    data[en],data[base_index] = data[base_index],data[en]
    return en
########################################
#################距离计算###############
#此处是欧式距离
def dis(d1,d2):
    #应保持向量长度一致
    res = [(d1[i]-d2[i])**2 for i in range(len(d1))]
    return sum(res)**0.5
#x到结点最小距离,结点处存的是向量组
#返回值:(距离,数据向量)tuple
def disfind(d1,Node):   
    if not Node:
        return (MAX_,None)
    res = MAX_
    cur_data = None
    d2 = Node.data
    for i in d2:
        cur_res = dis(d1,i)
        if res>cur_res:
            res = cur_res
            cur_data = i
    return (res,cur_data)
########################################

#KDtree 单点
class kdtree:
    ########################################
    #################构造函数###############
    def __init__(self,data):
        self.dataset = data
        self.k = len(data)
        #注意,数据用的是x = [1,2,3,4].T的形式,是个列向量。一行就是一整个特征,一列就是一整个样本
        #这与小蓝书的结构是一致的
        self.root = None
        self.root = self._maketree(self.root,data,0)
        #父结点构建
        self.root.parent = None
        if self.root.left:
            self.root.left.parent = self.root
        if self.root.right:
            self.root.right.parent = self.root
    ########################################
    ###############建立二叉树###############
    def _maketree(self,root,data,dep):
        #变量解释root 父结点 data子数据集 dep深度
        if not data or len(data)==0:
            return None
        cur_k = dep % self.k#根据李航老师的说法写的,从0开始就不用+1了,书上从1开始
        mid = get_mid(data,cur_k)#找到中位数
        #new一个结点出来,mid切分的数据分割点,cur_k切分的特征维
        root = Node(mid,cur_k)
        #因为对样本切分,所以转置一下,此时一行是一个样本
        data = T(data)
        left_data = [x for x in data if x[cur_k]<mid]
        right_data = [x for x in data if x[cur_k]>mid]
        #只是获得切分点,任选一个即可
        root.data = [x for x in data if x[cur_k]==mid]
        root.left = self._maketree(root.left,T(left_data),dep+1)
        root.right = self._maketree(root.right,T(right_data),dep+1)
        #父结点构建,后续没用到
        if root.left:
            root.left.parent = root
        if root.right:
            root.right.parent = root  
        #
        return root
    ########################################
    ###############搜索最近邻###############
    def search(self,x):
        #向量维数保持一致
        if not x or len(x)!=self.k:
            return None
        ########################################
        ###########找到包含x的叶节点############
        nearest_Node = self.root
        #栈模拟递归,tuple存的是(当前结点,父结点,深度)
        Node_Stack = [(self.root,None,0)]
        nearest_k = 0
        visit = set()
        while nearest_Node:
            #此处往下递归入栈的时候就可以先找一个最近的距离出来:
            #result = disfind(x,nearest_Node)
            #min_dis,nearest_res = result[0],result[1]
            #入栈:
            cur_parent = nearest_Node
            if x[nearest_k]<nearest_Node.val:
                nearest_Node = nearest_Node.left
            elif x[nearest_k]>=nearest_Node.val:
                nearest_Node = nearest_Node.right
            nearest_k = (nearest_k+1)%self.k
            Node_Stack.append((nearest_Node,cur_parent,nearest_k))
        ########################################
        ##########回退每个结点进行操作##########
        #a)如果该结点实例点比最近点更新,该实例点为最近
        #b)检查该最近结点的子节点的父结点的另一个子节点区域内是否有更新的点(判交+递归)
        #相交,则移动到另一个子节点,接着递归进行最近邻搜索
        #存在数据里面的nearest_Node.data都是一大堆数据,选出最近的一个行向量
        result = disfind(x,nearest_Node)
        min_dis,nearest_res = result[0],result[1]
        #当前最近点:cur_Node
        while Node_Stack:
            cur = Node_Stack.pop()
            cur_Node,cur_parent,cur_k =cur[0],cur[1],cur[2]
            #没明白的点,应该是我有理解错的地方?找相邻子结点的时候由于与父结点相交,会发生无限递归。
            #相交,则移动到另一个子结点,递归完回退的时候,假如最短距离不变,检查发现又相交,再次移动到另一个子结点
            #同一个父结点的两个子结点互相交换,发生无限递归。。。
            #走过的路不再走,原文没有这个要求,但不加上跑不完。。。
            if cur_Node in visit:
                continue
            if cur_Node:
                visit.add(cur_Node)
            #更新最短距离
            cur_result = disfind(x,cur_Node)
            cur_d,cur_res = cur_result[0],cur_result[1]
            if cur_d < min_dis:
                min_dis = cur_d
                nearest_Node = cur_Node
                nearest_res = cur_res
            #跳出判断
            if cur_Node is self.root:
                break
            ########################################
            ##################判交##################
            #检查另一子节点对应的区域是否与目标点为求点,以目标点到当前最近点的距离为半径的球体相交
            #由于数据切分保存在父结点,父结点切分点左右是子结点,所以判断交在父结点判断
            if cur_parent.left is cur_Node:
                next_Node = cur_parent.right
            else:
                next_Node = cur_parent.left
            #判断相交 索引-1是父结点的分割点,因为父结点分割线把子结点分成左右,0变-1一样是最后一个,所以不用+模取模
            if next_Node:
                #(cur_k-1+k)%k才是正确的写法,但python'-1'正好是最后一个元素,所以可以直接-1使用
                #父结点分界线与圆相交
                if abs(x[cur_k-1] - cur_parent.val)<=min_dis:               
                #暴力if True 必过,复杂度O(n)
                #if True:
                #相交,递归
                    if next_Node not in visit:
                        Node_Stack.append((next_Node,cur_parent,cur_k))
                    while next_Node:
                        #递归
                        cur_parent = next_Node
                        if x[cur_k]<next_Node.val:
                            next_Node = next_Node.left
                        elif x[cur_k]>=next_Node.val:
                            next_Node = next_Node.right
                        cur_k = (cur_k + 1)%self.k
                        if next_Node not in visit:
                            Node_Stack.append((next_Node,cur_parent,cur_k))
        print('实际运算次数:',len(visit))
        return nearest_res
#最近的m个距离怎么写的?因为除了最近邻还有多近邻啊。本质没有任何区别,单点是更替最近点,多点更新最近的第m个点就行了。
#用大顶堆,更新距离:假如新距离比顶小,把新点覆盖顶后更新就好了。最后输出结果排个序就可以了。
#这样做就是m近邻。在原本最近的距离更新这儿改一下和判交的时候用最远的第m个点就好了。可以把前m个距离初始化无穷大。

#正确性检验
old_data = []
data_num = 1000
test_num = 100
feature = 2
#推荐feature改成2、10、20,有惊喜
print('原本运算次数:',data_num)
import random
for _ in range(data_num): 
    old_data.append([random.randint(-5000,5000) for _ in range(feature)])
data = T(old_data)
#暴力res1,kdtree res2
res1 = []
res2 = []
kd = kdtree(data)
test_data = []
for _ in range(test_num):
    test_data.append([random.randint(-5000,5000) for _ in range(feature)])
    min_dis = MAX_
    min_index = 0
    for i in range(len(old_data)):
        cur_dis = dis(old_data[i],test_data[-1])
        if cur_dis < min_dis:
            min_dis = cur_dis
            min_index = i
    res1.append(old_data[min_index])
    #print(dis(res1[-1],test_data[-1]),dis(res2[-1],test_data[-1]))
res = 0
for i in range(test_num):
    res2.append(kd.search(test_data[i]))
    res+=abs(dis(res1[i],test_data[i])-dis(res2[i],test_data[i]))
print('error:',res)


#网易##秋招##算法工程师#
全部评论
tql!
点赞 回复
分享
发布于 2019-10-27 21:36
原来是答案是这个... 求大佬告知小蓝书是啥
点赞 回复
分享
发布于 2019-10-27 21:45
阅文集团
校招火热招聘中
官网直投
太秀了
点赞 回复
分享
发布于 2019-10-27 21:49

相关推荐

4 25 评论
分享
牛客网
牛客企业服务