keras_API汇总积累(熟读手册)二,函数式API

2021-03-02 18:28

阅读:705

标签:_for   cti   ola   parallel   perm   embed   none   保存   ati   

输入和输出均为张量,它们都可以用来定义一个模型(Model),这样的模型同 Keras 的 Sequential 模型一样,都可以被训练。

1.建立Model

from keras.layers import Input,Dense,TimeDistributed,Embedding,LSTM,contatenate,Maxpooling2D,Flatten

from keras.models import Model

inputs=Input(shape=(784,))

x=Dense(64,activation=‘relu‘)(inputs)

x=Dense(64,activation=‘relu‘)(x)

out=Dense(10,activation=‘softmax‘)(x)

x=Embedding(output_dim=512,input_dim=1000,input_length=100)(x)

lstm_out=LSTM(32)(x)

x=keras.layers.concatenate([lstm,x],axis=-1)

x=MaxPooling2D((3,3),strides=(1,1),padding=‘same‘)(x)

z=keras.layers.add([x,y])#残差网络

x=Flatten()(x)

model = Model(inputs=[a1, a2], outputs=[b1, b3, b3])

2.编译

model=Model(inputs=inputs,outputs=out)

processed_sequences=TimeDistributed(model)(input_sequences)#将图像分类模型转换成为视频分类模型,input_sequences=Input(shape=(时间序列,向量维度))

model.compile(optimizer=‘rmsprop‘,loss=‘categorical_crossentropy‘,metrics=[‘accuracy‘])

compile(optimizer, loss=None, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None)

3.训练

model.fit(data,labels)

 fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

4.评估

evaluate(x=None, y=None, batch_size=None, verbose=1, sample_weight=None, steps=None)

5.预测、

predict(x, batch_size=None, verbose=0, steps=None)

a.数据并行(串)

from keras.utils import mulit_gpu_model

parallel_model=multi_gpu_model(model,gpu=8)#8个gpu并运的model。

b.设备并行(并)

with tf.device_scope(‘/gpu:0‘):

  encode_a=Lstm(x1)#在一个GPU上处理第一个序列

with tf.device_scope(‘/gpu:1‘):

  encode_a=Lstm(x2)#在另一个GPU上处理另一个序列

with tf.device_scope(‘/cpu:0‘):

  merged_vector=keras.layers.concatenate([encode_a,encode_b],axis=-1)#在cpu上连接结果

c.保存并重载模型(结构权重,优化器状态)

from keras.models import load_model

model.save(‘my.h5‘)

del model

model= load_model(‘my.h5‘)

d。只保存加载模型结构

json_string=model.to_json()#保存为JSON结构

yaml_string=model.to_yaml()#保存为YAML结构

from keras.models import model_from_json,model_from_yaml

重建模型:model=model_from_json(json_string)

model=model_from_yaml(yaml_json)

e.只保存加载模型权重

model.save_weights(‘weights.h5‘)

model.load_weights(‘weights.h5‘,by_name=True)#true时,就是将名字一样的层的权重加载,不一样的层不加载

 f.批量训练预测数据

model.train_on_batch(x,y)

model.test_on_batch(x,y)

g.验证集误差不再下降的早停;

from keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(monitor=‘val_loss‘, patience=2)

model.fit(x, y, validation_split=0.2, callbacks=[early_stopping])

h.冻结释冻模型参数

layer.trainable=True/False

I.图形模型:

from keras.applications.xception import Xception

from keras.applications.vgg16 import VGG16

from keras.applications.vgg19 import VGG19

from keras.applications.resnet50 import ResNet50

from keras.applications.inception_v3 import InceptionV3

from keras.applications.inception_resnet_v2 import InceptionResNetV2

from keras.applications.mobilenet import MobileNet

from keras.applications.densenet import DenseNet121

from keras.applications.densenet import DenseNet169

from keras.applications.densenet import DenseNet201

from keras.applications.nasnet import NASNetLarge

from keras.applications.nasnet import NASNetMobile

from keras.applications.mobilenet_v2 import MobileNetV2

model = VGG16(weights=‘imagenet‘, include_top=True)

j.model.summary() 打印出模型概述信息

model.get_config() 返回包含模型配置信息的字典

k.置换输入维度

keras.layers.Permute(dims)

L.将任意表达式封装成layer对象

keras.layers.Lambda(function, output_shape=None, mask=None, arguments=None)

 

m.keras.layers.UpSampling2D(size=(2, 2), data_format=None, interpolation=‘nearest‘)

keras.layers.ZeroPadding2D(padding=(1, 1), data_format=None)

keras.layers.MaxPooling2D(pool_size=(2, 2), strides=None, padding=‘valid‘, data_format=None)

keras.layers.AveragePooling2D(pool_size=(2, 2), strides=None, padding=‘valid‘, data_format=None)

全局最大池化keras.layers.GlobalMaxPooling2D(data_format=None)

全局平均池化keras.layers.GlobalAveragePooling2D(data_format=None)

keras_API汇总积累(熟读手册)二,函数式API

标签:_for   cti   ola   parallel   perm   embed   none   保存   ati   

原文地址:https://www.cnblogs.com/Turing-dz/p/13030925.html


评论


亲,登录后才可以留言!