From: Sebastián Ramírez Date: Thu, 19 Jun 2025 16:19:22 +0000 (+0200) Subject: ✅ Simplify tests for `tests/test_tutorial/test_code_structure/test_tutorial001.py... X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=8ae5f9b6c8ec7883f0e95277cc44340cf3489335;p=thirdparty%2Ffastapi%2Fsqlmodel.git ✅ Simplify tests for `tests/test_tutorial/test_code_structure/test_tutorial001.py`, one test file for multiple variants (#1408) --- diff --git a/tests/test_tutorial/test_code_structure/test_tutorial001.py b/tests/test_tutorial/test_code_structure/test_tutorial001.py index 109c1ef5..99ae5c00 100644 --- a/tests/test_tutorial/test_code_structure/test_tutorial001.py +++ b/tests/test_tutorial/test_code_structure/test_tutorial001.py @@ -1,8 +1,11 @@ -from unittest.mock import patch +import importlib +from dataclasses import dataclass +from types import ModuleType +import pytest from sqlmodel import create_engine -from ...conftest import get_testing_print_function +from tests.conftest import PrintMock, needs_py39, needs_py310 expected_calls = [ [ @@ -22,16 +25,34 @@ expected_calls = [ ] -def test_tutorial(): - from docs_src.tutorial.code_structure.tutorial001 import app, database +@dataclass +class Modules: + app: ModuleType + database: ModuleType - database.sqlite_url = "sqlite://" - database.engine = create_engine(database.sqlite_url) - app.engine = database.engine - calls = [] - new_print = get_testing_print_function(calls) - - with patch("builtins.print", new=new_print): - app.main() - assert calls == expected_calls +@pytest.fixture( + name="modules", + params=[ + "tutorial001", + pytest.param("tutorial001_py39", marks=needs_py39), + pytest.param("tutorial001_py310", marks=needs_py310), + ], +) +def get_modules(request: pytest.FixtureRequest) -> Modules: + app_module = importlib.import_module( + f"docs_src.tutorial.code_structure.{request.param}.app" + ) + database_module = importlib.import_module( + f"docs_src.tutorial.code_structure.{request.param}.database" + ) + database_module.sqlite_url = "sqlite://" + database_module.engine = create_engine(database_module.sqlite_url) + app_module.engine = database_module.engine + + return Modules(app=app_module, database=database_module) + + +def test_tutorial(print_mock: PrintMock, modules: Modules): + modules.app.main() + assert print_mock.calls == expected_calls diff --git a/tests/test_tutorial/test_code_structure/test_tutorial001_py310.py b/tests/test_tutorial/test_code_structure/test_tutorial001_py310.py deleted file mode 100644 index 126bef25..00000000 --- a/tests/test_tutorial/test_code_structure/test_tutorial001_py310.py +++ /dev/null @@ -1,38 +0,0 @@ -from unittest.mock import patch - -from sqlmodel import create_engine - -from ...conftest import get_testing_print_function, needs_py310 - -expected_calls = [ - [ - "Created hero:", - { - "id": 1, - "name": "Deadpond", - "age": None, - "secret_name": "Dive Wilson", - "team_id": 1, - }, - ], - [ - "Hero's team:", - {"name": "Z-Force", "headquarters": "Sister Margaret's Bar", "id": 1}, - ], -] - - -@needs_py310 -def test_tutorial(): - from docs_src.tutorial.code_structure.tutorial001_py310 import app, database - - database.sqlite_url = "sqlite://" - database.engine = create_engine(database.sqlite_url) - app.engine = database.engine - calls = [] - - new_print = get_testing_print_function(calls) - - with patch("builtins.print", new=new_print): - app.main() - assert calls == expected_calls diff --git a/tests/test_tutorial/test_code_structure/test_tutorial001_py39.py b/tests/test_tutorial/test_code_structure/test_tutorial001_py39.py deleted file mode 100644 index 02f692ea..00000000 --- a/tests/test_tutorial/test_code_structure/test_tutorial001_py39.py +++ /dev/null @@ -1,38 +0,0 @@ -from unittest.mock import patch - -from sqlmodel import create_engine - -from ...conftest import get_testing_print_function, needs_py39 - -expected_calls = [ - [ - "Created hero:", - { - "id": 1, - "name": "Deadpond", - "age": None, - "secret_name": "Dive Wilson", - "team_id": 1, - }, - ], - [ - "Hero's team:", - {"name": "Z-Force", "headquarters": "Sister Margaret's Bar", "id": 1}, - ], -] - - -@needs_py39 -def test_tutorial(): - from docs_src.tutorial.code_structure.tutorial001_py39 import app, database - - database.sqlite_url = "sqlite://" - database.engine = create_engine(database.sqlite_url) - app.engine = database.engine - calls = [] - - new_print = get_testing_print_function(calls) - - with patch("builtins.print", new=new_print): - app.main() - assert calls == expected_calls