解决TypeError: tf__update_state() got an unexpected keyword argument ‘sample_weight‘

简介: 解决TypeError: tf__update_state() got an unexpected keyword argument ‘sample_weight‘

2022-01-02 18:41:16.826148: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Traceback (most recent call last):
  File "E:/Code/PyCharm/TensorFlow学习/Keras/自定义评估指标.py", line 59, in <module>
    validation_split=0.2
  File "D:\Anaconda\lib\site-packages\keras\engine\training.py", line 1184, in fit
    tmp_logs = self.train_function(iterator)
  File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\def_function.py", line 885, in __call__
    result = self._call(*args, **kwds)
  File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\def_function.py", line 933, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\def_function.py", line 760, in _initialize
    *args, **kwds))
  File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\function.py", line 3066, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\function.py", line 3463, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\function.py", line 3308, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "D:\Anaconda\lib\site-packages\tensorflow\python\framework\func_graph.py", line 1007, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "D:\Anaconda\lib\site-packages\tensorflow\python\eager\def_function.py", line 668, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "D:\Anaconda\lib\site-packages\tensorflow\python\framework\func_graph.py", line 994, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:
    D:\Anaconda\lib\site-packages\keras\engine\training.py:853 train_function  *
        return step_function(self, iterator)
    TypeError: tf__update_state() got an unexpected keyword argument 'sample_weight'

问题原因:

使用TensorFlow实现一些自定义层或者指标等,在一些需要实现的函数参数没有加上默认参数,导致模型训练时传递参数出现问题

解决办法:

在指定地方添加上默认参数

def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
        values = tf.cast(y_true, 'int32') == tf.cast(y_pred, 'int32')
        values = tf.cast(values, 'float32')
        self.true_positives.assign_add(tf.reduce_sum(values))


目录
相关文章
|
IDE PyTorch 网络安全
|
监控 异构计算
Jetson 学习笔记(八):htop查看CPU占用情况和jtop监控CPU和GPU
在NVIDIA Jetson平台上使用htop和jtop工具来监控CPU、GPU和内存的使用情况,并提供了安装和使用这些工具的具体命令。
1541 0
|
算法 程序员 Python
用伪代码表示算法
在算法设计和编程中,伪代码是一种非常重要的工具。它允许我们以一种既非特定编程语言又足够详细的方式来描述算法。伪代码的目标是提供一个清晰、简洁的算法表示,而不必拘泥于特定的编程语法或规则。本文将探讨伪代码的优势,并提供一个用伪代码表示算法的例子。
1229 1
|
机器学习/深度学习 算法 开发工具
【YOLOv8量化】普通CPU上加速推理可达100+FPS
【YOLOv8量化】普通CPU上加速推理可达100+FPS
2432 0
|
运维 前端开发 算法
开源中国【专访】 | CodeFuse:让研发变得更简单
CodeFuse 是蚂蚁集团自研的代码生成大模型,旨在简化研发流程,提供智能建议和实时支持。它能自动生成代码、添加注释、生成测试用例并优化代码。通过创新的 Rodimus 架构,CodeFuse 实现了“小体量,大能量”,显著提升了资源利用效率。其特色功能“图生代码”可将设计图一键转换为代码,准确率超过90%,大幅提高前端开发效率。此外,CodeFuse 还引入了“Code Graph”概念,帮助 LLM 更好地理解仓库级代码结构,缩短任务处理时间。未来,CodeFuse 将致力于全生命周期的研发支持,涵盖需求分析、代码生成到运维监测,推动行业技术迭代与创新。
869 3
|
Python
PyCharm View as Array 查看数组
PyCharm View as Array 查看数组
693 1
|
人工智能 自然语言处理 数据处理
【AI大模型】Transformers大模型库(十三):Datasets库
【AI大模型】Transformers大模型库(十三):Datasets库
963 0
|
机器学习/深度学习 人工智能 并行计算
英伟达系列显卡大解析B100、H200、L40S、A100、A800、H100、H800、V100如何选择,含架构技术和性能对比带你解决疑惑
英伟达系列显卡大解析B100、H200、L40S、A100、A800、H100、H800、V100如何选择,含架构技术和性能对比带你解决疑惑
英伟达系列显卡大解析B100、H200、L40S、A100、A800、H100、H800、V100如何选择,含架构技术和性能对比带你解决疑惑
|
IDE Java Linux
在Maven中设置JVM系统参数及Java应用调试实例
在Maven中设置JVM系统参数及Java应用调试实例
811 0