]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
✅ Simplify tests setup, one test file for multiple source variants (#1407)
authorSebastián Ramírez <tiangolo@gmail.com>
Thu, 19 Jun 2025 14:29:32 +0000 (16:29 +0200)
committerGitHub <noreply@github.com>
Thu, 19 Jun 2025 14:29:32 +0000 (16:29 +0200)
tests/conftest.py
tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_py310_tutorial002_py310.py [deleted file]
tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_tutorial002.py

index 9e8a45cc2ca02808aef5a976eb5ff56aebab328f..98a4d2b7e621a002441265217cef4a67cf15d409 100644 (file)
@@ -1,8 +1,10 @@
 import shutil
 import subprocess
 import sys
+from dataclasses import dataclass, field
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Union
+from typing import Any, Callable, Dict, Generator, List, Union
+from unittest.mock import patch
 
 import pytest
 from pydantic import BaseModel
@@ -26,7 +28,7 @@ def clear_sqlmodel() -> Any:
 
 
 @pytest.fixture()
-def cov_tmp_path(tmp_path: Path):
+def cov_tmp_path(tmp_path: Path) -> Generator[Path, None, None]:
     yield tmp_path
     for coverage_path in tmp_path.glob(".coverage*"):
         coverage_destiny_path = top_level_path / coverage_path.name
@@ -53,8 +55,8 @@ def coverage_run(*, module: str, cwd: Union[str, Path]) -> subprocess.CompletedP
 def get_testing_print_function(
     calls: List[List[Union[str, Dict[str, Any]]]],
 ) -> Callable[..., Any]:
-    def new_print(*args):
-        data = []
+    def new_print(*args: Any) -> None:
+        data: List[Any] = []
         for arg in args:
             if isinstance(arg, BaseModel):
                 data.append(arg.model_dump())
@@ -71,6 +73,19 @@ def get_testing_print_function(
     return new_print
 
 
+@dataclass
+class PrintMock:
+    calls: List[Any] = field(default_factory=list)
+
+
+@pytest.fixture(name="print_mock")
+def print_mock_fixture() -> Generator[PrintMock, None, None]:
+    print_mock = PrintMock()
+    new_print = get_testing_print_function(print_mock.calls)
+    with patch("builtins.print", new=new_print):
+        yield print_mock
+
+
 needs_pydanticv2 = pytest.mark.skipif(not IS_PYDANTIC_V2, reason="requires Pydantic v2")
 needs_pydanticv1 = pytest.mark.skipif(IS_PYDANTIC_V2, reason="requires Pydantic v1")
 
diff --git a/tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_py310_tutorial002_py310.py b/tests/test_tutorial/test_automatic_id_none_refresh/test_tutorial001_py310_tutorial002_py310.py
deleted file mode 100644 (file)
index 9ffcd8a..0000000
+++ /dev/null
@@ -1,163 +0,0 @@
-from typing import Any, Dict, List, Union
-from unittest.mock import patch
-
-from sqlmodel import create_engine
-
-from tests.conftest import get_testing_print_function, needs_py310
-
-
-def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
-    assert calls[0] == ["Before interacting with the database"]
-    assert calls[1] == [
-        "Hero 1:",
-        {
-            "id": None,
-            "name": "Deadpond",
-            "secret_name": "Dive Wilson",
-            "age": None,
-        },
-    ]
-    assert calls[2] == [
-        "Hero 2:",
-        {
-            "id": None,
-            "name": "Spider-Boy",
-            "secret_name": "Pedro Parqueador",
-            "age": None,
-        },
-    ]
-    assert calls[3] == [
-        "Hero 3:",
-        {
-            "id": None,
-            "name": "Rusty-Man",
-            "secret_name": "Tommy Sharp",
-            "age": 48,
-        },
-    ]
-    assert calls[4] == ["After adding to the session"]
-    assert calls[5] == [
-        "Hero 1:",
-        {
-            "id": None,
-            "name": "Deadpond",
-            "secret_name": "Dive Wilson",
-            "age": None,
-        },
-    ]
-    assert calls[6] == [
-        "Hero 2:",
-        {
-            "id": None,
-            "name": "Spider-Boy",
-            "secret_name": "Pedro Parqueador",
-            "age": None,
-        },
-    ]
-    assert calls[7] == [
-        "Hero 3:",
-        {
-            "id": None,
-            "name": "Rusty-Man",
-            "secret_name": "Tommy Sharp",
-            "age": 48,
-        },
-    ]
-    assert calls[8] == ["After committing the session"]
-    assert calls[9] == ["Hero 1:", {}]
-    assert calls[10] == ["Hero 2:", {}]
-    assert calls[11] == ["Hero 3:", {}]
-    assert calls[12] == ["After committing the session, show IDs"]
-    assert calls[13] == ["Hero 1 ID:", 1]
-    assert calls[14] == ["Hero 2 ID:", 2]
-    assert calls[15] == ["Hero 3 ID:", 3]
-    assert calls[16] == ["After committing the session, show names"]
-    assert calls[17] == ["Hero 1 name:", "Deadpond"]
-    assert calls[18] == ["Hero 2 name:", "Spider-Boy"]
-    assert calls[19] == ["Hero 3 name:", "Rusty-Man"]
-    assert calls[20] == ["After refreshing the heroes"]
-    assert calls[21] == [
-        "Hero 1:",
-        {
-            "id": 1,
-            "name": "Deadpond",
-            "secret_name": "Dive Wilson",
-            "age": None,
-        },
-    ]
-    assert calls[22] == [
-        "Hero 2:",
-        {
-            "id": 2,
-            "name": "Spider-Boy",
-            "secret_name": "Pedro Parqueador",
-            "age": None,
-        },
-    ]
-    assert calls[23] == [
-        "Hero 3:",
-        {
-            "id": 3,
-            "name": "Rusty-Man",
-            "secret_name": "Tommy Sharp",
-            "age": 48,
-        },
-    ]
-    assert calls[24] == ["After the session closes"]
-    assert calls[21] == [
-        "Hero 1:",
-        {
-            "id": 1,
-            "name": "Deadpond",
-            "secret_name": "Dive Wilson",
-            "age": None,
-        },
-    ]
-    assert calls[22] == [
-        "Hero 2:",
-        {
-            "id": 2,
-            "name": "Spider-Boy",
-            "secret_name": "Pedro Parqueador",
-            "age": None,
-        },
-    ]
-    assert calls[23] == [
-        "Hero 3:",
-        {
-            "id": 3,
-            "name": "Rusty-Man",
-            "secret_name": "Tommy Sharp",
-            "age": 48,
-        },
-    ]
-
-
-@needs_py310
-def test_tutorial_001(clear_sqlmodel):
-    from docs_src.tutorial.automatic_id_none_refresh import tutorial001_py310 as mod
-
-    mod.sqlite_url = "sqlite://"
-    mod.engine = create_engine(mod.sqlite_url)
-    calls = []
-
-    new_print = get_testing_print_function(calls)
-
-    with patch("builtins.print", new=new_print):
-        mod.main()
-    check_calls(calls)
-
-
-@needs_py310
-def test_tutorial_002(clear_sqlmodel):
-    from docs_src.tutorial.automatic_id_none_refresh import tutorial002_py310 as mod
-
-    mod.sqlite_url = "sqlite://"
-    mod.engine = create_engine(mod.sqlite_url)
-    calls = []
-
-    new_print = get_testing_print_function(calls)
-
-    with patch("builtins.print", new=new_print):
-        mod.main()
-    check_calls(calls)
index 5c2504710b3f1ae2234c89bffaf942938b04be8f..7233e40be8fe340a82dde5f05c65d8bb40b2e7eb 100644 (file)
@@ -1,12 +1,14 @@
+import importlib
+from types import ModuleType
 from typing import Any, Dict, List, Union
-from unittest.mock import patch
 
+import pytest
 from sqlmodel import create_engine
 
-from tests.conftest import get_testing_print_function
+from tests.conftest import PrintMock, needs_py310
 
 
-def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
+def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]) -> None:
     assert calls[0] == ["Before interacting with the database"]
     assert calls[1] == [
         "Hero 1:",
@@ -133,29 +135,25 @@ def check_calls(calls: List[List[Union[str, Dict[str, Any]]]]):
     ]
 
 
-def test_tutorial_001():
-    from docs_src.tutorial.automatic_id_none_refresh import tutorial001 as mod
+@pytest.fixture(
+    name="module",
+    params=[
+        "tutorial001",
+        "tutorial002",
+        pytest.param("tutorial001_py310", marks=needs_py310),
+        pytest.param("tutorial002_py310", marks=needs_py310),
+    ],
+)
+def get_module(request: pytest.FixtureRequest) -> ModuleType:
+    module = importlib.import_module(
+        f"docs_src.tutorial.automatic_id_none_refresh.{request.param}"
+    )
+    module.sqlite_url = "sqlite://"
+    module.engine = create_engine(module.sqlite_url)
 
-    mod.sqlite_url = "sqlite://"
-    mod.engine = create_engine(mod.sqlite_url)
-    calls = []
+    return module
 
-    new_print = get_testing_print_function(calls)
 
-    with patch("builtins.print", new=new_print):
-        mod.main()
-    check_calls(calls)
-
-
-def test_tutorial_002():
-    from docs_src.tutorial.automatic_id_none_refresh import tutorial002 as mod
-
-    mod.sqlite_url = "sqlite://"
-    mod.engine = create_engine(mod.sqlite_url)
-    calls = []
-
-    new_print = get_testing_print_function(calls)
-
-    with patch("builtins.print", new=new_print):
-        mod.main()
-    check_calls(calls)
+def test_tutorial_001_tutorial_002(print_mock: PrintMock, module: ModuleType) -> None:
+    module.main()
+    check_calls(print_mock.calls)