]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add support for aiosqlite
authorFederico Caselli <cfederico87@gmail.com>
Sat, 6 Feb 2021 14:17:20 +0000 (15:17 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 24 Mar 2021 15:45:39 +0000 (11:45 -0400)
Added support for the aiosqlite database driver for use with the
SQLAlchemy asyncio extension.

Fixes: #5920
Change-Id: Id11a320516a44e886a6f518d2866a0f992413e55

24 files changed:
doc/build/changelog/unreleased_14/5920.rst [new file with mode: 0644]
doc/build/dialects/sqlite.rst
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/sqlite/__init__.py
lib/sqlalchemy/dialects/sqlite/aiosqlite.py [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/provision.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/pool/impl.py
lib/sqlalchemy/testing/engines.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_dialect.py
lib/sqlalchemy/testing/suite/test_results.py
lib/sqlalchemy/util/concurrency.py
setup.cfg
test/dialect/test_sqlite.py
test/engine/test_execute.py
test/engine/test_pool.py
test/engine/test_reconnect.py
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_session_py3k.py
test/orm/test_transaction.py
test/requirements.py
tox.ini

diff --git a/doc/build/changelog/unreleased_14/5920.rst b/doc/build/changelog/unreleased_14/5920.rst
new file mode 100644 (file)
index 0000000..a44148c
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: sqlite, feature, asyncio
+    :tickets: 5920
+
+    Added support for the aiosqlite database driver for use with the
+    SQLAlchemy asyncio extension.
+
+    .. seealso::
+
+      :ref:`aiosqlite`
index 0c04ce3f5138c64dc699de857e7aea48b1ce2b37..e9a84e338806701ffba78c861a8415edc070e763 100644 (file)
@@ -35,11 +35,20 @@ SQLite DML Constructs
 .. autoclass:: sqlalchemy.dialects.sqlite.Insert
   :members:
 
+.. _pysqlite:
+
 Pysqlite
 --------
 
 .. automodule:: sqlalchemy.dialects.sqlite.pysqlite
 
+.. _aiosqlite:
+
+Aiosqlite
+---------
+
+.. automodule:: sqlalchemy.dialects.sqlite.aiosqlite
+
 Pysqlcipher
 -----------
 
index cab6df499f980ce4e5faa629ba6a7341c83ff772..c8c7c0f978af69b7ec676796fdc389c60012e9a6 100644 (file)
@@ -82,6 +82,13 @@ class AsyncAdapt_aiomysql_cursor:
         return self._cursor.lastrowid
 
     def close(self):
+        # note we aren't actually closing the cursor here,
+        # we are just letting GC do it.   to allow this to be async
+        # we would need the Result to change how it does "Safe close cursor".
+        # MySQL "cursors" don't actually have state to be "closed" besides
+        # exhausting rows, which we already have done for sync cursor.
+        # another option would be to emulate aiosqlite dialect and assign
+        # cursor only if we are doing server side cursor operation.
         self._rows[:] = []
 
     def execute(self, operation, parameters=None):
index d12203cbd7b8590d78765baf8f400ef9180fe297..8b24a19fd5eb183b57b29d1bde4a00107cd97911 100644 (file)
@@ -26,6 +26,10 @@ from .base import TIMESTAMP
 from .base import VARCHAR
 from .dml import Insert
 from .dml import insert
+from ...util import compat
+
+if compat.py3k:
+    from . import aiosqlite  # noqa
 
 # default dialect
 base.dialect = dialect = pysqlite.dialect
diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py
new file mode 100644 (file)
index 0000000..e4b7d1d
--- /dev/null
@@ -0,0 +1,331 @@
+# sqlite/aiosqlite.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+.. dialect:: sqlite+aiosqlite
+    :name: aiosqlite
+    :dbapi: aiosqlite
+    :connectstring: sqlite+aiosqlite:///file_path
+    :url: https://pypi.org/project/aiosqlite/
+
+The aiosqlite dialect provides support for the SQLAlchemy asyncio interface
+running on top of pysqlite.
+
+aiosqlite is a wrapper around pysqlite that uses a background thread for
+each connection.   It does not actually use non-blocking IO, as SQLite
+databases are not socket-based.  However it does provide a working asyncio
+interface that's useful for testing and prototyping purposes.
+
+Using a special asyncio mediation layer, the aiosqlite dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+    from sqlalchemy.ext.asyncio import create_async_engine
+    engine = create_async_engine("sqlite+aiosqlite:///filename")
+
+The URL passes through all arguments to the ``pysqlite`` driver, so all
+connection arguments are the same as they are for that of :ref:`pysqlite`.
+
+
+"""  # noqa
+
+from .base import SQLiteExecutionContext
+from .pysqlite import SQLiteDialect_pysqlite
+from ... import pool
+from ... import util
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_aiosqlite_cursor:
+    __slots__ = (
+        "_adapt_connection",
+        "_connection",
+        "description",
+        "await_",
+        "_rows",
+        "arraysize",
+        "rowcount",
+        "lastrowid",
+    )
+
+    server_side = False
+
+    def __init__(self, adapt_connection):
+        self._adapt_connection = adapt_connection
+        self._connection = adapt_connection._connection
+        self.await_ = adapt_connection.await_
+        self.arraysize = 1
+        self.rowcount = -1
+        self.description = None
+        self._rows = []
+
+    def close(self):
+        self._rows[:] = []
+
+    def execute(self, operation, parameters=None):
+        try:
+            _cursor = self.await_(self._connection.cursor())
+
+            if parameters is None:
+                self.await_(_cursor.execute(operation))
+            else:
+                self.await_(_cursor.execute(operation, parameters))
+
+            if _cursor.description:
+                self.description = _cursor.description
+                self.lastrowid = self.rowcount = -1
+
+                if not self.server_side:
+                    self._rows = self.await_(_cursor.fetchall())
+            else:
+                self.description = None
+                self.lastrowid = _cursor.lastrowid
+                self.rowcount = _cursor.rowcount
+
+            if not self.server_side:
+                self.await_(_cursor.close())
+            else:
+                self._cursor = _cursor
+        except Exception as error:
+            self._adapt_connection._handle_exception(error)
+
+    def executemany(self, operation, seq_of_parameters):
+        try:
+            _cursor = self.await_(self._connection.cursor())
+            self.await_(_cursor.executemany(operation, seq_of_parameters))
+            self.description = None
+            self.lastrowid = _cursor.lastrowid
+            self.rowcount = _cursor.rowcount
+            self.await_(_cursor.close())
+        except Exception as error:
+            self._adapt_connection._handle_exception(error)
+
+    def setinputsizes(self, *inputsizes):
+        pass
+
+    def __iter__(self):
+        while self._rows:
+            yield self._rows.pop(0)
+
+    def fetchone(self):
+        if self._rows:
+            return self._rows.pop(0)
+        else:
+            return None
+
+    def fetchmany(self, size=None):
+        if size is None:
+            size = self.arraysize
+
+        retval = self._rows[0:size]
+        self._rows[:] = self._rows[size:]
+        return retval
+
+    def fetchall(self):
+        retval = self._rows[:]
+        self._rows[:] = []
+        return retval
+
+
+class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
+    __slots__ = "_cursor"
+
+    server_side = True
+
+    def __init__(self, *arg, **kw):
+        super().__init__(*arg, **kw)
+        self._cursor = None
+
+    def close(self):
+        if self._cursor is not None:
+            self.await_(self._cursor.close())
+            self._cursor = None
+
+    def fetchone(self):
+        return self.await_(self._cursor.fetchone())
+
+    def fetchmany(self, size=None):
+        if size is None:
+            size = self.arraysize
+        return self.await_(self._cursor.fetchmany(size=size))
+
+    def fetchall(self):
+        return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_aiosqlite_connection:
+    await_ = staticmethod(await_only)
+    __slots__ = ("dbapi", "_connection")
+
+    def __init__(self, dbapi, connection):
+        self.dbapi = dbapi
+        self._connection = connection
+
+    @property
+    def isolation_level(self):
+        return self._connection.isolation_level
+
+    @isolation_level.setter
+    def isolation_level(self, value):
+        try:
+            self._connection.isolation_level = value
+        except Exception as error:
+            self._handle_exception(error)
+
+    def create_function(self, *args, **kw):
+        try:
+            self.await_(self._connection.create_function(*args, **kw))
+        except Exception as error:
+            self._handle_exception(error)
+
+    def cursor(self, server_side=False):
+        if server_side:
+            return AsyncAdapt_aiosqlite_ss_cursor(self)
+        else:
+            return AsyncAdapt_aiosqlite_cursor(self)
+
+    def execute(self, *args, **kw):
+        return self.await_(self._connection.execute(*args, **kw))
+
+    def rollback(self):
+        try:
+            self.await_(self._connection.rollback())
+        except Exception as error:
+            self._handle_exception(error)
+
+    def commit(self):
+        try:
+            self.await_(self._connection.commit())
+        except Exception as error:
+            self._handle_exception(error)
+
+    def close(self):
+        # print(">close", self)
+        try:
+            self.await_(self._connection.close())
+        except Exception as error:
+            self._handle_exception(error)
+
+    def _handle_exception(self, error):
+        if (
+            isinstance(error, ValueError)
+            and error.args[0] == "no active connection"
+        ):
+            util.raise_(
+                self.dbapi.sqlite.OperationalError("no active connection"),
+                from_=error,
+            )
+        else:
+            raise error
+
+
+class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
+    __slots__ = ()
+
+    await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_aiosqlite_dbapi:
+    def __init__(self, aiosqlite, sqlite):
+        self.aiosqlite = aiosqlite
+        self.sqlite = sqlite
+        self.paramstyle = "qmark"
+        self._init_dbapi_attributes()
+
+    def _init_dbapi_attributes(self):
+        for name in (
+            "DatabaseError",
+            "Error",
+            "IntegrityError",
+            "NotSupportedError",
+            "OperationalError",
+            "ProgrammingError",
+            "sqlite_version",
+            "sqlite_version_info",
+        ):
+            setattr(self, name, getattr(self.aiosqlite, name))
+
+        for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"):
+            setattr(self, name, getattr(self.sqlite, name))
+
+        for name in ("Binary",):
+            setattr(self, name, getattr(self.sqlite, name))
+
+    def connect(self, *arg, **kw):
+        async_fallback = kw.pop("async_fallback", False)
+
+        # Q. WHY do we need this?
+        # A. Because there is no way to set connection.isolation_level
+        #    otherwise
+        # Q. BUT HOW do you know it is SAFE ?????
+        # A. The only operation that isn't safe is the isolation level set
+        #    operation which aiosqlite appears to have let slip through even
+        #    though pysqlite appears to do check_same_thread for this.
+        #    All execute operations etc. should be safe because they all
+        #    go through the single executor thread.
+
+        kw["check_same_thread"] = False
+
+        connection = self.aiosqlite.connect(*arg, **kw)
+
+        # it's a Thread.   you'll thank us later
+        connection.daemon = True
+
+        if util.asbool(async_fallback):
+            return AsyncAdaptFallback_aiosqlite_connection(
+                self,
+                await_fallback(connection),
+            )
+        else:
+            return AsyncAdapt_aiosqlite_connection(
+                self,
+                await_only(connection),
+            )
+
+
+class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
+    def create_server_side_cursor(self):
+        return self._dbapi_connection.cursor(server_side=True)
+
+
+class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
+    driver = "aiosqlite"
+
+    is_async = True
+
+    supports_server_side_cursors = True
+
+    execution_ctx_cls = SQLiteExecutionContext_aiosqlite
+
+    @classmethod
+    def dbapi(cls):
+        return AsyncAdapt_aiosqlite_dbapi(
+            __import__("aiosqlite"), __import__("sqlite3")
+        )
+
+    @classmethod
+    def get_pool_class(cls, url):
+        if cls._is_url_file_db(url):
+            return pool.NullPool
+        else:
+            return pool.StaticPool
+
+    def is_disconnect(self, e, connection, cursor):
+        if isinstance(
+            e, self.dbapi.OperationalError
+        ) and "no active connection" in str(e):
+            return True
+
+        return super().is_disconnect(e, connection, cursor)
+
+
+dialect = SQLiteDialect_aiosqlite
index a481be27ef0cb0849e5334d4d2a0abb772a35709..d0d12695da38f7b7d690966a8292e5a64d8f8a5a 100644 (file)
@@ -11,13 +11,22 @@ from ...testing.provision import stop_test_class_outside_fixtures
 from ...testing.provision import temp_table_keyword_args
 
 
+# likely needs a generate_driver_url() def here for the --dbdriver part to
+# work
+
+_drivernames = set()
+
+
 @follower_url_from_main.for_db("sqlite")
 def _sqlite_follower_url_from_main(url, ident):
     url = sa_url.make_url(url)
     if not url.database or url.database == ":memory:":
         return url
     else:
-        return sa_url.make_url("sqlite:///%s.db" % ident)
+        _drivernames.add(url.get_driver_name())
+        return sa_url.make_url(
+            "sqlite+%s:///%s.db" % (url.get_driver_name(), ident)
+        )
 
 
 @post_configure_engine.for_db("sqlite")
@@ -35,12 +44,13 @@ def _sqlite_post_configure_engine(url, engine, follower_ident):
             # expected to be already present, so for now it just stays
             # in a given checkout directory.
             dbapi_connection.execute(
-                'ATTACH DATABASE "test_schema.db" AS test_schema'
+                'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
+                % (engine.driver,)
             )
         else:
             dbapi_connection.execute(
-                'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
-                % follower_ident
+                'ATTACH DATABASE "%s_%s_test_schema.db" AS test_schema'
+                % (follower_ident, engine.driver)
             )
 
 
@@ -51,7 +61,10 @@ def _sqlite_create_db(cfg, eng, ident):
 
 @drop_db.for_db("sqlite")
 def _sqlite_drop_db(cfg, eng, ident):
-    for path in ["%s.db" % ident, "%s_test_schema.db" % ident]:
+    for path in [
+        "%s.db" % ident,
+        "%s_%s_test_schema.db" % (ident, eng.driver),
+    ]:
         if os.path.exists(path):
             log.info("deleting SQLite database file: %s" % path)
             os.remove(path)
@@ -71,9 +84,9 @@ def stop_test_class_outside_fixtures(config, db, cls):
 
         # some sqlite file tests are not cleaning up well yet, so do this
         # just to make things simple for now
-        for file in files:
-            if file:
-                os.remove(file)
+        for file_ in files:
+            if file_ and os.path.exists(file_):
+                os.remove(file_)
 
 
 @temp_table_keyword_args.for_db("sqlite")
@@ -89,7 +102,19 @@ def _reap_sqlite_dbs(url, idents):
     for ident in idents:
         # we don't have a config so we can't call _sqlite_drop_db due to the
         # decorator
-        for path in ["%s.db" % ident, "%s_test_schema.db" % ident]:
+        for path in (
+            [
+                "%s.db" % ident,
+            ]
+            + [
+                "%s_test_schema.db" % (drivername,)
+                for drivername in _drivernames
+            ]
+            + [
+                "%s_%s_test_schema.db" % (ident, drivername)
+                for drivername in _drivernames
+            ]
+        ):
             if os.path.exists(path):
                 log.info("deleting SQLite database file: %s" % path)
                 os.remove(path)
index 6ec4896046d6c5061ce3b6b4887e4be483177401..d14316fdbeabfd9901d6c4a47b702d5d490a9665 100644 (file)
@@ -574,6 +574,13 @@ class _ConnectionRecord(object):
             self.__connect()
         return self.connection
 
+    def _is_hard_or_soft_invalidated(self):
+        return (
+            self.connection is None
+            or self.__pool._invalidate_time > self.starttime
+            or (self._soft_invalidate_time > self.starttime)
+        )
+
     def __close(self):
         self.finalize_callback.clear()
         if self.__pool.dispatch.close:
index 08371a31a4f9c905628cf615580065daa2835ef7..730293273adaa94315a23fd4b204c85179aab1d6 100644 (file)
@@ -395,15 +395,11 @@ class StaticPool(Pool):
     """A Pool of exactly one connection, used for all requests.
 
     Reconnect-related functions such as ``recycle`` and connection
-    invalidation (which is also used to support auto-reconnect) are not
-    currently supported by this Pool implementation but may be implemented
-    in a future release.
+    invalidation (which is also used to support auto-reconnect) are only
+    partially supported right now and may not yield good results.
 
-    """
 
-    @util.memoized_property
-    def _conn(self):
-        return self._creator()
+    """
 
     @util.memoized_property
     def connection(self):
@@ -413,9 +409,12 @@ class StaticPool(Pool):
         return "StaticPool"
 
     def dispose(self):
-        if "_conn" in self.__dict__:
-            self._conn.close()
-            self._conn = None
+        if (
+            "connection" in self.__dict__
+            and self.connection.connection is not None
+        ):
+            self.connection.close()
+            del self.__dict__["connection"]
 
     def recreate(self):
         self.logger.info("Pool recreating")
@@ -430,14 +429,26 @@ class StaticPool(Pool):
             dialect=self._dialect,
         )
 
+    def _transfer_from(self, other_static_pool):
+        # used by the test suite to make a new engine / pool without
+        # losing the state of an existing SQLite :memory: connection
+        self._invoke_creator = (
+            lambda crec: other_static_pool.connection.connection
+        )
+
     def _create_connection(self):
-        return self._conn
+        raise NotImplementedError()
 
     def _do_return_conn(self, conn):
         pass
 
     def _do_get(self):
-        return self.connection
+        rec = self.connection
+        if rec._is_hard_or_soft_invalidated():
+            del self.__dict__["connection"]
+            rec = self.connection
+
+        return rec
 
 
 class AssertionPool(Pool):
index a313c298a522eac1f4104b9332d209229d4c16e4..3faf9685757fdfecc9a7932400863af5feaa9066 100644 (file)
@@ -266,7 +266,13 @@ def reconnecting_engine(url=None, options=None):
     return engine
 
 
-def testing_engine(url=None, options=None, future=None, asyncio=False):
+def testing_engine(
+    url=None,
+    options=None,
+    future=None,
+    asyncio=False,
+    transfer_staticpool=False,
+):
     """Produce an engine configured by --options with optional overrides."""
 
     if asyncio:
@@ -300,6 +306,12 @@ def testing_engine(url=None, options=None, future=None, asyncio=False):
 
     engine = create_engine(url, **options)
 
+    if transfer_staticpool:
+        from sqlalchemy.pool import StaticPool
+
+        if config.db is not None and isinstance(config.db.pool, StaticPool):
+            engine.pool._transfer_from(config.db.pool)
+
     if scope == "global":
         if asyncio:
             engine.sync_engine._has_events = True
index f47277b4aea4f2042ccab1be6edf724544f73029..c3eb1b36392bc69c54cf578f562bec22b482ae16 100644 (file)
@@ -50,6 +50,13 @@ class TestBase(object):
     def assert_(self, val, msg=None):
         assert val, msg
 
+    @config.fixture()
+    def connection_no_trans(self):
+        eng = getattr(self, "bind", None) or config.db
+
+        with eng.connect() as conn:
+            yield conn
+
     @config.fixture()
     def connection(self):
         global _connection_fixture_connection
index de2b8f12c3d0a84a08a396921e92be5faf5b1f68..208ba00919cff6e9a9ffe6ae69ad3c5a7861dab5 100644 (file)
@@ -18,6 +18,7 @@ to provide specific inclusion/exclusions.
 import platform
 import sys
 
+from sqlalchemy.pool.impl import QueuePool
 from . import exclusions
 from .. import util
 
@@ -116,6 +117,15 @@ class SuiteRequirements(Requirements):
             or self.deferrable_fks.enabled
         )
 
+    @property
+    def queue_pool(self):
+        """target database is using QueuePool"""
+
+        def go(config):
+            return isinstance(config.db.pool, QueuePool)
+
+        return exclusions.only_if(go)
+
     @property
     def self_referential_foreign_keys(self):
         """Target database must support self-referential foreign keys."""
index a236b10769b6313633883cdca99cf9185dc3c962..c2c17d0ddd1cfd43fa64f2063e6a2d19e416abcd 100644 (file)
@@ -180,8 +180,8 @@ class AutocommitIsolationTest(fixtures.TablesTest):
         with conn.begin():
             conn.execute(self.tables.some_table.delete())
 
-    def test_autocommit_on(self):
-        conn = config.db.connect()
+    def test_autocommit_on(self, connection_no_trans):
+        conn = connection_no_trans
         c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
         self._test_conn_autocommits(c2, True)
 
@@ -189,12 +189,14 @@ class AutocommitIsolationTest(fixtures.TablesTest):
 
         self._test_conn_autocommits(conn, False)
 
-    def test_autocommit_off(self):
-        conn = config.db.connect()
+    def test_autocommit_off(self, connection_no_trans):
+        conn = connection_no_trans
         self._test_conn_autocommits(conn, False)
 
-    def test_turn_autocommit_off_via_default_iso_level(self):
-        conn = config.db.connect()
+    def test_turn_autocommit_off_via_default_iso_level(
+        self, connection_no_trans
+    ):
+        conn = connection_no_trans
         conn = conn.execution_options(isolation_level="AUTOCOMMIT")
         self._test_conn_autocommits(conn, True)
 
index e8dd6cf2c9e08bd517eee84aebadf28274b6d5fe..6c2880ad48cb034557dd60d9d2b61a5ed01a6034 100644 (file)
@@ -227,6 +227,8 @@ class ServerSideCursorsTest(
     __backend__ = True
 
     def _is_server_side(self, cursor):
+        # TODO: this is a huge issue as it prevents these tests from being
+        # usable by third party dialects.
         if self.engine.dialect.driver == "psycopg2":
             return bool(cursor.name)
         elif self.engine.dialect.driver == "pymysql":
@@ -239,7 +241,7 @@ class ServerSideCursorsTest(
             return isinstance(cursor, sscursor)
         elif self.engine.dialect.driver == "mariadbconnector":
             return not cursor.buffered
-        elif self.engine.dialect.driver == "asyncpg":
+        elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
             return cursor.server_side
         else:
             return False
@@ -279,7 +281,14 @@ class ServerSideCursorsTest(
             False,
         ),
         ("for_update_expr", True, select(1).with_for_update(), True),
-        ("for_update_string", True, "SELECT 1 FOR UPDATE", True),
+        # TODO: need a real requirement for this, or dont use this test
+        (
+            "for_update_string",
+            True,
+            "SELECT 1 FOR UPDATE",
+            True,
+            testing.skip_if("sqlite"),
+        ),
         ("text_no_ss", False, text("select 42"), False),
         (
             "text_ss_option",
index c44efba6202ba301b9d17f07623da3a670205d13..e26f305d940f2ce98a5bdd655554c22a4f3a6ceb 100644 (file)
@@ -32,7 +32,7 @@ if not have_greenlet:
             )
 
     def await_only(thing):  # noqa F811
-        return thing
+        _not_implemented()
 
     def await_fallback(thing):  # noqa F81
         return thing
index fd196f4f56cf42cd968bd63ee50294e49f6fa2ba..82251e38fab7e10cbfd0bbac477ac1483e9dc3ba 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -74,6 +74,9 @@ pymysql =
 aiomysql =
     %(asyncio)s
     aiomysql;python_version>="3"
+aiosqlite =
+    %(asyncio)s
+    aiosqlite;python_version>="3"
 
 [egg_info]
 tag_build = dev
@@ -136,7 +139,9 @@ oracle_db_link = test_link
 [db]
 default = sqlite:///:memory:
 sqlite = sqlite:///:memory:
+aiosqlite = sqlite+aiosqlite:///:memory:
 sqlite_file = sqlite:///querytest.db
+aiosqlite_file = sqlite+aiosqlite:///async_querytest.db
 postgresql = postgresql://scott:tiger@127.0.0.1:5432/test
 asyncpg = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test
 asyncpg_fallback = postgresql+asyncpg://scott:tiger@127.0.0.1:5432/test?async_fallback=true
index aee97e8c62315d1fb3a808e0a26857b9620d4564..00bf109c026bfa39fcbcc41220f88f6a84aa1724 100644 (file)
@@ -774,7 +774,7 @@ class AttachedDBTest(fixtures.TestBase):
 
     def _fixture(self):
         meta = self.metadata
-        self.conn = self.engine.connect()
+        self.conn = self.engine.connect()
         Table("created", meta, Column("foo", Integer), Column("bar", String))
         Table("local_only", meta, Column("q", Integer), Column("p", Integer))
 
index 7b997a017c9ee2667268ad98ff6d313a17bef836..ce91f76a3daa8b33bf4a39fbd9f5062e910a0f27 100644 (file)
@@ -2719,7 +2719,7 @@ class HandleErrorTest(fixtures.TestBase):
             ):
                 assert_raises(MySpecialException, conn.get_isolation_level)
 
-    @testing.only_on("sqlite")
+    @testing.only_on("sqlite+pysqlite")
     def test_cursor_close_resultset_failed_connectionless(self):
         engine = engines.testing_engine()
 
@@ -2755,7 +2755,7 @@ class HandleErrorTest(fixtures.TestBase):
         # connection is closed
         assert the_conn[0].closed
 
-    @testing.only_on("sqlite")
+    @testing.only_on("sqlite+pysqlite")
     def test_cursor_close_resultset_failed_explicit(self):
         engine = engines.testing_engine()
 
index 3c2257331daf325355f292080b94de2851695ef4..1c388050475bd1bdb984e23752da29323d62e673 100644 (file)
@@ -329,7 +329,7 @@ class PoolDialectTest(PoolTestBase):
         self._do_test(pool.NullPool, ["R", "CL", "R", "CL"])
 
     def test_static_pool(self):
-        self._do_test(pool.StaticPool, ["R", "R"])
+        self._do_test(pool.StaticPool, ["R", "CL", "R"])
 
 
 class PoolEventsTest(PoolTestBase):
@@ -1960,6 +1960,21 @@ class StaticPoolTest(PoolTestBase):
         p2 = p.recreate()
         assert p._creator is p2._creator
 
+    def test_connect(self):
+        dbapi = MockDBAPI()
+
+        def creator():
+            return dbapi.connect("foo.db")
+
+        p = pool.StaticPool(creator)
+
+        c1 = p.connect()
+        conn = c1.connection
+        c1.close()
+
+        c2 = p.connect()
+        is_(conn, c2.connection)
+
 
 class CreatorCompatibilityTest(PoolTestBase):
     def test_creator_callable_outside_noarg(self):
index 19b9a84a495c8b74737eacc12afaaac2d3a55c64..04cf5440cf887929b13a7cc4ce8acc047081a8b4 100644 (file)
@@ -1383,6 +1383,7 @@ class InvalidateDuringResultTest(fixtures.TestBase):
             "+pymysql",
             "+pg8000",
             "+asyncpg",
+            "+aiosqlite",
             "+aiomysql",
         ],
         "Buffers the result set and doesn't check for connection close",
index 512f0447f8081baed5f4cb1d0c9169659a90c3f4..dc791215cf2d1371c933b7f657ec721565155db2 100644 (file)
@@ -16,6 +16,7 @@ from sqlalchemy import union_all
 from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.ext.asyncio import engine as _async_engine
 from sqlalchemy.ext.asyncio import exc as asyncio_exc
+from sqlalchemy.pool import AsyncAdaptedQueuePool
 from sqlalchemy.testing import async_test
 from sqlalchemy.testing import combinations
 from sqlalchemy.testing import engines
@@ -38,7 +39,7 @@ class EngineFixture(fixtures.TablesTest):
 
     @testing.fixture
     def async_engine(self):
-        return engines.testing_engine(asyncio=True)
+        return engines.testing_engine(asyncio=True, transfer_staticpool=True)
 
     @classmethod
     def define_tables(cls, metadata):
@@ -80,7 +81,9 @@ class AsyncEngineTest(EngineFixture):
     @async_test
     async def test_engine_eq_ne(self, async_engine):
         e2 = _async_engine.AsyncEngine(async_engine.sync_engine)
-        e3 = testing.engines.testing_engine(asyncio=True)
+        e3 = testing.engines.testing_engine(
+            asyncio=True, transfer_staticpool=True
+        )
 
         eq_(async_engine, e2)
         ne_(async_engine, e3)
@@ -197,6 +200,7 @@ class AsyncEngineTest(EngineFixture):
             is_false(conn.in_transaction())
             is_false(conn.in_nested_transaction())
 
+    @testing.requires.queue_pool
     @async_test
     async def test_invalidate(self, async_engine):
         conn = await async_engine.connect()
@@ -254,6 +258,9 @@ class AsyncEngineTest(EngineFixture):
 
         eq_(isolation_level, "SERIALIZABLE")
 
+        await conn.close()
+
+    @testing.requires.queue_pool
     @async_test
     async def test_dispose(self, async_engine):
         c1 = await async_engine.connect()
@@ -263,12 +270,16 @@ class AsyncEngineTest(EngineFixture):
         await c2.close()
 
         p1 = async_engine.pool
-        eq_(async_engine.pool.checkedin(), 2)
+
+        if isinstance(p1, AsyncAdaptedQueuePool):
+            eq_(async_engine.pool.checkedin(), 2)
 
         await async_engine.dispose()
-        eq_(async_engine.pool.checkedin(), 0)
+        if isinstance(p1, AsyncAdaptedQueuePool):
+            eq_(async_engine.pool.checkedin(), 0)
         is_not(p1, async_engine.pool)
 
+    @testing.requires.independent_connections
     @async_test
     async def test_init_once_concurrency(self, async_engine):
         c1 = async_engine.connect()
@@ -362,6 +373,7 @@ class AsyncEngineTest(EngineFixture):
             ):
                 await trans.rollback(),
 
+    @testing.requires.queue_pool
     @async_test
     async def test_pool_exhausted_some_timeout(self, async_engine):
         engine = create_async_engine(
@@ -374,6 +386,7 @@ class AsyncEngineTest(EngineFixture):
             with expect_raises(exc.TimeoutError):
                 await engine.connect()
 
+    @testing.requires.queue_pool
     @async_test
     async def test_pool_exhausted_no_timeout(self, async_engine):
         engine = create_async_engine(
index d308764fbcc11062180e589d21e6f4c0e7ab0f99..032176ea6d60eca764582ac8537edab8d6060ac5 100644 (file)
@@ -25,7 +25,7 @@ class AsyncFixture(_fixtures.FixtureTest):
 
     @testing.fixture
     def async_engine(self):
-        return engines.testing_engine(asyncio=True)
+        return engines.testing_engine(asyncio=True, transfer_staticpool=True)
 
     @testing.fixture
     def async_session(self, async_engine):
index 7f77b01c781e2c6a9f1dd28dbaf27e822ab95d33..0e49ff2c349e64e5e4e8e6b196f81a7fae5bcdcf 100644 (file)
@@ -41,50 +41,58 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
     run_inserts = None
     __backend__ = True
 
-    def test_no_close_transaction_on_flush(self):
+    @testing.fixture
+    def conn(self):
+        with testing.db.connect() as conn:
+            yield conn
+
+    @testing.fixture
+    def future_conn(self):
+
+        engine = Engine._future_facade(testing.db)
+        with engine.connect() as conn:
+            yield conn
+
+    def test_no_close_transaction_on_flush(self, conn):
         User, users = self.classes.User, self.tables.users
 
-        with testing.db.connect() as c:
-            mapper(User, users)
-            s = Session(bind=c)
-            s.begin()
-            tran = s._legacy_transaction()
-            s.add(User(name="first"))
-            s.flush()
-            c.exec_driver_sql("select * from users")
-            u = User(name="two")
-            s.add(u)
-            s.flush()
-            u = User(name="third")
-            s.add(u)
-            s.flush()
-            assert s._legacy_transaction() is tran
-            tran.close()
+        c = conn
+        mapper(User, users)
+        s = Session(bind=c)
+        s.begin()
+        tran = s._legacy_transaction()
+        s.add(User(name="first"))
+        s.flush()
+        c.exec_driver_sql("select * from users")
+        u = User(name="two")
+        s.add(u)
+        s.flush()
+        u = User(name="third")
+        s.add(u)
+        s.flush()
+        assert s._legacy_transaction() is tran
+        tran.close()
 
-    @engines.close_open_connections
-    def test_subtransaction_on_external_subtrans(self):
+    def test_subtransaction_on_external_subtrans(self, conn):
         users, User = self.tables.users, self.classes.User
 
         mapper(User, users)
 
-        with testing.db.connect() as conn:
-            trans = conn.begin()
-            sess = Session(bind=conn, autocommit=False, autoflush=True)
-            sess.begin(subtransactions=True)
-            u = User(name="ed")
-            sess.add(u)
-            sess.flush()
-            sess.commit()  # commit does nothing
-            trans.rollback()  # rolls back
-            assert len(sess.query(User).all()) == 0
-            sess.close()
+        trans = conn.begin()
+        sess = Session(bind=conn, autocommit=False, autoflush=True)
+        sess.begin(subtransactions=True)
+        u = User(name="ed")
+        sess.add(u)
+        sess.flush()
+        sess.commit()  # commit does nothing
+        trans.rollback()  # rolls back
+        assert len(sess.query(User).all()) == 0
+        sess.close()
 
-    @engines.close_open_connections
-    def test_subtransaction_on_external_no_begin(self):
+    def test_subtransaction_on_external_no_begin(self, conn):
         users, User = self.tables.users, self.classes.User
 
         mapper(User, users)
-        conn = testing.db.connect()
         trans = conn.begin()
         sess = Session(bind=conn, autocommit=False, autoflush=True)
         u = User(name="ed")
@@ -96,40 +104,31 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
         sess.close()
 
     @testing.requires.savepoints
-    @engines.close_open_connections
-    def test_external_nested_transaction(self):
+    def test_external_nested_transaction(self, conn):
         users, User = self.tables.users, self.classes.User
 
         mapper(User, users)
-        try:
-            conn = testing.db.connect()
-            trans = conn.begin()
-            sess = Session(bind=conn, autocommit=False, autoflush=True)
-            u1 = User(name="u1")
-            sess.add(u1)
-            sess.flush()
+        trans = conn.begin()
+        sess = Session(bind=conn, autocommit=False, autoflush=True)
+        u1 = User(name="u1")
+        sess.add(u1)
+        sess.flush()
 
-            savepoint = sess.begin_nested()
-            u2 = User(name="u2")
-            sess.add(u2)
-            sess.flush()
-            savepoint.rollback()
+        savepoint = sess.begin_nested()
+        u2 = User(name="u2")
+        sess.add(u2)
+        sess.flush()
+        savepoint.rollback()
 
-            trans.commit()
-            assert len(sess.query(User).all()) == 1
-        except Exception:
-            conn.close()
-            raise
+        trans.commit()
+        assert len(sess.query(User).all()) == 1
 
-    @engines.close_open_connections
-    def test_subtransaction_on_external_commit_future(self):
+    def test_subtransaction_on_external_commit_future(self, future_conn):
         users, User = self.tables.users, self.classes.User
 
         mapper(User, users)
 
-        engine = Engine._future_facade(testing.db)
-
-        conn = engine.connect()
+        conn = future_conn
         conn.begin()
 
         sess = Session(bind=conn, autocommit=False, autoflush=True)
@@ -141,15 +140,12 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
         assert len(sess.query(User).all()) == 0
         sess.close()
 
-    @engines.close_open_connections
-    def test_subtransaction_on_external_rollback_future(self):
+    def test_subtransaction_on_external_rollback_future(self, future_conn):
         users, User = self.tables.users, self.classes.User
 
         mapper(User, users)
 
-        engine = Engine._future_facade(testing.db)
-
-        conn = engine.connect()
+        conn = future_conn
         conn.begin()
 
         sess = Session(bind=conn, autocommit=False, autoflush=True)
@@ -162,29 +158,26 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
         sess.close()
 
     @testing.requires.savepoints
-    @engines.close_open_connections
-    def test_savepoint_on_external_future(self):
+    def test_savepoint_on_external_future(self, future_conn):
         users, User = self.tables.users, self.classes.User
 
         mapper(User, users)
 
-        engine = Engine._future_facade(testing.db)
-
-        with engine.connect() as conn:
-            conn.begin()
-            sess = Session(bind=conn, autocommit=False, autoflush=True)
-            u1 = User(name="u1")
-            sess.add(u1)
-            sess.flush()
+        conn = future_conn
+        conn.begin()
+        sess = Session(bind=conn, autocommit=False, autoflush=True)
+        u1 = User(name="u1")
+        sess.add(u1)
+        sess.flush()
 
-            sess.begin_nested()
-            u2 = User(name="u2")
-            sess.add(u2)
-            sess.flush()
-            sess.rollback()
+        sess.begin_nested()
+        u2 = User(name="u2")
+        sess.add(u2)
+        sess.flush()
+        sess.rollback()
 
-            conn.commit()
-            assert len(sess.query(User).all()) == 1
+        conn.commit()
+        assert len(sess.query(User).all()) == 1
 
     @testing.requires.savepoints
     def test_nested_accounting_new_items_removed(self):
@@ -1214,18 +1207,18 @@ class SubtransactionRecipeTest(FixtureTest):
         users, User = self.tables.users, self.classes.User
 
         mapper(User, users)
-        conn = testing.db.connect()
-        trans = conn.begin()
-        sess = Session(conn, future=self.future)
+        with testing.db.connect() as conn:
+            trans = conn.begin()
+            sess = Session(conn, future=self.future)
 
-        with subtransaction_recipe(sess):
-            u = User(name="ed")
-            sess.add(u)
-            sess.flush()
-            # commit does nothing
-        trans.rollback()  # rolls back
-        assert len(sess.query(User).all()) == 0
-        sess.close()
+            with subtransaction_recipe(sess):
+                u = User(name="ed")
+                sess.add(u)
+                sess.flush()
+                # commit does nothing
+            trans.rollback()  # rolls back
+            assert len(sess.query(User).all()) == 0
+            sess.close()
 
     def test_recipe_commit_one(self, subtransaction_recipe):
         User, users = self.classes.User, self.tables.users
index 33f5e7fd0612436daaf4ef21e11bf1d5557ac3c8..971ed07733470bc59f61d8a4bcced0cda987cc80 100644 (file)
@@ -889,6 +889,7 @@ class DefaultRequirements(SuiteRequirements):
         return fails_on_everything_except(
             "mysql",
             "mariadb",
+            "sqlite+aiosqlite",
             "sqlite+pysqlite",
             "sqlite+pysqlcipher",
             "sybase",
@@ -925,6 +926,7 @@ class DefaultRequirements(SuiteRequirements):
             "mysql",
             "mariadb",
             "sqlite+pysqlite",
+            "sqlite+aiosqlite",
             "sqlite+pysqlcipher",
             "mssql",
         )
diff --git a/tox.ini b/tox.ini
index 96604887756c5fd39848ad0deaf11cf792a9deb7..e01b82372fdf5e9072ac498936a3882365620a40 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -23,6 +23,8 @@ deps=
      mock; python_version < '3.3'
      importlib_metadata; python_version < '3.8'
 
+     sqlite: .[aiosqlite]
+     sqlite_file: .[aiosqlite]
      postgresql: .[postgresql]
      postgresql: .[postgresql_asyncpg]; python_version >= '3'
      postgresql: .[postgresql_pg8000]; python_version >= '3'
@@ -85,6 +87,8 @@ setenv=
 
     sqlite: SQLITE={env:TOX_SQLITE:--db sqlite}
     sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
+    py3{,5,6,7,8,9,10,11}-sqlite: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite}
+    py3{,5,6,7,8,9,10,11}-sqlite_file: EXTRA_SQLITE_DRIVERS={env:EXTRA_SQLITE_DRIVERS:--dbdriver sqlite --dbdriver aiosqlite}
 
     postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql}
     py2{,7}-postgresql: POSTGRESQL={env:TOX_POSTGRESQL_PY2K:{env:TOX_POSTGRESQL:--db postgresql}}
@@ -104,13 +108,13 @@ setenv=
 # tox as of 2.0 blocks all environment variables from the
 # outside, unless they are here (or in TOX_TESTENV_PASSENV,
 # wildcards OK).  Need at least these
-passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_POSTGRESQL_PY2K TOX_MYSQL TOX_MYSQL_PY2K TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS EXTRA_PG_DRIVERS EXTRA_MYSQL_DRIVERS
+passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_POSTGRESQL_PY2K TOX_MYSQL TOX_MYSQL_PY2K TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS EXTRA_SQLITE_DRIVERS EXTRA_PG_DRIVERS EXTRA_MYSQL_DRIVERS
 
 # for nocext, we rm *.so in lib in case we are doing usedevelop=True
 commands=
   cext: /bin/true
   nocext: sh -c "rm -f lib/sqlalchemy/*.so"
-  {env:BASECOMMAND} {env:WORKERS} {env:SQLITE:} {env:POSTGRESQL:} {env:EXTRA_PG_DRIVERS:} {env:MYSQL:} {env:EXTRA_MYSQL_DRIVERS:} {env:ORACLE:} {env:MSSQL:} {env:BACKENDONLY:} {env:IDENTS:} {env:MEMUSAGE:} {env:COVERAGE:} {posargs}
+  {env:BASECOMMAND} {env:WORKERS} {env:SQLITE:} {env:EXTRA_SQLITE_DRIVERS:} {env:POSTGRESQL:} {env:EXTRA_PG_DRIVERS:} {env:MYSQL:} {env:EXTRA_MYSQL_DRIVERS:} {env:ORACLE:} {env:MSSQL:} {env:BACKENDONLY:} {env:IDENTS:} {env:MEMUSAGE:} {env:COVERAGE:} {posargs}
   oracle,mssql,sqlite_file: python reap_dbs.py db_idents.txt