generate_occupancy_data 函数

def generate_occupancy_data(nusc: NuScenes, cur_sample, num_sweeps, save_path='./occupacy/', gt_from: str = 'lidarseg'):
    pcs = []  # 用于存储关键帧的点云数据
    pc_segs = []  # 用于存储关键帧的标签

    intermediate_pcs = []  # 用于存储非关键帧的点云数据
    intermediate_labels = []  # 用于存储非关键帧的标签

    # 获取当前帧的LIDAR数据
    lidar_data = nusc.get('sample_data', cur_sample['data']['LIDAR_TOP'])
    pc = LidarPointCloud.from_file(nusc.dataroot + lidar_data['filename'])
    filename = os.path.split(lidar_data['filename'])[-1]
    lidar_sd_token = cur_sample['data']['LIDAR_TOP']

    # 获取当前帧的点云标签文件路径并读取标签
    lidarseg_labels_filename = os.path.join(nusc.dataroot, nusc.get(gt_from, lidar_sd_token)['filename'])
    lidar_seg = load_bin_file(lidarseg_labels_filename, type=gt_from)

    # 对齐关键帧
    count_prev_frame = 0
    prev_frame = cur_sample.copy()

    # 查找前 `num_sweeps` 帧
    while num_sweeps > 0:
        if prev_frame['prev'] == '':
            break
        prev_frame = nusc.get('sample', prev_frame['prev'])
        count_prev_frame += 1
        if count_prev_frame == num_sweeps:
            break

    # 获取当前帧的信息
    cur_sample_info = get_frame_info(cur_sample, nusc=nusc)

    # 将前 `num_sweeps` 帧的关键帧对齐到当前帧
    if count_prev_frame > 0:
        prev_info = get_frame_info(prev_frame, nusc)
    pc_points = None
    pc_seg = None
    while count_prev_frame > 0:
        income_info = get_frame_info(frame=prev_frame, nusc=nusc)
        prev_frame = nusc.get('sample', prev_frame['next'])
        prev_info = income_info
        pc_points, pc_seg = keyframe_align(prev_info, cur_sample_info)
        pcs.append(pc_points)
        pc_segs.append(pc_seg)
        count_prev_frame -= 1

    # 将后 `num_sweeps` 帧的关键帧对齐到当前帧
    next_frame = cur_sample.copy()
    pc_points = None
    pc_seg = None
    count_next_frame = 0
    while num_sweeps > 0:
        if next_frame['next'] == '':
            break
        next_frame = nusc.get('sample', next_frame['next'])
        count_next_frame += 1
        if count_next_frame == num_sweeps:
            break

    if count_next_frame > 0:
        prev_info = get_frame_info(next_frame, nusc=nusc)

    while count_next_frame > 0:
        income_info = get_frame_info(frame=next_frame, nusc=nusc)
        prev_info = income_info
        next_frame = nusc.get('sample', next_frame['prev'])
        pc_points, pc_seg = keyframe_align(prev_info, cur_sample_info)
        pcs.append(pc_points)
        pc_segs.append(pc_seg)
        count_next_frame -= 1

    # 合并所有关键帧的点云数据和标签
    pcs = np.concatenate(pcs, axis=-1)
    pc_segs = np.concatenate(pc_segs)

    # 将合并后的关键帧点云数据和标签添加到当前帧的点云数据和标签中
    pc.points = np.concatenate((pc.points, pcs), axis=-1)
    lidar_seg = np.concatenate((lidar_seg, pc_segs))

    # 应用范围过滤器,移除超出指定范围的点
    range_mask = (pc.points[0, :] <= 60) & (pc.points[0, :] >= -60) \
                 & (pc.points[1, :] <= 60) 
06-19 15:34