From: Mike Bayer Date: Sat, 17 Jan 2009 01:16:51 +0000 (+0000) Subject: - oracle support, includes fix for #994 X-Git-Tag: rel_0_6_6~335 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=419f212c3346173efb4a89ab80a4fe1ba2b4d7e0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - oracle support, includes fix for #994 --- diff --git a/doc/build/reference/dialects/oracle.rst b/doc/build/reference/dialects/oracle.rst index 188f6f4383..68be06b64c 100644 --- a/doc/build/reference/dialects/oracle.rst +++ b/doc/build/reference/dialects/oracle.rst @@ -1,4 +1,10 @@ Oracle ====== -.. automodule:: sqlalchemy.databases.oracle +.. automodule:: sqlalchemy.dialects.oracle.base + +cx_Oracle Notes +=============== + +.. automodule:: sqlalchemy.dialects.oracle.cx_oracle + diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py index 33e481d25c..8526b4d8fb 100644 --- a/lib/sqlalchemy/dialects/__init__.py +++ b/lib/sqlalchemy/dialects/__init__.py @@ -5,7 +5,7 @@ __all__ = ( # 'maxdb', # 'mssql', 'mysql', -# 'oracle', + 'oracle', 'postgres', 'sqlite', # 'sybase', diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 3c66945e80..74938abe0d 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1689,7 +1689,8 @@ class MySQLDialect(default.DefaultDialect): max_identifier_length = 255 supports_sane_rowcount = True default_paramstyle = 'format' - + colspecs = colspecs + statement_compiler = MySQLCompiler ddl_compiler = MySQLDDLCompiler type_compiler = MySQLTypeCompiler @@ -1699,9 +1700,6 @@ class MySQLDialect(default.DefaultDialect): self.use_ansiquotes = use_ansiquotes default.DefaultDialect.__init__(self, **kwargs) - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - def do_executemany(self, cursor, statement, parameters, context=None): rowcount = cursor.executemany(statement, parameters) if context is not None: diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 61f9d3f671..b077774ea5 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -40,8 +40,6 @@ class MySQL_mysqldbCompiler(MySQLCompiler): ) def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy mysql+mysqldb dialect now automatically escapes '%' in text() expressions to '%%'.") return text.replace('%', '%%') class MySQL_mysqldb(MySQLDialect): diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py new file mode 100644 index 0000000000..7038fb3ec0 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.oracle import base, cx_oracle + +base.dialect = cx_oracle.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/dialects/oracle/base.py similarity index 65% rename from lib/sqlalchemy/databases/oracle.py rename to lib/sqlalchemy/dialects/oracle/base.py index b0ec6115b2..9bf6db23d6 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1,4 +1,4 @@ -# oracle.py +# oracle/base.py # Copyright (C) 2005, 2006, 2007, 2008, 2009 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under @@ -7,37 +7,14 @@ Oracle version 8 through current (11g at the time of this writing) are supported. -Driver ------- +For information on connecting via specific drivers, see the documentation +for that driver. -The Oracle dialect uses the cx_oracle driver, available at -http://cx-oracle.sourceforge.net/ . The dialect has several behaviors -which are specifically tailored towards compatibility with this module. +Connect Arguments +----------------- -Connecting ----------- - -Connecting with create_engine() uses the standard URL approach of -``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the -host, port, and dbname tokens are converted to a TNS name using the cx_oracle -:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name. - -Additional arguments which may be specified either as query string arguments on the -URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are: - -* *allow_twophase* - enable two-phase transactions. Defaults to ``True``. - -* *auto_convert_lobs* - defaults to True, see the section on LOB objects. - -* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters. - This is required for LOB datatypes but can be disabled to reduce overhead. Defaults - to ``True``. - -* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an - integer value. This value is only available as a URL query string argument. - -* *threaded* - enable multithreaded access to cx_oracle connections. Defaults - to ``True``. Note that this is the opposite default of cx_oracle itself. +The dialect supports several :func:`~sqlalchemy.create_engine()` arguments which +affect the behavior of the dialect regardless of driver in use. * *use_ansi* - Use ANSI JOIN constructs (see the section on Oracle 8). Defaults to ``True``. If ``False``, Oracle-8 compatible constructs are used for joins. @@ -67,28 +44,6 @@ This step is also required when using table reflection, i.e. autoload=True:: autoload=True ) -LOB Objects ------------ - -cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set -is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default, -SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First, -the LOB object requires an active cursor association, meaning if you were to fetch many rows -at once such that cx_oracle had to go back to the database and fetch a new batch of rows, -the LOB objects in the already-fetched rows are now unreadable and will raise an error. -SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read. -The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy -defaults to 50 (cx_oracle normally defaults this to one). - -Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to -"normalize" the results to look more like other DBAPIs. - -The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place -for all statement executions, even plain string-based statements for which SQLA has no awareness -of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases -without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch" -of LOB objects, can be disabled using auto_convert_lobs=False. - LIMIT/OFFSET Support -------------------- @@ -100,12 +55,6 @@ http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html . Note tha this was stepping into the bounds of optimization that is better left on the DBA side, but this prefix can be added by enabling the optimize_limits=True flag on create_engine(). -Two Phase Transaction Support ------------------------------ - -Two Phase transactions are implemented using XA transactions. Success has been reported of them -working successfully but this should be regarded as an experimental feature. - Oracle 8 Compatibility ---------------------- @@ -127,29 +76,13 @@ import datetime, random, re from sqlalchemy import util, sql, schema, log from sqlalchemy.engine import default, base -from sqlalchemy.sql import compiler, visitors +from sqlalchemy.sql import compiler, visitors, expression from sqlalchemy.sql import operators as sql_operators, functions as sql_functions from sqlalchemy import types as sqltypes - -class OracleNumeric(sqltypes.Numeric): - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - -class OracleInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - -class OracleSmallInteger(sqltypes.SmallInteger): - def get_col_spec(self): - return "SMALLINT" +RESERVED_WORDS = set('''SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR DECIMAL UNION PUBLIC AND START'''.split()) class OracleDate(sqltypes.Date): - def get_col_spec(self): - return "DATE" def bind_processor(self, dialect): return None @@ -162,9 +95,6 @@ class OracleDate(sqltypes.Date): return process class OracleDateTime(sqltypes.DateTime): - def get_col_spec(self): - return "DATE" - def result_processor(self, dialect): def process(value): if value is None or isinstance(value, datetime.datetime): @@ -182,12 +112,6 @@ class OracleDateTime(sqltypes.DateTime): # only if cx_oracle contains TIMESTAMP class OracleTimestamp(sqltypes.TIMESTAMP): - def get_col_spec(self): - return "TIMESTAMP" - - def get_dbapi_type(self, dialect): - return dialect.TIMESTAMP - def result_processor(self, dialect): def process(value): if value is None or isinstance(value, datetime.datetime): @@ -198,21 +122,10 @@ class OracleTimestamp(sqltypes.TIMESTAMP): value.day,value.hour, value.minute, value.second) return process -class OracleString(sqltypes.String): - def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} - -class OracleNVarchar(sqltypes.Unicode, OracleString): - def get_col_spec(self): - return "NVARCHAR2(%(length)s)" % {'length' : self.length} - class OracleText(sqltypes.Text): def get_dbapi_type(self, dbapi): return dbapi.CLOB - def get_col_spec(self): - return "CLOB" - def result_processor(self, dialect): super_process = super(OracleText, self).result_processor(dialect) if not dialect.auto_convert_lobs: @@ -232,17 +145,10 @@ class OracleText(sqltypes.Text): return process -class OracleChar(sqltypes.CHAR): - def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} - class OracleBinary(sqltypes.Binary): def get_dbapi_type(self, dbapi): return dbapi.BLOB - def get_col_spec(self): - return "BLOB" - def bind_processor(self, dialect): return None @@ -262,9 +168,6 @@ class OracleRaw(OracleBinary): return "RAW(%(length)s)" % {'length' : self.length} class OracleBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "SMALLINT" - def result_processor(self, dialect): def process(value): if value is None: @@ -285,200 +188,297 @@ class OracleBoolean(sqltypes.Boolean): return process colspecs = { - sqltypes.Integer : OracleInteger, - sqltypes.SmallInteger : OracleSmallInteger, - sqltypes.Numeric : OracleNumeric, - sqltypes.Float : OracleNumeric, sqltypes.DateTime : OracleDateTime, sqltypes.Date : OracleDate, - sqltypes.String : OracleString, sqltypes.Binary : OracleBinary, sqltypes.Boolean : OracleBoolean, sqltypes.Text : OracleText, sqltypes.TIMESTAMP : OracleTimestamp, - sqltypes.CHAR: OracleChar, } ischema_names = { - 'VARCHAR2' : OracleString, - 'NVARCHAR2' : OracleNVarchar, - 'CHAR' : OracleString, - 'DATE' : OracleDateTime, - 'DATETIME' : OracleDateTime, - 'NUMBER' : OracleNumeric, - 'BLOB' : OracleBinary, - 'BFILE' : OracleBinary, - 'CLOB' : OracleText, - 'TIMESTAMP' : OracleTimestamp, + 'VARCHAR2' : sqltypes.VARCHAR, + 'NVARCHAR2' : sqltypes.NVARCHAR, + 'CHAR' : sqltypes.CHAR, + 'DATE' : sqltypes.DATE, + 'DATETIME' : sqltypes.DATETIME, + 'NUMBER' : sqltypes.Numeric, + 'BLOB' : sqltypes.BLOB, + 'BFILE' : sqltypes.Binary, + 'CLOB' : sqltypes.CLOB, + 'TIMESTAMP' : sqltypes.TIMESTAMP, 'RAW' : OracleRaw, - 'FLOAT' : OracleNumeric, - 'DOUBLE PRECISION' : OracleNumeric, - 'LONG' : OracleText, + 'FLOAT' : sqltypes.Float, + 'DOUBLE PRECISION' : sqltypes.Numeric, + 'LONG' : sqltypes.Text, } -class OracleExecutionContext(default.DefaultExecutionContext): - def pre_exec(self): - super(OracleExecutionContext, self).pre_exec() - if self.dialect.auto_setinputsizes: - self.set_input_sizes() - if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: - for key in self.compiled.binds: - bindparam = self.compiled.binds[key] - name = self.compiled.bind_names[bindparam] - value = self.compiled_parameters[0][name] - if bindparam.isoutparam: - dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if not hasattr(self, 'out_parameters'): - self.out_parameters = {} - self.out_parameters[name] = self.cursor.var(dbtype) - self.parameters[0][name] = self.out_parameters[name] - - def create_cursor(self): - c = self._connection.connection.cursor() - if self.dialect.arraysize: - c.cursor.arraysize = self.dialect.arraysize - return c - - def get_result_proxy(self): - if hasattr(self, 'out_parameters'): - if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: - for bind, name in self.compiled.bind_names.iteritems(): - if name in self.out_parameters: - type = bind.type - result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect) - if result_processor is not None: - self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue()) - else: - self.out_parameters[name] = self.out_parameters[name].getvalue() - else: - for k in self.out_parameters: - self.out_parameters[k] = self.out_parameters[k].getvalue() - if self.cursor.description is not None: - for column in self.cursor.description: - type_code = column[1] - if type_code in self.dialect.ORACLE_BINARY_TYPES: - return base.BufferedColumnResultProxy(self) +class OracleTypeCompiler(compiler.GenericTypeCompiler): + # Note: + # Oracle DATE == DATETIME + # Oracle does not allow milliseconds in DATE + # Oracle does not support TIME columns + + def visit_DATETIME(self, type_): + return self.visit_DATE(type_) + + def visit_VARCHAR(self, type_): + return "VARCHAR(%(length)s)" % {'length' : type_.length} - return base.ResultProxy(self) + def visit_NVARCHAR(self, type_): + return "NVARCHAR2(%(length)s)" % {'length' : type_.length} + + def visit_TEXT(self, type_): + return self.visit_CLOB(type_) -class OracleDialect(default.DefaultDialect): - name = 'oracle' - supports_alter = True - supports_unicode_statements = False - max_identifier_length = 30 - supports_sane_rowcount = True - supports_sane_multi_rowcount = False - preexecute_pk_sequences = True - supports_pk_autoincrement = False - default_paramstyle = 'named' + def visit_BINARY(self, type_): + return self.visit_BLOB(type_) + + def visit_BOOLEAN(self, type_): + return self.visit_SMALLINT(type_) + + def visit_RAW(self, type_): + return "RAW(%(length)s)" % {'length' : type_.length} - def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, optimize_limits=False, arraysize=50, **kwargs): - default.DefaultDialect.__init__(self, **kwargs) - self.use_ansi = use_ansi - self.threaded = threaded - self.arraysize = arraysize - self.allow_twophase = allow_twophase - self.optimize_limits = optimize_limits - self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) - self.auto_setinputsizes = auto_setinputsizes - self.auto_convert_lobs = auto_convert_lobs - if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__: - self.dbapi_type_map = {} - self.ORACLE_BINARY_TYPES = [] +class OracleCompiler(compiler.SQLCompiler): + """Oracle compiler modifies the lexical structure of Select + statements to work under non-ANSI configured Oracle databases, if + the use_ansi flag is False. + """ + + operators = util.update_copy( + compiler.SQLCompiler.operators, + { + sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y), + sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) + } + ) + + functions = util.update_copy( + compiler.SQLCompiler.functions, + { + sql_functions.now : 'CURRENT_TIMESTAMP' + } + ) + + def __init__(self, *args, **kwargs): + super(OracleCompiler, self).__init__(*args, **kwargs) + self.__wheres = {} + self._quoted_bind_names = {} + + def bindparam_string(self, name): + if self.preparer._bindparam_requires_quotes(name): + quoted_name = '"%s"' % name + self._quoted_bind_names[name] = quoted_name + return compiler.SQLCompiler.bindparam_string(self, quoted_name) else: - # only use this for LOB objects. using it for strings, dates - # etc. leads to a little too much magic, reflection doesn't know if it should - # expect encoded strings or unicodes, etc. - self.dbapi_type_map = { - self.dbapi.CLOB: OracleText(), - self.dbapi.BLOB: OracleBinary(), - self.dbapi.BINARY: OracleRaw(), - } - self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)] - - def dbapi(cls): - import cx_Oracle - return cx_Oracle - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - dialect_opts = dict(url.query) - for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs', - 'threaded', 'allow_twophase'): - if opt in dialect_opts: - util.coerce_kw_type(dialect_opts, opt, bool) - setattr(self, opt, dialect_opts[opt]) - - if url.database: - # if we have a database, then we have a remote host - port = url.port - if port: - port = int(port) + return compiler.SQLCompiler.bindparam_string(self, name) + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. + + The Oracle compiler tacks a "FROM DUAL" to the statement. + """ + + return " FROM DUAL" + + def apply_function_parens(self, func): + return len(func.clauses) > 0 + + def visit_join(self, join, **kwargs): + if self.dialect.use_ansi: + return compiler.SQLCompiler.visit_join(self, join, **kwargs) + else: + return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) + + def _get_nonansi_join_whereclause(self, froms): + clauses = [] + + def visit_join(join): + if join.isouter: + def visit_binary(binary): + if binary.operator == sql_operators.eq: + if binary.left.table is join.right: + binary.left = _OuterJoinColumn(binary.left) + elif binary.right.table is join.right: + binary.right = _OuterJoinColumn(binary.right) + clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary})) else: - port = 1521 - dsn = self.dbapi.makedsn(url.host, port, url.database) + clauses.append(join.onclause) + + for f in froms: + visitors.traverse(f, {}, {'join':visit_join}) + return sql.and_(*clauses) + + def visit_outer_join_column(self, vc): + return self.process(vc.column) + "(+)" + + def visit_sequence(self, seq): + return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" + + def visit_alias(self, alias, asfrom=False, **kwargs): + """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" + + if asfrom: + return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name)) else: - # we have a local tnsname - dsn = url.host - - opts = dict( - user=url.username, - password=url.password, - dsn=dsn, - threaded=self.threaded, - twophase=self.allow_twophase, - ) - if 'mode' in url.query: - opts['mode'] = url.query['mode'] - if isinstance(opts['mode'], basestring): - mode = opts['mode'].upper() - if mode == 'SYSDBA': - opts['mode'] = self.dbapi.SYSDBA - elif mode == 'SYSOPER': - opts['mode'] = self.dbapi.SYSOPER + return self.process(alias.original, **kwargs) + + def _TODO_visit_compound_select(self, select): + """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" + pass + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``rownum`` criterion. + """ + + if not getattr(select, '_oracle_visit', None): + if not self.dialect.use_ansi: + if self.stack and 'from' in self.stack[-1]: + existingfroms = self.stack[-1]['from'] else: - util.coerce_kw_type(opts, 'mode', int) - # Can't set 'handle' or 'pool' via URL query args, use connect_args + existingfroms = None - return ([], opts) + froms = select._get_display_froms(existingfroms) + whereclause = self._get_nonansi_join_whereclause(froms) + if whereclause: + select = select.where(whereclause) + select._oracle_visit = True - def is_disconnect(self, e): - if isinstance(e, self.dbapi.InterfaceError): - return "not connected" in str(e) - else: - return "ORA-03114" in str(e) or "ORA-03113" in str(e) + if select._limit is not None or select._offset is not None: + # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html + # + # Generalized form of an Oracle pagination query: + # select ... from ( + # select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from ( + # select distinct ... where ... order by ... + # ) where ROWNUM <= :limit+:offset + # ) where ora_rn > :offset + # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0 - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) + # TODO: use annotations instead of clone + attr set ? + select = select._generate() + select._oracle_visit = True - def create_xid(self): - """create a two-phase transaction ID. + # Wrap the middle select and add the hint + limitselect = sql.select([c for c in select.c]) + if select._limit and self.dialect.optimize_limits: + limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit) - this id will be passed to do_begin_twophase(), do_rollback_twophase(), - do_commit_twophase(). its format is unspecified.""" + limitselect._oracle_visit = True + limitselect._is_wrapper = True - id = random.randint(0, 2 ** 128) - return (0x1234, "%032x" % id, "%032x" % 9) - - def do_release_savepoint(self, connection, name): - # Oracle does not support RELEASE SAVEPOINT - pass + # If needed, add the limiting clause + if select._limit is not None: + max_row = select._limit + if select._offset is not None: + max_row += select._offset + limitselect.append_whereclause( + sql.literal_column("ROWNUM")<=max_row) - def do_begin_twophase(self, connection, xid): - connection.connection.begin(*xid) + # If needed, add the ora_rn, and wrap again with offset. + if select._offset is None: + select = limitselect + else: + limitselect = limitselect.column( + sql.literal_column("ROWNUM").label("ora_rn")) + limitselect._oracle_visit = True + limitselect._is_wrapper = True - def do_prepare_twophase(self, connection, xid): - connection.connection.prepare() + offsetselect = sql.select( + [c for c in limitselect.c if c.key!='ora_rn']) + offsetselect._oracle_visit = True + offsetselect._is_wrapper = True - def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): - self.do_rollback(connection.connection) + offsetselect.append_whereclause( + sql.literal_column("ora_rn")>select._offset) - def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): - self.do_commit(connection.connection) + select = offsetselect - def do_recover_twophase(self, connection): - pass + kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def limit_clause(self, select): + return "" + + def for_update_clause(self, select): + if select.for_update == "nowait": + return " FOR UPDATE NOWAIT" + else: + return super(OracleCompiler, self).for_update_clause(select) + +class OracleDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + colspec += " " + self.dialect.type_compiler.process(column.type) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + return colspec + + def visit_create_sequence(self, create): + return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element) + + def visit_drop_sequence(self, drop): + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + +class OracleDefaultRunner(base.DefaultRunner): + def visit_sequence(self, seq): + return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {}) + +class OracleIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = set([x.lower() for x in RESERVED_WORDS]) + + def _bindparam_requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return (lc_value in self.reserved_words + or self.illegal_initial_characters.match(value[0]) + or not self.legal_characters.match(unicode(value)) + ) + + def format_savepoint(self, savepoint): + name = re.sub(r'^_+', '', savepoint.ident) + return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) + +class OracleDialect(default.DefaultDialect): + name = 'oracle' + supports_alter = True + supports_unicode_statements = False + supports_unicode_binds = False + max_identifier_length = 30 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + supports_sequences = True + sequences_optional = False + preexecute_pk_sequences = True + supports_pk_autoincrement = False + default_paramstyle = 'named' + colspecs = colspecs + ischema_names = ischema_names + + supports_default_values = False + supports_empty_insert = False + + statement_compiler = OracleCompiler + ddl_compiler = OracleDDLCompiler + type_compiler = OracleTypeCompiler + preparer = OracleIdentifierPreparer + defaultrunner = OracleDefaultRunner + + def __init__(self, + use_ansi=True, + optimize_limits=False, + **kwargs): + default.DefaultDialect.__init__(self, **kwargs) + self.use_ansi = use_ansi + self.optimize_limits = optimize_limits def has_table(self, connection, table_name, schema=None): if not schema: @@ -508,10 +508,9 @@ class OracleDialect(default.DefaultDialect): else: return name.encode(self.encoding) + @base.connection_memoize(('dialect', 'default_schema_name')) def get_default_schema_name(self, connection): return self._normalize_name(connection.execute('SELECT USER FROM DUAL').scalar()) - get_default_schema_name = base.connection_memoize( - ('dialect', 'default_schema_name'))(get_default_schema_name) def table_names(self, connection, schema): # note that table_names() isnt loading DBLINKed or synonym'ed tables @@ -601,17 +600,17 @@ class OracleDialect(default.DefaultDialect): #length is ignored except for CHAR and VARCHAR2 if coltype == 'NUMBER' : if precision is None and scale is None: - coltype = OracleNumeric + coltype = sqltypes.NUMERIC elif precision is None and scale == 0 : - coltype = OracleInteger + coltype = sqltypes.INTEGER else : - coltype = OracleNumeric(precision, scale) + coltype = sqltypes.NUMERIC(precision, scale) elif coltype=='CHAR' or coltype=='VARCHAR2': - coltype = ischema_names.get(coltype, OracleString)(length) + coltype = self.ischema_names.get(coltype)(length) else: coltype = re.sub(r'\(\d+\)', '', coltype) try: - coltype = ischema_names[coltype] + coltype = self.ischema_names[coltype] except KeyError: util.warn("Did not recognize type '%s' of column '%s'" % (coltype, colname)) @@ -653,8 +652,9 @@ class OracleDialect(default.DefaultDialect): if row is None: break #print "ROW:" , row - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]]) - if cons_type == 'P': + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ + row[0:2] + tuple([self._normalize_name(x) for x in row[2:]]) + if cons_type == 'P' and local_column in table.c: table.primary_key.add(table.c[local_column]) elif cons_type == 'R': try: @@ -698,203 +698,5 @@ class _OuterJoinColumn(sql.ClauseElement): def __init__(self, column): self.column = column -class OracleCompiler(compiler.SQLCompiler): - """Oracle compiler modifies the lexical structure of Select - statements to work under non-ANSI configured Oracle databases, if - the use_ansi flag is False. - """ - - operators = compiler.SQLCompiler.operators.copy() - operators.update( - { - sql_operators.mod : lambda x, y:"mod(%s, %s)" % (x, y), - sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) - } - ) - - functions = compiler.SQLCompiler.functions.copy() - functions.update ( - { - sql_functions.now : 'CURRENT_TIMESTAMP' - } - ) - - def __init__(self, *args, **kwargs): - super(OracleCompiler, self).__init__(*args, **kwargs) - self.__wheres = {} - - def default_from(self): - """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. - - The Oracle compiler tacks a "FROM DUAL" to the statement. - """ - - return " FROM DUAL" - - def apply_function_parens(self, func): - return len(func.clauses) > 0 - - def visit_join(self, join, **kwargs): - if self.dialect.use_ansi: - return compiler.SQLCompiler.visit_join(self, join, **kwargs) - else: - return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) - - def _get_nonansi_join_whereclause(self, froms): - clauses = [] - - def visit_join(join): - if join.isouter: - def visit_binary(binary): - if binary.operator == sql_operators.eq: - if binary.left.table is join.right: - binary.left = _OuterJoinColumn(binary.left) - elif binary.right.table is join.right: - binary.right = _OuterJoinColumn(binary.right) - clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary})) - else: - clauses.append(join.onclause) - - for f in froms: - visitors.traverse(f, {}, {'join':visit_join}) - return sql.and_(*clauses) - - def visit_outer_join_column(self, vc): - return self.process(vc.column) + "(+)" - - def visit_sequence(self, seq): - return self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" - - def visit_alias(self, alias, asfrom=False, **kwargs): - """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" - - if asfrom: - return self.process(alias.original, asfrom=asfrom, **kwargs) + " " + self.preparer.format_alias(alias, self._anonymize(alias.name)) - else: - return self.process(alias.original, **kwargs) - - def _TODO_visit_compound_select(self, select): - """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" - pass - - def visit_select(self, select, **kwargs): - """Look for ``LIMIT`` and OFFSET in a select statement, and if - so tries to wrap it in a subquery with ``rownum`` criterion. - """ - - if not getattr(select, '_oracle_visit', None): - if not self.dialect.use_ansi: - if self.stack and 'from' in self.stack[-1]: - existingfroms = self.stack[-1]['from'] - else: - existingfroms = None - - froms = select._get_display_froms(existingfroms) - whereclause = self._get_nonansi_join_whereclause(froms) - if whereclause: - select = select.where(whereclause) - select._oracle_visit = True - - if select._limit is not None or select._offset is not None: - # See http://www.oracle.com/technology/oramag/oracle/06-sep/o56asktom.html - # - # Generalized form of an Oracle pagination query: - # select ... from ( - # select /*+ FIRST_ROWS(N) */ ...., rownum as ora_rn from ( - # select distinct ... where ... order by ... - # ) where ROWNUM <= :limit+:offset - # ) where ora_rn > :offset - # Outer select and "ROWNUM as ora_rn" can be dropped if limit=0 - - # TODO: use annotations instead of clone + attr set ? - select = select._generate() - select._oracle_visit = True - - # Wrap the middle select and add the hint - limitselect = sql.select([c for c in select.c]) - if select._limit and self.dialect.optimize_limits: - limitselect = limitselect.prefix_with("/*+ FIRST_ROWS(%d) */" % select._limit) - - limitselect._oracle_visit = True - limitselect._is_wrapper = True - - # If needed, add the limiting clause - if select._limit is not None: - max_row = select._limit - if select._offset is not None: - max_row += select._offset - limitselect.append_whereclause( - sql.literal_column("ROWNUM")<=max_row) - - # If needed, add the ora_rn, and wrap again with offset. - if select._offset is None: - select = limitselect - else: - limitselect = limitselect.column( - sql.literal_column("ROWNUM").label("ora_rn")) - limitselect._oracle_visit = True - limitselect._is_wrapper = True - - offsetselect = sql.select( - [c for c in limitselect.c if c.key!='ora_rn']) - offsetselect._oracle_visit = True - offsetselect._is_wrapper = True - - offsetselect.append_whereclause( - sql.literal_column("ora_rn")>select._offset) - - select = offsetselect - - kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) - return compiler.SQLCompiler.visit_select(self, select, **kwargs) - - def limit_clause(self, select): - return "" - - def for_update_clause(self, select): - if select.for_update == "nowait": - return " FOR UPDATE NOWAIT" - else: - return super(OracleCompiler, self).for_update_clause(select) - - -class OracleSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) - colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - if not column.nullable: - colspec += " NOT NULL" - return colspec - - def visit_sequence(self, sequence): - if not self.checkfirst or not self.dialect.has_sequence(self.connection, sequence.name, sequence.schema): - self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - -class OracleSchemaDropper(compiler.SchemaDropper): - def visit_sequence(self, sequence): - if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name, sequence.schema): - self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence)) - self.execute() - -class OracleDefaultRunner(base.DefaultRunner): - def visit_sequence(self, seq): - return self.execute_string("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", {}) - -class OracleIdentifierPreparer(compiler.IdentifierPreparer): - def format_savepoint(self, savepoint): - name = re.sub(r'^_+', '', savepoint.ident) - return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) -dialect = OracleDialect -dialect.statement_compiler = OracleCompiler -dialect.schemagenerator = OracleSchemaGenerator -dialect.schemadropper = OracleSchemaDropper -dialect.preparer = OracleIdentifierPreparer -dialect.defaultrunner = OracleDefaultRunner -dialect.execution_ctx_cls = OracleExecutionContext diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py new file mode 100644 index 0000000000..b899d4438d --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -0,0 +1,252 @@ +"""Support for the Oracle database via the cx_oracle driver. + +Driver +------ + +The Oracle dialect uses the cx_oracle driver, available at +http://cx-oracle.sourceforge.net/ . The dialect has several behaviors +which are specifically tailored towards compatibility with this module. + +Connecting +---------- + +Connecting with create_engine() uses the standard URL approach of +``oracle://user:pass@host:port/dbname[?key=value&key=value...]``. If dbname is present, the +host, port, and dbname tokens are converted to a TNS name using the cx_oracle +:func:`makedsn()` function. Otherwise, the host token is taken directly as a TNS name. + +Additional arguments which may be specified either as query string arguments on the +URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are: + +* *allow_twophase* - enable two-phase transactions. Defaults to ``True``. + +* *auto_convert_lobs* - defaults to True, see the section on LOB objects. + +* *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters. + This is required for LOB datatypes but can be disabled to reduce overhead. Defaults + to ``True``. + +* *mode* - This is given the string value of SYSDBA or SYSOPER, or alternatively an + integer value. This value is only available as a URL query string argument. + +* *threaded* - enable multithreaded access to cx_oracle connections. Defaults + to ``True``. Note that this is the opposite default of cx_oracle itself. + + +LOB Objects +----------- + +cx_oracle presents some challenges when fetching LOB objects. A LOB object in a result set +is presented by cx_oracle as a cx_oracle.LOB object which has a read() method. By default, +SQLAlchemy converts these LOB objects into Python strings. This is for two reasons. First, +the LOB object requires an active cursor association, meaning if you were to fetch many rows +at once such that cx_oracle had to go back to the database and fetch a new batch of rows, +the LOB objects in the already-fetched rows are now unreadable and will raise an error. +SQLA "pre-reads" all LOBs so that their data is fetched before further rows are read. +The size of a "batch of rows" is controlled by the cursor.arraysize value, which SQLAlchemy +defaults to 50 (cx_oracle normally defaults this to one). + +Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to +"normalize" the results to look more like other DBAPIs. + +The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place +for all statement executions, even plain string-based statements for which SQLA has no awareness +of result typing. This is so that calls like fetchmany() and fetchall() can work in all cases +without raising cursor errors. The conversion of LOB in all cases, as well as the "prefetch" +of LOB objects, can be disabled using auto_convert_lobs=False. + +Two Phase Transaction Support +----------------------------- + +Two Phase transactions are implemented using XA transactions. Success has been reported of them +working successfully but this should be regarded as an experimental feature. + +""" + +from sqlalchemy.dialects.oracle.base import OracleDialect, OracleText, OracleBinary, OracleRaw, RESERVED_WORDS +from sqlalchemy.engine.default import DefaultExecutionContext +from sqlalchemy.engine import base +from sqlalchemy import types as sqltypes, util + +class OracleNVarchar(sqltypes.NVARCHAR): + """The SQL NVARCHAR type.""" + + def __init__(self, **kw): + kw['convert_unicode'] = False # cx_oracle does this for us, for NVARCHAR2 + sqltypes.NVARCHAR.__init__(self, **kw) + +class Oracle_cx_oracleExecutionContext(DefaultExecutionContext): + def pre_exec(self): + + quoted_bind_names = getattr(self.compiled, '_quoted_bind_names', {}) + if quoted_bind_names: + for param in self.parameters: + for fromname, toname in self.compiled._quoted_bind_names.iteritems(): + param[toname.encode(self.dialect.encoding)] = param[fromname] + del param[fromname] + + if self.dialect.auto_setinputsizes: + self.set_input_sizes(quoted_bind_names) + + if len(self.compiled_parameters) == 1: + for key in self.compiled.binds: + bindparam = self.compiled.binds[key] + name = self.compiled.bind_names[bindparam] + value = self.compiled_parameters[0][name] + if bindparam.isoutparam: + dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if not hasattr(self, 'out_parameters'): + self.out_parameters = {} + self.out_parameters[name] = self.cursor.var(dbtype) + self.parameters[0][quoted_bind_names.get(name, name)] = self.out_parameters[name] + + + def create_cursor(self): + c = self._connection.connection.cursor() + if self.dialect.arraysize: + c.cursor.arraysize = self.dialect.arraysize + return c + + def get_result_proxy(self): + if hasattr(self, 'out_parameters'): + if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: + for bind, name in self.compiled.bind_names.iteritems(): + if name in self.out_parameters: + type = bind.type + result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect) + if result_processor is not None: + self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue()) + else: + self.out_parameters[name] = self.out_parameters[name].getvalue() + else: + for k in self.out_parameters: + self.out_parameters[k] = self.out_parameters[k].getvalue() + + if self.cursor.description is not None: + for column in self.cursor.description: + type_code = column[1] + if type_code in self.dialect.ORACLE_BINARY_TYPES: + return base.BufferedColumnResultProxy(self) + + return base.ResultProxy(self) + + +class Oracle_cx_oracle(OracleDialect): + execution_ctx_cls = Oracle_cx_oracleExecutionContext + driver = "cx_oracle" + + colspecs = util.update_copy( + OracleDialect.colspecs, + { + sqltypes.NVARCHAR:OracleNVarchar + } + ) + + def __init__(self, + auto_setinputsizes=True, + auto_convert_lobs=True, + threaded=True, + allow_twophase=True, + arraysize=50, **kwargs): + OracleDialect.__init__(self, **kwargs) + self.threaded = threaded + self.arraysize = arraysize + self.allow_twophase = allow_twophase + self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) + self.auto_setinputsizes = auto_setinputsizes + self.auto_convert_lobs = auto_convert_lobs + if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__: + self.dbapi_type_map = {} + self.ORACLE_BINARY_TYPES = [] + else: + # only use this for LOB objects. using it for strings, dates + # etc. leads to a little too much magic, reflection doesn't know if it should + # expect encoded strings or unicodes, etc. + self.dbapi_type_map = { + self.dbapi.CLOB: OracleText(), + self.dbapi.BLOB: OracleBinary(), + self.dbapi.BINARY: OracleRaw(), + } + self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)] + + @classmethod + def dbapi(cls): + import cx_Oracle + return cx_Oracle + + def create_connect_args(self, url): + dialect_opts = dict(url.query) + for opt in ('use_ansi', 'auto_setinputsizes', 'auto_convert_lobs', + 'threaded', 'allow_twophase'): + if opt in dialect_opts: + util.coerce_kw_type(dialect_opts, opt, bool) + setattr(self, opt, dialect_opts[opt]) + + if url.database: + # if we have a database, then we have a remote host + port = url.port + if port: + port = int(port) + else: + port = 1521 + dsn = self.dbapi.makedsn(url.host, port, url.database) + else: + # we have a local tnsname + dsn = url.host + + opts = dict( + user=url.username, + password=url.password, + dsn=dsn, + threaded=self.threaded, + twophase=self.allow_twophase, + ) + if 'mode' in url.query: + opts['mode'] = url.query['mode'] + if isinstance(opts['mode'], basestring): + mode = opts['mode'].upper() + if mode == 'SYSDBA': + opts['mode'] = self.dbapi.SYSDBA + elif mode == 'SYSOPER': + opts['mode'] = self.dbapi.SYSOPER + else: + util.coerce_kw_type(opts, 'mode', int) + # Can't set 'handle' or 'pool' via URL query args, use connect_args + + return ([], opts) + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.InterfaceError): + return "not connected" in str(e) + else: + return "ORA-03114" in str(e) or "ORA-03113" in str(e) + + def create_xid(self): + """create a two-phase transaction ID. + + this id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). its format is unspecified.""" + + id = random.randint(0, 2 ** 128) + return (0x1234, "%032x" % id, "%032x" % 9) + + def do_release_savepoint(self, connection, name): + # Oracle does not support RELEASE SAVEPOINT + pass + + def do_begin_twophase(self, connection, xid): + connection.connection.begin(*xid) + + def do_prepare_twophase(self, connection, xid): + connection.connection.prepare() + + def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): + self.do_rollback(connection.connection) + + def do_commit_twophase(self, connection, xid, is_prepared=True, recover=False): + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + pass + +dialect = Oracle_cx_oracle diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index 8fd4ef5ef2..ce1fa1507d 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -155,28 +155,28 @@ colspecs = { } ischema_names = { - 'integer' : sqltypes.Integer, + 'integer' : sqltypes.INTEGER, 'bigint' : PGBigInteger, - 'smallint' : sqltypes.SmallInteger, - 'character varying' : sqltypes.String, + 'smallint' : sqltypes.SMALLINT, + 'character varying' : sqltypes.VARCHAR, 'character' : sqltypes.CHAR, - 'text' : sqltypes.Text, - 'numeric' : sqltypes.Numeric, - 'float' : sqltypes.Float, + 'text' : sqltypes.TEXT, + 'numeric' : sqltypes.NUMERIC, + 'float' : sqltypes.FLOAT, 'real' : sqltypes.Float, 'inet': PGInet, 'cidr': PGCidr, 'macaddr': PGMacAddr, 'double precision' : sqltypes.Float, - 'timestamp' : sqltypes.DateTime, - 'timestamp with time zone' : sqltypes.DateTime, - 'timestamp without time zone' : sqltypes.DateTime, - 'time with time zone' : sqltypes.Time, - 'time without time zone' : sqltypes.Time, - 'date' : sqltypes.Date, - 'time': sqltypes.Time, + 'timestamp' : sqltypes.TIMESTAMP, + 'timestamp with time zone' : sqltypes.TIMESTAMP, + 'timestamp without time zone' : sqltypes.TIMESTAMP, + 'time with time zone' : sqltypes.TIME, + 'time without time zone' : sqltypes.TIME, + 'date' : sqltypes.DATE, + 'time': sqltypes.TIME, 'bytea' : sqltypes.Binary, - 'boolean' : sqltypes.Boolean, + 'boolean' : sqltypes.BOOLEAN, 'interval':PGInterval, } @@ -490,9 +490,6 @@ class PGDialect(default.DefaultDialect): raise AssertionError("Could not determine version from string '%s'" % v) return tuple([int(x) for x in m.group(1, 2, 3)]) - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, self.colspecs) - def reflecttable(self, connection, table, include_columns): preparer = self.identifier_preparer if table.schema is not None: diff --git a/lib/sqlalchemy/dialects/postgres/psycopg2.py b/lib/sqlalchemy/dialects/postgres/psycopg2.py index b90ac8d9c6..364c13236d 100644 --- a/lib/sqlalchemy/dialects/postgres/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgres/psycopg2.py @@ -100,8 +100,6 @@ class Postgres_psycopg2Compiler(PGCompiler): ) def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.") return text.replace('%', '%%') class Postgres_psycopg2(PGDialect): diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 773501d64c..319a5bffc6 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -146,23 +146,23 @@ colspecs = { } ischema_names = { - 'BLOB': sqltypes.Binary, - 'BOOL': sqltypes.Boolean, - 'BOOLEAN': sqltypes.Boolean, + 'BLOB': sqltypes.BLOB, + 'BOOL': sqltypes.BOOLEAN, + 'BOOLEAN': sqltypes.BOOLEAN, 'CHAR': sqltypes.CHAR, - 'DATE': sqltypes.Date, - 'DATETIME': sqltypes.DateTime, - 'DECIMAL': sqltypes.Numeric, - 'FLOAT': sqltypes.Numeric, - 'INT': sqltypes.Integer, - 'INTEGER': sqltypes.Integer, - 'NUMERIC': sqltypes.Numeric, + 'DATE': sqltypes.DATE, + 'DATETIME': sqltypes.DATETIME, + 'DECIMAL': sqltypes.DECIMAL, + 'FLOAT': sqltypes.FLOAT, + 'INT': sqltypes.INTEGER, + 'INTEGER': sqltypes.INTEGER, + 'NUMERIC': sqltypes.NUMERIC, 'REAL': sqltypes.Numeric, - 'SMALLINT': sqltypes.SmallInteger, - 'TEXT': sqltypes.Text, - 'TIME': sqltypes.Time, - 'TIMESTAMP': sqltypes.DateTime, - 'VARCHAR': sqltypes.String, + 'SMALLINT': sqltypes.SMALLINT, + 'TEXT': sqltypes.TEXT, + 'TIME': sqltypes.TIME, + 'TIMESTAMP': sqltypes.TIMESTAMP, + 'VARCHAR': sqltypes.VARCHAR, } @@ -256,10 +256,8 @@ class SQLiteDialect(default.DefaultDialect): type_compiler = SQLiteTypeCompiler preparer = SQLiteIdentifierPreparer ischema_names = ischema_names - - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - + colspecs = colspecs + def table_names(self, connection, schema): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 1dc3d720ef..8be0a2d85f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -15,7 +15,7 @@ as the base class for their own corresponding classes. import re, random from sqlalchemy.engine import base from sqlalchemy.sql import compiler, expression -from sqlalchemy import exc +from sqlalchemy import exc, types as sqltypes AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', re.I | re.UNICODE) @@ -72,13 +72,12 @@ class DefaultDialect(base.Dialect): """Provide a database-specific ``TypeEngine`` object, given the generic object which comes from the types module. - Subclasses will usually use the ``adapt_type()`` method in the - types module to make this job easy. + This method looks for a dictionary called + ``colspecs`` as a class or instance-level variable, + and passes on to ``types.adapt_type()``. """ - if type(typeobj) is type: - typeobj = typeobj() - return typeobj + return sqltypes.adapt_type(typeobj, self.colspecs) def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: @@ -315,12 +314,16 @@ class DefaultExecutionContext(base.ExecutionContext): def lastrow_has_defaults(self): return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols) - def set_input_sizes(self): + def set_input_sizes(self, translate=None): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. + """ + if not hasattr(self.compiled, 'bind_names'): + return + types = dict( (self.compiled.bind_names[bindparam], bindparam.type) for bindparam in self.compiled.bind_names) @@ -343,6 +346,8 @@ class DefaultExecutionContext(base.ExecutionContext): typeengine = types[key] dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if dbtype is not None: + if translate: + key = translate.get(key, key) inputsizes[key.encode(self.dialect.encoding)] = dbtype try: self.cursor.setinputsizes(**inputsizes) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6838319987..d506efcacb 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -982,6 +982,9 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_VARCHAR(self, type_): return "VARCHAR" + (type_.length and "(%d)" % type_.length or "") + def visit_NVARCHAR(self, type_): + return "NVARCHAR" + (type_.length and "(%d)" % type_.length or "") + def visit_BLOB(self, type_): return "BLOB" diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 92ee125b63..ea8a8ceb3a 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -917,6 +917,8 @@ class TIMESTAMP(DateTime): __visit_name__ = 'TIMESTAMP' + def get_dbapi_type(self, dbapi): + return dbapi.TIMESTAMP class DATETIME(DateTime): """The SQL DATETIME type.""" @@ -951,6 +953,10 @@ class VARCHAR(String): __visit_name__ = 'VARCHAR' +class NVARCHAR(Unicode): + """The SQL NVARCHAR type.""" + + __visit_name__ = 'NVARCHAR' class CHAR(String): """The SQL CHAR type.""" diff --git a/test/dialect/oracle.py b/test/dialect/oracle.py index 2186f22595..c55e778a24 100644 --- a/test/dialect/oracle.py +++ b/test/dialect/oracle.py @@ -2,6 +2,7 @@ import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import types as sqltypes from sqlalchemy.sql import table, column from sqlalchemy.databases import oracle from testlib import * @@ -301,13 +302,13 @@ class TypesTest(TestBase, AssertsCompiledSQL): def test_reflect_nvarchar(self): metadata = MetaData(testing.db) t = Table('t', metadata, - Column('data', oracle.OracleNVarchar(255)) + Column('data', sqltypes.NVARCHAR(255)) ) metadata.create_all() try: m2 = MetaData(testing.db) t2 = Table('t', m2, autoload=True) - assert isinstance(t2.c.data.type, oracle.OracleNVarchar) + assert isinstance(t2.c.data.type, sqltypes.NVARCHAR) data = u'm’a réveillé.' t2.insert().execute(data=data) eq_(t2.select().execute().fetchone()['data'], data) diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 4e6601951f..cd037f6ca3 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -21,10 +21,10 @@ class ReflectionTest(TestBase, ComparesTables): Column('test2', sa.Float(5), nullable=False), Column('test3', sa.Text), Column('test4', sa.Numeric, nullable = False), - Column('test5', sa.DateTime), + Column('test5', sa.Date), Column('parent_user_id', sa.Integer, sa.ForeignKey('engine_users.user_id')), - Column('test6', sa.DateTime, nullable=False), + Column('test6', sa.Date, nullable=False), Column('test7', sa.Text), Column('test8', sa.Binary), Column('test_passivedefault2', sa.Integer, server_default='5'), diff --git a/test/sql/query.py b/test/sql/query.py index 660529c25c..0e45aff107 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -12,11 +12,11 @@ class QueryTest(TestBase): global users, users2, addresses, metadata metadata = MetaData(testing.db) users = Table('query_users', metadata, - Column('user_id', INT, primary_key = True), + Column('user_id', INT, Sequence('user_id_seq', optional=True), primary_key = True), Column('user_name', VARCHAR(20)), ) addresses = Table('query_addresses', metadata, - Column('address_id', Integer, primary_key=True), + Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key=True), Column('user_id', Integer, ForeignKey('query_users.user_id')), Column('address', String(30))) @@ -252,6 +252,7 @@ class QueryTest(TestBase): eq_(expr.execute().fetchall(), result) + @testing.fails_on("oracle", "neither % nor %% are accepted") @testing.fails_on("+pg8000", "can't interpret result column from '%%'") @testing.emits_warning('.*now automatically escapes.*') def test_percents_in_text(self): @@ -484,13 +485,15 @@ class QueryTest(TestBase): self.assert_(r['query_users.user_id']) == 1 self.assert_(r['query_users.user_name']) == "john" + @testing.fails_on('oracle', 'oracle result keys() are all uppercase, not getting into this.') def test_row_as_args(self): users.insert().execute(user_id=1, user_name='john') r = users.select(users.c.user_id==1).execute().fetchone() users.delete().execute() users.insert().execute(r) - assert users.select().execute().fetchall() == [(1, 'john')] + eq_(users.select().execute().fetchall(), [(1, 'john')]) + @testing.fails_on('oracle', 'oracle result keys() are all uppercase, not getting into this.') def test_result_as_args(self): users.insert().execute([dict(user_id=1, user_name='john'), dict(user_id=2, user_name='ed')]) r = users.select().execute() @@ -720,7 +723,7 @@ class PercentSchemaNamesTest(TestBase): result.close() percent_table.update().values({percent_table.c['%(oneofthese)s']:9, percent_table.c['spaces % more spaces']:15}).execute() eq_( - percent_table.select().order_by(percent_table.c['%(oneofthese)s']).execute().fetchall(), + percent_table.select().order_by(percent_table.c['percent%']).execute().fetchall(), [ (5, 9, 15), (7, 9, 15),