From: Sebastián Ramírez Date: Thu, 19 Jun 2025 14:29:32 +0000 (+0200) Subject: ✅ Simplify tests setup, one test file for multiple source variants (#1407) X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=f1c9d15525b4e9da0644e95991ce145c34f8b86e;p=thirdparty%2Ffastapi%2Fsqlmodel.git ✅ Simplify tests setup, one test file for multiple source variants (#1407) --- diff --git a/tests/conftest.py b/tests/conftest.py index 9e8a45cc..98a4d2b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,10 @@ import shutil import subprocess import sys +from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, Generator, List, Union +from unittest.mock import patch import pytest from pydantic import BaseModel @@ -26,7 +28,7 @@ def clear_sqlmodel() -> Any: @pytest.fixture() -def cov_tmp_path(tmp_path: Path): +def cov_tmp_path(tmp_path: Path) -> Generator[Path, None, None]: yield tmp_path for coverage_path in tmp_path.glob(".coverage*"): coverage_destiny_path = top_level_path / coverage_path.name @@ -53,8 +55,8 @@ def coverage_run(*, module: str, cwd: Union[str, Path]) -> subprocess.CompletedP def get_testing_print_function( calls: List[List[Union[str, Dict[str, Any]]]], ) -> Callable[..., Any]: - def new_print(*args): - data = [] + def new_print(*args: Any) -> None: + data: List[Any] = [] for arg in args: if isinstance(arg, BaseModel): data.append(arg.model_dump()) @@ -71,6 +73,19 @@ def get_testing_print_function( return new_print +@dataclass +class PrintMock: + calls: List[Any] = field(default_factory=list) + + +@pytest.fixture(name="print_mock") +def print_mock_fixture() -> Generator[PrintMock, None, None]: + print_mock = PrintMock() + new_print = get_testing_print_function(print_mock.calls) + with patch("builtins.print", new=new_print): + yield print_mock + + needs_pydanticv2 = pytest.mark.skipif(not IS_PYDANTIC_V2, reason="requires Pydantic v2") needs_pydanticv1 = pytest.mark.skipif(IS_PYDANTIC_V2, reason="requires Pydantic v1") diff --git a/tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_py310_tutorial002_py310.py b/tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_py310_tutorial002_py310.py deleted file mode 100644 index 9ffcd8ae..00000000 --- a/tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_py310_tutorial002_py310.py +++ /dev/null @@ -1,163 +0,0 @@ -from typing import Any, Dict, List, Union -from unittest.mock import patch - -from sqlmodel import create_engine - -from tests.conftest import get_testing_print_function, needs_py310 - - -def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]): - assert calls[0] == ["Before interacting with the database"] - assert calls[1] == [ - "Hero 1:", - { - "id": None, - "name": "Deadpond", - "secret_name": "Dive Wilson", - "age": None, - }, - ] - assert calls[2] == [ - "Hero 2:", - { - "id": None, - "name": "Spider-Boy", - "secret_name": "Pedro Parqueador", - "age": None, - }, - ] - assert calls[3] == [ - "Hero 3:", - { - "id": None, - "name": "Rusty-Man", - "secret_name": "Tommy Sharp", - "age": 48, - }, - ] - assert calls[4] == ["After adding to the session"] - assert calls[5] == [ - "Hero 1:", - { - "id": None, - "name": "Deadpond", - "secret_name": "Dive Wilson", - "age": None, - }, - ] - assert calls[6] == [ - "Hero 2:", - { - "id": None, - "name": "Spider-Boy", - "secret_name": "Pedro Parqueador", - "age": None, - }, - ] - assert calls[7] == [ - "Hero 3:", - { - "id": None, - "name": "Rusty-Man", - "secret_name": "Tommy Sharp", - "age": 48, - }, - ] - assert calls[8] == ["After committing the session"] - assert calls[9] == ["Hero 1:", {}] - assert calls[10] == ["Hero 2:", {}] - assert calls[11] == ["Hero 3:", {}] - assert calls[12] == ["After committing the session, show IDs"] - assert calls[13] == ["Hero 1 ID:", 1] - assert calls[14] == ["Hero 2 ID:", 2] - assert calls[15] == ["Hero 3 ID:", 3] - assert calls[16] == ["After committing the session, show names"] - assert calls[17] == ["Hero 1 name:", "Deadpond"] - assert calls[18] == ["Hero 2 name:", "Spider-Boy"] - assert calls[19] == ["Hero 3 name:", "Rusty-Man"] - assert calls[20] == ["After refreshing the heroes"] - assert calls[21] == [ - "Hero 1:", - { - "id": 1, - "name": "Deadpond", - "secret_name": "Dive Wilson", - "age": None, - }, - ] - assert calls[22] == [ - "Hero 2:", - { - "id": 2, - "name": "Spider-Boy", - "secret_name": "Pedro Parqueador", - "age": None, - }, - ] - assert calls[23] == [ - "Hero 3:", - { - "id": 3, - "name": "Rusty-Man", - "secret_name": "Tommy Sharp", - "age": 48, - }, - ] - assert calls[24] == ["After the session closes"] - assert calls[21] == [ - "Hero 1:", - { - "id": 1, - "name": "Deadpond", - "secret_name": "Dive Wilson", - "age": None, - }, - ] - assert calls[22] == [ - "Hero 2:", - { - "id": 2, - "name": "Spider-Boy", - "secret_name": "Pedro Parqueador", - "age": None, - }, - ] - assert calls[23] == [ - "Hero 3:", - { - "id": 3, - "name": "Rusty-Man", - "secret_name": "Tommy Sharp", - "age": 48, - }, - ] - - -@needs_py310 -def test_tutorial_001(clear_sqlmodel): - from docs_src.tutorial.automatic_id_none_refresh import tutorial001_py310 as mod - - mod.sqlite_url = "sqlite://" - mod.engine = create_engine(mod.sqlite_url) - calls = [] - - new_print = get_testing_print_function(calls) - - with patch("builtins.print", new=new_print): - mod.main() - check_calls(calls) - - -@needs_py310 -def test_tutorial_002(clear_sqlmodel): - from docs_src.tutorial.automatic_id_none_refresh import tutorial002_py310 as mod - - mod.sqlite_url = "sqlite://" - mod.engine = create_engine(mod.sqlite_url) - calls = [] - - new_print = get_testing_print_function(calls) - - with patch("builtins.print", new=new_print): - mod.main() - check_calls(calls) diff --git a/tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_tutorial002.py b/tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_tutorial002.py index 5c250471..7233e40b 100644 --- a/tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_tutorial002.py +++ b/tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_tutorial002.py @@ -1,12 +1,14 @@ +import importlib +from types import ModuleType from typing import Any, Dict, List, Union -from unittest.mock import patch +import pytest from sqlmodel import create_engine -from tests.conftest import get_testing_print_function +from tests.conftest import PrintMock, needs_py310 -def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]): +def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]) -> None: assert calls[0] == ["Before interacting with the database"] assert calls[1] == [ "Hero 1:", @@ -133,29 +135,25 @@ def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]): ] -def test_tutorial_001(): - from docs_src.tutorial.automatic_id_none_refresh import tutorial001 as mod +@pytest.fixture( + name="module", + params=[ + "tutorial001", + "tutorial002", + pytest.param("tutorial001_py310", marks=needs_py310), + pytest.param("tutorial002_py310", marks=needs_py310), + ], +) +def get_module(request: pytest.FixtureRequest) -> ModuleType: + module = importlib.import_module( + f"docs_src.tutorial.automatic_id_none_refresh.{request.param}" + ) + module.sqlite_url = "sqlite://" + module.engine = create_engine(module.sqlite_url) - mod.sqlite_url = "sqlite://" - mod.engine = create_engine(mod.sqlite_url) - calls = [] + return module - new_print = get_testing_print_function(calls) - with patch("builtins.print", new=new_print): - mod.main() - check_calls(calls) - - -def test_tutorial_002(): - from docs_src.tutorial.automatic_id_none_refresh import tutorial002 as mod - - mod.sqlite_url = "sqlite://" - mod.engine = create_engine(mod.sqlite_url) - calls = [] - - new_print = get_testing_print_function(calls) - - with patch("builtins.print", new=new_print): - mod.main() - check_calls(calls) +def test_tutorial_001_tutorial_002(print_mock: PrintMock, module: ModuleType) -> None: + module.main() + check_calls(print_mock.calls)