Transformers 4.37 中文文档(十一)(2)https://developer.aliyun.com/article/1564971
编写测试
🤗 transformers 测试基于unittest
,但由pytest
运行,因此大多数情况下可以使用两个系统的功能。
您可以在这里阅读支持的功能,但要记住的重要事情是大多数pytest
固定装置不起作用。也不是参数化,但我们使用模块parameterized
以类似的方式工作。
参数化
经常需要多次运行相同的测试,但使用不同的参数。可以从测试内部完成,但是那样就无法仅为一个参数集运行该测试。
# test_this1.py import unittest from parameterized import parameterized class TestMathUnitTest(unittest.TestCase): @parameterized.expand( [ ("negative", -1.5, -2.0), ("integer", 1, 1.0), ("large fraction", 1.6, 1), ] ) def test_floor(self, name, input, expected): assert_equal(math.floor(input), expected)
现在,默认情况下,此测试将运行 3 次,每次将test_floor
的最后 3 个参数分配给参数列表中的相应参数。
您可以仅运行negative
和integer
参数集:
pytest -k "negative and integer" tests/test_mytest.py
或除了negative
子测试外的所有子测试,使用:
pytest -k "not negative" tests/test_mytest.py
除了使用刚提到的-k
过滤器,您可以找出每个子测试的确切名称,并使用其确切名称运行任何或所有子测试。
pytest test_this1.py --collect-only -q
它将列出:
test_this1.py::TestMathUnitTest::test_floor_0_negative test_this1.py::TestMathUnitTest::test_floor_1_integer test_this1.py::TestMathUnitTest::test_floor_2_large_fraction
现在您可以仅运行 2 个特定的子测试:
pytest test_this1.py::TestMathUnitTest::test_floor_0_negative test_this1.py::TestMathUnitTest::test_floor_1_integer
模块parameterized已经是transformers
的开发依赖项,适用于unittests
和pytest
测试。
但是,如果测试不是unittest
,则可以使用pytest.mark.parametrize
(或者您可能会看到它在一些现有测试中使用,主要在examples
下)。
以下是相同的示例,这次使用pytest
的parametrize
标记:
# test_this2.py import pytest @pytest.mark.parametrize( "name, input, expected", [ ("negative", -1.5, -2.0), ("integer", 1, 1.0), ("large fraction", 1.6, 1), ], ) def test_floor(name, input, expected): assert_equal(math.floor(input), expected)
与parameterized
一样,使用pytest.mark.parametrize
可以对要运行的子测试进行精细控制,如果-k
过滤器无法完成任务。除此之外,此参数化函数为子测试创建了一组略有不同的名称。以下是它们的样子:
pytest test_this2.py --collect-only -q
它将列出:
test_this2.py::test_floor[integer-1-1.0] test_this2.py::test_floor[negative--1.5--2.0] test_this2.py::test_floor[large fraction-1.6-1]
现在您可以仅运行特定的测试:
pytest test_this2.py::test_floor[negative--1.5--2.0] test_this2.py::test_floor[integer-1-1.0]
与前面的示例相同。
文件和目录
在测试中,我们经常需要知道事物相对于当前测试文件的位置,这并不是微不足道的,因为测试可能会从多个目录调用,或者可能位于具有不同深度的子目录中。一个辅助类transformers.test_utils.TestCasePlus
通过整理所有基本路径并提供易于访问的访问器来解决这个问题:
pathlib
对象(全部完全解析):
test_file_path
- 当前测试文件路径,即__file__
test_file_dir
- 包含当前测试文件的目录tests_dir
-tests
测试套件的目录examples_dir
-examples
测试套件的目录repo_root_dir
- 仓库的目录src_dir
-src
的目录(即transformers
子目录所在的地方)
- 字符串化路径—与上述相同,但这些返回路径作为字符串,而不是
pathlib
对象:
test_file_path_str
test_file_dir_str
tests_dir_str
examples_dir_str
repo_root_dir_str
src_dir_str
要开始使用这些,您只需要确保测试位于transformers.test_utils.TestCasePlus
的子类中。例如:
from transformers.testing_utils import TestCasePlus class PathExampleTest(TestCasePlus): def test_something_involving_local_locations(self): data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
如果您不需要通过pathlib
操纵路径,或者只需要路径作为字符串,您可以始终在pathlib
对象上调用str()
或使用以_str
结尾的访问器。例如:
from transformers.testing_utils import TestCasePlus class PathExampleTest(TestCasePlus): def test_something_involving_stringified_locations(self): examples_dir = self.examples_dir_str
临时文件和目录
使用唯一的临时文件和目录对于并行测试运行至关重要,以便测试不会覆盖彼此的数据。此外,我们希望在每个创建它们的测试结束时删除临时文件和目录。因此,使用像tempfile
这样满足这些需求的软件包是至关重要的。
然而,在调试测试时,您需要能够看到临时文件或目录中的内容,并且希望知道其确切路径,而不是在每次测试重新运行时随机化。
辅助类transformers.test_utils.TestCasePlus
最适合用于这些目的。它是unittest.TestCase
的子类,因此我们可以在测试模块中轻松继承它。
以下是其用法示例:
from transformers.testing_utils import TestCasePlus class ExamplesTests(TestCasePlus): def test_whatever(self): tmp_dir = self.get_auto_remove_tmp_dir()
此代码创建一个唯一的临时目录,并将tmp_dir
设置为其位置。
- 创建一个唯一的临时目录:
def test_whatever(self): tmp_dir = self.get_auto_remove_tmp_dir()
tmp_dir
将包含创建的临时目录的路径。它将在测试结束时自动删除。
- 创建我选择的临时目录,在测试开始之前确保它为空,并在测试结束后不清空它。
def test_whatever(self): tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
当您想要监视特定目录并确保之前的测试没有在其中留下任何数据时,这很有用。
- 您可以通过直接覆盖
before
和after
参数来覆盖默认行为,从而导致以下行为之一:
before=True
:临时目录将始终在测试开始时清除。before=False
:如果临时目录已经存在,则任何现有文件将保留在那里。after=True
:临时目录将始终在测试结束时被删除。after=False
:临时目录将始终在测试结束时保持不变。
为了安全运行等效于rm -r
的操作,只允许项目存储库检出的子目录,如果使用了显式的tmp_dir
,则不会错误地删除任何/tmp
或类似的文件系统重要部分。即请始终传递以./
开头的路径。
每个测试可以注册多个临时目录,它们都将自动删除,除非另有要求。
临时 sys.path 覆盖
如果需要临时覆盖sys.path
以从另一个测试中导入,例如,可以使用ExtendSysPath
上下文管理器。示例:
import os from transformers.testing_utils import ExtendSysPath bindir = os.path.abspath(os.path.dirname(__file__)) with ExtendSysPath(f"{bindir}/.."): from test_trainer import TrainerIntegrationCommon # noqa
跳过测试
当发现错误并编写新测试但尚未修复错误时,这很有用。为了能够将其提交到主存储库,我们需要确保在make test
期间跳过它。
方法:
- skip表示您期望测试仅在满足某些条件时才通过,否则 pytest 应该跳过运行测试。常见示例是在非 Windows 平台上跳过仅适用于 Windows 的测试,或者跳过依赖于当前不可用的外部资源的测试(例如数据库)。
- xfail表示您期望由于某种原因测试失败。一个常见示例是测试尚未实现的功能,或者尚未修复的错误。当一个测试尽管预期失败(标记为 pytest.mark.xfail)仍然通过时,它是一个 xpass,并将在测试摘要中报告。
两者之间的一个重要区别是,skip
不运行测试,而xfail
会运行。因此,如果导致一些糟糕状态的有缺陷代码会影响其他测试,请不要使用xfail
。
实施
- 以下是如何无条件跳过整个测试:
@unittest.skip("this bug needs to be fixed") def test_feature_x():
或通过 pytest:
@pytest.mark.skip(reason="this bug needs to be fixed")
或xfail
方式:
@pytest.mark.xfail def test_feature_x():
以下是如何根据测试内部检查跳过测试的方法:
def test_feature_x(): if not has_something(): pytest.skip("unsupported configuration")
或整个模块:
import pytest if not pytest.config.getoption("--custom-flag"): pytest.skip("--custom-flag is missing, skipping tests", allow_module_level=True)
或xfail
方式:
def test_feature_x(): pytest.xfail("expected to fail until bug XYZ is fixed")
- 以下是如何跳过模块中的所有测试,如果某个导入丢失:
docutils = pytest.importorskip("docutils", minversion="0.3")
- 根据条件跳过测试:
@pytest.mark.skipif(sys.version_info < (3,6), reason="requires python3.6 or higher") def test_feature_x():
或:
@unittest.skipIf(torch_device == "cpu", "Can't do half precision") def test_feature_x():
或跳过整个模块:
@pytest.mark.skipif(sys.platform == 'win32', reason="does not run on windows") class TestClass(): def test_feature_x(self):
更多详细信息、示例和方法请参见这里。
慢速测试
测试库不断增长,一些测试需要几分钟才能运行,因此我们无法等待一个小时才能完成 CI 的测试套件。因此,除了一些必要测试的例外,慢速测试应该标记为下面的示例:
from transformers.testing_utils import slow @slow def test_integration_foo():
一旦将测试标记为@slow
,要运行这些测试,请设置RUN_SLOW=1
环境变量,例如:
RUN_SLOW=1 pytest tests
一些装饰器如@parameterized
会重写测试名称,因此@slow
和其他跳过装饰器@require_*
必须在最后列出才能正常工作。以下是正确使用的示例:
@parameteriz ed.expand(...) @slow def test_integration_foo():
正如本文档开头所解释的,慢速测试会定期运行,而不是在 PR 的 CI 检查中运行。因此,可能会在提交 PR 时错过一些问题并合并。这些问题将在下一个定期 CI 作业中被捕获。但这也意味着在提交 PR 之前在您的计算机上运行慢速测试非常重要。
以下是选择哪些测试应标记为慢速的大致决策机制:
如果测试侧重于库的内部组件(例如,建模文件、标记文件、流水线),那么我们应该在非慢速测试套件中运行该测试。如果侧重于库的其他方面,例如文档或示例,则应在慢速测试套件中运行这些测试。然后,为了完善这种方法,我们应该有例外情况:
- 所有需要下载一组大量权重或大于~50MB 的数据集(例如,模型或分词器集成测试,管道集成测试)的测试都应该设置为慢速。如果要添加新模型,应该创建并上传到 hub 一个其微型版本(具有随机权重)用于集成测试。这将在以下段落中讨论。
- 所有需要进行训练但没有专门优化为快速的测试都应该设置为慢速。
- 如果其中一些应该是非慢速测试的测试非常慢,可以引入异常,并将它们设置为
@slow
。自动建模测试,将大文件保存和加载到磁盘上,是一个被标记为@slow
的测试的很好的例子。 - 如果一个测试在 CI 上完成时间不到 1 秒(包括下载),那么它应该是一个正常的测试。
总的来说,所有非慢速测试都需要完全覆盖不同的内部功能,同时保持快速。例如,通过使用具有随机权重的特别创建的微型模型进行测试,可以实现显著的覆盖范围。这些模型具有最小数量的层(例如 2),词汇量(例如 1000)等。然后@slow
测试可以使用大型慢速模型进行定性测试。要查看这些的用法,只需搜索带有tiny的模型:
grep tiny tests examples
这里是一个脚本的示例,创建了微型模型stas/tiny-wmt19-en-de。您可以轻松调整它以适应您特定模型的架构。
如果例如下载一个巨大模型的开销很大,那么很容易错误地测量运行时间,但如果您在本地测试它,下载的文件将被缓存,因此下载时间不会被测量。因此,应该查看 CI 日志中的执行速度报告(pytest --durations=0 tests
的输出)。
该报告还有助于找到未标记为慢的慢异常值,或者需要重新编写以提高速度的测试。如果您注意到 CI 上的测试套件开始变慢,此报告的顶部列表将显示最慢的测试。
测试 stdout/stderr 输出
为了测试写入stdout
和/或stderr
的函数,测试可以使用pytest
的capsys 系统来访问这些流。以下是实现这一目标的方法:
import sys def print_to_stdout(s): print(s) def print_to_stderr(s): sys.stderr.write(s) def test_result_and_stdout(capsys): msg = "Hello" print_to_stdout(msg) print_to_stderr(msg) out, err = capsys.readouterr() # consume the captured output streams # optional: if you want to replay the consumed streams: sys.stdout.write(out) sys.stderr.write(err) # test: assert msg in out assert msg in err
当然,大多数情况下,stderr
将作为异常的一部分出现,因此在这种情况下必须使用 try/except:
def raise_exception(msg): raise ValueError(msg) def test_something_exception(): msg = "Not a good value" error = "" try: raise_exception(msg) except Exception as e: error = str(e) assert msg in error, f"{msg} is in the exception:\n{error}"
另一种捕获 stdout 的方法是通过contextlib.redirect_stdout
:
from io import StringIO from contextlib import redirect_stdout def print_to_stdout(s): print(s) def test_result_and_stdout(): msg = "Hello" buffer = StringIO() with redirect_stdout(buffer): print_to_stdout(msg) out = buffer.getvalue() # optional: if you want to replay the consumed streams: sys.stdout.write(out) # test: assert msg in out
捕获 stdout 的一个重要潜在问题是它可能包含\r
字符,这些字符在正常的print
中会重置到目前为止已经打印的所有内容。对于pytest
没有问题,但对于pytest -s
,这些字符会包含在缓冲区中,因此为了能够在有和没有-s
的情况下运行测试,必须对捕获的输出进行额外的清理,使用re.sub(r'~.*\r', '', buf, 0, re.M)
。
但是,我们有一个辅助上下文管理器包装器,可以自动处理所有这些,无论它是否包含一些\r
,所以这很简单:
from transformers.testing_utils import CaptureStdout with CaptureStdout() as cs: function_that_writes_to_stdout() print(cs.out)
这里是一个完整的测试示例:
from transformers.testing_utils import CaptureStdout msg = "Secret message\r" final = "Hello World" with CaptureStdout() as cs: print(msg + final) assert cs.out == final + "\n", f"captured: {cs.out}, expecting {final}"
如果您想捕获stderr
,请改用CaptureStderr
类:
from transformers.testing_utils import CaptureStderr with CaptureStderr() as cs: function_that_writes_to_stderr() print(cs.err)
如果需要同时捕获两个流,请使用父类CaptureStd
:
from transformers.testing_utils import CaptureStd with CaptureStd() as cs: function_that_writes_to_stdout_and_stderr() print(cs.err, cs.out)
此外,为了帮助调试测试问题,默认情况下这些上下文管理器会在退出上下文时自动重放捕获的流。
Transformers 4.37 中文文档(十一)(4)https://developer.aliyun.com/article/1564974