2022-01-02 18:41:16.826148: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2) Traceback (most recent call last): File "E:/Code/PyCharm/TensorFlow学习/Keras/自定义评估指标.py", line 59, in <module> validation_split=0.2 File "D:\Anaconda\lib\site-packages\keras\engine\training.py", line 1184, in fit tmp_logs = self.train_function(iterator) File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\def_function.py", line 885, in __call__ result = self._call(*args, **kwds) File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\def_function.py", line 933, in _call self._initialize(args, kwds, add_initializers_to=initializers) File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\def_function.py", line 760, in _initialize *args, **kwds)) File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\function.py", line 3066, in _get_concrete_function_internal_garbage_collected graph_function, _ = self._maybe_define_function(args, kwargs) File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\function.py", line 3463, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\function.py", line 3308, in _create_graph_function capture_by_value=self._capture_by_value), File "D:\Anaconda\lib\site-packages\tensorflow\python\framework\func_graph.py", line 1007, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\def_function.py", line 668, in wrapped_fn out = weak_wrapped_fn().__wrapped__(*args, **kwds) File "D:\Anaconda\lib\site-packages\tensorflow\python\framework\func_graph.py", line 994, in wrapper raise e.ag_error_metadata.to_exception(e) TypeError: in user code: D:\Anaconda\lib\site-packages\keras\engine\training.py:853 train_function * return step_function(self, iterator) TypeError: tf__update_state() got an unexpected keyword argument 'sample_weight'
问题原因:
使用TensorFlow实现一些自定义层或者指标等,在一些需要实现的函数参数没有加上默认参数,导致模型训练时传递参数出现问题
解决办法:
在指定地方添加上默认参数
def update_state(self, y_true, y_pred, sample_weight=None): y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1)) values = tf.cast(y_true, 'int32') == tf.cast(y_pred, 'int32') values = tf.cast(values, 'float32') self.true_positives.assign_add(tf.reduce_sum(values))