From: Mike Bayer Date: Mon, 19 Jan 2009 18:24:40 +0000 (+0000) Subject: - moved all the dialects over to their final positions X-Git-Tag: rel_0_6_6~331 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=811a5cb3274b191872f5776b56c7ec7fa53dca37;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - moved all the dialects over to their final positions - structured maxdb, sybase, informix dialects. obviously no testing has been done. --- diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 5ef4ff6d84..35cdaa31ba 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -25,9 +25,11 @@ from sqlalchemy.types import ( FLOAT, Float, INT, + INTEGER, Integer, Interval, NCHAR, + NVARCHAR, NUMERIC, Numeric, PickleType, diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py new file mode 100644 index 0000000000..a0f3f02161 --- /dev/null +++ b/lib/sqlalchemy/connectors/mxodbc.py @@ -0,0 +1,24 @@ +from sqlalchemy.connectors import Connector + +class MxODBCConnector(Connector): + driver='mxodbc' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + supports_unicode_statements = False + supports_unicode_binds = False + + @classmethod + def import_dbapi(cls): + import mxODBC as module + return module + + def create_connect_args(self, url): + '''Return a tuple of *args,**kwargs''' + # FIXME: handle mx.odbc.Windows proprietary args + opts = url.translate_connect_args(username='user') + opts.update(url.query) + argsDict = {} + argsDict['user'] = opts['user'] + argsDict['password'] = opts['password'] + connArgs = [[opts['dsn']], argsDict] + return connArgs diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index b45ea73d66..48435770da 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -7,6 +7,15 @@ from sqlalchemy.dialects.sqlite import base as sqlite from sqlalchemy.dialects.postgres import base as postgres from sqlalchemy.dialects.mysql import base as mysql +from sqlalchemy.dialects.oracle import base as oracle +from sqlalchemy.dialects.firebird import base as firebird +from sqlalchemy.dialects.maxdb import base as maxdb +from sqlalchemy.dialects.informix import base as informix +from sqlalchemy.dialects.mssql import base as mssql +from sqlalchemy.dialects.access import base as access +from sqlalchemy.dialects.sybase import base as sybase + + __all__ = ( diff --git a/lib/sqlalchemy/databases/mxODBC.py b/lib/sqlalchemy/databases/mxODBC.py deleted file mode 100644 index 92f533633c..0000000000 --- a/lib/sqlalchemy/databases/mxODBC.py +++ /dev/null @@ -1,60 +0,0 @@ -# mxODBC.py -# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch -# Coding: Alexander Houben alexander.houben@thor-solutions.ch -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -""" -A wrapper for a mx.ODBC.Windows DB-API connection. - -Makes sure the mx module is configured to return datetime objects instead -of mx.DateTime.DateTime objects. -""" - -from mx.ODBC.Windows import * - - -class Cursor: - def __init__(self, cursor): - self.cursor = cursor - - def __getattr__(self, attr): - res = getattr(self.cursor, attr) - return res - - def execute(self, *args, **kwargs): - res = self.cursor.execute(*args, **kwargs) - return res - - -class Connection: - def myErrorHandler(self, connection, cursor, errorclass, errorvalue): - err0, err1, err2, err3 = errorvalue - #print ", ".join(["Err%d: %s"%(x, errorvalue[x]) for x in range(4)]) - if int(err1) == 109: - # Ignore "Null value eliminated in aggregate function", this is not an error - return - raise errorclass, errorvalue - - def __init__(self, conn): - self.conn = conn - # install a mx ODBC error handler - self.conn.errorhandler = self.myErrorHandler - - def __getattr__(self, attr): - res = getattr(self.conn, attr) - return res - - def cursor(self, *args, **kwargs): - res = Cursor(self.conn.cursor(*args, **kwargs)) - return res - - -# override 'connect' call -def connect(*args, **kwargs): - import mx.ODBC.Windows - conn = mx.ODBC.Windows.Connect(*args, **kwargs) - conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT - return Connection(conn) -Connect = connect diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py deleted file mode 100644 index 0cf0eeaf56..0000000000 --- a/lib/sqlalchemy/databases/sybase.py +++ /dev/null @@ -1,863 +0,0 @@ -# sybase.py -# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch -# Coding: Alexander Houben alexander.houben@thor-solutions.ch -# -# This module is part of SQLAlchemy and is released under -# the MIT License: http://www.opensource.org/licenses/mit-license.php - -""" -Sybase database backend. - -Known issues / TODO: - - * Uses the mx.ODBC driver from egenix (version 2.1.0) - * The current version of sqlalchemy.databases.sybase only supports - mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need - some development) - * Support for pyodbc has been built in but is not yet complete (needs - further development) - * Results of running tests/alltests.py: - Ran 934 tests in 287.032s - FAILED (failures=3, errors=1) - * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751) -""" - -import datetime, operator - -from sqlalchemy import util, sql, schema, exc -from sqlalchemy.sql import compiler, expression -from sqlalchemy.engine import default, base -from sqlalchemy import types as sqltypes -from sqlalchemy.sql import operators as sql_operators -from sqlalchemy import MetaData, Table, Column -from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey - - -__all__ = [ - 'SybaseTypeError' - 'SybaseNumeric', 'SybaseFloat', 'SybaseInteger', 'SybaseBigInteger', - 'SybaseTinyInteger', 'SybaseSmallInteger', - 'SybaseDateTime_mxodbc', 'SybaseDateTime_pyodbc', - 'SybaseDate_mxodbc', 'SybaseDate_pyodbc', - 'SybaseTime_mxodbc', 'SybaseTime_pyodbc', - 'SybaseText', 'SybaseString', 'SybaseChar', 'SybaseBinary', - 'SybaseBoolean', 'SybaseTimeStamp', 'SybaseMoney', 'SybaseSmallMoney', - 'SybaseUniqueIdentifier', - ] - - -RESERVED_WORDS = set([ - "add", "all", "alter", "and", - "any", "as", "asc", "backup", - "begin", "between", "bigint", "binary", - "bit", "bottom", "break", "by", - "call", "capability", "cascade", "case", - "cast", "char", "char_convert", "character", - "check", "checkpoint", "close", "comment", - "commit", "connect", "constraint", "contains", - "continue", "convert", "create", "cross", - "cube", "current", "current_timestamp", "current_user", - "cursor", "date", "dbspace", "deallocate", - "dec", "decimal", "declare", "default", - "delete", "deleting", "desc", "distinct", - "do", "double", "drop", "dynamic", - "else", "elseif", "encrypted", "end", - "endif", "escape", "except", "exception", - "exec", "execute", "existing", "exists", - "externlogin", "fetch", "first", "float", - "for", "force", "foreign", "forward", - "from", "full", "goto", "grant", - "group", "having", "holdlock", "identified", - "if", "in", "index", "index_lparen", - "inner", "inout", "insensitive", "insert", - "inserting", "install", "instead", "int", - "integer", "integrated", "intersect", "into", - "iq", "is", "isolation", "join", - "key", "lateral", "left", "like", - "lock", "login", "long", "match", - "membership", "message", "mode", "modify", - "natural", "new", "no", "noholdlock", - "not", "notify", "null", "numeric", - "of", "off", "on", "open", - "option", "options", "or", "order", - "others", "out", "outer", "over", - "passthrough", "precision", "prepare", "primary", - "print", "privileges", "proc", "procedure", - "publication", "raiserror", "readtext", "real", - "reference", "references", "release", "remote", - "remove", "rename", "reorganize", "resource", - "restore", "restrict", "return", "revoke", - "right", "rollback", "rollup", "save", - "savepoint", "scroll", "select", "sensitive", - "session", "set", "setuser", "share", - "smallint", "some", "sqlcode", "sqlstate", - "start", "stop", "subtrans", "subtransaction", - "synchronize", "syntax_error", "table", "temporary", - "then", "time", "timestamp", "tinyint", - "to", "top", "tran", "trigger", - "truncate", "tsequal", "unbounded", "union", - "unique", "unknown", "unsigned", "update", - "updating", "user", "using", "validate", - "values", "varbinary", "varchar", "variable", - "varying", "view", "wait", "waitfor", - "when", "where", "while", "window", - "with", "with_cube", "with_lparen", "with_rollup", - "within", "work", "writetext", - ]) - -ischema = MetaData() - -tables = Table("SYSTABLE", ischema, - Column("table_id", Integer, primary_key=True), - Column("file_id", SMALLINT), - Column("table_name", CHAR(128)), - Column("table_type", CHAR(10)), - Column("creator", Integer), - #schema="information_schema" - ) - -domains = Table("SYSDOMAIN", ischema, - Column("domain_id", Integer, primary_key=True), - Column("domain_name", CHAR(128)), - Column("type_id", SMALLINT), - Column("precision", SMALLINT, quote=True), - #schema="information_schema" - ) - -columns = Table("SYSCOLUMN", ischema, - Column("column_id", Integer, primary_key=True), - Column("table_id", Integer, ForeignKey(tables.c.table_id)), - Column("pkey", CHAR(1)), - Column("column_name", CHAR(128)), - Column("nulls", CHAR(1)), - Column("width", SMALLINT), - Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)), - # FIXME: should be mx.BIGINT - Column("max_identity", Integer), - # FIXME: should be mx.ODBC.Windows.LONGVARCHAR - Column("default", String), - Column("scale", Integer), - #schema="information_schema" - ) - -foreignkeys = Table("SYSFOREIGNKEY", ischema, - Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True), - Column("foreign_key_id", SMALLINT, primary_key=True), - Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)), - #schema="information_schema" - ) -fkcols = Table("SYSFKCOL", ischema, - Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True), - Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True), - Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True), - Column("primary_column_id", Integer), - #schema="information_schema" - ) - -class SybaseTypeError(sqltypes.TypeEngine): - def result_processor(self, dialect): - return None - - def bind_processor(self, dialect): - def process(value): - raise exc.InvalidRequestError("Data type not supported", [value]) - return process - - def get_col_spec(self): - raise exc.CompileError("Data type not supported") - -class SybaseNumeric(sqltypes.Numeric): - def get_col_spec(self): - if self.scale is None: - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s)" % {'precision' : self.precision} - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class SybaseFloat(sqltypes.FLOAT, SybaseNumeric): - def __init__(self, precision = 10, asdecimal = False, scale = 2, **kwargs): - super(sqltypes.FLOAT, self).__init__(precision, asdecimal, **kwargs) - self.scale = scale - - def get_col_spec(self): - # if asdecimal is True, handle same way as SybaseNumeric - if self.asdecimal: - return SybaseNumeric.get_col_spec(self) - if self.precision is None: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return float(value) - if self.asdecimal: - return SybaseNumeric.result_processor(self, dialect) - return process - -class SybaseInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class SybaseBigInteger(SybaseInteger): - def get_col_spec(self): - return "BIGINT" - -class SybaseTinyInteger(SybaseInteger): - def get_col_spec(self): - return "TINYINT" - -class SybaseSmallInteger(SybaseInteger): - def get_col_spec(self): - return "SMALLINT" - -class SybaseDateTime_mxodbc(sqltypes.DateTime): - def __init__(self, *a, **kw): - super(SybaseDateTime_mxodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - -class SybaseDateTime_pyodbc(sqltypes.DateTime): - def __init__(self, *a, **kw): - super(SybaseDateTime_pyodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - # Convert the datetime.datetime back to datetime.time - return value - return process - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - return value - return process - -class SybaseDate_mxodbc(sqltypes.Date): - def __init__(self, *a, **kw): - super(SybaseDate_mxodbc, self).__init__(False) - - def get_col_spec(self): - return "DATE" - -class SybaseDate_pyodbc(sqltypes.Date): - def __init__(self, *a, **kw): - super(SybaseDate_pyodbc, self).__init__(False) - - def get_col_spec(self): - return "DATE" - -class SybaseTime_mxodbc(sqltypes.Time): - def __init__(self, *a, **kw): - super(SybaseTime_mxodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - # Convert the datetime.datetime back to datetime.time - return datetime.time(value.hour, value.minute, value.second, value.microsecond) - return process - -class SybaseTime_pyodbc(sqltypes.Time): - def __init__(self, *a, **kw): - super(SybaseTime_pyodbc, self).__init__(False) - - def get_col_spec(self): - return "DATETIME" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - # Convert the datetime.datetime back to datetime.time - return datetime.time(value.hour, value.minute, value.second, value.microsecond) - return process - - def bind_processor(self, dialect): - def process(value): - if value is None: - return None - return datetime.datetime(1970, 1, 1, value.hour, value.minute, value.second, value.microsecond) - return process - -class SybaseText(sqltypes.Text): - def get_col_spec(self): - return "TEXT" - -class SybaseString(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} - -class SybaseChar(sqltypes.CHAR): - def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} - -class SybaseBinary(sqltypes.Binary): - def get_col_spec(self): - return "IMAGE" - -class SybaseBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BIT" - - def result_processor(self, dialect): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - -class SybaseTimeStamp(sqltypes.TIMESTAMP): - def get_col_spec(self): - return "TIMESTAMP" - -class SybaseMoney(sqltypes.TypeEngine): - def get_col_spec(self): - return "MONEY" - -class SybaseSmallMoney(SybaseMoney): - def get_col_spec(self): - return "SMALLMONEY" - -class SybaseUniqueIdentifier(sqltypes.TypeEngine): - def get_col_spec(self): - return "UNIQUEIDENTIFIER" - -class SybaseSQLExecutionContext(default.DefaultExecutionContext): - pass - -class SybaseSQLExecutionContext_mxodbc(SybaseSQLExecutionContext): - - def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): - super(SybaseSQLExecutionContext_mxodbc, self).__init__(dialect, connection, compiled, statement, parameters) - - def pre_exec(self): - super(SybaseSQLExecutionContext_mxodbc, self).pre_exec() - - def post_exec(self): - if self.compiled.isinsert: - table = self.compiled.statement.table - # get the inserted values of the primary key - - # get any sequence IDs first (using @@identity) - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - lastrowid = int(row[0]) - if lastrowid > 0: - # an IDENTITY was inserted, fetch it - # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! - if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: - self._last_inserted_ids = [lastrowid] - else: - self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] - super(SybaseSQLExecutionContext_mxodbc, self).post_exec() - -class SybaseSQLExecutionContext_pyodbc(SybaseSQLExecutionContext): - def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): - super(SybaseSQLExecutionContext_pyodbc, self).__init__(dialect, connection, compiled, statement, parameters) - - def pre_exec(self): - super(SybaseSQLExecutionContext_pyodbc, self).pre_exec() - - def post_exec(self): - if self.compiled.isinsert: - table = self.compiled.statement.table - # get the inserted values of the primary key - - # get any sequence IDs first (using @@identity) - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - lastrowid = int(row[0]) - if lastrowid > 0: - # an IDENTITY was inserted, fetch it - # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! - if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: - self._last_inserted_ids = [lastrowid] - else: - self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] - super(SybaseSQLExecutionContext_pyodbc, self).post_exec() - -class SybaseSQLDialect(default.DefaultDialect): - colspecs = { - # FIXME: unicode support - #sqltypes.Unicode : SybaseUnicode, - sqltypes.Integer : SybaseInteger, - sqltypes.SmallInteger : SybaseSmallInteger, - sqltypes.Numeric : SybaseNumeric, - sqltypes.Float : SybaseFloat, - sqltypes.String : SybaseString, - sqltypes.Binary : SybaseBinary, - sqltypes.Boolean : SybaseBoolean, - sqltypes.Text : SybaseText, - sqltypes.CHAR : SybaseChar, - sqltypes.TIMESTAMP : SybaseTimeStamp, - sqltypes.FLOAT : SybaseFloat, - } - - ischema_names = { - 'integer' : SybaseInteger, - 'unsigned int' : SybaseInteger, - 'unsigned smallint' : SybaseInteger, - 'unsigned bigint' : SybaseInteger, - 'bigint': SybaseBigInteger, - 'smallint' : SybaseSmallInteger, - 'tinyint' : SybaseTinyInteger, - 'varchar' : SybaseString, - 'long varchar' : SybaseText, - 'char' : SybaseChar, - 'decimal' : SybaseNumeric, - 'numeric' : SybaseNumeric, - 'float' : SybaseFloat, - 'double' : SybaseFloat, - 'binary' : SybaseBinary, - 'long binary' : SybaseBinary, - 'varbinary' : SybaseBinary, - 'bit': SybaseBoolean, - 'image' : SybaseBinary, - 'timestamp': SybaseTimeStamp, - 'money': SybaseMoney, - 'smallmoney': SybaseSmallMoney, - 'uniqueidentifier': SybaseUniqueIdentifier, - - 'java.lang.Object' : SybaseTypeError, - 'java serialization' : SybaseTypeError, - } - - name = 'sybase' - # Sybase backend peculiarities - supports_unicode_statements = False - supports_sane_rowcount = False - supports_sane_multi_rowcount = False - execution_ctx_cls = SybaseSQLExecutionContext - - def __new__(cls, dbapi=None, *args, **kwargs): - if cls != SybaseSQLDialect: - return super(SybaseSQLDialect, cls).__new__(cls, *args, **kwargs) - if dbapi: - print dbapi.__name__ - dialect = dialect_mapping.get(dbapi.__name__) - return dialect(*args, **kwargs) - else: - return object.__new__(cls, *args, **kwargs) - - def __init__(self, **params): - super(SybaseSQLDialect, self).__init__(**params) - self.text_as_varchar = False - # FIXME: what is the default schema for sybase connections (DBA?) ? - self.set_default_schema_name("dba") - - def dbapi(cls, module_name=None): - if module_name: - try: - dialect_cls = dialect_mapping[module_name] - return dialect_cls.import_dbapi() - except KeyError: - raise exc.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name) - else: - for dialect_cls in dialect_mapping.values(): - try: - return dialect_cls.import_dbapi() - except ImportError, e: - pass - else: - raise ImportError('No DBAPI module detected for SybaseSQL - please install mxodbc') - dbapi = classmethod(dbapi) - - def type_descriptor(self, typeobj): - newobj = sqltypes.adapt_type(typeobj, self.colspecs) - return newobj - - def last_inserted_ids(self): - return self.context.last_inserted_ids - - def get_default_schema_name(self, connection): - return self.schema_name - - def set_default_schema_name(self, schema_name): - self.schema_name = schema_name - - def do_execute(self, cursor, statement, params, **kwargs): - params = tuple(params) - super(SybaseSQLDialect, self).do_execute(cursor, statement, params, **kwargs) - - # FIXME: remove ? - def _execute(self, c, statement, parameters): - try: - if parameters == {}: - parameters = () - c.execute(statement, parameters) - self.context.rowcount = c.rowcount - c.DBPROP_COMMITPRESERVE = "Y" - except Exception, e: - raise exc.DBAPIError.instance(statement, parameters, e) - - def table_names(self, connection, schema): - """Ignore the schema and the charset for now.""" - s = sql.select([tables.c.table_name], - sql.not_(tables.c.table_name.like("SYS%")) and - tables.c.creator >= 100 - ) - rp = connection.execute(s) - return [row[0] for row in rp.fetchall()] - - def has_table(self, connection, tablename, schema=None): - # FIXME: ignore schemas for sybase - s = sql.select([tables.c.table_name], tables.c.table_name == tablename) - - c = connection.execute(s) - row = c.fetchone() - print "has_table: " + tablename + ": " + str(bool(row is not None)) - return row is not None - - def reflecttable(self, connection, table, include_columns): - # Get base columns - if table.schema is not None: - current_schema = table.schema - else: - current_schema = self.get_default_schema_name(connection) - - s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id]) - - c = connection.execute(s) - found_table = False - # makes sure we append the columns in the correct order - while True: - row = c.fetchone() - if row is None: - break - found_table = True - (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = ( - row[columns.c.column_name], - row[domains.c.domain_name], - row[columns.c.nulls] == 'Y', - row[columns.c.width], - row[domains.c.precision], - row[columns.c.scale], - row[columns.c.default], - row[columns.c.pkey] == 'Y', - row[columns.c.max_identity], - row[tables.c.table_id], - row[columns.c.column_id], - ) - if include_columns and name not in include_columns: - continue - - # FIXME: else problems with SybaseBinary(size) - if numericscale == 0: - numericscale = None - - args = [] - for a in (charlen, numericprec, numericscale): - if a is not None: - args.append(a) - coltype = self.ischema_names.get(type, None) - if coltype == SybaseString and charlen == -1: - coltype = SybaseText() - else: - if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % - (type, name)) - coltype = sqltypes.NULLTYPE - coltype = coltype(*args) - colargs = [] - if default is not None: - colargs.append(schema.DefaultClause(sql.text(default))) - - # any sequences ? - col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs) - if int(max_identity) > 0: - col.sequence = schema.Sequence(name + '_identity') - col.sequence.start = int(max_identity) - col.sequence.increment = 1 - - # append the column - table.append_column(col) - - # any foreign key constraint for this table ? - # note: no multi-column foreign keys are considered - s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name } - c = connection.execute(s) - foreignKeys = {} - while True: - row = c.fetchone() - if row is None: - break - (foreign_table, foreign_column, primary_table, primary_column) = ( - row[0], row[1], row[2], row[3], - ) - if not primary_table in foreignKeys.keys(): - foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]] - else: - foreignKeys[primary_table][0].append('%s'%(foreign_column)) - foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column)) - for primary_table in foreignKeys.keys(): - #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)])) - table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True)) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - -class SybaseSQLDialect_mxodbc(SybaseSQLDialect): - execution_ctx_cls = SybaseSQLExecutionContext_mxodbc - - def __init__(self, **params): - super(SybaseSQLDialect_mxodbc, self).__init__(**params) - - self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()} - - def import_dbapi(cls): - #import mx.ODBC.Windows as module - import mxODBC as module - return module - import_dbapi = classmethod(import_dbapi) - - colspecs = SybaseSQLDialect.colspecs.copy() - colspecs[sqltypes.Time] = SybaseTime_mxodbc - colspecs[sqltypes.Date] = SybaseDate_mxodbc - colspecs[sqltypes.DateTime] = SybaseDateTime_mxodbc - - ischema_names = SybaseSQLDialect.ischema_names.copy() - ischema_names['time'] = SybaseTime_mxodbc - ischema_names['date'] = SybaseDate_mxodbc - ischema_names['datetime'] = SybaseDateTime_mxodbc - ischema_names['smalldatetime'] = SybaseDateTime_mxodbc - - def is_disconnect(self, e): - # FIXME: optimize - #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e) - #return True - return False - - def do_execute(self, cursor, statement, parameters, context=None, **kwargs): - super(SybaseSQLDialect_mxodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) - - def create_connect_args(self, url): - '''Return a tuple of *args,**kwargs''' - # FIXME: handle mx.odbc.Windows proprietary args - opts = url.translate_connect_args(username='user') - opts.update(url.query) - argsDict = {} - argsDict['user'] = opts['user'] - argsDict['password'] = opts['password'] - connArgs = [[opts['dsn']], argsDict] - return connArgs - - -class SybaseSQLDialect_pyodbc(SybaseSQLDialect): - execution_ctx_cls = SybaseSQLExecutionContext_pyodbc - - def __init__(self, **params): - super(SybaseSQLDialect_pyodbc, self).__init__(**params) - self.dbapi_type_map = {'getdate' : SybaseDate_pyodbc()} - - def import_dbapi(cls): - import mypyodbc as module - return module - import_dbapi = classmethod(import_dbapi) - - colspecs = SybaseSQLDialect.colspecs.copy() - colspecs[sqltypes.Time] = SybaseTime_pyodbc - colspecs[sqltypes.Date] = SybaseDate_pyodbc - colspecs[sqltypes.DateTime] = SybaseDateTime_pyodbc - - ischema_names = SybaseSQLDialect.ischema_names.copy() - ischema_names['time'] = SybaseTime_pyodbc - ischema_names['date'] = SybaseDate_pyodbc - ischema_names['datetime'] = SybaseDateTime_pyodbc - ischema_names['smalldatetime'] = SybaseDateTime_pyodbc - - def is_disconnect(self, e): - # FIXME: optimize - #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e) - #return True - return False - - def do_execute(self, cursor, statement, parameters, context=None, **kwargs): - super(SybaseSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) - - def create_connect_args(self, url): - '''Return a tuple of *args,**kwargs''' - # FIXME: handle pyodbc proprietary args - opts = url.translate_connect_args(username='user') - opts.update(url.query) - - self.autocommit = False - if 'autocommit' in opts: - self.autocommit = bool(int(opts.pop('autocommit'))) - - argsDict = {} - argsDict['UID'] = opts['user'] - argsDict['PWD'] = opts['password'] - argsDict['DSN'] = opts['dsn'] - connArgs = [[';'.join(["%s=%s"%(key, argsDict[key]) for key in argsDict])], {'autocommit' : self.autocommit}] - return connArgs - - -dialect_mapping = { - 'sqlalchemy.databases.mxODBC' : SybaseSQLDialect_mxodbc, -# 'pyodbc' : SybaseSQLDialect_pyodbc, - } - - -class SybaseSQLCompiler(compiler.SQLCompiler): - operators = compiler.SQLCompiler.operators.copy() - operators.update({ - sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y), - }) - - def bindparam_string(self, name): - res = super(SybaseSQLCompiler, self).bindparam_string(name) - if name.lower().startswith('literal'): - res = 'STRING(%s)' % res - return res - - def get_select_precolumns(self, select): - s = select._distinct and "DISTINCT " or "" - if select._limit: - #if select._limit == 1: - #s += "FIRST " - #else: - #s += "TOP %s " % (select._limit,) - s += "TOP %s " % (select._limit,) - if select._offset: - if not select._limit: - # FIXME: sybase doesn't allow an offset without a limit - # so use a huge value for TOP here - s += "TOP 1000000 " - s += "START AT %s " % (select._offset+1,) - return s - - def limit_clause(self, select): - # Limit in sybase is after the select keyword - return "" - - def visit_binary(self, binary): - """Move bind parameters to the right-hand side of an operator, where possible.""" - if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: - return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) - else: - return super(SybaseSQLCompiler, self).visit_binary(binary) - - def label_select_column(self, select, column, asfrom): - if isinstance(column, expression.Function): - return column.label(None) - else: - return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) - - function_rewrites = {'current_date': 'getdate', - } - def visit_function(self, func): - func.name = self.function_rewrites.get(func.name, func.name) - res = super(SybaseSQLCompiler, self).visit_function(func) - if func.name.lower() == 'getdate': - # apply CAST operator - # FIXME: what about _pyodbc ? - cast = expression._Cast(func, SybaseDate_mxodbc) - # infinite recursion - # res = self.visit_cast(cast) - res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) - return res - - def for_update_clause(self, select): - # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use - return '' - - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) - - # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not self.is_subquery() or select._limit): - return " ORDER BY " + order_by - else: - return "" - - -class SybaseSQLSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - - colspec = self.preparer.format_column(column) - - if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ - column.autoincrement and isinstance(column.type, sqltypes.Integer): - if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): - column.sequence = schema.Sequence(column.name + '_seq') - - if hasattr(column, 'sequence'): - column.table.has_sequence = column - #colspec += " numeric(30,0) IDENTITY" - colspec += " Integer IDENTITY" - else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - - if not column.nullable: - colspec += " NOT NULL" - - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - return colspec - - -class SybaseSQLSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - self.append("\nDROP INDEX %s.%s" % ( - self.preparer.quote_identifier(index.table.name), - self.preparer.quote(self._validate_identifier(index.name, False), index.quote) - )) - self.execute() - - -class SybaseSQLDefaultRunner(base.DefaultRunner): - pass - - -class SybaseSQLIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = RESERVED_WORDS - - def __init__(self, dialect): - super(SybaseSQLIdentifierPreparer, self).__init__(dialect) - - def _escape_identifier(self, value): - #TODO: determin SybaseSQL's escapeing rules - return value - - def _fold_identifier_case(self, value): - #TODO: determin SybaseSQL's case folding rules - return value - - -dialect = SybaseSQLDialect -dialect.statement_compiler = SybaseSQLCompiler -dialect.schemagenerator = SybaseSQLSchemaGenerator -dialect.schemadropper = SybaseSQLSchemaDropper -dialect.preparer = SybaseSQLIdentifierPreparer -dialect.defaultrunner = SybaseSQLDefaultRunner diff --git a/lib/sqlalchemy/dialects/__init__.pyc b/lib/sqlalchemy/dialects/__init__.pyc new file mode 100644 index 0000000000..6c373872ac Binary files /dev/null and b/lib/sqlalchemy/dialects/__init__.pyc differ diff --git a/lib/sqlalchemy/dialects/access/__init__.py b/lib/sqlalchemy/dialects/access/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/dialects/access/base.py similarity index 100% rename from lib/sqlalchemy/databases/access.py rename to lib/sqlalchemy/dialects/access/base.py diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/dialects/firebird/base.py similarity index 100% rename from lib/sqlalchemy/databases/firebird.py rename to lib/sqlalchemy/dialects/firebird/base.py diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/dialects/information_schema.py similarity index 100% rename from lib/sqlalchemy/databases/information_schema.py rename to lib/sqlalchemy/dialects/information_schema.py diff --git a/lib/sqlalchemy/dialects/informix/__init__.py b/lib/sqlalchemy/dialects/informix/__init__.py new file mode 100644 index 0000000000..f2fcc76d4c --- /dev/null +++ b/lib/sqlalchemy/dialects/informix/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.informix import base, informixdb + +base.dialect = informixdb.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/dialects/informix/base.py similarity index 61% rename from lib/sqlalchemy/databases/informix.py rename to lib/sqlalchemy/dialects/informix/base.py index ad9dfd9bce..75e8cb54af 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/dialects/informix/base.py @@ -5,6 +5,12 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Informix support. + +--- THIS DIALECT IS NOT TESTED ON 0.6 --- + +""" + import datetime @@ -14,55 +20,7 @@ from sqlalchemy.engine import default from sqlalchemy import types as sqltypes -# for offset - -class informix_cursor(object): - def __init__( self , con ): - self.__cursor = con.cursor() - self.rowcount = 0 - - def offset( self , n ): - if n > 0: - self.fetchmany( n ) - self.rowcount = self.__cursor.rowcount - n - if self.rowcount < 0: - self.rowcount = 0 - else: - self.rowcount = self.__cursor.rowcount - - def execute( self , sql , params ): - if params is None or len( params ) == 0: - params = [] - - return self.__cursor.execute( sql , params ) - - def __getattr__( self , name ): - if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ): - return getattr( self.__cursor , name ) - -class InfoNumeric(sqltypes.Numeric): - def get_col_spec(self): - if not self.precision: - return 'NUMERIC' - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class InfoInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class InfoSmallInteger(sqltypes.SmallInteger): - def get_col_spec(self): - return "SMALLINT" - -class InfoDate(sqltypes.Date): - def get_col_spec( self ): - return "DATE" - class InfoDateTime(sqltypes.DateTime ): - def get_col_spec(self): - return "DATETIME YEAR TO SECOND" - def bind_processor(self, dialect): def process(value): if value is not None: @@ -72,9 +30,6 @@ class InfoDateTime(sqltypes.DateTime ): return process class InfoTime(sqltypes.Time ): - def get_col_spec(self): - return "DATETIME HOUR TO SECOND" - def bind_processor(self, dialect): def process(value): if value is not None: @@ -91,35 +46,8 @@ class InfoTime(sqltypes.Time ): return value return process -class InfoText(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(255)" - -class InfoString(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} - - def bind_processor(self, dialect): - def process(value): - if value == '': - return None - else: - return value - return process - -class InfoChar(sqltypes.CHAR): - def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} - -class InfoBinary(sqltypes.Binary): - def get_col_spec(self): - return "BYTE" class InfoBoolean(sqltypes.Boolean): - default_type = 'NUM' - def get_col_spec(self): - return "SMALLINT" - def result_processor(self, dialect): def process(value): if value is None: @@ -140,104 +68,156 @@ class InfoBoolean(sqltypes.Boolean): return process colspecs = { - sqltypes.Integer : InfoInteger, - sqltypes.SmallInteger : InfoSmallInteger, - sqltypes.Numeric : InfoNumeric, - sqltypes.Float : InfoNumeric, sqltypes.DateTime : InfoDateTime, - sqltypes.Date : InfoDate, sqltypes.Time: InfoTime, - sqltypes.String : InfoString, - sqltypes.Binary : InfoBinary, sqltypes.Boolean : InfoBoolean, - sqltypes.Text : InfoText, - sqltypes.CHAR: InfoChar, } ischema_names = { - 0 : InfoString, # CHAR - 1 : InfoSmallInteger, # SMALLINT - 2 : InfoInteger, # INT - 3 : InfoNumeric, # Float - 3 : InfoNumeric, # SmallFloat - 5 : InfoNumeric, # DECIMAL - 6 : InfoInteger, # Serial - 7 : InfoDate, # DATE - 8 : InfoNumeric, # MONEY - 10 : InfoDateTime, # DATETIME - 11 : InfoBinary, # BYTE - 12 : InfoText, # TEXT - 13 : InfoString, # VARCHAR - 15 : InfoString, # NCHAR - 16 : InfoString, # NVARCHAR - 17 : InfoInteger, # INT8 - 18 : InfoInteger, # Serial8 - 43 : InfoString, # LVARCHAR - -1 : InfoBinary, # BLOB - -1 : InfoText, # CLOB + 0 : sqltypes.CHAR, # CHAR + 1 : sqltypes.SMALLINT, # SMALLINT + 2 : sqltypes.INTEGER, # INT + 3 : sqltypes.FLOAT, # Float + 3 : sqltypes.Float, # SmallFloat + 5 : sqltypes.DECIMAL, # DECIMAL + 6 : sqltypes.Integer, # Serial + 7 : sqltypes.DATE, # DATE + 8 : sqltypes.Numeric, # MONEY + 10 : sqltypes.DATETIME, # DATETIME + 11 : sqltypes.Binary, # BYTE + 12 : sqltypes.TEXT, # TEXT + 13 : sqltypes.VARCHAR, # VARCHAR + 15 : sqltypes.NCHAR, # NCHAR + 16 : sqltypes.NVARCHAR, # NVARCHAR + 17 : sqltypes.Integer, # INT8 + 18 : sqltypes.Integer, # Serial8 + 43 : sqltypes.String, # LVARCHAR + -1 : sqltypes.BLOB, # BLOB + -1 : sqltypes.CLOB, # CLOB } -class InfoExecutionContext(default.DefaultExecutionContext): - # cursor.sqlerrd - # 0 - estimated number of rows returned - # 1 - serial value after insert or ISAM error code - # 2 - number of rows processed - # 3 - estimated cost - # 4 - offset of the error into the SQL statement - # 5 - rowid after insert - def post_exec(self): - if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None: - self._last_inserted_ids = [self.cursor.sqlerrd[1]] - elif hasattr( self.compiled , 'offset' ): - self.cursor.offset( self.compiled.offset ) - super(InfoExecutionContext, self).post_exec() - - def create_cursor( self ): - return informix_cursor( self.connection.connection ) - -class InfoDialect(default.DefaultDialect): - name = 'informix' - default_paramstyle = 'qmark' - # for informix 7.31 - max_identifier_length = 18 +class InfoTypeCompiler(compiler.GenericTypeCompiler): + def visit_DATETIME(self, type_): + return "DATETIME YEAR TO SECOND" + + def visit_TIME(self, type_): + return "DATETIME HOUR TO SECOND" + + def visit_binary(self, type_): + return "BYTE" + + def visit_boolean(self, type_): + return "SMALLINT" + +class InfoSQLCompiler(compiler.SQLCompiler): - def __init__(self, use_ansi=True, **kwargs): - self.use_ansi = use_ansi - default.DefaultDialect.__init__(self, **kwargs) + def __init__(self, *args, **kwargs): + self.limit = 0 + self.offset = 0 - def dbapi(cls): - import informixdb - return informixdb - dbapi = classmethod(dbapi) + compiler.SQLCompiler.__init__( self , *args, **kwargs ) - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return 'closed the connection' in str(e) or 'connection not open' in str(e) + def default_from(self): + return " from systables where tabname = 'systables' " + + def get_select_precolumns( self , select ): + s = select._distinct and "DISTINCT " or "" + # only has limit + if select._limit: + off = select._offset or 0 + s += " FIRST %s " % ( select._limit + off ) else: - return False + s += "" + return s - def do_begin(self , connect ): - cu = connect.cursor() - cu.execute( 'SET LOCK MODE TO WAIT' ) - #cu.execute( 'SET ISOLATION TO REPEATABLE READ' ) + def visit_select(self, select): + if select._offset: + self.offset = select._offset + self.limit = select._limit or 0 + # the column in order by clause must in select too - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) + def __label( c ): + try: + return c._label.lower() + except: + return '' + + # TODO: dont modify the original select, generate a new one + a = [ __label(c) for c in select._raw_columns ] + for c in select._order_by_clause.clauses: + if ( __label(c) not in a ): + select.append_column( c ) - def create_connect_args(self, url): - if url.host: - dsn = '%s@%s' % ( url.database , url.host ) + return compiler.SQLCompiler.visit_select(self, select) + + def limit_clause(self, select): + return "" + + def visit_function( self , func ): + if func.name.lower() == 'current_date': + return "today" + elif func.name.lower() == 'current_time': + return "CURRENT HOUR TO SECOND" + elif func.name.lower() in ( 'current_timestamp' , 'now' ): + return "CURRENT YEAR TO SECOND" else: - dsn = url.database + return compiler.SQLCompiler.visit_function( self , func ) - if url.username: - opt = { 'user':url.username , 'password': url.password } + def visit_clauselist(self, list, **kwargs): + return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None]) + +class InfoDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, first_pk=False): + colspec = self.preparer.format_column(column) + if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \ + isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk: + colspec += " SERIAL" + self.has_serial = True else: - opt = {} + colspec += " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + + return colspec + + def post_create_table(self, table): + if hasattr( self , 'has_serial' ): + del self.has_serial + return '' + +class InfoIdentifierPreparer(compiler.IdentifierPreparer): + def __init__(self, dialect): + super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'") - return ([dsn], opt) + def format_constraint(self, constraint): + # informix doesnt support names for constraints + return '' + + def _requires_quotes(self, value): + return False + +class InformixDialect(default.DefaultDialect): + name = 'informix' + # for informix 7.31 + max_identifier_length = 18 + type_compiler = InfoTypeCompiler + poolclass = pool.SingletonThreadPool + statement_compiler = InfoSQLCompiler + ddl_compiler = InfoDDLCompiler + preparer = InfoIdentifierPreparer + colspecs = colspecs + ischema_names = ischema_names + + def do_begin(self , connect ): + cu = connect.cursor() + cu.execute( 'SET LOCK MODE TO WAIT' ) + #cu.execute( 'SET ISOLATION TO REPEATABLE READ' ) def table_names(self, connection, schema): s = "select tabname from systables" @@ -352,142 +332,3 @@ class InfoDialect(default.DefaultDialect): for cons_name, cons_type, local_column in rows: table.primary_key.add( table.c[local_column] ) -class InfoCompiler(compiler.SQLCompiler): - """Info compiler modifies the lexical structure of Select statements to work under - non-ANSI configured Oracle databases, if the use_ansi flag is False.""" - - def __init__(self, *args, **kwargs): - self.limit = 0 - self.offset = 0 - - compiler.SQLCompiler.__init__( self , *args, **kwargs ) - - def default_from(self): - return " from systables where tabname = 'systables' " - - def get_select_precolumns( self , select ): - s = select._distinct and "DISTINCT " or "" - # only has limit - if select._limit: - off = select._offset or 0 - s += " FIRST %s " % ( select._limit + off ) - else: - s += "" - return s - - def visit_select(self, select): - if select._offset: - self.offset = select._offset - self.limit = select._limit or 0 - # the column in order by clause must in select too - - def __label( c ): - try: - return c._label.lower() - except: - return '' - - # TODO: dont modify the original select, generate a new one - a = [ __label(c) for c in select._raw_columns ] - for c in select._order_by_clause.clauses: - if ( __label(c) not in a ): - select.append_column( c ) - - return compiler.SQLCompiler.visit_select(self, select) - - def limit_clause(self, select): - return "" - - def visit_function( self , func ): - if func.name.lower() == 'current_date': - return "today" - elif func.name.lower() == 'current_time': - return "CURRENT HOUR TO SECOND" - elif func.name.lower() in ( 'current_timestamp' , 'now' ): - return "CURRENT YEAR TO SECOND" - else: - return compiler.SQLCompiler.visit_function( self , func ) - - def visit_clauselist(self, list, **kwargs): - return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None]) - -class InfoSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, first_pk=False): - colspec = self.preparer.format_column(column) - if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \ - isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk: - colspec += " SERIAL" - self.has_serial = True - else: - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - - return colspec - - def post_create_table(self, table): - if hasattr( self , 'has_serial' ): - del self.has_serial - return '' - - def visit_primary_key_constraint(self, constraint): - # for informix 7.31 not support constraint name - name = constraint.name - constraint.name = None - super(InfoSchemaGenerator, self).visit_primary_key_constraint(constraint) - constraint.name = name - - def visit_unique_constraint(self, constraint): - # for informix 7.31 not support constraint name - name = constraint.name - constraint.name = None - super(InfoSchemaGenerator, self).visit_unique_constraint(constraint) - constraint.name = name - - def visit_foreign_key_constraint( self , constraint ): - if constraint.name is not None: - constraint.use_alter = True - else: - super( InfoSchemaGenerator , self ).visit_foreign_key_constraint( constraint ) - - def define_foreign_key(self, constraint): - # for informix 7.31 not support constraint name - if constraint.use_alter: - name = constraint.name - constraint.name = None - self.append( "CONSTRAINT " ) - super(InfoSchemaGenerator, self).define_foreign_key(constraint) - constraint.name = name - if name is not None: - self.append( " CONSTRAINT " + name ) - else: - super(InfoSchemaGenerator, self).define_foreign_key(constraint) - - def visit_index(self, index): - if len( index.columns ) == 1 and index.columns[0].foreign_key: - return - super(InfoSchemaGenerator, self).visit_index(index) - -class InfoIdentifierPreparer(compiler.IdentifierPreparer): - def __init__(self, dialect): - super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'") - - def _requires_quotes(self, value): - return False - -class InfoSchemaDropper(compiler.SchemaDropper): - def drop_foreignkey(self, constraint): - if constraint.name is not None: - super( InfoSchemaDropper , self ).drop_foreignkey( constraint ) - -dialect = InfoDialect -poolclass = pool.SingletonThreadPool -dialect.statement_compiler = InfoCompiler -dialect.schemagenerator = InfoSchemaGenerator -dialect.schemadropper = InfoSchemaDropper -dialect.preparer = InfoIdentifierPreparer -dialect.execution_ctx_cls = InfoExecutionContext \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/informix/informixdb.py b/lib/sqlalchemy/dialects/informix/informixdb.py new file mode 100644 index 0000000000..ddfd597065 --- /dev/null +++ b/lib/sqlalchemy/dialects/informix/informixdb.py @@ -0,0 +1,80 @@ +from sqlalchemy.dialects.informix.base import InformixDialect +from sqlalchemy.engine import default + +# for offset + +class informix_cursor(object): + def __init__( self , con ): + self.__cursor = con.cursor() + self.rowcount = 0 + + def offset( self , n ): + if n > 0: + self.fetchmany( n ) + self.rowcount = self.__cursor.rowcount - n + if self.rowcount < 0: + self.rowcount = 0 + else: + self.rowcount = self.__cursor.rowcount + + def execute( self , sql , params ): + if params is None or len( params ) == 0: + params = [] + + return self.__cursor.execute( sql , params ) + + def __getattr__( self , name ): + if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ): + return getattr( self.__cursor , name ) + + +class InfoExecutionContext(default.DefaultExecutionContext): + # cursor.sqlerrd + # 0 - estimated number of rows returned + # 1 - serial value after insert or ISAM error code + # 2 - number of rows processed + # 3 - estimated cost + # 4 - offset of the error into the SQL statement + # 5 - rowid after insert + def post_exec(self): + if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None: + self._last_inserted_ids = [self.cursor.sqlerrd[1]] + elif hasattr( self.compiled , 'offset' ): + self.cursor.offset( self.compiled.offset ) + + def create_cursor( self ): + return informix_cursor( self.connection.connection ) + + +class Informix_informixdb(InformixDialect): + driver = 'informixdb' + default_paramstyle = 'qmark' + execution_context_cls = InfoExecutionContext + + @classmethod + def dbapi(cls): + import informixdb + return informixdb + + def create_connect_args(self, url): + if url.host: + dsn = '%s@%s' % ( url.database , url.host ) + else: + dsn = url.database + + if url.username: + opt = { 'user':url.username , 'password': url.password } + else: + opt = {} + + return ([dsn], opt) + + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'closed the connection' in str(e) or 'connection not open' in str(e) + else: + return False + + +dialect = Informix_informixdb \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/maxdb/__init__.py b/lib/sqlalchemy/dialects/maxdb/__init__.py new file mode 100644 index 0000000000..3f12448b79 --- /dev/null +++ b/lib/sqlalchemy/dialects/maxdb/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.maxdb import base, sapdb + +base.dialect = sapdb.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/databases/maxdb.py b/lib/sqlalchemy/dialects/maxdb/base.py similarity index 92% rename from lib/sqlalchemy/databases/maxdb.py rename to lib/sqlalchemy/dialects/maxdb/base.py index 6e521297fc..4be6f8f639 100644 --- a/lib/sqlalchemy/databases/maxdb.py +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -5,6 +5,8 @@ """Support for the MaxDB database. +-- NOT TESTED ON 0.6 -- + TODO: More module docs! MaxDB support is currently experimental. Overview @@ -79,16 +81,6 @@ class _StringType(sqltypes.String): super(_StringType, self).__init__(length=length, **kw) self.encoding = encoding - def get_col_spec(self): - if self.length is None: - spec = 'LONG' - else: - spec = '%s(%s)' % (self._type, self.length) - - if self.encoding is not None: - spec = ' '.join([spec, self.encoding.upper()]) - return spec - def bind_processor(self, dialect): if self.encoding == 'unicode': return None @@ -156,16 +148,6 @@ class MaxText(_StringType): return spec -class MaxInteger(sqltypes.Integer): - def get_col_spec(self): - return 'INTEGER' - - -class MaxSmallInteger(MaxInteger): - def get_col_spec(self): - return 'SMALLINT' - - class MaxNumeric(sqltypes.Numeric): """The FIXED (also NUMERIC, DECIMAL) data type.""" @@ -177,29 +159,7 @@ class MaxNumeric(sqltypes.Numeric): def bind_processor(self, dialect): return None - def get_col_spec(self): - if self.scale and self.precision: - return 'FIXED(%s, %s)' % (self.precision, self.scale) - elif self.precision: - return 'FIXED(%s)' % self.precision - else: - return 'INTEGER' - - -class MaxFloat(sqltypes.Float): - """The FLOAT data type.""" - - def get_col_spec(self): - if self.precision is None: - return 'FLOAT' - else: - return 'FLOAT(%s)' % (self.precision,) - - class MaxTimestamp(sqltypes.DateTime): - def get_col_spec(self): - return 'TIMESTAMP' - def bind_processor(self, dialect): def process(value): if value is None: @@ -242,9 +202,6 @@ class MaxTimestamp(sqltypes.DateTime): class MaxDate(sqltypes.Date): - def get_col_spec(self): - return 'DATE' - def bind_processor(self, dialect): def process(value): if value is None: @@ -279,9 +236,6 @@ class MaxDate(sqltypes.Date): class MaxTime(sqltypes.Time): - def get_col_spec(self): - return 'TIME' - def bind_processor(self, dialect): def process(value): if value is None: @@ -316,15 +270,7 @@ class MaxTime(sqltypes.Time): return process -class MaxBoolean(sqltypes.Boolean): - def get_col_spec(self): - return 'BOOLEAN' - - class MaxBlob(sqltypes.Binary): - def get_col_spec(self): - return 'LONG BYTE' - def bind_processor(self, dialect): def process(value): if value is None: @@ -341,18 +287,54 @@ class MaxBlob(sqltypes.Binary): return value.read(value.remainingLength()) return process +class MaxDBTypeCompiler(compiler.GenericTypeCompiler): + def _string_spec(self, string_spec, type_): + if type_.length is None: + spec = 'LONG' + else: + spec = '%s(%s)' % (string_spec, type_.length) + + if getattr(type_, 'encoding'): + spec = ' '.join([spec, getattr(type_, 'encoding').upper()]) + return spec + + def visit_text(self, type_): + spec = 'LONG' + if getattr(type_, 'encoding', None): + spec = ' '.join((spec, type_.encoding)) + elif type_.convert_unicode: + spec = ' '.join((spec, 'UNICODE')) + + return spec + def visit_char(self, type_): + return self._string_spec("CHAR", type_) + + def visit_string(self, type_): + return self._string_spec("VARCHAR", type_) + + def visit_binary(self, type_): + return "LONG BYTE" + + def visit_numeric(self, type_): + if type_.scale and type_.precision: + return 'FIXED(%s, %s)' % (type_.precision, type_.scale) + elif type_.precision: + return 'FIXED(%s)' % type_.precision + else: + return 'INTEGER' + + def visit_BOOLEAN(self, type_): + return "BOOLEAN" + colspecs = { - sqltypes.Integer: MaxInteger, - sqltypes.SmallInteger: MaxSmallInteger, sqltypes.Numeric: MaxNumeric, - sqltypes.Float: MaxFloat, sqltypes.DateTime: MaxTimestamp, sqltypes.Date: MaxDate, sqltypes.Time: MaxTime, sqltypes.String: MaxString, + sqltypes.Unicode:MaxUnicode, sqltypes.Binary: MaxBlob, - sqltypes.Boolean: MaxBoolean, sqltypes.Text: MaxText, sqltypes.CHAR: MaxChar, sqltypes.TIMESTAMP: MaxTimestamp, @@ -361,25 +343,25 @@ colspecs = { } ischema_names = { - 'boolean': MaxBoolean, - 'char': MaxChar, - 'character': MaxChar, - 'date': MaxDate, - 'fixed': MaxNumeric, - 'float': MaxFloat, - 'int': MaxInteger, - 'integer': MaxInteger, - 'long binary': MaxBlob, - 'long unicode': MaxText, - 'long': MaxText, - 'long': MaxText, - 'smallint': MaxSmallInteger, - 'time': MaxTime, - 'timestamp': MaxTimestamp, - 'varchar': MaxString, + 'boolean': sqltypes.BOOLEAN, + 'char': sqltypes.CHAR, + 'character': sqltypes.CHAR, + 'date': sqltypes.DATE, + 'fixed': sqltypes.Numeric, + 'float': sqltypes.FLOAT, + 'int': sqltypes.INT, + 'integer': sqltypes.INT, + 'long binary': sqltypes.BLOB, + 'long unicode': sqltypes.Text, + 'long': sqltypes.Text, + 'long': sqltypes.Text, + 'smallint': sqltypes.SmallInteger, + 'time': sqltypes.Time, + 'timestamp': sqltypes.TIMESTAMP, + 'varchar': sqltypes.VARCHAR, } - +# TODO: migrate this to sapdb.py class MaxDBExecutionContext(default.DefaultExecutionContext): def post_exec(self): # DB-API bug: if there were any functions as values, @@ -464,380 +446,127 @@ class MaxDBCachedColumnRow(engine_base.RowProxy): class MaxDBResultProxy(engine_base.ResultProxy): _process_row = MaxDBCachedColumnRow +class MaxDBCompiler(compiler.SQLCompiler): + operators = compiler.SQLCompiler.operators.copy() + operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y) -class MaxDBDialect(default.DefaultDialect): - name = 'maxdb' - supports_alter = True - supports_unicode_statements = True - max_identifier_length = 32 - supports_sane_rowcount = True - supports_sane_multi_rowcount = False - preexecute_pk_sequences = True + function_conversion = { + 'CURRENT_DATE': 'DATE', + 'CURRENT_TIME': 'TIME', + 'CURRENT_TIMESTAMP': 'TIMESTAMP', + } - # MaxDB-specific - datetimeformat = 'internal' + # These functions must be written without parens when called with no + # parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL' + bare_functions = set([ + 'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP', + 'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP', + 'UTCDATE', 'UTCDIFF']) - def __init__(self, _raise_known_sql_errors=False, **kw): - super(MaxDBDialect, self).__init__(**kw) - self._raise_known = _raise_known_sql_errors + def default_from(self): + return ' FROM DUAL' - if self.dbapi is None: - self.dbapi_type_map = {} + def for_update_clause(self, select): + clause = select.for_update + if clause is True: + return " WITH LOCK EXCLUSIVE" + elif clause is None: + return "" + elif clause == "read": + return " WITH LOCK" + elif clause == "ignore": + return " WITH LOCK (IGNORE) EXCLUSIVE" + elif clause == "nowait": + return " WITH LOCK (NOWAIT) EXCLUSIVE" + elif isinstance(clause, basestring): + return " WITH LOCK %s" % clause.upper() + elif not clause: + return "" else: - self.dbapi_type_map = { - 'Long Binary': MaxBlob(), - 'Long byte_t': MaxBlob(), - 'Long Unicode': MaxText(), - 'Timestamp': MaxTimestamp(), - 'Date': MaxDate(), - 'Time': MaxTime(), - datetime.datetime: MaxTimestamp(), - datetime.date: MaxDate(), - datetime.time: MaxTime(), - } + return " WITH LOCK EXCLUSIVE" - def dbapi(cls): - from sapdb import dbapi as _dbapi - return _dbapi - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - opts.update(url.query) - return [], opts - - def type_descriptor(self, typeobj): - if isinstance(typeobj, type): - typeobj = typeobj() - if isinstance(typeobj, sqltypes.Unicode): - return typeobj.adapt(MaxUnicode) + def apply_function_parens(self, func): + if func.name.upper() in self.bare_functions: + return len(func.clauses) > 0 else: - return sqltypes.adapt_type(typeobj, colspecs) - - def do_execute(self, cursor, statement, parameters, context=None): - res = cursor.execute(statement, parameters) - if isinstance(res, int) and context is not None: - context._rowcount = res - - def do_release_savepoint(self, connection, name): - # Does MaxDB truly support RELEASE SAVEPOINT ? All my attempts - # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS - # BEGIN SQLSTATE: I7065" - # Note that ROLLBACK TO works fine. In theory, a RELEASE should - # just free up some transactional resources early, before the overall - # COMMIT/ROLLBACK so omitting it should be relatively ok. - pass + return True - def get_default_schema_name(self, connection): - try: - return self._default_schema_name - except AttributeError: - name = self.identifier_preparer._normalize_name( - connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) - self._default_schema_name = name - return name + def visit_function(self, fn, **kw): + transform = self.function_conversion.get(fn.name.upper(), None) + if transform: + fn = fn._clone() + fn.name = transform + return super(MaxDBCompiler, self).visit_function(fn, **kw) - def has_table(self, connection, table_name, schema=None): - denormalize = self.identifier_preparer._denormalize_name - bind = [denormalize(table_name)] - if schema is None: - sql = ("SELECT tablename FROM TABLES " - "WHERE TABLES.TABLENAME=? AND" - " TABLES.SCHEMANAME=CURRENT_SCHEMA ") + def visit_cast(self, cast, **kwargs): + # MaxDB only supports casts * to NUMERIC, * to VARCHAR or + # date/time to VARCHAR. Casts of LONGs will fail. + if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)): + return "NUM(%s)" % self.process(cast.clause) + elif isinstance(cast.type, sqltypes.String): + return "CHR(%s)" % self.process(cast.clause) else: - sql = ("SELECT tablename FROM TABLES " - "WHERE TABLES.TABLENAME = ? AND" - " TABLES.SCHEMANAME=? ") - bind.append(denormalize(schema)) - - rp = connection.execute(sql, bind) - found = bool(rp.fetchone()) - rp.close() - return found + return self.process(cast.clause) - def table_names(self, connection, schema): - if schema is None: - sql = (" SELECT TABLENAME FROM TABLES WHERE " - " SCHEMANAME=CURRENT_SCHEMA ") - rs = connection.execute(sql) + def visit_sequence(self, sequence): + if sequence.optional: + return None else: - sql = (" SELECT TABLENAME FROM TABLES WHERE " - " SCHEMANAME=? ") - matchname = self.identifier_preparer._denormalize_name(schema) - rs = connection.execute(sql, matchname) - normalize = self.identifier_preparer._normalize_name - return [normalize(row[0]) for row in rs] + return (self.dialect.identifier_preparer.format_sequence(sequence) + + ".NEXTVAL") - def reflecttable(self, connection, table, include_columns): - denormalize = self.identifier_preparer._denormalize_name - normalize = self.identifier_preparer._normalize_name + class ColumnSnagger(visitors.ClauseVisitor): + def __init__(self): + self.count = 0 + self.column = None + def visit_column(self, column): + self.column = column + self.count += 1 - st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, ' - ' NULLABLE, "DEFAULT", DEFAULTFUNCTION ' - 'FROM COLUMNS ' - 'WHERE TABLENAME=? AND SCHEMANAME=%s ' - 'ORDER BY POS') + def _find_labeled_columns(self, columns, use_labels=False): + labels = {} + for column in columns: + if isinstance(column, basestring): + continue + snagger = self.ColumnSnagger() + snagger.traverse(column) + if snagger.count == 1: + if isinstance(column, sql_expr._Label): + labels[unicode(snagger.column)] = column.name + elif use_labels: + labels[unicode(snagger.column)] = column._label - fk = ('SELECT COLUMNNAME, FKEYNAME, ' - ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, ' - ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA ' - ' THEN 1 ELSE 0 END) AS in_schema ' - 'FROM FOREIGNKEYCOLUMNS ' - 'WHERE TABLENAME=? AND SCHEMANAME=%s ' - 'ORDER BY FKEYNAME ') + return labels - params = [denormalize(table.name)] - if not table.schema: - st = st % 'CURRENT_SCHEMA' - fk = fk % 'CURRENT_SCHEMA' - else: - st = st % '?' - fk = fk % '?' - params.append(denormalize(table.schema)) + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) - rows = connection.execute(st, params).fetchall() - if not rows: - raise exc.NoSuchTableError(table.fullname) + # ORDER BY clauses in DISTINCT queries must reference aliased + # inner columns by alias name, not true column name. + if order_by and getattr(select, '_distinct', False): + labels = self._find_labeled_columns(select.inner_columns, + select.use_labels) + if labels: + for needs_alias in labels.keys(): + r = re.compile(r'(^| )(%s)(,| |$)' % + re.escape(needs_alias)) + order_by = r.sub((r'\1%s\3' % labels[needs_alias]), + order_by) - include_columns = set(include_columns or []) - - for row in rows: - (name, mode, col_type, encoding, length, scale, - nullable, constant_def, func_def) = row - - name = normalize(name) - - if include_columns and name not in include_columns: - continue - - type_args, type_kw = [], {} - if col_type == 'FIXED': - type_args = length, scale - # Convert FIXED(10) DEFAULT SERIAL to our Integer - if (scale == 0 and - func_def is not None and func_def.startswith('SERIAL')): - col_type = 'INTEGER' - type_args = length, - elif col_type in 'FLOAT': - type_args = length, - elif col_type in ('CHAR', 'VARCHAR'): - type_args = length, - type_kw['encoding'] = encoding - elif col_type == 'LONG': - type_kw['encoding'] = encoding - - try: - type_cls = ischema_names[col_type.lower()] - type_instance = type_cls(*type_args, **type_kw) - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (col_type, name)) - type_instance = sqltypes.NullType - - col_kw = {'autoincrement': False} - col_kw['nullable'] = (nullable == 'YES') - col_kw['primary_key'] = (mode == 'KEY') - - if func_def is not None: - if func_def.startswith('SERIAL'): - if col_kw['primary_key']: - # No special default- let the standard autoincrement - # support handle SERIAL pk columns. - col_kw['autoincrement'] = True - else: - # strip current numbering - col_kw['server_default'] = schema.DefaultClause( - sql.text('SERIAL')) - col_kw['autoincrement'] = True - else: - col_kw['server_default'] = schema.DefaultClause( - sql.text(func_def)) - elif constant_def is not None: - col_kw['server_default'] = schema.DefaultClause(sql.text( - "'%s'" % constant_def.replace("'", "''"))) - - table.append_column(schema.Column(name, type_instance, **col_kw)) - - fk_sets = itertools.groupby(connection.execute(fk, params), - lambda row: row.FKEYNAME) - for fkeyname, fkey in fk_sets: - fkey = list(fkey) - if include_columns: - key_cols = set([r.COLUMNNAME for r in fkey]) - if key_cols != include_columns: - continue - - columns, referants = [], [] - quote = self.identifier_preparer._maybe_quote_identifier - - for row in fkey: - columns.append(normalize(row.COLUMNNAME)) - if table.schema or not row.in_schema: - referants.append('.'.join( - [quote(normalize(row[c])) - for c in ('REFSCHEMANAME', 'REFTABLENAME', - 'REFCOLUMNNAME')])) - else: - referants.append('.'.join( - [quote(normalize(row[c])) - for c in ('REFTABLENAME', 'REFCOLUMNNAME')])) - - constraint_kw = {'name': fkeyname.lower()} - if fkey[0].RULE is not None: - rule = fkey[0].RULE - if rule.startswith('DELETE '): - rule = rule[7:] - constraint_kw['ondelete'] = rule - - table_kw = {} - if table.schema or not row.in_schema: - table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME) - - ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME), - table_kw.get('schema')) - if ref_key not in table.metadata.tables: - schema.Table(normalize(fkey[0].REFTABLENAME), - table.metadata, - autoload=True, autoload_with=connection, - **table_kw) - - constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True, - **constraint_kw) - table.append_constraint(constraint) - - def has_sequence(self, connection, name): - # [ticket:726] makes this schema-aware. - denormalize = self.identifier_preparer._denormalize_name - sql = ("SELECT sequence_name FROM SEQUENCES " - "WHERE SEQUENCE_NAME=? ") - - rp = connection.execute(sql, denormalize(name)) - found = bool(rp.fetchone()) - rp.close() - return found - - -class MaxDBCompiler(compiler.SQLCompiler): - operators = compiler.SQLCompiler.operators.copy() - operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y) - - function_conversion = { - 'CURRENT_DATE': 'DATE', - 'CURRENT_TIME': 'TIME', - 'CURRENT_TIMESTAMP': 'TIMESTAMP', - } - - # These functions must be written without parens when called with no - # parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL' - bare_functions = set([ - 'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP', - 'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP', - 'UTCDATE', 'UTCDIFF']) - - def default_from(self): - return ' FROM DUAL' - - def for_update_clause(self, select): - clause = select.for_update - if clause is True: - return " WITH LOCK EXCLUSIVE" - elif clause is None: - return "" - elif clause == "read": - return " WITH LOCK" - elif clause == "ignore": - return " WITH LOCK (IGNORE) EXCLUSIVE" - elif clause == "nowait": - return " WITH LOCK (NOWAIT) EXCLUSIVE" - elif isinstance(clause, basestring): - return " WITH LOCK %s" % clause.upper() - elif not clause: - return "" - else: - return " WITH LOCK EXCLUSIVE" - - def apply_function_parens(self, func): - if func.name.upper() in self.bare_functions: - return len(func.clauses) > 0 - else: - return True - - def visit_function(self, fn, **kw): - transform = self.function_conversion.get(fn.name.upper(), None) - if transform: - fn = fn._clone() - fn.name = transform - return super(MaxDBCompiler, self).visit_function(fn, **kw) - - def visit_cast(self, cast, **kwargs): - # MaxDB only supports casts * to NUMERIC, * to VARCHAR or - # date/time to VARCHAR. Casts of LONGs will fail. - if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)): - return "NUM(%s)" % self.process(cast.clause) - elif isinstance(cast.type, sqltypes.String): - return "CHR(%s)" % self.process(cast.clause) - else: - return self.process(cast.clause) - - def visit_sequence(self, sequence): - if sequence.optional: - return None - else: - return (self.dialect.identifier_preparer.format_sequence(sequence) + - ".NEXTVAL") - - class ColumnSnagger(visitors.ClauseVisitor): - def __init__(self): - self.count = 0 - self.column = None - def visit_column(self, column): - self.column = column - self.count += 1 - - def _find_labeled_columns(self, columns, use_labels=False): - labels = {} - for column in columns: - if isinstance(column, basestring): - continue - snagger = self.ColumnSnagger() - snagger.traverse(column) - if snagger.count == 1: - if isinstance(column, sql_expr._Label): - labels[unicode(snagger.column)] = column.name - elif use_labels: - labels[unicode(snagger.column)] = column._label - - return labels - - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) - - # ORDER BY clauses in DISTINCT queries must reference aliased - # inner columns by alias name, not true column name. - if order_by and getattr(select, '_distinct', False): - labels = self._find_labeled_columns(select.inner_columns, - select.use_labels) - if labels: - for needs_alias in labels.keys(): - r = re.compile(r'(^| )(%s)(,| |$)' % - re.escape(needs_alias)) - order_by = r.sub((r'\1%s\3' % labels[needs_alias]), - order_by) - - # No ORDER BY in subqueries. - if order_by: - if self.is_subquery(): - # It's safe to simply drop the ORDER BY if there is no - # LIMIT. Right? Other dialects seem to get away with - # dropping order. - if select._limit: - raise exc.InvalidRequestError( - "MaxDB does not support ORDER BY in subqueries") - else: - return "" - return " ORDER BY " + order_by - else: - return "" + # No ORDER BY in subqueries. + if order_by: + if self.is_subquery(): + # It's safe to simply drop the ORDER BY if there is no + # LIMIT. Right? Other dialects seem to get away with + # dropping order. + if select._limit: + raise exc.InvalidRequestError( + "MaxDB does not support ORDER BY in subqueries") + else: + return "" + return " ORDER BY " + order_by + else: + return "" def get_select_precolumns(self, select): # Convert a subquery's LIMIT to TOP @@ -947,10 +676,10 @@ class MaxDBIdentifierPreparer(compiler.IdentifierPreparer): return name -class MaxDBSchemaGenerator(compiler.SchemaGenerator): +class MaxDBDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kw): colspec = [self.preparer.format_column(column), - column.type.dialect_impl(self.dialect).get_col_spec()] + self.dialect.type_compiler.process(column.type)] if not column.nullable: colspec.append('NOT NULL') @@ -996,7 +725,7 @@ class MaxDBSchemaGenerator(compiler.SchemaGenerator): else: return None - def visit_sequence(self, sequence): + def visit_create_sequence(self, create): """Creates a SEQUENCE. TODO: move to module doc? @@ -1024,7 +753,8 @@ class MaxDBSchemaGenerator(compiler.SchemaGenerator): maxdb_no_cache Defaults to False. If true, sets NOCACHE. """ - + sequence = create.element + if (not sequence.optional and (not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name))): @@ -1061,18 +791,251 @@ class MaxDBSchemaGenerator(compiler.SchemaGenerator): elif opts.get('no_cache', False): ddl.append('NOCACHE') - self.append(' '.join(ddl)) - self.execute() + return ' '.join(ddl) -class MaxDBSchemaDropper(compiler.SchemaDropper): - def visit_sequence(self, sequence): - if (not sequence.optional and - (not self.checkfirst or - self.dialect.has_sequence(self.connection, sequence.name))): - self.append("DROP SEQUENCE %s" % - self.preparer.format_sequence(sequence)) - self.execute() +class MaxDBDialect(default.DefaultDialect): + name = 'maxdb' + supports_alter = True + supports_unicode_statements = True + max_identifier_length = 32 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + preexecute_pk_sequences = True + + preparer = MaxDBIdentifierPreparer + statement_compiler = MaxDBCompiler + ddl_compiler = MaxDBDDLCompiler + defaultrunner = MaxDBDefaultRunner + execution_ctx_cls = MaxDBExecutionContext + + colspecs = colspecs + ischema_names = ischema_names + + # MaxDB-specific + datetimeformat = 'internal' + + def __init__(self, _raise_known_sql_errors=False, **kw): + super(MaxDBDialect, self).__init__(**kw) + self._raise_known = _raise_known_sql_errors + + if self.dbapi is None: + self.dbapi_type_map = {} + else: + self.dbapi_type_map = { + 'Long Binary': MaxBlob(), + 'Long byte_t': MaxBlob(), + 'Long Unicode': MaxText(), + 'Timestamp': MaxTimestamp(), + 'Date': MaxDate(), + 'Time': MaxTime(), + datetime.datetime: MaxTimestamp(), + datetime.date: MaxDate(), + datetime.time: MaxTime(), + } + + def do_execute(self, cursor, statement, parameters, context=None): + res = cursor.execute(statement, parameters) + if isinstance(res, int) and context is not None: + context._rowcount = res + + def do_release_savepoint(self, connection, name): + # Does MaxDB truly support RELEASE SAVEPOINT ? All my attempts + # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS + # BEGIN SQLSTATE: I7065" + # Note that ROLLBACK TO works fine. In theory, a RELEASE should + # just free up some transactional resources early, before the overall + # COMMIT/ROLLBACK so omitting it should be relatively ok. + pass + + def get_default_schema_name(self, connection): + try: + return self._default_schema_name + except AttributeError: + name = self.identifier_preparer._normalize_name( + connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) + self._default_schema_name = name + return name + + def has_table(self, connection, table_name, schema=None): + denormalize = self.identifier_preparer._denormalize_name + bind = [denormalize(table_name)] + if schema is None: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME=? AND" + " TABLES.SCHEMANAME=CURRENT_SCHEMA ") + else: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME = ? AND" + " TABLES.SCHEMANAME=? ") + bind.append(denormalize(schema)) + + rp = connection.execute(sql, bind) + found = bool(rp.fetchone()) + rp.close() + return found + + def table_names(self, connection, schema): + if schema is None: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=CURRENT_SCHEMA ") + rs = connection.execute(sql) + else: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=? ") + matchname = self.identifier_preparer._denormalize_name(schema) + rs = connection.execute(sql, matchname) + normalize = self.identifier_preparer._normalize_name + return [normalize(row[0]) for row in rs] + + def reflecttable(self, connection, table, include_columns): + denormalize = self.identifier_preparer._denormalize_name + normalize = self.identifier_preparer._normalize_name + + st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, ' + ' NULLABLE, "DEFAULT", DEFAULTFUNCTION ' + 'FROM COLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY POS') + + fk = ('SELECT COLUMNNAME, FKEYNAME, ' + ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, ' + ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA ' + ' THEN 1 ELSE 0 END) AS in_schema ' + 'FROM FOREIGNKEYCOLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY FKEYNAME ') + + params = [denormalize(table.name)] + if not table.schema: + st = st % 'CURRENT_SCHEMA' + fk = fk % 'CURRENT_SCHEMA' + else: + st = st % '?' + fk = fk % '?' + params.append(denormalize(table.schema)) + + rows = connection.execute(st, params).fetchall() + if not rows: + raise exc.NoSuchTableError(table.fullname) + + include_columns = set(include_columns or []) + + for row in rows: + (name, mode, col_type, encoding, length, scale, + nullable, constant_def, func_def) = row + + name = normalize(name) + + if include_columns and name not in include_columns: + continue + + type_args, type_kw = [], {} + if col_type == 'FIXED': + type_args = length, scale + # Convert FIXED(10) DEFAULT SERIAL to our Integer + if (scale == 0 and + func_def is not None and func_def.startswith('SERIAL')): + col_type = 'INTEGER' + type_args = length, + elif col_type in 'FLOAT': + type_args = length, + elif col_type in ('CHAR', 'VARCHAR'): + type_args = length, + type_kw['encoding'] = encoding + elif col_type == 'LONG': + type_kw['encoding'] = encoding + + try: + type_cls = ischema_names[col_type.lower()] + type_instance = type_cls(*type_args, **type_kw) + except KeyError: + util.warn("Did not recognize type '%s' of column '%s'" % + (col_type, name)) + type_instance = sqltypes.NullType + + col_kw = {'autoincrement': False} + col_kw['nullable'] = (nullable == 'YES') + col_kw['primary_key'] = (mode == 'KEY') + + if func_def is not None: + if func_def.startswith('SERIAL'): + if col_kw['primary_key']: + # No special default- let the standard autoincrement + # support handle SERIAL pk columns. + col_kw['autoincrement'] = True + else: + # strip current numbering + col_kw['server_default'] = schema.DefaultClause( + sql.text('SERIAL')) + col_kw['autoincrement'] = True + else: + col_kw['server_default'] = schema.DefaultClause( + sql.text(func_def)) + elif constant_def is not None: + col_kw['server_default'] = schema.DefaultClause(sql.text( + "'%s'" % constant_def.replace("'", "''"))) + + table.append_column(schema.Column(name, type_instance, **col_kw)) + + fk_sets = itertools.groupby(connection.execute(fk, params), + lambda row: row.FKEYNAME) + for fkeyname, fkey in fk_sets: + fkey = list(fkey) + if include_columns: + key_cols = set([r.COLUMNNAME for r in fkey]) + if key_cols != include_columns: + continue + + columns, referants = [], [] + quote = self.identifier_preparer._maybe_quote_identifier + + for row in fkey: + columns.append(normalize(row.COLUMNNAME)) + if table.schema or not row.in_schema: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFSCHEMANAME', 'REFTABLENAME', + 'REFCOLUMNNAME')])) + else: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFTABLENAME', 'REFCOLUMNNAME')])) + + constraint_kw = {'name': fkeyname.lower()} + if fkey[0].RULE is not None: + rule = fkey[0].RULE + if rule.startswith('DELETE '): + rule = rule[7:] + constraint_kw['ondelete'] = rule + + table_kw = {} + if table.schema or not row.in_schema: + table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME) + + ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME), + table_kw.get('schema')) + if ref_key not in table.metadata.tables: + schema.Table(normalize(fkey[0].REFTABLENAME), + table.metadata, + autoload=True, autoload_with=connection, + **table_kw) + + constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True, + **constraint_kw) + table.append_constraint(constraint) + + def has_sequence(self, connection, name): + # [ticket:726] makes this schema-aware. + denormalize = self.identifier_preparer._denormalize_name + sql = ("SELECT sequence_name FROM SEQUENCES " + "WHERE SEQUENCE_NAME=? ") + + rp = connection.execute(sql, denormalize(name)) + found = bool(rp.fetchone()) + rp.close() + return found + def _autoserial_column(table): @@ -1090,10 +1053,3 @@ def _autoserial_column(table): return None, None -dialect = MaxDBDialect -dialect.preparer = MaxDBIdentifierPreparer -dialect.statement_compiler = MaxDBCompiler -dialect.schemagenerator = MaxDBSchemaGenerator -dialect.schemadropper = MaxDBSchemaDropper -dialect.defaultrunner = MaxDBDefaultRunner -dialect.execution_ctx_cls = MaxDBExecutionContext \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/maxdb/sapdb.py b/lib/sqlalchemy/dialects/maxdb/sapdb.py new file mode 100644 index 0000000000..10e61228e9 --- /dev/null +++ b/lib/sqlalchemy/dialects/maxdb/sapdb.py @@ -0,0 +1,17 @@ +from sqlalchemy.dialects.maxdb.base import MaxDBDialect + +class MaxDB_sapdb(MaxDBDialect): + driver = 'sapdb' + + @classmethod + def dbapi(cls): + from sapdb import dbapi as _dbapi + return _dbapi + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + return [], opts + + +dialect = MaxDB_sapdb \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/dialects/mssql/base.py similarity index 100% rename from lib/sqlalchemy/databases/mssql.py rename to lib/sqlalchemy/dialects/mssql/base.py diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 74938abe0d..ad675839e0 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1479,31 +1479,32 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): spec += ' ZEROFILL' return spec - def _extend_string(self, type_, spec): + def _extend_string(self, type_, defaults, spec): """Extend a string-type declaration with standard SQL CHARACTER SET / COLLATE annotations and MySQL specific extensions. """ - if not self._mysql_type(type_): - return spec + + def attr(name): + return getattr(type_, name, defaults.get(name)) - if type_.charset: - charset = 'CHARACTER SET %s' % type_.charset - elif type_.ascii: + if attr('charset'): + charset = 'CHARACTER SET %s' % attr('charset') + elif attr('ascii'): charset = 'ASCII' - elif type_.unicode: + elif attr('unicode'): charset = 'UNICODE' else: charset = None - if type_.collation: + if attr('collation'): collation = 'COLLATE %s' % type_.collation - elif type_.binary: + elif attr('binary'): collation = 'BINARY' else: collation = None - if type_.national: + if attr('national'): # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. return ' '.join([c for c in ('NATIONAL', spec, collation) if c is not None]) @@ -1607,36 +1608,36 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def visit_TEXT(self, type_): if type_.length: - return self._extend_string(type_, "TEXT(%d)" % type_.length) + return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) else: - return self._extend_string(type_, "TEXT") + return self._extend_string(type_, {}, "TEXT") def visit_TINYTEXT(self, type_): - return self._extend_string(type_, "TINYTEXT") + return self._extend_string(type_, {}, "TINYTEXT") def visit_MEDIUMTEXT(self, type_): - return self._extend_string(type_, "MEDIUMTEXT") + return self._extend_string(type_, {}, "MEDIUMTEXT") def visit_LONGTEXT(self, type_): - return self._extend_string(type_, "LONGTEXT") + return self._extend_string(type_, {}, "LONGTEXT") def visit_VARCHAR(self, type_): if type_.length: - return self._extend_string(type_, "VARCHAR(%d)" % type_.length) + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) else: - return self._extend_string(type_, "VARCHAR") + return self._extend_string(type_, {}, "VARCHAR") def visit_CHAR(self, type_): - return self._extend_string(type_, "CHAR(%(length)s)" % {'length' : type_.length}) + return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length' : type_.length}) def visit_NVARCHAR(self, type_): # We'll actually generate the equiv. "NATIONAL VARCHAR" instead # of "NVARCHAR". - return self._extend_string(type_, "VARCHAR(%(length)s)" % {'length': type_.length}) + return self._extend_string(type_, {'national':True}, "VARCHAR(%(length)s)" % {'length': type_.length}) def visit_NCHAR(self, type_): # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". - return self._extend_string(type_, "CHAR(%(length)s)" % {'length': type_.length}) + return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length': type_.length}) def visit_VARBINARY(self, type_): if type_.length: @@ -1672,10 +1673,10 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): quoted_enums = [] for e in type_.enums: quoted_enums.append("'%s'" % e.replace("'", "''")) - return self._extend_string(type_, "ENUM(%s)" % ",".join(quoted_enums)) + return self._extend_string(type_, {}, "ENUM(%s)" % ",".join(quoted_enums)) def visit_SET(self, type_): - return self._extend_string(type_, "SET(%s)" % ",".join(type_._ddl_values)) + return self._extend_string(type_, {}, "SET(%s)" % ",".join(type_._ddl_values)) def visit_BOOLEAN(self, type): return "BOOL" diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 9bf6db23d6..8daf6404b7 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -220,7 +220,7 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): # Oracle does not allow milliseconds in DATE # Oracle does not support TIME columns - def visit_DATETIME(self, type_): + def visit_datetime(self, type_): return self.visit_DATE(type_) def visit_VARCHAR(self, type_): @@ -229,13 +229,13 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_NVARCHAR(self, type_): return "NVARCHAR2(%(length)s)" % {'length' : type_.length} - def visit_TEXT(self, type_): + def visit_text(self, type_): return self.visit_CLOB(type_) - def visit_BINARY(self, type_): + def visit_binary(self, type_): return self.visit_BLOB(type_) - def visit_BOOLEAN(self, type_): + def visit_boolean(self, type_): return self.visit_SMALLINT(type_) def visit_RAW(self, type_): diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index 4a85e43e99..2c5cadff77 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -367,7 +367,10 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_BIGINT(self, type_): return "BIGINT" - def visit_DATETIME(self, type_): + def visit_datetime(self, type_): + return self.visit_TIMESTAMP(type_) + + def visit_TIMESTAMP(self, type_): return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" def visit_TIME(self, type_): diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 319a5bffc6..bb297bc7dc 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -215,12 +215,6 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): def visit_binary(self, type_): return self.visit_BLOB(type_) - def visit_CLOB(self, type_): - return self.visit_TEXT(type_) - - def visit_NCHAR(self, type_): - return self.visit_CHAR(type_) - class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = set([ 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py new file mode 100644 index 0000000000..f8baf339e8 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -0,0 +1,4 @@ +from sqlalchemy.dialects.sybase import base, pyodbc + +# default dialect +base.dialect = pyodbc.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py new file mode 100644 index 0000000000..300edebf90 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -0,0 +1,453 @@ +# sybase.py +# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch +# Coding: Alexander Houben alexander.houben@thor-solutions.ch +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Sybase database backend. + +--- THIS BACKEND NOT YET TESTED ON SQLALCHEMY 0.6 --- + +Known issues / TODO: + + * Uses the mx.ODBC driver from egenix (version 2.1.0) + * The current version of sqlalchemy.databases.sybase only supports + mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need + some development) + * Support for pyodbc has been built in but is not yet complete (needs + further development) + * Results of running tests/alltests.py: + Ran 934 tests in 287.032s + FAILED (failures=3, errors=1) + * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751) +""" + +import datetime, operator + +from sqlalchemy import util, sql, schema, exc +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, base +from sqlalchemy import types as sqltypes +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import MetaData, Table, Column +from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey +from sqlalchemy.dialects.sybase.schema import * + +__all__ = [ + 'SybaseMoney', 'SybaseSmallMoney', + 'SybaseUniqueIdentifier', + ] + + +RESERVED_WORDS = set([ + "add", "all", "alter", "and", + "any", "as", "asc", "backup", + "begin", "between", "bigint", "binary", + "bit", "bottom", "break", "by", + "call", "capability", "cascade", "case", + "cast", "char", "char_convert", "character", + "check", "checkpoint", "close", "comment", + "commit", "connect", "constraint", "contains", + "continue", "convert", "create", "cross", + "cube", "current", "current_timestamp", "current_user", + "cursor", "date", "dbspace", "deallocate", + "dec", "decimal", "declare", "default", + "delete", "deleting", "desc", "distinct", + "do", "double", "drop", "dynamic", + "else", "elseif", "encrypted", "end", + "endif", "escape", "except", "exception", + "exec", "execute", "existing", "exists", + "externlogin", "fetch", "first", "float", + "for", "force", "foreign", "forward", + "from", "full", "goto", "grant", + "group", "having", "holdlock", "identified", + "if", "in", "index", "index_lparen", + "inner", "inout", "insensitive", "insert", + "inserting", "install", "instead", "int", + "integer", "integrated", "intersect", "into", + "iq", "is", "isolation", "join", + "key", "lateral", "left", "like", + "lock", "login", "long", "match", + "membership", "message", "mode", "modify", + "natural", "new", "no", "noholdlock", + "not", "notify", "null", "numeric", + "of", "off", "on", "open", + "option", "options", "or", "order", + "others", "out", "outer", "over", + "passthrough", "precision", "prepare", "primary", + "print", "privileges", "proc", "procedure", + "publication", "raiserror", "readtext", "real", + "reference", "references", "release", "remote", + "remove", "rename", "reorganize", "resource", + "restore", "restrict", "return", "revoke", + "right", "rollback", "rollup", "save", + "savepoint", "scroll", "select", "sensitive", + "session", "set", "setuser", "share", + "smallint", "some", "sqlcode", "sqlstate", + "start", "stop", "subtrans", "subtransaction", + "synchronize", "syntax_error", "table", "temporary", + "then", "time", "timestamp", "tinyint", + "to", "top", "tran", "trigger", + "truncate", "tsequal", "unbounded", "union", + "unique", "unknown", "unsigned", "update", + "updating", "user", "using", "validate", + "values", "varbinary", "varchar", "variable", + "varying", "view", "wait", "waitfor", + "when", "where", "while", "window", + "with", "with_cube", "with_lparen", "with_rollup", + "within", "work", "writetext", + ]) + + +class SybaseImage(sqltypes.Binary): + __visit_name__ = 'IMAGE' + +class SybaseBit(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + +class SybaseMoney(sqltypes.TypeEngine): + __visit_name__ = "MONEY" + +class SybaseSmallMoney(SybaseMoney): + __visit_name__ = "SMALLMONEY" + +class SybaseUniqueIdentifier(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + +class SybaseBoolean(sqltypes.Boolean): + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class SybaseTypeCompiler(compiler.GenericTypeCompiler): + def visit_binary(self, type_): + return self.visit_IMAGE(type_) + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return "SMALLMONEY" + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + +colspecs = { + sqltypes.Binary : SybaseImage, + sqltypes.Boolean : SybaseBoolean, +} + +ischema_names = { + 'integer' : sqltypes.INTEGER, + 'unsigned int' : sqltypes.Integer, + 'unsigned smallint' : sqltypes.SmallInteger, + 'unsigned bigint' : sqltypes.BigInteger, + 'bigint': sqltypes.BIGINT, + 'smallint' : sqltypes.SMALLINT, + 'tinyint' : sqltypes.SmallInteger, + 'varchar' : sqltypes.VARCHAR, + 'long varchar' : sqltypes.Text, + 'char' : sqltypes.CHAR, + 'decimal' : sqltypes.DECIMAL, + 'numeric' : sqltypes.NUMERIC, + 'float' : sqltypes.FLOAT, + 'double' : sqltypes.Numeric, + 'binary' : sqltypes.Binary, + 'long binary' : sqltypes.Binary, + 'varbinary' : sqltypes.Binary, + 'bit': SybaseBit, + 'image' : SybaseImage, + 'timestamp': sqltypes.TIMESTAMP, + 'money': SybaseMoney, + 'smallmoney': SybaseSmallMoney, + 'uniqueidentifier': SybaseUniqueIdentifier, + +} + + +class SybaseExecutionContext(default.DefaultExecutionContext): + + def post_exec(self): + if self.compiled.isinsert: + table = self.compiled.statement.table + # get the inserted values of the primary key + + # get any sequence IDs first (using @@identity) + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + lastrowid = int(row[0]) + if lastrowid > 0: + # an IDENTITY was inserted, fetch it + # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! + if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: + self._last_inserted_ids = [lastrowid] + else: + self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] + + +class SybaseSQLCompiler(compiler.SQLCompiler): + operators = compiler.SQLCompiler.operators.copy() + operators.update({ + sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y), + }) + + def bindparam_string(self, name): + res = super(SybaseSQLCompiler, self).bindparam_string(name) + if name.lower().startswith('literal'): + res = 'STRING(%s)' % res + return res + + def get_select_precolumns(self, select): + s = select._distinct and "DISTINCT " or "" + if select._limit: + #if select._limit == 1: + #s += "FIRST " + #else: + #s += "TOP %s " % (select._limit,) + s += "TOP %s " % (select._limit,) + if select._offset: + if not select._limit: + # FIXME: sybase doesn't allow an offset without a limit + # so use a huge value for TOP here + s += "TOP 1000000 " + s += "START AT %s " % (select._offset+1,) + return s + + def limit_clause(self, select): + # Limit in sybase is after the select keyword + return "" + + def visit_binary(self, binary): + """Move bind parameters to the right-hand side of an operator, where possible.""" + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) + else: + return super(SybaseSQLCompiler, self).visit_binary(binary) + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) + + function_rewrites = {'current_date': 'getdate', + } + def visit_function(self, func): + func.name = self.function_rewrites.get(func.name, func.name) + res = super(SybaseSQLCompiler, self).visit_function(func) + if func.name.lower() == 'getdate': + # apply CAST operator + # FIXME: what about _pyodbc ? + cast = expression._Cast(func, SybaseDate_mxodbc) + # infinite recursion + # res = self.visit_cast(cast) + res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) + return res + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class SybaseDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + + colspec = self.preparer.format_column(column) + + if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ + column.autoincrement and isinstance(column.type, sqltypes.Integer): + if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): + column.sequence = schema.Sequence(column.name + '_seq') + + if hasattr(column, 'sequence'): + column.table.has_sequence = column + #colspec += " numeric(30,0) IDENTITY" + colspec += " Integer IDENTITY" + else: + colspec += " " + self.dialect.type_compiler.process(column.type) + + if not column.nullable: + colspec += " NOT NULL" + + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(index.table.name), + self.preparer.quote(self._validate_identifier(index.name, False), index.quote) + ) + +class SybaseIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + +class SybaseDialect(default.DefaultDialect): + name = 'sybase' + supports_unicode_statements = False + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + colspecs = colspecs + ischema_names = ischema_names + + type_compiler = SybaseTypeCompiler + statement_compiler = SybaseSQLCompiler + ddl_compiler = SybaseDDLCompiler + preparer = SybaseIdentifierPreparer + + schema_name = "dba" + + def __init__(self, **params): + super(SybaseDialect, self).__init__(**params) + self.text_as_varchar = False + + def last_inserted_ids(self): + return self.context.last_inserted_ids + + def get_default_schema_name(self, connection): + return self.schema_name + + def table_names(self, connection, schema): + """Ignore the schema and the charset for now.""" + s = sql.select([tables.c.table_name], + sql.not_(tables.c.table_name.like("SYS%")) and + tables.c.creator >= 100 + ) + rp = connection.execute(s) + return [row[0] for row in rp.fetchall()] + + def has_table(self, connection, tablename, schema=None): + # FIXME: ignore schemas for sybase + s = sql.select([tables.c.table_name], tables.c.table_name == tablename) + + c = connection.execute(s) + row = c.fetchone() + print "has_table: " + tablename + ": " + str(bool(row is not None)) + return row is not None + + def reflecttable(self, connection, table, include_columns): + # Get base columns + if table.schema is not None: + current_schema = table.schema + else: + current_schema = self.get_default_schema_name(connection) + + s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id]) + + c = connection.execute(s) + found_table = False + # makes sure we append the columns in the correct order + while True: + row = c.fetchone() + if row is None: + break + found_table = True + (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = ( + row[columns.c.column_name], + row[domains.c.domain_name], + row[columns.c.nulls] == 'Y', + row[columns.c.width], + row[domains.c.precision], + row[columns.c.scale], + row[columns.c.default], + row[columns.c.pkey] == 'Y', + row[columns.c.max_identity], + row[tables.c.table_id], + row[columns.c.column_id], + ) + if include_columns and name not in include_columns: + continue + + # FIXME: else problems with SybaseBinary(size) + if numericscale == 0: + numericscale = None + + args = [] + for a in (charlen, numericprec, numericscale): + if a is not None: + args.append(a) + coltype = self.ischema_names.get(type, None) + if coltype == SybaseString and charlen == -1: + coltype = SybaseText() + else: + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % + (type, name)) + coltype = sqltypes.NULLTYPE + coltype = coltype(*args) + colargs = [] + if default is not None: + colargs.append(schema.DefaultClause(sql.text(default))) + + # any sequences ? + col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs) + if int(max_identity) > 0: + col.sequence = schema.Sequence(name + '_identity') + col.sequence.start = int(max_identity) + col.sequence.increment = 1 + + # append the column + table.append_column(col) + + # any foreign key constraint for this table ? + # note: no multi-column foreign keys are considered + s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name } + c = connection.execute(s) + foreignKeys = {} + while True: + row = c.fetchone() + if row is None: + break + (foreign_table, foreign_column, primary_table, primary_column) = ( + row[0], row[1], row[2], row[3], + ) + if not primary_table in foreignKeys.keys(): + foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]] + else: + foreignKeys[primary_table][0].append('%s'%(foreign_column)) + foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column)) + for primary_table in foreignKeys.keys(): + #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)])) + table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True)) + + if not found_table: + raise exc.NoSuchTableError(table.name) + diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py new file mode 100644 index 0000000000..86a23d5bcd --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py @@ -0,0 +1,10 @@ +from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext +from sqlalchemy.connectors.mxodbc import MxODBCConnector + +class SybaseExecutionContext_mxodbc(SybaseExecutionContext): + pass + +class Sybase_mxodbc(MxODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_mxodbc + +dialect = Sybase_mxodbc \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py new file mode 100644 index 0000000000..61c6f32928 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -0,0 +1,11 @@ +from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext +from sqlalchemy.connectors.pyodbc import PyODBCConnector + +class SybaseExecutionContext_pyodbc(SybaseExecutionContext): + pass + + +class Sybase_pyodbc(PyODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_pyodbc + +dialect = Sybase_pyodbc \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/sybase/schema.py b/lib/sqlalchemy/dialects/sybase/schema.py new file mode 100644 index 0000000000..15ac6b27bd --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/schema.py @@ -0,0 +1,51 @@ +from sqlalchemy import * + +ischema = MetaData() + +tables = Table("SYSTABLE", ischema, + Column("table_id", Integer, primary_key=True), + Column("file_id", SMALLINT), + Column("table_name", CHAR(128)), + Column("table_type", CHAR(10)), + Column("creator", Integer), + #schema="information_schema" + ) + +domains = Table("SYSDOMAIN", ischema, + Column("domain_id", Integer, primary_key=True), + Column("domain_name", CHAR(128)), + Column("type_id", SMALLINT), + Column("precision", SMALLINT, quote=True), + #schema="information_schema" + ) + +columns = Table("SYSCOLUMN", ischema, + Column("column_id", Integer, primary_key=True), + Column("table_id", Integer, ForeignKey(tables.c.table_id)), + Column("pkey", CHAR(1)), + Column("column_name", CHAR(128)), + Column("nulls", CHAR(1)), + Column("width", SMALLINT), + Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)), + # FIXME: should be mx.BIGINT + Column("max_identity", Integer), + # FIXME: should be mx.ODBC.Windows.LONGVARCHAR + Column("default", String), + Column("scale", Integer), + #schema="information_schema" + ) + +foreignkeys = Table("SYSFOREIGNKEY", ischema, + Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True), + Column("foreign_key_id", SMALLINT, primary_key=True), + Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)), + #schema="information_schema" + ) +fkcols = Table("SYSFKCOL", ischema, + Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True), + Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True), + Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True), + Column("primary_column_id", Integer), + #schema="information_schema" + ) + diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6ea74395ec..7305f497ef 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -964,6 +964,9 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_SMALLINT(self, type_): return "SMALLINT" + def visit_BIGINT(self, type_): + return "BIGINT" + def visit_TIMESTAMP(self, type_): return 'TIMESTAMP' @@ -988,9 +991,6 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_BLOB(self, type_): return "BLOB" - def visit_BINARY(self, type_): - return "BINARY" - def visit_BOOLEAN(self, type_): return "BOOLEAN" @@ -998,25 +998,35 @@ class GenericTypeCompiler(engine.TypeCompiler): return "TEXT" def visit_binary(self, type_): - return self.visit_BINARY(type_) + return self.visit_BLOB(type_) + def visit_boolean(self, type_): return self.visit_BOOLEAN(type_) + def visit_time(self, type_): return self.visit_TIME(type_) + def visit_datetime(self, type_): return self.visit_DATETIME(type_) + def visit_date(self, type_): return self.visit_DATE(type_) + def visit_small_integer(self, type_): return self.visit_SMALLINT(type_) + def visit_integer(self, type_): return self.visit_INTEGER(type_) + def visit_float(self, type_): return self.visit_FLOAT(type_) + def visit_numeric(self, type_): return self.visit_NUMERIC(type_) + def visit_string(self, type_): return self.visit_VARCHAR(type_) + def visit_text(self, type_): return self.visit_TEXT(type_) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 5e3946c1f1..2f5548236f 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -12,9 +12,9 @@ For more information see the SQLAlchemy documentation on types. """ __all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', - 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT', + 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text', 'FLOAT', 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', - 'BOOLEAN', 'SMALLINT', 'DATE', 'TIME', + 'BOOLEAN', 'SMALLINT', 'INTEGER','DATE', 'TIME', 'String', 'Integer', 'SmallInteger', 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'MutableType', 'Concatenable', 'UnicodeText', 'PickleType', 'Interval', @@ -39,6 +39,9 @@ class AbstractType(Visitable): def __init__(self, *args, **kwargs): pass + def compile(self, dialect): + return dialect.type_compiler.process(self) + def copy_value(self, value): return value @@ -609,6 +612,16 @@ class SmallInteger(Integer): __visit_name__ = 'small_integer' +class BigInteger(Integer): + """A type for bigger ``int`` integers. + + Typically generates a ``BIGINT`` in DDL, and otherwise acts like + a normal :class:`Integer` on the Python side. + + """ + + __visit_name__ = 'big_integer' + class Numeric(TypeEngine): """A type for fixed precision numbers. @@ -911,6 +924,11 @@ class SMALLINT(SmallInteger): __visit_name__ = 'SMALLINT' +class BIGINT(SmallInteger): + """The SQL BIGINT type.""" + + __visit_name__ = 'BIGINT' + class TIMESTAMP(DateTime): """The SQL TIMESTAMP type.""" diff --git a/test/sql/select.py b/test/sql/select.py index 72c552fffd..1790b3cdea 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -1240,7 +1240,7 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)") self.assertEqual(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0]) self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1]) self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2]) - self.assertEqual(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3])) + self.assertEqual(str(cast(1234, Text).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3])) self.assertEqual(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4])) # fixme: shoving all of this dialect-specific stuff in one test # is now officialy completely ridiculous AND non-obviously omits diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 73fcabb4a0..34acba4c74 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -68,12 +68,12 @@ class AdaptTest(TestBase): firebird_dialect = firebird.FBDialect() for dialect, start, test in [ - (oracle_dialect, String(), oracle.OracleString), - (oracle_dialect, VARCHAR(), oracle.OracleString), - (oracle_dialect, String(50), oracle.OracleString), - (oracle_dialect, Unicode(), oracle.OracleString), + (oracle_dialect, String(), String), + (oracle_dialect, VARCHAR(), VARCHAR), + (oracle_dialect, String(50), String), + (oracle_dialect, Unicode(), Unicode), (oracle_dialect, UnicodeText(), oracle.OracleText), - (oracle_dialect, NCHAR(), oracle.OracleString), + (oracle_dialect, NCHAR(), NCHAR), (oracle_dialect, oracle.OracleRaw(50), oracle.OracleRaw), (mysql_dialect, String(), mysql.MSString), (mysql_dialect, VARCHAR(), mysql.MSString), @@ -96,7 +96,43 @@ class AdaptTest(TestBase): ]: assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect)) - + def test_uppercase_rendering(self): + """Test that uppercase types from types.py always render as their type. + + As of SQLA 0.6, using an uppercase type means you want specifically that + type. If the database in use doesn't support that DDL, it (the DB backend) + should raise an error - it means you should be using a lowercased (genericized) type. + + """ + + for dialect in [oracle.dialect(), mysql.dialect(), postgres.dialect(), sqlite.dialect(), sybase.dialect(), informix.dialect(), maxdb.dialect()]: #engines.all_dialects(): + for type_, expected in ( + (FLOAT, "FLOAT"), + (NUMERIC, "NUMERIC"), + (DECIMAL, "DECIMAL"), + (INTEGER, "INTEGER"), + (SMALLINT, "SMALLINT"), + (TIMESTAMP, "TIMESTAMP"), + (DATETIME, "DATETIME"), + (DATE, "DATE"), + (TIME, "TIME"), + (CLOB, "CLOB"), + (VARCHAR, "VARCHAR"), + (NVARCHAR, ("NVARCHAR", "NATIONAL VARCHAR")), + (CHAR, "CHAR"), + (NCHAR, ("NCHAR", "NATIONAL CHAR")), + (BLOB, "BLOB"), + (BOOLEAN, ("BOOLEAN", "BOOL")) + ): + if isinstance(expected, str): + expected = (expected, ) + for exp in expected: + compiled = type_().compile(dialect=dialect) + if exp in compiled: + break + else: + assert False, "%r matches none of %r for dialect %s" % (compiled, expected, dialect.name) + class UserDefinedTest(TestBase): """tests user-defined types."""