文军的烹饪实验室

文军的烹饪实验室

使用函数式API构建模型,使得模型可以处理多输入多输出。

1、查看tensorflow版本

import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

MLP实现fashion_mnist数据集分类(2)-函数式API构建模型(tensorflow)-LMLPHP

2、fashion_mnist数据集分类模型

2.1 使用Sequential构建模型

from keras import Sequential
from keras.layers import Flatten,Dense,Dropout
from keras import Input

model = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=64,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

MLP实现fashion_mnist数据集分类(2)-函数式API构建模型(tensorflow)-LMLPHP

2.2 使用函数式API构建模型

from keras.layers import Flatten,Dense,Dropout
from keras import Input,Model

input = Input(shape=(28,28))
x = Flatten()(input)
x = Dense(units=256,kernel_initializer='normal',activation='relu')(x)
x = Dropout(rate=0.1)(x)
x = Dense(units=64,kernel_initializer='normal',activation='relu')(x)
x = Dropout(rate=0.1)(x)
output = Dense(units=10,kernel_initializer='normal',activation='softmax')(x)
model = Model(inputs=input, outputs=output)
model.summary()

MLP实现fashion_mnist数据集分类(2)-函数式API构建模型(tensorflow)-LMLPHP
可以看到两个模型的结构是一样的,编译和训练也是一样的。

3、使用函数式API搭建多输入多输出模型

两个输入一个输出,对比两个图片是否一样。

from keras.layers import Flatten,Dense,Dropout
from keras import Input,Model
import keras

input1 = Input(shape=(28,28))
input2 = Input(shape=(28,28))
x1 = Flatten()(input1)
x2 = Flatten()(input2)
x = keras.layers.concatenate([x1,x2])
x = Dense(units=256,kernel_initializer='normal',activation='relu')(x)
x = Dropout(rate=0.1)(x)
x = Dense(units=64,kernel_initializer='normal',activation='relu')(x)
x = Dropout(rate=0.1)(x)
output = Dense(units=1,kernel_initializer='normal',activation='sigmoid')(x)
model = Model(inputs=[input1,input2], outputs=output) # 两个输入,一个输出
model.summary()

MLP实现fashion_mnist数据集分类(2)-函数式API构建模型(tensorflow)-LMLPHP

05-04 06:51