]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Handle memoriview and buffer data as well as bytes, better import
authorJacopo Farina <jacopo.farina@flixbus.com>
Tue, 14 Sep 2021 09:24:13 +0000 (11:24 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 21 Sep 2021 17:11:11 +0000 (18:11 +0100)
psycopg/psycopg/types/geometry.py

index 7956ebb3807f22d1a9c0faf7ec4e27986f9a4067..0d88a052e61d7b63450a6186b0a89c801e886c18 100644 (file)
@@ -5,14 +5,14 @@ Adapters for PostGIS geometries
 from typing import Optional, Type
 
 from .. import postgres
-from ..abc import AdaptContext
+from ..abc import AdaptContext, Buffer
 from ..adapt import Dumper, Loader
 from ..pq import Format
 from .._typeinfo import TypeInfo
 
 
 try:
-    import shapely.wkb as wkb
+    from shapely.wkb import loads, dumps
     from shapely.geometry.base import BaseGeometry
 
 except ImportError:
@@ -25,30 +25,34 @@ except ImportError:
 class GeometryBinaryLoader(Loader):
     format = Format.BINARY
 
-    def load(self, data: bytes) -> "BaseGeometry":
-        return wkb.loads(data)
+    def load(self, data: Buffer) -> "BaseGeometry":
+        if not isinstance(data, bytes):
+            data = bytes(data)
+        return loads(data)
 
 
 class GeometryLoader(Loader):
     format = Format.TEXT
 
-    def load(self, data: bytes) -> "BaseGeometry":
+    def load(self, data: Buffer) -> "BaseGeometry":
         # it's a hex string in binary
-        return wkb.loads(data.decode(), hex=True)
+        if isinstance(data, memoryview):
+            data = bytes(data)
+        return loads(data.decode(), hex=True)
 
 
 class GeometryBinaryDumper(Dumper):
     format = Format.BINARY
 
     def dump(self, obj: "BaseGeometry") -> bytes:
-        return wkb.dumps(obj).encode()  # type: ignore
+        return dumps(obj)  # type: ignore
 
 
 class GeometryDumper(Dumper):
     format = Format.TEXT
 
     def dump(self, obj: "BaseGeometry") -> bytes:
-        return wkb.dumps(obj, hex=True).encode()  # type: ignore
+        return dumps(obj, hex=True).encode()  # type: ignore
 
 
 def register_shapely(