问题描述
我是机器学习和学习的新手.深度学习.我想澄清与train_test_split
有关的疑问,然后再训练
I am new to Machine Learning & Deep Learning. I would like to clarify my doubt related to train_test_split
before training
我有一个大小为(302, 100, 5)
的数据集,其中
I have a data set of size (302, 100, 5)
, where,
(207,100,5)
属于class 0
(95,100,5)
属于class 1.
我想使用LSTM(自序列数据)进行分类
I would like to perform Classification using LSTM (since, sequence Data)
我如何拆分我的数据集进行培训,因为这些课程确实可以没有相等的分布集?
How can i split my data set for training, since the classes donot have equal distribution sets ?
选项1 :考虑整个数据[(302,100, 5) - both classes (0 & 1)]
,将其洗牌,train_test_split,进行培训.
Option 1 : Consider whole data [(302,100, 5) - both classes (0 & 1)]
, shuffle it, train_test_split,proceed training.
选项2::均分两个类数据集[(95,100,5) - class 0 & (95,100,5) - class 1]
,将其洗牌,train_test_split,继续进行培训.
Option 2 : Split both class data set equally[(95,100,5) - class 0 & (95,100,5) - class 1]
, shuffle it,train_test_split, proceed training.
在训练之前进行分割的更好方法是什么,这样就可以在减少损失,准确性,预测等方面获得更好的结果?
What will be the better way of splitting before training, so that i can get better results in terms of loss reduction, accuracy, prediction, ?
如果除了上述2个选项以外还有其他选项,请推荐
If there are other options rather than above 2 options, kindly recommend,
基于评论部分,我包括了部分数据:
X_train:形状(241 * 100 * 5)
每100 * 5中的每一行对应1个时间步长最后100行对应以毫秒(ms)为单位的100个时间步长
Each row in every 100*5 corresponds to 1 Time stepFinally 100 rows corresponds to 100 Time steps in milli seconds (ms)
array([[[0.98620635, 0. , 0.12752912, 0.60897341, 0.46903766],
[0.97345112, 0. , 0.12752912, 0.49205995, 0.38709902],
[0.9566397 , 0. , 0.12752912, 0.45728718, 0.42154812],
...,
[0.28669754, 0.8852459 , 0.12752912, 0.8786213 , 0.80125523],
[0.31559784, 0.8852459 , 0.20968731, 0.89087803, 0.79476987],
[0.34368841, 0.8852459 , 0.12752912, 0.89087803, 0.71066946]],
[[0.97957188, 0.14909194, 0.04159147, 0.50548561, 0.34209531],
[0.9687237 , 0.13964397, 0.04159147, 0.55926067, 0.64613533],
[0.96596236, 0.13553813, 0.04159147, 0.55903796, 0.85299319],
...,
[0.49309139, 0.72396527, 0.04159147, 0.81998825, 0.12362443],
[0.52072591, 0.70872926, 0.04159147, 0.82361951, 0.89639432],
[0.54441507, 0.71835207, 0.04159147, 0.84964602, 1. ]],
[[0.48151381, 0.875 , 0.16666667, 0.90637286, 0.62737926],
[0.53325374, 0.8625 , 0.33333333, 0.87881677, 0.5321154 ],
[0.57506452, 0.81859091, 0.16666667, 0.84915758, 0.3552661 ],
...,
[0.34456041, 0.92993213, 0.33333333, 0.92953899, 0.78782408],
[0.39496018, 0.90523485, 0.33333333, 0.9117954 , 0.54579383],
[0.44187985, 0.8625 , 0.33333333, 0.84163194, 0.25789356]],
...,
[[0.16368355, 0. , 0.15313225, 0.40101906, 0.36784741],
[0.15679684, 0. , 0.15313225, 0.4435126 , 0.67351994],
[0.15544309, 0.06132052, 0.15313225, 0.40101906, 0.36611345],
...,
[0.43936628, 0.68292683, 0.15313225, 0.82305329, 0.36784741],
[0.49751546, 0.68292683, 0.07764888, 0.84141109, 0.42828833],
[0.53288488, 0.68292683, 0.15313225, 0.85959823, 0.36784741]],
[[0.9418247 , 0.30821318, 0.03072816, 0.744977 , 0.93769733],
[0.9537216 , 0.28989357, 0.03072816, 0.74576381, 0.98468743],
[0.96455286, 0.21736423, 0.03072816, 0.74182977, 1. ],
...,
[0.36273884, 0.60113245, 0.06145633, 0.85409181, 0.32277415],
[0.38774614, 0.57789971, 0.05844559, 0.82937631, 0. ],
[0.41546859, 0.57789971, 0.03072816, 0.79315883, 0.31256578]],
[[0.97868688, 0.06451613, 0.00411829, 0.64705259, 0.69827586],
[0.97999663, 0.06451613, 0.02256676, 0.66812232, 0.75195925],
[0.97143037, 0.02476377, 0.02256676, 0.66317859, 0.78487461],
...,
[0.50336862, 0.73867709, 0.02256676, 0.84921606, 0.1226489 ],
[0.54003486, 0.72043011, 0.02256676, 0.82679269, 0.20297806],
[0.57594039, 0.70967742, 0.02256676, 0.83350205, 0. ]]])
Y_train:形状(241,)
[1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0.
1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 1.
0. 0. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 0.
0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1.
1. 0. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1.
0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0.
0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 1. 1. 0. 1.
0. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
1. 0. 0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1.
1.]
供参考,正如您在上面看到的,X火车数据很大,我无法包含我的整个X_train数据的完整集合.因此,在这里我仅提供数据的一个细分,以便更好地了解1个细分的数据外观(i.e X_train[0] : shape- (100*5))
.其余的240
或多或少如下图
For reference, As you can see above, X-train data is large, I cannot include complete set of my entire X_train data. So I provide only one segment of my data here for better understanding of how my data looks like for 1 segment, (i.e X_train[0] : shape- (100*5))
. The remaining 240
will more or less looks like below
array([[9.86206354e-01, 0.00000000e+00, 1.27529123e-01, 2.29139335e-02,
6.08973407e-01, 4.69037657e-01],
[9.73451120e-01, 0.00000000e+00, 1.27529123e-01, 2.60807671e-02,
4.92059955e-01, 3.87099024e-01],
[9.56639704e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
4.57287179e-01, 4.21548117e-01],
[9.34897700e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
4.84177685e-01, 4.69037657e-01],
[9.18030989e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
4.86406180e-01, 4.08577406e-01],
[9.02168015e-01, 0.00000000e+00, 1.27529123e-01, 2.64020795e-02,
4.84920517e-01, 4.04184100e-01],
[8.82551572e-01, 0.00000000e+00, 1.27529123e-01, 2.56783096e-02,
4.51195959e-01, 3.78661088e-01],
[8.69975342e-01, 0.00000000e+00, 1.27529123e-01, 2.40477851e-02,
4.70286733e-01, 4.23640167e-01],
[8.41027241e-01, 0.00000000e+00, 1.27529123e-01, 1.75387576e-02,
5.04754123e-01, 4.34728033e-01],
[8.28189535e-01, 5.28763040e-01, 1.27529123e-01, 6.89133486e-03,
4.98662903e-01, 4.58368201e-01],
[8.21784739e-01, 8.21162444e-01, 1.27529123e-01, 1.06196483e-02,
5.87431288e-01, 5.72594142e-01],
[8.26651597e-01, 9.96721311e-01, 1.27529123e-01, 1.75044480e-02,
6.89050661e-01, 5.40376569e-01],
[8.42115326e-01, 1.00000000e+00, 1.27529123e-01, 1.71205069e-02,
8.35388501e-01, 4.69037657e-01],
[8.64071009e-01, 9.26875310e-01, 1.27529123e-01, 1.34068975e-02,
1.00000000e+00, 4.65062762e-01],
[8.79579724e-01, 7.60158967e-01, 1.27529123e-01, 4.65303975e-03,
9.61744169e-01, 3.65481172e-01],
[9.03630040e-01, 7.61549925e-01, 1.27529123e-01, 4.21518348e-03,
9.22076957e-01, 3.78033473e-01],
[9.18435858e-01, 6.72429210e-01, 1.27529123e-01, 2.70229205e-03,
9.39979201e-01, 5.03138075e-01],
[9.29983046e-01, 6.85345256e-01, 1.27529123e-01, 9.05120794e-04,
8.53736443e-01, 5.52510460e-01],
[9.48081232e-01, 5.78539493e-01, 1.27529123e-01, 6.96485550e-03,
8.84415391e-01, 3.04602510e-01],
[9.48112160e-01, 5.55091903e-01, 1.27529123e-01, 1.10493356e-02,
8.19046204e-01, 4.78661088e-01],
[9.61281634e-01, 5.08693492e-01, 1.27529123e-01, 9.36162843e-03,
8.23651761e-01, 3.21548117e-01],
[9.72179346e-01, 4.91803279e-01, 1.27529123e-01, 9.82725917e-03,
7.57391175e-01, 4.96025105e-01],
[9.84752763e-01, 4.91803279e-01, 1.27529123e-01, 7.04491131e-03,
7.59322538e-01, 3.95397490e-01],
[9.90300024e-01, 4.91803279e-01, 1.27529123e-01, 8.19346712e-03,
7.64819492e-01, 4.69037657e-01],
[9.88306609e-01, 3.77049180e-01, 1.27529123e-01, 8.62642201e-03,
7.93492795e-01, 4.16945607e-01],
[9.91084457e-01, 3.93442623e-01, 1.27529123e-01, 9.16557339e-03,
7.10741346e-01, 4.72175732e-01],
[1.00000000e+00, 3.78936910e-01, 1.27529123e-01, 1.16538387e-02,
6.93359085e-01, 4.76987448e-01],
[9.98925974e-01, 3.93442623e-01, 1.27529123e-01, 1.21309060e-02,
7.16609716e-01, 3.46025105e-01],
[9.92838888e-01, 3.32141083e-01, 1.27529123e-01, 1.19315833e-02,
7.31540633e-01, 4.16527197e-01],
[9.90637415e-01, 3.36910084e-01, 1.27529123e-01, 9.95632874e-03,
7.12524142e-01, 4.15481172e-01],
[9.90761125e-01, 3.38301043e-01, 1.27529123e-01, 6.59235091e-03,
6.86970732e-01, 4.37656904e-01],
[9.90274720e-01, 3.27868852e-01, 2.10913550e-01, 5.68396253e-03,
7.09181399e-01, 4.99372385e-01],
[9.83015202e-01, 3.27868852e-01, 1.27529123e-01, 2.14974358e-02,
7.31392067e-01, 6.41631799e-01],
[9.77392028e-01, 2.85245902e-01, 1.47762109e-01, 2.52861995e-02,
7.09478532e-01, 6.07112971e-01],
[9.75300207e-01, 2.78688525e-01, 1.27529123e-01, 2.91468501e-02,
6.70257020e-01, 6.28242678e-01],
[9.74917831e-01, 2.71733731e-01, 1.27529123e-01, 3.58780734e-02,
6.70257020e-01, 5.72594142e-01],
[9.64950755e-01, 2.62295082e-01, 1.27529123e-01, 3.92992339e-02,
6.36383895e-01, 6.67991632e-01],
[9.63159774e-01, 2.62295082e-01, 1.27529123e-01, 4.82932591e-02,
6.93581934e-01, 5.46443515e-01],
[9.54983679e-01, 2.90511674e-01, 1.27529123e-01, 4.90627752e-02,
6.59708810e-01, 7.40376569e-01],
[9.57595643e-01, 3.11475410e-01, 1.27529123e-01, 4.72492660e-02,
6.49977715e-01, 5.61297071e-01],
[9.51511369e-01, 2.95081967e-01, 1.27529123e-01, 1.82576261e-02,
6.64314366e-01, 5.22384937e-01],
[9.48528275e-01, 2.95081967e-01, 1.27529123e-01, 3.89659403e-03,
6.29846977e-01, 3.20711297e-01],
[9.47085931e-01, 2.95081967e-01, 1.27529123e-01, 6.86682798e-03,
6.48417769e-01, 4.38284519e-01],
[9.38153518e-01, 2.95081967e-01, 1.27529123e-01, 5.73951146e-03,
7.04130144e-01, 5.32635983e-01],
[9.38114156e-01, 2.95081967e-01, 1.27529123e-01, 2.05955826e-02,
6.85782202e-01, 5.47280335e-01],
[9.35597786e-01, 2.95081967e-01, 1.27529123e-01, 2.91141743e-02,
6.69142772e-01, 7.13807531e-01],
[9.29311077e-01, 2.72826627e-01, 1.27529123e-01, 2.91141743e-02,
6.81622344e-01, 5.72594142e-01],
[9.25495753e-01, 2.23646299e-01, 1.27529123e-01, 2.65507546e-02,
6.35566781e-01, 6.41004184e-01],
[9.18525829e-01, 2.08643815e-03, 1.27529123e-01, 2.37618715e-02,
6.09641955e-01, 5.02928870e-01],
[8.91801693e-01, 0.00000000e+00, 1.27529123e-01, 9.27013608e-03,
5.26073392e-01, 4.21338912e-01],
[8.77693149e-01, 0.00000000e+00, 1.27529123e-01, 8.13628440e-03,
4.22522656e-01, 3.44560669e-01],
[8.61894841e-01, 0.00000000e+00, 1.27529123e-01, 1.49639014e-02,
4.52755906e-01, 3.65481172e-01],
[8.44254943e-01, 0.00000000e+00, 1.27529123e-01, 2.29515107e-02,
4.59069975e-01, 3.76150628e-01],
[8.21183060e-01, 0.00000000e+00, 1.27529123e-01, 3.97583295e-02,
4.60852771e-01, 2.60460251e-01],
[8.04116726e-01, 0.00000000e+00, 1.27529123e-01, 5.89292454e-02,
4.26905363e-01, 1.97907950e-01],
[7.81311943e-01, 0.00000000e+00, 1.27529123e-01, 8.53656345e-02,
4.37379290e-01, 1.00836820e-01],
[7.60863270e-01, 0.00000000e+00, 1.27529123e-01, 1.03087377e-01,
4.37379290e-01, 6.98744770e-02],
[7.41227145e-01, 0.00000000e+00, 1.27529123e-01, 1.14206966e-01,
4.27128213e-01, 1.58368201e-01],
[7.26694052e-01, 0.00000000e+00, 1.27529123e-01, 1.17776801e-01,
4.37379290e-01, 0.00000000e+00],
[7.08716764e-01, 0.00000000e+00, 1.27529123e-01, 1.17288297e-01,
4.48596048e-01, 2.18619247e-01],
[6.90483621e-01, 0.00000000e+00, 1.27529123e-01, 1.08491961e-01,
4.58549993e-01, 1.26987448e-01],
[6.67451099e-01, 0.00000000e+00, 1.27529123e-01, 8.38217010e-02,
4.99628584e-01, 3.55020921e-01],
[6.51610618e-01, 0.00000000e+00, 1.27529123e-01, 4.32889541e-02,
5.10919626e-01, 4.83054393e-01],
[6.31195684e-01, 0.00000000e+00, 1.27529123e-01, 1.29200275e-02,
5.21170703e-01, 4.97907950e-01],
[6.14317726e-01, 0.00000000e+00, 2.26241570e-01, 9.32895259e-04,
4.98960036e-01, 4.69037657e-01],
[5.98165158e-01, 0.00000000e+00, 5.90435316e-01, 0.00000000e+00,
4.61892735e-01, 5.03556485e-01],
[5.68221755e-01, 0.00000000e+00, 6.33353771e-01, 1.61745413e-03,
4.25122567e-01, 4.69037657e-01],
[5.35292447e-01, 0.00000000e+00, 1.00000000e+00, 8.99402522e-03,
3.58490566e-01, 5.10041841e-01],
[5.10766973e-01, 0.00000000e+00, 3.93010423e-01, 3.39894098e-02,
3.27068786e-01, 6.15690377e-01],
[4.78939807e-01, 0.00000000e+00, 5.32188841e-01, 5.98114931e-02,
3.27068786e-01, 6.22175732e-01],
[4.47053597e-01, 0.00000000e+00, 4.31023912e-01, 8.44245703e-02,
3.24023176e-01, 6.76150628e-01],
[4.13654754e-01, 0.00000000e+00, 5.32188841e-01, 1.07209434e-01,
2.90298618e-01, 7.08577406e-01],
[3.80151882e-01, 0.00000000e+00, 7.97057020e-01, 1.21122807e-01,
1.19150201e-01, 4.95397490e-01],
[3.28235926e-01, 0.00000000e+00, 3.56223176e-01, 1.23820198e-01,
0.00000000e+00, 6.65271967e-01],
[2.83452966e-01, 0.00000000e+00, 2.28694053e-01, 1.22658572e-01,
2.65933739e-02, 5.55648536e-01],
[2.38616587e-01, 0.00000000e+00, 2.28694053e-01, 1.22990232e-01,
9.41910563e-02, 4.92887029e-01],
[1.82964031e-01, 0.00000000e+00, 5.19926426e-01, 1.30564491e-01,
8.97340663e-02, 4.94142259e-01],
[1.43835174e-01, 0.00000000e+00, 5.25444513e-01, 1.64135650e-01,
1.14618927e-01, 7.40585774e-01],
[1.04402664e-01, 0.00000000e+00, 1.55119559e-01, 2.41378071e-01,
1.98261774e-01, 6.50418410e-01],
[7.96438281e-02, 0.00000000e+00, 7.11220110e-02, 3.27145618e-01,
2.89110088e-01, 7.45188285e-01],
[6.36065353e-02, 0.00000000e+00, 0.00000000e+00, 4.11129065e-01,
4.05140395e-01, 6.88912134e-01],
[4.11672585e-02, 0.00000000e+00, 2.52605763e-01, 5.62182942e-01,
4.54315852e-01, 1.00000000e+00],
[2.87063044e-02, 0.00000000e+00, 1.27529123e-01, 6.81786323e-01,
4.59515674e-01, 9.32217573e-01],
[1.70269716e-02, 1.58966716e-03, 1.27529123e-01, 7.33474602e-01,
4.37453573e-01, 6.07322176e-01],
[3.30361486e-03, 6.37853949e-01, 1.27529123e-01, 8.06276376e-01,
4.69692468e-01, 7.54602510e-01],
[0.00000000e+00, 7.89369101e-01, 1.27529123e-01, 8.85843682e-01,
5.10919626e-01, 8.70502092e-01],
[5.13114648e-03, 8.19672131e-01, 1.27529123e-01, 9.60932765e-01,
5.99316595e-01, 8.79288703e-01],
[2.16829598e-02, 8.36065574e-01, 1.27529123e-01, 9.99121020e-01,
7.28866439e-01, 8.56903766e-01],
[4.27951674e-02, 8.36065574e-01, 1.27529123e-01, 1.00000000e+00,
8.67181697e-01, 7.88912134e-01],
[7.02334461e-02, 8.36065574e-01, 1.27529123e-01, 9.93500775e-01,
8.46308127e-01, 9.78451883e-01],
[9.73680733e-02, 8.36065574e-01, 1.27529123e-01, 9.87896869e-01,
8.66364582e-01, 8.59414226e-01],
[1.23611427e-01, 8.36065574e-01, 1.27529123e-01, 9.69613102e-01,
8.35685634e-01, 9.17991632e-01],
[1.52157471e-01, 8.68852459e-01, 1.27529123e-01, 9.22226597e-01,
7.96686971e-01, 9.65062762e-01],
[1.77979087e-01, 8.68852459e-01, 1.27529123e-01, 8.61132577e-01,
8.29594414e-01, 8.14225941e-01],
[2.03010647e-01, 8.84252360e-01, 1.27529123e-01, 8.13277174e-01,
8.29594414e-01, 9.11506276e-01],
[2.32490138e-01, 8.85245902e-01, 1.27529123e-01, 7.59549923e-01,
8.41851137e-01, 9.52301255e-01],
[2.58952796e-01, 8.85245902e-01, 1.27529123e-01, 6.97804020e-01,
8.55667806e-01, 8.68200837e-01],
[2.86697538e-01, 8.85245902e-01, 1.27529123e-01, 6.25149288e-01,
8.78621304e-01, 8.01255230e-01],
[3.15597842e-01, 8.85245902e-01, 2.09687308e-01, 5.51940700e-01,
8.90878027e-01, 7.94769874e-01],
[3.43688409e-01, 8.85245902e-01, 1.27529123e-01, 4.75801089e-01,
8.90878027e-01, 7.10669456e-01]])
推荐答案
TLDR:尝试两者!
在数据集不平衡之前,我曾遇到过类似情况.我使用了 train_test_split 或 KFold
TLDR: Try both!
I have been in similar situations before where my dataset was imbalanced. I used train_test_split or KFold to get through.
但是,一旦我偶然发现了处理不平衡数据集的问题,并碰到了过度平衡和欠平衡的技术.为此,我建议使用以下库: imblearn
However, once I stumbled upon the problem of handling imbalanced datasets and came across the techniques of overbalancing and underbalancing. To do this, I would recommend using the library: imblearn
您会在这里找到各种技巧来处理其中一个类别的人数超过另一个类别的情况.我个人曾经使用过 SMOTE SMOTE 很多,并且在这种情况下取得了相对较好的成功.
其他参考文献:
You will find various techniques there to handle the cases where one of your classes outnumbers the other one. I personally have used SMOTE a lot and have had relatively better success in such cases.
Other references:
https://www.analyticsvidhya.com/blog/2017/03/不平衡分类问题/
https://towardsdatascience.com/handling-imbalanced-datasets -in-machine-learning-7a0e84220f28
这篇关于拆分用于分类问题的数据集的正确程序是什么?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!