本文介绍了Tensorflow联合中的ResNet模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我尝试在Tensorflow Federated的图像分类"教程中自定义模型. (它最初使用的是顺序模型)我使用Keras ResNet50,但是当它开始训练时,总是出现错误形状不兼容"

I tried to customize the model in "Image classification" tutorial in Tensorflow Federated. (It originally used a sequential model)I use Keras ResNet50 but when it began to train, there is always an error "Incompatible shapes"

这是我的代码:

NUM_CLIENTS = 4
NUM_EPOCHS = 10
BATCH_SIZE = 2
SHUFFLE_BUFFER = 5

def create_compiled_keras_model():
  model = tf.keras.applications.resnet.ResNet50(include_top=False, weights='imagenet',
                                                input_tensor=tf.keras.layers.Input(shape=(100,
                                                300, 3)), pooling=None)

  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.02),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  return model


def model_fn():
  keras_model = create_compiled_keras_model()
  return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

iterative_process = tff.learning.build_federated_averaging_process(model_fn)

错误信息:在此处输入图片描述

我觉得形状不兼容,因为时代和客户信息丢失了.如果有人能给我一个提示,将非常感激.

I feel that the shape is incompatible because the epoch and clients information were somehow missing. Would be very thankful if someone could give me a hint.

更新:

tff.learning.build_federated_averaging_process

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-164-dac26193d9d8> in <module>()
----> 1 iterative_process = tff.learning.build_federated_averaging_process(model_fn)
      2
      3 # iterative_process = build_federated_averaging_process(model_fn)

13 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/federated_averaging.py in build_federated_averaging_process(model_fn, server_optimizer_fn, client_weight_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    165   return optimizer_utils.build_model_delta_optimizer_process(
    166       model_fn, client_fed_avg, server_optimizer_fn,
--> 167       stateful_delta_aggregate_fn, stateful_model_broadcast_fn)

/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py in build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    349   # still need this.
    350   with tf.Graph().as_default():
--> 351     dummy_model_for_metadata = model_utils.enhance(model_fn())
    352
    353   # ===========================================================================

<ipython-input-159-b2763ace8e5b> in model_fn()
      1 def model_fn():
      2   keras_model = model
----> 3   return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/keras_utils.py in from_compiled_keras_model(keras_model, dummy_batch)
    211   # Model.test_on_batch() once before asking for metrics.
    212   if isinstance(dummy_tensors, collections.Mapping):
--> 213     keras_model.test_on_batch(**dummy_tensors)
    214   else:
    215     keras_model.test_on_batch(*dummy_tensors)

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in test_on_batch(self, x, y, sample_weight, reset_metrics)
   1007         sample_weight=sample_weight,
   1008         reset_metrics=reset_metrics,
-> 1009         standalone=True)
   1010     outputs = (
   1011         outputs['total_loss'] + outputs['output_losses'] + outputs['metrics'])

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in test_on_batch(model, x, y, sample_weight, reset_metrics, standalone)
    503       y,
    504       sample_weights=sample_weights,
--> 505       output_loss_metrics=model._output_loss_metrics)
    506
    507   if reset_metrics:

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    568         xla_context.Exit()
    569     else:
--> 570       result = self._call(*args, **kwds)
    571
    572     if tracing_count == self._get_tracing_count():

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    606       # In this case we have not created variables on the first call. So we can
    607       # run the first trace but we should fail if variables are created.
--> 608       results = self._stateful_fn(*args, **kwds)
    609       if self._created_variables:
    610         raise ValueError("Creating variables on a non-first call to a function"

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   2407     """Calls a graph function specialized to the inputs."""
   2408     with self._lock:
-> 2409       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
   2410     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2411

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2765
   2766       self._function_cache.missed.add(call_context_key)
-> 2767       graph_function = self._create_graph_function(args, kwargs)
   2768       self._function_cache.primary[cache_key] = graph_function
   2769       return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2655             arg_names=arg_names,
   2656             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2657             capture_by_value=self._capture_by_value),
   2658         self._function_attributes,
   2659         # Tell the ConcreteFunction to clean up its graph once it goes out of

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

AssertionError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_eager.py:345 test_on_batch  *
        with backend.eager_learning_phase_scope(0):
    /usr/lib/python3.6/contextlib.py:81 __enter__
        return next(self.gen)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py:425 eager_learning_phase_scope
        assert ops.executing_eagerly_outside_functions()

    AssertionError:

推荐答案

我有同样的问题:如果我执行此行状态,指标= iterative_process.next(状态,federated_train_data)print('第1轮,metrics = {}'.format(metrics))

I have same problem:if I execute this linestate, metrics = iterative_process.next(state, federated_train_data)print('round 1, metrics={}'.format(metrics))

我发现此错误InvalidArgumentError:找到2个根错误. (0)无效的参数:默认MaxPoolingOp仅在设备类型CPU上支持NHWC [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset/_140]] (1)无效的参数:默认MaxPoolingOp仅在设备类型CPU上支持NHWC [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset]0次成功操作.忽略0个派生错误.

I find this errorInvalidArgumentError: 2 root error(s) found. (0) Invalid argument: Default MaxPoolingOp only supports NHWC on device type CPU [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset/_140]] (1) Invalid argument: Default MaxPoolingOp only supports NHWC on device type CPU [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential/vgg16/block1_pool/MaxPool}}]] [[subcomputation/StatefulPartitionedCall_1/ReduceDataset]]0 successful operations.0 derived errors ignored.

知道我雇用了VGG16您对这种类型的错误有任何想法吗

knowin that I employe VGG16have you any idea on this type of error

这篇关于Tensorflow联合中的ResNet模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

06-27 01:38