]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added random data generator and test for memory leaks
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jan 2021 22:18:01 +0000 (23:18 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jan 2021 22:18:01 +0000 (23:18 +0100)
Only implemented for binary types, as they are less than text...

tests/conftest.py
tests/fix_faker.py [new file with mode: 0644]
tests/test_cursor.py

index 7b1670859654033f26f30102475f182292fc9365..91b87b6924049d2b7328b486602575439f991eb6 100644 (file)
@@ -1,6 +1,7 @@
 pytest_plugins = (
     "tests.fix_db",
     "tests.fix_pq",
+    "tests.fix_faker",
 )
 
 
diff --git a/tests/fix_faker.py b/tests/fix_faker.py
new file mode 100644 (file)
index 0000000..c5b6d2f
--- /dev/null
@@ -0,0 +1,352 @@
+import importlib
+from math import isnan
+from uuid import UUID
+from random import choice, random, randrange
+from collections import deque
+
+import pytest
+
+import psycopg3
+from psycopg3 import sql
+from psycopg3.oids import builtins
+from psycopg3.adapt import Format
+
+
+@pytest.fixture
+def faker(conn):
+    return Faker(conn)
+
+
+class Faker:
+    """
+    An object to generate random records.
+    """
+
+    json_max_level = 3
+    json_max_length = 10
+    str_max_length = 100
+    list_max_length = 20
+
+    def __init__(self, connection):
+        self.conn = connection
+        self.format = Format.TEXT
+        self.records = []
+
+        self._schema = None
+        self._types_names = None
+        self._makers = {}
+        self.table_name = sql.Identifier("fake_table")
+
+    @property
+    def schema(self):
+        if not self._schema:
+            self._schema = self.choose_schema()
+        return self._schema
+
+    @schema.setter
+    def schema(self, schema):
+        self._schema = schema
+        self._types_names = None
+
+    @property
+    def fields_names(self):
+        return [sql.Identifier(f"fld_{i}") for i in range(len(self.schema))]
+
+    @property
+    def types_names(self):
+        if self._types_names:
+            return self._types_names
+
+        record = self.make_record(nulls=0)
+        tx = psycopg3.adapt.Transformer(self.conn)
+        types = []
+        for value in record:
+            dumper = tx.get_dumper(value, self.format)
+            dumper.dump(value)  # load the oid if it's dynamic (e.g. array)
+            info = builtins.get(dumper.oid) or builtins.get("text")
+            if dumper.oid == info.array_oid:
+                types.append(sql.SQL("{}[]").format(sql.Identifier(info.name)))
+            else:
+                types.append(sql.Identifier(info.name))
+
+        self._types_names = types
+        return types
+
+    @property
+    def drop_stmt(self):
+        return sql.SQL("drop table if exists {}").format(self.table_name)
+
+    @property
+    def create_stmt(self):
+        fields = []
+        for name, type in zip(self.fields_names, self.types_names):
+            fields.append(sql.SQL("{} {}").format(name, type))
+
+        fields = sql.SQL(", ").join(fields)
+        return sql.SQL(
+            "create table {table} (id serial primary key, {fields})"
+        ).format(table=self.table_name, fields=fields)
+
+    @property
+    def insert_stmt(self):
+        phs = [
+            sql.Placeholder(format=self.format)
+            for i in range(len(self.schema))
+        ]
+        return sql.SQL("insert into {} ({}) values ({})").format(
+            self.table_name,
+            sql.SQL(", ").join(self.fields_names),
+            sql.SQL(", ").join(phs),
+        )
+
+    @property
+    def select_stmt(self):
+        fields = sql.SQL(", ").join(self.fields_names)
+        return sql.SQL("select {} from {} order by id").format(
+            fields, self.table_name
+        )
+
+    def choose_schema(self, types=None, nfields=20):
+        if not types:
+            types = self.get_supported_types()
+
+        types_list = sorted(types, key=lambda cls: cls.__name__)
+        schema = [choice(types_list) for i in range(nfields)]
+        for i, cls in enumerate(schema):
+            # choose the type of the array
+            if cls is list:
+                while 1:
+                    scls = choice(types_list)
+                    if scls is not list:
+                        break
+                schema[i] = [scls]
+            elif cls is tuple:
+                schema[i] = tuple(
+                    self.choose_schema(types=types, nfields=nfields)
+                )
+
+        return schema
+
+    def make_records(self, nrecords):
+        self.records = [self.make_record(nulls=0.05) for i in range(nrecords)]
+
+    def make_record(self, nulls=0):
+        if not nulls:
+            return tuple(self.make(spec) for spec in self.schema)
+        else:
+            return tuple(
+                self.make(spec) if random() > nulls else None
+                for spec in self.schema
+            )
+
+    def assert_record(self, got, want):
+        for spec, g, w in zip(self.schema, got, want):
+            if g is None and w is None:
+                continue
+            m = self.get_matcher(spec)
+            m(spec, g, w)
+
+    def get_supported_types(self):
+        dumpers = self.conn.adapters._dumpers[self.format]
+        rv = set()
+        for cls in dumpers.keys():
+            if isinstance(cls, str):
+                cls = deep_import(cls)
+            rv.add(cls)
+
+        # check all the types are handled
+        for cls in rv:
+            self.get_maker(cls)
+
+        return rv
+
+    def get_maker(self, spec):
+        # convert a list or tuple into list or tuple
+        cls = spec if isinstance(spec, type) else type(spec)
+
+        try:
+            return self._makers[cls]
+        except KeyError:
+            pass
+
+        meth = self._get_method("make", cls)
+        if meth:
+            self._makers[cls] = meth
+            return meth
+        else:
+            raise NotImplementedError(
+                f"cannot make fake objects of class {cls}"
+            )
+
+    def get_matcher(self, spec):
+        # convert a list or tuple into list or tuple
+        cls = spec if isinstance(spec, type) else type(spec)
+        meth = self._get_method("match", cls)
+        return meth if meth else self.match_any
+
+    def _get_method(self, prefix, cls):
+        name = cls.__name__
+        if cls.__module__ != "builtins":
+            name = f"{cls.__module__}.{name}"
+
+        parts = name.split(".")
+        for i in range(len(parts)):
+            mname = f"{prefix}_{'_'.join(parts[-(i + 1) :])}"
+            meth = getattr(self, mname, None)
+            if meth:
+                return meth
+
+        return None
+
+    # methods to implement specific objects
+
+    def make(self, spec):
+        # spec can be a type or a list [type] or a tuple (spec, spec, ...)
+        return self.get_maker(spec)(spec)
+
+    def match_any(self, spec, got, want):
+        assert got == want
+
+    def make_bool(self, spec):
+        return choice((True, False))
+
+    def make_bytearray(self, spec):
+        return self.make_bytes(spec)
+
+    def make_bytes(self, spec):
+        length = randrange(self.str_max_length)
+        return spec(bytes([randrange(256) for i in range(length)]))
+
+    def make_float(self, spec):
+        if random() <= 0.99:
+            # this exponent should generate no inf
+            return float(
+                f"{choice('-+')}0.{randrange(1 << 53)}e{randrange(-310,309)}"
+            )
+        else:
+            return choice(
+                (0.0, -0.0, float("-inf"), float("inf"), float("nan"))
+            )
+
+    def match_float(self, spec, got, want):
+        if got is not None and isnan(got):
+            assert isnan(want)
+        else:
+            assert got == want
+
+    def make_int(self, spec):
+        return randrange(-(1 << 63), 1 << 63)
+
+    def make_Int2(self, spec):
+        return spec(randrange(-(1 << 15), 1 << 15))
+
+    def make_Int4(self, spec):
+        return spec(randrange(-(1 << 31), 1 << 31))
+
+    def make_Int8(self, spec):
+        return spec(randrange(-(1 << 63), 1 << 63))
+
+    def make_Json(self, spec):
+        return spec(self._make_json())
+
+    def match_Json(self, spec, got, want):
+        if want is not None:
+            want = want.obj
+        assert got == want
+
+    def make_Jsonb(self, spec):
+        return spec(self._make_json())
+
+    def match_Jsonb(self, spec, got, want):
+        return self.match_Json(spec, got, want)
+
+    def make_JsonFloat(self, spec):
+        # A float limited to what json accepts
+        # this exponent should generate no inf
+        return float(
+            f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}"
+        )
+
+    def make_list(self, spec):
+        # don't make empty lists because they regularly fail cast
+        length = randrange(1, self.list_max_length)
+        spec = spec[0]
+        return [self.make(spec) for i in range(length)]
+
+    def match_list(self, spec, got, want):
+        assert len(got) == len(want)
+        m = self.get_matcher(spec[0])
+        for g, w in zip(got, want):
+            m(spec, g, w)
+
+    def make_memoryview(self, spec):
+        return self.make_bytes(spec)
+
+    def make_NoneType(self, spec):
+        return None
+
+    def make_Oid(self, spec):
+        return spec(randrange(1 << 32))
+
+    def make_str(self, spec, length=0):
+        if not length:
+            length = randrange(self.str_max_length)
+
+        rv = []
+        while len(rv) < length:
+            c = randrange(1, 128) if random() < 0.5 else randrange(0x110000)
+            if not (0xD800 <= c <= 0xDBFF or 0xDC00 <= c <= 0xDFFF):
+                rv.append(c)
+
+        return "".join(map(chr, rv))
+
+    def make_UUID(self, spec):
+        return UUID(bytes=bytes([randrange(256) for i in range(16)]))
+
+    def _make_json(self, container_chance=0.66):
+        rec_types = [list, dict]
+        scal_types = [type(None), int, JsonFloat, bool, str]
+        if random() < container_chance:
+            cls = choice(rec_types)
+            if cls is list:
+                return [
+                    self._make_json(container_chance=container_chance / 2.0)
+                    for i in range(randrange(self.json_max_length))
+                ]
+            elif cls is dict:
+                return {
+                    self.make_str(str, 15): self._make_json(
+                        container_chance=container_chance / 2.0
+                    )
+                    for i in range(randrange(self.json_max_length))
+                }
+            else:
+                assert False, f"unknown rec type: {cls}"
+
+        else:
+            cls = choice(scal_types)
+            return self.make(cls)
+
+
+class JsonFloat:
+    pass
+
+
+def deep_import(name):
+    parts = deque(name.split("."))
+    seen = []
+    if not parts:
+        raise ValueError("name must be a dot-separated name")
+
+    seen.append(parts.popleft())
+    thing = importlib.import_module(seen[-1])
+    while parts:
+        attr = parts.popleft()
+        seen.append(attr)
+
+        if hasattr(thing, attr):
+            thing = getattr(thing, attr)
+        else:
+            thing = importlib.import_module(".".join(seen))
+
+    return thing
index c267f7d1dce9d53387db0d28cf2b0d7a04c6be55..e12c5c1bb4b017a84e86dbe714fef1dd38b73657 100644 (file)
@@ -6,6 +6,7 @@ import pytest
 
 import psycopg3
 from psycopg3.oids import builtins
+from psycopg3.adapt import Format
 
 
 def test_close(conn):
@@ -392,3 +393,33 @@ def test_str(conn):
     cur.close()
     assert "[closed]" in str(cur)
     assert "[INTRANS]" in str(cur)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt", [Format.TEXT, Format.BINARY])
+def test_leak_fetchall(dsn, faker, fmt):
+    if fmt == Format.TEXT:
+        pytest.xfail("faker to extend to all text dumpers")
+
+    faker.format = fmt
+    faker.choose_schema()
+    faker.make_records(100)
+
+    n = []
+    for i in range(3):
+        with psycopg3.connect(dsn) as conn:
+            with conn.cursor(format=fmt) as cur:
+                cur.execute(faker.drop_stmt)
+                cur.execute(faker.create_stmt)
+                cur.executemany(faker.insert_stmt, faker.records)
+                cur.execute(faker.select_stmt)
+                for got, want in zip(cur.fetchall(), faker.records):
+                    faker.assert_record(got, want)
+        del cur, conn
+        gc.collect()
+        gc.collect()
+        n.append(len(gc.get_objects()))
+
+    assert (
+        n[0] == n[1] == n[2]
+    ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"