]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(enum): move `enum` arg of register_enum() as third item
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 Apr 2022 08:22:44 +0000 (10:22 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 Apr 2022 03:03:24 +0000 (05:03 +0200)
This is more consistent with the other register_*() functions and allow
a more natural order of arguments if others must be added (e.g.
mapping).

docs/basic/adapt.rst
psycopg/psycopg/types/enum.py
tests/types/test_enum.py

index 97c006b2797e2399684ba6c20eedfe3cea546132..b06541380041baac9b889499044c45ca80e58076 100644 (file)
@@ -449,7 +449,7 @@ Example::
     >>> conn.execute("CREATE TYPE user_role AS ENUM ('ADMIN', 'EDITOR', 'GUEST')")
 
     >>> info = EnumInfo.fetch(conn, "user_role")
-    >>> register_enum(info, UserRole, conn)
+    >>> register_enum(info, conn, UserRole)
 
     >>> some_editor = info.enum.EDITOR
     >>> some_editor
index 70cf016cf6164aa8d4f17ab2da0c3ec85c1bb4a9..f79a7d7d89864b9f0856507120c04fbfbff1056f 100644 (file)
@@ -59,16 +59,16 @@ class EnumBinaryDumper(EnumDumper):
 
 def register_enum(
     info: EnumInfo,
-    enum: Optional[Type[E]] = None,
     context: Optional[AdaptContext] = None,
+    enum: Optional[Type[E]] = None,
 ) -> None:
     """Register the adapters to load and dump a enum type.
 
     :param info: The object with the information about the enum to register.
-    :param enum: Python enum type matching to the PostgreSQL one. If `!None`,
-        a new enum will be generated and exposed as `EnumInfo.enum`.
     :param context: The context where to register the adapters. If `!None`,
         register it globally.
+    :param enum: Python enum type matching to the PostgreSQL one. If `!None`,
+        a new enum will be generated and exposed as `EnumInfo.enum`.
     """
 
     if not info:
index 182fd8ace10b9488c7048855dd962358021de9e6..3cffe372966f8ffa0c139fc69bfb7c1d4faf679a 100644 (file)
@@ -91,7 +91,7 @@ def test_enum_loader(conn, testenum, encoding, fmt_in, fmt_out):
     conn.execute(f"set client_encoding to {encoding}")
 
     name, enum, labels = testenum
-    register_enum(EnumInfo.fetch(conn, name), enum, conn)
+    register_enum(EnumInfo.fetch(conn, name), conn, enum=enum)
 
     for label in labels:
         cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out)
@@ -103,7 +103,7 @@ def test_enum_loader(conn, testenum, encoding, fmt_in, fmt_out):
 @pytest.mark.parametrize("enum", ascii_cases)
 def test_enum_loader_sqlascii(conn, enum, fmt_in, fmt_out):
     info = EnumInfo.fetch(conn, enum.__name__.lower())
-    register_enum(info, enum, conn)
+    register_enum(info, conn, enum)
     conn.execute("set client_encoding to sql_ascii")
 
     for label in info.labels:
@@ -118,7 +118,7 @@ def test_enum_dumper(conn, testenum, encoding, fmt_in, fmt_out):
     conn.execute(f"set client_encoding to {encoding}")
 
     name, enum, labels = testenum
-    register_enum(EnumInfo.fetch(conn, name), enum, conn)
+    register_enum(EnumInfo.fetch(conn, name), conn, enum)
 
     for item in enum:
         cur = conn.execute(f"select %{fmt_in}", [item], binary=fmt_out)
@@ -130,7 +130,7 @@ def test_enum_dumper(conn, testenum, encoding, fmt_in, fmt_out):
 @pytest.mark.parametrize("enum", ascii_cases)
 def test_enum_dumper_sqlascii(conn, enum, fmt_in, fmt_out):
     info = EnumInfo.fetch(conn, enum.__name__.lower())
-    register_enum(info, enum, conn)
+    register_enum(info, conn, enum)
     conn.execute("set client_encoding to sql_ascii")
 
     for item in enum:
@@ -178,7 +178,7 @@ def test_enum_array_loader(conn, testenum, encoding, fmt_in, fmt_out):
     conn.execute(f"set client_encoding to {encoding}")
 
     name, enum, labels = testenum
-    register_enum(EnumInfo.fetch(conn, name), enum, conn)
+    register_enum(EnumInfo.fetch(conn, name), conn, enum)
 
     cur = conn.execute(f"select %{fmt_in}::{name}[]", [labels], binary=fmt_out)
     assert cur.fetchone()[0] == list(enum)
@@ -191,7 +191,7 @@ def test_enum_array_dumper(conn, testenum, encoding, fmt_in, fmt_out):
     conn.execute(f"set client_encoding to {encoding}")
 
     name, enum, labels = testenum
-    register_enum(EnumInfo.fetch(conn, name), enum, conn)
+    register_enum(EnumInfo.fetch(conn, name), conn, enum)
 
     cur = conn.execute(f"select %{fmt_in}", [list(enum)], binary=fmt_out)
     assert cur.fetchone()[0] == list(enum)
@@ -233,7 +233,7 @@ async def test_enum_async(aconn, testenum, encoding, fmt_in, fmt_out):
     await aconn.execute(f"set client_encoding to {encoding}")
 
     name, enum, labels = testenum
-    register_enum(await EnumInfo.fetch(aconn, name), enum, aconn)
+    register_enum(await EnumInfo.fetch(aconn, name), aconn, enum)
 
     for label in labels:
         cur = await aconn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out)