From: Mike Bayer Date: Sun, 12 Jul 2009 01:23:03 +0000 (+0000) Subject: - firebird support. reflection works fully, overall test success in the 75% range... X-Git-Tag: rel_0_6_6~141 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=faa6c79614d5f3a3e0fe074d8c1999357b6499f5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - firebird support. reflection works fully, overall test success in the 75% range approx - oracle and firebird now normalize column names to SQLA "lowercase" for result.keys() allowing consistent identifier name visibility on the client side. --- diff --git a/06CHANGES b/06CHANGES index 1742cfc2f7..413d2972a1 100644 --- a/06CHANGES +++ b/06CHANGES @@ -90,11 +90,26 @@ - func.char_length is a generic function for LENGTH - ForeignKey() which includes onupdate= will emit a warning, not emit ON UPDATE CASCADE which is unsupported by oracle + - the keys() method of RowProxy() now returns the result column names *normalized* + to be SQLAlchemy case insensitive names. This means they will be lower case + for case insensitive names, whereas the DBAPI would normally return them + as UPPERCASE names. This allows row keys() to be compatible with further + SQLAlchemy operations. + +- firebird + - the keys() method of RowProxy() now returns the result column names *normalized* + to be SQLAlchemy case insensitive names. This means they will be lower case + for case insensitive names, whereas the DBAPI would normally return them + as UPPERCASE names. This allows row keys() to be compatible with further + SQLAlchemy operations. - new dialects - - pg8000 - - pyodbc+mysql - + - postgres+pg8000 + - postgres+pypostgresql (partial) + - postgres+zxjdbc + - mysql+pyodbc + - mysql+zxjdbc + - mssql - the "has_window_funcs" flag is removed. LIMIT/OFFSET usage will use ROW NUMBER as always, and if on an older version of SQL Server, the operation fails. The behavior is exactly diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py index e69de29bb2..6b1b80db21 100644 --- a/lib/sqlalchemy/dialects/firebird/__init__.py +++ b/lib/sqlalchemy/dialects/firebird/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.firebird import base, kinterbasdb + +base.dialect = kinterbasdb.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index a616ee0f69..41bf4f8aa8 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -8,7 +8,9 @@ Firebird backend ================ -This module implements the Firebird backend, thru the kinterbasdb_ +This module implements the Firebird backend. + +Connectivity is usually supplied via the kinterbasdb_ DBAPI module. Firebird dialects @@ -22,61 +24,32 @@ dialect 1 dialect 3 This is the newer and supported syntax, introduced in Interbase 6.0. - -From the user point of view, the biggest change is in date/time -handling: under dialect 1, there's a single kind of field, ``DATE`` -with a synonim ``DATETIME``, that holds a `timestamp` value, that is a -date with hour, minute, second. Under dialect 3 there are three kinds, -a ``DATE`` that holds a date, a ``TIME`` that holds a *time of the -day* value and a ``TIMESTAMP``, equivalent to the old ``DATE``. - -The problem is that the dialect of a Firebird database is a property -of the database itself [#]_ (that is, any single database has been -created with one dialect or the other: there is no way to change the -after creation). SQLAlchemy has a single instance of the class that -controls all the connections to a particular kind of database, so it -cannot easily differentiate between the two modes, and in particular -it **cannot** simultaneously talk with two distinct Firebird databases -with different dialects. - -By default this module is biased toward dialect 3, but you can easily -tweak it to handle dialect 1 if needed:: - - from sqlalchemy import types as sqltypes - from sqlalchemy.databases.firebird import FBDate, colspecs, ischema_names - - # Adjust the mapping of the timestamp kind - ischema_names['TIMESTAMP'] = FBDate - colspecs[sqltypes.DateTime] = FBDate, - -Other aspects may be version-specific. You can use the ``server_version_info()`` method -on the ``FBDialect`` class to do whatever is needed:: - - from sqlalchemy.databases.firebird import FBCompiler - - if engine.dialect.server_version_info(connection) < (2,0): - # Change the name of the function ``length`` to use the UDF version - # instead of ``char_length`` - FBCompiler.LENGTH_FUNCTION_NAME = 'strlen' - -Pooling connections -------------------- - -The default strategy used by SQLAlchemy to pool the database connections -in particular cases may raise an ``OperationalError`` with a message -`"object XYZ is in use"`. This happens on Firebird when there are two -connections to the database, one is using, or has used, a particular table -and the other tries to drop or alter the same table. To garantee DDL -operations success Firebird recommend doing them as the single connected user. - -In case your SA application effectively needs to do DDL operations while other -connections are active, the following setting may alleviate the problem:: - - from sqlalchemy import pool - from sqlalchemy.databases.firebird import dialect - - # Force SA to use a single connection per thread - dialect.poolclass = pool.SingletonThreadPool + +The SQLAlchemy Firebird dialect detects these versions and +adjusts its representation of SQL accordingly. However, +support for dialect 1 is not well tested and probably has +incompatibilities. + +Firebird Locking Behavior +------------------------- + +Firebird locks tables aggressively. For this reason, a DROP TABLE +may hang until other transactions are released. SQLAlchemy +does its best to release transactions as quickly as possible. The +most common cause of hanging transactions is a non-fully consumed +result set, i.e.:: + + result = engine.execute("select * from table") + row = result.fetchone() + return + +Where above, the ``ResultProxy`` has not been fully consumed. The +connection will be returned to the pool and the transactional state +rolled back once the Python garbage collector reclaims the +objects which hold onto the connection, which often occurs asynchronously. +The above use case can be alleviated by calling ``first()`` on the +``ResultProxy`` which will fetch the first row and immediately close +all remaining cursor/connection resources. RETURNING support ----------------- @@ -96,6 +69,7 @@ parameter when creating the queries:: .. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html .. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb + """ @@ -103,265 +77,293 @@ import datetime, decimal, re from sqlalchemy import schema as sa_schema from sqlalchemy import exc, types as sqltypes, sql, util +from sqlalchemy.sql import expression from sqlalchemy.engine import base, default, reflection +from sqlalchemy.sql import compiler +from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE, + FLOAT, INTEGER, NUMERIC, SMALLINT, + TEXT, TIME, TIMESTAMP, VARCHAR) -_initialized_kb = False +RESERVED_WORDS = set( + ["action", "active", "add", "admin", "after", "all", "alter", "and", "any", + "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename", + "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer", + "by", "cache", "cascade", "case", "cast", "char", "character", "character_length", + "char_length", "check", "check_point_len", "check_point_length", "close", "collate", + "collation", "column", "commit", "committed", "compiletime", "computed", "conditional", + "connect", "constraint", "containing", "continue", "count", "create", "cstring", + "current", "current_connection", "current_date", "current_role", "current_time", + "current_timestamp", "current_transaction", "current_user", "cursor", "database", + "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete", + "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct", + "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point", + "escape", "event", "exception", "execute", "exists", "exit", "extern", "external", + "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it", + "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto", + "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour", + "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input", + "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join", + "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile", + "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment", + "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month", + "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric", + "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option", + "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength", + "pages", "page_size", "parameter", "password", "plan", "position", "post_event", + "precision", "prepare", "primary", "privileges", "procedure", "protected", "public", + "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate", + "references", "release", "release", "reserv", "reserving", "restrict", "retain", + "return", "returning_values", "returns", "revoke", "right", "role", "rollback", + "row_count", "runtime", "savepoint", "schema", "second", "segment", "select", + "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint", + "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability", + "starting", "starts", "statement", "static", "statistics", "sub_type", "sum", + "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction", + "translate", "translation", "trigger", "trim", "type", "uncommitted", "union", + "unique", "update", "upper", "user", "using", "value", "values", "varchar", + "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when", + "whenever", "where", "while", "with", "work", "write", "year", "yearday" ]) +colspecs = { +} -class FBNumeric(sqltypes.Numeric): - """Handle ``NUMERIC(precision,scale)`` datatype.""" +ischema_names = { + 'SHORT': SMALLINT, + 'LONG': BIGINT, + 'QUAD': FLOAT, + 'FLOAT': FLOAT, + 'DATE': DATE, + 'TIME': TIME, + 'TEXT': TEXT, + 'INT64': NUMERIC, + 'DOUBLE': FLOAT, + 'TIMESTAMP': TIMESTAMP, + 'VARYING': VARCHAR, + 'CSTRING': CHAR, + 'BLOB': BLOB, + } + + +# TODO: Boolean type, date conversion types (should be implemented as _FBDateTime, _FBDate, etc. +# as bind/result functionality is required) + + +class FBTypeCompiler(compiler.GenericTypeCompiler): + def visit_datetime(self, type_): + return self.visit_TIMESTAMP(type_) + + def visit_TEXT(self, type_): + return "BLOB SUB_TYPE 1" + + def visit_BLOB(self, type_): + return "BLOB SUB_TYPE 0" - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % { 'precision': self.precision, - 'scale' : self.scale } +class FBCompiler(sql.compiler.SQLCompiler): + """Firebird specific idiosincrasies""" - def bind_processor(self, dialect): - return None + def visit_mod(self, binary, **kw): + # Firebird lacks a builtin modulo operator, but there is + # an equivalent function in the ib_udf library. + return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) - def result_processor(self, dialect): - if self.asdecimal: - return None + def visit_alias(self, alias, asfrom=False, **kwargs): + if self.dialect._version_two: + return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs) else: - def process(value): - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - return process - - -class FBFloat(sqltypes.Float): - """Handle ``FLOAT(precision)`` datatype.""" + # Override to not use the AS keyword which FB 1.5 does not like + if asfrom: + alias_name = isinstance(alias.name, expression._generated_label) and \ + self._truncated_identifier("alias", alias.name) or alias.name + + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \ + self.preparer.format_alias(alias, alias_name) + else: + return self.process(alias.original, **kwargs) - def get_col_spec(self): - if not self.precision: - return "FLOAT" + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0]) + start = self.process(func.clauses.clauses[1]) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2]) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - -class FBInteger(sqltypes.Integer): - """Handle ``INTEGER`` datatype.""" - - def get_col_spec(self): - return "INTEGER" - + return "SUBSTRING(%s FROM %s)" % (s, start) -class FBSmallInteger(sqltypes.SmallInteger): - """Handle ``SMALLINT`` datatype.""" + def visit_length_func(self, function, **kw): + if self.dialect._version_two: + return "char_length" + self.function_argspec(function) + else: + return "strlen" + self.function_argspec(function) - def get_col_spec(self): - return "SMALLINT" + visit_char_length_func = visit_length_func + def function_argspec(self, func): + if func.clauses: + return self.process(func.clause_expr) + else: + return "" -class FBDateTime(sqltypes.DateTime): - """Handle ``TIMESTAMP`` datatype.""" + def default_from(self): + return " FROM rdb$database" - def get_col_spec(self): - return "TIMESTAMP" + def visit_sequence(self, seq): + return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) - def bind_processor(self, dialect): - def process(value): - if value is None or isinstance(value, datetime.datetime): - return value - else: - return datetime.datetime(year=value.year, - month=value.month, - day=value.day) - return process + def get_select_precolumns(self, select): + """Called when building a ``SELECT`` statement, position is just + before column list Firebird puts the limit and offset right + after the ``SELECT``... + """ + result = "" + if select._limit: + result += "FIRST %d " % select._limit + if select._offset: + result +="SKIP %d " % select._offset + if select._distinct: + result += "DISTINCT " + return result -class FBDate(sqltypes.DateTime): - """Handle ``DATE`` datatype.""" + def limit_clause(self, select): + """Already taken care of in the `get_select_precolumns` method.""" - def get_col_spec(self): - return "DATE" + return "" -class FBTime(sqltypes.Time): - """Handle ``TIME`` datatype.""" + def _append_returning(self, text, stmt): + returning_cols = stmt.kwargs["firebird_returning"] + def flatten_columnlist(collist): + for c in collist: + if isinstance(c, sql.expression.Selectable): + for co in c.columns: + yield co + else: + yield c + columns = [self.process(c, within_columns_clause=True) + for c in flatten_columnlist(returning_cols)] + text += ' RETURNING ' + ', '.join(columns) + return text - def get_col_spec(self): - return "TIME" + def visit_update(self, update_stmt): + text = super(FBCompiler, self).visit_update(update_stmt) + if "firebird_returning" in update_stmt.kwargs: + return self._append_returning(text, update_stmt) + else: + return text + def visit_insert(self, insert_stmt): + text = super(FBCompiler, self).visit_insert(insert_stmt) + if "firebird_returning" in insert_stmt.kwargs: + return self._append_returning(text, insert_stmt) + else: + return text -class FBText(sqltypes.Text): - """Handle ``BLOB SUB_TYPE 1`` datatype (aka *textual* blob).""" + def visit_delete(self, delete_stmt): + text = super(FBCompiler, self).visit_delete(delete_stmt) + if "firebird_returning" in delete_stmt.kwargs: + return self._append_returning(text, delete_stmt) + else: + return text - def get_col_spec(self): - return "BLOB SUB_TYPE 1" +class FBDDLCompiler(sql.compiler.DDLCompiler): + """Firebird syntactic idiosincrasies""" -class FBString(sqltypes.String): - """Handle ``VARCHAR(length)`` datatype.""" + def visit_create_sequence(self, create): + """Generate a ``CREATE GENERATOR`` statement for the sequence.""" - def get_col_spec(self): - if self.length: - return "VARCHAR(%(length)s)" % {'length' : self.length} + if self.dialect._version_two: + return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) else: - return "BLOB SUB_TYPE 1" - + return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element) -class FBChar(sqltypes.CHAR): - """Handle ``CHAR(length)`` datatype.""" + def visit_drop_sequence(self, drop): + """Generate a ``DROP GENERATOR`` statement for the sequence.""" - def get_col_spec(self): - if self.length: - return "CHAR(%(length)s)" % {'length' : self.length} + if self.dialect._version_two: + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) else: - return "BLOB SUB_TYPE 1" + return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element) -class FBBinary(sqltypes.Binary): - """Handle ``BLOB SUB_TYPE 0`` datatype (aka *binary* blob).""" - - def get_col_spec(self): - return "BLOB SUB_TYPE 0" - - -class FBBoolean(sqltypes.Boolean): - """Handle boolean values as a ``SMALLINT`` datatype.""" - - def get_col_spec(self): - return "SMALLINT" +class FBDefaultRunner(base.DefaultRunner): + """Firebird specific idiosincrasies""" + def visit_sequence(self, seq): + """Get the next value from the sequence using ``gen_id()``.""" -colspecs = { - sqltypes.Integer : FBInteger, - sqltypes.SmallInteger : FBSmallInteger, - sqltypes.Numeric : FBNumeric, - sqltypes.Float : FBFloat, - sqltypes.DateTime : FBDateTime, - sqltypes.Date : FBDate, - sqltypes.Time : FBTime, - sqltypes.String : FBString, - sqltypes.Binary : FBBinary, - sqltypes.Boolean : FBBoolean, - sqltypes.Text : FBText, - sqltypes.CHAR: FBChar, -} + return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \ + self.dialect.identifier_preparer.format_sequence(seq)) +class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): + """Install Firebird specific reserved words.""" -ischema_names = { - 'SHORT': lambda r: FBSmallInteger(), - 'LONG': lambda r: FBInteger(), - 'QUAD': lambda r: FBFloat(), - 'FLOAT': lambda r: FBFloat(), - 'DATE': lambda r: FBDate(), - 'TIME': lambda r: FBTime(), - 'TEXT': lambda r: FBString(r['flen']), - 'INT64': lambda r: FBNumeric(precision=r['fprec'], scale=r['fscale'] * -1), # This generically handles NUMERIC() - 'DOUBLE': lambda r: FBFloat(), - 'TIMESTAMP': lambda r: FBDateTime(), - 'VARYING': lambda r: FBString(r['flen']), - 'CSTRING': lambda r: FBChar(r['flen']), - 'BLOB': lambda r: r['stype']==1 and FBText() or FBBinary() - } - -RETURNING_KW_NAME = 'firebird_returning' - -class FBExecutionContext(default.DefaultExecutionContext): - pass + reserved_words = RESERVED_WORDS + def __init__(self, dialect): + super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) class FBDialect(default.DefaultDialect): """Firebird dialect""" + name = 'firebird' - supports_sane_rowcount = False - supports_sane_multi_rowcount = False + max_identifier_length = 31 + supports_sequences = True + sequences_optional = False + supports_default_values = True + supports_empty_insert = False preexecute_pk_sequences = True supports_pk_autoincrement = False - - def __init__(self, type_conv=200, concurrency_level=1, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - - self.type_conv = type_conv - self.concurrency_level = concurrency_level - - def dbapi(cls): - import kinterbasdb - return kinterbasdb - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if opts.get('port'): - opts['host'] = "%s/%s" % (opts['host'], opts['port']) - del opts['port'] - opts.update(url.query) - - type_conv = opts.pop('type_conv', self.type_conv) - concurrency_level = opts.pop('concurrency_level', self.concurrency_level) - global _initialized_kb - if not _initialized_kb and self.dbapi is not None: - _initialized_kb = True - self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level) - return ([], opts) - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - - def server_version_info(self, connection): - """Get the version of the Firebird server used by a connection. - - Returns a tuple of (`major`, `minor`, `build`), three integers - representing the version of the attached server. - """ - - # This is the simpler approach (the other uses the services api), - # that for backward compatibility reasons returns a string like - # LI-V6.3.3.12981 Firebird 2.0 - # where the first version is a fake one resembling the old - # Interbase signature. This is more than enough for our purposes, - # as this is mainly (only?) used by the testsuite. - - from re import match - - fbconn = connection.connection.connection - version = fbconn.server_version - m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version) - if not m: - raise AssertionError("Could not determine version from string '%s'" % version) - return tuple([int(x) for x in m.group(5, 6, 4)]) - - def _normalize_name(self, name): - """Convert the name to lowercase if it is possible""" - + requires_name_normalize = True + + statement_compiler = FBCompiler + ddl_compiler = FBDDLCompiler + defaultrunner = FBDefaultRunner + preparer = FBIdentifierPreparer + type_compiler = FBTypeCompiler + + colspecs = colspecs + ischema_names = ischema_names + + # defaults to dialect ver. 3, + # will be autodetected off upon + # first connect + _version_two = True + + def initialize(self, connection): + super(FBDialect, self).initialize(connection) + self._version_two = self.server_version_info > (2, ) + if not self._version_two: + # TODO: whatever other pre < 2.0 stuff goes here + self.ischema_names = ischema_names.copy() + self.ischema_names['TIMESTAMP'] = sqltypes.DATE + self.colspecs = { + sqltypes.DateTime :sqltypes.DATE + } + + def normalize_name(self, name): # Remove trailing spaces: FB uses a CHAR() type, # that is padded with spaces name = name and name.rstrip() if name is None: return None - elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower()): + elif name.upper() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): return name.lower() else: return name - def _denormalize_name(self, name): - """Revert a *normalized* name to its uppercase equivalent""" - + def denormalize_name(self, name): if name is None: return None - elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): + elif name.lower() == name and \ + not self.identifier_preparer._requires_quotes(name.lower()): return name.upper() else: return name - def table_names(self, connection, schema): - """Return a list of *normalized* table names omitting system relations.""" - - s = """ - SELECT r.rdb$relation_name - FROM rdb$relations r - WHERE r.rdb$system_flag=0 - """ - return [self._normalize_name(row[0]) for row in connection.execute(s)] - def has_table(self, connection, table_name, schema=None): """Return ``True`` if the given table exists, ignoring the `schema`.""" @@ -371,12 +373,8 @@ class FBDialect(default.DefaultDialect): FROM rdb$relations WHERE rdb$relation_name=?) """ - c = connection.execute(tblqry, [self._denormalize_name(table_name)]) - row = c.fetchone() - if row is not None: - return True - else: - return False + c = connection.execute(tblqry, [self.denormalize_name(table_name)]) + return c.first() is not None def has_sequence(self, connection, sequence_name): """Return ``True`` if the given sequence (generator) exists.""" @@ -387,32 +385,43 @@ class FBDialect(default.DefaultDialect): FROM rdb$generators WHERE rdb$generator_name=?) """ - c = connection.execute(genqry, [self._denormalize_name(sequence_name)]) - row = c.fetchone() - if row is not None: - return True - else: - return False - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.OperationalError): - return 'Unable to complete network request to host' in str(e) - elif isinstance(e, self.dbapi.ProgrammingError): - msg = str(e) - return ('Invalid connection state' in msg or - 'Invalid cursor state' in msg) - else: - return False + c = connection.execute(genqry, [self.denormalize_name(sequence_name)]) + return c.first() is not None - @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def table_names(self, connection, schema): s = """ SELECT DISTINCT rdb$relation_name FROM rdb$relation_fields WHERE rdb$system_flag=0 AND rdb$view_context IS NULL """ - return [self._normalize_name(row[0]) for row in connection.execute(s)] + return [self.normalize_name(row[0]) for row in connection.execute(s)] + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + return self.table_names(connection, schema) + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + s = """ + SELECT distinct rdb$view_name + FROM rdb$view_relations + """ + return [self.normalize_name(row[0]) for row in connection.execute(s)] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + qry = """ + SELECT rdb$view_source AS view_source + FROM rdb$relations + WHERE rdb$relation_name=?; + """ + rp = connection.execute(qry, [self.denormalize_name(view_name)]) + row = rp.first() + if row: + return row['view_source'] + else: + return None + @reflection.cache def get_primary_keys(self, connection, table_name, schema=None, **kw): # Query to extract the PK/FK constrained fields of the given table @@ -422,17 +431,17 @@ class FBDialect(default.DefaultDialect): JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? """ - tablename = self._denormalize_name(table.name) + tablename = self.denormalize_name(table_name) # get primary key fields c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) - pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()] + pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()] return pkfields @reflection.cache def get_column_sequence(self, connection, table_name, column_name, schema=None, **kw): - tablename = self._denormalize_name(table_name) - colname = self._denormalize_name(column_name) + tablename = self.denormalize_name(table_name) + colname = self.denormalize_name(column_name) # Heuristic-query to determine the generator associated to a PK field genqry = """ SELECT trigdep.rdb$depended_on_name AS fgenerator @@ -452,7 +461,7 @@ class FBDialect(default.DefaultDialect): genc = connection.execute(genqry, [tablename, colname]) genr = genc.fetchone() if genr is not None: - return dict(name=self._normalize_name(genr['fgenerator'])) + return dict(name=self.normalize_name(genr['fgenerator'])) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -472,7 +481,7 @@ class FBDialect(default.DefaultDialect): WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? ORDER BY r.rdb$field_position """ - tablename = self._denormalize_name(table_name) + tablename = self.denormalize_name(table_name) # get all of the fields for this table c = connection.execute(tblqry, [tablename]) cols = [] @@ -480,16 +489,29 @@ class FBDialect(default.DefaultDialect): row = c.fetchone() if row is None: break - name = self._normalize_name(row['fname']) + name = self.normalize_name(row['fname']) # get the data type - coltype = ischema_names.get(row['ftype'].rstrip()) + + colspec = row['ftype'].rstrip() + coltype = self.ischema_names.get(colspec) if coltype is None: util.warn("Did not recognize type '%s' of column '%s'" % - (str(row['ftype']), name)) + (colspec, name)) coltype = sqltypes.NULLTYPE + elif colspec == 'INT64': + coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1) + elif colspec in ('VARYING', 'CSTRING'): + coltype = coltype(row['flen']) + elif colspec == 'TEXT': + coltype = TEXT(row['flen']) + elif colspec == 'BLOB': + if row['stype'] == 1: + coltype = TEXT() + else: + coltype = BLOB() else: coltype = coltype(row) - + # does it have a default value? defvalue = None if row['fdefault'] is not None: @@ -517,96 +539,65 @@ class FBDialect(default.DefaultDialect): JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name - JOIN rdb$index_segments se ON se.rdb$index_name=ix2.rdb$index_name AND se.rdb$field_position=cse.rdb$field_position + JOIN rdb$index_segments se ON se.rdb$index_name=ix2.rdb$index_name AND + se.rdb$field_position=cse.rdb$field_position WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? ORDER BY se.rdb$index_name, se.rdb$field_position """ - tablename = self._denormalize_name(table_name) - # get the foreign keys + tablename = self.denormalize_name(table_name) + c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) - fks = {} - fkeys = [] - while True: - row = c.fetchone() - if not row: - break - cname = self._normalize_name(row['cname']) - if cname in fks: - fk = fks[cname] - else: - fk = { - 'name' : cname, - 'constrained_columns' : [], - 'referred_schema' : None, - 'referred_table' : None, - 'referred_columns' : [] - } - fks[cname] = fk - fkeys.append(fk) - fk['referred_table'] = self._normalize_name(row['targetrname']) - fk['constrained_columns'].append(self._normalize_name(row['fname'])) + fks = util.defaultdict(lambda:{ + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + }) + + for row in c: + cname = self.normalize_name(row['cname']) + fk = fks[cname] + if not fk['name']: + fk['name'] = cname + fk['referred_table'] = self.normalize_name(row['targetrname']) + fk['constrained_columns'].append(self.normalize_name(row['fname'])) fk['referred_columns'].append( - self._normalize_name(row['targetfname'])) - return fkeys - - def reflecttable(self, connection, table, include_columns): - - # get primary key fields - pkfields = self.get_primary_keys(connection, table.name) - - found_table = False - for col_d in self.get_columns(connection, table.name): - found_table = True - - name = col_d.get('name') - defvalue = col_d.get('default') - nullable = col_d.get('nullable') - coltype = col_d.get('type') - - if include_columns and name not in include_columns: - continue - args = [name] - - kw = {} - args.append(coltype) - - # is it a primary key? - kw['primary_key'] = name in pkfields - - # is it nullable? - kw['nullable'] = nullable - - # does it have a default value? - if defvalue: - args.append(sa_schema.DefaultClause(sql.text(defvalue))) - - col = sa_schema.Column(*args, **kw) - if kw['primary_key']: - # if the PK is a single field, try to see if its linked to - # a sequence thru a trigger - if len(pkfields)==1: - sequence_name = self.get_column_sequence(connection, - table.name, name) - if sequence_name is not None: - col.sequence = sa_schema.Sequence(sequence_name) - table.append_column(col) - - if not found_table: - raise exc.NoSuchTableError(table.name) - - # get the foreign keys - for fkey_d in self.get_foreign_keys(connection, table.name): - cname = fkey_d['name'] - constrained_columns = fkey_d['constrained_columns'] - rname = fkey_d['referred_table'] - referred_columns = fkey_d['referred_columns'] - - sa_schema.Table(rname, table.metadata, autoload=True, autoload_with=connection) - refspec = ['.'.join(c) for c in \ - zip(constrained_columns, referred_columns)] - table.append_constraint(sa_schema.ForeignKeyConstraint( - constrained_columns, refspec, name=cname, link_to_name=True)) + self.normalize_name(row['targetfname'])) + return fks.values() + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + qry = """ + SELECT + ix.rdb$index_name AS index_name, + ix.rdb$unique_flag AS unique_flag, + ic.rdb$field_name AS field_name + + FROM rdb$indices ix JOIN rdb$index_segments ic + ON ix.rdb$index_name=ic.rdb$index_name + + LEFT OUTER JOIN RDB$RELATION_CONSTRAINTS + ON RDB$RELATION_CONSTRAINTS.RDB$INDEX_NAME = ic.RDB$INDEX_NAME + + WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL + AND RDB$RELATION_CONSTRAINTS.RDB$CONSTRAINT_TYPE IS NULL + ORDER BY index_name, field_name + """ + c = connection.execute(qry, [self.denormalize_name(table_name)]) + + indexes = util.defaultdict(dict) + for row in c: + indexrec = indexes[row['index_name']] + if 'name' not in indexrec: + indexrec['name'] = self.normalize_name(row['index_name']) + indexrec['column_names'] = [] + indexrec['unique'] = bool(row['unique_flag']) + + indexrec['column_names'].append(self.normalize_name(row['field_name'])) + + return indexes.values() + def do_execute(self, cursor, statement, parameters, **kwargs): # kinterbase does not accept a None, but wants an empty list # when there are no arguments. @@ -621,194 +612,3 @@ class FBDialect(default.DefaultDialect): connection.commit(True) -class FBCompiler(sql.compiler.SQLCompiler): - """Firebird specific idiosincrasies""" - - def visit_mod(self, binary, **kw): - # Firebird lacks a builtin modulo operator, but there is - # an equivalent function in the ib_udf library. - return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right)) - - def visit_alias(self, alias, asfrom=False, **kwargs): - # Override to not use the AS keyword which FB 1.5 does not like - if asfrom: - return self.process(alias.original, asfrom=True, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name)) - else: - return self.process(alias.original, **kwargs) - - def visit_substring_func(self, func, **kw): - s = self.process(func.clauses.clauses[0]) - start = self.process(func.clauses.clauses[1]) - if len(func.clauses.clauses) > 2: - length = self.process(func.clauses.clauses[2]) - return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) - else: - return "SUBSTRING(%s FROM %s)" % (s, start) - - # TODO: auto-detect this or something - LENGTH_FUNCTION_NAME = 'char_length' - - def visit_length_func(self, function, **kw): - return self.LENGTH_FUNCTION_NAME + self.function_argspec(function) - - def visit_char_length_func(self, function, **kw): - return self.LENGTH_FUNCTION_NAME + self.function_argspec(function) - - def function_argspec(self, func): - if func.clauses: - return self.process(func.clause_expr) - else: - return "" - - def default_from(self): - return " FROM rdb$database" - - def visit_sequence(self, seq): - return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) - - def get_select_precolumns(self, select): - """Called when building a ``SELECT`` statement, position is just - before column list Firebird puts the limit and offset right - after the ``SELECT``... - """ - - result = "" - if select._limit: - result += "FIRST %d " % select._limit - if select._offset: - result +="SKIP %d " % select._offset - if select._distinct: - result += "DISTINCT " - return result - - def limit_clause(self, select): - """Already taken care of in the `get_select_precolumns` method.""" - - return "" - - - def _append_returning(self, text, stmt): - returning_cols = stmt.kwargs[RETURNING_KW_NAME] - def flatten_columnlist(collist): - for c in collist: - if isinstance(c, sql.expression.Selectable): - for co in c.columns: - yield co - else: - yield c - columns = [self.process(c, within_columns_clause=True) - for c in flatten_columnlist(returning_cols)] - text += ' RETURNING ' + ', '.join(columns) - return text - - def visit_update(self, update_stmt): - text = super(FBCompiler, self).visit_update(update_stmt) - if RETURNING_KW_NAME in update_stmt.kwargs: - return self._append_returning(text, update_stmt) - else: - return text - - def visit_insert(self, insert_stmt): - text = super(FBCompiler, self).visit_insert(insert_stmt) - if RETURNING_KW_NAME in insert_stmt.kwargs: - return self._append_returning(text, insert_stmt) - else: - return text - - def visit_delete(self, delete_stmt): - text = super(FBCompiler, self).visit_delete(delete_stmt) - if RETURNING_KW_NAME in delete_stmt.kwargs: - return self._append_returning(text, delete_stmt) - else: - return text - - -class FBSchemaGenerator(sql.compiler.SchemaGenerator): - """Firebird syntactic idiosincrasies""" - - def visit_sequence(self, sequence): - """Generate a ``CREATE GENERATOR`` statement for the sequence.""" - - if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name): - self.append("CREATE GENERATOR %s" % self.preparer.format_sequence(sequence)) - self.execute() - - -class FBSchemaDropper(sql.compiler.SchemaDropper): - """Firebird syntactic idiosincrasies""" - - def visit_sequence(self, sequence): - """Generate a ``DROP GENERATOR`` statement for the sequence.""" - - if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name): - self.append("DROP GENERATOR %s" % self.preparer.format_sequence(sequence)) - self.execute() - - -class FBDefaultRunner(base.DefaultRunner): - """Firebird specific idiosincrasies""" - - def visit_sequence(self, seq): - """Get the next value from the sequence using ``gen_id()``.""" - - return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \ - self.dialect.identifier_preparer.format_sequence(seq)) - - -RESERVED_WORDS = set( - ["action", "active", "add", "admin", "after", "all", "alter", "and", "any", - "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename", - "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer", - "by", "cache", "cascade", "case", "cast", "char", "character", "character_length", - "char_length", "check", "check_point_len", "check_point_length", "close", "collate", - "collation", "column", "commit", "committed", "compiletime", "computed", "conditional", - "connect", "constraint", "containing", "continue", "count", "create", "cstring", - "current", "current_connection", "current_date", "current_role", "current_time", - "current_timestamp", "current_transaction", "current_user", "cursor", "database", - "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete", - "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct", - "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point", - "escape", "event", "exception", "execute", "exists", "exit", "extern", "external", - "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it", - "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto", - "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour", - "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input", - "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join", - "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile", - "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment", - "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month", - "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric", - "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option", - "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength", - "pages", "page_size", "parameter", "password", "plan", "position", "post_event", - "precision", "prepare", "primary", "privileges", "procedure", "protected", "public", - "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate", - "references", "release", "release", "reserv", "reserving", "restrict", "retain", - "return", "returning_values", "returns", "revoke", "right", "role", "rollback", - "row_count", "runtime", "savepoint", "schema", "second", "segment", "select", - "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint", - "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability", - "starting", "starts", "statement", "static", "statistics", "sub_type", "sum", - "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction", - "translate", "translation", "trigger", "trim", "type", "uncommitted", "union", - "unique", "update", "upper", "user", "using", "value", "values", "varchar", - "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when", - "whenever", "where", "while", "with", "work", "write", "year", "yearday" ]) - - -class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): - """Install Firebird specific reserved words.""" - - reserved_words = RESERVED_WORDS - - def __init__(self, dialect): - super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) - - -dialect = FBDialect -dialect.statement_compiler = FBCompiler -dialect.schemagenerator = FBSchemaGenerator -dialect.schemadropper = FBSchemaDropper -dialect.defaultrunner = FBDefaultRunner -dialect.preparer = FBIdentifierPreparer -dialect.execution_ctx_cls = FBExecutionContext diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py new file mode 100644 index 0000000000..e463959027 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -0,0 +1,72 @@ +from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler +from sqlalchemy.engine.default import DefaultExecutionContext + +_initialized_kb = False + +class Firebird_kinterbasdb(FBDialect): + driver = 'kinterbasdb' + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + + def __init__(self, type_conv=200, concurrency_level=1, **kwargs): + super(Firebird_kinterbasdb, self).__init__(**kwargs) + + self.type_conv = type_conv + self.concurrency_level = concurrency_level + + @classmethod + def dbapi(cls): + k = __import__('kinterbasdb') + return k + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + if opts.get('port'): + opts['host'] = "%s/%s" % (opts['host'], opts['port']) + del opts['port'] + opts.update(url.query) + + type_conv = opts.pop('type_conv', self.type_conv) + concurrency_level = opts.pop('concurrency_level', self.concurrency_level) + global _initialized_kb + if not _initialized_kb and self.dbapi is not None: + _initialized_kb = True + self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level) + return ([], opts) + + + def _get_server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. This is more than enough for our purposes, + # as this is mainly (only?) used by the testsuite. + + from re import match + + fbconn = connection.connection + version = fbconn.server_version + m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version) + if not m: + raise AssertionError("Could not determine version from string '%s'" % version) + return tuple([int(x) for x in m.group(5, 6, 4)]) + + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.OperationalError): + return 'Unable to complete network request to host' in str(e) + elif isinstance(e, self.dbapi.ProgrammingError): + msg = str(e) + return ('Invalid connection state' in msg or + 'Invalid cursor state' in msg) + else: + return False + +dialect = Firebird_kinterbasdb diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 2d8b31bcf2..849b72b979 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1439,10 +1439,6 @@ class MSDialect(default.DefaultDialect): return fkeys.values() - def reflecttable(self, connection, table, include_columns): - - insp = reflection.Inspector.from_engine(connection) - return insp.reflecttable(table, include_columns) # fixme. I added this for the tests to run. -Randall MSSQLDialect = MSDialect diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 133d7ef57a..b6b57c7968 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1940,12 +1940,6 @@ class MySQLDialect(default.DefaultDialect): sql = parser._describe_to_create(table_name, columns) return parser.parse(sql, charset) - def reflecttable(self, connection, table, include_columns): - """Load column definitions from the server.""" - - insp = reflection.Inspector.from_engine(connection) - return insp.reflecttable(table, include_columns) - def _adjust_casing(self, table, charset=None): """Adjust Table name to the server case sensitivity, if needed.""" diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index a512a2b1b1..35c85c2c95 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -458,6 +458,7 @@ class OracleDialect(default.DefaultDialect): default_paramstyle = 'named' colspecs = colspecs ischema_names = ischema_names + requires_name_normalize = True supports_default_values = False supports_empty_insert = False @@ -482,16 +483,16 @@ class OracleDialect(default.DefaultDialect): def has_table(self, connection, table_name, schema=None): if not schema: schema = self.get_default_schema_name(connection) - cursor = connection.execute("""select table_name from all_tables where table_name=:name and owner=:schema_name""", {'name':self._denormalize_name(table_name), 'schema_name':self._denormalize_name(schema)}) + cursor = connection.execute("""select table_name from all_tables where table_name=:name and owner=:schema_name""", {'name':self.denormalize_name(table_name), 'schema_name':self.denormalize_name(schema)}) return cursor.fetchone() is not None def has_sequence(self, connection, sequence_name, schema=None): if not schema: schema = self.get_default_schema_name(connection) - cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name and sequence_owner=:schema_name""", {'name':self._denormalize_name(sequence_name), 'schema_name':self._denormalize_name(schema)}) + cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name and sequence_owner=:schema_name""", {'name':self.denormalize_name(sequence_name), 'schema_name':self.denormalize_name(schema)}) return cursor.fetchone() is not None - def _normalize_name(self, name): + def normalize_name(self, name): if name is None: return None elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding)): @@ -499,7 +500,7 @@ class OracleDialect(default.DefaultDialect): else: return name.decode(self.encoding) - def _denormalize_name(self, name): + def denormalize_name(self, name): if name is None: return None elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()): @@ -508,7 +509,7 @@ class OracleDialect(default.DefaultDialect): return name.encode(self.encoding) def get_default_schema_name(self, connection): - return self._normalize_name(connection.execute('SELECT USER FROM DUAL').scalar()) + return self.normalize_name(connection.execute('SELECT USER FROM DUAL').scalar()) def table_names(self, connection, schema): # note that table_names() isnt loading DBLINKed or synonym'ed tables @@ -517,8 +518,8 @@ class OracleDialect(default.DefaultDialect): cursor = connection.execute(s) else: s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner" - cursor = connection.execute(s, {'owner': self._denormalize_name(schema)}) - return [self._normalize_name(row[0]) for row in cursor] + cursor = connection.execute(s, {'owner': self.denormalize_name(schema)}) + return [self.normalize_name(row[0]) for row in cursor] def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None): """search for a local synonym matching the given desired owner/name. @@ -567,35 +568,35 @@ class OracleDialect(default.DefaultDialect): resolve_synonyms=False, dblink='', **kw): if resolve_synonyms: - actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(schema), desired_synonym=self._denormalize_name(table_name)) + actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self.denormalize_name(schema), desired_synonym=self.denormalize_name(table_name)) else: actual_name, owner, dblink, synonym = None, None, None, None if not actual_name: - actual_name = self._denormalize_name(table_name) + actual_name = self.denormalize_name(table_name) if not dblink: dblink = '' if not owner: - owner = self._denormalize_name(schema or self.get_default_schema_name(connection)) + owner = self.denormalize_name(schema or self.get_default_schema_name(connection)) return (actual_name, owner, dblink, synonym) @reflection.cache def get_schema_names(self, connection, **kw): s = "SELECT username FROM all_users ORDER BY username" cursor = connection.execute(s,) - return [self._normalize_name(row[0]) for row in cursor] + return [self.normalize_name(row[0]) for row in cursor] @reflection.cache def get_table_names(self, connection, schema=None, **kw): - schema = self._denormalize_name(schema or self.get_default_schema_name(connection)) + schema = self.denormalize_name(schema or self.get_default_schema_name(connection)) return self.table_names(connection, schema) @reflection.cache def get_view_names(self, connection, schema=None, **kw): - schema = self._denormalize_name(schema or self.get_default_schema_name(connection)) + schema = self.denormalize_name(schema or self.get_default_schema_name(connection)) s = "select view_name from all_views where OWNER = :owner" cursor = connection.execute(s, - {'owner':self._denormalize_name(schema)}) - return [self._normalize_name(row[0]) for row in cursor] + {'owner':self.denormalize_name(schema)}) + return [self.normalize_name(row[0]) for row in cursor] @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -627,7 +628,7 @@ class OracleDialect(default.DefaultDialect): for row in c: (colname, coltype, length, precision, scale, nullable, default) = \ - (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) + (self.normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) # INTEGER if the scale is 0 and precision is null # NUMBER if the scale and precision are both null @@ -684,8 +685,8 @@ class OracleDialect(default.DefaultDialect): ORDER BY a.INDEX_NAME, a.COLUMN_POSITION """ % dict(dblink=dblink) rp = connection.execute(q, - dict(table_name=self._denormalize_name(table_name), - schema=self._denormalize_name(schema))) + dict(table_name=self.denormalize_name(table_name), + schema=self.denormalize_name(schema))) indexes = [] last_index_name = None pkeys = self.get_primary_keys(connection, table_name, schema, @@ -698,10 +699,10 @@ class OracleDialect(default.DefaultDialect): if rset.column_name in [s.upper() for s in pkeys]: continue if rset.index_name != last_index_name: - index = dict(name=self._normalize_name(rset.index_name), column_names=[]) + index = dict(name=self.normalize_name(rset.index_name), column_names=[]) indexes.append(index) index['unique'] = uniqueness.get(rset.uniqueness, False) - index['column_names'].append(self._normalize_name(rset.column_name)) + index['column_names'].append(self.normalize_name(rset.column_name)) last_index_name = rset.index_name return indexes @@ -763,7 +764,7 @@ class OracleDialect(default.DefaultDialect): for row in constraint_data: #print "ROW:" , row (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ - row[0:2] + tuple([self._normalize_name(x) for x in row[2:6]]) + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) if cons_type == 'P': pkeys.append(local_column) return pkeys @@ -807,7 +808,7 @@ class OracleDialect(default.DefaultDialect): for row in constraint_data: (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ - row[0:2] + tuple([self._normalize_name(x) for x in row[2:6]]) + row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) if cons_type == 'R': if remote_table is None: @@ -827,16 +828,16 @@ class OracleDialect(default.DefaultDialect): ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \ self._resolve_synonym( connection, - desired_owner=self._denormalize_name(remote_owner), - desired_table=self._denormalize_name(remote_table) + desired_owner=self.denormalize_name(remote_owner), + desired_table=self.denormalize_name(remote_table) ) if ref_synonym: - remote_table = self._normalize_name(ref_synonym) - remote_owner = self._normalize_name(ref_remote_owner) + remote_table = self.normalize_name(ref_synonym) + remote_owner = self.normalize_name(ref_remote_owner) rec['referred_table'] = remote_table - if requested_schema is not None or self._denormalize_name(remote_owner) != schema: + if requested_schema is not None or self.denormalize_name(remote_owner) != schema: rec['referred_schema'] = remote_owner local_cols.append(local_column) @@ -864,9 +865,6 @@ class OracleDialect(default.DefaultDialect): else: return None - def reflecttable(self, connection, table, include_columns): - insp = reflection.Inspector.from_engine(connection) - return insp.reflecttable(table, include_columns) class _OuterJoinColumn(sql.ClauseElement): diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index 947ebbac5d..dd0bd80495 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -859,10 +859,6 @@ class PGDialect(default.DefaultDialect): index_d['unique'] = unique return indexes - def reflecttable(self, connection, table, include_columns): - insp = reflection.Inspector.from_engine(connection) - return insp.reflecttable(table, include_columns) - def _load_domains(self, connection): ## Load data types for domains: SQL_DOMAINS = """ diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 1b260aa686..b644c7bf80 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -516,35 +516,6 @@ class SQLiteDialect(default.DefaultDialect): cols.append(row[2]) return indexes - def get_unique_indexes(self, connection, table_name, schema=None, **kw): - quote = self.identifier_preparer.quote_identifier - if schema is not None: - pragma = "PRAGMA %s." % quote(schema) - else: - pragma = "PRAGMA " - qtable = quote(table_name) - c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable))) - unique_indexes = [] - while True: - row = c.fetchone() - if row is None: - break - if (row[2] == 1): - unique_indexes.append(row[1]) - # loop thru unique indexes for one that includes the primary key - for idx in unique_indexes: - c = _pragma_cursor(connection.execute("%sindex_info(%s)" % (pragma, idx))) - cols = [] - while True: - row = c.fetchone() - if row is None: - break - cols.append(row[2]) - return unique_indexes - - def reflecttable(self, connection, table, include_columns): - insp = reflection.Inspector.from_engine(connection) - return insp.reflecttable(table, include_columns) def _pragma_cursor(cursor): """work around SQLite issue whereby cursor.description is blank when PRAGMA returns no rows.""" diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index a570427969..f752182c36 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -289,6 +289,23 @@ class Dialect(object): raise NotImplementedError() + def normalize_name(self, name): + """convert the given name to lowercase if it is detected as case insensitive. + + this method is only used if the dialect defines requires_name_normalize=True. + + """ + raise NotImplementedError() + + def denormalize_name(self, name): + """convert the given name to a case insensitive identifier for the backend + if it is an all-lowercase name. + + this method is only used if the dialect defines requires_name_normalize=True. + + """ + raise NotImplementedError() + def has_table(self, connection, table_name, schema=None): """Check the existence of a particular table in the database. @@ -1636,7 +1653,10 @@ class ResultProxy(object): if origname: if self._props.setdefault(origname.lower(), rec) is not rec: self._props[origname.lower()] = (type_, self.__ambiguous_processor(origname), 0) - + + if self.dialect.requires_name_normalize: + colname = self.dialect.normalize_name(colname) + self.keys.append(colname) self._props[i] = rec if obj: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 413de171a2..5a86d7c94b 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -13,7 +13,7 @@ as the base class for their own corresponding classes. """ import re, random -from sqlalchemy.engine import base +from sqlalchemy.engine import base, reflection from sqlalchemy.sql import compiler, expression from sqlalchemy import exc, types as sqltypes @@ -51,6 +51,14 @@ class DefaultDialect(base.Dialect): default_paramstyle = 'named' supports_default_values = False supports_empty_insert = True + + # indicates symbol names are + # UPPERCASEd if they are case insensitive + # within the database. + # if this is True, the methods normalize_name() + # and denormalize_name() must be provided. + requires_name_normalize = False + reflection_options = () def __init__(self, convert_unicode=False, assert_unicode=False, @@ -104,9 +112,16 @@ class DefaultDialect(base.Dialect): """ return sqltypes.adapt_type(typeobj, cls.colspecs) + def reflecttable(self, connection, table, include_columns): + insp = reflection.Inspector.from_engine(connection) + return insp.reflecttable(table, include_columns) + def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: - raise exc.IdentifierError("Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length)) + raise exc.IdentifierError( + "Identifier '%s' exceeds maximum length of %d characters" % + (ident, self.max_identifier_length) + ) def connect(self, *cargs, **cparams): return self.dbapi.connect(*cargs, **cparams) diff --git a/lib/sqlalchemy/test/requires.py b/lib/sqlalchemy/test/requires.py index 98408ace98..b6fbb93dba 100644 --- a/lib/sqlalchemy/test/requires.py +++ b/lib/sqlalchemy/test/requires.py @@ -28,6 +28,25 @@ def foreign_keys(fn): no_support('sqlite', 'not supported by database'), ) + +def unbounded_varchar(fn): + """Target database must support VARCHAR with no length""" + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('oracle', 'not supported by database'), + no_support('mysql', 'not supported by database'), + ) + +def boolean_col_expressions(fn): + """Target database must support boolean expressions as columns""" + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('oracle', 'not supported by database'), + no_support('mssql', 'not supported by database'), + ) + def identity(fn): """Target database must support GENERATED AS IDENTITY or a facsimile. @@ -90,7 +109,8 @@ def schemas(fn): return _chain_decorators_on( fn, - no_support('sqlite', 'no schema support') + no_support('sqlite', 'no schema support'), + no_support('firebird', 'no schema support') ) def sequences(fn): diff --git a/lib/sqlalchemy/test/schema.py b/lib/sqlalchemy/test/schema.py index 555ffffe06..35b4060d2b 100644 --- a/lib/sqlalchemy/test/schema.py +++ b/lib/sqlalchemy/test/schema.py @@ -33,7 +33,7 @@ def Table(*args, **kw): # expand to ForeignKeyConstraint too. fks = [fk for col in args if isinstance(col, schema.Column) - for fk in col.args if isinstance(fk, schema.ForeignKey)] + for fk in col.foreign_keys] for fk in fks: # root around in raw spec diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 965849df02..8d123f5048 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -923,7 +923,7 @@ class SMALLINT(SmallInteger): __visit_name__ = 'SMALLINT' -class BIGINT(SmallInteger): +class BIGINT(BigInteger): """The SQL BIGINT type.""" __visit_name__ = 'BIGINT' diff --git a/test/dialect/test_firebird.py b/test/dialect/test_firebird.py index fa608c9a18..033522902e 100644 --- a/test/dialect/test_firebird.py +++ b/test/dialect/test_firebird.py @@ -198,7 +198,7 @@ class ReturningTest(TestBase, AssertsExecutionResults): table.drop() -class MiscFBTests(TestBase): +class MiscTest(TestBase): __only_on__ = 'firebird' def test_strlen(self): @@ -225,4 +225,12 @@ class MiscFBTests(TestBase): version = testing.db.dialect.server_version_info(testing.db.connect()) assert len(version) == 3, "Got strange version info: %s" % repr(version) + def test_percents_in_text(self): + for expr, result in ( + (text("select '%' from rdb$database"), '%'), + (text("select '%%' from rdb$database"), '%%'), + (text("select '%%%' from rdb$database"), '%%%'), + (text("select 'hello % world' from rdb$database"), "hello % world") + ): + eq_(testing.db.scalar(expr), result) diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 06451bbc4d..43b427d332 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -598,7 +598,6 @@ class ReflectionTest(TestBase, ComparesTables): m9.reflect() self.assert_(not m9.tables) - @testing.fails_on_everything_except('postgres', 'mysql', 'sqlite', 'oracle', 'mssql') def test_index_reflection(self): m1 = MetaData(testing.db) t1 = Table('party', m1, @@ -907,7 +906,7 @@ def dropViews(con, schema=None): class ComponentReflectionTest(TestBase): - @testing.fails_on('sqlite', 'no schemas') + @testing.requires.schemas def test_get_schema_names(self): meta = MetaData(testing.db) insp = Inspector(meta.bind) @@ -971,23 +970,27 @@ class ComponentReflectionTest(TestBase): # should be in order for (i, col) in enumerate(table.columns): eq_(col.name, cols[i]['name']) - # coltype is tricky - # It may not inherit from col.type while they share - # the same base. ctype = cols[i]['type'].__class__ ctype_def = col.type if isinstance(ctype_def, sa.types.TypeEngine): ctype_def = ctype_def.__class__ + # Oracle returns Date for DateTime. if testing.against('oracle') \ and ctype_def in (sql_types.Date, sql_types.DateTime): ctype_def = sql_types.Date + + # assert that the desired type and return type + # share a base within one of the generic types. self.assert_( - issubclass(ctype, ctype_def) or \ len( set( - ctype.__bases__ - ).intersection(ctype_def.__bases__)) > 0 + ctype.__mro__ + ).intersection(ctype_def.__mro__) + .intersection([sql_types.Integer, sql_types.Numeric, + sql_types.DateTime, sql_types.Date, sql_types.Time, + sql_types.String, sql_types.Binary]) + ) > 0 ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'], ctype))) finally: diff --git a/test/sql/test_query.py b/test/sql/test_query.py index f22a5c22a3..f679277049 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -149,7 +149,7 @@ class QueryTest(TestBase): l.append(row) self.assert_(len(l) == 3) - @testing.fails_on('firebird', 'Data type unknown') + @testing.fails_on('firebird', "kinterbasdb doesn't send full type information") @testing.requires.subqueries def test_anonymous_rows(self): users.insert().execute( @@ -163,6 +163,7 @@ class QueryTest(TestBase): assert row['anon_1'] == 8 assert row['anon_2'] == 10 + @testing.fails_on('firebird', "kinterbasdb doesn't send full type information") def test_order_by_label(self): """test that a label within an ORDER BY works on each backend. @@ -181,6 +182,11 @@ class QueryTest(TestBase): select([concat]).order_by(concat).execute().fetchall(), [("test: ed",), ("test: fred",), ("test: jack",)] ) + + eq_( + select([concat]).order_by(concat).execute().fetchall(), + [("test: ed",), ("test: fred",), ("test: jack",)] + ) concat = ("test: " + users.c.user_name).label('thedata') eq_( @@ -209,8 +215,7 @@ class QueryTest(TestBase): self.assert_(not (rp != equal)) self.assert_(not (equal != equal)) - @testing.fails_on('mssql', 'No support for boolean logic in column select.') - @testing.fails_on('oracle', 'FIXME: unknown') + @testing.requires.boolean_col_expressions def test_or_and_as_columns(self): true, false = literal(True), literal(False) @@ -255,6 +260,7 @@ class QueryTest(TestBase): eq_(expr.execute().fetchall(), result) + @testing.fails_on("firebird", "see dialect.test_firebird:MiscTest.test_percents_in_text") @testing.fails_on("oracle", "neither % nor %% are accepted") @testing.fails_on("+pg8000", "can't interpret result column from '%%'") @testing.emits_warning('.*now automatically escapes.*') @@ -475,15 +481,25 @@ class QueryTest(TestBase): self.assert_(r['query_users.user_id']) == 1 self.assert_(r['query_users.user_name']) == "john" - @testing.fails_on('oracle', 'oracle result keys() are all uppercase, not getting into this.') + def test_result_case_sensitivity(self): + """test name normalization for result sets.""" + + row = testing.db.execute( + select([ + literal_column("1").label("case_insensitive"), + literal_column("2").label("CaseSensitive") + ]) + ).first() + + assert row.keys() == ["case_insensitive", "CaseSensitive"] + def test_row_as_args(self): users.insert().execute(user_id=1, user_name='john') r = users.select(users.c.user_id==1).execute().first() users.delete().execute() users.insert().execute(r) eq_(users.select().execute().fetchall(), [(1, 'john')]) - - @testing.fails_on('oracle', 'oracle result keys() are all uppercase, not getting into this.') + def test_result_as_args(self): users.insert().execute([dict(user_id=1, user_name='john'), dict(user_id=2, user_name='ed')]) r = users.select().execute() @@ -620,13 +636,13 @@ class QueryTest(TestBase): # Null values are not outside any set assert len(r) == 0 - u = bindparam('search_key') + @testing.fails_on('firebird', "kinterbasdb doesn't send full type information") + def test_bind_in(self): + users.insert().execute(user_id = 7, user_name = 'jack') + users.insert().execute(user_id = 8, user_name = 'fred') + users.insert().execute(user_id = 9, user_name = None) - s = users.select(u.in_([])) - r = s.execute(search_key='john').fetchall() - assert len(r) == 0 - r = s.execute(search_key=None).fetchall() - assert len(r) == 0 + u = bindparam('search_key') s = users.select(not_(u.in_([]))) r = s.execute(search_key='john').fetchall() @@ -881,6 +897,7 @@ class CompoundTest(TestBase): found2 = self._fetchall_sorted(u.alias('bar').select().execute()) eq_(found2, wanted) + @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs") def test_union_ordered(self): (s1, s2) = ( select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], @@ -894,6 +911,7 @@ class CompoundTest(TestBase): ('ccc', 'aaa')] eq_(u.execute().fetchall(), wanted) + @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs") @testing.fails_on('maxdb', 'FIXME: unknown') @testing.requires.subqueries def test_union_ordered_alias(self): @@ -910,6 +928,7 @@ class CompoundTest(TestBase): eq_(u.alias('bar').select().execute().fetchall(), wanted) @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') + @testing.fails_on('firebird', "has trouble extracting anonymous column from union subquery") @testing.fails_on('mysql', 'FIXME: unknown') @testing.fails_on('sqlite', 'FIXME: unknown') def test_union_all(self): @@ -928,6 +947,29 @@ class CompoundTest(TestBase): found2 = self._fetchall_sorted(e.alias('foo').select().execute()) eq_(found2, wanted) + def test_union_all_lightweight(self): + """like test_union_all, but breaks the sub-union into + a subquery with an explicit column reference on the outside, + more palatable to a wider variety of engines. + + """ + u = union( + select([t1.c.col3]), + select([t1.c.col3]), + ).alias() + + e = union_all( + select([t1.c.col3]), + select([u.c.col3]) + ) + + wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)] + found1 = self._fetchall_sorted(e.execute()) + eq_(found1, wanted) + + found2 = self._fetchall_sorted(e.alias('foo').select().execute()) + eq_(found2, wanted) + @testing.crashes('firebird', 'Does not support intersect') @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on') @testing.fails_on('mysql', 'FIXME: unknown') diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 821a386ffe..512ef2e7f9 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -623,8 +623,8 @@ class DateTest(TestBase, AssertsExecutionResults): t.drop(checkfirst=True) class StringTest(TestBase, AssertsExecutionResults): - @testing.fails_on('mysql', 'FIXME: unknown') - @testing.fails_on('oracle', 'FIXME: unknown') + + @testing.requires.unbounded_varchar def test_nolength_string(self): metadata = MetaData(testing.db) foo = Table('foo', metadata, Column('one', String))