This is the MIT license: `<http://www.opensource.org/licenses/mit-license.php>`_
-Copyright (c) 2005, 2006, 2007, 2008 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael
+Copyright (c) 2005, 2006, 2007, 2008, 2009 Michael Bayer and contributors. SQLAlchemy is a trademark of Michael
Bayer.
Permission is hereby granted, free of charge, to any person obtaining a copy of this
'mssql',
'mysql',
'oracle',
- 'postgres',
- 'sqlite',
'sybase',
)
def get_col_spec(self):
return "TINYINT"
-class AcSmallInteger(types.Smallinteger):
+class AcSmallInteger(types.SmallInteger):
def get_col_spec(self):
return "SMALLINT"
colspecs = {
types.Unicode : AcUnicode,
types.Integer : AcInteger,
- types.Smallinteger: AcSmallInteger,
+ types.SmallInteger: AcSmallInteger,
types.Numeric : AcNumeric,
types.Float : AcFloat,
types.DateTime : AcDateTime,
return names
-class AccessCompiler(compiler.DefaultCompiler):
+class AccessCompiler(compiler.SQLCompiler):
def visit_select_precolumns(self, select):
"""Access puts TOP, it's version of LIMIT here """
s = select.distinct and "DISTINCT " or ""
return "INTEGER"
-class FBSmallInteger(sqltypes.Smallinteger):
+class FBSmallInteger(sqltypes.SmallInteger):
"""Handle ``SMALLINT`` datatype."""
def get_col_spec(self):
colspecs = {
sqltypes.Integer : FBInteger,
- sqltypes.Smallinteger : FBSmallInteger,
+ sqltypes.SmallInteger : FBSmallInteger,
sqltypes.Numeric : FBNumeric,
sqltypes.Float : FBFloat,
sqltypes.DateTime : FBDateTime,
return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
-class FBCompiler(sql.compiler.DefaultCompiler):
+class FBCompiler(sql.compiler.SQLCompiler):
"""Firebird specific idiosincrasies"""
# Firebird lacks a builtin modulo operator, but there is
# an equivalent function in the ib_udf library.
- operators = sql.compiler.DefaultCompiler.operators.copy()
+ operators = sql.compiler.SQLCompiler.operators.copy()
operators.update({
sql.operators.mod : lambda x, y:"mod(%s, %s)" % (x, y)
})
else:
return self.process(alias.original, **kwargs)
- functions = sql.compiler.DefaultCompiler.functions.copy()
+ functions = sql.compiler.SQLCompiler.functions.copy()
functions['substring'] = _substring
def function_argspec(self, func):
def get_col_spec(self):
return "INTEGER"
-class InfoSmallInteger(sqltypes.Smallinteger):
+class InfoSmallInteger(sqltypes.SmallInteger):
def get_col_spec(self):
return "SMALLINT"
colspecs = {
sqltypes.Integer : InfoInteger,
- sqltypes.Smallinteger : InfoSmallInteger,
+ sqltypes.SmallInteger : InfoSmallInteger,
sqltypes.Numeric : InfoNumeric,
sqltypes.Float : InfoNumeric,
sqltypes.DateTime : InfoDateTime,
for cons_name, cons_type, local_column in rows:
table.primary_key.add( table.c[local_column] )
-class InfoCompiler(compiler.DefaultCompiler):
+class InfoCompiler(compiler.SQLCompiler):
"""Info compiler modifies the lexical structure of Select statements to work under
non-ANSI configured Oracle databases, if the use_ansi flag is False."""
self.limit = 0
self.offset = 0
- compiler.DefaultCompiler.__init__( self , *args, **kwargs )
+ compiler.SQLCompiler.__init__( self , *args, **kwargs )
def default_from(self):
return " from systables where tabname = 'systables' "
if ( __label(c) not in a ):
select.append_column( c )
- return compiler.DefaultCompiler.visit_select(self, select)
+ return compiler.SQLCompiler.visit_select(self, select)
def limit_clause(self, select):
return ""
elif func.name.lower() in ( 'current_timestamp' , 'now' ):
return "CURRENT YEAR TO SECOND"
else:
- return compiler.DefaultCompiler.visit_function( self , func )
+ return compiler.SQLCompiler.visit_function( self , func )
def visit_clauselist(self, list, **kwargs):
return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None])
colspecs = {
sqltypes.Integer: MaxInteger,
- sqltypes.Smallinteger: MaxSmallInteger,
+ sqltypes.SmallInteger: MaxSmallInteger,
sqltypes.Numeric: MaxNumeric,
sqltypes.Float: MaxFloat,
sqltypes.DateTime: MaxTimestamp,
return found
-class MaxDBCompiler(compiler.DefaultCompiler):
- operators = compiler.DefaultCompiler.operators.copy()
+class MaxDBCompiler(compiler.SQLCompiler):
+ operators = compiler.SQLCompiler.operators.copy()
operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y)
function_conversion = {
colspecs = {
sqltypes.Unicode : MSNVarchar,
sqltypes.Integer : MSInteger,
- sqltypes.Smallinteger: MSSmallInteger,
+ sqltypes.SmallInteger: MSSmallInteger,
sqltypes.Numeric : MSNumeric,
sqltypes.Float : MSFloat,
sqltypes.DateTime : MSDateTime,
}
-class MSSQLCompiler(compiler.DefaultCompiler):
+class MSSQLCompiler(compiler.SQLCompiler):
operators = compiler.OPERATORS.copy()
operators.update({
sql_operators.concat_op: '+',
sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
})
- functions = compiler.DefaultCompiler.functions.copy()
+ functions = compiler.SQLCompiler.functions.copy()
functions.update (
{
sql_functions.now: 'CURRENT_TIMESTAMP',
if not self.dialect.has_window_funcs:
raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
return s
- return compiler.DefaultCompiler.get_select_precolumns(self, select)
+ return compiler.SQLCompiler.get_select_precolumns(self, select)
def limit_clause(self, select):
# Limit in mssql is after the select keyword
limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
return self.process(limitselect, iswrapper=True, **kwargs)
else:
- return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
+ return compiler.SQLCompiler.visit_select(self, select, **kwargs)
def _schema_aliased_table(self, table):
if getattr(table, 'schema', None) is not None:
return self._extend("TINYINT")
-class MSSmallInteger(sqltypes.Smallinteger, MSInteger):
+class MSSmallInteger(sqltypes.SmallInteger, MSInteger):
"""MySQL SMALLINTEGER type."""
def __init__(self, display_width=None, **kw):
colspecs = {
sqltypes.Integer: MSInteger,
- sqltypes.Smallinteger: MSSmallInteger,
+ sqltypes.SmallInteger: MSSmallInteger,
sqltypes.Numeric: MSNumeric,
sqltypes.Float: MSFloat,
sqltypes.DateTime: MSDateTime,
return item
-class MySQLCompiler(compiler.DefaultCompiler):
- operators = compiler.DefaultCompiler.operators.copy()
+class MySQLCompiler(compiler.SQLCompiler):
+ operators = compiler.SQLCompiler.operators.copy()
operators.update({
sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
sql_operators.mod: '%%',
sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y)
})
- functions = compiler.DefaultCompiler.functions.copy()
+ functions = compiler.SQLCompiler.functions.copy()
functions.update ({
sql_functions.random: 'rand%(expr)s',
"utc_timestamp":"UTC_TIMESTAMP"
self.isupdate = True
colparams = self._get_colparams(update_stmt)
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams])
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \
+ " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams])
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
def get_col_spec(self):
return "INTEGER"
-class OracleSmallInteger(sqltypes.Smallinteger):
+class OracleSmallInteger(sqltypes.SmallInteger):
def get_col_spec(self):
return "SMALLINT"
colspecs = {
sqltypes.Integer : OracleInteger,
- sqltypes.Smallinteger : OracleSmallInteger,
+ sqltypes.SmallInteger : OracleSmallInteger,
sqltypes.Numeric : OracleNumeric,
sqltypes.Float : OracleNumeric,
sqltypes.DateTime : OracleDateTime,
def __init__(self, column):
self.column = column
-class OracleCompiler(compiler.DefaultCompiler):
+class OracleCompiler(compiler.SQLCompiler):
"""Oracle compiler modifies the lexical structure of Select
statements to work under non-ANSI configured Oracle databases, if
the use_ansi flag is False.
"""
- operators = compiler.DefaultCompiler.operators.copy()
+ operators = compiler.SQLCompiler.operators.copy()
operators.update(
{
sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y),
}
)
- functions = compiler.DefaultCompiler.functions.copy()
+ functions = compiler.SQLCompiler.functions.copy()
functions.update (
{
sql_functions.now : 'CURRENT_TIMESTAMP'
def visit_join(self, join, **kwargs):
if self.dialect.use_ansi:
- return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
+ return compiler.SQLCompiler.visit_join(self, join, **kwargs)
else:
return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
select = offsetselect
kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
- return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
+ return compiler.SQLCompiler.visit_select(self, select, **kwargs)
def limit_clause(self, select):
return ""
+++ /dev/null
-# sqlite.py
-# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""Support for the SQLite database.
-
-Driver
-------
-
-When using Python 2.5 and above, the built in ``sqlite3`` driver is
-already installed and no additional installation is needed. Otherwise,
-the ``pysqlite2`` driver needs to be present. This is the same driver as
-``sqlite3``, just with a different name.
-
-The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
-is loaded. This allows an explicitly installed pysqlite driver to take
-precedence over the built in one. As with all dialects, a specific
-DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control
-this explicitly::
-
- from sqlite3 import dbapi2 as sqlite
- e = create_engine('sqlite:///file.db', module=sqlite)
-
-Full documentation on pysqlite is available at:
-`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
-
-Connect Strings
----------------
-
-The file specification for the SQLite database is taken as the "database" portion of
-the URL. Note that the format of a url is::
-
- driver://user:pass@host/database
-
-This means that the actual filename to be used starts with the characters to the
-**right** of the third slash. So connecting to a relative filepath looks like::
-
- # relative path
- e = create_engine('sqlite:///path/to/database.db')
-
-An absolute path, which is denoted by starting with a slash, means you need **four**
-slashes::
-
- # absolute path
- e = create_engine('sqlite:////path/to/database.db')
-
-To use a Windows path, regular drive specifications and backslashes can be used.
-Double backslashes are probably needed::
-
- # absolute path on Windows
- e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
-
-The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify
-``sqlite://`` and nothing else::
-
- # in-memory database
- e = create_engine('sqlite://')
-
-Threading Behavior
-------------------
-
-Pysqlite connections do not support being moved between threads, unless
-the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition,
-when using an in-memory SQLite database, the full database exists only within
-the scope of a single connection. It is reported that an in-memory
-database does not support being shared between threads regardless of the
-``check_same_thread`` flag - which means that a multithreaded
-application **cannot** share data from a ``:memory:`` database across threads
-unless access to the connection is limited to a single worker thread which communicates
-through a queueing mechanism to concurrent threads.
-
-To provide a default which accomodates SQLite's default threading capabilities
-somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
-be used by default. This pool maintains a single SQLite connection per thread
-that is held open up to a count of five concurrent threads. When more than five threads
-are used, a cleanup mechanism will dispose of excess unused connections.
-
-Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
-
- * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
- application using an in-memory database, assuming the threading issues inherent in
- pysqlite are somehow accomodated for. This pool holds persistently onto a single connection
- which is never closed, and is returned for all requests.
-
- * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
- makes use of a file-based sqlite database. This pool disables any actual "pooling"
- behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
- and :func:`close()` methods. SQLite can "connect" to a particular file with very high
- efficiency, so this option may actually perform better without the extra overhead
- of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection
- useless since the database would be lost as soon as the connection is "returned" to the pool.
-
-Date and Time Types
--------------------
-
-SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide
-out of the box functionality for translating values between Python `datetime` objects
-and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime`
-and related types provide date formatting and parsing functionality when SQlite is used.
-The implementation classes are :class:`SLDateTime`, :class:`SLDate` and :class:`SLTime`.
-These types represent dates and times as ISO formatted strings, which also nicely
-support ordering. There's no reliance on typical "libc" internals for these functions
-so historical dates are fully supported.
-
-Unicode
--------
-
-In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's
-default behavior regarding Unicode is that all strings are returned as Python unicode objects
-in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is
-*not* used, you will still always receive unicode data back from a result set. It is
-**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
-to represent strings, since it will raise a warning if a non-unicode Python string is
-passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can
-quickly create confusion, particularly when using the ORM as internal data is not
-always represented by an actual database result string.
-
-"""
-
-
-import datetime, re, time
-
-from sqlalchemy import sql, schema, exc, pool, DefaultClause
-from sqlalchemy.engine import default
-import sqlalchemy.types as sqltypes
-import sqlalchemy.util as util
-from sqlalchemy.sql import compiler, functions as sql_functions
-from types import NoneType
-
-class SLNumeric(sqltypes.Numeric):
- def bind_processor(self, dialect):
- type_ = self.asdecimal and str or float
- def process(value):
- if value is not None:
- return type_(value)
- else:
- return value
- return process
-
- def get_col_spec(self):
- if self.precision is None:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class SLFloat(sqltypes.Float):
- def bind_processor(self, dialect):
- type_ = self.asdecimal and str or float
- def process(value):
- if value is not None:
- return type_(value)
- else:
- return value
- return process
-
- def get_col_spec(self):
- return "FLOAT"
-
-class SLInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-class SLSmallInteger(sqltypes.Smallinteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-class DateTimeMixin(object):
- def _bind_processor(self, format, elements):
- def process(value):
- if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)):
- raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.")
- elif value is not None:
- return format % tuple([getattr(value, attr, 0) for attr in elements])
- else:
- return None
- return process
-
- def _result_processor(self, fn, regexp):
- def process(value):
- if value is not None:
- return fn(*[int(x or 0) for x in regexp.match(value).groups()])
- else:
- return None
- return process
-
-class SLDateTime(DateTimeMixin, sqltypes.DateTime):
- __legacy_microseconds__ = False
-
- def get_col_spec(self):
- return "TIMESTAMP"
-
- def bind_processor(self, dialect):
- if self.__legacy_microseconds__:
- return self._bind_processor(
- "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s",
- ("year", "month", "day", "hour", "minute", "second", "microsecond")
- )
- else:
- return self._bind_processor(
- "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d",
- ("year", "month", "day", "hour", "minute", "second", "microsecond")
- )
-
- _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?")
- def result_processor(self, dialect):
- return self._result_processor(datetime.datetime, self._reg)
-
-class SLDate(DateTimeMixin, sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
-
- def bind_processor(self, dialect):
- return self._bind_processor(
- "%4.4d-%2.2d-%2.2d",
- ("year", "month", "day")
- )
-
- _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
- def result_processor(self, dialect):
- return self._result_processor(datetime.date, self._reg)
-
-class SLTime(DateTimeMixin, sqltypes.Time):
- __legacy_microseconds__ = False
-
- def get_col_spec(self):
- return "TIME"
-
- def bind_processor(self, dialect):
- if self.__legacy_microseconds__:
- return self._bind_processor(
- "%2.2d:%2.2d:%2.2d.%s",
- ("hour", "minute", "second", "microsecond")
- )
- else:
- return self._bind_processor(
- "%2.2d:%2.2d:%2.2d.%06d",
- ("hour", "minute", "second", "microsecond")
- )
-
- _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
- def result_processor(self, dialect):
- return self._result_processor(datetime.time, self._reg)
-
-class SLUnicodeMixin(object):
- def bind_processor(self, dialect):
- if self.convert_unicode or dialect.convert_unicode:
- if self.assert_unicode is None:
- assert_unicode = dialect.assert_unicode
- else:
- assert_unicode = self.assert_unicode
-
- if not assert_unicode:
- return None
-
- def process(value):
- if not isinstance(value, (unicode, NoneType)):
- if assert_unicode == 'warn':
- util.warn("Unicode type received non-unicode bind "
- "param value %r" % value)
- return value
- else:
- raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
- else:
- return value
- return process
- else:
- return None
-
- def result_processor(self, dialect):
- return None
-
-class SLText(SLUnicodeMixin, sqltypes.Text):
- def get_col_spec(self):
- return "TEXT"
-
-class SLString(SLUnicodeMixin, sqltypes.String):
- def get_col_spec(self):
- return "VARCHAR" + (self.length and "(%d)" % self.length or "")
-
-class SLChar(SLUnicodeMixin, sqltypes.CHAR):
- def get_col_spec(self):
- return "CHAR" + (self.length and "(%d)" % self.length or "")
-
-class SLBinary(sqltypes.Binary):
- def get_col_spec(self):
- return "BLOB"
-
-class SLBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return "BOOLEAN"
-
- def bind_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return value and 1 or 0
- return process
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return value and True or False
- return process
-
-colspecs = {
- sqltypes.Binary: SLBinary,
- sqltypes.Boolean: SLBoolean,
- sqltypes.CHAR: SLChar,
- sqltypes.Date: SLDate,
- sqltypes.DateTime: SLDateTime,
- sqltypes.Float: SLFloat,
- sqltypes.Integer: SLInteger,
- sqltypes.NCHAR: SLChar,
- sqltypes.Numeric: SLNumeric,
- sqltypes.Smallinteger: SLSmallInteger,
- sqltypes.String: SLString,
- sqltypes.Text: SLText,
- sqltypes.Time: SLTime,
-}
-
-ischema_names = {
- 'BLOB': SLBinary,
- 'BOOL': SLBoolean,
- 'BOOLEAN': SLBoolean,
- 'CHAR': SLChar,
- 'DATE': SLDate,
- 'DATETIME': SLDateTime,
- 'DECIMAL': SLNumeric,
- 'FLOAT': SLNumeric,
- 'INT': SLInteger,
- 'INTEGER': SLInteger,
- 'NUMERIC': SLNumeric,
- 'REAL': SLNumeric,
- 'SMALLINT': SLSmallInteger,
- 'TEXT': SLText,
- 'TIME': SLTime,
- 'TIMESTAMP': SLDateTime,
- 'VARCHAR': SLString,
-}
-
-class SQLiteExecutionContext(default.DefaultExecutionContext):
- def post_exec(self):
- if self.compiled.isinsert and not self.executemany:
- if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
- self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
-
-class SQLiteDialect(default.DefaultDialect):
- name = 'sqlite'
- supports_alter = False
- supports_unicode_statements = True
- default_paramstyle = 'qmark'
- supports_default_values = True
- supports_empty_insert = False
-
- def __init__(self, **kwargs):
- default.DefaultDialect.__init__(self, **kwargs)
- def vers(num):
- return tuple([int(x) for x in num.split('.')])
- if self.dbapi is not None:
- sqlite_ver = self.dbapi.version_info
- if sqlite_ver < (2, 1, '3'):
- util.warn(
- ("The installed version of pysqlite2 (%s) is out-dated "
- "and will cause errors in some cases. Version 2.1.3 "
- "or greater is recommended.") %
- '.'.join([str(subver) for subver in sqlite_ver]))
- if self.dbapi.sqlite_version_info < (3, 3, 8):
- self.supports_default_values = False
- self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
-
- def dbapi(cls):
- try:
- from pysqlite2 import dbapi2 as sqlite
- except ImportError, e:
- try:
- from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
- except ImportError:
- raise e
- return sqlite
- dbapi = classmethod(dbapi)
-
- def server_version_info(self, connection):
- return self.dbapi.sqlite_version_info
-
- def create_connect_args(self, url):
- if url.username or url.password or url.host or url.port:
- raise exc.ArgumentError(
- "Invalid SQLite URL: %s\n"
- "Valid SQLite URL forms are:\n"
- " sqlite:///:memory: (or, sqlite://)\n"
- " sqlite:///relative/path/to/file.db\n"
- " sqlite:////absolute/path/to/file.db" % (url,))
- filename = url.database or ':memory:'
-
- opts = url.query.copy()
- util.coerce_kw_type(opts, 'timeout', float)
- util.coerce_kw_type(opts, 'isolation_level', str)
- util.coerce_kw_type(opts, 'detect_types', int)
- util.coerce_kw_type(opts, 'check_same_thread', bool)
- util.coerce_kw_type(opts, 'cached_statements', int)
-
- return ([filename], opts)
-
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
-
- def is_disconnect(self, e):
- return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
-
- def table_names(self, connection, schema):
- if schema is not None:
- qschema = self.identifier_preparer.quote_identifier(schema)
- master = '%s.sqlite_master' % qschema
- s = ("SELECT name FROM %s "
- "WHERE type='table' ORDER BY name") % (master,)
- rs = connection.execute(s)
- else:
- try:
- s = ("SELECT name FROM "
- " (SELECT * FROM sqlite_master UNION ALL "
- " SELECT * FROM sqlite_temp_master) "
- "WHERE type='table' ORDER BY name")
- rs = connection.execute(s)
- except exc.DBAPIError:
- raise
- s = ("SELECT name FROM sqlite_master "
- "WHERE type='table' ORDER BY name")
- rs = connection.execute(s)
-
- return [row[0] for row in rs]
-
- def has_table(self, connection, table_name, schema=None):
- quote = self.identifier_preparer.quote_identifier
- if schema is not None:
- pragma = "PRAGMA %s." % quote(schema)
- else:
- pragma = "PRAGMA "
- qtable = quote(table_name)
- cursor = connection.execute("%stable_info(%s)" % (pragma, qtable))
- row = cursor.fetchone()
-
- # consume remaining rows, to work around
- # http://www.sqlite.org/cvstrac/tktview?tn=1884
- while cursor.fetchone() is not None:
- pass
-
- return (row is not None)
-
- def reflecttable(self, connection, table, include_columns):
- preparer = self.identifier_preparer
- if table.schema is None:
- pragma = "PRAGMA "
- else:
- pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema)
- qtable = preparer.format_table(table, False)
-
- c = connection.execute("%stable_info(%s)" % (pragma, qtable))
- found_table = False
- while True:
- row = c.fetchone()
- if row is None:
- break
-
- found_table = True
- (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5])
- name = re.sub(r'^\"|\"$', '', name)
- if include_columns and name not in include_columns:
- continue
- match = re.match(r'(\w+)(\(.*?\))?', type_)
- if match:
- coltype = match.group(1)
- args = match.group(2)
- else:
- coltype = "VARCHAR"
- args = ''
-
- try:
- coltype = ischema_names[coltype]
- except KeyError:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (coltype, name))
- coltype = sqltypes.NullType
-
- if args is not None:
- args = re.findall(r'(\d+)', args)
- coltype = coltype(*[int(a) for a in args])
-
- colargs = []
- if has_default:
- colargs.append(DefaultClause(sql.text(default)))
- table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
-
- if not found_table:
- raise exc.NoSuchTableError(table.name)
-
- c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))
- fks = {}
- while True:
- row = c.fetchone()
- if row is None:
- break
- (constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4])
- tablename = re.sub(r'^\"|\"$', '', tablename)
- localcol = re.sub(r'^\"|\"$', '', localcol)
- remotecol = re.sub(r'^\"|\"$', '', remotecol)
- try:
- fk = fks[constraint_name]
- except KeyError:
- fk = ([], [])
- fks[constraint_name] = fk
-
- # look up the table based on the given table's engine, not 'self',
- # since it could be a ProxyEngine
- remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection)
- constrained_column = table.c[localcol].name
- refspec = ".".join([tablename, remotecol])
- if constrained_column not in fk[0]:
- fk[0].append(constrained_column)
- if refspec not in fk[1]:
- fk[1].append(refspec)
- for name, value in fks.iteritems():
- table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], link_to_name=True))
- # check for UNIQUE indexes
- c = connection.execute("%sindex_list(%s)" % (pragma, qtable))
- unique_indexes = []
- while True:
- row = c.fetchone()
- if row is None:
- break
- if (row[2] == 1):
- unique_indexes.append(row[1])
- # loop thru unique indexes for one that includes the primary key
- for idx in unique_indexes:
- c = connection.execute("%sindex_info(%s)" % (pragma, idx))
- cols = []
- while True:
- row = c.fetchone()
- if row is None:
- break
- cols.append(row[2])
-
-
-class SQLiteCompiler(compiler.DefaultCompiler):
- functions = compiler.DefaultCompiler.functions.copy()
- functions.update (
- {
- sql_functions.now: 'CURRENT_TIMESTAMP',
- sql_functions.char_length: 'length%(expr)s'
- }
- )
-
- def visit_cast(self, cast, **kwargs):
- if self.dialect.supports_cast:
- return super(SQLiteCompiler, self).visit_cast(cast)
- else:
- return self.process(cast.clause)
-
- def limit_clause(self, select):
- text = ""
- if select._limit is not None:
- text += " \n LIMIT " + str(select._limit)
- if select._offset is not None:
- if select._limit is None:
- text += " \n LIMIT -1"
- text += " OFFSET " + str(select._offset)
- else:
- text += " OFFSET 0"
- return text
-
- def for_update_clause(self, select):
- # sqlite has no "FOR UPDATE" AFAICT
- return ''
-
-
-class SQLiteSchemaGenerator(compiler.SchemaGenerator):
-
- def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- if not column.nullable:
- colspec += " NOT NULL"
- return colspec
-
-class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
- reserved_words = set([
- 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
- 'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
- 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
- 'conflict', 'constraint', 'create', 'cross', 'current_date',
- 'current_time', 'current_timestamp', 'database', 'default',
- 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
- 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
- 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
- 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
- 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is',
- 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural',
- 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer',
- 'plan', 'pragma', 'primary', 'query', 'raise', 'references',
- 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback',
- 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to',
- 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using',
- 'vacuum', 'values', 'view', 'virtual', 'when', 'where',
- ])
-
- def __init__(self, dialect):
- super(SQLiteIdentifierPreparer, self).__init__(dialect)
-
-dialect = SQLiteDialect
-dialect.poolclass = pool.SingletonThreadPool
-dialect.statement_compiler = SQLiteCompiler
-dialect.schemagenerator = SQLiteSchemaGenerator
-dialect.preparer = SQLiteIdentifierPreparer
-dialect.execution_ctx_cls = SQLiteExecutionContext
}
-class SybaseSQLCompiler(compiler.DefaultCompiler):
- operators = compiler.DefaultCompiler.operators.copy()
+class SybaseSQLCompiler(compiler.SQLCompiler):
+ operators = compiler.SQLCompiler.operators.copy()
operators.update({
sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y),
})
--- /dev/null
+__all__ = (
+# 'access',
+# 'firebird',
+# 'informix',
+# 'maxdb',
+# 'mssql',
+# 'mysql',
+# 'oracle',
+ 'postgres',
+ 'sqlite',
+# 'sybase',
+ )
--- /dev/null
+from sqlalchemy.dialects.postgres import base, psycopg2
+
+base.dialect = psycopg2.dialect
\ No newline at end of file
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Support for the PostgreSQL database.
-
-Driver
-------
-
-The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
-The dialect has several behaviors which are specifically tailored towards compatibility
-with this module.
-
-Note that psycopg1 is **not** supported.
-
-Connecting
-----------
-
-URLs are of the form `postgres://user:password@host:port/dbname[?key=value&key=value...]`.
-
-Postgres-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
-
-* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
- this feature. What this essentially means from a psycopg2 point of view is that the cursor is
- created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
- are not immediately pre-fetched and buffered after statement execution, but are instead left
- on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
- uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows
- at a time are fetched over the wire to reduce conversational overhead.
-
-Sequences/SERIAL
-----------------
-
-Postgres supports sequences, and SQLAlchemy uses these as the default means of creating
-new primary key values for integer-based primary key columns. When creating tables,
-SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns,
-which generates a sequence corresponding to the column and associated with it based on
-a naming convention.
-
-To specify a specific named sequence to be used for primary key generation, use the
-:func:`~sqlalchemy.schema.Sequence` construct::
-
- Table('sometable', metadata,
- Column('id', Integer, Sequence('some_id_seq'), primary_key=True)
- )
-
-Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of
-having the "last insert identifier" available, the sequence is executed independently
-beforehand and the new value is retrieved, to be used in the subsequent insert. Note
-that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using
-"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior
-is used.
-
-Postgres 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports
-as well. A future release of SQLA will use this feature by default in lieu of
-sequence pre-execution in order to retrieve new primary key values, when available.
-
-INSERT/UPDATE...RETURNING
--------------------------
-
-The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes,
-but must be explicitly enabled on a per-statement basis::
-
- # INSERT..RETURNING
- result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\
- values(name='foo')
- print result.fetchall()
-
- # UPDATE..RETURNING
- result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\
- where(table.c.name=='foo').values(name='bar')
- print result.fetchall()
-
-Indexes
--------
-
-PostgreSQL supports partial indexes. To create them pass a postgres_where
-option to the Index constructor::
-
- Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10)
-
-Transactions
-------------
-
-The Postgres dialect fully supports SAVEPOINT and two-phase commit operations.
-
-
-"""
import decimal, random, re, string
class PGInet(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "INET"
+ __visit_name__ = "INET"
class PGCidr(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "CIDR"
+ __visit_name__ = "CIDR"
class PGMacAddr(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "MACADDR"
-
-class PGNumeric(sqltypes.Numeric):
- def get_col_spec(self):
- if not self.precision:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
- def bind_processor(self, dialect):
- return None
-
- def result_processor(self, dialect):
- if self.asdecimal:
- return None
- else:
- def process(value):
- if isinstance(value, decimal.Decimal):
- return float(value)
- else:
- return value
- return process
-
-class PGFloat(sqltypes.Float):
- def get_col_spec(self):
- if not self.precision:
- return "FLOAT"
- else:
- return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class PGInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-class PGSmallInteger(sqltypes.Smallinteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-class PGBigInteger(PGInteger):
- def get_col_spec(self):
- return "BIGINT"
-
-class PGDateTime(sqltypes.DateTime):
- def get_col_spec(self):
- return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
-
-class PGDate(sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
+ __visit_name__ = "MACADDR"
-class PGTime(sqltypes.Time):
- def get_col_spec(self):
- return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+class PGBigInteger(sqltypes.Integer):
+ __visit_name__ = "BIGINT"
class PGInterval(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "INTERVAL"
-
-class PGText(sqltypes.Text):
- def get_col_spec(self):
- return "TEXT"
-
-class PGString(sqltypes.String):
- def get_col_spec(self):
- if self.length:
- return "VARCHAR(%(length)d)" % {'length' : self.length}
- else:
- return "VARCHAR"
-
-class PGChar(sqltypes.CHAR):
- def get_col_spec(self):
- if self.length:
- return "CHAR(%(length)d)" % {'length' : self.length}
- else:
- return "CHAR"
-
-class PGBinary(sqltypes.Binary):
- def get_col_spec(self):
- return "BYTEA"
-
-class PGBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return "BOOLEAN"
+ __visit_name__ = 'INTERVAL'
class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
+ __visit_name__ = 'ARRAY'
+
def __init__(self, item_type, mutable=True):
if isinstance(item_type, type):
item_type = item_type()
return item
return [convert_item(item) for item in value]
return process
- def get_col_spec(self):
- return self.item_type.get_col_spec() + '[]'
-
-colspecs = {
- sqltypes.Integer : PGInteger,
- sqltypes.Smallinteger : PGSmallInteger,
- sqltypes.Numeric : PGNumeric,
- sqltypes.Float : PGFloat,
- sqltypes.DateTime : PGDateTime,
- sqltypes.Date : PGDate,
- sqltypes.Time : PGTime,
- sqltypes.String : PGString,
- sqltypes.Binary : PGBinary,
- sqltypes.Boolean : PGBoolean,
- sqltypes.Text : PGText,
- sqltypes.CHAR: PGChar,
-}
-
-ischema_names = {
- 'integer' : PGInteger,
- 'bigint' : PGBigInteger,
- 'smallint' : PGSmallInteger,
- 'character varying' : PGString,
- 'character' : PGChar,
- 'text' : PGText,
- 'numeric' : PGNumeric,
- 'float' : PGFloat,
- 'real' : PGFloat,
- 'inet': PGInet,
- 'cidr': PGCidr,
- 'macaddr': PGMacAddr,
- 'double precision' : PGFloat,
- 'timestamp' : PGDateTime,
- 'timestamp with time zone' : PGDateTime,
- 'timestamp without time zone' : PGDateTime,
- 'time with time zone' : PGTime,
- 'time without time zone' : PGTime,
- 'date' : PGDate,
- 'time': PGTime,
- 'bytea' : PGBinary,
- 'boolean' : PGBoolean,
- 'interval':PGInterval,
-}
-
-# TODO: filter out 'FOR UPDATE' statements
-SERVER_SIDE_CURSOR_RE = re.compile(
- r'\s*SELECT',
- re.I | re.UNICODE)
-
-class PGExecutionContext(default.DefaultExecutionContext):
- def create_cursor(self):
- # TODO: coverage for server side cursors + select.for_update()
- is_server_side = \
- self.dialect.server_side_cursors and \
- ((self.compiled and isinstance(self.compiled.statement, expression.Selectable)
- and not getattr(self.compiled.statement, 'for_update', False)) \
- or \
- (
- (not self.compiled or isinstance(self.compiled.statement, expression._TextClause))
- and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
- )
- self.__is_server_side = is_server_side
- if is_server_side:
- # use server-side cursors:
- # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
- ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
- return self._connection.connection.cursor(ident)
+
+
+
+
+class PGCompiler(compiler.SQLCompiler):
+ operators = compiler.SQLCompiler.operators.copy()
+ operators.update(
+ {
+ sql_operators.mod : '%%',
+ sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
+ sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
+ sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
+ }
+ )
+
+ functions = compiler.SQLCompiler.functions.copy()
+ functions.update (
+ {
+ 'TIMESTAMP':lambda x:'TIMESTAMP %s' % x,
+ }
+ )
+
+ def visit_sequence(self, seq):
+ if seq.optional:
+ return None
+ else:
+ return "nextval('%s')" % self.preparer.format_sequence(seq)
+
+ def post_process_text(self, text):
+ if '%%' in text:
+ util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.")
+ return text.replace('%', '%%')
+
+ def limit_clause(self, select):
+ text = ""
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
+ text += " \n LIMIT ALL"
+ text += " OFFSET " + str(select._offset)
+ return text
+
+ def get_select_precolumns(self, select):
+ if select._distinct:
+ if isinstance(select._distinct, bool):
+ return "DISTINCT "
+ elif isinstance(select._distinct, (list, tuple)):
+ return "DISTINCT ON (" + ', '.join(
+ [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
+ )+ ") "
+ else:
+ return "DISTINCT ON (" + unicode(select._distinct) + ") "
+ else:
+ return ""
+
+ def for_update_clause(self, select):
+ if select.for_update == 'nowait':
+ return " FOR UPDATE NOWAIT"
+ else:
+ return super(PGCompiler, self).for_update_clause(select)
+
+ def _append_returning(self, text, stmt):
+ returning_cols = stmt.kwargs['postgres_returning']
+ def flatten_columnlist(collist):
+ for c in collist:
+ if isinstance(c, expression.Selectable):
+ for co in c.columns:
+ yield co
+ else:
+ yield c
+ columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
+ text += ' RETURNING ' + string.join(columns, ', ')
+ return text
+
+ def visit_update(self, update_stmt):
+ text = super(PGCompiler, self).visit_update(update_stmt)
+ if 'postgres_returning' in update_stmt.kwargs:
+ return self._append_returning(text, update_stmt)
+ else:
+ return text
+
+ def visit_insert(self, insert_stmt):
+ text = super(PGCompiler, self).visit_insert(insert_stmt)
+ if 'postgres_returning' in insert_stmt.kwargs:
+ return self._append_returning(text, insert_stmt)
+ else:
+ return text
+
+class PGDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column)
+ if column.primary_key and \
+ len(column.foreign_keys)==0 and \
+ column.autoincrement and \
+ isinstance(column.type, sqltypes.Integer) and \
+ not isinstance(column.type, sqltypes.SmallInteger) and \
+ (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+ if isinstance(column.type, PGBigInteger):
+ colspec += " BIGSERIAL"
+ else:
+ colspec += " SERIAL"
else:
- return self._connection.connection.cursor()
+ colspec += " " + self.dialect.type_compiler.process(column.type)
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+ return colspec
+
+ def visit_create_sequence(self, create):
+ return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+
+ def visit_drop_sequence(self, drop):
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+
+ def visit_create_index(self, create):
+ preparer = self.preparer
+ index = create.element
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ text += "INDEX %s ON %s (%s)" \
+ % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
+ preparer.format_table(index.table),
+ string.join([preparer.format_column(c) for c in index.columns], ', '))
+
+ whereclause = index.kwargs.get('postgres_where', None)
+ if whereclause is not None:
+ compiler = self._compile(whereclause, None)
+ # this might belong to the compiler class
+ inlined_clause = str(compiler) % dict(
+ [(key,bind.value) for key,bind in compiler.binds.iteritems()])
+ text += " WHERE " + inlined_clause
+ return text
+
+class PGDefaultRunner(base.DefaultRunner):
+ def __init__(self, context):
+ base.DefaultRunner.__init__(self, context)
+ # craete cursor which won't conflict with a server-side cursor
+ self.cursor = context._connection.connection.cursor()
- def get_result_proxy(self):
- if self.__is_server_side:
- return base.BufferedRowResultProxy(self)
+ def get_column_default(self, column, isinsert=True):
+ if column.primary_key:
+ # pre-execute passive defaults on primary keys
+ if (isinstance(column.server_default, schema.DefaultClause) and
+ column.server_default.arg is not None):
+ return self.execute_string("select %s" % column.server_default.arg)
+ elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+ sch = column.table.schema
+ # TODO: this has to build into the Sequence object so we can get the quoting
+ # logic from it
+ if sch is not None:
+ exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
+ else:
+ exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
+ return self.execute_string(exc.encode(self.dialect.encoding))
+
+ return super(PGDefaultRunner, self).get_column_default(column)
+
+ def visit_sequence(self, seq):
+ if not seq.optional:
+ return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
+ else:
+ return None
+
+class PGTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_INET(self, type_):
+ return "INET"
+
+ def visit_CIDR(self, type_):
+ return "CIDR"
+
+ def visit_MACADDR(self, type_):
+ return "MACADDR"
+
+ def visit_FLOAT(self, type_):
+ if not type_.precision:
+ return "FLOAT"
else:
- return base.ResultProxy(self)
+ return "FLOAT(%(precision)s)" % {'precision': type_.precision}
+
+ def visit_BIGINT(self, type_):
+ return "BIGINT"
+
+ def visit_DATETIME(self, type_):
+ return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
+ def visit_TIME(self, type_):
+ return "TIME " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
+ def visit_INTERVAL(self, type_):
+ return "INTERVAL"
+
+ def visit_BINARY(self, type_):
+ return "BYTEA"
+
+ def visit_ARRAY(self, type_):
+ return self.process(type_.item_type) + '[]'
+
+class PGIdentifierPreparer(compiler.IdentifierPreparer):
+ def _unquote_identifier(self, value):
+ if value[0] == self.initial_quote:
+ value = value[1:-1].replace('""','"')
+ return value
class PGDialect(default.DefaultDialect):
name = 'postgres'
supports_alter = True
- supports_unicode_statements = False
max_identifier_length = 63
supports_sane_rowcount = True
- supports_sane_multi_rowcount = False
+ supports_sequences = True
+ sequences_optional = True
preexecute_pk_sequences = True
supports_pk_autoincrement = False
- default_paramstyle = 'pyformat'
supports_default_values = True
supports_empty_insert = False
-
- def __init__(self, server_side_cursors=False, **kwargs):
- default.DefaultDialect.__init__(self, **kwargs)
- self.server_side_cursors = server_side_cursors
- def dbapi(cls):
- import psycopg2 as psycopg
- return psycopg
- dbapi = classmethod(dbapi)
+ statement_compiler = PGCompiler
+ ddl_compiler = PGDDLCompiler
+ type_compiler = PGTypeCompiler
+ preparer = PGIdentifierPreparer
+ defaultrunner = PGDefaultRunner
- def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- if 'port' in opts:
- opts['port'] = int(opts['port'])
- opts.update(url.query)
- return ([], opts)
-
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
def do_begin_twophase(self, connection, xid):
self.do_begin(connection.connection)
resultset = connection.execute(sql.text("SELECT gid FROM pg_prepared_xacts"))
return [row[0] for row in resultset]
+ @base.connection_memoize(('dialect', 'default_schema_name'))
def get_default_schema_name(self, connection):
return connection.scalar("select current_schema()", None)
- get_default_schema_name = base.connection_memoize(
- ('dialect', 'default_schema_name'))(get_default_schema_name)
-
- def last_inserted_ids(self):
- if self.context.last_inserted_ids is None:
- raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled")
- else:
- return self.context.last_inserted_ids
def has_table(self, connection, table_name, schema=None):
# seems like case gets folded in pg_class...
if schema is None:
- cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding)});
+ cursor = connection.execute(
+ sql.text("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=current_schema() and lower(relname)=:name""",
+ bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode)]
+ )
+ )
else:
- cursor = connection.execute("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=%(schema)s and lower(relname)=%(name)s""", {'name':table_name.lower().encode(self.encoding), 'schema':schema});
- return bool( not not cursor.rowcount )
+ cursor = connection.execute(
+ sql.text("""select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where n.nspname=:schema and lower(relname)=:name""",
+ bindparams=[sql.bindparam('name', unicode(table_name.lower()), type_=sqltypes.Unicode),
+ sql.bindparam('schema', unicode(schema), type_=sqltypes.Unicode)]
+ )
+ )
+ return bool( cursor.rowcount )
def has_sequence(self, connection, sequence_name):
- cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name.encode(self.encoding)})
- return bool(not not cursor.rowcount)
-
- def is_disconnect(self, e):
- if isinstance(e, self.dbapi.OperationalError):
- return 'closed the connection' in str(e) or 'connection not open' in str(e)
- elif isinstance(e, self.dbapi.InterfaceError):
- return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
- elif isinstance(e, self.dbapi.ProgrammingError):
- # yes, it really says "losed", not "closed"
- return "losed the connection unexpectedly" in str(e)
- else:
- return False
+ cursor = connection.execute(
+ sql.text("SELECT relname FROM pg_class WHERE relkind = 'S' AND "
+ "relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' "
+ "AND nspname != 'information_schema' AND relname = :seqname)",
+ bindparams=[sql.bindparam('seqname', unicode(sequence_name), type_=sqltypes.Unicode)]
+ ))
+ return bool(cursor.rowcount)
def table_names(self, connection, schema):
- s = """
- SELECT relname
- FROM pg_class c
- WHERE relkind = 'r'
- AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
- """ % locals()
- return [row[0].decode(self.encoding) for row in connection.execute(s)]
+ result = connection.execute(
+ sql.text(u"""SELECT relname
+ FROM pg_class c
+ WHERE relkind = 'r'
+ AND '%s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)""" % schema,
+ typemap = {'relname':sqltypes.Unicode}
+ )
+ )
+ return [row[0] for row in result]
def server_version_info(self, connection):
v = connection.execute("select version()").scalar()
elif attype == 'timestamp without time zone':
kwargs['timezone'] = False
- if attype in ischema_names:
- coltype = ischema_names[attype]
+ if attype in self.ischema_names:
+ coltype = self.ischema_names[attype]
else:
if attype in domains:
domain = domains[attype]
- if domain['attype'] in ischema_names:
+ if domain['attype'] in self.ischema_names:
# A table can't override whether the domain is nullable.
nullable = domain['nullable']
if domain['default'] and not default:
# It can, however, override the default value, but can't set it to null.
default = domain['default']
- coltype = ischema_names[domain['attype']]
+ coltype = self.ischema_names[domain['attype']]
else:
coltype = None
return domains
-
-
-class PGCompiler(compiler.DefaultCompiler):
- operators = compiler.DefaultCompiler.operators.copy()
- operators.update(
- {
- sql_operators.mod : '%%',
- sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
- sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
- }
- )
-
- functions = compiler.DefaultCompiler.functions.copy()
- functions.update (
- {
- 'TIMESTAMP':lambda x:'TIMESTAMP %s' % x,
- }
- )
-
- def visit_sequence(self, seq):
- if seq.optional:
- return None
- else:
- return "nextval('%s')" % self.preparer.format_sequence(seq)
-
- def post_process_text(self, text):
- if '%%' in text:
- util.warn("The SQLAlchemy psycopg2 dialect now automatically escapes '%' in text() expressions to '%%'.")
- return text.replace('%', '%%')
-
- def limit_clause(self, select):
- text = ""
- if select._limit is not None:
- text += " \n LIMIT " + str(select._limit)
- if select._offset is not None:
- if select._limit is None:
- text += " \n LIMIT ALL"
- text += " OFFSET " + str(select._offset)
- return text
-
- def get_select_precolumns(self, select):
- if select._distinct:
- if isinstance(select._distinct, bool):
- return "DISTINCT "
- elif isinstance(select._distinct, (list, tuple)):
- return "DISTINCT ON (" + ', '.join(
- [(isinstance(col, basestring) and col or self.process(col)) for col in select._distinct]
- )+ ") "
- else:
- return "DISTINCT ON (" + unicode(select._distinct) + ") "
- else:
- return ""
-
- def for_update_clause(self, select):
- if select.for_update == 'nowait':
- return " FOR UPDATE NOWAIT"
- else:
- return super(PGCompiler, self).for_update_clause(select)
-
- def _append_returning(self, text, stmt):
- returning_cols = stmt.kwargs['postgres_returning']
- def flatten_columnlist(collist):
- for c in collist:
- if isinstance(c, expression.Selectable):
- for co in c.columns:
- yield co
- else:
- yield c
- columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
- text += ' RETURNING ' + string.join(columns, ', ')
- return text
-
- def visit_update(self, update_stmt):
- text = super(PGCompiler, self).visit_update(update_stmt)
- if 'postgres_returning' in update_stmt.kwargs:
- return self._append_returning(text, update_stmt)
- else:
- return text
-
- def visit_insert(self, insert_stmt):
- text = super(PGCompiler, self).visit_insert(insert_stmt)
- if 'postgres_returning' in insert_stmt.kwargs:
- return self._append_returning(text, insert_stmt)
- else:
- return text
-
-class PGSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column)
- if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
- if isinstance(column.type, PGBigInteger):
- colspec += " BIGSERIAL"
- else:
- colspec += " SERIAL"
- else:
- colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- if not column.nullable:
- colspec += " NOT NULL"
- return colspec
-
- def visit_sequence(self, sequence):
- if not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
- self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
- self.execute()
-
- def visit_index(self, index):
- preparer = self.preparer
- self.append("CREATE ")
- if index.unique:
- self.append("UNIQUE ")
- self.append("INDEX %s ON %s (%s)" \
- % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
- preparer.format_table(index.table),
- string.join([preparer.format_column(c) for c in index.columns], ', ')))
- whereclause = index.kwargs.get('postgres_where', None)
- if whereclause is not None:
- compiler = self._compile(whereclause, None)
- # this might belong to the compiler class
- inlined_clause = str(compiler) % dict(
- [(key,bind.value) for key,bind in compiler.binds.iteritems()])
- self.append(" WHERE " + inlined_clause)
- self.execute()
-
-class PGSchemaDropper(compiler.SchemaDropper):
- def visit_sequence(self, sequence):
- if not sequence.optional and (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
- self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
- self.execute()
-
-class PGDefaultRunner(base.DefaultRunner):
- def __init__(self, context):
- base.DefaultRunner.__init__(self, context)
- # craete cursor which won't conflict with a server-side cursor
- self.cursor = context._connection.connection.cursor()
-
- def get_column_default(self, column, isinsert=True):
- if column.primary_key:
- # pre-execute passive defaults on primary keys
- if (isinstance(column.server_default, schema.DefaultClause) and
- column.server_default.arg is not None):
- return self.execute_string("select %s" % column.server_default.arg)
- elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
- sch = column.table.schema
- # TODO: this has to build into the Sequence object so we can get the quoting
- # logic from it
- if sch is not None:
- exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
- else:
- exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name)
- return self.execute_string(exc.encode(self.dialect.encoding))
-
- return super(PGDefaultRunner, self).get_column_default(column)
-
- def visit_sequence(self, seq):
- if not seq.optional:
- return self.execute_string(("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)))
- else:
- return None
-
-class PGIdentifierPreparer(compiler.IdentifierPreparer):
- def _unquote_identifier(self, value):
- if value[0] == self.initial_quote:
- value = value[1:-1].replace('""','"')
- return value
-
-dialect = PGDialect
-dialect.statement_compiler = PGCompiler
-dialect.schemagenerator = PGSchemaGenerator
-dialect.schemadropper = PGSchemaDropper
-dialect.preparer = PGIdentifierPreparer
-dialect.defaultrunner = PGDefaultRunner
-dialect.execution_ctx_cls = PGExecutionContext
--- /dev/null
+"""Support for the PostgreSQL database via the psycopg2 driver.
+
+Driver
+------
+
+The psycopg2 driver is supported, available at http://pypi.python.org/pypi/psycopg2/ .
+The dialect has several behaviors which are specifically tailored towards compatibility
+with this module.
+
+Note that psycopg1 is **not** supported.
+
+Connecting
+----------
+
+URLs are of the form `postgres+psycopg2://user@password@host:port/dbname[?key=value&key=value...]`.
+
+psycopg2-specific keyword arguments which are accepted by :func:`~sqlalchemy.create_engine()` are:
+
+* *server_side_cursors* - Enable the usage of "server side cursors" for SQL statements which support
+ this feature. What this essentially means from a psycopg2 point of view is that the cursor is
+ created using a name, e.g. `connection.cursor('some name')`, which has the effect that result rows
+ are not immediately pre-fetched and buffered after statement execution, but are instead left
+ on the server and only retrieved as needed. SQLAlchemy's :class:`~sqlalchemy.engine.base.ResultProxy`
+ uses special row-buffering behavior when this feature is enabled, such that groups of 100 rows
+ at a time are fetched over the wire to reduce conversational overhead.
+
+Sequences/SERIAL
+----------------
+
+Postgres supports sequences, and SQLAlchemy uses these as the default means of creating
+new primary key values for integer-based primary key columns. When creating tables,
+SQLAlchemy will issue the ``SERIAL`` datatype for integer-based primary key columns,
+which generates a sequence corresponding to the column and associated with it based on
+a naming convention.
+
+To specify a specific named sequence to be used for primary key generation, use the
+:func:`~sqlalchemy.schema.Sequence` construct::
+
+ Table('sometable', metadata,
+ Column('id', Integer, Sequence('some_id_seq'), primary_key=True)
+ )
+
+Currently, when SQLAlchemy issues a single insert statement, to fulfill the contract of
+having the "last insert identifier" available, the sequence is executed independently
+beforehand and the new value is retrieved, to be used in the subsequent insert. Note
+that when an :func:`~sqlalchemy.sql.expression.insert()` construct is executed using
+"executemany" semantics, the sequence is not pre-executed and normal PG SERIAL behavior
+is used.
+
+Postgres 8.3 supports an ``INSERT...RETURNING`` syntax which SQLAlchemy supports
+as well. A future release of SQLA will use this feature by default in lieu of
+sequence pre-execution in order to retrieve new primary key values, when available.
+
+INSERT/UPDATE...RETURNING
+-------------------------
+
+The dialect supports PG 8.3's ``INSERT..RETURNING`` and ``UPDATE..RETURNING`` syntaxes,
+but must be explicitly enabled on a per-statement basis::
+
+ # INSERT..RETURNING
+ result = table.insert(postgres_returning=[table.c.col1, table.c.col2]).\\
+ values(name='foo')
+ print result.fetchall()
+
+ # UPDATE..RETURNING
+ result = table.update(postgres_returning=[table.c.col1, table.c.col2]).\\
+ where(table.c.name=='foo').values(name='bar')
+ print result.fetchall()
+
+Indexes
+-------
+
+PostgreSQL supports partial indexes. To create them pass a postgres_where
+option to the Index constructor::
+
+ Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10)
+
+Transactions
+------------
+
+The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
+
+
+"""
+
+import decimal, random, re, string
+
+from sqlalchemy import sql, schema, exc, util
+from sqlalchemy.engine import base, default
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.postgres.base import PGDialect, PGInet, PGCidr, PGMacAddr, PGArray, \
+ PGBigInteger, PGInterval
+
+class PGNumeric(sqltypes.Numeric):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect):
+ if self.asdecimal:
+ return None
+ else:
+ def process(value):
+ if isinstance(value, decimal.Decimal):
+ return float(value)
+ else:
+ return value
+ return process
+
+
+colspecs = {
+ sqltypes.Numeric : PGNumeric,
+ sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used
+}
+
+ischema_names = {
+ 'integer' : sqltypes.Integer,
+ 'bigint' : PGBigInteger,
+ 'smallint' : sqltypes.SmallInteger,
+ 'character varying' : sqltypes.String,
+ 'character' : sqltypes.CHAR,
+ 'text' : sqltypes.Text,
+ 'numeric' : PGNumeric,
+ 'float' : sqltypes.Float,
+ 'real' : sqltypes.Float,
+ 'inet': PGInet,
+ 'cidr': PGCidr,
+ 'macaddr': PGMacAddr,
+ 'double precision' : sqltypes.Float,
+ 'timestamp' : sqltypes.DateTime,
+ 'timestamp with time zone' : sqltypes.DateTime,
+ 'timestamp without time zone' : sqltypes.DateTime,
+ 'time with time zone' : sqltypes.Time,
+ 'time without time zone' : sqltypes.Time,
+ 'date' : sqltypes.Date,
+ 'time': sqltypes.Time,
+ 'bytea' : sqltypes.Binary,
+ 'boolean' : sqltypes.Boolean,
+ 'interval':PGInterval,
+}
+
+# TODO: filter out 'FOR UPDATE' statements
+SERVER_SIDE_CURSOR_RE = re.compile(
+ r'\s*SELECT',
+ re.I | re.UNICODE)
+
+class Postgres_psycopg2ExecutionContext(default.DefaultExecutionContext):
+ def create_cursor(self):
+ # TODO: coverage for server side cursors + select.for_update()
+ is_server_side = \
+ self.dialect.server_side_cursors and \
+ ((self.compiled and isinstance(self.compiled.statement, expression.Selectable)
+ and not getattr(self.compiled.statement, 'for_update', False)) \
+ or \
+ (
+ (not self.compiled or isinstance(self.compiled.statement, expression._TextClause))
+ and self.statement and SERVER_SIDE_CURSOR_RE.match(self.statement))
+ )
+
+ self.__is_server_side = is_server_side
+ if is_server_side:
+ # use server-side cursors:
+ # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+ ident = "c_%s_%s" % (hex(id(self))[2:], hex(random.randint(0, 65535))[2:])
+ return self._connection.connection.cursor(ident)
+ else:
+ return self._connection.connection.cursor()
+
+ def get_result_proxy(self):
+ if self.__is_server_side:
+ return base.BufferedRowResultProxy(self)
+ else:
+ return base.ResultProxy(self)
+
+class Postgres_psycopg2(PGDialect):
+ driver = 'psycopg2'
+ supports_unicode_statements = False
+ default_paramstyle = 'pyformat'
+ supports_sane_multi_rowcount = False
+ execution_ctx_cls = Postgres_psycopg2ExecutionContext
+ ischema_names = ischema_names
+
+ def __init__(self, server_side_cursors=False, **kwargs):
+ PGDialect.__init__(self, **kwargs)
+ self.server_side_cursors = server_side_cursors
+
+ @classmethod
+ def dbapi(cls):
+ psycopg = __import__('psycopg2')
+ return psycopg
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username='user')
+ if 'port' in opts:
+ opts['port'] = int(opts['port'])
+ opts.update(url.query)
+ return ([], opts)
+
+ def type_descriptor(self, typeobj):
+ return sqltypes.adapt_type(typeobj, colspecs)
+
+ def is_disconnect(self, e):
+ if isinstance(e, self.dbapi.OperationalError):
+ return 'closed the connection' in str(e) or 'connection not open' in str(e)
+ elif isinstance(e, self.dbapi.InterfaceError):
+ return 'connection already closed' in str(e) or 'cursor already closed' in str(e)
+ elif isinstance(e, self.dbapi.ProgrammingError):
+ # yes, it really says "losed", not "closed"
+ return "losed the connection unexpectedly" in str(e)
+ else:
+ return False
+
+dialect = Postgres_psycopg2
+
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.sqlite import base, pysqlite
+
+# default dialect
+base.dialect = pysqlite.dialect
\ No newline at end of file
--- /dev/null
+# sqlite.py
+# Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import datetime, re, time
+
+from sqlalchemy import sql, schema, exc, pool, DefaultClause
+from sqlalchemy.engine import default
+from sqlalchemy.engine import default
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+from sqlalchemy.sql import compiler, functions as sql_functions
+from types import NoneType
+
+class NumericMixin(object):
+ def bind_processor(self, dialect):
+ type_ = self.asdecimal and str or float
+ def process(value):
+ if value is not None:
+ return type_(value)
+ else:
+ return value
+ return process
+
+class SLNumeric(NumericMixin, sqltypes.Numeric):
+ pass
+
+class SLFloat(NumericMixin, sqltypes.Float):
+ pass
+
+# since SQLite has no date types, we're assuming that SQLite via ODBC
+# or JDBC would similarly have no built in date support, so the "string" based logic
+# would apply to all implementing dialects.
+class DateTimeMixin(object):
+ def _bind_processor(self, format, elements):
+ def process(value):
+ if not isinstance(value, (NoneType, datetime.date, datetime.datetime, datetime.time)):
+ raise TypeError("SQLite Date, Time, and DateTime types only accept Python datetime objects as input.")
+ elif value is not None:
+ return format % tuple([getattr(value, attr, 0) for attr in elements])
+ else:
+ return None
+ return process
+
+ def _result_processor(self, fn, regexp):
+ def process(value):
+ if value is not None:
+ return fn(*[int(x or 0) for x in regexp.match(value).groups()])
+ else:
+ return None
+ return process
+
+class SLDateTime(DateTimeMixin, sqltypes.DateTime):
+ __legacy_microseconds__ = False
+
+ def bind_processor(self, dialect):
+ if self.__legacy_microseconds__:
+ return self._bind_processor(
+ "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%s",
+ ("year", "month", "day", "hour", "minute", "second", "microsecond")
+ )
+ else:
+ return self._bind_processor(
+ "%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d.%06d",
+ ("year", "month", "day", "hour", "minute", "second", "microsecond")
+ )
+
+ _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?")
+ def result_processor(self, dialect):
+ return self._result_processor(datetime.datetime, self._reg)
+
+class SLDate(DateTimeMixin, sqltypes.Date):
+ def bind_processor(self, dialect):
+ return self._bind_processor(
+ "%4.4d-%2.2d-%2.2d",
+ ("year", "month", "day")
+ )
+
+ _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
+ def result_processor(self, dialect):
+ return self._result_processor(datetime.date, self._reg)
+
+class SLTime(DateTimeMixin, sqltypes.Time):
+ __legacy_microseconds__ = False
+
+ def bind_processor(self, dialect):
+ if self.__legacy_microseconds__:
+ return self._bind_processor(
+ "%2.2d:%2.2d:%2.2d.%s",
+ ("hour", "minute", "second", "microsecond")
+ )
+ else:
+ return self._bind_processor(
+ "%2.2d:%2.2d:%2.2d.%06d",
+ ("hour", "minute", "second", "microsecond")
+ )
+
+ _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
+ def result_processor(self, dialect):
+ return self._result_processor(datetime.time, self._reg)
+
+
+class SLBoolean(sqltypes.Boolean):
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ return value and 1 or 0
+ return process
+
+ def result_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ return value and True or False
+ return process
+
+
+class SQLiteCompiler(compiler.SQLCompiler):
+ functions = compiler.SQLCompiler.functions.copy()
+ functions.update (
+ {
+ sql_functions.now: 'CURRENT_TIMESTAMP',
+ sql_functions.char_length: 'length%(expr)s'
+ }
+ )
+
+ def visit_cast(self, cast, **kwargs):
+ if self.dialect.supports_cast:
+ return super(SQLiteCompiler, self).visit_cast(cast)
+ else:
+ return self.process(cast.clause)
+
+ def limit_clause(self, select):
+ text = ""
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
+ text += " \n LIMIT -1"
+ text += " OFFSET " + str(select._offset)
+ else:
+ text += " OFFSET 0"
+ return text
+
+ def for_update_clause(self, select):
+ # sqlite has no "FOR UPDATE" AFAICT
+ return ''
+
+
+class SQLiteDDLCompiler(compiler.DDLCompiler):
+
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+ return colspec
+
+class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_binary(self, type_):
+ return self.visit_BLOB(type_)
+
+ def visit_CLOB(self, type_):
+ return self.visit_TEXT(type_)
+
+ def visit_NCHAR(self, type_):
+ return self.visit_CHAR(type_)
+
+class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = set([
+ 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
+ 'attach', 'autoincrement', 'before', 'begin', 'between', 'by',
+ 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit',
+ 'conflict', 'constraint', 'create', 'cross', 'current_date',
+ 'current_time', 'current_timestamp', 'database', 'default',
+ 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct',
+ 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive',
+ 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob',
+ 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index',
+ 'initially', 'inner', 'insert', 'instead', 'intersect', 'into', 'is',
+ 'isnull', 'join', 'key', 'left', 'like', 'limit', 'match', 'natural',
+ 'not', 'notnull', 'null', 'of', 'offset', 'on', 'or', 'order', 'outer',
+ 'plan', 'pragma', 'primary', 'query', 'raise', 'references',
+ 'reindex', 'rename', 'replace', 'restrict', 'right', 'rollback',
+ 'row', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to',
+ 'transaction', 'trigger', 'true', 'union', 'unique', 'update', 'using',
+ 'vacuum', 'values', 'view', 'virtual', 'when', 'where',
+ ])
+
+class SQLiteDialect(default.DefaultDialect):
+ name = 'sqlite'
+ supports_alter = False
+ supports_unicode_statements = True
+ supports_default_values = True
+ supports_empty_insert = False
+ supports_cast = True
+ statement_compiler = SQLiteCompiler
+ ddl_compiler = SQLiteDDLCompiler
+ type_compiler = SQLiteTypeCompiler
+ preparer = SQLiteIdentifierPreparer
+
+ def table_names(self, connection, schema):
+ if schema is not None:
+ qschema = self.identifier_preparer.quote_identifier(schema)
+ master = '%s.sqlite_master' % qschema
+ s = ("SELECT name FROM %s "
+ "WHERE type='table' ORDER BY name") % (master,)
+ rs = connection.execute(s)
+ else:
+ try:
+ s = ("SELECT name FROM "
+ " (SELECT * FROM sqlite_master UNION ALL "
+ " SELECT * FROM sqlite_temp_master) "
+ "WHERE type='table' ORDER BY name")
+ rs = connection.execute(s)
+ except exc.DBAPIError:
+ raise
+ s = ("SELECT name FROM sqlite_master "
+ "WHERE type='table' ORDER BY name")
+ rs = connection.execute(s)
+
+ return [row[0] for row in rs]
+
+ def has_table(self, connection, table_name, schema=None):
+ quote = self.identifier_preparer.quote_identifier
+ if schema is not None:
+ pragma = "PRAGMA %s." % quote(schema)
+ else:
+ pragma = "PRAGMA "
+ qtable = quote(table_name)
+ cursor = connection.execute("%stable_info(%s)" % (pragma, qtable))
+ row = cursor.fetchone()
+
+ # consume remaining rows, to work around
+ # http://www.sqlite.org/cvstrac/tktview?tn=1884
+ while cursor.fetchone() is not None:
+ pass
+
+ return (row is not None)
+
+ def reflecttable(self, connection, table, include_columns):
+ preparer = self.identifier_preparer
+ if table.schema is None:
+ pragma = "PRAGMA "
+ else:
+ pragma = "PRAGMA %s." % preparer.quote_identifier(table.schema)
+ qtable = preparer.format_table(table, False)
+
+ c = connection.execute("%stable_info(%s)" % (pragma, qtable))
+ found_table = False
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+
+ found_table = True
+ (name, type_, nullable, default, has_default, primary_key) = (row[1], row[2].upper(), not row[3], row[4], row[4] is not None, row[5])
+ name = re.sub(r'^\"|\"$', '', name)
+ if include_columns and name not in include_columns:
+ continue
+ match = re.match(r'(\w+)(\(.*?\))?', type_)
+ if match:
+ coltype = match.group(1)
+ args = match.group(2)
+ else:
+ coltype = "VARCHAR"
+ args = ''
+
+ try:
+ coltype = self.ischema_names[coltype]
+ except KeyError:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (coltype, name))
+ coltype = sqltypes.NullType
+
+ if args is not None:
+ args = re.findall(r'(\d+)', args)
+ coltype = coltype(*[int(a) for a in args])
+
+ colargs = []
+ if has_default:
+ colargs.append(DefaultClause(sql.text(default)))
+ table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
+
+ if not found_table:
+ raise exc.NoSuchTableError(table.name)
+
+ c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))
+ fks = {}
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ (constraint_name, tablename, localcol, remotecol) = (row[0], row[2], row[3], row[4])
+ tablename = re.sub(r'^\"|\"$', '', tablename)
+ localcol = re.sub(r'^\"|\"$', '', localcol)
+ remotecol = re.sub(r'^\"|\"$', '', remotecol)
+ try:
+ fk = fks[constraint_name]
+ except KeyError:
+ fk = ([], [])
+ fks[constraint_name] = fk
+
+ # look up the table based on the given table's engine, not 'self',
+ # since it could be a ProxyEngine
+ remotetable = schema.Table(tablename, table.metadata, autoload=True, autoload_with=connection)
+ constrained_column = table.c[localcol].name
+ refspec = ".".join([tablename, remotecol])
+ if constrained_column not in fk[0]:
+ fk[0].append(constrained_column)
+ if refspec not in fk[1]:
+ fk[1].append(refspec)
+ for name, value in fks.iteritems():
+ table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], link_to_name=True))
+ # check for UNIQUE indexes
+ c = connection.execute("%sindex_list(%s)" % (pragma, qtable))
+ unique_indexes = []
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ if (row[2] == 1):
+ unique_indexes.append(row[1])
+ # loop thru unique indexes for one that includes the primary key
+ for idx in unique_indexes:
+ c = connection.execute("%sindex_info(%s)" % (pragma, idx))
+ cols = []
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ cols.append(row[2])
+
--- /dev/null
+"""Support for the SQLite database via pysqlite.
+
+Driver
+------
+
+When using Python 2.5 and above, the built in ``sqlite3`` driver is
+already installed and no additional installation is needed. Otherwise,
+the ``pysqlite2`` driver needs to be present. This is the same driver as
+``sqlite3``, just with a different name.
+
+The ``pysqlite2`` driver will be loaded first, and if not found, ``sqlite3``
+is loaded. This allows an explicitly installed pysqlite driver to take
+precedence over the built in one. As with all dialects, a specific
+DBAPI module may be provided to :func:`~sqlalchemy.create_engine()` to control
+this explicitly::
+
+ from sqlite3 import dbapi2 as sqlite
+ e = create_engine('sqlite+pysqlite:///file.db', module=sqlite)
+
+Full documentation on pysqlite is available at:
+`<http://www.initd.org/pub/software/pysqlite/doc/usage-guide.html>`_
+
+Connect Strings
+---------------
+
+The file specification for the SQLite database is taken as the "database" portion of
+the URL. Note that the format of a url is::
+
+ driver://user:pass@host/database
+
+This means that the actual filename to be used starts with the characters to the
+**right** of the third slash. So connecting to a relative filepath looks like::
+
+ # relative path
+ e = create_engine('sqlite:///path/to/database.db')
+
+An absolute path, which is denoted by starting with a slash, means you need **four**
+slashes::
+
+ # absolute path
+ e = create_engine('sqlite:////path/to/database.db')
+
+To use a Windows path, regular drive specifications and backslashes can be used.
+Double backslashes are probably needed::
+
+ # absolute path on Windows
+ e = create_engine('sqlite:///C:\\\\path\\\\to\\\\database.db')
+
+The sqlite ``:memory:`` identifier is the default if no filepath is present. Specify
+``sqlite://`` and nothing else::
+
+ # in-memory database
+ e = create_engine('sqlite://')
+
+Threading Behavior
+------------------
+
+Pysqlite connections do not support being moved between threads, unless
+the ``check_same_thread`` Pysqlite flag is set to ``False``. In addition,
+when using an in-memory SQLite database, the full database exists only within
+the scope of a single connection. It is reported that an in-memory
+database does not support being shared between threads regardless of the
+``check_same_thread`` flag - which means that a multithreaded
+application **cannot** share data from a ``:memory:`` database across threads
+unless access to the connection is limited to a single worker thread which communicates
+through a queueing mechanism to concurrent threads.
+
+To provide a default which accomodates SQLite's default threading capabilities
+somewhat reasonably, the SQLite dialect will specify that the :class:`~sqlalchemy.pool.SingletonThreadPool`
+be used by default. This pool maintains a single SQLite connection per thread
+that is held open up to a count of five concurrent threads. When more than five threads
+are used, a cleanup mechanism will dispose of excess unused connections.
+
+Two optional pool implementations that may be appropriate for particular SQLite usage scenarios:
+
+ * the :class:`sqlalchemy.pool.StaticPool` might be appropriate for a multithreaded
+ application using an in-memory database, assuming the threading issues inherent in
+ pysqlite are somehow accomodated for. This pool holds persistently onto a single connection
+ which is never closed, and is returned for all requests.
+
+ * the :class:`sqlalchemy.pool.NullPool` might be appropriate for an application that
+ makes use of a file-based sqlite database. This pool disables any actual "pooling"
+ behavior, and simply opens and closes real connections corresonding to the :func:`connect()`
+ and :func:`close()` methods. SQLite can "connect" to a particular file with very high
+ efficiency, so this option may actually perform better without the extra overhead
+ of :class:`SingletonThreadPool`. NullPool will of course render a ``:memory:`` connection
+ useless since the database would be lost as soon as the connection is "returned" to the pool.
+
+Date and Time Types
+-------------------
+
+SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does not provide
+out of the box functionality for translating values between Python `datetime` objects
+and a SQLite-supported format. SQLAlchemy's own :class:`~sqlalchemy.types.DateTime`
+and related types provide date formatting and parsing functionality when SQlite is used.
+The implementation classes are :class:`SLDateTime`, :class:`SLDate` and :class:`SLTime`.
+These types represent dates and times as ISO formatted strings, which also nicely
+support ordering. There's no reliance on typical "libc" internals for these functions
+so historical dates are fully supported.
+
+Unicode
+-------
+
+In contrast to SQLAlchemy's active handling of date and time types for pysqlite, pysqlite's
+default behavior regarding Unicode is that all strings are returned as Python unicode objects
+in all cases. So even if the :class:`~sqlalchemy.types.Unicode` type is
+*not* used, you will still always receive unicode data back from a result set. It is
+**strongly** recommended that you do use the :class:`~sqlalchemy.types.Unicode` type
+to represent strings, since it will raise a warning if a non-unicode Python string is
+passed from the user application. Mixing the usage of non-unicode objects with returned unicode objects can
+quickly create confusion, particularly when using the ORM as internal data is not
+always represented by an actual database result string.
+
+"""
+
+from sqlalchemy.dialects.sqlite.base import SLNumeric, SLFloat, SQLiteDialect, SLBoolean, SLDate, SLDateTime, SLTime
+from sqlalchemy import schema, exc, pool
+from sqlalchemy.engine import default
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+from types import NoneType
+
+class SLUnicodeMixin(object):
+ def bind_processor(self, dialect):
+ if self.convert_unicode or dialect.convert_unicode:
+ if self.assert_unicode is None:
+ assert_unicode = dialect.assert_unicode
+ else:
+ assert_unicode = self.assert_unicode
+
+ if not assert_unicode:
+ return None
+
+ def process(value):
+ if not isinstance(value, (unicode, NoneType)):
+ if assert_unicode == 'warn':
+ util.warn("Unicode type received non-unicode bind "
+ "param value %r" % value)
+ return value
+ else:
+ raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
+ else:
+ return value
+ return process
+ else:
+ return None
+
+ def result_processor(self, dialect):
+ return None
+
+class SLText(SLUnicodeMixin, sqltypes.Text):
+ pass
+
+class SLString(SLUnicodeMixin, sqltypes.String):
+ pass
+
+class SLChar(SLUnicodeMixin, sqltypes.CHAR):
+ pass
+
+
+colspecs = {
+ sqltypes.Boolean: SLBoolean,
+ sqltypes.CHAR: SLChar,
+ sqltypes.Date: SLDate,
+ sqltypes.DateTime: SLDateTime,
+ sqltypes.Float: SLFloat,
+ sqltypes.NCHAR: SLChar,
+ sqltypes.Numeric: SLNumeric,
+ sqltypes.String: SLString,
+ sqltypes.Text: SLText,
+ sqltypes.Time: SLTime,
+}
+
+ischema_names = {
+ 'BLOB': sqltypes.Binary,
+ 'BOOL': SLBoolean,
+ 'BOOLEAN': SLBoolean,
+ 'CHAR': SLChar,
+ 'DATE': SLDate,
+ 'DATETIME': SLDateTime,
+ 'DECIMAL': SLNumeric,
+ 'FLOAT': SLNumeric,
+ 'INT': sqltypes.Integer,
+ 'INTEGER': sqltypes.Integer,
+ 'NUMERIC': SLNumeric,
+ 'REAL': SLNumeric,
+ 'SMALLINT': sqltypes.SmallInteger,
+ 'TEXT': SLText,
+ 'TIME': SLTime,
+ 'TIMESTAMP': SLDateTime,
+ 'VARCHAR': SLString,
+}
+
+
+class SQLite_pysqliteExecutionContext(default.DefaultExecutionContext):
+ def post_exec(self):
+ if self.isinsert and not self.executemany:
+ if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
+ self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
+
+
+class SQLite_pysqlite(SQLiteDialect):
+ default_paramstyle = 'qmark'
+ poolclass = pool.SingletonThreadPool
+ execution_ctx_cls = SQLite_pysqliteExecutionContext
+ driver = 'pysqlite'
+ ischema_names = ischema_names
+
+ def __init__(self, **kwargs):
+ SQLiteDialect.__init__(self, **kwargs)
+ def vers(num):
+ return tuple([int(x) for x in num.split('.')])
+ if self.dbapi is not None:
+ sqlite_ver = self.dbapi.version_info
+ if sqlite_ver < (2, 1, '3'):
+ util.warn(
+ ("The installed version of pysqlite2 (%s) is out-dated "
+ "and will cause errors in some cases. Version 2.1.3 "
+ "or greater is recommended.") %
+ '.'.join([str(subver) for subver in sqlite_ver]))
+ if self.dbapi.sqlite_version_info < (3, 3, 8):
+ self.supports_default_values = False
+ self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
+
+ @classmethod
+ def dbapi(cls):
+ try:
+ from pysqlite2 import dbapi2 as sqlite
+ except ImportError, e:
+ try:
+ from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
+ except ImportError:
+ raise e
+ return sqlite
+
+ def server_version_info(self, connection):
+ return self.dbapi.sqlite_version_info
+
+ def create_connect_args(self, url):
+ if url.username or url.password or url.host or url.port:
+ raise exc.ArgumentError(
+ "Invalid SQLite URL: %s\n"
+ "Valid SQLite URL forms are:\n"
+ " sqlite:///:memory: (or, sqlite://)\n"
+ " sqlite:///relative/path/to/file.db\n"
+ " sqlite:////absolute/path/to/file.db" % (url,))
+ filename = url.database or ':memory:'
+
+ opts = url.query.copy()
+ util.coerce_kw_type(opts, 'timeout', float)
+ util.coerce_kw_type(opts, 'isolation_level', str)
+ util.coerce_kw_type(opts, 'detect_types', int)
+ util.coerce_kw_type(opts, 'check_same_thread', bool)
+ util.coerce_kw_type(opts, 'cached_statements', int)
+
+ return ([filename], opts)
+
+ def type_descriptor(self, typeobj):
+ return sqltypes.adapt_type(typeobj, colspecs)
+
+ def is_disconnect(self, e):
+ return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
+
+dialect = SQLite_pysqlite
ResultProxy,
RootTransaction,
RowProxy,
- SchemaIterator,
Transaction,
- TwoPhaseTransaction
+ TwoPhaseTransaction,
+ TypeCompiler
)
from sqlalchemy.engine import strategies
from sqlalchemy import util
'ResultProxy',
'RootTransaction',
'RowProxy',
- 'SchemaIterator',
'Transaction',
'TwoPhaseTransaction',
+ 'TypeCompiler',
'create_engine',
'engine_from_config',
)
All Dialects implement the following attributes:
name
- identifying name for the dialect (i.e. 'sqlite')
+ identifying name for the dialect from a DBAPI-neutral point of view
+ (i.e. 'sqlite')
+
+ driver
+ identitfying name for the dialect's DBAPI
positional
True if the paramstyle for this Dialect is positional.
type of encoding to use for unicode, usually defaults to
'utf-8'.
- schemagenerator
- a :class:`~sqlalchemy.schema.SchemaVisitor` class which generates
- schemas.
-
- schemadropper
- a :class:`~sqlalchemy.schema.SchemaVisitor` class which drops schemas.
-
defaultrunner
a :class:`~sqlalchemy.schema.SchemaVisitor` class which executes
defaults.
statement_compiler
- a :class:`~sqlalchemy.engine.base.Compiled` class used to compile SQL
+ a :class:`~Compiled` class used to compile SQL
statements
+ ddl_compiler
+ a :class:`~Compiled` class used to compile DDL
+ statements
+
+ execution_ctx_cls
+ a :class:`ExecutionContext` class used to handle statement execution
+
preparer
a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to
quote identifiers.
supports_default_values
Indicates if the construct ``INSERT INTO tablename DEFAULT VALUES`` is supported
-
- description_encoding
- type of encoding to use for unicode when working with metadata
- descriptions. If set to ``None`` no encoding will be done.
- This usually defaults to 'utf-8'.
"""
def create_connect_args(self, url):
class Compiled(object):
- """Represent a compiled SQL expression.
+ """Represent a compiled SQL or DDL expression.
The ``__str__`` method of the ``Compiled`` object should produce
the actual text of the statement. ``Compiled`` objects are
``Compiled`` object be dependent on the actual values of those
bind parameters, even though it may reference those values as
defaults.
+
"""
- def __init__(self, dialect, statement, column_keys=None, bind=None):
+ def __init__(self, dialect, statement, bind=None):
"""Construct a new ``Compiled`` object.
dialect
statement
``ClauseElement`` to be compiled.
- column_keys
- a list of column names to be compiled into an INSERT or UPDATE
- statement.
-
bind
Optional Engine or Connection to compile this statement against.
"""
self.dialect = dialect
self.statement = statement
- self.column_keys = column_keys
self.bind = bind
self.can_execute = statement.supports_execution
def compile(self):
"""Produce the internal string representation of this element."""
- raise NotImplementedError()
+ self.string = self.process(self.statement)
+
+ def process(self, obj, **kwargs):
+ return obj._compiler_dispatch(self, **kwargs)
def __str__(self):
- """Return the string text of the generated SQL statement."""
+ """Return the string text of the generated SQL or DDL."""
- raise NotImplementedError()
+ return self.string or ''
@util.deprecated('Deprecated. Use construct_params(). '
'(supports Unicode key names.)')
def get_params(self, **params):
return self.construct_params(params)
- def construct_params(self, params):
+ def construct_params(self, params=None):
"""Return the bind params for this compiled object.
`params` is a dict of string/object pairs whos
values will override bind values compiled in
to the statement.
+
"""
raise NotImplementedError()
return self.execute(*multiparams, **params).scalar()
+class TypeCompiler(object):
+ """Produces DDL specification for TypeEngine objects."""
+
+ def __init__(self, dialect):
+ self.dialect = dialect
+
+ def process(self, type_):
+ return type_._compiler_dispatch(self)
+
class Connectable(object):
"""Interface for an object which supports execution of SQL constructs.
The two implementations of ``Connectable`` are :class:`Connection` and
:class:`Engine`.
+ Connectable must also implement the 'dialect' member which references a
+ :class:`Dialect` instance.
+
"""
def contextual_connect(self):
return self.execute(object, *multiparams, **params).scalar()
- def statement_compiler(self, statement, **kwargs):
- return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs)
-
def execute(self, object, *multiparams, **params):
"""Executes and returns a ResultProxy."""
def _execute_default(self, default, multiparams, params):
return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default)
+ def _execute_ddl(self, ddl, params, multiparams):
+ context = self.__create_execution_context(
+ compiled_ddl=ddl.compile(dialect=self.dialect),
+ parameters=None
+ )
+ return self.__execute_context(context)
+
def _execute_clauseelement(self, elem, multiparams, params):
params = self.__distill_params(multiparams, params)
if params:
keys = []
context = self.__create_execution_context(
- compiled=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1),
+ compiled_sql=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1),
parameters=params
)
return self.__execute_context(context)
"""Execute a sql.Compiled object."""
context = self.__create_execution_context(
- compiled=compiled,
+ compiled_sql=compiled,
parameters=self.__distill_params(multiparams, params)
)
return self.__execute_context(context)
self._commit_impl()
return context.get_result_proxy()
- def _execute_ddl(self, ddl, params, multiparams):
- if params:
- schema_item, params = params[0], params[1:]
- else:
- schema_item = None
- return ddl(None, schema_item, self, *params, **multiparams)
-
def _handle_dbapi_exception(self, e, statement, parameters, cursor, context):
if getattr(self, '_reentrant_error', False):
raise exc.DBAPIError.instance(None, None, e)
expression.ClauseElement: _execute_clauseelement,
Compiled: _execute_compiled,
schema.SchemaItem: _execute_default,
- schema.DDL: _execute_ddl,
+ schema.DDLElement: _execute_ddl,
basestring: _execute_text
}
def create(self, entity, connection=None, **kwargs):
"""Create a table or index within this engine's database connection given a schema.Table object."""
- self._run_visitor(self.dialect.schemagenerator, entity, connection=connection, **kwargs)
+ from sqlalchemy.engine import ddl
+
+ self._run_visitor(ddl.SchemaGenerator, entity, connection=connection, **kwargs)
def drop(self, entity, connection=None, **kwargs):
"""Drop a table or index within this engine's database connection given a schema.Table object."""
- self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs)
+ from sqlalchemy.engine import ddl
+
+ self._run_visitor(ddl.SchemaDropper, entity, connection=connection, **kwargs)
def _execute_default(self, default):
connection = self.contextual_connect()
connection = self.contextual_connect(close_with_result=True)
return connection._execute_compiled(compiled, multiparams, params)
- def statement_compiler(self, statement, **kwargs):
- return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs)
-
def connect(self, **kwargs):
"""Return a newly allocated Connection object."""
l.append(row)
return l
-
-class SchemaIterator(schema.SchemaVisitor):
- """A visitor that can gather text into a buffer and execute the contents of the buffer."""
-
- def __init__(self, connection):
- """Construct a new SchemaIterator."""
-
- self.connection = connection
- self.buffer = StringIO.StringIO()
-
- def append(self, s):
- """Append content to the SchemaIterator's query buffer."""
-
- self.buffer.write(s)
-
- def execute(self):
- """Execute the contents of the SchemaIterator's buffer."""
-
- try:
- return self.connection.execute(self.buffer.getvalue())
- finally:
- self.buffer.truncate(0)
-
class DefaultRunner(schema.SchemaVisitor):
"""A visitor which accepts ColumnDefault objects, produces the
dialect-specific SQL corresponding to their execution, and
--- /dev/null
+"""routines to handle CREATE/DROP workflow."""
+
+### TOOD: CREATE TABLE and DROP TABLE have been moved out so far.
+### Index, ForeignKey, etc. still need to move.
+
+from sqlalchemy import engine, schema
+from sqlalchemy.sql import util as sql_util
+
+class DDLBase(schema.SchemaVisitor):
+ def __init__(self, connection):
+ self.connection = connection
+
+ def find_alterables(self, tables):
+ alterables = []
+ class FindAlterables(schema.SchemaVisitor):
+ def visit_foreign_key_constraint(self, constraint):
+ if constraint.use_alter and constraint.table in tables:
+ alterables.append(constraint)
+ findalterables = FindAlterables()
+ for table in tables:
+ for c in table.constraints:
+ findalterables.traverse(c)
+ return alterables
+
+
+class SchemaGenerator(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+ super(SchemaGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+ self.tables = tables and set(tables) or None
+ self.preparer = dialect.identifier_preparer
+ self.dialect = dialect
+
+ def _can_create(self, table):
+ self.dialect.validate_identifier(table.name)
+ if table.schema:
+ self.dialect.validate_identifier(table.schema)
+ return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
+
+ def visit_metadata(self, metadata):
+ if self.tables:
+ tables = self.tables
+ else:
+ tables = metadata.tables.values()
+ collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
+ for table in collection:
+ self.traverse_single(table)
+ if self.dialect.supports_alter:
+ for alterable in self.find_alterables(collection):
+ self.connection.execute(schema.AddForeignKey(alterable))
+
+ def visit_table(self, table):
+ for listener in table.ddl_listeners['before-create']:
+ listener('before-create', table, self.connection)
+
+ for column in table.columns:
+ if column.default is not None:
+ self.traverse_single(column.default)
+
+ self.connection.execute(schema.CreateTable(table))
+
+ if hasattr(table, 'indexes'):
+ for index in table.indexes:
+ self.traverse_single(index)
+
+ for listener in table.ddl_listeners['after-create']:
+ listener('after-create', table, self.connection)
+
+ def visit_sequence(self, sequence):
+ if self.dialect.supports_sequences:
+ if \
+ (not self.dialect.sequences_optional or not sequence.optional) and \
+ (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name)):
+ self.connection.execute(schema.CreateSequence(sequence))
+
+ def visit_index(self, index):
+ self.connection.execute(schema.CreateIndex(index))
+
+class SchemaDropper(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
+ super(SchemaDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+ self.tables = tables
+ self.preparer = dialect.identifier_preparer
+ self.dialect = dialect
+
+ def visit_metadata(self, metadata):
+ if self.tables:
+ tables = self.tables
+ else:
+ tables = metadata.tables.values()
+ collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
+ if self.dialect.supports_alter:
+ for alterable in self.find_alterables(collection):
+ self.connection.execute(schema.DropForeignKey(alterable))
+ for table in collection:
+ self.traverse_single(table)
+
+ def _can_drop(self, table):
+ self.dialect.validate_identifier(table.name)
+ if table.schema:
+ self.dialect.validate_identifier(table.schema)
+ return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
+
+ def visit_index(self, index):
+ self.connection.execute(schema.DropIndex(index))
+
+ def visit_table(self, table):
+ for listener in table.ddl_listeners['before-drop']:
+ listener('before-drop', table, self.connection)
+
+ for column in table.columns:
+ if column.default is not None:
+ self.traverse_single(column.default)
+
+ self.connection.execute(schema.DropTable(table))
+
+ for listener in table.ddl_listeners['after-drop']:
+ listener('after-drop', table, self.connection)
+
+ def visit_sequence(self, sequence):
+ if self.dialect.supports_sequences:
+ if \
+ (not self.dialect.sequences_optional or not sequence.optional) and \
+ (not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name)):
+ self.connection.execute(schema.DropSequence(sequence))
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
- schemagenerator = compiler.SchemaGenerator
- schemadropper = compiler.SchemaDropper
- statement_compiler = compiler.DefaultCompiler
+ statement_compiler = compiler.SQLCompiler
+ ddl_compiler = compiler.DDLCompiler
+ type_compiler = compiler.GenericTypeCompiler
preparer = compiler.IdentifierPreparer
defaultrunner = base.DefaultRunner
supports_alter = True
+ supports_sequences = False
+ sequences_optional = False
supports_unicode_statements = False
max_identifier_length = 9999
supports_sane_rowcount = True
self.paramstyle = self.default_paramstyle
self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
self.identifier_preparer = self.preparer(self)
+ self.type_compiler = self.type_compiler(self)
+
if label_length and label_length > self.max_identifier_length:
raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length))
self.label_length = label_length
the generic object which comes from the types module.
Subclasses will usually use the ``adapt_type()`` method in the
- types module to make this job easy."""
-
+ types module to make this job easy.
+
+ """
if type(typeobj) is type:
typeobj = typeobj()
return typeobj
class DefaultExecutionContext(base.ExecutionContext):
- def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
+ def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None):
self.dialect = dialect
self._connection = self.root_connection = connection
- self.compiled = compiled
self.engine = connection.engine
- if compiled is not None:
+ if compiled_ddl is not None:
+ self.compiled = compiled = compiled_ddl
+ if not dialect.supports_unicode_statements:
+ self.statement = unicode(compiled).encode(self.dialect.encoding)
+ else:
+ self.statement = unicode(compiled)
+ self.isinsert = self.isupdate = self.executemany = False
+ self.should_autocommit = True
+ self.result_map = None
+ self.cursor = self.create_cursor()
+ self.compiled_parameters = []
+ if self.dialect.positional:
+ self.parameters = [()]
+ else:
+ self.parameters = [{}]
+ elif compiled_sql is not None:
+ self.compiled = compiled = compiled_sql
+
# compiled clauseelement. process bind params, process table defaults,
# track collections used by ResultProxy to target and process results
self.parameters = self.__convert_compiled_params(self.compiled_parameters)
elif statement is not None:
- # plain text statement.
- self.result_map = None
+ # plain text statement
+ self.result_map = self.compiled = None
self.parameters = self.__encode_param_keys(parameters)
self.executemany = len(parameters) > 1
if isinstance(statement, unicode) and not dialect.supports_unicode_statements:
self.should_autocommit = self.should_autocommit_text(statement)
else:
# no statement. used for standalone ColumnDefault execution.
- self.statement = None
+ self.statement = self.compiled = None
self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False
self.cursor = self.create_cursor()
def create(self, entity, **kwargs):
kwargs['checkfirst'] = False
- self.dialect.schemagenerator(self.dialect, self, **kwargs).traverse(entity)
+ from sqlalchemy.engine import ddl
+
+ ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity)
def drop(self, entity, **kwargs):
kwargs['checkfirst'] = False
- self.dialect.schemadropper(self.dialect, self, **kwargs).traverse(entity)
+ from sqlalchemy.engine import ddl
+ ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity)
def execute(self, object, *multiparams, **params):
raise NotImplementedError()
"""Return the SQLAlchemy database dialect class corresponding to this URL's driver name."""
try:
- module = getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
+ if '+' in self.drivername:
+ dialect, driver = self.drivername.split('+')
+ else:
+ dialect, driver = self.drivername, 'base'
+
+ module = __import__('sqlalchemy.dialects.%s.%s' % (dialect, driver)).dialects
+ module = getattr(module, dialect)
+ module = getattr(module, driver)
+
return module.dialect
except ImportError:
if sys.exc_info()[2].tb_next is None:
def _parse_rfc1738_args(name):
pattern = re.compile(r'''
- (?P<name>\w+)://
+ (?P<name>[\w\+]+)://
(?:
(?P<username>[^:/]*)
(?::(?P<password>[^/]*))?
*(tuple(cols) + tuple(args)), **table_kw)
else:
table = cls.__table__
- if cols:
- raise exceptions.ArgumentError("Can't add additional columns when specifying __table__")
-
+
mapper_args = getattr(cls, '__mapper_args__', {})
if 'inherits' not in mapper_args:
inherits = cls.__mro__[1]
mapper_args['exclude_properties'] = exclude_properties = \
set([c.key for c in inherited_table.c if c not in inherited_mapper._columntoproperty])
exclude_properties.difference_update([c.key for c in cols])
-
+
cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args)
class DeclarativeMeta(type):
if col in pks:
if history.deleted:
params[col._label] = prop.get_col_value(col, history.deleted[0])
- hasdata = True
else:
# row switch logic can reach us here
# remove the pk from the update params so the update doesn't
# attempt to include the pk in the update statement
del params[col.key]
params[col._label] = prop.get_col_value(col, history.added[0])
- else:
- hasdata = True
+ hasdata = True
elif col in pks:
params[col._label] = mapper._get_state_attr_by_column(state, col)
if hasdata:
"""Describes an object attribute that corresponds to a table column."""
def __init__(self, *columns, **kwargs):
- """Construct a ColumnProperty.
-
- :param \*columns: The list of `columns` describes a single
- object property. If there are multiple tables joined
- together for the mapper, this list represents the equivalent
- column as it appears across each table.
-
- :param group:
-
- :param deferred:
-
- :param comparator_factory:
-
- :param descriptor:
-
- :param extension:
-
+ """The list of `columns` describes a single object
+ property. If there are multiple tables joined together for the
+ mapper, this list represents the equivalent column as it
+ appears across each table.
"""
+
self.columns = [expression._labeled(c) for c in columns]
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
self.comparator_factory = kwargs.pop('comparator_factory', self.__class__.Comparator)
self.descriptor = kwargs.pop('descriptor', None)
self.extension = kwargs.pop('extension', None)
- if kwargs:
- raise TypeError(
- "%s received unexpected keyword argument(s): %s" % (
- self.__class__.__name__, ', '.join(sorted(kwargs.keys()))))
-
util.set_creation_order(self)
if self.no_instrument:
self.strategy_class = strategies.UninstrumentedColumnLoader
mapper.SynonymProperty = SynonymProperty
mapper.ComparableProperty = ComparableProperty
mapper.RelationProperty = RelationProperty
-mapper.ConcreteInheritedProperty = ConcreteInheritedProperty
+mapper.ConcreteInheritedProperty = ConcreteInheritedProperty
\ No newline at end of file
"Unknown arguments passed to Column: " + repr(kwargs.keys()))
def __str__(self):
- if self.name is None:
- return "(no name)"
- elif self.table is not None:
+ if self.table is not None:
if self.table.named_with_column:
return (self.table.description + "." + self.description)
else:
else:
return self.description
- @property
def bind(self):
return self.table.bind
+ bind = property(bind)
def references(self, column):
"""Return True if this Column references the given column via foreign key."""
__traverse_options__ = {'schema_visitor':True}
-class DDL(object):
+class DDLElement(expression.ClauseElement):
+ """Base class for DDL expression constructs."""
+
+ supports_execution = True
+ _autocommit = True
+
+ def bind(self):
+ if self._bind:
+ return self._bind
+ def _set_bind(self, bind):
+ self._bind = bind
+ bind = property(bind, _set_bind)
+
+ def _generate(self):
+ s = self.__class__.__new__(self.__class__)
+ s.__dict__ = self.__dict__.copy()
+ return s
+
+ def _compiler(self, dialect, **kw):
+ """Return a compiler appropriate for this ClauseElement, given a Dialect."""
+
+ return dialect.ddl_compiler(dialect, self, **kw)
+
+class DDL(DDLElement):
"""A literal DDL statement.
Specifies literal SQL DDL to be executed by the database. DDL objects can
connection.execute(drop_spow)
"""
+ __visit_name__ = "ddl"
+
def __init__(self, statement, on=None, context=None, bind=None):
"""Create a DDL statement.
self.on = on
self.context = context or {}
self._bind = bind
+ self.schema_item = None
def execute(self, bind=None, schema_item=None):
"""Execute this DDL immediately.
if bind is None:
bind = _bind_or_error(self)
- # no SQL bind params are supported
+
if self._should_execute(None, schema_item, bind):
- executable = expression.text(self._expand(schema_item, bind))
- return bind.execute(executable)
+ return bind.execute(self.against(schema_item))
else:
bind.engine.logger.info("DDL execution skipped, criteria not met.")
(', '.join(schema_item.ddl_events), event))
schema_item.ddl_listeners[event].append(self)
return self
-
- def bind(self):
- """An Engine or Connection to which this DDL is bound.
-
- This property may be assigned an ``Engine`` or ``Connection``, or
- assigned a string or URL to automatically create a basic ``Engine``
- for this bind with ``create_engine()``.
- """
- return self._bind
-
- def _bind_to(self, bind):
- """Bind this MetaData to an Engine, Connection, string or URL."""
-
- global URL
- if URL is None:
- from sqlalchemy.engine.url import URL
-
- if isinstance(bind, (basestring, URL)):
- from sqlalchemy import create_engine
- self._bind = create_engine(bind)
- else:
- self._bind = bind
- bind = property(bind, _bind_to)
-
+
+ @expression._generative
+ def against(self, schema_item):
+ """Return a copy of this DDL against a specific schema item."""
+
+ self.schema_item = schema_item
+
def __call__(self, event, schema_item, bind):
"""Execute the DDL as a ddl_listener."""
if self._should_execute(event, schema_item, bind):
- statement = expression.text(self._expand(schema_item, bind))
- return bind.execute(statement)
-
- def _expand(self, schema_item, bind):
- return self.statement % self._prepare_context(schema_item, bind)
+ return bind.execute(self.against(schema_item))
def _should_execute(self, event, schema_item, bind):
if self.on is None:
else:
return self.on(event, schema_item, bind)
- def _prepare_context(self, schema_item, bind):
- # table events can substitute table and schema name
- if isinstance(schema_item, Table):
- context = self.context.copy()
-
- preparer = bind.dialect.identifier_preparer
- path = preparer.format_table_seq(schema_item)
- if len(path) == 1:
- table, schema = path[0], ''
- else:
- table, schema = path[-1], path[0]
-
- context.setdefault('table', table)
- context.setdefault('schema', schema)
- context.setdefault('fullname', preparer.format_table(schema_item))
- return context
- else:
- return self.context
-
def __repr__(self):
return '<%s@%s; %s>' % (
type(self).__name__, id(self),
if getattr(self, key)]))
def _to_schema_column(element):
- if hasattr(element, '__clause_element__'):
- element = element.__clause_element__()
- if not isinstance(element, Column):
- raise exc.ArgumentError("schema.Column object expected")
- return element
+ if hasattr(element, '__clause_element__'):
+ element = element.__clause_element__()
+ if not isinstance(element, Column):
+ raise exc.ArgumentError("schema.Column object expected")
+ return element
+
+class _CreateDropBase(DDLElement):
+ """Base class for DDL constucts that represent CREATE and DROP or equivalents.
+
+ The common theme of _CreateDropBase is a single
+ ``element`` attribute which refers to the element
+ to be created or dropped.
+
+ """
+
+ def __init__(self, element):
+ self.element = element
+
+ def bind(self):
+ if self._bind:
+ return self._bind
+ if self.element:
+ e = self.element.bind
+ if e:
+ return e
+ return None
+
+ def _set_bind(self, bind):
+ self._bind = bind
+ bind = property(bind, _set_bind)
+
+class CreateTable(_CreateDropBase):
+ """Represent a CREATE TABLE statement."""
+
+ __visit_name__ = "create_table"
+
+class DropTable(_CreateDropBase):
+ """Represent a DROP TABLE statement."""
+
+ __visit_name__ = "drop_table"
+
+class AddForeignKey(_CreateDropBase):
+ """Represent an ALTER TABLE ADD FOREIGN KEY statement."""
+
+ __visit_name__ = "add_foreignkey"
+
+class DropForeignKey(_CreateDropBase):
+ """Represent an ALTER TABLE DROP FOREIGN KEY statement."""
+
+ __visit_name__ = "drop_foreignkey"
+
+class CreateSequence(_CreateDropBase):
+ """Represent a CREATE SEQUENCE statement."""
+
+ __visit_name__ = "create_sequence"
+
+class DropSequence(_CreateDropBase):
+ """Represent a DROP SEQUENCE statement."""
+
+ __visit_name__ = "drop_sequence"
+
+class CreateIndex(_CreateDropBase):
+ """Represent a CREATE INDEX statement."""
+
+ __visit_name__ = "create_index"
+
+class DropIndex(_CreateDropBase):
+ """Represent a DROP INDEX statement."""
+
+ __visit_name__ = "drop_index"
def _bind_or_error(schemaitem):
bind = schemaitem.bind
def quote(self):
return self.element.quote
-class DefaultCompiler(engine.Compiled):
+class SQLCompiler(engine.Compiled):
"""Default implementation of Compiled.
Compiles ClauseElements into SQL strings. Uses a similar visit
operators = OPERATORS
functions = FUNCTIONS
- # if we are insert/update/delete.
- # set to true when we visit an INSERT, UPDATE or DELETE
+ # class-level defaults which can be set at the instance
+ # level to define if this Compiled instance represents
+ # INSERT/UPDATE/DELETE
isdelete = isinsert = isupdate = False
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
statement.
"""
- engine.Compiled.__init__(self, dialect, statement, column_keys, **kwargs)
+ engine.Compiled.__init__(self, dialect, statement, **kwargs)
+
+ self.column_keys = column_keys
# compile INSERT/UPDATE defaults/sequences inlined (no pre-execute)
self.inline = inline or getattr(statement, 'inline', False)
# or dialect.max_identifier_length
self.truncated_names = {}
- def compile(self):
- self.string = self.process(self.statement)
-
- def process(self, obj, **kwargs):
- return obj._compiler_dispatch(self, **kwargs)
-
def is_subquery(self):
return len(self.stack) > 1
return index.name
def visit_typeclause(self, typeclause, **kwargs):
- return typeclause.type.dialect_impl(self.dialect).get_col_spec()
+ return self.dialect.type_compiler.process(typeclause.type)
def post_process_text(self, text):
return text
def visit_release_savepoint(self, savepoint_stmt):
return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
- def __str__(self):
- return self.string or ''
-
-class DDLBase(engine.SchemaIterator):
- def find_alterables(self, tables):
- alterables = []
- class FindAlterables(schema.SchemaVisitor):
- def visit_foreign_key_constraint(self, constraint):
- if constraint.use_alter and constraint.table in tables:
- alterables.append(constraint)
- findalterables = FindAlterables()
- for table in tables:
- for c in table.constraints:
- findalterables.traverse(c)
- return alterables
- def _validate_identifier(self, ident, truncate):
- if truncate:
- if len(ident) > self.dialect.max_identifier_length:
- counter = getattr(self, 'counter', 0)
- self.counter = counter + 1
- return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
+class DDLCompiler(engine.Compiled):
+ @property
+ def preparer(self):
+ return self.dialect.identifier_preparer
+
+ def visit_ddl(self, ddl, **kwargs):
+ # table events can substitute table and schema name
+ context = ddl.context
+ if isinstance(ddl.schema_item, schema.Table):
+ context = context.copy()
+
+ preparer = self.dialect.identifier_preparer
+ path = preparer.format_table_seq(ddl.schema_item)
+ if len(path) == 1:
+ table, sch = path[0], ''
else:
- return ident
- else:
- self.dialect.validate_identifier(ident)
- return ident
-
-
-class SchemaGenerator(DDLBase):
- def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
- super(SchemaGenerator, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
- self.tables = tables and set(tables) or None
- self.preparer = dialect.identifier_preparer
- self.dialect = dialect
-
- def get_column_specification(self, column, first_pk=False):
- raise NotImplementedError()
-
- def _can_create(self, table):
- self.dialect.validate_identifier(table.name)
- if table.schema:
- self.dialect.validate_identifier(table.schema)
- return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema)
+ table, sch = path[-1], path[0]
- def visit_metadata(self, metadata):
- if self.tables:
- tables = self.tables
- else:
- tables = metadata.tables.values()
- collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)]
- for table in collection:
- self.traverse_single(table)
- if self.dialect.supports_alter:
- for alterable in self.find_alterables(collection):
- self.add_foreignkey(alterable)
-
- def visit_table(self, table):
- for listener in table.ddl_listeners['before-create']:
- listener('before-create', table, self.connection)
+ context.setdefault('table', table)
+ context.setdefault('schema', sch)
+ context.setdefault('fullname', preparer.format_table(ddl.schema_item))
+
+ return ddl.statement % context
- for column in table.columns:
- if column.default is not None:
- self.traverse_single(column.default)
+ def visit_create_table(self, create):
+ table = create.element
+ preparer = self.dialect.identifier_preparer
- self.append("\n" + " ".join(['CREATE'] +
- table._prefixes +
+ text = "\n" + " ".join(['CREATE'] + \
+ table._prefixes + \
['TABLE',
- self.preparer.format_table(table),
- "("]))
+ preparer.format_table(table),
+ "("])
separator = "\n"
# if only one primary key, specify it along with the column
first_pk = False
for column in table.columns:
- self.append(separator)
+ text += separator
separator = ", \n"
- self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk))
+ text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)
if column.primary_key:
first_pk = True
for constraint in column.constraints:
- self.traverse_single(constraint)
+ text += self.process(constraint)
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
if table.primary_key:
- self.traverse_single(table.primary_key)
+ text += self.process(table.primary_key)
+
for constraint in [c for c in table.constraints if c is not table.primary_key]:
- self.traverse_single(constraint)
+ text += self.process(constraint)
- self.append("\n)%s\n\n" % self.post_create_table(table))
- self.execute()
+ text += "\n)%s\n\n" % self.post_create_table(table)
+ return text
+
+ def visit_drop_table(self, drop):
+ return "\nDROP TABLE " + self.preparer.format_table(drop.element)
+
+ def visit_add_foreignkey(self, add):
+ return "ALTER TABLE %s ADD " % self.preparer.format_table(add.element.table) + \
+ self.define_foreign_key(add.element)
- if hasattr(table, 'indexes'):
- for index in table.indexes:
- self.traverse_single(index)
+ def visit_drop_foreignkey(self, drop):
+ return "ALTER TABLE %s DROP CONSTRAINT %s" % (
+ self.preparer.format_table(drop.element.table),
+ self.preparer.format_constraint(drop.element))
- for listener in table.ddl_listeners['after-create']:
- listener('after-create', table, self.connection)
+ def visit_create_index(self, create):
+ index = create.element
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ text += "INDEX %s ON %s (%s)" \
+ % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
+ preparer.format_table(index.table),
+ ', '.join(preparer.quote(c.name, c.quote)
+ for c in index.columns))
+ return text
+
+ def visit_drop_index(self, drop):
+ index = drop.element
+ return "\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
+
+ def get_column_specification(self, column, first_pk=False):
+ raise NotImplementedError()
def post_create_table(self, table):
return ''
+ def _compile(self, tocompile, parameters):
+ """compile the given string/parameters using this SchemaGenerator's dialect."""
+ compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
+ compiler.compile()
+ return compiler
+
+ def _validate_identifier(self, ident, truncate):
+ if truncate:
+ if len(ident) > self.dialect.max_identifier_length:
+ counter = getattr(self, 'counter', 0)
+ self.counter = counter + 1
+ return ident[0:self.dialect.max_identifier_length - 6] + "_" + hex(self.counter)[2:]
+ else:
+ return ident
+ else:
+ self.dialect.validate_identifier(ident)
+ return ident
+
def get_column_default_string(self, column):
if isinstance(column.server_default, schema.DefaultClause):
if isinstance(column.server_default.arg, basestring):
else:
return None
- def _compile(self, tocompile, parameters):
- """compile the given string/parameters using this SchemaGenerator's dialect."""
- compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
- compiler.compile()
- return compiler
-
def visit_check_constraint(self, constraint):
- self.append(", \n\t")
+ text = ", \n\t"
if constraint.name is not None:
- self.append("CONSTRAINT %s " %
- self.preparer.format_constraint(constraint))
- self.append(" CHECK (%s)" % constraint.sqltext)
- self.define_constraint_deferrability(constraint)
+ text += "CONSTRAINT %s " % \
+ self.preparer.format_constraint(constraint)
+ text += " CHECK (%s)" % constraint.sqltext
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_column_check_constraint(self, constraint):
- self.append(" CHECK (%s)" % constraint.sqltext)
- self.define_constraint_deferrability(constraint)
+ text = " CHECK (%s)" % constraint.sqltext
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
- return
- self.append(", \n\t")
+ return ''
+ text = ", \n\t"
if constraint.name is not None:
- self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
- self.append("PRIMARY KEY ")
- self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
- for c in constraint))
- self.define_constraint_deferrability(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+ text += "PRIMARY KEY "
+ text += "(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
+ for c in constraint)
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_foreign_key_constraint(self, constraint):
if constraint.use_alter and self.dialect.supports_alter:
- return
- self.append(", \n\t ")
- self.define_foreign_key(constraint)
-
- def add_foreignkey(self, constraint):
- self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table))
- self.define_foreign_key(constraint)
- self.execute()
+ return ''
+
+ return ", \n\t " + self.define_foreign_key(constraint)
def define_foreign_key(self, constraint):
- preparer = self.preparer
+ preparer = self.dialect.identifier_preparer
+ text = ""
if constraint.name is not None:
- self.append("CONSTRAINT %s " %
- preparer.format_constraint(constraint))
+ text += "CONSTRAINT %s " % \
+ preparer.format_constraint(constraint)
table = list(constraint.elements)[0].column.table
- self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
+ text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
', '.join(preparer.quote(f.parent.name, f.parent.quote)
for f in constraint.elements),
preparer.format_table(table),
', '.join(preparer.quote(f.column.name, f.column.quote)
for f in constraint.elements)
- ))
+ )
if constraint.ondelete is not None:
- self.append(" ON DELETE %s" % constraint.ondelete)
+ text += " ON DELETE %s" % constraint.ondelete
if constraint.onupdate is not None:
- self.append(" ON UPDATE %s" % constraint.onupdate)
- self.define_constraint_deferrability(constraint)
+ text += " ON UPDATE %s" % constraint.onupdate
+ text += self.define_constraint_deferrability(constraint)
+ return text
def visit_unique_constraint(self, constraint):
- self.append(", \n\t")
+ text = ", \n\t"
if constraint.name is not None:
- self.append("CONSTRAINT %s " %
- self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint)))
- self.define_constraint_deferrability(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
+ text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))
+ text += self.define_constraint_deferrability(constraint)
+ return text
def define_constraint_deferrability(self, constraint):
+ text = ""
if constraint.deferrable is not None:
if constraint.deferrable:
- self.append(" DEFERRABLE")
+ text += " DEFERRABLE"
else:
- self.append(" NOT DEFERRABLE")
+ text += " NOT DEFERRABLE"
if constraint.initially is not None:
- self.append(" INITIALLY %s" % constraint.initially)
+ text += " INITIALLY %s" % constraint.initially
+ return text
+
+
+# PLACEHOLDERS to get non-converted dialects to compile
+class SchemaGenerator(object):
+ pass
+
+class SchemaDropper(object):
+ pass
+
+
+class GenericTypeCompiler(engine.TypeCompiler):
+ def visit_CHAR(self, type_):
+ return "CHAR" + (type_.length and "(%d)" % type_.length or "")
- def visit_column(self, column):
- pass
+ def visit_NCHAR(self, type_):
+ return "NCHAR" + (type_.length and "(%d)" % type_.length or "")
+
+ def visit_FLOAT(self, type_):
+ return "FLOAT"
- def visit_index(self, index):
- preparer = self.preparer
- self.append("CREATE ")
- if index.unique:
- self.append("UNIQUE ")
- self.append("INDEX %s ON %s (%s)" \
- % (preparer.quote(self._validate_identifier(index.name, True), index.quote),
- preparer.format_table(index.table),
- ', '.join(preparer.quote(c.name, c.quote)
- for c in index.columns)))
- self.execute()
+ def visit_NUMERIC(self, type_):
+ if type_.precision is None:
+ return "NUMERIC"
+ else:
+ return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}
+ def visit_DECIMAL(self, type_):
+ return "DECIMAL"
+
+ def visit_INTEGER(self, type_):
+ return "INTEGER"
-class SchemaDropper(DDLBase):
- def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
- super(SchemaDropper, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
- self.tables = tables
- self.preparer = dialect.identifier_preparer
- self.dialect = dialect
+ def visit_SMALLINT(self, type_):
+ return "SMALLINT"
- def visit_metadata(self, metadata):
- if self.tables:
- tables = self.tables
- else:
- tables = metadata.tables.values()
- collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
- if self.dialect.supports_alter:
- for alterable in self.find_alterables(collection):
- self.drop_foreignkey(alterable)
- for table in collection:
- self.traverse_single(table)
-
- def _can_drop(self, table):
- self.dialect.validate_identifier(table.name)
- if table.schema:
- self.dialect.validate_identifier(table.schema)
- return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema)
-
- def visit_index(self, index):
- self.append("\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote))
- self.execute()
-
- def drop_foreignkey(self, constraint):
- self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (
- self.preparer.format_table(constraint.table),
- self.preparer.format_constraint(constraint)))
- self.execute()
-
- def visit_table(self, table):
- for listener in table.ddl_listeners['before-drop']:
- listener('before-drop', table, self.connection)
+ def visit_TIMESTAMP(self, type_):
+ return 'TIMESTAMP'
- for column in table.columns:
- if column.default is not None:
- self.traverse_single(column.default)
+ def visit_DATETIME(self, type_):
+ return "DATETIME"
- self.append("\nDROP TABLE " + self.preparer.format_table(table))
- self.execute()
+ def visit_DATE(self, type_):
+ return "DATE"
- for listener in table.ddl_listeners['after-drop']:
- listener('after-drop', table, self.connection)
+ def visit_TIME(self, type_):
+ return "TIME"
+ def visit_CLOB(self, type_):
+ return "CLOB"
+ def visit_VARCHAR(self, type_):
+ return "VARCHAR" + (type_.length and "(%d)" % type_.length or "")
+
+ def visit_BLOB(self, type_):
+ return "BLOB"
+
+ def visit_BINARY(self, type_):
+ return "BINARY"
+
+ def visit_BOOLEAN(self, type_):
+ return "BOOLEAN"
+
+ def visit_TEXT(self, type_):
+ return "TEXT"
+
+ def visit_binary(self, type_):
+ return self.visit_BINARY(type_)
+ def visit_boolean(self, type_):
+ return self.visit_BOOLEAN(type_)
+ def visit_time(self, type_):
+ return self.visit_TIME(type_)
+ def visit_datetime(self, type_):
+ return self.visit_DATETIME(type_)
+ def visit_date(self, type_):
+ return self.visit_DATE(type_)
+ def visit_small_integer(self, type_):
+ return self.visit_SMALLINT(type_)
+ def visit_integer(self, type_):
+ return self.visit_INTEGER(type_)
+ def visit_float(self, type_):
+ return self.visit_FLOAT(type_)
+ def visit_numeric(self, type_):
+ return self.visit_NUMERIC(type_)
+ def visit_string(self, type_):
+ return self.visit_VARCHAR(type_)
+ def visit_text(self, type_):
+ return self.visit_TEXT(type_)
+
+ def visit_null(self, type_):
+ raise NotImplementedError("Can't generate DDL for the null type")
+
+ def visit_type_decorator(self, type_):
+ return self.process(type_.dialect_impl(self.dialect).impl)
+
+ def visit_user_defined(self, type_):
+ return type_.get_col_spec()
+
class IdentifierPreparer(object):
"""Handle quoting and case-folding of identifiers based on options."""
from sqlalchemy import util, exc
from sqlalchemy.sql import operators
from sqlalchemy.sql.visitors import Visitable, cloned_traverse
-from sqlalchemy import types as sqltypes
import operator
-functions, schema, sql_util = None, None, None
+functions, schema, sql_util, sqltypes = None, None, None, None
DefaultDialect, ClauseAdapter, Annotated = None, None, None
__all__ = [
_annotations = {}
supports_execution = False
_from_objects = []
-
+ _bind = None
+
def _clone(self):
"""Create a shallow copy of this ClauseElement.
def bind(self):
"""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found."""
- try:
- if self._bind is not None:
- return self._bind
- except AttributeError:
- pass
+ if self._bind is not None:
+ return self._bind
+
for f in _from_objects(self):
if f is self:
continue
return self.execute(*multiparams, **params).scalar()
- def compile(self, bind=None, column_keys=None, compiler=None, dialect=None, inline=False):
+ def compile(self, bind=None, dialect=None, **kw):
"""Compile this SQL expression.
The return value is a :class:`~sqlalchemy.engine.Compiled` object.
takes precedence over this ``ClauseElement``'s
bound engine, if any.
- column_keys
- Used for INSERT and UPDATE statements, a list of
- column names which should be present in the VALUES clause
- of the compiled statement. If ``None``, all columns
- from the target table object are rendered.
-
- compiler
- A ``Compiled`` instance which will be used to compile
- this expression. This argument takes precedence
- over the `bind` and `dialect` arguments as well as
- this ``ClauseElement``'s bound engine, if
- any.
-
dialect
A ``Dialect`` instance frmo which a ``Compiled``
will be acquired. This argument takes precedence
over the `bind` argument as well as this
``ClauseElement``'s bound engine, if any.
- inline
- Used for INSERT statements, for a dialect which does
- not support inline retrieval of newly generated
- primary key columns, will force the expression used
- to create the new primary key value to be rendered
- inline within the INSERT statement's VALUES clause.
- This typically refers to Sequence execution but
- may also refer to any server-side default generation
- function associated with a primary key `Column`.
+ \**kw
+
+ Keyword arguments are passed along to the compiler,
+ which can affect the string produced.
+
+ Keywords for a statement compiler are:
+
+ column_keys
+ Used for INSERT and UPDATE statements, a list of
+ column names which should be present in the VALUES clause
+ of the compiled statement. If ``None``, all columns
+ from the target table object are rendered.
+
+ inline
+ Used for INSERT statements, for a dialect which does
+ not support inline retrieval of newly generated
+ primary key columns, will force the expression used
+ to create the new primary key value to be rendered
+ inline within the INSERT statement's VALUES clause.
+ This typically refers to Sequence execution but
+ may also refer to any server-side default generation
+ function associated with a primary key `Column`.
"""
- if compiler is None:
- if dialect is not None:
- compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline)
- elif bind is not None:
- compiler = bind.statement_compiler(self, column_keys=column_keys, inline=inline)
- elif self.bind is not None:
- compiler = self.bind.statement_compiler(self, column_keys=column_keys, inline=inline)
+
+ if not dialect:
+ if bind:
+ dialect = bind.dialect
+ elif self.bind:
+ dialect = self.bind.dialect
+ bind = self.bind
else:
global DefaultDialect
if DefaultDialect is None:
from sqlalchemy.engine.default import DefaultDialect
dialect = DefaultDialect()
- compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline)
+ compiler = self._compiler(dialect, bind=bind, **kw)
compiler.compile()
return compiler
-
+
+ def _compiler(self, dialect, **kw):
+ """Return a compiler appropriate for this ClauseElement, given a Dialect."""
+
+ return dialect.statement_compiler(dialect, self, **kw)
+
def __str__(self):
return unicode(self.compile()).encode('ascii', 'backslashreplace')
class _Immutable(object):
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
+ def unique_params(self, *optionaldict, **kwargs):
+ raise NotImplementedError("Immutable objects do not support copying")
+
+ def params(self, *optionaldict, **kwargs):
+ raise NotImplementedError("Immutable objects do not support copying")
+
def _clone(self):
return self
"""
def __init__(cls, clsname, bases, clsdict):
- if cls.__name__ == 'Visitable':
+ if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'):
super(VisitableType, cls).__init__(clsname, bases, clsdict)
return
- assert hasattr(cls, '__visit_name__'), "`Visitable` descendants " \
- "should define `__visit_name__`"
-
# set up an optimized visit dispatch function
# for use by the compiler
visit_name = cls.__visit_name__
'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT',
'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB',
'BOOLEAN', 'SMALLINT', 'DATE', 'TIME',
- 'String', 'Integer', 'SmallInteger','Smallinteger',
+ 'String', 'Integer', 'SmallInteger',
'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary',
'Boolean', 'Unicode', 'MutableType', 'Concatenable', 'UnicodeText', 'PickleType', 'Interval',
'type_map'
from sqlalchemy import exc
from sqlalchemy.util import pickle
+from sqlalchemy.sql.visitors import Visitable
+from sqlalchemy.sql import expression
+import sys
+expression.sqltypes = sys.modules[__name__]
import sqlalchemy.util as util
NoneType = type(None)
-class AbstractType(object):
+class AbstractType(Visitable):
def __init__(self, *args, **kwargs):
pass
for k in inspect.getargspec(self.__init__)[0][1:]))
class TypeEngine(AbstractType):
- """Base for built-in types.
-
- May be sub-classed to create entirely new types. Example::
-
- import sqlalchemy.types as types
-
- class MyType(types.TypeEngine):
- def __init__(self, precision = 8):
- self.precision = precision
-
- def get_col_spec(self):
- return "MYTYPE(%s)" % self.precision
-
- def bind_processor(self, dialect):
- def process(value):
- return value
- return process
-
- def result_processor(self, dialect):
- def process(value):
- return value
- return process
-
- Once the type is made, it's immediately usable::
-
- table = Table('foo', meta,
- Column('id', Integer, primary_key=True),
- Column('data', MyType(16))
- )
-
- """
+ """Base for built-in types."""
def dialect_impl(self, dialect, **kwargs):
try:
d['_impl_dict'] = {}
return d
- def get_col_spec(self):
- """Return the DDL representation for this type."""
- raise NotImplementedError()
-
def bind_processor(self, dialect):
"""Return a conversion function for processing bind values.
return self.__class__.__mro__[0:-1]
+class UserDefinedType(TypeEngine):
+ """Base for user defined types.
+
+ This should be the base of new types. Note that
+ for most cases, :class:`TypeDecorator` is probably
+ more appropriate.
+
+ import sqlalchemy.types as types
+
+ class MyType(types.UserDefinedType):
+ def __init__(self, precision = 8):
+ self.precision = precision
+
+ def get_col_spec(self):
+ return "MYTYPE(%s)" % self.precision
+
+ def bind_processor(self, dialect):
+ def process(value):
+ return value
+ return process
+
+ def result_processor(self, dialect):
+ def process(value):
+ return value
+ return process
+
+ Once the type is made, it's immediately usable::
+
+ table = Table('foo', meta,
+ Column('id', Integer, primary_key=True),
+ Column('data', MyType(16))
+ )
+
+ """
+ __visit_name__ = "user_defined"
+
class TypeDecorator(AbstractType):
"""Allows the creation of types which add additional functionality
to an existing type.
"""
+ __visit_name__ = "type_decorator"
+
def __init__(self, *args, **kwargs):
if not hasattr(self.__class__, 'impl'):
raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
return getattr(self.impl, key)
- def get_col_spec(self):
- return self.impl.get_col_spec()
-
def process_bind_param(self, value, dialect):
raise NotImplementedError()
encountered during a :meth:`~sqlalchemy.Table.create` operation.
"""
-
- def get_col_spec(self):
- raise NotImplementedError()
+ __visit_name__ = 'null'
NullTypeEngine = NullType
"""
+ __visit_name__ = 'string'
+
def __init__(self, length=None, convert_unicode=False, assert_unicode=None):
"""
Create a string-holding type.
params (and the reverse for result sets.)
"""
+
+ __visit_name__ = 'text'
+
def dialect_impl(self, dialect, **kwargs):
return TypeEngine.dialect_impl(self, dialect, **kwargs)
class Integer(TypeEngine):
"""A type for ``int`` integers."""
-
+
+ __visit_name__ = 'integer'
+
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
"""
-Smallinteger = SmallInteger
+ __visit_name__ = 'small_integer'
class Numeric(TypeEngine):
"""A type for fixed precision numbers.
"""
+ __visit_name__ = 'numeric'
+
def __init__(self, precision=10, scale=2, asdecimal=True, length=None):
"""
Construct a Numeric.
class Float(Numeric):
"""A type for ``float`` numbers."""
+ __visit_name__ = 'float'
+
def __init__(self, precision=10, asdecimal=False, **kwargs):
"""
Construct a Float.
converted back to datetime objects when rows are returned.
"""
-
+
+ __visit_name__ = 'datetime'
+
def __init__(self, timezone=False):
self.timezone = timezone
class Date(TypeEngine):
"""A type for ``datetime.date()`` objects."""
+ __visit_name__ = 'date'
+
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
class Time(TypeEngine):
"""A type for ``datetime.time()`` objects."""
+ __visit_name__ = 'time'
+
def __init__(self, timezone=False):
self.timezone = timezone
"""
+ __visit_name__ = 'binary'
+
def __init__(self, length=None):
"""
Construct a Binary type.
"""
+ __visit_name__ = 'boolean'
class Interval(TypeDecorator):
"""A type for ``datetime.timedelta()`` objects.
def __init__(self):
super(Interval, self).__init__()
- import sqlalchemy.databases.postgres as pg
+ import sqlalchemy.dialects.postgres.base as pg
self.__supported = {pg.PGDialect:pg.PGInterval}
del pg
class FLOAT(Float):
"""The SQL FLOAT type."""
+ __visit_name__ = 'FLOAT'
class NUMERIC(Numeric):
"""The SQL NUMERIC type."""
+ __visit_name__ = 'NUMERIC'
+
class DECIMAL(Numeric):
"""The SQL DECIMAL type."""
+ __visit_name__ = 'DECIMAL'
-class INT(Integer):
+
+class INTEGER(Integer):
"""The SQL INT or INTEGER type."""
+ __visit_name__ = 'INTEGER'
+INT = INTEGER
-INTEGER = INT
-class SMALLINT(Smallinteger):
+class SMALLINT(SmallInteger):
"""The SQL SMALLINT type."""
+ __visit_name__ = 'SMALLINT'
+
class TIMESTAMP(DateTime):
"""The SQL TIMESTAMP type."""
+ __visit_name__ = 'TIMESTAMP'
+
class DATETIME(DateTime):
"""The SQL DATETIME type."""
+ __visit_name__ = 'DATETIME'
+
class DATE(Date):
"""The SQL DATE type."""
+ __visit_name__ = 'DATE'
+
class TIME(Time):
"""The SQL TIME type."""
+ __visit_name__ = 'TIME'
-TEXT = Text
+class TEXT(Text):
+ """The SQL TEXT type."""
+
+ __visit_name__ = 'TEXT'
class CLOB(Text):
"""The SQL CLOB type."""
+ __visit_name__ = 'CLOB'
+
class VARCHAR(String):
"""The SQL VARCHAR type."""
+ __visit_name__ = 'VARCHAR'
+
class CHAR(String):
"""The SQL CHAR type."""
+ __visit_name__ = 'CHAR'
+
class NCHAR(Unicode):
"""The SQL NCHAR type."""
+ __visit_name__ = 'NCHAR'
+
class BLOB(Binary):
"""The SQL BLOB type."""
+ __visit_name__ = 'BLOB'
+
class BOOLEAN(Boolean):
"""The SQL BOOLEAN type."""
+ __visit_name__ = 'BOOLEAN'
+
NULLTYPE = NullType()
# using VARCHAR/NCHAR so that we dont get the genericized "String"
dt.timedelta : Interval,
type(None): NullType
}
+
class_ = stack.pop()
ctr = class_.__dict__.get('__init__', False)
if not ctr or not isinstance(ctr, types.FunctionType):
+ stack.update(class_.__bases__)
continue
names, _, has_kw, _ = inspect.getargspec(ctr)
args.update(names)
import datetime
from sqlalchemy import *
from sqlalchemy.orm import *
-from sqlalchemy import exc
-from sqlalchemy.databases import postgres
+from sqlalchemy import exc, schema
+from sqlalchemy.dialects.postgres import base as postgres
from sqlalchemy.engine.strategies import MockEngineStrategy
from testlib import *
from sqlalchemy.sql import table, column
-
+from testlib.testing import eq_
class SequenceTest(TestBase, AssertsCompiledSQL):
def test_basic(self):
i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)])
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
+ def test_create_partial_index(self):
+ tbl = Table('testtbl', MetaData(), Column('data',Integer))
+ idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10))
+
+ self.assert_compile(schema.CreateIndex(idx),
+ "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgres.dialect())
+
+
class ReturningTest(TestBase, AssertsExecutionResults):
__only_on__ = 'postgres'
metadata = MetaData(testing.db)
table = Table('testtable', metadata, autoload=True)
self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
- self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger)
+ assert isinstance(table.c.answer.type, Integer)
def test_domain_is_reflected(self):
metadata = MetaData(testing.db)
metadata = MetaData(testing.db)
table = Table('testtable', metadata, autoload=True, schema='alt_schema')
self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
- self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger)
+ assert isinstance(table.c.anything.type, Integer)
def test_schema_domain_is_reflected(self):
metadata = MetaData(testing.db)
self.assertEquals(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
-class MiscTest(TestBase, AssertsExecutionResults):
+class MiscTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
__only_on__ = 'postgres'
def test_date_reflection(self):
warnings.warn = capture_warnings._orig_showwarning
m1.drop_all()
- def test_create_partial_index(self):
- tbl = Table('testtbl', MetaData(), Column('data',Integer))
- idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10))
-
- executed_sql = []
- mock_strategy = MockEngineStrategy()
- mock_conn = mock_strategy.create('postgres://', executed_sql.append)
-
- idx.create(mock_conn)
-
- assert executed_sql == ['CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10']
class TimezoneTest(TestBase, AssertsExecutionResults):
"""Test timezone-aware datetimes.
import datetime
from sqlalchemy import *
from sqlalchemy import exc
-from sqlalchemy.databases import sqlite
+from sqlalchemy.dialects.sqlite import base as sqlite, pysqlite as pysqlite_dialect
from testlib import *
@testing.uses_deprecated('Using String type with no length')
def test_type_reflection(self):
# (ask_for, roundtripped_as_if_different)
- specs = [( String(), sqlite.SLString(), ),
- ( String(1), sqlite.SLString(1), ),
- ( String(3), sqlite.SLString(3), ),
- ( Text(), sqlite.SLText(), ),
- ( Unicode(), sqlite.SLString(), ),
- ( Unicode(1), sqlite.SLString(1), ),
- ( Unicode(3), sqlite.SLString(3), ),
- ( UnicodeText(), sqlite.SLText(), ),
- ( CLOB, sqlite.SLText(), ),
- ( sqlite.SLChar(1), ),
- ( CHAR(3), sqlite.SLChar(3), ),
- ( NCHAR(2), sqlite.SLChar(2), ),
- ( SmallInteger(), sqlite.SLSmallInteger(), ),
- ( sqlite.SLSmallInteger(), ),
- ( Binary(3), sqlite.SLBinary(), ),
- ( Binary(), sqlite.SLBinary() ),
- ( sqlite.SLBinary(3), sqlite.SLBinary(), ),
+ specs = [( String(), pysqlite_dialect.SLString(), ),
+ ( String(1), pysqlite_dialect.SLString(1), ),
+ ( String(3), pysqlite_dialect.SLString(3), ),
+ ( Text(), pysqlite_dialect.SLText(), ),
+ ( Unicode(), pysqlite_dialect.SLString(), ),
+ ( Unicode(1), pysqlite_dialect.SLString(1), ),
+ ( Unicode(3), pysqlite_dialect.SLString(3), ),
+ ( UnicodeText(), pysqlite_dialect.SLText(), ),
+ ( CLOB, pysqlite_dialect.SLText(), ),
+ ( pysqlite_dialect.SLChar(1), ),
+ ( CHAR(3), pysqlite_dialect.SLChar(3), ),
+ ( NCHAR(2), pysqlite_dialect.SLChar(2), ),
( NUMERIC, sqlite.SLNumeric(), ),
( NUMERIC(10,2), sqlite.SLNumeric(10,2), ),
( Numeric, sqlite.SLNumeric(), ),
( DECIMAL(10, 2), sqlite.SLNumeric(10, 2), ),
( Float, sqlite.SLNumeric(), ),
( sqlite.SLNumeric(), ),
- ( INT, sqlite.SLInteger(), ),
- ( Integer, sqlite.SLInteger(), ),
- ( sqlite.SLInteger(), ),
( TIMESTAMP, sqlite.SLDateTime(), ),
( DATETIME, sqlite.SLDateTime(), ),
( DateTime, sqlite.SLDateTime(), ),
finally:
db.execute('DROP VIEW types_v')
finally:
- m.drop_all()
+ pass
+ #m.drop_all()
class TestDefaults(TestBase, AssertsExecutionResults):
from testlib.sa import MetaData, Table, Column, Integer, String
import testlib.sa as tsa
from testlib import TestBase, testing, engines
-
+from testlib.testing import AssertsCompiledSQL
class DDLEventTest(TestBase):
class Canary(object):
r = eval(py)
assert list(r) == [(1,)], py
-class DDLTest(TestBase):
+class DDLTest(TestBase, AssertsCompiledSQL):
def mock_engine(self):
executor = lambda *a, **kw: None
engine = create_engine(testing.db.name + '://',
ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
- self.assertEquals(ddl._expand(sane_alone, bind), '-t-t')
- self.assertEquals(ddl._expand(sane_schema, bind), 's-t-s.t')
- self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
- self.assertEquals(ddl._expand(insane_schema, bind),
- '"s s"-"t t"-"s s"."t t"')
+ dialect = bind.dialect
+ self.assert_compile(ddl.against(sane_alone), '-t-t', dialect=dialect)
+ self.assert_compile(ddl.against(sane_schema), 's-t-s.t', dialect=dialect)
+ self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect)
+ self.assert_compile(ddl.against(insane_schema), '"s s"-"t t"-"s s"."t t"', dialect=dialect)
# overrides are used piece-meal and verbatim.
ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s',
context={'schema':'S S', 'table': 'T T', 'bonus': 'b'})
- self.assertEquals(ddl._expand(sane_alone, bind), 'S S-T T-t-b')
- self.assertEquals(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b')
- self.assertEquals(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b')
- self.assertEquals(ddl._expand(insane_schema, bind),
- 'S S-T T-"s s"."t t"-b')
+
+ self.assert_compile(ddl.against(sane_alone), 'S S-T T-t-b', dialect=dialect)
+ self.assert_compile(ddl.against(sane_schema), 'S S-T T-s.t-b', dialect=dialect)
+ self.assert_compile(ddl.against(insane_alone), 'S S-T T-"t t"-b', dialect=dialect)
+ self.assert_compile(ddl.against(insane_schema), 'S S-T T-"s s"."t t"-b', dialect=dialect)
+
def test_filter(self):
cx = self.mock_engine()
import testenv; testenv.configure_for_tests()
import StringIO, unicodedata
import sqlalchemy as sa
+from sqlalchemy import schema
from testlib.sa import MetaData, Table, Column
from testlib import TestBase, ComparesTables, testing, engines, sa as tsa
self.assert_tables_equal(users, reflected_users)
self.assert_tables_equal(addresses, reflected_addresses)
finally:
- addresses.drop()
- users.drop()
+ meta.drop_all()
def test_include_columns(self):
meta = MetaData(testing.db)
t = Table("test", meta,
Column('foo', sa.DateTime))
- import sys
- dialect_module = sys.modules[testing.db.dialect.__module__]
-
- # we're relying on the presence of "ischema_names" in the
- # dialect module, else we can't test this. we need to be able
- # to get the dialect to not be aware of some type so we temporarily
- # monkeypatch. not sure what a better way for this could be,
- # except for an established dialect hook or dialect-specific tests
- if not hasattr(dialect_module, 'ischema_names'):
- return
-
- ischema_names = dialect_module.ischema_names
+ ischema_names = testing.db.dialect.ischema_names
t.create()
- dialect_module.ischema_names = {}
+ testing.db.dialect.ischema_names = {}
try:
m2 = MetaData(testing.db)
self.assertRaises(tsa.exc.SAWarning, Table, "test", m2, autoload=True)
assert t3.c.foo.type.__class__ == sa.types.NullType
finally:
- dialect_module.ischema_names = ischema_names
+ testing.db.dialect.ischema_names = ischema_names
t.drop()
def test_basic_override(self):
r.drop_all()
r.create_all()
finally:
- metadata.drop_all()
- bind.dispose()
+ pass
+# metadata.drop_all()
+# bind.dispose()
class SchemaTest(TestBase):
Column('col1', sa.Integer, primary_key=True),
Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')),
schema='someschema')
- # ensure this doesnt crash
- print [t for t in metadata.sorted_tables]
- buf = StringIO.StringIO()
- def foo(s, p=None):
- buf.write(s)
- gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo)
- gen = gen.dialect.schemagenerator(gen.dialect, gen)
- gen.traverse(table1)
- gen.traverse(table2)
- buf = buf.getvalue()
- print buf
+
+ t1 = str(schema.CreateTable(table1).compile(bind=testing.db))
+ t2 = str(schema.CreateTable(table2).compile(bind=testing.db))
if testing.db.dialect.preparer(testing.db.dialect).omit_schema:
- assert buf.index("CREATE TABLE table1") > -1
- assert buf.index("CREATE TABLE table2") > -1
+ assert t1.index("CREATE TABLE table1") > -1
+ assert t2.index("CREATE TABLE table2") > -1
else:
- assert buf.index("CREATE TABLE someschema.table1") > -1
- assert buf.index("CREATE TABLE someschema.table2") > -1
+ assert t1.index("CREATE TABLE someschema.table1") > -1
+ assert t2.index("CREATE TABLE someschema.table2") > -1
@testing.crashes('firebird', 'No schema support')
@testing.fails_on('sqlite', 'FIXME: unknown')
import testenv; testenv.configure_for_tests()
import doctest, sys
-
from testlib import sa_unittest as unittest
class User(Base):
id = Column('id', Integer, primary_key=True)
self.assertRaisesMessage(sa.exc.InvalidRequestError, "does not have a __table__", go)
-
- def test_cant_add_columns(self):
- t = Table('t', Base.metadata, Column('id', Integer, primary_key=True))
- def go():
- class User(Base):
- __table__ = t
- foo = Column(Integer, primary_key=True)
- self.assertRaisesMessage(sa.exc.ArgumentError, "add additional columns", go)
-
- def test_undefer_column_name(self):
- # TODO: not sure if there was an explicit
- # test for this elsewhere
- foo = Column(Integer)
- eq_(str(foo), '(no name)')
- eq_(foo.key, None)
- eq_(foo.name, None)
- decl._undefer_column_name('foo', foo)
- eq_(str(foo), 'foo')
- eq_(foo.key, 'foo')
- eq_(foo.name, 'foo')
def test_recompile_on_othermapper(self):
"""declarative version of the same test in mappers.py"""
assert list(sess.execute(t5.select(), mapper=T5)) == [(2, 'some other t5')]
assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some other t6', 2)]
-class InheritingRowSwitchTest(_base.MappedTest):
- def define_tables(self, metadata):
- Table('parent', metadata,
- Column('id', Integer, primary_key=True),
- Column('pdata', String(30))
- )
- Table('child', metadata,
- Column('id', Integer, primary_key=True),
- Column('pid', Integer, ForeignKey('parent.id')),
- Column('cdata', String(30))
- )
-
- def setup_classes(self):
- class P(_base.ComparableEntity):
- pass
-
- class C(P):
- pass
-
- @testing.resolve_artifact_names
- def test_row_switch_no_child_table(self):
- mapper(P, parent)
- mapper(C, child, inherits=P)
-
- sess = create_session()
- c1 = C(id=1, pdata='c1', cdata='c1')
- sess.add(c1)
- sess.flush()
-
- # establish a row switch between c1 and c2.
- # c2 has no value for the "child" table
- c2 = C(id=1, pdata='c2')
- sess.add(c2)
- sess.delete(c1)
-
- self.assert_sql_execution(testing.db, sess.flush,
- CompiledSQL("UPDATE parent SET pdata=:pdata WHERE parent.id = :parent_id",
- {'pdata':'c2', 'parent_id':1}
- )
- )
-
-
-
class TransactionTest(_base.MappedTest):
__requires__ = ('deferrable_constraints',)
import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from sqlalchemy import exc
+from sqlalchemy import exc, schema
from testlib import *
from testlib import config, engines
+from sqlalchemy.engine import ddl
+from testlib.testing import eq_
+from testlib.assertsql import AllOf, RegexSQL, ExactSQL, CompiledSQL
-class ConstraintTest(TestBase, AssertsExecutionResults):
+class ConstraintTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
def setUp(self):
global metadata
Column('y', Integer, f)
)
-
def test_circular_constraint(self):
a = Table("a", metadata,
Column('id', Integer, primary_key=True),
metadata.create_all()
foo.insert().execute(id=1,x=9,y=5)
- try:
- foo.insert().execute(id=2,x=5,y=9)
- assert False
- except exc.SQLError:
- assert True
-
+ self.assertRaises(exc.SQLError, foo.insert().execute, id=2,x=5,y=9)
bar.insert().execute(id=1,x=10)
- try:
- bar.insert().execute(id=2,x=5)
- assert False
- except exc.SQLError:
- assert True
+ self.assertRaises(exc.SQLError, bar.insert().execute, id=2,x=5)
def test_unique_constraint(self):
foo = Table('foo', metadata,
foo.insert().execute(id=2, value='value2')
bar.insert().execute(id=1, value='a', value2='a')
bar.insert().execute(id=2, value='a', value2='b')
- try:
- foo.insert().execute(id=3, value='value1')
- assert False
- except exc.SQLError:
- assert True
- try:
- bar.insert().execute(id=3, value='a', value2='b')
- assert False
- except exc.SQLError:
- assert True
+ self.assertRaises(exc.SQLError, foo.insert().execute, id=3, value='value1')
+ self.assertRaises(exc.SQLError, bar.insert().execute, id=3, value='a', value2='b')
def test_index_create(self):
employees = Table('employees', metadata,
Index('sport_announcer', events.c.sport, events.c.announcer, unique=True)
Index('idx_winners', events.c.winner)
- index_names = [ ix.name for ix in events.indexes ]
- assert 'ix_events_name' in index_names
- assert 'ix_events_location' in index_names
- assert 'sport_announcer' in index_names
- assert 'idx_winners' in index_names
- assert len(index_names) == 4
-
- capt = []
- connection = testing.db.connect()
- # TODO: hacky, put a real connection proxy in
- ex = connection._Connection__execute_context
- def proxy(context):
- capt.append(context.statement)
- capt.append(repr(context.parameters))
- ex(context)
- connection._Connection__execute_context = proxy
- schemagen = testing.db.dialect.schemagenerator(testing.db.dialect, connection)
- schemagen.traverse(events)
-
- assert capt[0].strip().startswith('CREATE TABLE events')
-
- s = set([capt[x].strip() for x in [2,4,6,8]])
-
- assert s == set([
- 'CREATE UNIQUE INDEX ix_events_name ON events (name)',
- 'CREATE INDEX ix_events_location ON events (location)',
- 'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)',
- 'CREATE INDEX idx_winners ON events (winner)'
- ])
+ eq_(
+ set([ ix.name for ix in events.indexes ]),
+ set(['ix_events_name', 'ix_events_location', 'sport_announcer', 'idx_winners'])
+ )
+
+ self.assert_sql_execution(
+ testing.db,
+ lambda: events.create(testing.db),
+ RegexSQL("^CREATE TABLE events"),
+ AllOf(
+ ExactSQL('CREATE UNIQUE INDEX ix_events_name ON events (name)'),
+ ExactSQL('CREATE INDEX ix_events_location ON events (location)'),
+ ExactSQL('CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'),
+ ExactSQL('CREATE INDEX idx_winners ON events (winner)')
+ )
+ )
# verify that the table is functional
events.insert().execute(id=1, name='hockey finals', location='rink',
dialect = testing.db.dialect.__class__()
dialect.max_identifier_length = 20
- schemagen = dialect.schemagenerator(dialect, None)
- schemagen.execute = lambda : None
-
t1 = Table("sometable", MetaData(), Column("foo", Integer))
- schemagen.visit_index(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
- self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)")
- schemagen.buffer.truncate(0)
- schemagen.visit_index(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo))
- self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)")
-
- schemadrop = dialect.schemadropper(dialect, None)
- schemadrop.execute = lambda: None
- self.assertRaises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
+ self.assert_compile(
+ schema.CreateIndex(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)),
+ "CREATE INDEX this_name_is_t_1 ON sometable (foo)",
+ dialect=dialect
+ )
+
+ self.assert_compile(
+ schema.CreateIndex(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)),
+ "CREATE INDEX this_other_nam_1 ON sometable (foo)",
+ dialect=dialect
+ )
-class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
- class accum(object):
- def __init__(self):
- self.statements = []
- def __call__(self, sql, *a, **kw):
- self.statements.append(sql)
- def __contains__(self, substring):
- for s in self.statements:
- if substring in s:
- return True
- return False
- def __str__(self):
- return '\n'.join([repr(x) for x in self.statements])
- def clear(self):
- del self.statements[:]
-
- def setUp(self):
- self.sql = self.accum()
- opts = config.db_opts.copy()
- opts['strategy'] = 'mock'
- opts['executor'] = self.sql
- self.engine = engines.testing_engine(options=opts)
-
+class ConstraintCompilationTest(TestBase):
def _test_deferrable(self, constraint_factory):
- meta = MetaData(self.engine)
- t = Table('tbl', meta,
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer),
constraint_factory(deferrable=True))
- t.create()
- assert 'DEFERRABLE' in self.sql, self.sql
- assert 'NOT DEFERRABLE' not in self.sql, self.sql
- self.sql.clear()
- meta.clear()
-
- t = Table('tbl', meta,
+
+ sql = str(schema.CreateTable(t).compile(bind=testing.db))
+ assert 'DEFERRABLE' in sql, sql
+ assert 'NOT DEFERRABLE' not in sql, sql
+
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer),
constraint_factory(deferrable=False))
- t.create()
- assert 'NOT DEFERRABLE' in self.sql
- self.sql.clear()
- meta.clear()
- t = Table('tbl', meta,
+ sql = str(schema.CreateTable(t).compile(bind=testing.db))
+ assert 'NOT DEFERRABLE' in sql
+
+
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer),
constraint_factory(deferrable=True, initially='IMMEDIATE'))
- t.create()
- assert 'NOT DEFERRABLE' not in self.sql
- assert 'INITIALLY IMMEDIATE' in self.sql
- self.sql.clear()
- meta.clear()
+ sql = str(schema.CreateTable(t).compile(bind=testing.db))
+ assert 'NOT DEFERRABLE' not in sql
+ assert 'INITIALLY IMMEDIATE' in sql
- t = Table('tbl', meta,
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer),
constraint_factory(deferrable=True, initially='DEFERRED'))
- t.create()
+ sql = str(schema.CreateTable(t).compile(bind=testing.db))
- assert 'NOT DEFERRABLE' not in self.sql
- assert 'INITIALLY DEFERRED' in self.sql, self.sql
+ assert 'NOT DEFERRABLE' not in sql
+ assert 'INITIALLY DEFERRED' in sql
def test_deferrable_pk(self):
factory = lambda **kw: PrimaryKeyConstraint('a', **kw)
self._test_deferrable(factory)
def test_deferrable_column_fk(self):
- meta = MetaData(self.engine)
- t = Table('tbl', meta,
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer,
ForeignKey('tbl.a', deferrable=True,
initially='DEFERRED')))
- t.create()
- assert 'DEFERRABLE' in self.sql, self.sql
- assert 'INITIALLY DEFERRED' in self.sql, self.sql
+
+ sql = str(schema.CreateTable(t).compile(bind=testing.db))
+ assert 'DEFERRABLE' in sql
+ assert 'INITIALLY DEFERRED' in sql
def test_deferrable_unique(self):
factory = lambda **kw: UniqueConstraint('b', **kw)
self._test_deferrable(factory)
def test_deferrable_column_check(self):
- meta = MetaData(self.engine)
- t = Table('tbl', meta,
+ t = Table('tbl', MetaData(),
Column('a', Integer),
Column('b', Integer,
CheckConstraint('a < b',
deferrable=True,
initially='DEFERRED')))
- t.create()
- assert 'DEFERRABLE' in self.sql, self.sql
- assert 'INITIALLY DEFERRED' in self.sql, self.sql
+ sql = str(schema.CreateTable(t).compile(bind=testing.db))
+ assert 'DEFERRABLE' in sql
+ assert 'INITIALLY DEFERRED' in sql
if __name__ == "__main__":
from sqlalchemy.sql import table, column, label, compiler
from sqlalchemy.sql.expression import ClauseList
from sqlalchemy.engine import default
-from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
+from sqlalchemy.databases import mysql, oracle, firebird, mssql
+from sqlalchemy.dialects.sqlite import pysqlite as sqlite
+from sqlalchemy.dialects.postgres import psycopg2 as postgres
from testlib import *
table1 = table('mytable',
import testenv; testenv.configure_for_tests()
import datetime, os, pickleable, re
from sqlalchemy import *
-from sqlalchemy import exc, types, util
+from sqlalchemy import exc, types, util, schema
from sqlalchemy.sql import operators
from testlib.testing import eq_
import sqlalchemy.engine.url as url
-from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
+from sqlalchemy.databases import mssql, oracle, mysql, firebird
+from sqlalchemy.dialects.sqlite import pysqlite as sqlite
+from sqlalchemy.dialects.postgres import psycopg2 as postgres
+
from testlib import *
(mysql_dialect, Unicode(), mysql.MSString),
(mysql_dialect, UnicodeText(), mysql.MSText),
(mysql_dialect, NCHAR(), mysql.MSNChar),
- (postgres_dialect, String(), postgres.PGString),
- (postgres_dialect, VARCHAR(), postgres.PGString),
- (postgres_dialect, String(50), postgres.PGString),
- (postgres_dialect, Unicode(), postgres.PGString),
- (postgres_dialect, UnicodeText(), postgres.PGText),
- (postgres_dialect, NCHAR(), postgres.PGString),
+ (postgres_dialect, String(), String),
+ (postgres_dialect, VARCHAR(), String),
+ (postgres_dialect, String(50), String),
+ (postgres_dialect, Unicode(), String),
+ (postgres_dialect, UnicodeText(), Text),
+ (postgres_dialect, NCHAR(), String),
(firebird_dialect, String(), firebird.FBString),
(firebird_dialect, VARCHAR(), firebird.FBString),
(firebird_dialect, String(50), firebird.FBString),
class UserDefinedTest(TestBase):
"""tests user-defined types."""
- def testbasic(self):
- print users.c.goofy4.type
- print users.c.goofy4.type.dialect_impl(testing.db.dialect)
- print users.c.goofy4.type.dialect_impl(testing.db.dialect).get_col_spec()
-
def testprocessing(self):
global users
def setUpAll(self):
global users, metadata
- class MyType(types.TypeEngine):
+ class MyType(types.UserDefinedType):
def get_col_spec(self):
return "VARCHAR(100)"
def bind_processor(self, dialect):
for key, value in expectedResults.items():
expectedResults[key] = '%s NULL' % value
- print db.engine.__module__
testTable = Table('testColumns', MetaData(db),
Column('int_column', Integer),
Column('smallint_column', SmallInteger),
for aCol in testTable.c:
self.assertEquals(
expectedResults[aCol.name],
- db.dialect.schemagenerator(db.dialect, db, None, None).\
+ db.dialect.ddl_compiler(db.dialect, schema.CreateTable(testTable)).\
get_column_specification(aCol))
class UnicodeTest(TestBase, AssertsExecutionResults):
def setUpAll(self):
global test_table, meta
- class MyCustomType(types.TypeEngine):
+ class MyCustomType(types.UserDefinedType):
def get_col_spec(self):
return "INT"
def bind_processor(self, dialect):
from decimal import Decimal
numeric_table.insert().execute(
numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.75)
+
numeric_table.insert().execute(
numericcol=Decimal("3.5"), floatcol=Decimal("5.6"),
ncasdec=Decimal("12.4"), fcasdec=Decimal("15.75"))
eq_(n2.scale, 12, dialect.name)
# test colspec generates successfully using 'scale'
- assert n2.get_col_spec()
+ assert dialect.type_compiler.process(n2)
# test constructor of the dialect-specific type
n3 = n2.__class__(scale=5)
for name in d.__all__:
mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
yield mod.dialect()
+ import sqlalchemy.dialects as d
+ for name in d.__all__:
+ mod = getattr(__import__('sqlalchemy.dialects.%s.base' % name).dialects, name).base
+ yield mod.dialect()
class ReconnectFixture(object):
def __init__(self, dbapi):
if dialect is None:
dialect = getattr(self, '__dialect__', None)
- if params is None:
- keys = None
- else:
- keys = params.keys()
+ kw = {}
+ if params is not None:
+ kw['column_keys'] = params.keys()
- c = clause.compile(column_keys=keys, dialect=dialect)
+ c = clause.compile(dialect=dialect, **kw)
- print "\nSQL String:\n" + str(c) + repr(c.params)
+ print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {}))
cc = re.sub(r'\n', '', str(c))