K近邻法与kd树

一、k近邻算法

k近邻算法没有显式的学习过程。

二、k近邻模型

k近邻模型由三个基本要素组成——距离度量,k值选择,分类决策规则

2.1 距离度量

闵可夫斯基距离公式:
\[L_p(x_i,x_j) = (\sum_{l=1}^{n}|x_i^{(l)}-x_j^{(l)}|^p)^{\frac{1}{p}}\]
当p = 2 时,公式称为欧式距离。

当p = 1时,公式称为曼哈顿距离。

当p = \(\infin\)时,它时各个坐标的最大值,即:\(L_{\infin}(x_i,x_j) = \mathop{max}_l|x_i^{(l)} - x_j^{(l)}|\)

2.2 k值的选择

k值选择越小,对噪声越敏感,越容易造成过拟合

2.3 分类决策规则

一般采用多数表决规则(majority voting rule)

三、k近邻法的实现:kd树

但我们使用k近邻算法时,其中最耗时的就是对k近邻的搜索,当我们需要比较每一个训练数据点与所求样本之间的距离,如果是线性搜索,那么当特征空间的维数和训练数据量非常大时,这就变成了k近邻法的鸡肋了。

为了提高k近邻搜索的效率,我么可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数。这里我们就要用到kd树的方法。(这里的k指的是k维空间)

3.1 构造kd树

kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,kd树是二叉树,表示对k维空间的一个划分。构造kd树相当于不断的用垂直于坐标轴的超平面将k维空间切分,构成一系列的k维超矩形空间。

构造kd树的算法如下:

代码(简化代码实现,只实现二维的)

先对kd树和结点进行定义

class KDT(object):
    def __init__(self):
        self.head = None


class KDTNode(object):
    def __init__(self, parent, value, axis):
        self.left = None
        self.right = None
        self.parent = parent
        self.value = value
        self.axis = axis
        self.isVisit = False  #搜索时记录是否被访问过

下面是kd树的构造过程和其用到的工具方法:

def handleData(data, idx):
    '''
    将数据根据idx维度的中位数进行划分
    :param data: 要划分的数据
    :param idx:  在数据的哪个维度上进行划分
    :return: 作为划分结点的med ,左子树数据left,右子树数据right
    '''
    tmp = data[:, idx]
    medI = np.median(tmp)
    no = (np.abs(tmp - medI)).argmin()
    medI = data[no, idx]
    med = data[no]
    data = np.delete(data, no, axis=0)
    left = data[data[:, idx] < medI]
    right = data[data[:, idx] >= medI]
    return med, left, right

def builtNode(data, parent, idx):
    '''
    构造kd树的结点
    :param data: 输入的数据
    :param parent: 父节点
    :param idx: 在idx维度上进行划分
    :return:
    '''
    # 将数据分为idx维度上的中间节点,左部分,右部分
    med, left, right = handleData(data, idx)
    node = KDTNode(parent=parent, value=med, axis=idx)
    if (left.size != 0):
        node.left = builtNode(left, node, 1 - idx)
    if (right.size != 0):
        node.right = builtNode(right, node, 1 - idx)
    return node

def builtKDT(data):
    '''
    构造kd树
    :param data: 输入的数据
    :return:
    '''
    kdt = KDT()
    kdt.root = builtNode(data, None, 0)
    return kdt

def printKDT(kdt):
    '''
    二叉树的先序输出
    :param kdt:
    :return:
    '''
    printNode(kdt.root)

def printNode(node):
    print('结点数据:{},划分维度:{}'.format(node.value, node.axis))
    if node.left != None:
        printNode(node.left)
    if node.right != None:
        printNode(node.right)
data = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
kdt = builtKDT(data2)
printKDT(kdt)

上述代码的结果截图如下:

3.2 搜索kd树

下面介绍如何利用kd树进行最近邻搜索(k=1):给定一个 目标点,首先找到代表目标点区域的叶节点,然后从该叶节点出发,依次回退到父节点,不断查找与目标节点最邻近的节点,当确定不可能存在更近的节点时终止。

def findleaf(root, x):
    '''
    根据kd树,找到x对应的叶节点
    :param root: kd树的根节点
    :param x: 目标节点
    :return:  叶节点
    '''
    axis = root.axis
    leaf = root
    if (x[axis] < root.value[axis]) & (root.left != None):
        leaf = findleaf(root.left, x)
    if (x[axis] >= root.value[axis]) & (root.right != None):
        leaf = findleaf(root.right, x)
    return leaf


def nearstNeighborSearch(node, x, curDis, cur, searchSibling):
    '''
    递归的查找最近邻点
    :param node:
    :param x:
    :param curDis:
    :param cur:
    :param searchSibling:
    :return:
    '''
    if node.isVisit:
        return curDis, cur
    node.isVisit = True
    dis = np.sqrt(np.sum((node.value - x) ** 2))
    if dis < curDis:
        curDis = dis
        cur = node
    if (searchSibling == 1) & (node.left != None):
        tmpDis, tmp = nearstNeighbor(node.left, x, curDis=curDis, cur=cur)
        if tmpDis < curDis:
            curDis = tmpDis
            cur = tmp

    if (searchSibling == 2) & (node.right != None):
        tmpDis, tmp = nearstNeighbor(node.right, x, curDis=curDis, cur=cur)
        if tmpDis < curDis:
            curDis = tmpDis
            cur = tmp
    if node.parent != None:
        axis = node.parent.axis
        tmpDis, tmp = 0, None
        if (x[axis] < node.parent.value[axis]) & (node.parent.value[axis] - x[axis] < curDis):
            tmpDis, tmp = nearstNeighborSearch(node.parent, x, curDis, cur, 2)
        elif (x[axis] >= node.parent.value[axis]) & (x[axis] - node.parent.value[axis] < curDis):
            tmpDis, tmp = nearstNeighborSearch(node.parent, x, curDis, cur, 1)
        else:
            tmpDis, tmp = nearstNeighborSearch(node.parent, x, curDis, cur, 0)
        if tmpDis < curDis:
            curDis = tmpDis
            cur = tmp
    return curDis, cur


def nearstNeighbor(root, x, curDis=np.inf, cur=None):
    leaf = findleaf(root, x)
    curDis, cur = nearstNeighborSearch(leaf, x, curDis, cur, 0)
    return curDis, cur

data = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
kdt = builtKDT(data)
printKDT(kdt)
x = np.array([4, 5])
nearestDis, nearestNode = nearstNeighbor(kdt.root, x, curDis=np.inf, cur=None)
print(nearestDis, nearestNode.value) #输出最短距离 最近邻点

运行的结果如下:

如果实例点是随机分布的,kd树搜索的平均计算复杂度是\(O(logN)\).

12-24 17:34