From: Mike Bayer Date: Tue, 13 Jan 2009 17:33:53 +0000 (+0000) Subject: first merge from the hg repo. may need cleanup/refreshing X-Git-Tag: rel_0_6_6~346 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2c3984eca11c4464da2f3955769b0967ca5bbc0e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git first merge from the hg repo. may need cleanup/refreshing --- diff --git a/doc/build/copyright.rst b/doc/build/copyright.rst index 227a54c9c8..501b4ee757 100644 --- a/doc/build/copyright.rst +++ b/doc/build/copyright.rst @@ -4,7 +4,7 @@ Appendix: Copyright This is the MIT license: ``_ -Copyright (c) 2005, 2006, 2007, 2008 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael +Copyright (c) 2005, 2006, 2007, 2008, 2009 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael Bayer. Permission is hereby granted, free of charge, to any person obtaining a copy of this diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/sqlalchemy/connectors/zxJDBC.py b/lib/sqlalchemy/connectors/zxJDBC.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index 6588be0ae7..7f124d7dbd 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -13,7 +13,5 @@ __all__ = ( 'mssql', 'mysql', 'oracle', - 'postgres', - 'sqlite', 'sybase', ) diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index 67af4a7a4a..de4af6bcb7 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -46,7 +46,7 @@ class AcTinyInteger(types.Integer): def get_col_spec(self): return "TINYINT" -class AcSmallInteger(types.Smallinteger): +class AcSmallInteger(types.SmallInteger): def get_col_spec(self): return "SMALLINT" @@ -155,7 +155,7 @@ class AccessDialect(default.DefaultDialect): colspecs = { types.Unicode : AcUnicode, types.Integer : AcInteger, - types.Smallinteger: AcSmallInteger, + types.SmallInteger: AcSmallInteger, types.Numeric : AcNumeric, types.Float : AcFloat, types.DateTime : AcDateTime, @@ -327,7 +327,7 @@ class AccessDialect(default.DefaultDialect): return names -class AccessCompiler(compiler.DefaultCompiler): +class AccessCompiler(compiler.SQLCompiler): def visit_select_precolumns(self, select): """Access puts TOP, it's version of LIMIT here """ s = select.distinct and "DISTINCT " or "" diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 6b1af9fab0..f00aa963ee 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -150,7 +150,7 @@ class FBInteger(sqltypes.Integer): return "INTEGER" -class FBSmallInteger(sqltypes.Smallinteger): +class FBSmallInteger(sqltypes.SmallInteger): """Handle ``SMALLINT`` datatype.""" def get_col_spec(self): @@ -231,7 +231,7 @@ class FBBoolean(sqltypes.Boolean): colspecs = { sqltypes.Integer : FBInteger, - sqltypes.Smallinteger : FBSmallInteger, + sqltypes.SmallInteger : FBSmallInteger, sqltypes.Numeric : FBNumeric, sqltypes.Float : FBFloat, sqltypes.DateTime : FBDateTime, @@ -564,12 +564,12 @@ def _substring(s, start, length=None): return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) -class FBCompiler(sql.compiler.DefaultCompiler): +class FBCompiler(sql.compiler.SQLCompiler): """Firebird specific idiosincrasies""" # Firebird lacks a builtin modulo operator, but there is # an equivalent function in the ib_udf library. - operators = sql.compiler.DefaultCompiler.operators.copy() + operators = sql.compiler.SQLCompiler.operators.copy() operators.update({ sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y) }) @@ -581,7 +581,7 @@ class FBCompiler(sql.compiler.DefaultCompiler): else: return self.process(alias.original, **kwargs) - functions = sql.compiler.DefaultCompiler.functions.copy() + functions = sql.compiler.SQLCompiler.functions.copy() functions['substring'] = _substring def function_argspec(self, func): diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 4476af3b9c..ad9dfd9bce 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -51,7 +51,7 @@ class InfoInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" -class InfoSmallInteger(sqltypes.Smallinteger): +class InfoSmallInteger(sqltypes.SmallInteger): def get_col_spec(self): return "SMALLINT" @@ -141,7 +141,7 @@ class InfoBoolean(sqltypes.Boolean): colspecs = { sqltypes.Integer : InfoInteger, - sqltypes.Smallinteger : InfoSmallInteger, + sqltypes.SmallInteger : InfoSmallInteger, sqltypes.Numeric : InfoNumeric, sqltypes.Float : InfoNumeric, sqltypes.DateTime : InfoDateTime, @@ -352,7 +352,7 @@ class InfoDialect(default.DefaultDialect): for cons_name, cons_type, local_column in rows: table.primary_key.add( table.c[local_column] ) -class InfoCompiler(compiler.DefaultCompiler): +class InfoCompiler(compiler.SQLCompiler): """Info compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False.""" @@ -360,7 +360,7 @@ class InfoCompiler(compiler.DefaultCompiler): self.limit = 0 self.offset = 0 - compiler.DefaultCompiler.__init__( self , *args, **kwargs ) + compiler.SQLCompiler.__init__( self , *args, **kwargs ) def default_from(self): return " from systables where tabname = 'systables' " @@ -393,7 +393,7 @@ class InfoCompiler(compiler.DefaultCompiler): if ( __label(c) not in a ): select.append_column( c ) - return compiler.DefaultCompiler.visit_select(self, select) + return compiler.SQLCompiler.visit_select(self, select) def limit_clause(self, select): return "" @@ -406,7 +406,7 @@ class InfoCompiler(compiler.DefaultCompiler): elif func.name.lower() in ( 'current_timestamp' , 'now' ): return "CURRENT YEAR TO SECOND" else: - return compiler.DefaultCompiler.visit_function( self , func ) + return compiler.SQLCompiler.visit_function( self , func ) def visit_clauselist(self, list, **kwargs): return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None]) diff --git a/lib/sqlalchemy/databases/maxdb.py b/lib/sqlalchemy/databases/maxdb.py index 693295054e..6e521297fc 100644 --- a/lib/sqlalchemy/databases/maxdb.py +++ b/lib/sqlalchemy/databases/maxdb.py @@ -344,7 +344,7 @@ class MaxBlob(sqltypes.Binary): colspecs = { sqltypes.Integer: MaxInteger, - sqltypes.Smallinteger: MaxSmallInteger, + sqltypes.SmallInteger: MaxSmallInteger, sqltypes.Numeric: MaxNumeric, sqltypes.Float: MaxFloat, sqltypes.DateTime: MaxTimestamp, @@ -717,8 +717,8 @@ class MaxDBDialect(default.DefaultDialect): return found -class MaxDBCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() +class MaxDBCompiler(compiler.SQLCompiler): + operators = compiler.SQLCompiler.operators.copy() operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y) function_conversion = { diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 7d23c5b273..dda0fddd24 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -923,7 +923,7 @@ class MSSQLDialect(default.DefaultDialect): colspecs = { sqltypes.Unicode : MSNVarchar, sqltypes.Integer : MSInteger, - sqltypes.Smallinteger: MSSmallInteger, + sqltypes.SmallInteger: MSSmallInteger, sqltypes.Numeric : MSNumeric, sqltypes.Float : MSFloat, sqltypes.DateTime : MSDateTime, @@ -1445,14 +1445,14 @@ dialect_mapping = { } -class MSSQLCompiler(compiler.DefaultCompiler): +class MSSQLCompiler(compiler.SQLCompiler): operators = compiler.OPERATORS.copy() operators.update({ sql_operators.concat_op: '+', sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) }) - functions = compiler.DefaultCompiler.functions.copy() + functions = compiler.SQLCompiler.functions.copy() functions.update ( { sql_functions.now: 'CURRENT_TIMESTAMP', @@ -1478,7 +1478,7 @@ class MSSQLCompiler(compiler.DefaultCompiler): if not self.dialect.has_window_funcs: raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') return s - return compiler.DefaultCompiler.get_select_precolumns(self, select) + return compiler.SQLCompiler.get_select_precolumns(self, select) def limit_clause(self, select): # Limit in mssql is after the select keyword @@ -1506,7 +1506,7 @@ class MSSQLCompiler(compiler.DefaultCompiler): limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) return self.process(limitselect, iswrapper=True, **kwargs) else: - return compiler.DefaultCompiler.visit_select(self, select, **kwargs) + return compiler.SQLCompiler.visit_select(self, select, **kwargs) def _schema_aliased_table(self, table): if getattr(table, 'schema', None) is not None: diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 3d71bb7232..ac4e64b597 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -630,7 +630,7 @@ class MSTinyInteger(MSInteger): return self._extend("TINYINT") -class MSSmallInteger(sqltypes.Smallinteger, MSInteger): +class MSSmallInteger(sqltypes.SmallInteger, MSInteger): """MySQL SMALLINTEGER type.""" def __init__(self, display_width=None, **kw): @@ -1363,7 +1363,7 @@ class MSBoolean(sqltypes.Boolean): colspecs = { sqltypes.Integer: MSInteger, - sqltypes.Smallinteger: MSSmallInteger, + sqltypes.SmallInteger: MSSmallInteger, sqltypes.Numeric: MSNumeric, sqltypes.Float: MSFloat, sqltypes.DateTime: MSDateTime, @@ -1901,14 +1901,14 @@ class _MySQLPythonRowProxy(object): return item -class MySQLCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() +class MySQLCompiler(compiler.SQLCompiler): + operators = compiler.SQLCompiler.operators.copy() operators.update({ sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y), sql_operators.mod: '%%', sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y) }) - functions = compiler.DefaultCompiler.functions.copy() + functions = compiler.SQLCompiler.functions.copy() functions.update ({ sql_functions.random: 'rand%(expr)s', "utc_timestamp":"UTC_TIMESTAMP" @@ -2013,7 +2013,8 @@ class MySQLCompiler(compiler.DefaultCompiler): self.isupdate = True colparams = self._get_colparams(update_stmt) - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \ + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 6749d8e407..b0ec6115b2 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -143,7 +143,7 @@ class OracleInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" -class OracleSmallInteger(sqltypes.Smallinteger): +class OracleSmallInteger(sqltypes.SmallInteger): def get_col_spec(self): return "SMALLINT" @@ -286,7 +286,7 @@ class OracleBoolean(sqltypes.Boolean): colspecs = { sqltypes.Integer : OracleInteger, - sqltypes.Smallinteger : OracleSmallInteger, + sqltypes.SmallInteger : OracleSmallInteger, sqltypes.Numeric : OracleNumeric, sqltypes.Float : OracleNumeric, sqltypes.DateTime : OracleDateTime, @@ -698,13 +698,13 @@ class _OuterJoinColumn(sql.ClauseElement): def __init__(self, column): self.column = column -class OracleCompiler(compiler.DefaultCompiler): +class OracleCompiler(compiler.SQLCompiler): """Oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False. """ - operators = compiler.DefaultCompiler.operators.copy() + operators = compiler.SQLCompiler.operators.copy() operators.update( { sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y), @@ -712,7 +712,7 @@ class OracleCompiler(compiler.DefaultCompiler): } ) - functions = compiler.DefaultCompiler.functions.copy() + functions = compiler.SQLCompiler.functions.copy() functions.update ( { sql_functions.now : 'CURRENT_TIMESTAMP' @@ -736,7 +736,7 @@ class OracleCompiler(compiler.DefaultCompiler): def visit_join(self, join, **kwargs): if self.dialect.use_ansi: - return compiler.DefaultCompiler.visit_join(self, join, **kwargs) + return compiler.SQLCompiler.visit_join(self, join, **kwargs) else: return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) @@ -846,7 +846,7 @@ class OracleCompiler(compiler.DefaultCompiler): select = offsetselect kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) - return compiler.DefaultCompiler.visit_select(self, select, **kwargs) + return compiler.SQLCompiler.visit_select(self, select, **kwargs) def limit_clause(self, select): return "" diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py deleted file mode 100644 index 8b46132f3e..0000000000 --- a/lib/sqlalchemy/databases/sqlite.py +++ /dev/null @@ -1,619 +0,0 @@ -# sqlite.py -# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -"""Support for the SQLite database. - -Driver ------- - -When using Python 2.5 and above, the built in ``sqlite3`` driver is -already installed and no additional installation is needed. Otherwise, -the ``pysqlite2`` driver needs to be present. This is the same driver as -``sqlite3``, just with a different name. - -The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3`` -is loaded. This allows an explicitly installed pysqlite driver to take -precedence over the built in one. As with all dialects, a specific -DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control -this explicitly:: - - from sqlite3 import dbapi2 as sqlite - e = create_engine('sqlite:///file.db', module=sqlite) - -Full documentation on pysqlite is available at: -``_ - -Connect Strings ---------------- - -The file specification for the SQLite database is taken as the "database" portion of -the URL. Note that the format of a url is:: - - driver://user:pass@host/database - -This means that the actual filename to be used starts with the characters to the -**right** of the third slash. So connecting to a relative filepath looks like:: - - # relative path - e = create_engine('sqlite:///path/to/database.db') - -An absolute path, which is denoted by starting with a slash, means you need **four** -slashes:: - - # absolute path - e = create_engine('sqlite:////path/to/database.db') - -To use a Windows path, regular drive specifications and backslashes can be used. -Double backslashes are probably needed:: - - # absolute path on Windows - e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db') - -The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify -``sqlite://`` and nothing else:: - - # in-memory database - e = create_engine('sqlite://') - -Threading Behavior ------------------- - -Pysqlite connections do not support being moved between threads, unless -the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition, -when using an in-memory SQLite database, the full database exists only within -the scope of a single connection. It is reported that an in-memory -database does not support being shared between threads regardless of the -``check_same_thread`` flag - which means that a multithreaded -application **cannot** share data from a ``:memory:`` database across threads -unless access to the connection is limited to a single worker thread which communicates -through a queueing mechanism to concurrent threads. - -To provide a default which accomodates SQLite's default threading capabilities -somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool` -be used by default. This pool maintains a single SQLite connection per thread -that is held open up to a count of five concurrent threads. When more than five threads -are used, a cleanup mechanism will dispose of excess unused connections. - -Two optional pool implementations that may be appropriate for particular SQLite usage scenarios: - - * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded - application using an in-memory database, assuming the threading issues inherent in - pysqlite are somehow accomodated for. This pool holds persistently onto a single connection - which is never closed, and is returned for all requests. - - * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that - makes use of a file-based sqlite database. This pool disables any actual "pooling" - behavior, and simply opens and closes real connections corresonding to the :func:`connect()` - and :func:`close()` methods. SQLite can "connect" to a particular file with very high - efficiency, so this option may actually perform better without the extra overhead - of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection - useless since the database would be lost as soon as the connection is "returned" to the pool. - -Date and Time Types -------------------- - -SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide -out of the box functionality for translating values between Python `datetime` objects -and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime` -and related types provide date formatting and parsing functionality when SQlite is used. -The implementation classes are :class:`SLDateTime`, :class:`SLDate` and :class:`SLTime`. -These types represent dates and times as ISO formatted strings, which also nicely -support ordering. There's no reliance on typical "libc" internals for these functions -so historical dates are fully supported. - -Unicode -------- - -In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's -default behavior regarding Unicode is that all strings are returned as Python unicode objects -in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is -*not* used, you will still always receive unicode data back from a result set. It is -**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type -to represent strings, since it will raise a warning if a non-unicode Python string is -passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can -quickly create confusion, particularly when using the ORM as internal data is not -always represented by an actual database result string. - -""" - - -import datetime, re, time - -from sqlalchemy import sql, schema, exc, pool, DefaultClause -from sqlalchemy.engine import default -import sqlalchemy.types as sqltypes -import sqlalchemy.util as util -from sqlalchemy.sql import compiler, functions as sql_functions -from types import NoneType - -class SLNumeric(sqltypes.Numeric): - def bind_processor(self, dialect): - type_ = self.asdecimal and str or float - def process(value): - if value is not None: - return type_(value) - else: - return value - return process - - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class SLFloat(sqltypes.Float): - def bind_processor(self, dialect): - type_ = self.asdecimal and str or float - def process(value): - if value is not None: - return type_(value) - else: - return value - return process - - def get_col_spec(self): - return "FLOAT" - -class SLInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class SLSmallInteger(sqltypes.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class DateTimeMixin(object): - def _bind_processor(self, format, elements): - def process(value): - if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)): - raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.") - elif value is not None: - return format % tuple([getattr(value, attr, 0) for attr in elements]) - else: - return None - return process - - def _result_processor(self, fn, regexp): - def process(value): - if value is not None: - return fn(*[int(x or 0) for x in regexp.match(value).groups()]) - else: - return None - return process - -class SLDateTime(DateTimeMixin, sqltypes.DateTime): - __legacy_microseconds__ = False - - def get_col_spec(self): - return "TIMESTAMP" - - def bind_processor(self, dialect): - if self.__legacy_microseconds__: - return self._bind_processor( - "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s", - ("year", "month", "day", "hour", "minute", "second", "microsecond") - ) - else: - return self._bind_processor( - "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d", - ("year", "month", "day", "hour", "minute", "second", "microsecond") - ) - - _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?") - def result_processor(self, dialect): - return self._result_processor(datetime.datetime, self._reg) - -class SLDate(DateTimeMixin, sqltypes.Date): - def get_col_spec(self): - return "DATE" - - def bind_processor(self, dialect): - return self._bind_processor( - "%4.4d-%2.2d-%2.2d", - ("year", "month", "day") - ) - - _reg = re.compile(r"(\d+)-(\d+)-(\d+)") - def result_processor(self, dialect): - return self._result_processor(datetime.date, self._reg) - -class SLTime(DateTimeMixin, sqltypes.Time): - __legacy_microseconds__ = False - - def get_col_spec(self): - return "TIME" - - def bind_processor(self, dialect): - if self.__legacy_microseconds__: - return self._bind_processor( - "%2.2d:%2.2d:%2.2d.%s", - ("hour", "minute", "second", "microsecond") - ) - else: - return self._bind_processor( - "%2.2d:%2.2d:%2.2d.%06d", - ("hour", "minute", "second", "microsecond") - ) - - _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") - def result_processor(self, dialect): - return self._result_processor(datetime.time, self._reg) - -class SLUnicodeMixin(object): - def bind_processor(self, dialect): - if self.convert_unicode or dialect.convert_unicode: - if self.assert_unicode is None: - assert_unicode = dialect.assert_unicode - else: - assert_unicode = self.assert_unicode - - if not assert_unicode: - return None - - def process(value): - if not isinstance(value, (unicode, NoneType)): - if assert_unicode == 'warn': - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) - return value - else: - raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) - else: - return value - return process - else: - return None - - def result_processor(self, dialect): - return None - -class SLText(SLUnicodeMixin, sqltypes.Text): - def get_col_spec(self): - return "TEXT" - -class SLString(SLUnicodeMixin, sqltypes.String): - def get_col_spec(self): - return "VARCHAR" + (self.length and "(%d)" % self.length or "") - -class SLChar(SLUnicodeMixin, sqltypes.CHAR): - def get_col_spec(self): - return "CHAR" + (self.length and "(%d)" % self.length or "") - -class SLBinary(sqltypes.Binary): - def get_col_spec(self): - return "BLOB" - -class SLBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BOOLEAN" - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - return value and 1 or 0 - return process - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - -colspecs = { - sqltypes.Binary: SLBinary, - sqltypes.Boolean: SLBoolean, - sqltypes.CHAR: SLChar, - sqltypes.Date: SLDate, - sqltypes.DateTime: SLDateTime, - sqltypes.Float: SLFloat, - sqltypes.Integer: SLInteger, - sqltypes.NCHAR: SLChar, - sqltypes.Numeric: SLNumeric, - sqltypes.Smallinteger: SLSmallInteger, - sqltypes.String: SLString, - sqltypes.Text: SLText, - sqltypes.Time: SLTime, -} - -ischema_names = { - 'BLOB': SLBinary, - 'BOOL': SLBoolean, - 'BOOLEAN': SLBoolean, - 'CHAR': SLChar, - 'DATE': SLDate, - 'DATETIME': SLDateTime, - 'DECIMAL': SLNumeric, - 'FLOAT': SLNumeric, - 'INT': SLInteger, - 'INTEGER': SLInteger, - 'NUMERIC': SLNumeric, - 'REAL': SLNumeric, - 'SMALLINT': SLSmallInteger, - 'TEXT': SLText, - 'TIME': SLTime, - 'TIMESTAMP': SLDateTime, - 'VARCHAR': SLString, -} - -class SQLiteExecutionContext(default.DefaultExecutionContext): - def post_exec(self): - if self.compiled.isinsert and not self.executemany: - if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: - self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - -class SQLiteDialect(default.DefaultDialect): - name = 'sqlite' - supports_alter = False - supports_unicode_statements = True - default_paramstyle = 'qmark' - supports_default_values = True - supports_empty_insert = False - - def __init__(self, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - def vers(num): - return tuple([int(x) for x in num.split('.')]) - if self.dbapi is not None: - sqlite_ver = self.dbapi.version_info - if sqlite_ver < (2, 1, '3'): - util.warn( - ("The installed version of pysqlite2 (%s) is out-dated " - "and will cause errors in some cases. Version 2.1.3 " - "or greater is recommended.") % - '.'.join([str(subver) for subver in sqlite_ver])) - if self.dbapi.sqlite_version_info < (3, 3, 8): - self.supports_default_values = False - self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) - - def dbapi(cls): - try: - from pysqlite2 import dbapi2 as sqlite - except ImportError, e: - try: - from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. - except ImportError: - raise e - return sqlite - dbapi = classmethod(dbapi) - - def server_version_info(self, connection): - return self.dbapi.sqlite_version_info - - def create_connect_args(self, url): - if url.username or url.password or url.host or url.port: - raise exc.ArgumentError( - "Invalid SQLite URL: %s\n" - "Valid SQLite URL forms are:\n" - " sqlite:///:memory: (or, sqlite://)\n" - " sqlite:///relative/path/to/file.db\n" - " sqlite:////absolute/path/to/file.db" % (url,)) - filename = url.database or ':memory:' - - opts = url.query.copy() - util.coerce_kw_type(opts, 'timeout', float) - util.coerce_kw_type(opts, 'isolation_level', str) - util.coerce_kw_type(opts, 'detect_types', int) - util.coerce_kw_type(opts, 'check_same_thread', bool) - util.coerce_kw_type(opts, 'cached_statements', int) - - return ([filename], opts) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def is_disconnect(self, e): - return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) - - def table_names(self, connection, schema): - if schema is not None: - qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema - s = ("SELECT name FROM %s " - "WHERE type='table' ORDER BY name") % (master,) - rs = connection.execute(s) - else: - try: - s = ("SELECT name FROM " - " (SELECT * FROM sqlite_master UNION ALL " - " SELECT * FROM sqlite_temp_master) " - "WHERE type='table' ORDER BY name") - rs = connection.execute(s) - except exc.DBAPIError: - raise - s = ("SELECT name FROM sqlite_master " - "WHERE type='table' ORDER BY name") - rs = connection.execute(s) - - return [row[0] for row in rs] - - def has_table(self, connection, table_name, schema=None): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - cursor = connection.execute("%stable_info(%s)" % (pragma, qtable)) - row = cursor.fetchone() - - # consume remaining rows, to work around - # http://www.sqlite.org/cvstrac/tktview?tn=1884 - while cursor.fetchone() is not None: - pass - - return (row is not None) - - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - if table.schema is None: - pragma = "PRAGMA " - else: - pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema) - qtable = preparer.format_table(table, False) - - c = connection.execute("%stable_info(%s)" % (pragma, qtable)) - found_table = False - while True: - row = c.fetchone() - if row is None: - break - - found_table = True - (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5]) - name = re.sub(r'^\"|\"$', '', name) - if include_columns and name not in include_columns: - continue - match = re.match(r'(\w+)(\(.*?\))?', type_) - if match: - coltype = match.group(1) - args = match.group(2) - else: - coltype = "VARCHAR" - args = '' - - try: - coltype = ischema_names[coltype] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, name)) - coltype = sqltypes.NullType - - if args is not None: - args = re.findall(r'(\d+)', args) - coltype = coltype(*[int(a) for a in args]) - - colargs = [] - if has_default: - colargs.append(DefaultClause(sql.text(default))) - table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable)) - fks = {} - while True: - row = c.fetchone() - if row is None: - break - (constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4]) - tablename = re.sub(r'^\"|\"$', '', tablename) - localcol = re.sub(r'^\"|\"$', '', localcol) - remotecol = re.sub(r'^\"|\"$', '', remotecol) - try: - fk = fks[constraint_name] - except KeyError: - fk = ([], []) - fks[constraint_name] = fk - - # look up the table based on the given table's engine, not 'self', - # since it could be a ProxyEngine - remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection) - constrained_column = table.c[localcol].name - refspec = ".".join([tablename, remotecol]) - if constrained_column not in fk[0]: - fk[0].append(constrained_column) - if refspec not in fk[1]: - fk[1].append(refspec) - for name, value in fks.iteritems(): - table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], link_to_name=True)) - # check for UNIQUE indexes - c = connection.execute("%sindex_list(%s)" % (pragma, qtable)) - unique_indexes = [] - while True: - row = c.fetchone() - if row is None: - break - if (row[2] == 1): - unique_indexes.append(row[1]) - # loop thru unique indexes for one that includes the primary key - for idx in unique_indexes: - c = connection.execute("%sindex_info(%s)" % (pragma, idx)) - cols = [] - while True: - row = c.fetchone() - if row is None: - break - cols.append(row[2]) - - -class SQLiteCompiler(compiler.DefaultCompiler): - functions = compiler.DefaultCompiler.functions.copy() - functions.update ( - { - sql_functions.now: 'CURRENT_TIMESTAMP', - sql_functions.char_length: 'length%(expr)s' - } - ) - - def visit_cast(self, cast, **kwargs): - if self.dialect.supports_cast: - return super(SQLiteCompiler, self).visit_cast(cast) - else: - return self.process(cast.clause) - - def limit_clause(self, select): - text = "" - if select._limit is not None: - text += " \n LIMIT " + str(select._limit) - if select._offset is not None: - if select._limit is None: - text += " \n LIMIT -1" - text += " OFFSET " + str(select._offset) - else: - text += " OFFSET 0" - return text - - def for_update_clause(self, select): - # sqlite has no "FOR UPDATE" AFAICT - return '' - - -class SQLiteSchemaGenerator(compiler.SchemaGenerator): - - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - return colspec - -class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = set([ - 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', - 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', - 'conflict', 'constraint', 'create', 'cross', 'current_date', - 'current_time', 'current_timestamp', 'database', 'default', - 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', - 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', - 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', - 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', - 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is', - 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural', - 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer', - 'plan', 'pragma', 'primary', 'query', 'raise', 'references', - 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback', - 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', - 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using', - 'vacuum', 'values', 'view', 'virtual', 'when', 'where', - ]) - - def __init__(self, dialect): - super(SQLiteIdentifierPreparer, self).__init__(dialect) - -dialect = SQLiteDialect -dialect.poolclass = pool.SingletonThreadPool -dialect.statement_compiler = SQLiteCompiler -dialect.schemagenerator = SQLiteSchemaGenerator -dialect.preparer = SQLiteIdentifierPreparer -dialect.execution_ctx_cls = SQLiteExecutionContext diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py index 6007315f26..0cf0eeaf56 100644 --- a/lib/sqlalchemy/databases/sybase.py +++ b/lib/sqlalchemy/databases/sybase.py @@ -727,8 +727,8 @@ dialect_mapping = { } -class SybaseSQLCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() +class SybaseSQLCompiler(compiler.SQLCompiler): + operators = compiler.SQLCompiler.operators.copy() operators.update({ sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y), }) diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py new file mode 100644 index 0000000000..075e897fa8 --- /dev/null +++ b/lib/sqlalchemy/dialects/__init__.py @@ -0,0 +1,12 @@ +__all__ = ( +# 'access', +# 'firebird', +# 'informix', +# 'maxdb', +# 'mssql', +# 'mysql', +# 'oracle', + 'postgres', + 'sqlite', +# 'sybase', + ) diff --git a/lib/sqlalchemy/dialects/postgres/__init__.py b/lib/sqlalchemy/dialects/postgres/__init__.py new file mode 100644 index 0000000000..c9ac0e1e5a --- /dev/null +++ b/lib/sqlalchemy/dialects/postgres/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.postgres import base, psycopg2 + +base.dialect = psycopg2.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/dialects/postgres/base.py similarity index 63% rename from lib/sqlalchemy/databases/postgres.py rename to lib/sqlalchemy/dialects/postgres/base.py index fe5ffe24a0..d33a6db935 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -4,90 +4,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Support for the PostgreSQL database. - -Driver ------- - -The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ . -The dialect has several behaviors which are specifically tailored towards compatibility -with this module. - -Note that psycopg1 is **not** supported. - -Connecting ----------- - -URLs are of the form `postgres://user:password@host:port/dbname[?key=value&key=value...]`. - -Postgres-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are: - -* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support - this feature. What this essentially means from a psycopg2 point of view is that the cursor is - created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows - are not immediately pre-fetched and buffered after statement execution, but are instead left - on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy` - uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows - at a time are fetched over the wire to reduce conversational overhead. - -Sequences/SERIAL ----------------- - -Postgres supports sequences, and SQLAlchemy uses these as the default means of creating -new primary key values for integer-based primary key columns. When creating tables, -SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns, -which generates a sequence corresponding to the column and associated with it based on -a naming convention. - -To specify a specific named sequence to be used for primary key generation, use the -:func:`~sqlalchemy.schema.Sequence` construct:: - - Table('sometable', metadata, - Column('id', Integer, Sequence('some_id_seq'), primary_key=True) - ) - -Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of -having the "last insert identifier" available, the sequence is executed independently -beforehand and the new value is retrieved, to be used in the subsequent insert. Note -that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using -"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior -is used. - -Postgres 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports -as well. A future release of SQLA will use this feature by default in lieu of -sequence pre-execution in order to retrieve new primary key values, when available. - -INSERT/UPDATE...RETURNING -------------------------- - -The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes, -but must be explicitly enabled on a per-statement basis:: - - # INSERT..RETURNING - result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\ - values(name='foo') - print result.fetchall() - - # UPDATE..RETURNING - result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\ - where(table.c.name=='foo').values(name='bar') - print result.fetchall() - -Indexes -------- - -PostgreSQL supports partial indexes. To create them pass a postgres_where -option to the Index constructor:: - - Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10) - -Transactions ------------- - -The Postgres dialect fully supports SAVEPOINT and two-phase commit operations. - - -""" import decimal, random, re, string @@ -99,101 +15,23 @@ from sqlalchemy import types as sqltypes class PGInet(sqltypes.TypeEngine): - def get_col_spec(self): - return "INET" + __visit_name__ = "INET" class PGCidr(sqltypes.TypeEngine): - def get_col_spec(self): - return "CIDR" + __visit_name__ = "CIDR" class PGMacAddr(sqltypes.TypeEngine): - def get_col_spec(self): - return "MACADDR" - -class PGNumeric(sqltypes.Numeric): - def get_col_spec(self): - if not self.precision: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - - def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): - if self.asdecimal: - return None - else: - def process(value): - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - return process - -class PGFloat(sqltypes.Float): - def get_col_spec(self): - if not self.precision: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - -class PGInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class PGSmallInteger(sqltypes.Smallinteger): - def get_col_spec(self): - return "SMALLINT" - -class PGBigInteger(PGInteger): - def get_col_spec(self): - return "BIGINT" - -class PGDateTime(sqltypes.DateTime): - def get_col_spec(self): - return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" - -class PGDate(sqltypes.Date): - def get_col_spec(self): - return "DATE" + __visit_name__ = "MACADDR" -class PGTime(sqltypes.Time): - def get_col_spec(self): - return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" +class PGBigInteger(sqltypes.Integer): + __visit_name__ = "BIGINT" class PGInterval(sqltypes.TypeEngine): - def get_col_spec(self): - return "INTERVAL" - -class PGText(sqltypes.Text): - def get_col_spec(self): - return "TEXT" - -class PGString(sqltypes.String): - def get_col_spec(self): - if self.length: - return "VARCHAR(%(length)d)" % {'length' : self.length} - else: - return "VARCHAR" - -class PGChar(sqltypes.CHAR): - def get_col_spec(self): - if self.length: - return "CHAR(%(length)d)" % {'length' : self.length} - else: - return "CHAR" - -class PGBinary(sqltypes.Binary): - def get_col_spec(self): - return "BYTEA" - -class PGBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BOOLEAN" + __visit_name__ = 'INTERVAL' class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): + __visit_name__ = 'ARRAY' + def __init__(self, item_type, mutable=True): if isinstance(item_type, type): item_type = item_type() @@ -251,114 +89,233 @@ class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): return item return [convert_item(item) for item in value] return process - def get_col_spec(self): - return self.item_type.get_col_spec() + '[]' - -colspecs = { - sqltypes.Integer : PGInteger, - sqltypes.Smallinteger : PGSmallInteger, - sqltypes.Numeric : PGNumeric, - sqltypes.Float : PGFloat, - sqltypes.DateTime : PGDateTime, - sqltypes.Date : PGDate, - sqltypes.Time : PGTime, - sqltypes.String : PGString, - sqltypes.Binary : PGBinary, - sqltypes.Boolean : PGBoolean, - sqltypes.Text : PGText, - sqltypes.CHAR: PGChar, -} - -ischema_names = { - 'integer' : PGInteger, - 'bigint' : PGBigInteger, - 'smallint' : PGSmallInteger, - 'character varying' : PGString, - 'character' : PGChar, - 'text' : PGText, - 'numeric' : PGNumeric, - 'float' : PGFloat, - 'real' : PGFloat, - 'inet': PGInet, - 'cidr': PGCidr, - 'macaddr': PGMacAddr, - 'double precision' : PGFloat, - 'timestamp' : PGDateTime, - 'timestamp with time zone' : PGDateTime, - 'timestamp without time zone' : PGDateTime, - 'time with time zone' : PGTime, - 'time without time zone' : PGTime, - 'date' : PGDate, - 'time': PGTime, - 'bytea' : PGBinary, - 'boolean' : PGBoolean, - 'interval':PGInterval, -} - -# TODO: filter out 'FOR UPDATE' statements -SERVER_SIDE_CURSOR_RE = re.compile( - r'\s*SELECT', - re.I | re.UNICODE) - -class PGExecutionContext(default.DefaultExecutionContext): - def create_cursor(self): - # TODO: coverage for server side cursors + select.for_update() - is_server_side = \ - self.dialect.server_side_cursors and \ - ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) - and not getattr(self.compiled.statement, 'for_update', False)) \ - or \ - ( - (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) - and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) - ) - self.__is_server_side = is_server_side - if is_server_side: - # use server-side cursors: - # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:]) - return self._connection.connection.cursor(ident) + + + + +class PGCompiler(compiler.SQLCompiler): + operators = compiler.SQLCompiler.operators.copy() + operators.update( + { + sql_operators.mod : '%%', + sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), + sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), + sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y), + } + ) + + functions = compiler.SQLCompiler.functions.copy() + functions.update ( + { + 'TIMESTAMP':lambda x:'TIMESTAMP %s' % x, + } + ) + + def visit_sequence(self, seq): + if seq.optional: + return None + else: + return "nextval('%s')" % self.preparer.format_sequence(seq) + + def post_process_text(self, text): + if '%%' in text: + util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.") + return text.replace('%', '%%') + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT ALL" + text += " OFFSET " + str(select._offset) + return text + + def get_select_precolumns(self, select): + if select._distinct: + if isinstance(select._distinct, bool): + return "DISTINCT " + elif isinstance(select._distinct, (list, tuple)): + return "DISTINCT ON (" + ', '.join( + [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct] + )+ ") " + else: + return "DISTINCT ON (" + unicode(select._distinct) + ") " + else: + return "" + + def for_update_clause(self, select): + if select.for_update == 'nowait': + return " FOR UPDATE NOWAIT" + else: + return super(PGCompiler, self).for_update_clause(select) + + def _append_returning(self, text, stmt): + returning_cols = stmt.kwargs['postgres_returning'] + def flatten_columnlist(collist): + for c in collist: + if isinstance(c, expression.Selectable): + for co in c.columns: + yield co + else: + yield c + columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)] + text += ' RETURNING ' + string.join(columns, ', ') + return text + + def visit_update(self, update_stmt): + text = super(PGCompiler, self).visit_update(update_stmt) + if 'postgres_returning' in update_stmt.kwargs: + return self._append_returning(text, update_stmt) + else: + return text + + def visit_insert(self, insert_stmt): + text = super(PGCompiler, self).visit_insert(insert_stmt) + if 'postgres_returning' in insert_stmt.kwargs: + return self._append_returning(text, insert_stmt) + else: + return text + +class PGDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + if column.primary_key and \ + len(column.foreign_keys)==0 and \ + column.autoincrement and \ + isinstance(column.type, sqltypes.Integer) and \ + not isinstance(column.type, sqltypes.SmallInteger) and \ + (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if isinstance(column.type, PGBigInteger): + colspec += " BIGSERIAL" + else: + colspec += " SERIAL" else: - return self._connection.connection.cursor() + colspec += " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + + def visit_create_sequence(self, create): + return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + + def visit_drop_sequence(self, drop): + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + + def visit_create_index(self, create): + preparer = self.preparer + index = create.element + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % (preparer.quote(self._validate_identifier(index.name, True), index.quote), + preparer.format_table(index.table), + string.join([preparer.format_column(c) for c in index.columns], ', ')) + + whereclause = index.kwargs.get('postgres_where', None) + if whereclause is not None: + compiler = self._compile(whereclause, None) + # this might belong to the compiler class + inlined_clause = str(compiler) % dict( + [(key,bind.value) for key,bind in compiler.binds.iteritems()]) + text += " WHERE " + inlined_clause + return text + +class PGDefaultRunner(base.DefaultRunner): + def __init__(self, context): + base.DefaultRunner.__init__(self, context) + # craete cursor which won't conflict with a server-side cursor + self.cursor = context._connection.connection.cursor() - def get_result_proxy(self): - if self.__is_server_side: - return base.BufferedRowResultProxy(self) + def get_column_default(self, column, isinsert=True): + if column.primary_key: + # pre-execute passive defaults on primary keys + if (isinstance(column.server_default, schema.DefaultClause) and + column.server_default.arg is not None): + return self.execute_string("select %s" % column.server_default.arg) + elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + sch = column.table.schema + # TODO: this has to build into the Sequence object so we can get the quoting + # logic from it + if sch is not None: + exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) + else: + exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) + return self.execute_string(exc.encode(self.dialect.encoding)) + + return super(PGDefaultRunner, self).get_column_default(column) + + def visit_sequence(self, seq): + if not seq.optional: + return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq))) + else: + return None + +class PGTypeCompiler(compiler.GenericTypeCompiler): + def visit_INET(self, type_): + return "INET" + + def visit_CIDR(self, type_): + return "CIDR" + + def visit_MACADDR(self, type_): + return "MACADDR" + + def visit_FLOAT(self, type_): + if not type_.precision: + return "FLOAT" else: - return base.ResultProxy(self) + return "FLOAT(%(precision)s)" % {'precision': type_.precision} + + def visit_BIGINT(self, type_): + return "BIGINT" + + def visit_DATETIME(self, type_): + return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + + def visit_TIME(self, type_): + return "TIME " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + + def visit_INTERVAL(self, type_): + return "INTERVAL" + + def visit_BINARY(self, type_): + return "BYTEA" + + def visit_ARRAY(self, type_): + return self.process(type_.item_type) + '[]' + +class PGIdentifierPreparer(compiler.IdentifierPreparer): + def _unquote_identifier(self, value): + if value[0] == self.initial_quote: + value = value[1:-1].replace('""','"') + return value class PGDialect(default.DefaultDialect): name = 'postgres' supports_alter = True - supports_unicode_statements = False max_identifier_length = 63 supports_sane_rowcount = True - supports_sane_multi_rowcount = False + supports_sequences = True + sequences_optional = True preexecute_pk_sequences = True supports_pk_autoincrement = False - default_paramstyle = 'pyformat' supports_default_values = True supports_empty_insert = False - - def __init__(self, server_side_cursors=False, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - self.server_side_cursors = server_side_cursors - def dbapi(cls): - import psycopg2 as psycopg - return psycopg - dbapi = classmethod(dbapi) + statement_compiler = PGCompiler + ddl_compiler = PGDDLCompiler + type_compiler = PGTypeCompiler + preparer = PGIdentifierPreparer + defaultrunner = PGDefaultRunner - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['port'] = int(opts['port']) - opts.update(url.query) - return ([], opts) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) def do_begin_twophase(self, connection, xid): self.do_begin(connection.connection) @@ -392,48 +349,46 @@ class PGDialect(default.DefaultDialect): resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts")) return [row[0] for row in resultset] + @base.connection_memoize(('dialect', 'default_schema_name')) def get_default_schema_name(self, connection): return connection.scalar("select current_schema()", None) - get_default_schema_name = base.connection_memoize( - ('dialect', 'default_schema_name'))(get_default_schema_name) - - def last_inserted_ids(self): - if self.context.last_inserted_ids is None: - raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled") - else: - return self.context.last_inserted_ids def has_table(self, connection, table_name, schema=None): # seems like case gets folded in pg_class... if schema is None: - cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)}); + cursor = connection.execute( + sql.text("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=:name""", + bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode)] + ) + ) else: - cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema}); - return bool( not not cursor.rowcount ) + cursor = connection.execute( + sql.text("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=:schema and lower(relname)=:name""", + bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode), + sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)] + ) + ) + return bool( cursor.rowcount ) def has_sequence(self, connection, sequence_name): - cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name.encode(self.encoding)}) - return bool(not not cursor.rowcount) - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return 'closed the connection' in str(e) or 'connection not open' in str(e) - elif isinstance(e, self.dbapi.InterfaceError): - return 'connection already closed' in str(e) or 'cursor already closed' in str(e) - elif isinstance(e, self.dbapi.ProgrammingError): - # yes, it really says "losed", not "closed" - return "losed the connection unexpectedly" in str(e) - else: - return False + cursor = connection.execute( + sql.text("SELECT relname FROM pg_class WHERE relkind = 'S' AND " + "relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' " + "AND nspname != 'information_schema' AND relname = :seqname)", + bindparams=[sql.bindparam('seqname', unicode(sequence_name), type_=sqltypes.Unicode)] + )) + return bool(cursor.rowcount) def table_names(self, connection, schema): - s = """ - SELECT relname - FROM pg_class c - WHERE relkind = 'r' - AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace) - """ % locals() - return [row[0].decode(self.encoding) for row in connection.execute(s)] + result = connection.execute( + sql.text(u"""SELECT relname + FROM pg_class c + WHERE relkind = 'r' + AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)""" % schema, + typemap = {'relname':sqltypes.Unicode} + ) + ) + return [row[0] for row in result] def server_version_info(self, connection): v = connection.execute("select version()").scalar() @@ -525,19 +480,19 @@ class PGDialect(default.DefaultDialect): elif attype == 'timestamp without time zone': kwargs['timezone'] = False - if attype in ischema_names: - coltype = ischema_names[attype] + if attype in self.ischema_names: + coltype = self.ischema_names[attype] else: if attype in domains: domain = domains[attype] - if domain['attype'] in ischema_names: + if domain['attype'] in self.ischema_names: # A table can't override whether the domain is nullable. nullable = domain['nullable'] if domain['default'] and not default: # It can, however, override the default value, but can't set it to null. default = domain['default'] - coltype = ischema_names[domain['attype']] + coltype = self.ischema_names[domain['attype']] else: coltype = None @@ -693,180 +648,3 @@ class PGDialect(default.DefaultDialect): return domains - - -class PGCompiler(compiler.DefaultCompiler): - operators = compiler.DefaultCompiler.operators.copy() - operators.update( - { - sql_operators.mod : '%%', - sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''), - sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y), - } - ) - - functions = compiler.DefaultCompiler.functions.copy() - functions.update ( - { - 'TIMESTAMP':lambda x:'TIMESTAMP %s' % x, - } - ) - - def visit_sequence(self, seq): - if seq.optional: - return None - else: - return "nextval('%s')" % self.preparer.format_sequence(seq) - - def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy psycopg2 dialect now automatically escapes '%' in text() expressions to '%%'.") - return text.replace('%', '%%') - - def limit_clause(self, select): - text = "" - if select._limit is not None: - text += " \n LIMIT " + str(select._limit) - if select._offset is not None: - if select._limit is None: - text += " \n LIMIT ALL" - text += " OFFSET " + str(select._offset) - return text - - def get_select_precolumns(self, select): - if select._distinct: - if isinstance(select._distinct, bool): - return "DISTINCT " - elif isinstance(select._distinct, (list, tuple)): - return "DISTINCT ON (" + ', '.join( - [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct] - )+ ") " - else: - return "DISTINCT ON (" + unicode(select._distinct) + ") " - else: - return "" - - def for_update_clause(self, select): - if select.for_update == 'nowait': - return " FOR UPDATE NOWAIT" - else: - return super(PGCompiler, self).for_update_clause(select) - - def _append_returning(self, text, stmt): - returning_cols = stmt.kwargs['postgres_returning'] - def flatten_columnlist(collist): - for c in collist: - if isinstance(c, expression.Selectable): - for co in c.columns: - yield co - else: - yield c - columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)] - text += ' RETURNING ' + string.join(columns, ', ') - return text - - def visit_update(self, update_stmt): - text = super(PGCompiler, self).visit_update(update_stmt) - if 'postgres_returning' in update_stmt.kwargs: - return self._append_returning(text, update_stmt) - else: - return text - - def visit_insert(self, insert_stmt): - text = super(PGCompiler, self).visit_insert(insert_stmt) - if 'postgres_returning' in insert_stmt.kwargs: - return self._append_returning(text, insert_stmt) - else: - return text - -class PGSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) - if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): - if isinstance(column.type, PGBigInteger): - colspec += " BIGSERIAL" - else: - colspec += " SERIAL" - else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - return colspec - - def visit_sequence(self, sequence): - if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)): - self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - - def visit_index(self, index): - preparer = self.preparer - self.append("CREATE ") - if index.unique: - self.append("UNIQUE ") - self.append("INDEX %s ON %s (%s)" \ - % (preparer.quote(self._validate_identifier(index.name, True), index.quote), - preparer.format_table(index.table), - string.join([preparer.format_column(c) for c in index.columns], ', '))) - whereclause = index.kwargs.get('postgres_where', None) - if whereclause is not None: - compiler = self._compile(whereclause, None) - # this might belong to the compiler class - inlined_clause = str(compiler) % dict( - [(key,bind.value) for key,bind in compiler.binds.iteritems()]) - self.append(" WHERE " + inlined_clause) - self.execute() - -class PGSchemaDropper(compiler.SchemaDropper): - def visit_sequence(self, sequence): - if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): - self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - -class PGDefaultRunner(base.DefaultRunner): - def __init__(self, context): - base.DefaultRunner.__init__(self, context) - # craete cursor which won't conflict with a server-side cursor - self.cursor = context._connection.connection.cursor() - - def get_column_default(self, column, isinsert=True): - if column.primary_key: - # pre-execute passive defaults on primary keys - if (isinstance(column.server_default, schema.DefaultClause) and - column.server_default.arg is not None): - return self.execute_string("select %s" % column.server_default.arg) - elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): - sch = column.table.schema - # TODO: this has to build into the Sequence object so we can get the quoting - # logic from it - if sch is not None: - exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) - else: - exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - return self.execute_string(exc.encode(self.dialect.encoding)) - - return super(PGDefaultRunner, self).get_column_default(column) - - def visit_sequence(self, seq): - if not seq.optional: - return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq))) - else: - return None - -class PGIdentifierPreparer(compiler.IdentifierPreparer): - def _unquote_identifier(self, value): - if value[0] == self.initial_quote: - value = value[1:-1].replace('""','"') - return value - -dialect = PGDialect -dialect.statement_compiler = PGCompiler -dialect.schemagenerator = PGSchemaGenerator -dialect.schemadropper = PGSchemaDropper -dialect.preparer = PGIdentifierPreparer -dialect.defaultrunner = PGDefaultRunner -dialect.execution_ctx_cls = PGExecutionContext diff --git a/lib/sqlalchemy/dialects/postgres/psycopg2.py b/lib/sqlalchemy/dialects/postgres/psycopg2.py new file mode 100644 index 0000000000..5cda71bb55 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgres/psycopg2.py @@ -0,0 +1,215 @@ +"""Support for the PostgreSQL database via the psycopg2 driver. + +Driver +------ + +The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ . +The dialect has several behaviors which are specifically tailored towards compatibility +with this module. + +Note that psycopg1 is **not** supported. + +Connecting +---------- + +URLs are of the form `postgres+psycopg2://user@password@host:port/dbname[?key=value&key=value...]`. + +psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are: + +* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support + this feature. What this essentially means from a psycopg2 point of view is that the cursor is + created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows + are not immediately pre-fetched and buffered after statement execution, but are instead left + on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy` + uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows + at a time are fetched over the wire to reduce conversational overhead. + +Sequences/SERIAL +---------------- + +Postgres supports sequences, and SQLAlchemy uses these as the default means of creating +new primary key values for integer-based primary key columns. When creating tables, +SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns, +which generates a sequence corresponding to the column and associated with it based on +a naming convention. + +To specify a specific named sequence to be used for primary key generation, use the +:func:`~sqlalchemy.schema.Sequence` construct:: + + Table('sometable', metadata, + Column('id', Integer, Sequence('some_id_seq'), primary_key=True) + ) + +Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of +having the "last insert identifier" available, the sequence is executed independently +beforehand and the new value is retrieved, to be used in the subsequent insert. Note +that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using +"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior +is used. + +Postgres 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports +as well. A future release of SQLA will use this feature by default in lieu of +sequence pre-execution in order to retrieve new primary key values, when available. + +INSERT/UPDATE...RETURNING +------------------------- + +The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes, +but must be explicitly enabled on a per-statement basis:: + + # INSERT..RETURNING + result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\ + values(name='foo') + print result.fetchall() + + # UPDATE..RETURNING + result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\ + where(table.c.name=='foo').values(name='bar') + print result.fetchall() + +Indexes +------- + +PostgreSQL supports partial indexes. To create them pass a postgres_where +option to the Index constructor:: + + Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10) + +Transactions +------------ + +The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. + + +""" + +import decimal, random, re, string + +from sqlalchemy import sql, schema, exc, util +from sqlalchemy.engine import base, default +from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.postgres.base import PGDialect, PGInet, PGCidr, PGMacAddr, PGArray, \ + PGBigInteger, PGInterval + +class PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect): + if self.asdecimal: + return None + else: + def process(value): + if isinstance(value, decimal.Decimal): + return float(value) + else: + return value + return process + + +colspecs = { + sqltypes.Numeric : PGNumeric, + sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used +} + +ischema_names = { + 'integer' : sqltypes.Integer, + 'bigint' : PGBigInteger, + 'smallint' : sqltypes.SmallInteger, + 'character varying' : sqltypes.String, + 'character' : sqltypes.CHAR, + 'text' : sqltypes.Text, + 'numeric' : PGNumeric, + 'float' : sqltypes.Float, + 'real' : sqltypes.Float, + 'inet': PGInet, + 'cidr': PGCidr, + 'macaddr': PGMacAddr, + 'double precision' : sqltypes.Float, + 'timestamp' : sqltypes.DateTime, + 'timestamp with time zone' : sqltypes.DateTime, + 'timestamp without time zone' : sqltypes.DateTime, + 'time with time zone' : sqltypes.Time, + 'time without time zone' : sqltypes.Time, + 'date' : sqltypes.Date, + 'time': sqltypes.Time, + 'bytea' : sqltypes.Binary, + 'boolean' : sqltypes.Boolean, + 'interval':PGInterval, +} + +# TODO: filter out 'FOR UPDATE' statements +SERVER_SIDE_CURSOR_RE = re.compile( + r'\s*SELECT', + re.I | re.UNICODE) + +class Postgres_psycopg2ExecutionContext(default.DefaultExecutionContext): + def create_cursor(self): + # TODO: coverage for server side cursors + select.for_update() + is_server_side = \ + self.dialect.server_side_cursors and \ + ((self.compiled and isinstance(self.compiled.statement, expression.Selectable) + and not getattr(self.compiled.statement, 'for_update', False)) \ + or \ + ( + (not self.compiled or isinstance(self.compiled.statement, expression._TextClause)) + and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement)) + ) + + self.__is_server_side = is_server_side + if is_server_side: + # use server-side cursors: + # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html + ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:]) + return self._connection.connection.cursor(ident) + else: + return self._connection.connection.cursor() + + def get_result_proxy(self): + if self.__is_server_side: + return base.BufferedRowResultProxy(self) + else: + return base.ResultProxy(self) + +class Postgres_psycopg2(PGDialect): + driver = 'psycopg2' + supports_unicode_statements = False + default_paramstyle = 'pyformat' + supports_sane_multi_rowcount = False + execution_ctx_cls = Postgres_psycopg2ExecutionContext + ischema_names = ischema_names + + def __init__(self, server_side_cursors=False, **kwargs): + PGDialect.__init__(self, **kwargs) + self.server_side_cursors = server_side_cursors + + @classmethod + def dbapi(cls): + psycopg = __import__('psycopg2') + return psycopg + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if 'port' in opts: + opts['port'] = int(opts['port']) + opts.update(url.query) + return ([], opts) + + def type_descriptor(self, typeobj): + return sqltypes.adapt_type(typeobj, colspecs) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'closed the connection' in str(e) or 'connection not open' in str(e) + elif isinstance(e, self.dbapi.InterfaceError): + return 'connection already closed' in str(e) or 'cursor already closed' in str(e) + elif isinstance(e, self.dbapi.ProgrammingError): + # yes, it really says "losed", not "closed" + return "losed the connection unexpectedly" in str(e) + else: + return False + +dialect = Postgres_psycopg2 + \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py new file mode 100644 index 0000000000..3cc08870f2 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -0,0 +1,4 @@ +from sqlalchemy.dialects.sqlite import base, pysqlite + +# default dialect +base.dialect = pysqlite.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py new file mode 100644 index 0000000000..a080b94ec0 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -0,0 +1,339 @@ +# sqlite.py +# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import datetime, re, time + +from sqlalchemy import sql, schema, exc, pool, DefaultClause +from sqlalchemy.engine import default +from sqlalchemy.engine import default +from sqlalchemy import types as sqltypes +from sqlalchemy import util +from sqlalchemy.sql import compiler, functions as sql_functions +from types import NoneType + +class NumericMixin(object): + def bind_processor(self, dialect): + type_ = self.asdecimal and str or float + def process(value): + if value is not None: + return type_(value) + else: + return value + return process + +class SLNumeric(NumericMixin, sqltypes.Numeric): + pass + +class SLFloat(NumericMixin, sqltypes.Float): + pass + +# since SQLite has no date types, we're assuming that SQLite via ODBC +# or JDBC would similarly have no built in date support, so the "string" based logic +# would apply to all implementing dialects. +class DateTimeMixin(object): + def _bind_processor(self, format, elements): + def process(value): + if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)): + raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.") + elif value is not None: + return format % tuple([getattr(value, attr, 0) for attr in elements]) + else: + return None + return process + + def _result_processor(self, fn, regexp): + def process(value): + if value is not None: + return fn(*[int(x or 0) for x in regexp.match(value).groups()]) + else: + return None + return process + +class SLDateTime(DateTimeMixin, sqltypes.DateTime): + __legacy_microseconds__ = False + + def bind_processor(self, dialect): + if self.__legacy_microseconds__: + return self._bind_processor( + "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s", + ("year", "month", "day", "hour", "minute", "second", "microsecond") + ) + else: + return self._bind_processor( + "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d", + ("year", "month", "day", "hour", "minute", "second", "microsecond") + ) + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?") + def result_processor(self, dialect): + return self._result_processor(datetime.datetime, self._reg) + +class SLDate(DateTimeMixin, sqltypes.Date): + def bind_processor(self, dialect): + return self._bind_processor( + "%4.4d-%2.2d-%2.2d", + ("year", "month", "day") + ) + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + def result_processor(self, dialect): + return self._result_processor(datetime.date, self._reg) + +class SLTime(DateTimeMixin, sqltypes.Time): + __legacy_microseconds__ = False + + def bind_processor(self, dialect): + if self.__legacy_microseconds__: + return self._bind_processor( + "%2.2d:%2.2d:%2.2d.%s", + ("hour", "minute", "second", "microsecond") + ) + else: + return self._bind_processor( + "%2.2d:%2.2d:%2.2d.%06d", + ("hour", "minute", "second", "microsecond") + ) + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") + def result_processor(self, dialect): + return self._result_processor(datetime.time, self._reg) + + +class SLBoolean(sqltypes.Boolean): + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return value and 1 or 0 + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + +class SQLiteCompiler(compiler.SQLCompiler): + functions = compiler.SQLCompiler.functions.copy() + functions.update ( + { + sql_functions.now: 'CURRENT_TIMESTAMP', + sql_functions.char_length: 'length%(expr)s' + } + ) + + def visit_cast(self, cast, **kwargs): + if self.dialect.supports_cast: + return super(SQLiteCompiler, self).visit_cast(cast) + else: + return self.process(cast.clause) + + def limit_clause(self, select): + text = "" + if select._limit is not None: + text += " \n LIMIT " + str(select._limit) + if select._offset is not None: + if select._limit is None: + text += " \n LIMIT -1" + text += " OFFSET " + str(select._offset) + else: + text += " OFFSET 0" + return text + + def for_update_clause(self, select): + # sqlite has no "FOR UPDATE" AFAICT + return '' + + +class SQLiteDDLCompiler(compiler.DDLCompiler): + + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + +class SQLiteTypeCompiler(compiler.GenericTypeCompiler): + def visit_binary(self, type_): + return self.visit_BLOB(type_) + + def visit_CLOB(self, type_): + return self.visit_TEXT(type_) + + def visit_NCHAR(self, type_): + return self.visit_CHAR(type_) + +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = set([ + 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', + 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', + 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', + 'conflict', 'constraint', 'create', 'cross', 'current_date', + 'current_time', 'current_timestamp', 'database', 'default', + 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', + 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', + 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', + 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', + 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is', + 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural', + 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer', + 'plan', 'pragma', 'primary', 'query', 'raise', 'references', + 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback', + 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', + 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using', + 'vacuum', 'values', 'view', 'virtual', 'when', 'where', + ]) + +class SQLiteDialect(default.DefaultDialect): + name = 'sqlite' + supports_alter = False + supports_unicode_statements = True + supports_default_values = True + supports_empty_insert = False + supports_cast = True + statement_compiler = SQLiteCompiler + ddl_compiler = SQLiteDDLCompiler + type_compiler = SQLiteTypeCompiler + preparer = SQLiteIdentifierPreparer + + def table_names(self, connection, schema): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = '%s.sqlite_master' % qschema + s = ("SELECT name FROM %s " + "WHERE type='table' ORDER BY name") % (master,) + rs = connection.execute(s) + else: + try: + s = ("SELECT name FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + except exc.DBAPIError: + raise + s = ("SELECT name FROM sqlite_master " + "WHERE type='table' ORDER BY name") + rs = connection.execute(s) + + return [row[0] for row in rs] + + def has_table(self, connection, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + pragma = "PRAGMA %s." % quote(schema) + else: + pragma = "PRAGMA " + qtable = quote(table_name) + cursor = connection.execute("%stable_info(%s)" % (pragma, qtable)) + row = cursor.fetchone() + + # consume remaining rows, to work around + # http://www.sqlite.org/cvstrac/tktview?tn=1884 + while cursor.fetchone() is not None: + pass + + return (row is not None) + + def reflecttable(self, connection, table, include_columns): + preparer = self.identifier_preparer + if table.schema is None: + pragma = "PRAGMA " + else: + pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema) + qtable = preparer.format_table(table, False) + + c = connection.execute("%stable_info(%s)" % (pragma, qtable)) + found_table = False + while True: + row = c.fetchone() + if row is None: + break + + found_table = True + (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5]) + name = re.sub(r'^\"|\"$', '', name) + if include_columns and name not in include_columns: + continue + match = re.match(r'(\w+)(\(.*?\))?', type_) + if match: + coltype = match.group(1) + args = match.group(2) + else: + coltype = "VARCHAR" + args = '' + + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (coltype, name)) + coltype = sqltypes.NullType + + if args is not None: + args = re.findall(r'(\d+)', args) + coltype = coltype(*[int(a) for a in args]) + + colargs = [] + if has_default: + colargs.append(DefaultClause(sql.text(default))) + table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs)) + + if not found_table: + raise exc.NoSuchTableError(table.name) + + c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable)) + fks = {} + while True: + row = c.fetchone() + if row is None: + break + (constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4]) + tablename = re.sub(r'^\"|\"$', '', tablename) + localcol = re.sub(r'^\"|\"$', '', localcol) + remotecol = re.sub(r'^\"|\"$', '', remotecol) + try: + fk = fks[constraint_name] + except KeyError: + fk = ([], []) + fks[constraint_name] = fk + + # look up the table based on the given table's engine, not 'self', + # since it could be a ProxyEngine + remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection) + constrained_column = table.c[localcol].name + refspec = ".".join([tablename, remotecol]) + if constrained_column not in fk[0]: + fk[0].append(constrained_column) + if refspec not in fk[1]: + fk[1].append(refspec) + for name, value in fks.iteritems(): + table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], link_to_name=True)) + # check for UNIQUE indexes + c = connection.execute("%sindex_list(%s)" % (pragma, qtable)) + unique_indexes = [] + while True: + row = c.fetchone() + if row is None: + break + if (row[2] == 1): + unique_indexes.append(row[1]) + # loop thru unique indexes for one that includes the primary key + for idx in unique_indexes: + c = connection.execute("%sindex_info(%s)" % (pragma, idx)) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + cols.append(row[2]) + diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py new file mode 100644 index 0000000000..55ac8bd27b --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -0,0 +1,264 @@ +"""Support for the SQLite database via pysqlite. + +Driver +------ + +When using Python 2.5 and above, the built in ``sqlite3`` driver is +already installed and no additional installation is needed. Otherwise, +the ``pysqlite2`` driver needs to be present. This is the same driver as +``sqlite3``, just with a different name. + +The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3`` +is loaded. This allows an explicitly installed pysqlite driver to take +precedence over the built in one. As with all dialects, a specific +DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control +this explicitly:: + + from sqlite3 import dbapi2 as sqlite + e = create_engine('sqlite+pysqlite:///file.db', module=sqlite) + +Full documentation on pysqlite is available at: +``_ + +Connect Strings +--------------- + +The file specification for the SQLite database is taken as the "database" portion of +the URL. Note that the format of a url is:: + + driver://user:pass@host/database + +This means that the actual filename to be used starts with the characters to the +**right** of the third slash. So connecting to a relative filepath looks like:: + + # relative path + e = create_engine('sqlite:///path/to/database.db') + +An absolute path, which is denoted by starting with a slash, means you need **four** +slashes:: + + # absolute path + e = create_engine('sqlite:////path/to/database.db') + +To use a Windows path, regular drive specifications and backslashes can be used. +Double backslashes are probably needed:: + + # absolute path on Windows + e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db') + +The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify +``sqlite://`` and nothing else:: + + # in-memory database + e = create_engine('sqlite://') + +Threading Behavior +------------------ + +Pysqlite connections do not support being moved between threads, unless +the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition, +when using an in-memory SQLite database, the full database exists only within +the scope of a single connection. It is reported that an in-memory +database does not support being shared between threads regardless of the +``check_same_thread`` flag - which means that a multithreaded +application **cannot** share data from a ``:memory:`` database across threads +unless access to the connection is limited to a single worker thread which communicates +through a queueing mechanism to concurrent threads. + +To provide a default which accomodates SQLite's default threading capabilities +somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool` +be used by default. This pool maintains a single SQLite connection per thread +that is held open up to a count of five concurrent threads. When more than five threads +are used, a cleanup mechanism will dispose of excess unused connections. + +Two optional pool implementations that may be appropriate for particular SQLite usage scenarios: + + * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded + application using an in-memory database, assuming the threading issues inherent in + pysqlite are somehow accomodated for. This pool holds persistently onto a single connection + which is never closed, and is returned for all requests. + + * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that + makes use of a file-based sqlite database. This pool disables any actual "pooling" + behavior, and simply opens and closes real connections corresonding to the :func:`connect()` + and :func:`close()` methods. SQLite can "connect" to a particular file with very high + efficiency, so this option may actually perform better without the extra overhead + of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection + useless since the database would be lost as soon as the connection is "returned" to the pool. + +Date and Time Types +------------------- + +SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide +out of the box functionality for translating values between Python `datetime` objects +and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime` +and related types provide date formatting and parsing functionality when SQlite is used. +The implementation classes are :class:`SLDateTime`, :class:`SLDate` and :class:`SLTime`. +These types represent dates and times as ISO formatted strings, which also nicely +support ordering. There's no reliance on typical "libc" internals for these functions +so historical dates are fully supported. + +Unicode +------- + +In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's +default behavior regarding Unicode is that all strings are returned as Python unicode objects +in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is +*not* used, you will still always receive unicode data back from a result set. It is +**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type +to represent strings, since it will raise a warning if a non-unicode Python string is +passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can +quickly create confusion, particularly when using the ORM as internal data is not +always represented by an actual database result string. + +""" + +from sqlalchemy.dialects.sqlite.base import SLNumeric, SLFloat, SQLiteDialect, SLBoolean, SLDate, SLDateTime, SLTime +from sqlalchemy import schema, exc, pool +from sqlalchemy.engine import default +from sqlalchemy import types as sqltypes +from sqlalchemy import util +from types import NoneType + +class SLUnicodeMixin(object): + def bind_processor(self, dialect): + if self.convert_unicode or dialect.convert_unicode: + if self.assert_unicode is None: + assert_unicode = dialect.assert_unicode + else: + assert_unicode = self.assert_unicode + + if not assert_unicode: + return None + + def process(value): + if not isinstance(value, (unicode, NoneType)): + if assert_unicode == 'warn': + util.warn("Unicode type received non-unicode bind " + "param value %r" % value) + return value + else: + raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) + else: + return value + return process + else: + return None + + def result_processor(self, dialect): + return None + +class SLText(SLUnicodeMixin, sqltypes.Text): + pass + +class SLString(SLUnicodeMixin, sqltypes.String): + pass + +class SLChar(SLUnicodeMixin, sqltypes.CHAR): + pass + + +colspecs = { + sqltypes.Boolean: SLBoolean, + sqltypes.CHAR: SLChar, + sqltypes.Date: SLDate, + sqltypes.DateTime: SLDateTime, + sqltypes.Float: SLFloat, + sqltypes.NCHAR: SLChar, + sqltypes.Numeric: SLNumeric, + sqltypes.String: SLString, + sqltypes.Text: SLText, + sqltypes.Time: SLTime, +} + +ischema_names = { + 'BLOB': sqltypes.Binary, + 'BOOL': SLBoolean, + 'BOOLEAN': SLBoolean, + 'CHAR': SLChar, + 'DATE': SLDate, + 'DATETIME': SLDateTime, + 'DECIMAL': SLNumeric, + 'FLOAT': SLNumeric, + 'INT': sqltypes.Integer, + 'INTEGER': sqltypes.Integer, + 'NUMERIC': SLNumeric, + 'REAL': SLNumeric, + 'SMALLINT': sqltypes.SmallInteger, + 'TEXT': SLText, + 'TIME': SLTime, + 'TIMESTAMP': SLDateTime, + 'VARCHAR': SLString, +} + + +class SQLite_pysqliteExecutionContext(default.DefaultExecutionContext): + def post_exec(self): + if self.isinsert and not self.executemany: + if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] + + +class SQLite_pysqlite(SQLiteDialect): + default_paramstyle = 'qmark' + poolclass = pool.SingletonThreadPool + execution_ctx_cls = SQLite_pysqliteExecutionContext + driver = 'pysqlite' + ischema_names = ischema_names + + def __init__(self, **kwargs): + SQLiteDialect.__init__(self, **kwargs) + def vers(num): + return tuple([int(x) for x in num.split('.')]) + if self.dbapi is not None: + sqlite_ver = self.dbapi.version_info + if sqlite_ver < (2, 1, '3'): + util.warn( + ("The installed version of pysqlite2 (%s) is out-dated " + "and will cause errors in some cases. Version 2.1.3 " + "or greater is recommended.") % + '.'.join([str(subver) for subver in sqlite_ver])) + if self.dbapi.sqlite_version_info < (3, 3, 8): + self.supports_default_values = False + self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) + + @classmethod + def dbapi(cls): + try: + from pysqlite2 import dbapi2 as sqlite + except ImportError, e: + try: + from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. + except ImportError: + raise e + return sqlite + + def server_version_info(self, connection): + return self.dbapi.sqlite_version_info + + def create_connect_args(self, url): + if url.username or url.password or url.host or url.port: + raise exc.ArgumentError( + "Invalid SQLite URL: %s\n" + "Valid SQLite URL forms are:\n" + " sqlite:///:memory: (or, sqlite://)\n" + " sqlite:///relative/path/to/file.db\n" + " sqlite:////absolute/path/to/file.db" % (url,)) + filename = url.database or ':memory:' + + opts = url.query.copy() + util.coerce_kw_type(opts, 'timeout', float) + util.coerce_kw_type(opts, 'isolation_level', str) + util.coerce_kw_type(opts, 'detect_types', int) + util.coerce_kw_type(opts, 'check_same_thread', bool) + util.coerce_kw_type(opts, 'cached_statements', int) + + return ([filename], opts) + + def type_descriptor(self, typeobj): + return sqltypes.adapt_type(typeobj, colspecs) + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e) + +dialect = SQLite_pysqlite diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index b0f4465985..6def864e89 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -66,9 +66,9 @@ from sqlalchemy.engine.base import ( ResultProxy, RootTransaction, RowProxy, - SchemaIterator, Transaction, - TwoPhaseTransaction + TwoPhaseTransaction, + TypeCompiler ) from sqlalchemy.engine import strategies from sqlalchemy import util @@ -89,9 +89,9 @@ __all__ = ( 'ResultProxy', 'RootTransaction', 'RowProxy', - 'SchemaIterator', 'Transaction', 'TwoPhaseTransaction', + 'TypeCompiler', 'create_engine', 'engine_from_config', ) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 39085c3596..f95da22731 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -34,7 +34,11 @@ class Dialect(object): All Dialects implement the following attributes: name - identifying name for the dialect (i.e. 'sqlite') + identifying name for the dialect from a DBAPI-neutral point of view + (i.e. 'sqlite') + + driver + identitfying name for the dialect's DBAPI positional True if the paramstyle for this Dialect is positional. @@ -51,21 +55,21 @@ class Dialect(object): type of encoding to use for unicode, usually defaults to 'utf-8'. - schemagenerator - a :class:`~sqlalchemy.schema.SchemaVisitor` class which generates - schemas. - - schemadropper - a :class:`~sqlalchemy.schema.SchemaVisitor` class which drops schemas. - defaultrunner a :class:`~sqlalchemy.schema.SchemaVisitor` class which executes defaults. statement_compiler - a :class:`~sqlalchemy.engine.base.Compiled` class used to compile SQL + a :class:`~Compiled` class used to compile SQL statements + ddl_compiler + a :class:`~Compiled` class used to compile DDL + statements + + execution_ctx_cls + a :class:`ExecutionContext` class used to handle statement execution + preparer a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to quote identifiers. @@ -107,11 +111,6 @@ class Dialect(object): supports_default_values Indicates if the construct ``INSERT INTO tablename DEFAULT VALUES`` is supported - - description_encoding - type of encoding to use for unicode when working with metadata - descriptions. If set to ``None`` no encoding will be done. - This usually defaults to 'utf-8'. """ def create_connect_args(self, url): @@ -401,7 +400,7 @@ class ExecutionContext(object): class Compiled(object): - """Represent a compiled SQL expression. + """Represent a compiled SQL or DDL expression. The ``__str__`` method of the ``Compiled`` object should produce the actual text of the statement. ``Compiled`` objects are @@ -411,9 +410,10 @@ class Compiled(object): ``Compiled`` object be dependent on the actual values of those bind parameters, even though it may reference those values as defaults. + """ - def __init__(self, dialect, statement, column_keys=None, bind=None): + def __init__(self, dialect, statement, bind=None): """Construct a new ``Compiled`` object. dialect @@ -422,41 +422,40 @@ class Compiled(object): statement ``ClauseElement`` to be compiled. - column_keys - a list of column names to be compiled into an INSERT or UPDATE - statement. - bind Optional Engine or Connection to compile this statement against. """ self.dialect = dialect self.statement = statement - self.column_keys = column_keys self.bind = bind self.can_execute = statement.supports_execution def compile(self): """Produce the internal string representation of this element.""" - raise NotImplementedError() + self.string = self.process(self.statement) + + def process(self, obj, **kwargs): + return obj._compiler_dispatch(self, **kwargs) def __str__(self): - """Return the string text of the generated SQL statement.""" + """Return the string text of the generated SQL or DDL.""" - raise NotImplementedError() + return self.string or '' @util.deprecated('Deprecated. Use construct_params(). ' '(supports Unicode key names.)') def get_params(self, **params): return self.construct_params(params) - def construct_params(self, params): + def construct_params(self, params=None): """Return the bind params for this compiled object. `params` is a dict of string/object pairs whos values will override bind values compiled in to the statement. + """ raise NotImplementedError() @@ -473,6 +472,15 @@ class Compiled(object): return self.execute(*multiparams, **params).scalar() +class TypeCompiler(object): + """Produces DDL specification for TypeEngine objects.""" + + def __init__(self, dialect): + self.dialect = dialect + + def process(self, type_): + return type_._compiler_dispatch(self) + class Connectable(object): """Interface for an object which supports execution of SQL constructs. @@ -480,6 +488,9 @@ class Connectable(object): The two implementations of ``Connectable`` are :class:`Connection` and :class:`Engine`. + Connectable must also implement the 'dialect' member which references a + :class:`Dialect` instance. + """ def contextual_connect(self): @@ -813,9 +824,6 @@ class Connection(Connectable): return self.execute(object, *multiparams, **params).scalar() - def statement_compiler(self, statement, **kwargs): - return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs) - def execute(self, object, *multiparams, **params): """Executes and returns a ResultProxy.""" @@ -860,6 +868,13 @@ class Connection(Connectable): def _execute_default(self, default, multiparams, params): return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default) + def _execute_ddl(self, ddl, params, multiparams): + context = self.__create_execution_context( + compiled_ddl=ddl.compile(dialect=self.dialect), + parameters=None + ) + return self.__execute_context(context) + def _execute_clauseelement(self, elem, multiparams, params): params = self.__distill_params(multiparams, params) if params: @@ -868,7 +883,7 @@ class Connection(Connectable): keys = [] context = self.__create_execution_context( - compiled=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), + compiled_sql=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), parameters=params ) return self.__execute_context(context) @@ -877,7 +892,7 @@ class Connection(Connectable): """Execute a sql.Compiled object.""" context = self.__create_execution_context( - compiled=compiled, + compiled_sql=compiled, parameters=self.__distill_params(multiparams, params) ) return self.__execute_context(context) @@ -900,13 +915,6 @@ class Connection(Connectable): self._commit_impl() return context.get_result_proxy() - def _execute_ddl(self, ddl, params, multiparams): - if params: - schema_item, params = params[0], params[1:] - else: - schema_item = None - return ddl(None, schema_item, self, *params, **multiparams) - def _handle_dbapi_exception(self, e, statement, parameters, cursor, context): if getattr(self, '_reentrant_error', False): raise exc.DBAPIError.instance(None, None, e) @@ -966,7 +974,7 @@ class Connection(Connectable): expression.ClauseElement: _execute_clauseelement, Compiled: _execute_compiled, schema.SchemaItem: _execute_default, - schema.DDL: _execute_ddl, + schema.DDLElement: _execute_ddl, basestring: _execute_text } @@ -1126,12 +1134,16 @@ class Engine(Connectable): def create(self, entity, connection=None, **kwargs): """Create a table or index within this engine's database connection given a schema.Table object.""" - self._run_visitor(self.dialect.schemagenerator, entity, connection=connection, **kwargs) + from sqlalchemy.engine import ddl + + self._run_visitor(ddl.SchemaGenerator, entity, connection=connection, **kwargs) def drop(self, entity, connection=None, **kwargs): """Drop a table or index within this engine's database connection given a schema.Table object.""" - self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs) + from sqlalchemy.engine import ddl + + self._run_visitor(ddl.SchemaDropper, entity, connection=connection, **kwargs) def _execute_default(self, default): connection = self.contextual_connect() @@ -1212,9 +1224,6 @@ class Engine(Connectable): connection = self.contextual_connect(close_with_result=True) return connection._execute_compiled(compiled, multiparams, params) - def statement_compiler(self, statement, **kwargs): - return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs) - def connect(self, **kwargs): """Return a newly allocated Connection object.""" @@ -1790,29 +1799,6 @@ class BufferedColumnResultProxy(ResultProxy): l.append(row) return l - -class SchemaIterator(schema.SchemaVisitor): - """A visitor that can gather text into a buffer and execute the contents of the buffer.""" - - def __init__(self, connection): - """Construct a new SchemaIterator.""" - - self.connection = connection - self.buffer = StringIO.StringIO() - - def append(self, s): - """Append content to the SchemaIterator's query buffer.""" - - self.buffer.write(s) - - def execute(self): - """Execute the contents of the SchemaIterator's buffer.""" - - try: - return self.connection.execute(self.buffer.getvalue()) - finally: - self.buffer.truncate(0) - class DefaultRunner(schema.SchemaVisitor): """A visitor which accepts ColumnDefault objects, produces the dialect-specific SQL corresponding to their execution, and diff --git a/lib/sqlalchemy/engine/ddl.py b/lib/sqlalchemy/engine/ddl.py new file mode 100644 index 0000000000..2fc09a20b8 --- /dev/null +++ b/lib/sqlalchemy/engine/ddl.py @@ -0,0 +1,126 @@ +"""routines to handle CREATE/DROP workflow.""" + +### TOOD: CREATE TABLE and DROP TABLE have been moved out so far. +### Index, ForeignKey, etc. still need to move. + +from sqlalchemy import engine, schema +from sqlalchemy.sql import util as sql_util + +class DDLBase(schema.SchemaVisitor): + def __init__(self, connection): + self.connection = connection + + def find_alterables(self, tables): + alterables = [] + class FindAlterables(schema.SchemaVisitor): + def visit_foreign_key_constraint(self, constraint): + if constraint.use_alter and constraint.table in tables: + alterables.append(constraint) + findalterables = FindAlterables() + for table in tables: + for c in table.constraints: + findalterables.traverse(c) + return alterables + + +class SchemaGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables and set(tables) or None + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def _can_create(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) + + def visit_metadata(self, metadata): + if self.tables: + tables = self.tables + else: + tables = metadata.tables.values() + collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)] + for table in collection: + self.traverse_single(table) + if self.dialect.supports_alter: + for alterable in self.find_alterables(collection): + self.connection.execute(schema.AddForeignKey(alterable)) + + def visit_table(self, table): + for listener in table.ddl_listeners['before-create']: + listener('before-create', table, self.connection) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.connection.execute(schema.CreateTable(table)) + + if hasattr(table, 'indexes'): + for index in table.indexes: + self.traverse_single(index) + + for listener in table.ddl_listeners['after-create']: + listener('after-create', table, self.connection) + + def visit_sequence(self, sequence): + if self.dialect.supports_sequences: + if \ + (not self.dialect.sequences_optional or not sequence.optional) and \ + (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)): + self.connection.execute(schema.CreateSequence(sequence)) + + def visit_index(self, index): + self.connection.execute(schema.CreateIndex(index)) + +class SchemaDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def visit_metadata(self, metadata): + if self.tables: + tables = self.tables + else: + tables = metadata.tables.values() + collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)] + if self.dialect.supports_alter: + for alterable in self.find_alterables(collection): + self.connection.execute(schema.DropForeignKey(alterable)) + for table in collection: + self.traverse_single(table) + + def _can_drop(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) + + def visit_index(self, index): + self.connection.execute(schema.DropIndex(index)) + + def visit_table(self, table): + for listener in table.ddl_listeners['before-drop']: + listener('before-drop', table, self.connection) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.connection.execute(schema.DropTable(table)) + + for listener in table.ddl_listeners['after-drop']: + listener('after-drop', table, self.connection) + + def visit_sequence(self, sequence): + if self.dialect.supports_sequences: + if \ + (not self.dialect.sequences_optional or not sequence.optional) and \ + (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)): + self.connection.execute(schema.DropSequence(sequence)) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 1ffc7bb04c..12b1661925 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -23,12 +23,14 @@ AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', class DefaultDialect(base.Dialect): """Default implementation of Dialect""" - schemagenerator = compiler.SchemaGenerator - schemadropper = compiler.SchemaDropper - statement_compiler = compiler.DefaultCompiler + statement_compiler = compiler.SQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler = compiler.GenericTypeCompiler preparer = compiler.IdentifierPreparer defaultrunner = base.DefaultRunner supports_alter = True + supports_sequences = False + sequences_optional = False supports_unicode_statements = False max_identifier_length = 9999 supports_sane_rowcount = True @@ -57,6 +59,8 @@ class DefaultDialect(base.Dialect): self.paramstyle = self.default_paramstyle self.positional = self.paramstyle in ('qmark', 'format', 'numeric') self.identifier_preparer = self.preparer(self) + self.type_compiler = self.type_compiler(self) + if label_length and label_length > self.max_identifier_length: raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length)) self.label_length = label_length @@ -67,8 +71,9 @@ class DefaultDialect(base.Dialect): the generic object which comes from the types module. Subclasses will usually use the ``adapt_type()`` method in the - types module to make this job easy.""" - + types module to make this job easy. + + """ if type(typeobj) is type: typeobj = typeobj() return typeobj @@ -126,13 +131,29 @@ class DefaultDialect(base.Dialect): class DefaultExecutionContext(base.ExecutionContext): - def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): + def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None): self.dialect = dialect self._connection = self.root_connection = connection - self.compiled = compiled self.engine = connection.engine - if compiled is not None: + if compiled_ddl is not None: + self.compiled = compiled = compiled_ddl + if not dialect.supports_unicode_statements: + self.statement = unicode(compiled).encode(self.dialect.encoding) + else: + self.statement = unicode(compiled) + self.isinsert = self.isupdate = self.executemany = False + self.should_autocommit = True + self.result_map = None + self.cursor = self.create_cursor() + self.compiled_parameters = [] + if self.dialect.positional: + self.parameters = [()] + else: + self.parameters = [{}] + elif compiled_sql is not None: + self.compiled = compiled = compiled_sql + # compiled clauseelement. process bind params, process table defaults, # track collections used by ResultProxy to target and process results @@ -172,8 +193,8 @@ class DefaultExecutionContext(base.ExecutionContext): self.parameters = self.__convert_compiled_params(self.compiled_parameters) elif statement is not None: - # plain text statement. - self.result_map = None + # plain text statement + self.result_map = self.compiled = None self.parameters = self.__encode_param_keys(parameters) self.executemany = len(parameters) > 1 if isinstance(statement, unicode) and not dialect.supports_unicode_statements: @@ -185,7 +206,7 @@ class DefaultExecutionContext(base.ExecutionContext): self.should_autocommit = self.should_autocommit_text(statement) else: # no statement. used for standalone ColumnDefault execution. - self.statement = None + self.statement = self.compiled = None self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False self.cursor = self.create_cursor() diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index fa608df65e..b1261da0a8 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -201,11 +201,14 @@ class MockEngineStrategy(EngineStrategy): def create(self, entity, **kwargs): kwargs['checkfirst'] = False - self.dialect.schemagenerator(self.dialect, self, **kwargs).traverse(entity) + from sqlalchemy.engine import ddl + + ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity) def drop(self, entity, **kwargs): kwargs['checkfirst'] = False - self.dialect.schemadropper(self.dialect, self, **kwargs).traverse(entity) + from sqlalchemy.engine import ddl + ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity) def execute(self, object, *multiparams, **params): raise NotImplementedError() diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 5c8e68ce45..8000cbc6c3 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -88,7 +88,15 @@ class URL(object): """Return the SQLAlchemy database dialect class corresponding to this URL's driver name.""" try: - module = getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername) + if '+' in self.drivername: + dialect, driver = self.drivername.split('+') + else: + dialect, driver = self.drivername, 'base' + + module = __import__('sqlalchemy.dialects.%s.%s' % (dialect, driver)).dialects + module = getattr(module, dialect) + module = getattr(module, driver) + return module.dialect except ImportError: if sys.exc_info()[2].tb_next is None: @@ -140,7 +148,7 @@ def make_url(name_or_url): def _parse_rfc1738_args(name): pattern = re.compile(r''' - (?P\w+):// + (?P[\w\+]+):// (?: (?P[^:/]*) (?::(?P[^/]*))? diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py index 3b4880403a..a5cb6e9d2d 100644 --- a/lib/sqlalchemy/ext/declarative.py +++ b/lib/sqlalchemy/ext/declarative.py @@ -478,9 +478,7 @@ def _as_declarative(cls, classname, dict_): *(tuple(cols) + tuple(args)), **table_kw) else: table = cls.__table__ - if cols: - raise exceptions.ArgumentError("Can't add additional columns when specifying __table__") - + mapper_args = getattr(cls, '__mapper_args__', {}) if 'inherits' not in mapper_args: inherits = cls.__mro__[1] @@ -532,7 +530,7 @@ def _as_declarative(cls, classname, dict_): mapper_args['exclude_properties'] = exclude_properties = \ set([c.key for c in inherited_table.c if c not in inherited_mapper._columntoproperty]) exclude_properties.difference_update([c.key for c in cols]) - + cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args) class DeclarativeMeta(type): diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 6bcc89b3c2..04fc9d0ef1 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1305,15 +1305,13 @@ class Mapper(object): if col in pks: if history.deleted: params[col._label] = prop.get_col_value(col, history.deleted[0]) - hasdata = True else: # row switch logic can reach us here # remove the pk from the update params so the update doesn't # attempt to include the pk in the update statement del params[col.key] params[col._label] = prop.get_col_value(col, history.added[0]) - else: - hasdata = True + hasdata = True elif col in pks: params[col._label] = mapper._get_state_attr_by_column(state, col) if hasdata: diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index a4561d443d..0211b9707a 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -32,24 +32,12 @@ class ColumnProperty(StrategizedProperty): """Describes an object attribute that corresponds to a table column.""" def __init__(self, *columns, **kwargs): - """Construct a ColumnProperty. - - :param \*columns: The list of `columns` describes a single - object property. If there are multiple tables joined - together for the mapper, this list represents the equivalent - column as it appears across each table. - - :param group: - - :param deferred: - - :param comparator_factory: - - :param descriptor: - - :param extension: - + """The list of `columns` describes a single object + property. If there are multiple tables joined together for the + mapper, this list represents the equivalent column as it + appears across each table. """ + self.columns = [expression._labeled(c) for c in columns] self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) @@ -57,11 +45,6 @@ class ColumnProperty(StrategizedProperty): self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator) self.descriptor = kwargs.pop('descriptor', None) self.extension = kwargs.pop('extension', None) - if kwargs: - raise TypeError( - "%s received unexpected keyword argument(s): %s" % ( - self.__class__.__name__, ', '.join(sorted(kwargs.keys())))) - util.set_creation_order(self) if self.no_instrument: self.strategy_class = strategies.UninstrumentedColumnLoader @@ -1153,4 +1136,4 @@ mapper.ColumnProperty = ColumnProperty mapper.SynonymProperty = SynonymProperty mapper.ComparableProperty = ComparableProperty mapper.RelationProperty = RelationProperty -mapper.ConcreteInheritedProperty = ConcreteInheritedProperty +mapper.ConcreteInheritedProperty = ConcreteInheritedProperty \ No newline at end of file diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index d454bc7cff..c9dc152b98 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -609,9 +609,7 @@ class Column(SchemaItem, expression.ColumnClause): "Unknown arguments passed to Column: " + repr(kwargs.keys())) def __str__(self): - if self.name is None: - return "(no name)" - elif self.table is not None: + if self.table is not None: if self.table.named_with_column: return (self.table.description + "." + self.description) else: @@ -619,9 +617,9 @@ class Column(SchemaItem, expression.ColumnClause): else: return self.description - @property def bind(self): return self.table.bind + bind = property(bind) def references(self, column): """Return True if this Column references the given column via foreign key.""" @@ -1884,7 +1882,30 @@ class SchemaVisitor(visitors.ClauseVisitor): __traverse_options__ = {'schema_visitor':True} -class DDL(object): +class DDLElement(expression.ClauseElement): + """Base class for DDL expression constructs.""" + + supports_execution = True + _autocommit = True + + def bind(self): + if self._bind: + return self._bind + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a Dialect.""" + + return dialect.ddl_compiler(dialect, self, **kw) + +class DDL(DDLElement): """A literal DDL statement. Specifies literal SQL DDL to be executed by the database. DDL objects can @@ -1905,6 +1926,8 @@ class DDL(object): connection.execute(drop_spow) """ + __visit_name__ = "ddl" + def __init__(self, statement, on=None, context=None, bind=None): """Create a DDL statement. @@ -1964,6 +1987,7 @@ class DDL(object): self.on = on self.context = context or {} self._bind = bind + self.schema_item = None def execute(self, bind=None, schema_item=None): """Execute this DDL immediately. @@ -1985,10 +2009,9 @@ class DDL(object): if bind is None: bind = _bind_or_error(self) - # no SQL bind params are supported + if self._should_execute(None, schema_item, bind): - executable = expression.text(self._expand(schema_item, bind)) - return bind.execute(executable) + return bind.execute(self.against(schema_item)) else: bind.engine.logger.info("DDL execution skipped, criteria not met.") @@ -2040,39 +2063,18 @@ class DDL(object): (', '.join(schema_item.ddl_events), event)) schema_item.ddl_listeners[event].append(self) return self - - def bind(self): - """An Engine or Connection to which this DDL is bound. - - This property may be assigned an ``Engine`` or ``Connection``, or - assigned a string or URL to automatically create a basic ``Engine`` - for this bind with ``create_engine()``. - """ - return self._bind - - def _bind_to(self, bind): - """Bind this MetaData to an Engine, Connection, string or URL.""" - - global URL - if URL is None: - from sqlalchemy.engine.url import URL - - if isinstance(bind, (basestring, URL)): - from sqlalchemy import create_engine - self._bind = create_engine(bind) - else: - self._bind = bind - bind = property(bind, _bind_to) - + + @expression._generative + def against(self, schema_item): + """Return a copy of this DDL against a specific schema item.""" + + self.schema_item = schema_item + def __call__(self, event, schema_item, bind): """Execute the DDL as a ddl_listener.""" if self._should_execute(event, schema_item, bind): - statement = expression.text(self._expand(schema_item, bind)) - return bind.execute(statement) - - def _expand(self, schema_item, bind): - return self.statement % self._prepare_context(schema_item, bind) + return bind.execute(self.against(schema_item)) def _should_execute(self, event, schema_item, bind): if self.on is None: @@ -2082,25 +2084,6 @@ class DDL(object): else: return self.on(event, schema_item, bind) - def _prepare_context(self, schema_item, bind): - # table events can substitute table and schema name - if isinstance(schema_item, Table): - context = self.context.copy() - - preparer = bind.dialect.identifier_preparer - path = preparer.format_table_seq(schema_item) - if len(path) == 1: - table, schema = path[0], '' - else: - table, schema = path[-1], path[0] - - context.setdefault('table', table) - context.setdefault('schema', schema) - context.setdefault('fullname', preparer.format_table(schema_item)) - return context - else: - return self.context - def __repr__(self): return '<%s@%s; %s>' % ( type(self).__name__, id(self), @@ -2110,11 +2093,76 @@ class DDL(object): if getattr(self, key)])) def _to_schema_column(element): - if hasattr(element, '__clause_element__'): - element = element.__clause_element__() - if not isinstance(element, Column): - raise exc.ArgumentError("schema.Column object expected") - return element + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + if not isinstance(element, Column): + raise exc.ArgumentError("schema.Column object expected") + return element + +class _CreateDropBase(DDLElement): + """Base class for DDL constucts that represent CREATE and DROP or equivalents. + + The common theme of _CreateDropBase is a single + ``element`` attribute which refers to the element + to be created or dropped. + + """ + + def __init__(self, element): + self.element = element + + def bind(self): + if self._bind: + return self._bind + if self.element: + e = self.element.bind + if e: + return e + return None + + def _set_bind(self, bind): + self._bind = bind + bind = property(bind, _set_bind) + +class CreateTable(_CreateDropBase): + """Represent a CREATE TABLE statement.""" + + __visit_name__ = "create_table" + +class DropTable(_CreateDropBase): + """Represent a DROP TABLE statement.""" + + __visit_name__ = "drop_table" + +class AddForeignKey(_CreateDropBase): + """Represent an ALTER TABLE ADD FOREIGN KEY statement.""" + + __visit_name__ = "add_foreignkey" + +class DropForeignKey(_CreateDropBase): + """Represent an ALTER TABLE DROP FOREIGN KEY statement.""" + + __visit_name__ = "drop_foreignkey" + +class CreateSequence(_CreateDropBase): + """Represent a CREATE SEQUENCE statement.""" + + __visit_name__ = "create_sequence" + +class DropSequence(_CreateDropBase): + """Represent a DROP SEQUENCE statement.""" + + __visit_name__ = "drop_sequence" + +class CreateIndex(_CreateDropBase): + """Represent a CREATE INDEX statement.""" + + __visit_name__ = "create_index" + +class DropIndex(_CreateDropBase): + """Represent a DROP INDEX statement.""" + + __visit_name__ = "drop_index" def _bind_or_error(schemaitem): bind = schemaitem.bind diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d5c85d71d6..3e61b459b4 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -123,7 +123,7 @@ class _CompileLabel(visitors.Visitable): def quote(self): return self.element.quote -class DefaultCompiler(engine.Compiled): +class SQLCompiler(engine.Compiled): """Default implementation of Compiled. Compiles ClauseElements into SQL strings. Uses a similar visit @@ -134,8 +134,9 @@ class DefaultCompiler(engine.Compiled): operators = OPERATORS functions = FUNCTIONS - # if we are insert/update/delete. - # set to true when we visit an INSERT, UPDATE or DELETE + # class-level defaults which can be set at the instance + # level to define if this Compiled instance represents + # INSERT/UPDATE/DELETE isdelete = isinsert = isupdate = False def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): @@ -152,7 +153,9 @@ class DefaultCompiler(engine.Compiled): statement. """ - engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs) + engine.Compiled.__init__(self, dialect, statement, **kwargs) + + self.column_keys = column_keys # compile INSERT/UPDATE defaults/sequences inlined (no pre-execute) self.inline = inline or getattr(statement, 'inline', False) @@ -192,12 +195,6 @@ class DefaultCompiler(engine.Compiled): # or dialect.max_identifier_length self.truncated_names = {} - def compile(self): - self.string = self.process(self.statement) - - def process(self, obj, **kwargs): - return obj._compiler_dispatch(self, **kwargs) - def is_subquery(self): return len(self.stack) > 1 @@ -292,7 +289,7 @@ class DefaultCompiler(engine.Compiled): return index.name def visit_typeclause(self, typeclause, **kwargs): - return typeclause.type.dialect_impl(self.dialect).get_col_spec() + return self.dialect.type_compiler.process(typeclause.type) def post_process_text(self, text): return text @@ -739,110 +736,117 @@ class DefaultCompiler(engine.Compiled): def visit_release_savepoint(self, savepoint_stmt): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) - def __str__(self): - return self.string or '' - -class DDLBase(engine.SchemaIterator): - def find_alterables(self, tables): - alterables = [] - class FindAlterables(schema.SchemaVisitor): - def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and constraint.table in tables: - alterables.append(constraint) - findalterables = FindAlterables() - for table in tables: - for c in table.constraints: - findalterables.traverse(c) - return alterables - def _validate_identifier(self, ident, truncate): - if truncate: - if len(ident) > self.dialect.max_identifier_length: - counter = getattr(self, 'counter', 0) - self.counter = counter + 1 - return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] +class DDLCompiler(engine.Compiled): + @property + def preparer(self): + return self.dialect.identifier_preparer + + def visit_ddl(self, ddl, **kwargs): + # table events can substitute table and schema name + context = ddl.context + if isinstance(ddl.schema_item, schema.Table): + context = context.copy() + + preparer = self.dialect.identifier_preparer + path = preparer.format_table_seq(ddl.schema_item) + if len(path) == 1: + table, sch = path[0], '' else: - return ident - else: - self.dialect.validate_identifier(ident) - return ident - - -class SchemaGenerator(DDLBase): - def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(SchemaGenerator, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - self.tables = tables and set(tables) or None - self.preparer = dialect.identifier_preparer - self.dialect = dialect - - def get_column_specification(self, column, first_pk=False): - raise NotImplementedError() - - def _can_create(self, table): - self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) - return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) + table, sch = path[-1], path[0] - def visit_metadata(self, metadata): - if self.tables: - tables = self.tables - else: - tables = metadata.tables.values() - collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)] - for table in collection: - self.traverse_single(table) - if self.dialect.supports_alter: - for alterable in self.find_alterables(collection): - self.add_foreignkey(alterable) - - def visit_table(self, table): - for listener in table.ddl_listeners['before-create']: - listener('before-create', table, self.connection) + context.setdefault('table', table) + context.setdefault('schema', sch) + context.setdefault('fullname', preparer.format_table(ddl.schema_item)) + + return ddl.statement % context - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) + def visit_create_table(self, create): + table = create.element + preparer = self.dialect.identifier_preparer - self.append("\n" + " ".join(['CREATE'] + - table._prefixes + + text = "\n" + " ".join(['CREATE'] + \ + table._prefixes + \ ['TABLE', - self.preparer.format_table(table), - "("])) + preparer.format_table(table), + "("]) separator = "\n" # if only one primary key, specify it along with the column first_pk = False for column in table.columns: - self.append(separator) + text += separator separator = ", \n" - self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) + text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk) if column.primary_key: first_pk = True for constraint in column.constraints: - self.traverse_single(constraint) + text += self.process(constraint) # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) if table.primary_key: - self.traverse_single(table.primary_key) + text += self.process(table.primary_key) + for constraint in [c for c in table.constraints if c is not table.primary_key]: - self.traverse_single(constraint) + text += self.process(constraint) - self.append("\n)%s\n\n" % self.post_create_table(table)) - self.execute() + text += "\n)%s\n\n" % self.post_create_table(table) + return text + + def visit_drop_table(self, drop): + return "\nDROP TABLE " + self.preparer.format_table(drop.element) + + def visit_add_foreignkey(self, add): + return "ALTER TABLE %s ADD " % self.preparer.format_table(add.element.table) + \ + self.define_foreign_key(add.element) - if hasattr(table, 'indexes'): - for index in table.indexes: - self.traverse_single(index) + def visit_drop_foreignkey(self, drop): + return "ALTER TABLE %s DROP CONSTRAINT %s" % ( + self.preparer.format_table(drop.element.table), + self.preparer.format_constraint(drop.element)) - for listener in table.ddl_listeners['after-create']: - listener('after-create', table, self.connection) + def visit_create_index(self, create): + index = create.element + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX %s ON %s (%s)" \ + % (preparer.quote(self._validate_identifier(index.name, True), index.quote), + preparer.format_table(index.table), + ', '.join(preparer.quote(c.name, c.quote) + for c in index.columns)) + return text + + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote) + + def get_column_specification(self, column, first_pk=False): + raise NotImplementedError() def post_create_table(self, table): return '' + def _compile(self, tocompile, parameters): + """compile the given string/parameters using this SchemaGenerator's dialect.""" + compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) + compiler.compile() + return compiler + + def _validate_identifier(self, ident, truncate): + if truncate: + if len(ident) > self.dialect.max_identifier_length: + counter = getattr(self, 'counter', 0) + self.counter = counter + 1 + return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:] + else: + return ident + else: + self.dialect.validate_identifier(ident) + return ident + def get_column_default_string(self, column): if isinstance(column.server_default, schema.DefaultClause): if isinstance(column.server_default.arg, basestring): @@ -852,149 +856,174 @@ class SchemaGenerator(DDLBase): else: return None - def _compile(self, tocompile, parameters): - """compile the given string/parameters using this SchemaGenerator's dialect.""" - compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters) - compiler.compile() - return compiler - def visit_check_constraint(self, constraint): - self.append(", \n\t") + text = ", \n\t" if constraint.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(constraint)) - self.append(" CHECK (%s)" % constraint.sqltext) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += " CHECK (%s)" % constraint.sqltext + text += self.define_constraint_deferrability(constraint) + return text def visit_column_check_constraint(self, constraint): - self.append(" CHECK (%s)" % constraint.sqltext) - self.define_constraint_deferrability(constraint) + text = " CHECK (%s)" % constraint.sqltext + text += self.define_constraint_deferrability(constraint) + return text def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return - self.append(", \n\t") + return '' + text = ", \n\t" if constraint.name is not None: - self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append("PRIMARY KEY ") - self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) - for c in constraint)) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += "PRIMARY KEY " + text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) + for c in constraint) + text += self.define_constraint_deferrability(constraint) + return text def visit_foreign_key_constraint(self, constraint): if constraint.use_alter and self.dialect.supports_alter: - return - self.append(", \n\t ") - self.define_foreign_key(constraint) - - def add_foreignkey(self, constraint): - self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) - self.define_foreign_key(constraint) - self.execute() + return '' + + return ", \n\t " + self.define_foreign_key(constraint) def define_foreign_key(self, constraint): - preparer = self.preparer + preparer = self.dialect.identifier_preparer + text = "" if constraint.name is not None: - self.append("CONSTRAINT %s " % - preparer.format_constraint(constraint)) + text += "CONSTRAINT %s " % \ + preparer.format_constraint(constraint) table = list(constraint.elements)[0].column.table - self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( ', '.join(preparer.quote(f.parent.name, f.parent.quote) for f in constraint.elements), preparer.format_table(table), ', '.join(preparer.quote(f.column.name, f.column.quote) for f in constraint.elements) - )) + ) if constraint.ondelete is not None: - self.append(" ON DELETE %s" % constraint.ondelete) + text += " ON DELETE %s" % constraint.ondelete if constraint.onupdate is not None: - self.append(" ON UPDATE %s" % constraint.onupdate) - self.define_constraint_deferrability(constraint) + text += " ON UPDATE %s" % constraint.onupdate + text += self.define_constraint_deferrability(constraint) + return text def visit_unique_constraint(self, constraint): - self.append(", \n\t") + text = ", \n\t" if constraint.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))) - self.define_constraint_deferrability(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint) + text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)) + text += self.define_constraint_deferrability(constraint) + return text def define_constraint_deferrability(self, constraint): + text = "" if constraint.deferrable is not None: if constraint.deferrable: - self.append(" DEFERRABLE") + text += " DEFERRABLE" else: - self.append(" NOT DEFERRABLE") + text += " NOT DEFERRABLE" if constraint.initially is not None: - self.append(" INITIALLY %s" % constraint.initially) + text += " INITIALLY %s" % constraint.initially + return text + + +# PLACEHOLDERS to get non-converted dialects to compile +class SchemaGenerator(object): + pass + +class SchemaDropper(object): + pass + + +class GenericTypeCompiler(engine.TypeCompiler): + def visit_CHAR(self, type_): + return "CHAR" + (type_.length and "(%d)" % type_.length or "") - def visit_column(self, column): - pass + def visit_NCHAR(self, type_): + return "NCHAR" + (type_.length and "(%d)" % type_.length or "") + + def visit_FLOAT(self, type_): + return "FLOAT" - def visit_index(self, index): - preparer = self.preparer - self.append("CREATE ") - if index.unique: - self.append("UNIQUE ") - self.append("INDEX %s ON %s (%s)" \ - % (preparer.quote(self._validate_identifier(index.name, True), index.quote), - preparer.format_table(index.table), - ', '.join(preparer.quote(c.name, c.quote) - for c in index.columns))) - self.execute() + def visit_NUMERIC(self, type_): + if type_.precision is None: + return "NUMERIC" + else: + return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale} + def visit_DECIMAL(self, type_): + return "DECIMAL" + + def visit_INTEGER(self, type_): + return "INTEGER" -class SchemaDropper(DDLBase): - def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): - super(SchemaDropper, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - self.tables = tables - self.preparer = dialect.identifier_preparer - self.dialect = dialect + def visit_SMALLINT(self, type_): + return "SMALLINT" - def visit_metadata(self, metadata): - if self.tables: - tables = self.tables - else: - tables = metadata.tables.values() - collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)] - if self.dialect.supports_alter: - for alterable in self.find_alterables(collection): - self.drop_foreignkey(alterable) - for table in collection: - self.traverse_single(table) - - def _can_drop(self, table): - self.dialect.validate_identifier(table.name) - if table.schema: - self.dialect.validate_identifier(table.schema) - return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) - - def visit_index(self, index): - self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote)) - self.execute() - - def drop_foreignkey(self, constraint): - self.append("ALTER TABLE %s DROP CONSTRAINT %s" % ( - self.preparer.format_table(constraint.table), - self.preparer.format_constraint(constraint))) - self.execute() - - def visit_table(self, table): - for listener in table.ddl_listeners['before-drop']: - listener('before-drop', table, self.connection) + def visit_TIMESTAMP(self, type_): + return 'TIMESTAMP' - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) + def visit_DATETIME(self, type_): + return "DATETIME" - self.append("\nDROP TABLE " + self.preparer.format_table(table)) - self.execute() + def visit_DATE(self, type_): + return "DATE" - for listener in table.ddl_listeners['after-drop']: - listener('after-drop', table, self.connection) + def visit_TIME(self, type_): + return "TIME" + def visit_CLOB(self, type_): + return "CLOB" + def visit_VARCHAR(self, type_): + return "VARCHAR" + (type_.length and "(%d)" % type_.length or "") + + def visit_BLOB(self, type_): + return "BLOB" + + def visit_BINARY(self, type_): + return "BINARY" + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + + def visit_TEXT(self, type_): + return "TEXT" + + def visit_binary(self, type_): + return self.visit_BINARY(type_) + def visit_boolean(self, type_): + return self.visit_BOOLEAN(type_) + def visit_time(self, type_): + return self.visit_TIME(type_) + def visit_datetime(self, type_): + return self.visit_DATETIME(type_) + def visit_date(self, type_): + return self.visit_DATE(type_) + def visit_small_integer(self, type_): + return self.visit_SMALLINT(type_) + def visit_integer(self, type_): + return self.visit_INTEGER(type_) + def visit_float(self, type_): + return self.visit_FLOAT(type_) + def visit_numeric(self, type_): + return self.visit_NUMERIC(type_) + def visit_string(self, type_): + return self.visit_VARCHAR(type_) + def visit_text(self, type_): + return self.visit_TEXT(type_) + + def visit_null(self, type_): + raise NotImplementedError("Can't generate DDL for the null type") + + def visit_type_decorator(self, type_): + return self.process(type_.dialect_impl(self.dialect).impl) + + def visit_user_defined(self, type_): + return type_.get_col_spec() + class IdentifierPreparer(object): """Handle quoting and case-folding of identifiers based on options.""" diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index f527c6351f..6be867dbf5 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -32,10 +32,9 @@ from operator import attrgetter from sqlalchemy import util, exc from sqlalchemy.sql import operators from sqlalchemy.sql.visitors import Visitable, cloned_traverse -from sqlalchemy import types as sqltypes import operator -functions, schema, sql_util = None, None, None +functions, schema, sql_util, sqltypes = None, None, None, None DefaultDialect, ClauseAdapter, Annotated = None, None, None __all__ = [ @@ -974,7 +973,8 @@ class ClauseElement(Visitable): _annotations = {} supports_execution = False _from_objects = [] - + _bind = None + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -1106,11 +1106,9 @@ class ClauseElement(Visitable): def bind(self): """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""" - try: - if self._bind is not None: - return self._bind - except AttributeError: - pass + if self._bind is not None: + return self._bind + for f in _from_objects(self): if f is self: continue @@ -1139,7 +1137,7 @@ class ClauseElement(Visitable): return self.execute(*multiparams, **params).scalar() - def compile(self, bind=None, column_keys=None, compiler=None, dialect=None, inline=False): + def compile(self, bind=None, dialect=None, **kw): """Compile this SQL expression. The return value is a :class:`~sqlalchemy.engine.Compiled` object. @@ -1154,52 +1152,57 @@ class ClauseElement(Visitable): takes precedence over this ``ClauseElement``'s bound engine, if any. - column_keys - Used for INSERT and UPDATE statements, a list of - column names which should be present in the VALUES clause - of the compiled statement. If ``None``, all columns - from the target table object are rendered. - - compiler - A ``Compiled`` instance which will be used to compile - this expression. This argument takes precedence - over the `bind` and `dialect` arguments as well as - this ``ClauseElement``'s bound engine, if - any. - dialect A ``Dialect`` instance frmo which a ``Compiled`` will be acquired. This argument takes precedence over the `bind` argument as well as this ``ClauseElement``'s bound engine, if any. - inline - Used for INSERT statements, for a dialect which does - not support inline retrieval of newly generated - primary key columns, will force the expression used - to create the new primary key value to be rendered - inline within the INSERT statement's VALUES clause. - This typically refers to Sequence execution but - may also refer to any server-side default generation - function associated with a primary key `Column`. + \**kw + + Keyword arguments are passed along to the compiler, + which can affect the string produced. + + Keywords for a statement compiler are: + + column_keys + Used for INSERT and UPDATE statements, a list of + column names which should be present in the VALUES clause + of the compiled statement. If ``None``, all columns + from the target table object are rendered. + + inline + Used for INSERT statements, for a dialect which does + not support inline retrieval of newly generated + primary key columns, will force the expression used + to create the new primary key value to be rendered + inline within the INSERT statement's VALUES clause. + This typically refers to Sequence execution but + may also refer to any server-side default generation + function associated with a primary key `Column`. """ - if compiler is None: - if dialect is not None: - compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline) - elif bind is not None: - compiler = bind.statement_compiler(self, column_keys=column_keys, inline=inline) - elif self.bind is not None: - compiler = self.bind.statement_compiler(self, column_keys=column_keys, inline=inline) + + if not dialect: + if bind: + dialect = bind.dialect + elif self.bind: + dialect = self.bind.dialect + bind = self.bind else: global DefaultDialect if DefaultDialect is None: from sqlalchemy.engine.default import DefaultDialect dialect = DefaultDialect() - compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline) + compiler = self._compiler(dialect, bind=bind, **kw) compiler.compile() return compiler - + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a Dialect.""" + + return dialect.statement_compiler(dialect, self, **kw) + def __str__(self): return unicode(self.compile()).encode('ascii', 'backslashreplace') @@ -1230,6 +1233,12 @@ class ClauseElement(Visitable): class _Immutable(object): """mark a ClauseElement as 'immutable' when expressions are cloned.""" + def unique_params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + def _clone(self): return self diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index a5bd497aed..4471d4fb0d 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -34,13 +34,10 @@ class VisitableType(type): """ def __init__(cls, clsname, bases, clsdict): - if cls.__name__ == 'Visitable': + if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'): super(VisitableType, cls).__init__(clsname, bases, clsdict) return - assert hasattr(cls, '__visit_name__'), "`Visitable` descendants " \ - "should define `__visit_name__`" - # set up an optimized visit dispatch function # for use by the compiler visit_name = cls.__visit_name__ diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 38aba026c4..986d3d1332 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -15,7 +15,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT', 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BOOLEAN', 'SMALLINT', 'DATE', 'TIME', - 'String', 'Integer', 'SmallInteger','Smallinteger', + 'String', 'Integer', 'SmallInteger', 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'MutableType', 'Concatenable', 'UnicodeText', 'PickleType', 'Interval', 'type_map' @@ -27,10 +27,14 @@ from decimal import Decimal as _python_Decimal from sqlalchemy import exc from sqlalchemy.util import pickle +from sqlalchemy.sql.visitors import Visitable +from sqlalchemy.sql import expression +import sys +expression.sqltypes = sys.modules[__name__] import sqlalchemy.util as util NoneType = type(None) -class AbstractType(object): +class AbstractType(Visitable): def __init__(self, *args, **kwargs): pass @@ -89,37 +93,7 @@ class AbstractType(object): for k in inspect.getargspec(self.__init__)[0][1:])) class TypeEngine(AbstractType): - """Base for built-in types. - - May be sub-classed to create entirely new types. Example:: - - import sqlalchemy.types as types - - class MyType(types.TypeEngine): - def __init__(self, precision = 8): - self.precision = precision - - def get_col_spec(self): - return "MYTYPE(%s)" % self.precision - - def bind_processor(self, dialect): - def process(value): - return value - return process - - def result_processor(self, dialect): - def process(value): - return value - return process - - Once the type is made, it's immediately usable:: - - table = Table('foo', meta, - Column('id', Integer, primary_key=True), - Column('data', MyType(16)) - ) - - """ + """Base for built-in types.""" def dialect_impl(self, dialect, **kwargs): try: @@ -135,10 +109,6 @@ class TypeEngine(AbstractType): d['_impl_dict'] = {} return d - def get_col_spec(self): - """Return the DDL representation for this type.""" - raise NotImplementedError() - def bind_processor(self, dialect): """Return a conversion function for processing bind values. @@ -174,6 +144,42 @@ class TypeEngine(AbstractType): return self.__class__.__mro__[0:-1] +class UserDefinedType(TypeEngine): + """Base for user defined types. + + This should be the base of new types. Note that + for most cases, :class:`TypeDecorator` is probably + more appropriate. + + import sqlalchemy.types as types + + class MyType(types.UserDefinedType): + def __init__(self, precision = 8): + self.precision = precision + + def get_col_spec(self): + return "MYTYPE(%s)" % self.precision + + def bind_processor(self, dialect): + def process(value): + return value + return process + + def result_processor(self, dialect): + def process(value): + return value + return process + + Once the type is made, it's immediately usable:: + + table = Table('foo', meta, + Column('id', Integer, primary_key=True), + Column('data', MyType(16)) + ) + + """ + __visit_name__ = "user_defined" + class TypeDecorator(AbstractType): """Allows the creation of types which add additional functionality to an existing type. @@ -214,6 +220,8 @@ class TypeDecorator(AbstractType): """ + __visit_name__ = "type_decorator" + def __init__(self, *args, **kwargs): if not hasattr(self.__class__, 'impl'): raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated") @@ -253,9 +261,6 @@ class TypeDecorator(AbstractType): return getattr(self.impl, key) - def get_col_spec(self): - return self.impl.get_col_spec() - def process_bind_param(self, value, dialect): raise NotImplementedError() @@ -370,9 +375,7 @@ class NullType(TypeEngine): encountered during a :meth:`~sqlalchemy.Table.create` operation. """ - - def get_col_spec(self): - raise NotImplementedError() + __visit_name__ = 'null' NullTypeEngine = NullType @@ -400,6 +403,8 @@ class String(Concatenable, TypeEngine): """ + __visit_name__ = 'string' + def __init__(self, length=None, convert_unicode=False, assert_unicode=None): """ Create a string-holding type. @@ -485,6 +490,9 @@ class Text(String): params (and the reverse for result sets.) """ + + __visit_name__ = 'text' + def dialect_impl(self, dialect, **kwargs): return TypeEngine.dialect_impl(self, dialect, **kwargs) @@ -555,7 +563,9 @@ class UnicodeText(Text): class Integer(TypeEngine): """A type for ``int`` integers.""" - + + __visit_name__ = 'integer' + def get_dbapi_type(self, dbapi): return dbapi.NUMBER @@ -568,7 +578,7 @@ class SmallInteger(Integer): """ -Smallinteger = SmallInteger + __visit_name__ = 'small_integer' class Numeric(TypeEngine): """A type for fixed precision numbers. @@ -578,6 +588,8 @@ class Numeric(TypeEngine): """ + __visit_name__ = 'numeric' + def __init__(self, precision=10, scale=2, asdecimal=True, length=None): """ Construct a Numeric. @@ -628,6 +640,8 @@ class Numeric(TypeEngine): class Float(Numeric): """A type for ``float`` numbers.""" + __visit_name__ = 'float' + def __init__(self, precision=10, asdecimal=False, **kwargs): """ Construct a Float. @@ -652,7 +666,9 @@ class DateTime(TypeEngine): converted back to datetime objects when rows are returned. """ - + + __visit_name__ = 'datetime' + def __init__(self, timezone=False): self.timezone = timezone @@ -666,6 +682,8 @@ class DateTime(TypeEngine): class Date(TypeEngine): """A type for ``datetime.date()`` objects.""" + __visit_name__ = 'date' + def get_dbapi_type(self, dbapi): return dbapi.DATETIME @@ -673,6 +691,8 @@ class Date(TypeEngine): class Time(TypeEngine): """A type for ``datetime.time()`` objects.""" + __visit_name__ = 'time' + def __init__(self, timezone=False): self.timezone = timezone @@ -692,6 +712,8 @@ class Binary(TypeEngine): """ + __visit_name__ = 'binary' + def __init__(self, length=None): """ Construct a Binary type. @@ -806,6 +828,7 @@ class Boolean(TypeEngine): """ + __visit_name__ = 'boolean' class Interval(TypeDecorator): """A type for ``datetime.timedelta()`` objects. @@ -821,7 +844,7 @@ class Interval(TypeDecorator): def __init__(self): super(Interval, self).__init__() - import sqlalchemy.databases.postgres as pg + import sqlalchemy.dialects.postgres.base as pg self.__supported = {pg.PGDialect:pg.PGInterval} del pg @@ -850,66 +873,96 @@ class Interval(TypeDecorator): class FLOAT(Float): """The SQL FLOAT type.""" + __visit_name__ = 'FLOAT' class NUMERIC(Numeric): """The SQL NUMERIC type.""" + __visit_name__ = 'NUMERIC' + class DECIMAL(Numeric): """The SQL DECIMAL type.""" + __visit_name__ = 'DECIMAL' -class INT(Integer): + +class INTEGER(Integer): """The SQL INT or INTEGER type.""" + __visit_name__ = 'INTEGER' +INT = INTEGER -INTEGER = INT -class SMALLINT(Smallinteger): +class SMALLINT(SmallInteger): """The SQL SMALLINT type.""" + __visit_name__ = 'SMALLINT' + class TIMESTAMP(DateTime): """The SQL TIMESTAMP type.""" + __visit_name__ = 'TIMESTAMP' + class DATETIME(DateTime): """The SQL DATETIME type.""" + __visit_name__ = 'DATETIME' + class DATE(Date): """The SQL DATE type.""" + __visit_name__ = 'DATE' + class TIME(Time): """The SQL TIME type.""" + __visit_name__ = 'TIME' -TEXT = Text +class TEXT(Text): + """The SQL TEXT type.""" + + __visit_name__ = 'TEXT' class CLOB(Text): """The SQL CLOB type.""" + __visit_name__ = 'CLOB' + class VARCHAR(String): """The SQL VARCHAR type.""" + __visit_name__ = 'VARCHAR' + class CHAR(String): """The SQL CHAR type.""" + __visit_name__ = 'CHAR' + class NCHAR(Unicode): """The SQL NCHAR type.""" + __visit_name__ = 'NCHAR' + class BLOB(Binary): """The SQL BLOB type.""" + __visit_name__ = 'BLOB' + class BOOLEAN(Boolean): """The SQL BOOLEAN type.""" + __visit_name__ = 'BOOLEAN' + NULLTYPE = NullType() # using VARCHAR/NCHAR so that we dont get the genericized "String" @@ -927,3 +980,4 @@ type_map = { dt.timedelta : Interval, type(None): NullType } + diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 619888135d..12f155d606 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -299,6 +299,7 @@ def get_cls_kwargs(cls): class_ = stack.pop() ctr = class_.__dict__.get('__init__', False) if not ctr or not isinstance(ctr, types.FunctionType): + stack.update(class_.__bases__) continue names, _, has_kw, _ = inspect.getargspec(ctr) args.update(names) diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py index 75c0918b81..84b2a16748 100644 --- a/test/dialect/postgres.py +++ b/test/dialect/postgres.py @@ -2,12 +2,12 @@ import testenv; testenv.configure_for_tests() import datetime from sqlalchemy import * from sqlalchemy.orm import * -from sqlalchemy import exc -from sqlalchemy.databases import postgres +from sqlalchemy import exc, schema +from sqlalchemy.dialects.postgres import base as postgres from sqlalchemy.engine.strategies import MockEngineStrategy from testlib import * from sqlalchemy.sql import table, column - +from testlib.testing import eq_ class SequenceTest(TestBase, AssertsCompiledSQL): def test_basic(self): @@ -58,6 +58,14 @@ class CompileTest(TestBase, AssertsCompiledSQL): i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)]) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect) + def test_create_partial_index(self): + tbl = Table('testtbl', MetaData(), Column('data',Integer)) + idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + + self.assert_compile(schema.CreateIndex(idx), + "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgres.dialect()) + + class ReturningTest(TestBase, AssertsExecutionResults): __only_on__ = 'postgres' @@ -406,7 +414,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True) self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns") - self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger) + assert isinstance(table.c.answer.type, Integer) def test_domain_is_reflected(self): metadata = MetaData(testing.db) @@ -418,7 +426,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True, schema='alt_schema') self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns") - self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger) + assert isinstance(table.c.anything.type, Integer) def test_schema_domain_is_reflected(self): metadata = MetaData(testing.db) @@ -432,7 +440,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): self.assertEquals(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value") self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.") -class MiscTest(TestBase, AssertsExecutionResults): +class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): __only_on__ = 'postgres' def test_date_reflection(self): @@ -666,17 +674,6 @@ class MiscTest(TestBase, AssertsExecutionResults): warnings.warn = capture_warnings._orig_showwarning m1.drop_all() - def test_create_partial_index(self): - tbl = Table('testtbl', MetaData(), Column('data',Integer)) - idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10)) - - executed_sql = [] - mock_strategy = MockEngineStrategy() - mock_conn = mock_strategy.create('postgres://', executed_sql.append) - - idx.create(mock_conn) - - assert executed_sql == ['CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10'] class TimezoneTest(TestBase, AssertsExecutionResults): """Test timezone-aware datetimes. diff --git a/test/dialect/sqlite.py b/test/dialect/sqlite.py index 97d12bf603..29beec8d35 100644 --- a/test/dialect/sqlite.py +++ b/test/dialect/sqlite.py @@ -4,7 +4,7 @@ import testenv; testenv.configure_for_tests() import datetime from sqlalchemy import * from sqlalchemy import exc -from sqlalchemy.databases import sqlite +from sqlalchemy.dialects.sqlite import base as sqlite, pysqlite as pysqlite_dialect from testlib import * @@ -50,23 +50,18 @@ class TestTypes(TestBase, AssertsExecutionResults): @testing.uses_deprecated('Using String type with no length') def test_type_reflection(self): # (ask_for, roundtripped_as_if_different) - specs = [( String(), sqlite.SLString(), ), - ( String(1), sqlite.SLString(1), ), - ( String(3), sqlite.SLString(3), ), - ( Text(), sqlite.SLText(), ), - ( Unicode(), sqlite.SLString(), ), - ( Unicode(1), sqlite.SLString(1), ), - ( Unicode(3), sqlite.SLString(3), ), - ( UnicodeText(), sqlite.SLText(), ), - ( CLOB, sqlite.SLText(), ), - ( sqlite.SLChar(1), ), - ( CHAR(3), sqlite.SLChar(3), ), - ( NCHAR(2), sqlite.SLChar(2), ), - ( SmallInteger(), sqlite.SLSmallInteger(), ), - ( sqlite.SLSmallInteger(), ), - ( Binary(3), sqlite.SLBinary(), ), - ( Binary(), sqlite.SLBinary() ), - ( sqlite.SLBinary(3), sqlite.SLBinary(), ), + specs = [( String(), pysqlite_dialect.SLString(), ), + ( String(1), pysqlite_dialect.SLString(1), ), + ( String(3), pysqlite_dialect.SLString(3), ), + ( Text(), pysqlite_dialect.SLText(), ), + ( Unicode(), pysqlite_dialect.SLString(), ), + ( Unicode(1), pysqlite_dialect.SLString(1), ), + ( Unicode(3), pysqlite_dialect.SLString(3), ), + ( UnicodeText(), pysqlite_dialect.SLText(), ), + ( CLOB, pysqlite_dialect.SLText(), ), + ( pysqlite_dialect.SLChar(1), ), + ( CHAR(3), pysqlite_dialect.SLChar(3), ), + ( NCHAR(2), pysqlite_dialect.SLChar(2), ), ( NUMERIC, sqlite.SLNumeric(), ), ( NUMERIC(10,2), sqlite.SLNumeric(10,2), ), ( Numeric, sqlite.SLNumeric(), ), @@ -75,9 +70,6 @@ class TestTypes(TestBase, AssertsExecutionResults): ( DECIMAL(10, 2), sqlite.SLNumeric(10, 2), ), ( Float, sqlite.SLNumeric(), ), ( sqlite.SLNumeric(), ), - ( INT, sqlite.SLInteger(), ), - ( Integer, sqlite.SLInteger(), ), - ( sqlite.SLInteger(), ), ( TIMESTAMP, sqlite.SLDateTime(), ), ( DATETIME, sqlite.SLDateTime(), ), ( DateTime, sqlite.SLDateTime(), ), @@ -113,7 +105,8 @@ class TestTypes(TestBase, AssertsExecutionResults): finally: db.execute('DROP VIEW types_v') finally: - m.drop_all() + pass + #m.drop_all() class TestDefaults(TestBase, AssertsExecutionResults): diff --git a/test/engine/ddlevents.py b/test/engine/ddlevents.py index 8274c63476..4c929b766c 100644 --- a/test/engine/ddlevents.py +++ b/test/engine/ddlevents.py @@ -4,7 +4,7 @@ from sqlalchemy import create_engine from testlib.sa import MetaData, Table, Column, Integer, String import testlib.sa as tsa from testlib import TestBase, testing, engines - +from testlib.testing import AssertsCompiledSQL class DDLEventTest(TestBase): class Canary(object): @@ -284,7 +284,7 @@ class DDLExecutionTest(TestBase): r = eval(py) assert list(r) == [(1,)], py -class DDLTest(TestBase): +class DDLTest(TestBase, AssertsCompiledSQL): def mock_engine(self): executor = lambda *a, **kw: None engine = create_engine(testing.db.name + '://', @@ -303,20 +303,21 @@ class DDLTest(TestBase): ddl = DDL('%(schema)s-%(table)s-%(fullname)s') - self.assertEquals(ddl._expand(sane_alone, bind), '-t-t') - self.assertEquals(ddl._expand(sane_schema, bind), 's-t-s.t') - self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"') - self.assertEquals(ddl._expand(insane_schema, bind), - '"s s"-"t t"-"s s"."t t"') + dialect = bind.dialect + self.assert_compile(ddl.against(sane_alone), '-t-t', dialect=dialect) + self.assert_compile(ddl.against(sane_schema), 's-t-s.t', dialect=dialect) + self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect) + self.assert_compile(ddl.against(insane_schema), '"s s"-"t t"-"s s"."t t"', dialect=dialect) # overrides are used piece-meal and verbatim. ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s', context={'schema':'S S', 'table': 'T T', 'bonus': 'b'}) - self.assertEquals(ddl._expand(sane_alone, bind), 'S S-T T-t-b') - self.assertEquals(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b') - self.assertEquals(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b') - self.assertEquals(ddl._expand(insane_schema, bind), - 'S S-T T-"s s"."t t"-b') + + self.assert_compile(ddl.against(sane_alone), 'S S-T T-t-b', dialect=dialect) + self.assert_compile(ddl.against(sane_schema), 'S S-T T-s.t-b', dialect=dialect) + self.assert_compile(ddl.against(insane_alone), 'S S-T T-"t t"-b', dialect=dialect) + self.assert_compile(ddl.against(insane_schema), 'S S-T T-"s s"."t t"-b', dialect=dialect) + def test_filter(self): cx = self.mock_engine() diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 8e6a3df987..ac245981e8 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -1,6 +1,7 @@ import testenv; testenv.configure_for_tests() import StringIO, unicodedata import sqlalchemy as sa +from sqlalchemy import schema from testlib.sa import MetaData, Table, Column from testlib import TestBase, ComparesTables, testing, engines, sa as tsa @@ -49,8 +50,7 @@ class ReflectionTest(TestBase, ComparesTables): self.assert_tables_equal(users, reflected_users) self.assert_tables_equal(addresses, reflected_addresses) finally: - addresses.drop() - users.drop() + meta.drop_all() def test_include_columns(self): meta = MetaData(testing.db) @@ -87,20 +87,9 @@ class ReflectionTest(TestBase, ComparesTables): t = Table("test", meta, Column('foo', sa.DateTime)) - import sys - dialect_module = sys.modules[testing.db.dialect.__module__] - - # we're relying on the presence of "ischema_names" in the - # dialect module, else we can't test this. we need to be able - # to get the dialect to not be aware of some type so we temporarily - # monkeypatch. not sure what a better way for this could be, - # except for an established dialect hook or dialect-specific tests - if not hasattr(dialect_module, 'ischema_names'): - return - - ischema_names = dialect_module.ischema_names + ischema_names = testing.db.dialect.ischema_names t.create() - dialect_module.ischema_names = {} + testing.db.dialect.ischema_names = {} try: m2 = MetaData(testing.db) self.assertRaises(tsa.exc.SAWarning, Table, "test", m2, autoload=True) @@ -112,7 +101,7 @@ class ReflectionTest(TestBase, ComparesTables): assert t3.c.foo.type.__class__ == sa.types.NullType finally: - dialect_module.ischema_names = ischema_names + testing.db.dialect.ischema_names = ischema_names t.drop() def test_basic_override(self): @@ -718,8 +707,9 @@ class UnicodeReflectionTest(TestBase): r.drop_all() r.create_all() finally: - metadata.drop_all() - bind.dispose() + pass +# metadata.drop_all() +# bind.dispose() class SchemaTest(TestBase): @@ -733,23 +723,15 @@ class SchemaTest(TestBase): Column('col1', sa.Integer, primary_key=True), Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')), schema='someschema') - # ensure this doesnt crash - print [t for t in metadata.sorted_tables] - buf = StringIO.StringIO() - def foo(s, p=None): - buf.write(s) - gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo) - gen = gen.dialect.schemagenerator(gen.dialect, gen) - gen.traverse(table1) - gen.traverse(table2) - buf = buf.getvalue() - print buf + + t1 = str(schema.CreateTable(table1).compile(bind=testing.db)) + t2 = str(schema.CreateTable(table2).compile(bind=testing.db)) if testing.db.dialect.preparer(testing.db.dialect).omit_schema: - assert buf.index("CREATE TABLE table1") > -1 - assert buf.index("CREATE TABLE table2") > -1 + assert t1.index("CREATE TABLE table1") > -1 + assert t2.index("CREATE TABLE table2") > -1 else: - assert buf.index("CREATE TABLE someschema.table1") > -1 - assert buf.index("CREATE TABLE someschema.table2") > -1 + assert t1.index("CREATE TABLE someschema.table1") > -1 + assert t2.index("CREATE TABLE someschema.table2") > -1 @testing.crashes('firebird', 'No schema support') @testing.fails_on('sqlite', 'FIXME: unknown') diff --git a/test/ext/alltests.py b/test/ext/alltests.py index 3f0360e85e..4733292483 100644 --- a/test/ext/alltests.py +++ b/test/ext/alltests.py @@ -1,6 +1,5 @@ import testenv; testenv.configure_for_tests() import doctest, sys - from testlib import sa_unittest as unittest diff --git a/test/ext/declarative.py b/test/ext/declarative.py index c9477b5d85..3176832f30 100644 --- a/test/ext/declarative.py +++ b/test/ext/declarative.py @@ -63,26 +63,6 @@ class DeclarativeTest(DeclarativeTestBase): class User(Base): id = Column('id', Integer, primary_key=True) self.assertRaisesMessage(sa.exc.InvalidRequestError, "does not have a __table__", go) - - def test_cant_add_columns(self): - t = Table('t', Base.metadata, Column('id', Integer, primary_key=True)) - def go(): - class User(Base): - __table__ = t - foo = Column(Integer, primary_key=True) - self.assertRaisesMessage(sa.exc.ArgumentError, "add additional columns", go) - - def test_undefer_column_name(self): - # TODO: not sure if there was an explicit - # test for this elsewhere - foo = Column(Integer) - eq_(str(foo), '(no name)') - eq_(foo.key, None) - eq_(foo.name, None) - decl._undefer_column_name('foo', foo) - eq_(str(foo), 'foo') - eq_(foo.key, 'foo') - eq_(foo.name, 'foo') def test_recompile_on_othermapper(self): """declarative version of the same test in mappers.py""" diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 553713da53..a4363b5e5f 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -2216,49 +2216,6 @@ class RowSwitchTest(_base.MappedTest): assert list(sess.execute(t5.select(), mapper=T5)) == [(2, 'some other t5')] assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some other t6', 2)] -class InheritingRowSwitchTest(_base.MappedTest): - def define_tables(self, metadata): - Table('parent', metadata, - Column('id', Integer, primary_key=True), - Column('pdata', String(30)) - ) - Table('child', metadata, - Column('id', Integer, primary_key=True), - Column('pid', Integer, ForeignKey('parent.id')), - Column('cdata', String(30)) - ) - - def setup_classes(self): - class P(_base.ComparableEntity): - pass - - class C(P): - pass - - @testing.resolve_artifact_names - def test_row_switch_no_child_table(self): - mapper(P, parent) - mapper(C, child, inherits=P) - - sess = create_session() - c1 = C(id=1, pdata='c1', cdata='c1') - sess.add(c1) - sess.flush() - - # establish a row switch between c1 and c2. - # c2 has no value for the "child" table - c2 = C(id=1, pdata='c2') - sess.add(c2) - sess.delete(c1) - - self.assert_sql_execution(testing.db, sess.flush, - CompiledSQL("UPDATE parent SET pdata=:pdata WHERE parent.id = :parent_id", - {'pdata':'c2', 'parent_id':1} - ) - ) - - - class TransactionTest(_base.MappedTest): __requires__ = ('deferrable_constraints',) diff --git a/test/sql/constraints.py b/test/sql/constraints.py index d019aa0378..b03005c00e 100644 --- a/test/sql/constraints.py +++ b/test/sql/constraints.py @@ -1,10 +1,13 @@ import testenv; testenv.configure_for_tests() from sqlalchemy import * -from sqlalchemy import exc +from sqlalchemy import exc, schema from testlib import * from testlib import config, engines +from sqlalchemy.engine import ddl +from testlib.testing import eq_ +from testlib.assertsql import AllOf, RegexSQL, ExactSQL, CompiledSQL -class ConstraintTest(TestBase, AssertsExecutionResults): +class ConstraintTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): def setUp(self): global metadata @@ -38,7 +41,6 @@ class ConstraintTest(TestBase, AssertsExecutionResults): Column('y', Integer, f) ) - def test_circular_constraint(self): a = Table("a", metadata, Column('id', Integer, primary_key=True), @@ -78,18 +80,9 @@ class ConstraintTest(TestBase, AssertsExecutionResults): metadata.create_all() foo.insert().execute(id=1,x=9,y=5) - try: - foo.insert().execute(id=2,x=5,y=9) - assert False - except exc.SQLError: - assert True - + self.assertRaises(exc.SQLError, foo.insert().execute, id=2,x=5,y=9) bar.insert().execute(id=1,x=10) - try: - bar.insert().execute(id=2,x=5) - assert False - except exc.SQLError: - assert True + self.assertRaises(exc.SQLError, bar.insert().execute, id=2,x=5) def test_unique_constraint(self): foo = Table('foo', metadata, @@ -106,16 +99,8 @@ class ConstraintTest(TestBase, AssertsExecutionResults): foo.insert().execute(id=2, value='value2') bar.insert().execute(id=1, value='a', value2='a') bar.insert().execute(id=2, value='a', value2='b') - try: - foo.insert().execute(id=3, value='value1') - assert False - except exc.SQLError: - assert True - try: - bar.insert().execute(id=3, value='a', value2='b') - assert False - except exc.SQLError: - assert True + self.assertRaises(exc.SQLError, foo.insert().execute, id=3, value='value1') + self.assertRaises(exc.SQLError, bar.insert().execute, id=3, value='a', value2='b') def test_index_create(self): employees = Table('employees', metadata, @@ -174,35 +159,22 @@ class ConstraintTest(TestBase, AssertsExecutionResults): Index('sport_announcer', events.c.sport, events.c.announcer, unique=True) Index('idx_winners', events.c.winner) - index_names = [ ix.name for ix in events.indexes ] - assert 'ix_events_name' in index_names - assert 'ix_events_location' in index_names - assert 'sport_announcer' in index_names - assert 'idx_winners' in index_names - assert len(index_names) == 4 - - capt = [] - connection = testing.db.connect() - # TODO: hacky, put a real connection proxy in - ex = connection._Connection__execute_context - def proxy(context): - capt.append(context.statement) - capt.append(repr(context.parameters)) - ex(context) - connection._Connection__execute_context = proxy - schemagen = testing.db.dialect.schemagenerator(testing.db.dialect, connection) - schemagen.traverse(events) - - assert capt[0].strip().startswith('CREATE TABLE events') - - s = set([capt[x].strip() for x in [2,4,6,8]]) - - assert s == set([ - 'CREATE UNIQUE INDEX ix_events_name ON events (name)', - 'CREATE INDEX ix_events_location ON events (location)', - 'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)', - 'CREATE INDEX idx_winners ON events (winner)' - ]) + eq_( + set([ ix.name for ix in events.indexes ]), + set(['ix_events_name', 'ix_events_location', 'sport_announcer', 'idx_winners']) + ) + + self.assert_sql_execution( + testing.db, + lambda: events.create(testing.db), + RegexSQL("^CREATE TABLE events"), + AllOf( + ExactSQL('CREATE UNIQUE INDEX ix_events_name ON events (name)'), + ExactSQL('CREATE INDEX ix_events_location ON events (location)'), + ExactSQL('CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'), + ExactSQL('CREATE INDEX idx_winners ON events (winner)') + ) + ) # verify that the table is functional events.insert().execute(id=1, name='hockey finals', location='rink', @@ -214,84 +186,57 @@ class ConstraintTest(TestBase, AssertsExecutionResults): dialect = testing.db.dialect.__class__() dialect.max_identifier_length = 20 - schemagen = dialect.schemagenerator(dialect, None) - schemagen.execute = lambda : None - t1 = Table("sometable", MetaData(), Column("foo", Integer)) - schemagen.visit_index(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)) - self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)") - schemagen.buffer.truncate(0) - schemagen.visit_index(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)) - self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)") - - schemadrop = dialect.schemadropper(dialect, None) - schemadrop.execute = lambda: None - self.assertRaises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)) + self.assert_compile( + schema.CreateIndex(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)), + "CREATE INDEX this_name_is_t_1 ON sometable (foo)", + dialect=dialect + ) + + self.assert_compile( + schema.CreateIndex(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)), + "CREATE INDEX this_other_nam_1 ON sometable (foo)", + dialect=dialect + ) -class ConstraintCompilationTest(TestBase, AssertsExecutionResults): - class accum(object): - def __init__(self): - self.statements = [] - def __call__(self, sql, *a, **kw): - self.statements.append(sql) - def __contains__(self, substring): - for s in self.statements: - if substring in s: - return True - return False - def __str__(self): - return '\n'.join([repr(x) for x in self.statements]) - def clear(self): - del self.statements[:] - - def setUp(self): - self.sql = self.accum() - opts = config.db_opts.copy() - opts['strategy'] = 'mock' - opts['executor'] = self.sql - self.engine = engines.testing_engine(options=opts) - +class ConstraintCompilationTest(TestBase): def _test_deferrable(self, constraint_factory): - meta = MetaData(self.engine) - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=True)) - t.create() - assert 'DEFERRABLE' in self.sql, self.sql - assert 'NOT DEFERRABLE' not in self.sql, self.sql - self.sql.clear() - meta.clear() - - t = Table('tbl', meta, + + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'DEFERRABLE' in sql, sql + assert 'NOT DEFERRABLE' not in sql, sql + + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=False)) - t.create() - assert 'NOT DEFERRABLE' in self.sql - self.sql.clear() - meta.clear() - t = Table('tbl', meta, + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'NOT DEFERRABLE' in sql + + + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=True, initially='IMMEDIATE')) - t.create() - assert 'NOT DEFERRABLE' not in self.sql - assert 'INITIALLY IMMEDIATE' in self.sql - self.sql.clear() - meta.clear() + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'NOT DEFERRABLE' not in sql + assert 'INITIALLY IMMEDIATE' in sql - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer), constraint_factory(deferrable=True, initially='DEFERRED')) - t.create() + sql = str(schema.CreateTable(t).compile(bind=testing.db)) - assert 'NOT DEFERRABLE' not in self.sql - assert 'INITIALLY DEFERRED' in self.sql, self.sql + assert 'NOT DEFERRABLE' not in sql + assert 'INITIALLY DEFERRED' in sql def test_deferrable_pk(self): factory = lambda **kw: PrimaryKeyConstraint('a', **kw) @@ -302,15 +247,15 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults): self._test_deferrable(factory) def test_deferrable_column_fk(self): - meta = MetaData(self.engine) - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer, ForeignKey('tbl.a', deferrable=True, initially='DEFERRED'))) - t.create() - assert 'DEFERRABLE' in self.sql, self.sql - assert 'INITIALLY DEFERRED' in self.sql, self.sql + + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'DEFERRABLE' in sql + assert 'INITIALLY DEFERRED' in sql def test_deferrable_unique(self): factory = lambda **kw: UniqueConstraint('b', **kw) @@ -321,16 +266,15 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults): self._test_deferrable(factory) def test_deferrable_column_check(self): - meta = MetaData(self.engine) - t = Table('tbl', meta, + t = Table('tbl', MetaData(), Column('a', Integer), Column('b', Integer, CheckConstraint('a < b', deferrable=True, initially='DEFERRED'))) - t.create() - assert 'DEFERRABLE' in self.sql, self.sql - assert 'INITIALLY DEFERRED' in self.sql, self.sql + sql = str(schema.CreateTable(t).compile(bind=testing.db)) + assert 'DEFERRABLE' in sql + assert 'INITIALLY DEFERRED' in sql if __name__ == "__main__": diff --git a/test/sql/select.py b/test/sql/select.py index ea9f27cdf2..671ccab1a0 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -5,7 +5,9 @@ from sqlalchemy import exc, sql, util from sqlalchemy.sql import table, column, label, compiler from sqlalchemy.sql.expression import ClauseList from sqlalchemy.engine import default -from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql +from sqlalchemy.databases import mysql, oracle, firebird, mssql +from sqlalchemy.dialects.sqlite import pysqlite as sqlite +from sqlalchemy.dialects.postgres import psycopg2 as postgres from testlib import * table1 = table('mytable', diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 44b83defd8..da649d0970 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -2,11 +2,14 @@ import decimal import testenv; testenv.configure_for_tests() import datetime, os, pickleable, re from sqlalchemy import * -from sqlalchemy import exc, types, util +from sqlalchemy import exc, types, util, schema from sqlalchemy.sql import operators from testlib.testing import eq_ import sqlalchemy.engine.url as url -from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird +from sqlalchemy.databases import mssql, oracle, mysql, firebird +from sqlalchemy.dialects.sqlite import pysqlite as sqlite +from sqlalchemy.dialects.postgres import psycopg2 as postgres + from testlib import * @@ -80,12 +83,12 @@ class AdaptTest(TestBase): (mysql_dialect, Unicode(), mysql.MSString), (mysql_dialect, UnicodeText(), mysql.MSText), (mysql_dialect, NCHAR(), mysql.MSNChar), - (postgres_dialect, String(), postgres.PGString), - (postgres_dialect, VARCHAR(), postgres.PGString), - (postgres_dialect, String(50), postgres.PGString), - (postgres_dialect, Unicode(), postgres.PGString), - (postgres_dialect, UnicodeText(), postgres.PGText), - (postgres_dialect, NCHAR(), postgres.PGString), + (postgres_dialect, String(), String), + (postgres_dialect, VARCHAR(), String), + (postgres_dialect, String(50), String), + (postgres_dialect, Unicode(), String), + (postgres_dialect, UnicodeText(), Text), + (postgres_dialect, NCHAR(), String), (firebird_dialect, String(), firebird.FBString), (firebird_dialect, VARCHAR(), firebird.FBString), (firebird_dialect, String(50), firebird.FBString), @@ -100,11 +103,6 @@ class AdaptTest(TestBase): class UserDefinedTest(TestBase): """tests user-defined types.""" - def testbasic(self): - print users.c.goofy4.type - print users.c.goofy4.type.dialect_impl(testing.db.dialect) - print users.c.goofy4.type.dialect_impl(testing.db.dialect).get_col_spec() - def testprocessing(self): global users @@ -135,7 +133,7 @@ class UserDefinedTest(TestBase): def setUpAll(self): global users, metadata - class MyType(types.TypeEngine): + class MyType(types.UserDefinedType): def get_col_spec(self): return "VARCHAR(100)" def bind_processor(self, dialect): @@ -259,7 +257,6 @@ class ColumnsTest(TestBase, AssertsExecutionResults): for key, value in expectedResults.items(): expectedResults[key] = '%s NULL' % value - print db.engine.__module__ testTable = Table('testColumns', MetaData(db), Column('int_column', Integer), Column('smallint_column', SmallInteger), @@ -271,7 +268,7 @@ class ColumnsTest(TestBase, AssertsExecutionResults): for aCol in testTable.c: self.assertEquals( expectedResults[aCol.name], - db.dialect.schemagenerator(db.dialect, db, None, None).\ + db.dialect.ddl_compiler(db.dialect, schema.CreateTable(testTable)).\ get_column_specification(aCol)) class UnicodeTest(TestBase, AssertsExecutionResults): @@ -469,7 +466,7 @@ class ExpressionTest(TestBase, AssertsExecutionResults): def setUpAll(self): global test_table, meta - class MyCustomType(types.TypeEngine): + class MyCustomType(types.UserDefinedType): def get_col_spec(self): return "INT" def bind_processor(self, dialect): @@ -712,6 +709,7 @@ class NumericTest(TestBase, AssertsExecutionResults): from decimal import Decimal numeric_table.insert().execute( numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.75) + numeric_table.insert().execute( numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), ncasdec=Decimal("12.4"), fcasdec=Decimal("15.75")) @@ -753,7 +751,7 @@ class NumericTest(TestBase, AssertsExecutionResults): eq_(n2.scale, 12, dialect.name) # test colspec generates successfully using 'scale' - assert n2.get_col_spec() + assert dialect.type_compiler.process(n2) # test constructor of the dialect-specific type n3 = n2.__class__(scale=5) diff --git a/test/testlib/engines.py b/test/testlib/engines.py index 4068f43d0a..df1d37d3cd 100644 --- a/test/testlib/engines.py +++ b/test/testlib/engines.py @@ -71,6 +71,10 @@ def all_dialects(): for name in d.__all__: mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) yield mod.dialect() + import sqlalchemy.dialects as d + for name in d.__all__: + mod = getattr(__import__('sqlalchemy.dialects.%s.base' % name).dialects, name).base + yield mod.dialect() class ReconnectFixture(object): def __init__(self, dbapi): diff --git a/test/testlib/testing.py b/test/testlib/testing.py index fffb301f2f..fb77b07bb1 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -615,14 +615,13 @@ class AssertsCompiledSQL(object): if dialect is None: dialect = getattr(self, '__dialect__', None) - if params is None: - keys = None - else: - keys = params.keys() + kw = {} + if params is not None: + kw['column_keys'] = params.keys() - c = clause.compile(column_keys=keys, dialect=dialect) + c = clause.compile(dialect=dialect, **kw) - print "\nSQL String:\n" + str(c) + repr(c.params) + print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {})) cc = re.sub(r'\n', '', str(c))