from collections import defaultdict
+from datetime import datetime
+import subprocess
+import sqlalchemy as sa
from .base import Case
if True:
from . import cache_key # noqa: F401
from . import collections_ # noqa: F401
from . import misc # noqa: F401
+ from . import result # noqa: F401
from . import row # noqa: F401
def tabulate(
- result_by_impl: dict[str, dict[str, float]],
+ impl_names: list[str],
result_by_method: dict[str, dict[str, float]],
):
if not result_by_method:
return
- dim = 11
+ dim = max(len(n) for n in impl_names)
+ dim = min(dim, 20)
width = max(20, *(len(m) + 1 for m in result_by_method))
string_cell = "{:<%s}" % dim
- header = "{:<%s}|" % width + f" {string_cell} |" * len(result_by_impl)
+ header = "{:<%s}|" % width + f" {string_cell} |" * len(impl_names)
num_format = "{:<%s.9f}" % dim
- csv_row = "{:<%s}|" % width + " {} |" * len(result_by_impl)
- names = list(result_by_impl)
- print(header.format("", *names))
+ csv_row = "{:<%s}|" % width + " {} |" * len(impl_names)
+ print(header.format("", *impl_names))
for meth in result_by_method:
data = result_by_method[meth]
if name in data
else string_cell.format("—")
)
- for name in names
+ for name in impl_names
]
print(csv_row.format(meth, *strings))
+def find_git_sha():
+ try:
+ git_res = subprocess.run(
+ ["git", "rev-parse", "--short", "HEAD"], stdout=subprocess.PIPE
+ )
+ return git_res.stdout.decode("utf-8").strip()
+ except Exception:
+ return None
+
+
def main():
import argparse
"--factor", help="scale number passed to timeit", type=float, default=1
)
parser.add_argument("--csv", help="save to csv", action="store_true")
+ save_group = parser.add_argument_group("Save result for later compare")
+ save_group.add_argument(
+ "--save-db",
+ help="Name of the sqlite db file to use",
+ const="perf.db",
+ nargs="?",
+ )
+ save_group.add_argument(
+ "--save-name",
+ help="A name given to the current save. "
+ "Can be used later to compare against this run.",
+ )
+
+ compare_group = parser.add_argument_group("Compare against stored data")
+ compare_group.add_argument(
+ "--compare-db",
+ help="Name of the sqlite db file to read for the compare data",
+ const="perf.db",
+ nargs="?",
+ )
+ compare_group.add_argument(
+ "--compare-filter",
+ help="Filter the compare data using this string. Can include "
+ "git-short-sha, save-name previously used or date. By default the "
+ "latest values are used",
+ )
args = parser.parse_args()
+ to_run: list[type[Case]]
if "all" in args.case:
to_run = cases
else:
to_run = [c for c in cases if c.__name__ in args.case]
+ if args.save_db:
+ save_engine = sa.create_engine(
+ f"sqlite:///{args.save_db}", poolclass=sa.NullPool
+ )
+ PerfTable.metadata.create_all(save_engine)
+ sha = find_git_sha()
+
+ if args.compare_db:
+ compare_engine = sa.create_engine(
+ f"sqlite:///{args.compare_db}", poolclass=sa.NullPool
+ )
+ stmt = (
+ sa.select(PerfTable)
+ .where(PerfTable.c.factor == args.factor)
+ .order_by(PerfTable.c.created.desc())
+ )
+ if args.compare_filter:
+ cf = args.compare_filter
+ stmt = stmt.where(
+ sa.or_(
+ PerfTable.c.created.cast(sa.Text).icontains(cf),
+ PerfTable.c.git_short_sha.icontains(cf),
+ PerfTable.c.save_name.icontains(cf),
+ ),
+ )
+
for case in to_run:
print("Running case", case.__name__)
- result_by_impl = case.run_case(args.factor, args.filter)
+ if args.compare_db:
+ with compare_engine.connect() as conn:
+ case_stmt = stmt.where(PerfTable.c.case == case.__name__)
+ compare_by_meth = defaultdict(dict)
+ for prow in conn.execute(case_stmt):
+ if prow.impl in compare_by_meth[prow.method]:
+ continue
+ compare_by_meth[prow.method][prow.impl] = prow.value
+ else:
+ compare_by_meth = {}
+
+ result_by_impl, impl_names = case.run_case(args.factor, args.filter)
result_by_method = defaultdict(dict)
- for name in result_by_impl:
- for meth in result_by_impl[name]:
- result_by_method[meth][name] = result_by_impl[name][meth]
-
- tabulate(result_by_impl, result_by_method)
+ all_impls = dict.fromkeys(result_by_impl)
+ for impl in result_by_impl:
+ for meth in result_by_impl[impl]:
+ meth_dict = result_by_method[meth]
+ meth_dict[impl] = result_by_impl[impl][meth]
+ if meth in compare_by_meth and impl in compare_by_meth[meth]:
+ cmp_impl = f"compare {impl}"
+ over = f"{impl} / compare"
+ all_impls[cmp_impl] = None
+ all_impls[over] = None
+ meth_dict[cmp_impl] = compare_by_meth[meth][impl]
+ meth_dict[over] = meth_dict[impl] / meth_dict[cmp_impl]
+
+ tabulate(list(all_impls), result_by_method)
if args.csv:
import csv
for n in result_by_method:
w.writerow({"": n, **result_by_method[n]})
print("Wrote file", file_name)
+
+ if args.save_db:
+ data = [
+ {
+ "case": case.__name__,
+ "impl": impl,
+ "method": meth,
+ "value": result_by_impl[impl][meth],
+ "factor": args.factor,
+ "save_name": args.save_name,
+ "git_short_sha": sha,
+ "created": Now,
+ }
+ for impl in impl_names
+ for meth in result_by_impl[impl]
+ ]
+ with save_engine.begin() as conn:
+ conn.execute(PerfTable.insert(), data)
+
+
+PerfTable = sa.Table(
+ "perf_table",
+ sa.MetaData(),
+ sa.Column("case", sa.Text, nullable=False),
+ sa.Column("impl", sa.Text, nullable=False),
+ sa.Column("method", sa.Text, nullable=False),
+ sa.Column("value", sa.Float),
+ sa.Column("factor", sa.Float),
+ sa.Column("save_name", sa.Text),
+ sa.Column("git_short_sha", sa.Text),
+ sa.Column("created", sa.DateTime, nullable=False),
+)
+Now = datetime.now()
--- /dev/null
+from __future__ import annotations
+
+from dataclasses import dataclass
+from itertools import product
+from operator import itemgetter
+from typing import Callable
+from typing import Optional
+
+import sqlalchemy as sa
+from sqlalchemy.dialects import sqlite
+from sqlalchemy.engine import cursor
+from sqlalchemy.engine import result
+from sqlalchemy.engine.default import DefaultExecutionContext
+from .base import Case
+from .base import test_case
+
+
+class _CommonResult(Case):
+ @classmethod
+ def init_class(cls):
+ # 3-col
+ cls.def3_plain = Definition(list("abc"))
+ cls.def3_1proc = Definition(list("abc"), [None, str, None])
+ cls.def3_tf = Definition(list("abc"), tuplefilter=itemgetter(1, 2))
+ cls.def3_1proc_tf = Definition(
+ list("abc"), [None, str, None], itemgetter(1, 2)
+ )
+ cls.data3_100 = [(i, i + i, i - 1) for i in range(100)]
+ cls.data3_1000 = [(i, i + i, i - 1) for i in range(1000)]
+ cls.data3_10000 = [(i, i + i, i - 1) for i in range(10000)]
+
+ cls.make_test_cases("row3col", "def3_", "data3_")
+
+ # 21-col
+ cols = [f"c_{i}" for i in range(21)]
+ cls.def21_plain = Definition(cols)
+ cls.def21_7proc = Definition(cols, [None, str, None] * 7)
+ cls.def21_tf = Definition(
+ cols, tuplefilter=itemgetter(1, 2, 9, 17, 18)
+ )
+ cls.def21_7proc_tf = Definition(
+ cols, [None, str, None] * 7, itemgetter(1, 2, 9, 17, 18)
+ )
+ cls.data21_100 = [(i, i + i, i - 1) * 7 for i in range(100)]
+ cls.data21_1000 = [(i, i + i, i - 1) * 7 for i in range(1000)]
+ cls.data21_10000 = [(i, i + i, i - 1) * 7 for i in range(10000)]
+
+ cls.make_test_cases("row21col", "def21_", "data21_")
+
+ @classmethod
+ def make_test_cases(cls, prefix: str, def_prefix: str, data_prefix: str):
+ all_defs = [
+ (k, v) for k, v in vars(cls).items() if k.startswith(def_prefix)
+ ]
+ all_data = [
+ (k, v) for k, v in vars(cls).items() if k.startswith(data_prefix)
+ ]
+ assert all_defs and all_data
+
+ def make_case(name, definition, data, number):
+ init_args = cls.get_init_args_callable(definition, data)
+
+ def go_all(self):
+ result = self.impl(*init_args())
+ result.all()
+
+ setattr(cls, name + "_all", test_case(go_all, number=number))
+
+ def go_all_uq(self):
+ result = self.impl(*init_args()).unique()
+ result.all()
+
+ setattr(cls, name + "_all_uq", test_case(go_all_uq, number=number))
+
+ def go_iter(self):
+ result = self.impl(*init_args())
+ for _ in result:
+ pass
+
+ setattr(cls, name + "_iter", test_case(go_iter, number=number))
+
+ def go_iter_uq(self):
+ result = self.impl(*init_args()).unique()
+ for _ in result:
+ pass
+
+ setattr(
+ cls, name + "_iter_uq", test_case(go_iter_uq, number=number)
+ )
+
+ def go_many(self):
+ result = self.impl(*init_args())
+ while result.fetchmany(10):
+ pass
+
+ setattr(cls, name + "_many", test_case(go_many, number=number))
+
+ def go_many_uq(self):
+ result = self.impl(*init_args()).unique()
+ while result.fetchmany(10):
+ pass
+
+ setattr(
+ cls, name + "_many_uq", test_case(go_many_uq, number=number)
+ )
+
+ def go_one(self):
+ result = self.impl(*init_args())
+ while result.fetchone() is not None:
+ pass
+
+ setattr(cls, name + "_one", test_case(go_one, number=number))
+
+ def go_one_uq(self):
+ result = self.impl(*init_args()).unique()
+ while result.fetchone() is not None:
+ pass
+
+ setattr(cls, name + "_one_uq", test_case(go_one_uq, number=number))
+
+ def go_scalar_all(self):
+ result = self.impl(*init_args())
+ result.scalars().all()
+
+ setattr(
+ cls, name + "_sc_all", test_case(go_scalar_all, number=number)
+ )
+
+ def go_scalar_iter(self):
+ result = self.impl(*init_args())
+ rs = result.scalars()
+ for _ in rs:
+ pass
+
+ setattr(
+ cls,
+ name + "_sc_iter",
+ test_case(go_scalar_iter, number=number),
+ )
+
+ def go_scalar_many(self):
+ result = self.impl(*init_args())
+ rs = result.scalars()
+ while rs.fetchmany(10):
+ pass
+
+ setattr(
+ cls,
+ name + "_sc_many",
+ test_case(go_scalar_many, number=number),
+ )
+
+ for (def_name, definition), (data_name, data) in product(
+ all_defs, all_data
+ ):
+ name = (
+ f"{prefix}_{def_name.removeprefix(def_prefix)}_"
+ f"{data_name.removeprefix(data_prefix)}"
+ )
+ number = 500 if data_name.endswith("10000") else None
+ make_case(name, definition, data, number)
+
+ @classmethod
+ def get_init_args_callable(
+ cls, definition: Definition, data: list
+ ) -> Callable:
+ raise NotImplementedError
+
+
+class IteratorResult(_CommonResult):
+ NUMBER = 1_000
+
+ impl: result.IteratorResult
+
+ @staticmethod
+ def default():
+ return cursor.IteratorResult
+
+ IMPLEMENTATIONS = {"default": default.__func__}
+
+ @classmethod
+ def get_init_args_callable(
+ cls, definition: Definition, data: list
+ ) -> Callable:
+ meta = result.SimpleResultMetaData(
+ definition.columns,
+ _processors=definition.processors,
+ _tuplefilter=definition.tuplefilter,
+ )
+ return lambda: (meta, iter(data))
+
+
+class CursorResult(_CommonResult):
+ NUMBER = 1_000
+
+ impl: cursor.CursorResult
+
+ @staticmethod
+ def default():
+ return cursor.CursorResult
+
+ IMPLEMENTATIONS = {"default": default.__func__}
+
+ @classmethod
+ def get_init_args_callable(
+ cls, definition: Definition, data: list
+ ) -> Callable:
+ if definition.processors:
+ proc_dict = {
+ c: p for c, p in zip(definition.columns, definition.processors)
+ }
+ else:
+ proc_dict = None
+
+ class MockExecutionContext(DefaultExecutionContext):
+ def create_cursor(self):
+ return _MockCursor(data, self.compiled)
+
+ def get_result_processor(self, type_, colname, coltype):
+ return None if proc_dict is None else proc_dict[colname]
+
+ def args_for_new_cursor_result(self):
+ self.cursor = self.create_cursor()
+ return (
+ self,
+ self.cursor_fetch_strategy,
+ context.cursor.description,
+ )
+
+ dialect = sqlite.dialect()
+ stmt = sa.select(
+ *(sa.column(c) for c in definition.columns)
+ ).select_from(sa.table("t"))
+ compiled = stmt._compile_w_cache(
+ dialect, compiled_cache=None, column_keys=[]
+ )[0]
+
+ context = MockExecutionContext._init_compiled(
+ dialect=dialect,
+ connection=_MockConnection(dialect),
+ dbapi_connection=None,
+ execution_options={},
+ compiled=compiled,
+ parameters=[],
+ invoked_statement=stmt,
+ extracted_parameters=None,
+ )
+ _ = context._setup_result_proxy()
+ assert compiled._cached_metadata
+
+ return context.args_for_new_cursor_result
+
+
+class _MockCursor:
+ def __init__(self, rows: list[tuple], compiled):
+ self._rows = list(rows)
+ if compiled._result_columns is None:
+ self.description = None
+ else:
+ self.description = [
+ (rc.keyname, 42, None, None, None, True)
+ for rc in compiled._result_columns
+ ]
+
+ def close(self):
+ pass
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ return self.fetchall()
+ else:
+ ret = self._rows[:size]
+ self._rows[:size] = []
+ return ret
+
+ def fetchall(self):
+ ret = self._rows
+ self._rows = []
+ return ret
+
+
+class _MockConnection:
+ _echo = False
+
+ def __init__(self, dialect):
+ self.dialect = dialect
+
+ def _safe_close_cursor(self, cursor):
+ cursor.close()
+
+ def _handle_dbapi_exception(self, e, *args, **kw):
+ raise e
+
+
+@dataclass
+class Definition:
+ columns: list[str]
+ processors: Optional[list[Optional[Callable]]] = None
+ tuplefilter: Optional[Callable] = None