keras_API汇总积累(熟读手册)二,函数式API
2021-03-02 18:28
标签:_for cti ola parallel perm embed none 保存 ati 输入和输出均为张量,它们都可以用来定义一个模型( 1.建立Model from keras.layers import Input,Dense,TimeDistributed,Embedding,LSTM,contatenate,Maxpooling2D,Flatten from keras.models import Model inputs=Input(shape=(784,)) x=Dense(64,activation=‘relu‘)(inputs) x=Dense(64,activation=‘relu‘)(x) out=Dense(10,activation=‘softmax‘)(x) x=Embedding(output_dim=512,input_dim=1000,input_length=100)(x) lstm_out=LSTM(32)(x) x=keras.layers.concatenate([lstm,x],axis=-1) x=MaxPooling2D((3,3),strides=(1,1),padding=‘same‘)(x) z=keras.layers.add([x,y])#残差网络 x=Flatten()(x) model = Model(inputs=[a1, a2], outputs=[b1, b3, b3]) 2.编译 model=Model(inputs=inputs,outputs=out) processed_sequences=TimeDistributed(model)(input_sequences)#将图像分类模型转换成为视频分类模型,input_sequences=Input(shape=(时间序列,向量维度)) model.compile(optimizer=‘rmsprop‘,loss=‘categorical_crossentropy‘,metrics=[‘accuracy‘]) compile(optimizer, loss=None, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None) 3.训练 model.fit(data,labels) fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None) 4.评估 evaluate(x=None, y=None, batch_size=None, verbose=1, sample_weight=None, steps=None) 5.预测、 predict(x, batch_size=None, verbose=0, steps=None) a.数据并行(串) from keras.utils import mulit_gpu_model parallel_model=multi_gpu_model(model,gpu=8)#8个gpu并运的model。 b.设备并行(并) with tf.device_scope(‘/gpu:0‘): encode_a=Lstm(x1)#在一个GPU上处理第一个序列 with tf.device_scope(‘/gpu:1‘): encode_a=Lstm(x2)#在另一个GPU上处理另一个序列 with tf.device_scope(‘/cpu:0‘): merged_vector=keras.layers.concatenate([encode_a,encode_b],axis=-1)#在cpu上连接结果 c.保存并重载模型(结构权重,优化器状态) from keras.models import load_model model.save(‘my.h5‘) del model model= load_model(‘my.h5‘) d。只保存加载模型结构 json_string=model.to_json()#保存为JSON结构 yaml_string=model.to_yaml()#保存为YAML结构 from keras.models import model_from_json,model_from_yaml 重建模型:model=model_from_json(json_string) model=model_from_yaml(yaml_json) e.只保存加载模型权重 model.save_weights(‘weights.h5‘) model.load_weights(‘weights.h5‘,by_name=True)#true时,就是将名字一样的层的权重加载,不一样的层不加载 f.批量训练预测数据 model.train_on_batch(x,y) model.test_on_batch(x,y) g.验证集误差不再下降的早停; from keras.callbacks import EarlyStopping early_stopping = EarlyStopping(monitor=‘val_loss‘, patience=2) model.fit(x, y, validation_split=0.2, callbacks=[early_stopping]) h.冻结释冻模型参数 layer.trainable=True/False I.图形模型: from keras.applications.xception import Xception from keras.applications.vgg16 import VGG16 from keras.applications.vgg19 import VGG19 from keras.applications.resnet50 import ResNet50 from keras.applications.inception_v3 import InceptionV3 from keras.applications.inception_resnet_v2 import InceptionResNetV2 from keras.applications.mobilenet import MobileNet from keras.applications.densenet import DenseNet121 from keras.applications.densenet import DenseNet169 from keras.applications.densenet import DenseNet201 from keras.applications.nasnet import NASNetLarge from keras.applications.nasnet import NASNetMobile from keras.applications.mobilenet_v2 import MobileNetV2 model = VGG16(weights=‘imagenet‘, include_top=True) j. k.置换输入维度 keras.layers.Permute(dims) L.将任意表达式封装成layer对象 keras.layers.Lambda(function, output_shape=None, mask=None, arguments=None) m.keras.layers.UpSampling2D(size=(2, 2), data_format=None, interpolation=‘nearest‘) keras.layers.ZeroPadding2D(padding=(1, 1), data_format=None) keras.layers.MaxPooling2D(pool_size=(2, 2), strides=None, padding=‘valid‘, data_format=None) keras.layers.AveragePooling2D(pool_size=(2, 2), strides=None, padding=‘valid‘, data_format=None) 全局最大池化keras.layers.GlobalMaxPooling2D(data_format=None) 全局平均池化keras.layers.GlobalAveragePooling2D(data_format=None) keras_API汇总积累(熟读手册)二,函数式API 标签:_for cti ola parallel perm embed none 保存 ati 原文地址:https://www.cnblogs.com/Turing-dz/p/13030925.htmlModel
),这样的模型同 Keras 的 Sequential
模型一样,都可以被训练。model.summary()
打印出模型概述信息model.get_config()
返回包含模型配置信息的字典