目录
基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成
相关文章
DL之DCGAN(Keras框架):基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成(保存h5模型→加载模型)
DL之DCGAN(Keras框架):基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成(保存h5模型→加载模型)实现
基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成
设计思路
输出结果
1. X像素取值范围是[-1.0, 1.0] 2. _________________________________________________________________ 3. Layer (type) Output Shape Param # 4. ================================================================= 5. dense_1 (Dense) (None, 1024) 103424 6. _________________________________________________________________ 7. activation_1 (Activation) (None, 1024) 0 8. _________________________________________________________________ 9. dense_2 (Dense) (None, 6272) 6428800 10. _________________________________________________________________ 11. batch_normalization_1 (Batch (None, 6272) 25088 12. _________________________________________________________________ 13. activation_2 (Activation) (None, 6272) 0 14. _________________________________________________________________ 15. reshape_1 (Reshape) (None, 7, 7, 128) 0 16. _________________________________________________________________ 17. up_sampling2d_1 (UpSampling2 (None, 14, 14, 128) 0 18. _________________________________________________________________ 19. conv2d_1 (Conv2D) (None, 14, 14, 64) 204864 20. _________________________________________________________________ 21. activation_3 (Activation) (None, 14, 14, 64) 0 22. _________________________________________________________________ 23. up_sampling2d_2 (UpSampling2 (None, 28, 28, 64) 0 24. _________________________________________________________________ 25. conv2d_2 (Conv2D) (None, 28, 28, 1) 1601 26. _________________________________________________________________ 27. activation_4 (Activation) (None, 28, 28, 1) 0 28. ================================================================= 29. Total params: 6,763,777 30. Trainable params: 6,751,233 31. Non-trainable params: 12,544 32. _________________________________________________________________ 33. _________________________________________________________________ 34. Layer (type) Output Shape Param # 35. ================================================================= 36. conv2d_3 (Conv2D) (None, 28, 28, 64) 1664 37. _________________________________________________________________ 38. activation_5 (Activation) (None, 28, 28, 64) 0 39. _________________________________________________________________ 40. max_pooling2d_1 (MaxPooling2 (None, 14, 14, 64) 0 41. _________________________________________________________________ 42. conv2d_4 (Conv2D) (None, 10, 10, 128) 204928 43. _________________________________________________________________ 44. activation_6 (Activation) (None, 10, 10, 128) 0 45. _________________________________________________________________ 46. max_pooling2d_2 (MaxPooling2 (None, 5, 5, 128) 0 47. _________________________________________________________________ 48. flatten_1 (Flatten) (None, 3200) 0 49. _________________________________________________________________ 50. dense_3 (Dense) (None, 1024) 3277824 51. _________________________________________________________________ 52. activation_7 (Activation) (None, 1024) 0 53. _________________________________________________________________ 54. dense_4 (Dense) (None, 1) 1025 55. _________________________________________________________________ 56. activation_8 (Activation) (None, 1) 0 57. ================================================================= 58. Total params: 3,485,441 59. Trainable params: 3,485,441 60. Non-trainable params: 0 61. _________________________________________________________________ 62. 2020-11-24 21:53:56.659897: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 63. (25, 28, 28, 1)
核心代码
1. def generator_model(): 2. model = Sequential() 3. model.add(Dense(input_dim=100, units=1024)) # 1034 1024 4. model.add(Activation('tanh')) 5. model.add(Dense(128*7*7)) 6. model.add(BatchNormalization()) 7. model.add(Activation('tanh')) 8. model.add(Reshape((7, 7, 128), input_shape=(128*7*7,))) 9. model.add(UpSampling2D(size=(2, 2))) 10. model.add(Conv2D(64, (5, 5), padding='same')) 11. model.add(Activation('tanh')) 12. model.add(UpSampling2D(size=(2, 2))) 13. model.add(Conv2D(1, (5, 5), padding='same')) 14. model.add(Activation('tanh')) 15. return model 16. def discriminator_model(): # 定义鉴别网络:输入一张图像,输出0(伪造)/1(真实) 17. model = Sequential() 18. model.add( 19. Conv2D(64, (5, 5), 20. padding='same', 21. input_shape=(28, 28, 1)) 22. ) 23. model.add(Activation('tanh')) 24. model.add(MaxPooling2D(pool_size=(2, 2))) 25. model.add(Conv2D(128, (5, 5))) 26. model.add(Activation('tanh')) 27. model.add(MaxPooling2D(pool_size=(2, 2))) 28. model.add(Flatten()) 29. model.add(Dense(1024)) 30. model.add(Activation('tanh')) 31. model.add(Dense(1)) 32. model.add(Activation('sigmoid')) 33. return model 34. 35. 36. g = generator_model() 37. g.summary() 38. 39. d = discriminator_model() 40. d.summary()