[Python] python神经网络使用Keras进行模型的保存与读取

3465 3
小明Python 2022-12-24 20:05:49 | 显示全部楼层 |阅读模式
一、Keras中保存与读取的重要函数
1、model.save

model.save用于保存模型,在保存模型前,首先要利用pip install安装h5py的模块,这个模块在Keras的模型保存与读取中常常被使用,用于定义保存格式。

  1. pip install h5py
复制代码

完成安装后,可以通过如下函数保存模型。

2、load_model

load_model用于载入模型。

具体使用方式如下:

  1. model = load_model("./model.hdf5")
复制代码

代码:

  1. import numpy as np
  2. from keras.models import Sequential,load_model,save_model
  3. from keras.layers import Dense,Activation ## 全连接层
  4. from keras.datasets import mnist
  5. from keras.utils import np_utils
  6. from keras.optimizers import RMSprop
  7. # 获取训练集
  8. (X_train,Y_train),(X_test,Y_test) = mnist.load_data()
  9. # 首先进行标准化
  10. X_train = X_train.reshape(X_train.shape[0],-1)/255
  11. X_test = X_test.reshape(X_test.shape[0],-1)/255
  12. # 计算categorical_crossentropy需要对分类结果进行categorical
  13. # 即需要将标签转化为形如(nb_samples, nb_classes)的二值序列
  14. Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
  15. Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
  16. # 构建模型
  17. model = Sequential([
  18.     Dense(32,input_dim = 784),
  19.     Activation("relu"),
  20.     Dense(10),
  21.     Activation("softmax")
  22.     ]
  23. )
  24. rmsprop = RMSprop(lr = 0.001,rho = 0.9,epsilon = 1e-08,decay = 0)
  25. ## compile
  26. model.compile(loss = 'categorical_crossentropy',optimizer = rmsprop,metrics=['accuracy'])
  27. print("\ntraining")
  28. cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 100)
  29. print("\nTest")
  30. # 测试
  31. cost,accuracy = model.evaluate(X_test,Y_test)
  32. print("accuracy:",accuracy)
  33. # 保存模型
  34. model.save("./model.hdf5")
  35. # 删除现有模型
  36. del model
  37. print("model had been del")
  38. # 再次载入模型
  39. model = load_model("./model.hdf5")
  40. # 预测
  41. cost,accuracy = model.evaluate(X_test,Y_test)
  42. print("accuracy:",accuracy)
复制代码












小明Python 2022-12-25 13:48:42 | 显示全部楼层

我觉得彳亍
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

中国红客联盟公众号

联系站长QQ:5520533

admin@chnhonker.com
Copyright © 2001-2025 Discuz Team. Powered by Discuz! X3.5 ( 粤ICP备13060014号 )|天天打卡 本站已运行