- structured maxdb, sybase, informix dialects. obviously no testing has been done.
FLOAT,
Float,
INT,
+ INTEGER,
Integer,
Interval,
NCHAR,
+ NVARCHAR,
NUMERIC,
Numeric,
PickleType,
--- /dev/null
+from sqlalchemy.connectors import Connector
+
+class MxODBCConnector(Connector):
+ driver='mxodbc'
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+ supports_unicode_statements = False
+ supports_unicode_binds = False
+
+ @classmethod
+ def import_dbapi(cls):
+ import mxODBC as module
+ return module
+
+ def create_connect_args(self, url):
+ '''Return a tuple of *args,**kwargs'''
+ # FIXME: handle mx.odbc.Windows proprietary args
+ opts = url.translate_connect_args(username='user')
+ opts.update(url.query)
+ argsDict = {}
+ argsDict['user'] = opts['user']
+ argsDict['password'] = opts['password']
+ connArgs = [[opts['dsn']], argsDict]
+ return connArgs
from sqlalchemy.dialects.sqlite import base as sqlite
from sqlalchemy.dialects.postgres import base as postgres
from sqlalchemy.dialects.mysql import base as mysql
+from sqlalchemy.dialects.oracle import base as oracle
+from sqlalchemy.dialects.firebird import base as firebird
+from sqlalchemy.dialects.maxdb import base as maxdb
+from sqlalchemy.dialects.informix import base as informix
+from sqlalchemy.dialects.mssql import base as mssql
+from sqlalchemy.dialects.access import base as access
+from sqlalchemy.dialects.sybase import base as sybase
+
+
__all__ = (
+++ /dev/null
-# mxODBC.py
-# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch
-# Coding: Alexander Houben alexander.houben@thor-solutions.ch
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""
-A wrapper for a mx.ODBC.Windows DB-API connection.
-
-Makes sure the mx module is configured to return datetime objects instead
-of mx.DateTime.DateTime objects.
-"""
-
-from mx.ODBC.Windows import *
-
-
-class Cursor:
- def __init__(self, cursor):
- self.cursor = cursor
-
- def __getattr__(self, attr):
- res = getattr(self.cursor, attr)
- return res
-
- def execute(self, *args, **kwargs):
- res = self.cursor.execute(*args, **kwargs)
- return res
-
-
-class Connection:
- def myErrorHandler(self, connection, cursor, errorclass, errorvalue):
- err0, err1, err2, err3 = errorvalue
- #print ", ".join(["Err%d: %s"%(x, errorvalue[x]) for x in range(4)])
- if int(err1) == 109:
- # Ignore "Null value eliminated in aggregate function", this is not an error
- return
- raise errorclass, errorvalue
-
- def __init__(self, conn):
- self.conn = conn
- # install a mx ODBC error handler
- self.conn.errorhandler = self.myErrorHandler
-
- def __getattr__(self, attr):
- res = getattr(self.conn, attr)
- return res
-
- def cursor(self, *args, **kwargs):
- res = Cursor(self.conn.cursor(*args, **kwargs))
- return res
-
-
-# override 'connect' call
-def connect(*args, **kwargs):
- import mx.ODBC.Windows
- conn = mx.ODBC.Windows.Connect(*args, **kwargs)
- conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT
- return Connection(conn)
-Connect = connect
+++ /dev/null
-# sybase.py
-# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch
-# Coding: Alexander Houben alexander.houben@thor-solutions.ch
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: http://www.opensource.org/licenses/mit-license.php
-
-"""
-Sybase database backend.
-
-Known issues / TODO:
-
- * Uses the mx.ODBC driver from egenix (version 2.1.0)
- * The current version of sqlalchemy.databases.sybase only supports
- mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need
- some development)
- * Support for pyodbc has been built in but is not yet complete (needs
- further development)
- * Results of running tests/alltests.py:
- Ran 934 tests in 287.032s
- FAILED (failures=3, errors=1)
- * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751)
-"""
-
-import datetime, operator
-
-from sqlalchemy import util, sql, schema, exc
-from sqlalchemy.sql import compiler, expression
-from sqlalchemy.engine import default, base
-from sqlalchemy import types as sqltypes
-from sqlalchemy.sql import operators as sql_operators
-from sqlalchemy import MetaData, Table, Column
-from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey
-
-
-__all__ = [
- 'SybaseTypeError'
- 'SybaseNumeric', 'SybaseFloat', 'SybaseInteger', 'SybaseBigInteger',
- 'SybaseTinyInteger', 'SybaseSmallInteger',
- 'SybaseDateTime_mxodbc', 'SybaseDateTime_pyodbc',
- 'SybaseDate_mxodbc', 'SybaseDate_pyodbc',
- 'SybaseTime_mxodbc', 'SybaseTime_pyodbc',
- 'SybaseText', 'SybaseString', 'SybaseChar', 'SybaseBinary',
- 'SybaseBoolean', 'SybaseTimeStamp', 'SybaseMoney', 'SybaseSmallMoney',
- 'SybaseUniqueIdentifier',
- ]
-
-
-RESERVED_WORDS = set([
- "add", "all", "alter", "and",
- "any", "as", "asc", "backup",
- "begin", "between", "bigint", "binary",
- "bit", "bottom", "break", "by",
- "call", "capability", "cascade", "case",
- "cast", "char", "char_convert", "character",
- "check", "checkpoint", "close", "comment",
- "commit", "connect", "constraint", "contains",
- "continue", "convert", "create", "cross",
- "cube", "current", "current_timestamp", "current_user",
- "cursor", "date", "dbspace", "deallocate",
- "dec", "decimal", "declare", "default",
- "delete", "deleting", "desc", "distinct",
- "do", "double", "drop", "dynamic",
- "else", "elseif", "encrypted", "end",
- "endif", "escape", "except", "exception",
- "exec", "execute", "existing", "exists",
- "externlogin", "fetch", "first", "float",
- "for", "force", "foreign", "forward",
- "from", "full", "goto", "grant",
- "group", "having", "holdlock", "identified",
- "if", "in", "index", "index_lparen",
- "inner", "inout", "insensitive", "insert",
- "inserting", "install", "instead", "int",
- "integer", "integrated", "intersect", "into",
- "iq", "is", "isolation", "join",
- "key", "lateral", "left", "like",
- "lock", "login", "long", "match",
- "membership", "message", "mode", "modify",
- "natural", "new", "no", "noholdlock",
- "not", "notify", "null", "numeric",
- "of", "off", "on", "open",
- "option", "options", "or", "order",
- "others", "out", "outer", "over",
- "passthrough", "precision", "prepare", "primary",
- "print", "privileges", "proc", "procedure",
- "publication", "raiserror", "readtext", "real",
- "reference", "references", "release", "remote",
- "remove", "rename", "reorganize", "resource",
- "restore", "restrict", "return", "revoke",
- "right", "rollback", "rollup", "save",
- "savepoint", "scroll", "select", "sensitive",
- "session", "set", "setuser", "share",
- "smallint", "some", "sqlcode", "sqlstate",
- "start", "stop", "subtrans", "subtransaction",
- "synchronize", "syntax_error", "table", "temporary",
- "then", "time", "timestamp", "tinyint",
- "to", "top", "tran", "trigger",
- "truncate", "tsequal", "unbounded", "union",
- "unique", "unknown", "unsigned", "update",
- "updating", "user", "using", "validate",
- "values", "varbinary", "varchar", "variable",
- "varying", "view", "wait", "waitfor",
- "when", "where", "while", "window",
- "with", "with_cube", "with_lparen", "with_rollup",
- "within", "work", "writetext",
- ])
-
-ischema = MetaData()
-
-tables = Table("SYSTABLE", ischema,
- Column("table_id", Integer, primary_key=True),
- Column("file_id", SMALLINT),
- Column("table_name", CHAR(128)),
- Column("table_type", CHAR(10)),
- Column("creator", Integer),
- #schema="information_schema"
- )
-
-domains = Table("SYSDOMAIN", ischema,
- Column("domain_id", Integer, primary_key=True),
- Column("domain_name", CHAR(128)),
- Column("type_id", SMALLINT),
- Column("precision", SMALLINT, quote=True),
- #schema="information_schema"
- )
-
-columns = Table("SYSCOLUMN", ischema,
- Column("column_id", Integer, primary_key=True),
- Column("table_id", Integer, ForeignKey(tables.c.table_id)),
- Column("pkey", CHAR(1)),
- Column("column_name", CHAR(128)),
- Column("nulls", CHAR(1)),
- Column("width", SMALLINT),
- Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)),
- # FIXME: should be mx.BIGINT
- Column("max_identity", Integer),
- # FIXME: should be mx.ODBC.Windows.LONGVARCHAR
- Column("default", String),
- Column("scale", Integer),
- #schema="information_schema"
- )
-
-foreignkeys = Table("SYSFOREIGNKEY", ischema,
- Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True),
- Column("foreign_key_id", SMALLINT, primary_key=True),
- Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)),
- #schema="information_schema"
- )
-fkcols = Table("SYSFKCOL", ischema,
- Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True),
- Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True),
- Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True),
- Column("primary_column_id", Integer),
- #schema="information_schema"
- )
-
-class SybaseTypeError(sqltypes.TypeEngine):
- def result_processor(self, dialect):
- return None
-
- def bind_processor(self, dialect):
- def process(value):
- raise exc.InvalidRequestError("Data type not supported", [value])
- return process
-
- def get_col_spec(self):
- raise exc.CompileError("Data type not supported")
-
-class SybaseNumeric(sqltypes.Numeric):
- def get_col_spec(self):
- if self.scale is None:
- if self.precision is None:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s)" % {'precision' : self.precision}
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class SybaseFloat(sqltypes.FLOAT, SybaseNumeric):
- def __init__(self, precision = 10, asdecimal = False, scale = 2, **kwargs):
- super(sqltypes.FLOAT, self).__init__(precision, asdecimal, **kwargs)
- self.scale = scale
-
- def get_col_spec(self):
- # if asdecimal is True, handle same way as SybaseNumeric
- if self.asdecimal:
- return SybaseNumeric.get_col_spec(self)
- if self.precision is None:
- return "FLOAT"
- else:
- return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return float(value)
- if self.asdecimal:
- return SybaseNumeric.result_processor(self, dialect)
- return process
-
-class SybaseInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-class SybaseBigInteger(SybaseInteger):
- def get_col_spec(self):
- return "BIGINT"
-
-class SybaseTinyInteger(SybaseInteger):
- def get_col_spec(self):
- return "TINYINT"
-
-class SybaseSmallInteger(SybaseInteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-class SybaseDateTime_mxodbc(sqltypes.DateTime):
- def __init__(self, *a, **kw):
- super(SybaseDateTime_mxodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATETIME"
-
-class SybaseDateTime_pyodbc(sqltypes.DateTime):
- def __init__(self, *a, **kw):
- super(SybaseDateTime_pyodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATETIME"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- # Convert the datetime.datetime back to datetime.time
- return value
- return process
-
- def bind_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return value
- return process
-
-class SybaseDate_mxodbc(sqltypes.Date):
- def __init__(self, *a, **kw):
- super(SybaseDate_mxodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATE"
-
-class SybaseDate_pyodbc(sqltypes.Date):
- def __init__(self, *a, **kw):
- super(SybaseDate_pyodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATE"
-
-class SybaseTime_mxodbc(sqltypes.Time):
- def __init__(self, *a, **kw):
- super(SybaseTime_mxodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATETIME"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- # Convert the datetime.datetime back to datetime.time
- return datetime.time(value.hour, value.minute, value.second, value.microsecond)
- return process
-
-class SybaseTime_pyodbc(sqltypes.Time):
- def __init__(self, *a, **kw):
- super(SybaseTime_pyodbc, self).__init__(False)
-
- def get_col_spec(self):
- return "DATETIME"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- # Convert the datetime.datetime back to datetime.time
- return datetime.time(value.hour, value.minute, value.second, value.microsecond)
- return process
-
- def bind_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return datetime.datetime(1970, 1, 1, value.hour, value.minute, value.second, value.microsecond)
- return process
-
-class SybaseText(sqltypes.Text):
- def get_col_spec(self):
- return "TEXT"
-
-class SybaseString(sqltypes.String):
- def get_col_spec(self):
- return "VARCHAR(%(length)s)" % {'length' : self.length}
-
-class SybaseChar(sqltypes.CHAR):
- def get_col_spec(self):
- return "CHAR(%(length)s)" % {'length' : self.length}
-
-class SybaseBinary(sqltypes.Binary):
- def get_col_spec(self):
- return "IMAGE"
-
-class SybaseBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return "BIT"
-
- def result_processor(self, dialect):
- def process(value):
- if value is None:
- return None
- return value and True or False
- return process
-
- def bind_processor(self, dialect):
- def process(value):
- if value is True:
- return 1
- elif value is False:
- return 0
- elif value is None:
- return None
- else:
- return value and True or False
- return process
-
-class SybaseTimeStamp(sqltypes.TIMESTAMP):
- def get_col_spec(self):
- return "TIMESTAMP"
-
-class SybaseMoney(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "MONEY"
-
-class SybaseSmallMoney(SybaseMoney):
- def get_col_spec(self):
- return "SMALLMONEY"
-
-class SybaseUniqueIdentifier(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "UNIQUEIDENTIFIER"
-
-class SybaseSQLExecutionContext(default.DefaultExecutionContext):
- pass
-
-class SybaseSQLExecutionContext_mxodbc(SybaseSQLExecutionContext):
-
- def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
- super(SybaseSQLExecutionContext_mxodbc, self).__init__(dialect, connection, compiled, statement, parameters)
-
- def pre_exec(self):
- super(SybaseSQLExecutionContext_mxodbc, self).pre_exec()
-
- def post_exec(self):
- if self.compiled.isinsert:
- table = self.compiled.statement.table
- # get the inserted values of the primary key
-
- # get any sequence IDs first (using @@identity)
- self.cursor.execute("SELECT @@identity AS lastrowid")
- row = self.cursor.fetchone()
- lastrowid = int(row[0])
- if lastrowid > 0:
- # an IDENTITY was inserted, fetch it
- # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?!
- if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None:
- self._last_inserted_ids = [lastrowid]
- else:
- self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:]
- super(SybaseSQLExecutionContext_mxodbc, self).post_exec()
-
-class SybaseSQLExecutionContext_pyodbc(SybaseSQLExecutionContext):
- def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
- super(SybaseSQLExecutionContext_pyodbc, self).__init__(dialect, connection, compiled, statement, parameters)
-
- def pre_exec(self):
- super(SybaseSQLExecutionContext_pyodbc, self).pre_exec()
-
- def post_exec(self):
- if self.compiled.isinsert:
- table = self.compiled.statement.table
- # get the inserted values of the primary key
-
- # get any sequence IDs first (using @@identity)
- self.cursor.execute("SELECT @@identity AS lastrowid")
- row = self.cursor.fetchone()
- lastrowid = int(row[0])
- if lastrowid > 0:
- # an IDENTITY was inserted, fetch it
- # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?!
- if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None:
- self._last_inserted_ids = [lastrowid]
- else:
- self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:]
- super(SybaseSQLExecutionContext_pyodbc, self).post_exec()
-
-class SybaseSQLDialect(default.DefaultDialect):
- colspecs = {
- # FIXME: unicode support
- #sqltypes.Unicode : SybaseUnicode,
- sqltypes.Integer : SybaseInteger,
- sqltypes.SmallInteger : SybaseSmallInteger,
- sqltypes.Numeric : SybaseNumeric,
- sqltypes.Float : SybaseFloat,
- sqltypes.String : SybaseString,
- sqltypes.Binary : SybaseBinary,
- sqltypes.Boolean : SybaseBoolean,
- sqltypes.Text : SybaseText,
- sqltypes.CHAR : SybaseChar,
- sqltypes.TIMESTAMP : SybaseTimeStamp,
- sqltypes.FLOAT : SybaseFloat,
- }
-
- ischema_names = {
- 'integer' : SybaseInteger,
- 'unsigned int' : SybaseInteger,
- 'unsigned smallint' : SybaseInteger,
- 'unsigned bigint' : SybaseInteger,
- 'bigint': SybaseBigInteger,
- 'smallint' : SybaseSmallInteger,
- 'tinyint' : SybaseTinyInteger,
- 'varchar' : SybaseString,
- 'long varchar' : SybaseText,
- 'char' : SybaseChar,
- 'decimal' : SybaseNumeric,
- 'numeric' : SybaseNumeric,
- 'float' : SybaseFloat,
- 'double' : SybaseFloat,
- 'binary' : SybaseBinary,
- 'long binary' : SybaseBinary,
- 'varbinary' : SybaseBinary,
- 'bit': SybaseBoolean,
- 'image' : SybaseBinary,
- 'timestamp': SybaseTimeStamp,
- 'money': SybaseMoney,
- 'smallmoney': SybaseSmallMoney,
- 'uniqueidentifier': SybaseUniqueIdentifier,
-
- 'java.lang.Object' : SybaseTypeError,
- 'java serialization' : SybaseTypeError,
- }
-
- name = 'sybase'
- # Sybase backend peculiarities
- supports_unicode_statements = False
- supports_sane_rowcount = False
- supports_sane_multi_rowcount = False
- execution_ctx_cls = SybaseSQLExecutionContext
-
- def __new__(cls, dbapi=None, *args, **kwargs):
- if cls != SybaseSQLDialect:
- return super(SybaseSQLDialect, cls).__new__(cls, *args, **kwargs)
- if dbapi:
- print dbapi.__name__
- dialect = dialect_mapping.get(dbapi.__name__)
- return dialect(*args, **kwargs)
- else:
- return object.__new__(cls, *args, **kwargs)
-
- def __init__(self, **params):
- super(SybaseSQLDialect, self).__init__(**params)
- self.text_as_varchar = False
- # FIXME: what is the default schema for sybase connections (DBA?) ?
- self.set_default_schema_name("dba")
-
- def dbapi(cls, module_name=None):
- if module_name:
- try:
- dialect_cls = dialect_mapping[module_name]
- return dialect_cls.import_dbapi()
- except KeyError:
- raise exc.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name)
- else:
- for dialect_cls in dialect_mapping.values():
- try:
- return dialect_cls.import_dbapi()
- except ImportError, e:
- pass
- else:
- raise ImportError('No DBAPI module detected for SybaseSQL - please install mxodbc')
- dbapi = classmethod(dbapi)
-
- def type_descriptor(self, typeobj):
- newobj = sqltypes.adapt_type(typeobj, self.colspecs)
- return newobj
-
- def last_inserted_ids(self):
- return self.context.last_inserted_ids
-
- def get_default_schema_name(self, connection):
- return self.schema_name
-
- def set_default_schema_name(self, schema_name):
- self.schema_name = schema_name
-
- def do_execute(self, cursor, statement, params, **kwargs):
- params = tuple(params)
- super(SybaseSQLDialect, self).do_execute(cursor, statement, params, **kwargs)
-
- # FIXME: remove ?
- def _execute(self, c, statement, parameters):
- try:
- if parameters == {}:
- parameters = ()
- c.execute(statement, parameters)
- self.context.rowcount = c.rowcount
- c.DBPROP_COMMITPRESERVE = "Y"
- except Exception, e:
- raise exc.DBAPIError.instance(statement, parameters, e)
-
- def table_names(self, connection, schema):
- """Ignore the schema and the charset for now."""
- s = sql.select([tables.c.table_name],
- sql.not_(tables.c.table_name.like("SYS%")) and
- tables.c.creator >= 100
- )
- rp = connection.execute(s)
- return [row[0] for row in rp.fetchall()]
-
- def has_table(self, connection, tablename, schema=None):
- # FIXME: ignore schemas for sybase
- s = sql.select([tables.c.table_name], tables.c.table_name == tablename)
-
- c = connection.execute(s)
- row = c.fetchone()
- print "has_table: " + tablename + ": " + str(bool(row is not None))
- return row is not None
-
- def reflecttable(self, connection, table, include_columns):
- # Get base columns
- if table.schema is not None:
- current_schema = table.schema
- else:
- current_schema = self.get_default_schema_name(connection)
-
- s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id])
-
- c = connection.execute(s)
- found_table = False
- # makes sure we append the columns in the correct order
- while True:
- row = c.fetchone()
- if row is None:
- break
- found_table = True
- (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = (
- row[columns.c.column_name],
- row[domains.c.domain_name],
- row[columns.c.nulls] == 'Y',
- row[columns.c.width],
- row[domains.c.precision],
- row[columns.c.scale],
- row[columns.c.default],
- row[columns.c.pkey] == 'Y',
- row[columns.c.max_identity],
- row[tables.c.table_id],
- row[columns.c.column_id],
- )
- if include_columns and name not in include_columns:
- continue
-
- # FIXME: else problems with SybaseBinary(size)
- if numericscale == 0:
- numericscale = None
-
- args = []
- for a in (charlen, numericprec, numericscale):
- if a is not None:
- args.append(a)
- coltype = self.ischema_names.get(type, None)
- if coltype == SybaseString and charlen == -1:
- coltype = SybaseText()
- else:
- if coltype is None:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (type, name))
- coltype = sqltypes.NULLTYPE
- coltype = coltype(*args)
- colargs = []
- if default is not None:
- colargs.append(schema.DefaultClause(sql.text(default)))
-
- # any sequences ?
- col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs)
- if int(max_identity) > 0:
- col.sequence = schema.Sequence(name + '_identity')
- col.sequence.start = int(max_identity)
- col.sequence.increment = 1
-
- # append the column
- table.append_column(col)
-
- # any foreign key constraint for this table ?
- # note: no multi-column foreign keys are considered
- s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name }
- c = connection.execute(s)
- foreignKeys = {}
- while True:
- row = c.fetchone()
- if row is None:
- break
- (foreign_table, foreign_column, primary_table, primary_column) = (
- row[0], row[1], row[2], row[3],
- )
- if not primary_table in foreignKeys.keys():
- foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]]
- else:
- foreignKeys[primary_table][0].append('%s'%(foreign_column))
- foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column))
- for primary_table in foreignKeys.keys():
- #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)]))
- table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True))
-
- if not found_table:
- raise exc.NoSuchTableError(table.name)
-
-
-class SybaseSQLDialect_mxodbc(SybaseSQLDialect):
- execution_ctx_cls = SybaseSQLExecutionContext_mxodbc
-
- def __init__(self, **params):
- super(SybaseSQLDialect_mxodbc, self).__init__(**params)
-
- self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()}
-
- def import_dbapi(cls):
- #import mx.ODBC.Windows as module
- import mxODBC as module
- return module
- import_dbapi = classmethod(import_dbapi)
-
- colspecs = SybaseSQLDialect.colspecs.copy()
- colspecs[sqltypes.Time] = SybaseTime_mxodbc
- colspecs[sqltypes.Date] = SybaseDate_mxodbc
- colspecs[sqltypes.DateTime] = SybaseDateTime_mxodbc
-
- ischema_names = SybaseSQLDialect.ischema_names.copy()
- ischema_names['time'] = SybaseTime_mxodbc
- ischema_names['date'] = SybaseDate_mxodbc
- ischema_names['datetime'] = SybaseDateTime_mxodbc
- ischema_names['smalldatetime'] = SybaseDateTime_mxodbc
-
- def is_disconnect(self, e):
- # FIXME: optimize
- #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e)
- #return True
- return False
-
- def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
- super(SybaseSQLDialect_mxodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
-
- def create_connect_args(self, url):
- '''Return a tuple of *args,**kwargs'''
- # FIXME: handle mx.odbc.Windows proprietary args
- opts = url.translate_connect_args(username='user')
- opts.update(url.query)
- argsDict = {}
- argsDict['user'] = opts['user']
- argsDict['password'] = opts['password']
- connArgs = [[opts['dsn']], argsDict]
- return connArgs
-
-
-class SybaseSQLDialect_pyodbc(SybaseSQLDialect):
- execution_ctx_cls = SybaseSQLExecutionContext_pyodbc
-
- def __init__(self, **params):
- super(SybaseSQLDialect_pyodbc, self).__init__(**params)
- self.dbapi_type_map = {'getdate' : SybaseDate_pyodbc()}
-
- def import_dbapi(cls):
- import mypyodbc as module
- return module
- import_dbapi = classmethod(import_dbapi)
-
- colspecs = SybaseSQLDialect.colspecs.copy()
- colspecs[sqltypes.Time] = SybaseTime_pyodbc
- colspecs[sqltypes.Date] = SybaseDate_pyodbc
- colspecs[sqltypes.DateTime] = SybaseDateTime_pyodbc
-
- ischema_names = SybaseSQLDialect.ischema_names.copy()
- ischema_names['time'] = SybaseTime_pyodbc
- ischema_names['date'] = SybaseDate_pyodbc
- ischema_names['datetime'] = SybaseDateTime_pyodbc
- ischema_names['smalldatetime'] = SybaseDateTime_pyodbc
-
- def is_disconnect(self, e):
- # FIXME: optimize
- #return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e)
- #return True
- return False
-
- def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
- super(SybaseSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
-
- def create_connect_args(self, url):
- '''Return a tuple of *args,**kwargs'''
- # FIXME: handle pyodbc proprietary args
- opts = url.translate_connect_args(username='user')
- opts.update(url.query)
-
- self.autocommit = False
- if 'autocommit' in opts:
- self.autocommit = bool(int(opts.pop('autocommit')))
-
- argsDict = {}
- argsDict['UID'] = opts['user']
- argsDict['PWD'] = opts['password']
- argsDict['DSN'] = opts['dsn']
- connArgs = [[';'.join(["%s=%s"%(key, argsDict[key]) for key in argsDict])], {'autocommit' : self.autocommit}]
- return connArgs
-
-
-dialect_mapping = {
- 'sqlalchemy.databases.mxODBC' : SybaseSQLDialect_mxodbc,
-# 'pyodbc' : SybaseSQLDialect_pyodbc,
- }
-
-
-class SybaseSQLCompiler(compiler.SQLCompiler):
- operators = compiler.SQLCompiler.operators.copy()
- operators.update({
- sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y),
- })
-
- def bindparam_string(self, name):
- res = super(SybaseSQLCompiler, self).bindparam_string(name)
- if name.lower().startswith('literal'):
- res = 'STRING(%s)' % res
- return res
-
- def get_select_precolumns(self, select):
- s = select._distinct and "DISTINCT " or ""
- if select._limit:
- #if select._limit == 1:
- #s += "FIRST "
- #else:
- #s += "TOP %s " % (select._limit,)
- s += "TOP %s " % (select._limit,)
- if select._offset:
- if not select._limit:
- # FIXME: sybase doesn't allow an offset without a limit
- # so use a huge value for TOP here
- s += "TOP 1000000 "
- s += "START AT %s " % (select._offset+1,)
- return s
-
- def limit_clause(self, select):
- # Limit in sybase is after the select keyword
- return ""
-
- def visit_binary(self, binary):
- """Move bind parameters to the right-hand side of an operator, where possible."""
- if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
- return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
- else:
- return super(SybaseSQLCompiler, self).visit_binary(binary)
-
- def label_select_column(self, select, column, asfrom):
- if isinstance(column, expression.Function):
- return column.label(None)
- else:
- return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
-
- function_rewrites = {'current_date': 'getdate',
- }
- def visit_function(self, func):
- func.name = self.function_rewrites.get(func.name, func.name)
- res = super(SybaseSQLCompiler, self).visit_function(func)
- if func.name.lower() == 'getdate':
- # apply CAST operator
- # FIXME: what about _pyodbc ?
- cast = expression._Cast(func, SybaseDate_mxodbc)
- # infinite recursion
- # res = self.visit_cast(cast)
- res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
- return res
-
- def for_update_clause(self, select):
- # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
- return ''
-
- def order_by_clause(self, select):
- order_by = self.process(select._order_by_clause)
-
- # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
- if order_by and (not self.is_subquery() or select._limit):
- return " ORDER BY " + order_by
- else:
- return ""
-
-
-class SybaseSQLSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, **kwargs):
-
- colspec = self.preparer.format_column(column)
-
- if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
- column.autoincrement and isinstance(column.type, sqltypes.Integer):
- if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
- column.sequence = schema.Sequence(column.name + '_seq')
-
- if hasattr(column, 'sequence'):
- column.table.has_sequence = column
- #colspec += " numeric(30,0) IDENTITY"
- colspec += " Integer IDENTITY"
- else:
- colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-
- if not column.nullable:
- colspec += " NOT NULL"
-
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- return colspec
-
-
-class SybaseSQLSchemaDropper(compiler.SchemaDropper):
- def visit_index(self, index):
- self.append("\nDROP INDEX %s.%s" % (
- self.preparer.quote_identifier(index.table.name),
- self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
- ))
- self.execute()
-
-
-class SybaseSQLDefaultRunner(base.DefaultRunner):
- pass
-
-
-class SybaseSQLIdentifierPreparer(compiler.IdentifierPreparer):
- reserved_words = RESERVED_WORDS
-
- def __init__(self, dialect):
- super(SybaseSQLIdentifierPreparer, self).__init__(dialect)
-
- def _escape_identifier(self, value):
- #TODO: determin SybaseSQL's escapeing rules
- return value
-
- def _fold_identifier_case(self, value):
- #TODO: determin SybaseSQL's case folding rules
- return value
-
-
-dialect = SybaseSQLDialect
-dialect.statement_compiler = SybaseSQLCompiler
-dialect.schemagenerator = SybaseSQLSchemaGenerator
-dialect.schemadropper = SybaseSQLSchemaDropper
-dialect.preparer = SybaseSQLIdentifierPreparer
-dialect.defaultrunner = SybaseSQLDefaultRunner
--- /dev/null
+from sqlalchemy.dialects.informix import base, informixdb
+
+base.dialect = informixdb.dialect
\ No newline at end of file
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Informix support.
+
+--- THIS DIALECT IS NOT TESTED ON 0.6 ---
+
+"""
+
import datetime
from sqlalchemy import types as sqltypes
-# for offset
-
-class informix_cursor(object):
- def __init__( self , con ):
- self.__cursor = con.cursor()
- self.rowcount = 0
-
- def offset( self , n ):
- if n > 0:
- self.fetchmany( n )
- self.rowcount = self.__cursor.rowcount - n
- if self.rowcount < 0:
- self.rowcount = 0
- else:
- self.rowcount = self.__cursor.rowcount
-
- def execute( self , sql , params ):
- if params is None or len( params ) == 0:
- params = []
-
- return self.__cursor.execute( sql , params )
-
- def __getattr__( self , name ):
- if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ):
- return getattr( self.__cursor , name )
-
-class InfoNumeric(sqltypes.Numeric):
- def get_col_spec(self):
- if not self.precision:
- return 'NUMERIC'
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class InfoInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-class InfoSmallInteger(sqltypes.SmallInteger):
- def get_col_spec(self):
- return "SMALLINT"
-
-class InfoDate(sqltypes.Date):
- def get_col_spec( self ):
- return "DATE"
-
class InfoDateTime(sqltypes.DateTime ):
- def get_col_spec(self):
- return "DATETIME YEAR TO SECOND"
-
def bind_processor(self, dialect):
def process(value):
if value is not None:
return process
class InfoTime(sqltypes.Time ):
- def get_col_spec(self):
- return "DATETIME HOUR TO SECOND"
-
def bind_processor(self, dialect):
def process(value):
if value is not None:
return value
return process
-class InfoText(sqltypes.String):
- def get_col_spec(self):
- return "VARCHAR(255)"
-
-class InfoString(sqltypes.String):
- def get_col_spec(self):
- return "VARCHAR(%(length)s)" % {'length' : self.length}
-
- def bind_processor(self, dialect):
- def process(value):
- if value == '':
- return None
- else:
- return value
- return process
-
-class InfoChar(sqltypes.CHAR):
- def get_col_spec(self):
- return "CHAR(%(length)s)" % {'length' : self.length}
-
-class InfoBinary(sqltypes.Binary):
- def get_col_spec(self):
- return "BYTE"
class InfoBoolean(sqltypes.Boolean):
- default_type = 'NUM'
- def get_col_spec(self):
- return "SMALLINT"
-
def result_processor(self, dialect):
def process(value):
if value is None:
return process
colspecs = {
- sqltypes.Integer : InfoInteger,
- sqltypes.SmallInteger : InfoSmallInteger,
- sqltypes.Numeric : InfoNumeric,
- sqltypes.Float : InfoNumeric,
sqltypes.DateTime : InfoDateTime,
- sqltypes.Date : InfoDate,
sqltypes.Time: InfoTime,
- sqltypes.String : InfoString,
- sqltypes.Binary : InfoBinary,
sqltypes.Boolean : InfoBoolean,
- sqltypes.Text : InfoText,
- sqltypes.CHAR: InfoChar,
}
ischema_names = {
- 0 : InfoString, # CHAR
- 1 : InfoSmallInteger, # SMALLINT
- 2 : InfoInteger, # INT
- 3 : InfoNumeric, # Float
- 3 : InfoNumeric, # SmallFloat
- 5 : InfoNumeric, # DECIMAL
- 6 : InfoInteger, # Serial
- 7 : InfoDate, # DATE
- 8 : InfoNumeric, # MONEY
- 10 : InfoDateTime, # DATETIME
- 11 : InfoBinary, # BYTE
- 12 : InfoText, # TEXT
- 13 : InfoString, # VARCHAR
- 15 : InfoString, # NCHAR
- 16 : InfoString, # NVARCHAR
- 17 : InfoInteger, # INT8
- 18 : InfoInteger, # Serial8
- 43 : InfoString, # LVARCHAR
- -1 : InfoBinary, # BLOB
- -1 : InfoText, # CLOB
+ 0 : sqltypes.CHAR, # CHAR
+ 1 : sqltypes.SMALLINT, # SMALLINT
+ 2 : sqltypes.INTEGER, # INT
+ 3 : sqltypes.FLOAT, # Float
+ 3 : sqltypes.Float, # SmallFloat
+ 5 : sqltypes.DECIMAL, # DECIMAL
+ 6 : sqltypes.Integer, # Serial
+ 7 : sqltypes.DATE, # DATE
+ 8 : sqltypes.Numeric, # MONEY
+ 10 : sqltypes.DATETIME, # DATETIME
+ 11 : sqltypes.Binary, # BYTE
+ 12 : sqltypes.TEXT, # TEXT
+ 13 : sqltypes.VARCHAR, # VARCHAR
+ 15 : sqltypes.NCHAR, # NCHAR
+ 16 : sqltypes.NVARCHAR, # NVARCHAR
+ 17 : sqltypes.Integer, # INT8
+ 18 : sqltypes.Integer, # Serial8
+ 43 : sqltypes.String, # LVARCHAR
+ -1 : sqltypes.BLOB, # BLOB
+ -1 : sqltypes.CLOB, # CLOB
}
-class InfoExecutionContext(default.DefaultExecutionContext):
- # cursor.sqlerrd
- # 0 - estimated number of rows returned
- # 1 - serial value after insert or ISAM error code
- # 2 - number of rows processed
- # 3 - estimated cost
- # 4 - offset of the error into the SQL statement
- # 5 - rowid after insert
- def post_exec(self):
- if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None:
- self._last_inserted_ids = [self.cursor.sqlerrd[1]]
- elif hasattr( self.compiled , 'offset' ):
- self.cursor.offset( self.compiled.offset )
- super(InfoExecutionContext, self).post_exec()
-
- def create_cursor( self ):
- return informix_cursor( self.connection.connection )
-
-class InfoDialect(default.DefaultDialect):
- name = 'informix'
- default_paramstyle = 'qmark'
- # for informix 7.31
- max_identifier_length = 18
+class InfoTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_DATETIME(self, type_):
+ return "DATETIME YEAR TO SECOND"
+
+ def visit_TIME(self, type_):
+ return "DATETIME HOUR TO SECOND"
+
+ def visit_binary(self, type_):
+ return "BYTE"
+
+ def visit_boolean(self, type_):
+ return "SMALLINT"
+
+class InfoSQLCompiler(compiler.SQLCompiler):
- def __init__(self, use_ansi=True, **kwargs):
- self.use_ansi = use_ansi
- default.DefaultDialect.__init__(self, **kwargs)
+ def __init__(self, *args, **kwargs):
+ self.limit = 0
+ self.offset = 0
- def dbapi(cls):
- import informixdb
- return informixdb
- dbapi = classmethod(dbapi)
+ compiler.SQLCompiler.__init__( self , *args, **kwargs )
- def is_disconnect(self, e):
- if isinstance(e, self.dbapi.OperationalError):
- return 'closed the connection' in str(e) or 'connection not open' in str(e)
+ def default_from(self):
+ return " from systables where tabname = 'systables' "
+
+ def get_select_precolumns( self , select ):
+ s = select._distinct and "DISTINCT " or ""
+ # only has limit
+ if select._limit:
+ off = select._offset or 0
+ s += " FIRST %s " % ( select._limit + off )
else:
- return False
+ s += ""
+ return s
- def do_begin(self , connect ):
- cu = connect.cursor()
- cu.execute( 'SET LOCK MODE TO WAIT' )
- #cu.execute( 'SET ISOLATION TO REPEATABLE READ' )
+ def visit_select(self, select):
+ if select._offset:
+ self.offset = select._offset
+ self.limit = select._limit or 0
+ # the column in order by clause must in select too
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
+ def __label( c ):
+ try:
+ return c._label.lower()
+ except:
+ return ''
+
+ # TODO: dont modify the original select, generate a new one
+ a = [ __label(c) for c in select._raw_columns ]
+ for c in select._order_by_clause.clauses:
+ if ( __label(c) not in a ):
+ select.append_column( c )
- def create_connect_args(self, url):
- if url.host:
- dsn = '%s@%s' % ( url.database , url.host )
+ return compiler.SQLCompiler.visit_select(self, select)
+
+ def limit_clause(self, select):
+ return ""
+
+ def visit_function( self , func ):
+ if func.name.lower() == 'current_date':
+ return "today"
+ elif func.name.lower() == 'current_time':
+ return "CURRENT HOUR TO SECOND"
+ elif func.name.lower() in ( 'current_timestamp' , 'now' ):
+ return "CURRENT YEAR TO SECOND"
else:
- dsn = url.database
+ return compiler.SQLCompiler.visit_function( self , func )
- if url.username:
- opt = { 'user':url.username , 'password': url.password }
+ def visit_clauselist(self, list, **kwargs):
+ return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None])
+
+class InfoDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, first_pk=False):
+ colspec = self.preparer.format_column(column)
+ if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
+ isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk:
+ colspec += " SERIAL"
+ self.has_serial = True
else:
- opt = {}
+ colspec += " " + self.dialect.type_compiler.process(column.type)
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+
+ return colspec
+
+ def post_create_table(self, table):
+ if hasattr( self , 'has_serial' ):
+ del self.has_serial
+ return ''
+
+class InfoIdentifierPreparer(compiler.IdentifierPreparer):
+ def __init__(self, dialect):
+ super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
- return ([dsn], opt)
+ def format_constraint(self, constraint):
+ # informix doesnt support names for constraints
+ return ''
+
+ def _requires_quotes(self, value):
+ return False
+
+class InformixDialect(default.DefaultDialect):
+ name = 'informix'
+ # for informix 7.31
+ max_identifier_length = 18
+ type_compiler = InfoTypeCompiler
+ poolclass = pool.SingletonThreadPool
+ statement_compiler = InfoSQLCompiler
+ ddl_compiler = InfoDDLCompiler
+ preparer = InfoIdentifierPreparer
+ colspecs = colspecs
+ ischema_names = ischema_names
+
+ def do_begin(self , connect ):
+ cu = connect.cursor()
+ cu.execute( 'SET LOCK MODE TO WAIT' )
+ #cu.execute( 'SET ISOLATION TO REPEATABLE READ' )
def table_names(self, connection, schema):
s = "select tabname from systables"
for cons_name, cons_type, local_column in rows:
table.primary_key.add( table.c[local_column] )
-class InfoCompiler(compiler.SQLCompiler):
- """Info compiler modifies the lexical structure of Select statements to work under
- non-ANSI configured Oracle databases, if the use_ansi flag is False."""
-
- def __init__(self, *args, **kwargs):
- self.limit = 0
- self.offset = 0
-
- compiler.SQLCompiler.__init__( self , *args, **kwargs )
-
- def default_from(self):
- return " from systables where tabname = 'systables' "
-
- def get_select_precolumns( self , select ):
- s = select._distinct and "DISTINCT " or ""
- # only has limit
- if select._limit:
- off = select._offset or 0
- s += " FIRST %s " % ( select._limit + off )
- else:
- s += ""
- return s
-
- def visit_select(self, select):
- if select._offset:
- self.offset = select._offset
- self.limit = select._limit or 0
- # the column in order by clause must in select too
-
- def __label( c ):
- try:
- return c._label.lower()
- except:
- return ''
-
- # TODO: dont modify the original select, generate a new one
- a = [ __label(c) for c in select._raw_columns ]
- for c in select._order_by_clause.clauses:
- if ( __label(c) not in a ):
- select.append_column( c )
-
- return compiler.SQLCompiler.visit_select(self, select)
-
- def limit_clause(self, select):
- return ""
-
- def visit_function( self , func ):
- if func.name.lower() == 'current_date':
- return "today"
- elif func.name.lower() == 'current_time':
- return "CURRENT HOUR TO SECOND"
- elif func.name.lower() in ( 'current_timestamp' , 'now' ):
- return "CURRENT YEAR TO SECOND"
- else:
- return compiler.SQLCompiler.visit_function( self , func )
-
- def visit_clauselist(self, list, **kwargs):
- return ', '.join([s for s in [self.process(c) for c in list.clauses] if s is not None])
-
-class InfoSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, first_pk=False):
- colspec = self.preparer.format_column(column)
- if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and \
- isinstance(column.type, sqltypes.Integer) and not getattr( self , 'has_serial' , False ) and first_pk:
- colspec += " SERIAL"
- self.has_serial = True
- else:
- colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- if not column.nullable:
- colspec += " NOT NULL"
-
- return colspec
-
- def post_create_table(self, table):
- if hasattr( self , 'has_serial' ):
- del self.has_serial
- return ''
-
- def visit_primary_key_constraint(self, constraint):
- # for informix 7.31 not support constraint name
- name = constraint.name
- constraint.name = None
- super(InfoSchemaGenerator, self).visit_primary_key_constraint(constraint)
- constraint.name = name
-
- def visit_unique_constraint(self, constraint):
- # for informix 7.31 not support constraint name
- name = constraint.name
- constraint.name = None
- super(InfoSchemaGenerator, self).visit_unique_constraint(constraint)
- constraint.name = name
-
- def visit_foreign_key_constraint( self , constraint ):
- if constraint.name is not None:
- constraint.use_alter = True
- else:
- super( InfoSchemaGenerator , self ).visit_foreign_key_constraint( constraint )
-
- def define_foreign_key(self, constraint):
- # for informix 7.31 not support constraint name
- if constraint.use_alter:
- name = constraint.name
- constraint.name = None
- self.append( "CONSTRAINT " )
- super(InfoSchemaGenerator, self).define_foreign_key(constraint)
- constraint.name = name
- if name is not None:
- self.append( " CONSTRAINT " + name )
- else:
- super(InfoSchemaGenerator, self).define_foreign_key(constraint)
-
- def visit_index(self, index):
- if len( index.columns ) == 1 and index.columns[0].foreign_key:
- return
- super(InfoSchemaGenerator, self).visit_index(index)
-
-class InfoIdentifierPreparer(compiler.IdentifierPreparer):
- def __init__(self, dialect):
- super(InfoIdentifierPreparer, self).__init__(dialect, initial_quote="'")
-
- def _requires_quotes(self, value):
- return False
-
-class InfoSchemaDropper(compiler.SchemaDropper):
- def drop_foreignkey(self, constraint):
- if constraint.name is not None:
- super( InfoSchemaDropper , self ).drop_foreignkey( constraint )
-
-dialect = InfoDialect
-poolclass = pool.SingletonThreadPool
-dialect.statement_compiler = InfoCompiler
-dialect.schemagenerator = InfoSchemaGenerator
-dialect.schemadropper = InfoSchemaDropper
-dialect.preparer = InfoIdentifierPreparer
-dialect.execution_ctx_cls = InfoExecutionContext
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.informix.base import InformixDialect
+from sqlalchemy.engine import default
+
+# for offset
+
+class informix_cursor(object):
+ def __init__( self , con ):
+ self.__cursor = con.cursor()
+ self.rowcount = 0
+
+ def offset( self , n ):
+ if n > 0:
+ self.fetchmany( n )
+ self.rowcount = self.__cursor.rowcount - n
+ if self.rowcount < 0:
+ self.rowcount = 0
+ else:
+ self.rowcount = self.__cursor.rowcount
+
+ def execute( self , sql , params ):
+ if params is None or len( params ) == 0:
+ params = []
+
+ return self.__cursor.execute( sql , params )
+
+ def __getattr__( self , name ):
+ if name not in ( 'offset' , '__cursor' , 'rowcount' , '__del__' , 'execute' ):
+ return getattr( self.__cursor , name )
+
+
+class InfoExecutionContext(default.DefaultExecutionContext):
+ # cursor.sqlerrd
+ # 0 - estimated number of rows returned
+ # 1 - serial value after insert or ISAM error code
+ # 2 - number of rows processed
+ # 3 - estimated cost
+ # 4 - offset of the error into the SQL statement
+ # 5 - rowid after insert
+ def post_exec(self):
+ if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None:
+ self._last_inserted_ids = [self.cursor.sqlerrd[1]]
+ elif hasattr( self.compiled , 'offset' ):
+ self.cursor.offset( self.compiled.offset )
+
+ def create_cursor( self ):
+ return informix_cursor( self.connection.connection )
+
+
+class Informix_informixdb(InformixDialect):
+ driver = 'informixdb'
+ default_paramstyle = 'qmark'
+ execution_context_cls = InfoExecutionContext
+
+ @classmethod
+ def dbapi(cls):
+ import informixdb
+ return informixdb
+
+ def create_connect_args(self, url):
+ if url.host:
+ dsn = '%s@%s' % ( url.database , url.host )
+ else:
+ dsn = url.database
+
+ if url.username:
+ opt = { 'user':url.username , 'password': url.password }
+ else:
+ opt = {}
+
+ return ([dsn], opt)
+
+
+ def is_disconnect(self, e):
+ if isinstance(e, self.dbapi.OperationalError):
+ return 'closed the connection' in str(e) or 'connection not open' in str(e)
+ else:
+ return False
+
+
+dialect = Informix_informixdb
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.maxdb import base, sapdb
+
+base.dialect = sapdb.dialect
\ No newline at end of file
"""Support for the MaxDB database.
+-- NOT TESTED ON 0.6 --
+
TODO: More module docs! MaxDB support is currently experimental.
Overview
super(_StringType, self).__init__(length=length, **kw)
self.encoding = encoding
- def get_col_spec(self):
- if self.length is None:
- spec = 'LONG'
- else:
- spec = '%s(%s)' % (self._type, self.length)
-
- if self.encoding is not None:
- spec = ' '.join([spec, self.encoding.upper()])
- return spec
-
def bind_processor(self, dialect):
if self.encoding == 'unicode':
return None
return spec
-class MaxInteger(sqltypes.Integer):
- def get_col_spec(self):
- return 'INTEGER'
-
-
-class MaxSmallInteger(MaxInteger):
- def get_col_spec(self):
- return 'SMALLINT'
-
-
class MaxNumeric(sqltypes.Numeric):
"""The FIXED (also NUMERIC, DECIMAL) data type."""
def bind_processor(self, dialect):
return None
- def get_col_spec(self):
- if self.scale and self.precision:
- return 'FIXED(%s, %s)' % (self.precision, self.scale)
- elif self.precision:
- return 'FIXED(%s)' % self.precision
- else:
- return 'INTEGER'
-
-
-class MaxFloat(sqltypes.Float):
- """The FLOAT data type."""
-
- def get_col_spec(self):
- if self.precision is None:
- return 'FLOAT'
- else:
- return 'FLOAT(%s)' % (self.precision,)
-
-
class MaxTimestamp(sqltypes.DateTime):
- def get_col_spec(self):
- return 'TIMESTAMP'
-
def bind_processor(self, dialect):
def process(value):
if value is None:
class MaxDate(sqltypes.Date):
- def get_col_spec(self):
- return 'DATE'
-
def bind_processor(self, dialect):
def process(value):
if value is None:
class MaxTime(sqltypes.Time):
- def get_col_spec(self):
- return 'TIME'
-
def bind_processor(self, dialect):
def process(value):
if value is None:
return process
-class MaxBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return 'BOOLEAN'
-
-
class MaxBlob(sqltypes.Binary):
- def get_col_spec(self):
- return 'LONG BYTE'
-
def bind_processor(self, dialect):
def process(value):
if value is None:
return value.read(value.remainingLength())
return process
+class MaxDBTypeCompiler(compiler.GenericTypeCompiler):
+ def _string_spec(self, string_spec, type_):
+ if type_.length is None:
+ spec = 'LONG'
+ else:
+ spec = '%s(%s)' % (string_spec, type_.length)
+
+ if getattr(type_, 'encoding'):
+ spec = ' '.join([spec, getattr(type_, 'encoding').upper()])
+ return spec
+
+ def visit_text(self, type_):
+ spec = 'LONG'
+ if getattr(type_, 'encoding', None):
+ spec = ' '.join((spec, type_.encoding))
+ elif type_.convert_unicode:
+ spec = ' '.join((spec, 'UNICODE'))
+
+ return spec
+ def visit_char(self, type_):
+ return self._string_spec("CHAR", type_)
+
+ def visit_string(self, type_):
+ return self._string_spec("VARCHAR", type_)
+
+ def visit_binary(self, type_):
+ return "LONG BYTE"
+
+ def visit_numeric(self, type_):
+ if type_.scale and type_.precision:
+ return 'FIXED(%s, %s)' % (type_.precision, type_.scale)
+ elif type_.precision:
+ return 'FIXED(%s)' % type_.precision
+ else:
+ return 'INTEGER'
+
+ def visit_BOOLEAN(self, type_):
+ return "BOOLEAN"
+
colspecs = {
- sqltypes.Integer: MaxInteger,
- sqltypes.SmallInteger: MaxSmallInteger,
sqltypes.Numeric: MaxNumeric,
- sqltypes.Float: MaxFloat,
sqltypes.DateTime: MaxTimestamp,
sqltypes.Date: MaxDate,
sqltypes.Time: MaxTime,
sqltypes.String: MaxString,
+ sqltypes.Unicode:MaxUnicode,
sqltypes.Binary: MaxBlob,
- sqltypes.Boolean: MaxBoolean,
sqltypes.Text: MaxText,
sqltypes.CHAR: MaxChar,
sqltypes.TIMESTAMP: MaxTimestamp,
}
ischema_names = {
- 'boolean': MaxBoolean,
- 'char': MaxChar,
- 'character': MaxChar,
- 'date': MaxDate,
- 'fixed': MaxNumeric,
- 'float': MaxFloat,
- 'int': MaxInteger,
- 'integer': MaxInteger,
- 'long binary': MaxBlob,
- 'long unicode': MaxText,
- 'long': MaxText,
- 'long': MaxText,
- 'smallint': MaxSmallInteger,
- 'time': MaxTime,
- 'timestamp': MaxTimestamp,
- 'varchar': MaxString,
+ 'boolean': sqltypes.BOOLEAN,
+ 'char': sqltypes.CHAR,
+ 'character': sqltypes.CHAR,
+ 'date': sqltypes.DATE,
+ 'fixed': sqltypes.Numeric,
+ 'float': sqltypes.FLOAT,
+ 'int': sqltypes.INT,
+ 'integer': sqltypes.INT,
+ 'long binary': sqltypes.BLOB,
+ 'long unicode': sqltypes.Text,
+ 'long': sqltypes.Text,
+ 'long': sqltypes.Text,
+ 'smallint': sqltypes.SmallInteger,
+ 'time': sqltypes.Time,
+ 'timestamp': sqltypes.TIMESTAMP,
+ 'varchar': sqltypes.VARCHAR,
}
-
+# TODO: migrate this to sapdb.py
class MaxDBExecutionContext(default.DefaultExecutionContext):
def post_exec(self):
# DB-API bug: if there were any functions as values,
class MaxDBResultProxy(engine_base.ResultProxy):
_process_row = MaxDBCachedColumnRow
+class MaxDBCompiler(compiler.SQLCompiler):
+ operators = compiler.SQLCompiler.operators.copy()
+ operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y)
-class MaxDBDialect(default.DefaultDialect):
- name = 'maxdb'
- supports_alter = True
- supports_unicode_statements = True
- max_identifier_length = 32
- supports_sane_rowcount = True
- supports_sane_multi_rowcount = False
- preexecute_pk_sequences = True
+ function_conversion = {
+ 'CURRENT_DATE': 'DATE',
+ 'CURRENT_TIME': 'TIME',
+ 'CURRENT_TIMESTAMP': 'TIMESTAMP',
+ }
- # MaxDB-specific
- datetimeformat = 'internal'
+ # These functions must be written without parens when called with no
+ # parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL'
+ bare_functions = set([
+ 'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP',
+ 'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP',
+ 'UTCDATE', 'UTCDIFF'])
- def __init__(self, _raise_known_sql_errors=False, **kw):
- super(MaxDBDialect, self).__init__(**kw)
- self._raise_known = _raise_known_sql_errors
+ def default_from(self):
+ return ' FROM DUAL'
- if self.dbapi is None:
- self.dbapi_type_map = {}
+ def for_update_clause(self, select):
+ clause = select.for_update
+ if clause is True:
+ return " WITH LOCK EXCLUSIVE"
+ elif clause is None:
+ return ""
+ elif clause == "read":
+ return " WITH LOCK"
+ elif clause == "ignore":
+ return " WITH LOCK (IGNORE) EXCLUSIVE"
+ elif clause == "nowait":
+ return " WITH LOCK (NOWAIT) EXCLUSIVE"
+ elif isinstance(clause, basestring):
+ return " WITH LOCK %s" % clause.upper()
+ elif not clause:
+ return ""
else:
- self.dbapi_type_map = {
- 'Long Binary': MaxBlob(),
- 'Long byte_t': MaxBlob(),
- 'Long Unicode': MaxText(),
- 'Timestamp': MaxTimestamp(),
- 'Date': MaxDate(),
- 'Time': MaxTime(),
- datetime.datetime: MaxTimestamp(),
- datetime.date: MaxDate(),
- datetime.time: MaxTime(),
- }
+ return " WITH LOCK EXCLUSIVE"
- def dbapi(cls):
- from sapdb import dbapi as _dbapi
- return _dbapi
- dbapi = classmethod(dbapi)
-
- def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- opts.update(url.query)
- return [], opts
-
- def type_descriptor(self, typeobj):
- if isinstance(typeobj, type):
- typeobj = typeobj()
- if isinstance(typeobj, sqltypes.Unicode):
- return typeobj.adapt(MaxUnicode)
+ def apply_function_parens(self, func):
+ if func.name.upper() in self.bare_functions:
+ return len(func.clauses) > 0
else:
- return sqltypes.adapt_type(typeobj, colspecs)
-
- def do_execute(self, cursor, statement, parameters, context=None):
- res = cursor.execute(statement, parameters)
- if isinstance(res, int) and context is not None:
- context._rowcount = res
-
- def do_release_savepoint(self, connection, name):
- # Does MaxDB truly support RELEASE SAVEPOINT <id>? All my attempts
- # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS
- # BEGIN SQLSTATE: I7065"
- # Note that ROLLBACK TO works fine. In theory, a RELEASE should
- # just free up some transactional resources early, before the overall
- # COMMIT/ROLLBACK so omitting it should be relatively ok.
- pass
+ return True
- def get_default_schema_name(self, connection):
- try:
- return self._default_schema_name
- except AttributeError:
- name = self.identifier_preparer._normalize_name(
- connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar())
- self._default_schema_name = name
- return name
+ def visit_function(self, fn, **kw):
+ transform = self.function_conversion.get(fn.name.upper(), None)
+ if transform:
+ fn = fn._clone()
+ fn.name = transform
+ return super(MaxDBCompiler, self).visit_function(fn, **kw)
- def has_table(self, connection, table_name, schema=None):
- denormalize = self.identifier_preparer._denormalize_name
- bind = [denormalize(table_name)]
- if schema is None:
- sql = ("SELECT tablename FROM TABLES "
- "WHERE TABLES.TABLENAME=? AND"
- " TABLES.SCHEMANAME=CURRENT_SCHEMA ")
+ def visit_cast(self, cast, **kwargs):
+ # MaxDB only supports casts * to NUMERIC, * to VARCHAR or
+ # date/time to VARCHAR. Casts of LONGs will fail.
+ if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)):
+ return "NUM(%s)" % self.process(cast.clause)
+ elif isinstance(cast.type, sqltypes.String):
+ return "CHR(%s)" % self.process(cast.clause)
else:
- sql = ("SELECT tablename FROM TABLES "
- "WHERE TABLES.TABLENAME = ? AND"
- " TABLES.SCHEMANAME=? ")
- bind.append(denormalize(schema))
-
- rp = connection.execute(sql, bind)
- found = bool(rp.fetchone())
- rp.close()
- return found
+ return self.process(cast.clause)
- def table_names(self, connection, schema):
- if schema is None:
- sql = (" SELECT TABLENAME FROM TABLES WHERE "
- " SCHEMANAME=CURRENT_SCHEMA ")
- rs = connection.execute(sql)
+ def visit_sequence(self, sequence):
+ if sequence.optional:
+ return None
else:
- sql = (" SELECT TABLENAME FROM TABLES WHERE "
- " SCHEMANAME=? ")
- matchname = self.identifier_preparer._denormalize_name(schema)
- rs = connection.execute(sql, matchname)
- normalize = self.identifier_preparer._normalize_name
- return [normalize(row[0]) for row in rs]
+ return (self.dialect.identifier_preparer.format_sequence(sequence) +
+ ".NEXTVAL")
- def reflecttable(self, connection, table, include_columns):
- denormalize = self.identifier_preparer._denormalize_name
- normalize = self.identifier_preparer._normalize_name
+ class ColumnSnagger(visitors.ClauseVisitor):
+ def __init__(self):
+ self.count = 0
+ self.column = None
+ def visit_column(self, column):
+ self.column = column
+ self.count += 1
- st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, '
- ' NULLABLE, "DEFAULT", DEFAULTFUNCTION '
- 'FROM COLUMNS '
- 'WHERE TABLENAME=? AND SCHEMANAME=%s '
- 'ORDER BY POS')
+ def _find_labeled_columns(self, columns, use_labels=False):
+ labels = {}
+ for column in columns:
+ if isinstance(column, basestring):
+ continue
+ snagger = self.ColumnSnagger()
+ snagger.traverse(column)
+ if snagger.count == 1:
+ if isinstance(column, sql_expr._Label):
+ labels[unicode(snagger.column)] = column.name
+ elif use_labels:
+ labels[unicode(snagger.column)] = column._label
- fk = ('SELECT COLUMNNAME, FKEYNAME, '
- ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, '
- ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA '
- ' THEN 1 ELSE 0 END) AS in_schema '
- 'FROM FOREIGNKEYCOLUMNS '
- 'WHERE TABLENAME=? AND SCHEMANAME=%s '
- 'ORDER BY FKEYNAME ')
+ return labels
- params = [denormalize(table.name)]
- if not table.schema:
- st = st % 'CURRENT_SCHEMA'
- fk = fk % 'CURRENT_SCHEMA'
- else:
- st = st % '?'
- fk = fk % '?'
- params.append(denormalize(table.schema))
+ def order_by_clause(self, select):
+ order_by = self.process(select._order_by_clause)
- rows = connection.execute(st, params).fetchall()
- if not rows:
- raise exc.NoSuchTableError(table.fullname)
+ # ORDER BY clauses in DISTINCT queries must reference aliased
+ # inner columns by alias name, not true column name.
+ if order_by and getattr(select, '_distinct', False):
+ labels = self._find_labeled_columns(select.inner_columns,
+ select.use_labels)
+ if labels:
+ for needs_alias in labels.keys():
+ r = re.compile(r'(^| )(%s)(,| |$)' %
+ re.escape(needs_alias))
+ order_by = r.sub((r'\1%s\3' % labels[needs_alias]),
+ order_by)
- include_columns = set(include_columns or [])
-
- for row in rows:
- (name, mode, col_type, encoding, length, scale,
- nullable, constant_def, func_def) = row
-
- name = normalize(name)
-
- if include_columns and name not in include_columns:
- continue
-
- type_args, type_kw = [], {}
- if col_type == 'FIXED':
- type_args = length, scale
- # Convert FIXED(10) DEFAULT SERIAL to our Integer
- if (scale == 0 and
- func_def is not None and func_def.startswith('SERIAL')):
- col_type = 'INTEGER'
- type_args = length,
- elif col_type in 'FLOAT':
- type_args = length,
- elif col_type in ('CHAR', 'VARCHAR'):
- type_args = length,
- type_kw['encoding'] = encoding
- elif col_type == 'LONG':
- type_kw['encoding'] = encoding
-
- try:
- type_cls = ischema_names[col_type.lower()]
- type_instance = type_cls(*type_args, **type_kw)
- except KeyError:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (col_type, name))
- type_instance = sqltypes.NullType
-
- col_kw = {'autoincrement': False}
- col_kw['nullable'] = (nullable == 'YES')
- col_kw['primary_key'] = (mode == 'KEY')
-
- if func_def is not None:
- if func_def.startswith('SERIAL'):
- if col_kw['primary_key']:
- # No special default- let the standard autoincrement
- # support handle SERIAL pk columns.
- col_kw['autoincrement'] = True
- else:
- # strip current numbering
- col_kw['server_default'] = schema.DefaultClause(
- sql.text('SERIAL'))
- col_kw['autoincrement'] = True
- else:
- col_kw['server_default'] = schema.DefaultClause(
- sql.text(func_def))
- elif constant_def is not None:
- col_kw['server_default'] = schema.DefaultClause(sql.text(
- "'%s'" % constant_def.replace("'", "''")))
-
- table.append_column(schema.Column(name, type_instance, **col_kw))
-
- fk_sets = itertools.groupby(connection.execute(fk, params),
- lambda row: row.FKEYNAME)
- for fkeyname, fkey in fk_sets:
- fkey = list(fkey)
- if include_columns:
- key_cols = set([r.COLUMNNAME for r in fkey])
- if key_cols != include_columns:
- continue
-
- columns, referants = [], []
- quote = self.identifier_preparer._maybe_quote_identifier
-
- for row in fkey:
- columns.append(normalize(row.COLUMNNAME))
- if table.schema or not row.in_schema:
- referants.append('.'.join(
- [quote(normalize(row[c]))
- for c in ('REFSCHEMANAME', 'REFTABLENAME',
- 'REFCOLUMNNAME')]))
- else:
- referants.append('.'.join(
- [quote(normalize(row[c]))
- for c in ('REFTABLENAME', 'REFCOLUMNNAME')]))
-
- constraint_kw = {'name': fkeyname.lower()}
- if fkey[0].RULE is not None:
- rule = fkey[0].RULE
- if rule.startswith('DELETE '):
- rule = rule[7:]
- constraint_kw['ondelete'] = rule
-
- table_kw = {}
- if table.schema or not row.in_schema:
- table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME)
-
- ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME),
- table_kw.get('schema'))
- if ref_key not in table.metadata.tables:
- schema.Table(normalize(fkey[0].REFTABLENAME),
- table.metadata,
- autoload=True, autoload_with=connection,
- **table_kw)
-
- constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True,
- **constraint_kw)
- table.append_constraint(constraint)
-
- def has_sequence(self, connection, name):
- # [ticket:726] makes this schema-aware.
- denormalize = self.identifier_preparer._denormalize_name
- sql = ("SELECT sequence_name FROM SEQUENCES "
- "WHERE SEQUENCE_NAME=? ")
-
- rp = connection.execute(sql, denormalize(name))
- found = bool(rp.fetchone())
- rp.close()
- return found
-
-
-class MaxDBCompiler(compiler.SQLCompiler):
- operators = compiler.SQLCompiler.operators.copy()
- operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y)
-
- function_conversion = {
- 'CURRENT_DATE': 'DATE',
- 'CURRENT_TIME': 'TIME',
- 'CURRENT_TIMESTAMP': 'TIMESTAMP',
- }
-
- # These functions must be written without parens when called with no
- # parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL'
- bare_functions = set([
- 'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP',
- 'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP',
- 'UTCDATE', 'UTCDIFF'])
-
- def default_from(self):
- return ' FROM DUAL'
-
- def for_update_clause(self, select):
- clause = select.for_update
- if clause is True:
- return " WITH LOCK EXCLUSIVE"
- elif clause is None:
- return ""
- elif clause == "read":
- return " WITH LOCK"
- elif clause == "ignore":
- return " WITH LOCK (IGNORE) EXCLUSIVE"
- elif clause == "nowait":
- return " WITH LOCK (NOWAIT) EXCLUSIVE"
- elif isinstance(clause, basestring):
- return " WITH LOCK %s" % clause.upper()
- elif not clause:
- return ""
- else:
- return " WITH LOCK EXCLUSIVE"
-
- def apply_function_parens(self, func):
- if func.name.upper() in self.bare_functions:
- return len(func.clauses) > 0
- else:
- return True
-
- def visit_function(self, fn, **kw):
- transform = self.function_conversion.get(fn.name.upper(), None)
- if transform:
- fn = fn._clone()
- fn.name = transform
- return super(MaxDBCompiler, self).visit_function(fn, **kw)
-
- def visit_cast(self, cast, **kwargs):
- # MaxDB only supports casts * to NUMERIC, * to VARCHAR or
- # date/time to VARCHAR. Casts of LONGs will fail.
- if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)):
- return "NUM(%s)" % self.process(cast.clause)
- elif isinstance(cast.type, sqltypes.String):
- return "CHR(%s)" % self.process(cast.clause)
- else:
- return self.process(cast.clause)
-
- def visit_sequence(self, sequence):
- if sequence.optional:
- return None
- else:
- return (self.dialect.identifier_preparer.format_sequence(sequence) +
- ".NEXTVAL")
-
- class ColumnSnagger(visitors.ClauseVisitor):
- def __init__(self):
- self.count = 0
- self.column = None
- def visit_column(self, column):
- self.column = column
- self.count += 1
-
- def _find_labeled_columns(self, columns, use_labels=False):
- labels = {}
- for column in columns:
- if isinstance(column, basestring):
- continue
- snagger = self.ColumnSnagger()
- snagger.traverse(column)
- if snagger.count == 1:
- if isinstance(column, sql_expr._Label):
- labels[unicode(snagger.column)] = column.name
- elif use_labels:
- labels[unicode(snagger.column)] = column._label
-
- return labels
-
- def order_by_clause(self, select):
- order_by = self.process(select._order_by_clause)
-
- # ORDER BY clauses in DISTINCT queries must reference aliased
- # inner columns by alias name, not true column name.
- if order_by and getattr(select, '_distinct', False):
- labels = self._find_labeled_columns(select.inner_columns,
- select.use_labels)
- if labels:
- for needs_alias in labels.keys():
- r = re.compile(r'(^| )(%s)(,| |$)' %
- re.escape(needs_alias))
- order_by = r.sub((r'\1%s\3' % labels[needs_alias]),
- order_by)
-
- # No ORDER BY in subqueries.
- if order_by:
- if self.is_subquery():
- # It's safe to simply drop the ORDER BY if there is no
- # LIMIT. Right? Other dialects seem to get away with
- # dropping order.
- if select._limit:
- raise exc.InvalidRequestError(
- "MaxDB does not support ORDER BY in subqueries")
- else:
- return ""
- return " ORDER BY " + order_by
- else:
- return ""
+ # No ORDER BY in subqueries.
+ if order_by:
+ if self.is_subquery():
+ # It's safe to simply drop the ORDER BY if there is no
+ # LIMIT. Right? Other dialects seem to get away with
+ # dropping order.
+ if select._limit:
+ raise exc.InvalidRequestError(
+ "MaxDB does not support ORDER BY in subqueries")
+ else:
+ return ""
+ return " ORDER BY " + order_by
+ else:
+ return ""
def get_select_precolumns(self, select):
# Convert a subquery's LIMIT to TOP
return name
-class MaxDBSchemaGenerator(compiler.SchemaGenerator):
+class MaxDBDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kw):
colspec = [self.preparer.format_column(column),
- column.type.dialect_impl(self.dialect).get_col_spec()]
+ self.dialect.type_compiler.process(column.type)]
if not column.nullable:
colspec.append('NOT NULL')
else:
return None
- def visit_sequence(self, sequence):
+ def visit_create_sequence(self, create):
"""Creates a SEQUENCE.
TODO: move to module doc?
maxdb_no_cache
Defaults to False. If true, sets NOCACHE.
"""
-
+ sequence = create.element
+
if (not sequence.optional and
(not self.checkfirst or
not self.dialect.has_sequence(self.connection, sequence.name))):
elif opts.get('no_cache', False):
ddl.append('NOCACHE')
- self.append(' '.join(ddl))
- self.execute()
+ return ' '.join(ddl)
-class MaxDBSchemaDropper(compiler.SchemaDropper):
- def visit_sequence(self, sequence):
- if (not sequence.optional and
- (not self.checkfirst or
- self.dialect.has_sequence(self.connection, sequence.name))):
- self.append("DROP SEQUENCE %s" %
- self.preparer.format_sequence(sequence))
- self.execute()
+class MaxDBDialect(default.DefaultDialect):
+ name = 'maxdb'
+ supports_alter = True
+ supports_unicode_statements = True
+ max_identifier_length = 32
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+ preexecute_pk_sequences = True
+
+ preparer = MaxDBIdentifierPreparer
+ statement_compiler = MaxDBCompiler
+ ddl_compiler = MaxDBDDLCompiler
+ defaultrunner = MaxDBDefaultRunner
+ execution_ctx_cls = MaxDBExecutionContext
+
+ colspecs = colspecs
+ ischema_names = ischema_names
+
+ # MaxDB-specific
+ datetimeformat = 'internal'
+
+ def __init__(self, _raise_known_sql_errors=False, **kw):
+ super(MaxDBDialect, self).__init__(**kw)
+ self._raise_known = _raise_known_sql_errors
+
+ if self.dbapi is None:
+ self.dbapi_type_map = {}
+ else:
+ self.dbapi_type_map = {
+ 'Long Binary': MaxBlob(),
+ 'Long byte_t': MaxBlob(),
+ 'Long Unicode': MaxText(),
+ 'Timestamp': MaxTimestamp(),
+ 'Date': MaxDate(),
+ 'Time': MaxTime(),
+ datetime.datetime: MaxTimestamp(),
+ datetime.date: MaxDate(),
+ datetime.time: MaxTime(),
+ }
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ res = cursor.execute(statement, parameters)
+ if isinstance(res, int) and context is not None:
+ context._rowcount = res
+
+ def do_release_savepoint(self, connection, name):
+ # Does MaxDB truly support RELEASE SAVEPOINT <id>? All my attempts
+ # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS
+ # BEGIN SQLSTATE: I7065"
+ # Note that ROLLBACK TO works fine. In theory, a RELEASE should
+ # just free up some transactional resources early, before the overall
+ # COMMIT/ROLLBACK so omitting it should be relatively ok.
+ pass
+
+ def get_default_schema_name(self, connection):
+ try:
+ return self._default_schema_name
+ except AttributeError:
+ name = self.identifier_preparer._normalize_name(
+ connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar())
+ self._default_schema_name = name
+ return name
+
+ def has_table(self, connection, table_name, schema=None):
+ denormalize = self.identifier_preparer._denormalize_name
+ bind = [denormalize(table_name)]
+ if schema is None:
+ sql = ("SELECT tablename FROM TABLES "
+ "WHERE TABLES.TABLENAME=? AND"
+ " TABLES.SCHEMANAME=CURRENT_SCHEMA ")
+ else:
+ sql = ("SELECT tablename FROM TABLES "
+ "WHERE TABLES.TABLENAME = ? AND"
+ " TABLES.SCHEMANAME=? ")
+ bind.append(denormalize(schema))
+
+ rp = connection.execute(sql, bind)
+ found = bool(rp.fetchone())
+ rp.close()
+ return found
+
+ def table_names(self, connection, schema):
+ if schema is None:
+ sql = (" SELECT TABLENAME FROM TABLES WHERE "
+ " SCHEMANAME=CURRENT_SCHEMA ")
+ rs = connection.execute(sql)
+ else:
+ sql = (" SELECT TABLENAME FROM TABLES WHERE "
+ " SCHEMANAME=? ")
+ matchname = self.identifier_preparer._denormalize_name(schema)
+ rs = connection.execute(sql, matchname)
+ normalize = self.identifier_preparer._normalize_name
+ return [normalize(row[0]) for row in rs]
+
+ def reflecttable(self, connection, table, include_columns):
+ denormalize = self.identifier_preparer._denormalize_name
+ normalize = self.identifier_preparer._normalize_name
+
+ st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, '
+ ' NULLABLE, "DEFAULT", DEFAULTFUNCTION '
+ 'FROM COLUMNS '
+ 'WHERE TABLENAME=? AND SCHEMANAME=%s '
+ 'ORDER BY POS')
+
+ fk = ('SELECT COLUMNNAME, FKEYNAME, '
+ ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, '
+ ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA '
+ ' THEN 1 ELSE 0 END) AS in_schema '
+ 'FROM FOREIGNKEYCOLUMNS '
+ 'WHERE TABLENAME=? AND SCHEMANAME=%s '
+ 'ORDER BY FKEYNAME ')
+
+ params = [denormalize(table.name)]
+ if not table.schema:
+ st = st % 'CURRENT_SCHEMA'
+ fk = fk % 'CURRENT_SCHEMA'
+ else:
+ st = st % '?'
+ fk = fk % '?'
+ params.append(denormalize(table.schema))
+
+ rows = connection.execute(st, params).fetchall()
+ if not rows:
+ raise exc.NoSuchTableError(table.fullname)
+
+ include_columns = set(include_columns or [])
+
+ for row in rows:
+ (name, mode, col_type, encoding, length, scale,
+ nullable, constant_def, func_def) = row
+
+ name = normalize(name)
+
+ if include_columns and name not in include_columns:
+ continue
+
+ type_args, type_kw = [], {}
+ if col_type == 'FIXED':
+ type_args = length, scale
+ # Convert FIXED(10) DEFAULT SERIAL to our Integer
+ if (scale == 0 and
+ func_def is not None and func_def.startswith('SERIAL')):
+ col_type = 'INTEGER'
+ type_args = length,
+ elif col_type in 'FLOAT':
+ type_args = length,
+ elif col_type in ('CHAR', 'VARCHAR'):
+ type_args = length,
+ type_kw['encoding'] = encoding
+ elif col_type == 'LONG':
+ type_kw['encoding'] = encoding
+
+ try:
+ type_cls = ischema_names[col_type.lower()]
+ type_instance = type_cls(*type_args, **type_kw)
+ except KeyError:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (col_type, name))
+ type_instance = sqltypes.NullType
+
+ col_kw = {'autoincrement': False}
+ col_kw['nullable'] = (nullable == 'YES')
+ col_kw['primary_key'] = (mode == 'KEY')
+
+ if func_def is not None:
+ if func_def.startswith('SERIAL'):
+ if col_kw['primary_key']:
+ # No special default- let the standard autoincrement
+ # support handle SERIAL pk columns.
+ col_kw['autoincrement'] = True
+ else:
+ # strip current numbering
+ col_kw['server_default'] = schema.DefaultClause(
+ sql.text('SERIAL'))
+ col_kw['autoincrement'] = True
+ else:
+ col_kw['server_default'] = schema.DefaultClause(
+ sql.text(func_def))
+ elif constant_def is not None:
+ col_kw['server_default'] = schema.DefaultClause(sql.text(
+ "'%s'" % constant_def.replace("'", "''")))
+
+ table.append_column(schema.Column(name, type_instance, **col_kw))
+
+ fk_sets = itertools.groupby(connection.execute(fk, params),
+ lambda row: row.FKEYNAME)
+ for fkeyname, fkey in fk_sets:
+ fkey = list(fkey)
+ if include_columns:
+ key_cols = set([r.COLUMNNAME for r in fkey])
+ if key_cols != include_columns:
+ continue
+
+ columns, referants = [], []
+ quote = self.identifier_preparer._maybe_quote_identifier
+
+ for row in fkey:
+ columns.append(normalize(row.COLUMNNAME))
+ if table.schema or not row.in_schema:
+ referants.append('.'.join(
+ [quote(normalize(row[c]))
+ for c in ('REFSCHEMANAME', 'REFTABLENAME',
+ 'REFCOLUMNNAME')]))
+ else:
+ referants.append('.'.join(
+ [quote(normalize(row[c]))
+ for c in ('REFTABLENAME', 'REFCOLUMNNAME')]))
+
+ constraint_kw = {'name': fkeyname.lower()}
+ if fkey[0].RULE is not None:
+ rule = fkey[0].RULE
+ if rule.startswith('DELETE '):
+ rule = rule[7:]
+ constraint_kw['ondelete'] = rule
+
+ table_kw = {}
+ if table.schema or not row.in_schema:
+ table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME)
+
+ ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME),
+ table_kw.get('schema'))
+ if ref_key not in table.metadata.tables:
+ schema.Table(normalize(fkey[0].REFTABLENAME),
+ table.metadata,
+ autoload=True, autoload_with=connection,
+ **table_kw)
+
+ constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True,
+ **constraint_kw)
+ table.append_constraint(constraint)
+
+ def has_sequence(self, connection, name):
+ # [ticket:726] makes this schema-aware.
+ denormalize = self.identifier_preparer._denormalize_name
+ sql = ("SELECT sequence_name FROM SEQUENCES "
+ "WHERE SEQUENCE_NAME=? ")
+
+ rp = connection.execute(sql, denormalize(name))
+ found = bool(rp.fetchone())
+ rp.close()
+ return found
+
def _autoserial_column(table):
return None, None
-dialect = MaxDBDialect
-dialect.preparer = MaxDBIdentifierPreparer
-dialect.statement_compiler = MaxDBCompiler
-dialect.schemagenerator = MaxDBSchemaGenerator
-dialect.schemadropper = MaxDBSchemaDropper
-dialect.defaultrunner = MaxDBDefaultRunner
-dialect.execution_ctx_cls = MaxDBExecutionContext
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.maxdb.base import MaxDBDialect
+
+class MaxDB_sapdb(MaxDBDialect):
+ driver = 'sapdb'
+
+ @classmethod
+ def dbapi(cls):
+ from sapdb import dbapi as _dbapi
+ return _dbapi
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username='user')
+ opts.update(url.query)
+ return [], opts
+
+
+dialect = MaxDB_sapdb
\ No newline at end of file
spec += ' ZEROFILL'
return spec
- def _extend_string(self, type_, spec):
+ def _extend_string(self, type_, defaults, spec):
"""Extend a string-type declaration with standard SQL CHARACTER SET /
COLLATE annotations and MySQL specific extensions.
"""
- if not self._mysql_type(type_):
- return spec
+
+ def attr(name):
+ return getattr(type_, name, defaults.get(name))
- if type_.charset:
- charset = 'CHARACTER SET %s' % type_.charset
- elif type_.ascii:
+ if attr('charset'):
+ charset = 'CHARACTER SET %s' % attr('charset')
+ elif attr('ascii'):
charset = 'ASCII'
- elif type_.unicode:
+ elif attr('unicode'):
charset = 'UNICODE'
else:
charset = None
- if type_.collation:
+ if attr('collation'):
collation = 'COLLATE %s' % type_.collation
- elif type_.binary:
+ elif attr('binary'):
collation = 'BINARY'
else:
collation = None
- if type_.national:
+ if attr('national'):
# NATIONAL (aka NCHAR/NVARCHAR) trumps charsets.
return ' '.join([c for c in ('NATIONAL', spec, collation)
if c is not None])
def visit_TEXT(self, type_):
if type_.length:
- return self._extend_string(type_, "TEXT(%d)" % type_.length)
+ return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
else:
- return self._extend_string(type_, "TEXT")
+ return self._extend_string(type_, {}, "TEXT")
def visit_TINYTEXT(self, type_):
- return self._extend_string(type_, "TINYTEXT")
+ return self._extend_string(type_, {}, "TINYTEXT")
def visit_MEDIUMTEXT(self, type_):
- return self._extend_string(type_, "MEDIUMTEXT")
+ return self._extend_string(type_, {}, "MEDIUMTEXT")
def visit_LONGTEXT(self, type_):
- return self._extend_string(type_, "LONGTEXT")
+ return self._extend_string(type_, {}, "LONGTEXT")
def visit_VARCHAR(self, type_):
if type_.length:
- return self._extend_string(type_, "VARCHAR(%d)" % type_.length)
+ return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length)
else:
- return self._extend_string(type_, "VARCHAR")
+ return self._extend_string(type_, {}, "VARCHAR")
def visit_CHAR(self, type_):
- return self._extend_string(type_, "CHAR(%(length)s)" % {'length' : type_.length})
+ return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length' : type_.length})
def visit_NVARCHAR(self, type_):
# We'll actually generate the equiv. "NATIONAL VARCHAR" instead
# of "NVARCHAR".
- return self._extend_string(type_, "VARCHAR(%(length)s)" % {'length': type_.length})
+ return self._extend_string(type_, {'national':True}, "VARCHAR(%(length)s)" % {'length': type_.length})
def visit_NCHAR(self, type_):
# We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR".
- return self._extend_string(type_, "CHAR(%(length)s)" % {'length': type_.length})
+ return self._extend_string(type_, {'national':True}, "CHAR(%(length)s)" % {'length': type_.length})
def visit_VARBINARY(self, type_):
if type_.length:
quoted_enums = []
for e in type_.enums:
quoted_enums.append("'%s'" % e.replace("'", "''"))
- return self._extend_string(type_, "ENUM(%s)" % ",".join(quoted_enums))
+ return self._extend_string(type_, {}, "ENUM(%s)" % ",".join(quoted_enums))
def visit_SET(self, type_):
- return self._extend_string(type_, "SET(%s)" % ",".join(type_._ddl_values))
+ return self._extend_string(type_, {}, "SET(%s)" % ",".join(type_._ddl_values))
def visit_BOOLEAN(self, type):
return "BOOL"
# Oracle does not allow milliseconds in DATE
# Oracle does not support TIME columns
- def visit_DATETIME(self, type_):
+ def visit_datetime(self, type_):
return self.visit_DATE(type_)
def visit_VARCHAR(self, type_):
def visit_NVARCHAR(self, type_):
return "NVARCHAR2(%(length)s)" % {'length' : type_.length}
- def visit_TEXT(self, type_):
+ def visit_text(self, type_):
return self.visit_CLOB(type_)
- def visit_BINARY(self, type_):
+ def visit_binary(self, type_):
return self.visit_BLOB(type_)
- def visit_BOOLEAN(self, type_):
+ def visit_boolean(self, type_):
return self.visit_SMALLINT(type_)
def visit_RAW(self, type_):
def visit_BIGINT(self, type_):
return "BIGINT"
- def visit_DATETIME(self, type_):
+ def visit_datetime(self, type_):
+ return self.visit_TIMESTAMP(type_)
+
+ def visit_TIMESTAMP(self, type_):
return "TIMESTAMP " + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
def visit_TIME(self, type_):
def visit_binary(self, type_):
return self.visit_BLOB(type_)
- def visit_CLOB(self, type_):
- return self.visit_TEXT(type_)
-
- def visit_NCHAR(self, type_):
- return self.visit_CHAR(type_)
-
class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = set([
'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
--- /dev/null
+from sqlalchemy.dialects.sybase import base, pyodbc
+
+# default dialect
+base.dialect = pyodbc.dialect
\ No newline at end of file
--- /dev/null
+# sybase.py
+# Copyright (C) 2007 Fisch Asset Management AG http://www.fam.ch
+# Coding: Alexander Houben alexander.houben@thor-solutions.ch
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""
+Sybase database backend.
+
+--- THIS BACKEND NOT YET TESTED ON SQLALCHEMY 0.6 ---
+
+Known issues / TODO:
+
+ * Uses the mx.ODBC driver from egenix (version 2.1.0)
+ * The current version of sqlalchemy.databases.sybase only supports
+ mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need
+ some development)
+ * Support for pyodbc has been built in but is not yet complete (needs
+ further development)
+ * Results of running tests/alltests.py:
+ Ran 934 tests in 287.032s
+ FAILED (failures=3, errors=1)
+ * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751)
+"""
+
+import datetime, operator
+
+from sqlalchemy import util, sql, schema, exc
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.engine import default, base
+from sqlalchemy import types as sqltypes
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import MetaData, Table, Column
+from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey
+from sqlalchemy.dialects.sybase.schema import *
+
+__all__ = [
+ 'SybaseMoney', 'SybaseSmallMoney',
+ 'SybaseUniqueIdentifier',
+ ]
+
+
+RESERVED_WORDS = set([
+ "add", "all", "alter", "and",
+ "any", "as", "asc", "backup",
+ "begin", "between", "bigint", "binary",
+ "bit", "bottom", "break", "by",
+ "call", "capability", "cascade", "case",
+ "cast", "char", "char_convert", "character",
+ "check", "checkpoint", "close", "comment",
+ "commit", "connect", "constraint", "contains",
+ "continue", "convert", "create", "cross",
+ "cube", "current", "current_timestamp", "current_user",
+ "cursor", "date", "dbspace", "deallocate",
+ "dec", "decimal", "declare", "default",
+ "delete", "deleting", "desc", "distinct",
+ "do", "double", "drop", "dynamic",
+ "else", "elseif", "encrypted", "end",
+ "endif", "escape", "except", "exception",
+ "exec", "execute", "existing", "exists",
+ "externlogin", "fetch", "first", "float",
+ "for", "force", "foreign", "forward",
+ "from", "full", "goto", "grant",
+ "group", "having", "holdlock", "identified",
+ "if", "in", "index", "index_lparen",
+ "inner", "inout", "insensitive", "insert",
+ "inserting", "install", "instead", "int",
+ "integer", "integrated", "intersect", "into",
+ "iq", "is", "isolation", "join",
+ "key", "lateral", "left", "like",
+ "lock", "login", "long", "match",
+ "membership", "message", "mode", "modify",
+ "natural", "new", "no", "noholdlock",
+ "not", "notify", "null", "numeric",
+ "of", "off", "on", "open",
+ "option", "options", "or", "order",
+ "others", "out", "outer", "over",
+ "passthrough", "precision", "prepare", "primary",
+ "print", "privileges", "proc", "procedure",
+ "publication", "raiserror", "readtext", "real",
+ "reference", "references", "release", "remote",
+ "remove", "rename", "reorganize", "resource",
+ "restore", "restrict", "return", "revoke",
+ "right", "rollback", "rollup", "save",
+ "savepoint", "scroll", "select", "sensitive",
+ "session", "set", "setuser", "share",
+ "smallint", "some", "sqlcode", "sqlstate",
+ "start", "stop", "subtrans", "subtransaction",
+ "synchronize", "syntax_error", "table", "temporary",
+ "then", "time", "timestamp", "tinyint",
+ "to", "top", "tran", "trigger",
+ "truncate", "tsequal", "unbounded", "union",
+ "unique", "unknown", "unsigned", "update",
+ "updating", "user", "using", "validate",
+ "values", "varbinary", "varchar", "variable",
+ "varying", "view", "wait", "waitfor",
+ "when", "where", "while", "window",
+ "with", "with_cube", "with_lparen", "with_rollup",
+ "within", "work", "writetext",
+ ])
+
+
+class SybaseImage(sqltypes.Binary):
+ __visit_name__ = 'IMAGE'
+
+class SybaseBit(sqltypes.TypeEngine):
+ __visit_name__ = 'BIT'
+
+class SybaseMoney(sqltypes.TypeEngine):
+ __visit_name__ = "MONEY"
+
+class SybaseSmallMoney(SybaseMoney):
+ __visit_name__ = "SMALLMONEY"
+
+class SybaseUniqueIdentifier(sqltypes.TypeEngine):
+ __visit_name__ = "UNIQUEIDENTIFIER"
+
+class SybaseBoolean(sqltypes.Boolean):
+ def result_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ return value and True or False
+ return process
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is True:
+ return 1
+ elif value is False:
+ return 0
+ elif value is None:
+ return None
+ else:
+ return value and True or False
+ return process
+
+class SybaseTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_binary(self, type_):
+ return self.visit_IMAGE(type_)
+
+ def visit_boolean(self, type_):
+ return self.visit_BIT(type_)
+
+ def visit_IMAGE(self, type_):
+ return "IMAGE"
+
+ def visit_BIT(self, type_):
+ return "BIT"
+
+ def visit_MONEY(self, type_):
+ return "MONEY"
+
+ def visit_SMALLMONEY(self, type_):
+ return "SMALLMONEY"
+
+ def visit_UNIQUEIDENTIFIER(self, type_):
+ return "UNIQUEIDENTIFIER"
+
+colspecs = {
+ sqltypes.Binary : SybaseImage,
+ sqltypes.Boolean : SybaseBoolean,
+}
+
+ischema_names = {
+ 'integer' : sqltypes.INTEGER,
+ 'unsigned int' : sqltypes.Integer,
+ 'unsigned smallint' : sqltypes.SmallInteger,
+ 'unsigned bigint' : sqltypes.BigInteger,
+ 'bigint': sqltypes.BIGINT,
+ 'smallint' : sqltypes.SMALLINT,
+ 'tinyint' : sqltypes.SmallInteger,
+ 'varchar' : sqltypes.VARCHAR,
+ 'long varchar' : sqltypes.Text,
+ 'char' : sqltypes.CHAR,
+ 'decimal' : sqltypes.DECIMAL,
+ 'numeric' : sqltypes.NUMERIC,
+ 'float' : sqltypes.FLOAT,
+ 'double' : sqltypes.Numeric,
+ 'binary' : sqltypes.Binary,
+ 'long binary' : sqltypes.Binary,
+ 'varbinary' : sqltypes.Binary,
+ 'bit': SybaseBit,
+ 'image' : SybaseImage,
+ 'timestamp': sqltypes.TIMESTAMP,
+ 'money': SybaseMoney,
+ 'smallmoney': SybaseSmallMoney,
+ 'uniqueidentifier': SybaseUniqueIdentifier,
+
+}
+
+
+class SybaseExecutionContext(default.DefaultExecutionContext):
+
+ def post_exec(self):
+ if self.compiled.isinsert:
+ table = self.compiled.statement.table
+ # get the inserted values of the primary key
+
+ # get any sequence IDs first (using @@identity)
+ self.cursor.execute("SELECT @@identity AS lastrowid")
+ row = self.cursor.fetchone()
+ lastrowid = int(row[0])
+ if lastrowid > 0:
+ # an IDENTITY was inserted, fetch it
+ # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?!
+ if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None:
+ self._last_inserted_ids = [lastrowid]
+ else:
+ self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:]
+
+
+class SybaseSQLCompiler(compiler.SQLCompiler):
+ operators = compiler.SQLCompiler.operators.copy()
+ operators.update({
+ sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y),
+ })
+
+ def bindparam_string(self, name):
+ res = super(SybaseSQLCompiler, self).bindparam_string(name)
+ if name.lower().startswith('literal'):
+ res = 'STRING(%s)' % res
+ return res
+
+ def get_select_precolumns(self, select):
+ s = select._distinct and "DISTINCT " or ""
+ if select._limit:
+ #if select._limit == 1:
+ #s += "FIRST "
+ #else:
+ #s += "TOP %s " % (select._limit,)
+ s += "TOP %s " % (select._limit,)
+ if select._offset:
+ if not select._limit:
+ # FIXME: sybase doesn't allow an offset without a limit
+ # so use a huge value for TOP here
+ s += "TOP 1000000 "
+ s += "START AT %s " % (select._offset+1,)
+ return s
+
+ def limit_clause(self, select):
+ # Limit in sybase is after the select keyword
+ return ""
+
+ def visit_binary(self, binary):
+ """Move bind parameters to the right-hand side of an operator, where possible."""
+ if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
+ return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
+ else:
+ return super(SybaseSQLCompiler, self).visit_binary(binary)
+
+ def label_select_column(self, select, column, asfrom):
+ if isinstance(column, expression.Function):
+ return column.label(None)
+ else:
+ return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
+
+ function_rewrites = {'current_date': 'getdate',
+ }
+ def visit_function(self, func):
+ func.name = self.function_rewrites.get(func.name, func.name)
+ res = super(SybaseSQLCompiler, self).visit_function(func)
+ if func.name.lower() == 'getdate':
+ # apply CAST operator
+ # FIXME: what about _pyodbc ?
+ cast = expression._Cast(func, SybaseDate_mxodbc)
+ # infinite recursion
+ # res = self.visit_cast(cast)
+ res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
+ return res
+
+ def for_update_clause(self, select):
+ # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
+ return ''
+
+ def order_by_clause(self, select):
+ order_by = self.process(select._order_by_clause)
+
+ # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
+ if order_by and (not self.is_subquery() or select._limit):
+ return " ORDER BY " + order_by
+ else:
+ return ""
+
+
+class SybaseDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+
+ colspec = self.preparer.format_column(column)
+
+ if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
+ column.autoincrement and isinstance(column.type, sqltypes.Integer):
+ if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
+ column.sequence = schema.Sequence(column.name + '_seq')
+
+ if hasattr(column, 'sequence'):
+ column.table.has_sequence = column
+ #colspec += " numeric(30,0) IDENTITY"
+ colspec += " Integer IDENTITY"
+ else:
+ colspec += " " + self.dialect.type_compiler.process(column.type)
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ return colspec
+
+ def visit_drop_index(self, drop):
+ index = drop.element
+ return "\nDROP INDEX %s.%s" % (
+ self.preparer.quote_identifier(index.table.name),
+ self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
+ )
+
+class SybaseIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = RESERVED_WORDS
+
+class SybaseDialect(default.DefaultDialect):
+ name = 'sybase'
+ supports_unicode_statements = False
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+ colspecs = colspecs
+ ischema_names = ischema_names
+
+ type_compiler = SybaseTypeCompiler
+ statement_compiler = SybaseSQLCompiler
+ ddl_compiler = SybaseDDLCompiler
+ preparer = SybaseIdentifierPreparer
+
+ schema_name = "dba"
+
+ def __init__(self, **params):
+ super(SybaseDialect, self).__init__(**params)
+ self.text_as_varchar = False
+
+ def last_inserted_ids(self):
+ return self.context.last_inserted_ids
+
+ def get_default_schema_name(self, connection):
+ return self.schema_name
+
+ def table_names(self, connection, schema):
+ """Ignore the schema and the charset for now."""
+ s = sql.select([tables.c.table_name],
+ sql.not_(tables.c.table_name.like("SYS%")) and
+ tables.c.creator >= 100
+ )
+ rp = connection.execute(s)
+ return [row[0] for row in rp.fetchall()]
+
+ def has_table(self, connection, tablename, schema=None):
+ # FIXME: ignore schemas for sybase
+ s = sql.select([tables.c.table_name], tables.c.table_name == tablename)
+
+ c = connection.execute(s)
+ row = c.fetchone()
+ print "has_table: " + tablename + ": " + str(bool(row is not None))
+ return row is not None
+
+ def reflecttable(self, connection, table, include_columns):
+ # Get base columns
+ if table.schema is not None:
+ current_schema = table.schema
+ else:
+ current_schema = self.get_default_schema_name(connection)
+
+ s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id])
+
+ c = connection.execute(s)
+ found_table = False
+ # makes sure we append the columns in the correct order
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ found_table = True
+ (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = (
+ row[columns.c.column_name],
+ row[domains.c.domain_name],
+ row[columns.c.nulls] == 'Y',
+ row[columns.c.width],
+ row[domains.c.precision],
+ row[columns.c.scale],
+ row[columns.c.default],
+ row[columns.c.pkey] == 'Y',
+ row[columns.c.max_identity],
+ row[tables.c.table_id],
+ row[columns.c.column_id],
+ )
+ if include_columns and name not in include_columns:
+ continue
+
+ # FIXME: else problems with SybaseBinary(size)
+ if numericscale == 0:
+ numericscale = None
+
+ args = []
+ for a in (charlen, numericprec, numericscale):
+ if a is not None:
+ args.append(a)
+ coltype = self.ischema_names.get(type, None)
+ if coltype == SybaseString and charlen == -1:
+ coltype = SybaseText()
+ else:
+ if coltype is None:
+ util.warn("Did not recognize type '%s' of column '%s'" %
+ (type, name))
+ coltype = sqltypes.NULLTYPE
+ coltype = coltype(*args)
+ colargs = []
+ if default is not None:
+ colargs.append(schema.DefaultClause(sql.text(default)))
+
+ # any sequences ?
+ col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs)
+ if int(max_identity) > 0:
+ col.sequence = schema.Sequence(name + '_identity')
+ col.sequence.start = int(max_identity)
+ col.sequence.increment = 1
+
+ # append the column
+ table.append_column(col)
+
+ # any foreign key constraint for this table ?
+ # note: no multi-column foreign keys are considered
+ s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name }
+ c = connection.execute(s)
+ foreignKeys = {}
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ (foreign_table, foreign_column, primary_table, primary_column) = (
+ row[0], row[1], row[2], row[3],
+ )
+ if not primary_table in foreignKeys.keys():
+ foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]]
+ else:
+ foreignKeys[primary_table][0].append('%s'%(foreign_column))
+ foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column))
+ for primary_table in foreignKeys.keys():
+ #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)]))
+ table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True))
+
+ if not found_table:
+ raise exc.NoSuchTableError(table.name)
+
--- /dev/null
+from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
+from sqlalchemy.connectors.mxodbc import MxODBCConnector
+
+class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
+ pass
+
+class Sybase_mxodbc(MxODBCConnector, SybaseDialect):
+ execution_ctx_cls = SybaseExecutionContext_mxodbc
+
+dialect = Sybase_mxodbc
\ No newline at end of file
--- /dev/null
+from sqlalchemy.dialects.sybase.base import SybaseDialect, SybaseExecutionContext
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+
+class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
+ pass
+
+
+class Sybase_pyodbc(PyODBCConnector, SybaseDialect):
+ execution_ctx_cls = SybaseExecutionContext_pyodbc
+
+dialect = Sybase_pyodbc
\ No newline at end of file
--- /dev/null
+from sqlalchemy import *
+
+ischema = MetaData()
+
+tables = Table("SYSTABLE", ischema,
+ Column("table_id", Integer, primary_key=True),
+ Column("file_id", SMALLINT),
+ Column("table_name", CHAR(128)),
+ Column("table_type", CHAR(10)),
+ Column("creator", Integer),
+ #schema="information_schema"
+ )
+
+domains = Table("SYSDOMAIN", ischema,
+ Column("domain_id", Integer, primary_key=True),
+ Column("domain_name", CHAR(128)),
+ Column("type_id", SMALLINT),
+ Column("precision", SMALLINT, quote=True),
+ #schema="information_schema"
+ )
+
+columns = Table("SYSCOLUMN", ischema,
+ Column("column_id", Integer, primary_key=True),
+ Column("table_id", Integer, ForeignKey(tables.c.table_id)),
+ Column("pkey", CHAR(1)),
+ Column("column_name", CHAR(128)),
+ Column("nulls", CHAR(1)),
+ Column("width", SMALLINT),
+ Column("domain_id", SMALLINT, ForeignKey(domains.c.domain_id)),
+ # FIXME: should be mx.BIGINT
+ Column("max_identity", Integer),
+ # FIXME: should be mx.ODBC.Windows.LONGVARCHAR
+ Column("default", String),
+ Column("scale", Integer),
+ #schema="information_schema"
+ )
+
+foreignkeys = Table("SYSFOREIGNKEY", ischema,
+ Column("foreign_table_id", Integer, ForeignKey(tables.c.table_id), primary_key=True),
+ Column("foreign_key_id", SMALLINT, primary_key=True),
+ Column("primary_table_id", Integer, ForeignKey(tables.c.table_id)),
+ #schema="information_schema"
+ )
+fkcols = Table("SYSFKCOL", ischema,
+ Column("foreign_table_id", Integer, ForeignKey(columns.c.table_id), primary_key=True),
+ Column("foreign_key_id", SMALLINT, ForeignKey(foreignkeys.c.foreign_key_id), primary_key=True),
+ Column("foreign_column_id", Integer, ForeignKey(columns.c.column_id), primary_key=True),
+ Column("primary_column_id", Integer),
+ #schema="information_schema"
+ )
+
def visit_SMALLINT(self, type_):
return "SMALLINT"
+ def visit_BIGINT(self, type_):
+ return "BIGINT"
+
def visit_TIMESTAMP(self, type_):
return 'TIMESTAMP'
def visit_BLOB(self, type_):
return "BLOB"
- def visit_BINARY(self, type_):
- return "BINARY"
-
def visit_BOOLEAN(self, type_):
return "BOOLEAN"
return "TEXT"
def visit_binary(self, type_):
- return self.visit_BINARY(type_)
+ return self.visit_BLOB(type_)
+
def visit_boolean(self, type_):
return self.visit_BOOLEAN(type_)
+
def visit_time(self, type_):
return self.visit_TIME(type_)
+
def visit_datetime(self, type_):
return self.visit_DATETIME(type_)
+
def visit_date(self, type_):
return self.visit_DATE(type_)
+
def visit_small_integer(self, type_):
return self.visit_SMALLINT(type_)
+
def visit_integer(self, type_):
return self.visit_INTEGER(type_)
+
def visit_float(self, type_):
return self.visit_FLOAT(type_)
+
def visit_numeric(self, type_):
return self.visit_NUMERIC(type_)
+
def visit_string(self, type_):
return self.visit_VARCHAR(type_)
+
def visit_text(self, type_):
return self.visit_TEXT(type_)
"""
__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType',
- 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT',
+ 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text', 'FLOAT',
'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB',
- 'BOOLEAN', 'SMALLINT', 'DATE', 'TIME',
+ 'BOOLEAN', 'SMALLINT', 'INTEGER','DATE', 'TIME',
'String', 'Integer', 'SmallInteger',
'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary',
'Boolean', 'Unicode', 'MutableType', 'Concatenable', 'UnicodeText', 'PickleType', 'Interval',
def __init__(self, *args, **kwargs):
pass
+ def compile(self, dialect):
+ return dialect.type_compiler.process(self)
+
def copy_value(self, value):
return value
__visit_name__ = 'small_integer'
+class BigInteger(Integer):
+ """A type for bigger ``int`` integers.
+
+ Typically generates a ``BIGINT`` in DDL, and otherwise acts like
+ a normal :class:`Integer` on the Python side.
+
+ """
+
+ __visit_name__ = 'big_integer'
+
class Numeric(TypeEngine):
"""A type for fixed precision numbers.
__visit_name__ = 'SMALLINT'
+class BIGINT(SmallInteger):
+ """The SQL BIGINT type."""
+
+ __visit_name__ = 'BIGINT'
+
class TIMESTAMP(DateTime):
"""The SQL TIMESTAMP type."""
self.assertEqual(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0])
self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1])
self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2])
- self.assertEqual(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
+ self.assertEqual(str(cast(1234, Text).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
self.assertEqual(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4]))
# fixme: shoving all of this dialect-specific stuff in one test
# is now officialy completely ridiculous AND non-obviously omits
firebird_dialect = firebird.FBDialect()
for dialect, start, test in [
- (oracle_dialect, String(), oracle.OracleString),
- (oracle_dialect, VARCHAR(), oracle.OracleString),
- (oracle_dialect, String(50), oracle.OracleString),
- (oracle_dialect, Unicode(), oracle.OracleString),
+ (oracle_dialect, String(), String),
+ (oracle_dialect, VARCHAR(), VARCHAR),
+ (oracle_dialect, String(50), String),
+ (oracle_dialect, Unicode(), Unicode),
(oracle_dialect, UnicodeText(), oracle.OracleText),
- (oracle_dialect, NCHAR(), oracle.OracleString),
+ (oracle_dialect, NCHAR(), NCHAR),
(oracle_dialect, oracle.OracleRaw(50), oracle.OracleRaw),
(mysql_dialect, String(), mysql.MSString),
(mysql_dialect, VARCHAR(), mysql.MSString),
]:
assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect))
-
+ def test_uppercase_rendering(self):
+ """Test that uppercase types from types.py always render as their type.
+
+ As of SQLA 0.6, using an uppercase type means you want specifically that
+ type. If the database in use doesn't support that DDL, it (the DB backend)
+ should raise an error - it means you should be using a lowercased (genericized) type.
+
+ """
+
+ for dialect in [oracle.dialect(), mysql.dialect(), postgres.dialect(), sqlite.dialect(), sybase.dialect(), informix.dialect(), maxdb.dialect()]: #engines.all_dialects():
+ for type_, expected in (
+ (FLOAT, "FLOAT"),
+ (NUMERIC, "NUMERIC"),
+ (DECIMAL, "DECIMAL"),
+ (INTEGER, "INTEGER"),
+ (SMALLINT, "SMALLINT"),
+ (TIMESTAMP, "TIMESTAMP"),
+ (DATETIME, "DATETIME"),
+ (DATE, "DATE"),
+ (TIME, "TIME"),
+ (CLOB, "CLOB"),
+ (VARCHAR, "VARCHAR"),
+ (NVARCHAR, ("NVARCHAR", "NATIONAL VARCHAR")),
+ (CHAR, "CHAR"),
+ (NCHAR, ("NCHAR", "NATIONAL CHAR")),
+ (BLOB, "BLOB"),
+ (BOOLEAN, ("BOOLEAN", "BOOL"))
+ ):
+ if isinstance(expected, str):
+ expected = (expected, )
+ for exp in expected:
+ compiled = type_().compile(dialect=dialect)
+ if exp in compiled:
+ break
+ else:
+ assert False, "%r matches none of %r for dialect %s" % (compiled, expected, dialect.name)
+
class UserDefinedTest(TestBase):
"""tests user-defined types."""