]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- oracle support, includes fix for #994
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Jan 2009 01:16:51 +0000 (01:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Jan 2009 01:16:51 +0000 (01:16 +0000)
16 files changed:
doc/build/reference/dialects/oracle.rst
lib/sqlalchemy/dialects/__init__.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/oracle/__init__.py [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/base.py [moved from lib/sqlalchemy/databases/oracle.py with 65% similarity]
lib/sqlalchemy/dialects/oracle/cx_oracle.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgres/base.py
lib/sqlalchemy/dialects/postgres/psycopg2.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/types.py
test/dialect/oracle.py
test/engine/reflection.py
test/sql/query.py

index 188f6f438314a4e0e43e57ac3665c4057a554f19..68be06b64c74df70cf0be69e85eb99f096c8d929 100644 (file)
@@ -1,4 +1,10 @@
 Oracle
 ======
 
-.. automodule:: sqlalchemy.databases.oracle
+.. automodule:: sqlalchemy.dialects.oracle.base
+
+cx_Oracle Notes
+===============
+
+.. automodule:: sqlalchemy.dialects.oracle.cx_oracle
+
index 33e481d25cc8f17cbb6dab595bc761342a1c76d5..8526b4d8fbd100902c12c0080c50e73724b5e1f1 100644 (file)
@@ -5,7 +5,7 @@ __all__ = (
 #    'maxdb',
 #    'mssql',
     'mysql',
-#    'oracle',
+    'oracle',
     'postgres',
     'sqlite',
 #    'sybase',
index 3c66945e80afa37c68219c144096f06c270f09dd..74938abe0d5e28e7387abfff2f8042c9e5dd225e 100644 (file)
@@ -1689,7 +1689,8 @@ class MySQLDialect(default.DefaultDialect):
     max_identifier_length = 255
     supports_sane_rowcount = True
     default_paramstyle = 'format'
-
+    colspecs = colspecs
+    
     statement_compiler = MySQLCompiler
     ddl_compiler = MySQLDDLCompiler
     type_compiler = MySQLTypeCompiler
@@ -1699,9 +1700,6 @@ class MySQLDialect(default.DefaultDialect):
         self.use_ansiquotes = use_ansiquotes
         default.DefaultDialect.__init__(self, **kwargs)
 
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
-
     def do_executemany(self, cursor, statement, parameters, context=None):
         rowcount = cursor.executemany(statement, parameters)
         if context is not None:
index 61f9d3f6719d3015eea897ae1ce362fbf20168d4..b077774ea51d2c6d8987a1e876d2b292bc9c756a 100644 (file)
@@ -40,8 +40,6 @@ class MySQL_mysqldbCompiler(MySQLCompiler):
     )
     
     def post_process_text(self, text):
-        if '%%' in text:
-            util.warn("The SQLAlchemy mysql+mysqldb dialect now automatically escapes '%' in text() expressions to '%%'.")
         return text.replace('%', '%%')
     
 class MySQL_mysqldb(MySQLDialect):
diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py
new file mode 100644 (file)
index 0000000..7038fb3
--- /dev/null
@@ -0,0 +1,3 @@
+from sqlalchemy.dialects.oracle import base, cx_oracle
+
+base.dialect = cx_oracle.dialect
\ No newline at end of file
similarity index 65%
rename from lib/sqlalchemy/databases/oracle.py
rename to lib/sqlalchemy/dialects/oracle/base.py
index b0ec6115b2b6a768811c5038f8f53f37dea20b5d..9bf6db23d60007a5dfd6e1f44769398592beaa83 100644 (file)
@@ -1,4 +1,4 @@
-# oracle.py
+# oracle/base.py
 # Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com
 #
 # This module is part of SQLAlchemy and is released under
@@ -7,37 +7,14 @@
 
 Oracle version 8 through current (11g at the time of this writing) are supported.
 
-Driver
-------
+For information on connecting via specific drivers, see the documentation
+for that driver.
 
-The Oracle dialect uses the cx_oracle driver, available at 
-http://cx-oracle.sourceforge.net/ .   The dialect has several behaviors 
-which are specifically tailored towards compatibility with this module.
+Connect Arguments
+-----------------
 
-Connecting
-----------
-
-Connecting with create_engine() uses the standard URL approach of 
-``oracle://user:pass@host:port/dbname[?key=value&key=value...]``.  If dbname is present, the 
-host, port, and dbname tokens are converted to a TNS name using the cx_oracle 
-:func:`makedsn()` function.  Otherwise, the host token is taken directly as a TNS name.
-
-Additional arguments which may be specified either as query string arguments on the
-URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are:
-
-* *allow_twophase* - enable two-phase transactions.  Defaults to ``True``.
-
-* *auto_convert_lobs* - defaults to True, see the section on LOB objects.
-
-* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters.
-  This is required for LOB datatypes but can be disabled to reduce overhead.  Defaults
-  to ``True``.
-
-* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an
-  integer value.  This value is only available as a URL query string argument.
-
-* *threaded* - enable multithreaded access to cx_oracle connections.  Defaults
-  to ``True``.  Note that this is the opposite default of cx_oracle itself.
+The dialect supports several :func:`~sqlalchemy.create_engine()` arguments which 
+affect the behavior of the dialect regardless of driver in use.
 
 * *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8).  Defaults
   to ``True``.  If ``False``, Oracle-8 compatible constructs are used for joins.
@@ -67,28 +44,6 @@ This step is also required when using table reflection, i.e. autoload=True::
         autoload=True
   ) 
 
-LOB Objects
------------
-
-cx_oracle presents some challenges when fetching LOB objects.  A LOB object in a result set
-is presented by cx_oracle as a cx_oracle.LOB object which has a read() method.  By default, 
-SQLAlchemy converts these LOB objects into Python strings.  This is for two reasons.  First,
-the LOB object requires an active cursor association, meaning if you were to fetch many rows
-at once such that cx_oracle had to go back to the database and fetch a new batch of rows,
-the LOB objects in the already-fetched rows are now unreadable and will raise an error. 
-SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read.  
-The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy
-defaults to 50 (cx_oracle normally defaults this to one).  
-
-Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to 
-"normalize" the results to look more like other DBAPIs.
-
-The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place
-for all statement executions, even plain string-based statements for which SQLA has no awareness
-of result typing.  This is so that calls like fetchmany() and fetchall() can work in all cases
-without raising cursor errors.  The conversion of LOB in all cases, as well as the "prefetch"
-of LOB objects, can be disabled using auto_convert_lobs=False.  
-
 LIMIT/OFFSET Support
 --------------------
 
@@ -100,12 +55,6 @@ http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html .  Note tha
 this was stepping into the bounds of optimization that is better left on the DBA side, but this
 prefix can be added by enabling the optimize_limits=True flag on create_engine().
 
-Two Phase Transaction Support
------------------------------
-
-Two Phase transactions are implemented using XA transactions.  Success has been reported of them
-working successfully but this should be regarded as an experimental feature.
-
 Oracle 8 Compatibility
 ----------------------
 
@@ -127,29 +76,13 @@ import datetime, random, re
 
 from sqlalchemy import util, sql, schema, log
 from sqlalchemy.engine import default, base
-from sqlalchemy.sql import compiler, visitors
+from sqlalchemy.sql import compiler, visitors, expression
 from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
 from sqlalchemy import types as sqltypes
 
-
-class OracleNumeric(sqltypes.Numeric):
-    def get_col_spec(self):
-        if self.precision is None:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-class OracleInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return "INTEGER"
-
-class OracleSmallInteger(sqltypes.SmallInteger):
-    def get_col_spec(self):
-        return "SMALLINT"
+RESERVED_WORDS = set('''SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR DECIMAL UNION PUBLIC AND START'''.split()) 
 
 class OracleDate(sqltypes.Date):
-    def get_col_spec(self):
-        return "DATE"
     def bind_processor(self, dialect):
         return None
 
@@ -162,9 +95,6 @@ class OracleDate(sqltypes.Date):
         return process
 
 class OracleDateTime(sqltypes.DateTime):
-    def get_col_spec(self):
-        return "DATE"
-
     def result_processor(self, dialect):
         def process(value):
             if value is None or isinstance(value, datetime.datetime):
@@ -182,12 +112,6 @@ class OracleDateTime(sqltypes.DateTime):
 
 # only if cx_oracle contains TIMESTAMP
 class OracleTimestamp(sqltypes.TIMESTAMP):
-    def get_col_spec(self):
-        return "TIMESTAMP"
-
-    def get_dbapi_type(self, dialect):
-        return dialect.TIMESTAMP
-
     def result_processor(self, dialect):
         def process(value):
             if value is None or isinstance(value, datetime.datetime):
@@ -198,21 +122,10 @@ class OracleTimestamp(sqltypes.TIMESTAMP):
                     value.day,value.hour, value.minute, value.second)
         return process
 
-class OracleString(sqltypes.String):
-    def get_col_spec(self):
-        return "VARCHAR(%(length)s)" % {'length' : self.length}
-
-class OracleNVarchar(sqltypes.Unicode, OracleString):
-    def get_col_spec(self):
-        return "NVARCHAR2(%(length)s)" % {'length' : self.length}
-
 class OracleText(sqltypes.Text):
     def get_dbapi_type(self, dbapi):
         return dbapi.CLOB
 
-    def get_col_spec(self):
-        return "CLOB"
-
     def result_processor(self, dialect):
         super_process = super(OracleText, self).result_processor(dialect)
         if not dialect.auto_convert_lobs:
@@ -232,17 +145,10 @@ class OracleText(sqltypes.Text):
         return process
 
 
-class OracleChar(sqltypes.CHAR):
-    def get_col_spec(self):
-        return "CHAR(%(length)s)" % {'length' : self.length}
-
 class OracleBinary(sqltypes.Binary):
     def get_dbapi_type(self, dbapi):
         return dbapi.BLOB
 
-    def get_col_spec(self):
-        return "BLOB"
-
     def bind_processor(self, dialect):
         return None
 
@@ -262,9 +168,6 @@ class OracleRaw(OracleBinary):
         return "RAW(%(length)s)" % {'length' : self.length}
 
 class OracleBoolean(sqltypes.Boolean):
-    def get_col_spec(self):
-        return "SMALLINT"
-
     def result_processor(self, dialect):
         def process(value):
             if value is None:
@@ -285,200 +188,297 @@ class OracleBoolean(sqltypes.Boolean):
         return process
 
 colspecs = {
-    sqltypes.Integer : OracleInteger,
-    sqltypes.SmallInteger : OracleSmallInteger,
-    sqltypes.Numeric : OracleNumeric,
-    sqltypes.Float : OracleNumeric,
     sqltypes.DateTime : OracleDateTime,
     sqltypes.Date : OracleDate,
-    sqltypes.String : OracleString,
     sqltypes.Binary : OracleBinary,
     sqltypes.Boolean : OracleBoolean,
     sqltypes.Text : OracleText,
     sqltypes.TIMESTAMP : OracleTimestamp,
-    sqltypes.CHAR: OracleChar,
 }
 
 ischema_names = {
-    'VARCHAR2' : OracleString,
-    'NVARCHAR2' : OracleNVarchar,
-    'CHAR' : OracleString,
-    'DATE' : OracleDateTime,
-    'DATETIME' : OracleDateTime,
-    'NUMBER' : OracleNumeric,
-    'BLOB' : OracleBinary,
-    'BFILE' : OracleBinary,
-    'CLOB' : OracleText,
-    'TIMESTAMP' : OracleTimestamp,
+    'VARCHAR2' : sqltypes.VARCHAR,
+    'NVARCHAR2' : sqltypes.NVARCHAR,
+    'CHAR' : sqltypes.CHAR,
+    'DATE' : sqltypes.DATE,
+    'DATETIME' : sqltypes.DATETIME,
+    'NUMBER' : sqltypes.Numeric,
+    'BLOB' : sqltypes.BLOB,
+    'BFILE' : sqltypes.Binary,
+    'CLOB' : sqltypes.CLOB,
+    'TIMESTAMP' : sqltypes.TIMESTAMP,
     'RAW' : OracleRaw,
-    'FLOAT' : OracleNumeric,
-    'DOUBLE PRECISION' : OracleNumeric,
-    'LONG' : OracleText,
+    'FLOAT' : sqltypes.Float,
+    'DOUBLE PRECISION' : sqltypes.Numeric,
+    'LONG' : sqltypes.Text,
 }
 
-class OracleExecutionContext(default.DefaultExecutionContext):
-    def pre_exec(self):
-        super(OracleExecutionContext, self).pre_exec()
-        if self.dialect.auto_setinputsizes:
-            self.set_input_sizes()
-        if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
-            for key in self.compiled.binds:
-                bindparam = self.compiled.binds[key]
-                name = self.compiled.bind_names[bindparam]
-                value = self.compiled_parameters[0][name]
-                if bindparam.isoutparam:
-                    dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
-                    if not hasattr(self, 'out_parameters'):
-                        self.out_parameters = {}
-                    self.out_parameters[name] = self.cursor.var(dbtype)
-                    self.parameters[0][name] = self.out_parameters[name]
-
-    def create_cursor(self):
-        c = self._connection.connection.cursor()
-        if self.dialect.arraysize:
-            c.cursor.arraysize = self.dialect.arraysize
-        return c
-
-    def get_result_proxy(self):
-        if hasattr(self, 'out_parameters'):
-            if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
-                for bind, name in self.compiled.bind_names.iteritems():
-                    if name in self.out_parameters:
-                        type = bind.type
-                        result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect)
-                        if result_processor is not None:
-                            self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue())
-                        else:
-                            self.out_parameters[name] = self.out_parameters[name].getvalue()
-            else:
-                for k in self.out_parameters:
-                    self.out_parameters[k] = self.out_parameters[k].getvalue()
 
-        if self.cursor.description is not None:
-            for column in self.cursor.description:
-                type_code = column[1]
-                if type_code in self.dialect.ORACLE_BINARY_TYPES:
-                    return base.BufferedColumnResultProxy(self)
+class OracleTypeCompiler(compiler.GenericTypeCompiler):
+    # Note:
+    # Oracle DATE == DATETIME
+    # Oracle does not allow milliseconds in DATE
+    # Oracle does not support TIME columns
+    
+    def visit_DATETIME(self, type_):
+        return self.visit_DATE(type_)
+        
+    def visit_VARCHAR(self, type_):
+        return "VARCHAR(%(length)s)" % {'length' : type_.length}
 
-        return base.ResultProxy(self)
+    def visit_NVARCHAR(self, type_):
+        return "NVARCHAR2(%(length)s)" % {'length' : type_.length}
+    
+    def visit_TEXT(self, type_):
+        return self.visit_CLOB(type_)
 
-class OracleDialect(default.DefaultDialect):
-    name = 'oracle'
-    supports_alter = True
-    supports_unicode_statements = False
-    max_identifier_length = 30
-    supports_sane_rowcount = True
-    supports_sane_multi_rowcount = False
-    preexecute_pk_sequences = True
-    supports_pk_autoincrement = False
-    default_paramstyle = 'named'
+    def visit_BINARY(self, type_):
+        return self.visit_BLOB(type_)
+    
+    def visit_BOOLEAN(self, type_):
+        return self.visit_SMALLINT(type_)
+    
+    def visit_RAW(self, type_):
+        return "RAW(%(length)s)" % {'length' : type_.length}
 
-    def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, optimize_limits=False, arraysize=50, **kwargs):
-        default.DefaultDialect.__init__(self, **kwargs)
-        self.use_ansi = use_ansi
-        self.threaded = threaded
-        self.arraysize = arraysize
-        self.allow_twophase = allow_twophase
-        self.optimize_limits = optimize_limits
-        self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
-        self.auto_setinputsizes = auto_setinputsizes
-        self.auto_convert_lobs = auto_convert_lobs
-        if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__:
-            self.dbapi_type_map = {}
-            self.ORACLE_BINARY_TYPES = []
+class OracleCompiler(compiler.SQLCompiler):
+    """Oracle compiler modifies the lexical structure of Select
+    statements to work under non-ANSI configured Oracle databases, if
+    the use_ansi flag is False.
+    """
+
+    operators = util.update_copy(
+        compiler.SQLCompiler.operators,
+        {
+            sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y),
+            sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
+        }
+    )
+
+    functions = util.update_copy(
+        compiler.SQLCompiler.functions,
+        {
+            sql_functions.now : 'CURRENT_TIMESTAMP'
+        }
+    )
+
+    def __init__(self, *args, **kwargs):
+        super(OracleCompiler, self).__init__(*args, **kwargs)
+        self.__wheres = {}
+        self._quoted_bind_names = {}
+
+    def bindparam_string(self, name):
+        if self.preparer._bindparam_requires_quotes(name):
+            quoted_name = '"%s"' % name
+            self._quoted_bind_names[name] = quoted_name
+            return compiler.SQLCompiler.bindparam_string(self, quoted_name)
         else:
-            # only use this for LOB objects.  using it for strings, dates
-            # etc. leads to a little too much magic, reflection doesn't know if it should
-            # expect encoded strings or unicodes, etc.
-            self.dbapi_type_map = {
-                self.dbapi.CLOB: OracleText(),
-                self.dbapi.BLOB: OracleBinary(),
-                self.dbapi.BINARY: OracleRaw(),
-            }
-            self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
-
-    def dbapi(cls):
-        import cx_Oracle
-        return cx_Oracle
-    dbapi = classmethod(dbapi)
-
-    def create_connect_args(self, url):
-        dialect_opts = dict(url.query)
-        for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
-                    'threaded', 'allow_twophase'):
-            if opt in dialect_opts:
-                util.coerce_kw_type(dialect_opts, opt, bool)
-                setattr(self, opt, dialect_opts[opt])
-
-        if url.database:
-            # if we have a database, then we have a remote host
-            port = url.port
-            if port:
-                port = int(port)
+            return compiler.SQLCompiler.bindparam_string(self, name)
+
+    def default_from(self):
+        """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
+
+        The Oracle compiler tacks a "FROM DUAL" to the statement.
+        """
+
+        return " FROM DUAL"
+
+    def apply_function_parens(self, func):
+        return len(func.clauses) > 0
+
+    def visit_join(self, join, **kwargs):
+        if self.dialect.use_ansi:
+            return compiler.SQLCompiler.visit_join(self, join, **kwargs)
+        else:
+            return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+
+    def _get_nonansi_join_whereclause(self, froms):
+        clauses = []
+
+        def visit_join(join):
+            if join.isouter:
+                def visit_binary(binary):
+                    if binary.operator == sql_operators.eq:
+                        if binary.left.table is join.right:
+                            binary.left = _OuterJoinColumn(binary.left)
+                        elif binary.right.table is join.right:
+                            binary.right = _OuterJoinColumn(binary.right)
+                clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary}))
             else:
-                port = 1521
-            dsn = self.dbapi.makedsn(url.host, port, url.database)
+                clauses.append(join.onclause)
+
+        for f in froms:
+            visitors.traverse(f, {}, {'join':visit_join})
+        return sql.and_(*clauses)
+
+    def visit_outer_join_column(self, vc):
+        return self.process(vc.column) + "(+)"
+
+    def visit_sequence(self, seq):
+        return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
+
+    def visit_alias(self, alias, asfrom=False, **kwargs):
+        """Oracle doesn't like ``FROM table AS alias``.  Is the AS standard SQL??"""
+
+        if asfrom:
+            return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name))
         else:
-            # we have a local tnsname
-            dsn = url.host
-
-        opts = dict(
-            user=url.username,
-            password=url.password,
-            dsn=dsn,
-            threaded=self.threaded,
-            twophase=self.allow_twophase,
-            )
-        if 'mode' in url.query:
-            opts['mode'] = url.query['mode']
-            if isinstance(opts['mode'], basestring):
-                mode = opts['mode'].upper()
-                if mode == 'SYSDBA':
-                    opts['mode'] = self.dbapi.SYSDBA
-                elif mode == 'SYSOPER':
-                    opts['mode'] = self.dbapi.SYSOPER
+            return self.process(alias.original, **kwargs)
+
+    def _TODO_visit_compound_select(self, select):
+        """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
+        pass
+
+    def visit_select(self, select, **kwargs):
+        """Look for ``LIMIT`` and OFFSET in a select statement, and if
+        so tries to wrap it in a subquery with ``rownum`` criterion.
+        """
+
+        if not getattr(select, '_oracle_visit', None):
+            if not self.dialect.use_ansi:
+                if self.stack and 'from' in self.stack[-1]:
+                    existingfroms = self.stack[-1]['from']
                 else:
-                    util.coerce_kw_type(opts, 'mode', int)
-        # Can't set 'handle' or 'pool' via URL query args, use connect_args
+                    existingfroms = None
 
-        return ([], opts)
+                froms = select._get_display_froms(existingfroms)
+                whereclause = self._get_nonansi_join_whereclause(froms)
+                if whereclause:
+                    select = select.where(whereclause)
+                    select._oracle_visit = True
 
-    def is_disconnect(self, e):
-        if isinstance(e, self.dbapi.InterfaceError):
-            return "not connected" in str(e)
-        else:
-            return "ORA-03114" in str(e) or "ORA-03113" in str(e)
+            if select._limit is not None or select._offset is not None:
+                # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html
+                #
+                # Generalized form of an Oracle pagination query:
+                #   select ... from (
+                #     select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from (
+                #         select distinct ... where ... order by ...
+                #     ) where ROWNUM <= :limit+:offset
+                #   ) where ora_rn > :offset
+                # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0
 
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
+                # TODO: use annotations instead of clone + attr set ?
+                select = select._generate()
+                select._oracle_visit = True
 
-    def create_xid(self):
-        """create a two-phase transaction ID.
+                # Wrap the middle select and add the hint
+                limitselect = sql.select([c for c in select.c])
+                if select._limit and self.dialect.optimize_limits:
+                    limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit)
 
-        this id will be passed to do_begin_twophase(), do_rollback_twophase(),
-        do_commit_twophase().  its format is unspecified."""
+                limitselect._oracle_visit = True
+                limitselect._is_wrapper = True
 
-        id = random.randint(0, 2 ** 128)
-        return (0x1234, "%032x" % id, "%032x" % 9)
-        
-    def do_release_savepoint(self, connection, name):
-        # Oracle does not support RELEASE SAVEPOINT
-        pass
+                # If needed, add the limiting clause
+                if select._limit is not None:
+                    max_row = select._limit
+                    if select._offset is not None:
+                        max_row += select._offset
+                    limitselect.append_whereclause(
+                            sql.literal_column("ROWNUM")<=max_row)
 
-    def do_begin_twophase(self, connection, xid):
-        connection.connection.begin(*xid)
+                # If needed, add the ora_rn, and wrap again with offset.
+                if select._offset is None:
+                    select = limitselect
+                else:
+                     limitselect = limitselect.column(
+                             sql.literal_column("ROWNUM").label("ora_rn"))
+                     limitselect._oracle_visit = True
+                     limitselect._is_wrapper = True
 
-    def do_prepare_twophase(self, connection, xid):
-        connection.connection.prepare()
+                     offsetselect = sql.select(
+                             [c for c in limitselect.c if c.key!='ora_rn'])
+                     offsetselect._oracle_visit = True
+                     offsetselect._is_wrapper = True
 
-    def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
-        self.do_rollback(connection.connection)
+                     offsetselect.append_whereclause(
+                             sql.literal_column("ora_rn")>select._offset)
 
-    def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
-        self.do_commit(connection.connection)
+                     select = offsetselect
 
-    def do_recover_twophase(self, connection):
-        pass
+        kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
+        return compiler.SQLCompiler.visit_select(self, select, **kwargs)
+
+    def limit_clause(self, select):
+        return ""
+
+    def for_update_clause(self, select):
+        if select.for_update == "nowait":
+            return " FOR UPDATE NOWAIT"
+        else:
+            return super(OracleCompiler, self).for_update_clause(select)
+
+class OracleDDLCompiler(compiler.DDLCompiler):
+    def get_column_specification(self, column, **kwargs):
+        colspec = self.preparer.format_column(column)
+        colspec += " " + self.dialect.type_compiler.process(column.type)
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+        return colspec
+
+    def visit_create_sequence(self, create):
+        return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
+
+    def visit_drop_sequence(self, drop):
+        return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+
+class OracleDefaultRunner(base.DefaultRunner):
+    def visit_sequence(self, seq):
+        return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {})
+
+class OracleIdentifierPreparer(compiler.IdentifierPreparer):
+    
+    reserved_words = set([x.lower() for x in RESERVED_WORDS])
+
+    def _bindparam_requires_quotes(self, value):
+        """Return True if the given identifier requires quoting."""
+        lc_value = value.lower()
+        return (lc_value in self.reserved_words
+                or self.illegal_initial_characters.match(value[0])
+                or not self.legal_characters.match(unicode(value))
+                )
+    
+    def format_savepoint(self, savepoint):
+        name = re.sub(r'^_+', '', savepoint.ident)
+        return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
+        
+class OracleDialect(default.DefaultDialect):
+    name = 'oracle'
+    supports_alter = True
+    supports_unicode_statements = False
+    supports_unicode_binds = False
+    max_identifier_length = 30
+    supports_sane_rowcount = True
+    supports_sane_multi_rowcount = False
+    supports_sequences = True
+    sequences_optional = False
+    preexecute_pk_sequences = True
+    supports_pk_autoincrement = False
+    default_paramstyle = 'named'
+    colspecs = colspecs
+    ischema_names = ischema_names
+    
+    supports_default_values = False
+    supports_empty_insert = False
+    
+    statement_compiler = OracleCompiler
+    ddl_compiler = OracleDDLCompiler
+    type_compiler = OracleTypeCompiler
+    preparer = OracleIdentifierPreparer
+    defaultrunner = OracleDefaultRunner
+    
+    def __init__(self, 
+                use_ansi=True, 
+                optimize_limits=False, 
+                **kwargs):
+        default.DefaultDialect.__init__(self, **kwargs)
+        self.use_ansi = use_ansi
+        self.optimize_limits = optimize_limits
 
     def has_table(self, connection, table_name, schema=None):
         if not schema:
@@ -508,10 +508,9 @@ class OracleDialect(default.DefaultDialect):
         else:
             return name.encode(self.encoding)
 
+    @base.connection_memoize(('dialect', 'default_schema_name'))
     def get_default_schema_name(self, connection):
         return self._normalize_name(connection.execute('SELECT USER FROM DUAL').scalar())
-    get_default_schema_name = base.connection_memoize(
-        ('dialect', 'default_schema_name'))(get_default_schema_name)
 
     def table_names(self, connection, schema):
         # note that table_names() isnt loading DBLINKed or synonym'ed tables
@@ -601,17 +600,17 @@ class OracleDialect(default.DefaultDialect):
             #length is ignored except for CHAR and VARCHAR2
             if coltype == 'NUMBER' :
                 if precision is None and scale is None:
-                    coltype = OracleNumeric
+                    coltype = sqltypes.NUMERIC
                 elif precision is None and scale == 0  :
-                    coltype = OracleInteger
+                    coltype = sqltypes.INTEGER
                 else :
-                    coltype = OracleNumeric(precision, scale)
+                    coltype = sqltypes.NUMERIC(precision, scale)
             elif coltype=='CHAR' or coltype=='VARCHAR2':
-                coltype = ischema_names.get(coltype, OracleString)(length)
+                coltype = self.ischema_names.get(coltype)(length)
             else:
                 coltype = re.sub(r'\(\d+\)', '', coltype)
                 try:
-                    coltype = ischema_names[coltype]
+                    coltype = self.ischema_names[coltype]
                 except KeyError:
                     util.warn("Did not recognize type '%s' of column '%s'" %
                               (coltype, colname))
@@ -653,8 +652,9 @@ class OracleDialect(default.DefaultDialect):
             if row is None:
                 break
             #print "ROW:" , row
-            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
-            if cons_type == 'P':
+            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
+                    row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
+            if cons_type == 'P' and local_column in table.c:
                 table.primary_key.add(table.c[local_column])
             elif cons_type == 'R':
                 try:
@@ -698,203 +698,5 @@ class _OuterJoinColumn(sql.ClauseElement):
     def __init__(self, column):
         self.column = column
 
-class OracleCompiler(compiler.SQLCompiler):
-    """Oracle compiler modifies the lexical structure of Select
-    statements to work under non-ANSI configured Oracle databases, if
-    the use_ansi flag is False.
-    """
-
-    operators = compiler.SQLCompiler.operators.copy()
-    operators.update(
-        {
-            sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y),
-            sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
-        }
-    )
-
-    functions = compiler.SQLCompiler.functions.copy()
-    functions.update (
-        {
-            sql_functions.now : 'CURRENT_TIMESTAMP'
-        }
-    )
-
-    def __init__(self, *args, **kwargs):
-        super(OracleCompiler, self).__init__(*args, **kwargs)
-        self.__wheres = {}
-
-    def default_from(self):
-        """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
-
-        The Oracle compiler tacks a "FROM DUAL" to the statement.
-        """
-
-        return " FROM DUAL"
-
-    def apply_function_parens(self, func):
-        return len(func.clauses) > 0
-
-    def visit_join(self, join, **kwargs):
-        if self.dialect.use_ansi:
-            return compiler.SQLCompiler.visit_join(self, join, **kwargs)
-        else:
-            return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
-
-    def _get_nonansi_join_whereclause(self, froms):
-        clauses = []
-
-        def visit_join(join):
-            if join.isouter:
-                def visit_binary(binary):
-                    if binary.operator == sql_operators.eq:
-                        if binary.left.table is join.right:
-                            binary.left = _OuterJoinColumn(binary.left)
-                        elif binary.right.table is join.right:
-                            binary.right = _OuterJoinColumn(binary.right)
-                clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary}))
-            else:
-                clauses.append(join.onclause)
-
-        for f in froms:
-            visitors.traverse(f, {}, {'join':visit_join})
-        return sql.and_(*clauses)
-
-    def visit_outer_join_column(self, vc):
-        return self.process(vc.column) + "(+)"
-
-    def visit_sequence(self, seq):
-        return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
-
-    def visit_alias(self, alias, asfrom=False, **kwargs):
-        """Oracle doesn't like ``FROM table AS alias``.  Is the AS standard SQL??"""
-
-        if asfrom:
-            return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name))
-        else:
-            return self.process(alias.original, **kwargs)
-
-    def _TODO_visit_compound_select(self, select):
-        """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
-        pass
-
-    def visit_select(self, select, **kwargs):
-        """Look for ``LIMIT`` and OFFSET in a select statement, and if
-        so tries to wrap it in a subquery with ``rownum`` criterion.
-        """
-
-        if not getattr(select, '_oracle_visit', None):
-            if not self.dialect.use_ansi:
-                if self.stack and 'from' in self.stack[-1]:
-                    existingfroms = self.stack[-1]['from']
-                else:
-                    existingfroms = None
-
-                froms = select._get_display_froms(existingfroms)
-                whereclause = self._get_nonansi_join_whereclause(froms)
-                if whereclause:
-                    select = select.where(whereclause)
-                    select._oracle_visit = True
-
-            if select._limit is not None or select._offset is not None:
-                # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html
-                #
-                # Generalized form of an Oracle pagination query:
-                #   select ... from (
-                #     select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from (
-                #         select distinct ... where ... order by ...
-                #     ) where ROWNUM <= :limit+:offset
-                #   ) where ora_rn > :offset
-                # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0
-
-                # TODO: use annotations instead of clone + attr set ?
-                select = select._generate()
-                select._oracle_visit = True
-
-                # Wrap the middle select and add the hint
-                limitselect = sql.select([c for c in select.c])
-                if select._limit and self.dialect.optimize_limits:
-                    limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit)
-
-                limitselect._oracle_visit = True
-                limitselect._is_wrapper = True
-
-                # If needed, add the limiting clause
-                if select._limit is not None:
-                    max_row = select._limit
-                    if select._offset is not None:
-                        max_row += select._offset
-                    limitselect.append_whereclause(
-                            sql.literal_column("ROWNUM")<=max_row)
-                # If needed, add the ora_rn, and wrap again with offset.
-                if select._offset is None:
-                    select = limitselect
-                else:
-                     limitselect = limitselect.column(
-                             sql.literal_column("ROWNUM").label("ora_rn"))
-                     limitselect._oracle_visit = True
-                     limitselect._is_wrapper = True
-                     offsetselect = sql.select(
-                             [c for c in limitselect.c if c.key!='ora_rn'])
-                     offsetselect._oracle_visit = True
-                     offsetselect._is_wrapper = True
-                     offsetselect.append_whereclause(
-                             sql.literal_column("ora_rn")>select._offset)
-                     select = offsetselect
-
-        kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
-        return compiler.SQLCompiler.visit_select(self, select, **kwargs)
-
-    def limit_clause(self, select):
-        return ""
-
-    def for_update_clause(self, select):
-        if select.for_update == "nowait":
-            return " FOR UPDATE NOWAIT"
-        else:
-            return super(OracleCompiler, self).for_update_clause(select)
-
-
-class OracleSchemaGenerator(compiler.SchemaGenerator):
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column)
-        colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-        default = self.get_column_default_string(column)
-        if default is not None:
-            colspec += " DEFAULT " + default
-
-        if not column.nullable:
-            colspec += " NOT NULL"
-        return colspec
-
-    def visit_sequence(self, sequence):
-        if not self.checkfirst  or not self.dialect.has_sequence(self.connection, sequence.name, sequence.schema):
-            self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-class OracleSchemaDropper(compiler.SchemaDropper):
-    def visit_sequence(self, sequence):
-        if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name, sequence.schema):
-            self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
-            self.execute()
-
-class OracleDefaultRunner(base.DefaultRunner):
-    def visit_sequence(self, seq):
-        return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {})
-
-class OracleIdentifierPreparer(compiler.IdentifierPreparer):
-    def format_savepoint(self, savepoint):
-        name = re.sub(r'^_+', '', savepoint.ident)
-        return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
 
 
-dialect = OracleDialect
-dialect.statement_compiler = OracleCompiler
-dialect.schemagenerator = OracleSchemaGenerator
-dialect.schemadropper = OracleSchemaDropper
-dialect.preparer = OracleIdentifierPreparer
-dialect.defaultrunner = OracleDefaultRunner
-dialect.execution_ctx_cls = OracleExecutionContext
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
new file mode 100644 (file)
index 0000000..b899d44
--- /dev/null
@@ -0,0 +1,252 @@
+"""Support for the Oracle database via the cx_oracle driver.
+
+Driver
+------
+
+The Oracle dialect uses the cx_oracle driver, available at 
+http://cx-oracle.sourceforge.net/ .   The dialect has several behaviors 
+which are specifically tailored towards compatibility with this module.
+
+Connecting
+----------
+
+Connecting with create_engine() uses the standard URL approach of 
+``oracle://user:pass@host:port/dbname[?key=value&key=value...]``.  If dbname is present, the 
+host, port, and dbname tokens are converted to a TNS name using the cx_oracle 
+:func:`makedsn()` function.  Otherwise, the host token is taken directly as a TNS name.
+
+Additional arguments which may be specified either as query string arguments on the
+URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are:
+
+* *allow_twophase* - enable two-phase transactions.  Defaults to ``True``.
+
+* *auto_convert_lobs* - defaults to True, see the section on LOB objects.
+
+* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters.
+  This is required for LOB datatypes but can be disabled to reduce overhead.  Defaults
+  to ``True``.
+
+* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an
+  integer value.  This value is only available as a URL query string argument.
+
+* *threaded* - enable multithreaded access to cx_oracle connections.  Defaults
+  to ``True``.  Note that this is the opposite default of cx_oracle itself.
+
+
+LOB Objects
+-----------
+
+cx_oracle presents some challenges when fetching LOB objects.  A LOB object in a result set
+is presented by cx_oracle as a cx_oracle.LOB object which has a read() method.  By default, 
+SQLAlchemy converts these LOB objects into Python strings.  This is for two reasons.  First,
+the LOB object requires an active cursor association, meaning if you were to fetch many rows
+at once such that cx_oracle had to go back to the database and fetch a new batch of rows,
+the LOB objects in the already-fetched rows are now unreadable and will raise an error. 
+SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read.  
+The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy
+defaults to 50 (cx_oracle normally defaults this to one).  
+
+Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to 
+"normalize" the results to look more like other DBAPIs.
+
+The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place
+for all statement executions, even plain string-based statements for which SQLA has no awareness
+of result typing.  This is so that calls like fetchmany() and fetchall() can work in all cases
+without raising cursor errors.  The conversion of LOB in all cases, as well as the "prefetch"
+of LOB objects, can be disabled using auto_convert_lobs=False.  
+
+Two Phase Transaction Support
+-----------------------------
+
+Two Phase transactions are implemented using XA transactions.  Success has been reported of them
+working successfully but this should be regarded as an experimental feature.
+
+"""
+
+from sqlalchemy.dialects.oracle.base import OracleDialect, OracleText, OracleBinary, OracleRaw, RESERVED_WORDS
+from sqlalchemy.engine.default import DefaultExecutionContext
+from sqlalchemy.engine import base
+from sqlalchemy import types as sqltypes, util
+
+class OracleNVarchar(sqltypes.NVARCHAR):
+    """The SQL NVARCHAR type."""
+
+    def __init__(self, **kw):
+        kw['convert_unicode'] = False  # cx_oracle does this for us, for NVARCHAR2
+        sqltypes.NVARCHAR.__init__(self, **kw)
+
+class Oracle_cx_oracleExecutionContext(DefaultExecutionContext):
+    def pre_exec(self):
+        
+        quoted_bind_names = getattr(self.compiled, '_quoted_bind_names', {})
+        if quoted_bind_names:
+            for param in self.parameters:
+                for fromname, toname in self.compiled._quoted_bind_names.iteritems():
+                    param[toname.encode(self.dialect.encoding)] = param[fromname]
+                    del param[fromname]
+
+        if self.dialect.auto_setinputsizes:
+            self.set_input_sizes(quoted_bind_names)
+            
+        if len(self.compiled_parameters) == 1:
+            for key in self.compiled.binds:
+                bindparam = self.compiled.binds[key]
+                name = self.compiled.bind_names[bindparam]
+                value = self.compiled_parameters[0][name]
+                if bindparam.isoutparam:
+                    dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+                    if not hasattr(self, 'out_parameters'):
+                        self.out_parameters = {}
+                    self.out_parameters[name] = self.cursor.var(dbtype)
+                    self.parameters[0][quoted_bind_names.get(name, name)] = self.out_parameters[name]
+        
+        
+    def create_cursor(self):
+        c = self._connection.connection.cursor()
+        if self.dialect.arraysize:
+            c.cursor.arraysize = self.dialect.arraysize
+        return c
+
+    def get_result_proxy(self):
+        if hasattr(self, 'out_parameters'):
+            if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
+                for bind, name in self.compiled.bind_names.iteritems():
+                    if name in self.out_parameters:
+                        type = bind.type
+                        result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect)
+                        if result_processor is not None:
+                            self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue())
+                        else:
+                            self.out_parameters[name] = self.out_parameters[name].getvalue()
+            else:
+                for k in self.out_parameters:
+                    self.out_parameters[k] = self.out_parameters[k].getvalue()
+
+        if self.cursor.description is not None:
+            for column in self.cursor.description:
+                type_code = column[1]
+                if type_code in self.dialect.ORACLE_BINARY_TYPES:
+                    return base.BufferedColumnResultProxy(self)
+
+        return base.ResultProxy(self)
+
+
+class Oracle_cx_oracle(OracleDialect):
+    execution_ctx_cls = Oracle_cx_oracleExecutionContext
+    driver = "cx_oracle"
+    
+    colspecs = util.update_copy(
+        OracleDialect.colspecs,
+        {
+            sqltypes.NVARCHAR:OracleNVarchar
+        }
+    )
+    
+    def __init__(self, 
+                auto_setinputsizes=True, 
+                auto_convert_lobs=True, 
+                threaded=True, 
+                allow_twophase=True, 
+                arraysize=50, **kwargs):
+        OracleDialect.__init__(self, **kwargs)
+        self.threaded = threaded
+        self.arraysize = arraysize
+        self.allow_twophase = allow_twophase
+        self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
+        self.auto_setinputsizes = auto_setinputsizes
+        self.auto_convert_lobs = auto_convert_lobs
+        if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__:
+            self.dbapi_type_map = {}
+            self.ORACLE_BINARY_TYPES = []
+        else:
+            # only use this for LOB objects.  using it for strings, dates
+            # etc. leads to a little too much magic, reflection doesn't know if it should
+            # expect encoded strings or unicodes, etc.
+            self.dbapi_type_map = {
+                self.dbapi.CLOB: OracleText(),
+                self.dbapi.BLOB: OracleBinary(),
+                self.dbapi.BINARY: OracleRaw(),
+            }
+            self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
+    
+    @classmethod
+    def dbapi(cls):
+        import cx_Oracle
+        return cx_Oracle
+
+    def create_connect_args(self, url):
+        dialect_opts = dict(url.query)
+        for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs',
+                    'threaded', 'allow_twophase'):
+            if opt in dialect_opts:
+                util.coerce_kw_type(dialect_opts, opt, bool)
+                setattr(self, opt, dialect_opts[opt])
+
+        if url.database:
+            # if we have a database, then we have a remote host
+            port = url.port
+            if port:
+                port = int(port)
+            else:
+                port = 1521
+            dsn = self.dbapi.makedsn(url.host, port, url.database)
+        else:
+            # we have a local tnsname
+            dsn = url.host
+
+        opts = dict(
+            user=url.username,
+            password=url.password,
+            dsn=dsn,
+            threaded=self.threaded,
+            twophase=self.allow_twophase,
+            )
+        if 'mode' in url.query:
+            opts['mode'] = url.query['mode']
+            if isinstance(opts['mode'], basestring):
+                mode = opts['mode'].upper()
+                if mode == 'SYSDBA':
+                    opts['mode'] = self.dbapi.SYSDBA
+                elif mode == 'SYSOPER':
+                    opts['mode'] = self.dbapi.SYSOPER
+                else:
+                    util.coerce_kw_type(opts, 'mode', int)
+        # Can't set 'handle' or 'pool' via URL query args, use connect_args
+
+        return ([], opts)
+
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.InterfaceError):
+            return "not connected" in str(e)
+        else:
+            return "ORA-03114" in str(e) or "ORA-03113" in str(e)
+
+    def create_xid(self):
+        """create a two-phase transaction ID.
+
+        this id will be passed to do_begin_twophase(), do_rollback_twophase(),
+        do_commit_twophase().  its format is unspecified."""
+
+        id = random.randint(0, 2 ** 128)
+        return (0x1234, "%032x" % id, "%032x" % 9)
+
+    def do_release_savepoint(self, connection, name):
+        # Oracle does not support RELEASE SAVEPOINT
+        pass
+
+    def do_begin_twophase(self, connection, xid):
+        connection.connection.begin(*xid)
+
+    def do_prepare_twophase(self, connection, xid):
+        connection.connection.prepare()
+
+    def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False):
+        self.do_rollback(connection.connection)
+
+    def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False):
+        self.do_commit(connection.connection)
+
+    def do_recover_twophase(self, connection):
+        pass
+
+dialect = Oracle_cx_oracle
index 8fd4ef5ef284fd39ff0962fdd7a72d7e591ecf4d..ce1fa1507de90e9ee1ae097aa72c91861110cadc 100644 (file)
@@ -155,28 +155,28 @@ colspecs = {
 }
 
 ischema_names = {
-    'integer' : sqltypes.Integer,
+    'integer' : sqltypes.INTEGER,
     'bigint' : PGBigInteger,
-    'smallint' : sqltypes.SmallInteger,
-    'character varying' : sqltypes.String,
+    'smallint' : sqltypes.SMALLINT,
+    'character varying' : sqltypes.VARCHAR,
     'character' : sqltypes.CHAR,
-    'text' : sqltypes.Text,
-    'numeric' : sqltypes.Numeric,
-    'float' : sqltypes.Float,
+    'text' : sqltypes.TEXT,
+    'numeric' : sqltypes.NUMERIC,
+    'float' : sqltypes.FLOAT,
     'real' : sqltypes.Float,
     'inet': PGInet,
     'cidr': PGCidr,
     'macaddr': PGMacAddr,
     'double precision' : sqltypes.Float,
-    'timestamp' : sqltypes.DateTime,
-    'timestamp with time zone' : sqltypes.DateTime,
-    'timestamp without time zone' : sqltypes.DateTime,
-    'time with time zone' : sqltypes.Time,
-    'time without time zone' : sqltypes.Time,
-    'date' : sqltypes.Date,
-    'time': sqltypes.Time,
+    'timestamp' : sqltypes.TIMESTAMP,
+    'timestamp with time zone' : sqltypes.TIMESTAMP,
+    'timestamp without time zone' : sqltypes.TIMESTAMP,
+    'time with time zone' : sqltypes.TIME,
+    'time without time zone' : sqltypes.TIME,
+    'date' : sqltypes.DATE,
+    'time': sqltypes.TIME,
     'bytea' : sqltypes.Binary,
-    'boolean' : sqltypes.Boolean,
+    'boolean' : sqltypes.BOOLEAN,
     'interval':PGInterval,
 }
 
@@ -490,9 +490,6 @@ class PGDialect(default.DefaultDialect):
             raise AssertionError("Could not determine version from string '%s'" % v)
         return tuple([int(x) for x in m.group(1, 2, 3)])
 
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, self.colspecs)
-
     def reflecttable(self, connection, table, include_columns):
         preparer = self.identifier_preparer
         if table.schema is not None:
index b90ac8d9c651070a64b9ad803ad99e2213c5f0e9..364c13236d5e0afaef86a551a537ecf71aedda00 100644 (file)
@@ -100,8 +100,6 @@ class Postgres_psycopg2Compiler(PGCompiler):
     )
     
     def post_process_text(self, text):
-        if '%%' in text:
-            util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.")
         return text.replace('%', '%%')
 
 class Postgres_psycopg2(PGDialect):
index 773501d64c5164e8a00e91793acbbf82aa9c347a..319a5bffc6d253d7258e199c2091d954e173f6bb 100644 (file)
@@ -146,23 +146,23 @@ colspecs = {
 }
 
 ischema_names = {
-    'BLOB': sqltypes.Binary,
-    'BOOL': sqltypes.Boolean,
-    'BOOLEAN': sqltypes.Boolean,
+    'BLOB': sqltypes.BLOB,
+    'BOOL': sqltypes.BOOLEAN,
+    'BOOLEAN': sqltypes.BOOLEAN,
     'CHAR': sqltypes.CHAR,
-    'DATE': sqltypes.Date,
-    'DATETIME': sqltypes.DateTime,
-    'DECIMAL': sqltypes.Numeric,
-    'FLOAT': sqltypes.Numeric,
-    'INT': sqltypes.Integer,
-    'INTEGER': sqltypes.Integer,
-    'NUMERIC': sqltypes.Numeric,
+    'DATE': sqltypes.DATE,
+    'DATETIME': sqltypes.DATETIME,
+    'DECIMAL': sqltypes.DECIMAL,
+    'FLOAT': sqltypes.FLOAT,
+    'INT': sqltypes.INTEGER,
+    'INTEGER': sqltypes.INTEGER,
+    'NUMERIC': sqltypes.NUMERIC,
     'REAL': sqltypes.Numeric,
-    'SMALLINT': sqltypes.SmallInteger,
-    'TEXT': sqltypes.Text,
-    'TIME': sqltypes.Time,
-    'TIMESTAMP': sqltypes.DateTime,
-    'VARCHAR': sqltypes.String,
+    'SMALLINT': sqltypes.SMALLINT,
+    'TEXT': sqltypes.TEXT,
+    'TIME': sqltypes.TIME,
+    'TIMESTAMP': sqltypes.TIMESTAMP,
+    'VARCHAR': sqltypes.VARCHAR,
 }
 
 
@@ -256,10 +256,8 @@ class SQLiteDialect(default.DefaultDialect):
     type_compiler = SQLiteTypeCompiler
     preparer = SQLiteIdentifierPreparer
     ischema_names = ischema_names
-
-    def type_descriptor(self, typeobj):
-        return sqltypes.adapt_type(typeobj, colspecs)
-
+    colspecs = colspecs
+    
     def table_names(self, connection, schema):
         if schema is not None:
             qschema = self.identifier_preparer.quote_identifier(schema)
index 1dc3d720ef8aae6c5ba2d140034d0d78079c14c8..8be0a2d85fd9b80400438411afeb1e6c6cde747a 100644 (file)
@@ -15,7 +15,7 @@ as the base class for their own corresponding classes.
 import re, random
 from sqlalchemy.engine import base
 from sqlalchemy.sql import compiler, expression
-from sqlalchemy import exc
+from sqlalchemy import exc, types as sqltypes
 
 AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
                                re.I | re.UNICODE)
@@ -72,13 +72,12 @@ class DefaultDialect(base.Dialect):
         """Provide a database-specific ``TypeEngine`` object, given
         the generic object which comes from the types module.
 
-        Subclasses will usually use the ``adapt_type()`` method in the
-        types module to make this job easy.
+        This method looks for a dictionary called 
+        ``colspecs`` as a class or instance-level variable,
+        and passes on to ``types.adapt_type()``.
         
         """
-        if type(typeobj) is type:
-            typeobj = typeobj()
-        return typeobj
+        return sqltypes.adapt_type(typeobj, self.colspecs)
 
     def validate_identifier(self, ident):
         if len(ident) > self.max_identifier_length:
@@ -315,12 +314,16 @@ class DefaultExecutionContext(base.ExecutionContext):
     def lastrow_has_defaults(self):
         return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
 
-    def set_input_sizes(self):
+    def set_input_sizes(self, translate=None):
         """Given a cursor and ClauseParameters, call the appropriate
         style of ``setinputsizes()`` on the cursor, using DB-API types
         from the bind parameter's ``TypeEngine`` objects.
+        
         """
 
+        if not hasattr(self.compiled, 'bind_names'):
+            return
+            
         types = dict(
                 (self.compiled.bind_names[bindparam], bindparam.type)
                  for bindparam in self.compiled.bind_names)
@@ -343,6 +346,8 @@ class DefaultExecutionContext(base.ExecutionContext):
                 typeengine = types[key]
                 dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
                 if dbtype is not None:
+                    if translate:
+                        key = translate.get(key, key)
                     inputsizes[key.encode(self.dialect.encoding)] = dbtype
             try:
                 self.cursor.setinputsizes(**inputsizes)
index 683831998737177105e427d56fcf2ae25196ca4b..d506efcacbefa97628b872acf5e6ff0a8ff33387 100644 (file)
@@ -982,6 +982,9 @@ class GenericTypeCompiler(engine.TypeCompiler):
     def visit_VARCHAR(self, type_):
         return "VARCHAR" + (type_.length and "(%d)" % type_.length or "")
 
+    def visit_NVARCHAR(self, type_):
+        return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "")
+
     def visit_BLOB(self, type_):
         return "BLOB"
     
index 92ee125b631915df91a154537f976bde871f5ea0..ea8a8ceb3a198b884eeeecea5c5730e60dbcae42 100644 (file)
@@ -917,6 +917,8 @@ class TIMESTAMP(DateTime):
 
     __visit_name__ = 'TIMESTAMP'
 
+    def get_dbapi_type(self, dbapi):
+        return dbapi.TIMESTAMP
 
 class DATETIME(DateTime):
     """The SQL DATETIME type."""
@@ -951,6 +953,10 @@ class VARCHAR(String):
 
     __visit_name__ = 'VARCHAR'
 
+class NVARCHAR(Unicode):
+    """The SQL NVARCHAR type."""
+
+    __visit_name__ = 'NVARCHAR'
 
 class CHAR(String):
     """The SQL CHAR type."""
index 2186f22595b30739e0f5b0429dd9cd30b0671d50..c55e778a2431eb7dbfaa67ac297b3ef297b09708 100644 (file)
@@ -2,6 +2,7 @@
 
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
+from sqlalchemy import types as sqltypes
 from sqlalchemy.sql import table, column
 from sqlalchemy.databases import oracle
 from testlib import *
@@ -301,13 +302,13 @@ class TypesTest(TestBase, AssertsCompiledSQL):
     def test_reflect_nvarchar(self):
         metadata = MetaData(testing.db)
         t = Table('t', metadata,
-            Column('data', oracle.OracleNVarchar(255))
+            Column('data', sqltypes.NVARCHAR(255))
         )
         metadata.create_all()
         try:
             m2 = MetaData(testing.db)
             t2 = Table('t', m2, autoload=True)
-            assert isinstance(t2.c.data.type, oracle.OracleNVarchar)
+            assert isinstance(t2.c.data.type, sqltypes.NVARCHAR)
             data = u'm’a réveillé.'
             t2.insert().execute(data=data)
             eq_(t2.select().execute().fetchone()['data'], data)
index 4e6601951f2ccdb343dc5023564bd4b3a15aae6c..cd037f6ca3671e37ede8aaebcb38175070cfe3e8 100644 (file)
@@ -21,10 +21,10 @@ class ReflectionTest(TestBase, ComparesTables):
             Column('test2', sa.Float(5), nullable=False),
             Column('test3', sa.Text),
             Column('test4', sa.Numeric, nullable = False),
-            Column('test5', sa.DateTime),
+            Column('test5', sa.Date),
             Column('parent_user_id', sa.Integer,
                    sa.ForeignKey('engine_users.user_id')),
-            Column('test6', sa.DateTime, nullable=False),
+            Column('test6', sa.Date, nullable=False),
             Column('test7', sa.Text),
             Column('test8', sa.Binary),
             Column('test_passivedefault2', sa.Integer, server_default='5'),
index 660529c25c194ea9bf0e0d736d625cdb312cea6a..0e45aff1071ad578bf6636ca388e0c5451f6314c 100644 (file)
@@ -12,11 +12,11 @@ class QueryTest(TestBase):
         global users, users2, addresses, metadata
         metadata = MetaData(testing.db)
         users = Table('query_users', metadata,
-            Column('user_id', INT, primary_key = True),
+            Column('user_id', INT, Sequence('user_id_seq', optional=True), primary_key = True),
             Column('user_name', VARCHAR(20)),
         )
         addresses = Table('query_addresses', metadata,
-            Column('address_id', Integer, primary_key=True),
+            Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key=True),
             Column('user_id', Integer, ForeignKey('query_users.user_id')),
             Column('address', String(30)))
             
@@ -252,6 +252,7 @@ class QueryTest(TestBase):
             eq_(expr.execute().fetchall(), result)
     
 
+    @testing.fails_on("oracle", "neither % nor %% are accepted")
     @testing.fails_on("+pg8000", "can't interpret result column from '%%'")
     @testing.emits_warning('.*now automatically escapes.*')
     def test_percents_in_text(self):
@@ -484,13 +485,15 @@ class QueryTest(TestBase):
         self.assert_(r['query_users.user_id']) == 1
         self.assert_(r['query_users.user_name']) == "john"
 
+    @testing.fails_on('oracle', 'oracle result keys() are all uppercase, not getting into this.')
     def test_row_as_args(self):
         users.insert().execute(user_id=1, user_name='john')
         r = users.select(users.c.user_id==1).execute().fetchone()
         users.delete().execute()
         users.insert().execute(r)
-        assert users.select().execute().fetchall() == [(1, 'john')]
+        eq_(users.select().execute().fetchall(), [(1, 'john')])
     
+    @testing.fails_on('oracle', 'oracle result keys() are all uppercase, not getting into this.')
     def test_result_as_args(self):
         users.insert().execute([dict(user_id=1, user_name='john'), dict(user_id=2, user_name='ed')])
         r = users.select().execute()
@@ -720,7 +723,7 @@ class PercentSchemaNamesTest(TestBase):
         result.close()
         percent_table.update().values({percent_table.c['%(oneofthese)s']:9, percent_table.c['spaces % more spaces']:15}).execute()
         eq_(
-            percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute().fetchall(),
+            percent_table.select().order_by(percent_table.c['percent%']).execute().fetchall(),
             [
                 (5, 9, 15),
                 (7, 9, 15),