[源码解析] TensorFlow 分布式之 ParameterServerStrategy V2

对于 ParameterServerStrategy V2,我们将从几个方面来研究:如何与集群建立连接,如何生成变量,如何获取数据,如何运行。其中,变量和作用域我们在前文已经研究过,运行在 MirroredStrategy 里面也介绍,所以本文主要看看如何使用,如何初始化。在下一篇之中会重点看看如何分发计算。

安利两个github,都是非常好的学习资料,推荐。

https://github.com/yuhuiaws/ML-study

https://github.com/Jack47/hack-SysML

另外推荐西门宇少的最新大作让Pipeline在Transformer LM上沿着Token level并行起来——TeraPipe

本系列其他文章是:

[翻译] TensorFlow 分布式之论文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻译] TensorFlow 分布式之论文篇 "Implementation of Control Flow in TensorFlow"

[源码解析] TensorFlow 分布式环境(1) --- 总体架构

[源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑

[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

[源码解析] TensorFlow 分布式环境(4) --- WorkerCache

[源码解析] TensorFlow 分布式环境(5) --- Session

[源码解析] TensorFlow 分布式环境(7) --- Worker 动态逻辑

[源码解析] TensorFlow 分布式环境(8) --- 通信机制

[翻译] 使用 TensorFlow 进行分布式训练

[源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇

[源码解析] TensorFlow 之 分布式变量

[源码解析] TensorFlow 分布式之 MirroredStrategy

[源码解析] TensorFlow 分布式之 MirroredStrategy 分发计算

[源码解析] TensorFlow 分布式之 ParameterServerStrategy V1

1. 如何使用

在 TensorFlow 2 中,参数服务器训练由 tf.distribution.experimental.ParameterServerStrategy 类提供支持,该类将训练步骤分布到一个可扩展到数千个工作者(伴随着参数服务器)的集群。

1.1 训练方法

支持训练有两种主要方法:

  • Keras Model.fit API。如果用户喜欢用高层次抽象来训练,则建议使用这种方式。
  • 自定义训练循环(custom training loop)。如果用户需要自己实现或者定义训练细节,则可以考虑这种方式。

1.2 集群

无论选择何种API( Model.fit 或自定义训练循环),TensorFlow 2中的分布式训练都会涉及如下概念:一个"集群" 有若干个"作业(job)",每个作业可能包括一个或多个"任务"。而当使用参数服务器训练时,建议使用如下配置:

  • 一个协调者(coordinator ) job(job名称为 chief)。
  • 多个工作者 jobs(job名称为 worker)。
  • 多个参数服务器 jobs(job名称为 ps)。

协调者负责创建资源、分配训练任务、写检查点和处理任务失败,工作者参数服务器则运行 tf.distribution.Server 来听取协调者的请求。

1.3 使用 Model.fit API 进行训练

如果使用 "Model.fit" API,则参数服务器训练需要协调者使用 tf.distribution.experimental.ParameterServerStrategy 对象和 tf.keras.utils.experimental.DatasetCreator 作为输入。与其他策略类似,其工作流程包括:创建和编译模型,准备回调,调用 Model.fit。

1.4 使用自定义循环进行训练

TensorFlow 2 推荐使用一种基于中央协调的架构来进行参数服务器训练。每个工作者和参数服务器都运行一个 tf.distribution.Server,在此基础上,一个协调者任务负责在工作者和参数服务器上创建资源,调度功能,并协调训练。协调器使用 tf.distribution.experimental.coordinator.ClusterCoordinator 来协调集群,使用 tf.distribution.experimental.ParameterServerStrategy 来定义参数服务器上的变量和工作者的计算。在自定义训练循环中, tf.distribution.experimental.coordinator.ClusterCoordinator 类是用于协调器的关键组件。

  • ClusterCoordinator 类需要与 tf.distribution.Strategy 对象一起工作。
  • 对于参数服务器训练, ClusterCoordinator 需要与 tf.distribution.experimental.ParameterServerStrategy 一起工作。
  • 这个 tf.distribution.Strategy 对象需要使用者提供集群的信息,并使用这些信息来定义训练步骤。然后, ClusterCoordinator 对象将这些训练步骤的执行分派给远程工作者。

ClusterCoordinator 提供的最重要的 API 是 schedule 。

  • Schedule API 把一个 tf.function 插入队列,并立即返回一个类似 future 的 RemoteValue 。
  • 在队列之中排队的函数被派发给后台线程中的远程工作者,他们的 RemoteValue 将被异步赋值。
  • 由于 schedule 不需要执行分配任务,因此传递进来的 tf.function 可以在任何可用的工作者上执行。
  • 如果被执行的工作者在结束之前变得不可用,该 tf.function 将在另一个可用的工作者上重试。
  • 由于函数的执行不是原子性的,所以一个函数可能被执行多次。

除了调度远程函数这个功能之外,ClusterCoordinator 还帮助在所有工作者上创建数据集,以及当一个工作者从失败中恢复时重建这些数据集。

1.5 建立集群

如上所述,一个参数服务器训练集群需要一个协调者任务来运行你的训练程序,程序包括一个或几个运行TensorFlow 服务器( tf.distribution.Server )的工作者和参数服务器,可能还有一个运行 side-car 评估的评估任务。设置它们的要求是。

  • 协调者(coordinator)任务需要知道所有其他 TensorFlow 服务器(评估者除外)的地址和端口。
  • 工作者和参数服务器需要知道他们应该监听哪个端口。为了简单起见,用户通常可以在这些任务上创建 TensorFlow 服务器时传入完整的集群信息。
  • 评估器(evaluator)任务不需要知道训练集群的设置,它也不应该试图连接到训练集群。
  • 工作者和参数服务器的任务类型应该分为 "worker" 和 "ps" 两种。出于历史原因,协调器应使用 "chief" 作为任务类型。

2. 初始化

2.1 用例

以下是如何初始化 ParameterServerStrategy 的样例,无论是使用 Model.fit 还是自定义循环,都需要这步工作。为了使用 GPU 进行训练,需要为每个工作者分配可见的 GPU。 ParameterServerStrategy 将使用每个工作者上所有可用的 GPU,但有个限制是:所有工作者都应该有相同数量的 GPU 可用。

variable_partitioner = (
    tf.distribute.experimental.partitioners.MinSizePartitioner(
        min_shard_bytes=(256 << 10),
        max_shards=NUM_PS))

strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver,
    variable_partitioner=variable_partitioner)

对于 variable_partitioner,这是一个 distribute.experimental.partitioners.Partitioner,其指定如何对变量进行分区。如果是 None,变量将不被分割,其特点如下:

  • 此参数取值是 tf.distribute.experimental.partitioners 中预定义的分区器。一个常用的分区器是 MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps),它为每个分片分配至少 256K,每个 ps 最多得到一个分片。
  • 在策略 scope 下创建的每个变量都会调用 variable_partitioner,以指示该变量应如何分区。沿着分区轴只有一个分区的变量(即不需要分区)将被创建为一个普通的 tf.Variable 。
  • 只支持第一个/最外层轴的分区。
  • Div 分区策略被用来对变量进行分区。假设我们沿着变量的第一轴分配连续的整数 id,那么 id 会以连续的方式分配给分片,同时试图保持每个分片的大小相同。如果 id 不能平均分配给分片的数量,那么前几个分片中的每一个将被多分配一个 id。例如,一个变量的第一个维度是 13,它有 13 个 id,它们被分成 5 个分片。 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]] .
  • 在 strategy.extended.colocate_vars_with 下创建的变量将不会被分割。

2.2 集群设置

在真实的生产环境中,用户需要在不同机器上的所有不同进程中运行训练任务。在每个任务上配置集群信息的最简单方法是设置"TF_CONFIG" 环境变量,并使用 tf.distribution.cluster_resolver.TFConfigClusterResolver 来解析"TF_CONFIG" 。如果用户使用 Kubernetes 或其他配置模板开始训练任务,很可能这些模板已经设置了"TF_CONFIG"

2.2.1 设置 "TF_CONFIG" 环境变量

假定你有 3 个工作者,3 个参数服务器,那么 worker 1 的 "TF_CONFIG" 可以如下:

os.environ["TF_CONFIG"] = json.dumps({
   "cluster": {
       "worker": ["host1:port","host2:port","host3:port"],
       "ps": ["host4:port","host5:port"],
       "chief": ["host6:port"]
    },
   "task": {"type":"worker","index": 1}
})

2.2.2 使用二进制文件

如果你喜欢用一个二进制文件来运行所有这些任务,你将需要在程序开始就指明不同分支负责处理不同的角色。

cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
if cluster_resolver.task_type in ("worker","ps"):
  # Start a TensorFlow server and wait.
elif cluster_resolver.task_type =="evaluator":
  # Run side-car evaluation
else:
  # Run the coordinator.

如下代码启动一个 TensorFlow server 然后等待完成。

# Set the environment variable to allow reporting worker and ps failure to the
# coordinator. This is a workaround and won't be necessary in the future.
os.environ["GRPC_FAIL_FAST"] ="use_caller"

server = tf.distribute.Server(
    cluster_resolver.cluster_spec(),
    job_name=cluster_resolver.task_type,
    task_index=cluster_resolver.task_id,
    protocol=cluster_resolver.rpc_layer or"grpc",
    start=True)
server.join()

2.3 初始化方法

初始化方法如下,主要工作是连接到集群,然后调用 _extended 进行继续初始化。

  def __init__(self, cluster_resolver, variable_partitioner=None):
   """Initializes the TF2 parameter server strategy.

    This initializes the  tf.distribute.experimental.ParameterServerStrategy
    object to be ready for use with
     tf.distribute.experimental.coordinator.ClusterCoordinator .
   """
    # pyformat: enable
    self._cluster_resolver = cluster_resolver

    self._verify_args_and_config(cluster_resolver)
    self._cluster_coordinator = None

    self._connect_to_cluster(coordinator_name="chief") # 连接到集群
    self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
                                                       variable_partitioner)
    super(ParameterServerStrategyV2, self).__init__(self._extended)
    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
       "ParameterServerStrategy")
    self._should_use_with_coordinator = True
    # Used while constructing distributed iterators.
    self._canonicalize_devices = False

2.4 连接到集群

_connect_to_cluster 起到了连接到集群的作用,其主要逻辑是设置了 filter,然后调用 remote.connect_to_cluster 去连接集群。

  def _connect_to_cluster(self, coordinator_name):
    if coordinator_name in ["worker","ps"]:
      raise ValueError("coordinator name should not be 'worker' or 'ps'.")
    cluster_spec = self._cluster_resolver.cluster_spec()
    self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
    self._num_ps = len(cluster_spec.as_dict().get("ps", ()))

    device_filters = server_lib.ClusterDeviceFilters()
    # For any worker, only the devices on ps and coordinator nodes are visible
    for i in range(self._num_workers):
      device_filters.set_device_filters(
         "worker", i, ["/job:ps","/job:%s" % coordinator_name])
    # Similarly for any ps, only the devices on workers and coordinator are
    # visible
    for i in range(self._num_ps):
      device_filters.set_device_filters(
         "ps", i, ["/job:worker","/job:%s" % coordinator_name])

    # Allow at most one outstanding RPC for each worker at a certain time. This
    # is to simplify worker failure handling in the runtime
    os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] ="False"

    remote.connect_to_cluster(
        cluster_spec,
        job_name=coordinator_name,
        protocol=self._cluster_resolver.rpc_layer,
        cluster_device_filters=device_filters)

    distribute_lib.distribution_strategy_replica_gauge.get_cell(
       "ps_strategy_num_workers").set(self._num_workers)
    distribute_lib.distribution_strategy_replica_gauge.get_cell(
       "ps_strategy_num_ps").set(self._num_ps)

connect_to_cluster 方法会连接到给定的集群,使集群上的设备可用。如果给定的本地 job 名称没有出现在集群规范中,它将被自动添加,并且使用本地主机上一个未使用的端口。

工作者如果在被过滤的远程设备上访问资源或启动程序/功能,将导致一个未知设备错误。对于任何远程任务,如果没有设备过滤器,所有的集群设备都是可见的;如果指定了设备过滤器,任务则只能看到与至少一个过滤器匹配的设备。任务本身的设备始终是可见的。

以下是使用样例。

cdf = tf.config.experimental.ClusterDeviceFilters()
# For any worker, only the devices on PS nodes and itself are visible
for i in range(num_workers):
  cdf.set_device_filters('worker', i, ['/job:ps'])
# Similarly for any ps, only the devices on workers and itself are visible
for i in range(num_ps):
  cdf.set_device_filters('ps', i, ['/job:worker'])

tf.config.experimental_connect_to_cluster(cluster_def,
                                          cluster_device_filters=cdf)

具体 connect_to_cluster 的代码如下。

@tf_export("config.experimental_connect_to_cluster")
def connect_to_cluster(cluster_spec_or_resolver,
                       job_name="localhost",
                       task_index=0,
                       protocol=None,
                       make_master_device_default=True,
                       cluster_device_filters=None):
 """Connects to the given cluster.

  Will make devices on the cluster available to use. Note that calling this more
  than once will work, but will invalidate any tensor handles on the old remote
  devices.

  If the given local job name is not present in the cluster specification, it
  will be automatically added, using an unused port on the localhost.

  Device filters can be specified to isolate groups of remote tasks to avoid
  undesired accesses between workers. Workers accessing resources or launching
  ops / functions on filtered remote devices will result in errors (unknown
  devices). For any remote task, if no device filter is present, all cluster
  devices will be visible; if any device filter is specified, it can only
  see devices matching at least one filter. Devices on the task itself are
  always visible. Device filters can be particially specified.

  Args:
    cluster_spec_or_resolver: A  ClusterSpec  or  ClusterResolver  describing
      the cluster.
    job_name: The name of the local job.
    task_index: The local task index.
    protocol: The communication protocol, such as "grpc" . If unspecified, will
      use the default from  python/platform/remote_utils.py .
    make_master_device_default: If True and a cluster resolver is passed, will
      automatically enter the master task device scope, which indicates the
      master becomes the default device to run ops. It won't do anything if
      a cluster spec is passed. Will throw an error if the caller is currently
      already in some device scope.
    cluster_device_filters: an instance of
       tf.train.experimental/ClusterDeviceFilters  that specify device filters
      to the remote tasks in cluster.
 """
  if not context.executing_eagerly():
    raise ValueError(
       " tf.config.experimental_connect_to_cluster  can only be called in"
       "eager mode."
    )
  protocol = protocol or remote_utils.get_default_communication_protocol()
  if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
    cluster_spec = cluster_spec_or_resolver
  elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver):
    if cluster_spec_or_resolver.master() in _LOCAL_MASTERS:
      # Do nothing if the master is local.
      return
    cluster_spec = cluster_spec_or_resolver.cluster_spec()
  else:
    raise ValueError(
       " cluster_spec_or_resolver  must be a  ClusterSpec  or a"
       " ClusterResolver .")

  cluster_def = copy.deepcopy(cluster_spec.as_cluster_def())
  if cluster_device_filters:
    if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters):
      cluster_device_filters = copy.deepcopy(
          cluster_device_filters._as_cluster_device_filters())
    else:
      raise ValueError(" cluster_device_filters  must be an instance of"
                      " tf.train.experimental.ClusterDeviceFilters .")

  # Automatically add local job, if not part of the cluster spec.
  if job_name not in cluster_spec.jobs:
    local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
    job_def = cluster_def.job.add()
    job_def.name = job_name
    job_def.tasks[0] ="localhost:{}".format(local_port)

  server_def = ServerDef(
      cluster=cluster_def,
      job_name=job_name,
      task_index=task_index,
      protocol=protocol,
      default_session_config=context.context().config,
      cluster_device_filters=cluster_device_filters)

  if context.get_server_def() is None:
    context.set_server_def(server_def) # 这里会做处理设备
  else:
    context.update_server_def(server_def)

  # 配置 master Device
  if make_master_device_default and isinstance(
      cluster_spec_or_resolver,
      cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
    master = cluster_spec_or_resolver.master()
    master_job_name = None
    master_task_id = None
    for job_name in cluster_spec.jobs:
      for task_id in cluster_spec.task_indices(job_name):
        task_address = cluster_spec.task_address(job_name, task_id)
        if master in task_address or task_address in master:
          master_job_name = job_name
          master_task_id = task_id
          break

    if not master_job_name:
      raise ValueError(
         " make_master_device_default  is set to True but cannot find"
         "master %s in the cluster" % master)

    master_device ="/job:{}/replica:0/task:{}".format(master_job_name,
                                                       master_task_id)
    master_device = device_util.canonicalize(master_device)
    current_device = device_util.current()
    if current_device:
      current_device = device_util.canonicalize(current_device)
    if current_device and current_device != master_device:
      raise ValueError(" connect_to_cluster  is called inside existing device"
                      "scope %s, which is different from the master device"
                      "scope %s to enter. This is not allowed." %
                       (current_device, master_device))

    if not current_device:
      logging.info("Entering into master device scope: %s", master_device)
      ops.device(master_device).__enter__()

2.5 初始化设备

set_server_def 会调用 _initialize_logical_devices 来初始化逻辑设备。

  def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
   """Allow setting a server_def on the context.

    When a server def is replaced, it effectively clears a bunch of caches
    within the context. If you attempt to use a tensor object that was pointing
    to a tensor on the remote device, it will raise an error.

    Args:
      server_def: A tensorflow::ServerDef proto. Enables execution on remote
        devices.
      keep_alive_secs: Num. seconds after which the remote end will hang up. As
        long as the client is still alive, the server state for the context will
        be kept alive. If the client is killed (or there is some failure), the
        server will clean up its context keep_alive_secs after the final RPC it
        receives.

    Raises:
      ValueError: if server_def is None.
   """
    if not server_def:
      raise ValueError("server_def is None.")

    self._server_def = server_def

    if self._context_handle:
      server_def_str = server_def.SerializeToString()
      pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
                                         server_def_str)
      self._initialize_logical_devices()

    # Clear all the caches in case there are remote tensors in them.
    self._clear_caches()

_initialize_logical_devices 则会调用上下文对象的方法和一些其他方法来实现功能。

  def _initialize_logical_devices(self):
   """Helper to initialize devices."""
    # Store list of devices
    logical_devices = []
    context_devices = []
    device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
    try:
      self._num_gpus = 0
      for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
        dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
        context_devices.append(pydev.canonical_name(dev_name))
        spec = pydev.DeviceSpec.from_string(dev_name)
        # If the job is localhost, we assume that the cluster has not yet been
        # configured and thus clear the job, replica & task.
        if spec.job =="localhost":
          spec = spec.replace(job=None, replica=None, task=None)
        logical_devices.append(
            LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
        dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
        if dev_type =="GPU":
          self._num_gpus += 1

    finally:
      self._logical_devices = logical_devices
      self._context_devices = context_devices
      pywrap_tfe.TF_DeleteDeviceList(device_list)

我们以 TFE_ContextListDevices 为例来看,其调用到了 Context 的 ListDevices 方法。

TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
  TF_DeviceList* l = new TF_DeviceList;
  tensorflow::unwrap(ctx)->ListDevices(&l->response);
  return l;
}

上下文如何实现,就需要具体情况具体分析了,比如下面的生成上下文的代码。

TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
  if (opts->use_tfrt) {
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
    tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
        opts->session_options.options,
        static_cast<tensorflow::ContextDevicePlacementPolicy>(
            opts->device_placement_policy),
        opts->async, opts->use_tfrt_distributed_runtime);
#if !defined(IS_MOBILE_PLATFORM)
    tfrt_context->SetDistributedManager(
        tfrt::tf::CreateDistributedManagerContext(
            tfrt_context->GetCoreRuntime()->GetHostContext()));
#endif  // !IS_MOBILE_PLATFORM
    return tensorflow::wrap(tfrt_context);
#else
    status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
    return nullptr;
#endif  // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
  }
  std::vector<std::unique_ptr<tensorflow::Device>> devices;
  status->status = tensorflow::DeviceFactory::AddDevices(
      opts->session_options.options,"/job:localhost/replica:0/task:0",
      &devices);
  if (!status->status.ok()) return nullptr;
  std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
      new tensorflow::DynamicDeviceMgr(std::move(devices)));

  tensorflow::Rendezvous* r =
      new tensorflow::IntraProcessRendezvous(device_mgr.get());
  tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
      opts->session_options.options,
      static_cast<tensorflow::ContextDevicePlacementPolicy>(
          opts->device_placement_policy),
      opts->async, device_mgr.release(),
      /*device_mgr_owned*/ true, r,
      /*cluster_flr=*/nullptr,
      /*collective_executor_mgr=*/nullptr,
      /*run_eager_op_as_function=*/opts->run_eager_op_as_function);
#if !defined(IS_MOBILE_PLATFORM)
  eager_context->SetDistributedManager(
      std::make_unique<tensorflow::EagerContextDistributedManager>(
          eager_context));
#endif  // !IS_MOBILE_PLATFORM
  return tensorflow::wrap(eager_context);
}

2.6 Master 设备

在 connect_to_cluster 之中,会调用 ops.device(master_device).enter() 来设置 master Device。代码位于 tensorflow/python/framework/ops.py。 device_name_or_function 参数可以是一个设备名称字符串,一个设备函数,或者是None:

  • 如果它是一个设备名称字符串,在这个上下文中构建的所有操作将被分配给具有该名称的设备,除非被嵌套的 device() 上下文覆盖。
  • 如果它是一个函数,它将被视为一个从操作对象到设备名称字符串的函数,并且在每次创建一个新操作时被调用。该操作将被分配给具有返回名称的设备。
  • 如果它是 None,所有来自包围上下文(enclosing context)的 device() 调用将被忽略。
@tf_export(v1=["device"])
def device(device_name_or_function):
 """Wrapper for  Graph.device()  using the default graph.

  See  tf.Graph.device  for more details.

  Args:
    device_name_or_function: The device name or function to use in the context.

  Returns:
    A context manager that specifies the default device to use for newly
    created ops.

  Raises:
    RuntimeError: If eager execution is enabled and a function is passed in.
 """
  if context.executing_eagerly():
    if callable(device_name_or_function):
      raise RuntimeError(
         "tf.device does not support functions when eager execution"
         "is enabled.")
    return context.device(device_name_or_function)
  elif executing_eagerly_outside_functions():
    @tf_contextlib.contextmanager
    def combined(device_name_or_function):
      with get_default_graph().device(device_name_or_function):
        if not callable(device_name_or_function):
          with context.device(device_name_or_function):
            yield
        else:
          yield
    return combined(device_name_or_function)
  else:
    return get_default_graph().device(device_name_or_function)

3. 使用 Model.fit 训练

Keras 通过 Model.fit 提供了一个易于使用的训练 API,它在幕后处理训练循环,并且通过可重写的 train_step 和回调方法提供了灵活性,也提供了检查点保存或 TensorBoard 摘要保存等功能。通过 Model.fit,同样的训练代码只需通过简单地交换策略对象即可被用于其他策略。

3.1 输入数据

使用参数服务器训练的 Model.fit 需要在一个 callable 中提供输入数据,该 callable 接收一个 tf.distribution.InputContext 类型的参数,并返回一个 tf.data.Dataset 。然后,系统将创建一个 tf.keras.utils.experimental.DatasetCreator 对象,它接受上述的 callable,并通过 input_options 参数创建一个可选的 tf.distribution.InputOptions 对象。

注意,建议用参数服务器训练来 shuffle 和 repeat 数据,并在 fit 调用中指定 steps_per_epoch,这样库就会知道 epoch 的界限。

关于 InputContext 参数的更多信息,请参见官方 Distributed input 教程。

def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)

  x = tf.random.uniform((10, 10))
  y = tf.random.uniform((10,))

  dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines,
      input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)

  return dataset

dc = tf.keras.utils.experimental.DatasetCreator(dataset_fn)

dataset_fn 中的代码将在每个工作者的输入设备上被调用,这个设备通常是CPU。

3.2 模型构建和编译

处理好数据之后,用户需要创建一个 tf.keras.Model,然后是一个 Model.compile 调用,以纳入组件,如优化器、度量或参数(如 steps_per_execution)。

with strategy.scope():
  model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
  model.compile(tf.keras.optimizers.SGD(), loss='mse', steps_per_execution=10)

3.3 回调和训练

在你调用 model.fit 进行实际训练之前,还需要为常见的工作准备所需的回调,例如。

  • ModelCheckpoint :保存模型的权重。
  • BackupAndRestore :确保训练进度被自动备份,并在集群出现不可用情况(如中止或抢占)时恢复;
  • TensorBoard :将进度报告保存为摘要文件,在 TensorBoard 工具中进行可视化。

注意:由于性能方面的考虑,自定义回调在与 ParameterServerStrategy 一起使用时不能覆盖批级(batch level)回调。请修改你的自定义回调成为 epoch 级别的调用,并将 steps_per_epoch 调整到一个合适的值。此外,当与 ParameterServerStrategy 一起使用时, steps_per_epoch 是 Model.fit 的一个必要参数。

working_dir = '/tmp/my_working_dir'
log_dir = os.path.join(working_dir, 'log')
ckpt_filepath = os.path.join(working_dir, 'ckpt')
backup_dir = os.path.join(working_dir, 'backup')

callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=log_dir),
    tf.keras.callbacks.ModelCheckpoint(filepath=ckpt_filepath),
    tf.keras.callbacks.experimental.BackupAndRestore(backup_dir=backup_dir),
]

model.fit(dc, epochs=5, steps_per_epoch=20, callbacks=callbacks)

3.4 直接使用 ClusterCoordinator (optional)

即使你选择了 Model.fit 训练路径,你也可以选择实例化一个 tf.distribution.experimental.coordinator.ClusterCoordinator 对象来安排你希望在工作者上执行的其他功能。

0x04 自定义训练

使用 tf.distribution.Strategy 的自定义训练循环为定义训练循环提供了极大的灵活性。通过上面定义的 ParameterServerStrategy (作为 strategy ),用户可以使用 tf.distribution.experimental.coordinator.ClusterCoordinator 将训练步骤调度给远程工作者来执行。

和其他 tf.distribution.Strategy 的训练循环一样,用户需要创建一个模型,定义一个数据集和一个步进函数(step function)。为了确保高效的数据集预取,建议使用下面会提到的分布式数据集创建 API。此外,确保在 worker_fn 内调用 Strategy.run,这样可以充分利用分配给工作者的 GPU。

我们接下来看看如何创建这些组件。

4.1 配置数据

首先,编写一个函数来创建一个数据集,其中包括由 Keras preprocessing layers 所实现的预处理逻辑。我们在 dataset_fn 之外创建这些层,但在 dataset_fn 内应用转换,因为我们将把 dataset_fn 包裹到 tf.function 中,它不允许在其内部创建变量。

feature_vocab = [
   "avenger","ironman","batman","hulk","spiderman","kingkong","wonder_woman"
]
label_vocab = ["yes","no"]

with strategy.scope():
  feature_lookup_layer = tf.keras.layers.StringLookup(
      vocabulary=feature_vocab,
      mask_token=None)
  label_lookup_layer = tf.keras.layers.StringLookup(
      vocabulary=label_vocab,
      num_oov_indices=0,
      mask_token=None)

  raw_feature_input = tf.keras.layers.Input(
      shape=(3,),
      dtype=tf.string,
      name="feature")
  feature_id_input = feature_lookup_layer(raw_feature_input)
  feature_preprocess_stage = tf.keras.Model(
      {"features": raw_feature_input},
      feature_id_input)

  raw_label_input = tf.keras.layers.Input(
      shape=(1,),
      dtype=tf.string,
      name="label")
  label_id_input = label_lookup_layer(raw_label_input)

  label_preprocess_stage = tf.keras.Model(
      {"label": raw_label_input},
      label_id_input)

以下是构建数据的代码。

def feature_and_label_gen(num_examples=200):
  examples = {"features": [],"label": []}
  for _ in range(num_examples):
    features = random.sample(feature_vocab, 3)
    label = ["yes"] if"avenger" in features else ["no"]
    examples["features"].append(features)
    examples["label"].append(label)
  return examples

examples = feature_and_label_gen()

然后,使用 dataset_fn 把训练数据集包装起来。

def dataset_fn(_):
  raw_dataset = tf.data.Dataset.from_tensor_slices(examples)

  train_dataset = raw_dataset.map(
      lambda x: (
          {"features": feature_preprocess_stage(x["features"])},
          label_preprocess_stage(x["label"])
      )).shuffle(200).batch(32).repeat()
  return train_dataset

4.2 建立模型

接下来,我们来建立模型和其他对象,要确保在 strategy.scope 之下创建这些变量。

# These variables created under the  strategy.scope  will be placed on parameter
# servers in a round-robin fashion.
with strategy.scope():
  # Create the model. The input needs to be compatible with Keras processing layers.
  model_input = tf.keras.layers.Input(
      shape=(3,), dtype=tf.int64, name="model_input")

  emb_layer = tf.keras.layers.Embedding(
      input_dim=len(feature_lookup_layer.get_vocabulary()), output_dim=16384)
  emb_output = tf.reduce_mean(emb_layer(model_input), axis=1)
  dense_output = tf.keras.layers.Dense(units=1, activation="sigmoid")(emb_output)
  model = tf.keras.Model({"features": model_input}, dense_output)

  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
  accuracy = tf.keras.metrics.Accuracy()

然后需要确保使用 FixedShardsPartitioner 将所有变量分成两个分片,每个分片被分配给不同的参数服务器。

assert len(emb_layer.weights) == 2
assert emb_layer.weights[0].shape == (4, 16384)
assert emb_layer.weights[1].shape == (4, 16384)
assert emb_layer.weights[0].device =="/job:ps/replica:0/task:0/device:CPU:0"
assert emb_layer.weights[1].device =="/job:ps/replica:0/task:1/device:CPU:0"

4.3 定义训练步骤

第三步则是使用 tf.function 来创建训练 step。

@tf.function
def step_fn(iterator):

  def replica_fn(batch_data, labels):
    with tf.GradientTape() as tape:
      pred = model(batch_data, training=True)
      per_example_loss = tf.keras.losses.BinaryCrossentropy(
              reduction=tf.keras.losses.Reduction.NONE)(labels, pred)
      loss = tf.nn.compute_average_loss(per_example_loss)
      gradients = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)
    accuracy.update_state(labels, actual_pred)
    return loss

  batch_data, labels = next(iterator)
  losses = strategy.run(replica_fn, args=(batch_data, labels))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)

在上面的训练步进函数中,在 step_fn 中调用 Strategy.run 和 Strategy.reduce 就可以支持每个工作者的多个GPU。工作者被分配 GPU 之后, Strategy.run 将在多个模型副本上分配数据集。

4.4 分配计算到远端

在使用 ParameterServerStrategy 定义所有的计算后,你将使用 tf.distribution.experimental.coordinator.ClusterCoordinator 类来创建资源并将训练步骤分配给远程工作者。

coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)

然后,为每个工作者(per-worker)创建一个数据集和一个迭代器。在下面的 per_worker_dataset_fn 中,建议将 dataset_fn 包裹到 strategy.distribution_datasets_from_function 中,以允许无缝高效的把数据预取到 GPU。

@tf.function
def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(dataset_fn)

per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)

最后一步是使用 ClusterCoordinator.schedule 将计算分配给远程工作者。

  • schedule 方法把一个 tf.function 插入队列,并立即返回一个 future-like 的 RemoteValue 。队列之中的函数将被派发给后台线程中的远程工作者,RemoteValue 将被异步填充。
  • 可以使用 join 方法( ClusterCoordinator.join )来等待所有被规划(scheduled)的函数执行完毕。
num_epoches = 4
steps_per_epoch = 5
for i in range(num_epoches):
  accuracy.reset_states()
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  # Wait at epoch boundaries.
  coordinator.join()
  print ("Finished epoch %d, accuracy is %f." % (i, accuracy.result().numpy()))

下面是如何得到 RemoteValue 的结果。

loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))
print ("Final loss is %f" % loss.fetch())

或者,你可以启动所有的步骤,并在等待完成时做一些事情。

for _ in range(total_steps):
  coordinator.schedule(step_fn, args=(per_worker_iterator,))
while not coordinator.done():
  time.sleep(10)
  # Do something like logging metrics or writing checkpoints.

4.5 建立数据集

上述代码中的数据集是使用 ClusterCoordinator.create_per_worker_dataset API 创建的。它为每个工作者创建一个数据集,并返回一个容器对象。你可以调用 iter 方法来创建一个属于每个工作者(per-worker)的迭代器。在工作者执行函数之前, ClusterCoordinator.schedule 方法的输入参数将被设置成工作者的相应切片(slice)。

目前, ClusterCoordinator.schedule 方法假定worker都是相同的,因此假定不同worker上的数据集是相同的,如果数据集包含 Dataset.shuffle 操作,则数据集可能会被shuffle。正因为如此,建议用户安排运行有限的步骤,而不是依赖数据集的 OutOfRangeError 。

另一个重要的注意事项是, tf.data 数据集不支持跨任务边界的隐式序列化和反序列化。所以在传递给 ClusterCoordinator.create_per_worker_dataset 的函数内创建整个数据集是很重要的。

5. 运行

5.1 直接运行

如果直接调用 run 来运行,则 ParameterServerStrategy 和其他策略套路类似,比如在 parameter_server_strategy_v2 之中调用了 mirrored_run,所以我们不在赘述。

  def _call_for_each_replica(self, fn, args, kwargs):
    self._assert_being_scheduled_by_cluster_coordinator()

    return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
                                              args, kwargs)

5.2 ClusterCoordinator

另一种方式是使用 ClusterCoordinator 来运行,我们将在下一章节结合自定义训练循环来进行分析。

6. 性能改进

如果你在使用 ParameterServerStrategy 和 ClusterResolver 训练时发现性能问题,可能有几个原因。

一个常见的原因是参数服务器的负载不平衡,一些重载的参数服务器已经达到容量。也可能有多种根本原因。缓解这个问题的一些简单方法是:

  1. 在构建 ParameterServerStrategy 时,通过指定一个 variable_partitioner 来分割你的大型模型变量。
  2. 如果可能的话,避免创建一个所有参数服务器都需要的热点(hotspot)变量。例如,在优化器中使用一个恒定的学习率或子类 tf.keras.optimizers.schedules.LearningRateSchedule,因为默认行为是:学习率将成为一个放在特定参数服务器上的变量,但是此变量在每一步中被所有其他参数服务器使用。
  3. 在将你的大词汇表传递给 Keras 预处理层之前,对它们进行 shuffle。

性能问题的另一个可能原因是协调器。你的第一个 schedule / join 的实现是基于Python的,因此可能有线程开销。另外,协调器和工作者之间的延迟也可能很大。如果是这种情况,那么建议:

  • 对于 Model.fit,你可以将 Model.compile 提供的 steps_per_execution 参数设置为大于1的值。
  • 对于一个自定义的训练循环,你可以将多个步骤打包到一个 tf.function 中。
steps_per_invocation = 10

@tf.function
def step_fn(iterator):
  for _ in range(steps_per_invocation):
    features, labels = next(iterator)
    def replica_fn(features, labels):
      ...

    strategy.run(replica_fn, args=(features, labels))

随着库的进一步优化,希望可以让大多数用户在未来不必手动打包步骤。此外,提高性能的一个小窍门是安排没有返回值的函数。

7. 已知限制

在上述章节中已经涉及了大部分已知的限制。本节提供一个总结。

7.1 ParameterServerStrategy

  • os.environment["grpc_fail_fast"]="use_caller" 在包括协调器在内的每个任务上都需要,以使容错正常工作。
  • 不支持同步的参数服务器训练。
  • 通常需要将多个步骤打包到一个函数中,以实现最佳性能。
  • 不支持通过 tf.saved_model.load 加载含有分片变量的保存模型。注意使用 TensorFlow Serving 加载这样的 saved_model 是可以的。
  • 不支持将包含分片优化器插槽(slot)变量的检查点加载到不同数量的分片中。
  • 不支持在不重启协调者任务的情况下从参数服务器故障中恢复。
  • 使用 tf.lookup.StaticHashTable(它通常被一些 Keras 预处理层采用,如 tf.keras.layer.IntegerLookup 、 tf.keras.layer.StringLookup 和 tf.keras.layer.TextVectorization )将导致在这一步之中参数服务器训练所使用的资源被放在协调器上。这会影响从工作者到协调器的查找RPC的性能。这是目前需要解决的一个高度优先事项。

7.2 Model.fit

  • steps_per_epoch 参数在 Model.fit 中是必需的。你可以选择一个值来确保epoch之内被分割恰当。

  • 由于性能原因, ParameterServerStrategy 不支持批量级自定义回调。你应该将这些调用转换为epoch级的调用,并适当选择 steps_per_epoch,以便每隔 steps_per_epoch 步数调用这些回调。内置回调不受影响:它们的批处理级调用已经被修改为可执行的。官方正在计划为"ParameterServerStrategy"支持批量调用。

  • 出于同样的原因,与其他策略不同,进度条和指标只在epoch边界被记录。

  • 不支持 run_eagerly 。

7.3 自定义循环

  • ClusterCoordinator.schedule 不支持数据集的访问量保证(visitation guarantees)。

0xFF 参考

https://www.youtube.com/watch?v=B2Tpv_N7wkg&ab_channel=TensorFlow

[中字] TFRT: 新的 TensorFlow 运行库 - TF Dev Summit '20

深入理解 TFRT

Inside TensorFlow: Eager execution runtime

【 深度学习框架tensorflow: Inside TensorFlow 】Inside TensorFlow(合辑)

https://github.com/tensorflow/docs-l10n/blob/07e15a23c7fa397bc44acbf20f997f7cb268ab1c/site/en-snapshot/tutorials/distribute/parameter_server_training.ipynb

05-14 11:09