分布式TensorFlow由高性能gRPC库底层技术支持。Martin Abadi、Ashish Agarwal、Paul Barham论文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》。


分布式原理。分布式集群 由多个服务器进程、客户端进程组成。部署方式,单机多卡、分布式(多机多卡)。多机多卡TensorFlow分布式。


单机多卡,单台服务器多块GPU。训练过程:在单机单GPU训练,数据一个批次(batch)一个批次训练。单机多GPU,一次处理多个批次数据,每个GPU处理一个批次数据计算。变量参数保存在CPU,数据由CPU分发给多个GPU,GPU计算每个批次更新梯度。CPU收集完多个GPU更新梯度,计算平均梯度,更新参数。继续计算更新梯度。处理速度取决最慢GPU速度。


分布式,训练在多个工作节点(worker)。工作节点,实现计算单元。计算服务器单卡,指服务器。计算服务器多卡,多个GPU划分多个工作节点。数据量大,超过一台机器处理能力,须用分布式。


分布式TensorFlow底层通信,gRPC(google remote procedure call)。gRPC,谷歌开源高性能、跨语言RPC框架。RPC协议,远程过程调用协议,网络从远程计算机程度请求服务。


分布式部署方式。分布式运行,多个计算单元(工作节点),后端服务器部署单工作节点、多工作节点。


单工作节点部署。每台服务器运行一个工作节点,服务器多个GPU,一个工作节点可以访问多块GPU卡。代码tf.device()指定运行操作设备。优势,单机多GPU间通信,效率高。劣势,手动代码指定设备。


多工作节点部署。一台服务器运行多个工作节点。


设置CUDA_VISIBLE_DEVICES环境变量,限制各个工作节点只可见一个GPU,启动进程添加环境变量。用tf.device()指定特定GPU。多工作节点部署优势,代码简单,提高GPU使用率。劣势,工作节点通信,需部署多个工作节点。https://github.com/tobegit3hub/tensorflow_examples/tree/master/distributed_tensorflow 。


    CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=0
    CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=1
    CUDA_VISIBLE_DEVICES='0' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=0
    CUDA_VISIBLE_DEVICES='1' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=1


分布式架构。https://www.tensorflow.org/extend/architecture 。客户端(client)、服务端(server),服务端包括主节点(master)、工作节点(worker)组成。


客户端、主节点、工作节点关系。TensorFlow,客户端会话联系主节点,实际工作由工作节点实现,每个工作节点占一台设备(TensorFlow具体计算硬件抽象,CPU或GPU)。单机模式,客户端、主节点、工作节点在同一台服务器。分布模式,可不同服务器。客户端->主节点->工作节点/job:worker/task:0->/job:ps/task:0。
客户端。建立TensorFlow计算图,建立与集群交互会话层。代码包含Session()。一个客户端可同时与多个服务端相连,一具服务端也可与多个客户端相连。
服务端。运行tf.train.Server实例进程,TensroFlow执行任务集群(cluster)一部分。有主节点服务(Master service)和工作节点服务(Worker service)。运行中,一个主节点进程和数个工作节点进程,主节点进程和工作接点进程通过接口通信。单机多卡和分布式结构相同,只需要更改通信接口实现切换。
主节点服务。实现tensorflow::Session接口。通过RPC服务程序连接工作节点,与工作节点服务进程工作任务通信。TensorFlow服务端,task_index为0作业(job)。
工作节点服务。实现worker_service.proto接口,本地设备计算部分图。TensorFlow服务端,所有工作节点包含工作节点服务逻辑。每个工作节点负责管理一个或多个设备。工作节点可以是本地不同端口不同进程,或多台服务多个进程。运行TensorFlow分布式执行任务集,一个或多个作业(job)。每个作业,一个或多个相同目的任务(task)。每个任务,一个工作进程执行。作业是任务集合,集群是作业集合。
分布式机器学习框架,作业分参数作业(parameter job)和工作节点作业(worker job)。参数作业运行服务器为参数服务器(parameter server,PS),管理参数存储、更新。工作节点作业,管理无状态主要从事计算任务。模型越大,参数越多,模型参数更新超过一台机器性能,需要把参数分开到不同机器存储更新。参数服务,多台机器组成集群,类似分布式存储架构,涉及数据同步、一致性,参数存储为键值对(key-value)。分布式键值内存数据库,加参数更新操作。李沐《Parameter Server for Distributed Machine Learning》http://www.cs.cmu.edu/~muli/file/ps.pdf 。
参数存储更新在参数作业进行,模型计算在工作节点作业进行。TensorFlow分布式实现作业间数据传输,参数作业到工作节点作业前向传播,工作节点作业到参数作业反向传播。
任务。特定TensorFlow服务器独立进程,在作业中拥有对应序号。一个任务对应一个工作节点。集群->作业->任务->工作节点。


客户端、主节点、工作节点交互过程。单机多卡交互,客户端->会话运行->主节点->执行子图->工作节点->GPU0?GPU1。分布式交互,客户端->会话运行->主节点进程->执行子图1->工作节点进程1->GPU0?GPU1。《TensorFlow:Large-Scale Machine Learning on Heterogeneous distributed Systems》https://arxiv.org/abs/1603.04467v1 。


分布式模式。


数据并行。https://www.tensorflow.org/tutorials/deep_cnn 。CPU负责梯度平均、参数更新,不同GPU训练模型副本(model replica)。基于训练样例子集训练,模型有独立性。
步骤:不同GPU分别定义模型网络结构。单个GPU从数据管道读取不同数据块,前向传播,计算损失,计算当前变量梯度。所有GPU输出梯度数据转移到CPU,梯度求平均操作,模型变量更新。重复,直到模型变量收敛。
数据并行,提高SGD效率。SGD mini-batch样本,切成多份,模型复制多份,在多个模型上同时计算。多个模型计算速度不一致,CPU更新变量有同步、异步两个方案。


同步更新、异步更新。分布式随机梯度下降法,模型参数分布式存储在不同参数服务上,工作节点并行训练数据,和参数服务器通信获取模型参数。
同步随机梯度下降法(Sync-SGD,同步更新、同步训练),训练时,每个节点上工作任务读入共享参数,执行并行梯度计算,同步需要等待所有工作节点把局部梯度处好,将所有共享参数合并、累加,再一次性更新到模型参数,下一批次,所有工作节点用模型更新后参数训练。优势,每个训练批次考虑所有工作节点训练情部,损失下降稳定。劣势,性能瓶颈在最慢工作节点。异楹设备,工作节点性能不同,劣势明显。
异步随机梯度下降法(Async-SGD,异步更新、异步训练),每个工作节点任务独立计算局部梯度,异步更新到模型参数,不需执行协调、等待操作。优势,性能不存在瓶颈。劣势,每个工作节点计算梯度值发磅回参数服务器有参数更新冲突,影响算法收剑速度,损失下降过程抖动较大。
同步更新、异步更新实现区别于更新参数服务器参数策略。数据量小,各节点计算能力较均衡,用同步模型。数据量大,各机器计算性能参差不齐,用异步模式。
带备份的Sync-SGD(Sync-SDG with backup)。Jianmin Chen、Xinghao Pan、Rajat Monga、Aamy Bengio、Rafal Jozefowicz论文《Revisiting Distributed Synchronous SGD》https://arxiv.org/abs/1604.00981 。增加工作节点,解决部分工作节点计算慢问题。工作节点总数n+n*5%,n为集群工作节点数。异步更新设定接受到n个工作节点参数直接更新参数服务器模型参数,进入下一批次模型训练。计算较慢节点训练参数直接丢弃。
同步更新、异步更新有图内模式(in-graph pattern)和图间模式(between-graph pattern),独立于图内(in-graph)、图间(between-graph)概念。
图内复制(in-grasph replication),所有操作(operation)在同一个图中,用一个客户端来生成图,把所有操作分配到集群所有参数服务器和工作节点上。国内复制和单机多卡类似,扩展到多机多卡,数据分发还是在客户端一个节点上。优势,计算节点只需要调用join()函数等待任务,客户端随时提交数据就可以训练。劣势,训练数据分发在一个节点上,要分发给不同工作节点,严重影响并发训练速度。
图间复制(between-graph replication),每一个工作节点创建一个图,训练参数保存在参数服务器,数据不分发,各个工作节点独立计算,计算完成把要更新参数告诉参数服务器,参数服务器更新参数。优势,不需要数据分发,各个工作节点都创建图和读取数据训练。劣势,工作节点既是图创建者又是计算任务执行者,某个工作节点宕机影响集群工作。大数据相关深度学习推荐使用图间模式。


模型并行。切分模型,模型不同部分执行在不同设备上,一个批次样本可以在不同设备同时执行。TensorFlow尽量让相邻计算在同一台设备上完成节省网络开销。Martin Abadi、Ashish Agarwal、Paul Barham论文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》https://arxiv.org/abs/1603.04467v1 。


模型并行、数据并行,TensorFlow中,计算可以分离,参数可以分离。可以在每个设备上分配计算节点,让对应参数也在该设备上,计算参数放一起。


分布式API。https://www.tensorflow.org/deploy/distributed 。
创建集群,每个任务(task)启动一个服务(工作节点服务或主节点服务)。任务可以分布不同机器,可以同一台机器启动多个任务,用不同GPU运行。每个任务完成工作:创建一个tf.train.ClusterSpec,对集群所有任务进行描述,描述内容对所有任务相同。创建一个tf.train.Server,创建一个服务,运行相应作业计算任务。
TensorFlow分布式开发API。tf.train.ClusterSpec({"ps":ps_hosts,"worker":worke_hosts})。创建TensorFlow集群描述信息,ps、worker为作业名称,ps_phsts、worker_hosts为作业任务所在节点地址信息。tf.train.ClusterSpec传入参数,作业和任务间关系映射,映射关系任务通过IP地址、端口号表示。


    结构 tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
    可用任务 /job:local/task:0?/job:local/task:1。
    结构 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2222","worker2.example.com:2222"],"ps":["ps0.example.com:2222","ps1.example.com:2222"]})
    可用任务 /job:worker/task:0? /job:worker/task:1? /job:worker/task:2? /job:ps/task:0? /job:ps/task:1
tf.train.Server(cluster,job_name,task_index)。创建服务(主节点服务或工作节点服务),运行作业计算任务,运行任务在task_index指定机器启动。


    #任务0 
    cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
    server  = tr.train.Server(cluster,job_name="local",task_index=0) 
    #任务1 
    cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
    server  = tr.train.Server(cluster,job_name="local",task_index=1)。
自动化管理节点、监控节点工具。集群管理工具Kubernetes。
tf.device(device_name_or_function)。设定指定设备执行张量运算,批定代码运行CPU、GPU。


    #指定在task0所在机器执行Tensor操作运算 
    with tf.device("/job:ps/task:0"):
      weights_1 = tf.Variable(…)
      biases_1 = tf.Variable(…)


分布式训练代码框架。创建TensorFlow服务器集群,在该集群分布式计算数据流图。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/deploy/distributed.md 。


    import argparse
    import sys
    import tensorflow as tf
    FLAGS = None
    def main(_):
      # 第1步:命令行参数解析,获取集群信息ps_hosts、worker_hosts
      # 当前节点角色信息job_name、task_index
      ps_hosts = FLAGS.ps_hosts.split(",")
      worker_hosts = FLAGS.worker_hosts.split(",")
      # 第2步:创建当前任务节点服务器
      # Create a cluster from the parameter server and worker hosts.
      cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
      # Create and start a server for the local task.
      server = tf.train.Server(cluster,
                               job_name=FLAGS.job_name,
                               task_index=FLAGS.task_index)
      # 第3步:如果当前节点是参数服务器,调用server.join()无休止等待;如果是工作节点,执行第4步
      if FLAGS.job_name == "ps":
        server.join()
      # 第4步:构建要训练模型,构建计算图
      elif FLAGS.job_name == "worker":
        # Assigns ops to the local worker by default.
        with tf.device(tf.train.replica_device_setter(
            worker_device="/job:worker/task:%d" % FLAGS.task_index,
            cluster=cluster)):
          # Build model...
          loss = ...
          global_step = tf.contrib.framework.get_or_create_global_step()
          train_op = tf.train.AdagradOptimizer(0.01).minimize(
              loss, global_step=global_step)
        # The StopAtStepHook handles stopping after running given steps.
        # 第5步管理模型训练过程
        hooks=[tf.train.StopAtStepHook(last_step=1000000)]
        # The MonitoredTrainingSession takes care of session initialization,
        # restoring from a checkpoint, saving to a checkpoint, and closing when done
        # or an error occurs.
        with tf.train.MonitoredTrainingSession(master=server.target,
                                               is_chief=(FLAGS.task_index == 0),
                                               checkpoint_dir="/tmp/train_logs",
                                               hooks=hooks) as mon_sess:
          while not mon_sess.should_stop():
            # Run a training step asynchronously.
            # See `tf.train.SyncReplicasOptimizer` for additional details on how to
            # perform *synchronous* training.
            # mon_sess.run handles AbortedError in case of preempted PS.
            # 训练模型
            mon_sess.run(train_op)
    if __name__ == "__main__":
      parser = argparse.ArgumentParser()
      parser.register("type", "bool", lambda v: v.lower() == "true")
      # Flags for defining the tf.train.ClusterSpec
      parser.add_argument(
          "--ps_hosts",
          type=str,
          default="",
          help="Comma-separated list of hostname:port pairs"
      )
      parser.add_argument(
          "--worker_hosts",
          type=str,
          default="",
          help="Comma-separated list of hostname:port pairs"
      )
      parser.add_argument(
          "--job_name",
          type=str,
          default="",
          help="One of 'ps', 'worker'"
      )
      # Flags for defining the tf.train.Server
      parser.add_argument(
          "--task_index",
          type=int,
          default=0,
          help="Index of task within the job"
      )
      FLAGS, unparsed = parser.parse_known_args()
      tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)


分布式最佳实践。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py 。
MNIST数据集分布式训练。开设3个端口作分布式工作节点部署,2222端口参数服务器,2223端口工作节点0,2224端口工作节点1。参数服务器执行参数更新任务,工作节点0?工作节点1执行图模型训练计算任务。参数服务器/job:ps/task:0 cocalhost:2222,工作节点/job:worker/task:0 cocalhost:2223,工作节点/job:worker/task:1 cocalhost:2224。
运行代码。


    python mnist_replica.py --job_name="ps" --task_index=0
    python mnist_replica.py --job_name="worker" --task_index=0
    python mnist_replica.py --job_name="worker" --task_index=1


    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    import math
    import sys
    import tempfile
    import time
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    # 定义常量,用于创建数据流图
    flags = tf.app.flags
    flags.DEFINE_string("data_dir", "/tmp/mnist-data",
                        "Directory for storing mnist data")
    # 只下载数据,不做其他操作
    flags.DEFINE_boolean("download_only", False,
                         "Only perform downloading of data; Do not proceed to "
                         "session preparation, model definition or training")
    # task_index从0开始。0代表用来初始化变量的第一个任务
    flags.DEFINE_integer("task_index", None,
                         "Worker task index, should be >= 0. task_index=0 is "
                         "the master worker task the performs the variable "
                         "initialization ")
    # 每台机器GPU个数,机器没有GPU为0
    flags.DEFINE_integer("num_gpus", 1,
                         "Total number of gpus for each machine."
                         "If you don't use GPU, please set it to '0'")
    # 同步训练模型下,设置收集工作节点数量。默认工作节点总数
    flags.DEFINE_integer("replicas_to_aggregate", None,
                         "Number of replicas to aggregate before parameter update"
                         "is applied (For sync_replicas mode only; default: "
                         "num_workers)")
    flags.DEFINE_integer("hidden_units", 100,
                         "Number of units in the hidden layer of the NN")
    # 训练次数
    flags.DEFINE_integer("train_steps", 200,
                         "Number of (global) training steps to perform")
    flags.DEFINE_integer("batch_size", 100, "Training batch size")
    flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
    # 使用同步训练、异步训练
    flags.DEFINE_boolean("sync_replicas", False,
                         "Use the sync_replicas (synchronized replicas) mode, "
                         "wherein the parameter updates from workers are aggregated "
                         "before applied to avoid stale gradients")
    # 如果服务器已经存在,采用gRPC协议通信;如果不存在,采用进程间通信
    flags.DEFINE_boolean(
        "existing_servers", False, "Whether servers already exists. If True, "
        "will use the worker hosts via their GRPC URLs (one client process "
        "per worker host). Otherwise, will create an in-process TensorFlow "
        "server.")
    # 参数服务器主机
    flags.DEFINE_string("ps_hosts","localhost:2222",
                        "Comma-separated list of hostname:port pairs")
    # 工作节点主机
    flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
                        "Comma-separated list of hostname:port pairs")
    # 本作业是工作节点还是参数服务器
    flags.DEFINE_string("job_name", None,"job name: worker or ps")
    FLAGS = flags.FLAGS
    IMAGE_PIXELS = 28
    def main(unused_argv):
      mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
      if FLAGS.download_only:
        sys.exit(0)
      if FLAGS.job_name is None or FLAGS.job_name == "":
        raise ValueError("Must specify an explicit `job_name`")
      if FLAGS.task_index is None or FLAGS.task_index =="":
        raise ValueError("Must specify an explicit `task_index`")
      print("job name = %s" % FLAGS.job_name)
      print("task index = %d" % FLAGS.task_index)
      #Construct the cluster and start the server
      # 读取集群描述信息
      ps_spec = FLAGS.ps_hosts.split(",")
      worker_spec = FLAGS.worker_hosts.split(",")
      # Get the number of workers.
      num_workers = len(worker_spec)
      # 创建TensorFlow集群描述对象
      cluster = tf.train.ClusterSpec({
          "ps": ps_spec,
          "worker": worker_spec})
      # 为本地执行任务创建TensorFlow Server对象。
      if not FLAGS.existing_servers:
        # Not using existing servers. Create an in-process server.
        # 创建本地Sever对象,从tf.train.Server这个定义开始,每个节点开始不同
        # 根据执行的命令的参数(作业名字)不同,决定这个任务是哪个任务
        # 如果作业名字是ps,进程就加入这里,作为参数更新的服务,等待其他工作节点给它提交参数更新的数据
        # 如果作业名字是worker,就执行后面的计算任务
        server = tf.train.Server(
            cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
        # 如果是参数服务器,直接启动即可。这里,进程就会阻塞在这里
        # 下面的tf.train.replica_device_setter代码会将参数批定给ps_server保管
        if FLAGS.job_name == "ps":
          server.join()
      # 处理工作节点
      # 找出worker的主节点,即task_index为0的点
      is_chief = (FLAGS.task_index == 0)
      # 如果使用gpu
      if FLAGS.num_gpus > 0:
        # Avoid gpu allocation conflict: now allocate task_num -> #gpu
        # for each worker in the corresponding machine
        gpu = (FLAGS.task_index % FLAGS.num_gpus)
        # 分配worker到指定gpu上运行
        worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
      # 如果使用cpu
      elif FLAGS.num_gpus == 0:
        # Just allocate the CPU to worker server
        # 把cpu分配给worker
        cpu = 0
        worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
      # The device setter will automatically place Variables ops on separate
      # parameter servers (ps). The non-Variable ops will be placed on the workers.
      # The ps use CPU and workers use corresponding GPU
      # 用tf.train.replica_device_setter将涉及变量操作分配到参数服务器上,使用CPU。将涉及非变量操作分配到工作节点上,使用上一步worker_device值。
      # 在这个with语句之下定义的参数,会自动分配到参数服务器上去定义。如果有多个参数服务器,就轮流循环分配
      with tf.device(
          tf.train.replica_device_setter(
              worker_device=worker_device,
              ps_device="/job:ps/cpu:0",
              cluster=cluster)):
    
        # 定义全局步长,默认值为0
        global_step = tf.Variable(0, name="global_step", trainable=False)
        # Variables of the hidden layer
        # 定义隐藏层参数变量,这里是全连接神经网络隐藏层
        hid_w = tf.Variable(
            tf.truncated_normal(
                [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
                stddev=1.0 / IMAGE_PIXELS),
            name="hid_w")
        hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
        # Variables of the softmax layer
        # 定义Softmax 回归层参数变量
        sm_w = tf.Variable(
            tf.truncated_normal(
                [FLAGS.hidden_units, 10],
                stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
            name="sm_w")
        sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
        # Ops: located on the worker specified with FLAGS.task_index
        # 定义模型输入数据变量
        x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
        y_ = tf.placeholder(tf.float32, [None, 10])
        # 构建隐藏层
        hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
        hid = tf.nn.relu(hid_lin)
        # 构建损失函数和优化器
        y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
        cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
        # 异步训练模式:自己计算完成梯度就去更新参数,不同副本之间不会去协调进度
        opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
        # 同步训练模式
        if FLAGS.sync_replicas:
          if FLAGS.replicas_to_aggregate is None:
            replicas_to_aggregate = num_workers
          else:
            replicas_to_aggregate = FLAGS.replicas_to_aggregate
          # 使用SyncReplicasOptimizer作优化器,并且是在图间复制情况下
          # 在图内复制情况下将所有梯度平均
          opt = tf.train.SyncReplicasOptimizer(
              opt,
              replicas_to_aggregate=replicas_to_aggregate,
              total_num_replicas=num_workers,
              name="mnist_sync_replicas")
        train_step = opt.minimize(cross_entropy, global_step=global_step)
        if FLAGS.sync_replicas:
          local_init_op = opt.local_step_init_op
          if is_chief:
            # 所有进行计算工作节点里一个主工作节点(chief)
            # 主节点负责初始化参数、模型保存、概要保存
            local_init_op = opt.chief_init_op
          ready_for_local_init_op = opt.ready_for_local_init_op
          # Initial token and chief queue runners required by the sync_replicas mode
          # 同步训练模式所需初始令牌、主队列
          chief_queue_runner = opt.get_chief_queue_runner()
          sync_init_op = opt.get_init_tokens_op()
        init_op = tf.global_variables_initializer()
        train_dir = tempfile.mkdtemp()
        if FLAGS.sync_replicas:
          # 创建一个监管程序,用于统计训练模型过程中的信息
          # lodger 是保存和加载模型路径
          # 启动就会去这个logdir目录看是否有检查点文件,有的话就自动加载
          # 没有就用init_op指定初始化参数
          # 主工作节点(chief)负责模型参数初始化工作
          # 过程中,其他工作节点等待主节眯完成初始化工作,初始化完成后,一起开始训练数据
          # global_step值是所有计算节点共享的
          # 在执行损失函数最小值时自动加1,通过global_step知道所有计算节点一共计算多少步
          sv = tf.train.Supervisor(
              is_chief=is_chief,
              logdir=train_dir,
              init_op=init_op,
              local_init_op=local_init_op,
              ready_for_local_init_op=ready_for_local_init_op,
              recovery_wait_secs=1,
              global_step=global_step)
        else:
          sv = tf.train.Supervisor(
              is_chief=is_chief,
              logdir=train_dir,
              init_op=init_op,
              recovery_wait_secs=1,
              global_step=global_step)
        # 创建会话,设置属性allow_soft_placement为True
        # 所有操作默认使用被指定设置,如GPU
        # 如果该操作函数没有GPU实现,自动使用CPU设备
        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False,
            device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
        # The chief worker (task_index==0) session will prepare the session,
        # while the remaining workers will wait for the preparation to complete.
        # 主工作节点(chief),task_index为0节点初始化会话
        # 其余工作节点等待会话被初始化后进行计算
        if is_chief:
          print("Worker %d: Initializing session..." % FLAGS.task_index)
        else:
          print("Worker %d: Waiting for session to be initialized..." %
                FLAGS.task_index)
        if FLAGS.existing_servers:
          server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
          print("Using existing server at: %s" % server_grpc_url)
          # 创建TensorFlow会话对象,用于执行TensorFlow图计算
          # prepare_or_wait_for_session需要参数初始化完成且主节点准备好后,才开始训练
          sess = sv.prepare_or_wait_for_session(server_grpc_url,
                                                config=sess_config)
        else:
          sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
        print("Worker %d: Session initialization complete." % FLAGS.task_index)
        if FLAGS.sync_replicas and is_chief:
          # Chief worker will start the chief queue runner and call the init op.
          sess.run(sync_init_op)
          sv.start_queue_runners(sess, [chief_queue_runner])
        # Perform training
        # 执行分布式模型训练
        time_begin = time.time()
        print("Training begins @ %f" % time_begin)
        local_step = 0
        while True:
          # Training feed
          # 读入MNIST训练数据,默认每批次100张图片
          batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
          train_feed = {x: batch_xs, y_: batch_ys}
          _, step = sess.run([train_step, global_step], feed_dict=train_feed)
          local_step += 1
          now = time.time()
          print("%f: Worker %d: training step %d done (global step: %d)" %
                (now, FLAGS.task_index, local_step, step))
          if step >= FLAGS.train_steps:
            break
        time_end = time.time()
        print("Training ends @ %f" % time_end)
        training_time = time_end - time_begin
        print("Training elapsed time: %f s" % training_time)
        # Validation feed
        # 读入MNIST验证数据,计算验证的交叉熵
        val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
        val_xent = sess.run(cross_entropy, feed_dict=val_feed)
        print("After %d training step(s), validation cross entropy = %g" %
              (FLAGS.train_steps, val_xent))
    if __name__ == "__main__":
      tf.app.run()


参考资料:
《TensorFlow技术解析与实战》


欢迎推荐上海机器学习工作机会,我的微信:qingxingfengzi
09-10 11:33