我有2个Numpy数组,第1个有210行,第2个有30行,都包含4列,我想在两个仅包含0或1的数组的第4列上应用条件/过滤器。
因此,我想将第1个数组的0检测为Train_Safe,将第1个数组的1检测为Train_Cracked,将第2个数组的0检测为Test_Safe,将第2个数组的1检测为Test_Cracked,并使用Matplotlib在3D散点图上绘制这些值,我尝试使用此代码:
for i in X_train_merge[0:, 3]:
if i == 0:
x_vals_train_0 = X_train_merge[0:, 0:1]
y_vals_train_0 = X_train_merge[0:, 1:2]
z_vals_train_0 = X_train_merge[0:, 2:3]
elif i == 1:
x_vals_train_1 = X_train_merge[0:, 0:1]
y_vals_train_1 = X_train_merge[0:, 1:2]
z_vals_train_1 = X_train_merge[0:, 2:3]
for j in X_test_merge[0:, 3]:
if j == 0:
x_vals_test_0 = X_test_merge[0:, 0:1]
y_vals_test_0 = X_test_merge[0:, 1:2]
z_vals_test_0 = X_test_merge[0:, 2:3]
elif j == 1:
x_vals_test_1 = X_test_merge[0:, 0:1]
y_vals_test_1 = X_test_merge[0:, 1:2]
z_vals_test_1 = X_test_merge[0:, 2:3]
ax.scatter(x_vals_train_0, y_vals_train_0, z_vals_train_0, c='g', marker='o', label="Train_Safe")
ax.scatter(x_vals_train_1, y_vals_train_1, z_vals_train_1, c='b', marker='o', label="Train_Cracked")
ax.scatter(x_vals_test_0, y_vals_test_0, z_vals_test_0, c='black', marker='*', label="Test_Safe")
ax.scatter(x_vals_test_1, y_vals_test_1, z_vals_test_1, c='brown', marker='*', label="Test_Cracked")
它可以绘制/给出所有数据点,而不会将其分散/划分为Train_Safe,Train_Cracked,Test_Safe和Test_Cracked。有关此任务的任何建议/解决方案。提前致谢。
最佳答案
有礼貌地提供玩具数据
import numpy as np
a = np.random.rand(10, 4)
a[:, 3] = a[:, 3] > 0.5
a
np.array([[ 0.93011873, 0.80167023, 0.46502502, 0. ],
[ 0.48754049, 0.331763 , 0.19391945, 1. ],
[ 0.17976529, 0.73625689, 0.6550934 , 0. ],
[ 0.17797159, 0.89597292, 0.67507392, 1. ],
[ 0.89972382, 0.86131195, 0.85239512, 1. ],
[ 0.59199271, 0.14223656, 0.12101887, 1. ],
[ 0.71962168, 0.89132196, 0.61149278, 0. ],
[ 0.63606024, 0.04821054, 0.49971309, 1. ],
[ 0.18976505, 0.49880633, 0.93362872, 1. ],
[ 0.00154421, 0.79748799, 0.46080879, 0. ]])
那么
np.where
是工具:ts = a[np.where(a[:, -1] == 0), :-1].T
tc = a[np.where(a[:, -1] == 1), :-1].T
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(*ts, c='g', marker='o', label="Train_Safe")
ax.scatter(*tc, c='b', marker='o', label="Train_Cracked")
fig.show()
关于python - 在Numpy阵列的一列上应用条件/过滤器,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/47966665/