作为测试套件的一部分,我必须检查函数返回的numpy数组是否正确。使用np.array_equal
可以很容易地进行此检查,它返回一个布尔值,判断所有数组元素是否相同。
如果测试失败,则错误消息对于了解导致失败的原因不是特别有用。
import unittest
import numpy as np
class TestArray(unittest.TestCase):
def test_values(self):
x = np.array([1, 2])
self.assertTrue(np.array_equal(x, [1, 3]))
if __name__ == "__main__":
unittest.main()
测试失败消息:
Traceback (most recent call last):
File "test.py", line 7, in test_values
self.assertTrue(np.array_equal(x, [1, 3]))
AssertionError: False is not true
有没有一种简单的方法来检查条目是否相等,从而显示不相等的第一个条目的索引和值?我想要类似的错误消息:
AssertionError: Arrays not equal at index 1 (2 != 3)
问题来源:stackoverflow
从np.array_equal
中,我们可以获取代码并重写它,在末尾添加另一个检查
def array_equal(a1, a2):
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return False
if a1.shape != a2.shape:
return False
eq = asarray(a1 == a2) # [ True False False True]
if not bool(eq.all()):
errors = [f"idx:{idx} ({vals[0]}!={vals[1]})"
for idx, vals in enumerate(zip(a1, a2))
if not eq[idx]]
raise AssertionError("Arrays not equal " + " ".join(errors))
return True
class TestArray(unittest.TestCase):
def test_values(self):
x = np.array([1, 1, 1, 1])
self.assertTrue(array_equal(x, [1, 2, 3, 1]))
if __name__ == "__main__":
unittest.main()
给AssertionError:数组不等于idx:1(1!= 2)idx:2(1!= 3)
回答来源:stackoverflow
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。