From ddf875dd0c554660d3a4f58e261b03e22a92a625 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 10 Oct 2007 15:19:28 +0000 Subject: [PATCH] - initial sybase support checkin, [ticket:785] --- lib/sqlalchemy/databases/mxODBC.py | 68 +++ lib/sqlalchemy/databases/sybase.py | 875 +++++++++++++++++++++++++++++ test/engine/reflection.py | 1 + test/orm/assorted_eager.py | 1 + test/sql/query.py | 8 +- test/sql/unicode.py | 10 +- 6 files changed, 954 insertions(+), 9 deletions(-) create mode 100644 lib/sqlalchemy/databases/mxODBC.py create mode 100644 lib/sqlalchemy/databases/sybase.py diff --git a/lib/sqlalchemy/databases/mxODBC.py b/lib/sqlalchemy/databases/mxODBC.py new file mode 100644 index 0000000000..61649b9b3c --- /dev/null +++ b/lib/sqlalchemy/databases/mxODBC.py @@ -0,0 +1,68 @@ +# mxODBC.py +# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch +# Coding: Alexander Houben alexander.houben@thor-solutions.ch +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +''' +A wrapper for a mx.ODBC.Windows DB-API connection. + +Makes sure the mx module is configured to return datetime objects instead of +mx.DateTime.DateTime objects. +''' + +from mx.ODBC.Windows import * + + + +''' +Override the 'cursor' method. +''' + +class Cursor: + + def __init__(self, cursor): + self.cursor = cursor + + def __getattr__(self, attr): + res = getattr(self.cursor, attr) + return res + + def execute(self, *args, **kwargs): + res = self.cursor.execute(*args, **kwargs) + return res + +class Connection: + + def myErrorHandler(self, connection, cursor, errorclass, errorvalue): + err0, err1, err2, err3 = errorvalue + #print ", ".join(["Err%d: %s"%(x, errorvalue[x]) for x in range(4)]) + if int(err1) == 109: + # Ignore "Null value eliminated in aggregate function", this is not an error + return + raise errorclass, errorvalue + + def __init__(self, conn): + self.conn = conn + # install a mx ODBC error handler + self.conn.errorhandler = self.myErrorHandler + + def __getattr__(self, attr): + res = getattr(self.conn, attr) + return res + + def cursor(self, *args, **kwargs): + res = Cursor(self.conn.cursor(*args, **kwargs)) + return res + +# override 'connect' call +def connect(*args, **kwargs): + import mx.ODBC.Windows + conn = mx.ODBC.Windows.Connect(*args, **kwargs) + conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT + return Connection(conn) + +# override 'Connect' call +def Connect(*args, **kwargs): + return self.connect(*args, **kwargs) diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py new file mode 100644 index 0000000000..ba20d6ad51 --- /dev/null +++ b/lib/sqlalchemy/databases/sybase.py @@ -0,0 +1,875 @@ +# sybase.py +# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch +# Coding: Alexander Houben alexander.houben@thor-solutions.ch +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +""" +Sybase database backend, supported through the mxodbc or pyodbc DBAPI2.0 interfaces. + +Known issues / TODO: + + * Uses the mx.ODBC driver from egenix (version 2.1.0) + * The current version of sqlalchemy.databases.sybase only supports mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need some development) + * Support for pyodbc has been built in but is not yet complete (needs further development) + * Results of running tests/alltests.py: +Ran 934 tests in 287.032s + +FAILED (failures=3, errors=1) + * Some tests had to be marked @testing.unsupported('sybase'), see patch for details + * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751) +""" + +import datetime, random, warnings, re, sys, operator + +from sqlalchemy import util, sql, schema, exceptions +from sqlalchemy.sql import compiler, expression +from sqlalchemy.engine import default, base +from sqlalchemy import types as sqltypes +from sqlalchemy.sql import operators as sql_operators +from sqlalchemy import select, MetaData, Table, Column, String, Integer, SMALLINT, CHAR, ForeignKey +from sqlalchemy.schema import PassiveDefault, ForeignKeyConstraint + + +import logging + +RESERVED_WORDS = util.Set([ +"add", "all", "alter", "and", +"any", "as", "asc", "backup", +"begin", "between", "bigint", "binary", +"bit", "bottom", "break", "by", +"call", "capability", "cascade", "case", +"cast", "char", "char_convert", "character", +"check", "checkpoint", "close", "comment", +"commit", "connect", "constraint", "contains", +"continue", "convert", "create", "cross", +"cube", "current", "current_timestamp", "current_user", +"cursor", "date", "dbspace", "deallocate", +"dec", "decimal", "declare", "default", +"delete", "deleting", "desc", "distinct", +"do", "double", "drop", "dynamic", +"else", "elseif", "encrypted", "end", +"endif", "escape", "except", "exception", +"exec", "execute", "existing", "exists", +"externlogin", "fetch", "first", "float", +"for", "force", "foreign", "forward", +"from", "full", "goto", "grant", +"group", "having", "holdlock", "identified", +"if", "in", "index", "index_lparen", +"inner", "inout", "insensitive", "insert", +"inserting", "install", "instead", "int", +"integer", "integrated", "intersect", "into", +"iq", "is", "isolation", "join", +"key", "lateral", "left", "like", +"lock", "login", "long", "match", +"membership", "message", "mode", "modify", +"natural", "new", "no", "noholdlock", +"not", "notify", "null", "numeric", +"of", "off", "on", "open", +"option", "options", "or", "order", +"others", "out", "outer", "over", +"passthrough", "precision", "prepare", "primary", +"print", "privileges", "proc", "procedure", +"publication", "raiserror", "readtext", "real", +"reference", "references", "release", "remote", +"remove", "rename", "reorganize", "resource", +"restore", "restrict", "return", "revoke", +"right", "rollback", "rollup", "save", +"savepoint", "scroll", "select", "sensitive", +"session", "set", "setuser", "share", +"smallint", "some", "sqlcode", "sqlstate", +"start", "stop", "subtrans", "subtransaction", +"synchronize", "syntax_error", "table", "temporary", +"then", "time", "timestamp", "tinyint", +"to", "top", "tran", "trigger", +"truncate", "tsequal", "unbounded", "union", +"unique", "unknown", "unsigned", "update", +"updating", "user", "using", "validate", +"values", "varbinary", "varchar", "variable", +"varying", "view", "wait", "waitfor", +"when", "where", "while", "window", +"with", "with_cube", "with_lparen", "with_rollup", +"within", "work", "writetext", +]) + +ischema = MetaData() + +tables = Table("SYSTABLE", ischema, + Column("table_id", Integer, primary_key=True), + Column("file_id", SMALLINT), + Column("table_name", CHAR(128)), + Column("table_type", CHAR(10)), + Column("creator", Integer), + #schema="information_schema" + ) + +domains = Table("SYSDOMAIN", ischema, + Column("domain_id", Integer, primary_key=True), + Column("domain_name", CHAR(128)), + Column("type_id", SMALLINT), + Column("precision", SMALLINT, quote=True), + #schema="information_schema" + ) + +columns = Table("SYSCOLUMN", ischema, + Column("column_id", Integer, primary_key=True), + Column("table_id", Integer, ForeignKey(tables.c.table_id)), + Column("pkey", CHAR(1)), + Column("column_name", CHAR(128)), + Column("nulls", CHAR(1)), + Column("width", SMALLINT), + Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)), + # FIXME: should be mx.BIGINT + Column("max_identity", Integer), + # FIXME: should be mx.ODBC.Windows.LONGVARCHAR + Column("default", String), + Column("scale", Integer), + #schema="information_schema" + ) + +foreignkeys = Table("SYSFOREIGNKEY", ischema, + Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True), + Column("foreign_key_id", SMALLINT, primary_key=True), + Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)), + #schema="information_schema" + ) +fkcols = Table("SYSFKCOL", ischema, + Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True), + Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True), + Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True), + Column("primary_column_id", Integer), + #schema="information_schema" + ) + +class SybaseTypeError(sqltypes.TypeEngine): + def result_processor(self, dialect): + return None + + def bind_processor(self, dialect): + def process(value): + raise exceptions.NotSupportedError("Data type not supported", [value]) + return process + + def get_col_spec(self): + raise exceptions.NotSupportedError("Data type not supported") + +class SybaseNumeric(sqltypes.Numeric): + def get_col_spec(self): + if self.length is None: + if self.precision is None: + return "NUMERIC" + else: + return "NUMERIC(%(precision)s)" % {'precision' : self.precision} + else: + return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} + +class SybaseFloat(sqltypes.FLOAT, SybaseNumeric): + def __init__(self, precision = 10, asdecimal = False, length = 2, **kwargs): + super(sqltypes.FLOAT, self).__init__(precision, asdecimal, **kwargs) + self.length = length + + def get_col_spec(self): + # if asdecimal is True, handle same way as SybaseNumeric + if self.asdecimal: + return SybaseNumeric.get_col_spec(self) + if self.precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': self.precision} + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return float(value) + if self.asdecimal: + return SybaseNumeric.result_processor(self, dialect) + return process + +class SybaseInteger(sqltypes.Integer): + def get_col_spec(self): + return "INTEGER" + +class SybaseBigInteger(SybaseInteger): + def get_col_spec(self): + return "BIGINT" + +class SybaseTinyInteger(SybaseInteger): + def get_col_spec(self): + return "TINYINT" + +class SybaseSmallInteger(SybaseInteger): + def get_col_spec(self): + return "SMALLINT" + +class SybaseDateTime_mxodbc(sqltypes.DateTime): + def __init__(self, *a, **kw): + super(SybaseDateTime_mxodbc, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + +class SybaseDateTime_pyodbc(sqltypes.DateTime): + def __init__(self, *a, **kw): + super(SybaseDateTime_pyodbc, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + # Convert the datetime.datetime back to datetime.time + return value + return process + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return value + return process + +class SybaseDate_mxodbc(sqltypes.Date): + def __init__(self, *a, **kw): + super(SybaseDate_mxodbc, self).__init__(False) + + def get_col_spec(self): + return "DATE" + +class SybaseDate_pyodbc(sqltypes.Date): + def __init__(self, *a, **kw): + super(SybaseDate_pyodbc, self).__init__(False) + + def get_col_spec(self): + return "DATE" + +class SybaseTime_mxodbc(sqltypes.Time): + def __init__(self, *a, **kw): + super(SybaseTime_mxodbc, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + # Convert the datetime.datetime back to datetime.time + return datetime.time(value.hour, value.minute, value.second, value.microsecond) + return process + +class SybaseTime_pyodbc(sqltypes.Time): + def __init__(self, *a, **kw): + super(SybaseTime_pyodbc, self).__init__(False) + + def get_col_spec(self): + return "DATETIME" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + # Convert the datetime.datetime back to datetime.time + return datetime.time(value.hour, value.minute, value.second, value.microsecond) + return process + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + return datetime.datetime(1970, 1, 1, value.hour, value.minute, value.second, value.microsecond) + return process + +class SybaseText(sqltypes.TEXT): + def get_col_spec(self): + return "TEXT" + +class SybaseString(sqltypes.String): + def get_col_spec(self): + return "VARCHAR(%(length)s)" % {'length' : self.length} + +class SybaseChar(sqltypes.CHAR): + def get_col_spec(self): + return "CHAR(%(length)s)" % {'length' : self.length} + +class SybaseBinary(sqltypes.Binary): + def get_col_spec(self): + return "IMAGE" + +class SybaseBoolean(sqltypes.Boolean): + def get_col_spec(self): + return "BIT" + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class SybaseTimeStamp(sqltypes.TIMESTAMP): + def get_col_spec(self): + return "TIMESTAMP" + +class SybaseMoney(sqltypes.TypeEngine): + def get_col_spec(self): + return "MONEY" + +class SybaseSmallMoney(SybaseMoney): + def get_col_spec(self): + return "SMALLMONEY" + +class SybaseUniqueIdentifier(sqltypes.TypeEngine): + def get_col_spec(self): + return "UNIQUEIDENTIFIER" + +def descriptor(): + return {'name':'sybase', + 'description':'SybaseSQL', + 'arguments':[ + ('user',"Database Username",None), + ('password',"Database Password",None), + ('db',"Database Name",None), + ('host',"Hostname", None), + ]} + +class SybaseSQLExecutionContext(default.DefaultExecutionContext): + pass + +class SybaseSQLExecutionContext_mxodbc(SybaseSQLExecutionContext): + + def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): + super(SybaseSQLExecutionContext_mxodbc, self).__init__(dialect, connection, compiled, statement, parameters) + + def pre_exec(self): + super(SybaseSQLExecutionContext_mxodbc, self).pre_exec() + + def post_exec(self): + if self.compiled.isinsert: + table = self.compiled.statement.table + # get the inserted values of the primary key + + # get any sequence IDs first (using @@identity) + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + lastrowid = int(row[0]) + if lastrowid > 0: + # an IDENTITY was inserted, fetch it + # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! + if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: + self._last_inserted_ids = [lastrowid] + else: + self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] + super(SybaseSQLExecutionContext_mxodbc, self).post_exec() + +class SybaseSQLExecutionContext_pyodbc(SybaseSQLExecutionContext): + + def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): + super(SybaseSQLExecutionContext_pyodbc, self).__init__(dialect, connection, compiled, statement, parameters) + + def pre_exec(self): + super(SybaseSQLExecutionContext_pyodbc, self).pre_exec() + + def post_exec(self): + if self.compiled.isinsert: + table = self.compiled.statement.table + # get the inserted values of the primary key + + # get any sequence IDs first (using @@identity) + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + lastrowid = int(row[0]) + if lastrowid > 0: + # an IDENTITY was inserted, fetch it + # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! + if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: + self._last_inserted_ids = [lastrowid] + else: + self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] + super(SybaseSQLExecutionContext_pyodbc, self).post_exec() + +class SybaseSQLDialect(default.DefaultDialect): + colspecs = { + # FIXME: unicode support + #sqltypes.Unicode : SybaseUnicode, + sqltypes.Integer : SybaseInteger, + sqltypes.SmallInteger : SybaseSmallInteger, + sqltypes.Numeric : SybaseNumeric, + sqltypes.Float : SybaseFloat, + sqltypes.String : SybaseString, + sqltypes.Binary : SybaseBinary, + sqltypes.Boolean : SybaseBoolean, + sqltypes.TEXT : SybaseText, + sqltypes.CHAR : SybaseChar, + sqltypes.TIMESTAMP : SybaseTimeStamp, + sqltypes.FLOAT : SybaseFloat, + } + + ischema_names = { + 'integer' : SybaseInteger, + 'unsigned int' : SybaseInteger, + 'unsigned smallint' : SybaseInteger, + 'unsigned bigint' : SybaseInteger, + 'bigint': SybaseBigInteger, + 'smallint' : SybaseSmallInteger, + 'tinyint' : SybaseTinyInteger, + 'varchar' : SybaseString, + 'long varchar' : SybaseText, + 'char' : SybaseChar, + 'decimal' : SybaseNumeric, + 'numeric' : SybaseNumeric, + 'float' : SybaseFloat, + 'double' : SybaseFloat, + 'binary' : SybaseBinary, + 'long binary' : SybaseBinary, + 'varbinary' : SybaseBinary, + 'bit': SybaseBoolean, + 'image' : SybaseBinary, + 'timestamp': SybaseTimeStamp, + 'money': SybaseMoney, + 'smallmoney': SybaseSmallMoney, + 'uniqueidentifier': SybaseUniqueIdentifier, + + 'java.lang.Object' : SybaseTypeError, + 'java serialization' : SybaseTypeError, + } + + # Sybase backend peculiarities + supports_unicode_statements = False + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + def __new__(cls, dbapi=None, *args, **kwargs): + if cls != SybaseSQLDialect: + return super(SybaseSQLDialect, cls).__new__(cls, *args, **kwargs) + if dbapi: + print dbapi.__name__ + dialect = dialect_mapping.get(dbapi.__name__) + return dialect(*args, **kwargs) + else: + return object.__new__(cls, *args, **kwargs) + + def __init__(self, **params): + super(SybaseSQLDialect, self).__init__(**params) + self.text_as_varchar = False + # FIXME: what is the default schema for sybase connections (DBA?) ? + self.set_default_schema_name("dba") + + def dbapi(cls, module_name=None): + if module_name: + try: + dialect_cls = dialect_mapping[module_name] + return dialect_cls.import_dbapi() + except KeyError: + raise exceptions.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name) + else: + for dialect_cls in dialect_mapping.values(): + try: + return dialect_cls.import_dbapi() + except ImportError, e: + pass + else: + raise ImportError('No DBAPI module detected for SybaseSQL - please install mxodbc') + dbapi = classmethod(dbapi) + + def create_execution_context(self, *args, **kwargs): + return SybaseSQLExecutionContext(self, *args, **kwargs) + + def type_descriptor(self, typeobj): + newobj = sqltypes.adapt_type(typeobj, self.colspecs) + return newobj + + def last_inserted_ids(self): + return self.context.last_inserted_ids + + def get_default_schema_name(self, connection): + return self.schema_name + + def set_default_schema_name(self, schema_name): + self.schema_name = schema_name + + def do_execute(self, cursor, statement, params, **kwargs): + params = tuple(params) + super(SybaseSQLDialect, self).do_execute(cursor, statement, params, **kwargs) + + # FIXME: remove ? + def _execute(self, c, statement, parameters): + try: + if parameters == {}: + parameters = () + c.execute(statement, parameters) + self.context.rowcount = c.rowcount + c.DBPROP_COMMITPRESERVE = "Y" + except Exception, e: + raise exceptions.DBAPIError.instance(statement, parameters, e) + + def table_names(self, connection, schema): + """Ignore the schema and the charset for now.""" + s = sql.select([tables.c.table_name], + sql.not_(tables.c.table_name.like("SYS%")) and + tables.c.creator >= 100 + ) + rp = connection.execute(s) + return [row[0] for row in rp.fetchall()] + + def has_table(self, connection, tablename, schema=None): + # FIXME: ignore schemas for sybase + s = sql.select([tables.c.table_name], tables.c.table_name == tablename) + + c = connection.execute(s) + row = c.fetchone() + print "has_table: " + tablename + ": " + str(bool(row is not None)) + return row is not None + + def reflecttable(self, connection, table, include_columns): + # Get base columns + if table.schema is not None: + current_schema = table.schema + else: + current_schema = self.get_default_schema_name(connection) + + s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id]) + + c = connection.execute(s) + found_table = False + # makes sure we append the columns in the correct order + while True: + row = c.fetchone() + if row is None: + break + found_table = True + (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = ( + row[columns.c.column_name], + row[domains.c.domain_name], + row[columns.c.nulls] == 'Y', + row[columns.c.width], + row[domains.c.precision], + row[columns.c.scale], + row[columns.c.default], + row[columns.c.pkey] == 'Y', + row[columns.c.max_identity], + row[tables.c.table_id], + row[columns.c.column_id], + ) + if include_columns and name not in include_columns: + continue + + # FIXME: else problems with SybaseBinary(size) + if numericscale == 0: + numericscale = None + + args = [] + for a in (charlen, numericprec, numericscale): + if a is not None: + args.append(a) + coltype = self.ischema_names.get(type, None) + if coltype == SybaseString and charlen == -1: + coltype = SybaseText() + else: + if coltype is None: + warnings.warn(RuntimeWarning("Did not recognize type '%s' of column '%s'" % (type, name))) + coltype = sqltypes.NULLTYPE + coltype = coltype(*args) + colargs= [] + if default is not None: + colargs.append(schema.PassiveDefault(sql.text(default))) + + # any sequences ? + col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs) + if int(max_identity) > 0: + col.sequence = schema.Sequence(name + '_identity') + col.sequence.start = int(max_identity) + col.sequence.increment = 1 + + # append the column + table.append_column(col) + + # any foreign key constraint for this table ? + # note: no multi-column foreign keys are considered + s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name } + c = connection.execute(s) + foreignKeys = {} + while True: + row = c.fetchone() + if row is None: + break + (foreign_table, foreign_column, primary_table, primary_column) = ( + row[0], row[1], row[2], row[3], + ) + if not primary_table in foreignKeys.keys(): + foreignKeys[primary_table] = [['%s'%(foreign_column)], ['%s.%s'%(primary_table,primary_column)]] + else: + foreignKeys[primary_table][0].append('%s'%(foreign_column)) + foreignKeys[primary_table][1].append('%s.%s'%(primary_table,primary_column)) + for primary_table in foreignKeys.keys(): + #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)])) + table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1])) + + if not found_table: + raise exceptions.NoSuchTableError(table.name) + + def _get_ischema(self): + if self._ischema is None: + # ??? didnt see an ISchema class in the 'sybase_information_schema' module + self._ischema = ISchema(self) + return self._ischema + ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""") + +class SybaseSQLDialect_mxodbc(SybaseSQLDialect): + def __init__(self, **params): + super(SybaseSQLDialect_mxodbc, self).__init__(**params) + + def dbapi_type_map(self): + return {'getdate' : SybaseDate_mxodbc()} + + def import_dbapi(cls): + #import mx.ODBC.Windows as module + import mxODBC as module + return module + import_dbapi = classmethod(import_dbapi) + + colspecs = SybaseSQLDialect.colspecs.copy() + colspecs[sqltypes.Time] = SybaseTime_mxodbc + colspecs[sqltypes.Date] = SybaseDate_mxodbc + colspecs[sqltypes.DateTime] = SybaseDateTime_mxodbc + + ischema_names = SybaseSQLDialect.ischema_names.copy() + ischema_names['time'] = SybaseTime_mxodbc + ischema_names['date'] = SybaseDate_mxodbc + ischema_names['datetime'] = SybaseDateTime_mxodbc + ischema_names['smalldatetime'] = SybaseDateTime_mxodbc + + def is_disconnect(self, e): + # FIXME: optimize + #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e) + #return True + return False + + def create_execution_context(self, *args, **kwargs): + return SybaseSQLExecutionContext_mxodbc(self, *args, **kwargs) + + def do_execute(self, cursor, statement, parameters, context=None, **kwargs): + super(SybaseSQLDialect_mxodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) + + def create_connect_args(self, url): + '''Return a tuple of *args,**kwargs''' + # FIXME: handle mx.odbc.Windows proprietary args + opts = url.translate_connect_args(username='user') + opts.update(url.query) + argsDict = {} + argsDict['user'] = opts['user'] + argsDict['password'] = opts['password'] + connArgs = [[opts['dsn']], argsDict] + logging.info("Creating connection args: " + repr(connArgs)) + return connArgs + +class SybaseSQLDialect_pyodbc(SybaseSQLDialect): + def __init__(self, **params): + super(SybaseSQLDialect_pyodbc, self).__init__(**params) + + def dbapi_type_map(self): + return {'getdate' : SybaseDate_pyodbc()} + + def import_dbapi(cls): + import mypyodbc as module + return module + import_dbapi = classmethod(import_dbapi) + + colspecs = SybaseSQLDialect.colspecs.copy() + colspecs[sqltypes.Time] = SybaseTime_pyodbc + colspecs[sqltypes.Date] = SybaseDate_pyodbc + colspecs[sqltypes.DateTime] = SybaseDateTime_pyodbc + + ischema_names = SybaseSQLDialect.ischema_names.copy() + ischema_names['time'] = SybaseTime_pyodbc + ischema_names['date'] = SybaseDate_pyodbc + ischema_names['datetime'] = SybaseDateTime_pyodbc + ischema_names['smalldatetime'] = SybaseDateTime_pyodbc + + def is_disconnect(self, e): + # FIXME: optimize + #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e) + #return True + return False + + def create_execution_context(self, *args, **kwargs): + return SybaseSQLExecutionContext_pyodbc(self, *args, **kwargs) + + def do_execute(self, cursor, statement, parameters, context=None, **kwargs): + super(SybaseSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs) + + def create_connect_args(self, url): + '''Return a tuple of *args,**kwargs''' + # FIXME: handle pyodbc proprietary args + opts = url.translate_connect_args(username='user') + opts.update(url.query) + + self.autocommit = False + if 'autocommit' in opts: + self.autocommit = bool(int(opts.pop('autocommit'))) + + argsDict = {} + argsDict['UID'] = opts['user'] + argsDict['PWD'] = opts['password'] + argsDict['DSN'] = opts['dsn'] + connArgs = [[';'.join(["%s=%s"%(key, argsDict[key]) for key in argsDict])], {'autocommit' : self.autocommit}] + logging.info("Creating connection args: " + repr(connArgs)) + return connArgs + +dialect_mapping = { + 'sqlalchemy.databases.mxODBC' : SybaseSQLDialect_mxodbc, +# 'pyodbc' : SybaseSQLDialect_pyodbc, + } + +class SybaseSQLCompiler(compiler.DefaultCompiler): + + operators = compiler.DefaultCompiler.operators.copy() + operators.update({ + sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y), + }) + + def bindparam_string(self, name): + res = super(SybaseSQLCompiler, self).bindparam_string(name) + if name.lower().startswith('literal'): + res = 'STRING(%s)'%res + return res + + def get_select_precolumns(self, select): + s = select._distinct and "DISTINCT " or "" + if select._limit: + #if select._limit == 1: + #s += "FIRST " + #else: + #s += "TOP %s " % (select._limit,) + s += "TOP %s " % (select._limit,) + if select._offset: + if not select._limit: + # FIXME: sybase doesn't allow an offset without a limit + # so use a huge value for TOP here + s += "TOP 1000000 " + s += "START AT %s " % (select._offset+1,) + return s + + def limit_clause(self, select): + # Limit in sybase is after the select keyword + return "" + + def visit_binary(self, binary): + """Move bind parameters to the right-hand side of an operator, where possible.""" + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) + else: + return super(SybaseSQLCompiler, self).visit_binary(binary) + + def label_select_column(self, select, column): + if isinstance(column, expression._Function): + return column.label(column.name + "_" + hex(random.randint(0, 65535))[2:]) + else: + return super(SybaseSQLCompiler, self).label_select_column(select, column) + + function_rewrites = {'current_date': 'getdate', + } + def visit_function(self, func): + func.name = self.function_rewrites.get(func.name, func.name) + res = super(SybaseSQLCompiler, self).visit_function(func) + if func.name.lower() == 'getdate': + # apply CAST operator + # FIXME: what about _pyodbc ? + cast = expression._Cast(func, SybaseDate_mxodbc) + # infinite recursion + # res = self.visit_cast(cast) + if self.stack and self.stack[-1].get('select'): + # not sure if we want to set the typemap here... + self.typemap.setdefault("CAST", cast.type) +# res = "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) + res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) +# elif func.name.lower() == 'count': +# res = 'count(*)' + return res + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery(select) or select._limit): + return " ORDER BY " + order_by + else: + return "" + +class SybaseSQLSchemaGenerator(compiler.SchemaGenerator): + def get_column_specification(self, column, **kwargs): + + colspec = self.preparer.format_column(column) + + if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ + column.autoincrement and isinstance(column.type, sqltypes.Integer): + if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): + column.sequence = schema.Sequence(column.name + '_seq') + + if hasattr(column, 'sequence'): + column.table.has_sequence = column + #colspec += " numeric(30,0) IDENTITY" + colspec += " Integer IDENTITY" + else: + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() + + if not column.nullable: + colspec += " NOT NULL" + + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + +class SybaseSQLSchemaDropper(compiler.SchemaDropper): + def visit_index(self, index): + self.append("\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(index.table.name), + self.preparer.quote_identifier(index.name) + )) + self.execute() + +class SybaseSQLDefaultRunner(base.DefaultRunner): + pass + +class SybaseSQLIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(SybaseSQLIdentifierPreparer, self).__init__(dialect) + + def _escape_identifier(self, value): + #TODO: determin SybaseSQL's escapeing rules + return value + + def _fold_identifier_case(self, value): + #TODO: determin SybaseSQL's case folding rules + return value + +dialect = SybaseSQLDialect +dialect.statement_compiler = SybaseSQLCompiler +dialect.schemagenerator = SybaseSQLSchemaGenerator +dialect.schemadropper = SybaseSQLSchemaDropper +dialect.preparer = SybaseSQLIdentifierPreparer +dialect.defaultrunner = SybaseSQLDefaultRunner diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 8ce3c51a10..d4ada94e4c 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -688,6 +688,7 @@ class CreateDropTest(PersistTest): metadata.drop_all(bind=testbase.db) class UnicodeTest(PersistTest): + @testing.unsupported('sybase') def test_basic(self): try: # the 'convert_unicode' should not get in the way of the reflection diff --git a/test/orm/assorted_eager.py b/test/orm/assorted_eager.py index c98ffeb962..353560826e 100644 --- a/test/orm/assorted_eager.py +++ b/test/orm/assorted_eager.py @@ -136,6 +136,7 @@ class EagerTest(AssertMixin): print result assert result == [u'1 Some Category', u'3 Some Category'] + @testing.unsupported('sybase') def test_withoutouterjoin_literal(self): s = create_session() q=s.query(Test).options(eagerload('category')) diff --git a/test/sql/query.py b/test/sql/query.py index 4006d56171..5a5965f3d4 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -674,7 +674,7 @@ class CompoundTest(PersistTest): found2 = self._fetchall_sorted(e.alias('foo').select().execute()) self.assertEquals(found2, wanted) - @testing.unsupported('mysql') + @testing.unsupported('mysql', 'sybase') def test_intersect(self): i = intersect( select([t2.c.col3, t2.c.col4]), @@ -689,7 +689,7 @@ class CompoundTest(PersistTest): found2 = self._fetchall_sorted(i.alias('bar').select().execute()) self.assertEquals(found2, wanted) - @testing.unsupported('mysql', 'oracle') + @testing.unsupported('mysql', 'oracle', 'sybase') def test_except_style1(self): e = except_(union( select([t1.c.col3, t1.c.col4]), @@ -703,7 +703,7 @@ class CompoundTest(PersistTest): found = self._fetchall_sorted(e.alias('bar').select().execute()) self.assertEquals(found, wanted) - @testing.unsupported('mysql', 'oracle') + @testing.unsupported('mysql', 'oracle', 'sybase') def test_except_style2(self): e = except_(union( select([t1.c.col3, t1.c.col4]), @@ -720,7 +720,7 @@ class CompoundTest(PersistTest): found2 = self._fetchall_sorted(e.alias('bar').select().execute()) self.assertEquals(found2, wanted) - @testing.unsupported('sqlite', 'mysql', 'oracle') + @testing.unsupported('sqlite', 'mysql', 'oracle', 'sybase') def test_except_style3(self): # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc e = except_( diff --git a/test/sql/unicode.py b/test/sql/unicode.py index 8174ab8b68..1b76983011 100644 --- a/test/sql/unicode.py +++ b/test/sql/unicode.py @@ -8,7 +8,7 @@ from testlib.engines import utf8_engine class UnicodeSchemaTest(PersistTest): - @testing.unsupported('oracle') + @testing.unsupported('oracle', 'sybase') def setUpAll(self): global unicode_bind, metadata, t1, t2, t3 @@ -46,20 +46,20 @@ class UnicodeSchemaTest(PersistTest): ) metadata.create_all() - @testing.unsupported('oracle') + @testing.unsupported('oracle', 'sybase') def tearDown(self): if metadata.tables: t3.delete().execute() t2.delete().execute() t1.delete().execute() - @testing.unsupported('oracle') + @testing.unsupported('oracle', 'sybase') def tearDownAll(self): global unicode_bind metadata.drop_all() del unicode_bind - @testing.unsupported('oracle') + @testing.unsupported('oracle', 'sybase') def test_insert(self): t1.insert().execute({u'méil':1, u'\u6e2c\u8a66':5}) t2.insert().execute({'a':1, 'b':1}) @@ -72,7 +72,7 @@ class UnicodeSchemaTest(PersistTest): assert t2.select().execute().fetchall() == [(1, 1)] assert t3.select().execute().fetchall() == [(1, 5, 1, 1)] - @testing.unsupported('oracle') + @testing.unsupported('oracle', 'sybase') def test_reflect(self): t1.insert().execute({u'méil':2, u'\u6e2c\u8a66':7}) t2.insert().execute({'a':2, 'b':2}) -- 2.47.3