# Shapenet official train/test split if FLAGS.normal: assert(NUM_POINT<=10000) DATA_PATH = os.path.join(ROOT_DIR, 'data/modelnet40_normal_resampled') TRAIN_DATASET = modelnet_dataset.ModelNetDataset(root=DATA_PATH, npoints=NUM_POINT, split='train', normal_channel=FLAGS.normal, batch_size=BATCH_SIZE) TEST_DATASET = modelnet_dataset.ModelNetDataset(root=DATA_PATH, npoints=NUM_POINT, split='test', normal_channel=FLAGS.normal, batch_size=BATCH_SIZE) else: assert(NUM_POINT<=2048) TRAIN_DATASET = modelnet_h5_dataset.ModelNetH5Dataset(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt'), batch_size=BATCH_SIZE, npoints=NUM_POINT, shuffle=True) TEST_DATASET = modelnet_h5_dataset.ModelNetH5Dataset(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt'), batch_size=BATCH_SIZE, npoints=NUM_POINT, shuffle=False)
训练数据(TRAIN_DATASET)是5个.h5格式的文件:
data/modelnet40_ply_hdf5_2048/ply_data_train0.h5
data/modelnet40_ply_hdf5_2048/ply_data_train1.h5
data/modelnet40_ply_hdf5_2048/ply_data_train2.h5
data/modelnet40_ply_hdf5_2048/ply_data_train3.h5
data/modelnet40_ply_hdf5_2048/ply_data_train4.h5
训练之前把5个训练文件的顺序打乱:
if self.shuffle: np.random.shuffle(self.file_idxs)
测试数据(TEST_DATASET)是2个.h5格式的文件:
data/modelnet40_ply_hdf5_2048/ply_data_test0.h5
data/modelnet40_ply_hdf5_2048/ply_data_test1.h5