From: Mike Bayer Date: Thu, 6 Aug 2009 21:11:27 +0000 (+0000) Subject: merge 0.6 series to trunk. X-Git-Tag: rel_0_6beta1~361 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git merge 0.6 series to trunk. --- diff --git a/06CHANGES b/06CHANGES new file mode 100644 index 0000000000..4c7f9ca693 --- /dev/null +++ b/06CHANGES @@ -0,0 +1,177 @@ +- orm + - the 'expire' option on query.update() has been renamed to 'fetch', thus matching + that of query.delete() + - query.update() and query.delete() both default to 'evaluate' for the synchronize + strategy. + - the 'synchronize' strategy for update() and delete() raises an error on failure. + There is no implicit fallback onto "fetch". Failure of evaluation is based + on the structure of criteria, so success/failure is deterministic based on + code structure. + - the "dont_load=True" flag on Session.merge() is deprecated and is now + "load=False". + +- sql + - returning() support is native to insert(), update(), delete(). Implementations + of varying levels of functionality exist for Postgresql, Firebird, MSSQL and + Oracle. returning() can be called explicitly with column expressions which + are then returned in the resultset, usually via fetchone() or first(). + + insert() constructs will also use RETURNING implicitly to get newly + generated primary key values, if the database version in use supports it + (a version number check is performed). This occurs if no end-user + returning() was specified. + + - Databases which rely upon postfetch of "last inserted id" to get at a + generated sequence value (i.e. MySQL, MS-SQL) now work correctly + when there is a composite primary key where the "autoincrement" column + is not the first primary key column in the table. + + - the last_inserted_ids() method has been renamed to the descriptor + "inserted_primary_key". + +- engines + - transaction isolation level may be specified with + create_engine(... isolation_level="..."); available on + postgresql and sqlite. [ticket:443] + - added first() method to ResultProxy, returns first row and closes + result set immediately. + +- schema + - deprecated metadata.connect() and threadlocalmetadata.connect() have been + removed - send the "bind" attribute to bind a metadata. + - deprecated metadata.table_iterator() method removed (use sorted_tables) + - the "metadata" argument is removed from DefaultGenerator and subclasses, + but remains locally present on Sequence, which is a standalone construct + in DDL. + - Removed public mutability from Index and Constraint objects: + - ForeignKeyConstraint.append_element() + - Index.append_column() + - UniqueConstraint.append_column() + - PrimaryKeyConstraint.add() + - PrimaryKeyConstraint.remove() + These should be constructed declaratively (i.e. in one construction). + - UniqueConstraint, Index, PrimaryKeyConstraint all accept lists + of column names or column objects as arguments. + - Other removed things: + - Table.key (no idea what this was for) + - Table.primary_key is not assignable - use table.append_constraint(PrimaryKeyConstraint(...)) + - Column.bind (get via column.table.bind) + - Column.metadata (get via column.table.metadata) + - the use_alter flag on ForeignKey is now a shortcut option for operations that + can be hand-constructed using the DDL() event system. A side effect of this refactor + is that ForeignKeyConstraint objects with use_alter=True will *not* be emitted on + SQLite, which does not support ALTER for foreign keys. This has no effect on SQLite's + behavior since SQLite does not actually honor FOREIGN KEY constraints. + +- DDL + - the DDL() system has been greatly expanded: + - CreateTable() + - DropTable() + - AddConstraint() + - DropConstraint() + - CreateIndex() + - DropIndex() + - CreateSequence() + - DropSequence() + - these support "on" and "execute-at()" just like + plain DDL() does. + - the "on" callable passed to DDL() needs to accept **kw arguments. + In the case of MetaData before/after create/drop, the list of + Table objects for which CREATE/DROP DDL is to be issued is passed + as the kw argument "tables". This is necessary for metadata-level + DDL that is dependent on the presence of specific tables. + +- dialect refactor + - the "owner" keyword argument is removed from Table. Use "schema" to + represent any namespaces to be prepended to the table name. + - server_version_info becomes a static attribute. + - dialects receive an initialize() event on initial connection to + determine connection properties. + - dialects receive a visit_pool event have an opportunity to + establish pool listeners. + - cached TypeEngine classes are cached per-dialect class instead of per-dialect. + - Deprecated Dialect.get_params() removed. + - Dialect.get_rowcount() has been renamed to a descriptor "rowcount", and calls + cursor.rowcount directly. Dialects which need to hardwire a rowcount in for + certain calls should override the method to provide different behavior. + - functions and operators generated by the compiler now use (almost) regular + dispatch functions of the form "visit_" and "visit__fn" + to provide customed processing. This replaces the need to copy the "functions" + and "operators" dictionaries in compiler subclasses with straightforward + visitor methods, and also allows compiler subclasses complete control over + rendering, as the full _Function or _BinaryExpression object is passed in. + +- postgresql + - the "postgres" dialect is now named "postgresql" ! Connection strings look + like: + + postgresql://scott:tiger@localhost/test + postgresql+pg8000://scott:tiger@localhost/test + + The "postgres" name remains for backwards compatiblity in the following ways: + + - There is a "postgres.py" dummy dialect which allows old URLs to work, + i.e. postgres://scott:tiger@localhost/test + + - The "postgres" name can be imported from the old "databases" module, + i.e. "from sqlalchemy.databases import postgres" as well as "dialects", + "from sqlalchemy.dialects.postgres import base as pg", will send + a deprecation warning. + + - Special expression arguments are now named "postgresql_returning" + and "postgresql_where", but the older "postgres_returning" and + "postgres_where" names still work with a deprecation warning. + +- mysql + - all the _detect_XXX() functions now run once underneath dialect.initialize() + +- oracle + - support for cx_Oracle's "native unicode" mode which does not require NLS_LANG + to be set. Use the latest 5.0.2 or later of cx_oracle. + - an NCLOB type is added to the base types. + - func.char_length is a generic function for LENGTH + - ForeignKey() which includes onupdate= will emit a warning, not + emit ON UPDATE CASCADE which is unsupported by oracle + - the keys() method of RowProxy() now returns the result column names *normalized* + to be SQLAlchemy case insensitive names. This means they will be lower case + for case insensitive names, whereas the DBAPI would normally return them + as UPPERCASE names. This allows row keys() to be compatible with further + SQLAlchemy operations. + +- firebird + - the keys() method of RowProxy() now returns the result column names *normalized* + to be SQLAlchemy case insensitive names. This means they will be lower case + for case insensitive names, whereas the DBAPI would normally return them + as UPPERCASE names. This allows row keys() to be compatible with further + SQLAlchemy operations. + +- new dialects + - postgresql+pg8000 + - postgresql+pypostgresql (partial) + - postgresql+zxjdbc + - mysql+pyodbc + - mysql+zxjdbc + +- mssql + - the "has_window_funcs" flag is removed. LIMIT/OFFSET usage will use ROW NUMBER as always, + and if on an older version of SQL Server, the operation fails. The behavior is exactly + the same except the error is raised by SQL server instead of the dialect, and no + flag setting is required to enable it. + - the "auto_identity_insert" flag is removed. This feature always takes effect + when an INSERT statement overrides a column that is known to have a sequence on it. + As with "has_window_funcs", if the underlying driver doesn't support this, then you + can't do this operation in any case, so there's no point in having a flag. + - using new dialect.initialize() feature to set up version-dependent behavior. + +- types + - PickleType now uses == for comparison of values when mutable=True, + unless the "comparator" argument with a comparsion function is specified to the type. + Objects being pickled will be compared based on identity (which defeats the purpose + of mutable=True) if __eq__() is not overridden or a comparison function is not provided. + - The default "precision" and "scale" arguments of Numeric and Float have been removed + and now default to None. NUMERIC and FLOAT will be rendered with no numeric arguments + by default unless these values are provided. + - AbstractType.get_search_list() is removed - the games that was used for are no + longer necessary. + + \ No newline at end of file diff --git a/CHANGES b/CHANGES index 3ff1ee032b..c073159da8 100644 --- a/CHANGES +++ b/CHANGES @@ -199,7 +199,7 @@ CHANGES - Repaired the printing of SQL exceptions which are not based on parameters or are not executemany() style. -- postgres +- postgresql - Deprecated the hardcoded TIMESTAMP function, which when used as func.TIMESTAMP(value) would render "TIMESTAMP value". This breaks on some platforms as PostgreSQL doesn't allow @@ -524,7 +524,7 @@ CHANGES fail on recent versions of pysqlite which raise an error when fetchone() called with no rows present. -- postgres +- postgresql - Index reflection won't fail when an index with multiple expressions is encountered. diff --git a/README.unittests b/README.unittests index 99edfcacec..92a7521d02 100644 --- a/README.unittests +++ b/README.unittests @@ -10,10 +10,19 @@ downloads for nose are available at: http://somethingaboutorange.com/mrl/projects/nose/0.11.1/index.html - SQLAlchemy implements a nose plugin that must be present when tests are run. This plugin is available when SQLAlchemy is installed via setuptools. +INSTANT TEST RUNNER +------------------- + +A plain vanilla run of all tests using sqlite can be run via setup.py: + + $ python setup.py test + +Setuptools will take care of the rest ! To run nose directly and have +its full set of options available, read on... + SETUP ----- @@ -67,11 +76,14 @@ DATABASE TARGETS Tests will target an in-memory SQLite database by default. To test against another database, use the --dburi option with any standard SQLAlchemy URL: - --dburi=postgres://user:password@localhost/test + --dburi=postgresql://user:password@localhost/test -Use an empty database and a database user with general DBA privileges. The -test suite will be creating and dropping many tables and other DDL, and -preexisting tables will interfere with the tests +Use an empty database and a database user with general DBA privileges. +The test suite will be creating and dropping many tables and other DDL, and +preexisting tables will interfere with the tests. + +IMPORTANT !: please see TIPS at the end if your are testing on POSTGRESQL, +ORACLE, or MSSQL - additional steps are required to prepare a test database. If you'll be running the tests frequently, database aliases can save a lot of typing. The --dbs option lists the built-in aliases and their matching URLs: @@ -80,19 +92,19 @@ typing. The --dbs option lists the built-in aliases and their matching URLs: Available --db options (use --dburi to override) mysql mysql://scott:tiger@127.0.0.1:3306/test oracle oracle://scott:tiger@127.0.0.1:1521 - postgres postgres://scott:tiger@127.0.0.1:5432/test + postgresql postgresql://scott:tiger@127.0.0.1:5432/test [...] To run tests against an aliased database: - $ nosetests --db=postgres + $ nosetests --db=postgresql To customize the URLs with your own users or hostnames, make a simple .ini file called `test.cfg` at the top level of the SQLAlchemy source distribution or a `.satest.cfg` in your home directory: [db] - postgres=postgres://myuser:mypass@localhost/mydb + postgresql=postgresql://myuser:mypass@localhost/mydb Your custom entries will override the defaults and you'll see them reflected in the output of --dbs. @@ -159,13 +171,23 @@ IRC! TIPS ---- -PostgreSQL: The tests require an 'alt_schema' and 'alt_schema_2' to be present in + +PostgreSQL: The tests require an 'test_schema' and 'test_schema_2' to be present in the testing database. -PostgreSQL: When running the tests on postgres, postgres can get slower and -slower each time you run the tests. This seems to be related to the constant -creation/dropping of tables. Running a "VACUUM FULL" on the database will -speed it up again. +Oracle: the database owner should be named "scott" (this will be fixed), +and an additional "owner" named "ed" is required: + +1. create a user 'ed' in the oracle database. +2. in 'ed', issue the following statements: + create table parent(id integer primary key, data varchar2(50)); + create table child(id integer primary key, data varchar2(50), parent_id integer references parent(id)); + create synonym ptable for parent; + create synonym ctable for child; + grant all on parent to scott; (or to whoever you run the oracle tests as) + grant all on child to scott; (same) + grant all on ptable to scott; + grant all on ctable to scott; MSSQL: Tests that involve multiple connections require Snapshot Isolation ability implented on the test database in order to prevent deadlocks that will diff --git a/README_THIS_IS_06 b/README_THIS_IS_06 new file mode 100644 index 0000000000..b83f886661 --- /dev/null +++ b/README_THIS_IS_06 @@ -0,0 +1,11 @@ +Trunk is now moved to SQLAlchemy 0.6. + +An ongoing wiki page of changes etc. is at: + +http://www.sqlalchemy.org/trac/wiki/06Migration + +SQLAlchemy 0.5 is now in a maintenance branch. Get it at: + +http://svn.sqlalchemy.org/sqlalchemy/branches/rel_0_5 + + diff --git a/convert.py b/convert.py index b574c27a92..cb2c8c1a75 100644 --- a/convert.py +++ b/convert.py @@ -224,5 +224,7 @@ handlers.append((re.compile(r".*"), default)) if __name__ == '__main__': - convert("test/orm/inheritance/abc_inheritance.py") + import sys + for f in sys.argv[1:]: + convert(f) # walk() diff --git a/doc/build/copyright.rst b/doc/build/copyright.rst index 227a54c9c8..501b4ee757 100644 --- a/doc/build/copyright.rst +++ b/doc/build/copyright.rst @@ -4,7 +4,7 @@ Appendix: Copyright This is the MIT license: ``_ -Copyright (c) 2005, 2006, 2007, 2008 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael +Copyright (c) 2005, 2006, 2007, 2008, 2009 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael Bayer. Permission is hereby granted, free of charge, to any person obtaining a copy of this diff --git a/doc/build/dbengine.rst b/doc/build/dbengine.rst index c7f924a9dc..df1088bcd2 100644 --- a/doc/build/dbengine.rst +++ b/doc/build/dbengine.rst @@ -19,9 +19,9 @@ Where above, a :class:`~sqlalchemy.engine.Engine` references both a :class:`~sq Creating an engine is just a matter of issuing a single call, :func:`create_engine()`:: - engine = create_engine('postgres://scott:tiger@localhost:5432/mydatabase') + engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase') -The above engine invokes the ``postgres`` dialect and a connection pool which references ``localhost:5432``. +The above engine invokes the ``postgresql`` dialect and a connection pool which references ``localhost:5432``. The engine can be used directly to issue SQL to the database. The most generic way is to use connections, which you get via the ``connect()`` method:: @@ -52,7 +52,7 @@ The ``Engine`` and ``Connection`` can do a lot more than what we illustrated abo Supported Databases ==================== -Recall that the ``Dialect`` is used to describe how to talk to a specific kind of database. Dialects are included with SQLAlchemy for many different backends; these can be seen as a Python package within the :mod:`~sqlalchemy.databases` package. Each dialect requires the appropriate DBAPI drivers to be installed separately. +Recall that the ``Dialect`` is used to describe how to talk to a specific kind of database. Dialects are included with SQLAlchemy for many different backends; these can be seen as a Python package within the :mod:`~sqlalchemy.dialect` package. Each dialect requires the appropriate DBAPI drivers to be installed separately. Dialects included with SQLAlchemy fall under one of three categories: supported, experimental, and third party. Supported drivers are those which work against the most common databases available in the open source world, including SQLite, PostgreSQL, MySQL, and Firebird. Very popular commercial databases which provide easy access to test platforms are also supported, these currently include MSSQL and Oracle. These dialects are tested frequently and the level of support should be close to 100% for each. @@ -63,23 +63,22 @@ There are also third-party dialects available - currently IBM offers a DB2/Infor Downloads for each DBAPI at the time of this writing are as follows: * Supported Dialects - - - PostgreSQL: `psycopg2 `_ + - PostgreSQL: `psycopg2 `_ `pg8000 `_ + - PostgreSQL on Jython: `PostgreSQL JDBC Driver `_ - SQLite: `sqlite3 `_ (included in Python 2.5 or greater) `pysqlite `_ - MySQL: `MySQLDB (a.k.a. mysql-python) `_ + - MySQL on Jython: `JDBC Driver for MySQL `_ - Oracle: `cx_Oracle `_ - Firebird: `kinterbasdb `_ - MS-SQL, MSAccess: `pyodbc `_ (recommended) `adodbapi `_ `pymssql `_ * Experimental Dialects - - MSAccess: `pyodbc `_ - Informix: `informixdb `_ - Sybase: TODO - MAXDB: TODO * Third Party Dialects - - DB2/Informix IDS: `ibm-db `_ The SQLAlchemy Wiki contains a page of database notes, describing whatever quirks and behaviors have been observed. Its a good place to check for issues with specific databases. `Database Notes `_ @@ -89,31 +88,42 @@ create_engine() URL Arguments SQLAlchemy indicates the source of an Engine strictly via `RFC-1738 `_ style URLs, combined with optional keyword arguments to specify options for the Engine. The form of the URL is: - driver://username:password@host:port/database + dialect+driver://username:password@host:port/database + +Dialect names include the identifying name of the SQLAlchemy dialect which include ``sqlite``, ``mysql``, ``postgresql``, ``oracle``, ``mssql``, and ``firebird``. The drivername is the name of the DBAPI to be used to connect to the database using all lowercase letters. If not specified, a "default" DBAPI will be imported if available - this default is typically the most widely known driver available for that backend (i.e. cx_oracle, pysqlite/sqlite3, psycopg2, mysqldb). For Jython connections, the driver is always `zxjdbc`, which is the JDBC-DBAPI bridge included with Jython. + +.. sourcecode:: python+sql + + # postgresql - psycopg2 is the default driver. + pg_db = create_engine('postgresql://scott:tiger@localhost/mydatabase') + pg_db = create_engine('postgresql+psycopg2://scott:tiger@localhost/mydatabase') + pg_db = create_engine('postgresql+pg8000://scott:tiger@localhost/mydatabase') -Dialect names include the identifying name of the SQLAlchemy dialect which include ``sqlite``, ``mysql``, ``postgres``, ``oracle``, ``mssql``, and ``firebird``. In SQLAlchemy 0.5 and earlier, the DBAPI implementation is automatically selected if more than one are available - currently this includes only MSSQL (pyodbc is the default, then adodbapi, then pymssql) and SQLite (sqlite3 is the default, or pysqlite if sqlite3 is not availble). When using MSSQL, ``create_engine()`` accepts a ``module`` argument which specifies the name of the desired DBAPI to be used, overriding the default behavior. + # postgresql on Jython + pg_db = create_engine('postgresql+zxjdbc://scott:tiger@localhost/mydatabase') + + # mysql - MySQLdb (mysql-python) is the default driver + mysql_db = create_engine('mysql://scott:tiger@localhost/foo') + mysql_db = create_engine('mysql+mysqldb://scott:tiger@localhost/foo') + + # mysql on Jython + mysql_db = create_engine('mysql+zxjdbc://localhost/foo') - .. sourcecode:: python+sql - - # postgresql - pg_db = create_engine('postgres://scott:tiger@localhost/mydatabase') + # mysql with pyodbc (buggy) + mysql_db = create_engine('mysql+pyodbc://scott:tiger@some_dsn') - # mysql - mysql_db = create_engine('mysql://scott:tiger@localhost/mydatabase') - - # oracle + # oracle - cx_oracle is the default driver oracle_db = create_engine('oracle://scott:tiger@127.0.0.1:1521/sidname') - + # oracle via TNS name - oracle_db = create_engine('oracle://scott:tiger@tnsname') - + oracle_db = create_engine('oracle+cx_oracle://scott:tiger@tnsname') + # mssql using ODBC datasource names. PyODBC is the default driver. mssql_db = create_engine('mssql://mydsn') - mssql_db = create_engine('mssql://scott:tiger@mydsn') - - # firebird - firebird_db = create_engine('firebird://scott:tiger@localhost/sometest.gdm') - + mssql_db = create_engine('mssql+pyodbc://mydsn') + mssql_db = create_engine('mssql+adodbapi://mydsn') + mssql_db = create_engine('mssql+pyodbc://username:password@mydsn') + SQLite connects to file based databases. The same URL format is used, omitting the hostname, and using the "file" portion as the filename of the database. This has the effect of four slashes being present for an absolute file path:: # sqlite:/// @@ -132,12 +142,11 @@ The :class:`~sqlalchemy.engine.base.Engine` will ask the connection pool for a c Custom DBAPI connect() arguments -------------------------------- - Custom arguments used when issuing the ``connect()`` call to the underlying DBAPI may be issued in three distinct ways. String-based arguments can be passed directly from the URL string as query arguments: .. sourcecode:: python+sql - db = create_engine('postgres://scott:tiger@localhost/test?argument1=foo&argument2=bar') + db = create_engine('postgresql://scott:tiger@localhost/test?argument1=foo&argument2=bar') If SQLAlchemy's database connector is aware of a particular query argument, it may convert its type from string to its proper type. @@ -145,7 +154,7 @@ If SQLAlchemy's database connector is aware of a particular query argument, it m .. sourcecode:: python+sql - db = create_engine('postgres://scott:tiger@localhost/test', connect_args = {'argument1':17, 'argument2':'bar'}) + db = create_engine('postgresql://scott:tiger@localhost/test', connect_args = {'argument1':17, 'argument2':'bar'}) The most customizable connection method of all is to pass a ``creator`` argument, which specifies a callable that returns a DBAPI connection: @@ -154,7 +163,7 @@ The most customizable connection method of all is to pass a ``creator`` argument def connect(): return psycopg.connect(user='scott', host='localhost') - db = create_engine('postgres://', creator=connect) + db = create_engine('postgresql://', creator=connect) .. _create_engine_args: @@ -165,7 +174,7 @@ Keyword options can also be specified to ``create_engine()``, following the stri .. sourcecode:: python+sql - db = create_engine('postgres://...', encoding='latin1', echo=True) + db = create_engine('postgresql://...', encoding='latin1', echo=True) Options common to all database dialects are described at :func:`~sqlalchemy.create_engine`. diff --git a/doc/build/metadata.rst b/doc/build/metadata.rst index f79c637ee0..464b764bf1 100644 --- a/doc/build/metadata.rst +++ b/doc/build/metadata.rst @@ -396,7 +396,7 @@ The above SQL functions are usually executed "inline" with the INSERT or UPDATE * the ``inline=True`` flag is not set on the ``Insert()`` or ``Update()`` construct. -For a statement execution which is not an executemany, the returned ``ResultProxy`` will contain a collection accessible via ``result.postfetch_cols()`` which contains a list of all ``Column`` objects which had an inline-executed default. Similarly, all parameters which were bound to the statement, including all Python and SQL expressions which were pre-executed, are present in the ``last_inserted_params()`` or ``last_updated_params()`` collections on ``ResultProxy``. The ``last_inserted_ids()`` collection contains a list of primary key values for the row inserted. +For a statement execution which is not an executemany, the returned ``ResultProxy`` will contain a collection accessible via ``result.postfetch_cols()`` which contains a list of all ``Column`` objects which had an inline-executed default. Similarly, all parameters which were bound to the statement, including all Python and SQL expressions which were pre-executed, are present in the ``last_inserted_params()`` or ``last_updated_params()`` collections on ``ResultProxy``. The ``inserted_primary_key`` collection contains a list of primary key values for the row inserted. DDL-Level Defaults ------------------- diff --git a/doc/build/reference/dialects/access.rst b/doc/build/reference/dialects/access.rst index cd635aaa09..52a2ee3710 100644 --- a/doc/build/reference/dialects/access.rst +++ b/doc/build/reference/dialects/access.rst @@ -1,4 +1,4 @@ -Access -====== +Microsoft Access +================ -.. automodule:: sqlalchemy.databases.access +.. automodule:: sqlalchemy.dialects.access.base diff --git a/doc/build/reference/dialects/firebird.rst b/doc/build/reference/dialects/firebird.rst index 19a2c4f918..54c38f49b0 100644 --- a/doc/build/reference/dialects/firebird.rst +++ b/doc/build/reference/dialects/firebird.rst @@ -1,4 +1,4 @@ Firebird ======== -.. automodule:: sqlalchemy.databases.firebird +.. automodule:: sqlalchemy.dialects.firebird.base diff --git a/doc/build/reference/dialects/index.rst b/doc/build/reference/dialects/index.rst index fe9f253952..f9c4df5ce8 100644 --- a/doc/build/reference/dialects/index.rst +++ b/doc/build/reference/dialects/index.rst @@ -3,17 +3,33 @@ sqlalchemy.databases ==================== +Supported Databases +------------------- + +These backends are fully operational with +current versions of SQLAlchemy. + .. toctree:: :glob: - access firebird - informix - maxdb mssql mysql oracle - postgres + postgresql sqlite + +Unsupported Databases +--------------------- + +These backends are untested and may not be completely +ported to current versions of SQLAlchemy. + +.. toctree:: + :glob: + + access + informix + maxdb sybase diff --git a/doc/build/reference/dialects/informix.rst b/doc/build/reference/dialects/informix.rst index 9f787e3c29..7cf271d0b7 100644 --- a/doc/build/reference/dialects/informix.rst +++ b/doc/build/reference/dialects/informix.rst @@ -1,4 +1,4 @@ Informix ======== -.. automodule:: sqlalchemy.databases.informix +.. automodule:: sqlalchemy.dialects.informix.base diff --git a/doc/build/reference/dialects/maxdb.rst b/doc/build/reference/dialects/maxdb.rst index b137da917c..3edd55a775 100644 --- a/doc/build/reference/dialects/maxdb.rst +++ b/doc/build/reference/dialects/maxdb.rst @@ -1,4 +1,4 @@ MaxDB ===== -.. automodule:: sqlalchemy.databases.maxdb +.. automodule:: sqlalchemy.dialects.maxdb.base diff --git a/doc/build/reference/dialects/mssql.rst b/doc/build/reference/dialects/mssql.rst index a55ab85a95..68c0f0462c 100644 --- a/doc/build/reference/dialects/mssql.rst +++ b/doc/build/reference/dialects/mssql.rst @@ -1,4 +1,18 @@ -SQL Server -========== +Microsoft SQL Server +==================== + +.. automodule:: sqlalchemy.dialects.mssql.base + +PyODBC +------ +.. automodule:: sqlalchemy.dialects.mssql.pyodbc + +AdoDBAPI +-------- +.. automodule:: sqlalchemy.dialects.mssql.adodbapi + +pymssql +------- +.. automodule:: sqlalchemy.dialects.mssql.pymssql + -.. automodule:: sqlalchemy.databases.mssql diff --git a/doc/build/reference/dialects/mysql.rst b/doc/build/reference/dialects/mysql.rst index 28f905343f..839b8cae07 100644 --- a/doc/build/reference/dialects/mysql.rst +++ b/doc/build/reference/dialects/mysql.rst @@ -1,140 +1,149 @@ MySQL ===== -.. automodule:: sqlalchemy.databases.mysql +.. automodule:: sqlalchemy.dialects.mysql.base MySQL Column Types ------------------ -.. autoclass:: MSNumeric +.. autoclass:: NUMERIC :members: __init__ :show-inheritance: -.. autoclass:: MSDecimal +.. autoclass:: DECIMAL :members: __init__ :show-inheritance: -.. autoclass:: MSDouble +.. autoclass:: DOUBLE :members: __init__ :show-inheritance: -.. autoclass:: MSReal +.. autoclass:: REAL :members: __init__ :show-inheritance: -.. autoclass:: MSFloat +.. autoclass:: FLOAT :members: __init__ :show-inheritance: -.. autoclass:: MSInteger +.. autoclass:: INTEGER :members: __init__ :show-inheritance: -.. autoclass:: MSBigInteger +.. autoclass:: BIGINT :members: __init__ :show-inheritance: -.. autoclass:: MSMediumInteger +.. autoclass:: MEDIUMINT :members: __init__ :show-inheritance: -.. autoclass:: MSTinyInteger +.. autoclass:: TINYINT :members: __init__ :show-inheritance: -.. autoclass:: MSSmallInteger +.. autoclass:: SMALLINT :members: __init__ :show-inheritance: -.. autoclass:: MSBit +.. autoclass:: BIT :members: __init__ :show-inheritance: -.. autoclass:: MSDateTime +.. autoclass:: DATETIME :members: __init__ :show-inheritance: -.. autoclass:: MSDate +.. autoclass:: DATE :members: __init__ :show-inheritance: -.. autoclass:: MSTime +.. autoclass:: TIME :members: __init__ :show-inheritance: -.. autoclass:: MSTimeStamp +.. autoclass:: TIMESTAMP :members: __init__ :show-inheritance: -.. autoclass:: MSYear +.. autoclass:: YEAR :members: __init__ :show-inheritance: -.. autoclass:: MSText +.. autoclass:: TEXT :members: __init__ :show-inheritance: -.. autoclass:: MSTinyText +.. autoclass:: TINYTEXT :members: __init__ :show-inheritance: -.. autoclass:: MSMediumText +.. autoclass:: MEDIUMTEXT :members: __init__ :show-inheritance: -.. autoclass:: MSLongText +.. autoclass:: LONGTEXT :members: __init__ :show-inheritance: -.. autoclass:: MSString +.. autoclass:: VARCHAR :members: __init__ :show-inheritance: -.. autoclass:: MSChar +.. autoclass:: CHAR :members: __init__ :show-inheritance: -.. autoclass:: MSNVarChar +.. autoclass:: NVARCHAR :members: __init__ :show-inheritance: -.. autoclass:: MSNChar +.. autoclass:: NCHAR :members: __init__ :show-inheritance: -.. autoclass:: MSVarBinary +.. autoclass:: VARBINARY :members: __init__ :show-inheritance: -.. autoclass:: MSBinary +.. autoclass:: BINARY :members: __init__ :show-inheritance: -.. autoclass:: MSBlob +.. autoclass:: BLOB :members: __init__ :show-inheritance: -.. autoclass:: MSTinyBlob +.. autoclass:: TINYBLOB :members: __init__ :show-inheritance: -.. autoclass:: MSMediumBlob +.. autoclass:: MEDIUMBLOB :members: __init__ :show-inheritance: -.. autoclass:: MSLongBlob +.. autoclass:: LONGBLOB :members: __init__ :show-inheritance: -.. autoclass:: MSEnum +.. autoclass:: ENUM :members: __init__ :show-inheritance: -.. autoclass:: MSSet +.. autoclass:: SET :members: __init__ :show-inheritance: -.. autoclass:: MSBoolean +.. autoclass:: BOOLEAN :members: __init__ :show-inheritance: +MySQLdb Notes +-------------- + +.. automodule:: sqlalchemy.dialects.mysql.mysqldb + +zxjdbc Notes +-------------- + +.. automodule:: sqlalchemy.dialects.mysql.zxjdbc diff --git a/doc/build/reference/dialects/oracle.rst b/doc/build/reference/dialects/oracle.rst index 188f6f4383..584dfbf814 100644 --- a/doc/build/reference/dialects/oracle.rst +++ b/doc/build/reference/dialects/oracle.rst @@ -1,4 +1,10 @@ Oracle ====== -.. automodule:: sqlalchemy.databases.oracle +.. automodule:: sqlalchemy.dialects.oracle.base + +cx_Oracle Notes +--------------- + +.. automodule:: sqlalchemy.dialects.oracle.cx_oracle + diff --git a/doc/build/reference/dialects/postgres.rst b/doc/build/reference/dialects/postgres.rst deleted file mode 100644 index 7cf072383e..0000000000 --- a/doc/build/reference/dialects/postgres.rst +++ /dev/null @@ -1,4 +0,0 @@ -PostgreSQL -========== - -.. automodule:: sqlalchemy.databases.postgres diff --git a/doc/build/reference/dialects/postgresql.rst b/doc/build/reference/dialects/postgresql.rst new file mode 100644 index 0000000000..7e00645d82 --- /dev/null +++ b/doc/build/reference/dialects/postgresql.rst @@ -0,0 +1,20 @@ +PostgreSQL +========== + +.. automodule:: sqlalchemy.dialects.postgresql.base + +psycopg2 Notes +-------------- + +.. automodule:: sqlalchemy.dialects.postgresql.psycopg2 + + +pg8000 Notes +-------------- + +.. automodule:: sqlalchemy.dialects.postgresql.pg8000 + +zxjdbc Notes +-------------- + +.. automodule:: sqlalchemy.dialects.postgresql.zxjdbc diff --git a/doc/build/reference/dialects/sqlite.rst b/doc/build/reference/dialects/sqlite.rst index 118c239b1d..8361876c38 100644 --- a/doc/build/reference/dialects/sqlite.rst +++ b/doc/build/reference/dialects/sqlite.rst @@ -1,5 +1,9 @@ SQLite ====== -.. automodule:: sqlalchemy.databases.sqlite +.. automodule:: sqlalchemy.dialects.sqlite.base +Pysqlite +-------- + +.. automodule:: sqlalchemy.dialects.sqlite.pysqlite \ No newline at end of file diff --git a/doc/build/reference/dialects/sybase.rst b/doc/build/reference/dialects/sybase.rst index fac1a1f6b4..1b7651d2cf 100644 --- a/doc/build/reference/dialects/sybase.rst +++ b/doc/build/reference/dialects/sybase.rst @@ -1,4 +1,4 @@ Sybase ====== -.. automodule:: sqlalchemy.databases.sybase +.. automodule:: sqlalchemy.dialects.sybase.base diff --git a/doc/build/reference/sqlalchemy/connections.rst b/doc/build/reference/sqlalchemy/connections.rst index 2f861816c3..394fa864ce 100644 --- a/doc/build/reference/sqlalchemy/connections.rst +++ b/doc/build/reference/sqlalchemy/connections.rst @@ -65,7 +65,3 @@ Internals .. autoclass:: ExecutionContext :members: -.. autoclass:: SchemaIterator - :members: - :show-inheritance: - diff --git a/doc/build/reference/sqlalchemy/pooling.rst b/doc/build/reference/sqlalchemy/pooling.rst index 91e9681978..d37425e3a6 100644 --- a/doc/build/reference/sqlalchemy/pooling.rst +++ b/doc/build/reference/sqlalchemy/pooling.rst @@ -32,7 +32,7 @@ directly to :func:`~sqlalchemy.create_engine` as keyword arguments: ``pool_size``, ``max_overflow``, ``pool_recycle`` and ``pool_timeout``. For example:: - engine = create_engine('postgres://me@localhost/mydb', + engine = create_engine('postgresql://me@localhost/mydb', pool_size=20, max_overflow=0) In the case of SQLite, a :class:`SingletonThreadPool` is provided instead, diff --git a/doc/build/reference/sqlalchemy/types.rst b/doc/build/reference/sqlalchemy/types.rst index afe509d74b..6eb532fe18 100644 --- a/doc/build/reference/sqlalchemy/types.rst +++ b/doc/build/reference/sqlalchemy/types.rst @@ -153,22 +153,36 @@ reference for the database you're interested in. For example, MySQL has a ``BIGINTEGER`` type and PostgreSQL has an ``INET`` type. To use these, import them from the module explicitly:: - from sqlalchemy.databases.mysql import MSBigInteger, MSEnum + from sqlalchemy.dialect.mysql import dialect as mysql table = Table('foo', meta, - Column('id', MSBigInteger), - Column('enumerates', MSEnum('a', 'b', 'c')) + Column('id', mysql.BIGINTEGER), + Column('enumerates', mysql.ENUM('a', 'b', 'c')) ) Or some PostgreSQL types:: - from sqlalchemy.databases.postgres import PGInet, PGArray + from sqlalchemy.dialect.postgresql import dialect as postgresql table = Table('foo', meta, - Column('ipaddress', PGInet), - Column('elements', PGArray(str)) + Column('ipaddress', postgresql.INET), + Column('elements', postgresql.ARRAY(str)) ) +Each dialect should provide the full set of typenames supported by +that backend, so that a backend-specific schema can be created without +the need to locate types:: + + from sqlalchemy.dialects.postgresql import dialect as pg + + t = Table('mytable', metadata, + Column('id', pg.INTEGER, primary_key=True), + Column('name', pg.VARCHAR(300)), + Column('inetaddr', pg.INET) + ) + +Where above, the INTEGER and VARCHAR types are ultimately from +sqlalchemy.types, but the Postgresql dialect makes them available. Custom Types ------------ @@ -181,7 +195,7 @@ The simplest method is implementing a :class:`TypeDecorator`, a helper class that makes it easy to augment the bind parameter and result processing capabilities of one of the built in types. -To build a type object from scratch, subclass `:class:TypeEngine`. +To build a type object from scratch, subclass `:class:UserDefinedType`. .. autoclass:: TypeDecorator :members: @@ -189,6 +203,12 @@ To build a type object from scratch, subclass `:class:TypeEngine`. :inherited-members: :show-inheritance: +.. autoclass:: UserDefinedType + :members: + :undoc-members: + :inherited-members: + :show-inheritance: + .. autoclass:: TypeEngine :members: :undoc-members: diff --git a/doc/build/session.rst b/doc/build/session.rst index b2b66c32fe..c704dc7928 100644 --- a/doc/build/session.rst +++ b/doc/build/session.rst @@ -54,7 +54,7 @@ In our previous example regarding ``sessionmaker()``, we specified a ``bind`` fo Session = sessionmaker() # later, we create the engine - engine = create_engine('postgres://...') + engine = create_engine('postgresql://...') # associate it with our custom Session class Session.configure(bind=engine) @@ -74,7 +74,7 @@ The ``Session`` can also be explicitly bound to an individual database ``Connect # global application scope. create Session class, engine Session = sessionmaker() - engine = create_engine('postgres://...') + engine = create_engine('postgresql://...') ... @@ -219,7 +219,7 @@ With ``merge()``, the given instance is not placed within the session, and can b * An application which reads an object structure from a file and wishes to save it to the database might parse the file, build up the structure, and then use ``merge()`` to save it to the database, ensuring that the data within the file is used to formulate the primary key of each element of the structure. Later, when the file has changed, the same process can be re-run, producing a slightly different object structure, which can then be ``merged()`` in again, and the ``Session`` will automatically update the database to reflect those changes. * A web application stores mapped entities within an HTTP session object. When each request starts up, the serialized data can be merged into the session, so that the original entity may be safely shared among requests and threads. -``merge()`` is frequently used by applications which implement their own second level caches. This refers to an application which uses an in memory dictionary, or an tool like Memcached to store objects over long running spans of time. When such an object needs to exist within a ``Session``, ``merge()`` is a good choice since it leaves the original cached object untouched. For this use case, merge provides a keyword option called ``dont_load=True``. When this boolean flag is set to ``True``, ``merge()`` will not issue any SQL to reconcile the given object against the current state of the database, thereby reducing query overhead. The limitation is that the given object and all of its children may not contain any pending changes, and it's also of course possible that newer information in the database will not be present on the merged object, since no load is issued. +``merge()`` is frequently used by applications which implement their own second level caches. This refers to an application which uses an in memory dictionary, or an tool like Memcached to store objects over long running spans of time. When such an object needs to exist within a ``Session``, ``merge()`` is a good choice since it leaves the original cached object untouched. For this use case, merge provides a keyword option called ``load=False``. When this boolean flag is set to ``False``, ``merge()`` will not issue any SQL to reconcile the given object against the current state of the database, thereby reducing query overhead. The limitation is that the given object and all of its children may not contain any pending changes, and it's also of course possible that newer information in the database will not be present on the merged object, since no load is issued. Deleting -------- @@ -459,8 +459,8 @@ Enabling Two-Phase Commit Finally, for MySQL, PostgreSQL, and soon Oracle as well, the session can be instructed to use two-phase commit semantics. This will coordinate the committing of transactions across databases so that the transaction is either committed or rolled back in all databases. You can also ``prepare()`` the session for interacting with transactions not managed by SQLAlchemy. To use two phase transactions set the flag ``twophase=True`` on the session:: - engine1 = create_engine('postgres://db1') - engine2 = create_engine('postgres://db2') + engine1 = create_engine('postgresql://db1') + engine2 = create_engine('postgresql://db2') Session = sessionmaker(twophase=True) @@ -549,7 +549,7 @@ Note that above, we issue a ``commit()`` both on the ``Session`` as well as the When using the ``threadlocal`` engine context, the process above is simplified; the ``Session`` uses the same connection/transaction as everyone else in the current thread, whether or not you explicitly bind it:: - engine = create_engine('postgres://mydb', strategy="threadlocal") + engine = create_engine('postgresql://mydb', strategy="threadlocal") engine.begin() session = Session() # session takes place in the transaction like everyone else @@ -652,8 +652,8 @@ Vertical Partitioning Vertical partitioning places different kinds of objects, or different tables, across multiple databases:: - engine1 = create_engine('postgres://db1') - engine2 = create_engine('postgres://db2') + engine1 = create_engine('postgresql://db1') + engine2 = create_engine('postgresql://db2') Session = sessionmaker(twophase=True) diff --git a/doc/build/sqlexpression.rst b/doc/build/sqlexpression.rst index 387013cacc..2bcaa631d2 100644 --- a/doc/build/sqlexpression.rst +++ b/doc/build/sqlexpression.rst @@ -143,10 +143,10 @@ What about the ``result`` variable we got when we called ``execute()`` ? As the .. sourcecode:: pycon+sql - >>> result.last_inserted_ids() + >>> result.inserted_primary_key [1] -The value of ``1`` was automatically generated by SQLite, but only because we did not specify the ``id`` column in our ``Insert`` statement; otherwise, our explicit value would have been used. In either case, SQLAlchemy always knows how to get at a newly generated primary key value, even though the method of generating them is different across different databases; each databases' ``Dialect`` knows the specific steps needed to determine the correct value (or values; note that ``last_inserted_ids()`` returns a list so that it supports composite primary keys). +The value of ``1`` was automatically generated by SQLite, but only because we did not specify the ``id`` column in our ``Insert`` statement; otherwise, our explicit value would have been used. In either case, SQLAlchemy always knows how to get at a newly generated primary key value, even though the method of generating them is different across different databases; each databases' ``Dialect`` knows the specific steps needed to determine the correct value (or values; note that ``inserted_primary_key`` returns a list so that it supports composite primary keys). Executing Multiple Statements ============================== diff --git a/doc/build/testdocs.py b/doc/build/testdocs.py index 0a344e98ef..1f57e32720 100644 --- a/doc/build/testdocs.py +++ b/doc/build/testdocs.py @@ -55,7 +55,7 @@ def teststring(s, name, globs=None, verbose=None, report=True, return runner.failures, runner.tries def replace_file(s, newfile): - engine = r"'(sqlite|postgres|mysql):///.*'" + engine = r"'(sqlite|postgresql|mysql):///.*'" engine = re.compile(engine, re.MULTILINE) s, n = re.subn(engine, "'sqlite:///" + newfile + "'", s) if not n: diff --git a/examples/postgis/postgis.py b/examples/postgis/postgis.py index c482d82560..8e687d7f8c 100644 --- a/examples/postgis/postgis.py +++ b/examples/postgis/postgis.py @@ -231,7 +231,7 @@ if __name__ == '__main__': from sqlalchemy.orm import sessionmaker, column_property from sqlalchemy.ext.declarative import declarative_base - engine = create_engine('postgres://scott:tiger@localhost/gistest', echo=True) + engine = create_engine('postgresql://scott:tiger@localhost/gistest', echo=True) metadata = MetaData(engine) Base = declarative_base(metadata=metadata) diff --git a/examples/query_caching/query_caching.py b/examples/query_caching/query_caching.py index 92d48e2d78..00a4cc3ec2 100644 --- a/examples/query_caching/query_caching.py +++ b/examples/query_caching/query_caching.py @@ -27,7 +27,7 @@ class CachingQuery(Query): self.session.expunge(x) _cache[self.cachekey] = ret - return iter(self.session.merge(x, dont_load=True) for x in ret) + return iter(self.session.merge(x, load=False) for x in ret) else: return Query.__iter__(self) diff --git a/ez_setup.py b/ez_setup.py new file mode 100644 index 0000000000..d24e845e58 --- /dev/null +++ b/ez_setup.py @@ -0,0 +1,276 @@ +#!python +"""Bootstrap setuptools installation + +If you want to use setuptools in your package's setup.py, just include this +file in the same directory with it, and add this to the top of your setup.py:: + + from ez_setup import use_setuptools + use_setuptools() + +If you want to require a specific version of setuptools, set a download +mirror, or use an alternate download directory, you can do so by supplying +the appropriate options to ``use_setuptools()``. + +This file can also be run as a script to install or upgrade setuptools. +""" +import sys +DEFAULT_VERSION = "0.6c9" +DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3] + +md5_data = { + 'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca', + 'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb', + 'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b', + 'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a', + 'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618', + 'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac', + 'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5', + 'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4', + 'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c', + 'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b', + 'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27', + 'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277', + 'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa', + 'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e', + 'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e', + 'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f', + 'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2', + 'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc', + 'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167', + 'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64', + 'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d', + 'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20', + 'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab', + 'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53', + 'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2', + 'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e', + 'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372', + 'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902', + 'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de', + 'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b', + 'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03', + 'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a', + 'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6', + 'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a', +} + +import sys, os +try: from hashlib import md5 +except ImportError: from md5 import md5 + +def _validate_md5(egg_name, data): + if egg_name in md5_data: + digest = md5(data).hexdigest() + if digest != md5_data[egg_name]: + print >>sys.stderr, ( + "md5 validation of %s failed! (Possible download problem?)" + % egg_name + ) + sys.exit(2) + return data + +def use_setuptools( + version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, + download_delay=15 +): + """Automatically find/download setuptools and make it available on sys.path + + `version` should be a valid setuptools version number that is available + as an egg for download under the `download_base` URL (which should end with + a '/'). `to_dir` is the directory where setuptools will be downloaded, if + it is not already available. If `download_delay` is specified, it should + be the number of seconds that will be paused before initiating a download, + should one be required. If an older version of setuptools is installed, + this routine will print a message to ``sys.stderr`` and raise SystemExit in + an attempt to abort the calling script. + """ + was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules + def do_download(): + egg = download_setuptools(version, download_base, to_dir, download_delay) + sys.path.insert(0, egg) + import setuptools; setuptools.bootstrap_install_from = egg + try: + import pkg_resources + except ImportError: + return do_download() + try: + pkg_resources.require("setuptools>="+version); return + except pkg_resources.VersionConflict, e: + if was_imported: + print >>sys.stderr, ( + "The required version of setuptools (>=%s) is not available, and\n" + "can't be installed while this script is running. Please install\n" + " a more recent version first, using 'easy_install -U setuptools'." + "\n\n(Currently using %r)" + ) % (version, e.args[0]) + sys.exit(2) + else: + del pkg_resources, sys.modules['pkg_resources'] # reload ok + return do_download() + except pkg_resources.DistributionNotFound: + return do_download() + +def download_setuptools( + version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, + delay = 15 +): + """Download setuptools from a specified location and return its filename + + `version` should be a valid setuptools version number that is available + as an egg for download under the `download_base` URL (which should end + with a '/'). `to_dir` is the directory where the egg will be downloaded. + `delay` is the number of seconds to pause before an actual download attempt. + """ + import urllib2, shutil + egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3]) + url = download_base + egg_name + saveto = os.path.join(to_dir, egg_name) + src = dst = None + if not os.path.exists(saveto): # Avoid repeated downloads + try: + from distutils import log + if delay: + log.warn(""" +--------------------------------------------------------------------------- +This script requires setuptools version %s to run (even to display +help). I will attempt to download it for you (from +%s), but +you may need to enable firewall access for this script first. +I will start the download in %d seconds. + +(Note: if this machine does not have network access, please obtain the file + + %s + +and place it in this directory before rerunning this script.) +---------------------------------------------------------------------------""", + version, download_base, delay, url + ); from time import sleep; sleep(delay) + log.warn("Downloading %s", url) + src = urllib2.urlopen(url) + # Read/write all in one block, so we don't create a corrupt file + # if the download is interrupted. + data = _validate_md5(egg_name, src.read()) + dst = open(saveto,"wb"); dst.write(data) + finally: + if src: src.close() + if dst: dst.close() + return os.path.realpath(saveto) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +def main(argv, version=DEFAULT_VERSION): + """Install or upgrade setuptools and EasyInstall""" + try: + import setuptools + except ImportError: + egg = None + try: + egg = download_setuptools(version, delay=0) + sys.path.insert(0,egg) + from setuptools.command.easy_install import main + return main(list(argv)+[egg]) # we're done here + finally: + if egg and os.path.exists(egg): + os.unlink(egg) + else: + if setuptools.__version__ == '0.0.1': + print >>sys.stderr, ( + "You have an obsolete version of setuptools installed. Please\n" + "remove it from your system entirely before rerunning this script." + ) + sys.exit(2) + + req = "setuptools>="+version + import pkg_resources + try: + pkg_resources.require(req) + except pkg_resources.VersionConflict: + try: + from setuptools.command.easy_install import main + except ImportError: + from easy_install import main + main(list(argv)+[download_setuptools(delay=0)]) + sys.exit(0) # try to force an exit + else: + if argv: + from setuptools.command.easy_install import main + main(argv) + else: + print "Setuptools version",version,"or greater has been installed." + print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)' + +def update_md5(filenames): + """Update our built-in md5 registry""" + + import re + + for name in filenames: + base = os.path.basename(name) + f = open(name,'rb') + md5_data[base] = md5(f.read()).hexdigest() + f.close() + + data = [" %r: %r,\n" % it for it in md5_data.items()] + data.sort() + repl = "".join(data) + + import inspect + srcfile = inspect.getsourcefile(sys.modules[__name__]) + f = open(srcfile, 'rb'); src = f.read(); f.close() + + match = re.search("\nmd5_data = {\n([^}]+)}", src) + if not match: + print >>sys.stderr, "Internal error!" + sys.exit(2) + + src = src[:match.start(1)] + repl + src[match.end(1):] + f = open(srcfile,'w') + f.write(src) + f.close() + + +if __name__=='__main__': + if len(sys.argv)>2 and sys.argv[1]=='--md5update': + update_md5(sys.argv[2:]) + else: + main(sys.argv[1:]) + + + + + + diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index ddbbb7b7ed..31469ee5ae 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -10,40 +10,6 @@ import sys import sqlalchemy.exc as exceptions sys.modules['sqlalchemy.exceptions'] = exceptions -from sqlalchemy.types import ( - BLOB, - BOOLEAN, - Binary, - Boolean, - CHAR, - CLOB, - DATE, - DATETIME, - DECIMAL, - Date, - DateTime, - FLOAT, - Float, - INT, - Integer, - Interval, - NCHAR, - NUMERIC, - Numeric, - PickleType, - SMALLINT, - SmallInteger, - String, - TEXT, - TIME, - TIMESTAMP, - Text, - Time, - Unicode, - UnicodeText, - VARCHAR, - ) - from sqlalchemy.sql import ( alias, and_, @@ -81,6 +47,43 @@ from sqlalchemy.sql import ( update, ) +from sqlalchemy.types import ( + BLOB, + BOOLEAN, + Binary, + Boolean, + CHAR, + CLOB, + DATE, + DATETIME, + DECIMAL, + Date, + DateTime, + FLOAT, + Float, + INT, + INTEGER, + Integer, + Interval, + NCHAR, + NVARCHAR, + NUMERIC, + Numeric, + PickleType, + SMALLINT, + SmallInteger, + String, + TEXT, + TIME, + TIMESTAMP, + Text, + Time, + Unicode, + UnicodeText, + VARCHAR, + ) + + from sqlalchemy.schema import ( CheckConstraint, Column, @@ -107,6 +110,6 @@ from sqlalchemy.engine import create_engine, engine_from_config __all__ = sorted(name for name, obj in locals().items() if not (name.startswith('_') or inspect.ismodule(obj))) -__version__ = '0.5.5' +__version__ = '0.6beta1' del inspect, sys diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py new file mode 100644 index 0000000000..f1383ad829 --- /dev/null +++ b/lib/sqlalchemy/connectors/__init__.py @@ -0,0 +1,6 @@ + + +class Connector(object): + pass + + \ No newline at end of file diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py new file mode 100644 index 0000000000..a0f3f02161 --- /dev/null +++ b/lib/sqlalchemy/connectors/mxodbc.py @@ -0,0 +1,24 @@ +from sqlalchemy.connectors import Connector + +class MxODBCConnector(Connector): + driver='mxodbc' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + supports_unicode_statements = False + supports_unicode_binds = False + + @classmethod + def import_dbapi(cls): + import mxODBC as module + return module + + def create_connect_args(self, url): + '''Return a tuple of *args,**kwargs''' + # FIXME: handle mx.odbc.Windows proprietary args + opts = url.translate_connect_args(username='user') + opts.update(url.query) + argsDict = {} + argsDict['user'] = opts['user'] + argsDict['password'] = opts['password'] + connArgs = [[opts['dsn']], argsDict] + return connArgs diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py new file mode 100644 index 0000000000..4f8d6d517f --- /dev/null +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -0,0 +1,80 @@ +from sqlalchemy.connectors import Connector + +import sys +import re +import urllib + +class PyODBCConnector(Connector): + driver='pyodbc' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + # PyODBC unicode is broken on UCS-4 builds + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = supports_unicode + default_paramstyle = 'named' + + # for non-DSN connections, this should + # hold the desired driver name + pyodbc_driver_name = None + + @classmethod + def dbapi(cls): + return __import__('pyodbc') + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + + keys = opts + query = url.query + + if 'odbc_connect' in keys: + connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))] + else: + dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys) + if dsn_connection: + connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))] + else: + port = '' + if 'port' in keys and not 'port' in query: + port = ',%d' % int(keys.pop('port')) + + connectors = ["DRIVER={%s}" % keys.pop('driver', self.pyodbc_driver_name), + 'Server=%s%s' % (keys.pop('host', ''), port), + 'Database=%s' % keys.pop('database', '') ] + + user = keys.pop("user", None) + if user: + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % keys.pop('password', '')) + else: + connectors.append("TrustedConnection=Yes") + + # if set to 'Yes', the ODBC layer will try to automagically convert + # textual data from your database encoding to your client encoding + # This should obviously be set to 'No' if you query a cp1253 encoded + # database from a latin1 client... + if 'odbc_autotranslate' in keys: + connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) + + connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()]) + return [[";".join (connectors)], {}] + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.ProgrammingError): + return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e) + elif isinstance(e, self.dbapi.Error): + return '[08S01]' in str(e) + else: + return False + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) diff --git a/lib/sqlalchemy/connectors/zxJDBC.py b/lib/sqlalchemy/connectors/zxJDBC.py new file mode 100644 index 0000000000..3cdfeb32e4 --- /dev/null +++ b/lib/sqlalchemy/connectors/zxJDBC.py @@ -0,0 +1,43 @@ +import sys +from sqlalchemy.connectors import Connector + +class ZxJDBCConnector(Connector): + driver = 'zxjdbc' + + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + supports_unicode_binds = True + supports_unicode_statements = sys.version > '2.5.0+' + description_encoding = None + default_paramstyle = 'qmark' + + jdbc_db_name = None + jdbc_driver_name = None + + @classmethod + def dbapi(cls): + from com.ziclix.python.sql import zxJDBC + return zxJDBC + + def _driver_kwargs(self): + """return kw arg dict to be sent to connect().""" + return {} + + def create_connect_args(self, url): + hostname = url.host + dbname = url.database + d, u, p, v = ("jdbc:%s://%s/%s" % (self.jdbc_db_name, hostname, dbname), + url.username, url.password, self.jdbc_driver_name) + return [[d, u, p, v], self._driver_kwargs()] + + def is_disconnect(self, e): + if not isinstance(e, self.dbapi.ProgrammingError): + return False + e = str(e) + return 'connection is closed' in e or 'cursor is closed' in e + + def _get_server_version_info(self, connection): + # use connection.connection.dbversion, and parse appropriately + # to get a tuple + raise NotImplementedError() diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index 6588be0ae7..16cabd47f8 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -4,6 +4,20 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from sqlalchemy.dialects.sqlite import base as sqlite +from sqlalchemy.dialects.postgresql import base as postgresql +postgres = postgresql +from sqlalchemy.dialects.mysql import base as mysql +from sqlalchemy.dialects.oracle import base as oracle +from sqlalchemy.dialects.firebird import base as firebird +from sqlalchemy.dialects.maxdb import base as maxdb +from sqlalchemy.dialects.informix import base as informix +from sqlalchemy.dialects.mssql import base as mssql +from sqlalchemy.dialects.access import base as access +from sqlalchemy.dialects.sybase import base as sybase + + + __all__ = ( 'access', @@ -12,8 +26,8 @@ __all__ = ( 'maxdb', 'mssql', 'mysql', - 'oracle', - 'postgres', + 'postgresql', 'sqlite', + 'oracle', 'sybase', ) diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py deleted file mode 100644 index 8a8d02d4a1..0000000000 --- a/lib/sqlalchemy/databases/firebird.py +++ /dev/null @@ -1,768 +0,0 @@ -# firebird.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -""" -Firebird backend -================ - -This module implements the Firebird backend, thru the kinterbasdb_ -DBAPI module. - -Firebird dialects ------------------ - -Firebird offers two distinct dialects_ (not to be confused with the -SA ``Dialect`` thing): - -dialect 1 - This is the old syntax and behaviour, inherited from Interbase pre-6.0. - -dialect 3 - This is the newer and supported syntax, introduced in Interbase 6.0. - -From the user point of view, the biggest change is in date/time -handling: under dialect 1, there's a single kind of field, ``DATE`` -with a synonim ``DATETIME``, that holds a `timestamp` value, that is a -date with hour, minute, second. Under dialect 3 there are three kinds, -a ``DATE`` that holds a date, a ``TIME`` that holds a *time of the -day* value and a ``TIMESTAMP``, equivalent to the old ``DATE``. - -The problem is that the dialect of a Firebird database is a property -of the database itself [#]_ (that is, any single database has been -created with one dialect or the other: there is no way to change the -after creation). SQLAlchemy has a single instance of the class that -controls all the connections to a particular kind of database, so it -cannot easily differentiate between the two modes, and in particular -it **cannot** simultaneously talk with two distinct Firebird databases -with different dialects. - -By default this module is biased toward dialect 3, but you can easily -tweak it to handle dialect 1 if needed:: - - from sqlalchemy import types as sqltypes - from sqlalchemy.databases.firebird import FBDate, colspecs, ischema_names - - # Adjust the mapping of the timestamp kind - ischema_names['TIMESTAMP'] = FBDate - colspecs[sqltypes.DateTime] = FBDate, - -Other aspects may be version-specific. You can use the ``server_version_info()`` method -on the ``FBDialect`` class to do whatever is needed:: - - from sqlalchemy.databases.firebird import FBCompiler - - if engine.dialect.server_version_info(connection) < (2,0): - # Change the name of the function ``length`` to use the UDF version - # instead of ``char_length`` - FBCompiler.LENGTH_FUNCTION_NAME = 'strlen' - -Pooling connections -------------------- - -The default strategy used by SQLAlchemy to pool the database connections -in particular cases may raise an ``OperationalError`` with a message -`"object XYZ is in use"`. This happens on Firebird when there are two -connections to the database, one is using, or has used, a particular table -and the other tries to drop or alter the same table. To garantee DDL -operations success Firebird recommend doing them as the single connected user. - -In case your SA application effectively needs to do DDL operations while other -connections are active, the following setting may alleviate the problem:: - - from sqlalchemy import pool - from sqlalchemy.databases.firebird import dialect - - # Force SA to use a single connection per thread - dialect.poolclass = pool.SingletonThreadPool - -RETURNING support ------------------ - -Firebird 2.0 supports returning a result set from inserts, and 2.1 extends -that to deletes and updates. - -To use this pass the column/expression list to the ``firebird_returning`` -parameter when creating the queries:: - - raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1), - firebird_returning=[empl.c.id, empl.c.salary]).execute().fetchall() - - -.. [#] Well, that is not the whole story, as the client may still ask - a different (lower) dialect... - -.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html -.. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb -""" - - -import datetime, decimal, re - -from sqlalchemy import exc, schema, types as sqltypes, sql, util -from sqlalchemy.engine import base, default - - -_initialized_kb = False - - -class FBNumeric(sqltypes.Numeric): - """Handle ``NUMERIC(precision,scale)`` datatype.""" - - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % { 'precision': self.precision, - 'scale' : self.scale } - - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - if self.asdecimal: - return None - else: - def process(value): - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - return process - - -class FBFloat(sqltypes.Float): - """Handle ``FLOAT(precision)`` datatype.""" - - def get_col_spec(self): - if not self.precision: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - -class FBInteger(sqltypes.Integer): - """Handle ``INTEGER`` datatype.""" - - def get_col_spec(self): - return "INTEGER" - - -class FBSmallInteger(sqltypes.Smallinteger): - """Handle ``SMALLINT`` datatype.""" - - def get_col_spec(self): - return "SMALLINT" - - -class FBDateTime(sqltypes.DateTime): - """Handle ``TIMESTAMP`` datatype.""" - - def get_col_spec(self): - return "TIMESTAMP" - - def bind_processor(self, dialect): - def process(value): - if value is None or isinstance(value, datetime.datetime): - return value - else: - return datetime.datetime(year=value.year, - month=value.month, - day=value.day) - return process - - -class FBDate(sqltypes.DateTime): - """Handle ``DATE`` datatype.""" - - def get_col_spec(self): - return "DATE" - - -class FBTime(sqltypes.Time): - """Handle ``TIME`` datatype.""" - - def get_col_spec(self): - return "TIME" - - -class FBText(sqltypes.Text): - """Handle ``BLOB SUB_TYPE 1`` datatype (aka *textual* blob).""" - - def get_col_spec(self): - return "BLOB SUB_TYPE 1" - - -class FBString(sqltypes.String): - """Handle ``VARCHAR(length)`` datatype.""" - - def get_col_spec(self): - if self.length: - return "VARCHAR(%(length)s)" % {'length' : self.length} - else: - return "BLOB SUB_TYPE 1" - - -class FBChar(sqltypes.CHAR): - """Handle ``CHAR(length)`` datatype.""" - - def get_col_spec(self): - if self.length: - return "CHAR(%(length)s)" % {'length' : self.length} - else: - return "BLOB SUB_TYPE 1" - - -class FBBinary(sqltypes.Binary): - """Handle ``BLOB SUB_TYPE 0`` datatype (aka *binary* blob).""" - - def get_col_spec(self): - return "BLOB SUB_TYPE 0" - - -class FBBoolean(sqltypes.Boolean): - """Handle boolean values as a ``SMALLINT`` datatype.""" - - def get_col_spec(self): - return "SMALLINT" - - -colspecs = { - sqltypes.Integer : FBInteger, - sqltypes.Smallinteger : FBSmallInteger, - sqltypes.Numeric : FBNumeric, - sqltypes.Float : FBFloat, - sqltypes.DateTime : FBDateTime, - sqltypes.Date : FBDate, - sqltypes.Time : FBTime, - sqltypes.String : FBString, - sqltypes.Binary : FBBinary, - sqltypes.Boolean : FBBoolean, - sqltypes.Text : FBText, - sqltypes.CHAR: FBChar, -} - - -ischema_names = { - 'SHORT': lambda r: FBSmallInteger(), - 'LONG': lambda r: FBInteger(), - 'QUAD': lambda r: FBFloat(), - 'FLOAT': lambda r: FBFloat(), - 'DATE': lambda r: FBDate(), - 'TIME': lambda r: FBTime(), - 'TEXT': lambda r: FBString(r['flen']), - 'INT64': lambda r: FBNumeric(precision=r['fprec'], scale=r['fscale'] * -1), # This generically handles NUMERIC() - 'DOUBLE': lambda r: FBFloat(), - 'TIMESTAMP': lambda r: FBDateTime(), - 'VARYING': lambda r: FBString(r['flen']), - 'CSTRING': lambda r: FBChar(r['flen']), - 'BLOB': lambda r: r['stype']==1 and FBText() or FBBinary() - } - -RETURNING_KW_NAME = 'firebird_returning' - -class FBExecutionContext(default.DefaultExecutionContext): - pass - - -class FBDialect(default.DefaultDialect): - """Firebird dialect""" - name = 'firebird' - supports_sane_rowcount = False - supports_sane_multi_rowcount = False - max_identifier_length = 31 - preexecute_pk_sequences = True - supports_pk_autoincrement = False - - def __init__(self, type_conv=200, concurrency_level=1, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - - self.type_conv = type_conv - self.concurrency_level = concurrency_level - - def dbapi(cls): - import kinterbasdb - return kinterbasdb - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if opts.get('port'): - opts['host'] = "%s/%s" % (opts['host'], opts['port']) - del opts['port'] - opts.update(url.query) - - type_conv = opts.pop('type_conv', self.type_conv) - concurrency_level = opts.pop('concurrency_level', self.concurrency_level) - global _initialized_kb - if not _initialized_kb and self.dbapi is not None: - _initialized_kb = True - self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level) - return ([], opts) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def server_version_info(self, connection): - """Get the version of the Firebird server used by a connection. - - Returns a tuple of (`major`, `minor`, `build`), three integers - representing the version of the attached server. - """ - - # This is the simpler approach (the other uses the services api), - # that for backward compatibility reasons returns a string like - # LI-V6.3.3.12981 Firebird 2.0 - # where the first version is a fake one resembling the old - # Interbase signature. This is more than enough for our purposes, - # as this is mainly (only?) used by the testsuite. - - from re import match - - fbconn = connection.connection.connection - version = fbconn.server_version - m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version) - if not m: - raise AssertionError("Could not determine version from string '%s'" % version) - return tuple([int(x) for x in m.group(5, 6, 4)]) - - def _normalize_name(self, name): - """Convert the name to lowercase if it is possible""" - - # Remove trailing spaces: FB uses a CHAR() type, - # that is padded with spaces - name = name and name.rstrip() - if name is None: - return None - elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower()): - return name.lower() - else: - return name - - def _denormalize_name(self, name): - """Revert a *normalized* name to its uppercase equivalent""" - - if name is None: - return None - elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): - return name.upper() - else: - return name - - def table_names(self, connection, schema): - """Return a list of *normalized* table names omitting system relations.""" - - s = """ - SELECT r.rdb$relation_name - FROM rdb$relations r - WHERE r.rdb$system_flag=0 - """ - return [self._normalize_name(row[0]) for row in connection.execute(s)] - - def has_table(self, connection, table_name, schema=None): - """Return ``True`` if the given table exists, ignoring the `schema`.""" - - tblqry = """ - SELECT 1 FROM rdb$database - WHERE EXISTS (SELECT rdb$relation_name - FROM rdb$relations - WHERE rdb$relation_name=?) - """ - c = connection.execute(tblqry, [self._denormalize_name(table_name)]) - row = c.fetchone() - if row is not None: - return True - else: - return False - - def has_sequence(self, connection, sequence_name): - """Return ``True`` if the given sequence (generator) exists.""" - - genqry = """ - SELECT 1 FROM rdb$database - WHERE EXISTS (SELECT rdb$generator_name - FROM rdb$generators - WHERE rdb$generator_name=?) - """ - c = connection.execute(genqry, [self._denormalize_name(sequence_name)]) - row = c.fetchone() - if row is not None: - return True - else: - return False - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return 'Unable to complete network request to host' in str(e) - elif isinstance(e, self.dbapi.ProgrammingError): - msg = str(e) - return ('Invalid connection state' in msg or - 'Invalid cursor state' in msg) - else: - return False - - def reflecttable(self, connection, table, include_columns): - # Query to extract the details of all the fields of the given table - tblqry = """ - SELECT DISTINCT r.rdb$field_name AS fname, - r.rdb$null_flag AS null_flag, - t.rdb$type_name AS ftype, - f.rdb$field_sub_type AS stype, - f.rdb$field_length AS flen, - f.rdb$field_precision AS fprec, - f.rdb$field_scale AS fscale, - COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault - FROM rdb$relation_fields r - JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name - JOIN rdb$types t ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE' - WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? - ORDER BY r.rdb$field_position - """ - # Query to extract the PK/FK constrained fields of the given table - keyqry = """ - SELECT se.rdb$field_name AS fname - FROM rdb$relation_constraints rc - JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name - WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? - """ - # Query to extract the details of each UK/FK of the given table - fkqry = """ - SELECT rc.rdb$constraint_name AS cname, - cse.rdb$field_name AS fname, - ix2.rdb$relation_name AS targetrname, - se.rdb$field_name AS targetfname - FROM rdb$relation_constraints rc - JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name - JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key - JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name - JOIN rdb$index_segments se ON se.rdb$index_name=ix2.rdb$index_name AND se.rdb$field_position=cse.rdb$field_position - WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? - ORDER BY se.rdb$index_name, se.rdb$field_position - """ - # Heuristic-query to determine the generator associated to a PK field - genqry = """ - SELECT trigdep.rdb$depended_on_name AS fgenerator - FROM rdb$dependencies tabdep - JOIN rdb$dependencies trigdep ON (tabdep.rdb$dependent_name=trigdep.rdb$dependent_name - AND trigdep.rdb$depended_on_type=14 - AND trigdep.rdb$dependent_type=2) - JOIN rdb$triggers trig ON (trig.rdb$trigger_name=tabdep.rdb$dependent_name) - WHERE tabdep.rdb$depended_on_name=? - AND tabdep.rdb$depended_on_type=0 - AND trig.rdb$trigger_type=1 - AND tabdep.rdb$field_name=? - AND (SELECT count(*) - FROM rdb$dependencies trigdep2 - WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2 - """ - - tablename = self._denormalize_name(table.name) - - # get primary key fields - c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) - pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()] - - # get all of the fields for this table - c = connection.execute(tblqry, [tablename]) - - found_table = False - while True: - row = c.fetchone() - if row is None: - break - found_table = True - - name = self._normalize_name(row['fname']) - if include_columns and name not in include_columns: - continue - args = [name] - - kw = {} - # get the data type - coltype = ischema_names.get(row['ftype'].rstrip()) - if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % - (str(row['ftype']), name)) - coltype = sqltypes.NULLTYPE - else: - coltype = coltype(row) - args.append(coltype) - - # is it a primary key? - kw['primary_key'] = name in pkfields - - # is it nullable? - kw['nullable'] = not bool(row['null_flag']) - - # does it have a default value? - if row['fdefault'] is not None: - # the value comes down as "DEFAULT 'value'" - assert row['fdefault'].upper().startswith('DEFAULT '), row - defvalue = row['fdefault'][8:] - args.append(schema.DefaultClause(sql.text(defvalue))) - - col = schema.Column(*args, **kw) - if kw['primary_key']: - # if the PK is a single field, try to see if its linked to - # a sequence thru a trigger - if len(pkfields)==1: - genc = connection.execute(genqry, [tablename, row['fname']]) - genr = genc.fetchone() - if genr is not None: - col.sequence = schema.Sequence(self._normalize_name(genr['fgenerator'])) - - table.append_column(col) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - # get the foreign keys - c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) - fks = {} - while True: - row = c.fetchone() - if not row: - break - - cname = self._normalize_name(row['cname']) - try: - fk = fks[cname] - except KeyError: - fks[cname] = fk = ([], []) - rname = self._normalize_name(row['targetrname']) - schema.Table(rname, table.metadata, autoload=True, autoload_with=connection) - fname = self._normalize_name(row['fname']) - refspec = rname + '.' + self._normalize_name(row['targetfname']) - fk[0].append(fname) - fk[1].append(refspec) - - for name, value in fks.iteritems(): - table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True)) - - def do_execute(self, cursor, statement, parameters, **kwargs): - # kinterbase does not accept a None, but wants an empty list - # when there are no arguments. - cursor.execute(statement, parameters or []) - - def do_rollback(self, connection): - # Use the retaining feature, that keeps the transaction going - connection.rollback(True) - - def do_commit(self, connection): - # Use the retaining feature, that keeps the transaction going - connection.commit(True) - - -def _substring(s, start, length=None): - "Helper function to handle Firebird 2 SUBSTRING builtin" - - if length is None: - return "SUBSTRING(%s FROM %s)" % (s, start) - else: - return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) - - -class FBCompiler(sql.compiler.DefaultCompiler): - """Firebird specific idiosincrasies""" - - # Firebird lacks a builtin modulo operator, but there is - # an equivalent function in the ib_udf library. - operators = sql.compiler.DefaultCompiler.operators.copy() - operators.update({ - sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y) - }) - - def visit_alias(self, alias, asfrom=False, **kwargs): - # Override to not use the AS keyword which FB 1.5 does not like - if asfrom: - return self.process(alias.original, asfrom=True, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name)) - else: - return self.process(alias.original, **kwargs) - - functions = sql.compiler.DefaultCompiler.functions.copy() - functions['substring'] = _substring - - def function_argspec(self, func): - if func.clauses: - return self.process(func.clause_expr) - else: - return "" - - def default_from(self): - return " FROM rdb$database" - - def visit_sequence(self, seq): - return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) - - def get_select_precolumns(self, select): - """Called when building a ``SELECT`` statement, position is just - before column list Firebird puts the limit and offset right - after the ``SELECT``... - """ - - result = "" - if select._limit: - result += "FIRST %d " % select._limit - if select._offset: - result +="SKIP %d " % select._offset - if select._distinct: - result += "DISTINCT " - return result - - def limit_clause(self, select): - """Already taken care of in the `get_select_precolumns` method.""" - - return "" - - LENGTH_FUNCTION_NAME = 'char_length' - def function_string(self, func): - """Substitute the ``length`` function. - - On newer FB there is a ``char_length`` function, while older - ones need the ``strlen`` UDF. - """ - - if func.name == 'length': - return self.LENGTH_FUNCTION_NAME + '%(expr)s' - return super(FBCompiler, self).function_string(func) - - def _append_returning(self, text, stmt): - returning_cols = stmt.kwargs[RETURNING_KW_NAME] - def flatten_columnlist(collist): - for c in collist: - if isinstance(c, sql.expression.Selectable): - for co in c.columns: - yield co - else: - yield c - columns = [self.process(c, within_columns_clause=True) - for c in flatten_columnlist(returning_cols)] - text += ' RETURNING ' + ', '.join(columns) - return text - - def visit_update(self, update_stmt): - text = super(FBCompiler, self).visit_update(update_stmt) - if RETURNING_KW_NAME in update_stmt.kwargs: - return self._append_returning(text, update_stmt) - else: - return text - - def visit_insert(self, insert_stmt): - text = super(FBCompiler, self).visit_insert(insert_stmt) - if RETURNING_KW_NAME in insert_stmt.kwargs: - return self._append_returning(text, insert_stmt) - else: - return text - - def visit_delete(self, delete_stmt): - text = super(FBCompiler, self).visit_delete(delete_stmt) - if RETURNING_KW_NAME in delete_stmt.kwargs: - return self._append_returning(text, delete_stmt) - else: - return text - - -class FBSchemaGenerator(sql.compiler.SchemaGenerator): - """Firebird syntactic idiosincrasies""" - - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable or column.primary_key: - colspec += " NOT NULL" - - return colspec - - def visit_sequence(self, sequence): - """Generate a ``CREATE GENERATOR`` statement for the sequence.""" - - if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name): - self.append("CREATE GENERATOR %s" % self.preparer.format_sequence(sequence)) - self.execute() - - -class FBSchemaDropper(sql.compiler.SchemaDropper): - """Firebird syntactic idiosincrasies""" - - def visit_sequence(self, sequence): - """Generate a ``DROP GENERATOR`` statement for the sequence.""" - - if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name): - self.append("DROP GENERATOR %s" % self.preparer.format_sequence(sequence)) - self.execute() - - -class FBDefaultRunner(base.DefaultRunner): - """Firebird specific idiosincrasies""" - - def visit_sequence(self, seq): - """Get the next value from the sequence using ``gen_id()``.""" - - return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \ - self.dialect.identifier_preparer.format_sequence(seq)) - - -RESERVED_WORDS = set( - ["action", "active", "add", "admin", "after", "all", "alter", "and", "any", - "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename", - "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer", - "by", "cache", "cascade", "case", "cast", "char", "character", "character_length", - "char_length", "check", "check_point_len", "check_point_length", "close", "collate", - "collation", "column", "commit", "committed", "compiletime", "computed", "conditional", - "connect", "constraint", "containing", "continue", "count", "create", "cstring", - "current", "current_connection", "current_date", "current_role", "current_time", - "current_timestamp", "current_transaction", "current_user", "cursor", "database", - "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete", - "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct", - "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point", - "escape", "event", "exception", "execute", "exists", "exit", "extern", "external", - "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it", - "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto", - "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour", - "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input", - "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join", - "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile", - "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment", - "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month", - "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric", - "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option", - "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength", - "pages", "page_size", "parameter", "password", "plan", "position", "post_event", - "precision", "prepare", "primary", "privileges", "procedure", "protected", "public", - "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate", - "references", "release", "release", "reserv", "reserving", "restrict", "retain", - "return", "returning_values", "returns", "revoke", "right", "role", "rollback", - "row_count", "runtime", "savepoint", "schema", "second", "segment", "select", - "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint", - "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability", - "starting", "starts", "statement", "static", "statistics", "sub_type", "sum", - "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction", - "translate", "translation", "trigger", "trim", "type", "uncommitted", "union", - "unique", "update", "upper", "user", "using", "value", "values", "varchar", - "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when", - "whenever", "where", "while", "with", "work", "write", "year", "yearday" ]) - - -class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): - """Install Firebird specific reserved words.""" - - reserved_words = RESERVED_WORDS - - def __init__(self, dialect): - super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) - - -dialect = FBDialect -dialect.statement_compiler = FBCompiler -dialect.schemagenerator = FBSchemaGenerator -dialect.schemadropper = FBSchemaDropper -dialect.defaultrunner = FBDefaultRunner -dialect.preparer = FBIdentifierPreparer -dialect.execution_ctx_cls = FBExecutionContext diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py deleted file mode 100644 index a7d4101cdb..0000000000 --- a/lib/sqlalchemy/databases/information_schema.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -information schema implementation. - -This module is deprecated and will not be present in this form in SQLAlchemy 0.6. - -""" -from sqlalchemy import util - -util.warn_deprecated("the information_schema module is deprecated.") - -import sqlalchemy.sql as sql -import sqlalchemy.exc as exc -from sqlalchemy import select, MetaData, Table, Column, String, Integer -from sqlalchemy.schema import DefaultClause, ForeignKeyConstraint - -ischema = MetaData() - -schemata = Table("schemata", ischema, - Column("catalog_name", String), - Column("schema_name", String), - Column("schema_owner", String), - schema="information_schema") - -tables = Table("tables", ischema, - Column("table_catalog", String), - Column("table_schema", String), - Column("table_name", String), - Column("table_type", String), - schema="information_schema") - -columns = Table("columns", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("is_nullable", Integer), - Column("data_type", String), - Column("ordinal_position", Integer), - Column("character_maximum_length", Integer), - Column("numeric_precision", Integer), - Column("numeric_scale", Integer), - Column("column_default", Integer), - Column("collation_name", String), - schema="information_schema") - -constraints = Table("table_constraints", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("constraint_name", String), - Column("constraint_type", String), - schema="information_schema") - -column_constraints = Table("constraint_column_usage", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("constraint_name", String), - schema="information_schema") - -pg_key_constraints = Table("key_column_usage", ischema, - Column("table_schema", String), - Column("table_name", String), - Column("column_name", String), - Column("constraint_name", String), - Column("ordinal_position", Integer), - schema="information_schema") - -#mysql_key_constraints = Table("key_column_usage", ischema, -# Column("table_schema", String), -# Column("table_name", String), -# Column("column_name", String), -# Column("constraint_name", String), -# Column("referenced_table_schema", String), -# Column("referenced_table_name", String), -# Column("referenced_column_name", String), -# schema="information_schema") - -key_constraints = pg_key_constraints - -ref_constraints = Table("referential_constraints", ischema, - Column("constraint_catalog", String), - Column("constraint_schema", String), - Column("constraint_name", String), - Column("unique_constraint_catlog", String), - Column("unique_constraint_schema", String), - Column("unique_constraint_name", String), - Column("match_option", String), - Column("update_rule", String), - Column("delete_rule", String), - schema="information_schema") - - -def table_names(connection, schema): - s = select([tables.c.table_name], tables.c.table_schema==schema) - return [row[0] for row in connection.execute(s)] - - -def reflecttable(connection, table, include_columns, ischema_names): - key_constraints = pg_key_constraints - - if table.schema is not None: - current_schema = table.schema - else: - current_schema = connection.default_schema_name() - - s = select([columns], - sql.and_(columns.c.table_name==table.name, - columns.c.table_schema==current_schema), - order_by=[columns.c.ordinal_position]) - - c = connection.execute(s) - found_table = False - while True: - row = c.fetchone() - if row is None: - break - #print "row! " + repr(row) - # continue - found_table = True - (name, type, nullable, charlen, numericprec, numericscale, default) = ( - row[columns.c.column_name], - row[columns.c.data_type], - row[columns.c.is_nullable] == 'YES', - row[columns.c.character_maximum_length], - row[columns.c.numeric_precision], - row[columns.c.numeric_scale], - row[columns.c.column_default] - ) - if include_columns and name not in include_columns: - continue - - args = [] - for a in (charlen, numericprec, numericscale): - if a is not None: - args.append(a) - coltype = ischema_names[type] - #print "coltype " + repr(coltype) + " args " + repr(args) - coltype = coltype(*args) - colargs = [] - if default is not None: - colargs.append(DefaultClause(sql.text(default))) - table.append_column(Column(name, coltype, nullable=nullable, *colargs)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - # we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns - # in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys - # wont reflect properly. dont see a way around this based on whats available from information_schema - s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)], order_by=[key_constraints.c.ordinal_position]) - s.append_column(column_constraints) - s.append_whereclause(constraints.c.table_name==table.name) - s.append_whereclause(constraints.c.table_schema==current_schema) - colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name, constraints.c.constraint_name, key_constraints.c.ordinal_position] - c = connection.execute(s) - - fks = {} - while True: - row = c.fetchone() - if row is None: - break - (type, constrained_column, referred_schema, referred_table, referred_column, constraint_name, ordinal_position) = ( - row[colmap[0]], - row[colmap[1]], - row[colmap[2]], - row[colmap[3]], - row[colmap[4]], - row[colmap[5]], - row[colmap[6]] - ) - #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) - if type == 'PRIMARY KEY': - table.primary_key.add(table.c[constrained_column]) - elif type == 'FOREIGN KEY': - try: - fk = fks[constraint_name] - except KeyError: - fk = ([], []) - fks[constraint_name] = fk - if current_schema == referred_schema: - referred_schema = table.schema - if referred_schema is not None: - Table(referred_table, table.metadata, autoload=True, schema=referred_schema, autoload_with=connection) - refspec = ".".join([referred_schema, referred_table, referred_column]) - else: - Table(referred_table, table.metadata, autoload=True, autoload_with=connection) - refspec = ".".join([referred_table, referred_column]) - if constrained_column not in fk[0]: - fk[0].append(constrained_column) - if refspec not in fk[1]: - fk[1].append(refspec) - - for name, value in fks.iteritems(): - table.append_constraint(ForeignKeyConstraint(value[0], value[1], name=name)) diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py deleted file mode 100644 index d963b74770..0000000000 --- a/lib/sqlalchemy/databases/mssql.py +++ /dev/null @@ -1,1771 +0,0 @@ -# mssql.py - -"""Support for the Microsoft SQL Server database. - -Driver ------- - -The MSSQL dialect will work with three different available drivers: - -* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded - driver. - -* *pymssql* - http://pymssql.sourceforge.net/ - -* *adodbapi* - http://adodbapi.sourceforge.net/ - -Drivers are loaded in the order listed above based on availability. - -If you need to load a specific driver pass ``module_name`` when -creating the engine:: - - engine = create_engine('mssql://dsn', module_name='pymssql') - -``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and -``adodbapi``. - -Currently the pyodbc driver offers the greatest level of -compatibility. - -Connecting ----------- - -Connecting with create_engine() uses the standard URL approach of -``mssql://user:pass@host/dbname[?key=value&key=value...]``. - -If the database name is present, the tokens are converted to a -connection string with the specified values. If the database is not -present, then the host token is taken directly as the DSN name. - -Examples of pyodbc connection string URLs: - -* *mssql://mydsn* - connects using the specified DSN named ``mydsn``. - The connection string that is created will appear like:: - - dsn=mydsn;TrustedConnection=Yes - -* *mssql://user:pass@mydsn* - connects using the DSN named - ``mydsn`` passing in the ``UID`` and ``PWD`` information. The - connection string that is created will appear like:: - - dsn=mydsn;UID=user;PWD=pass - -* *mssql://user:pass@mydsn/?LANGUAGE=us_english* - connects - using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` - information, plus the additional connection configuration option - ``LANGUAGE``. The connection string that is created will appear - like:: - - dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english - -* *mssql://user:pass@host/db* - connects using a connection string - dynamically created that would appear like:: - - DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass - -* *mssql://user:pass@host:123/db* - connects using a connection - string that is dynamically created, which also includes the port - information using the comma syntax. If your connection string - requires the port information to be passed as a ``port`` keyword - see the next example. This will create the following connection - string:: - - DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass - -* *mssql://user:pass@host/db?port=123* - connects using a connection - string that is dynamically created that includes the port - information as a separate ``port`` keyword. This will create the - following connection string:: - - DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123 - -If you require a connection string that is outside the options -presented above, use the ``odbc_connect`` keyword to pass in a -urlencoded connection string. What gets passed in will be urldecoded -and passed directly. - -For example:: - - mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb - -would create the following connection string:: - - dsn=mydsn;Database=db - -Encoding your connection string can be easily accomplished through -the python shell. For example:: - - >>> import urllib - >>> urllib.quote_plus('dsn=mydsn;Database=db') - 'dsn%3Dmydsn%3BDatabase%3Ddb' - -Additional arguments which may be specified either as query string -arguments on the URL, or as keyword argument to -:func:`~sqlalchemy.create_engine()` are: - -* *auto_identity_insert* - enables support for IDENTITY inserts by - automatically turning IDENTITY INSERT ON and OFF as required. - Defaults to ``True``. - -* *query_timeout* - allows you to override the default query timeout. - Defaults to ``None``. This is only supported on pymssql. - -* *text_as_varchar* - if enabled this will treat all TEXT column - types as their equivalent VARCHAR(max) type. This is often used if - you need to compare a VARCHAR to a TEXT field, which is not - supported directly on MSSQL. Defaults to ``False``. - -* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY - should be used in place of the non-scoped version @@IDENTITY. - Defaults to ``False``. On pymssql this defaults to ``True``, and on - pyodbc this defaults to ``True`` if the version of pyodbc being - used supports it. - -* *has_window_funcs* - indicates whether or not window functions - (LIMIT and OFFSET) are supported on the version of MSSQL being - used. If you're running MSSQL 2005 or later turn this on to get - OFFSET support. Defaults to ``False``. - -* *max_identifier_length* - allows you to se the maximum length of - identfiers supported by the database. Defaults to 128. For pymssql - the default is 30. - -* *schema_name* - use to set the schema name. Defaults to ``dbo``. - -Auto Increment Behavior ------------------------ - -``IDENTITY`` columns are supported by using SQLAlchemy -``schema.Sequence()`` objects. In other words:: - - Table('test', mss_engine, - Column('id', Integer, - Sequence('blah',100,10), primary_key=True), - Column('name', String(20)) - ).create() - -would yield:: - - CREATE TABLE test ( - id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, - name VARCHAR(20) NULL, - ) - -Note that the ``start`` and ``increment`` values for sequences are -optional and will default to 1,1. - -* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for - ``INSERT`` s) - -* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on - ``INSERT`` - -Collation Support ------------------ - -MSSQL specific string types support a collation parameter that -creates a column-level specific collation for the column. The -collation parameter accepts a Windows Collation Name or a SQL -Collation Name. Supported types are MSChar, MSNChar, MSString, -MSNVarchar, MSText, and MSNText. For example:: - - Column('login', String(32, collation='Latin1_General_CI_AS')) - -will yield:: - - login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL - -LIMIT/OFFSET Support --------------------- - -MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is -supported directly through the ``TOP`` Transact SQL keyword:: - - select.limit - -will yield:: - - SELECT TOP n - -If the ``has_window_funcs`` flag is set then LIMIT with OFFSET -support is available through the ``ROW_NUMBER OVER`` construct. This -construct requires an ``ORDER BY`` to be specified as well and is -only available on MSSQL 2005 and later. - -Nullability ------------ -MSSQL has support for three levels of column nullability. The default -nullability allows nulls and is explicit in the CREATE TABLE -construct:: - - name VARCHAR(20) NULL - -If ``nullable=None`` is specified then no specification is made. In -other words the database's configured default is used. This will -render:: - - name VARCHAR(20) - -If ``nullable`` is ``True`` or ``False`` then the column will be -``NULL` or ``NOT NULL`` respectively. - -Date / Time Handling --------------------- -For MSSQL versions that support the ``DATE`` and ``TIME`` types -(MSSQL 2008+) the data type is used. For versions that do not -support the ``DATE`` and ``TIME`` types a ``DATETIME`` type is used -instead and the MSSQL dialect handles converting the results -properly. This means ``Date()`` and ``Time()`` are fully supported -on all versions of MSSQL. If you do not desire this behavior then -do not use the ``Date()`` or ``Time()`` types. - -Compatibility Levels --------------------- -MSSQL supports the notion of setting compatibility levels at the -database level. This allows, for instance, to run a database that -is compatibile with SQL2000 while running on a SQL2005 database -server. ``server_version_info`` will always retrun the database -server version information (in this case SQL2005) and not the -compatibiility level information. Because of this, if running under -a backwards compatibility mode SQAlchemy may attempt to use T-SQL -statements that are unable to be parsed by the database server. - -Known Issues ------------- - -* No support for more than one ``IDENTITY`` column per table - -* pymssql has problems with binary and unicode data that this module - does **not** work around - -""" -import datetime, decimal, inspect, operator, re, sys, urllib - -from sqlalchemy import sql, schema, exc, util -from sqlalchemy import Table, MetaData, Column, ForeignKey, String, Integer -from sqlalchemy.sql import select, compiler, expression, operators as sql_operators, functions as sql_functions -from sqlalchemy.engine import default, base -from sqlalchemy import types as sqltypes -from decimal import Decimal as _python_Decimal - - -RESERVED_WORDS = set( - ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', - 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', - 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', - 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', - 'containstable', 'continue', 'convert', 'create', 'cross', 'current', - 'current_date', 'current_time', 'current_timestamp', 'current_user', - 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', - 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', - 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', - 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', - 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', - 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', - 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', - 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', - 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', - 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', - 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', - 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', - 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', - 'reconfigure', 'references', 'replication', 'restore', 'restrict', - 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', - 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', - 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', - 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', - 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', - 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', - 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', - 'writetext', - ]) - - -class _StringType(object): - """Base for MSSQL string types.""" - - def __init__(self, collation=None, **kwargs): - self.collation = kwargs.get('collate', collation) - - def _extend(self, spec): - """Extend a string-type declaration with standard SQL - COLLATE annotations. - """ - - if self.collation: - collation = 'COLLATE %s' % self.collation - else: - collation = None - - return ' '.join([c for c in (spec, collation) - if c is not None]) - - def __repr__(self): - attributes = inspect.getargspec(self.__init__)[0][1:] - attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) - - params = {} - for attr in attributes: - val = getattr(self, attr) - if val is not None and val is not False: - params[attr] = val - - return "%s(%s)" % (self.__class__.__name__, - ', '.join(['%s=%r' % (k, params[k]) for k in params])) - - def bind_processor(self, dialect): - if self.convert_unicode or dialect.convert_unicode: - if self.assert_unicode is None: - assert_unicode = dialect.assert_unicode - else: - assert_unicode = self.assert_unicode - - if not assert_unicode: - return None - - def process(value): - if not isinstance(value, (unicode, sqltypes.NoneType)): - if assert_unicode == 'warn': - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) - return value - else: - raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) - else: - return value - return process - else: - return None - - -class MSNumeric(sqltypes.Numeric): - def result_processor(self, dialect): - if self.asdecimal: - def process(value): - if value is not None: - return _python_Decimal(str(value)) - else: - return value - return process - else: - def process(value): - return float(value) - return process - - def bind_processor(self, dialect): - def process(value): - if value is None: - # Not sure that this exception is needed - return value - - elif isinstance(value, decimal.Decimal): - if value.adjusted() < 0: - result = "%s0.%s%s" % ( - (value < 0 and '-' or ''), - '0' * (abs(value.adjusted()) - 1), - "".join([str(nint) for nint in value._int])) - - else: - if 'E' in str(value): - result = "%s%s%s" % ( - (value < 0 and '-' or ''), - "".join([str(s) for s in value._int]), - "0" * (value.adjusted() - (len(value._int)-1))) - else: - if (len(value._int) - 1) > value.adjusted(): - result = "%s%s.%s" % ( - (value < 0 and '-' or ''), - "".join([str(s) for s in value._int][0:value.adjusted() + 1]), - "".join([str(s) for s in value._int][value.adjusted() + 1:])) - else: - result = "%s%s" % ( - (value < 0 and '-' or ''), - "".join([str(s) for s in value._int][0:value.adjusted() + 1])) - - return result - - else: - return value - - return process - - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - - -class MSFloat(sqltypes.Float): - def get_col_spec(self): - if self.precision is None: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - -class MSReal(MSFloat): - """A type for ``real`` numbers.""" - - def __init__(self): - """ - Construct a Real. - - """ - super(MSReal, self).__init__(precision=24) - - def adapt(self, impltype): - return impltype() - - def get_col_spec(self): - return "REAL" - - -class MSInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - - -class MSBigInteger(MSInteger): - def get_col_spec(self): - return "BIGINT" - - -class MSTinyInteger(MSInteger): - def get_col_spec(self): - return "TINYINT" - - -class MSSmallInteger(MSInteger): - def get_col_spec(self): - return "SMALLINT" - - -class _DateTimeType(object): - """Base for MSSQL datetime types.""" - - def bind_processor(self, dialect): - # if we receive just a date we can manipulate it - # into a datetime since the db-api may not do this. - def process(value): - if type(value) is datetime.date: - return datetime.datetime(value.year, value.month, value.day) - return value - return process - - -class MSDateTime(_DateTimeType, sqltypes.DateTime): - def get_col_spec(self): - return "DATETIME" - - -class MSDate(sqltypes.Date): - def get_col_spec(self): - return "DATE" - - -class MSTime(sqltypes.Time): - def __init__(self, precision=None, **kwargs): - self.precision = precision - super(MSTime, self).__init__() - - def get_col_spec(self): - if self.precision: - return "TIME(%s)" % self.precision - else: - return "TIME" - - -class MSSmallDateTime(_DateTimeType, sqltypes.TypeEngine): - def get_col_spec(self): - return "SMALLDATETIME" - - -class MSDateTime2(_DateTimeType, sqltypes.TypeEngine): - def __init__(self, precision=None, **kwargs): - self.precision = precision - - def get_col_spec(self): - if self.precision: - return "DATETIME2(%s)" % self.precision - else: - return "DATETIME2" - - -class MSDateTimeOffset(_DateTimeType, sqltypes.TypeEngine): - def __init__(self, precision=None, **kwargs): - self.precision = precision - - def get_col_spec(self): - if self.precision: - return "DATETIMEOFFSET(%s)" % self.precision - else: - return "DATETIMEOFFSET" - - -class MSDateTimeAsDate(_DateTimeType, MSDate): - """ This is an implementation of the Date type for versions of MSSQL that - do not support that specific type. In order to make it work a ``DATETIME`` - column specification is used and the results get converted back to just - the date portion. - - """ - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - # If the DBAPI returns the value as datetime.datetime(), truncate - # it back to datetime.date() - if type(value) is datetime.datetime: - return value.date() - return value - return process - - -class MSDateTimeAsTime(MSTime): - """ This is an implementation of the Time type for versions of MSSQL that - do not support that specific type. In order to make it work a ``DATETIME`` - column specification is used and the results get converted back to just - the time portion. - - """ - - __zero_date = datetime.date(1900, 1, 1) - - def get_col_spec(self): - return "DATETIME" - - def bind_processor(self, dialect): - def process(value): - if type(value) is datetime.datetime: - value = datetime.datetime.combine(self.__zero_date, value.time()) - elif type(value) is datetime.time: - value = datetime.datetime.combine(self.__zero_date, value) - return value - return process - - def result_processor(self, dialect): - def process(value): - if type(value) is datetime.datetime: - return value.time() - elif type(value) is datetime.date: - return datetime.time(0, 0, 0) - return value - return process - - -class MSDateTime_adodbapi(MSDateTime): - def result_processor(self, dialect): - def process(value): - # adodbapi will return datetimes with empty time values as datetime.date() objects. - # Promote them back to full datetime.datetime() - if type(value) is datetime.date: - return datetime.datetime(value.year, value.month, value.day) - return value - return process - - -class MSText(_StringType, sqltypes.Text): - """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" - - def __init__(self, *args, **kwargs): - """Construct a TEXT. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.Text.__init__(self, None, - convert_unicode=kwargs.get('convert_unicode', False), - assert_unicode=kwargs.get('assert_unicode', None)) - - def get_col_spec(self): - if self.dialect.text_as_varchar: - return self._extend("VARCHAR(max)") - else: - return self._extend("TEXT") - - -class MSNText(_StringType, sqltypes.UnicodeText): - """MSSQL NTEXT type, for variable-length unicode text up to 2^30 - characters.""" - - def __init__(self, *args, **kwargs): - """Construct a NTEXT. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.UnicodeText.__init__(self, None, - convert_unicode=kwargs.get('convert_unicode', True), - assert_unicode=kwargs.get('assert_unicode', 'warn')) - - def get_col_spec(self): - if self.dialect.text_as_varchar: - return self._extend("NVARCHAR(max)") - else: - return self._extend("NTEXT") - - -class MSString(_StringType, sqltypes.String): - """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum - of 8,000 characters.""" - - def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs): - """Construct a VARCHAR. - - :param length: Optinal, maximum data length, in characters. - - :param convert_unicode: defaults to False. If True, convert - ``unicode`` data sent to the database to a ``str`` - bytestring, and convert bytestrings coming back from the - database into ``unicode``. - - Bytestrings are encoded using the dialect's - :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which - defaults to `utf-8`. - - If False, may be overridden by - :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. - - :param assert_unicode: - - If None (the default), no assertion will take place unless - overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. - - If 'warn', will issue a runtime warning if a ``str`` - instance is used as a bind value. - - If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.String.__init__(self, length=length, - convert_unicode=convert_unicode, - assert_unicode=assert_unicode) - - def get_col_spec(self): - if self.length: - return self._extend("VARCHAR(%s)" % self.length) - else: - return self._extend("VARCHAR") - - -class MSNVarchar(_StringType, sqltypes.Unicode): - """MSSQL NVARCHAR type. - - For variable-length unicode character data up to 4,000 characters.""" - - def __init__(self, length=None, **kwargs): - """Construct a NVARCHAR. - - :param length: Optional, Maximum data length, in characters. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.Unicode.__init__(self, length=length, - convert_unicode=kwargs.get('convert_unicode', True), - assert_unicode=kwargs.get('assert_unicode', 'warn')) - - def adapt(self, impltype): - return impltype(length=self.length, - convert_unicode=self.convert_unicode, - assert_unicode=self.assert_unicode, - collation=self.collation) - - def get_col_spec(self): - if self.length: - return self._extend("NVARCHAR(%(length)s)" % {'length' : self.length}) - else: - return self._extend("NVARCHAR") - - -class MSChar(_StringType, sqltypes.CHAR): - """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum - of 8,000 characters.""" - - def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs): - """Construct a CHAR. - - :param length: Optinal, maximum data length, in characters. - - :param convert_unicode: defaults to False. If True, convert - ``unicode`` data sent to the database to a ``str`` - bytestring, and convert bytestrings coming back from the - database into ``unicode``. - - Bytestrings are encoded using the dialect's - :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which - defaults to `utf-8`. - - If False, may be overridden by - :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. - - :param assert_unicode: - - If None (the default), no assertion will take place unless - overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. - - If 'warn', will issue a runtime warning if a ``str`` - instance is used as a bind value. - - If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.CHAR.__init__(self, length=length, - convert_unicode=convert_unicode, - assert_unicode=assert_unicode) - - def get_col_spec(self): - if self.length: - return self._extend("CHAR(%s)" % self.length) - else: - return self._extend("CHAR") - - -class MSNChar(_StringType, sqltypes.NCHAR): - """MSSQL NCHAR type. - - For fixed-length unicode character data up to 4,000 characters.""" - - def __init__(self, length=None, **kwargs): - """Construct an NCHAR. - - :param length: Optional, Maximum data length, in characters. - - :param collation: Optional, a column-level collation for this string - value. Accepts a Windows Collation Name or a SQL Collation Name. - - """ - _StringType.__init__(self, **kwargs) - sqltypes.NCHAR.__init__(self, length=length, - convert_unicode=kwargs.get('convert_unicode', True), - assert_unicode=kwargs.get('assert_unicode', 'warn')) - - def get_col_spec(self): - if self.length: - return self._extend("NCHAR(%(length)s)" % {'length' : self.length}) - else: - return self._extend("NCHAR") - - -class MSGenericBinary(sqltypes.Binary): - """The Binary type assumes that a Binary specification without a length - is an unbound Binary type whereas one with a length specification results - in a fixed length Binary type. - - If you want standard MSSQL ``BINARY`` behavior use the ``MSBinary`` type. - - """ - - def get_col_spec(self): - if self.length: - return "BINARY(%s)" % self.length - else: - return "IMAGE" - - -class MSBinary(MSGenericBinary): - def get_col_spec(self): - if self.length: - return "BINARY(%s)" % self.length - else: - return "BINARY" - - -class MSVarBinary(MSGenericBinary): - def get_col_spec(self): - if self.length: - return "VARBINARY(%s)" % self.length - else: - return "VARBINARY" - - -class MSImage(MSGenericBinary): - def get_col_spec(self): - return "IMAGE" - - -class MSBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BIT" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - - -class MSTimeStamp(sqltypes.TIMESTAMP): - def get_col_spec(self): - return "TIMESTAMP" - - -class MSMoney(sqltypes.TypeEngine): - def get_col_spec(self): - return "MONEY" - - -class MSSmallMoney(MSMoney): - def get_col_spec(self): - return "SMALLMONEY" - - -class MSUniqueIdentifier(sqltypes.TypeEngine): - def get_col_spec(self): - return "UNIQUEIDENTIFIER" - - -class MSVariant(sqltypes.TypeEngine): - def get_col_spec(self): - return "SQL_VARIANT" - -ischema = MetaData() - -schemata = Table("SCHEMATA", ischema, - Column("CATALOG_NAME", String, key="catalog_name"), - Column("SCHEMA_NAME", String, key="schema_name"), - Column("SCHEMA_OWNER", String, key="schema_owner"), - schema="INFORMATION_SCHEMA") - -tables = Table("TABLES", ischema, - Column("TABLE_CATALOG", String, key="table_catalog"), - Column("TABLE_SCHEMA", String, key="table_schema"), - Column("TABLE_NAME", String, key="table_name"), - Column("TABLE_TYPE", String, key="table_type"), - schema="INFORMATION_SCHEMA") - -columns = Table("COLUMNS", ischema, - Column("TABLE_SCHEMA", String, key="table_schema"), - Column("TABLE_NAME", String, key="table_name"), - Column("COLUMN_NAME", String, key="column_name"), - Column("IS_NULLABLE", Integer, key="is_nullable"), - Column("DATA_TYPE", String, key="data_type"), - Column("ORDINAL_POSITION", Integer, key="ordinal_position"), - Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"), - Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), - Column("NUMERIC_SCALE", Integer, key="numeric_scale"), - Column("COLUMN_DEFAULT", Integer, key="column_default"), - Column("COLLATION_NAME", String, key="collation_name"), - schema="INFORMATION_SCHEMA") - -constraints = Table("TABLE_CONSTRAINTS", ischema, - Column("TABLE_SCHEMA", String, key="table_schema"), - Column("TABLE_NAME", String, key="table_name"), - Column("CONSTRAINT_NAME", String, key="constraint_name"), - Column("CONSTRAINT_TYPE", String, key="constraint_type"), - schema="INFORMATION_SCHEMA") - -column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, - Column("TABLE_SCHEMA", String, key="table_schema"), - Column("TABLE_NAME", String, key="table_name"), - Column("COLUMN_NAME", String, key="column_name"), - Column("CONSTRAINT_NAME", String, key="constraint_name"), - schema="INFORMATION_SCHEMA") - -key_constraints = Table("KEY_COLUMN_USAGE", ischema, - Column("TABLE_SCHEMA", String, key="table_schema"), - Column("TABLE_NAME", String, key="table_name"), - Column("COLUMN_NAME", String, key="column_name"), - Column("CONSTRAINT_NAME", String, key="constraint_name"), - Column("ORDINAL_POSITION", Integer, key="ordinal_position"), - schema="INFORMATION_SCHEMA") - -ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, - Column("CONSTRAINT_CATALOG", String, key="constraint_catalog"), - Column("CONSTRAINT_SCHEMA", String, key="constraint_schema"), - Column("CONSTRAINT_NAME", String, key="constraint_name"), - Column("UNIQUE_CONSTRAINT_CATLOG", String, key="unique_constraint_catalog"), - Column("UNIQUE_CONSTRAINT_SCHEMA", String, key="unique_constraint_schema"), - Column("UNIQUE_CONSTRAINT_NAME", String, key="unique_constraint_name"), - Column("MATCH_OPTION", String, key="match_option"), - Column("UPDATE_RULE", String, key="update_rule"), - Column("DELETE_RULE", String, key="delete_rule"), - schema="INFORMATION_SCHEMA") - -def _has_implicit_sequence(column): - return column.primary_key and \ - column.autoincrement and \ - isinstance(column.type, sqltypes.Integer) and \ - not column.foreign_keys and \ - ( - column.default is None or - ( - isinstance(column.default, schema.Sequence) and - column.default.optional) - ) - -def _table_sequence_column(tbl): - if not hasattr(tbl, '_ms_has_sequence'): - tbl._ms_has_sequence = None - for column in tbl.c: - if getattr(column, 'sequence', False) or _has_implicit_sequence(column): - tbl._ms_has_sequence = column - break - return tbl._ms_has_sequence - -class MSSQLExecutionContext(default.DefaultExecutionContext): - IINSERT = False - HASIDENT = False - - def pre_exec(self): - """Activate IDENTITY_INSERT if needed.""" - - if self.compiled.isinsert: - tbl = self.compiled.statement.table - seq_column = _table_sequence_column(tbl) - self.HASIDENT = bool(seq_column) - if self.dialect.auto_identity_insert and self.HASIDENT: - self.IINSERT = tbl._ms_has_sequence.key in self.compiled_parameters[0] - else: - self.IINSERT = False - - if self.IINSERT: - self.cursor.execute("SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) - - def handle_dbapi_exception(self, e): - if self.IINSERT: - try: - self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) - except: - pass - - def post_exec(self): - """Disable IDENTITY_INSERT if enabled.""" - - if self.compiled.isinsert and not self.executemany and self.HASIDENT and not self.IINSERT: - if not self._last_inserted_ids or self._last_inserted_ids[0] is None: - if self.dialect.use_scope_identity: - self.cursor.execute("SELECT scope_identity() AS lastrowid") - else: - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] - - if self.IINSERT: - self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) - - -class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext): - def pre_exec(self): - """where appropriate, issue "select scope_identity()" in the same statement""" - super(MSSQLExecutionContext_pyodbc, self).pre_exec() - if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \ - and len(self.parameters) == 1 and self.dialect.use_scope_identity: - self.statement += "; select scope_identity()" - - def post_exec(self): - if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany: - import pyodbc - # Fetch the last inserted id from the manipulated statement - # We may have to skip over a number of result sets with no data (due to triggers, etc.) - while True: - try: - row = self.cursor.fetchone() - break - except pyodbc.Error, e: - self.cursor.nextset() - self._last_inserted_ids = [int(row[0])] - else: - super(MSSQLExecutionContext_pyodbc, self).post_exec() - -class MSSQLDialect(default.DefaultDialect): - name = 'mssql' - supports_default_values = True - supports_empty_insert = False - auto_identity_insert = True - execution_ctx_cls = MSSQLExecutionContext - text_as_varchar = False - use_scope_identity = False - has_window_funcs = False - max_identifier_length = 128 - schema_name = "dbo" - - colspecs = { - sqltypes.Unicode : MSNVarchar, - sqltypes.Integer : MSInteger, - sqltypes.Smallinteger: MSSmallInteger, - sqltypes.Numeric : MSNumeric, - sqltypes.Float : MSFloat, - sqltypes.DateTime : MSDateTime, - sqltypes.Date : MSDate, - sqltypes.Time : MSTime, - sqltypes.String : MSString, - sqltypes.Binary : MSGenericBinary, - sqltypes.Boolean : MSBoolean, - sqltypes.Text : MSText, - sqltypes.UnicodeText : MSNText, - sqltypes.CHAR: MSChar, - sqltypes.NCHAR: MSNChar, - sqltypes.TIMESTAMP: MSTimeStamp, - } - - ischema_names = { - 'int' : MSInteger, - 'bigint': MSBigInteger, - 'smallint' : MSSmallInteger, - 'tinyint' : MSTinyInteger, - 'varchar' : MSString, - 'nvarchar' : MSNVarchar, - 'char' : MSChar, - 'nchar' : MSNChar, - 'text' : MSText, - 'ntext' : MSNText, - 'decimal' : MSNumeric, - 'numeric' : MSNumeric, - 'float' : MSFloat, - 'datetime' : MSDateTime, - 'datetime2' : MSDateTime2, - 'datetimeoffset' : MSDateTimeOffset, - 'date': MSDate, - 'time': MSTime, - 'smalldatetime' : MSSmallDateTime, - 'binary' : MSBinary, - 'varbinary' : MSVarBinary, - 'bit': MSBoolean, - 'real' : MSFloat, - 'image' : MSImage, - 'timestamp': MSTimeStamp, - 'money': MSMoney, - 'smallmoney': MSSmallMoney, - 'uniqueidentifier': MSUniqueIdentifier, - 'sql_variant': MSVariant, - } - - def __new__(cls, *args, **kwargs): - if cls is not MSSQLDialect: - # this gets called with the dialect specific class - return super(MSSQLDialect, cls).__new__(cls) - dbapi = kwargs.get('dbapi', None) - if dbapi: - dialect = dialect_mapping.get(dbapi.__name__) - return dialect(**kwargs) - else: - return object.__new__(cls) - - def __init__(self, - auto_identity_insert=True, query_timeout=None, - text_as_varchar=False, use_scope_identity=False, - has_window_funcs=False, max_identifier_length=None, - schema_name="dbo", **opts): - self.auto_identity_insert = bool(auto_identity_insert) - self.query_timeout = int(query_timeout or 0) - self.schema_name = schema_name - - # to-do: the options below should use server version introspection to set themselves on connection - self.text_as_varchar = bool(text_as_varchar) - self.use_scope_identity = bool(use_scope_identity) - self.has_window_funcs = bool(has_window_funcs) - self.max_identifier_length = int(max_identifier_length or 0) or \ - self.max_identifier_length - super(MSSQLDialect, self).__init__(**opts) - - @classmethod - def dbapi(cls, module_name=None): - if module_name: - try: - dialect_cls = dialect_mapping[module_name] - return dialect_cls.import_dbapi() - except KeyError: - raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name) - else: - for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]: - try: - return dialect_cls.import_dbapi() - except ImportError, e: - pass - else: - raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi') - - @base.connection_memoize(('mssql', 'server_version_info')) - def server_version_info(self, connection): - """A tuple of the database server version. - - Formats the remote server version as a tuple of version values, - e.g. ``(9, 0, 1399)``. If there are strings in the version number - they will be in the tuple too, so don't count on these all being - ``int`` values. - - This is a fast check that does not require a round trip. It is also - cached per-Connection. - """ - return connection.dialect._server_version_info(connection.connection) - - def _server_version_info(self, dbapi_con): - """Return a tuple of the database's version number.""" - raise NotImplementedError() - - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - opts.update(url.query) - if 'auto_identity_insert' in opts: - self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert'))) - if 'query_timeout' in opts: - self.query_timeout = int(opts.pop('query_timeout')) - if 'text_as_varchar' in opts: - self.text_as_varchar = bool(int(opts.pop('text_as_varchar'))) - if 'use_scope_identity' in opts: - self.use_scope_identity = bool(int(opts.pop('use_scope_identity'))) - if 'has_window_funcs' in opts: - self.has_window_funcs = bool(int(opts.pop('has_window_funcs'))) - return self.make_connect_string(opts, url.query) - - def type_descriptor(self, typeobj): - newobj = sqltypes.adapt_type(typeobj, self.colspecs) - # Some types need to know about the dialect - if isinstance(newobj, (MSText, MSNText)): - newobj.dialect = self - return newobj - - def do_savepoint(self, connection, name): - util.warn("Savepoint support in mssql is experimental and may lead to data loss.") - connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") - connection.execute("SAVE TRANSACTION %s" % name) - - def do_release_savepoint(self, connection, name): - pass - - @base.connection_memoize(('dialect', 'default_schema_name')) - def get_default_schema_name(self, connection): - query = "SELECT user_name() as user_name;" - user_name = connection.scalar(sql.text(query)) - if user_name is not None: - # now, get the default schema - query = """ - SELECT default_schema_name FROM - sys.database_principals - WHERE name = :user_name - AND type = 'S' - """ - try: - default_schema_name = connection.scalar(sql.text(query), - user_name=user_name) - if default_schema_name is not None: - return default_schema_name - except: - pass - return self.schema_name - - def table_names(self, connection, schema): - s = select([tables.c.table_name], tables.c.table_schema==schema) - return [row[0] for row in connection.execute(s)] - - - def has_table(self, connection, tablename, schema=None): - - current_schema = schema or self.get_default_schema_name(connection) - s = sql.select([columns], - current_schema - and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) - or columns.c.table_name==tablename, - ) - - c = connection.execute(s) - row = c.fetchone() - return row is not None - - def reflecttable(self, connection, table, include_columns): - # Get base columns - if table.schema is not None: - current_schema = table.schema - else: - current_schema = self.get_default_schema_name(connection) - - s = sql.select([columns], - current_schema - and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema) - or columns.c.table_name==table.name, - order_by=[columns.c.ordinal_position]) - - c = connection.execute(s) - found_table = False - while True: - row = c.fetchone() - if row is None: - break - found_table = True - (name, type, nullable, charlen, numericprec, numericscale, default, collation) = ( - row[columns.c.column_name], - row[columns.c.data_type], - row[columns.c.is_nullable] == 'YES', - row[columns.c.character_maximum_length], - row[columns.c.numeric_precision], - row[columns.c.numeric_scale], - row[columns.c.column_default], - row[columns.c.collation_name] - ) - if include_columns and name not in include_columns: - continue - - coltype = self.ischema_names.get(type, None) - - kwargs = {} - if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.Binary): - kwargs['length'] = charlen - if collation: - kwargs['collation'] = collation - if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1): - kwargs.pop('length') - - if issubclass(coltype, sqltypes.Numeric): - kwargs['scale'] = numericscale - kwargs['precision'] = numericprec - - if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % (type, name)) - coltype = sqltypes.NULLTYPE - - coltype = coltype(**kwargs) - colargs = [] - if default is not None: - colargs.append(schema.DefaultClause(sql.text(default))) - table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - # We also run an sp_columns to check for identity columns: - cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (table.name, current_schema)) - ic = None - while True: - row = cursor.fetchone() - if row is None: - break - col_name, type_name = row[3], row[5] - if type_name.endswith("identity") and col_name in table.c: - ic = table.c[col_name] - ic.autoincrement = True - # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute - ic.sequence = schema.Sequence(ic.name + '_identity', 1, 1) - # MSSQL: only one identity per table allowed - cursor.close() - break - if not ic is None: - try: - cursor = connection.execute("select ident_seed(?), ident_incr(?)", table.fullname, table.fullname) - row = cursor.fetchone() - cursor.close() - if not row is None: - ic.sequence.start = int(row[0]) - ic.sequence.increment = int(row[1]) - except: - # ignoring it, works just like before - pass - - # Add constraints - RR = ref_constraints - TC = constraints - C = key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column - R = key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column - - # Primary key constraints - s = sql.select([C.c.column_name, TC.c.constraint_type], sql.and_(TC.c.constraint_name == C.c.constraint_name, - C.c.table_name == table.name, - C.c.table_schema == (table.schema or current_schema))) - c = connection.execute(s) - for row in c: - if 'PRIMARY' in row[TC.c.constraint_type.name] and row[0] in table.c: - table.primary_key.add(table.c[row[0]]) - - # Foreign key constraints - s = sql.select([C.c.column_name, - R.c.table_schema, R.c.table_name, R.c.column_name, - RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule], - sql.and_(C.c.table_name == table.name, - C.c.table_schema == (table.schema or current_schema), - C.c.constraint_name == RR.c.constraint_name, - R.c.constraint_name == RR.c.unique_constraint_name, - C.c.ordinal_position == R.c.ordinal_position - ), - order_by = [RR.c.constraint_name, R.c.ordinal_position]) - rows = connection.execute(s).fetchall() - - def _gen_fkref(table, rschema, rtbl, rcol): - if rschema == current_schema and not table.schema: - return '.'.join([rtbl, rcol]) - else: - return '.'.join([rschema, rtbl, rcol]) - - # group rows by constraint ID, to handle multi-column FKs - fknm, scols, rcols = (None, [], []) - for r in rows: - scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r - # if the reflected schema is the default schema then don't set it because this will - # play into the metadata key causing duplicates. - if rschema == current_schema and not table.schema: - schema.Table(rtbl, table.metadata, autoload=True, autoload_with=connection) - else: - schema.Table(rtbl, table.metadata, schema=rschema, autoload=True, autoload_with=connection) - if rfknm != fknm: - if fknm: - table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True)) - fknm, scols, rcols = (rfknm, [], []) - if not scol in scols: - scols.append(scol) - if not (rschema, rtbl, rcol) in rcols: - rcols.append((rschema, rtbl, rcol)) - - if fknm and scols: - table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True)) - - -class MSSQLDialect_pymssql(MSSQLDialect): - supports_sane_rowcount = False - max_identifier_length = 30 - - @classmethod - def import_dbapi(cls): - import pymssql as module - # pymmsql doesn't have a Binary method. we use string - # TODO: monkeypatching here is less than ideal - module.Binary = lambda st: str(st) - try: - module.version_info = tuple(map(int, module.__version__.split('.'))) - except: - module.version_info = (0, 0, 0) - return module - - def __init__(self, **params): - super(MSSQLDialect_pymssql, self).__init__(**params) - self.use_scope_identity = True - - # pymssql understands only ascii - if self.convert_unicode: - util.warn("pymssql does not support unicode") - self.encoding = params.get('encoding', 'ascii') - - self.colspecs = MSSQLDialect.colspecs.copy() - self.ischema_names = MSSQLDialect.ischema_names.copy() - self.ischema_names['date'] = MSDateTimeAsDate - self.colspecs[sqltypes.Date] = MSDateTimeAsDate - self.ischema_names['time'] = MSDateTimeAsTime - self.colspecs[sqltypes.Time] = MSDateTimeAsTime - - def create_connect_args(self, url): - r = super(MSSQLDialect_pymssql, self).create_connect_args(url) - if hasattr(self, 'query_timeout'): - if self.dbapi.version_info > (0, 8, 0): - r[1]['timeout'] = self.query_timeout - else: - self.dbapi._mssql.set_query_timeout(self.query_timeout) - return r - - def make_connect_string(self, keys, query): - if keys.get('port'): - # pymssql expects port as host:port, not a separate arg - keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])]) - del keys['port'] - return [[], keys] - - def is_disconnect(self, e): - return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e) - - def do_begin(self, connection): - pass - - -class MSSQLDialect_pyodbc(MSSQLDialect): - supports_sane_rowcount = False - supports_sane_multi_rowcount = False - # PyODBC unicode is broken on UCS-4 builds - supports_unicode = sys.maxunicode == 65535 - supports_unicode_statements = supports_unicode - execution_ctx_cls = MSSQLExecutionContext_pyodbc - - def __init__(self, description_encoding='latin-1', **params): - super(MSSQLDialect_pyodbc, self).__init__(**params) - self.description_encoding = description_encoding - - if self.server_version_info < (10,): - self.colspecs = MSSQLDialect.colspecs.copy() - self.ischema_names = MSSQLDialect.ischema_names.copy() - self.ischema_names['date'] = MSDateTimeAsDate - self.colspecs[sqltypes.Date] = MSDateTimeAsDate - self.ischema_names['time'] = MSDateTimeAsTime - self.colspecs[sqltypes.Time] = MSDateTimeAsTime - - # FIXME: scope_identity sniff should look at server version, not the ODBC driver - # whether use_scope_identity will work depends on the version of pyodbc - try: - import pyodbc - self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset') - except: - pass - - @classmethod - def import_dbapi(cls): - import pyodbc as module - return module - - def make_connect_string(self, keys, query): - if 'max_identifier_length' in keys: - self.max_identifier_length = int(keys.pop('max_identifier_length')) - - if 'odbc_connect' in keys: - connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))] - else: - dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys) - if dsn_connection: - connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))] - else: - port = '' - if 'port' in keys and not 'port' in query: - port = ',%d' % int(keys.pop('port')) - - connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'), - 'Server=%s%s' % (keys.pop('host', ''), port), - 'Database=%s' % keys.pop('database', '') ] - - user = keys.pop("user", None) - if user: - connectors.append("UID=%s" % user) - connectors.append("PWD=%s" % keys.pop('password', '')) - else: - connectors.append("TrustedConnection=Yes") - - # if set to 'Yes', the ODBC layer will try to automagically convert - # textual data from your database encoding to your client encoding - # This should obviously be set to 'No' if you query a cp1253 encoded - # database from a latin1 client... - if 'odbc_autotranslate' in keys: - connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) - - connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()]) - - return [[";".join (connectors)], {}] - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.ProgrammingError): - return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e) - elif isinstance(e, self.dbapi.Error): - return '[08S01]' in str(e) - else: - return False - - - def _server_version_info(self, dbapi_con): - """Convert a pyodbc SQL_DBMS_VER string into a tuple.""" - version = [] - r = re.compile('[.\-]') - for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): - try: - version.append(int(n)) - except ValueError: - version.append(n) - return tuple(version) - -class MSSQLDialect_adodbapi(MSSQLDialect): - supports_sane_rowcount = True - supports_sane_multi_rowcount = True - supports_unicode = sys.maxunicode == 65535 - supports_unicode_statements = True - - @classmethod - def import_dbapi(cls): - import adodbapi as module - return module - - colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.DateTime] = MSDateTime_adodbapi - - ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['datetime'] = MSDateTime_adodbapi - - def make_connect_string(self, keys, query): - connectors = ["Provider=SQLOLEDB"] - if 'port' in keys: - connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port"))) - else: - connectors.append ("Data Source=%s" % keys.get("host")) - connectors.append ("Initial Catalog=%s" % keys.get("database")) - user = keys.get("user") - if user: - connectors.append("User Id=%s" % user) - connectors.append("Password=%s" % keys.get("password", "")) - else: - connectors.append("Integrated Security=SSPI") - return [[";".join (connectors)], {}] - - def is_disconnect(self, e): - return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) - - -dialect_mapping = { - 'pymssql': MSSQLDialect_pymssql, - 'pyodbc': MSSQLDialect_pyodbc, - 'adodbapi': MSSQLDialect_adodbapi - } - - -class MSSQLCompiler(compiler.DefaultCompiler): - operators = compiler.OPERATORS.copy() - operators.update({ - sql_operators.concat_op: '+', - sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) - }) - - functions = compiler.DefaultCompiler.functions.copy() - functions.update ( - { - sql_functions.now: 'CURRENT_TIMESTAMP', - sql_functions.current_date: 'GETDATE()', - 'length': lambda x: "LEN(%s)" % x, - sql_functions.char_length: lambda x: "LEN(%s)" % x - } - ) - - extract_map = compiler.DefaultCompiler.extract_map.copy() - extract_map.update ({ - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond', - 'microseconds': 'microsecond' - }) - - def __init__(self, *args, **kwargs): - super(MSSQLCompiler, self).__init__(*args, **kwargs) - self.tablealiases = {} - - def get_select_precolumns(self, select): - """ MS-SQL puts TOP, it's version of LIMIT here """ - if select._distinct or select._limit: - s = select._distinct and "DISTINCT " or "" - - if select._limit: - if not select._offset: - s += "TOP %s " % (select._limit,) - else: - if not self.dialect.has_window_funcs: - raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') - return s - return compiler.DefaultCompiler.get_select_precolumns(self, select) - - def limit_clause(self, select): - # Limit in mssql is after the select keyword - return "" - - def visit_select(self, select, **kwargs): - """Look for ``LIMIT`` and OFFSET in a select statement, and if - so tries to wrap it in a subquery with ``row_number()`` criterion. - - """ - if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset: - # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.process(select._order_by_clause) - if not orderby: - raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.') - - _offset = select._offset - _limit = select._limit - select._mssql_visit = True - select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() - - limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) - limitselect.append_whereclause("mssql_rn>%d" % _offset) - if _limit is not None: - limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) - return self.process(limitselect, iswrapper=True, **kwargs) - else: - return compiler.DefaultCompiler.visit_select(self, select, **kwargs) - - def _schema_aliased_table(self, table): - if getattr(table, 'schema', None) is not None: - if table not in self.tablealiases: - self.tablealiases[table] = table.alias() - return self.tablealiases[table] - else: - return None - - def visit_table(self, table, mssql_aliased=False, **kwargs): - if mssql_aliased: - return super(MSSQLCompiler, self).visit_table(table, **kwargs) - - # alias schema-qualified tables - alias = self._schema_aliased_table(table) - if alias is not None: - return self.process(alias, mssql_aliased=True, **kwargs) - else: - return super(MSSQLCompiler, self).visit_table(table, **kwargs) - - def visit_alias(self, alias, **kwargs): - # translate for schema-qualified table aliases - self.tablealiases[alias.original] = alias - kwargs['mssql_aliased'] = True - return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) - - def visit_extract(self, extract): - field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) - - def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) - - def visit_column(self, column, result_map=None, **kwargs): - if column.table is not None and \ - (not self.isupdate and not self.isdelete) or self.is_subquery(): - # translate for schema-qualified table aliases - t = self._schema_aliased_table(column.table) - if t is not None: - converted = expression._corresponding_column_or_error(t, column) - - if result_map is not None: - result_map[column.name.lower()] = (column.name, (column, ), column.type) - - return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs) - - return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs) - - def visit_binary(self, binary, **kwargs): - """Move bind parameters to the right-hand side of an operator, where - possible. - - """ - if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \ - and not isinstance(binary.right, expression._BindParamClause): - return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) - else: - if (binary.operator is operator.eq or binary.operator is operator.ne) and ( - (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \ - (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \ - isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)): - op = binary.operator == operator.eq and "IN" or "NOT IN" - return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) - return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) - - def visit_insert(self, insert_stmt): - insert_select = False - if insert_stmt.parameters: - insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)] - if insert_select: - self.isinsert = True - colparams = self._get_colparams(insert_stmt) - preparer = self.preparer - - insert = ' '.join(["INSERT"] + - [self.process(x) for x in insert_stmt._prefixes]) - - if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert: - raise exc.CompileError( - "The version of %s you are using does not support empty inserts." % self.dialect.name) - elif not colparams and self.dialect.supports_default_values: - return (insert + " INTO %s DEFAULT VALUES" % ( - (preparer.format_table(insert_stmt.table),))) - else: - return (insert + " INTO %s (%s) SELECT %s" % - (preparer.format_table(insert_stmt.table), - ', '.join([preparer.format_column(c[0]) - for c in colparams]), - ', '.join([c[1] for c in colparams]))) - else: - return super(MSSQLCompiler, self).visit_insert(insert_stmt) - - def label_select_column(self, select, column, asfrom): - if isinstance(column, expression.Function): - return column.label(None) - else: - return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) - - def for_update_clause(self, select): - # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use - return '' - - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) - - # MSSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not self.is_subquery() or select._limit): - return " ORDER BY " + order_by - else: - return "" - - -class MSSQLSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() - - if column.nullable is not None: - if not column.nullable or column.primary_key: - colspec += " NOT NULL" - else: - colspec += " NULL" - - if not column.table: - raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL") - - seq_col = _table_sequence_column(column.table) - - # install a IDENTITY Sequence if we have an implicit IDENTITY column - if seq_col is column: - sequence = getattr(column, 'sequence', None) - if sequence: - start, increment = sequence.start or 1, sequence.increment or 1 - else: - start, increment = 1, 1 - colspec += " IDENTITY(%s,%s)" % (start, increment) - else: - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - return colspec - -class MSSQLSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - self.append("\nDROP INDEX %s.%s" % ( - self.preparer.quote_identifier(index.table.name), - self.preparer.quote(self._validate_identifier(index.name, False), index.quote) - )) - self.execute() - - -class MSSQLIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = RESERVED_WORDS - - def __init__(self, dialect): - super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') - - def _escape_identifier(self, value): - #TODO: determine MSSQL's escaping rules - return value - - def quote_schema(self, schema, force=True): - """Prepare a quoted table and schema name.""" - result = '.'.join([self.quote(x, force) for x in schema.split('.')]) - return result - -dialect = MSSQLDialect -dialect.statement_compiler = MSSQLCompiler -dialect.schemagenerator = MSSQLSchemaGenerator -dialect.schemadropper = MSSQLSchemaDropper -dialect.preparer = MSSQLIdentifierPreparer - diff --git a/lib/sqlalchemy/databases/mxODBC.py b/lib/sqlalchemy/databases/mxODBC.py deleted file mode 100644 index 92f533633c..0000000000 --- a/lib/sqlalchemy/databases/mxODBC.py +++ /dev/null @@ -1,60 +0,0 @@ -# mxODBC.py -# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch -# Coding: Alexander Houben alexander.houben@thor-solutions.ch -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -""" -A wrapper for a mx.ODBC.Windows DB-API connection. - -Makes sure the mx module is configured to return datetime objects instead -of mx.DateTime.DateTime objects. -""" - -from mx.ODBC.Windows import * - - -class Cursor: - def __init__(self, cursor): - self.cursor = cursor - - def __getattr__(self, attr): - res = getattr(self.cursor, attr) - return res - - def execute(self, *args, **kwargs): - res = self.cursor.execute(*args, **kwargs) - return res - - -class Connection: - def myErrorHandler(self, connection, cursor, errorclass, errorvalue): - err0, err1, err2, err3 = errorvalue - #print ", ".join(["Err%d: %s"%(x, errorvalue[x]) for x in range(4)]) - if int(err1) == 109: - # Ignore "Null value eliminated in aggregate function", this is not an error - return - raise errorclass, errorvalue - - def __init__(self, conn): - self.conn = conn - # install a mx ODBC error handler - self.conn.errorhandler = self.myErrorHandler - - def __getattr__(self, attr): - res = getattr(self.conn, attr) - return res - - def cursor(self, *args, **kwargs): - res = Cursor(self.conn.cursor(*args, **kwargs)) - return res - - -# override 'connect' call -def connect(*args, **kwargs): - import mx.ODBC.Windows - conn = mx.ODBC.Windows.Connect(*args, **kwargs) - conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT - return Connection(conn) -Connect = connect diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py deleted file mode 100644 index 852cab448e..0000000000 --- a/lib/sqlalchemy/databases/oracle.py +++ /dev/null @@ -1,904 +0,0 @@ -# oracle.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Support for the Oracle database. - -Oracle version 8 through current (11g at the time of this writing) are supported. - -Driver ------- - -The Oracle dialect uses the cx_oracle driver, available at -http://cx-oracle.sourceforge.net/ . The dialect has several behaviors -which are specifically tailored towards compatibility with this module. - -Connecting ----------- - -Connecting with create_engine() uses the standard URL approach of -``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the -host, port, and dbname tokens are converted to a TNS name using the cx_oracle -:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name. - -Additional arguments which may be specified either as query string arguments on the -URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are: - -* *allow_twophase* - enable two-phase transactions. Defaults to ``True``. - -* *auto_convert_lobs* - defaults to True, see the section on LOB objects. - -* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters. - This is required for LOB datatypes but can be disabled to reduce overhead. Defaults - to ``True``. - -* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an - integer value. This value is only available as a URL query string argument. - -* *threaded* - enable multithreaded access to cx_oracle connections. Defaults - to ``True``. Note that this is the opposite default of cx_oracle itself. - -* *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8). Defaults - to ``True``. If ``False``, Oracle-8 compatible constructs are used for joins. - -* *optimize_limits* - defaults to ``False``. see the section on LIMIT/OFFSET. - -Auto Increment Behavior ------------------------ - -SQLAlchemy Table objects which include integer primary keys are usually assumed to have -"autoincrementing" behavior, meaning they can generate their own primary key values upon -INSERT. Since Oracle has no "autoincrement" feature, SQLAlchemy relies upon sequences -to produce these values. With the Oracle dialect, *a sequence must always be explicitly -specified to enable autoincrement*. This is divergent with the majority of documentation -examples which assume the usage of an autoincrement-capable database. To specify sequences, -use the sqlalchemy.schema.Sequence object which is passed to a Column construct:: - - t = Table('mytable', metadata, - Column('id', Integer, Sequence('id_seq'), primary_key=True), - Column(...), ... - ) - -This step is also required when using table reflection, i.e. autoload=True:: - - t = Table('mytable', metadata, - Column('id', Integer, Sequence('id_seq'), primary_key=True), - autoload=True - ) - -LOB Objects ------------ - -cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set -is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default, -SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First, -the LOB object requires an active cursor association, meaning if you were to fetch many rows -at once such that cx_oracle had to go back to the database and fetch a new batch of rows, -the LOB objects in the already-fetched rows are now unreadable and will raise an error. -SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read. -The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy -defaults to 50 (cx_oracle normally defaults this to one). - -Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to -"normalize" the results to look more like other DBAPIs. - -The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place -for all statement executions, even plain string-based statements for which SQLA has no awareness -of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases -without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch" -of LOB objects, can be disabled using auto_convert_lobs=False. - -LIMIT/OFFSET Support --------------------- - -Oracle has no support for the LIMIT or OFFSET keywords. Whereas previous versions of SQLAlchemy -used the "ROW NUMBER OVER..." construct to simulate LIMIT/OFFSET, SQLAlchemy 0.5 now uses -a wrapped subquery approach in conjunction with ROWNUM. The exact methodology is taken from -http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html . Note that the -"FIRST ROWS()" optimization keyword mentioned is not used by default, as the user community felt -this was stepping into the bounds of optimization that is better left on the DBA side, but this -prefix can be added by enabling the optimize_limits=True flag on create_engine(). - -Two Phase Transaction Support ------------------------------ - -Two Phase transactions are implemented using XA transactions. Success has been reported of them -working successfully but this should be regarded as an experimental feature. - -Oracle 8 Compatibility ----------------------- - -When using Oracle 8, a "use_ansi=False" flag is available which converts all -JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN -makes use of Oracle's (+) operator. - -Synonym/DBLINK Reflection -------------------------- - -When using reflection with Table objects, the dialect can optionally search for tables -indicated by synonyms that reference DBLINK-ed tables by passing the flag -oracle_resolve_synonyms=True as a keyword argument to the Table construct. If DBLINK -is not in use this flag should be left off. - -""" - -import datetime, random, re - -from sqlalchemy import util, sql, schema, log -from sqlalchemy.engine import default, base -from sqlalchemy.sql import compiler, visitors, expression -from sqlalchemy.sql import operators as sql_operators, functions as sql_functions -from sqlalchemy import types as sqltypes - - -class OracleNumeric(sqltypes.Numeric): - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class OracleInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class OracleSmallInteger(sqltypes.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class OracleDate(sqltypes.Date): - def get_col_spec(self): - return "DATE" - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - def process(value): - if not isinstance(value, datetime.datetime): - return value - else: - return value.date() - return process - -class OracleDateTime(sqltypes.DateTime): - def get_col_spec(self): - return "DATE" - - def result_processor(self, dialect): - def process(value): - if value is None or isinstance(value, datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year, value.month, - value.day,value.hour, value.minute, value.second) - return process - -# Note: -# Oracle DATE == DATETIME -# Oracle does not allow milliseconds in DATE -# Oracle does not support TIME columns - -# only if cx_oracle contains TIMESTAMP -class OracleTimestamp(sqltypes.TIMESTAMP): - def get_col_spec(self): - return "TIMESTAMP" - - def get_dbapi_type(self, dialect): - return dialect.TIMESTAMP - - def result_processor(self, dialect): - def process(value): - if value is None or isinstance(value, datetime.datetime): - return value - else: - # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year, value.month, - value.day,value.hour, value.minute, value.second) - return process - -class OracleString(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} - -class OracleNVarchar(sqltypes.Unicode, OracleString): - def get_col_spec(self): - return "NVARCHAR2(%(length)s)" % {'length' : self.length} - -class OracleText(sqltypes.Text): - def get_dbapi_type(self, dbapi): - return dbapi.CLOB - - def get_col_spec(self): - return "CLOB" - - def result_processor(self, dialect): - super_process = super(OracleText, self).result_processor(dialect) - if not dialect.auto_convert_lobs: - return super_process - lob = dialect.dbapi.LOB - def process(value): - if isinstance(value, lob): - if super_process: - return super_process(value.read()) - else: - return value.read() - else: - if super_process: - return super_process(value) - else: - return value - return process - - -class OracleChar(sqltypes.CHAR): - def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} - -class OracleBinary(sqltypes.Binary): - def get_dbapi_type(self, dbapi): - return dbapi.BLOB - - def get_col_spec(self): - return "BLOB" - - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - if not dialect.auto_convert_lobs: - return None - lob = dialect.dbapi.LOB - def process(value): - if isinstance(value, lob): - return value.read() - else: - return value - return process - -class OracleRaw(OracleBinary): - def get_col_spec(self): - return "RAW(%(length)s)" % {'length' : self.length} - -class OracleBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "SMALLINT" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - -colspecs = { - sqltypes.Integer : OracleInteger, - sqltypes.Smallinteger : OracleSmallInteger, - sqltypes.Numeric : OracleNumeric, - sqltypes.Float : OracleNumeric, - sqltypes.DateTime : OracleDateTime, - sqltypes.Date : OracleDate, - sqltypes.String : OracleString, - sqltypes.Binary : OracleBinary, - sqltypes.Boolean : OracleBoolean, - sqltypes.Text : OracleText, - sqltypes.TIMESTAMP : OracleTimestamp, - sqltypes.CHAR: OracleChar, -} - -ischema_names = { - 'VARCHAR2' : OracleString, - 'NVARCHAR2' : OracleNVarchar, - 'CHAR' : OracleString, - 'DATE' : OracleDateTime, - 'DATETIME' : OracleDateTime, - 'NUMBER' : OracleNumeric, - 'BLOB' : OracleBinary, - 'BFILE' : OracleBinary, - 'CLOB' : OracleText, - 'TIMESTAMP' : OracleTimestamp, - 'RAW' : OracleRaw, - 'FLOAT' : OracleNumeric, - 'DOUBLE PRECISION' : OracleNumeric, - 'LONG' : OracleText, -} - -class OracleExecutionContext(default.DefaultExecutionContext): - def pre_exec(self): - super(OracleExecutionContext, self).pre_exec() - if self.dialect.auto_setinputsizes: - self.set_input_sizes() - if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: - for key in self.compiled.binds: - bindparam = self.compiled.binds[key] - name = self.compiled.bind_names[bindparam] - value = self.compiled_parameters[0][name] - if bindparam.isoutparam: - dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if not hasattr(self, 'out_parameters'): - self.out_parameters = {} - self.out_parameters[name] = self.cursor.var(dbtype) - self.parameters[0][name] = self.out_parameters[name] - - def create_cursor(self): - c = self._connection.connection.cursor() - if self.dialect.arraysize: - c.cursor.arraysize = self.dialect.arraysize - return c - - def get_result_proxy(self): - if hasattr(self, 'out_parameters'): - if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: - for bind, name in self.compiled.bind_names.iteritems(): - if name in self.out_parameters: - type = bind.type - result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect) - if result_processor is not None: - self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue()) - else: - self.out_parameters[name] = self.out_parameters[name].getvalue() - else: - for k in self.out_parameters: - self.out_parameters[k] = self.out_parameters[k].getvalue() - - if self.cursor.description is not None: - for column in self.cursor.description: - type_code = column[1] - if type_code in self.dialect.ORACLE_BINARY_TYPES: - return base.BufferedColumnResultProxy(self) - - return base.ResultProxy(self) - -class OracleDialect(default.DefaultDialect): - name = 'oracle' - supports_alter = True - supports_unicode_statements = False - max_identifier_length = 30 - supports_sane_rowcount = True - supports_sane_multi_rowcount = False - preexecute_pk_sequences = True - supports_pk_autoincrement = False - default_paramstyle = 'named' - - def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, optimize_limits=False, arraysize=50, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - self.use_ansi = use_ansi - self.threaded = threaded - self.arraysize = arraysize - self.allow_twophase = allow_twophase - self.optimize_limits = optimize_limits - self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) - self.auto_setinputsizes = auto_setinputsizes - self.auto_convert_lobs = auto_convert_lobs - if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__: - self.dbapi_type_map = {} - self.ORACLE_BINARY_TYPES = [] - else: - # only use this for LOB objects. using it for strings, dates - # etc. leads to a little too much magic, reflection doesn't know if it should - # expect encoded strings or unicodes, etc. - self.dbapi_type_map = { - self.dbapi.CLOB: OracleText(), - self.dbapi.BLOB: OracleBinary(), - self.dbapi.BINARY: OracleRaw(), - } - self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)] - - def dbapi(cls): - import cx_Oracle - return cx_Oracle - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - dialect_opts = dict(url.query) - for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs', - 'threaded', 'allow_twophase'): - if opt in dialect_opts: - util.coerce_kw_type(dialect_opts, opt, bool) - setattr(self, opt, dialect_opts[opt]) - - if url.database: - # if we have a database, then we have a remote host - port = url.port - if port: - port = int(port) - else: - port = 1521 - dsn = self.dbapi.makedsn(url.host, port, url.database) - else: - # we have a local tnsname - dsn = url.host - - opts = dict( - user=url.username, - password=url.password, - dsn=dsn, - threaded=self.threaded, - twophase=self.allow_twophase, - ) - if 'mode' in url.query: - opts['mode'] = url.query['mode'] - if isinstance(opts['mode'], basestring): - mode = opts['mode'].upper() - if mode == 'SYSDBA': - opts['mode'] = self.dbapi.SYSDBA - elif mode == 'SYSOPER': - opts['mode'] = self.dbapi.SYSOPER - else: - util.coerce_kw_type(opts, 'mode', int) - # Can't set 'handle' or 'pool' via URL query args, use connect_args - - return ([], opts) - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.InterfaceError): - return "not connected" in str(e) - else: - return "ORA-03114" in str(e) or "ORA-03113" in str(e) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def create_xid(self): - """create a two-phase transaction ID. - - this id will be passed to do_begin_twophase(), do_rollback_twophase(), - do_commit_twophase(). its format is unspecified.""" - - id = random.randint(0, 2 ** 128) - return (0x1234, "%032x" % id, "%032x" % 9) - - def do_release_savepoint(self, connection, name): - # Oracle does not support RELEASE SAVEPOINT - pass - - def do_begin_twophase(self, connection, xid): - connection.connection.begin(*xid) - - def do_prepare_twophase(self, connection, xid): - connection.connection.prepare() - - def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): - self.do_rollback(connection.connection) - - def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): - self.do_commit(connection.connection) - - def do_recover_twophase(self, connection): - pass - - def has_table(self, connection, table_name, schema=None): - if not schema: - schema = self.get_default_schema_name(connection) - cursor = connection.execute("""select table_name from all_tables where table_name=:name and owner=:schema_name""", {'name':self._denormalize_name(table_name), 'schema_name':self._denormalize_name(schema)}) - return cursor.fetchone() is not None - - def has_sequence(self, connection, sequence_name, schema=None): - if not schema: - schema = self.get_default_schema_name(connection) - cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name and sequence_owner=:schema_name""", {'name':self._denormalize_name(sequence_name), 'schema_name':self._denormalize_name(schema)}) - return cursor.fetchone() is not None - - def _normalize_name(self, name): - if name is None: - return None - elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding)): - return name.lower().decode(self.encoding) - else: - return name.decode(self.encoding) - - def _denormalize_name(self, name): - if name is None: - return None - elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): - return name.upper().encode(self.encoding) - else: - return name.encode(self.encoding) - - def get_default_schema_name(self, connection): - return self._normalize_name(connection.execute('SELECT USER FROM DUAL').scalar()) - get_default_schema_name = base.connection_memoize( - ('dialect', 'default_schema_name'))(get_default_schema_name) - - def table_names(self, connection, schema): - # note that table_names() isnt loading DBLINKed or synonym'ed tables - if schema is None: - s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')" - cursor = connection.execute(s) - else: - s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner" - cursor = connection.execute(s, {'owner': self._denormalize_name(schema)}) - return [self._normalize_name(row[0]) for row in cursor] - - def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None): - """search for a local synonym matching the given desired owner/name. - - if desired_owner is None, attempts to locate a distinct owner. - - returns the actual name, owner, dblink name, and synonym name if found. - """ - - sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME - from ALL_SYNONYMS WHERE """ - - clauses = [] - params = {} - if desired_synonym: - clauses.append("SYNONYM_NAME=:synonym_name") - params['synonym_name'] = desired_synonym - if desired_owner: - clauses.append("TABLE_OWNER=:desired_owner") - params['desired_owner'] = desired_owner - if desired_table: - clauses.append("TABLE_NAME=:tname") - params['tname'] = desired_table - - sql += " AND ".join(clauses) - - result = connection.execute(sql, **params) - if desired_owner: - row = result.fetchone() - if row: - return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME'] - else: - return None, None, None, None - else: - rows = result.fetchall() - if len(rows) > 1: - raise AssertionError("There are multiple tables visible to the schema, you must specify owner") - elif len(rows) == 1: - row = rows[0] - return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME'] - else: - return None, None, None, None - - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - - resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False) - - if resolve_synonyms: - actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name)) - else: - actual_name, owner, dblink, synonym = None, None, None, None - - if not actual_name: - actual_name = self._denormalize_name(table.name) - if not dblink: - dblink = '' - if not owner: - owner = self._denormalize_name(table.schema or self.get_default_schema_name(connection)) - - c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner}) - - while True: - row = c.fetchone() - if row is None: - break - - (colname, coltype, length, precision, scale, nullable, default) = (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) - - if include_columns and colname not in include_columns: - continue - - # INTEGER if the scale is 0 and precision is null - # NUMBER if the scale and precision are both null - # NUMBER(9,2) if the precision is 9 and the scale is 2 - # NUMBER(3) if the precision is 3 and scale is 0 - #length is ignored except for CHAR and VARCHAR2 - if coltype == 'NUMBER' : - if precision is None and scale is None: - coltype = OracleNumeric - elif precision is None and scale == 0 : - coltype = OracleInteger - else : - coltype = OracleNumeric(precision, scale) - elif coltype=='CHAR' or coltype=='VARCHAR2': - coltype = ischema_names.get(coltype, OracleString)(length) - else: - coltype = re.sub(r'\(\d+\)', '', coltype) - try: - coltype = ischema_names[coltype] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, colname)) - coltype = sqltypes.NULLTYPE - - colargs = [] - if default is not None: - colargs.append(schema.DefaultClause(sql.text(default))) - - table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs)) - - if not table.columns: - raise AssertionError("Couldn't find any column information for table %s" % actual_name) - - c = connection.execute("""SELECT - ac.constraint_name, - ac.constraint_type, - loc.column_name AS local_column, - rem.table_name AS remote_table, - rem.column_name AS remote_column, - rem.owner AS remote_owner - FROM all_constraints%(dblink)s ac, - all_cons_columns%(dblink)s loc, - all_cons_columns%(dblink)s rem - WHERE ac.table_name = :table_name - AND ac.constraint_type IN ('R','P') - AND ac.owner = :owner - AND ac.owner = loc.owner - AND ac.constraint_name = loc.constraint_name - AND ac.r_owner = rem.owner(+) - AND ac.r_constraint_name = rem.constraint_name(+) - -- order multiple primary keys correctly - ORDER BY ac.constraint_name, loc.position, rem.position""" - % {'dblink':dblink}, {'table_name' : actual_name, 'owner' : owner}) - - fks = {} - while True: - row = c.fetchone() - if row is None: - break - #print "ROW:" , row - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]]) - if cons_type == 'P': - table.primary_key.add(table.c[local_column]) - elif cons_type == 'R': - try: - fk = fks[cons_name] - except KeyError: - fk = ([], []) - fks[cons_name] = fk - if remote_table is None: - # ticket 363 - util.warn( - ("Got 'None' querying 'table_name' from " - "all_cons_columns%(dblink)s - does the user have " - "proper rights to the table?") % {'dblink':dblink}) - continue - - if resolve_synonyms: - ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(remote_owner), desired_table=self._denormalize_name(remote_table)) - if ref_synonym: - remote_table = self._normalize_name(ref_synonym) - remote_owner = self._normalize_name(ref_remote_owner) - - if not table.schema and self._denormalize_name(remote_owner) == owner: - refspec = ".".join([remote_table, remote_column]) - t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) - else: - refspec = ".".join([x for x in [remote_owner, remote_table, remote_column] if x]) - t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, schema=remote_owner, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) - - if local_column not in fk[0]: - fk[0].append(local_column) - if refspec not in fk[1]: - fk[1].append(refspec) - - for name, value in fks.iteritems(): - table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True)) - - -class _OuterJoinColumn(sql.ClauseElement): - __visit_name__ = 'outer_join_column' - - def __init__(self, column): - self.column = column - -class OracleCompiler(compiler.DefaultCompiler): - """Oracle compiler modifies the lexical structure of Select - statements to work under non-ANSI configured Oracle databases, if - the use_ansi flag is False. - """ - - operators = compiler.DefaultCompiler.operators.copy() - operators.update( - { - sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y), - sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) - } - ) - - functions = compiler.DefaultCompiler.functions.copy() - functions.update ( - { - sql_functions.now : 'CURRENT_TIMESTAMP' - } - ) - - def __init__(self, *args, **kwargs): - super(OracleCompiler, self).__init__(*args, **kwargs) - self.__wheres = {} - - def default_from(self): - """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. - - The Oracle compiler tacks a "FROM DUAL" to the statement. - """ - - return " FROM DUAL" - - def apply_function_parens(self, func): - return len(func.clauses) > 0 - - def visit_join(self, join, **kwargs): - if self.dialect.use_ansi: - return compiler.DefaultCompiler.visit_join(self, join, **kwargs) - else: - return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) - - def _get_nonansi_join_whereclause(self, froms): - clauses = [] - - def visit_join(join): - if join.isouter: - def visit_binary(binary): - if binary.operator == sql_operators.eq: - if binary.left.table is join.right: - binary.left = _OuterJoinColumn(binary.left) - elif binary.right.table is join.right: - binary.right = _OuterJoinColumn(binary.right) - clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary})) - else: - clauses.append(join.onclause) - - for f in froms: - visitors.traverse(f, {}, {'join':visit_join}) - return sql.and_(*clauses) - - def visit_outer_join_column(self, vc): - return self.process(vc.column) + "(+)" - - def visit_sequence(self, seq): - return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" - - def visit_alias(self, alias, asfrom=False, **kwargs): - """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" - - if asfrom: - alias_name = isinstance(alias.name, expression._generated_label) and \ - self._truncated_identifier("alias", alias.name) or alias.name - - return self.process(alias.original, asfrom=True, **kwargs) + " " +\ - self.preparer.format_alias(alias, alias_name) - else: - return self.process(alias.original, **kwargs) - - def _TODO_visit_compound_select(self, select): - """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" - pass - - def visit_select(self, select, **kwargs): - """Look for ``LIMIT`` and OFFSET in a select statement, and if - so tries to wrap it in a subquery with ``rownum`` criterion. - """ - - if not getattr(select, '_oracle_visit', None): - if not self.dialect.use_ansi: - if self.stack and 'from' in self.stack[-1]: - existingfroms = self.stack[-1]['from'] - else: - existingfroms = None - - froms = select._get_display_froms(existingfroms) - whereclause = self._get_nonansi_join_whereclause(froms) - if whereclause: - select = select.where(whereclause) - select._oracle_visit = True - - if select._limit is not None or select._offset is not None: - # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html - # - # Generalized form of an Oracle pagination query: - # select ... from ( - # select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from ( - # select distinct ... where ... order by ... - # ) where ROWNUM <= :limit+:offset - # ) where ora_rn > :offset - # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0 - - # TODO: use annotations instead of clone + attr set ? - select = select._generate() - select._oracle_visit = True - - # Wrap the middle select and add the hint - limitselect = sql.select([c for c in select.c]) - if select._limit and self.dialect.optimize_limits: - limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit) - - limitselect._oracle_visit = True - limitselect._is_wrapper = True - - # If needed, add the limiting clause - if select._limit is not None: - max_row = select._limit - if select._offset is not None: - max_row += select._offset - limitselect.append_whereclause( - sql.literal_column("ROWNUM")<=max_row) - - # If needed, add the ora_rn, and wrap again with offset. - if select._offset is None: - select = limitselect - else: - limitselect = limitselect.column( - sql.literal_column("ROWNUM").label("ora_rn")) - limitselect._oracle_visit = True - limitselect._is_wrapper = True - - offsetselect = sql.select( - [c for c in limitselect.c if c.key!='ora_rn']) - offsetselect._oracle_visit = True - offsetselect._is_wrapper = True - - offsetselect.append_whereclause( - sql.literal_column("ora_rn")>select._offset) - - select = offsetselect - - kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) - return compiler.DefaultCompiler.visit_select(self, select, **kwargs) - - def limit_clause(self, select): - return "" - - def for_update_clause(self, select): - if select.for_update == "nowait": - return " FOR UPDATE NOWAIT" - else: - return super(OracleCompiler, self).for_update_clause(select) - - -class OracleSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - return colspec - - def visit_sequence(self, sequence): - if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name, sequence.schema): - self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - -class OracleSchemaDropper(compiler.SchemaDropper): - def visit_sequence(self, sequence): - if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name, sequence.schema): - self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - -class OracleDefaultRunner(base.DefaultRunner): - def visit_sequence(self, seq): - return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {}) - -class OracleIdentifierPreparer(compiler.IdentifierPreparer): - def format_savepoint(self, savepoint): - name = re.sub(r'^_+', '', savepoint.ident) - return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) - - -dialect = OracleDialect -dialect.statement_compiler = OracleCompiler -dialect.schemagenerator = OracleSchemaGenerator -dialect.schemadropper = OracleSchemaDropper -dialect.preparer = OracleIdentifierPreparer -dialect.defaultrunner = OracleDefaultRunner -dialect.execution_ctx_cls = OracleExecutionContext diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py deleted file mode 100644 index 8952b2b1da..0000000000 --- a/lib/sqlalchemy/databases/sqlite.py +++ /dev/null @@ -1,646 +0,0 @@ -# sqlite.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -"""Support for the SQLite database. - -Driver ------- - -When using Python 2.5 and above, the built in ``sqlite3`` driver is -already installed and no additional installation is needed. Otherwise, -the ``pysqlite2`` driver needs to be present. This is the same driver as -``sqlite3``, just with a different name. - -The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3`` -is loaded. This allows an explicitly installed pysqlite driver to take -precedence over the built in one. As with all dialects, a specific -DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control -this explicitly:: - - from sqlite3 import dbapi2 as sqlite - e = create_engine('sqlite:///file.db', module=sqlite) - -Full documentation on pysqlite is available at: -``_ - -Connect Strings ---------------- - -The file specification for the SQLite database is taken as the "database" portion of -the URL. Note that the format of a url is:: - - driver://user:pass@host/database - -This means that the actual filename to be used starts with the characters to the -**right** of the third slash. So connecting to a relative filepath looks like:: - - # relative path - e = create_engine('sqlite:///path/to/database.db') - -An absolute path, which is denoted by starting with a slash, means you need **four** -slashes:: - - # absolute path - e = create_engine('sqlite:////path/to/database.db') - -To use a Windows path, regular drive specifications and backslashes can be used. -Double backslashes are probably needed:: - - # absolute path on Windows - e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db') - -The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify -``sqlite://`` and nothing else:: - - # in-memory database - e = create_engine('sqlite://') - -Threading Behavior ------------------- - -Pysqlite connections do not support being moved between threads, unless -the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition, -when using an in-memory SQLite database, the full database exists only within -the scope of a single connection. It is reported that an in-memory -database does not support being shared between threads regardless of the -``check_same_thread`` flag - which means that a multithreaded -application **cannot** share data from a ``:memory:`` database across threads -unless access to the connection is limited to a single worker thread which communicates -through a queueing mechanism to concurrent threads. - -To provide a default which accomodates SQLite's default threading capabilities -somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool` -be used by default. This pool maintains a single SQLite connection per thread -that is held open up to a count of five concurrent threads. When more than five threads -are used, a cleanup mechanism will dispose of excess unused connections. - -Two optional pool implementations that may be appropriate for particular SQLite usage scenarios: - - * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded - application using an in-memory database, assuming the threading issues inherent in - pysqlite are somehow accomodated for. This pool holds persistently onto a single connection - which is never closed, and is returned for all requests. - - * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that - makes use of a file-based sqlite database. This pool disables any actual "pooling" - behavior, and simply opens and closes real connections corresonding to the :func:`connect()` - and :func:`close()` methods. SQLite can "connect" to a particular file with very high - efficiency, so this option may actually perform better without the extra overhead - of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection - useless since the database would be lost as soon as the connection is "returned" to the pool. - -Date and Time Types -------------------- - -SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide -out of the box functionality for translating values between Python `datetime` objects -and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime` -and related types provide date formatting and parsing functionality when SQlite is used. -The implementation classes are :class:`SLDateTime`, :class:`SLDate` and :class:`SLTime`. -These types represent dates and times as ISO formatted strings, which also nicely -support ordering. There's no reliance on typical "libc" internals for these functions -so historical dates are fully supported. - -Unicode -------- - -In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's -default behavior regarding Unicode is that all strings are returned as Python unicode objects -in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is -*not* used, you will still always receive unicode data back from a result set. It is -**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type -to represent strings, since it will raise a warning if a non-unicode Python string is -passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can -quickly create confusion, particularly when using the ORM as internal data is not -always represented by an actual database result string. - -""" - - -import datetime, re, time - -from sqlalchemy import sql, schema, exc, pool, DefaultClause -from sqlalchemy.engine import default -import sqlalchemy.types as sqltypes -import sqlalchemy.util as util -from sqlalchemy.sql import compiler, functions as sql_functions -from types import NoneType - -class SLNumeric(sqltypes.Numeric): - def bind_processor(self, dialect): - type_ = self.asdecimal and str or float - def process(value): - if value is not None: - return type_(value) - else: - return value - return process - - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class SLFloat(sqltypes.Float): - def bind_processor(self, dialect): - type_ = self.asdecimal and str or float - def process(value): - if value is not None: - return type_(value) - else: - return value - return process - - def get_col_spec(self): - return "FLOAT" - -class SLInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class SLSmallInteger(sqltypes.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class DateTimeMixin(object): - def _bind_processor(self, format, elements): - def process(value): - if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)): - raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.") - elif value is not None: - return format % tuple([getattr(value, attr, 0) for attr in elements]) - else: - return None - return process - - def _result_processor(self, fn, regexp): - def process(value): - if value is not None: - return fn(*[int(x or 0) for x in regexp.match(value).groups()]) - else: - return None - return process - -class SLDateTime(DateTimeMixin, sqltypes.DateTime): - __legacy_microseconds__ = False - - def get_col_spec(self): - return "TIMESTAMP" - - def bind_processor(self, dialect): - if self.__legacy_microseconds__: - return self._bind_processor( - "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s", - ("year", "month", "day", "hour", "minute", "second", "microsecond") - ) - else: - return self._bind_processor( - "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d", - ("year", "month", "day", "hour", "minute", "second", "microsecond") - ) - - _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?") - def result_processor(self, dialect): - return self._result_processor(datetime.datetime, self._reg) - -class SLDate(DateTimeMixin, sqltypes.Date): - def get_col_spec(self): - return "DATE" - - def bind_processor(self, dialect): - return self._bind_processor( - "%4.4d-%2.2d-%2.2d", - ("year", "month", "day") - ) - - _reg = re.compile(r"(\d+)-(\d+)-(\d+)") - def result_processor(self, dialect): - return self._result_processor(datetime.date, self._reg) - -class SLTime(DateTimeMixin, sqltypes.Time): - __legacy_microseconds__ = False - - def get_col_spec(self): - return "TIME" - - def bind_processor(self, dialect): - if self.__legacy_microseconds__: - return self._bind_processor( - "%2.2d:%2.2d:%2.2d.%s", - ("hour", "minute", "second", "microsecond") - ) - else: - return self._bind_processor( - "%2.2d:%2.2d:%2.2d.%06d", - ("hour", "minute", "second", "microsecond") - ) - - _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") - def result_processor(self, dialect): - return self._result_processor(datetime.time, self._reg) - -class SLUnicodeMixin(object): - def bind_processor(self, dialect): - if self.convert_unicode or dialect.convert_unicode: - if self.assert_unicode is None: - assert_unicode = dialect.assert_unicode - else: - assert_unicode = self.assert_unicode - - if not assert_unicode: - return None - - def process(value): - if not isinstance(value, (unicode, NoneType)): - if assert_unicode == 'warn': - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) - return value - else: - raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) - else: - return value - return process - else: - return None - - def result_processor(self, dialect): - return None - -class SLText(SLUnicodeMixin, sqltypes.Text): - def get_col_spec(self): - return "TEXT" - -class SLString(SLUnicodeMixin, sqltypes.String): - def get_col_spec(self): - return "VARCHAR" + (self.length and "(%d)" % self.length or "") - -class SLChar(SLUnicodeMixin, sqltypes.CHAR): - def get_col_spec(self): - return "CHAR" + (self.length and "(%d)" % self.length or "") - -class SLBinary(sqltypes.Binary): - def get_col_spec(self): - return "BLOB" - -class SLBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BOOLEAN" - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - return value and 1 or 0 - return process - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value == 1 - return process - -colspecs = { - sqltypes.Binary: SLBinary, - sqltypes.Boolean: SLBoolean, - sqltypes.CHAR: SLChar, - sqltypes.Date: SLDate, - sqltypes.DateTime: SLDateTime, - sqltypes.Float: SLFloat, - sqltypes.Integer: SLInteger, - sqltypes.NCHAR: SLChar, - sqltypes.Numeric: SLNumeric, - sqltypes.Smallinteger: SLSmallInteger, - sqltypes.String: SLString, - sqltypes.Text: SLText, - sqltypes.Time: SLTime, -} - -ischema_names = { - 'BLOB': SLBinary, - 'BOOL': SLBoolean, - 'BOOLEAN': SLBoolean, - 'CHAR': SLChar, - 'DATE': SLDate, - 'DATETIME': SLDateTime, - 'DECIMAL': SLNumeric, - 'FLOAT': SLFloat, - 'INT': SLInteger, - 'INTEGER': SLInteger, - 'NUMERIC': SLNumeric, - 'REAL': SLNumeric, - 'SMALLINT': SLSmallInteger, - 'TEXT': SLText, - 'TIME': SLTime, - 'TIMESTAMP': SLDateTime, - 'VARCHAR': SLString, -} - -class SQLiteExecutionContext(default.DefaultExecutionContext): - def post_exec(self): - if self.compiled.isinsert and not self.executemany: - if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: - self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - -class SQLiteDialect(default.DefaultDialect): - name = 'sqlite' - supports_alter = False - supports_unicode_statements = True - default_paramstyle = 'qmark' - supports_default_values = True - supports_empty_insert = False - - def __init__(self, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - def vers(num): - return tuple([int(x) for x in num.split('.')]) - if self.dbapi is not None: - sqlite_ver = self.dbapi.version_info - if sqlite_ver < (2, 1, '3'): - util.warn( - ("The installed version of pysqlite2 (%s) is out-dated " - "and will cause errors in some cases. Version 2.1.3 " - "or greater is recommended.") % - '.'.join([str(subver) for subver in sqlite_ver])) - if self.dbapi.sqlite_version_info < (3, 3, 8): - self.supports_default_values = False - self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) - - def dbapi(cls): - try: - from pysqlite2 import dbapi2 as sqlite - except ImportError, e: - try: - from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. - except ImportError: - raise e - return sqlite - dbapi = classmethod(dbapi) - - def server_version_info(self, connection): - return self.dbapi.sqlite_version_info - - def create_connect_args(self, url): - if url.username or url.password or url.host or url.port: - raise exc.ArgumentError( - "Invalid SQLite URL: %s\n" - "Valid SQLite URL forms are:\n" - " sqlite:///:memory: (or, sqlite://)\n" - " sqlite:///relative/path/to/file.db\n" - " sqlite:////absolute/path/to/file.db" % (url,)) - filename = url.database or ':memory:' - - opts = url.query.copy() - util.coerce_kw_type(opts, 'timeout', float) - util.coerce_kw_type(opts, 'isolation_level', str) - util.coerce_kw_type(opts, 'detect_types', int) - util.coerce_kw_type(opts, 'check_same_thread', bool) - util.coerce_kw_type(opts, 'cached_statements', int) - - return ([filename], opts) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def is_disconnect(self, e): - return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) - - def table_names(self, connection, schema): - if schema is not None: - qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema - s = ("SELECT name FROM %s " - "WHERE type='table' ORDER BY name") % (master,) - rs = connection.execute(s) - else: - try: - s = ("SELECT name FROM " - " (SELECT * FROM sqlite_master UNION ALL " - " SELECT * FROM sqlite_temp_master) " - "WHERE type='table' ORDER BY name") - rs = connection.execute(s) - except exc.DBAPIError: - raise - s = ("SELECT name FROM sqlite_master " - "WHERE type='table' ORDER BY name") - rs = connection.execute(s) - - return [row[0] for row in rs] - - def has_table(self, connection, table_name, schema=None): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable))) - - row = cursor.fetchone() - - # consume remaining rows, to work around - # http://www.sqlite.org/cvstrac/tktview?tn=1884 - while cursor.fetchone() is not None: - pass - - return (row is not None) - - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - if table.schema is None: - pragma = "PRAGMA " - else: - pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema) - qtable = preparer.format_table(table, False) - - c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable))) - found_table = False - while True: - row = c.fetchone() - if row is None: - break - - found_table = True - (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5]) - name = re.sub(r'^\"|\"$', '', name) - if include_columns and name not in include_columns: - continue - match = re.match(r'(\w+)(\(.*?\))?', type_) - if match: - coltype = match.group(1) - args = match.group(2) - else: - coltype = "VARCHAR" - args = '' - - try: - coltype = ischema_names[coltype] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, name)) - coltype = sqltypes.NullType - - if args is not None: - args = re.findall(r'(\d+)', args) - coltype = coltype(*[int(a) for a in args]) - - colargs = [] - if has_default: - colargs.append(DefaultClause(sql.text(default))) - table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))) - fks = {} - while True: - row = c.fetchone() - if row is None: - break - (constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4]) - tablename = re.sub(r'^\"|\"$', '', tablename) - localcol = re.sub(r'^\"|\"$', '', localcol) - remotecol = re.sub(r'^\"|\"$', '', remotecol) - try: - fk = fks[constraint_name] - except KeyError: - fk = ([], []) - fks[constraint_name] = fk - - # look up the table based on the given table's engine, not 'self', - # since it could be a ProxyEngine - remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection) - constrained_column = table.c[localcol].name - refspec = ".".join([tablename, remotecol]) - if constrained_column not in fk[0]: - fk[0].append(constrained_column) - if refspec not in fk[1]: - fk[1].append(refspec) - for name, value in fks.iteritems(): - table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], link_to_name=True)) - # check for UNIQUE indexes - c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable))) - unique_indexes = [] - while True: - row = c.fetchone() - if row is None: - break - if (row[2] == 1): - unique_indexes.append(row[1]) - # loop thru unique indexes for one that includes the primary key - for idx in unique_indexes: - c = connection.execute("%sindex_info(%s)" % (pragma, idx)) - cols = [] - while True: - row = c.fetchone() - if row is None: - break - cols.append(row[2]) - -def _pragma_cursor(cursor): - if cursor.closed: - cursor._fetchone_impl = lambda: None - return cursor - -class SQLiteCompiler(compiler.DefaultCompiler): - functions = compiler.DefaultCompiler.functions.copy() - functions.update ( - { - sql_functions.now: 'CURRENT_TIMESTAMP', - sql_functions.char_length: 'length%(expr)s' - } - ) - - extract_map = compiler.DefaultCompiler.extract_map.copy() - extract_map.update({ - 'month': '%m', - 'day': '%d', - 'year': '%Y', - 'second': '%S', - 'hour': '%H', - 'doy': '%j', - 'minute': '%M', - 'epoch': '%s', - 'dow': '%w', - 'week': '%W' - }) - - def visit_cast(self, cast, **kwargs): - if self.dialect.supports_cast: - return super(SQLiteCompiler, self).visit_cast(cast) - else: - return self.process(cast.clause) - - def visit_extract(self, extract): - try: - return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( - self.extract_map[extract.field], self.process(extract.expr)) - except KeyError: - raise exc.ArgumentError( - "%s is not a valid extract argument." % extract.field) - - def limit_clause(self, select): - text = "" - if select._limit is not None: - text += " \n LIMIT " + str(select._limit) - if select._offset is not None: - if select._limit is None: - text += " \n LIMIT -1" - text += " OFFSET " + str(select._offset) - else: - text += " OFFSET 0" - return text - - def for_update_clause(self, select): - # sqlite has no "FOR UPDATE" AFAICT - return '' - - -class SQLiteSchemaGenerator(compiler.SchemaGenerator): - - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - return colspec - -class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = set([ - 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', - 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', - 'conflict', 'constraint', 'create', 'cross', 'current_date', - 'current_time', 'current_timestamp', 'database', 'default', - 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', - 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', - 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', - 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', - 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is', - 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural', - 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer', - 'plan', 'pragma', 'primary', 'query', 'raise', 'references', - 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback', - 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', - 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using', - 'vacuum', 'values', 'view', 'virtual', 'when', 'where', 'indexed', - ]) - - def __init__(self, dialect): - super(SQLiteIdentifierPreparer, self).__init__(dialect) - -dialect = SQLiteDialect -dialect.poolclass = pool.SingletonThreadPool -dialect.statement_compiler = SQLiteCompiler -dialect.schemagenerator = SQLiteSchemaGenerator -dialect.preparer = SQLiteIdentifierPreparer -dialect.execution_ctx_cls = SQLiteExecutionContext diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py deleted file mode 100644 index f5b48e1479..0000000000 --- a/lib/sqlalchemy/databases/sybase.py +++ /dev/null @@ -1,875 +0,0 @@ -# sybase.py -# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch -# Coding: Alexander Houben alexander.houben@thor-solutions.ch -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -""" -Sybase database backend. - -Known issues / TODO: - - * Uses the mx.ODBC driver from egenix (version 2.1.0) - * The current version of sqlalchemy.databases.sybase only supports - mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need - some development) - * Support for pyodbc has been built in but is not yet complete (needs - further development) - * Results of running tests/alltests.py: - Ran 934 tests in 287.032s - FAILED (failures=3, errors=1) - * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751) -""" - -import datetime, operator - -from sqlalchemy import util, sql, schema, exc -from sqlalchemy.sql import compiler, expression -from sqlalchemy.engine import default, base -from sqlalchemy import types as sqltypes -from sqlalchemy.sql import operators as sql_operators -from sqlalchemy import MetaData, Table, Column -from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey - - -__all__ = [ - 'SybaseTypeError' - 'SybaseNumeric', 'SybaseFloat', 'SybaseInteger', 'SybaseBigInteger', - 'SybaseTinyInteger', 'SybaseSmallInteger', - 'SybaseDateTime_mxodbc', 'SybaseDateTime_pyodbc', - 'SybaseDate_mxodbc', 'SybaseDate_pyodbc', - 'SybaseTime_mxodbc', 'SybaseTime_pyodbc', - 'SybaseText', 'SybaseString', 'SybaseChar', 'SybaseBinary', - 'SybaseBoolean', 'SybaseTimeStamp', 'SybaseMoney', 'SybaseSmallMoney', - 'SybaseUniqueIdentifier', - ] - - -RESERVED_WORDS = set([ - "add", "all", "alter", "and", - "any", "as", "asc", "backup", - "begin", "between", "bigint", "binary", - "bit", "bottom", "break", "by", - "call", "capability", "cascade", "case", - "cast", "char", "char_convert", "character", - "check", "checkpoint", "close", "comment", - "commit", "connect", "constraint", "contains", - "continue", "convert", "create", "cross", - "cube", "current", "current_timestamp", "current_user", - "cursor", "date", "dbspace", "deallocate", - "dec", "decimal", "declare", "default", - "delete", "deleting", "desc", "distinct", - "do", "double", "drop", "dynamic", - "else", "elseif", "encrypted", "end", - "endif", "escape", "except", "exception", - "exec", "execute", "existing", "exists", - "externlogin", "fetch", "first", "float", - "for", "force", "foreign", "forward", - "from", "full", "goto", "grant", - "group", "having", "holdlock", "identified", - "if", "in", "index", "index_lparen", - "inner", "inout", "insensitive", "insert", - "inserting", "install", "instead", "int", - "integer", "integrated", "intersect", "into", - "iq", "is", "isolation", "join", - "key", "lateral", "left", "like", - "lock", "login", "long", "match", - "membership", "message", "mode", "modify", - "natural", "new", "no", "noholdlock", - "not", "notify", "null", "numeric", - "of", "off", "on", "open", - "option", "options", "or", "order", - "others", "out", "outer", "over", - "passthrough", "precision", "prepare", "primary", - "print", "privileges", "proc", "procedure", - "publication", "raiserror", "readtext", "real", - "reference", "references", "release", "remote", - "remove", "rename", "reorganize", "resource", - "restore", "restrict", "return", "revoke", - "right", "rollback", "rollup", "save", - "savepoint", "scroll", "select", "sensitive", - "session", "set", "setuser", "share", - "smallint", "some", "sqlcode", "sqlstate", - "start", "stop", "subtrans", "subtransaction", - "synchronize", "syntax_error", "table", "temporary", - "then", "time", "timestamp", "tinyint", - "to", "top", "tran", "trigger", - "truncate", "tsequal", "unbounded", "union", - "unique", "unknown", "unsigned", "update", - "updating", "user", "using", "validate", - "values", "varbinary", "varchar", "variable", - "varying", "view", "wait", "waitfor", - "when", "where", "while", "window", - "with", "with_cube", "with_lparen", "with_rollup", - "within", "work", "writetext", - ]) - -ischema = MetaData() - -tables = Table("SYSTABLE", ischema, - Column("table_id", Integer, primary_key=True), - Column("file_id", SMALLINT), - Column("table_name", CHAR(128)), - Column("table_type", CHAR(10)), - Column("creator", Integer), - #schema="information_schema" - ) - -domains = Table("SYSDOMAIN", ischema, - Column("domain_id", Integer, primary_key=True), - Column("domain_name", CHAR(128)), - Column("type_id", SMALLINT), - Column("precision", SMALLINT, quote=True), - #schema="information_schema" - ) - -columns = Table("SYSCOLUMN", ischema, - Column("column_id", Integer, primary_key=True), - Column("table_id", Integer, ForeignKey(tables.c.table_id)), - Column("pkey", CHAR(1)), - Column("column_name", CHAR(128)), - Column("nulls", CHAR(1)), - Column("width", SMALLINT), - Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)), - # FIXME: should be mx.BIGINT - Column("max_identity", Integer), - # FIXME: should be mx.ODBC.Windows.LONGVARCHAR - Column("default", String), - Column("scale", Integer), - #schema="information_schema" - ) - -foreignkeys = Table("SYSFOREIGNKEY", ischema, - Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True), - Column("foreign_key_id", SMALLINT, primary_key=True), - Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)), - #schema="information_schema" - ) -fkcols = Table("SYSFKCOL", ischema, - Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True), - Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True), - Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True), - Column("primary_column_id", Integer), - #schema="information_schema" - ) - -class SybaseTypeError(sqltypes.TypeEngine): - def result_processor(self, dialect): - return None - - def bind_processor(self, dialect): - def process(value): - raise exc.InvalidRequestError("Data type not supported", [value]) - return process - - def get_col_spec(self): - raise exc.CompileError("Data type not supported") - -class SybaseNumeric(sqltypes.Numeric): - def get_col_spec(self): - if self.scale is None: - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s)" % {'precision' : self.precision} - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class SybaseFloat(sqltypes.FLOAT, SybaseNumeric): - def __init__(self, precision = 10, asdecimal = False, scale = 2, **kwargs): - super(sqltypes.FLOAT, self).__init__(precision, asdecimal, **kwargs) - self.scale = scale - - def get_col_spec(self): - # if asdecimal is True, handle same way as SybaseNumeric - if self.asdecimal: - return SybaseNumeric.get_col_spec(self) - if self.precision is None: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return float(value) - if self.asdecimal: - return SybaseNumeric.result_processor(self, dialect) - return process - -class SybaseInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class SybaseBigInteger(SybaseInteger): - def get_col_spec(self): - return "BIGINT" - -class SybaseTinyInteger(SybaseInteger): - def get_col_spec(self): - return "TINYINT" - -class SybaseSmallInteger(SybaseInteger): - def get_col_spec(self): - return "SMALLINT" - -class SybaseDateTime_mxodbc(sqltypes.DateTime): - def __init__(self, *a, **kw): - super(SybaseDateTime_mxodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - -class SybaseDateTime_pyodbc(sqltypes.DateTime): - def __init__(self, *a, **kw): - super(SybaseDateTime_pyodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - # Convert the datetime.datetime back to datetime.time - return value - return process - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - return value - return process - -class SybaseDate_mxodbc(sqltypes.Date): - def __init__(self, *a, **kw): - super(SybaseDate_mxodbc, self).__init__(False) - - def get_col_spec(self): - return "DATE" - -class SybaseDate_pyodbc(sqltypes.Date): - def __init__(self, *a, **kw): - super(SybaseDate_pyodbc, self).__init__(False) - - def get_col_spec(self): - return "DATE" - -class SybaseTime_mxodbc(sqltypes.Time): - def __init__(self, *a, **kw): - super(SybaseTime_mxodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - # Convert the datetime.datetime back to datetime.time - return datetime.time(value.hour, value.minute, value.second, value.microsecond) - return process - -class SybaseTime_pyodbc(sqltypes.Time): - def __init__(self, *a, **kw): - super(SybaseTime_pyodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - # Convert the datetime.datetime back to datetime.time - return datetime.time(value.hour, value.minute, value.second, value.microsecond) - return process - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - return datetime.datetime(1970, 1, 1, value.hour, value.minute, value.second, value.microsecond) - return process - -class SybaseText(sqltypes.Text): - def get_col_spec(self): - return "TEXT" - -class SybaseString(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} - -class SybaseChar(sqltypes.CHAR): - def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} - -class SybaseBinary(sqltypes.Binary): - def get_col_spec(self): - return "IMAGE" - -class SybaseBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BIT" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - -class SybaseTimeStamp(sqltypes.TIMESTAMP): - def get_col_spec(self): - return "TIMESTAMP" - -class SybaseMoney(sqltypes.TypeEngine): - def get_col_spec(self): - return "MONEY" - -class SybaseSmallMoney(SybaseMoney): - def get_col_spec(self): - return "SMALLMONEY" - -class SybaseUniqueIdentifier(sqltypes.TypeEngine): - def get_col_spec(self): - return "UNIQUEIDENTIFIER" - -class SybaseSQLExecutionContext(default.DefaultExecutionContext): - pass - -class SybaseSQLExecutionContext_mxodbc(SybaseSQLExecutionContext): - - def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): - super(SybaseSQLExecutionContext_mxodbc, self).__init__(dialect, connection, compiled, statement, parameters) - - def pre_exec(self): - super(SybaseSQLExecutionContext_mxodbc, self).pre_exec() - - def post_exec(self): - if self.compiled.isinsert: - table = self.compiled.statement.table - # get the inserted values of the primary key - - # get any sequence IDs first (using @@identity) - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - lastrowid = int(row[0]) - if lastrowid > 0: - # an IDENTITY was inserted, fetch it - # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! - if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: - self._last_inserted_ids = [lastrowid] - else: - self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] - super(SybaseSQLExecutionContext_mxodbc, self).post_exec() - -class SybaseSQLExecutionContext_pyodbc(SybaseSQLExecutionContext): - def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): - super(SybaseSQLExecutionContext_pyodbc, self).__init__(dialect, connection, compiled, statement, parameters) - - def pre_exec(self): - super(SybaseSQLExecutionContext_pyodbc, self).pre_exec() - - def post_exec(self): - if self.compiled.isinsert: - table = self.compiled.statement.table - # get the inserted values of the primary key - - # get any sequence IDs first (using @@identity) - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - lastrowid = int(row[0]) - if lastrowid > 0: - # an IDENTITY was inserted, fetch it - # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! - if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: - self._last_inserted_ids = [lastrowid] - else: - self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] - super(SybaseSQLExecutionContext_pyodbc, self).post_exec() - -class SybaseSQLDialect(default.DefaultDialect): - colspecs = { - # FIXME: unicode support - #sqltypes.Unicode : SybaseUnicode, - sqltypes.Integer : SybaseInteger, - sqltypes.SmallInteger : SybaseSmallInteger, - sqltypes.Numeric : SybaseNumeric, - sqltypes.Float : SybaseFloat, - sqltypes.String : SybaseString, - sqltypes.Binary : SybaseBinary, - sqltypes.Boolean : SybaseBoolean, - sqltypes.Text : SybaseText, - sqltypes.CHAR : SybaseChar, - sqltypes.TIMESTAMP : SybaseTimeStamp, - sqltypes.FLOAT : SybaseFloat, - } - - ischema_names = { - 'integer' : SybaseInteger, - 'unsigned int' : SybaseInteger, - 'unsigned smallint' : SybaseInteger, - 'unsigned bigint' : SybaseInteger, - 'bigint': SybaseBigInteger, - 'smallint' : SybaseSmallInteger, - 'tinyint' : SybaseTinyInteger, - 'varchar' : SybaseString, - 'long varchar' : SybaseText, - 'char' : SybaseChar, - 'decimal' : SybaseNumeric, - 'numeric' : SybaseNumeric, - 'float' : SybaseFloat, - 'double' : SybaseFloat, - 'binary' : SybaseBinary, - 'long binary' : SybaseBinary, - 'varbinary' : SybaseBinary, - 'bit': SybaseBoolean, - 'image' : SybaseBinary, - 'timestamp': SybaseTimeStamp, - 'money': SybaseMoney, - 'smallmoney': SybaseSmallMoney, - 'uniqueidentifier': SybaseUniqueIdentifier, - - 'java.lang.Object' : SybaseTypeError, - 'java serialization' : SybaseTypeError, - } - - name = 'sybase' - # Sybase backend peculiarities - supports_unicode_statements = False - supports_sane_rowcount = False - supports_sane_multi_rowcount = False - execution_ctx_cls = SybaseSQLExecutionContext - - def __new__(cls, dbapi=None, *args, **kwargs): - if cls != SybaseSQLDialect: - return super(SybaseSQLDialect, cls).__new__(cls, *args, **kwargs) - if dbapi: - print dbapi.__name__ - dialect = dialect_mapping.get(dbapi.__name__) - return dialect(*args, **kwargs) - else: - return object.__new__(cls, *args, **kwargs) - - def __init__(self, **params): - super(SybaseSQLDialect, self).__init__(**params) - self.text_as_varchar = False - # FIXME: what is the default schema for sybase connections (DBA?) ? - self.set_default_schema_name("dba") - - def dbapi(cls, module_name=None): - if module_name: - try: - dialect_cls = dialect_mapping[module_name] - return dialect_cls.import_dbapi() - except KeyError: - raise exc.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name) - else: - for dialect_cls in dialect_mapping.values(): - try: - return dialect_cls.import_dbapi() - except ImportError, e: - pass - else: - raise ImportError('No DBAPI module detected for SybaseSQL - please install mxodbc') - dbapi = classmethod(dbapi) - - def type_descriptor(self, typeobj): - newobj = sqltypes.adapt_type(typeobj, self.colspecs) - return newobj - - def last_inserted_ids(self): - return self.context.last_inserted_ids - - def get_default_schema_name(self, connection): - return self.schema_name - - def set_default_schema_name(self, schema_name): - self.schema_name = schema_name - - def do_execute(self, cursor, statement, params, **kwargs): - params = tuple(params) - super(SybaseSQLDialect, self).do_execute(cursor, statement, params, **kwargs) - - # FIXME: remove ? - def _execute(self, c, statement, parameters): - try: - if parameters == {}: - parameters = () - c.execute(statement, parameters) - self.context.rowcount = c.rowcount - c.DBPROP_COMMITPRESERVE = "Y" - except Exception, e: - raise exc.DBAPIError.instance(statement, parameters, e) - - def table_names(self, connection, schema): - """Ignore the schema and the charset for now.""" - s = sql.select([tables.c.table_name], - sql.not_(tables.c.table_name.like("SYS%")) and - tables.c.creator >= 100 - ) - rp = connection.execute(s) - return [row[0] for row in rp.fetchall()] - - def has_table(self, connection, tablename, schema=None): - # FIXME: ignore schemas for sybase - s = sql.select([tables.c.table_name], tables.c.table_name == tablename) - - c = connection.execute(s) - row = c.fetchone() - print "has_table: " + tablename + ": " + str(bool(row is not None)) - return row is not None - - def reflecttable(self, connection, table, include_columns): - # Get base columns - if table.schema is not None: - current_schema = table.schema - else: - current_schema = self.get_default_schema_name(connection) - - s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id]) - - c = connection.execute(s) - found_table = False - # makes sure we append the columns in the correct order - while True: - row = c.fetchone() - if row is None: - break - found_table = True - (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = ( - row[columns.c.column_name], - row[domains.c.domain_name], - row[columns.c.nulls] == 'Y', - row[columns.c.width], - row[domains.c.precision], - row[columns.c.scale], - row[columns.c.default], - row[columns.c.pkey] == 'Y', - row[columns.c.max_identity], - row[tables.c.table_id], - row[columns.c.column_id], - ) - if include_columns and name not in include_columns: - continue - - # FIXME: else problems with SybaseBinary(size) - if numericscale == 0: - numericscale = None - - args = [] - for a in (charlen, numericprec, numericscale): - if a is not None: - args.append(a) - coltype = self.ischema_names.get(type, None) - if coltype == SybaseString and charlen == -1: - coltype = SybaseText() - else: - if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % - (type, name)) - coltype = sqltypes.NULLTYPE - coltype = coltype(*args) - colargs = [] - if default is not None: - colargs.append(schema.DefaultClause(sql.text(default))) - - # any sequences ? - col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs) - if int(max_identity) > 0: - col.sequence = schema.Sequence(name + '_identity') - col.sequence.start = int(max_identity) - col.sequence.increment = 1 - - # append the column - table.append_column(col) - - # any foreign key constraint for this table ? - # note: no multi-column foreign keys are considered - s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name } - c = connection.execute(s) - foreignKeys = {} - while True: - row = c.fetchone() - if row is None: - break - (foreign_table, foreign_column, primary_table, primary_column) = ( - row[0], row[1], row[2], row[3], - ) - if not primary_table in foreignKeys.keys(): - foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]] - else: - foreignKeys[primary_table][0].append('%s'%(foreign_column)) - foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column)) - for primary_table in foreignKeys.keys(): - #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)])) - table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - -class SybaseSQLDialect_mxodbc(SybaseSQLDialect): - execution_ctx_cls = SybaseSQLExecutionContext_mxodbc - - def __init__(self, **params): - super(SybaseSQLDialect_mxodbc, self).__init__(**params) - - self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()} - - def import_dbapi(cls): - #import mx.ODBC.Windows as module - import mxODBC as module - return module - import_dbapi = classmethod(import_dbapi) - - colspecs = SybaseSQLDialect.colspecs.copy() - colspecs[sqltypes.Time] = SybaseTime_mxodbc - colspecs[sqltypes.Date] = SybaseDate_mxodbc - colspecs[sqltypes.DateTime] = SybaseDateTime_mxodbc - - ischema_names = SybaseSQLDialect.ischema_names.copy() - ischema_names['time'] = SybaseTime_mxodbc - ischema_names['date'] = SybaseDate_mxodbc - ischema_names['datetime'] = SybaseDateTime_mxodbc - ischema_names['smalldatetime'] = SybaseDateTime_mxodbc - - def is_disconnect(self, e): - # FIXME: optimize - #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e) - #return True - return False - - def do_execute(self, cursor, statement, parameters, context=None, **kwargs): - super(SybaseSQLDialect_mxodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) - - def create_connect_args(self, url): - '''Return a tuple of *args,**kwargs''' - # FIXME: handle mx.odbc.Windows proprietary args - opts = url.translate_connect_args(username='user') - opts.update(url.query) - argsDict = {} - argsDict['user'] = opts['user'] - argsDict['password'] = opts['password'] - connArgs = [[opts['dsn']], argsDict] - return connArgs - - -class SybaseSQLDialect_pyodbc(SybaseSQLDialect): - execution_ctx_cls = SybaseSQLExecutionContext_pyodbc - - def __init__(self, **params): - super(SybaseSQLDialect_pyodbc, self).__init__(**params) - self.dbapi_type_map = {'getdate' : SybaseDate_pyodbc()} - - def import_dbapi(cls): - import mypyodbc as module - return module - import_dbapi = classmethod(import_dbapi) - - colspecs = SybaseSQLDialect.colspecs.copy() - colspecs[sqltypes.Time] = SybaseTime_pyodbc - colspecs[sqltypes.Date] = SybaseDate_pyodbc - colspecs[sqltypes.DateTime] = SybaseDateTime_pyodbc - - ischema_names = SybaseSQLDialect.ischema_names.copy() - ischema_names['time'] = SybaseTime_pyodbc - ischema_names['date'] = SybaseDate_pyodbc - ischema_names['datetime'] = SybaseDateTime_pyodbc - ischema_names['smalldatetime'] = SybaseDateTime_pyodbc - - def is_disconnect(self, e): - # FIXME: optimize - #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e) - #return True - return False - - def do_execute(self, cursor, statement, parameters, context=None, **kwargs): - super(SybaseSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) - - def create_connect_args(self, url): - '''Return a tuple of *args,**kwargs''' - # FIXME: handle pyodbc proprietary args - opts = url.translate_connect_args(username='user') - opts.update(url.query) - - self.autocommit = False - if 'autocommit' in opts: - self.autocommit = bool(int(opts.pop('autocommit'))) - - argsDict = {} - argsDict['UID'] = opts['user'] - argsDict['PWD'] = opts['password'] - argsDict['DSN'] = opts['dsn'] - connArgs = [[';'.join(["%s=%s"%(key, argsDict[key]) for key in argsDict])], {'autocommit' : self.autocommit}] - return connArgs - - -dialect_mapping = { - 'sqlalchemy.databases.mxODBC' : SybaseSQLDialect_mxodbc, -# 'pyodbc' : SybaseSQLDialect_pyodbc, - } - - -class SybaseSQLCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() - operators.update({ - sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y), - }) - - extract_map = compiler.DefaultCompiler.extract_map.copy() - extract_map.update ({ - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond' - }) - - - def bindparam_string(self, name): - res = super(SybaseSQLCompiler, self).bindparam_string(name) - if name.lower().startswith('literal'): - res = 'STRING(%s)' % res - return res - - def get_select_precolumns(self, select): - s = select._distinct and "DISTINCT " or "" - if select._limit: - #if select._limit == 1: - #s += "FIRST " - #else: - #s += "TOP %s " % (select._limit,) - s += "TOP %s " % (select._limit,) - if select._offset: - if not select._limit: - # FIXME: sybase doesn't allow an offset without a limit - # so use a huge value for TOP here - s += "TOP 1000000 " - s += "START AT %s " % (select._offset+1,) - return s - - def limit_clause(self, select): - # Limit in sybase is after the select keyword - return "" - - def visit_binary(self, binary): - """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: - return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) - else: - return super(SybaseSQLCompiler, self).visit_binary(binary) - - def label_select_column(self, select, column, asfrom): - if isinstance(column, expression.Function): - return column.label(None) - else: - return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) - - function_rewrites = {'current_date': 'getdate', - } - def visit_function(self, func): - func.name = self.function_rewrites.get(func.name, func.name) - res = super(SybaseSQLCompiler, self).visit_function(func) - if func.name.lower() == 'getdate': - # apply CAST operator - # FIXME: what about _pyodbc ? - cast = expression._Cast(func, SybaseDate_mxodbc) - # infinite recursion - # res = self.visit_cast(cast) - res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) - return res - - def visit_extract(self, extract): - field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) - - def for_update_clause(self, select): - # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use - return '' - - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) - - # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not self.is_subquery() or select._limit): - return " ORDER BY " + order_by - else: - return "" - - -class SybaseSQLSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - - colspec = self.preparer.format_column(column) - - if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ - column.autoincrement and isinstance(column.type, sqltypes.Integer): - if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): - column.sequence = schema.Sequence(column.name + '_seq') - - if hasattr(column, 'sequence'): - column.table.has_sequence = column - #colspec += " numeric(30,0) IDENTITY" - colspec += " Integer IDENTITY" - else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - - if not column.nullable: - colspec += " NOT NULL" - - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - return colspec - - -class SybaseSQLSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - self.append("\nDROP INDEX %s.%s" % ( - self.preparer.quote_identifier(index.table.name), - self.preparer.quote(self._validate_identifier(index.name, False), index.quote) - )) - self.execute() - - -class SybaseSQLDefaultRunner(base.DefaultRunner): - pass - - -class SybaseSQLIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = RESERVED_WORDS - - def __init__(self, dialect): - super(SybaseSQLIdentifierPreparer, self).__init__(dialect) - - def _escape_identifier(self, value): - #TODO: determin SybaseSQL's escapeing rules - return value - - def _fold_identifier_case(self, value): - #TODO: determin SybaseSQL's case folding rules - return value - - -dialect = SybaseSQLDialect -dialect.statement_compiler = SybaseSQLCompiler -dialect.schemagenerator = SybaseSQLSchemaGenerator -dialect.schemadropper = SybaseSQLSchemaDropper -dialect.preparer = SybaseSQLIdentifierPreparer -dialect.defaultrunner = SybaseSQLDefaultRunner diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py new file mode 100644 index 0000000000..91ca91fafd --- /dev/null +++ b/lib/sqlalchemy/dialects/__init__.py @@ -0,0 +1,12 @@ +__all__ = ( +# 'access', +# 'firebird', +# 'informix', +# 'maxdb', +# 'mssql', + 'mysql', + 'oracle', + 'postgresql', + 'sqlite', +# 'sybase', + ) diff --git a/lib/sqlalchemy/dialects/access/__init__.py b/lib/sqlalchemy/dialects/access/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/dialects/access/base.py similarity index 96% rename from lib/sqlalchemy/databases/access.py rename to lib/sqlalchemy/dialects/access/base.py index 56c28b8cc6..ed8297137a 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/dialects/access/base.py @@ -5,6 +5,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +""" +Support for the Microsoft Access database. + +This dialect is *not* tested on SQLAlchemy 0.6. + + +""" from sqlalchemy import sql, schema, types, exc, pool from sqlalchemy.sql import compiler, expression from sqlalchemy.engine import default, base @@ -46,7 +53,7 @@ class AcTinyInteger(types.Integer): def get_col_spec(self): return "TINYINT" -class AcSmallInteger(types.Smallinteger): +class AcSmallInteger(types.SmallInteger): def get_col_spec(self): return "SMALLINT" @@ -155,7 +162,7 @@ class AccessDialect(default.DefaultDialect): colspecs = { types.Unicode : AcUnicode, types.Integer : AcInteger, - types.Smallinteger: AcSmallInteger, + types.SmallInteger: AcSmallInteger, types.Numeric : AcNumeric, types.Float : AcFloat, types.DateTime : AcDateTime, @@ -327,8 +334,8 @@ class AccessDialect(default.DefaultDialect): return names -class AccessCompiler(compiler.DefaultCompiler): - extract_map = compiler.DefaultCompiler.extract_map.copy() +class AccessCompiler(compiler.SQLCompiler): + extract_map = compiler.SQLCompiler.extract_map.copy() extract_map.update ({ 'month': 'm', 'day': 'd', @@ -341,7 +348,7 @@ class AccessCompiler(compiler.DefaultCompiler): 'dow': 'w', 'week': 'ww' }) - + def visit_select_precolumns(self, select): """Access puts TOP, it's version of LIMIT here """ s = select.distinct and "DISTINCT " or "" @@ -393,8 +400,7 @@ class AccessCompiler(compiler.DefaultCompiler): field = self.extract_map.get(extract.field, extract.field) return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) - -class AccessSchemaGenerator(compiler.SchemaGenerator): +class AccessDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() @@ -417,14 +423,9 @@ class AccessSchemaGenerator(compiler.SchemaGenerator): return colspec -class AccessSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - + def visit_drop_index(self, drop): + index = drop.element self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False))) - self.execute() - -class AccessDefaultRunner(base.DefaultRunner): - pass class AccessIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = compiler.RESERVED_WORDS.copy() @@ -436,8 +437,6 @@ class AccessIdentifierPreparer(compiler.IdentifierPreparer): dialect = AccessDialect dialect.poolclass = pool.SingletonThreadPool dialect.statement_compiler = AccessCompiler -dialect.schemagenerator = AccessSchemaGenerator -dialect.schemadropper = AccessSchemaDropper +dialect.ddlcompiler = AccessDDLCompiler dialect.preparer = AccessIdentifierPreparer -dialect.defaultrunner = AccessDefaultRunner -dialect.execution_ctx_cls = AccessExecutionContext +dialect.execution_ctx_cls = AccessExecutionContext \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py new file mode 100644 index 0000000000..6b1b80db21 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.firebird import base, kinterbasdb + +base.dialect = kinterbasdb.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py new file mode 100644 index 0000000000..57b89ed058 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -0,0 +1,626 @@ +# firebird.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Support for the Firebird database. + +Connectivity is usually supplied via the kinterbasdb_ +DBAPI module. + +Firebird dialects +----------------- + +Firebird offers two distinct dialects_ (not to be confused with a +SQLAlchemy ``Dialect``): + +dialect 1 + This is the old syntax and behaviour, inherited from Interbase pre-6.0. + +dialect 3 + This is the newer and supported syntax, introduced in Interbase 6.0. + +The SQLAlchemy Firebird dialect detects these versions and +adjusts its representation of SQL accordingly. However, +support for dialect 1 is not well tested and probably has +incompatibilities. + +Firebird Locking Behavior +------------------------- + +Firebird locks tables aggressively. For this reason, a DROP TABLE may +hang until other transactions are released. SQLAlchemy does its best +to release transactions as quickly as possible. The most common cause +of hanging transactions is a non-fully consumed result set, i.e.:: + + result = engine.execute("select * from table") + row = result.fetchone() + return + +Where above, the ``ResultProxy`` has not been fully consumed. The +connection will be returned to the pool and the transactional state +rolled back once the Python garbage collector reclaims the objects +which hold onto the connection, which often occurs asynchronously. +The above use case can be alleviated by calling ``first()`` on the +``ResultProxy`` which will fetch the first row and immediately close +all remaining cursor/connection resources. + +RETURNING support +----------------- + +Firebird 2.0 supports returning a result set from inserts, and 2.1 extends +that to deletes and updates. + +To use this pass the column/expression list to the ``firebird_returning`` +parameter when creating the queries:: + + raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1), + firebird_returning=[empl.c.id, empl.c.salary]).execute().fetchall() + + +.. [#] Well, that is not the whole story, as the client may still ask + a different (lower) dialect... + +.. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html +.. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb + +""" + + +import datetime, decimal, re + +from sqlalchemy import schema as sa_schema +from sqlalchemy import exc, types as sqltypes, sql, util +from sqlalchemy.sql import expression +from sqlalchemy.engine import base, default, reflection +from sqlalchemy.sql import compiler + +from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE, + FLOAT, INTEGER, NUMERIC, SMALLINT, + TEXT, TIME, TIMESTAMP, VARCHAR) + + +RESERVED_WORDS = set( + ["action", "active", "add", "admin", "after", "all", "alter", "and", "any", + "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename", + "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer", + "by", "cache", "cascade", "case", "cast", "char", "character", "character_length", + "char_length", "check", "check_point_len", "check_point_length", "close", "collate", + "collation", "column", "commit", "committed", "compiletime", "computed", "conditional", + "connect", "constraint", "containing", "continue", "count", "create", "cstring", + "current", "current_connection", "current_date", "current_role", "current_time", + "current_timestamp", "current_transaction", "current_user", "cursor", "database", + "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete", + "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct", + "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point", + "escape", "event", "exception", "execute", "exists", "exit", "extern", "external", + "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it", + "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto", + "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour", + "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input", + "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join", + "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile", + "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment", + "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month", + "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric", + "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option", + "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength", + "pages", "page_size", "parameter", "password", "plan", "position", "post_event", + "precision", "prepare", "primary", "privileges", "procedure", "protected", "public", + "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate", + "references", "release", "release", "reserv", "reserving", "restrict", "retain", + "return", "returning_values", "returns", "revoke", "right", "role", "rollback", + "row_count", "runtime", "savepoint", "schema", "second", "segment", "select", + "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint", + "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability", + "starting", "starts", "statement", "static", "statistics", "sub_type", "sum", + "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction", + "translate", "translation", "trigger", "trim", "type", "uncommitted", "union", + "unique", "update", "upper", "user", "using", "value", "values", "varchar", + "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when", + "whenever", "where", "while", "with", "work", "write", "year", "yearday" ]) + + +class _FBBoolean(sqltypes.Boolean): + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + + +colspecs = { + sqltypes.Boolean: _FBBoolean, +} + +ischema_names = { + 'SHORT': SMALLINT, + 'LONG': BIGINT, + 'QUAD': FLOAT, + 'FLOAT': FLOAT, + 'DATE': DATE, + 'TIME': TIME, + 'TEXT': TEXT, + 'INT64': NUMERIC, + 'DOUBLE': FLOAT, + 'TIMESTAMP': TIMESTAMP, + 'VARYING': VARCHAR, + 'CSTRING': CHAR, + 'BLOB': BLOB, + } + + +# TODO: date conversion types (should be implemented as _FBDateTime, _FBDate, etc. +# as bind/result functionality is required) + +class FBTypeCompiler(compiler.GenericTypeCompiler): + def visit_boolean(self, type_): + return self.visit_SMALLINT(type_) + + def visit_datetime(self, type_): + return self.visit_TIMESTAMP(type_) + + def visit_TEXT(self, type_): + return "BLOB SUB_TYPE 1" + + def visit_BLOB(self, type_): + return "BLOB SUB_TYPE 0" + + +class FBCompiler(sql.compiler.SQLCompiler): + """Firebird specific idiosincrasies""" + + def visit_mod(self, binary, **kw): + # Firebird lacks a builtin modulo operator, but there is + # an equivalent function in the ib_udf library. + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_alias(self, alias, asfrom=False, **kwargs): + if self.dialect._version_two: + return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs) + else: + # Override to not use the AS keyword which FB 1.5 does not like + if asfrom: + alias_name = isinstance(alias.name, expression._generated_label) and \ + self._truncated_identifier("alias", alias.name) or alias.name + + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \ + self.preparer.format_alias(alias, alias_name) + else: + return self.process(alias.original, **kwargs) + + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0]) + start = self.process(func.clauses.clauses[1]) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2]) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + + def visit_length_func(self, function, **kw): + if self.dialect._version_two: + return "char_length" + self.function_argspec(function) + else: + return "strlen" + self.function_argspec(function) + + visit_char_length_func = visit_length_func + + def function_argspec(self, func, **kw): + if func.clauses: + return self.process(func.clause_expr) + else: + return "" + + def default_from(self): + return " FROM rdb$database" + + def visit_sequence(self, seq): + return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) + + def get_select_precolumns(self, select): + """Called when building a ``SELECT`` statement, position is just + before column list Firebird puts the limit and offset right + after the ``SELECT``... + """ + + result = "" + if select._limit: + result += "FIRST %d " % select._limit + if select._offset: + result +="SKIP %d " % select._offset + if select._distinct: + result += "DISTINCT " + return result + + def limit_clause(self, select): + """Already taken care of in the `get_select_precolumns` method.""" + + return "" + + def returning_clause(self, stmt, returning_cols): + + columns = [ + self.process( + self.label_select_column(None, c, asfrom=False), + within_columns_clause=True, + result_map=self.result_map + ) + for c in expression._select_iterables(returning_cols) + ] + return 'RETURNING ' + ', '.join(columns) + + +class FBDDLCompiler(sql.compiler.DDLCompiler): + """Firebird syntactic idiosincrasies""" + + def visit_create_sequence(self, create): + """Generate a ``CREATE GENERATOR`` statement for the sequence.""" + + if self.dialect._version_two: + return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + else: + return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element) + + def visit_drop_sequence(self, drop): + """Generate a ``DROP GENERATOR`` statement for the sequence.""" + + if self.dialect._version_two: + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + else: + return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element) + + +class FBDefaultRunner(base.DefaultRunner): + """Firebird specific idiosincrasies""" + + def visit_sequence(self, seq): + """Get the next value from the sequence using ``gen_id()``.""" + + return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \ + self.dialect.identifier_preparer.format_sequence(seq)) + + +class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): + """Install Firebird specific reserved words.""" + + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) + + +class FBDialect(default.DefaultDialect): + """Firebird dialect""" + + name = 'firebird' + + max_identifier_length = 31 + + supports_sequences = True + sequences_optional = False + supports_default_values = True + postfetch_lastrowid = False + + requires_name_normalize = True + supports_empty_insert = False + + statement_compiler = FBCompiler + ddl_compiler = FBDDLCompiler + defaultrunner = FBDefaultRunner + preparer = FBIdentifierPreparer + type_compiler = FBTypeCompiler + + colspecs = colspecs + ischema_names = ischema_names + + # defaults to dialect ver. 3, + # will be autodetected off upon + # first connect + _version_two = True + + def initialize(self, connection): + super(FBDialect, self).initialize(connection) + self._version_two = self.server_version_info > (2, ) + if not self._version_two: + # TODO: whatever other pre < 2.0 stuff goes here + self.ischema_names = ischema_names.copy() + self.ischema_names['TIMESTAMP'] = sqltypes.DATE + self.colspecs = { + sqltypes.DateTime: sqltypes.DATE + } + else: + self.implicit_returning = True + + def normalize_name(self, name): + # Remove trailing spaces: FB uses a CHAR() type, + # that is padded with spaces + name = name and name.rstrip() + if name is None: + return None + elif name.upper() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): + return name.lower() + else: + return name + + def denormalize_name(self, name): + if name is None: + return None + elif name.lower() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): + return name.upper() + else: + return name + + def has_table(self, connection, table_name, schema=None): + """Return ``True`` if the given table exists, ignoring the `schema`.""" + + tblqry = """ + SELECT 1 FROM rdb$database + WHERE EXISTS (SELECT rdb$relation_name + FROM rdb$relations + WHERE rdb$relation_name=?) + """ + c = connection.execute(tblqry, [self.denormalize_name(table_name)]) + return c.first() is not None + + def has_sequence(self, connection, sequence_name): + """Return ``True`` if the given sequence (generator) exists.""" + + genqry = """ + SELECT 1 FROM rdb$database + WHERE EXISTS (SELECT rdb$generator_name + FROM rdb$generators + WHERE rdb$generator_name=?) + """ + c = connection.execute(genqry, [self.denormalize_name(sequence_name)]) + return c.first() is not None + + def table_names(self, connection, schema): + s = """ + SELECT DISTINCT rdb$relation_name + FROM rdb$relation_fields + WHERE rdb$system_flag=0 AND rdb$view_context IS NULL + """ + return [self.normalize_name(row[0]) for row in connection.execute(s)] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + return self.table_names(connection, schema) + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + s = """ + SELECT distinct rdb$view_name + FROM rdb$view_relations + """ + return [self.normalize_name(row[0]) for row in connection.execute(s)] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + qry = """ + SELECT rdb$view_source AS view_source + FROM rdb$relations + WHERE rdb$relation_name=? + """ + rp = connection.execute(qry, [self.denormalize_name(view_name)]) + row = rp.first() + if row: + return row['view_source'] + else: + return None + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + # Query to extract the PK/FK constrained fields of the given table + keyqry = """ + SELECT se.rdb$field_name AS fname + FROM rdb$relation_constraints rc + JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + """ + tablename = self.denormalize_name(table_name) + # get primary key fields + c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) + pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()] + return pkfields + + @reflection.cache + def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw): + tablename = self.denormalize_name(table_name) + colname = self.denormalize_name(column_name) + # Heuristic-query to determine the generator associated to a PK field + genqry = """ + SELECT trigdep.rdb$depended_on_name AS fgenerator + FROM rdb$dependencies tabdep + JOIN rdb$dependencies trigdep + ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name + AND trigdep.rdb$depended_on_type=14 + AND trigdep.rdb$dependent_type=2 + JOIN rdb$triggers trig ON trig.rdb$trigger_name=tabdep.rdb$dependent_name + WHERE tabdep.rdb$depended_on_name=? + AND tabdep.rdb$depended_on_type=0 + AND trig.rdb$trigger_type=1 + AND tabdep.rdb$field_name=? + AND (SELECT count(*) + FROM rdb$dependencies trigdep2 + WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2 + """ + genc = connection.execute(genqry, [tablename, colname]) + genr = genc.fetchone() + if genr is not None: + return dict(name=self.normalize_name(genr['fgenerator'])) + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + # Query to extract the details of all the fields of the given table + tblqry = """ + SELECT DISTINCT r.rdb$field_name AS fname, + r.rdb$null_flag AS null_flag, + t.rdb$type_name AS ftype, + f.rdb$field_sub_type AS stype, + f.rdb$field_length AS flen, + f.rdb$field_precision AS fprec, + f.rdb$field_scale AS fscale, + COALESCE(r.rdb$default_source, f.rdb$default_source) AS fdefault + FROM rdb$relation_fields r + JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name + JOIN rdb$types t + ON t.rdb$type=f.rdb$field_type AND t.rdb$field_name='RDB$FIELD_TYPE' + WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? + ORDER BY r.rdb$field_position + """ + # get the PK, used to determine the eventual associated sequence + pkey_cols = self.get_primary_keys(connection, table_name) + + tablename = self.denormalize_name(table_name) + # get all of the fields for this table + c = connection.execute(tblqry, [tablename]) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + name = self.normalize_name(row['fname']) + # get the data type + + colspec = row['ftype'].rstrip() + coltype = self.ischema_names.get(colspec) + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % + (colspec, name)) + coltype = sqltypes.NULLTYPE + elif colspec == 'INT64': + coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1) + elif colspec in ('VARYING', 'CSTRING'): + coltype = coltype(row['flen']) + elif colspec == 'TEXT': + coltype = TEXT(row['flen']) + elif colspec == 'BLOB': + if row['stype'] == 1: + coltype = TEXT() + else: + coltype = BLOB() + else: + coltype = coltype(row) + + # does it have a default value? + defvalue = None + if row['fdefault'] is not None: + # the value comes down as "DEFAULT 'value'" + assert row['fdefault'].upper().startswith('DEFAULT '), row + defvalue = row['fdefault'][8:] + col_d = { + 'name' : name, + 'type' : coltype, + 'nullable' : not bool(row['null_flag']), + 'default' : defvalue + } + + # if the PK is a single field, try to see if its linked to + # a sequence thru a trigger + if len(pkey_cols)==1 and name==pkey_cols[0]: + seq_d = self.get_column_sequence(connection, tablename, name) + if seq_d is not None: + col_d['sequence'] = seq_d + + cols.append(col_d) + return cols + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # Query to extract the details of each UK/FK of the given table + fkqry = """ + SELECT rc.rdb$constraint_name AS cname, + cse.rdb$field_name AS fname, + ix2.rdb$relation_name AS targetrname, + se.rdb$field_name AS targetfname + FROM rdb$relation_constraints rc + JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name + JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key + JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name + JOIN rdb$index_segments se + ON se.rdb$index_name=ix2.rdb$index_name + AND se.rdb$field_position=cse.rdb$field_position + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + ORDER BY se.rdb$index_name, se.rdb$field_position + """ + tablename = self.denormalize_name(table_name) + + c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) + fks = util.defaultdict(lambda:{ + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + }) + + for row in c: + cname = self.normalize_name(row['cname']) + fk = fks[cname] + if not fk['name']: + fk['name'] = cname + fk['referred_table'] = self.normalize_name(row['targetrname']) + fk['constrained_columns'].append(self.normalize_name(row['fname'])) + fk['referred_columns'].append( + self.normalize_name(row['targetfname'])) + return fks.values() + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + qry = """ + SELECT ix.rdb$index_name AS index_name, + ix.rdb$unique_flag AS unique_flag, + ic.rdb$field_name AS field_name + FROM rdb$indices ix + JOIN rdb$index_segments ic + ON ix.rdb$index_name=ic.rdb$index_name + LEFT OUTER JOIN rdb$relation_constraints + ON rdb$relation_constraints.rdb$index_name = ic.rdb$index_name + WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL + AND rdb$relation_constraints.rdb$constraint_type IS NULL + ORDER BY index_name, field_name + """ + c = connection.execute(qry, [self.denormalize_name(table_name)]) + + indexes = util.defaultdict(dict) + for row in c: + indexrec = indexes[row['index_name']] + if 'name' not in indexrec: + indexrec['name'] = self.normalize_name(row['index_name']) + indexrec['column_names'] = [] + indexrec['unique'] = bool(row['unique_flag']) + + indexrec['column_names'].append(self.normalize_name(row['field_name'])) + + return indexes.values() + + def do_execute(self, cursor, statement, parameters, **kwargs): + # kinterbase does not accept a None, but wants an empty list + # when there are no arguments. + cursor.execute(statement, parameters or []) + + def do_rollback(self, connection): + # Use the retaining feature, that keeps the transaction going + connection.rollback(True) + + def do_commit(self, connection): + # Use the retaining feature, that keeps the transaction going + connection.commit(True) diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py new file mode 100644 index 0000000000..7d30f87f50 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -0,0 +1,70 @@ +from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler +from sqlalchemy.engine.default import DefaultExecutionContext + +_initialized_kb = False + +class Firebird_kinterbasdb(FBDialect): + driver = 'kinterbasdb' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + def __init__(self, type_conv=200, concurrency_level=1, **kwargs): + super(Firebird_kinterbasdb, self).__init__(**kwargs) + + self.type_conv = type_conv + self.concurrency_level = concurrency_level + + @classmethod + def dbapi(cls): + k = __import__('kinterbasdb') + return k + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if opts.get('port'): + opts['host'] = "%s/%s" % (opts['host'], opts['port']) + del opts['port'] + opts.update(url.query) + + type_conv = opts.pop('type_conv', self.type_conv) + concurrency_level = opts.pop('concurrency_level', self.concurrency_level) + global _initialized_kb + if not _initialized_kb and self.dbapi is not None: + _initialized_kb = True + self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level) + return ([], opts) + + def _get_server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. This is more than enough for our purposes, + # as this is mainly (only?) used by the testsuite. + + from re import match + + fbconn = connection.connection + version = fbconn.server_version + m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version) + if not m: + raise AssertionError("Could not determine version from string '%s'" % version) + return tuple([int(x) for x in m.group(5, 6, 4)]) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'Unable to complete network request to host' in str(e) + elif isinstance(e, self.dbapi.ProgrammingError): + msg = str(e) + return ('Invalid connection state' in msg or + 'Invalid cursor state' in msg) + else: + return False + +dialect = Firebird_kinterbasdb diff --git a/lib/sqlalchemy/dialects/informix/__init__.py b/lib/sqlalchemy/dialects/informix/__init__.py new file mode 100644 index 0000000000..f2fcc76d4c --- /dev/null +++ b/lib/sqlalchemy/dialects/informix/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.informix import base, informixdb + +base.dialect = informixdb.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/dialects/informix/base.py similarity index 59% rename from lib/sqlalchemy/databases/informix.py rename to lib/sqlalchemy/dialects/informix/base.py index 4476af3b9c..b69748fcf1 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/dialects/informix/base.py @@ -5,6 +5,12 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Support for the Informix database. + +This dialect is *not* tested on SQLAlchemy 0.6. + +""" + import datetime @@ -14,55 +20,7 @@ from sqlalchemy.engine import default from sqlalchemy import types as sqltypes -# for offset - -class informix_cursor(object): - def __init__( self , con ): - self.__cursor = con.cursor() - self.rowcount = 0 - - def offset( self , n ): - if n > 0: - self.fetchmany( n ) - self.rowcount = self.__cursor.rowcount - n - if self.rowcount < 0: - self.rowcount = 0 - else: - self.rowcount = self.__cursor.rowcount - - def execute( self , sql , params ): - if params is None or len( params ) == 0: - params = [] - - return self.__cursor.execute( sql , params ) - - def __getattr__( self , name ): - if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ): - return getattr( self.__cursor , name ) - -class InfoNumeric(sqltypes.Numeric): - def get_col_spec(self): - if not self.precision: - return 'NUMERIC' - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class InfoInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class InfoSmallInteger(sqltypes.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class InfoDate(sqltypes.Date): - def get_col_spec( self ): - return "DATE" - class InfoDateTime(sqltypes.DateTime ): - def get_col_spec(self): - return "DATETIME YEAR TO SECOND" - def bind_processor(self, dialect): def process(value): if value is not None: @@ -72,9 +30,6 @@ class InfoDateTime(sqltypes.DateTime ): return process class InfoTime(sqltypes.Time ): - def get_col_spec(self): - return "DATETIME HOUR TO SECOND" - def bind_processor(self, dialect): def process(value): if value is not None: @@ -91,35 +46,8 @@ class InfoTime(sqltypes.Time ): return value return process -class InfoText(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(255)" - -class InfoString(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} - - def bind_processor(self, dialect): - def process(value): - if value == '': - return None - else: - return value - return process - -class InfoChar(sqltypes.CHAR): - def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} - -class InfoBinary(sqltypes.Binary): - def get_col_spec(self): - return "BYTE" class InfoBoolean(sqltypes.Boolean): - default_type = 'NUM' - def get_col_spec(self): - return "SMALLINT" - def result_processor(self, dialect): def process(value): if value is None: @@ -140,104 +68,156 @@ class InfoBoolean(sqltypes.Boolean): return process colspecs = { - sqltypes.Integer : InfoInteger, - sqltypes.Smallinteger : InfoSmallInteger, - sqltypes.Numeric : InfoNumeric, - sqltypes.Float : InfoNumeric, sqltypes.DateTime : InfoDateTime, - sqltypes.Date : InfoDate, sqltypes.Time: InfoTime, - sqltypes.String : InfoString, - sqltypes.Binary : InfoBinary, sqltypes.Boolean : InfoBoolean, - sqltypes.Text : InfoText, - sqltypes.CHAR: InfoChar, } ischema_names = { - 0 : InfoString, # CHAR - 1 : InfoSmallInteger, # SMALLINT - 2 : InfoInteger, # INT - 3 : InfoNumeric, # Float - 3 : InfoNumeric, # SmallFloat - 5 : InfoNumeric, # DECIMAL - 6 : InfoInteger, # Serial - 7 : InfoDate, # DATE - 8 : InfoNumeric, # MONEY - 10 : InfoDateTime, # DATETIME - 11 : InfoBinary, # BYTE - 12 : InfoText, # TEXT - 13 : InfoString, # VARCHAR - 15 : InfoString, # NCHAR - 16 : InfoString, # NVARCHAR - 17 : InfoInteger, # INT8 - 18 : InfoInteger, # Serial8 - 43 : InfoString, # LVARCHAR - -1 : InfoBinary, # BLOB - -1 : InfoText, # CLOB + 0 : sqltypes.CHAR, # CHAR + 1 : sqltypes.SMALLINT, # SMALLINT + 2 : sqltypes.INTEGER, # INT + 3 : sqltypes.FLOAT, # Float + 3 : sqltypes.Float, # SmallFloat + 5 : sqltypes.DECIMAL, # DECIMAL + 6 : sqltypes.Integer, # Serial + 7 : sqltypes.DATE, # DATE + 8 : sqltypes.Numeric, # MONEY + 10 : sqltypes.DATETIME, # DATETIME + 11 : sqltypes.Binary, # BYTE + 12 : sqltypes.TEXT, # TEXT + 13 : sqltypes.VARCHAR, # VARCHAR + 15 : sqltypes.NCHAR, # NCHAR + 16 : sqltypes.NVARCHAR, # NVARCHAR + 17 : sqltypes.Integer, # INT8 + 18 : sqltypes.Integer, # Serial8 + 43 : sqltypes.String, # LVARCHAR + -1 : sqltypes.BLOB, # BLOB + -1 : sqltypes.CLOB, # CLOB } -class InfoExecutionContext(default.DefaultExecutionContext): - # cursor.sqlerrd - # 0 - estimated number of rows returned - # 1 - serial value after insert or ISAM error code - # 2 - number of rows processed - # 3 - estimated cost - # 4 - offset of the error into the SQL statement - # 5 - rowid after insert - def post_exec(self): - if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None: - self._last_inserted_ids = [self.cursor.sqlerrd[1]] - elif hasattr( self.compiled , 'offset' ): - self.cursor.offset( self.compiled.offset ) - super(InfoExecutionContext, self).post_exec() - - def create_cursor( self ): - return informix_cursor( self.connection.connection ) - -class InfoDialect(default.DefaultDialect): - name = 'informix' - default_paramstyle = 'qmark' - # for informix 7.31 - max_identifier_length = 18 +class InfoTypeCompiler(compiler.GenericTypeCompiler): + def visit_DATETIME(self, type_): + return "DATETIME YEAR TO SECOND" + + def visit_TIME(self, type_): + return "DATETIME HOUR TO SECOND" + + def visit_binary(self, type_): + return "BYTE" + + def visit_boolean(self, type_): + return "SMALLINT" + +class InfoSQLCompiler(compiler.SQLCompiler): - def __init__(self, use_ansi=True, **kwargs): - self.use_ansi = use_ansi - default.DefaultDialect.__init__(self, **kwargs) + def __init__(self, *args, **kwargs): + self.limit = 0 + self.offset = 0 - def dbapi(cls): - import informixdb - return informixdb - dbapi = classmethod(dbapi) + compiler.SQLCompiler.__init__( self , *args, **kwargs ) + + def default_from(self): + return " from systables where tabname = 'systables' " - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return 'closed the connection' in str(e) or 'connection not open' in str(e) + def get_select_precolumns( self , select ): + s = select._distinct and "DISTINCT " or "" + # only has limit + if select._limit: + off = select._offset or 0 + s += " FIRST %s " % ( select._limit + off ) else: - return False + s += "" + return s - def do_begin(self , connect ): - cu = connect.cursor() - cu.execute( 'SET LOCK MODE TO WAIT' ) - #cu.execute( 'SET ISOLATION TO REPEATABLE READ' ) + def visit_select(self, select): + if select._offset: + self.offset = select._offset + self.limit = select._limit or 0 + # the column in order by clause must in select too + + def __label( c ): + try: + return c._label.lower() + except: + return '' - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) + # TODO: dont modify the original select, generate a new one + a = [ __label(c) for c in select._raw_columns ] + for c in select._order_by_clause.clauses: + if ( __label(c) not in a ): + select.append_column( c ) + + return compiler.SQLCompiler.visit_select(self, select) + + def limit_clause(self, select): + return "" - def create_connect_args(self, url): - if url.host: - dsn = '%s@%s' % ( url.database , url.host ) + def visit_function( self , func ): + if func.name.lower() == 'current_date': + return "today" + elif func.name.lower() == 'current_time': + return "CURRENT HOUR TO SECOND" + elif func.name.lower() in ( 'current_timestamp' , 'now' ): + return "CURRENT YEAR TO SECOND" else: - dsn = url.database + return compiler.SQLCompiler.visit_function( self , func ) - if url.username: - opt = { 'user':url.username , 'password': url.password } + def visit_clauselist(self, list, **kwargs): + return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None]) + +class InfoDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, first_pk=False): + colspec = self.preparer.format_column(column) + if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \ + isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk: + colspec += " SERIAL" + self.has_serial = True else: - opt = {} + colspec += " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + + return colspec + + def post_create_table(self, table): + if hasattr( self , 'has_serial' ): + del self.has_serial + return '' + +class InfoIdentifierPreparer(compiler.IdentifierPreparer): + def __init__(self, dialect): + super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'") + + def format_constraint(self, constraint): + # informix doesnt support names for constraints + return '' + + def _requires_quotes(self, value): + return False + +class InformixDialect(default.DefaultDialect): + name = 'informix' + # for informix 7.31 + max_identifier_length = 18 + type_compiler = InfoTypeCompiler + poolclass = pool.SingletonThreadPool + statement_compiler = InfoSQLCompiler + ddl_compiler = InfoDDLCompiler + preparer = InfoIdentifierPreparer + colspecs = colspecs + ischema_names = ischema_names - return ([dsn], opt) + def do_begin(self , connect ): + cu = connect.cursor() + cu.execute( 'SET LOCK MODE TO WAIT' ) + #cu.execute( 'SET ISOLATION TO REPEATABLE READ' ) def table_names(self, connection, schema): s = "select tabname from systables" @@ -352,142 +332,3 @@ class InfoDialect(default.DefaultDialect): for cons_name, cons_type, local_column in rows: table.primary_key.add( table.c[local_column] ) -class InfoCompiler(compiler.DefaultCompiler): - """Info compiler modifies the lexical structure of Select statements to work under - non-ANSI configured Oracle databases, if the use_ansi flag is False.""" - - def __init__(self, *args, **kwargs): - self.limit = 0 - self.offset = 0 - - compiler.DefaultCompiler.__init__( self , *args, **kwargs ) - - def default_from(self): - return " from systables where tabname = 'systables' " - - def get_select_precolumns( self , select ): - s = select._distinct and "DISTINCT " or "" - # only has limit - if select._limit: - off = select._offset or 0 - s += " FIRST %s " % ( select._limit + off ) - else: - s += "" - return s - - def visit_select(self, select): - if select._offset: - self.offset = select._offset - self.limit = select._limit or 0 - # the column in order by clause must in select too - - def __label( c ): - try: - return c._label.lower() - except: - return '' - - # TODO: dont modify the original select, generate a new one - a = [ __label(c) for c in select._raw_columns ] - for c in select._order_by_clause.clauses: - if ( __label(c) not in a ): - select.append_column( c ) - - return compiler.DefaultCompiler.visit_select(self, select) - - def limit_clause(self, select): - return "" - - def visit_function( self , func ): - if func.name.lower() == 'current_date': - return "today" - elif func.name.lower() == 'current_time': - return "CURRENT HOUR TO SECOND" - elif func.name.lower() in ( 'current_timestamp' , 'now' ): - return "CURRENT YEAR TO SECOND" - else: - return compiler.DefaultCompiler.visit_function( self , func ) - - def visit_clauselist(self, list, **kwargs): - return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None]) - -class InfoSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, first_pk=False): - colspec = self.preparer.format_column(column) - if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \ - isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk: - colspec += " SERIAL" - self.has_serial = True - else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - - return colspec - - def post_create_table(self, table): - if hasattr( self , 'has_serial' ): - del self.has_serial - return '' - - def visit_primary_key_constraint(self, constraint): - # for informix 7.31 not support constraint name - name = constraint.name - constraint.name = None - super(InfoSchemaGenerator, self).visit_primary_key_constraint(constraint) - constraint.name = name - - def visit_unique_constraint(self, constraint): - # for informix 7.31 not support constraint name - name = constraint.name - constraint.name = None - super(InfoSchemaGenerator, self).visit_unique_constraint(constraint) - constraint.name = name - - def visit_foreign_key_constraint( self , constraint ): - if constraint.name is not None: - constraint.use_alter = True - else: - super( InfoSchemaGenerator , self ).visit_foreign_key_constraint( constraint ) - - def define_foreign_key(self, constraint): - # for informix 7.31 not support constraint name - if constraint.use_alter: - name = constraint.name - constraint.name = None - self.append( "CONSTRAINT " ) - super(InfoSchemaGenerator, self).define_foreign_key(constraint) - constraint.name = name - if name is not None: - self.append( " CONSTRAINT " + name ) - else: - super(InfoSchemaGenerator, self).define_foreign_key(constraint) - - def visit_index(self, index): - if len( index.columns ) == 1 and index.columns[0].foreign_key: - return - super(InfoSchemaGenerator, self).visit_index(index) - -class InfoIdentifierPreparer(compiler.IdentifierPreparer): - def __init__(self, dialect): - super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'") - - def _requires_quotes(self, value): - return False - -class InfoSchemaDropper(compiler.SchemaDropper): - def drop_foreignkey(self, constraint): - if constraint.name is not None: - super( InfoSchemaDropper , self ).drop_foreignkey( constraint ) - -dialect = InfoDialect -poolclass = pool.SingletonThreadPool -dialect.statement_compiler = InfoCompiler -dialect.schemagenerator = InfoSchemaGenerator -dialect.schemadropper = InfoSchemaDropper -dialect.preparer = InfoIdentifierPreparer -dialect.execution_ctx_cls = InfoExecutionContext \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/informix/informixdb.py b/lib/sqlalchemy/dialects/informix/informixdb.py new file mode 100644 index 0000000000..4e929e024d --- /dev/null +++ b/lib/sqlalchemy/dialects/informix/informixdb.py @@ -0,0 +1,79 @@ +from sqlalchemy.dialects.informix.base import InformixDialect +from sqlalchemy.engine import default + +# for offset + +class informix_cursor(object): + def __init__( self , con ): + self.__cursor = con.cursor() + self.rowcount = 0 + + def offset( self , n ): + if n > 0: + self.fetchmany( n ) + self.rowcount = self.__cursor.rowcount - n + if self.rowcount < 0: + self.rowcount = 0 + else: + self.rowcount = self.__cursor.rowcount + + def execute( self , sql , params ): + if params is None or len( params ) == 0: + params = [] + + return self.__cursor.execute( sql , params ) + + def __getattr__( self , name ): + if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ): + return getattr( self.__cursor , name ) + + +class InfoExecutionContext(default.DefaultExecutionContext): + # cursor.sqlerrd + # 0 - estimated number of rows returned + # 1 - serial value after insert or ISAM error code + # 2 - number of rows processed + # 3 - estimated cost + # 4 - offset of the error into the SQL statement + # 5 - rowid after insert + def post_exec(self): + if getattr(self.compiled, "isinsert", False) and self.inserted_primary_key is None: + self._last_inserted_ids = [self.cursor.sqlerrd[1]] + elif hasattr( self.compiled , 'offset' ): + self.cursor.offset( self.compiled.offset ) + + def create_cursor( self ): + return informix_cursor( self.connection.connection ) + + +class Informix_informixdb(InformixDialect): + driver = 'informixdb' + default_paramstyle = 'qmark' + execution_context_cls = InfoExecutionContext + + @classmethod + def dbapi(cls): + return __import__('informixdb') + + def create_connect_args(self, url): + if url.host: + dsn = '%s@%s' % ( url.database , url.host ) + else: + dsn = url.database + + if url.username: + opt = { 'user':url.username , 'password': url.password } + else: + opt = {} + + return ([dsn], opt) + + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'closed the connection' in str(e) or 'connection not open' in str(e) + else: + return False + + +dialect = Informix_informixdb \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/maxdb/__init__.py b/lib/sqlalchemy/dialects/maxdb/__init__.py new file mode 100644 index 0000000000..3f12448b79 --- /dev/null +++ b/lib/sqlalchemy/dialects/maxdb/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.maxdb import base, sapdb + +base.dialect = sapdb.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/databases/maxdb.py b/lib/sqlalchemy/dialects/maxdb/base.py similarity index 90% rename from lib/sqlalchemy/databases/maxdb.py rename to lib/sqlalchemy/dialects/maxdb/base.py index 693295054e..1ec95e03b4 100644 --- a/lib/sqlalchemy/databases/maxdb.py +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -5,7 +5,7 @@ """Support for the MaxDB database. -TODO: More module docs! MaxDB support is currently experimental. +This dialect is *not* tested on SQLAlchemy 0.6. Overview -------- @@ -65,13 +65,6 @@ from sqlalchemy.engine import base as engine_base, default from sqlalchemy import types as sqltypes -__all__ = [ - 'MaxString', 'MaxUnicode', 'MaxChar', 'MaxText', 'MaxInteger', - 'MaxSmallInteger', 'MaxNumeric', 'MaxFloat', 'MaxTimestamp', - 'MaxDate', 'MaxTime', 'MaxBoolean', 'MaxBlob', - ] - - class _StringType(sqltypes.String): _type = None @@ -79,16 +72,6 @@ class _StringType(sqltypes.String): super(_StringType, self).__init__(length=length, **kw) self.encoding = encoding - def get_col_spec(self): - if self.length is None: - spec = 'LONG' - else: - spec = '%s(%s)' % (self._type, self.length) - - if self.encoding is not None: - spec = ' '.join([spec, self.encoding.upper()]) - return spec - def bind_processor(self, dialect): if self.encoding == 'unicode': return None @@ -156,16 +139,6 @@ class MaxText(_StringType): return spec -class MaxInteger(sqltypes.Integer): - def get_col_spec(self): - return 'INTEGER' - - -class MaxSmallInteger(MaxInteger): - def get_col_spec(self): - return 'SMALLINT' - - class MaxNumeric(sqltypes.Numeric): """The FIXED (also NUMERIC, DECIMAL) data type.""" @@ -177,29 +150,7 @@ class MaxNumeric(sqltypes.Numeric): def bind_processor(self, dialect): return None - def get_col_spec(self): - if self.scale and self.precision: - return 'FIXED(%s, %s)' % (self.precision, self.scale) - elif self.precision: - return 'FIXED(%s)' % self.precision - else: - return 'INTEGER' - - -class MaxFloat(sqltypes.Float): - """The FLOAT data type.""" - - def get_col_spec(self): - if self.precision is None: - return 'FLOAT' - else: - return 'FLOAT(%s)' % (self.precision,) - - class MaxTimestamp(sqltypes.DateTime): - def get_col_spec(self): - return 'TIMESTAMP' - def bind_processor(self, dialect): def process(value): if value is None: @@ -242,9 +193,6 @@ class MaxTimestamp(sqltypes.DateTime): class MaxDate(sqltypes.Date): - def get_col_spec(self): - return 'DATE' - def bind_processor(self, dialect): def process(value): if value is None: @@ -279,9 +227,6 @@ class MaxDate(sqltypes.Date): class MaxTime(sqltypes.Time): - def get_col_spec(self): - return 'TIME' - def bind_processor(self, dialect): def process(value): if value is None: @@ -316,15 +261,7 @@ class MaxTime(sqltypes.Time): return process -class MaxBoolean(sqltypes.Boolean): - def get_col_spec(self): - return 'BOOLEAN' - - class MaxBlob(sqltypes.Binary): - def get_col_spec(self): - return 'LONG BYTE' - def bind_processor(self, dialect): def process(value): if value is None: @@ -341,18 +278,54 @@ class MaxBlob(sqltypes.Binary): return value.read(value.remainingLength()) return process +class MaxDBTypeCompiler(compiler.GenericTypeCompiler): + def _string_spec(self, string_spec, type_): + if type_.length is None: + spec = 'LONG' + else: + spec = '%s(%s)' % (string_spec, type_.length) + + if getattr(type_, 'encoding'): + spec = ' '.join([spec, getattr(type_, 'encoding').upper()]) + return spec + + def visit_text(self, type_): + spec = 'LONG' + if getattr(type_, 'encoding', None): + spec = ' '.join((spec, type_.encoding)) + elif type_.convert_unicode: + spec = ' '.join((spec, 'UNICODE')) + + return spec + def visit_char(self, type_): + return self._string_spec("CHAR", type_) + + def visit_string(self, type_): + return self._string_spec("VARCHAR", type_) + + def visit_binary(self, type_): + return "LONG BYTE" + + def visit_numeric(self, type_): + if type_.scale and type_.precision: + return 'FIXED(%s, %s)' % (type_.precision, type_.scale) + elif type_.precision: + return 'FIXED(%s)' % type_.precision + else: + return 'INTEGER' + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + colspecs = { - sqltypes.Integer: MaxInteger, - sqltypes.Smallinteger: MaxSmallInteger, sqltypes.Numeric: MaxNumeric, - sqltypes.Float: MaxFloat, sqltypes.DateTime: MaxTimestamp, sqltypes.Date: MaxDate, sqltypes.Time: MaxTime, sqltypes.String: MaxString, + sqltypes.Unicode:MaxUnicode, sqltypes.Binary: MaxBlob, - sqltypes.Boolean: MaxBoolean, sqltypes.Text: MaxText, sqltypes.CHAR: MaxChar, sqltypes.TIMESTAMP: MaxTimestamp, @@ -361,25 +334,25 @@ colspecs = { } ischema_names = { - 'boolean': MaxBoolean, - 'char': MaxChar, - 'character': MaxChar, - 'date': MaxDate, - 'fixed': MaxNumeric, - 'float': MaxFloat, - 'int': MaxInteger, - 'integer': MaxInteger, - 'long binary': MaxBlob, - 'long unicode': MaxText, - 'long': MaxText, - 'long': MaxText, - 'smallint': MaxSmallInteger, - 'time': MaxTime, - 'timestamp': MaxTimestamp, - 'varchar': MaxString, + 'boolean': sqltypes.BOOLEAN, + 'char': sqltypes.CHAR, + 'character': sqltypes.CHAR, + 'date': sqltypes.DATE, + 'fixed': sqltypes.Numeric, + 'float': sqltypes.FLOAT, + 'int': sqltypes.INT, + 'integer': sqltypes.INT, + 'long binary': sqltypes.BLOB, + 'long unicode': sqltypes.Text, + 'long': sqltypes.Text, + 'long': sqltypes.Text, + 'smallint': sqltypes.SmallInteger, + 'time': sqltypes.Time, + 'timestamp': sqltypes.TIMESTAMP, + 'varchar': sqltypes.VARCHAR, } - +# TODO: migrate this to sapdb.py class MaxDBExecutionContext(default.DefaultExecutionContext): def post_exec(self): # DB-API bug: if there were any functions as values, @@ -421,6 +394,12 @@ class MaxDBExecutionContext(default.DefaultExecutionContext): return MaxDBResultProxy(self) return engine_base.ResultProxy(self) + @property + def rowcount(self): + if hasattr(self, '_rowcount'): + return self._rowcount + else: + return self.cursor.rowcount class MaxDBCachedColumnRow(engine_base.RowProxy): """A RowProxy that only runs result_processors once per column.""" @@ -454,272 +433,17 @@ class MaxDBCachedColumnRow(engine_base.RowProxy): else: return self._get_col(key) - def __getattr__(self, name): - try: - return self._get_col(name) - except KeyError: - raise AttributeError(name) - - -class MaxDBResultProxy(engine_base.ResultProxy): - _process_row = MaxDBCachedColumnRow - - -class MaxDBDialect(default.DefaultDialect): - name = 'maxdb' - supports_alter = True - supports_unicode_statements = True - max_identifier_length = 32 - supports_sane_rowcount = True - supports_sane_multi_rowcount = False - preexecute_pk_sequences = True - - # MaxDB-specific - datetimeformat = 'internal' - - def __init__(self, _raise_known_sql_errors=False, **kw): - super(MaxDBDialect, self).__init__(**kw) - self._raise_known = _raise_known_sql_errors - - if self.dbapi is None: - self.dbapi_type_map = {} - else: - self.dbapi_type_map = { - 'Long Binary': MaxBlob(), - 'Long byte_t': MaxBlob(), - 'Long Unicode': MaxText(), - 'Timestamp': MaxTimestamp(), - 'Date': MaxDate(), - 'Time': MaxTime(), - datetime.datetime: MaxTimestamp(), - datetime.date: MaxDate(), - datetime.time: MaxTime(), - } - - def dbapi(cls): - from sapdb import dbapi as _dbapi - return _dbapi - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - opts.update(url.query) - return [], opts - - def type_descriptor(self, typeobj): - if isinstance(typeobj, type): - typeobj = typeobj() - if isinstance(typeobj, sqltypes.Unicode): - return typeobj.adapt(MaxUnicode) - else: - return sqltypes.adapt_type(typeobj, colspecs) - - def do_execute(self, cursor, statement, parameters, context=None): - res = cursor.execute(statement, parameters) - if isinstance(res, int) and context is not None: - context._rowcount = res - - def do_release_savepoint(self, connection, name): - # Does MaxDB truly support RELEASE SAVEPOINT ? All my attempts - # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS - # BEGIN SQLSTATE: I7065" - # Note that ROLLBACK TO works fine. In theory, a RELEASE should - # just free up some transactional resources early, before the overall - # COMMIT/ROLLBACK so omitting it should be relatively ok. - pass - - def get_default_schema_name(self, connection): - try: - return self._default_schema_name - except AttributeError: - name = self.identifier_preparer._normalize_name( - connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) - self._default_schema_name = name - return name - - def has_table(self, connection, table_name, schema=None): - denormalize = self.identifier_preparer._denormalize_name - bind = [denormalize(table_name)] - if schema is None: - sql = ("SELECT tablename FROM TABLES " - "WHERE TABLES.TABLENAME=? AND" - " TABLES.SCHEMANAME=CURRENT_SCHEMA ") - else: - sql = ("SELECT tablename FROM TABLES " - "WHERE TABLES.TABLENAME = ? AND" - " TABLES.SCHEMANAME=? ") - bind.append(denormalize(schema)) - - rp = connection.execute(sql, bind) - found = bool(rp.fetchone()) - rp.close() - return found - - def table_names(self, connection, schema): - if schema is None: - sql = (" SELECT TABLENAME FROM TABLES WHERE " - " SCHEMANAME=CURRENT_SCHEMA ") - rs = connection.execute(sql) - else: - sql = (" SELECT TABLENAME FROM TABLES WHERE " - " SCHEMANAME=? ") - matchname = self.identifier_preparer._denormalize_name(schema) - rs = connection.execute(sql, matchname) - normalize = self.identifier_preparer._normalize_name - return [normalize(row[0]) for row in rs] - - def reflecttable(self, connection, table, include_columns): - denormalize = self.identifier_preparer._denormalize_name - normalize = self.identifier_preparer._normalize_name - - st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, ' - ' NULLABLE, "DEFAULT", DEFAULTFUNCTION ' - 'FROM COLUMNS ' - 'WHERE TABLENAME=? AND SCHEMANAME=%s ' - 'ORDER BY POS') - - fk = ('SELECT COLUMNNAME, FKEYNAME, ' - ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, ' - ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA ' - ' THEN 1 ELSE 0 END) AS in_schema ' - 'FROM FOREIGNKEYCOLUMNS ' - 'WHERE TABLENAME=? AND SCHEMANAME=%s ' - 'ORDER BY FKEYNAME ') - - params = [denormalize(table.name)] - if not table.schema: - st = st % 'CURRENT_SCHEMA' - fk = fk % 'CURRENT_SCHEMA' - else: - st = st % '?' - fk = fk % '?' - params.append(denormalize(table.schema)) - - rows = connection.execute(st, params).fetchall() - if not rows: - raise exc.NoSuchTableError(table.fullname) - - include_columns = set(include_columns or []) - - for row in rows: - (name, mode, col_type, encoding, length, scale, - nullable, constant_def, func_def) = row - - name = normalize(name) - - if include_columns and name not in include_columns: - continue - - type_args, type_kw = [], {} - if col_type == 'FIXED': - type_args = length, scale - # Convert FIXED(10) DEFAULT SERIAL to our Integer - if (scale == 0 and - func_def is not None and func_def.startswith('SERIAL')): - col_type = 'INTEGER' - type_args = length, - elif col_type in 'FLOAT': - type_args = length, - elif col_type in ('CHAR', 'VARCHAR'): - type_args = length, - type_kw['encoding'] = encoding - elif col_type == 'LONG': - type_kw['encoding'] = encoding - - try: - type_cls = ischema_names[col_type.lower()] - type_instance = type_cls(*type_args, **type_kw) - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (col_type, name)) - type_instance = sqltypes.NullType - - col_kw = {'autoincrement': False} - col_kw['nullable'] = (nullable == 'YES') - col_kw['primary_key'] = (mode == 'KEY') - - if func_def is not None: - if func_def.startswith('SERIAL'): - if col_kw['primary_key']: - # No special default- let the standard autoincrement - # support handle SERIAL pk columns. - col_kw['autoincrement'] = True - else: - # strip current numbering - col_kw['server_default'] = schema.DefaultClause( - sql.text('SERIAL')) - col_kw['autoincrement'] = True - else: - col_kw['server_default'] = schema.DefaultClause( - sql.text(func_def)) - elif constant_def is not None: - col_kw['server_default'] = schema.DefaultClause(sql.text( - "'%s'" % constant_def.replace("'", "''"))) - - table.append_column(schema.Column(name, type_instance, **col_kw)) - - fk_sets = itertools.groupby(connection.execute(fk, params), - lambda row: row.FKEYNAME) - for fkeyname, fkey in fk_sets: - fkey = list(fkey) - if include_columns: - key_cols = set([r.COLUMNNAME for r in fkey]) - if key_cols != include_columns: - continue - - columns, referants = [], [] - quote = self.identifier_preparer._maybe_quote_identifier - - for row in fkey: - columns.append(normalize(row.COLUMNNAME)) - if table.schema or not row.in_schema: - referants.append('.'.join( - [quote(normalize(row[c])) - for c in ('REFSCHEMANAME', 'REFTABLENAME', - 'REFCOLUMNNAME')])) - else: - referants.append('.'.join( - [quote(normalize(row[c])) - for c in ('REFTABLENAME', 'REFCOLUMNNAME')])) - - constraint_kw = {'name': fkeyname.lower()} - if fkey[0].RULE is not None: - rule = fkey[0].RULE - if rule.startswith('DELETE '): - rule = rule[7:] - constraint_kw['ondelete'] = rule - - table_kw = {} - if table.schema or not row.in_schema: - table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME) - - ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME), - table_kw.get('schema')) - if ref_key not in table.metadata.tables: - schema.Table(normalize(fkey[0].REFTABLENAME), - table.metadata, - autoload=True, autoload_with=connection, - **table_kw) - - constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True, - **constraint_kw) - table.append_constraint(constraint) - - def has_sequence(self, connection, name): - # [ticket:726] makes this schema-aware. - denormalize = self.identifier_preparer._denormalize_name - sql = ("SELECT sequence_name FROM SEQUENCES " - "WHERE SEQUENCE_NAME=? ") + def __getattr__(self, name): + try: + return self._get_col(name) + except KeyError: + raise AttributeError(name) - rp = connection.execute(sql, denormalize(name)) - found = bool(rp.fetchone()) - rp.close() - return found +class MaxDBResultProxy(engine_base.ResultProxy): + _process_row = MaxDBCachedColumnRow -class MaxDBCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() - operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y) +class MaxDBCompiler(compiler.SQLCompiler): function_conversion = { 'CURRENT_DATE': 'DATE', @@ -734,6 +458,9 @@ class MaxDBCompiler(compiler.DefaultCompiler): 'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP', 'UTCDATE', 'UTCDIFF']) + def visit_mod(self, binary, **kw): + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + def default_from(self): return ' FROM DUAL' @@ -756,11 +483,13 @@ class MaxDBCompiler(compiler.DefaultCompiler): else: return " WITH LOCK EXCLUSIVE" - def apply_function_parens(self, func): - if func.name.upper() in self.bare_functions: - return len(func.clauses) > 0 + def function_argspec(self, fn, **kw): + if fn.name.upper() in self.bare_functions: + return "" + elif len(fn.clauses) > 0: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) else: - return True + return "" def visit_function(self, fn, **kw): transform = self.function_conversion.get(fn.name.upper(), None) @@ -947,10 +676,10 @@ class MaxDBIdentifierPreparer(compiler.IdentifierPreparer): return name -class MaxDBSchemaGenerator(compiler.SchemaGenerator): +class MaxDBDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kw): colspec = [self.preparer.format_column(column), - column.type.dialect_impl(self.dialect).get_col_spec()] + self.dialect.type_compiler.process(column.type)] if not column.nullable: colspec.append('NOT NULL') @@ -996,7 +725,7 @@ class MaxDBSchemaGenerator(compiler.SchemaGenerator): else: return None - def visit_sequence(self, sequence): + def visit_create_sequence(self, create): """Creates a SEQUENCE. TODO: move to module doc? @@ -1024,7 +753,8 @@ class MaxDBSchemaGenerator(compiler.SchemaGenerator): maxdb_no_cache Defaults to False. If true, sets NOCACHE. """ - + sequence = create.element + if (not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name))): @@ -1061,18 +791,250 @@ class MaxDBSchemaGenerator(compiler.SchemaGenerator): elif opts.get('no_cache', False): ddl.append('NOCACHE') - self.append(' '.join(ddl)) - self.execute() + return ' '.join(ddl) -class MaxDBSchemaDropper(compiler.SchemaDropper): - def visit_sequence(self, sequence): - if (not sequence.optional and - (not self.checkfirst or - self.dialect.has_sequence(self.connection, sequence.name))): - self.append("DROP SEQUENCE %s" % - self.preparer.format_sequence(sequence)) - self.execute() +class MaxDBDialect(default.DefaultDialect): + name = 'maxdb' + supports_alter = True + supports_unicode_statements = True + max_identifier_length = 32 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + preparer = MaxDBIdentifierPreparer + statement_compiler = MaxDBCompiler + ddl_compiler = MaxDBDDLCompiler + defaultrunner = MaxDBDefaultRunner + execution_ctx_cls = MaxDBExecutionContext + + colspecs = colspecs + ischema_names = ischema_names + + # MaxDB-specific + datetimeformat = 'internal' + + def __init__(self, _raise_known_sql_errors=False, **kw): + super(MaxDBDialect, self).__init__(**kw) + self._raise_known = _raise_known_sql_errors + + if self.dbapi is None: + self.dbapi_type_map = {} + else: + self.dbapi_type_map = { + 'Long Binary': MaxBlob(), + 'Long byte_t': MaxBlob(), + 'Long Unicode': MaxText(), + 'Timestamp': MaxTimestamp(), + 'Date': MaxDate(), + 'Time': MaxTime(), + datetime.datetime: MaxTimestamp(), + datetime.date: MaxDate(), + datetime.time: MaxTime(), + } + + def do_execute(self, cursor, statement, parameters, context=None): + res = cursor.execute(statement, parameters) + if isinstance(res, int) and context is not None: + context._rowcount = res + + def do_release_savepoint(self, connection, name): + # Does MaxDB truly support RELEASE SAVEPOINT ? All my attempts + # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS + # BEGIN SQLSTATE: I7065" + # Note that ROLLBACK TO works fine. In theory, a RELEASE should + # just free up some transactional resources early, before the overall + # COMMIT/ROLLBACK so omitting it should be relatively ok. + pass + + def get_default_schema_name(self, connection): + try: + return self._default_schema_name + except AttributeError: + name = self.identifier_preparer._normalize_name( + connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) + self._default_schema_name = name + return name + + def has_table(self, connection, table_name, schema=None): + denormalize = self.identifier_preparer._denormalize_name + bind = [denormalize(table_name)] + if schema is None: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME=? AND" + " TABLES.SCHEMANAME=CURRENT_SCHEMA ") + else: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME = ? AND" + " TABLES.SCHEMANAME=? ") + bind.append(denormalize(schema)) + + rp = connection.execute(sql, bind) + found = bool(rp.fetchone()) + rp.close() + return found + + def table_names(self, connection, schema): + if schema is None: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=CURRENT_SCHEMA ") + rs = connection.execute(sql) + else: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=? ") + matchname = self.identifier_preparer._denormalize_name(schema) + rs = connection.execute(sql, matchname) + normalize = self.identifier_preparer._normalize_name + return [normalize(row[0]) for row in rs] + + def reflecttable(self, connection, table, include_columns): + denormalize = self.identifier_preparer._denormalize_name + normalize = self.identifier_preparer._normalize_name + + st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, ' + ' NULLABLE, "DEFAULT", DEFAULTFUNCTION ' + 'FROM COLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY POS') + + fk = ('SELECT COLUMNNAME, FKEYNAME, ' + ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, ' + ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA ' + ' THEN 1 ELSE 0 END) AS in_schema ' + 'FROM FOREIGNKEYCOLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY FKEYNAME ') + + params = [denormalize(table.name)] + if not table.schema: + st = st % 'CURRENT_SCHEMA' + fk = fk % 'CURRENT_SCHEMA' + else: + st = st % '?' + fk = fk % '?' + params.append(denormalize(table.schema)) + + rows = connection.execute(st, params).fetchall() + if not rows: + raise exc.NoSuchTableError(table.fullname) + + include_columns = set(include_columns or []) + + for row in rows: + (name, mode, col_type, encoding, length, scale, + nullable, constant_def, func_def) = row + + name = normalize(name) + + if include_columns and name not in include_columns: + continue + + type_args, type_kw = [], {} + if col_type == 'FIXED': + type_args = length, scale + # Convert FIXED(10) DEFAULT SERIAL to our Integer + if (scale == 0 and + func_def is not None and func_def.startswith('SERIAL')): + col_type = 'INTEGER' + type_args = length, + elif col_type in 'FLOAT': + type_args = length, + elif col_type in ('CHAR', 'VARCHAR'): + type_args = length, + type_kw['encoding'] = encoding + elif col_type == 'LONG': + type_kw['encoding'] = encoding + + try: + type_cls = ischema_names[col_type.lower()] + type_instance = type_cls(*type_args, **type_kw) + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (col_type, name)) + type_instance = sqltypes.NullType + + col_kw = {'autoincrement': False} + col_kw['nullable'] = (nullable == 'YES') + col_kw['primary_key'] = (mode == 'KEY') + + if func_def is not None: + if func_def.startswith('SERIAL'): + if col_kw['primary_key']: + # No special default- let the standard autoincrement + # support handle SERIAL pk columns. + col_kw['autoincrement'] = True + else: + # strip current numbering + col_kw['server_default'] = schema.DefaultClause( + sql.text('SERIAL')) + col_kw['autoincrement'] = True + else: + col_kw['server_default'] = schema.DefaultClause( + sql.text(func_def)) + elif constant_def is not None: + col_kw['server_default'] = schema.DefaultClause(sql.text( + "'%s'" % constant_def.replace("'", "''"))) + + table.append_column(schema.Column(name, type_instance, **col_kw)) + + fk_sets = itertools.groupby(connection.execute(fk, params), + lambda row: row.FKEYNAME) + for fkeyname, fkey in fk_sets: + fkey = list(fkey) + if include_columns: + key_cols = set([r.COLUMNNAME for r in fkey]) + if key_cols != include_columns: + continue + + columns, referants = [], [] + quote = self.identifier_preparer._maybe_quote_identifier + + for row in fkey: + columns.append(normalize(row.COLUMNNAME)) + if table.schema or not row.in_schema: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFSCHEMANAME', 'REFTABLENAME', + 'REFCOLUMNNAME')])) + else: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFTABLENAME', 'REFCOLUMNNAME')])) + + constraint_kw = {'name': fkeyname.lower()} + if fkey[0].RULE is not None: + rule = fkey[0].RULE + if rule.startswith('DELETE '): + rule = rule[7:] + constraint_kw['ondelete'] = rule + + table_kw = {} + if table.schema or not row.in_schema: + table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME) + + ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME), + table_kw.get('schema')) + if ref_key not in table.metadata.tables: + schema.Table(normalize(fkey[0].REFTABLENAME), + table.metadata, + autoload=True, autoload_with=connection, + **table_kw) + + constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True, + **constraint_kw) + table.append_constraint(constraint) + + def has_sequence(self, connection, name): + # [ticket:726] makes this schema-aware. + denormalize = self.identifier_preparer._denormalize_name + sql = ("SELECT sequence_name FROM SEQUENCES " + "WHERE SEQUENCE_NAME=? ") + + rp = connection.execute(sql, denormalize(name)) + found = bool(rp.fetchone()) + rp.close() + return found + def _autoserial_column(table): @@ -1090,10 +1052,3 @@ def _autoserial_column(table): return None, None -dialect = MaxDBDialect -dialect.preparer = MaxDBIdentifierPreparer -dialect.statement_compiler = MaxDBCompiler -dialect.schemagenerator = MaxDBSchemaGenerator -dialect.schemadropper = MaxDBSchemaDropper -dialect.defaultrunner = MaxDBDefaultRunner -dialect.execution_ctx_cls = MaxDBExecutionContext \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/maxdb/sapdb.py b/lib/sqlalchemy/dialects/maxdb/sapdb.py new file mode 100644 index 0000000000..10e61228e9 --- /dev/null +++ b/lib/sqlalchemy/dialects/maxdb/sapdb.py @@ -0,0 +1,17 @@ +from sqlalchemy.dialects.maxdb.base import MaxDBDialect + +class MaxDB_sapdb(MaxDBDialect): + driver = 'sapdb' + + @classmethod + def dbapi(cls): + from sapdb import dbapi as _dbapi + return _dbapi + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + return [], opts + + +dialect = MaxDB_sapdb \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py new file mode 100644 index 0000000000..e3a829047c --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, pymssql + +base.dialect = pyodbc.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py new file mode 100644 index 0000000000..10b8b33b30 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -0,0 +1,51 @@ +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect +import sys + +class MSDateTime_adodbapi(MSDateTime): + def result_processor(self, dialect): + def process(value): + # adodbapi will return datetimes with empty time values as datetime.date() objects. + # Promote them back to full datetime.datetime() + if type(value) is datetime.date: + return datetime.datetime(value.year, value.month, value.day) + return value + return process + + +class MSDialect_adodbapi(MSDialect): + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = True + driver = 'adodbapi' + + @classmethod + def import_dbapi(cls): + import adodbapi as module + return module + + colspecs = MSDialect.colspecs.copy() + colspecs[sqltypes.DateTime] = MSDateTime_adodbapi + + def create_connect_args(self, url): + keys = url.query + + connectors = ["Provider=SQLOLEDB"] + if 'port' in keys: + connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port"))) + else: + connectors.append ("Data Source=%s" % keys.get("host")) + connectors.append ("Initial Catalog=%s" % keys.get("database")) + user = keys.get("user") + if user: + connectors.append("User Id=%s" % user) + connectors.append("Password=%s" % keys.get("password", "")) + else: + connectors.append("Integrated Security=SSPI") + return [[";".join (connectors)], {}] + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) + +dialect = MSDialect_adodbapi diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py new file mode 100644 index 0000000000..cd031af401 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -0,0 +1,1448 @@ +# mssql.py + +"""Support for the Microsoft SQL Server database. + +Driver +------ + +The MSSQL dialect will work with three different available drivers: + +* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded + driver. + +* *pymssql* - http://pymssql.sourceforge.net/ + +* *adodbapi* - http://adodbapi.sourceforge.net/ + +Drivers are loaded in the order listed above based on availability. + +If you need to load a specific driver pass ``module_name`` when +creating the engine:: + + engine = create_engine('mssql+module_name://dsn') + +``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and +``adodbapi``. + +Currently the pyodbc driver offers the greatest level of +compatibility. + +Connecting +---------- + +Connecting with create_engine() uses the standard URL approach of +``mssql://user:pass@host/dbname[?key=value&key=value...]``. + +If the database name is present, the tokens are converted to a +connection string with the specified values. If the database is not +present, then the host token is taken directly as the DSN name. + +Examples of pyodbc connection string URLs: + +* *mssql+pyodbc://mydsn* - connects using the specified DSN named ``mydsn``. + The connection string that is created will appear like:: + + dsn=mydsn;TrustedConnection=Yes + +* *mssql+pyodbc://user:pass@mydsn* - connects using the DSN named + ``mydsn`` passing in the ``UID`` and ``PWD`` information. The + connection string that is created will appear like:: + + dsn=mydsn;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english* - connects + using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` + information, plus the additional connection configuration option + ``LANGUAGE``. The connection string that is created will appear + like:: + + dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english + +* *mssql+pyodbc://user:pass@host/db* - connects using a connection string + dynamically created that would appear like:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@host:123/db* - connects using a connection + string that is dynamically created, which also includes the port + information using the comma syntax. If your connection string + requires the port information to be passed as a ``port`` keyword + see the next example. This will create the following connection + string:: + + DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@host/db?port=123* - connects using a connection + string that is dynamically created that includes the port + information as a separate ``port`` keyword. This will create the + following connection string:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123 + +If you require a connection string that is outside the options +presented above, use the ``odbc_connect`` keyword to pass in a +urlencoded connection string. What gets passed in will be urldecoded +and passed directly. + +For example:: + + mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb + +would create the following connection string:: + + dsn=mydsn;Database=db + +Encoding your connection string can be easily accomplished through +the python shell. For example:: + + >>> import urllib + >>> urllib.quote_plus('dsn=mydsn;Database=db') + 'dsn%3Dmydsn%3BDatabase%3Ddb' + +Additional arguments which may be specified either as query string +arguments on the URL, or as keyword argument to +:func:`~sqlalchemy.create_engine()` are: + +* *query_timeout* - allows you to override the default query timeout. + Defaults to ``None``. This is only supported on pymssql. + +* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY + should be used in place of the non-scoped version @@IDENTITY. + Defaults to True. + +* *max_identifier_length* - allows you to se the maximum length of + identfiers supported by the database. Defaults to 128. For pymssql + the default is 30. + +* *schema_name* - use to set the schema name. Defaults to ``dbo``. + +Auto Increment Behavior +----------------------- + +``IDENTITY`` columns are supported by using SQLAlchemy +``schema.Sequence()`` objects. In other words:: + + Table('test', mss_engine, + Column('id', Integer, + Sequence('blah',100,10), primary_key=True), + Column('name', String(20)) + ).create() + +would yield:: + + CREATE TABLE test ( + id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, + name VARCHAR(20) NULL, + ) + +Note that the ``start`` and ``increment`` values for sequences are +optional and will default to 1,1. + +* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for + ``INSERT`` s) + +* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on + ``INSERT`` + +Collation Support +----------------- + +MSSQL specific string types support a collation parameter that +creates a column-level specific collation for the column. The +collation parameter accepts a Windows Collation Name or a SQL +Collation Name. Supported types are MSChar, MSNChar, MSString, +MSNVarchar, MSText, and MSNText. For example:: + + Column('login', String(32, collation='Latin1_General_CI_AS')) + +will yield:: + + login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL + +LIMIT/OFFSET Support +-------------------- + +MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is +supported directly through the ``TOP`` Transact SQL keyword:: + + select.limit + +will yield:: + + SELECT TOP n + +If using SQL Server 2005 or above, LIMIT with OFFSET +support is available through the ``ROW_NUMBER OVER`` construct. +For versions below 2005, LIMIT with OFFSET usage will fail. + +Nullability +----------- +MSSQL has support for three levels of column nullability. The default +nullability allows nulls and is explicit in the CREATE TABLE +construct:: + + name VARCHAR(20) NULL + +If ``nullable=None`` is specified then no specification is made. In +other words the database's configured default is used. This will +render:: + + name VARCHAR(20) + +If ``nullable`` is ``True`` or ``False`` then the column will be +``NULL` or ``NOT NULL`` respectively. + +Date / Time Handling +-------------------- +DATE and TIME are supported. Bind parameters are converted +to datetime.datetime() objects as required by most MSSQL drivers, +and results are processed from strings if needed. +The DATE and TIME types are not available for MSSQL 2005 and +previous - if a server version below 2008 is detected, DDL +for these types will be issued as DATETIME. + +Compatibility Levels +-------------------- +MSSQL supports the notion of setting compatibility levels at the +database level. This allows, for instance, to run a database that +is compatibile with SQL2000 while running on a SQL2005 database +server. ``server_version_info`` will always retrun the database +server version information (in this case SQL2005) and not the +compatibiility level information. Because of this, if running under +a backwards compatibility mode SQAlchemy may attempt to use T-SQL +statements that are unable to be parsed by the database server. + +Known Issues +------------ + +* No support for more than one ``IDENTITY`` column per table + +* pymssql has problems with binary and unicode data that this module + does **not** work around + +""" +import datetime, decimal, inspect, operator, sys, re +import itertools + +from sqlalchemy import sql, schema as sa_schema, exc, util +from sqlalchemy.sql import select, compiler, expression, \ + operators as sql_operators, \ + functions as sql_functions, util as sql_util +from sqlalchemy.engine import default, base, reflection +from sqlalchemy import types as sqltypes +from decimal import Decimal as _python_Decimal +from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ + FLOAT, TIMESTAMP, DATETIME, DATE + + +from sqlalchemy.dialects.mssql import information_schema as ischema + +MS_2008_VERSION = (10,) +MS_2005_VERSION = (9,) +MS_2000_VERSION = (8,) + +RESERVED_WORDS = set( + ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', + 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', + 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', + 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', + 'containstable', 'continue', 'convert', 'create', 'cross', 'current', + 'current_date', 'current_time', 'current_timestamp', 'current_user', + 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', + 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', + 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', + 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', + 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', + 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', + 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', + 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', + 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', + 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', + 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', + 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', + 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', + 'reconfigure', 'references', 'replication', 'restore', 'restrict', + 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', + 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', + 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', + 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', + 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', + 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', + 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', + 'writetext', + ]) + + +class _MSNumeric(sqltypes.Numeric): + def result_processor(self, dialect): + if self.asdecimal: + def process(value): + if value is not None: + return _python_Decimal(str(value)) + else: + return value + return process + else: + def process(value): + return float(value) + return process + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, decimal.Decimal): + if value.adjusted() < 0: + result = "%s0.%s%s" % ( + (value < 0 and '-' or ''), + '0' * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value._int])) + + else: + if 'E' in str(value): + result = "%s%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int]), + "0" * (value.adjusted() - (len(value._int)-1))) + else: + if (len(value._int) - 1) > value.adjusted(): + result = "%s%s.%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int][0:value.adjusted() + 1]), + "".join([str(s) for s in value._int][value.adjusted() + 1:])) + else: + result = "%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int][0:value.adjusted() + 1])) + + return result + + else: + return value + + return process + +class REAL(sqltypes.Float): + """A type for ``real`` numbers.""" + + __visit_name__ = 'REAL' + + def __init__(self): + super(REAL, self).__init__(precision=24) + +class TINYINT(sqltypes.Integer): + __visit_name__ = 'TINYINT' + + +# MSSQL DATE/TIME types have varied behavior, sometimes returning +# strings. MSDate/TIME check for everything, and always +# filter bind parameters into datetime objects (required by pyodbc, +# not sure about other dialects). + +class _MSDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.date() + elif isinstance(value, basestring): + return datetime.date(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + +class TIME(sqltypes.TIME): + def __init__(self, precision=None, **kwargs): + self.precision = precision + super(TIME, self).__init__() + + __zero_date = datetime.date(1900, 1, 1) + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif isinstance(value, datetime.time): + value = datetime.datetime.combine(self.__zero_date, value) + return value + return process + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, basestring): + return datetime.time(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + + +class _DateTimeBase(object): + def bind_processor(self, dialect): + def process(value): + # TODO: why ? + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + +class _MSDateTime(_DateTimeBase, sqltypes.DateTime): + pass + +class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'SMALLDATETIME' + +class DATETIME2(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'DATETIME2' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + + +# TODO: is this not an Interval ? +class DATETIMEOFFSET(sqltypes.TypeEngine): + __visit_name__ = 'DATETIMEOFFSET' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + + +class _StringType(object): + """Base for MSSQL string types.""" + + def __init__(self, collation=None): + self.collation = collation + + def __repr__(self): + attributes = inspect.getargspec(self.__init__)[0][1:] + attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) + + params = {} + for attr in attributes: + val = getattr(self, attr) + if val is not None and val is not False: + params[attr] = val + + return "%s(%s)" % (self.__class__.__name__, + ', '.join(['%s=%r' % (k, params[k]) for k in params])) + + +class TEXT(_StringType, sqltypes.TEXT): + """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" + + def __init__(self, *args, **kw): + """Construct a TEXT. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.Text.__init__(self, *args, **kw) + +class NTEXT(_StringType, sqltypes.UnicodeText): + """MSSQL NTEXT type, for variable-length unicode text up to 2^30 + characters.""" + + __visit_name__ = 'NTEXT' + + def __init__(self, *args, **kwargs): + """Construct a NTEXT. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kwargs.pop('collation', None) + _StringType.__init__(self, collation) + length = kwargs.pop('length', None) + sqltypes.UnicodeText.__init__(self, length, **kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum + of 8,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a VARCHAR. + + :param length: Optinal, maximum data length, in characters. + + :param convert_unicode: defaults to False. If True, convert + ``unicode`` data sent to the database to a ``str`` + bytestring, and convert bytestrings coming back from the + database into ``unicode``. + + Bytestrings are encoded using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + If False, may be overridden by + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. + + :param assert_unicode: + + If None (the default), no assertion will take place unless + overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. + + If 'warn', will issue a runtime warning if a ``str`` + instance is used as a bind value. + + If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.VARCHAR.__init__(self, *args, **kw) + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MSSQL NVARCHAR type. + + For variable-length unicode character data up to 4,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a NVARCHAR. + + :param length: Optional, Maximum data length, in characters. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NVARCHAR.__init__(self, *args, **kw) + +class CHAR(_StringType, sqltypes.CHAR): + """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum + of 8,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a CHAR. + + :param length: Optinal, maximum data length, in characters. + + :param convert_unicode: defaults to False. If True, convert + ``unicode`` data sent to the database to a ``str`` + bytestring, and convert bytestrings coming back from the + database into ``unicode``. + + Bytestrings are encoded using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + If False, may be overridden by + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. + + :param assert_unicode: + + If None (the default), no assertion will take place unless + overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. + + If 'warn', will issue a runtime warning if a ``str`` + instance is used as a bind value. + + If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.CHAR.__init__(self, *args, **kw) + +class NCHAR(_StringType, sqltypes.NCHAR): + """MSSQL NCHAR type. + + For fixed-length unicode character data up to 4,000 characters.""" + + def __init__(self, *args, **kw): + """Construct an NCHAR. + + :param length: Optional, Maximum data length, in characters. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NCHAR.__init__(self, *args, **kw) + +class BINARY(sqltypes.Binary): + __visit_name__ = 'BINARY' + +class VARBINARY(sqltypes.Binary): + __visit_name__ = 'VARBINARY' + +class IMAGE(sqltypes.Binary): + __visit_name__ = 'IMAGE' + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + +class _MSBoolean(sqltypes.Boolean): + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = 'MONEY' + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = 'SMALLMONEY' + +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + +class SQL_VARIANT(sqltypes.TypeEngine): + __visit_name__ = 'SQL_VARIANT' + +# old names. +MSNumeric = _MSNumeric +MSDateTime = _MSDateTime +MSDate = _MSDate +MSBoolean = _MSBoolean +MSReal = REAL +MSTinyInteger = TINYINT +MSTime = TIME +MSSmallDateTime = SMALLDATETIME +MSDateTime2 = DATETIME2 +MSDateTimeOffset = DATETIMEOFFSET +MSText = TEXT +MSNText = NTEXT +MSString = VARCHAR +MSNVarchar = NVARCHAR +MSChar = CHAR +MSNChar = NCHAR +MSBinary = BINARY +MSVarBinary = VARBINARY +MSImage = IMAGE +MSBit = BIT +MSMoney = MONEY +MSSmallMoney = SMALLMONEY +MSUniqueIdentifier = UNIQUEIDENTIFIER +MSVariant = SQL_VARIANT + +colspecs = { + sqltypes.Numeric : _MSNumeric, + sqltypes.DateTime : _MSDateTime, + sqltypes.Date : _MSDate, + sqltypes.Time : TIME, + sqltypes.Boolean : _MSBoolean, +} + +ischema_names = { + 'int' : INTEGER, + 'bigint': BIGINT, + 'smallint' : SMALLINT, + 'tinyint' : TINYINT, + 'varchar' : VARCHAR, + 'nvarchar' : NVARCHAR, + 'char' : CHAR, + 'nchar' : NCHAR, + 'text' : TEXT, + 'ntext' : NTEXT, + 'decimal' : DECIMAL, + 'numeric' : NUMERIC, + 'float' : FLOAT, + 'datetime' : DATETIME, + 'datetime2' : DATETIME2, + 'datetimeoffset' : DATETIMEOFFSET, + 'date': DATE, + 'time': TIME, + 'smalldatetime' : SMALLDATETIME, + 'binary' : BINARY, + 'varbinary' : VARBINARY, + 'bit': BIT, + 'real' : REAL, + 'image' : IMAGE, + 'timestamp': TIMESTAMP, + 'money': MONEY, + 'smallmoney': SMALLMONEY, + 'uniqueidentifier': UNIQUEIDENTIFIER, + 'sql_variant': SQL_VARIANT, +} + + +class MSTypeCompiler(compiler.GenericTypeCompiler): + def _extend(self, spec, type_): + """Extend a string-type declaration with standard SQL + COLLATE annotations. + + """ + + if getattr(type_, 'collation', None): + collation = 'COLLATE %s' % type_.collation + else: + collation = None + + if type_.length: + spec = spec + "(%d)" % type_.length + + return ' '.join([c for c in (spec, collation) + if c is not None]) + + def visit_FLOAT(self, type_): + precision = getattr(type_, 'precision', None) + if precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': precision} + + def visit_REAL(self, type_): + return "REAL" + + def visit_TINYINT(self, type_): + return "TINYINT" + + def visit_DATETIMEOFFSET(self, type_): + if type_.precision: + return "DATETIMEOFFSET(%s)" % type_.precision + else: + return "DATETIMEOFFSET" + + def visit_TIME(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "TIME(%s)" % precision + else: + return "TIME" + + def visit_DATETIME2(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "DATETIME2(%s)" % precision + else: + return "DATETIME2" + + def visit_SMALLDATETIME(self, type_): + return "SMALLDATETIME" + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_unicode_text(self, type_): + return self.visit_NTEXT(type_) + + def visit_NTEXT(self, type_): + return self._extend("NTEXT", type_) + + def visit_TEXT(self, type_): + return self._extend("TEXT", type_) + + def visit_VARCHAR(self, type_): + return self._extend("VARCHAR", type_) + + def visit_CHAR(self, type_): + return self._extend("CHAR", type_) + + def visit_NCHAR(self, type_): + return self._extend("NCHAR", type_) + + def visit_NVARCHAR(self, type_): + return self._extend("NVARCHAR", type_) + + def visit_date(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_DATE(type_) + + def visit_time(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_TIME(type_) + + def visit_binary(self, type_): + if type_.length: + return self.visit_BINARY(type_) + else: + return self.visit_IMAGE(type_) + + def visit_BINARY(self, type_): + if type_.length: + return "BINARY(%s)" % type_.length + else: + return "BINARY" + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_VARBINARY(self, type_): + if type_.length: + return "VARBINARY(%s)" % type_.length + else: + return "VARBINARY" + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return 'SMALLMONEY' + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + + def visit_SQL_VARIANT(self, type_): + return 'SQL_VARIANT' + +class MSExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + _select_lastrowid = False + _result_proxy = None + _lastrowid = None + + def pre_exec(self): + """Activate IDENTITY_INSERT if needed.""" + + if self.isinsert: + tbl = self.compiled.statement.table + seq_column = tbl._autoincrement_column + insert_has_sequence = seq_column is not None + + if insert_has_sequence: + self._enable_identity_insert = seq_column.key in self.compiled_parameters[0] + else: + self._enable_identity_insert = False + + self._select_lastrowid = insert_has_sequence and \ + not self.compiled.returning and \ + not self._enable_identity_insert and \ + not self.executemany + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s ON" % + self.dialect.identifier_preparer.format_table(tbl)) + + def post_exec(self): + """Disable IDENTITY_INSERT if enabled.""" + + if self._select_lastrowid: + if self.dialect.use_scope_identity: + self.cursor.execute("SELECT scope_identity() AS lastrowid") + else: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchall()[0] # fetchall() ensures the cursor is consumed without closing it + self._lastrowid = int(row[0]) + + if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning: + self._result_proxy = base.FullyBufferedResultProxy(self) + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) + + def get_lastrowid(self): + return self._lastrowid + + def handle_dbapi_exception(self, e): + if self._enable_identity_insert: + try: + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) + except: + pass + + def get_result_proxy(self): + if self._result_proxy: + return self._result_proxy + else: + return base.ResultProxy(self) + +class MSSQLCompiler(compiler.SQLCompiler): + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond', + 'microseconds': 'microsecond' + }) + + def __init__(self, *args, **kwargs): + super(MSSQLCompiler, self).__init__(*args, **kwargs) + self.tablealiases = {} + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_current_date_func(self, fn, **kw): + return "GETDATE()" + + def visit_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_char_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_concat_op(self, binary): + return "%s + %s" % (self.process(binary.left), self.process(binary.right)) + + def visit_match_op(self, binary): + return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def get_select_precolumns(self, select): + """ MS-SQL puts TOP, it's version of LIMIT here """ + if select._distinct or select._limit: + s = select._distinct and "DISTINCT " or "" + + if select._limit: + if not select._offset: + s += "TOP %s " % (select._limit,) + return s + return compiler.SQLCompiler.get_select_precolumns(self, select) + + def limit_clause(self, select): + # Limit in mssql is after the select keyword + return "" + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + + """ + if not getattr(select, '_mssql_visit', None) and select._offset: + # to use ROW_NUMBER(), an ORDER BY is required. + orderby = self.process(select._order_by_clause) + if not orderby: + raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.') + + _offset = select._offset + _limit = select._limit + select._mssql_visit = True + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() + + limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) + limitselect.append_whereclause("mssql_rn>%d" % _offset) + if _limit is not None: + limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) + return self.process(limitselect, iswrapper=True, **kwargs) + else: + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, **kwargs): + if mssql_aliased: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=True, **kwargs) + else: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + def visit_alias(self, alias, **kwargs): + # translate for schema-qualified table aliases + self.tablealiases[alias.original] = alias + kwargs['mssql_aliased'] = True + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) + + def visit_extract(self, extract): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_column(self, column, result_map=None, **kwargs): + if column.table is not None and \ + (not self.isupdate and not self.isdelete) or self.is_subquery(): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = expression._corresponding_column_or_error(t, column) + + if result_map is not None: + result_map[column.name.lower()] = (column.name, (column, ), column.type) + + return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs) + + return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs) + + def visit_binary(self, binary, **kwargs): + """Move bind parameters to the right-hand side of an operator, where + possible. + + """ + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \ + and not isinstance(binary.right, expression._BindParamClause): + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) + else: + if (binary.operator is operator.eq or binary.operator is operator.ne) and ( + (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \ + (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \ + isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)): + op = binary.operator == operator.eq and "IN" or "NOT IN" + return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) + return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + + def returning_clause(self, stmt, returning_cols): + + if self.isinsert or self.isupdate: + target = stmt.table.alias("inserted") + else: + target = stmt.table.alias("deleted") + + adapter = sql_util.ClauseAdapter(target) + def col_label(col): + adapted = adapter.traverse(c) + if isinstance(c, expression._Label): + return adapted.label(c.key) + else: + return self.label_select_column(None, adapted, asfrom=False) + + columns = [ + self.process( + col_label(c), + within_columns_clause=True, + result_map=self.result_map + ) + for c in expression._select_iterables(returning_cols) + ] + return 'OUTPUT ' + ', '.join(columns) + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # MSSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class MSDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type) + + if column.nullable is not None: + if not column.nullable or column.primary_key: + colspec += " NOT NULL" + else: + colspec += " NULL" + + if not column.table: + raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL") + + seq_col = column.table._autoincrement_column + + # install a IDENTITY Sequence if we have an implicit IDENTITY column + if seq_col is column: + sequence = getattr(column, 'sequence', None) + if sequence: + start, increment = sequence.start or 1, sequence.increment or 1 + else: + start, increment = 1, 1 + colspec += " IDENTITY(%s,%s)" % (start, increment) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(drop.element.table.name), + self.preparer.quote(self._validate_identifier(drop.element.name, False), drop.element.quote) + ) + + +class MSIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + + def _escape_identifier(self, value): + #TODO: determine MSSQL's escaping rules + return value + + def quote_schema(self, schema, force=True): + """Prepare a quoted table and schema name.""" + result = '.'.join([self.quote(x, force) for x in schema.split('.')]) + return result + +class MSDialect(default.DefaultDialect): + name = 'mssql' + supports_default_values = True + supports_empty_insert = False + execution_ctx_cls = MSExecutionContext + text_as_varchar = False + use_scope_identity = True + max_identifier_length = 128 + schema_name = "dbo" + colspecs = colspecs + ischema_names = ischema_names + + supports_unicode_binds = True + postfetch_lastrowid = True + + server_version_info = () + + statement_compiler = MSSQLCompiler + ddl_compiler = MSDDLCompiler + type_compiler = MSTypeCompiler + preparer = MSIdentifierPreparer + + def __init__(self, + query_timeout=None, + use_scope_identity=True, + max_identifier_length=None, + schema_name="dbo", **opts): + self.query_timeout = int(query_timeout or 0) + self.schema_name = schema_name + + self.use_scope_identity = use_scope_identity + self.max_identifier_length = int(max_identifier_length or 0) or \ + self.max_identifier_length + super(MSDialect, self).__init__(**opts) + + def do_savepoint(self, connection, name): + util.warn("Savepoint support in mssql is experimental and may lead to data loss.") + connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") + connection.execute("SAVE TRANSACTION %s" % name) + + def do_release_savepoint(self, connection, name): + pass + + def initialize(self, connection): + super(MSDialect, self).initialize(connection) + if self.server_version_info >= MS_2005_VERSION and 'implicit_returning' not in self.__dict__: + self.implicit_returning = True + + def get_default_schema_name(self, connection): + return self.default_schema_name + + def _get_default_schema_name(self, connection): + user_name = connection.scalar("SELECT user_name() as user_name;") + if user_name is not None: + # now, get the default schema + query = """ + SELECT default_schema_name FROM + sys.database_principals + WHERE name = ? + AND type = 'S' + """ + try: + default_schema_name = connection.scalar(query, [user_name]) + if default_schema_name is not None: + return default_schema_name + except: + pass + return self.schema_name + + def table_names(self, connection, schema): + s = select([ischema.tables.c.table_name], ischema.tables.c.table_schema==schema) + return [row[0] for row in connection.execute(s)] + + + def has_table(self, connection, tablename, schema=None): + current_schema = schema or self.default_schema_name + columns = ischema.columns + s = sql.select([columns], + current_schema + and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) + or columns.c.table_name==tablename, + ) + + c = connection.execute(s) + row = c.fetchone() + return row is not None + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = sql.select([ischema.schemata.c.schema_name], + order_by=[ischema.schemata.c.schema_name] + ) + schema_names = [r[0] for r in connection.execute(s)] + return schema_names + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + current_schema = schema or self.default_schema_name + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == current_schema, + tables.c.table_type == 'BASE TABLE' + ), + order_by=[tables.c.table_name] + ) + table_names = [r[0] for r in connection.execute(s)] + return table_names + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + current_schema = schema or self.default_schema_name + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == current_schema, + tables.c.table_type == 'VIEW' + ), + order_by=[tables.c.table_name] + ) + view_names = [r[0] for r in connection.execute(s)] + return view_names + + # The cursor reports it is closed after executing the sp. + @reflection.cache + def get_indexes(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + full_tname = "%s.%s" % (current_schema, tablename) + indexes = [] + s = sql.text("exec sp_helpindex '%s'" % full_tname) + rp = connection.execute(s) + if rp.closed: + # did not work for this setup. + return [] + for row in rp: + if 'primary key' not in row['index_description']: + indexes.append({ + 'name' : row['index_name'], + 'column_names' : [c.strip() for c in row['index_keys'].split(',')], + 'unique': 'unique' in row['index_description'] + }) + return indexes + + @reflection.cache + def get_view_definition(self, connection, viewname, schema=None, **kw): + current_schema = schema or self.default_schema_name + views = ischema.views + s = sql.select([views.c.view_definition], + sql.and_( + views.c.table_schema == current_schema, + views.c.table_name == viewname + ), + ) + rp = connection.execute(s) + if rp: + view_def = rp.scalar() + return view_def + + @reflection.cache + def get_columns(self, connection, tablename, schema=None, **kw): + # Get base columns + current_schema = schema or self.default_schema_name + columns = ischema.columns + s = sql.select([columns], + current_schema + and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) + or columns.c.table_name==tablename, + order_by=[columns.c.ordinal_position]) + c = connection.execute(s) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + (name, type, nullable, charlen, numericprec, numericscale, default, collation) = ( + row[columns.c.column_name], + row[columns.c.data_type], + row[columns.c.is_nullable] == 'YES', + row[columns.c.character_maximum_length], + row[columns.c.numeric_precision], + row[columns.c.numeric_scale], + row[columns.c.column_default], + row[columns.c.collation_name] + ) + coltype = self.ischema_names.get(type, None) + + kwargs = {} + if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.Binary): + kwargs['length'] = charlen + if collation: + kwargs['collation'] = collation + if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1): + kwargs.pop('length') + + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % (type, name)) + coltype = sqltypes.NULLTYPE + + if issubclass(coltype, sqltypes.Numeric) and coltype is not MSReal: + kwargs['scale'] = numericscale + kwargs['precision'] = numericprec + + coltype = coltype(**kwargs) + cdict = { + 'name' : name, + 'type' : coltype, + 'nullable' : nullable, + 'default' : default, + 'autoincrement':False, + } + cols.append(cdict) + # autoincrement and identity + colmap = {} + for col in cols: + colmap[col['name']] = col + # We also run an sp_columns to check for identity columns: + cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (tablename, current_schema)) + ic = None + while True: + row = cursor.fetchone() + if row is None: + break + (col_name, type_name) = row[3], row[5] + if type_name.endswith("identity") and col_name in colmap: + ic = col_name + colmap[col_name]['autoincrement'] = True + colmap[col_name]['sequence'] = dict( + name='%s_identity' % col_name) + break + cursor.close() + if ic is not None: + try: + # is this table_fullname reliable? + table_fullname = "%s.%s" % (current_schema, tablename) + cursor = connection.execute( + sql.text("select ident_seed(:seed), ident_incr(:incr)"), + {'seed':table_fullname, 'incr':table_fullname} + ) + row = cursor.fetchone() + cursor.close() + if not row is None: + colmap[ic]['sequence'].update({ + 'start' : int(row[0]), + 'increment' : int(row[1]) + }) + except: + # ignoring it, works just like before + pass + return cols + + @reflection.cache + def get_primary_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + pkeys = [] + # Add constraints + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column + R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column + + # Primary key constraints + s = sql.select([C.c.column_name, TC.c.constraint_type], + sql.and_(TC.c.constraint_name == C.c.constraint_name, + C.c.table_name == tablename, + C.c.table_schema == current_schema) + ) + c = connection.execute(s) + for row in c: + if 'PRIMARY' in row[TC.c.constraint_type.name]: + pkeys.append(row[0]) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + # Add constraints + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column + R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column + + # Foreign key constraints + s = sql.select([C.c.column_name, + R.c.table_schema, R.c.table_name, R.c.column_name, + RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule], + sql.and_(C.c.table_name == tablename, + C.c.table_schema == current_schema, + C.c.constraint_name == RR.c.constraint_name, + R.c.constraint_name == RR.c.unique_constraint_name, + C.c.ordinal_position == R.c.ordinal_position + ), + order_by = [RR.c.constraint_name, R.c.ordinal_position]) + + + # group rows by constraint ID, to handle multi-column FKs + fkeys = [] + fknm, scols, rcols = (None, [], []) + + def fkey_rec(): + return { + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + } + + fkeys = util.defaultdict(fkey_rec) + + for r in connection.execute(s).fetchall(): + scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r + + rec = fkeys[rfknm] + rec['name'] = rfknm + if not rec['referred_table']: + rec['referred_table'] = rtbl + + if schema is not None or current_schema != rschema: + rec['referred_schema'] = rschema + + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + local_cols.append(scol) + remote_cols.append(rcol) + + return fkeys.values() + + +# fixme. I added this for the tests to run. -Randall +MSSQLDialect = MSDialect diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py new file mode 100644 index 0000000000..bb6ff315a7 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -0,0 +1,83 @@ +from sqlalchemy import Table, MetaData, Column, ForeignKey +from sqlalchemy.types import String, Unicode, Integer, TypeDecorator + +ischema = MetaData() + +class CoerceUnicode(TypeDecorator): + impl = Unicode + + def process_bind_param(self, value, dialect): + if isinstance(value, str): + value = value.decode(dialect.encoding) + return value + +schemata = Table("SCHEMATA", ischema, + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA") + +tables = Table("TABLES", ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("TABLE_TYPE", String, key="table_type"), + schema="INFORMATION_SCHEMA") + +columns = Table("COLUMNS", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA") + +constraints = Table("TABLE_CONSTRAINTS", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_TYPE", String, key="constraint_type"), + schema="INFORMATION_SCHEMA") + +column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA") + +key_constraints = Table("KEY_COLUMN_USAGE", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA") + +ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, + Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, key="unique_constraint_catalog"), # TODO: is CATLOG misspelled ? + Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, key="unique_constraint_schema"), + Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA") + +views = Table("VIEWS", ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA") + diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py new file mode 100644 index 0000000000..0961c2e760 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -0,0 +1,46 @@ +from sqlalchemy.dialects.mssql.base import MSDialect +from sqlalchemy import types as sqltypes + + +class MSDialect_pymssql(MSDialect): + supports_sane_rowcount = False + max_identifier_length = 30 + driver = 'pymssql' + + @classmethod + def import_dbapi(cls): + import pymssql as module + # pymmsql doesn't have a Binary method. we use string + # TODO: monkeypatching here is less than ideal + module.Binary = lambda st: str(st) + return module + + def __init__(self, **params): + super(MSSQLDialect_pymssql, self).__init__(**params) + self.use_scope_identity = True + + # pymssql understands only ascii + if self.convert_unicode: + util.warn("pymssql does not support unicode") + self.encoding = params.get('encoding', 'ascii') + + + def create_connect_args(self, url): + if hasattr(self, 'query_timeout'): + # ick, globals ? we might want to move this.... + self.dbapi._mssql.set_query_timeout(self.query_timeout) + + keys = url.query + if keys.get('port'): + # pymssql expects port as host:port, not a separate arg + keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])]) + del keys['port'] + return [[], keys] + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e) + + def do_begin(self, connection): + pass + +dialect = MSDialect_pymssql \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py new file mode 100644 index 0000000000..9a2a9e4e78 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -0,0 +1,79 @@ +from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect +from sqlalchemy.connectors.pyodbc import PyODBCConnector +from sqlalchemy import types as sqltypes +import re +import sys + +class MSExecutionContext_pyodbc(MSExecutionContext): + _embedded_scope_identity = False + + def pre_exec(self): + """where appropriate, issue "select scope_identity()" in the same statement. + + Background on why "scope_identity()" is preferable to "@@identity": + http://msdn.microsoft.com/en-us/library/ms190315.aspx + + Background on why we attempt to embed "scope_identity()" into the same + statement as the INSERT: + http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values? + + """ + + super(MSExecutionContext_pyodbc, self).pre_exec() + + # don't embed the scope_identity select into an "INSERT .. DEFAULT VALUES" + if self._select_lastrowid and \ + self.dialect.use_scope_identity and \ + len(self.parameters[0]): + self._embedded_scope_identity = True + + self.statement += "; select scope_identity()" + + def post_exec(self): + if self._embedded_scope_identity: + # Fetch the last inserted id from the manipulated statement + # We may have to skip over a number of result sets with no data (due to triggers, etc.) + while True: + try: + # fetchall() ensures the cursor is consumed without closing it (FreeTDS particularly) + row = self.cursor.fetchall()[0] + break + except self.dialect.dbapi.Error, e: + # no way around this - nextset() consumes the previous set + # so we need to just keep flipping + self.cursor.nextset() + + self._lastrowid = int(row[0]) + else: + super(MSExecutionContext_pyodbc, self).post_exec() + + +class MSDialect_pyodbc(PyODBCConnector, MSDialect): + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + execution_ctx_cls = MSExecutionContext_pyodbc + + pyodbc_driver_name = 'SQL Server' + + def __init__(self, description_encoding='latin-1', **params): + super(MSDialect_pyodbc, self).__init__(**params) + self.description_encoding = description_encoding + self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset') + + def initialize(self, connection): + super(MSDialect_pyodbc, self).initialize(connection) + pyodbc = self.dbapi + + dbapi_con = connection.connection + + self._free_tds = re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)) + + # the "Py2K only" part here is theoretical. + # have not tried pyodbc + python3.1 yet. + # Py2K + self.supports_unicode_statements = not self._free_tds + self.supports_unicode_binds = not self._free_tds + # end Py2K + +dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py new file mode 100644 index 0000000000..4106a299be --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -0,0 +1,4 @@ +from sqlalchemy.dialects.mysql import base, mysqldb, pyodbc, zxjdbc + +# default dialect +base.dialect = mysqldb.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/dialects/mysql/base.py similarity index 66% rename from lib/sqlalchemy/databases/mysql.py rename to lib/sqlalchemy/dialects/mysql/base.py index ba6b026ea2..1c5c251e54 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -19,12 +19,12 @@ But if you would like to use one of the MySQL-specific or enhanced column types when creating tables with your :class:`~sqlalchemy.Table` definitions, then you will need to import them from this module:: - from sqlalchemy.databases import mysql + from sqlalchemy.dialect.mysql import base as mysql Table('mytable', metadata, Column('id', Integer, primary_key=True), - Column('ittybittyblob', mysql.MSTinyBlob), - Column('biggy', mysql.MSBigInteger(unsigned=True))) + Column('ittybittyblob', mysql.TINYBLOB), + Column('biggy', mysql.BIGINT(unsigned=True))) All standard MySQL column types are supported. The OpenGIS types are available for use via table reflection but have no special support or mapping @@ -64,25 +64,6 @@ Nested Transactions 5.0.3 See the official MySQL documentation for detailed information about features supported in any given server release. -Character Sets --------------- - -Many MySQL server installations default to a ``latin1`` encoding for client -connections. All data sent through the connection will be converted into -``latin1``, even if you have ``utf8`` or another character set on your tables -and columns. With versions 4.1 and higher, you can change the connection -character set either through server configuration or by including the -``charset`` parameter in the URL used for ``create_engine``. The ``charset`` -option is passed through to MySQL-Python and has the side-effect of also -enabling ``use_unicode`` in the driver by default. For regular encoded -strings, also pass ``use_unicode=0`` in the connection arguments:: - - # set client encoding to utf8; all strings come back as unicode - create_engine('mysql:///mydb?charset=utf8') - - # set client encoding to utf8; all strings come back as utf8 str - create_engine('mysql:///mydb?charset=utf8&use_unicode=0') - Storage Engines --------------- @@ -196,27 +177,20 @@ timely information affecting MySQL in SQLAlchemy. """ -import datetime, decimal, inspect, re, sys -from array import array as _array +import datetime, inspect, re, sys -from sqlalchemy import exc, log, schema, sql, util +from sqlalchemy import schema as sa_schema +from sqlalchemy import exc, log, sql, util from sqlalchemy.sql import operators as sql_operators from sqlalchemy.sql import functions as sql_functions from sqlalchemy.sql import compiler +from array import array as _array +from sqlalchemy.engine import reflection from sqlalchemy.engine import base as engine_base, default from sqlalchemy import types as sqltypes - -__all__ = ( - 'MSBigInteger', 'MSMediumInteger', 'MSBinary', 'MSBit', 'MSBlob', 'MSBoolean', - 'MSChar', 'MSDate', 'MSDateTime', 'MSDecimal', 'MSDouble', - 'MSEnum', 'MSFloat', 'MSInteger', 'MSLongBlob', 'MSLongText', - 'MSMediumBlob', 'MSMediumText', 'MSNChar', 'MSNVarChar', - 'MSNumeric', 'MSSet', 'MSSmallInteger', 'MSString', 'MSText', - 'MSTime', 'MSTimeStamp', 'MSTinyBlob', 'MSTinyInteger', - 'MSTinyText', 'MSVarBinary', 'MSYear' ) - +from sqlalchemy.types import DATE, DATETIME, BOOLEAN, TIME RESERVED_WORDS = set( ['accessible', 'add', 'all', 'alter', 'analyze','and', 'as', 'asc', @@ -271,62 +245,45 @@ SET_RE = re.compile( class _NumericType(object): """Base for MySQL numeric types.""" - def __init__(self, kw): + def __init__(self, **kw): self.unsigned = kw.pop('unsigned', False) self.zerofill = kw.pop('zerofill', False) + super(_NumericType, self).__init__(**kw) + +class _FloatType(_NumericType, sqltypes.Float): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + if isinstance(self, (REAL, DOUBLE)) and \ + ( + (precision is None and scale is not None) or + (precision is not None and scale is None) + ): + raise exc.ArgumentError( + "You must specify both precision and scale or omit " + "both altogether.") + + super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw) + self.scale = scale - def _extend(self, spec): - "Extend a numeric-type declaration with MySQL specific extensions." - - if self.unsigned: - spec += ' UNSIGNED' - if self.zerofill: - spec += ' ZEROFILL' - return spec - +class _IntegerType(_NumericType, sqltypes.Integer): + def __init__(self, display_width=None, **kw): + self.display_width = display_width + super(_IntegerType, self).__init__(**kw) -class _StringType(object): +class _StringType(sqltypes.String): """Base for MySQL string types.""" def __init__(self, charset=None, collation=None, ascii=False, unicode=False, binary=False, - national=False, **kwargs): + national=False, **kw): self.charset = charset # allow collate= or collation= - self.collation = kwargs.get('collate', collation) + self.collation = kw.pop('collate', collation) self.ascii = ascii self.unicode = unicode self.binary = binary self.national = national - - def _extend(self, spec): - """Extend a string-type declaration with standard SQL CHARACTER SET / - COLLATE annotations and MySQL specific extensions. - """ - - if self.charset: - charset = 'CHARACTER SET %s' % self.charset - elif self.ascii: - charset = 'ASCII' - elif self.unicode: - charset = 'UNICODE' - else: - charset = None - - if self.collation: - collation = 'COLLATE %s' % self.collation - elif self.binary: - collation = 'BINARY' - else: - collation = None - - if self.national: - # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. - return ' '.join([c for c in ('NATIONAL', spec, collation) - if c is not None]) - return ' '.join([c for c in (spec, charset, collation) - if c is not None]) - + super(_StringType, self).__init__(**kw) + def __repr__(self): attributes = inspect.getargspec(self.__init__)[0][1:] attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) @@ -341,10 +298,23 @@ class _StringType(object): ', '.join(['%s=%r' % (k, params[k]) for k in params])) -class MSNumeric(sqltypes.Numeric, _NumericType): - """MySQL NUMERIC type.""" +class _BinaryType(sqltypes.Binary): + """Base for MySQL binary types.""" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return util.buffer(value) + return process - def __init__(self, precision=10, scale=2, asdecimal=True, **kw): +class NUMERIC(_NumericType, sqltypes.NUMERIC): + """MySQL NUMERIC type.""" + + __visit_name__ = 'NUMERIC' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a NUMERIC. :param precision: Total digits in this number. If scale and precision @@ -360,34 +330,15 @@ class MSNumeric(sqltypes.Numeric, _NumericType): numeric. """ - _NumericType.__init__(self, kw) - sqltypes.Numeric.__init__(self, precision, scale, asdecimal=asdecimal, **kw) - - def get_col_spec(self): - if self.precision is None: - return self._extend("NUMERIC") - else: - return self._extend("NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}) - - def bind_processor(self, dialect): - return None + super(NUMERIC, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) - def result_processor(self, dialect): - if not self.asdecimal: - def process(value): - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - return process - else: - return None - -class MSDecimal(MSNumeric): +class DECIMAL(_NumericType, sqltypes.DECIMAL): """MySQL DECIMAL type.""" - - def __init__(self, precision=10, scale=2, asdecimal=True, **kw): + + __visit_name__ = 'DECIMAL' + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a DECIMAL. :param precision: Total digits in this number. If scale and precision @@ -403,20 +354,14 @@ class MSDecimal(MSNumeric): numeric. """ - super(MSDecimal, self).__init__(precision, scale, asdecimal=asdecimal, **kw) - - def get_col_spec(self): - if self.precision is None: - return self._extend("DECIMAL") - elif self.scale is None: - return self._extend("DECIMAL(%(precision)s)" % {'precision': self.precision}) - else: - return self._extend("DECIMAL(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}) - + super(DECIMAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) -class MSDouble(sqltypes.Float, _NumericType): + +class DOUBLE(_FloatType): """MySQL DOUBLE type.""" + __visit_name__ = 'DOUBLE' + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a DOUBLE. @@ -433,29 +378,13 @@ class MSDouble(sqltypes.Float, _NumericType): numeric. """ - if ((precision is None and scale is not None) or - (precision is not None and scale is None)): - raise exc.ArgumentError( - "You must specify both precision and scale or omit " - "both altogether.") - - _NumericType.__init__(self, kw) - sqltypes.Float.__init__(self, asdecimal=asdecimal, **kw) - self.scale = scale - self.precision = precision + super(DOUBLE, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) - def get_col_spec(self): - if self.precision is not None and self.scale is not None: - return self._extend("DOUBLE(%(precision)s, %(scale)s)" % - {'precision': self.precision, - 'scale' : self.scale}) - else: - return self._extend('DOUBLE') - - -class MSReal(MSDouble): +class REAL(_FloatType): """MySQL REAL type.""" + __visit_name__ = 'REAL' + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a REAL. @@ -472,20 +401,13 @@ class MSReal(MSDouble): numeric. """ - MSDouble.__init__(self, precision, scale, asdecimal, **kw) + super(REAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) - def get_col_spec(self): - if self.precision is not None and self.scale is not None: - return self._extend("REAL(%(precision)s, %(scale)s)" % - {'precision': self.precision, - 'scale' : self.scale}) - else: - return self._extend('REAL') - - -class MSFloat(sqltypes.Float, _NumericType): +class FLOAT(_FloatType, sqltypes.FLOAT): """MySQL FLOAT type.""" + __visit_name__ = 'FLOAT' + def __init__(self, precision=None, scale=None, asdecimal=False, **kw): """Construct a FLOAT. @@ -502,26 +424,16 @@ class MSFloat(sqltypes.Float, _NumericType): numeric. """ - _NumericType.__init__(self, kw) - sqltypes.Float.__init__(self, asdecimal=asdecimal, **kw) - self.scale = scale - self.precision = precision - - def get_col_spec(self): - if self.scale is not None and self.precision is not None: - return self._extend("FLOAT(%s, %s)" % (self.precision, self.scale)) - elif self.precision is not None: - return self._extend("FLOAT(%s)" % (self.precision,)) - else: - return self._extend("FLOAT") + super(FLOAT, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) def bind_processor(self, dialect): return None - -class MSInteger(sqltypes.Integer, _NumericType): +class INTEGER(_IntegerType, sqltypes.INTEGER): """MySQL INTEGER type.""" + __visit_name__ = 'INTEGER' + def __init__(self, display_width=None, **kw): """Construct an INTEGER. @@ -535,24 +447,13 @@ class MSInteger(sqltypes.Integer, _NumericType): numeric. """ - if 'length' in kw: - util.warn_deprecated("'length' is deprecated for MSInteger and subclasses. Use 'display_width'.") - self.display_width = kw.pop('length') - else: - self.display_width = display_width - _NumericType.__init__(self, kw) - sqltypes.Integer.__init__(self, **kw) - - def get_col_spec(self): - if self.display_width is not None: - return self._extend("INTEGER(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("INTEGER") + super(INTEGER, self).__init__(display_width=display_width, **kw) - -class MSBigInteger(MSInteger): +class BIGINT(_IntegerType, sqltypes.BIGINT): """MySQL BIGINTEGER type.""" + __visit_name__ = 'BIGINT' + def __init__(self, display_width=None, **kw): """Construct a BIGINTEGER. @@ -566,18 +467,13 @@ class MSBigInteger(MSInteger): numeric. """ - super(MSBigInteger, self).__init__(display_width, **kw) - - def get_col_spec(self): - if self.display_width is not None: - return self._extend("BIGINT(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("BIGINT") + super(BIGINT, self).__init__(display_width=display_width, **kw) - -class MSMediumInteger(MSInteger): +class MEDIUMINT(_IntegerType): """MySQL MEDIUMINTEGER type.""" + __visit_name__ = 'MEDIUMINT' + def __init__(self, display_width=None, **kw): """Construct a MEDIUMINTEGER @@ -591,19 +487,13 @@ class MSMediumInteger(MSInteger): numeric. """ - super(MSMediumInteger, self).__init__(display_width, **kw) - - def get_col_spec(self): - if self.display_width is not None: - return self._extend("MEDIUMINT(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("MEDIUMINT") + super(MEDIUMINT, self).__init__(display_width=display_width, **kw) - - -class MSTinyInteger(MSInteger): +class TINYINT(_IntegerType): """MySQL TINYINT type.""" + __visit_name__ = 'TINYINT' + def __init__(self, display_width=None, **kw): """Construct a TINYINT. @@ -621,18 +511,13 @@ class MSTinyInteger(MSInteger): numeric. """ - super(MSTinyInteger, self).__init__(display_width, **kw) - - def get_col_spec(self): - if self.display_width is not None: - return self._extend("TINYINT(%s)" % self.display_width) - else: - return self._extend("TINYINT") - + super(TINYINT, self).__init__(display_width=display_width, **kw) -class MSSmallInteger(sqltypes.Smallinteger, MSInteger): +class SMALLINT(_IntegerType, sqltypes.SMALLINT): """MySQL SMALLINTEGER type.""" + __visit_name__ = 'SMALLINT' + def __init__(self, display_width=None, **kw): """Construct a SMALLINTEGER. @@ -646,18 +531,9 @@ class MSSmallInteger(sqltypes.Smallinteger, MSInteger): numeric. """ - self.display_width = display_width - _NumericType.__init__(self, kw) - sqltypes.SmallInteger.__init__(self, **kw) + super(SMALLINT, self).__init__(display_width=display_width, **kw) - def get_col_spec(self): - if self.display_width is not None: - return self._extend("SMALLINT(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("SMALLINT") - - -class MSBit(sqltypes.TypeEngine): +class BIT(sqltypes.TypeEngine): """MySQL BIT type. This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater for @@ -666,6 +542,8 @@ class MSBit(sqltypes.TypeEngine): """ + __visit_name__ = 'BIT' + def __init__(self, length=None): """Construct a BIT. @@ -685,32 +563,10 @@ class MSBit(sqltypes.TypeEngine): return value return process - def get_col_spec(self): - if self.length is not None: - return "BIT(%s)" % self.length - else: - return "BIT" - - -class MSDateTime(sqltypes.DateTime): - """MySQL DATETIME type.""" - - def get_col_spec(self): - return "DATETIME" - - -class MSDate(sqltypes.Date): - """MySQL DATE type.""" - - def get_col_spec(self): - return "DATE" - - -class MSTime(sqltypes.Time): +class _MSTime(sqltypes.Time): """MySQL TIME type.""" - def get_col_spec(self): - return "TIME" + __visit_name__ = 'TIME' def result_processor(self, dialect): def process(value): @@ -721,45 +577,24 @@ class MSTime(sqltypes.Time): return None return process -class MSTimeStamp(sqltypes.TIMESTAMP): - """MySQL TIMESTAMP type. - - To signal the orm to automatically re-select modified rows to retrieve the - updated timestamp, add a ``server_default`` to your - :class:`~sqlalchemy.Column` specification:: - - from sqlalchemy.databases import mysql - Column('updated', mysql.MSTimeStamp, - server_default=sql.text('CURRENT_TIMESTAMP') - ) - - The full range of MySQL 4.1+ TIMESTAMP defaults can be specified in - the the default:: +class TIMESTAMP(sqltypes.TIMESTAMP): + """MySQL TIMESTAMP type.""" + __visit_name__ = 'TIMESTAMP' - server_default=sql.text('CURRENT TIMESTAMP ON UPDATE CURRENT_TIMESTAMP') - - """ - - def get_col_spec(self): - return "TIMESTAMP" - - -class MSYear(sqltypes.TypeEngine): +class YEAR(sqltypes.TypeEngine): """MySQL YEAR type, for single byte storage of years 1901-2155.""" + __visit_name__ = 'YEAR' + def __init__(self, display_width=None): self.display_width = display_width - def get_col_spec(self): - if self.display_width is None: - return "YEAR" - else: - return "YEAR(%s)" % self.display_width - -class MSText(_StringType, sqltypes.Text): +class TEXT(_StringType, sqltypes.TEXT): """MySQL TEXT type, for text up to 2^16 characters.""" - def __init__(self, length=None, **kwargs): + __visit_name__ = 'TEXT' + + def __init__(self, length=None, **kw): """Construct a TEXT. :param length: Optional, if provided the server may optimize storage @@ -787,20 +622,13 @@ class MSText(_StringType, sqltypes.Text): only the collation of character data. """ - _StringType.__init__(self, **kwargs) - sqltypes.Text.__init__(self, length, - kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None)) + super(TEXT, self).__init__(length=length, **kw) - def get_col_spec(self): - if self.length: - return self._extend("TEXT(%d)" % self.length) - else: - return self._extend("TEXT") - - -class MSTinyText(MSText): +class TINYTEXT(_StringType): """MySQL TINYTEXT type, for text up to 2^8 characters.""" + __visit_name__ = 'TINYTEXT' + def __init__(self, **kwargs): """Construct a TINYTEXT. @@ -825,16 +653,13 @@ class MSTinyText(MSText): only the collation of character data. """ + super(TINYTEXT, self).__init__(**kwargs) - super(MSTinyText, self).__init__(**kwargs) - - def get_col_spec(self): - return self._extend("TINYTEXT") - - -class MSMediumText(MSText): +class MEDIUMTEXT(_StringType): """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" + __visit_name__ = 'MEDIUMTEXT' + def __init__(self, **kwargs): """Construct a MEDIUMTEXT. @@ -859,15 +684,13 @@ class MSMediumText(MSText): only the collation of character data. """ - super(MSMediumText, self).__init__(**kwargs) + super(MEDIUMTEXT, self).__init__(**kwargs) - def get_col_spec(self): - return self._extend("MEDIUMTEXT") - - -class MSLongText(MSText): +class LONGTEXT(_StringType): """MySQL LONGTEXT type, for text up to 2^32 characters.""" + __visit_name__ = 'LONGTEXT' + def __init__(self, **kwargs): """Construct a LONGTEXT. @@ -892,15 +715,14 @@ class MSLongText(MSText): only the collation of character data. """ - super(MSLongText, self).__init__(**kwargs) - - def get_col_spec(self): - return self._extend("LONGTEXT") + super(LONGTEXT, self).__init__(**kwargs) - -class MSString(_StringType, sqltypes.String): + +class VARCHAR(_StringType, sqltypes.VARCHAR): """MySQL VARCHAR type, for variable-length character data.""" + __visit_name__ = 'VARCHAR' + def __init__(self, length=None, **kwargs): """Construct a VARCHAR. @@ -925,22 +747,15 @@ class MSString(_StringType, sqltypes.String): only the collation of character data. """ - _StringType.__init__(self, **kwargs) - sqltypes.String.__init__(self, length, - kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None)) - - def get_col_spec(self): - if self.length: - return self._extend("VARCHAR(%d)" % self.length) - else: - return self._extend("VARCHAR") + super(VARCHAR, self).__init__(length=length, **kwargs) - -class MSChar(_StringType, sqltypes.CHAR): +class CHAR(_StringType, sqltypes.CHAR): """MySQL CHAR type, for fixed-length character data.""" + __visit_name__ = 'CHAR' + def __init__(self, length, **kwargs): - """Construct an NCHAR. + """Construct a CHAR. :param length: Maximum data length, in characters. @@ -952,21 +767,17 @@ class MSChar(_StringType, sqltypes.CHAR): compatible with the national character set. """ - _StringType.__init__(self, **kwargs) - sqltypes.CHAR.__init__(self, length, - kwargs.get('convert_unicode', False)) - - def get_col_spec(self): - return self._extend("CHAR(%(length)s)" % {'length' : self.length}) - + super(CHAR, self).__init__(length=length, **kwargs) -class MSNVarChar(_StringType, sqltypes.String): +class NVARCHAR(_StringType, sqltypes.NVARCHAR): """MySQL NVARCHAR type. For variable-length character data in the server's configured national character set. """ + __visit_name__ = 'NVARCHAR' + def __init__(self, length=None, **kwargs): """Construct an NVARCHAR. @@ -981,23 +792,18 @@ class MSNVarChar(_StringType, sqltypes.String): """ kwargs['national'] = True - _StringType.__init__(self, **kwargs) - sqltypes.String.__init__(self, length, - kwargs.get('convert_unicode', False)) - - def get_col_spec(self): - # We'll actually generate the equiv. "NATIONAL VARCHAR" instead - # of "NVARCHAR". - return self._extend("VARCHAR(%(length)s)" % {'length': self.length}) + super(NVARCHAR, self).__init__(length=length, **kwargs) -class MSNChar(_StringType, sqltypes.CHAR): +class NCHAR(_StringType, sqltypes.NCHAR): """MySQL NCHAR type. For fixed-length character data in the server's configured national character set. """ + __visit_name__ = 'NCHAR' + def __init__(self, length=None, **kwargs): """Construct an NCHAR. Arguments are: @@ -1012,52 +818,28 @@ class MSNChar(_StringType, sqltypes.CHAR): """ kwargs['national'] = True - _StringType.__init__(self, **kwargs) - sqltypes.CHAR.__init__(self, length, - kwargs.get('convert_unicode', False)) - def get_col_spec(self): - # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". - return self._extend("CHAR(%(length)s)" % {'length': self.length}) - - -class _BinaryType(sqltypes.Binary): - """Base for MySQL binary types.""" + super(NCHAR, self).__init__(length=length, **kwargs) - def get_col_spec(self): - if self.length: - return "BLOB(%d)" % self.length - else: - return "BLOB" - def result_processor(self, dialect): - def process(value): - if value is None: - return None - else: - return util.buffer(value) - return process -class MSVarBinary(_BinaryType): +class VARBINARY(_BinaryType): """MySQL VARBINARY type, for variable length binary data.""" + __visit_name__ = 'VARBINARY' + def __init__(self, length=None, **kw): """Construct a VARBINARY. Arguments are: :param length: Maximum data length, in characters. """ - super(MSVarBinary, self).__init__(length, **kw) - - def get_col_spec(self): - if self.length: - return "VARBINARY(%d)" % self.length - else: - return "BLOB" - + super(VARBINARY, self).__init__(length=length, **kw) -class MSBinary(_BinaryType): +class BINARY(_BinaryType): """MySQL BINARY type, for fixed length binary data""" + __visit_name__ = 'BINARY' + def __init__(self, length=None, **kw): """Construct a BINARY. @@ -1068,25 +850,13 @@ class MSBinary(_BinaryType): specified, this will generate a BLOB. This usage is deprecated. """ - super(MSBinary, self).__init__(length, **kw) - - def get_col_spec(self): - if self.length: - return "BINARY(%d)" % self.length - else: - return "BLOB" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - else: - return util.buffer(value) - return process + super(BINARY, self).__init__(length=length, **kw) -class MSBlob(_BinaryType): +class BLOB(_BinaryType, sqltypes.BLOB): """MySQL BLOB type, for binary data up to 2^16 bytes""" + __visit_name__ = 'BLOB' + def __init__(self, length=None, **kw): """Construct a BLOB. Arguments are: @@ -1095,50 +865,29 @@ class MSBlob(_BinaryType): ``length`` characters. """ - super(MSBlob, self).__init__(length, **kw) - - def get_col_spec(self): - if self.length: - return "BLOB(%d)" % self.length - else: - return "BLOB" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - else: - return util.buffer(value) - return process - - def __repr__(self): - return "%s()" % self.__class__.__name__ + super(BLOB, self).__init__(length=length, **kw) -class MSTinyBlob(MSBlob): +class TINYBLOB(_BinaryType): """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" + + __visit_name__ = 'TINYBLOB' - def get_col_spec(self): - return "TINYBLOB" - - -class MSMediumBlob(MSBlob): +class MEDIUMBLOB(_BinaryType): """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" - def get_col_spec(self): - return "MEDIUMBLOB" - + __visit_name__ = 'MEDIUMBLOB' -class MSLongBlob(MSBlob): +class LONGBLOB(_BinaryType): """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" - def get_col_spec(self): - return "LONGBLOB" - + __visit_name__ = 'LONGBLOB' -class MSEnum(MSString): +class ENUM(_StringType): """MySQL ENUM type.""" + __visit_name__ = 'ENUM' + def __init__(self, *enums, **kw): """Construct an ENUM. @@ -1225,10 +974,10 @@ class MSEnum(MSString): self.strict = kw.pop('strict', False) length = max([len(v) for v in self.enums] + [0]) - super(MSEnum, self).__init__(length, **kw) + super(ENUM, self).__init__(length=length, **kw) def bind_processor(self, dialect): - super_convert = super(MSEnum, self).bind_processor(dialect) + super_convert = super(ENUM, self).bind_processor(dialect) def process(value): if self.strict and value is not None and value not in self.enums: raise exc.InvalidRequestError('"%s" not a valid value for ' @@ -1239,15 +988,11 @@ class MSEnum(MSString): return value return process - def get_col_spec(self): - quoted_enums = [] - for e in self.enums: - quoted_enums.append("'%s'" % e.replace("'", "''")) - return self._extend("ENUM(%s)" % ",".join(quoted_enums)) - -class MSSet(MSString): +class SET(_StringType): """MySQL SET type.""" + __visit_name__ = 'SET' + def __init__(self, *values, **kw): """Construct a SET. @@ -1281,7 +1026,7 @@ class MSSet(MSString): only the collation of character data. """ - self.__ddl_values = values + self._ddl_values = values strip_values = [] for a in values: @@ -1292,7 +1037,7 @@ class MSSet(MSString): self.values = strip_values length = max([len(v) for v in strip_values] + [0]) - super(MSSet, self).__init__(length, **kw) + super(SET, self).__init__(length=length, **kw) def result_processor(self, dialect): def process(value): @@ -1316,7 +1061,7 @@ class MSSet(MSString): return process def bind_processor(self, dialect): - super_convert = super(MSSet, self).bind_processor(dialect) + super_convert = super(SET, self).bind_processor(dialect) def process(value): if value is None or isinstance(value, (int, long, basestring)): pass @@ -1332,15 +1077,10 @@ class MSSet(MSString): return value return process - def get_col_spec(self): - return self._extend("SET(%s)" % ",".join(self.__ddl_values)) - - -class MSBoolean(sqltypes.Boolean): +class _MSBoolean(sqltypes.Boolean): """MySQL BOOLEAN type.""" - def get_col_spec(self): - return "BOOL" + __visit_name__ = 'BOOLEAN' def result_processor(self, dialect): def process(value): @@ -1361,74 +1101,91 @@ class MSBoolean(sqltypes.Boolean): return value and True or False return process +# old names +MSBoolean = _MSBoolean +MSTime = _MSTime +MSSet = SET +MSEnum = ENUM +MSLongBlob = LONGBLOB +MSMediumBlob = MEDIUMBLOB +MSTinyBlob = TINYBLOB +MSBlob = BLOB +MSBinary = BINARY +MSVarBinary = VARBINARY +MSNChar = NCHAR +MSNVarChar = NVARCHAR +MSChar = CHAR +MSString = VARCHAR +MSLongText = LONGTEXT +MSMediumText = MEDIUMTEXT +MSTinyText = TINYTEXT +MSText = TEXT +MSYear = YEAR +MSTimeStamp = TIMESTAMP +MSBit = BIT +MSSmallInteger = SMALLINT +MSTinyInteger = TINYINT +MSMediumInteger = MEDIUMINT +MSBigInteger = BIGINT +MSNumeric = NUMERIC +MSDecimal = DECIMAL +MSDouble = DOUBLE +MSReal = REAL +MSFloat = FLOAT +MSInteger = INTEGER + colspecs = { - sqltypes.Integer: MSInteger, - sqltypes.Smallinteger: MSSmallInteger, - sqltypes.Numeric: MSNumeric, - sqltypes.Float: MSFloat, - sqltypes.DateTime: MSDateTime, - sqltypes.Date: MSDate, - sqltypes.Time: MSTime, - sqltypes.String: MSString, - sqltypes.Binary: MSBlob, - sqltypes.Boolean: MSBoolean, - sqltypes.Text: MSText, - sqltypes.CHAR: MSChar, - sqltypes.NCHAR: MSNChar, - sqltypes.TIMESTAMP: MSTimeStamp, - sqltypes.BLOB: MSBlob, - MSDouble: MSDouble, - MSReal: MSReal, - _BinaryType: _BinaryType, + sqltypes.Numeric: NUMERIC, + sqltypes.Float: FLOAT, + sqltypes.Binary: _BinaryType, + sqltypes.Boolean: _MSBoolean, + sqltypes.Time: _MSTime, } # Everything 3.23 through 5.1 excepting OpenGIS types. ischema_names = { - 'bigint': MSBigInteger, - 'binary': MSBinary, - 'bit': MSBit, - 'blob': MSBlob, - 'boolean':MSBoolean, - 'char': MSChar, - 'date': MSDate, - 'datetime': MSDateTime, - 'decimal': MSDecimal, - 'double': MSDouble, - 'enum': MSEnum, - 'fixed': MSDecimal, - 'float': MSFloat, - 'int': MSInteger, - 'integer': MSInteger, - 'longblob': MSLongBlob, - 'longtext': MSLongText, - 'mediumblob': MSMediumBlob, - 'mediumint': MSMediumInteger, - 'mediumtext': MSMediumText, - 'nchar': MSNChar, - 'nvarchar': MSNVarChar, - 'numeric': MSNumeric, - 'set': MSSet, - 'smallint': MSSmallInteger, - 'text': MSText, - 'time': MSTime, - 'timestamp': MSTimeStamp, - 'tinyblob': MSTinyBlob, - 'tinyint': MSTinyInteger, - 'tinytext': MSTinyText, - 'varbinary': MSVarBinary, - 'varchar': MSString, - 'year': MSYear, + 'bigint': BIGINT, + 'binary': BINARY, + 'bit': BIT, + 'blob': BLOB, + 'boolean':BOOLEAN, + 'char': CHAR, + 'date': DATE, + 'datetime': DATETIME, + 'decimal': DECIMAL, + 'double': DOUBLE, + 'enum': ENUM, + 'fixed': DECIMAL, + 'float': FLOAT, + 'int': INTEGER, + 'integer': INTEGER, + 'longblob': LONGBLOB, + 'longtext': LONGTEXT, + 'mediumblob': MEDIUMBLOB, + 'mediumint': MEDIUMINT, + 'mediumtext': MEDIUMTEXT, + 'nchar': NCHAR, + 'nvarchar': NVARCHAR, + 'numeric': NUMERIC, + 'set': SET, + 'smallint': SMALLINT, + 'text': TEXT, + 'time': TIME, + 'timestamp': TIMESTAMP, + 'tinyblob': TINYBLOB, + 'tinyint': TINYINT, + 'tinytext': TINYTEXT, + 'varbinary': VARBINARY, + 'varchar': VARCHAR, + 'year': YEAR, } - class MySQLExecutionContext(default.DefaultExecutionContext): def post_exec(self): - if self.compiled.isinsert and not self.executemany: - if (not len(self._last_inserted_ids) or - self._last_inserted_ids[0] is None): - self._last_inserted_ids = ([self.cursor.lastrowid] + - self._last_inserted_ids[1:]) - elif (not self.isupdate and not self.should_autocommit and + # TODO: i think this 'charset' in the info thing + # is out + + if (not self.isupdate and not self.should_autocommit and self.statement and SET_RE.match(self.statement)): # This misses if a user forces autocommit on text('SET NAMES'), # which is probably a programming error anyhow. @@ -1437,75 +1194,440 @@ class MySQLExecutionContext(default.DefaultExecutionContext): def should_autocommit_text(self, statement): return AUTOCOMMIT_RE.match(statement) +class MySQLCompiler(compiler.SQLCompiler): -class MySQLDialect(default.DefaultDialect): - """Details of the MySQL dialect. Not used directly in application code.""" - name = 'mysql' - supports_alter = True - supports_unicode_statements = False - # identifiers are 64, however aliases can be 255... - max_identifier_length = 255 - supports_sane_rowcount = True - default_paramstyle = 'format' - - def __init__(self, use_ansiquotes=None, **kwargs): - self.use_ansiquotes = use_ansiquotes - default.DefaultDialect.__init__(self, **kwargs) + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'milliseconds': 'millisecond', + }) + + def visit_random_func(self, fn, **kw): + return "rand%s" % self.function_argspec(fn) + + def visit_utc_timestamp_func(self, fn, **kw): + return "UTC_TIMESTAMP" + + def visit_concat_op(self, binary, **kw): + return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_match_op(self, binary, **kw): + return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (self.process(binary.left), self.process(binary.right)) + + def visit_typeclause(self, typeclause): + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, sqltypes.Integer): + if getattr(type_, 'unsigned', False): + return 'UNSIGNED INTEGER' + else: + return 'SIGNED INTEGER' + elif isinstance(type_, sqltypes.TIMESTAMP): + return 'DATETIME' + elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, sqltypes.Date, sqltypes.Time)): + return self.dialect.type_compiler.process(type_) + elif isinstance(type_, sqltypes.Text): + return 'CHAR' + elif (isinstance(type_, sqltypes.String) and not + isinstance(type_, (ENUM, SET))): + if getattr(type_, 'length'): + return 'CHAR(%s)' % type_.length + else: + return 'CHAR' + elif isinstance(type_, sqltypes.Binary): + return 'BINARY' + elif isinstance(type_, NUMERIC): + return self.dialect.type_compiler.process(type_).replace('NUMERIC', 'DECIMAL') + else: + return None - def dbapi(cls): - import MySQLdb as mysql - return mysql - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(database='db', username='user', - password='passwd') - opts.update(url.query) - - util.coerce_kw_type(opts, 'compress', bool) - util.coerce_kw_type(opts, 'connect_timeout', int) - util.coerce_kw_type(opts, 'client_flag', int) - util.coerce_kw_type(opts, 'local_infile', int) - # Note: using either of the below will cause all strings to be returned - # as Unicode, both in raw SQL operations and with column types like - # String and MSString. - util.coerce_kw_type(opts, 'use_unicode', bool) - util.coerce_kw_type(opts, 'charset', str) - - # Rich values 'cursorclass' and 'conv' are not supported via - # query string. - - ssl = {} - for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']: - if key in opts: - ssl[key[4:]] = opts[key] - util.coerce_kw_type(ssl, key[4:], str) - del opts[key] - if ssl: - opts['ssl'] = ssl - - # FOUND_ROWS must be set in CLIENT_FLAGS to enable - # supports_sane_rowcount. - client_flag = opts.get('client_flag', 0) - if self.dbapi is not None: - try: - import MySQLdb.constants.CLIENT as CLIENT_FLAGS - client_flag |= CLIENT_FLAGS.FOUND_ROWS - except: - pass - opts['client_flag'] = client_flag - return [[], opts] + def visit_cast(self, cast, **kwargs): + # No cast until 4, no decimals until 5. + type_ = self.process(cast.typeclause) + if type_ is None: + return self.process(cast.clause) - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) + return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) - def do_executemany(self, cursor, statement, parameters, context=None): - rowcount = cursor.executemany(statement, parameters) - if context is not None: - context._rowcount = rowcount + def get_select_precolumns(self, select): + if isinstance(select._distinct, basestring): + return select._distinct.upper() + " " + elif select._distinct: + return "DISTINCT " + else: + return "" + + def visit_join(self, join, asfrom=False, **kwargs): + # 'JOIN ... ON ...' for inner joins isn't available until 4.0. + # Apparently < 3.23.17 requires theta joins for inner joins + # (but not outer). Not generating these currently, but + # support can be added, preferably after dialects are + # refactored to be version-sensitive. + return ''.join( + (self.process(join.left, asfrom=True), + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "), + self.process(join.right, asfrom=True), + " ON ", + self.process(join.onclause))) + + def for_update_clause(self, select): + if select.for_update == 'read': + return ' LOCK IN SHARE MODE' + else: + return super(MySQLCompiler, self).for_update_clause(select) + + def limit_clause(self, select): + # MySQL supports: + # LIMIT + # LIMIT , + # and in server versions > 3.3: + # LIMIT OFFSET + # The latter is more readable for offsets but we're stuck with the + # former until we can refine dialects by server revision. + + limit, offset = select._limit, select._offset + + if (limit, offset) == (None, None): + return '' + elif offset is not None: + # As suggested by the MySQL docs, need to apply an + # artificial limit if one wasn't provided + if limit is None: + limit = 18446744073709551615 + return ' \n LIMIT %s, %s' % (offset, limit) + else: + # No offset provided, so just use the limit + return ' \n LIMIT %s' % (limit,) + + def visit_update(self, update_stmt): + self.stack.append({'from': set([update_stmt.table])}) + + self.isupdate = True + colparams = self._get_colparams(update_stmt) + + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \ + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) + + if update_stmt._whereclause: + text += " WHERE " + self.process(update_stmt._whereclause) + + limit = update_stmt.kwargs.get('mysql_limit', None) + if limit: + text += " LIMIT %s" % limit + + self.stack.pop(-1) + + return text + +# ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. +# Starting with MySQL 4.1.2, these indexes are created automatically. +# In older versions, the indexes must be created explicitly or the +# creation of foreign key constraints fails." + +class MySQLDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kw): + """Builds column DDL.""" + + colspec = [self.preparer.format_column(column), + self.dialect.type_compiler.process(column.type) + ] + + default = self.get_column_default_string(column) + if default is not None: + colspec.append('DEFAULT ' + default) + + if not column.nullable: + colspec.append('NOT NULL') + + if column.primary_key and column.autoincrement: + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + isinstance(c.type, sqltypes.Integer) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('AUTO_INCREMENT') + except IndexError: + pass + + return ' '.join(colspec) - def supports_unicode_statements(self): - return True + def post_create_table(self, table): + """Build table-level CREATE options like ENGINE and COLLATE.""" + + table_opts = [] + for k in table.kwargs: + if k.startswith('mysql_'): + opt = k[6:].upper() + joiner = '=' + if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', + 'CHARACTER SET', 'COLLATE'): + joiner = ' ' + + table_opts.append(joiner.join((opt, table.kwargs[k]))) + return ' '.join(table_opts) + + def visit_drop_index(self, drop): + index = drop.element + + return "\nDROP INDEX %s ON %s" % \ + (self.preparer.quote(self._validate_identifier(index.name, False), index.quote), + self.preparer.format_table(index.table)) + + def visit_drop_constraint(self, drop): + constraint = drop.element + if isinstance(constraint, sa_schema.ForeignKeyConstraint): + qual = "FOREIGN KEY " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.PrimaryKeyConstraint): + qual = "PRIMARY KEY " + const = "" + elif isinstance(constraint, sa_schema.UniqueConstraint): + qual = "INDEX " + const = self.preparer.format_constraint(constraint) + else: + qual = "" + const = self.preparer.format_constraint(constraint) + return "ALTER TABLE %s DROP %s%s" % \ + (self.preparer.format_table(constraint.table), + qual, const) + +class MySQLTypeCompiler(compiler.GenericTypeCompiler): + def _extend_numeric(self, type_, spec): + "Extend a numeric-type declaration with MySQL specific extensions." + + if not self._mysql_type(type_): + return spec + + if type_.unsigned: + spec += ' UNSIGNED' + if type_.zerofill: + spec += ' ZEROFILL' + return spec + + def _extend_string(self, type_, defaults, spec): + """Extend a string-type declaration with standard SQL CHARACTER SET / + COLLATE annotations and MySQL specific extensions. + + """ + + def attr(name): + return getattr(type_, name, defaults.get(name)) + + if attr('charset'): + charset = 'CHARACTER SET %s' % attr('charset') + elif attr('ascii'): + charset = 'ASCII' + elif attr('unicode'): + charset = 'UNICODE' + else: + charset = None + + if attr('collation'): + collation = 'COLLATE %s' % type_.collation + elif attr('binary'): + collation = 'BINARY' + else: + collation = None + + if attr('national'): + # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. + return ' '.join([c for c in ('NATIONAL', spec, collation) + if c is not None]) + return ' '.join([c for c in (spec, charset, collation) + if c is not None]) + + def _mysql_type(self, type_): + return isinstance(type_, (_StringType, _NumericType, _BinaryType)) + + def visit_NUMERIC(self, type_): + if type_.precision is None: + return self._extend_numeric(type_, "NUMERIC") + elif type_.scale is None: + return self._extend_numeric(type_, "NUMERIC(%(precision)s)" % {'precision': type_.precision}) + else: + return self._extend_numeric(type_, "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}) + + def visit_DECIMAL(self, type_): + if type_.precision is None: + return self._extend_numeric(type_, "DECIMAL") + elif type_.scale is None: + return self._extend_numeric(type_, "DECIMAL(%(precision)s)" % {'precision': type_.precision}) + else: + return self._extend_numeric(type_, "DECIMAL(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}) + + def visit_DOUBLE(self, type_): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric(type_, "DOUBLE(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale' : type_.scale}) + else: + return self._extend_numeric(type_, 'DOUBLE') + + def visit_REAL(self, type_): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric(type_, "REAL(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale' : type_.scale}) + else: + return self._extend_numeric(type_, 'REAL') + + def visit_FLOAT(self, type_): + if self._mysql_type(type_) and type_.scale is not None and type_.precision is not None: + return self._extend_numeric(type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)) + elif type_.precision is not None: + return self._extend_numeric(type_, "FLOAT(%s)" % (type_.precision,)) + else: + return self._extend_numeric(type_, "FLOAT") + + def visit_INTEGER(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "INTEGER(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "INTEGER") + + def visit_BIGINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "BIGINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "BIGINT") + + def visit_MEDIUMINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "MEDIUMINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "MEDIUMINT") + + def visit_TINYINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "TINYINT(%s)" % type_.display_width) + else: + return self._extend_numeric(type_, "TINYINT") + + def visit_SMALLINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "SMALLINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "SMALLINT") + + def visit_BIT(self, type_): + if type_.length is not None: + return "BIT(%s)" % type_.length + else: + return "BIT" + + def visit_DATETIME(self, type_): + return "DATETIME" + + def visit_DATE(self, type_): + return "DATE" + + def visit_TIME(self, type_): + return "TIME" + + def visit_TIMESTAMP(self, type_): + return 'TIMESTAMP' + + def visit_YEAR(self, type_): + if type_.display_width is None: + return "YEAR" + else: + return "YEAR(%s)" % type_.display_width + + def visit_TEXT(self, type_): + if type_.length: + return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) + else: + return self._extend_string(type_, {}, "TEXT") + + def visit_TINYTEXT(self, type_): + return self._extend_string(type_, {}, "TINYTEXT") + + def visit_MEDIUMTEXT(self, type_): + return self._extend_string(type_, {}, "MEDIUMTEXT") + + def visit_LONGTEXT(self, type_): + return self._extend_string(type_, {}, "LONGTEXT") + + def visit_VARCHAR(self, type_): + if type_.length: + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) + else: + return self._extend_string(type_, {}, "VARCHAR") + + def visit_CHAR(self, type_): + return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length' : type_.length}) + + def visit_NVARCHAR(self, type_): + # We'll actually generate the equiv. "NATIONAL VARCHAR" instead + # of "NVARCHAR". + return self._extend_string(type_, {'national':True}, "VARCHAR(%(length)s)" % {'length': type_.length}) + + def visit_NCHAR(self, type_): + # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". + return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length': type_.length}) + + def visit_VARBINARY(self, type_): + if type_.length: + return "VARBINARY(%d)" % type_.length + else: + return self.visit_BLOB(type_) + + def visit_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_BINARY(self, type_): + if type_.length: + return "BINARY(%d)" % type_.length + else: + return self.visit_BLOB(type_) + + def visit_BLOB(self, type_): + if type_.length: + return "BLOB(%d)" % type_.length + else: + return "BLOB" + + def visit_TINYBLOB(self, type_): + return "TINYBLOB" + + def visit_MEDIUMBLOB(self, type_): + return "MEDIUMBLOB" + + def visit_LONGBLOB(self, type_): + return "LONGBLOB" + + def visit_ENUM(self, type_): + quoted_enums = [] + for e in type_.enums: + quoted_enums.append("'%s'" % e.replace("'", "''")) + return self._extend_string(type_, {}, "ENUM(%s)" % ",".join(quoted_enums)) + + def visit_SET(self, type_): + return self._extend_string(type_, {}, "SET(%s)" % ",".join(type_._ddl_values)) + + def visit_BOOLEAN(self, type): + return "BOOL" + + +class MySQLDialect(default.DefaultDialect): + """Details of the MySQL dialect. Not used directly in application code.""" + name = 'mysql' + supports_alter = True + # identifiers are 64, however aliases can be 255... + max_identifier_length = 255 + + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + default_paramstyle = 'format' + colspecs = colspecs + + statement_compiler = MySQLCompiler + ddl_compiler = MySQLDDLCompiler + type_compiler = MySQLTypeCompiler + ischema_names = ischema_names + + def __init__(self, use_ansiquotes=None, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) def do_commit(self, connection): """Execute a COMMIT.""" @@ -1518,7 +1640,7 @@ class MySQLDialect(default.DefaultDialect): try: connection.commit() except: - if self._server_version_info(connection) < (3, 23, 15): + if self.server_version_info < (3, 23, 15): args = sys.exc_info()[1].args if args and args[0] == 1064: return @@ -1530,59 +1652,66 @@ class MySQLDialect(default.DefaultDialect): try: connection.rollback() except: - if self._server_version_info(connection) < (3, 23, 15): + if self.server_version_info < (3, 23, 15): args = sys.exc_info()[1].args if args and args[0] == 1064: return raise def do_begin_twophase(self, connection, xid): - connection.execute("XA BEGIN %s", xid) + connection.execute(sql.text("XA BEGIN :xid"), xid=xid) def do_prepare_twophase(self, connection, xid): - connection.execute("XA END %s", xid) - connection.execute("XA PREPARE %s", xid) + connection.execute(sql.text("XA END :xid"), xid=xid) + connection.execute(sql.text("XA PREPARE :xid"), xid=xid) def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): if not is_prepared: - connection.execute("XA END %s", xid) - connection.execute("XA ROLLBACK %s", xid) + connection.execute(sql.text("XA END :xid"), xid=xid) + connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid) def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): if not is_prepared: self.do_prepare_twophase(connection, xid) - connection.execute("XA COMMIT %s", xid) + connection.execute(sql.text("XA COMMIT :xid"), xid=xid) def do_recover_twophase(self, connection): resultset = connection.execute("XA RECOVER") return [row['data'][0:row['gtrid_length']] for row in resultset] - def do_ping(self, connection): - connection.ping() - def is_disconnect(self, e): if isinstance(e, self.dbapi.OperationalError): - return e.args[0] in (2006, 2013, 2014, 2045, 2055) + return self._extract_error_code(e) in (2006, 2013, 2014, 2045, 2055) elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, this is the error you get return "(0, '')" in str(e) else: return False + def _compat_fetchall(self, rp, charset=None): + """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" + + return [_DecodingRowProxy(row, charset) for row in rp.fetchall()] + + def _compat_fetchone(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" + + return _DecodingRowProxy(rp.fetchone(), charset) + + def _extract_error_code(self, exception): + raise NotImplementedError() + def get_default_schema_name(self, connection): return connection.execute('SELECT DATABASE()').scalar() - get_default_schema_name = engine_base.connection_memoize( - ('dialect', 'default_schema_name'))(get_default_schema_name) def table_names(self, connection, schema): """Return a Unicode SHOW TABLES from a given schema.""" - charset = self._detect_charset(connection) - self._autoset_identifier_style(connection) + charset = self._connection_charset rp = connection.execute("SHOW TABLES FROM %s" % self.identifier_preparer.quote_identifier(schema)) - return [row[0] for row in _compat_fetchall(rp, charset=charset)] + return [row[0] for row in self._compat_fetchall(rp, charset=charset)] def has_table(self, connection, table_name, schema=None): # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly @@ -1595,7 +1724,6 @@ class MySQLDialect(default.DefaultDialect): # full_name = self.identifier_preparer.format_table(table, # use_schema=True) - self._autoset_identifier_style(connection) full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( schema, table_name)) @@ -1609,73 +1737,191 @@ class MySQLDialect(default.DefaultDialect): rs.close() return have except exc.SQLError, e: - if e.orig.args[0] == 1146: + if self._extract_error_code(e) == 1146: return False raise finally: if rs: rs.close() + + def initialize(self, connection): + self.server_version_info = self._get_server_version_info(connection) + self._connection_charset = self._detect_charset(connection) + self._server_casing = self._detect_casing(connection) + self._server_collations = self._detect_collations(connection) + self._server_ansiquotes = self._detect_ansiquotes(connection) + + if self._server_ansiquotes: + self.preparer = MySQLANSIIdentifierPreparer + else: + self.preparer = MySQLIdentifierPreparer + self.identifier_preparer = self.preparer(self) + + @reflection.cache + def get_schema_names(self, connection, **kw): + rp = connection.execute("SHOW schemas") + return [r[0] for r in rp] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.get_default_schema_name(connection) + if self.server_version_info < (5, 0, 2): + return self.table_names(connection, schema) + charset = self._connection_charset + rp = connection.execute("SHOW FULL TABLES FROM %s" % + self.identifier_preparer.quote_identifier(schema)) + + return [row[0] for row in self._compat_fetchall(rp, charset=charset)\ + if row[1] == 'BASE TABLE'] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + charset = self._connection_charset + if self.server_version_info < (5, 0, 2): + raise NotImplementedError + if schema is None: + schema = self.get_default_schema_name(connection) + if self.server_version_info < (5, 0, 2): + return self.table_names(connection, schema) + charset = self._connection_charset + rp = connection.execute("SHOW FULL TABLES FROM %s" % + self.identifier_preparer.quote_identifier(schema)) + return [row[0] for row in self._compat_fetchall(rp, charset=charset)\ + if row[1] == 'VIEW'] + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + return parsed_state.table_options + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + return parsed_state.columns + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + for key in parsed_state.keys: + if key['type'] == 'PRIMARY': + # There can be only one. + ##raise Exception, str(key) + return [s[0] for s in key['columns']] + return [] + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + default_schema = None + + fkeys = [] - def server_version_info(self, connection): - """A tuple of the database server version. - - Formats the remote server version as a tuple of version values, - e.g. ``(5, 0, 44)``. If there are strings in the version number - they will be in the tuple too, so don't count on these all being - ``int`` values. - - This is a fast check that does not require a round trip. It is also - cached per-Connection. - """ - - return self._server_version_info(connection.connection.connection) - server_version_info = engine_base.connection_memoize( - ('mysql', 'server_version_info'))(server_version_info) + for spec in parsed_state.constraints: + # only FOREIGN KEYs + ref_name = spec['table'][-1] + ref_schema = len(spec['table']) > 1 and spec['table'][-2] or schema - def _server_version_info(self, dbapi_con): - """Convert a MySQL-python server_info string into a tuple.""" + if not ref_schema: + if default_schema is None: + default_schema = \ + connection.dialect.get_default_schema_name(connection) + if schema == default_schema: + ref_schema = schema - version = [] - r = re.compile('[.\-]') - for n in r.split(dbapi_con.get_server_info()): - try: - version.append(int(n)) - except ValueError: - version.append(n) - return tuple(version) + loc_names = spec['local'] + ref_names = spec['foreign'] - def reflecttable(self, connection, table, include_columns): - """Load column definitions from the server.""" + con_kw = {} + for opt in ('name', 'onupdate', 'ondelete'): + if spec.get(opt, False): + con_kw[opt] = spec[opt] - charset = self._detect_charset(connection) - self._autoset_identifier_style(connection) + fkey_d = { + 'name' : spec['name'], + 'constrained_columns' : loc_names, + 'referred_schema' : ref_schema, + 'referred_table' : ref_name, + 'referred_columns' : ref_names, + 'options' : con_kw + } + fkeys.append(fkey_d) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + + indexes = [] + for spec in parsed_state.keys: + unique = False + flavor = spec['type'] + if flavor == 'PRIMARY': + continue + if flavor == 'UNIQUE': + unique = True + elif flavor in (None, 'FULLTEXT', 'SPATIAL'): + pass + else: + self.logger.info( + "Converting unknown KEY type %s to a plain KEY" % flavor) + pass + index_d = {} + index_d['name'] = spec['name'] + index_d['column_names'] = [s[0] for s in spec['columns']] + index_d['unique'] = unique + index_d['type'] = flavor + indexes.append(index_d) + return indexes + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + + charset = self._connection_charset + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, view_name)) + sql = self._show_create_table(connection, None, charset, + full_name=full_name) + return sql + def _parsed_state_or_create(self, connection, table_name, schema=None, **kw): + return self._setup_parser( + connection, + table_name, + schema, + info_cache=kw.get('info_cache', None) + ) + + @reflection.cache + def _setup_parser(self, connection, table_name, schema=None, **kw): + charset = self._connection_charset try: - reflector = self.reflector + parser = self.parser except AttributeError: preparer = self.identifier_preparer - if (self.server_version_info(connection) < (4, 1) and - self.use_ansiquotes): + if (self.server_version_info < (4, 1) and + self._server_ansiquotes): # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 preparer = MySQLIdentifierPreparer(self) - - self.reflector = reflector = MySQLSchemaReflector(preparer) - - sql = self._show_create_table(connection, table, charset) + self.parser = parser = MySQLTableDefinitionParser(self, preparer) + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, table_name)) + sql = self._show_create_table(connection, None, charset, + full_name=full_name) if sql.startswith('CREATE ALGORITHM'): # Adapt views to something table-like. - columns = self._describe_table(connection, table, charset) - sql = reflector._describe_to_create(table, columns) - - self._adjust_casing(connection, table) - - return reflector.reflect(connection, table, sql, charset, - only=include_columns) - - def _adjust_casing(self, connection, table, charset=None): + columns = self._describe_table(connection, None, charset, + full_name=full_name) + sql = parser._describe_to_create(table_name, columns) + return parser.parse(sql, charset) + + def _adjust_casing(self, table, charset=None): """Adjust Table name to the server case sensitivity, if needed.""" - casing = self._detect_casing(connection) + casing = self._server_casing # For winxx database hosts. TODO: is this really needed? if casing == 1 and table.name != table.name.lower(): @@ -1683,50 +1929,8 @@ class MySQLDialect(default.DefaultDialect): lc_alias = schema._get_table_key(table.name, table.schema) table.metadata.tables[lc_alias] = table - def _detect_charset(self, connection): - """Sniff out the character set in use for connection results.""" - - # Allow user override, won't sniff if force_charset is set. - if ('mysql', 'force_charset') in connection.info: - return connection.info[('mysql', 'force_charset')] - - # Note: MySQL-python 1.2.1c7 seems to ignore changes made - # on a connection via set_character_set() - if self.server_version_info(connection) < (4, 1, 0): - try: - return connection.connection.character_set_name() - except AttributeError: - # < 1.2.1 final MySQL-python drivers have no charset support. - # a query is needed. - pass - - # Prefer 'character_set_results' for the current connection over the - # value in the driver. SET NAMES or individual variable SETs will - # change the charset without updating the driver's view of the world. - # - # If it's decided that issuing that sort of SQL leaves you SOL, then - # this can prefer the driver value. - rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") - opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)]) - - if 'character_set_results' in opts: - return opts['character_set_results'] - try: - return connection.connection.character_set_name() - except AttributeError: - # Still no charset on < 1.2.1 final... - if 'character_set' in opts: - return opts['character_set'] - else: - util.warn( - "Could not detect the connection character set with this " - "combination of MySQL server and MySQL-python. " - "MySQL-python >= 1.2.2 is recommended. Assuming latin1.") - return 'latin1' - _detect_charset = engine_base.connection_memoize( - ('mysql', 'charset'))(_detect_charset) - + raise NotImplementedError() def _detect_casing(self, connection): """Sniff out identifier case sensitivity. @@ -1737,8 +1941,8 @@ class MySQLDialect(default.DefaultDialect): """ # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html - charset = self._detect_charset(connection) - row = _compat_fetchone(connection.execute( + charset = self._connection_charset + row = self._compat_fetchone(connection.execute( "SHOW VARIABLES LIKE 'lower_case_table_names'"), charset=charset) if not row: @@ -1754,8 +1958,6 @@ class MySQLDialect(default.DefaultDialect): cs = int(row[1]) row.close() return cs - _detect_casing = engine_base.connection_memoize( - ('mysql', 'lower_case_table_names'))(_detect_casing) def _detect_collations(self, connection): """Pull the active COLLATIONS list from the server. @@ -1764,49 +1966,22 @@ class MySQLDialect(default.DefaultDialect): """ collations = {} - if self.server_version_info(connection) < (4, 1, 0): + if self.server_version_info < (4, 1, 0): pass else: - charset = self._detect_charset(connection) + charset = self._connection_charset rs = connection.execute('SHOW COLLATION') - for row in _compat_fetchall(rs, charset): + for row in self._compat_fetchall(rs, charset): collations[row[0]] = row[1] return collations - _detect_collations = engine_base.connection_memoize( - ('mysql', 'collations'))(_detect_collations) - - def use_ansiquotes(self, useansi): - self._use_ansiquotes = useansi - if useansi: - self.preparer = MySQLANSIIdentifierPreparer - else: - self.preparer = MySQLIdentifierPreparer - # icky - if hasattr(self, 'identifier_preparer'): - self.identifier_preparer = self.preparer(self) - if hasattr(self, 'reflector'): - del self.reflector - - use_ansiquotes = property(lambda s: s._use_ansiquotes, use_ansiquotes, - doc="True if ANSI_QUOTES is in effect.") - - def _autoset_identifier_style(self, connection, charset=None): - """Detect and adjust for the ANSI_QUOTES sql mode. - If the dialect's use_ansiquotes is unset, query the server's sql mode - and reset the identifier style. - - Note that this currently *only* runs during reflection. Ideally this - would run the first time a connection pool connects to the database, - but the infrastructure for that is not yet in place. - """ - - if self.use_ansiquotes is not None: - return + def _detect_ansiquotes(self, connection): + """Detect and adjust for the ANSI_QUOTES sql mode.""" - row = _compat_fetchone( + row = self._compat_fetchone( connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), - charset=charset) + charset=self._connection_charset) + if not row: mode = '' else: @@ -1816,7 +1991,7 @@ class MySQLDialect(default.DefaultDialect): mode_no = int(mode) mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' - self.use_ansiquotes = 'ANSI_QUOTES' in mode + return 'ANSI_QUOTES' in mode def _show_create_table(self, connection, table, charset=None, full_name=None): @@ -1831,11 +2006,11 @@ class MySQLDialect(default.DefaultDialect): try: rp = connection.execute(st) except exc.SQLError, e: - if e.orig.args[0] == 1146: + if self._extract_error_code(e) == 1146: raise exc.NoSuchTableError(full_name) else: raise - row = _compat_fetchone(rp, charset=charset) + row = self._compat_fetchone(rp, charset=charset) if not row: raise exc.NoSuchTableError(full_name) return row[1].strip() @@ -1858,326 +2033,163 @@ class MySQLDialect(default.DefaultDialect): try: rp = connection.execute(st) except exc.SQLError, e: - if e.orig.args[0] == 1146: + if self._extract_error_code(e) == 1146: raise exc.NoSuchTableError(full_name) else: raise - rows = _compat_fetchall(rp, charset=charset) + rows = self._compat_fetchall(rp, charset=charset) finally: if rp: rp.close() return rows -class _MySQLPythonRowProxy(object): - """Return consistent column values for all versions of MySQL-python. - - Smooth over data type issues (esp. with alpha driver versions) and - normalize strings as Unicode regardless of user-configured driver - encoding settings. - """ - - # Some MySQL-python versions can return some columns as - # sets.Set(['value']) (seriously) but thankfully that doesn't - # seem to come up in DDL queries. - - def __init__(self, rowproxy, charset): - self.rowproxy = rowproxy - self.charset = charset - def __getitem__(self, index): - item = self.rowproxy[index] - if isinstance(item, _array): - item = item.tostring() - if self.charset and isinstance(item, str): - return item.decode(self.charset) - else: - return item - def __getattr__(self, attr): - item = getattr(self.rowproxy, attr) - if isinstance(item, _array): - item = item.tostring() - if self.charset and isinstance(item, str): - return item.decode(self.charset) - else: - return item - - -class MySQLCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() - operators.update({ - sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y), - sql_operators.mod: '%%', - sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y) - }) - functions = compiler.DefaultCompiler.functions.copy() - functions.update ({ - sql_functions.random: 'rand%(expr)s', - "utc_timestamp":"UTC_TIMESTAMP" - }) - - extract_map = compiler.DefaultCompiler.extract_map.copy() - extract_map.update ({ - 'milliseconds': 'millisecond', - }) - - def visit_typeclause(self, typeclause): - type_ = typeclause.type.dialect_impl(self.dialect) - if isinstance(type_, MSInteger): - if getattr(type_, 'unsigned', False): - return 'UNSIGNED INTEGER' - else: - return 'SIGNED INTEGER' - elif isinstance(type_, (MSDecimal, MSDateTime, MSDate, MSTime)): - return type_.get_col_spec() - elif isinstance(type_, MSText): - return 'CHAR' - elif (isinstance(type_, _StringType) and not - isinstance(type_, (MSEnum, MSSet))): - if getattr(type_, 'length'): - return 'CHAR(%s)' % type_.length - else: - return 'CHAR' - elif isinstance(type_, _BinaryType): - return 'BINARY' - elif isinstance(type_, MSNumeric): - return type_.get_col_spec().replace('NUMERIC', 'DECIMAL') - elif isinstance(type_, MSTimeStamp): - return 'DATETIME' - elif isinstance(type_, (MSDateTime, MSDate, MSTime)): - return type_.get_col_spec() - else: - return None - - def visit_cast(self, cast, **kwargs): - # No cast until 4, no decimals until 5. - type_ = self.process(cast.typeclause) - if type_ is None: - return self.process(cast.clause) - - return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) - - - def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy MySQLDB dialect now automatically escapes '%' in text() expressions to '%%'.") - return text.replace('%', '%%') - - def get_select_precolumns(self, select): - if isinstance(select._distinct, basestring): - return select._distinct.upper() + " " - elif select._distinct: - return "DISTINCT " - else: - return "" - - def visit_join(self, join, asfrom=False, **kwargs): - # 'JOIN ... ON ...' for inner joins isn't available until 4.0. - # Apparently < 3.23.17 requires theta joins for inner joins - # (but not outer). Not generating these currently, but - # support can be added, preferably after dialects are - # refactored to be version-sensitive. - return ''.join( - (self.process(join.left, asfrom=True), - (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "), - self.process(join.right, asfrom=True), - " ON ", - self.process(join.onclause))) - - def for_update_clause(self, select): - if select.for_update == 'read': - return ' LOCK IN SHARE MODE' - else: - return super(MySQLCompiler, self).for_update_clause(select) - - def limit_clause(self, select): - # MySQL supports: - # LIMIT - # LIMIT , - # and in server versions > 3.3: - # LIMIT OFFSET - # The latter is more readable for offsets but we're stuck with the - # former until we can refine dialects by server revision. - - limit, offset = select._limit, select._offset - - if (limit, offset) == (None, None): - return '' - elif offset is not None: - # As suggested by the MySQL docs, need to apply an - # artificial limit if one wasn't provided - if limit is None: - limit = 18446744073709551615 - return ' \n LIMIT %s, %s' % (offset, limit) - else: - # No offset provided, so just use the limit - return ' \n LIMIT %s' % (limit,) - - def visit_update(self, update_stmt): - self.stack.append({'from': set([update_stmt.table])}) - - self.isupdate = True - colparams = self._get_colparams(update_stmt) - - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) - - if update_stmt._whereclause: - text += " WHERE " + self.process(update_stmt._whereclause) - - limit = update_stmt.kwargs.get('mysql_limit', None) - if limit: - text += " LIMIT %s" % limit - - self.stack.pop(-1) - - return text - -# ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. -# Starting with MySQL 4.1.2, these indexes are created automatically. -# In older versions, the indexes must be created explicitly or the -# creation of foreign key constraints fails." - -class MySQLSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, first_pk=False): - """Builds column DDL.""" - - colspec = [self.preparer.format_column(column), - column.type.dialect_impl(self.dialect).get_col_spec()] - - default = self.get_column_default_string(column) - if default is not None: - colspec.append('DEFAULT ' + default) - - if not column.nullable: - colspec.append('NOT NULL') - - if column.primary_key and column.autoincrement: - try: - first = [c for c in column.table.primary_key.columns - if (c.autoincrement and - isinstance(c.type, sqltypes.Integer) and - not c.foreign_keys)].pop(0) - if column is first: - colspec.append('AUTO_INCREMENT') - except IndexError: - pass - - return ' '.join(colspec) - - def post_create_table(self, table): - """Build table-level CREATE options like ENGINE and COLLATE.""" - - table_opts = [] - for k in table.kwargs: - if k.startswith('mysql_'): - opt = k[6:].upper() - joiner = '=' - if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', - 'CHARACTER SET', 'COLLATE'): - joiner = ' ' - - table_opts.append(joiner.join((opt, table.kwargs[k]))) - return ' '.join(table_opts) - - -class MySQLSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - self.append("\nDROP INDEX %s ON %s" % - (self.preparer.quote(self._validate_identifier(index.name, False), index.quote), - self.preparer.format_table(index.table))) - self.execute() - - def drop_foreignkey(self, constraint): - self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % - (self.preparer.format_table(constraint.table), - self.preparer.format_constraint(constraint))) - self.execute() - - -class MySQLSchemaReflector(object): - """Parses SHOW CREATE TABLE output.""" - - def __init__(self, identifier_preparer): - """Construct a MySQLSchemaReflector. - - identifier_preparer - An ANSIIdentifierPreparer type, used to determine the identifier - quoting style in effect. - """ - - self.preparer = identifier_preparer +class ReflectedState(object): + """Stores raw information about a SHOW CREATE TABLE statement.""" + + def __init__(self): + self.columns = [] + self.table_options = {} + self.table_name = None + self.keys = [] + self.constraints = [] + +class MySQLTableDefinitionParser(object): + """Parses the results of a SHOW CREATE TABLE statement.""" + + def __init__(self, dialect, preparer): + self.dialect = dialect + self.preparer = preparer self._prep_regexes() - def reflect(self, connection, table, show_create, charset, only=None): - """Parse MySQL SHOW CREATE TABLE and fill in a ''Table''. - - show_create - Unicode output of SHOW CREATE TABLE - - table - A ''Table'', to be loaded with Columns, Indexes, etc. - table.name will be set if not already - - charset - FIXME, some constructed values (like column defaults) - currently can't be Unicode. ''charset'' will convert them - into the connection character set. - - only - An optional sequence of column names. If provided, only - these columns will be reflected, and any keys or constraints - that include columns outside this set will also be omitted. - That means that if ``only`` includes only one column in a - 2 part primary key, the entire primary key will be omitted. - """ - - keys, constraints = [], [] - - if only: - only = set(only) - + def parse(self, show_create, charset): + state = ReflectedState() + state.charset = charset for line in re.split(r'\r?\n', show_create): if line.startswith(' ' + self.preparer.initial_quote): - self._add_column(table, line, charset, only) + self._parse_column(line, state) # a regular table options line elif line.startswith(') '): - self._set_options(table, line) + self._parse_table_options(line, state) # an ANSI-mode table options line elif line == ')': pass elif line.startswith('CREATE '): - self._set_name(table, line) + self._parse_table_name(line, state) # Not present in real reflection, but may be if loading from a file. elif not line: pass else: - type_, spec = self.parse_constraints(line) + type_, spec = self._parse_constraints(line) if type_ is None: util.warn("Unknown schema content: %r" % line) elif type_ == 'key': - keys.append(spec) + state.keys.append(spec) elif type_ == 'constraint': - constraints.append(spec) + state.constraints.append(spec) else: pass + + return state + + def _parse_constraints(self, line): + """Parse a KEY or CONSTRAINT line. + + line + A line of SHOW CREATE TABLE output + """ + + # KEY + m = self._re_key.match(line) + if m: + spec = m.groupdict() + # convert columns into name, length pairs + spec['columns'] = self._parse_keyexprs(spec['columns']) + return 'key', spec + + # CONSTRAINT + m = self._re_constraint.match(line) + if m: + spec = m.groupdict() + spec['table'] = \ + self.preparer.unformat_identifiers(spec['table']) + spec['local'] = [c[0] + for c in self._parse_keyexprs(spec['local'])] + spec['foreign'] = [c[0] + for c in self._parse_keyexprs(spec['foreign'])] + return 'constraint', spec + + # PARTITION and SUBPARTITION + m = self._re_partition.match(line) + if m: + # Punt! + return 'partition', line + + # No match. + return (None, line) + + def _parse_table_name(self, line, state): + """Extract the table name. + + line + The first line of SHOW CREATE TABLE + """ + + regex, cleanup = self._pr_name + m = regex.match(line) + if m: + state.table_name = cleanup(m.group('name')) + + def _parse_table_options(self, line, state): + """Build a dictionary of all reflected table-level options. + + line + The final line of SHOW CREATE TABLE output. + """ + + options = {} - self._set_keys(table, keys, only) - self._set_constraints(table, constraints, connection, only) + if not line or line == ')': + pass + + else: + r_eq_trim = self._re_options_util['='] + + for regex, cleanup in self._pr_options: + m = regex.search(line) + if not m: + continue + directive, value = m.group('directive'), m.group('val') + directive = r_eq_trim.sub('', directive).lower() + if cleanup: + value = cleanup(value) + options[directive] = value + + for nope in ('auto_increment', 'data_directory', 'index_directory'): + options.pop(nope, None) + + for opt, val in options.items(): + state.table_options['mysql_%s' % opt] = val - def _set_name(self, table, line): - """Override a Table name with the reflected name. + def _parse_column(self, line, state): + """Extract column details. - table - A ``Table`` + Falls back to a 'minimal support' variant if full parse fails. line - The first line of SHOW CREATE TABLE output. + Any column-bearing line from SHOW CREATE TABLE """ - # Don't override by default. - if table.name is None: - table.name = self.parse_name(line) - - def _add_column(self, table, line, charset, only=None): - spec = self.parse_column(line) + charset = state.charset + spec = None + m = self._re_column.match(line) + if m: + spec = m.groupdict() + spec['full'] = True + else: + m = self._re_column_loose.match(line) + if m: + spec = m.groupdict() + spec['full'] = False if not spec: util.warn("Unknown column definition %r" % line) return @@ -2187,18 +2199,13 @@ class MySQLSchemaReflector(object): name, type_, args, notnull = \ spec['name'], spec['coltype'], spec['arg'], spec['notnull'] - if only and name not in only: - self.logger.info("Omitting reflected column %s.%s" % - (table.name, name)) - return - # Convention says that TINYINT(1) columns == BOOLEAN if type_ == 'tinyint' and args == '1': type_ = 'boolean' args = None try: - col_type = ischema_names[type_] + col_type = self.dialect.ischema_names[type_] except KeyError: util.warn("Did not recognize type '%s' of column '%s'" % (type_, name)) @@ -2229,6 +2236,7 @@ class MySQLSchemaReflector(object): col_args, col_kw = [], {} # NOT NULL + col_kw['nullable'] = True if spec.get('notnull', False): col_kw['nullable'] = False @@ -2240,131 +2248,64 @@ class MySQLSchemaReflector(object): # DEFAULT default = spec.get('default', None) - if default is not None and default != 'NULL': - # Defaults should be in the native charset for the moment - default = default.encode(charset) - if type_ == 'timestamp': - # can't be NULL for TIMESTAMPs - if (default[0], default[-1]) != ("'", "'"): - default = sql.text(default) - else: - default = default[1:-1] - col_args.append(schema.DefaultClause(default)) - - table.append_column(schema.Column(name, type_instance, - *col_args, **col_kw)) - def _set_keys(self, table, keys, only): - """Add ``Index`` and ``PrimaryKeyConstraint`` items to a ``Table``. + if default == 'NULL': + # eliminates the need to deal with this later. + default = None + + col_d = dict(name=name, type=type_instance, default=default) + col_d.update(col_kw) + state.columns.append(col_d) - Most of the information gets dropped here- more is reflected than - the schema objects can currently represent. - - table - A ``Table`` + def _describe_to_create(self, table_name, columns): + """Re-format DESCRIBE output as a SHOW CREATE TABLE string. - keys - A sequence of key specifications produced by `constraints` + DESCRIBE is a much simpler reflection and is sufficient for + reflecting views for runtime use. This method formats DDL + for columns only- keys are omitted. - only - Optional `set` of column names. If provided, keys covering - columns not in this set will be omitted. + `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples. + SHOW FULL COLUMNS FROM rows must be rearranged for use with + this function. """ - for spec in keys: - flavor = spec['type'] - col_names = [s[0] for s in spec['columns']] - - if only and not set(col_names).issubset(only): - if flavor is None: - flavor = 'index' - self.logger.info( - "Omitting %s KEY for (%s), key covers ommitted columns." % - (flavor, ', '.join(col_names))) - continue - - constraint = False - if flavor == 'PRIMARY': - key = schema.PrimaryKeyConstraint() - constraint = True - elif flavor == 'UNIQUE': - key = schema.Index(spec['name'], unique=True) - elif flavor in (None, 'FULLTEXT', 'SPATIAL'): - key = schema.Index(spec['name']) - else: - self.logger.info( - "Converting unknown KEY type %s to a plain KEY" % flavor) - key = schema.Index(spec['name']) - - for col in [table.c[name] for name in col_names]: - key.append_column(col) - - if constraint: - table.append_constraint(key) - - def _set_constraints(self, table, constraints, connection, only): - """Apply constraints to a ``Table``.""" - - default_schema = None - - for spec in constraints: - # only FOREIGN KEYs - ref_name = spec['table'][-1] - ref_schema = len(spec['table']) > 1 and spec['table'][-2] or table.schema - - if not ref_schema: - if default_schema is None: - default_schema = connection.dialect.get_default_schema_name( - connection) - if table.schema == default_schema: - ref_schema = table.schema - - loc_names = spec['local'] - if only and not set(loc_names).issubset(only): - self.logger.info( - "Omitting FOREIGN KEY for (%s), key covers ommitted " - "columns." % (', '.join(loc_names))) - continue - - ref_key = schema._get_table_key(ref_name, ref_schema) - if ref_key in table.metadata.tables: - ref_table = table.metadata.tables[ref_key] - else: - ref_table = schema.Table( - ref_name, table.metadata, schema=ref_schema, - autoload=True, autoload_with=connection) - - ref_names = spec['foreign'] - - if ref_schema: - refspec = [".".join([ref_schema, ref_name, column]) for column in ref_names] - else: - refspec = [".".join([ref_name, column]) for column in ref_names] - - con_kw = {} - for opt in ('name', 'onupdate', 'ondelete'): - if spec.get(opt, False): - con_kw[opt] = spec[opt] - - key = schema.ForeignKeyConstraint(loc_names, refspec, link_to_name=True, **con_kw) - table.append_constraint(key) + buffer = [] + for row in columns: + (name, col_type, nullable, default, extra) = \ + [row[i] for i in (0, 1, 2, 4, 5)] - def _set_options(self, table, line): - """Apply safe reflected table options to a ``Table``. + line = [' '] + line.append(self.preparer.quote_identifier(name)) + line.append(col_type) + if not nullable: + line.append('NOT NULL') + if default: + if 'auto_increment' in default: + pass + elif (col_type.startswith('timestamp') and + default.startswith('C')): + line.append('DEFAULT') + line.append(default) + elif default == 'NULL': + line.append('DEFAULT') + line.append(default) + else: + line.append('DEFAULT') + line.append("'%s'" % default.replace("'", "''")) + if extra: + line.append(extra) - table - A ``Table`` + buffer.append(' '.join(line)) - line - The final line of SHOW CREATE TABLE output. - """ + return ''.join([('CREATE TABLE %s (\n' % + self.preparer.quote_identifier(table_name)), + ',\n'.join(buffer), + '\n) ']) - options = self.parse_table_options(line) - for nope in ('auto_increment', 'data_directory', 'index_directory'): - options.pop(nope, None) + def _parse_keyexprs(self, identifiers): + """Unpack '"col"(2),"col" ASC'-ish strings into components.""" - for opt, val in options.items(): - table.kwargs['mysql_%s' % opt] = val + return self._re_keyexprs.findall(identifiers) def _prep_regexes(self): """Pre-compile regular expressions.""" @@ -2522,154 +2463,42 @@ class MySQLSchemaReflector(object): r'(?P%s)' % (re.escape(directive), regex)) self._pr_options.append(_pr_compile(regex)) +log.class_logger(MySQLTableDefinitionParser) +log.class_logger(MySQLDialect) - def parse_name(self, line): - """Extract the table name. - - line - The first line of SHOW CREATE TABLE - """ - - regex, cleanup = self._pr_name - m = regex.match(line) - if not m: - return None - return cleanup(m.group('name')) - - def parse_column(self, line): - """Extract column details. - - Falls back to a 'minimal support' variant if full parse fails. - - line - Any column-bearing line from SHOW CREATE TABLE - """ - - m = self._re_column.match(line) - if m: - spec = m.groupdict() - spec['full'] = True - return spec - m = self._re_column_loose.match(line) - if m: - spec = m.groupdict() - spec['full'] = False - return spec - return None - - def parse_constraints(self, line): - """Parse a KEY or CONSTRAINT line. - - line - A line of SHOW CREATE TABLE output - """ - - # KEY - m = self._re_key.match(line) - if m: - spec = m.groupdict() - # convert columns into name, length pairs - spec['columns'] = self._parse_keyexprs(spec['columns']) - return 'key', spec - - # CONSTRAINT - m = self._re_constraint.match(line) - if m: - spec = m.groupdict() - spec['table'] = \ - self.preparer.unformat_identifiers(spec['table']) - spec['local'] = [c[0] - for c in self._parse_keyexprs(spec['local'])] - spec['foreign'] = [c[0] - for c in self._parse_keyexprs(spec['foreign'])] - return 'constraint', spec - - # PARTITION and SUBPARTITION - m = self._re_partition.match(line) - if m: - # Punt! - return 'partition', line - - # No match. - return (None, line) - - def parse_table_options(self, line): - """Build a dictionary of all reflected table-level options. - - line - The final line of SHOW CREATE TABLE output. - """ - - options = {} - - if not line or line == ')': - return options - - r_eq_trim = self._re_options_util['='] - - for regex, cleanup in self._pr_options: - m = regex.search(line) - if not m: - continue - directive, value = m.group('directive'), m.group('val') - directive = r_eq_trim.sub('', directive).lower() - if cleanup: - value = cleanup(value) - options[directive] = value - - return options - - def _describe_to_create(self, table, columns): - """Re-format DESCRIBE output as a SHOW CREATE TABLE string. - - DESCRIBE is a much simpler reflection and is sufficient for - reflecting views for runtime use. This method formats DDL - for columns only- keys are omitted. - - `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples. - SHOW FULL COLUMNS FROM rows must be rearranged for use with - this function. - """ - - buffer = [] - for row in columns: - (name, col_type, nullable, default, extra) = \ - [row[i] for i in (0, 1, 2, 4, 5)] - - line = [' '] - line.append(self.preparer.quote_identifier(name)) - line.append(col_type) - if not nullable: - line.append('NOT NULL') - if default: - if 'auto_increment' in default: - pass - elif (col_type.startswith('timestamp') and - default.startswith('C')): - line.append('DEFAULT') - line.append(default) - elif default == 'NULL': - line.append('DEFAULT') - line.append(default) - else: - line.append('DEFAULT') - line.append("'%s'" % default.replace("'", "''")) - if extra: - line.append(extra) - buffer.append(' '.join(line)) +class _DecodingRowProxy(object): + """Return unicode-decoded values based on type inspection. - return ''.join([('CREATE TABLE %s (\n' % - self.preparer.quote_identifier(table.name)), - ',\n'.join(buffer), - '\n) ']) + Smooth over data type issues (esp. with alpha driver versions) and + normalize strings as Unicode regardless of user-configured driver + encoding settings. - def _parse_keyexprs(self, identifiers): - """Unpack '"col"(2),"col" ASC'-ish strings into components.""" + """ - return self._re_keyexprs.findall(identifiers) + # Some MySQL-python versions can return some columns as + # sets.Set(['value']) (seriously) but thankfully that doesn't + # seem to come up in DDL queries. -log.class_logger(MySQLSchemaReflector) + def __init__(self, rowproxy, charset): + self.rowproxy = rowproxy + self.charset = charset + def __getitem__(self, index): + item = self.rowproxy[index] + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, str): + return item.decode(self.charset) + else: + return item + def __getattr__(self, attr): + item = getattr(self.rowproxy, attr) + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, str): + return item.decode(self.charset) + else: + return item class _MySQLIdentifierPreparer(compiler.IdentifierPreparer): @@ -2691,7 +2520,7 @@ class MySQLIdentifierPreparer(_MySQLIdentifierPreparer): def __init__(self, dialect): super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote="`") - + def _escape_identifier(self, value): return value.replace('`', '``') @@ -2704,17 +2533,6 @@ class MySQLANSIIdentifierPreparer(_MySQLIdentifierPreparer): pass - -def _compat_fetchall(rp, charset=None): - """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" - - return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()] - -def _compat_fetchone(rp, charset=None): - """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" - - return _MySQLPythonRowProxy(rp.fetchone(), charset) - def _pr_compile(regex, cleanup=None): """Prepare a 2-tuple of compiled regex and callable.""" @@ -2725,8 +2543,3 @@ def _re_compile(regex): return re.compile(regex, re.I | re.UNICODE) -dialect = MySQLDialect -dialect.statement_compiler = MySQLCompiler -dialect.schemagenerator = MySQLSchemaGenerator -dialect.schemadropper = MySQLSchemaDropper -dialect.execution_ctx_cls = MySQLExecutionContext diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py new file mode 100644 index 0000000000..6ecfc4b845 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -0,0 +1,194 @@ +"""Support for the MySQL database via the MySQL-python adapter. + +Character Sets +-------------- + +Many MySQL server installations default to a ``latin1`` encoding for client +connections. All data sent through the connection will be converted into +``latin1``, even if you have ``utf8`` or another character set on your tables +and columns. With versions 4.1 and higher, you can change the connection +character set either through server configuration or by including the +``charset`` parameter in the URL used for ``create_engine``. The ``charset`` +option is passed through to MySQL-Python and has the side-effect of also +enabling ``use_unicode`` in the driver by default. For regular encoded +strings, also pass ``use_unicode=0`` in the connection arguments:: + + # set client encoding to utf8; all strings come back as unicode + create_engine('mysql:///mydb?charset=utf8') + + # set client encoding to utf8; all strings come back as utf8 str + create_engine('mysql:///mydb?charset=utf8&use_unicode=0') +""" + +import decimal +import re + +from sqlalchemy.dialects.mysql.base import (DECIMAL, MySQLDialect, MySQLExecutionContext, + MySQLCompiler, NUMERIC, _NumericType) +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import exc, log, schema, sql, types as sqltypes, util + +class MySQL_mysqldbExecutionContext(MySQLExecutionContext): + + @property + def rowcount(self): + if hasattr(self, '_rowcount'): + return self._rowcount + else: + return self.cursor.rowcount + + +class MySQL_mysqldbCompiler(MySQLCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + def post_process_text(self, text): + return text.replace('%', '%%') + + +class _DecimalType(_NumericType): + def result_processor(self, dialect): + if self.asdecimal: + return + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + + +class _MySQLdbNumeric(_DecimalType, NUMERIC): + pass + + +class _MySQLdbDecimal(_DecimalType, DECIMAL): + pass + + +class MySQL_mysqldb(MySQLDialect): + driver = 'mysqldb' + supports_unicode_statements = False + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + default_paramstyle = 'format' + execution_ctx_cls = MySQL_mysqldbExecutionContext + statement_compiler = MySQL_mysqldbCompiler + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + sqltypes.Numeric: _MySQLdbNumeric, + DECIMAL: _MySQLdbDecimal + } + ) + + @classmethod + def dbapi(cls): + return __import__('MySQLdb') + + def do_executemany(self, cursor, statement, parameters, context=None): + rowcount = cursor.executemany(statement, parameters) + if context is not None: + context._rowcount = rowcount + + def create_connect_args(self, url): + opts = url.translate_connect_args(database='db', username='user', + password='passwd') + opts.update(url.query) + + util.coerce_kw_type(opts, 'compress', bool) + util.coerce_kw_type(opts, 'connect_timeout', int) + util.coerce_kw_type(opts, 'client_flag', int) + util.coerce_kw_type(opts, 'local_infile', int) + # Note: using either of the below will cause all strings to be returned + # as Unicode, both in raw SQL operations and with column types like + # String and MSString. + util.coerce_kw_type(opts, 'use_unicode', bool) + util.coerce_kw_type(opts, 'charset', str) + + # Rich values 'cursorclass' and 'conv' are not supported via + # query string. + + ssl = {} + for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']: + if key in opts: + ssl[key[4:]] = opts[key] + util.coerce_kw_type(ssl, key[4:], str) + del opts[key] + if ssl: + opts['ssl'] = ssl + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get('client_flag', 0) + if self.dbapi is not None: + try: + from MySQLdb.constants import CLIENT as CLIENT_FLAGS + client_flag |= CLIENT_FLAGS.FOUND_ROWS + except: + pass + opts['client_flag'] = client_flag + return [[], opts] + + def do_ping(self, connection): + connection.ping() + + def _get_server_version_info(self, connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.get_server_info()): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def _extract_error_code(self, exception): + try: + return exception.orig.args[0] + except AttributeError: + return None + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Note: MySQL-python 1.2.1c7 seems to ignore changes made + # on a connection via set_character_set() + if self.server_version_info < (4, 1, 0): + try: + return connection.connection.character_set_name() + except AttributeError: + # < 1.2.1 final MySQL-python drivers have no charset support. + # a query is needed. + pass + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)]) + + if 'character_set_results' in opts: + return opts['character_set_results'] + try: + return connection.connection.character_set_name() + except AttributeError: + # Still no charset on < 1.2.1 final... + if 'character_set' in opts: + return opts['character_set'] + else: + util.warn( + "Could not detect the connection character set with this " + "combination of MySQL server and MySQL-python. " + "MySQL-python >= 1.2.2 is recommended. Assuming latin1.") + return 'latin1' + + +dialect = MySQL_mysqldb diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py new file mode 100644 index 0000000000..1ea7ec8646 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -0,0 +1,54 @@ +from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext +from sqlalchemy.connectors.pyodbc import PyODBCConnector +from sqlalchemy.engine import base as engine_base +from sqlalchemy import util +import re + +class MySQL_pyodbcExecutionContext(MySQLExecutionContext): + + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + +class MySQL_pyodbc(PyODBCConnector, MySQLDialect): + supports_unicode_statements = False + execution_ctx_cls = MySQL_pyodbcExecutionContext + + pyodbc_driver_name = "MySQL" + + def __init__(self, **kw): + # deal with http://code.google.com/p/pyodbc/issues/detail?id=25 + kw.setdefault('convert_unicode', True) + MySQLDialect.__init__(self, **kw) + PyODBCConnector.__init__(self, **kw) + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)]) + for key in ('character_set_connection', 'character_set'): + if opts.get(key, None): + return opts[key] + + util.warn("Could not detect the connection character set. Assuming latin1.") + return 'latin1' + + def _extract_error_code(self, exception): + m = re.compile(r"\((\d+)\)").search(str(exception.orig.args)) + c = m.group(1) + if c: + return int(c) + else: + return None + +dialect = MySQL_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py new file mode 100644 index 0000000000..6cdc6f4386 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py @@ -0,0 +1,95 @@ +"""Support for the MySQL database via Jython's zxjdbc JDBC connector. + +JDBC Driver +----------- + +The official MySQL JDBC driver is at +http://dev.mysql.com/downloads/connector/j/. + +""" +import re + +from sqlalchemy import types as sqltypes, util +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext + +class _JDBCBit(BIT): + def result_processor(self, dialect): + """Converts boolean or byte arrays from MySQL Connector/J to longs.""" + def process(value): + if value is None: + return value + if isinstance(value, bool): + return int(value) + v = 0L + for i in value: + v = v << 8 | (i & 0xff) + value = v + return value + return process + + +class MySQL_jdbcExecutionContext(MySQLExecutionContext): + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + + +class MySQL_jdbc(ZxJDBCConnector, MySQLDialect): + execution_ctx_cls = MySQL_jdbcExecutionContext + + jdbc_db_name = 'mysql' + jdbc_driver_name = 'com.mysql.jdbc.Driver' + + colspecs = util.update_copy( + MySQLDialect.colspecs, + { + sqltypes.Time: sqltypes.Time, + BIT: _JDBCBit + } + ) + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict((row[0], row[1]) for row in self._compat_fetchall(rs)) + for key in ('character_set_connection', 'character_set'): + if opts.get(key, None): + return opts[key] + + util.warn("Could not detect the connection character set. Assuming latin1.") + return 'latin1' + + def _driver_kwargs(self): + """return kw arg dict to be sent to connect().""" + return dict(CHARSET=self.encoding, yearIsDateType='false') + + def _extract_error_code(self, exception): + # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist + # [SQLCode: 1146], [SQLState: 42S02] 'DESCRIBE `u2`' () + m = re.compile(r"\[SQLCode\: (\d+)\]").search(str(exception.orig.args)) + c = m.group(1) + if c: + return int(c) + + def _get_server_version_info(self,connection): + dbapi_con = connection.connection + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.dbversion): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + +dialect = MySQL_jdbc diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py new file mode 100644 index 0000000000..3b4379ab70 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc + +base.dialect = cx_oracle.dialect diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py new file mode 100644 index 0000000000..419ccedb16 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -0,0 +1,904 @@ +# oracle/base.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Support for the Oracle database. + +Oracle version 8 through current (11g at the time of this writing) are supported. + +For information on connecting via specific drivers, see the documentation +for that driver. + +Connect Arguments +----------------- + +The dialect supports several :func:`~sqlalchemy.create_engine()` arguments which +affect the behavior of the dialect regardless of driver in use. + +* *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8). Defaults + to ``True``. If ``False``, Oracle-8 compatible constructs are used for joins. + +* *optimize_limits* - defaults to ``False``. see the section on LIMIT/OFFSET. + +Auto Increment Behavior +----------------------- + +SQLAlchemy Table objects which include integer primary keys are usually assumed to have +"autoincrementing" behavior, meaning they can generate their own primary key values upon +INSERT. Since Oracle has no "autoincrement" feature, SQLAlchemy relies upon sequences +to produce these values. With the Oracle dialect, *a sequence must always be explicitly +specified to enable autoincrement*. This is divergent with the majority of documentation +examples which assume the usage of an autoincrement-capable database. To specify sequences, +use the sqlalchemy.schema.Sequence object which is passed to a Column construct:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq'), primary_key=True), + Column(...), ... + ) + +This step is also required when using table reflection, i.e. autoload=True:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq'), primary_key=True), + autoload=True + ) + +Identifier Casing +----------------- + +In Oracle, the data dictionary represents all case insensitive identifier names +using UPPERCASE text. SQLAlchemy on the other hand considers an all-lower case identifier +name to be case insensitive. The Oracle dialect converts all case insensitive identifiers +to and from those two formats during schema level communication, such as reflection of +tables and indexes. Using an UPPERCASE name on the SQLAlchemy side indicates a +case sensitive identifier, and SQLAlchemy will quote the name - this will cause mismatches +against data dictionary data received from Oracle, so unless identifier names have been +truly created as case sensitive (i.e. using quoted names), all lowercase names should be +used on the SQLAlchemy side. + +Unicode +------- + +SQLAlchemy 0.6 uses the "native unicode" mode provided as of cx_oracle 5. cx_oracle 5.0.2 +or greater is recommended for support of NCLOB. If not using cx_oracle 5, the NLS_LANG +environment variable needs to be set in order for the oracle client library to use +proper encoding, such as "AMERICAN_AMERICA.UTF8". + +Also note that Oracle supports unicode data through the NVARCHAR and NCLOB data types. +When using the SQLAlchemy Unicode and UnicodeText types, these DDL types will be used +within CREATE TABLE statements. Usage of VARCHAR2 and CLOB with unicode text still +requires NLS_LANG to be set. + +LIMIT/OFFSET Support +-------------------- + +Oracle has no support for the LIMIT or OFFSET keywords. Whereas previous versions of SQLAlchemy +used the "ROW NUMBER OVER..." construct to simulate LIMIT/OFFSET, SQLAlchemy 0.5 now uses +a wrapped subquery approach in conjunction with ROWNUM. The exact methodology is taken from +http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html . Note that the +"FIRST ROWS()" optimization keyword mentioned is not used by default, as the user community felt +this was stepping into the bounds of optimization that is better left on the DBA side, but this +prefix can be added by enabling the optimize_limits=True flag on create_engine(). + +ON UPDATE CASCADE +----------------- + +Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based solution +is available at http://asktom.oracle.com/tkyte/update_cascade/index.html . + +When using the SQLAlchemy ORM, the ORM has limited ability to manually issue +cascading updates - specify ForeignKey objects using the +"deferrable=True, initially='deferred'" keyword arguments, +and specify "passive_updates=False" on each relation(). + +Oracle 8 Compatibility +---------------------- + +When using Oracle 8, a "use_ansi=False" flag is available which converts all +JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN +makes use of Oracle's (+) operator. + +Synonym/DBLINK Reflection +------------------------- + +When using reflection with Table objects, the dialect can optionally search for tables +indicated by synonyms that reference DBLINK-ed tables by passing the flag +oracle_resolve_synonyms=True as a keyword argument to the Table construct. If DBLINK +is not in use this flag should be left off. + +""" + +import random, re + +from sqlalchemy import schema as sa_schema +from sqlalchemy import util, sql, log +from sqlalchemy.engine import default, base, reflection +from sqlalchemy.sql import compiler, visitors, expression +from sqlalchemy.sql import operators as sql_operators, functions as sql_functions +from sqlalchemy import types as sqltypes +from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, DATE, DATETIME, \ + BLOB, CLOB, TIMESTAMP, FLOAT + +RESERVED_WORDS = set('''SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR DECIMAL UNION PUBLIC AND START UID COMMENT'''.split()) + +class RAW(sqltypes.Binary): + pass +OracleRaw = RAW + +class NCLOB(sqltypes.Text): + __visit_name__ = 'NCLOB' + +VARCHAR2 = VARCHAR +NVARCHAR2 = NVARCHAR + +class NUMBER(sqltypes.Numeric): + __visit_name__ = 'NUMBER' + +class BFILE(sqltypes.Binary): + __visit_name__ = 'BFILE' + +class DOUBLE_PRECISION(sqltypes.Numeric): + __visit_name__ = 'DOUBLE_PRECISION' + +class LONG(sqltypes.Text): + __visit_name__ = 'LONG' + +class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +colspecs = { + sqltypes.Boolean : _OracleBoolean, +} + +ischema_names = { + 'VARCHAR2' : VARCHAR, + 'NVARCHAR2' : NVARCHAR, + 'CHAR' : CHAR, + 'DATE' : DATE, + 'DATETIME' : DATETIME, + 'NUMBER' : NUMBER, + 'BLOB' : BLOB, + 'BFILE' : BFILE, + 'CLOB' : CLOB, + 'NCLOB' : NCLOB, + 'TIMESTAMP' : TIMESTAMP, + 'RAW' : RAW, + 'FLOAT' : FLOAT, + 'DOUBLE PRECISION' : DOUBLE_PRECISION, + 'LONG' : LONG, +} + + +class OracleTypeCompiler(compiler.GenericTypeCompiler): + # Note: + # Oracle DATE == DATETIME + # Oracle does not allow milliseconds in DATE + # Oracle does not support TIME columns + + def visit_datetime(self, type_): + return self.visit_DATE(type_) + + def visit_float(self, type_): + if type_.precision is None: + return "NUMERIC" + else: + return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : 2} + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_VARCHAR(self, type_): + return "VARCHAR(%(length)s)" % {'length' : type_.length} + + def visit_NVARCHAR(self, type_): + return "NVARCHAR2(%(length)s)" % {'length' : type_.length} + + def visit_text(self, type_): + return self.visit_CLOB(type_) + + def visit_unicode_text(self, type_): + return self.visit_NCLOB(type_) + + def visit_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_boolean(self, type_): + return self.visit_SMALLINT(type_) + + def visit_RAW(self, type_): + return "RAW(%(length)s)" % {'length' : type_.length} + +class OracleCompiler(compiler.SQLCompiler): + """Oracle compiler modifies the lexical structure of Select + statements to work under non-ANSI configured Oracle databases, if + the use_ansi flag is False. + """ + + def __init__(self, *args, **kwargs): + super(OracleCompiler, self).__init__(*args, **kwargs) + self.__wheres = {} + self._quoted_bind_names = {} + + def visit_mod(self, binary, **kw): + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "LENGTH" + self.function_argspec(fn, **kw) + + def visit_match_op(self, binary, **kw): + return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def function_argspec(self, fn, **kw): + if len(fn.clauses) > 0: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) + else: + return "" + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. + + The Oracle compiler tacks a "FROM DUAL" to the statement. + """ + + return " FROM DUAL" + + def visit_join(self, join, **kwargs): + if self.dialect.use_ansi: + return compiler.SQLCompiler.visit_join(self, join, **kwargs) + else: + return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) + + def _get_nonansi_join_whereclause(self, froms): + clauses = [] + + def visit_join(join): + if join.isouter: + def visit_binary(binary): + if binary.operator == sql_operators.eq: + if binary.left.table is join.right: + binary.left = _OuterJoinColumn(binary.left) + elif binary.right.table is join.right: + binary.right = _OuterJoinColumn(binary.right) + clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary})) + else: + clauses.append(join.onclause) + + for f in froms: + visitors.traverse(f, {}, {'join':visit_join}) + return sql.and_(*clauses) + + def visit_outer_join_column(self, vc): + return self.process(vc.column) + "(+)" + + def visit_sequence(self, seq): + return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" + + def visit_alias(self, alias, asfrom=False, **kwargs): + """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" + + if asfrom: + alias_name = isinstance(alias.name, expression._generated_label) and \ + self._truncated_identifier("alias", alias.name) or alias.name + + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, alias_name) + else: + return self.process(alias.original, **kwargs) + + def returning_clause(self, stmt, returning_cols): + + def create_out_param(col, i): + bindparam = sql.outparam("ret_%d" % i, type_=col.type) + self.binds[bindparam.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + + columnlist = list(expression._select_iterables(returning_cols)) + + # within_columns_clause =False so that labels (foo AS bar) don't render + columns = [self.process(c, within_columns_clause=False) for c in columnlist] + + binds = [create_out_param(c, i) for i, c in enumerate(columnlist)] + + return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + + def _TODO_visit_compound_select(self, select): + """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" + pass + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``rownum`` criterion. + """ + + if not getattr(select, '_oracle_visit', None): + if not self.dialect.use_ansi: + if self.stack and 'from' in self.stack[-1]: + existingfroms = self.stack[-1]['from'] + else: + existingfroms = None + + froms = select._get_display_froms(existingfroms) + whereclause = self._get_nonansi_join_whereclause(froms) + if whereclause: + select = select.where(whereclause) + select._oracle_visit = True + + if select._limit is not None or select._offset is not None: + # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html + # + # Generalized form of an Oracle pagination query: + # select ... from ( + # select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from ( + # select distinct ... where ... order by ... + # ) where ROWNUM <= :limit+:offset + # ) where ora_rn > :offset + # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0 + + # TODO: use annotations instead of clone + attr set ? + select = select._generate() + select._oracle_visit = True + + # Wrap the middle select and add the hint + limitselect = sql.select([c for c in select.c]) + if select._limit and self.dialect.optimize_limits: + limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit) + + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + # If needed, add the limiting clause + if select._limit is not None: + max_row = select._limit + if select._offset is not None: + max_row += select._offset + limitselect.append_whereclause( + sql.literal_column("ROWNUM")<=max_row) + + # If needed, add the ora_rn, and wrap again with offset. + if select._offset is None: + select = limitselect + else: + limitselect = limitselect.column( + sql.literal_column("ROWNUM").label("ora_rn")) + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + offsetselect = sql.select( + [c for c in limitselect.c if c.key!='ora_rn']) + offsetselect._oracle_visit = True + offsetselect._is_wrapper = True + + offsetselect.append_whereclause( + sql.literal_column("ora_rn")>select._offset) + + select = offsetselect + + kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def limit_clause(self, select): + return "" + + def for_update_clause(self, select): + if select.for_update == "nowait": + return " FOR UPDATE NOWAIT" + else: + return super(OracleCompiler, self).for_update_clause(select) + +class OracleDDLCompiler(compiler.DDLCompiler): + + def visit_create_sequence(self, create): + return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + + def visit_drop_sequence(self, drop): + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + + # oracle has no ON UPDATE CASCADE - + # its only available via triggers http://asktom.oracle.com/tkyte/update_cascade/index.html + if constraint.onupdate is not None: + util.warn( + "Oracle does not contain native UPDATE CASCADE " + "functionality - onupdates will not be rendered for foreign keys." + "Consider using deferrable=True, initially='deferred' or triggers.") + + return text + +class OracleDefaultRunner(base.DefaultRunner): + def visit_sequence(self, seq): + return self.execute_string("SELECT " + + self.dialect.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", ()) + +class OracleIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = set([x.lower() for x in RESERVED_WORDS]) + illegal_initial_characters = re.compile(r'[0-9_$]') + + def _bindparam_requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return (lc_value in self.reserved_words + or self.illegal_initial_characters.match(value[0]) + or not self.legal_characters.match(unicode(value)) + ) + + def format_savepoint(self, savepoint): + name = re.sub(r'^_+', '', savepoint.ident) + return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) + +class OracleDialect(default.DefaultDialect): + name = 'oracle' + supports_alter = True + supports_unicode_statements = False + supports_unicode_binds = False + max_identifier_length = 30 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + supports_sequences = True + sequences_optional = False + postfetch_lastrowid = False + + default_paramstyle = 'named' + colspecs = colspecs + ischema_names = ischema_names + requires_name_normalize = True + + supports_default_values = False + supports_empty_insert = False + + statement_compiler = OracleCompiler + ddl_compiler = OracleDDLCompiler + type_compiler = OracleTypeCompiler + preparer = OracleIdentifierPreparer + defaultrunner = OracleDefaultRunner + + reflection_options = ('oracle_resolve_synonyms', ) + + + def __init__(self, + use_ansi=True, + optimize_limits=False, + **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + self.use_ansi = use_ansi + self.optimize_limits = optimize_limits + +# TODO: implement server_version_info for oracle +# def initialize(self, connection): +# super(OracleDialect, self).initialize(connection) +# self.implicit_returning = self.server_version_info > (10, ) and \ +# self.__dict__.get('implicit_returning', True) + + def do_release_savepoint(self, connection, name): + # Oracle does not support RELEASE SAVEPOINT + pass + + def has_table(self, connection, table_name, schema=None): + if not schema: + schema = self.get_default_schema_name(connection) + cursor = connection.execute( + sql.text("SELECT table_name FROM all_tables " + "WHERE table_name = :name AND owner = :schema_name"), + name=self.denormalize_name(table_name), schema_name=self.denormalize_name(schema)) + return cursor.fetchone() is not None + + def has_sequence(self, connection, sequence_name, schema=None): + if not schema: + schema = self.get_default_schema_name(connection) + cursor = connection.execute( + sql.text("SELECT sequence_name FROM all_sequences " + "WHERE sequence_name = :name AND sequence_owner = :schema_name"), + name=self.denormalize_name(sequence_name), schema_name=self.denormalize_name(schema)) + return cursor.fetchone() is not None + + def normalize_name(self, name): + if name is None: + return None + elif (name.upper() == name and + not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding))): + return name.lower().decode(self.encoding) + else: + return name.decode(self.encoding) + + def denormalize_name(self, name): + if name is None: + return None + elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): + return name.upper().encode(self.encoding) + else: + return name.encode(self.encoding) + + def get_default_schema_name(self, connection): + return self.normalize_name(connection.execute('SELECT USER FROM DUAL').scalar()) + + def table_names(self, connection, schema): + # note that table_names() isnt loading DBLINKed or synonym'ed tables + if schema is None: + cursor = connection.execute( + "SELECT table_name FROM all_tables " + "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')") + else: + s = sql.text( + "SELECT table_name FROM all_tables " + "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') " + "AND OWNER = :owner") + cursor = connection.execute(s, owner=self.denormalize_name(schema)) + return [self.normalize_name(row[0]) for row in cursor] + + def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None): + """search for a local synonym matching the given desired owner/name. + + if desired_owner is None, attempts to locate a distinct owner. + + returns the actual name, owner, dblink name, and synonym name if found. + """ + + q = "SELECT owner, table_owner, table_name, db_link, synonym_name FROM all_synonyms WHERE " + clauses = [] + params = {} + if desired_synonym: + clauses.append("synonym_name = :synonym_name") + params['synonym_name'] = desired_synonym + if desired_owner: + clauses.append("table_owner = :desired_owner") + params['desired_owner'] = desired_owner + if desired_table: + clauses.append("table_name = :tname") + params['tname'] = desired_table + + q += " AND ".join(clauses) + + result = connection.execute(sql.text(q), **params) + if desired_owner: + row = result.fetchone() + if row: + return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name'] + else: + return None, None, None, None + else: + rows = result.fetchall() + if len(rows) > 1: + raise AssertionError("There are multiple tables visible to the schema, you must specify owner") + elif len(rows) == 1: + row = rows[0] + return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name'] + else: + return None, None, None, None + + @reflection.cache + def _prepare_reflection_args(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + + if resolve_synonyms: + actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self.denormalize_name(schema), desired_synonym=self.denormalize_name(table_name)) + else: + actual_name, owner, dblink, synonym = None, None, None, None + if not actual_name: + actual_name = self.denormalize_name(table_name) + if not dblink: + dblink = '' + if not owner: + owner = self.denormalize_name(schema or self.get_default_schema_name(connection)) + return (actual_name, owner, dblink, synonym) + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = "SELECT username FROM all_users ORDER BY username" + cursor = connection.execute(s,) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + schema = self.denormalize_name(schema or self.get_default_schema_name(connection)) + return self.table_names(connection, schema) + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + schema = self.denormalize_name(schema or self.get_default_schema_name(connection)) + s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner") + cursor = connection.execute(s, owner=self.denormalize_name(schema)) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + columns = [] + c = connection.execute(sql.text( + "SELECT column_name, data_type, data_length, data_precision, data_scale, " + "nullable, data_default FROM ALL_TAB_COLUMNS%(dblink)s " + "WHERE table_name = :table_name AND owner = :owner" % {'dblink': dblink}), + table_name=table_name, owner=schema) + + for row in c: + (colname, coltype, length, precision, scale, nullable, default) = \ + (self.normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) + + # INTEGER if the scale is 0 and precision is null + # NUMBER if the scale and precision are both null + # NUMBER(9,2) if the precision is 9 and the scale is 2 + # NUMBER(3) if the precision is 3 and scale is 0 + #length is ignored except for CHAR and VARCHAR2 + if coltype == 'NUMBER' : + if precision is None and scale is None: + coltype = sqltypes.NUMERIC + elif precision is None and scale == 0: + coltype = sqltypes.INTEGER + else : + coltype = sqltypes.NUMERIC(precision, scale) + elif coltype=='CHAR' or coltype=='VARCHAR2': + coltype = self.ischema_names.get(coltype)(length) + else: + coltype = re.sub(r'\(\d+\)', '', coltype) + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, colname)) + coltype = sqltypes.NULLTYPE + + cdict = { + 'name': colname, + 'type': coltype, + 'nullable': nullable, + 'default': default, + } + columns.append(cdict) + return columns + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + + + info_cache = kw.get('info_cache') + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + indexes = [] + q = sql.text(""" + SELECT a.index_name, a.column_name, b.uniqueness + FROM ALL_IND_COLUMNS%(dblink)s a + INNER JOIN ALL_INDEXES%(dblink)s b + ON a.index_name = b.index_name + AND a.table_owner = b.table_owner + AND a.table_name = b.table_name + WHERE a.table_name = :table_name + AND a.table_owner = :schema + ORDER BY a.index_name, a.column_position""" % {'dblink': dblink}) + rp = connection.execute(q, table_name=self.denormalize_name(table_name), + schema=self.denormalize_name(schema)) + indexes = [] + last_index_name = None + pkeys = self.get_primary_keys(connection, table_name, schema, + resolve_synonyms=resolve_synonyms, + dblink=dblink, + info_cache=kw.get('info_cache')) + uniqueness = dict(NONUNIQUE=False, UNIQUE=True) + for rset in rp: + # don't include the primary key columns + if rset.column_name in [s.upper() for s in pkeys]: + continue + if rset.index_name != last_index_name: + index = dict(name=self.normalize_name(rset.index_name), column_names=[]) + indexes.append(index) + index['unique'] = uniqueness.get(rset.uniqueness, False) + index['column_names'].append(self.normalize_name(rset.column_name)) + last_index_name = rset.index_name + return indexes + + @reflection.cache + def _get_constraint_data(self, connection, table_name, schema=None, + dblink='', **kw): + + rp = connection.execute( + sql.text("""SELECT + ac.constraint_name, + ac.constraint_type, + loc.column_name AS local_column, + rem.table_name AS remote_table, + rem.column_name AS remote_column, + rem.owner AS remote_owner, + loc.position as loc_pos, + rem.position as rem_pos + FROM all_constraints%(dblink)s ac, + all_cons_columns%(dblink)s loc, + all_cons_columns%(dblink)s rem + WHERE ac.table_name = :table_name + AND ac.constraint_type IN ('R','P') + AND ac.owner = :owner + AND ac.owner = loc.owner + AND ac.constraint_name = loc.constraint_name + AND ac.r_owner = rem.owner(+) + AND ac.r_constraint_name = rem.constraint_name(+) + AND (rem.position IS NULL or loc.position=rem.position) + ORDER BY ac.constraint_name, loc.position""" % {'dblink': dblink}), + table_name=table_name, owner=schema) + constraint_data = rp.fetchall() + return constraint_data + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + pkeys = [] + constraint_data = self._get_constraint_data(connection, table_name, + schema, dblink, + info_cache=kw.get('info_cache')) + + for row in constraint_data: + #print "ROW:" , row + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + if cons_type == 'P': + pkeys.append(local_column) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + requested_schema = schema # to check later on + resolve_synonyms = kw.get('oracle_resolve_synonyms', False) + dblink = kw.get('dblink', '') + info_cache = kw.get('info_cache') + + (table_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, table_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + + constraint_data = self._get_constraint_data(connection, table_name, + schema, dblink, + info_cache=kw.get('info_cache')) + + def fkey_rec(): + return { + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + } + + fkeys = util.defaultdict(fkey_rec) + + for row in constraint_data: + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + + if cons_type == 'R': + if remote_table is None: + # ticket 363 + util.warn( + ("Got 'None' querying 'table_name' from " + "all_cons_columns%(dblink)s - does the user have " + "proper rights to the table?") % {'dblink':dblink}) + continue + + rec = fkeys[cons_name] + rec['name'] = cons_name + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + if not rec['referred_table']: + if resolve_synonyms: + ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \ + self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(remote_owner), + desired_table=self.denormalize_name(remote_table) + ) + if ref_synonym: + remote_table = self.normalize_name(ref_synonym) + remote_owner = self.normalize_name(ref_remote_owner) + + rec['referred_table'] = remote_table + + if requested_schema is not None or self.denormalize_name(remote_owner) != schema: + rec['referred_schema'] = remote_owner + + local_cols.append(local_column) + remote_cols.append(remote_column) + + return fkeys.values() + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, + resolve_synonyms=False, dblink='', **kw): + info_cache = kw.get('info_cache') + (view_name, schema, dblink, synonym) = \ + self._prepare_reflection_args(connection, view_name, schema, + resolve_synonyms, dblink, + info_cache=info_cache) + s = sql.text(""" + SELECT text FROM all_views + WHERE owner = :schema + AND view_name = :view_name + """) + rp = connection.execute(s, + view_name=view_name, schema=schema).scalar() + if rp: + return rp.decode(self.encoding) + else: + return None + + + +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = 'outer_join_column' + + def __init__(self, column): + self.column = column + + + diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py new file mode 100644 index 0000000000..d8a0c445a4 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -0,0 +1,371 @@ +"""Support for the Oracle database via the cx_oracle driver. + +Driver +------ + +The Oracle dialect uses the cx_oracle driver, available at +http://cx-oracle.sourceforge.net/ . The dialect has several behaviors +which are specifically tailored towards compatibility with this module. + +Connecting +---------- + +Connecting with create_engine() uses the standard URL approach of +``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the +host, port, and dbname tokens are converted to a TNS name using the cx_oracle +:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name. + +Additional arguments which may be specified either as query string arguments on the +URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are: + +* *allow_twophase* - enable two-phase transactions. Defaults to ``True``. + +* *auto_convert_lobs* - defaults to True, see the section on LOB objects. + +* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters. + This is required for LOB datatypes but can be disabled to reduce overhead. Defaults + to ``True``. + +* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an + integer value. This value is only available as a URL query string argument. + +* *threaded* - enable multithreaded access to cx_oracle connections. Defaults + to ``True``. Note that this is the opposite default of cx_oracle itself. + + +LOB Objects +----------- + +cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set +is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default, +SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First, +the LOB object requires an active cursor association, meaning if you were to fetch many rows +at once such that cx_oracle had to go back to the database and fetch a new batch of rows, +the LOB objects in the already-fetched rows are now unreadable and will raise an error. +SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read. +The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy +defaults to 50 (cx_oracle normally defaults this to one). + +Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to +"normalize" the results to look more like other DBAPIs. + +The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place +for all statement executions, even plain string-based statements for which SQLA has no awareness +of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases +without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch" +of LOB objects, can be disabled using auto_convert_lobs=False. + +Two Phase Transaction Support +----------------------------- + +Two Phase transactions are implemented using XA transactions. Success has been reported of them +working successfully but this should be regarded as an experimental feature. + +""" + +from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, RESERVED_WORDS +from sqlalchemy.dialects.oracle import base as oracle +from sqlalchemy.engine.default import DefaultExecutionContext +from sqlalchemy.engine import base +from sqlalchemy import types as sqltypes, util +import datetime + +class _OracleDate(sqltypes.Date): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + def process(value): + if not isinstance(value, datetime.datetime): + return value + else: + return value.date() + return process + +class _OracleDateTime(sqltypes.DateTime): + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value, datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year, value.month, + value.day,value.hour, value.minute, value.second) + return process + +# Note: +# Oracle DATE == DATETIME +# Oracle does not allow milliseconds in DATE +# Oracle does not support TIME columns + +# only if cx_oracle contains TIMESTAMP +class _OracleTimestamp(sqltypes.TIMESTAMP): + def result_processor(self, dialect): + def process(value): + if value is None or isinstance(value, datetime.datetime): + return value + else: + # convert cx_oracle datetime object returned pre-python 2.4 + return datetime.datetime(value.year, value.month, + value.day,value.hour, value.minute, value.second) + return process + +class _LOBMixin(object): + def result_processor(self, dialect): + super_process = super(_LOBMixin, self).result_processor(dialect) + if not dialect.auto_convert_lobs: + return super_process + lob = dialect.dbapi.LOB + def process(value): + if isinstance(value, lob): + if super_process: + return super_process(value.read()) + else: + return value.read() + else: + if super_process: + return super_process(value) + else: + return value + return process + +class _OracleText(_LOBMixin, sqltypes.Text): + def get_dbapi_type(self, dbapi): + return dbapi.CLOB + +class _OracleUnicodeText(_LOBMixin, sqltypes.UnicodeText): + def get_dbapi_type(self, dbapi): + return dbapi.NCLOB + + +class _OracleBinary(_LOBMixin, sqltypes.Binary): + def get_dbapi_type(self, dbapi): + return dbapi.BLOB + + def bind_processor(self, dialect): + return None + + +class _OracleRaw(_LOBMixin, oracle.RAW): + pass + + +colspecs = { + sqltypes.DateTime : _OracleDateTime, + sqltypes.Date : _OracleDate, + sqltypes.Binary : _OracleBinary, + sqltypes.Boolean : oracle._OracleBoolean, + sqltypes.Text : _OracleText, + sqltypes.UnicodeText : _OracleUnicodeText, + sqltypes.TIMESTAMP : _OracleTimestamp, + oracle.RAW: _OracleRaw, +} + +class Oracle_cx_oracleCompiler(OracleCompiler): + def bindparam_string(self, name): + if self.preparer._bindparam_requires_quotes(name): + quoted_name = '"%s"' % name + self._quoted_bind_names[name] = quoted_name + return OracleCompiler.bindparam_string(self, quoted_name) + else: + return OracleCompiler.bindparam_string(self, name) + +class Oracle_cx_oracleExecutionContext(DefaultExecutionContext): + def pre_exec(self): + quoted_bind_names = getattr(self.compiled, '_quoted_bind_names', {}) + if quoted_bind_names: + for param in self.parameters: + for fromname, toname in self.compiled._quoted_bind_names.iteritems(): + param[toname.encode(self.dialect.encoding)] = param[fromname] + del param[fromname] + + if self.dialect.auto_setinputsizes: + self.set_input_sizes(quoted_bind_names, exclude_types=(self.dialect.dbapi.STRING,)) + + if len(self.compiled_parameters) == 1: + for key in self.compiled.binds: + bindparam = self.compiled.binds[key] + name = self.compiled.bind_names[bindparam] + value = self.compiled_parameters[0][name] + if bindparam.isoutparam: + dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if not hasattr(self, 'out_parameters'): + self.out_parameters = {} + self.out_parameters[name] = self.cursor.var(dbtype) + self.parameters[0][quoted_bind_names.get(name, name)] = self.out_parameters[name] + + + def create_cursor(self): + c = self._connection.connection.cursor() + if self.dialect.arraysize: + c.cursor.arraysize = self.dialect.arraysize + return c + + def get_result_proxy(self): + if hasattr(self, 'out_parameters'): + if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: + for bind, name in self.compiled.bind_names.iteritems(): + if name in self.out_parameters: + type = bind.type + result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect) + if result_processor is not None: + self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue()) + else: + self.out_parameters[name] = self.out_parameters[name].getvalue() + else: + for k in self.out_parameters: + self.out_parameters[k] = self.out_parameters[k].getvalue() + + if self.cursor.description is not None: + for column in self.cursor.description: + type_code = column[1] + if type_code in self.dialect.ORACLE_BINARY_TYPES: + return base.BufferedColumnResultProxy(self) + + if hasattr(self, 'out_parameters') and \ + self.compiled.returning: + + return ReturningResultProxy(self) + else: + return base.ResultProxy(self) + +class ReturningResultProxy(base.FullyBufferedResultProxy): + """Result proxy which stuffs the _returning clause + outparams into the fetch.""" + + def _cursor_description(self): + returning = self.context.compiled.returning + + ret = [] + for c in returning: + if hasattr(c, 'key'): + ret.append((c.key, c.type)) + else: + ret.append((c.anon_label, c.type)) + return ret + + def _buffer_rows(self): + returning = self.context.compiled.returning + return [tuple(self.context.out_parameters["ret_%d" % i] for i, c in enumerate(returning))] + +class Oracle_cx_oracle(OracleDialect): + execution_ctx_cls = Oracle_cx_oracleExecutionContext + statement_compiler = Oracle_cx_oracleCompiler + driver = "cx_oracle" + colspecs = colspecs + + def __init__(self, + auto_setinputsizes=True, + auto_convert_lobs=True, + threaded=True, + allow_twophase=True, + arraysize=50, **kwargs): + OracleDialect.__init__(self, **kwargs) + self.threaded = threaded + self.arraysize = arraysize + self.allow_twophase = allow_twophase + self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) + self.auto_setinputsizes = auto_setinputsizes + self.auto_convert_lobs = auto_convert_lobs + + def vers(num): + return tuple([int(x) for x in num.split('.')]) + + if hasattr(self.dbapi, 'version'): + cx_oracle_ver = vers(self.dbapi.version) + self.supports_unicode_binds = cx_oracle_ver >= (5, 0) + + if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__: + self.dbapi_type_map = {} + self.ORACLE_BINARY_TYPES = [] + else: + # only use this for LOB objects. using it for strings, dates + # etc. leads to a little too much magic, reflection doesn't know if it should + # expect encoded strings or unicodes, etc. + self.dbapi_type_map = { + self.dbapi.CLOB: oracle.CLOB(), + self.dbapi.NCLOB:oracle.NCLOB(), + self.dbapi.BLOB: oracle.BLOB(), + self.dbapi.BINARY: oracle.RAW(), + } + self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)] + + @classmethod + def dbapi(cls): + import cx_Oracle + return cx_Oracle + + def create_connect_args(self, url): + dialect_opts = dict(url.query) + for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs', + 'threaded', 'allow_twophase'): + if opt in dialect_opts: + util.coerce_kw_type(dialect_opts, opt, bool) + setattr(self, opt, dialect_opts[opt]) + + if url.database: + # if we have a database, then we have a remote host + port = url.port + if port: + port = int(port) + else: + port = 1521 + dsn = self.dbapi.makedsn(url.host, port, url.database) + else: + # we have a local tnsname + dsn = url.host + + opts = dict( + user=url.username, + password=url.password, + dsn=dsn, + threaded=self.threaded, + twophase=self.allow_twophase, + ) + if 'mode' in url.query: + opts['mode'] = url.query['mode'] + if isinstance(opts['mode'], basestring): + mode = opts['mode'].upper() + if mode == 'SYSDBA': + opts['mode'] = self.dbapi.SYSDBA + elif mode == 'SYSOPER': + opts['mode'] = self.dbapi.SYSOPER + else: + util.coerce_kw_type(opts, 'mode', int) + # Can't set 'handle' or 'pool' via URL query args, use connect_args + + return ([], opts) + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.version.split('.')) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.InterfaceError): + return "not connected" in str(e) + else: + return "ORA-03114" in str(e) or "ORA-03113" in str(e) + + def create_xid(self): + """create a two-phase transaction ID. + + this id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). its format is unspecified.""" + + id = random.randint(0, 2 ** 128) + return (0x1234, "%032x" % id, "%032x" % 9) + + def do_begin_twophase(self, connection, xid): + connection.connection.begin(*xid) + + def do_prepare_twophase(self, connection, xid): + connection.connection.prepare() + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + pass + +dialect = Oracle_cx_oracle diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py new file mode 100644 index 0000000000..a0ad088b2d --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py @@ -0,0 +1,24 @@ +"""Support for the Oracle database via the zxjdbc JDBC connector.""" +import re + +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.oracle.base import OracleDialect + +class Oracle_jdbc(ZxJDBCConnector, OracleDialect): + + jdbc_db_name = 'oracle' + jdbc_driver_name = 'oracle.jdbc.driver.OracleDriver' + + def create_connect_args(self, url): + hostname = url.host + port = url.port or '1521' + dbname = url.database + jdbc_url = 'jdbc:oracle:thin:@%s:%s:%s' % (hostname, port, dbname) + return [[jdbc_url, url.username, url.password, self.jdbc_driver_name], + self._driver_kwargs()] + + def _get_server_version_info(self, connection): + version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1) + return tuple(int(x) for x in version.split('.')) + +dialect = Oracle_jdbc diff --git a/lib/sqlalchemy/dialects/postgres.py b/lib/sqlalchemy/dialects/postgres.py new file mode 100644 index 0000000000..e66989fa7d --- /dev/null +++ b/lib/sqlalchemy/dialects/postgres.py @@ -0,0 +1,9 @@ +# backwards compat with the old name +from sqlalchemy.util import warn_deprecated + +warn_deprecated( + "The SQLAlchemy PostgreSQL dialect has been renamed from 'postgres' to 'postgresql'. " + "The new URL format is postgresql[+driver]://:@/" + ) + +from sqlalchemy.dialects.postgresql import * \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py new file mode 100644 index 0000000000..af9430a2b0 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.postgresql import base, psycopg2, pg8000, zxjdbc + +base.dialect = psycopg2.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/dialects/postgresql/base.py similarity index 51% rename from lib/sqlalchemy/databases/postgres.py rename to lib/sqlalchemy/dialects/postgresql/base.py index 154d971e35..874907abc1 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1,34 +1,13 @@ -# postgres.py +# postgresql.py # Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Support for the PostgreSQL database. +"""Support for the PostgreSQL database. -Driver ------- - -The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ . -The dialect has several behaviors which are specifically tailored towards compatibility -with this module. - -Note that psycopg1 is **not** supported. - -Connecting ----------- - -URLs are of the form `postgres://user:password@host:port/dbname[?key=value&key=value...]`. - -PostgreSQL-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are: - -* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support - this feature. What this essentially means from a psycopg2 point of view is that the cursor is - created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows - are not immediately pre-fetched and buffered after statement execution, but are instead left - on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy` - uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows - at a time are fetched over the wire to reduce conversational overhead. +For information on connecting using specific drivers, see the documentation section +regarding that driver. Sequences/SERIAL ---------------- @@ -64,144 +43,76 @@ The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` sy but must be explicitly enabled on a per-statement basis:: # INSERT..RETURNING - result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\ + result = table.insert(postgresql_returning=[table.c.col1, table.c.col2]).\\ values(name='foo') print result.fetchall() # UPDATE..RETURNING - result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\ + result = table.update(postgresql_returning=[table.c.col1, table.c.col2]).\\ where(table.c.name=='foo').values(name='bar') print result.fetchall() Indexes ------- -PostgreSQL supports partial indexes. To create them pass a postgres_where +PostgreSQL supports partial indexes. To create them pass a postgresql_where option to the Index constructor:: - Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10) + Index('my_index', my_table.c.id, postgresql_where=tbl.c.value > 10) -Transactions ------------- - -The PostgreSQL dialect fully supports SAVEPOINT and two-phase commit operations. """ -import decimal, random, re, string +import re +from sqlalchemy import schema as sa_schema from sqlalchemy import sql, schema, exc, util -from sqlalchemy.engine import base, default +from sqlalchemy.engine import base, default, reflection from sqlalchemy.sql import compiler, expression from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes +from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \ + CHAR, TEXT, FLOAT, NUMERIC, \ + TIMESTAMP, TIME, DATE, BOOLEAN -class PGInet(sqltypes.TypeEngine): - def get_col_spec(self): - return "INET" - -class PGCidr(sqltypes.TypeEngine): - def get_col_spec(self): - return "CIDR" - -class PGMacAddr(sqltypes.TypeEngine): - def get_col_spec(self): - return "MACADDR" - -class PGNumeric(sqltypes.Numeric): - def get_col_spec(self): - if not self.precision: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - if self.asdecimal: - return None - else: - def process(value): - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - return process - -class PGFloat(sqltypes.Float): - def get_col_spec(self): - if not self.precision: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - -class PGInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class PGSmallInteger(sqltypes.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class PGBigInteger(PGInteger): - def get_col_spec(self): - return "BIGINT" - -class PGDateTime(sqltypes.DateTime): - def get_col_spec(self): - return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" +class REAL(sqltypes.Float): + __visit_name__ = "REAL" -class PGDate(sqltypes.Date): - def get_col_spec(self): - return "DATE" +class BYTEA(sqltypes.Binary): + __visit_name__ = 'BYTEA' -class PGTime(sqltypes.Time): - def get_col_spec(self): - return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" +class DOUBLE_PRECISION(sqltypes.Float): + __visit_name__ = 'DOUBLE_PRECISION' + +class INET(sqltypes.TypeEngine): + __visit_name__ = "INET" +PGInet = INET -class PGInterval(sqltypes.TypeEngine): - def get_col_spec(self): - return "INTERVAL" +class CIDR(sqltypes.TypeEngine): + __visit_name__ = "CIDR" +PGCidr = CIDR -class PGText(sqltypes.Text): - def get_col_spec(self): - return "TEXT" +class MACADDR(sqltypes.TypeEngine): + __visit_name__ = "MACADDR" +PGMacAddr = MACADDR -class PGString(sqltypes.String): - def get_col_spec(self): - if self.length: - return "VARCHAR(%(length)d)" % {'length' : self.length} - else: - return "VARCHAR" +class INTERVAL(sqltypes.TypeEngine): + __visit_name__ = 'INTERVAL' +PGInterval = INTERVAL -class PGChar(sqltypes.CHAR): - def get_col_spec(self): - if self.length: - return "CHAR(%(length)d)" % {'length' : self.length} - else: - return "CHAR" +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' +PGBit = BIT -class PGBinary(sqltypes.Binary): - def get_col_spec(self): - return "BYTEA" +class UUID(sqltypes.TypeEngine): + __visit_name__ = 'UUID' +PGUuid = UUID -class PGBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BOOLEAN" - -class PGBit(sqltypes.TypeEngine): - def get_col_spec(self): - return "BIT" - -class PGUuid(sqltypes.TypeEngine): - def get_col_spec(self): - return "UUID" +class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): + __visit_name__ = 'ARRAY' -class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): def __init__(self, item_type, mutable=True): if isinstance(item_type, type): item_type = item_type() @@ -259,133 +170,341 @@ class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): return item return [convert_item(item) for item in value] return process - def get_col_spec(self): - return self.item_type.get_col_spec() + '[]' +PGArray = ARRAY colspecs = { - sqltypes.Integer : PGInteger, - sqltypes.Smallinteger : PGSmallInteger, - sqltypes.Numeric : PGNumeric, - sqltypes.Float : PGFloat, - sqltypes.DateTime : PGDateTime, - sqltypes.Date : PGDate, - sqltypes.Time : PGTime, - sqltypes.String : PGString, - sqltypes.Binary : PGBinary, - sqltypes.Boolean : PGBoolean, - sqltypes.Text : PGText, - sqltypes.CHAR: PGChar, + sqltypes.Interval:INTERVAL } ischema_names = { - 'integer' : PGInteger, - 'bigint' : PGBigInteger, - 'smallint' : PGSmallInteger, - 'character varying' : PGString, - 'character' : PGChar, - '"char"' : PGChar, - 'name': PGChar, - 'text' : PGText, - 'numeric' : PGNumeric, - 'float' : PGFloat, - 'real' : PGFloat, - 'inet': PGInet, - 'cidr': PGCidr, - 'uuid':PGUuid, - 'bit':PGBit, - 'macaddr': PGMacAddr, - 'double precision' : PGFloat, - 'timestamp' : PGDateTime, - 'timestamp with time zone' : PGDateTime, - 'timestamp without time zone' : PGDateTime, - 'time with time zone' : PGTime, - 'time without time zone' : PGTime, - 'date' : PGDate, - 'time': PGTime, - 'bytea' : PGBinary, - 'boolean' : PGBoolean, - 'interval':PGInterval, + 'integer' : INTEGER, + 'bigint' : BIGINT, + 'smallint' : SMALLINT, + 'character varying' : VARCHAR, + 'character' : CHAR, + '"char"' : sqltypes.String, + 'name' : sqltypes.String, + 'text' : TEXT, + 'numeric' : NUMERIC, + 'float' : FLOAT, + 'real' : REAL, + 'inet': INET, + 'cidr': CIDR, + 'uuid': UUID, + 'bit':BIT, + 'macaddr': MACADDR, + 'double precision' : DOUBLE_PRECISION, + 'timestamp' : TIMESTAMP, + 'timestamp with time zone' : TIMESTAMP, + 'timestamp without time zone' : TIMESTAMP, + 'time with time zone' : TIME, + 'time without time zone' : TIME, + 'date' : DATE, + 'time': TIME, + 'bytea' : BYTEA, + 'boolean' : BOOLEAN, + 'interval':INTERVAL, } -# TODO: filter out 'FOR UPDATE' statements -SERVER_SIDE_CURSOR_RE = re.compile( - r'\s*SELECT', - re.I | re.UNICODE) - -class PGExecutionContext(default.DefaultExecutionContext): - def create_cursor(self): - # TODO: coverage for server side cursors + select.for_update() - is_server_side = \ - self.dialect.server_side_cursors and \ - ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) - and not getattr(self.compiled.statement, 'for_update', False)) \ - or \ - ( - (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) - and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) - ) - self.__is_server_side = is_server_side - if is_server_side: - # use server-side cursors: - # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:]) - return self._connection.connection.cursor(ident) + +class PGCompiler(compiler.SQLCompiler): + + def visit_match_op(self, binary, **kw): + return "%s @@ to_tsquery(%s)" % (self.process(binary.left), self.process(binary.right)) + + def visit_ilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_notilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT ILIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def post_process_text(self, text): + if '%%' in text: + util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() expressions to '%%'.") + return text.replace('%', '%%') + + def visit_sequence(self, seq): + if seq.optional: + return None + else: + return "nextval('%s')" % self.preparer.format_sequence(seq) + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT ALL" + text += " OFFSET " + str(select._offset) + return text + + def get_select_precolumns(self, select): + if select._distinct: + if isinstance(select._distinct, bool): + return "DISTINCT " + elif isinstance(select._distinct, (list, tuple)): + return "DISTINCT ON (" + ', '.join( + [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct] + )+ ") " + else: + return "DISTINCT ON (" + unicode(select._distinct) + ") " + else: + return "" + + def for_update_clause(self, select): + if select.for_update == 'nowait': + return " FOR UPDATE NOWAIT" + else: + return super(PGCompiler, self).for_update_clause(select) + + def returning_clause(self, stmt, returning_cols): + + columns = [ + self.process( + self.label_select_column(None, c, asfrom=False), + within_columns_clause=True, + result_map=self.result_map) + for c in expression._select_iterables(returning_cols) + ] + + return 'RETURNING ' + ', '.join(columns) + + def visit_extract(self, extract, **kwargs): + field = self.extract_map.get(extract.field, extract.field) + return "EXTRACT(%s FROM %s::timestamp)" % ( + field, self.process(extract.expr)) + +class PGDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + if column.primary_key and \ + len(column.foreign_keys)==0 and \ + column.autoincrement and \ + isinstance(column.type, sqltypes.Integer) and \ + not isinstance(column.type, sqltypes.SmallInteger) and \ + (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if isinstance(column.type, sqltypes.BigInteger): + colspec += " BIGSERIAL" + else: + colspec += " SERIAL" else: - return self._connection.connection.cursor() + colspec += " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + + def visit_create_sequence(self, create): + return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + + def visit_drop_sequence(self, drop): + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + + def visit_create_index(self, create): + preparer = self.preparer + index = create.element + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % (preparer.quote(self._validate_identifier(index.name, True), index.quote), + preparer.format_table(index.table), + ', '.join([preparer.format_column(c) for c in index.columns])) + + if "postgres_where" in index.kwargs: + whereclause = index.kwargs['postgres_where'] + util.warn_deprecated("The 'postgres_where' argument has been renamed to 'postgresql_where'.") + elif 'postgresql_where' in index.kwargs: + whereclause = index.kwargs['postgresql_where'] + else: + whereclause = None + + if whereclause is not None: + compiler = self._compile(whereclause, None) + # this might belong to the compiler class + inlined_clause = str(compiler) % dict( + [(key,bind.value) for key,bind in compiler.binds.iteritems()]) + text += " WHERE " + inlined_clause + return text + + +class PGDefaultRunner(base.DefaultRunner): + def __init__(self, context): + base.DefaultRunner.__init__(self, context) + # craete cursor which won't conflict with a server-side cursor + self.cursor = context._connection.connection.cursor() - def get_result_proxy(self): - if self.__is_server_side: - return base.BufferedRowResultProxy(self) + def get_column_default(self, column, isinsert=True): + if column.primary_key: + # pre-execute passive defaults on primary keys + if (isinstance(column.server_default, schema.DefaultClause) and + column.server_default.arg is not None): + return self.execute_string("select %s" % column.server_default.arg) + elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) \ + and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + sch = column.table.schema + # TODO: this has to build into the Sequence object so we can get the quoting + # logic from it + if sch is not None: + exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) + else: + exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) + + if self.dialect.supports_unicode_statements: + return self.execute_string(exc) + else: + return self.execute_string(exc.encode(self.dialect.encoding)) + + return super(PGDefaultRunner, self).get_column_default(column) + + def visit_sequence(self, seq): + if not seq.optional: + return self.execute_string(("select nextval('%s')" % \ + self.dialect.identifier_preparer.format_sequence(seq))) else: - return base.ResultProxy(self) + return None + +class PGTypeCompiler(compiler.GenericTypeCompiler): + def visit_INET(self, type_): + return "INET" + + def visit_CIDR(self, type_): + return "CIDR" + + def visit_MACADDR(self, type_): + return "MACADDR" + + def visit_FLOAT(self, type_): + if not type_.precision: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': type_.precision} + + def visit_DOUBLE_PRECISION(self, type_): + return "DOUBLE PRECISION" + + def visit_BIGINT(self, type_): + return "BIGINT" + + def visit_datetime(self, type_): + return self.visit_TIMESTAMP(type_) + + def visit_TIMESTAMP(self, type_): + return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + + def visit_TIME(self, type_): + return "TIME " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + + def visit_INTERVAL(self, type_): + return "INTERVAL" + + def visit_BIT(self, type_): + return "BIT" + + def visit_UUID(self, type_): + return "UUID" + + def visit_binary(self, type_): + return self.visit_BYTEA(type_) + + def visit_BYTEA(self, type_): + return "BYTEA" + + def visit_REAL(self, type_): + return "REAL" + + def visit_ARRAY(self, type_): + return self.process(type_.item_type) + '[]' + +class PGIdentifierPreparer(compiler.IdentifierPreparer): + def _unquote_identifier(self, value): + if value[0] == self.initial_quote: + value = value[1:-1].replace('""','"') + return value + +class PGInspector(reflection.Inspector): + + def __init__(self, conn): + reflection.Inspector.__init__(self, conn) + + def get_table_oid(self, table_name, schema=None): + """Return the oid from `table_name` and `schema`.""" + + return self.dialect.get_table_oid(self.conn, table_name, schema, + info_cache=self.info_cache) + class PGDialect(default.DefaultDialect): - name = 'postgres' + name = 'postgresql' supports_alter = True - supports_unicode_statements = False max_identifier_length = 63 supports_sane_rowcount = True - supports_sane_multi_rowcount = False - preexecute_pk_sequences = True - supports_pk_autoincrement = False - default_paramstyle = 'pyformat' + + supports_sequences = True + sequences_optional = True + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + supports_default_values = True supports_empty_insert = False + default_paramstyle = 'pyformat' + ischema_names = ischema_names + colspecs = colspecs - def __init__(self, server_side_cursors=False, **kwargs): + statement_compiler = PGCompiler + ddl_compiler = PGDDLCompiler + type_compiler = PGTypeCompiler + preparer = PGIdentifierPreparer + defaultrunner = PGDefaultRunner + inspector = PGInspector + isolation_level = None + + def __init__(self, isolation_level=None, **kwargs): default.DefaultDialect.__init__(self, **kwargs) - self.server_side_cursors = server_side_cursors - - def dbapi(cls): - import psycopg2 as psycopg - return psycopg - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['port'] = int(opts['port']) - opts.update(url.query) - return ([], opts) + self.isolation_level = isolation_level - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) + def initialize(self, connection): + super(PGDialect, self).initialize(connection) + self.implicit_returning = self.server_version_info > (8, 3) and \ + self.__dict__.get('implicit_returning', True) + + def visit_pool(self, pool): + if self.isolation_level is not None: + class SetIsolationLevel(object): + def __init__(self, isolation_level): + self.isolation_level = isolation_level + + def connect(self, conn, rec): + cursor = conn.cursor() + cursor.execute("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL %s" + % self.isolation_level) + cursor.execute("COMMIT") + cursor.close() + pool.add_listener(SetIsolationLevel(self.isolation_level)) def do_begin_twophase(self, connection, xid): self.do_begin(connection.connection) def do_prepare_twophase(self, connection, xid): - connection.execute(sql.text("PREPARE TRANSACTION :tid", bindparams=[sql.bindparam('tid', xid)])) + connection.execute("PREPARE TRANSACTION '%s'" % xid) def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): if is_prepared: if recover: #FIXME: ugly hack to get out of transaction context when commiting recoverable transactions # Must find out a way how to make the dbapi not open a transaction. - connection.execute(sql.text("ROLLBACK")) - connection.execute(sql.text("ROLLBACK PREPARED :tid", bindparams=[sql.bindparam('tid', xid)])) - connection.execute(sql.text("BEGIN")) + connection.execute("ROLLBACK") + connection.execute("ROLLBACK PREPARED '%s'" % xid) + connection.execute("BEGIN") self.do_rollback(connection.connection) else: self.do_rollback(connection.connection) @@ -393,9 +512,9 @@ class PGDialect(default.DefaultDialect): def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): if is_prepared: if recover: - connection.execute(sql.text("ROLLBACK")) - connection.execute(sql.text("COMMIT PREPARED :tid", bindparams=[sql.bindparam('tid', xid)])) - connection.execute(sql.text("BEGIN")) + connection.execute("ROLLBACK") + connection.execute("COMMIT PREPARED '%s'" % xid) + connection.execute("BEGIN") self.do_rollback(connection.connection) else: self.do_commit(connection.connection) @@ -405,66 +524,151 @@ class PGDialect(default.DefaultDialect): return [row[0] for row in resultset] def get_default_schema_name(self, connection): - return connection.scalar("select current_schema()", None) - get_default_schema_name = base.connection_memoize( - ('dialect', 'default_schema_name'))(get_default_schema_name) - - def last_inserted_ids(self): - if self.context.last_inserted_ids is None: - raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without PostgreSQL OIDs enabled") - else: - return self.context.last_inserted_ids + return connection.scalar("select current_schema()") def has_table(self, connection, table_name, schema=None): # seems like case gets folded in pg_class... if schema is None: - cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)}); + cursor = connection.execute( + sql.text("select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=:name", + bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode)] + ) + ) else: - cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema}); - return bool( not not cursor.rowcount ) + cursor = connection.execute( + sql.text("select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=:schema and lower(relname)=:name", + bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode), + sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)] + ) + ) + return bool(cursor.fetchone()) def has_sequence(self, connection, sequence_name): - cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name.encode(self.encoding)}) - return bool(not not cursor.rowcount) - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return 'closed the connection' in str(e) or 'connection not open' in str(e) - elif isinstance(e, self.dbapi.InterfaceError): - return 'connection already closed' in str(e) or 'cursor already closed' in str(e) - elif isinstance(e, self.dbapi.ProgrammingError): - # yes, it really says "losed", not "closed" - return "losed the connection unexpectedly" in str(e) - else: - return False + cursor = connection.execute( + sql.text("SELECT relname FROM pg_class WHERE relkind = 'S' AND " + "relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' " + "AND nspname != 'information_schema' AND relname = :seqname)", + bindparams=[sql.bindparam('seqname', unicode(sequence_name), type_=sqltypes.Unicode)] + )) + return bool(cursor.fetchone()) def table_names(self, connection, schema): - s = """ - SELECT relname - FROM pg_class c - WHERE relkind = 'r' - AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) - """ % locals() - return [row[0].decode(self.encoding) for row in connection.execute(s)] + result = connection.execute( + sql.text(u"""SELECT relname + FROM pg_class c + WHERE relkind = 'r' + AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)""" % schema, + typemap = {'relname':sqltypes.Unicode} + ) + ) + return [row[0] for row in result] - def server_version_info(self, connection): + def _get_server_version_info(self, connection): v = connection.execute("select version()").scalar() m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v) if not m: raise AssertionError("Could not determine version from string '%s'" % v) return tuple([int(x) for x in m.group(1, 2, 3)]) - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - if table.schema is not None: + @reflection.cache + def get_table_oid(self, connection, table_name, schema=None, **kw): + """Fetch the oid for schema.table_name. + + Several reflection methods require the table oid. The idea for using + this method is that it can be fetched one time and cached for + subsequent calls. + + """ + table_oid = None + if schema is not None: schema_where_clause = "n.nspname = :schema" - schemaname = table.schema - if isinstance(schemaname, str): - schemaname = schemaname.decode(self.encoding) else: schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - schemaname = None + query = """ + SELECT c.oid + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE (%s) + AND c.relname = :table_name AND c.relkind in ('r','v') + """ % schema_where_clause + # Since we're binding to unicode, table_name and schema_name must be + # unicode. + table_name = unicode(table_name) + if schema is not None: + schema = unicode(schema) + s = sql.text(query, bindparams=[ + sql.bindparam('table_name', type_=sqltypes.Unicode), + sql.bindparam('schema', type_=sqltypes.Unicode) + ], + typemap={'oid':sqltypes.Integer} + ) + c = connection.execute(s, table_name=table_name, schema=schema) + table_oid = c.scalar() + if table_oid is None: + raise exc.NoSuchTableError(table_name) + return table_oid + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = """ + SELECT nspname + FROM pg_namespace + ORDER BY nspname + """ + rp = connection.execute(s) + # what about system tables? + schema_names = [row[0].decode(self.encoding) for row in rp \ + if not row[0].startswith('pg_')] + return schema_names + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.get_default_schema_name(connection) + table_names = self.table_names(connection, current_schema) + return table_names + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.get_default_schema_name(connection) + s = """ + SELECT relname + FROM pg_class c + WHERE relkind = 'v' + AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) + """ % dict(schema=current_schema) + view_names = [row[0].decode(self.encoding) for row in connection.execute(s)] + return view_names + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is not None: + current_schema = schema + else: + current_schema = self.get_default_schema_name(connection) + s = """ + SELECT definition FROM pg_views + WHERE schemaname = :schema + AND viewname = :view_name + """ + rp = connection.execute(sql.text(s), + view_name=view_name, schema=current_schema) + if rp: + view_def = rp.scalar().decode(self.encoding) + return view_def + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), @@ -473,42 +677,28 @@ class PGDialect(default.DefaultDialect): AS DEFAULT, a.attnotnull, a.attnum, a.attrelid as table_oid FROM pg_catalog.pg_attribute a - WHERE a.attrelid = ( - SELECT c.oid - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE (%s) - AND c.relname = :table_name AND c.relkind in ('r','v') - ) AND a.attnum > 0 AND NOT a.attisdropped + WHERE a.attrelid = :table_oid + AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum - """ % schema_where_clause - - s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}) - tablename = table.name - if isinstance(tablename, str): - tablename = tablename.decode(self.encoding) - c = connection.execute(s, table_name=tablename, schema=schemaname) + """ + s = sql.text(SQL_COLS, + bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)], + typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode} + ) + c = connection.execute(s, table_oid=table_oid) rows = c.fetchall() - - if not rows: - raise exc.NoSuchTableError(table.name) - domains = self._load_domains(connection) - + # format columns + columns = [] for name, format_type, default, notnull, attnum, table_oid in rows: - if include_columns and name not in include_columns: - continue - ## strip (30) from character varying(30) attype = re.search('([^\([]+)', format_type).group(1) nullable = not notnull is_array = format_type.endswith('[]') - try: charlen = re.search('\(([\d,]+)\)', format_type).group(1) except: charlen = False - numericprec = False numericscale = False if attype == 'numeric': @@ -523,35 +713,31 @@ class PGDialect(default.DefaultDialect): if attype == 'integer': numericprec, numericscale = (32, 0) charlen = False - args = [] for a in (charlen, numericprec, numericscale): if a is None: args.append(None) elif a is not False: args.append(int(a)) - kwargs = {} if attype == 'timestamp with time zone': kwargs['timezone'] = True elif attype == 'timestamp without time zone': kwargs['timezone'] = False - - coltype = None - if attype in ischema_names: - coltype = ischema_names[attype] + if attype in self.ischema_names: + coltype = self.ischema_names[attype] else: if attype in domains: domain = domains[attype] - if domain['attype'] in ischema_names: + if domain['attype'] in self.ischema_names: # A table can't override whether the domain is nullable. nullable = domain['nullable'] - if domain['default'] and not default: # It can, however, override the default value, but can't set it to null. default = domain['default'] - coltype = ischema_names[domain['attype']] - + coltype = self.ischema_names[domain['attype']] + else: + coltype = None if coltype: coltype = coltype(*args, **kwargs) if is_array: @@ -560,41 +746,46 @@ class PGDialect(default.DefaultDialect): util.warn("Did not recognize type '%s' of column '%s'" % (attype, name)) coltype = sqltypes.NULLTYPE - - colargs = [] + # adjust the default value + autoincrement = False if default is not None: match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) if match is not None: + autoincrement = True # the default is related to a Sequence - sch = table.schema + sch = schema if '.' not in match.group(2) and sch is not None: # unconditionally quote the schema name. this could # later be enhanced to obey quoting rules / "quote schema" default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) - colargs.append(schema.DefaultClause(sql.text(default))) - table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) + column_info = dict(name=name, type=coltype, nullable=nullable, + default=default, autoincrement=autoincrement) + columns.append(column_info) + return columns - # Primary keys + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) PK_SQL = """ SELECT attname FROM pg_attribute WHERE attrelid = ( SELECT indexrelid FROM pg_index i - WHERE i.indrelid = :table + WHERE i.indrelid = :table_oid AND i.indisprimary = 't') ORDER BY attnum """ t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) - c = connection.execute(t, table=table_oid) - for row in c.fetchall(): - pk = row[0] - if pk in table.c: - col = table.c[pk] - table.primary_key.add(col) - if col.default is None: - col.autoincrement = False - - # Foreign keys + c = connection.execute(t, table_oid=table_oid) + primary_keys = [r[0] for r in c.fetchall()] + return primary_keys + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + preparer = self.identifier_preparer + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) FK_SQL = """ SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef FROM pg_catalog.pg_constraint r @@ -604,51 +795,51 @@ class PGDialect(default.DefaultDialect): t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode}) c = connection.execute(t, table=table_oid) + fkeys = [] for conname, condef in c.fetchall(): m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() (constrained_columns, referred_schema, referred_table, referred_columns) = m constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] if referred_schema: referred_schema = preparer._unquote_identifier(referred_schema) - elif table.schema is not None and table.schema == self.get_default_schema_name(connection): + elif schema is not None and schema == self.get_default_schema_name(connection): # no schema (i.e. its the default schema), and the table we're # reflecting has the default schema explicit, then use that. # i.e. try to use the user's conventions - referred_schema = table.schema + referred_schema = schema referred_table = preparer._unquote_identifier(referred_table) referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] - - refspec = [] - if referred_schema is not None: - schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, - autoload_with=connection) - for column in referred_columns: - refspec.append(".".join([referred_schema, referred_table, column])) - else: - schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) - for column in referred_columns: - refspec.append(".".join([referred_table, column])) - - table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True)) - - # Indexes + fkey_d = { + 'name' : conname, + 'constrained_columns' : constrained_columns, + 'referred_schema' : referred_schema, + 'referred_table' : referred_table, + 'referred_columns' : referred_columns + } + fkeys.append(fkey_d) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema, **kw): + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) IDX_SQL = """ SELECT c.relname, i.indisunique, i.indexprs, i.indpred, a.attname FROM pg_index i, pg_class c, pg_attribute a - WHERE i.indrelid = :table AND i.indexrelid = c.oid + WHERE i.indrelid = :table_oid AND i.indexrelid = c.oid AND a.attrelid = i.indexrelid AND i.indisprimary = 'f' ORDER BY c.relname, a.attnum """ t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode}) - c = connection.execute(t, table=table_oid) - indexes = {} + c = connection.execute(t, table_oid=table_oid) + index_names = {} + indexes = [] sv_idx_name = None for row in c.fetchall(): idx_name, unique, expr, prd, col = row - if expr: - if not idx_name == sv_idx_name: + if idx_name != sv_idx_name: util.warn( "Skipped unsupported reflection of expression-based index %s" % idx_name) @@ -659,16 +850,16 @@ class PGDialect(default.DefaultDialect): "Predicate of partial index %s ignored during reflection" % idx_name) sv_idx_name = idx_name - - if not indexes.has_key(idx_name): - indexes[idx_name] = [unique, []] - indexes[idx_name][1].append(col) - - for name, (unique, columns) in indexes.items(): - schema.Index(name, *[table.columns[c] for c in columns], - **dict(unique=unique)) - - + if idx_name in index_names: + index_d = index_names[idx_name] + else: + index_d = {'column_names':[]} + indexes.append(index_d) + index_names[idx_name] = index_d + index_d['name'] = idx_name + index_d['column_names'].append(col) + index_d['unique'] = unique + return indexes def _load_domains(self, connection): ## Load data types for domains: @@ -705,185 +896,3 @@ class PGDialect(default.DefaultDialect): return domains - -class PGCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() - operators.update( - { - sql_operators.mod : '%%', - sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y), - } - ) - - functions = compiler.DefaultCompiler.functions.copy() - functions.update ( - { - 'TIMESTAMP':util.deprecated(message="Use a literal string 'timestamp ' instead")(lambda x:'TIMESTAMP %s' % x), - } - ) - - def visit_sequence(self, seq): - if seq.optional: - return None - else: - return "nextval('%s')" % self.preparer.format_sequence(seq) - - def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy psycopg2 dialect now automatically escapes '%' in text() expressions to '%%'.") - return text.replace('%', '%%') - - def limit_clause(self, select): - text = "" - if select._limit is not None: - text += " \n LIMIT " + str(select._limit) - if select._offset is not None: - if select._limit is None: - text += " \n LIMIT ALL" - text += " OFFSET " + str(select._offset) - return text - - def get_select_precolumns(self, select): - if select._distinct: - if isinstance(select._distinct, bool): - return "DISTINCT " - elif isinstance(select._distinct, (list, tuple)): - return "DISTINCT ON (" + ', '.join( - [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct] - )+ ") " - else: - return "DISTINCT ON (" + unicode(select._distinct) + ") " - else: - return "" - - def for_update_clause(self, select): - if select.for_update == 'nowait': - return " FOR UPDATE NOWAIT" - else: - return super(PGCompiler, self).for_update_clause(select) - - def _append_returning(self, text, stmt): - returning_cols = stmt.kwargs['postgres_returning'] - def flatten_columnlist(collist): - for c in collist: - if isinstance(c, expression.Selectable): - for co in c.columns: - yield co - else: - yield c - columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)] - text += ' RETURNING ' + string.join(columns, ', ') - return text - - def visit_update(self, update_stmt): - text = super(PGCompiler, self).visit_update(update_stmt) - if 'postgres_returning' in update_stmt.kwargs: - return self._append_returning(text, update_stmt) - else: - return text - - def visit_insert(self, insert_stmt): - text = super(PGCompiler, self).visit_insert(insert_stmt) - if 'postgres_returning' in insert_stmt.kwargs: - return self._append_returning(text, insert_stmt) - else: - return text - - def visit_extract(self, extract, **kwargs): - field = self.extract_map.get(extract.field, extract.field) - return "EXTRACT(%s FROM %s::timestamp)" % ( - field, self.process(extract.expr)) - - -class PGSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) - if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): - if isinstance(column.type, PGBigInteger): - colspec += " BIGSERIAL" - else: - colspec += " SERIAL" - else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - return colspec - - def visit_sequence(self, sequence): - if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)): - self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - - def visit_index(self, index): - preparer = self.preparer - self.append("CREATE ") - if index.unique: - self.append("UNIQUE ") - self.append("INDEX %s ON %s (%s)" \ - % (preparer.quote(self._validate_identifier(index.name, True), index.quote), - preparer.format_table(index.table), - string.join([preparer.format_column(c) for c in index.columns], ', '))) - whereclause = index.kwargs.get('postgres_where', None) - if whereclause is not None: - compiler = self._compile(whereclause, None) - # this might belong to the compiler class - inlined_clause = str(compiler) % dict( - [(key,bind.value) for key,bind in compiler.binds.iteritems()]) - self.append(" WHERE " + inlined_clause) - self.execute() - -class PGSchemaDropper(compiler.SchemaDropper): - def visit_sequence(self, sequence): - if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): - self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - -class PGDefaultRunner(base.DefaultRunner): - def __init__(self, context): - base.DefaultRunner.__init__(self, context) - # craete cursor which won't conflict with a server-side cursor - self.cursor = context._connection.connection.cursor() - - def get_column_default(self, column, isinsert=True): - if column.primary_key: - # pre-execute passive defaults on primary keys - if (isinstance(column.server_default, schema.DefaultClause) and - column.server_default.arg is not None): - return self.execute_string("select %s" % column.server_default.arg) - elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): - sch = column.table.schema - # TODO: this has to build into the Sequence object so we can get the quoting - # logic from it - if sch is not None: - exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) - else: - exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - return self.execute_string(exc.encode(self.dialect.encoding)) - - return super(PGDefaultRunner, self).get_column_default(column) - - def visit_sequence(self, seq): - if not seq.optional: - return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq))) - else: - return None - -class PGIdentifierPreparer(compiler.IdentifierPreparer): - def _unquote_identifier(self, value): - if value[0] == self.initial_quote: - value = value[1:-1].replace('""','"') - return value - -dialect = PGDialect -dialect.statement_compiler = PGCompiler -dialect.schemagenerator = PGSchemaGenerator -dialect.schemadropper = PGSchemaDropper -dialect.preparer = PGIdentifierPreparer -dialect.defaultrunner = PGDefaultRunner -dialect.execution_ctx_cls = PGExecutionContext diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py new file mode 100644 index 0000000000..e8dd03113f --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -0,0 +1,84 @@ +"""Support for the PostgreSQL database via the pg8000 driver. + +Connecting +---------- + +URLs are of the form `postgresql+pg8000://user@password@host:port/dbname[?key=value&key=value...]`. + +Unicode +------- + +pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file +in order to use encodings other than ascii. Set this value to the same value as +the "encoding" parameter on create_engine(), usually "utf-8". + +Interval +-------- + +Passing data from/to the Interval type is not supported as of yet. + +""" +from sqlalchemy.engine import default +import decimal +from sqlalchemy import util +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler + +class _PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + if self.asdecimal: + return None + else: + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + +class PostgreSQL_pg8000ExecutionContext(default.DefaultExecutionContext): + pass + +class PostgreSQL_pg8000Compiler(PGCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + +class PostgreSQL_pg8000(PGDialect): + driver = 'pg8000' + + supports_unicode_statements = True + + supports_unicode_binds = True + + default_paramstyle = 'format' + supports_sane_multi_rowcount = False + execution_ctx_cls = PostgreSQL_pg8000ExecutionContext + statement_compiler = PostgreSQL_pg8000Compiler + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric : _PGNumeric, + sqltypes.Float: sqltypes.Float, # prevents _PGNumeric from being used + } + ) + + @classmethod + def dbapi(cls): + return __import__('pg8000').dbapi + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e): + return "connection is closed" in str(e) + +dialect = PostgreSQL_pg8000 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py new file mode 100644 index 0000000000..a428878ae0 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -0,0 +1,147 @@ +"""Support for the PostgreSQL database via the psycopg2 driver. + +Driver +------ + +The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ . +The dialect has several behaviors which are specifically tailored towards compatibility +with this module. + +Note that psycopg1 is **not** supported. + +Connecting +---------- + +URLs are of the form `postgresql+psycopg2://user@password@host:port/dbname[?key=value&key=value...]`. + +psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are: + +* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support + this feature. What this essentially means from a psycopg2 point of view is that the cursor is + created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows + are not immediately pre-fetched and buffered after statement execution, but are instead left + on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy` + uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows + at a time are fetched over the wire to reduce conversational overhead. + +* *isolation_level* - Sets the transaction isolation level for each transaction + within the engine. Valid isolation levels are `READ_COMMITTED`, + `READ_UNCOMMITTED`, `REPEATABLE_READ`, and `SERIALIZABLE`. + +Transactions +------------ + +The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. + + +""" + +import decimal, random, re +from sqlalchemy import util +from sqlalchemy.engine import base, default +from sqlalchemy.sql import expression +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler + +class _PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + if self.asdecimal: + return None + else: + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + + +# TODO: filter out 'FOR UPDATE' statements +SERVER_SIDE_CURSOR_RE = re.compile( + r'\s*SELECT', + re.I | re.UNICODE) + +class PostgreSQL_psycopg2ExecutionContext(default.DefaultExecutionContext): + def create_cursor(self): + # TODO: coverage for server side cursors + select.for_update() + is_server_side = \ + self.dialect.server_side_cursors and \ + ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) + and not getattr(self.compiled.statement, 'for_update', False)) \ + or \ + ( + (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) + and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) + ) + + self.__is_server_side = is_server_side + if is_server_side: + # use server-side cursors: + # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html + ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:]) + return self._connection.connection.cursor(ident) + else: + return self._connection.connection.cursor() + + def get_result_proxy(self): + if self.__is_server_side: + return base.BufferedRowResultProxy(self) + else: + return base.ResultProxy(self) + +class PostgreSQL_psycopg2Compiler(PGCompiler): + def visit_mod(self, binary, **kw): + return self.process(binary.left) + " %% " + self.process(binary.right) + + def post_process_text(self, text): + return text.replace('%', '%%') + +class PostgreSQL_psycopg2(PGDialect): + driver = 'psycopg2' + supports_unicode_statements = False + default_paramstyle = 'pyformat' + supports_sane_multi_rowcount = False + execution_ctx_cls = PostgreSQL_psycopg2ExecutionContext + statement_compiler = PostgreSQL_psycopg2Compiler + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric : _PGNumeric, + sqltypes.Float: sqltypes.Float, # prevents _PGNumeric from being used + } + ) + + def __init__(self, server_side_cursors=False, **kwargs): + PGDialect.__init__(self, **kwargs) + self.server_side_cursors = server_side_cursors + + @classmethod + def dbapi(cls): + psycopg = __import__('psycopg2') + return psycopg + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'closed the connection' in str(e) or 'connection not open' in str(e) + elif isinstance(e, self.dbapi.InterfaceError): + return 'connection already closed' in str(e) or 'cursor already closed' in str(e) + elif isinstance(e, self.dbapi.ProgrammingError): + # yes, it really says "losed", not "closed" + return "losed the connection unexpectedly" in str(e) + else: + return False + +dialect = PostgreSQL_psycopg2 + diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py new file mode 100644 index 0000000000..975006d927 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -0,0 +1,80 @@ +"""Support for the PostgreSQL database via py-postgresql. + +Connecting +---------- + +URLs are of the form `postgresql+pypostgresql://user@password@host:port/dbname[?key=value&key=value...]`. + + +""" +from sqlalchemy.engine import default +import decimal +from sqlalchemy import util +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgresql.base import PGDialect, PGDefaultRunner + +class PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + if self.asdecimal: + return None + else: + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + +class PostgreSQL_pypostgresqlExecutionContext(default.DefaultExecutionContext): + pass + +class PostgreSQL_pypostgresqlDefaultRunner(PGDefaultRunner): + def execute_string(self, stmt, params=None): + return PGDefaultRunner.execute_string(self, stmt, params or ()) + +class PostgreSQL_pypostgresql(PGDialect): + driver = 'pypostgresql' + + supports_unicode_statements = True + + supports_unicode_binds = True + description_encoding = None + + defaultrunner = PostgreSQL_pypostgresqlDefaultRunner + + default_paramstyle = 'format' + + supports_sane_rowcount = False # alas....posting a bug now + + supports_sane_multi_rowcount = False + + execution_ctx_cls = PostgreSQL_pypostgresqlExecutionContext + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric : PGNumeric, + sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used + } + ) + + @classmethod + def dbapi(cls): + from postgresql.driver import dbapi20 + return dbapi20 + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + else: + opts['port'] = 5432 + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e): + return "connection is closed" in str(e) + +dialect = PostgreSQL_pypostgresql diff --git a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py new file mode 100644 index 0000000000..b707d2d9eb --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py @@ -0,0 +1,28 @@ +"""Support for the PostgreSQL database via the zxjdbc JDBC connector. + +JDBC Driver +----------- + +The official Postgresql JDBC driver is at http://jdbc.postgresql.org/. + +""" +from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector +from sqlalchemy.dialects.postgresql.base import PGCompiler, PGDialect + +class PostgreSQL_jdbcCompiler(PGCompiler): + + def post_process_text(self, text): + # Don't escape '%' like PGCompiler + return text + + +class PostgreSQL_jdbc(ZxJDBCConnector, PGDialect): + statement_compiler = PostgreSQL_jdbcCompiler + + jdbc_db_name = 'postgresql' + jdbc_driver_name = 'org.postgresql.Driver' + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.dbversion.split('.')) + +dialect = PostgreSQL_jdbc diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py new file mode 100644 index 0000000000..3cc08870f2 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -0,0 +1,4 @@ +from sqlalchemy.dialects.sqlite import base, pysqlite + +# default dialect +base.dialect = pysqlite.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py new file mode 100644 index 0000000000..8dea91d0ab --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -0,0 +1,526 @@ +# sqlite.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Support for the SQLite database. + +For information on connecting using a specific driver, see the documentation +section regarding that driver. + +Date and Time Types +------------------- + +SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide +out of the box functionality for translating values between Python `datetime` objects +and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime` +and related types provide date formatting and parsing functionality when SQlite is used. +The implementation classes are :class:`_SLDateTime`, :class:`_SLDate` and :class:`_SLTime`. +These types represent dates and times as ISO formatted strings, which also nicely +support ordering. There's no reliance on typical "libc" internals for these functions +so historical dates are fully supported. + + +""" + +import datetime, re, time + +from sqlalchemy import schema as sa_schema +from sqlalchemy import sql, exc, pool, DefaultClause +from sqlalchemy.engine import default +from sqlalchemy.engine import reflection +from sqlalchemy import types as sqltypes +from sqlalchemy import util +from sqlalchemy.sql import compiler, functions as sql_functions +from sqlalchemy.util import NoneType + +from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\ + FLOAT, INTEGER, NUMERIC, SMALLINT, TEXT, TIME,\ + TIMESTAMP, VARCHAR + + +class _NumericMixin(object): + def bind_processor(self, dialect): + type_ = self.asdecimal and str or float + def process(value): + if value is not None: + return type_(value) + else: + return value + return process + +class _SLNumeric(_NumericMixin, sqltypes.Numeric): + pass + +class _SLFloat(_NumericMixin, sqltypes.Float): + pass + +# since SQLite has no date types, we're assuming that SQLite via ODBC +# or JDBC would similarly have no built in date support, so the "string" based logic +# would apply to all implementing dialects. +class _DateTimeMixin(object): + def _bind_processor(self, format, elements): + def process(value): + if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)): + raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.") + elif value is not None: + return format % tuple([getattr(value, attr, 0) for attr in elements]) + else: + return None + return process + + def _result_processor(self, fn, regexp): + def process(value): + if value is not None: + return fn(*[int(x or 0) for x in regexp.match(value).groups()]) + else: + return None + return process + +class _SLDateTime(_DateTimeMixin, sqltypes.DateTime): + __legacy_microseconds__ = False + + def bind_processor(self, dialect): + if self.__legacy_microseconds__: + return self._bind_processor( + "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s", + ("year", "month", "day", "hour", "minute", "second", "microsecond") + ) + else: + return self._bind_processor( + "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d", + ("year", "month", "day", "hour", "minute", "second", "microsecond") + ) + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?") + def result_processor(self, dialect): + return self._result_processor(datetime.datetime, self._reg) + +class _SLDate(_DateTimeMixin, sqltypes.Date): + def bind_processor(self, dialect): + return self._bind_processor( + "%4.4d-%2.2d-%2.2d", + ("year", "month", "day") + ) + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + def result_processor(self, dialect): + return self._result_processor(datetime.date, self._reg) + +class _SLTime(_DateTimeMixin, sqltypes.Time): + __legacy_microseconds__ = False + + def bind_processor(self, dialect): + if self.__legacy_microseconds__: + return self._bind_processor( + "%2.2d:%2.2d:%2.2d.%s", + ("hour", "minute", "second", "microsecond") + ) + else: + return self._bind_processor( + "%2.2d:%2.2d:%2.2d.%06d", + ("hour", "minute", "second", "microsecond") + ) + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") + def result_processor(self, dialect): + return self._result_processor(datetime.time, self._reg) + + +class _SLBoolean(sqltypes.Boolean): + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return value and 1 or 0 + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value == 1 + return process + +colspecs = { + sqltypes.Boolean: _SLBoolean, + sqltypes.Date: _SLDate, + sqltypes.DateTime: _SLDateTime, + sqltypes.Float: _SLFloat, + sqltypes.Numeric: _SLNumeric, + sqltypes.Time: _SLTime, +} + +ischema_names = { + 'BLOB': sqltypes.BLOB, + 'BOOL': sqltypes.BOOLEAN, + 'BOOLEAN': sqltypes.BOOLEAN, + 'CHAR': sqltypes.CHAR, + 'DATE': sqltypes.DATE, + 'DATETIME': sqltypes.DATETIME, + 'DECIMAL': sqltypes.DECIMAL, + 'FLOAT': sqltypes.FLOAT, + 'INT': sqltypes.INTEGER, + 'INTEGER': sqltypes.INTEGER, + 'NUMERIC': sqltypes.NUMERIC, + 'REAL': sqltypes.Numeric, + 'SMALLINT': sqltypes.SMALLINT, + 'TEXT': sqltypes.TEXT, + 'TIME': sqltypes.TIME, + 'TIMESTAMP': sqltypes.TIMESTAMP, + 'VARCHAR': sqltypes.VARCHAR, +} + + + +class SQLiteCompiler(compiler.SQLCompiler): + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update({ + 'month': '%m', + 'day': '%d', + 'year': '%Y', + 'second': '%S', + 'hour': '%H', + 'doy': '%j', + 'minute': '%M', + 'epoch': '%s', + 'dow': '%w', + 'week': '%W' + }) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "length%s" % self.function_argspec(fn) + + def visit_cast(self, cast, **kwargs): + if self.dialect.supports_cast: + return super(SQLiteCompiler, self).visit_cast(cast) + else: + return self.process(cast.clause) + + def visit_extract(self, extract): + try: + return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( + self.extract_map[extract.field], self.process(extract.expr)) + except KeyError: + raise exc.ArgumentError( + "%s is not a valid extract argument." % extract.field) + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT -1" + text += " OFFSET " + str(select._offset) + else: + text += " OFFSET 0" + return text + + def for_update_clause(self, select): + # sqlite has no "FOR UPDATE" AFAICT + return '' + + +class SQLiteDDLCompiler(compiler.DDLCompiler): + + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + +class SQLiteTypeCompiler(compiler.GenericTypeCompiler): + def visit_binary(self, type_): + return self.visit_BLOB(type_) + +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = set([ + 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', + 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', + 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', + 'conflict', 'constraint', 'create', 'cross', 'current_date', + 'current_time', 'current_timestamp', 'database', 'default', + 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', + 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', + 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', + 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', + 'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is', + 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural', + 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer', + 'plan', 'pragma', 'primary', 'query', 'raise', 'references', + 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback', + 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', + 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using', + 'vacuum', 'values', 'view', 'virtual', 'when', 'where', + ]) + +class SQLiteDialect(default.DefaultDialect): + name = 'sqlite' + supports_alter = False + supports_unicode_statements = True + supports_unicode_binds = True + supports_default_values = True + supports_empty_insert = False + supports_cast = True + + default_paramstyle = 'qmark' + statement_compiler = SQLiteCompiler + ddl_compiler = SQLiteDDLCompiler + type_compiler = SQLiteTypeCompiler + preparer = SQLiteIdentifierPreparer + ischema_names = ischema_names + colspecs = colspecs + isolation_level = None + + def __init__(self, isolation_level=None, **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + if isolation_level and isolation_level not in ('SERIALIZABLE', + 'READ UNCOMMITTED'): + raise exc.ArgumentError("Invalid value for isolation_level. " + "Valid isolation levels for sqlite are 'SERIALIZABLE' and " + "'READ UNCOMMITTED'.") + self.isolation_level = isolation_level + + def visit_pool(self, pool): + if self.isolation_level is not None: + class SetIsolationLevel(object): + def __init__(self, isolation_level): + if isolation_level == 'READ UNCOMMITTED': + self.isolation_level = 1 + else: + self.isolation_level = 0 + + def connect(self, conn, rec): + cursor = conn.cursor() + cursor.execute("PRAGMA read_uncommitted = %d" % self.isolation_level) + cursor.close() + pool.add_listener(SetIsolationLevel(self.isolation_level)) + + def table_names(self, connection, schema): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT name FROM %s " + "WHERE type='table' ORDER BY name") % (master,) + rs = connection.execute(s) + else: + try: + s = ("SELECT name FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + except exc.DBAPIError: + raise + s = ("SELECT name FROM sqlite_master " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + + return [row[0] for row in rs] + + def has_table(self, connection, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + cursor = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable))) + row = cursor.fetchone() + + # consume remaining rows, to work around + # http://www.sqlite.org/cvstrac/tktview?tn=1884 + while cursor.fetchone() is not None: + pass + + return (row is not None) + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + return self.table_names(connection, schema) + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT name FROM %s " + "WHERE type='view' ORDER BY name") % (master,) + rs = connection.execute(s) + else: + try: + s = ("SELECT name FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE type='view' ORDER BY name") + rs = connection.execute(s) + except exc.DBAPIError: + raise + s = ("SELECT name FROM sqlite_master " + "WHERE type='view' ORDER BY name") + rs = connection.execute(s) + + return [row[0] for row in rs] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT sql FROM %s WHERE name = '%s'" + "AND type='view'") % (master, view_name) + rs = connection.execute(s) + else: + try: + s = ("SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = '%s' " + "AND type='view'") % view_name + rs = connection.execute(s) + except exc.DBAPIError: + raise + s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " + "AND type='view'") % view_name + rs = connection.execute(s) + + result = rs.fetchall() + if result: + return result[0].sql + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + c = _pragma_cursor(connection.execute("%stable_info(%s)" % (pragma, qtable))) + found_table = False + columns = [] + while True: + row = c.fetchone() + if row is None: + break + (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5]) + name = re.sub(r'^\"|\"$', '', name) + if default: + default = re.sub(r"^\'|\'$", '', default) + match = re.match(r'(\w+)(\(.*?\))?', type_) + if match: + coltype = match.group(1) + args = match.group(2) + else: + coltype = "VARCHAR" + args = '' + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, name)) + coltype = sqltypes.NullType + if args is not None: + args = re.findall(r'(\d+)', args) + coltype = coltype(*[int(a) for a in args]) + + columns.append({ + 'name' : name, + 'type' : coltype, + 'nullable' : nullable, + 'default' : default, + 'primary_key': primary_key + }) + return columns + + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + cols = self.get_columns(connection, table_name, schema, **kw) + pkeys = [] + for col in cols: + if col['primary_key']: + pkeys.append(col['name']) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + c = _pragma_cursor(connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))) + fkeys = [] + fks = {} + while True: + row = c.fetchone() + if row is None: + break + (constraint_name, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4]) + rtbl = re.sub(r'^\"|\"$', '', rtbl) + lcol = re.sub(r'^\"|\"$', '', lcol) + rcol = re.sub(r'^\"|\"$', '', rcol) + try: + fk = fks[constraint_name] + except KeyError: + fk = { + 'name' : constraint_name, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : rtbl, + 'referred_columns' : [] + } + fkeys.append(fk) + fks[constraint_name] = fk + + # look up the table based on the given table's engine, not 'self', + # since it could be a ProxyEngine + if lcol not in fk['constrained_columns']: + fk['constrained_columns'].append(lcol) + if rcol not in fk['referred_columns']: + fk['referred_columns'].append(rcol) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable))) + indexes = [] + while True: + row = c.fetchone() + if row is None: + break + indexes.append(dict(name=row[1], column_names=[], unique=row[2])) + # loop thru unique indexes to get the column names. + for idx in indexes: + c = connection.execute("%sindex_info(%s)" % (pragma, quote(idx['name']))) + cols = idx['column_names'] + while True: + row = c.fetchone() + if row is None: + break + cols.append(row[2]) + return indexes + + +def _pragma_cursor(cursor): + """work around SQLite issue whereby cursor.description is blank when PRAGMA returns no rows.""" + + if cursor.closed: + cursor._fetchone_impl = lambda: None + return cursor diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py new file mode 100644 index 0000000000..a1873f33a8 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -0,0 +1,174 @@ +"""Support for the SQLite database via pysqlite. + +Note that pysqlite is the same driver as the ``sqlite3`` +module included with the Python distribution. + +Driver +------ + +When using Python 2.5 and above, the built in ``sqlite3`` driver is +already installed and no additional installation is needed. Otherwise, +the ``pysqlite2`` driver needs to be present. This is the same driver as +``sqlite3``, just with a different name. + +The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3`` +is loaded. This allows an explicitly installed pysqlite driver to take +precedence over the built in one. As with all dialects, a specific +DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control +this explicitly:: + + from sqlite3 import dbapi2 as sqlite + e = create_engine('sqlite+pysqlite:///file.db', module=sqlite) + +Full documentation on pysqlite is available at: +``_ + +Connect Strings +--------------- + +The file specification for the SQLite database is taken as the "database" portion of +the URL. Note that the format of a url is:: + + driver://user:pass@host/database + +This means that the actual filename to be used starts with the characters to the +**right** of the third slash. So connecting to a relative filepath looks like:: + + # relative path + e = create_engine('sqlite:///path/to/database.db') + +An absolute path, which is denoted by starting with a slash, means you need **four** +slashes:: + + # absolute path + e = create_engine('sqlite:////path/to/database.db') + +To use a Windows path, regular drive specifications and backslashes can be used. +Double backslashes are probably needed:: + + # absolute path on Windows + e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db') + +The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify +``sqlite://`` and nothing else:: + + # in-memory database + e = create_engine('sqlite://') + +Threading Behavior +------------------ + +Pysqlite connections do not support being moved between threads, unless +the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition, +when using an in-memory SQLite database, the full database exists only within +the scope of a single connection. It is reported that an in-memory +database does not support being shared between threads regardless of the +``check_same_thread`` flag - which means that a multithreaded +application **cannot** share data from a ``:memory:`` database across threads +unless access to the connection is limited to a single worker thread which communicates +through a queueing mechanism to concurrent threads. + +To provide a default which accomodates SQLite's default threading capabilities +somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool` +be used by default. This pool maintains a single SQLite connection per thread +that is held open up to a count of five concurrent threads. When more than five threads +are used, a cleanup mechanism will dispose of excess unused connections. + +Two optional pool implementations that may be appropriate for particular SQLite usage scenarios: + + * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded + application using an in-memory database, assuming the threading issues inherent in + pysqlite are somehow accomodated for. This pool holds persistently onto a single connection + which is never closed, and is returned for all requests. + + * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that + makes use of a file-based sqlite database. This pool disables any actual "pooling" + behavior, and simply opens and closes real connections corresonding to the :func:`connect()` + and :func:`close()` methods. SQLite can "connect" to a particular file with very high + efficiency, so this option may actually perform better without the extra overhead + of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection + useless since the database would be lost as soon as the connection is "returned" to the pool. + +Unicode +------- + +In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's +default behavior regarding Unicode is that all strings are returned as Python unicode objects +in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is +*not* used, you will still always receive unicode data back from a result set. It is +**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type +to represent strings, since it will raise a warning if a non-unicode Python string is +passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can +quickly create confusion, particularly when using the ORM as internal data is not +always represented by an actual database result string. + +""" + +from sqlalchemy.dialects.sqlite.base import SQLiteDialect +from sqlalchemy import schema, exc, pool +from sqlalchemy.engine import default +from sqlalchemy import types as sqltypes +from sqlalchemy import util + +class SQLite_pysqlite(SQLiteDialect): + default_paramstyle = 'qmark' + poolclass = pool.SingletonThreadPool + + # Py3K + #description_encoding = None + + driver = 'pysqlite' + + def __init__(self, **kwargs): + SQLiteDialect.__init__(self, **kwargs) + def vers(num): + return tuple([int(x) for x in num.split('.')]) + if self.dbapi is not None: + sqlite_ver = self.dbapi.version_info + if sqlite_ver < (2, 1, '3'): + util.warn( + ("The installed version of pysqlite2 (%s) is out-dated " + "and will cause errors in some cases. Version 2.1.3 " + "or greater is recommended.") % + '.'.join([str(subver) for subver in sqlite_ver])) + if self.dbapi.sqlite_version_info < (3, 3, 8): + self.supports_default_values = False + self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) + + @classmethod + def dbapi(cls): + try: + from pysqlite2 import dbapi2 as sqlite + except ImportError, e: + try: + from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. + except ImportError: + raise e + return sqlite + + def _get_server_version_info(self, connection): + return self.dbapi.sqlite_version_info + + def create_connect_args(self, url): + if url.username or url.password or url.host or url.port: + raise exc.ArgumentError( + "Invalid SQLite URL: %s\n" + "Valid SQLite URL forms are:\n" + " sqlite:///:memory: (or, sqlite://)\n" + " sqlite:///relative/path/to/file.db\n" + " sqlite:////absolute/path/to/file.db" % (url,)) + filename = url.database or ':memory:' + + opts = url.query.copy() + util.coerce_kw_type(opts, 'timeout', float) + util.coerce_kw_type(opts, 'isolation_level', str) + util.coerce_kw_type(opts, 'detect_types', int) + util.coerce_kw_type(opts, 'check_same_thread', bool) + util.coerce_kw_type(opts, 'cached_statements', int) + + return ([filename], opts) + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) + +dialect = SQLite_pysqlite diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py new file mode 100644 index 0000000000..f8baf339e8 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -0,0 +1,4 @@ +from sqlalchemy.dialects.sybase import base, pyodbc + +# default dialect +base.dialect = pyodbc.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py new file mode 100644 index 0000000000..6f8c648e45 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -0,0 +1,458 @@ +# sybase.py +# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch +# Coding: Alexander Houben alexander.houben@thor-solutions.ch +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the Sybase iAnywhere database. + +This is not a full backend for Sybase ASE. + +This dialect is *not* tested on SQLAlchemy 0.6. + + +Known issues / TODO: + + * Uses the mx.ODBC driver from egenix (version 2.1.0) + * The current version of sqlalchemy.databases.sybase only supports + mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need + some development) + * Support for pyodbc has been built in but is not yet complete (needs + further development) + * Results of running tests/alltests.py: + Ran 934 tests in 287.032s + FAILED (failures=3, errors=1) + * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751) +""" + +import datetime, operator + +from sqlalchemy import util, sql, schema, exc +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, base +from sqlalchemy import types as sqltypes +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import MetaData, Table, Column +from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey +from sqlalchemy.dialects.sybase.schema import * + +RESERVED_WORDS = set([ + "add", "all", "alter", "and", + "any", "as", "asc", "backup", + "begin", "between", "bigint", "binary", + "bit", "bottom", "break", "by", + "call", "capability", "cascade", "case", + "cast", "char", "char_convert", "character", + "check", "checkpoint", "close", "comment", + "commit", "connect", "constraint", "contains", + "continue", "convert", "create", "cross", + "cube", "current", "current_timestamp", "current_user", + "cursor", "date", "dbspace", "deallocate", + "dec", "decimal", "declare", "default", + "delete", "deleting", "desc", "distinct", + "do", "double", "drop", "dynamic", + "else", "elseif", "encrypted", "end", + "endif", "escape", "except", "exception", + "exec", "execute", "existing", "exists", + "externlogin", "fetch", "first", "float", + "for", "force", "foreign", "forward", + "from", "full", "goto", "grant", + "group", "having", "holdlock", "identified", + "if", "in", "index", "index_lparen", + "inner", "inout", "insensitive", "insert", + "inserting", "install", "instead", "int", + "integer", "integrated", "intersect", "into", + "iq", "is", "isolation", "join", + "key", "lateral", "left", "like", + "lock", "login", "long", "match", + "membership", "message", "mode", "modify", + "natural", "new", "no", "noholdlock", + "not", "notify", "null", "numeric", + "of", "off", "on", "open", + "option", "options", "or", "order", + "others", "out", "outer", "over", + "passthrough", "precision", "prepare", "primary", + "print", "privileges", "proc", "procedure", + "publication", "raiserror", "readtext", "real", + "reference", "references", "release", "remote", + "remove", "rename", "reorganize", "resource", + "restore", "restrict", "return", "revoke", + "right", "rollback", "rollup", "save", + "savepoint", "scroll", "select", "sensitive", + "session", "set", "setuser", "share", + "smallint", "some", "sqlcode", "sqlstate", + "start", "stop", "subtrans", "subtransaction", + "synchronize", "syntax_error", "table", "temporary", + "then", "time", "timestamp", "tinyint", + "to", "top", "tran", "trigger", + "truncate", "tsequal", "unbounded", "union", + "unique", "unknown", "unsigned", "update", + "updating", "user", "using", "validate", + "values", "varbinary", "varchar", "variable", + "varying", "view", "wait", "waitfor", + "when", "where", "while", "window", + "with", "with_cube", "with_lparen", "with_rollup", + "within", "work", "writetext", + ]) + + +class SybaseImage(sqltypes.Binary): + __visit_name__ = 'IMAGE' + +class SybaseBit(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + +class SybaseMoney(sqltypes.TypeEngine): + __visit_name__ = "MONEY" + +class SybaseSmallMoney(SybaseMoney): + __visit_name__ = "SMALLMONEY" + +class SybaseUniqueIdentifier(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + +class SybaseBoolean(sqltypes.Boolean): + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class SybaseTypeCompiler(compiler.GenericTypeCompiler): + def visit_binary(self, type_): + return self.visit_IMAGE(type_) + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return "SMALLMONEY" + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + +colspecs = { + sqltypes.Binary : SybaseImage, + sqltypes.Boolean : SybaseBoolean, +} + +ischema_names = { + 'integer' : sqltypes.INTEGER, + 'unsigned int' : sqltypes.Integer, + 'unsigned smallint' : sqltypes.SmallInteger, + 'unsigned bigint' : sqltypes.BigInteger, + 'bigint': sqltypes.BIGINT, + 'smallint' : sqltypes.SMALLINT, + 'tinyint' : sqltypes.SmallInteger, + 'varchar' : sqltypes.VARCHAR, + 'long varchar' : sqltypes.Text, + 'char' : sqltypes.CHAR, + 'decimal' : sqltypes.DECIMAL, + 'numeric' : sqltypes.NUMERIC, + 'float' : sqltypes.FLOAT, + 'double' : sqltypes.Numeric, + 'binary' : sqltypes.Binary, + 'long binary' : sqltypes.Binary, + 'varbinary' : sqltypes.Binary, + 'bit': SybaseBit, + 'image' : SybaseImage, + 'timestamp': sqltypes.TIMESTAMP, + 'money': SybaseMoney, + 'smallmoney': SybaseSmallMoney, + 'uniqueidentifier': SybaseUniqueIdentifier, + +} + + +class SybaseExecutionContext(default.DefaultExecutionContext): + + def post_exec(self): + if self.compiled.isinsert: + table = self.compiled.statement.table + # get the inserted values of the primary key + + # get any sequence IDs first (using @@identity) + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + lastrowid = int(row[0]) + if lastrowid > 0: + # an IDENTITY was inserted, fetch it + # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! + if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: + self._last_inserted_ids = [lastrowid] + else: + self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] + + +class SybaseSQLCompiler(compiler.SQLCompiler): + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond' + }) + + def visit_mod(self, binary, **kw): + return "MOD(%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def bindparam_string(self, name): + res = super(SybaseSQLCompiler, self).bindparam_string(name) + if name.lower().startswith('literal'): + res = 'STRING(%s)' % res + return res + + def get_select_precolumns(self, select): + s = select._distinct and "DISTINCT " or "" + if select._limit: + #if select._limit == 1: + #s += "FIRST " + #else: + #s += "TOP %s " % (select._limit,) + s += "TOP %s " % (select._limit,) + if select._offset: + if not select._limit: + # FIXME: sybase doesn't allow an offset without a limit + # so use a huge value for TOP here + s += "TOP 1000000 " + s += "START AT %s " % (select._offset+1,) + return s + + def limit_clause(self, select): + # Limit in sybase is after the select keyword + return "" + + def visit_binary(self, binary): + """Move bind parameters to the right-hand side of an operator, where possible.""" + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) + else: + return super(SybaseSQLCompiler, self).visit_binary(binary) + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) + + function_rewrites = {'current_date': 'getdate', + } + def visit_function(self, func): + func.name = self.function_rewrites.get(func.name, func.name) + res = super(SybaseSQLCompiler, self).visit_function(func) + if func.name.lower() == 'getdate': + # apply CAST operator + # FIXME: what about _pyodbc ? + cast = expression._Cast(func, SybaseDate_mxodbc) + # infinite recursion + # res = self.visit_cast(cast) + res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) + return res + + def visit_extract(self, extract): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class SybaseDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + + colspec = self.preparer.format_column(column) + + if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ + column.autoincrement and isinstance(column.type, sqltypes.Integer): + if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): + column.sequence = schema.Sequence(column.name + '_seq') + + if hasattr(column, 'sequence'): + column.table.has_sequence = column + #colspec += " numeric(30,0) IDENTITY" + colspec += " Integer IDENTITY" + else: + colspec += " " + self.dialect.type_compiler.process(column.type) + + if not column.nullable: + colspec += " NOT NULL" + + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(index.table.name), + self.preparer.quote(self._validate_identifier(index.name, False), index.quote) + ) + +class SybaseIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + +class SybaseDialect(default.DefaultDialect): + name = 'sybase' + supports_unicode_statements = False + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + colspecs = colspecs + ischema_names = ischema_names + + type_compiler = SybaseTypeCompiler + statement_compiler = SybaseSQLCompiler + ddl_compiler = SybaseDDLCompiler + preparer = SybaseIdentifierPreparer + + schema_name = "dba" + + def __init__(self, **params): + super(SybaseDialect, self).__init__(**params) + self.text_as_varchar = False + + def last_inserted_ids(self): + return self.context.last_inserted_ids + + def get_default_schema_name(self, connection): + return self.schema_name + + def table_names(self, connection, schema): + """Ignore the schema and the charset for now.""" + s = sql.select([tables.c.table_name], + sql.not_(tables.c.table_name.like("SYS%")) and + tables.c.creator >= 100 + ) + rp = connection.execute(s) + return [row[0] for row in rp.fetchall()] + + def has_table(self, connection, tablename, schema=None): + # FIXME: ignore schemas for sybase + s = sql.select([tables.c.table_name], tables.c.table_name == tablename) + + c = connection.execute(s) + row = c.fetchone() + return row is not None + + def reflecttable(self, connection, table, include_columns): + # Get base columns + if table.schema is not None: + current_schema = table.schema + else: + current_schema = self.get_default_schema_name(connection) + + s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id]) + + c = connection.execute(s) + found_table = False + # makes sure we append the columns in the correct order + while True: + row = c.fetchone() + if row is None: + break + found_table = True + (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = ( + row[columns.c.column_name], + row[domains.c.domain_name], + row[columns.c.nulls] == 'Y', + row[columns.c.width], + row[domains.c.precision], + row[columns.c.scale], + row[columns.c.default], + row[columns.c.pkey] == 'Y', + row[columns.c.max_identity], + row[tables.c.table_id], + row[columns.c.column_id], + ) + if include_columns and name not in include_columns: + continue + + # FIXME: else problems with SybaseBinary(size) + if numericscale == 0: + numericscale = None + + args = [] + for a in (charlen, numericprec, numericscale): + if a is not None: + args.append(a) + coltype = self.ischema_names.get(type, None) + if coltype == SybaseString and charlen == -1: + coltype = SybaseText() + else: + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % + (type, name)) + coltype = sqltypes.NULLTYPE + coltype = coltype(*args) + colargs = [] + if default is not None: + colargs.append(schema.DefaultClause(sql.text(default))) + + # any sequences ? + col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs) + if int(max_identity) > 0: + col.sequence = schema.Sequence(name + '_identity') + col.sequence.start = int(max_identity) + col.sequence.increment = 1 + + # append the column + table.append_column(col) + + # any foreign key constraint for this table ? + # note: no multi-column foreign keys are considered + s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name } + c = connection.execute(s) + foreignKeys = {} + while True: + row = c.fetchone() + if row is None: + break + (foreign_table, foreign_column, primary_table, primary_column) = ( + row[0], row[1], row[2], row[3], + ) + if not primary_table in foreignKeys.keys(): + foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]] + else: + foreignKeys[primary_table][0].append('%s'%(foreign_column)) + foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column)) + for primary_table in foreignKeys.iterkeys(): + #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)])) + table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True)) + + if not found_table: + raise exc.NoSuchTableError(table.name) + diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py new file mode 100644 index 0000000000..86a23d5bcd --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py @@ -0,0 +1,10 @@ +from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext +from sqlalchemy.connectors.mxodbc import MxODBCConnector + +class SybaseExecutionContext_mxodbc(SybaseExecutionContext): + pass + +class Sybase_mxodbc(MxODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_mxodbc + +dialect = Sybase_mxodbc \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py new file mode 100644 index 0000000000..61c6f32928 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -0,0 +1,11 @@ +from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext +from sqlalchemy.connectors.pyodbc import PyODBCConnector + +class SybaseExecutionContext_pyodbc(SybaseExecutionContext): + pass + + +class Sybase_pyodbc(PyODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_pyodbc + +dialect = Sybase_pyodbc \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sybase/schema.py b/lib/sqlalchemy/dialects/sybase/schema.py new file mode 100644 index 0000000000..15ac6b27bd --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/schema.py @@ -0,0 +1,51 @@ +from sqlalchemy import * + +ischema = MetaData() + +tables = Table("SYSTABLE", ischema, + Column("table_id", Integer, primary_key=True), + Column("file_id", SMALLINT), + Column("table_name", CHAR(128)), + Column("table_type", CHAR(10)), + Column("creator", Integer), + #schema="information_schema" + ) + +domains = Table("SYSDOMAIN", ischema, + Column("domain_id", Integer, primary_key=True), + Column("domain_name", CHAR(128)), + Column("type_id", SMALLINT), + Column("precision", SMALLINT, quote=True), + #schema="information_schema" + ) + +columns = Table("SYSCOLUMN", ischema, + Column("column_id", Integer, primary_key=True), + Column("table_id", Integer, ForeignKey(tables.c.table_id)), + Column("pkey", CHAR(1)), + Column("column_name", CHAR(128)), + Column("nulls", CHAR(1)), + Column("width", SMALLINT), + Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)), + # FIXME: should be mx.BIGINT + Column("max_identity", Integer), + # FIXME: should be mx.ODBC.Windows.LONGVARCHAR + Column("default", String), + Column("scale", Integer), + #schema="information_schema" + ) + +foreignkeys = Table("SYSFOREIGNKEY", ischema, + Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True), + Column("foreign_key_id", SMALLINT, primary_key=True), + Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)), + #schema="information_schema" + ) +fkcols = Table("SYSFKCOL", ischema, + Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True), + Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True), + Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True), + Column("primary_column_id", Integer), + #schema="information_schema" + ) + diff --git a/lib/sqlalchemy/dialects/type_migration_guidelines.txt b/lib/sqlalchemy/dialects/type_migration_guidelines.txt new file mode 100644 index 0000000000..8ed1a17975 --- /dev/null +++ b/lib/sqlalchemy/dialects/type_migration_guidelines.txt @@ -0,0 +1,145 @@ +Rules for Migrating TypeEngine classes to 0.6 +--------------------------------------------- + +1. the TypeEngine classes are used for: + + a. Specifying behavior which needs to occur for bind parameters + or result row columns. + + b. Specifying types that are entirely specific to the database + in use and have no analogue in the sqlalchemy.types package. + + c. Specifying types where there is an analogue in sqlalchemy.types, + but the database in use takes vendor-specific flags for those + types. + + d. If a TypeEngine class doesn't provide any of this, it should be + *removed* from the dialect. + +2. the TypeEngine classes are *no longer* used for generating DDL. Dialects +now have a TypeCompiler subclass which uses the same visit_XXX model as +other compilers. + +3. the "ischema_names" and "colspecs" dictionaries are now required members on +the Dialect class. + +4. The names of types within dialects are now important. If a dialect-specific type +is a subclass of an existing generic type and is only provided for bind/result behavior, +the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case, +end users would never need to use _PGNumeric directly. However, if a dialect-specific +type is specifying a type *or* arguments that are not present generically, it should +match the real name of the type on that backend, in uppercase. E.g. postgresql.INET, +mysql.ENUM, postgresql.ARRAY. + +Or follow this handy flowchart: + + is the type meant to provide bind/result is the type the same name as an + behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ? + type in types.py ? | | + | no yes + yes | | + | | does your type need special + | +<--- yes --- behavior or arguments ? + | | | + | | no + name the type using | | + _MixedCase, i.e. v V + _OracleBoolean. it name the type don't make a + stays private to the dialect identically as that type, make sure the dialect's + and is invoked *only* via within the DB, base.py imports the types.py + the colspecs dict. using UPPERCASE UPPERCASE name into its namespace + | (i.e. BIT, NCHAR, INTERVAL). + | Users can import it. + | | + v v + subclass the closest is the name of this type + MixedCase type types.py, identical to an UPPERCASE + i.e. <--- no ------- name in types.py ? + class _DateTime(types.DateTime), + class DATETIME2(types.DateTime), | + class BIT(types.TypeEngine). yes + | + v + the type should + subclass the + UPPERCASE + type in types.py + (i.e. class BLOB(types.BLOB)) + + +Example 1. pysqlite needs bind/result processing for the DateTime type in types.py, +which applies to all DateTimes and subclasses. It's named _SLDateTime and +subclasses types.DateTime. + +Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument +that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py, +and subclasses types.TIME. Users can then say mssql.TIME(precision=10). + +Example 3. MS-SQL dialects also need special bind/result processing for date +But its DATE type doesn't render DDL differently than that of a plain +DATE, i.e. it takes no special arguments. Therefore we are just adding behavior +to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses +types.Date. + +Example 4. MySQL has a SET type, there's no analogue for this in types.py. So +MySQL names it SET in the dialect's base.py, and it subclasses types.String, since +it ultimately deals with strings. + +Example 5. Postgresql has a DATETIME type. The DBAPIs handle dates correctly, +and no special arguments are used in PG's DDL beyond what types.py provides. +Postgresql dialect therefore imports types.DATETIME into its base.py. + +Ideally one should be able to specify a schema using names imported completely from a +dialect, all matching the real name on that backend: + + from sqlalchemy.dialects.postgresql import base as pg + + t = Table('mytable', metadata, + Column('id', pg.INTEGER, primary_key=True), + Column('name', pg.VARCHAR(300)), + Column('inetaddr', pg.INET) + ) + +where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types, +but the PG dialect makes them available in its own namespace. + +5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types +linked to types specified in the dialect. Again, if a type in the dialect does not +specify any special behavior for bind_processor() or result_processor() and does not +indicate a special type only available in this database, it must be *removed* from the +module and from this dictionary. + +6. "ischema_names" indicates string descriptions of types as returned from the database +linked to TypeEngine classes. + + a. The string name should be matched to the most specific type possible within + sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which + case it points to a dialect type. *It doesn't matter* if the dialect has it's + own subclass of that type with special bind/result behavior - reflect to the types.py + UPPERCASE type as much as possible. With very few exceptions, all types + should reflect to an UPPERCASE type. + + b. If the dialect contains a matching dialect-specific type that takes extra arguments + which the generic one does not, then point to the dialect-specific type. E.g. + mssql.VARCHAR takes a "collation" parameter which should be preserved. + +5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by +a subclass of compiler.GenericTypeCompiler. + + a. your TypeCompiler class will receive generic and uppercase types from + sqlalchemy.types. Do not assume the presence of dialect-specific attributes on + these types. + + b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with + methods that produce a different DDL name. Uppercase types don't do any kind of + "guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in + all cases, regardless of whether or not that type is legal on the backend database. + + c. the visit_UPPERCASE methods *should* be overridden with methods that add additional + arguments and flags to those types. + + d. the visit_lowercase methods are overridden to provide an interpretation of a generic + type. E.g. visit_binary() might be overridden to say "return self.visit_BIT(type_)". + + e. visit_lowercase methods should *never* render strings directly - it should always + be via calling a visit_UPPERCASE() method. diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index bb2b1b5be4..694a2f71fa 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -50,7 +50,9 @@ url.py within a URL. """ -import sqlalchemy.databases +# not sure what this was used for +#import sqlalchemy.databases + from sqlalchemy.engine.base import ( BufferedColumnResultProxy, BufferedColumnRow, @@ -66,9 +68,9 @@ from sqlalchemy.engine.base import ( ResultProxy, RootTransaction, RowProxy, - SchemaIterator, Transaction, - TwoPhaseTransaction + TwoPhaseTransaction, + TypeCompiler ) from sqlalchemy.engine import strategies from sqlalchemy import util @@ -89,9 +91,9 @@ __all__ = ( 'ResultProxy', 'RootTransaction', 'RowProxy', - 'SchemaIterator', 'Transaction', 'TwoPhaseTransaction', + 'TypeCompiler', 'create_engine', 'engine_from_config', ) @@ -108,7 +110,7 @@ def create_engine(*args, **kwargs): The URL is a string in the form ``dialect://user:password@host/dbname[?key=value..]``, where - ``dialect`` is a name such as ``mysql``, ``oracle``, ``postgres``, + ``dialect`` is a name such as ``mysql``, ``oracle``, ``postgresql``, etc. Alternatively, the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`. diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 39085c3596..0a0b0ff0ca 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -10,14 +10,16 @@ Defines the basic components used to interface DB-API modules with higher-level statement-construction, connection-management, execution and result contexts. - """ -__all__ = ['BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy', 'Compiled', 'Connectable', - 'Connection', 'DefaultRunner', 'Dialect', 'Engine', 'ExecutionContext', 'NestedTransaction', 'ResultProxy', - 'RootTransaction', 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction', 'connection_memoize'] +__all__ = [ + 'BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy', + 'Compiled', 'Connectable', 'Connection', 'DefaultRunner', 'Dialect', 'Engine', + 'ExecutionContext', 'NestedTransaction', 'ResultProxy', 'RootTransaction', + 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction', + 'connection_memoize'] -import inspect, StringIO +import inspect, StringIO, sys from sqlalchemy import exc, schema, util, types, log from sqlalchemy.sql import expression @@ -32,10 +34,14 @@ class Dialect(object): ExecutionContext, Compiled, DefaultGenerator, and TypeEngine. All Dialects implement the following attributes: - + name - identifying name for the dialect (i.e. 'sqlite') - + identifying name for the dialect from a DBAPI-neutral point of view + (i.e. 'sqlite') + + driver + identifying name for the dialect's DBAPI + positional True if the paramstyle for this Dialect is positional. @@ -51,20 +57,25 @@ class Dialect(object): type of encoding to use for unicode, usually defaults to 'utf-8'. - schemagenerator - a :class:`~sqlalchemy.schema.SchemaVisitor` class which generates - schemas. - - schemadropper - a :class:`~sqlalchemy.schema.SchemaVisitor` class which drops schemas. - defaultrunner a :class:`~sqlalchemy.schema.SchemaVisitor` class which executes defaults. statement_compiler - a :class:`~sqlalchemy.engine.base.Compiled` class used to compile SQL - statements + a :class:`~Compiled` class used to compile SQL statements + + ddl_compiler + a :class:`~Compiled` class used to compile DDL statements + + server_version_info + a tuple containing a version number for the DB backend in use. + This value is only available for supporting dialects, and only for + a dialect that's been associated with a connection pool via + create_engine() or otherwise had its ``initialize()`` method called + with a conneciton. + + execution_ctx_cls + a :class:`ExecutionContext` class used to handle statement execution preparer a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to @@ -77,27 +88,38 @@ class Dialect(object): The maximum length of identifier names. supports_unicode_statements - Indicate whether the DB-API can receive SQL statements as Python unicode strings + Indicate whether the DB-API can receive SQL statements as Python + unicode strings + + supports_unicode_binds + Indicate whether the DB-API can receive string bind parameters + as Python unicode strings supports_sane_rowcount - Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements. + Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements. supports_sane_multi_rowcount - Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements - when executed via executemany. - - preexecute_pk_sequences - Indicate if the dialect should pre-execute sequences on primary key - columns during an INSERT, if it's desired that the new row's primary key - be available after execution. - - supports_pk_autoincrement - Indicates if the dialect should allow the database to passively assign - a primary key column value. - + Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements when executed via + executemany. + + preexecute_autoincrement_sequences + True if 'implicit' primary key functions must be executed separately + in order to get their value. This is currently oriented towards + Postgresql. + + implicit_returning + use RETURNING or equivalent during INSERT execution in order to load + newly generated primary keys and other column defaults in one execution, + which are then available via inserted_primary_key. + If an insert statement has returning() specified explicitly, + the "implicit" functionality is not used and inserted_primary_key + will not be available. + dbapi_type_map A mapping of DB-API type objects present in this Dialect's - DB-API implmentation mapped to TypeEngine implementations used + DB-API implementation mapped to TypeEngine implementations used by the dialect. This is used to apply types to result sets based on the DB-API @@ -105,13 +127,15 @@ class Dialect(object): result sets against textual statements where no explicit typemap was present. - supports_default_values - Indicates if the construct ``INSERT INTO tablename DEFAULT VALUES`` is supported + colspecs + A dictionary of TypeEngine classes from sqlalchemy.types mapped + to subclasses that are specific to the dialect class. This + dictionary is class-level only and is not accessed from the + dialect instance itself. - description_encoding - type of encoding to use for unicode when working with metadata - descriptions. If set to ``None`` no encoding will be done. - This usually defaults to 'utf-8'. + supports_default_values + Indicates if the construct ``INSERT INTO tablename DEFAULT + VALUES`` is supported """ def create_connect_args(self, url): @@ -124,25 +148,28 @@ class Dialect(object): raise NotImplementedError() + @classmethod + def type_descriptor(cls, typeobj): + """Transform a generic type to a dialect-specific type. - def type_descriptor(self, typeobj): - """Transform a generic type to a database-specific type. - - Transforms the given :class:`~sqlalchemy.types.TypeEngine` instance - from generic to database-specific. - - Subclasses will usually use the + Dialect classes will usually use the :func:`~sqlalchemy.types.adapt_type` method in the types module to make this job easy. + + The returned result is cached *per dialect class* so can + contain no dialect-instance state. """ raise NotImplementedError() + def initialize(self, connection): + """Called during strategized creation of the dialect with a connection. - def server_version_info(self, connection): - """Return a tuple of the database's version number.""" + Allows dialects to configure options based on server version info or + other properties. + """ - raise NotImplementedError() + pass def reflecttable(self, connection, table, include_columns=None): """Load table description from the database. @@ -156,6 +183,133 @@ class Dialect(object): raise NotImplementedError() + def get_columns(self, connection, table_name, schema=None, **kw): + """Return information about columns in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name`, and an optional string `schema`, return column + information as a list of dictionaries with these keys: + + name + the column's name + + type + [sqlalchemy.types#TypeEngine] + + nullable + boolean + + default + the column's default value + + autoincrement + boolean + + sequence + a dictionary of the form + {'name' : str, 'start' :int, 'increment': int} + + Additional column attributes may be present. + """ + + raise NotImplementedError() + + def get_primary_keys(self, connection, table_name, schema=None, **kw): + """Return information about primary keys in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name`, and an optional string `schema`, return primary + key information as a list of column names. + """ + + raise NotImplementedError() + + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + """Return information about foreign_keys in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name`, and an optional string `schema`, return foreign + key information as a list of dicts with these keys: + + name + the constraint's name + + constrained_columns + a list of column names that make up the foreign key + + referred_schema + the name of the referred schema + + referred_table + the name of the referred table + + referred_columns + a list of column names in the referred table that correspond to + constrained_columns + """ + + raise NotImplementedError() + + def get_table_names(self, connection, schema=None, **kw): + """Return a list of table names for `schema`.""" + + raise NotImplementedError + + def get_view_names(self, connection, schema=None, **kw): + """Return a list of all view names available in the database. + + schema: + Optional, retrieve names from a non-default schema. + """ + + raise NotImplementedError() + + def get_view_definition(self, connection, view_name, schema=None, **kw): + """Return view definition. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `view_name`, and an optional string `schema`, return the view + definition. + """ + + raise NotImplementedError() + + def get_indexes(self, connection, table_name, schema=None, **kw): + """Return information about indexes in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name` and an optional string `schema`, return index + information as a list of dictionaries with these keys: + + name + the index's name + + column_names + list of column names in order + + unique + boolean + """ + + raise NotImplementedError() + + def normalize_name(self, name): + """convert the given name to lowercase if it is detected as case insensitive. + + this method is only used if the dialect defines requires_name_normalize=True. + + """ + raise NotImplementedError() + + def denormalize_name(self, name): + """convert the given name to a case insensitive identifier for the backend + if it is an all-lowercase name. + + this method is only used if the dialect defines requires_name_normalize=True. + + """ + raise NotImplementedError() + def has_table(self, connection, table_name, schema=None): """Check the existence of a particular table in the database. @@ -178,7 +332,11 @@ class Dialect(object): raise NotImplementedError() def get_default_schema_name(self, connection): - """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`.""" + """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`. + + DEPRECATED. moving this towards dialect.default_schema_name (not complete). + + """ raise NotImplementedError() @@ -262,11 +420,14 @@ class Dialect(object): raise NotImplementedError() + def visit_pool(self, pool): + """Executed after a pool is created.""" + class ExecutionContext(object): """A messenger object for a Dialect that corresponds to a single execution. - ExecutionContext should have these datamembers: + ExecutionContext should have these data members: connection Connection object which can be freely used by default value @@ -308,20 +469,19 @@ class ExecutionContext(object): True if the statement is an UPDATE. should_autocommit - True if the statement is a "committable" statement + True if the statement is a "committable" statement. postfetch_cols - a list of Column objects for which a server-side default - or inline SQL expression value was fired off. applies to inserts and updates. - - + a list of Column objects for which a server-side default or + inline SQL expression value was fired off. Applies to inserts + and updates. """ def create_cursor(self): """Return a new cursor generated from this ExecutionContext's connection. Some dialects may wish to change the behavior of - connection.cursor(), such as postgres which may return a PG + connection.cursor(), such as postgresql which may return a PG "server side" cursor. """ @@ -357,22 +517,11 @@ class ExecutionContext(object): def handle_dbapi_exception(self, e): """Receive a DBAPI exception which occured upon execute, result fetch, etc.""" - - raise NotImplementedError() - - def should_autocommit_text(self, statement): - """Parse the given textual statement and return True if it refers to a "committable" statement""" raise NotImplementedError() - def last_inserted_ids(self): - """Return the list of the primary key values for the last insert statement executed. - - This does not apply to straight textual clauses; only to - ``sql.Insert`` objects compiled against a ``schema.Table`` - object. The order of items in the list is the same as that of - the Table's 'primary_key' attribute. - """ + def should_autocommit_text(self, statement): + """Parse the given textual statement and return True if it refers to a "committable" statement""" raise NotImplementedError() @@ -401,7 +550,7 @@ class ExecutionContext(object): class Compiled(object): - """Represent a compiled SQL expression. + """Represent a compiled SQL or DDL expression. The ``__str__`` method of the ``Compiled`` object should produce the actual text of the statement. ``Compiled`` objects are @@ -413,53 +562,49 @@ class Compiled(object): defaults. """ - def __init__(self, dialect, statement, column_keys=None, bind=None): + def __init__(self, dialect, statement, bind=None): """Construct a new ``Compiled`` object. - dialect - ``Dialect`` to compile against. - - statement - ``ClauseElement`` to be compiled. + :param dialect: ``Dialect`` to compile against. - column_keys - a list of column names to be compiled into an INSERT or UPDATE - statement. + :param statement: ``ClauseElement`` to be compiled. - bind - Optional Engine or Connection to compile this statement against. - + :param bind: Optional Engine or Connection to compile this statement against. """ + self.dialect = dialect self.statement = statement - self.column_keys = column_keys self.bind = bind self.can_execute = statement.supports_execution def compile(self): """Produce the internal string representation of this element.""" - raise NotImplementedError() + self.string = self.process(self.statement) - def __str__(self): - """Return the string text of the generated SQL statement.""" + def process(self, obj, **kwargs): + return obj._compiler_dispatch(self, **kwargs) - raise NotImplementedError() + def __str__(self): + """Return the string text of the generated SQL or DDL.""" - @util.deprecated('Deprecated. Use construct_params(). ' - '(supports Unicode key names.)') - def get_params(self, **params): - return self.construct_params(params) + return self.string or '' - def construct_params(self, params): + def construct_params(self, params=None): """Return the bind params for this compiled object. - `params` is a dict of string/object pairs whos - values will override bind values compiled in - to the statement. + :param params: a dict of string/object pairs whos values will + override bind values compiled in to the + statement. """ + raise NotImplementedError() + @property + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() + def execute(self, *multiparams, **params): """Execute this compiled object.""" @@ -474,12 +619,24 @@ class Compiled(object): return self.execute(*multiparams, **params).scalar() +class TypeCompiler(object): + """Produces DDL specification for TypeEngine objects.""" + + def __init__(self, dialect): + self.dialect = dialect + + def process(self, type_): + return type_._compiler_dispatch(self) + + class Connectable(object): """Interface for an object which supports execution of SQL constructs. - + The two implementations of ``Connectable`` are :class:`Connection` and :class:`Engine`. - + + Connectable must also implement the 'dialect' member which references a + :class:`Dialect` instance. """ def contextual_connect(self): @@ -503,6 +660,7 @@ class Connectable(object): def _execute_clauseelement(self, elem, multiparams=None, params=None): raise NotImplementedError() + class Connection(Connectable): """Provides high-level functionality for a wrapped DB-API connection. @@ -514,7 +672,6 @@ class Connection(Connectable): .. index:: single: thread safety; Connection - """ def __init__(self, engine, connection=None, close_with_result=False, @@ -524,7 +681,6 @@ class Connection(Connectable): Connection objects are typically constructed by an :class:`~sqlalchemy.engine.Engine`, see the ``connect()`` and ``contextual_connect()`` methods of Engine. - """ self.engine = engine @@ -534,7 +690,7 @@ class Connection(Connectable): self.__savepoint_seq = 0 self.__branch = _branch self.__invalid = False - + def _branch(self): """Return a new Connection which references this Connection's engine and connection; but does not have close_with_result enabled, @@ -542,8 +698,8 @@ class Connection(Connectable): This is used to execute "sub" statements within a single execution, usually an INSERT statement. - """ + return self.engine.Connection(self.engine, self.__connection, _branch=True) @property @@ -554,13 +710,13 @@ class Connection(Connectable): @property def closed(self): - """return True if this connection is closed.""" + """Return True if this connection is closed.""" return not self.__invalid and '_Connection__connection' not in self.__dict__ @property def invalidated(self): - """return True if this connection was invalidated.""" + """Return True if this connection was invalidated.""" return self.__invalid @@ -583,13 +739,14 @@ class Connection(Connectable): def should_close_with_result(self): """Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode. - """ + return self.__close_with_result @property def info(self): """A collection of per-DB-API connection instance properties.""" + return self.connection.info def connect(self): @@ -598,8 +755,8 @@ class Connection(Connectable): This ``Connectable`` interface method returns self, allowing Connections to be used interchangably with Engines in most situations that require a bind. - """ + return self def contextual_connect(self, **kwargs): @@ -608,8 +765,8 @@ class Connection(Connectable): This ``Connectable`` interface method returns self, allowing Connections to be used interchangably with Engines in most situations that require a bind. - """ + return self def invalidate(self, exception=None): @@ -627,8 +784,8 @@ class Connection(Connectable): rolled back before a reconnect on this Connection can proceed. This is to prevent applications from accidentally continuing their transactional operations in a non-transactional state. - """ + if self.closed: raise exc.InvalidRequestError("This Connection is closed") @@ -651,8 +808,8 @@ class Connection(Connectable): :class:`~sqlalchemy.interfaces.PoolListener` for a mechanism to modify connection state when connections leave and return to their connection pool. - """ + self.__connection.detach() def begin(self): @@ -663,8 +820,8 @@ class Connection(Connectable): outermost transaction may ``commit``. Calls to ``commit`` on inner transactions are ignored. Any transaction in the hierarchy may ``rollback``, however. - """ + if self.__transaction is None: self.__transaction = RootTransaction(self) else: @@ -690,9 +847,8 @@ class Connection(Connectable): def begin_twophase(self, xid=None): """Begin a two-phase or XA transaction and return a Transaction handle. - xid - the two phase transaction id. If not supplied, a random id - will be generated. + :param xid: the two phase transaction id. If not supplied, a random id + will be generated. """ if self.__transaction is not None: @@ -813,9 +969,6 @@ class Connection(Connectable): return self.execute(object, *multiparams, **params).scalar() - def statement_compiler(self, statement, **kwargs): - return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs) - def execute(self, object, *multiparams, **params): """Executes and returns a ResultProxy.""" @@ -826,11 +979,12 @@ class Connection(Connectable): raise exc.InvalidRequestError("Unexecutable object type: " + str(type(object))) def __distill_params(self, multiparams, params): - """given arguments from the calling form *multiparams, **params, return a list + """Given arguments from the calling form *multiparams, **params, return a list of bind parameter structures, usually a list of dictionaries. - in the case of 'raw' execution which accepts positional parameters, - it may be a list of tuples or lists.""" + In the case of 'raw' execution which accepts positional parameters, + it may be a list of tuples or lists. + """ if not multiparams: if params: @@ -858,7 +1012,19 @@ class Connection(Connectable): return self._execute_clauseelement(func.select(), multiparams, params) def _execute_default(self, default, multiparams, params): - return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default) + ret = self.engine.dialect.\ + defaultrunner(self.__create_execution_context()).\ + traverse_single(default) + if self.__close_with_result: + self.close() + return ret + + def _execute_ddl(self, ddl, params, multiparams): + context = self.__create_execution_context( + compiled_ddl=ddl.compile(dialect=self.dialect), + parameters=None + ) + return self.__execute_context(context) def _execute_clauseelement(self, elem, multiparams, params): params = self.__distill_params(multiparams, params) @@ -868,7 +1034,7 @@ class Connection(Connectable): keys = [] context = self.__create_execution_context( - compiled=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), + compiled_sql=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), parameters=params ) return self.__execute_context(context) @@ -877,7 +1043,7 @@ class Connection(Connectable): """Execute a sql.Compiled object.""" context = self.__create_execution_context( - compiled=compiled, + compiled_sql=compiled, parameters=self.__distill_params(multiparams, params) ) return self.__execute_context(context) @@ -886,38 +1052,42 @@ class Connection(Connectable): parameters = self.__distill_params(multiparams, params) context = self.__create_execution_context(statement=statement, parameters=parameters) return self.__execute_context(context) - + def __execute_context(self, context): if context.compiled: context.pre_exec() + if context.executemany: self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context) else: self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context) + if context.compiled: context.post_exec() + + if context.isinsert and not context.executemany: + context.post_insert() + if context.should_autocommit and not self.in_transaction(): self._commit_impl() - return context.get_result_proxy() + + return context.get_result_proxy()._autoclose() - def _execute_ddl(self, ddl, params, multiparams): - if params: - schema_item, params = params[0], params[1:] - else: - schema_item = None - return ddl(None, schema_item, self, *params, **multiparams) - def _handle_dbapi_exception(self, e, statement, parameters, cursor, context): if getattr(self, '_reentrant_error', False): - raise exc.DBAPIError.instance(None, None, e) + # Py3K + #raise exc.DBAPIError.instance(statement, parameters, e) from e + # Py2K + raise exc.DBAPIError.instance(statement, parameters, e), None, sys.exc_info()[2] + # end Py2K self._reentrant_error = True try: if not isinstance(e, self.dialect.dbapi.Error): return - + if context: context.handle_dbapi_exception(e) - + is_disconnect = self.dialect.is_disconnect(e) if is_disconnect: self.invalidate(e) @@ -928,7 +1098,12 @@ class Connection(Connectable): self._autorollback() if self.__close_with_result: self.close() - raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) + # Py3K + #raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) from e + # Py2K + raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect), None, sys.exc_info()[2] + # end Py2K + finally: del self._reentrant_error @@ -966,7 +1141,7 @@ class Connection(Connectable): expression.ClauseElement: _execute_clauseelement, Compiled: _execute_compiled, schema.SchemaItem: _execute_default, - schema.DDL: _execute_ddl, + schema.DDLElement: _execute_ddl, basestring: _execute_text } @@ -991,6 +1166,7 @@ class Connection(Connectable): def run_callable(self, callable_): return callable_(self) + class Transaction(object): """Represent a Transaction in progress. @@ -998,14 +1174,13 @@ class Transaction(object): .. index:: single: thread safety; Transaction - """ def __init__(self, connection, parent): self.connection = connection self._parent = parent or self self.is_active = True - + def close(self): """Close this transaction. @@ -1016,6 +1191,7 @@ class Transaction(object): This is used to cancel a Transaction without affecting the scope of an enclosing transaction. """ + if not self._parent.is_active: return if self._parent is self: @@ -1048,6 +1224,7 @@ class Transaction(object): else: self.rollback() + class RootTransaction(Transaction): def __init__(self, connection): super(RootTransaction, self).__init__(connection, None) @@ -1059,6 +1236,7 @@ class RootTransaction(Transaction): def _do_commit(self): self.connection._commit_impl() + class NestedTransaction(Transaction): def __init__(self, connection, parent): super(NestedTransaction, self).__init__(connection, parent) @@ -1070,6 +1248,7 @@ class NestedTransaction(Transaction): def _do_commit(self): self.connection._release_savepoint_impl(self._savepoint, self._parent) + class TwoPhaseTransaction(Transaction): def __init__(self, connection, xid): super(TwoPhaseTransaction, self).__init__(connection, None) @@ -1089,9 +1268,10 @@ class TwoPhaseTransaction(Transaction): def commit(self): self.connection._commit_twophase_impl(self.xid, self._is_prepared) + class Engine(Connectable): """ - Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect` + Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect` together to provide a source of database connectivity and behavior. """ @@ -1111,9 +1291,15 @@ class Engine(Connectable): @property def name(self): "String name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``." - + return self.dialect.name + @property + def driver(self): + "Driver name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``." + + return self.dialect.driver + echo = log.echo_property() def __repr__(self): @@ -1126,12 +1312,16 @@ class Engine(Connectable): def create(self, entity, connection=None, **kwargs): """Create a table or index within this engine's database connection given a schema.Table object.""" - self._run_visitor(self.dialect.schemagenerator, entity, connection=connection, **kwargs) + from sqlalchemy.engine import ddl + + self._run_visitor(ddl.SchemaGenerator, entity, connection=connection, **kwargs) def drop(self, entity, connection=None, **kwargs): """Drop a table or index within this engine's database connection given a schema.Table object.""" - self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs) + from sqlalchemy.engine import ddl + + self._run_visitor(ddl.SchemaDropper, entity, connection=connection, **kwargs) def _execute_default(self, default): connection = self.contextual_connect() @@ -1212,9 +1402,6 @@ class Engine(Connectable): connection = self.contextual_connect(close_with_result=True) return connection._execute_compiled(compiled, multiparams, params) - def statement_compiler(self, statement, **kwargs): - return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs) - def connect(self, **kwargs): """Return a newly allocated Connection object.""" @@ -1231,12 +1418,10 @@ class Engine(Connectable): def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. - schema: - Optional, retrieve names from a non-default schema. + :param schema: Optional, retrieve names from a non-default schema. - connection: - Optional, use a specified connection. Default is the - ``contextual_connect`` for this ``Engine``. + :param connection: Optional, use a specified connection. Default is the + ``contextual_connect`` for this ``Engine``. """ if connection is None: @@ -1275,22 +1460,24 @@ class Engine(Connectable): return self.pool.unique_connection() + def _proxy_connection_cls(cls, proxy): class ProxyConnection(cls): def execute(self, object, *multiparams, **params): return proxy.execute(self, super(ProxyConnection, self).execute, object, *multiparams, **params) - + def _execute_clauseelement(self, elem, multiparams=None, params=None): return proxy.execute(self, super(ProxyConnection, self).execute, elem, *(multiparams or []), **(params or {})) - + def _cursor_execute(self, cursor, statement, parameters, context=None): return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, cursor, statement, parameters, context, False) - + def _cursor_executemany(self, cursor, statement, parameters, context=None): return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, cursor, statement, parameters, context, True) return ProxyConnection + class RowProxy(object): """Proxy a single cursor row for a parent ResultProxy. @@ -1302,7 +1489,7 @@ class RowProxy(object): """ __slots__ = ['__parent', '__row'] - + def __init__(self, parent, row): """RowProxy objects are constructed by ResultProxy objects.""" @@ -1327,7 +1514,7 @@ class RowProxy(object): yield self.__parent._get_col(self.__row, i) __hash__ = None - + def __eq__(self, other): return ((other is self) or (other == tuple(self.__parent._get_col(self.__row, key) @@ -1362,18 +1549,19 @@ class RowProxy(object): """Return the list of keys as strings represented by this RowProxy.""" return self.__parent.keys - + def iterkeys(self): return iter(self.__parent.keys) - + def values(self): """Return the values represented by this RowProxy as a list.""" return list(self) - + def itervalues(self): return iter(self) + class BufferedColumnRow(RowProxy): def __init__(self, parent, row): row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))] @@ -1403,9 +1591,8 @@ class ResultProxy(object): """ _process_row = RowProxy - + def __init__(self, context): - """ResultProxy objects are constructed via the execute() method on SQLEngine.""" self.context = context self.dialect = context.dialect self.closed = False @@ -1413,40 +1600,81 @@ class ResultProxy(object): self.connection = context.root_connection self._echo = context.engine._should_log_info self._init_metadata() - - @property + + @util.memoized_property def rowcount(self): - if self._rowcount is None: - return self.context.get_rowcount() - else: - return self._rowcount + """Return the 'rowcount' for this result. + + The 'rowcount' reports the number of rows affected + by an UPDATE or DELETE statement. It has *no* other + uses and is not intended to provide the number of rows + present from a SELECT. + + Additionally, this value is only meaningful if the + dialect's supports_sane_rowcount flag is True for + single-parameter executions, or supports_sane_multi_rowcount + is true for multiple parameter executions - otherwise + results are undefined. + + rowcount may not work at this time for a statement + that uses ``returning()``. + + """ + return self.context.rowcount @property def lastrowid(self): + """return the 'lastrowid' accessor on the DBAPI cursor. + + This is a DBAPI specific method and is only functional + for those backends which support it, for statements + where it is appropriate. It's behavior is not + consistent across backends. + + Usage of this method is normally unnecessary; the + inserted_primary_key method provides a + tuple of primary key values for a newly inserted row, + regardless of database backend. + + """ return self.cursor.lastrowid @property def out_parameters(self): return self.context.out_parameters - + + def _cursor_description(self): + return self.cursor.description + + def _autoclose(self): + if self.context.isinsert: + if self.context._is_implicit_returning: + self.context._fetch_implicit_returning(self) + self.close() + elif not self.context._is_explicit_returning: + self.close() + elif self._metadata is None: + # no results, get rowcount + # (which requires open cursor on some DB's such as firebird), + self.rowcount + self.close() # autoclose + + return self + + def _init_metadata(self): - metadata = self.cursor.description + self._metadata = metadata = self._cursor_description() if metadata is None: - # no results, get rowcount (which requires open cursor on some DB's such as firebird), - # then close - self._rowcount = self.context.get_rowcount() - self.close() return - - self._rowcount = None + self._props = util.populate_column_dict(None) self._props.creator = self.__key_fallback() self.keys = [] typemap = self.dialect.dbapi_type_map - for i, item in enumerate(metadata): - colname = item[0] + for i, (colname, coltype) in enumerate(m[0:2] for m in metadata): + if self.dialect.description_encoding: colname = colname.decode(self.dialect.description_encoding) @@ -1461,9 +1689,9 @@ class ResultProxy(object): try: (name, obj, type_) = self.context.result_map[colname.lower()] except KeyError: - (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) + (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE)) else: - (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) + (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE)) rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i) @@ -1474,7 +1702,10 @@ class ResultProxy(object): if origname: if self._props.setdefault(origname.lower(), rec) is not rec: self._props[origname.lower()] = (type_, self.__ambiguous_processor(origname), 0) - + + if self.dialect.requires_name_normalize: + colname = self.dialect.normalize_name(colname) + self.keys.append(colname) self._props[i] = rec if obj: @@ -1484,11 +1715,11 @@ class ResultProxy(object): if self._echo: self.context.engine.logger.debug( "Col " + repr(tuple(x[0] for x in metadata))) - + def __key_fallback(self): # create a closure without 'self' to avoid circular references props = self._props - + def fallback(key): if isinstance(key, basestring): key = key.lower() @@ -1515,19 +1746,22 @@ class ResultProxy(object): def close(self): """Close this ResultProxy. - + Closes the underlying DBAPI cursor corresponding to the execution. + + Note that any data cached within this ResultProxy is still available. + For some types of results, this may include buffered rows. If this ResultProxy was generated from an implicit execution, the underlying Connection will also be closed (returns the underlying DBAPI connection to the connection pool.) This method is called automatically when: - - * all result rows are exhausted using the fetchXXX() methods. - * cursor.description is None. - + + * all result rows are exhausted using the fetchXXX() methods. + * cursor.description is None. """ + if not self.closed: self.closed = True self.cursor.close() @@ -1550,53 +1784,66 @@ class ResultProxy(object): raise StopIteration else: yield row - - def last_inserted_ids(self): - """Return ``last_inserted_ids()`` from the underlying ExecutionContext. - - See ExecutionContext for details. + + @util.memoized_property + def inserted_primary_key(self): + """Return the primary key for the row just inserted. + + This only applies to single row insert() constructs which + did not explicitly specify returning(). """ - return self.context.last_inserted_ids() + if not self.context.isinsert: + raise exc.InvalidRequestError("Statement is not an insert() expression construct.") + elif self.context._is_explicit_returning: + raise exc.InvalidRequestError("Can't call inserted_primary_key when returning() is used.") + + return self.context._inserted_primary_key + @util.deprecated("Use inserted_primary_key") + def last_inserted_ids(self): + """deprecated. use inserted_primary_key.""" + + return self.inserted_primary_key + def last_updated_params(self): """Return ``last_updated_params()`` from the underlying ExecutionContext. See ExecutionContext for details. - """ + return self.context.last_updated_params() def last_inserted_params(self): """Return ``last_inserted_params()`` from the underlying ExecutionContext. See ExecutionContext for details. - """ + return self.context.last_inserted_params() def lastrow_has_defaults(self): """Return ``lastrow_has_defaults()`` from the underlying ExecutionContext. See ExecutionContext for details. - """ + return self.context.lastrow_has_defaults() def postfetch_cols(self): """Return ``postfetch_cols()`` from the underlying ExecutionContext. See ExecutionContext for details. - """ + return self.context.postfetch_cols - + def prefetch_cols(self): return self.context.prefetch_cols - + def supports_sane_rowcount(self): """Return ``supports_sane_rowcount`` from the dialect.""" - + return self.dialect.supports_sane_rowcount def supports_sane_multi_rowcount(self): @@ -1643,7 +1890,12 @@ class ResultProxy(object): raise def fetchmany(self, size=None): - """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``.""" + """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``. + + If rows are present, the cursor remains open after this is called. + Else the cursor is automatically closed and an empty list is returned. + + """ try: process_row = self._process_row @@ -1656,7 +1908,13 @@ class ResultProxy(object): raise def fetchone(self): - """Fetch one row, just like DB-API ``cursor.fetchone()``.""" + """Fetch one row, just like DB-API ``cursor.fetchone()``. + + If a row is present, the cursor remains open after this is called. + Else the cursor is automatically closed and None is returned. + + """ + try: row = self._fetchone_impl() if row is not None: @@ -1668,21 +1926,38 @@ class ResultProxy(object): self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) raise - def scalar(self): - """Fetch the first column of the first row, and close the result set.""" + def first(self): + """Fetch the first row and then close the result set unconditionally. + + Returns None if no row is present. + + """ try: row = self._fetchone_impl() except Exception, e: self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) raise - + try: if row is not None: - return self._process_row(self, row)[0] + return self._process_row(self, row) else: return None finally: self.close() + + + def scalar(self): + """Fetch the first column of the first row, and close the result set. + + Returns None if no row is present. + + """ + row = self.first() + if row is not None: + return row[0] + else: + return None class BufferedRowResultProxy(ResultProxy): """A ResultProxy with row buffering behavior. @@ -1697,7 +1972,6 @@ class BufferedRowResultProxy(ResultProxy): The pre-fetching behavior fetches only one row initially, and then grows its buffer size by a fixed amount with each successive need for additional rows up to a size of 100. - """ def _init_metadata(self): @@ -1740,7 +2014,44 @@ class BufferedRowResultProxy(ResultProxy): return result def _fetchall_impl(self): - return self.__rowbuffer + list(self.cursor.fetchall()) + ret = self.__rowbuffer + list(self.cursor.fetchall()) + self.__rowbuffer[:] = [] + return ret + +class FullyBufferedResultProxy(ResultProxy): + """A result proxy that buffers rows fully upon creation. + + Used for operations where a result is to be delivered + after the database conversation can not be continued, + such as MSSQL INSERT...OUTPUT after an autocommit. + + """ + def _init_metadata(self): + super(FullyBufferedResultProxy, self)._init_metadata() + self.__rowbuffer = self._buffer_rows() + + def _buffer_rows(self): + return self.cursor.fetchall() + + def _fetchone_impl(self): + if self.__rowbuffer: + return self.__rowbuffer.pop(0) + else: + return None + + def _fetchmany_impl(self, size=None): + result = [] + for x in range(0, size): + row = self._fetchone_impl() + if row is None: + break + result.append(row) + return result + + def _fetchall_impl(self): + ret = self.__rowbuffer + self.__rowbuffer = [] + return ret class BufferedColumnResultProxy(ResultProxy): """A ResultProxy with column buffering behavior. @@ -1791,28 +2102,6 @@ class BufferedColumnResultProxy(ResultProxy): return l -class SchemaIterator(schema.SchemaVisitor): - """A visitor that can gather text into a buffer and execute the contents of the buffer.""" - - def __init__(self, connection): - """Construct a new SchemaIterator.""" - - self.connection = connection - self.buffer = StringIO.StringIO() - - def append(self, s): - """Append content to the SchemaIterator's query buffer.""" - - self.buffer.write(s) - - def execute(self): - """Execute the contents of the SchemaIterator's buffer.""" - - try: - return self.connection.execute(self.buffer.getvalue()) - finally: - self.buffer.truncate(0) - class DefaultRunner(schema.SchemaVisitor): """A visitor which accepts ColumnDefault objects, produces the dialect-specific SQL corresponding to their execution, and @@ -1821,7 +2110,6 @@ class DefaultRunner(schema.SchemaVisitor): DefaultRunners are used internally by Engines and Dialects. Specific database modules should provide their own subclasses of DefaultRunner to allow database-specific behavior. - """ def __init__(self, context): @@ -1854,7 +2142,7 @@ class DefaultRunner(schema.SchemaVisitor): def execute_string(self, stmt, params=None): """execute a string statement, using the raw cursor, and return a scalar result.""" - + conn = self.context._connection if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements: stmt = stmt.encode(self.dialect.encoding) @@ -1883,8 +2171,8 @@ def connection_memoize(key): Only applicable to functions which take no arguments other than a connection. The memo will be stored in ``connection.info[key]``. - """ + @util.decorator def decorated(fn, self, connection): connection = connection.connect() diff --git a/lib/sqlalchemy/engine/ddl.py b/lib/sqlalchemy/engine/ddl.py new file mode 100644 index 0000000000..6e7253e9a7 --- /dev/null +++ b/lib/sqlalchemy/engine/ddl.py @@ -0,0 +1,128 @@ +# engine/ddl.py +# Copyright (C) 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Routines to handle CREATE/DROP workflow.""" + +from sqlalchemy import engine, schema +from sqlalchemy.sql import util as sql_util + + +class DDLBase(schema.SchemaVisitor): + def __init__(self, connection): + self.connection = connection + +class SchemaGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables and set(tables) or None + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def _can_create(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) + + def visit_metadata(self, metadata): + if self.tables: + tables = self.tables + else: + tables = metadata.tables.values() + collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)] + + for listener in metadata.ddl_listeners['before-create']: + listener('before-create', metadata, self.connection, tables=collection) + + for table in collection: + self.traverse_single(table) + + for listener in metadata.ddl_listeners['after-create']: + listener('after-create', metadata, self.connection, tables=collection) + + def visit_table(self, table): + for listener in table.ddl_listeners['before-create']: + listener('before-create', table, self.connection) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.connection.execute(schema.CreateTable(table)) + + if hasattr(table, 'indexes'): + for index in table.indexes: + self.traverse_single(index) + + for listener in table.ddl_listeners['after-create']: + listener('after-create', table, self.connection) + + def visit_sequence(self, sequence): + if self.dialect.supports_sequences: + if ((not self.dialect.sequences_optional or + not sequence.optional) and + (not self.checkfirst or + not self.dialect.has_sequence(self.connection, sequence.name))): + self.connection.execute(schema.CreateSequence(sequence)) + + def visit_index(self, index): + self.connection.execute(schema.CreateIndex(index)) + + +class SchemaDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def visit_metadata(self, metadata): + if self.tables: + tables = self.tables + else: + tables = metadata.tables.values() + collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)] + + for listener in metadata.ddl_listeners['before-drop']: + listener('before-drop', metadata, self.connection, tables=collection) + + for table in collection: + self.traverse_single(table) + + for listener in metadata.ddl_listeners['after-drop']: + listener('after-drop', metadata, self.connection, tables=collection) + + def _can_drop(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) + + def visit_index(self, index): + self.connection.execute(schema.DropIndex(index)) + + def visit_table(self, table): + for listener in table.ddl_listeners['before-drop']: + listener('before-drop', table, self.connection) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.connection.execute(schema.DropTable(table)) + + for listener in table.ddl_listeners['after-drop']: + listener('after-drop', table, self.connection) + + def visit_sequence(self, sequence): + if self.dialect.supports_sequences: + if ((not self.dialect.sequences_optional or + not sequence.optional) and + (not self.checkfirst or + self.dialect.has_sequence(self.connection, sequence.name))): + self.connection.execute(schema.DropSequence(sequence)) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 728b932a2e..935d1e087d 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -13,36 +13,59 @@ as the base class for their own corresponding classes. """ import re, random -from sqlalchemy.engine import base +from sqlalchemy.engine import base, reflection from sqlalchemy.sql import compiler, expression -from sqlalchemy import exc +from sqlalchemy import exc, types as sqltypes, util AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', re.I | re.UNICODE) + class DefaultDialect(base.Dialect): """Default implementation of Dialect""" - name = 'default' - schemagenerator = compiler.SchemaGenerator - schemadropper = compiler.SchemaDropper - statement_compiler = compiler.DefaultCompiler + statement_compiler = compiler.SQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler = compiler.GenericTypeCompiler preparer = compiler.IdentifierPreparer defaultrunner = base.DefaultRunner supports_alter = True + + supports_sequences = False + sequences_optional = False + preexecute_autoincrement_sequences = False + postfetch_lastrowid = True + implicit_returning = False + + # Py3K + #supports_unicode_statements = True + #supports_unicode_binds = True + # Py2K supports_unicode_statements = False + supports_unicode_binds = False + # end Py2K + + name = 'default' max_identifier_length = 9999 supports_sane_rowcount = True supports_sane_multi_rowcount = True - preexecute_pk_sequences = False - supports_pk_autoincrement = True dbapi_type_map = {} default_paramstyle = 'named' - supports_default_values = False + supports_default_values = False supports_empty_insert = True + + # indicates symbol names are + # UPPERCASEd if they are case insensitive + # within the database. + # if this is True, the methods normalize_name() + # and denormalize_name() must be provided. + requires_name_normalize = False + + reflection_options = () def __init__(self, convert_unicode=False, assert_unicode=False, - encoding='utf-8', paramstyle=None, dbapi=None, + encoding='utf-8', paramstyle=None, dbapi=None, + implicit_returning=None, label_length=None, **kwargs): self.convert_unicode = convert_unicode self.assert_unicode = assert_unicode @@ -56,28 +79,58 @@ class DefaultDialect(base.Dialect): self.paramstyle = self.dbapi.paramstyle else: self.paramstyle = self.default_paramstyle + if implicit_returning is not None: + self.implicit_returning = implicit_returning self.positional = self.paramstyle in ('qmark', 'format', 'numeric') self.identifier_preparer = self.preparer(self) + self.type_compiler = self.type_compiler(self) + if label_length and label_length > self.max_identifier_length: - raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length)) + raise exc.ArgumentError("Label length of %d is greater than this dialect's" + " maximum identifier length of %d" % + (label_length, self.max_identifier_length)) self.label_length = label_length - self.description_encoding = getattr(self, 'description_encoding', encoding) - def type_descriptor(self, typeobj): + if not hasattr(self, 'description_encoding'): + self.description_encoding = getattr(self, 'description_encoding', encoding) + + # Py3K + ## work around dialects that might change these values + #self.supports_unicode_statements = True + #self.supports_unicode_binds = True + + def initialize(self, connection): + if hasattr(self, '_get_server_version_info'): + self.server_version_info = self._get_server_version_info(connection) + if hasattr(self, '_get_default_schema_name'): + self.default_schema_name = self._get_default_schema_name(connection) + + @classmethod + def type_descriptor(cls, typeobj): """Provide a database-specific ``TypeEngine`` object, given the generic object which comes from the types module. - Subclasses will usually use the ``adapt_type()`` method in the - types module to make this job easy.""" + This method looks for a dictionary called + ``colspecs`` as a class or instance-level variable, + and passes on to ``types.adapt_type()``. - if type(typeobj) is type: - typeobj = typeobj() - return typeobj + """ + return sqltypes.adapt_type(typeobj, cls.colspecs) + + def reflecttable(self, connection, table, include_columns): + insp = reflection.Inspector.from_engine(connection) + return insp.reflecttable(table, include_columns) def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: - raise exc.IdentifierError("Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length)) - + raise exc.IdentifierError( + "Identifier '%s' exceeds maximum length of %d characters" % + (ident, self.max_identifier_length) + ) + + def connect(self, *cargs, **cparams): + return self.dbapi.connect(*cargs, **cparams) + def do_begin(self, connection): """Implementations might want to put logic here for turning autocommit on/off, etc. @@ -103,7 +156,8 @@ class DefaultDialect(base.Dialect): """Create a random two-phase transaction ID. This id will be passed to do_begin_twophase(), do_rollback_twophase(), - do_commit_twophase(). Its format is unspecified.""" + do_commit_twophase(). Its format is unspecified. + """ return "_sa_%032x" % random.randint(0, 2 ** 128) @@ -127,13 +181,30 @@ class DefaultDialect(base.Dialect): class DefaultExecutionContext(base.ExecutionContext): - def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): + + def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None): self.dialect = dialect self._connection = self.root_connection = connection - self.compiled = compiled self.engine = connection.engine - if compiled is not None: + if compiled_ddl is not None: + self.compiled = compiled = compiled_ddl + if not dialect.supports_unicode_statements: + self.statement = unicode(compiled).encode(self.dialect.encoding) + else: + self.statement = unicode(compiled) + self.isinsert = self.isupdate = self.isdelete = self.executemany = False + self.should_autocommit = True + self.result_map = None + self.cursor = self.create_cursor() + self.compiled_parameters = [] + if self.dialect.positional: + self.parameters = [()] + else: + self.parameters = [{}] + elif compiled_sql is not None: + self.compiled = compiled = compiled_sql + # compiled clauseelement. process bind params, process table defaults, # track collections used by ResultProxy to target and process results @@ -156,6 +227,7 @@ class DefaultExecutionContext(base.ExecutionContext): self.isinsert = compiled.isinsert self.isupdate = compiled.isupdate + self.isdelete = compiled.isdelete self.should_autocommit = compiled.statement._autocommit if isinstance(compiled.statement, expression._TextClause): self.should_autocommit = self.should_autocommit or self.should_autocommit_text(self.statement) @@ -173,31 +245,43 @@ class DefaultExecutionContext(base.ExecutionContext): self.parameters = self.__convert_compiled_params(self.compiled_parameters) elif statement is not None: - # plain text statement. - self.result_map = None + # plain text statement + self.result_map = self.compiled = None self.parameters = self.__encode_param_keys(parameters) self.executemany = len(parameters) > 1 if isinstance(statement, unicode) and not dialect.supports_unicode_statements: self.statement = statement.encode(self.dialect.encoding) else: self.statement = statement - self.isinsert = self.isupdate = False + self.isinsert = self.isupdate = self.isdelete = False self.cursor = self.create_cursor() self.should_autocommit = self.should_autocommit_text(statement) else: # no statement. used for standalone ColumnDefault execution. - self.statement = None - self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False + self.statement = self.compiled = None + self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False self.cursor = self.create_cursor() - + + @util.memoized_property + def _is_explicit_returning(self): + return self.compiled and \ + getattr(self.compiled.statement, '_returning', False) + + @util.memoized_property + def _is_implicit_returning(self): + return self.compiled and \ + bool(self.compiled.returning) and \ + not self.compiled.statement._returning + @property def connection(self): return self._connection._branch() def __encode_param_keys(self, params): - """apply string encoding to the keys of dictionary-based bind parameters. + """Apply string encoding to the keys of dictionary-based bind parameters. - This is only used executing textual, non-compiled SQL expressions.""" + This is only used executing textual, non-compiled SQL expressions. + """ if self.dialect.positional or self.dialect.supports_unicode_statements: if params: @@ -216,7 +300,7 @@ class DefaultExecutionContext(base.ExecutionContext): return [proc(d) for d in params] or [{}] def __convert_compiled_params(self, compiled_parameters): - """convert the dictionary of bind parameter values into a dict or list + """Convert the dictionary of bind parameter values into a dict or list to be sent to the DBAPI's execute() or executemany() method. """ @@ -263,26 +347,69 @@ class DefaultExecutionContext(base.ExecutionContext): def post_exec(self): pass + def get_lastrowid(self): + """return self.cursor.lastrowid, or equivalent, after an INSERT. + + This may involve calling special cursor functions, + issuing a new SELECT on the cursor (or a new one), + or returning a stored value that was + calculated within post_exec(). + + This function will only be called for dialects + which support "implicit" primary key generation, + keep preexecute_autoincrement_sequences set to False, + and when no explicit id value was bound to the + statement. + + The function is called once, directly after + post_exec() and before the transaction is committed + or ResultProxy is generated. If the post_exec() + method assigns a value to `self._lastrowid`, the + value is used in place of calling get_lastrowid(). + + Note that this method is *not* equivalent to the + ``lastrowid`` method on ``ResultProxy``, which is a + direct proxy to the DBAPI ``lastrowid`` accessor + in all cases. + + """ + + return self.cursor.lastrowid + def handle_dbapi_exception(self, e): pass def get_result_proxy(self): return base.ResultProxy(self) + + @property + def rowcount(self): + return self.cursor.rowcount - def get_rowcount(self): - if hasattr(self, '_rowcount'): - return self._rowcount - else: - return self.cursor.rowcount - def supports_sane_rowcount(self): return self.dialect.supports_sane_rowcount def supports_sane_multi_rowcount(self): return self.dialect.supports_sane_multi_rowcount - - def last_inserted_ids(self): - return self._last_inserted_ids + + def post_insert(self): + if self.dialect.postfetch_lastrowid and \ + (not len(self._inserted_primary_key) or \ + None in self._inserted_primary_key): + + table = self.compiled.statement.table + lastrowid = self.get_lastrowid() + self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v + for c, v in zip(table.primary_key, self._inserted_primary_key) + ] + + def _fetch_implicit_returning(self, resultproxy): + table = self.compiled.statement.table + row = resultproxy.first() + + self._inserted_primary_key = [v is not None and v or row[c] + for c, v in zip(table.primary_key, self._inserted_primary_key) + ] def last_inserted_params(self): return self._last_inserted_params @@ -293,12 +420,15 @@ class DefaultExecutionContext(base.ExecutionContext): def lastrow_has_defaults(self): return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols) - def set_input_sizes(self): + def set_input_sizes(self, translate=None, exclude_types=None): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. """ + if not hasattr(self.compiled, 'bind_names'): + return + types = dict( (self.compiled.bind_names[bindparam], bindparam.type) for bindparam in self.compiled.bind_names) @@ -308,7 +438,7 @@ class DefaultExecutionContext(base.ExecutionContext): for key in self.compiled.positiontup: typeengine = types[key] dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: + if dbtype is not None and (not exclude_types or dbtype not in exclude_types): inputsizes.append(dbtype) try: self.cursor.setinputsizes(*inputsizes) @@ -320,7 +450,9 @@ class DefaultExecutionContext(base.ExecutionContext): for key in self.compiled.bind_names.values(): typeengine = types[key] dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: + if dbtype is not None and (not exclude_types or dbtype not in exclude_types): + if translate: + key = translate.get(key, key) inputsizes[key.encode(self.dialect.encoding)] = dbtype try: self.cursor.setinputsizes(**inputsizes) @@ -329,8 +461,9 @@ class DefaultExecutionContext(base.ExecutionContext): raise def __process_defaults(self): - """generate default values for compiled insert/update statements, - and generate last_inserted_ids() collection.""" + """Generate default values for compiled insert/update statements, + and generate inserted_primary_key collection. + """ if self.executemany: if len(self.compiled.prefetch): @@ -364,7 +497,8 @@ class DefaultExecutionContext(base.ExecutionContext): compiled_parameters[c.key] = val if self.isinsert: - self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key] + self._inserted_primary_key = [compiled_parameters.get(c.key, None) + for c in self.compiled.statement.table.primary_key] self._last_inserted_params = compiled_parameters else: self._last_updated_params = compiled_parameters diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py new file mode 100644 index 0000000000..173e0fab00 --- /dev/null +++ b/lib/sqlalchemy/engine/reflection.py @@ -0,0 +1,361 @@ +"""Provides an abstraction for obtaining database schema information. + +Usage Notes: + +Here are some general conventions when accessing the low level inspector +methods such as get_table_names, get_columns, etc. + +1. Inspector methods return lists of dicts in most cases for the following + reasons: + + * They're both standard types that can be serialized. + * Using a dict instead of a tuple allows easy expansion of attributes. + * Using a list for the outer structure maintains order and is easy to work + with (e.g. list comprehension [d['name'] for d in cols]). + +2. Records that contain a name, such as the column name in a column record + use the key 'name'. So for most return values, each record will have a + 'name' attribute.. +""" + +import sqlalchemy +from sqlalchemy import exc, sql +from sqlalchemy import util +from sqlalchemy.types import TypeEngine +from sqlalchemy import schema as sa_schema + + +@util.decorator +def cache(fn, self, con, *args, **kw): + info_cache = kw.get('info_cache', None) + if info_cache is None: + return fn(self, con, *args, **kw) + key = ( + fn.__name__, + tuple(a for a in args if isinstance(a, basestring)), + tuple((k, v) for k, v in kw.iteritems() if isinstance(v, basestring)) + ) + ret = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + + +class Inspector(object): + """Performs database schema inspection. + + The Inspector acts as a proxy to the dialects' reflection methods and + provides higher level functions for accessing database schema information. + """ + + def __init__(self, conn): + """Initialize the instance. + + :param conn: a :class:`~sqlalchemy.engine.base.Connectable` + """ + + self.conn = conn + # set the engine + if hasattr(conn, 'engine'): + self.engine = conn.engine + else: + self.engine = conn + self.dialect = self.engine.dialect + self.info_cache = {} + + @classmethod + def from_engine(cls, engine): + if hasattr(engine.dialect, 'inspector'): + return engine.dialect.inspector(engine) + return Inspector(engine) + + @property + def default_schema_name(self): + return self.dialect.get_default_schema_name(self.conn) + + def get_schema_names(self): + """Return all schema names. + """ + + if hasattr(self.dialect, 'get_schema_names'): + return self.dialect.get_schema_names(self.conn, + info_cache=self.info_cache) + return [] + + def get_table_names(self, schema=None, order_by=None): + """Return all table names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + :param order_by: Optional, may be the string "foreign_key" to sort + the result on foreign key dependencies. + + This should probably not return view names or maybe it should return + them with an indicator t or v. + """ + + if hasattr(self.dialect, 'get_table_names'): + tnames = self.dialect.get_table_names(self.conn, + schema, + info_cache=self.info_cache) + else: + tnames = self.engine.table_names(schema) + if order_by == 'foreign_key': + ordered_tnames = tnames[:] + # Order based on foreign key dependencies. + for tname in tnames: + table_pos = tnames.index(tname) + fkeys = self.get_foreign_keys(tname, schema) + for fkey in fkeys: + rtable = fkey['referred_table'] + if rtable in ordered_tnames: + ref_pos = ordered_tnames.index(rtable) + # Make sure it's lower in the list than anything it + # references. + if table_pos > ref_pos: + ordered_tnames.pop(table_pos) # rtable moves up 1 + # insert just below rtable + ordered_tnames.index(ref_pos, tname) + tnames = ordered_tnames + return tnames + + def get_table_options(self, table_name, schema=None, **kw): + if hasattr(self.dialect, 'get_table_options'): + return self.dialect.get_table_options(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + return {} + + def get_view_names(self, schema=None): + """Return all view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + """ + + return self.dialect.get_view_names(self.conn, schema, + info_cache=self.info_cache) + + def get_view_definition(self, view_name, schema=None): + """Return definition for `view_name`. + + :param schema: Optional, retrieve names from a non-default schema. + """ + + return self.dialect.get_view_definition( + self.conn, view_name, schema, info_cache=self.info_cache) + + def get_columns(self, table_name, schema=None, **kw): + """Return information about columns in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + column information as a list of dicts with these keys: + + name + the column's name + + type + :class:`~sqlalchemy.types.TypeEngine` + + nullable + boolean + + default + the column's default value + + attrs + dict containing optional column attributes + """ + + col_defs = self.dialect.get_columns(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + for col_def in col_defs: + # make this easy and only return instances for coltype + coltype = col_def['type'] + if not isinstance(coltype, TypeEngine): + col_def['type'] = coltype() + return col_defs + + def get_primary_keys(self, table_name, schema=None, **kw): + """Return information about primary keys in `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + primary key information as a list of column names. + """ + + pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + + return pkeys + + def get_foreign_keys(self, table_name, schema=None, **kw): + """Return information about foreign_keys in `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + foreign key information as a list of dicts with these keys: + + constrained_columns + a list of column names that make up the foreign key + + referred_schema + the name of the referred schema + + referred_table + the name of the referred table + + referred_columns + a list of column names in the referred table that correspond to + constrained_columns + """ + + fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + return fk_defs + + def get_indexes(self, table_name, schema=None): + """Return information about indexes in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + index information as a list of dicts with these keys: + + name + the index's name + + column_names + list of column names in order + + unique + boolean + """ + + indexes = self.dialect.get_indexes(self.conn, table_name, + schema, + info_cache=self.info_cache) + return indexes + + def reflecttable(self, table, include_columns): + + dialect = self.conn.dialect + + # MySQL dialect does this. Applicable with other dialects? + if hasattr(dialect, '_connection_charset') \ + and hasattr(dialect, '_adjust_casing'): + charset = dialect._connection_charset + dialect._adjust_casing(table) + + # table attributes we might need. + reflection_options = dict( + (k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs) + + schema = table.schema + table_name = table.name + + # apply table options + tbl_opts = self.get_table_options(table_name, schema, **table.kwargs) + if tbl_opts: + table.kwargs.update(tbl_opts) + + # table.kwargs will need to be passed to each reflection method. Make + # sure keywords are strings. + tblkw = table.kwargs.copy() + for (k, v) in tblkw.items(): + del tblkw[k] + tblkw[str(k)] = v + + # Py2K + if isinstance(schema, str): + schema = schema.decode(dialect.encoding) + if isinstance(table_name, str): + table_name = table_name.decode(dialect.encoding) + # end Py2K + + # columns + found_table = False + for col_d in self.get_columns(table_name, schema, **tblkw): + found_table = True + name = col_d['name'] + if include_columns and name not in include_columns: + continue + + coltype = col_d['type'] + col_kw = { + 'nullable':col_d['nullable'], + } + if 'autoincrement' in col_d: + col_kw['autoincrement'] = col_d['autoincrement'] + + colargs = [] + if col_d.get('default') is not None: + # the "default" value is assumed to be a literal SQL expression, + # so is wrapped in text() so that no quoting occurs on re-issuance. + colargs.append(sa_schema.DefaultClause(sql.text(col_d['default']))) + + if 'sequence' in col_d: + # TODO: whos using this ? + seq = col_d['sequence'] + sequence = sa_schema.Sequence(seq['name'], 1, 1) + if 'start' in seq: + sequence.start = seq['start'] + if 'increment' in seq: + sequence.increment = seq['increment'] + colargs.append(sequence) + + col = sa_schema.Column(name, coltype, *colargs, **col_kw) + table.append_column(col) + + if not found_table: + raise exc.NoSuchTableError(table.name) + + # Primary keys + primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[ + table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw) + if pk in table.c + ]) + + table.append_constraint(primary_key_constraint) + + # Foreign keys + fkeys = self.get_foreign_keys(table_name, schema, **tblkw) + for fkey_d in fkeys: + conname = fkey_d['name'] + constrained_columns = fkey_d['constrained_columns'] + referred_schema = fkey_d['referred_schema'] + referred_table = fkey_d['referred_table'] + referred_columns = fkey_d['referred_columns'] + refspec = [] + if referred_schema is not None: + sa_schema.Table(referred_table, table.metadata, + autoload=True, schema=referred_schema, + autoload_with=self.conn, + **reflection_options + ) + for column in referred_columns: + refspec.append(".".join( + [referred_schema, referred_table, column])) + else: + sa_schema.Table(referred_table, table.metadata, autoload=True, + autoload_with=self.conn, + **reflection_options + ) + for column in referred_columns: + refspec.append(".".join([referred_table, column])) + table.append_constraint( + sa_schema.ForeignKeyConstraint(constrained_columns, refspec, + conname, link_to_name=True)) + # Indexes + indexes = self.get_indexes(table_name, schema) + for index_d in indexes: + name = index_d['name'] + columns = index_d['column_names'] + unique = index_d['unique'] + flavor = index_d.get('type', 'unknown type') + if include_columns and \ + not set(columns).issubset(include_columns): + util.warn( + "Omitting %s KEY for (%s), key covers omitted columns." % + (flavor, ', '.join(columns))) + continue + sa_schema.Index(name, *[table.columns[c] for c in columns], + **dict(unique=unique)) diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index fa608df65e..ff62b265ba 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -6,31 +6,26 @@ underlying behavior for the "strategy" keyword argument available on ``plain``, ``threadlocal``, and ``mock``. New strategies can be added via new ``EngineStrategy`` classes. - """ + from operator import attrgetter from sqlalchemy.engine import base, threadlocal, url from sqlalchemy import util, exc from sqlalchemy import pool as poollib - strategies = {} + class EngineStrategy(object): """An adaptor that processes input arguements and produces an Engine. Provides a ``create`` method that receives input arguments and produces an instance of base.Engine or a subclass. + """ - def __init__(self, name): - """Construct a new EngineStrategy object. - - Sets it in the list of available strategies under this name. - """ - - self.name = name + def __init__(self): strategies[self.name] = self def create(self, *args, **kwargs): @@ -38,9 +33,12 @@ class EngineStrategy(object): raise NotImplementedError() + class DefaultEngineStrategy(EngineStrategy): """Base class for built-in stratgies.""" + pool_threadlocal = False + def create(self, name_or_url, **kwargs): # create url.URL object u = url.make_url(name_or_url) @@ -75,9 +73,15 @@ class DefaultEngineStrategy(EngineStrategy): if pool is None: def connect(): try: - return dbapi.connect(*cargs, **cparams) + return dialect.connect(*cargs, **cparams) except Exception, e: - raise exc.DBAPIError.instance(None, None, e) + # Py3K + #raise exc.DBAPIError.instance(None, None, e) from e + # Py2K + import sys + raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2] + # end Py2K + creator = kwargs.pop('creator', connect) poolclass = (kwargs.pop('poolclass', None) or @@ -94,7 +98,7 @@ class DefaultEngineStrategy(EngineStrategy): tk = translate.get(k, k) if tk in kwargs: pool_args[k] = kwargs.pop(tk) - pool_args.setdefault('use_threadlocal', self.pool_threadlocal()) + pool_args.setdefault('use_threadlocal', self.pool_threadlocal) pool = poolclass(creator, **pool_args) else: if isinstance(pool, poollib._DBProxy): @@ -103,12 +107,14 @@ class DefaultEngineStrategy(EngineStrategy): pool = pool # create engine. - engineclass = self.get_engine_cls() + engineclass = self.engine_cls engine_args = {} for k in util.get_cls_kwargs(engineclass): if k in kwargs: engine_args[k] = kwargs.pop(k) + _initialize = kwargs.pop('_initialize', True) + # all kwargs should be consumed if kwargs: raise TypeError( @@ -119,39 +125,38 @@ class DefaultEngineStrategy(EngineStrategy): dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__)) - return engineclass(pool, dialect, u, **engine_args) + + engine = engineclass(pool, dialect, u, **engine_args) - def pool_threadlocal(self): - raise NotImplementedError() + if _initialize: + # some unit tests pass through _initialize=False + # to help mock engines work + class OnInit(object): + def first_connect(self, conn, rec): + c = base.Connection(engine, connection=conn) + dialect.initialize(c) + pool._on_first_connect.insert(0, OnInit()) - def get_engine_cls(self): - raise NotImplementedError() + dialect.visit_pool(pool) -class PlainEngineStrategy(DefaultEngineStrategy): - """Strategy for configuring a regular Engine.""" + return engine - def __init__(self): - DefaultEngineStrategy.__init__(self, 'plain') - - def pool_threadlocal(self): - return False - def get_engine_cls(self): - return base.Engine +class PlainEngineStrategy(DefaultEngineStrategy): + """Strategy for configuring a regular Engine.""" + name = 'plain' + engine_cls = base.Engine + PlainEngineStrategy() + class ThreadLocalEngineStrategy(DefaultEngineStrategy): """Strategy for configuring an Engine with thredlocal behavior.""" - - def __init__(self): - DefaultEngineStrategy.__init__(self, 'threadlocal') - - def pool_threadlocal(self): - return True - - def get_engine_cls(self): - return threadlocal.TLEngine + + name = 'threadlocal' + pool_threadlocal = True + engine_cls = threadlocal.TLEngine ThreadLocalEngineStrategy() @@ -161,11 +166,11 @@ class MockEngineStrategy(EngineStrategy): Produces a single mock Connectable object which dispatches statement execution to a passed-in function. + """ - def __init__(self): - EngineStrategy.__init__(self, 'mock') - + name = 'mock' + def create(self, name_or_url, executor, **kwargs): # create url.URL object u = url.make_url(name_or_url) @@ -201,11 +206,14 @@ class MockEngineStrategy(EngineStrategy): def create(self, entity, **kwargs): kwargs['checkfirst'] = False - self.dialect.schemagenerator(self.dialect, self, **kwargs).traverse(entity) + from sqlalchemy.engine import ddl + + ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity) def drop(self, entity, **kwargs): kwargs['checkfirst'] = False - self.dialect.schemadropper(self.dialect, self, **kwargs).traverse(entity) + from sqlalchemy.engine import ddl + ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity) def execute(self, object, *multiparams, **params): raise NotImplementedError() diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 8ad14ad35f..27d857623e 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -8,6 +8,7 @@ invoked automatically when the threadlocal engine strategy is used. from sqlalchemy import util from sqlalchemy.engine import base + class TLSession(object): def __init__(self, engine): self.engine = engine @@ -17,7 +18,8 @@ class TLSession(object): try: return self.__transaction._increment_connect() except AttributeError: - return self.engine.TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result) + return self.engine.TLConnection(self, self.engine.pool.connect(), + close_with_result=close_with_result) def reset(self): try: diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 5c8e68ce45..b0e21f5f72 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -20,9 +20,9 @@ class URL(object): format of the URL is an RFC-1738-style string. All initialization parameters are available as public attributes. - - :param drivername: the name of the database backend. - This name will correspond to a module in sqlalchemy/databases + + :param drivername: the name of the database backend. + This name will correspond to a module in sqlalchemy/databases or a third party plug-in. :param username: The user name. @@ -35,12 +35,13 @@ class URL(object): :param database: The database name. - :param query: A dictionary of options to be passed to the + :param query: A dictionary of options to be passed to the dialect and/or the DBAPI upon connect. - + """ - def __init__(self, drivername, username=None, password=None, host=None, port=None, database=None, query=None): + def __init__(self, drivername, username=None, password=None, + host=None, port=None, database=None, query=None): self.drivername = drivername self.username = username self.password = password @@ -70,10 +71,10 @@ class URL(object): keys.sort() s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys) return s - + def __hash__(self): return hash(str(self)) - + def __eq__(self, other): return \ isinstance(other, URL) and \ @@ -83,12 +84,22 @@ class URL(object): self.host == other.host and \ self.database == other.database and \ self.query == other.query - + def get_dialect(self): - """Return the SQLAlchemy database dialect class corresponding to this URL's driver name.""" - + """Return the SQLAlchemy database dialect class corresponding + to this URL's driver name. + """ + try: - module = getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername) + if '+' in self.drivername: + dialect, driver = self.drivername.split('+') + else: + dialect, driver = self.drivername, 'base' + + module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects + module = getattr(module, dialect) + module = getattr(module, driver) + return module.dialect except ImportError: if sys.exc_info()[2].tb_next is None: @@ -97,7 +108,7 @@ class URL(object): if res.name == self.drivername: return res.load() raise - + def translate_connect_args(self, names=[], **kw): """Translate url attributes into a dictionary of connection arguments. @@ -107,10 +118,9 @@ class URL(object): from the final dictionary. :param \**kw: Optional, alternate key names for url attributes. - + :param names: Deprecated. Same purpose as the keyword-based alternate names, but correlates the name to the original positionally. - """ translated = {} @@ -131,8 +141,8 @@ def make_url(name_or_url): The given string is parsed according to the RFC 1738 spec. If an existing URL object is passed, just returns the object. - """ + if isinstance(name_or_url, basestring): return _parse_rfc1738_args(name_or_url) else: @@ -140,7 +150,7 @@ def make_url(name_or_url): def _parse_rfc1738_args(name): pattern = re.compile(r''' - (?P\w+):// + (?P[\w\+]+):// (?: (?P[^:/]*) (?::(?P[^/]*))? @@ -160,8 +170,10 @@ def _parse_rfc1738_args(name): tokens = components['database'].split('?', 2) components['database'] = tokens[0] query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None + # Py2K if query is not None: query = dict((k.encode('ascii'), query[k]) for k in query) + # end Py2K else: query = None components['query'] = query diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index ce130ce3c2..f1678743d9 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -103,6 +103,7 @@ class DBAPIError(SQLAlchemyError): """ + @classmethod def instance(cls, statement, params, orig, connection_invalidated=False): # Don't ever wrap these, just return them directly as if # DBAPIError didn't exist. @@ -115,7 +116,6 @@ class DBAPIError(SQLAlchemyError): cls = glob[name] return cls(statement, params, orig, connection_invalidated) - instance = classmethod(instance) def __init__(self, statement, params, orig, connection_invalidated=False): try: diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 0e3db00e02..05df8d2be6 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -33,7 +33,7 @@ Produces:: Compilers can also be made dialect-specific. The appropriate compiler will be invoked for the dialect in use:: - from sqlalchemy.schema import DDLElement # this is a SQLA 0.6 construct + from sqlalchemy.schema import DDLElement class AlterColumn(DDLElement): @@ -45,16 +45,16 @@ for the dialect in use:: def visit_alter_column(element, compiler, **kw): return "ALTER COLUMN %s ..." % element.column.name - @compiles(AlterColumn, 'postgres') + @compiles(AlterColumn, 'postgresql') def visit_alter_column(element, compiler, **kw): return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, element.column.name) -The second ``visit_alter_table`` will be invoked when any ``postgres`` dialect is used. +The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used. The ``compiler`` argument is the :class:`~sqlalchemy.engine.base.Compiled` object in use. This object can be inspected for any information about the in-progress compilation, including ``compiler.dialect``, ``compiler.statement`` etc. -The :class:`~sqlalchemy.sql.compiler.SQLCompiler` and :class:`~sqlalchemy.sql.compiler.DDLCompiler` (DDLCompiler is 0.6. only) +The :class:`~sqlalchemy.sql.compiler.SQLCompiler` and :class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()`` method which can be used for compilation of embedded attributes:: class InsertFromSelect(ClauseElement): diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py index 07974cacce..c37211ac3d 100644 --- a/lib/sqlalchemy/ext/declarative.py +++ b/lib/sqlalchemy/ext/declarative.py @@ -396,7 +396,7 @@ only intended as an optional syntax for the regular usage of mappers and Table objects. A typical application setup using :func:`~sqlalchemy.orm.scoped_session` might look like:: - engine = create_engine('postgres://scott:tiger@localhost/test') + engine = create_engine('postgresql://scott:tiger@localhost/test') Session = scoped_session(sessionmaker(autocommit=False, autoflush=False, bind=engine)) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index a5d60bf82e..8e63ed1c29 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -240,7 +240,8 @@ class OrderingList(list): _raw_append = collection.adds(1)(_raw_append) def insert(self, index, entity): - self[index:index] = [entity] + super(OrderingList, self).insert(index, entity) + self._reorder() def remove(self, entity): super(OrderingList, self).remove(entity) @@ -253,7 +254,15 @@ class OrderingList(list): def __setitem__(self, index, entity): if isinstance(index, slice): - for i in range(index.start or 0, index.stop or 0, index.step or 1): + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + stop = index.stop or len(self) + if stop < 0: + stop += len(self) + + for i in xrange(start, stop, step): self.__setitem__(i, entity[i]) else: self._order_entity(index, entity, True) @@ -263,6 +272,7 @@ class OrderingList(list): super(OrderingList, self).__delitem__(index) self._reorder() + # Py2K def __setslice__(self, start, end, values): super(OrderingList, self).__setslice__(start, end, values) self._reorder() @@ -270,7 +280,8 @@ class OrderingList(list): def __delslice__(self, start, end): super(OrderingList, self).__delslice__(start, end) self._reorder() - + # end Py2K + for func_name, func in locals().items(): if (util.callable(func) and func.func_name == func_name and not func.__doc__ and hasattr(list, func_name)): diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py index b62ee0ce64..fd456e385f 100644 --- a/lib/sqlalchemy/ext/serializer.py +++ b/lib/sqlalchemy/ext/serializer.py @@ -43,10 +43,26 @@ from sqlalchemy.engine import Engine from sqlalchemy.util import pickle import re import base64 -from cStringIO import StringIO +# Py3K +#from io import BytesIO as byte_buffer +# Py2K +from cStringIO import StringIO as byte_buffer +# end Py2K + +# Py3K +#def b64encode(x): +# return base64.b64encode(x).decode('ascii') +#def b64decode(x): +# return base64.b64decode(x.encode('ascii')) +# Py2K +b64encode = base64.b64encode +b64decode = base64.b64decode +# end Py2K __all__ = ['Serializer', 'Deserializer', 'dumps', 'loads'] + + def Serializer(*args, **kw): pickler = pickle.Pickler(*args, **kw) @@ -55,9 +71,9 @@ def Serializer(*args, **kw): if isinstance(obj, QueryableAttribute): cls = obj.impl.class_ key = obj.impl.key - id = "attribute:" + key + ":" + base64.b64encode(pickle.dumps(cls)) + id = "attribute:" + key + ":" + b64encode(pickle.dumps(cls)) elif isinstance(obj, Mapper) and not obj.non_primary: - id = "mapper:" + base64.b64encode(pickle.dumps(obj.class_)) + id = "mapper:" + b64encode(pickle.dumps(obj.class_)) elif isinstance(obj, Table): id = "table:" + str(obj) elif isinstance(obj, Column) and isinstance(obj.table, Table): @@ -96,10 +112,10 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): type_, args = m.group(1, 2) if type_ == 'attribute': key, clsarg = args.split(":") - cls = pickle.loads(base64.b64decode(clsarg)) + cls = pickle.loads(b64decode(clsarg)) return getattr(cls, key) elif type_ == "mapper": - cls = pickle.loads(base64.b64decode(args)) + cls = pickle.loads(b64decode(args)) return class_mapper(cls) elif type_ == "table": return metadata.tables[args] @@ -116,13 +132,13 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None): return unpickler def dumps(obj): - buf = StringIO() + buf = byte_buffer() pickler = Serializer(buf) pickler.dump(obj) return buf.getvalue() def loads(data, metadata=None, scoped_session=None, engine=None): - buf = StringIO(data) + buf = byte_buffer(data) unpickler = Deserializer(buf, metadata, scoped_session, engine) return unpickler.load() diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index b3f2de743e..6eef4657c3 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -86,7 +86,7 @@ Full query documentation Get, filter, filter_by, order_by, limit, and the rest of the query methods are explained in detail in the `SQLAlchemy documentation`__. -__ http://www.sqlalchemy.org/docs/04/ormtutorial.html#datamapping_querying +__ http://www.sqlalchemy.org/docs/05/ormtutorial.html#datamapping_querying Modifying objects @@ -447,9 +447,11 @@ def _selectable_name(selectable): def class_for_table(selectable, **mapper_kwargs): selectable = expression._clause_element_as_expr(selectable) mapname = 'Mapped' + _selectable_name(selectable) + # Py2K if isinstance(mapname, unicode): engine_encoding = selectable.metadata.bind.dialect.encoding mapname = mapname.encode(engine_encoding) + # end Py2K if isinstance(selectable, Table): klass = TableClassType(mapname, (object,), {}) else: @@ -543,10 +545,14 @@ class SqlSoup: def entity(self, attr, schema=None): try: t = self._cache[attr] - except KeyError: + except KeyError, ke: table = Table(attr, self._metadata, autoload=True, schema=schema or self.schema) if not table.primary_key.columns: + # Py3K + #raise PKNotFoundError('table %r does not have a primary key defined [columns: %s]' % (attr, ','.join(table.c.keys()))) from ke + # Py2K raise PKNotFoundError('table %r does not have a primary key defined [columns: %s]' % (attr, ','.join(table.c.keys()))) + # end Py2K if table.columns: t = class_for_table(table) else: diff --git a/lib/sqlalchemy/interfaces.py b/lib/sqlalchemy/interfaces.py index dfceffe445..e4a9adee1f 100644 --- a/lib/sqlalchemy/interfaces.py +++ b/lib/sqlalchemy/interfaces.py @@ -71,6 +71,18 @@ class PoolListener(object): """ + def first_connect(self, dbapi_con, con_record): + """Called exactly once for the first DB-API connection. + + dbapi_con + A newly connected raw DB-API connection (not a SQLAlchemy + ``Connection`` wrapper). + + con_record + The ``_ConnectionRecord`` that persistently manages the connection + + """ + def checkout(self, dbapi_con, con_record, con_proxy): """Called when a connection is retrieved from the Pool. diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 2a20b05ef8..3c39316da3 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -849,8 +849,13 @@ def clear_mappers(): """ mapperlib._COMPILE_MUTEX.acquire() try: - for mapper in list(_mapper_registry): - mapper.dispose() + while _mapper_registry: + try: + # can't even reliably call list(weakdict) in jython + mapper, b = _mapper_registry.popitem() + mapper.dispose() + except KeyError: + pass finally: mapperlib._COMPILE_MUTEX.release() diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 46e9b00de2..f6947dbc11 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -159,7 +159,7 @@ class InstrumentedAttribute(QueryableAttribute): class _ProxyImpl(object): accepts_scalar_loader = False - dont_expire_missing = False + expire_missing = True def __init__(self, key): self.key = key @@ -231,7 +231,7 @@ class AttributeImpl(object): def __init__(self, class_, key, callable_, trackparent=False, extension=None, compare_function=None, active_history=False, parent_token=None, - dont_expire_missing=False, + expire_missing=True, **kwargs): """Construct an AttributeImpl. @@ -269,8 +269,8 @@ class AttributeImpl(object): Allows multiple AttributeImpls to all match a single owner attribute. - dont_expire_missing - if True, don't add an "expiry" callable to this attribute + expire_missing + if False, don't add an "expiry" callable to this attribute during state.expire_attributes(None), if no value is present for this key. @@ -290,7 +290,7 @@ class AttributeImpl(object): active_history = True break self.active_history = active_history - self.dont_expire_missing = dont_expire_missing + self.expire_missing = expire_missing def hasparent(self, state, optimistic=False): """Return the boolean value of a `hasparent` flag attached to the given item. @@ -991,9 +991,8 @@ class ClassManager(dict): self.local_attrs[key] = inst self.install_descriptor(key, inst) self[key] = inst + for cls in self.class_.__subclasses__(): - if isinstance(cls, types.ClassType): - continue manager = self._subclass_manager(cls) manager.instrument_attribute(key, inst, True) @@ -1013,8 +1012,6 @@ class ClassManager(dict): if key in self.mutable_attributes: self.mutable_attributes.remove(key) for cls in self.class_.__subclasses__(): - if isinstance(cls, types.ClassType): - continue manager = self._subclass_manager(cls) manager.uninstrument_attribute(key, True) @@ -1646,8 +1643,12 @@ def __init__(%(apply_pos)s): func_vars = util.format_argspec_init(original__init__, grouped=False) func_text = func_body % func_vars + # Py3K + #func_defaults = getattr(original__init__, '__defaults__', None) + # Py2K func = getattr(original__init__, 'im_func', original__init__) func_defaults = getattr(func, 'func_defaults', None) + # end Py2K env = locals().copy() exec func_text in env diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 4ca4c5719e..6a77018468 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -529,7 +529,11 @@ class CollectionAdapter(object): if getattr(obj, '_sa_adapter', None) is not None: return getattr(obj, '_sa_adapter') elif setting_type == dict: + # Py3K + #return obj.values() + # Py2K return getattr(obj, 'itervalues', getattr(obj, 'values'))() + # end Py2K else: return iter(obj) @@ -561,7 +565,9 @@ class CollectionAdapter(object): def __iter__(self): """Iterate over entities in the collection.""" - return getattr(self._data(), '_sa_iterator')() + + # Py3K requires iter() here + return iter(getattr(self._data(), '_sa_iterator')()) def __len__(self): """Count entities in the collection.""" @@ -938,22 +944,23 @@ def _list_decorators(): fn(self, index, value) else: # slice assignment requires __delitem__, insert, __len__ - if index.stop is None: - stop = 0 - elif index.stop < 0: - stop = len(self) + index.stop - else: - stop = index.stop step = index.step or 1 - rng = range(index.start or 0, stop, step) + start = index.start or 0 + if start < 0: + start += len(self) + stop = index.stop or len(self) + if stop < 0: + stop += len(self) + if step == 1: - for i in rng: - del self[index.start] - i = index.start - for item in value: - self.insert(i, item) - i += 1 + for i in xrange(start, stop, step): + if len(self) > start: + del self[start] + + for i, item in enumerate(value): + self.insert(i + start, item) else: + rng = range(start, stop, step) if len(value) != len(rng): raise ValueError( "attempt to assign sequence of size %s to " @@ -980,6 +987,7 @@ def _list_decorators(): _tidy(__delitem__) return __delitem__ + # Py2K def __setslice__(fn): def __setslice__(self, start, end, values): for value in self[start:end]: @@ -996,7 +1004,8 @@ def _list_decorators(): fn(self, start, end) _tidy(__delslice__) return __delslice__ - + # end Py2K + def extend(fn): def extend(self, iterable): for value in iterable: @@ -1319,9 +1328,14 @@ class InstrumentedSet(set): class InstrumentedDict(dict): """An instrumented version of the built-in dict.""" + # Py3K + #__instrumentation__ = { + # 'iterator': 'values', } + # Py2K __instrumentation__ = { 'iterator': 'itervalues', } - + # end Py2K + __canned_instrumentation = { list: InstrumentedList, set: InstrumentedSet, @@ -1338,8 +1352,13 @@ __interfaces = { 'iterator': '__iter__', '_decorators': _set_decorators(), }, # decorators are required for dicts and object collections. + # Py3K + #dict: {'iterator': 'values', + # '_decorators': _dict_decorators(), }, + # Py2K dict: {'iterator': 'itervalues', '_decorators': _dict_decorators(), }, + # end Py2K # < 0.4 compatible naming, deprecated- use decorators instead. None: {} } diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index f3820eb7cd..407a04ae4d 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -27,7 +27,7 @@ def create_dependency_processor(prop): return types[prop.direction](prop) class DependencyProcessor(object): - no_dependencies = False + has_dependencies = True def __init__(self, prop): self.prop = prop @@ -291,7 +291,7 @@ class DetectKeySwitch(DependencyProcessor): """a special DP that works for many-to-one relations, fires off for child items who have changed their referenced key.""" - no_dependencies = True + has_dependencies = False def register_dependencies(self, uowcommit): pass diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 9076f610d7..05af5d8ca7 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -7,7 +7,11 @@ class UnevaluatableError(Exception): pass _straight_ops = set(getattr(operators, op) - for op in ('add', 'mul', 'sub', 'div', 'mod', 'truediv', + for op in ('add', 'mul', 'sub', + # Py2K + 'div', + # end Py2K + 'mod', 'truediv', 'lt', 'le', 'ne', 'gt', 'ge', 'eq')) diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 50301a13c3..b7d4234f4c 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -45,7 +45,7 @@ class IdentityMap(dict): self._modified.discard(state) def _dirty_states(self): - return self._modified.union(s for s in list(self._mutable_attrs) + return self._modified.union(s for s in self._mutable_attrs.copy() if s.modified) def check_modified(self): @@ -54,7 +54,7 @@ class IdentityMap(dict): if self._modified: return True else: - for state in list(self._mutable_attrs): + for state in self._mutable_attrs.copy(): if state.modified: return True return False @@ -145,34 +145,49 @@ class WeakInstanceDict(IdentityMap): return self[key] except KeyError: return default - + + # Py2K def items(self): return list(self.iteritems()) def iteritems(self): for state in dict.itervalues(self): + # end Py2K + # Py3K + #def items(self): + # for state in dict.values(self): value = state.obj() if value is not None: yield state.key, value + # Py2K + def values(self): + return list(self.itervalues()) + def itervalues(self): for state in dict.itervalues(self): + # end Py2K + # Py3K + #def values(self): + # for state in dict.values(self): instance = state.obj() if instance is not None: yield instance - def values(self): - return list(self.itervalues()) - def all_states(self): + # Py3K + # return list(dict.values(self)) + + # Py2K return dict.values(self) + # end Py2K def prune(self): return 0 class StrongInstanceDict(IdentityMap): def all_states(self): - return [attributes.instance_state(o) for o in self.values()] + return [attributes.instance_state(o) for o in self.itervalues()] def contains_state(self, state): return state.key in self and attributes.instance_state(self[state.key]) is state @@ -212,7 +227,11 @@ class StrongInstanceDict(IdentityMap): ref_count = len(self) dirty = [s.obj() for s in self.all_states() if s.check_modified()] - keepers = weakref.WeakValueDictionary(self) + + # work around http://bugs.python.org/issue6149 + keepers = weakref.WeakValueDictionary() + keepers.update(self) + dict.clear(self) dict.update(self, keepers) self.modified = bool(dirty) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 5dffa6774a..eaafe5761a 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -455,7 +455,7 @@ class MapperProperty(object): return not self.parent.non_primary - def merge(self, session, source, dest, dont_load, _recursive): + def merge(self, session, source, dest, load, _recursive): """Merge the attribute represented by this ``MapperProperty`` from source to destination object""" diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 078056a01c..c2c57825e3 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -506,13 +506,13 @@ class Mapper(object): if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty: col = self.mapped_table.corresponding_column(self.polymorphic_on) if not col: - dont_instrument = True + instrument = False col = self.polymorphic_on else: - dont_instrument = False + instrument = True if self._should_exclude(col.key, col.key, local=False): raise sa_exc.InvalidRequestError("Cannot exclude or override the discriminator column %r" % col.key) - self._configure_property(col.key, ColumnProperty(col, _no_instrument=dont_instrument), init=False, setparent=True) + self._configure_property(col.key, ColumnProperty(col, _instrument=instrument), init=False, setparent=True) def _adapt_inherited_property(self, key, prop, init): if not self.concrete: @@ -1397,7 +1397,7 @@ class Mapper(object): statement = table.insert() for state, params, mapper, connection, value_params in insert: c = connection.execute(statement.values(value_params), params) - primary_key = c.last_inserted_ids() + primary_key = c.inserted_primary_key if primary_key is not None: # set primary key attributes @@ -1574,6 +1574,12 @@ class Mapper(object): if state.load_options: state.load_path = context.query._current_path + path + if isnew: + if context.options: + state.load_options = context.options + if state.load_options: + state.load_path = context.query._current_path + path + if not new_populators: new_populators[:], existing_populators[:] = self._populators(context, path, row, adapter) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 0fa32f73f6..3489d81f2e 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -53,7 +53,7 @@ class ColumnProperty(StrategizedProperty): self.columns = [expression._labeled(c) for c in columns] self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) - self.no_instrument = kwargs.pop('_no_instrument', False) + self.instrument = kwargs.pop('_instrument', True) self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator) self.descriptor = kwargs.pop('descriptor', None) self.extension = kwargs.pop('extension', None) @@ -63,7 +63,7 @@ class ColumnProperty(StrategizedProperty): self.__class__.__name__, ', '.join(sorted(kwargs.keys())))) util.set_creation_order(self) - if self.no_instrument: + if not self.instrument: self.strategy_class = strategies.UninstrumentedColumnLoader elif self.deferred: self.strategy_class = strategies.DeferredColumnLoader @@ -71,7 +71,7 @@ class ColumnProperty(StrategizedProperty): self.strategy_class = strategies.ColumnLoader def instrument_class(self, mapper): - if self.no_instrument: + if not self.instrument: return attributes.register_descriptor( @@ -104,7 +104,7 @@ class ColumnProperty(StrategizedProperty): def setattr(self, state, value, column): state.get_impl(self.key).set(state, state.dict, value, None) - def merge(self, session, source, dest, dont_load, _recursive): + def merge(self, session, source, dest, load, _recursive): value = attributes.instance_state(source).value_as_iterable( self.key, passive=True) if value: @@ -302,7 +302,7 @@ class SynonymProperty(MapperProperty): proxy_property=self.descriptor ) - def merge(self, session, source, dest, dont_load, _recursive): + def merge(self, session, source, dest, load, _recursive): pass log.class_logger(SynonymProperty) @@ -335,7 +335,7 @@ class ComparableProperty(MapperProperty): def create_row_processor(self, selectcontext, path, mapper, row, adapter): return (None, None) - def merge(self, session, source, dest, dont_load, _recursive): + def merge(self, session, source, dest, load, _recursive): pass @@ -627,8 +627,8 @@ class RelationProperty(StrategizedProperty): def __str__(self): return str(self.parent.class_.__name__) + "." + self.key - def merge(self, session, source, dest, dont_load, _recursive): - if not dont_load: + def merge(self, session, source, dest, load, _recursive): + if load: # TODO: no test coverage for recursive check for r in self._reverse_property: if (source, r) in _recursive: @@ -650,10 +650,10 @@ class RelationProperty(StrategizedProperty): dest_list = [] for current in instances: _recursive[(current, self)] = True - obj = session._merge(current, dont_load=dont_load, _recursive=_recursive) + obj = session._merge(current, load=load, _recursive=_recursive) if obj is not None: dest_list.append(obj) - if dont_load: + if not load: coll = attributes.init_collection(dest_state, self.key) for c in dest_list: coll.append_without_event(c) @@ -663,9 +663,9 @@ class RelationProperty(StrategizedProperty): current = instances[0] if current is not None: _recursive[(current, self)] = True - obj = session._merge(current, dont_load=dont_load, _recursive=_recursive) + obj = session._merge(current, load=load, _recursive=_recursive) if obj is not None: - if dont_load: + if not load: dest_state.dict[self.key] = obj else: setattr(dest, self.key, obj) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index e764856bf2..21137bc28b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -648,7 +648,11 @@ class Query(object): def value(self, column): """Return a scalar result corresponding to the given column expression.""" try: + # Py3K + #return self.values(column).__next__()[0] + # Py2K return self.values(column).next()[0] + # end Py2K except StopIteration: return None @@ -1433,7 +1437,7 @@ class Query(object): session._finalize_loaded(context.progress) - for ii, (dict_, attrs) in context.partials.items(): + for ii, (dict_, attrs) in context.partials.iteritems(): ii.commit(dict_, attrs) for row in rows: @@ -1582,7 +1586,7 @@ class Query(object): self.session._autoflush() return self.session.scalar(s, params=self._params, mapper=self._mapper_zero()) - def delete(self, synchronize_session='fetch'): + def delete(self, synchronize_session='evaluate'): """Perform a bulk delete query. Deletes rows matched by this query from the database. @@ -1592,15 +1596,15 @@ class Query(object): False don't synchronize the session. This option is the most efficient and is reliable - once the session is expired, which typically occurs after a commit(). Before - the expiration, objects may still remain in the session which were in fact deleted - which can lead to confusing results if they are accessed via get() or already - loaded collections. + once the session is expired, which typically occurs after a commit(), or explicitly + using expire_all(). Before the expiration, objects may still remain in the session + which were in fact deleted which can lead to confusing results if they are accessed + via get() or already loaded collections. 'fetch' performs a select query before the delete to find objects that are matched by the delete query and need to be removed from the session. Matched objects - are removed from the session. 'fetch' is the default strategy. + are removed from the session. 'evaluate' experimental feature. Tries to evaluate the querys criteria in Python @@ -1642,9 +1646,10 @@ class Query(object): if synchronize_session == 'evaluate': try: evaluator_compiler = evaluator.EvaluatorCompiler() - eval_condition = evaluator_compiler.process(self.whereclause) + eval_condition = evaluator_compiler.process(self.whereclause or expression._Null) except evaluator.UnevaluatableError: - synchronize_session = 'fetch' + raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python. " + "Specify 'fetch' or False for the synchronize_session parameter.") delete_stmt = sql.delete(primary_table, context.whereclause) @@ -1677,7 +1682,7 @@ class Query(object): return result.rowcount - def update(self, values, synchronize_session='expire'): + def update(self, values, synchronize_session='evaluate'): """Perform a bulk update query. Updates rows matched by this query in the database. @@ -1689,18 +1694,19 @@ class Query(object): attributes on objects in the session. Valid values are: False - don't synchronize the session. Use this when you don't need to use the - session after the update or you can be sure that none of the matched objects - are in the session. - - 'expire' + don't synchronize the session. This option is the most efficient and is reliable + once the session is expired, which typically occurs after a commit(), or explicitly + using expire_all(). Before the expiration, updated objects may still remain in the session + with stale values on their attributes, which can lead to confusing results. + + 'fetch' performs a select query before the update to find objects that are matched by the update query. The updated attributes are expired on matched objects. 'evaluate' - experimental feature. Tries to evaluate the querys criteria in Python + Tries to evaluate the Query's criteria in Python straight on the objects in the session. If evaluation of the criteria isn't - implemented, the 'expire' strategy will be used as a fallback. + implemented, an exception is raised. The expression evaluator currently doesn't account for differing string collations between the database and Python. @@ -1709,6 +1715,7 @@ class Query(object): The method does *not* offer in-Python cascading of relations - it is assumed that ON UPDATE CASCADE is configured for any foreign key references which require it. + The Session needs to be expired (occurs automatically after commit(), or call expire_all()) in order for the state of dependent objects subject foreign key cascade to be correctly represented. @@ -1723,9 +1730,9 @@ class Query(object): #TODO: updates of manytoone relations need to be converted to fk assignments #TODO: cascades need handling. + if synchronize_session not in [False, 'evaluate', 'fetch']: + raise sa_exc.ArgumentError("Valid strategies for session synchronization are False, 'evaluate' and 'fetch'") self._no_select_modifiers("update") - if synchronize_session not in [False, 'evaluate', 'expire']: - raise sa_exc.ArgumentError("Valid strategies for session synchronization are False, 'evaluate' and 'expire'") self = self.enable_eagerloads(False) @@ -1739,18 +1746,19 @@ class Query(object): if synchronize_session == 'evaluate': try: evaluator_compiler = evaluator.EvaluatorCompiler() - eval_condition = evaluator_compiler.process(self.whereclause) + eval_condition = evaluator_compiler.process(self.whereclause or expression._Null) value_evaluators = {} - for key,value in values.items(): + for key,value in values.iteritems(): key = expression._column_as_key(key) value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value)) except evaluator.UnevaluatableError: - synchronize_session = 'expire' + raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python. " + "Specify 'fetch' or False for the synchronize_session parameter.") update_stmt = sql.update(primary_table, context.whereclause, values) - if synchronize_session == 'expire': + if synchronize_session == 'fetch': select_stmt = context.statement.with_only_columns(primary_table.primary_key) matched_rows = session.execute(select_stmt, params=self._params).fetchall() @@ -1777,7 +1785,7 @@ class Query(object): # expire attributes with pending changes (there was no autoflush, so they are overwritten) state.expire_attributes(set(evaluated_keys).difference(to_evaluate)) - elif synchronize_session == 'expire': + elif synchronize_session == 'fetch': target_mapper = self._mapper_zero() for primary_key in matched_rows: @@ -2142,7 +2150,7 @@ class _ColumnEntity(_QueryEntity): return entity is self.entity_zero else: return not _is_aliased_class(self.entity_zero) and entity.base_mapper.common_parent(self.entity_zero) - + def _resolve_expr_against_query_aliases(self, query, expr, context): return query._adapt_clause(expr, False, True) diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 4339b68ebc..28eb63819e 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -171,7 +171,7 @@ class _ScopedExt(MapperExtension): def _default__init__(ext, mapper): def __init__(self, **kwargs): - for key, value in kwargs.items(): + for key, value in kwargs.iteritems(): if ext.validate: if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index c010a217bd..d3d653de4f 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -109,9 +109,9 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, like:: sess = Session(binds={ - SomeMappedClass: create_engine('postgres://engine1'), - somemapper: create_engine('postgres://engine2'), - some_table: create_engine('postgres://engine3'), + SomeMappedClass: create_engine('postgresql://engine1'), + somemapper: create_engine('postgresql://engine2'), + some_table: create_engine('postgresql://engine3'), }) Also see the ``bind_mapper()`` and ``bind_table()`` methods. @@ -1139,7 +1139,7 @@ class Session(object): for state, m, o in cascade_states: self._delete_impl(state) - def merge(self, instance, dont_load=False): + def merge(self, instance, load=True, **kw): """Copy the state an instance onto the persistent instance with the same identifier. If there is no persistent instance currently associated with the @@ -1152,6 +1152,10 @@ class Session(object): mapped with ``cascade="merge"``. """ + if 'dont_load' in kw: + load = not kw['dont_load'] + util.warn_deprecated("dont_load=True has been renamed to load=False.") + # TODO: this should be an IdentityDict for instances, but will # need a separate dict for PropertyLoader tuples _recursive = {} @@ -1159,11 +1163,11 @@ class Session(object): autoflush = self.autoflush try: self.autoflush = False - return self._merge(instance, dont_load=dont_load, _recursive=_recursive) + return self._merge(instance, load=load, _recursive=_recursive) finally: self.autoflush = autoflush - def _merge(self, instance, dont_load=False, _recursive=None): + def _merge(self, instance, load=True, _recursive=None): mapper = _object_mapper(instance) if instance in _recursive: return _recursive[instance] @@ -1173,24 +1177,24 @@ class Session(object): key = state.key if key is None: - if dont_load: + if not load: raise sa_exc.InvalidRequestError( - "merge() with dont_load=True option does not support " + "merge() with load=False option does not support " "objects transient (i.e. unpersisted) objects. flush() " "all changes on mapped instances before merging with " - "dont_load=True.") + "load=False.") key = mapper._identity_key_from_state(state) merged = None if key: if key in self.identity_map: merged = self.identity_map[key] - elif dont_load: + elif not load: if state.modified: raise sa_exc.InvalidRequestError( - "merge() with dont_load=True option does not support " + "merge() with load=False option does not support " "objects marked as 'dirty'. flush() all changes on " - "mapped instances before merging with dont_load=True.") + "mapped instances before merging with load=False.") merged = mapper.class_manager.new_instance() merged_state = attributes.instance_state(merged) merged_state.key = key @@ -1208,9 +1212,9 @@ class Session(object): _recursive[instance] = merged for prop in mapper.iterate_properties: - prop.merge(self, instance, merged, dont_load, _recursive) + prop.merge(self, instance, merged, load, _recursive) - if dont_load: + if not load: attributes.instance_state(merged).commit_all(attributes.instance_dict(merged), self.identity_map) # remove any history if new_instance: diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 10a0f43eeb..4d9fa5ade8 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -31,11 +31,17 @@ class InstanceState(object): def detach(self): if self.session_id: - del self.session_id + try: + del self.session_id + except AttributeError: + pass def dispose(self): if self.session_id: - del self.session_id + try: + del self.session_id + except AttributeError: + pass del self.obj def _cleanup(self, ref): @@ -230,7 +236,7 @@ class InstanceState(object): for key in attribute_names: impl = self.manager[key].impl if not filter_deferred or \ - not impl.dont_expire_missing or \ + impl.expire_missing or \ key in dict_: self.expired_attributes.add(key) if impl.accepts_scalar_loader: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index f739fb1dd0..e19e8fb31c 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -45,6 +45,7 @@ def _register_attribute(strategy, mapper, useobject, if useobject: attribute_ext.append(sessionlib.UOWEventHandler(prop.key)) + for m in mapper.polymorphic_iterator(): if prop is m._props.get(prop.key): @@ -235,7 +236,7 @@ class DeferredColumnLoader(LoaderStrategy): copy_function=self.columns[0].type.copy_value, mutable_scalars=self.columns[0].type.is_mutable(), callable_=self._class_level_loader, - dont_expire_missing=True + expire_missing=False ) def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs): diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index d650f65a54..bca6b4f463 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -531,7 +531,7 @@ class UOWTask(object): if subtask not in deps_by_targettask: continue for dep in deps_by_targettask[subtask]: - if dep.processor.no_dependencies or not dependency_in_cycles(dep): + if not dep.processor.has_dependencies or not dependency_in_cycles(dep): continue (processor, targettask) = (dep.processor, dep.targettask) isdelete = taskelement.isdelete diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index c4e1af20cf..dabdc6e353 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -19,7 +19,7 @@ SQLAlchemy connection pool. import weakref, time, threading from sqlalchemy import exc, log -from sqlalchemy import queue as Queue +from sqlalchemy import queue as sqla_queue from sqlalchemy.util import threading, pickle, as_interface proxies = {} @@ -51,7 +51,7 @@ def clear_managers(): All pools and connections are disposed. """ - for manager in proxies.values(): + for manager in proxies.itervalues(): manager.close() proxies.clear() @@ -108,6 +108,7 @@ class Pool(object): self.echo = echo self.listeners = [] self._on_connect = [] + self._on_first_connect = [] self._on_checkout = [] self._on_checkin = [] @@ -178,12 +179,14 @@ class Pool(object): """ - listener = as_interface( - listener, methods=('connect', 'checkout', 'checkin')) + listener = as_interface(listener, + methods=('connect', 'first_connect', 'checkout', 'checkin')) self.listeners.append(listener) if hasattr(listener, 'connect'): self._on_connect.append(listener) + if hasattr(listener, 'first_connect'): + self._on_first_connect.append(listener) if hasattr(listener, 'checkout'): self._on_checkout.append(listener) if hasattr(listener, 'checkin'): @@ -197,6 +200,10 @@ class _ConnectionRecord(object): self.__pool = pool self.connection = self.__connect() self.info = {} + ls = pool.__dict__.pop('_on_first_connect', None) + if ls is not None: + for l in ls: + l.first_connect(self.connection, self) if pool._on_connect: for l in pool._on_connect: l.connect(self.connection, self) @@ -269,8 +276,11 @@ class _ConnectionRecord(object): def _finalize_fairy(connection, connection_record, pool, ref=None): - if ref is not None and connection_record.backref is not ref: + _refs.discard(connection_record) + + if ref is not None and (connection_record.fairy is not ref or isinstance(pool, AssertionPool)): return + if connection is not None: try: if pool._reset_on_return: @@ -284,7 +294,7 @@ def _finalize_fairy(connection, connection_record, pool, ref=None): if isinstance(e, (SystemExit, KeyboardInterrupt)): raise if connection_record is not None: - connection_record.backref = None + connection_record.fairy = None if pool._should_log_info: pool.log("Connection %r being returned to pool" % connection) if pool._on_checkin: @@ -292,6 +302,8 @@ def _finalize_fairy(connection, connection_record, pool, ref=None): l.checkin(connection, connection_record) pool.return_conn(connection_record) +_refs = set() + class _ConnectionFairy(object): """Proxies a DB-API connection and provides return-on-dereference support.""" @@ -303,7 +315,8 @@ class _ConnectionFairy(object): try: rec = self._connection_record = pool.get() conn = self.connection = self._connection_record.get_connection() - self._connection_record.backref = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref)) + rec.fairy = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref)) + _refs.add(rec) except: self.connection = None # helps with endless __getattr__ loops later on self._connection_record = None @@ -402,8 +415,9 @@ class _ConnectionFairy(object): """ if self._connection_record is not None: + _refs.remove(self._connection_record) + self._connection_record.fairy = None self._connection_record.connection = None - self._connection_record.backref = None self._pool.do_return_conn(self._connection_record) self._detached_info = \ self._connection_record.info.copy() @@ -501,10 +515,8 @@ class SingletonThreadPool(Pool): del self._conn.current def cleanup(self): - for conn in list(self._all_conns): - self._all_conns.discard(conn) - if len(self._all_conns) <= self.size: - return + while len(self._all_conns) > self.size: + self._all_conns.pop() def status(self): return "SingletonThreadPool id:%d size: %d" % (id(self), len(self._all_conns)) @@ -593,7 +605,7 @@ class QueuePool(Pool): """ Pool.__init__(self, creator, **params) - self._pool = Queue.Queue(pool_size) + self._pool = sqla_queue.Queue(pool_size) self._overflow = 0 - pool_size self._max_overflow = max_overflow self._timeout = timeout @@ -606,7 +618,7 @@ class QueuePool(Pool): def do_return_conn(self, conn): try: self._pool.put(conn, False) - except Queue.Full: + except sqla_queue.Full: if self._overflow_lock is None: self._overflow -= 1 else: @@ -620,7 +632,7 @@ class QueuePool(Pool): try: wait = self._max_overflow > -1 and self._overflow >= self._max_overflow return self._pool.get(wait, self._timeout) - except Queue.Empty: + except sqla_queue.Empty: if self._max_overflow > -1 and self._overflow >= self._max_overflow: if not wait: return self.do_get() @@ -648,7 +660,7 @@ class QueuePool(Pool): try: conn = self._pool.get(False) conn.close() - except Queue.Empty: + except sqla_queue.Empty: break self._overflow = 0 - self.size() @@ -747,7 +759,8 @@ class StaticPool(Pool): Pool.__init__(self, creator, **params) self._conn = creator() self.connection = _ConnectionRecord(self) - + self.connection = None + def status(self): return "StaticPool" @@ -788,68 +801,41 @@ class AssertionPool(Pool): ## TODO: modify this to handle an arbitrary connection count. - def __init__(self, creator, **params): - """ - Construct an AssertionPool. - - :param creator: a callable function that returns a DB-API - connection object. The function will be called with - parameters. - - :param recycle: If set to non -1, number of seconds between - connection recycling, which means upon checkout, if this - timeout is surpassed the connection will be closed and - replaced with a newly opened connection. Defaults to -1. - - :param echo: If True, connections being pulled and retrieved - from the pool will be logged to the standard output, as well - as pool sizing information. Echoing can also be achieved by - enabling logging for the "sqlalchemy.pool" - namespace. Defaults to False. - - :param use_threadlocal: If set to True, repeated calls to - :meth:`connect` within the same application thread will be - guaranteed to return the same connection object, if one has - already been retrieved from the pool and has not been - returned yet. Offers a slight performance advantage at the - cost of individual transactions by default. The - :meth:`unique_connection` method is provided to bypass the - threadlocal behavior installed into :meth:`connect`. - - :param reset_on_return: If true, reset the database state of - connections returned to the pool. This is typically a - ROLLBACK to release locks and transaction resources. - Disable at your own peril. Defaults to True. - - :param listeners: A list of - :class:`~sqlalchemy.interfaces.PoolListener`-like objects or - dictionaries of callables that receive events when DB-API - connections are created, checked out and checked in to the - pool. - - """ - Pool.__init__(self, creator, **params) - self.connection = _ConnectionRecord(self) - self._conn = self.connection - + def __init__(self, *args, **kw): + self._conn = None + self._checked_out = False + Pool.__init__(self, *args, **kw) + def status(self): return "AssertionPool" - def create_connection(self): - raise AssertionError("Invalid") - def do_return_conn(self, conn): - assert conn is self._conn and self.connection is None - self.connection = conn + if not self._checked_out: + raise AssertionError("connection is not checked out") + self._checked_out = False + assert conn is self._conn def do_return_invalid(self, conn): - raise AssertionError("Invalid") + self._conn = None + self._checked_out = False + + def dispose(self): + self._checked_out = False + self._conn.close() + def recreate(self): + self.log("Pool recreating") + return AssertionPool(self._creator, echo=self._should_log_info, listeners=self.listeners) + def do_get(self): - assert self.connection is not None - c = self.connection - self.connection = None - return c + if self._checked_out: + raise AssertionError("connection is already checked out") + + if not self._conn: + self._conn = self.create_connection() + + self._checked_out = True + return self._conn class _DBProxy(object): """Layers connection pooling behavior on top of a standard DB-API module. diff --git a/lib/sqlalchemy/queue.py b/lib/sqlalchemy/queue.py index c9ab82acf8..2aaeea9d0f 100644 --- a/lib/sqlalchemy/queue.py +++ b/lib/sqlalchemy/queue.py @@ -2,7 +2,7 @@ behavior, using RLock instead of Lock for its mutex object. This is to support the connection pool's usage of weakref callbacks to return -connections to the underlying Queue, which can apparently in extremely +connections to the underlying Queue, which can in extremely rare cases be invoked within the ``get()`` method of the Queue itself, producing a ``put()`` inside the ``get()`` and therefore a reentrant condition.""" diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index e641f119b3..231496676c 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -28,7 +28,7 @@ expressions. """ import re, inspect -from sqlalchemy import types, exc, util, databases +from sqlalchemy import types, exc, util, dialects from sqlalchemy.sql import expression, visitors URL = None @@ -65,13 +65,6 @@ class SchemaItem(visitors.Visitable): def __repr__(self): return "%s()" % self.__class__.__name__ - @property - def bind(self): - """Return the connectable associated with this SchemaItem.""" - - m = self.metadata - return m and m.bind or None - @util.memoized_property def info(self): return {} @@ -82,127 +75,140 @@ def _get_table_key(name, schema): else: return schema + "." + name -class _TableSingleton(visitors.VisitableType): - """A metaclass used by the ``Table`` object to provide singleton behavior.""" +class Table(SchemaItem, expression.TableClause): + """Represent a table in a database. + + e.g.:: + + mytable = Table("mytable", metadata, + Column('mytable_id', Integer, primary_key=True), + Column('value', String(50)) + ) + + The Table object constructs a unique instance of itself based on its + name within the given MetaData object. Constructor + arguments are as follows: + + :param name: The name of this table as represented in the database. + + This property, along with the *schema*, indicates the *singleton + identity* of this table in relation to its parent :class:`MetaData`. + Additional calls to :class:`Table` with the same name, metadata, + and schema name will return the same :class:`Table` object. + + Names which contain no upper case characters + will be treated as case insensitive names, and will not be quoted + unless they are a reserved word. Names with any number of upper + case characters will be quoted and sent exactly. Note that this + behavior applies even for databases which standardize upper + case names as case insensitive such as Oracle. + + :param metadata: a :class:`MetaData` object which will contain this + table. The metadata is used as a point of association of this table + with other tables which are referenced via foreign key. It also + may be used to associate this table with a particular + :class:`~sqlalchemy.engine.base.Connectable`. + + :param \*args: Additional positional arguments are used primarily + to add the list of :class:`Column` objects contained within this table. + Similar to the style of a CREATE TABLE statement, other :class:`SchemaItem` + constructs may be added here, including :class:`PrimaryKeyConstraint`, + and :class:`ForeignKeyConstraint`. + + :param autoload: Defaults to False: the Columns for this table should be reflected + from the database. Usually there will be no Column objects in the + constructor if this property is set. + + :param autoload_with: If autoload==True, this is an optional Engine or Connection + instance to be used for the table reflection. If ``None``, the + underlying MetaData's bound connectable will be used. + + :param implicit_returning: True by default - indicates that + RETURNING can be used by default to fetch newly inserted primary key + values, for backends which support this. Note that + create_engine() also provides an implicit_returning flag. + + :param include_columns: A list of strings indicating a subset of columns to be loaded via + the ``autoload`` operation; table columns who aren't present in + this list will not be represented on the resulting ``Table`` + object. Defaults to ``None`` which indicates all columns should + be reflected. + + :param info: A dictionary which defaults to ``{}``. A space to store application + specific data. This must be a dictionary. + + :param mustexist: When ``True``, indicates that this Table must already + be present in the given :class:`MetaData`` collection. + + :param prefixes: + A list of strings to insert after CREATE in the CREATE TABLE + statement. They will be separated by spaces. + + :param quote: Force quoting of this table's name on or off, corresponding + to ``True`` or ``False``. When left at its default of ``None``, + the column identifier will be quoted according to whether the name is + case sensitive (identifiers with at least one upper case character are + treated as case sensitive), or if it's a reserved word. This flag + is only needed to force quoting of a reserved word which is not known + by the SQLAlchemy dialect. + + :param quote_schema: same as 'quote' but applies to the schema identifier. + + :param schema: The *schema name* for this table, which is required if the table + resides in a schema other than the default selected schema for the + engine's database connection. Defaults to ``None``. + + :param useexisting: When ``True``, indicates that if this Table is already + present in the given :class:`MetaData`, apply further arguments within + the constructor to the existing :class:`Table`. If this flag is not + set, an error is raised when the parameters of an existing :class:`Table` + are overwritten. + + """ + + __visit_name__ = 'table' + + ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop') - def __call__(self, name, metadata, *args, **kwargs): - schema = kwargs.get('schema', kwargs.get('owner', None)) - useexisting = kwargs.pop('useexisting', False) - mustexist = kwargs.pop('mustexist', False) + def __new__(cls, name, metadata, *args, **kw): + schema = kw.get('schema', None) + useexisting = kw.pop('useexisting', False) + mustexist = kw.pop('mustexist', False) key = _get_table_key(name, schema) - try: - table = metadata.tables[key] - if not useexisting and table._cant_override(*args, **kwargs): + if key in metadata.tables: + if not useexisting and bool(args): raise exc.InvalidRequestError( "Table '%s' is already defined for this MetaData instance. " "Specify 'useexisting=True' to redefine options and " "columns on an existing Table object." % key) - else: - table._init_existing(*args, **kwargs) + table = metadata.tables[key] + table._init_existing(*args, **kw) return table - except KeyError: + else: if mustexist: raise exc.InvalidRequestError( "Table '%s' not defined" % (key)) + metadata.tables[key] = table = object.__new__(cls) try: - return type.__call__(self, name, metadata, *args, **kwargs) + table._init(name, metadata, *args, **kw) + return table except: - if key in metadata.tables: - del metadata.tables[key] + metadata.tables.pop(key) raise - - -class Table(SchemaItem, expression.TableClause): - """Represent a table in a database.""" - - __metaclass__ = _TableSingleton - - __visit_name__ = 'table' - - ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop') - - def __init__(self, name, metadata, *args, **kwargs): - """ - Construct a Table. - - :param name: The name of this table as represented in the database. - - This property, along with the *schema*, indicates the *singleton - identity* of this table in relation to its parent :class:`MetaData`. - Additional calls to :class:`Table` with the same name, metadata, - and schema name will return the same :class:`Table` object. - - Names which contain no upper case characters - will be treated as case insensitive names, and will not be quoted - unless they are a reserved word. Names with any number of upper - case characters will be quoted and sent exactly. Note that this - behavior applies even for databases which standardize upper - case names as case insensitive such as Oracle. - - :param metadata: a :class:`MetaData` object which will contain this - table. The metadata is used as a point of association of this table - with other tables which are referenced via foreign key. It also - may be used to associate this table with a particular - :class:`~sqlalchemy.engine.base.Connectable`. - - :param \*args: Additional positional arguments are used primarily - to add the list of :class:`Column` objects contained within this table. - Similar to the style of a CREATE TABLE statement, other :class:`SchemaItem` - constructs may be added here, including :class:`PrimaryKeyConstraint`, - and :class:`ForeignKeyConstraint`. - - :param autoload: Defaults to False: the Columns for this table should be reflected - from the database. Usually there will be no Column objects in the - constructor if this property is set. - - :param autoload_with: If autoload==True, this is an optional Engine or Connection - instance to be used for the table reflection. If ``None``, the - underlying MetaData's bound connectable will be used. - - :param include_columns: A list of strings indicating a subset of columns to be loaded via - the ``autoload`` operation; table columns who aren't present in - this list will not be represented on the resulting ``Table`` - object. Defaults to ``None`` which indicates all columns should - be reflected. - - :param info: A dictionary which defaults to ``{}``. A space to store application - specific data. This must be a dictionary. - - :param mustexist: When ``True``, indicates that this Table must already - be present in the given :class:`MetaData`` collection. - - :param prefixes: - A list of strings to insert after CREATE in the CREATE TABLE - statement. They will be separated by spaces. - - :param quote: Force quoting of this table's name on or off, corresponding - to ``True`` or ``False``. When left at its default of ``None``, - the column identifier will be quoted according to whether the name is - case sensitive (identifiers with at least one upper case character are - treated as case sensitive), or if it's a reserved word. This flag - is only needed to force quoting of a reserved word which is not known - by the SQLAlchemy dialect. - - :param quote_schema: same as 'quote' but applies to the schema identifier. - - :param schema: The *schema name* for this table, which is required if the table - resides in a schema other than the default selected schema for the - engine's database connection. Defaults to ``None``. - - :param useexisting: When ``True``, indicates that if this Table is already - present in the given :class:`MetaData`, apply further arguments within - the constructor to the existing :class:`Table`. If this flag is not - set, an error is raised when the parameters of an existing :class:`Table` - are overwritten. - - """ + + def __init__(self, *args, **kw): + # __init__ is overridden to prevent __new__ from + # calling the superclass constructor. + pass + + def _init(self, name, metadata, *args, **kwargs): super(Table, self).__init__(name) self.metadata = metadata - self.schema = kwargs.pop('schema', kwargs.pop('owner', None)) + self.schema = kwargs.pop('schema', None) self.indexes = set() self.constraints = set() self._columns = expression.ColumnCollection() - self.primary_key = PrimaryKeyConstraint() + self._set_primary_key(PrimaryKeyConstraint()) self._foreign_keys = util.OrderedSet() self.ddl_listeners = util.defaultdict(list) self.kwargs = {} @@ -215,8 +221,7 @@ class Table(SchemaItem, expression.TableClause): autoload_with = kwargs.pop('autoload_with', None) include_columns = kwargs.pop('include_columns', None) - self._set_parent(metadata) - + self.implicit_returning = kwargs.pop('implicit_returning', True) self.quote = kwargs.pop('quote', None) self.quote_schema = kwargs.pop('quote_schema', None) if 'info' in kwargs: @@ -224,7 +229,7 @@ class Table(SchemaItem, expression.TableClause): self._prefixes = kwargs.pop('prefixes', []) - self.__extra_kwargs(**kwargs) + self._extra_kwargs(**kwargs) # load column definitions from the database if 'autoload' is defined # we do it after the table is in the singleton dictionary to support @@ -237,7 +242,7 @@ class Table(SchemaItem, expression.TableClause): # initialize all the column, etc. objects. done after reflection to # allow user-overrides - self.__post_init(*args, **kwargs) + self._init_items(*args) def _init_existing(self, *args, **kwargs): autoload = kwargs.pop('autoload', False) @@ -261,43 +266,43 @@ class Table(SchemaItem, expression.TableClause): if 'info' in kwargs: self.info = kwargs.pop('info') - self.__extra_kwargs(**kwargs) - self.__post_init(*args, **kwargs) - - def _cant_override(self, *args, **kwargs): - """Return True if any argument is not supported as an override. - - Takes arguments that would be sent to Table.__init__, and returns - True if any of them would be disallowed if sent to an existing - Table singleton. - """ - return bool(args) or bool(set(kwargs).difference( - ['autoload', 'autoload_with', 'schema', 'owner'])) + self._extra_kwargs(**kwargs) + self._init_items(*args) - def __extra_kwargs(self, **kwargs): + def _extra_kwargs(self, **kwargs): # validate remaining kwargs that they all specify DB prefixes if len([k for k in kwargs - if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]): + if not re.match(r'^(?:%s)_' % '|'.join(dialects.__all__), k)]): raise TypeError( - "Invalid argument(s) for Table: %s" % repr(kwargs.keys())) + "Invalid argument(s) for Table: %r" % kwargs.keys()) self.kwargs.update(kwargs) - def __post_init(self, *args, **kwargs): - self._init_items(*args) - - @property - def key(self): - return _get_table_key(self.name, self.schema) - def _set_primary_key(self, pk): if getattr(self, '_primary_key', None) in self.constraints: self.constraints.remove(self._primary_key) self._primary_key = pk self.constraints.add(pk) + for c in pk.columns: + c.primary_key = True + + @util.memoized_property + def _autoincrement_column(self): + for col in self.primary_key: + if col.autoincrement and \ + isinstance(col.type, types.Integer) and \ + not col.foreign_keys and \ + isinstance(col.default, (type(None), Sequence)): + + return col + + @property + def key(self): + return _get_table_key(self.name, self.schema) + + @property def primary_key(self): return self._primary_key - primary_key = property(primary_key, _set_primary_key) def __repr__(self): return "Table(%s)" % ', '.join( @@ -308,6 +313,12 @@ class Table(SchemaItem, expression.TableClause): def __str__(self): return _get_table_key(self.description, self.schema) + @property + def bind(self): + """Return the connectable associated with this Table.""" + + return self.metadata and self.metadata.bind or None + def append_column(self, column): """Append a ``Column`` to this ``Table``.""" @@ -359,7 +370,7 @@ class Table(SchemaItem, expression.TableClause): self, column_collections=column_collections, **kwargs) else: if column_collections: - return [c for c in self.columns] + return list(self.columns) else: return [] @@ -407,7 +418,7 @@ class Column(SchemaItem, expression.ColumnClause): """Represents a column in a database table.""" __visit_name__ = 'column' - + def __init__(self, *args, **kwargs): """ Construct a new ``Column`` object. @@ -478,7 +489,7 @@ class Column(SchemaItem, expression.ColumnClause): Contrast this argument to ``server_default`` which creates a default generator on the database side. - + :param key: An optional string identifier which will identify this ``Column`` object on the :class:`Table`. When a key is provided, this is the only identifier referencing the ``Column`` within the application, @@ -568,10 +579,6 @@ class Column(SchemaItem, expression.ColumnClause): if args: coltype = args[0] - # adjust for partials - if util.callable(coltype): - coltype = args[0]() - if (isinstance(coltype, types.AbstractType) or (isinstance(coltype, type) and issubclass(coltype, types.AbstractType))): @@ -581,7 +588,6 @@ class Column(SchemaItem, expression.ColumnClause): type_ = args.pop(0) super(Column, self).__init__(name, None, type_) - self.args = args self.key = kwargs.pop('key', name) self.primary_key = kwargs.pop('primary_key', False) self.nullable = kwargs.pop('nullable', not self.primary_key) @@ -595,6 +601,28 @@ class Column(SchemaItem, expression.ColumnClause): self.autoincrement = kwargs.pop('autoincrement', True) self.constraints = set() self.foreign_keys = util.OrderedSet() + self._table_events = set() + + if self.default is not None: + if isinstance(self.default, ColumnDefault): + args.append(self.default) + else: + args.append(ColumnDefault(self.default)) + if self.server_default is not None: + if isinstance(self.server_default, FetchedValue): + args.append(self.server_default) + else: + args.append(DefaultClause(self.server_default)) + if self.onupdate is not None: + args.append(ColumnDefault(self.onupdate, for_update=True)) + if self.server_onupdate is not None: + if isinstance(self.server_onupdate, FetchedValue): + args.append(self.server_default) + else: + args.append(DefaultClause(self.server_onupdate, + for_update=True)) + self._init_items(*args) + util.set_creation_order(self) if 'info' in kwargs: @@ -615,10 +643,6 @@ class Column(SchemaItem, expression.ColumnClause): else: return self.description - @property - def bind(self): - return self.table.bind - def references(self, column): """Return True if this Column references the given column via foreign key.""" for fk in self.foreign_keys: @@ -658,24 +682,26 @@ class Column(SchemaItem, expression.ColumnClause): "before adding to a Table.") if self.key is None: self.key = self.name - self.metadata = table.metadata + if getattr(self, 'table', None) is not None: raise exc.ArgumentError("this Column already has a table!") if self.key in table._columns: - # note the column being replaced, if any - self._pre_existing_column = table._columns.get(self.key) + col = table._columns.get(self.key) + for fk in col.foreign_keys: + col.foreign_keys.remove(fk) + table.foreign_keys.remove(fk) + table.constraints.remove(fk.constraint) + table._columns.replace(self) if self.primary_key: - table.primary_key.replace(self) + table.primary_key._replace(self) elif self.key in table.primary_key: raise exc.ArgumentError( "Trying to redefine primary-key column '%s' as a " "non-primary-key column on table '%s'" % ( self.key, table.fullname)) - # if we think this should not raise an error, we'd instead do this: - #table.primary_key.remove(self) self.table = table if self.index: @@ -695,48 +721,47 @@ class Column(SchemaItem, expression.ColumnClause): "external to the Table.") table.append_constraint(UniqueConstraint(self.key)) - toinit = list(self.args) - if self.default is not None: - if isinstance(self.default, ColumnDefault): - toinit.append(self.default) - else: - toinit.append(ColumnDefault(self.default)) - if self.server_default is not None: - if isinstance(self.server_default, FetchedValue): - toinit.append(self.server_default) - else: - toinit.append(DefaultClause(self.server_default)) - if self.onupdate is not None: - toinit.append(ColumnDefault(self.onupdate, for_update=True)) - if self.server_onupdate is not None: - if isinstance(self.server_onupdate, FetchedValue): - toinit.append(self.server_default) - else: - toinit.append(DefaultClause(self.server_onupdate, - for_update=True)) - self._init_items(*toinit) - self.args = None - + for fn in self._table_events: + fn(table) + del self._table_events + + def _on_table_attach(self, fn): + if self.table: + fn(self.table) + else: + self._table_events.add(fn) + def copy(self, **kw): """Create a copy of this ``Column``, unitialized. This is used in ``Table.tometadata``. """ - return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, quote=self.quote, index=self.index, autoincrement=self.autoincrement, *[c.copy(**kw) for c in self.constraints]) + return Column( + self.name, + self.type, + self.default, + key = self.key, + primary_key = self.primary_key, + nullable = self.nullable, + quote=self.quote, + index=self.index, + autoincrement=self.autoincrement, + *[c.copy(**kw) for c in self.constraints]) def _make_proxy(self, selectable, name=None): """Create a *proxy* for this column. This is a copy of this ``Column`` referenced by a different parent - (such as an alias or select statement). - + (such as an alias or select statement). The column should + be used only in select scenarios, as its full DDL/default + information is not transferred. + """ fk = [ForeignKey(f.column) for f in self.foreign_keys] c = Column( name or self.name, self.type, - self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, @@ -746,7 +771,9 @@ class Column(SchemaItem, expression.ColumnClause): selectable.columns.add(c) if self.primary_key: selectable.primary_key.add(c) - [c._init_items(f) for f in fk] + for fn in c._table_events: + fn(selectable) + del c._table_events return c def get_children(self, schema_visitor=False, **kwargs): @@ -775,9 +802,13 @@ class ForeignKey(SchemaItem): __visit_name__ = 'foreign_key' - def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None, link_to_name=False): + def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, + ondelete=None, deferrable=None, initially=None, link_to_name=False): """ - Construct a column-level FOREIGN KEY. + Construct a column-level FOREIGN KEY. + + The :class:`ForeignKey` object when constructed generates a :class:`ForeignKeyConstraint` + which is associated with the parent :class:`Table` object's collection of constraints. :param column: A single target column for the key relationship. A :class:`Column` object or a column name as a string: ``tablename.columnkey`` or @@ -809,10 +840,10 @@ class ForeignKey(SchemaItem): :param link_to_name: if True, the string name given in ``column`` is the rendered name of the referenced column, not its locally assigned ``key``. - :param use_alter: If True, do not emit this key as part of the CREATE TABLE - definition. Instead, use ALTER TABLE after table creation to add - the key. Useful for circular dependencies. - + :param use_alter: passed to the underlying :class:`ForeignKeyConstraint` to indicate the + constraint should be generated/dropped externally from the CREATE TABLE/ + DROP TABLE statement. See that classes' constructor for details. + """ self._colspec = column @@ -946,39 +977,35 @@ class ForeignKey(SchemaItem): def _set_parent(self, column): if hasattr(self, 'parent'): + if self.parent is column: + return raise exc.InvalidRequestError("This ForeignKey already has a parent !") self.parent = column - if hasattr(self.parent, '_pre_existing_column'): - # remove existing FK which matches us - for fk in self.parent._pre_existing_column.foreign_keys: - if fk.target_fullname == self.target_fullname: - self.parent.table.foreign_keys.remove(fk) - self.parent.table.constraints.remove(fk.constraint) - - if self.constraint is None and isinstance(self.parent.table, Table): + self.parent.foreign_keys.add(self) + self.parent._on_table_attach(self._set_table) + + def _set_table(self, table): + if self.constraint is None and isinstance(table, Table): self.constraint = ForeignKeyConstraint( [], [], use_alter=self.use_alter, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, - deferrable=self.deferrable, initially=self.initially) - self.parent.table.append_constraint(self.constraint) - self.constraint._append_fk(self) - - self.parent.foreign_keys.add(self) - self.parent.table.foreign_keys.add(self) - + deferrable=self.deferrable, initially=self.initially, + ) + self.constraint._elements[self.parent] = self + self.constraint._set_parent(table) + table.foreign_keys.add(self) + class DefaultGenerator(SchemaItem): """Base class for column *default* values.""" __visit_name__ = 'default_generator' - def __init__(self, for_update=False, metadata=None): + def __init__(self, for_update=False): self.for_update = for_update - self.metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata') def _set_parent(self, column): self.column = column - self.metadata = self.column.table.metadata if self.for_update: self.column.onupdate = self else: @@ -989,6 +1016,14 @@ class DefaultGenerator(SchemaItem): bind = _bind_or_error(self) return bind._execute_default(self, **kwargs) + @property + def bind(self): + """Return the connectable associated with this default.""" + if getattr(self, 'column', None): + return self.column.table.bind + else: + return None + def __repr__(self): return "DefaultGenerator()" @@ -1026,7 +1061,9 @@ class ColumnDefault(DefaultGenerator): return lambda ctx: fn() positionals = len(argspec[0]) - if inspect.ismethod(inspectable): + + # Py3K compat - no unbound methods + if inspect.ismethod(inspectable) or inspect.isclass(fn): positionals -= 1 if positionals == 0: @@ -1047,7 +1084,7 @@ class ColumnDefault(DefaultGenerator): __visit_name__ = property(_visit_name) def __repr__(self): - return "ColumnDefault(%s)" % repr(self.arg) + return "ColumnDefault(%r)" % self.arg class Sequence(DefaultGenerator): """Represents a named database sequence.""" @@ -1055,15 +1092,15 @@ class Sequence(DefaultGenerator): __visit_name__ = 'sequence' def __init__(self, name, start=None, increment=None, schema=None, - optional=False, quote=None, **kwargs): - super(Sequence, self).__init__(**kwargs) + optional=False, quote=None, metadata=None, for_update=False): + super(Sequence, self).__init__(for_update=for_update) self.name = name self.start = start self.increment = increment self.optional = optional self.quote = quote self.schema = schema - self.kwargs = kwargs + self.metadata = metadata def __repr__(self): return "Sequence(%s)" % ', '.join( @@ -1074,7 +1111,19 @@ class Sequence(DefaultGenerator): def _set_parent(self, column): super(Sequence, self)._set_parent(column) column.sequence = self - + + column._on_table_attach(self._set_table) + + def _set_table(self, table): + self.metadata = table.metadata + + @property + def bind(self): + if self.metadata: + return self.metadata.bind + else: + return None + def create(self, bind=None, checkfirst=True): """Creates this sequence in the database.""" @@ -1123,17 +1172,12 @@ class DefaultClause(FetchedValue): # alias; deprecated starting 0.5.0 PassiveDefault = DefaultClause - class Constraint(SchemaItem): - """A table-level SQL constraint, such as a KEY. - - Implements a hybrid of dict/setlike behavior with regards to the list of - underying columns. - """ + """A table-level SQL constraint.""" __visit_name__ = 'constraint' - def __init__(self, name=None, deferrable=None, initially=None): + def __init__(self, name=None, deferrable=None, initially=None, inline_ddl=True): """Create a SQL constraint. name @@ -1146,33 +1190,87 @@ class Constraint(SchemaItem): initially Optional string. If set, emit INITIALLY when issuing DDL for this constraint. + + inline_ddl + if True, DDL for this Constraint will be generated within the span of a + CREATE TABLE or DROP TABLE statement, when the associated table's + DDL is generated. if False, no DDL is issued within that process. + Instead, it is expected that an AddConstraint or DropConstraint + construct will be used to issue DDL for this Contraint. + The AddConstraint/DropConstraint constructs set this flag automatically + as well. """ self.name = name - self.columns = expression.ColumnCollection() self.deferrable = deferrable self.initially = initially + self.inline_ddl = inline_ddl + + @property + def table(self): + if isinstance(self.parent, Table): + return self.parent + else: + raise exc.InvalidRequestError("This constraint is not bound to a table.") + + def _set_parent(self, parent): + self.parent = parent + parent.constraints.add(self) + + def copy(self, **kw): + raise NotImplementedError() + +class ColumnCollectionConstraint(Constraint): + """A constraint that proxies a ColumnCollection.""" + + def __init__(self, *columns, **kw): + """ + \*columns + A sequence of column names or Column objects. + + name + Optional, the in-database name of this constraint. + + deferrable + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + initially + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + """ + super(ColumnCollectionConstraint, self).__init__(**kw) + self.columns = expression.ColumnCollection() + self._pending_colargs = [_to_schema_column_or_string(c) for c in columns] + if self._pending_colargs and \ + isinstance(self._pending_colargs[0], Column) and \ + self._pending_colargs[0].table is not None: + self._set_parent(self._pending_colargs[0].table) + + def _set_parent(self, table): + super(ColumnCollectionConstraint, self)._set_parent(table) + for col in self._pending_colargs: + if isinstance(col, basestring): + col = table.c[col] + self.columns.add(col) def __contains__(self, x): return x in self.columns + def copy(self, **kw): + return self.__class__(name=self.name, deferrable=self.deferrable, + initially=self.initially, *self.columns.keys()) + def contains_column(self, col): return self.columns.contains_column(col) - def keys(self): - return self.columns.keys() - - def __add__(self, other): - return self.columns + other - def __iter__(self): return iter(self.columns) def __len__(self): return len(self.columns) - def copy(self, **kw): - raise NotImplementedError() class CheckConstraint(Constraint): """A table- or column-level CHECK constraint. @@ -1180,7 +1278,7 @@ class CheckConstraint(Constraint): Can be included in the definition of a Table or Column. """ - def __init__(self, sqltext, name=None, deferrable=None, initially=None): + def __init__(self, sqltext, name=None, deferrable=None, initially=None, table=None): """Construct a CHECK constraint. sqltext @@ -1197,6 +1295,7 @@ class CheckConstraint(Constraint): initially Optional string. If set, emit INITIALLY when issuing DDL for this constraint. + """ super(CheckConstraint, self).__init__(name, deferrable, initially) @@ -1204,7 +1303,9 @@ class CheckConstraint(Constraint): raise exc.ArgumentError( "sqltext must be a string and will be used verbatim.") self.sqltext = sqltext - + if table: + self._set_parent(table) + def __visit_name__(self): if isinstance(self.parent, Table): return "check_constraint" @@ -1212,10 +1313,6 @@ class CheckConstraint(Constraint): return "column_check_constraint" __visit_name__ = property(__visit_name__) - def _set_parent(self, parent): - self.parent = parent - parent.constraints.add(self) - def copy(self, **kw): return CheckConstraint(self.sqltext, name=self.name) @@ -1232,7 +1329,8 @@ class ForeignKeyConstraint(Constraint): """ __visit_name__ = 'foreign_key_constraint' - def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False, deferrable=None, initially=None, link_to_name=False): + def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, + deferrable=None, initially=None, use_alter=False, link_to_name=False, table=None): """Construct a composite-capable FOREIGN KEY. :param columns: A sequence of local column names. The named columns must be defined @@ -1261,42 +1359,72 @@ class ForeignKeyConstraint(Constraint): :param link_to_name: if True, the string name given in ``column`` is the rendered name of the referenced column, not its locally assigned ``key``. - :param use_alter: If True, do not emit this key as part of the CREATE TABLE - definition. Instead, use ALTER TABLE after table creation to add - the key. Useful for circular dependencies. + :param use_alter: If True, do not emit the DDL for this constraint + as part of the CREATE TABLE definition. Instead, generate it via an + ALTER TABLE statement issued after the full collection of tables have been + created, and drop it via an ALTER TABLE statement before the full collection + of tables are dropped. This is shorthand for the usage of + :class:`AddConstraint` and :class:`DropConstraint` applied as "after-create" + and "before-drop" events on the MetaData object. This is normally used to + generate/drop constraints on objects that are mutually dependent on each other. """ super(ForeignKeyConstraint, self).__init__(name, deferrable, initially) - self.__colnames = columns - self.__refcolnames = refcolumns - self.elements = util.OrderedSet() + self.onupdate = onupdate self.ondelete = ondelete self.link_to_name = link_to_name if self.name is None and use_alter: - raise exc.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name") + raise exc.ArgumentError("Alterable Constraint requires a name") self.use_alter = use_alter + self._elements = util.OrderedDict() + for col, refcol in zip(columns, refcolumns): + self._elements[col] = ForeignKey( + refcol, + constraint=self, + name=self.name, + onupdate=self.onupdate, + ondelete=self.ondelete, + use_alter=self.use_alter, + link_to_name=self.link_to_name + ) + + if table: + self._set_parent(table) + + @property + def columns(self): + return self._elements.keys() + + @property + def elements(self): + return self._elements.values() + def _set_parent(self, table): - self.table = table - if self not in table.constraints: - table.constraints.add(self) - for (c, r) in zip(self.__colnames, self.__refcolnames): - self.append_element(c, r) - - def append_element(self, col, refcol): - fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter, link_to_name=self.link_to_name) - fk._set_parent(self.table.c[col]) - self._append_fk(fk) - - def _append_fk(self, fk): - self.columns.add(self.table.c[fk.parent.key]) - self.elements.add(fk) - + super(ForeignKeyConstraint, self)._set_parent(table) + for col, fk in self._elements.iteritems(): + if isinstance(col, basestring): + col = table.c[col] + fk._set_parent(col) + + if self.use_alter: + def supports_alter(event, schema_item, bind, **kw): + return table in set(kw['tables']) and bind.dialect.supports_alter + AddConstraint(self, on=supports_alter).execute_at('after-create', table.metadata) + DropConstraint(self, on=supports_alter).execute_at('before-drop', table.metadata) + def copy(self, **kw): - return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec(**kw) for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter) - -class PrimaryKeyConstraint(Constraint): + return ForeignKeyConstraint( + [x.parent.name for x in self._elements.values()], + [x._get_colspec(**kw) for x in self._elements.values()], + name=self.name, + onupdate=self.onupdate, + ondelete=self.ondelete, + use_alter=self.use_alter + ) + +class PrimaryKeyConstraint(ColumnCollectionConstraint): """A table-level PRIMARY KEY constraint. Defines a single column or composite PRIMARY KEY constraint. For a @@ -1307,63 +1435,14 @@ class PrimaryKeyConstraint(Constraint): __visit_name__ = 'primary_key_constraint' - def __init__(self, *columns, **kwargs): - """Construct a composite-capable PRIMARY KEY. - - \*columns - A sequence of column names. All columns named must be defined and - present within the parent Table. - - name - Optional, the in-database name of the key. - - deferrable - Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when - issuing DDL for this constraint. - - initially - Optional string. If set, emit INITIALLY when issuing DDL - for this constraint. - """ - - constraint_args = dict(name=kwargs.pop('name', None), - deferrable=kwargs.pop('deferrable', None), - initially=kwargs.pop('initially', None)) - if kwargs: - raise exc.ArgumentError( - 'Unknown PrimaryKeyConstraint argument(s): %s' % - ', '.join(repr(x) for x in kwargs.keys())) - - super(PrimaryKeyConstraint, self).__init__(**constraint_args) - self.__colnames = list(columns) - def _set_parent(self, table): - self.table = table - table.primary_key = self - for name in self.__colnames: - self.add(table.c[name]) - - def add(self, col): - self.columns.add(col) - col.primary_key = True - append_column = add + super(PrimaryKeyConstraint, self)._set_parent(table) + table._set_primary_key(self) - def replace(self, col): + def _replace(self, col): self.columns.replace(col) - def remove(self, col): - col.primary_key = False - del self.columns[col.key] - - def copy(self, **kw): - return PrimaryKeyConstraint(name=self.name, *[c.key for c in self]) - - __hash__ = Constraint.__hash__ - - def __eq__(self, other): - return self.columns == other - -class UniqueConstraint(Constraint): +class UniqueConstraint(ColumnCollectionConstraint): """A table-level UNIQUE constraint. Defines a single column or composite UNIQUE constraint. For a no-frills, @@ -1374,48 +1453,6 @@ class UniqueConstraint(Constraint): __visit_name__ = 'unique_constraint' - def __init__(self, *columns, **kwargs): - """Construct a UNIQUE constraint. - - \*columns - A sequence of column names. All columns named must be defined and - present within the parent Table. - - name - Optional, the in-database name of the key. - - deferrable - Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when - issuing DDL for this constraint. - - initially - Optional string. If set, emit INITIALLY when issuing DDL - for this constraint. - """ - - constraint_args = dict(name=kwargs.pop('name', None), - deferrable=kwargs.pop('deferrable', None), - initially=kwargs.pop('initially', None)) - if kwargs: - raise exc.ArgumentError( - 'Unknown UniqueConstraint argument(s): %s' % - ', '.join(repr(x) for x in kwargs.keys())) - - super(UniqueConstraint, self).__init__(**constraint_args) - self.__colnames = list(columns) - - def _set_parent(self, table): - self.table = table - table.constraints.add(self) - for c in self.__colnames: - self.append_column(table.c[c]) - - def append_column(self, col): - self.columns.add(col) - - def copy(self, **kw): - return UniqueConstraint(name=self.name, *self.__colnames) - class Index(SchemaItem): """A table-level INDEX. @@ -1436,7 +1473,7 @@ class Index(SchemaItem): \*columns Columns to include in the index. All columns must belong to the same - table, and no column may appear more than once. + table. \**kwargs Keyword arguments include: @@ -1444,42 +1481,36 @@ class Index(SchemaItem): unique Defaults to False: create a unique index. - postgres_where + postgresql_where Defaults to None: create a partial index when using PostgreSQL """ self.name = name - self.columns = [] + self.columns = expression.ColumnCollection() self.table = None self.unique = kwargs.pop('unique', False) self.kwargs = kwargs - self._init_items(*columns) - - def _init_items(self, *args): - for column in args: - self.append_column(_to_schema_column(column)) + for column in columns: + column = _to_schema_column(column) + if self.table is None: + self._set_parent(column.table) + elif column.table != self.table: + # all columns muse be from same table + raise exc.ArgumentError( + "All index columns must be from same table. " + "%s is from %s not %s" % (column, column.table, self.table)) + self.columns.add(column) def _set_parent(self, table): self.table = table - self.metadata = table.metadata table.indexes.add(self) - def append_column(self, column): - # make sure all columns are from the same table - # and no column is repeated - if self.table is None: - self._set_parent(column.table) - elif column.table != self.table: - # all columns muse be from same table - raise exc.ArgumentError( - "All index columns must be from same table. " - "%s is from %s not %s" % (column, column.table, self.table)) - elif column.name in [ c.name for c in self.columns ]: - raise exc.ArgumentError( - "A column may not appear twice in the " - "same index (%s already has column %s)" % (self.name, column)) - self.columns.append(column) + @property + def bind(self): + """Return the connectable associated with this Index.""" + + return self.table.bind def create(self, bind=None): if bind is None: @@ -1492,9 +1523,6 @@ class Index(SchemaItem): bind = _bind_or_error(self) bind.drop(self) - def __str__(self): - return repr(self) - def __repr__(self): return 'Index("%s", %s%s)' % (self.name, ', '.join(repr(c) for c in self.columns), @@ -1576,27 +1604,6 @@ class MetaData(SchemaItem): return self._bind is not None - @util.deprecated('Deprecated. Use ``metadata.bind = `` or ' - '``metadata.bind = ``.') - def connect(self, bind, **kwargs): - """Bind this MetaData to an Engine. - - bind - A string, ``URL``, ``Engine`` or ``Connection`` instance. If a - string or ``URL``, will be passed to ``create_engine()`` along with - ``\**kwargs`` to produce the engine which to connect to. Otherwise - connects directly to the given ``Engine``. - - """ - global URL - if URL is None: - from sqlalchemy.engine.url import URL - if isinstance(bind, (basestring, URL)): - from sqlalchemy import create_engine - self._bind = create_engine(bind, **kwargs) - else: - self._bind = bind - def bind(self): """An Engine or Connection to which this MetaData is bound. @@ -1633,27 +1640,13 @@ class MetaData(SchemaItem): # TODO: scan all other tables and remove FK _column del self.tables[table.key] - @util.deprecated('Deprecated. Use ``metadata.sorted_tables``') - def table_iterator(self, reverse=True, tables=None): - """Deprecated - use metadata.sorted_tables().""" - - from sqlalchemy.sql.util import sort_tables - if tables is None: - tables = self.tables.values() - else: - tables = set(tables).intersection(self.tables.values()) - ret = sort_tables(tables) - if reverse: - ret = reversed(ret) - return iter(ret) - @property def sorted_tables(self): """Returns a list of ``Table`` objects sorted in order of dependency. """ from sqlalchemy.sql.util import sort_tables - return sort_tables(self.tables.values()) + return sort_tables(self.tables.itervalues()) def reflect(self, bind=None, schema=None, only=None): """Load all available table definitions from the database. @@ -1699,7 +1692,7 @@ class MetaData(SchemaItem): available = util.OrderedSet(bind.engine.table_names(schema, connection=conn)) - current = set(self.tables.keys()) + current = set(self.tables.iterkeys()) if only is None: load = [name for name in available if name not in current] @@ -1777,11 +1770,7 @@ class MetaData(SchemaItem): """ if bind is None: bind = _bind_or_error(self) - for listener in self.ddl_listeners['before-create']: - listener('before-create', self, bind) bind.create(self, checkfirst=checkfirst, tables=tables) - for listener in self.ddl_listeners['after-create']: - listener('after-create', self, bind) def drop_all(self, bind=None, tables=None, checkfirst=True): """Drop all tables stored in this metadata. @@ -1804,11 +1793,7 @@ class MetaData(SchemaItem): """ if bind is None: bind = _bind_or_error(self) - for listener in self.ddl_listeners['before-drop']: - listener('before-drop', self, bind) bind.drop(self, checkfirst=checkfirst, tables=tables) - for listener in self.ddl_listeners['after-drop']: - listener('after-drop', self, bind) class ThreadLocalMetaData(MetaData): """A MetaData variant that presents a different ``bind`` in every thread. @@ -1833,31 +1818,6 @@ class ThreadLocalMetaData(MetaData): self.__engines = {} super(ThreadLocalMetaData, self).__init__() - @util.deprecated('Deprecated. Use ``metadata.bind = `` or ' - '``metadata.bind = ``.') - def connect(self, bind, **kwargs): - """Bind to an Engine in the caller's thread. - - bind - A string, ``URL``, ``Engine`` or ``Connection`` instance. If a - string or ``URL``, will be passed to ``create_engine()`` along with - ``\**kwargs`` to produce the engine which to connect to. Otherwise - connects directly to the given ``Engine``. - """ - - global URL - if URL is None: - from sqlalchemy.engine.url import URL - - if isinstance(bind, (basestring, URL)): - try: - engine = self.__engines[bind] - except KeyError: - from sqlalchemy import create_engine - engine = create_engine(bind, **kwargs) - bind = engine - self._bind_to(bind) - def bind(self): """The bound Engine or Connection for this thread. @@ -1899,7 +1859,7 @@ class ThreadLocalMetaData(MetaData): def dispose(self): """Dispose all bound engines, in all thread contexts.""" - for e in self.__engines.values(): + for e in self.__engines.itervalues(): if hasattr(e, 'dispose'): e.dispose() @@ -1909,87 +1869,15 @@ class SchemaVisitor(visitors.ClauseVisitor): __traverse_options__ = {'schema_visitor':True} -class DDL(object): - """A literal DDL statement. - - Specifies literal SQL DDL to be executed by the database. DDL objects can - be attached to ``Tables`` or ``MetaData`` instances, conditionally - executing SQL as part of the DDL lifecycle of those schema items. Basic - templating support allows a single DDL instance to handle repetitive tasks - for multiple tables. - - Examples:: - - tbl = Table('users', metadata, Column('uid', Integer)) # ... - DDL('DROP TRIGGER users_trigger').execute_at('before-create', tbl) - - spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE', on='somedb') - spow.execute_at('after-create', tbl) - - drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE') - connection.execute(drop_spow) - """ - - def __init__(self, statement, on=None, context=None, bind=None): - """Create a DDL statement. - - statement - A string or unicode string to be executed. Statements will be - processed with Python's string formatting operator. See the - ``context`` argument and the ``execute_at`` method. - - A literal '%' in a statement must be escaped as '%%'. - - SQL bind parameters are not available in DDL statements. - - on - Optional filtering criteria. May be a string or a callable - predicate. If a string, it will be compared to the name of the - executing database dialect:: - - DDL('something', on='postgres') - - If a callable, it will be invoked with three positional arguments: - - event - The name of the event that has triggered this DDL, such as - 'after-create' Will be None if the DDL is executed explicitly. - - schema_item - A SchemaItem instance, such as ``Table`` or ``MetaData``. May be - None if the DDL is executed explicitly. - - connection - The ``Connection`` being used for DDL execution - - If the callable returns a true value, the DDL statement will be - executed. - - context - Optional dictionary, defaults to None. These values will be - available for use in string substitutions on the DDL statement. - - bind - Optional. A :class:`~sqlalchemy.engine.base.Connectable`, used by default when ``execute()`` - is invoked without a bind argument. - - """ - - if not isinstance(statement, basestring): - raise exc.ArgumentError( - "Expected a string or unicode SQL statement, got '%r'" % - statement) - if (on is not None and - (not isinstance(on, basestring) and not util.callable(on))): - raise exc.ArgumentError( - "Expected the name of a database dialect or a callable for " - "'on' criteria, got type '%s'." % type(on).__name__) - - self.statement = statement - self.on = on - self.context = context or {} - self._bind = bind +class DDLElement(expression.ClauseElement): + """Base class for DDL expression constructs.""" + + supports_execution = True + _autocommit = True + schema_item = None + on = None + def execute(self, bind=None, schema_item=None): """Execute this DDL immediately. @@ -2010,10 +1898,9 @@ class DDL(object): if bind is None: bind = _bind_or_error(self) - # no SQL bind params are supported + if self._should_execute(None, schema_item, bind): - executable = expression.text(self._expand(schema_item, bind)) - return bind.execute(executable) + return bind.execute(self.against(schema_item)) else: bind.engine.logger.info("DDL execution skipped, criteria not met.") @@ -2025,7 +1912,7 @@ class DDL(object): will be executed using the same Connection and transactional context as the Table create/drop itself. The ``.bind`` property of this statement is ignored. - + event One of the events defined in the schema item's ``.ddl_events``; e.g. 'before-create', 'after-create', 'before-drop' or 'after-drop' @@ -2066,65 +1953,143 @@ class DDL(object): schema_item.ddl_listeners[event].append(self) return self - def bind(self): - """An Engine or Connection to which this DDL is bound. - - This property may be assigned an ``Engine`` or ``Connection``, or - assigned a string or URL to automatically create a basic ``Engine`` - for this bind with ``create_engine()``. - """ - return self._bind - - def _bind_to(self, bind): - """Bind this MetaData to an Engine, Connection, string or URL.""" - - global URL - if URL is None: - from sqlalchemy.engine.url import URL + @expression._generative + def against(self, schema_item): + """Return a copy of this DDL against a specific schema item.""" - if isinstance(bind, (basestring, URL)): - from sqlalchemy import create_engine - self._bind = create_engine(bind) - else: - self._bind = bind - bind = property(bind, _bind_to) + self.schema_item = schema_item - def __call__(self, event, schema_item, bind): + def __call__(self, event, schema_item, bind, **kw): """Execute the DDL as a ddl_listener.""" - if self._should_execute(event, schema_item, bind): - statement = expression.text(self._expand(schema_item, bind)) - return bind.execute(statement) + if self._should_execute(event, schema_item, bind, **kw): + return bind.execute(self.against(schema_item)) - def _expand(self, schema_item, bind): - return self.statement % self._prepare_context(schema_item, bind) + def _check_ddl_on(self, on): + if (on is not None and + (not isinstance(on, (basestring, tuple, list, set)) and not util.callable(on))): + raise exc.ArgumentError( + "Expected the name of a database dialect, a tuple of names, or a callable for " + "'on' criteria, got type '%s'." % type(on).__name__) - def _should_execute(self, event, schema_item, bind): + def _should_execute(self, event, schema_item, bind, **kw): if self.on is None: return True elif isinstance(self.on, basestring): return self.on == bind.engine.name + elif isinstance(self.on, (tuple, list, set)): + return bind.engine.name in self.on else: - return self.on(event, schema_item, bind) + return self.on(event, schema_item, bind, **kw) - def _prepare_context(self, schema_item, bind): - # table events can substitute table and schema name - if isinstance(schema_item, Table): - context = self.context.copy() + def bind(self): + if self._bind: + return self._bind + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) - preparer = bind.dialect.identifier_preparer - path = preparer.format_table_seq(schema_item) - if len(path) == 1: - table, schema = path[0], '' - else: - table, schema = path[-1], path[0] + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a Dialect.""" + + return dialect.ddl_compiler(dialect, self, **kw) + +class DDL(DDLElement): + """A literal DDL statement. + + Specifies literal SQL DDL to be executed by the database. DDL objects can + be attached to ``Tables`` or ``MetaData`` instances, conditionally + executing SQL as part of the DDL lifecycle of those schema items. Basic + templating support allows a single DDL instance to handle repetitive tasks + for multiple tables. + + Examples:: + + tbl = Table('users', metadata, Column('uid', Integer)) # ... + DDL('DROP TRIGGER users_trigger').execute_at('before-create', tbl) + + spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE', on='somedb') + spow.execute_at('after-create', tbl) + + drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE') + connection.execute(drop_spow) + """ + + __visit_name__ = "ddl" + + def __init__(self, statement, on=None, context=None, bind=None): + """Create a DDL statement. + + statement + A string or unicode string to be executed. Statements will be + processed with Python's string formatting operator. See the + ``context`` argument and the ``execute_at`` method. + + A literal '%' in a statement must be escaped as '%%'. + + SQL bind parameters are not available in DDL statements. + + on + Optional filtering criteria. May be a string, tuple or a callable + predicate. If a string, it will be compared to the name of the + executing database dialect:: + + DDL('something', on='postgresql') + + If a tuple, specifies multiple dialect names: + + DDL('something', on=('postgresql', 'mysql')) + + If a callable, it will be invoked with three positional arguments + as well as optional keyword arguments: + + event + The name of the event that has triggered this DDL, such as + 'after-create' Will be None if the DDL is executed explicitly. + + schema_item + A SchemaItem instance, such as ``Table`` or ``MetaData``. May be + None if the DDL is executed explicitly. + + connection + The ``Connection`` being used for DDL execution + + **kw + Keyword arguments which may be sent include: + tables - a list of Table objects which are to be created/ + dropped within a MetaData.create_all() or drop_all() method + call. + + If the callable returns a true value, the DDL statement will be + executed. + + context + Optional dictionary, defaults to None. These values will be + available for use in string substitutions on the DDL statement. + + bind + Optional. A :class:`~sqlalchemy.engine.base.Connectable`, used by default when ``execute()`` + is invoked without a bind argument. + + """ + + if not isinstance(statement, basestring): + raise exc.ArgumentError( + "Expected a string or unicode SQL statement, got '%r'" % + statement) + + self.statement = statement + self.context = context or {} + + self._check_ddl_on(on) + self.on = on + self._bind = bind - context.setdefault('table', table) - context.setdefault('schema', schema) - context.setdefault('fullname', preparer.format_table(schema_item)) - return context - else: - return self.context def __repr__(self): return '<%s@%s; %s>' % ( @@ -2135,12 +2100,81 @@ class DDL(object): if getattr(self, key)])) def _to_schema_column(element): - if hasattr(element, '__clause_element__'): - element = element.__clause_element__() - if not isinstance(element, Column): - raise exc.ArgumentError("schema.Column object expected") - return element + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + if not isinstance(element, Column): + raise exc.ArgumentError("schema.Column object expected") + return element + +def _to_schema_column_or_string(element): + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + return element + +class _CreateDropBase(DDLElement): + """Base class for DDL constucts that represent CREATE and DROP or equivalents. + + The common theme of _CreateDropBase is a single + ``element`` attribute which refers to the element + to be created or dropped. + """ + + def __init__(self, element, on=None, bind=None): + self.element = element + self._check_ddl_on(on) + self.on = on + self.bind = bind + element.inline_ddl = False + +class CreateTable(_CreateDropBase): + """Represent a CREATE TABLE statement.""" + + __visit_name__ = "create_table" + +class DropTable(_CreateDropBase): + """Represent a DROP TABLE statement.""" + + __visit_name__ = "drop_table" + + def __init__(self, element, cascade=False, **kw): + self.cascade = cascade + super(DropTable, self).__init__(element, **kw) + +class CreateSequence(_CreateDropBase): + """Represent a CREATE SEQUENCE statement.""" + + __visit_name__ = "create_sequence" + +class DropSequence(_CreateDropBase): + """Represent a DROP SEQUENCE statement.""" + + __visit_name__ = "drop_sequence" + +class CreateIndex(_CreateDropBase): + """Represent a CREATE INDEX statement.""" + + __visit_name__ = "create_index" + +class DropIndex(_CreateDropBase): + """Represent a DROP INDEX statement.""" + + __visit_name__ = "drop_index" + +class AddConstraint(_CreateDropBase): + """Represent an ALTER TABLE ADD CONSTRAINT statement.""" + + __visit_name__ = "add_constraint" + +class DropConstraint(_CreateDropBase): + """Represent an ALTER TABLE DROP CONSTRAINT statement.""" + + __visit_name__ = "drop_constraint" + + def __init__(self, element, cascade=False, **kw): + self.cascade = cascade + super(DropConstraint, self).__init__(element, **kw) + def _bind_or_error(schemaitem): bind = schemaitem.bind if not bind: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6af65ec140..6bfad4a76c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -6,19 +6,23 @@ """Base SQL and DDL compiler implementations. -Provides the :class:`~sqlalchemy.sql.compiler.DefaultCompiler` class, which is -responsible for generating all SQL query strings, as well as -:class:`~sqlalchemy.sql.compiler.SchemaGenerator` and :class:`~sqlalchemy.sql.compiler.SchemaDropper` -which issue CREATE and DROP DDL for tables, sequences, and indexes. - -The elements in this module are used by public-facing constructs like -:class:`~sqlalchemy.sql.expression.ClauseElement` and :class:`~sqlalchemy.engine.Engine`. -While dialect authors will want to be familiar with this module for the purpose of -creating database-specific compilers and schema generators, the module -is otherwise internal to SQLAlchemy. +Classes provided include: + +:class:`~sqlalchemy.sql.compiler.SQLCompiler` - renders SQL +strings + +:class:`~sqlalchemy.sql.compiler.DDLCompiler` - renders DDL +(data definition language) strings + +:class:`~sqlalchemy.sql.compiler.GenericTypeCompiler` - renders +type specification strings. + +To generate user-defined SQL strings, see +:module:`~sqlalchemy.ext.compiler`. + """ -import string, re +import re from sqlalchemy import schema, engine, util, exc from sqlalchemy.sql import operators, functions, util as sql_util, visitors from sqlalchemy.sql import expression as sql @@ -58,40 +62,43 @@ BIND_TEMPLATES = { OPERATORS = { - operators.and_ : 'AND', - operators.or_ : 'OR', - operators.inv : 'NOT', - operators.add : '+', - operators.mul : '*', - operators.sub : '-', - operators.div : '/', - operators.mod : '%', - operators.truediv : '/', - operators.lt : '<', - operators.le : '<=', - operators.ne : '!=', - operators.gt : '>', - operators.ge : '>=', - operators.eq : '=', - operators.distinct_op : 'DISTINCT', - operators.concat_op : '||', - operators.like_op : lambda x, y, escape=None: '%s LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.notlike_op : lambda x, y, escape=None: '%s NOT LIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.ilike_op : lambda x, y, escape=None: "lower(%s) LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.notilike_op : lambda x, y, escape=None: "lower(%s) NOT LIKE lower(%s)" % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - operators.between_op : 'BETWEEN', - operators.match_op : 'MATCH', - operators.in_op : 'IN', - operators.notin_op : 'NOT IN', + # binary + operators.and_ : ' AND ', + operators.or_ : ' OR ', + operators.add : ' + ', + operators.mul : ' * ', + operators.sub : ' - ', + # Py2K + operators.div : ' / ', + # end Py2K + operators.mod : ' % ', + operators.truediv : ' / ', + operators.lt : ' < ', + operators.le : ' <= ', + operators.ne : ' != ', + operators.gt : ' > ', + operators.ge : ' >= ', + operators.eq : ' = ', + operators.concat_op : ' || ', + operators.between_op : ' BETWEEN ', + operators.match_op : ' MATCH ', + operators.in_op : ' IN ', + operators.notin_op : ' NOT IN ', operators.comma_op : ', ', - operators.desc_op : 'DESC', - operators.asc_op : 'ASC', - operators.from_ : 'FROM', - operators.as_ : 'AS', - operators.exists : 'EXISTS', - operators.is_ : 'IS', - operators.isnot : 'IS NOT', - operators.collate : 'COLLATE', + operators.from_ : ' FROM ', + operators.as_ : ' AS ', + operators.is_ : ' IS ', + operators.isnot : ' IS NOT ', + operators.collate : ' COLLATE ', + + # unary + operators.exists : 'EXISTS ', + operators.distinct_op : 'DISTINCT ', + operators.inv : 'NOT ', + + # modifiers + operators.desc_op : ' DESC', + operators.asc_op : ' ASC', } FUNCTIONS = { @@ -140,7 +147,7 @@ class _CompileLabel(visitors.Visitable): def quote(self): return self.element.quote -class DefaultCompiler(engine.Compiled): +class SQLCompiler(engine.Compiled): """Default implementation of Compiled. Compiles ClauseElements into SQL strings. Uses a similar visit @@ -148,14 +155,14 @@ class DefaultCompiler(engine.Compiled): """ - operators = OPERATORS - functions = FUNCTIONS extract_map = EXTRACT_MAP - # if we are insert/update/delete. - # set to true when we visit an INSERT, UPDATE or DELETE + # class-level defaults which can be set at the instance + # level to define if this Compiled instance represents + # INSERT/UPDATE/DELETE isdelete = isinsert = isupdate = False - + returning = None + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -170,7 +177,9 @@ class DefaultCompiler(engine.Compiled): statement. """ - engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs) + engine.Compiled.__init__(self, dialect, statement, **kwargs) + + self.column_keys = column_keys # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) self.inline = inline or getattr(statement, 'inline', False) @@ -210,12 +219,6 @@ class DefaultCompiler(engine.Compiled): # or dialect.max_identifier_length self.truncated_names = {} - def compile(self): - self.string = self.process(self.statement) - - def process(self, obj, **kwargs): - return obj._compiler_dispatch(self, **kwargs) - def is_subquery(self): return len(self.stack) > 1 @@ -223,7 +226,6 @@ class DefaultCompiler(engine.Compiled): """return a dictionary of bind parameter keys and values""" if params: - params = util.column_dict(params) pd = {} for bindparam, name in self.bind_names.iteritems(): for paramname in (bindparam.key, bindparam.shortname, name): @@ -245,7 +247,10 @@ class DefaultCompiler(engine.Compiled): pd[self.bind_names[bindparam]] = bindparam.value return pd - params = property(construct_params) + params = property(construct_params, doc=""" + Return the bind params for this compiled object. + + """) def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. @@ -267,10 +272,11 @@ class DefaultCompiler(engine.Compiled): self._truncated_identifier("colident", label.name) or label.name if result_map is not None: - result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type) + result_map[labelname.lower()] = \ + (label.name, (label, label.element, labelname), label.element.type) - return self.process(label.element) + " " + \ - self.operator_string(operators.as_) + " " + \ + return self.process(label.element) + \ + OPERATORS[operators.as_] + \ self.preparer.format_label(label, labelname) else: return self.process(label.element) @@ -292,14 +298,17 @@ class DefaultCompiler(engine.Compiled): return name else: if column.table.schema: - schema_prefix = self.preparer.quote_schema(column.table.schema, column.table.quote_schema) + '.' + schema_prefix = self.preparer.quote_schema( + column.table.schema, + column.table.quote_schema) + '.' else: schema_prefix = '' tablename = column.table.name tablename = isinstance(tablename, sql._generated_label) and \ self._truncated_identifier("alias", tablename) or tablename - return schema_prefix + self.preparer.quote(tablename, column.table.quote) + "." + name + return schema_prefix + \ + self.preparer.quote(tablename, column.table.quote) + "." + name def escape_literal_column(self, text): """provide escaping for the literal_column() construct.""" @@ -314,7 +323,7 @@ class DefaultCompiler(engine.Compiled): return index.name def visit_typeclause(self, typeclause, **kwargs): - return typeclause.type.dialect_impl(self.dialect).get_col_spec() + return self.dialect.type_compiler.process(typeclause.type) def post_process_text(self, text): return text @@ -343,10 +352,8 @@ class DefaultCompiler(engine.Compiled): sep = clauselist.operator if sep is None: sep = " " - elif sep is operators.comma_op: - sep = ', ' else: - sep = " " + self.operator_string(clauselist.operator) + " " + sep = OPERATORS[clauselist.operator] return sep.join(s for s in (self.process(c) for c in clauselist.clauses) if s is not None) @@ -362,7 +369,8 @@ class DefaultCompiler(engine.Compiled): return x def visit_cast(self, cast, **kwargs): - return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) + return "CAST(%s AS %s)" % \ + (self.process(cast.clause), self.process(cast.typeclause)) def visit_extract(self, extract, **kwargs): field = self.extract_map.get(extract.field, extract.field) @@ -372,26 +380,26 @@ class DefaultCompiler(engine.Compiled): if result_map is not None: result_map[func.name.lower()] = (func.name, None, func.type) - name = self.function_string(func) - - if util.callable(name): - return name(*[self.process(x) for x in func.clauses]) + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + if disp: + return disp(func, **kwargs) else: - return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} + name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") + return ".".join(func.packagenames + [name]) % \ + {'expr':self.function_argspec(func, **kwargs)} def function_argspec(self, func, **kwargs): return self.process(func.clause_expr, **kwargs) - def function_string(self, func): - return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s")) - def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): entry = self.stack and self.stack[-1] or {} self.stack.append({'from':entry.get('from', None), 'iswrapper':True}) - text = string.join((self.process(c, asfrom=asfrom, parens=False, compound_index=i) - for i, c in enumerate(cs.selects)), - " " + cs.keyword + " ") + text = (" " + cs.keyword + " ").join( + (self.process(c, asfrom=asfrom, parens=False, compound_index=i) + for i, c in enumerate(cs.selects)) + ) + group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: text += " GROUP BY " + group_by @@ -408,27 +416,57 @@ class DefaultCompiler(engine.Compiled): def visit_unary(self, unary, **kw): s = self.process(unary.element, **kw) if unary.operator: - s = self.operator_string(unary.operator) + " " + s + s = OPERATORS[unary.operator] + s if unary.modifier: - s = s + " " + self.operator_string(unary.modifier) + s = s + OPERATORS[unary.modifier] return s def visit_binary(self, binary, **kwargs): - op = self.operator_string(binary.operator) - if util.callable(op): - return op(self.process(binary.left), self.process(binary.right), **binary.modifiers) - else: - return self.process(binary.left) + " " + op + " " + self.process(binary.right) + + return self._operator_dispatch(binary.operator, + binary, + lambda opstr: self.process(binary.left) + opstr + self.process(binary.right), + **kwargs + ) - def operator_string(self, operator): - return self.operators.get(operator, str(operator)) + def visit_like_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + def visit_notlike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_ilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def visit_notilike_op(self, binary, **kw): + escape = binary.modifiers.get("escape", None) + return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \ + + (escape and ' ESCAPE \'%s\'' % escape or '') + + def _operator_dispatch(self, operator, element, fn, **kw): + if util.callable(operator): + disp = getattr(self, "visit_%s" % operator.__name__, None) + if disp: + return disp(element, **kw) + else: + return fn(OPERATORS[operator]) + else: + return fn(" " + operator + " ") + def visit_bindparam(self, bindparam, **kwargs): name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] if existing is not bindparam and (existing.unique or bindparam.unique): - raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) + raise exc.CompileError( + "Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key + ) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) @@ -491,7 +529,7 @@ class DefaultCompiler(engine.Compiled): if isinstance(column, sql._Label): return column - if select.use_labels and column._label: + if select and select.use_labels and column._label: return _CompileLabel(column, column._label) if \ @@ -501,13 +539,15 @@ class DefaultCompiler(engine.Compiled): column.table is not None and \ not isinstance(column.table, sql.Select): return _CompileLabel(column, sql._generated_label(column.name)) - elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ + elif not isinstance(column, + (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ and (not hasattr(column, 'name') or isinstance(column, sql.Function)): return _CompileLabel(column, column.anon_label) else: return column - def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, compound_index=1, **kwargs): + def visit_select(self, select, asfrom=False, parens=True, + iswrapper=False, compound_index=1, **kwargs): entry = self.stack and self.stack[-1] or {} @@ -583,8 +623,10 @@ class DefaultCompiler(engine.Compiled): return text def get_select_precolumns(self, select): - """Called when building a ``SELECT`` statement, position is just before column list.""" - + """Called when building a ``SELECT`` statement, position is just before + column list. + + """ return select._distinct and "DISTINCT " or "" def order_by_clause(self, select): @@ -613,14 +655,16 @@ class DefaultCompiler(engine.Compiled): def visit_table(self, table, asfrom=False, **kwargs): if asfrom: if getattr(table, "schema", None): - return self.preparer.quote_schema(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote) + return self.preparer.quote_schema(table.schema, table.quote_schema) + \ + "." + self.preparer.quote(table.name, table.quote) else: return self.preparer.quote(table.name, table.quote) else: return "" def visit_join(self, join, asfrom=False, **kwargs): - return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ + return (self.process(join.left, asfrom=True) + \ + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) def visit_sequence(self, seq): @@ -629,41 +673,75 @@ class DefaultCompiler(engine.Compiled): def visit_insert(self, insert_stmt): self.isinsert = True colparams = self._get_colparams(insert_stmt) - preparer = self.preparer - - insert = ' '.join(["INSERT"] + - [self.process(x) for x in insert_stmt._prefixes]) - if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert: + if not colparams and \ + not self.dialect.supports_default_values and \ + not self.dialect.supports_empty_insert: raise exc.CompileError( "The version of %s you are using does not support empty inserts." % self.dialect.name) - elif not colparams and self.dialect.supports_default_values: - return (insert + " INTO %s DEFAULT VALUES" % ( - (preparer.format_table(insert_stmt.table),))) - else: - return (insert + " INTO %s (%s) VALUES (%s)" % - (preparer.format_table(insert_stmt.table), - ', '.join([preparer.format_column(c[0]) - for c in colparams]), - ', '.join([c[1] for c in colparams]))) + preparer = self.preparer + supports_default_values = self.dialect.supports_default_values + + text = "INSERT" + + prefixes = [self.process(x) for x in insert_stmt._prefixes] + if prefixes: + text += " " + " ".join(prefixes) + + text += " INTO " + preparer.format_table(insert_stmt.table) + + if colparams or not supports_default_values: + text += " (%s)" % ', '.join([preparer.format_column(c[0]) + for c in colparams]) + + if self.returning or insert_stmt._returning: + self.returning = self.returning or insert_stmt._returning + returning_clause = self.returning_clause(insert_stmt, self.returning) + + # cheating + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + + if not colparams and supports_default_values: + text += " DEFAULT VALUES" + else: + text += " VALUES (%s)" % \ + ', '.join([c[1] for c in colparams]) + + if self.returning and returning_clause: + text += " " + returning_clause + + return text + def visit_update(self, update_stmt): self.stack.append({'from': set([update_stmt.table])}) self.isupdate = True colparams = self._get_colparams(update_stmt) - text = ' '.join(( - "UPDATE", - self.preparer.format_table(update_stmt.table), - 'SET', - ', '.join(self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] - for c in colparams) - )) - + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + + text += ' SET ' + \ + ', '.join( + self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] + for c in colparams + ) + + if update_stmt._returning: + self.returning = update_stmt._returning + returning_clause = self.returning_clause(update_stmt, update_stmt._returning) + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) + if self.returning and returning_clause: + text += " " + returning_clause + self.stack.pop(-1) return text @@ -681,7 +759,8 @@ class DefaultCompiler(engine.Compiled): self.postfetch = [] self.prefetch = [] - + self.returning = [] + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: @@ -701,6 +780,15 @@ class DefaultCompiler(engine.Compiled): # create a list of column assignment clauses as tuples values = [] + + need_pks = self.isinsert and \ + not self.inline and \ + not self.statement._returning + + implicit_returning = need_pks and \ + self.dialect.implicit_returning and \ + stmt.table.implicit_returning + for c in stmt.table.columns: if c.key in parameters: value = parameters[c.key] @@ -710,19 +798,48 @@ class DefaultCompiler(engine.Compiled): self.postfetch.append(c) value = self.process(value.self_group()) values.append((c, value)) + elif isinstance(c, schema.Column): if self.isinsert: - if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline): - if (((isinstance(c.default, schema.Sequence) and - not c.default.optional) or - not self.dialect.supports_pk_autoincrement) or - (c.default is not None and - not isinstance(c.default, schema.Sequence))): - values.append((c, create_bind_param(c, None))) - self.prefetch.append(c) + if c.primary_key and \ + need_pks and \ + ( + c is not stmt.table._autoincrement_column or + not self.dialect.postfetch_lastrowid + ): + + if implicit_returning: + if isinstance(c.default, schema.Sequence): + proc = self.process(c.default) + if proc is not None: + values.append((c, proc)) + self.returning.append(c) + elif isinstance(c.default, schema.ColumnDefault) and \ + isinstance(c.default.arg, sql.ClauseElement): + values.append((c, self.process(c.default.arg.self_group()))) + self.returning.append(c) + elif c.default is not None: + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + else: + self.returning.append(c) + else: + if ( + c.default is not None and \ + ( + self.dialect.supports_sequences or + not isinstance(c.default, schema.Sequence) + ) + ) or \ + self.dialect.preexecute_autoincrement_sequences: + + values.append((c, create_bind_param(c, None))) + self.prefetch.append(c) + elif isinstance(c.default, schema.ColumnDefault): if isinstance(c.default.arg, sql.ClauseElement): values.append((c, self.process(c.default.arg.self_group()))) + if not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) @@ -759,9 +876,19 @@ class DefaultCompiler(engine.Compiled): text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) + if delete_stmt._returning: + self.returning = delete_stmt._returning + returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning) + if returning_clause.startswith("OUTPUT"): + text += " " + returning_clause + returning_clause = None + if delete_stmt._whereclause: text += " WHERE " + self.process(delete_stmt._whereclause) + if self.returning and returning_clause: + text += " " + returning_clause + self.stack.pop(-1) return text @@ -775,110 +902,146 @@ class DefaultCompiler(engine.Compiled): def visit_release_savepoint(self, savepoint_stmt): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - def __str__(self): - return self.string or '' - -class DDLBase(engine.SchemaIterator): - def find_alterables(self, tables): - alterables = [] - class FindAlterables(schema.SchemaVisitor): - def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and constraint.table in tables: - alterables.append(constraint) - findalterables = FindAlterables() - for table in tables: - for c in table.constraints: - findalterables.traverse(c) - return alterables - - def _validate_identifier(self, ident, truncate): - if truncate: - if len(ident) > self.dialect.max_identifier_length: - counter = getattr(self, 'counter', 0) - self.counter = counter + 1 - return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] - else: - return ident - else: - self.dialect.validate_identifier(ident) - return ident - - -class SchemaGenerator(DDLBase): - def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(SchemaGenerator, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - self.tables = tables and set(tables) or None - self.preparer = dialect.identifier_preparer - self.dialect = dialect - def get_column_specification(self, column, first_pk=False): - raise NotImplementedError() +class DDLCompiler(engine.Compiled): + @property + def preparer(self): + return self.dialect.identifier_preparer - def _can_create(self, table): - self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) - return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) + def construct_params(self, params=None): + return None + + def visit_ddl(self, ddl, **kwargs): + # table events can substitute table and schema name + context = ddl.context + if isinstance(ddl.schema_item, schema.Table): + context = context.copy() + + preparer = self.dialect.identifier_preparer + path = preparer.format_table_seq(ddl.schema_item) + if len(path) == 1: + table, sch = path[0], '' + else: + table, sch = path[-1], path[0] - def visit_metadata(self, metadata): - if self.tables: - tables = self.tables - else: - tables = metadata.tables.values() - collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)] - for table in collection: - self.traverse_single(table) - if self.dialect.supports_alter: - for alterable in self.find_alterables(collection): - self.add_foreignkey(alterable) - - def visit_table(self, table): - for listener in table.ddl_listeners['before-create']: - listener('before-create', table, self.connection) + context.setdefault('table', table) + context.setdefault('schema', sch) + context.setdefault('fullname', preparer.format_table(ddl.schema_item)) + + return ddl.statement % context - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) + def visit_create_table(self, create): + table = create.element + preparer = self.dialect.identifier_preparer - self.append("\n" + " ".join(['CREATE'] + - table._prefixes + + text = "\n" + " ".join(['CREATE'] + \ + table._prefixes + \ ['TABLE', - self.preparer.format_table(table), - "("])) + preparer.format_table(table), + "("]) separator = "\n" # if only one primary key, specify it along with the column first_pk = False for column in table.columns: - self.append(separator) + text += separator separator = ", \n" - self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) + text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk) if column.primary_key: first_pk = True - for constraint in column.constraints: - self.traverse_single(constraint) + const = " ".join(self.process(constraint) for constraint in column.constraints) + if const: + text += " " + const # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) if table.primary_key: - self.traverse_single(table.primary_key) - for constraint in [c for c in table.constraints if c is not table.primary_key]: - self.traverse_single(constraint) + text += ", \n\t" + self.process(table.primary_key) + + const = ", \n\t".join( + self.process(constraint) for constraint in table.constraints + if constraint is not table.primary_key + and constraint.inline_ddl + and (not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False)) + ) + if const: + text += ", \n\t" + const + + text += "\n)%s\n\n" % self.post_create_table(table) + return text + + def visit_drop_table(self, drop): + ret = "\nDROP TABLE " + self.preparer.format_table(drop.element) + if drop.cascade: + ret += " CASCADE CONSTRAINTS" + return ret + + def visit_create_index(self, create): + index = create.element + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % (preparer.quote(self._validate_identifier(index.name, True), index.quote), + preparer.format_table(index.table), + ', '.join(preparer.quote(c.name, c.quote) + for c in index.columns)) + return text - self.append("\n)%s\n\n" % self.post_create_table(table)) - self.execute() + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX " + \ + self.preparer.quote(self._validate_identifier(index.name, False), index.quote) - if hasattr(table, 'indexes'): - for index in table.indexes: - self.traverse_single(index) + def visit_add_constraint(self, create): + preparer = self.preparer + return "ALTER TABLE %s ADD %s" % ( + self.preparer.format_table(create.element.table), + self.process(create.element) + ) + + def visit_drop_constraint(self, drop): + preparer = self.preparer + return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( + self.preparer.format_table(drop.element.table), + self.preparer.format_constraint(drop.element), + " CASCADE" if drop.cascade else "" + ) + + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + \ + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default - for listener in table.ddl_listeners['after-create']: - listener('after-create', table, self.connection) + if not column.nullable: + colspec += " NOT NULL" + return colspec def post_create_table(self, table): return '' + def _compile(self, tocompile, parameters): + """compile the given string/parameters using this SchemaGenerator's dialect.""" + + compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) + compiler.compile() + return compiler + + def _validate_identifier(self, ident, truncate): + if truncate: + if len(ident) > self.dialect.max_identifier_length: + counter = getattr(self, 'counter', 0) + self.counter = counter + 1 + return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] + else: + return ident + else: + self.dialect.validate_identifier(ident) + return ident + def get_column_default_string(self, column): if isinstance(column.server_default, schema.DefaultClause): if isinstance(column.server_default.arg, basestring): @@ -888,149 +1051,190 @@ class SchemaGenerator(DDLBase): else: return None - def _compile(self, tocompile, parameters): - """compile the given string/parameters using this SchemaGenerator's dialect.""" - compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) - compiler.compile() - return compiler - def visit_check_constraint(self, constraint): - self.append(", \n\t") + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(constraint)) - self.append(" CHECK (%s)" % constraint.sqltext) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += " CHECK (%s)" % constraint.sqltext + text += self.define_constraint_deferrability(constraint) + return text def visit_column_check_constraint(self, constraint): - self.append(" CHECK (%s)" % constraint.sqltext) - self.define_constraint_deferrability(constraint) + text = " CHECK (%s)" % constraint.sqltext + text += self.define_constraint_deferrability(constraint) + return text def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return - self.append(", \n\t") + return '' + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append("PRIMARY KEY ") - self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) - for c in constraint)) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += "PRIMARY KEY " + text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) + for c in constraint) + text += self.define_constraint_deferrability(constraint) + return text def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and self.dialect.supports_alter: - return - self.append(", \n\t ") - self.define_foreign_key(constraint) - - def add_foreignkey(self, constraint): - self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) - self.define_foreign_key(constraint) - self.execute() - - def define_foreign_key(self, constraint): - preparer = self.preparer + preparer = self.dialect.identifier_preparer + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % - preparer.format_constraint(constraint)) - table = list(constraint.elements)[0].column.table - self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + text += "CONSTRAINT %s " % \ + preparer.format_constraint(constraint) + remote_table = list(constraint._elements.values())[0].column.table + text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( ', '.join(preparer.quote(f.parent.name, f.parent.quote) - for f in constraint.elements), - preparer.format_table(table), + for f in constraint._elements.values()), + preparer.format_table(remote_table), ', '.join(preparer.quote(f.column.name, f.column.quote) - for f in constraint.elements) - )) - if constraint.ondelete is not None: - self.append(" ON DELETE %s" % constraint.ondelete) - if constraint.onupdate is not None: - self.append(" ON UPDATE %s" % constraint.onupdate) - self.define_constraint_deferrability(constraint) + for f in constraint._elements.values()) + ) + text += self.define_constraint_cascades(constraint) + text += self.define_constraint_deferrability(constraint) + return text def visit_unique_constraint(self, constraint): - self.append(", \n\t") + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)) + text += self.define_constraint_deferrability(constraint) + return text + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + if constraint.onupdate is not None: + text += " ON UPDATE %s" % constraint.onupdate + return text + def define_constraint_deferrability(self, constraint): + text = "" if constraint.deferrable is not None: if constraint.deferrable: - self.append(" DEFERRABLE") + text += " DEFERRABLE" else: - self.append(" NOT DEFERRABLE") + text += " NOT DEFERRABLE" if constraint.initially is not None: - self.append(" INITIALLY %s" % constraint.initially) + text += " INITIALLY %s" % constraint.initially + return text + + +class GenericTypeCompiler(engine.TypeCompiler): + def visit_CHAR(self, type_): + return "CHAR" + (type_.length and "(%d)" % type_.length or "") - def visit_column(self, column): - pass + def visit_NCHAR(self, type_): + return "NCHAR" + (type_.length and "(%d)" % type_.length or "") + + def visit_FLOAT(self, type_): + return "FLOAT" - def visit_index(self, index): - preparer = self.preparer - self.append("CREATE ") - if index.unique: - self.append("UNIQUE ") - self.append("INDEX %s ON %s (%s)" \ - % (preparer.quote(self._validate_identifier(index.name, True), index.quote), - preparer.format_table(index.table), - ', '.join(preparer.quote(c.name, c.quote) - for c in index.columns))) - self.execute() + def visit_NUMERIC(self, type_): + if type_.precision is None: + return "NUMERIC" + else: + return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale} + def visit_DECIMAL(self, type_): + return "DECIMAL" + + def visit_INTEGER(self, type_): + return "INTEGER" -class SchemaDropper(DDLBase): - def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(SchemaDropper, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - self.tables = tables - self.preparer = dialect.identifier_preparer - self.dialect = dialect + def visit_SMALLINT(self, type_): + return "SMALLINT" - def visit_metadata(self, metadata): - if self.tables: - tables = self.tables - else: - tables = metadata.tables.values() - collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)] - if self.dialect.supports_alter: - for alterable in self.find_alterables(collection): - self.drop_foreignkey(alterable) - for table in collection: - self.traverse_single(table) - - def _can_drop(self, table): - self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) - return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) - - def visit_index(self, index): - self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote)) - self.execute() - - def drop_foreignkey(self, constraint): - self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( - self.preparer.format_table(constraint.table), - self.preparer.format_constraint(constraint))) - self.execute() - - def visit_table(self, table): - for listener in table.ddl_listeners['before-drop']: - listener('before-drop', table, self.connection) + def visit_BIGINT(self, type_): + return "BIGINT" - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) + def visit_TIMESTAMP(self, type_): + return 'TIMESTAMP' + + def visit_DATETIME(self, type_): + return "DATETIME" + + def visit_DATE(self, type_): + return "DATE" + + def visit_TIME(self, type_): + return "TIME" + + def visit_CLOB(self, type_): + return "CLOB" - self.append("\nDROP TABLE " + self.preparer.format_table(table)) - self.execute() + def visit_NCLOB(self, type_): + return "NCLOB" - for listener in table.ddl_listeners['after-drop']: - listener('after-drop', table, self.connection) + def visit_VARCHAR(self, type_): + return "VARCHAR" + (type_.length and "(%d)" % type_.length or "") + def visit_NVARCHAR(self, type_): + return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "") + def visit_BLOB(self, type_): + return "BLOB" + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + + def visit_TEXT(self, type_): + return "TEXT" + + def visit_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_boolean(self, type_): + return self.visit_BOOLEAN(type_) + + def visit_time(self, type_): + return self.visit_TIME(type_) + + def visit_datetime(self, type_): + return self.visit_DATETIME(type_) + + def visit_date(self, type_): + return self.visit_DATE(type_) + + def visit_big_integer(self, type_): + return self.visit_BIGINT(type_) + + def visit_small_integer(self, type_): + return self.visit_SMALLINT(type_) + + def visit_integer(self, type_): + return self.visit_INTEGER(type_) + + def visit_float(self, type_): + return self.visit_FLOAT(type_) + + def visit_numeric(self, type_): + return self.visit_NUMERIC(type_) + + def visit_string(self, type_): + return self.visit_VARCHAR(type_) + + def visit_unicode(self, type_): + return self.visit_VARCHAR(type_) + + def visit_text(self, type_): + return self.visit_TEXT(type_) + + def visit_unicode_text(self, type_): + return self.visit_TEXT(type_) + + def visit_null(self, type_): + raise NotImplementedError("Can't generate DDL for the null type") + + def visit_type_decorator(self, type_): + return self.process(type_.type_engine(self.dialect)) + + def visit_user_defined(self, type_): + return type_.get_col_spec() + class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" @@ -1176,24 +1380,24 @@ class IdentifierPreparer(object): else: return (self.format_table(table, use_schema=False), ) + @util.memoized_property + def _r_identifiers(self): + initial, final, escaped_final = \ + [re.escape(s) for s in + (self.initial_quote, self.final_quote, + self._escape_identifier(self.final_quote))] + r = re.compile( + r'(?:' + r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' + r'|([^\.]+))(?=\.|$))+' % + { 'initial': initial, + 'final': final, + 'escaped': escaped_final }) + return r + def unformat_identifiers(self, identifiers): """Unpack 'schema.table.column'-like strings into components.""" - try: - r = self._r_identifiers - except AttributeError: - initial, final, escaped_final = \ - [re.escape(s) for s in - (self.initial_quote, self.final_quote, - self._escape_identifier(self.final_quote))] - r = re.compile( - r'(?:' - r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s' - r'|([^\.]+))(?=\.|$))+' % - { 'initial': initial, - 'final': final, - 'escaped': escaped_final }) - self._r_identifiers = r - + r = self._r_identifiers return [self._unescape_identifier(i) for i in [a or b for a, b in r.findall(identifiers)]] diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 83897ef051..91e0e74ae4 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -29,10 +29,9 @@ to stay the same in future releases. import itertools, re from operator import attrgetter -from sqlalchemy import util, exc +from sqlalchemy import util, exc, types as sqltypes from sqlalchemy.sql import operators from sqlalchemy.sql.visitors import Visitable, cloned_traverse -from sqlalchemy import types as sqltypes import operator functions, schema, sql_util = None, None, None @@ -128,7 +127,8 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs): Similar functionality is also available via the ``select()`` method on any :class:`~sqlalchemy.sql.expression.FromClause`. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.Select`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.Select`. All arguments which accept ``ClauseElement`` arguments also accept string arguments, which will be converted as appropriate into @@ -241,7 +241,8 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs): """ if 'scalar' in kwargs: - util.warn_deprecated('scalar option is deprecated; see docs for details') + util.warn_deprecated( + 'scalar option is deprecated; see docs for details') scalar = kwargs.pop('scalar', False) s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) if scalar: @@ -250,15 +251,16 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs): return s def subquery(alias, *args, **kwargs): - """Return an :class:`~sqlalchemy.sql.expression.Alias` object derived from a :class:`~sqlalchemy.sql.expression.Select`. + """Return an :class:`~sqlalchemy.sql.expression.Alias` object derived + from a :class:`~sqlalchemy.sql.expression.Select`. name alias name \*args, \**kwargs - all other arguments are delivered to the :func:`~sqlalchemy.sql.expression.select` - function. + all other arguments are delivered to the + :func:`~sqlalchemy.sql.expression.select` function. """ return Select(*args, **kwargs).alias(alias) @@ -280,12 +282,12 @@ def insert(table, values=None, inline=False, **kwargs): table columns. Note that the :meth:`~Insert.values()` generative method may also be used for this. - :param prefixes: A list of modifier keywords to be inserted between INSERT and INTO. - Alternatively, the :meth:`~Insert.prefix_with` generative method may be used. + :param prefixes: A list of modifier keywords to be inserted between INSERT + and INTO. Alternatively, the :meth:`~Insert.prefix_with` generative method + may be used. - :param inline: - if True, SQL defaults will be compiled 'inline' into the statement - and not pre-executed. + :param inline: if True, SQL defaults will be compiled 'inline' into the + statement and not pre-executed. If both `values` and compile-time bind parameters are present, the compile-time bind parameters override the information specified @@ -313,9 +315,9 @@ def update(table, whereclause=None, values=None, inline=False, **kwargs): :param table: The table to be updated. - :param whereclause: A ``ClauseElement`` describing the ``WHERE`` condition of the - ``UPDATE`` statement. Note that the :meth:`~Update.where()` generative - method may also be used for this. + :param whereclause: A ``ClauseElement`` describing the ``WHERE`` condition + of the ``UPDATE`` statement. Note that the :meth:`~Update.where()` + generative method may also be used for this. :param values: A dictionary which specifies the ``SET`` conditions of the @@ -347,7 +349,12 @@ def update(table, whereclause=None, values=None, inline=False, **kwargs): against the ``UPDATE`` statement. """ - return Update(table, whereclause=whereclause, values=values, inline=inline, **kwargs) + return Update( + table, + whereclause=whereclause, + values=values, + inline=inline, + **kwargs) def delete(table, whereclause = None, **kwargs): """Return a :class:`~sqlalchemy.sql.expression.Delete` clause element. @@ -357,9 +364,9 @@ def delete(table, whereclause = None, **kwargs): :param table: The table to be updated. - :param whereclause: A :class:`ClauseElement` describing the ``WHERE`` condition of the - ``UPDATE`` statement. Note that the :meth:`~Delete.where()` generative method - may be used instead. + :param whereclause: A :class:`ClauseElement` describing the ``WHERE`` + condition of the ``UPDATE`` statement. Note that the :meth:`~Delete.where()` + generative method may be used instead. """ return Delete(table, whereclause, **kwargs) @@ -368,8 +375,8 @@ def and_(*clauses): """Join a list of clauses together using the ``AND`` operator. The ``&`` operator is also overloaded on all - :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same - result. + :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the + same result. """ if len(clauses) == 1: @@ -380,8 +387,8 @@ def or_(*clauses): """Join a list of clauses together using the ``OR`` operator. The ``|`` operator is also overloaded on all - :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same - result. + :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the + same result. """ if len(clauses) == 1: @@ -392,8 +399,8 @@ def not_(clause): """Return a negation of the given clause, i.e. ``NOT(clause)``. The ``~`` operator is also overloaded on all - :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the same - result. + :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses to produce the + same result. """ return operators.inv(_literal_as_binds(clause)) @@ -408,8 +415,9 @@ def between(ctest, cleft, cright): Equivalent of SQL ``clausetest BETWEEN clauseleft AND clauseright``. - The ``between()`` method on all :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses - provides similar functionality. + The ``between()`` method on all + :class:`~sqlalchemy.sql.expression._CompareMixin` subclasses provides + similar functionality. """ ctest = _literal_as_binds(ctest) @@ -517,7 +525,8 @@ def exists(*args, **kwargs): def union(*selects, **kwargs): """Return a ``UNION`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. A similar ``union()`` method is available on all :class:`~sqlalchemy.sql.expression.FromClause` subclasses. @@ -535,7 +544,8 @@ def union(*selects, **kwargs): def union_all(*selects, **kwargs): """Return a ``UNION ALL`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. A similar ``union_all()`` method is available on all :class:`~sqlalchemy.sql.expression.FromClause` subclasses. @@ -553,7 +563,8 @@ def union_all(*selects, **kwargs): def except_(*selects, **kwargs): """Return an ``EXCEPT`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. \*selects a list of :class:`~sqlalchemy.sql.expression.Select` instances. @@ -568,7 +579,8 @@ def except_(*selects, **kwargs): def except_all(*selects, **kwargs): """Return an ``EXCEPT ALL`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. \*selects a list of :class:`~sqlalchemy.sql.expression.Select` instances. @@ -583,7 +595,8 @@ def except_all(*selects, **kwargs): def intersect(*selects, **kwargs): """Return an ``INTERSECT`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. \*selects a list of :class:`~sqlalchemy.sql.expression.Select` instances. @@ -598,7 +611,8 @@ def intersect(*selects, **kwargs): def intersect_all(*selects, **kwargs): """Return an ``INTERSECT ALL`` of multiple selectables. - The returned object is an instance of :class:`~sqlalchemy.sql.expression.CompoundSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression.CompoundSelect`. \*selects a list of :class:`~sqlalchemy.sql.expression.Select` instances. @@ -613,8 +627,8 @@ def intersect_all(*selects, **kwargs): def alias(selectable, alias=None): """Return an :class:`~sqlalchemy.sql.expression.Alias` object. - An ``Alias`` represents any :class:`~sqlalchemy.sql.expression.FromClause` with - an alternate name assigned within SQL, typically using the ``AS`` + An ``Alias`` represents any :class:`~sqlalchemy.sql.expression.FromClause` + with an alternate name assigned within SQL, typically using the ``AS`` clause when generated, e.g. ``SELECT * FROM table AS aliasname``. Similar functionality is available via the ``alias()`` method @@ -656,7 +670,8 @@ def literal(value, type_=None): return _BindParamClause(None, value, type_=type_, unique=True) def label(name, obj): - """Return a :class:`~sqlalchemy.sql.expression._Label` object for the given :class:`~sqlalchemy.sql.expression.ColumnElement`. + """Return a :class:`~sqlalchemy.sql.expression._Label` object for the given + :class:`~sqlalchemy.sql.expression.ColumnElement`. A label changes the name of an element in the columns clause of a ``SELECT`` statement, typically via the ``AS`` SQL keyword. @@ -674,11 +689,13 @@ def label(name, obj): return _Label(name, obj) def column(text, type_=None): - """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement. + """Return a textual column clause, as would be in the columns clause of a + ``SELECT`` statement. - The object returned is an instance of :class:`~sqlalchemy.sql.expression.ColumnClause`, - which represents the "syntactical" portion of the schema-level - :class:`~sqlalchemy.schema.Column` object. + The object returned is an instance of + :class:`~sqlalchemy.sql.expression.ColumnClause`, which represents the + "syntactical" portion of the schema-level :class:`~sqlalchemy.schema.Column` + object. text the name of the column. Quoting rules will be applied to the @@ -710,9 +727,9 @@ def literal_column(text, type_=None): :func:`~sqlalchemy.sql.expression.column` function. type\_ - an optional :class:`~sqlalchemy.types.TypeEngine` object which will provide - result-set translation and additional expression semantics for this - column. If left as None the type will be NullType. + an optional :class:`~sqlalchemy.types.TypeEngine` object which will + provide result-set translation and additional expression semantics for + this column. If left as None the type will be NullType. """ return ColumnClause(text, type_=type_, is_literal=True) @@ -752,7 +769,8 @@ def bindparam(key, value=None, shortname=None, type_=None, unique=False): return _BindParamClause(key, value, type_=type_, unique=unique, shortname=shortname) def outparam(key, type_=None): - """Create an 'OUT' parameter for usage in functions (stored procedures), for databases which support them. + """Create an 'OUT' parameter for usage in functions (stored procedures), for + databases which support them. The ``outparam`` can be used like a regular function parameter. The "output" value will be available from the @@ -760,7 +778,8 @@ def outparam(key, type_=None): attribute, which returns a dictionary containing the values. """ - return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True) + return _BindParamClause( + key, None, type_=type_, unique=False, isoutparam=True) def text(text, bind=None, *args, **kwargs): """Create literal text to be inserted into a query. @@ -803,8 +822,10 @@ def text(text, bind=None, *args, **kwargs): return _TextClause(text, bind=bind, *args, **kwargs) def null(): - """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql statement.""" - + """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql + statement. + + """ return _Null() class _FunctionGenerator(object): @@ -839,7 +860,8 @@ class _FunctionGenerator(object): if func is not None: return func(*c, **o) - return Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) + return Function( + self.__names[-1], packagenames=self.__names[0:-1], *c, **o) # "func" global - i.e. func.count() func = _FunctionGenerator() @@ -861,10 +883,19 @@ def _clone(element): return element._clone() def _expand_cloned(elements): - """expand the given set of ClauseElements to be the set of all 'cloned' predecessors.""" - + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ return itertools.chain(*[x._cloned_set for x in elements]) +def _select_iterables(elements): + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain(*[c._select_iterable for c in elements]) + def _cloned_intersection(a, b): """return the intersection of sets a and b, counting any overlap between 'cloned' predecessors. @@ -879,7 +910,8 @@ def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): - return not isinstance(element, Visitable) and not hasattr(element, '__clause_element__') + return not isinstance(element, Visitable) and \ + not hasattr(element, '__clause_element__') def _from_objects(*elements): return itertools.chain(*[element._from_objects for element in elements]) @@ -940,27 +972,36 @@ def _no_literals(element): return element def _corresponding_column_or_error(fromclause, column, require_embedded=False): - c = fromclause.corresponding_column(column, require_embedded=require_embedded) + c = fromclause.corresponding_column(column, + require_embedded=require_embedded) if not c: - raise exc.InvalidRequestError("Given column '%s', attached to table '%s', " + raise exc.InvalidRequestError( + "Given column '%s', attached to table '%s', " "failed to locate a corresponding column from table '%s'" - % (column, getattr(column, 'table', None), fromclause.description)) + % + (column, + getattr(column, 'table', None),fromclause.description) + ) return c def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" + return isinstance(col, ColumnElement) class ClauseElement(Visitable): - """Base class for elements of a programmatically constructed SQL expression.""" - + """Base class for elements of a programmatically constructed SQL + expression. + + """ __visit_name__ = 'clause' _annotations = {} supports_execution = False _from_objects = [] - + _bind = None + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -984,7 +1025,8 @@ class ClauseElement(Visitable): @util.memoized_property def _cloned_set(self): - """Return the set consisting all cloned anscestors of this ClauseElement. + """Return the set consisting all cloned anscestors of this + ClauseElement. Includes this ClauseElement. This accessor tends to be used for FromClause objects to identify 'equivalent' FROM clauses, regardless @@ -1004,15 +1046,20 @@ class ClauseElement(Visitable): return d def _annotate(self, values): - """return a copy of this ClauseElement with the given annotations dictionary.""" - + """return a copy of this ClauseElement with the given annotations + dictionary. + + """ global Annotated if Annotated is None: from sqlalchemy.sql.util import Annotated return Annotated(self, values) def _deannotate(self): - """return a copy of this ClauseElement with an empty annotations dictionary.""" + """return a copy of this ClauseElement with an empty annotations + dictionary. + + """ return self._clone() def unique_params(self, *optionaldict, **kwargs): @@ -1044,7 +1091,8 @@ class ClauseElement(Visitable): if len(optionaldict) == 1: kwargs.update(optionaldict[0]) elif len(optionaldict) > 1: - raise exc.ArgumentError("params() takes zero or one positional dictionary argument") + raise exc.ArgumentError( + "params() takes zero or one positional dictionary argument") def visit_bindparam(bind): if bind.key in kwargs: @@ -1088,15 +1136,20 @@ class ClauseElement(Visitable): def self_group(self, against=None): return self + # TODO: remove .bind as a method from the root ClauseElement. + # we should only be deriving binds from FromClause elements + # and certain SchemaItem subclasses. + # the "search_for_bind" functionality can still be used by + # execute(), however. @property def bind(self): - """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""" + """Returns the Engine or Connection to which this ClauseElement is + bound, or None if none found. + + """ + if self._bind is not None: + return self._bind - try: - if self._bind is not None: - return self._bind - except AttributeError: - pass for f in _from_objects(self): if f is self: continue @@ -1121,68 +1174,82 @@ class ClauseElement(Visitable): return e._execute_clauseelement(self, multiparams, params) def scalar(self, *multiparams, **params): - """Compile and execute this ``ClauseElement``, returning the result's scalar representation.""" - + """Compile and execute this ``ClauseElement``, returning the result's + scalar representation. + + """ return self.execute(*multiparams, **params).scalar() - def compile(self, bind=None, column_keys=None, compiler=None, dialect=None, inline=False): + def compile(self, bind=None, dialect=None, **kw): """Compile this SQL expression. The return value is a :class:`~sqlalchemy.engine.Compiled` object. - Calling `str()` or `unicode()` on the returned value will yield - a string representation of the result. The :class:`~sqlalchemy.engine.Compiled` - object also can return a dictionary of bind parameter names and - values using the `params` accessor. + Calling `str()` or `unicode()` on the returned value will yield a string + representation of the result. The :class:`~sqlalchemy.engine.Compiled` + object also can return a dictionary of bind parameter names and values + using the `params` accessor. :param bind: An ``Engine`` or ``Connection`` from which a ``Compiled`` will be acquired. This argument takes precedence over this ``ClauseElement``'s bound engine, if any. - :param column_keys: Used for INSERT and UPDATE statements, a list of - column names which should be present in the VALUES clause - of the compiled statement. If ``None``, all columns - from the target table object are rendered. - - :param compiler: A ``Compiled`` instance which will be used to compile - this expression. This argument takes precedence - over the `bind` and `dialect` arguments as well as - this ``ClauseElement``'s bound engine, if - any. - :param dialect: A ``Dialect`` instance frmo which a ``Compiled`` will be acquired. This argument takes precedence over the `bind` argument as well as this ``ClauseElement``'s bound engine, if any. - :param inline: Used for INSERT statements, for a dialect which does - not support inline retrieval of newly generated - primary key columns, will force the expression used - to create the new primary key value to be rendered - inline within the INSERT statement's VALUES clause. - This typically refers to Sequence execution but - may also refer to any server-side default generation - function associated with a primary key `Column`. + \**kw + + Keyword arguments are passed along to the compiler, + which can affect the string produced. + + Keywords for a statement compiler are: + + column_keys + Used for INSERT and UPDATE statements, a list of + column names which should be present in the VALUES clause + of the compiled statement. If ``None``, all columns + from the target table object are rendered. + + inline + Used for INSERT statements, for a dialect which does + not support inline retrieval of newly generated + primary key columns, will force the expression used + to create the new primary key value to be rendered + inline within the INSERT statement's VALUES clause. + This typically refers to Sequence execution but + may also refer to any server-side default generation + function associated with a primary key `Column`. """ - if compiler is None: - if dialect is not None: - compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline) - elif bind is not None: - compiler = bind.statement_compiler(self, column_keys=column_keys, inline=inline) - elif self.bind is not None: - compiler = self.bind.statement_compiler(self, column_keys=column_keys, inline=inline) + + if not dialect: + if bind: + dialect = bind.dialect + elif self.bind: + dialect = self.bind.dialect + bind = self.bind else: global DefaultDialect if DefaultDialect is None: from sqlalchemy.engine.default import DefaultDialect dialect = DefaultDialect() - compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline) + compiler = self._compiler(dialect, bind=bind, **kw) compiler.compile() return compiler - + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a Dialect.""" + + return dialect.statement_compiler(dialect, self, **kw) + def __str__(self): + # Py3K + #return unicode(self.compile()) + # Py2K return unicode(self.compile()).encode('ascii', 'backslashreplace') + # end Py2K def __and__(self, other): return and_(self, other) @@ -1193,11 +1260,25 @@ class ClauseElement(Visitable): def __invert__(self): return self._negate() + if util.jython: + def __hash__(self): + """Return a distinct hash code. + + ClauseElements may have special equality comparisons which + makes us rely on them having unique hash codes for use in + hash-based collections. Stock __hash__ doesn't guarantee + unique values on platforms with moving GCs. + """ + return id(self) + def _negate(self): if hasattr(self, 'negation_clause'): return self.negation_clause else: - return _UnaryExpression(self.self_group(against=operators.inv), operator=operators.inv, negate=None) + return _UnaryExpression( + self.self_group(against=operators.inv), + operator=operators.inv, + negate=None) def __repr__(self): friendly = getattr(self, 'description', None) @@ -1211,6 +1292,12 @@ class ClauseElement(Visitable): class _Immutable(object): """mark a ClauseElement as 'immutable' when expressions are cloned.""" + def unique_params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + def _clone(self): return self @@ -1330,6 +1417,9 @@ class ColumnOperators(Operators): def __truediv__(self, other): return self.operate(operators.truediv, other) + def __rtruediv__(self, other): + return self.reverse_operate(operators.truediv, other) + class _CompareMixin(ColumnOperators): """Defines comparison and math operations for ``ClauseElement`` instances.""" @@ -1365,7 +1455,9 @@ class _CompareMixin(ColumnOperators): operators.add : (__operate,), operators.mul : (__operate,), operators.sub : (__operate,), + # Py2K operators.div : (__operate,), + # end Py2K operators.mod : (__operate,), operators.truediv : (__operate,), operators.lt : (__compare, operators.ge), @@ -1632,7 +1724,7 @@ class ColumnCollection(util.OrderedProperties): def __init__(self, *cols): super(ColumnCollection, self).__init__() - [self.add(c) for c in cols] + self.update((c.key, c) for c in cols) def __str__(self): return repr([str(c) for c in self]) @@ -1734,8 +1826,10 @@ class Selectable(ClauseElement): __visit_name__ = 'selectable' class FromClause(Selectable): - """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement.""" - + """Represent an element that can be used within the ``FROM`` + clause of a ``SELECT`` statement. + + """ __visit_name__ = 'fromclause' named_with_column = False _hide_froms = [] @@ -1749,7 +1843,11 @@ class FromClause(Selectable): col = list(self.primary_key)[0] else: col = list(self.columns)[0] - return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) + return select( + [func.count(col).label('tbl_row_count')], + whereclause, + from_obj=[self], + **params) def select(self, whereclause=None, **params): """return a SELECT of this ``FromClause``.""" @@ -1794,8 +1892,10 @@ class FromClause(Selectable): return fromclause in self._cloned_set def replace_selectable(self, old, alias): - """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" - + """replace all occurences of FromClause 'old' with the given Alias + object, returning a copy of this ``FromClause``. + + """ global ClauseAdapter if ClauseAdapter is None: from sqlalchemy.sql.util import ClauseAdapter @@ -1846,24 +1946,30 @@ class FromClause(Selectable): col, intersect = c, i elif len(i) > len(intersect): # 'c' has a larger field of correspondence than 'col'. - # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches a1.c.x->table.c.x better than + # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches + # a1.c.x->table.c.x better than # selectable.c.x->table.c.x does. col, intersect = c, i elif i == intersect: # they have the same field of correspondence. - # see which proxy_set has fewer columns in it, which indicates a - # closer relationship with the root column. Also take into account the - # "weight" attribute which CompoundSelect() uses to give higher precedence to - # columns based on vertical position in the compound statement, and discard columns - # that have no reference to the target column (also occurs with CompoundSelect) + # see which proxy_set has fewer columns in it, which indicates + # a closer relationship with the root column. Also take into + # account the "weight" attribute which CompoundSelect() uses to + # give higher precedence to columns based on vertical position + # in the compound statement, and discard columns that have no + # reference to the target column (also occurs with + # CompoundSelect) col_distance = util.reduce(operator.add, - [sc._annotations.get('weight', 1) for sc in col.proxy_set if sc.shares_lineage(column)] + [sc._annotations.get('weight', 1) + for sc in col.proxy_set + if sc.shares_lineage(column)] ) c_distance = util.reduce(operator.add, - [sc._annotations.get('weight', 1) for sc in c.proxy_set if sc.shares_lineage(column)] + [sc._annotations.get('weight', 1) + for sc in c.proxy_set + if sc.shares_lineage(column)] ) - if \ - c_distance < col_distance: + if c_distance < col_distance: col, intersect = c, i return col @@ -2011,7 +2117,9 @@ class _BindParamClause(ColumnElement): the same type. """ - return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__ and self.value == other.value + return isinstance(other, _BindParamClause) and \ + other.type.__class__ == self.type.__class__ and \ + self.value == other.value def __getstate__(self): """execute a deferred value for serialization purposes.""" @@ -2024,7 +2132,9 @@ class _BindParamClause(ColumnElement): return d def __repr__(self): - return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type)) + return "_BindParamClause(%r, %r, type_=%r)" % ( + self.key, self.value, self.type + ) class _TypeClause(ClauseElement): """Handle a type keyword in a SQL statement. @@ -2057,7 +2167,8 @@ class _TextClause(ClauseElement): _hide_froms = [] - def __init__(self, text = "", bind=None, bindparams=None, typemap=None, autocommit=False): + def __init__(self, text = "", bind=None, + bindparams=None, typemap=None, autocommit=False): self._bind = bind self.bindparams = {} self.typemap = typemap @@ -2157,7 +2268,8 @@ class ClauseList(ClauseElement): return list(itertools.chain(*[c._from_objects for c in self.clauses])) def self_group(self, against=None): - if self.group and self.operator is not against and operators.is_precedent(self.operator, against): + if self.group and self.operator is not against and \ + operators.is_precedent(self.operator, against): return _Grouping(self) else: return self @@ -2200,9 +2312,13 @@ class _Case(ColumnElement): pass if value: - whenlist = [(_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens] + whenlist = [ + (_literal_as_binds(c).self_group(), _literal_as_binds(r)) for (c, r) in whens + ] else: - whenlist = [(_no_literals(c).self_group(), _literal_as_binds(r)) for (c, r) in whens] + whenlist = [ + (_no_literals(c).self_group(), _literal_as_binds(r)) for (c, r) in whens + ] if whenlist: type_ = list(whenlist[-1])[-1].type @@ -2472,16 +2588,19 @@ class _Exists(_UnaryExpression): return e def select_from(self, clause): - """return a new exists() construct with the given expression set as its FROM clause.""" - + """return a new exists() construct with the given expression set as its FROM + clause. + + """ e = self._clone() e.element = self.element.select_from(clause).self_group() return e def where(self, clause): - """return a new exists() construct with the given expression added to its WHERE clause, joined - to the existing clause via AND, if any.""" - + """return a new exists() construct with the given expression added to its WHERE + clause, joined to the existing clause via AND, if any. + + """ e = self._clone() e.element = self.element.where(clause).self_group() return e @@ -2517,7 +2636,9 @@ class Join(FromClause): id(self.right)) def is_derived_from(self, fromclause): - return fromclause is self or self.left.is_derived_from(fromclause) or self.right.is_derived_from(fromclause) + return fromclause is self or \ + self.left.is_derived_from(fromclause) or\ + self.right.is_derived_from(fromclause) def self_group(self, against=None): return _FromGrouping(self) @@ -2634,7 +2755,11 @@ class Alias(FromClause): @property def description(self): + # Py3K + #return self.name + # Py2K return self.name.encode('ascii', 'backslashreplace') + # end Py2K def as_scalar(self): try: @@ -2762,14 +2887,19 @@ class _Label(ColumnElement): def __init__(self, name, element, type_=None): while isinstance(element, _Label): element = element.element - self.name = self.key = self._label = name or _generated_label("%%(%d %s)s" % (id(self), getattr(element, 'name', 'anon'))) + self.name = self.key = self._label = name or \ + _generated_label("%%(%d %s)s" % ( + id(self), getattr(element, 'name', 'anon')) + ) self._element = element self._type = type_ self.quote = element.quote @util.memoized_property def type(self): - return sqltypes.to_instance(self._type or getattr(self._element, 'type', None)) + return sqltypes.to_instance( + self._type or getattr(self._element, 'type', None) + ) @util.memoized_property def element(self): @@ -2842,7 +2972,11 @@ class ColumnClause(_Immutable, ColumnElement): @util.memoized_property def description(self): + # Py3K + #return self.name + # Py2K return self.name.encode('ascii', 'backslashreplace') + # end Py2K @util.memoized_property def _label(self): @@ -2891,7 +3025,12 @@ class ColumnClause(_Immutable, ColumnElement): # propagate the "is_literal" flag only if we are keeping our name, # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) - c = ColumnClause(name or self.name, selectable=selectable, type_=self.type, is_literal=is_literal) + c = ColumnClause( + name or self.name, + selectable=selectable, + type_=self.type, + is_literal=is_literal + ) c.proxies = [self] if attach: selectable.columns[c.name] = c @@ -2927,7 +3066,11 @@ class TableClause(_Immutable, FromClause): @util.memoized_property def description(self): + # Py3K + #return self.name + # Py2K return self.name.encode('ascii', 'backslashreplace') + # end Py2K def append_column(self, c): self._columns[c.name] = c @@ -2944,7 +3087,11 @@ class TableClause(_Immutable, FromClause): col = list(self.primary_key)[0] else: col = list(self.columns)[0] - return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) + return select( + [func.count(col).label('tbl_row_count')], + whereclause, + from_obj=[self], + **params) def insert(self, values=None, inline=False, **kwargs): """Generate an :func:`~sqlalchemy.sql.expression.insert()` construct.""" @@ -2954,7 +3101,8 @@ class TableClause(_Immutable, FromClause): def update(self, whereclause=None, values=None, inline=False, **kwargs): """Generate an :func:`~sqlalchemy.sql.expression.update()` construct.""" - return update(self, whereclause=whereclause, values=values, inline=inline, **kwargs) + return update(self, whereclause=whereclause, + values=values, inline=inline, **kwargs) def delete(self, whereclause=None, **kwargs): """Generate a :func:`~sqlalchemy.sql.expression.delete()` construct.""" @@ -3004,7 +3152,8 @@ class _SelectBaseMixin(object): Typically, a select statement which has only one column in its columns clause is eligible to be used as a scalar expression. - The returned object is an instance of :class:`~sqlalchemy.sql.expression._ScalarSelect`. + The returned object is an instance of + :class:`~sqlalchemy.sql.expression._ScalarSelect`. """ return _ScalarSelect(self) @@ -3013,10 +3162,10 @@ class _SelectBaseMixin(object): def apply_labels(self): """return a new selectable with the 'use_labels' flag set to True. - This will result in column expressions being generated using labels against their table - name, such as "SELECT somecolumn AS tablename_somecolumn". This allows selectables which - contain multiple FROM clauses to produce a unique set of column names regardless of name conflicts - among the individual FROM clauses. + This will result in column expressions being generated using labels against their + table name, such as "SELECT somecolumn AS tablename_somecolumn". This allows + selectables which contain multiple FROM clauses to produce a unique set of column + names regardless of name conflicts among the individual FROM clauses. """ self.use_labels = True @@ -3127,7 +3276,8 @@ class _ScalarSelect(_Grouping): return list(self.inner_columns)[0]._make_proxy(selectable, name) class CompoundSelect(_SelectBaseMixin, FromClause): - """Forms the basis of ``UNION``, ``UNION ALL``, and other SELECT-based set operations.""" + """Forms the basis of ``UNION``, ``UNION ALL``, and other + SELECT-based set operations.""" __visit_name__ = 'compound_select' @@ -3147,7 +3297,8 @@ class CompoundSelect(_SelectBaseMixin, FromClause): elif len(s.c) != numcols: raise exc.ArgumentError( "All selectables passed to CompoundSelect must " - "have identical numbers of columns; select #%d has %d columns, select #%d has %d" % + "have identical numbers of columns; select #%d has %d columns," + " select #%d has %d" % (1, len(self.selects[0].c), n+1, len(s.c)) ) @@ -3222,7 +3373,15 @@ class Select(_SelectBaseMixin, FromClause): __visit_name__ = 'select' - def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, prefixes=None, **kwargs): + def __init__(self, + columns, + whereclause=None, + from_obj=None, + distinct=False, + having=None, + correlate=True, + prefixes=None, + **kwargs): """Construct a Select object. The public constructor for Select is the @@ -3241,9 +3400,9 @@ class Select(_SelectBaseMixin, FromClause): if columns: self._raw_columns = [ - isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c - for c in - [_literal_as_column(c) for c in columns] + isinstance(c, _ScalarSelect) and + c.self_group(against=operators.comma_op) or c + for c in [_literal_as_column(c) for c in columns] ] self._froms.update(_from_objects(*self._raw_columns)) @@ -3331,8 +3490,7 @@ class Select(_SelectBaseMixin, FromClause): be rendered into the columns clause of the resulting SELECT statement. """ - - return itertools.chain(*[c._select_iterable for c in self._raw_columns]) + return _select_iterables(self._raw_columns) def is_derived_from(self, fromclause): if self in fromclause._cloned_set: @@ -3347,7 +3505,7 @@ class Select(_SelectBaseMixin, FromClause): self._reset_exported() from_cloned = dict((f, clone(f)) for f in self._froms.union(self._correlate)) - self._froms = set(from_cloned[f] for f in self._froms) + self._froms = util.OrderedSet(from_cloned[f] for f in self._froms) self._correlate = set(from_cloned[f] for f in self._correlate) self._raw_columns = [clone(c) for c in self._raw_columns] for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): @@ -3359,11 +3517,17 @@ class Select(_SelectBaseMixin, FromClause): return (column_collections and list(self.columns) or []) + \ self._raw_columns + list(self._froms) + \ - [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] + [x for x in + (self._whereclause, self._having, + self._order_by_clause, self._group_by_clause) + if x is not None] @_generative def column(self, column): - """return a new select() construct with the given column expression added to its columns clause.""" + """return a new select() construct with the given column expression + added to its columns clause. + + """ column = _literal_as_column(column) @@ -3375,63 +3539,73 @@ class Select(_SelectBaseMixin, FromClause): @_generative def with_only_columns(self, columns): - """return a new select() construct with its columns clause replaced with the given columns.""" + """return a new select() construct with its columns clause replaced + with the given columns. + + """ self._raw_columns = [ - isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c - for c in - [_literal_as_column(c) for c in columns] + isinstance(c, _ScalarSelect) and + c.self_group(against=operators.comma_op) or c + for c in [_literal_as_column(c) for c in columns] ] @_generative def where(self, whereclause): - """return a new select() construct with the given expression added to its WHERE clause, joined - to the existing clause via AND, if any.""" + """return a new select() construct with the given expression added to its + WHERE clause, joined to the existing clause via AND, if any. + + """ self.append_whereclause(whereclause) @_generative def having(self, having): - """return a new select() construct with the given expression added to its HAVING clause, joined - to the existing clause via AND, if any.""" - + """return a new select() construct with the given expression added to its HAVING + clause, joined to the existing clause via AND, if any. + + """ self.append_having(having) @_generative def distinct(self): - """return a new select() construct which will apply DISTINCT to its columns clause.""" - + """return a new select() construct which will apply DISTINCT to its columns + clause. + + """ self._distinct = True @_generative def prefix_with(self, clause): - """return a new select() construct which will apply the given expression to the start of its - columns clause, not using any commas.""" + """return a new select() construct which will apply the given expression to the + start of its columns clause, not using any commas. + """ clause = _literal_as_text(clause) self._prefixes = self._prefixes + [clause] @_generative def select_from(self, fromclause): - """return a new select() construct with the given FROM expression applied to its list of - FROM objects.""" + """return a new select() construct with the given FROM expression applied to its + list of FROM objects. + """ fromclause = _literal_as_text(fromclause) self._froms = self._froms.union([fromclause]) @_generative def correlate(self, *fromclauses): - """return a new select() construct which will correlate the given FROM clauses to that - of an enclosing select(), if a match is found. - - By "match", the given fromclause must be present in this select's list of FROM objects - and also present in an enclosing select's list of FROM objects. - - Calling this method turns off the select's default behavior of "auto-correlation". Normally, - select() auto-correlates all of its FROM clauses to those of an embedded select when - compiled. - - If the fromclause is None, correlation is disabled for the returned select(). + """return a new select() construct which will correlate the given FROM clauses to + that of an enclosing select(), if a match is found. + + By "match", the given fromclause must be present in this select's list of FROM + objects and also present in an enclosing select's list of FROM objects. + + Calling this method turns off the select's default behavior of + "auto-correlation". Normally, select() auto-correlates all of its FROM clauses to + those of an embedded select when compiled. + + If the fromclause is None, correlation is disabled for the returned select(). """ self._should_correlate = False @@ -3447,8 +3621,10 @@ class Select(_SelectBaseMixin, FromClause): self._correlate = self._correlate.union([fromclause]) def append_column(self, column): - """append the given column expression to the columns clause of this select() construct.""" - + """append the given column expression to the columns clause of this select() + construct. + + """ column = _literal_as_column(column) if isinstance(column, _ScalarSelect): @@ -3459,8 +3635,10 @@ class Select(_SelectBaseMixin, FromClause): self._reset_exported() def append_prefix(self, clause): - """append the given columns clause prefix expression to this select() construct.""" - + """append the given columns clause prefix expression to this select() + construct. + + """ clause = _literal_as_text(clause) self._prefixes = self._prefixes.union([clause]) @@ -3490,7 +3668,8 @@ class Select(_SelectBaseMixin, FromClause): self._having = _literal_as_text(having) def append_from(self, fromclause): - """append the given FromClause expression to this select() construct's FROM clause. + """append the given FromClause expression to this select() construct's FROM + clause. """ if _is_literal(fromclause): @@ -3529,8 +3708,10 @@ class Select(_SelectBaseMixin, FromClause): return union(self, other, **kwargs) def union_all(self, other, **kwargs): - """return a SQL UNION ALL of this select() construct against the given selectable.""" - + """return a SQL UNION ALL of this select() construct against the given + selectable. + + """ return union_all(self, other, **kwargs) def except_(self, other, **kwargs): @@ -3539,18 +3720,24 @@ class Select(_SelectBaseMixin, FromClause): return except_(self, other, **kwargs) def except_all(self, other, **kwargs): - """return a SQL EXCEPT ALL of this select() construct against the given selectable.""" - + """return a SQL EXCEPT ALL of this select() construct against the given + selectable. + + """ return except_all(self, other, **kwargs) def intersect(self, other, **kwargs): - """return a SQL INTERSECT of this select() construct against the given selectable.""" - + """return a SQL INTERSECT of this select() construct against the given + selectable. + + """ return intersect(self, other, **kwargs) def intersect_all(self, other, **kwargs): - """return a SQL INTERSECT ALL of this select() construct against the given selectable.""" - + """return a SQL INTERSECT ALL of this select() construct against the given + selectable. + + """ return intersect_all(self, other, **kwargs) def bind(self): @@ -3581,7 +3768,7 @@ class _UpdateBase(ClauseElement): supports_execution = True _autocommit = True - + def _generate(self): s = self.__class__.__new__(self.__class__) s.__dict__ = self.__dict__.copy() @@ -3597,8 +3784,10 @@ class _UpdateBase(ClauseElement): return parameters def params(self, *arg, **kw): - raise NotImplementedError("params() is not supported for INSERT/UPDATE/DELETE statements." - " To set the values for an INSERT or UPDATE statement, use stmt.values(**parameters).") + raise NotImplementedError( + "params() is not supported for INSERT/UPDATE/DELETE statements." + " To set the values for an INSERT or UPDATE statement, use" + " stmt.values(**parameters).") def bind(self): return self._bind or self.table.bind @@ -3607,6 +3796,51 @@ class _UpdateBase(ClauseElement): self._bind = bind bind = property(bind, _set_bind) + _returning_re = re.compile(r'(?:firebird|postgres(?:ql)?)_returning') + def _process_deprecated_kw(self, kwargs): + for k in list(kwargs): + m = self._returning_re.match(k) + if m: + self._returning = kwargs.pop(k) + util.warn_deprecated( + "The %r argument is deprecated. Please use statement.returning(col1, col2, ...)" % k + ) + return kwargs + + @_generative + def returning(self, *cols): + """Add a RETURNING or equivalent clause to this statement. + + The given list of columns represent columns within the table + that is the target of the INSERT, UPDATE, or DELETE. Each + element can be any column expression. ``Table`` objects + will be expanded into their individual columns. + + Upon compilation, a RETURNING clause, or database equivalent, + will be rendered within the statement. For INSERT and UPDATE, + the values are the newly inserted/updated values. For DELETE, + the values are those of the rows which were deleted. + + Upon execution, the values of the columns to be returned + are made available via the result set and can be iterated + using ``fetchone()`` and similar. For DBAPIs which do not + natively support returning values (i.e. cx_oracle), + SQLAlchemy will approximate this behavior at the result level + so that a reasonable amount of behavioral neutrality is + provided. + + Note that not all databases/DBAPIs + support RETURNING. For those backends with no support, + an exception is raised upon compilation and/or execution. + For those who do support it, the functionality across backends + varies greatly, including restrictions on executemany() + and other statements which return multiple rows. Please + read the documentation notes for the database in use in + order to determine the availability of RETURNING. + + """ + self._returning = cols + class _ValuesBase(_UpdateBase): __visit_name__ = 'values_base' @@ -3617,14 +3851,15 @@ class _ValuesBase(_UpdateBase): @_generative def values(self, *args, **kwargs): - """specify the VALUES clause for an INSERT statement, or the SET clause for an UPDATE. + """specify the VALUES clause for an INSERT statement, or the SET clause for an + UPDATE. \**kwargs key= arguments \*args - A single dictionary can be sent as the first positional argument. This allows - non-string based keys, such as Column objects, to be used. + A single dictionary can be sent as the first positional argument. This + allows non-string based keys, such as Column objects, to be used. """ if args: @@ -3648,16 +3883,25 @@ class Insert(_ValuesBase): """ __visit_name__ = 'insert' - def __init__(self, table, values=None, inline=False, bind=None, prefixes=None, **kwargs): + def __init__(self, + table, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + **kwargs): _ValuesBase.__init__(self, table, values) self._bind = bind self.select = None self.inline = inline + self._returning = returning if prefixes: self._prefixes = [_literal_as_text(p) for p in prefixes] else: self._prefixes = [] - self.kwargs = kwargs + + self.kwargs = self._process_deprecated_kw(kwargs) def get_children(self, **kwargs): if self.select is not None: @@ -3688,15 +3932,24 @@ class Update(_ValuesBase): """ __visit_name__ = 'update' - def __init__(self, table, whereclause, values=None, inline=False, bind=None, **kwargs): + def __init__(self, + table, + whereclause, + values=None, + inline=False, + bind=None, + returning=None, + **kwargs): _ValuesBase.__init__(self, table, values) self._bind = bind + self._returning = returning if whereclause: self._whereclause = _literal_as_text(whereclause) else: self._whereclause = None self.inline = inline - self.kwargs = kwargs + + self.kwargs = self._process_deprecated_kw(kwargs) def get_children(self, **kwargs): if self._whereclause is not None: @@ -3711,9 +3964,10 @@ class Update(_ValuesBase): @_generative def where(self, whereclause): - """return a new update() construct with the given expression added to its WHERE clause, joined - to the existing clause via AND, if any.""" - + """return a new update() construct with the given expression added to its WHERE + clause, joined to the existing clause via AND, if any. + + """ if self._whereclause is not None: self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) else: @@ -3729,15 +3983,22 @@ class Delete(_UpdateBase): __visit_name__ = 'delete' - def __init__(self, table, whereclause, bind=None, **kwargs): + def __init__(self, + table, + whereclause, + bind=None, + returning =None, + **kwargs): self._bind = bind self.table = table + self._returning = returning + if whereclause: self._whereclause = _literal_as_text(whereclause) else: self._whereclause = None - self.kwargs = kwargs + self.kwargs = self._process_deprecated_kw(kwargs) def get_children(self, **kwargs): if self._whereclause is not None: diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 7c21e8233a..879f0f3e51 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -4,8 +4,13 @@ """Defines operators used in SQL expressions.""" from operator import ( - and_, or_, inv, add, mul, sub, div, mod, truediv, lt, le, ne, gt, ge, eq + and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq ) + +# Py2K +from operator import (div,) +# end Py2K + from sqlalchemy.util import symbol @@ -88,7 +93,10 @@ _largest = symbol('_largest') _PRECEDENCE = { from_: 15, mul: 7, + truediv: 7, + # Py2K div: 7, + # end Py2K mod: 7, add: 6, sub: 6, diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index a5bd497aed..4471d4fb0d 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -34,13 +34,10 @@ class VisitableType(type): """ def __init__(cls, clsname, bases, clsdict): - if cls.__name__ == 'Visitable': + if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'): super(VisitableType, cls).__init__(clsname, bases, clsdict) return - assert hasattr(cls, '__visit_name__'), "`Visitable` descendants " \ - "should define `__visit_name__`" - # set up an optimized visit dispatch function # for use by the compiler visit_name = cls.__visit_name__ diff --git a/lib/sqlalchemy/test/assertsql.py b/lib/sqlalchemy/test/assertsql.py index dc2c6d40f8..1af28794ed 100644 --- a/lib/sqlalchemy/test/assertsql.py +++ b/lib/sqlalchemy/test/assertsql.py @@ -3,7 +3,6 @@ from sqlalchemy.interfaces import ConnectionProxy from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.engine.base import Connection from sqlalchemy import util -import testing import re class AssertRule(object): diff --git a/lib/sqlalchemy/test/config.py b/lib/sqlalchemy/test/config.py index 6ea5667cc3..eec962d807 100644 --- a/lib/sqlalchemy/test/config.py +++ b/lib/sqlalchemy/test/config.py @@ -1,4 +1,8 @@ -import optparse, os, sys, re, ConfigParser, StringIO, time, warnings +import optparse, os, sys, re, ConfigParser, time, warnings + +# 2to3 +import StringIO + logging = None __all__ = 'parser', 'configure', 'options', @@ -13,7 +17,11 @@ base_config = """ [db] sqlite=sqlite:///:memory: sqlite_file=sqlite:///querytest.db -postgres=postgres://scott:tiger@127.0.0.1:5432/test +postgresql=postgresql://scott:tiger@127.0.0.1:5432/test +postgres=postgresql://scott:tiger@127.0.0.1:5432/test +pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test +postgresql_jython=postgresql+zxjdbc://scott:tiger@127.0.0.1:5432/test +mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test mysql=mysql://scott:tiger@127.0.0.1:3306/test oracle=oracle://scott:tiger@127.0.0.1:1521 oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 @@ -125,28 +133,22 @@ def _prep_testing_database(options, file_config): from sqlalchemy.test import engines from sqlalchemy import schema - try: - # also create alt schemas etc. here? - if options.dropfirst: - e = engines.utf8_engine() - existing = e.table_names() - if existing: - print "Dropping existing tables in database: " + db_url - try: - print "Tables: %s" % ', '.join(existing) - except: - pass - print "Abort within 5 seconds..." - time.sleep(5) - md = schema.MetaData(e, reflect=True) - md.drop_all() - e.dispose() - except (KeyboardInterrupt, SystemExit): - raise - except Exception, e: - warnings.warn(RuntimeWarning( - "Error checking for existing tables in testing " - "database: %s" % e)) + # also create alt schemas etc. here? + if options.dropfirst: + e = engines.utf8_engine() + existing = e.table_names() + if existing: + print "Dropping existing tables in database: " + db_url + try: + print "Tables: %s" % ', '.join(existing) + except: + pass + print "Abort within 5 seconds..." + time.sleep(5) + md = schema.MetaData(e, reflect=True) + md.drop_all() + e.dispose() + post_configure['prep_db'] = _prep_testing_database def _set_table_options(options, file_config): diff --git a/lib/sqlalchemy/test/engines.py b/lib/sqlalchemy/test/engines.py index f0001978bf..187ad2ff03 100644 --- a/lib/sqlalchemy/test/engines.py +++ b/lib/sqlalchemy/test/engines.py @@ -2,6 +2,7 @@ import sys, types, weakref from collections import deque import config from sqlalchemy.util import function_named, callable +import re class ConnectionKiller(object): def __init__(self): @@ -11,7 +12,8 @@ class ConnectionKiller(object): self.proxy_refs[con_proxy] = True def _apply_all(self, methods): - for rec in self.proxy_refs: + # must copy keys atomically + for rec in self.proxy_refs.keys(): if rec is not None and rec.is_valid: try: for name in methods: @@ -38,6 +40,10 @@ class ConnectionKiller(object): testing_reaper = ConnectionKiller() +def drop_all_tables(metadata): + testing_reaper.close_all() + metadata.drop_all() + def assert_conns_closed(fn): def decorated(*args, **kw): try: @@ -56,6 +62,14 @@ def rollback_open_connections(fn): testing_reaper.rollback_all() return function_named(decorated, fn.__name__) +def close_first(fn): + """Decorator that closes all connections before fn execution.""" + def decorated(*args, **kw): + testing_reaper.close_all() + fn(*args, **kw) + return function_named(decorated, fn.__name__) + + def close_open_connections(fn): """Decorator that closes all connections after fn execution.""" @@ -69,7 +83,10 @@ def close_open_connections(fn): def all_dialects(): import sqlalchemy.databases as d for name in d.__all__: - mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) + # TEMPORARY + mod = getattr(d, name, None) + if not mod: + mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) yield mod.dialect() class ReconnectFixture(object): @@ -115,7 +132,11 @@ def testing_engine(url=None, options=None): listeners.append(testing_reaper) engine = create_engine(url, **options) - + + # may want to call this, results + # in first-connect initializers + #engine.connect() + return engine def utf8_engine(url=None, options=None): @@ -123,7 +144,7 @@ def utf8_engine(url=None, options=None): from sqlalchemy.engine import url as engine_url - if config.db.name == 'mysql': + if config.db.driver == 'mysqldb': dbapi_ver = config.db.dialect.dbapi.version_info if (dbapi_ver < (1, 2, 1) or dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2), @@ -139,19 +160,35 @@ def utf8_engine(url=None, options=None): return testing_engine(url, options) -def mock_engine(db=None): - """Provides a mocking engine based on the current testing.db.""" +def mock_engine(dialect_name=None): + """Provides a mocking engine based on the current testing.db. + + This is normally used to test DDL generation flow as emitted + by an Engine. + + It should not be used in other cases, as assert_compile() and + assert_sql_execution() are much better choices with fewer + moving parts. + + """ from sqlalchemy import create_engine - dbi = db or config.db + if not dialect_name: + dialect_name = config.db.name + buffer = [] def executor(sql, *a, **kw): buffer.append(sql) - engine = create_engine(dbi.name + '://', + def assert_sql(stmts): + recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + assert recv == stmts, recv + + engine = create_engine(dialect_name + '://', strategy='mock', executor=executor) assert not hasattr(engine, 'mock') engine.mock = buffer + engine.assert_sql = assert_sql return engine class ReplayableSession(object): @@ -168,9 +205,16 @@ class ReplayableSession(object): Natives = set([getattr(types, t) for t in dir(types) if not t.startswith('_')]). \ difference([getattr(types, t) + # Py3K + #for t in ('FunctionType', 'BuiltinFunctionType', + # 'MethodType', 'BuiltinMethodType', + # 'LambdaType', )]) + + # Py2K for t in ('FunctionType', 'BuiltinFunctionType', 'MethodType', 'BuiltinMethodType', 'LambdaType', 'UnboundMethodType',)]) + # end Py2K def __init__(self): self.buffer = deque() diff --git a/lib/sqlalchemy/test/noseplugin.py b/lib/sqlalchemy/test/noseplugin.py index 263d2d7831..c4f32a1630 100644 --- a/lib/sqlalchemy/test/noseplugin.py +++ b/lib/sqlalchemy/test/noseplugin.py @@ -14,7 +14,7 @@ from config import db, db_label, db_url, file_config, base_config, \ _set_table_options, _reverse_topological, _log from sqlalchemy.test import testing, config, requires from nose.plugins import Plugin -from nose.util import tolist +from sqlalchemy import util import nose.case log = logging.getLogger('nose.plugins.sqlalchemy') @@ -30,9 +30,6 @@ class NoseSQLAlchemy(Plugin): def options(self, parser, env=os.environ): Plugin.options(self, parser, env) opt = parser.add_option - #opt("--verbose", action="store_true", dest="verbose", - #help="enable stdout echoing/printing") - #opt("--quiet", action="store_true", dest="quiet", help="suppress output") opt("--log-info", action="callback", type="string", callback=_log, help="turn on info logging for (multiple OK)") opt("--log-debug", action="callback", type="string", callback=_log, @@ -77,15 +74,16 @@ class NoseSQLAlchemy(Plugin): def configure(self, options, conf): Plugin.configure(self, options, conf) - - import testing, requires + self.options = options + + def begin(self): testing.db = db testing.requires = requires # Lazy setup of other options (post coverage) for fn in post_configure: - fn(options, file_config) - + fn(self.options, file_config) + def describeTest(self, test): return "" @@ -117,15 +115,20 @@ class NoseSQLAlchemy(Plugin): if check(test_suite)() != 'ok': # The requirement will perform messaging. return True - if (hasattr(cls, '__unsupported_on__') and - testing.db.name in cls.__unsupported_on__): - print "'%s' unsupported on DB implementation '%s'" % ( - cls.__class__.__name__, testing.db.name) - return True - if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)): - print "'%s' unsupported on DB implementation '%s'" % ( - cls.__class__.__name__, testing.db.name) - return True + + if cls.__unsupported_on__: + spec = testing.db_spec(*cls.__unsupported_on__) + if spec(testing.db): + print "'%s' unsupported on DB implementation '%s'" % ( + cls.__class__.__name__, testing.db.name) + return True + if getattr(cls, '__only_on__', None): + spec = testing.db_spec(*util.to_list(cls.__only_on__)) + if not spec(testing.db): + print "'%s' unsupported on DB implementation '%s'" % ( + cls.__class__.__name__, testing.db.name) + return True + if (getattr(cls, '__skip_if__', False)): for c in getattr(cls, '__skip_if__'): if c(): @@ -140,15 +143,15 @@ class NoseSQLAlchemy(Plugin): return True return False - #def begin(self): - #pass - def beforeTest(self, test): testing.resetwarnings() def afterTest(self, test): testing.resetwarnings() + def afterContext(self): + testing.global_cleanup_assertions() + #def handleError(self, test, err): #pass diff --git a/lib/sqlalchemy/test/profiling.py b/lib/sqlalchemy/test/profiling.py index ca4b31cbd8..8cab6ceba1 100644 --- a/lib/sqlalchemy/test/profiling.py +++ b/lib/sqlalchemy/test/profiling.py @@ -6,8 +6,9 @@ in a more fine-grained way than nose's profiling plugin. """ import os, sys -from sqlalchemy.util import function_named -import config +from sqlalchemy.test import config +from sqlalchemy.test.util import function_named, gc_collect +from nose import SkipTest __all__ = 'profiled', 'function_call_count', 'conditional_call_count' @@ -162,15 +163,22 @@ def conditional_call_count(discriminator, categories): def _profile(filename, fn, *args, **kw): global profiler if not profiler: - profiler = 'hotshot' if sys.version_info > (2, 5): try: import cProfile profiler = 'cProfile' except ImportError: pass + if not profiler: + try: + import hotshot + profiler = 'hotshot' + except ImportError: + profiler = 'skip' - if profiler == 'cProfile': + if profiler == 'skip': + raise SkipTest('Profiling not supported on this platform') + elif profiler == 'cProfile': return _profile_cProfile(filename, fn, *args, **kw) else: return _profile_hotshot(filename, fn, *args, **kw) @@ -179,7 +187,7 @@ def _profile_cProfile(filename, fn, *args, **kw): import cProfile, gc, pstats, time load_stats = lambda: pstats.Stats(filename) - gc.collect() + gc_collect() began = time.time() cProfile.runctx('result = fn(*args, **kw)', globals(), locals(), @@ -192,7 +200,7 @@ def _profile_hotshot(filename, fn, *args, **kw): import gc, hotshot, hotshot.stats, time load_stats = lambda: hotshot.stats.load(filename) - gc.collect() + gc_collect() prof = hotshot.Profile(filename) began = time.time() prof.start() diff --git a/lib/sqlalchemy/test/requires.py b/lib/sqlalchemy/test/requires.py index b23b8620da..f3f4ec1911 100644 --- a/lib/sqlalchemy/test/requires.py +++ b/lib/sqlalchemy/test/requires.py @@ -28,6 +28,25 @@ def foreign_keys(fn): no_support('sqlite', 'not supported by database'), ) + +def unbounded_varchar(fn): + """Target database must support VARCHAR with no length""" + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('oracle', 'not supported by database'), + no_support('mysql', 'not supported by database'), + ) + +def boolean_col_expressions(fn): + """Target database must support boolean expressions as columns""" + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('oracle', 'not supported by database'), + no_support('mssql', 'not supported by database'), + ) + def identity(fn): """Target database must support GENERATED AS IDENTITY or a facsimile. @@ -40,7 +59,7 @@ def identity(fn): fn, no_support('firebird', 'not supported by database'), no_support('oracle', 'not supported by database'), - no_support('postgres', 'not supported by database'), + no_support('postgresql', 'not supported by database'), no_support('sybase', 'not supported by database'), ) @@ -61,9 +80,19 @@ def row_triggers(fn): # no access to same table no_support('mysql', 'requires SUPER priv'), exclude('mysql', '<', (5, 0, 10), 'not supported by database'), - no_support('postgres', 'not supported by database: no statements'), + + # huh? TODO: implement triggers for PG tests, remove this + no_support('postgresql', 'PG triggers need to be implemented for tests'), ) +def correlated_outer_joins(fn): + """Target must support an outer join to a subquery which correlates to the parent.""" + + return _chain_decorators_on( + fn, + no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"') + ) + def savepoints(fn): """Target database must support savepoints.""" return _chain_decorators_on( @@ -75,6 +104,15 @@ def savepoints(fn): exclude('mysql', '<', (5, 0, 3), 'not supported by database'), ) +def schemas(fn): + """Target database must support external schemas, and have one named 'test_schema'.""" + + return _chain_decorators_on( + fn, + no_support('sqlite', 'no schema support'), + no_support('firebird', 'no schema support') + ) + def sequences(fn): """Target database must support SEQUENCEs.""" return _chain_decorators_on( @@ -93,6 +131,17 @@ def subqueries(fn): exclude('mysql', '<', (4, 1, 1), 'no subquery support'), ) +def returning(fn): + return _chain_decorators_on( + fn, + no_support('access', 'not supported by database'), + no_support('sqlite', 'not supported by database'), + no_support('mysql', 'not supported by database'), + no_support('maxdb', 'not supported by database'), + no_support('sybase', 'not supported by database'), + no_support('informix', 'not supported by database'), + ) + def two_phase_transactions(fn): """Target database must support two-phase transactions.""" return _chain_decorators_on( @@ -104,6 +153,8 @@ def two_phase_transactions(fn): no_support('oracle', 'no SA implementation'), no_support('sqlite', 'not supported by database'), no_support('sybase', 'FIXME: guessing, needs confirmation'), + no_support('postgresql+zxjdbc', 'FIXME: JDBC driver confuses the transaction state, may ' + 'need separate XA implementation'), exclude('mysql', '<', (5, 0, 3), 'not supported by database'), ) diff --git a/lib/sqlalchemy/test/schema.py b/lib/sqlalchemy/test/schema.py index f96805fe49..35b4060d2b 100644 --- a/lib/sqlalchemy/test/schema.py +++ b/lib/sqlalchemy/test/schema.py @@ -33,7 +33,7 @@ def Table(*args, **kw): # expand to ForeignKeyConstraint too. fks = [fk for col in args if isinstance(col, schema.Column) - for fk in col.args if isinstance(fk, schema.ForeignKey)] + for fk in col.foreign_keys] for fk in fks: # root around in raw spec @@ -51,13 +51,6 @@ def Table(*args, **kw): if fk.onupdate is None: fk.onupdate = 'CASCADE' - if testing.against('firebird', 'oracle'): - pk_seqs = [col for col in args - if (isinstance(col, schema.Column) - and col.primary_key - and getattr(col, '_needs_autoincrement', False))] - for c in pk_seqs: - c.args.append(schema.Sequence(args[0] + '_' + c.name + '_seq', optional=True)) return schema.Table(*args, **kw) @@ -67,8 +60,20 @@ def Column(*args, **kw): test_opts = dict([(k,kw.pop(k)) for k in kw.keys() if k.startswith('test_')]) - c = schema.Column(*args, **kw) - if testing.against('firebird', 'oracle'): - if 'test_needs_autoincrement' in test_opts: - c._needs_autoincrement = True - return c + col = schema.Column(*args, **kw) + if 'test_needs_autoincrement' in test_opts and \ + kw.get('primary_key', False) and \ + testing.against('firebird', 'oracle'): + def add_seq(tbl): + col._init_items( + schema.Sequence(_truncate_name(testing.db.dialect, tbl.name + '_' + col.name + '_seq'), optional=True) + ) + col._on_table_attach(add_seq) + return col + +def _truncate_name(dialect, name): + if len(name) > dialect.max_identifier_length: + return name[0:max(dialect.max_identifier_length - 6, 0)] + "_" + hex(hash(name) % 64)[2:] + else: + return name + \ No newline at end of file diff --git a/lib/sqlalchemy/test/testing.py b/lib/sqlalchemy/test/testing.py index 36c7d340a3..16a13d9d3b 100644 --- a/lib/sqlalchemy/test/testing.py +++ b/lib/sqlalchemy/test/testing.py @@ -8,10 +8,12 @@ import types import warnings from cStringIO import StringIO -from sqlalchemy.test import config, assertsql +from sqlalchemy.test import config, assertsql, util as testutil from sqlalchemy.util import function_named +from engines import drop_all_tables -from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema +from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool +from nose import SkipTest _ops = { '<': operator.lt, '>': operator.gt, @@ -80,6 +82,19 @@ def future(fn): "Unexpected success for future test '%s'" % fn_name) return function_named(decorated, fn_name) +def db_spec(*dbs): + dialects = set([x for x in dbs if '+' not in x]) + drivers = set([x[1:] for x in dbs if x.startswith('+')]) + specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers]) + + def check(engine): + return engine.name in dialects or \ + engine.driver in drivers or \ + (engine.name, engine.driver) in specs + + return check + + def fails_on(dbs, reason): """Mark a test as expected to fail on the specified database implementation. @@ -90,23 +105,25 @@ def fails_on(dbs, reason): succeeds, a failure is reported. """ + spec = db_spec(dbs) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name != dbs: + if not spec(config.db): return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected on DB implementation " - "'%s': %s" % ( - fn_name, config.db.name, reason)) + "'%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason)) return True else: raise AssertionError( - "Unexpected success for '%s' on DB implementation '%s'" % - (fn_name, config.db.name)) + "Unexpected success for '%s' on DB implementation '%s+%s'" % + (fn_name, config.db.name, config.db.driver)) return function_named(maybe, fn_name) return decorate @@ -117,23 +134,25 @@ def fails_on_everything_except(*dbs): databases except those listed. """ + spec = db_spec(*dbs) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name in dbs: + if spec(config.db): return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected on DB implementation " - "'%s': %s" % ( - fn_name, config.db.name, str(ex))) + "'%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, str(ex))) return True else: raise AssertionError( - "Unexpected success for '%s' on DB implementation '%s'" % - (fn_name, config.db.name)) + "Unexpected success for '%s' on DB implementation '%s+%s'" % + (fn_name, config.db.name, config.db.driver)) return function_named(maybe, fn_name) return decorate @@ -145,12 +164,13 @@ def crashes(db, reason): """ carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name == db: - msg = "'%s' unsupported on DB implementation '%s': %s" % ( - fn_name, config.db.name, reason) + if spec(config.db): + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) print msg if carp: print >> sys.stderr, msg @@ -169,12 +189,13 @@ def _block_unconditionally(db, reason): """ carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name == db: - msg = "'%s' unsupported on DB implementation '%s': %s" % ( - fn_name, config.db.name, reason) + if spec(config.db): + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) print msg if carp: print >> sys.stderr, msg @@ -198,6 +219,7 @@ def exclude(db, op, spec, reason): """ carp = _should_carp_about_exclusion(reason) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): @@ -242,7 +264,9 @@ def _is_excluded(db, op, spec): _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) """ - if config.db.name != db: + vendor_spec = db_spec(db) + + if not vendor_spec(config.db): return False version = _server_version() @@ -255,7 +279,12 @@ def _server_version(bind=None): if bind is None: bind = config.db - return bind.dialect.server_version_info(bind.contextual_connect()) + + # force metadata to be retrieved + conn = bind.connect() + version = getattr(bind.dialect, 'server_version_info', ()) + conn.close() + return version def skip_if(predicate, reason=None): """Skip a test if predicate is true.""" @@ -266,8 +295,7 @@ def skip_if(predicate, reason=None): if predicate(): msg = "'%s' skipped on DB %s version '%s': %s" % ( fn_name, config.db.name, _server_version(), reason) - print msg - return True + raise SkipTest(msg) else: return fn(*args, **kw) return function_named(maybe, fn_name) @@ -315,10 +343,12 @@ def emits_warning_on(db, *warnings): strings; these will be matched to the root of the warning description by warnings.filterwarnings(). """ + spec = db_spec(db) + def decorate(fn): def maybe(*args, **kw): if isinstance(db, basestring): - if config.db.name != db: + if not spec(config.db): return fn(*args, **kw) else: wrapped = emits_warning(*warnings)(fn) @@ -384,6 +414,19 @@ def resetwarnings(): if sys.version_info < (2, 4): warnings.filterwarnings('ignore', category=FutureWarning) +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + + testutil.lazy_gc() + assert not pool._refs + + def against(*queries): """Boolean predicate, compares to testing database configuration. @@ -394,21 +437,20 @@ def against(*queries): Also supports comparison to database version when provided with one or more 3-tuples of dialect name, operator, and version specification:: - testing.against('mysql', 'postgres') + testing.against('mysql', 'postgresql') testing.against(('mysql', '>=', (5, 0, 0)) """ for query in queries: if isinstance(query, basestring): - if config.db.name == query: + if db_spec(query)(config.db): return True else: name, op, spec = query - if config.db.name != name: + if not db_spec(name)(config.db): continue - have = config.db.dialect.server_version_info( - config.db.contextual_connect()) + have = _server_version() oper = hasattr(op, '__call__') and op or _ops[op] if oper(have, spec): @@ -545,16 +587,15 @@ class AssertsCompiledSQL(object): if dialect is None: dialect = getattr(self, '__dialect__', None) - if params is None: - keys = None - else: - keys = params.keys() + kw = {} + if params is not None: + kw['column_keys'] = params.keys() - c = clause.compile(column_keys=keys, dialect=dialect) + c = clause.compile(dialect=dialect, **kw) - print "\nSQL String:\n" + str(c) + repr(c.params) + print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {})) - cc = re.sub(r'\n', '', str(c)) + cc = re.sub(r'[\n\t]', '', str(c)) eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) @@ -563,18 +604,13 @@ class AssertsCompiledSQL(object): class ComparesTables(object): def assert_tables_equal(self, table, reflected_table): - base_mro = sqltypes.TypeEngine.__mro__ assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): eq_(c.name, reflected_c.name) assert reflected_c is reflected_table.c[c.name] eq_(c.primary_key, reflected_c.primary_key) eq_(c.nullable, reflected_c.nullable) - assert len( - set(type(reflected_c.type).__mro__).difference(base_mro).intersection( - set(type(c.type).__mro__).difference(base_mro) - ) - ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + self.assert_types_base(reflected_c, c) if isinstance(c.type, sqltypes.String): eq_(c.type.length, reflected_c.type.length) @@ -586,14 +622,21 @@ class ComparesTables(object): elif against(('mysql', '<', (5, 0))): # ignore reflection of bogus db-generated DefaultClause() pass - elif not c.primary_key or not against('postgres'): - print repr(c) + elif not c.primary_key or not against('postgresql', 'mssql'): + #print repr(c) assert reflected_c.default is None, reflected_c.default assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] - + + def assert_types_base(self, c1, c2): + base_mro = sqltypes.TypeEngine.__mro__ + assert len( + set(type(c1.type).__mro__).difference(base_mro).intersection( + set(type(c2.type).__mro__).difference(base_mro) + ) + ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (c1.name, c1.type, c2.type) class AssertsExecutionResults(object): def assert_result(self, result, class_, *objects): @@ -678,7 +721,7 @@ class AssertsExecutionResults(object): assertsql.asserter.clear_rules() def assert_sql(self, db, callable_, list_, with_sequences=None): - if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): + if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'): rules = with_sequences else: rules = list_ diff --git a/lib/sqlalchemy/test/util.py b/lib/sqlalchemy/test/util.py new file mode 100644 index 0000000000..60b0a4ef81 --- /dev/null +++ b/lib/sqlalchemy/test/util.py @@ -0,0 +1,24 @@ +from sqlalchemy.util import jython, function_named + +import gc +import time + +if jython: + def gc_collect(*args): + """aggressive gc.collect for tests.""" + gc.collect() + time.sleep(0.1) + gc.collect() + gc.collect() + return 0 + + # "lazy" gc, for VM's that don't GC on refcount == 0 + lazy_gc = gc_collect + +else: + # assume CPython - straight gc.collect, lazy_gc() is a pass + gc_collect = gc.collect + def lazy_gc(): + pass + + diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py index f9b9ad7b36..fbdb17963b 100644 --- a/lib/sqlalchemy/topological.py +++ b/lib/sqlalchemy/topological.py @@ -161,20 +161,21 @@ def _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=False): edges = _EdgeCollection() for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]: - if id(item) not in nodes: - node = _Node(item) - nodes[item] = node + item_id = id(item) + if item_id not in nodes: + nodes[item_id] = _Node(item) for t in tuples: + id0, id1 = id(t[0]), id(t[1]) if t[0] is t[1]: if allow_cycles: - n = nodes[t[0]] + n = nodes[id0] n.cycles = set([n]) elif not ignore_self_cycles: raise CircularDependencyError("Self-referential dependency detected " + repr(t)) continue - childnode = nodes[t[1]] - parentnode = nodes[t[0]] + childnode = nodes[id1] + parentnode = nodes[id0] edges.add((parentnode, childnode)) queue = [] @@ -210,7 +211,7 @@ def _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=False): node = queue.pop() if not hasattr(node, '_cyclical'): output.append(node) - del nodes[node.item] + del nodes[id(node.item)] for childnode in edges.pop_node(node): queue.append(childnode) return output @@ -293,8 +294,8 @@ def _find_cycles(edges): for parent in edges.get_parents(): traverse(parent) - # sets are not hashable, so uniquify with id - unique_cycles = dict((id(s), s) for s in cycles.values()).values() + unique_cycles = set(tuple(s) for s in cycles.values()) + for cycle in unique_cycles: edgecollection = [edge for edge in edges if edge[0] in cycle and edge[1] in cycle] diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index a03d6137df..692e63347b 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -11,11 +11,11 @@ types. For more information see the SQLAlchemy documentation on types. """ -__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', - 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT', - 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', - 'BOOLEAN', 'SMALLINT', 'DATE', 'TIME', - 'String', 'Integer', 'SmallInteger','Smallinteger', +__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType', + 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text', 'FLOAT', + 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'NCLOB', 'BLOB', + 'BOOLEAN', 'SMALLINT', 'INTEGER','DATE', 'TIME', + 'String', 'Integer', 'SmallInteger', 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'MutableType', 'Concatenable', 'UnicodeText', 'PickleType', 'Interval', 'type_map' @@ -24,17 +24,23 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', import inspect import datetime as dt from decimal import Decimal as _python_Decimal -import weakref + from sqlalchemy import exc from sqlalchemy.util import pickle +from sqlalchemy.sql.visitors import Visitable import sqlalchemy.util as util NoneType = type(None) +if util.jython: + import array -class AbstractType(object): +class AbstractType(Visitable): def __init__(self, *args, **kwargs): pass + def compile(self, dialect): + return dialect.type_compiler.process(self) + def copy_value(self, value): return value @@ -89,56 +95,23 @@ class AbstractType(object): for k in inspect.getargspec(self.__init__)[0][1:])) class TypeEngine(AbstractType): - """Base for built-in types. - - May be sub-classed to create entirely new types. Example:: - - import sqlalchemy.types as types - - class MyType(types.TypeEngine): - def __init__(self, precision = 8): - self.precision = precision - - def get_col_spec(self): - return "MYTYPE(%s)" % self.precision - - def bind_processor(self, dialect): - def process(value): - return value - return process - - def result_processor(self, dialect): - def process(value): - return value - return process - - Once the type is made, it's immediately usable:: - - table = Table('foo', meta, - Column('id', Integer, primary_key=True), - Column('data', MyType(16)) - ) - - """ + """Base for built-in types.""" + @util.memoized_property + def _impl_dict(self): + return {} + def dialect_impl(self, dialect, **kwargs): try: - return self._impl_dict[dialect] - except AttributeError: - self._impl_dict = weakref.WeakKeyDictionary() # will be optimized in 0.6 - return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self)) + return self._impl_dict[dialect.__class__] except KeyError: - return self._impl_dict.setdefault(dialect, dialect.type_descriptor(self)) + return self._impl_dict.setdefault(dialect.__class__, dialect.__class__.type_descriptor(self)) def __getstate__(self): d = self.__dict__.copy() d.pop('_impl_dict', None) return d - def get_col_spec(self): - """Return the DDL representation for this type.""" - raise NotImplementedError() - def bind_processor(self, dialect): """Return a conversion function for processing bind values. @@ -166,14 +139,42 @@ class TypeEngine(AbstractType): def adapt(self, cls): return cls() - def get_search_list(self): - """return a list of classes to test for a match - when adapting this type to a dialect-specific type. +class UserDefinedType(TypeEngine): + """Base for user defined types. + + This should be the base of new types. Note that + for most cases, :class:`TypeDecorator` is probably + more appropriate. - """ + import sqlalchemy.types as types - return self.__class__.__mro__[0:-1] + class MyType(types.UserDefinedType): + def __init__(self, precision = 8): + self.precision = precision + def get_col_spec(self): + return "MYTYPE(%s)" % self.precision + + def bind_processor(self, dialect): + def process(value): + return value + return process + + def result_processor(self, dialect): + def process(value): + return value + return process + + Once the type is made, it's immediately usable:: + + table = Table('foo', meta, + Column('id', Integer, primary_key=True), + Column('data', MyType(16)) + ) + + """ + __visit_name__ = "user_defined" + class TypeDecorator(AbstractType): """Allows the creation of types which add additional functionality to an existing type. @@ -214,19 +215,33 @@ class TypeDecorator(AbstractType): """ + __visit_name__ = "type_decorator" + def __init__(self, *args, **kwargs): if not hasattr(self.__class__, 'impl'): - raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated") + raise AssertionError("TypeDecorator implementations require a class-level " + "variable 'impl' which refers to the class of type being decorated") self.impl = self.__class__.impl(*args, **kwargs) - def dialect_impl(self, dialect, **kwargs): + def dialect_impl(self, dialect): try: - return self._impl_dict[dialect] + return self._impl_dict[dialect.__class__] except AttributeError: - self._impl_dict = weakref.WeakKeyDictionary() # will be optimized in 0.6 + self._impl_dict = {} except KeyError: pass + # adapt the TypeDecorator first, in + # the case that the dialect maps the TD + # to one of its native types (i.e. PGInterval) + adapted = dialect.__class__.type_descriptor(self) + if adapted is not self: + self._impl_dict[dialect] = adapted + return adapted + + # otherwise adapt the impl type, link + # to a copy of this TypeDecorator and return + # that. typedesc = self.load_dialect_impl(dialect) tt = self.copy() if not isinstance(tt, self.__class__): @@ -236,26 +251,30 @@ class TypeDecorator(AbstractType): self._impl_dict[dialect] = tt return tt + def type_engine(self, dialect): + impl = self.dialect_impl(dialect) + if not isinstance(impl, TypeDecorator): + return impl + else: + return impl.impl + def load_dialect_impl(self, dialect): """Loads the dialect-specific implementation of this type. by default calls dialect.type_descriptor(self.impl), but can be overridden to provide different behavior. + """ - if isinstance(self.impl, TypeDecorator): return self.impl.dialect_impl(dialect) else: - return dialect.type_descriptor(self.impl) + return dialect.__class__.type_descriptor(self.impl) def __getattr__(self, key): """Proxy all other undefined accessors to the underlying implementation.""" return getattr(self.impl, key) - def get_col_spec(self): - return self.impl.get_col_spec() - def process_bind_param(self, value, dialect): raise NotImplementedError() @@ -339,7 +358,7 @@ def to_instance(typeobj): def adapt_type(typeobj, colspecs): if isinstance(typeobj, type): typeobj = typeobj() - for t in typeobj.get_search_list(): + for t in typeobj.__class__.__mro__[0:-1]: try: impltype = colspecs[t] break @@ -370,9 +389,7 @@ class NullType(TypeEngine): encountered during a :meth:`~sqlalchemy.Table.create` operation. """ - - def get_col_spec(self): - raise NotImplementedError() + __visit_name__ = 'null' NullTypeEngine = NullType @@ -400,6 +417,8 @@ class String(Concatenable, TypeEngine): """ + __visit_name__ = 'string' + def __init__(self, length=None, convert_unicode=False, assert_unicode=None): """ Create a string-holding type. @@ -439,7 +458,10 @@ class String(Concatenable, TypeEngine): self.assert_unicode = assert_unicode def adapt(self, impltype): - return impltype(length=self.length, convert_unicode=self.convert_unicode, assert_unicode=self.assert_unicode) + return impltype( + length=self.length, + convert_unicode=self.convert_unicode, + assert_unicode=self.assert_unicode) def bind_processor(self, dialect): if self.convert_unicode or dialect.convert_unicode: @@ -447,18 +469,33 @@ class String(Concatenable, TypeEngine): assert_unicode = dialect.assert_unicode else: assert_unicode = self.assert_unicode - def process(value): - if isinstance(value, unicode): - return value.encode(dialect.encoding) - elif assert_unicode and not isinstance(value, (unicode, NoneType)): - if assert_unicode == 'warn': - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) + + if dialect.supports_unicode_binds and assert_unicode: + def process(value): + if not isinstance(value, (unicode, NoneType)): + if assert_unicode == 'warn': + util.warn("Unicode type received non-unicode bind " + "param value %r" % value) + return value + else: + raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) + else: return value + elif dialect.supports_unicode_binds: + return None + else: + def process(value): + if isinstance(value, unicode): + return value.encode(dialect.encoding) + elif assert_unicode and not isinstance(value, (unicode, NoneType)): + if assert_unicode == 'warn': + util.warn("Unicode type received non-unicode bind " + "param value %r" % value) + return value + else: + raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) else: - raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) - else: - return value + return value return process else: return None @@ -485,6 +522,7 @@ class Text(String): params (and the reverse for result sets.) """ + __visit_name__ = 'text' class Unicode(String): """A variable length Unicode string. @@ -511,10 +549,10 @@ class Unicode(String): :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which defaults to `utf-8`. - A synonym for String(length, convert_unicode=True, assert_unicode='warn'). - """ + __visit_name__ = 'unicode' + def __init__(self, length=None, **kwargs): """ Create a Unicode-converting String type. @@ -532,7 +570,14 @@ class Unicode(String): super(Unicode, self).__init__(length=length, **kwargs) class UnicodeText(Text): - """A synonym for Text(convert_unicode=True, assert_unicode='warn').""" + """An unbounded-length Unicode string. + + See :class:`Unicode` for details on the unicode + behavior of this object. + + """ + + __visit_name__ = 'unicode_text' def __init__(self, length=None, **kwargs): """ @@ -553,7 +598,9 @@ class UnicodeText(Text): class Integer(TypeEngine): """A type for ``int`` integers.""" - + + __visit_name__ = 'integer' + def get_dbapi_type(self, dbapi): return dbapi.NUMBER @@ -566,7 +613,17 @@ class SmallInteger(Integer): """ -Smallinteger = SmallInteger + __visit_name__ = 'small_integer' + +class BigInteger(Integer): + """A type for bigger ``int`` integers. + + Typically generates a ``BIGINT`` in DDL, and otherwise acts like + a normal :class:`Integer` on the Python side. + + """ + + __visit_name__ = 'big_integer' class Numeric(TypeEngine): """A type for fixed precision numbers. @@ -576,7 +633,9 @@ class Numeric(TypeEngine): """ - def __init__(self, precision=10, scale=2, asdecimal=True, length=None): + __visit_name__ = 'numeric' + + def __init__(self, precision=None, scale=None, asdecimal=True): """ Construct a Numeric. @@ -590,9 +649,6 @@ class Numeric(TypeEngine): use. """ - if length: - util.warn_deprecated("'length' is deprecated for Numeric. Use 'scale'.") - scale = length self.precision = precision self.scale = scale self.asdecimal = asdecimal @@ -626,7 +682,9 @@ class Numeric(TypeEngine): class Float(Numeric): """A type for ``float`` numbers.""" - def __init__(self, precision=10, asdecimal=False, **kwargs): + __visit_name__ = 'float' + + def __init__(self, precision=None, asdecimal=False, **kwargs): """ Construct a Float. @@ -650,7 +708,9 @@ class DateTime(TypeEngine): converted back to datetime objects when rows are returned. """ - + + __visit_name__ = 'datetime' + def __init__(self, timezone=False): self.timezone = timezone @@ -664,6 +724,8 @@ class DateTime(TypeEngine): class Date(TypeEngine): """A type for ``datetime.date()`` objects.""" + __visit_name__ = 'date' + def get_dbapi_type(self, dbapi): return dbapi.DATETIME @@ -671,6 +733,8 @@ class Date(TypeEngine): class Time(TypeEngine): """A type for ``datetime.time()`` objects.""" + __visit_name__ = 'time' + def __init__(self, timezone=False): self.timezone = timezone @@ -690,6 +754,8 @@ class Binary(TypeEngine): """ + __visit_name__ = 'binary' + def __init__(self, length=None): """ Construct a Binary type. @@ -775,7 +841,15 @@ class PickleType(MutableType, TypeDecorator): loads = self.pickler.loads if value is None: return None - return loads(str(value)) + # Py3K + #return loads(value) + # Py2K + if util.jython and isinstance(value, array.ArrayType): + value = value.tostring() + else: + value = str(value) + return loads(value) + # end Py2K def copy_value(self, value): if self.mutable: @@ -786,9 +860,6 @@ class PickleType(MutableType, TypeDecorator): def compare_values(self, x, y): if self.comparator: return self.comparator(x, y) - elif self.mutable and not hasattr(x, '__eq__') and x is not None: - util.warn_deprecated("Objects stored with PickleType when mutable=True must implement __eq__() for reliable comparison.") - return self.pickler.dumps(x, self.protocol) == self.pickler.dumps(y, self.protocol) else: return x == y @@ -804,6 +875,7 @@ class Boolean(TypeEngine): """ + __visit_name__ = 'boolean' class Interval(TypeDecorator): """A type for ``datetime.timedelta()`` objects. @@ -815,106 +887,133 @@ class Interval(TypeDecorator): """ - impl = TypeEngine - - def __init__(self): - super(Interval, self).__init__() - import sqlalchemy.databases.postgres as pg - self.__supported = {pg.PGDialect:pg.PGInterval} - del pg - - def load_dialect_impl(self, dialect): - if dialect.__class__ in self.__supported: - return self.__supported[dialect.__class__]() - else: - return dialect.type_descriptor(DateTime) + impl = DateTime def process_bind_param(self, value, dialect): - if dialect.__class__ in self.__supported: - return value - else: - if value is None: - return None - return dt.datetime.utcfromtimestamp(0) + value + if value is None: + return None + return dt.datetime.utcfromtimestamp(0) + value def process_result_value(self, value, dialect): - if dialect.__class__ in self.__supported: - return value - else: - if value is None: - return None - return value - dt.datetime.utcfromtimestamp(0) + if value is None: + return None + return value - dt.datetime.utcfromtimestamp(0) class FLOAT(Float): """The SQL FLOAT type.""" + __visit_name__ = 'FLOAT' class NUMERIC(Numeric): """The SQL NUMERIC type.""" + __visit_name__ = 'NUMERIC' + class DECIMAL(Numeric): """The SQL DECIMAL type.""" + __visit_name__ = 'DECIMAL' -class INT(Integer): + +class INTEGER(Integer): """The SQL INT or INTEGER type.""" + __visit_name__ = 'INTEGER' +INT = INTEGER -INTEGER = INT -class SMALLINT(Smallinteger): +class SMALLINT(SmallInteger): """The SQL SMALLINT type.""" + __visit_name__ = 'SMALLINT' + + +class BIGINT(BigInteger): + """The SQL BIGINT type.""" + + __visit_name__ = 'BIGINT' class TIMESTAMP(DateTime): """The SQL TIMESTAMP type.""" + __visit_name__ = 'TIMESTAMP' + + def get_dbapi_type(self, dbapi): + return dbapi.TIMESTAMP class DATETIME(DateTime): """The SQL DATETIME type.""" + __visit_name__ = 'DATETIME' + class DATE(Date): """The SQL DATE type.""" + __visit_name__ = 'DATE' + class TIME(Time): """The SQL TIME type.""" + __visit_name__ = 'TIME' -TEXT = Text +class TEXT(Text): + """The SQL TEXT type.""" + + __visit_name__ = 'TEXT' class CLOB(Text): - """The SQL CLOB type.""" + """The CLOB type. + + This type is found in Oracle and Informix. + """ + __visit_name__ = 'CLOB' class VARCHAR(String): """The SQL VARCHAR type.""" + __visit_name__ = 'VARCHAR' + +class NVARCHAR(Unicode): + """The SQL NVARCHAR type.""" + + __visit_name__ = 'NVARCHAR' class CHAR(String): """The SQL CHAR type.""" + __visit_name__ = 'CHAR' + class NCHAR(Unicode): """The SQL NCHAR type.""" + __visit_name__ = 'NCHAR' + class BLOB(Binary): """The SQL BLOB type.""" + __visit_name__ = 'BLOB' + class BOOLEAN(Boolean): """The SQL BOOLEAN type.""" + __visit_name__ = 'BOOLEAN' + NULLTYPE = NullType() # using VARCHAR/NCHAR so that we dont get the genericized "String" # type which usually resolves to TEXT/CLOB type_map = { - str : VARCHAR, - unicode : NCHAR, + str: String, + # Py2K + unicode : String, + # end Py2K int : Integer, float : Numeric, bool: Boolean, @@ -923,5 +1022,6 @@ type_map = { dt.datetime : DateTime, dt.time : Time, dt.timedelta : Interval, - type(None): NullType + NoneType: NullType } + diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 8eeeda4555..f970f3737d 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -4,8 +4,10 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import inspect, itertools, operator, sys, warnings, weakref +import inspect, itertools, operator, sys, warnings, weakref, gc +# Py2K import __builtin__ +# end Py2K types = __import__('types') from sqlalchemy import exc @@ -16,6 +18,7 @@ except ImportError: import dummy_threading as threading py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0) +jython = sys.platform.startswith('java') if py3k: set_types = set @@ -38,6 +41,8 @@ else: EMPTY_SET = frozenset() +NoneType = type(None) + if py3k: import pickle else: @@ -46,11 +51,13 @@ else: except ImportError: import pickle +# Py2K # a controversial feature, required by MySQLdb currently def buffer(x): return x buffer = getattr(__builtin__, 'buffer', buffer) +# end Py2K if sys.version_info >= (2, 5): class PopulateDict(dict): @@ -84,12 +91,13 @@ else: if py3k: def callable(fn): return hasattr(fn, '__call__') -else: - callable = __builtin__.callable + def cmp(a, b): + return (a > b) - (a < b) -if py3k: from functools import reduce else: + callable = __builtin__.callable + cmp = __builtin__.cmp reduce = __builtin__.reduce try: @@ -262,6 +270,15 @@ else: def decode_slice(slc): return (slc.start, slc.stop, slc.step) +def update_copy(d, _new=None, **kw): + """Copy the given dict and update with the given values.""" + + d = d.copy() + if _new: + d.update(_new) + d.update(**kw) + return d + def flatten_iterator(x): """Given an iterator of which further sub-elements may also be iterators, flatten the sub-elements into a single iterator. @@ -296,6 +313,7 @@ def get_cls_kwargs(cls): class_ = stack.pop() ctr = class_.__dict__.get('__init__', False) if not ctr or not isinstance(ctr, types.FunctionType): + stack.update(class_.__bases__) continue names, _, has_kw, _ = inspect.getargspec(ctr) args.update(names) @@ -419,20 +437,32 @@ def class_hierarchy(cls): will not be descended. """ + # Py2K if isinstance(cls, types.ClassType): return list() + # end Py2K hier = set([cls]) process = list(cls.__mro__) while process: c = process.pop() + # Py2K if isinstance(c, types.ClassType): continue for b in (_ for _ in c.__bases__ if _ not in hier and not isinstance(_, types.ClassType)): + # end Py2K + # Py3K + #for b in (_ for _ in c.__bases__ + # if _ not in hier): process.append(b) hier.add(b) + # Py3K + #if c.__module__ == 'builtins' or not hasattr(c, '__subclasses__'): + # continue + # Py2K if c.__module__ == '__builtin__' or not hasattr(c, '__subclasses__'): continue + # end Py2K for s in [_ for _ in c.__subclasses__() if _ not in hier]: process.append(s) hier.add(s) @@ -664,12 +694,11 @@ class OrderedProperties(object): return self._data.keys() def has_key(self, key): - return self._data.has_key(key) + return key in self._data def clear(self): self._data.clear() - class OrderedDict(dict): """A dict that returns keys/values/items in the order they were added.""" @@ -735,7 +764,12 @@ class OrderedDict(dict): def __setitem__(self, key, object): if key not in self: - self._list.append(key) + try: + self._list.append(key) + except AttributeError: + # work around Python pickle loads() with + # dict subclass (seems to ignore __setstate__?) + self._list = [key] dict.__setitem__(self, key, object) def __delitem__(self, key): @@ -915,7 +949,7 @@ class IdentitySet(object): if len(self) > len(other): return False - for m in itertools.ifilterfalse(other._members.has_key, + for m in itertools.ifilterfalse(other._members.__contains__, self._members.iterkeys()): return False return True @@ -1416,8 +1450,11 @@ class WeakIdentityMapping(weakref.WeakKeyDictionary): return item def clear(self): + # Py2K + # in 3k, MutableMapping calls popitem() self._weakrefs.clear() self.by_id.clear() + # end Py2K weakref.WeakKeyDictionary.clear(self) def update(self, *a, **kw): diff --git a/sa2to3.py b/sa2to3.py new file mode 100644 index 0000000000..9c06dafabd --- /dev/null +++ b/sa2to3.py @@ -0,0 +1,76 @@ +"""SQLAlchemy 2to3 tool. + +Relax ! This just calls the regular 2to3 tool with a preprocessor bolted onto it. + + +I originally wanted to write a custom fixer to accomplish this +but the Fixer classes seem like they can only understand +the grammar file included with 2to3, and the grammar does not +seem to include Python comments (and of course, huge hacks needed +to get out-of-package fixers in there). While that may be +an option later on this is a pretty simple approach for +what is a pretty simple problem. + +""" + +from lib2to3 import main, refactor + +import re + +py3k_pattern = re.compile(r'\s*# Py3K') +comment_pattern = re.compile(r'(\s*)#(?! ?Py2K)(.*)') +py2k_pattern = re.compile(r'\s*# Py2K') +end_py2k_pattern = re.compile(r'\s*# end Py2K') + +def preprocess(data): + lines = data.split('\n') + def consume_normal(): + while lines: + line = lines.pop(0) + if py3k_pattern.match(line): + for line in consume_py3k(): + yield line + elif py2k_pattern.match(line): + for line in consume_py2k(): + yield line + else: + yield line + + def consume_py3k(): + yield "# start Py3K" + while lines: + line = lines.pop(0) + m = comment_pattern.match(line) + if m: + yield "%s%s" % m.group(1, 2) + else: + # pushback + lines.insert(0, line) + break + yield "# end Py3K" + + def consume_py2k(): + yield "# start Py2K" + while lines: + line = lines.pop(0) + if not end_py2k_pattern.match(line): + yield "#%s" % line + else: + break + yield "# end Py2K" + + return "\n".join(consume_normal()) + +old_refactor_string = main.StdoutRefactoringTool.refactor_string + +def refactor_string(self, data, name): + newdata = preprocess(data) + tree = old_refactor_string(self, newdata, name) + if tree: + if newdata != data: + tree.was_changed = True + return tree + +main.StdoutRefactoringTool.refactor_string = refactor_string + +main.main("lib2to3.fixes") diff --git a/setup.py b/setup.py index 3d65f022e0..12925a1154 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def find_packages(dir_): if sys.version_info < (2, 4): raise Exception("SQLAlchemy requires Python 2.4 or higher.") -v = file(os.path.join(os.path.dirname(__file__), 'lib', 'sqlalchemy', '__init__.py')) +v = open(os.path.join(os.path.dirname(__file__), 'lib', 'sqlalchemy', '__init__.py')) VERSION = re.compile(r".*__version__ = '(.*?)'", re.S).match(v.read()).group(1) v.close() @@ -31,7 +31,8 @@ setup(name = "SQLAlchemy", packages = find_packages('lib'), package_dir = {'':'lib'}, license = "MIT License", - + tests_require = ['nose >= 0.10'], + test_suite = "nose.collector", entry_points = { 'nose.plugins.0.10': [ 'sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy', diff --git a/sqla_nose.py b/sqla_nose.py new file mode 100644 index 0000000000..0542b4e5dd --- /dev/null +++ b/sqla_nose.py @@ -0,0 +1,22 @@ +""" +nose runner script. + +Only use this script if setuptools is not available, i.e. such as +on Python 3K. Otherwise consult README.unittests for the +recommended methods of running tests. + +""" + +import nose +from sqlalchemy.test.noseplugin import NoseSQLAlchemy +from sqlalchemy.util import py3k + +if __name__ == '__main__': + if py3k: + # this version breaks verbose output, + # but is the only API that nose3 currently supports + nose.main(plugins=[NoseSQLAlchemy()]) + else: + # this is the "correct" API + nose.main(addplugins=[NoseSQLAlchemy()]) + diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index 3e4274d47d..79ae09b054 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -15,15 +15,15 @@ class CompileTest(TestBase, AssertsExecutionResults): Column('c1', Integer, primary_key=True), Column('c2', String(30))) - @profiling.function_call_count(68, {'2.4': 42}) + @profiling.function_call_count(72, {'2.4': 42, '3.0':77}) def test_insert(self): t1.insert().compile() - @profiling.function_call_count(68, {'2.4': 45}) + @profiling.function_call_count(72, {'2.4': 45}) def test_update(self): t1.update().compile() - @profiling.function_call_count(185, versions={'2.4':118}) + @profiling.function_call_count(195, versions={'2.4':118, '3.0':208}) def test_select(self): s = select([t1], t1.c.c2==t2.c.c1) s.compile() diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 70a3cf8cd6..fbf0560ca1 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1,17 +1,22 @@ from sqlalchemy.test.testing import eq_ -import gc from sqlalchemy.orm import mapper, relation, create_session, clear_mappers, sessionmaker from sqlalchemy.orm.mapper import _mapper_registry from sqlalchemy.orm.session import _sessions +from sqlalchemy.util import jython import operator from sqlalchemy.test import testing from sqlalchemy import MetaData, Integer, String, ForeignKey, PickleType -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column import sqlalchemy as sa from sqlalchemy.sql import column +from sqlalchemy.test.util import gc_collect +import gc from test.orm import _base +if jython: + from nose import SkipTest + raise SkipTest("Profiling not supported on this platform") + class A(_base.ComparableEntity): pass @@ -22,11 +27,11 @@ def profile_memory(func): # run the test 50 times. if length of gc.get_objects() # keeps growing, assert false def profile(*args): - gc.collect() + gc_collect() samples = [0 for x in range(0, 50)] for x in range(0, 50): func(*args) - gc.collect() + gc_collect() samples[x] = len(gc.get_objects()) print "sample gc sizes:", samples @@ -50,7 +55,7 @@ def profile_memory(func): def assert_no_mappers(): clear_mappers() - gc.collect() + gc_collect() assert len(_mapper_registry) == 0 class EnsureZeroed(_base.ORMTest): @@ -61,7 +66,7 @@ class EnsureZeroed(_base.ORMTest): class MemUsageTest(EnsureZeroed): # ensure a pure growing test trips the assertion - @testing.fails_if(lambda:True) + @testing.fails_if(lambda: True) def test_fixture(self): class Foo(object): pass @@ -76,11 +81,11 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30))) table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30)), Column('col3', Integer, ForeignKey("mytable.col1"))) @@ -129,11 +134,11 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30))) table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30)), Column('col3', Integer, ForeignKey("mytable.col1"))) @@ -184,13 +189,13 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30)) ) table2 = Table("mytable2", metadata, Column('col1', Integer, ForeignKey('mytable.col1'), - primary_key=True), + primary_key=True, test_needs_autoincrement=True), Column('col3', String(30)), ) @@ -244,12 +249,12 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30)) ) table2 = Table("mytable2", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', String(30)), ) @@ -308,12 +313,12 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("table1", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)) ) table2 = Table("table2", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('t1id', Integer, ForeignKey('table1.id')) ) @@ -347,7 +352,7 @@ class MemUsageTest(EnsureZeroed): metadata = MetaData(testing.db) table1 = Table("mytable", metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('col2', PickleType(comparator=operator.eq)) ) @@ -382,7 +387,7 @@ class MemUsageTest(EnsureZeroed): testing.eq_(len(session.identity_map._mutable_attrs), 12) testing.eq_(len(session.identity_map), 12) obj = None - gc.collect() + gc_collect() testing.eq_(len(session.identity_map._mutable_attrs), 0) testing.eq_(len(session.identity_map), 0) @@ -392,7 +397,7 @@ class MemUsageTest(EnsureZeroed): metadata.drop_all() def test_type_compile(self): - from sqlalchemy.databases.sqlite import SQLiteDialect + from sqlalchemy.dialects.sqlite.base import dialect as SQLiteDialect cast = sa.cast(column('x'), sa.Integer) @profile_memory def go(): diff --git a/test/aaa_profiling/test_pool.py b/test/aaa_profiling/test_pool.py index 7bb61deb28..6ae3edc989 100644 --- a/test/aaa_profiling/test_pool.py +++ b/test/aaa_profiling/test_pool.py @@ -5,6 +5,9 @@ from sqlalchemy.pool import QueuePool class QueuePoolTest(TestBase, AssertsExecutionResults): class Connection(object): + def rollback(self): + pass + def close(self): pass @@ -15,7 +18,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults): use_threadlocal=True) - @profiling.function_call_count(54, {'2.4': 38}) + @profiling.function_call_count(54, {'2.4': 38, '3.0':57}) def test_first_connect(self): conn = pool.connect() @@ -23,7 +26,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults): conn = pool.connect() conn.close() - @profiling.function_call_count(31, {'2.4': 21}) + @profiling.function_call_count(29, {'2.4': 21}) def go(): conn2 = pool.connect() return conn2 @@ -32,7 +35,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults): def test_second_samethread_connect(self): conn = pool.connect() - @profiling.function_call_count(5, {'2.4': 3}) + @profiling.function_call_count(5, {'2.4': 3, '3.0':6}) def go(): return pool.connect() c2 = go() diff --git a/test/aaa_profiling/test_zoomark.py b/test/aaa_profiling/test_zoomark.py index be29318964..e413031926 100644 --- a/test/aaa_profiling/test_zoomark.py +++ b/test/aaa_profiling/test_zoomark.py @@ -26,7 +26,7 @@ class ZooMarkTest(TestBase): """ - __only_on__ = 'postgres' + __only_on__ = 'postgresql+psycopg2' __skip_if__ = ((lambda: sys.version_info < (2, 4)), ) def test_baseline_0_setup(self): @@ -75,15 +75,15 @@ class ZooMarkTest(TestBase): Opens=datetime.time(8, 15, 59), LastEscape=datetime.datetime(2004, 7, 29, 5, 6, 7), Admission=4.95, - ).last_inserted_ids()[0] + ).inserted_primary_key[0] sdz = Zoo.insert().execute(Name =u'San Diego Zoo', Founded = datetime.date(1935, 9, 13), Opens = datetime.time(9, 0, 0), Admission = 0, - ).last_inserted_ids()[0] + ).inserted_primary_key[0] - Zoo.insert().execute( + Zoo.insert(inline=True).execute( Name = u'Montr\xe9al Biod\xf4me', Founded = datetime.date(1992, 6, 19), Opens = datetime.time(9, 0, 0), @@ -91,48 +91,48 @@ class ZooMarkTest(TestBase): ) seaworld = Zoo.insert().execute( - Name =u'Sea_World', Admission = 60).last_inserted_ids()[0] + Name =u'Sea_World', Admission = 60).inserted_primary_key[0] # Let's add a crazy futuristic Zoo to test large date values. lp = Zoo.insert().execute(Name =u'Luna Park', Founded = datetime.date(2072, 7, 17), Opens = datetime.time(0, 0, 0), Admission = 134.95, - ).last_inserted_ids()[0] + ).inserted_primary_key[0] # Animals leopardid = Animal.insert().execute(Species=u'Leopard', Lifespan=73.5, - ).last_inserted_ids()[0] + ).inserted_primary_key[0] Animal.update(Animal.c.ID==leopardid).execute(ZooID=wap, LastEscape=datetime.datetime(2004, 12, 21, 8, 15, 0, 999907)) - lion = Animal.insert().execute(Species=u'Lion', ZooID=wap).last_inserted_ids()[0] + lion = Animal.insert().execute(Species=u'Lion', ZooID=wap).inserted_primary_key[0] Animal.insert().execute(Species=u'Slug', Legs=1, Lifespan=.75) tiger = Animal.insert().execute(Species=u'Tiger', ZooID=sdz - ).last_inserted_ids()[0] + ).inserted_primary_key[0] # Override Legs.default with itself just to make sure it works. - Animal.insert().execute(Species=u'Bear', Legs=4) - Animal.insert().execute(Species=u'Ostrich', Legs=2, Lifespan=103.2) - Animal.insert().execute(Species=u'Centipede', Legs=100) + Animal.insert(inline=True).execute(Species=u'Bear', Legs=4) + Animal.insert(inline=True).execute(Species=u'Ostrich', Legs=2, Lifespan=103.2) + Animal.insert(inline=True).execute(Species=u'Centipede', Legs=100) emp = Animal.insert().execute(Species=u'Emperor Penguin', Legs=2, - ZooID=seaworld).last_inserted_ids()[0] + ZooID=seaworld).inserted_primary_key[0] adelie = Animal.insert().execute(Species=u'Adelie Penguin', Legs=2, - ZooID=seaworld).last_inserted_ids()[0] + ZooID=seaworld).inserted_primary_key[0] - Animal.insert().execute(Species=u'Millipede', Legs=1000000, ZooID=sdz) + Animal.insert(inline=True).execute(Species=u'Millipede', Legs=1000000, ZooID=sdz) # Add a mother and child to test relationships bai_yun = Animal.insert().execute(Species=u'Ape', Name=u'Bai Yun', - Legs=2).last_inserted_ids()[0] - Animal.insert().execute(Species=u'Ape', Name=u'Hua Mei', Legs=2, + Legs=2).inserted_primary_key[0] + Animal.insert(inline=True).execute(Species=u'Ape', Name=u'Hua Mei', Legs=2, MotherID=bai_yun) def test_baseline_2_insert(self): Animal = metadata.tables['Animal'] - i = Animal.insert() + i = Animal.insert(inline=True) for x in xrange(ITERATIONS): tick = i.execute(Species=u'Tick', Name=u'Tick %d' % x, Legs=8) @@ -142,7 +142,7 @@ class ZooMarkTest(TestBase): def fullobject(select): """Iterate over the full result row.""" - return list(select.execute().fetchone()) + return list(select.execute().first()) for x in xrange(ITERATIONS): # Zoos @@ -254,7 +254,7 @@ class ZooMarkTest(TestBase): for x in xrange(ITERATIONS): # Edit - SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().fetchone() + SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().first() Zoo.update(Zoo.c.ID==SDZ['ID']).execute( Name=u'The San Diego Zoo', Founded = datetime.date(1900, 1, 1), @@ -262,7 +262,7 @@ class ZooMarkTest(TestBase): Admission = "35.00") # Test edits - SDZ = Zoo.select(Zoo.c.Name==u'The San Diego Zoo').execute().fetchone() + SDZ = Zoo.select(Zoo.c.Name==u'The San Diego Zoo').execute().first() assert SDZ['Founded'] == datetime.date(1900, 1, 1), SDZ['Founded'] # Change it back @@ -273,7 +273,7 @@ class ZooMarkTest(TestBase): Admission = "0") # Test re-edits - SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().fetchone() + SDZ = Zoo.select(Zoo.c.Name==u'San Diego Zoo').execute().first() assert SDZ['Founded'] == datetime.date(1935, 9, 13) def test_baseline_7_multiview(self): @@ -316,10 +316,10 @@ class ZooMarkTest(TestBase): global metadata player = lambda: dbapi_session.player() - engine = create_engine('postgres:///', creator=player) + engine = create_engine('postgresql:///', creator=player) metadata = MetaData(engine) - @profiling.function_call_count(3230, {'2.4': 1796}) + @profiling.function_call_count(2991, {'2.4': 1796}) def test_profile_1_create_tables(self): self.test_baseline_1_create_tables() @@ -327,7 +327,7 @@ class ZooMarkTest(TestBase): def test_profile_1a_populate(self): self.test_baseline_1a_populate() - @profiling.function_call_count(322, {'2.4': 202}) + @profiling.function_call_count(305, {'2.4': 202}) def test_profile_2_insert(self): self.test_baseline_2_insert() diff --git a/test/aaa_profiling/test_zoomark_orm.py b/test/aaa_profiling/test_zoomark_orm.py index 57e1e24049..660f478110 100644 --- a/test/aaa_profiling/test_zoomark_orm.py +++ b/test/aaa_profiling/test_zoomark_orm.py @@ -27,7 +27,7 @@ class ZooMarkTest(TestBase): """ - __only_on__ = 'postgres' + __only_on__ = 'postgresql+psycopg2' __skip_if__ = ((lambda: sys.version_info < (2, 5)), ) # TODO: get 2.4 support def test_baseline_0_setup(self): @@ -281,7 +281,7 @@ class ZooMarkTest(TestBase): global metadata, session player = lambda: dbapi_session.player() - engine = create_engine('postgres:///', creator=player) + engine = create_engine('postgresql:///', creator=player) metadata = MetaData(engine) session = sessionmaker()() diff --git a/test/base/test_dependency.py b/test/base/test_dependency.py index 0457d552a4..890dd76078 100644 --- a/test/base/test_dependency.py +++ b/test/base/test_dependency.py @@ -178,9 +178,12 @@ class DependencySortTest(TestBase): self.assert_sort(tuples, head) def testbigsort(self): - tuples = [] - for i in range(0,1500, 2): - tuples.append((i, i+1)) + tuples = [(i, i + 1) for i in range(0, 1500, 2)] head = topological.sort_as_tree(tuples, []) + def testids(self): + # ticket:1380 regression: would raise a KeyError + topological.sort([(id(i), i) for i in range(3)], []) + + diff --git a/test/base/test_except.py b/test/base/test_except.py index efb18a153c..fbe0a05de4 100644 --- a/test/base/test_except.py +++ b/test/base/test_except.py @@ -1,10 +1,14 @@ """Tests exceptions and DB-API exception wrapping.""" -import exceptions as stdlib_exceptions from sqlalchemy import exc as sa_exceptions from sqlalchemy.test import TestBase +# Py3K +#StandardError = BaseException +# Py2K +from exceptions import StandardError, KeyboardInterrupt, SystemExit +# end Py2K -class Error(stdlib_exceptions.StandardError): +class Error(StandardError): """This class will be old-style on <= 2.4 and new-style on >= 2.5.""" class DatabaseError(Error): pass @@ -101,19 +105,19 @@ class WrapTest(TestBase): def test_db_error_keyboard_interrupt(self): try: raise sa_exceptions.DBAPIError.instance( - '', [], stdlib_exceptions.KeyboardInterrupt()) + '', [], KeyboardInterrupt()) except sa_exceptions.DBAPIError: self.assert_(False) - except stdlib_exceptions.KeyboardInterrupt: + except KeyboardInterrupt: self.assert_(True) def test_db_error_system_exit(self): try: raise sa_exceptions.DBAPIError.instance( - '', [], stdlib_exceptions.SystemExit()) + '', [], SystemExit()) except sa_exceptions.DBAPIError: self.assert_(False) - except stdlib_exceptions.SystemExit: + except SystemExit: self.assert_(True) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 39561e9682..e4c2eaba05 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -3,6 +3,7 @@ import copy, threading from sqlalchemy import util, sql, exc from sqlalchemy.test import TestBase from sqlalchemy.test.testing import eq_, is_, ne_ +from sqlalchemy.test.util import gc_collect class OrderedDictTest(TestBase): def test_odict(self): @@ -260,7 +261,7 @@ class IdentitySetTest(TestBase): except TypeError: assert True - assert_raises(TypeError, cmp, ids) + assert_raises(TypeError, util.cmp, ids) assert_raises(TypeError, hash, ids) def test_difference(self): @@ -325,11 +326,13 @@ class DictlikeIteritemsTest(TestBase): d = subdict(a=1,b=2,c=3) self._ok(d) + # Py2K def test_UserDict(self): import UserDict d = UserDict.UserDict(a=1,b=2,c=3) self._ok(d) - + # end Py2K + def test_object(self): self._notok(object()) @@ -339,12 +342,15 @@ class DictlikeIteritemsTest(TestBase): return iter(self.baseline) self._ok(duck1()) + # Py2K def test_duck_2(self): class duck2(object): def items(duck): return list(self.baseline) self._ok(duck2()) + # end Py2K + # Py2K def test_duck_3(self): class duck3(object): def iterkeys(duck): @@ -352,6 +358,7 @@ class DictlikeIteritemsTest(TestBase): def __getitem__(duck, key): return dict(a=1,b=2,c=3).get(key) self._ok(duck3()) + # end Py2K def test_duck_4(self): class duck4(object): @@ -376,16 +383,20 @@ class DictlikeIteritemsTest(TestBase): class DuckTypeCollectionTest(TestBase): def test_sets(self): + # Py2K import sets + # end Py2K class SetLike(object): def add(self): pass class ForcedSet(list): __emulates__ = set - + for type_ in (set, + # Py2K sets.Set, + # end Py2K SetLike, ForcedSet): eq_(util.duck_type_collection(type_), set) @@ -393,12 +404,14 @@ class DuckTypeCollectionTest(TestBase): eq_(util.duck_type_collection(instance), set) for type_ in (frozenset, - sets.ImmutableSet): + # Py2K + sets.ImmutableSet + # end Py2K + ): is_(util.duck_type_collection(type_), None) instance = type_() is_(util.duck_type_collection(instance), None) - class ArgInspectionTest(TestBase): def test_get_cls_kwargs(self): class A(object): @@ -646,6 +659,8 @@ class WeakIdentityMappingTest(TestBase): assert len(data) == len(wim) == len(wim.by_id) del data[:] + gc_collect() + eq_(wim, {}) eq_(wim.by_id, {}) eq_(wim._weakrefs, {}) @@ -657,6 +672,7 @@ class WeakIdentityMappingTest(TestBase): oid = id(data[0]) del data[0] + gc_collect() assert len(data) == len(wim) == len(wim.by_id) assert oid not in wim.by_id @@ -679,6 +695,7 @@ class WeakIdentityMappingTest(TestBase): th.start() cv.wait() cv.release() + gc_collect() eq_(wim, {}) eq_(wim.by_id, {}) @@ -939,7 +956,8 @@ class TestClassHierarchy(TestBase): eq_(set(util.class_hierarchy(A)), set((A, B, C, object))) eq_(set(util.class_hierarchy(B)), set((A, B, C, object))) - + + # Py2K def test_oldstyle_mixin(self): class A(object): pass @@ -953,5 +971,5 @@ class TestClassHierarchy(TestBase): eq_(set(util.class_hierarchy(B)), set((A, B, object))) eq_(set(util.class_hierarchy(Mixin)), set()) eq_(set(util.class_hierarchy(A)), set((A, B, object))) - + # end Py2K diff --git a/test/dialect/test_firebird.py b/test/dialect/test_firebird.py index fa608c9a18..2dc6af91b7 100644 --- a/test/dialect/test_firebird.py +++ b/test/dialect/test_firebird.py @@ -50,6 +50,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): con.execute('DROP GENERATOR gen_testtable_id') def test_table_is_reflected(self): + from sqlalchemy.types import Integer, Text, Binary, String, Date, Time, DateTime metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True) eq_(set(table.columns.keys()), @@ -57,17 +58,17 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): "Columns of reflected table didn't equal expected columns") eq_(table.c.question.primary_key, True) eq_(table.c.question.sequence.name, 'gen_testtable_id') - eq_(table.c.question.type.__class__, firebird.FBInteger) + assert isinstance(table.c.question.type, Integer) eq_(table.c.question.server_default.arg.text, "42") - eq_(table.c.answer.type.__class__, firebird.FBString) + assert isinstance(table.c.answer.type, String) eq_(table.c.answer.server_default.arg.text, "'no answer'") - eq_(table.c.remark.type.__class__, firebird.FBText) + assert isinstance(table.c.remark.type, Text) eq_(table.c.remark.server_default.arg.text, "''") - eq_(table.c.photo.type.__class__, firebird.FBBinary) + assert isinstance(table.c.photo.type, Binary) # The following assume a Dialect 3 database - eq_(table.c.d.type.__class__, firebird.FBDate) - eq_(table.c.t.type.__class__, firebird.FBTime) - eq_(table.c.dt.type.__class__, firebird.FBDateTime) + assert isinstance(table.c.d.type, Date) + assert isinstance(table.c.t.type, Time) + assert isinstance(table.c.dt.type, DateTime) class CompileTest(TestBase, AssertsCompiledSQL): @@ -76,7 +77,13 @@ class CompileTest(TestBase, AssertsCompiledSQL): def test_alias(self): t = table('sometable', column('col1'), column('col2')) s = select([t.alias()]) - self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable sometable_1") + self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable AS sometable_1") + + dialect = firebird.FBDialect() + dialect._version_two = False + self.assert_compile(s, "SELECT sometable_1.col1, sometable_1.col2 FROM sometable sometable_1", + dialect = dialect + ) def test_function(self): self.assert_compile(func.foo(1, 2), "foo(:foo_1, :foo_2)") @@ -98,15 +105,15 @@ class CompileTest(TestBase, AssertsCompiledSQL): column('description', String(128)), ) - u = update(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name]) + u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING mytable.myid, mytable.name") - u = update(table1, values=dict(name='foo'), firebird_returning=[table1]) + u = update(table1, values=dict(name='foo')).returning(table1) self.assert_compile(u, "UPDATE mytable SET name=:name "\ "RETURNING mytable.myid, mytable.name, mytable.description") - u = update(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)]) - self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)") + u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name) AS length_1") def test_insert_returning(self): table1 = table('mytable', @@ -115,90 +122,20 @@ class CompileTest(TestBase, AssertsCompiledSQL): column('description', String(128)), ) - i = insert(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name]) + i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING mytable.myid, mytable.name") - i = insert(table1, values=dict(name='foo'), firebird_returning=[table1]) + i = insert(table1, values=dict(name='foo')).returning(table1) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) "\ "RETURNING mytable.myid, mytable.name, mytable.description") - i = insert(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)]) - self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)") + i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name) AS length_1") -class ReturningTest(TestBase, AssertsExecutionResults): - __only_on__ = 'firebird' - - @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') - def test_update_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) - - result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute() - eq_(result.fetchall(), [(1,)]) - - result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - eq_(result2.fetchall(), [(1,True),(2,False)]) - finally: - table.drop() - - @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') - def test_insert_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False}) - - eq_(result.fetchall(), [(1,)]) - - # Multiple inserts only return the last row - result2 = table.insert(firebird_returning=[table]).execute( - [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) - - eq_(result2.fetchall(), [(3,3,True)]) - - result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False}) - eq_([dict(row) for row in result3], [{'ID':4}]) - - result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons') - eq_([dict(row) for row in result4], [{'PERSONS': 10}]) - finally: - table.drop() - - @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') - def test_delete_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, Sequence('gen_tables_id'), primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) - - result = table.delete(table.c.persons > 4, firebird_returning=[table.c.id]).execute() - eq_(result.fetchall(), [(1,)]) - - result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - eq_(result2.fetchall(), [(2,False),]) - finally: - table.drop() -class MiscFBTests(TestBase): +class MiscTest(TestBase): __only_on__ = 'firebird' def test_strlen(self): @@ -217,12 +154,20 @@ class MiscFBTests(TestBase): try: t.insert(values=dict(name='dante')).execute() t.insert(values=dict(name='alighieri')).execute() - select([func.count(t.c.id)],func.length(t.c.name)==5).execute().fetchone()[0] == 1 + select([func.count(t.c.id)],func.length(t.c.name)==5).execute().first()[0] == 1 finally: meta.drop_all() def test_server_version_info(self): - version = testing.db.dialect.server_version_info(testing.db.connect()) + version = testing.db.dialect.server_version_info assert len(version) == 3, "Got strange version info: %s" % repr(version) + def test_percents_in_text(self): + for expr, result in ( + (text("select '%' from rdb$database"), '%'), + (text("select '%%' from rdb$database"), '%%'), + (text("select '%%%' from rdb$database"), '%%%'), + (text("select 'hello % world' from rdb$database"), "hello % world") + ): + eq_(testing.db.scalar(expr), result) diff --git a/test/dialect/test_informix.py b/test/dialect/test_informix.py index 86a4e751d4..e647990d31 100644 --- a/test/dialect/test_informix.py +++ b/test/dialect/test_informix.py @@ -4,7 +4,8 @@ from sqlalchemy.test import * class CompileTest(TestBase, AssertsCompiledSQL): - __dialect__ = informix.InfoDialect() + __only_on__ = 'informix' + __dialect__ = informix.InformixDialect() def test_statements(self): meta =MetaData() diff --git a/test/dialect/test_maxdb.py b/test/dialect/test_maxdb.py index 033a05533f..c69a81120f 100644 --- a/test/dialect/test_maxdb.py +++ b/test/dialect/test_maxdb.py @@ -185,7 +185,7 @@ class DBAPITest(TestBase, AssertsExecutionResults): vals = [] for i in xrange(3): cr.execute('SELECT busto.NEXTVAL FROM DUAL') - vals.append(cr.fetchone()[0]) + vals.append(cr.first()[0]) # should be 1,2,3, but no... self.assert_(vals != [1,2,3]) diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index dd86ce0de2..423310db62 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -2,17 +2,18 @@ from sqlalchemy.test.testing import eq_ import datetime, os, re from sqlalchemy import * -from sqlalchemy import types, exc +from sqlalchemy import types, exc, schema from sqlalchemy.orm import * from sqlalchemy.sql import table, column from sqlalchemy.databases import mssql -import sqlalchemy.engine.url as url +from sqlalchemy.dialects.mssql import pyodbc +from sqlalchemy.engine import url from sqlalchemy.test import * from sqlalchemy.test.testing import eq_ class CompileTest(TestBase, AssertsCompiledSQL): - __dialect__ = mssql.MSSQLDialect() + __dialect__ = mssql.dialect() def test_insert(self): t = table('sometable', column('somecolumn')) @@ -157,6 +158,45 @@ class CompileTest(TestBase, AssertsCompiledSQL): select([extract(field, t.c.col1)]), 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % field) + def test_update_returning(self): + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, inserted.name") + + u = update(table1, values=dict(name='foo')).returning(table1) + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, " + "inserted.name, inserted.description") + + u = update(table1, values=dict(name='foo')).returning(table1).where(table1.c.name=='bar') + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT inserted.myid, " + "inserted.name, inserted.description WHERE mytable.name = :name_1") + + u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name) AS length_1") + + def test_insert_returning(self): + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) + self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, inserted.name VALUES (:name)") + + i = insert(table1, values=dict(name='foo')).returning(table1) + self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT inserted.myid, " + "inserted.name, inserted.description VALUES (:name)") + + i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(i, "INSERT INTO mytable (name) OUTPUT LEN(inserted.name) AS length_1 VALUES (:name)") + + class IdentityInsertTest(TestBase, AssertsCompiledSQL): __only_on__ = 'mssql' @@ -189,9 +229,9 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL): eq_([(9, 'Python')], list(cats)) result = cattable.insert().values(description='PHP').execute() - eq_([10], result.last_inserted_ids()) + eq_([10], result.inserted_primary_key) lastcat = cattable.select().order_by(desc(cattable.c.id)).execute() - eq_((10, 'PHP'), lastcat.fetchone()) + eq_((10, 'PHP'), lastcat.first()) def test_executemany(self): cattable.insert().execute([ @@ -213,10 +253,51 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL): eq_([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats)) -class ReflectionTest(TestBase): +class ReflectionTest(TestBase, ComparesTables): __only_on__ = 'mssql' - def testidentity(self): + def test_basic_reflection(self): + meta = MetaData(testing.db) + + users = Table('engine_users', meta, + Column('user_id', types.INT, primary_key=True), + Column('user_name', types.VARCHAR(20), nullable=False), + Column('test1', types.CHAR(5), nullable=False), + Column('test2', types.Float(5), nullable=False), + Column('test3', types.Text), + Column('test4', types.Numeric, nullable = False), + Column('test5', types.DateTime), + Column('parent_user_id', types.Integer, + ForeignKey('engine_users.user_id')), + Column('test6', types.DateTime, nullable=False), + Column('test7', types.Text), + Column('test8', types.Binary), + Column('test_passivedefault2', types.Integer, server_default='5'), + Column('test9', types.Binary(100)), + Column('test_numeric', types.Numeric()), + test_needs_fk=True, + ) + + addresses = Table('engine_email_addresses', meta, + Column('address_id', types.Integer, primary_key = True), + Column('remote_user_id', types.Integer, ForeignKey(users.c.user_id)), + Column('email_address', types.String(20)), + test_needs_fk=True, + ) + meta.create_all() + + try: + meta2 = MetaData() + reflected_users = Table('engine_users', meta2, autoload=True, + autoload_with=testing.db) + reflected_addresses = Table('engine_email_addresses', meta2, + autoload=True, autoload_with=testing.db) + self.assert_tables_equal(users, reflected_users) + self.assert_tables_equal(addresses, reflected_addresses) + finally: + meta.drop_all() + + def test_identity(self): meta = MetaData(testing.db) table = Table( 'identity_test', meta, @@ -240,7 +321,7 @@ class QueryUnicodeTest(TestBase): meta = MetaData(testing.db) t1 = Table('unitest_table', meta, Column('id', Integer, primary_key=True), - Column('descr', mssql.MSText(200, convert_unicode=True))) + Column('descr', mssql.MSText(convert_unicode=True))) meta.create_all() con = testing.db.connect() @@ -248,7 +329,7 @@ class QueryUnicodeTest(TestBase): con.execute(u"insert into unitest_table values ('bien mangé')".encode('UTF-8')) try: - r = t1.select().execute().fetchone() + r = t1.select().execute().first() assert isinstance(r[1], unicode), '%s is %s instead of unicode, working on %s' % ( r[1], type(r[1]), meta.bind) @@ -262,7 +343,9 @@ class QueryTest(TestBase): meta = MetaData(testing.db) t1 = Table('t1', meta, Column('id', Integer, Sequence('fred', 100, 1), primary_key=True), - Column('descr', String(200))) + Column('descr', String(200)), + implicit_returning = False + ) t2 = Table('t2', meta, Column('id', Integer, Sequence('fred', 200, 1), primary_key=True), Column('descr', String(200))) @@ -274,9 +357,9 @@ class QueryTest(TestBase): try: tr = con.begin() r = con.execute(t2.insert(), descr='hello') - self.assert_(r.last_inserted_ids() == [200]) + self.assert_(r.inserted_primary_key == [200]) r = con.execute(t1.insert(), descr='hello') - self.assert_(r.last_inserted_ids() == [100]) + self.assert_(r.inserted_primary_key == [100]) finally: tr.commit() @@ -295,6 +378,19 @@ class QueryTest(TestBase): tbl.drop() con.execute('drop schema paj') + def test_returning_no_autoinc(self): + meta = MetaData(testing.db) + + table = Table('t1', meta, Column('id', Integer, primary_key=True), Column('data', String(50))) + table.create() + try: + result = table.insert().values(id=1, data=func.lower("SomeString")).returning(table.c.id, table.c.data).execute() + eq_(result.fetchall(), [(1, 'somestring',)]) + finally: + # this will hang if the "SET IDENTITY_INSERT t1 OFF" occurs before the + # result is fetched + table.drop() + def test_delete_schema(self): meta = MetaData(testing.db) con = testing.db.connect() @@ -371,36 +467,26 @@ class SchemaTest(TestBase): ) self.column = t.c.test_column + dialect = mssql.dialect() + self.ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t)) + + def _column_spec(self): + return self.ddl_compiler.get_column_specification(self.column) + def test_that_mssql_default_nullability_emits_null(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR NULL", column_specification) + eq_("test_column VARCHAR NULL", self._column_spec()) def test_that_mssql_none_nullability_does_not_emit_nullability(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) self.column.nullable = None - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR", column_specification) + eq_("test_column VARCHAR", self._column_spec()) def test_that_mssql_specified_nullable_emits_null(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) self.column.nullable = True - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR NULL", column_specification) + eq_("test_column VARCHAR NULL", self._column_spec()) def test_that_mssql_specified_not_nullable_emits_not_null(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) self.column.nullable = False - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR NOT NULL", column_specification) + eq_("test_column VARCHAR NOT NULL", self._column_spec()) def full_text_search_missing(): @@ -515,79 +601,73 @@ class MatchTest(TestBase, AssertsCompiledSQL): class ParseConnectTest(TestBase, AssertsCompiledSQL): __only_on__ = 'mssql' + @classmethod + def setup_class(cls): + global dialect + dialect = pyodbc.MSDialect_pyodbc() + def test_pyodbc_connect_dsn_trusted(self): u = url.make_url('mssql://mydsn') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection) def test_pyodbc_connect_old_style_dsn_trusted(self): u = url.make_url('mssql:///?dsn=mydsn') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection) def test_pyodbc_connect_dsn_non_trusted(self): u = url.make_url('mssql://username:password@mydsn') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_dsn_extra(self): u = url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_english&foo=bar') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection) def test_pyodbc_connect(self): u = url.make_url('mssql://username:password@hostspec/database') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_comma_port(self): u = url.make_url('mssql://username:password@hostspec:12345/database') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_config_port(self): u = url.make_url('mssql://username:password@hostspec/database?port=12345') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection) def test_pyodbc_extra_connect(self): u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection) def test_pyodbc_odbc_connect(self): u = url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_odbc_connect_with_dsn(self): u = url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_odbc_connect_ignores_other_values(self): u = url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword') - dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) -class TypesTest(TestBase): +class TypesTest(TestBase, AssertsExecutionResults, ComparesTables): __only_on__ = 'mssql' @classmethod def setup_class(cls): - global numeric_table, metadata + global metadata metadata = MetaData(testing.db) def teardown(self): @@ -601,26 +681,22 @@ class TypesTest(TestBase): ) metadata.create_all() - try: - test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000', - '-1500000.00000000000000000000', '1500000', - '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2', - '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234', - '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3', - '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2', - '-02452E-2', '45125E-2', - '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25', - '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12', - '00000000000000.1E+12', '000000000000.2E-32'] + test_items = [decimal.Decimal(d) for d in '1500000.00000000000000000000', + '-1500000.00000000000000000000', '1500000', + '0.0000000000000000002', '0.2', '-0.0000000000000000002', '-2E-2', + '156666.458923543', '-156666.458923543', '1', '-1', '-1234', '1234', + '2E-12', '4E8', '3E-6', '3E-7', '4.1', '1E-1', '1E-2', '1E-3', + '1E-4', '1E-5', '1E-6', '1E-7', '1E-1', '1E-8', '0.2732E2', '-0.2432E2', '4.35656E2', + '-02452E-2', '45125E-2', + '1234.58965E-2', '1.521E+15', '-1E-25', '1E-25', '1254E-25', '-1203E-25', + '0', '-0.00', '-0', '4585E12', '000000000000000000012', '000000000000.32E12', + '00000000000000.1E+12', '000000000000.2E-32'] - for value in test_items: - numeric_table.insert().execute(numericcol=value) + for value in test_items: + numeric_table.insert().execute(numericcol=value) - for value in select([numeric_table.c.numericcol]).execute(): - assert value[0] in test_items, "%s not in test_items" % value[0] - - except Exception, e: - raise e + for value in select([numeric_table.c.numericcol]).execute(): + assert value[0] in test_items, "%s not in test_items" % value[0] def test_float(self): float_table = Table('float_table', metadata, @@ -643,11 +719,6 @@ class TypesTest(TestBase): raise e -class TypesTest2(TestBase, AssertsExecutionResults): - "Test Microsoft SQL Server column types" - - __only_on__ = 'mssql' - def test_money(self): "Exercise type specification for money types." @@ -659,13 +730,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'SMALLMONEY'), ] - table_args = ['test_mssql_money', MetaData(testing.db)] + table_args = ['test_mssql_money', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) money_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(money_table)) for col in money_table.c: index = int(col.name[1:]) @@ -688,15 +760,27 @@ class TypesTest2(TestBase, AssertsExecutionResults): (mssql.MSDateTime, [], {}, 'DATETIME', []), + (types.DATE, [], {}, + 'DATE', ['>=', (10,)]), + (types.Date, [], {}, + 'DATE', ['>=', (10,)]), + (types.Date, [], {}, + 'DATETIME', ['<', (10,)], mssql.MSDateTime), (mssql.MSDate, [], {}, 'DATE', ['>=', (10,)]), (mssql.MSDate, [], {}, 'DATETIME', ['<', (10,)], mssql.MSDateTime), + (types.TIME, [], {}, + 'TIME', ['>=', (10,)]), + (types.Time, [], {}, + 'TIME', ['>=', (10,)]), (mssql.MSTime, [], {}, 'TIME', ['>=', (10,)]), (mssql.MSTime, [1], {}, 'TIME(1)', ['>=', (10,)]), + (types.Time, [], {}, + 'DATETIME', ['<', (10,)], mssql.MSDateTime), (mssql.MSTime, [], {}, 'DATETIME', ['<', (10,)], mssql.MSDateTime), @@ -715,14 +799,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): ] - table_args = ['test_mssql_dates', MetaData(testing.db)] + table_args = ['test_mssql_dates', metadata] for index, spec in enumerate(columns): type_, args, kw, res, requires = spec[0:5] if (requires and testing._is_excluded('mssql', *requires)) or not requires: table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) dates_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, schema.CreateTable(dates_table)) for col in dates_table.c: index = int(col.name[1:]) @@ -730,49 +814,37 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - dates_table.create(checkfirst=True) - assert True - except: - raise + dates_table.create(checkfirst=True) reflected_dates = Table('test_mssql_dates', MetaData(testing.db), autoload=True) for col in reflected_dates.c: - index = int(col.name[1:]) - testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__, - len(columns[index]) > 5 and columns[index][5] or columns[index][0]) - dates_table.drop() - - def test_dates2(self): - meta = MetaData(testing.db) - t = Table('test_dates', meta, - Column('id', Integer, - Sequence('datetest_id_seq', optional=True), - primary_key=True), - Column('adate', Date), - Column('atime', Time), - Column('adatetime', DateTime)) - t.create(checkfirst=True) - try: - d1 = datetime.date(2007, 10, 30) - t1 = datetime.time(11, 2, 32) - d2 = datetime.datetime(2007, 10, 30, 11, 2, 32) - t.insert().execute(adate=d1, adatetime=d2, atime=t1) - t.insert().execute(adate=d2, adatetime=d2, atime=d2) + self.assert_types_base(col, dates_table.c[col.key]) - x = t.select().execute().fetchall()[0] - self.assert_(x.adate.__class__ == datetime.date) - self.assert_(x.atime.__class__ == datetime.time) - self.assert_(x.adatetime.__class__ == datetime.datetime) + def test_date_roundtrip(self): + t = Table('test_dates', metadata, + Column('id', Integer, + Sequence('datetest_id_seq', optional=True), + primary_key=True), + Column('adate', Date), + Column('atime', Time), + Column('adatetime', DateTime)) + metadata.create_all() + d1 = datetime.date(2007, 10, 30) + t1 = datetime.time(11, 2, 32) + d2 = datetime.datetime(2007, 10, 30, 11, 2, 32) + t.insert().execute(adate=d1, adatetime=d2, atime=t1) + t.insert().execute(adate=d2, adatetime=d2, atime=d2) - t.delete().execute() + x = t.select().execute().fetchall()[0] + self.assert_(x.adate.__class__ == datetime.date) + self.assert_(x.atime.__class__ == datetime.time) + self.assert_(x.adatetime.__class__ == datetime.datetime) - t.insert().execute(adate=d1, adatetime=d2, atime=t1) + t.delete().execute() - eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)]) + t.insert().execute(adate=d1, adatetime=d2, atime=t1) - finally: - t.drop(checkfirst=True) + eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)]) def test_binary(self): "Exercise type specification for binary types." @@ -781,6 +853,9 @@ class TypesTest2(TestBase, AssertsExecutionResults): # column type, args, kwargs, expected ddl (mssql.MSBinary, [], {}, 'BINARY'), + (types.Binary, [10], {}, + 'BINARY(10)'), + (mssql.MSBinary, [10], {}, 'BINARY(10)'), @@ -798,13 +873,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'BINARY(10)') ] - table_args = ['test_mssql_binary', MetaData(testing.db)] + table_args = ['test_mssql_binary', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) binary_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(binary_table)) for col in binary_table.c: index = int(col.name[1:]) @@ -812,22 +888,15 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - binary_table.create(checkfirst=True) - assert True - except: - raise + metadata.create_all() reflected_binary = Table('test_mssql_binary', MetaData(testing.db), autoload=True) for col in reflected_binary.c: - # don't test the MSGenericBinary since it's a special case and - # reflected it will map to a MSImage or MSBinary depending - if not testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ == mssql.MSGenericBinary: - testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__, - testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__) + c1 =testing.db.dialect.type_descriptor(col.type).__class__ + c2 =testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ + assert issubclass(c1, c2), "%r is not a subclass of %r" % (c1, c2) if binary_table.c[col.name].type.length: testing.eq_(col.type.length, binary_table.c[col.name].type.length) - binary_table.drop() def test_boolean(self): "Exercise type specification for boolean type." @@ -838,13 +907,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'BIT'), ] - table_args = ['test_mssql_boolean', MetaData(testing.db)] + table_args = ['test_mssql_boolean', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) boolean_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(boolean_table)) for col in boolean_table.c: index = int(col.name[1:]) @@ -852,12 +922,7 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - boolean_table.create(checkfirst=True) - assert True - except: - raise - boolean_table.drop() + metadata.create_all() def test_numeric(self): "Exercise type specification and options for numeric types." @@ -865,40 +930,39 @@ class TypesTest2(TestBase, AssertsExecutionResults): columns = [ # column type, args, kwargs, expected ddl (mssql.MSNumeric, [], {}, - 'NUMERIC(10, 2)'), + 'NUMERIC'), (mssql.MSNumeric, [None], {}, 'NUMERIC'), - (mssql.MSNumeric, [12], {}, - 'NUMERIC(12, 2)'), (mssql.MSNumeric, [12, 4], {}, 'NUMERIC(12, 4)'), - (mssql.MSFloat, [], {}, - 'FLOAT(10)'), - (mssql.MSFloat, [None], {}, + (types.Float, [], {}, + 'FLOAT'), + (types.Float, [None], {}, 'FLOAT'), - (mssql.MSFloat, [12], {}, + (types.Float, [12], {}, 'FLOAT(12)'), (mssql.MSReal, [], {}, 'REAL'), - (mssql.MSInteger, [], {}, + (types.Integer, [], {}, 'INTEGER'), - (mssql.MSBigInteger, [], {}, + (types.BigInteger, [], {}, 'BIGINT'), (mssql.MSTinyInteger, [], {}, 'TINYINT'), - (mssql.MSSmallInteger, [], {}, + (types.SmallInteger, [], {}, 'SMALLINT'), ] - table_args = ['test_mssql_numeric', MetaData(testing.db)] + table_args = ['test_mssql_numeric', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) numeric_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(numeric_table)) for col in numeric_table.c: index = int(col.name[1:]) @@ -906,20 +970,11 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - numeric_table.create(checkfirst=True) - assert True - except: - raise - numeric_table.drop() + metadata.create_all() def test_char(self): """Exercise COLLATE-ish options on string types.""" - # modify the text_as_varchar setting since we are not testing that behavior here - text_as_varchar = testing.db.dialect.text_as_varchar - testing.db.dialect.text_as_varchar = False - columns = [ (mssql.MSChar, [], {}, 'CHAR'), @@ -960,13 +1015,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'NTEXT COLLATE Latin1_General_CI_AS'), ] - table_args = ['test_mssql_charset', MetaData(testing.db)] + table_args = ['test_mssql_charset', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) charset_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(charset_table)) for col in charset_table.c: index = int(col.name[1:]) @@ -974,110 +1030,91 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - charset_table.create(checkfirst=True) - assert True - except: - raise - charset_table.drop() - - testing.db.dialect.text_as_varchar = text_as_varchar + metadata.create_all() def test_timestamp(self): """Exercise TIMESTAMP column.""" - meta = MetaData(testing.db) - - try: - columns = [ - (TIMESTAMP, - 'TIMESTAMP'), - (mssql.MSTimeStamp, - 'TIMESTAMP'), - ] - for idx, (spec, expected) in enumerate(columns): - t = Table('mssql_ts%s' % idx, meta, - Column('id', Integer, primary_key=True), - Column('t', spec, nullable=None)) - testing.eq_(colspec(t.c.t), "t %s" % expected) - self.assert_(repr(t.c.t)) - try: - t.create(checkfirst=True) - assert True - except: - raise - t.drop() - finally: - meta.drop_all() + dialect = mssql.dialect() + spec, expected = (TIMESTAMP,'TIMESTAMP') + t = Table('mssql_ts', metadata, + Column('id', Integer, primary_key=True), + Column('t', spec, nullable=None)) + gen = dialect.ddl_compiler(dialect, schema.CreateTable(t)) + testing.eq_(gen.get_column_specification(t.c.t), "t %s" % expected) + self.assert_(repr(t.c.t)) + t.create(checkfirst=True) + def test_autoincrement(self): - meta = MetaData(testing.db) - try: - Table('ai_1', meta, - Column('int_y', Integer, primary_key=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True)) - Table('ai_2', meta, - Column('int_y', Integer, primary_key=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True)) - Table('ai_3', meta, - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False), - Column('int_y', Integer, primary_key=True)) - Table('ai_4', meta, - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False), - Column('int_n2', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False)) - Table('ai_5', meta, - Column('int_y', Integer, primary_key=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False)) - Table('ai_6', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('int_y', Integer, primary_key=True)) - Table('ai_7', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('o2', String(1), DefaultClause('x'), - primary_key=True), - Column('int_y', Integer, primary_key=True)) - Table('ai_8', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('o2', String(1), DefaultClause('x'), - primary_key=True)) - meta.create_all() - - table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', - 'ai_5', 'ai_6', 'ai_7', 'ai_8'] - mr = MetaData(testing.db) - mr.reflect(only=table_names) - - for tbl in [mr.tables[name] for name in table_names]: - for c in tbl.c: - if c.name.startswith('int_y'): - assert c.autoincrement - elif c.name.startswith('int_n'): - assert not c.autoincrement - tbl.insert().execute() + Table('ai_1', metadata, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, DefaultClause('0'), + primary_key=True)) + Table('ai_2', metadata, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, DefaultClause('0'), + primary_key=True)) + Table('ai_3', metadata, + Column('int_n', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False), + Column('int_y', Integer, primary_key=True)) + Table('ai_4', metadata, + Column('int_n', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False), + Column('int_n2', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False)) + Table('ai_5', metadata, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False)) + Table('ai_6', metadata, + Column('o1', String(1), DefaultClause('x'), + primary_key=True), + Column('int_y', Integer, primary_key=True)) + Table('ai_7', metadata, + Column('o1', String(1), DefaultClause('x'), + primary_key=True), + Column('o2', String(1), DefaultClause('x'), + primary_key=True), + Column('int_y', Integer, primary_key=True)) + Table('ai_8', metadata, + Column('o1', String(1), DefaultClause('x'), + primary_key=True), + Column('o2', String(1), DefaultClause('x'), + primary_key=True)) + metadata.create_all() + + table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', + 'ai_5', 'ai_6', 'ai_7', 'ai_8'] + mr = MetaData(testing.db) + + for name in table_names: + tbl = Table(name, mr, autoload=True) + for c in tbl.c: + if c.name.startswith('int_y'): + assert c.autoincrement + elif c.name.startswith('int_n'): + assert not c.autoincrement + + for counter, engine in enumerate([ + engines.testing_engine(options={'implicit_returning':False}), + engines.testing_engine(options={'implicit_returning':True}), + ] + ): + engine.execute(tbl.insert()) if 'int_y' in tbl.c: - assert select([tbl.c.int_y]).scalar() == 1 - assert list(tbl.select().execute().fetchone()).count(1) == 1 + assert engine.scalar(select([tbl.c.int_y])) == counter + 1 + assert list(engine.execute(tbl.select()).first()).count(counter + 1) == 1 else: - assert 1 not in list(tbl.select().execute().fetchone()) - finally: - meta.drop_all() - -def colspec(c): - return testing.db.dialect.schemagenerator(testing.db.dialect, - testing.db, None, None).get_column_specification(c) - + assert 1 not in list(engine.execute(tbl.select()).first()) + engine.execute(tbl.delete()) class BinaryTest(TestBase, AssertsExecutionResults): """Test the Binary and VarBinary types""" + + __only_on__ = 'mssql' + @classmethod def setup_class(cls): global binary_table, MyPickleType @@ -1125,6 +1162,11 @@ class BinaryTest(TestBase, AssertsExecutionResults): stream2 =self.load_stream('binary_data_two.dat') binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_image=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3) binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_image=stream2, data_slice=stream2[0:99], pickled=testobj2) + + # TODO: pyodbc does not seem to accept "None" for a VARBINARY column (data=None). + # error: [Microsoft][ODBC SQL Server Driver][SQL Server]Implicit conversion from + # data type varchar to varbinary is not allowed. Use the CONVERT function to run this query. (257) + #binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_image=None, data_slice=stream2[0:99], pickled=None) binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data_image=None, data_slice=stream2[0:99], pickled=None) for stmt in ( diff --git a/test/dialect/test_mysql.py b/test/dialect/test_mysql.py index 8adb2d71c5..4052641522 100644 --- a/test/dialect/test_mysql.py +++ b/test/dialect/test_mysql.py @@ -1,8 +1,12 @@ from sqlalchemy.test.testing import eq_ + +# Py2K import sets +# end Py2K + from sqlalchemy import * from sqlalchemy import sql, exc -from sqlalchemy.databases import mysql +from sqlalchemy.dialects.mysql import base as mysql from sqlalchemy.test.testing import eq_ from sqlalchemy.test import * @@ -56,11 +60,11 @@ class TypesTest(TestBase, AssertsExecutionResults): # column type, args, kwargs, expected ddl # e.g. Column(Integer(10, unsigned=True)) == 'INTEGER(10) UNSIGNED' (mysql.MSNumeric, [], {}, - 'NUMERIC(10, 2)'), + 'NUMERIC'), (mysql.MSNumeric, [None], {}, 'NUMERIC'), (mysql.MSNumeric, [12], {}, - 'NUMERIC(12, 2)'), + 'NUMERIC(12)'), (mysql.MSNumeric, [12, 4], {'unsigned':True}, 'NUMERIC(12, 4) UNSIGNED'), (mysql.MSNumeric, [12, 4], {'zerofill':True}, @@ -69,11 +73,11 @@ class TypesTest(TestBase, AssertsExecutionResults): 'NUMERIC(12, 4) UNSIGNED ZEROFILL'), (mysql.MSDecimal, [], {}, - 'DECIMAL(10, 2)'), + 'DECIMAL'), (mysql.MSDecimal, [None], {}, 'DECIMAL'), (mysql.MSDecimal, [12], {}, - 'DECIMAL(12, 2)'), + 'DECIMAL(12)'), (mysql.MSDecimal, [12, None], {}, 'DECIMAL(12)'), (mysql.MSDecimal, [12, 4], {'unsigned':True}, @@ -178,11 +182,11 @@ class TypesTest(TestBase, AssertsExecutionResults): table_args.append(Column('c%s' % index, type_(*args, **kw))) numeric_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, numeric_table) for col in numeric_table.c: index = int(col.name[1:]) - self.assert_eq(gen.get_column_specification(col), + eq_(gen.get_column_specification(col), "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) @@ -262,11 +266,11 @@ class TypesTest(TestBase, AssertsExecutionResults): table_args.append(Column('c%s' % index, type_(*args, **kw))) charset_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, charset_table) for col in charset_table.c: index = int(col.name[1:]) - self.assert_eq(gen.get_column_specification(col), + eq_(gen.get_column_specification(col), "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) @@ -292,14 +296,14 @@ class TypesTest(TestBase, AssertsExecutionResults): Column('b7', mysql.MSBit(63)), Column('b8', mysql.MSBit(64))) - self.assert_eq(colspec(bit_table.c.b1), 'b1 BIT') - self.assert_eq(colspec(bit_table.c.b2), 'b2 BIT') - self.assert_eq(colspec(bit_table.c.b3), 'b3 BIT NOT NULL') - self.assert_eq(colspec(bit_table.c.b4), 'b4 BIT(1)') - self.assert_eq(colspec(bit_table.c.b5), 'b5 BIT(8)') - self.assert_eq(colspec(bit_table.c.b6), 'b6 BIT(32)') - self.assert_eq(colspec(bit_table.c.b7), 'b7 BIT(63)') - self.assert_eq(colspec(bit_table.c.b8), 'b8 BIT(64)') + eq_(colspec(bit_table.c.b1), 'b1 BIT') + eq_(colspec(bit_table.c.b2), 'b2 BIT') + eq_(colspec(bit_table.c.b3), 'b3 BIT NOT NULL') + eq_(colspec(bit_table.c.b4), 'b4 BIT(1)') + eq_(colspec(bit_table.c.b5), 'b5 BIT(8)') + eq_(colspec(bit_table.c.b6), 'b6 BIT(32)') + eq_(colspec(bit_table.c.b7), 'b7 BIT(63)') + eq_(colspec(bit_table.c.b8), 'b8 BIT(64)') for col in bit_table.c: self.assert_(repr(col)) @@ -314,7 +318,7 @@ class TypesTest(TestBase, AssertsExecutionResults): def roundtrip(store, expected=None): expected = expected or store table.insert(store).execute() - row = list(table.select().execute())[0] + row = table.select().execute().first() try: self.assert_(list(row) == expected) except: @@ -322,7 +326,7 @@ class TypesTest(TestBase, AssertsExecutionResults): print "Expected %s" % expected print "Found %s" % list(row) raise - table.delete().execute() + table.delete().execute().close() roundtrip([0] * 8) roundtrip([None, None, 0, None, None, None, None, None]) @@ -350,10 +354,10 @@ class TypesTest(TestBase, AssertsExecutionResults): Column('b3', mysql.MSTinyInteger(1)), Column('b4', mysql.MSTinyInteger)) - self.assert_eq(colspec(bool_table.c.b1), 'b1 BOOL') - self.assert_eq(colspec(bool_table.c.b2), 'b2 BOOL') - self.assert_eq(colspec(bool_table.c.b3), 'b3 TINYINT(1)') - self.assert_eq(colspec(bool_table.c.b4), 'b4 TINYINT') + eq_(colspec(bool_table.c.b1), 'b1 BOOL') + eq_(colspec(bool_table.c.b2), 'b2 BOOL') + eq_(colspec(bool_table.c.b3), 'b3 TINYINT(1)') + eq_(colspec(bool_table.c.b4), 'b4 TINYINT') for col in bool_table.c: self.assert_(repr(col)) @@ -364,7 +368,7 @@ class TypesTest(TestBase, AssertsExecutionResults): def roundtrip(store, expected=None): expected = expected or store table.insert(store).execute() - row = list(table.select().execute())[0] + row = table.select().execute().first() try: self.assert_(list(row) == expected) for i, val in enumerate(expected): @@ -375,7 +379,7 @@ class TypesTest(TestBase, AssertsExecutionResults): print "Expected %s" % expected print "Found %s" % list(row) raise - table.delete().execute() + table.delete().execute().close() roundtrip([None, None, None, None]) @@ -387,7 +391,7 @@ class TypesTest(TestBase, AssertsExecutionResults): meta2 = MetaData(testing.db) # replace with reflected table = Table('mysql_bool', meta2, autoload=True) - self.assert_eq(colspec(table.c.b3), 'b3 BOOL') + eq_(colspec(table.c.b3), 'b3 BOOL') roundtrip([None, None, None, None]) roundtrip([True, True, 1, 1], [True, True, True, 1]) @@ -430,7 +434,7 @@ class TypesTest(TestBase, AssertsExecutionResults): t = Table('mysql_ts%s' % idx, meta, Column('id', Integer, primary_key=True), Column('t', *spec)) - self.assert_eq(colspec(t.c.t), "t %s" % expected) + eq_(colspec(t.c.t), "t %s" % expected) self.assert_(repr(t.c.t)) t.create() r = Table('mysql_ts%s' % idx, MetaData(testing.db), @@ -460,12 +464,12 @@ class TypesTest(TestBase, AssertsExecutionResults): for table in year_table, reflected: table.insert(['1950', '50', None, 50, 1950]).execute() - row = list(table.select().execute())[0] - self.assert_eq(list(row), [1950, 2050, None, 50, 1950]) + row = table.select().execute().first() + eq_(list(row), [1950, 2050, None, 50, 1950]) table.delete().execute() self.assert_(colspec(table.c.y1).startswith('y1 YEAR')) - self.assert_eq(colspec(table.c.y4), 'y4 YEAR(2)') - self.assert_eq(colspec(table.c.y5), 'y5 YEAR(4)') + eq_(colspec(table.c.y4), 'y4 YEAR(2)') + eq_(colspec(table.c.y5), 'y5 YEAR(4)') finally: meta.drop_all() @@ -479,9 +483,9 @@ class TypesTest(TestBase, AssertsExecutionResults): Column('s2', mysql.MSSet("'a'")), Column('s3', mysql.MSSet("'5'", "'7'", "'9'"))) - self.assert_eq(colspec(set_table.c.s1), "s1 SET('dq','sq')") - self.assert_eq(colspec(set_table.c.s2), "s2 SET('a')") - self.assert_eq(colspec(set_table.c.s3), "s3 SET('5','7','9')") + eq_(colspec(set_table.c.s1), "s1 SET('dq','sq')") + eq_(colspec(set_table.c.s2), "s2 SET('a')") + eq_(colspec(set_table.c.s3), "s3 SET('5','7','9')") for col in set_table.c: self.assert_(repr(col)) @@ -494,7 +498,7 @@ class TypesTest(TestBase, AssertsExecutionResults): def roundtrip(store, expected=None): expected = expected or store table.insert(store).execute() - row = list(table.select().execute())[0] + row = table.select().execute().first() try: self.assert_(list(row) == expected) except: @@ -518,12 +522,12 @@ class TypesTest(TestBase, AssertsExecutionResults): {'s3':set(['5', '7'])}, {'s3':set(['5', '7', '9'])}, {'s3':set(['7', '9'])}) - rows = list(select( + rows = select( [set_table.c.s3], - set_table.c.s3.in_([set(['5']), set(['5', '7'])])).execute()) + set_table.c.s3.in_([set(['5']), set(['5', '7']), set(['7', '5'])]) + ).execute().fetchall() found = set([frozenset(row[0]) for row in rows]) - eq_(found, - set([frozenset(['5']), frozenset(['5', '7'])])) + eq_(found, set([frozenset(['5']), frozenset(['5', '7'])])) finally: meta.drop_all() @@ -542,17 +546,17 @@ class TypesTest(TestBase, AssertsExecutionResults): Column('e6', mysql.MSEnum("'a'", "b")), ) - self.assert_eq(colspec(enum_table.c.e1), + eq_(colspec(enum_table.c.e1), "e1 ENUM('a','b')") - self.assert_eq(colspec(enum_table.c.e2), + eq_(colspec(enum_table.c.e2), "e2 ENUM('a','b') NOT NULL") - self.assert_eq(colspec(enum_table.c.e3), + eq_(colspec(enum_table.c.e3), "e3 ENUM('a','b')") - self.assert_eq(colspec(enum_table.c.e4), + eq_(colspec(enum_table.c.e4), "e4 ENUM('a','b') NOT NULL") - self.assert_eq(colspec(enum_table.c.e5), + eq_(colspec(enum_table.c.e5), "e5 ENUM('a','b')") - self.assert_eq(colspec(enum_table.c.e6), + eq_(colspec(enum_table.c.e6), "e6 ENUM('''a''','b')") enum_table.drop(checkfirst=True) enum_table.create() @@ -585,8 +589,9 @@ class TypesTest(TestBase, AssertsExecutionResults): # This is known to fail with MySQLDB 1.2.2 beta versions # which return these as sets.Set(['a']), sets.Set(['b']) # (even on Pythons with __builtin__.set) - if testing.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \ - testing.db.dialect.dbapi.version_info >= (1, 2, 2): + if (not testing.against('+zxjdbc') and + testing.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and + testing.db.dialect.dbapi.version_info >= (1, 2, 2)): # these mysqldb seem to always uses 'sets', even on later pythons import sets def convert(value): @@ -602,7 +607,7 @@ class TypesTest(TestBase, AssertsExecutionResults): e.append(tuple([convert(c) for c in row])) expected = e - self.assert_eq(res, expected) + eq_(res, expected) enum_table.drop() @testing.exclude('mysql', '<', (4,), "3.23 can't handle an ENUM of ''") @@ -637,25 +642,52 @@ class TypesTest(TestBase, AssertsExecutionResults): finally: enum_table.drop() + + +class ReflectionTest(TestBase, AssertsExecutionResults): + + __only_on__ = 'mysql' + def test_default_reflection(self): """Test reflection of column defaults.""" def_table = Table('mysql_def', MetaData(testing.db), Column('c1', String(10), DefaultClause('')), Column('c2', String(10), DefaultClause('0')), - Column('c3', String(10), DefaultClause('abc'))) + Column('c3', String(10), DefaultClause('abc')), + Column('c4', TIMESTAMP, DefaultClause('2009-04-05 12:00:00')), + Column('c5', TIMESTAMP, ), + + ) + def_table.create() try: - def_table.create() reflected = Table('mysql_def', MetaData(testing.db), - autoload=True) - for t in def_table, reflected: - assert t.c.c1.server_default.arg == '' - assert t.c.c2.server_default.arg == '0' - assert t.c.c3.server_default.arg == 'abc' + autoload=True) finally: def_table.drop() + + assert def_table.c.c1.server_default.arg == '' + assert def_table.c.c2.server_default.arg == '0' + assert def_table.c.c3.server_default.arg == 'abc' + assert def_table.c.c4.server_default.arg == '2009-04-05 12:00:00' + + assert str(reflected.c.c1.server_default.arg) == "''" + assert str(reflected.c.c2.server_default.arg) == "'0'" + assert str(reflected.c.c3.server_default.arg) == "'abc'" + assert str(reflected.c.c4.server_default.arg) == "'2009-04-05 12:00:00'" + + reflected.create() + try: + reflected2 = Table('mysql_def', MetaData(testing.db), autoload=True) + finally: + reflected.drop() + assert str(reflected2.c.c1.server_default.arg) == "''" + assert str(reflected2.c.c2.server_default.arg) == "'0'" + assert str(reflected2.c.c3.server_default.arg) == "'abc'" + assert str(reflected2.c.c4.server_default.arg) == "'2009-04-05 12:00:00'" + def test_reflection_on_include_columns(self): """Test reflection of include_columns to be sure they respect case.""" @@ -700,8 +732,8 @@ class TypesTest(TestBase, AssertsExecutionResults): ( mysql.MSSmallInteger(4), mysql.MSSmallInteger(4), ), ( mysql.MSMediumInteger(), mysql.MSMediumInteger(), ), ( mysql.MSMediumInteger(8), mysql.MSMediumInteger(8), ), - ( Binary(3), mysql.MSBlob(3), ), - ( Binary(), mysql.MSBlob() ), + ( Binary(3), mysql.TINYBLOB(), ), + ( Binary(), mysql.BLOB() ), ( mysql.MSBinary(3), mysql.MSBinary(3), ), ( mysql.MSVarBinary(3),), ( mysql.MSVarBinary(), mysql.MSBlob()), @@ -734,14 +766,15 @@ class TypesTest(TestBase, AssertsExecutionResults): # in a view, e.g. char -> varchar, tinyblob -> mediumblob # # Not sure exactly which point version has the fix. - if db.dialect.server_version_info(db.connect()) < (5, 0, 11): + if db.dialect.server_version_info < (5, 0, 11): tables = rt, else: tables = rt, rv for table in tables: for i, reflected in enumerate(table.c): - assert isinstance(reflected.type, type(expected[i])) + assert isinstance(reflected.type, type(expected[i])), \ + "element %d: %r not instance of %r" % (i, reflected.type, type(expected[i])) finally: db.execute('DROP VIEW mysql_types_v') finally: @@ -802,17 +835,12 @@ class TypesTest(TestBase, AssertsExecutionResults): tbl.insert().execute() if 'int_y' in tbl.c: assert select([tbl.c.int_y]).scalar() == 1 - assert list(tbl.select().execute().fetchone()).count(1) == 1 + assert list(tbl.select().execute().first()).count(1) == 1 else: - assert 1 not in list(tbl.select().execute().fetchone()) + assert 1 not in list(tbl.select().execute().first()) finally: meta.drop_all() - def assert_eq(self, got, wanted): - if got != wanted: - print "Expected %s" % wanted - print "Found %s" % got - eq_(got, wanted) class SQLTest(TestBase, AssertsCompiledSQL): @@ -909,11 +937,11 @@ class SQLTest(TestBase, AssertsCompiledSQL): (m.MSBit, "t.col"), # this is kind of sucky. thank you default arguments! - (NUMERIC, "CAST(t.col AS DECIMAL(10, 2))"), - (DECIMAL, "CAST(t.col AS DECIMAL(10, 2))"), - (Numeric, "CAST(t.col AS DECIMAL(10, 2))"), - (m.MSNumeric, "CAST(t.col AS DECIMAL(10, 2))"), - (m.MSDecimal, "CAST(t.col AS DECIMAL(10, 2))"), + (NUMERIC, "CAST(t.col AS DECIMAL)"), + (DECIMAL, "CAST(t.col AS DECIMAL)"), + (Numeric, "CAST(t.col AS DECIMAL)"), + (m.MSNumeric, "CAST(t.col AS DECIMAL)"), + (m.MSDecimal, "CAST(t.col AS DECIMAL)"), (FLOAT, "t.col"), (Float, "t.col"), @@ -928,8 +956,8 @@ class SQLTest(TestBase, AssertsCompiledSQL): (DateTime, "CAST(t.col AS DATETIME)"), (Date, "CAST(t.col AS DATE)"), (Time, "CAST(t.col AS TIME)"), - (m.MSDateTime, "CAST(t.col AS DATETIME)"), - (m.MSDate, "CAST(t.col AS DATE)"), + (DateTime, "CAST(t.col AS DATETIME)"), + (Date, "CAST(t.col AS DATE)"), (m.MSTime, "CAST(t.col AS TIME)"), (m.MSTimeStamp, "CAST(t.col AS DATETIME)"), (m.MSYear, "t.col"), @@ -998,12 +1026,11 @@ class SQLTest(TestBase, AssertsCompiledSQL): class RawReflectionTest(TestBase): def setup(self): - self.dialect = mysql.dialect() - self.reflector = mysql.MySQLSchemaReflector( - self.dialect.identifier_preparer) + dialect = mysql.dialect() + self.parser = mysql.MySQLTableDefinitionParser(dialect, dialect.identifier_preparer) def test_key_reflection(self): - regex = self.reflector._re_key + regex = self.parser._re_key assert regex.match(' PRIMARY KEY (`id`),') assert regex.match(' PRIMARY KEY USING BTREE (`id`),') @@ -1023,37 +1050,11 @@ class ExecutionTest(TestBase): cx = engine.connect() meta = MetaData() - - assert ('mysql', 'charset') not in cx.info - assert ('mysql', 'force_charset') not in cx.info - - cx.execute(text("SELECT 1")).fetchall() - assert ('mysql', 'charset') not in cx.info - - meta.reflect(cx) - assert ('mysql', 'charset') in cx.info - - cx.execute(text("SET @squiznart=123")) - assert ('mysql', 'charset') in cx.info - - # the charset invalidation is very conservative - cx.execute(text("SET TIMESTAMP = DEFAULT")) - assert ('mysql', 'charset') not in cx.info - - cx.info[('mysql', 'force_charset')] = 'latin1' - - assert engine.dialect._detect_charset(cx) == 'latin1' - assert cx.info[('mysql', 'charset')] == 'latin1' - - del cx.info[('mysql', 'force_charset')] - del cx.info[('mysql', 'charset')] + charset = engine.dialect._detect_charset(cx) meta.reflect(cx) - assert ('mysql', 'charset') in cx.info - - # String execution doesn't go through the detector. - cx.execute("SET TIMESTAMP = DEFAULT") - assert ('mysql', 'charset') in cx.info + eq_(cx.dialect._connection_charset, charset) + cx.close() class MatchTest(TestBase, AssertsCompiledSQL): @@ -1102,9 +1103,10 @@ class MatchTest(TestBase, AssertsCompiledSQL): metadata.drop_all() def test_expression(self): + format = testing.db.dialect.paramstyle == 'format' and '%s' or '?' self.assert_compile( matchtable.c.title.match('somstr'), - "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)") + "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)" % format) def test_simple_match(self): results = (matchtable.select(). @@ -1162,6 +1164,5 @@ class MatchTest(TestBase, AssertsCompiledSQL): def colspec(c): - return testing.db.dialect.schemagenerator(testing.db.dialect, - testing.db, None, None).get_column_specification(c) + return testing.db.dialect.ddl_compiler(testing.db.dialect, c.table).get_column_specification(c) diff --git a/test/dialect/test_oracle.py b/test/dialect/test_oracle.py index d9d64806e8..53e0f9ec2f 100644 --- a/test/dialect/test_oracle.py +++ b/test/dialect/test_oracle.py @@ -2,12 +2,14 @@ from sqlalchemy.test.testing import eq_ from sqlalchemy import * +from sqlalchemy import types as sqltypes from sqlalchemy.sql import table, column -from sqlalchemy.databases import oracle from sqlalchemy.test import * from sqlalchemy.test.testing import eq_ from sqlalchemy.test.engines import testing_engine +from sqlalchemy.dialects.oracle import cx_oracle, base as oracle from sqlalchemy.engine import default +from sqlalchemy.util import jython import os @@ -43,10 +45,10 @@ class CompileTest(TestBase, AssertsCompiledSQL): meta = MetaData() parent = Table('parent', meta, Column('id', Integer, primary_key=True), Column('name', String(50)), - owner='ed') + schema='ed') child = Table('child', meta, Column('id', Integer, primary_key=True), Column('parent_id', Integer, ForeignKey('ed.parent.id')), - owner = 'ed') + schema = 'ed') self.assert_compile(parent.join(child), "ed.parent JOIN ed.child ON ed.parent.id = ed.child.parent_id") @@ -342,6 +344,25 @@ class TypesTest(TestBase, AssertsCompiledSQL): b = bindparam("foo", u"hello world!") assert b.type.dialect_impl(dialect).get_dbapi_type(dbapi) == 'STRING' + def test_type_adapt(self): + dialect = cx_oracle.dialect() + + for start, test in [ + (DateTime(), cx_oracle._OracleDateTime), + (TIMESTAMP(), cx_oracle._OracleTimestamp), + (oracle.OracleRaw(), cx_oracle._OracleRaw), + (String(), String), + (VARCHAR(), VARCHAR), + (String(50), String), + (Unicode(), Unicode), + (Text(), cx_oracle._OracleText), + (UnicodeText(), cx_oracle._OracleUnicodeText), + (NCHAR(), NCHAR), + (oracle.RAW(50), cx_oracle._OracleRaw), + ]: + assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect)) + + def test_reflect_raw(self): types_table = Table( 'all_types', MetaData(testing.db), @@ -354,16 +375,16 @@ class TypesTest(TestBase, AssertsCompiledSQL): def test_reflect_nvarchar(self): metadata = MetaData(testing.db) t = Table('t', metadata, - Column('data', oracle.OracleNVarchar(255)) + Column('data', sqltypes.NVARCHAR(255)) ) metadata.create_all() try: m2 = MetaData(testing.db) t2 = Table('t', m2, autoload=True) - assert isinstance(t2.c.data.type, oracle.OracleNVarchar) + assert isinstance(t2.c.data.type, sqltypes.NVARCHAR) data = u'm’a réveillé.' t2.insert().execute(data=data) - eq_(t2.select().execute().fetchone()['data'], data) + eq_(t2.select().execute().first()['data'], data) finally: metadata.drop_all() @@ -391,7 +412,7 @@ class TypesTest(TestBase, AssertsCompiledSQL): t.create(engine) try: engine.execute(t.insert(), id=1, data='this is text', bindata='this is binary') - row = engine.execute(t.select()).fetchone() + row = engine.execute(t.select()).first() eq_(row['data'].read(), 'this is text') eq_(row['bindata'].read(), 'this is binary') finally: @@ -408,7 +429,6 @@ class BufferedColumnTest(TestBase, AssertsCompiledSQL): Column('data', Binary) ) meta.create_all() - stream = os.path.join(os.path.dirname(__file__), "..", 'binary_data_one.dat') stream = file(stream).read(12000) @@ -420,17 +440,18 @@ class BufferedColumnTest(TestBase, AssertsCompiledSQL): meta.drop_all() def test_fetch(self): - eq_( - binary_table.select().execute().fetchall() , - [(i, stream) for i in range(1, 11)], - ) + result = binary_table.select().execute().fetchall() + if jython: + result = [(i, value.tostring()) for i, value in result] + eq_(result, [(i, stream) for i in range(1, 11)]) + @testing.fails_on('+zxjdbc', 'FIXME: zxjdbc should support this') def test_fetch_single_arraysize(self): eng = testing_engine(options={'arraysize':1}) - eq_( - eng.execute(binary_table.select()).fetchall(), - [(i, stream) for i in range(1, 11)], - ) + result = eng.execute(binary_table.select()).fetchall(), + if jython: + result = [(i, value.tostring()) for i, value in result] + eq_(result, [(i, stream) for i in range(1, 11)]) class SequenceTest(TestBase, AssertsCompiledSQL): def test_basic(self): diff --git a/test/dialect/test_postgres.py b/test/dialect/test_postgresql.py similarity index 66% rename from test/dialect/test_postgres.py rename to test/dialect/test_postgresql.py index 8ca714badc..e1c351a93e 100644 --- a/test/dialect/test_postgres.py +++ b/test/dialect/test_postgresql.py @@ -1,18 +1,19 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy.test import engines import datetime from sqlalchemy import * from sqlalchemy.orm import * -from sqlalchemy import exc -from sqlalchemy.databases import postgres +from sqlalchemy import exc, schema +from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.engine.strategies import MockEngineStrategy from sqlalchemy.test import * from sqlalchemy.sql import table, column - +from sqlalchemy.test.testing import eq_ class SequenceTest(TestBase, AssertsCompiledSQL): def test_basic(self): seq = Sequence("my_seq_no_schema") - dialect = postgres.PGDialect() + dialect = postgresql.PGDialect() assert dialect.identifier_preparer.format_sequence(seq) == "my_seq_no_schema" seq = Sequence("my_seq", schema="some_schema") @@ -22,43 +23,77 @@ class SequenceTest(TestBase, AssertsCompiledSQL): assert dialect.identifier_preparer.format_sequence(seq) == '"Some_Schema"."My_Seq"' class CompileTest(TestBase, AssertsCompiledSQL): - __dialect__ = postgres.dialect() + __dialect__ = postgresql.dialect() def test_update_returning(self): - dialect = postgres.dialect() + dialect = postgresql.dialect() table1 = table('mytable', column('myid', Integer), column('name', String(128)), column('description', String(128)), ) - u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + u = update(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) - u = update(table1, values=dict(name='foo'), postgres_returning=[table1]) + u = update(table1, values=dict(name='foo')).returning(table1) self.assert_compile(u, "UPDATE mytable SET name=%(name)s "\ "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) - u = update(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)]) - self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect) + u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name) AS length_1", dialect=dialect) + def test_insert_returning(self): - dialect = postgres.dialect() + dialect = postgresql.dialect() table1 = table('mytable', column('myid', Integer), column('name', String(128)), column('description', String(128)), ) - i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + i = insert(table1, values=dict(name='foo')).returning(table1.c.myid, table1.c.name) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect) - i = insert(table1, values=dict(name='foo'), postgres_returning=[table1]) + i = insert(table1, values=dict(name='foo')).returning(table1) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) "\ "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) - i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)]) - self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect) + i = insert(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name) AS length_1", dialect=dialect) + + @testing.uses_deprecated(r".*argument is deprecated. Please use statement.returning.*") + def test_old_returning_names(self): + dialect = postgresql.dialect() + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) + + u = update(table1, values=dict(name='foo'), postgresql_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) + + i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect) + + def test_create_partial_index(self): + tbl = Table('testtbl', MetaData(), Column('data',Integer)) + idx = Index('test_idx1', tbl.c.data, postgresql_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + + self.assert_compile(schema.CreateIndex(idx), + "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgresql.dialect()) + + @testing.uses_deprecated(r".*'postgres_where' argument has been renamed.*") + def test_old_create_partial_index(self): + tbl = Table('testtbl', MetaData(), Column('data',Integer)) + idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + + self.assert_compile(schema.CreateIndex(idx), + "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgresql.dialect()) def test_extract(self): t = table('t', column('col1')) @@ -70,72 +105,20 @@ class CompileTest(TestBase, AssertsCompiledSQL): "FROM t" % field) -class ReturningTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' - - @testing.exclude('postgres', '<', (8, 2), '8.3+ feature') - def test_update_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) - - result = table.update(table.c.persons > 4, dict(full=True), postgres_returning=[table.c.id]).execute() - eq_(result.fetchall(), [(1,)]) - - result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - eq_(result2.fetchall(), [(1,True),(2,False)]) - finally: - table.drop() - - @testing.exclude('postgres', '<', (8, 2), '8.3+ feature') - def test_insert_returning(self): - meta = MetaData(testing.db) - table = Table('tables', meta, - Column('id', Integer, primary_key=True), - Column('persons', Integer), - Column('full', Boolean) - ) - table.create() - try: - result = table.insert(postgres_returning=[table.c.id]).execute({'persons': 1, 'full': False}) - - eq_(result.fetchall(), [(1,)]) - - @testing.fails_on('postgres', 'Known limitation of psycopg2') - def test_executemany(): - # return value is documented as failing with psycopg2/executemany - result2 = table.insert(postgres_returning=[table]).execute( - [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) - eq_(result2.fetchall(), [(2, 2, False), (3,3,True)]) - - test_executemany() - - result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False}) - eq_([dict(row) for row in result3], [{'double_id':8}]) - - result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons') - eq_([dict(row) for row in result4], [{'persons': 10}]) - finally: - table.drop() - - class InsertTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): global metadata + cls.engine= testing.db metadata = MetaData(testing.db) def teardown(self): metadata.drop_all() metadata.tables.clear() + if self.engine is not testing.db: + self.engine.dispose() def test_compiled_insert(self): table = Table('testtable', metadata, @@ -144,7 +127,7 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() - ins = table.insert(values={'data':bindparam('x')}).compile() + ins = table.insert(inline=True, values={'data':bindparam('x')}).compile() ins.execute({'x':"five"}, {'x':"seven"}) assert table.select().execute().fetchall() == [(1, 'five'), (2, 'seven')] @@ -155,6 +138,13 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() self._assert_data_with_sequence(table, "my_seq") + def test_sequence_returning_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, Sequence('my_seq'), primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_with_sequence_returning(table, "my_seq") + def test_opt_sequence_insert(self): table = Table('testtable', metadata, Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True), @@ -162,6 +152,13 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() self._assert_data_autoincrement(table) + def test_opt_sequence_returning_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, Sequence('my_seq', optional=True), primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_autoincrement_returning(table) + def test_autoincrement_insert(self): table = Table('testtable', metadata, Column('id', Integer, primary_key=True), @@ -169,6 +166,13 @@ class InsertTest(TestBase, AssertsExecutionResults): metadata.create_all() self._assert_data_autoincrement(table) + def test_autoincrement_returning_insert(self): + table = Table('testtable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + metadata.create_all() + self._assert_data_autoincrement_returning(table) + def test_noautoincrement_insert(self): table = Table('testtable', metadata, Column('id', Integer, primary_key=True, autoincrement=False), @@ -177,14 +181,17 @@ class InsertTest(TestBase, AssertsExecutionResults): self._assert_data_noautoincrement(table) def _assert_data_autoincrement(self, table): + self.engine = engines.testing_engine(options={'implicit_returning':False}) + metadata.bind = self.engine + def go(): # execute with explicit id r = table.insert().execute({'id':30, 'data':'d1'}) - assert r.last_inserted_ids() == [30] + assert r.inserted_primary_key == [30] # execute with prefetch id r = table.insert().execute({'data':'d2'}) - assert r.last_inserted_ids() == [1] + assert r.inserted_primary_key == [1] # executemany with explicit ids table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) @@ -201,7 +208,7 @@ class InsertTest(TestBase, AssertsExecutionResults): # note that the test framework doesnt capture the "preexecute" of a seqeuence # or default. we just see it in the bind params. - self.assert_sql(testing.db, go, [], with_sequences=[ + self.assert_sql(self.engine, go, [], with_sequences=[ ( "INSERT INTO testtable (id, data) VALUES (:id, :data)", {'id':30, 'data':'d1'} @@ -242,19 +249,19 @@ class InsertTest(TestBase, AssertsExecutionResults): # test the same series of events using a reflected # version of the table - m2 = MetaData(testing.db) + m2 = MetaData(self.engine) table = Table(table.name, m2, autoload=True) def go(): table.insert().execute({'id':30, 'data':'d1'}) r = table.insert().execute({'data':'d2'}) - assert r.last_inserted_ids() == [5] + assert r.inserted_primary_key == [5] table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) table.insert().execute({'data':'d5'}, {'data':'d6'}) table.insert(inline=True).execute({'id':33, 'data':'d7'}) table.insert(inline=True).execute({'data':'d8'}) - self.assert_sql(testing.db, go, [], with_sequences=[ + self.assert_sql(self.engine, go, [], with_sequences=[ ( "INSERT INTO testtable (id, data) VALUES (:id, :data)", {'id':30, 'data':'d1'} @@ -293,7 +300,127 @@ class InsertTest(TestBase, AssertsExecutionResults): ] table.delete().execute() + def _assert_data_autoincrement_returning(self, table): + self.engine = engines.testing_engine(options={'implicit_returning':True}) + metadata.bind = self.engine + + def go(): + # execute with explicit id + r = table.insert().execute({'id':30, 'data':'d1'}) + assert r.inserted_primary_key == [30] + + # execute with prefetch id + r = table.insert().execute({'data':'d2'}) + assert r.inserted_primary_key == [1] + + # executemany with explicit ids + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + + # executemany, uses SERIAL + table.insert().execute({'data':'d5'}, {'data':'d6'}) + + # single execute, explicit id, inline + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + + # single execute, inline, uses SERIAL + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(self.engine, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id", + {'data': 'd2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (1, 'd2'), + (31, 'd3'), + (32, 'd4'), + (2, 'd5'), + (3, 'd6'), + (33, 'd7'), + (4, 'd8'), + ] + table.delete().execute() + + # test the same series of events using a reflected + # version of the table + m2 = MetaData(self.engine) + table = Table(table.name, m2, autoload=True) + + def go(): + table.insert().execute({'id':30, 'data':'d1'}) + r = table.insert().execute({'data':'d2'}) + assert r.inserted_primary_key == [5] + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + table.insert().execute({'data':'d5'}, {'data':'d6'}) + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(self.engine, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (data) VALUES (:data) RETURNING testtable.id", + {'data':'d2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (data) VALUES (:data)", + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (5, 'd2'), + (31, 'd3'), + (32, 'd4'), + (6, 'd5'), + (7, 'd6'), + (33, 'd7'), + (8, 'd8'), + ] + table.delete().execute() + def _assert_data_with_sequence(self, table, seqname): + self.engine = engines.testing_engine(options={'implicit_returning':False}) + metadata.bind = self.engine + def go(): table.insert().execute({'id':30, 'data':'d1'}) table.insert().execute({'data':'d2'}) @@ -302,7 +429,7 @@ class InsertTest(TestBase, AssertsExecutionResults): table.insert(inline=True).execute({'id':33, 'data':'d7'}) table.insert(inline=True).execute({'data':'d8'}) - self.assert_sql(testing.db, go, [], with_sequences=[ + self.assert_sql(self.engine, go, [], with_sequences=[ ( "INSERT INTO testtable (id, data) VALUES (:id, :data)", {'id':30, 'data':'d1'} @@ -343,18 +470,76 @@ class InsertTest(TestBase, AssertsExecutionResults): # cant test reflection here since the Sequence must be # explicitly specified + def _assert_data_with_sequence_returning(self, table, seqname): + self.engine = engines.testing_engine(options={'implicit_returning':True}) + metadata.bind = self.engine + + def go(): + table.insert().execute({'id':30, 'data':'d1'}) + table.insert().execute({'data':'d2'}) + table.insert().execute({'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}) + table.insert().execute({'data':'d5'}, {'data':'d6'}) + table.insert(inline=True).execute({'id':33, 'data':'d7'}) + table.insert(inline=True).execute({'data':'d8'}) + + self.assert_sql(self.engine, go, [], with_sequences=[ + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + {'id':30, 'data':'d1'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('my_seq'), :data) RETURNING testtable.id", + {'data':'d2'} + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':31, 'data':'d3'}, {'id':32, 'data':'d4'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname, + [{'data':'d5'}, {'data':'d6'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (:id, :data)", + [{'id':33, 'data':'d7'}] + ), + ( + "INSERT INTO testtable (id, data) VALUES (nextval('%s'), :data)" % seqname, + [{'data':'d8'}] + ), + ]) + + assert table.select().execute().fetchall() == [ + (30, 'd1'), + (1, 'd2'), + (31, 'd3'), + (32, 'd4'), + (2, 'd5'), + (3, 'd6'), + (33, 'd7'), + (4, 'd8'), + ] + + # cant test reflection here since the Sequence must be + # explicitly specified + def _assert_data_noautoincrement(self, table): + self.engine = engines.testing_engine(options={'implicit_returning':False}) + metadata.bind = self.engine + table.insert().execute({'id':30, 'data':'d1'}) - try: - table.insert().execute({'data':'d2'}) - assert False - except exc.IntegrityError, e: - assert "violates not-null constraint" in str(e) - try: - table.insert().execute({'data':'d2'}, {'data':'d3'}) - assert False - except exc.IntegrityError, e: - assert "violates not-null constraint" in str(e) + + if self.engine.driver == 'pg8000': + exception_cls = exc.ProgrammingError + else: + exception_cls = exc.IntegrityError + + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}) + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'}) + + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}) + + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'}) table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'}) table.insert(inline=True).execute({'id':33, 'data':'d4'}) @@ -369,19 +554,12 @@ class InsertTest(TestBase, AssertsExecutionResults): # test the same series of events using a reflected # version of the table - m2 = MetaData(testing.db) + m2 = MetaData(self.engine) table = Table(table.name, m2, autoload=True) table.insert().execute({'id':30, 'data':'d1'}) - try: - table.insert().execute({'data':'d2'}) - assert False - except exc.IntegrityError, e: - assert "violates not-null constraint" in str(e) - try: - table.insert().execute({'data':'d2'}, {'data':'d3'}) - assert False - except exc.IntegrityError, e: - assert "violates not-null constraint" in str(e) + + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}) + assert_raises_message(exception_cls, "violates not-null constraint", table.insert().execute, {'data':'d2'}, {'data':'d3'}) table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'}) table.insert(inline=True).execute({'id':33, 'data':'d4'}) @@ -396,36 +574,36 @@ class InsertTest(TestBase, AssertsExecutionResults): class DomainReflectionTest(TestBase, AssertsExecutionResults): "Test PostgreSQL domains" - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): con = testing.db.connect() for ddl in ('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42', - 'CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0'): + 'CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0'): try: con.execute(ddl) except exc.SQLError, e: if not "already exists" in str(e): raise e con.execute('CREATE TABLE testtable (question integer, answer testdomain)') - con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)') - con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)') + con.execute('CREATE TABLE test_schema.testtable(question integer, answer test_schema.testdomain, anything integer)') + con.execute('CREATE TABLE crosschema (question integer, answer test_schema.testdomain)') @classmethod def teardown_class(cls): con = testing.db.connect() con.execute('DROP TABLE testtable') - con.execute('DROP TABLE alt_schema.testtable') + con.execute('DROP TABLE test_schema.testtable') con.execute('DROP TABLE crosschema') con.execute('DROP DOMAIN testdomain') - con.execute('DROP DOMAIN alt_schema.testdomain') + con.execute('DROP DOMAIN test_schema.testdomain') def test_table_is_reflected(self): metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True) eq_(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns") - eq_(table.c.answer.type.__class__, postgres.PGInteger) + assert isinstance(table.c.answer.type, Integer) def test_domain_is_reflected(self): metadata = MetaData(testing.db) @@ -433,15 +611,15 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): eq_(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value") assert not table.columns.answer.nullable, "Expected reflected column to not be nullable." - def test_table_is_reflected_alt_schema(self): + def test_table_is_reflected_test_schema(self): metadata = MetaData(testing.db) - table = Table('testtable', metadata, autoload=True, schema='alt_schema') + table = Table('testtable', metadata, autoload=True, schema='test_schema') eq_(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns") - eq_(table.c.anything.type.__class__, postgres.PGInteger) + assert isinstance(table.c.anything.type, Integer) def test_schema_domain_is_reflected(self): metadata = MetaData(testing.db) - table = Table('testtable', metadata, autoload=True, schema='alt_schema') + table = Table('testtable', metadata, autoload=True, schema='test_schema') eq_(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value") assert table.columns.answer.nullable, "Expected reflected column to be nullable." @@ -452,10 +630,10 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): assert table.columns.answer.nullable, "Expected reflected column to be nullable." def test_unknown_types(self): - from sqlalchemy.databases import postgres + from sqlalchemy.databases import postgresql - ischema_names = postgres.ischema_names - postgres.ischema_names = {} + ischema_names = postgresql.PGDialect.ischema_names + postgresql.PGDialect.ischema_names = {} try: m2 = MetaData(testing.db) assert_raises(exc.SAWarning, Table, "testtable", m2, autoload=True) @@ -467,11 +645,11 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): assert t3.c.answer.type.__class__ == sa.types.NullType finally: - postgres.ischema_names = ischema_names + postgresql.PGDialect.ischema_names = ischema_names -class MiscTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' +class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): + __only_on__ = 'postgresql' def test_date_reflection(self): m1 = MetaData(testing.db) @@ -536,26 +714,26 @@ class MiscTest(TestBase, AssertsExecutionResults): 'FROM mytable') def test_schema_reflection(self): - """note: this test requires that the 'alt_schema' schema be separate and accessible by the test user""" + """note: this test requires that the 'test_schema' schema be separate and accessible by the test user""" meta1 = MetaData(testing.db) users = Table('users', meta1, Column('user_id', Integer, primary_key = True), Column('user_name', String(30), nullable = False), - schema="alt_schema" + schema="test_schema" ) addresses = Table('email_addresses', meta1, Column('address_id', Integer, primary_key = True), Column('remote_user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(20)), - schema="alt_schema" + schema="test_schema" ) meta1.create_all() try: meta2 = MetaData(testing.db) - addresses = Table('email_addresses', meta2, autoload=True, schema="alt_schema") - users = Table('users', meta2, mustexist=True, schema="alt_schema") + addresses = Table('email_addresses', meta2, autoload=True, schema="test_schema") + users = Table('users', meta2, mustexist=True, schema="test_schema") print users print addresses @@ -574,12 +752,12 @@ class MiscTest(TestBase, AssertsExecutionResults): referer = Table("referer", meta1, Column("id", Integer, primary_key=True), Column("ref", Integer, ForeignKey('subject.id')), - schema="alt_schema") + schema="test_schema") meta1.create_all() try: meta2 = MetaData(testing.db) subject = Table("subject", meta2, autoload=True) - referer = Table("referer", meta2, schema="alt_schema", autoload=True) + referer = Table("referer", meta2, schema="test_schema", autoload=True) print str(subject.join(referer).onclause) self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause)) finally: @@ -589,19 +767,19 @@ class MiscTest(TestBase, AssertsExecutionResults): meta1 = MetaData(testing.db) subject = Table("subject", meta1, Column("id", Integer, primary_key=True), - schema='alt_schema_2' + schema='test_schema_2' ) referer = Table("referer", meta1, Column("id", Integer, primary_key=True), - Column("ref", Integer, ForeignKey('alt_schema_2.subject.id')), - schema="alt_schema") + Column("ref", Integer, ForeignKey('test_schema_2.subject.id')), + schema="test_schema") meta1.create_all() try: meta2 = MetaData(testing.db) - subject = Table("subject", meta2, autoload=True, schema="alt_schema_2") - referer = Table("referer", meta2, schema="alt_schema", autoload=True) + subject = Table("subject", meta2, autoload=True, schema="test_schema_2") + referer = Table("referer", meta2, schema="test_schema", autoload=True) print str(subject.join(referer).onclause) self.assert_((subject.c.id==referer.c.ref).compare(subject.join(referer).onclause)) finally: @@ -611,7 +789,7 @@ class MiscTest(TestBase, AssertsExecutionResults): meta = MetaData(testing.db) users = Table('users', meta, Column('id', Integer, primary_key=True), - Column('name', String(50)), schema='alt_schema') + Column('name', String(50)), schema='test_schema') users.create() try: users.insert().execute(id=1, name='name1') @@ -646,15 +824,15 @@ class MiscTest(TestBase, AssertsExecutionResults): user_name VARCHAR NOT NULL, user_password VARCHAR NOT NULL ); - """, None) + """) t = Table("speedy_users", meta, autoload=True) r = t.insert().execute(user_name='user', user_password='lala') - assert r.last_inserted_ids() == [1] + assert r.inserted_primary_key == [1] l = t.select().execute().fetchall() assert l == [(1, 'user', 'lala')] finally: - testing.db.execute("drop table speedy_users", None) + testing.db.execute("drop table speedy_users") @testing.emits_warning() def test_index_reflection(self): @@ -676,10 +854,10 @@ class MiscTest(TestBase, AssertsExecutionResults): testing.db.execute(""" create index idx1 on party ((id || name)) - """, None) + """) testing.db.execute(""" create unique index idx2 on party (id) where name = 'test' - """, None) + """) testing.db.execute(""" create index idx3 on party using btree @@ -713,35 +891,42 @@ class MiscTest(TestBase, AssertsExecutionResults): warnings.warn = capture_warnings._orig_showwarning m1.drop_all() - def test_create_partial_index(self): - tbl = Table('testtbl', MetaData(), Column('data',Integer)) - idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10)) - - executed_sql = [] - mock_strategy = MockEngineStrategy() - mock_conn = mock_strategy.create('postgres://', executed_sql.append) + def test_set_isolation_level(self): + """Test setting the isolation level with create_engine""" + eng = create_engine(testing.db.url) + eq_( + eng.execute("show transaction isolation level").scalar(), + 'read committed') + eng = create_engine(testing.db.url, isolation_level="SERIALIZABLE") + eq_( + eng.execute("show transaction isolation level").scalar(), + 'serializable') + eng = create_engine(testing.db.url, isolation_level="FOO") - idx.create(mock_conn) + if testing.db.driver == 'zxjdbc': + exception_cls = eng.dialect.dbapi.Error + else: + exception_cls = eng.dialect.dbapi.ProgrammingError + assert_raises(exception_cls, eng.execute, "show transaction isolation level") - assert executed_sql == ['CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10'] class TimezoneTest(TestBase, AssertsExecutionResults): """Test timezone-aware datetimes. - psycopg will return a datetime with a tzinfo attached to it, if postgres + psycopg will return a datetime with a tzinfo attached to it, if postgresql returns it. python then will not let you compare a datetime with a tzinfo to a datetime that doesnt have one. this test illustrates two ways to have datetime types with and without timezone info. """ - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): global tztable, notztable, metadata metadata = MetaData(testing.db) - # current_timestamp() in postgres is assumed to return TIMESTAMP WITH TIMEZONE + # current_timestamp() in postgresql is assumed to return TIMESTAMP WITH TIMEZONE tztable = Table('tztable', metadata, Column("id", Integer, primary_key=True), Column("date", DateTime(timezone=True), onupdate=func.current_timestamp()), @@ -762,17 +947,17 @@ class TimezoneTest(TestBase, AssertsExecutionResults): somedate = testing.db.connect().scalar(func.current_timestamp().select()) tztable.insert().execute(id=1, name='row1', date=somedate) c = tztable.update(tztable.c.id==1).execute(name='newname') - print tztable.select(tztable.c.id==1).execute().fetchone() + print tztable.select(tztable.c.id==1).execute().first() def test_without_timezone(self): # get a date without a tzinfo somedate = datetime.datetime(2005, 10,20, 11, 52, 00) notztable.insert().execute(id=1, name='row1', date=somedate) c = notztable.update(notztable.c.id==1).execute(name='newname') - print notztable.select(tztable.c.id==1).execute().fetchone() + print notztable.select(tztable.c.id==1).execute().first() class ArrayTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): @@ -781,10 +966,14 @@ class ArrayTest(TestBase, AssertsExecutionResults): arrtable = Table('arrtable', metadata, Column('id', Integer, primary_key=True), - Column('intarr', postgres.PGArray(Integer)), - Column('strarr', postgres.PGArray(String(convert_unicode=True)), nullable=False) + Column('intarr', postgresql.PGArray(Integer)), + Column('strarr', postgresql.PGArray(String(convert_unicode=True)), nullable=False) ) metadata.create_all() + + def teardown(self): + arrtable.delete().execute() + @classmethod def teardown_class(cls): metadata.drop_all() @@ -792,34 +981,38 @@ class ArrayTest(TestBase, AssertsExecutionResults): def test_reflect_array_column(self): metadata2 = MetaData(testing.db) tbl = Table('arrtable', metadata2, autoload=True) - assert isinstance(tbl.c.intarr.type, postgres.PGArray) - assert isinstance(tbl.c.strarr.type, postgres.PGArray) + assert isinstance(tbl.c.intarr.type, postgresql.PGArray) + assert isinstance(tbl.c.strarr.type, postgresql.PGArray) assert isinstance(tbl.c.intarr.type.item_type, Integer) assert isinstance(tbl.c.strarr.type.item_type, String) + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') def test_insert_array(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) results = arrtable.select().execute().fetchall() eq_(len(results), 1) eq_(results[0]['intarr'], [1,2,3]) eq_(results[0]['strarr'], ['abc','def']) - arrtable.delete().execute() + @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays') + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') def test_array_where(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) arrtable.insert().execute(intarr=[4,5,6], strarr='ABC') results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall() eq_(len(results), 1) eq_(results[0]['intarr'], [1,2,3]) - arrtable.delete().execute() + @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays') + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') def test_array_concat(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall() eq_(len(results), 1) eq_(results[0][0], [1,2,3,4,5,6]) - arrtable.delete().execute() + @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays') + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') def test_array_subtype_resultprocessor(self): arrtable.insert().execute(intarr=[4,5,6], strarr=[[u'm\xe4\xe4'], [u'm\xf6\xf6']]) arrtable.insert().execute(intarr=[1,2,3], strarr=[u'm\xe4\xe4', u'm\xf6\xf6']) @@ -827,13 +1020,14 @@ class ArrayTest(TestBase, AssertsExecutionResults): eq_(len(results), 2) eq_(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6']) eq_(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']]) - arrtable.delete().execute() + @testing.fails_on('postgresql+pg8000', 'pg8000 has poor support for PG arrays') + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc has no support for PG arrays') def test_array_mutability(self): class Foo(object): pass footable = Table('foo', metadata, Column('id', Integer, primary_key=True), - Column('intarr', postgres.PGArray(Integer), nullable=True) + Column('intarr', postgresql.PGArray(Integer), nullable=True) ) mapper(Foo, footable) metadata.create_all() @@ -870,19 +1064,19 @@ class ArrayTest(TestBase, AssertsExecutionResults): sess.add(foo) sess.flush() -class TimeStampTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' - - @testing.uses_deprecated() +class TimestampTest(TestBase, AssertsExecutionResults): + __only_on__ = 'postgresql' + def test_timestamp(self): engine = testing.db connection = engine.connect() - s = select([func.TIMESTAMP("12/25/07").label("ts")]) - result = connection.execute(s).fetchone() + + s = select(["timestamp '2007-12-25'"]) + result = connection.execute(s).first() eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0)) class ServerSideCursorsTest(TestBase, AssertsExecutionResults): - __only_on__ = 'postgres' + __only_on__ = 'postgresql+psycopg2' @classmethod def setup_class(cls): @@ -927,8 +1121,8 @@ class ServerSideCursorsTest(TestBase, AssertsExecutionResults): class SpecialTypesTest(TestBase, ComparesTables): """test DDL and reflection of PG-specific types """ - __only_on__ = 'postgres' - __excluded_on__ = (('postgres', '<', (8, 3, 0)),) + __only_on__ = 'postgresql' + __excluded_on__ = (('postgresql', '<', (8, 3, 0)),) @classmethod def setup_class(cls): @@ -936,11 +1130,11 @@ class SpecialTypesTest(TestBase, ComparesTables): metadata = MetaData(testing.db) table = Table('sometable', metadata, - Column('id', postgres.PGUuid, primary_key=True), - Column('flag', postgres.PGBit), - Column('addr', postgres.PGInet), - Column('addr2', postgres.PGMacAddr), - Column('addr3', postgres.PGCidr) + Column('id', postgresql.PGUuid, primary_key=True), + Column('flag', postgresql.PGBit), + Column('addr', postgresql.PGInet), + Column('addr2', postgresql.PGMacAddr), + Column('addr3', postgresql.PGCidr) ) metadata.create_all() @@ -957,8 +1151,8 @@ class SpecialTypesTest(TestBase, ComparesTables): class MatchTest(TestBase, AssertsCompiledSQL): - __only_on__ = 'postgres' - __excluded_on__ = (('postgres', '<', (8, 3, 0)),) + __only_on__ = 'postgresql' + __excluded_on__ = (('postgresql', '<', (8, 3, 0)),) @classmethod def setup_class(cls): @@ -992,9 +1186,16 @@ class MatchTest(TestBase, AssertsCompiledSQL): def teardown_class(cls): metadata.drop_all() - def test_expression(self): + @testing.fails_on('postgresql+pg8000', 'uses positional') + @testing.fails_on('postgresql+zxjdbc', 'uses qmark') + def test_expression_pyformat(self): self.assert_compile(matchtable.c.title.match('somstr'), "matchtable.title @@ to_tsquery(%(title_1)s)") + @testing.fails_on('postgresql+psycopg2', 'uses pyformat') + @testing.fails_on('postgresql+zxjdbc', 'uses qmark') + def test_expression_positional(self): + self.assert_compile(matchtable.c.title.match('somstr'), "matchtable.title @@ to_tsquery(%s)") + def test_simple_match(self): results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall() eq_([2, 5], [r.id for r in results]) diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index eb4581e20f..448ee947c0 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -4,7 +4,7 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import datetime from sqlalchemy import * from sqlalchemy import exc, sql -from sqlalchemy.databases import sqlite +from sqlalchemy.dialects.sqlite import base as sqlite, pysqlite as pysqlite_dialect from sqlalchemy.test import * @@ -19,7 +19,7 @@ class TestTypes(TestBase, AssertsExecutionResults): meta = MetaData(testing.db) t = Table('bool_table', meta, Column('id', Integer, primary_key=True), - Column('boo', sqlite.SLBoolean)) + Column('boo', Boolean)) try: meta.create_all() @@ -39,7 +39,7 @@ class TestTypes(TestBase, AssertsExecutionResults): def test_time_microseconds(self): dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125) # 125 usec eq_(str(dt), '2008-06-27 12:00:00.000125') - sldt = sqlite.SLDateTime() + sldt = sqlite._SLDateTime() bp = sldt.bind_processor(None) eq_(bp(dt), '2008-06-27 12:00:00.000125') @@ -69,59 +69,44 @@ class TestTypes(TestBase, AssertsExecutionResults): bindproc = t.dialect_impl(dialect).bind_processor(dialect) assert not bindproc or isinstance(bindproc(u"some string"), unicode) - @testing.uses_deprecated('Using String type with no length') def test_type_reflection(self): # (ask_for, roundtripped_as_if_different) - specs = [( String(), sqlite.SLString(), ), - ( String(1), sqlite.SLString(1), ), - ( String(3), sqlite.SLString(3), ), - ( Text(), sqlite.SLText(), ), - ( Unicode(), sqlite.SLString(), ), - ( Unicode(1), sqlite.SLString(1), ), - ( Unicode(3), sqlite.SLString(3), ), - ( UnicodeText(), sqlite.SLText(), ), - ( CLOB, sqlite.SLText(), ), - ( sqlite.SLChar(1), ), - ( CHAR(3), sqlite.SLChar(3), ), - ( NCHAR(2), sqlite.SLChar(2), ), - ( SmallInteger(), sqlite.SLSmallInteger(), ), - ( sqlite.SLSmallInteger(), ), - ( Binary(3), sqlite.SLBinary(), ), - ( Binary(), sqlite.SLBinary() ), - ( sqlite.SLBinary(3), sqlite.SLBinary(), ), - ( NUMERIC, sqlite.SLNumeric(), ), - ( NUMERIC(10,2), sqlite.SLNumeric(10,2), ), - ( Numeric, sqlite.SLNumeric(), ), - ( Numeric(10, 2), sqlite.SLNumeric(10, 2), ), - ( DECIMAL, sqlite.SLNumeric(), ), - ( DECIMAL(10, 2), sqlite.SLNumeric(10, 2), ), - ( Float, sqlite.SLFloat(), ), - ( sqlite.SLNumeric(), ), - ( INT, sqlite.SLInteger(), ), - ( Integer, sqlite.SLInteger(), ), - ( sqlite.SLInteger(), ), - ( TIMESTAMP, sqlite.SLDateTime(), ), - ( DATETIME, sqlite.SLDateTime(), ), - ( DateTime, sqlite.SLDateTime(), ), - ( sqlite.SLDateTime(), ), - ( DATE, sqlite.SLDate(), ), - ( Date, sqlite.SLDate(), ), - ( sqlite.SLDate(), ), - ( TIME, sqlite.SLTime(), ), - ( Time, sqlite.SLTime(), ), - ( sqlite.SLTime(), ), - ( BOOLEAN, sqlite.SLBoolean(), ), - ( Boolean, sqlite.SLBoolean(), ), - ( sqlite.SLBoolean(), ), + specs = [( String(), String(), ), + ( String(1), String(1), ), + ( String(3), String(3), ), + ( Text(), Text(), ), + ( Unicode(), String(), ), + ( Unicode(1), String(1), ), + ( Unicode(3), String(3), ), + ( UnicodeText(), Text(), ), + ( CHAR(1), ), + ( CHAR(3), CHAR(3), ), + ( NUMERIC, NUMERIC(), ), + ( NUMERIC(10,2), NUMERIC(10,2), ), + ( Numeric, NUMERIC(), ), + ( Numeric(10, 2), NUMERIC(10, 2), ), + ( DECIMAL, DECIMAL(), ), + ( DECIMAL(10, 2), DECIMAL(10, 2), ), + ( Float, Float(), ), + ( NUMERIC(), ), + ( TIMESTAMP, TIMESTAMP(), ), + ( DATETIME, DATETIME(), ), + ( DateTime, DateTime(), ), + ( DateTime(), ), + ( DATE, DATE(), ), + ( Date, Date(), ), + ( TIME, TIME(), ), + ( Time, Time(), ), + ( BOOLEAN, BOOLEAN(), ), + ( Boolean, Boolean(), ), ] columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)] db = testing.db m = MetaData(db) t_table = Table('types', m, *columns) + m.create_all() try: - m.create_all() - m2 = MetaData(db) rt = Table('types', m2, autoload=True) try: @@ -131,7 +116,7 @@ class TestTypes(TestBase, AssertsExecutionResults): expected = [len(c) > 1 and c[1] or c[0] for c in specs] for table in rt, rv: for i, reflected in enumerate(table.c): - assert isinstance(reflected.type, type(expected[i])), type(expected[i]) + assert isinstance(reflected.type, type(expected[i])), "%d: %r" % (i, type(expected[i])) finally: db.execute('DROP VIEW types_v') finally: @@ -163,7 +148,7 @@ class TestDefaults(TestBase, AssertsExecutionResults): rt = Table('t_defaults', m2, autoload=True) expected = [c[1] for c in specs] for i, reflected in enumerate(rt.c): - eq_(reflected.server_default.arg.text, expected[i]) + eq_(str(reflected.server_default.arg), expected[i]) finally: m.drop_all() @@ -173,7 +158,7 @@ class TestDefaults(TestBase, AssertsExecutionResults): db = testing.db m = MetaData(db) - expected = ["'my_default'", '0'] + expected = ["my_default", '0'] table = """CREATE TABLE r_defaults ( data VARCHAR(40) DEFAULT 'my_default', val INTEGER NOT NULL DEFAULT 0 @@ -184,7 +169,7 @@ class TestDefaults(TestBase, AssertsExecutionResults): rt = Table('r_defaults', m, autoload=True) for i, reflected in enumerate(rt.c): - eq_(reflected.server_default.arg.text, expected[i]) + eq_(str(reflected.server_default.arg), expected[i]) finally: db.execute("DROP TABLE r_defaults") @@ -247,24 +232,24 @@ class DialectTest(TestBase, AssertsExecutionResults): def test_attached_as_schema(self): cx = testing.db.connect() try: - cx.execute('ATTACH DATABASE ":memory:" AS alt_schema') + cx.execute('ATTACH DATABASE ":memory:" AS test_schema') dialect = cx.dialect - assert dialect.table_names(cx, 'alt_schema') == [] + assert dialect.table_names(cx, 'test_schema') == [] meta = MetaData(cx) Table('created', meta, Column('id', Integer), - schema='alt_schema') + schema='test_schema') alt_master = Table('sqlite_master', meta, autoload=True, - schema='alt_schema') + schema='test_schema') meta.create_all(cx) - eq_(dialect.table_names(cx, 'alt_schema'), + eq_(dialect.table_names(cx, 'test_schema'), ['created']) assert len(alt_master.c) > 0 meta.clear() reflected = Table('created', meta, autoload=True, - schema='alt_schema') + schema='test_schema') assert len(reflected.c) == 1 cx.execute(reflected.insert(), dict(id=1)) @@ -282,9 +267,9 @@ class DialectTest(TestBase, AssertsExecutionResults): # note that sqlite_master is cleared, above meta.drop_all() - assert dialect.table_names(cx, 'alt_schema') == [] + assert dialect.table_names(cx, 'test_schema') == [] finally: - cx.execute('DETACH DATABASE alt_schema') + cx.execute('DETACH DATABASE test_schema') @testing.exclude('sqlite', '<', (2, 6), 'no database support') def test_temp_table_reflection(self): @@ -305,6 +290,20 @@ class DialectTest(TestBase, AssertsExecutionResults): pass raise + def test_set_isolation_level(self): + """Test setting the read uncommitted/serializable levels""" + eng = create_engine(testing.db.url) + eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 0) + + eng = create_engine(testing.db.url, isolation_level="READ UNCOMMITTED") + eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 1) + + eng = create_engine(testing.db.url, isolation_level="SERIALIZABLE") + eq_(eng.execute("PRAGMA read_uncommitted").scalar(), 0) + + assert_raises(exc.ArgumentError, create_engine, testing.db.url, + isolation_level="FOO") + class SQLTest(TestBase, AssertsCompiledSQL): """Tests SQLite-dialect specific compilation.""" diff --git a/test/engine/test_bind.py b/test/engine/test_bind.py index 7fd3009bca..1122f1632f 100644 --- a/test/engine/test_bind.py +++ b/test/engine/test_bind.py @@ -121,7 +121,7 @@ class BindTest(testing.TestBase): table = Table('test_table', metadata, Column('foo', Integer)) - metadata.connect(bind) + metadata.bind = bind assert metadata.bind is table.bind is bind metadata.create_all() @@ -199,7 +199,7 @@ class BindTest(testing.TestBase): try: e = elem(bind=bind) assert e.bind is bind - e.execute() + e.execute().close() finally: if isinstance(bind, engine.Connection): bind.close() diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py index 5716006d93..434a5d873c 100644 --- a/test/engine/test_ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -1,12 +1,13 @@ -from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message -from sqlalchemy.schema import DDL +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy.schema import DDL, CheckConstraint, AddConstraint, DropConstraint from sqlalchemy import create_engine from sqlalchemy import MetaData, Integer, String from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column import sqlalchemy as tsa from sqlalchemy.test import TestBase, testing, engines - +from sqlalchemy.test.testing import AssertsCompiledSQL +from nose import SkipTest class DDLEventTest(TestBase): class Canary(object): @@ -15,25 +16,25 @@ class DDLEventTest(TestBase): self.schema_item = schema_item self.bind = bind - def before_create(self, action, schema_item, bind): + def before_create(self, action, schema_item, bind, **kw): assert self.state is None assert schema_item is self.schema_item assert bind is self.bind self.state = action - def after_create(self, action, schema_item, bind): + def after_create(self, action, schema_item, bind, **kw): assert self.state in ('before-create', 'skipped') assert schema_item is self.schema_item assert bind is self.bind self.state = action - def before_drop(self, action, schema_item, bind): + def before_drop(self, action, schema_item, bind, **kw): assert self.state is None assert schema_item is self.schema_item assert bind is self.bind self.state = action - def after_drop(self, action, schema_item, bind): + def after_drop(self, action, schema_item, bind, **kw): assert self.state in ('before-drop', 'skipped') assert schema_item is self.schema_item assert bind is self.bind @@ -232,7 +233,33 @@ class DDLExecutionTest(TestBase): assert 'klptzyxm' not in strings assert 'xyzzy' in strings assert 'fnord' in strings - + + def test_conditional_constraint(self): + metadata, users, engine = self.metadata, self.users, self.engine + nonpg_mock = engines.mock_engine(dialect_name='sqlite') + pg_mock = engines.mock_engine(dialect_name='postgresql') + + constraint = CheckConstraint('a < b',name="my_test_constraint", table=users) + + # by placing the constraint in an Add/Drop construct, + # the 'inline_ddl' flag is set to False + AddConstraint(constraint, on='postgresql').execute_at("after-create", users) + DropConstraint(constraint, on='postgresql').execute_at("before-drop", users) + + metadata.create_all(bind=nonpg_mock) + strings = " ".join(str(x) for x in nonpg_mock.mock) + assert "my_test_constraint" not in strings + metadata.drop_all(bind=nonpg_mock) + strings = " ".join(str(x) for x in nonpg_mock.mock) + assert "my_test_constraint" not in strings + + metadata.create_all(bind=pg_mock) + strings = " ".join(str(x) for x in pg_mock.mock) + assert "my_test_constraint" in strings + metadata.drop_all(bind=pg_mock) + strings = " ".join(str(x) for x in pg_mock.mock) + assert "my_test_constraint" in strings + def test_metadata(self): metadata, engine = self.metadata, self.engine DDL('mxyzptlk').execute_at('before-create', metadata) @@ -255,7 +282,10 @@ class DDLExecutionTest(TestBase): assert 'fnord' in strings def test_ddl_execute(self): - engine = create_engine('sqlite:///') + try: + engine = create_engine('sqlite:///') + except ImportError: + raise SkipTest('Requires sqlite') cx = engine.connect() table = self.users ddl = DDL('SELECT 1') @@ -286,7 +316,7 @@ class DDLExecutionTest(TestBase): r = eval(py) assert list(r) == [(1,)], py -class DDLTest(TestBase): +class DDLTest(TestBase, AssertsCompiledSQL): def mock_engine(self): executor = lambda *a, **kw: None engine = create_engine(testing.db.name + '://', @@ -297,7 +327,6 @@ class DDLTest(TestBase): def test_tokens(self): m = MetaData() - bind = self.mock_engine() sane_alone = Table('t', m, Column('id', Integer)) sane_schema = Table('t', m, Column('id', Integer), schema='s') insane_alone = Table('t t', m, Column('id', Integer)) @@ -305,20 +334,21 @@ class DDLTest(TestBase): ddl = DDL('%(schema)s-%(table)s-%(fullname)s') - eq_(ddl._expand(sane_alone, bind), '-t-t') - eq_(ddl._expand(sane_schema, bind), 's-t-s.t') - eq_(ddl._expand(insane_alone, bind), '-"t t"-"t t"') - eq_(ddl._expand(insane_schema, bind), - '"s s"-"t t"-"s s"."t t"') + dialect = self.mock_engine().dialect + self.assert_compile(ddl.against(sane_alone), '-t-t', dialect=dialect) + self.assert_compile(ddl.against(sane_schema), 's-t-s.t', dialect=dialect) + self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect) + self.assert_compile(ddl.against(insane_schema), '"s s"-"t t"-"s s"."t t"', dialect=dialect) # overrides are used piece-meal and verbatim. ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s', context={'schema':'S S', 'table': 'T T', 'bonus': 'b'}) - eq_(ddl._expand(sane_alone, bind), 'S S-T T-t-b') - eq_(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b') - eq_(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b') - eq_(ddl._expand(insane_schema, bind), - 'S S-T T-"s s"."t t"-b') + + self.assert_compile(ddl.against(sane_alone), 'S S-T T-t-b', dialect=dialect) + self.assert_compile(ddl.against(sane_schema), 'S S-T T-s.t-b', dialect=dialect) + self.assert_compile(ddl.against(insane_alone), 'S S-T T-"t t"-b', dialect=dialect) + self.assert_compile(ddl.against(insane_schema), 'S S-T T-"s s"."t t"-b', dialect=dialect) + def test_filter(self): cx = self.mock_engine() diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 08bf80fe2f..4783c55080 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -15,18 +15,20 @@ class ExecuteTest(TestBase): global users, metadata metadata = MetaData(testing.db) users = Table('users', metadata, - Column('user_id', INT, primary_key = True), + Column('user_id', INT, primary_key = True, test_needs_autoincrement=True), Column('user_name', VARCHAR(20)), ) metadata.create_all() + @engines.close_first def teardown(self): testing.db.connect().execute(users.delete()) + @classmethod def teardown_class(cls): metadata.drop_all() - @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite') + @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc', '+zxjdbc') def test_raw_qmark(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack")) @@ -38,7 +40,8 @@ class ExecuteTest(TestBase): assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')] conn.execute("delete from users") - @testing.fails_on_everything_except('mysql', 'postgres') + @testing.fails_on_everything_except('mysql+mysqldb', 'postgresql') + @testing.fails_on('postgresql+zxjdbc', 'sprintf not supported') # some psycopg2 versions bomb this. def test_raw_sprintf(self): for conn in (testing.db, testing.db.connect()): @@ -52,8 +55,8 @@ class ExecuteTest(TestBase): # pyformat is supported for mysql, but skipping because a few driver # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2) - @testing.skip_if(lambda: testing.against('mysql'), 'db-api flaky') - @testing.fails_on_everything_except('postgres') + @testing.skip_if(lambda: testing.against('mysql+mysqldb'), 'db-api flaky') + @testing.fails_on_everything_except('postgresql+psycopg2') def test_raw_python(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'}) @@ -63,7 +66,7 @@ class ExecuteTest(TestBase): assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')] conn.execute("delete from users") - @testing.fails_on_everything_except('sqlite', 'oracle') + @testing.fails_on_everything_except('sqlite', 'oracle+cx_oracle') def test_raw_named(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'}) @@ -81,11 +84,12 @@ class ExecuteTest(TestBase): except tsa.exc.DBAPIError: assert True - @testing.fails_on('mssql', 'rowcount returns -1') def test_empty_insert(self): """test that execute() interprets [] as a list with no params""" result = testing.db.execute(users.insert().values(user_name=bindparam('name')), []) - eq_(result.rowcount, 1) + eq_(testing.db.execute(users.select()).fetchall(), [ + (1, None) + ]) class ProxyConnectionTest(TestBase): @testing.fails_on('firebird', 'Data type unknown') @@ -102,6 +106,7 @@ class ProxyConnectionTest(TestBase): return execute(clauseelement, *multiparams, **params) def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + print "CE", statement, parameters cursor_stmts.append( (statement, parameters, None) ) @@ -118,8 +123,8 @@ class ProxyConnectionTest(TestBase): break for engine in ( - engines.testing_engine(options=dict(proxy=MyProxy())), - engines.testing_engine(options=dict(proxy=MyProxy(), strategy='threadlocal')) + engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy())), + engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy(), strategy='threadlocal')) ): m = MetaData(engine) @@ -131,6 +136,7 @@ class ProxyConnectionTest(TestBase): t1.insert().execute(c1=6) assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')] finally: + pass m.drop_all() engine.dispose() @@ -143,14 +149,14 @@ class ProxyConnectionTest(TestBase): ("DROP TABLE t1", {}, None) ] - if engine.dialect.preexecute_pk_sequences: + if True: # or engine.dialect.preexecute_pk_sequences: cursor = [ - ("CREATE TABLE t1", {}, None), + ("CREATE TABLE t1", {}, ()), ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']), ("SELECT lower", {'lower_2':'Foo'}, ['Foo']), ("INSERT INTO t1 (c1, c2)", {'c2': 'foo', 'c1': 6}, [6, 'foo']), - ("select * from t1", {}, None), - ("DROP TABLE t1", {}, None) + ("select * from t1", {}, ()), + ("DROP TABLE t1", {}, ()) ] else: cursor = [ diff --git a/test/engine/test_metadata.py b/test/engine/test_metadata.py index ca4fbaa48a..784a7b9ce6 100644 --- a/test/engine/test_metadata.py +++ b/test/engine/test_metadata.py @@ -1,11 +1,11 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message import pickle -from sqlalchemy import MetaData -from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey +from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey, MetaData from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column +from sqlalchemy import schema import sqlalchemy as tsa -from sqlalchemy.test import TestBase, ComparesTables, testing, engines +from sqlalchemy.test import TestBase, ComparesTables, AssertsCompiledSQL, testing, engines from sqlalchemy.test.testing import eq_ class MetaDataTest(TestBase, ComparesTables): @@ -83,7 +83,7 @@ class MetaDataTest(TestBase, ComparesTables): meta.create_all(testing.db) try: - for test, has_constraints in ((test_to_metadata, True), (test_pickle, True), (test_pickle_via_reflect, False)): + for test, has_constraints in ((test_to_metadata, True), (test_pickle, True),(test_pickle_via_reflect, False)): table_c, table2_c = test() self.assert_tables_equal(table, table_c) self.assert_tables_equal(table2, table2_c) @@ -143,29 +143,30 @@ class MetaDataTest(TestBase, ComparesTables): MetaData(testing.db), autoload=True) -class TableOptionsTest(TestBase): - def setup(self): - self.engine = engines.mock_engine() - self.metadata = MetaData(self.engine) - +class TableOptionsTest(TestBase, AssertsCompiledSQL): def test_prefixes(self): - table1 = Table("temporary_table_1", self.metadata, + table1 = Table("temporary_table_1", MetaData(), Column("col1", Integer), prefixes = ["TEMPORARY"]) - table1.create() - assert [str(x) for x in self.engine.mock if 'CREATE TEMPORARY TABLE' in str(x)] - del self.engine.mock[:] - table2 = Table("temporary_table_2", self.metadata, + + self.assert_compile( + schema.CreateTable(table1), + "CREATE TEMPORARY TABLE temporary_table_1 (col1 INTEGER)" + ) + + table2 = Table("temporary_table_2", MetaData(), Column("col1", Integer), prefixes = ["VIRTUAL"]) - table2.create() - assert [str(x) for x in self.engine.mock if 'CREATE VIRTUAL TABLE' in str(x)] + self.assert_compile( + schema.CreateTable(table2), + "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)" + ) def test_table_info(self): - - t1 = Table('foo', self.metadata, info={'x':'y'}) - t2 = Table('bar', self.metadata, info={}) - t3 = Table('bat', self.metadata) + metadata = MetaData() + t1 = Table('foo', metadata, info={'x':'y'}) + t2 = Table('bar', metadata, info={}) + t3 = Table('bat', metadata) assert t1.info == {'x':'y'} assert t2.info == {} assert t3.info == {} diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index 6b7ac37b20..90c0969bed 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -1,4 +1,6 @@ -import ConfigParser, StringIO +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import ConfigParser +import StringIO import sqlalchemy.engine.url as url from sqlalchemy import create_engine, engine_from_config import sqlalchemy as tsa @@ -28,8 +30,6 @@ class ParseConnectTest(TestBase): 'dbtype://username:apples%2Foranges@hostspec/mydatabase', ): u = url.make_url(text) - print u, text - print "username=", u.username, "password=", u.password, "database=", u.database, "host=", u.host assert u.drivername == 'dbtype' assert u.username == 'username' or u.username is None assert u.password == 'password' or u.password == 'apples/oranges' or u.password is None @@ -41,21 +41,28 @@ class CreateEngineTest(TestBase): def test_connect_query(self): dbapi = MockDBAPI(foober='12', lala='18', fooz='somevalue') - # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg - e = create_engine('postgres://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', module=dbapi) + e = create_engine( + 'postgresql://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', + module=dbapi, + _initialize=False + ) c = e.connect() def test_kwargs(self): dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue') - # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg - e = create_engine('postgres://scott:tiger@somehost/test?fooz=somevalue', connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, module=dbapi) + e = create_engine( + 'postgresql://scott:tiger@somehost/test?fooz=somevalue', + connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, + module=dbapi, + _initialize=False + ) c = e.connect() def test_coerce_config(self): raw = r""" [prefixed] -sqlalchemy.url=postgres://scott:tiger@somehost/test?fooz=somevalue +sqlalchemy.url=postgresql://scott:tiger@somehost/test?fooz=somevalue sqlalchemy.convert_unicode=0 sqlalchemy.echo=false sqlalchemy.echo_pool=1 @@ -65,7 +72,7 @@ sqlalchemy.pool_size=2 sqlalchemy.pool_threadlocal=1 sqlalchemy.pool_timeout=10 [plain] -url=postgres://scott:tiger@somehost/test?fooz=somevalue +url=postgresql://scott:tiger@somehost/test?fooz=somevalue convert_unicode=0 echo=0 echo_pool=1 @@ -79,7 +86,7 @@ pool_timeout=10 ini.readfp(StringIO.StringIO(raw)) expected = { - 'url': 'postgres://scott:tiger@somehost/test?fooz=somevalue', + 'url': 'postgresql://scott:tiger@somehost/test?fooz=somevalue', 'convert_unicode': 0, 'echo': False, 'echo_pool': True, @@ -97,17 +104,17 @@ pool_timeout=10 self.assert_(tsa.engine._coerce_config(plain, '') == expected) def test_engine_from_config(self): - dbapi = MockDBAPI() + dbapi = mock_dbapi config = { - 'sqlalchemy.url':'postgres://scott:tiger@somehost/test?fooz=somevalue', + 'sqlalchemy.url':'postgresql://scott:tiger@somehost/test?fooz=somevalue', 'sqlalchemy.pool_recycle':'50', 'sqlalchemy.echo':'true' } e = engine_from_config(config, module=dbapi) assert e.pool._recycle == 50 - assert e.url == url.make_url('postgres://scott:tiger@somehost/test?fooz=somevalue') + assert e.url == url.make_url('postgresql://scott:tiger@somehost/test?fooz=somevalue') assert e.echo is True def test_custom(self): @@ -116,109 +123,77 @@ pool_timeout=10 def connect(): return dbapi.connect(foober=12, lala=18, fooz='somevalue', hoho={'this':'dict'}) - # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg - e = create_engine('postgres://', creator=connect, module=dbapi) + # start the postgresql dialect, but put our mock DBAPI as the module instead of psycopg + e = create_engine('postgresql://', creator=connect, module=dbapi, _initialize=False) c = e.connect() def test_recycle(self): dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue') - e = create_engine('postgres://', pool_recycle=472, module=dbapi) + e = create_engine('postgresql://', pool_recycle=472, module=dbapi, _initialize=False) assert e.pool._recycle == 472 def test_badargs(self): - # good arg, use MockDBAPI to prevent oracle import errors - e = create_engine('oracle://', use_ansi=True, module=MockDBAPI()) - - try: - e = create_engine("foobar://", module=MockDBAPI()) - assert False - except ImportError: - assert True + assert_raises(ImportError, create_engine, "foobar://", module=mock_dbapi) # bad arg - try: - e = create_engine('postgres://', use_ansi=True, module=MockDBAPI()) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'postgresql://', use_ansi=True, module=mock_dbapi) # bad arg - try: - e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI()) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'oracle://', lala=5, use_ansi=True, module=mock_dbapi) - try: - e = create_engine('postgres://', lala=5, module=MockDBAPI()) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'postgresql://', lala=5, module=mock_dbapi) - try: - e = create_engine('sqlite://', lala=5) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine,'sqlite://', lala=5, module=mock_sqlite_dbapi) - try: - e = create_engine('mysql://', use_unicode=True, module=MockDBAPI()) - assert False - except TypeError: - assert True - - try: - # sqlite uses SingletonThreadPool which doesnt have max_overflow - e = create_engine('sqlite://', max_overflow=5) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'mysql+mysqldb://', use_unicode=True, module=mock_dbapi) - e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True) + # sqlite uses SingletonThreadPool which doesnt have max_overflow + assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=5, + module=mock_sqlite_dbapi) - e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True) try: - c = e.connect() - assert False - except tsa.exc.DBAPIError: - assert True + e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True) + except ImportError: + # no sqlite + pass + else: + # raises DBAPIerror due to use_unicode not a sqlite arg + assert_raises(tsa.exc.DBAPIError, e.connect) def test_urlattr(self): """test the url attribute on ``Engine``.""" - e = create_engine('mysql://scott:tiger@localhost/test', module=MockDBAPI()) + e = create_engine('mysql://scott:tiger@localhost/test', module=mock_dbapi, _initialize=False) u = url.make_url('mysql://scott:tiger@localhost/test') - e2 = create_engine(u, module=MockDBAPI()) + e2 = create_engine(u, module=mock_dbapi, _initialize=False) assert e.url.drivername == e2.url.drivername == 'mysql' assert e.url.username == e2.url.username == 'scott' assert e2.url is u def test_poolargs(self): """test that connection pool args make it thru""" - e = create_engine('postgres://', creator=None, pool_recycle=50, echo_pool=None, module=MockDBAPI()) + e = create_engine('postgresql://', creator=None, pool_recycle=50, echo_pool=None, module=mock_dbapi, _initialize=False) assert e.pool._recycle == 50 # these args work for QueuePool - e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=MockDBAPI()) + e = create_engine('postgresql://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=mock_dbapi) - try: - # but not SingletonThreadPool - e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.SingletonThreadPool) - assert False - except TypeError: - assert True + # but not SingletonThreadPool + assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=8, pool_timeout=60, + poolclass=tsa.pool.SingletonThreadPool, module=mock_sqlite_dbapi) class MockDBAPI(object): def __init__(self, **kwargs): self.kwargs = kwargs self.paramstyle = 'named' - def connect(self, **kwargs): - print kwargs, self.kwargs + def connect(self, *args, **kwargs): for k in self.kwargs: assert k in kwargs, "key %s not present in dictionary" % k assert kwargs[k]==self.kwargs[k], "value %s does not match %s" % (kwargs[k], self.kwargs[k]) return MockConnection() class MockConnection(object): + def get_server_info(self): + return "5.0" def close(self): pass def cursor(self): @@ -227,4 +202,6 @@ class MockCursor(object): def close(self): pass mock_dbapi = MockDBAPI() - +mock_sqlite_dbapi = msd = MockDBAPI() +msd.version_info = msd.sqlite_version_info = (99, 9, 9) +msd.sqlite_version = '99.9.9' diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index d135ad337a..68637281e1 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -1,7 +1,8 @@ -import threading, time, gc -from sqlalchemy import pool, interfaces +import threading, time +from sqlalchemy import pool, interfaces, create_engine, select import sqlalchemy as tsa -from sqlalchemy.test import TestBase +from sqlalchemy.test import TestBase, testing +from sqlalchemy.test.util import gc_collect, lazy_gc mcid = 1 @@ -51,7 +52,6 @@ class PoolTest(PoolTestBase): connection2 = manager.connect('foo.db') connection3 = manager.connect('bar.db') - print "connection " + repr(connection) self.assert_(connection.cursor() is not None) self.assert_(connection is connection2) self.assert_(connection2 is not connection3) @@ -70,8 +70,6 @@ class PoolTest(PoolTestBase): connection = manager.connect('foo.db') connection2 = manager.connect('foo.db') - print "connection " + repr(connection) - self.assert_(connection.cursor() is not None) self.assert_(connection is not connection2) @@ -103,7 +101,8 @@ class PoolTest(PoolTestBase): c2.close() else: c2 = None - + lazy_gc() + if useclose: c1 = p.connect() c2 = p.connect() @@ -117,6 +116,8 @@ class PoolTest(PoolTestBase): # extra tests with QueuePool to ensure connections get __del__()ed when dereferenced if isinstance(p, pool.QueuePool): + lazy_gc() + self.assert_(p.checkedout() == 0) c1 = p.connect() c2 = p.connect() @@ -126,6 +127,7 @@ class PoolTest(PoolTestBase): else: c2 = None c1 = None + lazy_gc() self.assert_(p.checkedout() == 0) def test_properties(self): @@ -164,6 +166,8 @@ class PoolTest(PoolTestBase): def __init__(self): if hasattr(self, 'connect'): self.connect = self.inst_connect + if hasattr(self, 'first_connect'): + self.first_connect = self.inst_first_connect if hasattr(self, 'checkout'): self.checkout = self.inst_checkout if hasattr(self, 'checkin'): @@ -171,14 +175,17 @@ class PoolTest(PoolTestBase): self.clear() def clear(self): self.connected = [] + self.first_connected = [] self.checked_out = [] self.checked_in = [] - def assert_total(innerself, conn, cout, cin): + def assert_total(innerself, conn, fconn, cout, cin): self.assert_(len(innerself.connected) == conn) + self.assert_(len(innerself.first_connected) == fconn) self.assert_(len(innerself.checked_out) == cout) self.assert_(len(innerself.checked_in) == cin) - def assert_in(innerself, item, in_conn, in_cout, in_cin): + def assert_in(innerself, item, in_conn, in_fconn, in_cout, in_cin): self.assert_((item in innerself.connected) == in_conn) + self.assert_((item in innerself.first_connected) == in_fconn) self.assert_((item in innerself.checked_out) == in_cout) self.assert_((item in innerself.checked_in) == in_cin) def inst_connect(self, con, record): @@ -186,6 +193,11 @@ class PoolTest(PoolTestBase): assert con is not None assert record is not None self.connected.append(con) + def inst_first_connect(self, con, record): + print "first_connect(%s, %s)" % (con, record) + assert con is not None + assert record is not None + self.first_connected.append(con) def inst_checkout(self, con, record, proxy): print "checkout(%s, %s, %s)" % (con, record, proxy) assert con is not None @@ -203,6 +215,9 @@ class PoolTest(PoolTestBase): class ListenConnect(InstrumentingListener): def connect(self, con, record): pass + class ListenFirstConnect(InstrumentingListener): + def first_connect(self, con, record): + pass class ListenCheckOut(InstrumentingListener): def checkout(self, con, record, proxy, num): pass @@ -214,40 +229,43 @@ class PoolTest(PoolTestBase): return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), use_threadlocal=False, **kw) - def assert_listeners(p, total, conn, cout, cin): + def assert_listeners(p, total, conn, fconn, cout, cin): for instance in (p, p.recreate()): self.assert_(len(instance.listeners) == total) self.assert_(len(instance._on_connect) == conn) + self.assert_(len(instance._on_first_connect) == fconn) self.assert_(len(instance._on_checkout) == cout) self.assert_(len(instance._on_checkin) == cin) p = _pool() - assert_listeners(p, 0, 0, 0, 0) + assert_listeners(p, 0, 0, 0, 0, 0) p.add_listener(ListenAll()) - assert_listeners(p, 1, 1, 1, 1) + assert_listeners(p, 1, 1, 1, 1, 1) p.add_listener(ListenConnect()) - assert_listeners(p, 2, 2, 1, 1) + assert_listeners(p, 2, 2, 1, 1, 1) + + p.add_listener(ListenFirstConnect()) + assert_listeners(p, 3, 2, 2, 1, 1) p.add_listener(ListenCheckOut()) - assert_listeners(p, 3, 2, 2, 1) + assert_listeners(p, 4, 2, 2, 2, 1) p.add_listener(ListenCheckIn()) - assert_listeners(p, 4, 2, 2, 2) + assert_listeners(p, 5, 2, 2, 2, 2) del p - print "----" snoop = ListenAll() p = _pool(listeners=[snoop]) - assert_listeners(p, 1, 1, 1, 1) + assert_listeners(p, 1, 1, 1, 1, 1) c = p.connect() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 1, 1, 0) cc = c.connection - snoop.assert_in(cc, True, True, False) + snoop.assert_in(cc, True, True, True, False) c.close() - snoop.assert_in(cc, True, True, True) + snoop.assert_in(cc, True, True, True, True) del c, cc snoop.clear() @@ -255,10 +273,11 @@ class PoolTest(PoolTestBase): # this one depends on immediate gc c = p.connect() cc = c.connection - snoop.assert_in(cc, False, True, False) - snoop.assert_total(0, 1, 0) + snoop.assert_in(cc, False, False, True, False) + snoop.assert_total(0, 0, 1, 0) del c, cc - snoop.assert_total(0, 1, 1) + lazy_gc() + snoop.assert_total(0, 0, 1, 1) p.dispose() snoop.clear() @@ -266,44 +285,46 @@ class PoolTest(PoolTestBase): c = p.connect() c.close() c = p.connect() - snoop.assert_total(1, 2, 1) + snoop.assert_total(1, 0, 2, 1) c.close() - snoop.assert_total(1, 2, 2) + snoop.assert_total(1, 0, 2, 2) # invalidation p.dispose() snoop.clear() c = p.connect() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c.invalidate() - snoop.assert_total(1, 1, 1) + snoop.assert_total(1, 0, 1, 1) c.close() - snoop.assert_total(1, 1, 1) + snoop.assert_total(1, 0, 1, 1) del c - snoop.assert_total(1, 1, 1) + lazy_gc() + snoop.assert_total(1, 0, 1, 1) c = p.connect() - snoop.assert_total(2, 2, 1) + snoop.assert_total(2, 0, 2, 1) c.close() del c - snoop.assert_total(2, 2, 2) + lazy_gc() + snoop.assert_total(2, 0, 2, 2) # detached p.dispose() snoop.clear() c = p.connect() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c.detach() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c.close() del c - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c = p.connect() - snoop.assert_total(2, 2, 0) + snoop.assert_total(2, 0, 2, 0) c.close() del c - snoop.assert_total(2, 2, 1) + snoop.assert_total(2, 0, 2, 1) def test_listeners_callables(self): dbapi = MockDBAPI() @@ -362,262 +383,293 @@ class PoolTest(PoolTestBase): c.close() assert counts == [1, 2, 3] + def test_listener_after_oninit(self): + """Test that listeners are called after OnInit is removed""" + called = [] + def listener(*args): + called.append(True) + listener.connect = listener + engine = create_engine(testing.db.url) + engine.pool.add_listener(listener) + engine.execute(select([1])) + assert called, "Listener not called on connect" + + class QueuePoolTest(PoolTestBase): - def testqueuepool_del(self): - self._do_testqueuepool(useclose=False) - - def testqueuepool_close(self): - self._do_testqueuepool(useclose=True) - - def _do_testqueuepool(self, useclose=False): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False) - - def status(pool): - tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout()) - print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup - return tup - - c1 = p.connect() - self.assert_(status(p) == (3,0,-2,1)) - c2 = p.connect() - self.assert_(status(p) == (3,0,-1,2)) - c3 = p.connect() - self.assert_(status(p) == (3,0,0,3)) - c4 = p.connect() - self.assert_(status(p) == (3,0,1,4)) - c5 = p.connect() - self.assert_(status(p) == (3,0,2,5)) - c6 = p.connect() - self.assert_(status(p) == (3,0,3,6)) - if useclose: - c4.close() - c3.close() - c2.close() - else: - c4 = c3 = c2 = None - self.assert_(status(p) == (3,3,3,3)) - if useclose: - c1.close() - c5.close() - c6.close() - else: - c1 = c5 = c6 = None - self.assert_(status(p) == (3,3,0,0)) - c1 = p.connect() - c2 = p.connect() - self.assert_(status(p) == (3, 1, 0, 2), status(p)) - if useclose: - c2.close() - else: - c2 = None - self.assert_(status(p) == (3, 2, 0, 1)) - - def test_timeout(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2) - c1 = p.connect() - c2 = p.connect() - c3 = p.connect() - now = time.time() - try: - c4 = p.connect() - assert False - except tsa.exc.TimeoutError, e: - assert int(time.time() - now) == 2 - - def test_timeout_race(self): - # test a race condition where the initial connecting threads all race - # to queue.Empty, then block on the mutex. each thread consumes a - # connection as they go in. when the limit is reached, the remaining - # threads go in, and get TimeoutError; even though they never got to - # wait for the timeout on queue.get(). the fix involves checking the - # timeout again within the mutex, and if so, unlocking and throwing - # them back to the start of do_get() - p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3) - timeouts = [] - def checkout(): - for x in xrange(1): - now = time.time() - try: - c1 = p.connect() - except tsa.exc.TimeoutError, e: - timeouts.append(int(time.time()) - now) - continue - time.sleep(4) - c1.close() - - threads = [] - for i in xrange(10): - th = threading.Thread(target=checkout) - th.start() - threads.append(th) - for th in threads: - th.join() - - print timeouts - assert len(timeouts) > 0 - for t in timeouts: - assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts) - - def _test_overflow(self, thread_count, max_overflow): - def creator(): - time.sleep(.05) - return mock_dbapi.connect() - - p = pool.QueuePool(creator=creator, - pool_size=3, timeout=2, - max_overflow=max_overflow) - peaks = [] - def whammy(): - for i in range(10): - try: - con = p.connect() - time.sleep(.005) - peaks.append(p.overflow()) - con.close() - del con - except tsa.exc.TimeoutError: - pass - threads = [] - for i in xrange(thread_count): - th = threading.Thread(target=whammy) - th.start() - threads.append(th) - for th in threads: - th.join() - - self.assert_(max(peaks) <= max_overflow) - - def test_no_overflow(self): - self._test_overflow(40, 0) - - def test_max_overflow(self): - self._test_overflow(40, 5) - - def test_mixed_close(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - assert c1 is c2 - c1.close() - c2 = None - assert p.checkedout() == 1 - c1 = None - assert p.checkedout() == 0 - - def test_weakref_kaboom(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - c1.close() - c2 = None - del c1 - del c2 - gc.collect() - assert p.checkedout() == 0 - c3 = p.connect() - assert c3 is not None - - def test_trick_the_counter(self): - """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread - with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an - ambiguous counter. i.e. its not true reference counting.""" - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - assert c1 is c2 - c1.close() - c2 = p.connect() - c2.close() - self.assert_(p.checkedout() != 0) - - c2.close() - self.assert_(p.checkedout() == 0) - - def test_recycle(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3) - - c1 = p.connect() - c_id = id(c1.connection) - c1.close() - c2 = p.connect() - assert id(c2.connection) == c_id - c2.close() - time.sleep(4) - c3= p.connect() - assert id(c3.connection) != c_id - - def test_invalidate(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - c1 = p.connect() - c_id = c1.connection.id - c1.close(); c1=None - c1 = p.connect() - assert c1.connection.id == c_id - c1.invalidate() - c1 = None - - c1 = p.connect() - assert c1.connection.id != c_id - - def test_recreate(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - p2 = p.recreate() - assert p2.size() == 1 - assert p2._use_threadlocal is False - assert p2._max_overflow == 0 - - def test_reconnect(self): - """tests reconnect operations at the pool level. SA's engine/dialect includes another - layer of reconnect support for 'database was lost' errors.""" + def testqueuepool_del(self): + self._do_testqueuepool(useclose=False) + + def testqueuepool_close(self): + self._do_testqueuepool(useclose=True) + + def _do_testqueuepool(self, useclose=False): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False) + + def status(pool): + tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout()) + print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup + return tup + + c1 = p.connect() + self.assert_(status(p) == (3,0,-2,1)) + c2 = p.connect() + self.assert_(status(p) == (3,0,-1,2)) + c3 = p.connect() + self.assert_(status(p) == (3,0,0,3)) + c4 = p.connect() + self.assert_(status(p) == (3,0,1,4)) + c5 = p.connect() + self.assert_(status(p) == (3,0,2,5)) + c6 = p.connect() + self.assert_(status(p) == (3,0,3,6)) + if useclose: + c4.close() + c3.close() + c2.close() + else: + c4 = c3 = c2 = None + lazy_gc() + + self.assert_(status(p) == (3,3,3,3)) + if useclose: + c1.close() + c5.close() + c6.close() + else: + c1 = c5 = c6 = None + lazy_gc() + + self.assert_(status(p) == (3,3,0,0)) + + c1 = p.connect() + c2 = p.connect() + self.assert_(status(p) == (3, 1, 0, 2), status(p)) + if useclose: + c2.close() + else: + c2 = None + lazy_gc() + + self.assert_(status(p) == (3, 2, 0, 1)) + + c1.close() + + lazy_gc() + assert not pool._refs - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - c1 = p.connect() - c_id = c1.connection.id - c1.close(); c1=None - - c1 = p.connect() - assert c1.connection.id == c_id - dbapi.raise_error = True - c1.invalidate() - c1 = None - - c1 = p.connect() - assert c1.connection.id != c_id - - def test_detach(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - - c1 = p.connect() - c1.detach() - c_id = c1.connection.id - - c2 = p.connect() - assert c2.connection.id != c1.connection.id - dbapi.raise_error = True - - c2.invalidate() - c2 = None - - c2 = p.connect() - assert c2.connection.id != c1.connection.id - - con = c1.connection - - assert not con.closed - c1.close() - assert con.closed - - def test_threadfairy(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c1.close() - c2 = p.connect() - assert c2.connection is not None + def test_timeout(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2) + c1 = p.connect() + c2 = p.connect() + c3 = p.connect() + now = time.time() + try: + c4 = p.connect() + assert False + except tsa.exc.TimeoutError, e: + assert int(time.time() - now) == 2 + + def test_timeout_race(self): + # test a race condition where the initial connecting threads all race + # to queue.Empty, then block on the mutex. each thread consumes a + # connection as they go in. when the limit is reached, the remaining + # threads go in, and get TimeoutError; even though they never got to + # wait for the timeout on queue.get(). the fix involves checking the + # timeout again within the mutex, and if so, unlocking and throwing + # them back to the start of do_get() + p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3) + timeouts = [] + def checkout(): + for x in xrange(1): + now = time.time() + try: + c1 = p.connect() + except tsa.exc.TimeoutError, e: + timeouts.append(int(time.time()) - now) + continue + time.sleep(4) + c1.close() + + threads = [] + for i in xrange(10): + th = threading.Thread(target=checkout) + th.start() + threads.append(th) + for th in threads: + th.join() + + print timeouts + assert len(timeouts) > 0 + for t in timeouts: + assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts) + + def _test_overflow(self, thread_count, max_overflow): + def creator(): + time.sleep(.05) + return mock_dbapi.connect() + + p = pool.QueuePool(creator=creator, + pool_size=3, timeout=2, + max_overflow=max_overflow) + peaks = [] + def whammy(): + for i in range(10): + try: + con = p.connect() + time.sleep(.005) + peaks.append(p.overflow()) + con.close() + del con + except tsa.exc.TimeoutError: + pass + threads = [] + for i in xrange(thread_count): + th = threading.Thread(target=whammy) + th.start() + threads.append(th) + for th in threads: + th.join() + + self.assert_(max(peaks) <= max_overflow) + + lazy_gc() + assert not pool._refs + + def test_no_overflow(self): + self._test_overflow(40, 0) + + def test_max_overflow(self): + self._test_overflow(40, 5) + + def test_mixed_close(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + assert c1 is c2 + c1.close() + c2 = None + assert p.checkedout() == 1 + c1 = None + lazy_gc() + assert p.checkedout() == 0 + + lazy_gc() + assert not pool._refs + + def test_weakref_kaboom(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + c1.close() + c2 = None + del c1 + del c2 + gc_collect() + assert p.checkedout() == 0 + c3 = p.connect() + assert c3 is not None + + def test_trick_the_counter(self): + """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread + with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an + ambiguous counter. i.e. its not true reference counting.""" + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + assert c1 is c2 + c1.close() + c2 = p.connect() + c2.close() + self.assert_(p.checkedout() != 0) + + c2.close() + self.assert_(p.checkedout() == 0) + + def test_recycle(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3) + + c1 = p.connect() + c_id = id(c1.connection) + c1.close() + c2 = p.connect() + assert id(c2.connection) == c_id + c2.close() + time.sleep(4) + c3= p.connect() + assert id(c3.connection) != c_id + + def test_invalidate(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + c1 = p.connect() + c_id = c1.connection.id + c1.close(); c1=None + c1 = p.connect() + assert c1.connection.id == c_id + c1.invalidate() + c1 = None + + c1 = p.connect() + assert c1.connection.id != c_id + + def test_recreate(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + p2 = p.recreate() + assert p2.size() == 1 + assert p2._use_threadlocal is False + assert p2._max_overflow == 0 + + def test_reconnect(self): + """tests reconnect operations at the pool level. SA's engine/dialect includes another + layer of reconnect support for 'database was lost' errors.""" + + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + c1 = p.connect() + c_id = c1.connection.id + c1.close(); c1=None + + c1 = p.connect() + assert c1.connection.id == c_id + dbapi.raise_error = True + c1.invalidate() + c1 = None + + c1 = p.connect() + assert c1.connection.id != c_id + + def test_detach(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + + c1 = p.connect() + c1.detach() + c_id = c1.connection.id + + c2 = p.connect() + assert c2.connection.id != c1.connection.id + dbapi.raise_error = True + + c2.invalidate() + c2 = None + + c2 = p.connect() + assert c2.connection.id != c1.connection.id + + con = c1.connection + + assert not con.closed + c1.close() + assert con.closed + + def test_threadfairy(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c1.close() + c2 = p.connect() + assert c2.connection is not None class SingletonThreadPoolTest(PoolTestBase): def test_cleanup(self): diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 3a525c2a70..6afd715155 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1,12 +1,13 @@ from sqlalchemy.test.testing import eq_ +import time import weakref from sqlalchemy import select, MetaData, Integer, String, pool from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column import sqlalchemy as tsa from sqlalchemy.test import TestBase, testing, engines -import time -import gc +from sqlalchemy.test.util import gc_collect + class MockDisconnect(Exception): pass @@ -54,7 +55,7 @@ class MockReconnectTest(TestBase): dbapi = MockDBAPI() # create engine using our current dburi - db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi) + db = tsa.create_engine('postgresql://foo:bar@localhost/test', module=dbapi, _initialize=False) # monkeypatch disconnect checker db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect) @@ -98,7 +99,7 @@ class MockReconnectTest(TestBase): assert id(db.pool) != pid # ensure all connections closed (pool was recycled) - gc.collect() + gc_collect() assert len(dbapi.connections) == 0 conn =db.connect() @@ -118,7 +119,7 @@ class MockReconnectTest(TestBase): pass # assert was invalidated - gc.collect() + gc_collect() assert len(dbapi.connections) == 0 assert not conn.closed assert conn.invalidated @@ -168,7 +169,7 @@ class MockReconnectTest(TestBase): assert conn.invalidated # ensure all connections closed (pool was recycled) - gc.collect() + gc_collect() assert len(dbapi.connections) == 0 # test reconnects @@ -334,7 +335,8 @@ class InvalidateDuringResultTest(TestBase): meta.drop_all() engine.dispose() - @testing.fails_on('mysql', 'FIXME: unknown') + @testing.fails_on('+mysqldb', "Buffers the result set and doesn't check for connection close") + @testing.fails_on('+pg8000', "Buffers the result set and doesn't check for connection close") def test_invalidate_on_results(self): conn = engine.connect() @@ -344,7 +346,7 @@ class InvalidateDuringResultTest(TestBase): engine.test_shutdown() try: - result.fetchone() + print "ghost result: %r" % result.fetchone() assert False except tsa.exc.DBAPIError, e: if not e.connection_invalidated: diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index ea80776a6a..dff9fa1bb6 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -1,17 +1,22 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import StringIO, unicodedata import sqlalchemy as sa +from sqlalchemy import types as sql_types +from sqlalchemy import schema +from sqlalchemy.engine.reflection import Inspector from sqlalchemy import MetaData from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column import sqlalchemy as tsa from sqlalchemy.test import TestBase, ComparesTables, testing, engines +create_inspector = Inspector.from_engine metadata, users = None, None class ReflectionTest(TestBase, ComparesTables): + @testing.exclude('mssql', '<', (10, 0, 0), 'Date is only supported on MSSQL 2008+') @testing.exclude('mysql', '<', (4, 1, 1), 'early types are squirrely') def test_basic_reflection(self): meta = MetaData(testing.db) @@ -22,16 +27,16 @@ class ReflectionTest(TestBase, ComparesTables): Column('test1', sa.CHAR(5), nullable=False), Column('test2', sa.Float(5), nullable=False), Column('test3', sa.Text), - Column('test4', sa.Numeric, nullable = False), - Column('test5', sa.DateTime), + Column('test4', sa.Numeric(10, 2), nullable = False), + Column('test5', sa.Date), Column('parent_user_id', sa.Integer, sa.ForeignKey('engine_users.user_id')), - Column('test6', sa.DateTime, nullable=False), + Column('test6', sa.Date, nullable=False), Column('test7', sa.Text), Column('test8', sa.Binary), Column('test_passivedefault2', sa.Integer, server_default='5'), Column('test9', sa.Binary(100)), - Column('test_numeric', sa.Numeric()), + Column('test10', sa.Numeric(10, 2)), test_needs_fk=True, ) @@ -52,9 +57,35 @@ class ReflectionTest(TestBase, ComparesTables): self.assert_tables_equal(users, reflected_users) self.assert_tables_equal(addresses, reflected_addresses) finally: - addresses.drop() - users.drop() - + meta.drop_all() + + def test_two_foreign_keys(self): + meta = MetaData(testing.db) + t1 = Table('t1', meta, + Column('id', sa.Integer, primary_key=True), + Column('t2id', sa.Integer, sa.ForeignKey('t2.id')), + Column('t3id', sa.Integer, sa.ForeignKey('t3.id')), + test_needs_fk=True + ) + t2 = Table('t2', meta, + Column('id', sa.Integer, primary_key=True), + test_needs_fk=True + ) + t3 = Table('t3', meta, + Column('id', sa.Integer, primary_key=True), + test_needs_fk=True + ) + meta.create_all() + try: + meta2 = MetaData() + t1r, t2r, t3r = [Table(x, meta2, autoload=True, autoload_with=testing.db) for x in ('t1', 't2', 't3')] + + assert t1r.c.t2id.references(t2r.c.id) + assert t1r.c.t3id.references(t3r.c.id) + + finally: + meta.drop_all() + def test_include_columns(self): meta = MetaData(testing.db) foo = Table('foo', meta, *[Column(n, sa.String(30)) @@ -84,26 +115,68 @@ class ReflectionTest(TestBase, ComparesTables): finally: meta.drop_all() + @testing.emits_warning(r".*omitted columns") + def test_include_columns_indexes(self): + m = MetaData(testing.db) + + t1 = Table('t1', m, Column('a', sa.Integer), Column('b', sa.Integer)) + sa.Index('foobar', t1.c.a, t1.c.b) + sa.Index('bat', t1.c.a) + m.create_all() + try: + m2 = MetaData(testing.db) + t2 = Table('t1', m2, autoload=True) + assert len(t2.indexes) == 2 + m2 = MetaData(testing.db) + t2 = Table('t1', m2, autoload=True, include_columns=['a']) + assert len(t2.indexes) == 1 + + m2 = MetaData(testing.db) + t2 = Table('t1', m2, autoload=True, include_columns=['a', 'b']) + assert len(t2.indexes) == 2 + finally: + m.drop_all() + + def test_autoincrement_col(self): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + """ + + meta = MetaData(testing.db) + t1 = Table('test', meta, + Column('id', sa.Integer, primary_key=True), + Column('data', sa.String(50)), + ) + t2 = Table('test2', meta, + Column('id', sa.Integer, sa.ForeignKey('test.id'), primary_key=True), + Column('id2', sa.Integer, primary_key=True), + Column('data', sa.String(50)), + ) + meta.create_all() + try: + m2 = MetaData(testing.db) + t1a = Table('test', m2, autoload=True) + assert t1a._autoincrement_column is t1a.c.id + + t2a = Table('test2', m2, autoload=True) + assert t2a._autoincrement_column is t2a.c.id2 + + finally: + meta.drop_all() + def test_unknown_types(self): meta = MetaData(testing.db) t = Table("test", meta, Column('foo', sa.DateTime)) - import sys - dialect_module = sys.modules[testing.db.dialect.__module__] - - # we're relying on the presence of "ischema_names" in the - # dialect module, else we can't test this. we need to be able - # to get the dialect to not be aware of some type so we temporarily - # monkeypatch. not sure what a better way for this could be, - # except for an established dialect hook or dialect-specific tests - if not hasattr(dialect_module, 'ischema_names'): - return - - ischema_names = dialect_module.ischema_names + ischema_names = testing.db.dialect.ischema_names t.create() - dialect_module.ischema_names = {} + testing.db.dialect.ischema_names = {} try: m2 = MetaData(testing.db) assert_raises(tsa.exc.SAWarning, Table, "test", m2, autoload=True) @@ -115,7 +188,7 @@ class ReflectionTest(TestBase, ComparesTables): assert t3.c.foo.type.__class__ == sa.types.NullType finally: - dialect_module.ischema_names = ischema_names + testing.db.dialect.ischema_names = ischema_names t.drop() def test_basic_override(self): @@ -578,7 +651,6 @@ class ReflectionTest(TestBase, ComparesTables): m9.reflect() self.assert_(not m9.tables) - @testing.fails_on_everything_except('postgres', 'mysql') def test_index_reflection(self): m1 = MetaData(testing.db) t1 = Table('party', m1, @@ -698,7 +770,7 @@ class UnicodeReflectionTest(TestBase): def test_basic(self): try: # the 'convert_unicode' should not get in the way of the reflection - # process. reflecttable for oracle, postgres (others?) expect non-unicode + # process. reflecttable for oracle, postgresql (others?) expect non-unicode # strings in result sets/bind params bind = engines.utf8_engine(options={'convert_unicode':True}) metadata = MetaData(bind) @@ -713,7 +785,8 @@ class UnicodeReflectionTest(TestBase): metadata.create_all() reflected = set(bind.table_names()) - if not names.issubset(reflected): + # Jython 2.5 on Java 5 lacks unicodedata.normalize + if not names.issubset(reflected) and hasattr(unicodedata, 'normalize'): # Python source files in the utf-8 coding seem to normalize # literals as NFC (and the above are explicitly NFC). Maybe # this database normalizes NFD on reflection. @@ -741,23 +814,15 @@ class SchemaTest(TestBase): Column('col1', sa.Integer, primary_key=True), Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')), schema='someschema') - # ensure this doesnt crash - print [t for t in metadata.sorted_tables] - buf = StringIO.StringIO() - def foo(s, p=None): - buf.write(s) - gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo) - gen = gen.dialect.schemagenerator(gen.dialect, gen) - gen.traverse(table1) - gen.traverse(table2) - buf = buf.getvalue() - print buf + + t1 = str(schema.CreateTable(table1).compile(bind=testing.db)) + t2 = str(schema.CreateTable(table2).compile(bind=testing.db)) if testing.db.dialect.preparer(testing.db.dialect).omit_schema: - assert buf.index("CREATE TABLE table1") > -1 - assert buf.index("CREATE TABLE table2") > -1 + assert t1.index("CREATE TABLE table1") > -1 + assert t2.index("CREATE TABLE table2") > -1 else: - assert buf.index("CREATE TABLE someschema.table1") > -1 - assert buf.index("CREATE TABLE someschema.table2") > -1 + assert t1.index("CREATE TABLE someschema.table1") > -1 + assert t2.index("CREATE TABLE someschema.table2") > -1 @testing.crashes('firebird', 'No schema support') @testing.fails_on('sqlite', 'FIXME: unknown') @@ -767,9 +832,9 @@ class SchemaTest(TestBase): def test_explicit_default_schema(self): engine = testing.db - if testing.against('mysql'): + if testing.against('mysql+mysqldb'): schema = testing.db.url.database - elif testing.against('postgres'): + elif testing.against('postgresql'): schema = 'public' elif testing.against('sqlite'): # Works for CREATE TABLE main.foo, SELECT FROM main.foo, etc., @@ -820,4 +885,324 @@ class HasSequenceTest(TestBase): metadata.drop_all(bind=testing.db) eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False) +# Tests related to engine.reflection + +def get_schema(): + if testing.against('oracle'): + return 'scott' + return 'test_schema' + +def createTables(meta, schema=None): + if schema: + parent_user_id = Column('parent_user_id', sa.Integer, + sa.ForeignKey('%s.users.user_id' % schema) + ) + else: + parent_user_id = Column('parent_user_id', sa.Integer, + sa.ForeignKey('users.user_id') + ) + + users = Table('users', meta, + Column('user_id', sa.INT, primary_key=True), + Column('user_name', sa.VARCHAR(20), nullable=False), + Column('test1', sa.CHAR(5), nullable=False), + Column('test2', sa.Float(5), nullable=False), + Column('test3', sa.Text), + Column('test4', sa.Numeric(10, 2), nullable = False), + Column('test5', sa.DateTime), + Column('test5-1', sa.TIMESTAMP), + parent_user_id, + Column('test6', sa.DateTime, nullable=False), + Column('test7', sa.Text), + Column('test8', sa.Binary), + Column('test_passivedefault2', sa.Integer, server_default='5'), + Column('test9', sa.Binary(100)), + Column('test10', sa.Numeric(10, 2)), + schema=schema, + test_needs_fk=True, + ) + addresses = Table('email_addresses', meta, + Column('address_id', sa.Integer, primary_key = True), + Column('remote_user_id', sa.Integer, + sa.ForeignKey(users.c.user_id)), + Column('email_address', sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + return (users, addresses) + +def createIndexes(con, schema=None): + fullname = 'users' + if schema: + fullname = "%s.%s" % (schema, 'users') + query = "CREATE INDEX users_t_idx ON %s (test1, test2)" % fullname + con.execute(sa.sql.text(query)) + +def createViews(con, schema=None): + for table_name in ('users', 'email_addresses'): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + '_v' + query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name, + fullname) + con.execute(sa.sql.text(query)) + +def dropViews(con, schema=None): + for table_name in ('email_addresses', 'users'): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + '_v' + query = "DROP VIEW %s" % view_name + con.execute(sa.sql.text(query)) + + +class ComponentReflectionTest(TestBase): + + @testing.requires.schemas + def test_get_schema_names(self): + meta = MetaData(testing.db) + insp = Inspector(meta.bind) + + self.assert_(get_schema() in insp.get_schema_names()) + + def _test_get_table_names(self, schema=None, table_type='table', + order_by=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + createViews(meta.bind, schema) + try: + insp = Inspector(meta.bind) + if table_type == 'view': + table_names = insp.get_view_names(schema) + table_names.sort() + answer = ['email_addresses_v', 'users_v'] + else: + table_names = insp.get_table_names(schema, + order_by=order_by) + table_names.sort() + if order_by == 'foreign_key': + answer = ['users', 'email_addresses'] + else: + answer = ['email_addresses', 'users'] + eq_(table_names, answer) + finally: + dropViews(meta.bind, schema) + addresses.drop() + users.drop() + + def test_get_table_names(self): + self._test_get_table_names() + + @testing.requires.schemas + def test_get_table_names_with_schema(self): + self._test_get_table_names(get_schema()) + + def test_get_view_names(self): + self._test_get_table_names(table_type='view') + + @testing.requires.schemas + def test_get_view_names_with_schema(self): + self._test_get_table_names(get_schema(), table_type='view') + + def _test_get_columns(self, schema=None, table_type='table'): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + table_names = ['users', 'email_addresses'] + meta.create_all() + if table_type == 'view': + createViews(meta.bind, schema) + table_names = ['users_v', 'email_addresses_v'] + try: + insp = Inspector(meta.bind) + for (table_name, table) in zip(table_names, (users, addresses)): + schema_name = schema + cols = insp.get_columns(table_name, schema=schema_name) + self.assert_(len(cols) > 0, len(cols)) + # should be in order + for (i, col) in enumerate(table.columns): + eq_(col.name, cols[i]['name']) + ctype = cols[i]['type'].__class__ + ctype_def = col.type + if isinstance(ctype_def, sa.types.TypeEngine): + ctype_def = ctype_def.__class__ + + # Oracle returns Date for DateTime. + if testing.against('oracle') \ + and ctype_def in (sql_types.Date, sql_types.DateTime): + ctype_def = sql_types.Date + + # assert that the desired type and return type + # share a base within one of the generic types. + self.assert_( + len( + set( + ctype.__mro__ + ).intersection(ctype_def.__mro__) + .intersection([sql_types.Integer, sql_types.Numeric, + sql_types.DateTime, sql_types.Date, sql_types.Time, + sql_types.String, sql_types.Binary]) + ) > 0 + ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'], + ctype))) + finally: + if table_type == 'view': + dropViews(meta.bind, schema) + addresses.drop() + users.drop() + + def test_get_columns(self): + self._test_get_columns() + + @testing.requires.schemas + def test_get_columns_with_schema(self): + self._test_get_columns(schema=get_schema()) + + def test_get_view_columns(self): + self._test_get_columns(table_type='view') + + @testing.requires.schemas + def test_get_view_columns_with_schema(self): + self._test_get_columns(schema=get_schema(), table_type='view') + + def _test_get_primary_keys(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + insp = Inspector(meta.bind) + try: + users_pkeys = insp.get_primary_keys(users.name, + schema=schema) + eq_(users_pkeys, ['user_id']) + addr_pkeys = insp.get_primary_keys(addresses.name, + schema=schema) + eq_(addr_pkeys, ['address_id']) + + finally: + addresses.drop() + users.drop() + + def test_get_primary_keys(self): + self._test_get_primary_keys() + + @testing.fails_on('sqlite', 'no schemas') + def test_get_primary_keys_with_schema(self): + self._test_get_primary_keys(schema=get_schema()) + + def _test_get_foreign_keys(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + insp = Inspector(meta.bind) + try: + expected_schema = schema + # users + users_fkeys = insp.get_foreign_keys(users.name, + schema=schema) + fkey1 = users_fkeys[0] + self.assert_(fkey1['name'] is not None) + eq_(fkey1['referred_schema'], expected_schema) + eq_(fkey1['referred_table'], users.name) + eq_(fkey1['referred_columns'], ['user_id', ]) + eq_(fkey1['constrained_columns'], ['parent_user_id']) + #addresses + addr_fkeys = insp.get_foreign_keys(addresses.name, + schema=schema) + fkey1 = addr_fkeys[0] + self.assert_(fkey1['name'] is not None) + eq_(fkey1['referred_schema'], expected_schema) + eq_(fkey1['referred_table'], users.name) + eq_(fkey1['referred_columns'], ['user_id', ]) + eq_(fkey1['constrained_columns'], ['remote_user_id']) + finally: + addresses.drop() + users.drop() + + def test_get_foreign_keys(self): + self._test_get_foreign_keys() + + @testing.requires.schemas + def test_get_foreign_keys_with_schema(self): + self._test_get_foreign_keys(schema=get_schema()) + + def _test_get_indexes(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + createIndexes(meta.bind, schema) + try: + # The database may decide to create indexes for foreign keys, etc. + # so there may be more indexes than expected. + insp = Inspector(meta.bind) + indexes = insp.get_indexes('users', schema=schema) + indexes.sort() + expected_indexes = [ + {'unique': False, + 'column_names': ['test1', 'test2'], + 'name': 'users_t_idx'}] + index_names = [d['name'] for d in indexes] + for e_index in expected_indexes: + assert e_index['name'] in index_names + index = indexes[index_names.index(e_index['name'])] + for key in e_index: + eq_(e_index[key], index[key]) + + finally: + addresses.drop() + users.drop() + + def test_get_indexes(self): + self._test_get_indexes() + + @testing.requires.schemas + def test_get_indexes_with_schema(self): + self._test_get_indexes(schema=get_schema()) + + def _test_get_view_definition(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + createViews(meta.bind, schema) + view_name1 = 'users_v' + view_name2 = 'email_addresses_v' + try: + insp = Inspector(meta.bind) + v1 = insp.get_view_definition(view_name1, schema=schema) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schema=schema) + self.assert_(v2) + finally: + dropViews(meta.bind, schema) + addresses.drop() + users.drop() + + def test_get_view_definition(self): + self._test_get_view_definition() + + @testing.requires.schemas + def test_get_view_definition_with_schema(self): + self._test_get_view_definition(schema=get_schema()) + + def _test_get_table_oid(self, table_name, schema=None): + if testing.against('postgresql'): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + try: + insp = create_inspector(meta.bind) + oid = insp.get_table_oid(table_name, schema) + self.assert_(isinstance(oid, (int, long))) + finally: + addresses.drop() + users.drop() + + def test_get_table_oid(self): + self._test_get_table_oid('users') + + @testing.requires.schemas + def test_get_table_oid_with_schema(self): + self._test_get_table_oid('users', schema=get_schema()) + diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index 6698259a45..8e3f3412d6 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -20,7 +20,8 @@ class TransactionTest(TestBase): users.create(testing.db) def teardown(self): - testing.db.connect().execute(users.delete()) + testing.db.execute(users.delete()).close() + @classmethod def teardown_class(cls): users.drop(testing.db) @@ -40,6 +41,7 @@ class TransactionTest(TestBase): result = connection.execute("select * from query_users") assert len(result.fetchall()) == 3 transaction.commit() + connection.close() def test_rollback(self): """test a basic rollback""" @@ -176,6 +178,7 @@ class TransactionTest(TestBase): connection.close() @testing.requires.savepoints + @testing.crashes('oracle+zxjdbc', 'Errors out and causes subsequent tests to deadlock') def test_nested_subtransaction_commit(self): connection = testing.db.connect() transaction = connection.begin() @@ -274,6 +277,7 @@ class TransactionTest(TestBase): connection.close() @testing.requires.two_phase_transactions + @testing.crashes('mysql+zxjdbc', 'Deadlocks, causing subsequent tests to fail') @testing.fails_on('mysql', 'FIXME: unknown') def test_two_phase_recover(self): # MySQL recovery doesn't currently seem to work correctly @@ -369,7 +373,7 @@ class ExplicitAutoCommitTest(TestBase): Requires PostgreSQL so that we may define a custom function which modifies the database. """ - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): @@ -380,7 +384,7 @@ class ExplicitAutoCommitTest(TestBase): testing.db.execute("create function insert_foo(varchar) returns integer as 'insert into foo(data) values ($1);select 1;' language sql") def teardown(self): - foo.delete().execute() + foo.delete().execute().close() @classmethod def teardown_class(cls): @@ -453,8 +457,10 @@ class TLTransactionTest(TestBase): test_needs_acid=True, ) users.create(tlengine) + def teardown(self): - tlengine.execute(users.delete()) + tlengine.execute(users.delete()).close() + @classmethod def teardown_class(cls): users.drop(tlengine) @@ -497,6 +503,7 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 0 finally: + c.close() external_connection.close() def test_rollback(self): @@ -530,7 +537,9 @@ class TLTransactionTest(TestBase): external_connection.close() def test_commits(self): - assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0 + connection = tlengine.connect() + assert connection.execute("select count(1) from query_users").scalar() == 0 + connection.close() connection = tlengine.contextual_connect() transaction = connection.begin() @@ -547,6 +556,7 @@ class TLTransactionTest(TestBase): l = result.fetchall() assert len(l) == 3, "expected 3 got %d" % len(l) transaction.commit() + connection.close() def test_rollback_off_conn(self): # test that a TLTransaction opened off a TLConnection allows that @@ -563,6 +573,7 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 0 finally: + conn.close() external_connection.close() def test_morerollback_off_conn(self): @@ -581,6 +592,8 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 0 finally: + conn.close() + conn2.close() external_connection.close() def test_commit_off_connection(self): @@ -596,6 +609,7 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 3 finally: + conn.close() external_connection.close() def test_nesting(self): @@ -712,8 +726,10 @@ class ForUpdateTest(TestBase): test_needs_acid=True, ) counters.create(testing.db) + def teardown(self): - testing.db.connect().execute(counters.delete()) + testing.db.execute(counters.delete()).close() + @classmethod def teardown_class(cls): counters.drop(testing.db) @@ -726,7 +742,7 @@ class ForUpdateTest(TestBase): for i in xrange(count): trans = con.begin() try: - existing = con.execute(sel).fetchone() + existing = con.execute(sel).first() incr = existing['counter_value'] + 1 time.sleep(delay) @@ -734,7 +750,7 @@ class ForUpdateTest(TestBase): values={'counter_value':incr})) time.sleep(delay) - readback = con.execute(sel).fetchone() + readback = con.execute(sel).first() if (readback['counter_value'] != incr): raise AssertionError("Got %s post-update, expected %s" % (readback['counter_value'], incr)) @@ -778,7 +794,7 @@ class ForUpdateTest(TestBase): self.assert_(len(errors) == 0) sel = counters.select(whereclause=counters.c.counter_id==1) - final = db.execute(sel).fetchone() + final = db.execute(sel).first() self.assert_(final['counter_value'] == iterations * thread_count) def overlap(self, ids, errors, update_style): diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 8df449718e..4a5775218d 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -1,6 +1,5 @@ from sqlalchemy.test.testing import eq_, assert_raises import copy -import gc import pickle from sqlalchemy import * @@ -8,6 +7,7 @@ from sqlalchemy.orm import * from sqlalchemy.orm.collections import collection from sqlalchemy.ext.associationproxy import * from sqlalchemy.test import * +from sqlalchemy.test.util import gc_collect class DictCollection(dict): @@ -880,7 +880,7 @@ class ReconstitutionTest(TestBase): add_child('p1', 'c1') - gc.collect() + gc_collect() add_child('p1', 'c2') session.flush() @@ -895,7 +895,7 @@ class ReconstitutionTest(TestBase): p.kids.extend(['c1', 'c2']) p_copy = copy.copy(p) del p - gc.collect() + gc_collect() assert set(p_copy.kids) == set(['c1', 'c2']), p.kids diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index ce25490998..3ee94d0271 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -1,5 +1,7 @@ from sqlalchemy import * +from sqlalchemy.types import TypeEngine from sqlalchemy.sql.expression import ClauseElement, ColumnClause +from sqlalchemy.schema import DDLElement from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import table, column from sqlalchemy.test import * @@ -25,7 +27,35 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): select([MyThingy('x'), MyThingy('y')]).where(MyThingy() == 5), "SELECT >>x<<, >>y<< WHERE >>MYTHINGY!<< = :MYTHINGY!_1" ) + + def test_types(self): + class MyType(TypeEngine): + pass + + @compiles(MyType, 'sqlite') + def visit_type(type, compiler, **kw): + return "SQLITE_FOO" + + @compiles(MyType, 'postgresql') + def visit_type(type, compiler, **kw): + return "POSTGRES_FOO" + + from sqlalchemy.dialects.sqlite import base as sqlite + from sqlalchemy.dialects.postgresql import base as postgresql + self.assert_compile( + MyType(), + "SQLITE_FOO", + dialect=sqlite.dialect() + ) + + self.assert_compile( + MyType(), + "POSTGRES_FOO", + dialect=postgresql.dialect() + ) + + def test_stateful(self): class MyThingy(ColumnClause): def __init__(self): @@ -71,10 +101,10 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): ) def test_dialect_specific(self): - class AddThingy(ClauseElement): + class AddThingy(DDLElement): __visit_name__ = 'add_thingy' - class DropThingy(ClauseElement): + class DropThingy(DDLElement): __visit_name__ = 'drop_thingy' @compiles(AddThingy, 'sqlite') @@ -97,7 +127,7 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): "DROP THINGY" ) - from sqlalchemy.databases import sqlite as base + from sqlalchemy.dialects.sqlite import base self.assert_compile(AddThingy(), "ADD SPECIAL SL THINGY", dialect=base.dialect() diff --git a/test/ext/test_declarative.py b/test/ext/test_declarative.py index 224f41731a..745e3b7cf8 100644 --- a/test/ext/test_declarative.py +++ b/test/ext/test_declarative.py @@ -5,8 +5,7 @@ from sqlalchemy import exc import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import MetaData, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred from sqlalchemy.test.testing import eq_ @@ -27,14 +26,14 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) email = Column(String(50), key='_email') user_id = Column('user_id', Integer, ForeignKey('users.id'), key='_user_id') @@ -127,7 +126,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) addresses = relation("Address", order_by="desc(Address.email)", primaryjoin="User.id==Address.user_id", foreign_keys="[Address.user_id]", @@ -136,7 +135,7 @@ class DeclarativeTest(DeclarativeTestBase): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) email = Column(String(50)) user_id = Column(Integer) # note no foreign key @@ -180,13 +179,13 @@ class DeclarativeTest(DeclarativeTestBase): def test_uncompiled_attributes_in_relation(self): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) email = Column(String(50)) user_id = Column(Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) addresses = relation("Address", order_by=Address.email, foreign_keys=Address.user_id, @@ -272,14 +271,14 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) User.name = Column('name', String(50)) User.addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) Address.email = Column(String(50), key='_email') Address.user_id = Column('user_id', Integer, ForeignKey('users.id'), key='_user_id') @@ -312,14 +311,14 @@ class DeclarativeTest(DeclarativeTestBase): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", order_by=Address.email) @@ -341,14 +340,14 @@ class DeclarativeTest(DeclarativeTestBase): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", order_by=(Address.email, Address.id)) @@ -368,14 +367,14 @@ class DeclarativeTest(DeclarativeTestBase): class User(ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", backref="user") class Address(ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) @@ -478,14 +477,14 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) @@ -513,7 +512,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) User.a = Column('a', String(10)) @@ -535,14 +534,14 @@ class DeclarativeTest(DeclarativeTestBase): def test_column_properties(self): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) email = Column(String(50)) user_id = Column(Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) adr_count = sa.orm.column_property( sa.select([sa.func.count(Address.id)], Address.user_id == id). @@ -588,7 +587,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = sa.orm.deferred(Column(String(50))) Base.metadata.create_all() @@ -607,7 +606,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) _name = Column('name', String(50)) def _set_name(self, name): self._name = "SOMENAME " + name @@ -636,7 +635,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) _name = Column('name', String(50)) name = sa.orm.synonym('_name', comparator_factory=CustomCompare) @@ -652,7 +651,7 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) _name = Column('name', String(50)) def _set_name(self, name): self._name = "SOMENAME " + name @@ -674,14 +673,14 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey(User.id)) @@ -711,14 +710,14 @@ class DeclarativeTest(DeclarativeTestBase): class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) email = Column('email', String(50)) user_id = Column('user_id', Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) addresses = relation("Address", backref="user", primaryjoin=id == Address.user_id) @@ -763,7 +762,7 @@ class DeclarativeTest(DeclarativeTestBase): def test_with_explicit_autoloaded(self): meta = MetaData(testing.db) t1 = Table('t1', meta, - Column('id', String(50), primary_key=True), + Column('id', String(50), primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) meta.create_all() try: @@ -779,6 +778,70 @@ class DeclarativeTest(DeclarativeTestBase): finally: meta.drop_all() + def test_synonym_for(self): + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) + name = Column('name', String(50)) + + @decl.synonym_for('name') + @property + def namesyn(self): + return self.name + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + eq_(u1.name, "someuser") + eq_(u1.namesyn, 'someuser') + sess.add(u1) + sess.flush() + + rt = sess.query(User).filter(User.namesyn == 'someuser').one() + eq_(rt, u1) + + def test_comparable_using(self): + class NameComparator(sa.orm.PropComparator): + @property + def upperself(self): + cls = self.prop.parent.class_ + col = getattr(cls, 'name') + return sa.func.upper(col) + + def operate(self, op, other, **kw): + return op(self.upperself, other, **kw) + + class User(Base, ComparableEntity): + __tablename__ = 'users' + + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) + name = Column('name', String(50)) + + @decl.comparable_using(NameComparator) + @property + def uc_name(self): + return self.name is not None and self.name.upper() or None + + Base.metadata.create_all() + + sess = create_session() + u1 = User(name='someuser') + eq_(u1.name, "someuser", u1.name) + eq_(u1.uc_name, 'SOMEUSER', u1.uc_name) + sess.add(u1) + sess.flush() + sess.expunge_all() + + rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one() + eq_(rt, u1) + sess.expunge_all() + + rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one() + eq_(rt, u1) + + class DeclarativeInheritanceTest(DeclarativeTestBase): def test_custom_join_condition(self): class Foo(Base): @@ -797,13 +860,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def test_joined(self): class Company(Base, ComparableEntity): __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) employees = relation("Person") class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) company_id = Column('company_id', Integer, ForeignKey('companies.id')) name = Column('name', String(50)) @@ -911,13 +974,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Company(Base, ComparableEntity): __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) employees = relation("Person") class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) company_id = Column('company_id', Integer, ForeignKey('companies.id')) name = Column('name', String(50)) @@ -967,13 +1030,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Company(Base, ComparableEntity): __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) employees = relation("Person") class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) company_id = Column(Integer, ForeignKey('companies.id')) name = Column(String(50)) @@ -1037,13 +1100,13 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def test_joined_from_single(self): class Company(Base, ComparableEntity): __tablename__ = 'companies' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) name = Column('name', String(50)) employees = relation("Person") class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) company_id = Column(Integer, ForeignKey('companies.id')) name = Column(String(50)) discriminator = Column('type', String(50)) @@ -1100,7 +1163,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def test_add_deferred(self): class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column('id', Integer, primary_key=True) + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) Person.name = deferred(Column(String(10))) @@ -1117,6 +1180,8 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): Person(name='ratbert') ] ) + sess.expunge_all() + person = sess.query(Person).filter(Person.name == 'ratbert').one() assert 'name' not in person.__dict__ @@ -1127,7 +1192,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Person(Base, ComparableEntity): __tablename__ = 'people' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) discriminator = Column('type', String(50)) __mapper_args__ = {'polymorphic_on':discriminator} @@ -1139,7 +1204,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): class Language(Base, ComparableEntity): __tablename__ = 'languages' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) assert not hasattr(Person, 'primary_language_id') @@ -1236,12 +1301,12 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): def test_concrete(self): engineers = Table('engineers', Base.metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('primary_language', String(50)) ) managers = Table('managers', Base.metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('golf_swing', String(50)) ) @@ -1293,12 +1358,12 @@ def _produce_test(inline, stringbased): class User(Base, ComparableEntity): __tablename__ = 'users' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) name = Column(String(50)) class Address(Base, ComparableEntity): __tablename__ = 'addresses' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) email = Column(String(50)) user_id = Column(Integer, ForeignKey('users.id')) if inline: @@ -1363,16 +1428,16 @@ class DeclarativeReflectionTest(testing.TestBase): reflection_metadata = MetaData(testing.db) Table('users', reflection_metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), test_needs_fk=True) Table('addresses', reflection_metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('email', String(50)), Column('user_id', Integer, ForeignKey('users.id')), test_needs_fk=True) Table('imhandles', reflection_metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer), Column('network', String(50)), Column('handle', String(50)), @@ -1398,12 +1463,17 @@ class DeclarativeReflectionTest(testing.TestBase): class User(Base, ComparableEntity): __tablename__ = 'users' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) + u1 = User(name='u1', addresses=[ Address(email='one'), Address(email='two'), @@ -1428,12 +1498,16 @@ class DeclarativeReflectionTest(testing.TestBase): class User(Base, ComparableEntity): __tablename__ = 'users' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) nom = Column('name', String(50), key='nom') addresses = relation("Address", backref="user") class Address(Base, ComparableEntity): __tablename__ = 'addresses' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) u1 = User(nom='u1', addresses=[ Address(email='one'), @@ -1461,12 +1535,16 @@ class DeclarativeReflectionTest(testing.TestBase): class IMHandle(Base, ComparableEntity): __tablename__ = 'imhandles' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) user_id = Column('user_id', Integer, ForeignKey('users.id')) class User(Base, ComparableEntity): __tablename__ = 'users' __autoload__ = True + if testing.against('oracle', 'firebird'): + id = Column('id', Integer, primary_key=True, test_needs_autoincrement=True) handles = relation("IMHandle", backref="user") u1 = User(name='u1', handles=[ @@ -1487,69 +1565,3 @@ class DeclarativeReflectionTest(testing.TestBase): eq_(a1, IMHandle(network='lol', handle='zomg')) eq_(a1.user, User(name='u1')) - def test_synonym_for(self): - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - - @decl.synonym_for('name') - @property - def namesyn(self): - return self.name - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, "someuser") - eq_(u1.namesyn, 'someuser') - sess.add(u1) - sess.flush() - - rt = sess.query(User).filter(User.namesyn == 'someuser').one() - eq_(rt, u1) - - def test_comparable_using(self): - class NameComparator(sa.orm.PropComparator): - @property - def upperself(self): - cls = self.prop.parent.class_ - col = getattr(cls, 'name') - return sa.func.upper(col) - - def operate(self, op, other, **kw): - return op(self.upperself, other, **kw) - - class User(Base, ComparableEntity): - __tablename__ = 'users' - - id = Column('id', Integer, primary_key=True) - name = Column('name', String(50)) - - @decl.comparable_using(NameComparator) - @property - def uc_name(self): - return self.name is not None and self.name.upper() or None - - Base.metadata.create_all() - - sess = create_session() - u1 = User(name='someuser') - eq_(u1.name, "someuser", u1.name) - eq_(u1.uc_name, 'SOMEUSER', u1.uc_name) - sess.add(u1) - sess.flush() - sess.expunge_all() - - rt = sess.query(User).filter(User.uc_name == 'SOMEUSER').one() - eq_(rt, u1) - sess.expunge_all() - - rt = sess.query(User).filter(User.uc_name.startswith('SOMEUSE')).one() - eq_(rt, u1) - - -if __name__ == '__main__': - testing.main() diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index b8a8e3fef9..c400797b0e 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -96,8 +96,6 @@ class SerializeTest(MappedTest): [(7, u'jack'), (8, u'ed'), (8, u'ed'), (8, u'ed'), (9, u'fred')] ) - # fails due to pure Python pickle bug: http://bugs.python.org/issue998998 - @testing.fails_if(lambda: util.py3k) def test_query(self): q = Session.query(User).filter(User.name=='ed').options(eagerload(User.addresses)) eq_(q.all(), [User(name='ed', addresses=[Address(id=2), Address(id=3), Address(id=4)])]) diff --git a/test/orm/_base.py b/test/orm/_base.py index 8d695e912b..f08d253d57 100644 --- a/test/orm/_base.py +++ b/test/orm/_base.py @@ -1,10 +1,11 @@ -import gc import inspect import sys import types import sqlalchemy as sa +import sqlalchemy.exceptions as sa_exc from sqlalchemy.test import config, testing from sqlalchemy.test.testing import resolve_artifact_names, adict +from sqlalchemy.test.engines import drop_all_tables from sqlalchemy.util import function_named @@ -74,20 +75,19 @@ class ComparableEntity(BasicEntity): if attr.startswith('_'): continue value = getattr(a, attr) - if (hasattr(value, '__iter__') and - not isinstance(value, basestring)): - try: - # catch AttributeError so that lazy loaders trigger - battr = getattr(b, attr) - except AttributeError: - return False + try: + # handle lazy loader errors + battr = getattr(b, attr) + except (AttributeError, sa_exc.UnboundExecutionError): + return False + + if hasattr(value, '__iter__'): if list(value) != list(battr): return False else: - if value is not None: - if value != getattr(b, attr, None): - return False + if value is not None and value != battr: + return False return True finally: _recursion_stack.remove(id(self)) @@ -173,7 +173,7 @@ class MappedTest(ORMTest): def setup(self): if self.run_define_tables == 'each': self.tables.clear() - self.metadata.drop_all() + drop_all_tables(self.metadata) self.metadata.clear() self.define_tables(self.metadata) self.metadata.create_all() @@ -217,7 +217,7 @@ class MappedTest(ORMTest): for cl in cls.classes.values(): cls.unregister_class(cl) ORMTest.teardown_class() - cls.metadata.drop_all() + drop_all_tables(cls.metadata) cls.metadata.bind = None @classmethod diff --git a/test/orm/_fixtures.py b/test/orm/_fixtures.py index 931d8cadf8..e9d6ac1656 100644 --- a/test/orm/_fixtures.py +++ b/test/orm/_fixtures.py @@ -60,7 +60,7 @@ email_bounces = fixture_table( orders = fixture_table( Table('orders', fixture_metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', None, ForeignKey('users.id')), Column('address_id', None, ForeignKey('addresses.id')), Column('description', String(30)), @@ -76,7 +76,7 @@ orders = fixture_table( dingalings = fixture_table( Table("dingalings", fixture_metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('address_id', None, ForeignKey('addresses.id')), Column('data', String(30)), test_needs_acid=True, diff --git a/test/orm/inheritance/test_abc_inheritance.py b/test/orm/inheritance/test_abc_inheritance.py index 4e55cf70ea..f6d5111b2c 100644 --- a/test/orm/inheritance/test_abc_inheritance.py +++ b/test/orm/inheritance/test_abc_inheritance.py @@ -3,6 +3,7 @@ from sqlalchemy.orm import * from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE from sqlalchemy.test import testing +from sqlalchemy.test.schema import Table, Column from test.orm import _base @@ -15,7 +16,7 @@ def produce_test(parent, child, direction): def define_tables(cls, metadata): global ta, tb, tc ta = ["a", metadata] - ta.append(Column('id', Integer, primary_key=True)), + ta.append(Column('id', Integer, primary_key=True, test_needs_autoincrement=True)), ta.append(Column('a_data', String(30))) if "a"== parent and direction == MANYTOONE: ta.append(Column('child_id', Integer, ForeignKey("%s.id" % child, use_alter=True, name="foo"))) diff --git a/test/orm/inheritance/test_abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py index 8cad8ed781..2dab59bb25 100644 --- a/test/orm/inheritance/test_abc_polymorphic.py +++ b/test/orm/inheritance/test_abc_polymorphic.py @@ -4,13 +4,14 @@ from sqlalchemy.orm import * from sqlalchemy.util import function_named from test.orm import _base, _fixtures +from sqlalchemy.test.schema import Table, Column class ABCTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): global a, b, c a = Table('a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('adata', String(30)), Column('type', String(30)), ) @@ -61,7 +62,7 @@ class ABCTest(_base.MappedTest): C(cdata='c1', bdata='c1', adata='c1'), C(cdata='c2', bdata='c2', adata='c2'), C(cdata='c2', bdata='c2', adata='c2'), - ] == sess.query(A).all() + ] == sess.query(A).order_by(A.id).all() assert [ B(bdata='b1', adata='b1'), diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index bad6920de7..b2e00de359 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import exc as orm_exc from sqlalchemy.test import testing, engines from sqlalchemy.util import function_named from test.orm import _base, _fixtures +from sqlalchemy.test.schema import Table, Column class O2MTest(_base.MappedTest): """deals with inheritance and one-to-many relationships""" @@ -14,8 +15,7 @@ class O2MTest(_base.MappedTest): def define_tables(cls, metadata): global foo, bar, blub foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(20))) bar = Table('bar', metadata, @@ -73,9 +73,8 @@ class FalseDiscriminatorTest(_base.MappedTest): def define_tables(cls, metadata): global t1 t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), - Column('type', Boolean, nullable=False) - ) + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('type', Boolean, nullable=False)) def test_false_on_sub(self): class Foo(object):pass @@ -108,7 +107,7 @@ class PolymorphicSynonymTest(_base.MappedTest): def define_tables(cls, metadata): global t1, t2 t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(10), nullable=False), Column('info', String(255))) t2 = Table('t2', metadata, @@ -149,12 +148,12 @@ class CascadeTest(_base.MappedTest): def define_tables(cls, metadata): global t1, t2, t3, t4 t1= Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)) ) t2 = Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('t1id', Integer, ForeignKey('t1.id')), Column('type', String(30)), Column('data', String(30)) @@ -164,7 +163,7 @@ class CascadeTest(_base.MappedTest): Column('moredata', String(30))) t4 = Table('t4', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('t3id', Integer, ForeignKey('t3.id')), Column('data', String(30))) @@ -214,8 +213,7 @@ class GetTest(_base.MappedTest): def define_tables(cls, metadata): global foo, bar, blub foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(30)), Column('data', String(20))) @@ -224,7 +222,7 @@ class GetTest(_base.MappedTest): Column('data', String(20))) blub = Table('blub', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('foo_id', Integer, ForeignKey('foo.id')), Column('bar_id', Integer, ForeignKey('bar.id')), Column('data', String(20))) @@ -304,8 +302,7 @@ class EagerLazyTest(_base.MappedTest): def define_tables(cls, metadata): global foo, bar, bar_foo foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) bar = Table('bar', metadata, Column('id', Integer, ForeignKey('foo.id'), primary_key=True), @@ -350,13 +347,13 @@ class FlushTest(_base.MappedTest): def define_tables(cls, metadata): global users, roles, user_roles, admins users = Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('email', String(128)), Column('password', String(16)), ) roles = Table('role', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('description', String(32)) ) @@ -366,7 +363,7 @@ class FlushTest(_base.MappedTest): ) admins = Table('admin', metadata, - Column('admin_id', Integer, primary_key=True), + Column('admin_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey('users.id')) ) @@ -439,7 +436,7 @@ class VersioningTest(_base.MappedTest): def define_tables(cls, metadata): global base, subtable, stuff base = Table('base', metadata, - Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('version_id', Integer, nullable=False), Column('value', String(40)), Column('discriminator', Integer, nullable=False) @@ -449,11 +446,10 @@ class VersioningTest(_base.MappedTest): Column('subdata', String(50)) ) stuff = Table('stuff', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent', Integer, ForeignKey('base.id')) ) - @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.') @engines.close_open_connections def test_save_update(self): class Base(_fixtures.Base): @@ -493,16 +489,16 @@ class VersioningTest(_base.MappedTest): try: sess2.flush() - assert False + assert not testing.db.dialect.supports_sane_rowcount except orm_exc.ConcurrentModificationError, e: assert True sess2.refresh(s2) - assert s2.subdata == 'sess1 subdata' + if testing.db.dialect.supports_sane_rowcount: + assert s2.subdata == 'sess1 subdata' s2.subdata = 'sess2 subdata' sess2.flush() - @testing.fails_on('mssql', 'FIXME: the flush still happens with the concurrency issue.') def test_delete(self): class Base(_fixtures.Base): pass @@ -534,7 +530,7 @@ class VersioningTest(_base.MappedTest): try: s1.subdata = 'some new subdata' sess.flush() - assert False + assert not testing.db.dialect.supports_sane_rowcount except orm_exc.ConcurrentModificationError, e: assert True @@ -550,12 +546,12 @@ class DistinctPKTest(_base.MappedTest): global person_table, employee_table, Person, Employee person_table = Table("persons", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("name", String(80)), ) employee_table = Table("employees", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("salary", Integer), Column("person_id", Integer, ForeignKey("persons.id")), ) @@ -623,7 +619,7 @@ class SyncCompileTest(_base.MappedTest): global _a_table, _b_table, _c_table _a_table = Table('a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data1', String(128)) ) @@ -691,7 +687,7 @@ class OverrideColKeyTest(_base.MappedTest): global base, subtable base = Table('base', metadata, - Column('base_id', Integer, primary_key=True), + Column('base_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(255)), Column('sqlite_fixer', String(10)) ) @@ -921,7 +917,7 @@ class OptimizedLoadTest(_base.MappedTest): def define_tables(cls, metadata): global base, sub base = Table('base', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('type', String(50)) ) @@ -1008,7 +1004,7 @@ class PKDiscriminatorTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): parents = Table('parents', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(60))) children = Table('children', metadata, @@ -1061,14 +1057,14 @@ class DeleteOrphanTest(_base.MappedTest): def define_tables(cls, metadata): global single, parent single = Table('single', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(50), nullable=False), Column('data', String(50)), Column('parent_id', Integer, ForeignKey('parent.id'), nullable=False), ) parent = Table('parent', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)) ) diff --git a/test/orm/inheritance/test_concrete.py b/test/orm/inheritance/test_concrete.py index 46bd171e44..3a78be9d7b 100644 --- a/test/orm/inheritance/test_concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -8,6 +8,7 @@ from sqlalchemy.test import testing from test.orm import _base from sqlalchemy.orm import attributes from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.schema import Table, Column class Employee(object): def __init__(self, name): @@ -48,31 +49,31 @@ class ConcreteTest(_base.MappedTest): global managers_table, engineers_table, hackers_table, companies, employees_table companies = Table('companies', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) employees_table = Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('company_id', Integer, ForeignKey('companies.id')) ) managers_table = Table('managers', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('manager_data', String(50)), Column('company_id', Integer, ForeignKey('companies.id')) ) engineers_table = Table('engineers', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('engineer_info', String(50)), Column('company_id', Integer, ForeignKey('companies.id')) ) hackers_table = Table('hackers', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('engineer_info', String(50)), Column('company_id', Integer, ForeignKey('companies.id')), @@ -320,17 +321,17 @@ class PropertyInheritanceTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('a_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('some_c_id', Integer, ForeignKey('c_table.id')), Column('aname', String(50)), ) Table('b_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('some_c_id', Integer, ForeignKey('c_table.id')), Column('bname', String(50)), ) Table('c_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('cname', String(50)), ) @@ -525,11 +526,11 @@ class ColKeysTest(_base.MappedTest): def define_tables(cls, metadata): global offices_table, refugees_table refugees_table = Table('refugee', metadata, - Column('refugee_fid', Integer, primary_key=True), + Column('refugee_fid', Integer, primary_key=True, test_needs_autoincrement=True), Column('refugee_name', Unicode(30), key='name')) offices_table = Table('office', metadata, - Column('office_fid', Integer, primary_key=True), + Column('office_fid', Integer, primary_key=True, test_needs_autoincrement=True), Column('office_name', Unicode(30), key='name')) @classmethod diff --git a/test/orm/inheritance/test_magazine.py b/test/orm/inheritance/test_magazine.py index 0673012511..f94781c278 100644 --- a/test/orm/inheritance/test_magazine.py +++ b/test/orm/inheritance/test_magazine.py @@ -4,6 +4,7 @@ from sqlalchemy.orm import * from sqlalchemy.test import testing from sqlalchemy.util import function_named from test.orm import _base +from sqlalchemy.test.schema import Table, Column class BaseObject(object): def __init__(self, *args, **kwargs): @@ -75,49 +76,48 @@ class MagazineTest(_base.MappedTest): global publication_table, issue_table, location_table, location_name_table, magazine_table, \ page_table, magazine_page_table, classified_page_table, page_size_table - zerodefault = {} #{'default':0} publication_table = Table('publication', metadata, - Column('id', Integer, primary_key=True, default=None), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(45), default=''), ) issue_table = Table('issue', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('publication_id', Integer, ForeignKey('publication.id'), **zerodefault), - Column('issue', Integer, **zerodefault), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('publication_id', Integer, ForeignKey('publication.id')), + Column('issue', Integer), ) location_table = Table('location', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('issue_id', Integer, ForeignKey('issue.id'), **zerodefault), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('issue_id', Integer, ForeignKey('issue.id')), Column('ref', CHAR(3), default=''), - Column('location_name_id', Integer, ForeignKey('location_name.id'), **zerodefault), + Column('location_name_id', Integer, ForeignKey('location_name.id')), ) location_name_table = Table('location_name', metadata, - Column('id', Integer, primary_key=True, default=None), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(45), default=''), ) magazine_table = Table('magazine', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('location_id', Integer, ForeignKey('location.id'), **zerodefault), - Column('page_size_id', Integer, ForeignKey('page_size.id'), **zerodefault), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('location_id', Integer, ForeignKey('location.id')), + Column('page_size_id', Integer, ForeignKey('page_size.id')), ) page_table = Table('page', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('page_no', Integer, **zerodefault), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('page_no', Integer), Column('type', CHAR(1), default='p'), ) magazine_page_table = Table('magazine_page', metadata, - Column('page_id', Integer, ForeignKey('page.id'), primary_key=True, **zerodefault), - Column('magazine_id', Integer, ForeignKey('magazine.id'), **zerodefault), - Column('orders', TEXT, default=''), + Column('page_id', Integer, ForeignKey('page.id'), primary_key=True), + Column('magazine_id', Integer, ForeignKey('magazine.id')), + Column('orders', Text, default=''), ) classified_page_table = Table('classified_page', metadata, - Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True, **zerodefault), + Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True), Column('titles', String(45), default=''), ) page_size_table = Table('page_size', metadata, - Column('id', Integer, primary_key=True, default=None), - Column('width', Integer, **zerodefault), - Column('height', Integer, **zerodefault), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('width', Integer), + Column('height', Integer), Column('name', String(45), default=''), ) @@ -176,10 +176,11 @@ def generate_round_trip_test(use_unions=False, use_joins=False): 'magazine': relation(Magazine, backref=backref('pages', order_by=page_table.c.page_no)) }) - classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c', primary_key=[page_table.c.id]) - #compile_mappers() - #print [str(s) for s in classified_page_mapper.primary_key] - #print classified_page_mapper.columntoproperty[page_table.c.id] + classified_page_mapper = mapper(ClassifiedPage, + classified_page_table, + inherits=magazine_page_mapper, + polymorphic_identity='c', + primary_key=[page_table.c.id]) session = create_session() diff --git a/test/orm/inheritance/test_manytomany.py b/test/orm/inheritance/test_manytomany.py index f7e676bbbc..7b6ad04eb2 100644 --- a/test/orm/inheritance/test_manytomany.py +++ b/test/orm/inheritance/test_manytomany.py @@ -194,11 +194,11 @@ class InheritTest3(_base.MappedTest): b.foos.append(Foo("foo #1")) b.foos.append(Foo("foo #2")) sess.flush() - compare = repr(b) + repr(sorted([repr(o) for o in b.foos])) + compare = [repr(b)] + sorted([repr(o) for o in b.foos]) sess.expunge_all() l = sess.query(Bar).all() print repr(l[0]) + repr(l[0].foos) - found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos])) + found = [repr(l[0])] + sorted([repr(o) for o in l[0].foos]) eq_(found, compare) @testing.fails_on('maxdb', 'FIXME: unknown') diff --git a/test/orm/inheritance/test_poly_linked_list.py b/test/orm/inheritance/test_poly_linked_list.py index 67b543f31c..e434218b9c 100644 --- a/test/orm/inheritance/test_poly_linked_list.py +++ b/test/orm/inheritance/test_poly_linked_list.py @@ -3,6 +3,7 @@ from sqlalchemy.orm import * from test.orm import _base from sqlalchemy.test import testing +from sqlalchemy.test.schema import Table, Column class PolymorphicCircularTest(_base.MappedTest): @@ -12,7 +13,7 @@ class PolymorphicCircularTest(_base.MappedTest): def define_tables(cls, metadata): global Table1, Table1B, Table2, Table3, Data table1 = Table('table1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('related_id', Integer, ForeignKey('table1.id'), nullable=True), Column('type', String(30)), Column('name', String(30)) @@ -27,7 +28,7 @@ class PolymorphicCircularTest(_base.MappedTest): ) data = Table('data', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('node_id', Integer, ForeignKey('table1.id')), Column('data', String(30)) ) @@ -72,7 +73,7 @@ class PolymorphicCircularTest(_base.MappedTest): polymorphic_on=table1.c.type, polymorphic_identity='table1', properties={ - 'next': relation(Table1, + 'nxt': relation(Table1, backref=backref('prev', foreignkey=join.c.id, uselist=False), uselist=False, primaryjoin=join.c.id==join.c.related_id), 'data':relation(mapper(Data, data)) @@ -86,15 +87,16 @@ class PolymorphicCircularTest(_base.MappedTest): # currently, the "eager" relationships degrade to lazy relationships # due to the polymorphic load. - # the "next" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential" + # the "nxt" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential" # exception now. since eager loading would never work for that relation anyway, its better that the user # gets an exception instead of it silently not eager loading. + # NOTE: using "nxt" instead of "next" to avoid 2to3 turning it into __next__() for some reason. table1_mapper = mapper(Table1, table1, #select_table=join, polymorphic_on=table1.c.type, polymorphic_identity='table1', properties={ - 'next': relation(Table1, + 'nxt': relation(Table1, backref=backref('prev', remote_side=table1.c.id, uselist=False), uselist=False, primaryjoin=table1.c.id==table1.c.related_id), 'data':relation(mapper(Data, data), lazy=False, order_by=data.c.id) @@ -147,7 +149,7 @@ class PolymorphicCircularTest(_base.MappedTest): else: newobj = c if obj is not None: - obj.next = newobj + obj.nxt = newobj else: t = newobj obj = newobj @@ -161,7 +163,7 @@ class PolymorphicCircularTest(_base.MappedTest): node = t while (node): assertlist.append(node) - n = node.next + n = node.nxt if n is not None: assert n.prev is node node = n @@ -174,7 +176,7 @@ class PolymorphicCircularTest(_base.MappedTest): assertlist = [] while (node): assertlist.append(node) - n = node.next + n = node.nxt if n is not None: assert n.prev is node node = n @@ -188,7 +190,7 @@ class PolymorphicCircularTest(_base.MappedTest): assertlist.insert(0, node) n = node.prev if n is not None: - assert n.next is node + assert n.nxt is node node = n backwards = repr(assertlist) diff --git a/test/orm/inheritance/test_polymorph2.py b/test/orm/inheritance/test_polymorph2.py index 51b6d4970a..80c14413a0 100644 --- a/test/orm/inheritance/test_polymorph2.py +++ b/test/orm/inheritance/test_polymorph2.py @@ -11,6 +11,7 @@ from sqlalchemy.test import TestBase, AssertsExecutionResults, testing from sqlalchemy.util import function_named from test.orm import _base, _fixtures from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.schema import Table, Column class AttrSettable(object): def __init__(self, **kwargs): @@ -105,7 +106,7 @@ class RelationTest2(_base.MappedTest): def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(30))) @@ -201,7 +202,7 @@ class RelationTest3(_base.MappedTest): def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('colleague_id', Integer, ForeignKey('people.person_id')), Column('name', String(50)), Column('type', String(30))) @@ -307,7 +308,7 @@ class RelationTest4(_base.MappedTest): def define_tables(cls, metadata): global people, engineers, managers, cars people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) engineers = Table('engineers', metadata, @@ -319,7 +320,7 @@ class RelationTest4(_base.MappedTest): Column('longer_status', String(70))) cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), + Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('owner', Integer, ForeignKey('people.person_id'))) def testmanytoonepolymorphic(self): @@ -420,7 +421,7 @@ class RelationTest5(_base.MappedTest): def define_tables(cls, metadata): global people, engineers, managers, cars people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(50))) @@ -433,7 +434,7 @@ class RelationTest5(_base.MappedTest): Column('longer_status', String(70))) cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), + Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('owner', Integer, ForeignKey('people.person_id'))) def testeagerempty(self): @@ -482,7 +483,7 @@ class RelationTest6(_base.MappedTest): def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), ) @@ -525,14 +526,14 @@ class RelationTest7(_base.MappedTest): def define_tables(cls, metadata): global people, engineers, managers, cars, offroad_cars cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), + Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30))) offroad_cars = Table('offroad_cars', metadata, Column('car_id',Integer, ForeignKey('cars.car_id'),nullable=False,primary_key=True)) people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('car_id', Integer, ForeignKey('cars.car_id'), nullable=False), Column('name', String(50))) @@ -625,7 +626,7 @@ class RelationTest8(_base.MappedTest): def define_tables(cls, metadata): global taggable, users taggable = Table('taggable', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(30)), Column('owner_id', Integer, ForeignKey('taggable.id')), ) @@ -680,11 +681,11 @@ class GenerativeTest(TestBase, AssertsExecutionResults): metadata = MetaData(testing.db) # table definitions status = Table('status', metadata, - Column('status_id', Integer, primary_key=True), + Column('status_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(20))) people = Table('people', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), Column('name', String(50))) @@ -697,7 +698,7 @@ class GenerativeTest(TestBase, AssertsExecutionResults): Column('category', String(70))) cars = Table('cars', metadata, - Column('car_id', Integer, primary_key=True), + Column('car_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), Column('owner', Integer, ForeignKey('people.person_id'), nullable=False)) @@ -786,13 +787,13 @@ class GenerativeTest(TestBase, AssertsExecutionResults): e = exists([Car.owner], Car.owner==employee_join.c.person_id) Query(Person)._adapt_clause(employee_join, False, False) - r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active") - assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]" + r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active").order_by(Person.person_id) + eq_(str(list(r)), "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]") r = session.query(Engineer).join('status').filter(Person.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active")).order_by(Person.name) - assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]" + eq_(str(list(r)), "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]") r = session.query(Person).filter(exists([1], Car.owner==Person.person_id)) - assert str(list(r)) == "[Engineer E4, field X, status Status dead]" + eq_(str(list(r)), "[Engineer E4, field X, status Status dead]") class MultiLevelTest(_base.MappedTest): @classmethod @@ -800,7 +801,7 @@ class MultiLevelTest(_base.MappedTest): global table_Employee, table_Engineer, table_Manager table_Employee = Table( 'Employee', metadata, Column( 'name', type_= String(100), ), - Column( 'id', primary_key= True, type_= Integer, ), + Column( 'id', primary_key= True, type_= Integer, test_needs_autoincrement=True), Column( 'atype', type_= String(100), ), ) @@ -878,7 +879,7 @@ class ManyToManyPolyTest(_base.MappedTest): global base_item_table, item_table, base_item_collection_table, collection_table base_item_table = Table( 'base_item', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('child_name', String(255), default=None)) item_table = Table( @@ -893,7 +894,7 @@ class ManyToManyPolyTest(_base.MappedTest): collection_table = Table( 'collection', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', Unicode(255))) def test_pjoin_compile(self): @@ -928,7 +929,7 @@ class CustomPKTest(_base.MappedTest): def define_tables(cls, metadata): global t1, t2 t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(30), nullable=False), Column('data', String(30))) # note that the primary key column in t2 is named differently @@ -1013,7 +1014,7 @@ class InheritingEagerTest(_base.MappedTest): global people, employees, tags, peopleTags people = Table('people', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('_type', String(30), nullable=False), ) @@ -1023,7 +1024,7 @@ class InheritingEagerTest(_base.MappedTest): ) tags = Table('tags', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('label', String(50), nullable=False), ) @@ -1074,11 +1075,11 @@ class MissingPolymorphicOnTest(_base.MappedTest): def define_tables(cls, metadata): global tablea, tableb, tablec, tabled tablea = Table('tablea', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('adata', String(50)), ) tableb = Table('tableb', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('aid', Integer, ForeignKey('tablea.id')), Column('data', String(50)), ) diff --git a/test/orm/inheritance/test_productspec.py b/test/orm/inheritance/test_productspec.py index b2bcb85d54..4c593e2a38 100644 --- a/test/orm/inheritance/test_productspec.py +++ b/test/orm/inheritance/test_productspec.py @@ -2,10 +2,9 @@ from datetime import datetime from sqlalchemy import * from sqlalchemy.orm import * - from sqlalchemy.test import testing from test.orm import _base - +from sqlalchemy.test.schema import Table, Column class InheritTest(_base.MappedTest): """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships""" @@ -15,14 +14,14 @@ class InheritTest(_base.MappedTest): global Product, Detail, Assembly, SpecLine, Document, RasterDocument products_table = Table('products', metadata, - Column('product_id', Integer, primary_key=True), + Column('product_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('product_type', String(128)), Column('name', String(128)), Column('mark', String(128)), ) specification_table = Table('specification', metadata, - Column('spec_line_id', Integer, primary_key=True), + Column('spec_line_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('master_id', Integer, ForeignKey("products.product_id"), nullable=True), Column('slave_id', Integer, ForeignKey("products.product_id"), @@ -31,7 +30,7 @@ class InheritTest(_base.MappedTest): ) documents_table = Table('documents', metadata, - Column('document_id', Integer, primary_key=True), + Column('document_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('document_type', String(128)), Column('product_id', Integer, ForeignKey('products.product_id')), Column('create_date', DateTime, default=lambda:datetime.now()), diff --git a/test/orm/inheritance/test_query.py b/test/orm/inheritance/test_query.py index 5b57e8f457..daf8bf3bd0 100644 --- a/test/orm/inheritance/test_query.py +++ b/test/orm/inheritance/test_query.py @@ -8,6 +8,7 @@ from sqlalchemy.engine import default from sqlalchemy.test import AssertsCompiledSQL, testing from test.orm import _base, _fixtures from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.schema import Table, Column class Company(_fixtures.Base): pass @@ -38,11 +39,11 @@ def _produce_test(select_type): global companies, people, engineers, managers, boss, paperwork, machines companies = Table('companies', metadata, - Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key=True), + Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('company_id', Integer, ForeignKey('companies.company_id')), Column('name', String(50)), Column('type', String(30))) @@ -55,7 +56,7 @@ def _produce_test(select_type): ) machines = Table('machines', metadata, - Column('machine_id', Integer, primary_key=True), + Column('machine_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('engineer_id', Integer, ForeignKey('engineers.person_id'))) @@ -71,7 +72,7 @@ def _produce_test(select_type): ) paperwork = Table('paperwork', metadata, - Column('paperwork_id', Integer, primary_key=True), + Column('paperwork_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('description', String(50)), Column('person_id', Integer, ForeignKey('people.person_id'))) @@ -771,7 +772,7 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest): def define_tables(cls, metadata): global people, engineers people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(30))) @@ -831,7 +832,7 @@ class SelfReferentialJ2JTest(_base.MappedTest): def define_tables(cls, metadata): global people, engineers, managers people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(30))) @@ -947,7 +948,7 @@ class M2MFilterTest(_base.MappedTest): global people, engineers, organizations, engineers_to_org organizations = Table('organizations', metadata, - Column('id', Integer, Sequence('org_id_seq', optional=True), primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), ) engineers_to_org = Table('engineers_org', metadata, @@ -956,7 +957,7 @@ class M2MFilterTest(_base.MappedTest): ) people = Table('people', metadata, - Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(30))) @@ -1023,7 +1024,7 @@ class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL): class Parent(Base): __tablename__ = 'parent' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, test_needs_autoincrement=True) cls = Column(String(50)) __mapper_args__ = dict(polymorphic_on = cls ) diff --git a/test/orm/inheritance/test_selects.py b/test/orm/inheritance/test_selects.py index a151af4fa2..7c9920f6f8 100644 --- a/test/orm/inheritance/test_selects.py +++ b/test/orm/inheritance/test_selects.py @@ -46,6 +46,6 @@ class InheritingSelectablesTest(MappedTest): s = sessionmaker(bind=testing.db)() - assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all() + assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).order_by(Foo.b.desc()).all() assert [Bar(), Bar()] == s.query(Bar).all() diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index 7058268857..fc30955db8 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -5,20 +5,21 @@ from sqlalchemy.orm import * from sqlalchemy.test import testing from test.orm import _fixtures from test.orm._base import MappedTest, ComparableEntity +from sqlalchemy.test.schema import Table, Column class SingleInheritanceTest(MappedTest): @classmethod def define_tables(cls, metadata): Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('manager_data', String(50)), Column('engineer_info', String(50)), Column('type', String(20))) Table('reports', metadata, - Column('report_id', Integer, primary_key=True), + Column('report_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('employee_id', ForeignKey('employees.employee_id')), Column('name', String(50)), ) @@ -186,7 +187,7 @@ class RelationToSingleTest(MappedTest): @classmethod def define_tables(cls, metadata): Table('employees', metadata, - Column('employee_id', Integer, primary_key=True), + Column('employee_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('manager_data', String(50)), Column('engineer_info', String(50)), @@ -195,7 +196,7 @@ class RelationToSingleTest(MappedTest): ) Table('companies', metadata, - Column('company_id', Integer, primary_key=True), + Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), ) @@ -342,7 +343,7 @@ class SingleOnJoinedTest(MappedTest): global persons_table, employees_table persons_table = Table('persons', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('type', String(20), nullable=False) ) diff --git a/test/orm/sharding/test_shard.py b/test/orm/sharding/test_shard.py index 89e23fb759..e8ffaa7cad 100644 --- a/test/orm/sharding/test_shard.py +++ b/test/orm/sharding/test_shard.py @@ -6,6 +6,7 @@ from sqlalchemy.orm.shard import ShardedSession from sqlalchemy.sql import operators from sqlalchemy.test import * from sqlalchemy.test.testing import eq_ +from nose import SkipTest # TODO: ShardTest can be turned into a base for further subclasses @@ -14,7 +15,10 @@ class ShardTest(TestBase): def setup_class(cls): global db1, db2, db3, db4, weather_locations, weather_reports - db1 = create_engine('sqlite:///shard1.db') + try: + db1 = create_engine('sqlite:///shard1.db') + except ImportError: + raise SkipTest('Requires sqlite') db2 = create_engine('sqlite:///shard2.db') db3 = create_engine('sqlite:///shard3.db') db4 = create_engine('sqlite:///shard4.db') diff --git a/test/orm/test_association.py b/test/orm/test_association.py index ee7fb7af94..d537430cc6 100644 --- a/test/orm/test_association.py +++ b/test/orm/test_association.py @@ -1,8 +1,7 @@ from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session from test.orm import _base from sqlalchemy.test.testing import eq_ @@ -15,14 +14,14 @@ class AssociationTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('items', metadata, - Column('item_id', Integer, primary_key=True), + Column('item_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40))) Table('item_keywords', metadata, Column('item_id', Integer, ForeignKey('items.item_id')), Column('keyword_id', Integer, ForeignKey('keywords.keyword_id')), Column('data', String(40))) Table('keywords', metadata, - Column('keyword_id', Integer, primary_key=True), + Column('keyword_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40))) @classmethod diff --git a/test/orm/test_assorted_eager.py b/test/orm/test_assorted_eager.py index 09f0075479..94a98d9aea 100644 --- a/test/orm/test_assorted_eager.py +++ b/test/orm/test_assorted_eager.py @@ -8,8 +8,7 @@ import datetime import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, backref, create_session from sqlalchemy.test.testing import eq_ from test.orm import _base @@ -37,27 +36,24 @@ class EagerTest(_base.MappedTest): cls.other_artifacts['false'] = false Table('owners', metadata , - Column('id', Integer, primary_key=True, nullable=False), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) Table('categories', metadata, - Column('id', Integer, primary_key=True, nullable=False), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(20))) Table('tests', metadata , - Column('id', Integer, primary_key=True, nullable=False ), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('owner_id', Integer, ForeignKey('owners.id'), nullable=False), Column('category_id', Integer, ForeignKey('categories.id'), nullable=False)) Table('options', metadata , - Column('test_id', Integer, ForeignKey('tests.id'), - primary_key=True, nullable=False), - Column('owner_id', Integer, ForeignKey('owners.id'), - primary_key=True, nullable=False), - Column('someoption', sa.Boolean, server_default=false, - nullable=False)) + Column('test_id', Integer, ForeignKey('tests.id'), primary_key=True), + Column('owner_id', Integer, ForeignKey('owners.id'), primary_key=True), + Column('someoption', sa.Boolean, server_default=false, nullable=False)) @classmethod def setup_classes(cls): @@ -219,7 +215,7 @@ class EagerTest2(_base.MappedTest): Column('data', String(50), primary_key=True)) Table('middle', metadata, - Column('id', Integer, primary_key = True), + Column('id', Integer, primary_key = True, test_needs_autoincrement=True), Column('data', String(50))) Table('right', metadata, @@ -280,17 +276,15 @@ class EagerTest3(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('datas', metadata, - Column('id', Integer, primary_key=True, nullable=False), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('a', Integer, nullable=False)) Table('foo', metadata, - Column('data_id', Integer, - ForeignKey('datas.id'), - nullable=False, primary_key=True), + Column('data_id', Integer, ForeignKey('datas.id'),primary_key=True), Column('bar', Integer)) Table('stats', metadata, - Column('id', Integer, primary_key=True, nullable=False ), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data_id', Integer, ForeignKey('datas.id')), Column('somedata', Integer, nullable=False )) @@ -364,11 +358,11 @@ class EagerTest4(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('departments', metadata, - Column('department_id', Integer, primary_key=True), + Column('department_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) Table('employees', metadata, - Column('person_id', Integer, primary_key=True), + Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('department_id', Integer, ForeignKey('departments.department_id'))) @@ -422,17 +416,15 @@ class EagerTest5(_base.MappedTest): Column('x', String(30))) Table('derived', metadata, - Column('uid', String(30), ForeignKey('base.uid'), - primary_key=True), + Column('uid', String(30), ForeignKey('base.uid'), primary_key=True), Column('y', String(30))) Table('derivedII', metadata, - Column('uid', String(30), ForeignKey('base.uid'), - primary_key=True), + Column('uid', String(30), ForeignKey('base.uid'), primary_key=True), Column('z', String(30))) Table('comments', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('uid', String(30), ForeignKey('base.uid')), Column('comment', String(30))) @@ -505,21 +497,21 @@ class EagerTest6(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('design_types', metadata, - Column('design_type_id', Integer, primary_key=True)) + Column('design_type_id', Integer, primary_key=True, test_needs_autoincrement=True)) Table('design', metadata, - Column('design_id', Integer, primary_key=True), + Column('design_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) Table('parts', metadata, - Column('part_id', Integer, primary_key=True), + Column('part_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('design_id', Integer, ForeignKey('design.design_id')), Column('design_type_id', Integer, ForeignKey('design_types.design_type_id'))) Table('inherited_part', metadata, - Column('ip_id', Integer, primary_key=True), + Column('ip_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('part_id', Integer, ForeignKey('parts.part_id')), Column('design_id', Integer, ForeignKey('design.design_id'))) @@ -573,32 +565,27 @@ class EagerTest7(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('companies', metadata, - Column('company_id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('company_name', String(40))) Table('addresses', metadata, - Column('address_id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('address_id', Integer, primary_key=True,test_needs_autoincrement=True), Column('company_id', Integer, ForeignKey("companies.company_id")), Column('address', String(40))) Table('phone_numbers', metadata, - Column('phone_id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('phone_id', Integer, primary_key=True,test_needs_autoincrement=True), Column('address_id', Integer, ForeignKey('addresses.address_id')), Column('type', String(20)), Column('number', String(10))) Table('invoices', metadata, - Column('invoice_id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('invoice_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('company_id', Integer, ForeignKey("companies.company_id")), Column('date', sa.DateTime)) Table('items', metadata, - Column('item_id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('item_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('invoice_id', Integer, ForeignKey('invoices.invoice_id')), Column('code', String(20)), Column('qty', Integer)) @@ -722,12 +709,12 @@ class EagerTest8(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('prj', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('created', sa.DateTime ), Column('title', sa.Unicode(100))) Table('task', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('status_id', Integer, ForeignKey('task_status.id'), nullable=False), Column('title', sa.Unicode(100)), @@ -736,19 +723,19 @@ class EagerTest8(_base.MappedTest): Column('prj_id', Integer , ForeignKey('prj.id'), nullable=False)) Table('task_status', metadata, - Column('id', Integer, primary_key=True)) + Column('id', Integer, primary_key=True, test_needs_autoincrement=True)) Table('task_type', metadata, - Column('id', Integer, primary_key=True)) + Column('id', Integer, primary_key=True, test_needs_autoincrement=True)) Table('msg', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('posted', sa.DateTime, index=True,), Column('type_id', Integer, ForeignKey('msg_type.id')), Column('task_id', Integer, ForeignKey('task.id'))) Table('msg_type', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', sa.Unicode(20)), Column('display_name', sa.Unicode(20))) @@ -814,15 +801,15 @@ class EagerTest9(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('accounts', metadata, - Column('account_id', Integer, primary_key=True), + Column('account_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40))) Table('transactions', metadata, - Column('transaction_id', Integer, primary_key=True), + Column('transaction_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40))) Table('entries', metadata, - Column('entry_id', Integer, primary_key=True), + Column('entry_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40)), Column('account_id', Integer, ForeignKey('accounts.account_id')), diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index ca8cef3ad8..fa26ec7d7e 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -6,7 +6,8 @@ from sqlalchemy import exc as sa_exc from sqlalchemy.test import * from sqlalchemy.test.testing import eq_ from test.orm import _base -import gc +from sqlalchemy.test.util import gc_collect +from sqlalchemy.util import cmp, jython # global for pickling tests MyTest = None @@ -80,7 +81,9 @@ class AttributesTest(_base.ORMTest): del o2.__dict__['mt2'] o2.__dict__[o_mt2_str] = former - self.assert_(pk_o == pk_o2) + # Relies on dict ordering + if not jython: + self.assert_(pk_o == pk_o2) # the above is kind of distrurbing, so let's do it again a little # differently. the string-id in serialization thing is just an @@ -93,7 +96,9 @@ class AttributesTest(_base.ORMTest): o4 = pickle.loads(pk_o3) pk_o4 = pickle.dumps(o4) - self.assert_(pk_o3 == pk_o4) + # Relies on dict ordering + if not jython: + self.assert_(pk_o3 == pk_o4) # and lastly make sure we still have our data after all that. # identical serialzation is great, *if* it's complete :) @@ -117,7 +122,7 @@ class AttributesTest(_base.ORMTest): f.bar = "foo" assert state.dict == {'bar':'foo', state.manager.STATE_ATTR:state} del f - gc.collect() + gc_collect() assert state.obj() is None assert state.dict == {} diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py index d0a7b9ded6..c523fb5f01 100644 --- a/test/orm/test_cascade.py +++ b/test/orm/test_cascade.py @@ -1,8 +1,7 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref from sqlalchemy.orm import attributes, exc as orm_exc from sqlalchemy.test import testing @@ -20,7 +19,7 @@ class O2MCascadeTest(_fixtures.FixtureTest): mapper(User, users, properties = dict( addresses = relation(Address, cascade="all, delete-orphan", backref="user"), orders = relation( - mapper(Order, orders), cascade="all, delete-orphan") + mapper(Order, orders), cascade="all, delete-orphan", order_by=orders.c.id) )) mapper(Dingaling,dingalings, properties={ 'address':relation(Address) @@ -50,16 +49,12 @@ class O2MCascadeTest(_fixtures.FixtureTest): orders=[Order(description="order 3"), Order(description="order 4")])) - eq_(sess.query(Order).all(), + eq_(sess.query(Order).order_by(Order.id).all(), [Order(description="order 3"), Order(description="order 4")]) o5 = Order(description="order 5") sess.add(o5) - try: - sess.flush() - assert False - except orm_exc.FlushError, e: - assert "is an orphan" in str(e) + assert_raises_message(orm_exc.FlushError, "is an orphan", sess.flush) @testing.resolve_artifact_names @@ -351,18 +346,15 @@ class M2OCascadeTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("extra", metadata, - Column("id", Integer, Sequence("extra_id_seq", optional=True), - primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("prefs_id", Integer, ForeignKey("prefs.id"))) Table('prefs', metadata, - Column('id', Integer, Sequence('prefs_id_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40))) Table('users', metadata, - Column('id', Integer, Sequence('user_id_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40)), Column('pref_id', Integer, ForeignKey('prefs.id'))) @@ -453,22 +445,22 @@ class M2OCascadeTest(_base.MappedTest): jack.pref = newpref jack.pref = newpref sess.flush() - eq_(sess.query(Pref).all(), + eq_(sess.query(Pref).order_by(Pref.id).all(), [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")]) class M2OCascadeDeleteTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t2id', Integer, ForeignKey('t2.id'))) Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t3id', Integer, ForeignKey('t3.id'))) Table('t3', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) @classmethod @@ -581,15 +573,15 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t2id', Integer, ForeignKey('t2.id'))) Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t3id', Integer, ForeignKey('t3.id'))) Table('t3', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) @classmethod @@ -696,12 +688,12 @@ class M2MCascadeTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), test_needs_fk=True ) Table('b', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), test_needs_fk=True @@ -713,7 +705,7 @@ class M2MCascadeTest(_base.MappedTest): ) Table('c', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('bid', Integer, ForeignKey('b.id')), test_needs_fk=True @@ -838,15 +830,11 @@ class UnsavedOrphansTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('user_id', Integer, - Sequence('user_id_seq', optional=True), - primary_key=True), + Column('user_id', Integer,primary_key=True, test_needs_autoincrement=True), Column('name', String(40))) Table('addresses', metadata, - Column('address_id', Integer, - Sequence('address_id_seq', optional=True), - primary_key=True), + Column('address_id', Integer,primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey('users.user_id')), Column('email_address', String(40))) @@ -923,20 +911,17 @@ class UnsavedOrphansTest2(_base.MappedTest): @classmethod def define_tables(cls, meta): Table('orders', meta, - Column('id', Integer, Sequence('order_id_seq'), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) Table('items', meta, - Column('id', Integer, Sequence('item_id_seq'), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('order_id', Integer, ForeignKey('orders.id'), nullable=False), Column('name', String(50))) Table('attributes', meta, - Column('id', Integer, Sequence('attribute_id_seq'), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('item_id', Integer, ForeignKey('items.id'), nullable=False), Column('name', String(50))) @@ -982,19 +967,13 @@ class UnsavedOrphansTest3(_base.MappedTest): @classmethod def define_tables(cls, meta): Table('sales_reps', meta, - Column('sales_rep_id', Integer, - Sequence('sales_rep_id_seq'), - primary_key=True), + Column('sales_rep_id', Integer,primary_key=True, test_needs_autoincrement=True), Column('name', String(50))) Table('accounts', meta, - Column('account_id', Integer, - Sequence('account_id_seq'), - primary_key=True), + Column('account_id', Integer,primary_key=True, test_needs_autoincrement=True), Column('balance', Integer)) Table('customers', meta, - Column('customer_id', Integer, - Sequence('customer_id_seq'), - primary_key=True), + Column('customer_id', Integer,primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('sales_rep_id', Integer, ForeignKey('sales_reps.sales_rep_id')), @@ -1087,19 +1066,19 @@ class DoubleParentOrphanTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('addresses', metadata, - Column('address_id', Integer, primary_key=True), + Column('address_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('street', String(30)), ) Table('homes', metadata, - Column('home_id', Integer, primary_key=True, key="id"), + Column('home_id', Integer, primary_key=True, key="id", test_needs_autoincrement=True), Column('description', String(30)), Column('address_id', Integer, ForeignKey('addresses.address_id'), nullable=False), ) Table('businesses', metadata, - Column('business_id', Integer, primary_key=True, key="id"), + Column('business_id', Integer, primary_key=True, key="id", test_needs_autoincrement=True), Column('description', String(30), key="description"), Column('address_id', Integer, ForeignKey('addresses.address_id'), nullable=False), @@ -1159,10 +1138,10 @@ class CollectionAssignmentOrphanTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('table_a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30))) Table('table_b', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30)), Column('a_id', Integer, ForeignKey('table_a.id'))) @@ -1208,12 +1187,12 @@ class PartialFlushTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("base", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("descr", String(50)) ) Table("noninh_child", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('base_id', Integer, ForeignKey('base.id')) ) diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 12ff25c460..3d1b30bc9c 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -8,12 +8,11 @@ from sqlalchemy.orm.collections import collection import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy import util, exc as sa_exc -from sqlalchemy.orm import create_session, mapper, relation, attributes +from sqlalchemy.orm import create_session, mapper, relation, attributes from test.orm import _base -from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.testing import eq_, assert_raises class Canary(sa.orm.interfaces.AttributeExtension): def __init__(self): @@ -169,6 +168,13 @@ class CollectionsTest(_base.ORMTest): control[slice(0,-1)] = values assert_eq() + values = [creator(),creator(),creator()] + control[:] = values + direct[:] = values + def invalid(): + direct[slice(0, 6, 2)] = [creator()] + assert_raises(ValueError, invalid) + if hasattr(direct, '__delitem__'): e = creator() direct.append(e) @@ -193,7 +199,7 @@ class CollectionsTest(_base.ORMTest): del direct[::2] del control[::2] assert_eq() - + if hasattr(direct, 'remove'): e = creator() direct.append(e) @@ -202,8 +208,21 @@ class CollectionsTest(_base.ORMTest): direct.remove(e) control.remove(e) assert_eq() - - if hasattr(direct, '__setslice__'): + + if hasattr(direct, '__setitem__') or hasattr(direct, '__setslice__'): + + values = [creator(), creator()] + direct[:] = values + control[:] = values + assert_eq() + + # test slice assignment where + # slice size goes over the number of items + values = [creator(), creator()] + direct[1:3] = values + control[1:3] = values + assert_eq() + values = [creator(), creator()] direct[0:1] = values control[0:1] = values @@ -228,8 +247,19 @@ class CollectionsTest(_base.ORMTest): direct[1::2] = values control[1::2] = values assert_eq() + + values = [creator(), creator()] + direct[-1:-3] = values + control[-1:-3] = values + assert_eq() - if hasattr(direct, '__delslice__'): + values = [creator(), creator()] + direct[-2:-1] = values + control[-2:-1] = values + assert_eq() + + + if hasattr(direct, '__delitem__') or hasattr(direct, '__delslice__'): for i in range(1, 4): e = creator() direct.append(e) @@ -246,7 +276,7 @@ class CollectionsTest(_base.ORMTest): del direct[:] del control[:] assert_eq() - + if hasattr(direct, 'extend'): values = [creator(), creator(), creator()] @@ -345,6 +375,45 @@ class CollectionsTest(_base.ORMTest): self._test_list(list) self._test_list_bulk(list) + def test_list_setitem_with_slices(self): + + # this is a "list" that has no __setslice__ + # or __delslice__ methods. The __setitem__ + # and __delitem__ must therefore accept + # slice objects (i.e. as in py3k) + class ListLike(object): + def __init__(self): + self.data = list() + def append(self, item): + self.data.append(item) + def remove(self, item): + self.data.remove(item) + def insert(self, index, item): + self.data.insert(index, item) + def pop(self, index=-1): + return self.data.pop(index) + def extend(self): + assert False + def __len__(self): + return len(self.data) + def __setitem__(self, key, value): + self.data[key] = value + def __getitem__(self, key): + return self.data[key] + def __delitem__(self, key): + del self.data[key] + def __iter__(self): + return iter(self.data) + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'ListLike(%s)' % repr(self.data) + + self._test_adapter(ListLike) + self._test_list(ListLike) + self._test_list_bulk(ListLike) + def test_list_subclass(self): class MyList(list): pass @@ -1343,10 +1412,10 @@ class DictHelpersTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('parents', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('label', String(128))) Table('children', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_id', Integer, ForeignKey('parents.id'), nullable=False), Column('a', String(128)), @@ -1481,12 +1550,12 @@ class DictHelpersTest(_base.MappedTest): class Foo(BaseObject): __tablename__ = "foo" - id = Column(Integer(), primary_key=True) + id = Column(Integer(), primary_key=True, test_needs_autoincrement=True) bar_id = Column(Integer, ForeignKey('bar.id')) class Bar(BaseObject): __tablename__ = "bar" - id = Column(Integer(), primary_key=True) + id = Column(Integer(), primary_key=True, test_needs_autoincrement=True) foos = relation(Foo, collection_class=collections.column_mapped_collection(Foo.id)) foos2 = relation(Foo, collection_class=collections.column_mapped_collection((Foo.id, Foo.bar_id))) @@ -1521,17 +1590,16 @@ class DictHelpersTest(_base.MappedTest): collection_class = lambda: Ordered2(lambda v: (v.a, v.b)) self._test_composite_mapped(collection_class) -# TODO: are these tests redundant vs. the above tests ? -# remove if so class CustomCollectionsTest(_base.MappedTest): + """test the integration of collections with mapped classes.""" @classmethod def define_tables(cls, metadata): Table('sometable', metadata, - Column('col1',Integer, primary_key=True), + Column('col1',Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) Table('someothertable', metadata, - Column('col1', Integer, primary_key=True), + Column('col1', Integer, primary_key=True, test_needs_autoincrement=True), Column('scol1', Integer, ForeignKey('sometable.col1')), Column('data', String(20))) @@ -1646,15 +1714,50 @@ class CustomCollectionsTest(_base.MappedTest): replaced = set([id(b) for b in f.bars.values()]) self.assert_(existing != replaced) - @testing.resolve_artifact_names def test_list(self): + self._test_list(list) + + def test_list_no_setslice(self): + class ListLike(object): + def __init__(self): + self.data = list() + def append(self, item): + self.data.append(item) + def remove(self, item): + self.data.remove(item) + def insert(self, index, item): + self.data.insert(index, item) + def pop(self, index=-1): + return self.data.pop(index) + def extend(self): + assert False + def __len__(self): + return len(self.data) + def __setitem__(self, key, value): + self.data[key] = value + def __getitem__(self, key): + return self.data[key] + def __delitem__(self, key): + del self.data[key] + def __iter__(self): + return iter(self.data) + __hash__ = object.__hash__ + def __eq__(self, other): + return self.data == other + def __repr__(self): + return 'ListLike(%s)' % repr(self.data) + + self._test_list(ListLike) + + @testing.resolve_artifact_names + def _test_list(self, listcls): class Parent(object): pass class Child(object): pass mapper(Parent, sometable, properties={ - 'children':relation(Child, collection_class=list) + 'children':relation(Child, collection_class=listcls) }) mapper(Child, someothertable) diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index fe77b36018..6fbfe7fe18 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -7,8 +7,7 @@ T1/T2. """ from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, backref, create_session from sqlalchemy.test.testing import eq_ from sqlalchemy.test.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf @@ -138,7 +137,7 @@ class SelfReferentialNoPKTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('item', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('uuid', String(32), unique=True, nullable=False), Column('parent_uuid', String(32), ForeignKey('item.uuid'), nullable=True)) @@ -190,18 +189,16 @@ class InheritTestOne(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("parent", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("parent_data", String(50)), Column("type", String(10))) Table("child1", metadata, - Column("id", Integer, ForeignKey("parent.id"), - primary_key=True), + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), Column("child1_data", String(50))) Table("child2", metadata, - Column("id", Integer, ForeignKey("parent.id"), - primary_key=True), + Column("id", Integer, ForeignKey("parent.id"), primary_key=True), Column("child1_id", Integer, ForeignKey("child1.id"), nullable=False), Column("child2_data", String(50))) @@ -262,7 +259,7 @@ class InheritTestTwo(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('cid', Integer, ForeignKey('c.id'))) @@ -271,7 +268,7 @@ class InheritTestTwo(_base.MappedTest): Column('data', String(30))) Table('c', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('aid', Integer, ForeignKey('a.id', use_alter=True, name="foo"))) @@ -311,16 +308,16 @@ class BiDirectionalManyToOneTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('t2id', Integer, ForeignKey('t2.id'))) Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('t1id', Integer, ForeignKey('t1.id', use_alter=True, name="foo_fk"))) Table('t3', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('t1id', Integer, ForeignKey('t1.id'), nullable=False), Column('t2id', Integer, ForeignKey('t2.id'), nullable=False)) @@ -402,13 +399,11 @@ class BiDirectionalOneToManyTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', Integer, ForeignKey('t2.c1'))) Table('t2', metadata, - Column('c1', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', Integer, ForeignKey('t1.c1', use_alter=True, name='t1c1_fk'))) @@ -453,18 +448,18 @@ class BiDirectionalOneToManyTest2(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('c1', Integer, primary_key=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', Integer, ForeignKey('t2.c1')), test_needs_autoincrement=True) Table('t2', metadata, - Column('c1', Integer, primary_key=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', Integer, ForeignKey('t1.c1', use_alter=True, name='t1c1_fq')), test_needs_autoincrement=True) Table('t1_data', metadata, - Column('c1', Integer, primary_key=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('t1id', Integer, ForeignKey('t1.c1')), Column('data', String(20)), test_needs_autoincrement=True) @@ -530,15 +525,13 @@ class OneToManyManyToOneTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('ball', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('person_id', Integer, ForeignKey('person.id', use_alter=True, name='fk_person_id')), Column('data', String(30))) Table('person', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('favorite_ball_id', Integer, ForeignKey('ball.id')), Column('data', String(30))) @@ -841,7 +834,7 @@ class SelfReferentialPostUpdateTest2(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("a_table", metadata, - Column("id", Integer(), primary_key=True), + Column("id", Integer(), primary_key=True, test_needs_autoincrement=True), Column("fui", String(128)), Column("b", Integer(), ForeignKey("a_table.id"))) diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index b063780ac7..5379c97149 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -2,8 +2,7 @@ import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session from test.orm import _base from sqlalchemy.test.testing import eq_ @@ -15,7 +14,7 @@ class TriggerDefaultsTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): dt = Table('dt', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('col1', String(20)), Column('col2', String(20), server_default=sa.schema.FetchedValue()), @@ -34,17 +33,23 @@ class TriggerDefaultsTest(_base.MappedTest): "UPDATE dt SET col2='ins', col4='ins' " "WHERE dt.id IN (SELECT id FROM inserted);", on='mssql'), - ): - if testing.against(ins.on): - break - else: - ins = sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt " + sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT " + "ON dt " + "FOR EACH ROW " + "BEGIN " + ":NEW.col2 := 'ins'; :NEW.col4 := 'ins'; END;", + on='oracle'), + sa.DDL("CREATE TRIGGER dt_ins BEFORE INSERT ON dt " "FOR EACH ROW BEGIN " - "SET NEW.col2='ins'; SET NEW.col4='ins'; END") - ins.execute_at('after-create', dt) + "SET NEW.col2='ins'; SET NEW.col4='ins'; END", + on=lambda event, schema_item, bind, **kw: + bind.engine.name not in ('oracle', 'mssql', 'sqlite') + ), + ): + ins.execute_at('after-create', dt) + sa.DDL("DROP TRIGGER dt_ins").execute_at('before-drop', dt) - for up in ( sa.DDL("CREATE TRIGGER dt_up AFTER UPDATE ON dt " "FOR EACH ROW BEGIN " @@ -55,14 +60,19 @@ class TriggerDefaultsTest(_base.MappedTest): "UPDATE dt SET col3='up', col4='up' " "WHERE dt.id IN (SELECT id FROM deleted);", on='mssql'), - ): - if testing.against(up.on): - break - else: - up = sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt " + sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt " + "FOR EACH ROW BEGIN " + ":NEW.col3 := 'up'; :NEW.col4 := 'up'; END;", + on='oracle'), + sa.DDL("CREATE TRIGGER dt_up BEFORE UPDATE ON dt " "FOR EACH ROW BEGIN " - "SET NEW.col3='up'; SET NEW.col4='up'; END") - up.execute_at('after-create', dt) + "SET NEW.col3='up'; SET NEW.col4='up'; END", + on=lambda event, schema_item, bind, **kw: + bind.engine.name not in ('oracle', 'mssql', 'sqlite') + ), + ): + up.execute_at('after-create', dt) + sa.DDL("DROP TRIGGER dt_up").execute_at('before-drop', dt) @@ -115,7 +125,7 @@ class ExcludedDefaultsTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): dt = Table('dt', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('col1', String(20), default="hello"), ) diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index f2089a4351..23a5fc8762 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -3,8 +3,7 @@ import operator from sqlalchemy.orm import dynamic_loader, backref from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey, desc, select, func -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session, Query, attributes from sqlalchemy.orm.dynamic import AppenderMixin from sqlalchemy.test.testing import eq_ @@ -344,7 +343,8 @@ class SessionTest(_fixtures.FixtureTest): sess.flush() sess.commit() u1.addresses.append(Address(email_address='foo@bar.com')) - eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')]) + eq_(u1.addresses.order_by(Address.id).all(), + [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')]) sess.rollback() eq_(u1.addresses.all(), [Address(email_address='lala@hoho.com')]) @@ -502,13 +502,13 @@ class DontDereferenceTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(40)), Column('fullname', String(100)), Column('password', String(15))) Table('addresses', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('email_address', String(100), nullable=False), Column('user_id', Integer, ForeignKey('users.id'))) diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index 384e0472f6..425c08c610 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -5,8 +5,7 @@ import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy.orm import eagerload, deferred, undefer from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, func -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session, lazyload, aliased from sqlalchemy.test.testing import eq_ from sqlalchemy.test.assertsql import CompiledSQL @@ -459,20 +458,14 @@ class EagerTest(_fixtures.FixtureTest): }) mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id), - 'orders':relation(Order, lazy=True) + 'orders':relation(Order, lazy=True, order_by=orders.c.id) }) sess = create_session() q = sess.query(User) - if testing.against('mysql'): - l = q.limit(2).all() - assert self.static.user_all_result[:2] == l - else: - l = q.order_by(User.id).limit(2).offset(1).all() - print self.static.user_all_result[1:3] - print l - assert self.static.user_all_result[1:3] == l + l = q.order_by(User.id).limit(2).offset(1).all() + eq_(self.static.user_all_result[1:3], l) @testing.resolve_artifact_names def test_distinct(self): @@ -483,15 +476,15 @@ class EagerTest(_fixtures.FixtureTest): s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses), lazy=False), + 'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id), }) sess = create_session() q = sess.query(User) def go(): - l = q.filter(s.c.u2_id==User.id).distinct().all() - assert self.static.user_address_result == l + l = q.filter(s.c.u2_id==User.id).distinct().order_by(User.id).all() + eq_(self.static.user_address_result, l) self.assert_sql_count(testing.db, go, 1) @testing.fails_on('maxdb', 'FIXME: unknown') @@ -656,9 +649,12 @@ class EagerTest(_fixtures.FixtureTest): mapper(Order, orders) mapper(User, users, properties={ - 'orders':relation(Order, backref='user', lazy=False), - 'max_order':relation(mapper(Order, max_orders, non_primary=True), lazy=False, uselist=False) + 'orders':relation(Order, backref='user', lazy=False, order_by=orders.c.id), + 'max_order':relation( + mapper(Order, max_orders, non_primary=True), + lazy=False, uselist=False) }) + q = create_session().query(User) def go(): @@ -675,7 +671,7 @@ class EagerTest(_fixtures.FixtureTest): max_order=Order(id=4) ), User(id=10), - ] == q.all() + ] == q.order_by(User.id).all() self.assert_sql_count(testing.db, go, 1) @testing.resolve_artifact_names @@ -823,15 +819,15 @@ class OrderBySecondaryTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('m2m', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('aid', Integer, ForeignKey('a.id')), Column('bid', Integer, ForeignKey('b.id'))) Table('a', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) Table('b', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) @classmethod @@ -873,8 +869,7 @@ class SelfReferentialEagerTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('nodes', metadata, - Column('id', Integer, sa.Sequence('node_id_seq', optional=True), - primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_id', Integer, ForeignKey('nodes.id')), Column('data', String(30))) @@ -1088,11 +1083,11 @@ class MixedSelfReferentialEagerTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('a_table', metadata, - Column('id', Integer, primary_key=True) + Column('id', Integer, primary_key=True, test_needs_autoincrement=True) ) Table('b_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_b1_id', Integer, ForeignKey('b_table.id')), Column('parent_a_id', Integer, ForeignKey('a_table.id')), Column('parent_b2_id', Integer, ForeignKey('b_table.id'))) @@ -1161,7 +1156,7 @@ class SelfReferentialM2MEagerTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('widget', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', sa.Unicode(40), nullable=False, unique=True), ) @@ -1244,7 +1239,7 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) self.assert_sql_count(testing.db, go, 1) - @testing.exclude('sqlite', '>', (0, 0, 0), "sqlite flat out blows it on the multiple JOINs") + @testing.exclude('sqlite', '>', (0, ), "sqlite flat out blows it on the multiple JOINs") @testing.resolve_artifact_names def test_two_entities_with_joins(self): sess = create_session() @@ -1337,13 +1332,13 @@ class CyclicalInheritingEagerTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('c1', Integer, primary_key=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', String(30)), Column('type', String(30)) ) Table('t2', metadata, - Column('c1', Integer, primary_key=True), + Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), Column('c2', String(30)), Column('type', String(30)), Column('t1.id', Integer, ForeignKey('t1.c1'))) @@ -1376,12 +1371,12 @@ class SubqueryTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(16)) ) Table('tags_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey("users_table.id")), Column('score1', sa.Float), Column('score2', sa.Float), @@ -1461,16 +1456,20 @@ class CorrelatedSubqueryTest(_base.MappedTest): Exercises a variety of ways to configure this. """ + + # another argument for eagerload learning about inner joins + + __requires__ = ('correlated_outer_joins', ) @classmethod def define_tables(cls, metadata): users = Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)) ) stuff = Table('stuff', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('date', Date), Column('user_id', Integer, ForeignKey('users.id'))) @@ -1549,11 +1548,13 @@ class CorrelatedSubqueryTest(_base.MappedTest): if ondate: # the more 'relational' way to do this, join on the max date - stuff_view = select([func.max(salias.c.date).label('max_date')]).where(salias.c.user_id==users.c.id).correlate(users) + stuff_view = select([func.max(salias.c.date).label('max_date')]).\ + where(salias.c.user_id==users.c.id).correlate(users) else: # a common method with the MySQL crowd, which actually might perform better in some # cases - subquery does a limit with order by DESC, join on the id - stuff_view = select([salias.c.id]).where(salias.c.user_id==users.c.id).correlate(users).order_by(salias.c.date.desc()).limit(1) + stuff_view = select([salias.c.id]).where(salias.c.user_id==users.c.id).\ + correlate(users).order_by(salias.c.date.desc()).limit(1) if labeled == 'label': stuff_view = stuff_view.label('foo') diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py index 6593498978..c602ac963f 100644 --- a/test/orm/test_expire.py +++ b/test/orm/test_expire.py @@ -1,7 +1,7 @@ """Attribute/instance expiration, deferral of attributes, etc.""" from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message -import gc +from sqlalchemy.test.util import gc_collect import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey, exc as sa_exc @@ -666,7 +666,7 @@ class ExpireTest(_fixtures.FixtureTest): assert self.static.user_address_result == userlist assert len(list(sess)) == 9 sess.expire_all() - gc.collect() + gc_collect() assert len(list(sess)) == 4 # since addresses were gc'ed userlist = sess.query(User).order_by(User.id).all() diff --git a/test/orm/test_generative.py b/test/orm/test_generative.py index 0efc1814ed..8f61d4d148 100644 --- a/test/orm/test_generative.py +++ b/test/orm/test_generative.py @@ -70,12 +70,17 @@ class GenerativeQueryTest(_base.MappedTest): assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar<30).one() == (0,) assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar<30).one() == (29,) + # Py3K + #assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).__next__()[0] == 29 + #assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).__next__()[0] == 29 + # Py2K assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29 assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29 - + # end Py2K + @testing.resolve_artifact_names def test_aggregate_1(self): - if (testing.against('mysql') and + if (testing.against('mysql') and not testing.against('+zxjdbc') and testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')): return @@ -95,10 +100,18 @@ class GenerativeQueryTest(_base.MappedTest): def test_aggregate_3(self): query = create_session().query(Foo) + # Py3K + #avg_f = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).__next__()[0] + # Py2K avg_f = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).next()[0] + # end Py2K assert round(avg_f, 1) == 14.5 + # Py3K + #avg_o = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).__next__()[0] + # Py2K avg_o = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).next()[0] + # end Py2K assert round(avg_o, 1) == 14.5 @testing.resolve_artifact_names diff --git a/test/orm/test_instrumentation.py b/test/orm/test_instrumentation.py index b4c8f8601c..6390e25963 100644 --- a/test/orm/test_instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -488,7 +488,7 @@ class InstrumentationCollisionTest(_base.ORMTest): class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) - assert_raises(TypeError, attributes.register_class, B) + assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B) def test_single_up(self): @@ -499,7 +499,8 @@ class InstrumentationCollisionTest(_base.ORMTest): class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) attributes.register_class(B) - assert_raises(TypeError, attributes.register_class, A) + + assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, A) def test_diamond_b1(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -507,10 +508,10 @@ class InstrumentationCollisionTest(_base.ORMTest): class A(object): pass class B1(A): pass class B2(A): - __sa_instrumentation_manager__ = mgr_factory + __sa_instrumentation_manager__ = staticmethod(mgr_factory) class C(object): pass - assert_raises(TypeError, attributes.register_class, B1) + assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1) def test_diamond_b2(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -518,10 +519,11 @@ class InstrumentationCollisionTest(_base.ORMTest): class A(object): pass class B1(A): pass class B2(A): - __sa_instrumentation_manager__ = mgr_factory + __sa_instrumentation_manager__ = staticmethod(mgr_factory) class C(object): pass - assert_raises(TypeError, attributes.register_class, B2) + attributes.register_class(B2) + assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1) def test_diamond_c_b(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -529,12 +531,12 @@ class InstrumentationCollisionTest(_base.ORMTest): class A(object): pass class B1(A): pass class B2(A): - __sa_instrumentation_manager__ = mgr_factory + __sa_instrumentation_manager__ = staticmethod(mgr_factory) class C(object): pass attributes.register_class(C) - assert_raises(TypeError, attributes.register_class, B1) + assert_raises_message(TypeError, "multiple instrumentation implementations", attributes.register_class, B1) class OnLoadTest(_base.ORMTest): """Check that Events.on_load is not hit in regular attributes operations.""" diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index 819f29911e..8c196cfcfb 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -163,9 +163,8 @@ class LazyTest(_fixtures.FixtureTest): # use a union all to get a lot of rows to join against u2 = users.alias('u2') s = sa.union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') - print [key for key in s.c.keys()] - l = q.filter(s.c.u2_id==User.id).distinct().all() - assert self.static.user_all_result == l + l = q.filter(s.c.u2_id==User.id).order_by(User.id).distinct().all() + eq_(self.static.user_all_result, l) @testing.resolve_artifact_names def test_one_to_many_scalar(self): diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 13913578a5..c34ccdbab8 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -3,9 +3,8 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message import sqlalchemy as sa from sqlalchemy.test import testing, pickleable -from sqlalchemy import MetaData, Integer, String, ForeignKey, func -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy import MetaData, Integer, String, ForeignKey, func, util +from sqlalchemy.test.schema import Table, Column from sqlalchemy.engine import default from sqlalchemy.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property @@ -390,7 +389,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_self_ref_synonym(self): t = Table('nodes', MetaData(), - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_id', Integer, ForeignKey('nodes.id'))) class Node(object): @@ -432,7 +431,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_prop_filters(self): t = Table('person', MetaData(), - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('type', String(128)), Column('name', String(128)), Column('employee_number', Integer), @@ -870,6 +869,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_comparable_column(self): class MyComparator(sa.orm.properties.ColumnProperty.Comparator): + __hash__ = None def __eq__(self, other): # lower case comparison return func.lower(self.__clause_element__()) == func.lower(other) @@ -1451,12 +1451,12 @@ class DeferredTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_group(self): """Deferred load with a group""" - mapper(Order, orders, properties={ - 'userident': deferred(orders.c.user_id, group='primary'), - 'addrident': deferred(orders.c.address_id, group='primary'), - 'description': deferred(orders.c.description, group='primary'), - 'opened': deferred(orders.c.isopen, group='primary') - }) + mapper(Order, orders, properties=util.OrderedDict([ + ('userident', deferred(orders.c.user_id, group='primary')), + ('addrident', deferred(orders.c.address_id, group='primary')), + ('description', deferred(orders.c.description, group='primary')), + ('opened', deferred(orders.c.isopen, group='primary')) + ])) sess = create_session() q = sess.query(Order).order_by(Order.id) @@ -1562,10 +1562,12 @@ class DeferredTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_undefer_group(self): - mapper(Order, orders, properties={ - 'userident':deferred(orders.c.user_id, group='primary'), - 'description':deferred(orders.c.description, group='primary'), - 'opened':deferred(orders.c.isopen, group='primary')}) + mapper(Order, orders, properties=util.OrderedDict([ + ('userident',deferred(orders.c.user_id, group='primary')), + ('description',deferred(orders.c.description, group='primary')), + ('opened',deferred(orders.c.isopen, group='primary')) + ] + )) sess = create_session() q = sess.query(Order).order_by(Order.id) @@ -1796,11 +1798,11 @@ class DeferredPopulationTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("thing", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("name", String(20))) Table("human", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("thing_id", Integer, ForeignKey("thing.id")), Column("name", String(20))) @@ -1884,13 +1886,12 @@ class CompositeTypesTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('graphs', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('version_id', Integer, primary_key=True, nullable=True), Column('name', String(30))) Table('edges', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('graph_id', Integer, nullable=False), Column('graph_version_id', Integer, nullable=False), Column('x1', Integer), @@ -1902,7 +1903,7 @@ class CompositeTypesTest(_base.MappedTest): ['graphs.id', 'graphs.version_id'])) Table('foobars', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('x1', Integer, default=2), Column('x2', Integer), Column('x3', Integer, default=15), @@ -2041,7 +2042,7 @@ class CompositeTypesTest(_base.MappedTest): # test pk with one column NULL # TODO: can't seem to get NULL in for a PK value - # in either mysql or postgres, autoincrement=False etc. + # in either mysql or postgresql, autoincrement=False etc. # notwithstanding @testing.fails_on_everything_except("sqlite") def go(): @@ -2475,33 +2476,26 @@ class RequirementsTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('ht1', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('value', String(10))) Table('ht2', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('ht1_id', Integer, ForeignKey('ht1.id')), Column('value', String(10))) Table('ht3', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('value', String(10))) Table('ht4', metadata, - Column('ht1_id', Integer, ForeignKey('ht1.id'), - primary_key=True), - Column('ht3_id', Integer, ForeignKey('ht3.id'), - primary_key=True)) + Column('ht1_id', Integer, ForeignKey('ht1.id'), primary_key=True), + Column('ht3_id', Integer, ForeignKey('ht3.id'), primary_key=True)) Table('ht5', metadata, - Column('ht1_id', Integer, ForeignKey('ht1.id'), - primary_key=True)) + Column('ht1_id', Integer, ForeignKey('ht1.id'), primary_key=True)) Table('ht6', metadata, - Column('ht1a_id', Integer, ForeignKey('ht1.id'), - primary_key=True), - Column('ht1b_id', Integer, ForeignKey('ht1.id'), - primary_key=True), + Column('ht1a_id', Integer, ForeignKey('ht1.id'), primary_key=True), + Column('ht1b_id', Integer, ForeignKey('ht1.id'), primary_key=True), Column('value', String(10))) + # Py2K @testing.resolve_artifact_names def test_baseclass(self): class OldStyle: @@ -2516,7 +2510,8 @@ class RequirementsTest(_base.MappedTest): # TODO: is weakref support detectable without an instance? #self.assertRaises(sa.exc.ArgumentError, mapper, NoWeakrefSupport, t2) - + # end Py2K + @testing.resolve_artifact_names def test_comparison_overrides(self): """Simple tests to ensure users can supply comparison __methods__. @@ -2618,12 +2613,12 @@ class MagicNamesTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('cartographers', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), Column('alias', String(50)), Column('quip', String(100))) Table('maps', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('cart_id', Integer, ForeignKey('cartographers.id')), Column('state', String(2)), @@ -2665,7 +2660,7 @@ class MagicNamesTest(_base.MappedTest): for reserved in (sa.orm.attributes.ClassManager.STATE_ATTR, sa.orm.attributes.ClassManager.MANAGER_ATTR): t = Table('t', sa.MetaData(), - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column(reserved, Integer)) class T(object): pass diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index f4e3872b06..5433515caa 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -1,13 +1,13 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message import sqlalchemy as sa -from sqlalchemy import Table, Column, Integer, PickleType +from sqlalchemy import Integer, PickleType import operator from sqlalchemy.test import testing from sqlalchemy.util import OrderedSet from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property, sessionmaker from sqlalchemy.test.testing import eq_, ne_ from test.orm import _base, _fixtures - +from sqlalchemy.test.schema import Table, Column class MergeTest(_fixtures.FixtureTest): """Session.merge() functionality""" @@ -103,6 +103,7 @@ class MergeTest(_fixtures.FixtureTest): 'addresses':relation(Address, backref='user', collection_class=OrderedSet, + order_by=addresses.c.id, cascade="all, delete-orphan") }) mapper(Address, addresses) @@ -154,6 +155,7 @@ class MergeTest(_fixtures.FixtureTest): mapper(User, users, properties={ 'addresses':relation(Address, backref='user', + order_by=addresses.c.id, collection_class=OrderedSet)}) mapper(Address, addresses) on_load = self.on_load_tracker(User) @@ -300,20 +302,20 @@ class MergeTest(_fixtures.FixtureTest): # test with "dontload" merge sess5 = create_session() - u = sess5.merge(u, dont_load=True) + u = sess5.merge(u, load=False) assert len(u.addresses) for a in u.addresses: assert a.user is u def go(): sess5.flush() # no changes; therefore flush should do nothing - # but also, dont_load wipes out any difference in committed state, + # but also, load=False wipes out any difference in committed state, # so no flush at all self.assert_sql_count(testing.db, go, 0) eq_(on_load.called, 15) sess4 = create_session() - u = sess4.merge(u, dont_load=True) + u = sess4.merge(u, load=False) # post merge change u.addresses[1].email_address='afafds' def go(): @@ -445,17 +447,35 @@ class MergeTest(_fixtures.FixtureTest): assert u3 is u @testing.resolve_artifact_names - def test_transient_dontload(self): + def test_transient_no_load(self): mapper(User, users) sess = create_session() u = User() - assert_raises_message(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True) + assert_raises_message(sa.exc.InvalidRequestError, "load=False option does not support", sess.merge, u, load=False) + @testing.resolve_artifact_names + def test_dont_load_deprecated(self): + mapper(User, users) + + sess = create_session() + u = User(name='ed') + sess.add(u) + sess.flush() + u = sess.query(User).first() + sess.expunge(u) + sess.execute(users.update().values(name='jack')) + @testing.uses_deprecated("dont_load=True has been renamed") + def go(): + u1 = sess.merge(u, dont_load=True) + assert u1 in sess + assert u1.name=='ed' + assert u1 not in sess.dirty + go() @testing.resolve_artifact_names - def test_dontload_with_backrefs(self): - """dontload populates relations in both directions without requiring a load""" + def test_no_load_with_backrefs(self): + """load=False populates relations in both directions without requiring a load""" mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses), backref='user') }) @@ -470,7 +490,7 @@ class MergeTest(_fixtures.FixtureTest): assert 'user' in u.addresses[1].__dict__ sess = create_session() - u2 = sess.merge(u, dont_load=True) + u2 = sess.merge(u, load=False) assert 'user' in u2.addresses[1].__dict__ eq_(u2.addresses[1].user, User(id=7, name='fred')) @@ -479,7 +499,7 @@ class MergeTest(_fixtures.FixtureTest): sess.close() sess = create_session() - u = sess.merge(u2, dont_load=True) + u = sess.merge(u2, load=False) assert 'user' not in u.addresses[1].__dict__ eq_(u.addresses[1].user, User(id=7, name='fred')) @@ -488,12 +508,12 @@ class MergeTest(_fixtures.FixtureTest): def test_dontload_with_eager(self): """ - This test illustrates that with dont_load=True, we can't just copy the + This test illustrates that with load=False, we can't just copy the committed_state of the merged instance over; since it references collection objects which themselves are to be merged. This committed_state would instead need to be piecemeal 'converted' to represent the correct objects. However, at the moment I'd rather not - support this use case; if you are merging with dont_load=True, you're + support this use case; if you are merging with load=False, you're typically dealing with caching and the merged objects shouldnt be 'dirty'. @@ -516,16 +536,16 @@ class MergeTest(_fixtures.FixtureTest): u2 = sess2.query(User).options(sa.orm.eagerload('addresses')).get(7) sess3 = create_session() - u3 = sess3.merge(u2, dont_load=True) + u3 = sess3.merge(u2, load=False) def go(): sess3.flush() self.assert_sql_count(testing.db, go, 0) @testing.resolve_artifact_names - def test_dont_load_disallows_dirty(self): - """dont_load doesnt support 'dirty' objects right now + def test_no_load_disallows_dirty(self): + """load=False doesnt support 'dirty' objects right now - (see test_dont_load_with_eager()). Therefore lets assert it. + (see test_no_load_with_eager()). Therefore lets assert it. """ mapper(User, users) @@ -539,17 +559,17 @@ class MergeTest(_fixtures.FixtureTest): u.name = 'ed' sess2 = create_session() try: - sess2.merge(u, dont_load=True) + sess2.merge(u, load=False) assert False except sa.exc.InvalidRequestError, e: - assert ("merge() with dont_load=True option does not support " + assert ("merge() with load=False option does not support " "objects marked as 'dirty'. flush() all changes on mapped " - "instances before merging with dont_load=True.") in str(e) + "instances before merging with load=False.") in str(e) u2 = sess2.query(User).get(7) sess3 = create_session() - u3 = sess3.merge(u2, dont_load=True) + u3 = sess3.merge(u2, load=False) assert not sess3.dirty def go(): sess3.flush() @@ -557,7 +577,7 @@ class MergeTest(_fixtures.FixtureTest): @testing.resolve_artifact_names - def test_dont_load_sets_backrefs(self): + def test_no_load_sets_backrefs(self): mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses),backref='user')}) @@ -575,17 +595,17 @@ class MergeTest(_fixtures.FixtureTest): assert u.addresses[0].user is u sess2 = create_session() - u2 = sess2.merge(u, dont_load=True) + u2 = sess2.merge(u, load=False) assert not sess2.dirty def go(): assert u2.addresses[0].user is u2 self.assert_sql_count(testing.db, go, 0) @testing.resolve_artifact_names - def test_dont_load_preserves_parents(self): - """Merge with dont_load does not trigger a 'delete-orphan' operation. + def test_no_load_preserves_parents(self): + """Merge with load=False does not trigger a 'delete-orphan' operation. - merge with dont_load sets attributes without using events. this means + merge with load=False sets attributes without using events. this means the 'hasparent' flag is not propagated to the newly merged instance. in fact this works out OK, because the '_state.parents' collection on the newly merged instance is empty; since the mapper doesn't see an @@ -610,7 +630,7 @@ class MergeTest(_fixtures.FixtureTest): assert u.addresses[0].user is u sess2 = create_session() - u2 = sess2.merge(u, dont_load=True) + u2 = sess2.merge(u, load=False) assert not sess2.dirty a2 = u2.addresses[0] a2.email_address='somenewaddress' @@ -624,19 +644,19 @@ class MergeTest(_fixtures.FixtureTest): # this use case is not supported; this is with a pending Address on # the pre-merged object, and we currently dont support 'dirty' objects - # being merged with dont_load=True. in this case, the empty + # being merged with load=False. in this case, the empty # '_state.parents' collection would be an issue, since the optimistic # flag is False in _is_orphan() for pending instances. so if we start - # supporting 'dirty' with dont_load=True, this test will need to pass + # supporting 'dirty' with load=False, this test will need to pass sess = create_session() u = sess.query(User).get(7) u.addresses.append(Address()) sess2 = create_session() try: - u2 = sess2.merge(u, dont_load=True) + u2 = sess2.merge(u, load=False) assert False - # if dont_load is changed to support dirty objects, this code + # if load=False is changed to support dirty objects, this code # needs to pass a2 = u2.addresses[0] a2.email_address='somenewaddress' @@ -647,7 +667,7 @@ class MergeTest(_fixtures.FixtureTest): eq_(sess2.query(User).get(u2.id).addresses[0].email_address, 'somenewaddress') except sa.exc.InvalidRequestError, e: - assert "dont_load=True option does not support" in str(e) + assert "load=False option does not support" in str(e) @testing.resolve_artifact_names def test_synonym_comparable(self): @@ -737,7 +757,7 @@ class MutableMergeTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("data", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', PickleType(comparator=operator.eq)) ) diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py index 1376c402e7..e99bfb794b 100644 --- a/test/orm/test_naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -6,8 +6,7 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session from sqlalchemy.test.testing import eq_ from test.orm import _base @@ -16,6 +15,11 @@ class NaturalPKTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): + if testing.against('oracle'): + fk_args = dict(deferrable=True, initially='deferred') + else: + fk_args = dict(onupdate='cascade') + users = Table('users', metadata, Column('username', String(50), primary_key=True), Column('fullname', String(100)), @@ -23,7 +27,7 @@ class NaturalPKTest(_base.MappedTest): addresses = Table('addresses', metadata, Column('email', String(50), primary_key=True), - Column('username', String(50), ForeignKey('users.username', onupdate="cascade")), + Column('username', String(50), ForeignKey('users.username', **fk_args)), test_needs_fk=True) items = Table('items', metadata, @@ -32,8 +36,8 @@ class NaturalPKTest(_base.MappedTest): test_needs_fk=True) users_to_items = Table('users_to_items', metadata, - Column('username', String(50), ForeignKey('users.username', onupdate='cascade'), primary_key=True), - Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True), + Column('username', String(50), ForeignKey('users.username', **fk_args), primary_key=True), + Column('itemname', String(50), ForeignKey('items.itemname', **fk_args), primary_key=True), test_needs_fk=True) @classmethod @@ -110,6 +114,7 @@ class NaturalPKTest(_base.MappedTest): @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_onetomany_passive(self): self._test_onetomany(True) @@ -161,6 +166,7 @@ class NaturalPKTest(_base.MappedTest): @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_manytoone_passive(self): self._test_manytoone(True) @@ -203,6 +209,7 @@ class NaturalPKTest(_base.MappedTest): eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_onetoone_passive(self): self._test_onetoone(True) @@ -244,6 +251,7 @@ class NaturalPKTest(_base.MappedTest): eq_([Address(username='ed')], sess.query(Address).all()) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_bidirectional_passive(self): self._test_bidirectional(True) @@ -298,10 +306,12 @@ class NaturalPKTest(_base.MappedTest): @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_manytomany_passive(self): self._test_manytomany(True) - @testing.fails_on('mysql', 'the executemany() of the association table fails to report the correct row count') + # mysqldb executemany() of the association table fails to report the correct row count + @testing.fails_if(lambda: testing.against('mysql') and not testing.against('+zxjdbc')) def test_manytomany_nonpassive(self): self._test_manytomany(False) @@ -361,10 +371,15 @@ class SelfRefTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): + if testing.against('oracle'): + fk_args = dict(deferrable=True, initially='deferred') + else: + fk_args = dict(onupdate='cascade') + Table('nodes', metadata, Column('name', String(50), primary_key=True), Column('parent', String(50), - ForeignKey('nodes.name', onupdate='cascade'))) + ForeignKey('nodes.name', **fk_args))) @classmethod def setup_classes(cls): @@ -400,17 +415,22 @@ class SelfRefTest(_base.MappedTest): class NonPKCascadeTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): + if testing.against('oracle'): + fk_args = dict(deferrable=True, initially='deferred') + else: + fk_args = dict(onupdate='cascade') + Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('username', String(50), unique=True), Column('fullname', String(100)), test_needs_fk=True) Table('addresses', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('email', String(50)), Column('username', String(50), - ForeignKey('users.username', onupdate="cascade")), + ForeignKey('users.username', **fk_args)), test_needs_fk=True ) @@ -422,6 +442,7 @@ class NonPKCascadeTest(_base.MappedTest): pass @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') + @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE') def test_onetomany_passive(self): self._test_onetomany(True) diff --git a/test/orm/test_onetoone.py b/test/orm/test_onetoone.py index 0d66915ea5..6880f1f747 100644 --- a/test/orm/test_onetoone.py +++ b/test/orm/test_onetoone.py @@ -1,8 +1,7 @@ import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session from test.orm import _base @@ -11,13 +10,13 @@ class O2OTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('jack', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('number', String(50)), Column('status', String(20)), Column('subroom', String(5))) Table('port', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30)), Column('description', String(100)), Column('jack_id', Integer, ForeignKey("jack.id"))) diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index 5343cc15b9..6ac9f24701 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -3,8 +3,7 @@ import pickle import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, create_session, attributes from test.orm import _base, _fixtures @@ -60,7 +59,7 @@ class PickleTest(_fixtures.FixtureTest): u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() - u2 = sess2.merge(u2, dont_load=True) + u2 = sess2.merge(u2, load=False) eq_(u2.name, 'ed') eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) @@ -94,7 +93,7 @@ class PickleTest(_fixtures.FixtureTest): u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() - u2 = sess2.merge(u2, dont_load=True) + u2 = sess2.merge(u2, load=False) eq_(u2.name, 'ed') assert 'addresses' not in u2.__dict__ ad = u2.addresses[0] @@ -136,7 +135,7 @@ class PolymorphicDeferredTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30)), Column('type', String(30))) Table('email_users', metadata, diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 88a95bf760..8cb7ef969b 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -109,13 +109,8 @@ class GetTest(QueryTest): pass s = users.select(users.c.id!=12).alias('users') m = mapper(SomeUser, s) - print s.primary_key - print m.primary_key assert s.primary_key == m.primary_key - row = s.select(use_labels=True).execute().fetchone() - print row[s.primary_key[0]] - sess = create_session() assert sess.query(SomeUser).get(7).name == 'jack' @@ -145,15 +140,20 @@ class GetTest(QueryTest): @testing.requires.unicode_connections def test_unicode(self): """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail - on postgres, mysql and oracle unless it is converted to an encoded string""" + on postgresql, mysql and oracle unless it is converted to an encoded string""" metadata = MetaData(engines.utf8_engine()) table = Table('unicode_data', metadata, - Column('id', Unicode(40), primary_key=True), + Column('id', Unicode(40), primary_key=True, test_needs_autoincrement=True), Column('data', Unicode(40))) try: metadata.create_all() + # Py3K + #ustring = 'petit voix m\xe2\x80\x99a' + # Py2K ustring = 'petit voix m\xe2\x80\x99a'.decode('utf-8') + # end Py2K + table.insert().execute(id=ustring, data=ustring) class LocalFoo(Base): pass @@ -195,7 +195,7 @@ class GetTest(QueryTest): assert u.addresses[0].email_address == 'jack@bean.com' assert u.orders[1].items[2].description == 'item 5' - @testing.fails_on_everything_except('sqlite', 'mssql') + @testing.fails_on_everything_except('sqlite', '+pyodbc', '+zxjdbc') def test_query_str(self): s = create_session() q = s.query(User).filter(User.id==1) @@ -299,7 +299,12 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): def test_arithmetic(self): create_session().query(User) for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), - (operator.sub, '-'), (operator.div, '/'), + (operator.sub, '-'), + # Py3k + #(operator.truediv, '/'), + # Py2K + (operator.div, '/'), + # end Py2K ): for (lhs, rhs, res) in ( (5, User.id, ':id_1 %s users.id'), @@ -489,10 +494,16 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): sess = create_session() self.assert_compile(sess.query(users).select_from(users.select()).with_labels().statement, - "SELECT users.id AS users_id, users.name AS users_name FROM users, (SELECT users.id AS id, users.name AS name FROM users) AS anon_1") + "SELECT users.id AS users_id, users.name AS users_name FROM users, " + "(SELECT users.id AS id, users.name AS name FROM users) AS anon_1", + dialect=default.DefaultDialect() + ) self.assert_compile(sess.query(users, exists([1], from_obj=addresses)).with_labels().statement, - "SELECT users.id AS users_id, users.name AS users_name, EXISTS (SELECT 1 FROM addresses) AS anon_1 FROM users") + "SELECT users.id AS users_id, users.name AS users_name, EXISTS " + "(SELECT 1 FROM addresses) AS anon_1 FROM users", + dialect=default.DefaultDialect() + ) # a little tedious here, adding labels to work around Query's auto-labelling. # also correlate needed explicitly. hmmm..... @@ -687,15 +698,19 @@ class FilterTest(QueryTest): sess = create_session() assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all() - assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).all() + assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == \ + sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).order_by(Address.id).all() - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all() + assert [Address(id=2), Address(id=3), Address(id=4)] == \ + sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all() # test has() doesn't overcorrelate - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).all() + assert [Address(id=2), Address(id=3), Address(id=4)] == \ + sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all() # test has() doesnt' get subquery contents adapted by aliased join - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).all() + assert [Address(id=2), Address(id=3), Address(id=4)] == \ + sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).order_by(Address.id).all() dingaling = sess.query(Dingaling).get(2) assert [User(id=9)] == sess.query(User).filter(User.addresses.any(Address.dingaling==dingaling)).all() @@ -730,8 +745,8 @@ class FilterTest(QueryTest): assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all() # m2m - eq_(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)]) - eq_(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)]) + eq_(sess.query(Item).filter(Item.keywords==None).order_by(Item.id).all(), [Item(id=4), Item(id=5)]) + eq_(sess.query(Item).filter(Item.keywords!=None).order_by(Item.id).all(), [Item(id=1),Item(id=2), Item(id=3)]) def test_filter_by(self): sess = create_session() @@ -748,8 +763,9 @@ class FilterTest(QueryTest): sess = create_session() # o2o - eq_([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all()) - eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all()) + eq_([Address(id=1), Address(id=3), Address(id=4)], + sess.query(Address).filter(Address.dingaling==None).order_by(Address.id).all()) + eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).order_by(Address.id).all()) # m2o eq_([Order(id=5)], sess.query(Order).filter(Order.address==None).all()) @@ -806,11 +822,15 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): s = create_session() + oracle_as = not testing.against('oracle') and "AS " or "" + self.assert_compile( s.query(User).options(eagerload(User.addresses)).from_self().statement, "SELECT anon_1.users_id, anon_1.users_name, addresses_1.id, addresses_1.user_id, "\ - "addresses_1.email_address FROM (SELECT users.id AS users_id, users.name AS users_name FROM users) AS anon_1 "\ - "LEFT OUTER JOIN addresses AS addresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id" + "addresses_1.email_address FROM (SELECT users.id AS users_id, users.name AS users_name FROM users) %(oracle_as)sanon_1 "\ + "LEFT OUTER JOIN addresses %(oracle_as)saddresses_1 ON anon_1.users_id = addresses_1.user_id ORDER BY addresses_1.id" % { + 'oracle_as':oracle_as + } ) def test_aliases(self): @@ -987,8 +1007,14 @@ class CountTest(QueryTest): class DistinctTest(QueryTest): def test_basic(self): - assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).distinct().all() - assert [User(id=7), User(id=9), User(id=8),User(id=10)] == create_session().query(User).distinct().order_by(desc(User.name)).all() + eq_( + [User(id=7), User(id=8), User(id=9),User(id=10)], + create_session().query(User).order_by(User.id).distinct().all() + ) + eq_( + [User(id=7), User(id=9), User(id=8),User(id=10)], + create_session().query(User).distinct().order_by(desc(User.name)).all() + ) def test_joined(self): """test that orderbys from a joined table get placed into the columns clause when DISTINCT is used""" @@ -1017,7 +1043,6 @@ class DistinctTest(QueryTest): class YieldTest(QueryTest): def test_basic(self): - import gc sess = create_session() q = iter(sess.query(User).yield_per(1).from_statement("select * from users")) @@ -1447,11 +1472,11 @@ class MultiplePathTest(_base.MappedTest): def define_tables(cls, metadata): global t1, t2, t1t2_1, t1t2_2 t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)) ) t2 = Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)) ) @@ -1715,7 +1740,7 @@ class MixedEntitiesTest(QueryTest): eq_(list(q2), [(u'jack',), (u'ed',)]) q = sess.query(User) - q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String)) + q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String(50))) eq_(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')]) q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address) @@ -1755,6 +1780,9 @@ class MixedEntitiesTest(QueryTest): eq_(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')]) @testing.fails_on('mssql', 'FIXME: unknown') + @testing.fails_on('oracle', "Oracle doesn't support boolean expressions as columns") + @testing.fails_on('postgresql+pg8000', "pg8000 parses the SQL itself before passing on to PG, doesn't parse this") + @testing.fails_on('postgresql+zxjdbc', "zxjdbc parses the SQL itself before passing on to PG, doesn't parse this") def test_values_with_boolean_selects(self): """Tests a values clause that works with select boolean evaluations""" sess = create_session() @@ -1763,6 +1791,10 @@ class MixedEntitiesTest(QueryTest): q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%'))) eq_(list(q2), [(True, 1), (False, 3)]) + q2 = q.order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%')) + eq_(list(q2), [(True,), (False,), (False,), (False,)]) + + def test_correlated_subquery(self): """test that a subquery constructed from ORM attributes doesn't leak out those entities to the outermost query. @@ -2514,12 +2546,14 @@ class SelfReferentialTest(_base.MappedTest): sess = create_session() eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), []) eq_(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')]) - eq_(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) + eq_(sess.query(Node).filter(~Node.children.any()).order_by(Node.id).all(), + [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) def test_has(self): sess = create_session() - eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).order_by(Node.id).all(), + [Node(data='n121'),Node(data='n122'),Node(data='n123')]) eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), []) eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')]) @@ -2660,7 +2694,7 @@ class ExternalColumnsTest(QueryTest): for x in range(2): sess.expunge_all() def go(): - eq_(sess.query(Address).options(eagerload('user')).all(), address_result) + eq_(sess.query(Address).options(eagerload('user')).order_by(Address.id).all(), address_result) self.assert_sql_count(testing.db, go, 1) ualias = aliased(User) @@ -2691,7 +2725,9 @@ class ExternalColumnsTest(QueryTest): ) ua = aliased(User) - eq_(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(), + eq_(sess.query(Address, ua.concat, ua.count). + select_from(join(Address, ua, 'user')). + options(eagerload(Address.user)).order_by(Address.id).all(), [ (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1), (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3), @@ -2742,7 +2778,7 @@ class TestOverlyEagerEquivalentCols(_base.MappedTest): def define_tables(cls, metadata): global base, sub1, sub2 base = Table('base', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)) ) @@ -2800,12 +2836,12 @@ class UpdateDeleteTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(32)), Column('age', Integer)) Table('documents', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', None, ForeignKey('users.id')), Column('title', String(32))) @@ -2875,7 +2911,7 @@ class UpdateDeleteTest(_base.MappedTest): sess = create_session(bind=testing.db, autocommit=False) john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter('name = :name').params(name='john').delete() + sess.query(User).filter('name = :name').params(name='john').delete('fetch') assert john not in sess eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane]) @@ -2922,12 +2958,17 @@ class UpdateDeleteTest(_base.MappedTest): @testing.fails_on('mysql', 'FIXME: unknown') @testing.resolve_artifact_names - def test_delete_fallback(self): + def test_delete_invalid_evaluation(self): sess = create_session(bind=testing.db, autocommit=False) john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(User.name == select([func.max(User.name)])).delete(synchronize_session='evaluate') + assert_raises(sa_exc.InvalidRequestError, + sess.query(User).filter(User.name == select([func.max(User.name)])).delete, synchronize_session='evaluate' + ) + + sess.query(User).filter(User.name == select([func.max(User.name)])).delete(synchronize_session='fetch') + assert john not in sess eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane]) @@ -2957,7 +2998,7 @@ class UpdateDeleteTest(_base.MappedTest): john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter('age > :x').params(x=29).update({'age': User.age - 10}, synchronize_session='evaluate') + sess.query(User).filter('age > :x').params(x=29).update({'age': User.age - 10}, synchronize_session='fetch') eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) @@ -3017,11 +3058,12 @@ class UpdateDeleteTest(_base.MappedTest): sess = create_session(bind=testing.db, autocommit=False) john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire') + sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='fetch') eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) + @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount) @testing.resolve_artifact_names def test_update_returns_rowcount(self): sess = create_session(bind=testing.db, autocommit=False) @@ -3032,6 +3074,7 @@ class UpdateDeleteTest(_base.MappedTest): rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10}) eq_(rowcount, 2) + @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount) @testing.resolve_artifact_names def test_delete_returns_rowcount(self): sess = create_session(bind=testing.db, autocommit=False) @@ -3046,7 +3089,7 @@ class UpdateDeleteTest(_base.MappedTest): sess = create_session(bind=testing.db, autocommit=False) foo,bar,baz = sess.query(Document).order_by(Document.id).all() - sess.query(Document).filter(Document.user_id == 1).update({'title': Document.title+Document.title}, synchronize_session='expire') + sess.query(Document).filter(Document.user_id == 1).update({'title': Document.title+Document.title}, synchronize_session='fetch') eq_([foo.title, bar.title, baz.title], ['foofoo','barbar', 'baz']) eq_(sess.query(Document.title).order_by(Document.id).all(), zip(['foofoo','barbar', 'baz'])) @@ -3056,7 +3099,7 @@ class UpdateDeleteTest(_base.MappedTest): sess = create_session(bind=testing.db, autocommit=False) john,jack,jill,jane = sess.query(User).order_by(User.id).all() - sess.query(User).options(eagerload(User.documents)).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire') + sess.query(User).options(eagerload(User.documents)).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='fetch') eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index fef1577f07..481deb81b1 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -3,8 +3,7 @@ import datetime import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import Integer, String, ForeignKey, MetaData, and_ -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers, sessionmaker from sqlalchemy.test.testing import eq_, startswith_ from test.orm import _base, _fixtures @@ -32,17 +31,17 @@ class RelationTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("tbl_a", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("name", String(128))) Table("tbl_b", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("name", String(128))) Table("tbl_c", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False), Column("name", String(128))) Table("tbl_d", metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False), Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")), Column("name", String(128))) @@ -132,7 +131,7 @@ class RelationTest2(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('company_t', metadata, - Column('company_id', Integer, primary_key=True), + Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', sa.Unicode(30))) Table('employee_t', metadata, @@ -395,7 +394,7 @@ class RelationTest4(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("tableA", metadata, - Column("id",Integer,primary_key=True), + Column("id",Integer,primary_key=True, test_needs_autoincrement=True), Column("foo",Integer,), test_needs_fk=True) Table("tableB",metadata, @@ -456,7 +455,7 @@ class RelationTest4(_base.MappedTest): @testing.fails_on_everything_except('sqlite', 'mysql') @testing.resolve_artifact_names def test_nullPKsOK_BtoA(self): - # postgres cant handle a nullable PK column...? + # postgresql cant handle a nullable PK column...? tableC = Table('tablec', tableA.metadata, Column('id', Integer, primary_key=True), Column('a_id', Integer, ForeignKey('tableA.id'), @@ -642,12 +641,12 @@ class RelationTest6(_base.MappedTest): @classmethod def define_tables(cls, metadata): - Table('tags', metadata, Column("id", Integer, primary_key=True), + Table('tags', metadata, Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column("data", String(50)), ) Table('tag_foo', metadata, - Column("id", Integer, primary_key=True), + Column("id", Integer, primary_key=True, test_needs_autoincrement=True), Column('tagid', Integer), Column("data", String(50)), ) @@ -691,11 +690,11 @@ class BackrefPropagatesForwardsArgs(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)) ) Table('addresses', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer), Column('email', String(50)) ) @@ -738,7 +737,7 @@ class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest): @classmethod def define_tables(cls, metadata): subscriber_table = Table('subscriber', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('dummy', String(10)) # to appease older sqlite version ) @@ -947,18 +946,18 @@ class TypeMatchTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("a", metadata, - Column('aid', Integer, primary_key=True), + Column('aid', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) Table("b", metadata, - Column('bid', Integer, primary_key=True), + Column('bid', Integer, primary_key=True, test_needs_autoincrement=True), Column("a_id", Integer, ForeignKey("a.aid")), Column('data', String(30))) Table("c", metadata, - Column('cid', Integer, primary_key=True), + Column('cid', Integer, primary_key=True, test_needs_autoincrement=True), Column("b_id", Integer, ForeignKey("b.bid")), Column('data', String(30))) Table("d", metadata, - Column('did', Integer, primary_key=True), + Column('did', Integer, primary_key=True, test_needs_autoincrement=True), Column("a_id", Integer, ForeignKey("a.aid")), Column('data', String(30))) @@ -1116,14 +1115,14 @@ class ViewOnlyOverlappingNames(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("t1", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40))) Table("t2", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40)), Column('t1id', Integer, ForeignKey('t1.id'))) Table("t3", metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40)), Column('t2id', Integer, ForeignKey('t2.id'))) @@ -1176,14 +1175,14 @@ class ViewOnlyUniqueNames(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table("t1", metadata, - Column('t1id', Integer, primary_key=True), + Column('t1id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40))) Table("t2", metadata, - Column('t2id', Integer, primary_key=True), + Column('t2id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40)), Column('t1id_ref', Integer, ForeignKey('t1.t1id'))) Table("t3", metadata, - Column('t3id', Integer, primary_key=True), + Column('t3id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(40)), Column('t2id_ref', Integer, ForeignKey('t2.t2id'))) @@ -1309,12 +1308,12 @@ class ViewOnlyRepeatedRemoteColumn(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('foos', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('bid1', Integer,ForeignKey('bars.id')), Column('bid2', Integer,ForeignKey('bars.id'))) Table('bars', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) @testing.resolve_artifact_names @@ -1357,10 +1356,10 @@ class ViewOnlyRepeatedLocalColumn(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('foos', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) - Table('bars', metadata, Column('id', Integer, primary_key=True), + Table('bars', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('fid1', Integer, ForeignKey('foos.id')), Column('fid2', Integer, ForeignKey('foos.id')), Column('data', String(50))) @@ -1405,14 +1404,14 @@ class ViewOnlyComplexJoin(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t1id', Integer, ForeignKey('t1.id'))) Table('t3', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) Table('t2tot3', metadata, Column('t2id', Integer, ForeignKey('t2.id')), @@ -1476,10 +1475,10 @@ class ExplicitLocalRemoteTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('t1', metadata, - Column('id', String(50), primary_key=True), + Column('id', String(50), primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) Table('t2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), Column('t1id', String(50))) diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py index 9f2f59e19b..0d6b3deaec 100644 --- a/test/orm/test_scoping.py +++ b/test/orm/test_scoping.py @@ -3,13 +3,13 @@ import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy.orm import scoped_session from sqlalchemy import Integer, String, ForeignKey -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, query from sqlalchemy.test.testing import eq_ from test.orm import _base + class _ScopedTest(_base.MappedTest): """Adds another lookup bucket to emulate Session globals.""" @@ -34,10 +34,10 @@ class ScopedSessionTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('table1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) Table('table2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('someid', None, ForeignKey('table1.id'))) @testing.resolve_artifact_names @@ -82,10 +82,10 @@ class ScopedMapperTest(_ScopedTest): @classmethod def define_tables(cls, metadata): Table('table1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30))) Table('table2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('someid', None, ForeignKey('table1.id'))) @classmethod @@ -204,11 +204,11 @@ class ScopedMapperTest2(_ScopedTest): @classmethod def define_tables(cls, metadata): Table('table1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), Column('type', String(30))) Table('table2', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('someid', None, ForeignKey('table1.id')), Column('somedata', String(30))) diff --git a/test/orm/test_selectable.py b/test/orm/test_selectable.py index 0a20253607..bfa4008957 100644 --- a/test/orm/test_selectable.py +++ b/test/orm/test_selectable.py @@ -3,8 +3,7 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message import sqlalchemy as sa from sqlalchemy.test import testing from sqlalchemy import String, Integer, select -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, create_session from sqlalchemy.test.testing import eq_ from test.orm import _base @@ -16,7 +15,7 @@ class SelectableNoFromsTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('common', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', Integer), Column('extra', String(45))) diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 328cbee8ee..2d99e20630 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -1,13 +1,12 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message -import gc +from sqlalchemy.test.util import gc_collect import inspect import pickle from sqlalchemy.orm import create_session, sessionmaker, attributes import sqlalchemy as sa from sqlalchemy.test import engines, testing, config from sqlalchemy import Integer, String, Sequence -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column from sqlalchemy.orm import mapper, relation, backref, eagerload from sqlalchemy.test.testing import eq_ from test.engine import _base as engine_base @@ -229,7 +228,7 @@ class SessionTest(_fixtures.FixtureTest): u = sess.query(User).get(u.id) q = sess.query(Address).filter(Address.user==u) del u - gc.collect() + gc_collect() eq_(q.one(), Address(email_address='foo')) @@ -381,18 +380,18 @@ class SessionTest(_fixtures.FixtureTest): session = create_session(bind=testing.db) session.begin() - session.connection().execute("insert into users (name) values ('user1')") + session.connection().execute(users.insert().values(name='user1')) session.begin(subtransactions=True) session.begin_nested() - session.connection().execute("insert into users (name) values ('user2')") + session.connection().execute(users.insert().values(name='user2')) assert session.connection().execute("select count(1) from users").scalar() == 2 session.rollback() assert session.connection().execute("select count(1) from users").scalar() == 1 - session.connection().execute("insert into users (name) values ('user3')") + session.connection().execute(users.insert().values(name='user3')) session.commit() assert session.connection().execute("select count(1) from users").scalar() == 2 @@ -771,18 +770,18 @@ class SessionTest(_fixtures.FixtureTest): user = s.query(User).one() del user - gc.collect() + gc_collect() assert len(s.identity_map) == 0 user = s.query(User).one() user.name = 'fred' del user - gc.collect() + gc_collect() assert len(s.identity_map) == 1 assert len(s.dirty) == 1 assert None not in s.dirty s.flush() - gc.collect() + gc_collect() assert not s.dirty assert not s.identity_map @@ -809,13 +808,13 @@ class SessionTest(_fixtures.FixtureTest): s.add(u2) del u2 - gc.collect() + gc_collect() assert len(s.identity_map) == 1 assert len(s.dirty) == 1 assert None not in s.dirty s.flush() - gc.collect() + gc_collect() assert not s.dirty assert not s.identity_map @@ -835,14 +834,14 @@ class SessionTest(_fixtures.FixtureTest): eq_(user, User(name="ed", addresses=[Address(email_address="ed1")])) del user - gc.collect() + gc_collect() assert len(s.identity_map) == 0 user = s.query(User).options(eagerload(User.addresses)).one() user.addresses[0].email_address='ed2' user.addresses[0].user # lazyload del user - gc.collect() + gc_collect() assert len(s.identity_map) == 2 s.commit() @@ -864,7 +863,7 @@ class SessionTest(_fixtures.FixtureTest): eq_(user, User(name="ed", address=Address(email_address="ed1"))) del user - gc.collect() + gc_collect() assert len(s.identity_map) == 0 user = s.query(User).options(eagerload(User.address)).one() @@ -872,7 +871,7 @@ class SessionTest(_fixtures.FixtureTest): user.address.user # lazyload del user - gc.collect() + gc_collect() assert len(s.identity_map) == 2 s.commit() @@ -890,8 +889,7 @@ class SessionTest(_fixtures.FixtureTest): user = s.query(User).one() user = None print s.identity_map - import gc - gc.collect() + gc_collect() assert len(s.identity_map) == 1 user = s.query(User).one() @@ -901,7 +899,7 @@ class SessionTest(_fixtures.FixtureTest): s.flush() eq_(users.select().execute().fetchall(), [(user.id, 'u2')]) - + @testing.fails_on('+zxjdbc', 'http://www.sqlalchemy.org/trac/ticket/1473') @testing.resolve_artifact_names def test_prune(self): s = create_session(weak_identity_map=False) @@ -914,8 +912,7 @@ class SessionTest(_fixtures.FixtureTest): self.assert_(len(s.identity_map) == 0) self.assert_(s.prune() == 0) s.flush() - import gc - gc.collect() + gc_collect() self.assert_(s.prune() == 9) self.assert_(len(s.identity_map) == 1) @@ -1228,7 +1225,7 @@ class DisposedStates(_base.MappedTest): def define_tables(cls, metadata): global t1 t1 = Table('t1', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)) ) @@ -1327,7 +1324,7 @@ class SessionInterface(testing.TestBase): def _map_it(self, cls): return mapper(cls, Table('t', sa.MetaData(), - Column('id', Integer, primary_key=True))) + Column('id', Integer, primary_key=True, test_needs_autoincrement=True))) @testing.uses_deprecated() def _test_instance_guards(self, user_arg): @@ -1447,7 +1444,7 @@ class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest): @classmethod def define_tables(cls, metadata): Table('users', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(20)), test_needs_acid=True) diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 5aa541cdad..51b345cebd 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -4,13 +4,11 @@ from sqlalchemy import * from sqlalchemy.orm import attributes from sqlalchemy import exc as sa_exc from sqlalchemy.orm import * - +from sqlalchemy.test.util import gc_collect from sqlalchemy.test import testing from test.orm import _base from test.orm._fixtures import FixtureTest, User, Address, users, addresses -import gc - class TransactionTest(FixtureTest): run_setup_mappers = 'once' run_inserts = None @@ -20,7 +18,7 @@ class TransactionTest(FixtureTest): def setup_mappers(cls): mapper(User, users, properties={ 'addresses':relation(Address, backref='user', - cascade="all, delete-orphan"), + cascade="all, delete-orphan", order_by=addresses.c.id), }) mapper(Address, addresses) @@ -109,7 +107,7 @@ class AutoExpireTest(TransactionTest): assert u1_state not in s.identity_map.all_states() assert u1_state not in s._deleted del u1 - gc.collect() + gc_collect() assert u1_state.obj() is None s.rollback() diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index f95346902b..4d2056b264 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -379,7 +379,6 @@ class MutableTypesTest(_base.MappedTest): "WHERE mutable_t.id = :mutable_t_id", {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})]) - @testing.resolve_artifact_names def test_resurrect(self): f1 = Foo() @@ -392,42 +391,13 @@ class MutableTypesTest(_base.MappedTest): f1.data.y = 19 del f1 - + gc.collect() assert len(session.identity_map) == 1 - - session.commit() - - assert session.query(Foo).one().data == pickleable.Bar(4, 19) - - - @testing.uses_deprecated() - @testing.resolve_artifact_names - def test_nocomparison(self): - """Changes are detected on MutableTypes lacking an __eq__ method.""" - f1 = Foo() - f1.data = pickleable.BarWithoutCompare(4,5) - session = create_session(autocommit=False) - session.add(f1) session.commit() - self.sql_count_(0, session.commit) - session.close() - - session = create_session(autocommit=False) - f2 = session.query(Foo).filter_by(id=f1.id).one() - self.sql_count_(0, session.commit) - - f2.data.y = 19 - self.sql_count_(1, session.commit) - session.close() - - session = create_session(autocommit=False) - f3 = session.query(Foo).filter_by(id=f1.id).one() - eq_((f3.data.x, f3.data.y), (4,19)) - self.sql_count_(0, session.commit) - session.close() + assert session.query(Foo).one().data == pickleable.Bar(4, 19) @testing.resolve_artifact_names def test_unicode(self): @@ -892,7 +862,7 @@ class DefaultTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): - use_string_defaults = testing.against('postgres', 'oracle', 'sqlite', 'mssql') + use_string_defaults = testing.against('postgresql', 'oracle', 'sqlite', 'mssql') if use_string_defaults: hohotype = String(30) @@ -910,15 +880,14 @@ class DefaultTest(_base.MappedTest): Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('hoho', hohotype, server_default=str(hohoval)), - Column('counter', Integer, default=sa.func.char_length("1234567")), - Column('foober', String(30), default="im foober", - onupdate="im the update")) + Column('counter', Integer, default=sa.func.char_length("1234567", type_=Integer)), + Column('foober', String(30), default="im foober", onupdate="im the update")) st = Table('secondary_table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50))) - if testing.against('postgres', 'oracle'): + if testing.against('postgresql', 'oracle'): dt.append_column( Column('secondary_id', Integer, sa.Sequence('sec_id_seq'), unique=True)) @@ -1004,14 +973,14 @@ class DefaultTest(_base.MappedTest): # "post-update" mapper(Hoho, default_t) - h1 = Hoho(hoho="15", counter="15") + h1 = Hoho(hoho="15", counter=15) session = create_session() session.add(h1) session.flush() def go(): eq_(h1.hoho, "15") - eq_(h1.counter, "15") + eq_(h1.counter, 15) eq_(h1.foober, "im foober") self.sql_count_(0, go) @@ -1036,7 +1005,7 @@ class DefaultTest(_base.MappedTest): """A server-side default can be used as the target of a foreign key""" mapper(Hoho, default_t, properties={ - 'secondaries':relation(Secondary)}) + 'secondaries':relation(Secondary, order_by=secondary_table.c.id)}) mapper(Secondary, secondary_table) h1 = Hoho() @@ -1068,7 +1037,7 @@ class ColumnPropertyTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): Table('data', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('a', String(50)), Column('b', String(50)) ) @@ -1681,7 +1650,7 @@ class ManyToOneTest(_fixtures.FixtureTest): l = sa.select([users, addresses], sa.and_(users.c.id==addresses.c.user_id, addresses.c.id==a.id)).execute() - eq_(l.fetchone().values(), + eq_(l.first().values(), [a.user.id, 'asdf8d', a.id, a.user_id, 'theater@foo.com']) @testing.resolve_artifact_names @@ -2201,8 +2170,14 @@ class RowSwitchTest(_base.MappedTest): sess.add(o5) sess.flush() - assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some t5')] - assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some t6', 1), (2, 'some other t6', 1)] + eq_( + list(sess.execute(t5.select(), mapper=T5)), + [(1, 'some t5')] + ) + eq_( + list(sess.execute(t6.select().order_by(t6.c.id), mapper=T5)), + [(1, 'some t6', 1), (2, 'some other t6', 1)] + ) o6 = T5(data='some other t5', id=o5.id, t6s=[ T6(data='third t6', id=3), @@ -2212,8 +2187,14 @@ class RowSwitchTest(_base.MappedTest): sess.add(o6) sess.flush() - assert list(sess.execute(t5.select(), mapper=T5)) == [(1, 'some other t5')] - assert list(sess.execute(t6.select(), mapper=T5)) == [(3, 'third t6', 1), (4, 'fourth t6', 1)] + eq_( + list(sess.execute(t5.select(), mapper=T5)), + [(1, 'some other t5')] + ) + eq_( + list(sess.execute(t6.select().order_by(t6.c.id), mapper=T5)), + [(3, 'third t6', 1), (4, 'fourth t6', 1)] + ) @testing.resolve_artifact_names def test_manytomany(self): @@ -2369,6 +2350,6 @@ class TransactionTest(_base.MappedTest): # todo: on 8.3 at least, the failed commit seems to close the cursor? # needs investigation. leaving in the DDL above now to help verify # that the new deferrable support on FK isn't involved in this issue. - if testing.against('postgres'): + if testing.against('postgresql'): t1.bind.engine.dispose() diff --git a/test/orm/test_utils.py b/test/orm/test_utils.py index 06533a243b..8635ad2125 100644 --- a/test/orm/test_utils.py +++ b/test/orm/test_utils.py @@ -39,8 +39,13 @@ class ExtensionCarrierTest(TestBase): assert 'populate_instance' not in carrier carrier.append(interfaces.MapperExtension) + + # Py3K + #assert 'populate_instance' not in carrier + # Py2K assert 'populate_instance' in carrier - + # end Py2K + assert carrier.interface for m in carrier.interface: assert getattr(interfaces.MapperExtension, m) @@ -85,7 +90,10 @@ class AliasedClassTest(TestBase): alias = aliased(Point) assert Point.zero + # Py2K + # TODO: what is this testing ?? assert not getattr(alias, 'zero') + # end Py2K def test_classmethods(self): class Point(object): @@ -152,9 +160,17 @@ class AliasedClassTest(TestBase): self.func = func def __get__(self, instance, owner): if instance is None: + # Py3K + #args = (self.func, owner) + # Py2K args = (self.func, owner, owner.__class__) + # end Py2K else: + # Py3K + #args = (self.func, instance) + # Py2K args = (self.func, instance, owner) + # end Py2K return types.MethodType(*args) class PropertyDescriptor(object): diff --git a/test/perf/insertspeed.py b/test/perf/insertspeed.py index 32877560eb..0491e9f959 100644 --- a/test/perf/insertspeed.py +++ b/test/perf/insertspeed.py @@ -2,7 +2,7 @@ import testenv; testenv.simple_setup() import sys, time from sqlalchemy import * from sqlalchemy.orm import * -from testlib import profiling +from sqlalchemy.test import profiling db = create_engine('sqlite://') metadata = MetaData(db) diff --git a/test/perf/masscreate.py b/test/perf/masscreate.py index ae32f83e2c..5b8e0da555 100644 --- a/test/perf/masscreate.py +++ b/test/perf/masscreate.py @@ -3,7 +3,6 @@ import testenv; testenv.simple_setup() from sqlalchemy.orm import attributes import time -import gc manage_attributes = True init_attributes = manage_attributes and True @@ -34,7 +33,6 @@ for i in range(0,130): attributes.manage(a) a.email = 'foo@bar.com' u.addresses.append(a) -# gc.collect() # print len(managed_attributes) # managed_attributes.clear() total = time.time() - now diff --git a/test/perf/masscreate2.py b/test/perf/masscreate2.py index 25d4b49153..e525fcf99d 100644 --- a/test/perf/masscreate2.py +++ b/test/perf/masscreate2.py @@ -1,9 +1,9 @@ import testenv; testenv.simple_setup() -import gc import random, string from sqlalchemy.orm import attributes +from sqlalchemy.test.util import gc_collect # with this test, run top. make sure the Python process doenst grow in size arbitrarily. @@ -33,4 +33,4 @@ for i in xrange(1000): a.user = u print "clearing" #managed_attributes.clear() - gc.collect() + gc_collect() diff --git a/test/perf/masseagerload.py b/test/perf/masseagerload.py index a848b866cc..88a3ade20b 100644 --- a/test/perf/masseagerload.py +++ b/test/perf/masseagerload.py @@ -1,7 +1,6 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * +from sqlalchemy.test import * NUM = 500 DIVISOR = 50 diff --git a/test/perf/massload.py b/test/perf/massload.py index 9391ead2a5..f6cde3adfd 100644 --- a/test/perf/massload.py +++ b/test/perf/massload.py @@ -1,10 +1,8 @@ -import testenv; testenv.configure_for_tests() import time -#import gc #import sqlalchemy.orm.attributes as attributes from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * +from sqlalchemy.test import * """ @@ -18,16 +16,18 @@ top while it runs NUM = 2500 class LoadTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global items, meta meta = MetaData(testing.db) items = Table('items', meta, Column('item_id', Integer, primary_key=True), Column('value', String(100))) items.create() - def tearDownAll(self): + @classmethod + def teardown_class(cls): items.drop() - def setUp(self): + def setup(self): for x in range(1,NUM/500+1): l = [] for y in range(x*500-500 + 1, x*500 + 1): @@ -43,7 +43,7 @@ class LoadTest(TestBase, AssertsExecutionResults): query = sess.query(Item) for x in range (1,NUM/100): # this is not needed with cpython which clears non-circular refs immediately - #gc.collect() + #gc_collect() l = query.filter(items.c.item_id.between(x*100 - 100 + 1, x*100)).all() assert len(l) == 100 print "loaded ", len(l), " items " @@ -61,5 +61,3 @@ class LoadTest(TestBase, AssertsExecutionResults): print "total time ", total -if __name__ == "__main__": - testenv.main() diff --git a/test/perf/masssave.py b/test/perf/masssave.py index bf65c8fdf7..41acd12ccf 100644 --- a/test/perf/masssave.py +++ b/test/perf/masssave.py @@ -1,21 +1,23 @@ -import testenv; testenv.configure_for_tests() +import gc import types from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * +from sqlalchemy.test import * NUM = 2500 class SaveTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global items, metadata metadata = MetaData(testing.db) items = Table('items', metadata, Column('item_id', Integer, primary_key=True), Column('value', String(100))) items.create() - def tearDownAll(self): + @classmethod + def teardown_class(cls): clear_mappers() metadata.drop_all() @@ -50,5 +52,3 @@ class SaveTest(TestBase, AssertsExecutionResults): print x -if __name__ == "__main__": - testenv.main() diff --git a/test/perf/objselectspeed.py b/test/perf/objselectspeed.py index 896fd4c494..867a396f35 100644 --- a/test/perf/objselectspeed.py +++ b/test/perf/objselectspeed.py @@ -1,7 +1,8 @@ import testenv; testenv.simple_setup() -import time, gc, resource +import time, resource from sqlalchemy import * from sqlalchemy.orm import * +from sqlalchemy.test.util import gc_collect db = create_engine('sqlite://') @@ -68,35 +69,35 @@ def all(): usage.snap = lambda stats=None: setattr( usage, 'last', stats or resource.getrusage(resource.RUSAGE_SELF)) - gc.collect() + gc_collect() usage.snap() t = time.clock() sqlite_select(RawPerson) t2 = time.clock() usage('sqlite select/native') - gc.collect() + gc_collect() usage.snap() t = time.clock() sqlite_select(Person) t2 = time.clock() usage('sqlite select/instrumented') - gc.collect() + gc_collect() usage.snap() t = time.clock() sql_select(RawPerson) t2 = time.clock() usage('sqlalchemy.sql select/native') - gc.collect() + gc_collect() usage.snap() t = time.clock() sql_select(Person) t2 = time.clock() usage('sqlalchemy.sql select/instrumented') - gc.collect() + gc_collect() usage.snap() t = time.clock() orm_select() diff --git a/test/perf/objupdatespeed.py b/test/perf/objupdatespeed.py index a49eb47245..52224211ae 100644 --- a/test/perf/objupdatespeed.py +++ b/test/perf/objupdatespeed.py @@ -1,8 +1,9 @@ -import testenv; testenv.configure_for_tests() -import time, gc, resource +import time, resource from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * +from sqlalchemy.test import * +from sqlalchemy.test.util import gc_collect + NUM = 100 @@ -72,14 +73,14 @@ def all(): session = create_session() - gc.collect() + gc_collect() usage.snap() t = time.clock() people = orm_select(session) t2 = time.clock() usage('load objects') - gc.collect() + gc_collect() usage.snap() t = time.clock() update_and_flush(session, people) diff --git a/test/perf/ormsession.py b/test/perf/ormsession.py index cdffa51a96..f9f9dee8b7 100644 --- a/test/perf/ormsession.py +++ b/test/perf/ormsession.py @@ -1,11 +1,10 @@ -import testenv; testenv.configure_for_tests() import time from datetime import datetime from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * -from testlib.profiling import profiled +from sqlalchemy.test import * +from sqlalchemy.test.profiling import profiled class Item(object): def __repr__(self): diff --git a/test/perf/poolload.py b/test/perf/poolload.py index 8d66da84f4..62c66fbae6 100644 --- a/test/perf/poolload.py +++ b/test/perf/poolload.py @@ -1,9 +1,8 @@ # load test of connection pool -import testenv; testenv.configure_for_tests() import thread, time from sqlalchemy import * import sqlalchemy.pool as pool -from testlib import testing +from sqlalchemy.test import testing db = create_engine(testing.db.url, pool_timeout=30, echo_pool=True) metadata = MetaData(db) diff --git a/test/perf/sessions.py b/test/perf/sessions.py index f4be1ee936..0d4cc1f014 100644 --- a/test/perf/sessions.py +++ b/test/perf/sessions.py @@ -1,17 +1,17 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -import gc -from testlib import TestBase, AssertsExecutionResults, profiling, testing -from orm import _fixtures +from sqlalchemy.test.compat import gc_collect +from sqlalchemy.test import TestBase, AssertsExecutionResults, profiling, testing +from test.orm import _fixtures # in this test we are specifically looking for time spent in the attributes.InstanceState.__cleanup() method. ITERATIONS = 100 class SessionTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global t1, t2, metadata,T1, T2 metadata = MetaData(testing.db) t1 = Table('t1', metadata, @@ -46,7 +46,8 @@ class SessionTest(TestBase, AssertsExecutionResults): }) mapper(T2, t2) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() clear_mappers() @@ -60,7 +61,7 @@ class SessionTest(TestBase, AssertsExecutionResults): sess.close() del sess - gc.collect() + gc_collect() @profiling.profiled('dirty', report=True) def test_session_dirty(self): @@ -74,11 +75,11 @@ class SessionTest(TestBase, AssertsExecutionResults): t2.c2 = 'this is some modified text' del t1s - gc.collect() + gc_collect() sess.close() del sess - gc.collect() + gc_collect() @profiling.profiled('noclose', report=True) def test_session_noclose(self): @@ -89,9 +90,6 @@ class SessionTest(TestBase, AssertsExecutionResults): t1s[index].t2s del sess - gc.collect() - + gc_collect() -if __name__ == '__main__': - testenv.main() diff --git a/test/perf/wsgi.py b/test/perf/wsgi.py index 6fc8149bcd..549c92ade8 100644 --- a/test/perf/wsgi.py +++ b/test/perf/wsgi.py @@ -1,11 +1,10 @@ #!/usr/bin/python """Uses ``wsgiref``, standard in Python 2.5 and also in the cheeseshop.""" -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * import thread -from testlib import * +from sqlalchemy.test import * port = 8000 diff --git a/test/sql/test_constraints.py b/test/sql/test_constraints.py index 8abeb35338..4ad52604d3 100644 --- a/test/sql/test_constraints.py +++ b/test/sql/test_constraints.py @@ -1,10 +1,14 @@ -from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +from sqlalchemy.test.testing import assert_raises, assert_raises_message from sqlalchemy import * -from sqlalchemy import exc +from sqlalchemy import exc, schema from sqlalchemy.test import * from sqlalchemy.test import config, engines +from sqlalchemy.engine import ddl +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.assertsql import AllOf, RegexSQL, ExactSQL, CompiledSQL +from sqlalchemy.dialects.postgresql import base as postgresql -class ConstraintTest(TestBase, AssertsExecutionResults): +class ConstraintTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): def setup(self): global metadata @@ -33,11 +37,8 @@ class ConstraintTest(TestBase, AssertsExecutionResults): def test_double_fk_usage_raises(self): f = ForeignKey('b.id') - assert_raises(exc.InvalidRequestError, Table, "a", metadata, - Column('x', Integer, f), - Column('y', Integer, f) - ) - + Column('x', Integer, f) + assert_raises(exc.InvalidRequestError, Column, "y", Integer, f) def test_circular_constraint(self): a = Table("a", metadata, @@ -78,18 +79,9 @@ class ConstraintTest(TestBase, AssertsExecutionResults): metadata.create_all() foo.insert().execute(id=1,x=9,y=5) - try: - foo.insert().execute(id=2,x=5,y=9) - assert False - except exc.SQLError: - assert True - + assert_raises(exc.SQLError, foo.insert().execute, id=2,x=5,y=9) bar.insert().execute(id=1,x=10) - try: - bar.insert().execute(id=2,x=5) - assert False - except exc.SQLError: - assert True + assert_raises(exc.SQLError, bar.insert().execute, id=2,x=5) def test_unique_constraint(self): foo = Table('foo', metadata, @@ -106,16 +98,8 @@ class ConstraintTest(TestBase, AssertsExecutionResults): foo.insert().execute(id=2, value='value2') bar.insert().execute(id=1, value='a', value2='a') bar.insert().execute(id=2, value='a', value2='b') - try: - foo.insert().execute(id=3, value='value1') - assert False - except exc.SQLError: - assert True - try: - bar.insert().execute(id=3, value='a', value2='b') - assert False - except exc.SQLError: - assert True + assert_raises(exc.SQLError, foo.insert().execute, id=3, value='value1') + assert_raises(exc.SQLError, bar.insert().execute, id=3, value='a', value2='b') def test_index_create(self): employees = Table('employees', metadata, @@ -174,35 +158,22 @@ class ConstraintTest(TestBase, AssertsExecutionResults): Index('sport_announcer', events.c.sport, events.c.announcer, unique=True) Index('idx_winners', events.c.winner) - index_names = [ ix.name for ix in events.indexes ] - assert 'ix_events_name' in index_names - assert 'ix_events_location' in index_names - assert 'sport_announcer' in index_names - assert 'idx_winners' in index_names - assert len(index_names) == 4 - - capt = [] - connection = testing.db.connect() - # TODO: hacky, put a real connection proxy in - ex = connection._Connection__execute_context - def proxy(context): - capt.append(context.statement) - capt.append(repr(context.parameters)) - ex(context) - connection._Connection__execute_context = proxy - schemagen = testing.db.dialect.schemagenerator(testing.db.dialect, connection) - schemagen.traverse(events) - - assert capt[0].strip().startswith('CREATE TABLE events') - - s = set([capt[x].strip() for x in [2,4,6,8]]) - - assert s == set([ - 'CREATE UNIQUE INDEX ix_events_name ON events (name)', - 'CREATE INDEX ix_events_location ON events (location)', - 'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)', - 'CREATE INDEX idx_winners ON events (winner)' - ]) + eq_( + set([ ix.name for ix in events.indexes ]), + set(['ix_events_name', 'ix_events_location', 'sport_announcer', 'idx_winners']) + ) + + self.assert_sql_execution( + testing.db, + lambda: events.create(testing.db), + RegexSQL("^CREATE TABLE events"), + AllOf( + ExactSQL('CREATE UNIQUE INDEX ix_events_name ON events (name)'), + ExactSQL('CREATE INDEX ix_events_location ON events (location)'), + ExactSQL('CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'), + ExactSQL('CREATE INDEX idx_winners ON events (winner)') + ) + ) # verify that the table is functional events.insert().execute(id=1, name='hockey finals', location='rink', @@ -214,84 +185,57 @@ class ConstraintTest(TestBase, AssertsExecutionResults): dialect = testing.db.dialect.__class__() dialect.max_identifier_length = 20 - schemagen = dialect.schemagenerator(dialect, None) - schemagen.execute = lambda : None - t1 = Table("sometable", MetaData(), Column("foo", Integer)) - schemagen.visit_index(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)) - eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)") - schemagen.buffer.truncate(0) - schemagen.visit_index(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)) - eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)") - - schemadrop = dialect.schemadropper(dialect, None) - schemadrop.execute = lambda: None - assert_raises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)) + self.assert_compile( + schema.CreateIndex(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)), + "CREATE INDEX this_name_is_t_1 ON sometable (foo)", + dialect=dialect + ) + + self.assert_compile( + schema.CreateIndex(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)), + "CREATE INDEX this_other_nam_1 ON sometable (foo)", + dialect=dialect + ) -class ConstraintCompilationTest(TestBase, AssertsExecutionResults): - class accum(object): - def __init__(self): - self.statements = [] - def __call__(self, sql, *a, **kw): - self.statements.append(sql) - def __contains__(self, substring): - for s in self.statements: - if substring in s: - return True - return False - def __str__(self): - return '\n'.join([repr(x) for x in self.statements]) - def clear(self): - del self.statements[:] - - def setup(self): - self.sql = self.accum() - opts = config.db_opts.copy() - opts['strategy'] = 'mock' - opts['executor'] = self.sql - self.engine = engines.testing_engine(options=opts) - +class ConstraintCompilationTest(TestBase, AssertsCompiledSQL): def _test_deferrable(self, constraint_factory): - meta = MetaData(self.engine) - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=True)) - t.create() - assert 'DEFERRABLE' in self.sql, self.sql - assert 'NOT DEFERRABLE' not in self.sql, self.sql - self.sql.clear() - meta.clear() - - t = Table('tbl', meta, + + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'DEFERRABLE' in sql, sql + assert 'NOT DEFERRABLE' not in sql, sql + + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=False)) - t.create() - assert 'NOT DEFERRABLE' in self.sql - self.sql.clear() - meta.clear() - t = Table('tbl', meta, + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'NOT DEFERRABLE' in sql + + + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=True, initially='IMMEDIATE')) - t.create() - assert 'NOT DEFERRABLE' not in self.sql - assert 'INITIALLY IMMEDIATE' in self.sql - self.sql.clear() - meta.clear() + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'NOT DEFERRABLE' not in sql + assert 'INITIALLY IMMEDIATE' in sql - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=True, initially='DEFERRED')) - t.create() + sql = str(schema.CreateTable(t).compile(bind=testing.db)) - assert 'NOT DEFERRABLE' not in self.sql - assert 'INITIALLY DEFERRED' in self.sql, self.sql + assert 'NOT DEFERRABLE' not in sql + assert 'INITIALLY DEFERRED' in sql def test_deferrable_pk(self): factory = lambda **kw: PrimaryKeyConstraint('a', **kw) @@ -302,15 +246,16 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults): self._test_deferrable(factory) def test_deferrable_column_fk(self): - meta = MetaData(self.engine) - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer, ForeignKey('tbl.a', deferrable=True, initially='DEFERRED'))) - t.create() - assert 'DEFERRABLE' in self.sql, self.sql - assert 'INITIALLY DEFERRED' in self.sql, self.sql + + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE tbl (a INTEGER, b INTEGER, FOREIGN KEY(b) REFERENCES tbl (a) DEFERRABLE INITIALLY DEFERRED)", + ) def test_deferrable_unique(self): factory = lambda **kw: UniqueConstraint('b', **kw) @@ -321,15 +266,105 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults): self._test_deferrable(factory) def test_deferrable_column_check(self): - meta = MetaData(self.engine) - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer, CheckConstraint('a < b', deferrable=True, initially='DEFERRED'))) - t.create() - assert 'DEFERRABLE' in self.sql, self.sql - assert 'INITIALLY DEFERRED' in self.sql, self.sql + + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE tbl (a INTEGER, b INTEGER CHECK (a < b) DEFERRABLE INITIALLY DEFERRED)" + ) + + def test_use_alter(self): + m = MetaData() + t = Table('t', m, + Column('a', Integer), + ) + + t2 = Table('t2', m, + Column('a', Integer, ForeignKey('t.a', use_alter=True, name='fk_ta')), + Column('b', Integer, ForeignKey('t.a', name='fk_tb')), # to ensure create ordering ... + ) + + e = engines.mock_engine(dialect_name='postgresql') + m.create_all(e) + m.drop_all(e) + + e.assert_sql([ + 'CREATE TABLE t (a INTEGER)', + 'CREATE TABLE t2 (a INTEGER, b INTEGER, CONSTRAINT fk_tb FOREIGN KEY(b) REFERENCES t (a))', + 'ALTER TABLE t2 ADD CONSTRAINT fk_ta FOREIGN KEY(a) REFERENCES t (a)', + 'ALTER TABLE t2 DROP CONSTRAINT fk_ta', + 'DROP TABLE t2', + 'DROP TABLE t' + ]) + + + def test_add_drop_constraint(self): + m = MetaData() + + t = Table('tbl', m, + Column('a', Integer), + Column('b', Integer) + ) + + t2 = Table('t2', m, + Column('a', Integer), + Column('b', Integer) + ) + + constraint = CheckConstraint('a < b',name="my_test_constraint", deferrable=True,initially='DEFERRED', table=t) + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD CONSTRAINT my_test_constraint CHECK (a < b) DEFERRABLE INITIALLY DEFERRED" + ) + + self.assert_compile( + schema.DropConstraint(constraint), + "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint" + ) + + self.assert_compile( + schema.DropConstraint(constraint, cascade=True), + "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint CASCADE" + ) + constraint = ForeignKeyConstraint(["b"], ["t2.a"]) + t.append_constraint(constraint) + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD FOREIGN KEY(b) REFERENCES t2 (a)" + ) + constraint = ForeignKeyConstraint([t.c.a], [t2.c.b]) + t.append_constraint(constraint) + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD FOREIGN KEY(a) REFERENCES t2 (b)" + ) + + constraint = UniqueConstraint("a", "b", name="uq_cst") + t2.append_constraint(constraint) + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE t2 ADD CONSTRAINT uq_cst UNIQUE (a, b)" + ) + + constraint = UniqueConstraint(t2.c.a, t2.c.b, name="uq_cs2") + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE t2 ADD CONSTRAINT uq_cs2 UNIQUE (a, b)" + ) + + assert t.c.a.primary_key is False + constraint = PrimaryKeyConstraint(t.c.a) + assert t.c.a.primary_key is True + self.assert_compile( + schema.AddConstraint(constraint), + "ALTER TABLE tbl ADD PRIMARY KEY (a)" + ) + + diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 9641574665..5638dad77f 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -3,7 +3,7 @@ import datetime from sqlalchemy import Sequence, Column, func from sqlalchemy.sql import select, text import sqlalchemy as sa -from sqlalchemy.test import testing +from sqlalchemy.test import testing, engines from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean from sqlalchemy.test.schema import Table from sqlalchemy.test.testing import eq_ @@ -37,7 +37,7 @@ class DefaultTest(testing.TestBase): # since its a "branched" connection conn.close() - use_function_defaults = testing.against('postgres', 'mssql', 'maxdb') + use_function_defaults = testing.against('postgresql', 'mssql', 'maxdb') is_oracle = testing.against('oracle') # select "count(1)" returns different results on different DBs also @@ -146,7 +146,7 @@ class DefaultTest(testing.TestBase): assert_raises_message(sa.exc.ArgumentError, ex_msg, sa.ColumnDefault, fn) - + def test_arg_signature(self): def fn1(): pass def fn2(): pass @@ -276,7 +276,7 @@ class DefaultTest(testing.TestBase): assert r.lastrow_has_defaults() eq_(set(r.context.postfetch_cols), set([t.c.col3, t.c.col5, t.c.col4, t.c.col6])) - + eq_(t.select(t.c.col1==54).execute().fetchall(), [(54, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today, None)]) @@ -284,7 +284,7 @@ class DefaultTest(testing.TestBase): @testing.fails_on('firebird', 'Data type unknown') def test_insertmany(self): # MySQL-Python 1.2.2 breaks functions in execute_many :( - if (testing.against('mysql') and + if (testing.against('mysql') and not testing.against('+zxjdbc') and testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)): return @@ -304,12 +304,12 @@ class DefaultTest(testing.TestBase): def test_insert_values(self): t.insert(values={'col3':50}).execute() l = t.select().execute() - eq_(50, l.fetchone()['col3']) + eq_(50, l.first()['col3']) @testing.fails_on('firebird', 'Data type unknown') def test_updatemany(self): # MySQL-Python 1.2.2 breaks functions in execute_many :( - if (testing.against('mysql') and + if (testing.against('mysql') and not testing.against('+zxjdbc') and testing.db.dialect.dbapi.version_info[:3] == (1, 2, 2)): return @@ -337,11 +337,11 @@ class DefaultTest(testing.TestBase): @testing.fails_on('firebird', 'Data type unknown') def test_update(self): r = t.insert().execute() - pk = r.last_inserted_ids()[0] + pk = r.inserted_primary_key[0] t.update(t.c.col1==pk).execute(col4=None, col5=None) ctexec = currenttime.scalar() l = t.select(t.c.col1==pk).execute() - l = l.fetchone() + l = l.first() eq_(l, (pk, 'im the update', f2, None, None, ctexec, True, False, 13, datetime.date.today(), 'py')) @@ -350,43 +350,12 @@ class DefaultTest(testing.TestBase): @testing.fails_on('firebird', 'Data type unknown') def test_update_values(self): r = t.insert().execute() - pk = r.last_inserted_ids()[0] + pk = r.inserted_primary_key[0] t.update(t.c.col1==pk, values={'col3': 55}).execute() l = t.select(t.c.col1==pk).execute() - l = l.fetchone() + l = l.first() eq_(55, l['col3']) - @testing.fails_on_everything_except('postgres') - def test_passive_override(self): - """ - Primarily for postgres, tests that when we get a primary key column - back from reflecting a table which has a default value on it, we - pre-execute that DefaultClause upon insert, even though DefaultClause - says "let the database execute this", because in postgres we must have - all the primary key values in memory before insert; otherwise we can't - locate the just inserted row. - - """ - # TODO: move this to dialect/postgres - try: - meta = MetaData(testing.db) - testing.db.execute(""" - CREATE TABLE speedy_users - ( - speedy_user_id SERIAL PRIMARY KEY, - - user_name VARCHAR NOT NULL, - user_password VARCHAR NOT NULL - ); - """, None) - - t = Table("speedy_users", meta, autoload=True) - t.insert().execute(user_name='user', user_password='lala') - l = t.select().execute().fetchall() - eq_(l, [(1, 'user', 'lala')]) - finally: - testing.db.execute("drop table speedy_users", None) - class PKDefaultTest(_base.TablesTest): __requires__ = ('subqueries',) @@ -400,18 +369,27 @@ class PKDefaultTest(_base.TablesTest): Column('id', Integer, primary_key=True, default=sa.select([func.max(t2.c.nextid)]).as_scalar()), Column('data', String(30))) - - @testing.fails_on('mssql', 'FIXME: unknown') + + @testing.requires.returning + def test_with_implicit_returning(self): + self._test(True) + + def test_regular(self): + self._test(False) + @testing.resolve_artifact_names - def test_basic(self): - t2.insert().execute(nextid=1) - r = t1.insert().execute(data='hi') - eq_([1], r.last_inserted_ids()) - - t2.insert().execute(nextid=2) - r = t1.insert().execute(data='there') - eq_([2], r.last_inserted_ids()) + def _test(self, returning): + if not returning and not testing.db.dialect.implicit_returning: + engine = testing.db + else: + engine = engines.testing_engine(options={'implicit_returning':returning}) + engine.execute(t2.insert(), nextid=1) + r = engine.execute(t1.insert(), data='hi') + eq_([1], r.inserted_primary_key) + engine.execute(t2.insert(), nextid=2) + r = engine.execute(t1.insert(), data='there') + eq_([2], r.inserted_primary_key) class PKIncrementTest(_base.TablesTest): run_define_tables = 'each' @@ -430,29 +408,31 @@ class PKIncrementTest(_base.TablesTest): def _test_autoincrement(self, bind): ids = set() rs = bind.execute(aitable.insert(), int1=1) - last = rs.last_inserted_ids()[0] + last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) ids.add(last) rs = bind.execute(aitable.insert(), str1='row 2') - last = rs.last_inserted_ids()[0] + last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) ids.add(last) rs = bind.execute(aitable.insert(), int1=3, str1='row 3') - last = rs.last_inserted_ids()[0] + last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) ids.add(last) rs = bind.execute(aitable.insert(values={'int1':func.length('four')})) - last = rs.last_inserted_ids()[0] + last = rs.inserted_primary_key[0] self.assert_(last) self.assert_(last not in ids) ids.add(last) + eq_(ids, set([1,2,3,4])) + eq_(list(bind.execute(aitable.select().order_by(aitable.c.id))), [(1, 1, None), (2, None, 'row 2'), (3, 3, 'row 3'), (4, 4, None)]) @@ -510,8 +490,8 @@ class AutoIncrementTest(_base.TablesTest): single.create() r = single.insert().execute() - id_ = r.last_inserted_ids()[0] - assert id_ is not None + id_ = r.inserted_primary_key[0] + eq_(id_, 1) eq_(1, sa.select([func.count(sa.text('*'))], from_obj=single).scalar()) def test_autoincrement_fk(self): @@ -522,7 +502,7 @@ class AutoIncrementTest(_base.TablesTest): nodes.create() r = nodes.insert().execute(data='foo') - id_ = r.last_inserted_ids()[0] + id_ = r.inserted_primary_key[0] nodes.insert().execute(data='bar', parent_id=id_) @testing.fails_on('sqlite', 'FIXME: unknown') @@ -535,7 +515,7 @@ class AutoIncrementTest(_base.TablesTest): try: - # postgres + mysql strict will fail on first row, + # postgresql + mysql strict will fail on first row, # mysql in legacy mode fails on second row nonai.insert().execute(data='row 1') nonai.insert().execute(data='row 2') @@ -570,16 +550,17 @@ class SequenceTest(testing.TestBase): def testseqnonpk(self): """test sequences fire off as defaults on non-pk columns""" - result = sometable.insert().execute(name="somename") + engine = engines.testing_engine(options={'implicit_returning':False}) + result = engine.execute(sometable.insert(), name="somename") assert 'id' in result.postfetch_cols() - result = sometable.insert().execute(name="someother") + result = engine.execute(sometable.insert(), name="someother") assert 'id' in result.postfetch_cols() sometable.insert().execute( {'name':'name3'}, {'name':'name4'}) - eq_(sometable.select().execute().fetchall(), + eq_(sometable.select().order_by(sometable.c.id).execute().fetchall(), [(1, "somename", 1), (2, "someother", 2), (3, "name3", 3), @@ -590,8 +571,8 @@ class SequenceTest(testing.TestBase): cartitems.insert().execute(description='there') r = cartitems.insert().execute(description='lala') - assert r.last_inserted_ids() and r.last_inserted_ids()[0] is not None - id_ = r.last_inserted_ids()[0] + assert r.inserted_primary_key and r.inserted_primary_key[0] is not None + id_ = r.inserted_primary_key[0] eq_(1, sa.select([func.count(cartitems.c.cart_id)], diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index e9bf49ce30..7a0f12cac3 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -24,7 +24,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): bindtemplate = BIND_TEMPLATES[dialect.paramstyle] self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect) self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect) - if isinstance(dialect, firebird.dialect): + if isinstance(dialect, (firebird.dialect, maxdb.dialect, oracle.dialect)): self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect) else: self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect) @@ -50,7 +50,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): for ret, dialect in [ ('CURRENT_TIMESTAMP', sqlite.dialect()), - ('now()', postgres.dialect()), + ('now()', postgresql.dialect()), ('now()', mysql.dialect()), ('CURRENT_TIMESTAMP', oracle.dialect()) ]: @@ -62,9 +62,9 @@ class CompileTest(TestBase, AssertsCompiledSQL): for ret, dialect in [ ('random()', sqlite.dialect()), - ('random()', postgres.dialect()), + ('random()', postgresql.dialect()), ('rand()', mysql.dialect()), - ('random()', oracle.dialect()) + ('random', oracle.dialect()) ]: self.assert_compile(func.random(), ret, dialect=dialect) @@ -180,7 +180,10 @@ class CompileTest(TestBase, AssertsCompiledSQL): class ExecuteTest(TestBase): - + @engines.close_first + def tearDown(self): + pass + def test_standalone_execute(self): x = testing.db.func.current_date().execute().scalar() y = testing.db.func.current_date().select().execute().scalar() @@ -202,6 +205,7 @@ class ExecuteTest(TestBase): conn.close() assert (x == y == z) is True + @engines.close_first def test_update(self): """ Tests sending functions and SQL expressions to the VALUES and SET @@ -222,15 +226,15 @@ class ExecuteTest(TestBase): meta.create_all() try: t.insert(values=dict(value=func.length("one"))).execute() - assert t.select().execute().fetchone()['value'] == 3 + assert t.select().execute().first()['value'] == 3 t.update(values=dict(value=func.length("asfda"))).execute() - assert t.select().execute().fetchone()['value'] == 5 + assert t.select().execute().first()['value'] == 5 r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() - id = r.last_inserted_ids()[0] - assert t.select(t.c.id==id).execute().fetchone()['value'] == 9 + id = r.inserted_primary_key[0] + assert t.select(t.c.id==id).execute().first()['value'] == 9 t.update(values={t.c.value:func.length("asdf")}).execute() - assert t.select().execute().fetchone()['value'] == 4 + assert t.select().execute().first()['value'] == 4 print "--------------------------" t2.insert().execute() t2.insert(values=dict(value=func.length("one"))).execute() @@ -245,18 +249,18 @@ class ExecuteTest(TestBase): t2.delete().execute() t2.insert(values=dict(value=func.length("one") + 8)).execute() - assert t2.select().execute().fetchone()['value'] == 11 + assert t2.select().execute().first()['value'] == 11 t2.update(values=dict(value=func.length("asfda"))).execute() - assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (5, "thisisstuff") + assert select([t2.c.value, t2.c.stuff]).execute().first() == (5, "thisisstuff") t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute() - print "HI", select([t2.c.value, t2.c.stuff]).execute().fetchone() - assert select([t2.c.value, t2.c.stuff]).execute().fetchone() == (9, "foo") + print "HI", select([t2.c.value, t2.c.stuff]).execute().first() + assert select([t2.c.value, t2.c.stuff]).execute().first() == (9, "foo") finally: meta.drop_all() - @testing.fails_on_everything_except('postgres') + @testing.fails_on_everything_except('postgresql') def test_as_from(self): # TODO: shouldnt this work on oracle too ? x = testing.db.func.current_date().execute().scalar() @@ -266,7 +270,7 @@ class ExecuteTest(TestBase): # construct a column-based FROM object out of a function, like in [ticket:172] s = select([sql.column('date', type_=DateTime)], from_obj=[testing.db.func.current_date()]) - q = s.execute().fetchone()[s.c.date] + q = s.execute().first()[s.c.date] r = s.alias('datequery').select().scalar() assert x == y == z == w == q == r @@ -301,7 +305,7 @@ class ExecuteTest(TestBase): 'd': datetime.date(2010, 5, 1) }) rs = select([extract('year', table.c.dt), extract('month', table.c.d)]).execute() - row = rs.fetchone() + row = rs.first() assert row[0] == 2010 assert row[1] == 5 rs.close() diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py index b946b0ae98..bcac7c01d2 100644 --- a/test/sql/test_labels.py +++ b/test/sql/test_labels.py @@ -35,6 +35,7 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL): maxlen = testing.db.dialect.max_identifier_length testing.db.dialect.max_identifier_length = IDENT_LENGTH + @engines.close_first def teardown(self): table1.delete().execute() @@ -92,10 +93,16 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL): ], repr(result) def test_table_alias_names(self): - self.assert_compile( - table2.alias().select(), - "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs AS table_with_exactly_29_c_1" - ) + if testing.against('oracle'): + self.assert_compile( + table2.alias().select(), + "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs table_with_exactly_29_c_1" + ) + else: + self.assert_compile( + table2.alias().select(), + "SELECT table_with_exactly_29_c_1.this_is_the_primarykey_column, table_with_exactly_29_c_1.this_is_the_data_column FROM table_with_exactly_29_characs AS table_with_exactly_29_c_1" + ) ta = table2.alias() dialect = default.DefaultDialect() diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 51b933e458..0e3b9dff20 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -1,9 +1,11 @@ +from sqlalchemy.test.testing import eq_ import datetime from sqlalchemy import * from sqlalchemy import exc, sql from sqlalchemy.engine import default from sqlalchemy.test import * -from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.testing import eq_, assert_raises_message +from sqlalchemy.test.schema import Table, Column class QueryTest(TestBase): @@ -12,11 +14,11 @@ class QueryTest(TestBase): global users, users2, addresses, metadata metadata = MetaData(testing.db) users = Table('query_users', metadata, - Column('user_id', INT, primary_key = True), + Column('user_id', INT, primary_key=True, test_needs_autoincrement=True), Column('user_name', VARCHAR(20)), ) addresses = Table('query_addresses', metadata, - Column('address_id', Integer, primary_key=True), + Column('address_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey('query_users.user_id')), Column('address', String(30))) @@ -26,7 +28,8 @@ class QueryTest(TestBase): ) metadata.create_all() - def tearDown(self): + @engines.close_first + def teardown(self): addresses.delete().execute() users.delete().execute() users2.delete().execute() @@ -52,89 +55,133 @@ class QueryTest(TestBase): assert users.count().scalar() == 1 users.update(users.c.user_id == 7).execute(user_name = 'fred') - assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred' + assert users.select(users.c.user_id==7).execute().first()['user_name'] == 'fred' def test_lastrow_accessor(self): - """Tests the last_inserted_ids() and lastrow_has_id() functions.""" + """Tests the inserted_primary_key and lastrow_has_id() functions.""" - def insert_values(table, values): + def insert_values(engine, table, values): """ Inserts a row into a table, returns the full list of values INSERTed including defaults that fired off on the DB side and detects rows that had defaults and post-fetches. """ - result = table.insert().execute(**values) + result = engine.execute(table.insert(), **values) ret = values.copy() - for col, id in zip(table.primary_key, result.last_inserted_ids()): + for col, id in zip(table.primary_key, result.inserted_primary_key): ret[col.key] = id if result.lastrow_has_defaults(): - criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())]) - row = table.select(criterion).execute().fetchone() + criterion = and_(*[col==id for col, id in zip(table.primary_key, result.inserted_primary_key)]) + row = engine.execute(table.select(criterion)).first() for c in table.c: ret[c.key] = row[c] return ret - for supported, table, values, assertvalues in [ - ( - {'unsupported':['sqlite']}, - Table("t1", metadata, - Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True), - Column('foo', String(30), primary_key=True)), - {'foo':'hi'}, - {'id':1, 'foo':'hi'} - ), - ( - {'unsupported':['sqlite']}, - Table("t2", metadata, - Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True), - Column('foo', String(30), primary_key=True), - Column('bar', String(30), server_default='hi') + if testing.against('firebird', 'postgresql', 'oracle', 'mssql'): + test_engines = [ + engines.testing_engine(options={'implicit_returning':False}), + engines.testing_engine(options={'implicit_returning':True}), + ] + else: + test_engines = [testing.db] + + for engine in test_engines: + metadata = MetaData() + for supported, table, values, assertvalues in [ + ( + {'unsupported':['sqlite']}, + Table("t1", metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('foo', String(30), primary_key=True)), + {'foo':'hi'}, + {'id':1, 'foo':'hi'} ), - {'foo':'hi'}, - {'id':1, 'foo':'hi', 'bar':'hi'} - ), - ( - {'unsupported':[]}, - Table("t3", metadata, - Column("id", String(40), primary_key=True), - Column('foo', String(30), primary_key=True), - Column("bar", String(30)) + ( + {'unsupported':['sqlite']}, + Table("t2", metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('foo', String(30), primary_key=True), + Column('bar', String(30), server_default='hi') ), - {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}, - {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"} - ), - ( - {'unsupported':[]}, - Table("t4", metadata, - Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True), - Column('foo', String(30), primary_key=True), - Column('bar', String(30), server_default='hi') + {'foo':'hi'}, + {'id':1, 'foo':'hi', 'bar':'hi'} ), - {'foo':'hi', 'id':1}, - {'id':1, 'foo':'hi', 'bar':'hi'} - ), - ( - {'unsupported':[]}, - Table("t5", metadata, - Column('id', String(10), primary_key=True), - Column('bar', String(30), server_default='hi') + ( + {'unsupported':[]}, + Table("t3", metadata, + Column("id", String(40), primary_key=True), + Column('foo', String(30), primary_key=True), + Column("bar", String(30)) + ), + {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}, + {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"} ), - {'id':'id1'}, - {'id':'id1', 'bar':'hi'}, - ), - ]: - if testing.db.name in supported['unsupported']: - continue - try: - table.create() - i = insert_values(table, values) - assert i == assertvalues, repr(i) + " " + repr(assertvalues) - finally: - table.drop() + ( + {'unsupported':[]}, + Table("t4", metadata, + Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True), + Column('foo', String(30), primary_key=True), + Column('bar', String(30), server_default='hi') + ), + {'foo':'hi', 'id':1}, + {'id':1, 'foo':'hi', 'bar':'hi'} + ), + ( + {'unsupported':[]}, + Table("t5", metadata, + Column('id', String(10), primary_key=True), + Column('bar', String(30), server_default='hi') + ), + {'id':'id1'}, + {'id':'id1', 'bar':'hi'}, + ), + ]: + if testing.db.name in supported['unsupported']: + continue + try: + table.create(bind=engine, checkfirst=True) + i = insert_values(engine, table, values) + assert i == assertvalues, "tablename: %s %r %r" % (table.name, repr(i), repr(assertvalues)) + finally: + table.drop(bind=engine) + + @testing.fails_on('sqlite', "sqlite autoincremnt doesn't work with composite pks") + def test_misordered_lastrow(self): + related = Table('related', metadata, + Column('id', Integer, primary_key=True) + ) + t6 = Table("t6", metadata, + Column('manual_id', Integer, ForeignKey('related.id'), primary_key=True), + Column('auto_id', Integer, primary_key=True, test_needs_autoincrement=True), + ) + metadata.create_all() + r = related.insert().values(id=12).execute() + id = r.inserted_primary_key[0] + assert id==12 + + r = t6.insert().values(manual_id=id).execute() + eq_(r.inserted_primary_key, [12, 1]) + + def test_autoclose_on_insert(self): + if testing.against('firebird', 'postgresql', 'oracle', 'mssql'): + test_engines = [ + engines.testing_engine(options={'implicit_returning':False}), + engines.testing_engine(options={'implicit_returning':True}), + ] + else: + test_engines = [testing.db] + + for engine in test_engines: + + r = engine.execute(users.insert(), + {'user_name':'jack'}, + ) + assert r.closed + def test_row_iteration(self): users.insert().execute( {'user_id':7, 'user_name':'jack'}, @@ -147,7 +194,7 @@ class QueryTest(TestBase): l.append(row) self.assert_(len(l) == 3) - @testing.fails_on('firebird', 'Data type unknown') + @testing.fails_on('firebird', "kinterbasdb doesn't send full type information") @testing.requires.subqueries def test_anonymous_rows(self): users.insert().execute( @@ -161,6 +208,7 @@ class QueryTest(TestBase): assert row['anon_1'] == 8 assert row['anon_2'] == 10 + @testing.fails_on('firebird', "kinterbasdb doesn't send full type information") def test_order_by_label(self): """test that a label within an ORDER BY works on each backend. @@ -179,6 +227,11 @@ class QueryTest(TestBase): select([concat]).order_by(concat).execute().fetchall(), [("test: ed",), ("test: fred",), ("test: jack",)] ) + + eq_( + select([concat]).order_by(concat).execute().fetchall(), + [("test: ed",), ("test: fred",), ("test: jack",)] + ) concat = ("test: " + users.c.user_name).label('thedata') eq_( @@ -195,7 +248,7 @@ class QueryTest(TestBase): def test_row_comparison(self): users.insert().execute(user_id = 7, user_name = 'jack') - rp = users.select().execute().fetchone() + rp = users.select().execute().first() self.assert_(rp == rp) self.assert_(not(rp != rp)) @@ -207,8 +260,7 @@ class QueryTest(TestBase): self.assert_(not (rp != equal)) self.assert_(not (equal != equal)) - @testing.fails_on('mssql', 'No support for boolean logic in column select.') - @testing.fails_on('oracle', 'FIXME: unknown') + @testing.requires.boolean_col_expressions def test_or_and_as_columns(self): true, false = literal(True), literal(False) @@ -218,11 +270,11 @@ class QueryTest(TestBase): eq_(testing.db.execute(select([or_(false, false)])).scalar(), False) eq_(testing.db.execute(select([not_(or_(false, false))])).scalar(), True) - row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).fetchone() + row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).first() assert row.x == False assert row.y == False - row = testing.db.execute(select([or_(true, false).label("x"), and_(true, false).label("y")])).fetchone() + row = testing.db.execute(select([or_(true, false).label("x"), and_(true, false).label("y")])).first() assert row.x == True assert row.y == False @@ -253,6 +305,9 @@ class QueryTest(TestBase): eq_(expr.execute().fetchall(), result) + @testing.fails_on("firebird", "see dialect.test_firebird:MiscTest.test_percents_in_text") + @testing.fails_on("oracle", "neither % nor %% are accepted") + @testing.fails_on("+pg8000", "can't interpret result column from '%%'") @testing.emits_warning('.*now automatically escapes.*') def test_percents_in_text(self): for expr, result in ( @@ -277,7 +332,7 @@ class QueryTest(TestBase): eq_(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )]) - if testing.against('postgres'): + if testing.against('postgresql'): eq_(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )]) eq_(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), []) @@ -373,7 +428,7 @@ class QueryTest(TestBase): s = select([datetable.alias('x').c.today]).as_scalar() s2 = select([datetable.c.id, s.label('somelabel')]) #print s2.c.somelabel.type - assert isinstance(s2.execute().fetchone()['somelabel'], datetime.datetime) + assert isinstance(s2.execute().first()['somelabel'], datetime.datetime) finally: datetable.drop() @@ -444,45 +499,58 @@ class QueryTest(TestBase): users.insert().execute(user_id=2, user_name='jack') addresses.insert().execute(address_id=1, user_id=2, address='foo@bar.com') - r = users.select(users.c.user_id==2).execute().fetchone() + r = users.select(users.c.user_id==2).execute().first() self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2) self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack') - - r = text("select * from query_users where user_id=2", bind=testing.db).execute().fetchone() + + r = text("select * from query_users where user_id=2", bind=testing.db).execute().first() self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2) self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack') - + # test slices - r = text("select * from query_addresses", bind=testing.db).execute().fetchone() + r = text("select * from query_addresses", bind=testing.db).execute().first() self.assert_(r[0:1] == (1,)) self.assert_(r[1:] == (2, 'foo@bar.com')) self.assert_(r[:-1] == (1, 2)) - + # test a little sqlite weirdness - with the UNION, cols come back as "query_users.user_id" in cursor.description r = text("select query_users.user_id, query_users.user_name from query_users " - "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().fetchone() + "UNION select query_users.user_id, query_users.user_name from query_users", bind=testing.db).execute().first() self.assert_(r['user_id']) == 1 self.assert_(r['user_name']) == "john" # test using literal tablename.colname - r = text('select query_users.user_id AS "query_users.user_id", query_users.user_name AS "query_users.user_name" from query_users', bind=testing.db).execute().fetchone() + r = text('select query_users.user_id AS "query_users.user_id", ' + 'query_users.user_name AS "query_users.user_name" from query_users', + bind=testing.db).execute().first() self.assert_(r['query_users.user_id']) == 1 self.assert_(r['query_users.user_name']) == "john" # unary experssions - r = select([users.c.user_name.distinct()]).order_by(users.c.user_name).execute().fetchone() + r = select([users.c.user_name.distinct()]).order_by(users.c.user_name).execute().first() eq_(r[users.c.user_name], 'jack') eq_(r.user_name, 'jack') - r.close() + + def test_result_case_sensitivity(self): + """test name normalization for result sets.""" + row = testing.db.execute( + select([ + literal_column("1").label("case_insensitive"), + literal_column("2").label("CaseSensitive") + ]) + ).first() + + assert row.keys() == ["case_insensitive", "CaseSensitive"] + def test_row_as_args(self): users.insert().execute(user_id=1, user_name='john') - r = users.select(users.c.user_id==1).execute().fetchone() + r = users.select(users.c.user_id==1).execute().first() users.delete().execute() users.insert().execute(r) - assert users.select().execute().fetchall() == [(1, 'john')] - + eq_(users.select().execute().fetchall(), [(1, 'john')]) + def test_result_as_args(self): users.insert().execute([dict(user_id=1, user_name='john'), dict(user_id=2, user_name='ed')]) r = users.select().execute() @@ -496,13 +564,12 @@ class QueryTest(TestBase): def test_ambiguous_column(self): users.insert().execute(user_id=1, user_name='john') - r = users.outerjoin(addresses).select().execute().fetchone() - try: - print r['user_id'] - assert False - except exc.InvalidRequestError, e: - assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \ - str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement." + r = users.outerjoin(addresses).select().execute().first() + assert_raises_message( + exc.InvalidRequestError, + "Ambiguous column name", + lambda: r['user_id'] + ) @testing.requires.subqueries def test_column_label_targeting(self): @@ -512,31 +579,29 @@ class QueryTest(TestBase): users.select().alias('foo'), users.select().alias(users.name), ): - row = s.select(use_labels=True).execute().fetchone() + row = s.select(use_labels=True).execute().first() assert row[s.c.user_id] == 7 assert row[s.c.user_name] == 'ed' def test_keys(self): users.insert().execute(user_id=1, user_name='foo') - r = users.select().execute().fetchone() + r = users.select().execute().first() eq_([x.lower() for x in r.keys()], ['user_id', 'user_name']) def test_items(self): users.insert().execute(user_id=1, user_name='foo') - r = users.select().execute().fetchone() + r = users.select().execute().first() eq_([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')]) def test_len(self): users.insert().execute(user_id=1, user_name='foo') - r = users.select().execute().fetchone() + r = users.select().execute().first() eq_(len(r), 2) - r.close() - r = testing.db.execute('select user_name, user_id from query_users').fetchone() + + r = testing.db.execute('select user_name, user_id from query_users').first() eq_(len(r), 2) - r.close() - r = testing.db.execute('select user_name from query_users').fetchone() + r = testing.db.execute('select user_name from query_users').first() eq_(len(r), 1) - r.close() def test_cant_execute_join(self): try: @@ -549,7 +614,7 @@ class QueryTest(TestBase): def test_column_order_with_simple_query(self): # should return values in column definition order users.insert().execute(user_id=1, user_name='foo') - r = users.select(users.c.user_id==1).execute().fetchone() + r = users.select(users.c.user_id==1).execute().first() eq_(r[0], 1) eq_(r[1], 'foo') eq_([x.lower() for x in r.keys()], ['user_id', 'user_name']) @@ -558,7 +623,7 @@ class QueryTest(TestBase): def test_column_order_with_text_query(self): # should return values in query order users.insert().execute(user_id=1, user_name='foo') - r = testing.db.execute('select user_name, user_id from query_users').fetchone() + r = testing.db.execute('select user_name, user_id from query_users').first() eq_(r[0], 'foo') eq_(r[1], 1) eq_([x.lower() for x in r.keys()], ['user_name', 'user_id']) @@ -580,7 +645,7 @@ class QueryTest(TestBase): shadowed.create(checkfirst=True) try: shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row') - r = shadowed.select(shadowed.c.shadow_id==1).execute().fetchone() + r = shadowed.select(shadowed.c.shadow_id==1).execute().first() self.assert_(r.shadow_id == r['shadow_id'] == r[shadowed.c.shadow_id] == 1) self.assert_(r.shadow_name == r['shadow_name'] == r[shadowed.c.shadow_name] == 'The Shadow') self.assert_(r.parent == r['parent'] == r[shadowed.c.parent] == 'The Light') @@ -622,13 +687,13 @@ class QueryTest(TestBase): # Null values are not outside any set assert len(r) == 0 - u = bindparam('search_key') + @testing.fails_on('firebird', "kinterbasdb doesn't send full type information") + def test_bind_in(self): + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'fred') + users.insert().execute(user_id = 9, user_name = None) - s = users.select(u.in_([])) - r = s.execute(search_key='john').fetchall() - assert len(r) == 0 - r = s.execute(search_key=None).fetchall() - assert len(r) == 0 + u = bindparam('search_key') s = users.select(not_(u.in_([]))) r = s.execute(search_key='john').fetchall() @@ -660,14 +725,15 @@ class QueryTest(TestBase): class PercentSchemaNamesTest(TestBase): """tests using percent signs, spaces in table and column names. - Doesn't pass for mysql, postgres, but this is really a + Doesn't pass for mysql, postgresql, but this is really a SQLAlchemy bug - we should be escaping out %% signs for this operation the same way we do for text() and column labels. """ + @classmethod @testing.crashes('mysql', 'mysqldb calls name % (params)') - @testing.crashes('postgres', 'postgres calls name % (params)') + @testing.crashes('postgresql', 'postgresql calls name % (params)') def setup_class(cls): global percent_table, metadata metadata = MetaData(testing.db) @@ -680,12 +746,12 @@ class PercentSchemaNamesTest(TestBase): @classmethod @testing.crashes('mysql', 'mysqldb calls name % (params)') - @testing.crashes('postgres', 'postgres calls name % (params)') + @testing.crashes('postgresql', 'postgresql calls name % (params)') def teardown_class(cls): metadata.drop_all() @testing.crashes('mysql', 'mysqldb calls name % (params)') - @testing.crashes('postgres', 'postgres calls name % (params)') + @testing.crashes('postgresql', 'postgresql calls name % (params)') def test_roundtrip(self): percent_table.insert().execute( {'percent%':5, '%(oneofthese)s':7, 'spaces % more spaces':12}, @@ -731,7 +797,7 @@ class PercentSchemaNamesTest(TestBase): percent_table.update().values({percent_table.c['%(oneofthese)s']:9, percent_table.c['spaces % more spaces']:15}).execute() eq_( - percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute().fetchall(), + percent_table.select().order_by(percent_table.c['percent%']).execute().fetchall(), [ (5, 9, 15), (7, 9, 15), @@ -852,7 +918,11 @@ class CompoundTest(TestBase): dict(col2="t3col2r2", col3="bbb", col4="aaa"), dict(col2="t3col2r3", col3="ccc", col4="bbb"), ]) - + + @engines.close_first + def teardown(self): + pass + @classmethod def teardown_class(cls): metadata.drop_all() @@ -878,6 +948,7 @@ class CompoundTest(TestBase): found2 = self._fetchall_sorted(u.alias('bar').select().execute()) eq_(found2, wanted) + @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs") def test_union_ordered(self): (s1, s2) = ( select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], @@ -891,6 +962,7 @@ class CompoundTest(TestBase): ('ccc', 'aaa')] eq_(u.execute().fetchall(), wanted) + @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs") @testing.fails_on('maxdb', 'FIXME: unknown') @testing.requires.subqueries def test_union_ordered_alias(self): @@ -907,6 +979,7 @@ class CompoundTest(TestBase): eq_(u.alias('bar').select().execute().fetchall(), wanted) @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') + @testing.fails_on('firebird', "has trouble extracting anonymous column from union subquery") @testing.fails_on('mysql', 'FIXME: unknown') @testing.fails_on('sqlite', 'FIXME: unknown') def test_union_all(self): @@ -925,6 +998,29 @@ class CompoundTest(TestBase): found2 = self._fetchall_sorted(e.alias('foo').select().execute()) eq_(found2, wanted) + def test_union_all_lightweight(self): + """like test_union_all, but breaks the sub-union into + a subquery with an explicit column reference on the outside, + more palatable to a wider variety of engines. + + """ + u = union( + select([t1.c.col3]), + select([t1.c.col3]), + ).alias() + + e = union_all( + select([t1.c.col3]), + select([u.c.col3]) + ) + + wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)] + found1 = self._fetchall_sorted(e.execute()) + eq_(found1, wanted) + + found2 = self._fetchall_sorted(e.alias('foo').select().execute()) + eq_(found2, wanted) + @testing.crashes('firebird', 'Does not support intersect') @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on') @testing.fails_on('mysql', 'FIXME: unknown') @@ -1330,3 +1426,6 @@ class OperatorTest(TestBase): order_by=flds.c.idcol).execute().fetchall(), [(2,),(1,)] ) + + + diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index 64e097b85f..3198a07af4 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -129,7 +129,7 @@ class QuoteTest(TestBase, AssertsCompiledSQL): def testlabels(self): """test the quoting of labels. - if labels arent quoted, a query in postgres in particular will fail since it produces: + if labels arent quoted, a query in postgresql in particular will fail since it produces: SELECT LaLa.lowercase, LaLa."UPPERCASE", LaLa."MixedCase", LaLa."ASC" FROM (SELECT DISTINCT "WorstCase1".lowercase AS lowercase, "WorstCase1"."UPPERCASE" AS UPPERCASE, "WorstCase1"."MixedCase" AS MixedCase, "WorstCase1"."ASC" AS ASC \nFROM "WorstCase1") AS LaLa diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py new file mode 100644 index 0000000000..e076f3fe7c --- /dev/null +++ b/test/sql/test_returning.py @@ -0,0 +1,159 @@ +from sqlalchemy.test.testing import eq_ +from sqlalchemy import * +from sqlalchemy.test import * +from sqlalchemy.test.schema import Table, Column +from sqlalchemy.types import TypeDecorator + + +class ReturningTest(TestBase, AssertsExecutionResults): + __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access') + + def setup(self): + meta = MetaData(testing.db) + global table, GoofyType + + class GoofyType(TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + if value is None: + return None + return "FOO" + value + + def process_result_value(self, value, dialect): + if value is None: + return None + return value + "BAR" + + table = Table('tables', meta, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('persons', Integer), + Column('full', Boolean), + Column('goofy', GoofyType(50)) + ) + table.create(checkfirst=True) + + def teardown(self): + table.drop() + + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_column_targeting(self): + result = table.insert().returning(table.c.id, table.c.full).execute({'persons': 1, 'full': False}) + + row = result.first() + assert row[table.c.id] == row['id'] == 1 + assert row[table.c.full] == row['full'] == False + + result = table.insert().values(persons=5, full=True, goofy="somegoofy").\ + returning(table.c.persons, table.c.full, table.c.goofy).execute() + row = result.first() + assert row[table.c.persons] == row['persons'] == 5 + assert row[table.c.full] == row['full'] == True + assert row[table.c.goofy] == row['goofy'] == "FOOsomegoofyBAR" + + @testing.fails_on('firebird', "fb can't handle returning x AS y") + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_labeling(self): + result = table.insert().values(persons=6).\ + returning(table.c.persons.label('lala')).execute() + row = result.first() + assert row['lala'] == 6 + + @testing.fails_on('firebird', "fb/kintersbasdb can't handle the bind params") + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_anon_expressions(self): + result = table.insert().values(goofy="someOTHERgoofy").\ + returning(func.lower(table.c.goofy, type_=GoofyType)).execute() + row = result.first() + assert row[0] == "foosomeothergoofyBAR" + + result = table.insert().values(persons=12).\ + returning(table.c.persons + 18).execute() + row = result.first() + assert row[0] == 30 + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_update_returning(self): + table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + + result = table.update(table.c.persons > 4, dict(full=True)).returning(table.c.id).execute() + eq_(result.fetchall(), [(1,)]) + + result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() + eq_(result2.fetchall(), [(1,True),(2,False)]) + + @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_insert_returning(self): + result = table.insert().returning(table.c.id).execute({'persons': 1, 'full': False}) + + eq_(result.fetchall(), [(1,)]) + + @testing.fails_on('postgresql', '') + @testing.fails_on('oracle', '') + def test_executemany(): + # return value is documented as failing with psycopg2/executemany + result2 = table.insert().returning(table).execute( + [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) + + if testing.against('firebird', 'mssql'): + # Multiple inserts only return the last row + eq_(result2.fetchall(), [(3,3,True, None)]) + else: + # nobody does this as far as we know (pg8000?) + eq_(result2.fetchall(), [(2, 2, False, None), (3,3,True, None)]) + + test_executemany() + + result3 = table.insert().returning(table.c.id).execute({'persons': 4, 'full': False}) + eq_([dict(row) for row in result3], [{'id': 4}]) + + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + @testing.fails_on_everything_except('postgresql', 'firebird') + def test_literal_returning(self): + if testing.against("postgresql"): + literal_true = "true" + else: + literal_true = "1" + + result4 = testing.db.execute('insert into tables (id, persons, "full") ' + 'values (5, 10, %s) returning persons' % literal_true) + eq_([dict(row) for row in result4], [{'persons': 10}]) + + @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') + @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') + def test_delete_returning(self): + table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + + result = table.delete(table.c.persons > 4).returning(table.c.id).execute() + eq_(result.fetchall(), [(1,)]) + + result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() + eq_(result2.fetchall(), [(2,False),]) + +class SequenceReturningTest(TestBase): + __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'mssql') + + def setup(self): + meta = MetaData(testing.db) + global table, seq + seq = Sequence('tid_seq') + table = Table('tables', meta, + Column('id', Integer, seq, primary_key=True), + Column('data', String(50)) + ) + table.create(checkfirst=True) + + def teardown(self): + table.drop() + + def test_insert(self): + r = table.insert().values(data='hi').returning(table.c.id).execute() + assert r.first() == (1, ) + assert seq.execute() == 2 diff --git a/test/sql/test_select.py b/test/sql/test_select.py index f70492fb31..9acc94eb28 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -5,7 +5,7 @@ from sqlalchemy import exc, sql, util from sqlalchemy.sql import table, column, label, compiler from sqlalchemy.sql.expression import ClauseList from sqlalchemy.engine import default -from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql +from sqlalchemy.databases import * from sqlalchemy.test import * table1 = table('mytable', @@ -149,12 +149,10 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A ) self.assert_compile( - select([cast("data", sqlite.SLInteger)], use_labels=True), # this will work with plain Integer in 0.6 + select([cast("data", Integer)], use_labels=True), # this will work with plain Integer in 0.6 "SELECT CAST(:param_1 AS INTEGER) AS anon_1" ) - - def test_nested_uselabels(self): """test nested anonymous label generation. this essentially tests the ANONYMOUS_LABEL regex. @@ -429,7 +427,12 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A def test_operators(self): for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), - (operator.sub, '-'), (operator.div, '/'), + (operator.sub, '-'), + # Py3K + #(operator.truediv, '/'), + # Py2K + (operator.div, '/'), + # end Py2K ): for (lhs, rhs, res) in ( (5, table1.c.myid, ':myid_1 %s mytable.myid'), @@ -519,22 +522,22 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A (~table1.c.myid.like('somstr', escape='\\'), "mytable.myid NOT LIKE :myid_1 ESCAPE '\\'", None), (table1.c.myid.ilike('somstr', escape='\\'), "lower(mytable.myid) LIKE lower(:myid_1) ESCAPE '\\'", None), (~table1.c.myid.ilike('somstr', escape='\\'), "lower(mytable.myid) NOT LIKE lower(:myid_1) ESCAPE '\\'", None), - (table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid ILIKE %(myid_1)s ESCAPE '\\'", postgres.PGDialect()), - (~table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid NOT ILIKE %(myid_1)s ESCAPE '\\'", postgres.PGDialect()), + (table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid ILIKE %(myid_1)s ESCAPE '\\'", postgresql.PGDialect()), + (~table1.c.myid.ilike('somstr', escape='\\'), "mytable.myid NOT ILIKE %(myid_1)s ESCAPE '\\'", postgresql.PGDialect()), (table1.c.name.ilike('%something%'), "lower(mytable.name) LIKE lower(:name_1)", None), - (table1.c.name.ilike('%something%'), "mytable.name ILIKE %(name_1)s", postgres.PGDialect()), + (table1.c.name.ilike('%something%'), "mytable.name ILIKE %(name_1)s", postgresql.PGDialect()), (~table1.c.name.ilike('%something%'), "lower(mytable.name) NOT LIKE lower(:name_1)", None), - (~table1.c.name.ilike('%something%'), "mytable.name NOT ILIKE %(name_1)s", postgres.PGDialect()), + (~table1.c.name.ilike('%something%'), "mytable.name NOT ILIKE %(name_1)s", postgresql.PGDialect()), ]: self.assert_compile(expr, check, dialect=dialect) def test_match(self): for expr, check, dialect in [ (table1.c.myid.match('somstr'), "mytable.myid MATCH ?", sqlite.SQLiteDialect()), - (table1.c.myid.match('somstr'), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", mysql.MySQLDialect()), - (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", mssql.MSSQLDialect()), - (table1.c.myid.match('somstr'), "mytable.myid @@ to_tsquery(%(myid_1)s)", postgres.PGDialect()), - (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.OracleDialect()), + (table1.c.myid.match('somstr'), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", mysql.dialect()), + (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", mssql.dialect()), + (table1.c.myid.match('somstr'), "mytable.myid @@ to_tsquery(%(myid_1)s)", postgresql.dialect()), + (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.dialect()), ]: self.assert_compile(expr, check, dialect=dialect) @@ -635,7 +638,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A select([table1.alias('foo')]) ,"SELECT foo.myid, foo.name, foo.description FROM mytable AS foo") - for dialect in (firebird.dialect(), oracle.dialect()): + for dialect in (oracle.dialect(),): self.assert_compile( select([table1.alias('foo')]) ,"SELECT foo.myid, foo.name, foo.description FROM mytable foo" @@ -748,7 +751,7 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = params={}, ) - dialect = postgres.dialect() + dialect = postgresql.dialect() self.assert_compile( text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar',4), bindparam('whee',7)]), "select * from foo where lala=%(bar)s and hoho=%(whee)s", @@ -1122,10 +1125,10 @@ UNION SELECT mytable.myid FROM mytable" self.assert_compile(stmt, expected_positional_stmt, dialect=sqlite.dialect()) nonpositional = stmt.compile() positional = stmt.compile(dialect=sqlite.dialect()) - pp = positional.get_params() + pp = positional.params assert [pp[k] for k in positional.positiontup] == expected_default_params_list - assert nonpositional.get_params(**test_param_dict) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict))) - pp = positional.get_params(**test_param_dict) + assert nonpositional.construct_params(test_param_dict) == expected_test_params_dict, "expected :%s got %s" % (str(expected_test_params_dict), str(nonpositional.get_params(**test_param_dict))) + pp = positional.construct_params(test_param_dict) assert [pp[k] for k in positional.positiontup] == expected_test_params_list # check that params() doesnt modify original statement @@ -1144,7 +1147,7 @@ UNION SELECT mytable.myid FROM mytable" ":myid_1) AS anon_1 FROM mytable WHERE mytable.myid = (SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1)") positional = s2.compile(dialect=sqlite.dialect()) - pp = positional.get_params() + pp = positional.params assert [pp[k] for k in positional.positiontup] == [12, 12] # check that conflicts with "unique" params are caught @@ -1163,11 +1166,11 @@ UNION SELECT mytable.myid FROM mytable" params = dict(('in%d' % i, i) for i in range(total_params)) sql = 'text clause %s' % ', '.join(in_clause) t = text(sql) - assert len(t.bindparams) == total_params + eq_(len(t.bindparams), total_params) c = t.compile() pp = c.construct_params(params) - assert len(set(pp)) == total_params - assert len(set(pp.values())) == total_params + eq_(len(set(pp)), total_params, '%s %s' % (len(set(pp)), len(pp))) + eq_(len(set(pp.values())), total_params) def test_bind_as_col(self): @@ -1291,28 +1294,28 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)") eq_(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0]) eq_(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1]) eq_(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2]) - eq_(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3])) + eq_(str(cast(1234, Text).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3])) eq_(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4])) # fixme: shoving all of this dialect-specific stuff in one test # is now officialy completely ridiculous AND non-obviously omits # coverage on other dialects. sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect) if isinstance(dialect, type(mysql.dialect())): - eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest") + eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL) AS anon_1 \nFROM casttest") else: - eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest") + eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC) AS anon_1 \nFROM casttest") # first test with PostgreSQL engine - check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s') + check_results(postgresql.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s') # then the Oracle engine - check_results(oracle.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':param_1') + check_results(oracle.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'CLOB', 'VARCHAR(20)'], ':param_1') # then the sqlite engine - check_results(sqlite.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?') + check_results(sqlite.dialect(), ['NUMERIC', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '?') # then the MySQL engine - check_results(mysql.dialect(), ['DECIMAL(10, 2)', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s') + check_results(mysql.dialect(), ['DECIMAL', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s') self.assert_compile(cast(text('NULL'), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect()) self.assert_compile(cast(null(), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect()) @@ -1360,7 +1363,6 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)") s1 = select([table1.c.myid, table1.c.myid.label('foobar'), func.hoho(table1.c.name), func.lala(table1.c.name).label('gg')]) assert s1.c.keys() == ['myid', 'foobar', 'hoho(mytable.name)', 'gg'] - from sqlalchemy.databases.sqlite import SLNumeric meta = MetaData() t1 = Table('mytable', meta, Column('col1', Integer)) @@ -1368,7 +1370,7 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)") (table1.c.name, 'name', 'mytable.name', None), (table1.c.myid==12, 'mytable.myid = :myid_1', 'mytable.myid = :myid_1', 'anon_1'), (func.hoho(table1.c.myid), 'hoho(mytable.myid)', 'hoho(mytable.myid)', 'hoho_1'), - (cast(table1.c.name, SLNumeric), 'CAST(mytable.name AS NUMERIC(10, 2))', 'CAST(mytable.name AS NUMERIC(10, 2))', 'anon_1'), + (cast(table1.c.name, Numeric), 'CAST(mytable.name AS NUMERIC)', 'CAST(mytable.name AS NUMERIC)', 'anon_1'), (t1.c.col1, 'col1', 'mytable.col1', None), (column('some wacky thing'), 'some wacky thing', '"some wacky thing"', '') ): diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index b0501c9134..95ca0d17bf 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -416,7 +416,7 @@ class ReduceTest(TestBase, AssertsExecutionResults): Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True), ) - # this is essentially the union formed by the ORM's polymorphic_union function. + # this is essentially the union formed by the ORM's polymorphic_union function. # we define two versions with different ordering of selects. # the first selectable has the "real" column classified_page.magazine_page_id @@ -432,7 +432,6 @@ class ReduceTest(TestBase, AssertsExecutionResults): magazine_page_table.c.page_id, cast(null(), Integer).label('magazine_page_id') ]).select_from(page_table.join(magazine_page_table)), - ).alias('pjoin') eq_( diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 15799358a7..9c90549e29 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1,101 +1,63 @@ +# coding: utf-8 from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import decimal import datetime, os, re from sqlalchemy import * -from sqlalchemy import exc, types, util +from sqlalchemy import exc, types, util, schema from sqlalchemy.sql import operators from sqlalchemy.test.testing import eq_ import sqlalchemy.engine.url as url -from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird +from sqlalchemy.databases import * + from sqlalchemy.test import * class AdaptTest(TestBase): - def testadapt(self): - e1 = url.URL('postgres').get_dialect()() - e2 = url.URL('mysql').get_dialect()() - e3 = url.URL('sqlite').get_dialect()() - e4 = url.URL('firebird').get_dialect()() - - type = String(40) - - t1 = type.dialect_impl(e1) - t2 = type.dialect_impl(e2) - t3 = type.dialect_impl(e3) - t4 = type.dialect_impl(e4) - - impls = [t1, t2, t3, t4] - for i,ta in enumerate(impls): - for j,tb in enumerate(impls): - if i == j: - assert ta == tb # call me paranoid... :) + def test_uppercase_rendering(self): + """Test that uppercase types from types.py always render as their type. + + As of SQLA 0.6, using an uppercase type means you want specifically that + type. If the database in use doesn't support that DDL, it (the DB backend) + should raise an error - it means you should be using a lowercased (genericized) type. + + """ + + for dialect in [ + oracle.dialect(), + mysql.dialect(), + postgresql.dialect(), + sqlite.dialect(), + sybase.dialect(), + informix.dialect(), + maxdb.dialect(), + mssql.dialect()]: # TODO when dialects are complete: engines.all_dialects(): + for type_, expected in ( + (FLOAT, "FLOAT"), + (NUMERIC, "NUMERIC"), + (DECIMAL, "DECIMAL"), + (INTEGER, "INTEGER"), + (SMALLINT, "SMALLINT"), + (TIMESTAMP, "TIMESTAMP"), + (DATETIME, "DATETIME"), + (DATE, "DATE"), + (TIME, "TIME"), + (CLOB, "CLOB"), + (VARCHAR, "VARCHAR"), + (NVARCHAR, ("NVARCHAR", "NATIONAL VARCHAR")), + (CHAR, "CHAR"), + (NCHAR, ("NCHAR", "NATIONAL CHAR")), + (BLOB, "BLOB"), + (BOOLEAN, ("BOOLEAN", "BOOL")) + ): + if isinstance(expected, str): + expected = (expected, ) + for exp in expected: + compiled = type_().compile(dialect=dialect) + if exp in compiled: + break else: - assert ta != tb - - def testmsnvarchar(self): - dialect = mssql.MSSQLDialect() - # run the test twice to ensure the caching step works too - for x in range(0, 1): - col = Column('', Unicode(length=10)) - dialect_type = col.type.dialect_impl(dialect) - assert isinstance(dialect_type, mssql.MSNVarchar) - assert dialect_type.get_col_spec() == 'NVARCHAR(10)' - - - def testoracletimestamp(self): - dialect = oracle.OracleDialect() - t1 = oracle.OracleTimestamp - t2 = oracle.OracleTimestamp() - t3 = types.TIMESTAMP - assert isinstance(dialect.type_descriptor(t1), oracle.OracleTimestamp) - assert isinstance(dialect.type_descriptor(t2), oracle.OracleTimestamp) - assert isinstance(dialect.type_descriptor(t3), oracle.OracleTimestamp) - - def testmysqlbinary(self): - dialect = mysql.MySQLDialect() - t1 = mysql.MSVarBinary - t2 = mysql.MSVarBinary() - assert isinstance(dialect.type_descriptor(t1), mysql.MSVarBinary) - assert isinstance(dialect.type_descriptor(t2), mysql.MSVarBinary) - - def teststringadapt(self): - """test that String with no size becomes TEXT, *all* others stay as varchar/String""" - - oracle_dialect = oracle.OracleDialect() - mysql_dialect = mysql.MySQLDialect() - postgres_dialect = postgres.PGDialect() - firebird_dialect = firebird.FBDialect() - - for dialect, start, test in [ - (oracle_dialect, String(), oracle.OracleString), - (oracle_dialect, VARCHAR(), oracle.OracleString), - (oracle_dialect, String(50), oracle.OracleString), - (oracle_dialect, Unicode(), oracle.OracleString), - (oracle_dialect, UnicodeText(), oracle.OracleText), - (oracle_dialect, NCHAR(), oracle.OracleString), - (oracle_dialect, oracle.OracleRaw(50), oracle.OracleRaw), - (mysql_dialect, String(), mysql.MSString), - (mysql_dialect, VARCHAR(), mysql.MSString), - (mysql_dialect, String(50), mysql.MSString), - (mysql_dialect, Unicode(), mysql.MSString), - (mysql_dialect, UnicodeText(), mysql.MSText), - (mysql_dialect, NCHAR(), mysql.MSNChar), - (postgres_dialect, String(), postgres.PGString), - (postgres_dialect, VARCHAR(), postgres.PGString), - (postgres_dialect, String(50), postgres.PGString), - (postgres_dialect, Unicode(), postgres.PGString), - (postgres_dialect, UnicodeText(), postgres.PGText), - (postgres_dialect, NCHAR(), postgres.PGString), - (firebird_dialect, String(), firebird.FBString), - (firebird_dialect, VARCHAR(), firebird.FBString), - (firebird_dialect, String(50), firebird.FBString), - (firebird_dialect, Unicode(), firebird.FBString), - (firebird_dialect, UnicodeText(), firebird.FBText), - (firebird_dialect, NCHAR(), firebird.FBString), - ]: - assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect)) - - + assert False, "%r matches none of %r for dialect %s" % (compiled, expected, dialect.name) + class UserDefinedTest(TestBase): """tests user-defined types.""" @@ -131,7 +93,7 @@ class UserDefinedTest(TestBase): def setup_class(cls): global users, metadata - class MyType(types.TypeEngine): + class MyType(types.UserDefinedType): def get_col_spec(self): return "VARCHAR(100)" def bind_processor(self, dialect): @@ -267,124 +229,105 @@ class ColumnsTest(TestBase, AssertsExecutionResults): for aCol in testTable.c: eq_( expectedResults[aCol.name], - db.dialect.schemagenerator(db.dialect, db, None, None).\ + db.dialect.ddl_compiler(db.dialect, schema.CreateTable(testTable)).\ get_column_specification(aCol)) class UnicodeTest(TestBase, AssertsExecutionResults): """tests the Unicode type. also tests the TypeDecorator with instances in the types package.""" + @classmethod def setup_class(cls): - global unicode_table + global unicode_table, metadata metadata = MetaData(testing.db) unicode_table = Table('unicode_table', metadata, Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True), Column('unicode_varchar', Unicode(250)), Column('unicode_text', UnicodeText), - Column('plain_varchar', String(250)) ) - unicode_table.create() + metadata.create_all() + @classmethod def teardown_class(cls): - unicode_table.drop() + metadata.drop_all() + @engines.close_first def teardown(self): unicode_table.delete().execute() def test_round_trip(self): - assert unicode_table.c.unicode_varchar.type.length == 250 - rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' - unicodedata = rawdata.decode('utf-8') - if testing.against('sqlite'): - rawdata = "something" - - unicode_table.insert().execute(unicode_varchar=unicodedata, - unicode_text=unicodedata, - plain_varchar=rawdata) - x = unicode_table.select().execute().fetchone() + unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »" + + unicode_table.insert().execute(unicode_varchar=unicodedata,unicode_text=unicodedata) + + x = unicode_table.select().execute().first() self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) - if isinstance(x['plain_varchar'], unicode): - # SQLLite and MSSQL return non-unicode data as unicode - self.assert_(testing.against('sqlite', 'mssql')) - if not testing.against('sqlite'): - self.assert_(x['plain_varchar'] == unicodedata) - else: - self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata) - def test_union(self): - """ensure compiler processing works for UNIONs""" + def test_round_trip_executemany(self): + # cx_oracle was producing different behavior for cursor.executemany() + # vs. cursor.execute() + + unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »" - rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' - unicodedata = rawdata.decode('utf-8') - if testing.against('sqlite'): - rawdata = "something" - unicode_table.insert().execute(unicode_varchar=unicodedata, - unicode_text=unicodedata, - plain_varchar=rawdata) - - x = union(select([unicode_table.c.unicode_varchar]), select([unicode_table.c.unicode_varchar])).execute().fetchone() + unicode_table.insert().execute( + dict(unicode_varchar=unicodedata,unicode_text=unicodedata), + dict(unicode_varchar=unicodedata,unicode_text=unicodedata) + ) + + x = unicode_table.select().execute().first() self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) + self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) - def test_assertions(self): - try: - unicode_table.insert().execute(unicode_varchar='not unicode') - assert False - except exc.SAWarning, e: - assert str(e) == "Unicode type received non-unicode bind param value 'not unicode'", str(e) + def test_union(self): + """ensure compiler processing works for UNIONs""" - unicode_engine = engines.utf8_engine(options={'convert_unicode':True, - 'assert_unicode':True}) - try: - try: - unicode_engine.execute(unicode_table.insert(), plain_varchar='im not unicode') - assert False - except exc.InvalidRequestError, e: - assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'" - - @testing.emits_warning('.*non-unicode bind') - def warns(): - # test that data still goes in if warning is emitted.... - unicode_table.insert().execute(unicode_varchar='not unicode') - assert (select([unicode_table.c.unicode_varchar]).execute().fetchall() == [('not unicode', )]) - warns() + unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »" - finally: - unicode_engine.dispose() + unicode_table.insert().execute(unicode_varchar=unicodedata,unicode_text=unicodedata) + + x = union(select([unicode_table.c.unicode_varchar]), select([unicode_table.c.unicode_varchar])).execute().first() + self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) - @testing.fails_on('oracle', 'FIXME: unknown') + @testing.fails_on('oracle', 'oracle converts empty strings to a blank space') def test_blank_strings(self): unicode_table.insert().execute(unicode_varchar=u'') assert select([unicode_table.c.unicode_varchar]).scalar() == u'' - def test_engine_parameter(self): - """tests engine-wide unicode conversion""" - prev_unicode = testing.db.engine.dialect.convert_unicode - prev_assert = testing.db.engine.dialect.assert_unicode - try: - testing.db.engine.dialect.convert_unicode = True - testing.db.engine.dialect.assert_unicode = False - rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' - unicodedata = rawdata.decode('utf-8') - if testing.against('sqlite', 'mssql'): - rawdata = "something" - unicode_table.insert().execute(unicode_varchar=unicodedata, - unicode_text=unicodedata, - plain_varchar=rawdata) - x = unicode_table.select().execute().fetchone() - self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) - self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) - if not testing.against('sqlite', 'mssql'): - self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata) - finally: - testing.db.engine.dialect.convert_unicode = prev_unicode - testing.db.engine.dialect.convert_unicode = prev_assert - - @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') - @testing.fails_on('firebird', 'Data type unknown') - def test_length_function(self): - """checks the database correctly understands the length of a unicode string""" - teststr = u'aaa\x1234' - self.assert_(testing.db.func.length(teststr).scalar() == len(teststr)) + def test_parameters(self): + """test the dialect convert_unicode parameters.""" + + unicodedata = u"Alors vous imaginez ma surprise, au lever du jour, quand une drôle de petit voix m’a réveillé. Elle disait: « S’il vous plaît… dessine-moi un mouton! »" + + u = Unicode(assert_unicode=True) + uni = u.dialect_impl(testing.db.dialect).bind_processor(testing.db.dialect) + # Py3K + #assert_raises(exc.InvalidRequestError, uni, b'x') + # Py2K + assert_raises(exc.InvalidRequestError, uni, 'x') + # end Py2K + + u = Unicode() + uni = u.dialect_impl(testing.db.dialect).bind_processor(testing.db.dialect) + # Py3K + #assert_raises(exc.SAWarning, uni, b'x') + # Py2K + assert_raises(exc.SAWarning, uni, 'x') + # end Py2K + + unicode_engine = engines.utf8_engine(options={'convert_unicode':True,'assert_unicode':True}) + unicode_engine.dialect.supports_unicode_binds = False + + s = String() + uni = s.dialect_impl(unicode_engine.dialect).bind_processor(unicode_engine.dialect) + # Py3K + #assert_raises(exc.InvalidRequestError, uni, b'x') + #assert isinstance(uni(unicodedata), bytes) + # Py2K + assert_raises(exc.InvalidRequestError, uni, 'x') + assert isinstance(uni(unicodedata), str) + # end Py2K + + assert uni(unicodedata) == unicodedata.encode('utf-8') class BinaryTest(TestBase, AssertsExecutionResults): __excluded_on__ = ( @@ -409,18 +352,19 @@ class BinaryTest(TestBase, AssertsExecutionResults): return value binary_table = Table('binary_table', MetaData(testing.db), - Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True), - Column('data', Binary), - Column('data_slice', Binary(100)), - Column('misc', String(30)), - # construct PickleType with non-native pickle module, since cPickle uses relative module - # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative - # to the 'types' module - Column('pickled', PickleType), - Column('mypickle', MyPickleType) + Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True), + Column('data', Binary), + Column('data_slice', Binary(100)), + Column('misc', String(30)), + # construct PickleType with non-native pickle module, since cPickle uses relative module + # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative + # to the 'types' module + Column('pickled', PickleType), + Column('mypickle', MyPickleType) ) binary_table.create() + @engines.close_first def teardown(self): binary_table.delete().execute() @@ -428,42 +372,65 @@ class BinaryTest(TestBase, AssertsExecutionResults): def teardown_class(cls): binary_table.drop() - @testing.fails_on('mssql', 'MSSQl BINARY type right pads the fixed length with \x00') - def testbinary(self): + def test_round_trip(self): testobj1 = pickleable.Foo('im foo 1') testobj2 = pickleable.Foo('im foo 2') testobj3 = pickleable.Foo('im foo 3') stream1 =self.load_stream('binary_data_one.dat') stream2 =self.load_stream('binary_data_two.dat') - binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3) - binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_slice=stream2[0:99], pickled=testobj2) - binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_slice=stream2[0:99], pickled=None) + binary_table.insert().execute( + primary_id=1, + misc='binary_data_one.dat', + data=stream1, + data_slice=stream1[0:100], + pickled=testobj1, + mypickle=testobj3) + binary_table.insert().execute( + primary_id=2, + misc='binary_data_two.dat', + data=stream2, + data_slice=stream2[0:99], + pickled=testobj2) + binary_table.insert().execute( + primary_id=3, + misc='binary_data_two.dat', + data=None, + data_slice=stream2[0:99], + pickled=None) for stmt in ( binary_table.select(order_by=binary_table.c.primary_id), text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db) ): + eq_data = lambda x, y: eq_(list(x), list(y)) + if util.jython: + _eq_data = eq_data + def eq_data(x, y): + # Jython currently returns arrays + from array import ArrayType + if isinstance(y, ArrayType): + return eq_(x, y.tostring()) + return _eq_data(x, y) l = stmt.execute().fetchall() - eq_(list(stream1), list(l[0]['data'])) - eq_(list(stream1[0:100]), list(l[0]['data_slice'])) - eq_(list(stream2), list(l[1]['data'])) + eq_data(stream1, l[0]['data']) + eq_data(stream1[0:100], l[0]['data_slice']) + eq_data(stream2, l[1]['data']) eq_(testobj1, l[0]['pickled']) eq_(testobj2, l[1]['pickled']) eq_(testobj3.moredata, l[0]['mypickle'].moredata) eq_(l[0]['mypickle'].stuff, 'this is the right stuff') - def load_stream(self, name, len=12579): + def load_stream(self, name): f = os.path.join(os.path.dirname(__file__), "..", name) - # put a number less than the typical MySQL default BLOB size - return file(f).read(len) + return open(f, mode='rb').read() class ExpressionTest(TestBase, AssertsExecutionResults): @classmethod def setup_class(cls): global test_table, meta - class MyCustomType(types.TypeEngine): + class MyCustomType(types.UserDefinedType): def get_col_spec(self): return "INT" def bind_processor(self, dialect): @@ -547,7 +514,6 @@ class DateTest(TestBase, AssertsExecutionResults): db = testing.db if testing.against('oracle'): - import sqlalchemy.databases.oracle as oracle insert_data = [ (7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), @@ -576,7 +542,7 @@ class DateTest(TestBase, AssertsExecutionResults): time_micro = 999 # Missing or poor microsecond support: - if testing.against('mssql', 'mysql', 'firebird'): + if testing.against('mssql', 'mysql', 'firebird', '+zxjdbc'): datetime_micro, time_micro = 0, 0 # No microseconds for TIME elif testing.against('maxdb'): @@ -608,7 +574,7 @@ class DateTest(TestBase, AssertsExecutionResults): Column('user_date', Date), Column('user_time', Time)] - if testing.against('sqlite', 'postgres'): + if testing.against('sqlite', 'postgresql'): insert_data.append( (11, 'historic', datetime.datetime(1850, 11, 10, 11, 52, 35, datetime_micro), @@ -676,8 +642,8 @@ class DateTest(TestBase, AssertsExecutionResults): t.drop(checkfirst=True) class StringTest(TestBase, AssertsExecutionResults): - @testing.fails_on('mysql', 'FIXME: unknown') - @testing.fails_on('oracle', 'FIXME: unknown') + + @testing.requires.unbounded_varchar def test_nolength_string(self): metadata = MetaData(testing.db) foo = Table('foo', metadata, Column('one', String)) @@ -700,10 +666,10 @@ class NumericTest(TestBase, AssertsExecutionResults): metadata = MetaData(testing.db) numeric_table = Table('numeric_table', metadata, Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True), - Column('numericcol', Numeric(asdecimal=False)), - Column('floatcol', Float), - Column('ncasdec', Numeric), - Column('fcasdec', Float(asdecimal=True)) + Column('numericcol', Numeric(precision=10, scale=2, asdecimal=False)), + Column('floatcol', Float(precision=10, )), + Column('ncasdec', Numeric(precision=10, scale=2)), + Column('fcasdec', Float(precision=10, asdecimal=True)) ) metadata.create_all() @@ -711,6 +677,7 @@ class NumericTest(TestBase, AssertsExecutionResults): def teardown_class(cls): metadata.drop_all() + @engines.close_first def teardown(self): numeric_table.delete().execute() @@ -719,6 +686,7 @@ class NumericTest(TestBase, AssertsExecutionResults): from decimal import Decimal numeric_table.insert().execute( numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.75) + numeric_table.insert().execute( numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), ncasdec=Decimal("12.4"), fcasdec=Decimal("15.75")) @@ -744,33 +712,6 @@ class NumericTest(TestBase, AssertsExecutionResults): assert isinstance(row['ncasdec'], decimal.Decimal) assert isinstance(row['fcasdec'], decimal.Decimal) - def test_length_deprecation(self): - assert_raises(exc.SADeprecationWarning, Numeric, length=8) - - @testing.uses_deprecated(".*is deprecated for Numeric") - def go(): - n = Numeric(length=12) - assert n.scale == 12 - go() - - n = Numeric(scale=12) - for dialect in engines.all_dialects(): - n2 = dialect.type_descriptor(n) - eq_(n2.scale, 12, dialect.name) - - # test colspec generates successfully using 'scale' - assert n2.get_col_spec() - - # test constructor of the dialect-specific type - n3 = n2.__class__(scale=5) - eq_(n3.scale, 5, dialect.name) - - @testing.uses_deprecated(".*is deprecated for Numeric") - def go(): - n3 = n2.__class__(length=6) - eq_(n3.scale, 6, dialect.name) - go() - class IntervalTest(TestBase, AssertsExecutionResults): @classmethod @@ -783,6 +724,7 @@ class IntervalTest(TestBase, AssertsExecutionResults): ) metadata.create_all() + @engines.close_first def teardown(self): interval_table.delete().execute() @@ -790,14 +732,16 @@ class IntervalTest(TestBase, AssertsExecutionResults): def teardown_class(cls): metadata.drop_all() + @testing.fails_on("+pg8000", "Not yet known how to pass values of the INTERVAL type") + @testing.fails_on("postgresql+zxjdbc", "Not yet known how to pass values of the INTERVAL type") def test_roundtrip(self): delta = datetime.datetime(2006, 10, 5) - datetime.datetime(2005, 8, 17) interval_table.insert().execute(interval=delta) - assert interval_table.select().execute().fetchone()['interval'] == delta + assert interval_table.select().execute().first()['interval'] == delta def test_null(self): interval_table.insert().execute(id=1, inverval=None) - assert interval_table.select().execute().fetchone()['interval'] is None + assert interval_table.select().execute().first()['interval'] is None class BooleanTest(TestBase, AssertsExecutionResults): @classmethod @@ -825,30 +769,6 @@ class BooleanTest(TestBase, AssertsExecutionResults): assert(res2==[(2, False)]) class PickleTest(TestBase): - def test_noeq_deprecation(self): - p1 = PickleType() - - assert_raises(DeprecationWarning, - p1.compare_values, pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2) - ) - - assert_raises(DeprecationWarning, - p1.compare_values, pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2) - ) - - @testing.uses_deprecated() - def go(): - # test actual dumps comparison - assert p1.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)) - assert p1.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)) - go() - - assert p1.compare_values({1:2, 3:4}, {3:4, 1:2}) - - p2 = PickleType(mutable=False) - assert not p2.compare_values(pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)) - assert not p2.compare_values(pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)) - def test_eq_comparison(self): p1 = PickleType() diff --git a/test/sql/test_unicode.py b/test/sql/test_unicode.py index d759132678..6551594f32 100644 --- a/test/sql/test_unicode.py +++ b/test/sql/test_unicode.py @@ -56,6 +56,7 @@ class UnicodeSchemaTest(TestBase): ) metadata.create_all() + @engines.close_first def teardown(self): if metadata.tables: t3.delete().execute() @@ -125,11 +126,11 @@ class EscapesDefaultsTest(testing.TestBase): # reset the identifier preparer, so that we can force it to cache # a unicode identifier engine.dialect.identifier_preparer = engine.dialect.preparer(engine.dialect) - select([column(u'special_col')]).select_from(t1).execute() + select([column(u'special_col')]).select_from(t1).execute().close() assert isinstance(engine.dialect.identifier_preparer.format_sequence(Sequence('special_col')), unicode) # now execute, run the sequence. it should run in u"Special_col.nextid" or similar as - # a unicode object; cx_oracle asserts that this is None or a String (postgres lets it pass thru). + # a unicode object; cx_oracle asserts that this is None or a String (postgresql lets it pass thru). # ensure that base.DefaultRunner is encoding. t1.insert().execute(data='foo') finally: diff --git a/test/zblog/mappers.py b/test/zblog/mappers.py index 5203bd866a..126d2c5684 100644 --- a/test/zblog/mappers.py +++ b/test/zblog/mappers.py @@ -1,7 +1,7 @@ """mapper.py - defines mappers for domain objects, mapping operations""" -import tables, user -from blog import * +from test.zblog import tables, user +from test.zblog.blog import * from sqlalchemy import * from sqlalchemy.orm import * import sqlalchemy.util as util diff --git a/test/zblog/tables.py b/test/zblog/tables.py index 36c7aeb8b1..4907259e18 100644 --- a/test/zblog/tables.py +++ b/test/zblog/tables.py @@ -1,12 +1,12 @@ """application table metadata objects are described here.""" from sqlalchemy import * - +from sqlalchemy.test.schema import Table, Column metadata = MetaData() users = Table('users', metadata, - Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key=True), + Column('user_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_name', String(30), nullable=False), Column('fullname', String(100), nullable=False), Column('password', String(40), nullable=False), @@ -14,14 +14,14 @@ users = Table('users', metadata, ) blogs = Table('blogs', metadata, - Column('blog_id', Integer, Sequence('blog_id_seq', optional=True), primary_key=True), + Column('blog_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('owner_id', Integer, ForeignKey('users.user_id'), nullable=False), Column('name', String(100), nullable=False), Column('description', String(500)) ) posts = Table('posts', metadata, - Column('post_id', Integer, Sequence('post_id_seq', optional=True), primary_key=True), + Column('post_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('blog_id', Integer, ForeignKey('blogs.blog_id'), nullable=False), Column('user_id', Integer, ForeignKey('users.user_id'), nullable=False), Column('datetime', DateTime, nullable=False), @@ -31,7 +31,7 @@ posts = Table('posts', metadata, ) topics = Table('topics', metadata, - Column('topic_id', Integer, primary_key=True), + Column('topic_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('keyword', String(50), nullable=False), Column('description', String(500)) ) @@ -43,7 +43,7 @@ topic_xref = Table('topic_post_xref', metadata, ) comments = Table('comments', metadata, - Column('comment_id', Integer, primary_key=True), + Column('comment_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', Integer, ForeignKey('users.user_id'), nullable=False), Column('post_id', Integer, ForeignKey('posts.post_id'), nullable=False), Column('datetime', DateTime, nullable=False), diff --git a/test/zblog/test_zblog.py b/test/zblog/test_zblog.py index 8170766cb2..5e46c1cebc 100644 --- a/test/zblog/test_zblog.py +++ b/test/zblog/test_zblog.py @@ -1,9 +1,9 @@ from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.test import * -import mappers, tables -from user import * -from blog import * +from test.zblog import mappers, tables +from test.zblog.user import * +from test.zblog.blog import * class ZBlogTest(TestBase, AssertsExecutionResults): diff --git a/test/zblog/user.py b/test/zblog/user.py index 0a13002cd8..30f1e3da16 100644 --- a/test/zblog/user.py +++ b/test/zblog/user.py @@ -14,9 +14,9 @@ groups = [user, administrator] def cryptpw(password, salt=None): if salt is None: - salt = string.join([chr(random.randint(ord('a'), ord('z'))), - chr(random.randint(ord('a'), ord('z')))],'') - return sha(password + salt).hexdigest() + salt = "".join([chr(random.randint(ord('a'), ord('z'))), + chr(random.randint(ord('a'), ord('z')))]) + return sha((password+ salt).encode('ascii')).hexdigest() def checkpw(password, dbpw): return cryptpw(password, dbpw[:2]) == dbpw