]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- firebird support. reflection works fully, overall test success in the 75% range...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 12 Jul 2009 01:23:03 +0000 (01:23 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 12 Jul 2009 01:23:03 +0000 (01:23 +0000)
- oracle and firebird now normalize column names to SQLA "lowercase" for result.keys()
allowing consistent identifier name visibility on the client side.

18 files changed:
06CHANGES
lib/sqlalchemy/dialects/firebird/__init__.py
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/firebird/kinterbasdb.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/test/requires.py
lib/sqlalchemy/test/schema.py
lib/sqlalchemy/types.py
test/dialect/test_firebird.py
test/engine/test_reflection.py
test/sql/test_query.py
test/sql/test_types.py

index 1742cfc2f711b7461866e3a0674b4807fdb34a4c..413d2972a175815c7d0d94dc56afcabfcc7e0fef 100644 (file)
--- a/06CHANGES
+++ b/06CHANGES
     - func.char_length is a generic function for LENGTH
     - ForeignKey() which includes onupdate=<value> will emit a warning, not 
       emit ON UPDATE CASCADE which is unsupported by oracle
+    - the keys() method of RowProxy() now returns the result column names *normalized*
+      to be SQLAlchemy case insensitive names.  This means they will be lower case 
+      for case insensitive names, whereas the DBAPI would normally return them 
+      as UPPERCASE names.  This allows row keys() to be compatible with further
+      SQLAlchemy operations.
+
+- firebird
+    - the keys() method of RowProxy() now returns the result column names *normalized*
+      to be SQLAlchemy case insensitive names. This means they will be lower case 
+      for case insensitive names, whereas the DBAPI would normally return them 
+      as UPPERCASE names.  This allows row keys() to be compatible with further
+      SQLAlchemy operations.
       
 - new dialects
-    - pg8000
-    - pyodbc+mysql
-    
+    - postgres+pg8000
+    - postgres+pypostgresql (partial)
+    - postgres+zxjdbc
+    - mysql+pyodbc
+    - mysql+zxjdbc
+
 - mssql
     - the "has_window_funcs" flag is removed.  LIMIT/OFFSET usage will use ROW NUMBER as always,
       and if on an older version of SQL Server, the operation fails.  The behavior is exactly
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..6b1b80db21ac11117a1c7f015eb4f0cd9bdd43c2 100644 (file)
@@ -0,0 +1,3 @@
+from sqlalchemy.dialects.firebird import base, kinterbasdb
+
+base.dialect = kinterbasdb.dialect
\ No newline at end of file
index a616ee0f692d5c3f9115d8d5392d7332f1e27af3..41bf4f8aa8a717f81638f2649cd1c4e4d0a9fd86 100644 (file)
@@ -8,7 +8,9 @@
 Firebird backend
 ================
 
-This module implements the Firebird backend, thru the kinterbasdb_
+This module implements the Firebird backend.
+
+Connectivity is usually supplied via the kinterbasdb_
 DBAPI module.
 
 Firebird dialects
@@ -22,61 +24,32 @@ dialect 1
 
 dialect 3
   This is the newer and supported syntax, introduced in Interbase 6.0.
-
-From the user point of view, the biggest change is in date/time
-handling: under dialect 1, there's a single kind of field, ``DATE``
-with a synonim ``DATETIME``, that holds a `timestamp` value, that is a
-date with hour, minute, second. Under dialect 3 there are three kinds,
-a ``DATE`` that holds a date, a ``TIME`` that holds a *time of the
-day* value and a ``TIMESTAMP``, equivalent to the old ``DATE``.
-
-The problem is that the dialect of a Firebird database is a property
-of the database itself [#]_ (that is, any single database has been
-created with one dialect or the other: there is no way to change the
-after creation). SQLAlchemy has a single instance of the class that
-controls all the connections to a particular kind of database, so it
-cannot easily differentiate between the two modes, and in particular
-it **cannot** simultaneously talk with two distinct Firebird databases
-with different dialects.
-
-By default this module is biased toward dialect 3, but you can easily
-tweak it to handle dialect 1 if needed::
-
-  from sqlalchemy import types as sqltypes
-  from sqlalchemy.databases.firebird import FBDate, colspecs, ischema_names
-
-  # Adjust the mapping of the timestamp kind
-  ischema_names['TIMESTAMP'] = FBDate
-  colspecs[sqltypes.DateTime] = FBDate,
-
-Other aspects may be version-specific. You can use the ``server_version_info()`` method
-on the ``FBDialect`` class to do whatever is needed::
-
-  from sqlalchemy.databases.firebird import FBCompiler
-
-  if engine.dialect.server_version_info(connection) < (2,0):
-      # Change the name of the function ``length`` to use the UDF version
-      # instead of ``char_length``
-      FBCompiler.LENGTH_FUNCTION_NAME = 'strlen'
-
-Pooling connections
--------------------
-
-The default strategy used by SQLAlchemy to pool the database connections
-in particular cases may raise an ``OperationalError`` with a message
-`"object XYZ is in use"`. This happens on Firebird when there are two
-connections to the database, one is using, or has used, a particular table
-and the other tries to drop or alter the same table. To garantee DDL
-operations success Firebird recommend doing them as the single connected user.
-
-In case your SA application effectively needs to do DDL operations while other
-connections are active, the following setting may alleviate the problem::
-
-  from sqlalchemy import pool
-  from sqlalchemy.databases.firebird import dialect
-
-  # Force SA to use a single connection per thread
-  dialect.poolclass = pool.SingletonThreadPool
+  
+The SQLAlchemy Firebird dialect detects these versions and 
+adjusts its representation of SQL accordingly.  However, 
+support for dialect 1 is not well tested and probably has 
+incompatibilities.
+
+Firebird Locking Behavior
+-------------------------
+
+Firebird locks tables aggressively.  For this reason, a DROP TABLE
+may hang until other transactions are released.  SQLAlchemy
+does its best to release transactions as quickly as possible.  The
+most common cause of hanging transactions is a non-fully consumed
+result set, i.e.::
+
+    result = engine.execute("select * from table")
+    row = result.fetchone()
+    return
+    
+Where above, the ``ResultProxy`` has not been fully consumed.  The
+connection will be returned to the pool and the transactional state
+rolled back once the Python garbage collector reclaims the 
+objects which hold onto the connection, which often occurs asynchronously.
+The above use case can be alleviated by calling ``first()`` on the 
+``ResultProxy`` which will fetch the first row and immediately close
+all remaining cursor/connection resources.
 
 RETURNING support
 -----------------
@@ -96,6 +69,7 @@ parameter when creating the queries::
 
 .. _dialects: http://mc-computing.com/Databases/Firebird/SQL_Dialect.html
 .. _kinterbasdb: http://sourceforge.net/projects/kinterbasdb
+
 """
 
 
@@ -103,265 +77,293 @@ import datetime, decimal, re
 
 from sqlalchemy import schema as sa_schema
 from sqlalchemy import exc, types as sqltypes, sql, util
+from sqlalchemy.sql import expression
 from sqlalchemy.engine import base, default, reflection
+from sqlalchemy.sql import compiler
 
+from sqlalchemy.types import (BIGINT, BLOB, BOOLEAN, CHAR, DATE,
+                               FLOAT, INTEGER, NUMERIC, SMALLINT,
+                               TEXT, TIME, TIMESTAMP, VARCHAR)
 
-_initialized_kb = False
+RESERVED_WORDS = set(
+   ["action", "active", "add", "admin", "after", "all", "alter", "and", "any",
+    "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename",
+    "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer",
+    "by", "cache", "cascade", "case", "cast", "char", "character", "character_length",
+    "char_length", "check", "check_point_len", "check_point_length", "close", "collate",
+    "collation", "column", "commit", "committed", "compiletime", "computed", "conditional",
+    "connect", "constraint", "containing", "continue", "count", "create", "cstring",
+    "current", "current_connection", "current_date", "current_role", "current_time",
+    "current_timestamp", "current_transaction", "current_user", "cursor", "database",
+    "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete",
+    "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct",
+    "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point",
+    "escape", "event", "exception", "execute", "exists", "exit", "extern", "external",
+    "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it",
+    "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto",
+    "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour",
+    "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input",
+    "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join",
+    "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile",
+    "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment",
+    "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month",
+    "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric",
+    "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option",
+    "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength",
+    "pages", "page_size", "parameter", "password", "plan", "position", "post_event",
+    "precision", "prepare", "primary", "privileges", "procedure", "protected", "public",
+    "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate",
+    "references", "release", "release", "reserv", "reserving", "restrict", "retain",
+    "return", "returning_values", "returns", "revoke", "right", "role", "rollback",
+    "row_count", "runtime", "savepoint", "schema", "second", "segment", "select",
+    "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint",
+    "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability",
+    "starting", "starts", "statement", "static", "statistics", "sub_type", "sum",
+    "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction",
+    "translate", "translation", "trigger", "trim", "type", "uncommitted", "union",
+    "unique", "update", "upper", "user", "using", "value", "values", "varchar",
+    "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when",
+    "whenever", "where", "while", "with", "work", "write", "year", "yearday" ])
 
+colspecs = {
+}
 
-class FBNumeric(sqltypes.Numeric):
-    """Handle ``NUMERIC(precision,scale)`` datatype."""
+ischema_names = {
+      'SHORT': SMALLINT,
+       'LONG': BIGINT,
+       'QUAD': FLOAT,
+      'FLOAT': FLOAT,
+       'DATE': DATE,
+       'TIME': TIME,
+       'TEXT': TEXT,
+      'INT64': NUMERIC,
+     'DOUBLE': FLOAT,
+  'TIMESTAMP': TIMESTAMP,
+    'VARYING': VARCHAR,
+    'CSTRING': CHAR,
+       'BLOB': BLOB,
+    }
+
+
+# TODO: Boolean type, date conversion types (should be implemented as _FBDateTime, _FBDate, etc.
+# as bind/result functionality is required)
+
+
+class FBTypeCompiler(compiler.GenericTypeCompiler):
+    def visit_datetime(self, type_):
+        return self.visit_TIMESTAMP(type_)
+        
+    def visit_TEXT(self, type_):
+        return "BLOB SUB_TYPE 1"
+        
+    def visit_BLOB(self, type_):
+        return "BLOB SUB_TYPE 0"
 
-    def get_col_spec(self):
-        if self.precision is None:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % { 'precision': self.precision,
-                                                            'scale' : self.scale }
+class FBCompiler(sql.compiler.SQLCompiler):
+    """Firebird specific idiosincrasies"""
 
-    def bind_processor(self, dialect):
-        return None
+    def visit_mod(self, binary, **kw):
+        # Firebird lacks a builtin modulo operator, but there is
+        # an equivalent function in the ib_udf library.
+        return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
 
-    def result_processor(self, dialect):
-        if self.asdecimal:
-            return None
+    def visit_alias(self, alias, asfrom=False, **kwargs):
+        if self.dialect._version_two:
+            return super(FBCompiler, self).visit_alias(alias, asfrom=asfrom, **kwargs)
         else:
-            def process(value):
-                if isinstance(value, decimal.Decimal):
-                    return float(value)
-                else:
-                    return value
-            return process
-
-
-class FBFloat(sqltypes.Float):
-    """Handle ``FLOAT(precision)`` datatype."""
+            # Override to not use the AS keyword which FB 1.5 does not like
+            if asfrom:
+                alias_name = isinstance(alias.name, expression._generated_label) and \
+                                self._truncated_identifier("alias", alias.name) or alias.name
+            
+                return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + \
+                            self.preparer.format_alias(alias, alias_name)
+            else:
+                return self.process(alias.original, **kwargs)
 
-    def get_col_spec(self):
-        if not self.precision:
-            return "FLOAT"
+    def visit_substring_func(self, func, **kw):
+        s = self.process(func.clauses.clauses[0])
+        start = self.process(func.clauses.clauses[1])
+        if len(func.clauses.clauses) > 2:
+            length = self.process(func.clauses.clauses[2])
+            return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
         else:
-            return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class FBInteger(sqltypes.Integer):
-    """Handle ``INTEGER`` datatype."""
-
-    def get_col_spec(self):
-        return "INTEGER"
-
+            return "SUBSTRING(%s FROM %s)" % (s, start)
 
-class FBSmallInteger(sqltypes.SmallInteger):
-    """Handle ``SMALLINT`` datatype."""
+    def visit_length_func(self, function, **kw):
+        if self.dialect._version_two:
+            return "char_length" + self.function_argspec(function)
+        else:
+            return "strlen" + self.function_argspec(function)
 
-    def get_col_spec(self):
-        return "SMALLINT"
+    visit_char_length_func = visit_length_func
 
+    def function_argspec(self, func):
+        if func.clauses:
+            return self.process(func.clause_expr)
+        else:
+            return ""
 
-class FBDateTime(sqltypes.DateTime):
-    """Handle ``TIMESTAMP`` datatype."""
+    def default_from(self):
+        return " FROM rdb$database"
 
-    def get_col_spec(self):
-        return "TIMESTAMP"
+    def visit_sequence(self, seq):
+        return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
 
-    def bind_processor(self, dialect):
-        def process(value):
-            if value is None or isinstance(value, datetime.datetime):
-                return value
-            else:
-                return datetime.datetime(year=value.year,
-                                         month=value.month,
-                                         day=value.day)
-        return process
+    def get_select_precolumns(self, select):
+        """Called when building a ``SELECT`` statement, position is just
+        before column list Firebird puts the limit and offset right
+        after the ``SELECT``...
+        """
 
+        result = ""
+        if select._limit:
+            result += "FIRST %d "  % select._limit
+        if select._offset:
+            result +="SKIP %d "  %  select._offset
+        if select._distinct:
+            result += "DISTINCT "
+        return result
 
-class FBDate(sqltypes.DateTime):
-    """Handle ``DATE`` datatype."""
+    def limit_clause(self, select):
+        """Already taken care of in the `get_select_precolumns` method."""
 
-    def get_col_spec(self):
-        return "DATE"
+        return ""
 
 
-class FBTime(sqltypes.Time):
-    """Handle ``TIME`` datatype."""
+    def _append_returning(self, text, stmt):
+        returning_cols = stmt.kwargs["firebird_returning"]
+        def flatten_columnlist(collist):
+            for c in collist:
+                if isinstance(c, sql.expression.Selectable):
+                    for co in c.columns:
+                        yield co
+                else:
+                    yield c
+        columns = [self.process(c, within_columns_clause=True)
+                   for c in flatten_columnlist(returning_cols)]
+        text += ' RETURNING ' + ', '.join(columns)
+        return text
 
-    def get_col_spec(self):
-        return "TIME"
+    def visit_update(self, update_stmt):
+        text = super(FBCompiler, self).visit_update(update_stmt)
+        if "firebird_returning" in update_stmt.kwargs:
+            return self._append_returning(text, update_stmt)
+        else:
+            return text
 
+    def visit_insert(self, insert_stmt):
+        text = super(FBCompiler, self).visit_insert(insert_stmt)
+        if "firebird_returning" in insert_stmt.kwargs:
+            return self._append_returning(text, insert_stmt)
+        else:
+            return text
 
-class FBText(sqltypes.Text):
-    """Handle ``BLOB SUB_TYPE 1`` datatype (aka *textual* blob)."""
+    def visit_delete(self, delete_stmt):
+        text = super(FBCompiler, self).visit_delete(delete_stmt)
+        if "firebird_returning" in delete_stmt.kwargs:
+            return self._append_returning(text, delete_stmt)
+        else:
+            return text
 
-    def get_col_spec(self):
-        return "BLOB SUB_TYPE 1"
 
+class FBDDLCompiler(sql.compiler.DDLCompiler):
+    """Firebird syntactic idiosincrasies"""
 
-class FBString(sqltypes.String):
-    """Handle ``VARCHAR(length)`` datatype."""
+    def visit_create_sequence(self, create):
+        """Generate a ``CREATE GENERATOR`` statement for the sequence."""
 
-    def get_col_spec(self):
-        if self.length:
-            return "VARCHAR(%(length)s)" % {'length' : self.length}
+        if self.dialect._version_two:
+            return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
         else:
-            return "BLOB SUB_TYPE 1"
-
+            return "CREATE GENERATOR %s" % self.preparer.format_sequence(create.element)
 
-class FBChar(sqltypes.CHAR):
-    """Handle ``CHAR(length)`` datatype."""
+    def visit_drop_sequence(self, drop):
+        """Generate a ``DROP GENERATOR`` statement for the sequence."""
 
-    def get_col_spec(self):
-        if self.length:
-            return "CHAR(%(length)s)" % {'length' : self.length}
+        if self.dialect._version_two:
+            return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
         else:
-            return "BLOB SUB_TYPE 1"
+            return "DROP GENERATOR %s" % self.preparer.format_sequence(drop.element)
 
 
-class FBBinary(sqltypes.Binary):
-    """Handle ``BLOB SUB_TYPE 0`` datatype (aka *binary* blob)."""
-
-    def get_col_spec(self):
-        return "BLOB SUB_TYPE 0"
-
-
-class FBBoolean(sqltypes.Boolean):
-    """Handle boolean values as a ``SMALLINT`` datatype."""
-
-    def get_col_spec(self):
-        return "SMALLINT"
+class FBDefaultRunner(base.DefaultRunner):
+    """Firebird specific idiosincrasies"""
 
+    def visit_sequence(self, seq):
+        """Get the next value from the sequence using ``gen_id()``."""
 
-colspecs = {
-    sqltypes.Integer : FBInteger,
-    sqltypes.SmallInteger : FBSmallInteger,
-    sqltypes.Numeric : FBNumeric,
-    sqltypes.Float : FBFloat,
-    sqltypes.DateTime : FBDateTime,
-    sqltypes.Date : FBDate,
-    sqltypes.Time : FBTime,
-    sqltypes.String : FBString,
-    sqltypes.Binary : FBBinary,
-    sqltypes.Boolean : FBBoolean,
-    sqltypes.Text : FBText,
-    sqltypes.CHAR: FBChar,
-}
+        return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \
+            self.dialect.identifier_preparer.format_sequence(seq))
 
+class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
+    """Install Firebird specific reserved words."""
 
-ischema_names = {
-      'SHORT': lambda r: FBSmallInteger(),
-       'LONG': lambda r: FBInteger(),
-       'QUAD': lambda r: FBFloat(),
-      'FLOAT': lambda r: FBFloat(),
-       'DATE': lambda r: FBDate(),
-       'TIME': lambda r: FBTime(),
-       'TEXT': lambda r: FBString(r['flen']),
-      'INT64': lambda r: FBNumeric(precision=r['fprec'], scale=r['fscale'] * -1), # This generically handles NUMERIC()
-     'DOUBLE': lambda r: FBFloat(),
-  'TIMESTAMP': lambda r: FBDateTime(),
-    'VARYING': lambda r: FBString(r['flen']),
-    'CSTRING': lambda r: FBChar(r['flen']),
-       'BLOB': lambda r: r['stype']==1 and FBText() or FBBinary()
-      }
-
-RETURNING_KW_NAME = 'firebird_returning'
-
-class FBExecutionContext(default.DefaultExecutionContext):
-    pass
+    reserved_words = RESERVED_WORDS
 
+    def __init__(self, dialect):
+        super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
 
 class FBDialect(default.DefaultDialect):
     """Firebird dialect"""
+
     name = 'firebird'
-    supports_sane_rowcount = False
-    supports_sane_multi_rowcount = False
+
     max_identifier_length = 31
+    supports_sequences = True
+    sequences_optional = False
+    supports_default_values = True
+    supports_empty_insert = False
     preexecute_pk_sequences = True
     supports_pk_autoincrement = False
-
-    def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
-        default.DefaultDialect.__init__(self, **kwargs)
-
-        self.type_conv = type_conv
-        self.concurrency_level = concurrency_level
-
-    def dbapi(cls):
-        import kinterbasdb
-        return kinterbasdb
-    dbapi = classmethod(dbapi)
-
-    def create_connect_args(self, url):
-        opts = url.translate_connect_args(username='user')
-        if opts.get('port'):
-            opts['host'] = "%s/%s" % (opts['host'], opts['port'])
-            del opts['port']
-        opts.update(url.query)
-
-        type_conv = opts.pop('type_conv', self.type_conv)
-        concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
-        global _initialized_kb
-        if not _initialized_kb and self.dbapi is not None:
-            _initialized_kb = True
-            self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
-        return ([], opts)
-
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
-
-    def server_version_info(self, connection):
-        """Get the version of the Firebird server used by a connection.
-
-        Returns a tuple of (`major`, `minor`, `build`), three integers
-        representing the version of the attached server.
-        """
-
-        # This is the simpler approach (the other uses the services api),
-        # that for backward compatibility reasons returns a string like
-        #   LI-V6.3.3.12981 Firebird 2.0
-        # where the first version is a fake one resembling the old
-        # Interbase signature. This is more than enough for our purposes,
-        # as this is mainly (only?) used by the testsuite.
-
-        from re import match
-
-        fbconn = connection.connection.connection
-        version = fbconn.server_version
-        m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
-        if not m:
-            raise AssertionError("Could not determine version from string '%s'" % version)
-        return tuple([int(x) for x in m.group(5, 6, 4)])
-
-    def _normalize_name(self, name):
-        """Convert the name to lowercase if it is possible"""
-
+    requires_name_normalize = True
+
+    statement_compiler = FBCompiler
+    ddl_compiler = FBDDLCompiler
+    defaultrunner = FBDefaultRunner
+    preparer = FBIdentifierPreparer
+    type_compiler = FBTypeCompiler
+
+    colspecs = colspecs
+    ischema_names = ischema_names
+    
+    # defaults to dialect ver. 3,
+    # will be autodetected off upon 
+    # first connect
+    _version_two = True
+    
+    def initialize(self, connection):
+        super(FBDialect, self).initialize(connection)
+        self._version_two = self.server_version_info > (2, )
+        if not self._version_two:
+            # TODO: whatever other pre < 2.0 stuff goes here
+            self.ischema_names = ischema_names.copy()
+            self.ischema_names['TIMESTAMP'] = sqltypes.DATE
+            self.colspecs = {
+                sqltypes.DateTime :sqltypes.DATE
+            }
+        
+    def normalize_name(self, name):
         # Remove trailing spaces: FB uses a CHAR() type,
         # that is padded with spaces
         name = name and name.rstrip()
         if name is None:
             return None
-        elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower()):
+        elif name.upper() == name and \
+            not self.identifier_preparer._requires_quotes(name.lower()):
             return name.lower()
         else:
             return name
 
-    def _denormalize_name(self, name):
-        """Revert a *normalized* name to its uppercase equivalent"""
-
+    def denormalize_name(self, name):
         if name is None:
             return None
-        elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()):
+        elif name.lower() == name and \
+            not self.identifier_preparer._requires_quotes(name.lower()):
             return name.upper()
         else:
             return name
 
-    def table_names(self, connection, schema):
-        """Return a list of *normalized* table names omitting system relations."""
-
-        s = """
-        SELECT r.rdb$relation_name
-        FROM rdb$relations r
-        WHERE r.rdb$system_flag=0
-        """
-        return [self._normalize_name(row[0]) for row in connection.execute(s)]
-
     def has_table(self, connection, table_name, schema=None):
         """Return ``True`` if the given table exists, ignoring the `schema`."""
 
@@ -371,12 +373,8 @@ class FBDialect(default.DefaultDialect):
                       FROM rdb$relations
                       WHERE rdb$relation_name=?)
         """
-        c = connection.execute(tblqry, [self._denormalize_name(table_name)])
-        row = c.fetchone()
-        if row is not None:
-            return True
-        else:
-            return False
+        c = connection.execute(tblqry, [self.denormalize_name(table_name)])
+        return c.first() is not None
 
     def has_sequence(self, connection, sequence_name):
         """Return ``True`` if the given sequence (generator) exists."""
@@ -387,32 +385,43 @@ class FBDialect(default.DefaultDialect):
                       FROM rdb$generators
                       WHERE rdb$generator_name=?)
         """
-        c = connection.execute(genqry, [self._denormalize_name(sequence_name)])
-        row = c.fetchone()
-        if row is not None:
-            return True
-        else:
-            return False
-
-    def is_disconnect(self, e):
-        if isinstance(e, self.dbapi.OperationalError):
-            return 'Unable to complete network request to host' in str(e)
-        elif isinstance(e, self.dbapi.ProgrammingError):
-            msg = str(e)
-            return ('Invalid connection state' in msg or
-                    'Invalid cursor state' in msg)
-        else:
-            return False
+        c = connection.execute(genqry, [self.denormalize_name(sequence_name)])
+        return c.first() is not None
 
-    @reflection.cache
-    def get_table_names(self, connection, schema=None, **kw):
+    def table_names(self, connection, schema):
         s = """
         SELECT DISTINCT rdb$relation_name
         FROM rdb$relation_fields WHERE
         rdb$system_flag=0 AND rdb$view_context IS NULL
         """
-        return [self._normalize_name(row[0]) for row in connection.execute(s)]
+        return [self.normalize_name(row[0]) for row in connection.execute(s)]
 
+    @reflection.cache
+    def get_table_names(self, connection, schema=None, **kw):
+        return self.table_names(connection, schema)
+
+    @reflection.cache
+    def get_view_names(self, connection, schema=None, **kw):
+        s = """
+        SELECT distinct rdb$view_name
+        FROM rdb$view_relations
+        """
+        return [self.normalize_name(row[0]) for row in connection.execute(s)]
+
+    @reflection.cache
+    def get_view_definition(self, connection, view_name, schema=None, **kw):
+        qry = """
+            SELECT rdb$view_source AS view_source
+            FROM rdb$relations 
+            WHERE rdb$relation_name=?;
+        """
+        rp = connection.execute(qry, [self.denormalize_name(view_name)])
+        row = rp.first()
+        if row:
+            return row['view_source']
+        else:
+            return None
+    
     @reflection.cache
     def get_primary_keys(self, connection, table_name, schema=None, **kw):
         # Query to extract the PK/FK constrained fields of the given table
@@ -422,17 +431,17 @@ class FBDialect(default.DefaultDialect):
              JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
         WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
         """
-        tablename = self._denormalize_name(table.name)
+        tablename = self.denormalize_name(table_name)
         # get primary key fields
         c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
-        pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()]
+        pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()]
         return pkfields
 
     @reflection.cache
     def get_column_sequence(self, connection, table_name, column_name,
                                                         schema=None, **kw):
-        tablename = self._denormalize_name(table_name)
-        colname = self._denormalize_name(column_name)
+        tablename = self.denormalize_name(table_name)
+        colname = self.denormalize_name(column_name)
         # Heuristic-query to determine the generator associated to a PK field
         genqry = """
         SELECT trigdep.rdb$depended_on_name AS fgenerator
@@ -452,7 +461,7 @@ class FBDialect(default.DefaultDialect):
         genc = connection.execute(genqry, [tablename, colname])
         genr = genc.fetchone()
         if genr is not None:
-            return dict(name=self._normalize_name(genr['fgenerator']))
+            return dict(name=self.normalize_name(genr['fgenerator']))
 
     @reflection.cache
     def get_columns(self, connection, table_name, schema=None, **kw):
@@ -472,7 +481,7 @@ class FBDialect(default.DefaultDialect):
         WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
         ORDER BY r.rdb$field_position
         """
-        tablename = self._denormalize_name(table_name)
+        tablename = self.denormalize_name(table_name)
         # get all of the fields for this table
         c = connection.execute(tblqry, [tablename])
         cols = []
@@ -480,16 +489,29 @@ class FBDialect(default.DefaultDialect):
             row = c.fetchone()
             if row is None:
                 break
-            name = self._normalize_name(row['fname'])
+            name = self.normalize_name(row['fname'])
             # get the data type
-            coltype = ischema_names.get(row['ftype'].rstrip())
+            
+            colspec = row['ftype'].rstrip()
+            coltype = self.ischema_names.get(colspec)
             if coltype is None:
                 util.warn("Did not recognize type '%s' of column '%s'" %
-                          (str(row['ftype']), name))
+                          (colspec, name))
                 coltype = sqltypes.NULLTYPE
+            elif colspec == 'INT64':
+                coltype = coltype(precision=row['fprec'], scale=row['fscale'] * -1)
+            elif colspec in ('VARYING', 'CSTRING'):
+                coltype = coltype(row['flen'])
+            elif colspec == 'TEXT':
+                coltype = TEXT(row['flen'])
+            elif colspec == 'BLOB':
+                if row['stype'] == 1:
+                    coltype = TEXT()
+                else:
+                    coltype = BLOB()
             else:
                 coltype = coltype(row)
-
+            
             # does it have a default value?
             defvalue = None
             if row['fdefault'] is not None:
@@ -517,96 +539,65 @@ class FBDialect(default.DefaultDialect):
              JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
              JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
              JOIN rdb$index_segments cse ON cse.rdb$index_name=ix1.rdb$index_name
-             JOIN rdb$index_segments se ON se.rdb$index_name=ix2.rdb$index_name AND se.rdb$field_position=cse.rdb$field_position
+             JOIN rdb$index_segments se ON se.rdb$index_name=ix2.rdb$index_name AND 
+             se.rdb$field_position=cse.rdb$field_position
         WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
         ORDER BY se.rdb$index_name, se.rdb$field_position
         """
-        tablename = self._denormalize_name(table_name)
-        # get the foreign keys
+        tablename = self.denormalize_name(table_name)
+
         c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
-        fks = {}
-        fkeys = []
-        while True:
-            row = c.fetchone()
-            if not row:
-                break
-            cname = self._normalize_name(row['cname'])
-            if cname in fks:
-                fk = fks[cname]
-            else:
-                fk = {
-                    'name' : cname,
-                    'constrained_columns' : [],
-                    'referred_schema' : None,
-                    'referred_table' : None,
-                    'referred_columns' : []
-                }
-                fks[cname] = fk
-                fkeys.append(fk)
-            fk['referred_table'] = self._normalize_name(row['targetrname'])
-            fk['constrained_columns'].append(self._normalize_name(row['fname']))
+        fks = util.defaultdict(lambda:{
+            'name' : None,
+            'constrained_columns' : [],
+            'referred_schema' : None,
+            'referred_table' : None,
+            'referred_columns' : []
+        })
+        
+        for row in c:
+            cname = self.normalize_name(row['cname'])
+            fk = fks[cname]
+            if not fk['name']:
+                fk['name'] = cname
+                fk['referred_table'] = self.normalize_name(row['targetrname'])
+            fk['constrained_columns'].append(self.normalize_name(row['fname']))
             fk['referred_columns'].append(
-                            self._normalize_name(row['targetfname']))
-            return fkeys
-
-    def reflecttable(self, connection, table, include_columns):
-
-        # get primary key fields
-        pkfields = self.get_primary_keys(connection, table.name)
-
-        found_table = False
-        for col_d in self.get_columns(connection, table.name):
-            found_table = True
-
-            name = col_d.get('name')
-            defvalue = col_d.get('default')
-            nullable = col_d.get('nullable')
-            coltype = col_d.get('type')
-
-            if include_columns and name not in include_columns:
-                continue
-            args = [name]
-
-            kw = {}
-            args.append(coltype)
-
-            # is it a primary key?
-            kw['primary_key'] = name in pkfields
-
-            # is it nullable?
-            kw['nullable'] = nullable
-
-            # does it have a default value?
-            if defvalue:
-                args.append(sa_schema.DefaultClause(sql.text(defvalue)))
-
-            col = sa_schema.Column(*args, **kw)
-            if kw['primary_key']:
-                # if the PK is a single field, try to see if its linked to
-                # a sequence thru a trigger
-                if len(pkfields)==1:
-                    sequence_name = self.get_column_sequence(connection,
-                                            table.name, name)
-                    if sequence_name is not None:
-                        col.sequence = sa_schema.Sequence(sequence_name)
-            table.append_column(col)
-
-        if not found_table:
-            raise exc.NoSuchTableError(table.name)
-
-        # get the foreign keys
-        for fkey_d in self.get_foreign_keys(connection, table.name):
-            cname = fkey_d['name']
-            constrained_columns = fkey_d['constrained_columns']
-            rname = fkey_d['referred_table']
-            referred_columns = fkey_d['referred_columns']
-
-            sa_schema.Table(rname, table.metadata, autoload=True, autoload_with=connection)
-            refspec = ['.'.join(c) for c in \
-                                zip(constrained_columns, referred_columns)]
-            table.append_constraint(sa_schema.ForeignKeyConstraint(
-                constrained_columns, refspec, name=cname, link_to_name=True))
+                            self.normalize_name(row['targetfname']))
+        return fks.values()
 
+    @reflection.cache
+    def get_indexes(self, connection, table_name, schema=None, **kw):
+        qry = """
+            SELECT 
+                ix.rdb$index_name AS index_name,
+                ix.rdb$unique_flag AS unique_flag,
+                ic.rdb$field_name AS field_name
+                
+            FROM rdb$indices ix JOIN rdb$index_segments ic
+            ON ix.rdb$index_name=ic.rdb$index_name
+            
+            LEFT OUTER JOIN RDB$RELATION_CONSTRAINTS 
+            ON RDB$RELATION_CONSTRAINTS.RDB$INDEX_NAME = ic.RDB$INDEX_NAME
+            
+            WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
+            AND RDB$RELATION_CONSTRAINTS.RDB$CONSTRAINT_TYPE IS NULL
+            ORDER BY index_name, field_name
+        """
+        c = connection.execute(qry, [self.denormalize_name(table_name)])
+        
+        indexes = util.defaultdict(dict)
+        for row in c:
+            indexrec = indexes[row['index_name']]
+            if 'name' not in indexrec:
+                indexrec['name'] = self.normalize_name(row['index_name'])
+                indexrec['column_names'] = []
+                indexrec['unique'] = bool(row['unique_flag'])
+            
+            indexrec['column_names'].append(self.normalize_name(row['field_name']))
+
+        return indexes.values()
+        
     def do_execute(self, cursor, statement, parameters, **kwargs):
         # kinterbase does not accept a None, but wants an empty list
         # when there are no arguments.
@@ -621,194 +612,3 @@ class FBDialect(default.DefaultDialect):
         connection.commit(True)
 
 
-class FBCompiler(sql.compiler.SQLCompiler):
-    """Firebird specific idiosincrasies"""
-
-    def visit_mod(self, binary, **kw):
-        # Firebird lacks a builtin modulo operator, but there is
-        # an equivalent function in the ib_udf library.
-        return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
-        
-    def visit_alias(self, alias, asfrom=False, **kwargs):
-        # Override to not use the AS keyword which FB 1.5 does not like
-        if asfrom:
-            return self.process(alias.original, asfrom=True, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name))
-        else:
-            return self.process(alias.original, **kwargs)
-
-    def visit_substring_func(self, func, **kw):
-        s = self.process(func.clauses.clauses[0])
-        start = self.process(func.clauses.clauses[1])
-        if len(func.clauses.clauses) > 2:
-            length = self.process(func.clauses.clauses[2])
-            return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
-        else:
-            return "SUBSTRING(%s FROM %s)" % (s, start)
-
-    # TODO: auto-detect this or something
-    LENGTH_FUNCTION_NAME = 'char_length'
-        
-    def visit_length_func(self, function, **kw):
-        return self.LENGTH_FUNCTION_NAME + self.function_argspec(function)
-
-    def visit_char_length_func(self, function, **kw):
-        return self.LENGTH_FUNCTION_NAME + self.function_argspec(function)
-        
-    def function_argspec(self, func):
-        if func.clauses:
-            return self.process(func.clause_expr)
-        else:
-            return ""
-
-    def default_from(self):
-        return " FROM rdb$database"
-
-    def visit_sequence(self, seq):
-        return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
-
-    def get_select_precolumns(self, select):
-        """Called when building a ``SELECT`` statement, position is just
-        before column list Firebird puts the limit and offset right
-        after the ``SELECT``...
-        """
-
-        result = ""
-        if select._limit:
-            result += "FIRST %d "  % select._limit
-        if select._offset:
-            result +="SKIP %d "  %  select._offset
-        if select._distinct:
-            result += "DISTINCT "
-        return result
-
-    def limit_clause(self, select):
-        """Already taken care of in the `get_select_precolumns` method."""
-
-        return ""
-
-
-    def _append_returning(self, text, stmt):
-        returning_cols = stmt.kwargs[RETURNING_KW_NAME]
-        def flatten_columnlist(collist):
-            for c in collist:
-                if isinstance(c, sql.expression.Selectable):
-                    for co in c.columns:
-                        yield co
-                else:
-                    yield c
-        columns = [self.process(c, within_columns_clause=True)
-                   for c in flatten_columnlist(returning_cols)]
-        text += ' RETURNING ' + ', '.join(columns)
-        return text
-
-    def visit_update(self, update_stmt):
-        text = super(FBCompiler, self).visit_update(update_stmt)
-        if RETURNING_KW_NAME in update_stmt.kwargs:
-            return self._append_returning(text, update_stmt)
-        else:
-            return text
-
-    def visit_insert(self, insert_stmt):
-        text = super(FBCompiler, self).visit_insert(insert_stmt)
-        if RETURNING_KW_NAME in insert_stmt.kwargs:
-            return self._append_returning(text, insert_stmt)
-        else:
-            return text
-
-    def visit_delete(self, delete_stmt):
-        text = super(FBCompiler, self).visit_delete(delete_stmt)
-        if RETURNING_KW_NAME in delete_stmt.kwargs:
-            return self._append_returning(text, delete_stmt)
-        else:
-            return text
-
-
-class FBSchemaGenerator(sql.compiler.SchemaGenerator):
-    """Firebird syntactic idiosincrasies"""
-
-    def visit_sequence(self, sequence):
-        """Generate a ``CREATE GENERATOR`` statement for the sequence."""
-
-        if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name):
-            self.append("CREATE GENERATOR %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-
-class FBSchemaDropper(sql.compiler.SchemaDropper):
-    """Firebird syntactic idiosincrasies"""
-
-    def visit_sequence(self, sequence):
-        """Generate a ``DROP GENERATOR`` statement for the sequence."""
-
-        if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
-            self.append("DROP GENERATOR %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-
-class FBDefaultRunner(base.DefaultRunner):
-    """Firebird specific idiosincrasies"""
-
-    def visit_sequence(self, seq):
-        """Get the next value from the sequence using ``gen_id()``."""
-
-        return self.execute_string("SELECT gen_id(%s, 1) FROM rdb$database" % \
-            self.dialect.identifier_preparer.format_sequence(seq))
-
-
-RESERVED_WORDS = set(
-    ["action", "active", "add", "admin", "after", "all", "alter", "and", "any",
-     "as", "asc", "ascending", "at", "auto", "autoddl", "avg", "based", "basename",
-     "base_name", "before", "begin", "between", "bigint", "blob", "blobedit", "buffer",
-     "by", "cache", "cascade", "case", "cast", "char", "character", "character_length",
-     "char_length", "check", "check_point_len", "check_point_length", "close", "collate",
-     "collation", "column", "commit", "committed", "compiletime", "computed", "conditional",
-     "connect", "constraint", "containing", "continue", "count", "create", "cstring",
-     "current", "current_connection", "current_date", "current_role", "current_time",
-     "current_timestamp", "current_transaction", "current_user", "cursor", "database",
-     "date", "day", "db_key", "debug", "dec", "decimal", "declare", "default", "delete",
-     "desc", "descending", "describe", "descriptor", "disconnect", "display", "distinct",
-     "do", "domain", "double", "drop", "echo", "edit", "else", "end", "entry_point",
-     "escape", "event", "exception", "execute", "exists", "exit", "extern", "external",
-     "extract", "fetch", "file", "filter", "float", "for", "foreign", "found", "free_it",
-     "from", "full", "function", "gdscode", "generator", "gen_id", "global", "goto",
-     "grant", "group", "group_commit_", "group_commit_wait", "having", "help", "hour",
-     "if", "immediate", "in", "inactive", "index", "indicator", "init", "inner", "input",
-     "input_type", "insert", "int", "integer", "into", "is", "isolation", "isql", "join",
-     "key", "lc_messages", "lc_type", "left", "length", "lev", "level", "like", "logfile",
-     "log_buffer_size", "log_buf_size", "long", "manual", "max", "maximum", "maximum_segment",
-     "max_segment", "merge", "message", "min", "minimum", "minute", "module_name", "month",
-     "names", "national", "natural", "nchar", "no", "noauto", "not", "null", "numeric",
-     "num_log_buffers", "num_log_bufs", "octet_length", "of", "on", "only", "open", "option",
-     "or", "order", "outer", "output", "output_type", "overflow", "page", "pagelength",
-     "pages", "page_size", "parameter", "password", "plan", "position", "post_event",
-     "precision", "prepare", "primary", "privileges", "procedure", "protected", "public",
-     "quit", "raw_partitions", "rdb$db_key", "read", "real", "record_version", "recreate",
-     "references", "release", "release", "reserv", "reserving", "restrict", "retain",
-     "return", "returning_values", "returns", "revoke", "right", "role", "rollback",
-     "row_count", "runtime", "savepoint", "schema", "second", "segment", "select",
-     "set", "shadow", "shared", "shell", "show", "singular", "size", "smallint",
-     "snapshot", "some", "sort", "sqlcode", "sqlerror", "sqlwarning", "stability",
-     "starting", "starts", "statement", "static", "statistics", "sub_type", "sum",
-     "suspend", "table", "terminator", "then", "time", "timestamp", "to", "transaction",
-     "translate", "translation", "trigger", "trim", "type", "uncommitted", "union",
-     "unique", "update", "upper", "user", "using", "value", "values", "varchar",
-     "variable", "varying", "version", "view", "wait", "wait_time", "weekday", "when",
-     "whenever", "where", "while", "with", "work", "write", "year", "yearday" ])
-
-
-class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
-    """Install Firebird specific reserved words."""
-
-    reserved_words = RESERVED_WORDS
-
-    def __init__(self, dialect):
-        super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
-
-
-dialect = FBDialect
-dialect.statement_compiler = FBCompiler
-dialect.schemagenerator = FBSchemaGenerator
-dialect.schemadropper = FBSchemaDropper
-dialect.defaultrunner = FBDefaultRunner
-dialect.preparer = FBIdentifierPreparer
-dialect.execution_ctx_cls = FBExecutionContext
diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
new file mode 100644 (file)
index 0000000..e463959
--- /dev/null
@@ -0,0 +1,72 @@
+from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler
+from sqlalchemy.engine.default import DefaultExecutionContext
+
+_initialized_kb  = False
+
+class Firebird_kinterbasdb(FBDialect):
+    driver = 'kinterbasdb'
+    supports_sane_rowcount = False
+    supports_sane_multi_rowcount = False
+    
+    def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
+        super(Firebird_kinterbasdb, self).__init__(**kwargs)
+
+        self.type_conv = type_conv
+        self.concurrency_level = concurrency_level
+    
+    @classmethod
+    def dbapi(cls):
+        k = __import__('kinterbasdb')
+        return k
+
+    def create_connect_args(self, url):
+        opts = url.translate_connect_args(username='user')
+        if opts.get('port'):
+            opts['host'] = "%s/%s" % (opts['host'], opts['port'])
+            del opts['port']
+        opts.update(url.query)
+
+        type_conv = opts.pop('type_conv', self.type_conv)
+        concurrency_level = opts.pop('concurrency_level', self.concurrency_level)
+        global _initialized_kb
+        if not _initialized_kb and self.dbapi is not None:
+            _initialized_kb = True
+            self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level)
+        return ([], opts)
+
+    
+    def _get_server_version_info(self, connection):
+        """Get the version of the Firebird server used by a connection.
+
+        Returns a tuple of (`major`, `minor`, `build`), three integers
+        representing the version of the attached server.
+        """
+
+        # This is the simpler approach (the other uses the services api),
+        # that for backward compatibility reasons returns a string like
+        #   LI-V6.3.3.12981 Firebird 2.0
+        # where the first version is a fake one resembling the old
+        # Interbase signature. This is more than enough for our purposes,
+        # as this is mainly (only?) used by the testsuite.
+
+        from re import match
+
+        fbconn = connection.connection
+        version = fbconn.server_version
+        m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
+        if not m:
+            raise AssertionError("Could not determine version from string '%s'" % version)
+        return tuple([int(x) for x in m.group(5, 6, 4)])
+
+
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.OperationalError):
+            return 'Unable to complete network request to host' in str(e)
+        elif isinstance(e, self.dbapi.ProgrammingError):
+            msg = str(e)
+            return ('Invalid connection state' in msg or
+                    'Invalid cursor state' in msg)
+        else:
+            return False
+
+dialect = Firebird_kinterbasdb
index 2d8b31bcf2dff87ecd4a7ead81bafddc8e32d0f6..849b72b9793888598c44a95bca44e692dc628fd8 100644 (file)
@@ -1439,10 +1439,6 @@ class MSDialect(default.DefaultDialect):
 
         return fkeys.values()
 
-    def reflecttable(self, connection, table, include_columns):
-
-        insp = reflection.Inspector.from_engine(connection)
-        return insp.reflecttable(table, include_columns)
 
 # fixme.  I added this for the tests to run. -Randall
 MSSQLDialect = MSDialect
index 133d7ef57a5c9283dc55293fed325928481cae67..b6b57c796847561cd377388844451e6add7562a1 100644 (file)
@@ -1940,12 +1940,6 @@ class MySQLDialect(default.DefaultDialect):
             sql = parser._describe_to_create(table_name, columns)
         return parser.parse(sql, charset)
   
-    def reflecttable(self, connection, table, include_columns):
-        """Load column definitions from the server."""
-
-        insp = reflection.Inspector.from_engine(connection)
-        return insp.reflecttable(table, include_columns)
-
     def _adjust_casing(self, table, charset=None):
         """Adjust Table name to the server case sensitivity, if needed."""
 
index a512a2b1b17842c39469dcae18a79e755a947462..35c85c2c95a422df50e7508352445bf3a060fa57 100644 (file)
@@ -458,6 +458,7 @@ class OracleDialect(default.DefaultDialect):
     default_paramstyle = 'named'
     colspecs = colspecs
     ischema_names = ischema_names
+    requires_name_normalize = True
     
     supports_default_values = False
     supports_empty_insert = False
@@ -482,16 +483,16 @@ class OracleDialect(default.DefaultDialect):
     def has_table(self, connection, table_name, schema=None):
         if not schema:
             schema = self.get_default_schema_name(connection)
-        cursor = connection.execute("""select table_name from all_tables where table_name=:name and owner=:schema_name""", {'name':self._denormalize_name(table_name), 'schema_name':self._denormalize_name(schema)})
+        cursor = connection.execute("""select table_name from all_tables where table_name=:name and owner=:schema_name""", {'name':self.denormalize_name(table_name), 'schema_name':self.denormalize_name(schema)})
         return cursor.fetchone() is not None
 
     def has_sequence(self, connection, sequence_name, schema=None):
         if not schema:
             schema = self.get_default_schema_name(connection)
-        cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name and sequence_owner=:schema_name""", {'name':self._denormalize_name(sequence_name), 'schema_name':self._denormalize_name(schema)})
+        cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name and sequence_owner=:schema_name""", {'name':self.denormalize_name(sequence_name), 'schema_name':self.denormalize_name(schema)})
         return cursor.fetchone() is not None
 
-    def _normalize_name(self, name):
+    def normalize_name(self, name):
         if name is None:
             return None
         elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding)):
@@ -499,7 +500,7 @@ class OracleDialect(default.DefaultDialect):
         else:
             return name.decode(self.encoding)
 
-    def _denormalize_name(self, name):
+    def denormalize_name(self, name):
         if name is None:
             return None
         elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower()):
@@ -508,7 +509,7 @@ class OracleDialect(default.DefaultDialect):
             return name.encode(self.encoding)
 
     def get_default_schema_name(self, connection):
-        return self._normalize_name(connection.execute('SELECT USER FROM DUAL').scalar())
+        return self.normalize_name(connection.execute('SELECT USER FROM DUAL').scalar())
 
     def table_names(self, connection, schema):
         # note that table_names() isnt loading DBLINKed or synonym'ed tables
@@ -517,8 +518,8 @@ class OracleDialect(default.DefaultDialect):
             cursor = connection.execute(s)
         else:
             s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner"
-            cursor = connection.execute(s, {'owner': self._denormalize_name(schema)})
-        return [self._normalize_name(row[0]) for row in cursor]
+            cursor = connection.execute(s, {'owner': self.denormalize_name(schema)})
+        return [self.normalize_name(row[0]) for row in cursor]
 
     def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None):
         """search for a local synonym matching the given desired owner/name.
@@ -567,35 +568,35 @@ class OracleDialect(default.DefaultDialect):
                                  resolve_synonyms=False, dblink='', **kw):
 
         if resolve_synonyms:
-            actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(schema), desired_synonym=self._denormalize_name(table_name))
+            actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self.denormalize_name(schema), desired_synonym=self.denormalize_name(table_name))
         else:
             actual_name, owner, dblink, synonym = None, None, None, None
         if not actual_name:
-            actual_name = self._denormalize_name(table_name)
+            actual_name = self.denormalize_name(table_name)
         if not dblink:
             dblink = ''
         if not owner:
-            owner = self._denormalize_name(schema or self.get_default_schema_name(connection))
+            owner = self.denormalize_name(schema or self.get_default_schema_name(connection))
         return (actual_name, owner, dblink, synonym)
 
     @reflection.cache
     def get_schema_names(self, connection, **kw):
         s = "SELECT username FROM all_users ORDER BY username"
         cursor = connection.execute(s,)
-        return [self._normalize_name(row[0]) for row in cursor]
+        return [self.normalize_name(row[0]) for row in cursor]
 
     @reflection.cache
     def get_table_names(self, connection, schema=None, **kw):
-        schema = self._denormalize_name(schema or self.get_default_schema_name(connection))
+        schema = self.denormalize_name(schema or self.get_default_schema_name(connection))
         return self.table_names(connection, schema)
 
     @reflection.cache
     def get_view_names(self, connection, schema=None, **kw):
-        schema = self._denormalize_name(schema or self.get_default_schema_name(connection))
+        schema = self.denormalize_name(schema or self.get_default_schema_name(connection))
         s = "select view_name from all_views where OWNER = :owner"
         cursor = connection.execute(s,
-                {'owner':self._denormalize_name(schema)})
-        return [self._normalize_name(row[0]) for row in cursor]
+                {'owner':self.denormalize_name(schema)})
+        return [self.normalize_name(row[0]) for row in cursor]
 
     @reflection.cache
     def get_columns(self, connection, table_name, schema=None, **kw):
@@ -627,7 +628,7 @@ class OracleDialect(default.DefaultDialect):
         for row in c:
 
             (colname, coltype, length, precision, scale, nullable, default) = \
-                (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
+                (self.normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
 
             # INTEGER if the scale is 0 and precision is null
             # NUMBER if the scale and precision are both null
@@ -684,8 +685,8 @@ class OracleDialect(default.DefaultDialect):
         ORDER BY a.INDEX_NAME, a.COLUMN_POSITION
         """ % dict(dblink=dblink)
         rp = connection.execute(q,
-            dict(table_name=self._denormalize_name(table_name),
-                 schema=self._denormalize_name(schema)))
+            dict(table_name=self.denormalize_name(table_name),
+                 schema=self.denormalize_name(schema)))
         indexes = []
         last_index_name = None
         pkeys = self.get_primary_keys(connection, table_name, schema,
@@ -698,10 +699,10 @@ class OracleDialect(default.DefaultDialect):
             if rset.column_name in [s.upper() for s in pkeys]:
                 continue
             if rset.index_name != last_index_name:
-                index = dict(name=self._normalize_name(rset.index_name), column_names=[])
+                index = dict(name=self.normalize_name(rset.index_name), column_names=[])
                 indexes.append(index)
             index['unique'] = uniqueness.get(rset.uniqueness, False)
-            index['column_names'].append(self._normalize_name(rset.column_name))
+            index['column_names'].append(self.normalize_name(rset.column_name))
             last_index_name = rset.index_name
         return indexes
 
@@ -763,7 +764,7 @@ class OracleDialect(default.DefaultDialect):
         for row in constraint_data:
             #print "ROW:" , row
             (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
-                row[0:2] + tuple([self._normalize_name(x) for x in row[2:6]])
+                row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
             if cons_type == 'P':
                 pkeys.append(local_column)
         return pkeys
@@ -807,7 +808,7 @@ class OracleDialect(default.DefaultDialect):
         
         for row in constraint_data:
             (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
-                    row[0:2] + tuple([self._normalize_name(x) for x in row[2:6]])
+                    row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
 
             if cons_type == 'R':
                 if remote_table is None:
@@ -827,16 +828,16 @@ class OracleDialect(default.DefaultDialect):
                         ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \
                                 self._resolve_synonym(
                                     connection, 
-                                    desired_owner=self._denormalize_name(remote_owner), 
-                                    desired_table=self._denormalize_name(remote_table)
+                                    desired_owner=self.denormalize_name(remote_owner), 
+                                    desired_table=self.denormalize_name(remote_table)
                                 )
                         if ref_synonym:
-                            remote_table = self._normalize_name(ref_synonym)
-                            remote_owner = self._normalize_name(ref_remote_owner)
+                            remote_table = self.normalize_name(ref_synonym)
+                            remote_owner = self.normalize_name(ref_remote_owner)
                     
                     rec['referred_table'] = remote_table
                     
-                    if requested_schema is not None or self._denormalize_name(remote_owner) != schema:
+                    if requested_schema is not None or self.denormalize_name(remote_owner) != schema:
                         rec['referred_schema'] = remote_owner
                 
                 local_cols.append(local_column)
@@ -864,9 +865,6 @@ class OracleDialect(default.DefaultDialect):
         else:
             return None
 
-    def reflecttable(self, connection, table, include_columns):
-        insp = reflection.Inspector.from_engine(connection)
-        return insp.reflecttable(table, include_columns)
 
 
 class _OuterJoinColumn(sql.ClauseElement):
index 947ebbac5dffc6e2e366bf316402f12d4a717cbf..dd0bd80495badb73fd3ab23ba398d9ab8d9a8ca3 100644 (file)
@@ -859,10 +859,6 @@ class PGDialect(default.DefaultDialect):
             index_d['unique'] = unique
         return indexes
 
-    def reflecttable(self, connection, table, include_columns):
-        insp = reflection.Inspector.from_engine(connection)
-        return insp.reflecttable(table, include_columns)
-
     def _load_domains(self, connection):
         ## Load data types for domains:
         SQL_DOMAINS = """
index 1b260aa6864f4899a4d746c18f6b54ff65b03438..b644c7bf80b35c4617681bc9356770bbc05cafeb 100644 (file)
@@ -516,35 +516,6 @@ class SQLiteDialect(default.DefaultDialect):
                 cols.append(row[2])
         return indexes
 
-    def get_unique_indexes(self, connection, table_name, schema=None, **kw):
-        quote = self.identifier_preparer.quote_identifier
-        if schema is not None:
-            pragma = "PRAGMA %s." % quote(schema)
-        else:
-            pragma = "PRAGMA "
-        qtable = quote(table_name)
-        c = _pragma_cursor(connection.execute("%sindex_list(%s)" % (pragma, qtable)))
-        unique_indexes = []
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            if (row[2] == 1):
-                unique_indexes.append(row[1])
-        # loop thru unique indexes for one that includes the primary key
-        for idx in unique_indexes:
-            c = _pragma_cursor(connection.execute("%sindex_info(%s)" % (pragma, idx)))
-            cols = []
-            while True:
-                row = c.fetchone()
-                if row is None:
-                    break
-                cols.append(row[2])
-        return unique_indexes
-
-    def reflecttable(self, connection, table, include_columns):
-        insp = reflection.Inspector.from_engine(connection)
-        return insp.reflecttable(table, include_columns)
 
 def _pragma_cursor(cursor):
     """work around SQLite issue whereby cursor.description is blank when PRAGMA returns no rows."""
index a57042796909d47299e093c9a3b56aa614708a7b..f752182c3689aecda54e72454871d85563d71bcc 100644 (file)
@@ -289,6 +289,23 @@ class Dialect(object):
 
         raise NotImplementedError()
 
+    def normalize_name(self, name):
+        """convert the given name to lowercase if it is detected as case insensitive.
+    
+        this method is only used if the dialect defines requires_name_normalize=True.
+
+        """
+        raise NotImplementedError()
+
+    def denormalize_name(self, name):
+        """convert the given name to a case insensitive identifier for the backend 
+        if it is an all-lowercase name.
+        
+        this method is only used if the dialect defines requires_name_normalize=True.
+
+        """
+        raise NotImplementedError()
+        
     def has_table(self, connection, table_name, schema=None):
         """Check the existence of a particular table in the database.
 
@@ -1636,7 +1653,10 @@ class ResultProxy(object):
             if origname:
                 if self._props.setdefault(origname.lower(), rec) is not rec:
                     self._props[origname.lower()] = (type_, self.__ambiguous_processor(origname), 0)
-
+            
+            if self.dialect.requires_name_normalize:
+                colname = self.dialect.normalize_name(colname)
+                
             self.keys.append(colname)
             self._props[i] = rec
             if obj:
index 413de171a20a711a756fab47cd4b83121cbc4545..5a86d7c94b0882a905e809d59f625c914c086cc1 100644 (file)
@@ -13,7 +13,7 @@ as the base class for their own corresponding classes.
 """
 
 import re, random
-from sqlalchemy.engine import base
+from sqlalchemy.engine import base, reflection
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy import exc, types as sqltypes
 
@@ -51,6 +51,14 @@ class DefaultDialect(base.Dialect):
     default_paramstyle = 'named'
     supports_default_values = False
     supports_empty_insert = True
+    
+    # indicates symbol names are 
+    # UPPERCASEd if they are case insensitive
+    # within the database.
+    # if this is True, the methods normalize_name()
+    # and denormalize_name() must be provided.
+    requires_name_normalize = False
+    
     reflection_options = ()
 
     def __init__(self, convert_unicode=False, assert_unicode=False,
@@ -104,9 +112,16 @@ class DefaultDialect(base.Dialect):
         """
         return sqltypes.adapt_type(typeobj, cls.colspecs)
 
+    def reflecttable(self, connection, table, include_columns):
+        insp = reflection.Inspector.from_engine(connection)
+        return insp.reflecttable(table, include_columns)
+
     def validate_identifier(self, ident):
         if len(ident) > self.max_identifier_length:
-            raise exc.IdentifierError("Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length))
+            raise exc.IdentifierError(
+                "Identifier '%s' exceeds maximum length of %d characters" % 
+                (ident, self.max_identifier_length)
+            )
 
     def connect(self, *cargs, **cparams):
         return self.dbapi.connect(*cargs, **cparams)
index 98408ace980c003515c4e4b9b02b14ee869c7620..b6fbb93dbac1a26b3b0579c3ec5431be90fc722f 100644 (file)
@@ -28,6 +28,25 @@ def foreign_keys(fn):
         no_support('sqlite', 'not supported by database'),
         )
 
+
+def unbounded_varchar(fn):
+    """Target database must support VARCHAR with no length"""
+    return _chain_decorators_on(
+        fn,
+        no_support('firebird', 'not supported by database'),
+        no_support('oracle', 'not supported by database'),
+        no_support('mysql', 'not supported by database'),
+    )
+
+def boolean_col_expressions(fn):
+    """Target database must support boolean expressions as columns"""
+    return _chain_decorators_on(
+        fn,
+        no_support('firebird', 'not supported by database'),
+        no_support('oracle', 'not supported by database'),
+        no_support('mssql', 'not supported by database'),
+    )
+    
 def identity(fn):
     """Target database must support GENERATED AS IDENTITY or a facsimile.
 
@@ -90,7 +109,8 @@ def schemas(fn):
     
     return _chain_decorators_on(
         fn,
-        no_support('sqlite', 'no schema support')
+        no_support('sqlite', 'no schema support'),
+        no_support('firebird', 'no schema support')
     )
     
 def sequences(fn):
index 555ffffe0627f96eb4c9ed0ef2b8fd9f2796b8cc..35b4060d2bd7c06c9c93f1372dc846299edcbf1a 100644 (file)
@@ -33,7 +33,7 @@ def Table(*args, **kw):
         # expand to ForeignKeyConstraint too.
         fks = [fk
                for col in args if isinstance(col, schema.Column)
-               for fk in col.args if isinstance(fk, schema.ForeignKey)]
+               for fk in col.foreign_keys]
 
         for fk in fks:
             # root around in raw spec
index 965849df02bf007a464392f250d66baed8238284..8d123f50484b89d2896e6128ab487748b83afb9c 100644 (file)
@@ -923,7 +923,7 @@ class SMALLINT(SmallInteger):
     __visit_name__ = 'SMALLINT'
 
 
-class BIGINT(SmallInteger):
+class BIGINT(BigInteger):
     """The SQL BIGINT type."""
 
     __visit_name__ = 'BIGINT'
index fa608c9a18e5c0761cf582b352ab38dc3b780b22..033522902efd7d84918cc34fa16fc2bde813aaf3 100644 (file)
@@ -198,7 +198,7 @@ class ReturningTest(TestBase, AssertsExecutionResults):
             table.drop()
 
 
-class MiscFBTests(TestBase):
+class MiscTest(TestBase):
     __only_on__ = 'firebird'
 
     def test_strlen(self):
@@ -225,4 +225,12 @@ class MiscFBTests(TestBase):
         version = testing.db.dialect.server_version_info(testing.db.connect())
         assert len(version) == 3, "Got strange version info: %s" % repr(version)
 
+    def test_percents_in_text(self):
+        for expr, result in (
+            (text("select '%' from rdb$database"), '%'),
+            (text("select '%%' from rdb$database"), '%%'),
+            (text("select '%%%' from rdb$database"), '%%%'),
+            (text("select 'hello % world' from rdb$database"), "hello % world")
+        ):
+            eq_(testing.db.scalar(expr), result)
 
index 06451bbc4d83331d25807e9cbc3310db6a2a8787..43b427d3323abeadf50ffa842b22de47415538d8 100644 (file)
@@ -598,7 +598,6 @@ class ReflectionTest(TestBase, ComparesTables):
             m9.reflect()
             self.assert_(not m9.tables)
 
-    @testing.fails_on_everything_except('postgres', 'mysql', 'sqlite', 'oracle', 'mssql')
     def test_index_reflection(self):
         m1 = MetaData(testing.db)
         t1 = Table('party', m1,
@@ -907,7 +906,7 @@ def dropViews(con, schema=None):
 
 class ComponentReflectionTest(TestBase):
 
-    @testing.fails_on('sqlite', 'no schemas')
+    @testing.requires.schemas
     def test_get_schema_names(self):
         meta = MetaData(testing.db)
         insp = Inspector(meta.bind)
@@ -971,23 +970,27 @@ class ComponentReflectionTest(TestBase):
                 # should be in order
                 for (i, col) in enumerate(table.columns):
                     eq_(col.name, cols[i]['name'])
-                    # coltype is tricky
-                    # It may not inherit from col.type while they share
-                    # the same base.
                     ctype = cols[i]['type'].__class__
                     ctype_def = col.type
                     if isinstance(ctype_def, sa.types.TypeEngine):
                         ctype_def = ctype_def.__class__
+                        
                     # Oracle returns Date for DateTime.
                     if testing.against('oracle') \
                         and ctype_def in (sql_types.Date, sql_types.DateTime):
                             ctype_def = sql_types.Date
+                    
+                    # assert that the desired type and return type
+                    # share a base within one of the generic types.
                     self.assert_(
-                        issubclass(ctype, ctype_def) or \
                         len(
                             set(
-                                ctype.__bases__
-                            ).intersection(ctype_def.__bases__)) > 0
+                                ctype.__mro__
+                            ).intersection(ctype_def.__mro__)
+                            .intersection([sql_types.Integer, sql_types.Numeric, 
+                                            sql_types.DateTime, sql_types.Date, sql_types.Time, 
+                                            sql_types.String, sql_types.Binary])
+                            ) > 0
                     ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'],
                                           ctype)))
         finally:
index f22a5c22a3201eb0657207b6360ce617ff674916..f679277049c56a98e870ef800b9166b156be8efc 100644 (file)
@@ -149,7 +149,7 @@ class QueryTest(TestBase):
             l.append(row)
         self.assert_(len(l) == 3)
 
-    @testing.fails_on('firebird', 'Data type unknown')
+    @testing.fails_on('firebird', "kinterbasdb doesn't send full type information")
     @testing.requires.subqueries
     def test_anonymous_rows(self):
         users.insert().execute(
@@ -163,6 +163,7 @@ class QueryTest(TestBase):
             assert row['anon_1'] == 8
             assert row['anon_2'] == 10
 
+    @testing.fails_on('firebird', "kinterbasdb doesn't send full type information")
     def test_order_by_label(self):
         """test that a label within an ORDER BY works on each backend.
         
@@ -181,6 +182,11 @@ class QueryTest(TestBase):
             select([concat]).order_by(concat).execute().fetchall(),
             [("test: ed",), ("test: fred",), ("test: jack",)]
         )
+        
+        eq_(
+            select([concat]).order_by(concat).execute().fetchall(),
+            [("test: ed",), ("test: fred",), ("test: jack",)]
+        )
 
         concat = ("test: " + users.c.user_name).label('thedata')
         eq_(
@@ -209,8 +215,7 @@ class QueryTest(TestBase):
         self.assert_(not (rp != equal))
         self.assert_(not (equal != equal))
 
-    @testing.fails_on('mssql', 'No support for boolean logic in column select.')
-    @testing.fails_on('oracle', 'FIXME: unknown')
+    @testing.requires.boolean_col_expressions
     def test_or_and_as_columns(self):
         true, false = literal(True), literal(False)
         
@@ -255,6 +260,7 @@ class QueryTest(TestBase):
             eq_(expr.execute().fetchall(), result)
     
 
+    @testing.fails_on("firebird", "see dialect.test_firebird:MiscTest.test_percents_in_text")
     @testing.fails_on("oracle", "neither % nor %% are accepted")
     @testing.fails_on("+pg8000", "can't interpret result column from '%%'")
     @testing.emits_warning('.*now automatically escapes.*')
@@ -475,15 +481,25 @@ class QueryTest(TestBase):
         self.assert_(r['query_users.user_id']) == 1
         self.assert_(r['query_users.user_name']) == "john"
 
-    @testing.fails_on('oracle', 'oracle result keys() are all uppercase, not getting into this.')
+    def test_result_case_sensitivity(self):
+        """test name normalization for result sets."""
+        
+        row = testing.db.execute(
+            select([
+                literal_column("1").label("case_insensitive"),
+                literal_column("2").label("CaseSensitive")
+            ])
+        ).first()
+        
+        assert row.keys() == ["case_insensitive", "CaseSensitive"]
+
     def test_row_as_args(self):
         users.insert().execute(user_id=1, user_name='john')
         r = users.select(users.c.user_id==1).execute().first()
         users.delete().execute()
         users.insert().execute(r)
         eq_(users.select().execute().fetchall(), [(1, 'john')])
-    
-    @testing.fails_on('oracle', 'oracle result keys() are all uppercase, not getting into this.')
+
     def test_result_as_args(self):
         users.insert().execute([dict(user_id=1, user_name='john'), dict(user_id=2, user_name='ed')])
         r = users.select().execute()
@@ -620,13 +636,13 @@ class QueryTest(TestBase):
         # Null values are not outside any set
         assert len(r) == 0
 
-        u = bindparam('search_key')
+    @testing.fails_on('firebird', "kinterbasdb doesn't send full type information")
+    def test_bind_in(self):
+        users.insert().execute(user_id = 7, user_name = 'jack')
+        users.insert().execute(user_id = 8, user_name = 'fred')
+        users.insert().execute(user_id = 9, user_name = None)
 
-        s = users.select(u.in_([]))
-        r = s.execute(search_key='john').fetchall()
-        assert len(r) == 0
-        r = s.execute(search_key=None).fetchall()
-        assert len(r) == 0
+        u = bindparam('search_key')
 
         s = users.select(not_(u.in_([])))
         r = s.execute(search_key='john').fetchall()
@@ -881,6 +897,7 @@ class CompoundTest(TestBase):
         found2 = self._fetchall_sorted(u.alias('bar').select().execute())
         eq_(found2, wanted)
 
+    @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs")
     def test_union_ordered(self):
         (s1, s2) = (
             select([t1.c.col3.label('col3'), t1.c.col4.label('col4')],
@@ -894,6 +911,7 @@ class CompoundTest(TestBase):
                   ('ccc', 'aaa')]
         eq_(u.execute().fetchall(), wanted)
 
+    @testing.fails_on('firebird', "doesn't like ORDER BY with UNIONs")
     @testing.fails_on('maxdb', 'FIXME: unknown')
     @testing.requires.subqueries
     def test_union_ordered_alias(self):
@@ -910,6 +928,7 @@ class CompoundTest(TestBase):
         eq_(u.alias('bar').select().execute().fetchall(), wanted)
 
     @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
+    @testing.fails_on('firebird', "has trouble extracting anonymous column from union subquery")
     @testing.fails_on('mysql', 'FIXME: unknown')
     @testing.fails_on('sqlite', 'FIXME: unknown')
     def test_union_all(self):
@@ -928,6 +947,29 @@ class CompoundTest(TestBase):
         found2 = self._fetchall_sorted(e.alias('foo').select().execute())
         eq_(found2, wanted)
 
+    def test_union_all_lightweight(self):
+        """like test_union_all, but breaks the sub-union into 
+        a subquery with an explicit column reference on the outside,
+        more palatable to a wider variety of engines.
+        
+        """
+        u = union(
+            select([t1.c.col3]),
+            select([t1.c.col3]),
+        ).alias()
+        
+        e = union_all(
+            select([t1.c.col3]),
+            select([u.c.col3])
+        )
+
+        wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)]
+        found1 = self._fetchall_sorted(e.execute())
+        eq_(found1, wanted)
+
+        found2 = self._fetchall_sorted(e.alias('foo').select().execute())
+        eq_(found2, wanted)
+
     @testing.crashes('firebird', 'Does not support intersect')
     @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
     @testing.fails_on('mysql', 'FIXME: unknown')
index 821a386ffeccd3c462bff085dfdda73b36248b83..512ef2e7f988c879c6562fe352bffa4029df3e4a 100644 (file)
@@ -623,8 +623,8 @@ class DateTest(TestBase, AssertsExecutionResults):
             t.drop(checkfirst=True)
 
 class StringTest(TestBase, AssertsExecutionResults):
-    @testing.fails_on('mysql', 'FIXME: unknown')
-    @testing.fails_on('oracle', 'FIXME: unknown')
+
+    @testing.requires.unbounded_varchar
     def test_nolength_string(self):
         metadata = MetaData(testing.db)
         foo = Table('foo', metadata, Column('one', String))