我正在尝试构建最近的邻居图,即散点图,其中每个数据点都连接到其k个最近的邻居。我当前的解决方案有效,但显然效率不高。这是到目前为止我得到的:

import numpy as np
from scipy.spatial.distance import pdist, squareform
from matplotlib import pyplot as plt

X = np.random.random(500).reshape((250, 2))
k = 4

# matrix of pairwise Euclidean distances
distmat = squareform(pdist(X, 'euclidean'))

# select the kNN for each datapoint
neighbors = np.sort(np.argsort(distmat, axis=1)[:, 0:k])

plt.figure(figsize = (8, 8))
plt.scatter(X[:,0], X[:,1], c = 'black')
for i in np.arange(250):
    for j in np.arange(k):
        x1 = np.array([X[i,:][0], X[neighbors[i, j], :][0]])
        x2 = np.array([X[i,:][1], X[neighbors[i, j], :][1]])
        plt.plot(x1, x2, color = 'black')
plt.show()


python - 使用matplotlib连接散点图中的k个最近邻居的有效方法-LMLPHP

有没有更有效的方法来构建该图?

最佳答案

使用LineCollection一次性绘制所有边缘,而不是分别绘制它们:

import numpy as np
from scipy.spatial.distance import pdist, squareform
from matplotlib import pyplot as plt
from matplotlib.collections import LineCollection

N = 250
X = np.random.rand(250,2)
k = 4

# matrix of pairwise Euclidean distances
distmat = squareform(pdist(X, 'euclidean'))

# select the kNN for each datapoint
neighbors = np.sort(np.argsort(distmat, axis=1)[:, 0:k])

# get edge coordinates
coordinates = np.zeros((N, k, 2, 2))
for i in np.arange(250):
    for j in np.arange(k):
        coordinates[i, j, :, 0] = np.array([X[i,:][0], X[neighbors[i, j], :][0]])
        coordinates[i, j, :, 1] = np.array([X[i,:][1], X[neighbors[i, j], :][1]])

# create line artists
lines = LineCollection(coordinates.reshape((N*k, 2, 2)), color='black')

fig, ax = plt.subplots(1,1,figsize = (8, 8))
ax.scatter(X[:,0], X[:,1], c = 'black')
ax.add_artist(lines)
plt.show()


在我的计算机上,您的代码需要大约1秒钟的时间才能运行。我的版本需要65毫秒。

关于python - 使用matplotlib连接散点图中的k个最近邻居的有效方法,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/50040310/

10-15 18:50