]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test(shapely): better use of fixtures to drop test duplications
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 26 Mar 2025 14:46:00 +0000 (15:46 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 26 Mar 2025 18:28:18 +0000 (19:28 +0100)
tests/types/test_shapely.py

index b4027b0cbbbea647d3f7c0c35079b31ba52d1d6b..a681ffedaeb7c24188d96d546bd0d9704aa7b935 100644 (file)
@@ -14,13 +14,23 @@ from psycopg.types.shapely import register_shapely, shapely_version
 if shapely_version >= (2, 0):
     from shapely import get_srid, set_srid
 else:
-    set_srid = get_srid = None  # type: ignore[assignment]
+
+    def set_srid(obj, srid):  # type: ignore[no-redef]
+        return obj
+
+    def get_srid(obj):  # type: ignore[no-redef]
+        raise NotImplementedError
+
 
 pytestmark = [
     pytest.mark.postgis,
     pytest.mark.crdb("skip"),
 ]
 
+SAMPLE_POINT = Point(1.2, 3.4)
+SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)])
+SAMPLE_POLYGON_4326 = set_srid(SAMPLE_POLYGON, 4326)
+
 # real example, with CRS and "holes"
 MULTIPOLYGON_GEOJSON = """
 {
@@ -91,74 +101,42 @@ def test_no_info_error(conn):
         register_shapely(None, conn)  # type: ignore[arg-type]
 
 
-def test_with_adapter(shapely_conn):
-    SAMPLE_POINT = Point(1.2, 3.4)
-    SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)])
-
-    assert (
-        shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POINT]).fetchone()[0]
-        == "geometry"
-    )
-
-    assert (
-        shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POLYGON]).fetchone()[0]
-        == "geometry"
-    )
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("obj", ["SAMPLE_POINT", "SAMPLE_POLYGON"])
+def test_with_adapter(shapely_conn, obj, fmt_in):
+    obj = globals()[obj]
+    with shapely_conn.cursor() as cur:
+        cur.execute(f"SELECT pg_typeof(%{fmt_in})", [obj])
+        assert cur.fetchone()[0] == "geometry"
 
 
 @pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", Format)
-def test_write_read_shape(shapely_conn, fmt_in, fmt_out):
-    SAMPLE_POINT = Point(1.2, 3.4)
-    SAMPLE_POLYGON_4326 = Polygon([(0, 0), (1, 1), (1, 0)])
-    if set_srid is not None:
-        SAMPLE_POLYGON_4326 = set_srid(SAMPLE_POLYGON_4326, 4326)
-
+@pytest.mark.parametrize(
+    "obj, srid",
+    [("SAMPLE_POINT", 0), ("SAMPLE_POLYGON", 0), ("SAMPLE_POLYGON_4326", 4326)],
+)
+def test_write_read_shape(shapely_conn, fmt_in, fmt_out, obj, srid):
+    obj = globals()[obj]
     with shapely_conn.cursor(binary=fmt_out) as cur:
-        cur.execute(
-            """
-        create table sample_geoms(
-            id     INTEGER PRIMARY KEY,
-            geom   geometry
-        )
-        """
-        )
-        cur.execute(
-            f"insert into sample_geoms(id, geom) VALUES(1, %{fmt_in})",
-            (SAMPLE_POINT,),
-        )
-        cur.execute(
-            f"insert into sample_geoms(id, geom) VALUES(2, %{fmt_in})",
-            (SAMPLE_POLYGON_4326,),
-        )
-
-        cur.execute("select geom from sample_geoms where id=1")
+        cur.execute("drop table if exists sample_geoms")
+        cur.execute("create table sample_geoms(id SERIAL PRIMARY KEY, geom geometry)")
+        cur.execute(f"insert into sample_geoms(geom) VALUES(%{fmt_in})", (obj,))
+        cur.execute("select geom from sample_geoms")
         result = cur.fetchone()[0]
-        assert result == SAMPLE_POINT
-        if get_srid is not None:
-            assert get_srid(result) == 0
-
-        cur.execute("select geom from sample_geoms where id=2")
-        result = cur.fetchone()[0]
-        assert result == SAMPLE_POLYGON_4326
-        if get_srid is not None:
-            assert get_srid(result) == 4326
+        assert result == obj
+        if shapely_version >= (2, 0):
+            assert get_srid(result) == srid
 
 
 @pytest.mark.parametrize("fmt_out", Format)
 def test_match_geojson(shapely_conn, fmt_out):
-    SAMPLE_POINT = Point(1.2, 3.4)
     with shapely_conn.cursor(binary=fmt_out) as cur:
-        cur.execute(
-            """
-            select ST_GeomFromGeoJSON(%s)
-            """,
-            (SAMPLE_POINT_GEOJSON,),
-        )
+        cur.execute("select ST_GeomFromGeoJSON(%s)", (SAMPLE_POINT_GEOJSON,))
         result = cur.fetchone()[0]
         # clone the coordinates to have a list instead of a shapely wrapper
         assert result.coords[:] == SAMPLE_POINT.coords[:]
-        #
+
         cur.execute("select ST_GeomFromGeoJSON(%s)", (MULTIPOLYGON_GEOJSON,))
         result = cur.fetchone()[0]
         assert isinstance(result, MultiPolygon)