1 引言
keras搭建神经网络模型有三种方式,第一种是使用sequential,第二种函数API,第三种是Class。第二种在IDE直接家断点就可以调试。但是在Class封装的神经网络中,如下,添加断点后,运行是不会进入到调试的。
# 模型
class test_layer(keras.layers.Layer):
def __init__(self, **kwargs):
super(test_layer, self).__init__(**kwargs)
def build(self, input_shape):
self.w = K.variable(1.)
self._trainable_weights.append(self.w)
super(test_layer, self).build(input_shape)
def call(self, x, **kwargs):
m = x * x # 在这设置断点
n = self.w * K.sqrt(x)
return m + n
# 主函数
import tensorflow as tf
import keras
import keras.backend as K
input = keras.layers.Input((100,1))
y = test_layer()(input)
model = keras.Model(input,y)
model.predict(np.ones((100,1)))
2 实现
添加断点后,通过单独调用Class中的call类,并传入实参,就可以进入到call函数进行调试查看
# 主函数
import tensorflow as tf
import keras
import keras.backend as K
test_input = np.ones((100,1)
model = test_layer()
test = model.call(test_input)