本文对habitat环境中的baseline点导航PPO的模型的构建和训练进行总结

0 训练代码

这个代码在上一篇文章出现过,再粘贴过来,如下:

import random
import numpy as np
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.config.default import get_config as get_baselines_config
import torch

if __name__ == "__main__":
    run_type = "train"
    config = get_baselines_config("../habitat_baselines/config/pointnav/ppo_pointnav_example.yaml")

    config.defrost()
    config.TASK_CONFIG.DATASET.DATA_PATH="/home/yons/LK/skill_transformer-main/data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"
    config.TASK_CONFIG.DATASET.SCENES_DIR="/home/yons/LK/skill_transformer-main/data/scene_datasets"
    config.freeze()
   
    random.seed(config.TASK_CONFIG.SEED)
    np.random.seed(config.TASK_CONFIG.SEED)
    torch.manual_seed(config.TASK_CONFIG.SEED)
    if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():
        torch.set_num_threads(1)

    trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME)
    print('trainer_init:',trainer_init)
    assert trainer_init is not None, f"{
     config.TRAINER_NAME} is not supported"
    trainer = trainer_init(config)

    if run_type == "train":
        trainer.train()
    elif run_type == "eval":
        trainer.eval()

1 trainer

在配置文件中要配置TRAINER_NAME,上面代码中的config.TRAINER_NAME是ppo,然后通过

trainer = trainer_init(config)

这句话,trainer就成了一个rl.ppo.ppo_trainer.PPOTrainer对象。在......./habitat_baselines/rl/ppo/ppo_trainer.py中定义

2 训练过程

我们看到训练过程调用了rl.ppo.ppo_trainer.PPOTrainer中的train()方法

  trainer.train()

TODO:对train()方法进行分析

3 模型结构定义

PPO是actor_critic结构,需要两个网络一个actor网络,一个critic网络。这两个网络可以共享参数也可以不共享参数。habitat中的ppo在特征提取阶段采用了参数共享,然后分出了两个头。

@baseline_registry.register_trainer(name="ddppo")
@baseline_registry.register_trainer(name="ppo")
class PPOTrainer(BaseRLTrainer):
    r"""Trainer class for PPO algorithm
    Paper: https://arxiv.org/abs/1707.06347.
    """
    supported_tasks = ["Nav-v0"]

    SHORT_ROLLOUT_THRESHOLD: float = 0.25
    _is_distributed: bool
    envs: VectorEnv
    agent: PPO
    actor_critic: NetPolicy

    def __init__(self, config=None):
        super().__init__(config)
        self.actor_critic = None
        self.agent = None
        self.envs = None
        self.obs_transforms = []

        self._static_encoder = False
        self._encoder = None
        self._obs_space = None

        # Distributed if the world size would be
        # greater than 1
        self._is
03-09 07:18