我有2个Numpy数组,第1个有210行,第2个有30行,都包含4列,我想在两个仅包含0或1的数组的第4列上应用条件/过滤器。
因此,我想将第1个数组的0检测为Train_Safe,将第1个数组的1检测为Train_Cracked,将第2个数组的0检测为Test_Safe,将第2个数组的1检测为Test_Cracked,并使用Matplotlib在3D散点图上绘制这些值,我尝试使用此代码:

    for i in X_train_merge[0:, 3]:
    if i == 0:
        x_vals_train_0 = X_train_merge[0:, 0:1]
        y_vals_train_0 = X_train_merge[0:, 1:2]
        z_vals_train_0 = X_train_merge[0:, 2:3]
    elif i == 1:
        x_vals_train_1 = X_train_merge[0:, 0:1]
        y_vals_train_1 = X_train_merge[0:, 1:2]
        z_vals_train_1 = X_train_merge[0:, 2:3]
for j in X_test_merge[0:, 3]:
    if j == 0:
        x_vals_test_0 = X_test_merge[0:, 0:1]
        y_vals_test_0 = X_test_merge[0:, 1:2]
        z_vals_test_0 = X_test_merge[0:, 2:3]
    elif j == 1:
        x_vals_test_1 = X_test_merge[0:, 0:1]
        y_vals_test_1 = X_test_merge[0:, 1:2]
        z_vals_test_1 = X_test_merge[0:, 2:3]

ax.scatter(x_vals_train_0, y_vals_train_0, z_vals_train_0, c='g', marker='o', label="Train_Safe")
ax.scatter(x_vals_train_1, y_vals_train_1, z_vals_train_1, c='b', marker='o', label="Train_Cracked")
ax.scatter(x_vals_test_0, y_vals_test_0, z_vals_test_0, c='black', marker='*', label="Test_Safe")
ax.scatter(x_vals_test_1, y_vals_test_1, z_vals_test_1, c='brown', marker='*', label="Test_Cracked")


它可以绘制/给出所有数据点,而不会将其分散/划分为Train_Safe,Train_Cracked,Test_Safe和Test_Cracked。有关此任务的任何建议/解决方案。提前致谢。

最佳答案

有礼貌地提供玩具数据

import numpy as np
a = np.random.rand(10, 4)

a[:, 3] = a[:, 3] > 0.5

a

np.array([[ 0.93011873,  0.80167023,  0.46502502,  0.        ],
       [ 0.48754049,  0.331763  ,  0.19391945,  1.        ],
       [ 0.17976529,  0.73625689,  0.6550934 ,  0.        ],
       [ 0.17797159,  0.89597292,  0.67507392,  1.        ],
       [ 0.89972382,  0.86131195,  0.85239512,  1.        ],
       [ 0.59199271,  0.14223656,  0.12101887,  1.        ],
       [ 0.71962168,  0.89132196,  0.61149278,  0.        ],
       [ 0.63606024,  0.04821054,  0.49971309,  1.        ],
       [ 0.18976505,  0.49880633,  0.93362872,  1.        ],
       [ 0.00154421,  0.79748799,  0.46080879,  0.        ]])


那么np.where是工具:

ts = a[np.where(a[:, -1] == 0), :-1].T

tc = a[np.where(a[:, -1] == 1), :-1].T

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(*ts, c='g', marker='o', label="Train_Safe")
ax.scatter(*tc, c='b', marker='o', label="Train_Cracked")
fig.show()


python - 在Numpy阵列的一列上应用条件/过滤器-LMLPHP

关于python - 在Numpy阵列的一列上应用条件/过滤器,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/47966665/

10-12 18:58