--- /dev/null
+- orm
+ - the 'expire' option on query.update() has been renamed to 'fetch', thus matching
+ that of query.delete()
+ - query.update() and query.delete() both default to 'evaluate' for the synchronize
+ strategy.
+ - the 'synchronize' strategy for update() and delete() raises an error on failure.
+ There is no implicit fallback onto "fetch". Failure of evaluation is based
+ on the structure of criteria, so success/failure is deterministic based on
+ code structure.
+ - the "dont_load=True" flag on Session.merge() is deprecated and is now
+ "load=False".
+
+- sql
+ - returning() support is native to insert(), update(), delete(). Implementations
+ of varying levels of functionality exist for Postgresql, Firebird, MSSQL and
+ Oracle. returning() can be called explicitly with column expressions which
+ are then returned in the resultset, usually via fetchone() or first().
+
+ insert() constructs will also use RETURNING implicitly to get newly
+ generated primary key values, if the database version in use supports it
+ (a version number check is performed). This occurs if no end-user
+ returning() was specified.
+
+ - Databases which rely upon postfetch of "last inserted id" to get at a
+ generated sequence value (i.e. MySQL, MS-SQL) now work correctly
+ when there is a composite primary key where the "autoincrement" column
+ is not the first primary key column in the table.
+
+ - the last_inserted_ids() method has been renamed to the descriptor
+ "inserted_primary_key".
+
+- engines
+ - transaction isolation level may be specified with
+ create_engine(... isolation_level="..."); available on
+ postgresql and sqlite. [ticket:443]
+ - added first() method to ResultProxy, returns first row and closes
+ result set immediately.
+
+- schema
+ - deprecated metadata.connect() and threadlocalmetadata.connect() have been
+ removed - send the "bind" attribute to bind a metadata.
+ - deprecated metadata.table_iterator() method removed (use sorted_tables)
+ - the "metadata" argument is removed from DefaultGenerator and subclasses,
+ but remains locally present on Sequence, which is a standalone construct
+ in DDL.
+ - Removed public mutability from Index and Constraint objects:
+ - ForeignKeyConstraint.append_element()
+ - Index.append_column()
+ - UniqueConstraint.append_column()
+ - PrimaryKeyConstraint.add()
+ - PrimaryKeyConstraint.remove()
+ These should be constructed declaratively (i.e. in one construction).
+ - UniqueConstraint, Index, PrimaryKeyConstraint all accept lists
+ of column names or column objects as arguments.
+ - Other removed things:
+ - Table.key (no idea what this was for)
+ - Table.primary_key is not assignable - use table.append_constraint(PrimaryKeyConstraint(...))
+ - Column.bind (get via column.table.bind)
+ - Column.metadata (get via column.table.metadata)
+ - the use_alter flag on ForeignKey is now a shortcut option for operations that
+ can be hand-constructed using the DDL() event system. A side effect of this refactor
+ is that ForeignKeyConstraint objects with use_alter=True will *not* be emitted on
+ SQLite, which does not support ALTER for foreign keys. This has no effect on SQLite's
+ behavior since SQLite does not actually honor FOREIGN KEY constraints.
+
+- DDL
+ - the DDL() system has been greatly expanded:
+ - CreateTable()
+ - DropTable()
+ - AddConstraint()
+ - DropConstraint()
+ - CreateIndex()
+ - DropIndex()
+ - CreateSequence()
+ - DropSequence()
+ - these support "on" and "execute-at()" just like
+ plain DDL() does.
+ - the "on" callable passed to DDL() needs to accept **kw arguments.
+ In the case of MetaData before/after create/drop, the list of
+ Table objects for which CREATE/DROP DDL is to be issued is passed
+ as the kw argument "tables". This is necessary for metadata-level
+ DDL that is dependent on the presence of specific tables.
+
+- dialect refactor
+ - the "owner" keyword argument is removed from Table. Use "schema" to
+ represent any namespaces to be prepended to the table name.
+ - server_version_info becomes a static attribute.
+ - dialects receive an initialize() event on initial connection to
+ determine connection properties.
+ - dialects receive a visit_pool event have an opportunity to
+ establish pool listeners.
+ - cached TypeEngine classes are cached per-dialect class instead of per-dialect.
+ - Deprecated Dialect.get_params() removed.
+ - Dialect.get_rowcount() has been renamed to a descriptor "rowcount", and calls
+ cursor.rowcount directly. Dialects which need to hardwire a rowcount in for
+ certain calls should override the method to provide different behavior.
+ - functions and operators generated by the compiler now use (almost) regular
+ dispatch functions of the form "visit_<opname>" and "visit_<funcname>_fn"
+ to provide customed processing. This replaces the need to copy the "functions"
+ and "operators" dictionaries in compiler subclasses with straightforward
+ visitor methods, and also allows compiler subclasses complete control over
+ rendering, as the full _Function or _BinaryExpression object is passed in.
+
+- postgresql
+ - the "postgres" dialect is now named "postgresql" ! Connection strings look
+ like:
+
+ postgresql://scott:tiger@localhost/test
+ postgresql+pg8000://scott:tiger@localhost/test
+
+ The "postgres" name remains for backwards compatiblity in the following ways:
+
+ - There is a "postgres.py" dummy dialect which allows old URLs to work,
+ i.e. postgres://scott:tiger@localhost/test
+
+ - The "postgres" name can be imported from the old "databases" module,
+ i.e. "from sqlalchemy.databases import postgres" as well as "dialects",
+ "from sqlalchemy.dialects.postgres import base as pg", will send
+ a deprecation warning.
+
+ - Special expression arguments are now named "postgresql_returning"
+ and "postgresql_where", but the older "postgres_returning" and
+ "postgres_where" names still work with a deprecation warning.
+
+- mysql
+ - all the _detect_XXX() functions now run once underneath dialect.initialize()
+
+- oracle
+ - support for cx_Oracle's "native unicode" mode which does not require NLS_LANG
+ to be set. Use the latest 5.0.2 or later of cx_oracle.
+ - an NCLOB type is added to the base types.
+ - func.char_length is a generic function for LENGTH
+ - ForeignKey() which includes onupdate=<value> will emit a warning, not
+ emit ON UPDATE CASCADE which is unsupported by oracle
+ - the keys() method of RowProxy() now returns the result column names *normalized*
+ to be SQLAlchemy case insensitive names. This means they will be lower case
+ for case insensitive names, whereas the DBAPI would normally return them
+ as UPPERCASE names. This allows row keys() to be compatible with further
+ SQLAlchemy operations.
+
+- firebird
+ - the keys() method of RowProxy() now returns the result column names *normalized*
+ to be SQLAlchemy case insensitive names. This means they will be lower case
+ for case insensitive names, whereas the DBAPI would normally return them
+ as UPPERCASE names. This allows row keys() to be compatible with further
+ SQLAlchemy operations.
+
+- new dialects
+ - postgresql+pg8000
+ - postgresql+pypostgresql (partial)
+ - postgresql+zxjdbc
+ - mysql+pyodbc
+ - mysql+zxjdbc
+
+- mssql
+ - the "has_window_funcs" flag is removed. LIMIT/OFFSET usage will use ROW NUMBER as always,
+ and if on an older version of SQL Server, the operation fails. The behavior is exactly
+ the same except the error is raised by SQL server instead of the dialect, and no
+ flag setting is required to enable it.
+ - the "auto_identity_insert" flag is removed. This feature always takes effect
+ when an INSERT statement overrides a column that is known to have a sequence on it.
+ As with "has_window_funcs", if the underlying driver doesn't support this, then you
+ can't do this operation in any case, so there's no point in having a flag.
+ - using new dialect.initialize() feature to set up version-dependent behavior.
+
+- types
+ - PickleType now uses == for comparison of values when mutable=True,
+ unless the "comparator" argument with a comparsion function is specified to the type.
+ Objects being pickled will be compared based on identity (which defeats the purpose
+ of mutable=True) if __eq__() is not overridden or a comparison function is not provided.
+ - The default "precision" and "scale" arguments of Numeric and Float have been removed
+ and now default to None. NUMERIC and FLOAT will be rendered with no numeric arguments
+ by default unless these values are provided.
+ - AbstractType.get_search_list() is removed - the games that was used for are no
+ longer necessary.
+
+
\ No newline at end of file
- Repaired the printing of SQL exceptions which are not
based on parameters or are not executemany() style.
-- postgres
+- postgresql
- Deprecated the hardcoded TIMESTAMP function, which when
used as func.TIMESTAMP(value) would render "TIMESTAMP value".
This breaks on some platforms as PostgreSQL doesn't allow
fail on recent versions of pysqlite which raise
an error when fetchone() called with no rows present.
-- postgres
+- postgresql
- Index reflection won't fail when an index with
multiple expressions is encountered.
http://somethingaboutorange.com/mrl/projects/nose/0.11.1/index.html
-
SQLAlchemy implements a nose plugin that must be present when tests are run.
This plugin is available when SQLAlchemy is installed via setuptools.
+INSTANT TEST RUNNER
+-------------------
+
+A plain vanilla run of all tests using sqlite can be run via setup.py:
+
+ $ python setup.py test
+
+Setuptools will take care of the rest ! To run nose directly and have
+its full set of options available, read on...
+
SETUP
-----
Tests will target an in-memory SQLite database by default. To test against
another database, use the --dburi option with any standard SQLAlchemy URL:
- --dburi=postgres://user:password@localhost/test
+ --dburi=postgresql://user:password@localhost/test
-Use an empty database and a database user with general DBA privileges. The
-test suite will be creating and dropping many tables and other DDL, and
-preexisting tables will interfere with the tests
+Use an empty database and a database user with general DBA privileges.
+The test suite will be creating and dropping many tables and other DDL, and
+preexisting tables will interfere with the tests.
+
+IMPORTANT !: please see TIPS at the end if your are testing on POSTGRESQL,
+ORACLE, or MSSQL - additional steps are required to prepare a test database.
If you'll be running the tests frequently, database aliases can save a lot of
typing. The --dbs option lists the built-in aliases and their matching URLs:
Available --db options (use --dburi to override)
mysql mysql://scott:tiger@127.0.0.1:3306/test
oracle oracle://scott:tiger@127.0.0.1:1521
- postgres postgres://scott:tiger@127.0.0.1:5432/test
+ postgresql postgresql://scott:tiger@127.0.0.1:5432/test
[...]
To run tests against an aliased database:
- $ nosetests --db=postgres
+ $ nosetests --db=postgresql
To customize the URLs with your own users or hostnames, make a simple .ini
file called `test.cfg` at the top level of the SQLAlchemy source distribution
or a `.satest.cfg` in your home directory:
[db]
- postgres=postgres://myuser:mypass@localhost/mydb
+ postgresql=postgresql://myuser:mypass@localhost/mydb
Your custom entries will override the defaults and you'll see them reflected
in the output of --dbs.
TIPS
----
-PostgreSQL: The tests require an 'alt_schema' and 'alt_schema_2' to be present in
+
+PostgreSQL: The tests require an 'test_schema' and 'test_schema_2' to be present in
the testing database.
-PostgreSQL: When running the tests on postgres, postgres can get slower and
-slower each time you run the tests. This seems to be related to the constant
-creation/dropping of tables. Running a "VACUUM FULL" on the database will
-speed it up again.
+Oracle: the database owner should be named "scott" (this will be fixed),
+and an additional "owner" named "ed" is required:
+
+1. create a user 'ed' in the oracle database.
+2. in 'ed', issue the following statements:
+ create table parent(id integer primary key, data varchar2(50));
+ create table child(id integer primary key, data varchar2(50), parent_id integer references parent(id));
+ create synonym ptable for parent;
+ create synonym ctable for child;
+ grant all on parent to scott; (or to whoever you run the oracle tests as)
+ grant all on child to scott; (same)
+ grant all on ptable to scott;
+ grant all on ctable to scott;
MSSQL: Tests that involve multiple connections require Snapshot Isolation
ability implented on the test database in order to prevent deadlocks that will
--- /dev/null
+Trunk is now moved to SQLAlchemy 0.6.
+
+An ongoing wiki page of changes etc. is at:
+
+http://www.sqlalchemy.org/trac/wiki/06Migration
+
+SQLAlchemy 0.5 is now in a maintenance branch. Get it at:
+
+http://svn.sqlalchemy.org/sqlalchemy/branches/rel_0_5
+
+
if __name__ == '__main__':
- convert("test/orm/inheritance/abc_inheritance.py")
+ import sys
+ for f in sys.argv[1:]:
+ convert(f)
# walk()
This is the MIT license: `<http://www.opensource.org/licenses/mit-license.php>`_
-Copyright (c) 2005, 2006, 2007, 2008 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael
+Copyright (c) 2005, 2006, 2007, 2008, 2009 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael
Bayer.
Permission is hereby granted, free of charge, to any person obtaining a copy of this
Creating an engine is just a matter of issuing a single call, :func:`create_engine()`::
- engine = create_engine('postgres://scott:tiger@localhost:5432/mydatabase')
+ engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase')
-The above engine invokes the ``postgres`` dialect and a connection pool which references ``localhost:5432``.
+The above engine invokes the ``postgresql`` dialect and a connection pool which references ``localhost:5432``.
The engine can be used directly to issue SQL to the database. The most generic way is to use connections, which you get via the ``connect()`` method::
Supported Databases
====================
-Recall that the ``Dialect`` is used to describe how to talk to a specific kind of database. Dialects are included with SQLAlchemy for many different backends; these can be seen as a Python package within the :mod:`~sqlalchemy.databases` package. Each dialect requires the appropriate DBAPI drivers to be installed separately.
+Recall that the ``Dialect`` is used to describe how to talk to a specific kind of database. Dialects are included with SQLAlchemy for many different backends; these can be seen as a Python package within the :mod:`~sqlalchemy.dialect` package. Each dialect requires the appropriate DBAPI drivers to be installed separately.
Dialects included with SQLAlchemy fall under one of three categories: supported, experimental, and third party. Supported drivers are those which work against the most common databases available in the open source world, including SQLite, PostgreSQL, MySQL, and Firebird. Very popular commercial databases which provide easy access to test platforms are also supported, these currently include MSSQL and Oracle. These dialects are tested frequently and the level of support should be close to 100% for each.
Downloads for each DBAPI at the time of this writing are as follows:
* Supported Dialects
-
- - PostgreSQL: `psycopg2 <http://www.initd.org/tracker/psycopg/wiki/PsycopgTwo>`_
+ - PostgreSQL: `psycopg2 <http://www.initd.org/tracker/psycopg/wiki/PsycopgTwo>`_ `pg8000 <http://pybrary.net/pg8000/>`_
+ - PostgreSQL on Jython: `PostgreSQL JDBC Driver <http://jdbc.postgresql.org/>`_
- SQLite: `sqlite3 <http://www.python.org/doc/2.5.2/lib/module-sqlite3.html>`_ (included in Python 2.5 or greater) `pysqlite <http://initd.org/tracker/pysqlite>`_
- MySQL: `MySQLDB (a.k.a. mysql-python) <http://sourceforge.net/projects/mysql-python>`_
+ - MySQL on Jython: `JDBC Driver for MySQL <http://www.mysql.com/products/connector/>`_
- Oracle: `cx_Oracle <http://cx-oracle.sourceforge.net/>`_
- Firebird: `kinterbasdb <http://kinterbasdb.sourceforge.net/>`_
- MS-SQL, MSAccess: `pyodbc <http://pyodbc.sourceforge.net/>`_ (recommended) `adodbapi <http://adodbapi.sourceforge.net/>`_ `pymssql <http://pymssql.sourceforge.net/>`_
* Experimental Dialects
-
- MSAccess: `pyodbc <http://pyodbc.sourceforge.net/>`_
- Informix: `informixdb <http://informixdb.sourceforge.net/>`_
- Sybase: TODO
- MAXDB: TODO
* Third Party Dialects
-
- DB2/Informix IDS: `ibm-db <http://code.google.com/p/ibm-db/>`_
The SQLAlchemy Wiki contains a page of database notes, describing whatever quirks and behaviors have been observed. Its a good place to check for issues with specific databases. `Database Notes <http://www.sqlalchemy.org/trac/wiki/DatabaseNotes>`_
SQLAlchemy indicates the source of an Engine strictly via `RFC-1738 <http://rfc.net/rfc1738.html>`_ style URLs, combined with optional keyword arguments to specify options for the Engine. The form of the URL is:
- driver://username:password@host:port/database
+ dialect+driver://username:password@host:port/database
+
+Dialect names include the identifying name of the SQLAlchemy dialect which include ``sqlite``, ``mysql``, ``postgresql``, ``oracle``, ``mssql``, and ``firebird``. The drivername is the name of the DBAPI to be used to connect to the database using all lowercase letters. If not specified, a "default" DBAPI will be imported if available - this default is typically the most widely known driver available for that backend (i.e. cx_oracle, pysqlite/sqlite3, psycopg2, mysqldb). For Jython connections, the driver is always `zxjdbc`, which is the JDBC-DBAPI bridge included with Jython.
+
+.. sourcecode:: python+sql
+
+ # postgresql - psycopg2 is the default driver.
+ pg_db = create_engine('postgresql://scott:tiger@localhost/mydatabase')
+ pg_db = create_engine('postgresql+psycopg2://scott:tiger@localhost/mydatabase')
+ pg_db = create_engine('postgresql+pg8000://scott:tiger@localhost/mydatabase')
-Dialect names include the identifying name of the SQLAlchemy dialect which include ``sqlite``, ``mysql``, ``postgres``, ``oracle``, ``mssql``, and ``firebird``. In SQLAlchemy 0.5 and earlier, the DBAPI implementation is automatically selected if more than one are available - currently this includes only MSSQL (pyodbc is the default, then adodbapi, then pymssql) and SQLite (sqlite3 is the default, or pysqlite if sqlite3 is not availble). When using MSSQL, ``create_engine()`` accepts a ``module`` argument which specifies the name of the desired DBAPI to be used, overriding the default behavior.
+ # postgresql on Jython
+ pg_db = create_engine('postgresql+zxjdbc://scott:tiger@localhost/mydatabase')
+
+ # mysql - MySQLdb (mysql-python) is the default driver
+ mysql_db = create_engine('mysql://scott:tiger@localhost/foo')
+ mysql_db = create_engine('mysql+mysqldb://scott:tiger@localhost/foo')
+
+ # mysql on Jython
+ mysql_db = create_engine('mysql+zxjdbc://localhost/foo')
- .. sourcecode:: python+sql
-
- # postgresql
- pg_db = create_engine('postgres://scott:tiger@localhost/mydatabase')
+ # mysql with pyodbc (buggy)
+ mysql_db = create_engine('mysql+pyodbc://scott:tiger@some_dsn')
- # mysql
- mysql_db = create_engine('mysql://scott:tiger@localhost/mydatabase')
-
- # oracle
+ # oracle - cx_oracle is the default driver
oracle_db = create_engine('oracle://scott:tiger@127.0.0.1:1521/sidname')
-
+
# oracle via TNS name
- oracle_db = create_engine('oracle://scott:tiger@tnsname')
-
+ oracle_db = create_engine('oracle+cx_oracle://scott:tiger@tnsname')
+
# mssql using ODBC datasource names. PyODBC is the default driver.
mssql_db = create_engine('mssql://mydsn')
- mssql_db = create_engine('mssql://scott:tiger@mydsn')
-
- # firebird
- firebird_db = create_engine('firebird://scott:tiger@localhost/sometest.gdm')
-
+ mssql_db = create_engine('mssql+pyodbc://mydsn')
+ mssql_db = create_engine('mssql+adodbapi://mydsn')
+ mssql_db = create_engine('mssql+pyodbc://username:password@mydsn')
+
SQLite connects to file based databases. The same URL format is used, omitting the hostname, and using the "file" portion as the filename of the database. This has the effect of four slashes being present for an absolute file path::
# sqlite://<nohostname>/<path>
Custom DBAPI connect() arguments
--------------------------------
-
Custom arguments used when issuing the ``connect()`` call to the underlying DBAPI may be issued in three distinct ways. String-based arguments can be passed directly from the URL string as query arguments:
.. sourcecode:: python+sql
- db = create_engine('postgres://scott:tiger@localhost/test?argument1=foo&argument2=bar')
+ db = create_engine('postgresql://scott:tiger@localhost/test?argument1=foo&argument2=bar')
If SQLAlchemy's database connector is aware of a particular query argument, it may convert its type from string to its proper type.
.. sourcecode:: python+sql
- db = create_engine('postgres://scott:tiger@localhost/test', connect_args = {'argument1':17, 'argument2':'bar'})
+ db = create_engine('postgresql://scott:tiger@localhost/test', connect_args = {'argument1':17, 'argument2':'bar'})
The most customizable connection method of all is to pass a ``creator`` argument, which specifies a callable that returns a DBAPI connection:
def connect():
return psycopg.connect(user='scott', host='localhost')
- db = create_engine('postgres://', creator=connect)
+ db = create_engine('postgresql://', creator=connect)
.. _create_engine_args:
.. sourcecode:: python+sql
- db = create_engine('postgres://...', encoding='latin1', echo=True)
+ db = create_engine('postgresql://...', encoding='latin1', echo=True)
Options common to all database dialects are described at :func:`~sqlalchemy.create_engine`.
* the ``inline=True`` flag is not set on the ``Insert()`` or ``Update()`` construct.
-For a statement execution which is not an executemany, the returned ``ResultProxy`` will contain a collection accessible via ``result.postfetch_cols()`` which contains a list of all ``Column`` objects which had an inline-executed default. Similarly, all parameters which were bound to the statement, including all Python and SQL expressions which were pre-executed, are present in the ``last_inserted_params()`` or ``last_updated_params()`` collections on ``ResultProxy``. The ``last_inserted_ids()`` collection contains a list of primary key values for the row inserted.
+For a statement execution which is not an executemany, the returned ``ResultProxy`` will contain a collection accessible via ``result.postfetch_cols()`` which contains a list of all ``Column`` objects which had an inline-executed default. Similarly, all parameters which were bound to the statement, including all Python and SQL expressions which were pre-executed, are present in the ``last_inserted_params()`` or ``last_updated_params()`` collections on ``ResultProxy``. The ``inserted_primary_key`` collection contains a list of primary key values for the row inserted.
DDL-Level Defaults
-------------------
-Access
-======
+Microsoft Access
+================
-.. automodule:: sqlalchemy.databases.access
+.. automodule:: sqlalchemy.dialects.access.base
Firebird
========
-.. automodule:: sqlalchemy.databases.firebird
+.. automodule:: sqlalchemy.dialects.firebird.base
sqlalchemy.databases
====================
+Supported Databases
+-------------------
+
+These backends are fully operational with
+current versions of SQLAlchemy.
+
.. toctree::
:glob:
- access
firebird
- informix
- maxdb
mssql
mysql
oracle
- postgres
+ postgresql
sqlite
+
+Unsupported Databases
+---------------------
+
+These backends are untested and may not be completely
+ported to current versions of SQLAlchemy.
+
+.. toctree::
+ :glob:
+
+ access
+ informix
+ maxdb
sybase
Informix
========
-.. automodule:: sqlalchemy.databases.informix
+.. automodule:: sqlalchemy.dialects.informix.base
MaxDB
=====
-.. automodule:: sqlalchemy.databases.maxdb
+.. automodule:: sqlalchemy.dialects.maxdb.base
-SQL Server
-==========
+Microsoft SQL Server
+====================
+
+.. automodule:: sqlalchemy.dialects.mssql.base
+
+PyODBC
+------
+.. automodule:: sqlalchemy.dialects.mssql.pyodbc
+
+AdoDBAPI
+--------
+.. automodule:: sqlalchemy.dialects.mssql.adodbapi
+
+pymssql
+-------
+.. automodule:: sqlalchemy.dialects.mssql.pymssql
+
-.. automodule:: sqlalchemy.databases.mssql
MySQL
=====
-.. automodule:: sqlalchemy.databases.mysql
+.. automodule:: sqlalchemy.dialects.mysql.base
MySQL Column Types
------------------
-.. autoclass:: MSNumeric
+.. autoclass:: NUMERIC
:members: __init__
:show-inheritance:
-.. autoclass:: MSDecimal
+.. autoclass:: DECIMAL
:members: __init__
:show-inheritance:
-.. autoclass:: MSDouble
+.. autoclass:: DOUBLE
:members: __init__
:show-inheritance:
-.. autoclass:: MSReal
+.. autoclass:: REAL
:members: __init__
:show-inheritance:
-.. autoclass:: MSFloat
+.. autoclass:: FLOAT
:members: __init__
:show-inheritance:
-.. autoclass:: MSInteger
+.. autoclass:: INTEGER
:members: __init__
:show-inheritance:
-.. autoclass:: MSBigInteger
+.. autoclass:: BIGINT
:members: __init__
:show-inheritance:
-.. autoclass:: MSMediumInteger
+.. autoclass:: MEDIUMINT
:members: __init__
:show-inheritance:
-.. autoclass:: MSTinyInteger
+.. autoclass:: TINYINT
:members: __init__
:show-inheritance:
-.. autoclass:: MSSmallInteger
+.. autoclass:: SMALLINT
:members: __init__
:show-inheritance:
-.. autoclass:: MSBit
+.. autoclass:: BIT
:members: __init__
:show-inheritance:
-.. autoclass:: MSDateTime
+.. autoclass:: DATETIME
:members: __init__
:show-inheritance:
-.. autoclass:: MSDate
+.. autoclass:: DATE
:members: __init__
:show-inheritance:
-.. autoclass:: MSTime
+.. autoclass:: TIME
:members: __init__
:show-inheritance:
-.. autoclass:: MSTimeStamp
+.. autoclass:: TIMESTAMP
:members: __init__
:show-inheritance:
-.. autoclass:: MSYear
+.. autoclass:: YEAR
:members: __init__
:show-inheritance:
-.. autoclass:: MSText
+.. autoclass:: TEXT
:members: __init__
:show-inheritance:
-.. autoclass:: MSTinyText
+.. autoclass:: TINYTEXT
:members: __init__
:show-inheritance:
-.. autoclass:: MSMediumText
+.. autoclass:: MEDIUMTEXT
:members: __init__
:show-inheritance:
-.. autoclass:: MSLongText
+.. autoclass:: LONGTEXT
:members: __init__
:show-inheritance:
-.. autoclass:: MSString
+.. autoclass:: VARCHAR
:members: __init__
:show-inheritance:
-.. autoclass:: MSChar
+.. autoclass:: CHAR
:members: __init__
:show-inheritance:
-.. autoclass:: MSNVarChar
+.. autoclass:: NVARCHAR
:members: __init__
:show-inheritance:
-.. autoclass:: MSNChar
+.. autoclass:: NCHAR
:members: __init__
:show-inheritance:
-.. autoclass:: MSVarBinary
+.. autoclass:: VARBINARY
:members: __init__
:show-inheritance:
-.. autoclass:: MSBinary
+.. autoclass:: BINARY
:members: __init__
:show-inheritance:
-.. autoclass:: MSBlob
+.. autoclass:: BLOB
:members: __init__
:show-inheritance:
-.. autoclass:: MSTinyBlob
+.. autoclass:: TINYBLOB
:members: __init__
:show-inheritance:
-.. autoclass:: MSMediumBlob
+.. autoclass:: MEDIUMBLOB
:members: __init__
:show-inheritance:
-.. autoclass:: MSLongBlob
+.. autoclass:: LONGBLOB
:members: __init__
:show-inheritance:
-.. autoclass:: MSEnum
+.. autoclass:: ENUM
:members: __init__
:show-inheritance:
-.. autoclass:: MSSet
+.. autoclass:: SET
:members: __init__
:show-inheritance:
-.. autoclass:: MSBoolean
+.. autoclass:: BOOLEAN
:members: __init__
:show-inheritance:
+MySQLdb Notes
+--------------
+
+.. automodule:: sqlalchemy.dialects.mysql.mysqldb
+
+zxjdbc Notes
+--------------
+
+.. automodule:: sqlalchemy.dialects.mysql.zxjdbc
Oracle
======
-.. automodule:: sqlalchemy.databases.oracle
+.. automodule:: sqlalchemy.dialects.oracle.base
+
+cx_Oracle Notes
+---------------
+
+.. automodule:: sqlalchemy.dialects.oracle.cx_oracle
+
+++ /dev/null
-PostgreSQL
-==========
-
-.. automodule:: sqlalchemy.databases.postgres
--- /dev/null
+PostgreSQL
+==========
+
+.. automodule:: sqlalchemy.dialects.postgresql.base
+
+psycopg2 Notes
+--------------
+
+.. automodule:: sqlalchemy.dialects.postgresql.psycopg2
+
+
+pg8000 Notes
+--------------
+
+.. automodule:: sqlalchemy.dialects.postgresql.pg8000
+
+zxjdbc Notes
+--------------
+
+.. automodule:: sqlalchemy.dialects.postgresql.zxjdbc
SQLite
======
-.. automodule:: sqlalchemy.databases.sqlite
+.. automodule:: sqlalchemy.dialects.sqlite.base
+Pysqlite
+--------
+
+.. automodule:: sqlalchemy.dialects.sqlite.pysqlite
\ No newline at end of file
Sybase
======
-.. automodule:: sqlalchemy.databases.sybase
+.. automodule:: sqlalchemy.dialects.sybase.base
.. autoclass:: ExecutionContext
:members:
-.. autoclass:: SchemaIterator
- :members:
- :show-inheritance:
-
``pool_size``, ``max_overflow``, ``pool_recycle`` and
``pool_timeout``. For example::
- engine = create_engine('postgres://me@localhost/mydb',
+ engine = create_engine('postgresql://me@localhost/mydb',
pool_size=20, max_overflow=0)
In the case of SQLite, a :class:`SingletonThreadPool` is provided instead,
For example, MySQL has a ``BIGINTEGER`` type and PostgreSQL has an
``INET`` type. To use these, import them from the module explicitly::
- from sqlalchemy.databases.mysql import MSBigInteger, MSEnum
+ from sqlalchemy.dialect.mysql import dialect as mysql
table = Table('foo', meta,
- Column('id', MSBigInteger),
- Column('enumerates', MSEnum('a', 'b', 'c'))
+ Column('id', mysql.BIGINTEGER),
+ Column('enumerates', mysql.ENUM('a', 'b', 'c'))
)
Or some PostgreSQL types::
- from sqlalchemy.databases.postgres import PGInet, PGArray
+ from sqlalchemy.dialect.postgresql import dialect as postgresql
table = Table('foo', meta,
- Column('ipaddress', PGInet),
- Column('elements', PGArray(str))
+ Column('ipaddress', postgresql.INET),
+ Column('elements', postgresql.ARRAY(str))
)
+Each dialect should provide the full set of typenames supported by
+that backend, so that a backend-specific schema can be created without
+the need to locate types::
+
+ from sqlalchemy.dialects.postgresql import dialect as pg
+
+ t = Table('mytable', metadata,
+ Column('id', pg.INTEGER, primary_key=True),
+ Column('name', pg.VARCHAR(300)),
+ Column('inetaddr', pg.INET)
+ )
+
+Where above, the INTEGER and VARCHAR types are ultimately from
+sqlalchemy.types, but the Postgresql dialect makes them available.
Custom Types
------------
class that makes it easy to augment the bind parameter and result
processing capabilities of one of the built in types.
-To build a type object from scratch, subclass `:class:TypeEngine`.
+To build a type object from scratch, subclass `:class:UserDefinedType`.
.. autoclass:: TypeDecorator
:members:
:inherited-members:
:show-inheritance:
+.. autoclass:: UserDefinedType
+ :members:
+ :undoc-members:
+ :inherited-members:
+ :show-inheritance:
+
.. autoclass:: TypeEngine
:members:
:undoc-members:
Session = sessionmaker()
# later, we create the engine
- engine = create_engine('postgres://...')
+ engine = create_engine('postgresql://...')
# associate it with our custom Session class
Session.configure(bind=engine)
# global application scope. create Session class, engine
Session = sessionmaker()
- engine = create_engine('postgres://...')
+ engine = create_engine('postgresql://...')
...
* An application which reads an object structure from a file and wishes to save it to the database might parse the file, build up the structure, and then use ``merge()`` to save it to the database, ensuring that the data within the file is used to formulate the primary key of each element of the structure. Later, when the file has changed, the same process can be re-run, producing a slightly different object structure, which can then be ``merged()`` in again, and the ``Session`` will automatically update the database to reflect those changes.
* A web application stores mapped entities within an HTTP session object. When each request starts up, the serialized data can be merged into the session, so that the original entity may be safely shared among requests and threads.
-``merge()`` is frequently used by applications which implement their own second level caches. This refers to an application which uses an in memory dictionary, or an tool like Memcached to store objects over long running spans of time. When such an object needs to exist within a ``Session``, ``merge()`` is a good choice since it leaves the original cached object untouched. For this use case, merge provides a keyword option called ``dont_load=True``. When this boolean flag is set to ``True``, ``merge()`` will not issue any SQL to reconcile the given object against the current state of the database, thereby reducing query overhead. The limitation is that the given object and all of its children may not contain any pending changes, and it's also of course possible that newer information in the database will not be present on the merged object, since no load is issued.
+``merge()`` is frequently used by applications which implement their own second level caches. This refers to an application which uses an in memory dictionary, or an tool like Memcached to store objects over long running spans of time. When such an object needs to exist within a ``Session``, ``merge()`` is a good choice since it leaves the original cached object untouched. For this use case, merge provides a keyword option called ``load=False``. When this boolean flag is set to ``False``, ``merge()`` will not issue any SQL to reconcile the given object against the current state of the database, thereby reducing query overhead. The limitation is that the given object and all of its children may not contain any pending changes, and it's also of course possible that newer information in the database will not be present on the merged object, since no load is issued.
Deleting
--------
Finally, for MySQL, PostgreSQL, and soon Oracle as well, the session can be instructed to use two-phase commit semantics. This will coordinate the committing of transactions across databases so that the transaction is either committed or rolled back in all databases. You can also ``prepare()`` the session for interacting with transactions not managed by SQLAlchemy. To use two phase transactions set the flag ``twophase=True`` on the session::
- engine1 = create_engine('postgres://db1')
- engine2 = create_engine('postgres://db2')
+ engine1 = create_engine('postgresql://db1')
+ engine2 = create_engine('postgresql://db2')
Session = sessionmaker(twophase=True)
When using the ``threadlocal`` engine context, the process above is simplified; the ``Session`` uses the same connection/transaction as everyone else in the current thread, whether or not you explicitly bind it::
- engine = create_engine('postgres://mydb', strategy="threadlocal")
+ engine = create_engine('postgresql://mydb', strategy="threadlocal")
engine.begin()
session = Session() # session takes place in the transaction like everyone else
Vertical partitioning places different kinds of objects, or different tables, across multiple databases::
- engine1 = create_engine('postgres://db1')
- engine2 = create_engine('postgres://db2')
+ engine1 = create_engine('postgresql://db1')
+ engine2 = create_engine('postgresql://db2')
Session = sessionmaker(twophase=True)
.. sourcecode:: pycon+sql
- >>> result.last_inserted_ids()
+ >>> result.inserted_primary_key
[1]
-The value of ``1`` was automatically generated by SQLite, but only because we did not specify the ``id`` column in our ``Insert`` statement; otherwise, our explicit value would have been used. In either case, SQLAlchemy always knows how to get at a newly generated primary key value, even though the method of generating them is different across different databases; each databases' ``Dialect`` knows the specific steps needed to determine the correct value (or values; note that ``last_inserted_ids()`` returns a list so that it supports composite primary keys).
+The value of ``1`` was automatically generated by SQLite, but only because we did not specify the ``id`` column in our ``Insert`` statement; otherwise, our explicit value would have been used. In either case, SQLAlchemy always knows how to get at a newly generated primary key value, even though the method of generating them is different across different databases; each databases' ``Dialect`` knows the specific steps needed to determine the correct value (or values; note that ``inserted_primary_key`` returns a list so that it supports composite primary keys).
Executing Multiple Statements
==============================
return runner.failures, runner.tries
def replace_file(s, newfile):
- engine = r"'(sqlite|postgres|mysql):///.*'"
+ engine = r"'(sqlite|postgresql|mysql):///.*'"
engine = re.compile(engine, re.MULTILINE)
s, n = re.subn(engine, "'sqlite:///" + newfile + "'", s)
if not n:
from sqlalchemy.orm import sessionmaker, column_property
from sqlalchemy.ext.declarative import declarative_base
- engine = create_engine('postgres://scott:tiger@localhost/gistest', echo=True)
+ engine = create_engine('postgresql://scott:tiger@localhost/gistest', echo=True)
metadata = MetaData(engine)
Base = declarative_base(metadata=metadata)
self.session.expunge(x)
_cache[self.cachekey] = ret
- return iter(self.session.merge(x, dont_load=True) for x in ret)
+ return iter(self.session.merge(x, load=False) for x in ret)
else:
return Query.__iter__(self)
--- /dev/null
+#!python
+"""Bootstrap setuptools installation
+
+If you want to use setuptools in your package's setup.py, just include this
+file in the same directory with it, and add this to the top of your setup.py::
+
+ from ez_setup import use_setuptools
+ use_setuptools()
+
+If you want to require a specific version of setuptools, set a download
+mirror, or use an alternate download directory, you can do so by supplying
+the appropriate options to ``use_setuptools()``.
+
+This file can also be run as a script to install or upgrade setuptools.
+"""
+import sys
+DEFAULT_VERSION = "0.6c9"
+DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3]
+
+md5_data = {
+ 'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca',
+ 'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb',
+ 'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b',
+ 'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a',
+ 'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618',
+ 'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac',
+ 'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5',
+ 'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4',
+ 'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c',
+ 'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b',
+ 'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27',
+ 'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277',
+ 'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa',
+ 'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e',
+ 'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e',
+ 'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f',
+ 'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2',
+ 'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc',
+ 'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167',
+ 'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64',
+ 'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d',
+ 'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20',
+ 'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab',
+ 'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53',
+ 'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2',
+ 'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e',
+ 'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372',
+ 'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902',
+ 'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de',
+ 'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b',
+ 'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03',
+ 'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a',
+ 'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6',
+ 'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a',
+}
+
+import sys, os
+try: from hashlib import md5
+except ImportError: from md5 import md5
+
+def _validate_md5(egg_name, data):
+ if egg_name in md5_data:
+ digest = md5(data).hexdigest()
+ if digest != md5_data[egg_name]:
+ print >>sys.stderr, (
+ "md5 validation of %s failed! (Possible download problem?)"
+ % egg_name
+ )
+ sys.exit(2)
+ return data
+
+def use_setuptools(
+ version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
+ download_delay=15
+):
+ """Automatically find/download setuptools and make it available on sys.path
+
+ `version` should be a valid setuptools version number that is available
+ as an egg for download under the `download_base` URL (which should end with
+ a '/'). `to_dir` is the directory where setuptools will be downloaded, if
+ it is not already available. If `download_delay` is specified, it should
+ be the number of seconds that will be paused before initiating a download,
+ should one be required. If an older version of setuptools is installed,
+ this routine will print a message to ``sys.stderr`` and raise SystemExit in
+ an attempt to abort the calling script.
+ """
+ was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules
+ def do_download():
+ egg = download_setuptools(version, download_base, to_dir, download_delay)
+ sys.path.insert(0, egg)
+ import setuptools; setuptools.bootstrap_install_from = egg
+ try:
+ import pkg_resources
+ except ImportError:
+ return do_download()
+ try:
+ pkg_resources.require("setuptools>="+version); return
+ except pkg_resources.VersionConflict, e:
+ if was_imported:
+ print >>sys.stderr, (
+ "The required version of setuptools (>=%s) is not available, and\n"
+ "can't be installed while this script is running. Please install\n"
+ " a more recent version first, using 'easy_install -U setuptools'."
+ "\n\n(Currently using %r)"
+ ) % (version, e.args[0])
+ sys.exit(2)
+ else:
+ del pkg_resources, sys.modules['pkg_resources'] # reload ok
+ return do_download()
+ except pkg_resources.DistributionNotFound:
+ return do_download()
+
+def download_setuptools(
+ version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
+ delay = 15
+):
+ """Download setuptools from a specified location and return its filename
+
+ `version` should be a valid setuptools version number that is available
+ as an egg for download under the `download_base` URL (which should end
+ with a '/'). `to_dir` is the directory where the egg will be downloaded.
+ `delay` is the number of seconds to pause before an actual download attempt.
+ """
+ import urllib2, shutil
+ egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3])
+ url = download_base + egg_name
+ saveto = os.path.join(to_dir, egg_name)
+ src = dst = None
+ if not os.path.exists(saveto): # Avoid repeated downloads
+ try:
+ from distutils import log
+ if delay:
+ log.warn("""
+---------------------------------------------------------------------------
+This script requires setuptools version %s to run (even to display
+help). I will attempt to download it for you (from
+%s), but
+you may need to enable firewall access for this script first.
+I will start the download in %d seconds.
+
+(Note: if this machine does not have network access, please obtain the file
+
+ %s
+
+and place it in this directory before rerunning this script.)
+---------------------------------------------------------------------------""",
+ version, download_base, delay, url
+ ); from time import sleep; sleep(delay)
+ log.warn("Downloading %s", url)
+ src = urllib2.urlopen(url)
+ # Read/write all in one block, so we don't create a corrupt file
+ # if the download is interrupted.
+ data = _validate_md5(egg_name, src.read())
+ dst = open(saveto,"wb"); dst.write(data)
+ finally:
+ if src: src.close()
+ if dst: dst.close()
+ return os.path.realpath(saveto)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+def main(argv, version=DEFAULT_VERSION):
+ """Install or upgrade setuptools and EasyInstall"""
+ try:
+ import setuptools
+ except ImportError:
+ egg = None
+ try:
+ egg = download_setuptools(version, delay=0)
+ sys.path.insert(0,egg)
+ from setuptools.command.easy_install import main
+ return main(list(argv)+[egg]) # we're done here
+ finally:
+ if egg and os.path.exists(egg):
+ os.unlink(egg)
+ else:
+ if setuptools.__version__ == '0.0.1':
+ print >>sys.stderr, (
+ "You have an obsolete version of setuptools installed. Please\n"
+ "remove it from your system entirely before rerunning this script."
+ )
+ sys.exit(2)
+
+ req = "setuptools>="+version
+ import pkg_resources
+ try:
+ pkg_resources.require(req)
+ except pkg_resources.VersionConflict:
+ try:
+ from setuptools.command.easy_install import main
+ except ImportError:
+ from easy_install import main
+ main(list(argv)+[download_setuptools(delay=0)])
+ sys.exit(0) # try to force an exit
+ else:
+ if argv:
+ from setuptools.command.easy_install import main
+ main(argv)
+ else:
+ print "Setuptools version",version,"or greater has been installed."
+ print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)'
+
+def update_md5(filenames):
+ """Update our built-in md5 registry"""
+
+ import re
+
+ for name in filenames:
+ base = os.path.basename(name)
+ f = open(name,'rb')
+ md5_data[base] = md5(f.read()).hexdigest()
+ f.close()
+
+ data = [" %r: %r,\n" % it for it in md5_data.items()]
+ data.sort()
+ repl = "".join(data)
+
+ import inspect
+ srcfile = inspect.getsourcefile(sys.modules[__name__])
+ f = open(srcfile, 'rb'); src = f.read(); f.close()
+
+ match = re.search("\nmd5_data = {\n([^}]+)}", src)
+ if not match:
+ print >>sys.stderr, "Internal error!"
+ sys.exit(2)
+
+ src = src[:match.start(1)] + repl + src[match.end(1):]
+ f = open(srcfile,'w')
+ f.write(src)
+ f.close()
+
+
+if __name__=='__main__':
+ if len(sys.argv)>2 and sys.argv[1]=='--md5update':
+ update_md5(sys.argv[2:])
+ else:
+ main(sys.argv[1:])
+
+
+
+
+
+
import sqlalchemy.exc as exceptions
sys.modules['sqlalchemy.exceptions'] = exceptions
-from sqlalchemy.types import (
- BLOB,
- BOOLEAN,
- Binary,
- Boolean,
- CHAR,
- CLOB,
- DATE,
- DATETIME,
- DECIMAL,
- Date,
- DateTime,
- FLOAT,
- Float,
- INT,
- Integer,
- Interval,
- NCHAR,
- NUMERIC,
- Numeric,
- PickleType,
- SMALLINT,
- SmallInteger,
- String,
- TEXT,
- TIME,
- TIMESTAMP,
- Text,
- Time,
- Unicode,
- UnicodeText,
- VARCHAR,
- )
-
from sqlalchemy.sql import (
alias,
and_,
update,
)
+from sqlalchemy.types import (
+ BLOB,
+ BOOLEAN,
+ Binary,
+ Boolean,
+ CHAR,
+ CLOB,
+ DATE,
+ DATETIME,
+ DECIMAL,
+ Date,
+ DateTime,
+ FLOAT,
+ Float,
+ INT,
+ INTEGER,
+ Integer,
+ Interval,
+ NCHAR,
+ NVARCHAR,
+ NUMERIC,
+ Numeric,
+ PickleType,
+ SMALLINT,
+ SmallInteger,
+ String,
+ TEXT,
+ TIME,
+ TIMESTAMP,
+ Text,
+ Time,
+ Unicode,
+ UnicodeText,
+ VARCHAR,
+ )
+
+
from sqlalchemy.schema import (
CheckConstraint,
Column,
__all__ = sorted(name for name, obj in locals().items()
if not (name.startswith('_') or inspect.ismodule(obj)))
-__version__ = '0.5.5'
+__version__ = '0.6beta1'
del inspect, sys
--- /dev/null
+
+
+class Connector(object):
+ pass
+
+
\ No newline at end of file
--- /dev/null
+from sqlalchemy.connectors import Connector
+
+class MxODBCConnector(Connector):
+ driver='mxodbc'
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+ supports_unicode_statements = False
+ supports_unicode_binds = False
+
+ @classmethod
+ def import_dbapi(cls):
+ import mxODBC as module
+ return module
+
+ def create_connect_args(self, url):
+ '''Return a tuple of *args,**kwargs'''
+ # FIXME: handle mx.odbc.Windows proprietary args
+ opts = url.translate_connect_args(username='user')
+ opts.update(url.query)
+ argsDict = {}
+ argsDict['user'] = opts['user']
+ argsDict['password'] = opts['password']
+ connArgs = [[opts['dsn']], argsDict]
+ return connArgs
--- /dev/null
+from sqlalchemy.connectors import Connector
+
+import sys
+import re
+import urllib
+
+class PyODBCConnector(Connector):
+ driver='pyodbc'
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+ # PyODBC unicode is broken on UCS-4 builds
+ supports_unicode = sys.maxunicode == 65535
+ supports_unicode_statements = supports_unicode
+ default_paramstyle = 'named'
+
+ # for non-DSN connections, this should
+ # hold the desired driver name
+ pyodbc_driver_name = None
+
+ @classmethod
+ def dbapi(cls):
+ return __import__('pyodbc')
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username='user')
+ opts.update(url.query)
+
+ keys = opts
+ query = url.query
+
+ if 'odbc_connect' in keys:
+ connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
+ else:
+ dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
+ if dsn_connection:
+ connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
+ else:
+ port = ''
+ if 'port' in keys and not 'port' in query:
+ port = ',%d' % int(keys.pop('port'))
+
+ connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name),
+ 'Server=%s%s' % (keys.pop('host', ''), port),
+ 'Database=%s' % keys.pop('database', '') ]
+
+ user = keys.pop("user", None)
+ if user:
+ connectors.append("UID=%s" % user)
+ connectors.append("PWD=%s" % keys.pop('password', ''))
+ else:
+ connectors.append("TrustedConnection=Yes")
+
+ # if set to 'Yes', the ODBC layer will try to automagically convert
+ # textual data from your database encoding to your client encoding
+ # This should obviously be set to 'No' if you query a cp1253 encoded
+ # database from a latin1 client...
+ if 'odbc_autotranslate' in keys:
+ connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
+
+ connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
+ return [[";".join (connectors)], {}]
+
+ def is_disconnect(self, e):
+ if isinstance(e, self.dbapi.ProgrammingError):
+ return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
+ elif isinstance(e, self.dbapi.Error):
+ return '[08S01]' in str(e)
+ else:
+ return False
+
+ def _get_server_version_info(self, connection):
+ dbapi_con = connection.connection
+ version = []
+ r = re.compile('[.\-]')
+ for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
+ try:
+ version.append(int(n))
+ except ValueError:
+ version.append(n)
+ return tuple(version)
--- /dev/null
+import sys
+from sqlalchemy.connectors import Connector
+
+class ZxJDBCConnector(Connector):
+ driver = 'zxjdbc'
+
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+
+ supports_unicode_binds = True
+ supports_unicode_statements = sys.version > '2.5.0+'
+ description_encoding = None
+ default_paramstyle = 'qmark'
+
+ jdbc_db_name = None
+ jdbc_driver_name = None
+
+ @classmethod
+ def dbapi(cls):
+ from com.ziclix.python.sql import zxJDBC
+ return zxJDBC
+
+ def _driver_kwargs(self):
+ """return kw arg dict to be sent to connect()."""
+ return {}
+
+ def create_connect_args(self, url):
+ hostname = url.host
+ dbname = url.database
+ d, u, p, v = ("jdbc:%s://%s/%s" % (self.jdbc_db_name, hostname, dbname),
+ url.username, url.password, self.jdbc_driver_name)
+ return [[d, u, p, v], self._driver_kwargs()]
+
+ def is_disconnect(self, e):
+ if not isinstance(e, self.dbapi.ProgrammingError):
+ return False
+ e = str(e)
+ return 'connection is closed' in e or 'cursor is closed' in e
+
+ def _get_server_version_info(self, connection):
+ # use connection.connection.dbversion, and parse appropriately
+ # to get a tuple
+ raise NotImplementedError()
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+from sqlalchemy.dialects.sqlite import base as sqlite
+from sqlalchemy.dialects.postgresql import base as postgresql
+postgres = postgresql
+from sqlalchemy.dialects.mysql import base as mysql
+from sqlalchemy.dialects.oracle import base as oracle
+from sqlalchemy.dialects.firebird import base as firebird
+from sqlalchemy.dialects.maxdb import base as maxdb
+from sqlalchemy.dialects.informix import base as informix
+from sqlalchemy.dialects.mssql import base as mssql
+from sqlalchemy.dialects.access import base as access
+from sqlalchemy.dialects.sybase import base as sybase
+
+
+
__all__ = (
'access',
'maxdb',
'mssql',
'mysql',
- 'oracle',
- 'postgres',
+ 'postgresql',
'sqlite',
+ 'oracle',
'sybase',
)
+++ /dev/null
-# firebird.py
-# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""
-Firebird backend
-================
-
-This module implements the Firebird backend, thru the kinterbasdb_
-DBAPI module.
-
-Firebird dialects
------------------
-
-Firebird offers two distinct dialects_ (not to be confused with the
-SA ``Dialect`` thing):
-
-dialect 1
- This is the old syntax and behaviour, inherited from Interbase pre-6.0.
-
-dialect 3
- This is the newer and supported syntax, introduced in Interbase 6.0.
-
-From the user point of view, the biggest change is in date/time
-handling: under dialect 1, there's a single kind of field, ``DATE``
-with a synonim ``DATETIME``, that holds a `timestamp` value, that is a
-date with hour, minute, second. Under dialect 3 there are three kinds,
-a ``DATE`` that holds a date, a ``TIME`` that holds a *time of the
-day* value and a ``TIMESTAMP``, equivalent to the old ``DATE``.
-
-The problem is that the dialect of a Firebird database is a property
-of the database itself [#]_ (that is, any single database has been
-created with one dialect or the other: there is no way to change the
-after creation). SQLAlchemy has a single instance of the class that
-controls all the connections to a particular kind of database, so it
-cannot easily differentiate between the two modes, and in particular
-it **cannot** simultaneously talk with two distinct Firebird databases
-with different dialects.
-
-By default this module is biased toward dialect 3, but you can easily
-tweak it to handle dialect 1 if needed::
-
- from sqlalchemy import types as sqltypes
- from sqlalchemy.databases.firebird import FBDate, colspecs, ischema_names
-
- # Adjust the mapping of the timestamp kind
- ischema_names['TIMESTAMP'] = FBDate
- colspecs[sqltypes.DateTime] = FBDate,
-
-Other aspects may be version-specific. You can use the ``server_version_info()`` method
-on the ``FBDialect`` class to do whatever is needed::
-
- from sqlalchemy.databases.firebird import FBCompiler
-
- if engine.dialect.server_version_info(connection) < (2,0):
- # Change the name of the function ``length`` to use the UDF version
- # instead of ``char_length``
- FBCompiler.LENGTH_FUNCTION_NAME = 'strlen'
-
-Pooling connections
--------------------
-
-The default strategy used by SQLAlchemy to pool the database connections
-in particular cases may raise an ``OperationalError`` with a message
-`"object XYZ is in use"`. This happens on Firebird when there are two
-connections to the database, one is using, or has used, a particular table
-and the other tries to drop or alter the same table. To garantee DDL
-operations success Firebird recommend doing them as the single connected user.
-
-In case your SA application effectively needs to do DDL operations while other
-connections are active, the following setting may alleviate the problem::
-
- from sqlalchemy import pool
- from sqlalchemy.databases.firebird import dialect
-
- # Force SA to use a single connection per thread
- dialect.poolclass = pool.SingletonThreadPool
-
-RETURNING support
------------------
-
-Firebird 2.0 supports returning a result set from inserts, and 2.1 extends
-that to deletes and updates.
-
-To use this pass the column/expression list to the ``firebird_returning``
-parameter when creating the queries::
-
- raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1),
- firebird_returning=[empl.c.id, empl.c.salary]).execute().fetchall()
-
-
-.. [#] Well, that is not the whole story, as the client may still ask
- a different (lower) dialect...
-
-.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html
-.. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb
-"""
-
-
-import datetime, decimal, re
-
-from sqlalchemy import exc, schema, types as sqltypes, sql, util
-from sqlalchemy.engine import base, default
-
-
-_initialized_kb = False
-
-
-class FBNumeric(sqltypes.Numeric):
- """Handle ``NUMERIC(precision,scale)`` datatype."""
-
- def get_col_spec(self):
- if self.precision is None:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % { 'precision': self.precision,
- 'scale' : self.scale }
-
- def bind_processor(self, dialect):
- return None
-
- def result_processor(self, dialect):
- if self.asdecimal:
- return None
- else:
- def process(value):
- if isinstance(value, decimal.Decimal):
- return float(value)
- else:
- return value
- return process
-
-
-class FBFloat(sqltypes.Float):
- """Handle ``FLOAT(precision)`` datatype."""
-
- def get_col_spec(self):
- if not self.precision:
- return "FLOAT"
- else:
- return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class FBInteger(sqltypes.Integer):
- """Handle ``INTEGER`` datatype."""
-
- def get_col_spec(self):
- return "INTEGER"
-
-
-class FBSmallInteger(sqltypes.Smallinteger):
- """Handle ``SMALLINT`` datatype."""
-
- def get_col_spec(self):
- return "SMALLINT"
-
-
-class FBDateTime(sqltypes.DateTime):
- """Handle ``TIMESTAMP`` datatype."""
-
- def get_col_spec(self):
- return "TIMESTAMP"
-
- def bind_processor(self, dialect):
- def process(value):
- if value is None or isinstance(value, datetime.datetime):
- return value
- else:
- return datetime.datetime(year=value.year,
- month=value.month,
- day=value.day)
- return process
-
-
-class FBDate(sqltypes.DateTime):
- """Handle ``DATE`` datatype."""
-
- def get_col_spec(self):
- return "DATE"
-
-
-class FBTime(sqltypes.Time):
- """Handle ``TIME`` datatype."""
-
- def get_col_spec(self):
- return "TIME"
-
-
-class FBText(sqltypes.Text):
- """Handle ``BLOB SUB_TYPE 1`` datatype (aka *textual* blob)."""
-
- def get_col_spec(self):
- return "BLOB SUB_TYPE 1"
-
-
-class FBString(sqltypes.String):
- """Handle ``VARCHAR(length)`` datatype."""
-
- def get_col_spec(self):
- if self.length:
- return "VARCHAR(%(length)s)" % {'length' : self.length}
- else:
- return "BLOB SUB_TYPE 1"
-
-
-class FBChar(sqltypes.CHAR):
- """Handle ``CHAR(length)`` datatype."""
-
- def get_col_spec(self):
- if self.length:
- return "CHAR(%(length)s)" % {'length' : self.length}
- else:
- return "BLOB SUB_TYPE 1"
-
-
-class FBBinary(sqltypes.Binary):
- """Handle ``BLOB SUB_TYPE 0`` datatype (aka *binary* blob)."""
-
- def get_col_spec(self):
- return "BLOB SUB_TYPE 0"
-
-
-class FBBoolean(sqltypes.Boolean):
- """Handle boolean values as a ``SMALLINT`` datatype."""
-
- def get_col_spec(self):
- return "SMALLINT"
-
-
-colspecs = {
- sqltypes.Integer : FBInteger,
- sqltypes.Smallinteger : FBSmallInteger,
- sqltypes.Numeric : FBNumeric,
- sqltypes.Float : FBFloat,
- sqltypes.DateTime : FBDateTime,
- sqltypes.Date : FBDate,
- sqltypes.Time : FBTime,
- sqltypes.String : FBString,
- sqltypes.Binary : FBBinary,
- sqltypes.Boolean : FBBoolean,
- sqltypes.Text : FBText,
- sqltypes.CHAR: FBChar,
-}
-
-
-ischema_names = {
- 'SHORT': lambda r: FBSmallInteger(),
- 'LONG': lambda r: FBInteger(),
- 'QUAD': lambda r: FBFloat(),
- 'FLOAT': lambda r: FBFloat(),
- 'DATE': lambda r: FBDate(),
- 'TIME': lambda r: FBTime(),
- 'TEXT': lambda r: FBString(r['flen']),
- 'INT64': lambda r: FBNumeric(precision=r['fprec'], scale=r['fscale'] * -1), # This generically handles NUMERIC()
- 'DOUBLE': lambda r: FBFloat(),
- 'TIMESTAMP': lambda r: FBDateTime(),
- 'VARYING': lambda r: FBString(r['flen']),
- 'CSTRING': lambda r: FBChar(r['flen']),
- 'BLOB': lambda r: r['stype']==1 and FBText() or FBBinary()
- }
-
-RETURNING_KW_NAME = 'firebird_returning'
-
-class FBExecutionContext(default.DefaultExecutionContext):
- pass
-
-
-class FBDialect(default.DefaultDialect):
- """Firebird dialect"""
- name = 'firebird'
- supports_sane_rowcount = False
- supports_sane_multi_rowcount = False
- max_identifier_length = 31
- preexecute_pk_sequences = True
- supports_pk_autoincrement = False
-
- def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
- default.DefaultDialect.__init__(self, **kwargs)
-
- self.type_conv = type_conv
- self.concurrency_level = concurrency_level
-
- def dbapi(cls):
- import kinterbasdb
- return kinterbasdb
- dbapi = classmethod(dbapi)
-
- def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- if opts.get('port'):
- opts['host'] = "%s/%s" % (opts['host'], opts['port'])
- del opts['port']
- opts.update(url.query)
-
- type_conv = opts.pop('type_conv', self.type_conv)
- concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
- global _initialized_kb
- if not _initialized_kb and self.dbapi is not None:
- _initialized_kb = True
- self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
- return ([], opts)
-
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
-
- def server_version_info(self, connection):
- """Get the version of the Firebird server used by a connection.
-
- Returns a tuple of (`major`, `minor`, `build`), three integers
- representing the version of the attached server.
- """
-
- # This is the simpler approach (the other uses the services api),
- # that for backward compatibility reasons returns a string like
- # LI-V6.3.3.12981 Firebird 2.0
- # where the first version is a fake one resembling the old
- # Interbase signature. This is more than enough for our purposes,
- # as this is mainly (only?) used by the testsuite.
-
- from re import match
-
- fbconn = connection.connection.connection
- version = fbconn.server_version
- m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
- if not m:
- raise AssertionError("Could not determine version from string '%s'" % version)
- return tuple([int(x) for x in m.group(5, 6, 4)])
-
- def _normalize_name(self, name):
- """Convert the name to lowercase if it is possible"""
-
- # Remove trailing spaces: FB uses a CHAR() type,
- # that is padded with spaces
- name = name and name.rstrip()
- if name is None:
- return None
- elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower()):
- return name.lower()
- else:
- return name
-
- def _denormalize_name(self, name):
- """Revert a *normalized* name to its uppercase equivalent"""
-
- if name is None:
- return None
- elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()):
- return name.upper()
- else:
- return name
-
- def table_names(self, connection, schema):
- """Return a list of *normalized* table names omitting system relations."""
-
- s = """
- SELECT r.rdb$relation_name
- FROM rdb$relations r
- WHERE r.rdb$system_flag=0
- """
- return [self._normalize_name(row[0]) for row in connection.execute(s)]
-
- def has_table(self, connection, table_name, schema=None):
- """Return ``True`` if the given table exists, ignoring the `schema`."""
-
- tblqry = """
- SELECT 1 FROM rdb$database
- WHERE EXISTS (SELECT rdb$relation_name
- FROM rdb$relations
- WHERE rdb$relation_name=?)
- """
- c = connection.execute(tblqry, [self._denormalize_name(table_name)])
- row = c.fetchone()
- if row is not None:
- return True
- else:
- return False
-
- def has_sequence(self, connection, sequence_name):
- """Return ``True`` if the given sequence (generator) exists."""
-
- genqry = """
- SELECT 1 FROM rdb$database
- WHERE EXISTS (SELECT rdb$generator_name
- FROM rdb$generators
- WHERE rdb$generator_name=?)
- """
- c = connection.execute(genqry, [self._denormalize_name(sequence_name)])
- row = c.fetchone()
- if row is not None:
- return True
- else:
- return False
-
- def is_disconnect(self, e):
- if isinstance(e, self.dbapi.OperationalError):
- return 'Unable to complete network request to host' in str(e)
- elif isinstance(e, self.dbapi.ProgrammingError):
- msg = str(e)
- return ('Invalid connection state' in msg or
- 'Invalid cursor state' in msg)
- else:
- return False
-
- def reflecttable(self, connection, table, include_columns):
- # Query to extract the details of all the fields of the given table
- tblqry = """
- SELECT DISTINCT r.rdb$field_name AS fname,
- r.rdb$null_flag AS null_flag,
- t.rdb$type_name AS ftype,
- f.rdb$field_sub_type AS stype,
- f.rdb$field_length AS flen,
- f.rdb$field_precision AS fprec,
- f.rdb$field_scale AS fscale,
- COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault
- FROM rdb$relation_fields r
- JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
- JOIN rdb$types t ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE'
- WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
- ORDER BY r.rdb$field_position
- """
- # Query to extract the PK/FK constrained fields of the given table
- keyqry = """
- SELECT se.rdb$field_name AS fname
- FROM rdb$relation_constraints rc
- JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
- WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
- """
- # Query to extract the details of each UK/FK of the given table
- fkqry = """
- SELECT rc.rdb$constraint_name AS cname,
- cse.rdb$field_name AS fname,
- ix2.rdb$relation_name AS targetrname,
- se.rdb$field_name AS targetfname
- FROM rdb$relation_constraints rc
- JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
- JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
- JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name
- JOIN rdb$index_segments se ON se.rdb$index_name=ix2.rdb$index_name AND se.rdb$field_position=cse.rdb$field_position
- WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
- ORDER BY se.rdb$index_name, se.rdb$field_position
- """
- # Heuristic-query to determine the generator associated to a PK field
- genqry = """
- SELECT trigdep.rdb$depended_on_name AS fgenerator
- FROM rdb$dependencies tabdep
- JOIN rdb$dependencies trigdep ON (tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
- AND trigdep.rdb$depended_on_type=14
- AND trigdep.rdb$dependent_type=2)
- JOIN rdb$triggers trig ON (trig.rdb$trigger_name=tabdep.rdb$dependent_name)
- WHERE tabdep.rdb$depended_on_name=?
- AND tabdep.rdb$depended_on_type=0
- AND trig.rdb$trigger_type=1
- AND tabdep.rdb$field_name=?
- AND (SELECT count(*)
- FROM rdb$dependencies trigdep2
- WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
- """
-
- tablename = self._denormalize_name(table.name)
-
- # get primary key fields
- c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
- pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()]
-
- # get all of the fields for this table
- c = connection.execute(tblqry, [tablename])
-
- found_table = False
- while True:
- row = c.fetchone()
- if row is None:
- break
- found_table = True
-
- name = self._normalize_name(row['fname'])
- if include_columns and name not in include_columns:
- continue
- args = [name]
-
- kw = {}
- # get the data type
- coltype = ischema_names.get(row['ftype'].rstrip())
- if coltype is None:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (str(row['ftype']), name))
- coltype = sqltypes.NULLTYPE
- else:
- coltype = coltype(row)
- args.append(coltype)
-
- # is it a primary key?
- kw['primary_key'] = name in pkfields
-
- # is it nullable?
- kw['nullable'] = not bool(row['null_flag'])
-
- # does it have a default value?
- if row['fdefault'] is not None:
- # the value comes down as "DEFAULT 'value'"
- assert row['fdefault'].upper().startswith('DEFAULT '), row
- defvalue = row['fdefault'][8:]
- args.append(schema.DefaultClause(sql.text(defvalue)))
-
- col = schema.Column(*args, **kw)
- if kw['primary_key']:
- # if the PK is a single field, try to see if its linked to
- # a sequence thru a trigger
- if len(pkfields)==1:
- genc = connection.execute(genqry, [tablename, row['fname']])
- genr = genc.fetchone()
- if genr is not None:
- col.sequence = schema.Sequence(self._normalize_name(genr['fgenerator']))
-
- table.append_column(col)
-
- if not found_table:
- raise exc.NoSuchTableError(table.name)
-
- # get the foreign keys
- c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
- fks = {}
- while True:
- row = c.fetchone()
- if not row:
- break
-
- cname = self._normalize_name(row['cname'])
- try:
- fk = fks[cname]
- except KeyError:
- fks[cname] = fk = ([], [])
- rname = self._normalize_name(row['targetrname'])
- schema.Table(rname, table.metadata, autoload=True, autoload_with=connection)
- fname = self._normalize_name(row['fname'])
- refspec = rname + '.' + self._normalize_name(row['targetfname'])
- fk[0].append(fname)
- fk[1].append(refspec)
-
- for name, value in fks.iteritems():
- table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True))
-
- def do_execute(self, cursor, statement, parameters, **kwargs):
- # kinterbase does not accept a None, but wants an empty list
- # when there are no arguments.
- cursor.execute(statement, parameters or [])
-
- def do_rollback(self, connection):
- # Use the retaining feature, that keeps the transaction going
- connection.rollback(True)
-
- def do_commit(self, connection):
- # Use the retaining feature, that keeps the transaction going
- connection.commit(True)
-
-
-def _substring(s, start, length=None):
- "Helper function to handle Firebird 2 SUBSTRING builtin"
-
- if length is None:
- return "SUBSTRING(%s FROM %s)" % (s, start)
- else:
- return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
-
-
-class FBCompiler(sql.compiler.DefaultCompiler):
- """Firebird specific idiosincrasies"""
-
- # Firebird lacks a builtin modulo operator, but there is
- # an equivalent function in the ib_udf library.
- operators = sql.compiler.DefaultCompiler.operators.copy()
- operators.update({
- sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y)
- })
-
- def visit_alias(self, alias, asfrom=False, **kwargs):
- # Override to not use the AS keyword which FB 1.5 does not like
- if asfrom:
- return self.process(alias.original, asfrom=True, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name))
- else:
- return self.process(alias.original, **kwargs)
-
- functions = sql.compiler.DefaultCompiler.functions.copy()
- functions['substring'] = _substring
-
- def function_argspec(self, func):
- if func.clauses:
- return self.process(func.clause_expr)
- else:
- return ""
-
- def default_from(self):
- return " FROM rdb$database"
-
- def visit_sequence(self, seq):
- return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
-
- def get_select_precolumns(self, select):
- """Called when building a ``SELECT`` statement, position is just
- before column list Firebird puts the limit and offset right
- after the ``SELECT``...
- """
-
- result = ""
- if select._limit:
- result += "FIRST %d " % select._limit
- if select._offset:
- result +="SKIP %d " % select._offset
- if select._distinct:
- result += "DISTINCT "
- return result
-
- def limit_clause(self, select):
- """Already taken care of in the `get_select_precolumns` method."""
-
- return ""
-
- LENGTH_FUNCTION_NAME = 'char_length'
- def function_string(self, func):
- """Substitute the ``length`` function.
-
- On newer FB there is a ``char_length`` function, while older
- ones need the ``strlen`` UDF.
- """
-
- if func.name == 'length':
- return self.LENGTH_FUNCTION_NAME + '%(expr)s'
- return super(FBCompiler, self).function_string(func)
-
- def _append_returning(self, text, stmt):
- returning_cols = stmt.kwargs[RETURNING_KW_NAME]
- def flatten_columnlist(collist):
- for c in collist:
- if isinstance(c, sql.expression.Selectable):
- for co in c.columns:
- yield co
- else:
- yield c
- columns = [self.process(c, within_columns_clause=True)
- for c in flatten_columnlist(returning_cols)]
- text += ' RETURNING ' + ', '.join(columns)
- return text
-
- def visit_update(self, update_stmt):
- text = super(FBCompiler, self).visit_update(update_stmt)
- if RETURNING_KW_NAME in update_stmt.kwargs:
- return self._append_returning(text, update_stmt)
- else:
- return text
-
- def visit_insert(self, insert_stmt):
- text = super(FBCompiler, self).visit_insert(insert_stmt)
- if RETURNING_KW_NAME in insert_stmt.kwargs:
- return self._append_returning(text, insert_stmt)
- else:
- return text
-
- def visit_delete(self, delete_stmt):
- text = super(FBCompiler, self).visit_delete(delete_stmt)
- if RETURNING_KW_NAME in delete_stmt.kwargs:
- return self._append_returning(text, delete_stmt)
- else:
- return text
-
-
-class FBSchemaGenerator(sql.compiler.SchemaGenerator):
- """Firebird syntactic idiosincrasies"""
-
- def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column)
- colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- if not column.nullable or column.primary_key:
- colspec += " NOT NULL"
-
- return colspec
-
- def visit_sequence(self, sequence):
- """Generate a ``CREATE GENERATOR`` statement for the sequence."""
-
- if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name):
- self.append("CREATE GENERATOR %s" % self.preparer.format_sequence(sequence))
- self.execute()
-
-
-class FBSchemaDropper(sql.compiler.SchemaDropper):
- """Firebird syntactic idiosincrasies"""
-
- def visit_sequence(self, sequence):
- """Generate a ``DROP GENERATOR`` statement for the sequence."""
-
- if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
- self.append("DROP GENERATOR %s" % self.preparer.format_sequence(sequence))
- self.execute()
-
-
-class FBDefaultRunner(base.DefaultRunner):
- """Firebird specific idiosincrasies"""
-
- def visit_sequence(self, seq):
- """Get the next value from the sequence using ``gen_id()``."""
-
- return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \
- self.dialect.identifier_preparer.format_sequence(seq))
-
-
-RESERVED_WORDS = set(
- ["action", "active", "add", "admin", "after", "all", "alter", "and", "any",
- "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename",
- "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer",
- "by", "cache", "cascade", "case", "cast", "char", "character", "character_length",
- "char_length", "check", "check_point_len", "check_point_length", "close", "collate",
- "collation", "column", "commit", "committed", "compiletime", "computed", "conditional",
- "connect", "constraint", "containing", "continue", "count", "create", "cstring",
- "current", "current_connection", "current_date", "current_role", "current_time",
- "current_timestamp", "current_transaction", "current_user", "cursor", "database",
- "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete",
- "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct",
- "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point",
- "escape", "event", "exception", "execute", "exists", "exit", "extern", "external",
- "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it",
- "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto",
- "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour",
- "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input",
- "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join",
- "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile",
- "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment",
- "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month",
- "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric",
- "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option",
- "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength",
- "pages", "page_size", "parameter", "password", "plan", "position", "post_event",
- "precision", "prepare", "primary", "privileges", "procedure", "protected", "public",
- "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate",
- "references", "release", "release", "reserv", "reserving", "restrict", "retain",
- "return", "returning_values", "returns", "revoke", "right", "role", "rollback",
- "row_count", "runtime", "savepoint", "schema", "second", "segment", "select",
- "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint",
- "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability",
- "starting", "starts", "statement", "static", "statistics", "sub_type", "sum",
- "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction",
- "translate", "translation", "trigger", "trim", "type", "uncommitted", "union",
- "unique", "update", "upper", "user", "using", "value", "values", "varchar",
- "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when",
- "whenever", "where", "while", "with", "work", "write", "year", "yearday" ])
-
-
-class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
- """Install Firebird specific reserved words."""
-
- reserved_words = RESERVED_WORDS
-
- def __init__(self, dialect):
- super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
-
-
-dialect = FBDialect
-dialect.statement_compiler = FBCompiler
-dialect.schemagenerator = FBSchemaGenerator
-dialect.schemadropper = FBSchemaDropper
-dialect.defaultrunner = FBDefaultRunner
-dialect.preparer = FBIdentifierPreparer
-dialect.execution_ctx_cls = FBExecutionContext
+++ /dev/null
-"""
-information schema implementation.
-
-This module is deprecated and will not be present in this form in SQLAlchemy 0.6.
-
-"""
-from sqlalchemy import util
-
-util.warn_deprecated("the information_schema module is deprecated.")
-
-import sqlalchemy.sql as sql
-import sqlalchemy.exc as exc
-from sqlalchemy import select, MetaData, Table, Column, String, Integer
-from sqlalchemy.schema import DefaultClause, ForeignKeyConstraint
-
-ischema = MetaData()
-
-schemata = Table("schemata", ischema,
- Column("catalog_name", String),
- Column("schema_name", String),
- Column("schema_owner", String),
- schema="information_schema")
-
-tables = Table("tables", ischema,
- Column("table_catalog", String),
- Column("table_schema", String),
- Column("table_name", String),
- Column("table_type", String),
- schema="information_schema")
-
-columns = Table("columns", ischema,
- Column("table_schema", String),
- Column("table_name", String),
- Column("column_name", String),
- Column("is_nullable", Integer),
- Column("data_type", String),
- Column("ordinal_position", Integer),
- Column("character_maximum_length", Integer),
- Column("numeric_precision", Integer),
- Column("numeric_scale", Integer),
- Column("column_default", Integer),
- Column("collation_name", String),
- schema="information_schema")
-
-constraints = Table("table_constraints", ischema,
- Column("table_schema", String),
- Column("table_name", String),
- Column("constraint_name", String),
- Column("constraint_type", String),
- schema="information_schema")
-
-column_constraints = Table("constraint_column_usage", ischema,
- Column("table_schema", String),
- Column("table_name", String),
- Column("column_name", String),
- Column("constraint_name", String),
- schema="information_schema")
-
-pg_key_constraints = Table("key_column_usage", ischema,
- Column("table_schema", String),
- Column("table_name", String),
- Column("column_name", String),
- Column("constraint_name", String),
- Column("ordinal_position", Integer),
- schema="information_schema")
-
-#mysql_key_constraints = Table("key_column_usage", ischema,
-# Column("table_schema", String),
-# Column("table_name", String),
-# Column("column_name", String),
-# Column("constraint_name", String),
-# Column("referenced_table_schema", String),
-# Column("referenced_table_name", String),
-# Column("referenced_column_name", String),
-# schema="information_schema")
-
-key_constraints = pg_key_constraints
-
-ref_constraints = Table("referential_constraints", ischema,
- Column("constraint_catalog", String),
- Column("constraint_schema", String),
- Column("constraint_name", String),
- Column("unique_constraint_catlog", String),
- Column("unique_constraint_schema", String),
- Column("unique_constraint_name", String),
- Column("match_option", String),
- Column("update_rule", String),
- Column("delete_rule", String),
- schema="information_schema")
-
-
-def table_names(connection, schema):
- s = select([tables.c.table_name], tables.c.table_schema==schema)
- return [row[0] for row in connection.execute(s)]
-
-
-def reflecttable(connection, table, include_columns, ischema_names):
- key_constraints = pg_key_constraints
-
- if table.schema is not None:
- current_schema = table.schema
- else:
- current_schema = connection.default_schema_name()
-
- s = select([columns],
- sql.and_(columns.c.table_name==table.name,
- columns.c.table_schema==current_schema),
- order_by=[columns.c.ordinal_position])
-
- c = connection.execute(s)
- found_table = False
- while True:
- row = c.fetchone()
- if row is None:
- break
- #print "row! " + repr(row)
- # continue
- found_table = True
- (name, type, nullable, charlen, numericprec, numericscale, default) = (
- row[columns.c.column_name],
- row[columns.c.data_type],
- row[columns.c.is_nullable] == 'YES',
- row[columns.c.character_maximum_length],
- row[columns.c.numeric_precision],
- row[columns.c.numeric_scale],
- row[columns.c.column_default]
- )
- if include_columns and name not in include_columns:
- continue
-
- args = []
- for a in (charlen, numericprec, numericscale):
- if a is not None:
- args.append(a)
- coltype = ischema_names[type]
- #print "coltype " + repr(coltype) + " args " + repr(args)
- coltype = coltype(*args)
- colargs = []
- if default is not None:
- colargs.append(DefaultClause(sql.text(default)))
- table.append_column(Column(name, coltype, nullable=nullable, *colargs))
-
- if not found_table:
- raise exc.NoSuchTableError(table.name)
-
- # we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns
- # in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys
- # wont reflect properly. dont see a way around this based on whats available from information_schema
- s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)], order_by=[key_constraints.c.ordinal_position])
- s.append_column(column_constraints)
- s.append_whereclause(constraints.c.table_name==table.name)
- s.append_whereclause(constraints.c.table_schema==current_schema)
- colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name, constraints.c.constraint_name, key_constraints.c.ordinal_position]
- c = connection.execute(s)
-
- fks = {}
- while True:
- row = c.fetchone()
- if row is None:
- break
- (type, constrained_column, referred_schema, referred_table, referred_column, constraint_name, ordinal_position) = (
- row[colmap[0]],
- row[colmap[1]],
- row[colmap[2]],
- row[colmap[3]],
- row[colmap[4]],
- row[colmap[5]],
- row[colmap[6]]
- )
- #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column)
- if type == 'PRIMARY KEY':
- table.primary_key.add(table.c[constrained_column])
- elif type == 'FOREIGN KEY':
- try:
- fk = fks[constraint_name]
- except KeyError:
- fk = ([], [])
- fks[constraint_name] = fk
- if current_schema == referred_schema:
- referred_schema = table.schema
- if referred_schema is not None:
- Table(referred_table, table.metadata, autoload=True, schema=referred_schema, autoload_with=connection)
- refspec = ".".join([referred_schema, referred_table, referred_column])
- else:
- Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
- refspec = ".".join([referred_table, referred_column])
- if constrained_column not in fk[0]:
- fk[0].append(constrained_column)
- if refspec not in fk[1]:
- fk[1].append(refspec)
-
- for name, value in fks.iteritems():
- table.append_constraint(ForeignKeyConstraint(value[0], value[1], name=name))
+++ /dev/null
-# mssql.py
-
-"""Support for the Microsoft SQL Server database.
-
-Driver
-------
-
-The MSSQL dialect will work with three different available drivers:
-
-* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded
- driver.
-
-* *pymssql* - http://pymssql.sourceforge.net/
-
-* *adodbapi* - http://adodbapi.sourceforge.net/
-
-Drivers are loaded in the order listed above based on availability.
-
-If you need to load a specific driver pass ``module_name`` when
-creating the engine::
-
- engine = create_engine('mssql://dsn', module_name='pymssql')
-
-``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and
-``adodbapi``.
-
-Currently the pyodbc driver offers the greatest level of
-compatibility.
-
-Connecting
-----------
-
-Connecting with create_engine() uses the standard URL approach of
-``mssql://user:pass@host/dbname[?key=value&key=value...]``.
-
-If the database name is present, the tokens are converted to a
-connection string with the specified values. If the database is not
-present, then the host token is taken directly as the DSN name.
-
-Examples of pyodbc connection string URLs:
-
-* *mssql://mydsn* - connects using the specified DSN named ``mydsn``.
- The connection string that is created will appear like::
-
- dsn=mydsn;TrustedConnection=Yes
-
-* *mssql://user:pass@mydsn* - connects using the DSN named
- ``mydsn`` passing in the ``UID`` and ``PWD`` information. The
- connection string that is created will appear like::
-
- dsn=mydsn;UID=user;PWD=pass
-
-* *mssql://user:pass@mydsn/?LANGUAGE=us_english* - connects
- using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
- information, plus the additional connection configuration option
- ``LANGUAGE``. The connection string that is created will appear
- like::
-
- dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
-
-* *mssql://user:pass@host/db* - connects using a connection string
- dynamically created that would appear like::
-
- DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
-
-* *mssql://user:pass@host:123/db* - connects using a connection
- string that is dynamically created, which also includes the port
- information using the comma syntax. If your connection string
- requires the port information to be passed as a ``port`` keyword
- see the next example. This will create the following connection
- string::
-
- DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
-
-* *mssql://user:pass@host/db?port=123* - connects using a connection
- string that is dynamically created that includes the port
- information as a separate ``port`` keyword. This will create the
- following connection string::
-
- DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
-
-If you require a connection string that is outside the options
-presented above, use the ``odbc_connect`` keyword to pass in a
-urlencoded connection string. What gets passed in will be urldecoded
-and passed directly.
-
-For example::
-
- mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
-
-would create the following connection string::
-
- dsn=mydsn;Database=db
-
-Encoding your connection string can be easily accomplished through
-the python shell. For example::
-
- >>> import urllib
- >>> urllib.quote_plus('dsn=mydsn;Database=db')
- 'dsn%3Dmydsn%3BDatabase%3Ddb'
-
-Additional arguments which may be specified either as query string
-arguments on the URL, or as keyword argument to
-:func:`~sqlalchemy.create_engine()` are:
-
-* *auto_identity_insert* - enables support for IDENTITY inserts by
- automatically turning IDENTITY INSERT ON and OFF as required.
- Defaults to ``True``.
-
-* *query_timeout* - allows you to override the default query timeout.
- Defaults to ``None``. This is only supported on pymssql.
-
-* *text_as_varchar* - if enabled this will treat all TEXT column
- types as their equivalent VARCHAR(max) type. This is often used if
- you need to compare a VARCHAR to a TEXT field, which is not
- supported directly on MSSQL. Defaults to ``False``.
-
-* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY
- should be used in place of the non-scoped version @@IDENTITY.
- Defaults to ``False``. On pymssql this defaults to ``True``, and on
- pyodbc this defaults to ``True`` if the version of pyodbc being
- used supports it.
-
-* *has_window_funcs* - indicates whether or not window functions
- (LIMIT and OFFSET) are supported on the version of MSSQL being
- used. If you're running MSSQL 2005 or later turn this on to get
- OFFSET support. Defaults to ``False``.
-
-* *max_identifier_length* - allows you to se the maximum length of
- identfiers supported by the database. Defaults to 128. For pymssql
- the default is 30.
-
-* *schema_name* - use to set the schema name. Defaults to ``dbo``.
-
-Auto Increment Behavior
------------------------
-
-``IDENTITY`` columns are supported by using SQLAlchemy
-``schema.Sequence()`` objects. In other words::
-
- Table('test', mss_engine,
- Column('id', Integer,
- Sequence('blah',100,10), primary_key=True),
- Column('name', String(20))
- ).create()
-
-would yield::
-
- CREATE TABLE test (
- id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
- name VARCHAR(20) NULL,
- )
-
-Note that the ``start`` and ``increment`` values for sequences are
-optional and will default to 1,1.
-
-* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
- ``INSERT`` s)
-
-* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on
- ``INSERT``
-
-Collation Support
------------------
-
-MSSQL specific string types support a collation parameter that
-creates a column-level specific collation for the column. The
-collation parameter accepts a Windows Collation Name or a SQL
-Collation Name. Supported types are MSChar, MSNChar, MSString,
-MSNVarchar, MSText, and MSNText. For example::
-
- Column('login', String(32, collation='Latin1_General_CI_AS'))
-
-will yield::
-
- login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL
-
-LIMIT/OFFSET Support
---------------------
-
-MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is
-supported directly through the ``TOP`` Transact SQL keyword::
-
- select.limit
-
-will yield::
-
- SELECT TOP n
-
-If the ``has_window_funcs`` flag is set then LIMIT with OFFSET
-support is available through the ``ROW_NUMBER OVER`` construct. This
-construct requires an ``ORDER BY`` to be specified as well and is
-only available on MSSQL 2005 and later.
-
-Nullability
------------
-MSSQL has support for three levels of column nullability. The default
-nullability allows nulls and is explicit in the CREATE TABLE
-construct::
-
- name VARCHAR(20) NULL
-
-If ``nullable=None`` is specified then no specification is made. In
-other words the database's configured default is used. This will
-render::
-
- name VARCHAR(20)
-
-If ``nullable`` is ``True`` or ``False`` then the column will be
-``NULL` or ``NOT NULL`` respectively.
-
-Date / Time Handling
---------------------
-For MSSQL versions that support the ``DATE`` and ``TIME`` types
-(MSSQL 2008+) the data type is used. For versions that do not
-support the ``DATE`` and ``TIME`` types a ``DATETIME`` type is used
-instead and the MSSQL dialect handles converting the results
-properly. This means ``Date()`` and ``Time()`` are fully supported
-on all versions of MSSQL. If you do not desire this behavior then
-do not use the ``Date()`` or ``Time()`` types.
-
-Compatibility Levels
---------------------
-MSSQL supports the notion of setting compatibility levels at the
-database level. This allows, for instance, to run a database that
-is compatibile with SQL2000 while running on a SQL2005 database
-server. ``server_version_info`` will always retrun the database
-server version information (in this case SQL2005) and not the
-compatibiility level information. Because of this, if running under
-a backwards compatibility mode SQAlchemy may attempt to use T-SQL
-statements that are unable to be parsed by the database server.
-
-Known Issues
-------------
-
-* No support for more than one ``IDENTITY`` column per table
-
-* pymssql has problems with binary and unicode data that this module
- does **not** work around
-
-"""
-import datetime, decimal, inspect, operator, re, sys, urllib
-
-from sqlalchemy import sql, schema, exc, util
-from sqlalchemy import Table, MetaData, Column, ForeignKey, String, Integer
-from sqlalchemy.sql import select, compiler, expression, operators as sql_operators, functions as sql_functions
-from sqlalchemy.engine import default, base
-from sqlalchemy import types as sqltypes
-from decimal import Decimal as _python_Decimal
-
-
-RESERVED_WORDS = set(
- ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization',
- 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade',
- 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce',
- 'collate', 'column', 'commit', 'compute', 'constraint', 'contains',
- 'containstable', 'continue', 'convert', 'create', 'cross', 'current',
- 'current_date', 'current_time', 'current_timestamp', 'current_user',
- 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default',
- 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double',
- 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec',
- 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor',
- 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full',
- 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity',
- 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert',
- 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like',
- 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not',
- 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource',
- 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer',
- 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print',
- 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext',
- 'reconfigure', 'references', 'replication', 'restore', 'restrict',
- 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount',
- 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select',
- 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics',
- 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top',
- 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union',
- 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values',
- 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with',
- 'writetext',
- ])
-
-
-class _StringType(object):
- """Base for MSSQL string types."""
-
- def __init__(self, collation=None, **kwargs):
- self.collation = kwargs.get('collate', collation)
-
- def _extend(self, spec):
- """Extend a string-type declaration with standard SQL
- COLLATE annotations.
- """
-
- if self.collation:
- collation = 'COLLATE %s' % self.collation
- else:
- collation = None
-
- return ' '.join([c for c in (spec, collation)
- if c is not None])
-
- def __repr__(self):
- attributes = inspect.getargspec(self.__init__)[0][1:]
- attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
-
- params = {}
- for attr in attributes:
- val = getattr(self, attr)
- if val is not None and val is not False:
- params[attr] = val
-
- return "%s(%s)" % (self.__class__.__name__,
- ', '.join(['%s=%r' % (k, params[k]) for k in params]))
-
- def bind_processor(self, dialect):
- if self.convert_unicode or dialect.convert_unicode:
- if self.assert_unicode is None:
- assert_unicode = dialect.assert_unicode
- else:
- assert_unicode = self.assert_unicode
-
- if not assert_unicode:
- return None
-
- def process(value):
- if not isinstance(value, (unicode, sqltypes.NoneType)):
- if assert_unicode == 'warn':
- util.warn("Unicode type received non-unicode bind "
- "param value %r" % value)
- return value
- else:
- raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
- else:
- return value
- return process
- else:
- return None
-
-
-class MSNumeric(sqltypes.Numeric):
- def result_processor(self, dialect):
- if self.asdecimal:
- def process(value):
- if value is not None:
- return _python_Decimal(str(value))
- else:
- return value
- return process
- else:
- def process(value):
- return float(value)
- return process
-
- def bind_processor(self, dialect):
- def process(value):
- if value is None:
- # Not sure that this exception is needed
- return value
-
- elif isinstance(value, decimal.Decimal):
- if value.adjusted() < 0:
- result = "%s0.%s%s" % (
- (value < 0 and '-' or ''),
- '0' * (abs(value.adjusted()) - 1),
- "".join([str(nint) for nint in value._int]))
-
- else:
- if 'E' in str(value):
- result = "%s%s%s" % (
- (value < 0 and '-' or ''),
- "".join([str(s) for s in value._int]),
- "0" * (value.adjusted() - (len(value._int)-1)))
- else:
- if (len(value._int) - 1) > value.adjusted():
- result = "%s%s.%s" % (
- (value < 0 and '-' or ''),
- "".join([str(s) for s in value._int][0:value.adjusted() + 1]),
- "".join([str(s) for s in value._int][value.adjusted() + 1:]))
- else:
- result = "%s%s" % (
- (value < 0 and '-' or ''),
- "".join([str(s) for s in value._int][0:value.adjusted() + 1]))
-
- return result
-
- else:
- return value
-
- return process
-
- def get_col_spec(self):
- if self.precision is None:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-
-class MSFloat(sqltypes.Float):
- def get_col_spec(self):
- if self.precision is None:
- return "FLOAT"
- else:
- return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class MSReal(MSFloat):
- """A type for ``real`` numbers."""
-
- def __init__(self):
- """
- Construct a Real.
-
- """
- super(MSReal, self).__init__(precision=24)
-
- def adapt(self, impltype):
- return impltype()
-
- def get_col_spec(self):
- return "REAL"
-
-
-class MSInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-
-class MSBigInteger(MSInteger):
- def get_col_spec(self):
- return "BIGINT"
-
-
-class MSTinyInteger(MSInteger):
- def get_col_spec(self):
- return "TINYINT"
-
-
-class MSSmallInteger(MSInteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-
-class _DateTimeType(object):
- """Base for MSSQL datetime types."""
-
- def bind_processor(self, dialect):
- # if we receive just a date we can manipulate it
- # into a datetime since the db-api may not do this.
- def process(value):
- if type(value) is datetime.date:
- return datetime.datetime(value.year, value.month, value.day)
- return value
- return process
-
-
-class MSDateTime(_DateTimeType, sqltypes.DateTime):
- def get_col_spec(self):
- return "DATETIME"
-
-
-class MSDate(sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
-
-
-class MSTime(sqltypes.Time):
- def __init__(self, precision=None, **kwargs):
- self.precision = precision
- super(MSTime, self).__init__()
-
- def get_col_spec(self):
- if self.precision:
- return "TIME(%s)" % self.precision
- else:
- return "TIME"
-
-
-class MSSmallDateTime(_DateTimeType, sqltypes.TypeEngine):
- def get_col_spec(self):
- return "SMALLDATETIME"
-
-
-class MSDateTime2(_DateTimeType, sqltypes.TypeEngine):
- def __init__(self, precision=None, **kwargs):
- self.precision = precision
-
- def get_col_spec(self):
- if self.precision:
- return "DATETIME2(%s)" % self.precision
- else:
- return "DATETIME2"
-
-
-class MSDateTimeOffset(_DateTimeType, sqltypes.TypeEngine):
- def __init__(self, precision=None, **kwargs):
- self.precision = precision
-
- def get_col_spec(self):
- if self.precision:
- return "DATETIMEOFFSET(%s)" % self.precision
- else:
- return "DATETIMEOFFSET"
-
-
-class MSDateTimeAsDate(_DateTimeType, MSDate):
- """ This is an implementation of the Date type for versions of MSSQL that
- do not support that specific type. In order to make it work a ``DATETIME``
- column specification is used and the results get converted back to just
- the date portion.
-
- """
-
- def get_col_spec(self):
- return "DATETIME"
-
- def result_processor(self, dialect):
- def process(value):
- # If the DBAPI returns the value as datetime.datetime(), truncate
- # it back to datetime.date()
- if type(value) is datetime.datetime:
- return value.date()
- return value
- return process
-
-
-class MSDateTimeAsTime(MSTime):
- """ This is an implementation of the Time type for versions of MSSQL that
- do not support that specific type. In order to make it work a ``DATETIME``
- column specification is used and the results get converted back to just
- the time portion.
-
- """
-
- __zero_date = datetime.date(1900, 1, 1)
-
- def get_col_spec(self):
- return "DATETIME"
-
- def bind_processor(self, dialect):
- def process(value):
- if type(value) is datetime.datetime:
- value = datetime.datetime.combine(self.__zero_date, value.time())
- elif type(value) is datetime.time:
- value = datetime.datetime.combine(self.__zero_date, value)
- return value
- return process
-
- def result_processor(self, dialect):
- def process(value):
- if type(value) is datetime.datetime:
- return value.time()
- elif type(value) is datetime.date:
- return datetime.time(0, 0, 0)
- return value
- return process
-
-
-class MSDateTime_adodbapi(MSDateTime):
- def result_processor(self, dialect):
- def process(value):
- # adodbapi will return datetimes with empty time values as datetime.date() objects.
- # Promote them back to full datetime.datetime()
- if type(value) is datetime.date:
- return datetime.datetime(value.year, value.month, value.day)
- return value
- return process
-
-
-class MSText(_StringType, sqltypes.Text):
- """MSSQL TEXT type, for variable-length text up to 2^31 characters."""
-
- def __init__(self, *args, **kwargs):
- """Construct a TEXT.
-
- :param collation: Optional, a column-level collation for this string
- value. Accepts a Windows Collation Name or a SQL Collation Name.
-
- """
- _StringType.__init__(self, **kwargs)
- sqltypes.Text.__init__(self, None,
- convert_unicode=kwargs.get('convert_unicode', False),
- assert_unicode=kwargs.get('assert_unicode', None))
-
- def get_col_spec(self):
- if self.dialect.text_as_varchar:
- return self._extend("VARCHAR(max)")
- else:
- return self._extend("TEXT")
-
-
-class MSNText(_StringType, sqltypes.UnicodeText):
- """MSSQL NTEXT type, for variable-length unicode text up to 2^30
- characters."""
-
- def __init__(self, *args, **kwargs):
- """Construct a NTEXT.
-
- :param collation: Optional, a column-level collation for this string
- value. Accepts a Windows Collation Name or a SQL Collation Name.
-
- """
- _StringType.__init__(self, **kwargs)
- sqltypes.UnicodeText.__init__(self, None,
- convert_unicode=kwargs.get('convert_unicode', True),
- assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
- def get_col_spec(self):
- if self.dialect.text_as_varchar:
- return self._extend("NVARCHAR(max)")
- else:
- return self._extend("NTEXT")
-
-
-class MSString(_StringType, sqltypes.String):
- """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum
- of 8,000 characters."""
-
- def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs):
- """Construct a VARCHAR.
-
- :param length: Optinal, maximum data length, in characters.
-
- :param convert_unicode: defaults to False. If True, convert
- ``unicode`` data sent to the database to a ``str``
- bytestring, and convert bytestrings coming back from the
- database into ``unicode``.
-
- Bytestrings are encoded using the dialect's
- :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which
- defaults to `utf-8`.
-
- If False, may be overridden by
- :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`.
-
- :param assert_unicode:
-
- If None (the default), no assertion will take place unless
- overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`.
-
- If 'warn', will issue a runtime warning if a ``str``
- instance is used as a bind value.
-
- If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`.
-
- :param collation: Optional, a column-level collation for this string
- value. Accepts a Windows Collation Name or a SQL Collation Name.
-
- """
- _StringType.__init__(self, **kwargs)
- sqltypes.String.__init__(self, length=length,
- convert_unicode=convert_unicode,
- assert_unicode=assert_unicode)
-
- def get_col_spec(self):
- if self.length:
- return self._extend("VARCHAR(%s)" % self.length)
- else:
- return self._extend("VARCHAR")
-
-
-class MSNVarchar(_StringType, sqltypes.Unicode):
- """MSSQL NVARCHAR type.
-
- For variable-length unicode character data up to 4,000 characters."""
-
- def __init__(self, length=None, **kwargs):
- """Construct a NVARCHAR.
-
- :param length: Optional, Maximum data length, in characters.
-
- :param collation: Optional, a column-level collation for this string
- value. Accepts a Windows Collation Name or a SQL Collation Name.
-
- """
- _StringType.__init__(self, **kwargs)
- sqltypes.Unicode.__init__(self, length=length,
- convert_unicode=kwargs.get('convert_unicode', True),
- assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
- def adapt(self, impltype):
- return impltype(length=self.length,
- convert_unicode=self.convert_unicode,
- assert_unicode=self.assert_unicode,
- collation=self.collation)
-
- def get_col_spec(self):
- if self.length:
- return self._extend("NVARCHAR(%(length)s)" % {'length' : self.length})
- else:
- return self._extend("NVARCHAR")
-
-
-class MSChar(_StringType, sqltypes.CHAR):
- """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum
- of 8,000 characters."""
-
- def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs):
- """Construct a CHAR.
-
- :param length: Optinal, maximum data length, in characters.
-
- :param convert_unicode: defaults to False. If True, convert
- ``unicode`` data sent to the database to a ``str``
- bytestring, and convert bytestrings coming back from the
- database into ``unicode``.
-
- Bytestrings are encoded using the dialect's
- :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which
- defaults to `utf-8`.
-
- If False, may be overridden by
- :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`.
-
- :param assert_unicode:
-
- If None (the default), no assertion will take place unless
- overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`.
-
- If 'warn', will issue a runtime warning if a ``str``
- instance is used as a bind value.
-
- If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`.
-
- :param collation: Optional, a column-level collation for this string
- value. Accepts a Windows Collation Name or a SQL Collation Name.
-
- """
- _StringType.__init__(self, **kwargs)
- sqltypes.CHAR.__init__(self, length=length,
- convert_unicode=convert_unicode,
- assert_unicode=assert_unicode)
-
- def get_col_spec(self):
- if self.length:
- return self._extend("CHAR(%s)" % self.length)
- else:
- return self._extend("CHAR")
-
-
-class MSNChar(_StringType, sqltypes.NCHAR):
- """MSSQL NCHAR type.
-
- For fixed-length unicode character data up to 4,000 characters."""
-
- def __init__(self, length=None, **kwargs):
- """Construct an NCHAR.
-
- :param length: Optional, Maximum data length, in characters.
-
- :param collation: Optional, a column-level collation for this string
- value. Accepts a Windows Collation Name or a SQL Collation Name.
-
- """
- _StringType.__init__(self, **kwargs)
- sqltypes.NCHAR.__init__(self, length=length,
- convert_unicode=kwargs.get('convert_unicode', True),
- assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
- def get_col_spec(self):
- if self.length:
- return self._extend("NCHAR(%(length)s)" % {'length' : self.length})
- else:
- return self._extend("NCHAR")
-
-
-class MSGenericBinary(sqltypes.Binary):
- """The Binary type assumes that a Binary specification without a length
- is an unbound Binary type whereas one with a length specification results
- in a fixed length Binary type.
-
- If you want standard MSSQL ``BINARY`` behavior use the ``MSBinary`` type.
-
- """
-
- def get_col_spec(self):
- if self.length:
- return "BINARY(%s)" % self.length
- else:
- return "IMAGE"
-
-
-class MSBinary(MSGenericBinary):
- def get_col_spec(self):
- if self.length:
- return "BINARY(%s)" % self.length
- else:
- return "BINARY"
-
-
-class MSVarBinary(MSGenericBinary):
- def get_col_spec(self):
- if self.length:
- return "VARBINARY(%s)" % self.length
- else:
- return "VARBINARY"
-
-
-class MSImage(MSGenericBinary):
- def get_col_spec(self):
- return "IMAGE"
-
-
-class MSBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return "BIT"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return value and True or False
- return process
-
- def bind_processor(self, dialect):
- def process(value):
- if value is True:
- return 1
- elif value is False:
- return 0
- elif value is None:
- return None
- else:
- return value and True or False
- return process
-
-
-class MSTimeStamp(sqltypes.TIMESTAMP):
- def get_col_spec(self):
- return "TIMESTAMP"
-
-
-class MSMoney(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "MONEY"
-
-
-class MSSmallMoney(MSMoney):
- def get_col_spec(self):
- return "SMALLMONEY"
-
-
-class MSUniqueIdentifier(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "UNIQUEIDENTIFIER"
-
-
-class MSVariant(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "SQL_VARIANT"
-
-ischema = MetaData()
-
-schemata = Table("SCHEMATA", ischema,
- Column("CATALOG_NAME", String, key="catalog_name"),
- Column("SCHEMA_NAME", String, key="schema_name"),
- Column("SCHEMA_OWNER", String, key="schema_owner"),
- schema="INFORMATION_SCHEMA")
-
-tables = Table("TABLES", ischema,
- Column("TABLE_CATALOG", String, key="table_catalog"),
- Column("TABLE_SCHEMA", String, key="table_schema"),
- Column("TABLE_NAME", String, key="table_name"),
- Column("TABLE_TYPE", String, key="table_type"),
- schema="INFORMATION_SCHEMA")
-
-columns = Table("COLUMNS", ischema,
- Column("TABLE_SCHEMA", String, key="table_schema"),
- Column("TABLE_NAME", String, key="table_name"),
- Column("COLUMN_NAME", String, key="column_name"),
- Column("IS_NULLABLE", Integer, key="is_nullable"),
- Column("DATA_TYPE", String, key="data_type"),
- Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
- Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"),
- Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
- Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
- Column("COLUMN_DEFAULT", Integer, key="column_default"),
- Column("COLLATION_NAME", String, key="collation_name"),
- schema="INFORMATION_SCHEMA")
-
-constraints = Table("TABLE_CONSTRAINTS", ischema,
- Column("TABLE_SCHEMA", String, key="table_schema"),
- Column("TABLE_NAME", String, key="table_name"),
- Column("CONSTRAINT_NAME", String, key="constraint_name"),
- Column("CONSTRAINT_TYPE", String, key="constraint_type"),
- schema="INFORMATION_SCHEMA")
-
-column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
- Column("TABLE_SCHEMA", String, key="table_schema"),
- Column("TABLE_NAME", String, key="table_name"),
- Column("COLUMN_NAME", String, key="column_name"),
- Column("CONSTRAINT_NAME", String, key="constraint_name"),
- schema="INFORMATION_SCHEMA")
-
-key_constraints = Table("KEY_COLUMN_USAGE", ischema,
- Column("TABLE_SCHEMA", String, key="table_schema"),
- Column("TABLE_NAME", String, key="table_name"),
- Column("COLUMN_NAME", String, key="column_name"),
- Column("CONSTRAINT_NAME", String, key="constraint_name"),
- Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
- schema="INFORMATION_SCHEMA")
-
-ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
- Column("CONSTRAINT_CATALOG", String, key="constraint_catalog"),
- Column("CONSTRAINT_SCHEMA", String, key="constraint_schema"),
- Column("CONSTRAINT_NAME", String, key="constraint_name"),
- Column("UNIQUE_CONSTRAINT_CATLOG", String, key="unique_constraint_catalog"),
- Column("UNIQUE_CONSTRAINT_SCHEMA", String, key="unique_constraint_schema"),
- Column("UNIQUE_CONSTRAINT_NAME", String, key="unique_constraint_name"),
- Column("MATCH_OPTION", String, key="match_option"),
- Column("UPDATE_RULE", String, key="update_rule"),
- Column("DELETE_RULE", String, key="delete_rule"),
- schema="INFORMATION_SCHEMA")
-
-def _has_implicit_sequence(column):
- return column.primary_key and \
- column.autoincrement and \
- isinstance(column.type, sqltypes.Integer) and \
- not column.foreign_keys and \
- (
- column.default is None or
- (
- isinstance(column.default, schema.Sequence) and
- column.default.optional)
- )
-
-def _table_sequence_column(tbl):
- if not hasattr(tbl, '_ms_has_sequence'):
- tbl._ms_has_sequence = None
- for column in tbl.c:
- if getattr(column, 'sequence', False) or _has_implicit_sequence(column):
- tbl._ms_has_sequence = column
- break
- return tbl._ms_has_sequence
-
-class MSSQLExecutionContext(default.DefaultExecutionContext):
- IINSERT = False
- HASIDENT = False
-
- def pre_exec(self):
- """Activate IDENTITY_INSERT if needed."""
-
- if self.compiled.isinsert:
- tbl = self.compiled.statement.table
- seq_column = _table_sequence_column(tbl)
- self.HASIDENT = bool(seq_column)
- if self.dialect.auto_identity_insert and self.HASIDENT:
- self.IINSERT = tbl._ms_has_sequence.key in self.compiled_parameters[0]
- else:
- self.IINSERT = False
-
- if self.IINSERT:
- self.cursor.execute("SET IDENTITY_INSERT %s ON" %
- self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
-
- def handle_dbapi_exception(self, e):
- if self.IINSERT:
- try:
- self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
- except:
- pass
-
- def post_exec(self):
- """Disable IDENTITY_INSERT if enabled."""
-
- if self.compiled.isinsert and not self.executemany and self.HASIDENT and not self.IINSERT:
- if not self._last_inserted_ids or self._last_inserted_ids[0] is None:
- if self.dialect.use_scope_identity:
- self.cursor.execute("SELECT scope_identity() AS lastrowid")
- else:
- self.cursor.execute("SELECT @@identity AS lastrowid")
- row = self.cursor.fetchone()
- self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
-
- if self.IINSERT:
- self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
-
-
-class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
- def pre_exec(self):
- """where appropriate, issue "select scope_identity()" in the same statement"""
- super(MSSQLExecutionContext_pyodbc, self).pre_exec()
- if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \
- and len(self.parameters) == 1 and self.dialect.use_scope_identity:
- self.statement += "; select scope_identity()"
-
- def post_exec(self):
- if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany:
- import pyodbc
- # Fetch the last inserted id from the manipulated statement
- # We may have to skip over a number of result sets with no data (due to triggers, etc.)
- while True:
- try:
- row = self.cursor.fetchone()
- break
- except pyodbc.Error, e:
- self.cursor.nextset()
- self._last_inserted_ids = [int(row[0])]
- else:
- super(MSSQLExecutionContext_pyodbc, self).post_exec()
-
-class MSSQLDialect(default.DefaultDialect):
- name = 'mssql'
- supports_default_values = True
- supports_empty_insert = False
- auto_identity_insert = True
- execution_ctx_cls = MSSQLExecutionContext
- text_as_varchar = False
- use_scope_identity = False
- has_window_funcs = False
- max_identifier_length = 128
- schema_name = "dbo"
-
- colspecs = {
- sqltypes.Unicode : MSNVarchar,
- sqltypes.Integer : MSInteger,
- sqltypes.Smallinteger: MSSmallInteger,
- sqltypes.Numeric : MSNumeric,
- sqltypes.Float : MSFloat,
- sqltypes.DateTime : MSDateTime,
- sqltypes.Date : MSDate,
- sqltypes.Time : MSTime,
- sqltypes.String : MSString,
- sqltypes.Binary : MSGenericBinary,
- sqltypes.Boolean : MSBoolean,
- sqltypes.Text : MSText,
- sqltypes.UnicodeText : MSNText,
- sqltypes.CHAR: MSChar,
- sqltypes.NCHAR: MSNChar,
- sqltypes.TIMESTAMP: MSTimeStamp,
- }
-
- ischema_names = {
- 'int' : MSInteger,
- 'bigint': MSBigInteger,
- 'smallint' : MSSmallInteger,
- 'tinyint' : MSTinyInteger,
- 'varchar' : MSString,
- 'nvarchar' : MSNVarchar,
- 'char' : MSChar,
- 'nchar' : MSNChar,
- 'text' : MSText,
- 'ntext' : MSNText,
- 'decimal' : MSNumeric,
- 'numeric' : MSNumeric,
- 'float' : MSFloat,
- 'datetime' : MSDateTime,
- 'datetime2' : MSDateTime2,
- 'datetimeoffset' : MSDateTimeOffset,
- 'date': MSDate,
- 'time': MSTime,
- 'smalldatetime' : MSSmallDateTime,
- 'binary' : MSBinary,
- 'varbinary' : MSVarBinary,
- 'bit': MSBoolean,
- 'real' : MSFloat,
- 'image' : MSImage,
- 'timestamp': MSTimeStamp,
- 'money': MSMoney,
- 'smallmoney': MSSmallMoney,
- 'uniqueidentifier': MSUniqueIdentifier,
- 'sql_variant': MSVariant,
- }
-
- def __new__(cls, *args, **kwargs):
- if cls is not MSSQLDialect:
- # this gets called with the dialect specific class
- return super(MSSQLDialect, cls).__new__(cls)
- dbapi = kwargs.get('dbapi', None)
- if dbapi:
- dialect = dialect_mapping.get(dbapi.__name__)
- return dialect(**kwargs)
- else:
- return object.__new__(cls)
-
- def __init__(self,
- auto_identity_insert=True, query_timeout=None,
- text_as_varchar=False, use_scope_identity=False,
- has_window_funcs=False, max_identifier_length=None,
- schema_name="dbo", **opts):
- self.auto_identity_insert = bool(auto_identity_insert)
- self.query_timeout = int(query_timeout or 0)
- self.schema_name = schema_name
-
- # to-do: the options below should use server version introspection to set themselves on connection
- self.text_as_varchar = bool(text_as_varchar)
- self.use_scope_identity = bool(use_scope_identity)
- self.has_window_funcs = bool(has_window_funcs)
- self.max_identifier_length = int(max_identifier_length or 0) or \
- self.max_identifier_length
- super(MSSQLDialect, self).__init__(**opts)
-
- @classmethod
- def dbapi(cls, module_name=None):
- if module_name:
- try:
- dialect_cls = dialect_mapping[module_name]
- return dialect_cls.import_dbapi()
- except KeyError:
- raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
- else:
- for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]:
- try:
- return dialect_cls.import_dbapi()
- except ImportError, e:
- pass
- else:
- raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi')
-
- @base.connection_memoize(('mssql', 'server_version_info'))
- def server_version_info(self, connection):
- """A tuple of the database server version.
-
- Formats the remote server version as a tuple of version values,
- e.g. ``(9, 0, 1399)``. If there are strings in the version number
- they will be in the tuple too, so don't count on these all being
- ``int`` values.
-
- This is a fast check that does not require a round trip. It is also
- cached per-Connection.
- """
- return connection.dialect._server_version_info(connection.connection)
-
- def _server_version_info(self, dbapi_con):
- """Return a tuple of the database's version number."""
- raise NotImplementedError()
-
- def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- opts.update(url.query)
- if 'auto_identity_insert' in opts:
- self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert')))
- if 'query_timeout' in opts:
- self.query_timeout = int(opts.pop('query_timeout'))
- if 'text_as_varchar' in opts:
- self.text_as_varchar = bool(int(opts.pop('text_as_varchar')))
- if 'use_scope_identity' in opts:
- self.use_scope_identity = bool(int(opts.pop('use_scope_identity')))
- if 'has_window_funcs' in opts:
- self.has_window_funcs = bool(int(opts.pop('has_window_funcs')))
- return self.make_connect_string(opts, url.query)
-
- def type_descriptor(self, typeobj):
- newobj = sqltypes.adapt_type(typeobj, self.colspecs)
- # Some types need to know about the dialect
- if isinstance(newobj, (MSText, MSNText)):
- newobj.dialect = self
- return newobj
-
- def do_savepoint(self, connection, name):
- util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
- connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION")
- connection.execute("SAVE TRANSACTION %s" % name)
-
- def do_release_savepoint(self, connection, name):
- pass
-
- @base.connection_memoize(('dialect', 'default_schema_name'))
- def get_default_schema_name(self, connection):
- query = "SELECT user_name() as user_name;"
- user_name = connection.scalar(sql.text(query))
- if user_name is not None:
- # now, get the default schema
- query = """
- SELECT default_schema_name FROM
- sys.database_principals
- WHERE name = :user_name
- AND type = 'S'
- """
- try:
- default_schema_name = connection.scalar(sql.text(query),
- user_name=user_name)
- if default_schema_name is not None:
- return default_schema_name
- except:
- pass
- return self.schema_name
-
- def table_names(self, connection, schema):
- s = select([tables.c.table_name], tables.c.table_schema==schema)
- return [row[0] for row in connection.execute(s)]
-
-
- def has_table(self, connection, tablename, schema=None):
-
- current_schema = schema or self.get_default_schema_name(connection)
- s = sql.select([columns],
- current_schema
- and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
- or columns.c.table_name==tablename,
- )
-
- c = connection.execute(s)
- row = c.fetchone()
- return row is not None
-
- def reflecttable(self, connection, table, include_columns):
- # Get base columns
- if table.schema is not None:
- current_schema = table.schema
- else:
- current_schema = self.get_default_schema_name(connection)
-
- s = sql.select([columns],
- current_schema
- and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema)
- or columns.c.table_name==table.name,
- order_by=[columns.c.ordinal_position])
-
- c = connection.execute(s)
- found_table = False
- while True:
- row = c.fetchone()
- if row is None:
- break
- found_table = True
- (name, type, nullable, charlen, numericprec, numericscale, default, collation) = (
- row[columns.c.column_name],
- row[columns.c.data_type],
- row[columns.c.is_nullable] == 'YES',
- row[columns.c.character_maximum_length],
- row[columns.c.numeric_precision],
- row[columns.c.numeric_scale],
- row[columns.c.column_default],
- row[columns.c.collation_name]
- )
- if include_columns and name not in include_columns:
- continue
-
- coltype = self.ischema_names.get(type, None)
-
- kwargs = {}
- if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.Binary):
- kwargs['length'] = charlen
- if collation:
- kwargs['collation'] = collation
- if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1):
- kwargs.pop('length')
-
- if issubclass(coltype, sqltypes.Numeric):
- kwargs['scale'] = numericscale
- kwargs['precision'] = numericprec
-
- if coltype is None:
- util.warn("Did not recognize type '%s' of column '%s'" % (type, name))
- coltype = sqltypes.NULLTYPE
-
- coltype = coltype(**kwargs)
- colargs = []
- if default is not None:
- colargs.append(schema.DefaultClause(sql.text(default)))
- table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs))
-
- if not found_table:
- raise exc.NoSuchTableError(table.name)
-
- # We also run an sp_columns to check for identity columns:
- cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (table.name, current_schema))
- ic = None
- while True:
- row = cursor.fetchone()
- if row is None:
- break
- col_name, type_name = row[3], row[5]
- if type_name.endswith("identity") and col_name in table.c:
- ic = table.c[col_name]
- ic.autoincrement = True
- # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute
- ic.sequence = schema.Sequence(ic.name + '_identity', 1, 1)
- # MSSQL: only one identity per table allowed
- cursor.close()
- break
- if not ic is None:
- try:
- cursor = connection.execute("select ident_seed(?), ident_incr(?)", table.fullname, table.fullname)
- row = cursor.fetchone()
- cursor.close()
- if not row is None:
- ic.sequence.start = int(row[0])
- ic.sequence.increment = int(row[1])
- except:
- # ignoring it, works just like before
- pass
-
- # Add constraints
- RR = ref_constraints
- TC = constraints
- C = key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column
- R = key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column
-
- # Primary key constraints
- s = sql.select([C.c.column_name, TC.c.constraint_type], sql.and_(TC.c.constraint_name == C.c.constraint_name,
- C.c.table_name == table.name,
- C.c.table_schema == (table.schema or current_schema)))
- c = connection.execute(s)
- for row in c:
- if 'PRIMARY' in row[TC.c.constraint_type.name] and row[0] in table.c:
- table.primary_key.add(table.c[row[0]])
-
- # Foreign key constraints
- s = sql.select([C.c.column_name,
- R.c.table_schema, R.c.table_name, R.c.column_name,
- RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule],
- sql.and_(C.c.table_name == table.name,
- C.c.table_schema == (table.schema or current_schema),
- C.c.constraint_name == RR.c.constraint_name,
- R.c.constraint_name == RR.c.unique_constraint_name,
- C.c.ordinal_position == R.c.ordinal_position
- ),
- order_by = [RR.c.constraint_name, R.c.ordinal_position])
- rows = connection.execute(s).fetchall()
-
- def _gen_fkref(table, rschema, rtbl, rcol):
- if rschema == current_schema and not table.schema:
- return '.'.join([rtbl, rcol])
- else:
- return '.'.join([rschema, rtbl, rcol])
-
- # group rows by constraint ID, to handle multi-column FKs
- fknm, scols, rcols = (None, [], [])
- for r in rows:
- scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
- # if the reflected schema is the default schema then don't set it because this will
- # play into the metadata key causing duplicates.
- if rschema == current_schema and not table.schema:
- schema.Table(rtbl, table.metadata, autoload=True, autoload_with=connection)
- else:
- schema.Table(rtbl, table.metadata, schema=rschema, autoload=True, autoload_with=connection)
- if rfknm != fknm:
- if fknm:
- table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True))
- fknm, scols, rcols = (rfknm, [], [])
- if not scol in scols:
- scols.append(scol)
- if not (rschema, rtbl, rcol) in rcols:
- rcols.append((rschema, rtbl, rcol))
-
- if fknm and scols:
- table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True))
-
-
-class MSSQLDialect_pymssql(MSSQLDialect):
- supports_sane_rowcount = False
- max_identifier_length = 30
-
- @classmethod
- def import_dbapi(cls):
- import pymssql as module
- # pymmsql doesn't have a Binary method. we use string
- # TODO: monkeypatching here is less than ideal
- module.Binary = lambda st: str(st)
- try:
- module.version_info = tuple(map(int, module.__version__.split('.')))
- except:
- module.version_info = (0, 0, 0)
- return module
-
- def __init__(self, **params):
- super(MSSQLDialect_pymssql, self).__init__(**params)
- self.use_scope_identity = True
-
- # pymssql understands only ascii
- if self.convert_unicode:
- util.warn("pymssql does not support unicode")
- self.encoding = params.get('encoding', 'ascii')
-
- self.colspecs = MSSQLDialect.colspecs.copy()
- self.ischema_names = MSSQLDialect.ischema_names.copy()
- self.ischema_names['date'] = MSDateTimeAsDate
- self.colspecs[sqltypes.Date] = MSDateTimeAsDate
- self.ischema_names['time'] = MSDateTimeAsTime
- self.colspecs[sqltypes.Time] = MSDateTimeAsTime
-
- def create_connect_args(self, url):
- r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
- if hasattr(self, 'query_timeout'):
- if self.dbapi.version_info > (0, 8, 0):
- r[1]['timeout'] = self.query_timeout
- else:
- self.dbapi._mssql.set_query_timeout(self.query_timeout)
- return r
-
- def make_connect_string(self, keys, query):
- if keys.get('port'):
- # pymssql expects port as host:port, not a separate arg
- keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
- del keys['port']
- return [[], keys]
-
- def is_disconnect(self, e):
- return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
-
- def do_begin(self, connection):
- pass
-
-
-class MSSQLDialect_pyodbc(MSSQLDialect):
- supports_sane_rowcount = False
- supports_sane_multi_rowcount = False
- # PyODBC unicode is broken on UCS-4 builds
- supports_unicode = sys.maxunicode == 65535
- supports_unicode_statements = supports_unicode
- execution_ctx_cls = MSSQLExecutionContext_pyodbc
-
- def __init__(self, description_encoding='latin-1', **params):
- super(MSSQLDialect_pyodbc, self).__init__(**params)
- self.description_encoding = description_encoding
-
- if self.server_version_info < (10,):
- self.colspecs = MSSQLDialect.colspecs.copy()
- self.ischema_names = MSSQLDialect.ischema_names.copy()
- self.ischema_names['date'] = MSDateTimeAsDate
- self.colspecs[sqltypes.Date] = MSDateTimeAsDate
- self.ischema_names['time'] = MSDateTimeAsTime
- self.colspecs[sqltypes.Time] = MSDateTimeAsTime
-
- # FIXME: scope_identity sniff should look at server version, not the ODBC driver
- # whether use_scope_identity will work depends on the version of pyodbc
- try:
- import pyodbc
- self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset')
- except:
- pass
-
- @classmethod
- def import_dbapi(cls):
- import pyodbc as module
- return module
-
- def make_connect_string(self, keys, query):
- if 'max_identifier_length' in keys:
- self.max_identifier_length = int(keys.pop('max_identifier_length'))
-
- if 'odbc_connect' in keys:
- connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
- else:
- dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
- if dsn_connection:
- connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
- else:
- port = ''
- if 'port' in keys and not 'port' in query:
- port = ',%d' % int(keys.pop('port'))
-
- connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'),
- 'Server=%s%s' % (keys.pop('host', ''), port),
- 'Database=%s' % keys.pop('database', '') ]
-
- user = keys.pop("user", None)
- if user:
- connectors.append("UID=%s" % user)
- connectors.append("PWD=%s" % keys.pop('password', ''))
- else:
- connectors.append("TrustedConnection=Yes")
-
- # if set to 'Yes', the ODBC layer will try to automagically convert
- # textual data from your database encoding to your client encoding
- # This should obviously be set to 'No' if you query a cp1253 encoded
- # database from a latin1 client...
- if 'odbc_autotranslate' in keys:
- connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
-
- connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
-
- return [[";".join (connectors)], {}]
-
- def is_disconnect(self, e):
- if isinstance(e, self.dbapi.ProgrammingError):
- return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
- elif isinstance(e, self.dbapi.Error):
- return '[08S01]' in str(e)
- else:
- return False
-
-
- def _server_version_info(self, dbapi_con):
- """Convert a pyodbc SQL_DBMS_VER string into a tuple."""
- version = []
- r = re.compile('[.\-]')
- for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
- try:
- version.append(int(n))
- except ValueError:
- version.append(n)
- return tuple(version)
-
-class MSSQLDialect_adodbapi(MSSQLDialect):
- supports_sane_rowcount = True
- supports_sane_multi_rowcount = True
- supports_unicode = sys.maxunicode == 65535
- supports_unicode_statements = True
-
- @classmethod
- def import_dbapi(cls):
- import adodbapi as module
- return module
-
- colspecs = MSSQLDialect.colspecs.copy()
- colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
-
- ischema_names = MSSQLDialect.ischema_names.copy()
- ischema_names['datetime'] = MSDateTime_adodbapi
-
- def make_connect_string(self, keys, query):
- connectors = ["Provider=SQLOLEDB"]
- if 'port' in keys:
- connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
- else:
- connectors.append ("Data Source=%s" % keys.get("host"))
- connectors.append ("Initial Catalog=%s" % keys.get("database"))
- user = keys.get("user")
- if user:
- connectors.append("User Id=%s" % user)
- connectors.append("Password=%s" % keys.get("password", ""))
- else:
- connectors.append("Integrated Security=SSPI")
- return [[";".join (connectors)], {}]
-
- def is_disconnect(self, e):
- return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
-
-
-dialect_mapping = {
- 'pymssql': MSSQLDialect_pymssql,
- 'pyodbc': MSSQLDialect_pyodbc,
- 'adodbapi': MSSQLDialect_adodbapi
- }
-
-
-class MSSQLCompiler(compiler.DefaultCompiler):
- operators = compiler.OPERATORS.copy()
- operators.update({
- sql_operators.concat_op: '+',
- sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
- })
-
- functions = compiler.DefaultCompiler.functions.copy()
- functions.update (
- {
- sql_functions.now: 'CURRENT_TIMESTAMP',
- sql_functions.current_date: 'GETDATE()',
- 'length': lambda x: "LEN(%s)" % x,
- sql_functions.char_length: lambda x: "LEN(%s)" % x
- }
- )
-
- extract_map = compiler.DefaultCompiler.extract_map.copy()
- extract_map.update ({
- 'doy': 'dayofyear',
- 'dow': 'weekday',
- 'milliseconds': 'millisecond',
- 'microseconds': 'microsecond'
- })
-
- def __init__(self, *args, **kwargs):
- super(MSSQLCompiler, self).__init__(*args, **kwargs)
- self.tablealiases = {}
-
- def get_select_precolumns(self, select):
- """ MS-SQL puts TOP, it's version of LIMIT here """
- if select._distinct or select._limit:
- s = select._distinct and "DISTINCT " or ""
-
- if select._limit:
- if not select._offset:
- s += "TOP %s " % (select._limit,)
- else:
- if not self.dialect.has_window_funcs:
- raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
- return s
- return compiler.DefaultCompiler.get_select_precolumns(self, select)
-
- def limit_clause(self, select):
- # Limit in mssql is after the select keyword
- return ""
-
- def visit_select(self, select, **kwargs):
- """Look for ``LIMIT`` and OFFSET in a select statement, and if
- so tries to wrap it in a subquery with ``row_number()`` criterion.
-
- """
- if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset:
- # to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.process(select._order_by_clause)
- if not orderby:
- raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
-
- _offset = select._offset
- _limit = select._limit
- select._mssql_visit = True
- select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
-
- limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
- limitselect.append_whereclause("mssql_rn>%d" % _offset)
- if _limit is not None:
- limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
- return self.process(limitselect, iswrapper=True, **kwargs)
- else:
- return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
-
- def _schema_aliased_table(self, table):
- if getattr(table, 'schema', None) is not None:
- if table not in self.tablealiases:
- self.tablealiases[table] = table.alias()
- return self.tablealiases[table]
- else:
- return None
-
- def visit_table(self, table, mssql_aliased=False, **kwargs):
- if mssql_aliased:
- return super(MSSQLCompiler, self).visit_table(table, **kwargs)
-
- # alias schema-qualified tables
- alias = self._schema_aliased_table(table)
- if alias is not None:
- return self.process(alias, mssql_aliased=True, **kwargs)
- else:
- return super(MSSQLCompiler, self).visit_table(table, **kwargs)
-
- def visit_alias(self, alias, **kwargs):
- # translate for schema-qualified table aliases
- self.tablealiases[alias.original] = alias
- kwargs['mssql_aliased'] = True
- return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
-
- def visit_extract(self, extract):
- field = self.extract_map.get(extract.field, extract.field)
- return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
-
- def visit_rollback_to_savepoint(self, savepoint_stmt):
- return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
-
- def visit_column(self, column, result_map=None, **kwargs):
- if column.table is not None and \
- (not self.isupdate and not self.isdelete) or self.is_subquery():
- # translate for schema-qualified table aliases
- t = self._schema_aliased_table(column.table)
- if t is not None:
- converted = expression._corresponding_column_or_error(t, column)
-
- if result_map is not None:
- result_map[column.name.lower()] = (column.name, (column, ), column.type)
-
- return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs)
-
- return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs)
-
- def visit_binary(self, binary, **kwargs):
- """Move bind parameters to the right-hand side of an operator, where
- possible.
-
- """
- if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
- and not isinstance(binary.right, expression._BindParamClause):
- return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
- else:
- if (binary.operator is operator.eq or binary.operator is operator.ne) and (
- (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
- (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
- isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
- op = binary.operator == operator.eq and "IN" or "NOT IN"
- return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
- return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
-
- def visit_insert(self, insert_stmt):
- insert_select = False
- if insert_stmt.parameters:
- insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)]
- if insert_select:
- self.isinsert = True
- colparams = self._get_colparams(insert_stmt)
- preparer = self.preparer
-
- insert = ' '.join(["INSERT"] +
- [self.process(x) for x in insert_stmt._prefixes])
-
- if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert:
- raise exc.CompileError(
- "The version of %s you are using does not support empty inserts." % self.dialect.name)
- elif not colparams and self.dialect.supports_default_values:
- return (insert + " INTO %s DEFAULT VALUES" % (
- (preparer.format_table(insert_stmt.table),)))
- else:
- return (insert + " INTO %s (%s) SELECT %s" %
- (preparer.format_table(insert_stmt.table),
- ', '.join([preparer.format_column(c[0])
- for c in colparams]),
- ', '.join([c[1] for c in colparams])))
- else:
- return super(MSSQLCompiler, self).visit_insert(insert_stmt)
-
- def label_select_column(self, select, column, asfrom):
- if isinstance(column, expression.Function):
- return column.label(None)
- else:
- return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
-
- def for_update_clause(self, select):
- # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
- return ''
-
- def order_by_clause(self, select):
- order_by = self.process(select._order_by_clause)
-
- # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
- if order_by and (not self.is_subquery() or select._limit):
- return " ORDER BY " + order_by
- else:
- return ""
-
-
-class MSSQLSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
-
- if column.nullable is not None:
- if not column.nullable or column.primary_key:
- colspec += " NOT NULL"
- else:
- colspec += " NULL"
-
- if not column.table:
- raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
-
- seq_col = _table_sequence_column(column.table)
-
- # install a IDENTITY Sequence if we have an implicit IDENTITY column
- if seq_col is column:
- sequence = getattr(column, 'sequence', None)
- if sequence:
- start, increment = sequence.start or 1, sequence.increment or 1
- else:
- start, increment = 1, 1
- colspec += " IDENTITY(%s,%s)" % (start, increment)
- else:
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- return colspec
-
-class MSSQLSchemaDropper(compiler.SchemaDropper):
- def visit_index(self, index):
- self.append("\nDROP INDEX %s.%s" % (
- self.preparer.quote_identifier(index.table.name),
- self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
- ))
- self.execute()
-
-
-class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
- reserved_words = RESERVED_WORDS
-
- def __init__(self, dialect):
- super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
-
- def _escape_identifier(self, value):
- #TODO: determine MSSQL's escaping rules
- return value
-
- def quote_schema(self, schema, force=True):
- """Prepare a quoted table and schema name."""
- result = '.'.join([self.quote(x, force) for x in schema.split('.')])
- return result
-
-dialect = MSSQLDialect
-dialect.statement_compiler = MSSQLCompiler
-dialect.schemagenerator = MSSQLSchemaGenerator
-dialect.schemadropper = MSSQLSchemaDropper
-dialect.preparer = MSSQLIdentifierPreparer
-
+++ /dev/null
-# mxODBC.py
-# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch
-# Coding: Alexander Houben alexander.houben@thor-solutions.ch
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""
-A wrapper for a mx.ODBC.Windows DB-API connection.
-
-Makes sure the mx module is configured to return datetime objects instead
-of mx.DateTime.DateTime objects.
-"""
-
-from mx.ODBC.Windows import *
-
-
-class Cursor:
- def __init__(self, cursor):
- self.cursor = cursor
-
- def __getattr__(self, attr):
- res = getattr(self.cursor, attr)
- return res
-
- def execute(self, *args, **kwargs):
- res = self.cursor.execute(*args, **kwargs)
- return res
-
-
-class Connection:
- def myErrorHandler(self, connection, cursor, errorclass, errorvalue):
- err0, err1, err2, err3 = errorvalue
- #print ", ".join(["Err%d: %s"%(x, errorvalue[x]) for x in range(4)])
- if int(err1) == 109:
- # Ignore "Null value eliminated in aggregate function", this is not an error
- return
- raise errorclass, errorvalue
-
- def __init__(self, conn):
- self.conn = conn
- # install a mx ODBC error handler
- self.conn.errorhandler = self.myErrorHandler
-
- def __getattr__(self, attr):
- res = getattr(self.conn, attr)
- return res
-
- def cursor(self, *args, **kwargs):
- res = Cursor(self.conn.cursor(*args, **kwargs))
- return res
-
-
-# override 'connect' call
-def connect(*args, **kwargs):
- import mx.ODBC.Windows
- conn = mx.ODBC.Windows.Connect(*args, **kwargs)
- conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT
- return Connection(conn)
-Connect = connect
+++ /dev/null
-# oracle.py
-# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Support for the Oracle database.
-
-Oracle version 8 through current (11g at the time of this writing) are supported.
-
-Driver
-------
-
-The Oracle dialect uses the cx_oracle driver, available at
-http://cx-oracle.sourceforge.net/ . The dialect has several behaviors
-which are specifically tailored towards compatibility with this module.
-
-Connecting
-----------
-
-Connecting with create_engine() uses the standard URL approach of
-``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the
-host, port, and dbname tokens are converted to a TNS name using the cx_oracle
-:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name.
-
-Additional arguments which may be specified either as query string arguments on the
-URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are:
-
-* *allow_twophase* - enable two-phase transactions. Defaults to ``True``.
-
-* *auto_convert_lobs* - defaults to True, see the section on LOB objects.
-
-* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters.
- This is required for LOB datatypes but can be disabled to reduce overhead. Defaults
- to ``True``.
-
-* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an
- integer value. This value is only available as a URL query string argument.
-
-* *threaded* - enable multithreaded access to cx_oracle connections. Defaults
- to ``True``. Note that this is the opposite default of cx_oracle itself.
-
-* *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8). Defaults
- to ``True``. If ``False``, Oracle-8 compatible constructs are used for joins.
-
-* *optimize_limits* - defaults to ``False``. see the section on LIMIT/OFFSET.
-
-Auto Increment Behavior
------------------------
-
-SQLAlchemy Table objects which include integer primary keys are usually assumed to have
-"autoincrementing" behavior, meaning they can generate their own primary key values upon
-INSERT. Since Oracle has no "autoincrement" feature, SQLAlchemy relies upon sequences
-to produce these values. With the Oracle dialect, *a sequence must always be explicitly
-specified to enable autoincrement*. This is divergent with the majority of documentation
-examples which assume the usage of an autoincrement-capable database. To specify sequences,
-use the sqlalchemy.schema.Sequence object which is passed to a Column construct::
-
- t = Table('mytable', metadata,
- Column('id', Integer, Sequence('id_seq'), primary_key=True),
- Column(...), ...
- )
-
-This step is also required when using table reflection, i.e. autoload=True::
-
- t = Table('mytable', metadata,
- Column('id', Integer, Sequence('id_seq'), primary_key=True),
- autoload=True
- )
-
-LOB Objects
------------
-
-cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set
-is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default,
-SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First,
-the LOB object requires an active cursor association, meaning if you were to fetch many rows
-at once such that cx_oracle had to go back to the database and fetch a new batch of rows,
-the LOB objects in the already-fetched rows are now unreadable and will raise an error.
-SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read.
-The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy
-defaults to 50 (cx_oracle normally defaults this to one).
-
-Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to
-"normalize" the results to look more like other DBAPIs.
-
-The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place
-for all statement executions, even plain string-based statements for which SQLA has no awareness
-of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases
-without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch"
-of LOB objects, can be disabled using auto_convert_lobs=False.
-
-LIMIT/OFFSET Support
---------------------
-
-Oracle has no support for the LIMIT or OFFSET keywords. Whereas previous versions of SQLAlchemy
-used the "ROW NUMBER OVER..." construct to simulate LIMIT/OFFSET, SQLAlchemy 0.5 now uses
-a wrapped subquery approach in conjunction with ROWNUM. The exact methodology is taken from
-http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html . Note that the
-"FIRST ROWS()" optimization keyword mentioned is not used by default, as the user community felt
-this was stepping into the bounds of optimization that is better left on the DBA side, but this
-prefix can be added by enabling the optimize_limits=True flag on create_engine().
-
-Two Phase Transaction Support
------------------------------
-
-Two Phase transactions are implemented using XA transactions. Success has been reported of them
-working successfully but this should be regarded as an experimental feature.
-
-Oracle 8 Compatibility
-----------------------
-
-When using Oracle 8, a "use_ansi=False" flag is available which converts all
-JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN
-makes use of Oracle's (+) operator.
-
-Synonym/DBLINK Reflection
--------------------------
-
-When using reflection with Table objects, the dialect can optionally search for tables
-indicated by synonyms that reference DBLINK-ed tables by passing the flag
-oracle_resolve_synonyms=True as a keyword argument to the Table construct. If DBLINK
-is not in use this flag should be left off.
-
-"""
-
-import datetime, random, re
-
-from sqlalchemy import util, sql, schema, log
-from sqlalchemy.engine import default, base
-from sqlalchemy.sql import compiler, visitors, expression
-from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
-from sqlalchemy import types as sqltypes
-
-
-class OracleNumeric(sqltypes.Numeric):
- def get_col_spec(self):
- if self.precision is None:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class OracleInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-class OracleSmallInteger(sqltypes.Smallinteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-class OracleDate(sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
- def bind_processor(self, dialect):
- return None
-
- def result_processor(self, dialect):
- def process(value):
- if not isinstance(value, datetime.datetime):
- return value
- else:
- return value.date()
- return process
-
-class OracleDateTime(sqltypes.DateTime):
- def get_col_spec(self):
- return "DATE"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None or isinstance(value, datetime.datetime):
- return value
- else:
- # convert cx_oracle datetime object returned pre-python 2.4
- return datetime.datetime(value.year, value.month,
- value.day,value.hour, value.minute, value.second)
- return process
-
-# Note:
-# Oracle DATE == DATETIME
-# Oracle does not allow milliseconds in DATE
-# Oracle does not support TIME columns
-
-# only if cx_oracle contains TIMESTAMP
-class OracleTimestamp(sqltypes.TIMESTAMP):
- def get_col_spec(self):
- return "TIMESTAMP"
-
- def get_dbapi_type(self, dialect):
- return dialect.TIMESTAMP
-
- def result_processor(self, dialect):
- def process(value):
- if value is None or isinstance(value, datetime.datetime):
- return value
- else:
- # convert cx_oracle datetime object returned pre-python 2.4
- return datetime.datetime(value.year, value.month,
- value.day,value.hour, value.minute, value.second)
- return process
-
-class OracleString(sqltypes.String):
- def get_col_spec(self):
- return "VARCHAR(%(length)s)" % {'length' : self.length}
-
-class OracleNVarchar(sqltypes.Unicode, OracleString):
- def get_col_spec(self):
- return "NVARCHAR2(%(length)s)" % {'length' : self.length}
-
-class OracleText(sqltypes.Text):
- def get_dbapi_type(self, dbapi):
- return dbapi.CLOB
-
- def get_col_spec(self):
- return "CLOB"
-
- def result_processor(self, dialect):
- super_process = super(OracleText, self).result_processor(dialect)
- if not dialect.auto_convert_lobs:
- return super_process
- lob = dialect.dbapi.LOB
- def process(value):
- if isinstance(value, lob):
- if super_process:
- return super_process(value.read())
- else:
- return value.read()
- else:
- if super_process:
- return super_process(value)
- else:
- return value
- return process
-
-
-class OracleChar(sqltypes.CHAR):
- def get_col_spec(self):
- return "CHAR(%(length)s)" % {'length' : self.length}
-
-class OracleBinary(sqltypes.Binary):
- def get_dbapi_type(self, dbapi):
- return dbapi.BLOB
-
- def get_col_spec(self):
- return "BLOB"
-
- def bind_processor(self, dialect):
- return None
-
- def result_processor(self, dialect):
- if not dialect.auto_convert_lobs:
- return None
- lob = dialect.dbapi.LOB
- def process(value):
- if isinstance(value, lob):
- return value.read()
- else:
- return value
- return process
-
-class OracleRaw(OracleBinary):
- def get_col_spec(self):
- return "RAW(%(length)s)" % {'length' : self.length}
-
-class OracleBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return "SMALLINT"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return value and True or False
- return process
-
- def bind_processor(self, dialect):
- def process(value):
- if value is True:
- return 1
- elif value is False:
- return 0
- elif value is None:
- return None
- else:
- return value and True or False
- return process
-
-colspecs = {
- sqltypes.Integer : OracleInteger,
- sqltypes.Smallinteger : OracleSmallInteger,
- sqltypes.Numeric : OracleNumeric,
- sqltypes.Float : OracleNumeric,
- sqltypes.DateTime : OracleDateTime,
- sqltypes.Date : OracleDate,
- sqltypes.String : OracleString,
- sqltypes.Binary : OracleBinary,
- sqltypes.Boolean : OracleBoolean,
- sqltypes.Text : OracleText,
- sqltypes.TIMESTAMP : OracleTimestamp,
- sqltypes.CHAR: OracleChar,
-}
-
-ischema_names = {
- 'VARCHAR2' : OracleString,
- 'NVARCHAR2' : OracleNVarchar,
- 'CHAR' : OracleString,
- 'DATE' : OracleDateTime,
- 'DATETIME' : OracleDateTime,
- 'NUMBER' : OracleNumeric,
- 'BLOB' : OracleBinary,
- 'BFILE' : OracleBinary,
- 'CLOB' : OracleText,
- 'TIMESTAMP' : OracleTimestamp,
- 'RAW' : OracleRaw,
- 'FLOAT' : OracleNumeric,
- 'DOUBLE PRECISION' : OracleNumeric,
- 'LONG' : OracleText,
-}
-
-class OracleExecutionContext(default.DefaultExecutionContext):
- def pre_exec(self):
- super(OracleExecutionContext, self).pre_exec()
- if self.dialect.auto_setinputsizes:
- self.set_input_sizes()
- if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
- for key in self.compiled.binds:
- bindparam = self.compiled.binds[key]
- name = self.compiled.bind_names[bindparam]
- value = self.compiled_parameters[0][name]
- if bindparam.isoutparam:
- dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
- if not hasattr(self, 'out_parameters'):
- self.out_parameters = {}
- self.out_parameters[name] = self.cursor.var(dbtype)
- self.parameters[0][name] = self.out_parameters[name]
-
- def create_cursor(self):
- c = self._connection.connection.cursor()
- if self.dialect.arraysize:
- c.cursor.arraysize = self.dialect.arraysize
- return c
-
- def get_result_proxy(self):
- if hasattr(self, 'out_parameters'):
- if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
- for bind, name in self.compiled.bind_names.iteritems():
- if name in self.out_parameters:
- type = bind.type
- result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect)
- if result_processor is not None:
- self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue())
- else:
- self.out_parameters[name] = self.out_parameters[name].getvalue()
- else:
- for k in self.out_parameters:
- self.out_parameters[k] = self.out_parameters[k].getvalue()
-
- if self.cursor.description is not None:
- for column in self.cursor.description:
- type_code = column[1]
- if type_code in self.dialect.ORACLE_BINARY_TYPES:
- return base.BufferedColumnResultProxy(self)
-
- return base.ResultProxy(self)
-
-class OracleDialect(default.DefaultDialect):
- name = 'oracle'
- supports_alter = True
- supports_unicode_statements = False
- max_identifier_length = 30
- supports_sane_rowcount = True
- supports_sane_multi_rowcount = False
- preexecute_pk_sequences = True
- supports_pk_autoincrement = False
- default_paramstyle = 'named'
-
- def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, optimize_limits=False, arraysize=50, **kwargs):
- default.DefaultDialect.__init__(self, **kwargs)
- self.use_ansi = use_ansi
- self.threaded = threaded
- self.arraysize = arraysize
- self.allow_twophase = allow_twophase
- self.optimize_limits = optimize_limits
- self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
- self.auto_setinputsizes = auto_setinputsizes
- self.auto_convert_lobs = auto_convert_lobs
- if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__:
- self.dbapi_type_map = {}
- self.ORACLE_BINARY_TYPES = []
- else:
- # only use this for LOB objects. using it for strings, dates
- # etc. leads to a little too much magic, reflection doesn't know if it should
- # expect encoded strings or unicodes, etc.
- self.dbapi_type_map = {
- self.dbapi.CLOB: OracleText(),
- self.dbapi.BLOB: OracleBinary(),
- self.dbapi.BINARY: OracleRaw(),
- }
- self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
-
- def dbapi(cls):
- import cx_Oracle
- return cx_Oracle
- dbapi = classmethod(dbapi)
-
- def create_connect_args(self, url):
- dialect_opts = dict(url.query)
- for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
- 'threaded', 'allow_twophase'):
- if opt in dialect_opts:
- util.coerce_kw_type(dialect_opts, opt, bool)
- setattr(self, opt, dialect_opts[opt])
-
- if url.database:
- # if we have a database, then we have a remote host
- port = url.port
- if port:
- port = int(port)
- else:
- port = 1521
- dsn = self.dbapi.makedsn(url.host, port, url.database)
- else:
- # we have a local tnsname
- dsn = url.host
-
- opts = dict(
- user=url.username,
- password=url.password,
- dsn=dsn,
- threaded=self.threaded,
- twophase=self.allow_twophase,
- )
- if 'mode' in url.query:
- opts['mode'] = url.query['mode']
- if isinstance(opts['mode'], basestring):
- mode = opts['mode'].upper()
- if mode == 'SYSDBA':
- opts['mode'] = self.dbapi.SYSDBA
- elif mode == 'SYSOPER':
- opts['mode'] = self.dbapi.SYSOPER
- else:
- util.coerce_kw_type(opts, 'mode', int)
- # Can't set 'handle' or 'pool' via URL query args, use connect_args
-
- return ([], opts)
-
- def is_disconnect(self, e):
- if isinstance(e, self.dbapi.InterfaceError):
- return "not connected" in str(e)
- else:
- return "ORA-03114" in str(e) or "ORA-03113" in str(e)
-
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
-
- def create_xid(self):
- """create a two-phase transaction ID.
-
- this id will be passed to do_begin_twophase(), do_rollback_twophase(),
- do_commit_twophase(). its format is unspecified."""
-
- id = random.randint(0, 2 ** 128)
- return (0x1234, "%032x" % id, "%032x" % 9)
-
- def do_release_savepoint(self, connection, name):
- # Oracle does not support RELEASE SAVEPOINT
- pass
-
- def do_begin_twophase(self, connection, xid):
- connection.connection.begin(*xid)
-
- def do_prepare_twophase(self, connection, xid):
- connection.connection.prepare()
-
- def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
- self.do_rollback(connection.connection)
-
- def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
- self.do_commit(connection.connection)
-
- def do_recover_twophase(self, connection):
- pass
-
- def has_table(self, connection, table_name, schema=None):
- if not schema:
- schema = self.get_default_schema_name(connection)
- cursor = connection.execute("""select table_name from all_tables where table_name=:name and owner=:schema_name""", {'name':self._denormalize_name(table_name), 'schema_name':self._denormalize_name(schema)})
- return cursor.fetchone() is not None
-
- def has_sequence(self, connection, sequence_name, schema=None):
- if not schema:
- schema = self.get_default_schema_name(connection)
- cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name and sequence_owner=:schema_name""", {'name':self._denormalize_name(sequence_name), 'schema_name':self._denormalize_name(schema)})
- return cursor.fetchone() is not None
-
- def _normalize_name(self, name):
- if name is None:
- return None
- elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding)):
- return name.lower().decode(self.encoding)
- else:
- return name.decode(self.encoding)
-
- def _denormalize_name(self, name):
- if name is None:
- return None
- elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()):
- return name.upper().encode(self.encoding)
- else:
- return name.encode(self.encoding)
-
- def get_default_schema_name(self, connection):
- return self._normalize_name(connection.execute('SELECT USER FROM DUAL').scalar())
- get_default_schema_name = base.connection_memoize(
- ('dialect', 'default_schema_name'))(get_default_schema_name)
-
- def table_names(self, connection, schema):
- # note that table_names() isnt loading DBLINKed or synonym'ed tables
- if schema is None:
- s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')"
- cursor = connection.execute(s)
- else:
- s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner"
- cursor = connection.execute(s, {'owner': self._denormalize_name(schema)})
- return [self._normalize_name(row[0]) for row in cursor]
-
- def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None):
- """search for a local synonym matching the given desired owner/name.
-
- if desired_owner is None, attempts to locate a distinct owner.
-
- returns the actual name, owner, dblink name, and synonym name if found.
- """
-
- sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME
- from ALL_SYNONYMS WHERE """
-
- clauses = []
- params = {}
- if desired_synonym:
- clauses.append("SYNONYM_NAME=:synonym_name")
- params['synonym_name'] = desired_synonym
- if desired_owner:
- clauses.append("TABLE_OWNER=:desired_owner")
- params['desired_owner'] = desired_owner
- if desired_table:
- clauses.append("TABLE_NAME=:tname")
- params['tname'] = desired_table
-
- sql += " AND ".join(clauses)
-
- result = connection.execute(sql, **params)
- if desired_owner:
- row = result.fetchone()
- if row:
- return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME']
- else:
- return None, None, None, None
- else:
- rows = result.fetchall()
- if len(rows) > 1:
- raise AssertionError("There are multiple tables visible to the schema, you must specify owner")
- elif len(rows) == 1:
- row = rows[0]
- return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME']
- else:
- return None, None, None, None
-
- def reflecttable(self, connection, table, include_columns):
- preparer = self.identifier_preparer
-
- resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False)
-
- if resolve_synonyms:
- actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name))
- else:
- actual_name, owner, dblink, synonym = None, None, None, None
-
- if not actual_name:
- actual_name = self._denormalize_name(table.name)
- if not dblink:
- dblink = ''
- if not owner:
- owner = self._denormalize_name(table.schema or self.get_default_schema_name(connection))
-
- c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner})
-
- while True:
- row = c.fetchone()
- if row is None:
- break
-
- (colname, coltype, length, precision, scale, nullable, default) = (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
-
- if include_columns and colname not in include_columns:
- continue
-
- # INTEGER if the scale is 0 and precision is null
- # NUMBER if the scale and precision are both null
- # NUMBER(9,2) if the precision is 9 and the scale is 2
- # NUMBER(3) if the precision is 3 and scale is 0
- #length is ignored except for CHAR and VARCHAR2
- if coltype == 'NUMBER' :
- if precision is None and scale is None:
- coltype = OracleNumeric
- elif precision is None and scale == 0 :
- coltype = OracleInteger
- else :
- coltype = OracleNumeric(precision, scale)
- elif coltype=='CHAR' or coltype=='VARCHAR2':
- coltype = ischema_names.get(coltype, OracleString)(length)
- else:
- coltype = re.sub(r'\(\d+\)', '', coltype)
- try:
- coltype = ischema_names[coltype]
- except KeyError:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (coltype, colname))
- coltype = sqltypes.NULLTYPE
-
- colargs = []
- if default is not None:
- colargs.append(schema.DefaultClause(sql.text(default)))
-
- table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
-
- if not table.columns:
- raise AssertionError("Couldn't find any column information for table %s" % actual_name)
-
- c = connection.execute("""SELECT
- ac.constraint_name,
- ac.constraint_type,
- loc.column_name AS local_column,
- rem.table_name AS remote_table,
- rem.column_name AS remote_column,
- rem.owner AS remote_owner
- FROM all_constraints%(dblink)s ac,
- all_cons_columns%(dblink)s loc,
- all_cons_columns%(dblink)s rem
- WHERE ac.table_name = :table_name
- AND ac.constraint_type IN ('R','P')
- AND ac.owner = :owner
- AND ac.owner = loc.owner
- AND ac.constraint_name = loc.constraint_name
- AND ac.r_owner = rem.owner(+)
- AND ac.r_constraint_name = rem.constraint_name(+)
- -- order multiple primary keys correctly
- ORDER BY ac.constraint_name, loc.position, rem.position"""
- % {'dblink':dblink}, {'table_name' : actual_name, 'owner' : owner})
-
- fks = {}
- while True:
- row = c.fetchone()
- if row is None:
- break
- #print "ROW:" , row
- (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
- if cons_type == 'P':
- table.primary_key.add(table.c[local_column])
- elif cons_type == 'R':
- try:
- fk = fks[cons_name]
- except KeyError:
- fk = ([], [])
- fks[cons_name] = fk
- if remote_table is None:
- # ticket 363
- util.warn(
- ("Got 'None' querying 'table_name' from "
- "all_cons_columns%(dblink)s - does the user have "
- "proper rights to the table?") % {'dblink':dblink})
- continue
-
- if resolve_synonyms:
- ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(remote_owner), desired_table=self._denormalize_name(remote_table))
- if ref_synonym:
- remote_table = self._normalize_name(ref_synonym)
- remote_owner = self._normalize_name(ref_remote_owner)
-
- if not table.schema and self._denormalize_name(remote_owner) == owner:
- refspec = ".".join([remote_table, remote_column])
- t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
- else:
- refspec = ".".join([x for x in [remote_owner, remote_table, remote_column] if x])
- t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, schema=remote_owner, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
-
- if local_column not in fk[0]:
- fk[0].append(local_column)
- if refspec not in fk[1]:
- fk[1].append(refspec)
-
- for name, value in fks.iteritems():
- table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True))
-
-
-class _OuterJoinColumn(sql.ClauseElement):
- __visit_name__ = 'outer_join_column'
-
- def __init__(self, column):
- self.column = column
-
-class OracleCompiler(compiler.DefaultCompiler):
- """Oracle compiler modifies the lexical structure of Select
- statements to work under non-ANSI configured Oracle databases, if
- the use_ansi flag is False.
- """
-
- operators = compiler.DefaultCompiler.operators.copy()
- operators.update(
- {
- sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y),
- sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
- }
- )
-
- functions = compiler.DefaultCompiler.functions.copy()
- functions.update (
- {
- sql_functions.now : 'CURRENT_TIMESTAMP'
- }
- )
-
- def __init__(self, *args, **kwargs):
- super(OracleCompiler, self).__init__(*args, **kwargs)
- self.__wheres = {}
-
- def default_from(self):
- """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
-
- The Oracle compiler tacks a "FROM DUAL" to the statement.
- """
-
- return " FROM DUAL"
-
- def apply_function_parens(self, func):
- return len(func.clauses) > 0
-
- def visit_join(self, join, **kwargs):
- if self.dialect.use_ansi:
- return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
- else:
- return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
-
- def _get_nonansi_join_whereclause(self, froms):
- clauses = []
-
- def visit_join(join):
- if join.isouter:
- def visit_binary(binary):
- if binary.operator == sql_operators.eq:
- if binary.left.table is join.right:
- binary.left = _OuterJoinColumn(binary.left)
- elif binary.right.table is join.right:
- binary.right = _OuterJoinColumn(binary.right)
- clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary}))
- else:
- clauses.append(join.onclause)
-
- for f in froms:
- visitors.traverse(f, {}, {'join':visit_join})
- return sql.and_(*clauses)
-
- def visit_outer_join_column(self, vc):
- return self.process(vc.column) + "(+)"
-
- def visit_sequence(self, seq):
- return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
-
- def visit_alias(self, alias, asfrom=False, **kwargs):
- """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
-
- if asfrom:
- alias_name = isinstance(alias.name, expression._generated_label) and \
- self._truncated_identifier("alias", alias.name) or alias.name
-
- return self.process(alias.original, asfrom=True, **kwargs) + " " +\
- self.preparer.format_alias(alias, alias_name)
- else:
- return self.process(alias.original, **kwargs)
-
- def _TODO_visit_compound_select(self, select):
- """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
- pass
-
- def visit_select(self, select, **kwargs):
- """Look for ``LIMIT`` and OFFSET in a select statement, and if
- so tries to wrap it in a subquery with ``rownum`` criterion.
- """
-
- if not getattr(select, '_oracle_visit', None):
- if not self.dialect.use_ansi:
- if self.stack and 'from' in self.stack[-1]:
- existingfroms = self.stack[-1]['from']
- else:
- existingfroms = None
-
- froms = select._get_display_froms(existingfroms)
- whereclause = self._get_nonansi_join_whereclause(froms)
- if whereclause:
- select = select.where(whereclause)
- select._oracle_visit = True
-
- if select._limit is not None or select._offset is not None:
- # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html
- #
- # Generalized form of an Oracle pagination query:
- # select ... from (
- # select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from (
- # select distinct ... where ... order by ...
- # ) where ROWNUM <= :limit+:offset
- # ) where ora_rn > :offset
- # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0
-
- # TODO: use annotations instead of clone + attr set ?
- select = select._generate()
- select._oracle_visit = True
-
- # Wrap the middle select and add the hint
- limitselect = sql.select([c for c in select.c])
- if select._limit and self.dialect.optimize_limits:
- limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit)
-
- limitselect._oracle_visit = True
- limitselect._is_wrapper = True
-
- # If needed, add the limiting clause
- if select._limit is not None:
- max_row = select._limit
- if select._offset is not None:
- max_row += select._offset
- limitselect.append_whereclause(
- sql.literal_column("ROWNUM")<=max_row)
-
- # If needed, add the ora_rn, and wrap again with offset.
- if select._offset is None:
- select = limitselect
- else:
- limitselect = limitselect.column(
- sql.literal_column("ROWNUM").label("ora_rn"))
- limitselect._oracle_visit = True
- limitselect._is_wrapper = True
-
- offsetselect = sql.select(
- [c for c in limitselect.c if c.key!='ora_rn'])
- offsetselect._oracle_visit = True
- offsetselect._is_wrapper = True
-
- offsetselect.append_whereclause(
- sql.literal_column("ora_rn")>select._offset)
-
- select = offsetselect
-
- kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
- return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
-
- def limit_clause(self, select):
- return ""
-
- def for_update_clause(self, select):
- if select.for_update == "nowait":
- return " FOR UPDATE NOWAIT"
- else:
- return super(OracleCompiler, self).for_update_clause(select)
-
-
-class OracleSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column)
- colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- if not column.nullable:
- colspec += " NOT NULL"
- return colspec
-
- def visit_sequence(self, sequence):
- if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name, sequence.schema):
- self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
- self.execute()
-
-class OracleSchemaDropper(compiler.SchemaDropper):
- def visit_sequence(self, sequence):
- if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name, sequence.schema):
- self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
- self.execute()
-
-class OracleDefaultRunner(base.DefaultRunner):
- def visit_sequence(self, seq):
- return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {})
-
-class OracleIdentifierPreparer(compiler.IdentifierPreparer):
- def format_savepoint(self, savepoint):
- name = re.sub(r'^_+', '', savepoint.ident)
- return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
-
-
-dialect = OracleDialect
-dialect.statement_compiler = OracleCompiler
-dialect.schemagenerator = OracleSchemaGenerator
-dialect.schemadropper = OracleSchemaDropper
-dialect.preparer = OracleIdentifierPreparer
-dialect.defaultrunner = OracleDefaultRunner
-dialect.execution_ctx_cls = OracleExecutionContext
+++ /dev/null
-# sqlite.py
-# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""Support for the SQLite database.
-
-Driver
-------
-
-When using Python 2.5 and above, the built in ``sqlite3`` driver is
-already installed and no additional installation is needed. Otherwise,
-the ``pysqlite2`` driver needs to be present. This is the same driver as
-``sqlite3``, just with a different name.
-
-The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
-is loaded. This allows an explicitly installed pysqlite driver to take
-precedence over the built in one. As with all dialects, a specific
-DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control
-this explicitly::
-
- from sqlite3 import dbapi2 as sqlite
- e = create_engine('sqlite:///file.db', module=sqlite)
-
-Full documentation on pysqlite is available at:
-`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
-
-Connect Strings
----------------
-
-The file specification for the SQLite database is taken as the "database" portion of
-the URL. Note that the format of a url is::
-
- driver://user:pass@host/database
-
-This means that the actual filename to be used starts with the characters to the
-**right** of the third slash. So connecting to a relative filepath looks like::
-
- # relative path
- e = create_engine('sqlite:///path/to/database.db')
-
-An absolute path, which is denoted by starting with a slash, means you need **four**
-slashes::
-
- # absolute path
- e = create_engine('sqlite:////path/to/database.db')
-
-To use a Windows path, regular drive specifications and backslashes can be used.
-Double backslashes are probably needed::
-
- # absolute path on Windows
- e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
-
-The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify
-``sqlite://`` and nothing else::
-
- # in-memory database
- e = create_engine('sqlite://')
-
-Threading Behavior
-------------------
-
-Pysqlite connections do not support being moved between threads, unless
-the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition,
-when using an in-memory SQLite database, the full database exists only within
-the scope of a single connection. It is reported that an in-memory
-database does not support being shared between threads regardless of the
-``check_same_thread`` flag - which means that a multithreaded
-application **cannot** share data from a ``:memory:`` database across threads
-unless access to the connection is limited to a single worker thread which communicates
-through a queueing mechanism to concurrent threads.
-
-To provide a default which accomodates SQLite's default threading capabilities
-somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
-be used by default. This pool maintains a single SQLite connection per thread
-that is held open up to a count of five concurrent threads. When more than five threads
-are used, a cleanup mechanism will dispose of excess unused connections.
-
-Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
-
- * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
- application using an in-memory database, assuming the threading issues inherent in
- pysqlite are somehow accomodated for. This pool holds persistently onto a single connection
- which is never closed, and is returned for all requests.
-
- * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
- makes use of a file-based sqlite database. This pool disables any actual "pooling"
- behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
- and :func:`close()` methods. SQLite can "connect" to a particular file with very high
- efficiency, so this option may actually perform better without the extra overhead
- of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection
- useless since the database would be lost as soon as the connection is "returned" to the pool.
-
-Date and Time Types
--------------------
-
-SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide
-out of the box functionality for translating values between Python `datetime` objects
-and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime`
-and related types provide date formatting and parsing functionality when SQlite is used.
-The implementation classes are :class:`SLDateTime`, :class:`SLDate` and :class:`SLTime`.
-These types represent dates and times as ISO formatted strings, which also nicely
-support ordering. There's no reliance on typical "libc" internals for these functions
-so historical dates are fully supported.
-
-Unicode
--------
-
-In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's
-default behavior regarding Unicode is that all strings are returned as Python unicode objects
-in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is
-*not* used, you will still always receive unicode data back from a result set. It is
-**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
-to represent strings, since it will raise a warning if a non-unicode Python string is
-passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can
-quickly create confusion, particularly when using the ORM as internal data is not
-always represented by an actual database result string.
-
-"""
-
-
-import datetime, re, time
-
-from sqlalchemy import sql, schema, exc, pool, DefaultClause
-from sqlalchemy.engine import default
-import sqlalchemy.types as sqltypes
-import sqlalchemy.util as util
-from sqlalchemy.sql import compiler, functions as sql_functions
-from types import NoneType
-
-class SLNumeric(sqltypes.Numeric):
- def bind_processor(self, dialect):
- type_ = self.asdecimal and str or float
- def process(value):
- if value is not None:
- return type_(value)
- else:
- return value
- return process
-
- def get_col_spec(self):
- if self.precision is None:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class SLFloat(sqltypes.Float):
- def bind_processor(self, dialect):
- type_ = self.asdecimal and str or float
- def process(value):
- if value is not None:
- return type_(value)
- else:
- return value
- return process
-
- def get_col_spec(self):
- return "FLOAT"
-
-class SLInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-class SLSmallInteger(sqltypes.Smallinteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-class DateTimeMixin(object):
- def _bind_processor(self, format, elements):
- def process(value):
- if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)):
- raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.")
- elif value is not None:
- return format % tuple([getattr(value, attr, 0) for attr in elements])
- else:
- return None
- return process
-
- def _result_processor(self, fn, regexp):
- def process(value):
- if value is not None:
- return fn(*[int(x or 0) for x in regexp.match(value).groups()])
- else:
- return None
- return process
-
-class SLDateTime(DateTimeMixin, sqltypes.DateTime):
- __legacy_microseconds__ = False
-
- def get_col_spec(self):
- return "TIMESTAMP"
-
- def bind_processor(self, dialect):
- if self.__legacy_microseconds__:
- return self._bind_processor(
- "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s",
- ("year", "month", "day", "hour", "minute", "second", "microsecond")
- )
- else:
- return self._bind_processor(
- "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d",
- ("year", "month", "day", "hour", "minute", "second", "microsecond")
- )
-
- _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?")
- def result_processor(self, dialect):
- return self._result_processor(datetime.datetime, self._reg)
-
-class SLDate(DateTimeMixin, sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
-
- def bind_processor(self, dialect):
- return self._bind_processor(
- "%4.4d-%2.2d-%2.2d",
- ("year", "month", "day")
- )
-
- _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
- def result_processor(self, dialect):
- return self._result_processor(datetime.date, self._reg)
-
-class SLTime(DateTimeMixin, sqltypes.Time):
- __legacy_microseconds__ = False
-
- def get_col_spec(self):
- return "TIME"
-
- def bind_processor(self, dialect):
- if self.__legacy_microseconds__:
- return self._bind_processor(
- "%2.2d:%2.2d:%2.2d.%s",
- ("hour", "minute", "second", "microsecond")
- )
- else:
- return self._bind_processor(
- "%2.2d:%2.2d:%2.2d.%06d",
- ("hour", "minute", "second", "microsecond")
- )
-
- _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
- def result_processor(self, dialect):
- return self._result_processor(datetime.time, self._reg)
-
-class SLUnicodeMixin(object):
- def bind_processor(self, dialect):
- if self.convert_unicode or dialect.convert_unicode:
- if self.assert_unicode is None:
- assert_unicode = dialect.assert_unicode
- else:
- assert_unicode = self.assert_unicode
-
- if not assert_unicode:
- return None
-
- def process(value):
- if not isinstance(value, (unicode, NoneType)):
- if assert_unicode == 'warn':
- util.warn("Unicode type received non-unicode bind "
- "param value %r" % value)
- return value
- else:
- raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
- else:
- return value
- return process
- else:
- return None
-
- def result_processor(self, dialect):
- return None
-
-class SLText(SLUnicodeMixin, sqltypes.Text):
- def get_col_spec(self):
- return "TEXT"
-
-class SLString(SLUnicodeMixin, sqltypes.String):
- def get_col_spec(self):
- return "VARCHAR" + (self.length and "(%d)" % self.length or "")
-
-class SLChar(SLUnicodeMixin, sqltypes.CHAR):
- def get_col_spec(self):
- return "CHAR" + (self.length and "(%d)" % self.length or "")
-
-class SLBinary(sqltypes.Binary):
- def get_col_spec(self):
- return "BLOB"
-
-class SLBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return "BOOLEAN"
-
- def bind_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return value and 1 or 0
- return process
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return value == 1
- return process
-
-colspecs = {
- sqltypes.Binary: SLBinary,
- sqltypes.Boolean: SLBoolean,
- sqltypes.CHAR: SLChar,
- sqltypes.Date: SLDate,
- sqltypes.DateTime: SLDateTime,
- sqltypes.Float: SLFloat,
- sqltypes.Integer: SLInteger,
- sqltypes.NCHAR: SLChar,
- sqltypes.Numeric: SLNumeric,
- sqltypes.Smallinteger: SLSmallInteger,
- sqltypes.String: SLString,
- sqltypes.Text: SLText,
- sqltypes.Time: SLTime,
-}
-
-ischema_names = {
- 'BLOB': SLBinary,
- 'BOOL': SLBoolean,
- 'BOOLEAN': SLBoolean,
- 'CHAR': SLChar,
- 'DATE': SLDate,
- 'DATETIME': SLDateTime,
- 'DECIMAL': SLNumeric,
- 'FLOAT': SLFloat,
- 'INT': SLInteger,
- 'INTEGER': SLInteger,
- 'NUMERIC': SLNumeric,
- 'REAL': SLNumeric,
- 'SMALLINT': SLSmallInteger,
- 'TEXT': SLText,
- 'TIME': SLTime,
- 'TIMESTAMP': SLDateTime,
- 'VARCHAR': SLString,
-}
-
-class SQLiteExecutionContext(default.DefaultExecutionContext):
- def post_exec(self):
- if self.compiled.isinsert and not self.executemany:
- if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
- self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
-class SQLiteDialect(default.DefaultDialect):
- name = 'sqlite'
- supports_alter = False
- supports_unicode_statements = True
- default_paramstyle = 'qmark'
- supports_default_values = True
- supports_empty_insert = False
-
- def __init__(self, **kwargs):
- default.DefaultDialect.__init__(self, **kwargs)
- def vers(num):
- return tuple([int(x) for x in num.split('.')])
- if self.dbapi is not None:
- sqlite_ver = self.dbapi.version_info
- if sqlite_ver < (2, 1, '3'):
- util.warn(
- ("The installed version of pysqlite2 (%s) is out-dated "
- "and will cause errors in some cases. Version 2.1.3 "
- "or greater is recommended.") %
- '.'.join([str(subver) for subver in sqlite_ver]))
- if self.dbapi.sqlite_version_info < (3, 3, 8):
- self.supports_default_values = False
- self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
-
- def dbapi(cls):
- try:
- from pysqlite2 import dbapi2 as sqlite
- except ImportError, e:
- try:
- from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
- except ImportError:
- raise e
- return sqlite
- dbapi = classmethod(dbapi)
-
- def server_version_info(self, connection):
- return self.dbapi.sqlite_version_info
-
- def create_connect_args(self, url):
- if url.username or url.password or url.host or url.port:
- raise exc.ArgumentError(
- "Invalid SQLite URL: %s\n"
- "Valid SQLite URL forms are:\n"
- " sqlite:///:memory: (or, sqlite://)\n"
- " sqlite:///relative/path/to/file.db\n"
- " sqlite:////absolute/path/to/file.db" % (url,))
- filename = url.database or ':memory:'
-
- opts = url.query.copy()
- util.coerce_kw_type(opts, 'timeout', float)
- util.coerce_kw_type(opts, 'isolation_level', str)
- util.coerce_kw_type(opts, 'detect_types', int)
- util.coerce_kw_type(opts, 'check_same_thread', bool)
- util.coerce_kw_type(opts, 'cached_statements', int)
-
- return ([filename], opts)
-
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
-
- def is_disconnect(self, e):
- return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
-
- def table_names(self, connection, schema):
- if schema is not None:
- qschema = self.identifier_preparer.quote_identifier(schema)
- master = '%s.sqlite_master' % qschema
- s = ("SELECT name FROM %s "
- "WHERE type='table' ORDER BY name") % (master,)
- rs = connection.execute(s)
- else:
- try:
- s = ("SELECT name FROM "
- " (SELECT * FROM sqlite_master UNION ALL "
- " SELECT * FROM sqlite_temp_master) "
- "WHERE type='table' ORDER BY name")
- rs = connection.execute(s)
- except exc.DBAPIError:
- raise
- s = ("SELECT name FROM sqlite_master "
- "WHERE type='table' ORDER BY name")
- rs = connection.execute(s)
-
- return [row[0] for row in rs]
-
- def has_table(self, connection, table_name, schema=None):
- quote = self.identifier_preparer.quote_identifier
- if schema is not None:
- pragma = "PRAGMA %s." % quote(schema)
- else:
- pragma = "PRAGMA "
- qtable = quote(table_name)
- cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
-
- row = cursor.fetchone()
-
- # consume remaining rows, to work around
- # http://www.sqlite.org/cvstrac/tktview?tn=1884
- while cursor.fetchone() is not None:
- pass
-
- return (row is not None)
-
- def reflecttable(self, connection, table, include_columns):
- preparer = self.identifier_preparer
- if table.schema is None:
- pragma = "PRAGMA "
- else:
- pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema)
- qtable = preparer.format_table(table, False)
-
- c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
- found_table = False
- while True:
- row = c.fetchone()
- if row is None:
- break
-
- found_table = True
- (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5])
- name = re.sub(r'^\"|\"$', '', name)
- if include_columns and name not in include_columns:
- continue
- match = re.match(r'(\w+)(\(.*?\))?', type_)
- if match:
- coltype = match.group(1)
- args = match.group(2)
- else:
- coltype = "VARCHAR"
- args = ''
-
- try:
- coltype = ischema_names[coltype]
- except KeyError:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (coltype, name))
- coltype = sqltypes.NullType
-
- if args is not None:
- args = re.findall(r'(\d+)', args)
- coltype = coltype(*[int(a) for a in args])
-
- colargs = []
- if has_default:
- colargs.append(DefaultClause(sql.text(default)))
- table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
-
- if not found_table:
- raise exc.NoSuchTableError(table.name)
-
- c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable)))
- fks = {}
- while True:
- row = c.fetchone()
- if row is None:
- break
- (constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4])
- tablename = re.sub(r'^\"|\"$', '', tablename)
- localcol = re.sub(r'^\"|\"$', '', localcol)
- remotecol = re.sub(r'^\"|\"$', '', remotecol)
- try:
- fk = fks[constraint_name]
- except KeyError:
- fk = ([], [])
- fks[constraint_name] = fk
-
- # look up the table based on the given table's engine, not 'self',
- # since it could be a ProxyEngine
- remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection)
- constrained_column = table.c[localcol].name
- refspec = ".".join([tablename, remotecol])
- if constrained_column not in fk[0]:
- fk[0].append(constrained_column)
- if refspec not in fk[1]:
- fk[1].append(refspec)
- for name, value in fks.iteritems():
- table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], link_to_name=True))
- # check for UNIQUE indexes
- c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable)))
- unique_indexes = []
- while True:
- row = c.fetchone()
- if row is None:
- break
- if (row[2] == 1):
- unique_indexes.append(row[1])
- # loop thru unique indexes for one that includes the primary key
- for idx in unique_indexes:
- c = connection.execute("%sindex_info(%s)" % (pragma, idx))
- cols = []
- while True:
- row = c.fetchone()
- if row is None:
- break
- cols.append(row[2])
-
-def _pragma_cursor(cursor):
- if cursor.closed:
- cursor._fetchone_impl = lambda: None
- return cursor
-
-class SQLiteCompiler(compiler.DefaultCompiler):
- functions = compiler.DefaultCompiler.functions.copy()
- functions.update (
- {
- sql_functions.now: 'CURRENT_TIMESTAMP',
- sql_functions.char_length: 'length%(expr)s'
- }
- )
-
- extract_map = compiler.DefaultCompiler.extract_map.copy()
- extract_map.update({
- 'month': '%m',
- 'day': '%d',
- 'year': '%Y',
- 'second': '%S',
- 'hour': '%H',
- 'doy': '%j',
- 'minute': '%M',
- 'epoch': '%s',
- 'dow': '%w',
- 'week': '%W'
- })
-
- def visit_cast(self, cast, **kwargs):
- if self.dialect.supports_cast:
- return super(SQLiteCompiler, self).visit_cast(cast)
- else:
- return self.process(cast.clause)
-
- def visit_extract(self, extract):
- try:
- return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
- self.extract_map[extract.field], self.process(extract.expr))
- except KeyError:
- raise exc.ArgumentError(
- "%s is not a valid extract argument." % extract.field)
-
- def limit_clause(self, select):
- text = ""
- if select._limit is not None:
- text += " \n LIMIT " + str(select._limit)
- if select._offset is not None:
- if select._limit is None:
- text += " \n LIMIT -1"
- text += " OFFSET " + str(select._offset)
- else:
- text += " OFFSET 0"
- return text
-
- def for_update_clause(self, select):
- # sqlite has no "FOR UPDATE" AFAICT
- return ''
-
-
-class SQLiteSchemaGenerator(compiler.SchemaGenerator):
-
- def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- if not column.nullable:
- colspec += " NOT NULL"
- return colspec
-
-class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
- reserved_words = set([
- 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
- 'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
- 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
- 'conflict', 'constraint', 'create', 'cross', 'current_date',
- 'current_time', 'current_timestamp', 'database', 'default',
- 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
- 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
- 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
- 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
- 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is',
- 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural',
- 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer',
- 'plan', 'pragma', 'primary', 'query', 'raise', 'references',
- 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback',
- 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to',
- 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using',
- 'vacuum', 'values', 'view', 'virtual', 'when', 'where', 'indexed',
- ])
-
- def __init__(self, dialect):
- super(SQLiteIdentifierPreparer, self).__init__(dialect)
-
-dialect = SQLiteDialect
-dialect.poolclass = pool.SingletonThreadPool
-dialect.statement_compiler = SQLiteCompiler
-dialect.schemagenerator = SQLiteSchemaGenerator
-dialect.preparer = SQLiteIdentifierPreparer
-dialect.execution_ctx_cls = SQLiteExecutionContext
+++ /dev/null
-# sybase.py
-# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch
-# Coding: Alexander Houben alexander.houben@thor-solutions.ch
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""
-Sybase database backend.
-
-Known issues / TODO:
-
- * Uses the mx.ODBC driver from egenix (version 2.1.0)
- * The current version of sqlalchemy.databases.sybase only supports
- mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need
- some development)
- * Support for pyodbc has been built in but is not yet complete (needs
- further development)
- * Results of running tests/alltests.py:
- Ran 934 tests in 287.032s
- FAILED (failures=3, errors=1)
- * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751)
-"""
-
-import datetime, operator
-
-from sqlalchemy import util, sql, schema, exc
-from sqlalchemy.sql import compiler, expression
-from sqlalchemy.engine import default, base
-from sqlalchemy import types as sqltypes
-from sqlalchemy.sql import operators as sql_operators
-from sqlalchemy import MetaData, Table, Column
-from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey
-
-
-__all__ = [
- 'SybaseTypeError'
- 'SybaseNumeric', 'SybaseFloat', 'SybaseInteger', 'SybaseBigInteger',
- 'SybaseTinyInteger', 'SybaseSmallInteger',
- 'SybaseDateTime_mxodbc', 'SybaseDateTime_pyodbc',
- 'SybaseDate_mxodbc', 'SybaseDate_pyodbc',
- 'SybaseTime_mxodbc', 'SybaseTime_pyodbc',
- 'SybaseText', 'SybaseString', 'SybaseChar', 'SybaseBinary',
- 'SybaseBoolean', 'SybaseTimeStamp', 'SybaseMoney', 'SybaseSmallMoney',
- 'SybaseUniqueIdentifier',
- ]
-
-
-RESERVED_WORDS = set([
- "add", "all", "alter", "and",
- "any", "as", "asc", "backup",
- "begin", "between", "bigint", "binary",
- "bit", "bottom", "break", "by",
- "call", "capability", "cascade", "case",
- "cast", "char", "char_convert", "character",
- "check", "checkpoint", "close", "comment",
- "commit", "connect", "constraint", "contains",
- "continue", "convert", "create", "cross",
- "cube", "current", "current_timestamp", "current_user",
- "cursor", "date", "dbspace", "deallocate",
- "dec", "decimal", "declare", "default",
- "delete", "deleting", "desc", "distinct",
- "do", "double", "drop", "dynamic",
- "else", "elseif", "encrypted", "end",
- "endif", "escape", "except", "exception",
- "exec", "execute", "existing", "exists",
- "externlogin", "fetch", "first", "float",
- "for", "force", "foreign", "forward",
- "from", "full", "goto", "grant",
- "group", "having", "holdlock", "identified",
- "if", "in", "index", "index_lparen",
- "inner", "inout", "insensitive", "insert",
- "inserting", "install", "instead", "int",
- "integer", "integrated", "intersect", "into",
- "iq", "is", "isolation", "join",
- "key", "lateral", "left", "like",
- "lock", "login", "long", "match",
- "membership", "message", "mode", "modify",
- "natural", "new", "no", "noholdlock",
- "not", "notify", "null", "numeric",
- "of", "off", "on", "open",
- "option", "options", "or", "order",
- "others", "out", "outer", "over",
- "passthrough", "precision", "prepare", "primary",
- "print", "privileges", "proc", "procedure",
- "publication", "raiserror", "readtext", "real",
- "reference", "references", "release", "remote",
- "remove", "rename", "reorganize", "resource",
- "restore", "restrict", "return", "revoke",
- "right", "rollback", "rollup", "save",
- "savepoint", "scroll", "select", "sensitive",
- "session", "set", "setuser", "share",
- "smallint", "some", "sqlcode", "sqlstate",
- "start", "stop", "subtrans", "subtransaction",
- "synchronize", "syntax_error", "table", "temporary",
- "then", "time", "timestamp", "tinyint",
- "to", "top", "tran", "trigger",
- "truncate", "tsequal", "unbounded", "union",
- "unique", "unknown", "unsigned", "update",
- "updating", "user", "using", "validate",
- "values", "varbinary", "varchar", "variable",
- "varying", "view", "wait", "waitfor",
- "when", "where", "while", "window",
- "with", "with_cube", "with_lparen", "with_rollup",
- "within", "work", "writetext",
- ])
-
-ischema = MetaData()
-
-tables = Table("SYSTABLE", ischema,
- Column("table_id", Integer, primary_key=True),
- Column("file_id", SMALLINT),
- Column("table_name", CHAR(128)),
- Column("table_type", CHAR(10)),
- Column("creator", Integer),
- #schema="information_schema"
- )
-
-domains = Table("SYSDOMAIN", ischema,
- Column("domain_id", Integer, primary_key=True),
- Column("domain_name", CHAR(128)),
- Column("type_id", SMALLINT),
- Column("precision", SMALLINT, quote=True),
- #schema="information_schema"
- )
-
-columns = Table("SYSCOLUMN", ischema,
- Column("column_id", Integer, primary_key=True),
- Column("table_id", Integer, ForeignKey(tables.c.table_id)),
- Column("pkey", CHAR(1)),
- Column("column_name", CHAR(128)),
- Column("nulls", CHAR(1)),
- Column("width", SMALLINT),
- Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)),
- # FIXME: should be mx.BIGINT
- Column("max_identity", Integer),
- # FIXME: should be mx.ODBC.Windows.LONGVARCHAR
- Column("default", String),
- Column("scale", Integer),
- #schema="information_schema"
- )
-
-foreignkeys = Table("SYSFOREIGNKEY", ischema,
- Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True),
- Column("foreign_key_id", SMALLINT, primary_key=True),
- Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)),
- #schema="information_schema"
- )
-fkcols = Table("SYSFKCOL", ischema,
- Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True),
- Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True),
- Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True),
- Column("primary_column_id", Integer),
- #schema="information_schema"
- )
-
-class SybaseTypeError(sqltypes.TypeEngine):
- def result_processor(self, dialect):
- return None
-
- def bind_processor(self, dialect):
- def process(value):
- raise exc.InvalidRequestError("Data type not supported", [value])
- return process
-
- def get_col_spec(self):
- raise exc.CompileError("Data type not supported")
-
-class SybaseNumeric(sqltypes.Numeric):
- def get_col_spec(self):
- if self.scale is None:
- if self.precision is None:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s)" % {'precision' : self.precision}
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class SybaseFloat(sqltypes.FLOAT, SybaseNumeric):
- def __init__(self, precision = 10, asdecimal = False, scale = 2, **kwargs):
- super(sqltypes.FLOAT, self).__init__(precision, asdecimal, **kwargs)
- self.scale = scale
-
- def get_col_spec(self):
- # if asdecimal is True, handle same way as SybaseNumeric
- if self.asdecimal:
- return SybaseNumeric.get_col_spec(self)
- if self.precision is None:
- return "FLOAT"
- else:
- return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return float(value)
- if self.asdecimal:
- return SybaseNumeric.result_processor(self, dialect)
- return process
-
-class SybaseInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-class SybaseBigInteger(SybaseInteger):
- def get_col_spec(self):
- return "BIGINT"
-
-class SybaseTinyInteger(SybaseInteger):
- def get_col_spec(self):
- return "TINYINT"
-
-class SybaseSmallInteger(SybaseInteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-class SybaseDateTime_mxodbc(sqltypes.DateTime):
- def __init__(self, *a, **kw):
- super(SybaseDateTime_mxodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATETIME"
-
-class SybaseDateTime_pyodbc(sqltypes.DateTime):
- def __init__(self, *a, **kw):
- super(SybaseDateTime_pyodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATETIME"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- # Convert the datetime.datetime back to datetime.time
- return value
- return process
-
- def bind_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return value
- return process
-
-class SybaseDate_mxodbc(sqltypes.Date):
- def __init__(self, *a, **kw):
- super(SybaseDate_mxodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATE"
-
-class SybaseDate_pyodbc(sqltypes.Date):
- def __init__(self, *a, **kw):
- super(SybaseDate_pyodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATE"
-
-class SybaseTime_mxodbc(sqltypes.Time):
- def __init__(self, *a, **kw):
- super(SybaseTime_mxodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATETIME"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- # Convert the datetime.datetime back to datetime.time
- return datetime.time(value.hour, value.minute, value.second, value.microsecond)
- return process
-
-class SybaseTime_pyodbc(sqltypes.Time):
- def __init__(self, *a, **kw):
- super(SybaseTime_pyodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATETIME"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- # Convert the datetime.datetime back to datetime.time
- return datetime.time(value.hour, value.minute, value.second, value.microsecond)
- return process
-
- def bind_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return datetime.datetime(1970, 1, 1, value.hour, value.minute, value.second, value.microsecond)
- return process
-
-class SybaseText(sqltypes.Text):
- def get_col_spec(self):
- return "TEXT"
-
-class SybaseString(sqltypes.String):
- def get_col_spec(self):
- return "VARCHAR(%(length)s)" % {'length' : self.length}
-
-class SybaseChar(sqltypes.CHAR):
- def get_col_spec(self):
- return "CHAR(%(length)s)" % {'length' : self.length}
-
-class SybaseBinary(sqltypes.Binary):
- def get_col_spec(self):
- return "IMAGE"
-
-class SybaseBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return "BIT"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return value and True or False
- return process
-
- def bind_processor(self, dialect):
- def process(value):
- if value is True:
- return 1
- elif value is False:
- return 0
- elif value is None:
- return None
- else:
- return value and True or False
- return process
-
-class SybaseTimeStamp(sqltypes.TIMESTAMP):
- def get_col_spec(self):
- return "TIMESTAMP"
-
-class SybaseMoney(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "MONEY"
-
-class SybaseSmallMoney(SybaseMoney):
- def get_col_spec(self):
- return "SMALLMONEY"
-
-class SybaseUniqueIdentifier(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "UNIQUEIDENTIFIER"
-
-class SybaseSQLExecutionContext(default.DefaultExecutionContext):
- pass
-
-class SybaseSQLExecutionContext_mxodbc(SybaseSQLExecutionContext):
-
- def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
- super(SybaseSQLExecutionContext_mxodbc, self).__init__(dialect, connection, compiled, statement, parameters)
-
- def pre_exec(self):
- super(SybaseSQLExecutionContext_mxodbc, self).pre_exec()
-
- def post_exec(self):
- if self.compiled.isinsert:
- table = self.compiled.statement.table
- # get the inserted values of the primary key
-
- # get any sequence IDs first (using @@identity)
- self.cursor.execute("SELECT @@identity AS lastrowid")
- row = self.cursor.fetchone()
- lastrowid = int(row[0])
- if lastrowid > 0:
- # an IDENTITY was inserted, fetch it
- # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?!
- if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None:
- self._last_inserted_ids = [lastrowid]
- else:
- self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:]
- super(SybaseSQLExecutionContext_mxodbc, self).post_exec()
-
-class SybaseSQLExecutionContext_pyodbc(SybaseSQLExecutionContext):
- def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
- super(SybaseSQLExecutionContext_pyodbc, self).__init__(dialect, connection, compiled, statement, parameters)
-
- def pre_exec(self):
- super(SybaseSQLExecutionContext_pyodbc, self).pre_exec()
-
- def post_exec(self):
- if self.compiled.isinsert:
- table = self.compiled.statement.table
- # get the inserted values of the primary key
-
- # get any sequence IDs first (using @@identity)
- self.cursor.execute("SELECT @@identity AS lastrowid")
- row = self.cursor.fetchone()
- lastrowid = int(row[0])
- if lastrowid > 0:
- # an IDENTITY was inserted, fetch it
- # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?!
- if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None:
- self._last_inserted_ids = [lastrowid]
- else:
- self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:]
- super(SybaseSQLExecutionContext_pyodbc, self).post_exec()
-
-class SybaseSQLDialect(default.DefaultDialect):
- colspecs = {
- # FIXME: unicode support
- #sqltypes.Unicode : SybaseUnicode,
- sqltypes.Integer : SybaseInteger,
- sqltypes.SmallInteger : SybaseSmallInteger,
- sqltypes.Numeric : SybaseNumeric,
- sqltypes.Float : SybaseFloat,
- sqltypes.String : SybaseString,
- sqltypes.Binary : SybaseBinary,
- sqltypes.Boolean : SybaseBoolean,
- sqltypes.Text : SybaseText,
- sqltypes.CHAR : SybaseChar,
- sqltypes.TIMESTAMP : SybaseTimeStamp,
- sqltypes.FLOAT : SybaseFloat,
- }
-
- ischema_names = {
- 'integer' : SybaseInteger,
- 'unsigned int' : SybaseInteger,
- 'unsigned smallint' : SybaseInteger,
- 'unsigned bigint' : SybaseInteger,
- 'bigint': SybaseBigInteger,
- 'smallint' : SybaseSmallInteger,
- 'tinyint' : SybaseTinyInteger,
- 'varchar' : SybaseString,
- 'long varchar' : SybaseText,
- 'char' : SybaseChar,
- 'decimal' : SybaseNumeric,
- 'numeric' : SybaseNumeric,
- 'float' : SybaseFloat,
- 'double' : SybaseFloat,
- 'binary' : SybaseBinary,
- 'long binary' : SybaseBinary,
- 'varbinary' : SybaseBinary,
- 'bit': SybaseBoolean,
- 'image' : SybaseBinary,
- 'timestamp': SybaseTimeStamp,
- 'money': SybaseMoney,
- 'smallmoney': SybaseSmallMoney,
- 'uniqueidentifier': SybaseUniqueIdentifier,
-
- 'java.lang.Object' : SybaseTypeError,
- 'java serialization' : SybaseTypeError,
- }
-
- name = 'sybase'
- # Sybase backend peculiarities
- supports_unicode_statements = False
- supports_sane_rowcount = False
- supports_sane_multi_rowcount = False
- execution_ctx_cls = SybaseSQLExecutionContext
-
- def __new__(cls, dbapi=None, *args, **kwargs):
- if cls != SybaseSQLDialect:
- return super(SybaseSQLDialect, cls).__new__(cls, *args, **kwargs)
- if dbapi:
- print dbapi.__name__
- dialect = dialect_mapping.get(dbapi.__name__)
- return dialect(*args, **kwargs)
- else:
- return object.__new__(cls, *args, **kwargs)
-
- def __init__(self, **params):
- super(SybaseSQLDialect, self).__init__(**params)
- self.text_as_varchar = False
- # FIXME: what is the default schema for sybase connections (DBA?) ?
- self.set_default_schema_name("dba")
-
- def dbapi(cls, module_name=None):
- if module_name:
- try:
- dialect_cls = dialect_mapping[module_name]
- return dialect_cls.import_dbapi()
- except KeyError:
- raise exc.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name)
- else:
- for dialect_cls in dialect_mapping.values():
- try:
- return dialect_cls.import_dbapi()
- except ImportError, e:
- pass
- else:
- raise ImportError('No DBAPI module detected for SybaseSQL - please install mxodbc')
- dbapi = classmethod(dbapi)
-
- def type_descriptor(self, typeobj):
- newobj = sqltypes.adapt_type(typeobj, self.colspecs)
- return newobj
-
- def last_inserted_ids(self):
- return self.context.last_inserted_ids
-
- def get_default_schema_name(self, connection):
- return self.schema_name
-
- def set_default_schema_name(self, schema_name):
- self.schema_name = schema_name
-
- def do_execute(self, cursor, statement, params, **kwargs):
- params = tuple(params)
- super(SybaseSQLDialect, self).do_execute(cursor, statement, params, **kwargs)
-
- # FIXME: remove ?
- def _execute(self, c, statement, parameters):
- try:
- if parameters == {}:
- parameters = ()
- c.execute(statement, parameters)
- self.context.rowcount = c.rowcount
- c.DBPROP_COMMITPRESERVE = "Y"
- except Exception, e:
- raise exc.DBAPIError.instance(statement, parameters, e)
-
- def table_names(self, connection, schema):
- """Ignore the schema and the charset for now."""
- s = sql.select([tables.c.table_name],
- sql.not_(tables.c.table_name.like("SYS%")) and
- tables.c.creator >= 100
- )
- rp = connection.execute(s)
- return [row[0] for row in rp.fetchall()]
-
- def has_table(self, connection, tablename, schema=None):
- # FIXME: ignore schemas for sybase
- s = sql.select([tables.c.table_name], tables.c.table_name == tablename)
-
- c = connection.execute(s)
- row = c.fetchone()
- print "has_table: " + tablename + ": " + str(bool(row is not None))
- return row is not None
-
- def reflecttable(self, connection, table, include_columns):
- # Get base columns
- if table.schema is not None:
- current_schema = table.schema
- else:
- current_schema = self.get_default_schema_name(connection)
-
- s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id])
-
- c = connection.execute(s)
- found_table = False
- # makes sure we append the columns in the correct order
- while True:
- row = c.fetchone()
- if row is None:
- break
- found_table = True
- (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = (
- row[columns.c.column_name],
- row[domains.c.domain_name],
- row[columns.c.nulls] == 'Y',
- row[columns.c.width],
- row[domains.c.precision],
- row[columns.c.scale],
- row[columns.c.default],
- row[columns.c.pkey] == 'Y',
- row[columns.c.max_identity],
- row[tables.c.table_id],
- row[columns.c.column_id],
- )
- if include_columns and name not in include_columns:
- continue
-
- # FIXME: else problems with SybaseBinary(size)
- if numericscale == 0:
- numericscale = None
-
- args = []
- for a in (charlen, numericprec, numericscale):
- if a is not None:
- args.append(a)
- coltype = self.ischema_names.get(type, None)
- if coltype == SybaseString and charlen == -1:
- coltype = SybaseText()
- else:
- if coltype is None:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (type, name))
- coltype = sqltypes.NULLTYPE
- coltype = coltype(*args)
- colargs = []
- if default is not None:
- colargs.append(schema.DefaultClause(sql.text(default)))
-
- # any sequences ?
- col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs)
- if int(max_identity) > 0:
- col.sequence = schema.Sequence(name + '_identity')
- col.sequence.start = int(max_identity)
- col.sequence.increment = 1
-
- # append the column
- table.append_column(col)
-
- # any foreign key constraint for this table ?
- # note: no multi-column foreign keys are considered
- s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name }
- c = connection.execute(s)
- foreignKeys = {}
- while True:
- row = c.fetchone()
- if row is None:
- break
- (foreign_table, foreign_column, primary_table, primary_column) = (
- row[0], row[1], row[2], row[3],
- )
- if not primary_table in foreignKeys.keys():
- foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]]
- else:
- foreignKeys[primary_table][0].append('%s'%(foreign_column))
- foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column))
- for primary_table in foreignKeys.keys():
- #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)]))
- table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True))
-
- if not found_table:
- raise exc.NoSuchTableError(table.name)
-
-
-class SybaseSQLDialect_mxodbc(SybaseSQLDialect):
- execution_ctx_cls = SybaseSQLExecutionContext_mxodbc
-
- def __init__(self, **params):
- super(SybaseSQLDialect_mxodbc, self).__init__(**params)
-
- self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()}
-
- def import_dbapi(cls):
- #import mx.ODBC.Windows as module
- import mxODBC as module
- return module
- import_dbapi = classmethod(import_dbapi)
-
- colspecs = SybaseSQLDialect.colspecs.copy()
- colspecs[sqltypes.Time] = SybaseTime_mxodbc
- colspecs[sqltypes.Date] = SybaseDate_mxodbc
- colspecs[sqltypes.DateTime] = SybaseDateTime_mxodbc
-
- ischema_names = SybaseSQLDialect.ischema_names.copy()
- ischema_names['time'] = SybaseTime_mxodbc
- ischema_names['date'] = SybaseDate_mxodbc
- ischema_names['datetime'] = SybaseDateTime_mxodbc
- ischema_names['smalldatetime'] = SybaseDateTime_mxodbc
-
- def is_disconnect(self, e):
- # FIXME: optimize
- #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e)
- #return True
- return False
-
- def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
- super(SybaseSQLDialect_mxodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
-
- def create_connect_args(self, url):
- '''Return a tuple of *args,**kwargs'''
- # FIXME: handle mx.odbc.Windows proprietary args
- opts = url.translate_connect_args(username='user')
- opts.update(url.query)
- argsDict = {}
- argsDict['user'] = opts['user']
- argsDict['password'] = opts['password']
- connArgs = [[opts['dsn']], argsDict]
- return connArgs
-
-
-class SybaseSQLDialect_pyodbc(SybaseSQLDialect):
- execution_ctx_cls = SybaseSQLExecutionContext_pyodbc
-
- def __init__(self, **params):
- super(SybaseSQLDialect_pyodbc, self).__init__(**params)
- self.dbapi_type_map = {'getdate' : SybaseDate_pyodbc()}
-
- def import_dbapi(cls):
- import mypyodbc as module
- return module
- import_dbapi = classmethod(import_dbapi)
-
- colspecs = SybaseSQLDialect.colspecs.copy()
- colspecs[sqltypes.Time] = SybaseTime_pyodbc
- colspecs[sqltypes.Date] = SybaseDate_pyodbc
- colspecs[sqltypes.DateTime] = SybaseDateTime_pyodbc
-
- ischema_names = SybaseSQLDialect.ischema_names.copy()
- ischema_names['time'] = SybaseTime_pyodbc
- ischema_names['date'] = SybaseDate_pyodbc
- ischema_names['datetime'] = SybaseDateTime_pyodbc
- ischema_names['smalldatetime'] = SybaseDateTime_pyodbc
-
- def is_disconnect(self, e):
- # FIXME: optimize
- #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e)
- #return True
- return False
-
- def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
- super(SybaseSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
-
- def create_connect_args(self, url):
- '''Return a tuple of *args,**kwargs'''
- # FIXME: handle pyodbc proprietary args
- opts = url.translate_connect_args(username='user')
- opts.update(url.query)
-
- self.autocommit = False
- if 'autocommit' in opts:
- self.autocommit = bool(int(opts.pop('autocommit')))
-
- argsDict = {}
- argsDict['UID'] = opts['user']
- argsDict['PWD'] = opts['password']
- argsDict['DSN'] = opts['dsn']
- connArgs = [[';'.join(["%s=%s"%(key, argsDict[key]) for key in argsDict])], {'autocommit' : self.autocommit}]
- return connArgs
-
-
-dialect_mapping = {
- 'sqlalchemy.databases.mxODBC' : SybaseSQLDialect_mxodbc,
-# 'pyodbc' : SybaseSQLDialect_pyodbc,
- }
-
-
-class SybaseSQLCompiler(compiler.DefaultCompiler):
- operators = compiler.DefaultCompiler.operators.copy()
- operators.update({
- sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y),
- })
-
- extract_map = compiler.DefaultCompiler.extract_map.copy()
- extract_map.update ({
- 'doy': 'dayofyear',
- 'dow': 'weekday',
- 'milliseconds': 'millisecond'
- })
-
-
- def bindparam_string(self, name):
- res = super(SybaseSQLCompiler, self).bindparam_string(name)
- if name.lower().startswith('literal'):
- res = 'STRING(%s)' % res
- return res
-
- def get_select_precolumns(self, select):
- s = select._distinct and "DISTINCT " or ""
- if select._limit:
- #if select._limit == 1:
- #s += "FIRST "
- #else:
- #s += "TOP %s " % (select._limit,)
- s += "TOP %s " % (select._limit,)
- if select._offset:
- if not select._limit:
- # FIXME: sybase doesn't allow an offset without a limit
- # so use a huge value for TOP here
- s += "TOP 1000000 "
- s += "START AT %s " % (select._offset+1,)
- return s
-
- def limit_clause(self, select):
- # Limit in sybase is after the select keyword
- return ""
-
- def visit_binary(self, binary):
- """Move bind parameters to the right-hand side of an operator, where possible."""
- if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
- return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
- else:
- return super(SybaseSQLCompiler, self).visit_binary(binary)
-
- def label_select_column(self, select, column, asfrom):
- if isinstance(column, expression.Function):
- return column.label(None)
- else:
- return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
-
- function_rewrites = {'current_date': 'getdate',
- }
- def visit_function(self, func):
- func.name = self.function_rewrites.get(func.name, func.name)
- res = super(SybaseSQLCompiler, self).visit_function(func)
- if func.name.lower() == 'getdate':
- # apply CAST operator
- # FIXME: what about _pyodbc ?
- cast = expression._Cast(func, SybaseDate_mxodbc)
- # infinite recursion
- # res = self.visit_cast(cast)
- res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
- return res
-
- def visit_extract(self, extract):
- field = self.extract_map.get(extract.field, extract.field)
- return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
-
- def for_update_clause(self, select):
- # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
- return ''
-
- def order_by_clause(self, select):
- order_by = self.process(select._order_by_clause)
-
- # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
- if order_by and (not self.is_subquery() or select._limit):
- return " ORDER BY " + order_by
- else:
- return ""
-
-
-class SybaseSQLSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, **kwargs):
-
- colspec = self.preparer.format_column(column)
-
- if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
- column.autoincrement and isinstance(column.type, sqltypes.Integer):
- if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
- column.sequence = schema.Sequence(column.name + '_seq')
-
- if hasattr(column, 'sequence'):
- column.table.has_sequence = column
- #colspec += " numeric(30,0) IDENTITY"
- colspec += " Integer IDENTITY"
- else:
- colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-
- if not column.nullable:
- colspec += " NOT NULL"
-
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- return colspec
-
-
-class SybaseSQLSchemaDropper(compiler.SchemaDropper):
- def visit_index(self, index):
- self.append("\nDROP INDEX %s.%s" % (
- self.preparer.quote_identifier(index.table.name),
- self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
- ))
- self.execute()
-
-
-class SybaseSQLDefaultRunner(base.DefaultRunner):
- pass
-
-
-class SybaseSQLIdentifierPreparer(compiler.IdentifierPreparer):
- reserved_words = RESERVED_WORDS
-
- def __init__(self, dialect):
- super(SybaseSQLIdentifierPreparer, self).__init__(dialect)
-
- def _escape_identifier(self, value):
- #TODO: determin SybaseSQL's escapeing rules
- return value
-
- def _fold_identifier_case(self, value):
- #TODO: determin SybaseSQL's case folding rules
- return value
-
-
-dialect = SybaseSQLDialect
-dialect.statement_compiler = SybaseSQLCompiler
-dialect.schemagenerator = SybaseSQLSchemaGenerator
-dialect.schemadropper = SybaseSQLSchemaDropper
-dialect.preparer = SybaseSQLIdentifierPreparer
-dialect.defaultrunner = SybaseSQLDefaultRunner
--- /dev/null
+__all__ = (
+# 'access',
+# 'firebird',
+# 'informix',
+# 'maxdb',
+# 'mssql',
+ 'mysql',
+ 'oracle',
+ 'postgresql',
+ 'sqlite',
+# 'sybase',
+ )
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""
+Support for the Microsoft Access database.
+
+This dialect is *not* tested on SQLAlchemy 0.6.
+
+
+"""
from sqlalchemy import sql, schema, types, exc, pool
from sqlalchemy.sql import compiler, expression
from sqlalchemy.engine import default, base
def get_col_spec(self):
return "TINYINT"
-class AcSmallInteger(types.Smallinteger):
+class AcSmallInteger(types.SmallInteger):
def get_col_spec(self):
return "SMALLINT"
colspecs = {
types.Unicode : AcUnicode,
types.Integer : AcInteger,
- types.Smallinteger: AcSmallInteger,
+ types.SmallInteger: AcSmallInteger,
types.Numeric : AcNumeric,
types.Float : AcFloat,
types.DateTime : AcDateTime,
return names
-class AccessCompiler(compiler.DefaultCompiler):
- extract_map = compiler.DefaultCompiler.extract_map.copy()
+class AccessCompiler(compiler.SQLCompiler):
+ extract_map = compiler.SQLCompiler.extract_map.copy()
extract_map.update ({
'month': 'm',
'day': 'd',
'dow': 'w',
'week': 'ww'
})
-
+
def visit_select_precolumns(self, select):
"""Access puts TOP, it's version of LIMIT here """
s = select.distinct and "DISTINCT " or ""
field = self.extract_map.get(extract.field, extract.field)
return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
-
-class AccessSchemaGenerator(compiler.SchemaGenerator):
+class AccessDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
return colspec
-class AccessSchemaDropper(compiler.SchemaDropper):
- def visit_index(self, index):
-
+ def visit_drop_index(self, drop):
+ index = drop.element
self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False)))
- self.execute()
-
-class AccessDefaultRunner(base.DefaultRunner):
- pass
class AccessIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = compiler.RESERVED_WORDS.copy()
dialect = AccessDialect
dialect.poolclass = pool.SingletonThreadPool
dialect.statement_compiler = AccessCompiler
-dialect.schemagenerator = AccessSchemaGenerator
-dialect.schemadropper = AccessSchemaDropper
+dialect.ddlcompiler = AccessDDLCompiler
dialect.preparer = AccessIdentifierPreparer
-dialect.defaultrunner = AccessDefaultRunner
-dialect.execution_ctx_cls = AccessExecutionContext
+dialect.execution_ctx_cls = AccessExecutionContext
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.firebird import base, kinterbasdb
+
+base.dialect = kinterbasdb.dialect
\ No newline at end of file
--- /dev/null
+# firebird.py
+# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""
+Support for the Firebird database.
+
+Connectivity is usually supplied via the kinterbasdb_
+DBAPI module.
+
+Firebird dialects
+-----------------
+
+Firebird offers two distinct dialects_ (not to be confused with a
+SQLAlchemy ``Dialect``):
+
+dialect 1
+ This is the old syntax and behaviour, inherited from Interbase pre-6.0.
+
+dialect 3
+ This is the newer and supported syntax, introduced in Interbase 6.0.
+
+The SQLAlchemy Firebird dialect detects these versions and
+adjusts its representation of SQL accordingly. However,
+support for dialect 1 is not well tested and probably has
+incompatibilities.
+
+Firebird Locking Behavior
+-------------------------
+
+Firebird locks tables aggressively. For this reason, a DROP TABLE may
+hang until other transactions are released. SQLAlchemy does its best
+to release transactions as quickly as possible. The most common cause
+of hanging transactions is a non-fully consumed result set, i.e.::
+
+ result = engine.execute("select * from table")
+ row = result.fetchone()
+ return
+
+Where above, the ``ResultProxy`` has not been fully consumed. The
+connection will be returned to the pool and the transactional state
+rolled back once the Python garbage collector reclaims the objects
+which hold onto the connection, which often occurs asynchronously.
+The above use case can be alleviated by calling ``first()`` on the
+``ResultProxy`` which will fetch the first row and immediately close
+all remaining cursor/connection resources.
+
+RETURNING support
+-----------------
+
+Firebird 2.0 supports returning a result set from inserts, and 2.1 extends
+that to deletes and updates.
+
+To use this pass the column/expression list to the ``firebird_returning``
+parameter when creating the queries::
+
+ raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1),
+ firebird_returning=[empl.c.id, empl.c.salary]).execute().fetchall()
+
+
+.. [#] Well, that is not the whole story, as the client may still ask
+ a different (lower) dialect...
+
+.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html
+.. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb
+
+"""
+
+
+import datetime, decimal, re
+
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import exc, types as sqltypes, sql, util
+from sqlalchemy.sql import expression
+from sqlalchemy.engine import base, default, reflection
+from sqlalchemy.sql import compiler
+
+from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE,
+ FLOAT, INTEGER, NUMERIC, SMALLINT,
+ TEXT, TIME, TIMESTAMP, VARCHAR)
+
+
+RESERVED_WORDS = set(
+ ["action", "active", "add", "admin", "after", "all", "alter", "and", "any",
+ "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename",
+ "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer",
+ "by", "cache", "cascade", "case", "cast", "char", "character", "character_length",
+ "char_length", "check", "check_point_len", "check_point_length", "close", "collate",
+ "collation", "column", "commit", "committed", "compiletime", "computed", "conditional",
+ "connect", "constraint", "containing", "continue", "count", "create", "cstring",
+ "current", "current_connection", "current_date", "current_role", "current_time",
+ "current_timestamp", "current_transaction", "current_user", "cursor", "database",
+ "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete",
+ "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct",
+ "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point",
+ "escape", "event", "exception", "execute", "exists", "exit", "extern", "external",
+ "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it",
+ "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto",
+ "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour",
+ "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input",
+ "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join",
+ "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile",
+ "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment",
+ "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month",
+ "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric",
+ "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option",
+ "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength",
+ "pages", "page_size", "parameter", "password", "plan", "position", "post_event",
+ "precision", "prepare", "primary", "privileges", "procedure", "protected", "public",
+ "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate",
+ "references", "release", "release", "reserv", "reserving", "restrict", "retain",
+ "return", "returning_values", "returns", "revoke", "right", "role", "rollback",
+ "row_count", "runtime", "savepoint", "schema", "second", "segment", "select",
+ "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint",
+ "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability",
+ "starting", "starts", "statement", "static", "statistics", "sub_type", "sum",
+ "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction",
+ "translate", "translation", "trigger", "trim", "type", "uncommitted", "union",
+ "unique", "update", "upper", "user", "using", "value", "values", "varchar",
+ "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when",
+ "whenever", "where", "while", "with", "work", "write", "year", "yearday" ])
+
+
+class _FBBoolean(sqltypes.Boolean):
+ def result_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ return value and True or False
+ return process
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is True:
+ return 1
+ elif value is False:
+ return 0
+ elif value is None:
+ return None
+ else:
+ return value and True or False
+ return process
+
+
+colspecs = {
+ sqltypes.Boolean: _FBBoolean,
+}
+
+ischema_names = {
+ 'SHORT': SMALLINT,
+ 'LONG': BIGINT,
+ 'QUAD': FLOAT,
+ 'FLOAT': FLOAT,
+ 'DATE': DATE,
+ 'TIME': TIME,
+ 'TEXT': TEXT,
+ 'INT64': NUMERIC,
+ 'DOUBLE': FLOAT,
+ 'TIMESTAMP': TIMESTAMP,
+ 'VARYING': VARCHAR,
+ 'CSTRING': CHAR,
+ 'BLOB': BLOB,
+ }
+
+
+# TODO: date conversion types (should be implemented as _FBDateTime, _FBDate, etc.
+# as bind/result functionality is required)
+
+class FBTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_boolean(self, type_):
+ return self.visit_SMALLINT(type_)
+
+ def visit_datetime(self, type_):
+ return self.visit_TIMESTAMP(type_)
+
+ def visit_TEXT(self, type_):
+ return "BLOB SUB_TYPE 1"
+
+ def visit_BLOB(self, type_):
+ return "BLOB SUB_TYPE 0"
+
+
+class FBCompiler(sql.compiler.SQLCompiler):
+ """Firebird specific idiosincrasies"""
+
+ def visit_mod(self, binary, **kw):
+ # Firebird lacks a builtin modulo operator, but there is
+ # an equivalent function in the ib_udf library.
+ return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_alias(self, alias, asfrom=False, **kwargs):
+ if self.dialect._version_two:
+ return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs)
+ else:
+ # Override to not use the AS keyword which FB 1.5 does not like
+ if asfrom:
+ alias_name = isinstance(alias.name, expression._generated_label) and \
+ self._truncated_identifier("alias", alias.name) or alias.name
+
+ return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \
+ self.preparer.format_alias(alias, alias_name)
+ else:
+ return self.process(alias.original, **kwargs)
+
+ def visit_substring_func(self, func, **kw):
+ s = self.process(func.clauses.clauses[0])
+ start = self.process(func.clauses.clauses[1])
+ if len(func.clauses.clauses) > 2:
+ length = self.process(func.clauses.clauses[2])
+ return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
+ else:
+ return "SUBSTRING(%s FROM %s)" % (s, start)
+
+ def visit_length_func(self, function, **kw):
+ if self.dialect._version_two:
+ return "char_length" + self.function_argspec(function)
+ else:
+ return "strlen" + self.function_argspec(function)
+
+ visit_char_length_func = visit_length_func
+
+ def function_argspec(self, func, **kw):
+ if func.clauses:
+ return self.process(func.clause_expr)
+ else:
+ return ""
+
+ def default_from(self):
+ return " FROM rdb$database"
+
+ def visit_sequence(self, seq):
+ return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
+
+ def get_select_precolumns(self, select):
+ """Called when building a ``SELECT`` statement, position is just
+ before column list Firebird puts the limit and offset right
+ after the ``SELECT``...
+ """
+
+ result = ""
+ if select._limit:
+ result += "FIRST %d " % select._limit
+ if select._offset:
+ result +="SKIP %d " % select._offset
+ if select._distinct:
+ result += "DISTINCT "
+ return result
+
+ def limit_clause(self, select):
+ """Already taken care of in the `get_select_precolumns` method."""
+
+ return ""
+
+ def returning_clause(self, stmt, returning_cols):
+
+ columns = [
+ self.process(
+ self.label_select_column(None, c, asfrom=False),
+ within_columns_clause=True,
+ result_map=self.result_map
+ )
+ for c in expression._select_iterables(returning_cols)
+ ]
+ return 'RETURNING ' + ', '.join(columns)
+
+
+class FBDDLCompiler(sql.compiler.DDLCompiler):
+ """Firebird syntactic idiosincrasies"""
+
+ def visit_create_sequence(self, create):
+ """Generate a ``CREATE GENERATOR`` statement for the sequence."""
+
+ if self.dialect._version_two:
+ return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+ else:
+ return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element)
+
+ def visit_drop_sequence(self, drop):
+ """Generate a ``DROP GENERATOR`` statement for the sequence."""
+
+ if self.dialect._version_two:
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+ else:
+ return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element)
+
+
+class FBDefaultRunner(base.DefaultRunner):
+ """Firebird specific idiosincrasies"""
+
+ def visit_sequence(self, seq):
+ """Get the next value from the sequence using ``gen_id()``."""
+
+ return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \
+ self.dialect.identifier_preparer.format_sequence(seq))
+
+
+class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
+ """Install Firebird specific reserved words."""
+
+ reserved_words = RESERVED_WORDS
+
+ def __init__(self, dialect):
+ super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
+
+
+class FBDialect(default.DefaultDialect):
+ """Firebird dialect"""
+
+ name = 'firebird'
+
+ max_identifier_length = 31
+
+ supports_sequences = True
+ sequences_optional = False
+ supports_default_values = True
+ postfetch_lastrowid = False
+
+ requires_name_normalize = True
+ supports_empty_insert = False
+
+ statement_compiler = FBCompiler
+ ddl_compiler = FBDDLCompiler
+ defaultrunner = FBDefaultRunner
+ preparer = FBIdentifierPreparer
+ type_compiler = FBTypeCompiler
+
+ colspecs = colspecs
+ ischema_names = ischema_names
+
+ # defaults to dialect ver. 3,
+ # will be autodetected off upon
+ # first connect
+ _version_two = True
+
+ def initialize(self, connection):
+ super(FBDialect, self).initialize(connection)
+ self._version_two = self.server_version_info > (2, )
+ if not self._version_two:
+ # TODO: whatever other pre < 2.0 stuff goes here
+ self.ischema_names = ischema_names.copy()
+ self.ischema_names['TIMESTAMP'] = sqltypes.DATE
+ self.colspecs = {
+ sqltypes.DateTime: sqltypes.DATE
+ }
+ else:
+ self.implicit_returning = True
+
+ def normalize_name(self, name):
+ # Remove trailing spaces: FB uses a CHAR() type,
+ # that is padded with spaces
+ name = name and name.rstrip()
+ if name is None:
+ return None
+ elif name.upper() == name and \
+ not self.identifier_preparer._requires_quotes(name.lower()):
+ return name.lower()
+ else:
+ return name
+
+ def denormalize_name(self, name):
+ if name is None:
+ return None
+ elif name.lower() == name and \
+ not self.identifier_preparer._requires_quotes(name.lower()):
+ return name.upper()
+ else:
+ return name
+
+ def has_table(self, connection, table_name, schema=None):
+ """Return ``True`` if the given table exists, ignoring the `schema`."""
+
+ tblqry = """
+ SELECT 1 FROM rdb$database
+ WHERE EXISTS (SELECT rdb$relation_name
+ FROM rdb$relations
+ WHERE rdb$relation_name=?)
+ """
+ c = connection.execute(tblqry, [self.denormalize_name(table_name)])
+ return c.first() is not None
+
+ def has_sequence(self, connection, sequence_name):
+ """Return ``True`` if the given sequence (generator) exists."""
+
+ genqry = """
+ SELECT 1 FROM rdb$database
+ WHERE EXISTS (SELECT rdb$generator_name
+ FROM rdb$generators
+ WHERE rdb$generator_name=?)
+ """
+ c = connection.execute(genqry, [self.denormalize_name(sequence_name)])
+ return c.first() is not None
+
+ def table_names(self, connection, schema):
+ s = """
+ SELECT DISTINCT rdb$relation_name
+ FROM rdb$relation_fields
+ WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
+ """
+ return [self.normalize_name(row[0]) for row in connection.execute(s)]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ return self.table_names(connection, schema)
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ s = """
+ SELECT distinct rdb$view_name
+ FROM rdb$view_relations
+ """
+ return [self.normalize_name(row[0]) for row in connection.execute(s)]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ qry = """
+ SELECT rdb$view_source AS view_source
+ FROM rdb$relations
+ WHERE rdb$relation_name=?
+ """
+ rp = connection.execute(qry, [self.denormalize_name(view_name)])
+ row = rp.first()
+ if row:
+ return row['view_source']
+ else:
+ return None
+
+ @reflection.cache
+ def get_primary_keys(self, connection, table_name, schema=None, **kw):
+ # Query to extract the PK/FK constrained fields of the given table
+ keyqry = """
+ SELECT se.rdb$field_name AS fname
+ FROM rdb$relation_constraints rc
+ JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
+ WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
+ """
+ tablename = self.denormalize_name(table_name)
+ # get primary key fields
+ c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
+ pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
+ return pkfields
+
+ @reflection.cache
+ def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw):
+ tablename = self.denormalize_name(table_name)
+ colname = self.denormalize_name(column_name)
+ # Heuristic-query to determine the generator associated to a PK field
+ genqry = """
+ SELECT trigdep.rdb$depended_on_name AS fgenerator
+ FROM rdb$dependencies tabdep
+ JOIN rdb$dependencies trigdep
+ ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
+ AND trigdep.rdb$depended_on_type=14
+ AND trigdep.rdb$dependent_type=2
+ JOIN rdb$triggers trig ON trig.rdb$trigger_name=tabdep.rdb$dependent_name
+ WHERE tabdep.rdb$depended_on_name=?
+ AND tabdep.rdb$depended_on_type=0
+ AND trig.rdb$trigger_type=1
+ AND tabdep.rdb$field_name=?
+ AND (SELECT count(*)
+ FROM rdb$dependencies trigdep2
+ WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
+ """
+ genc = connection.execute(genqry, [tablename, colname])
+ genr = genc.fetchone()
+ if genr is not None:
+ return dict(name=self.normalize_name(genr['fgenerator']))
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ # Query to extract the details of all the fields of the given table
+ tblqry = """
+ SELECT DISTINCT r.rdb$field_name AS fname,
+ r.rdb$null_flag AS null_flag,
+ t.rdb$type_name AS ftype,
+ f.rdb$field_sub_type AS stype,
+ f.rdb$field_length AS flen,
+ f.rdb$field_precision AS fprec,
+ f.rdb$field_scale AS fscale,
+ COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault
+ FROM rdb$relation_fields r
+ JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
+ JOIN rdb$types t
+ ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE'
+ WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
+ ORDER BY r.rdb$field_position
+ """
+ # get the PK, used to determine the eventual associated sequence
+ pkey_cols = self.get_primary_keys(connection, table_name)
+
+ tablename = self.denormalize_name(table_name)
+ # get all of the fields for this table
+ c = connection.execute(tblqry, [tablename])
+ cols = []
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ name = self.normalize_name(row['fname'])
+ # get the data type
+
+ colspec = row['ftype'].rstrip()
+ coltype = self.ischema_names.get(colspec)
+ if coltype is None:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (colspec, name))
+ coltype = sqltypes.NULLTYPE
+ elif colspec == 'INT64':
+ coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1)
+ elif colspec in ('VARYING', 'CSTRING'):
+ coltype = coltype(row['flen'])
+ elif colspec == 'TEXT':
+ coltype = TEXT(row['flen'])
+ elif colspec == 'BLOB':
+ if row['stype'] == 1:
+ coltype = TEXT()
+ else:
+ coltype = BLOB()
+ else:
+ coltype = coltype(row)
+
+ # does it have a default value?
+ defvalue = None
+ if row['fdefault'] is not None:
+ # the value comes down as "DEFAULT 'value'"
+ assert row['fdefault'].upper().startswith('DEFAULT '), row
+ defvalue = row['fdefault'][8:]
+ col_d = {
+ 'name' : name,
+ 'type' : coltype,
+ 'nullable' : not bool(row['null_flag']),
+ 'default' : defvalue
+ }
+
+ # if the PK is a single field, try to see if its linked to
+ # a sequence thru a trigger
+ if len(pkey_cols)==1 and name==pkey_cols[0]:
+ seq_d = self.get_column_sequence(connection, tablename, name)
+ if seq_d is not None:
+ col_d['sequence'] = seq_d
+
+ cols.append(col_d)
+ return cols
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ # Query to extract the details of each UK/FK of the given table
+ fkqry = """
+ SELECT rc.rdb$constraint_name AS cname,
+ cse.rdb$field_name AS fname,
+ ix2.rdb$relation_name AS targetrname,
+ se.rdb$field_name AS targetfname
+ FROM rdb$relation_constraints rc
+ JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
+ JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
+ JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name
+ JOIN rdb$index_segments se
+ ON se.rdb$index_name=ix2.rdb$index_name
+ AND se.rdb$field_position=cse.rdb$field_position
+ WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
+ ORDER BY se.rdb$index_name, se.rdb$field_position
+ """
+ tablename = self.denormalize_name(table_name)
+
+ c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
+ fks = util.defaultdict(lambda:{
+ 'name' : None,
+ 'constrained_columns' : [],
+ 'referred_schema' : None,
+ 'referred_table' : None,
+ 'referred_columns' : []
+ })
+
+ for row in c:
+ cname = self.normalize_name(row['cname'])
+ fk = fks[cname]
+ if not fk['name']:
+ fk['name'] = cname
+ fk['referred_table'] = self.normalize_name(row['targetrname'])
+ fk['constrained_columns'].append(self.normalize_name(row['fname']))
+ fk['referred_columns'].append(
+ self.normalize_name(row['targetfname']))
+ return fks.values()
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ qry = """
+ SELECT ix.rdb$index_name AS index_name,
+ ix.rdb$unique_flag AS unique_flag,
+ ic.rdb$field_name AS field_name
+ FROM rdb$indices ix
+ JOIN rdb$index_segments ic
+ ON ix.rdb$index_name=ic.rdb$index_name
+ LEFT OUTER JOIN rdb$relation_constraints
+ ON rdb$relation_constraints.rdb$index_name = ic.rdb$index_name
+ WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
+ AND rdb$relation_constraints.rdb$constraint_type IS NULL
+ ORDER BY index_name, field_name
+ """
+ c = connection.execute(qry, [self.denormalize_name(table_name)])
+
+ indexes = util.defaultdict(dict)
+ for row in c:
+ indexrec = indexes[row['index_name']]
+ if 'name' not in indexrec:
+ indexrec['name'] = self.normalize_name(row['index_name'])
+ indexrec['column_names'] = []
+ indexrec['unique'] = bool(row['unique_flag'])
+
+ indexrec['column_names'].append(self.normalize_name(row['field_name']))
+
+ return indexes.values()
+
+ def do_execute(self, cursor, statement, parameters, **kwargs):
+ # kinterbase does not accept a None, but wants an empty list
+ # when there are no arguments.
+ cursor.execute(statement, parameters or [])
+
+ def do_rollback(self, connection):
+ # Use the retaining feature, that keeps the transaction going
+ connection.rollback(True)
+
+ def do_commit(self, connection):
+ # Use the retaining feature, that keeps the transaction going
+ connection.commit(True)
--- /dev/null
+from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler
+from sqlalchemy.engine.default import DefaultExecutionContext
+
+_initialized_kb = False
+
+class Firebird_kinterbasdb(FBDialect):
+ driver = 'kinterbasdb'
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+
+ def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
+ super(Firebird_kinterbasdb, self).__init__(**kwargs)
+
+ self.type_conv = type_conv
+ self.concurrency_level = concurrency_level
+
+ @classmethod
+ def dbapi(cls):
+ k = __import__('kinterbasdb')
+ return k
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username='user')
+ if opts.get('port'):
+ opts['host'] = "%s/%s" % (opts['host'], opts['port'])
+ del opts['port']
+ opts.update(url.query)
+
+ type_conv = opts.pop('type_conv', self.type_conv)
+ concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
+ global _initialized_kb
+ if not _initialized_kb and self.dbapi is not None:
+ _initialized_kb = True
+ self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
+ return ([], opts)
+
+ def _get_server_version_info(self, connection):
+ """Get the version of the Firebird server used by a connection.
+
+ Returns a tuple of (`major`, `minor`, `build`), three integers
+ representing the version of the attached server.
+ """
+
+ # This is the simpler approach (the other uses the services api),
+ # that for backward compatibility reasons returns a string like
+ # LI-V6.3.3.12981 Firebird 2.0
+ # where the first version is a fake one resembling the old
+ # Interbase signature. This is more than enough for our purposes,
+ # as this is mainly (only?) used by the testsuite.
+
+ from re import match
+
+ fbconn = connection.connection
+ version = fbconn.server_version
+ m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
+ if not m:
+ raise AssertionError("Could not determine version from string '%s'" % version)
+ return tuple([int(x) for x in m.group(5, 6, 4)])
+
+ def is_disconnect(self, e):
+ if isinstance(e, self.dbapi.OperationalError):
+ return 'Unable to complete network request to host' in str(e)
+ elif isinstance(e, self.dbapi.ProgrammingError):
+ msg = str(e)
+ return ('Invalid connection state' in msg or
+ 'Invalid cursor state' in msg)
+ else:
+ return False
+
+dialect = Firebird_kinterbasdb
--- /dev/null
+from sqlalchemy.dialects.informix import base, informixdb
+
+base.dialect = informixdb.dialect
\ No newline at end of file
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Support for the Informix database.
+
+This dialect is *not* tested on SQLAlchemy 0.6.
+
+"""
+
import datetime
from sqlalchemy import types as sqltypes
-# for offset
-
-class informix_cursor(object):
- def __init__( self , con ):
- self.__cursor = con.cursor()
- self.rowcount = 0
-
- def offset( self , n ):
- if n > 0:
- self.fetchmany( n )
- self.rowcount = self.__cursor.rowcount - n
- if self.rowcount < 0:
- self.rowcount = 0
- else:
- self.rowcount = self.__cursor.rowcount
-
- def execute( self , sql , params ):
- if params is None or len( params ) == 0:
- params = []
-
- return self.__cursor.execute( sql , params )
-
- def __getattr__( self , name ):
- if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ):
- return getattr( self.__cursor , name )
-
-class InfoNumeric(sqltypes.Numeric):
- def get_col_spec(self):
- if not self.precision:
- return 'NUMERIC'
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class InfoInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-class InfoSmallInteger(sqltypes.Smallinteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-class InfoDate(sqltypes.Date):
- def get_col_spec( self ):
- return "DATE"
-
class InfoDateTime(sqltypes.DateTime ):
- def get_col_spec(self):
- return "DATETIME YEAR TO SECOND"
-
def bind_processor(self, dialect):
def process(value):
if value is not None:
return process
class InfoTime(sqltypes.Time ):
- def get_col_spec(self):
- return "DATETIME HOUR TO SECOND"
-
def bind_processor(self, dialect):
def process(value):
if value is not None:
return value
return process
-class InfoText(sqltypes.String):
- def get_col_spec(self):
- return "VARCHAR(255)"
-
-class InfoString(sqltypes.String):
- def get_col_spec(self):
- return "VARCHAR(%(length)s)" % {'length' : self.length}
-
- def bind_processor(self, dialect):
- def process(value):
- if value == '':
- return None
- else:
- return value
- return process
-
-class InfoChar(sqltypes.CHAR):
- def get_col_spec(self):
- return "CHAR(%(length)s)" % {'length' : self.length}
-
-class InfoBinary(sqltypes.Binary):
- def get_col_spec(self):
- return "BYTE"
class InfoBoolean(sqltypes.Boolean):
- default_type = 'NUM'
- def get_col_spec(self):
- return "SMALLINT"
-
def result_processor(self, dialect):
def process(value):
if value is None:
return process
colspecs = {
- sqltypes.Integer : InfoInteger,
- sqltypes.Smallinteger : InfoSmallInteger,
- sqltypes.Numeric : InfoNumeric,
- sqltypes.Float : InfoNumeric,
sqltypes.DateTime : InfoDateTime,
- sqltypes.Date : InfoDate,
sqltypes.Time: InfoTime,
- sqltypes.String : InfoString,
- sqltypes.Binary : InfoBinary,
sqltypes.Boolean : InfoBoolean,
- sqltypes.Text : InfoText,
- sqltypes.CHAR: InfoChar,
}
ischema_names = {
- 0 : InfoString, # CHAR
- 1 : InfoSmallInteger, # SMALLINT
- 2 : InfoInteger, # INT
- 3 : InfoNumeric, # Float
- 3 : InfoNumeric, # SmallFloat
- 5 : InfoNumeric, # DECIMAL
- 6 : InfoInteger, # Serial
- 7 : InfoDate, # DATE
- 8 : InfoNumeric, # MONEY
- 10 : InfoDateTime, # DATETIME
- 11 : InfoBinary, # BYTE
- 12 : InfoText, # TEXT
- 13 : InfoString, # VARCHAR
- 15 : InfoString, # NCHAR
- 16 : InfoString, # NVARCHAR
- 17 : InfoInteger, # INT8
- 18 : InfoInteger, # Serial8
- 43 : InfoString, # LVARCHAR
- -1 : InfoBinary, # BLOB
- -1 : InfoText, # CLOB
+ 0 : sqltypes.CHAR, # CHAR
+ 1 : sqltypes.SMALLINT, # SMALLINT
+ 2 : sqltypes.INTEGER, # INT
+ 3 : sqltypes.FLOAT, # Float
+ 3 : sqltypes.Float, # SmallFloat
+ 5 : sqltypes.DECIMAL, # DECIMAL
+ 6 : sqltypes.Integer, # Serial
+ 7 : sqltypes.DATE, # DATE
+ 8 : sqltypes.Numeric, # MONEY
+ 10 : sqltypes.DATETIME, # DATETIME
+ 11 : sqltypes.Binary, # BYTE
+ 12 : sqltypes.TEXT, # TEXT
+ 13 : sqltypes.VARCHAR, # VARCHAR
+ 15 : sqltypes.NCHAR, # NCHAR
+ 16 : sqltypes.NVARCHAR, # NVARCHAR
+ 17 : sqltypes.Integer, # INT8
+ 18 : sqltypes.Integer, # Serial8
+ 43 : sqltypes.String, # LVARCHAR
+ -1 : sqltypes.BLOB, # BLOB
+ -1 : sqltypes.CLOB, # CLOB
}
-class InfoExecutionContext(default.DefaultExecutionContext):
- # cursor.sqlerrd
- # 0 - estimated number of rows returned
- # 1 - serial value after insert or ISAM error code
- # 2 - number of rows processed
- # 3 - estimated cost
- # 4 - offset of the error into the SQL statement
- # 5 - rowid after insert
- def post_exec(self):
- if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None:
- self._last_inserted_ids = [self.cursor.sqlerrd[1]]
- elif hasattr( self.compiled , 'offset' ):
- self.cursor.offset( self.compiled.offset )
- super(InfoExecutionContext, self).post_exec()
-
- def create_cursor( self ):
- return informix_cursor( self.connection.connection )
-
-class InfoDialect(default.DefaultDialect):
- name = 'informix'
- default_paramstyle = 'qmark'
- # for informix 7.31
- max_identifier_length = 18
+class InfoTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_DATETIME(self, type_):
+ return "DATETIME YEAR TO SECOND"
+
+ def visit_TIME(self, type_):
+ return "DATETIME HOUR TO SECOND"
+
+ def visit_binary(self, type_):
+ return "BYTE"
+
+ def visit_boolean(self, type_):
+ return "SMALLINT"
+
+class InfoSQLCompiler(compiler.SQLCompiler):
- def __init__(self, use_ansi=True, **kwargs):
- self.use_ansi = use_ansi
- default.DefaultDialect.__init__(self, **kwargs)
+ def __init__(self, *args, **kwargs):
+ self.limit = 0
+ self.offset = 0
- def dbapi(cls):
- import informixdb
- return informixdb
- dbapi = classmethod(dbapi)
+ compiler.SQLCompiler.__init__( self , *args, **kwargs )
+
+ def default_from(self):
+ return " from systables where tabname = 'systables' "
- def is_disconnect(self, e):
- if isinstance(e, self.dbapi.OperationalError):
- return 'closed the connection' in str(e) or 'connection not open' in str(e)
+ def get_select_precolumns( self , select ):
+ s = select._distinct and "DISTINCT " or ""
+ # only has limit
+ if select._limit:
+ off = select._offset or 0
+ s += " FIRST %s " % ( select._limit + off )
else:
- return False
+ s += ""
+ return s
- def do_begin(self , connect ):
- cu = connect.cursor()
- cu.execute( 'SET LOCK MODE TO WAIT' )
- #cu.execute( 'SET ISOLATION TO REPEATABLE READ' )
+ def visit_select(self, select):
+ if select._offset:
+ self.offset = select._offset
+ self.limit = select._limit or 0
+ # the column in order by clause must in select too
+
+ def __label( c ):
+ try:
+ return c._label.lower()
+ except:
+ return ''
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
+ # TODO: dont modify the original select, generate a new one
+ a = [ __label(c) for c in select._raw_columns ]
+ for c in select._order_by_clause.clauses:
+ if ( __label(c) not in a ):
+ select.append_column( c )
+
+ return compiler.SQLCompiler.visit_select(self, select)
+
+ def limit_clause(self, select):
+ return ""
- def create_connect_args(self, url):
- if url.host:
- dsn = '%s@%s' % ( url.database , url.host )
+ def visit_function( self , func ):
+ if func.name.lower() == 'current_date':
+ return "today"
+ elif func.name.lower() == 'current_time':
+ return "CURRENT HOUR TO SECOND"
+ elif func.name.lower() in ( 'current_timestamp' , 'now' ):
+ return "CURRENT YEAR TO SECOND"
else:
- dsn = url.database
+ return compiler.SQLCompiler.visit_function( self , func )
- if url.username:
- opt = { 'user':url.username , 'password': url.password }
+ def visit_clauselist(self, list, **kwargs):
+ return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None])
+
+class InfoDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, first_pk=False):
+ colspec = self.preparer.format_column(column)
+ if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
+ isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk:
+ colspec += " SERIAL"
+ self.has_serial = True
else:
- opt = {}
+ colspec += " " + self.dialect.type_compiler.process(column.type)
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+
+ return colspec
+
+ def post_create_table(self, table):
+ if hasattr( self , 'has_serial' ):
+ del self.has_serial
+ return ''
+
+class InfoIdentifierPreparer(compiler.IdentifierPreparer):
+ def __init__(self, dialect):
+ super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
+
+ def format_constraint(self, constraint):
+ # informix doesnt support names for constraints
+ return ''
+
+ def _requires_quotes(self, value):
+ return False
+
+class InformixDialect(default.DefaultDialect):
+ name = 'informix'
+ # for informix 7.31
+ max_identifier_length = 18
+ type_compiler = InfoTypeCompiler
+ poolclass = pool.SingletonThreadPool
+ statement_compiler = InfoSQLCompiler
+ ddl_compiler = InfoDDLCompiler
+ preparer = InfoIdentifierPreparer
+ colspecs = colspecs
+ ischema_names = ischema_names
- return ([dsn], opt)
+ def do_begin(self , connect ):
+ cu = connect.cursor()
+ cu.execute( 'SET LOCK MODE TO WAIT' )
+ #cu.execute( 'SET ISOLATION TO REPEATABLE READ' )
def table_names(self, connection, schema):
s = "select tabname from systables"
for cons_name, cons_type, local_column in rows:
table.primary_key.add( table.c[local_column] )
-class InfoCompiler(compiler.DefaultCompiler):
- """Info compiler modifies the lexical structure of Select statements to work under
- non-ANSI configured Oracle databases, if the use_ansi flag is False."""
-
- def __init__(self, *args, **kwargs):
- self.limit = 0
- self.offset = 0
-
- compiler.DefaultCompiler.__init__( self , *args, **kwargs )
-
- def default_from(self):
- return " from systables where tabname = 'systables' "
-
- def get_select_precolumns( self , select ):
- s = select._distinct and "DISTINCT " or ""
- # only has limit
- if select._limit:
- off = select._offset or 0
- s += " FIRST %s " % ( select._limit + off )
- else:
- s += ""
- return s
-
- def visit_select(self, select):
- if select._offset:
- self.offset = select._offset
- self.limit = select._limit or 0
- # the column in order by clause must in select too
-
- def __label( c ):
- try:
- return c._label.lower()
- except:
- return ''
-
- # TODO: dont modify the original select, generate a new one
- a = [ __label(c) for c in select._raw_columns ]
- for c in select._order_by_clause.clauses:
- if ( __label(c) not in a ):
- select.append_column( c )
-
- return compiler.DefaultCompiler.visit_select(self, select)
-
- def limit_clause(self, select):
- return ""
-
- def visit_function( self , func ):
- if func.name.lower() == 'current_date':
- return "today"
- elif func.name.lower() == 'current_time':
- return "CURRENT HOUR TO SECOND"
- elif func.name.lower() in ( 'current_timestamp' , 'now' ):
- return "CURRENT YEAR TO SECOND"
- else:
- return compiler.DefaultCompiler.visit_function( self , func )
-
- def visit_clauselist(self, list, **kwargs):
- return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None])
-
-class InfoSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, first_pk=False):
- colspec = self.preparer.format_column(column)
- if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
- isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk:
- colspec += " SERIAL"
- self.has_serial = True
- else:
- colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- if not column.nullable:
- colspec += " NOT NULL"
-
- return colspec
-
- def post_create_table(self, table):
- if hasattr( self , 'has_serial' ):
- del self.has_serial
- return ''
-
- def visit_primary_key_constraint(self, constraint):
- # for informix 7.31 not support constraint name
- name = constraint.name
- constraint.name = None
- super(InfoSchemaGenerator, self).visit_primary_key_constraint(constraint)
- constraint.name = name
-
- def visit_unique_constraint(self, constraint):
- # for informix 7.31 not support constraint name
- name = constraint.name
- constraint.name = None
- super(InfoSchemaGenerator, self).visit_unique_constraint(constraint)
- constraint.name = name
-
- def visit_foreign_key_constraint( self , constraint ):
- if constraint.name is not None:
- constraint.use_alter = True
- else:
- super( InfoSchemaGenerator , self ).visit_foreign_key_constraint( constraint )
-
- def define_foreign_key(self, constraint):
- # for informix 7.31 not support constraint name
- if constraint.use_alter:
- name = constraint.name
- constraint.name = None
- self.append( "CONSTRAINT " )
- super(InfoSchemaGenerator, self).define_foreign_key(constraint)
- constraint.name = name
- if name is not None:
- self.append( " CONSTRAINT " + name )
- else:
- super(InfoSchemaGenerator, self).define_foreign_key(constraint)
-
- def visit_index(self, index):
- if len( index.columns ) == 1 and index.columns[0].foreign_key:
- return
- super(InfoSchemaGenerator, self).visit_index(index)
-
-class InfoIdentifierPreparer(compiler.IdentifierPreparer):
- def __init__(self, dialect):
- super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
-
- def _requires_quotes(self, value):
- return False
-
-class InfoSchemaDropper(compiler.SchemaDropper):
- def drop_foreignkey(self, constraint):
- if constraint.name is not None:
- super( InfoSchemaDropper , self ).drop_foreignkey( constraint )
-
-dialect = InfoDialect
-poolclass = pool.SingletonThreadPool
-dialect.statement_compiler = InfoCompiler
-dialect.schemagenerator = InfoSchemaGenerator
-dialect.schemadropper = InfoSchemaDropper
-dialect.preparer = InfoIdentifierPreparer
-dialect.execution_ctx_cls = InfoExecutionContext
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.informix.base import InformixDialect
+from sqlalchemy.engine import default
+
+# for offset
+
+class informix_cursor(object):
+ def __init__( self , con ):
+ self.__cursor = con.cursor()
+ self.rowcount = 0
+
+ def offset( self , n ):
+ if n > 0:
+ self.fetchmany( n )
+ self.rowcount = self.__cursor.rowcount - n
+ if self.rowcount < 0:
+ self.rowcount = 0
+ else:
+ self.rowcount = self.__cursor.rowcount
+
+ def execute( self , sql , params ):
+ if params is None or len( params ) == 0:
+ params = []
+
+ return self.__cursor.execute( sql , params )
+
+ def __getattr__( self , name ):
+ if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ):
+ return getattr( self.__cursor , name )
+
+
+class InfoExecutionContext(default.DefaultExecutionContext):
+ # cursor.sqlerrd
+ # 0 - estimated number of rows returned
+ # 1 - serial value after insert or ISAM error code
+ # 2 - number of rows processed
+ # 3 - estimated cost
+ # 4 - offset of the error into the SQL statement
+ # 5 - rowid after insert
+ def post_exec(self):
+ if getattr(self.compiled, "isinsert", False) and self.inserted_primary_key is None:
+ self._last_inserted_ids = [self.cursor.sqlerrd[1]]
+ elif hasattr( self.compiled , 'offset' ):
+ self.cursor.offset( self.compiled.offset )
+
+ def create_cursor( self ):
+ return informix_cursor( self.connection.connection )
+
+
+class Informix_informixdb(InformixDialect):
+ driver = 'informixdb'
+ default_paramstyle = 'qmark'
+ execution_context_cls = InfoExecutionContext
+
+ @classmethod
+ def dbapi(cls):
+ return __import__('informixdb')
+
+ def create_connect_args(self, url):
+ if url.host:
+ dsn = '%s@%s' % ( url.database , url.host )
+ else:
+ dsn = url.database
+
+ if url.username:
+ opt = { 'user':url.username , 'password': url.password }
+ else:
+ opt = {}
+
+ return ([dsn], opt)
+
+
+ def is_disconnect(self, e):
+ if isinstance(e, self.dbapi.OperationalError):
+ return 'closed the connection' in str(e) or 'connection not open' in str(e)
+ else:
+ return False
+
+
+dialect = Informix_informixdb
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.maxdb import base, sapdb
+
+base.dialect = sapdb.dialect
\ No newline at end of file
"""Support for the MaxDB database.
-TODO: More module docs! MaxDB support is currently experimental.
+This dialect is *not* tested on SQLAlchemy 0.6.
Overview
--------
from sqlalchemy import types as sqltypes
-__all__ = [
- 'MaxString', 'MaxUnicode', 'MaxChar', 'MaxText', 'MaxInteger',
- 'MaxSmallInteger', 'MaxNumeric', 'MaxFloat', 'MaxTimestamp',
- 'MaxDate', 'MaxTime', 'MaxBoolean', 'MaxBlob',
- ]
-
-
class _StringType(sqltypes.String):
_type = None
super(_StringType, self).__init__(length=length, **kw)
self.encoding = encoding
- def get_col_spec(self):
- if self.length is None:
- spec = 'LONG'
- else:
- spec = '%s(%s)' % (self._type, self.length)
-
- if self.encoding is not None:
- spec = ' '.join([spec, self.encoding.upper()])
- return spec
-
def bind_processor(self, dialect):
if self.encoding == 'unicode':
return None
return spec
-class MaxInteger(sqltypes.Integer):
- def get_col_spec(self):
- return 'INTEGER'
-
-
-class MaxSmallInteger(MaxInteger):
- def get_col_spec(self):
- return 'SMALLINT'
-
-
class MaxNumeric(sqltypes.Numeric):
"""The FIXED (also NUMERIC, DECIMAL) data type."""
def bind_processor(self, dialect):
return None
- def get_col_spec(self):
- if self.scale and self.precision:
- return 'FIXED(%s, %s)' % (self.precision, self.scale)
- elif self.precision:
- return 'FIXED(%s)' % self.precision
- else:
- return 'INTEGER'
-
-
-class MaxFloat(sqltypes.Float):
- """The FLOAT data type."""
-
- def get_col_spec(self):
- if self.precision is None:
- return 'FLOAT'
- else:
- return 'FLOAT(%s)' % (self.precision,)
-
-
class MaxTimestamp(sqltypes.DateTime):
- def get_col_spec(self):
- return 'TIMESTAMP'
-
def bind_processor(self, dialect):
def process(value):
if value is None:
class MaxDate(sqltypes.Date):
- def get_col_spec(self):
- return 'DATE'
-
def bind_processor(self, dialect):
def process(value):
if value is None:
class MaxTime(sqltypes.Time):
- def get_col_spec(self):
- return 'TIME'
-
def bind_processor(self, dialect):
def process(value):
if value is None:
return process
-class MaxBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return 'BOOLEAN'
-
-
class MaxBlob(sqltypes.Binary):
- def get_col_spec(self):
- return 'LONG BYTE'
-
def bind_processor(self, dialect):
def process(value):
if value is None:
return value.read(value.remainingLength())
return process
+class MaxDBTypeCompiler(compiler.GenericTypeCompiler):
+ def _string_spec(self, string_spec, type_):
+ if type_.length is None:
+ spec = 'LONG'
+ else:
+ spec = '%s(%s)' % (string_spec, type_.length)
+
+ if getattr(type_, 'encoding'):
+ spec = ' '.join([spec, getattr(type_, 'encoding').upper()])
+ return spec
+
+ def visit_text(self, type_):
+ spec = 'LONG'
+ if getattr(type_, 'encoding', None):
+ spec = ' '.join((spec, type_.encoding))
+ elif type_.convert_unicode:
+ spec = ' '.join((spec, 'UNICODE'))
+
+ return spec
+ def visit_char(self, type_):
+ return self._string_spec("CHAR", type_)
+
+ def visit_string(self, type_):
+ return self._string_spec("VARCHAR", type_)
+
+ def visit_binary(self, type_):
+ return "LONG BYTE"
+
+ def visit_numeric(self, type_):
+ if type_.scale and type_.precision:
+ return 'FIXED(%s, %s)' % (type_.precision, type_.scale)
+ elif type_.precision:
+ return 'FIXED(%s)' % type_.precision
+ else:
+ return 'INTEGER'
+
+ def visit_BOOLEAN(self, type_):
+ return "BOOLEAN"
+
colspecs = {
- sqltypes.Integer: MaxInteger,
- sqltypes.Smallinteger: MaxSmallInteger,
sqltypes.Numeric: MaxNumeric,
- sqltypes.Float: MaxFloat,
sqltypes.DateTime: MaxTimestamp,
sqltypes.Date: MaxDate,
sqltypes.Time: MaxTime,
sqltypes.String: MaxString,
+ sqltypes.Unicode:MaxUnicode,
sqltypes.Binary: MaxBlob,
- sqltypes.Boolean: MaxBoolean,
sqltypes.Text: MaxText,
sqltypes.CHAR: MaxChar,
sqltypes.TIMESTAMP: MaxTimestamp,
}
ischema_names = {
- 'boolean': MaxBoolean,
- 'char': MaxChar,
- 'character': MaxChar,
- 'date': MaxDate,
- 'fixed': MaxNumeric,
- 'float': MaxFloat,
- 'int': MaxInteger,
- 'integer': MaxInteger,
- 'long binary': MaxBlob,
- 'long unicode': MaxText,
- 'long': MaxText,
- 'long': MaxText,
- 'smallint': MaxSmallInteger,
- 'time': MaxTime,
- 'timestamp': MaxTimestamp,
- 'varchar': MaxString,
+ 'boolean': sqltypes.BOOLEAN,
+ 'char': sqltypes.CHAR,
+ 'character': sqltypes.CHAR,
+ 'date': sqltypes.DATE,
+ 'fixed': sqltypes.Numeric,
+ 'float': sqltypes.FLOAT,
+ 'int': sqltypes.INT,
+ 'integer': sqltypes.INT,
+ 'long binary': sqltypes.BLOB,
+ 'long unicode': sqltypes.Text,
+ 'long': sqltypes.Text,
+ 'long': sqltypes.Text,
+ 'smallint': sqltypes.SmallInteger,
+ 'time': sqltypes.Time,
+ 'timestamp': sqltypes.TIMESTAMP,
+ 'varchar': sqltypes.VARCHAR,
}
-
+# TODO: migrate this to sapdb.py
class MaxDBExecutionContext(default.DefaultExecutionContext):
def post_exec(self):
# DB-API bug: if there were any functions as values,
return MaxDBResultProxy(self)
return engine_base.ResultProxy(self)
+ @property
+ def rowcount(self):
+ if hasattr(self, '_rowcount'):
+ return self._rowcount
+ else:
+ return self.cursor.rowcount
class MaxDBCachedColumnRow(engine_base.RowProxy):
"""A RowProxy that only runs result_processors once per column."""
else:
return self._get_col(key)
- def __getattr__(self, name):
- try:
- return self._get_col(name)
- except KeyError:
- raise AttributeError(name)
-
-
-class MaxDBResultProxy(engine_base.ResultProxy):
- _process_row = MaxDBCachedColumnRow
-
-
-class MaxDBDialect(default.DefaultDialect):
- name = 'maxdb'
- supports_alter = True
- supports_unicode_statements = True
- max_identifier_length = 32
- supports_sane_rowcount = True
- supports_sane_multi_rowcount = False
- preexecute_pk_sequences = True
-
- # MaxDB-specific
- datetimeformat = 'internal'
-
- def __init__(self, _raise_known_sql_errors=False, **kw):
- super(MaxDBDialect, self).__init__(**kw)
- self._raise_known = _raise_known_sql_errors
-
- if self.dbapi is None:
- self.dbapi_type_map = {}
- else:
- self.dbapi_type_map = {
- 'Long Binary': MaxBlob(),
- 'Long byte_t': MaxBlob(),
- 'Long Unicode': MaxText(),
- 'Timestamp': MaxTimestamp(),
- 'Date': MaxDate(),
- 'Time': MaxTime(),
- datetime.datetime: MaxTimestamp(),
- datetime.date: MaxDate(),
- datetime.time: MaxTime(),
- }
-
- def dbapi(cls):
- from sapdb import dbapi as _dbapi
- return _dbapi
- dbapi = classmethod(dbapi)
-
- def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- opts.update(url.query)
- return [], opts
-
- def type_descriptor(self, typeobj):
- if isinstance(typeobj, type):
- typeobj = typeobj()
- if isinstance(typeobj, sqltypes.Unicode):
- return typeobj.adapt(MaxUnicode)
- else:
- return sqltypes.adapt_type(typeobj, colspecs)
-
- def do_execute(self, cursor, statement, parameters, context=None):
- res = cursor.execute(statement, parameters)
- if isinstance(res, int) and context is not None:
- context._rowcount = res
-
- def do_release_savepoint(self, connection, name):
- # Does MaxDB truly support RELEASE SAVEPOINT <id>? All my attempts
- # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS
- # BEGIN SQLSTATE: I7065"
- # Note that ROLLBACK TO works fine. In theory, a RELEASE should
- # just free up some transactional resources early, before the overall
- # COMMIT/ROLLBACK so omitting it should be relatively ok.
- pass
-
- def get_default_schema_name(self, connection):
- try:
- return self._default_schema_name
- except AttributeError:
- name = self.identifier_preparer._normalize_name(
- connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar())
- self._default_schema_name = name
- return name
-
- def has_table(self, connection, table_name, schema=None):
- denormalize = self.identifier_preparer._denormalize_name
- bind = [denormalize(table_name)]
- if schema is None:
- sql = ("SELECT tablename FROM TABLES "
- "WHERE TABLES.TABLENAME=? AND"
- " TABLES.SCHEMANAME=CURRENT_SCHEMA ")
- else:
- sql = ("SELECT tablename FROM TABLES "
- "WHERE TABLES.TABLENAME = ? AND"
- " TABLES.SCHEMANAME=? ")
- bind.append(denormalize(schema))
-
- rp = connection.execute(sql, bind)
- found = bool(rp.fetchone())
- rp.close()
- return found
-
- def table_names(self, connection, schema):
- if schema is None:
- sql = (" SELECT TABLENAME FROM TABLES WHERE "
- " SCHEMANAME=CURRENT_SCHEMA ")
- rs = connection.execute(sql)
- else:
- sql = (" SELECT TABLENAME FROM TABLES WHERE "
- " SCHEMANAME=? ")
- matchname = self.identifier_preparer._denormalize_name(schema)
- rs = connection.execute(sql, matchname)
- normalize = self.identifier_preparer._normalize_name
- return [normalize(row[0]) for row in rs]
-
- def reflecttable(self, connection, table, include_columns):
- denormalize = self.identifier_preparer._denormalize_name
- normalize = self.identifier_preparer._normalize_name
-
- st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, '
- ' NULLABLE, "DEFAULT", DEFAULTFUNCTION '
- 'FROM COLUMNS '
- 'WHERE TABLENAME=? AND SCHEMANAME=%s '
- 'ORDER BY POS')
-
- fk = ('SELECT COLUMNNAME, FKEYNAME, '
- ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, '
- ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA '
- ' THEN 1 ELSE 0 END) AS in_schema '
- 'FROM FOREIGNKEYCOLUMNS '
- 'WHERE TABLENAME=? AND SCHEMANAME=%s '
- 'ORDER BY FKEYNAME ')
-
- params = [denormalize(table.name)]
- if not table.schema:
- st = st % 'CURRENT_SCHEMA'
- fk = fk % 'CURRENT_SCHEMA'
- else:
- st = st % '?'
- fk = fk % '?'
- params.append(denormalize(table.schema))
-
- rows = connection.execute(st, params).fetchall()
- if not rows:
- raise exc.NoSuchTableError(table.fullname)
-
- include_columns = set(include_columns or [])
-
- for row in rows:
- (name, mode, col_type, encoding, length, scale,
- nullable, constant_def, func_def) = row
-
- name = normalize(name)
-
- if include_columns and name not in include_columns:
- continue
-
- type_args, type_kw = [], {}
- if col_type == 'FIXED':
- type_args = length, scale
- # Convert FIXED(10) DEFAULT SERIAL to our Integer
- if (scale == 0 and
- func_def is not None and func_def.startswith('SERIAL')):
- col_type = 'INTEGER'
- type_args = length,
- elif col_type in 'FLOAT':
- type_args = length,
- elif col_type in ('CHAR', 'VARCHAR'):
- type_args = length,
- type_kw['encoding'] = encoding
- elif col_type == 'LONG':
- type_kw['encoding'] = encoding
-
- try:
- type_cls = ischema_names[col_type.lower()]
- type_instance = type_cls(*type_args, **type_kw)
- except KeyError:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (col_type, name))
- type_instance = sqltypes.NullType
-
- col_kw = {'autoincrement': False}
- col_kw['nullable'] = (nullable == 'YES')
- col_kw['primary_key'] = (mode == 'KEY')
-
- if func_def is not None:
- if func_def.startswith('SERIAL'):
- if col_kw['primary_key']:
- # No special default- let the standard autoincrement
- # support handle SERIAL pk columns.
- col_kw['autoincrement'] = True
- else:
- # strip current numbering
- col_kw['server_default'] = schema.DefaultClause(
- sql.text('SERIAL'))
- col_kw['autoincrement'] = True
- else:
- col_kw['server_default'] = schema.DefaultClause(
- sql.text(func_def))
- elif constant_def is not None:
- col_kw['server_default'] = schema.DefaultClause(sql.text(
- "'%s'" % constant_def.replace("'", "''")))
-
- table.append_column(schema.Column(name, type_instance, **col_kw))
-
- fk_sets = itertools.groupby(connection.execute(fk, params),
- lambda row: row.FKEYNAME)
- for fkeyname, fkey in fk_sets:
- fkey = list(fkey)
- if include_columns:
- key_cols = set([r.COLUMNNAME for r in fkey])
- if key_cols != include_columns:
- continue
-
- columns, referants = [], []
- quote = self.identifier_preparer._maybe_quote_identifier
-
- for row in fkey:
- columns.append(normalize(row.COLUMNNAME))
- if table.schema or not row.in_schema:
- referants.append('.'.join(
- [quote(normalize(row[c]))
- for c in ('REFSCHEMANAME', 'REFTABLENAME',
- 'REFCOLUMNNAME')]))
- else:
- referants.append('.'.join(
- [quote(normalize(row[c]))
- for c in ('REFTABLENAME', 'REFCOLUMNNAME')]))
-
- constraint_kw = {'name': fkeyname.lower()}
- if fkey[0].RULE is not None:
- rule = fkey[0].RULE
- if rule.startswith('DELETE '):
- rule = rule[7:]
- constraint_kw['ondelete'] = rule
-
- table_kw = {}
- if table.schema or not row.in_schema:
- table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME)
-
- ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME),
- table_kw.get('schema'))
- if ref_key not in table.metadata.tables:
- schema.Table(normalize(fkey[0].REFTABLENAME),
- table.metadata,
- autoload=True, autoload_with=connection,
- **table_kw)
-
- constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True,
- **constraint_kw)
- table.append_constraint(constraint)
-
- def has_sequence(self, connection, name):
- # [ticket:726] makes this schema-aware.
- denormalize = self.identifier_preparer._denormalize_name
- sql = ("SELECT sequence_name FROM SEQUENCES "
- "WHERE SEQUENCE_NAME=? ")
+ def __getattr__(self, name):
+ try:
+ return self._get_col(name)
+ except KeyError:
+ raise AttributeError(name)
- rp = connection.execute(sql, denormalize(name))
- found = bool(rp.fetchone())
- rp.close()
- return found
+class MaxDBResultProxy(engine_base.ResultProxy):
+ _process_row = MaxDBCachedColumnRow
-class MaxDBCompiler(compiler.DefaultCompiler):
- operators = compiler.DefaultCompiler.operators.copy()
- operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y)
+class MaxDBCompiler(compiler.SQLCompiler):
function_conversion = {
'CURRENT_DATE': 'DATE',
'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP',
'UTCDATE', 'UTCDIFF'])
+ def visit_mod(self, binary, **kw):
+ return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
def default_from(self):
return ' FROM DUAL'
else:
return " WITH LOCK EXCLUSIVE"
- def apply_function_parens(self, func):
- if func.name.upper() in self.bare_functions:
- return len(func.clauses) > 0
+ def function_argspec(self, fn, **kw):
+ if fn.name.upper() in self.bare_functions:
+ return ""
+ elif len(fn.clauses) > 0:
+ return compiler.SQLCompiler.function_argspec(self, fn, **kw)
else:
- return True
+ return ""
def visit_function(self, fn, **kw):
transform = self.function_conversion.get(fn.name.upper(), None)
return name
-class MaxDBSchemaGenerator(compiler.SchemaGenerator):
+class MaxDBDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kw):
colspec = [self.preparer.format_column(column),
- column.type.dialect_impl(self.dialect).get_col_spec()]
+ self.dialect.type_compiler.process(column.type)]
if not column.nullable:
colspec.append('NOT NULL')
else:
return None
- def visit_sequence(self, sequence):
+ def visit_create_sequence(self, create):
"""Creates a SEQUENCE.
TODO: move to module doc?
maxdb_no_cache
Defaults to False. If true, sets NOCACHE.
"""
-
+ sequence = create.element
+
if (not sequence.optional and
(not self.checkfirst or
not self.dialect.has_sequence(self.connection, sequence.name))):
elif opts.get('no_cache', False):
ddl.append('NOCACHE')
- self.append(' '.join(ddl))
- self.execute()
+ return ' '.join(ddl)
-class MaxDBSchemaDropper(compiler.SchemaDropper):
- def visit_sequence(self, sequence):
- if (not sequence.optional and
- (not self.checkfirst or
- self.dialect.has_sequence(self.connection, sequence.name))):
- self.append("DROP SEQUENCE %s" %
- self.preparer.format_sequence(sequence))
- self.execute()
+class MaxDBDialect(default.DefaultDialect):
+ name = 'maxdb'
+ supports_alter = True
+ supports_unicode_statements = True
+ max_identifier_length = 32
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+
+ preparer = MaxDBIdentifierPreparer
+ statement_compiler = MaxDBCompiler
+ ddl_compiler = MaxDBDDLCompiler
+ defaultrunner = MaxDBDefaultRunner
+ execution_ctx_cls = MaxDBExecutionContext
+
+ colspecs = colspecs
+ ischema_names = ischema_names
+
+ # MaxDB-specific
+ datetimeformat = 'internal'
+
+ def __init__(self, _raise_known_sql_errors=False, **kw):
+ super(MaxDBDialect, self).__init__(**kw)
+ self._raise_known = _raise_known_sql_errors
+
+ if self.dbapi is None:
+ self.dbapi_type_map = {}
+ else:
+ self.dbapi_type_map = {
+ 'Long Binary': MaxBlob(),
+ 'Long byte_t': MaxBlob(),
+ 'Long Unicode': MaxText(),
+ 'Timestamp': MaxTimestamp(),
+ 'Date': MaxDate(),
+ 'Time': MaxTime(),
+ datetime.datetime: MaxTimestamp(),
+ datetime.date: MaxDate(),
+ datetime.time: MaxTime(),
+ }
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ res = cursor.execute(statement, parameters)
+ if isinstance(res, int) and context is not None:
+ context._rowcount = res
+
+ def do_release_savepoint(self, connection, name):
+ # Does MaxDB truly support RELEASE SAVEPOINT <id>? All my attempts
+ # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS
+ # BEGIN SQLSTATE: I7065"
+ # Note that ROLLBACK TO works fine. In theory, a RELEASE should
+ # just free up some transactional resources early, before the overall
+ # COMMIT/ROLLBACK so omitting it should be relatively ok.
+ pass
+
+ def get_default_schema_name(self, connection):
+ try:
+ return self._default_schema_name
+ except AttributeError:
+ name = self.identifier_preparer._normalize_name(
+ connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar())
+ self._default_schema_name = name
+ return name
+
+ def has_table(self, connection, table_name, schema=None):
+ denormalize = self.identifier_preparer._denormalize_name
+ bind = [denormalize(table_name)]
+ if schema is None:
+ sql = ("SELECT tablename FROM TABLES "
+ "WHERE TABLES.TABLENAME=? AND"
+ " TABLES.SCHEMANAME=CURRENT_SCHEMA ")
+ else:
+ sql = ("SELECT tablename FROM TABLES "
+ "WHERE TABLES.TABLENAME = ? AND"
+ " TABLES.SCHEMANAME=? ")
+ bind.append(denormalize(schema))
+
+ rp = connection.execute(sql, bind)
+ found = bool(rp.fetchone())
+ rp.close()
+ return found
+
+ def table_names(self, connection, schema):
+ if schema is None:
+ sql = (" SELECT TABLENAME FROM TABLES WHERE "
+ " SCHEMANAME=CURRENT_SCHEMA ")
+ rs = connection.execute(sql)
+ else:
+ sql = (" SELECT TABLENAME FROM TABLES WHERE "
+ " SCHEMANAME=? ")
+ matchname = self.identifier_preparer._denormalize_name(schema)
+ rs = connection.execute(sql, matchname)
+ normalize = self.identifier_preparer._normalize_name
+ return [normalize(row[0]) for row in rs]
+
+ def reflecttable(self, connection, table, include_columns):
+ denormalize = self.identifier_preparer._denormalize_name
+ normalize = self.identifier_preparer._normalize_name
+
+ st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, '
+ ' NULLABLE, "DEFAULT", DEFAULTFUNCTION '
+ 'FROM COLUMNS '
+ 'WHERE TABLENAME=? AND SCHEMANAME=%s '
+ 'ORDER BY POS')
+
+ fk = ('SELECT COLUMNNAME, FKEYNAME, '
+ ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, '
+ ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA '
+ ' THEN 1 ELSE 0 END) AS in_schema '
+ 'FROM FOREIGNKEYCOLUMNS '
+ 'WHERE TABLENAME=? AND SCHEMANAME=%s '
+ 'ORDER BY FKEYNAME ')
+
+ params = [denormalize(table.name)]
+ if not table.schema:
+ st = st % 'CURRENT_SCHEMA'
+ fk = fk % 'CURRENT_SCHEMA'
+ else:
+ st = st % '?'
+ fk = fk % '?'
+ params.append(denormalize(table.schema))
+
+ rows = connection.execute(st, params).fetchall()
+ if not rows:
+ raise exc.NoSuchTableError(table.fullname)
+
+ include_columns = set(include_columns or [])
+
+ for row in rows:
+ (name, mode, col_type, encoding, length, scale,
+ nullable, constant_def, func_def) = row
+
+ name = normalize(name)
+
+ if include_columns and name not in include_columns:
+ continue
+
+ type_args, type_kw = [], {}
+ if col_type == 'FIXED':
+ type_args = length, scale
+ # Convert FIXED(10) DEFAULT SERIAL to our Integer
+ if (scale == 0 and
+ func_def is not None and func_def.startswith('SERIAL')):
+ col_type = 'INTEGER'
+ type_args = length,
+ elif col_type in 'FLOAT':
+ type_args = length,
+ elif col_type in ('CHAR', 'VARCHAR'):
+ type_args = length,
+ type_kw['encoding'] = encoding
+ elif col_type == 'LONG':
+ type_kw['encoding'] = encoding
+
+ try:
+ type_cls = ischema_names[col_type.lower()]
+ type_instance = type_cls(*type_args, **type_kw)
+ except KeyError:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (col_type, name))
+ type_instance = sqltypes.NullType
+
+ col_kw = {'autoincrement': False}
+ col_kw['nullable'] = (nullable == 'YES')
+ col_kw['primary_key'] = (mode == 'KEY')
+
+ if func_def is not None:
+ if func_def.startswith('SERIAL'):
+ if col_kw['primary_key']:
+ # No special default- let the standard autoincrement
+ # support handle SERIAL pk columns.
+ col_kw['autoincrement'] = True
+ else:
+ # strip current numbering
+ col_kw['server_default'] = schema.DefaultClause(
+ sql.text('SERIAL'))
+ col_kw['autoincrement'] = True
+ else:
+ col_kw['server_default'] = schema.DefaultClause(
+ sql.text(func_def))
+ elif constant_def is not None:
+ col_kw['server_default'] = schema.DefaultClause(sql.text(
+ "'%s'" % constant_def.replace("'", "''")))
+
+ table.append_column(schema.Column(name, type_instance, **col_kw))
+
+ fk_sets = itertools.groupby(connection.execute(fk, params),
+ lambda row: row.FKEYNAME)
+ for fkeyname, fkey in fk_sets:
+ fkey = list(fkey)
+ if include_columns:
+ key_cols = set([r.COLUMNNAME for r in fkey])
+ if key_cols != include_columns:
+ continue
+
+ columns, referants = [], []
+ quote = self.identifier_preparer._maybe_quote_identifier
+
+ for row in fkey:
+ columns.append(normalize(row.COLUMNNAME))
+ if table.schema or not row.in_schema:
+ referants.append('.'.join(
+ [quote(normalize(row[c]))
+ for c in ('REFSCHEMANAME', 'REFTABLENAME',
+ 'REFCOLUMNNAME')]))
+ else:
+ referants.append('.'.join(
+ [quote(normalize(row[c]))
+ for c in ('REFTABLENAME', 'REFCOLUMNNAME')]))
+
+ constraint_kw = {'name': fkeyname.lower()}
+ if fkey[0].RULE is not None:
+ rule = fkey[0].RULE
+ if rule.startswith('DELETE '):
+ rule = rule[7:]
+ constraint_kw['ondelete'] = rule
+
+ table_kw = {}
+ if table.schema or not row.in_schema:
+ table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME)
+
+ ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME),
+ table_kw.get('schema'))
+ if ref_key not in table.metadata.tables:
+ schema.Table(normalize(fkey[0].REFTABLENAME),
+ table.metadata,
+ autoload=True, autoload_with=connection,
+ **table_kw)
+
+ constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True,
+ **constraint_kw)
+ table.append_constraint(constraint)
+
+ def has_sequence(self, connection, name):
+ # [ticket:726] makes this schema-aware.
+ denormalize = self.identifier_preparer._denormalize_name
+ sql = ("SELECT sequence_name FROM SEQUENCES "
+ "WHERE SEQUENCE_NAME=? ")
+
+ rp = connection.execute(sql, denormalize(name))
+ found = bool(rp.fetchone())
+ rp.close()
+ return found
+
def _autoserial_column(table):
return None, None
-dialect = MaxDBDialect
-dialect.preparer = MaxDBIdentifierPreparer
-dialect.statement_compiler = MaxDBCompiler
-dialect.schemagenerator = MaxDBSchemaGenerator
-dialect.schemadropper = MaxDBSchemaDropper
-dialect.defaultrunner = MaxDBDefaultRunner
-dialect.execution_ctx_cls = MaxDBExecutionContext
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.maxdb.base import MaxDBDialect
+
+class MaxDB_sapdb(MaxDBDialect):
+ driver = 'sapdb'
+
+ @classmethod
+ def dbapi(cls):
+ from sapdb import dbapi as _dbapi
+ return _dbapi
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username='user')
+ opts.update(url.query)
+ return [], opts
+
+
+dialect = MaxDB_sapdb
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, pymssql
+
+base.dialect = pyodbc.dialect
\ No newline at end of file
--- /dev/null
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
+import sys
+
+class MSDateTime_adodbapi(MSDateTime):
+ def result_processor(self, dialect):
+ def process(value):
+ # adodbapi will return datetimes with empty time values as datetime.date() objects.
+ # Promote them back to full datetime.datetime()
+ if type(value) is datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ return value
+ return process
+
+
+class MSDialect_adodbapi(MSDialect):
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+ supports_unicode = sys.maxunicode == 65535
+ supports_unicode_statements = True
+ driver = 'adodbapi'
+
+ @classmethod
+ def import_dbapi(cls):
+ import adodbapi as module
+ return module
+
+ colspecs = MSDialect.colspecs.copy()
+ colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
+
+ def create_connect_args(self, url):
+ keys = url.query
+
+ connectors = ["Provider=SQLOLEDB"]
+ if 'port' in keys:
+ connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
+ else:
+ connectors.append ("Data Source=%s" % keys.get("host"))
+ connectors.append ("Initial Catalog=%s" % keys.get("database"))
+ user = keys.get("user")
+ if user:
+ connectors.append("User Id=%s" % user)
+ connectors.append("Password=%s" % keys.get("password", ""))
+ else:
+ connectors.append("Integrated Security=SSPI")
+ return [[";".join (connectors)], {}]
+
+ def is_disconnect(self, e):
+ return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
+
+dialect = MSDialect_adodbapi
--- /dev/null
+# mssql.py
+
+"""Support for the Microsoft SQL Server database.
+
+Driver
+------
+
+The MSSQL dialect will work with three different available drivers:
+
+* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded
+ driver.
+
+* *pymssql* - http://pymssql.sourceforge.net/
+
+* *adodbapi* - http://adodbapi.sourceforge.net/
+
+Drivers are loaded in the order listed above based on availability.
+
+If you need to load a specific driver pass ``module_name`` when
+creating the engine::
+
+ engine = create_engine('mssql+module_name://dsn')
+
+``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and
+``adodbapi``.
+
+Currently the pyodbc driver offers the greatest level of
+compatibility.
+
+Connecting
+----------
+
+Connecting with create_engine() uses the standard URL approach of
+``mssql://user:pass@host/dbname[?key=value&key=value...]``.
+
+If the database name is present, the tokens are converted to a
+connection string with the specified values. If the database is not
+present, then the host token is taken directly as the DSN name.
+
+Examples of pyodbc connection string URLs:
+
+* *mssql+pyodbc://mydsn* - connects using the specified DSN named ``mydsn``.
+ The connection string that is created will appear like::
+
+ dsn=mydsn;TrustedConnection=Yes
+
+* *mssql+pyodbc://user:pass@mydsn* - connects using the DSN named
+ ``mydsn`` passing in the ``UID`` and ``PWD`` information. The
+ connection string that is created will appear like::
+
+ dsn=mydsn;UID=user;PWD=pass
+
+* *mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english* - connects
+ using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
+ information, plus the additional connection configuration option
+ ``LANGUAGE``. The connection string that is created will appear
+ like::
+
+ dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
+
+* *mssql+pyodbc://user:pass@host/db* - connects using a connection string
+ dynamically created that would appear like::
+
+ DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
+
+* *mssql+pyodbc://user:pass@host:123/db* - connects using a connection
+ string that is dynamically created, which also includes the port
+ information using the comma syntax. If your connection string
+ requires the port information to be passed as a ``port`` keyword
+ see the next example. This will create the following connection
+ string::
+
+ DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
+
+* *mssql+pyodbc://user:pass@host/db?port=123* - connects using a connection
+ string that is dynamically created that includes the port
+ information as a separate ``port`` keyword. This will create the
+ following connection string::
+
+ DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
+
+If you require a connection string that is outside the options
+presented above, use the ``odbc_connect`` keyword to pass in a
+urlencoded connection string. What gets passed in will be urldecoded
+and passed directly.
+
+For example::
+
+ mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
+
+would create the following connection string::
+
+ dsn=mydsn;Database=db
+
+Encoding your connection string can be easily accomplished through
+the python shell. For example::
+
+ >>> import urllib
+ >>> urllib.quote_plus('dsn=mydsn;Database=db')
+ 'dsn%3Dmydsn%3BDatabase%3Ddb'
+
+Additional arguments which may be specified either as query string
+arguments on the URL, or as keyword argument to
+:func:`~sqlalchemy.create_engine()` are:
+
+* *query_timeout* - allows you to override the default query timeout.
+ Defaults to ``None``. This is only supported on pymssql.
+
+* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY
+ should be used in place of the non-scoped version @@IDENTITY.
+ Defaults to True.
+
+* *max_identifier_length* - allows you to se the maximum length of
+ identfiers supported by the database. Defaults to 128. For pymssql
+ the default is 30.
+
+* *schema_name* - use to set the schema name. Defaults to ``dbo``.
+
+Auto Increment Behavior
+-----------------------
+
+``IDENTITY`` columns are supported by using SQLAlchemy
+``schema.Sequence()`` objects. In other words::
+
+ Table('test', mss_engine,
+ Column('id', Integer,
+ Sequence('blah',100,10), primary_key=True),
+ Column('name', String(20))
+ ).create()
+
+would yield::
+
+ CREATE TABLE test (
+ id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
+ name VARCHAR(20) NULL,
+ )
+
+Note that the ``start`` and ``increment`` values for sequences are
+optional and will default to 1,1.
+
+* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
+ ``INSERT`` s)
+
+* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on
+ ``INSERT``
+
+Collation Support
+-----------------
+
+MSSQL specific string types support a collation parameter that
+creates a column-level specific collation for the column. The
+collation parameter accepts a Windows Collation Name or a SQL
+Collation Name. Supported types are MSChar, MSNChar, MSString,
+MSNVarchar, MSText, and MSNText. For example::
+
+ Column('login', String(32, collation='Latin1_General_CI_AS'))
+
+will yield::
+
+ login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL
+
+LIMIT/OFFSET Support
+--------------------
+
+MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is
+supported directly through the ``TOP`` Transact SQL keyword::
+
+ select.limit
+
+will yield::
+
+ SELECT TOP n
+
+If using SQL Server 2005 or above, LIMIT with OFFSET
+support is available through the ``ROW_NUMBER OVER`` construct.
+For versions below 2005, LIMIT with OFFSET usage will fail.
+
+Nullability
+-----------
+MSSQL has support for three levels of column nullability. The default
+nullability allows nulls and is explicit in the CREATE TABLE
+construct::
+
+ name VARCHAR(20) NULL
+
+If ``nullable=None`` is specified then no specification is made. In
+other words the database's configured default is used. This will
+render::
+
+ name VARCHAR(20)
+
+If ``nullable`` is ``True`` or ``False`` then the column will be
+``NULL` or ``NOT NULL`` respectively.
+
+Date / Time Handling
+--------------------
+DATE and TIME are supported. Bind parameters are converted
+to datetime.datetime() objects as required by most MSSQL drivers,
+and results are processed from strings if needed.
+The DATE and TIME types are not available for MSSQL 2005 and
+previous - if a server version below 2008 is detected, DDL
+for these types will be issued as DATETIME.
+
+Compatibility Levels
+--------------------
+MSSQL supports the notion of setting compatibility levels at the
+database level. This allows, for instance, to run a database that
+is compatibile with SQL2000 while running on a SQL2005 database
+server. ``server_version_info`` will always retrun the database
+server version information (in this case SQL2005) and not the
+compatibiility level information. Because of this, if running under
+a backwards compatibility mode SQAlchemy may attempt to use T-SQL
+statements that are unable to be parsed by the database server.
+
+Known Issues
+------------
+
+* No support for more than one ``IDENTITY`` column per table
+
+* pymssql has problems with binary and unicode data that this module
+ does **not** work around
+
+"""
+import datetime, decimal, inspect, operator, sys, re
+import itertools
+
+from sqlalchemy import sql, schema as sa_schema, exc, util
+from sqlalchemy.sql import select, compiler, expression, \
+ operators as sql_operators, \
+ functions as sql_functions, util as sql_util
+from sqlalchemy.engine import default, base, reflection
+from sqlalchemy import types as sqltypes
+from decimal import Decimal as _python_Decimal
+from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \
+ FLOAT, TIMESTAMP, DATETIME, DATE
+
+
+from sqlalchemy.dialects.mssql import information_schema as ischema
+
+MS_2008_VERSION = (10,)
+MS_2005_VERSION = (9,)
+MS_2000_VERSION = (8,)
+
+RESERVED_WORDS = set(
+ ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization',
+ 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade',
+ 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce',
+ 'collate', 'column', 'commit', 'compute', 'constraint', 'contains',
+ 'containstable', 'continue', 'convert', 'create', 'cross', 'current',
+ 'current_date', 'current_time', 'current_timestamp', 'current_user',
+ 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default',
+ 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double',
+ 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec',
+ 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor',
+ 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full',
+ 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity',
+ 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert',
+ 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like',
+ 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not',
+ 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource',
+ 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer',
+ 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print',
+ 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext',
+ 'reconfigure', 'references', 'replication', 'restore', 'restrict',
+ 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount',
+ 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select',
+ 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics',
+ 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top',
+ 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union',
+ 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values',
+ 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with',
+ 'writetext',
+ ])
+
+
+class _MSNumeric(sqltypes.Numeric):
+ def result_processor(self, dialect):
+ if self.asdecimal:
+ def process(value):
+ if value is not None:
+ return _python_Decimal(str(value))
+ else:
+ return value
+ return process
+ else:
+ def process(value):
+ return float(value)
+ return process
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if isinstance(value, decimal.Decimal):
+ if value.adjusted() < 0:
+ result = "%s0.%s%s" % (
+ (value < 0 and '-' or ''),
+ '0' * (abs(value.adjusted()) - 1),
+ "".join([str(nint) for nint in value._int]))
+
+ else:
+ if 'E' in str(value):
+ result = "%s%s%s" % (
+ (value < 0 and '-' or ''),
+ "".join([str(s) for s in value._int]),
+ "0" * (value.adjusted() - (len(value._int)-1)))
+ else:
+ if (len(value._int) - 1) > value.adjusted():
+ result = "%s%s.%s" % (
+ (value < 0 and '-' or ''),
+ "".join([str(s) for s in value._int][0:value.adjusted() + 1]),
+ "".join([str(s) for s in value._int][value.adjusted() + 1:]))
+ else:
+ result = "%s%s" % (
+ (value < 0 and '-' or ''),
+ "".join([str(s) for s in value._int][0:value.adjusted() + 1]))
+
+ return result
+
+ else:
+ return value
+
+ return process
+
+class REAL(sqltypes.Float):
+ """A type for ``real`` numbers."""
+
+ __visit_name__ = 'REAL'
+
+ def __init__(self):
+ super(REAL, self).__init__(precision=24)
+
+class TINYINT(sqltypes.Integer):
+ __visit_name__ = 'TINYINT'
+
+
+# MSSQL DATE/TIME types have varied behavior, sometimes returning
+# strings. MSDate/TIME check for everything, and always
+# filter bind parameters into datetime objects (required by pyodbc,
+# not sure about other dialects).
+
+class _MSDate(sqltypes.Date):
+ def bind_processor(self, dialect):
+ def process(value):
+ if type(value) == datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
+ return process
+
+ _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
+ def result_processor(self, dialect):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ return value.date()
+ elif isinstance(value, basestring):
+ return datetime.date(*[int(x or 0) for x in self._reg.match(value).groups()])
+ else:
+ return value
+ return process
+
+class TIME(sqltypes.TIME):
+ def __init__(self, precision=None, **kwargs):
+ self.precision = precision
+ super(TIME, self).__init__()
+
+ __zero_date = datetime.date(1900, 1, 1)
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ value = datetime.datetime.combine(self.__zero_date, value.time())
+ elif isinstance(value, datetime.time):
+ value = datetime.datetime.combine(self.__zero_date, value)
+ return value
+ return process
+
+ _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
+ def result_processor(self, dialect):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ return value.time()
+ elif isinstance(value, basestring):
+ return datetime.time(*[int(x or 0) for x in self._reg.match(value).groups()])
+ else:
+ return value
+ return process
+
+
+class _DateTimeBase(object):
+ def bind_processor(self, dialect):
+ def process(value):
+ # TODO: why ?
+ if type(value) == datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
+ return process
+
+class _MSDateTime(_DateTimeBase, sqltypes.DateTime):
+ pass
+
+class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):
+ __visit_name__ = 'SMALLDATETIME'
+
+class DATETIME2(_DateTimeBase, sqltypes.DateTime):
+ __visit_name__ = 'DATETIME2'
+
+ def __init__(self, precision=None, **kwargs):
+ self.precision = precision
+
+
+# TODO: is this not an Interval ?
+class DATETIMEOFFSET(sqltypes.TypeEngine):
+ __visit_name__ = 'DATETIMEOFFSET'
+
+ def __init__(self, precision=None, **kwargs):
+ self.precision = precision
+
+
+class _StringType(object):
+ """Base for MSSQL string types."""
+
+ def __init__(self, collation=None):
+ self.collation = collation
+
+ def __repr__(self):
+ attributes = inspect.getargspec(self.__init__)[0][1:]
+ attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
+
+ params = {}
+ for attr in attributes:
+ val = getattr(self, attr)
+ if val is not None and val is not False:
+ params[attr] = val
+
+ return "%s(%s)" % (self.__class__.__name__,
+ ', '.join(['%s=%r' % (k, params[k]) for k in params]))
+
+
+class TEXT(_StringType, sqltypes.TEXT):
+ """MSSQL TEXT type, for variable-length text up to 2^31 characters."""
+
+ def __init__(self, *args, **kw):
+ """Construct a TEXT.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+ """
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.Text.__init__(self, *args, **kw)
+
+class NTEXT(_StringType, sqltypes.UnicodeText):
+ """MSSQL NTEXT type, for variable-length unicode text up to 2^30
+ characters."""
+
+ __visit_name__ = 'NTEXT'
+
+ def __init__(self, *args, **kwargs):
+ """Construct a NTEXT.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+ """
+ collation = kwargs.pop('collation', None)
+ _StringType.__init__(self, collation)
+ length = kwargs.pop('length', None)
+ sqltypes.UnicodeText.__init__(self, length, **kwargs)
+
+
+class VARCHAR(_StringType, sqltypes.VARCHAR):
+ """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum
+ of 8,000 characters."""
+
+ def __init__(self, *args, **kw):
+ """Construct a VARCHAR.
+
+ :param length: Optinal, maximum data length, in characters.
+
+ :param convert_unicode: defaults to False. If True, convert
+ ``unicode`` data sent to the database to a ``str``
+ bytestring, and convert bytestrings coming back from the
+ database into ``unicode``.
+
+ Bytestrings are encoded using the dialect's
+ :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which
+ defaults to `utf-8`.
+
+ If False, may be overridden by
+ :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`.
+
+ :param assert_unicode:
+
+ If None (the default), no assertion will take place unless
+ overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`.
+
+ If 'warn', will issue a runtime warning if a ``str``
+ instance is used as a bind value.
+
+ If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+ """
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.VARCHAR.__init__(self, *args, **kw)
+
+class NVARCHAR(_StringType, sqltypes.NVARCHAR):
+ """MSSQL NVARCHAR type.
+
+ For variable-length unicode character data up to 4,000 characters."""
+
+ def __init__(self, *args, **kw):
+ """Construct a NVARCHAR.
+
+ :param length: Optional, Maximum data length, in characters.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+ """
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.NVARCHAR.__init__(self, *args, **kw)
+
+class CHAR(_StringType, sqltypes.CHAR):
+ """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum
+ of 8,000 characters."""
+
+ def __init__(self, *args, **kw):
+ """Construct a CHAR.
+
+ :param length: Optinal, maximum data length, in characters.
+
+ :param convert_unicode: defaults to False. If True, convert
+ ``unicode`` data sent to the database to a ``str``
+ bytestring, and convert bytestrings coming back from the
+ database into ``unicode``.
+
+ Bytestrings are encoded using the dialect's
+ :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which
+ defaults to `utf-8`.
+
+ If False, may be overridden by
+ :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`.
+
+ :param assert_unicode:
+
+ If None (the default), no assertion will take place unless
+ overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`.
+
+ If 'warn', will issue a runtime warning if a ``str``
+ instance is used as a bind value.
+
+ If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+ """
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.CHAR.__init__(self, *args, **kw)
+
+class NCHAR(_StringType, sqltypes.NCHAR):
+ """MSSQL NCHAR type.
+
+ For fixed-length unicode character data up to 4,000 characters."""
+
+ def __init__(self, *args, **kw):
+ """Construct an NCHAR.
+
+ :param length: Optional, Maximum data length, in characters.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Accepts a Windows Collation Name or a SQL Collation Name.
+
+ """
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.NCHAR.__init__(self, *args, **kw)
+
+class BINARY(sqltypes.Binary):
+ __visit_name__ = 'BINARY'
+
+class VARBINARY(sqltypes.Binary):
+ __visit_name__ = 'VARBINARY'
+
+class IMAGE(sqltypes.Binary):
+ __visit_name__ = 'IMAGE'
+
+class BIT(sqltypes.TypeEngine):
+ __visit_name__ = 'BIT'
+
+class _MSBoolean(sqltypes.Boolean):
+ def result_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ return value and True or False
+ return process
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is True:
+ return 1
+ elif value is False:
+ return 0
+ elif value is None:
+ return None
+ else:
+ return value and True or False
+ return process
+
+class MONEY(sqltypes.TypeEngine):
+ __visit_name__ = 'MONEY'
+
+class SMALLMONEY(sqltypes.TypeEngine):
+ __visit_name__ = 'SMALLMONEY'
+
+class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
+ __visit_name__ = "UNIQUEIDENTIFIER"
+
+class SQL_VARIANT(sqltypes.TypeEngine):
+ __visit_name__ = 'SQL_VARIANT'
+
+# old names.
+MSNumeric = _MSNumeric
+MSDateTime = _MSDateTime
+MSDate = _MSDate
+MSBoolean = _MSBoolean
+MSReal = REAL
+MSTinyInteger = TINYINT
+MSTime = TIME
+MSSmallDateTime = SMALLDATETIME
+MSDateTime2 = DATETIME2
+MSDateTimeOffset = DATETIMEOFFSET
+MSText = TEXT
+MSNText = NTEXT
+MSString = VARCHAR
+MSNVarchar = NVARCHAR
+MSChar = CHAR
+MSNChar = NCHAR
+MSBinary = BINARY
+MSVarBinary = VARBINARY
+MSImage = IMAGE
+MSBit = BIT
+MSMoney = MONEY
+MSSmallMoney = SMALLMONEY
+MSUniqueIdentifier = UNIQUEIDENTIFIER
+MSVariant = SQL_VARIANT
+
+colspecs = {
+ sqltypes.Numeric : _MSNumeric,
+ sqltypes.DateTime : _MSDateTime,
+ sqltypes.Date : _MSDate,
+ sqltypes.Time : TIME,
+ sqltypes.Boolean : _MSBoolean,
+}
+
+ischema_names = {
+ 'int' : INTEGER,
+ 'bigint': BIGINT,
+ 'smallint' : SMALLINT,
+ 'tinyint' : TINYINT,
+ 'varchar' : VARCHAR,
+ 'nvarchar' : NVARCHAR,
+ 'char' : CHAR,
+ 'nchar' : NCHAR,
+ 'text' : TEXT,
+ 'ntext' : NTEXT,
+ 'decimal' : DECIMAL,
+ 'numeric' : NUMERIC,
+ 'float' : FLOAT,
+ 'datetime' : DATETIME,
+ 'datetime2' : DATETIME2,
+ 'datetimeoffset' : DATETIMEOFFSET,
+ 'date': DATE,
+ 'time': TIME,
+ 'smalldatetime' : SMALLDATETIME,
+ 'binary' : BINARY,
+ 'varbinary' : VARBINARY,
+ 'bit': BIT,
+ 'real' : REAL,
+ 'image' : IMAGE,
+ 'timestamp': TIMESTAMP,
+ 'money': MONEY,
+ 'smallmoney': SMALLMONEY,
+ 'uniqueidentifier': UNIQUEIDENTIFIER,
+ 'sql_variant': SQL_VARIANT,
+}
+
+
+class MSTypeCompiler(compiler.GenericTypeCompiler):
+ def _extend(self, spec, type_):
+ """Extend a string-type declaration with standard SQL
+ COLLATE annotations.
+
+ """
+
+ if getattr(type_, 'collation', None):
+ collation = 'COLLATE %s' % type_.collation
+ else:
+ collation = None
+
+ if type_.length:
+ spec = spec + "(%d)" % type_.length
+
+ return ' '.join([c for c in (spec, collation)
+ if c is not None])
+
+ def visit_FLOAT(self, type_):
+ precision = getattr(type_, 'precision', None)
+ if precision is None:
+ return "FLOAT"
+ else:
+ return "FLOAT(%(precision)s)" % {'precision': precision}
+
+ def visit_REAL(self, type_):
+ return "REAL"
+
+ def visit_TINYINT(self, type_):
+ return "TINYINT"
+
+ def visit_DATETIMEOFFSET(self, type_):
+ if type_.precision:
+ return "DATETIMEOFFSET(%s)" % type_.precision
+ else:
+ return "DATETIMEOFFSET"
+
+ def visit_TIME(self, type_):
+ precision = getattr(type_, 'precision', None)
+ if precision:
+ return "TIME(%s)" % precision
+ else:
+ return "TIME"
+
+ def visit_DATETIME2(self, type_):
+ precision = getattr(type_, 'precision', None)
+ if precision:
+ return "DATETIME2(%s)" % precision
+ else:
+ return "DATETIME2"
+
+ def visit_SMALLDATETIME(self, type_):
+ return "SMALLDATETIME"
+
+ def visit_unicode(self, type_):
+ return self.visit_NVARCHAR(type_)
+
+ def visit_unicode_text(self, type_):
+ return self.visit_NTEXT(type_)
+
+ def visit_NTEXT(self, type_):
+ return self._extend("NTEXT", type_)
+
+ def visit_TEXT(self, type_):
+ return self._extend("TEXT", type_)
+
+ def visit_VARCHAR(self, type_):
+ return self._extend("VARCHAR", type_)
+
+ def visit_CHAR(self, type_):
+ return self._extend("CHAR", type_)
+
+ def visit_NCHAR(self, type_):
+ return self._extend("NCHAR", type_)
+
+ def visit_NVARCHAR(self, type_):
+ return self._extend("NVARCHAR", type_)
+
+ def visit_date(self, type_):
+ if self.dialect.server_version_info < MS_2008_VERSION:
+ return self.visit_DATETIME(type_)
+ else:
+ return self.visit_DATE(type_)
+
+ def visit_time(self, type_):
+ if self.dialect.server_version_info < MS_2008_VERSION:
+ return self.visit_DATETIME(type_)
+ else:
+ return self.visit_TIME(type_)
+
+ def visit_binary(self, type_):
+ if type_.length:
+ return self.visit_BINARY(type_)
+ else:
+ return self.visit_IMAGE(type_)
+
+ def visit_BINARY(self, type_):
+ if type_.length:
+ return "BINARY(%s)" % type_.length
+ else:
+ return "BINARY"
+
+ def visit_IMAGE(self, type_):
+ return "IMAGE"
+
+ def visit_VARBINARY(self, type_):
+ if type_.length:
+ return "VARBINARY(%s)" % type_.length
+ else:
+ return "VARBINARY"
+
+ def visit_boolean(self, type_):
+ return self.visit_BIT(type_)
+
+ def visit_BIT(self, type_):
+ return "BIT"
+
+ def visit_MONEY(self, type_):
+ return "MONEY"
+
+ def visit_SMALLMONEY(self, type_):
+ return 'SMALLMONEY'
+
+ def visit_UNIQUEIDENTIFIER(self, type_):
+ return "UNIQUEIDENTIFIER"
+
+ def visit_SQL_VARIANT(self, type_):
+ return 'SQL_VARIANT'
+
+class MSExecutionContext(default.DefaultExecutionContext):
+ _enable_identity_insert = False
+ _select_lastrowid = False
+ _result_proxy = None
+ _lastrowid = None
+
+ def pre_exec(self):
+ """Activate IDENTITY_INSERT if needed."""
+
+ if self.isinsert:
+ tbl = self.compiled.statement.table
+ seq_column = tbl._autoincrement_column
+ insert_has_sequence = seq_column is not None
+
+ if insert_has_sequence:
+ self._enable_identity_insert = seq_column.key in self.compiled_parameters[0]
+ else:
+ self._enable_identity_insert = False
+
+ self._select_lastrowid = insert_has_sequence and \
+ not self.compiled.returning and \
+ not self._enable_identity_insert and \
+ not self.executemany
+
+ if self._enable_identity_insert:
+ self.cursor.execute("SET IDENTITY_INSERT %s ON" %
+ self.dialect.identifier_preparer.format_table(tbl))
+
+ def post_exec(self):
+ """Disable IDENTITY_INSERT if enabled."""
+
+ if self._select_lastrowid:
+ if self.dialect.use_scope_identity:
+ self.cursor.execute("SELECT scope_identity() AS lastrowid")
+ else:
+ self.cursor.execute("SELECT @@identity AS lastrowid")
+ row = self.cursor.fetchall()[0] # fetchall() ensures the cursor is consumed without closing it
+ self._lastrowid = int(row[0])
+
+ if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning:
+ self._result_proxy = base.FullyBufferedResultProxy(self)
+
+ if self._enable_identity_insert:
+ self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
+
+ def get_lastrowid(self):
+ return self._lastrowid
+
+ def handle_dbapi_exception(self, e):
+ if self._enable_identity_insert:
+ try:
+ self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
+ except:
+ pass
+
+ def get_result_proxy(self):
+ if self._result_proxy:
+ return self._result_proxy
+ else:
+ return base.ResultProxy(self)
+
+class MSSQLCompiler(compiler.SQLCompiler):
+
+ extract_map = compiler.SQLCompiler.extract_map.copy()
+ extract_map.update ({
+ 'doy': 'dayofyear',
+ 'dow': 'weekday',
+ 'milliseconds': 'millisecond',
+ 'microseconds': 'microsecond'
+ })
+
+ def __init__(self, *args, **kwargs):
+ super(MSSQLCompiler, self).__init__(*args, **kwargs)
+ self.tablealiases = {}
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_current_date_func(self, fn, **kw):
+ return "GETDATE()"
+
+ def visit_length_func(self, fn, **kw):
+ return "LEN%s" % self.function_argspec(fn, **kw)
+
+ def visit_char_length_func(self, fn, **kw):
+ return "LEN%s" % self.function_argspec(fn, **kw)
+
+ def visit_concat_op(self, binary):
+ return "%s + %s" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_match_op(self, binary):
+ return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+ def get_select_precolumns(self, select):
+ """ MS-SQL puts TOP, it's version of LIMIT here """
+ if select._distinct or select._limit:
+ s = select._distinct and "DISTINCT " or ""
+
+ if select._limit:
+ if not select._offset:
+ s += "TOP %s " % (select._limit,)
+ return s
+ return compiler.SQLCompiler.get_select_precolumns(self, select)
+
+ def limit_clause(self, select):
+ # Limit in mssql is after the select keyword
+ return ""
+
+ def visit_select(self, select, **kwargs):
+ """Look for ``LIMIT`` and OFFSET in a select statement, and if
+ so tries to wrap it in a subquery with ``row_number()`` criterion.
+
+ """
+ if not getattr(select, '_mssql_visit', None) and select._offset:
+ # to use ROW_NUMBER(), an ORDER BY is required.
+ orderby = self.process(select._order_by_clause)
+ if not orderby:
+ raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
+
+ _offset = select._offset
+ _limit = select._limit
+ select._mssql_visit = True
+ select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
+
+ limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
+ limitselect.append_whereclause("mssql_rn>%d" % _offset)
+ if _limit is not None:
+ limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
+ return self.process(limitselect, iswrapper=True, **kwargs)
+ else:
+ return compiler.SQLCompiler.visit_select(self, select, **kwargs)
+
+ def _schema_aliased_table(self, table):
+ if getattr(table, 'schema', None) is not None:
+ if table not in self.tablealiases:
+ self.tablealiases[table] = table.alias()
+ return self.tablealiases[table]
+ else:
+ return None
+
+ def visit_table(self, table, mssql_aliased=False, **kwargs):
+ if mssql_aliased:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+ # alias schema-qualified tables
+ alias = self._schema_aliased_table(table)
+ if alias is not None:
+ return self.process(alias, mssql_aliased=True, **kwargs)
+ else:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+ def visit_alias(self, alias, **kwargs):
+ # translate for schema-qualified table aliases
+ self.tablealiases[alias.original] = alias
+ kwargs['mssql_aliased'] = True
+ return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
+
+ def visit_extract(self, extract):
+ field = self.extract_map.get(extract.field, extract.field)
+ return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
+
+ def visit_rollback_to_savepoint(self, savepoint_stmt):
+ return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
+
+ def visit_column(self, column, result_map=None, **kwargs):
+ if column.table is not None and \
+ (not self.isupdate and not self.isdelete) or self.is_subquery():
+ # translate for schema-qualified table aliases
+ t = self._schema_aliased_table(column.table)
+ if t is not None:
+ converted = expression._corresponding_column_or_error(t, column)
+
+ if result_map is not None:
+ result_map[column.name.lower()] = (column.name, (column, ), column.type)
+
+ return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs)
+
+ return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs)
+
+ def visit_binary(self, binary, **kwargs):
+ """Move bind parameters to the right-hand side of an operator, where
+ possible.
+
+ """
+ if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
+ and not isinstance(binary.right, expression._BindParamClause):
+ return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
+ else:
+ if (binary.operator is operator.eq or binary.operator is operator.ne) and (
+ (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
+ (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
+ isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
+ op = binary.operator == operator.eq and "IN" or "NOT IN"
+ return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
+ return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
+
+ def returning_clause(self, stmt, returning_cols):
+
+ if self.isinsert or self.isupdate:
+ target = stmt.table.alias("inserted")
+ else:
+ target = stmt.table.alias("deleted")
+
+ adapter = sql_util.ClauseAdapter(target)
+ def col_label(col):
+ adapted = adapter.traverse(c)
+ if isinstance(c, expression._Label):
+ return adapted.label(c.key)
+ else:
+ return self.label_select_column(None, adapted, asfrom=False)
+
+ columns = [
+ self.process(
+ col_label(c),
+ within_columns_clause=True,
+ result_map=self.result_map
+ )
+ for c in expression._select_iterables(returning_cols)
+ ]
+ return 'OUTPUT ' + ', '.join(columns)
+
+ def label_select_column(self, select, column, asfrom):
+ if isinstance(column, expression.Function):
+ return column.label(None)
+ else:
+ return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
+
+ def for_update_clause(self, select):
+ # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
+ return ''
+
+ def order_by_clause(self, select):
+ order_by = self.process(select._order_by_clause)
+
+ # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
+ if order_by and (not self.is_subquery() or select._limit):
+ return " ORDER BY " + order_by
+ else:
+ return ""
+
+
+class MSDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
+
+ if column.nullable is not None:
+ if not column.nullable or column.primary_key:
+ colspec += " NOT NULL"
+ else:
+ colspec += " NULL"
+
+ if not column.table:
+ raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
+
+ seq_col = column.table._autoincrement_column
+
+ # install a IDENTITY Sequence if we have an implicit IDENTITY column
+ if seq_col is column:
+ sequence = getattr(column, 'sequence', None)
+ if sequence:
+ start, increment = sequence.start or 1, sequence.increment or 1
+ else:
+ start, increment = 1, 1
+ colspec += " IDENTITY(%s,%s)" % (start, increment)
+ else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ return colspec
+
+ def visit_drop_index(self, drop):
+ return "\nDROP INDEX %s.%s" % (
+ self.preparer.quote_identifier(drop.element.table.name),
+ self.preparer.quote(self._validate_identifier(drop.element.name, False), drop.element.quote)
+ )
+
+
+class MSIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = RESERVED_WORDS
+
+ def __init__(self, dialect):
+ super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
+
+ def _escape_identifier(self, value):
+ #TODO: determine MSSQL's escaping rules
+ return value
+
+ def quote_schema(self, schema, force=True):
+ """Prepare a quoted table and schema name."""
+ result = '.'.join([self.quote(x, force) for x in schema.split('.')])
+ return result
+
+class MSDialect(default.DefaultDialect):
+ name = 'mssql'
+ supports_default_values = True
+ supports_empty_insert = False
+ execution_ctx_cls = MSExecutionContext
+ text_as_varchar = False
+ use_scope_identity = True
+ max_identifier_length = 128
+ schema_name = "dbo"
+ colspecs = colspecs
+ ischema_names = ischema_names
+
+ supports_unicode_binds = True
+ postfetch_lastrowid = True
+
+ server_version_info = ()
+
+ statement_compiler = MSSQLCompiler
+ ddl_compiler = MSDDLCompiler
+ type_compiler = MSTypeCompiler
+ preparer = MSIdentifierPreparer
+
+ def __init__(self,
+ query_timeout=None,
+ use_scope_identity=True,
+ max_identifier_length=None,
+ schema_name="dbo", **opts):
+ self.query_timeout = int(query_timeout or 0)
+ self.schema_name = schema_name
+
+ self.use_scope_identity = use_scope_identity
+ self.max_identifier_length = int(max_identifier_length or 0) or \
+ self.max_identifier_length
+ super(MSDialect, self).__init__(**opts)
+
+ def do_savepoint(self, connection, name):
+ util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
+ connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION")
+ connection.execute("SAVE TRANSACTION %s" % name)
+
+ def do_release_savepoint(self, connection, name):
+ pass
+
+ def initialize(self, connection):
+ super(MSDialect, self).initialize(connection)
+ if self.server_version_info >= MS_2005_VERSION and 'implicit_returning' not in self.__dict__:
+ self.implicit_returning = True
+
+ def get_default_schema_name(self, connection):
+ return self.default_schema_name
+
+ def _get_default_schema_name(self, connection):
+ user_name = connection.scalar("SELECT user_name() as user_name;")
+ if user_name is not None:
+ # now, get the default schema
+ query = """
+ SELECT default_schema_name FROM
+ sys.database_principals
+ WHERE name = ?
+ AND type = 'S'
+ """
+ try:
+ default_schema_name = connection.scalar(query, [user_name])
+ if default_schema_name is not None:
+ return default_schema_name
+ except:
+ pass
+ return self.schema_name
+
+ def table_names(self, connection, schema):
+ s = select([ischema.tables.c.table_name], ischema.tables.c.table_schema==schema)
+ return [row[0] for row in connection.execute(s)]
+
+
+ def has_table(self, connection, tablename, schema=None):
+ current_schema = schema or self.default_schema_name
+ columns = ischema.columns
+ s = sql.select([columns],
+ current_schema
+ and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
+ or columns.c.table_name==tablename,
+ )
+
+ c = connection.execute(s)
+ row = c.fetchone()
+ return row is not None
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ s = sql.select([ischema.schemata.c.schema_name],
+ order_by=[ischema.schemata.c.schema_name]
+ )
+ schema_names = [r[0] for r in connection.execute(s)]
+ return schema_names
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ current_schema = schema or self.default_schema_name
+ tables = ischema.tables
+ s = sql.select([tables.c.table_name],
+ sql.and_(
+ tables.c.table_schema == current_schema,
+ tables.c.table_type == 'BASE TABLE'
+ ),
+ order_by=[tables.c.table_name]
+ )
+ table_names = [r[0] for r in connection.execute(s)]
+ return table_names
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ current_schema = schema or self.default_schema_name
+ tables = ischema.tables
+ s = sql.select([tables.c.table_name],
+ sql.and_(
+ tables.c.table_schema == current_schema,
+ tables.c.table_type == 'VIEW'
+ ),
+ order_by=[tables.c.table_name]
+ )
+ view_names = [r[0] for r in connection.execute(s)]
+ return view_names
+
+ # The cursor reports it is closed after executing the sp.
+ @reflection.cache
+ def get_indexes(self, connection, tablename, schema=None, **kw):
+ current_schema = schema or self.default_schema_name
+ full_tname = "%s.%s" % (current_schema, tablename)
+ indexes = []
+ s = sql.text("exec sp_helpindex '%s'" % full_tname)
+ rp = connection.execute(s)
+ if rp.closed:
+ # did not work for this setup.
+ return []
+ for row in rp:
+ if 'primary key' not in row['index_description']:
+ indexes.append({
+ 'name' : row['index_name'],
+ 'column_names' : [c.strip() for c in row['index_keys'].split(',')],
+ 'unique': 'unique' in row['index_description']
+ })
+ return indexes
+
+ @reflection.cache
+ def get_view_definition(self, connection, viewname, schema=None, **kw):
+ current_schema = schema or self.default_schema_name
+ views = ischema.views
+ s = sql.select([views.c.view_definition],
+ sql.and_(
+ views.c.table_schema == current_schema,
+ views.c.table_name == viewname
+ ),
+ )
+ rp = connection.execute(s)
+ if rp:
+ view_def = rp.scalar()
+ return view_def
+
+ @reflection.cache
+ def get_columns(self, connection, tablename, schema=None, **kw):
+ # Get base columns
+ current_schema = schema or self.default_schema_name
+ columns = ischema.columns
+ s = sql.select([columns],
+ current_schema
+ and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
+ or columns.c.table_name==tablename,
+ order_by=[columns.c.ordinal_position])
+ c = connection.execute(s)
+ cols = []
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ (name, type, nullable, charlen, numericprec, numericscale, default, collation) = (
+ row[columns.c.column_name],
+ row[columns.c.data_type],
+ row[columns.c.is_nullable] == 'YES',
+ row[columns.c.character_maximum_length],
+ row[columns.c.numeric_precision],
+ row[columns.c.numeric_scale],
+ row[columns.c.column_default],
+ row[columns.c.collation_name]
+ )
+ coltype = self.ischema_names.get(type, None)
+
+ kwargs = {}
+ if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.Binary):
+ kwargs['length'] = charlen
+ if collation:
+ kwargs['collation'] = collation
+ if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1):
+ kwargs.pop('length')
+
+ if coltype is None:
+ util.warn("Did not recognize type '%s' of column '%s'" % (type, name))
+ coltype = sqltypes.NULLTYPE
+
+ if issubclass(coltype, sqltypes.Numeric) and coltype is not MSReal:
+ kwargs['scale'] = numericscale
+ kwargs['precision'] = numericprec
+
+ coltype = coltype(**kwargs)
+ cdict = {
+ 'name' : name,
+ 'type' : coltype,
+ 'nullable' : nullable,
+ 'default' : default,
+ 'autoincrement':False,
+ }
+ cols.append(cdict)
+ # autoincrement and identity
+ colmap = {}
+ for col in cols:
+ colmap[col['name']] = col
+ # We also run an sp_columns to check for identity columns:
+ cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (tablename, current_schema))
+ ic = None
+ while True:
+ row = cursor.fetchone()
+ if row is None:
+ break
+ (col_name, type_name) = row[3], row[5]
+ if type_name.endswith("identity") and col_name in colmap:
+ ic = col_name
+ colmap[col_name]['autoincrement'] = True
+ colmap[col_name]['sequence'] = dict(
+ name='%s_identity' % col_name)
+ break
+ cursor.close()
+ if ic is not None:
+ try:
+ # is this table_fullname reliable?
+ table_fullname = "%s.%s" % (current_schema, tablename)
+ cursor = connection.execute(
+ sql.text("select ident_seed(:seed), ident_incr(:incr)"),
+ {'seed':table_fullname, 'incr':table_fullname}
+ )
+ row = cursor.fetchone()
+ cursor.close()
+ if not row is None:
+ colmap[ic]['sequence'].update({
+ 'start' : int(row[0]),
+ 'increment' : int(row[1])
+ })
+ except:
+ # ignoring it, works just like before
+ pass
+ return cols
+
+ @reflection.cache
+ def get_primary_keys(self, connection, tablename, schema=None, **kw):
+ current_schema = schema or self.default_schema_name
+ pkeys = []
+ # Add constraints
+ RR = ischema.ref_constraints #information_schema.referential_constraints
+ TC = ischema.constraints #information_schema.table_constraints
+ C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column
+ R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column
+
+ # Primary key constraints
+ s = sql.select([C.c.column_name, TC.c.constraint_type],
+ sql.and_(TC.c.constraint_name == C.c.constraint_name,
+ C.c.table_name == tablename,
+ C.c.table_schema == current_schema)
+ )
+ c = connection.execute(s)
+ for row in c:
+ if 'PRIMARY' in row[TC.c.constraint_type.name]:
+ pkeys.append(row[0])
+ return pkeys
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, tablename, schema=None, **kw):
+ current_schema = schema or self.default_schema_name
+ # Add constraints
+ RR = ischema.ref_constraints #information_schema.referential_constraints
+ TC = ischema.constraints #information_schema.table_constraints
+ C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column
+ R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column
+
+ # Foreign key constraints
+ s = sql.select([C.c.column_name,
+ R.c.table_schema, R.c.table_name, R.c.column_name,
+ RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule],
+ sql.and_(C.c.table_name == tablename,
+ C.c.table_schema == current_schema,
+ C.c.constraint_name == RR.c.constraint_name,
+ R.c.constraint_name == RR.c.unique_constraint_name,
+ C.c.ordinal_position == R.c.ordinal_position
+ ),
+ order_by = [RR.c.constraint_name, R.c.ordinal_position])
+
+
+ # group rows by constraint ID, to handle multi-column FKs
+ fkeys = []
+ fknm, scols, rcols = (None, [], [])
+
+ def fkey_rec():
+ return {
+ 'name' : None,
+ 'constrained_columns' : [],
+ 'referred_schema' : None,
+ 'referred_table' : None,
+ 'referred_columns' : []
+ }
+
+ fkeys = util.defaultdict(fkey_rec)
+
+ for r in connection.execute(s).fetchall():
+ scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
+
+ rec = fkeys[rfknm]
+ rec['name'] = rfknm
+ if not rec['referred_table']:
+ rec['referred_table'] = rtbl
+
+ if schema is not None or current_schema != rschema:
+ rec['referred_schema'] = rschema
+
+ local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns']
+
+ local_cols.append(scol)
+ remote_cols.append(rcol)
+
+ return fkeys.values()
+
+
+# fixme. I added this for the tests to run. -Randall
+MSSQLDialect = MSDialect
--- /dev/null
+from sqlalchemy import Table, MetaData, Column, ForeignKey
+from sqlalchemy.types import String, Unicode, Integer, TypeDecorator
+
+ischema = MetaData()
+
+class CoerceUnicode(TypeDecorator):
+ impl = Unicode
+
+ def process_bind_param(self, value, dialect):
+ if isinstance(value, str):
+ value = value.decode(dialect.encoding)
+ return value
+
+schemata = Table("SCHEMATA", ischema,
+ Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
+ Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
+ Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
+ schema="INFORMATION_SCHEMA")
+
+tables = Table("TABLES", ischema,
+ Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("TABLE_TYPE", String, key="table_type"),
+ schema="INFORMATION_SCHEMA")
+
+columns = Table("COLUMNS", ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("IS_NULLABLE", Integer, key="is_nullable"),
+ Column("DATA_TYPE", String, key="data_type"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"),
+ Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+ Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+ Column("COLUMN_DEFAULT", Integer, key="column_default"),
+ Column("COLLATION_NAME", String, key="collation_name"),
+ schema="INFORMATION_SCHEMA")
+
+constraints = Table("TABLE_CONSTRAINTS", ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ Column("CONSTRAINT_TYPE", String, key="constraint_type"),
+ schema="INFORMATION_SCHEMA")
+
+column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ schema="INFORMATION_SCHEMA")
+
+key_constraints = Table("KEY_COLUMN_USAGE", ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ schema="INFORMATION_SCHEMA")
+
+ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema,
+ Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
+ Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, key="unique_constraint_catalog"), # TODO: is CATLOG misspelled ?
+ Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, key="unique_constraint_schema"),
+ Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"),
+ Column("MATCH_OPTION", String, key="match_option"),
+ Column("UPDATE_RULE", String, key="update_rule"),
+ Column("DELETE_RULE", String, key="delete_rule"),
+ schema="INFORMATION_SCHEMA")
+
+views = Table("VIEWS", ischema,
+ Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
+ Column("CHECK_OPTION", String, key="check_option"),
+ Column("IS_UPDATABLE", String, key="is_updatable"),
+ schema="INFORMATION_SCHEMA")
+
--- /dev/null
+from sqlalchemy.dialects.mssql.base import MSDialect
+from sqlalchemy import types as sqltypes
+
+
+class MSDialect_pymssql(MSDialect):
+ supports_sane_rowcount = False
+ max_identifier_length = 30
+ driver = 'pymssql'
+
+ @classmethod
+ def import_dbapi(cls):
+ import pymssql as module
+ # pymmsql doesn't have a Binary method. we use string
+ # TODO: monkeypatching here is less than ideal
+ module.Binary = lambda st: str(st)
+ return module
+
+ def __init__(self, **params):
+ super(MSSQLDialect_pymssql, self).__init__(**params)
+ self.use_scope_identity = True
+
+ # pymssql understands only ascii
+ if self.convert_unicode:
+ util.warn("pymssql does not support unicode")
+ self.encoding = params.get('encoding', 'ascii')
+
+
+ def create_connect_args(self, url):
+ if hasattr(self, 'query_timeout'):
+ # ick, globals ? we might want to move this....
+ self.dbapi._mssql.set_query_timeout(self.query_timeout)
+
+ keys = url.query
+ if keys.get('port'):
+ # pymssql expects port as host:port, not a separate arg
+ keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
+ del keys['port']
+ return [[], keys]
+
+ def is_disconnect(self, e):
+ return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
+
+ def do_begin(self, connection):
+ pass
+
+dialect = MSDialect_pymssql
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+from sqlalchemy import types as sqltypes
+import re
+import sys
+
+class MSExecutionContext_pyodbc(MSExecutionContext):
+ _embedded_scope_identity = False
+
+ def pre_exec(self):
+ """where appropriate, issue "select scope_identity()" in the same statement.
+
+ Background on why "scope_identity()" is preferable to "@@identity":
+ http://msdn.microsoft.com/en-us/library/ms190315.aspx
+
+ Background on why we attempt to embed "scope_identity()" into the same
+ statement as the INSERT:
+ http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
+
+ """
+
+ super(MSExecutionContext_pyodbc, self).pre_exec()
+
+ # don't embed the scope_identity select into an "INSERT .. DEFAULT VALUES"
+ if self._select_lastrowid and \
+ self.dialect.use_scope_identity and \
+ len(self.parameters[0]):
+ self._embedded_scope_identity = True
+
+ self.statement += "; select scope_identity()"
+
+ def post_exec(self):
+ if self._embedded_scope_identity:
+ # Fetch the last inserted id from the manipulated statement
+ # We may have to skip over a number of result sets with no data (due to triggers, etc.)
+ while True:
+ try:
+ # fetchall() ensures the cursor is consumed without closing it (FreeTDS particularly)
+ row = self.cursor.fetchall()[0]
+ break
+ except self.dialect.dbapi.Error, e:
+ # no way around this - nextset() consumes the previous set
+ # so we need to just keep flipping
+ self.cursor.nextset()
+
+ self._lastrowid = int(row[0])
+ else:
+ super(MSExecutionContext_pyodbc, self).post_exec()
+
+
+class MSDialect_pyodbc(PyODBCConnector, MSDialect):
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+
+ execution_ctx_cls = MSExecutionContext_pyodbc
+
+ pyodbc_driver_name = 'SQL Server'
+
+ def __init__(self, description_encoding='latin-1', **params):
+ super(MSDialect_pyodbc, self).__init__(**params)
+ self.description_encoding = description_encoding
+ self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset')
+
+ def initialize(self, connection):
+ super(MSDialect_pyodbc, self).initialize(connection)
+ pyodbc = self.dbapi
+
+ dbapi_con = connection.connection
+
+ self._free_tds = re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME))
+
+ # the "Py2K only" part here is theoretical.
+ # have not tried pyodbc + python3.1 yet.
+ # Py2K
+ self.supports_unicode_statements = not self._free_tds
+ self.supports_unicode_binds = not self._free_tds
+ # end Py2K
+
+dialect = MSDialect_pyodbc
--- /dev/null
+from sqlalchemy.dialects.mysql import base, mysqldb, pyodbc, zxjdbc
+
+# default dialect
+base.dialect = mysqldb.dialect
\ No newline at end of file
types when creating tables with your :class:`~sqlalchemy.Table` definitions,
then you will need to import them from this module::
- from sqlalchemy.databases import mysql
+ from sqlalchemy.dialect.mysql import base as mysql
Table('mytable', metadata,
Column('id', Integer, primary_key=True),
- Column('ittybittyblob', mysql.MSTinyBlob),
- Column('biggy', mysql.MSBigInteger(unsigned=True)))
+ Column('ittybittyblob', mysql.TINYBLOB),
+ Column('biggy', mysql.BIGINT(unsigned=True)))
All standard MySQL column types are supported. The OpenGIS types are
available for use via table reflection but have no special support or mapping
See the official MySQL documentation for detailed information about features
supported in any given server release.
-Character Sets
---------------
-
-Many MySQL server installations default to a ``latin1`` encoding for client
-connections. All data sent through the connection will be converted into
-``latin1``, even if you have ``utf8`` or another character set on your tables
-and columns. With versions 4.1 and higher, you can change the connection
-character set either through server configuration or by including the
-``charset`` parameter in the URL used for ``create_engine``. The ``charset``
-option is passed through to MySQL-Python and has the side-effect of also
-enabling ``use_unicode`` in the driver by default. For regular encoded
-strings, also pass ``use_unicode=0`` in the connection arguments::
-
- # set client encoding to utf8; all strings come back as unicode
- create_engine('mysql:///mydb?charset=utf8')
-
- # set client encoding to utf8; all strings come back as utf8 str
- create_engine('mysql:///mydb?charset=utf8&use_unicode=0')
-
Storage Engines
---------------
"""
-import datetime, decimal, inspect, re, sys
-from array import array as _array
+import datetime, inspect, re, sys
-from sqlalchemy import exc, log, schema, sql, util
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import exc, log, sql, util
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy.sql import functions as sql_functions
from sqlalchemy.sql import compiler
+from array import array as _array
+from sqlalchemy.engine import reflection
from sqlalchemy.engine import base as engine_base, default
from sqlalchemy import types as sqltypes
-
-__all__ = (
- 'MSBigInteger', 'MSMediumInteger', 'MSBinary', 'MSBit', 'MSBlob', 'MSBoolean',
- 'MSChar', 'MSDate', 'MSDateTime', 'MSDecimal', 'MSDouble',
- 'MSEnum', 'MSFloat', 'MSInteger', 'MSLongBlob', 'MSLongText',
- 'MSMediumBlob', 'MSMediumText', 'MSNChar', 'MSNVarChar',
- 'MSNumeric', 'MSSet', 'MSSmallInteger', 'MSString', 'MSText',
- 'MSTime', 'MSTimeStamp', 'MSTinyBlob', 'MSTinyInteger',
- 'MSTinyText', 'MSVarBinary', 'MSYear' )
-
+from sqlalchemy.types import DATE, DATETIME, BOOLEAN, TIME
RESERVED_WORDS = set(
['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc',
class _NumericType(object):
"""Base for MySQL numeric types."""
- def __init__(self, kw):
+ def __init__(self, **kw):
self.unsigned = kw.pop('unsigned', False)
self.zerofill = kw.pop('zerofill', False)
+ super(_NumericType, self).__init__(**kw)
+
+class _FloatType(_NumericType, sqltypes.Float):
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ if isinstance(self, (REAL, DOUBLE)) and \
+ (
+ (precision is None and scale is not None) or
+ (precision is not None and scale is None)
+ ):
+ raise exc.ArgumentError(
+ "You must specify both precision and scale or omit "
+ "both altogether.")
+
+ super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw)
+ self.scale = scale
- def _extend(self, spec):
- "Extend a numeric-type declaration with MySQL specific extensions."
-
- if self.unsigned:
- spec += ' UNSIGNED'
- if self.zerofill:
- spec += ' ZEROFILL'
- return spec
-
+class _IntegerType(_NumericType, sqltypes.Integer):
+ def __init__(self, display_width=None, **kw):
+ self.display_width = display_width
+ super(_IntegerType, self).__init__(**kw)
-class _StringType(object):
+class _StringType(sqltypes.String):
"""Base for MySQL string types."""
def __init__(self, charset=None, collation=None,
ascii=False, unicode=False, binary=False,
- national=False, **kwargs):
+ national=False, **kw):
self.charset = charset
# allow collate= or collation=
- self.collation = kwargs.get('collate', collation)
+ self.collation = kw.pop('collate', collation)
self.ascii = ascii
self.unicode = unicode
self.binary = binary
self.national = national
-
- def _extend(self, spec):
- """Extend a string-type declaration with standard SQL CHARACTER SET /
- COLLATE annotations and MySQL specific extensions.
- """
-
- if self.charset:
- charset = 'CHARACTER SET %s' % self.charset
- elif self.ascii:
- charset = 'ASCII'
- elif self.unicode:
- charset = 'UNICODE'
- else:
- charset = None
-
- if self.collation:
- collation = 'COLLATE %s' % self.collation
- elif self.binary:
- collation = 'BINARY'
- else:
- collation = None
-
- if self.national:
- # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets.
- return ' '.join([c for c in ('NATIONAL', spec, collation)
- if c is not None])
- return ' '.join([c for c in (spec, charset, collation)
- if c is not None])
-
+ super(_StringType, self).__init__(**kw)
+
def __repr__(self):
attributes = inspect.getargspec(self.__init__)[0][1:]
attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
', '.join(['%s=%r' % (k, params[k]) for k in params]))
-class MSNumeric(sqltypes.Numeric, _NumericType):
- """MySQL NUMERIC type."""
+class _BinaryType(sqltypes.Binary):
+ """Base for MySQL binary types."""
+
+ def result_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ else:
+ return util.buffer(value)
+ return process
- def __init__(self, precision=10, scale=2, asdecimal=True, **kw):
+class NUMERIC(_NumericType, sqltypes.NUMERIC):
+ """MySQL NUMERIC type."""
+
+ __visit_name__ = 'NUMERIC'
+
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a NUMERIC.
:param precision: Total digits in this number. If scale and precision
numeric.
"""
- _NumericType.__init__(self, kw)
- sqltypes.Numeric.__init__(self, precision, scale, asdecimal=asdecimal, **kw)
-
- def get_col_spec(self):
- if self.precision is None:
- return self._extend("NUMERIC")
- else:
- return self._extend("NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale})
-
- def bind_processor(self, dialect):
- return None
+ super(NUMERIC, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
- def result_processor(self, dialect):
- if not self.asdecimal:
- def process(value):
- if isinstance(value, decimal.Decimal):
- return float(value)
- else:
- return value
- return process
- else:
- return None
-
-class MSDecimal(MSNumeric):
+class DECIMAL(_NumericType, sqltypes.DECIMAL):
"""MySQL DECIMAL type."""
-
- def __init__(self, precision=10, scale=2, asdecimal=True, **kw):
+
+ __visit_name__ = 'DECIMAL'
+
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DECIMAL.
:param precision: Total digits in this number. If scale and precision
numeric.
"""
- super(MSDecimal, self).__init__(precision, scale, asdecimal=asdecimal, **kw)
-
- def get_col_spec(self):
- if self.precision is None:
- return self._extend("DECIMAL")
- elif self.scale is None:
- return self._extend("DECIMAL(%(precision)s)" % {'precision': self.precision})
- else:
- return self._extend("DECIMAL(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale})
-
+ super(DECIMAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
-class MSDouble(sqltypes.Float, _NumericType):
+
+class DOUBLE(_FloatType):
"""MySQL DOUBLE type."""
+ __visit_name__ = 'DOUBLE'
+
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DOUBLE.
numeric.
"""
- if ((precision is None and scale is not None) or
- (precision is not None and scale is None)):
- raise exc.ArgumentError(
- "You must specify both precision and scale or omit "
- "both altogether.")
-
- _NumericType.__init__(self, kw)
- sqltypes.Float.__init__(self, asdecimal=asdecimal, **kw)
- self.scale = scale
- self.precision = precision
+ super(DOUBLE, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
- def get_col_spec(self):
- if self.precision is not None and self.scale is not None:
- return self._extend("DOUBLE(%(precision)s, %(scale)s)" %
- {'precision': self.precision,
- 'scale' : self.scale})
- else:
- return self._extend('DOUBLE')
-
-
-class MSReal(MSDouble):
+class REAL(_FloatType):
"""MySQL REAL type."""
+ __visit_name__ = 'REAL'
+
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a REAL.
numeric.
"""
- MSDouble.__init__(self, precision, scale, asdecimal, **kw)
+ super(REAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
- def get_col_spec(self):
- if self.precision is not None and self.scale is not None:
- return self._extend("REAL(%(precision)s, %(scale)s)" %
- {'precision': self.precision,
- 'scale' : self.scale})
- else:
- return self._extend('REAL')
-
-
-class MSFloat(sqltypes.Float, _NumericType):
+class FLOAT(_FloatType, sqltypes.FLOAT):
"""MySQL FLOAT type."""
+ __visit_name__ = 'FLOAT'
+
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
"""Construct a FLOAT.
numeric.
"""
- _NumericType.__init__(self, kw)
- sqltypes.Float.__init__(self, asdecimal=asdecimal, **kw)
- self.scale = scale
- self.precision = precision
-
- def get_col_spec(self):
- if self.scale is not None and self.precision is not None:
- return self._extend("FLOAT(%s, %s)" % (self.precision, self.scale))
- elif self.precision is not None:
- return self._extend("FLOAT(%s)" % (self.precision,))
- else:
- return self._extend("FLOAT")
+ super(FLOAT, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw)
def bind_processor(self, dialect):
return None
-
-class MSInteger(sqltypes.Integer, _NumericType):
+class INTEGER(_IntegerType, sqltypes.INTEGER):
"""MySQL INTEGER type."""
+ __visit_name__ = 'INTEGER'
+
def __init__(self, display_width=None, **kw):
"""Construct an INTEGER.
numeric.
"""
- if 'length' in kw:
- util.warn_deprecated("'length' is deprecated for MSInteger and subclasses. Use 'display_width'.")
- self.display_width = kw.pop('length')
- else:
- self.display_width = display_width
- _NumericType.__init__(self, kw)
- sqltypes.Integer.__init__(self, **kw)
-
- def get_col_spec(self):
- if self.display_width is not None:
- return self._extend("INTEGER(%(display_width)s)" % {'display_width': self.display_width})
- else:
- return self._extend("INTEGER")
+ super(INTEGER, self).__init__(display_width=display_width, **kw)
-
-class MSBigInteger(MSInteger):
+class BIGINT(_IntegerType, sqltypes.BIGINT):
"""MySQL BIGINTEGER type."""
+ __visit_name__ = 'BIGINT'
+
def __init__(self, display_width=None, **kw):
"""Construct a BIGINTEGER.
numeric.
"""
- super(MSBigInteger, self).__init__(display_width, **kw)
-
- def get_col_spec(self):
- if self.display_width is not None:
- return self._extend("BIGINT(%(display_width)s)" % {'display_width': self.display_width})
- else:
- return self._extend("BIGINT")
+ super(BIGINT, self).__init__(display_width=display_width, **kw)
-
-class MSMediumInteger(MSInteger):
+class MEDIUMINT(_IntegerType):
"""MySQL MEDIUMINTEGER type."""
+ __visit_name__ = 'MEDIUMINT'
+
def __init__(self, display_width=None, **kw):
"""Construct a MEDIUMINTEGER
numeric.
"""
- super(MSMediumInteger, self).__init__(display_width, **kw)
-
- def get_col_spec(self):
- if self.display_width is not None:
- return self._extend("MEDIUMINT(%(display_width)s)" % {'display_width': self.display_width})
- else:
- return self._extend("MEDIUMINT")
+ super(MEDIUMINT, self).__init__(display_width=display_width, **kw)
-
-
-class MSTinyInteger(MSInteger):
+class TINYINT(_IntegerType):
"""MySQL TINYINT type."""
+ __visit_name__ = 'TINYINT'
+
def __init__(self, display_width=None, **kw):
"""Construct a TINYINT.
numeric.
"""
- super(MSTinyInteger, self).__init__(display_width, **kw)
-
- def get_col_spec(self):
- if self.display_width is not None:
- return self._extend("TINYINT(%s)" % self.display_width)
- else:
- return self._extend("TINYINT")
-
+ super(TINYINT, self).__init__(display_width=display_width, **kw)
-class MSSmallInteger(sqltypes.Smallinteger, MSInteger):
+class SMALLINT(_IntegerType, sqltypes.SMALLINT):
"""MySQL SMALLINTEGER type."""
+ __visit_name__ = 'SMALLINT'
+
def __init__(self, display_width=None, **kw):
"""Construct a SMALLINTEGER.
numeric.
"""
- self.display_width = display_width
- _NumericType.__init__(self, kw)
- sqltypes.SmallInteger.__init__(self, **kw)
+ super(SMALLINT, self).__init__(display_width=display_width, **kw)
- def get_col_spec(self):
- if self.display_width is not None:
- return self._extend("SMALLINT(%(display_width)s)" % {'display_width': self.display_width})
- else:
- return self._extend("SMALLINT")
-
-
-class MSBit(sqltypes.TypeEngine):
+class BIT(sqltypes.TypeEngine):
"""MySQL BIT type.
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater for
"""
+ __visit_name__ = 'BIT'
+
def __init__(self, length=None):
"""Construct a BIT.
return value
return process
- def get_col_spec(self):
- if self.length is not None:
- return "BIT(%s)" % self.length
- else:
- return "BIT"
-
-
-class MSDateTime(sqltypes.DateTime):
- """MySQL DATETIME type."""
-
- def get_col_spec(self):
- return "DATETIME"
-
-
-class MSDate(sqltypes.Date):
- """MySQL DATE type."""
-
- def get_col_spec(self):
- return "DATE"
-
-
-class MSTime(sqltypes.Time):
+class _MSTime(sqltypes.Time):
"""MySQL TIME type."""
- def get_col_spec(self):
- return "TIME"
+ __visit_name__ = 'TIME'
def result_processor(self, dialect):
def process(value):
return None
return process
-class MSTimeStamp(sqltypes.TIMESTAMP):
- """MySQL TIMESTAMP type.
-
- To signal the orm to automatically re-select modified rows to retrieve the
- updated timestamp, add a ``server_default`` to your
- :class:`~sqlalchemy.Column` specification::
-
- from sqlalchemy.databases import mysql
- Column('updated', mysql.MSTimeStamp,
- server_default=sql.text('CURRENT_TIMESTAMP')
- )
-
- The full range of MySQL 4.1+ TIMESTAMP defaults can be specified in
- the the default::
+class TIMESTAMP(sqltypes.TIMESTAMP):
+ """MySQL TIMESTAMP type."""
+ __visit_name__ = 'TIMESTAMP'
- server_default=sql.text('CURRENT TIMESTAMP ON UPDATE CURRENT_TIMESTAMP')
-
- """
-
- def get_col_spec(self):
- return "TIMESTAMP"
-
-
-class MSYear(sqltypes.TypeEngine):
+class YEAR(sqltypes.TypeEngine):
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
+ __visit_name__ = 'YEAR'
+
def __init__(self, display_width=None):
self.display_width = display_width
- def get_col_spec(self):
- if self.display_width is None:
- return "YEAR"
- else:
- return "YEAR(%s)" % self.display_width
-
-class MSText(_StringType, sqltypes.Text):
+class TEXT(_StringType, sqltypes.TEXT):
"""MySQL TEXT type, for text up to 2^16 characters."""
- def __init__(self, length=None, **kwargs):
+ __visit_name__ = 'TEXT'
+
+ def __init__(self, length=None, **kw):
"""Construct a TEXT.
:param length: Optional, if provided the server may optimize storage
only the collation of character data.
"""
- _StringType.__init__(self, **kwargs)
- sqltypes.Text.__init__(self, length,
- kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None))
+ super(TEXT, self).__init__(length=length, **kw)
- def get_col_spec(self):
- if self.length:
- return self._extend("TEXT(%d)" % self.length)
- else:
- return self._extend("TEXT")
-
-
-class MSTinyText(MSText):
+class TINYTEXT(_StringType):
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
+ __visit_name__ = 'TINYTEXT'
+
def __init__(self, **kwargs):
"""Construct a TINYTEXT.
only the collation of character data.
"""
+ super(TINYTEXT, self).__init__(**kwargs)
- super(MSTinyText, self).__init__(**kwargs)
-
- def get_col_spec(self):
- return self._extend("TINYTEXT")
-
-
-class MSMediumText(MSText):
+class MEDIUMTEXT(_StringType):
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
+ __visit_name__ = 'MEDIUMTEXT'
+
def __init__(self, **kwargs):
"""Construct a MEDIUMTEXT.
only the collation of character data.
"""
- super(MSMediumText, self).__init__(**kwargs)
+ super(MEDIUMTEXT, self).__init__(**kwargs)
- def get_col_spec(self):
- return self._extend("MEDIUMTEXT")
-
-
-class MSLongText(MSText):
+class LONGTEXT(_StringType):
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
+ __visit_name__ = 'LONGTEXT'
+
def __init__(self, **kwargs):
"""Construct a LONGTEXT.
only the collation of character data.
"""
- super(MSLongText, self).__init__(**kwargs)
-
- def get_col_spec(self):
- return self._extend("LONGTEXT")
+ super(LONGTEXT, self).__init__(**kwargs)
-
-class MSString(_StringType, sqltypes.String):
+
+class VARCHAR(_StringType, sqltypes.VARCHAR):
"""MySQL VARCHAR type, for variable-length character data."""
+ __visit_name__ = 'VARCHAR'
+
def __init__(self, length=None, **kwargs):
"""Construct a VARCHAR.
only the collation of character data.
"""
- _StringType.__init__(self, **kwargs)
- sqltypes.String.__init__(self, length,
- kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None))
-
- def get_col_spec(self):
- if self.length:
- return self._extend("VARCHAR(%d)" % self.length)
- else:
- return self._extend("VARCHAR")
+ super(VARCHAR, self).__init__(length=length, **kwargs)
-
-class MSChar(_StringType, sqltypes.CHAR):
+class CHAR(_StringType, sqltypes.CHAR):
"""MySQL CHAR type, for fixed-length character data."""
+ __visit_name__ = 'CHAR'
+
def __init__(self, length, **kwargs):
- """Construct an NCHAR.
+ """Construct a CHAR.
:param length: Maximum data length, in characters.
compatible with the national character set.
"""
- _StringType.__init__(self, **kwargs)
- sqltypes.CHAR.__init__(self, length,
- kwargs.get('convert_unicode', False))
-
- def get_col_spec(self):
- return self._extend("CHAR(%(length)s)" % {'length' : self.length})
-
+ super(CHAR, self).__init__(length=length, **kwargs)
-class MSNVarChar(_StringType, sqltypes.String):
+class NVARCHAR(_StringType, sqltypes.NVARCHAR):
"""MySQL NVARCHAR type.
For variable-length character data in the server's configured national
character set.
"""
+ __visit_name__ = 'NVARCHAR'
+
def __init__(self, length=None, **kwargs):
"""Construct an NVARCHAR.
"""
kwargs['national'] = True
- _StringType.__init__(self, **kwargs)
- sqltypes.String.__init__(self, length,
- kwargs.get('convert_unicode', False))
-
- def get_col_spec(self):
- # We'll actually generate the equiv. "NATIONAL VARCHAR" instead
- # of "NVARCHAR".
- return self._extend("VARCHAR(%(length)s)" % {'length': self.length})
+ super(NVARCHAR, self).__init__(length=length, **kwargs)
-class MSNChar(_StringType, sqltypes.CHAR):
+class NCHAR(_StringType, sqltypes.NCHAR):
"""MySQL NCHAR type.
For fixed-length character data in the server's configured national
character set.
"""
+ __visit_name__ = 'NCHAR'
+
def __init__(self, length=None, **kwargs):
"""Construct an NCHAR. Arguments are:
"""
kwargs['national'] = True
- _StringType.__init__(self, **kwargs)
- sqltypes.CHAR.__init__(self, length,
- kwargs.get('convert_unicode', False))
- def get_col_spec(self):
- # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR".
- return self._extend("CHAR(%(length)s)" % {'length': self.length})
-
-
-class _BinaryType(sqltypes.Binary):
- """Base for MySQL binary types."""
+ super(NCHAR, self).__init__(length=length, **kwargs)
- def get_col_spec(self):
- if self.length:
- return "BLOB(%d)" % self.length
- else:
- return "BLOB"
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- else:
- return util.buffer(value)
- return process
-class MSVarBinary(_BinaryType):
+class VARBINARY(_BinaryType):
"""MySQL VARBINARY type, for variable length binary data."""
+ __visit_name__ = 'VARBINARY'
+
def __init__(self, length=None, **kw):
"""Construct a VARBINARY. Arguments are:
:param length: Maximum data length, in characters.
"""
- super(MSVarBinary, self).__init__(length, **kw)
-
- def get_col_spec(self):
- if self.length:
- return "VARBINARY(%d)" % self.length
- else:
- return "BLOB"
-
+ super(VARBINARY, self).__init__(length=length, **kw)
-class MSBinary(_BinaryType):
+class BINARY(_BinaryType):
"""MySQL BINARY type, for fixed length binary data"""
+ __visit_name__ = 'BINARY'
+
def __init__(self, length=None, **kw):
"""Construct a BINARY.
specified, this will generate a BLOB. This usage is deprecated.
"""
- super(MSBinary, self).__init__(length, **kw)
-
- def get_col_spec(self):
- if self.length:
- return "BINARY(%d)" % self.length
- else:
- return "BLOB"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- else:
- return util.buffer(value)
- return process
+ super(BINARY, self).__init__(length=length, **kw)
-class MSBlob(_BinaryType):
+class BLOB(_BinaryType, sqltypes.BLOB):
"""MySQL BLOB type, for binary data up to 2^16 bytes"""
+ __visit_name__ = 'BLOB'
+
def __init__(self, length=None, **kw):
"""Construct a BLOB. Arguments are:
``length`` characters.
"""
- super(MSBlob, self).__init__(length, **kw)
-
- def get_col_spec(self):
- if self.length:
- return "BLOB(%d)" % self.length
- else:
- return "BLOB"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- else:
- return util.buffer(value)
- return process
-
- def __repr__(self):
- return "%s()" % self.__class__.__name__
+ super(BLOB, self).__init__(length=length, **kw)
-class MSTinyBlob(MSBlob):
+class TINYBLOB(_BinaryType):
"""MySQL TINYBLOB type, for binary data up to 2^8 bytes."""
+
+ __visit_name__ = 'TINYBLOB'
- def get_col_spec(self):
- return "TINYBLOB"
-
-
-class MSMediumBlob(MSBlob):
+class MEDIUMBLOB(_BinaryType):
"""MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes."""
- def get_col_spec(self):
- return "MEDIUMBLOB"
-
+ __visit_name__ = 'MEDIUMBLOB'
-class MSLongBlob(MSBlob):
+class LONGBLOB(_BinaryType):
"""MySQL LONGBLOB type, for binary data up to 2^32 bytes."""
- def get_col_spec(self):
- return "LONGBLOB"
-
+ __visit_name__ = 'LONGBLOB'
-class MSEnum(MSString):
+class ENUM(_StringType):
"""MySQL ENUM type."""
+ __visit_name__ = 'ENUM'
+
def __init__(self, *enums, **kw):
"""Construct an ENUM.
self.strict = kw.pop('strict', False)
length = max([len(v) for v in self.enums] + [0])
- super(MSEnum, self).__init__(length, **kw)
+ super(ENUM, self).__init__(length=length, **kw)
def bind_processor(self, dialect):
- super_convert = super(MSEnum, self).bind_processor(dialect)
+ super_convert = super(ENUM, self).bind_processor(dialect)
def process(value):
if self.strict and value is not None and value not in self.enums:
raise exc.InvalidRequestError('"%s" not a valid value for '
return value
return process
- def get_col_spec(self):
- quoted_enums = []
- for e in self.enums:
- quoted_enums.append("'%s'" % e.replace("'", "''"))
- return self._extend("ENUM(%s)" % ",".join(quoted_enums))
-
-class MSSet(MSString):
+class SET(_StringType):
"""MySQL SET type."""
+ __visit_name__ = 'SET'
+
def __init__(self, *values, **kw):
"""Construct a SET.
only the collation of character data.
"""
- self.__ddl_values = values
+ self._ddl_values = values
strip_values = []
for a in values:
self.values = strip_values
length = max([len(v) for v in strip_values] + [0])
- super(MSSet, self).__init__(length, **kw)
+ super(SET, self).__init__(length=length, **kw)
def result_processor(self, dialect):
def process(value):
return process
def bind_processor(self, dialect):
- super_convert = super(MSSet, self).bind_processor(dialect)
+ super_convert = super(SET, self).bind_processor(dialect)
def process(value):
if value is None or isinstance(value, (int, long, basestring)):
pass
return value
return process
- def get_col_spec(self):
- return self._extend("SET(%s)" % ",".join(self.__ddl_values))
-
-
-class MSBoolean(sqltypes.Boolean):
+class _MSBoolean(sqltypes.Boolean):
"""MySQL BOOLEAN type."""
- def get_col_spec(self):
- return "BOOL"
+ __visit_name__ = 'BOOLEAN'
def result_processor(self, dialect):
def process(value):
return value and True or False
return process
+# old names
+MSBoolean = _MSBoolean
+MSTime = _MSTime
+MSSet = SET
+MSEnum = ENUM
+MSLongBlob = LONGBLOB
+MSMediumBlob = MEDIUMBLOB
+MSTinyBlob = TINYBLOB
+MSBlob = BLOB
+MSBinary = BINARY
+MSVarBinary = VARBINARY
+MSNChar = NCHAR
+MSNVarChar = NVARCHAR
+MSChar = CHAR
+MSString = VARCHAR
+MSLongText = LONGTEXT
+MSMediumText = MEDIUMTEXT
+MSTinyText = TINYTEXT
+MSText = TEXT
+MSYear = YEAR
+MSTimeStamp = TIMESTAMP
+MSBit = BIT
+MSSmallInteger = SMALLINT
+MSTinyInteger = TINYINT
+MSMediumInteger = MEDIUMINT
+MSBigInteger = BIGINT
+MSNumeric = NUMERIC
+MSDecimal = DECIMAL
+MSDouble = DOUBLE
+MSReal = REAL
+MSFloat = FLOAT
+MSInteger = INTEGER
+
colspecs = {
- sqltypes.Integer: MSInteger,
- sqltypes.Smallinteger: MSSmallInteger,
- sqltypes.Numeric: MSNumeric,
- sqltypes.Float: MSFloat,
- sqltypes.DateTime: MSDateTime,
- sqltypes.Date: MSDate,
- sqltypes.Time: MSTime,
- sqltypes.String: MSString,
- sqltypes.Binary: MSBlob,
- sqltypes.Boolean: MSBoolean,
- sqltypes.Text: MSText,
- sqltypes.CHAR: MSChar,
- sqltypes.NCHAR: MSNChar,
- sqltypes.TIMESTAMP: MSTimeStamp,
- sqltypes.BLOB: MSBlob,
- MSDouble: MSDouble,
- MSReal: MSReal,
- _BinaryType: _BinaryType,
+ sqltypes.Numeric: NUMERIC,
+ sqltypes.Float: FLOAT,
+ sqltypes.Binary: _BinaryType,
+ sqltypes.Boolean: _MSBoolean,
+ sqltypes.Time: _MSTime,
}
# Everything 3.23 through 5.1 excepting OpenGIS types.
ischema_names = {
- 'bigint': MSBigInteger,
- 'binary': MSBinary,
- 'bit': MSBit,
- 'blob': MSBlob,
- 'boolean':MSBoolean,
- 'char': MSChar,
- 'date': MSDate,
- 'datetime': MSDateTime,
- 'decimal': MSDecimal,
- 'double': MSDouble,
- 'enum': MSEnum,
- 'fixed': MSDecimal,
- 'float': MSFloat,
- 'int': MSInteger,
- 'integer': MSInteger,
- 'longblob': MSLongBlob,
- 'longtext': MSLongText,
- 'mediumblob': MSMediumBlob,
- 'mediumint': MSMediumInteger,
- 'mediumtext': MSMediumText,
- 'nchar': MSNChar,
- 'nvarchar': MSNVarChar,
- 'numeric': MSNumeric,
- 'set': MSSet,
- 'smallint': MSSmallInteger,
- 'text': MSText,
- 'time': MSTime,
- 'timestamp': MSTimeStamp,
- 'tinyblob': MSTinyBlob,
- 'tinyint': MSTinyInteger,
- 'tinytext': MSTinyText,
- 'varbinary': MSVarBinary,
- 'varchar': MSString,
- 'year': MSYear,
+ 'bigint': BIGINT,
+ 'binary': BINARY,
+ 'bit': BIT,
+ 'blob': BLOB,
+ 'boolean':BOOLEAN,
+ 'char': CHAR,
+ 'date': DATE,
+ 'datetime': DATETIME,
+ 'decimal': DECIMAL,
+ 'double': DOUBLE,
+ 'enum': ENUM,
+ 'fixed': DECIMAL,
+ 'float': FLOAT,
+ 'int': INTEGER,
+ 'integer': INTEGER,
+ 'longblob': LONGBLOB,
+ 'longtext': LONGTEXT,
+ 'mediumblob': MEDIUMBLOB,
+ 'mediumint': MEDIUMINT,
+ 'mediumtext': MEDIUMTEXT,
+ 'nchar': NCHAR,
+ 'nvarchar': NVARCHAR,
+ 'numeric': NUMERIC,
+ 'set': SET,
+ 'smallint': SMALLINT,
+ 'text': TEXT,
+ 'time': TIME,
+ 'timestamp': TIMESTAMP,
+ 'tinyblob': TINYBLOB,
+ 'tinyint': TINYINT,
+ 'tinytext': TINYTEXT,
+ 'varbinary': VARBINARY,
+ 'varchar': VARCHAR,
+ 'year': YEAR,
}
-
class MySQLExecutionContext(default.DefaultExecutionContext):
def post_exec(self):
- if self.compiled.isinsert and not self.executemany:
- if (not len(self._last_inserted_ids) or
- self._last_inserted_ids[0] is None):
- self._last_inserted_ids = ([self.cursor.lastrowid] +
- self._last_inserted_ids[1:])
- elif (not self.isupdate and not self.should_autocommit and
+ # TODO: i think this 'charset' in the info thing
+ # is out
+
+ if (not self.isupdate and not self.should_autocommit and
self.statement and SET_RE.match(self.statement)):
# This misses if a user forces autocommit on text('SET NAMES'),
# which is probably a programming error anyhow.
def should_autocommit_text(self, statement):
return AUTOCOMMIT_RE.match(statement)
+class MySQLCompiler(compiler.SQLCompiler):
-class MySQLDialect(default.DefaultDialect):
- """Details of the MySQL dialect. Not used directly in application code."""
- name = 'mysql'
- supports_alter = True
- supports_unicode_statements = False
- # identifiers are 64, however aliases can be 255...
- max_identifier_length = 255
- supports_sane_rowcount = True
- default_paramstyle = 'format'
-
- def __init__(self, use_ansiquotes=None, **kwargs):
- self.use_ansiquotes = use_ansiquotes
- default.DefaultDialect.__init__(self, **kwargs)
+ extract_map = compiler.SQLCompiler.extract_map.copy()
+ extract_map.update ({
+ 'milliseconds': 'millisecond',
+ })
+
+ def visit_random_func(self, fn, **kw):
+ return "rand%s" % self.function_argspec(fn)
+
+ def visit_utc_timestamp_func(self, fn, **kw):
+ return "UTC_TIMESTAMP"
+
+ def visit_concat_op(self, binary, **kw):
+ return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_match_op(self, binary, **kw):
+ return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_typeclause(self, typeclause):
+ type_ = typeclause.type.dialect_impl(self.dialect)
+ if isinstance(type_, sqltypes.Integer):
+ if getattr(type_, 'unsigned', False):
+ return 'UNSIGNED INTEGER'
+ else:
+ return 'SIGNED INTEGER'
+ elif isinstance(type_, sqltypes.TIMESTAMP):
+ return 'DATETIME'
+ elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, sqltypes.Date, sqltypes.Time)):
+ return self.dialect.type_compiler.process(type_)
+ elif isinstance(type_, sqltypes.Text):
+ return 'CHAR'
+ elif (isinstance(type_, sqltypes.String) and not
+ isinstance(type_, (ENUM, SET))):
+ if getattr(type_, 'length'):
+ return 'CHAR(%s)' % type_.length
+ else:
+ return 'CHAR'
+ elif isinstance(type_, sqltypes.Binary):
+ return 'BINARY'
+ elif isinstance(type_, NUMERIC):
+ return self.dialect.type_compiler.process(type_).replace('NUMERIC', 'DECIMAL')
+ else:
+ return None
- def dbapi(cls):
- import MySQLdb as mysql
- return mysql
- dbapi = classmethod(dbapi)
-
- def create_connect_args(self, url):
- opts = url.translate_connect_args(database='db', username='user',
- password='passwd')
- opts.update(url.query)
-
- util.coerce_kw_type(opts, 'compress', bool)
- util.coerce_kw_type(opts, 'connect_timeout', int)
- util.coerce_kw_type(opts, 'client_flag', int)
- util.coerce_kw_type(opts, 'local_infile', int)
- # Note: using either of the below will cause all strings to be returned
- # as Unicode, both in raw SQL operations and with column types like
- # String and MSString.
- util.coerce_kw_type(opts, 'use_unicode', bool)
- util.coerce_kw_type(opts, 'charset', str)
-
- # Rich values 'cursorclass' and 'conv' are not supported via
- # query string.
-
- ssl = {}
- for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']:
- if key in opts:
- ssl[key[4:]] = opts[key]
- util.coerce_kw_type(ssl, key[4:], str)
- del opts[key]
- if ssl:
- opts['ssl'] = ssl
-
- # FOUND_ROWS must be set in CLIENT_FLAGS to enable
- # supports_sane_rowcount.
- client_flag = opts.get('client_flag', 0)
- if self.dbapi is not None:
- try:
- import MySQLdb.constants.CLIENT as CLIENT_FLAGS
- client_flag |= CLIENT_FLAGS.FOUND_ROWS
- except:
- pass
- opts['client_flag'] = client_flag
- return [[], opts]
+ def visit_cast(self, cast, **kwargs):
+ # No cast until 4, no decimals until 5.
+ type_ = self.process(cast.typeclause)
+ if type_ is None:
+ return self.process(cast.clause)
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
+ return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
- def do_executemany(self, cursor, statement, parameters, context=None):
- rowcount = cursor.executemany(statement, parameters)
- if context is not None:
- context._rowcount = rowcount
+ def get_select_precolumns(self, select):
+ if isinstance(select._distinct, basestring):
+ return select._distinct.upper() + " "
+ elif select._distinct:
+ return "DISTINCT "
+ else:
+ return ""
+
+ def visit_join(self, join, asfrom=False, **kwargs):
+ # 'JOIN ... ON ...' for inner joins isn't available until 4.0.
+ # Apparently < 3.23.17 requires theta joins for inner joins
+ # (but not outer). Not generating these currently, but
+ # support can be added, preferably after dialects are
+ # refactored to be version-sensitive.
+ return ''.join(
+ (self.process(join.left, asfrom=True),
+ (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "),
+ self.process(join.right, asfrom=True),
+ " ON ",
+ self.process(join.onclause)))
+
+ def for_update_clause(self, select):
+ if select.for_update == 'read':
+ return ' LOCK IN SHARE MODE'
+ else:
+ return super(MySQLCompiler, self).for_update_clause(select)
+
+ def limit_clause(self, select):
+ # MySQL supports:
+ # LIMIT <limit>
+ # LIMIT <offset>, <limit>
+ # and in server versions > 3.3:
+ # LIMIT <limit> OFFSET <offset>
+ # The latter is more readable for offsets but we're stuck with the
+ # former until we can refine dialects by server revision.
+
+ limit, offset = select._limit, select._offset
+
+ if (limit, offset) == (None, None):
+ return ''
+ elif offset is not None:
+ # As suggested by the MySQL docs, need to apply an
+ # artificial limit if one wasn't provided
+ if limit is None:
+ limit = 18446744073709551615
+ return ' \n LIMIT %s, %s' % (offset, limit)
+ else:
+ # No offset provided, so just use the limit
+ return ' \n LIMIT %s' % (limit,)
+
+ def visit_update(self, update_stmt):
+ self.stack.append({'from': set([update_stmt.table])})
+
+ self.isupdate = True
+ colparams = self._get_colparams(update_stmt)
+
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \
+ " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams])
+
+ if update_stmt._whereclause:
+ text += " WHERE " + self.process(update_stmt._whereclause)
+
+ limit = update_stmt.kwargs.get('mysql_limit', None)
+ if limit:
+ text += " LIMIT %s" % limit
+
+ self.stack.pop(-1)
+
+ return text
+
+# ug. "InnoDB needs indexes on foreign keys and referenced keys [...].
+# Starting with MySQL 4.1.2, these indexes are created automatically.
+# In older versions, the indexes must be created explicitly or the
+# creation of foreign key constraints fails."
+
+class MySQLDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kw):
+ """Builds column DDL."""
+
+ colspec = [self.preparer.format_column(column),
+ self.dialect.type_compiler.process(column.type)
+ ]
+
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec.append('DEFAULT ' + default)
+
+ if not column.nullable:
+ colspec.append('NOT NULL')
+
+ if column.primary_key and column.autoincrement:
+ try:
+ first = [c for c in column.table.primary_key.columns
+ if (c.autoincrement and
+ isinstance(c.type, sqltypes.Integer) and
+ not c.foreign_keys)].pop(0)
+ if column is first:
+ colspec.append('AUTO_INCREMENT')
+ except IndexError:
+ pass
+
+ return ' '.join(colspec)
- def supports_unicode_statements(self):
- return True
+ def post_create_table(self, table):
+ """Build table-level CREATE options like ENGINE and COLLATE."""
+
+ table_opts = []
+ for k in table.kwargs:
+ if k.startswith('mysql_'):
+ opt = k[6:].upper()
+ joiner = '='
+ if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET',
+ 'CHARACTER SET', 'COLLATE'):
+ joiner = ' '
+
+ table_opts.append(joiner.join((opt, table.kwargs[k])))
+ return ' '.join(table_opts)
+
+ def visit_drop_index(self, drop):
+ index = drop.element
+
+ return "\nDROP INDEX %s ON %s" % \
+ (self.preparer.quote(self._validate_identifier(index.name, False), index.quote),
+ self.preparer.format_table(index.table))
+
+ def visit_drop_constraint(self, drop):
+ constraint = drop.element
+ if isinstance(constraint, sa_schema.ForeignKeyConstraint):
+ qual = "FOREIGN KEY "
+ const = self.preparer.format_constraint(constraint)
+ elif isinstance(constraint, sa_schema.PrimaryKeyConstraint):
+ qual = "PRIMARY KEY "
+ const = ""
+ elif isinstance(constraint, sa_schema.UniqueConstraint):
+ qual = "INDEX "
+ const = self.preparer.format_constraint(constraint)
+ else:
+ qual = ""
+ const = self.preparer.format_constraint(constraint)
+ return "ALTER TABLE %s DROP %s%s" % \
+ (self.preparer.format_table(constraint.table),
+ qual, const)
+
+class MySQLTypeCompiler(compiler.GenericTypeCompiler):
+ def _extend_numeric(self, type_, spec):
+ "Extend a numeric-type declaration with MySQL specific extensions."
+
+ if not self._mysql_type(type_):
+ return spec
+
+ if type_.unsigned:
+ spec += ' UNSIGNED'
+ if type_.zerofill:
+ spec += ' ZEROFILL'
+ return spec
+
+ def _extend_string(self, type_, defaults, spec):
+ """Extend a string-type declaration with standard SQL CHARACTER SET /
+ COLLATE annotations and MySQL specific extensions.
+
+ """
+
+ def attr(name):
+ return getattr(type_, name, defaults.get(name))
+
+ if attr('charset'):
+ charset = 'CHARACTER SET %s' % attr('charset')
+ elif attr('ascii'):
+ charset = 'ASCII'
+ elif attr('unicode'):
+ charset = 'UNICODE'
+ else:
+ charset = None
+
+ if attr('collation'):
+ collation = 'COLLATE %s' % type_.collation
+ elif attr('binary'):
+ collation = 'BINARY'
+ else:
+ collation = None
+
+ if attr('national'):
+ # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets.
+ return ' '.join([c for c in ('NATIONAL', spec, collation)
+ if c is not None])
+ return ' '.join([c for c in (spec, charset, collation)
+ if c is not None])
+
+ def _mysql_type(self, type_):
+ return isinstance(type_, (_StringType, _NumericType, _BinaryType))
+
+ def visit_NUMERIC(self, type_):
+ if type_.precision is None:
+ return self._extend_numeric(type_, "NUMERIC")
+ elif type_.scale is None:
+ return self._extend_numeric(type_, "NUMERIC(%(precision)s)" % {'precision': type_.precision})
+ else:
+ return self._extend_numeric(type_, "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale})
+
+ def visit_DECIMAL(self, type_):
+ if type_.precision is None:
+ return self._extend_numeric(type_, "DECIMAL")
+ elif type_.scale is None:
+ return self._extend_numeric(type_, "DECIMAL(%(precision)s)" % {'precision': type_.precision})
+ else:
+ return self._extend_numeric(type_, "DECIMAL(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale})
+
+ def visit_DOUBLE(self, type_):
+ if type_.precision is not None and type_.scale is not None:
+ return self._extend_numeric(type_, "DOUBLE(%(precision)s, %(scale)s)" %
+ {'precision': type_.precision,
+ 'scale' : type_.scale})
+ else:
+ return self._extend_numeric(type_, 'DOUBLE')
+
+ def visit_REAL(self, type_):
+ if type_.precision is not None and type_.scale is not None:
+ return self._extend_numeric(type_, "REAL(%(precision)s, %(scale)s)" %
+ {'precision': type_.precision,
+ 'scale' : type_.scale})
+ else:
+ return self._extend_numeric(type_, 'REAL')
+
+ def visit_FLOAT(self, type_):
+ if self._mysql_type(type_) and type_.scale is not None and type_.precision is not None:
+ return self._extend_numeric(type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale))
+ elif type_.precision is not None:
+ return self._extend_numeric(type_, "FLOAT(%s)" % (type_.precision,))
+ else:
+ return self._extend_numeric(type_, "FLOAT")
+
+ def visit_INTEGER(self, type_):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(type_, "INTEGER(%(display_width)s)" % {'display_width': type_.display_width})
+ else:
+ return self._extend_numeric(type_, "INTEGER")
+
+ def visit_BIGINT(self, type_):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(type_, "BIGINT(%(display_width)s)" % {'display_width': type_.display_width})
+ else:
+ return self._extend_numeric(type_, "BIGINT")
+
+ def visit_MEDIUMINT(self, type_):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(type_, "MEDIUMINT(%(display_width)s)" % {'display_width': type_.display_width})
+ else:
+ return self._extend_numeric(type_, "MEDIUMINT")
+
+ def visit_TINYINT(self, type_):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(type_, "TINYINT(%s)" % type_.display_width)
+ else:
+ return self._extend_numeric(type_, "TINYINT")
+
+ def visit_SMALLINT(self, type_):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(type_, "SMALLINT(%(display_width)s)" % {'display_width': type_.display_width})
+ else:
+ return self._extend_numeric(type_, "SMALLINT")
+
+ def visit_BIT(self, type_):
+ if type_.length is not None:
+ return "BIT(%s)" % type_.length
+ else:
+ return "BIT"
+
+ def visit_DATETIME(self, type_):
+ return "DATETIME"
+
+ def visit_DATE(self, type_):
+ return "DATE"
+
+ def visit_TIME(self, type_):
+ return "TIME"
+
+ def visit_TIMESTAMP(self, type_):
+ return 'TIMESTAMP'
+
+ def visit_YEAR(self, type_):
+ if type_.display_width is None:
+ return "YEAR"
+ else:
+ return "YEAR(%s)" % type_.display_width
+
+ def visit_TEXT(self, type_):
+ if type_.length:
+ return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
+ else:
+ return self._extend_string(type_, {}, "TEXT")
+
+ def visit_TINYTEXT(self, type_):
+ return self._extend_string(type_, {}, "TINYTEXT")
+
+ def visit_MEDIUMTEXT(self, type_):
+ return self._extend_string(type_, {}, "MEDIUMTEXT")
+
+ def visit_LONGTEXT(self, type_):
+ return self._extend_string(type_, {}, "LONGTEXT")
+
+ def visit_VARCHAR(self, type_):
+ if type_.length:
+ return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length)
+ else:
+ return self._extend_string(type_, {}, "VARCHAR")
+
+ def visit_CHAR(self, type_):
+ return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length' : type_.length})
+
+ def visit_NVARCHAR(self, type_):
+ # We'll actually generate the equiv. "NATIONAL VARCHAR" instead
+ # of "NVARCHAR".
+ return self._extend_string(type_, {'national':True}, "VARCHAR(%(length)s)" % {'length': type_.length})
+
+ def visit_NCHAR(self, type_):
+ # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR".
+ return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length': type_.length})
+
+ def visit_VARBINARY(self, type_):
+ if type_.length:
+ return "VARBINARY(%d)" % type_.length
+ else:
+ return self.visit_BLOB(type_)
+
+ def visit_binary(self, type_):
+ return self.visit_BLOB(type_)
+
+ def visit_BINARY(self, type_):
+ if type_.length:
+ return "BINARY(%d)" % type_.length
+ else:
+ return self.visit_BLOB(type_)
+
+ def visit_BLOB(self, type_):
+ if type_.length:
+ return "BLOB(%d)" % type_.length
+ else:
+ return "BLOB"
+
+ def visit_TINYBLOB(self, type_):
+ return "TINYBLOB"
+
+ def visit_MEDIUMBLOB(self, type_):
+ return "MEDIUMBLOB"
+
+ def visit_LONGBLOB(self, type_):
+ return "LONGBLOB"
+
+ def visit_ENUM(self, type_):
+ quoted_enums = []
+ for e in type_.enums:
+ quoted_enums.append("'%s'" % e.replace("'", "''"))
+ return self._extend_string(type_, {}, "ENUM(%s)" % ",".join(quoted_enums))
+
+ def visit_SET(self, type_):
+ return self._extend_string(type_, {}, "SET(%s)" % ",".join(type_._ddl_values))
+
+ def visit_BOOLEAN(self, type):
+ return "BOOL"
+
+
+class MySQLDialect(default.DefaultDialect):
+ """Details of the MySQL dialect. Not used directly in application code."""
+ name = 'mysql'
+ supports_alter = True
+ # identifiers are 64, however aliases can be 255...
+ max_identifier_length = 255
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+
+ default_paramstyle = 'format'
+ colspecs = colspecs
+
+ statement_compiler = MySQLCompiler
+ ddl_compiler = MySQLDDLCompiler
+ type_compiler = MySQLTypeCompiler
+ ischema_names = ischema_names
+
+ def __init__(self, use_ansiquotes=None, **kwargs):
+ default.DefaultDialect.__init__(self, **kwargs)
def do_commit(self, connection):
"""Execute a COMMIT."""
try:
connection.commit()
except:
- if self._server_version_info(connection) < (3, 23, 15):
+ if self.server_version_info < (3, 23, 15):
args = sys.exc_info()[1].args
if args and args[0] == 1064:
return
try:
connection.rollback()
except:
- if self._server_version_info(connection) < (3, 23, 15):
+ if self.server_version_info < (3, 23, 15):
args = sys.exc_info()[1].args
if args and args[0] == 1064:
return
raise
def do_begin_twophase(self, connection, xid):
- connection.execute("XA BEGIN %s", xid)
+ connection.execute(sql.text("XA BEGIN :xid"), xid=xid)
def do_prepare_twophase(self, connection, xid):
- connection.execute("XA END %s", xid)
- connection.execute("XA PREPARE %s", xid)
+ connection.execute(sql.text("XA END :xid"), xid=xid)
+ connection.execute(sql.text("XA PREPARE :xid"), xid=xid)
def do_rollback_twophase(self, connection, xid, is_prepared=True,
recover=False):
if not is_prepared:
- connection.execute("XA END %s", xid)
- connection.execute("XA ROLLBACK %s", xid)
+ connection.execute(sql.text("XA END :xid"), xid=xid)
+ connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid)
def do_commit_twophase(self, connection, xid, is_prepared=True,
recover=False):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
- connection.execute("XA COMMIT %s", xid)
+ connection.execute(sql.text("XA COMMIT :xid"), xid=xid)
def do_recover_twophase(self, connection):
resultset = connection.execute("XA RECOVER")
return [row['data'][0:row['gtrid_length']] for row in resultset]
- def do_ping(self, connection):
- connection.ping()
-
def is_disconnect(self, e):
if isinstance(e, self.dbapi.OperationalError):
- return e.args[0] in (2006, 2013, 2014, 2045, 2055)
+ return self._extract_error_code(e) in (2006, 2013, 2014, 2045, 2055)
elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, this is the error you get
return "(0, '')" in str(e)
else:
return False
+ def _compat_fetchall(self, rp, charset=None):
+ """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
+
+ return [_DecodingRowProxy(row, charset) for row in rp.fetchall()]
+
+ def _compat_fetchone(self, rp, charset=None):
+ """Proxy a result row to smooth over MySQL-Python driver inconsistencies."""
+
+ return _DecodingRowProxy(rp.fetchone(), charset)
+
+ def _extract_error_code(self, exception):
+ raise NotImplementedError()
+
def get_default_schema_name(self, connection):
return connection.execute('SELECT DATABASE()').scalar()
- get_default_schema_name = engine_base.connection_memoize(
- ('dialect', 'default_schema_name'))(get_default_schema_name)
def table_names(self, connection, schema):
"""Return a Unicode SHOW TABLES from a given schema."""
- charset = self._detect_charset(connection)
- self._autoset_identifier_style(connection)
+ charset = self._connection_charset
rp = connection.execute("SHOW TABLES FROM %s" %
self.identifier_preparer.quote_identifier(schema))
- return [row[0] for row in _compat_fetchall(rp, charset=charset)]
+ return [row[0] for row in self._compat_fetchall(rp, charset=charset)]
def has_table(self, connection, table_name, schema=None):
# SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly
# full_name = self.identifier_preparer.format_table(table,
# use_schema=True)
- self._autoset_identifier_style(connection)
full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
schema, table_name))
rs.close()
return have
except exc.SQLError, e:
- if e.orig.args[0] == 1146:
+ if self._extract_error_code(e) == 1146:
return False
raise
finally:
if rs:
rs.close()
+
+ def initialize(self, connection):
+ self.server_version_info = self._get_server_version_info(connection)
+ self._connection_charset = self._detect_charset(connection)
+ self._server_casing = self._detect_casing(connection)
+ self._server_collations = self._detect_collations(connection)
+ self._server_ansiquotes = self._detect_ansiquotes(connection)
+
+ if self._server_ansiquotes:
+ self.preparer = MySQLANSIIdentifierPreparer
+ else:
+ self.preparer = MySQLIdentifierPreparer
+ self.identifier_preparer = self.preparer(self)
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ rp = connection.execute("SHOW schemas")
+ return [r[0] for r in rp]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ if schema is None:
+ schema = self.get_default_schema_name(connection)
+ if self.server_version_info < (5, 0, 2):
+ return self.table_names(connection, schema)
+ charset = self._connection_charset
+ rp = connection.execute("SHOW FULL TABLES FROM %s" %
+ self.identifier_preparer.quote_identifier(schema))
+
+ return [row[0] for row in self._compat_fetchall(rp, charset=charset)\
+ if row[1] == 'BASE TABLE']
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ charset = self._connection_charset
+ if self.server_version_info < (5, 0, 2):
+ raise NotImplementedError
+ if schema is None:
+ schema = self.get_default_schema_name(connection)
+ if self.server_version_info < (5, 0, 2):
+ return self.table_names(connection, schema)
+ charset = self._connection_charset
+ rp = connection.execute("SHOW FULL TABLES FROM %s" %
+ self.identifier_preparer.quote_identifier(schema))
+ return [row[0] for row in self._compat_fetchall(rp, charset=charset)\
+ if row[1] == 'VIEW']
+
+ @reflection.cache
+ def get_table_options(self, connection, table_name, schema=None, **kw):
+
+ parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+ return parsed_state.table_options
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+ return parsed_state.columns
+
+ @reflection.cache
+ def get_primary_keys(self, connection, table_name, schema=None, **kw):
+ parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+ for key in parsed_state.keys:
+ if key['type'] == 'PRIMARY':
+ # There can be only one.
+ ##raise Exception, str(key)
+ return [s[0] for s in key['columns']]
+ return []
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+
+ parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+ default_schema = None
+
+ fkeys = []
- def server_version_info(self, connection):
- """A tuple of the database server version.
-
- Formats the remote server version as a tuple of version values,
- e.g. ``(5, 0, 44)``. If there are strings in the version number
- they will be in the tuple too, so don't count on these all being
- ``int`` values.
-
- This is a fast check that does not require a round trip. It is also
- cached per-Connection.
- """
-
- return self._server_version_info(connection.connection.connection)
- server_version_info = engine_base.connection_memoize(
- ('mysql', 'server_version_info'))(server_version_info)
+ for spec in parsed_state.constraints:
+ # only FOREIGN KEYs
+ ref_name = spec['table'][-1]
+ ref_schema = len(spec['table']) > 1 and spec['table'][-2] or schema
- def _server_version_info(self, dbapi_con):
- """Convert a MySQL-python server_info string into a tuple."""
+ if not ref_schema:
+ if default_schema is None:
+ default_schema = \
+ connection.dialect.get_default_schema_name(connection)
+ if schema == default_schema:
+ ref_schema = schema
- version = []
- r = re.compile('[.\-]')
- for n in r.split(dbapi_con.get_server_info()):
- try:
- version.append(int(n))
- except ValueError:
- version.append(n)
- return tuple(version)
+ loc_names = spec['local']
+ ref_names = spec['foreign']
- def reflecttable(self, connection, table, include_columns):
- """Load column definitions from the server."""
+ con_kw = {}
+ for opt in ('name', 'onupdate', 'ondelete'):
+ if spec.get(opt, False):
+ con_kw[opt] = spec[opt]
- charset = self._detect_charset(connection)
- self._autoset_identifier_style(connection)
+ fkey_d = {
+ 'name' : spec['name'],
+ 'constrained_columns' : loc_names,
+ 'referred_schema' : ref_schema,
+ 'referred_table' : ref_name,
+ 'referred_columns' : ref_names,
+ 'options' : con_kw
+ }
+ fkeys.append(fkey_d)
+ return fkeys
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+
+ parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+
+ indexes = []
+ for spec in parsed_state.keys:
+ unique = False
+ flavor = spec['type']
+ if flavor == 'PRIMARY':
+ continue
+ if flavor == 'UNIQUE':
+ unique = True
+ elif flavor in (None, 'FULLTEXT', 'SPATIAL'):
+ pass
+ else:
+ self.logger.info(
+ "Converting unknown KEY type %s to a plain KEY" % flavor)
+ pass
+ index_d = {}
+ index_d['name'] = spec['name']
+ index_d['column_names'] = [s[0] for s in spec['columns']]
+ index_d['unique'] = unique
+ index_d['type'] = flavor
+ indexes.append(index_d)
+ return indexes
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+
+ charset = self._connection_charset
+ full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
+ schema, view_name))
+ sql = self._show_create_table(connection, None, charset,
+ full_name=full_name)
+ return sql
+ def _parsed_state_or_create(self, connection, table_name, schema=None, **kw):
+ return self._setup_parser(
+ connection,
+ table_name,
+ schema,
+ info_cache=kw.get('info_cache', None)
+ )
+
+ @reflection.cache
+ def _setup_parser(self, connection, table_name, schema=None, **kw):
+ charset = self._connection_charset
try:
- reflector = self.reflector
+ parser = self.parser
except AttributeError:
preparer = self.identifier_preparer
- if (self.server_version_info(connection) < (4, 1) and
- self.use_ansiquotes):
+ if (self.server_version_info < (4, 1) and
+ self._server_ansiquotes):
# ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1
preparer = MySQLIdentifierPreparer(self)
-
- self.reflector = reflector = MySQLSchemaReflector(preparer)
-
- sql = self._show_create_table(connection, table, charset)
+ self.parser = parser = MySQLTableDefinitionParser(self, preparer)
+ full_name = '.'.join(self.identifier_preparer._quote_free_identifiers(
+ schema, table_name))
+ sql = self._show_create_table(connection, None, charset,
+ full_name=full_name)
if sql.startswith('CREATE ALGORITHM'):
# Adapt views to something table-like.
- columns = self._describe_table(connection, table, charset)
- sql = reflector._describe_to_create(table, columns)
-
- self._adjust_casing(connection, table)
-
- return reflector.reflect(connection, table, sql, charset,
- only=include_columns)
-
- def _adjust_casing(self, connection, table, charset=None):
+ columns = self._describe_table(connection, None, charset,
+ full_name=full_name)
+ sql = parser._describe_to_create(table_name, columns)
+ return parser.parse(sql, charset)
+
+ def _adjust_casing(self, table, charset=None):
"""Adjust Table name to the server case sensitivity, if needed."""
- casing = self._detect_casing(connection)
+ casing = self._server_casing
# For winxx database hosts. TODO: is this really needed?
if casing == 1 and table.name != table.name.lower():
lc_alias = schema._get_table_key(table.name, table.schema)
table.metadata.tables[lc_alias] = table
-
def _detect_charset(self, connection):
- """Sniff out the character set in use for connection results."""
-
- # Allow user override, won't sniff if force_charset is set.
- if ('mysql', 'force_charset') in connection.info:
- return connection.info[('mysql', 'force_charset')]
-
- # Note: MySQL-python 1.2.1c7 seems to ignore changes made
- # on a connection via set_character_set()
- if self.server_version_info(connection) < (4, 1, 0):
- try:
- return connection.connection.character_set_name()
- except AttributeError:
- # < 1.2.1 final MySQL-python drivers have no charset support.
- # a query is needed.
- pass
-
- # Prefer 'character_set_results' for the current connection over the
- # value in the driver. SET NAMES or individual variable SETs will
- # change the charset without updating the driver's view of the world.
- #
- # If it's decided that issuing that sort of SQL leaves you SOL, then
- # this can prefer the driver value.
- rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
- opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)])
-
- if 'character_set_results' in opts:
- return opts['character_set_results']
- try:
- return connection.connection.character_set_name()
- except AttributeError:
- # Still no charset on < 1.2.1 final...
- if 'character_set' in opts:
- return opts['character_set']
- else:
- util.warn(
- "Could not detect the connection character set with this "
- "combination of MySQL server and MySQL-python. "
- "MySQL-python >= 1.2.2 is recommended. Assuming latin1.")
- return 'latin1'
- _detect_charset = engine_base.connection_memoize(
- ('mysql', 'charset'))(_detect_charset)
-
+ raise NotImplementedError()
def _detect_casing(self, connection):
"""Sniff out identifier case sensitivity.
"""
# http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
- charset = self._detect_charset(connection)
- row = _compat_fetchone(connection.execute(
+ charset = self._connection_charset
+ row = self._compat_fetchone(connection.execute(
"SHOW VARIABLES LIKE 'lower_case_table_names'"),
charset=charset)
if not row:
cs = int(row[1])
row.close()
return cs
- _detect_casing = engine_base.connection_memoize(
- ('mysql', 'lower_case_table_names'))(_detect_casing)
def _detect_collations(self, connection):
"""Pull the active COLLATIONS list from the server.
"""
collations = {}
- if self.server_version_info(connection) < (4, 1, 0):
+ if self.server_version_info < (4, 1, 0):
pass
else:
- charset = self._detect_charset(connection)
+ charset = self._connection_charset
rs = connection.execute('SHOW COLLATION')
- for row in _compat_fetchall(rs, charset):
+ for row in self._compat_fetchall(rs, charset):
collations[row[0]] = row[1]
return collations
- _detect_collations = engine_base.connection_memoize(
- ('mysql', 'collations'))(_detect_collations)
-
- def use_ansiquotes(self, useansi):
- self._use_ansiquotes = useansi
- if useansi:
- self.preparer = MySQLANSIIdentifierPreparer
- else:
- self.preparer = MySQLIdentifierPreparer
- # icky
- if hasattr(self, 'identifier_preparer'):
- self.identifier_preparer = self.preparer(self)
- if hasattr(self, 'reflector'):
- del self.reflector
-
- use_ansiquotes = property(lambda s: s._use_ansiquotes, use_ansiquotes,
- doc="True if ANSI_QUOTES is in effect.")
-
- def _autoset_identifier_style(self, connection, charset=None):
- """Detect and adjust for the ANSI_QUOTES sql mode.
- If the dialect's use_ansiquotes is unset, query the server's sql mode
- and reset the identifier style.
-
- Note that this currently *only* runs during reflection. Ideally this
- would run the first time a connection pool connects to the database,
- but the infrastructure for that is not yet in place.
- """
-
- if self.use_ansiquotes is not None:
- return
+ def _detect_ansiquotes(self, connection):
+ """Detect and adjust for the ANSI_QUOTES sql mode."""
- row = _compat_fetchone(
+ row = self._compat_fetchone(
connection.execute("SHOW VARIABLES LIKE 'sql_mode'"),
- charset=charset)
+ charset=self._connection_charset)
+
if not row:
mode = ''
else:
mode_no = int(mode)
mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or ''
- self.use_ansiquotes = 'ANSI_QUOTES' in mode
+ return 'ANSI_QUOTES' in mode
def _show_create_table(self, connection, table, charset=None,
full_name=None):
try:
rp = connection.execute(st)
except exc.SQLError, e:
- if e.orig.args[0] == 1146:
+ if self._extract_error_code(e) == 1146:
raise exc.NoSuchTableError(full_name)
else:
raise
- row = _compat_fetchone(rp, charset=charset)
+ row = self._compat_fetchone(rp, charset=charset)
if not row:
raise exc.NoSuchTableError(full_name)
return row[1].strip()
try:
rp = connection.execute(st)
except exc.SQLError, e:
- if e.orig.args[0] == 1146:
+ if self._extract_error_code(e) == 1146:
raise exc.NoSuchTableError(full_name)
else:
raise
- rows = _compat_fetchall(rp, charset=charset)
+ rows = self._compat_fetchall(rp, charset=charset)
finally:
if rp:
rp.close()
return rows
-class _MySQLPythonRowProxy(object):
- """Return consistent column values for all versions of MySQL-python.
-
- Smooth over data type issues (esp. with alpha driver versions) and
- normalize strings as Unicode regardless of user-configured driver
- encoding settings.
- """
-
- # Some MySQL-python versions can return some columns as
- # sets.Set(['value']) (seriously) but thankfully that doesn't
- # seem to come up in DDL queries.
-
- def __init__(self, rowproxy, charset):
- self.rowproxy = rowproxy
- self.charset = charset
- def __getitem__(self, index):
- item = self.rowproxy[index]
- if isinstance(item, _array):
- item = item.tostring()
- if self.charset and isinstance(item, str):
- return item.decode(self.charset)
- else:
- return item
- def __getattr__(self, attr):
- item = getattr(self.rowproxy, attr)
- if isinstance(item, _array):
- item = item.tostring()
- if self.charset and isinstance(item, str):
- return item.decode(self.charset)
- else:
- return item
-
-
-class MySQLCompiler(compiler.DefaultCompiler):
- operators = compiler.DefaultCompiler.operators.copy()
- operators.update({
- sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
- sql_operators.mod: '%%',
- sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y)
- })
- functions = compiler.DefaultCompiler.functions.copy()
- functions.update ({
- sql_functions.random: 'rand%(expr)s',
- "utc_timestamp":"UTC_TIMESTAMP"
- })
-
- extract_map = compiler.DefaultCompiler.extract_map.copy()
- extract_map.update ({
- 'milliseconds': 'millisecond',
- })
-
- def visit_typeclause(self, typeclause):
- type_ = typeclause.type.dialect_impl(self.dialect)
- if isinstance(type_, MSInteger):
- if getattr(type_, 'unsigned', False):
- return 'UNSIGNED INTEGER'
- else:
- return 'SIGNED INTEGER'
- elif isinstance(type_, (MSDecimal, MSDateTime, MSDate, MSTime)):
- return type_.get_col_spec()
- elif isinstance(type_, MSText):
- return 'CHAR'
- elif (isinstance(type_, _StringType) and not
- isinstance(type_, (MSEnum, MSSet))):
- if getattr(type_, 'length'):
- return 'CHAR(%s)' % type_.length
- else:
- return 'CHAR'
- elif isinstance(type_, _BinaryType):
- return 'BINARY'
- elif isinstance(type_, MSNumeric):
- return type_.get_col_spec().replace('NUMERIC', 'DECIMAL')
- elif isinstance(type_, MSTimeStamp):
- return 'DATETIME'
- elif isinstance(type_, (MSDateTime, MSDate, MSTime)):
- return type_.get_col_spec()
- else:
- return None
-
- def visit_cast(self, cast, **kwargs):
- # No cast until 4, no decimals until 5.
- type_ = self.process(cast.typeclause)
- if type_ is None:
- return self.process(cast.clause)
-
- return 'CAST(%s AS %s)' % (self.process(cast.clause), type_)
-
-
- def post_process_text(self, text):
- if '%%' in text:
- util.warn("The SQLAlchemy MySQLDB dialect now automatically escapes '%' in text() expressions to '%%'.")
- return text.replace('%', '%%')
-
- def get_select_precolumns(self, select):
- if isinstance(select._distinct, basestring):
- return select._distinct.upper() + " "
- elif select._distinct:
- return "DISTINCT "
- else:
- return ""
-
- def visit_join(self, join, asfrom=False, **kwargs):
- # 'JOIN ... ON ...' for inner joins isn't available until 4.0.
- # Apparently < 3.23.17 requires theta joins for inner joins
- # (but not outer). Not generating these currently, but
- # support can be added, preferably after dialects are
- # refactored to be version-sensitive.
- return ''.join(
- (self.process(join.left, asfrom=True),
- (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "),
- self.process(join.right, asfrom=True),
- " ON ",
- self.process(join.onclause)))
-
- def for_update_clause(self, select):
- if select.for_update == 'read':
- return ' LOCK IN SHARE MODE'
- else:
- return super(MySQLCompiler, self).for_update_clause(select)
-
- def limit_clause(self, select):
- # MySQL supports:
- # LIMIT <limit>
- # LIMIT <offset>, <limit>
- # and in server versions > 3.3:
- # LIMIT <limit> OFFSET <offset>
- # The latter is more readable for offsets but we're stuck with the
- # former until we can refine dialects by server revision.
-
- limit, offset = select._limit, select._offset
-
- if (limit, offset) == (None, None):
- return ''
- elif offset is not None:
- # As suggested by the MySQL docs, need to apply an
- # artificial limit if one wasn't provided
- if limit is None:
- limit = 18446744073709551615
- return ' \n LIMIT %s, %s' % (offset, limit)
- else:
- # No offset provided, so just use the limit
- return ' \n LIMIT %s' % (limit,)
-
- def visit_update(self, update_stmt):
- self.stack.append({'from': set([update_stmt.table])})
-
- self.isupdate = True
- colparams = self._get_colparams(update_stmt)
-
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams])
-
- if update_stmt._whereclause:
- text += " WHERE " + self.process(update_stmt._whereclause)
-
- limit = update_stmt.kwargs.get('mysql_limit', None)
- if limit:
- text += " LIMIT %s" % limit
-
- self.stack.pop(-1)
-
- return text
-
-# ug. "InnoDB needs indexes on foreign keys and referenced keys [...].
-# Starting with MySQL 4.1.2, these indexes are created automatically.
-# In older versions, the indexes must be created explicitly or the
-# creation of foreign key constraints fails."
-
-class MySQLSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, first_pk=False):
- """Builds column DDL."""
-
- colspec = [self.preparer.format_column(column),
- column.type.dialect_impl(self.dialect).get_col_spec()]
-
- default = self.get_column_default_string(column)
- if default is not None:
- colspec.append('DEFAULT ' + default)
-
- if not column.nullable:
- colspec.append('NOT NULL')
-
- if column.primary_key and column.autoincrement:
- try:
- first = [c for c in column.table.primary_key.columns
- if (c.autoincrement and
- isinstance(c.type, sqltypes.Integer) and
- not c.foreign_keys)].pop(0)
- if column is first:
- colspec.append('AUTO_INCREMENT')
- except IndexError:
- pass
-
- return ' '.join(colspec)
-
- def post_create_table(self, table):
- """Build table-level CREATE options like ENGINE and COLLATE."""
-
- table_opts = []
- for k in table.kwargs:
- if k.startswith('mysql_'):
- opt = k[6:].upper()
- joiner = '='
- if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET',
- 'CHARACTER SET', 'COLLATE'):
- joiner = ' '
-
- table_opts.append(joiner.join((opt, table.kwargs[k])))
- return ' '.join(table_opts)
-
-
-class MySQLSchemaDropper(compiler.SchemaDropper):
- def visit_index(self, index):
- self.append("\nDROP INDEX %s ON %s" %
- (self.preparer.quote(self._validate_identifier(index.name, False), index.quote),
- self.preparer.format_table(index.table)))
- self.execute()
-
- def drop_foreignkey(self, constraint):
- self.append("ALTER TABLE %s DROP FOREIGN KEY %s" %
- (self.preparer.format_table(constraint.table),
- self.preparer.format_constraint(constraint)))
- self.execute()
-
-
-class MySQLSchemaReflector(object):
- """Parses SHOW CREATE TABLE output."""
-
- def __init__(self, identifier_preparer):
- """Construct a MySQLSchemaReflector.
-
- identifier_preparer
- An ANSIIdentifierPreparer type, used to determine the identifier
- quoting style in effect.
- """
-
- self.preparer = identifier_preparer
+class ReflectedState(object):
+ """Stores raw information about a SHOW CREATE TABLE statement."""
+
+ def __init__(self):
+ self.columns = []
+ self.table_options = {}
+ self.table_name = None
+ self.keys = []
+ self.constraints = []
+
+class MySQLTableDefinitionParser(object):
+ """Parses the results of a SHOW CREATE TABLE statement."""
+
+ def __init__(self, dialect, preparer):
+ self.dialect = dialect
+ self.preparer = preparer
self._prep_regexes()
- def reflect(self, connection, table, show_create, charset, only=None):
- """Parse MySQL SHOW CREATE TABLE and fill in a ''Table''.
-
- show_create
- Unicode output of SHOW CREATE TABLE
-
- table
- A ''Table'', to be loaded with Columns, Indexes, etc.
- table.name will be set if not already
-
- charset
- FIXME, some constructed values (like column defaults)
- currently can't be Unicode. ''charset'' will convert them
- into the connection character set.
-
- only
- An optional sequence of column names. If provided, only
- these columns will be reflected, and any keys or constraints
- that include columns outside this set will also be omitted.
- That means that if ``only`` includes only one column in a
- 2 part primary key, the entire primary key will be omitted.
- """
-
- keys, constraints = [], []
-
- if only:
- only = set(only)
-
+ def parse(self, show_create, charset):
+ state = ReflectedState()
+ state.charset = charset
for line in re.split(r'\r?\n', show_create):
if line.startswith(' ' + self.preparer.initial_quote):
- self._add_column(table, line, charset, only)
+ self._parse_column(line, state)
# a regular table options line
elif line.startswith(') '):
- self._set_options(table, line)
+ self._parse_table_options(line, state)
# an ANSI-mode table options line
elif line == ')':
pass
elif line.startswith('CREATE '):
- self._set_name(table, line)
+ self._parse_table_name(line, state)
# Not present in real reflection, but may be if loading from a file.
elif not line:
pass
else:
- type_, spec = self.parse_constraints(line)
+ type_, spec = self._parse_constraints(line)
if type_ is None:
util.warn("Unknown schema content: %r" % line)
elif type_ == 'key':
- keys.append(spec)
+ state.keys.append(spec)
elif type_ == 'constraint':
- constraints.append(spec)
+ state.constraints.append(spec)
else:
pass
+
+ return state
+
+ def _parse_constraints(self, line):
+ """Parse a KEY or CONSTRAINT line.
+
+ line
+ A line of SHOW CREATE TABLE output
+ """
+
+ # KEY
+ m = self._re_key.match(line)
+ if m:
+ spec = m.groupdict()
+ # convert columns into name, length pairs
+ spec['columns'] = self._parse_keyexprs(spec['columns'])
+ return 'key', spec
+
+ # CONSTRAINT
+ m = self._re_constraint.match(line)
+ if m:
+ spec = m.groupdict()
+ spec['table'] = \
+ self.preparer.unformat_identifiers(spec['table'])
+ spec['local'] = [c[0]
+ for c in self._parse_keyexprs(spec['local'])]
+ spec['foreign'] = [c[0]
+ for c in self._parse_keyexprs(spec['foreign'])]
+ return 'constraint', spec
+
+ # PARTITION and SUBPARTITION
+ m = self._re_partition.match(line)
+ if m:
+ # Punt!
+ return 'partition', line
+
+ # No match.
+ return (None, line)
+
+ def _parse_table_name(self, line, state):
+ """Extract the table name.
+
+ line
+ The first line of SHOW CREATE TABLE
+ """
+
+ regex, cleanup = self._pr_name
+ m = regex.match(line)
+ if m:
+ state.table_name = cleanup(m.group('name'))
+
+ def _parse_table_options(self, line, state):
+ """Build a dictionary of all reflected table-level options.
+
+ line
+ The final line of SHOW CREATE TABLE output.
+ """
+
+ options = {}
- self._set_keys(table, keys, only)
- self._set_constraints(table, constraints, connection, only)
+ if not line or line == ')':
+ pass
+
+ else:
+ r_eq_trim = self._re_options_util['=']
+
+ for regex, cleanup in self._pr_options:
+ m = regex.search(line)
+ if not m:
+ continue
+ directive, value = m.group('directive'), m.group('val')
+ directive = r_eq_trim.sub('', directive).lower()
+ if cleanup:
+ value = cleanup(value)
+ options[directive] = value
+
+ for nope in ('auto_increment', 'data_directory', 'index_directory'):
+ options.pop(nope, None)
+
+ for opt, val in options.items():
+ state.table_options['mysql_%s' % opt] = val
- def _set_name(self, table, line):
- """Override a Table name with the reflected name.
+ def _parse_column(self, line, state):
+ """Extract column details.
- table
- A ``Table``
+ Falls back to a 'minimal support' variant if full parse fails.
line
- The first line of SHOW CREATE TABLE output.
+ Any column-bearing line from SHOW CREATE TABLE
"""
- # Don't override by default.
- if table.name is None:
- table.name = self.parse_name(line)
-
- def _add_column(self, table, line, charset, only=None):
- spec = self.parse_column(line)
+ charset = state.charset
+ spec = None
+ m = self._re_column.match(line)
+ if m:
+ spec = m.groupdict()
+ spec['full'] = True
+ else:
+ m = self._re_column_loose.match(line)
+ if m:
+ spec = m.groupdict()
+ spec['full'] = False
if not spec:
util.warn("Unknown column definition %r" % line)
return
name, type_, args, notnull = \
spec['name'], spec['coltype'], spec['arg'], spec['notnull']
- if only and name not in only:
- self.logger.info("Omitting reflected column %s.%s" %
- (table.name, name))
- return
-
# Convention says that TINYINT(1) columns == BOOLEAN
if type_ == 'tinyint' and args == '1':
type_ = 'boolean'
args = None
try:
- col_type = ischema_names[type_]
+ col_type = self.dialect.ischema_names[type_]
except KeyError:
util.warn("Did not recognize type '%s' of column '%s'" %
(type_, name))
col_args, col_kw = [], {}
# NOT NULL
+ col_kw['nullable'] = True
if spec.get('notnull', False):
col_kw['nullable'] = False
# DEFAULT
default = spec.get('default', None)
- if default is not None and default != 'NULL':
- # Defaults should be in the native charset for the moment
- default = default.encode(charset)
- if type_ == 'timestamp':
- # can't be NULL for TIMESTAMPs
- if (default[0], default[-1]) != ("'", "'"):
- default = sql.text(default)
- else:
- default = default[1:-1]
- col_args.append(schema.DefaultClause(default))
-
- table.append_column(schema.Column(name, type_instance,
- *col_args, **col_kw))
- def _set_keys(self, table, keys, only):
- """Add ``Index`` and ``PrimaryKeyConstraint`` items to a ``Table``.
+ if default == 'NULL':
+ # eliminates the need to deal with this later.
+ default = None
+
+ col_d = dict(name=name, type=type_instance, default=default)
+ col_d.update(col_kw)
+ state.columns.append(col_d)
- Most of the information gets dropped here- more is reflected than
- the schema objects can currently represent.
-
- table
- A ``Table``
+ def _describe_to_create(self, table_name, columns):
+ """Re-format DESCRIBE output as a SHOW CREATE TABLE string.
- keys
- A sequence of key specifications produced by `constraints`
+ DESCRIBE is a much simpler reflection and is sufficient for
+ reflecting views for runtime use. This method formats DDL
+ for columns only- keys are omitted.
- only
- Optional `set` of column names. If provided, keys covering
- columns not in this set will be omitted.
+ `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
+ SHOW FULL COLUMNS FROM rows must be rearranged for use with
+ this function.
"""
- for spec in keys:
- flavor = spec['type']
- col_names = [s[0] for s in spec['columns']]
-
- if only and not set(col_names).issubset(only):
- if flavor is None:
- flavor = 'index'
- self.logger.info(
- "Omitting %s KEY for (%s), key covers ommitted columns." %
- (flavor, ', '.join(col_names)))
- continue
-
- constraint = False
- if flavor == 'PRIMARY':
- key = schema.PrimaryKeyConstraint()
- constraint = True
- elif flavor == 'UNIQUE':
- key = schema.Index(spec['name'], unique=True)
- elif flavor in (None, 'FULLTEXT', 'SPATIAL'):
- key = schema.Index(spec['name'])
- else:
- self.logger.info(
- "Converting unknown KEY type %s to a plain KEY" % flavor)
- key = schema.Index(spec['name'])
-
- for col in [table.c[name] for name in col_names]:
- key.append_column(col)
-
- if constraint:
- table.append_constraint(key)
-
- def _set_constraints(self, table, constraints, connection, only):
- """Apply constraints to a ``Table``."""
-
- default_schema = None
-
- for spec in constraints:
- # only FOREIGN KEYs
- ref_name = spec['table'][-1]
- ref_schema = len(spec['table']) > 1 and spec['table'][-2] or table.schema
-
- if not ref_schema:
- if default_schema is None:
- default_schema = connection.dialect.get_default_schema_name(
- connection)
- if table.schema == default_schema:
- ref_schema = table.schema
-
- loc_names = spec['local']
- if only and not set(loc_names).issubset(only):
- self.logger.info(
- "Omitting FOREIGN KEY for (%s), key covers ommitted "
- "columns." % (', '.join(loc_names)))
- continue
-
- ref_key = schema._get_table_key(ref_name, ref_schema)
- if ref_key in table.metadata.tables:
- ref_table = table.metadata.tables[ref_key]
- else:
- ref_table = schema.Table(
- ref_name, table.metadata, schema=ref_schema,
- autoload=True, autoload_with=connection)
-
- ref_names = spec['foreign']
-
- if ref_schema:
- refspec = [".".join([ref_schema, ref_name, column]) for column in ref_names]
- else:
- refspec = [".".join([ref_name, column]) for column in ref_names]
-
- con_kw = {}
- for opt in ('name', 'onupdate', 'ondelete'):
- if spec.get(opt, False):
- con_kw[opt] = spec[opt]
-
- key = schema.ForeignKeyConstraint(loc_names, refspec, link_to_name=True, **con_kw)
- table.append_constraint(key)
+ buffer = []
+ for row in columns:
+ (name, col_type, nullable, default, extra) = \
+ [row[i] for i in (0, 1, 2, 4, 5)]
- def _set_options(self, table, line):
- """Apply safe reflected table options to a ``Table``.
+ line = [' ']
+ line.append(self.preparer.quote_identifier(name))
+ line.append(col_type)
+ if not nullable:
+ line.append('NOT NULL')
+ if default:
+ if 'auto_increment' in default:
+ pass
+ elif (col_type.startswith('timestamp') and
+ default.startswith('C')):
+ line.append('DEFAULT')
+ line.append(default)
+ elif default == 'NULL':
+ line.append('DEFAULT')
+ line.append(default)
+ else:
+ line.append('DEFAULT')
+ line.append("'%s'" % default.replace("'", "''"))
+ if extra:
+ line.append(extra)
- table
- A ``Table``
+ buffer.append(' '.join(line))
- line
- The final line of SHOW CREATE TABLE output.
- """
+ return ''.join([('CREATE TABLE %s (\n' %
+ self.preparer.quote_identifier(table_name)),
+ ',\n'.join(buffer),
+ '\n) '])
- options = self.parse_table_options(line)
- for nope in ('auto_increment', 'data_directory', 'index_directory'):
- options.pop(nope, None)
+ def _parse_keyexprs(self, identifiers):
+ """Unpack '"col"(2),"col" ASC'-ish strings into components."""
- for opt, val in options.items():
- table.kwargs['mysql_%s' % opt] = val
+ return self._re_keyexprs.findall(identifiers)
def _prep_regexes(self):
"""Pre-compile regular expressions."""
r'(?P<val>%s)' % (re.escape(directive), regex))
self._pr_options.append(_pr_compile(regex))
+log.class_logger(MySQLTableDefinitionParser)
+log.class_logger(MySQLDialect)
- def parse_name(self, line):
- """Extract the table name.
-
- line
- The first line of SHOW CREATE TABLE
- """
-
- regex, cleanup = self._pr_name
- m = regex.match(line)
- if not m:
- return None
- return cleanup(m.group('name'))
-
- def parse_column(self, line):
- """Extract column details.
-
- Falls back to a 'minimal support' variant if full parse fails.
-
- line
- Any column-bearing line from SHOW CREATE TABLE
- """
-
- m = self._re_column.match(line)
- if m:
- spec = m.groupdict()
- spec['full'] = True
- return spec
- m = self._re_column_loose.match(line)
- if m:
- spec = m.groupdict()
- spec['full'] = False
- return spec
- return None
-
- def parse_constraints(self, line):
- """Parse a KEY or CONSTRAINT line.
-
- line
- A line of SHOW CREATE TABLE output
- """
-
- # KEY
- m = self._re_key.match(line)
- if m:
- spec = m.groupdict()
- # convert columns into name, length pairs
- spec['columns'] = self._parse_keyexprs(spec['columns'])
- return 'key', spec
-
- # CONSTRAINT
- m = self._re_constraint.match(line)
- if m:
- spec = m.groupdict()
- spec['table'] = \
- self.preparer.unformat_identifiers(spec['table'])
- spec['local'] = [c[0]
- for c in self._parse_keyexprs(spec['local'])]
- spec['foreign'] = [c[0]
- for c in self._parse_keyexprs(spec['foreign'])]
- return 'constraint', spec
-
- # PARTITION and SUBPARTITION
- m = self._re_partition.match(line)
- if m:
- # Punt!
- return 'partition', line
-
- # No match.
- return (None, line)
-
- def parse_table_options(self, line):
- """Build a dictionary of all reflected table-level options.
-
- line
- The final line of SHOW CREATE TABLE output.
- """
-
- options = {}
-
- if not line or line == ')':
- return options
-
- r_eq_trim = self._re_options_util['=']
-
- for regex, cleanup in self._pr_options:
- m = regex.search(line)
- if not m:
- continue
- directive, value = m.group('directive'), m.group('val')
- directive = r_eq_trim.sub('', directive).lower()
- if cleanup:
- value = cleanup(value)
- options[directive] = value
-
- return options
-
- def _describe_to_create(self, table, columns):
- """Re-format DESCRIBE output as a SHOW CREATE TABLE string.
-
- DESCRIBE is a much simpler reflection and is sufficient for
- reflecting views for runtime use. This method formats DDL
- for columns only- keys are omitted.
-
- `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
- SHOW FULL COLUMNS FROM rows must be rearranged for use with
- this function.
- """
-
- buffer = []
- for row in columns:
- (name, col_type, nullable, default, extra) = \
- [row[i] for i in (0, 1, 2, 4, 5)]
-
- line = [' ']
- line.append(self.preparer.quote_identifier(name))
- line.append(col_type)
- if not nullable:
- line.append('NOT NULL')
- if default:
- if 'auto_increment' in default:
- pass
- elif (col_type.startswith('timestamp') and
- default.startswith('C')):
- line.append('DEFAULT')
- line.append(default)
- elif default == 'NULL':
- line.append('DEFAULT')
- line.append(default)
- else:
- line.append('DEFAULT')
- line.append("'%s'" % default.replace("'", "''"))
- if extra:
- line.append(extra)
- buffer.append(' '.join(line))
+class _DecodingRowProxy(object):
+ """Return unicode-decoded values based on type inspection.
- return ''.join([('CREATE TABLE %s (\n' %
- self.preparer.quote_identifier(table.name)),
- ',\n'.join(buffer),
- '\n) '])
+ Smooth over data type issues (esp. with alpha driver versions) and
+ normalize strings as Unicode regardless of user-configured driver
+ encoding settings.
- def _parse_keyexprs(self, identifiers):
- """Unpack '"col"(2),"col" ASC'-ish strings into components."""
+ """
- return self._re_keyexprs.findall(identifiers)
+ # Some MySQL-python versions can return some columns as
+ # sets.Set(['value']) (seriously) but thankfully that doesn't
+ # seem to come up in DDL queries.
-log.class_logger(MySQLSchemaReflector)
+ def __init__(self, rowproxy, charset):
+ self.rowproxy = rowproxy
+ self.charset = charset
+ def __getitem__(self, index):
+ item = self.rowproxy[index]
+ if isinstance(item, _array):
+ item = item.tostring()
+ if self.charset and isinstance(item, str):
+ return item.decode(self.charset)
+ else:
+ return item
+ def __getattr__(self, attr):
+ item = getattr(self.rowproxy, attr)
+ if isinstance(item, _array):
+ item = item.tostring()
+ if self.charset and isinstance(item, str):
+ return item.decode(self.charset)
+ else:
+ return item
class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote="`")
-
+
def _escape_identifier(self, value):
return value.replace('`', '``')
pass
-
-def _compat_fetchall(rp, charset=None):
- """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
-
- return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()]
-
-def _compat_fetchone(rp, charset=None):
- """Proxy a result row to smooth over MySQL-Python driver inconsistencies."""
-
- return _MySQLPythonRowProxy(rp.fetchone(), charset)
-
def _pr_compile(regex, cleanup=None):
"""Prepare a 2-tuple of compiled regex and callable."""
return re.compile(regex, re.I | re.UNICODE)
-dialect = MySQLDialect
-dialect.statement_compiler = MySQLCompiler
-dialect.schemagenerator = MySQLSchemaGenerator
-dialect.schemadropper = MySQLSchemaDropper
-dialect.execution_ctx_cls = MySQLExecutionContext
--- /dev/null
+"""Support for the MySQL database via the MySQL-python adapter.
+
+Character Sets
+--------------
+
+Many MySQL server installations default to a ``latin1`` encoding for client
+connections. All data sent through the connection will be converted into
+``latin1``, even if you have ``utf8`` or another character set on your tables
+and columns. With versions 4.1 and higher, you can change the connection
+character set either through server configuration or by including the
+``charset`` parameter in the URL used for ``create_engine``. The ``charset``
+option is passed through to MySQL-Python and has the side-effect of also
+enabling ``use_unicode`` in the driver by default. For regular encoded
+strings, also pass ``use_unicode=0`` in the connection arguments::
+
+ # set client encoding to utf8; all strings come back as unicode
+ create_engine('mysql:///mydb?charset=utf8')
+
+ # set client encoding to utf8; all strings come back as utf8 str
+ create_engine('mysql:///mydb?charset=utf8&use_unicode=0')
+"""
+
+import decimal
+import re
+
+from sqlalchemy.dialects.mysql.base import (DECIMAL, MySQLDialect, MySQLExecutionContext,
+ MySQLCompiler, NUMERIC, _NumericType)
+from sqlalchemy.engine import base as engine_base, default
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
+
+class MySQL_mysqldbExecutionContext(MySQLExecutionContext):
+
+ @property
+ def rowcount(self):
+ if hasattr(self, '_rowcount'):
+ return self._rowcount
+ else:
+ return self.cursor.rowcount
+
+
+class MySQL_mysqldbCompiler(MySQLCompiler):
+ def visit_mod(self, binary, **kw):
+ return self.process(binary.left) + " %% " + self.process(binary.right)
+
+ def post_process_text(self, text):
+ return text.replace('%', '%%')
+
+
+class _DecimalType(_NumericType):
+ def result_processor(self, dialect):
+ if self.asdecimal:
+ return
+ def process(value):
+ if isinstance(value, decimal.Decimal):
+ return float(value)
+ else:
+ return value
+ return process
+
+
+class _MySQLdbNumeric(_DecimalType, NUMERIC):
+ pass
+
+
+class _MySQLdbDecimal(_DecimalType, DECIMAL):
+ pass
+
+
+class MySQL_mysqldb(MySQLDialect):
+ driver = 'mysqldb'
+ supports_unicode_statements = False
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+
+ default_paramstyle = 'format'
+ execution_ctx_cls = MySQL_mysqldbExecutionContext
+ statement_compiler = MySQL_mysqldbCompiler
+
+ colspecs = util.update_copy(
+ MySQLDialect.colspecs,
+ {
+ sqltypes.Numeric: _MySQLdbNumeric,
+ DECIMAL: _MySQLdbDecimal
+ }
+ )
+
+ @classmethod
+ def dbapi(cls):
+ return __import__('MySQLdb')
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ rowcount = cursor.executemany(statement, parameters)
+ if context is not None:
+ context._rowcount = rowcount
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(database='db', username='user',
+ password='passwd')
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, 'compress', bool)
+ util.coerce_kw_type(opts, 'connect_timeout', int)
+ util.coerce_kw_type(opts, 'client_flag', int)
+ util.coerce_kw_type(opts, 'local_infile', int)
+ # Note: using either of the below will cause all strings to be returned
+ # as Unicode, both in raw SQL operations and with column types like
+ # String and MSString.
+ util.coerce_kw_type(opts, 'use_unicode', bool)
+ util.coerce_kw_type(opts, 'charset', str)
+
+ # Rich values 'cursorclass' and 'conv' are not supported via
+ # query string.
+
+ ssl = {}
+ for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']:
+ if key in opts:
+ ssl[key[4:]] = opts[key]
+ util.coerce_kw_type(ssl, key[4:], str)
+ del opts[key]
+ if ssl:
+ opts['ssl'] = ssl
+
+ # FOUND_ROWS must be set in CLIENT_FLAGS to enable
+ # supports_sane_rowcount.
+ client_flag = opts.get('client_flag', 0)
+ if self.dbapi is not None:
+ try:
+ from MySQLdb.constants import CLIENT as CLIENT_FLAGS
+ client_flag |= CLIENT_FLAGS.FOUND_ROWS
+ except:
+ pass
+ opts['client_flag'] = client_flag
+ return [[], opts]
+
+ def do_ping(self, connection):
+ connection.ping()
+
+ def _get_server_version_info(self, connection):
+ dbapi_con = connection.connection
+ version = []
+ r = re.compile('[.\-]')
+ for n in r.split(dbapi_con.get_server_info()):
+ try:
+ version.append(int(n))
+ except ValueError:
+ version.append(n)
+ return tuple(version)
+
+ def _extract_error_code(self, exception):
+ try:
+ return exception.orig.args[0]
+ except AttributeError:
+ return None
+
+ def _detect_charset(self, connection):
+ """Sniff out the character set in use for connection results."""
+
+ # Note: MySQL-python 1.2.1c7 seems to ignore changes made
+ # on a connection via set_character_set()
+ if self.server_version_info < (4, 1, 0):
+ try:
+ return connection.connection.character_set_name()
+ except AttributeError:
+ # < 1.2.1 final MySQL-python drivers have no charset support.
+ # a query is needed.
+ pass
+
+ # Prefer 'character_set_results' for the current connection over the
+ # value in the driver. SET NAMES or individual variable SETs will
+ # change the charset without updating the driver's view of the world.
+ #
+ # If it's decided that issuing that sort of SQL leaves you SOL, then
+ # this can prefer the driver value.
+ rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
+ opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
+
+ if 'character_set_results' in opts:
+ return opts['character_set_results']
+ try:
+ return connection.connection.character_set_name()
+ except AttributeError:
+ # Still no charset on < 1.2.1 final...
+ if 'character_set' in opts:
+ return opts['character_set']
+ else:
+ util.warn(
+ "Could not detect the connection character set with this "
+ "combination of MySQL server and MySQL-python. "
+ "MySQL-python >= 1.2.2 is recommended. Assuming latin1.")
+ return 'latin1'
+
+
+dialect = MySQL_mysqldb
--- /dev/null
+from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+from sqlalchemy.engine import base as engine_base
+from sqlalchemy import util
+import re
+
+class MySQL_pyodbcExecutionContext(MySQLExecutionContext):
+
+ def get_lastrowid(self):
+ cursor = self.create_cursor()
+ cursor.execute("SELECT LAST_INSERT_ID()")
+ lastrowid = cursor.fetchone()[0]
+ cursor.close()
+ return lastrowid
+
+class MySQL_pyodbc(PyODBCConnector, MySQLDialect):
+ supports_unicode_statements = False
+ execution_ctx_cls = MySQL_pyodbcExecutionContext
+
+ pyodbc_driver_name = "MySQL"
+
+ def __init__(self, **kw):
+ # deal with http://code.google.com/p/pyodbc/issues/detail?id=25
+ kw.setdefault('convert_unicode', True)
+ MySQLDialect.__init__(self, **kw)
+ PyODBCConnector.__init__(self, **kw)
+
+ def _detect_charset(self, connection):
+ """Sniff out the character set in use for connection results."""
+
+ # Prefer 'character_set_results' for the current connection over the
+ # value in the driver. SET NAMES or individual variable SETs will
+ # change the charset without updating the driver's view of the world.
+ #
+ # If it's decided that issuing that sort of SQL leaves you SOL, then
+ # this can prefer the driver value.
+ rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
+ opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)])
+ for key in ('character_set_connection', 'character_set'):
+ if opts.get(key, None):
+ return opts[key]
+
+ util.warn("Could not detect the connection character set. Assuming latin1.")
+ return 'latin1'
+
+ def _extract_error_code(self, exception):
+ m = re.compile(r"\((\d+)\)").search(str(exception.orig.args))
+ c = m.group(1)
+ if c:
+ return int(c)
+ else:
+ return None
+
+dialect = MySQL_pyodbc
--- /dev/null
+"""Support for the MySQL database via Jython's zxjdbc JDBC connector.
+
+JDBC Driver
+-----------
+
+The official MySQL JDBC driver is at
+http://dev.mysql.com/downloads/connector/j/.
+
+"""
+import re
+
+from sqlalchemy import types as sqltypes, util
+from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
+from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext
+
+class _JDBCBit(BIT):
+ def result_processor(self, dialect):
+ """Converts boolean or byte arrays from MySQL Connector/J to longs."""
+ def process(value):
+ if value is None:
+ return value
+ if isinstance(value, bool):
+ return int(value)
+ v = 0L
+ for i in value:
+ v = v << 8 | (i & 0xff)
+ value = v
+ return value
+ return process
+
+
+class MySQL_jdbcExecutionContext(MySQLExecutionContext):
+ def get_lastrowid(self):
+ cursor = self.create_cursor()
+ cursor.execute("SELECT LAST_INSERT_ID()")
+ lastrowid = cursor.fetchone()[0]
+ cursor.close()
+ return lastrowid
+
+
+class MySQL_jdbc(ZxJDBCConnector, MySQLDialect):
+ execution_ctx_cls = MySQL_jdbcExecutionContext
+
+ jdbc_db_name = 'mysql'
+ jdbc_driver_name = 'com.mysql.jdbc.Driver'
+
+ colspecs = util.update_copy(
+ MySQLDialect.colspecs,
+ {
+ sqltypes.Time: sqltypes.Time,
+ BIT: _JDBCBit
+ }
+ )
+
+ def _detect_charset(self, connection):
+ """Sniff out the character set in use for connection results."""
+ # Prefer 'character_set_results' for the current connection over the
+ # value in the driver. SET NAMES or individual variable SETs will
+ # change the charset without updating the driver's view of the world.
+ #
+ # If it's decided that issuing that sort of SQL leaves you SOL, then
+ # this can prefer the driver value.
+ rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'")
+ opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs))
+ for key in ('character_set_connection', 'character_set'):
+ if opts.get(key, None):
+ return opts[key]
+
+ util.warn("Could not detect the connection character set. Assuming latin1.")
+ return 'latin1'
+
+ def _driver_kwargs(self):
+ """return kw arg dict to be sent to connect()."""
+ return dict(CHARSET=self.encoding, yearIsDateType='false')
+
+ def _extract_error_code(self, exception):
+ # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist
+ # [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' ()
+ m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args))
+ c = m.group(1)
+ if c:
+ return int(c)
+
+ def _get_server_version_info(self,connection):
+ dbapi_con = connection.connection
+ version = []
+ r = re.compile('[.\-]')
+ for n in r.split(dbapi_con.dbversion):
+ try:
+ version.append(int(n))
+ except ValueError:
+ version.append(n)
+ return tuple(version)
+
+dialect = MySQL_jdbc
--- /dev/null
+from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc
+
+base.dialect = cx_oracle.dialect
--- /dev/null
+# oracle/base.py
+# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Support for the Oracle database.
+
+Oracle version 8 through current (11g at the time of this writing) are supported.
+
+For information on connecting via specific drivers, see the documentation
+for that driver.
+
+Connect Arguments
+-----------------
+
+The dialect supports several :func:`~sqlalchemy.create_engine()` arguments which
+affect the behavior of the dialect regardless of driver in use.
+
+* *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8). Defaults
+ to ``True``. If ``False``, Oracle-8 compatible constructs are used for joins.
+
+* *optimize_limits* - defaults to ``False``. see the section on LIMIT/OFFSET.
+
+Auto Increment Behavior
+-----------------------
+
+SQLAlchemy Table objects which include integer primary keys are usually assumed to have
+"autoincrementing" behavior, meaning they can generate their own primary key values upon
+INSERT. Since Oracle has no "autoincrement" feature, SQLAlchemy relies upon sequences
+to produce these values. With the Oracle dialect, *a sequence must always be explicitly
+specified to enable autoincrement*. This is divergent with the majority of documentation
+examples which assume the usage of an autoincrement-capable database. To specify sequences,
+use the sqlalchemy.schema.Sequence object which is passed to a Column construct::
+
+ t = Table('mytable', metadata,
+ Column('id', Integer, Sequence('id_seq'), primary_key=True),
+ Column(...), ...
+ )
+
+This step is also required when using table reflection, i.e. autoload=True::
+
+ t = Table('mytable', metadata,
+ Column('id', Integer, Sequence('id_seq'), primary_key=True),
+ autoload=True
+ )
+
+Identifier Casing
+-----------------
+
+In Oracle, the data dictionary represents all case insensitive identifier names
+using UPPERCASE text. SQLAlchemy on the other hand considers an all-lower case identifier
+name to be case insensitive. The Oracle dialect converts all case insensitive identifiers
+to and from those two formats during schema level communication, such as reflection of
+tables and indexes. Using an UPPERCASE name on the SQLAlchemy side indicates a
+case sensitive identifier, and SQLAlchemy will quote the name - this will cause mismatches
+against data dictionary data received from Oracle, so unless identifier names have been
+truly created as case sensitive (i.e. using quoted names), all lowercase names should be
+used on the SQLAlchemy side.
+
+Unicode
+-------
+
+SQLAlchemy 0.6 uses the "native unicode" mode provided as of cx_oracle 5. cx_oracle 5.0.2
+or greater is recommended for support of NCLOB. If not using cx_oracle 5, the NLS_LANG
+environment variable needs to be set in order for the oracle client library to use
+proper encoding, such as "AMERICAN_AMERICA.UTF8".
+
+Also note that Oracle supports unicode data through the NVARCHAR and NCLOB data types.
+When using the SQLAlchemy Unicode and UnicodeText types, these DDL types will be used
+within CREATE TABLE statements. Usage of VARCHAR2 and CLOB with unicode text still
+requires NLS_LANG to be set.
+
+LIMIT/OFFSET Support
+--------------------
+
+Oracle has no support for the LIMIT or OFFSET keywords. Whereas previous versions of SQLAlchemy
+used the "ROW NUMBER OVER..." construct to simulate LIMIT/OFFSET, SQLAlchemy 0.5 now uses
+a wrapped subquery approach in conjunction with ROWNUM. The exact methodology is taken from
+http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html . Note that the
+"FIRST ROWS()" optimization keyword mentioned is not used by default, as the user community felt
+this was stepping into the bounds of optimization that is better left on the DBA side, but this
+prefix can be added by enabling the optimize_limits=True flag on create_engine().
+
+ON UPDATE CASCADE
+-----------------
+
+Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based solution
+is available at http://asktom.oracle.com/tkyte/update_cascade/index.html .
+
+When using the SQLAlchemy ORM, the ORM has limited ability to manually issue
+cascading updates - specify ForeignKey objects using the
+"deferrable=True, initially='deferred'" keyword arguments,
+and specify "passive_updates=False" on each relation().
+
+Oracle 8 Compatibility
+----------------------
+
+When using Oracle 8, a "use_ansi=False" flag is available which converts all
+JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN
+makes use of Oracle's (+) operator.
+
+Synonym/DBLINK Reflection
+-------------------------
+
+When using reflection with Table objects, the dialect can optionally search for tables
+indicated by synonyms that reference DBLINK-ed tables by passing the flag
+oracle_resolve_synonyms=True as a keyword argument to the Table construct. If DBLINK
+is not in use this flag should be left off.
+
+"""
+
+import random, re
+
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import util, sql, log
+from sqlalchemy.engine import default, base, reflection
+from sqlalchemy.sql import compiler, visitors, expression
+from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
+from sqlalchemy import types as sqltypes
+from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, \
+ BLOB, CLOB, TIMESTAMP, FLOAT
+
+RESERVED_WORDS = set('''SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR DECIMAL UNION PUBLIC AND START UID COMMENT'''.split())
+
+class RAW(sqltypes.Binary):
+ pass
+OracleRaw = RAW
+
+class NCLOB(sqltypes.Text):
+ __visit_name__ = 'NCLOB'
+
+VARCHAR2 = VARCHAR
+NVARCHAR2 = NVARCHAR
+
+class NUMBER(sqltypes.Numeric):
+ __visit_name__ = 'NUMBER'
+
+class BFILE(sqltypes.Binary):
+ __visit_name__ = 'BFILE'
+
+class DOUBLE_PRECISION(sqltypes.Numeric):
+ __visit_name__ = 'DOUBLE_PRECISION'
+
+class LONG(sqltypes.Text):
+ __visit_name__ = 'LONG'
+
+class _OracleBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ def result_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ return value and True or False
+ return process
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is True:
+ return 1
+ elif value is False:
+ return 0
+ elif value is None:
+ return None
+ else:
+ return value and True or False
+ return process
+
+colspecs = {
+ sqltypes.Boolean : _OracleBoolean,
+}
+
+ischema_names = {
+ 'VARCHAR2' : VARCHAR,
+ 'NVARCHAR2' : NVARCHAR,
+ 'CHAR' : CHAR,
+ 'DATE' : DATE,
+ 'DATETIME' : DATETIME,
+ 'NUMBER' : NUMBER,
+ 'BLOB' : BLOB,
+ 'BFILE' : BFILE,
+ 'CLOB' : CLOB,
+ 'NCLOB' : NCLOB,
+ 'TIMESTAMP' : TIMESTAMP,
+ 'RAW' : RAW,
+ 'FLOAT' : FLOAT,
+ 'DOUBLE PRECISION' : DOUBLE_PRECISION,
+ 'LONG' : LONG,
+}
+
+
+class OracleTypeCompiler(compiler.GenericTypeCompiler):
+ # Note:
+ # Oracle DATE == DATETIME
+ # Oracle does not allow milliseconds in DATE
+ # Oracle does not support TIME columns
+
+ def visit_datetime(self, type_):
+ return self.visit_DATE(type_)
+
+ def visit_float(self, type_):
+ if type_.precision is None:
+ return "NUMERIC"
+ else:
+ return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : 2}
+
+ def visit_unicode(self, type_):
+ return self.visit_NVARCHAR(type_)
+
+ def visit_VARCHAR(self, type_):
+ return "VARCHAR(%(length)s)" % {'length' : type_.length}
+
+ def visit_NVARCHAR(self, type_):
+ return "NVARCHAR2(%(length)s)" % {'length' : type_.length}
+
+ def visit_text(self, type_):
+ return self.visit_CLOB(type_)
+
+ def visit_unicode_text(self, type_):
+ return self.visit_NCLOB(type_)
+
+ def visit_binary(self, type_):
+ return self.visit_BLOB(type_)
+
+ def visit_boolean(self, type_):
+ return self.visit_SMALLINT(type_)
+
+ def visit_RAW(self, type_):
+ return "RAW(%(length)s)" % {'length' : type_.length}
+
+class OracleCompiler(compiler.SQLCompiler):
+ """Oracle compiler modifies the lexical structure of Select
+ statements to work under non-ANSI configured Oracle databases, if
+ the use_ansi flag is False.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(OracleCompiler, self).__init__(*args, **kwargs)
+ self.__wheres = {}
+ self._quoted_bind_names = {}
+
+ def visit_mod(self, binary, **kw):
+ return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_char_length_func(self, fn, **kw):
+ return "LENGTH" + self.function_argspec(fn, **kw)
+
+ def visit_match_op(self, binary, **kw):
+ return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+ def function_argspec(self, fn, **kw):
+ if len(fn.clauses) > 0:
+ return compiler.SQLCompiler.function_argspec(self, fn, **kw)
+ else:
+ return ""
+
+ def default_from(self):
+ """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
+
+ The Oracle compiler tacks a "FROM DUAL" to the statement.
+ """
+
+ return " FROM DUAL"
+
+ def visit_join(self, join, **kwargs):
+ if self.dialect.use_ansi:
+ return compiler.SQLCompiler.visit_join(self, join, **kwargs)
+ else:
+ return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+
+ def _get_nonansi_join_whereclause(self, froms):
+ clauses = []
+
+ def visit_join(join):
+ if join.isouter:
+ def visit_binary(binary):
+ if binary.operator == sql_operators.eq:
+ if binary.left.table is join.right:
+ binary.left = _OuterJoinColumn(binary.left)
+ elif binary.right.table is join.right:
+ binary.right = _OuterJoinColumn(binary.right)
+ clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary}))
+ else:
+ clauses.append(join.onclause)
+
+ for f in froms:
+ visitors.traverse(f, {}, {'join':visit_join})
+ return sql.and_(*clauses)
+
+ def visit_outer_join_column(self, vc):
+ return self.process(vc.column) + "(+)"
+
+ def visit_sequence(self, seq):
+ return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
+
+ def visit_alias(self, alias, asfrom=False, **kwargs):
+ """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
+
+ if asfrom:
+ alias_name = isinstance(alias.name, expression._generated_label) and \
+ self._truncated_identifier("alias", alias.name) or alias.name
+
+ return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, alias_name)
+ else:
+ return self.process(alias.original, **kwargs)
+
+ def returning_clause(self, stmt, returning_cols):
+
+ def create_out_param(col, i):
+ bindparam = sql.outparam("ret_%d" % i, type_=col.type)
+ self.binds[bindparam.key] = bindparam
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
+
+ columnlist = list(expression._select_iterables(returning_cols))
+
+ # within_columns_clause =False so that labels (foo AS bar) don't render
+ columns = [self.process(c, within_columns_clause=False) for c in columnlist]
+
+ binds = [create_out_param(c, i) for i, c in enumerate(columnlist)]
+
+ return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds)
+
+ def _TODO_visit_compound_select(self, select):
+ """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
+ pass
+
+ def visit_select(self, select, **kwargs):
+ """Look for ``LIMIT`` and OFFSET in a select statement, and if
+ so tries to wrap it in a subquery with ``rownum`` criterion.
+ """
+
+ if not getattr(select, '_oracle_visit', None):
+ if not self.dialect.use_ansi:
+ if self.stack and 'from' in self.stack[-1]:
+ existingfroms = self.stack[-1]['from']
+ else:
+ existingfroms = None
+
+ froms = select._get_display_froms(existingfroms)
+ whereclause = self._get_nonansi_join_whereclause(froms)
+ if whereclause:
+ select = select.where(whereclause)
+ select._oracle_visit = True
+
+ if select._limit is not None or select._offset is not None:
+ # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html
+ #
+ # Generalized form of an Oracle pagination query:
+ # select ... from (
+ # select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from (
+ # select distinct ... where ... order by ...
+ # ) where ROWNUM <= :limit+:offset
+ # ) where ora_rn > :offset
+ # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0
+
+ # TODO: use annotations instead of clone + attr set ?
+ select = select._generate()
+ select._oracle_visit = True
+
+ # Wrap the middle select and add the hint
+ limitselect = sql.select([c for c in select.c])
+ if select._limit and self.dialect.optimize_limits:
+ limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit)
+
+ limitselect._oracle_visit = True
+ limitselect._is_wrapper = True
+
+ # If needed, add the limiting clause
+ if select._limit is not None:
+ max_row = select._limit
+ if select._offset is not None:
+ max_row += select._offset
+ limitselect.append_whereclause(
+ sql.literal_column("ROWNUM")<=max_row)
+
+ # If needed, add the ora_rn, and wrap again with offset.
+ if select._offset is None:
+ select = limitselect
+ else:
+ limitselect = limitselect.column(
+ sql.literal_column("ROWNUM").label("ora_rn"))
+ limitselect._oracle_visit = True
+ limitselect._is_wrapper = True
+
+ offsetselect = sql.select(
+ [c for c in limitselect.c if c.key!='ora_rn'])
+ offsetselect._oracle_visit = True
+ offsetselect._is_wrapper = True
+
+ offsetselect.append_whereclause(
+ sql.literal_column("ora_rn")>select._offset)
+
+ select = offsetselect
+
+ kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
+ return compiler.SQLCompiler.visit_select(self, select, **kwargs)
+
+ def limit_clause(self, select):
+ return ""
+
+ def for_update_clause(self, select):
+ if select.for_update == "nowait":
+ return " FOR UPDATE NOWAIT"
+ else:
+ return super(OracleCompiler, self).for_update_clause(select)
+
+class OracleDDLCompiler(compiler.DDLCompiler):
+
+ def visit_create_sequence(self, create):
+ return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+
+ def visit_drop_sequence(self, drop):
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+
+ def define_constraint_cascades(self, constraint):
+ text = ""
+ if constraint.ondelete is not None:
+ text += " ON DELETE %s" % constraint.ondelete
+
+ # oracle has no ON UPDATE CASCADE -
+ # its only available via triggers http://asktom.oracle.com/tkyte/update_cascade/index.html
+ if constraint.onupdate is not None:
+ util.warn(
+ "Oracle does not contain native UPDATE CASCADE "
+ "functionality - onupdates will not be rendered for foreign keys."
+ "Consider using deferrable=True, initially='deferred' or triggers.")
+
+ return text
+
+class OracleDefaultRunner(base.DefaultRunner):
+ def visit_sequence(self, seq):
+ return self.execute_string("SELECT " +
+ self.dialect.identifier_preparer.format_sequence(seq) +
+ ".nextval FROM DUAL", ())
+
+class OracleIdentifierPreparer(compiler.IdentifierPreparer):
+
+ reserved_words = set([x.lower() for x in RESERVED_WORDS])
+ illegal_initial_characters = re.compile(r'[0-9_$]')
+
+ def _bindparam_requires_quotes(self, value):
+ """Return True if the given identifier requires quoting."""
+ lc_value = value.lower()
+ return (lc_value in self.reserved_words
+ or self.illegal_initial_characters.match(value[0])
+ or not self.legal_characters.match(unicode(value))
+ )
+
+ def format_savepoint(self, savepoint):
+ name = re.sub(r'^_+', '', savepoint.ident)
+ return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
+
+class OracleDialect(default.DefaultDialect):
+ name = 'oracle'
+ supports_alter = True
+ supports_unicode_statements = False
+ supports_unicode_binds = False
+ max_identifier_length = 30
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+
+ supports_sequences = True
+ sequences_optional = False
+ postfetch_lastrowid = False
+
+ default_paramstyle = 'named'
+ colspecs = colspecs
+ ischema_names = ischema_names
+ requires_name_normalize = True
+
+ supports_default_values = False
+ supports_empty_insert = False
+
+ statement_compiler = OracleCompiler
+ ddl_compiler = OracleDDLCompiler
+ type_compiler = OracleTypeCompiler
+ preparer = OracleIdentifierPreparer
+ defaultrunner = OracleDefaultRunner
+
+ reflection_options = ('oracle_resolve_synonyms', )
+
+
+ def __init__(self,
+ use_ansi=True,
+ optimize_limits=False,
+ **kwargs):
+ default.DefaultDialect.__init__(self, **kwargs)
+ self.use_ansi = use_ansi
+ self.optimize_limits = optimize_limits
+
+# TODO: implement server_version_info for oracle
+# def initialize(self, connection):
+# super(OracleDialect, self).initialize(connection)
+# self.implicit_returning = self.server_version_info > (10, ) and \
+# self.__dict__.get('implicit_returning', True)
+
+ def do_release_savepoint(self, connection, name):
+ # Oracle does not support RELEASE SAVEPOINT
+ pass
+
+ def has_table(self, connection, table_name, schema=None):
+ if not schema:
+ schema = self.get_default_schema_name(connection)
+ cursor = connection.execute(
+ sql.text("SELECT table_name FROM all_tables "
+ "WHERE table_name = :name AND owner = :schema_name"),
+ name=self.denormalize_name(table_name), schema_name=self.denormalize_name(schema))
+ return cursor.fetchone() is not None
+
+ def has_sequence(self, connection, sequence_name, schema=None):
+ if not schema:
+ schema = self.get_default_schema_name(connection)
+ cursor = connection.execute(
+ sql.text("SELECT sequence_name FROM all_sequences "
+ "WHERE sequence_name = :name AND sequence_owner = :schema_name"),
+ name=self.denormalize_name(sequence_name), schema_name=self.denormalize_name(schema))
+ return cursor.fetchone() is not None
+
+ def normalize_name(self, name):
+ if name is None:
+ return None
+ elif (name.upper() == name and
+ not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding))):
+ return name.lower().decode(self.encoding)
+ else:
+ return name.decode(self.encoding)
+
+ def denormalize_name(self, name):
+ if name is None:
+ return None
+ elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()):
+ return name.upper().encode(self.encoding)
+ else:
+ return name.encode(self.encoding)
+
+ def get_default_schema_name(self, connection):
+ return self.normalize_name(connection.execute('SELECT USER FROM DUAL').scalar())
+
+ def table_names(self, connection, schema):
+ # note that table_names() isnt loading DBLINKed or synonym'ed tables
+ if schema is None:
+ cursor = connection.execute(
+ "SELECT table_name FROM all_tables "
+ "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')")
+ else:
+ s = sql.text(
+ "SELECT table_name FROM all_tables "
+ "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') "
+ "AND OWNER = :owner")
+ cursor = connection.execute(s, owner=self.denormalize_name(schema))
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None):
+ """search for a local synonym matching the given desired owner/name.
+
+ if desired_owner is None, attempts to locate a distinct owner.
+
+ returns the actual name, owner, dblink name, and synonym name if found.
+ """
+
+ q = "SELECT owner, table_owner, table_name, db_link, synonym_name FROM all_synonyms WHERE "
+ clauses = []
+ params = {}
+ if desired_synonym:
+ clauses.append("synonym_name = :synonym_name")
+ params['synonym_name'] = desired_synonym
+ if desired_owner:
+ clauses.append("table_owner = :desired_owner")
+ params['desired_owner'] = desired_owner
+ if desired_table:
+ clauses.append("table_name = :tname")
+ params['tname'] = desired_table
+
+ q += " AND ".join(clauses)
+
+ result = connection.execute(sql.text(q), **params)
+ if desired_owner:
+ row = result.fetchone()
+ if row:
+ return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name']
+ else:
+ return None, None, None, None
+ else:
+ rows = result.fetchall()
+ if len(rows) > 1:
+ raise AssertionError("There are multiple tables visible to the schema, you must specify owner")
+ elif len(rows) == 1:
+ row = rows[0]
+ return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name']
+ else:
+ return None, None, None, None
+
+ @reflection.cache
+ def _prepare_reflection_args(self, connection, table_name, schema=None,
+ resolve_synonyms=False, dblink='', **kw):
+
+ if resolve_synonyms:
+ actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self.denormalize_name(schema), desired_synonym=self.denormalize_name(table_name))
+ else:
+ actual_name, owner, dblink, synonym = None, None, None, None
+ if not actual_name:
+ actual_name = self.denormalize_name(table_name)
+ if not dblink:
+ dblink = ''
+ if not owner:
+ owner = self.denormalize_name(schema or self.get_default_schema_name(connection))
+ return (actual_name, owner, dblink, synonym)
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ s = "SELECT username FROM all_users ORDER BY username"
+ cursor = connection.execute(s,)
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ schema = self.denormalize_name(schema or self.get_default_schema_name(connection))
+ return self.table_names(connection, schema)
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ schema = self.denormalize_name(schema or self.get_default_schema_name(connection))
+ s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner")
+ cursor = connection.execute(s, owner=self.denormalize_name(schema))
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ """
+
+ kw arguments can be:
+
+ oracle_resolve_synonyms
+
+ dblink
+
+ """
+
+ resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
+ dblink = kw.get('dblink', '')
+ info_cache = kw.get('info_cache')
+
+ (table_name, schema, dblink, synonym) = \
+ self._prepare_reflection_args(connection, table_name, schema,
+ resolve_synonyms, dblink,
+ info_cache=info_cache)
+ columns = []
+ c = connection.execute(sql.text(
+ "SELECT column_name, data_type, data_length, data_precision, data_scale, "
+ "nullable, data_default FROM ALL_TAB_COLUMNS%(dblink)s "
+ "WHERE table_name = :table_name AND owner = :owner" % {'dblink': dblink}),
+ table_name=table_name, owner=schema)
+
+ for row in c:
+ (colname, coltype, length, precision, scale, nullable, default) = \
+ (self.normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
+
+ # INTEGER if the scale is 0 and precision is null
+ # NUMBER if the scale and precision are both null
+ # NUMBER(9,2) if the precision is 9 and the scale is 2
+ # NUMBER(3) if the precision is 3 and scale is 0
+ #length is ignored except for CHAR and VARCHAR2
+ if coltype == 'NUMBER' :
+ if precision is None and scale is None:
+ coltype = sqltypes.NUMERIC
+ elif precision is None and scale == 0:
+ coltype = sqltypes.INTEGER
+ else :
+ coltype = sqltypes.NUMERIC(precision, scale)
+ elif coltype=='CHAR' or coltype=='VARCHAR2':
+ coltype = self.ischema_names.get(coltype)(length)
+ else:
+ coltype = re.sub(r'\(\d+\)', '', coltype)
+ try:
+ coltype = self.ischema_names[coltype]
+ except KeyError:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (coltype, colname))
+ coltype = sqltypes.NULLTYPE
+
+ cdict = {
+ 'name': colname,
+ 'type': coltype,
+ 'nullable': nullable,
+ 'default': default,
+ }
+ columns.append(cdict)
+ return columns
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None,
+ resolve_synonyms=False, dblink='', **kw):
+
+
+ info_cache = kw.get('info_cache')
+ (table_name, schema, dblink, synonym) = \
+ self._prepare_reflection_args(connection, table_name, schema,
+ resolve_synonyms, dblink,
+ info_cache=info_cache)
+ indexes = []
+ q = sql.text("""
+ SELECT a.index_name, a.column_name, b.uniqueness
+ FROM ALL_IND_COLUMNS%(dblink)s a
+ INNER JOIN ALL_INDEXES%(dblink)s b
+ ON a.index_name = b.index_name
+ AND a.table_owner = b.table_owner
+ AND a.table_name = b.table_name
+ WHERE a.table_name = :table_name
+ AND a.table_owner = :schema
+ ORDER BY a.index_name, a.column_position""" % {'dblink': dblink})
+ rp = connection.execute(q, table_name=self.denormalize_name(table_name),
+ schema=self.denormalize_name(schema))
+ indexes = []
+ last_index_name = None
+ pkeys = self.get_primary_keys(connection, table_name, schema,
+ resolve_synonyms=resolve_synonyms,
+ dblink=dblink,
+ info_cache=kw.get('info_cache'))
+ uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
+ for rset in rp:
+ # don't include the primary key columns
+ if rset.column_name in [s.upper() for s in pkeys]:
+ continue
+ if rset.index_name != last_index_name:
+ index = dict(name=self.normalize_name(rset.index_name), column_names=[])
+ indexes.append(index)
+ index['unique'] = uniqueness.get(rset.uniqueness, False)
+ index['column_names'].append(self.normalize_name(rset.column_name))
+ last_index_name = rset.index_name
+ return indexes
+
+ @reflection.cache
+ def _get_constraint_data(self, connection, table_name, schema=None,
+ dblink='', **kw):
+
+ rp = connection.execute(
+ sql.text("""SELECT
+ ac.constraint_name,
+ ac.constraint_type,
+ loc.column_name AS local_column,
+ rem.table_name AS remote_table,
+ rem.column_name AS remote_column,
+ rem.owner AS remote_owner,
+ loc.position as loc_pos,
+ rem.position as rem_pos
+ FROM all_constraints%(dblink)s ac,
+ all_cons_columns%(dblink)s loc,
+ all_cons_columns%(dblink)s rem
+ WHERE ac.table_name = :table_name
+ AND ac.constraint_type IN ('R','P')
+ AND ac.owner = :owner
+ AND ac.owner = loc.owner
+ AND ac.constraint_name = loc.constraint_name
+ AND ac.r_owner = rem.owner(+)
+ AND ac.r_constraint_name = rem.constraint_name(+)
+ AND (rem.position IS NULL or loc.position=rem.position)
+ ORDER BY ac.constraint_name, loc.position""" % {'dblink': dblink}),
+ table_name=table_name, owner=schema)
+ constraint_data = rp.fetchall()
+ return constraint_data
+
+ @reflection.cache
+ def get_primary_keys(self, connection, table_name, schema=None, **kw):
+ """
+
+ kw arguments can be:
+
+ oracle_resolve_synonyms
+
+ dblink
+
+ """
+
+ resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
+ dblink = kw.get('dblink', '')
+ info_cache = kw.get('info_cache')
+
+ (table_name, schema, dblink, synonym) = \
+ self._prepare_reflection_args(connection, table_name, schema,
+ resolve_synonyms, dblink,
+ info_cache=info_cache)
+ pkeys = []
+ constraint_data = self._get_constraint_data(connection, table_name,
+ schema, dblink,
+ info_cache=kw.get('info_cache'))
+
+ for row in constraint_data:
+ #print "ROW:" , row
+ (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
+ row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
+ if cons_type == 'P':
+ pkeys.append(local_column)
+ return pkeys
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ """
+
+ kw arguments can be:
+
+ oracle_resolve_synonyms
+
+ dblink
+
+ """
+
+ requested_schema = schema # to check later on
+ resolve_synonyms = kw.get('oracle_resolve_synonyms', False)
+ dblink = kw.get('dblink', '')
+ info_cache = kw.get('info_cache')
+
+ (table_name, schema, dblink, synonym) = \
+ self._prepare_reflection_args(connection, table_name, schema,
+ resolve_synonyms, dblink,
+ info_cache=info_cache)
+
+ constraint_data = self._get_constraint_data(connection, table_name,
+ schema, dblink,
+ info_cache=kw.get('info_cache'))
+
+ def fkey_rec():
+ return {
+ 'name' : None,
+ 'constrained_columns' : [],
+ 'referred_schema' : None,
+ 'referred_table' : None,
+ 'referred_columns' : []
+ }
+
+ fkeys = util.defaultdict(fkey_rec)
+
+ for row in constraint_data:
+ (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
+ row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
+
+ if cons_type == 'R':
+ if remote_table is None:
+ # ticket 363
+ util.warn(
+ ("Got 'None' querying 'table_name' from "
+ "all_cons_columns%(dblink)s - does the user have "
+ "proper rights to the table?") % {'dblink':dblink})
+ continue
+
+ rec = fkeys[cons_name]
+ rec['name'] = cons_name
+ local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns']
+
+ if not rec['referred_table']:
+ if resolve_synonyms:
+ ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \
+ self._resolve_synonym(
+ connection,
+ desired_owner=self.denormalize_name(remote_owner),
+ desired_table=self.denormalize_name(remote_table)
+ )
+ if ref_synonym:
+ remote_table = self.normalize_name(ref_synonym)
+ remote_owner = self.normalize_name(ref_remote_owner)
+
+ rec['referred_table'] = remote_table
+
+ if requested_schema is not None or self.denormalize_name(remote_owner) != schema:
+ rec['referred_schema'] = remote_owner
+
+ local_cols.append(local_column)
+ remote_cols.append(remote_column)
+
+ return fkeys.values()
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None,
+ resolve_synonyms=False, dblink='', **kw):
+ info_cache = kw.get('info_cache')
+ (view_name, schema, dblink, synonym) = \
+ self._prepare_reflection_args(connection, view_name, schema,
+ resolve_synonyms, dblink,
+ info_cache=info_cache)
+ s = sql.text("""
+ SELECT text FROM all_views
+ WHERE owner = :schema
+ AND view_name = :view_name
+ """)
+ rp = connection.execute(s,
+ view_name=view_name, schema=schema).scalar()
+ if rp:
+ return rp.decode(self.encoding)
+ else:
+ return None
+
+
+
+class _OuterJoinColumn(sql.ClauseElement):
+ __visit_name__ = 'outer_join_column'
+
+ def __init__(self, column):
+ self.column = column
+
+
+
--- /dev/null
+"""Support for the Oracle database via the cx_oracle driver.
+
+Driver
+------
+
+The Oracle dialect uses the cx_oracle driver, available at
+http://cx-oracle.sourceforge.net/ . The dialect has several behaviors
+which are specifically tailored towards compatibility with this module.
+
+Connecting
+----------
+
+Connecting with create_engine() uses the standard URL approach of
+``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the
+host, port, and dbname tokens are converted to a TNS name using the cx_oracle
+:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name.
+
+Additional arguments which may be specified either as query string arguments on the
+URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are:
+
+* *allow_twophase* - enable two-phase transactions. Defaults to ``True``.
+
+* *auto_convert_lobs* - defaults to True, see the section on LOB objects.
+
+* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters.
+ This is required for LOB datatypes but can be disabled to reduce overhead. Defaults
+ to ``True``.
+
+* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an
+ integer value. This value is only available as a URL query string argument.
+
+* *threaded* - enable multithreaded access to cx_oracle connections. Defaults
+ to ``True``. Note that this is the opposite default of cx_oracle itself.
+
+
+LOB Objects
+-----------
+
+cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set
+is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default,
+SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First,
+the LOB object requires an active cursor association, meaning if you were to fetch many rows
+at once such that cx_oracle had to go back to the database and fetch a new batch of rows,
+the LOB objects in the already-fetched rows are now unreadable and will raise an error.
+SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read.
+The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy
+defaults to 50 (cx_oracle normally defaults this to one).
+
+Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to
+"normalize" the results to look more like other DBAPIs.
+
+The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place
+for all statement executions, even plain string-based statements for which SQLA has no awareness
+of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases
+without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch"
+of LOB objects, can be disabled using auto_convert_lobs=False.
+
+Two Phase Transaction Support
+-----------------------------
+
+Two Phase transactions are implemented using XA transactions. Success has been reported of them
+working successfully but this should be regarded as an experimental feature.
+
+"""
+
+from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, RESERVED_WORDS
+from sqlalchemy.dialects.oracle import base as oracle
+from sqlalchemy.engine.default import DefaultExecutionContext
+from sqlalchemy.engine import base
+from sqlalchemy import types as sqltypes, util
+import datetime
+
+class _OracleDate(sqltypes.Date):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect):
+ def process(value):
+ if not isinstance(value, datetime.datetime):
+ return value
+ else:
+ return value.date()
+ return process
+
+class _OracleDateTime(sqltypes.DateTime):
+ def result_processor(self, dialect):
+ def process(value):
+ if value is None or isinstance(value, datetime.datetime):
+ return value
+ else:
+ # convert cx_oracle datetime object returned pre-python 2.4
+ return datetime.datetime(value.year, value.month,
+ value.day,value.hour, value.minute, value.second)
+ return process
+
+# Note:
+# Oracle DATE == DATETIME
+# Oracle does not allow milliseconds in DATE
+# Oracle does not support TIME columns
+
+# only if cx_oracle contains TIMESTAMP
+class _OracleTimestamp(sqltypes.TIMESTAMP):
+ def result_processor(self, dialect):
+ def process(value):
+ if value is None or isinstance(value, datetime.datetime):
+ return value
+ else:
+ # convert cx_oracle datetime object returned pre-python 2.4
+ return datetime.datetime(value.year, value.month,
+ value.day,value.hour, value.minute, value.second)
+ return process
+
+class _LOBMixin(object):
+ def result_processor(self, dialect):
+ super_process = super(_LOBMixin, self).result_processor(dialect)
+ if not dialect.auto_convert_lobs:
+ return super_process
+ lob = dialect.dbapi.LOB
+ def process(value):
+ if isinstance(value, lob):
+ if super_process:
+ return super_process(value.read())
+ else:
+ return value.read()
+ else:
+ if super_process:
+ return super_process(value)
+ else:
+ return value
+ return process
+
+class _OracleText(_LOBMixin, sqltypes.Text):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.CLOB
+
+class _OracleUnicodeText(_LOBMixin, sqltypes.UnicodeText):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NCLOB
+
+
+class _OracleBinary(_LOBMixin, sqltypes.Binary):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BLOB
+
+ def bind_processor(self, dialect):
+ return None
+
+
+class _OracleRaw(_LOBMixin, oracle.RAW):
+ pass
+
+
+colspecs = {
+ sqltypes.DateTime : _OracleDateTime,
+ sqltypes.Date : _OracleDate,
+ sqltypes.Binary : _OracleBinary,
+ sqltypes.Boolean : oracle._OracleBoolean,
+ sqltypes.Text : _OracleText,
+ sqltypes.UnicodeText : _OracleUnicodeText,
+ sqltypes.TIMESTAMP : _OracleTimestamp,
+ oracle.RAW: _OracleRaw,
+}
+
+class Oracle_cx_oracleCompiler(OracleCompiler):
+ def bindparam_string(self, name):
+ if self.preparer._bindparam_requires_quotes(name):
+ quoted_name = '"%s"' % name
+ self._quoted_bind_names[name] = quoted_name
+ return OracleCompiler.bindparam_string(self, quoted_name)
+ else:
+ return OracleCompiler.bindparam_string(self, name)
+
+class Oracle_cx_oracleExecutionContext(DefaultExecutionContext):
+ def pre_exec(self):
+ quoted_bind_names = getattr(self.compiled, '_quoted_bind_names', {})
+ if quoted_bind_names:
+ for param in self.parameters:
+ for fromname, toname in self.compiled._quoted_bind_names.iteritems():
+ param[toname.encode(self.dialect.encoding)] = param[fromname]
+ del param[fromname]
+
+ if self.dialect.auto_setinputsizes:
+ self.set_input_sizes(quoted_bind_names, exclude_types=(self.dialect.dbapi.STRING,))
+
+ if len(self.compiled_parameters) == 1:
+ for key in self.compiled.binds:
+ bindparam = self.compiled.binds[key]
+ name = self.compiled.bind_names[bindparam]
+ value = self.compiled_parameters[0][name]
+ if bindparam.isoutparam:
+ dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+ if not hasattr(self, 'out_parameters'):
+ self.out_parameters = {}
+ self.out_parameters[name] = self.cursor.var(dbtype)
+ self.parameters[0][quoted_bind_names.get(name, name)] = self.out_parameters[name]
+
+
+ def create_cursor(self):
+ c = self._connection.connection.cursor()
+ if self.dialect.arraysize:
+ c.cursor.arraysize = self.dialect.arraysize
+ return c
+
+ def get_result_proxy(self):
+ if hasattr(self, 'out_parameters'):
+ if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
+ for bind, name in self.compiled.bind_names.iteritems():
+ if name in self.out_parameters:
+ type = bind.type
+ result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect)
+ if result_processor is not None:
+ self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue())
+ else:
+ self.out_parameters[name] = self.out_parameters[name].getvalue()
+ else:
+ for k in self.out_parameters:
+ self.out_parameters[k] = self.out_parameters[k].getvalue()
+
+ if self.cursor.description is not None:
+ for column in self.cursor.description:
+ type_code = column[1]
+ if type_code in self.dialect.ORACLE_BINARY_TYPES:
+ return base.BufferedColumnResultProxy(self)
+
+ if hasattr(self, 'out_parameters') and \
+ self.compiled.returning:
+
+ return ReturningResultProxy(self)
+ else:
+ return base.ResultProxy(self)
+
+class ReturningResultProxy(base.FullyBufferedResultProxy):
+ """Result proxy which stuffs the _returning clause + outparams into the fetch."""
+
+ def _cursor_description(self):
+ returning = self.context.compiled.returning
+
+ ret = []
+ for c in returning:
+ if hasattr(c, 'key'):
+ ret.append((c.key, c.type))
+ else:
+ ret.append((c.anon_label, c.type))
+ return ret
+
+ def _buffer_rows(self):
+ returning = self.context.compiled.returning
+ return [tuple(self.context.out_parameters["ret_%d" % i] for i, c in enumerate(returning))]
+
+class Oracle_cx_oracle(OracleDialect):
+ execution_ctx_cls = Oracle_cx_oracleExecutionContext
+ statement_compiler = Oracle_cx_oracleCompiler
+ driver = "cx_oracle"
+ colspecs = colspecs
+
+ def __init__(self,
+ auto_setinputsizes=True,
+ auto_convert_lobs=True,
+ threaded=True,
+ allow_twophase=True,
+ arraysize=50, **kwargs):
+ OracleDialect.__init__(self, **kwargs)
+ self.threaded = threaded
+ self.arraysize = arraysize
+ self.allow_twophase = allow_twophase
+ self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
+ self.auto_setinputsizes = auto_setinputsizes
+ self.auto_convert_lobs = auto_convert_lobs
+
+ def vers(num):
+ return tuple([int(x) for x in num.split('.')])
+
+ if hasattr(self.dbapi, 'version'):
+ cx_oracle_ver = vers(self.dbapi.version)
+ self.supports_unicode_binds = cx_oracle_ver >= (5, 0)
+
+ if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__:
+ self.dbapi_type_map = {}
+ self.ORACLE_BINARY_TYPES = []
+ else:
+ # only use this for LOB objects. using it for strings, dates
+ # etc. leads to a little too much magic, reflection doesn't know if it should
+ # expect encoded strings or unicodes, etc.
+ self.dbapi_type_map = {
+ self.dbapi.CLOB: oracle.CLOB(),
+ self.dbapi.NCLOB:oracle.NCLOB(),
+ self.dbapi.BLOB: oracle.BLOB(),
+ self.dbapi.BINARY: oracle.RAW(),
+ }
+ self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
+
+ @classmethod
+ def dbapi(cls):
+ import cx_Oracle
+ return cx_Oracle
+
+ def create_connect_args(self, url):
+ dialect_opts = dict(url.query)
+ for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
+ 'threaded', 'allow_twophase'):
+ if opt in dialect_opts:
+ util.coerce_kw_type(dialect_opts, opt, bool)
+ setattr(self, opt, dialect_opts[opt])
+
+ if url.database:
+ # if we have a database, then we have a remote host
+ port = url.port
+ if port:
+ port = int(port)
+ else:
+ port = 1521
+ dsn = self.dbapi.makedsn(url.host, port, url.database)
+ else:
+ # we have a local tnsname
+ dsn = url.host
+
+ opts = dict(
+ user=url.username,
+ password=url.password,
+ dsn=dsn,
+ threaded=self.threaded,
+ twophase=self.allow_twophase,
+ )
+ if 'mode' in url.query:
+ opts['mode'] = url.query['mode']
+ if isinstance(opts['mode'], basestring):
+ mode = opts['mode'].upper()
+ if mode == 'SYSDBA':
+ opts['mode'] = self.dbapi.SYSDBA
+ elif mode == 'SYSOPER':
+ opts['mode'] = self.dbapi.SYSOPER
+ else:
+ util.coerce_kw_type(opts, 'mode', int)
+ # Can't set 'handle' or 'pool' via URL query args, use connect_args
+
+ return ([], opts)
+
+ def _get_server_version_info(self, connection):
+ return tuple(int(x) for x in connection.connection.version.split('.'))
+
+ def is_disconnect(self, e):
+ if isinstance(e, self.dbapi.InterfaceError):
+ return "not connected" in str(e)
+ else:
+ return "ORA-03114" in str(e) or "ORA-03113" in str(e)
+
+ def create_xid(self):
+ """create a two-phase transaction ID.
+
+ this id will be passed to do_begin_twophase(), do_rollback_twophase(),
+ do_commit_twophase(). its format is unspecified."""
+
+ id = random.randint(0, 2 ** 128)
+ return (0x1234, "%032x" % id, "%032x" % 9)
+
+ def do_begin_twophase(self, connection, xid):
+ connection.connection.begin(*xid)
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.connection.prepare()
+
+ def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+ self.do_rollback(connection.connection)
+
+ def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+ self.do_commit(connection.connection)
+
+ def do_recover_twophase(self, connection):
+ pass
+
+dialect = Oracle_cx_oracle
--- /dev/null
+"""Support for the Oracle database via the zxjdbc JDBC connector."""
+import re
+
+from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
+from sqlalchemy.dialects.oracle.base import OracleDialect
+
+class Oracle_jdbc(ZxJDBCConnector, OracleDialect):
+
+ jdbc_db_name = 'oracle'
+ jdbc_driver_name = 'oracle.jdbc.driver.OracleDriver'
+
+ def create_connect_args(self, url):
+ hostname = url.host
+ port = url.port or '1521'
+ dbname = url.database
+ jdbc_url = 'jdbc:oracle:thin:@%s:%s:%s' % (hostname, port, dbname)
+ return [[jdbc_url, url.username, url.password, self.jdbc_driver_name],
+ self._driver_kwargs()]
+
+ def _get_server_version_info(self, connection):
+ version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
+ return tuple(int(x) for x in version.split('.'))
+
+dialect = Oracle_jdbc
--- /dev/null
+# backwards compat with the old name
+from sqlalchemy.util import warn_deprecated
+
+warn_deprecated(
+ "The SQLAlchemy PostgreSQL dialect has been renamed from 'postgres' to 'postgresql'. "
+ "The new URL format is postgresql[+driver]://<user>:<pass>@<host>/<dbname>"
+ )
+
+from sqlalchemy.dialects.postgresql import *
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.postgresql import base, psycopg2, pg8000, zxjdbc
+
+base.dialect = psycopg2.dialect
\ No newline at end of file
-# postgres.py
+# postgresql.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Support for the PostgreSQL database.
+"""Support for the PostgreSQL database.
-Driver
-------
-
-The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
-The dialect has several behaviors which are specifically tailored towards compatibility
-with this module.
-
-Note that psycopg1 is **not** supported.
-
-Connecting
-----------
-
-URLs are of the form `postgres://user:password@host:port/dbname[?key=value&key=value...]`.
-
-PostgreSQL-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
-
-* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
- this feature. What this essentially means from a psycopg2 point of view is that the cursor is
- created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
- are not immediately pre-fetched and buffered after statement execution, but are instead left
- on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
- uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows
- at a time are fetched over the wire to reduce conversational overhead.
+For information on connecting using specific drivers, see the documentation section
+regarding that driver.
Sequences/SERIAL
----------------
but must be explicitly enabled on a per-statement basis::
# INSERT..RETURNING
- result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\
+ result = table.insert(postgresql_returning=[table.c.col1, table.c.col2]).\\
values(name='foo')
print result.fetchall()
# UPDATE..RETURNING
- result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\
+ result = table.update(postgresql_returning=[table.c.col1, table.c.col2]).\\
where(table.c.name=='foo').values(name='bar')
print result.fetchall()
Indexes
-------
-PostgreSQL supports partial indexes. To create them pass a postgres_where
+PostgreSQL supports partial indexes. To create them pass a postgresql_where
option to the Index constructor::
- Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10)
+ Index('my_index', my_table.c.id, postgresql_where=tbl.c.value > 10)
-Transactions
-------------
-
-The PostgreSQL dialect fully supports SAVEPOINT and two-phase commit operations.
"""
-import decimal, random, re, string
+import re
+from sqlalchemy import schema as sa_schema
from sqlalchemy import sql, schema, exc, util
-from sqlalchemy.engine import base, default
+from sqlalchemy.engine import base, default, reflection
from sqlalchemy.sql import compiler, expression
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import types as sqltypes
+from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \
+ CHAR, TEXT, FLOAT, NUMERIC, \
+ TIMESTAMP, TIME, DATE, BOOLEAN
-class PGInet(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "INET"
-
-class PGCidr(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "CIDR"
-
-class PGMacAddr(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "MACADDR"
-
-class PGNumeric(sqltypes.Numeric):
- def get_col_spec(self):
- if not self.precision:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
- def bind_processor(self, dialect):
- return None
-
- def result_processor(self, dialect):
- if self.asdecimal:
- return None
- else:
- def process(value):
- if isinstance(value, decimal.Decimal):
- return float(value)
- else:
- return value
- return process
-
-class PGFloat(sqltypes.Float):
- def get_col_spec(self):
- if not self.precision:
- return "FLOAT"
- else:
- return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class PGInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-class PGSmallInteger(sqltypes.Smallinteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-class PGBigInteger(PGInteger):
- def get_col_spec(self):
- return "BIGINT"
-
-class PGDateTime(sqltypes.DateTime):
- def get_col_spec(self):
- return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+class REAL(sqltypes.Float):
+ __visit_name__ = "REAL"
-class PGDate(sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
+class BYTEA(sqltypes.Binary):
+ __visit_name__ = 'BYTEA'
-class PGTime(sqltypes.Time):
- def get_col_spec(self):
- return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+class DOUBLE_PRECISION(sqltypes.Float):
+ __visit_name__ = 'DOUBLE_PRECISION'
+
+class INET(sqltypes.TypeEngine):
+ __visit_name__ = "INET"
+PGInet = INET
-class PGInterval(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "INTERVAL"
+class CIDR(sqltypes.TypeEngine):
+ __visit_name__ = "CIDR"
+PGCidr = CIDR
-class PGText(sqltypes.Text):
- def get_col_spec(self):
- return "TEXT"
+class MACADDR(sqltypes.TypeEngine):
+ __visit_name__ = "MACADDR"
+PGMacAddr = MACADDR
-class PGString(sqltypes.String):
- def get_col_spec(self):
- if self.length:
- return "VARCHAR(%(length)d)" % {'length' : self.length}
- else:
- return "VARCHAR"
+class INTERVAL(sqltypes.TypeEngine):
+ __visit_name__ = 'INTERVAL'
+PGInterval = INTERVAL
-class PGChar(sqltypes.CHAR):
- def get_col_spec(self):
- if self.length:
- return "CHAR(%(length)d)" % {'length' : self.length}
- else:
- return "CHAR"
+class BIT(sqltypes.TypeEngine):
+ __visit_name__ = 'BIT'
+PGBit = BIT
-class PGBinary(sqltypes.Binary):
- def get_col_spec(self):
- return "BYTEA"
+class UUID(sqltypes.TypeEngine):
+ __visit_name__ = 'UUID'
+PGUuid = UUID
-class PGBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return "BOOLEAN"
-
-class PGBit(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "BIT"
-
-class PGUuid(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "UUID"
+class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
+ __visit_name__ = 'ARRAY'
-class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
def __init__(self, item_type, mutable=True):
if isinstance(item_type, type):
item_type = item_type()
return item
return [convert_item(item) for item in value]
return process
- def get_col_spec(self):
- return self.item_type.get_col_spec() + '[]'
+PGArray = ARRAY
colspecs = {
- sqltypes.Integer : PGInteger,
- sqltypes.Smallinteger : PGSmallInteger,
- sqltypes.Numeric : PGNumeric,
- sqltypes.Float : PGFloat,
- sqltypes.DateTime : PGDateTime,
- sqltypes.Date : PGDate,
- sqltypes.Time : PGTime,
- sqltypes.String : PGString,
- sqltypes.Binary : PGBinary,
- sqltypes.Boolean : PGBoolean,
- sqltypes.Text : PGText,
- sqltypes.CHAR: PGChar,
+ sqltypes.Interval:INTERVAL
}
ischema_names = {
- 'integer' : PGInteger,
- 'bigint' : PGBigInteger,
- 'smallint' : PGSmallInteger,
- 'character varying' : PGString,
- 'character' : PGChar,
- '"char"' : PGChar,
- 'name': PGChar,
- 'text' : PGText,
- 'numeric' : PGNumeric,
- 'float' : PGFloat,
- 'real' : PGFloat,
- 'inet': PGInet,
- 'cidr': PGCidr,
- 'uuid':PGUuid,
- 'bit':PGBit,
- 'macaddr': PGMacAddr,
- 'double precision' : PGFloat,
- 'timestamp' : PGDateTime,
- 'timestamp with time zone' : PGDateTime,
- 'timestamp without time zone' : PGDateTime,
- 'time with time zone' : PGTime,
- 'time without time zone' : PGTime,
- 'date' : PGDate,
- 'time': PGTime,
- 'bytea' : PGBinary,
- 'boolean' : PGBoolean,
- 'interval':PGInterval,
+ 'integer' : INTEGER,
+ 'bigint' : BIGINT,
+ 'smallint' : SMALLINT,
+ 'character varying' : VARCHAR,
+ 'character' : CHAR,
+ '"char"' : sqltypes.String,
+ 'name' : sqltypes.String,
+ 'text' : TEXT,
+ 'numeric' : NUMERIC,
+ 'float' : FLOAT,
+ 'real' : REAL,
+ 'inet': INET,
+ 'cidr': CIDR,
+ 'uuid': UUID,
+ 'bit':BIT,
+ 'macaddr': MACADDR,
+ 'double precision' : DOUBLE_PRECISION,
+ 'timestamp' : TIMESTAMP,
+ 'timestamp with time zone' : TIMESTAMP,
+ 'timestamp without time zone' : TIMESTAMP,
+ 'time with time zone' : TIME,
+ 'time without time zone' : TIME,
+ 'date' : DATE,
+ 'time': TIME,
+ 'bytea' : BYTEA,
+ 'boolean' : BOOLEAN,
+ 'interval':INTERVAL,
}
-# TODO: filter out 'FOR UPDATE' statements
-SERVER_SIDE_CURSOR_RE = re.compile(
- r'\s*SELECT',
- re.I | re.UNICODE)
-
-class PGExecutionContext(default.DefaultExecutionContext):
- def create_cursor(self):
- # TODO: coverage for server side cursors + select.for_update()
- is_server_side = \
- self.dialect.server_side_cursors and \
- ((self.compiled and isinstance(self.compiled.statement, expression.Selectable)
- and not getattr(self.compiled.statement, 'for_update', False)) \
- or \
- (
- (not self.compiled or isinstance(self.compiled.statement, expression._TextClause))
- and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
- )
- self.__is_server_side = is_server_side
- if is_server_side:
- # use server-side cursors:
- # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
- ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
- return self._connection.connection.cursor(ident)
+
+class PGCompiler(compiler.SQLCompiler):
+
+ def visit_match_op(self, binary, **kw):
+ return "%s @@ to_tsquery(%s)" % (self.process(binary.left), self.process(binary.right))
+
+ def visit_ilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def visit_notilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def post_process_text(self, text):
+ if '%%' in text:
+ util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() expressions to '%%'.")
+ return text.replace('%', '%%')
+
+ def visit_sequence(self, seq):
+ if seq.optional:
+ return None
+ else:
+ return "nextval('%s')" % self.preparer.format_sequence(seq)
+
+ def limit_clause(self, select):
+ text = ""
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
+ text += " \n LIMIT ALL"
+ text += " OFFSET " + str(select._offset)
+ return text
+
+ def get_select_precolumns(self, select):
+ if select._distinct:
+ if isinstance(select._distinct, bool):
+ return "DISTINCT "
+ elif isinstance(select._distinct, (list, tuple)):
+ return "DISTINCT ON (" + ', '.join(
+ [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
+ )+ ") "
+ else:
+ return "DISTINCT ON (" + unicode(select._distinct) + ") "
+ else:
+ return ""
+
+ def for_update_clause(self, select):
+ if select.for_update == 'nowait':
+ return " FOR UPDATE NOWAIT"
+ else:
+ return super(PGCompiler, self).for_update_clause(select)
+
+ def returning_clause(self, stmt, returning_cols):
+
+ columns = [
+ self.process(
+ self.label_select_column(None, c, asfrom=False),
+ within_columns_clause=True,
+ result_map=self.result_map)
+ for c in expression._select_iterables(returning_cols)
+ ]
+
+ return 'RETURNING ' + ', '.join(columns)
+
+ def visit_extract(self, extract, **kwargs):
+ field = self.extract_map.get(extract.field, extract.field)
+ return "EXTRACT(%s FROM %s::timestamp)" % (
+ field, self.process(extract.expr))
+
+class PGDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column)
+ if column.primary_key and \
+ len(column.foreign_keys)==0 and \
+ column.autoincrement and \
+ isinstance(column.type, sqltypes.Integer) and \
+ not isinstance(column.type, sqltypes.SmallInteger) and \
+ (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+ if isinstance(column.type, sqltypes.BigInteger):
+ colspec += " BIGSERIAL"
+ else:
+ colspec += " SERIAL"
else:
- return self._connection.connection.cursor()
+ colspec += " " + self.dialect.type_compiler.process(column.type)
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+ return colspec
+
+ def visit_create_sequence(self, create):
+ return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+
+ def visit_drop_sequence(self, drop):
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+
+ def visit_create_index(self, create):
+ preparer = self.preparer
+ index = create.element
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ text += "INDEX %s ON %s (%s)" \
+ % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
+ preparer.format_table(index.table),
+ ', '.join([preparer.format_column(c) for c in index.columns]))
+
+ if "postgres_where" in index.kwargs:
+ whereclause = index.kwargs['postgres_where']
+ util.warn_deprecated("The 'postgres_where' argument has been renamed to 'postgresql_where'.")
+ elif 'postgresql_where' in index.kwargs:
+ whereclause = index.kwargs['postgresql_where']
+ else:
+ whereclause = None
+
+ if whereclause is not None:
+ compiler = self._compile(whereclause, None)
+ # this might belong to the compiler class
+ inlined_clause = str(compiler) % dict(
+ [(key,bind.value) for key,bind in compiler.binds.iteritems()])
+ text += " WHERE " + inlined_clause
+ return text
+
+
+class PGDefaultRunner(base.DefaultRunner):
+ def __init__(self, context):
+ base.DefaultRunner.__init__(self, context)
+ # craete cursor which won't conflict with a server-side cursor
+ self.cursor = context._connection.connection.cursor()
- def get_result_proxy(self):
- if self.__is_server_side:
- return base.BufferedRowResultProxy(self)
+ def get_column_default(self, column, isinsert=True):
+ if column.primary_key:
+ # pre-execute passive defaults on primary keys
+ if (isinstance(column.server_default, schema.DefaultClause) and
+ column.server_default.arg is not None):
+ return self.execute_string("select %s" % column.server_default.arg)
+ elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) \
+ and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+ sch = column.table.schema
+ # TODO: this has to build into the Sequence object so we can get the quoting
+ # logic from it
+ if sch is not None:
+ exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
+ else:
+ exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
+
+ if self.dialect.supports_unicode_statements:
+ return self.execute_string(exc)
+ else:
+ return self.execute_string(exc.encode(self.dialect.encoding))
+
+ return super(PGDefaultRunner, self).get_column_default(column)
+
+ def visit_sequence(self, seq):
+ if not seq.optional:
+ return self.execute_string(("select nextval('%s')" % \
+ self.dialect.identifier_preparer.format_sequence(seq)))
else:
- return base.ResultProxy(self)
+ return None
+
+class PGTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_INET(self, type_):
+ return "INET"
+
+ def visit_CIDR(self, type_):
+ return "CIDR"
+
+ def visit_MACADDR(self, type_):
+ return "MACADDR"
+
+ def visit_FLOAT(self, type_):
+ if not type_.precision:
+ return "FLOAT"
+ else:
+ return "FLOAT(%(precision)s)" % {'precision': type_.precision}
+
+ def visit_DOUBLE_PRECISION(self, type_):
+ return "DOUBLE PRECISION"
+
+ def visit_BIGINT(self, type_):
+ return "BIGINT"
+
+ def visit_datetime(self, type_):
+ return self.visit_TIMESTAMP(type_)
+
+ def visit_TIMESTAMP(self, type_):
+ return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
+ def visit_TIME(self, type_):
+ return "TIME " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
+ def visit_INTERVAL(self, type_):
+ return "INTERVAL"
+
+ def visit_BIT(self, type_):
+ return "BIT"
+
+ def visit_UUID(self, type_):
+ return "UUID"
+
+ def visit_binary(self, type_):
+ return self.visit_BYTEA(type_)
+
+ def visit_BYTEA(self, type_):
+ return "BYTEA"
+
+ def visit_REAL(self, type_):
+ return "REAL"
+
+ def visit_ARRAY(self, type_):
+ return self.process(type_.item_type) + '[]'
+
+class PGIdentifierPreparer(compiler.IdentifierPreparer):
+ def _unquote_identifier(self, value):
+ if value[0] == self.initial_quote:
+ value = value[1:-1].replace('""','"')
+ return value
+
+class PGInspector(reflection.Inspector):
+
+ def __init__(self, conn):
+ reflection.Inspector.__init__(self, conn)
+
+ def get_table_oid(self, table_name, schema=None):
+ """Return the oid from `table_name` and `schema`."""
+
+ return self.dialect.get_table_oid(self.conn, table_name, schema,
+ info_cache=self.info_cache)
+
class PGDialect(default.DefaultDialect):
- name = 'postgres'
+ name = 'postgresql'
supports_alter = True
- supports_unicode_statements = False
max_identifier_length = 63
supports_sane_rowcount = True
- supports_sane_multi_rowcount = False
- preexecute_pk_sequences = True
- supports_pk_autoincrement = False
- default_paramstyle = 'pyformat'
+
+ supports_sequences = True
+ sequences_optional = True
+ preexecute_autoincrement_sequences = True
+ postfetch_lastrowid = False
+
supports_default_values = True
supports_empty_insert = False
+ default_paramstyle = 'pyformat'
+ ischema_names = ischema_names
+ colspecs = colspecs
- def __init__(self, server_side_cursors=False, **kwargs):
+ statement_compiler = PGCompiler
+ ddl_compiler = PGDDLCompiler
+ type_compiler = PGTypeCompiler
+ preparer = PGIdentifierPreparer
+ defaultrunner = PGDefaultRunner
+ inspector = PGInspector
+ isolation_level = None
+
+ def __init__(self, isolation_level=None, **kwargs):
default.DefaultDialect.__init__(self, **kwargs)
- self.server_side_cursors = server_side_cursors
-
- def dbapi(cls):
- import psycopg2 as psycopg
- return psycopg
- dbapi = classmethod(dbapi)
-
- def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- if 'port' in opts:
- opts['port'] = int(opts['port'])
- opts.update(url.query)
- return ([], opts)
+ self.isolation_level = isolation_level
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
+ def initialize(self, connection):
+ super(PGDialect, self).initialize(connection)
+ self.implicit_returning = self.server_version_info > (8, 3) and \
+ self.__dict__.get('implicit_returning', True)
+
+ def visit_pool(self, pool):
+ if self.isolation_level is not None:
+ class SetIsolationLevel(object):
+ def __init__(self, isolation_level):
+ self.isolation_level = isolation_level
+
+ def connect(self, conn, rec):
+ cursor = conn.cursor()
+ cursor.execute("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL %s"
+ % self.isolation_level)
+ cursor.execute("COMMIT")
+ cursor.close()
+ pool.add_listener(SetIsolationLevel(self.isolation_level))
def do_begin_twophase(self, connection, xid):
self.do_begin(connection.connection)
def do_prepare_twophase(self, connection, xid):
- connection.execute(sql.text("PREPARE TRANSACTION :tid", bindparams=[sql.bindparam('tid', xid)]))
+ connection.execute("PREPARE TRANSACTION '%s'" % xid)
def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
if is_prepared:
if recover:
#FIXME: ugly hack to get out of transaction context when commiting recoverable transactions
# Must find out a way how to make the dbapi not open a transaction.
- connection.execute(sql.text("ROLLBACK"))
- connection.execute(sql.text("ROLLBACK PREPARED :tid", bindparams=[sql.bindparam('tid', xid)]))
- connection.execute(sql.text("BEGIN"))
+ connection.execute("ROLLBACK")
+ connection.execute("ROLLBACK PREPARED '%s'" % xid)
+ connection.execute("BEGIN")
self.do_rollback(connection.connection)
else:
self.do_rollback(connection.connection)
def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
if is_prepared:
if recover:
- connection.execute(sql.text("ROLLBACK"))
- connection.execute(sql.text("COMMIT PREPARED :tid", bindparams=[sql.bindparam('tid', xid)]))
- connection.execute(sql.text("BEGIN"))
+ connection.execute("ROLLBACK")
+ connection.execute("COMMIT PREPARED '%s'" % xid)
+ connection.execute("BEGIN")
self.do_rollback(connection.connection)
else:
self.do_commit(connection.connection)
return [row[0] for row in resultset]
def get_default_schema_name(self, connection):
- return connection.scalar("select current_schema()", None)
- get_default_schema_name = base.connection_memoize(
- ('dialect', 'default_schema_name'))(get_default_schema_name)
-
- def last_inserted_ids(self):
- if self.context.last_inserted_ids is None:
- raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without PostgreSQL OIDs enabled")
- else:
- return self.context.last_inserted_ids
+ return connection.scalar("select current_schema()")
def has_table(self, connection, table_name, schema=None):
# seems like case gets folded in pg_class...
if schema is None:
- cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)});
+ cursor = connection.execute(
+ sql.text("select relname from pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=:name",
+ bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode)]
+ )
+ )
else:
- cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema});
- return bool( not not cursor.rowcount )
+ cursor = connection.execute(
+ sql.text("select relname from pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where n.nspname=:schema and lower(relname)=:name",
+ bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode),
+ sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)]
+ )
+ )
+ return bool(cursor.fetchone())
def has_sequence(self, connection, sequence_name):
- cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name.encode(self.encoding)})
- return bool(not not cursor.rowcount)
-
- def is_disconnect(self, e):
- if isinstance(e, self.dbapi.OperationalError):
- return 'closed the connection' in str(e) or 'connection not open' in str(e)
- elif isinstance(e, self.dbapi.InterfaceError):
- return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
- elif isinstance(e, self.dbapi.ProgrammingError):
- # yes, it really says "losed", not "closed"
- return "losed the connection unexpectedly" in str(e)
- else:
- return False
+ cursor = connection.execute(
+ sql.text("SELECT relname FROM pg_class WHERE relkind = 'S' AND "
+ "relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' "
+ "AND nspname != 'information_schema' AND relname = :seqname)",
+ bindparams=[sql.bindparam('seqname', unicode(sequence_name), type_=sqltypes.Unicode)]
+ ))
+ return bool(cursor.fetchone())
def table_names(self, connection, schema):
- s = """
- SELECT relname
- FROM pg_class c
- WHERE relkind = 'r'
- AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
- """ % locals()
- return [row[0].decode(self.encoding) for row in connection.execute(s)]
+ result = connection.execute(
+ sql.text(u"""SELECT relname
+ FROM pg_class c
+ WHERE relkind = 'r'
+ AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)""" % schema,
+ typemap = {'relname':sqltypes.Unicode}
+ )
+ )
+ return [row[0] for row in result]
- def server_version_info(self, connection):
+ def _get_server_version_info(self, connection):
v = connection.execute("select version()").scalar()
m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v)
if not m:
raise AssertionError("Could not determine version from string '%s'" % v)
return tuple([int(x) for x in m.group(1, 2, 3)])
- def reflecttable(self, connection, table, include_columns):
- preparer = self.identifier_preparer
- if table.schema is not None:
+ @reflection.cache
+ def get_table_oid(self, connection, table_name, schema=None, **kw):
+ """Fetch the oid for schema.table_name.
+
+ Several reflection methods require the table oid. The idea for using
+ this method is that it can be fetched one time and cached for
+ subsequent calls.
+
+ """
+ table_oid = None
+ if schema is not None:
schema_where_clause = "n.nspname = :schema"
- schemaname = table.schema
- if isinstance(schemaname, str):
- schemaname = schemaname.decode(self.encoding)
else:
schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
- schemaname = None
+ query = """
+ SELECT c.oid
+ FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+ WHERE (%s)
+ AND c.relname = :table_name AND c.relkind in ('r','v')
+ """ % schema_where_clause
+ # Since we're binding to unicode, table_name and schema_name must be
+ # unicode.
+ table_name = unicode(table_name)
+ if schema is not None:
+ schema = unicode(schema)
+ s = sql.text(query, bindparams=[
+ sql.bindparam('table_name', type_=sqltypes.Unicode),
+ sql.bindparam('schema', type_=sqltypes.Unicode)
+ ],
+ typemap={'oid':sqltypes.Integer}
+ )
+ c = connection.execute(s, table_name=table_name, schema=schema)
+ table_oid = c.scalar()
+ if table_oid is None:
+ raise exc.NoSuchTableError(table_name)
+ return table_oid
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ s = """
+ SELECT nspname
+ FROM pg_namespace
+ ORDER BY nspname
+ """
+ rp = connection.execute(s)
+ # what about system tables?
+ schema_names = [row[0].decode(self.encoding) for row in rp \
+ if not row[0].startswith('pg_')]
+ return schema_names
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ current_schema = schema
+ else:
+ current_schema = self.get_default_schema_name(connection)
+ table_names = self.table_names(connection, current_schema)
+ return table_names
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ current_schema = schema
+ else:
+ current_schema = self.get_default_schema_name(connection)
+ s = """
+ SELECT relname
+ FROM pg_class c
+ WHERE relkind = 'v'
+ AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
+ """ % dict(schema=current_schema)
+ view_names = [row[0].decode(self.encoding) for row in connection.execute(s)]
+ return view_names
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ if schema is not None:
+ current_schema = schema
+ else:
+ current_schema = self.get_default_schema_name(connection)
+ s = """
+ SELECT definition FROM pg_views
+ WHERE schemaname = :schema
+ AND viewname = :view_name
+ """
+ rp = connection.execute(sql.text(s),
+ view_name=view_name, schema=current_schema)
+ if rp:
+ view_def = rp.scalar().decode(self.encoding)
+ return view_def
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+
+ table_oid = self.get_table_oid(connection, table_name, schema,
+ info_cache=kw.get('info_cache'))
SQL_COLS = """
SELECT a.attname,
pg_catalog.format_type(a.atttypid, a.atttypmod),
AS DEFAULT,
a.attnotnull, a.attnum, a.attrelid as table_oid
FROM pg_catalog.pg_attribute a
- WHERE a.attrelid = (
- SELECT c.oid
- FROM pg_catalog.pg_class c
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
- WHERE (%s)
- AND c.relname = :table_name AND c.relkind in ('r','v')
- ) AND a.attnum > 0 AND NOT a.attisdropped
+ WHERE a.attrelid = :table_oid
+ AND a.attnum > 0 AND NOT a.attisdropped
ORDER BY a.attnum
- """ % schema_where_clause
-
- s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode})
- tablename = table.name
- if isinstance(tablename, str):
- tablename = tablename.decode(self.encoding)
- c = connection.execute(s, table_name=tablename, schema=schemaname)
+ """
+ s = sql.text(SQL_COLS,
+ bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)],
+ typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}
+ )
+ c = connection.execute(s, table_oid=table_oid)
rows = c.fetchall()
-
- if not rows:
- raise exc.NoSuchTableError(table.name)
-
domains = self._load_domains(connection)
-
+ # format columns
+ columns = []
for name, format_type, default, notnull, attnum, table_oid in rows:
- if include_columns and name not in include_columns:
- continue
-
## strip (30) from character varying(30)
attype = re.search('([^\([]+)', format_type).group(1)
nullable = not notnull
is_array = format_type.endswith('[]')
-
try:
charlen = re.search('\(([\d,]+)\)', format_type).group(1)
except:
charlen = False
-
numericprec = False
numericscale = False
if attype == 'numeric':
if attype == 'integer':
numericprec, numericscale = (32, 0)
charlen = False
-
args = []
for a in (charlen, numericprec, numericscale):
if a is None:
args.append(None)
elif a is not False:
args.append(int(a))
-
kwargs = {}
if attype == 'timestamp with time zone':
kwargs['timezone'] = True
elif attype == 'timestamp without time zone':
kwargs['timezone'] = False
-
- coltype = None
- if attype in ischema_names:
- coltype = ischema_names[attype]
+ if attype in self.ischema_names:
+ coltype = self.ischema_names[attype]
else:
if attype in domains:
domain = domains[attype]
- if domain['attype'] in ischema_names:
+ if domain['attype'] in self.ischema_names:
# A table can't override whether the domain is nullable.
nullable = domain['nullable']
-
if domain['default'] and not default:
# It can, however, override the default value, but can't set it to null.
default = domain['default']
- coltype = ischema_names[domain['attype']]
-
+ coltype = self.ischema_names[domain['attype']]
+ else:
+ coltype = None
if coltype:
coltype = coltype(*args, **kwargs)
if is_array:
util.warn("Did not recognize type '%s' of column '%s'" %
(attype, name))
coltype = sqltypes.NULLTYPE
-
- colargs = []
+ # adjust the default value
+ autoincrement = False
if default is not None:
match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
if match is not None:
+ autoincrement = True
# the default is related to a Sequence
- sch = table.schema
+ sch = schema
if '.' not in match.group(2) and sch is not None:
# unconditionally quote the schema name. this could
# later be enhanced to obey quoting rules / "quote schema"
default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
- colargs.append(schema.DefaultClause(sql.text(default)))
- table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
+ column_info = dict(name=name, type=coltype, nullable=nullable,
+ default=default, autoincrement=autoincrement)
+ columns.append(column_info)
+ return columns
- # Primary keys
+ @reflection.cache
+ def get_primary_keys(self, connection, table_name, schema=None, **kw):
+ table_oid = self.get_table_oid(connection, table_name, schema,
+ info_cache=kw.get('info_cache'))
PK_SQL = """
SELECT attname FROM pg_attribute
WHERE attrelid = (
SELECT indexrelid FROM pg_index i
- WHERE i.indrelid = :table
+ WHERE i.indrelid = :table_oid
AND i.indisprimary = 't')
ORDER BY attnum
"""
t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
- c = connection.execute(t, table=table_oid)
- for row in c.fetchall():
- pk = row[0]
- if pk in table.c:
- col = table.c[pk]
- table.primary_key.add(col)
- if col.default is None:
- col.autoincrement = False
-
- # Foreign keys
+ c = connection.execute(t, table_oid=table_oid)
+ primary_keys = [r[0] for r in c.fetchall()]
+ return primary_keys
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ preparer = self.identifier_preparer
+ table_oid = self.get_table_oid(connection, table_name, schema,
+ info_cache=kw.get('info_cache'))
FK_SQL = """
SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
FROM pg_catalog.pg_constraint r
t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode})
c = connection.execute(t, table=table_oid)
+ fkeys = []
for conname, condef in c.fetchall():
m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups()
(constrained_columns, referred_schema, referred_table, referred_columns) = m
constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
if referred_schema:
referred_schema = preparer._unquote_identifier(referred_schema)
- elif table.schema is not None and table.schema == self.get_default_schema_name(connection):
+ elif schema is not None and schema == self.get_default_schema_name(connection):
# no schema (i.e. its the default schema), and the table we're
# reflecting has the default schema explicit, then use that.
# i.e. try to use the user's conventions
- referred_schema = table.schema
+ referred_schema = schema
referred_table = preparer._unquote_identifier(referred_table)
referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
-
- refspec = []
- if referred_schema is not None:
- schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
- autoload_with=connection)
- for column in referred_columns:
- refspec.append(".".join([referred_schema, referred_table, column]))
- else:
- schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
- for column in referred_columns:
- refspec.append(".".join([referred_table, column]))
-
- table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
-
- # Indexes
+ fkey_d = {
+ 'name' : conname,
+ 'constrained_columns' : constrained_columns,
+ 'referred_schema' : referred_schema,
+ 'referred_table' : referred_table,
+ 'referred_columns' : referred_columns
+ }
+ fkeys.append(fkey_d)
+ return fkeys
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema, **kw):
+ table_oid = self.get_table_oid(connection, table_name, schema,
+ info_cache=kw.get('info_cache'))
IDX_SQL = """
SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
a.attname
FROM pg_index i, pg_class c, pg_attribute a
- WHERE i.indrelid = :table AND i.indexrelid = c.oid
+ WHERE i.indrelid = :table_oid AND i.indexrelid = c.oid
AND a.attrelid = i.indexrelid AND i.indisprimary = 'f'
ORDER BY c.relname, a.attnum
"""
t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode})
- c = connection.execute(t, table=table_oid)
- indexes = {}
+ c = connection.execute(t, table_oid=table_oid)
+ index_names = {}
+ indexes = []
sv_idx_name = None
for row in c.fetchall():
idx_name, unique, expr, prd, col = row
-
if expr:
- if not idx_name == sv_idx_name:
+ if idx_name != sv_idx_name:
util.warn(
"Skipped unsupported reflection of expression-based index %s"
% idx_name)
"Predicate of partial index %s ignored during reflection"
% idx_name)
sv_idx_name = idx_name
-
- if not indexes.has_key(idx_name):
- indexes[idx_name] = [unique, []]
- indexes[idx_name][1].append(col)
-
- for name, (unique, columns) in indexes.items():
- schema.Index(name, *[table.columns[c] for c in columns],
- **dict(unique=unique))
-
-
+ if idx_name in index_names:
+ index_d = index_names[idx_name]
+ else:
+ index_d = {'column_names':[]}
+ indexes.append(index_d)
+ index_names[idx_name] = index_d
+ index_d['name'] = idx_name
+ index_d['column_names'].append(col)
+ index_d['unique'] = unique
+ return indexes
def _load_domains(self, connection):
## Load data types for domains:
return domains
-
-class PGCompiler(compiler.DefaultCompiler):
- operators = compiler.DefaultCompiler.operators.copy()
- operators.update(
- {
- sql_operators.mod : '%%',
- sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
- }
- )
-
- functions = compiler.DefaultCompiler.functions.copy()
- functions.update (
- {
- 'TIMESTAMP':util.deprecated(message="Use a literal string 'timestamp <value>' instead")(lambda x:'TIMESTAMP %s' % x),
- }
- )
-
- def visit_sequence(self, seq):
- if seq.optional:
- return None
- else:
- return "nextval('%s')" % self.preparer.format_sequence(seq)
-
- def post_process_text(self, text):
- if '%%' in text:
- util.warn("The SQLAlchemy psycopg2 dialect now automatically escapes '%' in text() expressions to '%%'.")
- return text.replace('%', '%%')
-
- def limit_clause(self, select):
- text = ""
- if select._limit is not None:
- text += " \n LIMIT " + str(select._limit)
- if select._offset is not None:
- if select._limit is None:
- text += " \n LIMIT ALL"
- text += " OFFSET " + str(select._offset)
- return text
-
- def get_select_precolumns(self, select):
- if select._distinct:
- if isinstance(select._distinct, bool):
- return "DISTINCT "
- elif isinstance(select._distinct, (list, tuple)):
- return "DISTINCT ON (" + ', '.join(
- [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
- )+ ") "
- else:
- return "DISTINCT ON (" + unicode(select._distinct) + ") "
- else:
- return ""
-
- def for_update_clause(self, select):
- if select.for_update == 'nowait':
- return " FOR UPDATE NOWAIT"
- else:
- return super(PGCompiler, self).for_update_clause(select)
-
- def _append_returning(self, text, stmt):
- returning_cols = stmt.kwargs['postgres_returning']
- def flatten_columnlist(collist):
- for c in collist:
- if isinstance(c, expression.Selectable):
- for co in c.columns:
- yield co
- else:
- yield c
- columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
- text += ' RETURNING ' + string.join(columns, ', ')
- return text
-
- def visit_update(self, update_stmt):
- text = super(PGCompiler, self).visit_update(update_stmt)
- if 'postgres_returning' in update_stmt.kwargs:
- return self._append_returning(text, update_stmt)
- else:
- return text
-
- def visit_insert(self, insert_stmt):
- text = super(PGCompiler, self).visit_insert(insert_stmt)
- if 'postgres_returning' in insert_stmt.kwargs:
- return self._append_returning(text, insert_stmt)
- else:
- return text
-
- def visit_extract(self, extract, **kwargs):
- field = self.extract_map.get(extract.field, extract.field)
- return "EXTRACT(%s FROM %s::timestamp)" % (
- field, self.process(extract.expr))
-
-
-class PGSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column)
- if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
- if isinstance(column.type, PGBigInteger):
- colspec += " BIGSERIAL"
- else:
- colspec += " SERIAL"
- else:
- colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- if not column.nullable:
- colspec += " NOT NULL"
- return colspec
-
- def visit_sequence(self, sequence):
- if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
- self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
- self.execute()
-
- def visit_index(self, index):
- preparer = self.preparer
- self.append("CREATE ")
- if index.unique:
- self.append("UNIQUE ")
- self.append("INDEX %s ON %s (%s)" \
- % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
- preparer.format_table(index.table),
- string.join([preparer.format_column(c) for c in index.columns], ', ')))
- whereclause = index.kwargs.get('postgres_where', None)
- if whereclause is not None:
- compiler = self._compile(whereclause, None)
- # this might belong to the compiler class
- inlined_clause = str(compiler) % dict(
- [(key,bind.value) for key,bind in compiler.binds.iteritems()])
- self.append(" WHERE " + inlined_clause)
- self.execute()
-
-class PGSchemaDropper(compiler.SchemaDropper):
- def visit_sequence(self, sequence):
- if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
- self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
- self.execute()
-
-class PGDefaultRunner(base.DefaultRunner):
- def __init__(self, context):
- base.DefaultRunner.__init__(self, context)
- # craete cursor which won't conflict with a server-side cursor
- self.cursor = context._connection.connection.cursor()
-
- def get_column_default(self, column, isinsert=True):
- if column.primary_key:
- # pre-execute passive defaults on primary keys
- if (isinstance(column.server_default, schema.DefaultClause) and
- column.server_default.arg is not None):
- return self.execute_string("select %s" % column.server_default.arg)
- elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
- sch = column.table.schema
- # TODO: this has to build into the Sequence object so we can get the quoting
- # logic from it
- if sch is not None:
- exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
- else:
- exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- return self.execute_string(exc.encode(self.dialect.encoding))
-
- return super(PGDefaultRunner, self).get_column_default(column)
-
- def visit_sequence(self, seq):
- if not seq.optional:
- return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
- else:
- return None
-
-class PGIdentifierPreparer(compiler.IdentifierPreparer):
- def _unquote_identifier(self, value):
- if value[0] == self.initial_quote:
- value = value[1:-1].replace('""','"')
- return value
-
-dialect = PGDialect
-dialect.statement_compiler = PGCompiler
-dialect.schemagenerator = PGSchemaGenerator
-dialect.schemadropper = PGSchemaDropper
-dialect.preparer = PGIdentifierPreparer
-dialect.defaultrunner = PGDefaultRunner
-dialect.execution_ctx_cls = PGExecutionContext
--- /dev/null
+"""Support for the PostgreSQL database via the pg8000 driver.
+
+Connecting
+----------
+
+URLs are of the form `postgresql+pg8000://user@password@host:port/dbname[?key=value&key=value...]`.
+
+Unicode
+-------
+
+pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file
+in order to use encodings other than ascii. Set this value to the same value as
+the "encoding" parameter on create_engine(), usually "utf-8".
+
+Interval
+--------
+
+Passing data from/to the Interval type is not supported as of yet.
+
+"""
+from sqlalchemy.engine import default
+import decimal
+from sqlalchemy import util
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler
+
+class _PGNumeric(sqltypes.Numeric):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect):
+ if self.asdecimal:
+ return None
+ else:
+ def process(value):
+ if isinstance(value, decimal.Decimal):
+ return float(value)
+ else:
+ return value
+ return process
+
+class PostgreSQL_pg8000ExecutionContext(default.DefaultExecutionContext):
+ pass
+
+class PostgreSQL_pg8000Compiler(PGCompiler):
+ def visit_mod(self, binary, **kw):
+ return self.process(binary.left) + " %% " + self.process(binary.right)
+
+
+class PostgreSQL_pg8000(PGDialect):
+ driver = 'pg8000'
+
+ supports_unicode_statements = True
+
+ supports_unicode_binds = True
+
+ default_paramstyle = 'format'
+ supports_sane_multi_rowcount = False
+ execution_ctx_cls = PostgreSQL_pg8000ExecutionContext
+ statement_compiler = PostgreSQL_pg8000Compiler
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Numeric : _PGNumeric,
+ sqltypes.Float: sqltypes.Float, # prevents _PGNumeric from being used
+ }
+ )
+
+ @classmethod
+ def dbapi(cls):
+ return __import__('pg8000').dbapi
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username='user')
+ if 'port' in opts:
+ opts['port'] = int(opts['port'])
+ opts.update(url.query)
+ return ([], opts)
+
+ def is_disconnect(self, e):
+ return "connection is closed" in str(e)
+
+dialect = PostgreSQL_pg8000
--- /dev/null
+"""Support for the PostgreSQL database via the psycopg2 driver.
+
+Driver
+------
+
+The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
+The dialect has several behaviors which are specifically tailored towards compatibility
+with this module.
+
+Note that psycopg1 is **not** supported.
+
+Connecting
+----------
+
+URLs are of the form `postgresql+psycopg2://user@password@host:port/dbname[?key=value&key=value...]`.
+
+psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
+
+* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
+ this feature. What this essentially means from a psycopg2 point of view is that the cursor is
+ created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
+ are not immediately pre-fetched and buffered after statement execution, but are instead left
+ on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
+ uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows
+ at a time are fetched over the wire to reduce conversational overhead.
+
+* *isolation_level* - Sets the transaction isolation level for each transaction
+ within the engine. Valid isolation levels are `READ_COMMITTED`,
+ `READ_UNCOMMITTED`, `REPEATABLE_READ`, and `SERIALIZABLE`.
+
+Transactions
+------------
+
+The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
+
+
+"""
+
+import decimal, random, re
+from sqlalchemy import util
+from sqlalchemy.engine import base, default
+from sqlalchemy.sql import expression
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler
+
+class _PGNumeric(sqltypes.Numeric):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect):
+ if self.asdecimal:
+ return None
+ else:
+ def process(value):
+ if isinstance(value, decimal.Decimal):
+ return float(value)
+ else:
+ return value
+ return process
+
+
+# TODO: filter out 'FOR UPDATE' statements
+SERVER_SIDE_CURSOR_RE = re.compile(
+ r'\s*SELECT',
+ re.I | re.UNICODE)
+
+class PostgreSQL_psycopg2ExecutionContext(default.DefaultExecutionContext):
+ def create_cursor(self):
+ # TODO: coverage for server side cursors + select.for_update()
+ is_server_side = \
+ self.dialect.server_side_cursors and \
+ ((self.compiled and isinstance(self.compiled.statement, expression.Selectable)
+ and not getattr(self.compiled.statement, 'for_update', False)) \
+ or \
+ (
+ (not self.compiled or isinstance(self.compiled.statement, expression._TextClause))
+ and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
+ )
+
+ self.__is_server_side = is_server_side
+ if is_server_side:
+ # use server-side cursors:
+ # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+ ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
+ return self._connection.connection.cursor(ident)
+ else:
+ return self._connection.connection.cursor()
+
+ def get_result_proxy(self):
+ if self.__is_server_side:
+ return base.BufferedRowResultProxy(self)
+ else:
+ return base.ResultProxy(self)
+
+class PostgreSQL_psycopg2Compiler(PGCompiler):
+ def visit_mod(self, binary, **kw):
+ return self.process(binary.left) + " %% " + self.process(binary.right)
+
+ def post_process_text(self, text):
+ return text.replace('%', '%%')
+
+class PostgreSQL_psycopg2(PGDialect):
+ driver = 'psycopg2'
+ supports_unicode_statements = False
+ default_paramstyle = 'pyformat'
+ supports_sane_multi_rowcount = False
+ execution_ctx_cls = PostgreSQL_psycopg2ExecutionContext
+ statement_compiler = PostgreSQL_psycopg2Compiler
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Numeric : _PGNumeric,
+ sqltypes.Float: sqltypes.Float, # prevents _PGNumeric from being used
+ }
+ )
+
+ def __init__(self, server_side_cursors=False, **kwargs):
+ PGDialect.__init__(self, **kwargs)
+ self.server_side_cursors = server_side_cursors
+
+ @classmethod
+ def dbapi(cls):
+ psycopg = __import__('psycopg2')
+ return psycopg
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username='user')
+ if 'port' in opts:
+ opts['port'] = int(opts['port'])
+ opts.update(url.query)
+ return ([], opts)
+
+ def is_disconnect(self, e):
+ if isinstance(e, self.dbapi.OperationalError):
+ return 'closed the connection' in str(e) or 'connection not open' in str(e)
+ elif isinstance(e, self.dbapi.InterfaceError):
+ return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
+ elif isinstance(e, self.dbapi.ProgrammingError):
+ # yes, it really says "losed", not "closed"
+ return "losed the connection unexpectedly" in str(e)
+ else:
+ return False
+
+dialect = PostgreSQL_psycopg2
+
--- /dev/null
+"""Support for the PostgreSQL database via py-postgresql.
+
+Connecting
+----------
+
+URLs are of the form `postgresql+pypostgresql://user@password@host:port/dbname[?key=value&key=value...]`.
+
+
+"""
+from sqlalchemy.engine import default
+import decimal
+from sqlalchemy import util
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.postgresql.base import PGDialect, PGDefaultRunner
+
+class PGNumeric(sqltypes.Numeric):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect):
+ if self.asdecimal:
+ return None
+ else:
+ def process(value):
+ if isinstance(value, decimal.Decimal):
+ return float(value)
+ else:
+ return value
+ return process
+
+class PostgreSQL_pypostgresqlExecutionContext(default.DefaultExecutionContext):
+ pass
+
+class PostgreSQL_pypostgresqlDefaultRunner(PGDefaultRunner):
+ def execute_string(self, stmt, params=None):
+ return PGDefaultRunner.execute_string(self, stmt, params or ())
+
+class PostgreSQL_pypostgresql(PGDialect):
+ driver = 'pypostgresql'
+
+ supports_unicode_statements = True
+
+ supports_unicode_binds = True
+ description_encoding = None
+
+ defaultrunner = PostgreSQL_pypostgresqlDefaultRunner
+
+ default_paramstyle = 'format'
+
+ supports_sane_rowcount = False # alas....posting a bug now
+
+ supports_sane_multi_rowcount = False
+
+ execution_ctx_cls = PostgreSQL_pypostgresqlExecutionContext
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Numeric : PGNumeric,
+ sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used
+ }
+ )
+
+ @classmethod
+ def dbapi(cls):
+ from postgresql.driver import dbapi20
+ return dbapi20
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username='user')
+ if 'port' in opts:
+ opts['port'] = int(opts['port'])
+ else:
+ opts['port'] = 5432
+ opts.update(url.query)
+ return ([], opts)
+
+ def is_disconnect(self, e):
+ return "connection is closed" in str(e)
+
+dialect = PostgreSQL_pypostgresql
--- /dev/null
+"""Support for the PostgreSQL database via the zxjdbc JDBC connector.
+
+JDBC Driver
+-----------
+
+The official Postgresql JDBC driver is at http://jdbc.postgresql.org/.
+
+"""
+from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
+from sqlalchemy.dialects.postgresql.base import PGCompiler, PGDialect
+
+class PostgreSQL_jdbcCompiler(PGCompiler):
+
+ def post_process_text(self, text):
+ # Don't escape '%' like PGCompiler
+ return text
+
+
+class PostgreSQL_jdbc(ZxJDBCConnector, PGDialect):
+ statement_compiler = PostgreSQL_jdbcCompiler
+
+ jdbc_db_name = 'postgresql'
+ jdbc_driver_name = 'org.postgresql.Driver'
+
+ def _get_server_version_info(self, connection):
+ return tuple(int(x) for x in connection.connection.dbversion.split('.'))
+
+dialect = PostgreSQL_jdbc
--- /dev/null
+from sqlalchemy.dialects.sqlite import base, pysqlite
+
+# default dialect
+base.dialect = pysqlite.dialect
\ No newline at end of file
--- /dev/null
+# sqlite.py
+# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Support for the SQLite database.
+
+For information on connecting using a specific driver, see the documentation
+section regarding that driver.
+
+Date and Time Types
+-------------------
+
+SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide
+out of the box functionality for translating values between Python `datetime` objects
+and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime`
+and related types provide date formatting and parsing functionality when SQlite is used.
+The implementation classes are :class:`_SLDateTime`, :class:`_SLDate` and :class:`_SLTime`.
+These types represent dates and times as ISO formatted strings, which also nicely
+support ordering. There's no reliance on typical "libc" internals for these functions
+so historical dates are fully supported.
+
+
+"""
+
+import datetime, re, time
+
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import sql, exc, pool, DefaultClause
+from sqlalchemy.engine import default
+from sqlalchemy.engine import reflection
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+from sqlalchemy.sql import compiler, functions as sql_functions
+from sqlalchemy.util import NoneType
+
+from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\
+ FLOAT, INTEGER, NUMERIC, SMALLINT, TEXT, TIME,\
+ TIMESTAMP, VARCHAR
+
+
+class _NumericMixin(object):
+ def bind_processor(self, dialect):
+ type_ = self.asdecimal and str or float
+ def process(value):
+ if value is not None:
+ return type_(value)
+ else:
+ return value
+ return process
+
+class _SLNumeric(_NumericMixin, sqltypes.Numeric):
+ pass
+
+class _SLFloat(_NumericMixin, sqltypes.Float):
+ pass
+
+# since SQLite has no date types, we're assuming that SQLite via ODBC
+# or JDBC would similarly have no built in date support, so the "string" based logic
+# would apply to all implementing dialects.
+class _DateTimeMixin(object):
+ def _bind_processor(self, format, elements):
+ def process(value):
+ if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)):
+ raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.")
+ elif value is not None:
+ return format % tuple([getattr(value, attr, 0) for attr in elements])
+ else:
+ return None
+ return process
+
+ def _result_processor(self, fn, regexp):
+ def process(value):
+ if value is not None:
+ return fn(*[int(x or 0) for x in regexp.match(value).groups()])
+ else:
+ return None
+ return process
+
+class _SLDateTime(_DateTimeMixin, sqltypes.DateTime):
+ __legacy_microseconds__ = False
+
+ def bind_processor(self, dialect):
+ if self.__legacy_microseconds__:
+ return self._bind_processor(
+ "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s",
+ ("year", "month", "day", "hour", "minute", "second", "microsecond")
+ )
+ else:
+ return self._bind_processor(
+ "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d",
+ ("year", "month", "day", "hour", "minute", "second", "microsecond")
+ )
+
+ _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?")
+ def result_processor(self, dialect):
+ return self._result_processor(datetime.datetime, self._reg)
+
+class _SLDate(_DateTimeMixin, sqltypes.Date):
+ def bind_processor(self, dialect):
+ return self._bind_processor(
+ "%4.4d-%2.2d-%2.2d",
+ ("year", "month", "day")
+ )
+
+ _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
+ def result_processor(self, dialect):
+ return self._result_processor(datetime.date, self._reg)
+
+class _SLTime(_DateTimeMixin, sqltypes.Time):
+ __legacy_microseconds__ = False
+
+ def bind_processor(self, dialect):
+ if self.__legacy_microseconds__:
+ return self._bind_processor(
+ "%2.2d:%2.2d:%2.2d.%s",
+ ("hour", "minute", "second", "microsecond")
+ )
+ else:
+ return self._bind_processor(
+ "%2.2d:%2.2d:%2.2d.%06d",
+ ("hour", "minute", "second", "microsecond")
+ )
+
+ _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
+ def result_processor(self, dialect):
+ return self._result_processor(datetime.time, self._reg)
+
+
+class _SLBoolean(sqltypes.Boolean):
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ return value and 1 or 0
+ return process
+
+ def result_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ return value == 1
+ return process
+
+colspecs = {
+ sqltypes.Boolean: _SLBoolean,
+ sqltypes.Date: _SLDate,
+ sqltypes.DateTime: _SLDateTime,
+ sqltypes.Float: _SLFloat,
+ sqltypes.Numeric: _SLNumeric,
+ sqltypes.Time: _SLTime,
+}
+
+ischema_names = {
+ 'BLOB': sqltypes.BLOB,
+ 'BOOL': sqltypes.BOOLEAN,
+ 'BOOLEAN': sqltypes.BOOLEAN,
+ 'CHAR': sqltypes.CHAR,
+ 'DATE': sqltypes.DATE,
+ 'DATETIME': sqltypes.DATETIME,
+ 'DECIMAL': sqltypes.DECIMAL,
+ 'FLOAT': sqltypes.FLOAT,
+ 'INT': sqltypes.INTEGER,
+ 'INTEGER': sqltypes.INTEGER,
+ 'NUMERIC': sqltypes.NUMERIC,
+ 'REAL': sqltypes.Numeric,
+ 'SMALLINT': sqltypes.SMALLINT,
+ 'TEXT': sqltypes.TEXT,
+ 'TIME': sqltypes.TIME,
+ 'TIMESTAMP': sqltypes.TIMESTAMP,
+ 'VARCHAR': sqltypes.VARCHAR,
+}
+
+
+
+class SQLiteCompiler(compiler.SQLCompiler):
+ extract_map = compiler.SQLCompiler.extract_map.copy()
+ extract_map.update({
+ 'month': '%m',
+ 'day': '%d',
+ 'year': '%Y',
+ 'second': '%S',
+ 'hour': '%H',
+ 'doy': '%j',
+ 'minute': '%M',
+ 'epoch': '%s',
+ 'dow': '%w',
+ 'week': '%W'
+ })
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_char_length_func(self, fn, **kw):
+ return "length%s" % self.function_argspec(fn)
+
+ def visit_cast(self, cast, **kwargs):
+ if self.dialect.supports_cast:
+ return super(SQLiteCompiler, self).visit_cast(cast)
+ else:
+ return self.process(cast.clause)
+
+ def visit_extract(self, extract):
+ try:
+ return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
+ self.extract_map[extract.field], self.process(extract.expr))
+ except KeyError:
+ raise exc.ArgumentError(
+ "%s is not a valid extract argument." % extract.field)
+
+ def limit_clause(self, select):
+ text = ""
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
+ text += " \n LIMIT -1"
+ text += " OFFSET " + str(select._offset)
+ else:
+ text += " OFFSET 0"
+ return text
+
+ def for_update_clause(self, select):
+ # sqlite has no "FOR UPDATE" AFAICT
+ return ''
+
+
+class SQLiteDDLCompiler(compiler.DDLCompiler):
+
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+ return colspec
+
+class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_binary(self, type_):
+ return self.visit_BLOB(type_)
+
+class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = set([
+ 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
+ 'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
+ 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
+ 'conflict', 'constraint', 'create', 'cross', 'current_date',
+ 'current_time', 'current_timestamp', 'database', 'default',
+ 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
+ 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
+ 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
+ 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
+ 'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is',
+ 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural',
+ 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer',
+ 'plan', 'pragma', 'primary', 'query', 'raise', 'references',
+ 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback',
+ 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to',
+ 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using',
+ 'vacuum', 'values', 'view', 'virtual', 'when', 'where',
+ ])
+
+class SQLiteDialect(default.DefaultDialect):
+ name = 'sqlite'
+ supports_alter = False
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+ supports_default_values = True
+ supports_empty_insert = False
+ supports_cast = True
+
+ default_paramstyle = 'qmark'
+ statement_compiler = SQLiteCompiler
+ ddl_compiler = SQLiteDDLCompiler
+ type_compiler = SQLiteTypeCompiler
+ preparer = SQLiteIdentifierPreparer
+ ischema_names = ischema_names
+ colspecs = colspecs
+ isolation_level = None
+
+ def __init__(self, isolation_level=None, **kwargs):
+ default.DefaultDialect.__init__(self, **kwargs)
+ if isolation_level and isolation_level not in ('SERIALIZABLE',
+ 'READ UNCOMMITTED'):
+ raise exc.ArgumentError("Invalid value for isolation_level. "
+ "Valid isolation levels for sqlite are 'SERIALIZABLE' and "
+ "'READ UNCOMMITTED'.")
+ self.isolation_level = isolation_level
+
+ def visit_pool(self, pool):
+ if self.isolation_level is not None:
+ class SetIsolationLevel(object):
+ def __init__(self, isolation_level):
+ if isolation_level == 'READ UNCOMMITTED':
+ self.isolation_level = 1
+ else:
+ self.isolation_level = 0
+
+ def connect(self, conn, rec):
+ cursor = conn.cursor()
+ cursor.execute("PRAGMA read_uncommitted = %d" % self.isolation_level)
+ cursor.close()
+ pool.add_listener(SetIsolationLevel(self.isolation_level))
+
+ def table_names(self, connection, schema):
+ if schema is not None:
+ qschema = self.identifier_preparer.quote_identifier(schema)
+ master = '%s.sqlite_master' % qschema
+ s = ("SELECT name FROM %s "
+ "WHERE type='table' ORDER BY name") % (master,)
+ rs = connection.execute(s)
+ else:
+ try:
+ s = ("SELECT name FROM "
+ " (SELECT * FROM sqlite_master UNION ALL "
+ " SELECT * FROM sqlite_temp_master) "
+ "WHERE type='table' ORDER BY name")
+ rs = connection.execute(s)
+ except exc.DBAPIError:
+ raise
+ s = ("SELECT name FROM sqlite_master "
+ "WHERE type='table' ORDER BY name")
+ rs = connection.execute(s)
+
+ return [row[0] for row in rs]
+
+ def has_table(self, connection, table_name, schema=None):
+ quote = self.identifier_preparer.quote_identifier
+ if schema is not None:
+ pragma = "PRAGMA %s." % quote(schema)
+ else:
+ pragma = "PRAGMA "
+ qtable = quote(table_name)
+ cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
+ row = cursor.fetchone()
+
+ # consume remaining rows, to work around
+ # http://www.sqlite.org/cvstrac/tktview?tn=1884
+ while cursor.fetchone() is not None:
+ pass
+
+ return (row is not None)
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ return self.table_names(connection, schema)
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ qschema = self.identifier_preparer.quote_identifier(schema)
+ master = '%s.sqlite_master' % qschema
+ s = ("SELECT name FROM %s "
+ "WHERE type='view' ORDER BY name") % (master,)
+ rs = connection.execute(s)
+ else:
+ try:
+ s = ("SELECT name FROM "
+ " (SELECT * FROM sqlite_master UNION ALL "
+ " SELECT * FROM sqlite_temp_master) "
+ "WHERE type='view' ORDER BY name")
+ rs = connection.execute(s)
+ except exc.DBAPIError:
+ raise
+ s = ("SELECT name FROM sqlite_master "
+ "WHERE type='view' ORDER BY name")
+ rs = connection.execute(s)
+
+ return [row[0] for row in rs]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ quote = self.identifier_preparer.quote_identifier
+ if schema is not None:
+ qschema = self.identifier_preparer.quote_identifier(schema)
+ master = '%s.sqlite_master' % qschema
+ s = ("SELECT sql FROM %s WHERE name = '%s'"
+ "AND type='view'") % (master, view_name)
+ rs = connection.execute(s)
+ else:
+ try:
+ s = ("SELECT sql FROM "
+ " (SELECT * FROM sqlite_master UNION ALL "
+ " SELECT * FROM sqlite_temp_master) "
+ "WHERE name = '%s' "
+ "AND type='view'") % view_name
+ rs = connection.execute(s)
+ except exc.DBAPIError:
+ raise
+ s = ("SELECT sql FROM sqlite_master WHERE name = '%s' "
+ "AND type='view'") % view_name
+ rs = connection.execute(s)
+
+ result = rs.fetchall()
+ if result:
+ return result[0].sql
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ quote = self.identifier_preparer.quote_identifier
+ if schema is not None:
+ pragma = "PRAGMA %s." % quote(schema)
+ else:
+ pragma = "PRAGMA "
+ qtable = quote(table_name)
+ c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable)))
+ found_table = False
+ columns = []
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5])
+ name = re.sub(r'^\"|\"$', '', name)
+ if default:
+ default = re.sub(r"^\'|\'$", '', default)
+ match = re.match(r'(\w+)(\(.*?\))?', type_)
+ if match:
+ coltype = match.group(1)
+ args = match.group(2)
+ else:
+ coltype = "VARCHAR"
+ args = ''
+ try:
+ coltype = self.ischema_names[coltype]
+ except KeyError:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (coltype, name))
+ coltype = sqltypes.NullType
+ if args is not None:
+ args = re.findall(r'(\d+)', args)
+ coltype = coltype(*[int(a) for a in args])
+
+ columns.append({
+ 'name' : name,
+ 'type' : coltype,
+ 'nullable' : nullable,
+ 'default' : default,
+ 'primary_key': primary_key
+ })
+ return columns
+
+ @reflection.cache
+ def get_primary_keys(self, connection, table_name, schema=None, **kw):
+ cols = self.get_columns(connection, table_name, schema, **kw)
+ pkeys = []
+ for col in cols:
+ if col['primary_key']:
+ pkeys.append(col['name'])
+ return pkeys
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ quote = self.identifier_preparer.quote_identifier
+ if schema is not None:
+ pragma = "PRAGMA %s." % quote(schema)
+ else:
+ pragma = "PRAGMA "
+ qtable = quote(table_name)
+ c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable)))
+ fkeys = []
+ fks = {}
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ (constraint_name, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4])
+ rtbl = re.sub(r'^\"|\"$', '', rtbl)
+ lcol = re.sub(r'^\"|\"$', '', lcol)
+ rcol = re.sub(r'^\"|\"$', '', rcol)
+ try:
+ fk = fks[constraint_name]
+ except KeyError:
+ fk = {
+ 'name' : constraint_name,
+ 'constrained_columns' : [],
+ 'referred_schema' : None,
+ 'referred_table' : rtbl,
+ 'referred_columns' : []
+ }
+ fkeys.append(fk)
+ fks[constraint_name] = fk
+
+ # look up the table based on the given table's engine, not 'self',
+ # since it could be a ProxyEngine
+ if lcol not in fk['constrained_columns']:
+ fk['constrained_columns'].append(lcol)
+ if rcol not in fk['referred_columns']:
+ fk['referred_columns'].append(rcol)
+ return fkeys
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ quote = self.identifier_preparer.quote_identifier
+ if schema is not None:
+ pragma = "PRAGMA %s." % quote(schema)
+ else:
+ pragma = "PRAGMA "
+ qtable = quote(table_name)
+ c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable)))
+ indexes = []
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ indexes.append(dict(name=row[1], column_names=[], unique=row[2]))
+ # loop thru unique indexes to get the column names.
+ for idx in indexes:
+ c = connection.execute("%sindex_info(%s)" % (pragma, quote(idx['name'])))
+ cols = idx['column_names']
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ cols.append(row[2])
+ return indexes
+
+
+def _pragma_cursor(cursor):
+ """work around SQLite issue whereby cursor.description is blank when PRAGMA returns no rows."""
+
+ if cursor.closed:
+ cursor._fetchone_impl = lambda: None
+ return cursor
--- /dev/null
+"""Support for the SQLite database via pysqlite.
+
+Note that pysqlite is the same driver as the ``sqlite3``
+module included with the Python distribution.
+
+Driver
+------
+
+When using Python 2.5 and above, the built in ``sqlite3`` driver is
+already installed and no additional installation is needed. Otherwise,
+the ``pysqlite2`` driver needs to be present. This is the same driver as
+``sqlite3``, just with a different name.
+
+The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
+is loaded. This allows an explicitly installed pysqlite driver to take
+precedence over the built in one. As with all dialects, a specific
+DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control
+this explicitly::
+
+ from sqlite3 import dbapi2 as sqlite
+ e = create_engine('sqlite+pysqlite:///file.db', module=sqlite)
+
+Full documentation on pysqlite is available at:
+`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
+
+Connect Strings
+---------------
+
+The file specification for the SQLite database is taken as the "database" portion of
+the URL. Note that the format of a url is::
+
+ driver://user:pass@host/database
+
+This means that the actual filename to be used starts with the characters to the
+**right** of the third slash. So connecting to a relative filepath looks like::
+
+ # relative path
+ e = create_engine('sqlite:///path/to/database.db')
+
+An absolute path, which is denoted by starting with a slash, means you need **four**
+slashes::
+
+ # absolute path
+ e = create_engine('sqlite:////path/to/database.db')
+
+To use a Windows path, regular drive specifications and backslashes can be used.
+Double backslashes are probably needed::
+
+ # absolute path on Windows
+ e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
+
+The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify
+``sqlite://`` and nothing else::
+
+ # in-memory database
+ e = create_engine('sqlite://')
+
+Threading Behavior
+------------------
+
+Pysqlite connections do not support being moved between threads, unless
+the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition,
+when using an in-memory SQLite database, the full database exists only within
+the scope of a single connection. It is reported that an in-memory
+database does not support being shared between threads regardless of the
+``check_same_thread`` flag - which means that a multithreaded
+application **cannot** share data from a ``:memory:`` database across threads
+unless access to the connection is limited to a single worker thread which communicates
+through a queueing mechanism to concurrent threads.
+
+To provide a default which accomodates SQLite's default threading capabilities
+somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
+be used by default. This pool maintains a single SQLite connection per thread
+that is held open up to a count of five concurrent threads. When more than five threads
+are used, a cleanup mechanism will dispose of excess unused connections.
+
+Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
+
+ * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
+ application using an in-memory database, assuming the threading issues inherent in
+ pysqlite are somehow accomodated for. This pool holds persistently onto a single connection
+ which is never closed, and is returned for all requests.
+
+ * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
+ makes use of a file-based sqlite database. This pool disables any actual "pooling"
+ behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
+ and :func:`close()` methods. SQLite can "connect" to a particular file with very high
+ efficiency, so this option may actually perform better without the extra overhead
+ of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection
+ useless since the database would be lost as soon as the connection is "returned" to the pool.
+
+Unicode
+-------
+
+In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's
+default behavior regarding Unicode is that all strings are returned as Python unicode objects
+in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is
+*not* used, you will still always receive unicode data back from a result set. It is
+**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
+to represent strings, since it will raise a warning if a non-unicode Python string is
+passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can
+quickly create confusion, particularly when using the ORM as internal data is not
+always represented by an actual database result string.
+
+"""
+
+from sqlalchemy.dialects.sqlite.base import SQLiteDialect
+from sqlalchemy import schema, exc, pool
+from sqlalchemy.engine import default
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+
+class SQLite_pysqlite(SQLiteDialect):
+ default_paramstyle = 'qmark'
+ poolclass = pool.SingletonThreadPool
+
+ # Py3K
+ #description_encoding = None
+
+ driver = 'pysqlite'
+
+ def __init__(self, **kwargs):
+ SQLiteDialect.__init__(self, **kwargs)
+ def vers(num):
+ return tuple([int(x) for x in num.split('.')])
+ if self.dbapi is not None:
+ sqlite_ver = self.dbapi.version_info
+ if sqlite_ver < (2, 1, '3'):
+ util.warn(
+ ("The installed version of pysqlite2 (%s) is out-dated "
+ "and will cause errors in some cases. Version 2.1.3 "
+ "or greater is recommended.") %
+ '.'.join([str(subver) for subver in sqlite_ver]))
+ if self.dbapi.sqlite_version_info < (3, 3, 8):
+ self.supports_default_values = False
+ self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
+
+ @classmethod
+ def dbapi(cls):
+ try:
+ from pysqlite2 import dbapi2 as sqlite
+ except ImportError, e:
+ try:
+ from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
+ except ImportError:
+ raise e
+ return sqlite
+
+ def _get_server_version_info(self, connection):
+ return self.dbapi.sqlite_version_info
+
+ def create_connect_args(self, url):
+ if url.username or url.password or url.host or url.port:
+ raise exc.ArgumentError(
+ "Invalid SQLite URL: %s\n"
+ "Valid SQLite URL forms are:\n"
+ " sqlite:///:memory: (or, sqlite://)\n"
+ " sqlite:///relative/path/to/file.db\n"
+ " sqlite:////absolute/path/to/file.db" % (url,))
+ filename = url.database or ':memory:'
+
+ opts = url.query.copy()
+ util.coerce_kw_type(opts, 'timeout', float)
+ util.coerce_kw_type(opts, 'isolation_level', str)
+ util.coerce_kw_type(opts, 'detect_types', int)
+ util.coerce_kw_type(opts, 'check_same_thread', bool)
+ util.coerce_kw_type(opts, 'cached_statements', int)
+
+ return ([filename], opts)
+
+ def is_disconnect(self, e):
+ return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
+
+dialect = SQLite_pysqlite
--- /dev/null
+from sqlalchemy.dialects.sybase import base, pyodbc
+
+# default dialect
+base.dialect = pyodbc.dialect
\ No newline at end of file
--- /dev/null
+# sybase.py
+# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch
+# Coding: Alexander Houben alexander.houben@thor-solutions.ch
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""Support for the Sybase iAnywhere database.
+
+This is not a full backend for Sybase ASE.
+
+This dialect is *not* tested on SQLAlchemy 0.6.
+
+
+Known issues / TODO:
+
+ * Uses the mx.ODBC driver from egenix (version 2.1.0)
+ * The current version of sqlalchemy.databases.sybase only supports
+ mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need
+ some development)
+ * Support for pyodbc has been built in but is not yet complete (needs
+ further development)
+ * Results of running tests/alltests.py:
+ Ran 934 tests in 287.032s
+ FAILED (failures=3, errors=1)
+ * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751)
+"""
+
+import datetime, operator
+
+from sqlalchemy import util, sql, schema, exc
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.engine import default, base
+from sqlalchemy import types as sqltypes
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import MetaData, Table, Column
+from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey
+from sqlalchemy.dialects.sybase.schema import *
+
+RESERVED_WORDS = set([
+ "add", "all", "alter", "and",
+ "any", "as", "asc", "backup",
+ "begin", "between", "bigint", "binary",
+ "bit", "bottom", "break", "by",
+ "call", "capability", "cascade", "case",
+ "cast", "char", "char_convert", "character",
+ "check", "checkpoint", "close", "comment",
+ "commit", "connect", "constraint", "contains",
+ "continue", "convert", "create", "cross",
+ "cube", "current", "current_timestamp", "current_user",
+ "cursor", "date", "dbspace", "deallocate",
+ "dec", "decimal", "declare", "default",
+ "delete", "deleting", "desc", "distinct",
+ "do", "double", "drop", "dynamic",
+ "else", "elseif", "encrypted", "end",
+ "endif", "escape", "except", "exception",
+ "exec", "execute", "existing", "exists",
+ "externlogin", "fetch", "first", "float",
+ "for", "force", "foreign", "forward",
+ "from", "full", "goto", "grant",
+ "group", "having", "holdlock", "identified",
+ "if", "in", "index", "index_lparen",
+ "inner", "inout", "insensitive", "insert",
+ "inserting", "install", "instead", "int",
+ "integer", "integrated", "intersect", "into",
+ "iq", "is", "isolation", "join",
+ "key", "lateral", "left", "like",
+ "lock", "login", "long", "match",
+ "membership", "message", "mode", "modify",
+ "natural", "new", "no", "noholdlock",
+ "not", "notify", "null", "numeric",
+ "of", "off", "on", "open",
+ "option", "options", "or", "order",
+ "others", "out", "outer", "over",
+ "passthrough", "precision", "prepare", "primary",
+ "print", "privileges", "proc", "procedure",
+ "publication", "raiserror", "readtext", "real",
+ "reference", "references", "release", "remote",
+ "remove", "rename", "reorganize", "resource",
+ "restore", "restrict", "return", "revoke",
+ "right", "rollback", "rollup", "save",
+ "savepoint", "scroll", "select", "sensitive",
+ "session", "set", "setuser", "share",
+ "smallint", "some", "sqlcode", "sqlstate",
+ "start", "stop", "subtrans", "subtransaction",
+ "synchronize", "syntax_error", "table", "temporary",
+ "then", "time", "timestamp", "tinyint",
+ "to", "top", "tran", "trigger",
+ "truncate", "tsequal", "unbounded", "union",
+ "unique", "unknown", "unsigned", "update",
+ "updating", "user", "using", "validate",
+ "values", "varbinary", "varchar", "variable",
+ "varying", "view", "wait", "waitfor",
+ "when", "where", "while", "window",
+ "with", "with_cube", "with_lparen", "with_rollup",
+ "within", "work", "writetext",
+ ])
+
+
+class SybaseImage(sqltypes.Binary):
+ __visit_name__ = 'IMAGE'
+
+class SybaseBit(sqltypes.TypeEngine):
+ __visit_name__ = 'BIT'
+
+class SybaseMoney(sqltypes.TypeEngine):
+ __visit_name__ = "MONEY"
+
+class SybaseSmallMoney(SybaseMoney):
+ __visit_name__ = "SMALLMONEY"
+
+class SybaseUniqueIdentifier(sqltypes.TypeEngine):
+ __visit_name__ = "UNIQUEIDENTIFIER"
+
+class SybaseBoolean(sqltypes.Boolean):
+ def result_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ return value and True or False
+ return process
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is True:
+ return 1
+ elif value is False:
+ return 0
+ elif value is None:
+ return None
+ else:
+ return value and True or False
+ return process
+
+class SybaseTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_binary(self, type_):
+ return self.visit_IMAGE(type_)
+
+ def visit_boolean(self, type_):
+ return self.visit_BIT(type_)
+
+ def visit_IMAGE(self, type_):
+ return "IMAGE"
+
+ def visit_BIT(self, type_):
+ return "BIT"
+
+ def visit_MONEY(self, type_):
+ return "MONEY"
+
+ def visit_SMALLMONEY(self, type_):
+ return "SMALLMONEY"
+
+ def visit_UNIQUEIDENTIFIER(self, type_):
+ return "UNIQUEIDENTIFIER"
+
+colspecs = {
+ sqltypes.Binary : SybaseImage,
+ sqltypes.Boolean : SybaseBoolean,
+}
+
+ischema_names = {
+ 'integer' : sqltypes.INTEGER,
+ 'unsigned int' : sqltypes.Integer,
+ 'unsigned smallint' : sqltypes.SmallInteger,
+ 'unsigned bigint' : sqltypes.BigInteger,
+ 'bigint': sqltypes.BIGINT,
+ 'smallint' : sqltypes.SMALLINT,
+ 'tinyint' : sqltypes.SmallInteger,
+ 'varchar' : sqltypes.VARCHAR,
+ 'long varchar' : sqltypes.Text,
+ 'char' : sqltypes.CHAR,
+ 'decimal' : sqltypes.DECIMAL,
+ 'numeric' : sqltypes.NUMERIC,
+ 'float' : sqltypes.FLOAT,
+ 'double' : sqltypes.Numeric,
+ 'binary' : sqltypes.Binary,
+ 'long binary' : sqltypes.Binary,
+ 'varbinary' : sqltypes.Binary,
+ 'bit': SybaseBit,
+ 'image' : SybaseImage,
+ 'timestamp': sqltypes.TIMESTAMP,
+ 'money': SybaseMoney,
+ 'smallmoney': SybaseSmallMoney,
+ 'uniqueidentifier': SybaseUniqueIdentifier,
+
+}
+
+
+class SybaseExecutionContext(default.DefaultExecutionContext):
+
+ def post_exec(self):
+ if self.compiled.isinsert:
+ table = self.compiled.statement.table
+ # get the inserted values of the primary key
+
+ # get any sequence IDs first (using @@identity)
+ self.cursor.execute("SELECT @@identity AS lastrowid")
+ row = self.cursor.fetchone()
+ lastrowid = int(row[0])
+ if lastrowid > 0:
+ # an IDENTITY was inserted, fetch it
+ # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?!
+ if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None:
+ self._last_inserted_ids = [lastrowid]
+ else:
+ self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:]
+
+
+class SybaseSQLCompiler(compiler.SQLCompiler):
+
+ extract_map = compiler.SQLCompiler.extract_map.copy()
+ extract_map.update ({
+ 'doy': 'dayofyear',
+ 'dow': 'weekday',
+ 'milliseconds': 'millisecond'
+ })
+
+ def visit_mod(self, binary, **kw):
+ return "MOD(%s, %s)" % (self.process(binary.left), self.process(binary.right))
+
+ def bindparam_string(self, name):
+ res = super(SybaseSQLCompiler, self).bindparam_string(name)
+ if name.lower().startswith('literal'):
+ res = 'STRING(%s)' % res
+ return res
+
+ def get_select_precolumns(self, select):
+ s = select._distinct and "DISTINCT " or ""
+ if select._limit:
+ #if select._limit == 1:
+ #s += "FIRST "
+ #else:
+ #s += "TOP %s " % (select._limit,)
+ s += "TOP %s " % (select._limit,)
+ if select._offset:
+ if not select._limit:
+ # FIXME: sybase doesn't allow an offset without a limit
+ # so use a huge value for TOP here
+ s += "TOP 1000000 "
+ s += "START AT %s " % (select._offset+1,)
+ return s
+
+ def limit_clause(self, select):
+ # Limit in sybase is after the select keyword
+ return ""
+
+ def visit_binary(self, binary):
+ """Move bind parameters to the right-hand side of an operator, where possible."""
+ if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
+ return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
+ else:
+ return super(SybaseSQLCompiler, self).visit_binary(binary)
+
+ def label_select_column(self, select, column, asfrom):
+ if isinstance(column, expression.Function):
+ return column.label(None)
+ else:
+ return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
+
+ function_rewrites = {'current_date': 'getdate',
+ }
+ def visit_function(self, func):
+ func.name = self.function_rewrites.get(func.name, func.name)
+ res = super(SybaseSQLCompiler, self).visit_function(func)
+ if func.name.lower() == 'getdate':
+ # apply CAST operator
+ # FIXME: what about _pyodbc ?
+ cast = expression._Cast(func, SybaseDate_mxodbc)
+ # infinite recursion
+ # res = self.visit_cast(cast)
+ res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
+ return res
+
+ def visit_extract(self, extract):
+ field = self.extract_map.get(extract.field, extract.field)
+ return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
+
+ def for_update_clause(self, select):
+ # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
+ return ''
+
+ def order_by_clause(self, select):
+ order_by = self.process(select._order_by_clause)
+
+ # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
+ if order_by and (not self.is_subquery() or select._limit):
+ return " ORDER BY " + order_by
+ else:
+ return ""
+
+
+class SybaseDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+
+ colspec = self.preparer.format_column(column)
+
+ if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
+ column.autoincrement and isinstance(column.type, sqltypes.Integer):
+ if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
+ column.sequence = schema.Sequence(column.name + '_seq')
+
+ if hasattr(column, 'sequence'):
+ column.table.has_sequence = column
+ #colspec += " numeric(30,0) IDENTITY"
+ colspec += " Integer IDENTITY"
+ else:
+ colspec += " " + self.dialect.type_compiler.process(column.type)
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ return colspec
+
+ def visit_drop_index(self, drop):
+ index = drop.element
+ return "\nDROP INDEX %s.%s" % (
+ self.preparer.quote_identifier(index.table.name),
+ self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
+ )
+
+class SybaseIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = RESERVED_WORDS
+
+class SybaseDialect(default.DefaultDialect):
+ name = 'sybase'
+ supports_unicode_statements = False
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+ colspecs = colspecs
+ ischema_names = ischema_names
+
+ type_compiler = SybaseTypeCompiler
+ statement_compiler = SybaseSQLCompiler
+ ddl_compiler = SybaseDDLCompiler
+ preparer = SybaseIdentifierPreparer
+
+ schema_name = "dba"
+
+ def __init__(self, **params):
+ super(SybaseDialect, self).__init__(**params)
+ self.text_as_varchar = False
+
+ def last_inserted_ids(self):
+ return self.context.last_inserted_ids
+
+ def get_default_schema_name(self, connection):
+ return self.schema_name
+
+ def table_names(self, connection, schema):
+ """Ignore the schema and the charset for now."""
+ s = sql.select([tables.c.table_name],
+ sql.not_(tables.c.table_name.like("SYS%")) and
+ tables.c.creator >= 100
+ )
+ rp = connection.execute(s)
+ return [row[0] for row in rp.fetchall()]
+
+ def has_table(self, connection, tablename, schema=None):
+ # FIXME: ignore schemas for sybase
+ s = sql.select([tables.c.table_name], tables.c.table_name == tablename)
+
+ c = connection.execute(s)
+ row = c.fetchone()
+ return row is not None
+
+ def reflecttable(self, connection, table, include_columns):
+ # Get base columns
+ if table.schema is not None:
+ current_schema = table.schema
+ else:
+ current_schema = self.get_default_schema_name(connection)
+
+ s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id])
+
+ c = connection.execute(s)
+ found_table = False
+ # makes sure we append the columns in the correct order
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ found_table = True
+ (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = (
+ row[columns.c.column_name],
+ row[domains.c.domain_name],
+ row[columns.c.nulls] == 'Y',
+ row[columns.c.width],
+ row[domains.c.precision],
+ row[columns.c.scale],
+ row[columns.c.default],
+ row[columns.c.pkey] == 'Y',
+ row[columns.c.max_identity],
+ row[tables.c.table_id],
+ row[columns.c.column_id],
+ )
+ if include_columns and name not in include_columns:
+ continue
+
+ # FIXME: else problems with SybaseBinary(size)
+ if numericscale == 0:
+ numericscale = None
+
+ args = []
+ for a in (charlen, numericprec, numericscale):
+ if a is not None:
+ args.append(a)
+ coltype = self.ischema_names.get(type, None)
+ if coltype == SybaseString and charlen == -1:
+ coltype = SybaseText()
+ else:
+ if coltype is None:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (type, name))
+ coltype = sqltypes.NULLTYPE
+ coltype = coltype(*args)
+ colargs = []
+ if default is not None:
+ colargs.append(schema.DefaultClause(sql.text(default)))
+
+ # any sequences ?
+ col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs)
+ if int(max_identity) > 0:
+ col.sequence = schema.Sequence(name + '_identity')
+ col.sequence.start = int(max_identity)
+ col.sequence.increment = 1
+
+ # append the column
+ table.append_column(col)
+
+ # any foreign key constraint for this table ?
+ # note: no multi-column foreign keys are considered
+ s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name }
+ c = connection.execute(s)
+ foreignKeys = {}
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ (foreign_table, foreign_column, primary_table, primary_column) = (
+ row[0], row[1], row[2], row[3],
+ )
+ if not primary_table in foreignKeys.keys():
+ foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]]
+ else:
+ foreignKeys[primary_table][0].append('%s'%(foreign_column))
+ foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column))
+ for primary_table in foreignKeys.iterkeys():
+ #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)]))
+ table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True))
+
+ if not found_table:
+ raise exc.NoSuchTableError(table.name)
+
--- /dev/null
+from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
+from sqlalchemy.connectors.mxodbc import MxODBCConnector
+
+class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
+ pass
+
+class Sybase_mxodbc(MxODBCConnector, SybaseDialect):
+ execution_ctx_cls = SybaseExecutionContext_mxodbc
+
+dialect = Sybase_mxodbc
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+
+class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
+ pass
+
+
+class Sybase_pyodbc(PyODBCConnector, SybaseDialect):
+ execution_ctx_cls = SybaseExecutionContext_pyodbc
+
+dialect = Sybase_pyodbc
\ No newline at end of file
--- /dev/null
+from sqlalchemy import *
+
+ischema = MetaData()
+
+tables = Table("SYSTABLE", ischema,
+ Column("table_id", Integer, primary_key=True),
+ Column("file_id", SMALLINT),
+ Column("table_name", CHAR(128)),
+ Column("table_type", CHAR(10)),
+ Column("creator", Integer),
+ #schema="information_schema"
+ )
+
+domains = Table("SYSDOMAIN", ischema,
+ Column("domain_id", Integer, primary_key=True),
+ Column("domain_name", CHAR(128)),
+ Column("type_id", SMALLINT),
+ Column("precision", SMALLINT, quote=True),
+ #schema="information_schema"
+ )
+
+columns = Table("SYSCOLUMN", ischema,
+ Column("column_id", Integer, primary_key=True),
+ Column("table_id", Integer, ForeignKey(tables.c.table_id)),
+ Column("pkey", CHAR(1)),
+ Column("column_name", CHAR(128)),
+ Column("nulls", CHAR(1)),
+ Column("width", SMALLINT),
+ Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)),
+ # FIXME: should be mx.BIGINT
+ Column("max_identity", Integer),
+ # FIXME: should be mx.ODBC.Windows.LONGVARCHAR
+ Column("default", String),
+ Column("scale", Integer),
+ #schema="information_schema"
+ )
+
+foreignkeys = Table("SYSFOREIGNKEY", ischema,
+ Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True),
+ Column("foreign_key_id", SMALLINT, primary_key=True),
+ Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)),
+ #schema="information_schema"
+ )
+fkcols = Table("SYSFKCOL", ischema,
+ Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True),
+ Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True),
+ Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True),
+ Column("primary_column_id", Integer),
+ #schema="information_schema"
+ )
+
--- /dev/null
+Rules for Migrating TypeEngine classes to 0.6
+---------------------------------------------
+
+1. the TypeEngine classes are used for:
+
+ a. Specifying behavior which needs to occur for bind parameters
+ or result row columns.
+
+ b. Specifying types that are entirely specific to the database
+ in use and have no analogue in the sqlalchemy.types package.
+
+ c. Specifying types where there is an analogue in sqlalchemy.types,
+ but the database in use takes vendor-specific flags for those
+ types.
+
+ d. If a TypeEngine class doesn't provide any of this, it should be
+ *removed* from the dialect.
+
+2. the TypeEngine classes are *no longer* used for generating DDL. Dialects
+now have a TypeCompiler subclass which uses the same visit_XXX model as
+other compilers.
+
+3. the "ischema_names" and "colspecs" dictionaries are now required members on
+the Dialect class.
+
+4. The names of types within dialects are now important. If a dialect-specific type
+is a subclass of an existing generic type and is only provided for bind/result behavior,
+the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case,
+end users would never need to use _PGNumeric directly. However, if a dialect-specific
+type is specifying a type *or* arguments that are not present generically, it should
+match the real name of the type on that backend, in uppercase. E.g. postgresql.INET,
+mysql.ENUM, postgresql.ARRAY.
+
+Or follow this handy flowchart:
+
+ is the type meant to provide bind/result is the type the same name as an
+ behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ?
+ type in types.py ? | |
+ | no yes
+ yes | |
+ | | does your type need special
+ | +<--- yes --- behavior or arguments ?
+ | | |
+ | | no
+ name the type using | |
+ _MixedCase, i.e. v V
+ _OracleBoolean. it name the type don't make a
+ stays private to the dialect identically as that type, make sure the dialect's
+ and is invoked *only* via within the DB, base.py imports the types.py
+ the colspecs dict. using UPPERCASE UPPERCASE name into its namespace
+ | (i.e. BIT, NCHAR, INTERVAL).
+ | Users can import it.
+ | |
+ v v
+ subclass the closest is the name of this type
+ MixedCase type types.py, identical to an UPPERCASE
+ i.e. <--- no ------- name in types.py ?
+ class _DateTime(types.DateTime),
+ class DATETIME2(types.DateTime), |
+ class BIT(types.TypeEngine). yes
+ |
+ v
+ the type should
+ subclass the
+ UPPERCASE
+ type in types.py
+ (i.e. class BLOB(types.BLOB))
+
+
+Example 1. pysqlite needs bind/result processing for the DateTime type in types.py,
+which applies to all DateTimes and subclasses. It's named _SLDateTime and
+subclasses types.DateTime.
+
+Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument
+that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py,
+and subclasses types.TIME. Users can then say mssql.TIME(precision=10).
+
+Example 3. MS-SQL dialects also need special bind/result processing for date
+But its DATE type doesn't render DDL differently than that of a plain
+DATE, i.e. it takes no special arguments. Therefore we are just adding behavior
+to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses
+types.Date.
+
+Example 4. MySQL has a SET type, there's no analogue for this in types.py. So
+MySQL names it SET in the dialect's base.py, and it subclasses types.String, since
+it ultimately deals with strings.
+
+Example 5. Postgresql has a DATETIME type. The DBAPIs handle dates correctly,
+and no special arguments are used in PG's DDL beyond what types.py provides.
+Postgresql dialect therefore imports types.DATETIME into its base.py.
+
+Ideally one should be able to specify a schema using names imported completely from a
+dialect, all matching the real name on that backend:
+
+ from sqlalchemy.dialects.postgresql import base as pg
+
+ t = Table('mytable', metadata,
+ Column('id', pg.INTEGER, primary_key=True),
+ Column('name', pg.VARCHAR(300)),
+ Column('inetaddr', pg.INET)
+ )
+
+where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types,
+but the PG dialect makes them available in its own namespace.
+
+5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types
+linked to types specified in the dialect. Again, if a type in the dialect does not
+specify any special behavior for bind_processor() or result_processor() and does not
+indicate a special type only available in this database, it must be *removed* from the
+module and from this dictionary.
+
+6. "ischema_names" indicates string descriptions of types as returned from the database
+linked to TypeEngine classes.
+
+ a. The string name should be matched to the most specific type possible within
+ sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which
+ case it points to a dialect type. *It doesn't matter* if the dialect has it's
+ own subclass of that type with special bind/result behavior - reflect to the types.py
+ UPPERCASE type as much as possible. With very few exceptions, all types
+ should reflect to an UPPERCASE type.
+
+ b. If the dialect contains a matching dialect-specific type that takes extra arguments
+ which the generic one does not, then point to the dialect-specific type. E.g.
+ mssql.VARCHAR takes a "collation" parameter which should be preserved.
+
+5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by
+a subclass of compiler.GenericTypeCompiler.
+
+ a. your TypeCompiler class will receive generic and uppercase types from
+ sqlalchemy.types. Do not assume the presence of dialect-specific attributes on
+ these types.
+
+ b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with
+ methods that produce a different DDL name. Uppercase types don't do any kind of
+ "guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in
+ all cases, regardless of whether or not that type is legal on the backend database.
+
+ c. the visit_UPPERCASE methods *should* be overridden with methods that add additional
+ arguments and flags to those types.
+
+ d. the visit_lowercase methods are overridden to provide an interpretation of a generic
+ type. E.g. visit_binary() might be overridden to say "return self.visit_BIT(type_)".
+
+ e. visit_lowercase methods should *never* render strings directly - it should always
+ be via calling a visit_UPPERCASE() method.
within a URL.
"""
-import sqlalchemy.databases
+# not sure what this was used for
+#import sqlalchemy.databases
+
from sqlalchemy.engine.base import (
BufferedColumnResultProxy,
BufferedColumnRow,
ResultProxy,
RootTransaction,
RowProxy,
- SchemaIterator,
Transaction,
- TwoPhaseTransaction
+ TwoPhaseTransaction,
+ TypeCompiler
)
from sqlalchemy.engine import strategies
from sqlalchemy import util
'ResultProxy',
'RootTransaction',
'RowProxy',
- 'SchemaIterator',
'Transaction',
'TwoPhaseTransaction',
+ 'TypeCompiler',
'create_engine',
'engine_from_config',
)
The URL is a string in the form
``dialect://user:password@host/dbname[?key=value..]``, where
- ``dialect`` is a name such as ``mysql``, ``oracle``, ``postgres``,
+ ``dialect`` is a name such as ``mysql``, ``oracle``, ``postgresql``,
etc. Alternatively, the URL can be an instance of
:class:`~sqlalchemy.engine.url.URL`.
Defines the basic components used to interface DB-API modules with
higher-level statement-construction, connection-management, execution
and result contexts.
-
"""
-__all__ = ['BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy', 'Compiled', 'Connectable',
- 'Connection', 'DefaultRunner', 'Dialect', 'Engine', 'ExecutionContext', 'NestedTransaction', 'ResultProxy',
- 'RootTransaction', 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction', 'connection_memoize']
+__all__ = [
+ 'BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy',
+ 'Compiled', 'Connectable', 'Connection', 'DefaultRunner', 'Dialect', 'Engine',
+ 'ExecutionContext', 'NestedTransaction', 'ResultProxy', 'RootTransaction',
+ 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction',
+ 'connection_memoize']
-import inspect, StringIO
+import inspect, StringIO, sys
from sqlalchemy import exc, schema, util, types, log
from sqlalchemy.sql import expression
ExecutionContext, Compiled, DefaultGenerator, and TypeEngine.
All Dialects implement the following attributes:
-
+
name
- identifying name for the dialect (i.e. 'sqlite')
-
+ identifying name for the dialect from a DBAPI-neutral point of view
+ (i.e. 'sqlite')
+
+ driver
+ identifying name for the dialect's DBAPI
+
positional
True if the paramstyle for this Dialect is positional.
type of encoding to use for unicode, usually defaults to
'utf-8'.
- schemagenerator
- a :class:`~sqlalchemy.schema.SchemaVisitor` class which generates
- schemas.
-
- schemadropper
- a :class:`~sqlalchemy.schema.SchemaVisitor` class which drops schemas.
-
defaultrunner
a :class:`~sqlalchemy.schema.SchemaVisitor` class which executes
defaults.
statement_compiler
- a :class:`~sqlalchemy.engine.base.Compiled` class used to compile SQL
- statements
+ a :class:`~Compiled` class used to compile SQL statements
+
+ ddl_compiler
+ a :class:`~Compiled` class used to compile DDL statements
+
+ server_version_info
+ a tuple containing a version number for the DB backend in use.
+ This value is only available for supporting dialects, and only for
+ a dialect that's been associated with a connection pool via
+ create_engine() or otherwise had its ``initialize()`` method called
+ with a conneciton.
+
+ execution_ctx_cls
+ a :class:`ExecutionContext` class used to handle statement execution
preparer
a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to
The maximum length of identifier names.
supports_unicode_statements
- Indicate whether the DB-API can receive SQL statements as Python unicode strings
+ Indicate whether the DB-API can receive SQL statements as Python
+ unicode strings
+
+ supports_unicode_binds
+ Indicate whether the DB-API can receive string bind parameters
+ as Python unicode strings
supports_sane_rowcount
- Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements.
+ Indicate whether the dialect properly implements rowcount for
+ ``UPDATE`` and ``DELETE`` statements.
supports_sane_multi_rowcount
- Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements
- when executed via executemany.
-
- preexecute_pk_sequences
- Indicate if the dialect should pre-execute sequences on primary key
- columns during an INSERT, if it's desired that the new row's primary key
- be available after execution.
-
- supports_pk_autoincrement
- Indicates if the dialect should allow the database to passively assign
- a primary key column value.
-
+ Indicate whether the dialect properly implements rowcount for
+ ``UPDATE`` and ``DELETE`` statements when executed via
+ executemany.
+
+ preexecute_autoincrement_sequences
+ True if 'implicit' primary key functions must be executed separately
+ in order to get their value. This is currently oriented towards
+ Postgresql.
+
+ implicit_returning
+ use RETURNING or equivalent during INSERT execution in order to load
+ newly generated primary keys and other column defaults in one execution,
+ which are then available via inserted_primary_key.
+ If an insert statement has returning() specified explicitly,
+ the "implicit" functionality is not used and inserted_primary_key
+ will not be available.
+
dbapi_type_map
A mapping of DB-API type objects present in this Dialect's
- DB-API implmentation mapped to TypeEngine implementations used
+ DB-API implementation mapped to TypeEngine implementations used
by the dialect.
This is used to apply types to result sets based on the DB-API
result sets against textual statements where no explicit
typemap was present.
- supports_default_values
- Indicates if the construct ``INSERT INTO tablename DEFAULT VALUES`` is supported
+ colspecs
+ A dictionary of TypeEngine classes from sqlalchemy.types mapped
+ to subclasses that are specific to the dialect class. This
+ dictionary is class-level only and is not accessed from the
+ dialect instance itself.
- description_encoding
- type of encoding to use for unicode when working with metadata
- descriptions. If set to ``None`` no encoding will be done.
- This usually defaults to 'utf-8'.
+ supports_default_values
+ Indicates if the construct ``INSERT INTO tablename DEFAULT
+ VALUES`` is supported
"""
def create_connect_args(self, url):
raise NotImplementedError()
+ @classmethod
+ def type_descriptor(cls, typeobj):
+ """Transform a generic type to a dialect-specific type.
- def type_descriptor(self, typeobj):
- """Transform a generic type to a database-specific type.
-
- Transforms the given :class:`~sqlalchemy.types.TypeEngine` instance
- from generic to database-specific.
-
- Subclasses will usually use the
+ Dialect classes will usually use the
:func:`~sqlalchemy.types.adapt_type` method in the types module to
make this job easy.
+
+ The returned result is cached *per dialect class* so can
+ contain no dialect-instance state.
"""
raise NotImplementedError()
+ def initialize(self, connection):
+ """Called during strategized creation of the dialect with a connection.
- def server_version_info(self, connection):
- """Return a tuple of the database's version number."""
+ Allows dialects to configure options based on server version info or
+ other properties.
+ """
- raise NotImplementedError()
+ pass
def reflecttable(self, connection, table, include_columns=None):
"""Load table description from the database.
raise NotImplementedError()
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ """Return information about columns in `table_name`.
+
+ Given a :class:`~sqlalchemy.engine.Connection`, a string
+ `table_name`, and an optional string `schema`, return column
+ information as a list of dictionaries with these keys:
+
+ name
+ the column's name
+
+ type
+ [sqlalchemy.types#TypeEngine]
+
+ nullable
+ boolean
+
+ default
+ the column's default value
+
+ autoincrement
+ boolean
+
+ sequence
+ a dictionary of the form
+ {'name' : str, 'start' :int, 'increment': int}
+
+ Additional column attributes may be present.
+ """
+
+ raise NotImplementedError()
+
+ def get_primary_keys(self, connection, table_name, schema=None, **kw):
+ """Return information about primary keys in `table_name`.
+
+ Given a :class:`~sqlalchemy.engine.Connection`, a string
+ `table_name`, and an optional string `schema`, return primary
+ key information as a list of column names.
+ """
+
+ raise NotImplementedError()
+
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ """Return information about foreign_keys in `table_name`.
+
+ Given a :class:`~sqlalchemy.engine.Connection`, a string
+ `table_name`, and an optional string `schema`, return foreign
+ key information as a list of dicts with these keys:
+
+ name
+ the constraint's name
+
+ constrained_columns
+ a list of column names that make up the foreign key
+
+ referred_schema
+ the name of the referred schema
+
+ referred_table
+ the name of the referred table
+
+ referred_columns
+ a list of column names in the referred table that correspond to
+ constrained_columns
+ """
+
+ raise NotImplementedError()
+
+ def get_table_names(self, connection, schema=None, **kw):
+ """Return a list of table names for `schema`."""
+
+ raise NotImplementedError
+
+ def get_view_names(self, connection, schema=None, **kw):
+ """Return a list of all view names available in the database.
+
+ schema:
+ Optional, retrieve names from a non-default schema.
+ """
+
+ raise NotImplementedError()
+
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ """Return view definition.
+
+ Given a :class:`~sqlalchemy.engine.Connection`, a string
+ `view_name`, and an optional string `schema`, return the view
+ definition.
+ """
+
+ raise NotImplementedError()
+
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ """Return information about indexes in `table_name`.
+
+ Given a :class:`~sqlalchemy.engine.Connection`, a string
+ `table_name` and an optional string `schema`, return index
+ information as a list of dictionaries with these keys:
+
+ name
+ the index's name
+
+ column_names
+ list of column names in order
+
+ unique
+ boolean
+ """
+
+ raise NotImplementedError()
+
+ def normalize_name(self, name):
+ """convert the given name to lowercase if it is detected as case insensitive.
+
+ this method is only used if the dialect defines requires_name_normalize=True.
+
+ """
+ raise NotImplementedError()
+
+ def denormalize_name(self, name):
+ """convert the given name to a case insensitive identifier for the backend
+ if it is an all-lowercase name.
+
+ this method is only used if the dialect defines requires_name_normalize=True.
+
+ """
+ raise NotImplementedError()
+
def has_table(self, connection, table_name, schema=None):
"""Check the existence of a particular table in the database.
raise NotImplementedError()
def get_default_schema_name(self, connection):
- """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`."""
+ """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`.
+
+ DEPRECATED. moving this towards dialect.default_schema_name (not complete).
+
+ """
raise NotImplementedError()
raise NotImplementedError()
+ def visit_pool(self, pool):
+ """Executed after a pool is created."""
+
class ExecutionContext(object):
"""A messenger object for a Dialect that corresponds to a single execution.
- ExecutionContext should have these datamembers:
+ ExecutionContext should have these data members:
connection
Connection object which can be freely used by default value
True if the statement is an UPDATE.
should_autocommit
- True if the statement is a "committable" statement
+ True if the statement is a "committable" statement.
postfetch_cols
- a list of Column objects for which a server-side default
- or inline SQL expression value was fired off. applies to inserts and updates.
-
-
+ a list of Column objects for which a server-side default or
+ inline SQL expression value was fired off. Applies to inserts
+ and updates.
"""
def create_cursor(self):
"""Return a new cursor generated from this ExecutionContext's connection.
Some dialects may wish to change the behavior of
- connection.cursor(), such as postgres which may return a PG
+ connection.cursor(), such as postgresql which may return a PG
"server side" cursor.
"""
def handle_dbapi_exception(self, e):
"""Receive a DBAPI exception which occured upon execute, result fetch, etc."""
-
- raise NotImplementedError()
-
- def should_autocommit_text(self, statement):
- """Parse the given textual statement and return True if it refers to a "committable" statement"""
raise NotImplementedError()
- def last_inserted_ids(self):
- """Return the list of the primary key values for the last insert statement executed.
-
- This does not apply to straight textual clauses; only to
- ``sql.Insert`` objects compiled against a ``schema.Table``
- object. The order of items in the list is the same as that of
- the Table's 'primary_key' attribute.
- """
+ def should_autocommit_text(self, statement):
+ """Parse the given textual statement and return True if it refers to a "committable" statement"""
raise NotImplementedError()
class Compiled(object):
- """Represent a compiled SQL expression.
+ """Represent a compiled SQL or DDL expression.
The ``__str__`` method of the ``Compiled`` object should produce
the actual text of the statement. ``Compiled`` objects are
defaults.
"""
- def __init__(self, dialect, statement, column_keys=None, bind=None):
+ def __init__(self, dialect, statement, bind=None):
"""Construct a new ``Compiled`` object.
- dialect
- ``Dialect`` to compile against.
-
- statement
- ``ClauseElement`` to be compiled.
+ :param dialect: ``Dialect`` to compile against.
- column_keys
- a list of column names to be compiled into an INSERT or UPDATE
- statement.
+ :param statement: ``ClauseElement`` to be compiled.
- bind
- Optional Engine or Connection to compile this statement against.
-
+ :param bind: Optional Engine or Connection to compile this statement against.
"""
+
self.dialect = dialect
self.statement = statement
- self.column_keys = column_keys
self.bind = bind
self.can_execute = statement.supports_execution
def compile(self):
"""Produce the internal string representation of this element."""
- raise NotImplementedError()
+ self.string = self.process(self.statement)
- def __str__(self):
- """Return the string text of the generated SQL statement."""
+ def process(self, obj, **kwargs):
+ return obj._compiler_dispatch(self, **kwargs)
- raise NotImplementedError()
+ def __str__(self):
+ """Return the string text of the generated SQL or DDL."""
- @util.deprecated('Deprecated. Use construct_params(). '
- '(supports Unicode key names.)')
- def get_params(self, **params):
- return self.construct_params(params)
+ return self.string or ''
- def construct_params(self, params):
+ def construct_params(self, params=None):
"""Return the bind params for this compiled object.
- `params` is a dict of string/object pairs whos
- values will override bind values compiled in
- to the statement.
+ :param params: a dict of string/object pairs whos values will
+ override bind values compiled in to the
+ statement.
"""
+
raise NotImplementedError()
+ @property
+ def params(self):
+ """Return the bind params for this compiled object."""
+ return self.construct_params()
+
def execute(self, *multiparams, **params):
"""Execute this compiled object."""
return self.execute(*multiparams, **params).scalar()
+class TypeCompiler(object):
+ """Produces DDL specification for TypeEngine objects."""
+
+ def __init__(self, dialect):
+ self.dialect = dialect
+
+ def process(self, type_):
+ return type_._compiler_dispatch(self)
+
+
class Connectable(object):
"""Interface for an object which supports execution of SQL constructs.
-
+
The two implementations of ``Connectable`` are :class:`Connection` and
:class:`Engine`.
-
+
+ Connectable must also implement the 'dialect' member which references a
+ :class:`Dialect` instance.
"""
def contextual_connect(self):
def _execute_clauseelement(self, elem, multiparams=None, params=None):
raise NotImplementedError()
+
class Connection(Connectable):
"""Provides high-level functionality for a wrapped DB-API connection.
.. index::
single: thread safety; Connection
-
"""
def __init__(self, engine, connection=None, close_with_result=False,
Connection objects are typically constructed by an
:class:`~sqlalchemy.engine.Engine`, see the ``connect()`` and
``contextual_connect()`` methods of Engine.
-
"""
self.engine = engine
self.__savepoint_seq = 0
self.__branch = _branch
self.__invalid = False
-
+
def _branch(self):
"""Return a new Connection which references this Connection's
engine and connection; but does not have close_with_result enabled,
This is used to execute "sub" statements within a single execution,
usually an INSERT statement.
-
"""
+
return self.engine.Connection(self.engine, self.__connection, _branch=True)
@property
@property
def closed(self):
- """return True if this connection is closed."""
+ """Return True if this connection is closed."""
return not self.__invalid and '_Connection__connection' not in self.__dict__
@property
def invalidated(self):
- """return True if this connection was invalidated."""
+ """Return True if this connection was invalidated."""
return self.__invalid
def should_close_with_result(self):
"""Indicates if this Connection should be closed when a corresponding
ResultProxy is closed; this is essentially an auto-release mode.
-
"""
+
return self.__close_with_result
@property
def info(self):
"""A collection of per-DB-API connection instance properties."""
+
return self.connection.info
def connect(self):
This ``Connectable`` interface method returns self, allowing
Connections to be used interchangably with Engines in most
situations that require a bind.
-
"""
+
return self
def contextual_connect(self, **kwargs):
This ``Connectable`` interface method returns self, allowing
Connections to be used interchangably with Engines in most
situations that require a bind.
-
"""
+
return self
def invalidate(self, exception=None):
rolled back before a reconnect on this Connection can proceed. This
is to prevent applications from accidentally continuing their transactional
operations in a non-transactional state.
-
"""
+
if self.closed:
raise exc.InvalidRequestError("This Connection is closed")
:class:`~sqlalchemy.interfaces.PoolListener` for a mechanism to modify
connection state when connections leave and return to their
connection pool.
-
"""
+
self.__connection.detach()
def begin(self):
outermost transaction may ``commit``. Calls to ``commit`` on
inner transactions are ignored. Any transaction in the
hierarchy may ``rollback``, however.
-
"""
+
if self.__transaction is None:
self.__transaction = RootTransaction(self)
else:
def begin_twophase(self, xid=None):
"""Begin a two-phase or XA transaction and return a Transaction handle.
- xid
- the two phase transaction id. If not supplied, a random id
- will be generated.
+ :param xid: the two phase transaction id. If not supplied, a random id
+ will be generated.
"""
if self.__transaction is not None:
return self.execute(object, *multiparams, **params).scalar()
- def statement_compiler(self, statement, **kwargs):
- return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs)
-
def execute(self, object, *multiparams, **params):
"""Executes and returns a ResultProxy."""
raise exc.InvalidRequestError("Unexecutable object type: " + str(type(object)))
def __distill_params(self, multiparams, params):
- """given arguments from the calling form *multiparams, **params, return a list
+ """Given arguments from the calling form *multiparams, **params, return a list
of bind parameter structures, usually a list of dictionaries.
- in the case of 'raw' execution which accepts positional parameters,
- it may be a list of tuples or lists."""
+ In the case of 'raw' execution which accepts positional parameters,
+ it may be a list of tuples or lists.
+ """
if not multiparams:
if params:
return self._execute_clauseelement(func.select(), multiparams, params)
def _execute_default(self, default, multiparams, params):
- return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default)
+ ret = self.engine.dialect.\
+ defaultrunner(self.__create_execution_context()).\
+ traverse_single(default)
+ if self.__close_with_result:
+ self.close()
+ return ret
+
+ def _execute_ddl(self, ddl, params, multiparams):
+ context = self.__create_execution_context(
+ compiled_ddl=ddl.compile(dialect=self.dialect),
+ parameters=None
+ )
+ return self.__execute_context(context)
def _execute_clauseelement(self, elem, multiparams, params):
params = self.__distill_params(multiparams, params)
keys = []
context = self.__create_execution_context(
- compiled=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1),
+ compiled_sql=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1),
parameters=params
)
return self.__execute_context(context)
"""Execute a sql.Compiled object."""
context = self.__create_execution_context(
- compiled=compiled,
+ compiled_sql=compiled,
parameters=self.__distill_params(multiparams, params)
)
return self.__execute_context(context)
parameters = self.__distill_params(multiparams, params)
context = self.__create_execution_context(statement=statement, parameters=parameters)
return self.__execute_context(context)
-
+
def __execute_context(self, context):
if context.compiled:
context.pre_exec()
+
if context.executemany:
self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context)
else:
self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context)
+
if context.compiled:
context.post_exec()
+
+ if context.isinsert and not context.executemany:
+ context.post_insert()
+
if context.should_autocommit and not self.in_transaction():
self._commit_impl()
- return context.get_result_proxy()
+
+ return context.get_result_proxy()._autoclose()
- def _execute_ddl(self, ddl, params, multiparams):
- if params:
- schema_item, params = params[0], params[1:]
- else:
- schema_item = None
- return ddl(None, schema_item, self, *params, **multiparams)
-
def _handle_dbapi_exception(self, e, statement, parameters, cursor, context):
if getattr(self, '_reentrant_error', False):
- raise exc.DBAPIError.instance(None, None, e)
+ # Py3K
+ #raise exc.DBAPIError.instance(statement, parameters, e) from e
+ # Py2K
+ raise exc.DBAPIError.instance(statement, parameters, e), None, sys.exc_info()[2]
+ # end Py2K
self._reentrant_error = True
try:
if not isinstance(e, self.dialect.dbapi.Error):
return
-
+
if context:
context.handle_dbapi_exception(e)
-
+
is_disconnect = self.dialect.is_disconnect(e)
if is_disconnect:
self.invalidate(e)
self._autorollback()
if self.__close_with_result:
self.close()
- raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
+ # Py3K
+ #raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) from e
+ # Py2K
+ raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect), None, sys.exc_info()[2]
+ # end Py2K
+
finally:
del self._reentrant_error
expression.ClauseElement: _execute_clauseelement,
Compiled: _execute_compiled,
schema.SchemaItem: _execute_default,
- schema.DDL: _execute_ddl,
+ schema.DDLElement: _execute_ddl,
basestring: _execute_text
}
def run_callable(self, callable_):
return callable_(self)
+
class Transaction(object):
"""Represent a Transaction in progress.
.. index::
single: thread safety; Transaction
-
"""
def __init__(self, connection, parent):
self.connection = connection
self._parent = parent or self
self.is_active = True
-
+
def close(self):
"""Close this transaction.
This is used to cancel a Transaction without affecting the scope of
an enclosing transaction.
"""
+
if not self._parent.is_active:
return
if self._parent is self:
else:
self.rollback()
+
class RootTransaction(Transaction):
def __init__(self, connection):
super(RootTransaction, self).__init__(connection, None)
def _do_commit(self):
self.connection._commit_impl()
+
class NestedTransaction(Transaction):
def __init__(self, connection, parent):
super(NestedTransaction, self).__init__(connection, parent)
def _do_commit(self):
self.connection._release_savepoint_impl(self._savepoint, self._parent)
+
class TwoPhaseTransaction(Transaction):
def __init__(self, connection, xid):
super(TwoPhaseTransaction, self).__init__(connection, None)
def commit(self):
self.connection._commit_twophase_impl(self.xid, self._is_prepared)
+
class Engine(Connectable):
"""
- Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect`
+ Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect`
together to provide a source of database connectivity and behavior.
"""
@property
def name(self):
"String name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``."
-
+
return self.dialect.name
+ @property
+ def driver(self):
+ "Driver name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``."
+
+ return self.dialect.driver
+
echo = log.echo_property()
def __repr__(self):
def create(self, entity, connection=None, **kwargs):
"""Create a table or index within this engine's database connection given a schema.Table object."""
- self._run_visitor(self.dialect.schemagenerator, entity, connection=connection, **kwargs)
+ from sqlalchemy.engine import ddl
+
+ self._run_visitor(ddl.SchemaGenerator, entity, connection=connection, **kwargs)
def drop(self, entity, connection=None, **kwargs):
"""Drop a table or index within this engine's database connection given a schema.Table object."""
- self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs)
+ from sqlalchemy.engine import ddl
+
+ self._run_visitor(ddl.SchemaDropper, entity, connection=connection, **kwargs)
def _execute_default(self, default):
connection = self.contextual_connect()
connection = self.contextual_connect(close_with_result=True)
return connection._execute_compiled(compiled, multiparams, params)
- def statement_compiler(self, statement, **kwargs):
- return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs)
-
def connect(self, **kwargs):
"""Return a newly allocated Connection object."""
def table_names(self, schema=None, connection=None):
"""Return a list of all table names available in the database.
- schema:
- Optional, retrieve names from a non-default schema.
+ :param schema: Optional, retrieve names from a non-default schema.
- connection:
- Optional, use a specified connection. Default is the
- ``contextual_connect`` for this ``Engine``.
+ :param connection: Optional, use a specified connection. Default is the
+ ``contextual_connect`` for this ``Engine``.
"""
if connection is None:
return self.pool.unique_connection()
+
def _proxy_connection_cls(cls, proxy):
class ProxyConnection(cls):
def execute(self, object, *multiparams, **params):
return proxy.execute(self, super(ProxyConnection, self).execute, object, *multiparams, **params)
-
+
def _execute_clauseelement(self, elem, multiparams=None, params=None):
return proxy.execute(self, super(ProxyConnection, self).execute, elem, *(multiparams or []), **(params or {}))
-
+
def _cursor_execute(self, cursor, statement, parameters, context=None):
return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, cursor, statement, parameters, context, False)
-
+
def _cursor_executemany(self, cursor, statement, parameters, context=None):
return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, cursor, statement, parameters, context, True)
return ProxyConnection
+
class RowProxy(object):
"""Proxy a single cursor row for a parent ResultProxy.
"""
__slots__ = ['__parent', '__row']
-
+
def __init__(self, parent, row):
"""RowProxy objects are constructed by ResultProxy objects."""
yield self.__parent._get_col(self.__row, i)
__hash__ = None
-
+
def __eq__(self, other):
return ((other is self) or
(other == tuple(self.__parent._get_col(self.__row, key)
"""Return the list of keys as strings represented by this RowProxy."""
return self.__parent.keys
-
+
def iterkeys(self):
return iter(self.__parent.keys)
-
+
def values(self):
"""Return the values represented by this RowProxy as a list."""
return list(self)
-
+
def itervalues(self):
return iter(self)
+
class BufferedColumnRow(RowProxy):
def __init__(self, parent, row):
row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))]
"""
_process_row = RowProxy
-
+
def __init__(self, context):
- """ResultProxy objects are constructed via the execute() method on SQLEngine."""
self.context = context
self.dialect = context.dialect
self.closed = False
self.connection = context.root_connection
self._echo = context.engine._should_log_info
self._init_metadata()
-
- @property
+
+ @util.memoized_property
def rowcount(self):
- if self._rowcount is None:
- return self.context.get_rowcount()
- else:
- return self._rowcount
+ """Return the 'rowcount' for this result.
+
+ The 'rowcount' reports the number of rows affected
+ by an UPDATE or DELETE statement. It has *no* other
+ uses and is not intended to provide the number of rows
+ present from a SELECT.
+
+ Additionally, this value is only meaningful if the
+ dialect's supports_sane_rowcount flag is True for
+ single-parameter executions, or supports_sane_multi_rowcount
+ is true for multiple parameter executions - otherwise
+ results are undefined.
+
+ rowcount may not work at this time for a statement
+ that uses ``returning()``.
+
+ """
+ return self.context.rowcount
@property
def lastrowid(self):
+ """return the 'lastrowid' accessor on the DBAPI cursor.
+
+ This is a DBAPI specific method and is only functional
+ for those backends which support it, for statements
+ where it is appropriate. It's behavior is not
+ consistent across backends.
+
+ Usage of this method is normally unnecessary; the
+ inserted_primary_key method provides a
+ tuple of primary key values for a newly inserted row,
+ regardless of database backend.
+
+ """
return self.cursor.lastrowid
@property
def out_parameters(self):
return self.context.out_parameters
-
+
+ def _cursor_description(self):
+ return self.cursor.description
+
+ def _autoclose(self):
+ if self.context.isinsert:
+ if self.context._is_implicit_returning:
+ self.context._fetch_implicit_returning(self)
+ self.close()
+ elif not self.context._is_explicit_returning:
+ self.close()
+ elif self._metadata is None:
+ # no results, get rowcount
+ # (which requires open cursor on some DB's such as firebird),
+ self.rowcount
+ self.close() # autoclose
+
+ return self
+
+
def _init_metadata(self):
- metadata = self.cursor.description
+ self._metadata = metadata = self._cursor_description()
if metadata is None:
- # no results, get rowcount (which requires open cursor on some DB's such as firebird),
- # then close
- self._rowcount = self.context.get_rowcount()
- self.close()
return
-
- self._rowcount = None
+
self._props = util.populate_column_dict(None)
self._props.creator = self.__key_fallback()
self.keys = []
typemap = self.dialect.dbapi_type_map
- for i, item in enumerate(metadata):
- colname = item[0]
+ for i, (colname, coltype) in enumerate(m[0:2] for m in metadata):
+
if self.dialect.description_encoding:
colname = colname.decode(self.dialect.description_encoding)
try:
(name, obj, type_) = self.context.result_map[colname.lower()]
except KeyError:
- (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
+ (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE))
else:
- (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE))
+ (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE))
rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i)
if origname:
if self._props.setdefault(origname.lower(), rec) is not rec:
self._props[origname.lower()] = (type_, self.__ambiguous_processor(origname), 0)
-
+
+ if self.dialect.requires_name_normalize:
+ colname = self.dialect.normalize_name(colname)
+
self.keys.append(colname)
self._props[i] = rec
if obj:
if self._echo:
self.context.engine.logger.debug(
"Col " + repr(tuple(x[0] for x in metadata)))
-
+
def __key_fallback(self):
# create a closure without 'self' to avoid circular references
props = self._props
-
+
def fallback(key):
if isinstance(key, basestring):
key = key.lower()
def close(self):
"""Close this ResultProxy.
-
+
Closes the underlying DBAPI cursor corresponding to the execution.
+
+ Note that any data cached within this ResultProxy is still available.
+ For some types of results, this may include buffered rows.
If this ResultProxy was generated from an implicit execution,
the underlying Connection will also be closed (returns the
underlying DBAPI connection to the connection pool.)
This method is called automatically when:
-
- * all result rows are exhausted using the fetchXXX() methods.
- * cursor.description is None.
-
+
+ * all result rows are exhausted using the fetchXXX() methods.
+ * cursor.description is None.
"""
+
if not self.closed:
self.closed = True
self.cursor.close()
raise StopIteration
else:
yield row
-
- def last_inserted_ids(self):
- """Return ``last_inserted_ids()`` from the underlying ExecutionContext.
-
- See ExecutionContext for details.
+
+ @util.memoized_property
+ def inserted_primary_key(self):
+ """Return the primary key for the row just inserted.
+
+ This only applies to single row insert() constructs which
+ did not explicitly specify returning().
"""
- return self.context.last_inserted_ids()
+ if not self.context.isinsert:
+ raise exc.InvalidRequestError("Statement is not an insert() expression construct.")
+ elif self.context._is_explicit_returning:
+ raise exc.InvalidRequestError("Can't call inserted_primary_key when returning() is used.")
+
+ return self.context._inserted_primary_key
+ @util.deprecated("Use inserted_primary_key")
+ def last_inserted_ids(self):
+ """deprecated. use inserted_primary_key."""
+
+ return self.inserted_primary_key
+
def last_updated_params(self):
"""Return ``last_updated_params()`` from the underlying ExecutionContext.
See ExecutionContext for details.
-
"""
+
return self.context.last_updated_params()
def last_inserted_params(self):
"""Return ``last_inserted_params()`` from the underlying ExecutionContext.
See ExecutionContext for details.
-
"""
+
return self.context.last_inserted_params()
def lastrow_has_defaults(self):
"""Return ``lastrow_has_defaults()`` from the underlying ExecutionContext.
See ExecutionContext for details.
-
"""
+
return self.context.lastrow_has_defaults()
def postfetch_cols(self):
"""Return ``postfetch_cols()`` from the underlying ExecutionContext.
See ExecutionContext for details.
-
"""
+
return self.context.postfetch_cols
-
+
def prefetch_cols(self):
return self.context.prefetch_cols
-
+
def supports_sane_rowcount(self):
"""Return ``supports_sane_rowcount`` from the dialect."""
-
+
return self.dialect.supports_sane_rowcount
def supports_sane_multi_rowcount(self):
raise
def fetchmany(self, size=None):
- """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``."""
+ """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``.
+
+ If rows are present, the cursor remains open after this is called.
+ Else the cursor is automatically closed and an empty list is returned.
+
+ """
try:
process_row = self._process_row
raise
def fetchone(self):
- """Fetch one row, just like DB-API ``cursor.fetchone()``."""
+ """Fetch one row, just like DB-API ``cursor.fetchone()``.
+
+ If a row is present, the cursor remains open after this is called.
+ Else the cursor is automatically closed and None is returned.
+
+ """
+
try:
row = self._fetchone_impl()
if row is not None:
self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
raise
- def scalar(self):
- """Fetch the first column of the first row, and close the result set."""
+ def first(self):
+ """Fetch the first row and then close the result set unconditionally.
+
+ Returns None if no row is present.
+
+ """
try:
row = self._fetchone_impl()
except Exception, e:
self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
raise
-
+
try:
if row is not None:
- return self._process_row(self, row)[0]
+ return self._process_row(self, row)
else:
return None
finally:
self.close()
+
+
+ def scalar(self):
+ """Fetch the first column of the first row, and close the result set.
+
+ Returns None if no row is present.
+
+ """
+ row = self.first()
+ if row is not None:
+ return row[0]
+ else:
+ return None
class BufferedRowResultProxy(ResultProxy):
"""A ResultProxy with row buffering behavior.
The pre-fetching behavior fetches only one row initially, and then
grows its buffer size by a fixed amount with each successive need
for additional rows up to a size of 100.
-
"""
def _init_metadata(self):
return result
def _fetchall_impl(self):
- return self.__rowbuffer + list(self.cursor.fetchall())
+ ret = self.__rowbuffer + list(self.cursor.fetchall())
+ self.__rowbuffer[:] = []
+ return ret
+
+class FullyBufferedResultProxy(ResultProxy):
+ """A result proxy that buffers rows fully upon creation.
+
+ Used for operations where a result is to be delivered
+ after the database conversation can not be continued,
+ such as MSSQL INSERT...OUTPUT after an autocommit.
+
+ """
+ def _init_metadata(self):
+ super(FullyBufferedResultProxy, self)._init_metadata()
+ self.__rowbuffer = self._buffer_rows()
+
+ def _buffer_rows(self):
+ return self.cursor.fetchall()
+
+ def _fetchone_impl(self):
+ if self.__rowbuffer:
+ return self.__rowbuffer.pop(0)
+ else:
+ return None
+
+ def _fetchmany_impl(self, size=None):
+ result = []
+ for x in range(0, size):
+ row = self._fetchone_impl()
+ if row is None:
+ break
+ result.append(row)
+ return result
+
+ def _fetchall_impl(self):
+ ret = self.__rowbuffer
+ self.__rowbuffer = []
+ return ret
class BufferedColumnResultProxy(ResultProxy):
"""A ResultProxy with column buffering behavior.
return l
-class SchemaIterator(schema.SchemaVisitor):
- """A visitor that can gather text into a buffer and execute the contents of the buffer."""
-
- def __init__(self, connection):
- """Construct a new SchemaIterator."""
-
- self.connection = connection
- self.buffer = StringIO.StringIO()
-
- def append(self, s):
- """Append content to the SchemaIterator's query buffer."""
-
- self.buffer.write(s)
-
- def execute(self):
- """Execute the contents of the SchemaIterator's buffer."""
-
- try:
- return self.connection.execute(self.buffer.getvalue())
- finally:
- self.buffer.truncate(0)
-
class DefaultRunner(schema.SchemaVisitor):
"""A visitor which accepts ColumnDefault objects, produces the
dialect-specific SQL corresponding to their execution, and
DefaultRunners are used internally by Engines and Dialects.
Specific database modules should provide their own subclasses of
DefaultRunner to allow database-specific behavior.
-
"""
def __init__(self, context):
def execute_string(self, stmt, params=None):
"""execute a string statement, using the raw cursor, and return a scalar result."""
-
+
conn = self.context._connection
if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
stmt = stmt.encode(self.dialect.encoding)
Only applicable to functions which take no arguments other than a
connection. The memo will be stored in ``connection.info[key]``.
-
"""
+
@util.decorator
def decorated(fn, self, connection):
connection = connection.connect()
--- /dev/null
+# engine/ddl.py
+# Copyright (C) 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""Routines to handle CREATE/DROP workflow."""
+
+from sqlalchemy import engine, schema
+from sqlalchemy.sql import util as sql_util
+
+
+class DDLBase(schema.SchemaVisitor):
+ def __init__(self, connection):
+ self.connection = connection
+
+class SchemaGenerator(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+ super(SchemaGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+ self.tables = tables and set(tables) or None
+ self.preparer = dialect.identifier_preparer
+ self.dialect = dialect
+
+ def _can_create(self, table):
+ self.dialect.validate_identifier(table.name)
+ if table.schema:
+ self.dialect.validate_identifier(table.schema)
+ return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
+
+ def visit_metadata(self, metadata):
+ if self.tables:
+ tables = self.tables
+ else:
+ tables = metadata.tables.values()
+ collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
+
+ for listener in metadata.ddl_listeners['before-create']:
+ listener('before-create', metadata, self.connection, tables=collection)
+
+ for table in collection:
+ self.traverse_single(table)
+
+ for listener in metadata.ddl_listeners['after-create']:
+ listener('after-create', metadata, self.connection, tables=collection)
+
+ def visit_table(self, table):
+ for listener in table.ddl_listeners['before-create']:
+ listener('before-create', table, self.connection)
+
+ for column in table.columns:
+ if column.default is not None:
+ self.traverse_single(column.default)
+
+ self.connection.execute(schema.CreateTable(table))
+
+ if hasattr(table, 'indexes'):
+ for index in table.indexes:
+ self.traverse_single(index)
+
+ for listener in table.ddl_listeners['after-create']:
+ listener('after-create', table, self.connection)
+
+ def visit_sequence(self, sequence):
+ if self.dialect.supports_sequences:
+ if ((not self.dialect.sequences_optional or
+ not sequence.optional) and
+ (not self.checkfirst or
+ not self.dialect.has_sequence(self.connection, sequence.name))):
+ self.connection.execute(schema.CreateSequence(sequence))
+
+ def visit_index(self, index):
+ self.connection.execute(schema.CreateIndex(index))
+
+
+class SchemaDropper(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+ super(SchemaDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+ self.tables = tables
+ self.preparer = dialect.identifier_preparer
+ self.dialect = dialect
+
+ def visit_metadata(self, metadata):
+ if self.tables:
+ tables = self.tables
+ else:
+ tables = metadata.tables.values()
+ collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
+
+ for listener in metadata.ddl_listeners['before-drop']:
+ listener('before-drop', metadata, self.connection, tables=collection)
+
+ for table in collection:
+ self.traverse_single(table)
+
+ for listener in metadata.ddl_listeners['after-drop']:
+ listener('after-drop', metadata, self.connection, tables=collection)
+
+ def _can_drop(self, table):
+ self.dialect.validate_identifier(table.name)
+ if table.schema:
+ self.dialect.validate_identifier(table.schema)
+ return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
+
+ def visit_index(self, index):
+ self.connection.execute(schema.DropIndex(index))
+
+ def visit_table(self, table):
+ for listener in table.ddl_listeners['before-drop']:
+ listener('before-drop', table, self.connection)
+
+ for column in table.columns:
+ if column.default is not None:
+ self.traverse_single(column.default)
+
+ self.connection.execute(schema.DropTable(table))
+
+ for listener in table.ddl_listeners['after-drop']:
+ listener('after-drop', table, self.connection)
+
+ def visit_sequence(self, sequence):
+ if self.dialect.supports_sequences:
+ if ((not self.dialect.sequences_optional or
+ not sequence.optional) and
+ (not self.checkfirst or
+ self.dialect.has_sequence(self.connection, sequence.name))):
+ self.connection.execute(schema.DropSequence(sequence))
"""
import re, random
-from sqlalchemy.engine import base
+from sqlalchemy.engine import base, reflection
from sqlalchemy.sql import compiler, expression
-from sqlalchemy import exc
+from sqlalchemy import exc, types as sqltypes, util
AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
re.I | re.UNICODE)
+
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
- name = 'default'
- schemagenerator = compiler.SchemaGenerator
- schemadropper = compiler.SchemaDropper
- statement_compiler = compiler.DefaultCompiler
+ statement_compiler = compiler.SQLCompiler
+ ddl_compiler = compiler.DDLCompiler
+ type_compiler = compiler.GenericTypeCompiler
preparer = compiler.IdentifierPreparer
defaultrunner = base.DefaultRunner
supports_alter = True
+
+ supports_sequences = False
+ sequences_optional = False
+ preexecute_autoincrement_sequences = False
+ postfetch_lastrowid = True
+ implicit_returning = False
+
+ # Py3K
+ #supports_unicode_statements = True
+ #supports_unicode_binds = True
+ # Py2K
supports_unicode_statements = False
+ supports_unicode_binds = False
+ # end Py2K
+
+ name = 'default'
max_identifier_length = 9999
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
- preexecute_pk_sequences = False
- supports_pk_autoincrement = True
dbapi_type_map = {}
default_paramstyle = 'named'
- supports_default_values = False
+ supports_default_values = False
supports_empty_insert = True
+
+ # indicates symbol names are
+ # UPPERCASEd if they are case insensitive
+ # within the database.
+ # if this is True, the methods normalize_name()
+ # and denormalize_name() must be provided.
+ requires_name_normalize = False
+
+ reflection_options = ()
def __init__(self, convert_unicode=False, assert_unicode=False,
- encoding='utf-8', paramstyle=None, dbapi=None,
+ encoding='utf-8', paramstyle=None, dbapi=None,
+ implicit_returning=None,
label_length=None, **kwargs):
self.convert_unicode = convert_unicode
self.assert_unicode = assert_unicode
self.paramstyle = self.dbapi.paramstyle
else:
self.paramstyle = self.default_paramstyle
+ if implicit_returning is not None:
+ self.implicit_returning = implicit_returning
self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
self.identifier_preparer = self.preparer(self)
+ self.type_compiler = self.type_compiler(self)
+
if label_length and label_length > self.max_identifier_length:
- raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length))
+ raise exc.ArgumentError("Label length of %d is greater than this dialect's"
+ " maximum identifier length of %d" %
+ (label_length, self.max_identifier_length))
self.label_length = label_length
- self.description_encoding = getattr(self, 'description_encoding', encoding)
- def type_descriptor(self, typeobj):
+ if not hasattr(self, 'description_encoding'):
+ self.description_encoding = getattr(self, 'description_encoding', encoding)
+
+ # Py3K
+ ## work around dialects that might change these values
+ #self.supports_unicode_statements = True
+ #self.supports_unicode_binds = True
+
+ def initialize(self, connection):
+ if hasattr(self, '_get_server_version_info'):
+ self.server_version_info = self._get_server_version_info(connection)
+ if hasattr(self, '_get_default_schema_name'):
+ self.default_schema_name = self._get_default_schema_name(connection)
+
+ @classmethod
+ def type_descriptor(cls, typeobj):
"""Provide a database-specific ``TypeEngine`` object, given
the generic object which comes from the types module.
- Subclasses will usually use the ``adapt_type()`` method in the
- types module to make this job easy."""
+ This method looks for a dictionary called
+ ``colspecs`` as a class or instance-level variable,
+ and passes on to ``types.adapt_type()``.
- if type(typeobj) is type:
- typeobj = typeobj()
- return typeobj
+ """
+ return sqltypes.adapt_type(typeobj, cls.colspecs)
+
+ def reflecttable(self, connection, table, include_columns):
+ insp = reflection.Inspector.from_engine(connection)
+ return insp.reflecttable(table, include_columns)
def validate_identifier(self, ident):
if len(ident) > self.max_identifier_length:
- raise exc.IdentifierError("Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length))
-
+ raise exc.IdentifierError(
+ "Identifier '%s' exceeds maximum length of %d characters" %
+ (ident, self.max_identifier_length)
+ )
+
+ def connect(self, *cargs, **cparams):
+ return self.dbapi.connect(*cargs, **cparams)
+
def do_begin(self, connection):
"""Implementations might want to put logic here for turning
autocommit on/off, etc.
"""Create a random two-phase transaction ID.
This id will be passed to do_begin_twophase(), do_rollback_twophase(),
- do_commit_twophase(). Its format is unspecified."""
+ do_commit_twophase(). Its format is unspecified.
+ """
return "_sa_%032x" % random.randint(0, 2 ** 128)
class DefaultExecutionContext(base.ExecutionContext):
- def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
+
+ def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None):
self.dialect = dialect
self._connection = self.root_connection = connection
- self.compiled = compiled
self.engine = connection.engine
- if compiled is not None:
+ if compiled_ddl is not None:
+ self.compiled = compiled = compiled_ddl
+ if not dialect.supports_unicode_statements:
+ self.statement = unicode(compiled).encode(self.dialect.encoding)
+ else:
+ self.statement = unicode(compiled)
+ self.isinsert = self.isupdate = self.isdelete = self.executemany = False
+ self.should_autocommit = True
+ self.result_map = None
+ self.cursor = self.create_cursor()
+ self.compiled_parameters = []
+ if self.dialect.positional:
+ self.parameters = [()]
+ else:
+ self.parameters = [{}]
+ elif compiled_sql is not None:
+ self.compiled = compiled = compiled_sql
+
# compiled clauseelement. process bind params, process table defaults,
# track collections used by ResultProxy to target and process results
self.isinsert = compiled.isinsert
self.isupdate = compiled.isupdate
+ self.isdelete = compiled.isdelete
self.should_autocommit = compiled.statement._autocommit
if isinstance(compiled.statement, expression._TextClause):
self.should_autocommit = self.should_autocommit or self.should_autocommit_text(self.statement)
self.parameters = self.__convert_compiled_params(self.compiled_parameters)
elif statement is not None:
- # plain text statement.
- self.result_map = None
+ # plain text statement
+ self.result_map = self.compiled = None
self.parameters = self.__encode_param_keys(parameters)
self.executemany = len(parameters) > 1
if isinstance(statement, unicode) and not dialect.supports_unicode_statements:
self.statement = statement.encode(self.dialect.encoding)
else:
self.statement = statement
- self.isinsert = self.isupdate = False
+ self.isinsert = self.isupdate = self.isdelete = False
self.cursor = self.create_cursor()
self.should_autocommit = self.should_autocommit_text(statement)
else:
# no statement. used for standalone ColumnDefault execution.
- self.statement = None
- self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False
+ self.statement = self.compiled = None
+ self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False
self.cursor = self.create_cursor()
-
+
+ @util.memoized_property
+ def _is_explicit_returning(self):
+ return self.compiled and \
+ getattr(self.compiled.statement, '_returning', False)
+
+ @util.memoized_property
+ def _is_implicit_returning(self):
+ return self.compiled and \
+ bool(self.compiled.returning) and \
+ not self.compiled.statement._returning
+
@property
def connection(self):
return self._connection._branch()
def __encode_param_keys(self, params):
- """apply string encoding to the keys of dictionary-based bind parameters.
+ """Apply string encoding to the keys of dictionary-based bind parameters.
- This is only used executing textual, non-compiled SQL expressions."""
+ This is only used executing textual, non-compiled SQL expressions.
+ """
if self.dialect.positional or self.dialect.supports_unicode_statements:
if params:
return [proc(d) for d in params] or [{}]
def __convert_compiled_params(self, compiled_parameters):
- """convert the dictionary of bind parameter values into a dict or list
+ """Convert the dictionary of bind parameter values into a dict or list
to be sent to the DBAPI's execute() or executemany() method.
"""
def post_exec(self):
pass
+ def get_lastrowid(self):
+ """return self.cursor.lastrowid, or equivalent, after an INSERT.
+
+ This may involve calling special cursor functions,
+ issuing a new SELECT on the cursor (or a new one),
+ or returning a stored value that was
+ calculated within post_exec().
+
+ This function will only be called for dialects
+ which support "implicit" primary key generation,
+ keep preexecute_autoincrement_sequences set to False,
+ and when no explicit id value was bound to the
+ statement.
+
+ The function is called once, directly after
+ post_exec() and before the transaction is committed
+ or ResultProxy is generated. If the post_exec()
+ method assigns a value to `self._lastrowid`, the
+ value is used in place of calling get_lastrowid().
+
+ Note that this method is *not* equivalent to the
+ ``lastrowid`` method on ``ResultProxy``, which is a
+ direct proxy to the DBAPI ``lastrowid`` accessor
+ in all cases.
+
+ """
+
+ return self.cursor.lastrowid
+
def handle_dbapi_exception(self, e):
pass
def get_result_proxy(self):
return base.ResultProxy(self)
+
+ @property
+ def rowcount(self):
+ return self.cursor.rowcount
- def get_rowcount(self):
- if hasattr(self, '_rowcount'):
- return self._rowcount
- else:
- return self.cursor.rowcount
-
def supports_sane_rowcount(self):
return self.dialect.supports_sane_rowcount
def supports_sane_multi_rowcount(self):
return self.dialect.supports_sane_multi_rowcount
-
- def last_inserted_ids(self):
- return self._last_inserted_ids
+
+ def post_insert(self):
+ if self.dialect.postfetch_lastrowid and \
+ (not len(self._inserted_primary_key) or \
+ None in self._inserted_primary_key):
+
+ table = self.compiled.statement.table
+ lastrowid = self.get_lastrowid()
+ self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v
+ for c, v in zip(table.primary_key, self._inserted_primary_key)
+ ]
+
+ def _fetch_implicit_returning(self, resultproxy):
+ table = self.compiled.statement.table
+ row = resultproxy.first()
+
+ self._inserted_primary_key = [v is not None and v or row[c]
+ for c, v in zip(table.primary_key, self._inserted_primary_key)
+ ]
def last_inserted_params(self):
return self._last_inserted_params
def lastrow_has_defaults(self):
return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
- def set_input_sizes(self):
+ def set_input_sizes(self, translate=None, exclude_types=None):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
"""
+ if not hasattr(self.compiled, 'bind_names'):
+ return
+
types = dict(
(self.compiled.bind_names[bindparam], bindparam.type)
for bindparam in self.compiled.bind_names)
for key in self.compiled.positiontup:
typeengine = types[key]
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
- if dbtype is not None:
+ if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
inputsizes.append(dbtype)
try:
self.cursor.setinputsizes(*inputsizes)
for key in self.compiled.bind_names.values():
typeengine = types[key]
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
- if dbtype is not None:
+ if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
+ if translate:
+ key = translate.get(key, key)
inputsizes[key.encode(self.dialect.encoding)] = dbtype
try:
self.cursor.setinputsizes(**inputsizes)
raise
def __process_defaults(self):
- """generate default values for compiled insert/update statements,
- and generate last_inserted_ids() collection."""
+ """Generate default values for compiled insert/update statements,
+ and generate inserted_primary_key collection.
+ """
if self.executemany:
if len(self.compiled.prefetch):
compiled_parameters[c.key] = val
if self.isinsert:
- self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key]
+ self._inserted_primary_key = [compiled_parameters.get(c.key, None)
+ for c in self.compiled.statement.table.primary_key]
self._last_inserted_params = compiled_parameters
else:
self._last_updated_params = compiled_parameters
--- /dev/null
+"""Provides an abstraction for obtaining database schema information.
+
+Usage Notes:
+
+Here are some general conventions when accessing the low level inspector
+methods such as get_table_names, get_columns, etc.
+
+1. Inspector methods return lists of dicts in most cases for the following
+ reasons:
+
+ * They're both standard types that can be serialized.
+ * Using a dict instead of a tuple allows easy expansion of attributes.
+ * Using a list for the outer structure maintains order and is easy to work
+ with (e.g. list comprehension [d['name'] for d in cols]).
+
+2. Records that contain a name, such as the column name in a column record
+ use the key 'name'. So for most return values, each record will have a
+ 'name' attribute..
+"""
+
+import sqlalchemy
+from sqlalchemy import exc, sql
+from sqlalchemy import util
+from sqlalchemy.types import TypeEngine
+from sqlalchemy import schema as sa_schema
+
+
+@util.decorator
+def cache(fn, self, con, *args, **kw):
+ info_cache = kw.get('info_cache', None)
+ if info_cache is None:
+ return fn(self, con, *args, **kw)
+ key = (
+ fn.__name__,
+ tuple(a for a in args if isinstance(a, basestring)),
+ tuple((k, v) for k, v in kw.iteritems() if isinstance(v, basestring))
+ )
+ ret = info_cache.get(key)
+ if ret is None:
+ ret = fn(self, con, *args, **kw)
+ info_cache[key] = ret
+ return ret
+
+
+class Inspector(object):
+ """Performs database schema inspection.
+
+ The Inspector acts as a proxy to the dialects' reflection methods and
+ provides higher level functions for accessing database schema information.
+ """
+
+ def __init__(self, conn):
+ """Initialize the instance.
+
+ :param conn: a :class:`~sqlalchemy.engine.base.Connectable`
+ """
+
+ self.conn = conn
+ # set the engine
+ if hasattr(conn, 'engine'):
+ self.engine = conn.engine
+ else:
+ self.engine = conn
+ self.dialect = self.engine.dialect
+ self.info_cache = {}
+
+ @classmethod
+ def from_engine(cls, engine):
+ if hasattr(engine.dialect, 'inspector'):
+ return engine.dialect.inspector(engine)
+ return Inspector(engine)
+
+ @property
+ def default_schema_name(self):
+ return self.dialect.get_default_schema_name(self.conn)
+
+ def get_schema_names(self):
+ """Return all schema names.
+ """
+
+ if hasattr(self.dialect, 'get_schema_names'):
+ return self.dialect.get_schema_names(self.conn,
+ info_cache=self.info_cache)
+ return []
+
+ def get_table_names(self, schema=None, order_by=None):
+ """Return all table names in `schema`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ :param order_by: Optional, may be the string "foreign_key" to sort
+ the result on foreign key dependencies.
+
+ This should probably not return view names or maybe it should return
+ them with an indicator t or v.
+ """
+
+ if hasattr(self.dialect, 'get_table_names'):
+ tnames = self.dialect.get_table_names(self.conn,
+ schema,
+ info_cache=self.info_cache)
+ else:
+ tnames = self.engine.table_names(schema)
+ if order_by == 'foreign_key':
+ ordered_tnames = tnames[:]
+ # Order based on foreign key dependencies.
+ for tname in tnames:
+ table_pos = tnames.index(tname)
+ fkeys = self.get_foreign_keys(tname, schema)
+ for fkey in fkeys:
+ rtable = fkey['referred_table']
+ if rtable in ordered_tnames:
+ ref_pos = ordered_tnames.index(rtable)
+ # Make sure it's lower in the list than anything it
+ # references.
+ if table_pos > ref_pos:
+ ordered_tnames.pop(table_pos) # rtable moves up 1
+ # insert just below rtable
+ ordered_tnames.index(ref_pos, tname)
+ tnames = ordered_tnames
+ return tnames
+
+ def get_table_options(self, table_name, schema=None, **kw):
+ if hasattr(self.dialect, 'get_table_options'):
+ return self.dialect.get_table_options(self.conn, table_name, schema,
+ info_cache=self.info_cache,
+ **kw)
+ return {}
+
+ def get_view_names(self, schema=None):
+ """Return all view names in `schema`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ """
+
+ return self.dialect.get_view_names(self.conn, schema,
+ info_cache=self.info_cache)
+
+ def get_view_definition(self, view_name, schema=None):
+ """Return definition for `view_name`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ """
+
+ return self.dialect.get_view_definition(
+ self.conn, view_name, schema, info_cache=self.info_cache)
+
+ def get_columns(self, table_name, schema=None, **kw):
+ """Return information about columns in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ column information as a list of dicts with these keys:
+
+ name
+ the column's name
+
+ type
+ :class:`~sqlalchemy.types.TypeEngine`
+
+ nullable
+ boolean
+
+ default
+ the column's default value
+
+ attrs
+ dict containing optional column attributes
+ """
+
+ col_defs = self.dialect.get_columns(self.conn, table_name, schema,
+ info_cache=self.info_cache,
+ **kw)
+ for col_def in col_defs:
+ # make this easy and only return instances for coltype
+ coltype = col_def['type']
+ if not isinstance(coltype, TypeEngine):
+ col_def['type'] = coltype()
+ return col_defs
+
+ def get_primary_keys(self, table_name, schema=None, **kw):
+ """Return information about primary keys in `table_name`.
+
+ Given a string `table_name`, and an optional string `schema`, return
+ primary key information as a list of column names.
+ """
+
+ pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema,
+ info_cache=self.info_cache,
+ **kw)
+
+ return pkeys
+
+ def get_foreign_keys(self, table_name, schema=None, **kw):
+ """Return information about foreign_keys in `table_name`.
+
+ Given a string `table_name`, and an optional string `schema`, return
+ foreign key information as a list of dicts with these keys:
+
+ constrained_columns
+ a list of column names that make up the foreign key
+
+ referred_schema
+ the name of the referred schema
+
+ referred_table
+ the name of the referred table
+
+ referred_columns
+ a list of column names in the referred table that correspond to
+ constrained_columns
+ """
+
+ fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema,
+ info_cache=self.info_cache,
+ **kw)
+ return fk_defs
+
+ def get_indexes(self, table_name, schema=None):
+ """Return information about indexes in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ index information as a list of dicts with these keys:
+
+ name
+ the index's name
+
+ column_names
+ list of column names in order
+
+ unique
+ boolean
+ """
+
+ indexes = self.dialect.get_indexes(self.conn, table_name,
+ schema,
+ info_cache=self.info_cache)
+ return indexes
+
+ def reflecttable(self, table, include_columns):
+
+ dialect = self.conn.dialect
+
+ # MySQL dialect does this. Applicable with other dialects?
+ if hasattr(dialect, '_connection_charset') \
+ and hasattr(dialect, '_adjust_casing'):
+ charset = dialect._connection_charset
+ dialect._adjust_casing(table)
+
+ # table attributes we might need.
+ reflection_options = dict(
+ (k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
+
+ schema = table.schema
+ table_name = table.name
+
+ # apply table options
+ tbl_opts = self.get_table_options(table_name, schema, **table.kwargs)
+ if tbl_opts:
+ table.kwargs.update(tbl_opts)
+
+ # table.kwargs will need to be passed to each reflection method. Make
+ # sure keywords are strings.
+ tblkw = table.kwargs.copy()
+ for (k, v) in tblkw.items():
+ del tblkw[k]
+ tblkw[str(k)] = v
+
+ # Py2K
+ if isinstance(schema, str):
+ schema = schema.decode(dialect.encoding)
+ if isinstance(table_name, str):
+ table_name = table_name.decode(dialect.encoding)
+ # end Py2K
+
+ # columns
+ found_table = False
+ for col_d in self.get_columns(table_name, schema, **tblkw):
+ found_table = True
+ name = col_d['name']
+ if include_columns and name not in include_columns:
+ continue
+
+ coltype = col_d['type']
+ col_kw = {
+ 'nullable':col_d['nullable'],
+ }
+ if 'autoincrement' in col_d:
+ col_kw['autoincrement'] = col_d['autoincrement']
+
+ colargs = []
+ if col_d.get('default') is not None:
+ # the "default" value is assumed to be a literal SQL expression,
+ # so is wrapped in text() so that no quoting occurs on re-issuance.
+ colargs.append(sa_schema.DefaultClause(sql.text(col_d['default'])))
+
+ if 'sequence' in col_d:
+ # TODO: whos using this ?
+ seq = col_d['sequence']
+ sequence = sa_schema.Sequence(seq['name'], 1, 1)
+ if 'start' in seq:
+ sequence.start = seq['start']
+ if 'increment' in seq:
+ sequence.increment = seq['increment']
+ colargs.append(sequence)
+
+ col = sa_schema.Column(name, coltype, *colargs, **col_kw)
+ table.append_column(col)
+
+ if not found_table:
+ raise exc.NoSuchTableError(table.name)
+
+ # Primary keys
+ primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[
+ table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw)
+ if pk in table.c
+ ])
+
+ table.append_constraint(primary_key_constraint)
+
+ # Foreign keys
+ fkeys = self.get_foreign_keys(table_name, schema, **tblkw)
+ for fkey_d in fkeys:
+ conname = fkey_d['name']
+ constrained_columns = fkey_d['constrained_columns']
+ referred_schema = fkey_d['referred_schema']
+ referred_table = fkey_d['referred_table']
+ referred_columns = fkey_d['referred_columns']
+ refspec = []
+ if referred_schema is not None:
+ sa_schema.Table(referred_table, table.metadata,
+ autoload=True, schema=referred_schema,
+ autoload_with=self.conn,
+ **reflection_options
+ )
+ for column in referred_columns:
+ refspec.append(".".join(
+ [referred_schema, referred_table, column]))
+ else:
+ sa_schema.Table(referred_table, table.metadata, autoload=True,
+ autoload_with=self.conn,
+ **reflection_options
+ )
+ for column in referred_columns:
+ refspec.append(".".join([referred_table, column]))
+ table.append_constraint(
+ sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
+ conname, link_to_name=True))
+ # Indexes
+ indexes = self.get_indexes(table_name, schema)
+ for index_d in indexes:
+ name = index_d['name']
+ columns = index_d['column_names']
+ unique = index_d['unique']
+ flavor = index_d.get('type', 'unknown type')
+ if include_columns and \
+ not set(columns).issubset(include_columns):
+ util.warn(
+ "Omitting %s KEY for (%s), key covers omitted columns." %
+ (flavor, ', '.join(columns)))
+ continue
+ sa_schema.Index(name, *[table.columns[c] for c in columns],
+ **dict(unique=unique))
``plain``, ``threadlocal``, and ``mock``.
New strategies can be added via new ``EngineStrategy`` classes.
-
"""
+
from operator import attrgetter
from sqlalchemy.engine import base, threadlocal, url
from sqlalchemy import util, exc
from sqlalchemy import pool as poollib
-
strategies = {}
+
class EngineStrategy(object):
"""An adaptor that processes input arguements and produces an Engine.
Provides a ``create`` method that receives input arguments and
produces an instance of base.Engine or a subclass.
+
"""
- def __init__(self, name):
- """Construct a new EngineStrategy object.
-
- Sets it in the list of available strategies under this name.
- """
-
- self.name = name
+ def __init__(self):
strategies[self.name] = self
def create(self, *args, **kwargs):
raise NotImplementedError()
+
class DefaultEngineStrategy(EngineStrategy):
"""Base class for built-in stratgies."""
+ pool_threadlocal = False
+
def create(self, name_or_url, **kwargs):
# create url.URL object
u = url.make_url(name_or_url)
if pool is None:
def connect():
try:
- return dbapi.connect(*cargs, **cparams)
+ return dialect.connect(*cargs, **cparams)
except Exception, e:
- raise exc.DBAPIError.instance(None, None, e)
+ # Py3K
+ #raise exc.DBAPIError.instance(None, None, e) from e
+ # Py2K
+ import sys
+ raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2]
+ # end Py2K
+
creator = kwargs.pop('creator', connect)
poolclass = (kwargs.pop('poolclass', None) or
tk = translate.get(k, k)
if tk in kwargs:
pool_args[k] = kwargs.pop(tk)
- pool_args.setdefault('use_threadlocal', self.pool_threadlocal())
+ pool_args.setdefault('use_threadlocal', self.pool_threadlocal)
pool = poolclass(creator, **pool_args)
else:
if isinstance(pool, poollib._DBProxy):
pool = pool
# create engine.
- engineclass = self.get_engine_cls()
+ engineclass = self.engine_cls
engine_args = {}
for k in util.get_cls_kwargs(engineclass):
if k in kwargs:
engine_args[k] = kwargs.pop(k)
+ _initialize = kwargs.pop('_initialize', True)
+
# all kwargs should be consumed
if kwargs:
raise TypeError(
dialect.__class__.__name__,
pool.__class__.__name__,
engineclass.__name__))
- return engineclass(pool, dialect, u, **engine_args)
+
+ engine = engineclass(pool, dialect, u, **engine_args)
- def pool_threadlocal(self):
- raise NotImplementedError()
+ if _initialize:
+ # some unit tests pass through _initialize=False
+ # to help mock engines work
+ class OnInit(object):
+ def first_connect(self, conn, rec):
+ c = base.Connection(engine, connection=conn)
+ dialect.initialize(c)
+ pool._on_first_connect.insert(0, OnInit())
- def get_engine_cls(self):
- raise NotImplementedError()
+ dialect.visit_pool(pool)
-class PlainEngineStrategy(DefaultEngineStrategy):
- """Strategy for configuring a regular Engine."""
+ return engine
- def __init__(self):
- DefaultEngineStrategy.__init__(self, 'plain')
-
- def pool_threadlocal(self):
- return False
- def get_engine_cls(self):
- return base.Engine
+class PlainEngineStrategy(DefaultEngineStrategy):
+ """Strategy for configuring a regular Engine."""
+ name = 'plain'
+ engine_cls = base.Engine
+
PlainEngineStrategy()
+
class ThreadLocalEngineStrategy(DefaultEngineStrategy):
"""Strategy for configuring an Engine with thredlocal behavior."""
-
- def __init__(self):
- DefaultEngineStrategy.__init__(self, 'threadlocal')
-
- def pool_threadlocal(self):
- return True
-
- def get_engine_cls(self):
- return threadlocal.TLEngine
+
+ name = 'threadlocal'
+ pool_threadlocal = True
+ engine_cls = threadlocal.TLEngine
ThreadLocalEngineStrategy()
Produces a single mock Connectable object which dispatches
statement execution to a passed-in function.
+
"""
- def __init__(self):
- EngineStrategy.__init__(self, 'mock')
-
+ name = 'mock'
+
def create(self, name_or_url, executor, **kwargs):
# create url.URL object
u = url.make_url(name_or_url)
def create(self, entity, **kwargs):
kwargs['checkfirst'] = False
- self.dialect.schemagenerator(self.dialect, self, **kwargs).traverse(entity)
+ from sqlalchemy.engine import ddl
+
+ ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity)
def drop(self, entity, **kwargs):
kwargs['checkfirst'] = False
- self.dialect.schemadropper(self.dialect, self, **kwargs).traverse(entity)
+ from sqlalchemy.engine import ddl
+ ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity)
def execute(self, object, *multiparams, **params):
raise NotImplementedError()
from sqlalchemy import util
from sqlalchemy.engine import base
+
class TLSession(object):
def __init__(self, engine):
self.engine = engine
try:
return self.__transaction._increment_connect()
except AttributeError:
- return self.engine.TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result)
+ return self.engine.TLConnection(self, self.engine.pool.connect(),
+ close_with_result=close_with_result)
def reset(self):
try:
format of the URL is an RFC-1738-style string.
All initialization parameters are available as public attributes.
-
- :param drivername: the name of the database backend.
- This name will correspond to a module in sqlalchemy/databases
+
+ :param drivername: the name of the database backend.
+ This name will correspond to a module in sqlalchemy/databases
or a third party plug-in.
:param username: The user name.
:param database: The database name.
- :param query: A dictionary of options to be passed to the
+ :param query: A dictionary of options to be passed to the
dialect and/or the DBAPI upon connect.
-
+
"""
- def __init__(self, drivername, username=None, password=None, host=None, port=None, database=None, query=None):
+ def __init__(self, drivername, username=None, password=None,
+ host=None, port=None, database=None, query=None):
self.drivername = drivername
self.username = username
self.password = password
keys.sort()
s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
return s
-
+
def __hash__(self):
return hash(str(self))
-
+
def __eq__(self, other):
return \
isinstance(other, URL) and \
self.host == other.host and \
self.database == other.database and \
self.query == other.query
-
+
def get_dialect(self):
- """Return the SQLAlchemy database dialect class corresponding to this URL's driver name."""
-
+ """Return the SQLAlchemy database dialect class corresponding
+ to this URL's driver name.
+ """
+
try:
- module = getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
+ if '+' in self.drivername:
+ dialect, driver = self.drivername.split('+')
+ else:
+ dialect, driver = self.drivername, 'base'
+
+ module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects
+ module = getattr(module, dialect)
+ module = getattr(module, driver)
+
return module.dialect
except ImportError:
if sys.exc_info()[2].tb_next is None:
if res.name == self.drivername:
return res.load()
raise
-
+
def translate_connect_args(self, names=[], **kw):
"""Translate url attributes into a dictionary of connection arguments.
from the final dictionary.
:param \**kw: Optional, alternate key names for url attributes.
-
+
:param names: Deprecated. Same purpose as the keyword-based alternate names,
but correlates the name to the original positionally.
-
"""
translated = {}
The given string is parsed according to the RFC 1738 spec. If an
existing URL object is passed, just returns the object.
-
"""
+
if isinstance(name_or_url, basestring):
return _parse_rfc1738_args(name_or_url)
else:
def _parse_rfc1738_args(name):
pattern = re.compile(r'''
- (?P<name>\w+)://
+ (?P<name>[\w\+]+)://
(?:
(?P<username>[^:/]*)
(?::(?P<password>[^/]*))?
tokens = components['database'].split('?', 2)
components['database'] = tokens[0]
query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None
+ # Py2K
if query is not None:
query = dict((k.encode('ascii'), query[k]) for k in query)
+ # end Py2K
else:
query = None
components['query'] = query
"""
+ @classmethod
def instance(cls, statement, params, orig, connection_invalidated=False):
# Don't ever wrap these, just return them directly as if
# DBAPIError didn't exist.
cls = glob[name]
return cls(statement, params, orig, connection_invalidated)
- instance = classmethod(instance)
def __init__(self, statement, params, orig, connection_invalidated=False):
try:
Compilers can also be made dialect-specific. The appropriate compiler will be invoked
for the dialect in use::
- from sqlalchemy.schema import DDLElement # this is a SQLA 0.6 construct
+ from sqlalchemy.schema import DDLElement
class AlterColumn(DDLElement):
def visit_alter_column(element, compiler, **kw):
return "ALTER COLUMN %s ..." % element.column.name
- @compiles(AlterColumn, 'postgres')
+ @compiles(AlterColumn, 'postgresql')
def visit_alter_column(element, compiler, **kw):
return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name)
-The second ``visit_alter_table`` will be invoked when any ``postgres`` dialect is used.
+The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used.
The ``compiler`` argument is the :class:`~sqlalchemy.engine.base.Compiled` object
in use. This object can be inspected for any information about the in-progress
compilation, including ``compiler.dialect``, ``compiler.statement`` etc.
-The :class:`~sqlalchemy.sql.compiler.SQLCompiler` and :class:`~sqlalchemy.sql.compiler.DDLCompiler` (DDLCompiler is 0.6. only)
+The :class:`~sqlalchemy.sql.compiler.SQLCompiler` and :class:`~sqlalchemy.sql.compiler.DDLCompiler`
both include a ``process()`` method which can be used for compilation of embedded attributes::
class InsertFromSelect(ClauseElement):
objects. A typical application setup using :func:`~sqlalchemy.orm.scoped_session` might look
like::
- engine = create_engine('postgres://scott:tiger@localhost/test')
+ engine = create_engine('postgresql://scott:tiger@localhost/test')
Session = scoped_session(sessionmaker(autocommit=False,
autoflush=False,
bind=engine))
_raw_append = collection.adds(1)(_raw_append)
def insert(self, index, entity):
- self[index:index] = [entity]
+ super(OrderingList, self).insert(index, entity)
+ self._reorder()
def remove(self, entity):
super(OrderingList, self).remove(entity)
def __setitem__(self, index, entity):
if isinstance(index, slice):
- for i in range(index.start or 0, index.stop or 0, index.step or 1):
+ step = index.step or 1
+ start = index.start or 0
+ if start < 0:
+ start += len(self)
+ stop = index.stop or len(self)
+ if stop < 0:
+ stop += len(self)
+
+ for i in xrange(start, stop, step):
self.__setitem__(i, entity[i])
else:
self._order_entity(index, entity, True)
super(OrderingList, self).__delitem__(index)
self._reorder()
+ # Py2K
def __setslice__(self, start, end, values):
super(OrderingList, self).__setslice__(start, end, values)
self._reorder()
def __delslice__(self, start, end):
super(OrderingList, self).__delslice__(start, end)
self._reorder()
-
+ # end Py2K
+
for func_name, func in locals().items():
if (util.callable(func) and func.func_name == func_name and
not func.__doc__ and hasattr(list, func_name)):
from sqlalchemy.util import pickle
import re
import base64
-from cStringIO import StringIO
+# Py3K
+#from io import BytesIO as byte_buffer
+# Py2K
+from cStringIO import StringIO as byte_buffer
+# end Py2K
+
+# Py3K
+#def b64encode(x):
+# return base64.b64encode(x).decode('ascii')
+#def b64decode(x):
+# return base64.b64decode(x.encode('ascii'))
+# Py2K
+b64encode = base64.b64encode
+b64decode = base64.b64decode
+# end Py2K
__all__ = ['Serializer', 'Deserializer', 'dumps', 'loads']
+
+
def Serializer(*args, **kw):
pickler = pickle.Pickler(*args, **kw)
if isinstance(obj, QueryableAttribute):
cls = obj.impl.class_
key = obj.impl.key
- id = "attribute:" + key + ":" + base64.b64encode(pickle.dumps(cls))
+ id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls))
elif isinstance(obj, Mapper) and not obj.non_primary:
- id = "mapper:" + base64.b64encode(pickle.dumps(obj.class_))
+ id = "mapper:" + b64encode(pickle.dumps(obj.class_))
elif isinstance(obj, Table):
id = "table:" + str(obj)
elif isinstance(obj, Column) and isinstance(obj.table, Table):
type_, args = m.group(1, 2)
if type_ == 'attribute':
key, clsarg = args.split(":")
- cls = pickle.loads(base64.b64decode(clsarg))
+ cls = pickle.loads(b64decode(clsarg))
return getattr(cls, key)
elif type_ == "mapper":
- cls = pickle.loads(base64.b64decode(args))
+ cls = pickle.loads(b64decode(args))
return class_mapper(cls)
elif type_ == "table":
return metadata.tables[args]
return unpickler
def dumps(obj):
- buf = StringIO()
+ buf = byte_buffer()
pickler = Serializer(buf)
pickler.dump(obj)
return buf.getvalue()
def loads(data, metadata=None, scoped_session=None, engine=None):
- buf = StringIO(data)
+ buf = byte_buffer(data)
unpickler = Deserializer(buf, metadata, scoped_session, engine)
return unpickler.load()
Get, filter, filter_by, order_by, limit, and the rest of the
query methods are explained in detail in the `SQLAlchemy documentation`__.
-__ http://www.sqlalchemy.org/docs/04/ormtutorial.html#datamapping_querying
+__ http://www.sqlalchemy.org/docs/05/ormtutorial.html#datamapping_querying
Modifying objects
def class_for_table(selectable, **mapper_kwargs):
selectable = expression._clause_element_as_expr(selectable)
mapname = 'Mapped' + _selectable_name(selectable)
+ # Py2K
if isinstance(mapname, unicode):
engine_encoding = selectable.metadata.bind.dialect.encoding
mapname = mapname.encode(engine_encoding)
+ # end Py2K
if isinstance(selectable, Table):
klass = TableClassType(mapname, (object,), {})
else:
def entity(self, attr, schema=None):
try:
t = self._cache[attr]
- except KeyError:
+ except KeyError, ke:
table = Table(attr, self._metadata, autoload=True, schema=schema or self.schema)
if not table.primary_key.columns:
+ # Py3K
+ #raise PKNotFoundError('table %r does not have a primary key defined [columns: %s]' % (attr, ','.join(table.c.keys()))) from ke
+ # Py2K
raise PKNotFoundError('table %r does not have a primary key defined [columns: %s]' % (attr, ','.join(table.c.keys())))
+ # end Py2K
if table.columns:
t = class_for_table(table)
else:
"""
+ def first_connect(self, dbapi_con, con_record):
+ """Called exactly once for the first DB-API connection.
+
+ dbapi_con
+ A newly connected raw DB-API connection (not a SQLAlchemy
+ ``Connection`` wrapper).
+
+ con_record
+ The ``_ConnectionRecord`` that persistently manages the connection
+
+ """
+
def checkout(self, dbapi_con, con_record, con_proxy):
"""Called when a connection is retrieved from the Pool.
"""
mapperlib._COMPILE_MUTEX.acquire()
try:
- for mapper in list(_mapper_registry):
- mapper.dispose()
+ while _mapper_registry:
+ try:
+ # can't even reliably call list(weakdict) in jython
+ mapper, b = _mapper_registry.popitem()
+ mapper.dispose()
+ except KeyError:
+ pass
finally:
mapperlib._COMPILE_MUTEX.release()
class _ProxyImpl(object):
accepts_scalar_loader = False
- dont_expire_missing = False
+ expire_missing = True
def __init__(self, key):
self.key = key
def __init__(self, class_, key,
callable_, trackparent=False, extension=None,
compare_function=None, active_history=False, parent_token=None,
- dont_expire_missing=False,
+ expire_missing=True,
**kwargs):
"""Construct an AttributeImpl.
Allows multiple AttributeImpls to all match a single
owner attribute.
- dont_expire_missing
- if True, don't add an "expiry" callable to this attribute
+ expire_missing
+ if False, don't add an "expiry" callable to this attribute
during state.expire_attributes(None), if no value is present
for this key.
active_history = True
break
self.active_history = active_history
- self.dont_expire_missing = dont_expire_missing
+ self.expire_missing = expire_missing
def hasparent(self, state, optimistic=False):
"""Return the boolean value of a `hasparent` flag attached to the given item.
self.local_attrs[key] = inst
self.install_descriptor(key, inst)
self[key] = inst
+
for cls in self.class_.__subclasses__():
- if isinstance(cls, types.ClassType):
- continue
manager = self._subclass_manager(cls)
manager.instrument_attribute(key, inst, True)
if key in self.mutable_attributes:
self.mutable_attributes.remove(key)
for cls in self.class_.__subclasses__():
- if isinstance(cls, types.ClassType):
- continue
manager = self._subclass_manager(cls)
manager.uninstrument_attribute(key, True)
func_vars = util.format_argspec_init(original__init__, grouped=False)
func_text = func_body % func_vars
+ # Py3K
+ #func_defaults = getattr(original__init__, '__defaults__', None)
+ # Py2K
func = getattr(original__init__, 'im_func', original__init__)
func_defaults = getattr(func, 'func_defaults', None)
+ # end Py2K
env = locals().copy()
exec func_text in env
if getattr(obj, '_sa_adapter', None) is not None:
return getattr(obj, '_sa_adapter')
elif setting_type == dict:
+ # Py3K
+ #return obj.values()
+ # Py2K
return getattr(obj, 'itervalues', getattr(obj, 'values'))()
+ # end Py2K
else:
return iter(obj)
def __iter__(self):
"""Iterate over entities in the collection."""
- return getattr(self._data(), '_sa_iterator')()
+
+ # Py3K requires iter() here
+ return iter(getattr(self._data(), '_sa_iterator')())
def __len__(self):
"""Count entities in the collection."""
fn(self, index, value)
else:
# slice assignment requires __delitem__, insert, __len__
- if index.stop is None:
- stop = 0
- elif index.stop < 0:
- stop = len(self) + index.stop
- else:
- stop = index.stop
step = index.step or 1
- rng = range(index.start or 0, stop, step)
+ start = index.start or 0
+ if start < 0:
+ start += len(self)
+ stop = index.stop or len(self)
+ if stop < 0:
+ stop += len(self)
+
if step == 1:
- for i in rng:
- del self[index.start]
- i = index.start
- for item in value:
- self.insert(i, item)
- i += 1
+ for i in xrange(start, stop, step):
+ if len(self) > start:
+ del self[start]
+
+ for i, item in enumerate(value):
+ self.insert(i + start, item)
else:
+ rng = range(start, stop, step)
if len(value) != len(rng):
raise ValueError(
"attempt to assign sequence of size %s to "
_tidy(__delitem__)
return __delitem__
+ # Py2K
def __setslice__(fn):
def __setslice__(self, start, end, values):
for value in self[start:end]:
fn(self, start, end)
_tidy(__delslice__)
return __delslice__
-
+ # end Py2K
+
def extend(fn):
def extend(self, iterable):
for value in iterable:
class InstrumentedDict(dict):
"""An instrumented version of the built-in dict."""
+ # Py3K
+ #__instrumentation__ = {
+ # 'iterator': 'values', }
+ # Py2K
__instrumentation__ = {
'iterator': 'itervalues', }
-
+ # end Py2K
+
__canned_instrumentation = {
list: InstrumentedList,
set: InstrumentedSet,
'iterator': '__iter__',
'_decorators': _set_decorators(), },
# decorators are required for dicts and object collections.
+ # Py3K
+ #dict: {'iterator': 'values',
+ # '_decorators': _dict_decorators(), },
+ # Py2K
dict: {'iterator': 'itervalues',
'_decorators': _dict_decorators(), },
+ # end Py2K
# < 0.4 compatible naming, deprecated- use decorators instead.
None: {}
}
return types[prop.direction](prop)
class DependencyProcessor(object):
- no_dependencies = False
+ has_dependencies = True
def __init__(self, prop):
self.prop = prop
"""a special DP that works for many-to-one relations, fires off for
child items who have changed their referenced key."""
- no_dependencies = True
+ has_dependencies = False
def register_dependencies(self, uowcommit):
pass
pass
_straight_ops = set(getattr(operators, op)
- for op in ('add', 'mul', 'sub', 'div', 'mod', 'truediv',
+ for op in ('add', 'mul', 'sub',
+ # Py2K
+ 'div',
+ # end Py2K
+ 'mod', 'truediv',
'lt', 'le', 'ne', 'gt', 'ge', 'eq'))
self._modified.discard(state)
def _dirty_states(self):
- return self._modified.union(s for s in list(self._mutable_attrs)
+ return self._modified.union(s for s in self._mutable_attrs.copy()
if s.modified)
def check_modified(self):
if self._modified:
return True
else:
- for state in list(self._mutable_attrs):
+ for state in self._mutable_attrs.copy():
if state.modified:
return True
return False
return self[key]
except KeyError:
return default
-
+
+ # Py2K
def items(self):
return list(self.iteritems())
def iteritems(self):
for state in dict.itervalues(self):
+ # end Py2K
+ # Py3K
+ #def items(self):
+ # for state in dict.values(self):
value = state.obj()
if value is not None:
yield state.key, value
+ # Py2K
+ def values(self):
+ return list(self.itervalues())
+
def itervalues(self):
for state in dict.itervalues(self):
+ # end Py2K
+ # Py3K
+ #def values(self):
+ # for state in dict.values(self):
instance = state.obj()
if instance is not None:
yield instance
- def values(self):
- return list(self.itervalues())
-
def all_states(self):
+ # Py3K
+ # return list(dict.values(self))
+
+ # Py2K
return dict.values(self)
+ # end Py2K
def prune(self):
return 0
class StrongInstanceDict(IdentityMap):
def all_states(self):
- return [attributes.instance_state(o) for o in self.values()]
+ return [attributes.instance_state(o) for o in self.itervalues()]
def contains_state(self, state):
return state.key in self and attributes.instance_state(self[state.key]) is state
ref_count = len(self)
dirty = [s.obj() for s in self.all_states() if s.check_modified()]
- keepers = weakref.WeakValueDictionary(self)
+
+ # work around http://bugs.python.org/issue6149
+ keepers = weakref.WeakValueDictionary()
+ keepers.update(self)
+
dict.clear(self)
dict.update(self, keepers)
self.modified = bool(dirty)
return not self.parent.non_primary
- def merge(self, session, source, dest, dont_load, _recursive):
+ def merge(self, session, source, dest, load, _recursive):
"""Merge the attribute represented by this ``MapperProperty``
from source to destination object"""
if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty:
col = self.mapped_table.corresponding_column(self.polymorphic_on)
if not col:
- dont_instrument = True
+ instrument = False
col = self.polymorphic_on
else:
- dont_instrument = False
+ instrument = True
if self._should_exclude(col.key, col.key, local=False):
raise sa_exc.InvalidRequestError("Cannot exclude or override the discriminator column %r" % col.key)
- self._configure_property(col.key, ColumnProperty(col, _no_instrument=dont_instrument), init=False, setparent=True)
+ self._configure_property(col.key, ColumnProperty(col, _instrument=instrument), init=False, setparent=True)
def _adapt_inherited_property(self, key, prop, init):
if not self.concrete:
statement = table.insert()
for state, params, mapper, connection, value_params in insert:
c = connection.execute(statement.values(value_params), params)
- primary_key = c.last_inserted_ids()
+ primary_key = c.inserted_primary_key
if primary_key is not None:
# set primary key attributes
if state.load_options:
state.load_path = context.query._current_path + path
+ if isnew:
+ if context.options:
+ state.load_options = context.options
+ if state.load_options:
+ state.load_path = context.query._current_path + path
+
if not new_populators:
new_populators[:], existing_populators[:] = self._populators(context, path, row, adapter)
self.columns = [expression._labeled(c) for c in columns]
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
- self.no_instrument = kwargs.pop('_no_instrument', False)
+ self.instrument = kwargs.pop('_instrument', True)
self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator)
self.descriptor = kwargs.pop('descriptor', None)
self.extension = kwargs.pop('extension', None)
self.__class__.__name__, ', '.join(sorted(kwargs.keys()))))
util.set_creation_order(self)
- if self.no_instrument:
+ if not self.instrument:
self.strategy_class = strategies.UninstrumentedColumnLoader
elif self.deferred:
self.strategy_class = strategies.DeferredColumnLoader
self.strategy_class = strategies.ColumnLoader
def instrument_class(self, mapper):
- if self.no_instrument:
+ if not self.instrument:
return
attributes.register_descriptor(
def setattr(self, state, value, column):
state.get_impl(self.key).set(state, state.dict, value, None)
- def merge(self, session, source, dest, dont_load, _recursive):
+ def merge(self, session, source, dest, load, _recursive):
value = attributes.instance_state(source).value_as_iterable(
self.key, passive=True)
if value:
proxy_property=self.descriptor
)
- def merge(self, session, source, dest, dont_load, _recursive):
+ def merge(self, session, source, dest, load, _recursive):
pass
log.class_logger(SynonymProperty)
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
return (None, None)
- def merge(self, session, source, dest, dont_load, _recursive):
+ def merge(self, session, source, dest, load, _recursive):
pass
def __str__(self):
return str(self.parent.class_.__name__) + "." + self.key
- def merge(self, session, source, dest, dont_load, _recursive):
- if not dont_load:
+ def merge(self, session, source, dest, load, _recursive):
+ if load:
# TODO: no test coverage for recursive check
for r in self._reverse_property:
if (source, r) in _recursive:
dest_list = []
for current in instances:
_recursive[(current, self)] = True
- obj = session._merge(current, dont_load=dont_load, _recursive=_recursive)
+ obj = session._merge(current, load=load, _recursive=_recursive)
if obj is not None:
dest_list.append(obj)
- if dont_load:
+ if not load:
coll = attributes.init_collection(dest_state, self.key)
for c in dest_list:
coll.append_without_event(c)
current = instances[0]
if current is not None:
_recursive[(current, self)] = True
- obj = session._merge(current, dont_load=dont_load, _recursive=_recursive)
+ obj = session._merge(current, load=load, _recursive=_recursive)
if obj is not None:
- if dont_load:
+ if not load:
dest_state.dict[self.key] = obj
else:
setattr(dest, self.key, obj)
def value(self, column):
"""Return a scalar result corresponding to the given column expression."""
try:
+ # Py3K
+ #return self.values(column).__next__()[0]
+ # Py2K
return self.values(column).next()[0]
+ # end Py2K
except StopIteration:
return None
session._finalize_loaded(context.progress)
- for ii, (dict_, attrs) in context.partials.items():
+ for ii, (dict_, attrs) in context.partials.iteritems():
ii.commit(dict_, attrs)
for row in rows:
self.session._autoflush()
return self.session.scalar(s, params=self._params, mapper=self._mapper_zero())
- def delete(self, synchronize_session='fetch'):
+ def delete(self, synchronize_session='evaluate'):
"""Perform a bulk delete query.
Deletes rows matched by this query from the database.
False
don't synchronize the session. This option is the most efficient and is reliable
- once the session is expired, which typically occurs after a commit(). Before
- the expiration, objects may still remain in the session which were in fact deleted
- which can lead to confusing results if they are accessed via get() or already
- loaded collections.
+ once the session is expired, which typically occurs after a commit(), or explicitly
+ using expire_all(). Before the expiration, objects may still remain in the session
+ which were in fact deleted which can lead to confusing results if they are accessed
+ via get() or already loaded collections.
'fetch'
performs a select query before the delete to find objects that are matched
by the delete query and need to be removed from the session. Matched objects
- are removed from the session. 'fetch' is the default strategy.
+ are removed from the session.
'evaluate'
experimental feature. Tries to evaluate the querys criteria in Python
if synchronize_session == 'evaluate':
try:
evaluator_compiler = evaluator.EvaluatorCompiler()
- eval_condition = evaluator_compiler.process(self.whereclause)
+ eval_condition = evaluator_compiler.process(self.whereclause or expression._Null)
except evaluator.UnevaluatableError:
- synchronize_session = 'fetch'
+ raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python. "
+ "Specify 'fetch' or False for the synchronize_session parameter.")
delete_stmt = sql.delete(primary_table, context.whereclause)
return result.rowcount
- def update(self, values, synchronize_session='expire'):
+ def update(self, values, synchronize_session='evaluate'):
"""Perform a bulk update query.
Updates rows matched by this query in the database.
attributes on objects in the session. Valid values are:
False
- don't synchronize the session. Use this when you don't need to use the
- session after the update or you can be sure that none of the matched objects
- are in the session.
-
- 'expire'
+ don't synchronize the session. This option is the most efficient and is reliable
+ once the session is expired, which typically occurs after a commit(), or explicitly
+ using expire_all(). Before the expiration, updated objects may still remain in the session
+ with stale values on their attributes, which can lead to confusing results.
+
+ 'fetch'
performs a select query before the update to find objects that are matched
by the update query. The updated attributes are expired on matched objects.
'evaluate'
- experimental feature. Tries to evaluate the querys criteria in Python
+ Tries to evaluate the Query's criteria in Python
straight on the objects in the session. If evaluation of the criteria isn't
- implemented, the 'expire' strategy will be used as a fallback.
+ implemented, an exception is raised.
The expression evaluator currently doesn't account for differing string
collations between the database and Python.
The method does *not* offer in-Python cascading of relations - it is assumed that
ON UPDATE CASCADE is configured for any foreign key references which require it.
+
The Session needs to be expired (occurs automatically after commit(), or call expire_all())
in order for the state of dependent objects subject foreign key cascade to be
correctly represented.
#TODO: updates of manytoone relations need to be converted to fk assignments
#TODO: cascades need handling.
+ if synchronize_session not in [False, 'evaluate', 'fetch']:
+ raise sa_exc.ArgumentError("Valid strategies for session synchronization are False, 'evaluate' and 'fetch'")
self._no_select_modifiers("update")
- if synchronize_session not in [False, 'evaluate', 'expire']:
- raise sa_exc.ArgumentError("Valid strategies for session synchronization are False, 'evaluate' and 'expire'")
self = self.enable_eagerloads(False)
if synchronize_session == 'evaluate':
try:
evaluator_compiler = evaluator.EvaluatorCompiler()
- eval_condition = evaluator_compiler.process(self.whereclause)
+ eval_condition = evaluator_compiler.process(self.whereclause or expression._Null)
value_evaluators = {}
- for key,value in values.items():
+ for key,value in values.iteritems():
key = expression._column_as_key(key)
value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value))
except evaluator.UnevaluatableError:
- synchronize_session = 'expire'
+ raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python. "
+ "Specify 'fetch' or False for the synchronize_session parameter.")
update_stmt = sql.update(primary_table, context.whereclause, values)
- if synchronize_session == 'expire':
+ if synchronize_session == 'fetch':
select_stmt = context.statement.with_only_columns(primary_table.primary_key)
matched_rows = session.execute(select_stmt, params=self._params).fetchall()
# expire attributes with pending changes (there was no autoflush, so they are overwritten)
state.expire_attributes(set(evaluated_keys).difference(to_evaluate))
- elif synchronize_session == 'expire':
+ elif synchronize_session == 'fetch':
target_mapper = self._mapper_zero()
for primary_key in matched_rows:
return entity is self.entity_zero
else:
return not _is_aliased_class(self.entity_zero) and entity.base_mapper.common_parent(self.entity_zero)
-
+
def _resolve_expr_against_query_aliases(self, query, expr, context):
return query._adapt_clause(expr, False, True)
def _default__init__(ext, mapper):
def __init__(self, **kwargs):
- for key, value in kwargs.items():
+ for key, value in kwargs.iteritems():
if ext.validate:
if not mapper.get_property(key, resolve_synonyms=False,
raiseerr=False):
like::
sess = Session(binds={
- SomeMappedClass: create_engine('postgres://engine1'),
- somemapper: create_engine('postgres://engine2'),
- some_table: create_engine('postgres://engine3'),
+ SomeMappedClass: create_engine('postgresql://engine1'),
+ somemapper: create_engine('postgresql://engine2'),
+ some_table: create_engine('postgresql://engine3'),
})
Also see the ``bind_mapper()`` and ``bind_table()`` methods.
for state, m, o in cascade_states:
self._delete_impl(state)
- def merge(self, instance, dont_load=False):
+ def merge(self, instance, load=True, **kw):
"""Copy the state an instance onto the persistent instance with the same identifier.
If there is no persistent instance currently associated with the
mapped with ``cascade="merge"``.
"""
+ if 'dont_load' in kw:
+ load = not kw['dont_load']
+ util.warn_deprecated("dont_load=True has been renamed to load=False.")
+
# TODO: this should be an IdentityDict for instances, but will
# need a separate dict for PropertyLoader tuples
_recursive = {}
autoflush = self.autoflush
try:
self.autoflush = False
- return self._merge(instance, dont_load=dont_load, _recursive=_recursive)
+ return self._merge(instance, load=load, _recursive=_recursive)
finally:
self.autoflush = autoflush
- def _merge(self, instance, dont_load=False, _recursive=None):
+ def _merge(self, instance, load=True, _recursive=None):
mapper = _object_mapper(instance)
if instance in _recursive:
return _recursive[instance]
key = state.key
if key is None:
- if dont_load:
+ if not load:
raise sa_exc.InvalidRequestError(
- "merge() with dont_load=True option does not support "
+ "merge() with load=False option does not support "
"objects transient (i.e. unpersisted) objects. flush() "
"all changes on mapped instances before merging with "
- "dont_load=True.")
+ "load=False.")
key = mapper._identity_key_from_state(state)
merged = None
if key:
if key in self.identity_map:
merged = self.identity_map[key]
- elif dont_load:
+ elif not load:
if state.modified:
raise sa_exc.InvalidRequestError(
- "merge() with dont_load=True option does not support "
+ "merge() with load=False option does not support "
"objects marked as 'dirty'. flush() all changes on "
- "mapped instances before merging with dont_load=True.")
+ "mapped instances before merging with load=False.")
merged = mapper.class_manager.new_instance()
merged_state = attributes.instance_state(merged)
merged_state.key = key
_recursive[instance] = merged
for prop in mapper.iterate_properties:
- prop.merge(self, instance, merged, dont_load, _recursive)
+ prop.merge(self, instance, merged, load, _recursive)
- if dont_load:
+ if not load:
attributes.instance_state(merged).commit_all(attributes.instance_dict(merged), self.identity_map) # remove any history
if new_instance:
def detach(self):
if self.session_id:
- del self.session_id
+ try:
+ del self.session_id
+ except AttributeError:
+ pass
def dispose(self):
if self.session_id:
- del self.session_id
+ try:
+ del self.session_id
+ except AttributeError:
+ pass
del self.obj
def _cleanup(self, ref):
for key in attribute_names:
impl = self.manager[key].impl
if not filter_deferred or \
- not impl.dont_expire_missing or \
+ impl.expire_missing or \
key in dict_:
self.expired_attributes.add(key)
if impl.accepts_scalar_loader:
if useobject:
attribute_ext.append(sessionlib.UOWEventHandler(prop.key))
+
for m in mapper.polymorphic_iterator():
if prop is m._props.get(prop.key):
copy_function=self.columns[0].type.copy_value,
mutable_scalars=self.columns[0].type.is_mutable(),
callable_=self._class_level_loader,
- dont_expire_missing=True
+ expire_missing=False
)
def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs):
if subtask not in deps_by_targettask:
continue
for dep in deps_by_targettask[subtask]:
- if dep.processor.no_dependencies or not dependency_in_cycles(dep):
+ if not dep.processor.has_dependencies or not dependency_in_cycles(dep):
continue
(processor, targettask) = (dep.processor, dep.targettask)
isdelete = taskelement.isdelete
import weakref, time, threading
from sqlalchemy import exc, log
-from sqlalchemy import queue as Queue
+from sqlalchemy import queue as sqla_queue
from sqlalchemy.util import threading, pickle, as_interface
proxies = {}
All pools and connections are disposed.
"""
- for manager in proxies.values():
+ for manager in proxies.itervalues():
manager.close()
proxies.clear()
self.echo = echo
self.listeners = []
self._on_connect = []
+ self._on_first_connect = []
self._on_checkout = []
self._on_checkin = []
"""
- listener = as_interface(
- listener, methods=('connect', 'checkout', 'checkin'))
+ listener = as_interface(listener,
+ methods=('connect', 'first_connect', 'checkout', 'checkin'))
self.listeners.append(listener)
if hasattr(listener, 'connect'):
self._on_connect.append(listener)
+ if hasattr(listener, 'first_connect'):
+ self._on_first_connect.append(listener)
if hasattr(listener, 'checkout'):
self._on_checkout.append(listener)
if hasattr(listener, 'checkin'):
self.__pool = pool
self.connection = self.__connect()
self.info = {}
+ ls = pool.__dict__.pop('_on_first_connect', None)
+ if ls is not None:
+ for l in ls:
+ l.first_connect(self.connection, self)
if pool._on_connect:
for l in pool._on_connect:
l.connect(self.connection, self)
def _finalize_fairy(connection, connection_record, pool, ref=None):
- if ref is not None and connection_record.backref is not ref:
+ _refs.discard(connection_record)
+
+ if ref is not None and (connection_record.fairy is not ref or isinstance(pool, AssertionPool)):
return
+
if connection is not None:
try:
if pool._reset_on_return:
if isinstance(e, (SystemExit, KeyboardInterrupt)):
raise
if connection_record is not None:
- connection_record.backref = None
+ connection_record.fairy = None
if pool._should_log_info:
pool.log("Connection %r being returned to pool" % connection)
if pool._on_checkin:
l.checkin(connection, connection_record)
pool.return_conn(connection_record)
+_refs = set()
+
class _ConnectionFairy(object):
"""Proxies a DB-API connection and provides return-on-dereference support."""
try:
rec = self._connection_record = pool.get()
conn = self.connection = self._connection_record.get_connection()
- self._connection_record.backref = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref))
+ rec.fairy = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref))
+ _refs.add(rec)
except:
self.connection = None # helps with endless __getattr__ loops later on
self._connection_record = None
"""
if self._connection_record is not None:
+ _refs.remove(self._connection_record)
+ self._connection_record.fairy = None
self._connection_record.connection = None
- self._connection_record.backref = None
self._pool.do_return_conn(self._connection_record)
self._detached_info = \
self._connection_record.info.copy()
del self._conn.current
def cleanup(self):
- for conn in list(self._all_conns):
- self._all_conns.discard(conn)
- if len(self._all_conns) <= self.size:
- return
+ while len(self._all_conns) > self.size:
+ self._all_conns.pop()
def status(self):
return "SingletonThreadPool id:%d size: %d" % (id(self), len(self._all_conns))
"""
Pool.__init__(self, creator, **params)
- self._pool = Queue.Queue(pool_size)
+ self._pool = sqla_queue.Queue(pool_size)
self._overflow = 0 - pool_size
self._max_overflow = max_overflow
self._timeout = timeout
def do_return_conn(self, conn):
try:
self._pool.put(conn, False)
- except Queue.Full:
+ except sqla_queue.Full:
if self._overflow_lock is None:
self._overflow -= 1
else:
try:
wait = self._max_overflow > -1 and self._overflow >= self._max_overflow
return self._pool.get(wait, self._timeout)
- except Queue.Empty:
+ except sqla_queue.Empty:
if self._max_overflow > -1 and self._overflow >= self._max_overflow:
if not wait:
return self.do_get()
try:
conn = self._pool.get(False)
conn.close()
- except Queue.Empty:
+ except sqla_queue.Empty:
break
self._overflow = 0 - self.size()
Pool.__init__(self, creator, **params)
self._conn = creator()
self.connection = _ConnectionRecord(self)
-
+ self.connection = None
+
def status(self):
return "StaticPool"
## TODO: modify this to handle an arbitrary connection count.
- def __init__(self, creator, **params):
- """
- Construct an AssertionPool.
-
- :param creator: a callable function that returns a DB-API
- connection object. The function will be called with
- parameters.
-
- :param recycle: If set to non -1, number of seconds between
- connection recycling, which means upon checkout, if this
- timeout is surpassed the connection will be closed and
- replaced with a newly opened connection. Defaults to -1.
-
- :param echo: If True, connections being pulled and retrieved
- from the pool will be logged to the standard output, as well
- as pool sizing information. Echoing can also be achieved by
- enabling logging for the "sqlalchemy.pool"
- namespace. Defaults to False.
-
- :param use_threadlocal: If set to True, repeated calls to
- :meth:`connect` within the same application thread will be
- guaranteed to return the same connection object, if one has
- already been retrieved from the pool and has not been
- returned yet. Offers a slight performance advantage at the
- cost of individual transactions by default. The
- :meth:`unique_connection` method is provided to bypass the
- threadlocal behavior installed into :meth:`connect`.
-
- :param reset_on_return: If true, reset the database state of
- connections returned to the pool. This is typically a
- ROLLBACK to release locks and transaction resources.
- Disable at your own peril. Defaults to True.
-
- :param listeners: A list of
- :class:`~sqlalchemy.interfaces.PoolListener`-like objects or
- dictionaries of callables that receive events when DB-API
- connections are created, checked out and checked in to the
- pool.
-
- """
- Pool.__init__(self, creator, **params)
- self.connection = _ConnectionRecord(self)
- self._conn = self.connection
-
+ def __init__(self, *args, **kw):
+ self._conn = None
+ self._checked_out = False
+ Pool.__init__(self, *args, **kw)
+
def status(self):
return "AssertionPool"
- def create_connection(self):
- raise AssertionError("Invalid")
-
def do_return_conn(self, conn):
- assert conn is self._conn and self.connection is None
- self.connection = conn
+ if not self._checked_out:
+ raise AssertionError("connection is not checked out")
+ self._checked_out = False
+ assert conn is self._conn
def do_return_invalid(self, conn):
- raise AssertionError("Invalid")
+ self._conn = None
+ self._checked_out = False
+
+ def dispose(self):
+ self._checked_out = False
+ self._conn.close()
+ def recreate(self):
+ self.log("Pool recreating")
+ return AssertionPool(self._creator, echo=self._should_log_info, listeners=self.listeners)
+
def do_get(self):
- assert self.connection is not None
- c = self.connection
- self.connection = None
- return c
+ if self._checked_out:
+ raise AssertionError("connection is already checked out")
+
+ if not self._conn:
+ self._conn = self.create_connection()
+
+ self._checked_out = True
+ return self._conn
class _DBProxy(object):
"""Layers connection pooling behavior on top of a standard DB-API module.
behavior, using RLock instead of Lock for its mutex object.
This is to support the connection pool's usage of weakref callbacks to return
-connections to the underlying Queue, which can apparently in extremely
+connections to the underlying Queue, which can in extremely
rare cases be invoked within the ``get()`` method of the Queue itself,
producing a ``put()`` inside the ``get()`` and therefore a reentrant
condition."""
"""
import re, inspect
-from sqlalchemy import types, exc, util, databases
+from sqlalchemy import types, exc, util, dialects
from sqlalchemy.sql import expression, visitors
URL = None
def __repr__(self):
return "%s()" % self.__class__.__name__
- @property
- def bind(self):
- """Return the connectable associated with this SchemaItem."""
-
- m = self.metadata
- return m and m.bind or None
-
@util.memoized_property
def info(self):
return {}
else:
return schema + "." + name
-class _TableSingleton(visitors.VisitableType):
- """A metaclass used by the ``Table`` object to provide singleton behavior."""
+class Table(SchemaItem, expression.TableClause):
+ """Represent a table in a database.
+
+ e.g.::
+
+ mytable = Table("mytable", metadata,
+ Column('mytable_id', Integer, primary_key=True),
+ Column('value', String(50))
+ )
+
+ The Table object constructs a unique instance of itself based on its
+ name within the given MetaData object. Constructor
+ arguments are as follows:
+
+ :param name: The name of this table as represented in the database.
+
+ This property, along with the *schema*, indicates the *singleton
+ identity* of this table in relation to its parent :class:`MetaData`.
+ Additional calls to :class:`Table` with the same name, metadata,
+ and schema name will return the same :class:`Table` object.
+
+ Names which contain no upper case characters
+ will be treated as case insensitive names, and will not be quoted
+ unless they are a reserved word. Names with any number of upper
+ case characters will be quoted and sent exactly. Note that this
+ behavior applies even for databases which standardize upper
+ case names as case insensitive such as Oracle.
+
+ :param metadata: a :class:`MetaData` object which will contain this
+ table. The metadata is used as a point of association of this table
+ with other tables which are referenced via foreign key. It also
+ may be used to associate this table with a particular
+ :class:`~sqlalchemy.engine.base.Connectable`.
+
+ :param \*args: Additional positional arguments are used primarily
+ to add the list of :class:`Column` objects contained within this table.
+ Similar to the style of a CREATE TABLE statement, other :class:`SchemaItem`
+ constructs may be added here, including :class:`PrimaryKeyConstraint`,
+ and :class:`ForeignKeyConstraint`.
+
+ :param autoload: Defaults to False: the Columns for this table should be reflected
+ from the database. Usually there will be no Column objects in the
+ constructor if this property is set.
+
+ :param autoload_with: If autoload==True, this is an optional Engine or Connection
+ instance to be used for the table reflection. If ``None``, the
+ underlying MetaData's bound connectable will be used.
+
+ :param implicit_returning: True by default - indicates that
+ RETURNING can be used by default to fetch newly inserted primary key
+ values, for backends which support this. Note that
+ create_engine() also provides an implicit_returning flag.
+
+ :param include_columns: A list of strings indicating a subset of columns to be loaded via
+ the ``autoload`` operation; table columns who aren't present in
+ this list will not be represented on the resulting ``Table``
+ object. Defaults to ``None`` which indicates all columns should
+ be reflected.
+
+ :param info: A dictionary which defaults to ``{}``. A space to store application
+ specific data. This must be a dictionary.
+
+ :param mustexist: When ``True``, indicates that this Table must already
+ be present in the given :class:`MetaData`` collection.
+
+ :param prefixes:
+ A list of strings to insert after CREATE in the CREATE TABLE
+ statement. They will be separated by spaces.
+
+ :param quote: Force quoting of this table's name on or off, corresponding
+ to ``True`` or ``False``. When left at its default of ``None``,
+ the column identifier will be quoted according to whether the name is
+ case sensitive (identifiers with at least one upper case character are
+ treated as case sensitive), or if it's a reserved word. This flag
+ is only needed to force quoting of a reserved word which is not known
+ by the SQLAlchemy dialect.
+
+ :param quote_schema: same as 'quote' but applies to the schema identifier.
+
+ :param schema: The *schema name* for this table, which is required if the table
+ resides in a schema other than the default selected schema for the
+ engine's database connection. Defaults to ``None``.
+
+ :param useexisting: When ``True``, indicates that if this Table is already
+ present in the given :class:`MetaData`, apply further arguments within
+ the constructor to the existing :class:`Table`. If this flag is not
+ set, an error is raised when the parameters of an existing :class:`Table`
+ are overwritten.
+
+ """
+
+ __visit_name__ = 'table'
+
+ ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop')
- def __call__(self, name, metadata, *args, **kwargs):
- schema = kwargs.get('schema', kwargs.get('owner', None))
- useexisting = kwargs.pop('useexisting', False)
- mustexist = kwargs.pop('mustexist', False)
+ def __new__(cls, name, metadata, *args, **kw):
+ schema = kw.get('schema', None)
+ useexisting = kw.pop('useexisting', False)
+ mustexist = kw.pop('mustexist', False)
key = _get_table_key(name, schema)
- try:
- table = metadata.tables[key]
- if not useexisting and table._cant_override(*args, **kwargs):
+ if key in metadata.tables:
+ if not useexisting and bool(args):
raise exc.InvalidRequestError(
"Table '%s' is already defined for this MetaData instance. "
"Specify 'useexisting=True' to redefine options and "
"columns on an existing Table object." % key)
- else:
- table._init_existing(*args, **kwargs)
+ table = metadata.tables[key]
+ table._init_existing(*args, **kw)
return table
- except KeyError:
+ else:
if mustexist:
raise exc.InvalidRequestError(
"Table '%s' not defined" % (key))
+ metadata.tables[key] = table = object.__new__(cls)
try:
- return type.__call__(self, name, metadata, *args, **kwargs)
+ table._init(name, metadata, *args, **kw)
+ return table
except:
- if key in metadata.tables:
- del metadata.tables[key]
+ metadata.tables.pop(key)
raise
-
-
-class Table(SchemaItem, expression.TableClause):
- """Represent a table in a database."""
-
- __metaclass__ = _TableSingleton
-
- __visit_name__ = 'table'
-
- ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop')
-
- def __init__(self, name, metadata, *args, **kwargs):
- """
- Construct a Table.
-
- :param name: The name of this table as represented in the database.
-
- This property, along with the *schema*, indicates the *singleton
- identity* of this table in relation to its parent :class:`MetaData`.
- Additional calls to :class:`Table` with the same name, metadata,
- and schema name will return the same :class:`Table` object.
-
- Names which contain no upper case characters
- will be treated as case insensitive names, and will not be quoted
- unless they are a reserved word. Names with any number of upper
- case characters will be quoted and sent exactly. Note that this
- behavior applies even for databases which standardize upper
- case names as case insensitive such as Oracle.
-
- :param metadata: a :class:`MetaData` object which will contain this
- table. The metadata is used as a point of association of this table
- with other tables which are referenced via foreign key. It also
- may be used to associate this table with a particular
- :class:`~sqlalchemy.engine.base.Connectable`.
-
- :param \*args: Additional positional arguments are used primarily
- to add the list of :class:`Column` objects contained within this table.
- Similar to the style of a CREATE TABLE statement, other :class:`SchemaItem`
- constructs may be added here, including :class:`PrimaryKeyConstraint`,
- and :class:`ForeignKeyConstraint`.
-
- :param autoload: Defaults to False: the Columns for this table should be reflected
- from the database. Usually there will be no Column objects in the
- constructor if this property is set.
-
- :param autoload_with: If autoload==True, this is an optional Engine or Connection
- instance to be used for the table reflection. If ``None``, the
- underlying MetaData's bound connectable will be used.
-
- :param include_columns: A list of strings indicating a subset of columns to be loaded via
- the ``autoload`` operation; table columns who aren't present in
- this list will not be represented on the resulting ``Table``
- object. Defaults to ``None`` which indicates all columns should
- be reflected.
-
- :param info: A dictionary which defaults to ``{}``. A space to store application
- specific data. This must be a dictionary.
-
- :param mustexist: When ``True``, indicates that this Table must already
- be present in the given :class:`MetaData`` collection.
-
- :param prefixes:
- A list of strings to insert after CREATE in the CREATE TABLE
- statement. They will be separated by spaces.
-
- :param quote: Force quoting of this table's name on or off, corresponding
- to ``True`` or ``False``. When left at its default of ``None``,
- the column identifier will be quoted according to whether the name is
- case sensitive (identifiers with at least one upper case character are
- treated as case sensitive), or if it's a reserved word. This flag
- is only needed to force quoting of a reserved word which is not known
- by the SQLAlchemy dialect.
-
- :param quote_schema: same as 'quote' but applies to the schema identifier.
-
- :param schema: The *schema name* for this table, which is required if the table
- resides in a schema other than the default selected schema for the
- engine's database connection. Defaults to ``None``.
-
- :param useexisting: When ``True``, indicates that if this Table is already
- present in the given :class:`MetaData`, apply further arguments within
- the constructor to the existing :class:`Table`. If this flag is not
- set, an error is raised when the parameters of an existing :class:`Table`
- are overwritten.
-
- """
+
+ def __init__(self, *args, **kw):
+ # __init__ is overridden to prevent __new__ from
+ # calling the superclass constructor.
+ pass
+
+ def _init(self, name, metadata, *args, **kwargs):
super(Table, self).__init__(name)
self.metadata = metadata
- self.schema = kwargs.pop('schema', kwargs.pop('owner', None))
+ self.schema = kwargs.pop('schema', None)
self.indexes = set()
self.constraints = set()
self._columns = expression.ColumnCollection()
- self.primary_key = PrimaryKeyConstraint()
+ self._set_primary_key(PrimaryKeyConstraint())
self._foreign_keys = util.OrderedSet()
self.ddl_listeners = util.defaultdict(list)
self.kwargs = {}
autoload_with = kwargs.pop('autoload_with', None)
include_columns = kwargs.pop('include_columns', None)
- self._set_parent(metadata)
-
+ self.implicit_returning = kwargs.pop('implicit_returning', True)
self.quote = kwargs.pop('quote', None)
self.quote_schema = kwargs.pop('quote_schema', None)
if 'info' in kwargs:
self._prefixes = kwargs.pop('prefixes', [])
- self.__extra_kwargs(**kwargs)
+ self._extra_kwargs(**kwargs)
# load column definitions from the database if 'autoload' is defined
# we do it after the table is in the singleton dictionary to support
# initialize all the column, etc. objects. done after reflection to
# allow user-overrides
- self.__post_init(*args, **kwargs)
+ self._init_items(*args)
def _init_existing(self, *args, **kwargs):
autoload = kwargs.pop('autoload', False)
if 'info' in kwargs:
self.info = kwargs.pop('info')
- self.__extra_kwargs(**kwargs)
- self.__post_init(*args, **kwargs)
-
- def _cant_override(self, *args, **kwargs):
- """Return True if any argument is not supported as an override.
-
- Takes arguments that would be sent to Table.__init__, and returns
- True if any of them would be disallowed if sent to an existing
- Table singleton.
- """
- return bool(args) or bool(set(kwargs).difference(
- ['autoload', 'autoload_with', 'schema', 'owner']))
+ self._extra_kwargs(**kwargs)
+ self._init_items(*args)
- def __extra_kwargs(self, **kwargs):
+ def _extra_kwargs(self, **kwargs):
# validate remaining kwargs that they all specify DB prefixes
if len([k for k in kwargs
- if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]):
+ if not re.match(r'^(?:%s)_' % '|'.join(dialects.__all__), k)]):
raise TypeError(
- "Invalid argument(s) for Table: %s" % repr(kwargs.keys()))
+ "Invalid argument(s) for Table: %r" % kwargs.keys())
self.kwargs.update(kwargs)
- def __post_init(self, *args, **kwargs):
- self._init_items(*args)
-
- @property
- def key(self):
- return _get_table_key(self.name, self.schema)
-
def _set_primary_key(self, pk):
if getattr(self, '_primary_key', None) in self.constraints:
self.constraints.remove(self._primary_key)
self._primary_key = pk
self.constraints.add(pk)
+ for c in pk.columns:
+ c.primary_key = True
+
+ @util.memoized_property
+ def _autoincrement_column(self):
+ for col in self.primary_key:
+ if col.autoincrement and \
+ isinstance(col.type, types.Integer) and \
+ not col.foreign_keys and \
+ isinstance(col.default, (type(None), Sequence)):
+
+ return col
+
+ @property
+ def key(self):
+ return _get_table_key(self.name, self.schema)
+
+ @property
def primary_key(self):
return self._primary_key
- primary_key = property(primary_key, _set_primary_key)
def __repr__(self):
return "Table(%s)" % ', '.join(
def __str__(self):
return _get_table_key(self.description, self.schema)
+ @property
+ def bind(self):
+ """Return the connectable associated with this Table."""
+
+ return self.metadata and self.metadata.bind or None
+
def append_column(self, column):
"""Append a ``Column`` to this ``Table``."""
self, column_collections=column_collections, **kwargs)
else:
if column_collections:
- return [c for c in self.columns]
+ return list(self.columns)
else:
return []
"""Represents a column in a database table."""
__visit_name__ = 'column'
-
+
def __init__(self, *args, **kwargs):
"""
Construct a new ``Column`` object.
Contrast this argument to ``server_default`` which creates a
default generator on the database side.
-
+
:param key: An optional string identifier which will identify this ``Column``
object on the :class:`Table`. When a key is provided, this is the
only identifier referencing the ``Column`` within the application,
if args:
coltype = args[0]
- # adjust for partials
- if util.callable(coltype):
- coltype = args[0]()
-
if (isinstance(coltype, types.AbstractType) or
(isinstance(coltype, type) and
issubclass(coltype, types.AbstractType))):
type_ = args.pop(0)
super(Column, self).__init__(name, None, type_)
- self.args = args
self.key = kwargs.pop('key', name)
self.primary_key = kwargs.pop('primary_key', False)
self.nullable = kwargs.pop('nullable', not self.primary_key)
self.autoincrement = kwargs.pop('autoincrement', True)
self.constraints = set()
self.foreign_keys = util.OrderedSet()
+ self._table_events = set()
+
+ if self.default is not None:
+ if isinstance(self.default, ColumnDefault):
+ args.append(self.default)
+ else:
+ args.append(ColumnDefault(self.default))
+ if self.server_default is not None:
+ if isinstance(self.server_default, FetchedValue):
+ args.append(self.server_default)
+ else:
+ args.append(DefaultClause(self.server_default))
+ if self.onupdate is not None:
+ args.append(ColumnDefault(self.onupdate, for_update=True))
+ if self.server_onupdate is not None:
+ if isinstance(self.server_onupdate, FetchedValue):
+ args.append(self.server_default)
+ else:
+ args.append(DefaultClause(self.server_onupdate,
+ for_update=True))
+ self._init_items(*args)
+
util.set_creation_order(self)
if 'info' in kwargs:
else:
return self.description
- @property
- def bind(self):
- return self.table.bind
-
def references(self, column):
"""Return True if this Column references the given column via foreign key."""
for fk in self.foreign_keys:
"before adding to a Table.")
if self.key is None:
self.key = self.name
- self.metadata = table.metadata
+
if getattr(self, 'table', None) is not None:
raise exc.ArgumentError("this Column already has a table!")
if self.key in table._columns:
- # note the column being replaced, if any
- self._pre_existing_column = table._columns.get(self.key)
+ col = table._columns.get(self.key)
+ for fk in col.foreign_keys:
+ col.foreign_keys.remove(fk)
+ table.foreign_keys.remove(fk)
+ table.constraints.remove(fk.constraint)
+
table._columns.replace(self)
if self.primary_key:
- table.primary_key.replace(self)
+ table.primary_key._replace(self)
elif self.key in table.primary_key:
raise exc.ArgumentError(
"Trying to redefine primary-key column '%s' as a "
"non-primary-key column on table '%s'" % (
self.key, table.fullname))
- # if we think this should not raise an error, we'd instead do this:
- #table.primary_key.remove(self)
self.table = table
if self.index:
"external to the Table.")
table.append_constraint(UniqueConstraint(self.key))
- toinit = list(self.args)
- if self.default is not None:
- if isinstance(self.default, ColumnDefault):
- toinit.append(self.default)
- else:
- toinit.append(ColumnDefault(self.default))
- if self.server_default is not None:
- if isinstance(self.server_default, FetchedValue):
- toinit.append(self.server_default)
- else:
- toinit.append(DefaultClause(self.server_default))
- if self.onupdate is not None:
- toinit.append(ColumnDefault(self.onupdate, for_update=True))
- if self.server_onupdate is not None:
- if isinstance(self.server_onupdate, FetchedValue):
- toinit.append(self.server_default)
- else:
- toinit.append(DefaultClause(self.server_onupdate,
- for_update=True))
- self._init_items(*toinit)
- self.args = None
-
+ for fn in self._table_events:
+ fn(table)
+ del self._table_events
+
+ def _on_table_attach(self, fn):
+ if self.table:
+ fn(self.table)
+ else:
+ self._table_events.add(fn)
+
def copy(self, **kw):
"""Create a copy of this ``Column``, unitialized.
This is used in ``Table.tometadata``.
"""
- return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, quote=self.quote, index=self.index, autoincrement=self.autoincrement, *[c.copy(**kw) for c in self.constraints])
+ return Column(
+ self.name,
+ self.type,
+ self.default,
+ key = self.key,
+ primary_key = self.primary_key,
+ nullable = self.nullable,
+ quote=self.quote,
+ index=self.index,
+ autoincrement=self.autoincrement,
+ *[c.copy(**kw) for c in self.constraints])
def _make_proxy(self, selectable, name=None):
"""Create a *proxy* for this column.
This is a copy of this ``Column`` referenced by a different parent
- (such as an alias or select statement).
-
+ (such as an alias or select statement). The column should
+ be used only in select scenarios, as its full DDL/default
+ information is not transferred.
+
"""
fk = [ForeignKey(f.column) for f in self.foreign_keys]
c = Column(
name or self.name,
self.type,
- self.default,
key = name or self.key,
primary_key = self.primary_key,
nullable = self.nullable,
selectable.columns.add(c)
if self.primary_key:
selectable.primary_key.add(c)
- [c._init_items(f) for f in fk]
+ for fn in c._table_events:
+ fn(selectable)
+ del c._table_events
return c
def get_children(self, schema_visitor=False, **kwargs):
__visit_name__ = 'foreign_key'
- def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None, link_to_name=False):
+ def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None,
+ ondelete=None, deferrable=None, initially=None, link_to_name=False):
"""
- Construct a column-level FOREIGN KEY.
+ Construct a column-level FOREIGN KEY.
+
+ The :class:`ForeignKey` object when constructed generates a :class:`ForeignKeyConstraint`
+ which is associated with the parent :class:`Table` object's collection of constraints.
:param column: A single target column for the key relationship. A :class:`Column`
object or a column name as a string: ``tablename.columnkey`` or
:param link_to_name: if True, the string name given in ``column`` is the rendered
name of the referenced column, not its locally assigned ``key``.
- :param use_alter: If True, do not emit this key as part of the CREATE TABLE
- definition. Instead, use ALTER TABLE after table creation to add
- the key. Useful for circular dependencies.
-
+ :param use_alter: passed to the underlying :class:`ForeignKeyConstraint` to indicate the
+ constraint should be generated/dropped externally from the CREATE TABLE/
+ DROP TABLE statement. See that classes' constructor for details.
+
"""
self._colspec = column
def _set_parent(self, column):
if hasattr(self, 'parent'):
+ if self.parent is column:
+ return
raise exc.InvalidRequestError("This ForeignKey already has a parent !")
self.parent = column
- if hasattr(self.parent, '_pre_existing_column'):
- # remove existing FK which matches us
- for fk in self.parent._pre_existing_column.foreign_keys:
- if fk.target_fullname == self.target_fullname:
- self.parent.table.foreign_keys.remove(fk)
- self.parent.table.constraints.remove(fk.constraint)
-
- if self.constraint is None and isinstance(self.parent.table, Table):
+ self.parent.foreign_keys.add(self)
+ self.parent._on_table_attach(self._set_table)
+
+ def _set_table(self, table):
+ if self.constraint is None and isinstance(table, Table):
self.constraint = ForeignKeyConstraint(
[], [], use_alter=self.use_alter, name=self.name,
onupdate=self.onupdate, ondelete=self.ondelete,
- deferrable=self.deferrable, initially=self.initially)
- self.parent.table.append_constraint(self.constraint)
- self.constraint._append_fk(self)
-
- self.parent.foreign_keys.add(self)
- self.parent.table.foreign_keys.add(self)
-
+ deferrable=self.deferrable, initially=self.initially,
+ )
+ self.constraint._elements[self.parent] = self
+ self.constraint._set_parent(table)
+ table.foreign_keys.add(self)
+
class DefaultGenerator(SchemaItem):
"""Base class for column *default* values."""
__visit_name__ = 'default_generator'
- def __init__(self, for_update=False, metadata=None):
+ def __init__(self, for_update=False):
self.for_update = for_update
- self.metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata')
def _set_parent(self, column):
self.column = column
- self.metadata = self.column.table.metadata
if self.for_update:
self.column.onupdate = self
else:
bind = _bind_or_error(self)
return bind._execute_default(self, **kwargs)
+ @property
+ def bind(self):
+ """Return the connectable associated with this default."""
+ if getattr(self, 'column', None):
+ return self.column.table.bind
+ else:
+ return None
+
def __repr__(self):
return "DefaultGenerator()"
return lambda ctx: fn()
positionals = len(argspec[0])
- if inspect.ismethod(inspectable):
+
+ # Py3K compat - no unbound methods
+ if inspect.ismethod(inspectable) or inspect.isclass(fn):
positionals -= 1
if positionals == 0:
__visit_name__ = property(_visit_name)
def __repr__(self):
- return "ColumnDefault(%s)" % repr(self.arg)
+ return "ColumnDefault(%r)" % self.arg
class Sequence(DefaultGenerator):
"""Represents a named database sequence."""
__visit_name__ = 'sequence'
def __init__(self, name, start=None, increment=None, schema=None,
- optional=False, quote=None, **kwargs):
- super(Sequence, self).__init__(**kwargs)
+ optional=False, quote=None, metadata=None, for_update=False):
+ super(Sequence, self).__init__(for_update=for_update)
self.name = name
self.start = start
self.increment = increment
self.optional = optional
self.quote = quote
self.schema = schema
- self.kwargs = kwargs
+ self.metadata = metadata
def __repr__(self):
return "Sequence(%s)" % ', '.join(
def _set_parent(self, column):
super(Sequence, self)._set_parent(column)
column.sequence = self
-
+
+ column._on_table_attach(self._set_table)
+
+ def _set_table(self, table):
+ self.metadata = table.metadata
+
+ @property
+ def bind(self):
+ if self.metadata:
+ return self.metadata.bind
+ else:
+ return None
+
def create(self, bind=None, checkfirst=True):
"""Creates this sequence in the database."""
# alias; deprecated starting 0.5.0
PassiveDefault = DefaultClause
-
class Constraint(SchemaItem):
- """A table-level SQL constraint, such as a KEY.
-
- Implements a hybrid of dict/setlike behavior with regards to the list of
- underying columns.
- """
+ """A table-level SQL constraint."""
__visit_name__ = 'constraint'
- def __init__(self, name=None, deferrable=None, initially=None):
+ def __init__(self, name=None, deferrable=None, initially=None, inline_ddl=True):
"""Create a SQL constraint.
name
initially
Optional string. If set, emit INITIALLY <value> when issuing DDL
for this constraint.
+
+ inline_ddl
+ if True, DDL for this Constraint will be generated within the span of a
+ CREATE TABLE or DROP TABLE statement, when the associated table's
+ DDL is generated. if False, no DDL is issued within that process.
+ Instead, it is expected that an AddConstraint or DropConstraint
+ construct will be used to issue DDL for this Contraint.
+ The AddConstraint/DropConstraint constructs set this flag automatically
+ as well.
"""
self.name = name
- self.columns = expression.ColumnCollection()
self.deferrable = deferrable
self.initially = initially
+ self.inline_ddl = inline_ddl
+
+ @property
+ def table(self):
+ if isinstance(self.parent, Table):
+ return self.parent
+ else:
+ raise exc.InvalidRequestError("This constraint is not bound to a table.")
+
+ def _set_parent(self, parent):
+ self.parent = parent
+ parent.constraints.add(self)
+
+ def copy(self, **kw):
+ raise NotImplementedError()
+
+class ColumnCollectionConstraint(Constraint):
+ """A constraint that proxies a ColumnCollection."""
+
+ def __init__(self, *columns, **kw):
+ """
+ \*columns
+ A sequence of column names or Column objects.
+
+ name
+ Optional, the in-database name of this constraint.
+
+ deferrable
+ Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
+ issuing DDL for this constraint.
+
+ initially
+ Optional string. If set, emit INITIALLY <value> when issuing DDL
+ for this constraint.
+
+ """
+ super(ColumnCollectionConstraint, self).__init__(**kw)
+ self.columns = expression.ColumnCollection()
+ self._pending_colargs = [_to_schema_column_or_string(c) for c in columns]
+ if self._pending_colargs and \
+ isinstance(self._pending_colargs[0], Column) and \
+ self._pending_colargs[0].table is not None:
+ self._set_parent(self._pending_colargs[0].table)
+
+ def _set_parent(self, table):
+ super(ColumnCollectionConstraint, self)._set_parent(table)
+ for col in self._pending_colargs:
+ if isinstance(col, basestring):
+ col = table.c[col]
+ self.columns.add(col)
def __contains__(self, x):
return x in self.columns
+ def copy(self, **kw):
+ return self.__class__(name=self.name, deferrable=self.deferrable,
+ initially=self.initially, *self.columns.keys())
+
def contains_column(self, col):
return self.columns.contains_column(col)
- def keys(self):
- return self.columns.keys()
-
- def __add__(self, other):
- return self.columns + other
-
def __iter__(self):
return iter(self.columns)
def __len__(self):
return len(self.columns)
- def copy(self, **kw):
- raise NotImplementedError()
class CheckConstraint(Constraint):
"""A table- or column-level CHECK constraint.
Can be included in the definition of a Table or Column.
"""
- def __init__(self, sqltext, name=None, deferrable=None, initially=None):
+ def __init__(self, sqltext, name=None, deferrable=None, initially=None, table=None):
"""Construct a CHECK constraint.
sqltext
initially
Optional string. If set, emit INITIALLY <value> when issuing DDL
for this constraint.
+
"""
super(CheckConstraint, self).__init__(name, deferrable, initially)
raise exc.ArgumentError(
"sqltext must be a string and will be used verbatim.")
self.sqltext = sqltext
-
+ if table:
+ self._set_parent(table)
+
def __visit_name__(self):
if isinstance(self.parent, Table):
return "check_constraint"
return "column_check_constraint"
__visit_name__ = property(__visit_name__)
- def _set_parent(self, parent):
- self.parent = parent
- parent.constraints.add(self)
-
def copy(self, **kw):
return CheckConstraint(self.sqltext, name=self.name)
"""
__visit_name__ = 'foreign_key_constraint'
- def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False, deferrable=None, initially=None, link_to_name=False):
+ def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None,
+ deferrable=None, initially=None, use_alter=False, link_to_name=False, table=None):
"""Construct a composite-capable FOREIGN KEY.
:param columns: A sequence of local column names. The named columns must be defined
:param link_to_name: if True, the string name given in ``column`` is the rendered
name of the referenced column, not its locally assigned ``key``.
- :param use_alter: If True, do not emit this key as part of the CREATE TABLE
- definition. Instead, use ALTER TABLE after table creation to add
- the key. Useful for circular dependencies.
+ :param use_alter: If True, do not emit the DDL for this constraint
+ as part of the CREATE TABLE definition. Instead, generate it via an
+ ALTER TABLE statement issued after the full collection of tables have been
+ created, and drop it via an ALTER TABLE statement before the full collection
+ of tables are dropped. This is shorthand for the usage of
+ :class:`AddConstraint` and :class:`DropConstraint` applied as "after-create"
+ and "before-drop" events on the MetaData object. This is normally used to
+ generate/drop constraints on objects that are mutually dependent on each other.
"""
super(ForeignKeyConstraint, self).__init__(name, deferrable, initially)
- self.__colnames = columns
- self.__refcolnames = refcolumns
- self.elements = util.OrderedSet()
+
self.onupdate = onupdate
self.ondelete = ondelete
self.link_to_name = link_to_name
if self.name is None and use_alter:
- raise exc.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name")
+ raise exc.ArgumentError("Alterable Constraint requires a name")
self.use_alter = use_alter
+ self._elements = util.OrderedDict()
+ for col, refcol in zip(columns, refcolumns):
+ self._elements[col] = ForeignKey(
+ refcol,
+ constraint=self,
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ use_alter=self.use_alter,
+ link_to_name=self.link_to_name
+ )
+
+ if table:
+ self._set_parent(table)
+
+ @property
+ def columns(self):
+ return self._elements.keys()
+
+ @property
+ def elements(self):
+ return self._elements.values()
+
def _set_parent(self, table):
- self.table = table
- if self not in table.constraints:
- table.constraints.add(self)
- for (c, r) in zip(self.__colnames, self.__refcolnames):
- self.append_element(c, r)
-
- def append_element(self, col, refcol):
- fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter, link_to_name=self.link_to_name)
- fk._set_parent(self.table.c[col])
- self._append_fk(fk)
-
- def _append_fk(self, fk):
- self.columns.add(self.table.c[fk.parent.key])
- self.elements.add(fk)
-
+ super(ForeignKeyConstraint, self)._set_parent(table)
+ for col, fk in self._elements.iteritems():
+ if isinstance(col, basestring):
+ col = table.c[col]
+ fk._set_parent(col)
+
+ if self.use_alter:
+ def supports_alter(event, schema_item, bind, **kw):
+ return table in set(kw['tables']) and bind.dialect.supports_alter
+ AddConstraint(self, on=supports_alter).execute_at('after-create', table.metadata)
+ DropConstraint(self, on=supports_alter).execute_at('before-drop', table.metadata)
+
def copy(self, **kw):
- return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec(**kw) for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter)
-
-class PrimaryKeyConstraint(Constraint):
+ return ForeignKeyConstraint(
+ [x.parent.name for x in self._elements.values()],
+ [x._get_colspec(**kw) for x in self._elements.values()],
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ use_alter=self.use_alter
+ )
+
+class PrimaryKeyConstraint(ColumnCollectionConstraint):
"""A table-level PRIMARY KEY constraint.
Defines a single column or composite PRIMARY KEY constraint. For a
__visit_name__ = 'primary_key_constraint'
- def __init__(self, *columns, **kwargs):
- """Construct a composite-capable PRIMARY KEY.
-
- \*columns
- A sequence of column names. All columns named must be defined and
- present within the parent Table.
-
- name
- Optional, the in-database name of the key.
-
- deferrable
- Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
- issuing DDL for this constraint.
-
- initially
- Optional string. If set, emit INITIALLY <value> when issuing DDL
- for this constraint.
- """
-
- constraint_args = dict(name=kwargs.pop('name', None),
- deferrable=kwargs.pop('deferrable', None),
- initially=kwargs.pop('initially', None))
- if kwargs:
- raise exc.ArgumentError(
- 'Unknown PrimaryKeyConstraint argument(s): %s' %
- ', '.join(repr(x) for x in kwargs.keys()))
-
- super(PrimaryKeyConstraint, self).__init__(**constraint_args)
- self.__colnames = list(columns)
-
def _set_parent(self, table):
- self.table = table
- table.primary_key = self
- for name in self.__colnames:
- self.add(table.c[name])
-
- def add(self, col):
- self.columns.add(col)
- col.primary_key = True
- append_column = add
+ super(PrimaryKeyConstraint, self)._set_parent(table)
+ table._set_primary_key(self)
- def replace(self, col):
+ def _replace(self, col):
self.columns.replace(col)
- def remove(self, col):
- col.primary_key = False
- del self.columns[col.key]
-
- def copy(self, **kw):
- return PrimaryKeyConstraint(name=self.name, *[c.key for c in self])
-
- __hash__ = Constraint.__hash__
-
- def __eq__(self, other):
- return self.columns == other
-
-class UniqueConstraint(Constraint):
+class UniqueConstraint(ColumnCollectionConstraint):
"""A table-level UNIQUE constraint.
Defines a single column or composite UNIQUE constraint. For a no-frills,
__visit_name__ = 'unique_constraint'
- def __init__(self, *columns, **kwargs):
- """Construct a UNIQUE constraint.
-
- \*columns
- A sequence of column names. All columns named must be defined and
- present within the parent Table.
-
- name
- Optional, the in-database name of the key.
-
- deferrable
- Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
- issuing DDL for this constraint.
-
- initially
- Optional string. If set, emit INITIALLY <value> when issuing DDL
- for this constraint.
- """
-
- constraint_args = dict(name=kwargs.pop('name', None),
- deferrable=kwargs.pop('deferrable', None),
- initially=kwargs.pop('initially', None))
- if kwargs:
- raise exc.ArgumentError(
- 'Unknown UniqueConstraint argument(s): %s' %
- ', '.join(repr(x) for x in kwargs.keys()))
-
- super(UniqueConstraint, self).__init__(**constraint_args)
- self.__colnames = list(columns)
-
- def _set_parent(self, table):
- self.table = table
- table.constraints.add(self)
- for c in self.__colnames:
- self.append_column(table.c[c])
-
- def append_column(self, col):
- self.columns.add(col)
-
- def copy(self, **kw):
- return UniqueConstraint(name=self.name, *self.__colnames)
-
class Index(SchemaItem):
"""A table-level INDEX.
\*columns
Columns to include in the index. All columns must belong to the same
- table, and no column may appear more than once.
+ table.
\**kwargs
Keyword arguments include:
unique
Defaults to False: create a unique index.
- postgres_where
+ postgresql_where
Defaults to None: create a partial index when using PostgreSQL
"""
self.name = name
- self.columns = []
+ self.columns = expression.ColumnCollection()
self.table = None
self.unique = kwargs.pop('unique', False)
self.kwargs = kwargs
- self._init_items(*columns)
-
- def _init_items(self, *args):
- for column in args:
- self.append_column(_to_schema_column(column))
+ for column in columns:
+ column = _to_schema_column(column)
+ if self.table is None:
+ self._set_parent(column.table)
+ elif column.table != self.table:
+ # all columns muse be from same table
+ raise exc.ArgumentError(
+ "All index columns must be from same table. "
+ "%s is from %s not %s" % (column, column.table, self.table))
+ self.columns.add(column)
def _set_parent(self, table):
self.table = table
- self.metadata = table.metadata
table.indexes.add(self)
- def append_column(self, column):
- # make sure all columns are from the same table
- # and no column is repeated
- if self.table is None:
- self._set_parent(column.table)
- elif column.table != self.table:
- # all columns muse be from same table
- raise exc.ArgumentError(
- "All index columns must be from same table. "
- "%s is from %s not %s" % (column, column.table, self.table))
- elif column.name in [ c.name for c in self.columns ]:
- raise exc.ArgumentError(
- "A column may not appear twice in the "
- "same index (%s already has column %s)" % (self.name, column))
- self.columns.append(column)
+ @property
+ def bind(self):
+ """Return the connectable associated with this Index."""
+
+ return self.table.bind
def create(self, bind=None):
if bind is None:
bind = _bind_or_error(self)
bind.drop(self)
- def __str__(self):
- return repr(self)
-
def __repr__(self):
return 'Index("%s", %s%s)' % (self.name,
', '.join(repr(c) for c in self.columns),
return self._bind is not None
- @util.deprecated('Deprecated. Use ``metadata.bind = <engine>`` or '
- '``metadata.bind = <url>``.')
- def connect(self, bind, **kwargs):
- """Bind this MetaData to an Engine.
-
- bind
- A string, ``URL``, ``Engine`` or ``Connection`` instance. If a
- string or ``URL``, will be passed to ``create_engine()`` along with
- ``\**kwargs`` to produce the engine which to connect to. Otherwise
- connects directly to the given ``Engine``.
-
- """
- global URL
- if URL is None:
- from sqlalchemy.engine.url import URL
- if isinstance(bind, (basestring, URL)):
- from sqlalchemy import create_engine
- self._bind = create_engine(bind, **kwargs)
- else:
- self._bind = bind
-
def bind(self):
"""An Engine or Connection to which this MetaData is bound.
# TODO: scan all other tables and remove FK _column
del self.tables[table.key]
- @util.deprecated('Deprecated. Use ``metadata.sorted_tables``')
- def table_iterator(self, reverse=True, tables=None):
- """Deprecated - use metadata.sorted_tables()."""
-
- from sqlalchemy.sql.util import sort_tables
- if tables is None:
- tables = self.tables.values()
- else:
- tables = set(tables).intersection(self.tables.values())
- ret = sort_tables(tables)
- if reverse:
- ret = reversed(ret)
- return iter(ret)
-
@property
def sorted_tables(self):
"""Returns a list of ``Table`` objects sorted in order of
dependency.
"""
from sqlalchemy.sql.util import sort_tables
- return sort_tables(self.tables.values())
+ return sort_tables(self.tables.itervalues())
def reflect(self, bind=None, schema=None, only=None):
"""Load all available table definitions from the database.
available = util.OrderedSet(bind.engine.table_names(schema,
connection=conn))
- current = set(self.tables.keys())
+ current = set(self.tables.iterkeys())
if only is None:
load = [name for name in available if name not in current]
"""
if bind is None:
bind = _bind_or_error(self)
- for listener in self.ddl_listeners['before-create']:
- listener('before-create', self, bind)
bind.create(self, checkfirst=checkfirst, tables=tables)
- for listener in self.ddl_listeners['after-create']:
- listener('after-create', self, bind)
def drop_all(self, bind=None, tables=None, checkfirst=True):
"""Drop all tables stored in this metadata.
"""
if bind is None:
bind = _bind_or_error(self)
- for listener in self.ddl_listeners['before-drop']:
- listener('before-drop', self, bind)
bind.drop(self, checkfirst=checkfirst, tables=tables)
- for listener in self.ddl_listeners['after-drop']:
- listener('after-drop', self, bind)
class ThreadLocalMetaData(MetaData):
"""A MetaData variant that presents a different ``bind`` in every thread.
self.__engines = {}
super(ThreadLocalMetaData, self).__init__()
- @util.deprecated('Deprecated. Use ``metadata.bind = <engine>`` or '
- '``metadata.bind = <url>``.')
- def connect(self, bind, **kwargs):
- """Bind to an Engine in the caller's thread.
-
- bind
- A string, ``URL``, ``Engine`` or ``Connection`` instance. If a
- string or ``URL``, will be passed to ``create_engine()`` along with
- ``\**kwargs`` to produce the engine which to connect to. Otherwise
- connects directly to the given ``Engine``.
- """
-
- global URL
- if URL is None:
- from sqlalchemy.engine.url import URL
-
- if isinstance(bind, (basestring, URL)):
- try:
- engine = self.__engines[bind]
- except KeyError:
- from sqlalchemy import create_engine
- engine = create_engine(bind, **kwargs)
- bind = engine
- self._bind_to(bind)
-
def bind(self):
"""The bound Engine or Connection for this thread.
def dispose(self):
"""Dispose all bound engines, in all thread contexts."""
- for e in self.__engines.values():
+ for e in self.__engines.itervalues():
if hasattr(e, 'dispose'):
e.dispose()
__traverse_options__ = {'schema_visitor':True}
-class DDL(object):
- """A literal DDL statement.
-
- Specifies literal SQL DDL to be executed by the database. DDL objects can
- be attached to ``Tables`` or ``MetaData`` instances, conditionally
- executing SQL as part of the DDL lifecycle of those schema items. Basic
- templating support allows a single DDL instance to handle repetitive tasks
- for multiple tables.
-
- Examples::
-
- tbl = Table('users', metadata, Column('uid', Integer)) # ...
- DDL('DROP TRIGGER users_trigger').execute_at('before-create', tbl)
-
- spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE', on='somedb')
- spow.execute_at('after-create', tbl)
-
- drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE')
- connection.execute(drop_spow)
- """
-
- def __init__(self, statement, on=None, context=None, bind=None):
- """Create a DDL statement.
-
- statement
- A string or unicode string to be executed. Statements will be
- processed with Python's string formatting operator. See the
- ``context`` argument and the ``execute_at`` method.
-
- A literal '%' in a statement must be escaped as '%%'.
-
- SQL bind parameters are not available in DDL statements.
-
- on
- Optional filtering criteria. May be a string or a callable
- predicate. If a string, it will be compared to the name of the
- executing database dialect::
-
- DDL('something', on='postgres')
-
- If a callable, it will be invoked with three positional arguments:
-
- event
- The name of the event that has triggered this DDL, such as
- 'after-create' Will be None if the DDL is executed explicitly.
-
- schema_item
- A SchemaItem instance, such as ``Table`` or ``MetaData``. May be
- None if the DDL is executed explicitly.
-
- connection
- The ``Connection`` being used for DDL execution
-
- If the callable returns a true value, the DDL statement will be
- executed.
-
- context
- Optional dictionary, defaults to None. These values will be
- available for use in string substitutions on the DDL statement.
-
- bind
- Optional. A :class:`~sqlalchemy.engine.base.Connectable`, used by default when ``execute()``
- is invoked without a bind argument.
-
- """
-
- if not isinstance(statement, basestring):
- raise exc.ArgumentError(
- "Expected a string or unicode SQL statement, got '%r'" %
- statement)
- if (on is not None and
- (not isinstance(on, basestring) and not util.callable(on))):
- raise exc.ArgumentError(
- "Expected the name of a database dialect or a callable for "
- "'on' criteria, got type '%s'." % type(on).__name__)
-
- self.statement = statement
- self.on = on
- self.context = context or {}
- self._bind = bind
+class DDLElement(expression.ClauseElement):
+ """Base class for DDL expression constructs."""
+
+ supports_execution = True
+ _autocommit = True
+ schema_item = None
+ on = None
+
def execute(self, bind=None, schema_item=None):
"""Execute this DDL immediately.
if bind is None:
bind = _bind_or_error(self)
- # no SQL bind params are supported
+
if self._should_execute(None, schema_item, bind):
- executable = expression.text(self._expand(schema_item, bind))
- return bind.execute(executable)
+ return bind.execute(self.against(schema_item))
else:
bind.engine.logger.info("DDL execution skipped, criteria not met.")
will be executed using the same Connection and transactional context
as the Table create/drop itself. The ``.bind`` property of this
statement is ignored.
-
+
event
One of the events defined in the schema item's ``.ddl_events``;
e.g. 'before-create', 'after-create', 'before-drop' or 'after-drop'
schema_item.ddl_listeners[event].append(self)
return self
- def bind(self):
- """An Engine or Connection to which this DDL is bound.
-
- This property may be assigned an ``Engine`` or ``Connection``, or
- assigned a string or URL to automatically create a basic ``Engine``
- for this bind with ``create_engine()``.
- """
- return self._bind
-
- def _bind_to(self, bind):
- """Bind this MetaData to an Engine, Connection, string or URL."""
-
- global URL
- if URL is None:
- from sqlalchemy.engine.url import URL
+ @expression._generative
+ def against(self, schema_item):
+ """Return a copy of this DDL against a specific schema item."""
- if isinstance(bind, (basestring, URL)):
- from sqlalchemy import create_engine
- self._bind = create_engine(bind)
- else:
- self._bind = bind
- bind = property(bind, _bind_to)
+ self.schema_item = schema_item
- def __call__(self, event, schema_item, bind):
+ def __call__(self, event, schema_item, bind, **kw):
"""Execute the DDL as a ddl_listener."""
- if self._should_execute(event, schema_item, bind):
- statement = expression.text(self._expand(schema_item, bind))
- return bind.execute(statement)
+ if self._should_execute(event, schema_item, bind, **kw):
+ return bind.execute(self.against(schema_item))
- def _expand(self, schema_item, bind):
- return self.statement % self._prepare_context(schema_item, bind)
+ def _check_ddl_on(self, on):
+ if (on is not None and
+ (not isinstance(on, (basestring, tuple, list, set)) and not util.callable(on))):
+ raise exc.ArgumentError(
+ "Expected the name of a database dialect, a tuple of names, or a callable for "
+ "'on' criteria, got type '%s'." % type(on).__name__)
- def _should_execute(self, event, schema_item, bind):
+ def _should_execute(self, event, schema_item, bind, **kw):
if self.on is None:
return True
elif isinstance(self.on, basestring):
return self.on == bind.engine.name
+ elif isinstance(self.on, (tuple, list, set)):
+ return bind.engine.name in self.on
else:
- return self.on(event, schema_item, bind)
+ return self.on(event, schema_item, bind, **kw)
- def _prepare_context(self, schema_item, bind):
- # table events can substitute table and schema name
- if isinstance(schema_item, Table):
- context = self.context.copy()
+ def bind(self):
+ if self._bind:
+ return self._bind
+ def _set_bind(self, bind):
+ self._bind = bind
+ bind = property(bind, _set_bind)
- preparer = bind.dialect.identifier_preparer
- path = preparer.format_table_seq(schema_item)
- if len(path) == 1:
- table, schema = path[0], ''
- else:
- table, schema = path[-1], path[0]
+ def _generate(self):
+ s = self.__class__.__new__(self.__class__)
+ s.__dict__ = self.__dict__.copy()
+ return s
+
+ def _compiler(self, dialect, **kw):
+ """Return a compiler appropriate for this ClauseElement, given a Dialect."""
+
+ return dialect.ddl_compiler(dialect, self, **kw)
+
+class DDL(DDLElement):
+ """A literal DDL statement.
+
+ Specifies literal SQL DDL to be executed by the database. DDL objects can
+ be attached to ``Tables`` or ``MetaData`` instances, conditionally
+ executing SQL as part of the DDL lifecycle of those schema items. Basic
+ templating support allows a single DDL instance to handle repetitive tasks
+ for multiple tables.
+
+ Examples::
+
+ tbl = Table('users', metadata, Column('uid', Integer)) # ...
+ DDL('DROP TRIGGER users_trigger').execute_at('before-create', tbl)
+
+ spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE', on='somedb')
+ spow.execute_at('after-create', tbl)
+
+ drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE')
+ connection.execute(drop_spow)
+ """
+
+ __visit_name__ = "ddl"
+
+ def __init__(self, statement, on=None, context=None, bind=None):
+ """Create a DDL statement.
+
+ statement
+ A string or unicode string to be executed. Statements will be
+ processed with Python's string formatting operator. See the
+ ``context`` argument and the ``execute_at`` method.
+
+ A literal '%' in a statement must be escaped as '%%'.
+
+ SQL bind parameters are not available in DDL statements.
+
+ on
+ Optional filtering criteria. May be a string, tuple or a callable
+ predicate. If a string, it will be compared to the name of the
+ executing database dialect::
+
+ DDL('something', on='postgresql')
+
+ If a tuple, specifies multiple dialect names:
+
+ DDL('something', on=('postgresql', 'mysql'))
+
+ If a callable, it will be invoked with three positional arguments
+ as well as optional keyword arguments:
+
+ event
+ The name of the event that has triggered this DDL, such as
+ 'after-create' Will be None if the DDL is executed explicitly.
+
+ schema_item
+ A SchemaItem instance, such as ``Table`` or ``MetaData``. May be
+ None if the DDL is executed explicitly.
+
+ connection
+ The ``Connection`` being used for DDL execution
+
+ **kw
+ Keyword arguments which may be sent include:
+ tables - a list of Table objects which are to be created/
+ dropped within a MetaData.create_all() or drop_all() method
+ call.
+
+ If the callable returns a true value, the DDL statement will be
+ executed.
+
+ context
+ Optional dictionary, defaults to None. These values will be
+ available for use in string substitutions on the DDL statement.
+
+ bind
+ Optional. A :class:`~sqlalchemy.engine.base.Connectable`, used by default when ``execute()``
+ is invoked without a bind argument.
+
+ """
+
+ if not isinstance(statement, basestring):
+ raise exc.ArgumentError(
+ "Expected a string or unicode SQL statement, got '%r'" %
+ statement)
+
+ self.statement = statement
+ self.context = context or {}
+
+ self._check_ddl_on(on)
+ self.on = on
+ self._bind = bind
- context.setdefault('table', table)
- context.setdefault('schema', schema)
- context.setdefault('fullname', preparer.format_table(schema_item))
- return context
- else:
- return self.context
def __repr__(self):
return '<%s@%s; %s>' % (
if getattr(self, key)]))
def _to_schema_column(element):
- if hasattr(element, '__clause_element__'):
- element = element.__clause_element__()
- if not isinstance(element, Column):
- raise exc.ArgumentError("schema.Column object expected")
- return element
+ if hasattr(element, '__clause_element__'):
+ element = element.__clause_element__()
+ if not isinstance(element, Column):
+ raise exc.ArgumentError("schema.Column object expected")
+ return element
+
+def _to_schema_column_or_string(element):
+ if hasattr(element, '__clause_element__'):
+ element = element.__clause_element__()
+ return element
+
+class _CreateDropBase(DDLElement):
+ """Base class for DDL constucts that represent CREATE and DROP or equivalents.
+
+ The common theme of _CreateDropBase is a single
+ ``element`` attribute which refers to the element
+ to be created or dropped.
+ """
+
+ def __init__(self, element, on=None, bind=None):
+ self.element = element
+ self._check_ddl_on(on)
+ self.on = on
+ self.bind = bind
+ element.inline_ddl = False
+
+class CreateTable(_CreateDropBase):
+ """Represent a CREATE TABLE statement."""
+
+ __visit_name__ = "create_table"
+
+class DropTable(_CreateDropBase):
+ """Represent a DROP TABLE statement."""
+
+ __visit_name__ = "drop_table"
+
+ def __init__(self, element, cascade=False, **kw):
+ self.cascade = cascade
+ super(DropTable, self).__init__(element, **kw)
+
+class CreateSequence(_CreateDropBase):
+ """Represent a CREATE SEQUENCE statement."""
+
+ __visit_name__ = "create_sequence"
+
+class DropSequence(_CreateDropBase):
+ """Represent a DROP SEQUENCE statement."""
+
+ __visit_name__ = "drop_sequence"
+
+class CreateIndex(_CreateDropBase):
+ """Represent a CREATE INDEX statement."""
+
+ __visit_name__ = "create_index"
+
+class DropIndex(_CreateDropBase):
+ """Represent a DROP INDEX statement."""
+
+ __visit_name__ = "drop_index"
+
+class AddConstraint(_CreateDropBase):
+ """Represent an ALTER TABLE ADD CONSTRAINT statement."""
+
+ __visit_name__ = "add_constraint"
+
+class DropConstraint(_CreateDropBase):
+ """Represent an ALTER TABLE DROP CONSTRAINT statement."""
+
+ __visit_name__ = "drop_constraint"
+
+ def __init__(self, element, cascade=False, **kw):
+ self.cascade = cascade
+ super(DropConstraint, self).__init__(element, **kw)
+
def _bind_or_error(schemaitem):
bind = schemaitem.bind
if not bind:
"""Base SQL and DDL compiler implementations.
-Provides the :class:`~sqlalchemy.sql.compiler.DefaultCompiler` class, which is
-responsible for generating all SQL query strings, as well as
-:class:`~sqlalchemy.sql.compiler.SchemaGenerator` and :class:`~sqlalchemy.sql.compiler.SchemaDropper`
-which issue CREATE and DROP DDL for tables, sequences, and indexes.
-
-The elements in this module are used by public-facing constructs like
-:class:`~sqlalchemy.sql.expression.ClauseElement` and :class:`~sqlalchemy.engine.Engine`.
-While dialect authors will want to be familiar with this module for the purpose of
-creating database-specific compilers and schema generators, the module
-is otherwise internal to SQLAlchemy.
+Classes provided include:
+
+:class:`~sqlalchemy.sql.compiler.SQLCompiler` - renders SQL
+strings
+
+:class:`~sqlalchemy.sql.compiler.DDLCompiler` - renders DDL
+(data definition language) strings
+
+:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders
+type specification strings.
+
+To generate user-defined SQL strings, see
+:module:`~sqlalchemy.ext.compiler`.
+
"""
-import string, re
+import re
from sqlalchemy import schema, engine, util, exc
from sqlalchemy.sql import operators, functions, util as sql_util, visitors
from sqlalchemy.sql import expression as sql
OPERATORS = {
- operators.and_ : 'AND',
- operators.or_ : 'OR',
- operators.inv : 'NOT',
- operators.add : '+',
- operators.mul : '*',
- operators.sub : '-',
- operators.div : '/',
- operators.mod : '%',
- operators.truediv : '/',
- operators.lt : '<',
- operators.le : '<=',
- operators.ne : '!=',
- operators.gt : '>',
- operators.ge : '>=',
- operators.eq : '=',
- operators.distinct_op : 'DISTINCT',
- operators.concat_op : '||',
- operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- operators.between_op : 'BETWEEN',
- operators.match_op : 'MATCH',
- operators.in_op : 'IN',
- operators.notin_op : 'NOT IN',
+ # binary
+ operators.and_ : ' AND ',
+ operators.or_ : ' OR ',
+ operators.add : ' + ',
+ operators.mul : ' * ',
+ operators.sub : ' - ',
+ # Py2K
+ operators.div : ' / ',
+ # end Py2K
+ operators.mod : ' % ',
+ operators.truediv : ' / ',
+ operators.lt : ' < ',
+ operators.le : ' <= ',
+ operators.ne : ' != ',
+ operators.gt : ' > ',
+ operators.ge : ' >= ',
+ operators.eq : ' = ',
+ operators.concat_op : ' || ',
+ operators.between_op : ' BETWEEN ',
+ operators.match_op : ' MATCH ',
+ operators.in_op : ' IN ',
+ operators.notin_op : ' NOT IN ',
operators.comma_op : ', ',
- operators.desc_op : 'DESC',
- operators.asc_op : 'ASC',
- operators.from_ : 'FROM',
- operators.as_ : 'AS',
- operators.exists : 'EXISTS',
- operators.is_ : 'IS',
- operators.isnot : 'IS NOT',
- operators.collate : 'COLLATE',
+ operators.from_ : ' FROM ',
+ operators.as_ : ' AS ',
+ operators.is_ : ' IS ',
+ operators.isnot : ' IS NOT ',
+ operators.collate : ' COLLATE ',
+
+ # unary
+ operators.exists : 'EXISTS ',
+ operators.distinct_op : 'DISTINCT ',
+ operators.inv : 'NOT ',
+
+ # modifiers
+ operators.desc_op : ' DESC',
+ operators.asc_op : ' ASC',
}
FUNCTIONS = {
def quote(self):
return self.element.quote
-class DefaultCompiler(engine.Compiled):
+class SQLCompiler(engine.Compiled):
"""Default implementation of Compiled.
Compiles ClauseElements into SQL strings. Uses a similar visit
"""
- operators = OPERATORS
- functions = FUNCTIONS
extract_map = EXTRACT_MAP
- # if we are insert/update/delete.
- # set to true when we visit an INSERT, UPDATE or DELETE
+ # class-level defaults which can be set at the instance
+ # level to define if this Compiled instance represents
+ # INSERT/UPDATE/DELETE
isdelete = isinsert = isupdate = False
-
+ returning = None
+
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
statement.
"""
- engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs)
+ engine.Compiled.__init__(self, dialect, statement, **kwargs)
+
+ self.column_keys = column_keys
# compile INSERT/UPDATE defaults/sequences inlined (no pre-execute)
self.inline = inline or getattr(statement, 'inline', False)
# or dialect.max_identifier_length
self.truncated_names = {}
- def compile(self):
- self.string = self.process(self.statement)
-
- def process(self, obj, **kwargs):
- return obj._compiler_dispatch(self, **kwargs)
-
def is_subquery(self):
return len(self.stack) > 1
"""return a dictionary of bind parameter keys and values"""
if params:
- params = util.column_dict(params)
pd = {}
for bindparam, name in self.bind_names.iteritems():
for paramname in (bindparam.key, bindparam.shortname, name):
pd[self.bind_names[bindparam]] = bindparam.value
return pd
- params = property(construct_params)
+ params = property(construct_params, doc="""
+ Return the bind params for this compiled object.
+
+ """)
def default_from(self):
"""Called when a SELECT statement has no froms, and no FROM clause is to be appended.
self._truncated_identifier("colident", label.name) or label.name
if result_map is not None:
- result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
+ result_map[labelname.lower()] = \
+ (label.name, (label, label.element, labelname), label.element.type)
- return self.process(label.element) + " " + \
- self.operator_string(operators.as_) + " " + \
+ return self.process(label.element) + \
+ OPERATORS[operators.as_] + \
self.preparer.format_label(label, labelname)
else:
return self.process(label.element)
return name
else:
if column.table.schema:
- schema_prefix = self.preparer.quote_schema(column.table.schema, column.table.quote_schema) + '.'
+ schema_prefix = self.preparer.quote_schema(
+ column.table.schema,
+ column.table.quote_schema) + '.'
else:
schema_prefix = ''
tablename = column.table.name
tablename = isinstance(tablename, sql._generated_label) and \
self._truncated_identifier("alias", tablename) or tablename
- return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name
+ return schema_prefix + \
+ self.preparer.quote(tablename, column.table.quote) + "." + name
def escape_literal_column(self, text):
"""provide escaping for the literal_column() construct."""
return index.name
def visit_typeclause(self, typeclause, **kwargs):
- return typeclause.type.dialect_impl(self.dialect).get_col_spec()
+ return self.dialect.type_compiler.process(typeclause.type)
def post_process_text(self, text):
return text
sep = clauselist.operator
if sep is None:
sep = " "
- elif sep is operators.comma_op:
- sep = ', '
else:
- sep = " " + self.operator_string(clauselist.operator) + " "
+ sep = OPERATORS[clauselist.operator]
return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
if s is not None)
return x
def visit_cast(self, cast, **kwargs):
- return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
+ return "CAST(%s AS %s)" % \
+ (self.process(cast.clause), self.process(cast.typeclause))
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
if result_map is not None:
result_map[func.name.lower()] = (func.name, None, func.type)
- name = self.function_string(func)
-
- if util.callable(name):
- return name(*[self.process(x) for x in func.clauses])
+ disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+ if disp:
+ return disp(func, **kwargs)
else:
- return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
+ name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
+ return ".".join(func.packagenames + [name]) % \
+ {'expr':self.function_argspec(func, **kwargs)}
def function_argspec(self, func, **kwargs):
return self.process(func.clause_expr, **kwargs)
- def function_string(self, func):
- return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s"))
-
def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
entry = self.stack and self.stack[-1] or {}
self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
- text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i)
- for i, c in enumerate(cs.selects)),
- " " + cs.keyword + " ")
+ text = (" " + cs.keyword + " ").join(
+ (self.process(c, asfrom=asfrom, parens=False, compound_index=i)
+ for i, c in enumerate(cs.selects))
+ )
+
group_by = self.process(cs._group_by_clause, asfrom=asfrom)
if group_by:
text += " GROUP BY " + group_by
def visit_unary(self, unary, **kw):
s = self.process(unary.element, **kw)
if unary.operator:
- s = self.operator_string(unary.operator) + " " + s
+ s = OPERATORS[unary.operator] + s
if unary.modifier:
- s = s + " " + self.operator_string(unary.modifier)
+ s = s + OPERATORS[unary.modifier]
return s
def visit_binary(self, binary, **kwargs):
- op = self.operator_string(binary.operator)
- if util.callable(op):
- return op(self.process(binary.left), self.process(binary.right), **binary.modifiers)
- else:
- return self.process(binary.left) + " " + op + " " + self.process(binary.right)
+
+ return self._operator_dispatch(binary.operator,
+ binary,
+ lambda opstr: self.process(binary.left) + opstr + self.process(binary.right),
+ **kwargs
+ )
- def operator_string(self, operator):
- return self.operators.get(operator, str(operator))
+ def visit_like_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+ def visit_notlike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def visit_ilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def visit_notilike_op(self, binary, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+ + (escape and ' ESCAPE \'%s\'' % escape or '')
+
+ def _operator_dispatch(self, operator, element, fn, **kw):
+ if util.callable(operator):
+ disp = getattr(self, "visit_%s" % operator.__name__, None)
+ if disp:
+ return disp(element, **kw)
+ else:
+ return fn(OPERATORS[operator])
+ else:
+ return fn(" " + operator + " ")
+
def visit_bindparam(self, bindparam, **kwargs):
name = self._truncate_bindparam(bindparam)
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam and (existing.unique or bindparam.unique):
- raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
+ raise exc.CompileError(
+ "Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key
+ )
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name)
if isinstance(column, sql._Label):
return column
- if select.use_labels and column._label:
+ if select and select.use_labels and column._label:
return _CompileLabel(column, column._label)
if \
column.table is not None and \
not isinstance(column.table, sql.Select):
return _CompileLabel(column, sql._generated_label(column.name))
- elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
+ elif not isinstance(column,
+ (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
and (not hasattr(column, 'name') or isinstance(column, sql.Function)):
return _CompileLabel(column, column.anon_label)
else:
return column
- def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs):
+ def visit_select(self, select, asfrom=False, parens=True,
+ iswrapper=False, compound_index=1, **kwargs):
entry = self.stack and self.stack[-1] or {}
return text
def get_select_precolumns(self, select):
- """Called when building a ``SELECT`` statement, position is just before column list."""
-
+ """Called when building a ``SELECT`` statement, position is just before
+ column list.
+
+ """
return select._distinct and "DISTINCT " or ""
def order_by_clause(self, select):
def visit_table(self, table, asfrom=False, **kwargs):
if asfrom:
if getattr(table, "schema", None):
- return self.preparer.quote_schema(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote)
+ return self.preparer.quote_schema(table.schema, table.quote_schema) + \
+ "." + self.preparer.quote(table.name, table.quote)
else:
return self.preparer.quote(table.name, table.quote)
else:
return ""
def visit_join(self, join, asfrom=False, **kwargs):
- return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
+ return (self.process(join.left, asfrom=True) + \
+ (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
def visit_sequence(self, seq):
def visit_insert(self, insert_stmt):
self.isinsert = True
colparams = self._get_colparams(insert_stmt)
- preparer = self.preparer
-
- insert = ' '.join(["INSERT"] +
- [self.process(x) for x in insert_stmt._prefixes])
- if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert:
+ if not colparams and \
+ not self.dialect.supports_default_values and \
+ not self.dialect.supports_empty_insert:
raise exc.CompileError(
"The version of %s you are using does not support empty inserts." % self.dialect.name)
- elif not colparams and self.dialect.supports_default_values:
- return (insert + " INTO %s DEFAULT VALUES" % (
- (preparer.format_table(insert_stmt.table),)))
- else:
- return (insert + " INTO %s (%s) VALUES (%s)" %
- (preparer.format_table(insert_stmt.table),
- ', '.join([preparer.format_column(c[0])
- for c in colparams]),
- ', '.join([c[1] for c in colparams])))
+ preparer = self.preparer
+ supports_default_values = self.dialect.supports_default_values
+
+ text = "INSERT"
+
+ prefixes = [self.process(x) for x in insert_stmt._prefixes]
+ if prefixes:
+ text += " " + " ".join(prefixes)
+
+ text += " INTO " + preparer.format_table(insert_stmt.table)
+
+ if colparams or not supports_default_values:
+ text += " (%s)" % ', '.join([preparer.format_column(c[0])
+ for c in colparams])
+
+ if self.returning or insert_stmt._returning:
+ self.returning = self.returning or insert_stmt._returning
+ returning_clause = self.returning_clause(insert_stmt, self.returning)
+
+ # cheating
+ if returning_clause.startswith("OUTPUT"):
+ text += " " + returning_clause
+ returning_clause = None
+
+ if not colparams and supports_default_values:
+ text += " DEFAULT VALUES"
+ else:
+ text += " VALUES (%s)" % \
+ ', '.join([c[1] for c in colparams])
+
+ if self.returning and returning_clause:
+ text += " " + returning_clause
+
+ return text
+
def visit_update(self, update_stmt):
self.stack.append({'from': set([update_stmt.table])})
self.isupdate = True
colparams = self._get_colparams(update_stmt)
- text = ' '.join((
- "UPDATE",
- self.preparer.format_table(update_stmt.table),
- 'SET',
- ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
- for c in colparams)
- ))
-
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table)
+
+ text += ' SET ' + \
+ ', '.join(
+ self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1]
+ for c in colparams
+ )
+
+ if update_stmt._returning:
+ self.returning = update_stmt._returning
+ returning_clause = self.returning_clause(update_stmt, update_stmt._returning)
+ if returning_clause.startswith("OUTPUT"):
+ text += " " + returning_clause
+ returning_clause = None
+
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
+ if self.returning and returning_clause:
+ text += " " + returning_clause
+
self.stack.pop(-1)
return text
self.postfetch = []
self.prefetch = []
-
+ self.returning = []
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
# create a list of column assignment clauses as tuples
values = []
+
+ need_pks = self.isinsert and \
+ not self.inline and \
+ not self.statement._returning
+
+ implicit_returning = need_pks and \
+ self.dialect.implicit_returning and \
+ stmt.table.implicit_returning
+
for c in stmt.table.columns:
if c.key in parameters:
value = parameters[c.key]
self.postfetch.append(c)
value = self.process(value.self_group())
values.append((c, value))
+
elif isinstance(c, schema.Column):
if self.isinsert:
- if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline):
- if (((isinstance(c.default, schema.Sequence) and
- not c.default.optional) or
- not self.dialect.supports_pk_autoincrement) or
- (c.default is not None and
- not isinstance(c.default, schema.Sequence))):
- values.append((c, create_bind_param(c, None)))
- self.prefetch.append(c)
+ if c.primary_key and \
+ need_pks and \
+ (
+ c is not stmt.table._autoincrement_column or
+ not self.dialect.postfetch_lastrowid
+ ):
+
+ if implicit_returning:
+ if isinstance(c.default, schema.Sequence):
+ proc = self.process(c.default)
+ if proc is not None:
+ values.append((c, proc))
+ self.returning.append(c)
+ elif isinstance(c.default, schema.ColumnDefault) and \
+ isinstance(c.default.arg, sql.ClauseElement):
+ values.append((c, self.process(c.default.arg.self_group())))
+ self.returning.append(c)
+ elif c.default is not None:
+ values.append((c, create_bind_param(c, None)))
+ self.prefetch.append(c)
+ else:
+ self.returning.append(c)
+ else:
+ if (
+ c.default is not None and \
+ (
+ self.dialect.supports_sequences or
+ not isinstance(c.default, schema.Sequence)
+ )
+ ) or \
+ self.dialect.preexecute_autoincrement_sequences:
+
+ values.append((c, create_bind_param(c, None)))
+ self.prefetch.append(c)
+
elif isinstance(c.default, schema.ColumnDefault):
if isinstance(c.default.arg, sql.ClauseElement):
values.append((c, self.process(c.default.arg.self_group())))
+
if not c.primary_key:
# dont add primary key column to postfetch
self.postfetch.append(c)
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
+ if delete_stmt._returning:
+ self.returning = delete_stmt._returning
+ returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning)
+ if returning_clause.startswith("OUTPUT"):
+ text += " " + returning_clause
+ returning_clause = None
+
if delete_stmt._whereclause:
text += " WHERE " + self.process(delete_stmt._whereclause)
+ if self.returning and returning_clause:
+ text += " " + returning_clause
+
self.stack.pop(-1)
return text
def visit_release_savepoint(self, savepoint_stmt):
return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
- def __str__(self):
- return self.string or ''
-
-class DDLBase(engine.SchemaIterator):
- def find_alterables(self, tables):
- alterables = []
- class FindAlterables(schema.SchemaVisitor):
- def visit_foreign_key_constraint(self, constraint):
- if constraint.use_alter and constraint.table in tables:
- alterables.append(constraint)
- findalterables = FindAlterables()
- for table in tables:
- for c in table.constraints:
- findalterables.traverse(c)
- return alterables
-
- def _validate_identifier(self, ident, truncate):
- if truncate:
- if len(ident) > self.dialect.max_identifier_length:
- counter = getattr(self, 'counter', 0)
- self.counter = counter + 1
- return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
- else:
- return ident
- else:
- self.dialect.validate_identifier(ident)
- return ident
-
-
-class SchemaGenerator(DDLBase):
- def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
- super(SchemaGenerator, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
- self.tables = tables and set(tables) or None
- self.preparer = dialect.identifier_preparer
- self.dialect = dialect
- def get_column_specification(self, column, first_pk=False):
- raise NotImplementedError()
+class DDLCompiler(engine.Compiled):
+ @property
+ def preparer(self):
+ return self.dialect.identifier_preparer
- def _can_create(self, table):
- self.dialect.validate_identifier(table.name)
- if table.schema:
- self.dialect.validate_identifier(table.schema)
- return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
+ def construct_params(self, params=None):
+ return None
+
+ def visit_ddl(self, ddl, **kwargs):
+ # table events can substitute table and schema name
+ context = ddl.context
+ if isinstance(ddl.schema_item, schema.Table):
+ context = context.copy()
+
+ preparer = self.dialect.identifier_preparer
+ path = preparer.format_table_seq(ddl.schema_item)
+ if len(path) == 1:
+ table, sch = path[0], ''
+ else:
+ table, sch = path[-1], path[0]
- def visit_metadata(self, metadata):
- if self.tables:
- tables = self.tables
- else:
- tables = metadata.tables.values()
- collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
- for table in collection:
- self.traverse_single(table)
- if self.dialect.supports_alter:
- for alterable in self.find_alterables(collection):
- self.add_foreignkey(alterable)
-
- def visit_table(self, table):
- for listener in table.ddl_listeners['before-create']:
- listener('before-create', table, self.connection)
+ context.setdefault('table', table)
+ context.setdefault('schema', sch)
+ context.setdefault('fullname', preparer.format_table(ddl.schema_item))
+
+ return ddl.statement % context
- for column in table.columns:
- if column.default is not None:
- self.traverse_single(column.default)
+ def visit_create_table(self, create):
+ table = create.element
+ preparer = self.dialect.identifier_preparer
- self.append("\n" + " ".join(['CREATE'] +
- table._prefixes +
+ text = "\n" + " ".join(['CREATE'] + \
+ table._prefixes + \
['TABLE',
- self.preparer.format_table(table),
- "("]))
+ preparer.format_table(table),
+ "("])
separator = "\n"
# if only one primary key, specify it along with the column
first_pk = False
for column in table.columns:
- self.append(separator)
+ text += separator
separator = ", \n"
- self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk))
+ text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)
if column.primary_key:
first_pk = True
- for constraint in column.constraints:
- self.traverse_single(constraint)
+ const = " ".join(self.process(constraint) for constraint in column.constraints)
+ if const:
+ text += " " + const
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
if table.primary_key:
- self.traverse_single(table.primary_key)
- for constraint in [c for c in table.constraints if c is not table.primary_key]:
- self.traverse_single(constraint)
+ text += ", \n\t" + self.process(table.primary_key)
+
+ const = ", \n\t".join(
+ self.process(constraint) for constraint in table.constraints
+ if constraint is not table.primary_key
+ and constraint.inline_ddl
+ and (not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False))
+ )
+ if const:
+ text += ", \n\t" + const
+
+ text += "\n)%s\n\n" % self.post_create_table(table)
+ return text
+
+ def visit_drop_table(self, drop):
+ ret = "\nDROP TABLE " + self.preparer.format_table(drop.element)
+ if drop.cascade:
+ ret += " CASCADE CONSTRAINTS"
+ return ret
+
+ def visit_create_index(self, create):
+ index = create.element
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ text += "INDEX %s ON %s (%s)" \
+ % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
+ preparer.format_table(index.table),
+ ', '.join(preparer.quote(c.name, c.quote)
+ for c in index.columns))
+ return text
- self.append("\n)%s\n\n" % self.post_create_table(table))
- self.execute()
+ def visit_drop_index(self, drop):
+ index = drop.element
+ return "\nDROP INDEX " + \
+ self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
- if hasattr(table, 'indexes'):
- for index in table.indexes:
- self.traverse_single(index)
+ def visit_add_constraint(self, create):
+ preparer = self.preparer
+ return "ALTER TABLE %s ADD %s" % (
+ self.preparer.format_table(create.element.table),
+ self.process(create.element)
+ )
+
+ def visit_drop_constraint(self, drop):
+ preparer = self.preparer
+ return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
+ self.preparer.format_table(drop.element.table),
+ self.preparer.format_constraint(drop.element),
+ " CASCADE" if drop.cascade else ""
+ )
+
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column) + " " + \
+ self.dialect.type_compiler.process(column.type)
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
- for listener in table.ddl_listeners['after-create']:
- listener('after-create', table, self.connection)
+ if not column.nullable:
+ colspec += " NOT NULL"
+ return colspec
def post_create_table(self, table):
return ''
+ def _compile(self, tocompile, parameters):
+ """compile the given string/parameters using this SchemaGenerator's dialect."""
+
+ compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
+ compiler.compile()
+ return compiler
+
+ def _validate_identifier(self, ident, truncate):
+ if truncate:
+ if len(ident) > self.dialect.max_identifier_length:
+ counter = getattr(self, 'counter', 0)
+ self.counter = counter + 1
+ return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
+ else:
+ return ident
+ else:
+ self.dialect.validate_identifier(ident)
+ return ident
+
def get_column_default_string(self, column):
if isinstance(column.server_default, schema.DefaultClause):
if isinstance(column.server_default.arg, basestring):
else:
return None
- def _compile(self, tocompile, parameters):
- """compile the given string/parameters using this SchemaGenerator's dialect."""
- compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
- compiler.compile()
- return compiler
-
def visit_check_constraint(self, constraint):
- self.append(", \n\t")
+ text = ""
if constraint.name is not None:
- self.append("CONSTRAINT %s " %
- self.preparer.format_constraint(constraint))
- self.append(" CHECK (%s)" % constraint.sqltext)
- self.define_constraint_deferrability(constraint)
+ text += "CONSTRAINT %s " % \
+ self.preparer.format_constraint(constraint)
+ text += " CHECK (%s)" % constraint.sqltext
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_column_check_constraint(self, constraint):
- self.append(" CHECK (%s)" % constraint.sqltext)
- self.define_constraint_deferrability(constraint)
+ text = " CHECK (%s)" % constraint.sqltext
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
- return
- self.append(", \n\t")
+ return ''
+ text = ""
if constraint.name is not None:
- self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
- self.append("PRIMARY KEY ")
- self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
- for c in constraint))
- self.define_constraint_deferrability(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+ text += "PRIMARY KEY "
+ text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
+ for c in constraint)
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_foreign_key_constraint(self, constraint):
- if constraint.use_alter and self.dialect.supports_alter:
- return
- self.append(", \n\t ")
- self.define_foreign_key(constraint)
-
- def add_foreignkey(self, constraint):
- self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table))
- self.define_foreign_key(constraint)
- self.execute()
-
- def define_foreign_key(self, constraint):
- preparer = self.preparer
+ preparer = self.dialect.identifier_preparer
+ text = ""
if constraint.name is not None:
- self.append("CONSTRAINT %s " %
- preparer.format_constraint(constraint))
- table = list(constraint.elements)[0].column.table
- self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
+ text += "CONSTRAINT %s " % \
+ preparer.format_constraint(constraint)
+ remote_table = list(constraint._elements.values())[0].column.table
+ text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
', '.join(preparer.quote(f.parent.name, f.parent.quote)
- for f in constraint.elements),
- preparer.format_table(table),
+ for f in constraint._elements.values()),
+ preparer.format_table(remote_table),
', '.join(preparer.quote(f.column.name, f.column.quote)
- for f in constraint.elements)
- ))
- if constraint.ondelete is not None:
- self.append(" ON DELETE %s" % constraint.ondelete)
- if constraint.onupdate is not None:
- self.append(" ON UPDATE %s" % constraint.onupdate)
- self.define_constraint_deferrability(constraint)
+ for f in constraint._elements.values())
+ )
+ text += self.define_constraint_cascades(constraint)
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_unique_constraint(self, constraint):
- self.append(", \n\t")
+ text = ""
if constraint.name is not None:
- self.append("CONSTRAINT %s " %
- self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)))
- self.define_constraint_deferrability(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+ text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))
+ text += self.define_constraint_deferrability(constraint)
+ return text
+ def define_constraint_cascades(self, constraint):
+ text = ""
+ if constraint.ondelete is not None:
+ text += " ON DELETE %s" % constraint.ondelete
+ if constraint.onupdate is not None:
+ text += " ON UPDATE %s" % constraint.onupdate
+ return text
+
def define_constraint_deferrability(self, constraint):
+ text = ""
if constraint.deferrable is not None:
if constraint.deferrable:
- self.append(" DEFERRABLE")
+ text += " DEFERRABLE"
else:
- self.append(" NOT DEFERRABLE")
+ text += " NOT DEFERRABLE"
if constraint.initially is not None:
- self.append(" INITIALLY %s" % constraint.initially)
+ text += " INITIALLY %s" % constraint.initially
+ return text
+
+
+class GenericTypeCompiler(engine.TypeCompiler):
+ def visit_CHAR(self, type_):
+ return "CHAR" + (type_.length and "(%d)" % type_.length or "")
- def visit_column(self, column):
- pass
+ def visit_NCHAR(self, type_):
+ return "NCHAR" + (type_.length and "(%d)" % type_.length or "")
+
+ def visit_FLOAT(self, type_):
+ return "FLOAT"
- def visit_index(self, index):
- preparer = self.preparer
- self.append("CREATE ")
- if index.unique:
- self.append("UNIQUE ")
- self.append("INDEX %s ON %s (%s)" \
- % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
- preparer.format_table(index.table),
- ', '.join(preparer.quote(c.name, c.quote)
- for c in index.columns)))
- self.execute()
+ def visit_NUMERIC(self, type_):
+ if type_.precision is None:
+ return "NUMERIC"
+ else:
+ return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}
+ def visit_DECIMAL(self, type_):
+ return "DECIMAL"
+
+ def visit_INTEGER(self, type_):
+ return "INTEGER"
-class SchemaDropper(DDLBase):
- def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
- super(SchemaDropper, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
- self.tables = tables
- self.preparer = dialect.identifier_preparer
- self.dialect = dialect
+ def visit_SMALLINT(self, type_):
+ return "SMALLINT"
- def visit_metadata(self, metadata):
- if self.tables:
- tables = self.tables
- else:
- tables = metadata.tables.values()
- collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
- if self.dialect.supports_alter:
- for alterable in self.find_alterables(collection):
- self.drop_foreignkey(alterable)
- for table in collection:
- self.traverse_single(table)
-
- def _can_drop(self, table):
- self.dialect.validate_identifier(table.name)
- if table.schema:
- self.dialect.validate_identifier(table.schema)
- return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
-
- def visit_index(self, index):
- self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote))
- self.execute()
-
- def drop_foreignkey(self, constraint):
- self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (
- self.preparer.format_table(constraint.table),
- self.preparer.format_constraint(constraint)))
- self.execute()
-
- def visit_table(self, table):
- for listener in table.ddl_listeners['before-drop']:
- listener('before-drop', table, self.connection)
+ def visit_BIGINT(self, type_):
+ return "BIGINT"
- for column in table.columns:
- if column.default is not None:
- self.traverse_single(column.default)
+ def visit_TIMESTAMP(self, type_):
+ return 'TIMESTAMP'
+
+ def visit_DATETIME(self, type_):
+ return "DATETIME"
+
+ def visit_DATE(self, type_):
+ return "DATE"
+
+ def visit_TIME(self, type_):
+ return "TIME"
+
+ def visit_CLOB(self, type_):
+ return "CLOB"
- self.append("\nDROP TABLE " + self.preparer.format_table(table))
- self.execute()
+ def visit_NCLOB(self, type_):
+ return "NCLOB"
- for listener in table.ddl_listeners['after-drop']:
- listener('after-drop', table, self.connection)
+ def visit_VARCHAR(self, type_):
+ return "VARCHAR" + (type_.length and "(%d)" % type_.length or "")
+ def visit_NVARCHAR(self, type_):
+ return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "")
+ def visit_BLOB(self, type_):
+ return "BLOB"
+
+ def visit_BOOLEAN(self, type_):
+ return "BOOLEAN"
+
+ def visit_TEXT(self, type_):
+ return "TEXT"
+
+ def visit_binary(self, type_):
+ return self.visit_BLOB(type_)
+
+ def visit_boolean(self, type_):
+ return self.visit_BOOLEAN(type_)
+
+ def visit_time(self, type_):
+ return self.visit_TIME(type_)
+
+ def visit_datetime(self, type_):
+ return self.visit_DATETIME(type_)
+
+ def visit_date(self, type_):
+ return self.visit_DATE(type_)
+
+ def visit_big_integer(self, type_):
+ return self.visit_BIGINT(type_)
+
+ def visit_small_integer(self, type_):
+ return self.visit_SMALLINT(type_)
+
+ def visit_integer(self, type_):
+ return self.visit_INTEGER(type_)
+
+ def visit_float(self, type_):
+ return self.visit_FLOAT(type_)
+
+ def visit_numeric(self, type_):
+ return self.visit_NUMERIC(type_)
+
+ def visit_string(self, type_):
+ return self.visit_VARCHAR(type_)
+
+ def visit_unicode(self, type_):
+ return self.visit_VARCHAR(type_)
+
+ def visit_text(self, type_):
+ return self.visit_TEXT(type_)
+
+ def visit_unicode_text(self, type_):
+ return self.visit_TEXT(type_)
+
+ def visit_null(self, type_):
+ raise NotImplementedError("Can't generate DDL for the null type")
+
+ def visit_type_decorator(self, type_):
+ return self.process(type_.type_engine(self.dialect))
+
+ def visit_user_defined(self, type_):
+ return type_.get_col_spec()
+
class IdentifierPreparer(object):
"""Handle quoting and case-folding of identifiers based on options."""
else:
return (self.format_table(table, use_schema=False), )
+ @util.memoized_property
+ def _r_identifiers(self):
+ initial, final, escaped_final = \
+ [re.escape(s) for s in
+ (self.initial_quote, self.final_quote,
+ self._escape_identifier(self.final_quote))]
+ r = re.compile(
+ r'(?:'
+ r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
+ r'|([^\.]+))(?=\.|$))+' %
+ { 'initial': initial,
+ 'final': final,
+ 'escaped': escaped_final })
+ return r
+
def unformat_identifiers(self, identifiers):
"""Unpack 'schema.table.column'-like strings into components."""
- try:
- r = self._r_identifiers
- except AttributeError:
- initial, final, escaped_final = \
- [re.escape(s) for s in
- (self.initial_quote, self.final_quote,
- self._escape_identifier(self.final_quote))]
- r = re.compile(
- r'(?:'
- r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
- r'|([^\.]+))(?=\.|$))+' %
- { 'initial': initial,
- 'final': final,
- 'escaped': escaped_final })
- self._r_identifiers = r
-
+ r = self._r_identifiers
return [self._unescape_identifier(i)
for i in [a or b for a, b in r.findall(identifiers)]]
import itertools, re
from operator import attrgetter
-from sqlalchemy import util, exc
+from sqlalchemy import util, exc, types as sqltypes
from sqlalchemy.sql import operators
from sqlalchemy.sql.visitors import Visitable, cloned_traverse
-from sqlalchemy import types as sqltypes
import operator
functions, schema, sql_util = None, None, None
Similar functionality is also available via the ``select()``
method on any :class:`~sqlalchemy.sql.expression.FromClause`.
- The returned object is an instance of :class:`~sqlalchemy.sql.expression.Select`.
+ The returned object is an instance of
+ :class:`~sqlalchemy.sql.expression.Select`.
All arguments which accept ``ClauseElement`` arguments also accept
string arguments, which will be converted as appropriate into
"""
if 'scalar' in kwargs:
- util.warn_deprecated('scalar option is deprecated; see docs for details')
+ util.warn_deprecated(
+ 'scalar option is deprecated; see docs for details')
scalar = kwargs.pop('scalar', False)
s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
if scalar:
return s
def subquery(alias, *args, **kwargs):
- """Return an :class:`~sqlalchemy.sql.expression.Alias` object derived from a :class:`~sqlalchemy.sql.expression.Select`.
+ """Return an :class:`~sqlalchemy.sql.expression.Alias` object derived
+ from a :class:`~sqlalchemy.sql.expression.Select`.
name
alias name
\*args, \**kwargs
- all other arguments are delivered to the :func:`~sqlalchemy.sql.expression.select`
- function.
+ all other arguments are delivered to the
+ :func:`~sqlalchemy.sql.expression.select` function.
"""
return Select(*args, **kwargs).alias(alias)
table columns. Note that the :meth:`~Insert.values()` generative method
may also be used for this.
- :param prefixes: A list of modifier keywords to be inserted between INSERT and INTO.
- Alternatively, the :meth:`~Insert.prefix_with` generative method may be used.
+ :param prefixes: A list of modifier keywords to be inserted between INSERT
+ and INTO. Alternatively, the :meth:`~Insert.prefix_with` generative method
+ may be used.
- :param inline:
- if True, SQL defaults will be compiled 'inline' into the statement
- and not pre-executed.
+ :param inline: if True, SQL defaults will be compiled 'inline' into the
+ statement and not pre-executed.
If both `values` and compile-time bind parameters are present, the
compile-time bind parameters override the information specified
:param table: The table to be updated.
- :param whereclause: A ``ClauseElement`` describing the ``WHERE`` condition of the
- ``UPDATE`` statement. Note that the :meth:`~Update.where()` generative
- method may also be used for this.
+ :param whereclause: A ``ClauseElement`` describing the ``WHERE`` condition
+ of the ``UPDATE`` statement. Note that the :meth:`~Update.where()`
+ generative method may also be used for this.
:param values:
A dictionary which specifies the ``SET`` conditions of the
against the ``UPDATE`` statement.
"""
- return Update(table, whereclause=whereclause, values=values, inline=inline, **kwargs)
+ return Update(
+ table,
+ whereclause=whereclause,
+ values=values,
+ inline=inline,
+ **kwargs)
def delete(table, whereclause = None, **kwargs):
"""Return a :class:`~sqlalchemy.sql.expression.Delete` clause element.
:param table: The table to be updated.
- :param whereclause: A :class:`ClauseElement` describing the ``WHERE`` condition of the
- ``UPDATE`` statement. Note that the :meth:`~Delete.where()` generative method
- may be used instead.
+ :param whereclause: A :class:`ClauseElement` describing the ``WHERE``
+ condition of the ``UPDATE`` statement. Note that the :meth:`~Delete.where()`
+ generative method may be used instead.
"""
return Delete(table, whereclause, **kwargs)
"""Join a list of clauses together using the ``AND`` operator.
The ``&`` operator is also overloaded on all
- :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same
- result.
+ :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the
+ same result.
"""
if len(clauses) == 1:
"""Join a list of clauses together using the ``OR`` operator.
The ``|`` operator is also overloaded on all
- :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same
- result.
+ :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the
+ same result.
"""
if len(clauses) == 1:
"""Return a negation of the given clause, i.e. ``NOT(clause)``.
The ``~`` operator is also overloaded on all
- :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same
- result.
+ :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the
+ same result.
"""
return operators.inv(_literal_as_binds(clause))
Equivalent of SQL ``clausetest BETWEEN clauseleft AND clauseright``.
- The ``between()`` method on all :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses
- provides similar functionality.
+ The ``between()`` method on all
+ :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses provides
+ similar functionality.
"""
ctest = _literal_as_binds(ctest)
def union(*selects, **kwargs):
"""Return a ``UNION`` of multiple selectables.
- The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+ The returned object is an instance of
+ :class:`~sqlalchemy.sql.expression.CompoundSelect`.
A similar ``union()`` method is available on all
:class:`~sqlalchemy.sql.expression.FromClause` subclasses.
def union_all(*selects, **kwargs):
"""Return a ``UNION ALL`` of multiple selectables.
- The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+ The returned object is an instance of
+ :class:`~sqlalchemy.sql.expression.CompoundSelect`.
A similar ``union_all()`` method is available on all
:class:`~sqlalchemy.sql.expression.FromClause` subclasses.
def except_(*selects, **kwargs):
"""Return an ``EXCEPT`` of multiple selectables.
- The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+ The returned object is an instance of
+ :class:`~sqlalchemy.sql.expression.CompoundSelect`.
\*selects
a list of :class:`~sqlalchemy.sql.expression.Select` instances.
def except_all(*selects, **kwargs):
"""Return an ``EXCEPT ALL`` of multiple selectables.
- The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+ The returned object is an instance of
+ :class:`~sqlalchemy.sql.expression.CompoundSelect`.
\*selects
a list of :class:`~sqlalchemy.sql.expression.Select` instances.
def intersect(*selects, **kwargs):
"""Return an ``INTERSECT`` of multiple selectables.
- The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+ The returned object is an instance of
+ :class:`~sqlalchemy.sql.expression.CompoundSelect`.
\*selects
a list of :class:`~sqlalchemy.sql.expression.Select` instances.
def intersect_all(*selects, **kwargs):
"""Return an ``INTERSECT ALL`` of multiple selectables.
- The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`.
+ The returned object is an instance of
+ :class:`~sqlalchemy.sql.expression.CompoundSelect`.
\*selects
a list of :class:`~sqlalchemy.sql.expression.Select` instances.
def alias(selectable, alias=None):
"""Return an :class:`~sqlalchemy.sql.expression.Alias` object.
- An ``Alias`` represents any :class:`~sqlalchemy.sql.expression.FromClause` with
- an alternate name assigned within SQL, typically using the ``AS``
+ An ``Alias`` represents any :class:`~sqlalchemy.sql.expression.FromClause`
+ with an alternate name assigned within SQL, typically using the ``AS``
clause when generated, e.g. ``SELECT * FROM table AS aliasname``.
Similar functionality is available via the ``alias()`` method
return _BindParamClause(None, value, type_=type_, unique=True)
def label(name, obj):
- """Return a :class:`~sqlalchemy.sql.expression._Label` object for the given :class:`~sqlalchemy.sql.expression.ColumnElement`.
+ """Return a :class:`~sqlalchemy.sql.expression._Label` object for the given
+ :class:`~sqlalchemy.sql.expression.ColumnElement`.
A label changes the name of an element in the columns clause of a
``SELECT`` statement, typically via the ``AS`` SQL keyword.
return _Label(name, obj)
def column(text, type_=None):
- """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement.
+ """Return a textual column clause, as would be in the columns clause of a
+ ``SELECT`` statement.
- The object returned is an instance of :class:`~sqlalchemy.sql.expression.ColumnClause`,
- which represents the "syntactical" portion of the schema-level
- :class:`~sqlalchemy.schema.Column` object.
+ The object returned is an instance of
+ :class:`~sqlalchemy.sql.expression.ColumnClause`, which represents the
+ "syntactical" portion of the schema-level :class:`~sqlalchemy.schema.Column`
+ object.
text
the name of the column. Quoting rules will be applied to the
:func:`~sqlalchemy.sql.expression.column` function.
type\_
- an optional :class:`~sqlalchemy.types.TypeEngine` object which will provide
- result-set translation and additional expression semantics for this
- column. If left as None the type will be NullType.
+ an optional :class:`~sqlalchemy.types.TypeEngine` object which will
+ provide result-set translation and additional expression semantics for
+ this column. If left as None the type will be NullType.
"""
return ColumnClause(text, type_=type_, is_literal=True)
return _BindParamClause(key, value, type_=type_, unique=unique, shortname=shortname)
def outparam(key, type_=None):
- """Create an 'OUT' parameter for usage in functions (stored procedures), for databases which support them.
+ """Create an 'OUT' parameter for usage in functions (stored procedures), for
+ databases which support them.
The ``outparam`` can be used like a regular function parameter.
The "output" value will be available from the
attribute, which returns a dictionary containing the values.
"""
- return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True)
+ return _BindParamClause(
+ key, None, type_=type_, unique=False, isoutparam=True)
def text(text, bind=None, *args, **kwargs):
"""Create literal text to be inserted into a query.
return _TextClause(text, bind=bind, *args, **kwargs)
def null():
- """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql statement."""
-
+ """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql
+ statement.
+
+ """
return _Null()
class _FunctionGenerator(object):
if func is not None:
return func(*c, **o)
- return Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
+ return Function(
+ self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
# "func" global - i.e. func.count()
func = _FunctionGenerator()
return element._clone()
def _expand_cloned(elements):
- """expand the given set of ClauseElements to be the set of all 'cloned' predecessors."""
-
+ """expand the given set of ClauseElements to be the set of all 'cloned'
+ predecessors.
+
+ """
return itertools.chain(*[x._cloned_set for x in elements])
+def _select_iterables(elements):
+ """expand tables into individual columns in the
+ given list of column expressions.
+
+ """
+ return itertools.chain(*[c._select_iterable for c in elements])
+
def _cloned_intersection(a, b):
"""return the intersection of sets a and b, counting
any overlap between 'cloned' predecessors.
return CompoundSelect(keyword, *selects, **kwargs)
def _is_literal(element):
- return not isinstance(element, Visitable) and not hasattr(element, '__clause_element__')
+ return not isinstance(element, Visitable) and \
+ not hasattr(element, '__clause_element__')
def _from_objects(*elements):
return itertools.chain(*[element._from_objects for element in elements])
return element
def _corresponding_column_or_error(fromclause, column, require_embedded=False):
- c = fromclause.corresponding_column(column, require_embedded=require_embedded)
+ c = fromclause.corresponding_column(column,
+ require_embedded=require_embedded)
if not c:
- raise exc.InvalidRequestError("Given column '%s', attached to table '%s', "
+ raise exc.InvalidRequestError(
+ "Given column '%s', attached to table '%s', "
"failed to locate a corresponding column from table '%s'"
- % (column, getattr(column, 'table', None), fromclause.description))
+ %
+ (column,
+ getattr(column, 'table', None),fromclause.description)
+ )
return c
def is_column(col):
"""True if ``col`` is an instance of ``ColumnElement``."""
+
return isinstance(col, ColumnElement)
class ClauseElement(Visitable):
- """Base class for elements of a programmatically constructed SQL expression."""
-
+ """Base class for elements of a programmatically constructed SQL
+ expression.
+
+ """
__visit_name__ = 'clause'
_annotations = {}
supports_execution = False
_from_objects = []
-
+ _bind = None
+
def _clone(self):
"""Create a shallow copy of this ClauseElement.
@util.memoized_property
def _cloned_set(self):
- """Return the set consisting all cloned anscestors of this ClauseElement.
+ """Return the set consisting all cloned anscestors of this
+ ClauseElement.
Includes this ClauseElement. This accessor tends to be used for
FromClause objects to identify 'equivalent' FROM clauses, regardless
return d
def _annotate(self, values):
- """return a copy of this ClauseElement with the given annotations dictionary."""
-
+ """return a copy of this ClauseElement with the given annotations
+ dictionary.
+
+ """
global Annotated
if Annotated is None:
from sqlalchemy.sql.util import Annotated
return Annotated(self, values)
def _deannotate(self):
- """return a copy of this ClauseElement with an empty annotations dictionary."""
+ """return a copy of this ClauseElement with an empty annotations
+ dictionary.
+
+ """
return self._clone()
def unique_params(self, *optionaldict, **kwargs):
if len(optionaldict) == 1:
kwargs.update(optionaldict[0])
elif len(optionaldict) > 1:
- raise exc.ArgumentError("params() takes zero or one positional dictionary argument")
+ raise exc.ArgumentError(
+ "params() takes zero or one positional dictionary argument")
def visit_bindparam(bind):
if bind.key in kwargs:
def self_group(self, against=None):
return self
+ # TODO: remove .bind as a method from the root ClauseElement.
+ # we should only be deriving binds from FromClause elements
+ # and certain SchemaItem subclasses.
+ # the "search_for_bind" functionality can still be used by
+ # execute(), however.
@property
def bind(self):
- """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found."""
+ """Returns the Engine or Connection to which this ClauseElement is
+ bound, or None if none found.
+
+ """
+ if self._bind is not None:
+ return self._bind
- try:
- if self._bind is not None:
- return self._bind
- except AttributeError:
- pass
for f in _from_objects(self):
if f is self:
continue
return e._execute_clauseelement(self, multiparams, params)
def scalar(self, *multiparams, **params):
- """Compile and execute this ``ClauseElement``, returning the result's scalar representation."""
-
+ """Compile and execute this ``ClauseElement``, returning the result's
+ scalar representation.
+
+ """
return self.execute(*multiparams, **params).scalar()
- def compile(self, bind=None, column_keys=None, compiler=None, dialect=None, inline=False):
+ def compile(self, bind=None, dialect=None, **kw):
"""Compile this SQL expression.
The return value is a :class:`~sqlalchemy.engine.Compiled` object.
- Calling `str()` or `unicode()` on the returned value will yield
- a string representation of the result. The :class:`~sqlalchemy.engine.Compiled`
- object also can return a dictionary of bind parameter names and
- values using the `params` accessor.
+ Calling `str()` or `unicode()` on the returned value will yield a string
+ representation of the result. The :class:`~sqlalchemy.engine.Compiled`
+ object also can return a dictionary of bind parameter names and values
+ using the `params` accessor.
:param bind: An ``Engine`` or ``Connection`` from which a
``Compiled`` will be acquired. This argument
takes precedence over this ``ClauseElement``'s
bound engine, if any.
- :param column_keys: Used for INSERT and UPDATE statements, a list of
- column names which should be present in the VALUES clause
- of the compiled statement. If ``None``, all columns
- from the target table object are rendered.
-
- :param compiler: A ``Compiled`` instance which will be used to compile
- this expression. This argument takes precedence
- over the `bind` and `dialect` arguments as well as
- this ``ClauseElement``'s bound engine, if
- any.
-
:param dialect: A ``Dialect`` instance frmo which a ``Compiled``
will be acquired. This argument takes precedence
over the `bind` argument as well as this
``ClauseElement``'s bound engine, if any.
- :param inline: Used for INSERT statements, for a dialect which does
- not support inline retrieval of newly generated
- primary key columns, will force the expression used
- to create the new primary key value to be rendered
- inline within the INSERT statement's VALUES clause.
- This typically refers to Sequence execution but
- may also refer to any server-side default generation
- function associated with a primary key `Column`.
+ \**kw
+
+ Keyword arguments are passed along to the compiler,
+ which can affect the string produced.
+
+ Keywords for a statement compiler are:
+
+ column_keys
+ Used for INSERT and UPDATE statements, a list of
+ column names which should be present in the VALUES clause
+ of the compiled statement. If ``None``, all columns
+ from the target table object are rendered.
+
+ inline
+ Used for INSERT statements, for a dialect which does
+ not support inline retrieval of newly generated
+ primary key columns, will force the expression used
+ to create the new primary key value to be rendered
+ inline within the INSERT statement's VALUES clause.
+ This typically refers to Sequence execution but
+ may also refer to any server-side default generation
+ function associated with a primary key `Column`.
"""
- if compiler is None:
- if dialect is not None:
- compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline)
- elif bind is not None:
- compiler = bind.statement_compiler(self, column_keys=column_keys, inline=inline)
- elif self.bind is not None:
- compiler = self.bind.statement_compiler(self, column_keys=column_keys, inline=inline)
+
+ if not dialect:
+ if bind:
+ dialect = bind.dialect
+ elif self.bind:
+ dialect = self.bind.dialect
+ bind = self.bind
else:
global DefaultDialect
if DefaultDialect is None:
from sqlalchemy.engine.default import DefaultDialect
dialect = DefaultDialect()
- compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline)
+ compiler = self._compiler(dialect, bind=bind, **kw)
compiler.compile()
return compiler
-
+
+ def _compiler(self, dialect, **kw):
+ """Return a compiler appropriate for this ClauseElement, given a Dialect."""
+
+ return dialect.statement_compiler(dialect, self, **kw)
+
def __str__(self):
+ # Py3K
+ #return unicode(self.compile())
+ # Py2K
return unicode(self.compile()).encode('ascii', 'backslashreplace')
+ # end Py2K
def __and__(self, other):
return and_(self, other)
def __invert__(self):
return self._negate()
+ if util.jython:
+ def __hash__(self):
+ """Return a distinct hash code.
+
+ ClauseElements may have special equality comparisons which
+ makes us rely on them having unique hash codes for use in
+ hash-based collections. Stock __hash__ doesn't guarantee
+ unique values on platforms with moving GCs.
+ """
+ return id(self)
+
def _negate(self):
if hasattr(self, 'negation_clause'):
return self.negation_clause
else:
- return _UnaryExpression(self.self_group(against=operators.inv), operator=operators.inv, negate=None)
+ return _UnaryExpression(
+ self.self_group(against=operators.inv),
+ operator=operators.inv,
+ negate=None)
def __repr__(self):
friendly = getattr(self, 'description', None)
class _Immutable(object):
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
+ def unique_params(self, *optionaldict, **kwargs):
+ raise NotImplementedError("Immutable objects do not support copying")
+
+ def params(self, *optionaldict, **kwargs):
+ raise NotImplementedError("Immutable objects do not support copying")
+
def _clone(self):
return self
def __truediv__(self, other):
return self.operate(operators.truediv, other)
+ def __rtruediv__(self, other):
+ return self.reverse_operate(operators.truediv, other)
+
class _CompareMixin(ColumnOperators):
"""Defines comparison and math operations for ``ClauseElement`` instances."""
operators.add : (__operate,),
operators.mul : (__operate,),
operators.sub : (__operate,),
+ # Py2K
operators.div : (__operate,),
+ # end Py2K
operators.mod : (__operate,),
operators.truediv : (__operate,),
operators.lt : (__compare, operators.ge),
def __init__(self, *cols):
super(ColumnCollection, self).__init__()
- [self.add(c) for c in cols]
+ self.update((c.key, c) for c in cols)
def __str__(self):
return repr([str(c) for c in self])
__visit_name__ = 'selectable'
class FromClause(Selectable):
- """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
-
+ """Represent an element that can be used within the ``FROM``
+ clause of a ``SELECT`` statement.
+
+ """
__visit_name__ = 'fromclause'
named_with_column = False
_hide_froms = []
col = list(self.primary_key)[0]
else:
col = list(self.columns)[0]
- return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
+ return select(
+ [func.count(col).label('tbl_row_count')],
+ whereclause,
+ from_obj=[self],
+ **params)
def select(self, whereclause=None, **params):
"""return a SELECT of this ``FromClause``."""
return fromclause in self._cloned_set
def replace_selectable(self, old, alias):
- """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
-
+ """replace all occurences of FromClause 'old' with the given Alias
+ object, returning a copy of this ``FromClause``.
+
+ """
global ClauseAdapter
if ClauseAdapter is None:
from sqlalchemy.sql.util import ClauseAdapter
col, intersect = c, i
elif len(i) > len(intersect):
# 'c' has a larger field of correspondence than 'col'.
- # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches a1.c.x->table.c.x better than
+ # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches
+ # a1.c.x->table.c.x better than
# selectable.c.x->table.c.x does.
col, intersect = c, i
elif i == intersect:
# they have the same field of correspondence.
- # see which proxy_set has fewer columns in it, which indicates a
- # closer relationship with the root column. Also take into account the
- # "weight" attribute which CompoundSelect() uses to give higher precedence to
- # columns based on vertical position in the compound statement, and discard columns
- # that have no reference to the target column (also occurs with CompoundSelect)
+ # see which proxy_set has fewer columns in it, which indicates
+ # a closer relationship with the root column. Also take into
+ # account the "weight" attribute which CompoundSelect() uses to
+ # give higher precedence to columns based on vertical position
+ # in the compound statement, and discard columns that have no
+ # reference to the target column (also occurs with
+ # CompoundSelect)
col_distance = util.reduce(operator.add,
- [sc._annotations.get('weight', 1) for sc in col.proxy_set if sc.shares_lineage(column)]
+ [sc._annotations.get('weight', 1)
+ for sc in col.proxy_set
+ if sc.shares_lineage(column)]
)
c_distance = util.reduce(operator.add,
- [sc._annotations.get('weight', 1) for sc in c.proxy_set if sc.shares_lineage(column)]
+ [sc._annotations.get('weight', 1)
+ for sc in c.proxy_set
+ if sc.shares_lineage(column)]
)
- if \
- c_distance < col_distance:
+ if c_distance < col_distance:
col, intersect = c, i
return col
the same type.
"""
- return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__ and self.value == other.value
+ return isinstance(other, _BindParamClause) and \
+ other.type.__class__ == self.type.__class__ and \
+ self.value == other.value
def __getstate__(self):
"""execute a deferred value for serialization purposes."""
return d
def __repr__(self):
- return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type))
+ return "_BindParamClause(%r, %r, type_=%r)" % (
+ self.key, self.value, self.type
+ )
class _TypeClause(ClauseElement):
"""Handle a type keyword in a SQL statement.
_hide_froms = []
- def __init__(self, text = "", bind=None, bindparams=None, typemap=None, autocommit=False):
+ def __init__(self, text = "", bind=None,
+ bindparams=None, typemap=None, autocommit=False):
self._bind = bind
self.bindparams = {}
self.typemap = typemap
return list(itertools.chain(*[c._from_objects for c in self.clauses]))
def self_group(self, against=None):
- if self.group and self.operator is not against and operators.is_precedent(self.operator, against):
+ if self.group and self.operator is not against and \
+ operators.is_precedent(self.operator, against):
return _Grouping(self)
else:
return self
pass
if value:
- whenlist = [(_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens]
+ whenlist = [
+ (_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens
+ ]
else:
- whenlist = [(_no_literals(c).self_group(), _literal_as_binds(r)) for (c, r) in whens]
+ whenlist = [
+ (_no_literals(c).self_group(), _literal_as_binds(r)) for (c, r) in whens
+ ]
if whenlist:
type_ = list(whenlist[-1])[-1].type
return e
def select_from(self, clause):
- """return a new exists() construct with the given expression set as its FROM clause."""
-
+ """return a new exists() construct with the given expression set as its FROM
+ clause.
+
+ """
e = self._clone()
e.element = self.element.select_from(clause).self_group()
return e
def where(self, clause):
- """return a new exists() construct with the given expression added to its WHERE clause, joined
- to the existing clause via AND, if any."""
-
+ """return a new exists() construct with the given expression added to its WHERE
+ clause, joined to the existing clause via AND, if any.
+
+ """
e = self._clone()
e.element = self.element.where(clause).self_group()
return e
id(self.right))
def is_derived_from(self, fromclause):
- return fromclause is self or self.left.is_derived_from(fromclause) or self.right.is_derived_from(fromclause)
+ return fromclause is self or \
+ self.left.is_derived_from(fromclause) or\
+ self.right.is_derived_from(fromclause)
def self_group(self, against=None):
return _FromGrouping(self)
@property
def description(self):
+ # Py3K
+ #return self.name
+ # Py2K
return self.name.encode('ascii', 'backslashreplace')
+ # end Py2K
def as_scalar(self):
try:
def __init__(self, name, element, type_=None):
while isinstance(element, _Label):
element = element.element
- self.name = self.key = self._label = name or _generated_label("%%(%d %s)s" % (id(self), getattr(element, 'name', 'anon')))
+ self.name = self.key = self._label = name or \
+ _generated_label("%%(%d %s)s" % (
+ id(self), getattr(element, 'name', 'anon'))
+ )
self._element = element
self._type = type_
self.quote = element.quote
@util.memoized_property
def type(self):
- return sqltypes.to_instance(self._type or getattr(self._element, 'type', None))
+ return sqltypes.to_instance(
+ self._type or getattr(self._element, 'type', None)
+ )
@util.memoized_property
def element(self):
@util.memoized_property
def description(self):
+ # Py3K
+ #return self.name
+ # Py2K
return self.name.encode('ascii', 'backslashreplace')
+ # end Py2K
@util.memoized_property
def _label(self):
# propagate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
- c = ColumnClause(name or self.name, selectable=selectable, type_=self.type, is_literal=is_literal)
+ c = ColumnClause(
+ name or self.name,
+ selectable=selectable,
+ type_=self.type,
+ is_literal=is_literal
+ )
c.proxies = [self]
if attach:
selectable.columns[c.name] = c
@util.memoized_property
def description(self):
+ # Py3K
+ #return self.name
+ # Py2K
return self.name.encode('ascii', 'backslashreplace')
+ # end Py2K
def append_column(self, c):
self._columns[c.name] = c
col = list(self.primary_key)[0]
else:
col = list(self.columns)[0]
- return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
+ return select(
+ [func.count(col).label('tbl_row_count')],
+ whereclause,
+ from_obj=[self],
+ **params)
def insert(self, values=None, inline=False, **kwargs):
"""Generate an :func:`~sqlalchemy.sql.expression.insert()` construct."""
def update(self, whereclause=None, values=None, inline=False, **kwargs):
"""Generate an :func:`~sqlalchemy.sql.expression.update()` construct."""
- return update(self, whereclause=whereclause, values=values, inline=inline, **kwargs)
+ return update(self, whereclause=whereclause,
+ values=values, inline=inline, **kwargs)
def delete(self, whereclause=None, **kwargs):
"""Generate a :func:`~sqlalchemy.sql.expression.delete()` construct."""
Typically, a select statement which has only one column in its columns clause
is eligible to be used as a scalar expression.
- The returned object is an instance of :class:`~sqlalchemy.sql.expression._ScalarSelect`.
+ The returned object is an instance of
+ :class:`~sqlalchemy.sql.expression._ScalarSelect`.
"""
return _ScalarSelect(self)
def apply_labels(self):
"""return a new selectable with the 'use_labels' flag set to True.
- This will result in column expressions being generated using labels against their table
- name, such as "SELECT somecolumn AS tablename_somecolumn". This allows selectables which
- contain multiple FROM clauses to produce a unique set of column names regardless of name conflicts
- among the individual FROM clauses.
+ This will result in column expressions being generated using labels against their
+ table name, such as "SELECT somecolumn AS tablename_somecolumn". This allows
+ selectables which contain multiple FROM clauses to produce a unique set of column
+ names regardless of name conflicts among the individual FROM clauses.
"""
self.use_labels = True
return list(self.inner_columns)[0]._make_proxy(selectable, name)
class CompoundSelect(_SelectBaseMixin, FromClause):
- """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations."""
+ """Forms the basis of ``UNION``, ``UNION ALL``, and other
+ SELECT-based set operations."""
__visit_name__ = 'compound_select'
elif len(s.c) != numcols:
raise exc.ArgumentError(
"All selectables passed to CompoundSelect must "
- "have identical numbers of columns; select #%d has %d columns, select #%d has %d" %
+ "have identical numbers of columns; select #%d has %d columns,"
+ " select #%d has %d" %
(1, len(self.selects[0].c), n+1, len(s.c))
)
__visit_name__ = 'select'
- def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs):
+ def __init__(self,
+ columns,
+ whereclause=None,
+ from_obj=None,
+ distinct=False,
+ having=None,
+ correlate=True,
+ prefixes=None,
+ **kwargs):
"""Construct a Select object.
The public constructor for Select is the
if columns:
self._raw_columns = [
- isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
- for c in
- [_literal_as_column(c) for c in columns]
+ isinstance(c, _ScalarSelect) and
+ c.self_group(against=operators.comma_op) or c
+ for c in [_literal_as_column(c) for c in columns]
]
self._froms.update(_from_objects(*self._raw_columns))
be rendered into the columns clause of the resulting SELECT statement.
"""
-
- return itertools.chain(*[c._select_iterable for c in self._raw_columns])
+ return _select_iterables(self._raw_columns)
def is_derived_from(self, fromclause):
if self in fromclause._cloned_set:
self._reset_exported()
from_cloned = dict((f, clone(f))
for f in self._froms.union(self._correlate))
- self._froms = set(from_cloned[f] for f in self._froms)
+ self._froms = util.OrderedSet(from_cloned[f] for f in self._froms)
self._correlate = set(from_cloned[f] for f in self._correlate)
self._raw_columns = [clone(c) for c in self._raw_columns]
for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
return (column_collections and list(self.columns) or []) + \
self._raw_columns + list(self._froms) + \
- [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
+ [x for x in
+ (self._whereclause, self._having,
+ self._order_by_clause, self._group_by_clause)
+ if x is not None]
@_generative
def column(self, column):
- """return a new select() construct with the given column expression added to its columns clause."""
+ """return a new select() construct with the given column expression
+ added to its columns clause.
+
+ """
column = _literal_as_column(column)
@_generative
def with_only_columns(self, columns):
- """return a new select() construct with its columns clause replaced with the given columns."""
+ """return a new select() construct with its columns clause replaced
+ with the given columns.
+
+ """
self._raw_columns = [
- isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
- for c in
- [_literal_as_column(c) for c in columns]
+ isinstance(c, _ScalarSelect) and
+ c.self_group(against=operators.comma_op) or c
+ for c in [_literal_as_column(c) for c in columns]
]
@_generative
def where(self, whereclause):
- """return a new select() construct with the given expression added to its WHERE clause, joined
- to the existing clause via AND, if any."""
+ """return a new select() construct with the given expression added to its
+ WHERE clause, joined to the existing clause via AND, if any.
+
+ """
self.append_whereclause(whereclause)
@_generative
def having(self, having):
- """return a new select() construct with the given expression added to its HAVING clause, joined
- to the existing clause via AND, if any."""
-
+ """return a new select() construct with the given expression added to its HAVING
+ clause, joined to the existing clause via AND, if any.
+
+ """
self.append_having(having)
@_generative
def distinct(self):
- """return a new select() construct which will apply DISTINCT to its columns clause."""
-
+ """return a new select() construct which will apply DISTINCT to its columns
+ clause.
+
+ """
self._distinct = True
@_generative
def prefix_with(self, clause):
- """return a new select() construct which will apply the given expression to the start of its
- columns clause, not using any commas."""
+ """return a new select() construct which will apply the given expression to the
+ start of its columns clause, not using any commas.
+ """
clause = _literal_as_text(clause)
self._prefixes = self._prefixes + [clause]
@_generative
def select_from(self, fromclause):
- """return a new select() construct with the given FROM expression applied to its list of
- FROM objects."""
+ """return a new select() construct with the given FROM expression applied to its
+ list of FROM objects.
+ """
fromclause = _literal_as_text(fromclause)
self._froms = self._froms.union([fromclause])
@_generative
def correlate(self, *fromclauses):
- """return a new select() construct which will correlate the given FROM clauses to that
- of an enclosing select(), if a match is found.
-
- By "match", the given fromclause must be present in this select's list of FROM objects
- and also present in an enclosing select's list of FROM objects.
-
- Calling this method turns off the select's default behavior of "auto-correlation". Normally,
- select() auto-correlates all of its FROM clauses to those of an embedded select when
- compiled.
-
- If the fromclause is None, correlation is disabled for the returned select().
+ """return a new select() construct which will correlate the given FROM clauses to
+ that of an enclosing select(), if a match is found.
+
+ By "match", the given fromclause must be present in this select's list of FROM
+ objects and also present in an enclosing select's list of FROM objects.
+
+ Calling this method turns off the select's default behavior of
+ "auto-correlation". Normally, select() auto-correlates all of its FROM clauses to
+ those of an embedded select when compiled.
+
+ If the fromclause is None, correlation is disabled for the returned select().
"""
self._should_correlate = False
self._correlate = self._correlate.union([fromclause])
def append_column(self, column):
- """append the given column expression to the columns clause of this select() construct."""
-
+ """append the given column expression to the columns clause of this select()
+ construct.
+
+ """
column = _literal_as_column(column)
if isinstance(column, _ScalarSelect):
self._reset_exported()
def append_prefix(self, clause):
- """append the given columns clause prefix expression to this select() construct."""
-
+ """append the given columns clause prefix expression to this select()
+ construct.
+
+ """
clause = _literal_as_text(clause)
self._prefixes = self._prefixes.union([clause])
self._having = _literal_as_text(having)
def append_from(self, fromclause):
- """append the given FromClause expression to this select() construct's FROM clause.
+ """append the given FromClause expression to this select() construct's FROM
+ clause.
"""
if _is_literal(fromclause):
return union(self, other, **kwargs)
def union_all(self, other, **kwargs):
- """return a SQL UNION ALL of this select() construct against the given selectable."""
-
+ """return a SQL UNION ALL of this select() construct against the given
+ selectable.
+
+ """
return union_all(self, other, **kwargs)
def except_(self, other, **kwargs):
return except_(self, other, **kwargs)
def except_all(self, other, **kwargs):
- """return a SQL EXCEPT ALL of this select() construct against the given selectable."""
-
+ """return a SQL EXCEPT ALL of this select() construct against the given
+ selectable.
+
+ """
return except_all(self, other, **kwargs)
def intersect(self, other, **kwargs):
- """return a SQL INTERSECT of this select() construct against the given selectable."""
-
+ """return a SQL INTERSECT of this select() construct against the given
+ selectable.
+
+ """
return intersect(self, other, **kwargs)
def intersect_all(self, other, **kwargs):
- """return a SQL INTERSECT ALL of this select() construct against the given selectable."""
-
+ """return a SQL INTERSECT ALL of this select() construct against the given
+ selectable.
+
+ """
return intersect_all(self, other, **kwargs)
def bind(self):
supports_execution = True
_autocommit = True
-
+
def _generate(self):
s = self.__class__.__new__(self.__class__)
s.__dict__ = self.__dict__.copy()
return parameters
def params(self, *arg, **kw):
- raise NotImplementedError("params() is not supported for INSERT/UPDATE/DELETE statements."
- " To set the values for an INSERT or UPDATE statement, use stmt.values(**parameters).")
+ raise NotImplementedError(
+ "params() is not supported for INSERT/UPDATE/DELETE statements."
+ " To set the values for an INSERT or UPDATE statement, use"
+ " stmt.values(**parameters).")
def bind(self):
return self._bind or self.table.bind
self._bind = bind
bind = property(bind, _set_bind)
+ _returning_re = re.compile(r'(?:firebird|postgres(?:ql)?)_returning')
+ def _process_deprecated_kw(self, kwargs):
+ for k in list(kwargs):
+ m = self._returning_re.match(k)
+ if m:
+ self._returning = kwargs.pop(k)
+ util.warn_deprecated(
+ "The %r argument is deprecated. Please use statement.returning(col1, col2, ...)" % k
+ )
+ return kwargs
+
+ @_generative
+ def returning(self, *cols):
+ """Add a RETURNING or equivalent clause to this statement.
+
+ The given list of columns represent columns within the table
+ that is the target of the INSERT, UPDATE, or DELETE. Each
+ element can be any column expression. ``Table`` objects
+ will be expanded into their individual columns.
+
+ Upon compilation, a RETURNING clause, or database equivalent,
+ will be rendered within the statement. For INSERT and UPDATE,
+ the values are the newly inserted/updated values. For DELETE,
+ the values are those of the rows which were deleted.
+
+ Upon execution, the values of the columns to be returned
+ are made available via the result set and can be iterated
+ using ``fetchone()`` and similar. For DBAPIs which do not
+ natively support returning values (i.e. cx_oracle),
+ SQLAlchemy will approximate this behavior at the result level
+ so that a reasonable amount of behavioral neutrality is
+ provided.
+
+ Note that not all databases/DBAPIs
+ support RETURNING. For those backends with no support,
+ an exception is raised upon compilation and/or execution.
+ For those who do support it, the functionality across backends
+ varies greatly, including restrictions on executemany()
+ and other statements which return multiple rows. Please
+ read the documentation notes for the database in use in
+ order to determine the availability of RETURNING.
+
+ """
+ self._returning = cols
+
class _ValuesBase(_UpdateBase):
__visit_name__ = 'values_base'
@_generative
def values(self, *args, **kwargs):
- """specify the VALUES clause for an INSERT statement, or the SET clause for an UPDATE.
+ """specify the VALUES clause for an INSERT statement, or the SET clause for an
+ UPDATE.
\**kwargs
key=<somevalue> arguments
\*args
- A single dictionary can be sent as the first positional argument. This allows
- non-string based keys, such as Column objects, to be used.
+ A single dictionary can be sent as the first positional argument. This
+ allows non-string based keys, such as Column objects, to be used.
"""
if args:
"""
__visit_name__ = 'insert'
- def __init__(self, table, values=None, inline=False, bind=None, prefixes=None, **kwargs):
+ def __init__(self,
+ table,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ **kwargs):
_ValuesBase.__init__(self, table, values)
self._bind = bind
self.select = None
self.inline = inline
+ self._returning = returning
if prefixes:
self._prefixes = [_literal_as_text(p) for p in prefixes]
else:
self._prefixes = []
- self.kwargs = kwargs
+
+ self.kwargs = self._process_deprecated_kw(kwargs)
def get_children(self, **kwargs):
if self.select is not None:
"""
__visit_name__ = 'update'
- def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs):
+ def __init__(self,
+ table,
+ whereclause,
+ values=None,
+ inline=False,
+ bind=None,
+ returning=None,
+ **kwargs):
_ValuesBase.__init__(self, table, values)
self._bind = bind
+ self._returning = returning
if whereclause:
self._whereclause = _literal_as_text(whereclause)
else:
self._whereclause = None
self.inline = inline
- self.kwargs = kwargs
+
+ self.kwargs = self._process_deprecated_kw(kwargs)
def get_children(self, **kwargs):
if self._whereclause is not None:
@_generative
def where(self, whereclause):
- """return a new update() construct with the given expression added to its WHERE clause, joined
- to the existing clause via AND, if any."""
-
+ """return a new update() construct with the given expression added to its WHERE
+ clause, joined to the existing clause via AND, if any.
+
+ """
if self._whereclause is not None:
self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
else:
__visit_name__ = 'delete'
- def __init__(self, table, whereclause, bind=None, **kwargs):
+ def __init__(self,
+ table,
+ whereclause,
+ bind=None,
+ returning =None,
+ **kwargs):
self._bind = bind
self.table = table
+ self._returning = returning
+
if whereclause:
self._whereclause = _literal_as_text(whereclause)
else:
self._whereclause = None
- self.kwargs = kwargs
+ self.kwargs = self._process_deprecated_kw(kwargs)
def get_children(self, **kwargs):
if self._whereclause is not None:
"""Defines operators used in SQL expressions."""
from operator import (
- and_, or_, inv, add, mul, sub, div, mod, truediv, lt, le, ne, gt, ge, eq
+ and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq
)
+
+# Py2K
+from operator import (div,)
+# end Py2K
+
from sqlalchemy.util import symbol
_PRECEDENCE = {
from_: 15,
mul: 7,
+ truediv: 7,
+ # Py2K
div: 7,
+ # end Py2K
mod: 7,
add: 6,
sub: 6,
"""
def __init__(cls, clsname, bases, clsdict):
- if cls.__name__ == 'Visitable':
+ if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'):
super(VisitableType, cls).__init__(clsname, bases, clsdict)
return
- assert hasattr(cls, '__visit_name__'), "`Visitable` descendants " \
- "should define `__visit_name__`"
-
# set up an optimized visit dispatch function
# for use by the compiler
visit_name = cls.__visit_name__
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.base import Connection
from sqlalchemy import util
-import testing
import re
class AssertRule(object):
-import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
+import optparse, os, sys, re, ConfigParser, time, warnings
+
+# 2to3
+import StringIO
+
logging = None
__all__ = 'parser', 'configure', 'options',
[db]
sqlite=sqlite:///:memory:
sqlite_file=sqlite:///querytest.db
-postgres=postgres://scott:tiger@127.0.0.1:5432/test
+postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
+postgres=postgresql://scott:tiger@127.0.0.1:5432/test
+pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
+postgresql_jython=postgresql+zxjdbc://scott:tiger@127.0.0.1:5432/test
+mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test
mysql=mysql://scott:tiger@127.0.0.1:3306/test
oracle=oracle://scott:tiger@127.0.0.1:1521
oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
from sqlalchemy.test import engines
from sqlalchemy import schema
- try:
- # also create alt schemas etc. here?
- if options.dropfirst:
- e = engines.utf8_engine()
- existing = e.table_names()
- if existing:
- print "Dropping existing tables in database: " + db_url
- try:
- print "Tables: %s" % ', '.join(existing)
- except:
- pass
- print "Abort within 5 seconds..."
- time.sleep(5)
- md = schema.MetaData(e, reflect=True)
- md.drop_all()
- e.dispose()
- except (KeyboardInterrupt, SystemExit):
- raise
- except Exception, e:
- warnings.warn(RuntimeWarning(
- "Error checking for existing tables in testing "
- "database: %s" % e))
+ # also create alt schemas etc. here?
+ if options.dropfirst:
+ e = engines.utf8_engine()
+ existing = e.table_names()
+ if existing:
+ print "Dropping existing tables in database: " + db_url
+ try:
+ print "Tables: %s" % ', '.join(existing)
+ except:
+ pass
+ print "Abort within 5 seconds..."
+ time.sleep(5)
+ md = schema.MetaData(e, reflect=True)
+ md.drop_all()
+ e.dispose()
+
post_configure['prep_db'] = _prep_testing_database
def _set_table_options(options, file_config):
from collections import deque
import config
from sqlalchemy.util import function_named, callable
+import re
class ConnectionKiller(object):
def __init__(self):
self.proxy_refs[con_proxy] = True
def _apply_all(self, methods):
- for rec in self.proxy_refs:
+ # must copy keys atomically
+ for rec in self.proxy_refs.keys():
if rec is not None and rec.is_valid:
try:
for name in methods:
testing_reaper = ConnectionKiller()
+def drop_all_tables(metadata):
+ testing_reaper.close_all()
+ metadata.drop_all()
+
def assert_conns_closed(fn):
def decorated(*args, **kw):
try:
testing_reaper.rollback_all()
return function_named(decorated, fn.__name__)
+def close_first(fn):
+ """Decorator that closes all connections before fn execution."""
+ def decorated(*args, **kw):
+ testing_reaper.close_all()
+ fn(*args, **kw)
+ return function_named(decorated, fn.__name__)
+
+
def close_open_connections(fn):
"""Decorator that closes all connections after fn execution."""
def all_dialects():
import sqlalchemy.databases as d
for name in d.__all__:
- mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+ # TEMPORARY
+ mod = getattr(d, name, None)
+ if not mod:
+ mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
yield mod.dialect()
class ReconnectFixture(object):
listeners.append(testing_reaper)
engine = create_engine(url, **options)
-
+
+ # may want to call this, results
+ # in first-connect initializers
+ #engine.connect()
+
return engine
def utf8_engine(url=None, options=None):
from sqlalchemy.engine import url as engine_url
- if config.db.name == 'mysql':
+ if config.db.driver == 'mysqldb':
dbapi_ver = config.db.dialect.dbapi.version_info
if (dbapi_ver < (1, 2, 1) or
dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
return testing_engine(url, options)
-def mock_engine(db=None):
- """Provides a mocking engine based on the current testing.db."""
+def mock_engine(dialect_name=None):
+ """Provides a mocking engine based on the current testing.db.
+
+ This is normally used to test DDL generation flow as emitted
+ by an Engine.
+
+ It should not be used in other cases, as assert_compile() and
+ assert_sql_execution() are much better choices with fewer
+ moving parts.
+
+ """
from sqlalchemy import create_engine
- dbi = db or config.db
+ if not dialect_name:
+ dialect_name = config.db.name
+
buffer = []
def executor(sql, *a, **kw):
buffer.append(sql)
- engine = create_engine(dbi.name + '://',
+ def assert_sql(stmts):
+ recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
+ assert recv == stmts, recv
+
+ engine = create_engine(dialect_name + '://',
strategy='mock', executor=executor)
assert not hasattr(engine, 'mock')
engine.mock = buffer
+ engine.assert_sql = assert_sql
return engine
class ReplayableSession(object):
Natives = set([getattr(types, t)
for t in dir(types) if not t.startswith('_')]). \
difference([getattr(types, t)
+ # Py3K
+ #for t in ('FunctionType', 'BuiltinFunctionType',
+ # 'MethodType', 'BuiltinMethodType',
+ # 'LambdaType', )])
+
+ # Py2K
for t in ('FunctionType', 'BuiltinFunctionType',
'MethodType', 'BuiltinMethodType',
'LambdaType', 'UnboundMethodType',)])
+ # end Py2K
def __init__(self):
self.buffer = deque()
_set_table_options, _reverse_topological, _log
from sqlalchemy.test import testing, config, requires
from nose.plugins import Plugin
-from nose.util import tolist
+from sqlalchemy import util
import nose.case
log = logging.getLogger('nose.plugins.sqlalchemy')
def options(self, parser, env=os.environ):
Plugin.options(self, parser, env)
opt = parser.add_option
- #opt("--verbose", action="store_true", dest="verbose",
- #help="enable stdout echoing/printing")
- #opt("--quiet", action="store_true", dest="quiet", help="suppress output")
opt("--log-info", action="callback", type="string", callback=_log,
help="turn on info logging for <LOG> (multiple OK)")
opt("--log-debug", action="callback", type="string", callback=_log,
def configure(self, options, conf):
Plugin.configure(self, options, conf)
-
- import testing, requires
+ self.options = options
+
+ def begin(self):
testing.db = db
testing.requires = requires
# Lazy setup of other options (post coverage)
for fn in post_configure:
- fn(options, file_config)
-
+ fn(self.options, file_config)
+
def describeTest(self, test):
return ""
if check(test_suite)() != 'ok':
# The requirement will perform messaging.
return True
- if (hasattr(cls, '__unsupported_on__') and
- testing.db.name in cls.__unsupported_on__):
- print "'%s' unsupported on DB implementation '%s'" % (
- cls.__class__.__name__, testing.db.name)
- return True
- if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)):
- print "'%s' unsupported on DB implementation '%s'" % (
- cls.__class__.__name__, testing.db.name)
- return True
+
+ if cls.__unsupported_on__:
+ spec = testing.db_spec(*cls.__unsupported_on__)
+ if spec(testing.db):
+ print "'%s' unsupported on DB implementation '%s'" % (
+ cls.__class__.__name__, testing.db.name)
+ return True
+ if getattr(cls, '__only_on__', None):
+ spec = testing.db_spec(*util.to_list(cls.__only_on__))
+ if not spec(testing.db):
+ print "'%s' unsupported on DB implementation '%s'" % (
+ cls.__class__.__name__, testing.db.name)
+ return True
+
if (getattr(cls, '__skip_if__', False)):
for c in getattr(cls, '__skip_if__'):
if c():
return True
return False
- #def begin(self):
- #pass
-
def beforeTest(self, test):
testing.resetwarnings()
def afterTest(self, test):
testing.resetwarnings()
+ def afterContext(self):
+ testing.global_cleanup_assertions()
+
#def handleError(self, test, err):
#pass
"""
import os, sys
-from sqlalchemy.util import function_named
-import config
+from sqlalchemy.test import config
+from sqlalchemy.test.util import function_named, gc_collect
+from nose import SkipTest
__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
def _profile(filename, fn, *args, **kw):
global profiler
if not profiler:
- profiler = 'hotshot'
if sys.version_info > (2, 5):
try:
import cProfile
profiler = 'cProfile'
except ImportError:
pass
+ if not profiler:
+ try:
+ import hotshot
+ profiler = 'hotshot'
+ except ImportError:
+ profiler = 'skip'
- if profiler == 'cProfile':
+ if profiler == 'skip':
+ raise SkipTest('Profiling not supported on this platform')
+ elif profiler == 'cProfile':
return _profile_cProfile(filename, fn, *args, **kw)
else:
return _profile_hotshot(filename, fn, *args, **kw)
import cProfile, gc, pstats, time
load_stats = lambda: pstats.Stats(filename)
- gc.collect()
+ gc_collect()
began = time.time()
cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
import gc, hotshot, hotshot.stats, time
load_stats = lambda: hotshot.stats.load(filename)
- gc.collect()
+ gc_collect()
prof = hotshot.Profile(filename)
began = time.time()
prof.start()
no_support('sqlite', 'not supported by database'),
)
+
+def unbounded_varchar(fn):
+ """Target database must support VARCHAR with no length"""
+ return _chain_decorators_on(
+ fn,
+ no_support('firebird', 'not supported by database'),
+ no_support('oracle', 'not supported by database'),
+ no_support('mysql', 'not supported by database'),
+ )
+
+def boolean_col_expressions(fn):
+ """Target database must support boolean expressions as columns"""
+ return _chain_decorators_on(
+ fn,
+ no_support('firebird', 'not supported by database'),
+ no_support('oracle', 'not supported by database'),
+ no_support('mssql', 'not supported by database'),
+ )
+
def identity(fn):
"""Target database must support GENERATED AS IDENTITY or a facsimile.
fn,
no_support('firebird', 'not supported by database'),
no_support('oracle', 'not supported by database'),
- no_support('postgres', 'not supported by database'),
+ no_support('postgresql', 'not supported by database'),
no_support('sybase', 'not supported by database'),
)
# no access to same table
no_support('mysql', 'requires SUPER priv'),
exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
- no_support('postgres', 'not supported by database: no statements'),
+
+ # huh? TODO: implement triggers for PG tests, remove this
+ no_support('postgresql', 'PG triggers need to be implemented for tests'),
)
+def correlated_outer_joins(fn):
+ """Target must support an outer join to a subquery which correlates to the parent."""
+
+ return _chain_decorators_on(
+ fn,
+ no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"')
+ )
+
def savepoints(fn):
"""Target database must support savepoints."""
return _chain_decorators_on(
exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
)
+def schemas(fn):
+ """Target database must support external schemas, and have one named 'test_schema'."""
+
+ return _chain_decorators_on(
+ fn,
+ no_support('sqlite', 'no schema support'),
+ no_support('firebird', 'no schema support')
+ )
+
def sequences(fn):
"""Target database must support SEQUENCEs."""
return _chain_decorators_on(
exclude('mysql', '<', (4, 1, 1), 'no subquery support'),
)
+def returning(fn):
+ return _chain_decorators_on(
+ fn,
+ no_support('access', 'not supported by database'),
+ no_support('sqlite', 'not supported by database'),
+ no_support('mysql', 'not supported by database'),
+ no_support('maxdb', 'not supported by database'),
+ no_support('sybase', 'not supported by database'),
+ no_support('informix', 'not supported by database'),
+ )
+
def two_phase_transactions(fn):
"""Target database must support two-phase transactions."""
return _chain_decorators_on(
no_support('oracle', 'no SA implementation'),
no_support('sqlite', 'not supported by database'),
no_support('sybase', 'FIXME: guessing, needs confirmation'),
+ no_support('postgresql+zxjdbc', 'FIXME: JDBC driver confuses the transaction state, may '
+ 'need separate XA implementation'),
exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
)
# expand to ForeignKeyConstraint too.
fks = [fk
for col in args if isinstance(col, schema.Column)
- for fk in col.args if isinstance(fk, schema.ForeignKey)]
+ for fk in col.foreign_keys]
for fk in fks:
# root around in raw spec
if fk.onupdate is None:
fk.onupdate = 'CASCADE'
- if testing.against('firebird', 'oracle'):
- pk_seqs = [col for col in args
- if (isinstance(col, schema.Column)
- and col.primary_key
- and getattr(col, '_needs_autoincrement', False))]
- for c in pk_seqs:
- c.args.append(schema.Sequence(args[0] + '_' + c.name + '_seq', optional=True))
return schema.Table(*args, **kw)
test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
if k.startswith('test_')])
- c = schema.Column(*args, **kw)
- if testing.against('firebird', 'oracle'):
- if 'test_needs_autoincrement' in test_opts:
- c._needs_autoincrement = True
- return c
+ col = schema.Column(*args, **kw)
+ if 'test_needs_autoincrement' in test_opts and \
+ kw.get('primary_key', False) and \
+ testing.against('firebird', 'oracle'):
+ def add_seq(tbl):
+ col._init_items(
+ schema.Sequence(_truncate_name(testing.db.dialect, tbl.name + '_' + col.name + '_seq'), optional=True)
+ )
+ col._on_table_attach(add_seq)
+ return col
+
+def _truncate_name(dialect, name):
+ if len(name) > dialect.max_identifier_length:
+ return name[0:max(dialect.max_identifier_length - 6, 0)] + "_" + hex(hash(name) % 64)[2:]
+ else:
+ return name
+
\ No newline at end of file
import warnings
from cStringIO import StringIO
-from sqlalchemy.test import config, assertsql
+from sqlalchemy.test import config, assertsql, util as testutil
from sqlalchemy.util import function_named
+from engines import drop_all_tables
-from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema
+from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool
+from nose import SkipTest
_ops = { '<': operator.lt,
'>': operator.gt,
"Unexpected success for future test '%s'" % fn_name)
return function_named(decorated, fn_name)
+def db_spec(*dbs):
+ dialects = set([x for x in dbs if '+' not in x])
+ drivers = set([x[1:] for x in dbs if x.startswith('+')])
+ specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers])
+
+ def check(engine):
+ return engine.name in dialects or \
+ engine.driver in drivers or \
+ (engine.name, engine.driver) in specs
+
+ return check
+
+
def fails_on(dbs, reason):
"""Mark a test as expected to fail on the specified database
implementation.
succeeds, a failure is reported.
"""
+ spec = db_spec(dbs)
+
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name != dbs:
+ if not spec(config.db):
return fn(*args, **kw)
else:
try:
fn(*args, **kw)
except Exception, ex:
print ("'%s' failed as expected on DB implementation "
- "'%s': %s" % (
- fn_name, config.db.name, reason))
+ "'%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, reason))
return True
else:
raise AssertionError(
- "Unexpected success for '%s' on DB implementation '%s'" %
- (fn_name, config.db.name))
+ "Unexpected success for '%s' on DB implementation '%s+%s'" %
+ (fn_name, config.db.name, config.db.driver))
return function_named(maybe, fn_name)
return decorate
databases except those listed.
"""
+ spec = db_spec(*dbs)
+
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name in dbs:
+ if spec(config.db):
return fn(*args, **kw)
else:
try:
fn(*args, **kw)
except Exception, ex:
print ("'%s' failed as expected on DB implementation "
- "'%s': %s" % (
- fn_name, config.db.name, str(ex)))
+ "'%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, str(ex)))
return True
else:
raise AssertionError(
- "Unexpected success for '%s' on DB implementation '%s'" %
- (fn_name, config.db.name))
+ "Unexpected success for '%s' on DB implementation '%s+%s'" %
+ (fn_name, config.db.name, config.db.driver))
return function_named(maybe, fn_name)
return decorate
"""
carp = _should_carp_about_exclusion(reason)
+ spec = db_spec(db)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name == db:
- msg = "'%s' unsupported on DB implementation '%s': %s" % (
- fn_name, config.db.name, reason)
+ if spec(config.db):
+ msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, reason)
print msg
if carp:
print >> sys.stderr, msg
"""
carp = _should_carp_about_exclusion(reason)
+ spec = db_spec(db)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name == db:
- msg = "'%s' unsupported on DB implementation '%s': %s" % (
- fn_name, config.db.name, reason)
+ if spec(config.db):
+ msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, reason)
print msg
if carp:
print >> sys.stderr, msg
"""
carp = _should_carp_about_exclusion(reason)
+
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
_is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
"""
- if config.db.name != db:
+ vendor_spec = db_spec(db)
+
+ if not vendor_spec(config.db):
return False
version = _server_version()
if bind is None:
bind = config.db
- return bind.dialect.server_version_info(bind.contextual_connect())
+
+ # force metadata to be retrieved
+ conn = bind.connect()
+ version = getattr(bind.dialect, 'server_version_info', ())
+ conn.close()
+ return version
def skip_if(predicate, reason=None):
"""Skip a test if predicate is true."""
if predicate():
msg = "'%s' skipped on DB %s version '%s': %s" % (
fn_name, config.db.name, _server_version(), reason)
- print msg
- return True
+ raise SkipTest(msg)
else:
return fn(*args, **kw)
return function_named(maybe, fn_name)
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
"""
+ spec = db_spec(db)
+
def decorate(fn):
def maybe(*args, **kw):
if isinstance(db, basestring):
- if config.db.name != db:
+ if not spec(config.db):
return fn(*args, **kw)
else:
wrapped = emits_warning(*warnings)(fn)
if sys.version_info < (2, 4):
warnings.filterwarnings('ignore', category=FutureWarning)
+def global_cleanup_assertions():
+ """Check things that have to be finalized at the end of a test suite.
+
+ Hardcoded at the moment, a modular system can be built here
+ to support things like PG prepared transactions, tables all
+ dropped, etc.
+
+ """
+
+ testutil.lazy_gc()
+ assert not pool._refs
+
+
def against(*queries):
"""Boolean predicate, compares to testing database configuration.
Also supports comparison to database version when provided with one or
more 3-tuples of dialect name, operator, and version specification::
- testing.against('mysql', 'postgres')
+ testing.against('mysql', 'postgresql')
testing.against(('mysql', '>=', (5, 0, 0))
"""
for query in queries:
if isinstance(query, basestring):
- if config.db.name == query:
+ if db_spec(query)(config.db):
return True
else:
name, op, spec = query
- if config.db.name != name:
+ if not db_spec(name)(config.db):
continue
- have = config.db.dialect.server_version_info(
- config.db.contextual_connect())
+ have = _server_version()
oper = hasattr(op, '__call__') and op or _ops[op]
if oper(have, spec):
if dialect is None:
dialect = getattr(self, '__dialect__', None)
- if params is None:
- keys = None
- else:
- keys = params.keys()
+ kw = {}
+ if params is not None:
+ kw['column_keys'] = params.keys()
- c = clause.compile(column_keys=keys, dialect=dialect)
+ c = clause.compile(dialect=dialect, **kw)
- print "\nSQL String:\n" + str(c) + repr(c.params)
+ print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {}))
- cc = re.sub(r'\n', '', str(c))
+ cc = re.sub(r'[\n\t]', '', str(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
class ComparesTables(object):
def assert_tables_equal(self, table, reflected_table):
- base_mro = sqltypes.TypeEngine.__mro__
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
eq_(c.primary_key, reflected_c.primary_key)
eq_(c.nullable, reflected_c.nullable)
- assert len(
- set(type(reflected_c.type).__mro__).difference(base_mro).intersection(
- set(type(c.type).__mro__).difference(base_mro)
- )
- ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+ self.assert_types_base(reflected_c, c)
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
elif against(('mysql', '<', (5, 0))):
# ignore reflection of bogus db-generated DefaultClause()
pass
- elif not c.primary_key or not against('postgres'):
- print repr(c)
+ elif not c.primary_key or not against('postgresql', 'mssql'):
+ #print repr(c)
assert reflected_c.default is None, reflected_c.default
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name]
-
+
+ def assert_types_base(self, c1, c2):
+ base_mro = sqltypes.TypeEngine.__mro__
+ assert len(
+ set(type(c1.type).__mro__).difference(base_mro).intersection(
+ set(type(c2.type).__mro__).difference(base_mro)
+ )
+ ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (c1.name, c1.type, c2.type)
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
assertsql.asserter.clear_rules()
def assert_sql(self, db, callable_, list_, with_sequences=None):
- if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
+ if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'):
rules = with_sequences
else:
rules = list_
--- /dev/null
+from sqlalchemy.util import jython, function_named
+
+import gc
+import time
+
+if jython:
+ def gc_collect(*args):
+ """aggressive gc.collect for tests."""
+ gc.collect()
+ time.sleep(0.1)
+ gc.collect()
+ gc.collect()
+ return 0
+
+ # "lazy" gc, for VM's that don't GC on refcount == 0
+ lazy_gc = gc_collect
+
+else:
+ # assume CPython - straight gc.collect, lazy_gc() is a pass
+ gc_collect = gc.collect
+ def lazy_gc():
+ pass
+
+
edges = _EdgeCollection()
for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]:
- if id(item) not in nodes:
- node = _Node(item)
- nodes[item] = node
+ item_id = id(item)
+ if item_id not in nodes:
+ nodes[item_id] = _Node(item)
for t in tuples:
+ id0, id1 = id(t[0]), id(t[1])
if t[0] is t[1]:
if allow_cycles:
- n = nodes[t[0]]
+ n = nodes[id0]
n.cycles = set([n])
elif not ignore_self_cycles:
raise CircularDependencyError("Self-referential dependency detected " + repr(t))
continue
- childnode = nodes[t[1]]
- parentnode = nodes[t[0]]
+ childnode = nodes[id1]
+ parentnode = nodes[id0]
edges.add((parentnode, childnode))
queue = []
node = queue.pop()
if not hasattr(node, '_cyclical'):
output.append(node)
- del nodes[node.item]
+ del nodes[id(node.item)]
for childnode in edges.pop_node(node):
queue.append(childnode)
return output
for parent in edges.get_parents():
traverse(parent)
- # sets are not hashable, so uniquify with id
- unique_cycles = dict((id(s), s) for s in cycles.values()).values()
+ unique_cycles = set(tuple(s) for s in cycles.values())
+
for cycle in unique_cycles:
edgecollection = [edge for edge in edges
if edge[0] in cycle and edge[1] in cycle]
For more information see the SQLAlchemy documentation on types.
"""
-__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType',
- 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT',
- 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB',
- 'BOOLEAN', 'SMALLINT', 'DATE', 'TIME',
- 'String', 'Integer', 'SmallInteger','Smallinteger',
+__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType',
+ 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text', 'FLOAT',
+ 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'NCLOB', 'BLOB',
+ 'BOOLEAN', 'SMALLINT', 'INTEGER','DATE', 'TIME',
+ 'String', 'Integer', 'SmallInteger',
'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary',
'Boolean', 'Unicode', 'MutableType', 'Concatenable', 'UnicodeText', 'PickleType', 'Interval',
'type_map'
import inspect
import datetime as dt
from decimal import Decimal as _python_Decimal
-import weakref
+
from sqlalchemy import exc
from sqlalchemy.util import pickle
+from sqlalchemy.sql.visitors import Visitable
import sqlalchemy.util as util
NoneType = type(None)
+if util.jython:
+ import array
-class AbstractType(object):
+class AbstractType(Visitable):
def __init__(self, *args, **kwargs):
pass
+ def compile(self, dialect):
+ return dialect.type_compiler.process(self)
+
def copy_value(self, value):
return value
for k in inspect.getargspec(self.__init__)[0][1:]))
class TypeEngine(AbstractType):
- """Base for built-in types.
-
- May be sub-classed to create entirely new types. Example::
-
- import sqlalchemy.types as types
-
- class MyType(types.TypeEngine):
- def __init__(self, precision = 8):
- self.precision = precision
-
- def get_col_spec(self):
- return "MYTYPE(%s)" % self.precision
-
- def bind_processor(self, dialect):
- def process(value):
- return value
- return process
-
- def result_processor(self, dialect):
- def process(value):
- return value
- return process
-
- Once the type is made, it's immediately usable::
-
- table = Table('foo', meta,
- Column('id', Integer, primary_key=True),
- Column('data', MyType(16))
- )
-
- """
+ """Base for built-in types."""
+ @util.memoized_property
+ def _impl_dict(self):
+ return {}
+
def dialect_impl(self, dialect, **kwargs):
try:
- return self._impl_dict[dialect]
- except AttributeError:
- self._impl_dict = weakref.WeakKeyDictionary() # will be optimized in 0.6
- return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self))
+ return self._impl_dict[dialect.__class__]
except KeyError:
- return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self))
+ return self._impl_dict.setdefault(dialect.__class__, dialect.__class__.type_descriptor(self))
def __getstate__(self):
d = self.__dict__.copy()
d.pop('_impl_dict', None)
return d
- def get_col_spec(self):
- """Return the DDL representation for this type."""
- raise NotImplementedError()
-
def bind_processor(self, dialect):
"""Return a conversion function for processing bind values.
def adapt(self, cls):
return cls()
- def get_search_list(self):
- """return a list of classes to test for a match
- when adapting this type to a dialect-specific type.
+class UserDefinedType(TypeEngine):
+ """Base for user defined types.
+
+ This should be the base of new types. Note that
+ for most cases, :class:`TypeDecorator` is probably
+ more appropriate.
- """
+ import sqlalchemy.types as types
- return self.__class__.__mro__[0:-1]
+ class MyType(types.UserDefinedType):
+ def __init__(self, precision = 8):
+ self.precision = precision
+ def get_col_spec(self):
+ return "MYTYPE(%s)" % self.precision
+
+ def bind_processor(self, dialect):
+ def process(value):
+ return value
+ return process
+
+ def result_processor(self, dialect):
+ def process(value):
+ return value
+ return process
+
+ Once the type is made, it's immediately usable::
+
+ table = Table('foo', meta,
+ Column('id', Integer, primary_key=True),
+ Column('data', MyType(16))
+ )
+
+ """
+ __visit_name__ = "user_defined"
+
class TypeDecorator(AbstractType):
"""Allows the creation of types which add additional functionality
to an existing type.
"""
+ __visit_name__ = "type_decorator"
+
def __init__(self, *args, **kwargs):
if not hasattr(self.__class__, 'impl'):
- raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
+ raise AssertionError("TypeDecorator implementations require a class-level "
+ "variable 'impl' which refers to the class of type being decorated")
self.impl = self.__class__.impl(*args, **kwargs)
- def dialect_impl(self, dialect, **kwargs):
+ def dialect_impl(self, dialect):
try:
- return self._impl_dict[dialect]
+ return self._impl_dict[dialect.__class__]
except AttributeError:
- self._impl_dict = weakref.WeakKeyDictionary() # will be optimized in 0.6
+ self._impl_dict = {}
except KeyError:
pass
+ # adapt the TypeDecorator first, in
+ # the case that the dialect maps the TD
+ # to one of its native types (i.e. PGInterval)
+ adapted = dialect.__class__.type_descriptor(self)
+ if adapted is not self:
+ self._impl_dict[dialect] = adapted
+ return adapted
+
+ # otherwise adapt the impl type, link
+ # to a copy of this TypeDecorator and return
+ # that.
typedesc = self.load_dialect_impl(dialect)
tt = self.copy()
if not isinstance(tt, self.__class__):
self._impl_dict[dialect] = tt
return tt
+ def type_engine(self, dialect):
+ impl = self.dialect_impl(dialect)
+ if not isinstance(impl, TypeDecorator):
+ return impl
+ else:
+ return impl.impl
+
def load_dialect_impl(self, dialect):
"""Loads the dialect-specific implementation of this type.
by default calls dialect.type_descriptor(self.impl), but
can be overridden to provide different behavior.
+
"""
-
if isinstance(self.impl, TypeDecorator):
return self.impl.dialect_impl(dialect)
else:
- return dialect.type_descriptor(self.impl)
+ return dialect.__class__.type_descriptor(self.impl)
def __getattr__(self, key):
"""Proxy all other undefined accessors to the underlying implementation."""
return getattr(self.impl, key)
- def get_col_spec(self):
- return self.impl.get_col_spec()
-
def process_bind_param(self, value, dialect):
raise NotImplementedError()
def adapt_type(typeobj, colspecs):
if isinstance(typeobj, type):
typeobj = typeobj()
- for t in typeobj.get_search_list():
+ for t in typeobj.__class__.__mro__[0:-1]:
try:
impltype = colspecs[t]
break
encountered during a :meth:`~sqlalchemy.Table.create` operation.
"""
-
- def get_col_spec(self):
- raise NotImplementedError()
+ __visit_name__ = 'null'
NullTypeEngine = NullType
"""
+ __visit_name__ = 'string'
+
def __init__(self, length=None, convert_unicode=False, assert_unicode=None):
"""
Create a string-holding type.
self.assert_unicode = assert_unicode
def adapt(self, impltype):
- return impltype(length=self.length, convert_unicode=self.convert_unicode, assert_unicode=self.assert_unicode)
+ return impltype(
+ length=self.length,
+ convert_unicode=self.convert_unicode,
+ assert_unicode=self.assert_unicode)
def bind_processor(self, dialect):
if self.convert_unicode or dialect.convert_unicode:
assert_unicode = dialect.assert_unicode
else:
assert_unicode = self.assert_unicode
- def process(value):
- if isinstance(value, unicode):
- return value.encode(dialect.encoding)
- elif assert_unicode and not isinstance(value, (unicode, NoneType)):
- if assert_unicode == 'warn':
- util.warn("Unicode type received non-unicode bind "
- "param value %r" % value)
+
+ if dialect.supports_unicode_binds and assert_unicode:
+ def process(value):
+ if not isinstance(value, (unicode, NoneType)):
+ if assert_unicode == 'warn':
+ util.warn("Unicode type received non-unicode bind "
+ "param value %r" % value)
+ return value
+ else:
+ raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
+ else:
return value
+ elif dialect.supports_unicode_binds:
+ return None
+ else:
+ def process(value):
+ if isinstance(value, unicode):
+ return value.encode(dialect.encoding)
+ elif assert_unicode and not isinstance(value, (unicode, NoneType)):
+ if assert_unicode == 'warn':
+ util.warn("Unicode type received non-unicode bind "
+ "param value %r" % value)
+ return value
+ else:
+ raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
else:
- raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
- else:
- return value
+ return value
return process
else:
return None
params (and the reverse for result sets.)
"""
+ __visit_name__ = 'text'
class Unicode(String):
"""A variable length Unicode string.
:attr:`~sqlalchemy.engine.base.Dialect.encoding`, which defaults
to `utf-8`.
- A synonym for String(length, convert_unicode=True, assert_unicode='warn').
-
"""
+ __visit_name__ = 'unicode'
+
def __init__(self, length=None, **kwargs):
"""
Create a Unicode-converting String type.
super(Unicode, self).__init__(length=length, **kwargs)
class UnicodeText(Text):
- """A synonym for Text(convert_unicode=True, assert_unicode='warn')."""
+ """An unbounded-length Unicode string.
+
+ See :class:`Unicode` for details on the unicode
+ behavior of this object.
+
+ """
+
+ __visit_name__ = 'unicode_text'
def __init__(self, length=None, **kwargs):
"""
class Integer(TypeEngine):
"""A type for ``int`` integers."""
-
+
+ __visit_name__ = 'integer'
+
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
"""
-Smallinteger = SmallInteger
+ __visit_name__ = 'small_integer'
+
+class BigInteger(Integer):
+ """A type for bigger ``int`` integers.
+
+ Typically generates a ``BIGINT`` in DDL, and otherwise acts like
+ a normal :class:`Integer` on the Python side.
+
+ """
+
+ __visit_name__ = 'big_integer'
class Numeric(TypeEngine):
"""A type for fixed precision numbers.
"""
- def __init__(self, precision=10, scale=2, asdecimal=True, length=None):
+ __visit_name__ = 'numeric'
+
+ def __init__(self, precision=None, scale=None, asdecimal=True):
"""
Construct a Numeric.
use.
"""
- if length:
- util.warn_deprecated("'length' is deprecated for Numeric. Use 'scale'.")
- scale = length
self.precision = precision
self.scale = scale
self.asdecimal = asdecimal
class Float(Numeric):
"""A type for ``float`` numbers."""
- def __init__(self, precision=10, asdecimal=False, **kwargs):
+ __visit_name__ = 'float'
+
+ def __init__(self, precision=None, asdecimal=False, **kwargs):
"""
Construct a Float.
converted back to datetime objects when rows are returned.
"""
-
+
+ __visit_name__ = 'datetime'
+
def __init__(self, timezone=False):
self.timezone = timezone
class Date(TypeEngine):
"""A type for ``datetime.date()`` objects."""
+ __visit_name__ = 'date'
+
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
class Time(TypeEngine):
"""A type for ``datetime.time()`` objects."""
+ __visit_name__ = 'time'
+
def __init__(self, timezone=False):
self.timezone = timezone
"""
+ __visit_name__ = 'binary'
+
def __init__(self, length=None):
"""
Construct a Binary type.
loads = self.pickler.loads
if value is None:
return None
- return loads(str(value))
+ # Py3K
+ #return loads(value)
+ # Py2K
+ if util.jython and isinstance(value, array.ArrayType):
+ value = value.tostring()
+ else:
+ value = str(value)
+ return loads(value)
+ # end Py2K
def copy_value(self, value):
if self.mutable:
def compare_values(self, x, y):
if self.comparator:
return self.comparator(x, y)
- elif self.mutable and not hasattr(x, '__eq__') and x is not None:
- util.warn_deprecated("Objects stored with PickleType when mutable=True must implement __eq__() for reliable comparison.")
- return self.pickler.dumps(x, self.protocol) == self.pickler.dumps(y, self.protocol)
else:
return x == y
"""
+ __visit_name__ = 'boolean'
class Interval(TypeDecorator):
"""A type for ``datetime.timedelta()`` objects.
"""
- impl = TypeEngine
-
- def __init__(self):
- super(Interval, self).__init__()
- import sqlalchemy.databases.postgres as pg
- self.__supported = {pg.PGDialect:pg.PGInterval}
- del pg
-
- def load_dialect_impl(self, dialect):
- if dialect.__class__ in self.__supported:
- return self.__supported[dialect.__class__]()
- else:
- return dialect.type_descriptor(DateTime)
+ impl = DateTime
def process_bind_param(self, value, dialect):
- if dialect.__class__ in self.__supported:
- return value
- else:
- if value is None:
- return None
- return dt.datetime.utcfromtimestamp(0) + value
+ if value is None:
+ return None
+ return dt.datetime.utcfromtimestamp(0) + value
def process_result_value(self, value, dialect):
- if dialect.__class__ in self.__supported:
- return value
- else:
- if value is None:
- return None
- return value - dt.datetime.utcfromtimestamp(0)
+ if value is None:
+ return None
+ return value - dt.datetime.utcfromtimestamp(0)
class FLOAT(Float):
"""The SQL FLOAT type."""
+ __visit_name__ = 'FLOAT'
class NUMERIC(Numeric):
"""The SQL NUMERIC type."""
+ __visit_name__ = 'NUMERIC'
+
class DECIMAL(Numeric):
"""The SQL DECIMAL type."""
+ __visit_name__ = 'DECIMAL'
-class INT(Integer):
+
+class INTEGER(Integer):
"""The SQL INT or INTEGER type."""
+ __visit_name__ = 'INTEGER'
+INT = INTEGER
-INTEGER = INT
-class SMALLINT(Smallinteger):
+class SMALLINT(SmallInteger):
"""The SQL SMALLINT type."""
+ __visit_name__ = 'SMALLINT'
+
+
+class BIGINT(BigInteger):
+ """The SQL BIGINT type."""
+
+ __visit_name__ = 'BIGINT'
class TIMESTAMP(DateTime):
"""The SQL TIMESTAMP type."""
+ __visit_name__ = 'TIMESTAMP'
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.TIMESTAMP
class DATETIME(DateTime):
"""The SQL DATETIME type."""
+ __visit_name__ = 'DATETIME'
+
class DATE(Date):
"""The SQL DATE type."""
+ __visit_name__ = 'DATE'
+
class TIME(Time):
"""The SQL TIME type."""
+ __visit_name__ = 'TIME'
-TEXT = Text
+class TEXT(Text):
+ """The SQL TEXT type."""
+
+ __visit_name__ = 'TEXT'
class CLOB(Text):
- """The SQL CLOB type."""
+ """The CLOB type.
+
+ This type is found in Oracle and Informix.
+ """
+ __visit_name__ = 'CLOB'
class VARCHAR(String):
"""The SQL VARCHAR type."""
+ __visit_name__ = 'VARCHAR'
+
+class NVARCHAR(Unicode):
+ """The SQL NVARCHAR type."""
+
+ __visit_name__ = 'NVARCHAR'
class CHAR(String):
"""The SQL CHAR type."""
+ __visit_name__ = 'CHAR'
+
class NCHAR(Unicode):
"""The SQL NCHAR type."""
+ __visit_name__ = 'NCHAR'
+
class BLOB(Binary):
"""The SQL BLOB type."""
+ __visit_name__ = 'BLOB'
+
class BOOLEAN(Boolean):
"""The SQL BOOLEAN type."""
+ __visit_name__ = 'BOOLEAN'
+
NULLTYPE = NullType()
# using VARCHAR/NCHAR so that we dont get the genericized "String"
# type which usually resolves to TEXT/CLOB
type_map = {
- str : VARCHAR,
- unicode : NCHAR,
+ str: String,
+ # Py2K
+ unicode : String,
+ # end Py2K
int : Integer,
float : Numeric,
bool: Boolean,
dt.datetime : DateTime,
dt.time : Time,
dt.timedelta : Interval,
- type(None): NullType
+ NoneType: NullType
}
+
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import inspect, itertools, operator, sys, warnings, weakref
+import inspect, itertools, operator, sys, warnings, weakref, gc
+# Py2K
import __builtin__
+# end Py2K
types = __import__('types')
from sqlalchemy import exc
import dummy_threading as threading
py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0)
+jython = sys.platform.startswith('java')
if py3k:
set_types = set
EMPTY_SET = frozenset()
+NoneType = type(None)
+
if py3k:
import pickle
else:
except ImportError:
import pickle
+# Py2K
# a controversial feature, required by MySQLdb currently
def buffer(x):
return x
buffer = getattr(__builtin__, 'buffer', buffer)
+# end Py2K
if sys.version_info >= (2, 5):
class PopulateDict(dict):
if py3k:
def callable(fn):
return hasattr(fn, '__call__')
-else:
- callable = __builtin__.callable
+ def cmp(a, b):
+ return (a > b) - (a < b)
-if py3k:
from functools import reduce
else:
+ callable = __builtin__.callable
+ cmp = __builtin__.cmp
reduce = __builtin__.reduce
try:
def decode_slice(slc):
return (slc.start, slc.stop, slc.step)
+def update_copy(d, _new=None, **kw):
+ """Copy the given dict and update with the given values."""
+
+ d = d.copy()
+ if _new:
+ d.update(_new)
+ d.update(**kw)
+ return d
+
def flatten_iterator(x):
"""Given an iterator of which further sub-elements may also be
iterators, flatten the sub-elements into a single iterator.
class_ = stack.pop()
ctr = class_.__dict__.get('__init__', False)
if not ctr or not isinstance(ctr, types.FunctionType):
+ stack.update(class_.__bases__)
continue
names, _, has_kw, _ = inspect.getargspec(ctr)
args.update(names)
will not be descended.
"""
+ # Py2K
if isinstance(cls, types.ClassType):
return list()
+ # end Py2K
hier = set([cls])
process = list(cls.__mro__)
while process:
c = process.pop()
+ # Py2K
if isinstance(c, types.ClassType):
continue
for b in (_ for _ in c.__bases__
if _ not in hier and not isinstance(_, types.ClassType)):
+ # end Py2K
+ # Py3K
+ #for b in (_ for _ in c.__bases__
+ # if _ not in hier):
process.append(b)
hier.add(b)
+ # Py3K
+ #if c.__module__ == 'builtins' or not hasattr(c, '__subclasses__'):
+ # continue
+ # Py2K
if c.__module__ == '__builtin__' or not hasattr(c, '__subclasses__'):
continue
+ # end Py2K
for s in [_ for _ in c.__subclasses__() if _ not in hier]:
process.append(s)
hier.add(s)
return self._data.keys()
def has_key(self, key):
- return self._data.has_key(key)
+ return key in self._data
def clear(self):
self._data.clear()
-
class OrderedDict(dict):
"""A dict that returns keys/values/items in the order they were added."""
def __setitem__(self, key, object):
if key not in self:
- self._list.append(key)
+ try:
+ self._list.append(key)
+ except AttributeError:
+ # work around Python pickle loads() with
+ # dict subclass (seems to ignore __setstate__?)
+ self._list = [key]
dict.__setitem__(self, key, object)
def __delitem__(self, key):
if len(self) > len(other):
return False
- for m in itertools.ifilterfalse(other._members.has_key,
+ for m in itertools.ifilterfalse(other._members.__contains__,
self._members.iterkeys()):
return False
return True
return item
def clear(self):
+ # Py2K
+ # in 3k, MutableMapping calls popitem()
self._weakrefs.clear()
self.by_id.clear()
+ # end Py2K
weakref.WeakKeyDictionary.clear(self)
def update(self, *a, **kw):
--- /dev/null
+"""SQLAlchemy 2to3 tool.
+
+Relax ! This just calls the regular 2to3 tool with a preprocessor bolted onto it.
+
+
+I originally wanted to write a custom fixer to accomplish this
+but the Fixer classes seem like they can only understand
+the grammar file included with 2to3, and the grammar does not
+seem to include Python comments (and of course, huge hacks needed
+to get out-of-package fixers in there). While that may be
+an option later on this is a pretty simple approach for
+what is a pretty simple problem.
+
+"""
+
+from lib2to3 import main, refactor
+
+import re
+
+py3k_pattern = re.compile(r'\s*# Py3K')
+comment_pattern = re.compile(r'(\s*)#(?! ?Py2K)(.*)')
+py2k_pattern = re.compile(r'\s*# Py2K')
+end_py2k_pattern = re.compile(r'\s*# end Py2K')
+
+def preprocess(data):
+ lines = data.split('\n')
+ def consume_normal():
+ while lines:
+ line = lines.pop(0)
+ if py3k_pattern.match(line):
+ for line in consume_py3k():
+ yield line
+ elif py2k_pattern.match(line):
+ for line in consume_py2k():
+ yield line
+ else:
+ yield line
+
+ def consume_py3k():
+ yield "# start Py3K"
+ while lines:
+ line = lines.pop(0)
+ m = comment_pattern.match(line)
+ if m:
+ yield "%s%s" % m.group(1, 2)
+ else:
+ # pushback
+ lines.insert(0, line)
+ break
+ yield "# end Py3K"
+
+ def consume_py2k():
+ yield "# start Py2K"
+ while lines:
+ line = lines.pop(0)
+ if not end_py2k_pattern.match(line):
+ yield "#%s" % line
+ else:
+ break
+ yield "# end Py2K"
+
+ return "\n".join(consume_normal())
+
+old_refactor_string = main.StdoutRefactoringTool.refactor_string
+
+def refactor_string(self, data, name):
+ newdata = preprocess(data)
+ tree = old_refactor_string(self, newdata, name)
+ if tree:
+ if newdata != data:
+ tree.was_changed = True
+ return tree
+
+main.StdoutRefactoringTool.refactor_string = refactor_string
+
+main.main("lib2to3.fixes")
if sys.version_info < (2, 4):
raise Exception("SQLAlchemy requires Python 2.4 or higher.")
-v = file(os.path.join(os.path.dirname(__file__), 'lib', 'sqlalchemy', '__init__.py'))
+v = open(os.path.join(os.path.dirname(__file__), 'lib', 'sqlalchemy', '__init__.py'))
VERSION = re.compile(r".*__version__ = '(.*?)'", re.S).match(v.read()).group(1)
v.close()
packages = find_packages('lib'),
package_dir = {'':'lib'},
license = "MIT License",
-
+ tests_require = ['nose >= 0.10'],
+ test_suite = "nose.collector",
entry_points = {
'nose.plugins.0.10': [
'sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy',
--- /dev/null
+"""
+nose runner script.
+
+Only use this script if setuptools is not available, i.e. such as
+on Python 3K. Otherwise consult README.unittests for the
+recommended methods of running tests.
+
+"""
+
+import nose
+from sqlalchemy.test.noseplugin import NoseSQLAlchemy
+from sqlalchemy.util import py3k
+
+if __name__ == '__main__':
+ if py3k:
+ # this version breaks verbose output,
+ # but is the only API that nose3 currently supports
+ nose.main(plugins=[NoseSQLAlchemy()])
+ else:
+ # this is the "correct" API
+ nose.main(addplugins=[NoseSQLAlchemy()])
+
Column('c1', Integer, primary_key=True),
Column('c2', String(30)))
- @profiling.function_call_count(68, {'2.4': 42})
+ @profiling.function_call_count(72, {'2.4': 42, '3.0':77})
def test_insert(self):
t1.insert().compile()
- @profiling.function_call_count(68, {'2.4': 45})
+ @profiling.function_call_count(72, {'2.4': 45})
def test_update(self):
t1.update().compile()
- @profiling.function_call_count(185, versions={'2.4':118})
+ @profiling.function_call_count(195, versions={'2.4':118, '3.0':208})
def test_select(self):
s = select([t1], t1.c.c2==t2.c.c1)
s.compile()
from sqlalchemy.test.testing import eq_
-import gc
from sqlalchemy.orm import mapper, relation, create_session, clear_mappers, sessionmaker
from sqlalchemy.orm.mapper import _mapper_registry
from sqlalchemy.orm.session import _sessions
+from sqlalchemy.util import jython
import operator
from sqlalchemy.test import testing
from sqlalchemy import MetaData, Integer, String, ForeignKey, PickleType
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
import sqlalchemy as sa
from sqlalchemy.sql import column
+from sqlalchemy.test.util import gc_collect
+import gc
from test.orm import _base
+if jython:
+ from nose import SkipTest
+ raise SkipTest("Profiling not supported on this platform")
+
class A(_base.ComparableEntity):
pass
# run the test 50 times. if length of gc.get_objects()
# keeps growing, assert false
def profile(*args):
- gc.collect()
+ gc_collect()
samples = [0 for x in range(0, 50)]
for x in range(0, 50):
func(*args)
- gc.collect()
+ gc_collect()
samples[x] = len(gc.get_objects())
print "sample gc sizes:", samples
def assert_no_mappers():
clear_mappers()
- gc.collect()
+ gc_collect()
assert len(_mapper_registry) == 0
class EnsureZeroed(_base.ORMTest):
class MemUsageTest(EnsureZeroed):
# ensure a pure growing test trips the assertion
- @testing.fails_if(lambda:True)
+ @testing.fails_if(lambda: True)
def test_fixture(self):
class Foo(object):
pass
metadata = MetaData(testing.db)
table1 = Table("mytable", metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('col2', String(30)))
table2 = Table("mytable2", metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('col2', String(30)),
Column('col3', Integer, ForeignKey("mytable.col1")))
metadata = MetaData(testing.db)
table1 = Table("mytable", metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('col2', String(30)))
table2 = Table("mytable2", metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('col2', String(30)),
Column('col3', Integer, ForeignKey("mytable.col1")))
metadata = MetaData(testing.db)
table1 = Table("mytable", metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('col2', String(30))
)
table2 = Table("mytable2", metadata,
Column('col1', Integer, ForeignKey('mytable.col1'),
- primary_key=True),
+ primary_key=True, test_needs_autoincrement=True),
Column('col3', String(30)),
)
metadata = MetaData(testing.db)
table1 = Table("mytable", metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('col2', String(30))
)
table2 = Table("mytable2", metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('col2', String(30)),
)
metadata = MetaData(testing.db)
table1 = Table("table1", metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30))
)
table2 = Table("table2", metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
Column('t1id', Integer, ForeignKey('table1.id'))
)
metadata = MetaData(testing.db)
table1 = Table("mytable", metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('col2', PickleType(comparator=operator.eq))
)
testing.eq_(len(session.identity_map._mutable_attrs), 12)
testing.eq_(len(session.identity_map), 12)
obj = None
- gc.collect()
+ gc_collect()
testing.eq_(len(session.identity_map._mutable_attrs), 0)
testing.eq_(len(session.identity_map), 0)
metadata.drop_all()
def test_type_compile(self):
- from sqlalchemy.databases.sqlite import SQLiteDialect
+ from sqlalchemy.dialects.sqlite.base import dialect as SQLiteDialect
cast = sa.cast(column('x'), sa.Integer)
@profile_memory
def go():
class QueuePoolTest(TestBase, AssertsExecutionResults):
class Connection(object):
+ def rollback(self):
+ pass
+
def close(self):
pass
use_threadlocal=True)
- @profiling.function_call_count(54, {'2.4': 38})
+ @profiling.function_call_count(54, {'2.4': 38, '3.0':57})
def test_first_connect(self):
conn = pool.connect()
conn = pool.connect()
conn.close()
- @profiling.function_call_count(31, {'2.4': 21})
+ @profiling.function_call_count(29, {'2.4': 21})
def go():
conn2 = pool.connect()
return conn2
def test_second_samethread_connect(self):
conn = pool.connect()
- @profiling.function_call_count(5, {'2.4': 3})
+ @profiling.function_call_count(5, {'2.4': 3, '3.0':6})
def go():
return pool.connect()
c2 = go()
"""
- __only_on__ = 'postgres'
+ __only_on__ = 'postgresql+psycopg2'
__skip_if__ = ((lambda: sys.version_info < (2, 4)), )
def test_baseline_0_setup(self):
Opens=datetime.time(8, 15, 59),
LastEscape=datetime.datetime(2004, 7, 29, 5, 6, 7),
Admission=4.95,
- ).last_inserted_ids()[0]
+ ).inserted_primary_key[0]
sdz = Zoo.insert().execute(Name =u'San Diego Zoo',
Founded = datetime.date(1935, 9, 13),
Opens = datetime.time(9, 0, 0),
Admission = 0,
- ).last_inserted_ids()[0]
+ ).inserted_primary_key[0]
- Zoo.insert().execute(
+ Zoo.insert(inline=True).execute(
Name = u'Montr\xe9al Biod\xf4me',
Founded = datetime.date(1992, 6, 19),
Opens = datetime.time(9, 0, 0),
)
seaworld = Zoo.insert().execute(
- Name =u'Sea_World', Admission = 60).last_inserted_ids()[0]
+ Name =u'Sea_World', Admission = 60).inserted_primary_key[0]
# Let's add a crazy futuristic Zoo to test large date values.
lp = Zoo.insert().execute(Name =u'Luna Park',
Founded = datetime.date(2072, 7, 17),
Opens = datetime.time(0, 0, 0),
Admission = 134.95,
- ).last_inserted_ids()[0]
+ ).inserted_primary_key[0]
# Animals
leopardid = Animal.insert().execute(Species=u'Leopard', Lifespan=73.5,
- ).last_inserted_ids()[0]
+ ).inserted_primary_key[0]
Animal.update(Animal.c.ID==leopardid).execute(ZooID=wap,
LastEscape=datetime.datetime(2004, 12, 21, 8, 15, 0, 999907))
- lion = Animal.insert().execute(Species=u'Lion', ZooID=wap).last_inserted_ids()[0]
+ lion = Animal.insert().execute(Species=u'Lion', ZooID=wap).inserted_primary_key[0]
Animal.insert().execute(Species=u'Slug', Legs=1, Lifespan=.75)
tiger = Animal.insert().execute(Species=u'Tiger', ZooID=sdz
- ).last_inserted_ids()[0]
+ ).inserted_primary_key[0]
# Override Legs.default with itself just to make sure it works.
- Animal.insert().execute(Species=u'Bear', Legs=4)
- Animal.insert().execute(Species=u'Ostrich', Legs=2, Lifespan=103.2)
- Animal.insert().execute(Species=u'Centipede', Legs=100)
+ Animal.insert(inline=True).execute(Species=u'Bear', Legs=4)
+ Animal.insert(inline=True).execute(Species=u'Ostrich', Legs=2, Lifespan=103.2)
+ Animal.insert(inline=True).execute(Species=u'Centipede', Legs=100)
emp = Animal.insert().execute(Species=u'Emperor Penguin', Legs=2,
- ZooID=seaworld).last_inserted_ids()[0]
+ ZooID=seaworld).inserted_primary_key[0]
adelie = Animal.insert().execute(Species=u'Adelie Penguin', Legs=2,
- ZooID=seaworld).last_inserted_ids()[0]
+ ZooID=seaworld).inserted_primary_key[0]
- Animal.insert().execute(Species=u'Millipede', Legs=1000000, ZooID=sdz)
+ Animal.insert(inline=True).execute(Species=u'Millipede', Legs=1000000, ZooID=sdz)
# Add a mother and child to test relationships
bai_yun = Animal.insert().execute(Species=u'Ape', Name=u'Bai Yun',
- Legs=2).last_inserted_ids()[0]
- Animal.insert().execute(Species=u'Ape', Name=u'Hua Mei', Legs=2,
+ Legs=2).inserted_primary_key[0]
+ Animal.insert(inline=True).execute(Species=u'Ape', Name=u'Hua Mei', Legs=2,
MotherID=bai_yun)
def test_baseline_2_insert(self):
Animal = metadata.tables['Animal']
- i = Animal.insert()
+ i = Animal.insert(inline=True)
for x in xrange(ITERATIONS):
tick = i.execute(Species=u'Tick', Name=u'Tick %d' % x, Legs=8)
def fullobject(select):
"""Iterate over the full result row."""
- return list(select.execute().fetchone())
+ return list(select.execute().first())
for x in xrange(ITERATIONS):
# Zoos
for x in xrange(ITERATIONS):
# Edit
- SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().fetchone()
+ SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().first()
Zoo.update(Zoo.c.ID==SDZ['ID']).execute(
Name=u'The San Diego Zoo',
Founded = datetime.date(1900, 1, 1),
Admission = "35.00")
# Test edits
- SDZ = Zoo.select(Zoo.c.Name==u'The San Diego Zoo').execute().fetchone()
+ SDZ = Zoo.select(Zoo.c.Name==u'The San Diego Zoo').execute().first()
assert SDZ['Founded'] == datetime.date(1900, 1, 1), SDZ['Founded']
# Change it back
Admission = "0")
# Test re-edits
- SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().fetchone()
+ SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().first()
assert SDZ['Founded'] == datetime.date(1935, 9, 13)
def test_baseline_7_multiview(self):
global metadata
player = lambda: dbapi_session.player()
- engine = create_engine('postgres:///', creator=player)
+ engine = create_engine('postgresql:///', creator=player)
metadata = MetaData(engine)
- @profiling.function_call_count(3230, {'2.4': 1796})
+ @profiling.function_call_count(2991, {'2.4': 1796})
def test_profile_1_create_tables(self):
self.test_baseline_1_create_tables()
def test_profile_1a_populate(self):
self.test_baseline_1a_populate()
- @profiling.function_call_count(322, {'2.4': 202})
+ @profiling.function_call_count(305, {'2.4': 202})
def test_profile_2_insert(self):
self.test_baseline_2_insert()
"""
- __only_on__ = 'postgres'
+ __only_on__ = 'postgresql+psycopg2'
__skip_if__ = ((lambda: sys.version_info < (2, 5)), ) # TODO: get 2.4 support
def test_baseline_0_setup(self):
global metadata, session
player = lambda: dbapi_session.player()
- engine = create_engine('postgres:///', creator=player)
+ engine = create_engine('postgresql:///', creator=player)
metadata = MetaData(engine)
session = sessionmaker()()
self.assert_sort(tuples, head)
def testbigsort(self):
- tuples = []
- for i in range(0,1500, 2):
- tuples.append((i, i+1))
+ tuples = [(i, i + 1) for i in range(0, 1500, 2)]
head = topological.sort_as_tree(tuples, [])
+ def testids(self):
+ # ticket:1380 regression: would raise a KeyError
+ topological.sort([(id(i), i) for i in range(3)], [])
+
+
"""Tests exceptions and DB-API exception wrapping."""
-import exceptions as stdlib_exceptions
from sqlalchemy import exc as sa_exceptions
from sqlalchemy.test import TestBase
+# Py3K
+#StandardError = BaseException
+# Py2K
+from exceptions import StandardError, KeyboardInterrupt, SystemExit
+# end Py2K
-class Error(stdlib_exceptions.StandardError):
+class Error(StandardError):
"""This class will be old-style on <= 2.4 and new-style on >= 2.5."""
class DatabaseError(Error):
pass
def test_db_error_keyboard_interrupt(self):
try:
raise sa_exceptions.DBAPIError.instance(
- '', [], stdlib_exceptions.KeyboardInterrupt())
+ '', [], KeyboardInterrupt())
except sa_exceptions.DBAPIError:
self.assert_(False)
- except stdlib_exceptions.KeyboardInterrupt:
+ except KeyboardInterrupt:
self.assert_(True)
def test_db_error_system_exit(self):
try:
raise sa_exceptions.DBAPIError.instance(
- '', [], stdlib_exceptions.SystemExit())
+ '', [], SystemExit())
except sa_exceptions.DBAPIError:
self.assert_(False)
- except stdlib_exceptions.SystemExit:
+ except SystemExit:
self.assert_(True)
from sqlalchemy import util, sql, exc
from sqlalchemy.test import TestBase
from sqlalchemy.test.testing import eq_, is_, ne_
+from sqlalchemy.test.util import gc_collect
class OrderedDictTest(TestBase):
def test_odict(self):
except TypeError:
assert True
- assert_raises(TypeError, cmp, ids)
+ assert_raises(TypeError, util.cmp, ids)
assert_raises(TypeError, hash, ids)
def test_difference(self):
d = subdict(a=1,b=2,c=3)
self._ok(d)
+ # Py2K
def test_UserDict(self):
import UserDict
d = UserDict.UserDict(a=1,b=2,c=3)
self._ok(d)
-
+ # end Py2K
+
def test_object(self):
self._notok(object())
return iter(self.baseline)
self._ok(duck1())
+ # Py2K
def test_duck_2(self):
class duck2(object):
def items(duck):
return list(self.baseline)
self._ok(duck2())
+ # end Py2K
+ # Py2K
def test_duck_3(self):
class duck3(object):
def iterkeys(duck):
def __getitem__(duck, key):
return dict(a=1,b=2,c=3).get(key)
self._ok(duck3())
+ # end Py2K
def test_duck_4(self):
class duck4(object):
class DuckTypeCollectionTest(TestBase):
def test_sets(self):
+ # Py2K
import sets
+ # end Py2K
class SetLike(object):
def add(self):
pass
class ForcedSet(list):
__emulates__ = set
-
+
for type_ in (set,
+ # Py2K
sets.Set,
+ # end Py2K
SetLike,
ForcedSet):
eq_(util.duck_type_collection(type_), set)
eq_(util.duck_type_collection(instance), set)
for type_ in (frozenset,
- sets.ImmutableSet):
+ # Py2K
+ sets.ImmutableSet
+ # end Py2K
+ ):
is_(util.duck_type_collection(type_), None)
instance = type_()
is_(util.duck_type_collection(instance), None)
-
class ArgInspectionTest(TestBase):
def test_get_cls_kwargs(self):
class A(object):
assert len(data) == len(wim) == len(wim.by_id)
del data[:]
+ gc_collect()
+
eq_(wim, {})
eq_(wim.by_id, {})
eq_(wim._weakrefs, {})
oid = id(data[0])
del data[0]
+ gc_collect()
assert len(data) == len(wim) == len(wim.by_id)
assert oid not in wim.by_id
th.start()
cv.wait()
cv.release()
+ gc_collect()
eq_(wim, {})
eq_(wim.by_id, {})
eq_(set(util.class_hierarchy(A)), set((A, B, C, object)))
eq_(set(util.class_hierarchy(B)), set((A, B, C, object)))
-
+
+ # Py2K
def test_oldstyle_mixin(self):
class A(object):
pass
eq_(set(util.class_hierarchy(B)), set((A, B, object)))
eq_(set(util.class_hierarchy(Mixin)), set())
eq_(set(util.class_hierarchy(A)), set((A, B, object)))
-
+ # end Py2K
con.execute('DROP GENERATOR gen_testtable_id')
def test_table_is_reflected(self):
+ from sqlalchemy.types import Integer, Text, Binary, String, Date, Time, DateTime
metadata = MetaData(testing.db)
table = Table('testtable', metadata, autoload=True)
eq_(set(table.columns.keys()),
"Columns of reflected table didn't equal expected columns")
eq_(table.c.question.primary_key, True)
eq_(table.c.question.sequence.name, 'gen_testtable_id')
- eq_(table.c.question.type.__class__, firebird.FBInteger)
+ assert isinstance(table.c.question.type, Integer)
eq_(table.c.question.server_default.arg.text, "42")
- eq_(table.c.answer.type.__class__, firebird.FBString)
+ assert isinstance(table.c.answer.type, String)
eq_(table.c.answer.server_default.arg.text, "'no answer'")
- eq_(table.c.remark.type.__class__, firebird.FBText)
+ assert isinstance(table.c.remark.type, Text)
eq_(table.c.remark.server_default.arg.text, "''")
- eq_(table.c.photo.type.__class__, firebird.FBBinary)
+ assert isinstance(table.c.photo.type, Binary)
# The following assume a Dialect 3 database
- eq_(table.c.d.type.__class__, firebird.FBDate)
- eq_(table.c.t.type.__class__, firebird.FBTime)
- eq_(table.c.dt.type.__class__, firebird.FBDateTime)
+ assert isinstance(table.c.d.type, Date)
+ assert isinstance(table.c.t.type, Time)
+ assert isinstance(table.c.dt.type, DateTime)
class CompileTest(TestBase, AssertsCompiledSQL):
def test_alias(self):
t = table('sometable', column('col1'), column('col2'))
s = select([t.alias()])
- self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable sometable_1")
+ self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable AS sometable_1")
+
+ dialect = firebird.FBDialect()
+ dialect._version_two = False
+ self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable sometable_1",
+ dialect = dialect
+ )
def test_function(self):
self.assert_compile(func.foo(1, 2), "foo(:foo_1, :foo_2)")
column('description', String(128)),
)
- u = update(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name])
+ u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING mytable.myid, mytable.name")
- u = update(table1, values=dict(name='foo'), firebird_returning=[table1])
+ u = update(table1, values=dict(name='foo')).returning(table1)
self.assert_compile(u, "UPDATE mytable SET name=:name "\
"RETURNING mytable.myid, mytable.name, mytable.description")
- u = update(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)])
- self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)")
+ u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+ self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name) AS length_1")
def test_insert_returning(self):
table1 = table('mytable',
column('description', String(128)),
)
- i = insert(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name])
+ i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING mytable.myid, mytable.name")
- i = insert(table1, values=dict(name='foo'), firebird_returning=[table1])
+ i = insert(table1, values=dict(name='foo')).returning(table1)
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) "\
"RETURNING mytable.myid, mytable.name, mytable.description")
- i = insert(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)])
- self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)")
+ i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+ self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name) AS length_1")
-class ReturningTest(TestBase, AssertsExecutionResults):
- __only_on__ = 'firebird'
-
- @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
- def test_update_returning(self):
- meta = MetaData(testing.db)
- table = Table('tables', meta,
- Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
- Column('persons', Integer),
- Column('full', Boolean)
- )
- table.create()
- try:
- table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
- result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute()
- eq_(result.fetchall(), [(1,)])
-
- result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
- eq_(result2.fetchall(), [(1,True),(2,False)])
- finally:
- table.drop()
-
- @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
- def test_insert_returning(self):
- meta = MetaData(testing.db)
- table = Table('tables', meta,
- Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
- Column('persons', Integer),
- Column('full', Boolean)
- )
- table.create()
- try:
- result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False})
-
- eq_(result.fetchall(), [(1,)])
-
- # Multiple inserts only return the last row
- result2 = table.insert(firebird_returning=[table]).execute(
- [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
-
- eq_(result2.fetchall(), [(3,3,True)])
-
- result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False})
- eq_([dict(row) for row in result3], [{'ID':4}])
-
- result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons')
- eq_([dict(row) for row in result4], [{'PERSONS': 10}])
- finally:
- table.drop()
-
- @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
- def test_delete_returning(self):
- meta = MetaData(testing.db)
- table = Table('tables', meta,
- Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
- Column('persons', Integer),
- Column('full', Boolean)
- )
- table.create()
- try:
- table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
- result = table.delete(table.c.persons > 4, firebird_returning=[table.c.id]).execute()
- eq_(result.fetchall(), [(1,)])
-
- result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
- eq_(result2.fetchall(), [(2,False),])
- finally:
- table.drop()
-class MiscFBTests(TestBase):
+class MiscTest(TestBase):
__only_on__ = 'firebird'
def test_strlen(self):
try:
t.insert(values=dict(name='dante')).execute()
t.insert(values=dict(name='alighieri')).execute()
- select([func.count(t.c.id)],func.length(t.c.name)==5).execute().fetchone()[0] == 1
+ select([func.count(t.c.id)],func.length(t.c.name)==5).execute().first()[0] == 1
finally:
meta.drop_all()
def test_server_version_info(self):
- version = testing.db.dialect.server_version_info(testing.db.connect())
+ version = testing.db.dialect.server_version_info
assert len(version) == 3, "Got strange version info: %s" % repr(version)
+ def test_percents_in_text(self):
+ for expr, result in (
+ (text("select '%' from rdb$database"), '%'),
+ (text("select '%%' from rdb$database"), '%%'),
+ (text("select '%%%' from rdb$database"), '%%%'),
+ (text("select 'hello % world' from rdb$database"), "hello % world")
+ ):
+ eq_(testing.db.scalar(expr), result)
class CompileTest(TestBase, AssertsCompiledSQL):
- __dialect__ = informix.InfoDialect()
+ __only_on__ = 'informix'
+ __dialect__ = informix.InformixDialect()
def test_statements(self):
meta =MetaData()
vals = []
for i in xrange(3):
cr.execute('SELECT busto.NEXTVAL FROM DUAL')
- vals.append(cr.fetchone()[0])
+ vals.append(cr.first()[0])
# should be 1,2,3, but no...
self.assert_(vals != [1,2,3])
from sqlalchemy.test.testing import eq_
import datetime, os, re
from sqlalchemy import *
-from sqlalchemy import types, exc
+from sqlalchemy import types, exc, schema
from sqlalchemy.orm import *
from sqlalchemy.sql import table, column
from sqlalchemy.databases import mssql
-import sqlalchemy.engine.url as url
+from sqlalchemy.dialects.mssql import pyodbc
+from sqlalchemy.engine import url
from sqlalchemy.test import *
from sqlalchemy.test.testing import eq_
class CompileTest(TestBase, AssertsCompiledSQL):
- __dialect__ = mssql.MSSQLDialect()
+ __dialect__ = mssql.dialect()
def test_insert(self):
t = table('sometable', column('somecolumn'))
select([extract(field, t.c.col1)]),
'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % field)
+ def test_update_returning(self):
+ table1 = table('mytable',
+ column('myid', Integer),
+ column('name', String(128)),
+ column('description', String(128)),
+ )
+
+ u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
+ self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, inserted.name")
+
+ u = update(table1, values=dict(name='foo')).returning(table1)
+ self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, "
+ "inserted.name, inserted.description")
+
+ u = update(table1, values=dict(name='foo')).returning(table1).where(table1.c.name=='bar')
+ self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, "
+ "inserted.name, inserted.description WHERE mytable.name = :name_1")
+
+ u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+ self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name) AS length_1")
+
+ def test_insert_returning(self):
+ table1 = table('mytable',
+ column('myid', Integer),
+ column('name', String(128)),
+ column('description', String(128)),
+ )
+
+ i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
+ self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, inserted.name VALUES (:name)")
+
+ i = insert(table1, values=dict(name='foo')).returning(table1)
+ self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, "
+ "inserted.name, inserted.description VALUES (:name)")
+
+ i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+ self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) AS length_1 VALUES (:name)")
+
+
class IdentityInsertTest(TestBase, AssertsCompiledSQL):
__only_on__ = 'mssql'
eq_([(9, 'Python')], list(cats))
result = cattable.insert().values(description='PHP').execute()
- eq_([10], result.last_inserted_ids())
+ eq_([10], result.inserted_primary_key)
lastcat = cattable.select().order_by(desc(cattable.c.id)).execute()
- eq_((10, 'PHP'), lastcat.fetchone())
+ eq_((10, 'PHP'), lastcat.first())
def test_executemany(self):
cattable.insert().execute([
eq_([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats))
-class ReflectionTest(TestBase):
+class ReflectionTest(TestBase, ComparesTables):
__only_on__ = 'mssql'
- def testidentity(self):
+ def test_basic_reflection(self):
+ meta = MetaData(testing.db)
+
+ users = Table('engine_users', meta,
+ Column('user_id', types.INT, primary_key=True),
+ Column('user_name', types.VARCHAR(20), nullable=False),
+ Column('test1', types.CHAR(5), nullable=False),
+ Column('test2', types.Float(5), nullable=False),
+ Column('test3', types.Text),
+ Column('test4', types.Numeric, nullable = False),
+ Column('test5', types.DateTime),
+ Column('parent_user_id', types.Integer,
+ ForeignKey('engine_users.user_id')),
+ Column('test6', types.DateTime, nullable=False),
+ Column('test7', types.Text),
+ Column('test8', types.Binary),
+ Column('test_passivedefault2', types.Integer, server_default='5'),
+ Column('test9', types.Binary(100)),
+ Column('test_numeric', types.Numeric()),
+ test_needs_fk=True,
+ )
+
+ addresses = Table('engine_email_addresses', meta,
+ Column('address_id', types.Integer, primary_key = True),
+ Column('remote_user_id', types.Integer, ForeignKey(users.c.user_id)),
+ Column('email_address', types.String(20)),
+ test_needs_fk=True,
+ )
+ meta.create_all()
+
+ try:
+ meta2 = MetaData()
+ reflected_users = Table('engine_users', meta2, autoload=True,
+ autoload_with=testing.db)
+ reflected_addresses = Table('engine_email_addresses', meta2,
+ autoload=True, autoload_with=testing.db)
+ self.assert_tables_equal(users, reflected_users)
+ self.assert_tables_equal(addresses, reflected_addresses)
+ finally:
+ meta.drop_all()
+
+ def test_identity(self):
meta = MetaData(testing.db)
table = Table(
'identity_test', meta,
meta = MetaData(testing.db)
t1 = Table('unitest_table', meta,
Column('id', Integer, primary_key=True),
- Column('descr', mssql.MSText(200, convert_unicode=True)))
+ Column('descr', mssql.MSText(convert_unicode=True)))
meta.create_all()
con = testing.db.connect()
con.execute(u"insert into unitest_table values ('bien mangé')".encode('UTF-8'))
try:
- r = t1.select().execute().fetchone()
+ r = t1.select().execute().first()
assert isinstance(r[1], unicode), '%s is %s instead of unicode, working on %s' % (
r[1], type(r[1]), meta.bind)
meta = MetaData(testing.db)
t1 = Table('t1', meta,
Column('id', Integer, Sequence('fred', 100, 1), primary_key=True),
- Column('descr', String(200)))
+ Column('descr', String(200)),
+ implicit_returning = False
+ )
t2 = Table('t2', meta,
Column('id', Integer, Sequence('fred', 200, 1), primary_key=True),
Column('descr', String(200)))
try:
tr = con.begin()
r = con.execute(t2.insert(), descr='hello')
- self.assert_(r.last_inserted_ids() == [200])
+ self.assert_(r.inserted_primary_key == [200])
r = con.execute(t1.insert(), descr='hello')
- self.assert_(r.last_inserted_ids() == [100])
+ self.assert_(r.inserted_primary_key == [100])
finally:
tr.commit()
tbl.drop()
con.execute('drop schema paj')
+ def test_returning_no_autoinc(self):
+ meta = MetaData(testing.db)
+
+ table = Table('t1', meta, Column('id', Integer, primary_key=True), Column('data', String(50)))
+ table.create()
+ try:
+ result = table.insert().values(id=1, data=func.lower("SomeString")).returning(table.c.id, table.c.data).execute()
+ eq_(result.fetchall(), [(1, 'somestring',)])
+ finally:
+ # this will hang if the "SET IDENTITY_INSERT t1 OFF" occurs before the
+ # result is fetched
+ table.drop()
+
def test_delete_schema(self):
meta = MetaData(testing.db)
con = testing.db.connect()
)
self.column = t.c.test_column
+ dialect = mssql.dialect()
+ self.ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t))
+
+ def _column_spec(self):
+ return self.ddl_compiler.get_column_specification(self.column)
+
def test_that_mssql_default_nullability_emits_null(self):
- schemagenerator = \
- mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
- column_specification = \
- schemagenerator.get_column_specification(self.column)
- eq_("test_column VARCHAR NULL", column_specification)
+ eq_("test_column VARCHAR NULL", self._column_spec())
def test_that_mssql_none_nullability_does_not_emit_nullability(self):
- schemagenerator = \
- mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
self.column.nullable = None
- column_specification = \
- schemagenerator.get_column_specification(self.column)
- eq_("test_column VARCHAR", column_specification)
+ eq_("test_column VARCHAR", self._column_spec())
def test_that_mssql_specified_nullable_emits_null(self):
- schemagenerator = \
- mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
self.column.nullable = True
- column_specification = \
- schemagenerator.get_column_specification(self.column)
- eq_("test_column VARCHAR NULL", column_specification)
+ eq_("test_column VARCHAR NULL", self._column_spec())
def test_that_mssql_specified_not_nullable_emits_not_null(self):
- schemagenerator = \
- mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
self.column.nullable = False
- column_specification = \
- schemagenerator.get_column_specification(self.column)
- eq_("test_column VARCHAR NOT NULL", column_specification)
+ eq_("test_column VARCHAR NOT NULL", self._column_spec())
def full_text_search_missing():
class ParseConnectTest(TestBase, AssertsCompiledSQL):
__only_on__ = 'mssql'
+ @classmethod
+ def setup_class(cls):
+ global dialect
+ dialect = pyodbc.MSDialect_pyodbc()
+
def test_pyodbc_connect_dsn_trusted(self):
u = url.make_url('mssql://mydsn')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
def test_pyodbc_connect_old_style_dsn_trusted(self):
u = url.make_url('mssql:///?dsn=mydsn')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
def test_pyodbc_connect_dsn_non_trusted(self):
u = url.make_url('mssql://username:password@mydsn')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['dsn=mydsn;UID=username;PWD=password'], {}], connection)
def test_pyodbc_connect_dsn_extra(self):
u = url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_english&foo=bar')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection)
def test_pyodbc_connect(self):
u = url.make_url('mssql://username:password@hostspec/database')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_connect_comma_port(self):
u = url.make_url('mssql://username:password@hostspec:12345/database')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_connect_config_port(self):
u = url.make_url('mssql://username:password@hostspec/database?port=12345')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection)
def test_pyodbc_extra_connect(self):
u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection)
def test_pyodbc_odbc_connect(self):
u = url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_odbc_connect_with_dsn(self):
u = url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_odbc_connect_ignores_other_values(self):
u = url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
- dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
-class TypesTest(TestBase):
+class TypesTest(TestBase, AssertsExecutionResults, ComparesTables):
__only_on__ = 'mssql'
@classmethod
def setup_class(cls):
- global numeric_table, metadata
+ global metadata
metadata = MetaData(testing.db)
def teardown(self):
)
metadata.create_all()
- try:
- test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000',
- '-1500000.00000000000000000000', '1500000',
- '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2',
- '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234',
- '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3',
- '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2',
- '-02452E-2', '45125E-2',
- '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25',
- '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12',
- '00000000000000.1E+12', '000000000000.2E-32']
+ test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000',
+ '-1500000.00000000000000000000', '1500000',
+ '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2',
+ '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234',
+ '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3',
+ '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2',
+ '-02452E-2', '45125E-2',
+ '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25',
+ '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12',
+ '00000000000000.1E+12', '000000000000.2E-32']
- for value in test_items:
- numeric_table.insert().execute(numericcol=value)
+ for value in test_items:
+ numeric_table.insert().execute(numericcol=value)
- for value in select([numeric_table.c.numericcol]).execute():
- assert value[0] in test_items, "%s not in test_items" % value[0]
-
- except Exception, e:
- raise e
+ for value in select([numeric_table.c.numericcol]).execute():
+ assert value[0] in test_items, "%s not in test_items" % value[0]
def test_float(self):
float_table = Table('float_table', metadata,
raise e
-class TypesTest2(TestBase, AssertsExecutionResults):
- "Test Microsoft SQL Server column types"
-
- __only_on__ = 'mssql'
-
def test_money(self):
"Exercise type specification for money types."
'SMALLMONEY'),
]
- table_args = ['test_mssql_money', MetaData(testing.db)]
+ table_args = ['test_mssql_money', metadata]
for index, spec in enumerate(columns):
type_, args, kw, res = spec
table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
money_table = Table(*table_args)
- gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+ dialect = mssql.dialect()
+ gen = dialect.ddl_compiler(dialect, schema.CreateTable(money_table))
for col in money_table.c:
index = int(col.name[1:])
(mssql.MSDateTime, [], {},
'DATETIME', []),
+ (types.DATE, [], {},
+ 'DATE', ['>=', (10,)]),
+ (types.Date, [], {},
+ 'DATE', ['>=', (10,)]),
+ (types.Date, [], {},
+ 'DATETIME', ['<', (10,)], mssql.MSDateTime),
(mssql.MSDate, [], {},
'DATE', ['>=', (10,)]),
(mssql.MSDate, [], {},
'DATETIME', ['<', (10,)], mssql.MSDateTime),
+ (types.TIME, [], {},
+ 'TIME', ['>=', (10,)]),
+ (types.Time, [], {},
+ 'TIME', ['>=', (10,)]),
(mssql.MSTime, [], {},
'TIME', ['>=', (10,)]),
(mssql.MSTime, [1], {},
'TIME(1)', ['>=', (10,)]),
+ (types.Time, [], {},
+ 'DATETIME', ['<', (10,)], mssql.MSDateTime),
(mssql.MSTime, [], {},
'DATETIME', ['<', (10,)], mssql.MSDateTime),
]
- table_args = ['test_mssql_dates', MetaData(testing.db)]
+ table_args = ['test_mssql_dates', metadata]
for index, spec in enumerate(columns):
type_, args, kw, res, requires = spec[0:5]
if (requires and testing._is_excluded('mssql', *requires)) or not requires:
table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
dates_table = Table(*table_args)
- gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+ gen = testing.db.dialect.ddl_compiler(testing.db.dialect, schema.CreateTable(dates_table))
for col in dates_table.c:
index = int(col.name[1:])
"%s %s" % (col.name, columns[index][3]))
self.assert_(repr(col))
- try:
- dates_table.create(checkfirst=True)
- assert True
- except:
- raise
+ dates_table.create(checkfirst=True)
reflected_dates = Table('test_mssql_dates', MetaData(testing.db), autoload=True)
for col in reflected_dates.c:
- index = int(col.name[1:])
- testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__,
- len(columns[index]) > 5 and columns[index][5] or columns[index][0])
- dates_table.drop()
-
- def test_dates2(self):
- meta = MetaData(testing.db)
- t = Table('test_dates', meta,
- Column('id', Integer,
- Sequence('datetest_id_seq', optional=True),
- primary_key=True),
- Column('adate', Date),
- Column('atime', Time),
- Column('adatetime', DateTime))
- t.create(checkfirst=True)
- try:
- d1 = datetime.date(2007, 10, 30)
- t1 = datetime.time(11, 2, 32)
- d2 = datetime.datetime(2007, 10, 30, 11, 2, 32)
- t.insert().execute(adate=d1, adatetime=d2, atime=t1)
- t.insert().execute(adate=d2, adatetime=d2, atime=d2)
+ self.assert_types_base(col, dates_table.c[col.key])
- x = t.select().execute().fetchall()[0]
- self.assert_(x.adate.__class__ == datetime.date)
- self.assert_(x.atime.__class__ == datetime.time)
- self.assert_(x.adatetime.__class__ == datetime.datetime)
+ def test_date_roundtrip(self):
+ t = Table('test_dates', metadata,
+ Column('id', Integer,
+ Sequence('datetest_id_seq', optional=True),
+ primary_key=True),
+ Column('adate', Date),
+ Column('atime', Time),
+ Column('adatetime', DateTime))
+ metadata.create_all()
+ d1 = datetime.date(2007, 10, 30)
+ t1 = datetime.time(11, 2, 32)
+ d2 = datetime.datetime(2007, 10, 30, 11, 2, 32)
+ t.insert().execute(adate=d1, adatetime=d2, atime=t1)
+ t.insert().execute(adate=d2, adatetime=d2, atime=d2)
- t.delete().execute()
+ x = t.select().execute().fetchall()[0]
+ self.assert_(x.adate.__class__ == datetime.date)
+ self.assert_(x.atime.__class__ == datetime.time)
+ self.assert_(x.adatetime.__class__ == datetime.datetime)
- t.insert().execute(adate=d1, adatetime=d2, atime=t1)
+ t.delete().execute()
- eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
+ t.insert().execute(adate=d1, adatetime=d2, atime=t1)
- finally:
- t.drop(checkfirst=True)
+ eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
def test_binary(self):
"Exercise type specification for binary types."
# column type, args, kwargs, expected ddl
(mssql.MSBinary, [], {},
'BINARY'),
+ (types.Binary, [10], {},
+ 'BINARY(10)'),
+
(mssql.MSBinary, [10], {},
'BINARY(10)'),
'BINARY(10)')
]
- table_args = ['test_mssql_binary', MetaData(testing.db)]
+ table_args = ['test_mssql_binary', metadata]
for index, spec in enumerate(columns):
type_, args, kw, res = spec
table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
binary_table = Table(*table_args)
- gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+ dialect = mssql.dialect()
+ gen = dialect.ddl_compiler(dialect, schema.CreateTable(binary_table))
for col in binary_table.c:
index = int(col.name[1:])
"%s %s" % (col.name, columns[index][3]))
self.assert_(repr(col))
- try:
- binary_table.create(checkfirst=True)
- assert True
- except:
- raise
+ metadata.create_all()
reflected_binary = Table('test_mssql_binary', MetaData(testing.db), autoload=True)
for col in reflected_binary.c:
- # don't test the MSGenericBinary since it's a special case and
- # reflected it will map to a MSImage or MSBinary depending
- if not testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ == mssql.MSGenericBinary:
- testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__,
- testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__)
+ c1 =testing.db.dialect.type_descriptor(col.type).__class__
+ c2 =testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__
+ assert issubclass(c1, c2), "%r is not a subclass of %r" % (c1, c2)
if binary_table.c[col.name].type.length:
testing.eq_(col.type.length, binary_table.c[col.name].type.length)
- binary_table.drop()
def test_boolean(self):
"Exercise type specification for boolean type."
'BIT'),
]
- table_args = ['test_mssql_boolean', MetaData(testing.db)]
+ table_args = ['test_mssql_boolean', metadata]
for index, spec in enumerate(columns):
type_, args, kw, res = spec
table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
boolean_table = Table(*table_args)
- gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+ dialect = mssql.dialect()
+ gen = dialect.ddl_compiler(dialect, schema.CreateTable(boolean_table))
for col in boolean_table.c:
index = int(col.name[1:])
"%s %s" % (col.name, columns[index][3]))
self.assert_(repr(col))
- try:
- boolean_table.create(checkfirst=True)
- assert True
- except:
- raise
- boolean_table.drop()
+ metadata.create_all()
def test_numeric(self):
"Exercise type specification and options for numeric types."
columns = [
# column type, args, kwargs, expected ddl
(mssql.MSNumeric, [], {},
- 'NUMERIC(10, 2)'),
+ 'NUMERIC'),
(mssql.MSNumeric, [None], {},
'NUMERIC'),
- (mssql.MSNumeric, [12], {},
- 'NUMERIC(12, 2)'),
(mssql.MSNumeric, [12, 4], {},
'NUMERIC(12, 4)'),
- (mssql.MSFloat, [], {},
- 'FLOAT(10)'),
- (mssql.MSFloat, [None], {},
+ (types.Float, [], {},
+ 'FLOAT'),
+ (types.Float, [None], {},
'FLOAT'),
- (mssql.MSFloat, [12], {},
+ (types.Float, [12], {},
'FLOAT(12)'),
(mssql.MSReal, [], {},
'REAL'),
- (mssql.MSInteger, [], {},
+ (types.Integer, [], {},
'INTEGER'),
- (mssql.MSBigInteger, [], {},
+ (types.BigInteger, [], {},
'BIGINT'),
(mssql.MSTinyInteger, [], {},
'TINYINT'),
- (mssql.MSSmallInteger, [], {},
+ (types.SmallInteger, [], {},
'SMALLINT'),
]
- table_args = ['test_mssql_numeric', MetaData(testing.db)]
+ table_args = ['test_mssql_numeric', metadata]
for index, spec in enumerate(columns):
type_, args, kw, res = spec
table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
numeric_table = Table(*table_args)
- gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+ dialect = mssql.dialect()
+ gen = dialect.ddl_compiler(dialect, schema.CreateTable(numeric_table))
for col in numeric_table.c:
index = int(col.name[1:])
"%s %s" % (col.name, columns[index][3]))
self.assert_(repr(col))
- try:
- numeric_table.create(checkfirst=True)
- assert True
- except:
- raise
- numeric_table.drop()
+ metadata.create_all()
def test_char(self):
"""Exercise COLLATE-ish options on string types."""
- # modify the text_as_varchar setting since we are not testing that behavior here
- text_as_varchar = testing.db.dialect.text_as_varchar
- testing.db.dialect.text_as_varchar = False
-
columns = [
(mssql.MSChar, [], {},
'CHAR'),
'NTEXT COLLATE Latin1_General_CI_AS'),
]
- table_args = ['test_mssql_charset', MetaData(testing.db)]
+ table_args = ['test_mssql_charset', metadata]
for index, spec in enumerate(columns):
type_, args, kw, res = spec
table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
charset_table = Table(*table_args)
- gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+ dialect = mssql.dialect()
+ gen = dialect.ddl_compiler(dialect, schema.CreateTable(charset_table))
for col in charset_table.c:
index = int(col.name[1:])
"%s %s" % (col.name, columns[index][3]))
self.assert_(repr(col))
- try:
- charset_table.create(checkfirst=True)
- assert True
- except:
- raise
- charset_table.drop()
-
- testing.db.dialect.text_as_varchar = text_as_varchar
+ metadata.create_all()
def test_timestamp(self):
"""Exercise TIMESTAMP column."""
- meta = MetaData(testing.db)
-
- try:
- columns = [
- (TIMESTAMP,
- 'TIMESTAMP'),
- (mssql.MSTimeStamp,
- 'TIMESTAMP'),
- ]
- for idx, (spec, expected) in enumerate(columns):
- t = Table('mssql_ts%s' % idx, meta,
- Column('id', Integer, primary_key=True),
- Column('t', spec, nullable=None))
- testing.eq_(colspec(t.c.t), "t %s" % expected)
- self.assert_(repr(t.c.t))
- try:
- t.create(checkfirst=True)
- assert True
- except:
- raise
- t.drop()
- finally:
- meta.drop_all()
+ dialect = mssql.dialect()
+ spec, expected = (TIMESTAMP,'TIMESTAMP')
+ t = Table('mssql_ts', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('t', spec, nullable=None))
+ gen = dialect.ddl_compiler(dialect, schema.CreateTable(t))
+ testing.eq_(gen.get_column_specification(t.c.t), "t %s" % expected)
+ self.assert_(repr(t.c.t))
+ t.create(checkfirst=True)
+
def test_autoincrement(self):
- meta = MetaData(testing.db)
- try:
- Table('ai_1', meta,
- Column('int_y', Integer, primary_key=True),
- Column('int_n', Integer, DefaultClause('0'),
- primary_key=True))
- Table('ai_2', meta,
- Column('int_y', Integer, primary_key=True),
- Column('int_n', Integer, DefaultClause('0'),
- primary_key=True))
- Table('ai_3', meta,
- Column('int_n', Integer, DefaultClause('0'),
- primary_key=True, autoincrement=False),
- Column('int_y', Integer, primary_key=True))
- Table('ai_4', meta,
- Column('int_n', Integer, DefaultClause('0'),
- primary_key=True, autoincrement=False),
- Column('int_n2', Integer, DefaultClause('0'),
- primary_key=True, autoincrement=False))
- Table('ai_5', meta,
- Column('int_y', Integer, primary_key=True),
- Column('int_n', Integer, DefaultClause('0'),
- primary_key=True, autoincrement=False))
- Table('ai_6', meta,
- Column('o1', String(1), DefaultClause('x'),
- primary_key=True),
- Column('int_y', Integer, primary_key=True))
- Table('ai_7', meta,
- Column('o1', String(1), DefaultClause('x'),
- primary_key=True),
- Column('o2', String(1), DefaultClause('x'),
- primary_key=True),
- Column('int_y', Integer, primary_key=True))
- Table('ai_8', meta,
- Column('o1', String(1), DefaultClause('x'),
- primary_key=True),
- Column('o2', String(1), DefaultClause('x'),
- primary_key=True))
- meta.create_all()
-
- table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4',
- 'ai_5', 'ai_6', 'ai_7', 'ai_8']
- mr = MetaData(testing.db)
- mr.reflect(only=table_names)
-
- for tbl in [mr.tables[name] for name in table_names]:
- for c in tbl.c:
- if c.name.startswith('int_y'):
- assert c.autoincrement
- elif c.name.startswith('int_n'):
- assert not c.autoincrement
- tbl.insert().execute()
+ Table('ai_1', metadata,
+ Column('int_y', Integer, primary_key=True),
+ Column('int_n', Integer, DefaultClause('0'),
+ primary_key=True))
+ Table('ai_2', metadata,
+ Column('int_y', Integer, primary_key=True),
+ Column('int_n', Integer, DefaultClause('0'),
+ primary_key=True))
+ Table('ai_3', metadata,
+ Column('int_n', Integer, DefaultClause('0'),
+ primary_key=True, autoincrement=False),
+ Column('int_y', Integer, primary_key=True))
+ Table('ai_4', metadata,
+ Column('int_n', Integer, DefaultClause('0'),
+ primary_key=True, autoincrement=False),
+ Column('int_n2', Integer, DefaultClause('0'),
+ primary_key=True, autoincrement=False))
+ Table('ai_5', metadata,
+ Column('int_y', Integer, primary_key=True),
+ Column('int_n', Integer, DefaultClause('0'),
+ primary_key=True, autoincrement=False))
+ Table('ai_6', metadata,
+ Column('o1', String(1), DefaultClause('x'),
+ primary_key=True),
+ Column('int_y', Integer, primary_key=True))
+ Table('ai_7', metadata,
+ Column('o1', String(1), DefaultClause('x'),
+ primary_key=True),
+ Column('o2', String(1), DefaultClause('x'),
+ primary_key=True),
+ Column('int_y', Integer, primary_key=True))
+ Table('ai_8', metadata,
+ Column('o1', String(1), DefaultClause('x'),
+ primary_key=True),
+ Column('o2', String(1), DefaultClause('x'),
+ primary_key=True))
+ metadata.create_all()
+
+ table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4',
+ 'ai_5', 'ai_6', 'ai_7', 'ai_8']
+ mr = MetaData(testing.db)
+
+ for name in table_names:
+ tbl = Table(name, mr, autoload=True)
+ for c in tbl.c:
+ if c.name.startswith('int_y'):
+ assert c.autoincrement
+ elif c.name.startswith('int_n'):
+ assert not c.autoincrement
+
+ for counter, engine in enumerate([
+ engines.testing_engine(options={'implicit_returning':False}),
+ engines.testing_engine(options={'implicit_returning':True}),
+ ]
+ ):
+ engine.execute(tbl.insert())
if 'int_y' in tbl.c:
- assert select([tbl.c.int_y]).scalar() == 1
- assert list(tbl.select().execute().fetchone()).count(1) == 1
+ assert engine.scalar(select([tbl.c.int_y])) == counter + 1
+ assert list(engine.execute(tbl.select()).first()).count(counter + 1) == 1
else:
- assert 1 not in list(tbl.select().execute().fetchone())
- finally:
- meta.drop_all()
-
-def colspec(c):
- return testing.db.dialect.schemagenerator(testing.db.dialect,
- testing.db, None, None).get_column_specification(c)
-
+ assert 1 not in list(engine.execute(tbl.select()).first())
+ engine.execute(tbl.delete())
class BinaryTest(TestBase, AssertsExecutionResults):
"""Test the Binary and VarBinary types"""
+
+ __only_on__ = 'mssql'
+
@classmethod
def setup_class(cls):
global binary_table, MyPickleType
stream2 =self.load_stream('binary_data_two.dat')
binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_image=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3)
binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_image=stream2, data_slice=stream2[0:99], pickled=testobj2)
+
+ # TODO: pyodbc does not seem to accept "None" for a VARBINARY column (data=None).
+ # error: [Microsoft][ODBC SQL Server Driver][SQL Server]Implicit conversion from
+ # data type varchar to varbinary is not allowed. Use the CONVERT function to run this query. (257)
+ #binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_image=None, data_slice=stream2[0:99], pickled=None)
binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data_image=None, data_slice=stream2[0:99], pickled=None)
for stmt in (
from sqlalchemy.test.testing import eq_
+
+# Py2K
import sets
+# end Py2K
+
from sqlalchemy import *
from sqlalchemy import sql, exc
-from sqlalchemy.databases import mysql
+from sqlalchemy.dialects.mysql import base as mysql
from sqlalchemy.test.testing import eq_
from sqlalchemy.test import *
# column type, args, kwargs, expected ddl
# e.g. Column(Integer(10, unsigned=True)) == 'INTEGER(10) UNSIGNED'
(mysql.MSNumeric, [], {},
- 'NUMERIC(10, 2)'),
+ 'NUMERIC'),
(mysql.MSNumeric, [None], {},
'NUMERIC'),
(mysql.MSNumeric, [12], {},
- 'NUMERIC(12, 2)'),
+ 'NUMERIC(12)'),
(mysql.MSNumeric, [12, 4], {'unsigned':True},
'NUMERIC(12, 4) UNSIGNED'),
(mysql.MSNumeric, [12, 4], {'zerofill':True},
'NUMERIC(12, 4) UNSIGNED ZEROFILL'),
(mysql.MSDecimal, [], {},
- 'DECIMAL(10, 2)'),
+ 'DECIMAL'),
(mysql.MSDecimal, [None], {},
'DECIMAL'),
(mysql.MSDecimal, [12], {},
- 'DECIMAL(12, 2)'),
+ 'DECIMAL(12)'),
(mysql.MSDecimal, [12, None], {},
'DECIMAL(12)'),
(mysql.MSDecimal, [12, 4], {'unsigned':True},
table_args.append(Column('c%s' % index, type_(*args, **kw)))
numeric_table = Table(*table_args)
- gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+ gen = testing.db.dialect.ddl_compiler(testing.db.dialect, numeric_table)
for col in numeric_table.c:
index = int(col.name[1:])
- self.assert_eq(gen.get_column_specification(col),
+ eq_(gen.get_column_specification(col),
"%s %s" % (col.name, columns[index][3]))
self.assert_(repr(col))
table_args.append(Column('c%s' % index, type_(*args, **kw)))
charset_table = Table(*table_args)
- gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+ gen = testing.db.dialect.ddl_compiler(testing.db.dialect, charset_table)
for col in charset_table.c:
index = int(col.name[1:])
- self.assert_eq(gen.get_column_specification(col),
+ eq_(gen.get_column_specification(col),
"%s %s" % (col.name, columns[index][3]))
self.assert_(repr(col))
Column('b7', mysql.MSBit(63)),
Column('b8', mysql.MSBit(64)))
- self.assert_eq(colspec(bit_table.c.b1), 'b1 BIT')
- self.assert_eq(colspec(bit_table.c.b2), 'b2 BIT')
- self.assert_eq(colspec(bit_table.c.b3), 'b3 BIT NOT NULL')
- self.assert_eq(colspec(bit_table.c.b4), 'b4 BIT(1)')
- self.assert_eq(colspec(bit_table.c.b5), 'b5 BIT(8)')
- self.assert_eq(colspec(bit_table.c.b6), 'b6 BIT(32)')
- self.assert_eq(colspec(bit_table.c.b7), 'b7 BIT(63)')
- self.assert_eq(colspec(bit_table.c.b8), 'b8 BIT(64)')
+ eq_(colspec(bit_table.c.b1), 'b1 BIT')
+ eq_(colspec(bit_table.c.b2), 'b2 BIT')
+ eq_(colspec(bit_table.c.b3), 'b3 BIT NOT NULL')
+ eq_(colspec(bit_table.c.b4), 'b4 BIT(1)')
+ eq_(colspec(bit_table.c.b5), 'b5 BIT(8)')
+ eq_(colspec(bit_table.c.b6), 'b6 BIT(32)')
+ eq_(colspec(bit_table.c.b7), 'b7 BIT(63)')
+ eq_(colspec(bit_table.c.b8), 'b8 BIT(64)')
for col in bit_table.c:
self.assert_(repr(col))
def roundtrip(store, expected=None):
expected = expected or store
table.insert(store).execute()
- row = list(table.select().execute())[0]
+ row = table.select().execute().first()
try:
self.assert_(list(row) == expected)
except:
print "Expected %s" % expected
print "Found %s" % list(row)
raise
- table.delete().execute()
+ table.delete().execute().close()
roundtrip([0] * 8)
roundtrip([None, None, 0, None, None, None, None, None])
Column('b3', mysql.MSTinyInteger(1)),
Column('b4', mysql.MSTinyInteger))
- self.assert_eq(colspec(bool_table.c.b1), 'b1 BOOL')
- self.assert_eq(colspec(bool_table.c.b2), 'b2 BOOL')
- self.assert_eq(colspec(bool_table.c.b3), 'b3 TINYINT(1)')
- self.assert_eq(colspec(bool_table.c.b4), 'b4 TINYINT')
+ eq_(colspec(bool_table.c.b1), 'b1 BOOL')
+ eq_(colspec(bool_table.c.b2), 'b2 BOOL')
+ eq_(colspec(bool_table.c.b3), 'b3 TINYINT(1)')
+ eq_(colspec(bool_table.c.b4), 'b4 TINYINT')
for col in bool_table.c:
self.assert_(repr(col))
def roundtrip(store, expected=None):
expected = expected or store
table.insert(store).execute()
- row = list(table.select().execute())[0]
+ row = table.select().execute().first()
try:
self.assert_(list(row) == expected)
for i, val in enumerate(expected):
print "Expected %s" % expected
print "Found %s" % list(row)
raise
- table.delete().execute()
+ table.delete().execute().close()
roundtrip([None, None, None, None])
meta2 = MetaData(testing.db)
# replace with reflected
table = Table('mysql_bool', meta2, autoload=True)
- self.assert_eq(colspec(table.c.b3), 'b3 BOOL')
+ eq_(colspec(table.c.b3), 'b3 BOOL')
roundtrip([None, None, None, None])
roundtrip([True, True, 1, 1], [True, True, True, 1])
t = Table('mysql_ts%s' % idx, meta,
Column('id', Integer, primary_key=True),
Column('t', *spec))
- self.assert_eq(colspec(t.c.t), "t %s" % expected)
+ eq_(colspec(t.c.t), "t %s" % expected)
self.assert_(repr(t.c.t))
t.create()
r = Table('mysql_ts%s' % idx, MetaData(testing.db),
for table in year_table, reflected:
table.insert(['1950', '50', None, 50, 1950]).execute()
- row = list(table.select().execute())[0]
- self.assert_eq(list(row), [1950, 2050, None, 50, 1950])
+ row = table.select().execute().first()
+ eq_(list(row), [1950, 2050, None, 50, 1950])
table.delete().execute()
self.assert_(colspec(table.c.y1).startswith('y1 YEAR'))
- self.assert_eq(colspec(table.c.y4), 'y4 YEAR(2)')
- self.assert_eq(colspec(table.c.y5), 'y5 YEAR(4)')
+ eq_(colspec(table.c.y4), 'y4 YEAR(2)')
+ eq_(colspec(table.c.y5), 'y5 YEAR(4)')
finally:
meta.drop_all()
Column('s2', mysql.MSSet("'a'")),
Column('s3', mysql.MSSet("'5'", "'7'", "'9'")))
- self.assert_eq(colspec(set_table.c.s1), "s1 SET('dq','sq')")
- self.assert_eq(colspec(set_table.c.s2), "s2 SET('a')")
- self.assert_eq(colspec(set_table.c.s3), "s3 SET('5','7','9')")
+ eq_(colspec(set_table.c.s1), "s1 SET('dq','sq')")
+ eq_(colspec(set_table.c.s2), "s2 SET('a')")
+ eq_(colspec(set_table.c.s3), "s3 SET('5','7','9')")
for col in set_table.c:
self.assert_(repr(col))
def roundtrip(store, expected=None):
expected = expected or store
table.insert(store).execute()
- row = list(table.select().execute())[0]
+ row = table.select().execute().first()
try:
self.assert_(list(row) == expected)
except:
{'s3':set(['5', '7'])},
{'s3':set(['5', '7', '9'])},
{'s3':set(['7', '9'])})
- rows = list(select(
+ rows = select(
[set_table.c.s3],
- set_table.c.s3.in_([set(['5']), set(['5', '7'])])).execute())
+ set_table.c.s3.in_([set(['5']), set(['5', '7']), set(['7', '5'])])
+ ).execute().fetchall()
found = set([frozenset(row[0]) for row in rows])
- eq_(found,
- set([frozenset(['5']), frozenset(['5', '7'])]))
+ eq_(found, set([frozenset(['5']), frozenset(['5', '7'])]))
finally:
meta.drop_all()
Column('e6', mysql.MSEnum("'a'", "b")),
)
- self.assert_eq(colspec(enum_table.c.e1),
+ eq_(colspec(enum_table.c.e1),
"e1 ENUM('a','b')")
- self.assert_eq(colspec(enum_table.c.e2),
+ eq_(colspec(enum_table.c.e2),
"e2 ENUM('a','b') NOT NULL")
- self.assert_eq(colspec(enum_table.c.e3),
+ eq_(colspec(enum_table.c.e3),
"e3 ENUM('a','b')")
- self.assert_eq(colspec(enum_table.c.e4),
+ eq_(colspec(enum_table.c.e4),
"e4 ENUM('a','b') NOT NULL")
- self.assert_eq(colspec(enum_table.c.e5),
+ eq_(colspec(enum_table.c.e5),
"e5 ENUM('a','b')")
- self.assert_eq(colspec(enum_table.c.e6),
+ eq_(colspec(enum_table.c.e6),
"e6 ENUM('''a''','b')")
enum_table.drop(checkfirst=True)
enum_table.create()
# This is known to fail with MySQLDB 1.2.2 beta versions
# which return these as sets.Set(['a']), sets.Set(['b'])
# (even on Pythons with __builtin__.set)
- if testing.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \
- testing.db.dialect.dbapi.version_info >= (1, 2, 2):
+ if (not testing.against('+zxjdbc') and
+ testing.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and
+ testing.db.dialect.dbapi.version_info >= (1, 2, 2)):
# these mysqldb seem to always uses 'sets', even on later pythons
import sets
def convert(value):
e.append(tuple([convert(c) for c in row]))
expected = e
- self.assert_eq(res, expected)
+ eq_(res, expected)
enum_table.drop()
@testing.exclude('mysql', '<', (4,), "3.23 can't handle an ENUM of ''")
finally:
enum_table.drop()
+
+
+class ReflectionTest(TestBase, AssertsExecutionResults):
+
+ __only_on__ = 'mysql'
+
def test_default_reflection(self):
"""Test reflection of column defaults."""
def_table = Table('mysql_def', MetaData(testing.db),
Column('c1', String(10), DefaultClause('')),
Column('c2', String(10), DefaultClause('0')),
- Column('c3', String(10), DefaultClause('abc')))
+ Column('c3', String(10), DefaultClause('abc')),
+ Column('c4', TIMESTAMP, DefaultClause('2009-04-05 12:00:00')),
+ Column('c5', TIMESTAMP, ),
+
+ )
+ def_table.create()
try:
- def_table.create()
reflected = Table('mysql_def', MetaData(testing.db),
- autoload=True)
- for t in def_table, reflected:
- assert t.c.c1.server_default.arg == ''
- assert t.c.c2.server_default.arg == '0'
- assert t.c.c3.server_default.arg == 'abc'
+ autoload=True)
finally:
def_table.drop()
+
+ assert def_table.c.c1.server_default.arg == ''
+ assert def_table.c.c2.server_default.arg == '0'
+ assert def_table.c.c3.server_default.arg == 'abc'
+ assert def_table.c.c4.server_default.arg == '2009-04-05 12:00:00'
+
+ assert str(reflected.c.c1.server_default.arg) == "''"
+ assert str(reflected.c.c2.server_default.arg) == "'0'"
+ assert str(reflected.c.c3.server_default.arg) == "'abc'"
+ assert str(reflected.c.c4.server_default.arg) == "'2009-04-05 12:00:00'"
+
+ reflected.create()
+ try:
+ reflected2 = Table('mysql_def', MetaData(testing.db), autoload=True)
+ finally:
+ reflected.drop()
+ assert str(reflected2.c.c1.server_default.arg) == "''"
+ assert str(reflected2.c.c2.server_default.arg) == "'0'"
+ assert str(reflected2.c.c3.server_default.arg) == "'abc'"
+ assert str(reflected2.c.c4.server_default.arg) == "'2009-04-05 12:00:00'"
+
def test_reflection_on_include_columns(self):
"""Test reflection of include_columns to be sure they respect case."""
( mysql.MSSmallInteger(4), mysql.MSSmallInteger(4), ),
( mysql.MSMediumInteger(), mysql.MSMediumInteger(), ),
( mysql.MSMediumInteger(8), mysql.MSMediumInteger(8), ),
- ( Binary(3), mysql.MSBlob(3), ),
- ( Binary(), mysql.MSBlob() ),
+ ( Binary(3), mysql.TINYBLOB(), ),
+ ( Binary(), mysql.BLOB() ),
( mysql.MSBinary(3), mysql.MSBinary(3), ),
( mysql.MSVarBinary(3),),
( mysql.MSVarBinary(), mysql.MSBlob()),
# in a view, e.g. char -> varchar, tinyblob -> mediumblob
#
# Not sure exactly which point version has the fix.
- if db.dialect.server_version_info(db.connect()) < (5, 0, 11):
+ if db.dialect.server_version_info < (5, 0, 11):
tables = rt,
else:
tables = rt, rv
for table in tables:
for i, reflected in enumerate(table.c):
- assert isinstance(reflected.type, type(expected[i]))
+ assert isinstance(reflected.type, type(expected[i])), \
+ "element %d: %r not instance of %r" % (i, reflected.type, type(expected[i]))
finally:
db.execute('DROP VIEW mysql_types_v')
finally:
tbl.insert().execute()
if 'int_y' in tbl.c:
assert select([tbl.c.int_y]).scalar() == 1
- assert list(tbl.select().execute().fetchone()).count(1) == 1
+ assert list(tbl.select().execute().first()).count(1) == 1
else:
- assert 1 not in list(tbl.select().execute().fetchone())
+ assert 1 not in list(tbl.select().execute().first())
finally:
meta.drop_all()
- def assert_eq(self, got, wanted):
- if got != wanted:
- print "Expected %s" % wanted
- print "Found %s" % got
- eq_(got, wanted)
class SQLTest(TestBase, AssertsCompiledSQL):
(m.MSBit, "t.col"),
# this is kind of sucky. thank you default arguments!
- (NUMERIC, "CAST(t.col AS DECIMAL(10, 2))"),
- (DECIMAL, "CAST(t.col AS DECIMAL(10, 2))"),
- (Numeric, "CAST(t.col AS DECIMAL(10, 2))"),
- (m.MSNumeric, "CAST(t.col AS DECIMAL(10, 2))"),
- (m.MSDecimal, "CAST(t.col AS DECIMAL(10, 2))"),
+ (NUMERIC, "CAST(t.col AS DECIMAL)"),
+ (DECIMAL, "CAST(t.col AS DECIMAL)"),
+ (Numeric, "CAST(t.col AS DECIMAL)"),
+ (m.MSNumeric, "CAST(t.col AS DECIMAL)"),
+ (m.MSDecimal, "CAST(t.col AS DECIMAL)"),
(FLOAT, "t.col"),
(Float, "t.col"),
(DateTime, "CAST(t.col AS DATETIME)"),
(Date, "CAST(t.col AS DATE)"),
(Time, "CAST(t.col AS TIME)"),
- (m.MSDateTime, "CAST(t.col AS DATETIME)"),
- (m.MSDate, "CAST(t.col AS DATE)"),
+ (DateTime, "CAST(t.col AS DATETIME)"),
+ (Date, "CAST(t.col AS DATE)"),
(m.MSTime, "CAST(t.col AS TIME)"),
(m.MSTimeStamp, "CAST(t.col AS DATETIME)"),
(m.MSYear, "t.col"),
class RawReflectionTest(TestBase):
def setup(self):
- self.dialect = mysql.dialect()
- self.reflector = mysql.MySQLSchemaReflector(
- self.dialect.identifier_preparer)
+ dialect = mysql.dialect()
+ self.parser = mysql.MySQLTableDefinitionParser(dialect, dialect.identifier_preparer)
def test_key_reflection(self):
- regex = self.reflector._re_key
+ regex = self.parser._re_key
assert regex.match(' PRIMARY KEY (`id`),')
assert regex.match(' PRIMARY KEY USING BTREE (`id`),')
cx = engine.connect()
meta = MetaData()
-
- assert ('mysql', 'charset') not in cx.info
- assert ('mysql', 'force_charset') not in cx.info
-
- cx.execute(text("SELECT 1")).fetchall()
- assert ('mysql', 'charset') not in cx.info
-
- meta.reflect(cx)
- assert ('mysql', 'charset') in cx.info
-
- cx.execute(text("SET @squiznart=123"))
- assert ('mysql', 'charset') in cx.info
-
- # the charset invalidation is very conservative
- cx.execute(text("SET TIMESTAMP = DEFAULT"))
- assert ('mysql', 'charset') not in cx.info
-
- cx.info[('mysql', 'force_charset')] = 'latin1'
-
- assert engine.dialect._detect_charset(cx) == 'latin1'
- assert cx.info[('mysql', 'charset')] == 'latin1'
-
- del cx.info[('mysql', 'force_charset')]
- del cx.info[('mysql', 'charset')]
+ charset = engine.dialect._detect_charset(cx)
meta.reflect(cx)
- assert ('mysql', 'charset') in cx.info
-
- # String execution doesn't go through the detector.
- cx.execute("SET TIMESTAMP = DEFAULT")
- assert ('mysql', 'charset') in cx.info
+ eq_(cx.dialect._connection_charset, charset)
+ cx.close()
class MatchTest(TestBase, AssertsCompiledSQL):
metadata.drop_all()
def test_expression(self):
+ format = testing.db.dialect.paramstyle == 'format' and '%s' or '?'
self.assert_compile(
matchtable.c.title.match('somstr'),
- "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)")
+ "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)" % format)
def test_simple_match(self):
results = (matchtable.select().
def colspec(c):
- return testing.db.dialect.schemagenerator(testing.db.dialect,
- testing.db, None, None).get_column_specification(c)
+ return testing.db.dialect.ddl_compiler(testing.db.dialect, c.table).get_column_specification(c)
from sqlalchemy.test.testing import eq_
from sqlalchemy import *
+from sqlalchemy import types as sqltypes
from sqlalchemy.sql import table, column
-from sqlalchemy.databases import oracle
from sqlalchemy.test import *
from sqlalchemy.test.testing import eq_
from sqlalchemy.test.engines import testing_engine
+from sqlalchemy.dialects.oracle import cx_oracle, base as oracle
from sqlalchemy.engine import default
+from sqlalchemy.util import jython
import os
meta = MetaData()
parent = Table('parent', meta, Column('id', Integer, primary_key=True),
Column('name', String(50)),
- owner='ed')
+ schema='ed')
child = Table('child', meta, Column('id', Integer, primary_key=True),
Column('parent_id', Integer, ForeignKey('ed.parent.id')),
- owner = 'ed')
+ schema = 'ed')
self.assert_compile(parent.join(child), "ed.parent JOIN ed.child ON ed.parent.id = ed.child.parent_id")
b = bindparam("foo", u"hello world!")
assert b.type.dialect_impl(dialect).get_dbapi_type(dbapi) == 'STRING'
+ def test_type_adapt(self):
+ dialect = cx_oracle.dialect()
+
+ for start, test in [
+ (DateTime(), cx_oracle._OracleDateTime),
+ (TIMESTAMP(), cx_oracle._OracleTimestamp),
+ (oracle.OracleRaw(), cx_oracle._OracleRaw),
+ (String(), String),
+ (VARCHAR(), VARCHAR),
+ (String(50), String),
+ (Unicode(), Unicode),
+ (Text(), cx_oracle._OracleText),
+ (UnicodeText(), cx_oracle._OracleUnicodeText),
+ (NCHAR(), NCHAR),
+ (oracle.RAW(50), cx_oracle._OracleRaw),
+ ]:
+ assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect))
+
+
def test_reflect_raw(self):
types_table = Table(
'all_types', MetaData(testing.db),
def test_reflect_nvarchar(self):
metadata = MetaData(testing.db)
t = Table('t', metadata,
- Column('data', oracle.OracleNVarchar(255))
+ Column('data', sqltypes.NVARCHAR(255))
)
metadata.create_all()
try:
m2 = MetaData(testing.db)
t2 = Table('t', m2, autoload=True)
- assert isinstance(t2.c.data.type, oracle.OracleNVarchar)
+ assert isinstance(t2.c.data.type, sqltypes.NVARCHAR)
data = u'm’a réveillé.'
t2.insert().execute(data=data)
- eq_(t2.select().execute().fetchone()['data'], data)
+ eq_(t2.select().execute().first()['data'], data)
finally:
metadata.drop_all()
t.create(engine)
try:
engine.execute(t.insert(), id=1, data='this is text', bindata='this is binary')
- row = engine.execute(t.select()).fetchone()
+ row = engine.execute(t.select()).first()
eq_(row['data'].read(), 'this is text')
eq_(row['bindata'].read(), 'this is binary')
finally:
Column('data', Binary)
)
meta.create_all()
-
stream = os.path.join(os.path.dirname(__file__), "..", 'binary_data_one.dat')
stream = file(stream).read(12000)
meta.drop_all()
def test_fetch(self):
- eq_(
- binary_table.select().execute().fetchall() ,
- [(i, stream) for i in range(1, 11)],
- )
+ result = binary_table.select().execute().fetchall()
+ if jython:
+ result = [(i, value.tostring()) for i, value in result]
+ eq_(result, [(i, stream) for i in range(1, 11)])
+ @testing.fails_on('+zxjdbc', 'FIXME: zxjdbc should support this')
def test_fetch_single_arraysize(self):
eng = testing_engine(options={'arraysize':1})
- eq_(
- eng.execute(binary_table.select()).fetchall(),
- [(i, stream) for i in range(1, 11)],
- )
+ result = eng.execute(binary_table.select()).fetchall(),
+ if jython:
+ result = [(i, value.tostring()) for i, value in result]
+ eq_(result, [(i, stream) for i in range(1, 11)])
class SequenceTest(TestBase, AssertsCompiledSQL):
def test_basic(self):
from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+from sqlalchemy.test import engines
import datetime
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exc
-from sqlalchemy.databases import postgres
+from sqlalchemy import exc, schema
+from sqlalchemy.dialects.postgresql import base as postgresql
from sqlalchemy.engine.strategies import MockEngineStrategy
from sqlalchemy.test import *
from sqlalchemy.sql import table, column
-
+from sqlalchemy.test.testing import eq_
class SequenceTest(TestBase, AssertsCompiledSQL):
def test_basic(self):
seq = Sequence("my_seq_no_schema")
- dialect = postgres.PGDialect()
+ dialect = postgresql.PGDialect()
assert dialect.identifier_preparer.format_sequence(seq) == "my_seq_no_schema"
seq = Sequence("my_seq", schema="some_schema")
assert dialect.identifier_preparer.format_sequence(seq) == '"Some_Schema"."My_Seq"'
class CompileTest(TestBase, AssertsCompiledSQL):
- __dialect__ = postgres.dialect()
+ __dialect__ = postgresql.dialect()
def test_update_returning(self):
- dialect = postgres.dialect()
+ dialect = postgresql.dialect()
table1 = table('mytable',
column('myid', Integer),
column('name', String(128)),
column('description', String(128)),
)
- u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
+ u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
- u = update(table1, values=dict(name='foo'), postgres_returning=[table1])
+ u = update(table1, values=dict(name='foo')).returning(table1)
self.assert_compile(u, "UPDATE mytable SET name=%(name)s "\
"RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
- u = update(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)])
- self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect)
+ u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+ self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name) AS length_1", dialect=dialect)
+
def test_insert_returning(self):
- dialect = postgres.dialect()
+ dialect = postgresql.dialect()
table1 = table('mytable',
column('myid', Integer),
column('name', String(128)),
column('description', String(128)),
)
- i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
+ i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name)
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect)
- i = insert(table1, values=dict(name='foo'), postgres_returning=[table1])
+ i = insert(table1, values=dict(name='foo')).returning(table1)
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) "\
"RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
- i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)])
- self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
+ i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
+ self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name) AS length_1", dialect=dialect)
+
+ @testing.uses_deprecated(r".*argument is deprecated. Please use statement.returning.*")
+ def test_old_returning_names(self):
+ dialect = postgresql.dialect()
+ table1 = table('mytable',
+ column('myid', Integer),
+ column('name', String(128)),
+ column('description', String(128)),
+ )
+
+ u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
+ self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
+
+ u = update(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name])
+ self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
+
+ i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
+ self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect)
+
+ def test_create_partial_index(self):
+ tbl = Table('testtbl', MetaData(), Column('data',Integer))
+ idx = Index('test_idx1', tbl.c.data, postgresql_where=and_(tbl.c.data > 5, tbl.c.data < 10))
+
+ self.assert_compile(schema.CreateIndex(idx),
+ "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgresql.dialect())
+
+ @testing.uses_deprecated(r".*'postgres_where' argument has been renamed.*")
+ def test_old_create_partial_index(self):
+ tbl = Table('testtbl', MetaData(), Column('data',Integer))
+ idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10))
+
+ self.assert_compile(schema.CreateIndex(idx),
+ "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgresql.dialect())
def test_extract(self):
t = table('t', column('col1'))
"FROM t" % field)
-class ReturningTest(TestBase, AssertsExecutionResults):
- __only_on__ = 'postgres'
-
- @testing.exclude('postgres', '<', (8, 2), '8.3+ feature')
- def test_update_returning(self):
- meta = MetaData(testing.db)
- table = Table('tables', meta,
- Column('id', Integer, primary_key=True),
- Column('persons', Integer),
- Column('full', Boolean)
- )
- table.create()
- try:
- table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
-
- result = table.update(table.c.persons > 4, dict(full=True), postgres_returning=[table.c.id]).execute()
- eq_(result.fetchall(), [(1,)])
-
- result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
- eq_(result2.fetchall(), [(1,True),(2,False)])
- finally:
- table.drop()
-
- @testing.exclude('postgres', '<', (8, 2), '8.3+ feature')
- def test_insert_returning(self):
- meta = MetaData(testing.db)
- table = Table('tables', meta,
- Column('id', Integer, primary_key=True),
- Column('persons', Integer),
- Column('full', Boolean)
- )
- table.create()
- try:
- result = table.insert(postgres_returning=[table.c.id]).execute({'persons': 1, 'full': False})
-
- eq_(result.fetchall(), [(1,)])
-
- @testing.fails_on('postgres', 'Known limitation of psycopg2')
- def test_executemany():
- # return value is documented as failing with psycopg2/executemany
- result2 = table.insert(postgres_returning=[table]).execute(
- [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
- eq_(result2.fetchall(), [(2, 2, False), (3,3,True)])
-
- test_executemany()
-
- result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False})
- eq_([dict(row) for row in result3], [{'double_id':8}])
-
- result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons')
- eq_([dict(row) for row in result4], [{'persons': 10}])
- finally:
- table.drop()
-
-
class InsertTest(TestBase, AssertsExecutionResults):
- __only_on__ = 'postgres'
+ __only_on__ = 'postgresql'
@classmethod
def setup_class(cls):
global metadata
+ cls.engine= testing.db
metadata = MetaData(testing.db)
def teardown(self):
metadata.drop_all()
metadata.tables.clear()
+ if self.engine is not testing.db:
+ self.engine.dispose()
def test_compiled_insert(self):
table = Table('testtable', metadata,
metadata.create_all()
- ins = table.insert(values={'data':bindparam('x')}).compile()
+ ins = table.insert(inline=True, values={'data':bindparam('x')}).compile()
ins.execute({'x':"five"}, {'x':"seven"})
assert table.select().execute().fetchall() == [(1, 'five'), (2, 'seven')]
metadata.create_all()
self._assert_data_with_sequence(table, "my_seq")
+ def test_sequence_returning_insert(self):
+ table = Table('testtable', metadata,
+ Column('id', Integer, Sequence('my_seq'), primary_key=True),
+ Column('data', String(30)))
+ metadata.create_all()
+ self._assert_data_with_sequence_returning(table, "my_seq")
+
def test_opt_sequence_insert(self):
table = Table('testtable', metadata,
Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True),
metadata.create_all()
self._assert_data_autoincrement(table)
+ def test_opt_sequence_returning_insert(self):
+ table = Table('testtable', metadata,
+ Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True),
+ Column('data', String(30)))
+ metadata.create_all()
+ self._assert_data_autoincrement_returning(table)
+
def test_autoincrement_insert(self):
table = Table('testtable', metadata,
Column('id', Integer, primary_key=True),
metadata.create_all()
self._assert_data_autoincrement(table)
+ def test_autoincrement_returning_insert(self):
+ table = Table('testtable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)))
+ metadata.create_all()
+ self._assert_data_autoincrement_returning(table)
+
def test_noautoincrement_insert(self):
table = Table('testtable', metadata,
Column('id', Integer, primary_key=True, autoincrement=False),
self._assert_data_noautoincrement(table)
def _assert_data_autoincrement(self, table):
+ self.engine = engines.testing_engine(options={'implicit_returning':False})
+ metadata.bind = self.engine
+
def go():
# execute with explicit id
r = table.insert().execute({'id':30, 'data':'d1'})
- assert r.last_inserted_ids() == [30]
+ assert r.inserted_primary_key == [30]
# execute with prefetch id
r = table.insert().execute({'data':'d2'})
- assert r.last_inserted_ids() == [1]
+ assert r.inserted_primary_key == [1]
# executemany with explicit ids
table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
# note that the test framework doesnt capture the "preexecute" of a seqeuence
# or default. we just see it in the bind params.
- self.assert_sql(testing.db, go, [], with_sequences=[
+ self.assert_sql(self.engine, go, [], with_sequences=[
(
"INSERT INTO testtable (id, data) VALUES (:id, :data)",
{'id':30, 'data':'d1'}
# test the same series of events using a reflected
# version of the table
- m2 = MetaData(testing.db)
+ m2 = MetaData(self.engine)
table = Table(table.name, m2, autoload=True)
def go():
table.insert().execute({'id':30, 'data':'d1'})
r = table.insert().execute({'data':'d2'})
- assert r.last_inserted_ids() == [5]
+ assert r.inserted_primary_key == [5]
table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
table.insert().execute({'data':'d5'}, {'data':'d6'})
table.insert(inline=True).execute({'id':33, 'data':'d7'})
table.insert(inline=True).execute({'data':'d8'})
- self.assert_sql(testing.db, go, [], with_sequences=[
+ self.assert_sql(self.engine, go, [], with_sequences=[
(
"INSERT INTO testtable (id, data) VALUES (:id, :data)",
{'id':30, 'data':'d1'}
]
table.delete().execute()
+ def _assert_data_autoincrement_returning(self, table):
+ self.engine = engines.testing_engine(options={'implicit_returning':True})
+ metadata.bind = self.engine
+
+ def go():
+ # execute with explicit id
+ r = table.insert().execute({'id':30, 'data':'d1'})
+ assert r.inserted_primary_key == [30]
+
+ # execute with prefetch id
+ r = table.insert().execute({'data':'d2'})
+ assert r.inserted_primary_key == [1]
+
+ # executemany with explicit ids
+ table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
+
+ # executemany, uses SERIAL
+ table.insert().execute({'data':'d5'}, {'data':'d6'})
+
+ # single execute, explicit id, inline
+ table.insert(inline=True).execute({'id':33, 'data':'d7'})
+
+ # single execute, inline, uses SERIAL
+ table.insert(inline=True).execute({'data':'d8'})
+
+ self.assert_sql(self.engine, go, [], with_sequences=[
+ (
+ "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+ {'id':30, 'data':'d1'}
+ ),
+ (
+ "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id",
+ {'data': 'd2'}
+ ),
+ (
+ "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+ [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}]
+ ),
+ (
+ "INSERT INTO testtable (data) VALUES (:data)",
+ [{'data':'d5'}, {'data':'d6'}]
+ ),
+ (
+ "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+ [{'id':33, 'data':'d7'}]
+ ),
+ (
+ "INSERT INTO testtable (data) VALUES (:data)",
+ [{'data':'d8'}]
+ ),
+ ])
+
+ assert table.select().execute().fetchall() == [
+ (30, 'd1'),
+ (1, 'd2'),
+ (31, 'd3'),
+ (32, 'd4'),
+ (2, 'd5'),
+ (3, 'd6'),
+ (33, 'd7'),
+ (4, 'd8'),
+ ]
+ table.delete().execute()
+
+ # test the same series of events using a reflected
+ # version of the table
+ m2 = MetaData(self.engine)
+ table = Table(table.name, m2, autoload=True)
+
+ def go():
+ table.insert().execute({'id':30, 'data':'d1'})
+ r = table.insert().execute({'data':'d2'})
+ assert r.inserted_primary_key == [5]
+ table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
+ table.insert().execute({'data':'d5'}, {'data':'d6'})
+ table.insert(inline=True).execute({'id':33, 'data':'d7'})
+ table.insert(inline=True).execute({'data':'d8'})
+
+ self.assert_sql(self.engine, go, [], with_sequences=[
+ (
+ "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+ {'id':30, 'data':'d1'}
+ ),
+ (
+ "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id",
+ {'data':'d2'}
+ ),
+ (
+ "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+ [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}]
+ ),
+ (
+ "INSERT INTO testtable (data) VALUES (:data)",
+ [{'data':'d5'}, {'data':'d6'}]
+ ),
+ (
+ "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+ [{'id':33, 'data':'d7'}]
+ ),
+ (
+ "INSERT INTO testtable (data) VALUES (:data)",
+ [{'data':'d8'}]
+ ),
+ ])
+
+ assert table.select().execute().fetchall() == [
+ (30, 'd1'),
+ (5, 'd2'),
+ (31, 'd3'),
+ (32, 'd4'),
+ (6, 'd5'),
+ (7, 'd6'),
+ (33, 'd7'),
+ (8, 'd8'),
+ ]
+ table.delete().execute()
+
def _assert_data_with_sequence(self, table, seqname):
+ self.engine = engines.testing_engine(options={'implicit_returning':False})
+ metadata.bind = self.engine
+
def go():
table.insert().execute({'id':30, 'data':'d1'})
table.insert().execute({'data':'d2'})
table.insert(inline=True).execute({'id':33, 'data':'d7'})
table.insert(inline=True).execute({'data':'d8'})
- self.assert_sql(testing.db, go, [], with_sequences=[
+ self.assert_sql(self.engine, go, [], with_sequences=[
(
"INSERT INTO testtable (id, data) VALUES (:id, :data)",
{'id':30, 'data':'d1'}
# cant test reflection here since the Sequence must be
# explicitly specified
+ def _assert_data_with_sequence_returning(self, table, seqname):
+ self.engine = engines.testing_engine(options={'implicit_returning':True})
+ metadata.bind = self.engine
+
+ def go():
+ table.insert().execute({'id':30, 'data':'d1'})
+ table.insert().execute({'data':'d2'})
+ table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'})
+ table.insert().execute({'data':'d5'}, {'data':'d6'})
+ table.insert(inline=True).execute({'id':33, 'data':'d7'})
+ table.insert(inline=True).execute({'data':'d8'})
+
+ self.assert_sql(self.engine, go, [], with_sequences=[
+ (
+ "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+ {'id':30, 'data':'d1'}
+ ),
+ (
+ "INSERT INTO testtable (id, data) VALUES (nextval('my_seq'), :data) RETURNING testtable.id",
+ {'data':'d2'}
+ ),
+ (
+ "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+ [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}]
+ ),
+ (
+ "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname,
+ [{'data':'d5'}, {'data':'d6'}]
+ ),
+ (
+ "INSERT INTO testtable (id, data) VALUES (:id, :data)",
+ [{'id':33, 'data':'d7'}]
+ ),
+ (
+ "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname,
+ [{'data':'d8'}]
+ ),
+ ])
+
+ assert table.select().execute().fetchall() == [
+ (30, 'd1'),
+ (1, 'd2'),
+ (31, 'd3'),
+ (32, 'd4'),
+ (2, 'd5'),
+ (3, 'd6'),
+ (33, 'd7'),
+ (4, 'd8'),
+ ]
+
+ # cant test reflection here since the Sequence must be
+ # explicitly specified
+
def _assert_data_noautoincrement(self, table):
+ self.engine = engines.testing_engine(options={'implicit_returning':False})
+ metadata.bind = self.engine
+
table.insert().execute({'id':30, 'data':'d1'})
- try:
- table.insert().execute({'data':'d2'})
- assert False
- except exc.IntegrityError, e:
- assert "violates not-null constraint" in str(e)
- try:
- table.insert().execute({'data':'d2'}, {'data':'d3'})
- assert False
- except exc.IntegrityError, e:
- assert "violates not-null constraint" in str(e)
+
+ if self.engine.driver == 'pg8000':
+ exception_cls = exc.ProgrammingError
+ else:
+ exception_cls = exc.IntegrityError
+
+ assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'})
+ assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'})
+
+ assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'})
+
+ assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'})
table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
table.insert(inline=True).execute({'id':33, 'data':'d4'})
# test the same series of events using a reflected
# version of the table
- m2 = MetaData(testing.db)
+ m2 = MetaData(self.engine)
table = Table(table.name, m2, autoload=True)
table.insert().execute({'id':30, 'data':'d1'})
- try:
- table.insert().execute({'data':'d2'})
- assert False
- except exc.IntegrityError, e:
- assert "violates not-null constraint" in str(e)
- try:
- table.insert().execute({'data':'d2'}, {'data':'d3'})
- assert False
- except exc.IntegrityError, e:
- assert "violates not-null constraint" in str(e)
+
+ assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'})
+ assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'})
table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
table.insert(inline=True).execute({'id':33, 'data':'d4'})
class DomainReflectionTest(TestBase, AssertsExecutionResults):
"Test PostgreSQL domains"
- __only_on__ = 'postgres'
+ __only_on__ = 'postgresql'
@classmethod
def setup_class(cls):
con = testing.db.connect()
for ddl in ('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42',
- 'CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0'):
+ 'CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0'):
try:
con.execute(ddl)
except exc.SQLError, e:
if not "already exists" in str(e):
raise e
con.execute('CREATE TABLE testtable (question integer, answer testdomain)')
- con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)')
- con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)')
+ con.execute('CREATE TABLE test_schema.testtable(question integer, answer test_schema.testdomain, anything integer)')
+ con.execute('CREATE TABLE crosschema (question integer, answer test_schema.testdomain)')
@classmethod
def teardown_class(cls):
con = testing.db.connect()
con.execute('DROP TABLE testtable')
- con.execute('DROP TABLE alt_schema.testtable')
+ con.execute('DROP TABLE test_schema.testtable')
con.execute('DROP TABLE crosschema')
con.execute('DROP DOMAIN testdomain')
- con.execute('DROP DOMAIN alt_schema.testdomain')
+ con.execute('DROP DOMAIN test_schema.testdomain')
def test_table_is_reflected(self):
metadata = MetaData(testing.db)
table = Table('testtable', metadata, autoload=True)
eq_(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
- eq_(table.c.answer.type.__class__, postgres.PGInteger)
+ assert isinstance(table.c.answer.type, Integer)
def test_domain_is_reflected(self):
metadata = MetaData(testing.db)
eq_(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value")
assert not table.columns.answer.nullable, "Expected reflected column to not be nullable."
- def test_table_is_reflected_alt_schema(self):
+ def test_table_is_reflected_test_schema(self):
metadata = MetaData(testing.db)
- table = Table('testtable', metadata, autoload=True, schema='alt_schema')
+ table = Table('testtable', metadata, autoload=True, schema='test_schema')
eq_(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
- eq_(table.c.anything.type.__class__, postgres.PGInteger)
+ assert isinstance(table.c.anything.type, Integer)
def test_schema_domain_is_reflected(self):
metadata = MetaData(testing.db)
- table = Table('testtable', metadata, autoload=True, schema='alt_schema')
+ table = Table('testtable', metadata, autoload=True, schema='test_schema')
eq_(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
assert table.columns.answer.nullable, "Expected reflected column to be nullable."
assert table.columns.answer.nullable, "Expected reflected column to be nullable."
def test_unknown_types(self):
- from sqlalchemy.databases import postgres
+ from sqlalchemy.databases import postgresql
- ischema_names = postgres.ischema_names
- postgres.ischema_names = {}
+ ischema_names = postgresql.PGDialect.ischema_names
+ postgresql.PGDialect.ischema_names = {}
try:
m2 = MetaData(testing.db)
assert_raises(exc.SAWarning, Table, "testtable", m2, autoload=True)
assert t3.c.answer.type.__class__ == sa.types.NullType
finally:
- postgres.ischema_names = ischema_names
+ postgresql.PGDialect.ischema_names = ischema_names
-class MiscTest(TestBase, AssertsExecutionResults):
- __only_on__ = 'postgres'
+class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
+ __only_on__ = 'postgresql'
def test_date_reflection(self):
m1 = MetaData(testing.db)
'FROM mytable')
def test_schema_reflection(self):
- """note: this test requires that the 'alt_schema' schema be separate and accessible by the test user"""
+ """note: this test requires that the 'test_schema' schema be separate and accessible by the test user"""
meta1 = MetaData(testing.db)
users = Table('users', meta1,
Column('user_id', Integer, primary_key = True),
Column('user_name', String(30), nullable = False),
- schema="alt_schema"
+ schema="test_schema"
)
addresses = Table('email_addresses', meta1,
Column('address_id', Integer, primary_key = True),
Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
Column('email_address', String(20)),
- schema="alt_schema"
+ schema="test_schema"
)
meta1.create_all()
try:
meta2 = MetaData(testing.db)
- addresses = Table('email_addresses', meta2, autoload=True, schema="alt_schema")
- users = Table('users', meta2, mustexist=True, schema="alt_schema")
+ addresses = Table('email_addresses', meta2, autoload=True, schema="test_schema")
+ users = Table('users', meta2, mustexist=True, schema="test_schema")
print users
print addresses
referer = Table("referer", meta1,
Column("id", Integer, primary_key=True),
Column("ref", Integer, ForeignKey('subject.id')),
- schema="alt_schema")
+ schema="test_schema")
meta1.create_all()
try:
meta2 = MetaData(testing.db)
subject = Table("subject", meta2, autoload=True)
- referer = Table("referer", meta2, schema="alt_schema", autoload=True)
+ referer = Table("referer", meta2, schema="test_schema", autoload=True)
print str(subject.join(referer).onclause)
self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause))
finally:
meta1 = MetaData(testing.db)
subject = Table("subject", meta1,
Column("id", Integer, primary_key=True),
- schema='alt_schema_2'
+ schema='test_schema_2'
)
referer = Table("referer", meta1,
Column("id", Integer, primary_key=True),
- Column("ref", Integer, ForeignKey('alt_schema_2.subject.id')),
- schema="alt_schema")
+ Column("ref", Integer, ForeignKey('test_schema_2.subject.id')),
+ schema="test_schema")
meta1.create_all()
try:
meta2 = MetaData(testing.db)
- subject = Table("subject", meta2, autoload=True, schema="alt_schema_2")
- referer = Table("referer", meta2, schema="alt_schema", autoload=True)
+ subject = Table("subject", meta2, autoload=True, schema="test_schema_2")
+ referer = Table("referer", meta2, schema="test_schema", autoload=True)
print str(subject.join(referer).onclause)
self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause))
finally:
meta = MetaData(testing.db)
users = Table('users', meta,
Column('id', Integer, primary_key=True),
- Column('name', String(50)), schema='alt_schema')
+ Column('name', String(50)), schema='test_schema')
users.create()
try:
users.insert().execute(id=1, name='name1')
user_name VARCHAR NOT NULL,
user_password VARCHAR NOT NULL
);
- """, None)
+ """)
t = Table("speedy_users", meta, autoload=True)
r = t.insert().execute(user_name='user', user_password='lala')
- assert r.last_inserted_ids() == [1]
+ assert r.inserted_primary_key == [1]
l = t.select().execute().fetchall()
assert l == [(1, 'user', 'lala')]
finally:
- testing.db.execute("drop table speedy_users", None)
+ testing.db.execute("drop table speedy_users")
@testing.emits_warning()
def test_index_reflection(self):
testing.db.execute("""
create index idx1 on party ((id || name))
- """, None)
+ """)
testing.db.execute("""
create unique index idx2 on party (id) where name = 'test'
- """, None)
+ """)
testing.db.execute("""
create index idx3 on party using btree
warnings.warn = capture_warnings._orig_showwarning
m1.drop_all()
- def test_create_partial_index(self):
- tbl = Table('testtbl', MetaData(), Column('data',Integer))
- idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10))
-
- executed_sql = []
- mock_strategy = MockEngineStrategy()
- mock_conn = mock_strategy.create('postgres://', executed_sql.append)
+ def test_set_isolation_level(self):
+ """Test setting the isolation level with create_engine"""
+ eng = create_engine(testing.db.url)
+ eq_(
+ eng.execute("show transaction isolation level").scalar(),
+ 'read committed')
+ eng = create_engine(testing.db.url, isolation_level="SERIALIZABLE")
+ eq_(
+ eng.execute("show transaction isolation level").scalar(),
+ 'serializable')
+ eng = create_engine(testing.db.url, isolation_level="FOO")
- idx.create(mock_conn)
+ if testing.db.driver == 'zxjdbc':
+ exception_cls = eng.dialect.dbapi.Error
+ else:
+ exception_cls = eng.dialect.dbapi.ProgrammingError
+ assert_raises(exception_cls, eng.execute, "show transaction isolation level")
- assert executed_sql == ['CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10']
class TimezoneTest(TestBase, AssertsExecutionResults):
"""Test timezone-aware datetimes.
- psycopg will return a datetime with a tzinfo attached to it, if postgres
+ psycopg will return a datetime with a tzinfo attached to it, if postgresql
returns it. python then will not let you compare a datetime with a tzinfo
to a datetime that doesnt have one. this test illustrates two ways to
have datetime types with and without timezone info.
"""
- __only_on__ = 'postgres'
+ __only_on__ = 'postgresql'
@classmethod
def setup_class(cls):
global tztable, notztable, metadata
metadata = MetaData(testing.db)
- # current_timestamp() in postgres is assumed to return TIMESTAMP WITH TIMEZONE
+ # current_timestamp() in postgresql is assumed to return TIMESTAMP WITH TIMEZONE
tztable = Table('tztable', metadata,
Column("id", Integer, primary_key=True),
Column("date", DateTime(timezone=True), onupdate=func.current_timestamp()),
somedate = testing.db.connect().scalar(func.current_timestamp().select())
tztable.insert().execute(id=1, name='row1', date=somedate)
c = tztable.update(tztable.c.id==1).execute(name='newname')
- print tztable.select(tztable.c.id==1).execute().fetchone()
+ print tztable.select(tztable.c.id==1).execute().first()
def test_without_timezone(self):
# get a date without a tzinfo
somedate = datetime.datetime(2005, 10,20, 11, 52, 00)
notztable.insert().execute(id=1, name='row1', date=somedate)
c = notztable.update(notztable.c.id==1).execute(name='newname')
- print notztable.select(tztable.c.id==1).execute().fetchone()
+ print notztable.select(tztable.c.id==1).execute().first()
class ArrayTest(TestBase, AssertsExecutionResults):
- __only_on__ = 'postgres'
+ __only_on__ = 'postgresql'
@classmethod
def setup_class(cls):
arrtable = Table('arrtable', metadata,
Column('id', Integer, primary_key=True),
- Column('intarr', postgres.PGArray(Integer)),
- Column('strarr', postgres.PGArray(String(convert_unicode=True)), nullable=False)
+ Column('intarr', postgresql.PGArray(Integer)),
+ Column('strarr', postgresql.PGArray(String(convert_unicode=True)), nullable=False)
)
metadata.create_all()
+
+ def teardown(self):
+ arrtable.delete().execute()
+
@classmethod
def teardown_class(cls):
metadata.drop_all()
def test_reflect_array_column(self):
metadata2 = MetaData(testing.db)
tbl = Table('arrtable', metadata2, autoload=True)
- assert isinstance(tbl.c.intarr.type, postgres.PGArray)
- assert isinstance(tbl.c.strarr.type, postgres.PGArray)
+ assert isinstance(tbl.c.intarr.type, postgresql.PGArray)
+ assert isinstance(tbl.c.strarr.type, postgresql.PGArray)
assert isinstance(tbl.c.intarr.type.item_type, Integer)
assert isinstance(tbl.c.strarr.type.item_type, String)
+ @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays')
def test_insert_array(self):
arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
results = arrtable.select().execute().fetchall()
eq_(len(results), 1)
eq_(results[0]['intarr'], [1,2,3])
eq_(results[0]['strarr'], ['abc','def'])
- arrtable.delete().execute()
+ @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays')
+ @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays')
def test_array_where(self):
arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
arrtable.insert().execute(intarr=[4,5,6], strarr='ABC')
results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall()
eq_(len(results), 1)
eq_(results[0]['intarr'], [1,2,3])
- arrtable.delete().execute()
+ @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays')
+ @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays')
def test_array_concat(self):
arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall()
eq_(len(results), 1)
eq_(results[0][0], [1,2,3,4,5,6])
- arrtable.delete().execute()
+ @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays')
+ @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays')
def test_array_subtype_resultprocessor(self):
arrtable.insert().execute(intarr=[4,5,6], strarr=[[u'm\xe4\xe4'], [u'm\xf6\xf6']])
arrtable.insert().execute(intarr=[1,2,3], strarr=[u'm\xe4\xe4', u'm\xf6\xf6'])
eq_(len(results), 2)
eq_(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6'])
eq_(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']])
- arrtable.delete().execute()
+ @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays')
+ @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays')
def test_array_mutability(self):
class Foo(object): pass
footable = Table('foo', metadata,
Column('id', Integer, primary_key=True),
- Column('intarr', postgres.PGArray(Integer), nullable=True)
+ Column('intarr', postgresql.PGArray(Integer), nullable=True)
)
mapper(Foo, footable)
metadata.create_all()
sess.add(foo)
sess.flush()
-class TimeStampTest(TestBase, AssertsExecutionResults):
- __only_on__ = 'postgres'
-
- @testing.uses_deprecated()
+class TimestampTest(TestBase, AssertsExecutionResults):
+ __only_on__ = 'postgresql'
+
def test_timestamp(self):
engine = testing.db
connection = engine.connect()
- s = select([func.TIMESTAMP("12/25/07").label("ts")])
- result = connection.execute(s).fetchone()
+
+ s = select(["timestamp '2007-12-25'"])
+ result = connection.execute(s).first()
eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0))
class ServerSideCursorsTest(TestBase, AssertsExecutionResults):
- __only_on__ = 'postgres'
+ __only_on__ = 'postgresql+psycopg2'
@classmethod
def setup_class(cls):
class SpecialTypesTest(TestBase, ComparesTables):
"""test DDL and reflection of PG-specific types """
- __only_on__ = 'postgres'
- __excluded_on__ = (('postgres', '<', (8, 3, 0)),)
+ __only_on__ = 'postgresql'
+ __excluded_on__ = (('postgresql', '<', (8, 3, 0)),)
@classmethod
def setup_class(cls):
metadata = MetaData(testing.db)
table = Table('sometable', metadata,
- Column('id', postgres.PGUuid, primary_key=True),
- Column('flag', postgres.PGBit),
- Column('addr', postgres.PGInet),
- Column('addr2', postgres.PGMacAddr),
- Column('addr3', postgres.PGCidr)
+ Column('id', postgresql.PGUuid, primary_key=True),
+ Column('flag', postgresql.PGBit),
+ Column('addr', postgresql.PGInet),
+ Column('addr2', postgresql.PGMacAddr),
+ Column('addr3', postgresql.PGCidr)
)
metadata.create_all()
class MatchTest(TestBase, AssertsCompiledSQL):
- __only_on__ = 'postgres'
- __excluded_on__ = (('postgres', '<', (8, 3, 0)),)
+ __only_on__ = 'postgresql'
+ __excluded_on__ = (('postgresql', '<', (8, 3, 0)),)
@classmethod
def setup_class(cls):
def teardown_class(cls):
metadata.drop_all()
- def test_expression(self):
+ @testing.fails_on('postgresql+pg8000', 'uses positional')
+ @testing.fails_on('postgresql+zxjdbc', 'uses qmark')
+ def test_expression_pyformat(self):
self.assert_compile(matchtable.c.title.match('somstr'), "matchtable.title @@ to_tsquery(%(title_1)s)")
+ @testing.fails_on('postgresql+psycopg2', 'uses pyformat')
+ @testing.fails_on('postgresql+zxjdbc', 'uses qmark')
+ def test_expression_positional(self):
+ self.assert_compile(matchtable.c.title.match('somstr'), "matchtable.title @@ to_tsquery(%s)")
+
def test_simple_match(self):
results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall()
eq_([2, 5], [r.id for r in results])
import datetime
from sqlalchemy import *
from sqlalchemy import exc, sql
-from sqlalchemy.databases import sqlite
+from sqlalchemy.dialects.sqlite import base as sqlite, pysqlite as pysqlite_dialect
from sqlalchemy.test import *
meta = MetaData(testing.db)
t = Table('bool_table', meta,
Column('id', Integer, primary_key=True),
- Column('boo', sqlite.SLBoolean))
+ Column('boo', Boolean))
try:
meta.create_all()
def test_time_microseconds(self):
dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125) # 125 usec
eq_(str(dt), '2008-06-27 12:00:00.000125')
- sldt = sqlite.SLDateTime()
+ sldt = sqlite._SLDateTime()
bp = sldt.bind_processor(None)
eq_(bp(dt), '2008-06-27 12:00:00.000125')
bindproc = t.dialect_impl(dialect).bind_processor(dialect)
assert not bindproc or isinstance(bindproc(u"some string"), unicode)
- @testing.uses_deprecated('Using String type with no length')
def test_type_reflection(self):
# (ask_for, roundtripped_as_if_different)
- specs = [( String(), sqlite.SLString(), ),
- ( String(1), sqlite.SLString(1), ),
- ( String(3), sqlite.SLString(3), ),
- ( Text(), sqlite.SLText(), ),
- ( Unicode(), sqlite.SLString(), ),
- ( Unicode(1), sqlite.SLString(1), ),
- ( Unicode(3), sqlite.SLString(3), ),
- ( UnicodeText(), sqlite.SLText(), ),
- ( CLOB, sqlite.SLText(), ),
- ( sqlite.SLChar(1), ),
- ( CHAR(3), sqlite.SLChar(3), ),
- ( NCHAR(2), sqlite.SLChar(2), ),
- ( SmallInteger(), sqlite.SLSmallInteger(), ),
- ( sqlite.SLSmallInteger(), ),
- ( Binary(3), sqlite.SLBinary(), ),
- ( Binary(), sqlite.SLBinary() ),
- ( sqlite.SLBinary(3), sqlite.SLBinary(), ),
- ( NUMERIC, sqlite.SLNumeric(), ),
- ( NUMERIC(10,2), sqlite.SLNumeric(10,2), ),
- ( Numeric, sqlite.SLNumeric(), ),
- ( Numeric(10, 2), sqlite.SLNumeric(10, 2), ),
- ( DECIMAL, sqlite.SLNumeric(), ),
- ( DECIMAL(10, 2), sqlite.SLNumeric(10, 2), ),
- ( Float, sqlite.SLFloat(), ),
- ( sqlite.SLNumeric(), ),
- ( INT, sqlite.SLInteger(), ),
- ( Integer, sqlite.SLInteger(), ),
- ( sqlite.SLInteger(), ),
- ( TIMESTAMP, sqlite.SLDateTime(), ),
- ( DATETIME, sqlite.SLDateTime(), ),
- ( DateTime, sqlite.SLDateTime(), ),
- ( sqlite.SLDateTime(), ),
- ( DATE, sqlite.SLDate(), ),
- ( Date, sqlite.SLDate(), ),
- ( sqlite.SLDate(), ),
- ( TIME, sqlite.SLTime(), ),
- ( Time, sqlite.SLTime(), ),
- ( sqlite.SLTime(), ),
- ( BOOLEAN, sqlite.SLBoolean(), ),
- ( Boolean, sqlite.SLBoolean(), ),
- ( sqlite.SLBoolean(), ),
+ specs = [( String(), String(), ),
+ ( String(1), String(1), ),
+ ( String(3), String(3), ),
+ ( Text(), Text(), ),
+ ( Unicode(), String(), ),
+ ( Unicode(1), String(1), ),
+ ( Unicode(3), String(3), ),
+ ( UnicodeText(), Text(), ),
+ ( CHAR(1), ),
+ ( CHAR(3), CHAR(3), ),
+ ( NUMERIC, NUMERIC(), ),
+ ( NUMERIC(10,2), NUMERIC(10,2), ),
+ ( Numeric, NUMERIC(), ),
+ ( Numeric(10, 2), NUMERIC(10, 2), ),
+ ( DECIMAL, DECIMAL(), ),
+ ( DECIMAL(10, 2), DECIMAL(10, 2), ),
+ ( Float, Float(), ),
+ ( NUMERIC(), ),
+ ( TIMESTAMP, TIMESTAMP(), ),
+ ( DATETIME, DATETIME(), ),
+ ( DateTime, DateTime(), ),
+ ( DateTime(), ),
+ ( DATE, DATE(), ),
+ ( Date, Date(), ),
+ ( TIME, TIME(), ),
+ ( Time, Time(), ),
+ ( BOOLEAN, BOOLEAN(), ),
+ ( Boolean, Boolean(), ),
]
columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)]
db = testing.db
m = MetaData(db)
t_table = Table('types', m, *columns)
+ m.create_all()
try:
- m.create_all()
-
m2 = MetaData(db)
rt = Table('types', m2, autoload=True)
try:
expected = [len(c) > 1 and c[1] or c[0] for c in specs]
for table in rt, rv:
for i, reflected in enumerate(table.c):
- assert isinstance(reflected.type, type(expected[i])), type(expected[i])
+ assert isinstance(reflected.type, type(expected[i])), "%d: %r" % (i, type(expected[i]))
finally:
db.execute('DROP VIEW types_v')
finally:
rt = Table('t_defaults', m2, autoload=True)
expected = [c[1] for c in specs]
for i, reflected in enumerate(rt.c):
- eq_(reflected.server_default.arg.text, expected[i])
+ eq_(str(reflected.server_default.arg), expected[i])
finally:
m.drop_all()
db = testing.db
m = MetaData(db)
- expected = ["'my_default'", '0']
+ expected = ["my_default", '0']
table = """CREATE TABLE r_defaults (
data VARCHAR(40) DEFAULT 'my_default',
val INTEGER NOT NULL DEFAULT 0
rt = Table('r_defaults', m, autoload=True)
for i, reflected in enumerate(rt.c):
- eq_(reflected.server_default.arg.text, expected[i])
+ eq_(str(reflected.server_default.arg), expected[i])
finally:
db.execute("DROP TABLE r_defaults")
def test_attached_as_schema(self):
cx = testing.db.connect()
try:
- cx.execute('ATTACH DATABASE ":memory:" AS alt_schema')
+ cx.execute('ATTACH DATABASE ":memory:" AS test_schema')
dialect = cx.dialect
- assert dialect.table_names(cx, 'alt_schema') == []
+ assert dialect.table_names(cx, 'test_schema') == []
meta = MetaData(cx)
Table('created', meta, Column('id', Integer),
- schema='alt_schema')
+ schema='test_schema')
alt_master = Table('sqlite_master', meta, autoload=True,
- schema='alt_schema')
+ schema='test_schema')
meta.create_all(cx)
- eq_(dialect.table_names(cx, 'alt_schema'),
+ eq_(dialect.table_names(cx, 'test_schema'),
['created'])
assert len(alt_master.c) > 0
meta.clear()
reflected = Table('created', meta, autoload=True,
- schema='alt_schema')
+ schema='test_schema')
assert len(reflected.c) == 1
cx.execute(reflected.insert(), dict(id=1))
# note that sqlite_master is cleared, above
meta.drop_all()
- assert dialect.table_names(cx, 'alt_schema') == []
+ assert dialect.table_names(cx, 'test_schema') == []
finally:
- cx.execute('DETACH DATABASE alt_schema')
+ cx.execute('DETACH DATABASE test_schema')
@testing.exclude('sqlite', '<', (2, 6), 'no database support')
def test_temp_table_reflection(self):
pass
raise
+ def test_set_isolation_level(self):
+ """Test setting the read uncommitted/serializable levels"""
+ eng = create_engine(testing.db.url)
+ eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 0)
+
+ eng = create_engine(testing.db.url, isolation_level="READ UNCOMMITTED")
+ eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 1)
+
+ eng = create_engine(testing.db.url, isolation_level="SERIALIZABLE")
+ eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 0)
+
+ assert_raises(exc.ArgumentError, create_engine, testing.db.url,
+ isolation_level="FOO")
+
class SQLTest(TestBase, AssertsCompiledSQL):
"""Tests SQLite-dialect specific compilation."""
table = Table('test_table', metadata,
Column('foo', Integer))
- metadata.connect(bind)
+ metadata.bind = bind
assert metadata.bind is table.bind is bind
metadata.create_all()
try:
e = elem(bind=bind)
assert e.bind is bind
- e.execute()
+ e.execute().close()
finally:
if isinstance(bind, engine.Connection):
bind.close()
-from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
-from sqlalchemy.schema import DDL
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+from sqlalchemy.schema import DDL, CheckConstraint, AddConstraint, DropConstraint
from sqlalchemy import create_engine
from sqlalchemy import MetaData, Integer, String
from sqlalchemy.test.schema import Table
from sqlalchemy.test.schema import Column
import sqlalchemy as tsa
from sqlalchemy.test import TestBase, testing, engines
-
+from sqlalchemy.test.testing import AssertsCompiledSQL
+from nose import SkipTest
class DDLEventTest(TestBase):
class Canary(object):
self.schema_item = schema_item
self.bind = bind
- def before_create(self, action, schema_item, bind):
+ def before_create(self, action, schema_item, bind, **kw):
assert self.state is None
assert schema_item is self.schema_item
assert bind is self.bind
self.state = action
- def after_create(self, action, schema_item, bind):
+ def after_create(self, action, schema_item, bind, **kw):
assert self.state in ('before-create', 'skipped')
assert schema_item is self.schema_item
assert bind is self.bind
self.state = action
- def before_drop(self, action, schema_item, bind):
+ def before_drop(self, action, schema_item, bind, **kw):
assert self.state is None
assert schema_item is self.schema_item
assert bind is self.bind
self.state = action
- def after_drop(self, action, schema_item, bind):
+ def after_drop(self, action, schema_item, bind, **kw):
assert self.state in ('before-drop', 'skipped')
assert schema_item is self.schema_item
assert bind is self.bind
assert 'klptzyxm' not in strings
assert 'xyzzy' in strings
assert 'fnord' in strings
-
+
+ def test_conditional_constraint(self):
+ metadata, users, engine = self.metadata, self.users, self.engine
+ nonpg_mock = engines.mock_engine(dialect_name='sqlite')
+ pg_mock = engines.mock_engine(dialect_name='postgresql')
+
+ constraint = CheckConstraint('a < b',name="my_test_constraint", table=users)
+
+ # by placing the constraint in an Add/Drop construct,
+ # the 'inline_ddl' flag is set to False
+ AddConstraint(constraint, on='postgresql').execute_at("after-create", users)
+ DropConstraint(constraint, on='postgresql').execute_at("before-drop", users)
+
+ metadata.create_all(bind=nonpg_mock)
+ strings = " ".join(str(x) for x in nonpg_mock.mock)
+ assert "my_test_constraint" not in strings
+ metadata.drop_all(bind=nonpg_mock)
+ strings = " ".join(str(x) for x in nonpg_mock.mock)
+ assert "my_test_constraint" not in strings
+
+ metadata.create_all(bind=pg_mock)
+ strings = " ".join(str(x) for x in pg_mock.mock)
+ assert "my_test_constraint" in strings
+ metadata.drop_all(bind=pg_mock)
+ strings = " ".join(str(x) for x in pg_mock.mock)
+ assert "my_test_constraint" in strings
+
def test_metadata(self):
metadata, engine = self.metadata, self.engine
DDL('mxyzptlk').execute_at('before-create', metadata)
assert 'fnord' in strings
def test_ddl_execute(self):
- engine = create_engine('sqlite:///')
+ try:
+ engine = create_engine('sqlite:///')
+ except ImportError:
+ raise SkipTest('Requires sqlite')
cx = engine.connect()
table = self.users
ddl = DDL('SELECT 1')
r = eval(py)
assert list(r) == [(1,)], py
-class DDLTest(TestBase):
+class DDLTest(TestBase, AssertsCompiledSQL):
def mock_engine(self):
executor = lambda *a, **kw: None
engine = create_engine(testing.db.name + '://',
def test_tokens(self):
m = MetaData()
- bind = self.mock_engine()
sane_alone = Table('t', m, Column('id', Integer))
sane_schema = Table('t', m, Column('id', Integer), schema='s')
insane_alone = Table('t t', m, Column('id', Integer))
ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
- eq_(ddl._expand(sane_alone, bind), '-t-t')
- eq_(ddl._expand(sane_schema, bind), 's-t-s.t')
- eq_(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
- eq_(ddl._expand(insane_schema, bind),
- '"s s"-"t t"-"s s"."t t"')
+ dialect = self.mock_engine().dialect
+ self.assert_compile(ddl.against(sane_alone), '-t-t', dialect=dialect)
+ self.assert_compile(ddl.against(sane_schema), 's-t-s.t', dialect=dialect)
+ self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect)
+ self.assert_compile(ddl.against(insane_schema), '"s s"-"t t"-"s s"."t t"', dialect=dialect)
# overrides are used piece-meal and verbatim.
ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s',
context={'schema':'S S', 'table': 'T T', 'bonus': 'b'})
- eq_(ddl._expand(sane_alone, bind), 'S S-T T-t-b')
- eq_(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b')
- eq_(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b')
- eq_(ddl._expand(insane_schema, bind),
- 'S S-T T-"s s"."t t"-b')
+
+ self.assert_compile(ddl.against(sane_alone), 'S S-T T-t-b', dialect=dialect)
+ self.assert_compile(ddl.against(sane_schema), 'S S-T T-s.t-b', dialect=dialect)
+ self.assert_compile(ddl.against(insane_alone), 'S S-T T-"t t"-b', dialect=dialect)
+ self.assert_compile(ddl.against(insane_schema), 'S S-T T-"s s"."t t"-b', dialect=dialect)
+
def test_filter(self):
cx = self.mock_engine()
global users, metadata
metadata = MetaData(testing.db)
users = Table('users', metadata,
- Column('user_id', INT, primary_key = True),
+ Column('user_id', INT, primary_key = True, test_needs_autoincrement=True),
Column('user_name', VARCHAR(20)),
)
metadata.create_all()
+ @engines.close_first
def teardown(self):
testing.db.connect().execute(users.delete())
+
@classmethod
def teardown_class(cls):
metadata.drop_all()
- @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite')
+ @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc', '+zxjdbc')
def test_raw_qmark(self):
for conn in (testing.db, testing.db.connect()):
conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack"))
assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')]
conn.execute("delete from users")
- @testing.fails_on_everything_except('mysql', 'postgres')
+ @testing.fails_on_everything_except('mysql+mysqldb', 'postgresql')
+ @testing.fails_on('postgresql+zxjdbc', 'sprintf not supported')
# some psycopg2 versions bomb this.
def test_raw_sprintf(self):
for conn in (testing.db, testing.db.connect()):
# pyformat is supported for mysql, but skipping because a few driver
# versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2)
- @testing.skip_if(lambda: testing.against('mysql'), 'db-api flaky')
- @testing.fails_on_everything_except('postgres')
+ @testing.skip_if(lambda: testing.against('mysql+mysqldb'), 'db-api flaky')
+ @testing.fails_on_everything_except('postgresql+psycopg2')
def test_raw_python(self):
for conn in (testing.db, testing.db.connect()):
conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'})
assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')]
conn.execute("delete from users")
- @testing.fails_on_everything_except('sqlite', 'oracle')
+ @testing.fails_on_everything_except('sqlite', 'oracle+cx_oracle')
def test_raw_named(self):
for conn in (testing.db, testing.db.connect()):
conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'})
except tsa.exc.DBAPIError:
assert True
- @testing.fails_on('mssql', 'rowcount returns -1')
def test_empty_insert(self):
"""test that execute() interprets [] as a list with no params"""
result = testing.db.execute(users.insert().values(user_name=bindparam('name')), [])
- eq_(result.rowcount, 1)
+ eq_(testing.db.execute(users.select()).fetchall(), [
+ (1, None)
+ ])
class ProxyConnectionTest(TestBase):
@testing.fails_on('firebird', 'Data type unknown')
return execute(clauseelement, *multiparams, **params)
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+ print "CE", statement, parameters
cursor_stmts.append(
(statement, parameters, None)
)
break
for engine in (
- engines.testing_engine(options=dict(proxy=MyProxy())),
- engines.testing_engine(options=dict(proxy=MyProxy(), strategy='threadlocal'))
+ engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy())),
+ engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy(), strategy='threadlocal'))
):
m = MetaData(engine)
t1.insert().execute(c1=6)
assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')]
finally:
+ pass
m.drop_all()
engine.dispose()
("DROP TABLE t1", {}, None)
]
- if engine.dialect.preexecute_pk_sequences:
+ if True: # or engine.dialect.preexecute_pk_sequences:
cursor = [
- ("CREATE TABLE t1", {}, None),
+ ("CREATE TABLE t1", {}, ()),
("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
("SELECT lower", {'lower_2':'Foo'}, ['Foo']),
("INSERT INTO t1 (c1, c2)", {'c2': 'foo', 'c1': 6}, [6, 'foo']),
- ("select * from t1", {}, None),
- ("DROP TABLE t1", {}, None)
+ ("select * from t1", {}, ()),
+ ("DROP TABLE t1", {}, ())
]
else:
cursor = [
from sqlalchemy.test.testing import assert_raises, assert_raises_message
import pickle
-from sqlalchemy import MetaData
-from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey
+from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey, MetaData
from sqlalchemy.test.schema import Table
from sqlalchemy.test.schema import Column
+from sqlalchemy import schema
import sqlalchemy as tsa
-from sqlalchemy.test import TestBase, ComparesTables, testing, engines
+from sqlalchemy.test import TestBase, ComparesTables, AssertsCompiledSQL, testing, engines
from sqlalchemy.test.testing import eq_
class MetaDataTest(TestBase, ComparesTables):
meta.create_all(testing.db)
try:
- for test, has_constraints in ((test_to_metadata, True), (test_pickle, True), (test_pickle_via_reflect, False)):
+ for test, has_constraints in ((test_to_metadata, True), (test_pickle, True),(test_pickle_via_reflect, False)):
table_c, table2_c = test()
self.assert_tables_equal(table, table_c)
self.assert_tables_equal(table2, table2_c)
MetaData(testing.db), autoload=True)
-class TableOptionsTest(TestBase):
- def setup(self):
- self.engine = engines.mock_engine()
- self.metadata = MetaData(self.engine)
-
+class TableOptionsTest(TestBase, AssertsCompiledSQL):
def test_prefixes(self):
- table1 = Table("temporary_table_1", self.metadata,
+ table1 = Table("temporary_table_1", MetaData(),
Column("col1", Integer),
prefixes = ["TEMPORARY"])
- table1.create()
- assert [str(x) for x in self.engine.mock if 'CREATE TEMPORARY TABLE' in str(x)]
- del self.engine.mock[:]
- table2 = Table("temporary_table_2", self.metadata,
+
+ self.assert_compile(
+ schema.CreateTable(table1),
+ "CREATE TEMPORARY TABLE temporary_table_1 (col1 INTEGER)"
+ )
+
+ table2 = Table("temporary_table_2", MetaData(),
Column("col1", Integer),
prefixes = ["VIRTUAL"])
- table2.create()
- assert [str(x) for x in self.engine.mock if 'CREATE VIRTUAL TABLE' in str(x)]
+ self.assert_compile(
+ schema.CreateTable(table2),
+ "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)"
+ )
def test_table_info(self):
-
- t1 = Table('foo', self.metadata, info={'x':'y'})
- t2 = Table('bar', self.metadata, info={})
- t3 = Table('bat', self.metadata)
+ metadata = MetaData()
+ t1 = Table('foo', metadata, info={'x':'y'})
+ t2 = Table('bar', metadata, info={})
+ t3 = Table('bat', metadata)
assert t1.info == {'x':'y'}
assert t2.info == {}
assert t3.info == {}
-import ConfigParser, StringIO
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import ConfigParser
+import StringIO
import sqlalchemy.engine.url as url
from sqlalchemy import create_engine, engine_from_config
import sqlalchemy as tsa
'dbtype://username:apples%2Foranges@hostspec/mydatabase',
):
u = url.make_url(text)
- print u, text
- print "username=", u.username, "password=", u.password, "database=", u.database, "host=", u.host
assert u.drivername == 'dbtype'
assert u.username == 'username' or u.username is None
assert u.password == 'password' or u.password == 'apples/oranges' or u.password is None
def test_connect_query(self):
dbapi = MockDBAPI(foober='12', lala='18', fooz='somevalue')
- # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
- e = create_engine('postgres://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', module=dbapi)
+ e = create_engine(
+ 'postgresql://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue',
+ module=dbapi,
+ _initialize=False
+ )
c = e.connect()
def test_kwargs(self):
dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue')
- # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
- e = create_engine('postgres://scott:tiger@somehost/test?fooz=somevalue', connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, module=dbapi)
+ e = create_engine(
+ 'postgresql://scott:tiger@somehost/test?fooz=somevalue',
+ connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}},
+ module=dbapi,
+ _initialize=False
+ )
c = e.connect()
def test_coerce_config(self):
raw = r"""
[prefixed]
-sqlalchemy.url=postgres://scott:tiger@somehost/test?fooz=somevalue
+sqlalchemy.url=postgresql://scott:tiger@somehost/test?fooz=somevalue
sqlalchemy.convert_unicode=0
sqlalchemy.echo=false
sqlalchemy.echo_pool=1
sqlalchemy.pool_threadlocal=1
sqlalchemy.pool_timeout=10
[plain]
-url=postgres://scott:tiger@somehost/test?fooz=somevalue
+url=postgresql://scott:tiger@somehost/test?fooz=somevalue
convert_unicode=0
echo=0
echo_pool=1
ini.readfp(StringIO.StringIO(raw))
expected = {
- 'url': 'postgres://scott:tiger@somehost/test?fooz=somevalue',
+ 'url': 'postgresql://scott:tiger@somehost/test?fooz=somevalue',
'convert_unicode': 0,
'echo': False,
'echo_pool': True,
self.assert_(tsa.engine._coerce_config(plain, '') == expected)
def test_engine_from_config(self):
- dbapi = MockDBAPI()
+ dbapi = mock_dbapi
config = {
- 'sqlalchemy.url':'postgres://scott:tiger@somehost/test?fooz=somevalue',
+ 'sqlalchemy.url':'postgresql://scott:tiger@somehost/test?fooz=somevalue',
'sqlalchemy.pool_recycle':'50',
'sqlalchemy.echo':'true'
}
e = engine_from_config(config, module=dbapi)
assert e.pool._recycle == 50
- assert e.url == url.make_url('postgres://scott:tiger@somehost/test?fooz=somevalue')
+ assert e.url == url.make_url('postgresql://scott:tiger@somehost/test?fooz=somevalue')
assert e.echo is True
def test_custom(self):
def connect():
return dbapi.connect(foober=12, lala=18, fooz='somevalue', hoho={'this':'dict'})
- # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
- e = create_engine('postgres://', creator=connect, module=dbapi)
+ # start the postgresql dialect, but put our mock DBAPI as the module instead of psycopg
+ e = create_engine('postgresql://', creator=connect, module=dbapi, _initialize=False)
c = e.connect()
def test_recycle(self):
dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue')
- e = create_engine('postgres://', pool_recycle=472, module=dbapi)
+ e = create_engine('postgresql://', pool_recycle=472, module=dbapi, _initialize=False)
assert e.pool._recycle == 472
def test_badargs(self):
- # good arg, use MockDBAPI to prevent oracle import errors
- e = create_engine('oracle://', use_ansi=True, module=MockDBAPI())
-
- try:
- e = create_engine("foobar://", module=MockDBAPI())
- assert False
- except ImportError:
- assert True
+ assert_raises(ImportError, create_engine, "foobar://", module=mock_dbapi)
# bad arg
- try:
- e = create_engine('postgres://', use_ansi=True, module=MockDBAPI())
- assert False
- except TypeError:
- assert True
+ assert_raises(TypeError, create_engine, 'postgresql://', use_ansi=True, module=mock_dbapi)
# bad arg
- try:
- e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI())
- assert False
- except TypeError:
- assert True
+ assert_raises(TypeError, create_engine, 'oracle://', lala=5, use_ansi=True, module=mock_dbapi)
- try:
- e = create_engine('postgres://', lala=5, module=MockDBAPI())
- assert False
- except TypeError:
- assert True
+ assert_raises(TypeError, create_engine, 'postgresql://', lala=5, module=mock_dbapi)
- try:
- e = create_engine('sqlite://', lala=5)
- assert False
- except TypeError:
- assert True
+ assert_raises(TypeError, create_engine,'sqlite://', lala=5, module=mock_sqlite_dbapi)
- try:
- e = create_engine('mysql://', use_unicode=True, module=MockDBAPI())
- assert False
- except TypeError:
- assert True
-
- try:
- # sqlite uses SingletonThreadPool which doesnt have max_overflow
- e = create_engine('sqlite://', max_overflow=5)
- assert False
- except TypeError:
- assert True
+ assert_raises(TypeError, create_engine, 'mysql+mysqldb://', use_unicode=True, module=mock_dbapi)
- e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True)
+ # sqlite uses SingletonThreadPool which doesnt have max_overflow
+ assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=5,
+ module=mock_sqlite_dbapi)
- e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
try:
- c = e.connect()
- assert False
- except tsa.exc.DBAPIError:
- assert True
+ e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
+ except ImportError:
+ # no sqlite
+ pass
+ else:
+ # raises DBAPIerror due to use_unicode not a sqlite arg
+ assert_raises(tsa.exc.DBAPIError, e.connect)
def test_urlattr(self):
"""test the url attribute on ``Engine``."""
- e = create_engine('mysql://scott:tiger@localhost/test', module=MockDBAPI())
+ e = create_engine('mysql://scott:tiger@localhost/test', module=mock_dbapi, _initialize=False)
u = url.make_url('mysql://scott:tiger@localhost/test')
- e2 = create_engine(u, module=MockDBAPI())
+ e2 = create_engine(u, module=mock_dbapi, _initialize=False)
assert e.url.drivername == e2.url.drivername == 'mysql'
assert e.url.username == e2.url.username == 'scott'
assert e2.url is u
def test_poolargs(self):
"""test that connection pool args make it thru"""
- e = create_engine('postgres://', creator=None, pool_recycle=50, echo_pool=None, module=MockDBAPI())
+ e = create_engine('postgresql://', creator=None, pool_recycle=50, echo_pool=None, module=mock_dbapi, _initialize=False)
assert e.pool._recycle == 50
# these args work for QueuePool
- e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=MockDBAPI())
+ e = create_engine('postgresql://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=mock_dbapi)
- try:
- # but not SingletonThreadPool
- e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.SingletonThreadPool)
- assert False
- except TypeError:
- assert True
+ # but not SingletonThreadPool
+ assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=8, pool_timeout=60,
+ poolclass=tsa.pool.SingletonThreadPool, module=mock_sqlite_dbapi)
class MockDBAPI(object):
def __init__(self, **kwargs):
self.kwargs = kwargs
self.paramstyle = 'named'
- def connect(self, **kwargs):
- print kwargs, self.kwargs
+ def connect(self, *args, **kwargs):
for k in self.kwargs:
assert k in kwargs, "key %s not present in dictionary" % k
assert kwargs[k]==self.kwargs[k], "value %s does not match %s" % (kwargs[k], self.kwargs[k])
return MockConnection()
class MockConnection(object):
+ def get_server_info(self):
+ return "5.0"
def close(self):
pass
def cursor(self):
def close(self):
pass
mock_dbapi = MockDBAPI()
-
+mock_sqlite_dbapi = msd = MockDBAPI()
+msd.version_info = msd.sqlite_version_info = (99, 9, 9)
+msd.sqlite_version = '99.9.9'
-import threading, time, gc
-from sqlalchemy import pool, interfaces
+import threading, time
+from sqlalchemy import pool, interfaces, create_engine, select
import sqlalchemy as tsa
-from sqlalchemy.test import TestBase
+from sqlalchemy.test import TestBase, testing
+from sqlalchemy.test.util import gc_collect, lazy_gc
mcid = 1
connection2 = manager.connect('foo.db')
connection3 = manager.connect('bar.db')
- print "connection " + repr(connection)
self.assert_(connection.cursor() is not None)
self.assert_(connection is connection2)
self.assert_(connection2 is not connection3)
connection = manager.connect('foo.db')
connection2 = manager.connect('foo.db')
- print "connection " + repr(connection)
-
self.assert_(connection.cursor() is not None)
self.assert_(connection is not connection2)
c2.close()
else:
c2 = None
-
+ lazy_gc()
+
if useclose:
c1 = p.connect()
c2 = p.connect()
# extra tests with QueuePool to ensure connections get __del__()ed when dereferenced
if isinstance(p, pool.QueuePool):
+ lazy_gc()
+
self.assert_(p.checkedout() == 0)
c1 = p.connect()
c2 = p.connect()
else:
c2 = None
c1 = None
+ lazy_gc()
self.assert_(p.checkedout() == 0)
def test_properties(self):
def __init__(self):
if hasattr(self, 'connect'):
self.connect = self.inst_connect
+ if hasattr(self, 'first_connect'):
+ self.first_connect = self.inst_first_connect
if hasattr(self, 'checkout'):
self.checkout = self.inst_checkout
if hasattr(self, 'checkin'):
self.clear()
def clear(self):
self.connected = []
+ self.first_connected = []
self.checked_out = []
self.checked_in = []
- def assert_total(innerself, conn, cout, cin):
+ def assert_total(innerself, conn, fconn, cout, cin):
self.assert_(len(innerself.connected) == conn)
+ self.assert_(len(innerself.first_connected) == fconn)
self.assert_(len(innerself.checked_out) == cout)
self.assert_(len(innerself.checked_in) == cin)
- def assert_in(innerself, item, in_conn, in_cout, in_cin):
+ def assert_in(innerself, item, in_conn, in_fconn, in_cout, in_cin):
self.assert_((item in innerself.connected) == in_conn)
+ self.assert_((item in innerself.first_connected) == in_fconn)
self.assert_((item in innerself.checked_out) == in_cout)
self.assert_((item in innerself.checked_in) == in_cin)
def inst_connect(self, con, record):
assert con is not None
assert record is not None
self.connected.append(con)
+ def inst_first_connect(self, con, record):
+ print "first_connect(%s, %s)" % (con, record)
+ assert con is not None
+ assert record is not None
+ self.first_connected.append(con)
def inst_checkout(self, con, record, proxy):
print "checkout(%s, %s, %s)" % (con, record, proxy)
assert con is not None
class ListenConnect(InstrumentingListener):
def connect(self, con, record):
pass
+ class ListenFirstConnect(InstrumentingListener):
+ def first_connect(self, con, record):
+ pass
class ListenCheckOut(InstrumentingListener):
def checkout(self, con, record, proxy, num):
pass
return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
use_threadlocal=False, **kw)
- def assert_listeners(p, total, conn, cout, cin):
+ def assert_listeners(p, total, conn, fconn, cout, cin):
for instance in (p, p.recreate()):
self.assert_(len(instance.listeners) == total)
self.assert_(len(instance._on_connect) == conn)
+ self.assert_(len(instance._on_first_connect) == fconn)
self.assert_(len(instance._on_checkout) == cout)
self.assert_(len(instance._on_checkin) == cin)
p = _pool()
- assert_listeners(p, 0, 0, 0, 0)
+ assert_listeners(p, 0, 0, 0, 0, 0)
p.add_listener(ListenAll())
- assert_listeners(p, 1, 1, 1, 1)
+ assert_listeners(p, 1, 1, 1, 1, 1)
p.add_listener(ListenConnect())
- assert_listeners(p, 2, 2, 1, 1)
+ assert_listeners(p, 2, 2, 1, 1, 1)
+
+ p.add_listener(ListenFirstConnect())
+ assert_listeners(p, 3, 2, 2, 1, 1)
p.add_listener(ListenCheckOut())
- assert_listeners(p, 3, 2, 2, 1)
+ assert_listeners(p, 4, 2, 2, 2, 1)
p.add_listener(ListenCheckIn())
- assert_listeners(p, 4, 2, 2, 2)
+ assert_listeners(p, 5, 2, 2, 2, 2)
del p
- print "----"
snoop = ListenAll()
p = _pool(listeners=[snoop])
- assert_listeners(p, 1, 1, 1, 1)
+ assert_listeners(p, 1, 1, 1, 1, 1)
c = p.connect()
- snoop.assert_total(1, 1, 0)
+ snoop.assert_total(1, 1, 1, 0)
cc = c.connection
- snoop.assert_in(cc, True, True, False)
+ snoop.assert_in(cc, True, True, True, False)
c.close()
- snoop.assert_in(cc, True, True, True)
+ snoop.assert_in(cc, True, True, True, True)
del c, cc
snoop.clear()
# this one depends on immediate gc
c = p.connect()
cc = c.connection
- snoop.assert_in(cc, False, True, False)
- snoop.assert_total(0, 1, 0)
+ snoop.assert_in(cc, False, False, True, False)
+ snoop.assert_total(0, 0, 1, 0)
del c, cc
- snoop.assert_total(0, 1, 1)
+ lazy_gc()
+ snoop.assert_total(0, 0, 1, 1)
p.dispose()
snoop.clear()
c = p.connect()
c.close()
c = p.connect()
- snoop.assert_total(1, 2, 1)
+ snoop.assert_total(1, 0, 2, 1)
c.close()
- snoop.assert_total(1, 2, 2)
+ snoop.assert_total(1, 0, 2, 2)
# invalidation
p.dispose()
snoop.clear()
c = p.connect()
- snoop.assert_total(1, 1, 0)
+ snoop.assert_total(1, 0, 1, 0)
c.invalidate()
- snoop.assert_total(1, 1, 1)
+ snoop.assert_total(1, 0, 1, 1)
c.close()
- snoop.assert_total(1, 1, 1)
+ snoop.assert_total(1, 0, 1, 1)
del c
- snoop.assert_total(1, 1, 1)
+ lazy_gc()
+ snoop.assert_total(1, 0, 1, 1)
c = p.connect()
- snoop.assert_total(2, 2, 1)
+ snoop.assert_total(2, 0, 2, 1)
c.close()
del c
- snoop.assert_total(2, 2, 2)
+ lazy_gc()
+ snoop.assert_total(2, 0, 2, 2)
# detached
p.dispose()
snoop.clear()
c = p.connect()
- snoop.assert_total(1, 1, 0)
+ snoop.assert_total(1, 0, 1, 0)
c.detach()
- snoop.assert_total(1, 1, 0)
+ snoop.assert_total(1, 0, 1, 0)
c.close()
del c
- snoop.assert_total(1, 1, 0)
+ snoop.assert_total(1, 0, 1, 0)
c = p.connect()
- snoop.assert_total(2, 2, 0)
+ snoop.assert_total(2, 0, 2, 0)
c.close()
del c
- snoop.assert_total(2, 2, 1)
+ snoop.assert_total(2, 0, 2, 1)
def test_listeners_callables(self):
dbapi = MockDBAPI()
c.close()
assert counts == [1, 2, 3]
+ def test_listener_after_oninit(self):
+ """Test that listeners are called after OnInit is removed"""
+ called = []
+ def listener(*args):
+ called.append(True)
+ listener.connect = listener
+ engine = create_engine(testing.db.url)
+ engine.pool.add_listener(listener)
+ engine.execute(select([1]))
+ assert called, "Listener not called on connect"
+
+
class QueuePoolTest(PoolTestBase):
- def testqueuepool_del(self):
- self._do_testqueuepool(useclose=False)
-
- def testqueuepool_close(self):
- self._do_testqueuepool(useclose=True)
-
- def _do_testqueuepool(self, useclose=False):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False)
-
- def status(pool):
- tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
- print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
- return tup
-
- c1 = p.connect()
- self.assert_(status(p) == (3,0,-2,1))
- c2 = p.connect()
- self.assert_(status(p) == (3,0,-1,2))
- c3 = p.connect()
- self.assert_(status(p) == (3,0,0,3))
- c4 = p.connect()
- self.assert_(status(p) == (3,0,1,4))
- c5 = p.connect()
- self.assert_(status(p) == (3,0,2,5))
- c6 = p.connect()
- self.assert_(status(p) == (3,0,3,6))
- if useclose:
- c4.close()
- c3.close()
- c2.close()
- else:
- c4 = c3 = c2 = None
- self.assert_(status(p) == (3,3,3,3))
- if useclose:
- c1.close()
- c5.close()
- c6.close()
- else:
- c1 = c5 = c6 = None
- self.assert_(status(p) == (3,3,0,0))
- c1 = p.connect()
- c2 = p.connect()
- self.assert_(status(p) == (3, 1, 0, 2), status(p))
- if useclose:
- c2.close()
- else:
- c2 = None
- self.assert_(status(p) == (3, 2, 0, 1))
-
- def test_timeout(self):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2)
- c1 = p.connect()
- c2 = p.connect()
- c3 = p.connect()
- now = time.time()
- try:
- c4 = p.connect()
- assert False
- except tsa.exc.TimeoutError, e:
- assert int(time.time() - now) == 2
-
- def test_timeout_race(self):
- # test a race condition where the initial connecting threads all race
- # to queue.Empty, then block on the mutex. each thread consumes a
- # connection as they go in. when the limit is reached, the remaining
- # threads go in, and get TimeoutError; even though they never got to
- # wait for the timeout on queue.get(). the fix involves checking the
- # timeout again within the mutex, and if so, unlocking and throwing
- # them back to the start of do_get()
- p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3)
- timeouts = []
- def checkout():
- for x in xrange(1):
- now = time.time()
- try:
- c1 = p.connect()
- except tsa.exc.TimeoutError, e:
- timeouts.append(int(time.time()) - now)
- continue
- time.sleep(4)
- c1.close()
-
- threads = []
- for i in xrange(10):
- th = threading.Thread(target=checkout)
- th.start()
- threads.append(th)
- for th in threads:
- th.join()
-
- print timeouts
- assert len(timeouts) > 0
- for t in timeouts:
- assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts)
-
- def _test_overflow(self, thread_count, max_overflow):
- def creator():
- time.sleep(.05)
- return mock_dbapi.connect()
-
- p = pool.QueuePool(creator=creator,
- pool_size=3, timeout=2,
- max_overflow=max_overflow)
- peaks = []
- def whammy():
- for i in range(10):
- try:
- con = p.connect()
- time.sleep(.005)
- peaks.append(p.overflow())
- con.close()
- del con
- except tsa.exc.TimeoutError:
- pass
- threads = []
- for i in xrange(thread_count):
- th = threading.Thread(target=whammy)
- th.start()
- threads.append(th)
- for th in threads:
- th.join()
-
- self.assert_(max(peaks) <= max_overflow)
-
- def test_no_overflow(self):
- self._test_overflow(40, 0)
-
- def test_max_overflow(self):
- self._test_overflow(40, 5)
-
- def test_mixed_close(self):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
- c1 = p.connect()
- c2 = p.connect()
- assert c1 is c2
- c1.close()
- c2 = None
- assert p.checkedout() == 1
- c1 = None
- assert p.checkedout() == 0
-
- def test_weakref_kaboom(self):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
- c1 = p.connect()
- c2 = p.connect()
- c1.close()
- c2 = None
- del c1
- del c2
- gc.collect()
- assert p.checkedout() == 0
- c3 = p.connect()
- assert c3 is not None
-
- def test_trick_the_counter(self):
- """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread
- with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an
- ambiguous counter. i.e. its not true reference counting."""
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
- c1 = p.connect()
- c2 = p.connect()
- assert c1 is c2
- c1.close()
- c2 = p.connect()
- c2.close()
- self.assert_(p.checkedout() != 0)
-
- c2.close()
- self.assert_(p.checkedout() == 0)
-
- def test_recycle(self):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3)
-
- c1 = p.connect()
- c_id = id(c1.connection)
- c1.close()
- c2 = p.connect()
- assert id(c2.connection) == c_id
- c2.close()
- time.sleep(4)
- c3= p.connect()
- assert id(c3.connection) != c_id
-
- def test_invalidate(self):
- dbapi = MockDBAPI()
- p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
- c1 = p.connect()
- c_id = c1.connection.id
- c1.close(); c1=None
- c1 = p.connect()
- assert c1.connection.id == c_id
- c1.invalidate()
- c1 = None
-
- c1 = p.connect()
- assert c1.connection.id != c_id
-
- def test_recreate(self):
- dbapi = MockDBAPI()
- p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
- p2 = p.recreate()
- assert p2.size() == 1
- assert p2._use_threadlocal is False
- assert p2._max_overflow == 0
-
- def test_reconnect(self):
- """tests reconnect operations at the pool level. SA's engine/dialect includes another
- layer of reconnect support for 'database was lost' errors."""
+ def testqueuepool_del(self):
+ self._do_testqueuepool(useclose=False)
+
+ def testqueuepool_close(self):
+ self._do_testqueuepool(useclose=True)
+
+ def _do_testqueuepool(self, useclose=False):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False)
+
+ def status(pool):
+ tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
+ print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
+ return tup
+
+ c1 = p.connect()
+ self.assert_(status(p) == (3,0,-2,1))
+ c2 = p.connect()
+ self.assert_(status(p) == (3,0,-1,2))
+ c3 = p.connect()
+ self.assert_(status(p) == (3,0,0,3))
+ c4 = p.connect()
+ self.assert_(status(p) == (3,0,1,4))
+ c5 = p.connect()
+ self.assert_(status(p) == (3,0,2,5))
+ c6 = p.connect()
+ self.assert_(status(p) == (3,0,3,6))
+ if useclose:
+ c4.close()
+ c3.close()
+ c2.close()
+ else:
+ c4 = c3 = c2 = None
+ lazy_gc()
+
+ self.assert_(status(p) == (3,3,3,3))
+ if useclose:
+ c1.close()
+ c5.close()
+ c6.close()
+ else:
+ c1 = c5 = c6 = None
+ lazy_gc()
+
+ self.assert_(status(p) == (3,3,0,0))
+
+ c1 = p.connect()
+ c2 = p.connect()
+ self.assert_(status(p) == (3, 1, 0, 2), status(p))
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
+ lazy_gc()
+
+ self.assert_(status(p) == (3, 2, 0, 1))
+
+ c1.close()
+
+ lazy_gc()
+ assert not pool._refs
- dbapi = MockDBAPI()
- p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
- c1 = p.connect()
- c_id = c1.connection.id
- c1.close(); c1=None
-
- c1 = p.connect()
- assert c1.connection.id == c_id
- dbapi.raise_error = True
- c1.invalidate()
- c1 = None
-
- c1 = p.connect()
- assert c1.connection.id != c_id
-
- def test_detach(self):
- dbapi = MockDBAPI()
- p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
-
- c1 = p.connect()
- c1.detach()
- c_id = c1.connection.id
-
- c2 = p.connect()
- assert c2.connection.id != c1.connection.id
- dbapi.raise_error = True
-
- c2.invalidate()
- c2 = None
-
- c2 = p.connect()
- assert c2.connection.id != c1.connection.id
-
- con = c1.connection
-
- assert not con.closed
- c1.close()
- assert con.closed
-
- def test_threadfairy(self):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
- c1 = p.connect()
- c1.close()
- c2 = p.connect()
- assert c2.connection is not None
+ def test_timeout(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2)
+ c1 = p.connect()
+ c2 = p.connect()
+ c3 = p.connect()
+ now = time.time()
+ try:
+ c4 = p.connect()
+ assert False
+ except tsa.exc.TimeoutError, e:
+ assert int(time.time() - now) == 2
+
+ def test_timeout_race(self):
+ # test a race condition where the initial connecting threads all race
+ # to queue.Empty, then block on the mutex. each thread consumes a
+ # connection as they go in. when the limit is reached, the remaining
+ # threads go in, and get TimeoutError; even though they never got to
+ # wait for the timeout on queue.get(). the fix involves checking the
+ # timeout again within the mutex, and if so, unlocking and throwing
+ # them back to the start of do_get()
+ p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3)
+ timeouts = []
+ def checkout():
+ for x in xrange(1):
+ now = time.time()
+ try:
+ c1 = p.connect()
+ except tsa.exc.TimeoutError, e:
+ timeouts.append(int(time.time()) - now)
+ continue
+ time.sleep(4)
+ c1.close()
+
+ threads = []
+ for i in xrange(10):
+ th = threading.Thread(target=checkout)
+ th.start()
+ threads.append(th)
+ for th in threads:
+ th.join()
+
+ print timeouts
+ assert len(timeouts) > 0
+ for t in timeouts:
+ assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts)
+
+ def _test_overflow(self, thread_count, max_overflow):
+ def creator():
+ time.sleep(.05)
+ return mock_dbapi.connect()
+
+ p = pool.QueuePool(creator=creator,
+ pool_size=3, timeout=2,
+ max_overflow=max_overflow)
+ peaks = []
+ def whammy():
+ for i in range(10):
+ try:
+ con = p.connect()
+ time.sleep(.005)
+ peaks.append(p.overflow())
+ con.close()
+ del con
+ except tsa.exc.TimeoutError:
+ pass
+ threads = []
+ for i in xrange(thread_count):
+ th = threading.Thread(target=whammy)
+ th.start()
+ threads.append(th)
+ for th in threads:
+ th.join()
+
+ self.assert_(max(peaks) <= max_overflow)
+
+ lazy_gc()
+ assert not pool._refs
+
+ def test_no_overflow(self):
+ self._test_overflow(40, 0)
+
+ def test_max_overflow(self):
+ self._test_overflow(40, 5)
+
+ def test_mixed_close(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c2 = p.connect()
+ assert c1 is c2
+ c1.close()
+ c2 = None
+ assert p.checkedout() == 1
+ c1 = None
+ lazy_gc()
+ assert p.checkedout() == 0
+
+ lazy_gc()
+ assert not pool._refs
+
+ def test_weakref_kaboom(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c2 = p.connect()
+ c1.close()
+ c2 = None
+ del c1
+ del c2
+ gc_collect()
+ assert p.checkedout() == 0
+ c3 = p.connect()
+ assert c3 is not None
+
+ def test_trick_the_counter(self):
+ """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread
+ with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an
+ ambiguous counter. i.e. its not true reference counting."""
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c2 = p.connect()
+ assert c1 is c2
+ c1.close()
+ c2 = p.connect()
+ c2.close()
+ self.assert_(p.checkedout() != 0)
+
+ c2.close()
+ self.assert_(p.checkedout() == 0)
+
+ def test_recycle(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3)
+
+ c1 = p.connect()
+ c_id = id(c1.connection)
+ c1.close()
+ c2 = p.connect()
+ assert id(c2.connection) == c_id
+ c2.close()
+ time.sleep(4)
+ c3= p.connect()
+ assert id(c3.connection) != c_id
+
+ def test_invalidate(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+ c1 = p.connect()
+ c_id = c1.connection.id
+ c1.close(); c1=None
+ c1 = p.connect()
+ assert c1.connection.id == c_id
+ c1.invalidate()
+ c1 = None
+
+ c1 = p.connect()
+ assert c1.connection.id != c_id
+
+ def test_recreate(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+ p2 = p.recreate()
+ assert p2.size() == 1
+ assert p2._use_threadlocal is False
+ assert p2._max_overflow == 0
+
+ def test_reconnect(self):
+ """tests reconnect operations at the pool level. SA's engine/dialect includes another
+ layer of reconnect support for 'database was lost' errors."""
+
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+ c1 = p.connect()
+ c_id = c1.connection.id
+ c1.close(); c1=None
+
+ c1 = p.connect()
+ assert c1.connection.id == c_id
+ dbapi.raise_error = True
+ c1.invalidate()
+ c1 = None
+
+ c1 = p.connect()
+ assert c1.connection.id != c_id
+
+ def test_detach(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+
+ c1 = p.connect()
+ c1.detach()
+ c_id = c1.connection.id
+
+ c2 = p.connect()
+ assert c2.connection.id != c1.connection.id
+ dbapi.raise_error = True
+
+ c2.invalidate()
+ c2 = None
+
+ c2 = p.connect()
+ assert c2.connection.id != c1.connection.id
+
+ con = c1.connection
+
+ assert not con.closed
+ c1.close()
+ assert con.closed
+
+ def test_threadfairy(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c1.close()
+ c2 = p.connect()
+ assert c2.connection is not None
class SingletonThreadPoolTest(PoolTestBase):
def test_cleanup(self):
from sqlalchemy.test.testing import eq_
+import time
import weakref
from sqlalchemy import select, MetaData, Integer, String, pool
from sqlalchemy.test.schema import Table
from sqlalchemy.test.schema import Column
import sqlalchemy as tsa
from sqlalchemy.test import TestBase, testing, engines
-import time
-import gc
+from sqlalchemy.test.util import gc_collect
+
class MockDisconnect(Exception):
pass
dbapi = MockDBAPI()
# create engine using our current dburi
- db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi)
+ db = tsa.create_engine('postgresql://foo:bar@localhost/test', module=dbapi, _initialize=False)
# monkeypatch disconnect checker
db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
assert id(db.pool) != pid
# ensure all connections closed (pool was recycled)
- gc.collect()
+ gc_collect()
assert len(dbapi.connections) == 0
conn =db.connect()
pass
# assert was invalidated
- gc.collect()
+ gc_collect()
assert len(dbapi.connections) == 0
assert not conn.closed
assert conn.invalidated
assert conn.invalidated
# ensure all connections closed (pool was recycled)
- gc.collect()
+ gc_collect()
assert len(dbapi.connections) == 0
# test reconnects
meta.drop_all()
engine.dispose()
- @testing.fails_on('mysql', 'FIXME: unknown')
+ @testing.fails_on('+mysqldb', "Buffers the result set and doesn't check for connection close")
+ @testing.fails_on('+pg8000', "Buffers the result set and doesn't check for connection close")
def test_invalidate_on_results(self):
conn = engine.connect()
engine.test_shutdown()
try:
- result.fetchone()
+ print "ghost result: %r" % result.fetchone()
assert False
except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import StringIO, unicodedata
import sqlalchemy as sa
+from sqlalchemy import types as sql_types
+from sqlalchemy import schema
+from sqlalchemy.engine.reflection import Inspector
from sqlalchemy import MetaData
from sqlalchemy.test.schema import Table
from sqlalchemy.test.schema import Column
import sqlalchemy as tsa
from sqlalchemy.test import TestBase, ComparesTables, testing, engines
+create_inspector = Inspector.from_engine
metadata, users = None, None
class ReflectionTest(TestBase, ComparesTables):
+ @testing.exclude('mssql', '<', (10, 0, 0), 'Date is only supported on MSSQL 2008+')
@testing.exclude('mysql', '<', (4, 1, 1), 'early types are squirrely')
def test_basic_reflection(self):
meta = MetaData(testing.db)
Column('test1', sa.CHAR(5), nullable=False),
Column('test2', sa.Float(5), nullable=False),
Column('test3', sa.Text),
- Column('test4', sa.Numeric, nullable = False),
- Column('test5', sa.DateTime),
+ Column('test4', sa.Numeric(10, 2), nullable = False),
+ Column('test5', sa.Date),
Column('parent_user_id', sa.Integer,
sa.ForeignKey('engine_users.user_id')),
- Column('test6', sa.DateTime, nullable=False),
+ Column('test6', sa.Date, nullable=False),
Column('test7', sa.Text),
Column('test8', sa.Binary),
Column('test_passivedefault2', sa.Integer, server_default='5'),
Column('test9', sa.Binary(100)),
- Column('test_numeric', sa.Numeric()),
+ Column('test10', sa.Numeric(10, 2)),
test_needs_fk=True,
)
self.assert_tables_equal(users, reflected_users)
self.assert_tables_equal(addresses, reflected_addresses)
finally:
- addresses.drop()
- users.drop()
-
+ meta.drop_all()
+
+ def test_two_foreign_keys(self):
+ meta = MetaData(testing.db)
+ t1 = Table('t1', meta,
+ Column('id', sa.Integer, primary_key=True),
+ Column('t2id', sa.Integer, sa.ForeignKey('t2.id')),
+ Column('t3id', sa.Integer, sa.ForeignKey('t3.id')),
+ test_needs_fk=True
+ )
+ t2 = Table('t2', meta,
+ Column('id', sa.Integer, primary_key=True),
+ test_needs_fk=True
+ )
+ t3 = Table('t3', meta,
+ Column('id', sa.Integer, primary_key=True),
+ test_needs_fk=True
+ )
+ meta.create_all()
+ try:
+ meta2 = MetaData()
+ t1r, t2r, t3r = [Table(x, meta2, autoload=True, autoload_with=testing.db) for x in ('t1', 't2', 't3')]
+
+ assert t1r.c.t2id.references(t2r.c.id)
+ assert t1r.c.t3id.references(t3r.c.id)
+
+ finally:
+ meta.drop_all()
+
def test_include_columns(self):
meta = MetaData(testing.db)
foo = Table('foo', meta, *[Column(n, sa.String(30))
finally:
meta.drop_all()
+ @testing.emits_warning(r".*omitted columns")
+ def test_include_columns_indexes(self):
+ m = MetaData(testing.db)
+
+ t1 = Table('t1', m, Column('a', sa.Integer), Column('b', sa.Integer))
+ sa.Index('foobar', t1.c.a, t1.c.b)
+ sa.Index('bat', t1.c.a)
+ m.create_all()
+ try:
+ m2 = MetaData(testing.db)
+ t2 = Table('t1', m2, autoload=True)
+ assert len(t2.indexes) == 2
+ m2 = MetaData(testing.db)
+ t2 = Table('t1', m2, autoload=True, include_columns=['a'])
+ assert len(t2.indexes) == 1
+
+ m2 = MetaData(testing.db)
+ t2 = Table('t1', m2, autoload=True, include_columns=['a', 'b'])
+ assert len(t2.indexes) == 2
+ finally:
+ m.drop_all()
+
+ def test_autoincrement_col(self):
+ """test that 'autoincrement' is reflected according to sqla's policy.
+
+ Don't mark this test as unsupported for any backend !
+
+ (technically it fails with MySQL InnoDB since "id" comes before "id2")
+
+ """
+
+ meta = MetaData(testing.db)
+ t1 = Table('test', meta,
+ Column('id', sa.Integer, primary_key=True),
+ Column('data', sa.String(50)),
+ )
+ t2 = Table('test2', meta,
+ Column('id', sa.Integer, sa.ForeignKey('test.id'), primary_key=True),
+ Column('id2', sa.Integer, primary_key=True),
+ Column('data', sa.String(50)),
+ )
+ meta.create_all()
+ try:
+ m2 = MetaData(testing.db)
+ t1a = Table('test', m2, autoload=True)
+ assert t1a._autoincrement_column is t1a.c.id
+
+ t2a = Table('test2', m2, autoload=True)
+ assert t2a._autoincrement_column is t2a.c.id2
+
+ finally:
+ meta.drop_all()
+
def test_unknown_types(self):
meta = MetaData(testing.db)
t = Table("test", meta,
Column('foo', sa.DateTime))
- import sys
- dialect_module = sys.modules[testing.db.dialect.__module__]
-
- # we're relying on the presence of "ischema_names" in the
- # dialect module, else we can't test this. we need to be able
- # to get the dialect to not be aware of some type so we temporarily
- # monkeypatch. not sure what a better way for this could be,
- # except for an established dialect hook or dialect-specific tests
- if not hasattr(dialect_module, 'ischema_names'):
- return
-
- ischema_names = dialect_module.ischema_names
+ ischema_names = testing.db.dialect.ischema_names
t.create()
- dialect_module.ischema_names = {}
+ testing.db.dialect.ischema_names = {}
try:
m2 = MetaData(testing.db)
assert_raises(tsa.exc.SAWarning, Table, "test", m2, autoload=True)
assert t3.c.foo.type.__class__ == sa.types.NullType
finally:
- dialect_module.ischema_names = ischema_names
+ testing.db.dialect.ischema_names = ischema_names
t.drop()
def test_basic_override(self):
m9.reflect()
self.assert_(not m9.tables)
- @testing.fails_on_everything_except('postgres', 'mysql')
def test_index_reflection(self):
m1 = MetaData(testing.db)
t1 = Table('party', m1,
def test_basic(self):
try:
# the 'convert_unicode' should not get in the way of the reflection
- # process. reflecttable for oracle, postgres (others?) expect non-unicode
+ # process. reflecttable for oracle, postgresql (others?) expect non-unicode
# strings in result sets/bind params
bind = engines.utf8_engine(options={'convert_unicode':True})
metadata = MetaData(bind)
metadata.create_all()
reflected = set(bind.table_names())
- if not names.issubset(reflected):
+ # Jython 2.5 on Java 5 lacks unicodedata.normalize
+ if not names.issubset(reflected) and hasattr(unicodedata, 'normalize'):
# Python source files in the utf-8 coding seem to normalize
# literals as NFC (and the above are explicitly NFC). Maybe
# this database normalizes NFD on reflection.
Column('col1', sa.Integer, primary_key=True),
Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')),
schema='someschema')
- # ensure this doesnt crash
- print [t for t in metadata.sorted_tables]
- buf = StringIO.StringIO()
- def foo(s, p=None):
- buf.write(s)
- gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo)
- gen = gen.dialect.schemagenerator(gen.dialect, gen)
- gen.traverse(table1)
- gen.traverse(table2)
- buf = buf.getvalue()
- print buf
+
+ t1 = str(schema.CreateTable(table1).compile(bind=testing.db))
+ t2 = str(schema.CreateTable(table2).compile(bind=testing.db))
if testing.db.dialect.preparer(testing.db.dialect).omit_schema:
- assert buf.index("CREATE TABLE table1") > -1
- assert buf.index("CREATE TABLE table2") > -1
+ assert t1.index("CREATE TABLE table1") > -1
+ assert t2.index("CREATE TABLE table2") > -1
else:
- assert buf.index("CREATE TABLE someschema.table1") > -1
- assert buf.index("CREATE TABLE someschema.table2") > -1
+ assert t1.index("CREATE TABLE someschema.table1") > -1
+ assert t2.index("CREATE TABLE someschema.table2") > -1
@testing.crashes('firebird', 'No schema support')
@testing.fails_on('sqlite', 'FIXME: unknown')
def test_explicit_default_schema(self):
engine = testing.db
- if testing.against('mysql'):
+ if testing.against('mysql+mysqldb'):
schema = testing.db.url.database
- elif testing.against('postgres'):
+ elif testing.against('postgresql'):
schema = 'public'
elif testing.against('sqlite'):
# Works for CREATE TABLE main.foo, SELECT FROM main.foo, etc.,
metadata.drop_all(bind=testing.db)
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False)
+# Tests related to engine.reflection
+
+def get_schema():
+ if testing.against('oracle'):
+ return 'scott'
+ return 'test_schema'
+
+def createTables(meta, schema=None):
+ if schema:
+ parent_user_id = Column('parent_user_id', sa.Integer,
+ sa.ForeignKey('%s.users.user_id' % schema)
+ )
+ else:
+ parent_user_id = Column('parent_user_id', sa.Integer,
+ sa.ForeignKey('users.user_id')
+ )
+
+ users = Table('users', meta,
+ Column('user_id', sa.INT, primary_key=True),
+ Column('user_name', sa.VARCHAR(20), nullable=False),
+ Column('test1', sa.CHAR(5), nullable=False),
+ Column('test2', sa.Float(5), nullable=False),
+ Column('test3', sa.Text),
+ Column('test4', sa.Numeric(10, 2), nullable = False),
+ Column('test5', sa.DateTime),
+ Column('test5-1', sa.TIMESTAMP),
+ parent_user_id,
+ Column('test6', sa.DateTime, nullable=False),
+ Column('test7', sa.Text),
+ Column('test8', sa.Binary),
+ Column('test_passivedefault2', sa.Integer, server_default='5'),
+ Column('test9', sa.Binary(100)),
+ Column('test10', sa.Numeric(10, 2)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ addresses = Table('email_addresses', meta,
+ Column('address_id', sa.Integer, primary_key = True),
+ Column('remote_user_id', sa.Integer,
+ sa.ForeignKey(users.c.user_id)),
+ Column('email_address', sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ return (users, addresses)
+
+def createIndexes(con, schema=None):
+ fullname = 'users'
+ if schema:
+ fullname = "%s.%s" % (schema, 'users')
+ query = "CREATE INDEX users_t_idx ON %s (test1, test2)" % fullname
+ con.execute(sa.sql.text(query))
+
+def createViews(con, schema=None):
+ for table_name in ('users', 'email_addresses'):
+ fullname = table_name
+ if schema:
+ fullname = "%s.%s" % (schema, table_name)
+ view_name = fullname + '_v'
+ query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name,
+ fullname)
+ con.execute(sa.sql.text(query))
+
+def dropViews(con, schema=None):
+ for table_name in ('email_addresses', 'users'):
+ fullname = table_name
+ if schema:
+ fullname = "%s.%s" % (schema, table_name)
+ view_name = fullname + '_v'
+ query = "DROP VIEW %s" % view_name
+ con.execute(sa.sql.text(query))
+
+
+class ComponentReflectionTest(TestBase):
+
+ @testing.requires.schemas
+ def test_get_schema_names(self):
+ meta = MetaData(testing.db)
+ insp = Inspector(meta.bind)
+
+ self.assert_(get_schema() in insp.get_schema_names())
+
+ def _test_get_table_names(self, schema=None, table_type='table',
+ order_by=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ createViews(meta.bind, schema)
+ try:
+ insp = Inspector(meta.bind)
+ if table_type == 'view':
+ table_names = insp.get_view_names(schema)
+ table_names.sort()
+ answer = ['email_addresses_v', 'users_v']
+ else:
+ table_names = insp.get_table_names(schema,
+ order_by=order_by)
+ table_names.sort()
+ if order_by == 'foreign_key':
+ answer = ['users', 'email_addresses']
+ else:
+ answer = ['email_addresses', 'users']
+ eq_(table_names, answer)
+ finally:
+ dropViews(meta.bind, schema)
+ addresses.drop()
+ users.drop()
+
+ def test_get_table_names(self):
+ self._test_get_table_names()
+
+ @testing.requires.schemas
+ def test_get_table_names_with_schema(self):
+ self._test_get_table_names(get_schema())
+
+ def test_get_view_names(self):
+ self._test_get_table_names(table_type='view')
+
+ @testing.requires.schemas
+ def test_get_view_names_with_schema(self):
+ self._test_get_table_names(get_schema(), table_type='view')
+
+ def _test_get_columns(self, schema=None, table_type='table'):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ table_names = ['users', 'email_addresses']
+ meta.create_all()
+ if table_type == 'view':
+ createViews(meta.bind, schema)
+ table_names = ['users_v', 'email_addresses_v']
+ try:
+ insp = Inspector(meta.bind)
+ for (table_name, table) in zip(table_names, (users, addresses)):
+ schema_name = schema
+ cols = insp.get_columns(table_name, schema=schema_name)
+ self.assert_(len(cols) > 0, len(cols))
+ # should be in order
+ for (i, col) in enumerate(table.columns):
+ eq_(col.name, cols[i]['name'])
+ ctype = cols[i]['type'].__class__
+ ctype_def = col.type
+ if isinstance(ctype_def, sa.types.TypeEngine):
+ ctype_def = ctype_def.__class__
+
+ # Oracle returns Date for DateTime.
+ if testing.against('oracle') \
+ and ctype_def in (sql_types.Date, sql_types.DateTime):
+ ctype_def = sql_types.Date
+
+ # assert that the desired type and return type
+ # share a base within one of the generic types.
+ self.assert_(
+ len(
+ set(
+ ctype.__mro__
+ ).intersection(ctype_def.__mro__)
+ .intersection([sql_types.Integer, sql_types.Numeric,
+ sql_types.DateTime, sql_types.Date, sql_types.Time,
+ sql_types.String, sql_types.Binary])
+ ) > 0
+ ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'],
+ ctype)))
+ finally:
+ if table_type == 'view':
+ dropViews(meta.bind, schema)
+ addresses.drop()
+ users.drop()
+
+ def test_get_columns(self):
+ self._test_get_columns()
+
+ @testing.requires.schemas
+ def test_get_columns_with_schema(self):
+ self._test_get_columns(schema=get_schema())
+
+ def test_get_view_columns(self):
+ self._test_get_columns(table_type='view')
+
+ @testing.requires.schemas
+ def test_get_view_columns_with_schema(self):
+ self._test_get_columns(schema=get_schema(), table_type='view')
+
+ def _test_get_primary_keys(self, schema=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ insp = Inspector(meta.bind)
+ try:
+ users_pkeys = insp.get_primary_keys(users.name,
+ schema=schema)
+ eq_(users_pkeys, ['user_id'])
+ addr_pkeys = insp.get_primary_keys(addresses.name,
+ schema=schema)
+ eq_(addr_pkeys, ['address_id'])
+
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_primary_keys(self):
+ self._test_get_primary_keys()
+
+ @testing.fails_on('sqlite', 'no schemas')
+ def test_get_primary_keys_with_schema(self):
+ self._test_get_primary_keys(schema=get_schema())
+
+ def _test_get_foreign_keys(self, schema=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ insp = Inspector(meta.bind)
+ try:
+ expected_schema = schema
+ # users
+ users_fkeys = insp.get_foreign_keys(users.name,
+ schema=schema)
+ fkey1 = users_fkeys[0]
+ self.assert_(fkey1['name'] is not None)
+ eq_(fkey1['referred_schema'], expected_schema)
+ eq_(fkey1['referred_table'], users.name)
+ eq_(fkey1['referred_columns'], ['user_id', ])
+ eq_(fkey1['constrained_columns'], ['parent_user_id'])
+ #addresses
+ addr_fkeys = insp.get_foreign_keys(addresses.name,
+ schema=schema)
+ fkey1 = addr_fkeys[0]
+ self.assert_(fkey1['name'] is not None)
+ eq_(fkey1['referred_schema'], expected_schema)
+ eq_(fkey1['referred_table'], users.name)
+ eq_(fkey1['referred_columns'], ['user_id', ])
+ eq_(fkey1['constrained_columns'], ['remote_user_id'])
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_foreign_keys(self):
+ self._test_get_foreign_keys()
+
+ @testing.requires.schemas
+ def test_get_foreign_keys_with_schema(self):
+ self._test_get_foreign_keys(schema=get_schema())
+
+ def _test_get_indexes(self, schema=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ createIndexes(meta.bind, schema)
+ try:
+ # The database may decide to create indexes for foreign keys, etc.
+ # so there may be more indexes than expected.
+ insp = Inspector(meta.bind)
+ indexes = insp.get_indexes('users', schema=schema)
+ indexes.sort()
+ expected_indexes = [
+ {'unique': False,
+ 'column_names': ['test1', 'test2'],
+ 'name': 'users_t_idx'}]
+ index_names = [d['name'] for d in indexes]
+ for e_index in expected_indexes:
+ assert e_index['name'] in index_names
+ index = indexes[index_names.index(e_index['name'])]
+ for key in e_index:
+ eq_(e_index[key], index[key])
+
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_indexes(self):
+ self._test_get_indexes()
+
+ @testing.requires.schemas
+ def test_get_indexes_with_schema(self):
+ self._test_get_indexes(schema=get_schema())
+
+ def _test_get_view_definition(self, schema=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ createViews(meta.bind, schema)
+ view_name1 = 'users_v'
+ view_name2 = 'email_addresses_v'
+ try:
+ insp = Inspector(meta.bind)
+ v1 = insp.get_view_definition(view_name1, schema=schema)
+ self.assert_(v1)
+ v2 = insp.get_view_definition(view_name2, schema=schema)
+ self.assert_(v2)
+ finally:
+ dropViews(meta.bind, schema)
+ addresses.drop()
+ users.drop()
+
+ def test_get_view_definition(self):
+ self._test_get_view_definition()
+
+ @testing.requires.schemas
+ def test_get_view_definition_with_schema(self):
+ self._test_get_view_definition(schema=get_schema())
+
+ def _test_get_table_oid(self, table_name, schema=None):
+ if testing.against('postgresql'):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ try:
+ insp = create_inspector(meta.bind)
+ oid = insp.get_table_oid(table_name, schema)
+ self.assert_(isinstance(oid, (int, long)))
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_table_oid(self):
+ self._test_get_table_oid('users')
+
+ @testing.requires.schemas
+ def test_get_table_oid_with_schema(self):
+ self._test_get_table_oid('users', schema=get_schema())
+
users.create(testing.db)
def teardown(self):
- testing.db.connect().execute(users.delete())
+ testing.db.execute(users.delete()).close()
+
@classmethod
def teardown_class(cls):
users.drop(testing.db)
result = connection.execute("select * from query_users")
assert len(result.fetchall()) == 3
transaction.commit()
+ connection.close()
def test_rollback(self):
"""test a basic rollback"""
connection.close()
@testing.requires.savepoints
+ @testing.crashes('oracle+zxjdbc', 'Errors out and causes subsequent tests to deadlock')
def test_nested_subtransaction_commit(self):
connection = testing.db.connect()
transaction = connection.begin()
connection.close()
@testing.requires.two_phase_transactions
+ @testing.crashes('mysql+zxjdbc', 'Deadlocks, causing subsequent tests to fail')
@testing.fails_on('mysql', 'FIXME: unknown')
def test_two_phase_recover(self):
# MySQL recovery doesn't currently seem to work correctly
Requires PostgreSQL so that we may define a custom function which modifies the database.
"""
- __only_on__ = 'postgres'
+ __only_on__ = 'postgresql'
@classmethod
def setup_class(cls):
testing.db.execute("create function insert_foo(varchar) returns integer as 'insert into foo(data) values ($1);select 1;' language sql")
def teardown(self):
- foo.delete().execute()
+ foo.delete().execute().close()
@classmethod
def teardown_class(cls):
test_needs_acid=True,
)
users.create(tlengine)
+
def teardown(self):
- tlengine.execute(users.delete())
+ tlengine.execute(users.delete()).close()
+
@classmethod
def teardown_class(cls):
users.drop(tlengine)
try:
assert len(result.fetchall()) == 0
finally:
+ c.close()
external_connection.close()
def test_rollback(self):
external_connection.close()
def test_commits(self):
- assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0
+ connection = tlengine.connect()
+ assert connection.execute("select count(1) from query_users").scalar() == 0
+ connection.close()
connection = tlengine.contextual_connect()
transaction = connection.begin()
l = result.fetchall()
assert len(l) == 3, "expected 3 got %d" % len(l)
transaction.commit()
+ connection.close()
def test_rollback_off_conn(self):
# test that a TLTransaction opened off a TLConnection allows that
try:
assert len(result.fetchall()) == 0
finally:
+ conn.close()
external_connection.close()
def test_morerollback_off_conn(self):
try:
assert len(result.fetchall()) == 0
finally:
+ conn.close()
+ conn2.close()
external_connection.close()
def test_commit_off_connection(self):
try:
assert len(result.fetchall()) == 3
finally:
+ conn.close()
external_connection.close()
def test_nesting(self):
test_needs_acid=True,
)
counters.create(testing.db)
+
def teardown(self):
- testing.db.connect().execute(counters.delete())
+ testing.db.execute(counters.delete()).close()
+
@classmethod
def teardown_class(cls):
counters.drop(testing.db)
for i in xrange(count):
trans = con.begin()
try:
- existing = con.execute(sel).fetchone()
+ existing = con.execute(sel).first()
incr = existing['counter_value'] + 1
time.sleep(delay)
values={'counter_value':incr}))
time.sleep(delay)
- readback = con.execute(sel).fetchone()
+ readback = con.execute(sel).first()
if (readback['counter_value'] != incr):
raise AssertionError("Got %s post-update, expected %s" %
(readback['counter_value'], incr))
self.assert_(len(errors) == 0)
sel = counters.select(whereclause=counters.c.counter_id==1)
- final = db.execute(sel).fetchone()
+ final = db.execute(sel).first()
self.assert_(final['counter_value'] == iterations * thread_count)
def overlap(self, ids, errors, update_style):
from sqlalchemy.test.testing import eq_, assert_raises
import copy
-import gc
import pickle
from sqlalchemy import *
from sqlalchemy.orm.collections import collection
from sqlalchemy.ext.associationproxy import *
from sqlalchemy.test import *
+from sqlalchemy.test.util import gc_collect
class DictCollection(dict):
add_child('p1', 'c1')
- gc.collect()
+ gc_collect()
add_child('p1', 'c2')
session.flush()
p.kids.extend(['c1', 'c2'])
p_copy = copy.copy(p)
del p
- gc.collect()
+ gc_collect()
assert set(p_copy.kids) == set(['c1', 'c2']), p.kids
from sqlalchemy import *
+from sqlalchemy.types import TypeEngine
from sqlalchemy.sql.expression import ClauseElement, ColumnClause
+from sqlalchemy.schema import DDLElement
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import table, column
from sqlalchemy.test import *
select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5),
"SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1"
)
+
+ def test_types(self):
+ class MyType(TypeEngine):
+ pass
+
+ @compiles(MyType, 'sqlite')
+ def visit_type(type, compiler, **kw):
+ return "SQLITE_FOO"
+
+ @compiles(MyType, 'postgresql')
+ def visit_type(type, compiler, **kw):
+ return "POSTGRES_FOO"
+
+ from sqlalchemy.dialects.sqlite import base as sqlite
+ from sqlalchemy.dialects.postgresql import base as postgresql
+ self.assert_compile(
+ MyType(),
+ "SQLITE_FOO",
+ dialect=sqlite.dialect()
+ )
+
+ self.assert_compile(
+ MyType(),
+ "POSTGRES_FOO",
+ dialect=postgresql.dialect()
+ )
+
+
def test_stateful(self):
class MyThingy(ColumnClause):
def __init__(self):
)
def test_dialect_specific(self):
- class AddThingy(ClauseElement):
+ class AddThingy(DDLElement):
__visit_name__ = 'add_thingy'
- class DropThingy(ClauseElement):
+ class DropThingy(DDLElement):
__visit_name__ = 'drop_thingy'
@compiles(AddThingy, 'sqlite')
"DROP THINGY"
)
- from sqlalchemy.databases import sqlite as base
+ from sqlalchemy.dialects.sqlite import base
self.assert_compile(AddThingy(),
"ADD SPECIAL SL THINGY",
dialect=base.dialect()
import sqlalchemy as sa
from sqlalchemy.test import testing
from sqlalchemy import MetaData, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred
from sqlalchemy.test.testing import eq_
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
addresses = relation("Address", backref="user")
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
email = Column(String(50), key='_email')
user_id = Column('user_id', Integer, ForeignKey('users.id'),
key='_user_id')
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
name = Column(String(50))
addresses = relation("Address", order_by="desc(Address.email)",
primaryjoin="User.id==Address.user_id", foreign_keys="[Address.user_id]",
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
email = Column(String(50))
user_id = Column(Integer) # note no foreign key
def test_uncompiled_attributes_in_relation(self):
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
email = Column(String(50))
user_id = Column(Integer, ForeignKey('users.id'))
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
name = Column(String(50))
addresses = relation("Address", order_by=Address.email,
foreign_keys=Address.user_id,
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
User.name = Column('name', String(50))
User.addresses = relation("Address", backref="user")
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
Address.email = Column(String(50), key='_email')
Address.user_id = Column('user_id', Integer, ForeignKey('users.id'),
key='_user_id')
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
email = Column('email', String(50))
user_id = Column('user_id', Integer, ForeignKey('users.id'))
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
addresses = relation("Address", order_by=Address.email)
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
email = Column('email', String(50))
user_id = Column('user_id', Integer, ForeignKey('users.id'))
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
addresses = relation("Address", order_by=(Address.email, Address.id))
class User(ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
addresses = relation("Address", backref="user")
class Address(ComparableEntity):
__tablename__ = 'addresses'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
email = Column('email', String(50))
user_id = Column('user_id', Integer, ForeignKey('users.id'))
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
addresses = relation("Address", backref="user")
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
email = Column('email', String(50))
user_id = Column('user_id', Integer, ForeignKey('users.id'))
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
User.a = Column('a', String(10))
def test_column_properties(self):
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
email = Column(String(50))
user_id = Column(Integer, ForeignKey('users.id'))
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
adr_count = sa.orm.column_property(
sa.select([sa.func.count(Address.id)], Address.user_id == id).
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
name = sa.orm.deferred(Column(String(50)))
Base.metadata.create_all()
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
_name = Column('name', String(50))
def _set_name(self, name):
self._name = "SOMENAME " + name
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
_name = Column('name', String(50))
name = sa.orm.synonym('_name', comparator_factory=CustomCompare)
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
_name = Column('name', String(50))
def _set_name(self, name):
self._name = "SOMENAME " + name
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
addresses = relation("Address", backref="user")
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
email = Column('email', String(50))
user_id = Column('user_id', Integer, ForeignKey(User.id))
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
email = Column('email', String(50))
user_id = Column('user_id', Integer, ForeignKey('users.id'))
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
addresses = relation("Address", backref="user",
primaryjoin=id == Address.user_id)
def test_with_explicit_autoloaded(self):
meta = MetaData(testing.db)
t1 = Table('t1', meta,
- Column('id', String(50), primary_key=True),
+ Column('id', String(50), primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
meta.create_all()
try:
finally:
meta.drop_all()
+ def test_synonym_for(self):
+ class User(Base, ComparableEntity):
+ __tablename__ = 'users'
+
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+ name = Column('name', String(50))
+
+ @decl.synonym_for('name')
+ @property
+ def namesyn(self):
+ return self.name
+
+ Base.metadata.create_all()
+
+ sess = create_session()
+ u1 = User(name='someuser')
+ eq_(u1.name, "someuser")
+ eq_(u1.namesyn, 'someuser')
+ sess.add(u1)
+ sess.flush()
+
+ rt = sess.query(User).filter(User.namesyn == 'someuser').one()
+ eq_(rt, u1)
+
+ def test_comparable_using(self):
+ class NameComparator(sa.orm.PropComparator):
+ @property
+ def upperself(self):
+ cls = self.prop.parent.class_
+ col = getattr(cls, 'name')
+ return sa.func.upper(col)
+
+ def operate(self, op, other, **kw):
+ return op(self.upperself, other, **kw)
+
+ class User(Base, ComparableEntity):
+ __tablename__ = 'users'
+
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+ name = Column('name', String(50))
+
+ @decl.comparable_using(NameComparator)
+ @property
+ def uc_name(self):
+ return self.name is not None and self.name.upper() or None
+
+ Base.metadata.create_all()
+
+ sess = create_session()
+ u1 = User(name='someuser')
+ eq_(u1.name, "someuser", u1.name)
+ eq_(u1.uc_name, 'SOMEUSER', u1.uc_name)
+ sess.add(u1)
+ sess.flush()
+ sess.expunge_all()
+
+ rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one()
+ eq_(rt, u1)
+ sess.expunge_all()
+
+ rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one()
+ eq_(rt, u1)
+
+
class DeclarativeInheritanceTest(DeclarativeTestBase):
def test_custom_join_condition(self):
class Foo(Base):
def test_joined(self):
class Company(Base, ComparableEntity):
__tablename__ = 'companies'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
employees = relation("Person")
class Person(Base, ComparableEntity):
__tablename__ = 'people'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
company_id = Column('company_id', Integer,
ForeignKey('companies.id'))
name = Column('name', String(50))
class Company(Base, ComparableEntity):
__tablename__ = 'companies'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
employees = relation("Person")
class Person(Base, ComparableEntity):
__tablename__ = 'people'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
company_id = Column('company_id', Integer,
ForeignKey('companies.id'))
name = Column('name', String(50))
class Company(Base, ComparableEntity):
__tablename__ = 'companies'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
employees = relation("Person")
class Person(Base, ComparableEntity):
__tablename__ = 'people'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
company_id = Column(Integer,
ForeignKey('companies.id'))
name = Column(String(50))
def test_joined_from_single(self):
class Company(Base, ComparableEntity):
__tablename__ = 'companies'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
name = Column('name', String(50))
employees = relation("Person")
class Person(Base, ComparableEntity):
__tablename__ = 'people'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
company_id = Column(Integer, ForeignKey('companies.id'))
name = Column(String(50))
discriminator = Column('type', String(50))
def test_add_deferred(self):
class Person(Base, ComparableEntity):
__tablename__ = 'people'
- id = Column('id', Integer, primary_key=True)
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
Person.name = deferred(Column(String(10)))
Person(name='ratbert')
]
)
+ sess.expunge_all()
+
person = sess.query(Person).filter(Person.name == 'ratbert').one()
assert 'name' not in person.__dict__
class Person(Base, ComparableEntity):
__tablename__ = 'people'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
name = Column(String(50))
discriminator = Column('type', String(50))
__mapper_args__ = {'polymorphic_on':discriminator}
class Language(Base, ComparableEntity):
__tablename__ = 'languages'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
name = Column(String(50))
assert not hasattr(Person, 'primary_language_id')
def test_concrete(self):
engineers = Table('engineers', Base.metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('primary_language', String(50))
)
managers = Table('managers', Base.metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('golf_swing', String(50))
)
class User(Base, ComparableEntity):
__tablename__ = 'users'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
name = Column(String(50))
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
email = Column(String(50))
user_id = Column(Integer, ForeignKey('users.id'))
if inline:
reflection_metadata = MetaData(testing.db)
Table('users', reflection_metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
test_needs_fk=True)
Table('addresses', reflection_metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('email', String(50)),
Column('user_id', Integer, ForeignKey('users.id')),
test_needs_fk=True)
Table('imhandles', reflection_metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('user_id', Integer),
Column('network', String(50)),
Column('handle', String(50)),
class User(Base, ComparableEntity):
__tablename__ = 'users'
__autoload__ = True
+ if testing.against('oracle', 'firebird'):
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
addresses = relation("Address", backref="user")
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
__autoload__ = True
+ if testing.against('oracle', 'firebird'):
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
+
u1 = User(name='u1', addresses=[
Address(email='one'),
Address(email='two'),
class User(Base, ComparableEntity):
__tablename__ = 'users'
__autoload__ = True
+ if testing.against('oracle', 'firebird'):
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
nom = Column('name', String(50), key='nom')
addresses = relation("Address", backref="user")
class Address(Base, ComparableEntity):
__tablename__ = 'addresses'
__autoload__ = True
+ if testing.against('oracle', 'firebird'):
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
u1 = User(nom='u1', addresses=[
Address(email='one'),
class IMHandle(Base, ComparableEntity):
__tablename__ = 'imhandles'
__autoload__ = True
+ if testing.against('oracle', 'firebird'):
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
user_id = Column('user_id', Integer,
ForeignKey('users.id'))
class User(Base, ComparableEntity):
__tablename__ = 'users'
__autoload__ = True
+ if testing.against('oracle', 'firebird'):
+ id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
handles = relation("IMHandle", backref="user")
u1 = User(name='u1', handles=[
eq_(a1, IMHandle(network='lol', handle='zomg'))
eq_(a1.user, User(name='u1'))
- def test_synonym_for(self):
- class User(Base, ComparableEntity):
- __tablename__ = 'users'
-
- id = Column('id', Integer, primary_key=True)
- name = Column('name', String(50))
-
- @decl.synonym_for('name')
- @property
- def namesyn(self):
- return self.name
-
- Base.metadata.create_all()
-
- sess = create_session()
- u1 = User(name='someuser')
- eq_(u1.name, "someuser")
- eq_(u1.namesyn, 'someuser')
- sess.add(u1)
- sess.flush()
-
- rt = sess.query(User).filter(User.namesyn == 'someuser').one()
- eq_(rt, u1)
-
- def test_comparable_using(self):
- class NameComparator(sa.orm.PropComparator):
- @property
- def upperself(self):
- cls = self.prop.parent.class_
- col = getattr(cls, 'name')
- return sa.func.upper(col)
-
- def operate(self, op, other, **kw):
- return op(self.upperself, other, **kw)
-
- class User(Base, ComparableEntity):
- __tablename__ = 'users'
-
- id = Column('id', Integer, primary_key=True)
- name = Column('name', String(50))
-
- @decl.comparable_using(NameComparator)
- @property
- def uc_name(self):
- return self.name is not None and self.name.upper() or None
-
- Base.metadata.create_all()
-
- sess = create_session()
- u1 = User(name='someuser')
- eq_(u1.name, "someuser", u1.name)
- eq_(u1.uc_name, 'SOMEUSER', u1.uc_name)
- sess.add(u1)
- sess.flush()
- sess.expunge_all()
-
- rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one()
- eq_(rt, u1)
- sess.expunge_all()
-
- rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one()
- eq_(rt, u1)
-
-
-if __name__ == '__main__':
- testing.main()
[(7, u'jack'), (8, u'ed'), (8, u'ed'), (8, u'ed'), (9, u'fred')]
)
- # fails due to pure Python pickle bug: http://bugs.python.org/issue998998
- @testing.fails_if(lambda: util.py3k)
def test_query(self):
q = Session.query(User).filter(User.name=='ed').options(eagerload(User.addresses))
eq_(q.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])])
-import gc
import inspect
import sys
import types
import sqlalchemy as sa
+import sqlalchemy.exceptions as sa_exc
from sqlalchemy.test import config, testing
from sqlalchemy.test.testing import resolve_artifact_names, adict
+from sqlalchemy.test.engines import drop_all_tables
from sqlalchemy.util import function_named
if attr.startswith('_'):
continue
value = getattr(a, attr)
- if (hasattr(value, '__iter__') and
- not isinstance(value, basestring)):
- try:
- # catch AttributeError so that lazy loaders trigger
- battr = getattr(b, attr)
- except AttributeError:
- return False
+ try:
+ # handle lazy loader errors
+ battr = getattr(b, attr)
+ except (AttributeError, sa_exc.UnboundExecutionError):
+ return False
+
+ if hasattr(value, '__iter__'):
if list(value) != list(battr):
return False
else:
- if value is not None:
- if value != getattr(b, attr, None):
- return False
+ if value is not None and value != battr:
+ return False
return True
finally:
_recursion_stack.remove(id(self))
def setup(self):
if self.run_define_tables == 'each':
self.tables.clear()
- self.metadata.drop_all()
+ drop_all_tables(self.metadata)
self.metadata.clear()
self.define_tables(self.metadata)
self.metadata.create_all()
for cl in cls.classes.values():
cls.unregister_class(cl)
ORMTest.teardown_class()
- cls.metadata.drop_all()
+ drop_all_tables(cls.metadata)
cls.metadata.bind = None
@classmethod
orders = fixture_table(
Table('orders', fixture_metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('user_id', None, ForeignKey('users.id')),
Column('address_id', None, ForeignKey('addresses.id')),
Column('description', String(30)),
dingalings = fixture_table(
Table("dingalings", fixture_metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('address_id', None, ForeignKey('addresses.id')),
Column('data', String(30)),
test_needs_acid=True,
from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
from sqlalchemy.test import testing
+from sqlalchemy.test.schema import Table, Column
from test.orm import _base
def define_tables(cls, metadata):
global ta, tb, tc
ta = ["a", metadata]
- ta.append(Column('id', Integer, primary_key=True)),
+ ta.append(Column('id', Integer, primary_key=True, test_needs_autoincrement=True)),
ta.append(Column('a_data', String(30)))
if "a"== parent and direction == MANYTOONE:
ta.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo")))
from sqlalchemy.util import function_named
from test.orm import _base, _fixtures
+from sqlalchemy.test.schema import Table, Column
class ABCTest(_base.MappedTest):
@classmethod
def define_tables(cls, metadata):
global a, b, c
a = Table('a', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('adata', String(30)),
Column('type', String(30)),
)
C(cdata='c1', bdata='c1', adata='c1'),
C(cdata='c2', bdata='c2', adata='c2'),
C(cdata='c2', bdata='c2', adata='c2'),
- ] == sess.query(A).all()
+ ] == sess.query(A).order_by(A.id).all()
assert [
B(bdata='b1', adata='b1'),
from sqlalchemy.test import testing, engines
from sqlalchemy.util import function_named
from test.orm import _base, _fixtures
+from sqlalchemy.test.schema import Table, Column
class O2MTest(_base.MappedTest):
"""deals with inheritance and one-to-many relationships"""
def define_tables(cls, metadata):
global foo, bar, blub
foo = Table('foo', metadata,
- Column('id', Integer, Sequence('foo_seq', optional=True),
- primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(20)))
bar = Table('bar', metadata,
def define_tables(cls, metadata):
global t1
t1 = Table('t1', metadata,
- Column('id', Integer, primary_key=True),
- Column('type', Boolean, nullable=False)
- )
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('type', Boolean, nullable=False))
def test_false_on_sub(self):
class Foo(object):pass
def define_tables(cls, metadata):
global t1, t2
t1 = Table('t1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('type', String(10), nullable=False),
Column('info', String(255)))
t2 = Table('t2', metadata,
def define_tables(cls, metadata):
global t1, t2, t3, t4
t1= Table('t1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30))
)
t2 = Table('t2', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('t1id', Integer, ForeignKey('t1.id')),
Column('type', String(30)),
Column('data', String(30))
Column('moredata', String(30)))
t4 = Table('t4', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('t3id', Integer, ForeignKey('t3.id')),
Column('data', String(30)))
def define_tables(cls, metadata):
global foo, bar, blub
foo = Table('foo', metadata,
- Column('id', Integer, Sequence('foo_seq', optional=True),
- primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('type', String(30)),
Column('data', String(20)))
Column('data', String(20)))
blub = Table('blub', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('foo_id', Integer, ForeignKey('foo.id')),
Column('bar_id', Integer, ForeignKey('bar.id')),
Column('data', String(20)))
def define_tables(cls, metadata):
global foo, bar, bar_foo
foo = Table('foo', metadata,
- Column('id', Integer, Sequence('foo_seq', optional=True),
- primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)))
bar = Table('bar', metadata,
Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
def define_tables(cls, metadata):
global users, roles, user_roles, admins
users = Table('users', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('email', String(128)),
Column('password', String(16)),
)
roles = Table('role', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('description', String(32))
)
)
admins = Table('admin', metadata,
- Column('admin_id', Integer, primary_key=True),
+ Column('admin_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('user_id', Integer, ForeignKey('users.id'))
)
def define_tables(cls, metadata):
global base, subtable, stuff
base = Table('base', metadata,
- Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('version_id', Integer, nullable=False),
Column('value', String(40)),
Column('discriminator', Integer, nullable=False)
Column('subdata', String(50))
)
stuff = Table('stuff', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('parent', Integer, ForeignKey('base.id'))
)
- @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.')
@engines.close_open_connections
def test_save_update(self):
class Base(_fixtures.Base):
try:
sess2.flush()
- assert False
+ assert not testing.db.dialect.supports_sane_rowcount
except orm_exc.ConcurrentModificationError, e:
assert True
sess2.refresh(s2)
- assert s2.subdata == 'sess1 subdata'
+ if testing.db.dialect.supports_sane_rowcount:
+ assert s2.subdata == 'sess1 subdata'
s2.subdata = 'sess2 subdata'
sess2.flush()
- @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.')
def test_delete(self):
class Base(_fixtures.Base):
pass
try:
s1.subdata = 'some new subdata'
sess.flush()
- assert False
+ assert not testing.db.dialect.supports_sane_rowcount
except orm_exc.ConcurrentModificationError, e:
assert True
global person_table, employee_table, Person, Employee
person_table = Table("persons", metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("name", String(80)),
)
employee_table = Table("employees", metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("salary", Integer),
Column("person_id", Integer, ForeignKey("persons.id")),
)
global _a_table, _b_table, _c_table
_a_table = Table('a', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data1', String(128))
)
global base, subtable
base = Table('base', metadata,
- Column('base_id', Integer, primary_key=True),
+ Column('base_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(255)),
Column('sqlite_fixer', String(10))
)
def define_tables(cls, metadata):
global base, sub
base = Table('base', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)),
Column('type', String(50))
)
@classmethod
def define_tables(cls, metadata):
parents = Table('parents', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(60)))
children = Table('children', metadata,
def define_tables(cls, metadata):
global single, parent
single = Table('single', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('type', String(50), nullable=False),
Column('data', String(50)),
Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False),
)
parent = Table('parent', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50))
)
from test.orm import _base
from sqlalchemy.orm import attributes
from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.schema import Table, Column
class Employee(object):
def __init__(self, name):
global managers_table, engineers_table, hackers_table, companies, employees_table
companies = Table('companies', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)))
employees_table = Table('employees', metadata,
- Column('employee_id', Integer, primary_key=True),
+ Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('company_id', Integer, ForeignKey('companies.id'))
)
managers_table = Table('managers', metadata,
- Column('employee_id', Integer, primary_key=True),
+ Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('manager_data', String(50)),
Column('company_id', Integer, ForeignKey('companies.id'))
)
engineers_table = Table('engineers', metadata,
- Column('employee_id', Integer, primary_key=True),
+ Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('engineer_info', String(50)),
Column('company_id', Integer, ForeignKey('companies.id'))
)
hackers_table = Table('hackers', metadata,
- Column('employee_id', Integer, primary_key=True),
+ Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('engineer_info', String(50)),
Column('company_id', Integer, ForeignKey('companies.id')),
@classmethod
def define_tables(cls, metadata):
Table('a_table', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('some_c_id', Integer, ForeignKey('c_table.id')),
Column('aname', String(50)),
)
Table('b_table', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('some_c_id', Integer, ForeignKey('c_table.id')),
Column('bname', String(50)),
)
Table('c_table', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('cname', String(50)),
)
def define_tables(cls, metadata):
global offices_table, refugees_table
refugees_table = Table('refugee', metadata,
- Column('refugee_fid', Integer, primary_key=True),
+ Column('refugee_fid', Integer, primary_key=True, test_needs_autoincrement=True),
Column('refugee_name', Unicode(30), key='name'))
offices_table = Table('office', metadata,
- Column('office_fid', Integer, primary_key=True),
+ Column('office_fid', Integer, primary_key=True, test_needs_autoincrement=True),
Column('office_name', Unicode(30), key='name'))
@classmethod
from sqlalchemy.test import testing
from sqlalchemy.util import function_named
from test.orm import _base
+from sqlalchemy.test.schema import Table, Column
class BaseObject(object):
def __init__(self, *args, **kwargs):
global publication_table, issue_table, location_table, location_name_table, magazine_table, \
page_table, magazine_page_table, classified_page_table, page_size_table
- zerodefault = {} #{'default':0}
publication_table = Table('publication', metadata,
- Column('id', Integer, primary_key=True, default=None),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(45), default=''),
)
issue_table = Table('issue', metadata,
- Column('id', Integer, primary_key=True, default=None),
- Column('publication_id', Integer, ForeignKey('publication.id'), **zerodefault),
- Column('issue', Integer, **zerodefault),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('publication_id', Integer, ForeignKey('publication.id')),
+ Column('issue', Integer),
)
location_table = Table('location', metadata,
- Column('id', Integer, primary_key=True, default=None),
- Column('issue_id', Integer, ForeignKey('issue.id'), **zerodefault),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('issue_id', Integer, ForeignKey('issue.id')),
Column('ref', CHAR(3), default=''),
- Column('location_name_id', Integer, ForeignKey('location_name.id'), **zerodefault),
+ Column('location_name_id', Integer, ForeignKey('location_name.id')),
)
location_name_table = Table('location_name', metadata,
- Column('id', Integer, primary_key=True, default=None),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(45), default=''),
)
magazine_table = Table('magazine', metadata,
- Column('id', Integer, primary_key=True, default=None),
- Column('location_id', Integer, ForeignKey('location.id'), **zerodefault),
- Column('page_size_id', Integer, ForeignKey('page_size.id'), **zerodefault),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('location_id', Integer, ForeignKey('location.id')),
+ Column('page_size_id', Integer, ForeignKey('page_size.id')),
)
page_table = Table('page', metadata,
- Column('id', Integer, primary_key=True, default=None),
- Column('page_no', Integer, **zerodefault),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('page_no', Integer),
Column('type', CHAR(1), default='p'),
)
magazine_page_table = Table('magazine_page', metadata,
- Column('page_id', Integer, ForeignKey('page.id'), primary_key=True, **zerodefault),
- Column('magazine_id', Integer, ForeignKey('magazine.id'), **zerodefault),
- Column('orders', TEXT, default=''),
+ Column('page_id', Integer, ForeignKey('page.id'), primary_key=True),
+ Column('magazine_id', Integer, ForeignKey('magazine.id')),
+ Column('orders', Text, default=''),
)
classified_page_table = Table('classified_page', metadata,
- Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True, **zerodefault),
+ Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True),
Column('titles', String(45), default=''),
)
page_size_table = Table('page_size', metadata,
- Column('id', Integer, primary_key=True, default=None),
- Column('width', Integer, **zerodefault),
- Column('height', Integer, **zerodefault),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('width', Integer),
+ Column('height', Integer),
Column('name', String(45), default=''),
)
'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no))
})
- classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c', primary_key=[page_table.c.id])
- #compile_mappers()
- #print [str(s) for s in classified_page_mapper.primary_key]
- #print classified_page_mapper.columntoproperty[page_table.c.id]
+ classified_page_mapper = mapper(ClassifiedPage,
+ classified_page_table,
+ inherits=magazine_page_mapper,
+ polymorphic_identity='c',
+ primary_key=[page_table.c.id])
session = create_session()
b.foos.append(Foo("foo #1"))
b.foos.append(Foo("foo #2"))
sess.flush()
- compare = repr(b) + repr(sorted([repr(o) for o in b.foos]))
+ compare = [repr(b)] + sorted([repr(o) for o in b.foos])
sess.expunge_all()
l = sess.query(Bar).all()
print repr(l[0]) + repr(l[0].foos)
- found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos]))
+ found = [repr(l[0])] + sorted([repr(o) for o in l[0].foos])
eq_(found, compare)
@testing.fails_on('maxdb', 'FIXME: unknown')
from test.orm import _base
from sqlalchemy.test import testing
+from sqlalchemy.test.schema import Table, Column
class PolymorphicCircularTest(_base.MappedTest):
def define_tables(cls, metadata):
global Table1, Table1B, Table2, Table3, Data
table1 = Table('table1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('related_id', Integer, ForeignKey('table1.id'), nullable=True),
Column('type', String(30)),
Column('name', String(30))
)
data = Table('data', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('node_id', Integer, ForeignKey('table1.id')),
Column('data', String(30))
)
polymorphic_on=table1.c.type,
polymorphic_identity='table1',
properties={
- 'next': relation(Table1,
+ 'nxt': relation(Table1,
backref=backref('prev', foreignkey=join.c.id, uselist=False),
uselist=False, primaryjoin=join.c.id==join.c.related_id),
'data':relation(mapper(Data, data))
# currently, the "eager" relationships degrade to lazy relationships
# due to the polymorphic load.
- # the "next" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential"
+ # the "nxt" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential"
# exception now. since eager loading would never work for that relation anyway, its better that the user
# gets an exception instead of it silently not eager loading.
+ # NOTE: using "nxt" instead of "next" to avoid 2to3 turning it into __next__() for some reason.
table1_mapper = mapper(Table1, table1,
#select_table=join,
polymorphic_on=table1.c.type,
polymorphic_identity='table1',
properties={
- 'next': relation(Table1,
+ 'nxt': relation(Table1,
backref=backref('prev', remote_side=table1.c.id, uselist=False),
uselist=False, primaryjoin=table1.c.id==table1.c.related_id),
'data':relation(mapper(Data, data), lazy=False, order_by=data.c.id)
else:
newobj = c
if obj is not None:
- obj.next = newobj
+ obj.nxt = newobj
else:
t = newobj
obj = newobj
node = t
while (node):
assertlist.append(node)
- n = node.next
+ n = node.nxt
if n is not None:
assert n.prev is node
node = n
assertlist = []
while (node):
assertlist.append(node)
- n = node.next
+ n = node.nxt
if n is not None:
assert n.prev is node
node = n
assertlist.insert(0, node)
n = node.prev
if n is not None:
- assert n.next is node
+ assert n.nxt is node
node = n
backwards = repr(assertlist)
from sqlalchemy.util import function_named
from test.orm import _base, _fixtures
from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.schema import Table, Column
class AttrSettable(object):
def __init__(self, **kwargs):
def define_tables(cls, metadata):
global people, managers, data
people = Table('people', metadata,
- Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('type', String(30)))
def define_tables(cls, metadata):
global people, managers, data
people = Table('people', metadata,
- Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('colleague_id', Integer, ForeignKey('people.person_id')),
Column('name', String(50)),
Column('type', String(30)))
def define_tables(cls, metadata):
global people, engineers, managers, cars
people = Table('people', metadata,
- Column('person_id', Integer, primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)))
engineers = Table('engineers', metadata,
Column('longer_status', String(70)))
cars = Table('cars', metadata,
- Column('car_id', Integer, primary_key=True),
+ Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('owner', Integer, ForeignKey('people.person_id')))
def testmanytoonepolymorphic(self):
def define_tables(cls, metadata):
global people, engineers, managers, cars
people = Table('people', metadata,
- Column('person_id', Integer, primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('type', String(50)))
Column('longer_status', String(70)))
cars = Table('cars', metadata,
- Column('car_id', Integer, primary_key=True),
+ Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('owner', Integer, ForeignKey('people.person_id')))
def testeagerempty(self):
def define_tables(cls, metadata):
global people, managers, data
people = Table('people', metadata,
- Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
)
def define_tables(cls, metadata):
global people, engineers, managers, cars, offroad_cars
cars = Table('cars', metadata,
- Column('car_id', Integer, primary_key=True),
+ Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(30)))
offroad_cars = Table('offroad_cars', metadata,
Column('car_id',Integer, ForeignKey('cars.car_id'),nullable=False,primary_key=True))
people = Table('people', metadata,
- Column('person_id', Integer, primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('car_id', Integer, ForeignKey('cars.car_id'), nullable=False),
Column('name', String(50)))
def define_tables(cls, metadata):
global taggable, users
taggable = Table('taggable', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('type', String(30)),
Column('owner_id', Integer, ForeignKey('taggable.id')),
)
metadata = MetaData(testing.db)
# table definitions
status = Table('status', metadata,
- Column('status_id', Integer, primary_key=True),
+ Column('status_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(20)))
people = Table('people', metadata,
- Column('person_id', Integer, primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False),
Column('name', String(50)))
Column('category', String(70)))
cars = Table('cars', metadata,
- Column('car_id', Integer, primary_key=True),
+ Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False),
Column('owner', Integer, ForeignKey('people.person_id'), nullable=False))
e = exists([Car.owner], Car.owner==employee_join.c.person_id)
Query(Person)._adapt_clause(employee_join, False, False)
- r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active")
- assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
+ r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active").order_by(Person.person_id)
+ eq_(str(list(r)), "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]")
r = session.query(Engineer).join('status').filter(Person.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active")).order_by(Person.name)
- assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
+ eq_(str(list(r)), "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]")
r = session.query(Person).filter(exists([1], Car.owner==Person.person_id))
- assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
+ eq_(str(list(r)), "[Engineer E4, field X, status Status dead]")
class MultiLevelTest(_base.MappedTest):
@classmethod
global table_Employee, table_Engineer, table_Manager
table_Employee = Table( 'Employee', metadata,
Column( 'name', type_= String(100), ),
- Column( 'id', primary_key= True, type_= Integer, ),
+ Column( 'id', primary_key= True, type_= Integer, test_needs_autoincrement=True),
Column( 'atype', type_= String(100), ),
)
global base_item_table, item_table, base_item_collection_table, collection_table
base_item_table = Table(
'base_item', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('child_name', String(255), default=None))
item_table = Table(
collection_table = Table(
'collection', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', Unicode(255)))
def test_pjoin_compile(self):
def define_tables(cls, metadata):
global t1, t2
t1 = Table('t1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('type', String(30), nullable=False),
Column('data', String(30)))
# note that the primary key column in t2 is named differently
global people, employees, tags, peopleTags
people = Table('people', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('_type', String(30), nullable=False),
)
)
tags = Table('tags', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('label', String(50), nullable=False),
)
def define_tables(cls, metadata):
global tablea, tableb, tablec, tabled
tablea = Table('tablea', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('adata', String(50)),
)
tableb = Table('tableb', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('aid', Integer, ForeignKey('tablea.id')),
Column('data', String(50)),
)
from sqlalchemy import *
from sqlalchemy.orm import *
-
from sqlalchemy.test import testing
from test.orm import _base
-
+from sqlalchemy.test.schema import Table, Column
class InheritTest(_base.MappedTest):
"""tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships"""
global Product, Detail, Assembly, SpecLine, Document, RasterDocument
products_table = Table('products', metadata,
- Column('product_id', Integer, primary_key=True),
+ Column('product_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('product_type', String(128)),
Column('name', String(128)),
Column('mark', String(128)),
)
specification_table = Table('specification', metadata,
- Column('spec_line_id', Integer, primary_key=True),
+ Column('spec_line_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('master_id', Integer, ForeignKey("products.product_id"),
nullable=True),
Column('slave_id', Integer, ForeignKey("products.product_id"),
)
documents_table = Table('documents', metadata,
- Column('document_id', Integer, primary_key=True),
+ Column('document_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('document_type', String(128)),
Column('product_id', Integer, ForeignKey('products.product_id')),
Column('create_date', DateTime, default=lambda:datetime.now()),
from sqlalchemy.test import AssertsCompiledSQL, testing
from test.orm import _base, _fixtures
from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.schema import Table, Column
class Company(_fixtures.Base):
pass
global companies, people, engineers, managers, boss, paperwork, machines
companies = Table('companies', metadata,
- Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key=True),
+ Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)))
people = Table('people', metadata,
- Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('company_id', Integer, ForeignKey('companies.company_id')),
Column('name', String(50)),
Column('type', String(30)))
)
machines = Table('machines', metadata,
- Column('machine_id', Integer, primary_key=True),
+ Column('machine_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('engineer_id', Integer, ForeignKey('engineers.person_id')))
)
paperwork = Table('paperwork', metadata,
- Column('paperwork_id', Integer, primary_key=True),
+ Column('paperwork_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('description', String(50)),
Column('person_id', Integer, ForeignKey('people.person_id')))
def define_tables(cls, metadata):
global people, engineers
people = Table('people', metadata,
- Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('type', String(30)))
def define_tables(cls, metadata):
global people, engineers, managers
people = Table('people', metadata,
- Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('type', String(30)))
global people, engineers, organizations, engineers_to_org
organizations = Table('organizations', metadata,
- Column('id', Integer, Sequence('org_id_seq', optional=True), primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
)
engineers_to_org = Table('engineers_org', metadata,
)
people = Table('people', metadata,
- Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('type', String(30)))
class Parent(Base):
__tablename__ = 'parent'
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, primary_key=True, test_needs_autoincrement=True)
cls = Column(String(50))
__mapper_args__ = dict(polymorphic_on = cls )
s = sessionmaker(bind=testing.db)()
- assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all()
+ assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).order_by(Foo.b.desc()).all()
assert [Bar(), Bar()] == s.query(Bar).all()
from sqlalchemy.test import testing
from test.orm import _fixtures
from test.orm._base import MappedTest, ComparableEntity
+from sqlalchemy.test.schema import Table, Column
class SingleInheritanceTest(MappedTest):
@classmethod
def define_tables(cls, metadata):
Table('employees', metadata,
- Column('employee_id', Integer, primary_key=True),
+ Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('manager_data', String(50)),
Column('engineer_info', String(50)),
Column('type', String(20)))
Table('reports', metadata,
- Column('report_id', Integer, primary_key=True),
+ Column('report_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('employee_id', ForeignKey('employees.employee_id')),
Column('name', String(50)),
)
@classmethod
def define_tables(cls, metadata):
Table('employees', metadata,
- Column('employee_id', Integer, primary_key=True),
+ Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('manager_data', String(50)),
Column('engineer_info', String(50)),
)
Table('companies', metadata,
- Column('company_id', Integer, primary_key=True),
+ Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
)
global persons_table, employees_table
persons_table = Table('persons', metadata,
- Column('person_id', Integer, primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('type', String(20), nullable=False)
)
from sqlalchemy.sql import operators
from sqlalchemy.test import *
from sqlalchemy.test.testing import eq_
+from nose import SkipTest
# TODO: ShardTest can be turned into a base for further subclasses
def setup_class(cls):
global db1, db2, db3, db4, weather_locations, weather_reports
- db1 = create_engine('sqlite:///shard1.db')
+ try:
+ db1 = create_engine('sqlite:///shard1.db')
+ except ImportError:
+ raise SkipTest('Requires sqlite')
db2 = create_engine('sqlite:///shard2.db')
db3 = create_engine('sqlite:///shard3.db')
db4 = create_engine('sqlite:///shard4.db')
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, create_session
from test.orm import _base
from sqlalchemy.test.testing import eq_
@classmethod
def define_tables(cls, metadata):
Table('items', metadata,
- Column('item_id', Integer, primary_key=True),
+ Column('item_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(40)))
Table('item_keywords', metadata,
Column('item_id', Integer, ForeignKey('items.item_id')),
Column('keyword_id', Integer, ForeignKey('keywords.keyword_id')),
Column('data', String(40)))
Table('keywords', metadata,
- Column('keyword_id', Integer, primary_key=True),
+ Column('keyword_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(40)))
@classmethod
import sqlalchemy as sa
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, backref, create_session
from sqlalchemy.test.testing import eq_
from test.orm import _base
cls.other_artifacts['false'] = false
Table('owners', metadata ,
- Column('id', Integer, primary_key=True, nullable=False),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)))
Table('categories', metadata,
- Column('id', Integer, primary_key=True, nullable=False),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(20)))
Table('tests', metadata ,
- Column('id', Integer, primary_key=True, nullable=False ),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('owner_id', Integer, ForeignKey('owners.id'),
nullable=False),
Column('category_id', Integer, ForeignKey('categories.id'),
nullable=False))
Table('options', metadata ,
- Column('test_id', Integer, ForeignKey('tests.id'),
- primary_key=True, nullable=False),
- Column('owner_id', Integer, ForeignKey('owners.id'),
- primary_key=True, nullable=False),
- Column('someoption', sa.Boolean, server_default=false,
- nullable=False))
+ Column('test_id', Integer, ForeignKey('tests.id'), primary_key=True),
+ Column('owner_id', Integer, ForeignKey('owners.id'), primary_key=True),
+ Column('someoption', sa.Boolean, server_default=false, nullable=False))
@classmethod
def setup_classes(cls):
Column('data', String(50), primary_key=True))
Table('middle', metadata,
- Column('id', Integer, primary_key = True),
+ Column('id', Integer, primary_key = True, test_needs_autoincrement=True),
Column('data', String(50)))
Table('right', metadata,
@classmethod
def define_tables(cls, metadata):
Table('datas', metadata,
- Column('id', Integer, primary_key=True, nullable=False),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('a', Integer, nullable=False))
Table('foo', metadata,
- Column('data_id', Integer,
- ForeignKey('datas.id'),
- nullable=False, primary_key=True),
+ Column('data_id', Integer, ForeignKey('datas.id'),primary_key=True),
Column('bar', Integer))
Table('stats', metadata,
- Column('id', Integer, primary_key=True, nullable=False ),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data_id', Integer, ForeignKey('datas.id')),
Column('somedata', Integer, nullable=False ))
@classmethod
def define_tables(cls, metadata):
Table('departments', metadata,
- Column('department_id', Integer, primary_key=True),
+ Column('department_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)))
Table('employees', metadata,
- Column('person_id', Integer, primary_key=True),
+ Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('department_id', Integer,
ForeignKey('departments.department_id')))
Column('x', String(30)))
Table('derived', metadata,
- Column('uid', String(30), ForeignKey('base.uid'),
- primary_key=True),
+ Column('uid', String(30), ForeignKey('base.uid'), primary_key=True),
Column('y', String(30)))
Table('derivedII', metadata,
- Column('uid', String(30), ForeignKey('base.uid'),
- primary_key=True),
+ Column('uid', String(30), ForeignKey('base.uid'), primary_key=True),
Column('z', String(30)))
Table('comments', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('uid', String(30), ForeignKey('base.uid')),
Column('comment', String(30)))
@classmethod
def define_tables(cls, metadata):
Table('design_types', metadata,
- Column('design_type_id', Integer, primary_key=True))
+ Column('design_type_id', Integer, primary_key=True, test_needs_autoincrement=True))
Table('design', metadata,
- Column('design_id', Integer, primary_key=True),
+ Column('design_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('design_type_id', Integer,
ForeignKey('design_types.design_type_id')))
Table('parts', metadata,
- Column('part_id', Integer, primary_key=True),
+ Column('part_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('design_id', Integer, ForeignKey('design.design_id')),
Column('design_type_id', Integer,
ForeignKey('design_types.design_type_id')))
Table('inherited_part', metadata,
- Column('ip_id', Integer, primary_key=True),
+ Column('ip_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('part_id', Integer, ForeignKey('parts.part_id')),
Column('design_id', Integer, ForeignKey('design.design_id')))
@classmethod
def define_tables(cls, metadata):
Table('companies', metadata,
- Column('company_id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('company_name', String(40)))
Table('addresses', metadata,
- Column('address_id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('address_id', Integer, primary_key=True,test_needs_autoincrement=True),
Column('company_id', Integer, ForeignKey("companies.company_id")),
Column('address', String(40)))
Table('phone_numbers', metadata,
- Column('phone_id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('phone_id', Integer, primary_key=True,test_needs_autoincrement=True),
Column('address_id', Integer, ForeignKey('addresses.address_id')),
Column('type', String(20)),
Column('number', String(10)))
Table('invoices', metadata,
- Column('invoice_id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('invoice_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('company_id', Integer, ForeignKey("companies.company_id")),
Column('date', sa.DateTime))
Table('items', metadata,
- Column('item_id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('item_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')),
Column('code', String(20)),
Column('qty', Integer))
@classmethod
def define_tables(cls, metadata):
Table('prj', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('created', sa.DateTime ),
Column('title', sa.Unicode(100)))
Table('task', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('status_id', Integer,
ForeignKey('task_status.id'), nullable=False),
Column('title', sa.Unicode(100)),
Column('prj_id', Integer , ForeignKey('prj.id'), nullable=False))
Table('task_status', metadata,
- Column('id', Integer, primary_key=True))
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True))
Table('task_type', metadata,
- Column('id', Integer, primary_key=True))
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True))
Table('msg', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('posted', sa.DateTime, index=True,),
Column('type_id', Integer, ForeignKey('msg_type.id')),
Column('task_id', Integer, ForeignKey('task.id')))
Table('msg_type', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', sa.Unicode(20)),
Column('display_name', sa.Unicode(20)))
@classmethod
def define_tables(cls, metadata):
Table('accounts', metadata,
- Column('account_id', Integer, primary_key=True),
+ Column('account_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(40)))
Table('transactions', metadata,
- Column('transaction_id', Integer, primary_key=True),
+ Column('transaction_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(40)))
Table('entries', metadata,
- Column('entry_id', Integer, primary_key=True),
+ Column('entry_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(40)),
Column('account_id', Integer,
ForeignKey('accounts.account_id')),
from sqlalchemy.test import *
from sqlalchemy.test.testing import eq_
from test.orm import _base
-import gc
+from sqlalchemy.test.util import gc_collect
+from sqlalchemy.util import cmp, jython
# global for pickling tests
MyTest = None
del o2.__dict__['mt2']
o2.__dict__[o_mt2_str] = former
- self.assert_(pk_o == pk_o2)
+ # Relies on dict ordering
+ if not jython:
+ self.assert_(pk_o == pk_o2)
# the above is kind of distrurbing, so let's do it again a little
# differently. the string-id in serialization thing is just an
o4 = pickle.loads(pk_o3)
pk_o4 = pickle.dumps(o4)
- self.assert_(pk_o3 == pk_o4)
+ # Relies on dict ordering
+ if not jython:
+ self.assert_(pk_o3 == pk_o4)
# and lastly make sure we still have our data after all that.
# identical serialzation is great, *if* it's complete :)
f.bar = "foo"
assert state.dict == {'bar':'foo', state.manager.STATE_ATTR:state}
del f
- gc.collect()
+ gc_collect()
assert state.obj() is None
assert state.dict == {}
from sqlalchemy.test.testing import assert_raises, assert_raises_message
from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref
from sqlalchemy.orm import attributes, exc as orm_exc
from sqlalchemy.test import testing
mapper(User, users, properties = dict(
addresses = relation(Address, cascade="all, delete-orphan", backref="user"),
orders = relation(
- mapper(Order, orders), cascade="all, delete-orphan")
+ mapper(Order, orders), cascade="all, delete-orphan", order_by=orders.c.id)
))
mapper(Dingaling,dingalings, properties={
'address':relation(Address)
orders=[Order(description="order 3"),
Order(description="order 4")]))
- eq_(sess.query(Order).all(),
+ eq_(sess.query(Order).order_by(Order.id).all(),
[Order(description="order 3"), Order(description="order 4")])
o5 = Order(description="order 5")
sess.add(o5)
- try:
- sess.flush()
- assert False
- except orm_exc.FlushError, e:
- assert "is an orphan" in str(e)
+ assert_raises_message(orm_exc.FlushError, "is an orphan", sess.flush)
@testing.resolve_artifact_names
@classmethod
def define_tables(cls, metadata):
Table("extra", metadata,
- Column("id", Integer, Sequence("extra_id_seq", optional=True),
- primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("prefs_id", Integer, ForeignKey("prefs.id")))
Table('prefs', metadata,
- Column('id', Integer, Sequence('prefs_id_seq', optional=True),
- primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(40)))
Table('users', metadata,
- Column('id', Integer, Sequence('user_id_seq', optional=True),
- primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(40)),
Column('pref_id', Integer, ForeignKey('prefs.id')))
jack.pref = newpref
jack.pref = newpref
sess.flush()
- eq_(sess.query(Pref).all(),
+ eq_(sess.query(Pref).order_by(Pref.id).all(),
[Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")])
class M2OCascadeDeleteTest(_base.MappedTest):
@classmethod
def define_tables(cls, metadata):
Table('t1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)),
Column('t2id', Integer, ForeignKey('t2.id')))
Table('t2', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)),
Column('t3id', Integer, ForeignKey('t3.id')))
Table('t3', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
@classmethod
@classmethod
def define_tables(cls, metadata):
Table('t1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)),
Column('t2id', Integer, ForeignKey('t2.id')))
Table('t2', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)),
Column('t3id', Integer, ForeignKey('t3.id')))
Table('t3', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
@classmethod
@classmethod
def define_tables(cls, metadata):
Table('a', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
test_needs_fk=True
)
Table('b', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
test_needs_fk=True
)
Table('c', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
Column('bid', Integer, ForeignKey('b.id')),
test_needs_fk=True
@classmethod
def define_tables(cls, metadata):
Table('users', metadata,
- Column('user_id', Integer,
- Sequence('user_id_seq', optional=True),
- primary_key=True),
+ Column('user_id', Integer,primary_key=True, test_needs_autoincrement=True),
Column('name', String(40)))
Table('addresses', metadata,
- Column('address_id', Integer,
- Sequence('address_id_seq', optional=True),
- primary_key=True),
+ Column('address_id', Integer,primary_key=True, test_needs_autoincrement=True),
Column('user_id', Integer, ForeignKey('users.user_id')),
Column('email_address', String(40)))
@classmethod
def define_tables(cls, meta):
Table('orders', meta,
- Column('id', Integer, Sequence('order_id_seq'),
- primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)))
Table('items', meta,
- Column('id', Integer, Sequence('item_id_seq'),
- primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('order_id', Integer, ForeignKey('orders.id'),
nullable=False),
Column('name', String(50)))
Table('attributes', meta,
- Column('id', Integer, Sequence('attribute_id_seq'),
- primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('item_id', Integer, ForeignKey('items.id'),
nullable=False),
Column('name', String(50)))
@classmethod
def define_tables(cls, meta):
Table('sales_reps', meta,
- Column('sales_rep_id', Integer,
- Sequence('sales_rep_id_seq'),
- primary_key=True),
+ Column('sales_rep_id', Integer,primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)))
Table('accounts', meta,
- Column('account_id', Integer,
- Sequence('account_id_seq'),
- primary_key=True),
+ Column('account_id', Integer,primary_key=True, test_needs_autoincrement=True),
Column('balance', Integer))
Table('customers', meta,
- Column('customer_id', Integer,
- Sequence('customer_id_seq'),
- primary_key=True),
+ Column('customer_id', Integer,primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('sales_rep_id', Integer,
ForeignKey('sales_reps.sales_rep_id')),
@classmethod
def define_tables(cls, metadata):
Table('addresses', metadata,
- Column('address_id', Integer, primary_key=True),
+ Column('address_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('street', String(30)),
)
Table('homes', metadata,
- Column('home_id', Integer, primary_key=True, key="id"),
+ Column('home_id', Integer, primary_key=True, key="id", test_needs_autoincrement=True),
Column('description', String(30)),
Column('address_id', Integer, ForeignKey('addresses.address_id'),
nullable=False),
)
Table('businesses', metadata,
- Column('business_id', Integer, primary_key=True, key="id"),
+ Column('business_id', Integer, primary_key=True, key="id", test_needs_autoincrement=True),
Column('description', String(30), key="description"),
Column('address_id', Integer, ForeignKey('addresses.address_id'),
nullable=False),
@classmethod
def define_tables(cls, metadata):
Table('table_a', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(30)))
Table('table_b', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(30)),
Column('a_id', Integer, ForeignKey('table_a.id')))
@classmethod
def define_tables(cls, metadata):
Table("base", metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("descr", String(50))
)
Table("noninh_child", metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('base_id', Integer, ForeignKey('base.id'))
)
import sqlalchemy as sa
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy import util, exc as sa_exc
-from sqlalchemy.orm import create_session, mapper, relation, attributes
+from sqlalchemy.orm import create_session, mapper, relation, attributes
from test.orm import _base
-from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.testing import eq_, assert_raises
class Canary(sa.orm.interfaces.AttributeExtension):
def __init__(self):
control[slice(0,-1)] = values
assert_eq()
+ values = [creator(),creator(),creator()]
+ control[:] = values
+ direct[:] = values
+ def invalid():
+ direct[slice(0, 6, 2)] = [creator()]
+ assert_raises(ValueError, invalid)
+
if hasattr(direct, '__delitem__'):
e = creator()
direct.append(e)
del direct[::2]
del control[::2]
assert_eq()
-
+
if hasattr(direct, 'remove'):
e = creator()
direct.append(e)
direct.remove(e)
control.remove(e)
assert_eq()
-
- if hasattr(direct, '__setslice__'):
+
+ if hasattr(direct, '__setitem__') or hasattr(direct, '__setslice__'):
+
+ values = [creator(), creator()]
+ direct[:] = values
+ control[:] = values
+ assert_eq()
+
+ # test slice assignment where
+ # slice size goes over the number of items
+ values = [creator(), creator()]
+ direct[1:3] = values
+ control[1:3] = values
+ assert_eq()
+
values = [creator(), creator()]
direct[0:1] = values
control[0:1] = values
direct[1::2] = values
control[1::2] = values
assert_eq()
+
+ values = [creator(), creator()]
+ direct[-1:-3] = values
+ control[-1:-3] = values
+ assert_eq()
- if hasattr(direct, '__delslice__'):
+ values = [creator(), creator()]
+ direct[-2:-1] = values
+ control[-2:-1] = values
+ assert_eq()
+
+
+ if hasattr(direct, '__delitem__') or hasattr(direct, '__delslice__'):
for i in range(1, 4):
e = creator()
direct.append(e)
del direct[:]
del control[:]
assert_eq()
-
+
if hasattr(direct, 'extend'):
values = [creator(), creator(), creator()]
self._test_list(list)
self._test_list_bulk(list)
+ def test_list_setitem_with_slices(self):
+
+ # this is a "list" that has no __setslice__
+ # or __delslice__ methods. The __setitem__
+ # and __delitem__ must therefore accept
+ # slice objects (i.e. as in py3k)
+ class ListLike(object):
+ def __init__(self):
+ self.data = list()
+ def append(self, item):
+ self.data.append(item)
+ def remove(self, item):
+ self.data.remove(item)
+ def insert(self, index, item):
+ self.data.insert(index, item)
+ def pop(self, index=-1):
+ return self.data.pop(index)
+ def extend(self):
+ assert False
+ def __len__(self):
+ return len(self.data)
+ def __setitem__(self, key, value):
+ self.data[key] = value
+ def __getitem__(self, key):
+ return self.data[key]
+ def __delitem__(self, key):
+ del self.data[key]
+ def __iter__(self):
+ return iter(self.data)
+ __hash__ = object.__hash__
+ def __eq__(self, other):
+ return self.data == other
+ def __repr__(self):
+ return 'ListLike(%s)' % repr(self.data)
+
+ self._test_adapter(ListLike)
+ self._test_list(ListLike)
+ self._test_list_bulk(ListLike)
+
def test_list_subclass(self):
class MyList(list):
pass
@classmethod
def define_tables(cls, metadata):
Table('parents', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('label', String(128)))
Table('children', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('parent_id', Integer, ForeignKey('parents.id'),
nullable=False),
Column('a', String(128)),
class Foo(BaseObject):
__tablename__ = "foo"
- id = Column(Integer(), primary_key=True)
+ id = Column(Integer(), primary_key=True, test_needs_autoincrement=True)
bar_id = Column(Integer, ForeignKey('bar.id'))
class Bar(BaseObject):
__tablename__ = "bar"
- id = Column(Integer(), primary_key=True)
+ id = Column(Integer(), primary_key=True, test_needs_autoincrement=True)
foos = relation(Foo, collection_class=collections.column_mapped_collection(Foo.id))
foos2 = relation(Foo, collection_class=collections.column_mapped_collection((Foo.id, Foo.bar_id)))
collection_class = lambda: Ordered2(lambda v: (v.a, v.b))
self._test_composite_mapped(collection_class)
-# TODO: are these tests redundant vs. the above tests ?
-# remove if so
class CustomCollectionsTest(_base.MappedTest):
+ """test the integration of collections with mapped classes."""
@classmethod
def define_tables(cls, metadata):
Table('sometable', metadata,
- Column('col1',Integer, primary_key=True),
+ Column('col1',Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)))
Table('someothertable', metadata,
- Column('col1', Integer, primary_key=True),
+ Column('col1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('scol1', Integer,
ForeignKey('sometable.col1')),
Column('data', String(20)))
replaced = set([id(b) for b in f.bars.values()])
self.assert_(existing != replaced)
- @testing.resolve_artifact_names
def test_list(self):
+ self._test_list(list)
+
+ def test_list_no_setslice(self):
+ class ListLike(object):
+ def __init__(self):
+ self.data = list()
+ def append(self, item):
+ self.data.append(item)
+ def remove(self, item):
+ self.data.remove(item)
+ def insert(self, index, item):
+ self.data.insert(index, item)
+ def pop(self, index=-1):
+ return self.data.pop(index)
+ def extend(self):
+ assert False
+ def __len__(self):
+ return len(self.data)
+ def __setitem__(self, key, value):
+ self.data[key] = value
+ def __getitem__(self, key):
+ return self.data[key]
+ def __delitem__(self, key):
+ del self.data[key]
+ def __iter__(self):
+ return iter(self.data)
+ __hash__ = object.__hash__
+ def __eq__(self, other):
+ return self.data == other
+ def __repr__(self):
+ return 'ListLike(%s)' % repr(self.data)
+
+ self._test_list(ListLike)
+
+ @testing.resolve_artifact_names
+ def _test_list(self, listcls):
class Parent(object):
pass
class Child(object):
pass
mapper(Parent, sometable, properties={
- 'children':relation(Child, collection_class=list)
+ 'children':relation(Child, collection_class=listcls)
})
mapper(Child, someothertable)
"""
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, backref, create_session
from sqlalchemy.test.testing import eq_
from sqlalchemy.test.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf
@classmethod
def define_tables(cls, metadata):
Table('item', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('uuid', String(32), unique=True, nullable=False),
Column('parent_uuid', String(32), ForeignKey('item.uuid'),
nullable=True))
@classmethod
def define_tables(cls, metadata):
Table("parent", metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("parent_data", String(50)),
Column("type", String(10)))
Table("child1", metadata,
- Column("id", Integer, ForeignKey("parent.id"),
- primary_key=True),
+ Column("id", Integer, ForeignKey("parent.id"), primary_key=True),
Column("child1_data", String(50)))
Table("child2", metadata,
- Column("id", Integer, ForeignKey("parent.id"),
- primary_key=True),
+ Column("id", Integer, ForeignKey("parent.id"), primary_key=True),
Column("child1_id", Integer, ForeignKey("child1.id"),
nullable=False),
Column("child2_data", String(50)))
@classmethod
def define_tables(cls, metadata):
Table('a', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
Column('cid', Integer, ForeignKey('c.id')))
Column('data', String(30)))
Table('c', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
Column('aid', Integer,
ForeignKey('a.id', use_alter=True, name="foo")))
@classmethod
def define_tables(cls, metadata):
Table('t1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
Column('t2id', Integer, ForeignKey('t2.id')))
Table('t2', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
Column('t1id', Integer,
ForeignKey('t1.id', use_alter=True, name="foo_fk")))
Table('t3', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
Column('t1id', Integer, ForeignKey('t1.id'), nullable=False),
Column('t2id', Integer, ForeignKey('t2.id'), nullable=False))
@classmethod
def define_tables(cls, metadata):
Table('t1', metadata,
- Column('c1', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('c2', Integer, ForeignKey('t2.c1')))
Table('t2', metadata,
- Column('c1', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('c2', Integer,
ForeignKey('t1.c1', use_alter=True, name='t1c1_fk')))
@classmethod
def define_tables(cls, metadata):
Table('t1', metadata,
- Column('c1', Integer, primary_key=True),
+ Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('c2', Integer, ForeignKey('t2.c1')),
test_needs_autoincrement=True)
Table('t2', metadata,
- Column('c1', Integer, primary_key=True),
+ Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('c2', Integer,
ForeignKey('t1.c1', use_alter=True, name='t1c1_fq')),
test_needs_autoincrement=True)
Table('t1_data', metadata,
- Column('c1', Integer, primary_key=True),
+ Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('t1id', Integer, ForeignKey('t1.c1')),
Column('data', String(20)),
test_needs_autoincrement=True)
@classmethod
def define_tables(cls, metadata):
Table('ball', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('person_id', Integer,
ForeignKey('person.id', use_alter=True, name='fk_person_id')),
Column('data', String(30)))
Table('person', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('favorite_ball_id', Integer, ForeignKey('ball.id')),
Column('data', String(30)))
@classmethod
def define_tables(cls, metadata):
Table("a_table", metadata,
- Column("id", Integer(), primary_key=True),
+ Column("id", Integer(), primary_key=True, test_needs_autoincrement=True),
Column("fui", String(128)),
Column("b", Integer(), ForeignKey("a_table.id")))
import sqlalchemy as sa
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, create_session
from test.orm import _base
from sqlalchemy.test.testing import eq_
@classmethod
def define_tables(cls, metadata):
dt = Table('dt', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('col1', String(20)),
Column('col2', String(20),
server_default=sa.schema.FetchedValue()),
"UPDATE dt SET col2='ins', col4='ins' "
"WHERE dt.id IN (SELECT id FROM inserted);",
on='mssql'),
- ):
- if testing.against(ins.on):
- break
- else:
- ins = sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt "
+ sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT "
+ "ON dt "
+ "FOR EACH ROW "
+ "BEGIN "
+ ":NEW.col2 := 'ins'; :NEW.col4 := 'ins'; END;",
+ on='oracle'),
+ sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt "
"FOR EACH ROW BEGIN "
- "SET NEW.col2='ins'; SET NEW.col4='ins'; END")
- ins.execute_at('after-create', dt)
+ "SET NEW.col2='ins'; SET NEW.col4='ins'; END",
+ on=lambda event, schema_item, bind, **kw:
+ bind.engine.name not in ('oracle', 'mssql', 'sqlite')
+ ),
+ ):
+ ins.execute_at('after-create', dt)
+
sa.DDL("DROP TRIGGER dt_ins").execute_at('before-drop', dt)
-
for up in (
sa.DDL("CREATE TRIGGER dt_up AFTER UPDATE ON dt "
"FOR EACH ROW BEGIN "
"UPDATE dt SET col3='up', col4='up' "
"WHERE dt.id IN (SELECT id FROM deleted);",
on='mssql'),
- ):
- if testing.against(up.on):
- break
- else:
- up = sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
+ sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
+ "FOR EACH ROW BEGIN "
+ ":NEW.col3 := 'up'; :NEW.col4 := 'up'; END;",
+ on='oracle'),
+ sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt "
"FOR EACH ROW BEGIN "
- "SET NEW.col3='up'; SET NEW.col4='up'; END")
- up.execute_at('after-create', dt)
+ "SET NEW.col3='up'; SET NEW.col4='up'; END",
+ on=lambda event, schema_item, bind, **kw:
+ bind.engine.name not in ('oracle', 'mssql', 'sqlite')
+ ),
+ ):
+ up.execute_at('after-create', dt)
+
sa.DDL("DROP TRIGGER dt_up").execute_at('before-drop', dt)
@classmethod
def define_tables(cls, metadata):
dt = Table('dt', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('col1', String(20), default="hello"),
)
from sqlalchemy.orm import dynamic_loader, backref
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey, desc, select, func
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, create_session, Query, attributes
from sqlalchemy.orm.dynamic import AppenderMixin
from sqlalchemy.test.testing import eq_
sess.flush()
sess.commit()
u1.addresses.append(Address(email_address='foo@bar.com'))
- eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')])
+ eq_(u1.addresses.order_by(Address.id).all(),
+ [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')])
sess.rollback()
eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com')])
@classmethod
def define_tables(cls, metadata):
Table('users', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(40)),
Column('fullname', String(100)),
Column('password', String(15)))
Table('addresses', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('email_address', String(100), nullable=False),
Column('user_id', Integer, ForeignKey('users.id')))
from sqlalchemy.test import testing
from sqlalchemy.orm import eagerload, deferred, undefer
from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, func
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, create_session, lazyload, aliased
from sqlalchemy.test.testing import eq_
from sqlalchemy.test.assertsql import CompiledSQL
})
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id),
- 'orders':relation(Order, lazy=True)
+ 'orders':relation(Order, lazy=True, order_by=orders.c.id)
})
sess = create_session()
q = sess.query(User)
- if testing.against('mysql'):
- l = q.limit(2).all()
- assert self.static.user_all_result[:2] == l
- else:
- l = q.order_by(User.id).limit(2).offset(1).all()
- print self.static.user_all_result[1:3]
- print l
- assert self.static.user_all_result[1:3] == l
+ l = q.order_by(User.id).limit(2).offset(1).all()
+ eq_(self.static.user_all_result[1:3], l)
@testing.resolve_artifact_names
def test_distinct(self):
s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
mapper(User, users, properties={
- 'addresses':relation(mapper(Address, addresses), lazy=False),
+ 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id),
})
sess = create_session()
q = sess.query(User)
def go():
- l = q.filter(s.c.u2_id==User.id).distinct().all()
- assert self.static.user_address_result == l
+ l = q.filter(s.c.u2_id==User.id).distinct().order_by(User.id).all()
+ eq_(self.static.user_address_result, l)
self.assert_sql_count(testing.db, go, 1)
@testing.fails_on('maxdb', 'FIXME: unknown')
mapper(Order, orders)
mapper(User, users, properties={
- 'orders':relation(Order, backref='user', lazy=False),
- 'max_order':relation(mapper(Order, max_orders, non_primary=True), lazy=False, uselist=False)
+ 'orders':relation(Order, backref='user', lazy=False, order_by=orders.c.id),
+ 'max_order':relation(
+ mapper(Order, max_orders, non_primary=True),
+ lazy=False, uselist=False)
})
+
q = create_session().query(User)
def go():
max_order=Order(id=4)
),
User(id=10),
- ] == q.all()
+ ] == q.order_by(User.id).all()
self.assert_sql_count(testing.db, go, 1)
@testing.resolve_artifact_names
@classmethod
def define_tables(cls, metadata):
Table('m2m', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('aid', Integer, ForeignKey('a.id')),
Column('bid', Integer, ForeignKey('b.id')))
Table('a', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
Table('b', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
@classmethod
@classmethod
def define_tables(cls, metadata):
Table('nodes', metadata,
- Column('id', Integer, sa.Sequence('node_id_seq', optional=True),
- primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('parent_id', Integer, ForeignKey('nodes.id')),
Column('data', String(30)))
@classmethod
def define_tables(cls, metadata):
Table('a_table', metadata,
- Column('id', Integer, primary_key=True)
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True)
)
Table('b_table', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('parent_b1_id', Integer, ForeignKey('b_table.id')),
Column('parent_a_id', Integer, ForeignKey('a_table.id')),
Column('parent_b2_id', Integer, ForeignKey('b_table.id')))
@classmethod
def define_tables(cls, metadata):
Table('widget', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', sa.Unicode(40), nullable=False, unique=True),
)
)
self.assert_sql_count(testing.db, go, 1)
- @testing.exclude('sqlite', '>', (0, 0, 0), "sqlite flat out blows it on the multiple JOINs")
+ @testing.exclude('sqlite', '>', (0, ), "sqlite flat out blows it on the multiple JOINs")
@testing.resolve_artifact_names
def test_two_entities_with_joins(self):
sess = create_session()
@classmethod
def define_tables(cls, metadata):
Table('t1', metadata,
- Column('c1', Integer, primary_key=True),
+ Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('c2', String(30)),
Column('type', String(30))
)
Table('t2', metadata,
- Column('c1', Integer, primary_key=True),
+ Column('c1', Integer, primary_key=True, test_needs_autoincrement=True),
Column('c2', String(30)),
Column('type', String(30)),
Column('t1.id', Integer, ForeignKey('t1.c1')))
@classmethod
def define_tables(cls, metadata):
Table('users_table', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(16))
)
Table('tags_table', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('user_id', Integer, ForeignKey("users_table.id")),
Column('score1', sa.Float),
Column('score2', sa.Float),
Exercises a variety of ways to configure this.
"""
+
+ # another argument for eagerload learning about inner joins
+
+ __requires__ = ('correlated_outer_joins', )
@classmethod
def define_tables(cls, metadata):
users = Table('users', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50))
)
stuff = Table('stuff', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('date', Date),
Column('user_id', Integer, ForeignKey('users.id')))
if ondate:
# the more 'relational' way to do this, join on the max date
- stuff_view = select([func.max(salias.c.date).label('max_date')]).where(salias.c.user_id==users.c.id).correlate(users)
+ stuff_view = select([func.max(salias.c.date).label('max_date')]).\
+ where(salias.c.user_id==users.c.id).correlate(users)
else:
# a common method with the MySQL crowd, which actually might perform better in some
# cases - subquery does a limit with order by DESC, join on the id
- stuff_view = select([salias.c.id]).where(salias.c.user_id==users.c.id).correlate(users).order_by(salias.c.date.desc()).limit(1)
+ stuff_view = select([salias.c.id]).where(salias.c.user_id==users.c.id).\
+ correlate(users).order_by(salias.c.date.desc()).limit(1)
if labeled == 'label':
stuff_view = stuff_view.label('foo')
"""Attribute/instance expiration, deferral of attributes, etc."""
from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
-import gc
+from sqlalchemy.test.util import gc_collect
import sqlalchemy as sa
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey, exc as sa_exc
assert self.static.user_address_result == userlist
assert len(list(sess)) == 9
sess.expire_all()
- gc.collect()
+ gc_collect()
assert len(list(sess)) == 4 # since addresses were gc'ed
userlist = sess.query(User).order_by(User.id).all()
assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar<30).one() == (0,)
assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar<30).one() == (29,)
+ # Py3K
+ #assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).__next__()[0] == 29
+ #assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).__next__()[0] == 29
+ # Py2K
assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
-
+ # end Py2K
+
@testing.resolve_artifact_names
def test_aggregate_1(self):
- if (testing.against('mysql') and
+ if (testing.against('mysql') and not testing.against('+zxjdbc') and
testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')):
return
def test_aggregate_3(self):
query = create_session().query(Foo)
+ # Py3K
+ #avg_f = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).__next__()[0]
+ # Py2K
avg_f = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).next()[0]
+ # end Py2K
assert round(avg_f, 1) == 14.5
+ # Py3K
+ #avg_o = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).__next__()[0]
+ # Py2K
avg_o = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).next()[0]
+ # end Py2K
assert round(avg_o, 1) == 14.5
@testing.resolve_artifact_names
class B(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
- assert_raises(TypeError, attributes.register_class, B)
+ assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B)
def test_single_up(self):
class B(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
attributes.register_class(B)
- assert_raises(TypeError, attributes.register_class, A)
+
+ assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, A)
def test_diamond_b1(self):
mgr_factory = lambda cls: attributes.ClassManager(cls)
class A(object): pass
class B1(A): pass
class B2(A):
- __sa_instrumentation_manager__ = mgr_factory
+ __sa_instrumentation_manager__ = staticmethod(mgr_factory)
class C(object): pass
- assert_raises(TypeError, attributes.register_class, B1)
+ assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1)
def test_diamond_b2(self):
mgr_factory = lambda cls: attributes.ClassManager(cls)
class A(object): pass
class B1(A): pass
class B2(A):
- __sa_instrumentation_manager__ = mgr_factory
+ __sa_instrumentation_manager__ = staticmethod(mgr_factory)
class C(object): pass
- assert_raises(TypeError, attributes.register_class, B2)
+ attributes.register_class(B2)
+ assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1)
def test_diamond_c_b(self):
mgr_factory = lambda cls: attributes.ClassManager(cls)
class A(object): pass
class B1(A): pass
class B2(A):
- __sa_instrumentation_manager__ = mgr_factory
+ __sa_instrumentation_manager__ = staticmethod(mgr_factory)
class C(object): pass
attributes.register_class(C)
- assert_raises(TypeError, attributes.register_class, B1)
+ assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1)
class OnLoadTest(_base.ORMTest):
"""Check that Events.on_load is not hit in regular attributes operations."""
# use a union all to get a lot of rows to join against
u2 = users.alias('u2')
s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
- print [key for key in s.c.keys()]
- l = q.filter(s.c.u2_id==User.id).distinct().all()
- assert self.static.user_all_result == l
+ l = q.filter(s.c.u2_id==User.id).order_by(User.id).distinct().all()
+ eq_(self.static.user_all_result, l)
@testing.resolve_artifact_names
def test_one_to_many_scalar(self):
from sqlalchemy.test.testing import assert_raises, assert_raises_message
import sqlalchemy as sa
from sqlalchemy.test import testing, pickleable
-from sqlalchemy import MetaData, Integer, String, ForeignKey, func
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy import MetaData, Integer, String, ForeignKey, func, util
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.engine import default
from sqlalchemy.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased
from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property
@testing.resolve_artifact_names
def test_self_ref_synonym(self):
t = Table('nodes', MetaData(),
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('parent_id', Integer, ForeignKey('nodes.id')))
class Node(object):
@testing.resolve_artifact_names
def test_prop_filters(self):
t = Table('person', MetaData(),
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('type', String(128)),
Column('name', String(128)),
Column('employee_number', Integer),
@testing.resolve_artifact_names
def test_comparable_column(self):
class MyComparator(sa.orm.properties.ColumnProperty.Comparator):
+ __hash__ = None
def __eq__(self, other):
# lower case comparison
return func.lower(self.__clause_element__()) == func.lower(other)
@testing.resolve_artifact_names
def test_group(self):
"""Deferred load with a group"""
- mapper(Order, orders, properties={
- 'userident': deferred(orders.c.user_id, group='primary'),
- 'addrident': deferred(orders.c.address_id, group='primary'),
- 'description': deferred(orders.c.description, group='primary'),
- 'opened': deferred(orders.c.isopen, group='primary')
- })
+ mapper(Order, orders, properties=util.OrderedDict([
+ ('userident', deferred(orders.c.user_id, group='primary')),
+ ('addrident', deferred(orders.c.address_id, group='primary')),
+ ('description', deferred(orders.c.description, group='primary')),
+ ('opened', deferred(orders.c.isopen, group='primary'))
+ ]))
sess = create_session()
q = sess.query(Order).order_by(Order.id)
@testing.resolve_artifact_names
def test_undefer_group(self):
- mapper(Order, orders, properties={
- 'userident':deferred(orders.c.user_id, group='primary'),
- 'description':deferred(orders.c.description, group='primary'),
- 'opened':deferred(orders.c.isopen, group='primary')})
+ mapper(Order, orders, properties=util.OrderedDict([
+ ('userident',deferred(orders.c.user_id, group='primary')),
+ ('description',deferred(orders.c.description, group='primary')),
+ ('opened',deferred(orders.c.isopen, group='primary'))
+ ]
+ ))
sess = create_session()
q = sess.query(Order).order_by(Order.id)
@classmethod
def define_tables(cls, metadata):
Table("thing", metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("name", String(20)))
Table("human", metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("thing_id", Integer, ForeignKey("thing.id")),
Column("name", String(20)))
@classmethod
def define_tables(cls, metadata):
Table('graphs', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('version_id', Integer, primary_key=True, nullable=True),
Column('name', String(30)))
Table('edges', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('graph_id', Integer, nullable=False),
Column('graph_version_id', Integer, nullable=False),
Column('x1', Integer),
['graphs.id', 'graphs.version_id']))
Table('foobars', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('x1', Integer, default=2),
Column('x2', Integer),
Column('x3', Integer, default=15),
# test pk with one column NULL
# TODO: can't seem to get NULL in for a PK value
- # in either mysql or postgres, autoincrement=False etc.
+ # in either mysql or postgresql, autoincrement=False etc.
# notwithstanding
@testing.fails_on_everything_except("sqlite")
def go():
@classmethod
def define_tables(cls, metadata):
Table('ht1', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('value', String(10)))
Table('ht2', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('ht1_id', Integer, ForeignKey('ht1.id')),
Column('value', String(10)))
Table('ht3', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('value', String(10)))
Table('ht4', metadata,
- Column('ht1_id', Integer, ForeignKey('ht1.id'),
- primary_key=True),
- Column('ht3_id', Integer, ForeignKey('ht3.id'),
- primary_key=True))
+ Column('ht1_id', Integer, ForeignKey('ht1.id'), primary_key=True),
+ Column('ht3_id', Integer, ForeignKey('ht3.id'), primary_key=True))
Table('ht5', metadata,
- Column('ht1_id', Integer, ForeignKey('ht1.id'),
- primary_key=True))
+ Column('ht1_id', Integer, ForeignKey('ht1.id'), primary_key=True))
Table('ht6', metadata,
- Column('ht1a_id', Integer, ForeignKey('ht1.id'),
- primary_key=True),
- Column('ht1b_id', Integer, ForeignKey('ht1.id'),
- primary_key=True),
+ Column('ht1a_id', Integer, ForeignKey('ht1.id'), primary_key=True),
+ Column('ht1b_id', Integer, ForeignKey('ht1.id'), primary_key=True),
Column('value', String(10)))
+ # Py2K
@testing.resolve_artifact_names
def test_baseclass(self):
class OldStyle:
# TODO: is weakref support detectable without an instance?
#self.assertRaises(sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2)
-
+ # end Py2K
+
@testing.resolve_artifact_names
def test_comparison_overrides(self):
"""Simple tests to ensure users can supply comparison __methods__.
@classmethod
def define_tables(cls, metadata):
Table('cartographers', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50)),
Column('alias', String(50)),
Column('quip', String(100)))
Table('maps', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('cart_id', Integer,
ForeignKey('cartographers.id')),
Column('state', String(2)),
for reserved in (sa.orm.attributes.ClassManager.STATE_ATTR,
sa.orm.attributes.ClassManager.MANAGER_ATTR):
t = Table('t', sa.MetaData(),
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column(reserved, Integer))
class T(object):
pass
from sqlalchemy.test.testing import assert_raises, assert_raises_message
import sqlalchemy as sa
-from sqlalchemy import Table, Column, Integer, PickleType
+from sqlalchemy import Integer, PickleType
import operator
from sqlalchemy.test import testing
from sqlalchemy.util import OrderedSet
from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property, sessionmaker
from sqlalchemy.test.testing import eq_, ne_
from test.orm import _base, _fixtures
-
+from sqlalchemy.test.schema import Table, Column
class MergeTest(_fixtures.FixtureTest):
"""Session.merge() functionality"""
'addresses':relation(Address,
backref='user',
collection_class=OrderedSet,
+ order_by=addresses.c.id,
cascade="all, delete-orphan")
})
mapper(Address, addresses)
mapper(User, users, properties={
'addresses':relation(Address,
backref='user',
+ order_by=addresses.c.id,
collection_class=OrderedSet)})
mapper(Address, addresses)
on_load = self.on_load_tracker(User)
# test with "dontload" merge
sess5 = create_session()
- u = sess5.merge(u, dont_load=True)
+ u = sess5.merge(u, load=False)
assert len(u.addresses)
for a in u.addresses:
assert a.user is u
def go():
sess5.flush()
# no changes; therefore flush should do nothing
- # but also, dont_load wipes out any difference in committed state,
+ # but also, load=False wipes out any difference in committed state,
# so no flush at all
self.assert_sql_count(testing.db, go, 0)
eq_(on_load.called, 15)
sess4 = create_session()
- u = sess4.merge(u, dont_load=True)
+ u = sess4.merge(u, load=False)
# post merge change
u.addresses[1].email_address='afafds'
def go():
assert u3 is u
@testing.resolve_artifact_names
- def test_transient_dontload(self):
+ def test_transient_no_load(self):
mapper(User, users)
sess = create_session()
u = User()
- assert_raises_message(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
+ assert_raises_message(sa.exc.InvalidRequestError, "load=False option does not support", sess.merge, u, load=False)
+ @testing.resolve_artifact_names
+ def test_dont_load_deprecated(self):
+ mapper(User, users)
+
+ sess = create_session()
+ u = User(name='ed')
+ sess.add(u)
+ sess.flush()
+ u = sess.query(User).first()
+ sess.expunge(u)
+ sess.execute(users.update().values(name='jack'))
+ @testing.uses_deprecated("dont_load=True has been renamed")
+ def go():
+ u1 = sess.merge(u, dont_load=True)
+ assert u1 in sess
+ assert u1.name=='ed'
+ assert u1 not in sess.dirty
+ go()
@testing.resolve_artifact_names
- def test_dontload_with_backrefs(self):
- """dontload populates relations in both directions without requiring a load"""
+ def test_no_load_with_backrefs(self):
+ """load=False populates relations in both directions without requiring a load"""
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses), backref='user')
})
assert 'user' in u.addresses[1].__dict__
sess = create_session()
- u2 = sess.merge(u, dont_load=True)
+ u2 = sess.merge(u, load=False)
assert 'user' in u2.addresses[1].__dict__
eq_(u2.addresses[1].user, User(id=7, name='fred'))
sess.close()
sess = create_session()
- u = sess.merge(u2, dont_load=True)
+ u = sess.merge(u2, load=False)
assert 'user' not in u.addresses[1].__dict__
eq_(u.addresses[1].user, User(id=7, name='fred'))
def test_dontload_with_eager(self):
"""
- This test illustrates that with dont_load=True, we can't just copy the
+ This test illustrates that with load=False, we can't just copy the
committed_state of the merged instance over; since it references
collection objects which themselves are to be merged. This
committed_state would instead need to be piecemeal 'converted' to
represent the correct objects. However, at the moment I'd rather not
- support this use case; if you are merging with dont_load=True, you're
+ support this use case; if you are merging with load=False, you're
typically dealing with caching and the merged objects shouldnt be
'dirty'.
u2 = sess2.query(User).options(sa.orm.eagerload('addresses')).get(7)
sess3 = create_session()
- u3 = sess3.merge(u2, dont_load=True)
+ u3 = sess3.merge(u2, load=False)
def go():
sess3.flush()
self.assert_sql_count(testing.db, go, 0)
@testing.resolve_artifact_names
- def test_dont_load_disallows_dirty(self):
- """dont_load doesnt support 'dirty' objects right now
+ def test_no_load_disallows_dirty(self):
+ """load=False doesnt support 'dirty' objects right now
- (see test_dont_load_with_eager()). Therefore lets assert it.
+ (see test_no_load_with_eager()). Therefore lets assert it.
"""
mapper(User, users)
u.name = 'ed'
sess2 = create_session()
try:
- sess2.merge(u, dont_load=True)
+ sess2.merge(u, load=False)
assert False
except sa.exc.InvalidRequestError, e:
- assert ("merge() with dont_load=True option does not support "
+ assert ("merge() with load=False option does not support "
"objects marked as 'dirty'. flush() all changes on mapped "
- "instances before merging with dont_load=True.") in str(e)
+ "instances before merging with load=False.") in str(e)
u2 = sess2.query(User).get(7)
sess3 = create_session()
- u3 = sess3.merge(u2, dont_load=True)
+ u3 = sess3.merge(u2, load=False)
assert not sess3.dirty
def go():
sess3.flush()
@testing.resolve_artifact_names
- def test_dont_load_sets_backrefs(self):
+ def test_no_load_sets_backrefs(self):
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses),backref='user')})
assert u.addresses[0].user is u
sess2 = create_session()
- u2 = sess2.merge(u, dont_load=True)
+ u2 = sess2.merge(u, load=False)
assert not sess2.dirty
def go():
assert u2.addresses[0].user is u2
self.assert_sql_count(testing.db, go, 0)
@testing.resolve_artifact_names
- def test_dont_load_preserves_parents(self):
- """Merge with dont_load does not trigger a 'delete-orphan' operation.
+ def test_no_load_preserves_parents(self):
+ """Merge with load=False does not trigger a 'delete-orphan' operation.
- merge with dont_load sets attributes without using events. this means
+ merge with load=False sets attributes without using events. this means
the 'hasparent' flag is not propagated to the newly merged instance.
in fact this works out OK, because the '_state.parents' collection on
the newly merged instance is empty; since the mapper doesn't see an
assert u.addresses[0].user is u
sess2 = create_session()
- u2 = sess2.merge(u, dont_load=True)
+ u2 = sess2.merge(u, load=False)
assert not sess2.dirty
a2 = u2.addresses[0]
a2.email_address='somenewaddress'
# this use case is not supported; this is with a pending Address on
# the pre-merged object, and we currently dont support 'dirty' objects
- # being merged with dont_load=True. in this case, the empty
+ # being merged with load=False. in this case, the empty
# '_state.parents' collection would be an issue, since the optimistic
# flag is False in _is_orphan() for pending instances. so if we start
- # supporting 'dirty' with dont_load=True, this test will need to pass
+ # supporting 'dirty' with load=False, this test will need to pass
sess = create_session()
u = sess.query(User).get(7)
u.addresses.append(Address())
sess2 = create_session()
try:
- u2 = sess2.merge(u, dont_load=True)
+ u2 = sess2.merge(u, load=False)
assert False
- # if dont_load is changed to support dirty objects, this code
+ # if load=False is changed to support dirty objects, this code
# needs to pass
a2 = u2.addresses[0]
a2.email_address='somenewaddress'
eq_(sess2.query(User).get(u2.id).addresses[0].email_address,
'somenewaddress')
except sa.exc.InvalidRequestError, e:
- assert "dont_load=True option does not support" in str(e)
+ assert "load=False option does not support" in str(e)
@testing.resolve_artifact_names
def test_synonym_comparable(self):
@classmethod
def define_tables(cls, metadata):
Table("data", metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', PickleType(comparator=operator.eq))
)
import sqlalchemy as sa
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, create_session
from sqlalchemy.test.testing import eq_
from test.orm import _base
@classmethod
def define_tables(cls, metadata):
+ if testing.against('oracle'):
+ fk_args = dict(deferrable=True, initially='deferred')
+ else:
+ fk_args = dict(onupdate='cascade')
+
users = Table('users', metadata,
Column('username', String(50), primary_key=True),
Column('fullname', String(100)),
addresses = Table('addresses', metadata,
Column('email', String(50), primary_key=True),
- Column('username', String(50), ForeignKey('users.username', onupdate="cascade")),
+ Column('username', String(50), ForeignKey('users.username', **fk_args)),
test_needs_fk=True)
items = Table('items', metadata,
test_needs_fk=True)
users_to_items = Table('users_to_items', metadata,
- Column('username', String(50), ForeignKey('users.username', onupdate='cascade'), primary_key=True),
- Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True),
+ Column('username', String(50), ForeignKey('users.username', **fk_args), primary_key=True),
+ Column('itemname', String(50), ForeignKey('items.itemname', **fk_args), primary_key=True),
test_needs_fk=True)
@classmethod
@testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+ @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
def test_onetomany_passive(self):
self._test_onetomany(True)
@testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+ @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
def test_manytoone_passive(self):
self._test_manytoone(True)
eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
@testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+ @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
def test_onetoone_passive(self):
self._test_onetoone(True)
eq_([Address(username='ed')], sess.query(Address).all())
@testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+ @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
def test_bidirectional_passive(self):
self._test_bidirectional(True)
@testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+ @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
def test_manytomany_passive(self):
self._test_manytomany(True)
- @testing.fails_on('mysql', 'the executemany() of the association table fails to report the correct row count')
+ # mysqldb executemany() of the association table fails to report the correct row count
+ @testing.fails_if(lambda: testing.against('mysql') and not testing.against('+zxjdbc'))
def test_manytomany_nonpassive(self):
self._test_manytomany(False)
@classmethod
def define_tables(cls, metadata):
+ if testing.against('oracle'):
+ fk_args = dict(deferrable=True, initially='deferred')
+ else:
+ fk_args = dict(onupdate='cascade')
+
Table('nodes', metadata,
Column('name', String(50), primary_key=True),
Column('parent', String(50),
- ForeignKey('nodes.name', onupdate='cascade')))
+ ForeignKey('nodes.name', **fk_args)))
@classmethod
def setup_classes(cls):
class NonPKCascadeTest(_base.MappedTest):
@classmethod
def define_tables(cls, metadata):
+ if testing.against('oracle'):
+ fk_args = dict(deferrable=True, initially='deferred')
+ else:
+ fk_args = dict(onupdate='cascade')
+
Table('users', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('username', String(50), unique=True),
Column('fullname', String(100)),
test_needs_fk=True)
Table('addresses', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('email', String(50)),
Column('username', String(50),
- ForeignKey('users.username', onupdate="cascade")),
+ ForeignKey('users.username', **fk_args)),
test_needs_fk=True
)
pass
@testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+ @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
def test_onetomany_passive(self):
self._test_onetomany(True)
import sqlalchemy as sa
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, create_session
from test.orm import _base
@classmethod
def define_tables(cls, metadata):
Table('jack', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('number', String(50)),
Column('status', String(20)),
Column('subroom', String(5)))
Table('port', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(30)),
Column('description', String(100)),
Column('jack_id', Integer, ForeignKey("jack.id")))
import sqlalchemy as sa
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, create_session, attributes
from test.orm import _base, _fixtures
u2 = pickle.loads(pickle.dumps(u1))
sess2 = create_session()
- u2 = sess2.merge(u2, dont_load=True)
+ u2 = sess2.merge(u2, load=False)
eq_(u2.name, 'ed')
eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
u2 = pickle.loads(pickle.dumps(u1))
sess2 = create_session()
- u2 = sess2.merge(u2, dont_load=True)
+ u2 = sess2.merge(u2, load=False)
eq_(u2.name, 'ed')
assert 'addresses' not in u2.__dict__
ad = u2.addresses[0]
@classmethod
def define_tables(cls, metadata):
Table('users', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(30)),
Column('type', String(30)))
Table('email_users', metadata,
pass
s = users.select(users.c.id!=12).alias('users')
m = mapper(SomeUser, s)
- print s.primary_key
- print m.primary_key
assert s.primary_key == m.primary_key
- row = s.select(use_labels=True).execute().fetchone()
- print row[s.primary_key[0]]
-
sess = create_session()
assert sess.query(SomeUser).get(7).name == 'jack'
@testing.requires.unicode_connections
def test_unicode(self):
"""test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail
- on postgres, mysql and oracle unless it is converted to an encoded string"""
+ on postgresql, mysql and oracle unless it is converted to an encoded string"""
metadata = MetaData(engines.utf8_engine())
table = Table('unicode_data', metadata,
- Column('id', Unicode(40), primary_key=True),
+ Column('id', Unicode(40), primary_key=True, test_needs_autoincrement=True),
Column('data', Unicode(40)))
try:
metadata.create_all()
+ # Py3K
+ #ustring = 'petit voix m\xe2\x80\x99a'
+ # Py2K
ustring = 'petit voix m\xe2\x80\x99a'.decode('utf-8')
+ # end Py2K
+
table.insert().execute(id=ustring, data=ustring)
class LocalFoo(Base):
pass
assert u.addresses[0].email_address == 'jack@bean.com'
assert u.orders[1].items[2].description == 'item 5'
- @testing.fails_on_everything_except('sqlite', 'mssql')
+ @testing.fails_on_everything_except('sqlite', '+pyodbc', '+zxjdbc')
def test_query_str(self):
s = create_session()
q = s.query(User).filter(User.id==1)
def test_arithmetic(self):
create_session().query(User)
for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
- (operator.sub, '-'), (operator.div, '/'),
+ (operator.sub, '-'),
+ # Py3k
+ #(operator.truediv, '/'),
+ # Py2K
+ (operator.div, '/'),
+ # end Py2K
):
for (lhs, rhs, res) in (
(5, User.id, ':id_1 %s users.id'),
sess = create_session()
self.assert_compile(sess.query(users).select_from(users.select()).with_labels().statement,
- "SELECT users.id AS users_id, users.name AS users_name FROM users, (SELECT users.id AS id, users.name AS name FROM users) AS anon_1")
+ "SELECT users.id AS users_id, users.name AS users_name FROM users, "
+ "(SELECT users.id AS id, users.name AS name FROM users) AS anon_1",
+ dialect=default.DefaultDialect()
+ )
self.assert_compile(sess.query(users, exists([1], from_obj=addresses)).with_labels().statement,
- "SELECT users.id AS users_id, users.name AS users_name, EXISTS (SELECT 1 FROM addresses) AS anon_1 FROM users")
+ "SELECT users.id AS users_id, users.name AS users_name, EXISTS "
+ "(SELECT 1 FROM addresses) AS anon_1 FROM users",
+ dialect=default.DefaultDialect()
+ )
# a little tedious here, adding labels to work around Query's auto-labelling.
# also correlate needed explicitly. hmmm.....
sess = create_session()
assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all()
- assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).all()
+ assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == \
+ sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).order_by(Address.id).all()
- assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+ assert [Address(id=2), Address(id=3), Address(id=4)] == \
+ sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all()
# test has() doesn't overcorrelate
- assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+ assert [Address(id=2), Address(id=3), Address(id=4)] == \
+ sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all()
# test has() doesnt' get subquery contents adapted by aliased join
- assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+ assert [Address(id=2), Address(id=3), Address(id=4)] == \
+ sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all()
dingaling = sess.query(Dingaling).get(2)
assert [User(id=9)] == sess.query(User).filter(User.addresses.any(Address.dingaling==dingaling)).all()
assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all()
# m2m
- eq_(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)])
- eq_(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)])
+ eq_(sess.query(Item).filter(Item.keywords==None).order_by(Item.id).all(), [Item(id=4), Item(id=5)])
+ eq_(sess.query(Item).filter(Item.keywords!=None).order_by(Item.id).all(), [Item(id=1),Item(id=2), Item(id=3)])
def test_filter_by(self):
sess = create_session()
sess = create_session()
# o2o
- eq_([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all())
- eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all())
+ eq_([Address(id=1), Address(id=3), Address(id=4)],
+ sess.query(Address).filter(Address.dingaling==None).order_by(Address.id).all())
+ eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).order_by(Address.id).all())
# m2o
eq_([Order(id=5)], sess.query(Order).filter(Order.address==None).all())
s = create_session()
+ oracle_as = not testing.against('oracle') and "AS " or ""
+
self.assert_compile(
s.query(User).options(eagerload(User.addresses)).from_self().statement,
"SELECT anon_1.users_id, anon_1.users_name, addresses_1.id, addresses_1.user_id, "\
- "addresses_1.email_address FROM (SELECT users.id AS users_id, users.name AS users_name FROM users) AS anon_1 "\
- "LEFT OUTER JOIN addresses AS addresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id"
+ "addresses_1.email_address FROM (SELECT users.id AS users_id, users.name AS users_name FROM users) %(oracle_as)sanon_1 "\
+ "LEFT OUTER JOIN addresses %(oracle_as)saddresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id" % {
+ 'oracle_as':oracle_as
+ }
)
def test_aliases(self):
class DistinctTest(QueryTest):
def test_basic(self):
- assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).distinct().all()
- assert [User(id=7), User(id=9), User(id=8),User(id=10)] == create_session().query(User).distinct().order_by(desc(User.name)).all()
+ eq_(
+ [User(id=7), User(id=8), User(id=9),User(id=10)],
+ create_session().query(User).order_by(User.id).distinct().all()
+ )
+ eq_(
+ [User(id=7), User(id=9), User(id=8),User(id=10)],
+ create_session().query(User).distinct().order_by(desc(User.name)).all()
+ )
def test_joined(self):
"""test that orderbys from a joined table get placed into the columns clause when DISTINCT is used"""
class YieldTest(QueryTest):
def test_basic(self):
- import gc
sess = create_session()
q = iter(sess.query(User).yield_per(1).from_statement("select * from users"))
def define_tables(cls, metadata):
global t1, t2, t1t2_1, t1t2_2
t1 = Table('t1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30))
)
t2 = Table('t2', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30))
)
eq_(list(q2), [(u'jack',), (u'ed',)])
q = sess.query(User)
- q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String))
+ q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String(50)))
eq_(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')])
q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address)
eq_(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')])
@testing.fails_on('mssql', 'FIXME: unknown')
+ @testing.fails_on('oracle', "Oracle doesn't support boolean expressions as columns")
+ @testing.fails_on('postgresql+pg8000', "pg8000 parses the SQL itself before passing on to PG, doesn't parse this")
+ @testing.fails_on('postgresql+zxjdbc', "zxjdbc parses the SQL itself before passing on to PG, doesn't parse this")
def test_values_with_boolean_selects(self):
"""Tests a values clause that works with select boolean evaluations"""
sess = create_session()
q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%')))
eq_(list(q2), [(True, 1), (False, 3)])
+ q2 = q.order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'))
+ eq_(list(q2), [(True,), (False,), (False,), (False,)])
+
+
def test_correlated_subquery(self):
"""test that a subquery constructed from ORM attributes doesn't leak out
those entities to the outermost query.
sess = create_session()
eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
eq_(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
- eq_(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
+ eq_(sess.query(Node).filter(~Node.children.any()).order_by(Node.id).all(),
+ [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
def test_has(self):
sess = create_session()
- eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+ eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).order_by(Node.id).all(),
+ [Node(data='n121'),Node(data='n122'),Node(data='n123')])
eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])
for x in range(2):
sess.expunge_all()
def go():
- eq_(sess.query(Address).options(eagerload('user')).all(), address_result)
+ eq_(sess.query(Address).options(eagerload('user')).order_by(Address.id).all(), address_result)
self.assert_sql_count(testing.db, go, 1)
ualias = aliased(User)
)
ua = aliased(User)
- eq_(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(),
+ eq_(sess.query(Address, ua.concat, ua.count).
+ select_from(join(Address, ua, 'user')).
+ options(eagerload(Address.user)).order_by(Address.id).all(),
[
(Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1),
(Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3),
def define_tables(cls, metadata):
global base, sub1, sub2
base = Table('base', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50))
)
@classmethod
def define_tables(cls, metadata):
Table('users', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(32)),
Column('age', Integer))
Table('documents', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('user_id', None, ForeignKey('users.id')),
Column('title', String(32)))
sess = create_session(bind=testing.db, autocommit=False)
john,jack,jill,jane = sess.query(User).order_by(User.id).all()
- sess.query(User).filter('name = :name').params(name='john').delete()
+ sess.query(User).filter('name = :name').params(name='john').delete('fetch')
assert john not in sess
eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane])
@testing.fails_on('mysql', 'FIXME: unknown')
@testing.resolve_artifact_names
- def test_delete_fallback(self):
+ def test_delete_invalid_evaluation(self):
sess = create_session(bind=testing.db, autocommit=False)
john,jack,jill,jane = sess.query(User).order_by(User.id).all()
- sess.query(User).filter(User.name == select([func.max(User.name)])).delete(synchronize_session='evaluate')
+ assert_raises(sa_exc.InvalidRequestError,
+ sess.query(User).filter(User.name == select([func.max(User.name)])).delete, synchronize_session='evaluate'
+ )
+
+ sess.query(User).filter(User.name == select([func.max(User.name)])).delete(synchronize_session='fetch')
+
assert john not in sess
eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane])
john,jack,jill,jane = sess.query(User).order_by(User.id).all()
- sess.query(User).filter('age > :x').params(x=29).update({'age': User.age - 10}, synchronize_session='evaluate')
+ sess.query(User).filter('age > :x').params(x=29).update({'age': User.age - 10}, synchronize_session='fetch')
eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
sess = create_session(bind=testing.db, autocommit=False)
john,jack,jill,jane = sess.query(User).order_by(User.id).all()
- sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire')
+ sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='fetch')
eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
+ @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
@testing.resolve_artifact_names
def test_update_returns_rowcount(self):
sess = create_session(bind=testing.db, autocommit=False)
rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10})
eq_(rowcount, 2)
+ @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
@testing.resolve_artifact_names
def test_delete_returns_rowcount(self):
sess = create_session(bind=testing.db, autocommit=False)
sess = create_session(bind=testing.db, autocommit=False)
foo,bar,baz = sess.query(Document).order_by(Document.id).all()
- sess.query(Document).filter(Document.user_id == 1).update({'title': Document.title+Document.title}, synchronize_session='expire')
+ sess.query(Document).filter(Document.user_id == 1).update({'title': Document.title+Document.title}, synchronize_session='fetch')
eq_([foo.title, bar.title, baz.title], ['foofoo','barbar', 'baz'])
eq_(sess.query(Document.title).order_by(Document.id).all(), zip(['foofoo','barbar', 'baz']))
sess = create_session(bind=testing.db, autocommit=False)
john,jack,jill,jane = sess.query(User).order_by(User.id).all()
- sess.query(User).options(eagerload(User.documents)).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire')
+ sess.query(User).options(eagerload(User.documents)).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='fetch')
eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
import sqlalchemy as sa
from sqlalchemy.test import testing
from sqlalchemy import Integer, String, ForeignKey, MetaData, and_
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers, sessionmaker
from sqlalchemy.test.testing import eq_, startswith_
from test.orm import _base, _fixtures
@classmethod
def define_tables(cls, metadata):
Table("tbl_a", metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("name", String(128)))
Table("tbl_b", metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("name", String(128)))
Table("tbl_c", metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False),
Column("name", String(128)))
Table("tbl_d", metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False),
Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")),
Column("name", String(128)))
@classmethod
def define_tables(cls, metadata):
Table('company_t', metadata,
- Column('company_id', Integer, primary_key=True),
+ Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', sa.Unicode(30)))
Table('employee_t', metadata,
@classmethod
def define_tables(cls, metadata):
Table("tableA", metadata,
- Column("id",Integer,primary_key=True),
+ Column("id",Integer,primary_key=True, test_needs_autoincrement=True),
Column("foo",Integer,),
test_needs_fk=True)
Table("tableB",metadata,
@testing.fails_on_everything_except('sqlite', 'mysql')
@testing.resolve_artifact_names
def test_nullPKsOK_BtoA(self):
- # postgres cant handle a nullable PK column...?
+ # postgresql cant handle a nullable PK column...?
tableC = Table('tablec', tableA.metadata,
Column('id', Integer, primary_key=True),
Column('a_id', Integer, ForeignKey('tableA.id'),
@classmethod
def define_tables(cls, metadata):
- Table('tags', metadata, Column("id", Integer, primary_key=True),
+ Table('tags', metadata, Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column("data", String(50)),
)
Table('tag_foo', metadata,
- Column("id", Integer, primary_key=True),
+ Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
Column('tagid', Integer),
Column("data", String(50)),
)
@classmethod
def define_tables(cls, metadata):
Table('users', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(50))
)
Table('addresses', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('user_id', Integer),
Column('email', String(50))
)
@classmethod
def define_tables(cls, metadata):
subscriber_table = Table('subscriber', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('dummy', String(10)) # to appease older sqlite version
)
@classmethod
def define_tables(cls, metadata):
Table("a", metadata,
- Column('aid', Integer, primary_key=True),
+ Column('aid', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)))
Table("b", metadata,
- Column('bid', Integer, primary_key=True),
+ Column('bid', Integer, primary_key=True, test_needs_autoincrement=True),
Column("a_id", Integer, ForeignKey("a.aid")),
Column('data', String(30)))
Table("c", metadata,
- Column('cid', Integer, primary_key=True),
+ Column('cid', Integer, primary_key=True, test_needs_autoincrement=True),
Column("b_id", Integer, ForeignKey("b.bid")),
Column('data', String(30)))
Table("d", metadata,
- Column('did', Integer, primary_key=True),
+ Column('did', Integer, primary_key=True, test_needs_autoincrement=True),
Column("a_id", Integer, ForeignKey("a.aid")),
Column('data', String(30)))
@classmethod
def define_tables(cls, metadata):
Table("t1", metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(40)))
Table("t2", metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(40)),
Column('t1id', Integer, ForeignKey('t1.id')))
Table("t3", metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(40)),
Column('t2id', Integer, ForeignKey('t2.id')))
@classmethod
def define_tables(cls, metadata):
Table("t1", metadata,
- Column('t1id', Integer, primary_key=True),
+ Column('t1id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(40)))
Table("t2", metadata,
- Column('t2id', Integer, primary_key=True),
+ Column('t2id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(40)),
Column('t1id_ref', Integer, ForeignKey('t1.t1id')))
Table("t3", metadata,
- Column('t3id', Integer, primary_key=True),
+ Column('t3id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(40)),
Column('t2id_ref', Integer, ForeignKey('t2.t2id')))
@classmethod
def define_tables(cls, metadata):
Table('foos', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('bid1', Integer,ForeignKey('bars.id')),
Column('bid2', Integer,ForeignKey('bars.id')))
Table('bars', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
@testing.resolve_artifact_names
@classmethod
def define_tables(cls, metadata):
Table('foos', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
- Table('bars', metadata, Column('id', Integer, primary_key=True),
+ Table('bars', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('fid1', Integer, ForeignKey('foos.id')),
Column('fid2', Integer, ForeignKey('foos.id')),
Column('data', String(50)))
@classmethod
def define_tables(cls, metadata):
Table('t1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
Table('t2', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)),
Column('t1id', Integer, ForeignKey('t1.id')))
Table('t3', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
Table('t2tot3', metadata,
Column('t2id', Integer, ForeignKey('t2.id')),
@classmethod
def define_tables(cls, metadata):
Table('t1', metadata,
- Column('id', String(50), primary_key=True),
+ Column('id', String(50), primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
Table('t2', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)),
Column('t1id', String(50)))
from sqlalchemy.test import testing
from sqlalchemy.orm import scoped_session
from sqlalchemy import Integer, String, ForeignKey
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, query
from sqlalchemy.test.testing import eq_
from test.orm import _base
+
class _ScopedTest(_base.MappedTest):
"""Adds another lookup bucket to emulate Session globals."""
@classmethod
def define_tables(cls, metadata):
Table('table1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)))
Table('table2', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('someid', None, ForeignKey('table1.id')))
@testing.resolve_artifact_names
@classmethod
def define_tables(cls, metadata):
Table('table1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)))
Table('table2', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('someid', None, ForeignKey('table1.id')))
@classmethod
@classmethod
def define_tables(cls, metadata):
Table('table1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
Column('type', String(30)))
Table('table2', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('someid', None, ForeignKey('table1.id')),
Column('somedata', String(30)))
import sqlalchemy as sa
from sqlalchemy.test import testing
from sqlalchemy import String, Integer, select
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, create_session
from sqlalchemy.test.testing import eq_
from test.orm import _base
@classmethod
def define_tables(cls, metadata):
Table('common', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', Integer),
Column('extra', String(45)))
from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
-import gc
+from sqlalchemy.test.util import gc_collect
import inspect
import pickle
from sqlalchemy.orm import create_session, sessionmaker, attributes
import sqlalchemy as sa
from sqlalchemy.test import engines, testing, config
from sqlalchemy import Integer, String, Sequence
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
from sqlalchemy.orm import mapper, relation, backref, eagerload
from sqlalchemy.test.testing import eq_
from test.engine import _base as engine_base
u = sess.query(User).get(u.id)
q = sess.query(Address).filter(Address.user==u)
del u
- gc.collect()
+ gc_collect()
eq_(q.one(), Address(email_address='foo'))
session = create_session(bind=testing.db)
session.begin()
- session.connection().execute("insert into users (name) values ('user1')")
+ session.connection().execute(users.insert().values(name='user1'))
session.begin(subtransactions=True)
session.begin_nested()
- session.connection().execute("insert into users (name) values ('user2')")
+ session.connection().execute(users.insert().values(name='user2'))
assert session.connection().execute("select count(1) from users").scalar() == 2
session.rollback()
assert session.connection().execute("select count(1) from users").scalar() == 1
- session.connection().execute("insert into users (name) values ('user3')")
+ session.connection().execute(users.insert().values(name='user3'))
session.commit()
assert session.connection().execute("select count(1) from users").scalar() == 2
user = s.query(User).one()
del user
- gc.collect()
+ gc_collect()
assert len(s.identity_map) == 0
user = s.query(User).one()
user.name = 'fred'
del user
- gc.collect()
+ gc_collect()
assert len(s.identity_map) == 1
assert len(s.dirty) == 1
assert None not in s.dirty
s.flush()
- gc.collect()
+ gc_collect()
assert not s.dirty
assert not s.identity_map
s.add(u2)
del u2
- gc.collect()
+ gc_collect()
assert len(s.identity_map) == 1
assert len(s.dirty) == 1
assert None not in s.dirty
s.flush()
- gc.collect()
+ gc_collect()
assert not s.dirty
assert not s.identity_map
eq_(user, User(name="ed", addresses=[Address(email_address="ed1")]))
del user
- gc.collect()
+ gc_collect()
assert len(s.identity_map) == 0
user = s.query(User).options(eagerload(User.addresses)).one()
user.addresses[0].email_address='ed2'
user.addresses[0].user # lazyload
del user
- gc.collect()
+ gc_collect()
assert len(s.identity_map) == 2
s.commit()
eq_(user, User(name="ed", address=Address(email_address="ed1")))
del user
- gc.collect()
+ gc_collect()
assert len(s.identity_map) == 0
user = s.query(User).options(eagerload(User.address)).one()
user.address.user # lazyload
del user
- gc.collect()
+ gc_collect()
assert len(s.identity_map) == 2
s.commit()
user = s.query(User).one()
user = None
print s.identity_map
- import gc
- gc.collect()
+ gc_collect()
assert len(s.identity_map) == 1
user = s.query(User).one()
s.flush()
eq_(users.select().execute().fetchall(), [(user.id, 'u2')])
-
+ @testing.fails_on('+zxjdbc', 'http://www.sqlalchemy.org/trac/ticket/1473')
@testing.resolve_artifact_names
def test_prune(self):
s = create_session(weak_identity_map=False)
self.assert_(len(s.identity_map) == 0)
self.assert_(s.prune() == 0)
s.flush()
- import gc
- gc.collect()
+ gc_collect()
self.assert_(s.prune() == 9)
self.assert_(len(s.identity_map) == 1)
def define_tables(cls, metadata):
global t1
t1 = Table('t1', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50))
)
def _map_it(self, cls):
return mapper(cls, Table('t', sa.MetaData(),
- Column('id', Integer, primary_key=True)))
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True)))
@testing.uses_deprecated()
def _test_instance_guards(self, user_arg):
@classmethod
def define_tables(cls, metadata):
Table('users', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(20)),
test_needs_acid=True)
from sqlalchemy.orm import attributes
from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
-
+from sqlalchemy.test.util import gc_collect
from sqlalchemy.test import testing
from test.orm import _base
from test.orm._fixtures import FixtureTest, User, Address, users, addresses
-import gc
-
class TransactionTest(FixtureTest):
run_setup_mappers = 'once'
run_inserts = None
def setup_mappers(cls):
mapper(User, users, properties={
'addresses':relation(Address, backref='user',
- cascade="all, delete-orphan"),
+ cascade="all, delete-orphan", order_by=addresses.c.id),
})
mapper(Address, addresses)
assert u1_state not in s.identity_map.all_states()
assert u1_state not in s._deleted
del u1
- gc.collect()
+ gc_collect()
assert u1_state.obj() is None
s.rollback()
"WHERE mutable_t.id = :mutable_t_id",
{'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})])
-
@testing.resolve_artifact_names
def test_resurrect(self):
f1 = Foo()
f1.data.y = 19
del f1
-
+
gc.collect()
assert len(session.identity_map) == 1
-
- session.commit()
-
- assert session.query(Foo).one().data == pickleable.Bar(4, 19)
-
-
- @testing.uses_deprecated()
- @testing.resolve_artifact_names
- def test_nocomparison(self):
- """Changes are detected on MutableTypes lacking an __eq__ method."""
- f1 = Foo()
- f1.data = pickleable.BarWithoutCompare(4,5)
- session = create_session(autocommit=False)
- session.add(f1)
session.commit()
- self.sql_count_(0, session.commit)
- session.close()
-
- session = create_session(autocommit=False)
- f2 = session.query(Foo).filter_by(id=f1.id).one()
- self.sql_count_(0, session.commit)
-
- f2.data.y = 19
- self.sql_count_(1, session.commit)
- session.close()
-
- session = create_session(autocommit=False)
- f3 = session.query(Foo).filter_by(id=f1.id).one()
- eq_((f3.data.x, f3.data.y), (4,19))
- self.sql_count_(0, session.commit)
- session.close()
+ assert session.query(Foo).one().data == pickleable.Bar(4, 19)
@testing.resolve_artifact_names
def test_unicode(self):
@classmethod
def define_tables(cls, metadata):
- use_string_defaults = testing.against('postgres', 'oracle', 'sqlite', 'mssql')
+ use_string_defaults = testing.against('postgresql', 'oracle', 'sqlite', 'mssql')
if use_string_defaults:
hohotype = String(30)
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('hoho', hohotype, server_default=str(hohoval)),
- Column('counter', Integer, default=sa.func.char_length("1234567")),
- Column('foober', String(30), default="im foober",
- onupdate="im the update"))
+ Column('counter', Integer, default=sa.func.char_length("1234567", type_=Integer)),
+ Column('foober', String(30), default="im foober", onupdate="im the update"))
st = Table('secondary_table', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(50)))
- if testing.against('postgres', 'oracle'):
+ if testing.against('postgresql', 'oracle'):
dt.append_column(
Column('secondary_id', Integer, sa.Sequence('sec_id_seq'),
unique=True))
# "post-update"
mapper(Hoho, default_t)
- h1 = Hoho(hoho="15", counter="15")
+ h1 = Hoho(hoho="15", counter=15)
session = create_session()
session.add(h1)
session.flush()
def go():
eq_(h1.hoho, "15")
- eq_(h1.counter, "15")
+ eq_(h1.counter, 15)
eq_(h1.foober, "im foober")
self.sql_count_(0, go)
"""A server-side default can be used as the target of a foreign key"""
mapper(Hoho, default_t, properties={
- 'secondaries':relation(Secondary)})
+ 'secondaries':relation(Secondary, order_by=secondary_table.c.id)})
mapper(Secondary, secondary_table)
h1 = Hoho()
@classmethod
def define_tables(cls, metadata):
Table('data', metadata,
- Column('id', Integer, primary_key=True),
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('a', String(50)),
Column('b', String(50))
)
l = sa.select([users, addresses],
sa.and_(users.c.id==addresses.c.user_id,
addresses.c.id==a.id)).execute()
- eq_(l.fetchone().values(),
+ eq_(l.first().values(),
[a.user.id, 'asdf8d', a.id, a.user_id, 'theater@foo.com'])
@testing.resolve_artifact_names
sess.add(o5)
sess.flush()
- assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some t5')]
- assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some t6', 1), (2, 'some other t6', 1)]
+ eq_(
+ list(sess.execute(t5.select(), mapper=T5)),
+ [(1, 'some t5')]
+ )
+ eq_(
+ list(sess.execute(t6.select().order_by(t6.c.id), mapper=T5)),
+ [(1, 'some t6', 1), (2, 'some other t6', 1)]
+ )
o6 = T5(data='some other t5', id=o5.id, t6s=[
T6(data='third t6', id=3),
sess.add(o6)
sess.flush()
- assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some other t5')]
- assert list(sess.execute(t6.select(), mapper=T5)) == [(3, 'third t6', 1), (4, 'fourth t6', 1)]
+ eq_(
+ list(sess.execute(t5.select(), mapper=T5)),
+ [(1, 'some other t5')]
+ )
+ eq_(
+ list(sess.execute(t6.select().order_by(t6.c.id), mapper=T5)),
+ [(3, 'third t6', 1), (4, 'fourth t6', 1)]
+ )
@testing.resolve_artifact_names
def test_manytomany(self):
# todo: on 8.3 at least, the failed commit seems to close the cursor?
# needs investigation. leaving in the DDL above now to help verify
# that the new deferrable support on FK isn't involved in this issue.
- if testing.against('postgres'):
+ if testing.against('postgresql'):
t1.bind.engine.dispose()
assert 'populate_instance' not in carrier
carrier.append(interfaces.MapperExtension)
+
+ # Py3K
+ #assert 'populate_instance' not in carrier
+ # Py2K
assert 'populate_instance' in carrier
-
+ # end Py2K
+
assert carrier.interface
for m in carrier.interface:
assert getattr(interfaces.MapperExtension, m)
alias = aliased(Point)
assert Point.zero
+ # Py2K
+ # TODO: what is this testing ??
assert not getattr(alias, 'zero')
+ # end Py2K
def test_classmethods(self):
class Point(object):
self.func = func
def __get__(self, instance, owner):
if instance is None:
+ # Py3K
+ #args = (self.func, owner)
+ # Py2K
args = (self.func, owner, owner.__class__)
+ # end Py2K
else:
+ # Py3K
+ #args = (self.func, instance)
+ # Py2K
args = (self.func, instance, owner)
+ # end Py2K
return types.MethodType(*args)
class PropertyDescriptor(object):
import sys, time
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import profiling
+from sqlalchemy.test import profiling
db = create_engine('sqlite://')
metadata = MetaData(db)
from sqlalchemy.orm import attributes
import time
-import gc
manage_attributes = True
init_attributes = manage_attributes and True
attributes.manage(a)
a.email = 'foo@bar.com'
u.addresses.append(a)
-# gc.collect()
# print len(managed_attributes)
# managed_attributes.clear()
total = time.time() - now
import testenv; testenv.simple_setup()
-import gc
import random, string
from sqlalchemy.orm import attributes
+from sqlalchemy.test.util import gc_collect
# with this test, run top. make sure the Python process doenst grow in size arbitrarily.
a.user = u
print "clearing"
#managed_attributes.clear()
- gc.collect()
+ gc_collect()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import *
+from sqlalchemy.test import *
NUM = 500
DIVISOR = 50
-import testenv; testenv.configure_for_tests()
import time
-#import gc
#import sqlalchemy.orm.attributes as attributes
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import *
+from sqlalchemy.test import *
"""
NUM = 2500
class LoadTest(TestBase, AssertsExecutionResults):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global items, meta
meta = MetaData(testing.db)
items = Table('items', meta,
Column('item_id', Integer, primary_key=True),
Column('value', String(100)))
items.create()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
items.drop()
- def setUp(self):
+ def setup(self):
for x in range(1,NUM/500+1):
l = []
for y in range(x*500-500 + 1, x*500 + 1):
query = sess.query(Item)
for x in range (1,NUM/100):
# this is not needed with cpython which clears non-circular refs immediately
- #gc.collect()
+ #gc_collect()
l = query.filter(items.c.item_id.between(x*100 - 100 + 1, x*100)).all()
assert len(l) == 100
print "loaded ", len(l), " items "
print "total time ", total
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+import gc
import types
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import *
+from sqlalchemy.test import *
NUM = 2500
class SaveTest(TestBase, AssertsExecutionResults):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global items, metadata
metadata = MetaData(testing.db)
items = Table('items', metadata,
Column('item_id', Integer, primary_key=True),
Column('value', String(100)))
items.create()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
clear_mappers()
metadata.drop_all()
print x
-if __name__ == "__main__":
- testenv.main()
import testenv; testenv.simple_setup()
-import time, gc, resource
+import time, resource
from sqlalchemy import *
from sqlalchemy.orm import *
+from sqlalchemy.test.util import gc_collect
db = create_engine('sqlite://')
usage.snap = lambda stats=None: setattr(
usage, 'last', stats or resource.getrusage(resource.RUSAGE_SELF))
- gc.collect()
+ gc_collect()
usage.snap()
t = time.clock()
sqlite_select(RawPerson)
t2 = time.clock()
usage('sqlite select/native')
- gc.collect()
+ gc_collect()
usage.snap()
t = time.clock()
sqlite_select(Person)
t2 = time.clock()
usage('sqlite select/instrumented')
- gc.collect()
+ gc_collect()
usage.snap()
t = time.clock()
sql_select(RawPerson)
t2 = time.clock()
usage('sqlalchemy.sql select/native')
- gc.collect()
+ gc_collect()
usage.snap()
t = time.clock()
sql_select(Person)
t2 = time.clock()
usage('sqlalchemy.sql select/instrumented')
- gc.collect()
+ gc_collect()
usage.snap()
t = time.clock()
orm_select()
-import testenv; testenv.configure_for_tests()
-import time, gc, resource
+import time, resource
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import *
+from sqlalchemy.test import *
+from sqlalchemy.test.util import gc_collect
+
NUM = 100
session = create_session()
- gc.collect()
+ gc_collect()
usage.snap()
t = time.clock()
people = orm_select(session)
t2 = time.clock()
usage('load objects')
- gc.collect()
+ gc_collect()
usage.snap()
t = time.clock()
update_and_flush(session, people)
-import testenv; testenv.configure_for_tests()
import time
from datetime import datetime
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import *
-from testlib.profiling import profiled
+from sqlalchemy.test import *
+from sqlalchemy.test.profiling import profiled
class Item(object):
def __repr__(self):
# load test of connection pool
-import testenv; testenv.configure_for_tests()
import thread, time
from sqlalchemy import *
import sqlalchemy.pool as pool
-from testlib import testing
+from sqlalchemy.test import testing
db = create_engine(testing.db.url, pool_timeout=30, echo_pool=True)
metadata = MetaData(db)
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-import gc
-from testlib import TestBase, AssertsExecutionResults, profiling, testing
-from orm import _fixtures
+from sqlalchemy.test.compat import gc_collect
+from sqlalchemy.test import TestBase, AssertsExecutionResults, profiling, testing
+from test.orm import _fixtures
# in this test we are specifically looking for time spent in the attributes.InstanceState.__cleanup() method.
ITERATIONS = 100
class SessionTest(TestBase, AssertsExecutionResults):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global t1, t2, metadata,T1, T2
metadata = MetaData(testing.db)
t1 = Table('t1', metadata,
})
mapper(T2, t2)
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
clear_mappers()
sess.close()
del sess
- gc.collect()
+ gc_collect()
@profiling.profiled('dirty', report=True)
def test_session_dirty(self):
t2.c2 = 'this is some modified text'
del t1s
- gc.collect()
+ gc_collect()
sess.close()
del sess
- gc.collect()
+ gc_collect()
@profiling.profiled('noclose', report=True)
def test_session_noclose(self):
t1s[index].t2s
del sess
- gc.collect()
-
+ gc_collect()
-if __name__ == '__main__':
- testenv.main()
#!/usr/bin/python
"""Uses ``wsgiref``, standard in Python 2.5 and also in the cheeseshop."""
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
import thread
-from testlib import *
+from sqlalchemy.test import *
port = 8000
-from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
from sqlalchemy import *
-from sqlalchemy import exc
+from sqlalchemy import exc, schema
from sqlalchemy.test import *
from sqlalchemy.test import config, engines
+from sqlalchemy.engine import ddl
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.assertsql import AllOf, RegexSQL, ExactSQL, CompiledSQL
+from sqlalchemy.dialects.postgresql import base as postgresql
-class ConstraintTest(TestBase, AssertsExecutionResults):
+class ConstraintTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
def setup(self):
global metadata
def test_double_fk_usage_raises(self):
f = ForeignKey('b.id')
- assert_raises(exc.InvalidRequestError, Table, "a", metadata,
- Column('x', Integer, f),
- Column('y', Integer, f)
- )
-
+ Column('x', Integer, f)
+ assert_raises(exc.InvalidRequestError, Column, "y", Integer, f)
def test_circular_constraint(self):
a = Table("a", metadata,
metadata.create_all()
foo.insert().execute(id=1,x=9,y=5)
- try:
- foo.insert().execute(id=2,x=5,y=9)
- assert False
- except exc.SQLError:
- assert True
-
+ assert_raises(exc.SQLError, foo.insert().execute, id=2,x=5,y=9)
bar.insert().execute(id=1,x=10)
- try:
- bar.insert().execute(id=2,x=5)
- assert False
- except exc.SQLError:
- assert True
+ assert_raises(exc.SQLError, bar.insert().execute, id=2,x=5)
def test_unique_constraint(self):
foo = Table('foo', metadata,
foo.insert().execute(id=2, value='value2')
bar.insert().execute(id=1, value='a', value2='a')
bar.insert().execute(id=2, value='a', value2='b')
- try:
- foo.insert().execute(id=3, value='value1')
- assert False
- except exc.SQLError:
- assert True
- try:
- bar.insert().execute(id=3, value='a', value2='b')
- assert False
- except exc.SQLError:
- assert True
+ assert_raises(exc.SQLError, foo.insert().execute, id=3, value='value1')
+ assert_raises(exc.SQLError, bar.insert().execute, id=3, value='a', value2='b')
def test_index_create(self):
employees = Table('employees', metadata,
Index('sport_announcer', events.c.sport, events.c.announcer, unique=True)
Index('idx_winners', events.c.winner)
- index_names = [ ix.name for ix in events.indexes ]
- assert 'ix_events_name' in index_names
- assert 'ix_events_location' in index_names
- assert 'sport_announcer' in index_names
- assert 'idx_winners' in index_names
- assert len(index_names) == 4
-
- capt = []
- connection = testing.db.connect()
- # TODO: hacky, put a real connection proxy in
- ex = connection._Connection__execute_context
- def proxy(context):
- capt.append(context.statement)
- capt.append(repr(context.parameters))
- ex(context)
- connection._Connection__execute_context = proxy
- schemagen = testing.db.dialect.schemagenerator(testing.db.dialect, connection)
- schemagen.traverse(events)
-
- assert capt[0].strip().startswith('CREATE TABLE events')
-
- s = set([capt[x].strip() for x in [2,4,6,8]])
-
- assert s == set([
- 'CREATE UNIQUE INDEX ix_events_name ON events (name)',
- 'CREATE INDEX ix_events_location ON events (location)',
- 'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)',
- 'CREATE INDEX idx_winners ON events (winner)'
- ])
+ eq_(
+ set([ ix.name for ix in events.indexes ]),
+ set(['ix_events_name', 'ix_events_location', 'sport_announcer', 'idx_winners'])
+ )
+
+ self.assert_sql_execution(
+ testing.db,
+ lambda: events.create(testing.db),
+ RegexSQL("^CREATE TABLE events"),
+ AllOf(
+ ExactSQL('CREATE UNIQUE INDEX ix_events_name ON events (name)'),
+ ExactSQL('CREATE INDEX ix_events_location ON events (location)'),
+ ExactSQL('CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'),
+ ExactSQL('CREATE INDEX idx_winners ON events (winner)')
+ )
+ )
# verify that the table is functional
events.insert().execute(id=1, name='hockey finals', location='rink',
dialect = testing.db.dialect.__class__()
dialect.max_identifier_length = 20
- schemagen = dialect.schemagenerator(dialect, None)
- schemagen.execute = lambda : None
-
t1 = Table("sometable", MetaData(), Column("foo", Integer))
- schemagen.visit_index(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
- eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)")
- schemagen.buffer.truncate(0)
- schemagen.visit_index(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo))
- eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)")
-
- schemadrop = dialect.schemadropper(dialect, None)
- schemadrop.execute = lambda: None
- assert_raises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
+ self.assert_compile(
+ schema.CreateIndex(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)),
+ "CREATE INDEX this_name_is_t_1 ON sometable (foo)",
+ dialect=dialect
+ )
+
+ self.assert_compile(
+ schema.CreateIndex(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)),
+ "CREATE INDEX this_other_nam_1 ON sometable (foo)",
+ dialect=dialect
+ )
-class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
- class accum(object):
- def __init__(self):
- self.statements = []
- def __call__(self, sql, *a, **kw):
- self.statements.append(sql)
- def __contains__(self, substring):
- for s in self.statements:
- if substring in s:
- return True
- return False
- def __str__(self):
- return '\n'.join([repr(x) for x in self.statements])
- def clear(self):
- del self.statements[:]
-
- def setup(self):
- self.sql = self.accum()
- opts = config.db_opts.copy()
- opts['strategy'] = 'mock'
- opts['executor'] = self.sql
- self.engine = engines.testing_engine(options=opts)
-
+class ConstraintCompilationTest(TestBase, AssertsCompiledSQL):
def _test_deferrable(self, constraint_factory):
- meta = MetaData(self.engine)
- t = Table('tbl', meta,
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer),
constraint_factory(deferrable=True))
- t.create()
- assert 'DEFERRABLE' in self.sql, self.sql
- assert 'NOT DEFERRABLE' not in self.sql, self.sql
- self.sql.clear()
- meta.clear()
-
- t = Table('tbl', meta,
+
+ sql = str(schema.CreateTable(t).compile(bind=testing.db))
+ assert 'DEFERRABLE' in sql, sql
+ assert 'NOT DEFERRABLE' not in sql, sql
+
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer),
constraint_factory(deferrable=False))
- t.create()
- assert 'NOT DEFERRABLE' in self.sql
- self.sql.clear()
- meta.clear()
- t = Table('tbl', meta,
+ sql = str(schema.CreateTable(t).compile(bind=testing.db))
+ assert 'NOT DEFERRABLE' in sql
+
+
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer),
constraint_factory(deferrable=True, initially='IMMEDIATE'))
- t.create()
- assert 'NOT DEFERRABLE' not in self.sql
- assert 'INITIALLY IMMEDIATE' in self.sql
- self.sql.clear()
- meta.clear()
+ sql = str(schema.CreateTable(t).compile(bind=testing.db))
+ assert 'NOT DEFERRABLE' not in sql
+ assert 'INITIALLY IMMEDIATE' in sql
- t = Table('tbl', meta,
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer),
constraint_factory(deferrable=True, initially='DEFERRED'))
- t.create()
+ sql = str(schema.CreateTable(t).compile(bind=testing.db))
- assert 'NOT DEFERRABLE' not in self.sql
- assert 'INITIALLY DEFERRED' in self.sql, self.sql
+ assert 'NOT DEFERRABLE' not in sql
+ assert 'INITIALLY DEFERRED' in sql
def test_deferrable_pk(self):
factory = lambda **kw: PrimaryKeyConstraint('a', **kw)
self._test_deferrable(factory)
def test_deferrable_column_fk(self):
- meta = MetaData(self.engine)
- t = Table('tbl', meta,
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer,
ForeignKey('tbl.a', deferrable=True,
initially='DEFERRED')))
- t.create()
- assert 'DEFERRABLE' in self.sql, self.sql
- assert 'INITIALLY DEFERRED' in self.sql, self.sql
+
+ self.assert_compile(
+ schema.CreateTable(t),
+ "CREATE TABLE tbl (a INTEGER, b INTEGER, FOREIGN KEY(b) REFERENCES tbl (a) DEFERRABLE INITIALLY DEFERRED)",
+ )
def test_deferrable_unique(self):
factory = lambda **kw: UniqueConstraint('b', **kw)
self._test_deferrable(factory)
def test_deferrable_column_check(self):
- meta = MetaData(self.engine)
- t = Table('tbl', meta,
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer,
CheckConstraint('a < b',
deferrable=True,
initially='DEFERRED')))
- t.create()
- assert 'DEFERRABLE' in self.sql, self.sql
- assert 'INITIALLY DEFERRED' in self.sql, self.sql
+
+ self.assert_compile(
+ schema.CreateTable(t),
+ "CREATE TABLE tbl (a INTEGER, b INTEGER CHECK (a < b) DEFERRABLE INITIALLY DEFERRED)"
+ )
+
+ def test_use_alter(self):
+ m = MetaData()
+ t = Table('t', m,
+ Column('a', Integer),
+ )
+
+ t2 = Table('t2', m,
+ Column('a', Integer, ForeignKey('t.a', use_alter=True, name='fk_ta')),
+ Column('b', Integer, ForeignKey('t.a', name='fk_tb')), # to ensure create ordering ...
+ )
+
+ e = engines.mock_engine(dialect_name='postgresql')
+ m.create_all(e)
+ m.drop_all(e)
+
+ e.assert_sql([
+ 'CREATE TABLE t (a INTEGER)',
+ 'CREATE TABLE t2 (a INTEGER, b INTEGER, CONSTRAINT fk_tb FOREIGN KEY(b) REFERENCES t (a))',
+ 'ALTER TABLE t2 ADD CONSTRAINT fk_ta FOREIGN KEY(a) REFERENCES t (a)',
+ 'ALTER TABLE t2 DROP CONSTRAINT fk_ta',
+ 'DROP TABLE t2',
+ 'DROP TABLE t'
+ ])
+
+
+ def test_add_drop_constraint(self):
+ m = MetaData()
+
+ t = Table('tbl', m,
+ Column('a', Integer),
+ Column('b', Integer)
+ )
+
+ t2 = Table('t2', m,
+ Column('a', Integer),
+ Column('b', Integer)
+ )
+
+ constraint = CheckConstraint('a < b',name="my_test_constraint", deferrable=True,initially='DEFERRED', table=t)
+ self.assert_compile(
+ schema.AddConstraint(constraint),
+ "ALTER TABLE tbl ADD CONSTRAINT my_test_constraint CHECK (a < b) DEFERRABLE INITIALLY DEFERRED"
+ )
+
+ self.assert_compile(
+ schema.DropConstraint(constraint),
+ "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint"
+ )
+
+ self.assert_compile(
+ schema.DropConstraint(constraint, cascade=True),
+ "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint CASCADE"
+ )
+ constraint = ForeignKeyConstraint(["b"], ["t2.a"])
+ t.append_constraint(constraint)
+ self.assert_compile(
+ schema.AddConstraint(constraint),
+ "ALTER TABLE tbl ADD FOREIGN KEY(b) REFERENCES t2 (a)"
+ )
+ constraint = ForeignKeyConstraint([t.c.a], [t2.c.b])
+ t.append_constraint(constraint)
+ self.assert_compile(
+ schema.AddConstraint(constraint),
+ "ALTER TABLE tbl ADD FOREIGN KEY(a) REFERENCES t2 (b)"
+ )
+
+ constraint = UniqueConstraint("a", "b", name="uq_cst")
+ t2.append_constraint(constraint)
+ self.assert_compile(
+ schema.AddConstraint(constraint),
+ "ALTER TABLE t2 ADD CONSTRAINT uq_cst UNIQUE (a, b)"
+ )
+
+ constraint = UniqueConstraint(t2.c.a, t2.c.b, name="uq_cs2")
+ self.assert_compile(
+ schema.AddConstraint(constraint),
+ "ALTER TABLE t2 ADD CONSTRAINT uq_cs2 UNIQUE (a, b)"
+ )
+
+ assert t.c.a.primary_key is False
+ constraint = PrimaryKeyConstraint(t.c.a)
+ assert t.c.a.primary_key is True
+ self.assert_compile(
+ schema.AddConstraint(constraint),
+ "ALTER TABLE tbl ADD PRIMARY KEY (a)"
+ )
+
+
from sqlalchemy import Sequence, Column, func
from sqlalchemy.sql import select, text
import sqlalchemy as sa
-from sqlalchemy.test import testing
+from sqlalchemy.test import testing, engines
from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean
from sqlalchemy.test.schema import Table
from sqlalchemy.test.testing import eq_
# since its a "branched" connection
conn.close()
- use_function_defaults = testing.against('postgres', 'mssql', 'maxdb')
+ use_function_defaults = testing.against('postgresql', 'mssql', 'maxdb')
is_oracle = testing.against('oracle')
# select "count(1)" returns different results on different DBs also
assert_raises_message(sa.exc.ArgumentError,
ex_msg,
sa.ColumnDefault, fn)
-
+
def test_arg_signature(self):
def fn1(): pass
def fn2(): pass
assert r.lastrow_has_defaults()
eq_(set(r.context.postfetch_cols),
set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]))
-
+
eq_(t.select(t.c.col1==54).execute().fetchall(),
[(54, 'imthedefault', f, ts, ts, ctexec, True, False,
12, today, None)])
@testing.fails_on('firebird', 'Data type unknown')
def test_insertmany(self):
# MySQL-Python 1.2.2 breaks functions in execute_many :(
- if (testing.against('mysql') and
+ if (testing.against('mysql') and not testing.against('+zxjdbc') and
testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)):
return
def test_insert_values(self):
t.insert(values={'col3':50}).execute()
l = t.select().execute()
- eq_(50, l.fetchone()['col3'])
+ eq_(50, l.first()['col3'])
@testing.fails_on('firebird', 'Data type unknown')
def test_updatemany(self):
# MySQL-Python 1.2.2 breaks functions in execute_many :(
- if (testing.against('mysql') and
+ if (testing.against('mysql') and not testing.against('+zxjdbc') and
testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)):
return
@testing.fails_on('firebird', 'Data type unknown')
def test_update(self):
r = t.insert().execute()
- pk = r.last_inserted_ids()[0]
+ pk = r.inserted_primary_key[0]
t.update(t.c.col1==pk).execute(col4=None, col5=None)
ctexec = currenttime.scalar()
l = t.select(t.c.col1==pk).execute()
- l = l.fetchone()
+ l = l.first()
eq_(l,
(pk, 'im the update', f2, None, None, ctexec, True, False,
13, datetime.date.today(), 'py'))
@testing.fails_on('firebird', 'Data type unknown')
def test_update_values(self):
r = t.insert().execute()
- pk = r.last_inserted_ids()[0]
+ pk = r.inserted_primary_key[0]
t.update(t.c.col1==pk, values={'col3': 55}).execute()
l = t.select(t.c.col1==pk).execute()
- l = l.fetchone()
+ l = l.first()
eq_(55, l['col3'])
- @testing.fails_on_everything_except('postgres')
- def test_passive_override(self):
- """
- Primarily for postgres, tests that when we get a primary key column
- back from reflecting a table which has a default value on it, we
- pre-execute that DefaultClause upon insert, even though DefaultClause
- says "let the database execute this", because in postgres we must have
- all the primary key values in memory before insert; otherwise we can't
- locate the just inserted row.
-
- """
- # TODO: move this to dialect/postgres
- try:
- meta = MetaData(testing.db)
- testing.db.execute("""
- CREATE TABLE speedy_users
- (
- speedy_user_id SERIAL PRIMARY KEY,
-
- user_name VARCHAR NOT NULL,
- user_password VARCHAR NOT NULL
- );
- """, None)
-
- t = Table("speedy_users", meta, autoload=True)
- t.insert().execute(user_name='user', user_password='lala')
- l = t.select().execute().fetchall()
- eq_(l, [(1, 'user', 'lala')])
- finally:
- testing.db.execute("drop table speedy_users", None)
-
class PKDefaultTest(_base.TablesTest):
__requires__ = ('subqueries',)
Column('id', Integer, primary_key=True,
default=sa.select([func.max(t2.c.nextid)]).as_scalar()),
Column('data', String(30)))
-
- @testing.fails_on('mssql', 'FIXME: unknown')
+
+ @testing.requires.returning
+ def test_with_implicit_returning(self):
+ self._test(True)
+
+ def test_regular(self):
+ self._test(False)
+
@testing.resolve_artifact_names
- def test_basic(self):
- t2.insert().execute(nextid=1)
- r = t1.insert().execute(data='hi')
- eq_([1], r.last_inserted_ids())
-
- t2.insert().execute(nextid=2)
- r = t1.insert().execute(data='there')
- eq_([2], r.last_inserted_ids())
+ def _test(self, returning):
+ if not returning and not testing.db.dialect.implicit_returning:
+ engine = testing.db
+ else:
+ engine = engines.testing_engine(options={'implicit_returning':returning})
+ engine.execute(t2.insert(), nextid=1)
+ r = engine.execute(t1.insert(), data='hi')
+ eq_([1], r.inserted_primary_key)
+ engine.execute(t2.insert(), nextid=2)
+ r = engine.execute(t1.insert(), data='there')
+ eq_([2], r.inserted_primary_key)
class PKIncrementTest(_base.TablesTest):
run_define_tables = 'each'
def _test_autoincrement(self, bind):
ids = set()
rs = bind.execute(aitable.insert(), int1=1)
- last = rs.last_inserted_ids()[0]
+ last = rs.inserted_primary_key[0]
self.assert_(last)
self.assert_(last not in ids)
ids.add(last)
rs = bind.execute(aitable.insert(), str1='row 2')
- last = rs.last_inserted_ids()[0]
+ last = rs.inserted_primary_key[0]
self.assert_(last)
self.assert_(last not in ids)
ids.add(last)
rs = bind.execute(aitable.insert(), int1=3, str1='row 3')
- last = rs.last_inserted_ids()[0]
+ last = rs.inserted_primary_key[0]
self.assert_(last)
self.assert_(last not in ids)
ids.add(last)
rs = bind.execute(aitable.insert(values={'int1':func.length('four')}))
- last = rs.last_inserted_ids()[0]
+ last = rs.inserted_primary_key[0]
self.assert_(last)
self.assert_(last not in ids)
ids.add(last)
+ eq_(ids, set([1,2,3,4]))
+
eq_(list(bind.execute(aitable.select().order_by(aitable.c.id))),
[(1, 1, None), (2, None, 'row 2'), (3, 3, 'row 3'), (4, 4, None)])
single.create()
r = single.insert().execute()
- id_ = r.last_inserted_ids()[0]
- assert id_ is not None
+ id_ = r.inserted_primary_key[0]
+ eq_(id_, 1)
eq_(1, sa.select([func.count(sa.text('*'))], from_obj=single).scalar())
def test_autoincrement_fk(self):
nodes.create()
r = nodes.insert().execute(data='foo')
- id_ = r.last_inserted_ids()[0]
+ id_ = r.inserted_primary_key[0]
nodes.insert().execute(data='bar', parent_id=id_)
@testing.fails_on('sqlite', 'FIXME: unknown')
try:
- # postgres + mysql strict will fail on first row,
+ # postgresql + mysql strict will fail on first row,
# mysql in legacy mode fails on second row
nonai.insert().execute(data='row 1')
nonai.insert().execute(data='row 2')
def testseqnonpk(self):
"""test sequences fire off as defaults on non-pk columns"""
- result = sometable.insert().execute(name="somename")
+ engine = engines.testing_engine(options={'implicit_returning':False})
+ result = engine.execute(sometable.insert(), name="somename")
assert 'id' in result.postfetch_cols()
- result = sometable.insert().execute(name="someother")
+ result = engine.execute(sometable.insert(), name="someother")
assert 'id' in result.postfetch_cols()
sometable.insert().execute(
{'name':'name3'},
{'name':'name4'})
- eq_(sometable.select().execute().fetchall(),
+ eq_(sometable.select().order_by(sometable.c.id).execute().fetchall(),
[(1, "somename", 1),
(2, "someother", 2),
(3, "name3", 3),
cartitems.insert().execute(description='there')
r = cartitems.insert().execute(description='lala')
- assert r.last_inserted_ids() and r.last_inserted_ids()[0] is not None
- id_ = r.last_inserted_ids()[0]
+ assert r.inserted_primary_key and r.inserted_primary_key[0] is not None
+ id_ = r.inserted_primary_key[0]
eq_(1,
sa.select([func.count(cartitems.c.cart_id)],
bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect)
self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect)
- if isinstance(dialect, firebird.dialect):
+ if isinstance(dialect, (firebird.dialect, maxdb.dialect, oracle.dialect)):
self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect)
else:
self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect)
for ret, dialect in [
('CURRENT_TIMESTAMP', sqlite.dialect()),
- ('now()', postgres.dialect()),
+ ('now()', postgresql.dialect()),
('now()', mysql.dialect()),
('CURRENT_TIMESTAMP', oracle.dialect())
]:
for ret, dialect in [
('random()', sqlite.dialect()),
- ('random()', postgres.dialect()),
+ ('random()', postgresql.dialect()),
('rand()', mysql.dialect()),
- ('random()', oracle.dialect())
+ ('random', oracle.dialect())
]:
self.assert_compile(func.random(), ret, dialect=dialect)
class ExecuteTest(TestBase):
-
+ @engines.close_first
+ def tearDown(self):
+ pass
+
def test_standalone_execute(self):
x = testing.db.func.current_date().execute().scalar()
y = testing.db.func.current_date().select().execute().scalar()
conn.close()
assert (x == y == z) is True
+ @engines.close_first
def test_update(self):
"""
Tests sending functions and SQL expressions to the VALUES and SET
meta.create_all()
try:
t.insert(values=dict(value=func.length("one"))).execute()
- assert t.select().execute().fetchone()['value'] == 3
+ assert t.select().execute().first()['value'] == 3
t.update(values=dict(value=func.length("asfda"))).execute()
- assert t.select().execute().fetchone()['value'] == 5
+ assert t.select().execute().first()['value'] == 5
r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute()
- id = r.last_inserted_ids()[0]
- assert t.select(t.c.id==id).execute().fetchone()['value'] == 9
+ id = r.inserted_primary_key[0]
+ assert t.select(t.c.id==id).execute().first()['value'] == 9
t.update(values={t.c.value:func.length("asdf")}).execute()
- assert t.select().execute().fetchone()['value'] == 4
+ assert t.select().execute().first()['value'] == 4
print "--------------------------"
t2.insert().execute()
t2.insert(values=dict(value=func.length("one"))).execute()
t2.delete().execute()
t2.insert(values=dict(value=func.length("one") + 8)).execute()
- assert t2.select().execute().fetchone()['value'] == 11
+ assert t2.select().execute().first()['value'] == 11
t2.update(values=dict(value=func.length("asfda"))).execute()
- assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (5, "thisisstuff")
+ assert select([t2.c.value, t2.c.stuff]).execute().first() == (5, "thisisstuff")
t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute()
- print "HI", select([t2.c.value, t2.c.stuff]).execute().fetchone()
- assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (9, "foo")
+ print "HI", select([t2.c.value, t2.c.stuff]).execute().first()
+ assert select([t2.c.value, t2.c.stuff]).execute().first() == (9, "foo")
finally:
meta.drop_all()
- @testing.fails_on_everything_except('postgres')
+ @testing.fails_on_everything_except('postgresql')
def test_as_from(self):
# TODO: shouldnt this work on oracle too ?
x = testing.db.func.current_date().execute().scalar()
# construct a column-based FROM object out of a function, like in [ticket:172]
s = select([sql.column('date', type_=DateTime)], from_obj=[testing.db.func.current_date()])
- q = s.execute().fetchone()[s.c.date]
+ q = s.execute().first()[s.c.date]
r = s.alias('datequery').select().scalar()
assert x == y == z == w == q == r
'd': datetime.date(2010, 5, 1) })
rs = select([extract('year', table.c.dt),
extract('month', table.c.d)]).execute()
- row = rs.fetchone()
+ row = rs.first()
assert row[0] == 2010
assert row[1] == 5
rs.close()
maxlen = testing.db.dialect.max_identifier_length
testing.db.dialect.max_identifier_length = IDENT_LENGTH
+ @engines.close_first
def teardown(self):
table1.delete().execute()
], repr(result)
def test_table_alias_names(self):
- self.assert_compile(
- table2.alias().select(),
- "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs AS table_with_exactly_29_c_1"
- )
+ if testing.against('oracle'):
+ self.assert_compile(
+ table2.alias().select(),
+ "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs table_with_exactly_29_c_1"
+ )
+ else:
+ self.assert_compile(
+ table2.alias().select(),
+ "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs AS table_with_exactly_29_c_1"
+ )
ta = table2.alias()
dialect = default.DefaultDialect()
+from sqlalchemy.test.testing import eq_
import datetime
from sqlalchemy import *
from sqlalchemy import exc, sql
from sqlalchemy.engine import default
from sqlalchemy.test import *
-from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.testing import eq_, assert_raises_message
+from sqlalchemy.test.schema import Table, Column
class QueryTest(TestBase):
global users, users2, addresses, metadata
metadata = MetaData(testing.db)
users = Table('query_users', metadata,
- Column('user_id', INT, primary_key = True),
+ Column('user_id', INT, primary_key=True, test_needs_autoincrement=True),
Column('user_name', VARCHAR(20)),
)
addresses = Table('query_addresses', metadata,
- Column('address_id', Integer, primary_key=True),
+ Column('address_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('user_id', Integer, ForeignKey('query_users.user_id')),
Column('address', String(30)))
)
metadata.create_all()
- def tearDown(self):
+ @engines.close_first
+ def teardown(self):
addresses.delete().execute()
users.delete().execute()
users2.delete().execute()
assert users.count().scalar() == 1
users.update(users.c.user_id == 7).execute(user_name = 'fred')
- assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred'
+ assert users.select(users.c.user_id==7).execute().first()['user_name'] == 'fred'
def test_lastrow_accessor(self):
- """Tests the last_inserted_ids() and lastrow_has_id() functions."""
+ """Tests the inserted_primary_key and lastrow_has_id() functions."""
- def insert_values(table, values):
+ def insert_values(engine, table, values):
"""
Inserts a row into a table, returns the full list of values
INSERTed including defaults that fired off on the DB side and
detects rows that had defaults and post-fetches.
"""
- result = table.insert().execute(**values)
+ result = engine.execute(table.insert(), **values)
ret = values.copy()
- for col, id in zip(table.primary_key, result.last_inserted_ids()):
+ for col, id in zip(table.primary_key, result.inserted_primary_key):
ret[col.key] = id
if result.lastrow_has_defaults():
- criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
- row = table.select(criterion).execute().fetchone()
+ criterion = and_(*[col==id for col, id in zip(table.primary_key, result.inserted_primary_key)])
+ row = engine.execute(table.select(criterion)).first()
for c in table.c:
ret[c.key] = row[c]
return ret
- for supported, table, values, assertvalues in [
- (
- {'unsupported':['sqlite']},
- Table("t1", metadata,
- Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True),
- Column('foo', String(30), primary_key=True)),
- {'foo':'hi'},
- {'id':1, 'foo':'hi'}
- ),
- (
- {'unsupported':['sqlite']},
- Table("t2", metadata,
- Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True),
- Column('foo', String(30), primary_key=True),
- Column('bar', String(30), server_default='hi')
+ if testing.against('firebird', 'postgresql', 'oracle', 'mssql'):
+ test_engines = [
+ engines.testing_engine(options={'implicit_returning':False}),
+ engines.testing_engine(options={'implicit_returning':True}),
+ ]
+ else:
+ test_engines = [testing.db]
+
+ for engine in test_engines:
+ metadata = MetaData()
+ for supported, table, values, assertvalues in [
+ (
+ {'unsupported':['sqlite']},
+ Table("t1", metadata,
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('foo', String(30), primary_key=True)),
+ {'foo':'hi'},
+ {'id':1, 'foo':'hi'}
),
- {'foo':'hi'},
- {'id':1, 'foo':'hi', 'bar':'hi'}
- ),
- (
- {'unsupported':[]},
- Table("t3", metadata,
- Column("id", String(40), primary_key=True),
- Column('foo', String(30), primary_key=True),
- Column("bar", String(30))
+ (
+ {'unsupported':['sqlite']},
+ Table("t2", metadata,
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('foo', String(30), primary_key=True),
+ Column('bar', String(30), server_default='hi')
),
- {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"},
- {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}
- ),
- (
- {'unsupported':[]},
- Table("t4", metadata,
- Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
- Column('foo', String(30), primary_key=True),
- Column('bar', String(30), server_default='hi')
+ {'foo':'hi'},
+ {'id':1, 'foo':'hi', 'bar':'hi'}
),
- {'foo':'hi', 'id':1},
- {'id':1, 'foo':'hi', 'bar':'hi'}
- ),
- (
- {'unsupported':[]},
- Table("t5", metadata,
- Column('id', String(10), primary_key=True),
- Column('bar', String(30), server_default='hi')
+ (
+ {'unsupported':[]},
+ Table("t3", metadata,
+ Column("id", String(40), primary_key=True),
+ Column('foo', String(30), primary_key=True),
+ Column("bar", String(30))
+ ),
+ {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"},
+ {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}
),
- {'id':'id1'},
- {'id':'id1', 'bar':'hi'},
- ),
- ]:
- if testing.db.name in supported['unsupported']:
- continue
- try:
- table.create()
- i = insert_values(table, values)
- assert i == assertvalues, repr(i) + " " + repr(assertvalues)
- finally:
- table.drop()
+ (
+ {'unsupported':[]},
+ Table("t4", metadata,
+ Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
+ Column('foo', String(30), primary_key=True),
+ Column('bar', String(30), server_default='hi')
+ ),
+ {'foo':'hi', 'id':1},
+ {'id':1, 'foo':'hi', 'bar':'hi'}
+ ),
+ (
+ {'unsupported':[]},
+ Table("t5", metadata,
+ Column('id', String(10), primary_key=True),
+ Column('bar', String(30), server_default='hi')
+ ),
+ {'id':'id1'},
+ {'id':'id1', 'bar':'hi'},
+ ),
+ ]:
+ if testing.db.name in supported['unsupported']:
+ continue
+ try:
+ table.create(bind=engine, checkfirst=True)
+ i = insert_values(engine, table, values)
+ assert i == assertvalues, "tablename: %s %r %r" % (table.name, repr(i), repr(assertvalues))
+ finally:
+ table.drop(bind=engine)
+
+ @testing.fails_on('sqlite', "sqlite autoincremnt doesn't work with composite pks")
+ def test_misordered_lastrow(self):
+ related = Table('related', metadata,
+ Column('id', Integer, primary_key=True)
+ )
+ t6 = Table("t6", metadata,
+ Column('manual_id', Integer, ForeignKey('related.id'), primary_key=True),
+ Column('auto_id', Integer, primary_key=True, test_needs_autoincrement=True),
+ )
+ metadata.create_all()
+ r = related.insert().values(id=12).execute()
+ id = r.inserted_primary_key[0]
+ assert id==12
+
+ r = t6.insert().values(manual_id=id).execute()
+ eq_(r.inserted_primary_key, [12, 1])
+
+ def test_autoclose_on_insert(self):
+ if testing.against('firebird', 'postgresql', 'oracle', 'mssql'):
+ test_engines = [
+ engines.testing_engine(options={'implicit_returning':False}),
+ engines.testing_engine(options={'implicit_returning':True}),
+ ]
+ else:
+ test_engines = [testing.db]
+
+ for engine in test_engines:
+
+ r = engine.execute(users.insert(),
+ {'user_name':'jack'},
+ )
+ assert r.closed
+
def test_row_iteration(self):
users.insert().execute(
{'user_id':7, 'user_name':'jack'},
l.append(row)
self.assert_(len(l) == 3)
- @testing.fails_on('firebird', 'Data type unknown')
+ @testing.fails_on('firebird', "kinterbasdb doesn't send full type information")
@testing.requires.subqueries
def test_anonymous_rows(self):
users.insert().execute(
assert row['anon_1'] == 8
assert row['anon_2'] == 10
+ @testing.fails_on('firebird', "kinterbasdb doesn't send full type information")
def test_order_by_label(self):
"""test that a label within an ORDER BY works on each backend.
select([concat]).order_by(concat).execute().fetchall(),
[("test: ed",), ("test: fred",), ("test: jack",)]
)
+
+ eq_(
+ select([concat]).order_by(concat).execute().fetchall(),
+ [("test: ed",), ("test: fred",), ("test: jack",)]
+ )
concat = ("test: " + users.c.user_name).label('thedata')
eq_(
def test_row_comparison(self):
users.insert().execute(user_id = 7, user_name = 'jack')
- rp = users.select().execute().fetchone()
+ rp = users.select().execute().first()
self.assert_(rp == rp)
self.assert_(not(rp != rp))
self.assert_(not (rp != equal))
self.assert_(not (equal != equal))
- @testing.fails_on('mssql', 'No support for boolean logic in column select.')
- @testing.fails_on('oracle', 'FIXME: unknown')
+ @testing.requires.boolean_col_expressions
def test_or_and_as_columns(self):
true, false = literal(True), literal(False)
eq_(testing.db.execute(select([or_(false, false)])).scalar(), False)
eq_(testing.db.execute(select([not_(or_(false, false))])).scalar(), True)
- row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).fetchone()
+ row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).first()
assert row.x == False
assert row.y == False
- row = testing.db.execute(select([or_(true, false).label("x"), and_(true, false).label("y")])).fetchone()
+ row = testing.db.execute(select([or_(true, false).label("x"), and_(true, false).label("y")])).first()
assert row.x == True
assert row.y == False
eq_(expr.execute().fetchall(), result)
+ @testing.fails_on("firebird", "see dialect.test_firebird:MiscTest.test_percents_in_text")
+ @testing.fails_on("oracle", "neither % nor %% are accepted")
+ @testing.fails_on("+pg8000", "can't interpret result column from '%%'")
@testing.emits_warning('.*now automatically escapes.*')
def test_percents_in_text(self):
for expr, result in (
eq_(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )])
- if testing.against('postgres'):
+ if testing.against('postgresql'):
eq_(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )])
eq_(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), [])
s = select([datetable.alias('x').c.today]).as_scalar()
s2 = select([datetable.c.id, s.label('somelabel')])
#print s2.c.somelabel.type
- assert isinstance(s2.execute().fetchone()['somelabel'], datetime.datetime)
+ assert isinstance(s2.execute().first()['somelabel'], datetime.datetime)
finally:
datetable.drop()
users.insert().execute(user_id=2, user_name='jack')
addresses.insert().execute(address_id=1, user_id=2, address='foo@bar.com')
- r = users.select(users.c.user_id==2).execute().fetchone()
+ r = users.select(users.c.user_id==2).execute().first()
self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
-
- r = text("select * from query_users where user_id=2", bind=testing.db).execute().fetchone()
+
+ r = text("select * from query_users where user_id=2", bind=testing.db).execute().first()
self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
-
+
# test slices
- r = text("select * from query_addresses", bind=testing.db).execute().fetchone()
+ r = text("select * from query_addresses", bind=testing.db).execute().first()
self.assert_(r[0:1] == (1,))
self.assert_(r[1:] == (2, 'foo@bar.com'))
self.assert_(r[:-1] == (1, 2))
-
+
# test a little sqlite weirdness - with the UNION, cols come back as "query_users.user_id" in cursor.description
r = text("select query_users.user_id, query_users.user_name from query_users "
- "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().fetchone()
+ "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().first()
self.assert_(r['user_id']) == 1
self.assert_(r['user_name']) == "john"
# test using literal tablename.colname
- r = text('select query_users.user_id AS "query_users.user_id", query_users.user_name AS "query_users.user_name" from query_users', bind=testing.db).execute().fetchone()
+ r = text('select query_users.user_id AS "query_users.user_id", '
+ 'query_users.user_name AS "query_users.user_name" from query_users',
+ bind=testing.db).execute().first()
self.assert_(r['query_users.user_id']) == 1
self.assert_(r['query_users.user_name']) == "john"
# unary experssions
- r = select([users.c.user_name.distinct()]).order_by(users.c.user_name).execute().fetchone()
+ r = select([users.c.user_name.distinct()]).order_by(users.c.user_name).execute().first()
eq_(r[users.c.user_name], 'jack')
eq_(r.user_name, 'jack')
- r.close()
+
+ def test_result_case_sensitivity(self):
+ """test name normalization for result sets."""
+ row = testing.db.execute(
+ select([
+ literal_column("1").label("case_insensitive"),
+ literal_column("2").label("CaseSensitive")
+ ])
+ ).first()
+
+ assert row.keys() == ["case_insensitive", "CaseSensitive"]
+
def test_row_as_args(self):
users.insert().execute(user_id=1, user_name='john')
- r = users.select(users.c.user_id==1).execute().fetchone()
+ r = users.select(users.c.user_id==1).execute().first()
users.delete().execute()
users.insert().execute(r)
- assert users.select().execute().fetchall() == [(1, 'john')]
-
+ eq_(users.select().execute().fetchall(), [(1, 'john')])
+
def test_result_as_args(self):
users.insert().execute([dict(user_id=1, user_name='john'), dict(user_id=2, user_name='ed')])
r = users.select().execute()
def test_ambiguous_column(self):
users.insert().execute(user_id=1, user_name='john')
- r = users.outerjoin(addresses).select().execute().fetchone()
- try:
- print r['user_id']
- assert False
- except exc.InvalidRequestError, e:
- assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \
- str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement."
+ r = users.outerjoin(addresses).select().execute().first()
+ assert_raises_message(
+ exc.InvalidRequestError,
+ "Ambiguous column name",
+ lambda: r['user_id']
+ )
@testing.requires.subqueries
def test_column_label_targeting(self):
users.select().alias('foo'),
users.select().alias(users.name),
):
- row = s.select(use_labels=True).execute().fetchone()
+ row = s.select(use_labels=True).execute().first()
assert row[s.c.user_id] == 7
assert row[s.c.user_name] == 'ed'
def test_keys(self):
users.insert().execute(user_id=1, user_name='foo')
- r = users.select().execute().fetchone()
+ r = users.select().execute().first()
eq_([x.lower() for x in r.keys()], ['user_id', 'user_name'])
def test_items(self):
users.insert().execute(user_id=1, user_name='foo')
- r = users.select().execute().fetchone()
+ r = users.select().execute().first()
eq_([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
def test_len(self):
users.insert().execute(user_id=1, user_name='foo')
- r = users.select().execute().fetchone()
+ r = users.select().execute().first()
eq_(len(r), 2)
- r.close()
- r = testing.db.execute('select user_name, user_id from query_users').fetchone()
+
+ r = testing.db.execute('select user_name, user_id from query_users').first()
eq_(len(r), 2)
- r.close()
- r = testing.db.execute('select user_name from query_users').fetchone()
+ r = testing.db.execute('select user_name from query_users').first()
eq_(len(r), 1)
- r.close()
def test_cant_execute_join(self):
try:
def test_column_order_with_simple_query(self):
# should return values in column definition order
users.insert().execute(user_id=1, user_name='foo')
- r = users.select(users.c.user_id==1).execute().fetchone()
+ r = users.select(users.c.user_id==1).execute().first()
eq_(r[0], 1)
eq_(r[1], 'foo')
eq_([x.lower() for x in r.keys()], ['user_id', 'user_name'])
def test_column_order_with_text_query(self):
# should return values in query order
users.insert().execute(user_id=1, user_name='foo')
- r = testing.db.execute('select user_name, user_id from query_users').fetchone()
+ r = testing.db.execute('select user_name, user_id from query_users').first()
eq_(r[0], 'foo')
eq_(r[1], 1)
eq_([x.lower() for x in r.keys()], ['user_name', 'user_id'])
shadowed.create(checkfirst=True)
try:
shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row')
- r = shadowed.select(shadowed.c.shadow_id==1).execute().fetchone()
+ r = shadowed.select(shadowed.c.shadow_id==1).execute().first()
self.assert_(r.shadow_id == r['shadow_id'] == r[shadowed.c.shadow_id] == 1)
self.assert_(r.shadow_name == r['shadow_name'] == r[shadowed.c.shadow_name] == 'The Shadow')
self.assert_(r.parent == r['parent'] == r[shadowed.c.parent] == 'The Light')
# Null values are not outside any set
assert len(r) == 0
- u = bindparam('search_key')
+ @testing.fails_on('firebird', "kinterbasdb doesn't send full type information")
+ def test_bind_in(self):
+ users.insert().execute(user_id = 7, user_name = 'jack')
+ users.insert().execute(user_id = 8, user_name = 'fred')
+ users.insert().execute(user_id = 9, user_name = None)
- s = users.select(u.in_([]))
- r = s.execute(search_key='john').fetchall()
- assert len(r) == 0
- r = s.execute(search_key=None).fetchall()
- assert len(r) == 0
+ u = bindparam('search_key')
s = users.select(not_(u.in_([])))
r = s.execute(search_key='john').fetchall()
class PercentSchemaNamesTest(TestBase):
"""tests using percent signs, spaces in table and column names.
- Doesn't pass for mysql, postgres, but this is really a
+ Doesn't pass for mysql, postgresql, but this is really a
SQLAlchemy bug - we should be escaping out %% signs for this
operation the same way we do for text() and column labels.
"""
+
@classmethod
@testing.crashes('mysql', 'mysqldb calls name % (params)')
- @testing.crashes('postgres', 'postgres calls name % (params)')
+ @testing.crashes('postgresql', 'postgresql calls name % (params)')
def setup_class(cls):
global percent_table, metadata
metadata = MetaData(testing.db)
@classmethod
@testing.crashes('mysql', 'mysqldb calls name % (params)')
- @testing.crashes('postgres', 'postgres calls name % (params)')
+ @testing.crashes('postgresql', 'postgresql calls name % (params)')
def teardown_class(cls):
metadata.drop_all()
@testing.crashes('mysql', 'mysqldb calls name % (params)')
- @testing.crashes('postgres', 'postgres calls name % (params)')
+ @testing.crashes('postgresql', 'postgresql calls name % (params)')
def test_roundtrip(self):
percent_table.insert().execute(
{'percent%':5, '%(oneofthese)s':7, 'spaces % more spaces':12},
percent_table.update().values({percent_table.c['%(oneofthese)s']:9, percent_table.c['spaces % more spaces']:15}).execute()
eq_(
- percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute().fetchall(),
+ percent_table.select().order_by(percent_table.c['percent%']).execute().fetchall(),
[
(5, 9, 15),
(7, 9, 15),
dict(col2="t3col2r2", col3="bbb", col4="aaa"),
dict(col2="t3col2r3", col3="ccc", col4="bbb"),
])
-
+
+ @engines.close_first
+ def teardown(self):
+ pass
+
@classmethod
def teardown_class(cls):
metadata.drop_all()
found2 = self._fetchall_sorted(u.alias('bar').select().execute())
eq_(found2, wanted)
+ @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs")
def test_union_ordered(self):
(s1, s2) = (
select([t1.c.col3.label('col3'), t1.c.col4.label('col4')],
('ccc', 'aaa')]
eq_(u.execute().fetchall(), wanted)
+ @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs")
@testing.fails_on('maxdb', 'FIXME: unknown')
@testing.requires.subqueries
def test_union_ordered_alias(self):
eq_(u.alias('bar').select().execute().fetchall(), wanted)
@testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
+ @testing.fails_on('firebird', "has trouble extracting anonymous column from union subquery")
@testing.fails_on('mysql', 'FIXME: unknown')
@testing.fails_on('sqlite', 'FIXME: unknown')
def test_union_all(self):
found2 = self._fetchall_sorted(e.alias('foo').select().execute())
eq_(found2, wanted)
+ def test_union_all_lightweight(self):
+ """like test_union_all, but breaks the sub-union into
+ a subquery with an explicit column reference on the outside,
+ more palatable to a wider variety of engines.
+
+ """
+ u = union(
+ select([t1.c.col3]),
+ select([t1.c.col3]),
+ ).alias()
+
+ e = union_all(
+ select([t1.c.col3]),
+ select([u.c.col3])
+ )
+
+ wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)]
+ found1 = self._fetchall_sorted(e.execute())
+ eq_(found1, wanted)
+
+ found2 = self._fetchall_sorted(e.alias('foo').select().execute())
+ eq_(found2, wanted)
+
@testing.crashes('firebird', 'Does not support intersect')
@testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
@testing.fails_on('mysql', 'FIXME: unknown')
order_by=flds.c.idcol).execute().fetchall(),
[(2,),(1,)]
)
+
+
+
def testlabels(self):
"""test the quoting of labels.
- if labels arent quoted, a query in postgres in particular will fail since it produces:
+ if labels arent quoted, a query in postgresql in particular will fail since it produces:
SELECT LaLa.lowercase, LaLa."UPPERCASE", LaLa."MixedCase", LaLa."ASC"
FROM (SELECT DISTINCT "WorstCase1".lowercase AS lowercase, "WorstCase1"."UPPERCASE" AS UPPERCASE, "WorstCase1"."MixedCase" AS MixedCase, "WorstCase1"."ASC" AS ASC \nFROM "WorstCase1") AS LaLa
--- /dev/null
+from sqlalchemy.test.testing import eq_
+from sqlalchemy import *
+from sqlalchemy.test import *
+from sqlalchemy.test.schema import Table, Column
+from sqlalchemy.types import TypeDecorator
+
+
+class ReturningTest(TestBase, AssertsExecutionResults):
+ __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access')
+
+ def setup(self):
+ meta = MetaData(testing.db)
+ global table, GoofyType
+
+ class GoofyType(TypeDecorator):
+ impl = String
+
+ def process_bind_param(self, value, dialect):
+ if value is None:
+ return None
+ return "FOO" + value
+
+ def process_result_value(self, value, dialect):
+ if value is None:
+ return None
+ return value + "BAR"
+
+ table = Table('tables', meta,
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('persons', Integer),
+ Column('full', Boolean),
+ Column('goofy', GoofyType(50))
+ )
+ table.create(checkfirst=True)
+
+ def teardown(self):
+ table.drop()
+
+ @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_column_targeting(self):
+ result = table.insert().returning(table.c.id, table.c.full).execute({'persons': 1, 'full': False})
+
+ row = result.first()
+ assert row[table.c.id] == row['id'] == 1
+ assert row[table.c.full] == row['full'] == False
+
+ result = table.insert().values(persons=5, full=True, goofy="somegoofy").\
+ returning(table.c.persons, table.c.full, table.c.goofy).execute()
+ row = result.first()
+ assert row[table.c.persons] == row['persons'] == 5
+ assert row[table.c.full] == row['full'] == True
+ assert row[table.c.goofy] == row['goofy'] == "FOOsomegoofyBAR"
+
+ @testing.fails_on('firebird', "fb can't handle returning x AS y")
+ @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_labeling(self):
+ result = table.insert().values(persons=6).\
+ returning(table.c.persons.label('lala')).execute()
+ row = result.first()
+ assert row['lala'] == 6
+
+ @testing.fails_on('firebird', "fb/kintersbasdb can't handle the bind params")
+ @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_anon_expressions(self):
+ result = table.insert().values(goofy="someOTHERgoofy").\
+ returning(func.lower(table.c.goofy, type_=GoofyType)).execute()
+ row = result.first()
+ assert row[0] == "foosomeothergoofyBAR"
+
+ result = table.insert().values(persons=12).\
+ returning(table.c.persons + 18).execute()
+ row = result.first()
+ assert row[0] == 30
+
+ @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_update_returning(self):
+ table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+
+ result = table.update(table.c.persons > 4, dict(full=True)).returning(table.c.id).execute()
+ eq_(result.fetchall(), [(1,)])
+
+ result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+ eq_(result2.fetchall(), [(1,True),(2,False)])
+
+ @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_insert_returning(self):
+ result = table.insert().returning(table.c.id).execute({'persons': 1, 'full': False})
+
+ eq_(result.fetchall(), [(1,)])
+
+ @testing.fails_on('postgresql', '')
+ @testing.fails_on('oracle', '')
+ def test_executemany():
+ # return value is documented as failing with psycopg2/executemany
+ result2 = table.insert().returning(table).execute(
+ [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
+
+ if testing.against('firebird', 'mssql'):
+ # Multiple inserts only return the last row
+ eq_(result2.fetchall(), [(3,3,True, None)])
+ else:
+ # nobody does this as far as we know (pg8000?)
+ eq_(result2.fetchall(), [(2, 2, False, None), (3,3,True, None)])
+
+ test_executemany()
+
+ result3 = table.insert().returning(table.c.id).execute({'persons': 4, 'full': False})
+ eq_([dict(row) for row in result3], [{'id': 4}])
+
+
+ @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ @testing.fails_on_everything_except('postgresql', 'firebird')
+ def test_literal_returning(self):
+ if testing.against("postgresql"):
+ literal_true = "true"
+ else:
+ literal_true = "1"
+
+ result4 = testing.db.execute('insert into tables (id, persons, "full") '
+ 'values (5, 10, %s) returning persons' % literal_true)
+ eq_([dict(row) for row in result4], [{'persons': 10}])
+
+ @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+ @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+ def test_delete_returning(self):
+ table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+
+ result = table.delete(table.c.persons > 4).returning(table.c.id).execute()
+ eq_(result.fetchall(), [(1,)])
+
+ result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+ eq_(result2.fetchall(), [(2,False),])
+
+class SequenceReturningTest(TestBase):
+ __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'mssql')
+
+ def setup(self):
+ meta = MetaData(testing.db)
+ global table, seq
+ seq = Sequence('tid_seq')
+ table = Table('tables', meta,
+ Column('id', Integer, seq, primary_key=True),
+ Column('data', String(50))
+ )
+ table.create(checkfirst=True)
+
+ def teardown(self):
+ table.drop()
+
+ def test_insert(self):
+ r = table.insert().values(data='hi').returning(table.c.id).execute()
+ assert r.first() == (1, )
+ assert seq.execute() == 2
from sqlalchemy.sql import table, column, label, compiler
from sqlalchemy.sql.expression import ClauseList
from sqlalchemy.engine import default
-from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
+from sqlalchemy.databases import *
from sqlalchemy.test import *
table1 = table('mytable',
)
self.assert_compile(
- select([cast("data", sqlite.SLInteger)], use_labels=True), # this will work with plain Integer in 0.6
+ select([cast("data", Integer)], use_labels=True), # this will work with plain Integer in 0.6
"SELECT CAST(:param_1 AS INTEGER) AS anon_1"
)
-
-
def test_nested_uselabels(self):
"""test nested anonymous label generation. this
essentially tests the ANONYMOUS_LABEL regex.
def test_operators(self):
for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
- (operator.sub, '-'), (operator.div, '/'),
+ (operator.sub, '-'),
+ # Py3K
+ #(operator.truediv, '/'),
+ # Py2K
+ (operator.div, '/'),
+ # end Py2K
):
for (lhs, rhs, res) in (
(5, table1.c.myid, ':myid_1 %s mytable.myid'),
(~table1.c.myid.like('somstr', escape='\\'), "mytable.myid NOT LIKE :myid_1 ESCAPE '\\'", None),
(table1.c.myid.ilike('somstr', escape='\\'), "lower(mytable.myid) LIKE lower(:myid_1) ESCAPE '\\'", None),
(~table1.c.myid.ilike('somstr', escape='\\'), "lower(mytable.myid) NOT LIKE lower(:myid_1) ESCAPE '\\'", None),
- (table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid ILIKE %(myid_1)s ESCAPE '\\'", postgres.PGDialect()),
- (~table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid NOT ILIKE %(myid_1)s ESCAPE '\\'", postgres.PGDialect()),
+ (table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid ILIKE %(myid_1)s ESCAPE '\\'", postgresql.PGDialect()),
+ (~table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid NOT ILIKE %(myid_1)s ESCAPE '\\'", postgresql.PGDialect()),
(table1.c.name.ilike('%something%'), "lower(mytable.name) LIKE lower(:name_1)", None),
- (table1.c.name.ilike('%something%'), "mytable.name ILIKE %(name_1)s", postgres.PGDialect()),
+ (table1.c.name.ilike('%something%'), "mytable.name ILIKE %(name_1)s", postgresql.PGDialect()),
(~table1.c.name.ilike('%something%'), "lower(mytable.name) NOT LIKE lower(:name_1)", None),
- (~table1.c.name.ilike('%something%'), "mytable.name NOT ILIKE %(name_1)s", postgres.PGDialect()),
+ (~table1.c.name.ilike('%something%'), "mytable.name NOT ILIKE %(name_1)s", postgresql.PGDialect()),
]:
self.assert_compile(expr, check, dialect=dialect)
def test_match(self):
for expr, check, dialect in [
(table1.c.myid.match('somstr'), "mytable.myid MATCH ?", sqlite.SQLiteDialect()),
- (table1.c.myid.match('somstr'), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", mysql.MySQLDialect()),
- (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", mssql.MSSQLDialect()),
- (table1.c.myid.match('somstr'), "mytable.myid @@ to_tsquery(%(myid_1)s)", postgres.PGDialect()),
- (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.OracleDialect()),
+ (table1.c.myid.match('somstr'), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", mysql.dialect()),
+ (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", mssql.dialect()),
+ (table1.c.myid.match('somstr'), "mytable.myid @@ to_tsquery(%(myid_1)s)", postgresql.dialect()),
+ (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.dialect()),
]:
self.assert_compile(expr, check, dialect=dialect)
select([table1.alias('foo')])
,"SELECT foo.myid, foo.name, foo.description FROM mytable AS foo")
- for dialect in (firebird.dialect(), oracle.dialect()):
+ for dialect in (oracle.dialect(),):
self.assert_compile(
select([table1.alias('foo')])
,"SELECT foo.myid, foo.name, foo.description FROM mytable foo"
params={},
)
- dialect = postgres.dialect()
+ dialect = postgresql.dialect()
self.assert_compile(
text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar',4), bindparam('whee',7)]),
"select * from foo where lala=%(bar)s and hoho=%(whee)s",
self.assert_compile(stmt, expected_positional_stmt, dialect=sqlite.dialect())
nonpositional = stmt.compile()
positional = stmt.compile(dialect=sqlite.dialect())
- pp = positional.get_params()
+ pp = positional.params
assert [pp[k] for k in positional.positiontup] == expected_default_params_list
- assert nonpositional.get_params(**test_param_dict) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict)))
- pp = positional.get_params(**test_param_dict)
+ assert nonpositional.construct_params(test_param_dict) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict)))
+ pp = positional.construct_params(test_param_dict)
assert [pp[k] for k in positional.positiontup] == expected_test_params_list
# check that params() doesnt modify original statement
":myid_1) AS anon_1 FROM mytable WHERE mytable.myid = (SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)")
positional = s2.compile(dialect=sqlite.dialect())
- pp = positional.get_params()
+ pp = positional.params
assert [pp[k] for k in positional.positiontup] == [12, 12]
# check that conflicts with "unique" params are caught
params = dict(('in%d' % i, i) for i in range(total_params))
sql = 'text clause %s' % ', '.join(in_clause)
t = text(sql)
- assert len(t.bindparams) == total_params
+ eq_(len(t.bindparams), total_params)
c = t.compile()
pp = c.construct_params(params)
- assert len(set(pp)) == total_params
- assert len(set(pp.values())) == total_params
+ eq_(len(set(pp)), total_params, '%s %s' % (len(set(pp)), len(pp)))
+ eq_(len(set(pp.values())), total_params)
def test_bind_as_col(self):
eq_(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0])
eq_(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1])
eq_(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2])
- eq_(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
+ eq_(str(cast(1234, Text).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
eq_(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4]))
# fixme: shoving all of this dialect-specific stuff in one test
# is now officialy completely ridiculous AND non-obviously omits
# coverage on other dialects.
sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect)
if isinstance(dialect, type(mysql.dialect())):
- eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest")
+ eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL) AS anon_1 \nFROM casttest")
else:
- eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest")
+ eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC) AS anon_1 \nFROM casttest")
# first test with PostgreSQL engine
- check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s')
+ check_results(postgresql.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s')
# then the Oracle engine
- check_results(oracle.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':param_1')
+ check_results(oracle.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':param_1')
# then the sqlite engine
- check_results(sqlite.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?')
+ check_results(sqlite.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?')
# then the MySQL engine
- check_results(mysql.dialect(), ['DECIMAL(10, 2)', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s')
+ check_results(mysql.dialect(), ['DECIMAL', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s')
self.assert_compile(cast(text('NULL'), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect())
self.assert_compile(cast(null(), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect())
s1 = select([table1.c.myid, table1.c.myid.label('foobar'), func.hoho(table1.c.name), func.lala(table1.c.name).label('gg')])
assert s1.c.keys() == ['myid', 'foobar', 'hoho(mytable.name)', 'gg']
- from sqlalchemy.databases.sqlite import SLNumeric
meta = MetaData()
t1 = Table('mytable', meta, Column('col1', Integer))
(table1.c.name, 'name', 'mytable.name', None),
(table1.c.myid==12, 'mytable.myid = :myid_1', 'mytable.myid = :myid_1', 'anon_1'),
(func.hoho(table1.c.myid), 'hoho(mytable.myid)', 'hoho(mytable.myid)', 'hoho_1'),
- (cast(table1.c.name, SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'),
+ (cast(table1.c.name, Numeric), 'CAST(mytable.name AS NUMERIC)', 'CAST(mytable.name AS NUMERIC)', 'anon_1'),
(t1.c.col1, 'col1', 'mytable.col1', None),
(column('some wacky thing'), 'some wacky thing', '"some wacky thing"', '')
):
Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True),
)
- # this is essentially the union formed by the ORM's polymorphic_union function.
+ # this is essentially the union formed by the ORM's polymorphic_union function.
# we define two versions with different ordering of selects.
# the first selectable has the "real" column classified_page.magazine_page_id
magazine_page_table.c.page_id,
cast(null(), Integer).label('magazine_page_id')
]).select_from(page_table.join(magazine_page_table)),
-
).alias('pjoin')
eq_(
+# coding: utf-8
from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import decimal
import datetime, os, re
from sqlalchemy import *
-from sqlalchemy import exc, types, util
+from sqlalchemy import exc, types, util, schema
from sqlalchemy.sql import operators
from sqlalchemy.test.testing import eq_
import sqlalchemy.engine.url as url
-from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
+from sqlalchemy.databases import *
+
from sqlalchemy.test import *
class AdaptTest(TestBase):
- def testadapt(self):
- e1 = url.URL('postgres').get_dialect()()
- e2 = url.URL('mysql').get_dialect()()
- e3 = url.URL('sqlite').get_dialect()()
- e4 = url.URL('firebird').get_dialect()()
-
- type = String(40)
-
- t1 = type.dialect_impl(e1)
- t2 = type.dialect_impl(e2)
- t3 = type.dialect_impl(e3)
- t4 = type.dialect_impl(e4)
-
- impls = [t1, t2, t3, t4]
- for i,ta in enumerate(impls):
- for j,tb in enumerate(impls):
- if i == j:
- assert ta == tb # call me paranoid... :)
+ def test_uppercase_rendering(self):
+ """Test that uppercase types from types.py always render as their type.
+
+ As of SQLA 0.6, using an uppercase type means you want specifically that
+ type. If the database in use doesn't support that DDL, it (the DB backend)
+ should raise an error - it means you should be using a lowercased (genericized) type.
+
+ """
+
+ for dialect in [
+ oracle.dialect(),
+ mysql.dialect(),
+ postgresql.dialect(),
+ sqlite.dialect(),
+ sybase.dialect(),
+ informix.dialect(),
+ maxdb.dialect(),
+ mssql.dialect()]: # TODO when dialects are complete: engines.all_dialects():
+ for type_, expected in (
+ (FLOAT, "FLOAT"),
+ (NUMERIC, "NUMERIC"),
+ (DECIMAL, "DECIMAL"),
+ (INTEGER, "INTEGER"),
+ (SMALLINT, "SMALLINT"),
+ (TIMESTAMP, "TIMESTAMP"),
+ (DATETIME, "DATETIME"),
+ (DATE, "DATE"),
+ (TIME, "TIME"),
+ (CLOB, "CLOB"),
+ (VARCHAR, "VARCHAR"),
+ (NVARCHAR, ("NVARCHAR", "NATIONAL VARCHAR")),
+ (CHAR, "CHAR"),
+ (NCHAR, ("NCHAR", "NATIONAL CHAR")),
+ (BLOB, "BLOB"),
+ (BOOLEAN, ("BOOLEAN", "BOOL"))
+ ):
+ if isinstance(expected, str):
+ expected = (expected, )
+ for exp in expected:
+ compiled = type_().compile(dialect=dialect)
+ if exp in compiled:
+ break
else:
- assert ta != tb
-
- def testmsnvarchar(self):
- dialect = mssql.MSSQLDialect()
- # run the test twice to ensure the caching step works too
- for x in range(0, 1):
- col = Column('', Unicode(length=10))
- dialect_type = col.type.dialect_impl(dialect)
- assert isinstance(dialect_type, mssql.MSNVarchar)
- assert dialect_type.get_col_spec() == 'NVARCHAR(10)'
-
-
- def testoracletimestamp(self):
- dialect = oracle.OracleDialect()
- t1 = oracle.OracleTimestamp
- t2 = oracle.OracleTimestamp()
- t3 = types.TIMESTAMP
- assert isinstance(dialect.type_descriptor(t1), oracle.OracleTimestamp)
- assert isinstance(dialect.type_descriptor(t2), oracle.OracleTimestamp)
- assert isinstance(dialect.type_descriptor(t3), oracle.OracleTimestamp)
-
- def testmysqlbinary(self):
- dialect = mysql.MySQLDialect()
- t1 = mysql.MSVarBinary
- t2 = mysql.MSVarBinary()
- assert isinstance(dialect.type_descriptor(t1), mysql.MSVarBinary)
- assert isinstance(dialect.type_descriptor(t2), mysql.MSVarBinary)
-
- def teststringadapt(self):
- """test that String with no size becomes TEXT, *all* others stay as varchar/String"""
-
- oracle_dialect = oracle.OracleDialect()
- mysql_dialect = mysql.MySQLDialect()
- postgres_dialect = postgres.PGDialect()
- firebird_dialect = firebird.FBDialect()
-
- for dialect, start, test in [
- (oracle_dialect, String(), oracle.OracleString),
- (oracle_dialect, VARCHAR(), oracle.OracleString),
- (oracle_dialect, String(50), oracle.OracleString),
- (oracle_dialect, Unicode(), oracle.OracleString),
- (oracle_dialect, UnicodeText(), oracle.OracleText),
- (oracle_dialect, NCHAR(), oracle.OracleString),
- (oracle_dialect, oracle.OracleRaw(50), oracle.OracleRaw),
- (mysql_dialect, String(), mysql.MSString),
- (mysql_dialect, VARCHAR(), mysql.MSString),
- (mysql_dialect, String(50), mysql.MSString),
- (mysql_dialect, Unicode(), mysql.MSString),
- (mysql_dialect, UnicodeText(), mysql.MSText),
- (mysql_dialect, NCHAR(), mysql.MSNChar),
- (postgres_dialect, String(), postgres.PGString),
- (postgres_dialect, VARCHAR(), postgres.PGString),
- (postgres_dialect, String(50), postgres.PGString),
- (postgres_dialect, Unicode(), postgres.PGString),
- (postgres_dialect, UnicodeText(), postgres.PGText),
- (postgres_dialect, NCHAR(), postgres.PGString),
- (firebird_dialect, String(), firebird.FBString),
- (firebird_dialect, VARCHAR(), firebird.FBString),
- (firebird_dialect, String(50), firebird.FBString),
- (firebird_dialect, Unicode(), firebird.FBString),
- (firebird_dialect, UnicodeText(), firebird.FBText),
- (firebird_dialect, NCHAR(), firebird.FBString),
- ]:
- assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect))
-
-
+ assert False, "%r matches none of %r for dialect %s" % (compiled, expected, dialect.name)
+
class UserDefinedTest(TestBase):
"""tests user-defined types."""
def setup_class(cls):
global users, metadata
- class MyType(types.TypeEngine):
+ class MyType(types.UserDefinedType):
def get_col_spec(self):
return "VARCHAR(100)"
def bind_processor(self, dialect):
for aCol in testTable.c:
eq_(
expectedResults[aCol.name],
- db.dialect.schemagenerator(db.dialect, db, None, None).\
+ db.dialect.ddl_compiler(db.dialect, schema.CreateTable(testTable)).\
get_column_specification(aCol))
class UnicodeTest(TestBase, AssertsExecutionResults):
"""tests the Unicode type. also tests the TypeDecorator with instances in the types package."""
+
@classmethod
def setup_class(cls):
- global unicode_table
+ global unicode_table, metadata
metadata = MetaData(testing.db)
unicode_table = Table('unicode_table', metadata,
Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True),
Column('unicode_varchar', Unicode(250)),
Column('unicode_text', UnicodeText),
- Column('plain_varchar', String(250))
)
- unicode_table.create()
+ metadata.create_all()
+
@classmethod
def teardown_class(cls):
- unicode_table.drop()
+ metadata.drop_all()
+ @engines.close_first
def teardown(self):
unicode_table.delete().execute()
def test_round_trip(self):
- assert unicode_table.c.unicode_varchar.type.length == 250
- rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
- unicodedata = rawdata.decode('utf-8')
- if testing.against('sqlite'):
- rawdata = "something"
-
- unicode_table.insert().execute(unicode_varchar=unicodedata,
- unicode_text=unicodedata,
- plain_varchar=rawdata)
- x = unicode_table.select().execute().fetchone()
+ unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »"
+
+ unicode_table.insert().execute(unicode_varchar=unicodedata,unicode_text=unicodedata)
+
+ x = unicode_table.select().execute().first()
self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
- if isinstance(x['plain_varchar'], unicode):
- # SQLLite and MSSQL return non-unicode data as unicode
- self.assert_(testing.against('sqlite', 'mssql'))
- if not testing.against('sqlite'):
- self.assert_(x['plain_varchar'] == unicodedata)
- else:
- self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata)
- def test_union(self):
- """ensure compiler processing works for UNIONs"""
+ def test_round_trip_executemany(self):
+ # cx_oracle was producing different behavior for cursor.executemany()
+ # vs. cursor.execute()
+
+ unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »"
- rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
- unicodedata = rawdata.decode('utf-8')
- if testing.against('sqlite'):
- rawdata = "something"
- unicode_table.insert().execute(unicode_varchar=unicodedata,
- unicode_text=unicodedata,
- plain_varchar=rawdata)
-
- x = union(select([unicode_table.c.unicode_varchar]), select([unicode_table.c.unicode_varchar])).execute().fetchone()
+ unicode_table.insert().execute(
+ dict(unicode_varchar=unicodedata,unicode_text=unicodedata),
+ dict(unicode_varchar=unicodedata,unicode_text=unicodedata)
+ )
+
+ x = unicode_table.select().execute().first()
self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
+ self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
- def test_assertions(self):
- try:
- unicode_table.insert().execute(unicode_varchar='not unicode')
- assert False
- except exc.SAWarning, e:
- assert str(e) == "Unicode type received non-unicode bind param value 'not unicode'", str(e)
+ def test_union(self):
+ """ensure compiler processing works for UNIONs"""
- unicode_engine = engines.utf8_engine(options={'convert_unicode':True,
- 'assert_unicode':True})
- try:
- try:
- unicode_engine.execute(unicode_table.insert(), plain_varchar='im not unicode')
- assert False
- except exc.InvalidRequestError, e:
- assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'"
-
- @testing.emits_warning('.*non-unicode bind')
- def warns():
- # test that data still goes in if warning is emitted....
- unicode_table.insert().execute(unicode_varchar='not unicode')
- assert (select([unicode_table.c.unicode_varchar]).execute().fetchall() == [('not unicode', )])
- warns()
+ unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »"
- finally:
- unicode_engine.dispose()
+ unicode_table.insert().execute(unicode_varchar=unicodedata,unicode_text=unicodedata)
+
+ x = union(select([unicode_table.c.unicode_varchar]), select([unicode_table.c.unicode_varchar])).execute().first()
+ self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
- @testing.fails_on('oracle', 'FIXME: unknown')
+ @testing.fails_on('oracle', 'oracle converts empty strings to a blank space')
def test_blank_strings(self):
unicode_table.insert().execute(unicode_varchar=u'')
assert select([unicode_table.c.unicode_varchar]).scalar() == u''
- def test_engine_parameter(self):
- """tests engine-wide unicode conversion"""
- prev_unicode = testing.db.engine.dialect.convert_unicode
- prev_assert = testing.db.engine.dialect.assert_unicode
- try:
- testing.db.engine.dialect.convert_unicode = True
- testing.db.engine.dialect.assert_unicode = False
- rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
- unicodedata = rawdata.decode('utf-8')
- if testing.against('sqlite', 'mssql'):
- rawdata = "something"
- unicode_table.insert().execute(unicode_varchar=unicodedata,
- unicode_text=unicodedata,
- plain_varchar=rawdata)
- x = unicode_table.select().execute().fetchone()
- self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
- self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
- if not testing.against('sqlite', 'mssql'):
- self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata)
- finally:
- testing.db.engine.dialect.convert_unicode = prev_unicode
- testing.db.engine.dialect.convert_unicode = prev_assert
-
- @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
- @testing.fails_on('firebird', 'Data type unknown')
- def test_length_function(self):
- """checks the database correctly understands the length of a unicode string"""
- teststr = u'aaa\x1234'
- self.assert_(testing.db.func.length(teststr).scalar() == len(teststr))
+ def test_parameters(self):
+ """test the dialect convert_unicode parameters."""
+
+ unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »"
+
+ u = Unicode(assert_unicode=True)
+ uni = u.dialect_impl(testing.db.dialect).bind_processor(testing.db.dialect)
+ # Py3K
+ #assert_raises(exc.InvalidRequestError, uni, b'x')
+ # Py2K
+ assert_raises(exc.InvalidRequestError, uni, 'x')
+ # end Py2K
+
+ u = Unicode()
+ uni = u.dialect_impl(testing.db.dialect).bind_processor(testing.db.dialect)
+ # Py3K
+ #assert_raises(exc.SAWarning, uni, b'x')
+ # Py2K
+ assert_raises(exc.SAWarning, uni, 'x')
+ # end Py2K
+
+ unicode_engine = engines.utf8_engine(options={'convert_unicode':True,'assert_unicode':True})
+ unicode_engine.dialect.supports_unicode_binds = False
+
+ s = String()
+ uni = s.dialect_impl(unicode_engine.dialect).bind_processor(unicode_engine.dialect)
+ # Py3K
+ #assert_raises(exc.InvalidRequestError, uni, b'x')
+ #assert isinstance(uni(unicodedata), bytes)
+ # Py2K
+ assert_raises(exc.InvalidRequestError, uni, 'x')
+ assert isinstance(uni(unicodedata), str)
+ # end Py2K
+
+ assert uni(unicodedata) == unicodedata.encode('utf-8')
class BinaryTest(TestBase, AssertsExecutionResults):
__excluded_on__ = (
return value
binary_table = Table('binary_table', MetaData(testing.db),
- Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True),
- Column('data', Binary),
- Column('data_slice', Binary(100)),
- Column('misc', String(30)),
- # construct PickleType with non-native pickle module, since cPickle uses relative module
- # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative
- # to the 'types' module
- Column('pickled', PickleType),
- Column('mypickle', MyPickleType)
+ Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True),
+ Column('data', Binary),
+ Column('data_slice', Binary(100)),
+ Column('misc', String(30)),
+ # construct PickleType with non-native pickle module, since cPickle uses relative module
+ # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative
+ # to the 'types' module
+ Column('pickled', PickleType),
+ Column('mypickle', MyPickleType)
)
binary_table.create()
+ @engines.close_first
def teardown(self):
binary_table.delete().execute()
def teardown_class(cls):
binary_table.drop()
- @testing.fails_on('mssql', 'MSSQl BINARY type right pads the fixed length with \x00')
- def testbinary(self):
+ def test_round_trip(self):
testobj1 = pickleable.Foo('im foo 1')
testobj2 = pickleable.Foo('im foo 2')
testobj3 = pickleable.Foo('im foo 3')
stream1 =self.load_stream('binary_data_one.dat')
stream2 =self.load_stream('binary_data_two.dat')
- binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3)
- binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99], pickled=testobj2)
- binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None)
+ binary_table.insert().execute(
+ primary_id=1,
+ misc='binary_data_one.dat',
+ data=stream1,
+ data_slice=stream1[0:100],
+ pickled=testobj1,
+ mypickle=testobj3)
+ binary_table.insert().execute(
+ primary_id=2,
+ misc='binary_data_two.dat',
+ data=stream2,
+ data_slice=stream2[0:99],
+ pickled=testobj2)
+ binary_table.insert().execute(
+ primary_id=3,
+ misc='binary_data_two.dat',
+ data=None,
+ data_slice=stream2[0:99],
+ pickled=None)
for stmt in (
binary_table.select(order_by=binary_table.c.primary_id),
text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db)
):
+ eq_data = lambda x, y: eq_(list(x), list(y))
+ if util.jython:
+ _eq_data = eq_data
+ def eq_data(x, y):
+ # Jython currently returns arrays
+ from array import ArrayType
+ if isinstance(y, ArrayType):
+ return eq_(x, y.tostring())
+ return _eq_data(x, y)
l = stmt.execute().fetchall()
- eq_(list(stream1), list(l[0]['data']))
- eq_(list(stream1[0:100]), list(l[0]['data_slice']))
- eq_(list(stream2), list(l[1]['data']))
+ eq_data(stream1, l[0]['data'])
+ eq_data(stream1[0:100], l[0]['data_slice'])
+ eq_data(stream2, l[1]['data'])
eq_(testobj1, l[0]['pickled'])
eq_(testobj2, l[1]['pickled'])
eq_(testobj3.moredata, l[0]['mypickle'].moredata)
eq_(l[0]['mypickle'].stuff, 'this is the right stuff')
- def load_stream(self, name, len=12579):
+ def load_stream(self, name):
f = os.path.join(os.path.dirname(__file__), "..", name)
- # put a number less than the typical MySQL default BLOB size
- return file(f).read(len)
+ return open(f, mode='rb').read()
class ExpressionTest(TestBase, AssertsExecutionResults):
@classmethod
def setup_class(cls):
global test_table, meta
- class MyCustomType(types.TypeEngine):
+ class MyCustomType(types.UserDefinedType):
def get_col_spec(self):
return "INT"
def bind_processor(self, dialect):
db = testing.db
if testing.against('oracle'):
- import sqlalchemy.databases.oracle as oracle
insert_data = [
(7, 'jack',
datetime.datetime(2005, 11, 10, 0, 0),
time_micro = 999
# Missing or poor microsecond support:
- if testing.against('mssql', 'mysql', 'firebird'):
+ if testing.against('mssql', 'mysql', 'firebird', '+zxjdbc'):
datetime_micro, time_micro = 0, 0
# No microseconds for TIME
elif testing.against('maxdb'):
Column('user_date', Date),
Column('user_time', Time)]
- if testing.against('sqlite', 'postgres'):
+ if testing.against('sqlite', 'postgresql'):
insert_data.append(
(11, 'historic',
datetime.datetime(1850, 11, 10, 11, 52, 35, datetime_micro),
t.drop(checkfirst=True)
class StringTest(TestBase, AssertsExecutionResults):
- @testing.fails_on('mysql', 'FIXME: unknown')
- @testing.fails_on('oracle', 'FIXME: unknown')
+
+ @testing.requires.unbounded_varchar
def test_nolength_string(self):
metadata = MetaData(testing.db)
foo = Table('foo', metadata, Column('one', String))
metadata = MetaData(testing.db)
numeric_table = Table('numeric_table', metadata,
Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True),
- Column('numericcol', Numeric(asdecimal=False)),
- Column('floatcol', Float),
- Column('ncasdec', Numeric),
- Column('fcasdec', Float(asdecimal=True))
+ Column('numericcol', Numeric(precision=10, scale=2, asdecimal=False)),
+ Column('floatcol', Float(precision=10, )),
+ Column('ncasdec', Numeric(precision=10, scale=2)),
+ Column('fcasdec', Float(precision=10, asdecimal=True))
)
metadata.create_all()
def teardown_class(cls):
metadata.drop_all()
+ @engines.close_first
def teardown(self):
numeric_table.delete().execute()
from decimal import Decimal
numeric_table.insert().execute(
numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.75)
+
numeric_table.insert().execute(
numericcol=Decimal("3.5"), floatcol=Decimal("5.6"),
ncasdec=Decimal("12.4"), fcasdec=Decimal("15.75"))
assert isinstance(row['ncasdec'], decimal.Decimal)
assert isinstance(row['fcasdec'], decimal.Decimal)
- def test_length_deprecation(self):
- assert_raises(exc.SADeprecationWarning, Numeric, length=8)
-
- @testing.uses_deprecated(".*is deprecated for Numeric")
- def go():
- n = Numeric(length=12)
- assert n.scale == 12
- go()
-
- n = Numeric(scale=12)
- for dialect in engines.all_dialects():
- n2 = dialect.type_descriptor(n)
- eq_(n2.scale, 12, dialect.name)
-
- # test colspec generates successfully using 'scale'
- assert n2.get_col_spec()
-
- # test constructor of the dialect-specific type
- n3 = n2.__class__(scale=5)
- eq_(n3.scale, 5, dialect.name)
-
- @testing.uses_deprecated(".*is deprecated for Numeric")
- def go():
- n3 = n2.__class__(length=6)
- eq_(n3.scale, 6, dialect.name)
- go()
-
class IntervalTest(TestBase, AssertsExecutionResults):
@classmethod
)
metadata.create_all()
+ @engines.close_first
def teardown(self):
interval_table.delete().execute()
def teardown_class(cls):
metadata.drop_all()
+ @testing.fails_on("+pg8000", "Not yet known how to pass values of the INTERVAL type")
+ @testing.fails_on("postgresql+zxjdbc", "Not yet known how to pass values of the INTERVAL type")
def test_roundtrip(self):
delta = datetime.datetime(2006, 10, 5) - datetime.datetime(2005, 8, 17)
interval_table.insert().execute(interval=delta)
- assert interval_table.select().execute().fetchone()['interval'] == delta
+ assert interval_table.select().execute().first()['interval'] == delta
def test_null(self):
interval_table.insert().execute(id=1, inverval=None)
- assert interval_table.select().execute().fetchone()['interval'] is None
+ assert interval_table.select().execute().first()['interval'] is None
class BooleanTest(TestBase, AssertsExecutionResults):
@classmethod
assert(res2==[(2, False)])
class PickleTest(TestBase):
- def test_noeq_deprecation(self):
- p1 = PickleType()
-
- assert_raises(DeprecationWarning,
- p1.compare_values, pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)
- )
-
- assert_raises(DeprecationWarning,
- p1.compare_values, pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)
- )
-
- @testing.uses_deprecated()
- def go():
- # test actual dumps comparison
- assert p1.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2))
- assert p1.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2))
- go()
-
- assert p1.compare_values({1:2, 3:4}, {3:4, 1:2})
-
- p2 = PickleType(mutable=False)
- assert not p2.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2))
- assert not p2.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2))
-
def test_eq_comparison(self):
p1 = PickleType()
)
metadata.create_all()
+ @engines.close_first
def teardown(self):
if metadata.tables:
t3.delete().execute()
# reset the identifier preparer, so that we can force it to cache
# a unicode identifier
engine.dialect.identifier_preparer = engine.dialect.preparer(engine.dialect)
- select([column(u'special_col')]).select_from(t1).execute()
+ select([column(u'special_col')]).select_from(t1).execute().close()
assert isinstance(engine.dialect.identifier_preparer.format_sequence(Sequence('special_col')), unicode)
# now execute, run the sequence. it should run in u"Special_col.nextid" or similar as
- # a unicode object; cx_oracle asserts that this is None or a String (postgres lets it pass thru).
+ # a unicode object; cx_oracle asserts that this is None or a String (postgresql lets it pass thru).
# ensure that base.DefaultRunner is encoding.
t1.insert().execute(data='foo')
finally:
"""mapper.py - defines mappers for domain objects, mapping operations"""
-import tables, user
-from blog import *
+from test.zblog import tables, user
+from test.zblog.blog import *
from sqlalchemy import *
from sqlalchemy.orm import *
import sqlalchemy.util as util
"""application table metadata objects are described here."""
from sqlalchemy import *
-
+from sqlalchemy.test.schema import Table, Column
metadata = MetaData()
users = Table('users', metadata,
- Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key=True),
+ Column('user_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('user_name', String(30), nullable=False),
Column('fullname', String(100), nullable=False),
Column('password', String(40), nullable=False),
)
blogs = Table('blogs', metadata,
- Column('blog_id', Integer, Sequence('blog_id_seq', optional=True), primary_key=True),
+ Column('blog_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('owner_id', Integer, ForeignKey('users.user_id'), nullable=False),
Column('name', String(100), nullable=False),
Column('description', String(500))
)
posts = Table('posts', metadata,
- Column('post_id', Integer, Sequence('post_id_seq', optional=True), primary_key=True),
+ Column('post_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('blog_id', Integer, ForeignKey('blogs.blog_id'), nullable=False),
Column('user_id', Integer, ForeignKey('users.user_id'), nullable=False),
Column('datetime', DateTime, nullable=False),
)
topics = Table('topics', metadata,
- Column('topic_id', Integer, primary_key=True),
+ Column('topic_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('keyword', String(50), nullable=False),
Column('description', String(500))
)
)
comments = Table('comments', metadata,
- Column('comment_id', Integer, primary_key=True),
+ Column('comment_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('user_id', Integer, ForeignKey('users.user_id'), nullable=False),
Column('post_id', Integer, ForeignKey('posts.post_id'), nullable=False),
Column('datetime', DateTime, nullable=False),
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.test import *
-import mappers, tables
-from user import *
-from blog import *
+from test.zblog import mappers, tables
+from test.zblog.user import *
+from test.zblog.blog import *
class ZBlogTest(TestBase, AssertsExecutionResults):
def cryptpw(password, salt=None):
if salt is None:
- salt = string.join([chr(random.randint(ord('a'), ord('z'))),
- chr(random.randint(ord('a'), ord('z')))],'')
- return sha(password + salt).hexdigest()
+ salt = "".join([chr(random.randint(ord('a'), ord('z'))),
+ chr(random.randint(ord('a'), ord('z')))])
+ return sha((password+ salt).encode('ascii')).hexdigest()
def checkpw(password, dbpw):
return cryptpw(password, dbpw[:2]) == dbpw