|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import platform |
|
|
import re |
|
|
|
|
|
import pytest |
|
|
|
|
|
|
|
|
def pytest_addoption(parser): |
|
|
parser.addoption("--regression", action="store_true", default=False, help="run regression tests") |
|
|
|
|
|
|
|
|
def pytest_configure(config): |
|
|
config.addinivalue_line("markers", "regression: mark regression tests") |
|
|
|
|
|
|
|
|
logger = logging.getLogger("transformers") |
|
|
|
|
|
class ErrorOnDeprecation(logging.Handler): |
|
|
def emit(self, record): |
|
|
msg = record.getMessage().lower() |
|
|
if "deprecat" in msg or "future" in msg: |
|
|
if "torch_dtype" not in msg: |
|
|
|
|
|
raise AssertionError(f"**Transformers Deprecation**: {msg}") |
|
|
|
|
|
|
|
|
handler = ErrorOnDeprecation() |
|
|
logger.addHandler(handler) |
|
|
logger.setLevel(logging.WARNING) |
|
|
|
|
|
|
|
|
def pytest_collection_modifyitems(config, items): |
|
|
if config.getoption("--regression"): |
|
|
return |
|
|
|
|
|
skip_regression = pytest.mark.skip(reason="need --regression option to run regression tests") |
|
|
for item in items: |
|
|
if "regression" in item.keywords: |
|
|
item.add_marker(skip_regression) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.hookimpl(hookwrapper=True) |
|
|
def pytest_runtest_makereport(item, call): |
|
|
""" |
|
|
Plug into the pytest test report generation to skip a specific MacOS failure caused by transformers. |
|
|
|
|
|
The error was introduced by https://github.com/huggingface/transformers/pull/37785, which results in torch.load |
|
|
failing when using torch < 2.6. |
|
|
|
|
|
Since the MacOS x86 runners need to use an older torch version, those steps are necessary to get the CI green. |
|
|
""" |
|
|
outcome = yield |
|
|
rep = outcome.get_result() |
|
|
|
|
|
|
|
|
error_msg = re.compile(r"Due to a serious vulnerability issue in `torch.load`") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if rep.failed and (rep.when in ("setup", "call")) and (platform.system() == "Darwin"): |
|
|
exc_msg = str(call.excinfo.value) |
|
|
if error_msg.search(exc_msg): |
|
|
|
|
|
rep.outcome = "skipped" |
|
|
|
|
|
|
|
|
rep.wasxfail = "Error known to occur on MacOS with older torch versions, won't be fixed" |
|
|
|