]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add results to compiled extensions
authorFederico Caselli <cfederico87@gmail.com>
Tue, 28 Nov 2023 21:40:03 +0000 (22:40 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Fri, 17 May 2024 20:04:45 +0000 (20:04 +0000)
Add the ability to compare a saved result with the current run

Change-Id: I0039cc93ed68d5957753ea49c076d934191e6cd0

lib/sqlalchemy/engine/result.py
test/perf/compiled_extensions/base.py
test/perf/compiled_extensions/command.py
test/perf/compiled_extensions/result.py [new file with mode: 0644]

index 226b7f8c6360f2895316d8c205b8d72e7d36845a..ad39756bd847b87416a1b5d3a930391f870206ac 100644 (file)
@@ -270,6 +270,7 @@ class SimpleResultMetaData(ResultMetaData):
         self._translated_indexes = _translated_indexes
         self._unique_filters = _unique_filters
         if extra:
+            assert len(self._keys) == len(extra)
             recs_names = [
                 (
                     (name,) + (extras if extras else ()),
index fd6c4198fe19902ac062970e34ae097809c471d7..ccf222437cfb67a2a695e669f0bf4b1a5c3de8ea 100644 (file)
@@ -120,4 +120,4 @@ class Case:
                 print("\t", f)
 
         cls.update_results(results)
-        return results
+        return results, [name for name, _ in objects]
index 97cf725460a24a866d7daa59a2e1e32600913b9c..587a9127dcdd9fb750352b7df1d05479ab1b2463 100644 (file)
@@ -1,30 +1,34 @@
 from collections import defaultdict
+from datetime import datetime
+import subprocess
 
+import sqlalchemy as sa
 from .base import Case
 
 if True:
     from . import cache_key  # noqa: F401
     from . import collections_  # noqa: F401
     from . import misc  # noqa: F401
+    from . import result  # noqa: F401
     from . import row  # noqa: F401
 
 
 def tabulate(
-    result_by_impl: dict[str, dict[str, float]],
+    impl_names: list[str],
     result_by_method: dict[str, dict[str, float]],
 ):
     if not result_by_method:
         return
-    dim = 11
+    dim = max(len(n) for n in impl_names)
+    dim = min(dim, 20)
 
     width = max(20, *(len(m) + 1 for m in result_by_method))
 
     string_cell = "{:<%s}" % dim
-    header = "{:<%s}|" % width + f" {string_cell} |" * len(result_by_impl)
+    header = "{:<%s}|" % width + f" {string_cell} |" * len(impl_names)
     num_format = "{:<%s.9f}" % dim
-    csv_row = "{:<%s}|" % width + " {} |" * len(result_by_impl)
-    names = list(result_by_impl)
-    print(header.format("", *names))
+    csv_row = "{:<%s}|" % width + " {} |" * len(impl_names)
+    print(header.format("", *impl_names))
 
     for meth in result_by_method:
         data = result_by_method[meth]
@@ -34,11 +38,21 @@ def tabulate(
                 if name in data
                 else string_cell.format("—")
             )
-            for name in names
+            for name in impl_names
         ]
         print(csv_row.format(meth, *strings))
 
 
+def find_git_sha():
+    try:
+        git_res = subprocess.run(
+            ["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE
+        )
+        return git_res.stdout.decode("utf-8").strip()
+    except Exception:
+        return None
+
+
 def main():
     import argparse
 
@@ -58,24 +72,97 @@ def main():
         "--factor", help="scale number passed to timeit", type=float, default=1
     )
     parser.add_argument("--csv", help="save to csv", action="store_true")
+    save_group = parser.add_argument_group("Save result for later compare")
+    save_group.add_argument(
+        "--save-db",
+        help="Name of the sqlite db file to use",
+        const="perf.db",
+        nargs="?",
+    )
+    save_group.add_argument(
+        "--save-name",
+        help="A name given to the current save. "
+        "Can be used later to compare against this run.",
+    )
+
+    compare_group = parser.add_argument_group("Compare against stored data")
+    compare_group.add_argument(
+        "--compare-db",
+        help="Name of the sqlite db file to read for the compare data",
+        const="perf.db",
+        nargs="?",
+    )
+    compare_group.add_argument(
+        "--compare-filter",
+        help="Filter the compare data using this string. Can include "
+        "git-short-sha, save-name previously used or date. By default the "
+        "latest values are used",
+    )
 
     args = parser.parse_args()
 
+    to_run: list[type[Case]]
     if "all" in args.case:
         to_run = cases
     else:
         to_run = [c for c in cases if c.__name__ in args.case]
 
+    if args.save_db:
+        save_engine = sa.create_engine(
+            f"sqlite:///{args.save_db}", poolclass=sa.NullPool
+        )
+        PerfTable.metadata.create_all(save_engine)
+        sha = find_git_sha()
+
+    if args.compare_db:
+        compare_engine = sa.create_engine(
+            f"sqlite:///{args.compare_db}", poolclass=sa.NullPool
+        )
+        stmt = (
+            sa.select(PerfTable)
+            .where(PerfTable.c.factor == args.factor)
+            .order_by(PerfTable.c.created.desc())
+        )
+        if args.compare_filter:
+            cf = args.compare_filter
+            stmt = stmt.where(
+                sa.or_(
+                    PerfTable.c.created.cast(sa.Text).icontains(cf),
+                    PerfTable.c.git_short_sha.icontains(cf),
+                    PerfTable.c.save_name.icontains(cf),
+                ),
+            )
+
     for case in to_run:
         print("Running case", case.__name__)
-        result_by_impl = case.run_case(args.factor, args.filter)
+        if args.compare_db:
+            with compare_engine.connect() as conn:
+                case_stmt = stmt.where(PerfTable.c.case == case.__name__)
+                compare_by_meth = defaultdict(dict)
+                for prow in conn.execute(case_stmt):
+                    if prow.impl in compare_by_meth[prow.method]:
+                        continue
+                    compare_by_meth[prow.method][prow.impl] = prow.value
+        else:
+            compare_by_meth = {}
+
+        result_by_impl, impl_names = case.run_case(args.factor, args.filter)
 
         result_by_method = defaultdict(dict)
-        for name in result_by_impl:
-            for meth in result_by_impl[name]:
-                result_by_method[meth][name] = result_by_impl[name][meth]
-
-        tabulate(result_by_impl, result_by_method)
+        all_impls = dict.fromkeys(result_by_impl)
+        for impl in result_by_impl:
+            for meth in result_by_impl[impl]:
+                meth_dict = result_by_method[meth]
+                meth_dict[impl] = result_by_impl[impl][meth]
+                if meth in compare_by_meth and impl in compare_by_meth[meth]:
+                    cmp_impl = f"compare {impl}"
+                    over = f"{impl} / compare"
+                    all_impls[cmp_impl] = None
+                    all_impls[over] = None
+                    meth_dict[cmp_impl] = compare_by_meth[meth][impl]
+                    meth_dict[over] = meth_dict[impl] / meth_dict[cmp_impl]
+
+        tabulate(list(all_impls), result_by_method)
 
         if args.csv:
             import csv
@@ -87,3 +174,36 @@ def main():
                 for n in result_by_method:
                     w.writerow({"": n, **result_by_method[n]})
             print("Wrote file", file_name)
+
+        if args.save_db:
+            data = [
+                {
+                    "case": case.__name__,
+                    "impl": impl,
+                    "method": meth,
+                    "value": result_by_impl[impl][meth],
+                    "factor": args.factor,
+                    "save_name": args.save_name,
+                    "git_short_sha": sha,
+                    "created": Now,
+                }
+                for impl in impl_names
+                for meth in result_by_impl[impl]
+            ]
+            with save_engine.begin() as conn:
+                conn.execute(PerfTable.insert(), data)
+
+
+PerfTable = sa.Table(
+    "perf_table",
+    sa.MetaData(),
+    sa.Column("case", sa.Text, nullable=False),
+    sa.Column("impl", sa.Text, nullable=False),
+    sa.Column("method", sa.Text, nullable=False),
+    sa.Column("value", sa.Float),
+    sa.Column("factor", sa.Float),
+    sa.Column("save_name", sa.Text),
+    sa.Column("git_short_sha", sa.Text),
+    sa.Column("created", sa.DateTime, nullable=False),
+)
+Now = datetime.now()
diff --git a/test/perf/compiled_extensions/result.py b/test/perf/compiled_extensions/result.py
new file mode 100644 (file)
index 0000000..b3f7145
--- /dev/null
@@ -0,0 +1,305 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from itertools import product
+from operator import itemgetter
+from typing import Callable
+from typing import Optional
+
+import sqlalchemy as sa
+from sqlalchemy.dialects import sqlite
+from sqlalchemy.engine import cursor
+from sqlalchemy.engine import result
+from sqlalchemy.engine.default import DefaultExecutionContext
+from .base import Case
+from .base import test_case
+
+
+class _CommonResult(Case):
+    @classmethod
+    def init_class(cls):
+        # 3-col
+        cls.def3_plain = Definition(list("abc"))
+        cls.def3_1proc = Definition(list("abc"), [None, str, None])
+        cls.def3_tf = Definition(list("abc"), tuplefilter=itemgetter(1, 2))
+        cls.def3_1proc_tf = Definition(
+            list("abc"), [None, str, None], itemgetter(1, 2)
+        )
+        cls.data3_100 = [(i, i + i, i - 1) for i in range(100)]
+        cls.data3_1000 = [(i, i + i, i - 1) for i in range(1000)]
+        cls.data3_10000 = [(i, i + i, i - 1) for i in range(10000)]
+
+        cls.make_test_cases("row3col", "def3_", "data3_")
+
+        # 21-col
+        cols = [f"c_{i}" for i in range(21)]
+        cls.def21_plain = Definition(cols)
+        cls.def21_7proc = Definition(cols, [None, str, None] * 7)
+        cls.def21_tf = Definition(
+            cols, tuplefilter=itemgetter(1, 2, 9, 17, 18)
+        )
+        cls.def21_7proc_tf = Definition(
+            cols, [None, str, None] * 7, itemgetter(1, 2, 9, 17, 18)
+        )
+        cls.data21_100 = [(i, i + i, i - 1) * 7 for i in range(100)]
+        cls.data21_1000 = [(i, i + i, i - 1) * 7 for i in range(1000)]
+        cls.data21_10000 = [(i, i + i, i - 1) * 7 for i in range(10000)]
+
+        cls.make_test_cases("row21col", "def21_", "data21_")
+
+    @classmethod
+    def make_test_cases(cls, prefix: str, def_prefix: str, data_prefix: str):
+        all_defs = [
+            (k, v) for k, v in vars(cls).items() if k.startswith(def_prefix)
+        ]
+        all_data = [
+            (k, v) for k, v in vars(cls).items() if k.startswith(data_prefix)
+        ]
+        assert all_defs and all_data
+
+        def make_case(name, definition, data, number):
+            init_args = cls.get_init_args_callable(definition, data)
+
+            def go_all(self):
+                result = self.impl(*init_args())
+                result.all()
+
+            setattr(cls, name + "_all", test_case(go_all, number=number))
+
+            def go_all_uq(self):
+                result = self.impl(*init_args()).unique()
+                result.all()
+
+            setattr(cls, name + "_all_uq", test_case(go_all_uq, number=number))
+
+            def go_iter(self):
+                result = self.impl(*init_args())
+                for _ in result:
+                    pass
+
+            setattr(cls, name + "_iter", test_case(go_iter, number=number))
+
+            def go_iter_uq(self):
+                result = self.impl(*init_args()).unique()
+                for _ in result:
+                    pass
+
+            setattr(
+                cls, name + "_iter_uq", test_case(go_iter_uq, number=number)
+            )
+
+            def go_many(self):
+                result = self.impl(*init_args())
+                while result.fetchmany(10):
+                    pass
+
+            setattr(cls, name + "_many", test_case(go_many, number=number))
+
+            def go_many_uq(self):
+                result = self.impl(*init_args()).unique()
+                while result.fetchmany(10):
+                    pass
+
+            setattr(
+                cls, name + "_many_uq", test_case(go_many_uq, number=number)
+            )
+
+            def go_one(self):
+                result = self.impl(*init_args())
+                while result.fetchone() is not None:
+                    pass
+
+            setattr(cls, name + "_one", test_case(go_one, number=number))
+
+            def go_one_uq(self):
+                result = self.impl(*init_args()).unique()
+                while result.fetchone() is not None:
+                    pass
+
+            setattr(cls, name + "_one_uq", test_case(go_one_uq, number=number))
+
+            def go_scalar_all(self):
+                result = self.impl(*init_args())
+                result.scalars().all()
+
+            setattr(
+                cls, name + "_sc_all", test_case(go_scalar_all, number=number)
+            )
+
+            def go_scalar_iter(self):
+                result = self.impl(*init_args())
+                rs = result.scalars()
+                for _ in rs:
+                    pass
+
+            setattr(
+                cls,
+                name + "_sc_iter",
+                test_case(go_scalar_iter, number=number),
+            )
+
+            def go_scalar_many(self):
+                result = self.impl(*init_args())
+                rs = result.scalars()
+                while rs.fetchmany(10):
+                    pass
+
+            setattr(
+                cls,
+                name + "_sc_many",
+                test_case(go_scalar_many, number=number),
+            )
+
+        for (def_name, definition), (data_name, data) in product(
+            all_defs, all_data
+        ):
+            name = (
+                f"{prefix}_{def_name.removeprefix(def_prefix)}_"
+                f"{data_name.removeprefix(data_prefix)}"
+            )
+            number = 500 if data_name.endswith("10000") else None
+            make_case(name, definition, data, number)
+
+    @classmethod
+    def get_init_args_callable(
+        cls, definition: Definition, data: list
+    ) -> Callable:
+        raise NotImplementedError
+
+
+class IteratorResult(_CommonResult):
+    NUMBER = 1_000
+
+    impl: result.IteratorResult
+
+    @staticmethod
+    def default():
+        return cursor.IteratorResult
+
+    IMPLEMENTATIONS = {"default": default.__func__}
+
+    @classmethod
+    def get_init_args_callable(
+        cls, definition: Definition, data: list
+    ) -> Callable:
+        meta = result.SimpleResultMetaData(
+            definition.columns,
+            _processors=definition.processors,
+            _tuplefilter=definition.tuplefilter,
+        )
+        return lambda: (meta, iter(data))
+
+
+class CursorResult(_CommonResult):
+    NUMBER = 1_000
+
+    impl: cursor.CursorResult
+
+    @staticmethod
+    def default():
+        return cursor.CursorResult
+
+    IMPLEMENTATIONS = {"default": default.__func__}
+
+    @classmethod
+    def get_init_args_callable(
+        cls, definition: Definition, data: list
+    ) -> Callable:
+        if definition.processors:
+            proc_dict = {
+                c: p for c, p in zip(definition.columns, definition.processors)
+            }
+        else:
+            proc_dict = None
+
+        class MockExecutionContext(DefaultExecutionContext):
+            def create_cursor(self):
+                return _MockCursor(data, self.compiled)
+
+            def get_result_processor(self, type_, colname, coltype):
+                return None if proc_dict is None else proc_dict[colname]
+
+            def args_for_new_cursor_result(self):
+                self.cursor = self.create_cursor()
+                return (
+                    self,
+                    self.cursor_fetch_strategy,
+                    context.cursor.description,
+                )
+
+        dialect = sqlite.dialect()
+        stmt = sa.select(
+            *(sa.column(c) for c in definition.columns)
+        ).select_from(sa.table("t"))
+        compiled = stmt._compile_w_cache(
+            dialect, compiled_cache=None, column_keys=[]
+        )[0]
+
+        context = MockExecutionContext._init_compiled(
+            dialect=dialect,
+            connection=_MockConnection(dialect),
+            dbapi_connection=None,
+            execution_options={},
+            compiled=compiled,
+            parameters=[],
+            invoked_statement=stmt,
+            extracted_parameters=None,
+        )
+        _ = context._setup_result_proxy()
+        assert compiled._cached_metadata
+
+        return context.args_for_new_cursor_result
+
+
+class _MockCursor:
+    def __init__(self, rows: list[tuple], compiled):
+        self._rows = list(rows)
+        if compiled._result_columns is None:
+            self.description = None
+        else:
+            self.description = [
+                (rc.keyname, 42, None, None, None, True)
+                for rc in compiled._result_columns
+            ]
+
+    def close(self):
+        pass
+
+    def fetchone(self):
+        if self._rows:
+            return self._rows.pop(0)
+        else:
+            return None
+
+    def fetchmany(self, size=None):
+        if size is None:
+            return self.fetchall()
+        else:
+            ret = self._rows[:size]
+            self._rows[:size] = []
+            return ret
+
+    def fetchall(self):
+        ret = self._rows
+        self._rows = []
+        return ret
+
+
+class _MockConnection:
+    _echo = False
+
+    def __init__(self, dialect):
+        self.dialect = dialect
+
+    def _safe_close_cursor(self, cursor):
+        cursor.close()
+
+    def _handle_dbapi_exception(self, e, *args, **kw):
+        raise e
+
+
+@dataclass
+class Definition:
+    columns: list[str]
+    processors: Optional[list[Optional[Callable]]] = None
+    tuplefilter: Optional[Callable] = None