本文对habitat环境中的baseline点导航PPO的模型的构建和训练进行总结
0 训练代码
这个代码在上一篇文章出现过,再粘贴过来,如下:
import random
import numpy as np
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.config.default import get_config as get_baselines_config
import torch
if __name__ == "__main__":
run_type = "train"
config = get_baselines_config("../habitat_baselines/config/pointnav/ppo_pointnav_example.yaml")
config.defrost()
config.TASK_CONFIG.DATASET.DATA_PATH="/home/yons/LK/skill_transformer-main/data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"
config.TASK_CONFIG.DATASET.SCENES_DIR="/home/yons/LK/skill_transformer-main/data/scene_datasets"
config.freeze()
random.seed(config.TASK_CONFIG.SEED)
np.random.seed(config.TASK_CONFIG.SEED)
torch.manual_seed(config.TASK_CONFIG.SEED)
if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():
torch.set_num_threads(1)
trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME)
print('trainer_init:',trainer_init)
assert trainer_init is not None, f"{
config.TRAINER_NAME} is not supported"
trainer = trainer_init(config)
if run_type == "train":
trainer.train()
elif run_type == "eval":
trainer.eval()
1 trainer
在配置文件中要配置TRAINER_NAME,上面代码中的config.TRAINER_NAME是ppo,然后通过
trainer = trainer_init(config)
这句话,trainer就成了一个rl.ppo.ppo_trainer.PPOTrainer
对象。在......./habitat_baselines/rl/ppo/ppo_trainer.py
中定义
2 训练过程
我们看到训练过程调用了rl.ppo.ppo_trainer.PPOTrainer
中的train()
方法
trainer.train()
TODO:对train()
方法进行分析
3 模型结构定义
PPO是actor_critic结构,需要两个网络一个actor网络,一个critic网络。这两个网络可以共享参数也可以不共享参数。habitat中的ppo在特征提取阶段采用了参数共享,然后分出了两个头。
@baseline_registry.register_trainer(name="ddppo")
@baseline_registry.register_trainer(name="ppo")
class PPOTrainer(BaseRLTrainer):
r"""Trainer class for PPO algorithm
Paper: https://arxiv.org/abs/1707.06347.
"""
supported_tasks = ["Nav-v0"]
SHORT_ROLLOUT_THRESHOLD: float = 0.25
_is_distributed: bool
envs: VectorEnv
agent: PPO
actor_critic: NetPolicy
def __init__(self, config=None):
super().__init__(config)
self.actor_critic = None
self.agent = None
self.envs = None
self.obs_transforms = []
self._static_encoder = False
self._encoder = None
self._obs_space = None
# Distributed if the world size would be
# greater than 1
self._is