|
1 | 1 | """Testing utilities for CmdStanPy."""
|
2 | 2 |
|
3 | 3 | import contextlib
|
4 |
| -import os |
5 |
| -import sys |
6 |
| -import unittest |
| 4 | +import logging |
| 5 | +import platform |
| 6 | +import re |
| 7 | +from typing import List, Type |
| 8 | +from unittest import mock |
7 | 9 | from importlib import reload
|
8 |
| -from io import StringIO |
| 10 | +import pytest |
9 | 11 |
|
10 | 12 |
|
11 |
| -class CustomTestCase(unittest.TestCase): |
12 |
| - # pylint: disable=invalid-name |
13 |
| - @contextlib.contextmanager |
14 |
| - def assertRaisesRegexNested(self, exc, msg): |
15 |
| - """A version of assertRaisesRegex that checks the full traceback. |
| 13 | +mark_windows_only = pytest.mark.skipif( |
| 14 | + platform.system() != 'Windows', reason='only runs on windows' |
| 15 | +) |
| 16 | +mark_not_windows = pytest.mark.skipif( |
| 17 | + platform.system() == 'Windows', reason='does not run on windows' |
| 18 | +) |
16 | 19 |
|
17 |
| - Useful for when an exception is raised from another and you wish to |
18 |
| - inspect the inner exception. |
19 |
| - """ |
20 |
| - with self.assertRaises(exc) as ctx: |
21 |
| - yield |
22 |
| - exception = ctx.exception |
23 |
| - exn_string = str(ctx.exception) |
24 |
| - while exception.__cause__ is not None: |
25 |
| - exception = exception.__cause__ |
26 |
| - exn_string += "\n" + str(exception) |
27 |
| - self.assertRegex(exn_string, msg) |
28 | 20 |
|
29 |
| - @contextlib.contextmanager |
30 |
| - def without_import(self, library, module): |
31 |
| - with unittest.mock.patch.dict('sys.modules', {library: None}): |
32 |
| - reload(module) |
33 |
| - yield |
34 |
| - reload(module) |
| 21 | +# pylint: disable=invalid-name |
| 22 | +@contextlib.contextmanager |
| 23 | +def raises_nested(expected_exception: Type[Exception], match: str) -> None: |
| 24 | + """A version of assertRaisesRegex that checks the full traceback. |
35 | 25 |
|
36 |
| - # recipe modified from https://stackoverflow.com/a/36491341 |
37 |
| - @contextlib.contextmanager |
38 |
| - def replace_stdin(self, target: str): |
39 |
| - orig = sys.stdin |
40 |
| - sys.stdin = StringIO(target) |
| 26 | + Useful for when an exception is raised from another and you wish to |
| 27 | + inspect the inner exception. |
| 28 | + """ |
| 29 | + with pytest.raises(expected_exception) as ctx: |
41 | 30 | yield
|
42 |
| - sys.stdin = orig |
43 |
| - |
44 |
| - # recipe from https://stackoverflow.com/a/34333710 |
45 |
| - @contextlib.contextmanager |
46 |
| - def modified_environ(self, *remove, **update): |
47 |
| - """ |
48 |
| - Temporarily updates the ``os.environ`` dictionary in-place. |
49 |
| -
|
50 |
| - The ``os.environ`` dictionary is updated in-place so that |
51 |
| - the modification is sure to work in all situations. |
| 31 | + exception: Exception = ctx.value |
| 32 | + lines = [] |
| 33 | + while exception: |
| 34 | + lines.append(str(exception)) |
| 35 | + exception = exception.__cause__ |
| 36 | + text = "\n".join(lines) |
| 37 | + assert re.search(match, text), f"pattern `{match}` does not match `{text}`" |
52 | 38 |
|
53 |
| - :param remove: Environment variables to remove. |
54 |
| - :param update: Dictionary of environment variables and values to |
55 |
| - add/update. |
56 |
| - """ |
57 |
| - env = os.environ |
58 |
| - update = update or {} |
59 |
| - remove = remove or [] |
60 | 39 |
|
61 |
| - # List of environment variables being updated or removed. |
62 |
| - stomped = (set(update.keys()) | set(remove)) & set(env.keys()) |
63 |
| - # Environment variables and values to restore on exit. |
64 |
| - update_after = {k: env[k] for k in stomped} |
65 |
| - # Environment variables and values to remove on exit. |
66 |
| - remove_after = frozenset(k for k in update if k not in env) |
| 40 | +@contextlib.contextmanager |
| 41 | +def without_import(library, module): |
| 42 | + with mock.patch.dict('sys.modules', {library: None}): |
| 43 | + reload(module) |
| 44 | + yield |
| 45 | + reload(module) |
67 | 46 |
|
68 |
| - try: |
69 |
| - env.update(update) |
70 |
| - for k in remove: |
71 |
| - env.pop(k, None) |
72 |
| - yield |
73 |
| - finally: |
74 |
| - env.update(update_after) |
75 |
| - for k in remove_after: |
76 |
| - env.pop(k) |
77 | 47 |
|
78 |
| - # pylint: disable=invalid-name |
79 |
| - def assertPathsEqual(self, path1, path2): |
80 |
| - """Assert paths are equal after normalization""" |
81 |
| - self.assertTrue(os.path.samefile(path1, path2)) |
| 48 | +def check_present( |
| 49 | + caplog: pytest.LogCaptureFixture, |
| 50 | + *conditions: List[tuple], |
| 51 | + clear: bool = True, |
| 52 | +) -> None: |
| 53 | + """ |
| 54 | + Check that all desired records exist. |
| 55 | + """ |
| 56 | + for condition in conditions: |
| 57 | + logger, level, message = condition |
| 58 | + if isinstance(level, str): |
| 59 | + level = getattr(logging, level) |
| 60 | + found = any( |
| 61 | + logger == logger_ and level == level_ and message.match(message_) |
| 62 | + if isinstance(message, re.Pattern) |
| 63 | + else message == message_ |
| 64 | + for logger_, level_, message_ in caplog.record_tuples |
| 65 | + ) |
| 66 | + if not found: |
| 67 | + raise ValueError(f"logs did not contain the record {condition}") |
| 68 | + if clear: |
| 69 | + caplog.clear() |
0 commit comments