From cdceb3c3714af707bfe3ede10af6536eaf529ca8 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 2 Apr 2007 21:36:11 +0000 Subject: [PATCH] - merged the "execcontext" branch, refactors engine/dialect codepaths - much more functionality moved into ExecutionContext, which impacted the API used by dialects to some degree - ResultProxy and subclasses now designed sanely - merged patch for #522, Unicode subclasses String directly, MSNVarchar implements for MS-SQL, removed MSUnicode. - String moves its "VARCHAR"/"TEXT" switchy thing into "get_search_list()" function, which VARCHAR and CHAR can override to not return TEXT in any case (didnt do the latter yet) - implements server side cursors for postgres, unit tests, #514 - includes overhaul of dbapi import strategy #480, all dbapi importing happens in dialect method "dbapi()", is only called inside of create_engine() for default and threadlocal strategies. Dialect subclasses have a datamember "dbapi" referencing the loaded module which may be None. - added "mock" engine strategy, doesnt require DBAPI module and gives you a "Connecition" which just sends all executes to a callable. can be used to create string output of create_all()/drop_all(). --- CHANGES | 20 +- lib/sqlalchemy/ansisql.py | 48 ++- lib/sqlalchemy/databases/firebird.py | 73 ++-- lib/sqlalchemy/databases/mssql.py | 159 ++++---- lib/sqlalchemy/databases/mysql.py | 50 ++- lib/sqlalchemy/databases/oracle.py | 75 ++-- lib/sqlalchemy/databases/postgres.py | 138 ++++--- lib/sqlalchemy/databases/sqlite.py | 60 ++-- lib/sqlalchemy/engine/base.py | 518 ++++++++++++++------------- lib/sqlalchemy/engine/default.py | 126 ++++--- lib/sqlalchemy/engine/strategies.py | 64 +++- lib/sqlalchemy/engine/url.py | 4 + lib/sqlalchemy/logging.py | 4 +- lib/sqlalchemy/pool.py | 14 +- lib/sqlalchemy/sql.py | 2 +- lib/sqlalchemy/types.py | 82 ++--- lib/sqlalchemy/util.py | 4 + test/engine/reflection.py | 5 +- test/orm/inheritance5.py | 2 +- test/orm/mapper.py | 12 +- test/sql/constraints.py | 12 +- test/sql/query.py | 4 +- test/sql/testtypes.py | 51 ++- test/testbase.py | 158 ++++---- 24 files changed, 852 insertions(+), 833 deletions(-) diff --git a/CHANGES b/CHANGES index fc8077167f..41a2ac3837 100644 --- a/CHANGES +++ b/CHANGES @@ -1,5 +1,22 @@ 0.3.7 +- engines + - SA default loglevel is now "WARN". we have a few warnings + for things that should be available by default. + - cleanup of DBAPI import strategies across all engines + [ticket:480] + - refactoring of engine internals which reduces complexity, + number of codepaths; places more state inside of ExecutionContext + to allow more dialect control of cursor handling, result sets. + ResultProxy totally refactored and also has two versions of + "buffered" result sets used for different purposes. + - server side cursor support fully functional in postgres + [ticket:514]. - sql: + - the Unicode type is now a direct subclass of String, which now + contains all the "convert_unicode" logic. This helps the variety + of unicode situations that occur in db's such as MS-SQL to be + better handled and allows subclassing of the Unicode datatype. + [ticket:522] - column labels are now generated in the compilation phase, which means their lengths are dialect-dependent. So on oracle a label that gets truncated to 30 chars will go out to 63 characters @@ -11,7 +28,8 @@ full statement being compiled. this means the same statement will produce the same string across application restarts and allowing DB query plan caching to work better. - - preliminary support for unicode table and column names added. + - preliminary support for unicode table names, column names and + SQL statements added, for databases which can support them. - fix for fetchmany() "size" argument being positional in most dbapis [ticket:505] - sending None as an argument to func. will produce diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index a75263d915..03053b998c 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -49,14 +49,11 @@ class ANSIDialect(default.DefaultDialect): def create_connect_args(self): return ([],{}) - def dbapi(self): - return None + def schemagenerator(self, *args, **kwargs): + return ANSISchemaGenerator(self, *args, **kwargs) - def schemagenerator(self, *args, **params): - return ANSISchemaGenerator(*args, **params) - - def schemadropper(self, *args, **params): - return ANSISchemaDropper(*args, **params) + def schemadropper(self, *args, **kwargs): + return ANSISchemaDropper(self, *args, **kwargs) def compiler(self, statement, parameters, **kwargs): return ANSICompiler(self, statement, parameters, **kwargs) @@ -97,6 +94,9 @@ class ANSICompiler(sql.Compiled): sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs) + # if we are insert/update. set to true when we visit an INSERT or UPDATE + self.isinsert = self.isupdate = False + # a dictionary of bind parameter keys to _BindParamClause instances. self.binds = {} @@ -789,13 +789,12 @@ class ANSISchemaBase(engine.SchemaIterator): return alterables class ANSISchemaGenerator(ANSISchemaBase): - def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): - super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs) + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(ANSISchemaGenerator, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables and util.Set(tables) or None - self.connection = connection - self.preparer = self.engine.dialect.preparer() - self.dialect = self.engine.dialect + self.preparer = dialect.preparer() + self.dialect = dialect def get_column_specification(self, column, first_pk=False): raise NotImplementedError() @@ -804,7 +803,7 @@ class ANSISchemaGenerator(ANSISchemaBase): collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] for table in collection: table.accept_visitor(self) - if self.supports_alter(): + if self.dialect.supports_alter(): for alterable in self.find_alterables(collection): self.add_foreignkey(alterable) @@ -857,7 +856,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def _compile(self, tocompile, parameters): """compile the given string/parameters using this SchemaGenerator's dialect.""" - compiler = self.engine.dialect.compiler(tocompile, parameters) + compiler = self.dialect.compiler(tocompile, parameters) compiler.compile() return compiler @@ -880,11 +879,8 @@ class ANSISchemaGenerator(ANSISchemaBase): self.append("PRIMARY KEY ") self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) - def supports_alter(self): - return True - def visit_foreign_key_constraint(self, constraint): - if constraint.use_alter and self.supports_alter(): + if constraint.use_alter and self.dialect.supports_alter(): return self.append(", \n\t ") self.define_foreign_key(constraint) @@ -927,25 +923,21 @@ class ANSISchemaGenerator(ANSISchemaBase): self.execute() class ANSISchemaDropper(ANSISchemaBase): - def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): - super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs) + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(ANSISchemaDropper, self).__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables - self.connection = connection - self.preparer = self.engine.dialect.preparer() - self.dialect = self.engine.dialect + self.preparer = dialect.preparer() + self.dialect = dialect def visit_metadata(self, metadata): collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name, schema=t.schema))] - if self.supports_alter(): + if self.dialect.supports_alter(): for alterable in self.find_alterables(collection): self.drop_foreignkey(alterable) for table in collection: table.accept_visitor(self) - def supports_alter(self): - return True - def visit_index(self, index): self.append("\nDROP INDEX " + index.name) self.execute() @@ -1099,3 +1091,5 @@ class ANSIIdentifierPreparer(object): """Prepare a quoted column name with table name.""" return self.format_column(column, use_table=True, name=column_name) + +dialect = ANSIDialect diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 91a0869c61..2ab88101a9 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -15,12 +15,9 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions -try: +def dbapi(): import kinterbasdb -except: - kinterbasdb = None - -dbmodule = kinterbasdb + return kinterbasdb _initialized_kb = False @@ -33,7 +30,6 @@ class FBNumeric(sqltypes.Numeric): return "NUMERIC(%(precision)s, %(length)s)" % { 'precision': self.precision, 'length' : self.length } - class FBInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" @@ -111,24 +107,11 @@ class FBExecutionContext(default.DefaultExecutionContext): class FBDialect(ansisql.ANSIDialect): - def __init__(self, module = None, **params): - global _initialized_kb - self.module = module or dbmodule - self.opts = {} - - if not _initialized_kb: - _initialized_kb = True - type_conv = params.get('type_conv', 200) or 200 - if isinstance(type_conv, types.StringTypes): - type_conv = int(type_conv) - - concurrency_level = params.get('concurrency_level', 1) or 1 - if isinstance(concurrency_level, types.StringTypes): - concurrency_level = int(concurrency_level) + def __init__(self, type_conv=200, concurrency_level=1, **kwargs): + ansisql.ANSIDialect.__init__(self, **kwargs) - if kinterbasdb is not None: - kinterbasdb.init(type_conv=type_conv, concurrency_level=concurrency_level) - ansisql.ANSIDialect.__init__(self, **params) + self.type_conv = type_conv + self.concurrency_level= concurrency_level def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) @@ -136,15 +119,17 @@ class FBDialect(ansisql.ANSIDialect): opts['host'] = "%s/%s" % (opts['host'], opts['port']) del opts['port'] opts.update(url.query) - # pop arguments that we took at the module level - opts.pop('type_conv', None) - opts.pop('concurrency_level', None) - self.opts = opts - return ([], self.opts) + type_conv = opts.pop('type_conv', self.type_conv) + concurrency_level = opts.pop('concurrency_level', self.concurrency_level) + global _initialized_kb + if not _initialized_kb and self.dbapi is not None: + _initialized_kb = True + self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level) + return ([], opts) - def create_execution_context(self): - return FBExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return FBExecutionContext(self, *args, **kwargs) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) @@ -156,13 +141,13 @@ class FBDialect(ansisql.ANSIDialect): return FBCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return FBSchemaGenerator(*args, **kwargs) + return FBSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return FBSchemaDropper(*args, **kwargs) + return FBSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, engine, proxy): - return FBDefaultRunner(engine, proxy) + def defaultrunner(self, connection): + return FBDefaultRunner(connection) def preparer(self): return FBIdentifierPreparer(self) @@ -292,9 +277,6 @@ class FBDialect(ansisql.ANSIDialect): for name,value in fks.iteritems(): table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name)) - def last_inserted_ids(self): - return self.context.last_inserted_ids - def do_execute(self, cursor, statement, parameters, **kwargs): cursor.execute(statement, parameters or []) @@ -304,15 +286,6 @@ class FBDialect(ansisql.ANSIDialect): def do_commit(self, connection): connection.commit(True) - def connection(self): - """Returns a managed DBAPI connection from this SQLEngine's connection pool.""" - c = self._pool.connect() - c.supportsTransactions = 0 - return c - - def dbapi(self): - return self.module - class FBCompiler(ansisql.ANSICompiler): """Firebird specific idiosincrasies""" @@ -364,7 +337,7 @@ class FBCompiler(ansisql.ANSICompiler): class FBSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) - colspec += " " + column.type.engine_impl(self.engine).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: @@ -388,11 +361,11 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper): class FBDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["rdb$database"], engine=self.engine).compile() - return self.proxy(str(c), c.get_params()).fetchone()[0] + c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.engine) + return self.connection.execute_compiled(c).scalar() def visit_sequence(self, seq): - return self.proxy("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").fetchone()[0] + return self.connection.execute_text("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").scalar() RESERVED_WORDS = util.Set( diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 1852edefb8..6d2ff66cd5 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -52,7 +52,22 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions - +def dbapi(module_name=None): + if module_name: + try: + dialect_cls = dialect_mapping[module_name] + return dialect_cls.import_dbapi() + except KeyError: + raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name) + else: + for dialect_cls in [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]: + try: + return dialect_cls.import_dbapi() + except ImportError, e: + pass + else: + raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc') + class MSNumeric(sqltypes.Numeric): def convert_result_value(self, value, dialect): return value @@ -142,9 +157,6 @@ class MSString(sqltypes.String): return "VARCHAR(%(length)s)" % {'length' : self.length} class MSNVarchar(MSString): - """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True. """ - impl = sqltypes.Unicode - def get_col_spec(self): if self.length: return "NVARCHAR(%(length)s)" % {'length' : self.length} @@ -154,19 +166,7 @@ class MSNVarchar(MSString): return "NTEXT" class AdoMSNVarchar(MSNVarchar): - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - return value - -class MSUnicode(sqltypes.Unicode): - """Unicode subclass, does Unicode conversion in all cases, uses NVARCHAR impl.""" - impl = MSNVarchar - -class AdoMSUnicode(MSUnicode): - impl = AdoMSNVarchar - + """overrides bindparam/result processing to not convert any unicode strings""" def convert_bind_param(self, value, dialect): return value @@ -215,9 +215,9 @@ def descriptor(): ]} class MSSQLExecutionContext(default.DefaultExecutionContext): - def __init__(self, dialect): + def __init__(self, *args, **kwargs): self.IINSERT = self.HASIDENT = False - super(MSSQLExecutionContext, self).__init__(dialect) + super(MSSQLExecutionContext, self).__init__(*args, **kwargs) def _has_implicit_sequence(self, column): if column.primary_key and column.autoincrement: @@ -227,14 +227,14 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): return True return False - def pre_exec(self, engine, proxy, compiled, parameters, **kwargs): + def pre_exec(self): """MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. Activate it if the feature is turned on and needed. """ - if getattr(compiled, "isinsert", False): - tbl = compiled.statement.table + if self.compiled.isinsert: + tbl = self.compiled.statement.table if not hasattr(tbl, 'has_sequence'): tbl.has_sequence = None for column in tbl.c: @@ -243,39 +243,43 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): break self.HASIDENT = bool(tbl.has_sequence) - if engine.dialect.auto_identity_insert and self.HASIDENT: - if isinstance(parameters, list): - self.IINSERT = tbl.has_sequence.key in parameters[0] + if self.dialect.auto_identity_insert and self.HASIDENT: + if isinstance(self.compiled_parameters, list): + self.IINSERT = tbl.has_sequence.key in self.compiled_parameters[0] else: - self.IINSERT = tbl.has_sequence.key in parameters + self.IINSERT = tbl.has_sequence.key in self.compiled_parameters else: self.IINSERT = False if self.IINSERT: - proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name) + # TODO: quoting rules for table name here ? + self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name) - super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs) + super(MSSQLExecutionContext, self).pre_exec() - def post_exec(self, engine, proxy, compiled, parameters, **kwargs): + def post_exec(self): """Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column). """ - if getattr(compiled, "isinsert", False): + if self.compiled.isinsert: if self.IINSERT: - proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name) + # TODO: quoting rules for table name here ? + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name) self.IINSERT = False elif self.HASIDENT: - cursor = proxy("SELECT @@IDENTITY AS lastrowid") - row = cursor.fetchone() + self.cursor.execute("SELECT @@IDENTITY AS lastrowid") + row = self.cursor.fetchone() self._last_inserted_ids = [int(row[0])] # print "LAST ROW ID", self._last_inserted_ids self.HASIDENT = False + super(MSSQLExecutionContext, self).post_exec() class MSSQLDialect(ansisql.ANSIDialect): colspecs = { + sqltypes.Unicode : MSNVarchar, sqltypes.Integer : MSInteger, sqltypes.Smallinteger: MSSmallInteger, sqltypes.Numeric : MSNumeric, @@ -283,7 +287,6 @@ class MSSQLDialect(ansisql.ANSIDialect): sqltypes.DateTime : MSDateTime, sqltypes.Date : MSDate, sqltypes.String : MSString, - sqltypes.Unicode : MSUnicode, sqltypes.Binary : MSBinary, sqltypes.Boolean : MSBoolean, sqltypes.TEXT : MSText, @@ -296,7 +299,7 @@ class MSSQLDialect(ansisql.ANSIDialect): 'smallint' : MSSmallInteger, 'tinyint' : MSTinyInteger, 'varchar' : MSString, - 'nvarchar' : MSUnicode, + 'nvarchar' : MSNVarchar, 'char' : MSChar, 'nchar' : MSNChar, 'text' : MSText, @@ -312,30 +315,16 @@ class MSSQLDialect(ansisql.ANSIDialect): 'image' : MSBinary } - def __new__(cls, module_name=None, *args, **kwargs): - module = kwargs.get('module', None) + def __new__(cls, dbapi=None, *args, **kwargs): if cls != MSSQLDialect: return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs) - if module_name: - dialect = dialect_mapping.get(module_name) - if not dialect: - raise exceptions.InvalidRequestError('Unsupported MSSQL module requested (must be adodbpi, pymssql or pyodbc): ' + module_name) - if not hasattr(dialect, 'module'): - raise dialect.saved_import_error + if dbapi: + dialect = dialect_mapping.get(dbapi.__name__) return dialect(*args, **kwargs) - elif module: - return object.__new__(cls, *args, **kwargs) else: - for dialect in dialect_preference: - if hasattr(dialect, 'module'): - return dialect(*args, **kwargs) - #raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc') - else: - return object.__new__(cls, *args, **kwargs) + return object.__new__(cls, *args, **kwargs) - def __init__(self, module_name=None, module=None, auto_identity_insert=True, **params): - if not hasattr(self, 'module'): - self.module = module + def __init__(self, auto_identity_insert=True, **params): super(MSSQLDialect, self).__init__(**params) self.auto_identity_insert = auto_identity_insert self.text_as_varchar = False @@ -352,8 +341,8 @@ class MSSQLDialect(ansisql.ANSIDialect): self.text_as_varchar = bool(opts.pop('text_as_varchar')) return self.make_connect_string(opts) - def create_execution_context(self): - return MSSQLExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return MSSQLExecutionContext(self, *args, **kwargs) def type_descriptor(self, typeobj): newobj = sqltypes.adapt_type(typeobj, self.colspecs) @@ -373,13 +362,13 @@ class MSSQLDialect(ansisql.ANSIDialect): return MSSQLCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return MSSQLSchemaGenerator(*args, **kwargs) + return MSSQLSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return MSSQLSchemaDropper(*args, **kwargs) + return MSSQLSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, engine, proxy): - return MSSQLDefaultRunner(engine, proxy) + def defaultrunner(self, connection, **kwargs): + return MSSQLDefaultRunner(connection, **kwargs) def preparer(self): return MSSQLIdentifierPreparer(self) @@ -411,19 +400,12 @@ class MSSQLDialect(ansisql.ANSIDialect): def raw_connection(self, connection): """Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes""" try: + # TODO: probably want to move this to individual dialect subclasses to + # save on the exception throw + simplify return connection.connection.__dict__['_pymssqlCnx__cnx'] except: return connection.connection.adoConn - def connection(self): - """returns a managed DBAPI connection from this SQLEngine's connection pool.""" - c = self._pool.connect() - c.supportsTransactions = 0 - return c - - def dbapi(self): - return self.module - def uppercase_table(self, t): # convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive t.name = t.name.upper() @@ -558,13 +540,14 @@ class MSSQLDialect(ansisql.ANSIDialect): table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) class MSSQLDialect_pymssql(MSSQLDialect): - try: + def import_dbapi(cls): import pymssql as module # pymmsql doesn't have a Binary method. we use string + # TODO: monkeypatching here is less than ideal module.Binary = lambda st: str(st) - except ImportError, e: - saved_import_error = e - + return module + import_dbapi = classmethod(import_dbapi) + def supports_sane_rowcount(self): return True @@ -578,7 +561,7 @@ class MSSQLDialect_pymssql(MSSQLDialect): def create_connect_args(self, url): r = super(MSSQLDialect_pymssql, self).create_connect_args(url) if hasattr(self, 'query_timeout'): - self.module._mssql.set_query_timeout(self.query_timeout) + self.dbapi._mssql.set_query_timeout(self.query_timeout) return r def make_connect_string(self, keys): @@ -621,15 +604,16 @@ class MSSQLDialect_pymssql(MSSQLDialect): ## r.fetch_array() class MSSQLDialect_pyodbc(MSSQLDialect): - try: + + def import_dbapi(cls): import pyodbc as module - except ImportError, e: - saved_import_error = e - + return module + import_dbapi = classmethod(import_dbapi) + colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.Unicode] = AdoMSUnicode + colspecs[sqltypes.Unicode] = AdoMSNVarchar ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['nvarchar'] = AdoMSUnicode + ischema_names['nvarchar'] = AdoMSNVarchar def supports_sane_rowcount(self): return False @@ -648,15 +632,15 @@ class MSSQLDialect_pyodbc(MSSQLDialect): class MSSQLDialect_adodbapi(MSSQLDialect): - try: + def import_dbapi(cls): import adodbapi as module - except ImportError, e: - saved_import_error = e + return module + import_dbapi = classmethod(import_dbapi) colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.Unicode] = AdoMSUnicode + colspecs[sqltypes.Unicode] = AdoMSNVarchar ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['nvarchar'] = AdoMSUnicode + ischema_names['nvarchar'] = AdoMSNVarchar def supports_sane_rowcount(self): return True @@ -676,13 +660,11 @@ class MSSQLDialect_adodbapi(MSSQLDialect): connectors.append("Integrated Security=SSPI") return [[";".join (connectors)], {}] - dialect_mapping = { 'pymssql': MSSQLDialect_pymssql, 'pyodbc': MSSQLDialect_pyodbc, 'adodbapi': MSSQLDialect_adodbapi } -dialect_preference = [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc] class MSSQLCompiler(ansisql.ANSICompiler): @@ -770,7 +752,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() # install a IDENTITY Sequence if we have an implicit IDENTITY column if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ @@ -797,6 +779,7 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): self.execute() class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): + # TODO: does ms-sql have standalone sequences ? pass class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 5fc63234a0..65ccb6af19 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -12,12 +12,9 @@ import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions from array import array -try: +def dbapi(): import MySQLdb as mysql - import MySQLdb.constants.CLIENT as CLIENT_FLAGS -except: - mysql = None - CLIENT_FLAGS = None + return mysql def kw_colspec(self, spec): if self.unsigned: @@ -158,8 +155,6 @@ class MSLongText(MSText): return "LONGTEXT" class MSString(sqltypes.String): - def __init__(self, length=None, *extra, **kwargs): - sqltypes.String.__init__(self, length=length) def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} @@ -277,16 +272,12 @@ def descriptor(): ]} class MySQLExecutionContext(default.DefaultExecutionContext): - def post_exec(self, engine, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False): - self._last_inserted_ids = [proxy().lastrowid] + def post_exec(self): + if self.compiled.isinsert: + self._last_inserted_ids = [self.cursor.lastrowid] class MySQLDialect(ansisql.ANSIDialect): - def __init__(self, module = None, **kwargs): - if module is None: - self.module = mysql - else: - self.module = module + def __init__(self, **kwargs): ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs) def create_connect_args(self, url): @@ -305,14 +296,18 @@ class MySQLDialect(ansisql.ANSIDialect): # TODO: what about options like "ssl", "cursorclass" and "conv" ? client_flag = opts.get('client_flag', 0) - if CLIENT_FLAGS is not None: - client_flag |= CLIENT_FLAGS.FOUND_ROWS + if self.dbapi is not None: + try: + import MySQLdb.constants.CLIENT as CLIENT_FLAGS + client_flag |= CLIENT_FLAGS.FOUND_ROWS + except: + pass opts['client_flag'] = client_flag return [[], opts] - def create_execution_context(self): - return MySQLExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return MySQLExecutionContext(self, *args, **kwargs) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) @@ -324,10 +319,10 @@ class MySQLDialect(ansisql.ANSIDialect): return MySQLCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return MySQLSchemaGenerator(*args, **kwargs) + return MySQLSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return MySQLSchemaDropper(*args, **kwargs) + return MySQLSchemaDropper(self, *args, **kwargs) def preparer(self): return MySQLIdentifierPreparer(self) @@ -337,14 +332,14 @@ class MySQLDialect(ansisql.ANSIDialect): rowcount = cursor.executemany(statement, parameters) if context is not None: context._rowcount = rowcount - except mysql.OperationalError, o: + except self.dbapi.OperationalError, o: if o.args[0] == 2006 or o.args[0] == 2014: cursor.invalidate() raise o def do_execute(self, cursor, statement, parameters, **kwargs): try: cursor.execute(statement, parameters) - except mysql.OperationalError, o: + except self.dbapi.OperationalError, o: if o.args[0] == 2006 or o.args[0] == 2014: cursor.invalidate() raise o @@ -361,11 +356,9 @@ class MySQLDialect(ansisql.ANSIDialect): self._default_schema_name = text("select database()", self).scalar() return self._default_schema_name - def dbapi(self): - return self.module - def has_table(self, connection, table_name, schema=None): - cursor = connection.execute("show table status like '" + table_name + "'") + cursor = connection.execute("show table status like %s", [table_name]) + print "CURSOR", cursor, "ROWCOUNT", cursor.rowcount, "REAL RC", cursor.cursor.rowcount return bool( not not cursor.rowcount ) def reflecttable(self, connection, table): @@ -492,8 +485,7 @@ class MySQLCompiler(ansisql.ANSICompiler): class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): - t = column.type.engine_impl(self.engine) - colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index adea127bfe..5377759a2a 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -8,15 +8,13 @@ import sys, StringIO, string, re from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging -import sqlalchemy.engine.default as default +from sqlalchemy.engine import default, base import sqlalchemy.types as sqltypes -try: +def dbapi(): import cx_Oracle -except: - cx_Oracle = None + return cx_Oracle -ORACLE_BINARY_TYPES = [getattr(cx_Oracle, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(cx_Oracle, k)] class OracleNumeric(sqltypes.Numeric): def get_col_spec(self): @@ -149,26 +147,32 @@ def descriptor(): ]} class OracleExecutionContext(default.DefaultExecutionContext): - def pre_exec(self, engine, proxy, compiled, parameters): - super(OracleExecutionContext, self).pre_exec(engine, proxy, compiled, parameters) + def pre_exec(self): + super(OracleExecutionContext, self).pre_exec() if self.dialect.auto_setinputsizes: - self.set_input_sizes(proxy(), parameters) + self.set_input_sizes() + + def get_result_proxy(self): + if self.cursor.description is not None: + for column in self.cursor.description: + type_code = column[1] + if type_code in self.dialect.ORACLE_BINARY_TYPES: + return base.BufferedColumnResultProxy(self) + + return base.ResultProxy(self) class OracleDialect(ansisql.ANSIDialect): - def __init__(self, use_ansi=True, auto_setinputsizes=True, module=None, threaded=True, **kwargs): + def __init__(self, use_ansi=True, auto_setinputsizes=True, threaded=True, **kwargs): + ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs) self.use_ansi = use_ansi self.threaded = threaded - if module is None: - self.module = cx_Oracle - else: - self.module = module - self.supports_timestamp = hasattr(self.module, 'TIMESTAMP' ) + self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) self.auto_setinputsizes = auto_setinputsizes - ansisql.ANSIDialect.__init__(self, **kwargs) - - def dbapi(self): - return self.module - + if self.dbapi is not None: + self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)] + else: + self.ORACLE_BINARY_TYPES = [] + def create_connect_args(self, url): if url.database: # if we have a database, then we have a remote host @@ -177,7 +181,7 @@ class OracleDialect(ansisql.ANSIDialect): port = int(port) else: port = 1521 - dsn = self.module.makedsn(url.host,port,url.database) + dsn = self.dbapi.makedsn(url.host,port,url.database) else: # we have a local tnsname dsn = url.host @@ -206,20 +210,20 @@ class OracleDialect(ansisql.ANSIDialect): else: return "rowid" - def create_execution_context(self): - return OracleExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return OracleExecutionContext(self, *args, **kwargs) def compiler(self, statement, bindparams, **kwargs): return OracleCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return OracleSchemaGenerator(*args, **kwargs) + return OracleSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return OracleSchemaDropper(*args, **kwargs) + return OracleSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, engine, proxy): - return OracleDefaultRunner(engine, proxy) + def defaultrunner(self, connection, **kwargs): + return OracleDefaultRunner(connection, **kwargs) def has_table(self, connection, table_name, schema=None): cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()}) @@ -405,15 +409,6 @@ class OracleDialect(ansisql.ANSIDialect): if context is not None: context._rowcount = rowcount - def create_result_proxy_args(self, connection, cursor): - args = super(OracleDialect, self).create_result_proxy_args(connection, cursor) - if cursor and cursor.description: - for column in cursor.description: - type_code = column[1] - if type_code in ORACLE_BINARY_TYPES: - args['should_prefetch'] = True - break - return args OracleDialect.logger = logging.class_logger(OracleDialect) @@ -569,7 +564,7 @@ class OracleCompiler(ansisql.ANSICompiler): class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) - colspec += " " + column.type.engine_impl(self.engine).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -579,22 +574,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - if not self.engine.dialect.has_sequence(self.connection, sequence.name): + if not self.dialect.has_sequence(self.connection, sequence.name): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class OracleSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): - if self.engine.dialect.has_sequence(self.connection, sequence.name): + if self.dialect.has_sequence(self.connection, sequence.name): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): c = sql.select([default.arg], from_obj=["DUAL"], engine=self.engine).compile() - return self.proxy(str(c), c.get_params()).fetchone()[0] + return self.connection.execute_compiled(c).scalar() def visit_sequence(self, seq): - return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0] + return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar() dialect = OracleDialect diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index d83607793e..2943d163e5 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,33 +4,28 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import datetime, sys, StringIO, string, types, re - -import sqlalchemy.util as util -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.engine.default as default -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +import datetime, string, types, re, random + +from sqlalchemy import util, sql, schema, ansisql, exceptions +from sqlalchemy.engine import base, default import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions from sqlalchemy.databases import information_schema as ischema -import re try: import mx.DateTime.DateTime as mxDateTime except: mxDateTime = None -try: - import psycopg2 as psycopg - #import psycopg2.psycopg1 as psycopg -except: +def dbapi(): try: - import psycopg - except: - psycopg = None - + import psycopg2 as psycopg + except ImportError, e: + try: + import psycopg + except ImportError, e2: + raise e + return psycopg + class PGInet(sqltypes.TypeEngine): def get_col_spec(self): return "INET" @@ -74,8 +69,8 @@ class PG1DateTime(sqltypes.DateTime): mx_datetime = mxDateTime(value.year, value.month, value.day, value.hour, value.minute, seconds) - return psycopg.TimestampFromMx(mx_datetime) - return psycopg.TimestampFromMx(value) + return dialect.dbapi.TimestampFromMx(mx_datetime) + return dialect.dbapi.TimestampFromMx(value) else: return None @@ -101,7 +96,7 @@ class PG1Date(sqltypes.Date): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: - return psycopg.DateFromMx(value) + return dialect.dbapi.DateFromMx(value) else: return None @@ -219,44 +214,49 @@ def descriptor(): ]} class PGExecutionContext(default.DefaultExecutionContext): - def post_exec(self, engine, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None: - if not engine.dialect.use_oids: + + def is_select(self): + return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I) + + def create_cursor(self): + if self.dialect.server_side_cursors and self.is_select(): + # use server-side cursors: + # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html + ident = "c" + hex(random.randint(0, 65535))[2:] + return self.connection.connection.cursor(ident) + else: + return self.connection.connection.cursor() + + def get_result_proxy(self): + if self.dialect.server_side_cursors and self.is_select(): + return base.BufferedRowResultProxy(self) + else: + return base.ResultProxy(self) + + def post_exec(self): + if self.compiled.isinsert and self.last_inserted_ids is None: + if not self.dialect.use_oids: pass # will raise invalid error when they go to get them else: - table = compiled.statement.table - cursor = proxy() - if cursor.lastrowid is not None and table is not None and len(table.primary_key): - s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) - c = s.compile(engine=engine) - cursor = proxy(str(c), c.get_params()) - row = cursor.fetchone() + table = self.compiled.statement.table + if self.cursor.lastrowid is not None and table is not None and len(table.primary_key): + s = sql.select(table.primary_key, table.oid_column == self.cursor.lastrowid) + row = self.connection.execute(s).fetchone() self._last_inserted_ids = [v for v in row] - + super(PGExecutionContext, self).post_exec() + class PGDialect(ansisql.ANSIDialect): - def __init__(self, module=None, use_oids=False, use_information_schema=False, server_side_cursors=False, **params): + def __init__(self, use_oids=False, use_information_schema=False, server_side_cursors=False, **kwargs): + ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs) self.use_oids = use_oids self.server_side_cursors = server_side_cursors - if module is None: - #if psycopg is None: - # raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") - self.module = psycopg + if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'): + self.version = 2 else: - self.module = module - # figure psycopg version 1 or 2 - try: - if self.module.__version__.startswith('2'): - self.version = 2 - else: - self.version = 1 - except: self.version = 1 - ansisql.ANSIDialect.__init__(self, **params) self.use_information_schema = use_information_schema - # produce consistent paramstyle even if psycopg2 module not present - if self.module is None: - self.paramstyle = 'pyformat' + self.paramstyle = 'pyformat' def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) @@ -268,16 +268,9 @@ class PGDialect(ansisql.ANSIDialect): opts.update(url.query) return ([], opts) - def create_cursor(self, connection): - if self.server_side_cursors: - # use server-side cursors: - # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - return connection.cursor('x') - else: - return connection.cursor() - def create_execution_context(self): - return PGExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return PGExecutionContext(self, *args, **kwargs) def max_identifier_length(self): return 68 @@ -292,13 +285,13 @@ class PGDialect(ansisql.ANSIDialect): return PGCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return PGSchemaGenerator(*args, **kwargs) + return PGSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return PGSchemaDropper(*args, **kwargs) + return PGSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, engine, proxy): - return PGDefaultRunner(engine, proxy) + def defaultrunner(self, connection, **kwargs): + return PGDefaultRunner(connection, **kwargs) def preparer(self): return PGIdentifierPreparer(self) @@ -326,7 +319,6 @@ class PGDialect(ansisql.ANSIDialect): ``psycopg2`` is not nice enough to produce this correctly for an executemany, so we do our own executemany here. """ - rowcount = 0 for param in parameters: c.execute(statement, param) @@ -334,9 +326,6 @@ class PGDialect(ansisql.ANSIDialect): if context is not None: context._rowcount = rowcount - def dbapi(self): - return self.module - def has_table(self, connection, table_name, schema=None): # seems like case gets folded in pg_class... if schema is None: @@ -542,7 +531,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): else: colspec += " SERIAL" else: - colspec += " " + column.type.engine_impl(self.engine).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -567,8 +556,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): if column.primary_key: # passive defaults on primary keys have to be overridden if isinstance(column.default, schema.PassiveDefault): - c = self.proxy("select %s" % column.default.arg) - return c.fetchone()[0] + return self.connection.execute_text("select %s" % column.default.arg).scalar() elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema # TODO: this has to build into the Sequence object so we can get the quoting @@ -577,17 +565,13 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) else: exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - c = self.proxy(exc) - return c.fetchone()[0] - else: - return ansisql.ANSIDefaultRunner.get_column_default(self, column) - else: - return ansisql.ANSIDefaultRunner.get_column_default(self, column) + return self.connection.execute_text(exc).scalar() + + return super(ansisql.ANSIDefaultRunner, self).get_column_default(column) def visit_sequence(self, seq): if not seq.optional: - c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq)) - return c.fetchone()[0] + return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar() else: return None diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index b29be9eedd..9270f2a5ff 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -12,19 +12,19 @@ import sqlalchemy.engine.default as default import sqlalchemy.types as sqltypes import datetime,time -pysqlite2_timesupport = False # Change this if the init.d guys ever get around to supporting time cols - -try: - from pysqlite2 import dbapi2 as sqlite -except ImportError: +def dbapi(): try: - from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. - except ImportError: + from pysqlite2 import dbapi2 as sqlite + except ImportError, e: try: - sqlite = __import__('sqlite') # skip ourselves - except: - sqlite = None - + from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. + except ImportError: + try: + sqlite = __import__('sqlite') # skip ourselves + except ImportError: + raise e + return sqlite + class SLNumeric(sqltypes.Numeric): def get_col_spec(self): if self.precision is None: @@ -140,10 +140,6 @@ pragma_names = { 'BLOB' : SLBinary, } -if pysqlite2_timesupport: - colspecs.update({sqltypes.Time : SLTime}) - pragma_names.update({'TIME' : SLTime}) - def descriptor(): return {'name':'sqlite', 'description':'SQLite', @@ -152,25 +148,29 @@ def descriptor(): ]} class SQLiteExecutionContext(default.DefaultExecutionContext): - def post_exec(self, engine, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False): - self._last_inserted_ids = [proxy().lastrowid] - + def post_exec(self): + if self.compiled.isinsert: + self._last_inserted_ids = [self.cursor.lastrowid] + super(SQLiteExecutionContext, self).post_exec() + class SQLiteDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): + ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs) def vers(num): return tuple([int(x) for x in num.split('.')]) - self.supports_cast = (sqlite is not None and vers(sqlite.sqlite_version) >= vers("3.2.3")) - ansisql.ANSIDialect.__init__(self, **kwargs) + self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) def compiler(self, statement, bindparams, **kwargs): return SQLiteCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return SQLiteSchemaGenerator(*args, **kwargs) + return SQLiteSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return SQLiteSchemaDropper(*args, **kwargs) + return SQLiteSchemaDropper(self, *args, **kwargs) + + def supports_alter(self): + return False def preparer(self): return SQLiteIdentifierPreparer(self) @@ -182,8 +182,8 @@ class SQLiteDialect(ansisql.ANSIDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def create_execution_context(self): - return SQLiteExecutionContext(self) + def create_execution_context(self, **kwargs): + return SQLiteExecutionContext(self, **kwargs) def last_inserted_ids(self): return self.context.last_inserted_ids @@ -191,9 +191,6 @@ class SQLiteDialect(ansisql.ANSIDialect): def oid_column_name(self, column): return "oid" - def dbapi(self): - return sqlite - def has_table(self, connection, table_name, schema=None): cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {}) row = cursor.fetchone() @@ -321,11 +318,9 @@ class SQLiteCompiler(ansisql.ANSICompiler): return ansisql.ANSICompiler.binary_operator_string(self, binary) class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): - def supports_alter(self): - return False def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -345,8 +340,7 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): # super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint) class SQLiteSchemaDropper(ansisql.ANSISchemaDropper): - def supports_alter(self): - return False + pass class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def __init__(self, dialect): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 0baaeb8268..d8a9c52998 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -83,7 +83,7 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() def type_descriptor(self, typeobj): - """Trasform the type from generic to database-specific. + """Transform the type from generic to database-specific. Provides a database-specific TypeEngine object, given the generic object which comes from the types module. Subclasses @@ -105,6 +105,10 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() + def supports_alter(self): + """return True if the database supports ALTER TABLE.""" + raise NotImplementedError() + def max_identifier_length(self): """Return the maximum length of identifier names. @@ -118,32 +122,43 @@ class Dialect(sql.AbstractDialect): def supports_sane_rowcount(self): """Indicate whether the dialect properly implements statements rowcount. - Provided to indicate when MySQL is being used, which does not - have standard behavior for the "rowcount" function on a statement handle. + This was needed for MySQL which had non-standard behavior of rowcount, + but this issue has since been resolved. """ raise NotImplementedError() - def schemagenerator(self, engine, proxy, **params): + def schemagenerator(self, connection, **kwargs): """Return a ``schema.SchemaVisitor`` instance that can generate schemas. + connection + a Connection to use for statement execution + `schemagenerator()` is called via the `create()` method on Table, Index, and others. """ raise NotImplementedError() - def schemadropper(self, engine, proxy, **params): + def schemadropper(self, connection, **kwargs): """Return a ``schema.SchemaVisitor`` instance that can drop schemas. + connection + a Connection to use for statement execution + `schemadropper()` is called via the `drop()` method on Table, Index, and others. """ raise NotImplementedError() - def defaultrunner(self, engine, proxy, **params): - """Return a ``schema.SchemaVisitor`` instance that can execute defaults.""" + def defaultrunner(self, connection, **kwargs): + """Return a ``schema.SchemaVisitor`` instance that can execute defaults. + + connection + a Connection to use for statement execution + + """ raise NotImplementedError() @@ -154,7 +169,6 @@ class Dialect(sql.AbstractDialect): ansisql.ANSICompiler, and will produce a string representation of the given ClauseElement and `parameters` dictionary. - `compiler()` is called within the context of the compile() method. """ raise NotImplementedError() @@ -188,23 +202,13 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def dbapi(self): - """Establish a connection to the database. - - Subclasses override this method to provide the DBAPI module - used to establish connections. - """ - - raise NotImplementedError() - def get_default_schema_name(self, connection): """Return the currently selected schema given a connection""" raise NotImplementedError() - def execution_context(self): + def create_execution_context(self, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None): """Return a new ExecutionContext object.""" - raise NotImplementedError() def do_begin(self, connection): @@ -232,15 +236,6 @@ class Dialect(sql.AbstractDialect): raise NotImplementedError() - def create_cursor(self, connection): - """Return a new cursor generated from the given connection.""" - - raise NotImplementedError() - - def create_result_proxy_args(self, connection, cursor): - """Return a dictionary of arguments that should be passed to ResultProxy().""" - - raise NotImplementedError() def compile(self, clauseelement, parameters=None): """Compile the given ClauseElement using this Dialect. @@ -255,42 +250,74 @@ class Dialect(sql.AbstractDialect): class ExecutionContext(object): """A messenger object for a Dialect that corresponds to a single execution. + ExecutionContext should have these datamembers: + + connection + Connection object which initiated the call to the + dialect to create this ExecutionContext. + + dialect + dialect which created this ExecutionContext. + + cursor + DBAPI cursor procured from the connection + + compiled + if passed to constructor, sql.Compiled object being executed + + compiled_parameters + if passed to constructor, sql.ClauseParameters object + + statement + string version of the statement to be executed. Is either + passed to the constructor, or must be created from the + sql.Compiled object by the time pre_exec() has completed. + + parameters + "raw" parameters suitable for direct execution by the + dialect. Either passed to the constructor, or must be + created from the sql.ClauseParameters object by the time + pre_exec() has completed. + + The Dialect should provide an ExecutionContext via the create_execution_context() method. The `pre_exec` and `post_exec` - methods will be called for compiled statements, afterwhich it is - expected that the various methods `last_inserted_ids`, - `last_inserted_params`, etc. will contain appropriate values, if - applicable. + methods will be called for compiled statements. + """ - def pre_exec(self, engine, proxy, compiled, parameters): - """Called before an execution of a compiled statement. + def create_cursor(self): + """Return a new cursor generated this ExecutionContext's connection.""" - `proxy` is a callable that takes a string statement and a bind - parameter list/dictionary. + raise NotImplementedError() + + def pre_exec(self): + """Called before an execution of a compiled statement. + + If compiled and compiled_parameters were passed to this + ExecutionContext, the `statement` and `parameters` datamembers + must be initialized after this statement is complete. """ raise NotImplementedError() - def post_exec(self, engine, proxy, compiled, parameters): + def post_exec(self): """Called after the execution of a compiled statement. - - `proxy` is a callable that takes a string statement and a bind - parameter list/dictionary. + + If compiled was passed to this ExecutionContext, + the `last_insert_ids`, `last_inserted_params`, etc. + datamembers should be available after this method + completes. """ raise NotImplementedError() - - def get_rowcount(self, cursor): - """Return the count of rows updated/deleted for an UPDATE/DELETE statement.""" - + + def get_result_proxy(self): + """return a ResultProxy corresponding to this ExecutionContext.""" raise NotImplementedError() - - def supports_sane_rowcount(self): - """Indicate if the "rowcount" DBAPI cursor function works properly. - - Currently, MySQLDB does not properly implement this function. - """ + + def get_rowcount(self): + """Return the count of rows updated/deleted for an UPDATE/DELETE statement.""" raise NotImplementedError() @@ -299,7 +326,7 @@ class ExecutionContext(object): This does not apply to straight textual clauses; only to ``sql.Insert`` objects compiled against a ``schema.Table`` object, - which are executed via `statement.execute()`. The order of + which are executed via `execute()`. The order of items in the list is the same as that of the Table's 'primary_key' attribute. @@ -337,7 +364,7 @@ class ExecutionContext(object): raise NotImplementedError() -class Connectable(object): +class Connectable(sql.Executor): """Interface for an object that can provide an Engine and a Connection object which correponds to that Engine.""" def contextual_connect(self): @@ -362,6 +389,7 @@ class Connectable(object): raise NotImplementedError() engine = property(_not_impl, doc="The Engine which this Connectable is associated with.") + dialect = property(_not_impl, doc="Dialect which this Connectable is associated with.") class Connection(Connectable): """Represent a single DBAPI connection returned from the underlying connection pool. @@ -385,7 +413,8 @@ class Connection(Connectable): except AttributeError: raise exceptions.InvalidRequestError("This Connection is closed") - engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated (read only)") + engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.") + dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.") connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.") should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.") @@ -429,7 +458,7 @@ class Connection(Connectable): """When no Transaction is present, this is called after executions to provide "autocommit" behavior.""" # TODO: have the dialect determine if autocommit can be set on the connection directly without this # extra step - if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip().upper()): + if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip(), re.I): self._commit_impl() def _autorollback(self): @@ -448,6 +477,9 @@ class Connection(Connectable): def scalar(self, object, *multiparams, **params): return self.execute(object, *multiparams, **params).scalar() + def compiler(self, statement, parameters, **kwargs): + return self.dialect.compiler(statement, parameters, engine=self.engine, **kwargs) + def execute(self, object, *multiparams, **params): for c in type(object).__mro__: if c in Connection.executors: @@ -456,7 +488,7 @@ class Connection(Connectable): raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object))) def execute_default(self, default, **kwargs): - return default.accept_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs)) + return default.accept_visitor(self.__engine.dialect.defaultrunner(self)) def execute_text(self, statement, *multiparams, **params): if len(multiparams) == 0: @@ -465,9 +497,9 @@ class Connection(Connectable): parameters = multiparams[0] else: parameters = list(multiparams) - cursor = self._execute_raw(statement, parameters) - rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor) - return ResultProxy(self.__engine, self, cursor, **rpargs) + context = self._create_execution_context(statement=statement, parameters=parameters) + self._execute_raw(context) + return context.get_result_proxy() def _params_to_listofdicts(self, *multiparams, **params): if len(multiparams) == 0: @@ -491,29 +523,57 @@ class Connection(Connectable): param = multiparams[0] else: param = params - return self.execute_compiled(elem.compile(engine=self.__engine, parameters=param), *multiparams, **params) + return self.execute_compiled(elem.compile(dialect=self.dialect, parameters=param), *multiparams, **params) def execute_compiled(self, compiled, *multiparams, **params): """Execute a sql.Compiled object.""" if not compiled.can_execute: raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled))) - cursor = self.__engine.dialect.create_cursor(self.connection) parameters = [compiled.construct_params(m) for m in self._params_to_listofdicts(*multiparams, **params)] if len(parameters) == 1: parameters = parameters[0] - def proxy(statement=None, parameters=None): - if statement is None: - return cursor - - parameters = self.__engine.dialect.convert_compiled_params(parameters) - self._execute_raw(statement, parameters, cursor=cursor, context=context) - return cursor - context = self.__engine.dialect.create_execution_context() - context.pre_exec(self.__engine, proxy, compiled, parameters) - proxy(unicode(compiled), parameters) - context.post_exec(self.__engine, proxy, compiled, parameters) - rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor) - return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, column_labels=compiled.column_labels, **rpargs) + context = self._create_execution_context(compiled=compiled, compiled_parameters=parameters) + context.pre_exec() + self._execute_raw(context) + context.post_exec() + return context.get_result_proxy() + + def _create_execution_context(self, **kwargs): + return self.__engine.dialect.create_execution_context(connection=self, **kwargs) + + def _execute_raw(self, context): + self.__engine.logger.info(context.statement) + self.__engine.logger.info(repr(context.parameters)) + if context.parameters is not None and isinstance(context.parameters, list) and len(context.parameters) > 0 and (isinstance(context.parameters[0], list) or isinstance(context.parameters[0], dict)): + self._executemany(context) + else: + self._execute(context) + self._autocommit(context.statement) + + def _execute(self, context): + if context.parameters is None: + if context.dialect.positional: + context.parameters = () + else: + context.parameters = {} + try: + context.dialect.do_execute(context.cursor, context.statement, context.parameters, context=context) + except Exception, e: + self._autorollback() + #self._rollback_impl() + if self.__close_with_result: + self.close() + raise exceptions.SQLError(context.statement, context.parameters, e) + + def _executemany(self, context): + try: + context.dialect.do_executemany(context.cursor, context.statement, context.parameters, context=context) + except Exception, e: + self._autorollback() + #self._rollback_impl() + if self.__close_with_result: + self.close() + raise exceptions.SQLError(context.statement, context.parameters, e) # poor man's multimethod/generic function thingy executors = { @@ -525,17 +585,17 @@ class Connection(Connectable): } def create(self, entity, **kwargs): - """Create a table or index given an appropriate schema object.""" + """Create a Table or Index given an appropriate Schema object.""" return self.__engine.create(entity, connection=self, **kwargs) def drop(self, entity, **kwargs): - """Drop a table or index given an appropriate schema object.""" + """Drop a Table or Index given an appropriate Schema object.""" return self.__engine.drop(entity, connection=self, **kwargs) def reflecttable(self, table, **kwargs): - """Reflect the columns in the given table from the database.""" + """Reflect the columns in the given string table name from the database.""" return self.__engine.reflecttable(table, connection=self, **kwargs) @@ -545,59 +605,6 @@ class Connection(Connectable): def run_callable(self, callable_): return callable_(self) - def _execute_raw(self, statement, parameters=None, cursor=None, context=None, **kwargs): - if cursor is None: - cursor = self.__engine.dialect.create_cursor(self.connection) - if not self.__engine.dialect.supports_unicode_statements(): - # encode to ascii, with full error handling - statement = statement.encode('ascii') - self.__engine.logger.info(statement) - self.__engine.logger.info(repr(parameters)) - if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)): - self._executemany(cursor, statement, parameters, context=context) - else: - self._execute(cursor, statement, parameters, context=context) - self._autocommit(statement) - return cursor - - def _execute(self, c, statement, parameters, context=None): - if parameters is None: - if self.__engine.dialect.positional: - parameters = () - else: - parameters = {} - try: - self.__engine.dialect.do_execute(c, statement, parameters, context=context) - except Exception, e: - self._autorollback() - #self._rollback_impl() - if self.__close_with_result: - self.close() - raise exceptions.SQLError(statement, parameters, e) - - def _executemany(self, c, statement, parameters, context=None): - try: - self.__engine.dialect.do_executemany(c, statement, parameters, context=context) - except Exception, e: - self._autorollback() - #self._rollback_impl() - if self.__close_with_result: - self.close() - raise exceptions.SQLError(statement, parameters, e) - - def proxy(self, statement=None, parameters=None): - """Execute the given statement string and parameter object. - - The parameter object is expected to be the result of a call to - ``compiled.get_params()``. This callable is a generic version - of a connection/cursor-specific callable that is produced - within the execute_compiled method, and is used for objects - that require this style of proxy when outside of an - execute_compiled method, primarily the DefaultRunner. - """ - parameters = self.__engine.dialect.convert_compiled_params(parameters) - return self._execute_raw(statement, parameters) - class Transaction(object): """Represent a Transaction in progress. @@ -630,7 +637,7 @@ class Transaction(object): self.__connection._commit_impl() self.__is_active = False -class Engine(sql.Executor, Connectable): +class Engine(Connectable): """ Connects a ConnectionProvider, a Dialect and a CompilerFactory together to provide a default implementation of SchemaEngine. @@ -638,12 +645,13 @@ class Engine(sql.Executor, Connectable): def __init__(self, connection_provider, dialect, echo=None): self.connection_provider = connection_provider - self.dialect=dialect + self._dialect=dialect self.echo = echo self.logger = logging.instance_logger(self) name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name']) engine = property(lambda s:s) + dialect = property(lambda s:s._dialect) echo = logging.echo_property() def dispose(self): @@ -678,11 +686,11 @@ class Engine(sql.Executor, Connectable): def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): if connection is None: - conn = self.contextual_connect() + conn = self.contextual_connect(close_with_result=False) else: conn = connection try: - element.accept_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs)) + element.accept_visitor(visitorcallable(conn, **kwargs)) finally: if connection is None: conn.close() @@ -807,55 +815,39 @@ class ResultProxy(object): def convert_result_value(self, arg, engine): raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key)) - def __new__(cls, *args, **kwargs): - if cls is ResultProxy and kwargs.has_key('should_prefetch') and kwargs['should_prefetch']: - return PrefetchingResultProxy(*args, **kwargs) - else: - return object.__new__(cls, *args, **kwargs) - - def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None, column_labels=None, should_prefetch=None): + def __init__(self, context): """ResultProxy objects are constructed via the execute() method on SQLEngine.""" - - self.connection = connection - self.dialect = engine.dialect - self.cursor = cursor - self.engine = engine + self.context = context self.closed = False - self.column_labels = column_labels - if executioncontext is not None: - self.__executioncontext = executioncontext - self.rowcount = executioncontext.get_rowcount(cursor) - else: - self.rowcount = cursor.rowcount - self.__key_cache = {} - self.__echo = engine.echo == 'debug' - metadata = cursor.description - self.props = {} - self.keys = [] - i = 0 + self.cursor = context.cursor + self.__echo = logging.is_debug_enabled(context.engine.logger) + self._init_metadata() + dialect = property(lambda s:s.context.dialect) + rowcount = property(lambda s:s.context.get_rowcount()) + connection = property(lambda s:s.context.connection) + + def _init_metadata(self): + if hasattr(self, '_ResultProxy__props'): + return + self.__key_cache = {} + self.__props = {} + self.__keys = [] + metadata = self.cursor.description if metadata is not None: - for item in metadata: + for i, item in enumerate(metadata): # sqlite possibly prepending table name to colnames so strip - colname = item[0].split('.')[-1].lower() - if typemap is not None: - rec = (typemap.get(colname, types.NULLTYPE), i) + colname = item[0].split('.')[-1] + if self.context.typemap is not None: + rec = (self.context.typemap.get(colname.lower(), types.NULLTYPE), i) else: rec = (types.NULLTYPE, i) if rec[0] is None: raise DBAPIError("None for metadata " + colname) - if self.props.setdefault(colname, rec) is not rec: - self.props[colname] = (ResultProxy.AmbiguousColumn(colname), 0) - self.keys.append(colname) - self.props[i] = rec - i+=1 - - def _executioncontext(self): - try: - return self.__executioncontext - except AttributeError: - raise exceptions.InvalidRequestError("This ResultProxy does not have an execution context with which to complete this operation. Execution contexts are not generated for literal SQL execution.") - executioncontext = property(_executioncontext) + if self.__props.setdefault(colname.lower(), rec) is not rec: + self.__props[colname.lower()] = (ResultProxy.AmbiguousColumn(colname), 0) + self.__keys.append(colname) + self.__props[i] = rec def close(self): """Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution. @@ -867,13 +859,12 @@ class ResultProxy(object): This method is also called automatically when all result rows are exhausted. """ - if not self.closed: self.closed = True self.cursor.close() if self.connection.should_close_with_result and self.dialect.supports_autoclose_results: self.connection.close() - + def _convert_key(self, key): """Convert and cache a key. @@ -882,25 +873,26 @@ class ResultProxy(object): metadata; then cache it locally for quick re-access. """ - try: + if key in self.__key_cache: return self.__key_cache[key] - except KeyError: - if isinstance(key, int) and key in self.props: - rec = self.props[key] - elif isinstance(key, basestring) and key.lower() in self.props: - rec = self.props[key.lower()] + else: + if isinstance(key, int) and key in self.__props: + rec = self.__props[key] + elif isinstance(key, basestring) and key.lower() in self.__props: + rec = self.__props[key.lower()] elif isinstance(key, sql.ColumnElement): - label = self.column_labels.get(key._label, key.name).lower() - if label in self.props: - rec = self.props[label] + label = self.context.column_labels.get(key._label, key.name).lower() + if label in self.__props: + rec = self.__props[label] if not "rec" in locals(): raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (repr(key))) self.__key_cache[key] = rec return rec - - + + keys = property(lambda s:s.__keys) + def _has_key(self, row, key): try: self._convert_key(key) @@ -908,10 +900,6 @@ class ResultProxy(object): except KeyError: return False - def _get_col(self, row, key): - rec = self._convert_key(key) - return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect) - def __iter__(self): while True: row = self.fetchone() @@ -926,7 +914,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.executioncontext.last_inserted_ids() + return self.context.last_inserted_ids() def last_updated_params(self): """Return ``last_updated_params()`` from the underlying ExecutionContext. @@ -934,7 +922,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.executioncontext.last_updated_params() + return self.context.last_updated_params() def last_inserted_params(self): """Return ``last_inserted_params()`` from the underlying ExecutionContext. @@ -942,7 +930,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.executioncontext.last_inserted_params() + return self.context.last_inserted_params() def lastrow_has_defaults(self): """Return ``lastrow_has_defaults()`` from the underlying ExecutionContext. @@ -950,7 +938,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.executioncontext.lastrow_has_defaults() + return self.context.lastrow_has_defaults() def supports_sane_rowcount(self): """Return ``supports_sane_rowcount()`` from the underlying ExecutionContext. @@ -958,71 +946,122 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.executioncontext.supports_sane_rowcount() + return self.context.supports_sane_rowcount() + def _get_col(self, row, key): + rec = self._convert_key(key) + return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect) + + def _fetchone_impl(self): + return self.cursor.fetchone() + def _fetchmany_impl(self, size=None): + return self.cursor.fetchmany(size) + def _fetchall_impl(self): + return self.cursor.fetchall() + + def _process_row(self, row): + return RowProxy(self, row) + def fetchall(self): """Fetch all rows, just like DBAPI ``cursor.fetchall()``.""" - l = [] - for row in self.cursor.fetchall(): - l.append(RowProxy(self, row)) + l = [self._process_row(row) for row in self._fetchall_impl()] self.close() return l def fetchmany(self, size=None): """Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``.""" - if size is None: - rows = self.cursor.fetchmany() - else: - rows = self.cursor.fetchmany(size) - l = [] - for row in rows: - l.append(RowProxy(self, row)) + l = [self._process_row(row) for row in self._fetchmany_impl(size)] if len(l) == 0: self.close() return l def fetchone(self): """Fetch one row, just like DBAPI ``cursor.fetchone()``.""" - - row = self.cursor.fetchone() + row = self._fetchone_impl() if row is not None: - return RowProxy(self, row) + return self._process_row(row) else: self.close() return None def scalar(self): """Fetch the first column of the first row, and close the result set.""" - - row = self.cursor.fetchone() + row = self._fetchone_impl() try: if row is not None: - return RowProxy(self, row)[0] + return self._process_row(row)[0] else: return None finally: self.close() -class PrefetchingResultProxy(ResultProxy): +class BufferedRowResultProxy(ResultProxy): + def _init_metadata(self): + self.__buffer_rows() + super(BufferedRowResultProxy, self)._init_metadata() + + # this is a "growth chart" for the buffering of rows. + # each successive __buffer_rows call will use the next + # value in the list for the buffer size until the max + # is reached + size_growth = { + 1 : 5, + 5 : 10, + 10 : 20, + 20 : 50, + 50 : 100 + } + + def __buffer_rows(self): + size = getattr(self, '_bufsize', 1) + self.__rowbuffer = self.cursor.fetchmany(size) + #self.context.engine.logger.debug("Buffered %d rows" % size) + self._bufsize = self.size_growth.get(size, size) + + def _fetchone_impl(self): + if self.closed: + return None + if len(self.__rowbuffer) == 0: + self.__buffer_rows() + if len(self.__rowbuffer) == 0: + return None + return self.__rowbuffer.pop(0) + + def _fetchmany_impl(self, size=None): + result = [] + for x in range(0, size): + row = self._fetchone_impl() + if row is None: + break + result.append(row) + return result + + def _fetchall_impl(self): + return self.__rowbuffer + list(self.cursor.fetchall()) + +class BufferedColumnResultProxy(ResultProxy): """ResultProxy that loads all columns into memory each time fetchone() is called. If fetchmany() or fetchall() are called, the full grid of results is fetched. """ - def _get_col(self, row, key): rec = self._convert_key(key) return row[rec[1]] + + def _process_row(self, row): + sup = super(BufferedColumnResultProxy, self) + row = [sup._get_col(row, i) for i in xrange(len(row))] + return RowProxy(self, row) def fetchall(self): l = [] while True: row = self.fetchone() - if row is not None: - l.append(row) - else: + if row is None: break + l.append(row) return l def fetchmany(self, size=None): @@ -1031,24 +1070,13 @@ class PrefetchingResultProxy(ResultProxy): l = [] for i in xrange(size): row = self.fetchone() - if row is not None: - l.append(row) - else: + if row is None: break + l.append(row) return l - def fetchone(self): - sup = super(PrefetchingResultProxy, self) - row = self.cursor.fetchone() - if row is not None: - row = [sup._get_col(row, i) for i in xrange(len(row))] - return RowProxy(self, row) - else: - self.close() - return None - class RowProxy(object): - """Proxie a single cursor row for a parent ResultProxy. + """Proxy a single cursor row for a parent ResultProxy. Mostly follows "ordered dictionary" behavior, mapping result values to the string-based column name, the integer position of @@ -1063,7 +1091,7 @@ class RowProxy(object): self.__parent = parent self.__row = row if self.__parent._ResultProxy__echo: - self.__parent.engine.logger.debug("Row " + repr(row)) + self.__parent.context.engine.logger.debug("Row " + repr(row)) def close(self): """Close the parent ResultProxy.""" @@ -1115,20 +1143,10 @@ class RowProxy(object): class SchemaIterator(schema.SchemaVisitor): """A visitor that can gather text into a buffer and execute the contents of the buffer.""" - def __init__(self, engine, proxy, **params): + def __init__(self, connection): """Construct a new SchemaIterator. - - engine - the Engine used by this SchemaIterator - - proxy - a callable which takes a statement and bind parameters and - executes it, returning the cursor (the actual DBAPI cursor). - The callable should use the same cursor repeatedly. """ - - self.proxy = proxy - self.engine = engine + self.connection = connection self.buffer = StringIO.StringIO() def append(self, s): @@ -1140,7 +1158,7 @@ class SchemaIterator(schema.SchemaVisitor): """Execute the contents of the SchemaIterator's buffer.""" try: - return self.proxy(self.buffer.getvalue(), None) + return self.connection.execute(self.buffer.getvalue()) finally: self.buffer.truncate(0) @@ -1154,10 +1172,10 @@ class DefaultRunner(schema.SchemaVisitor): DefaultRunner to allow database-specific behavior. """ - def __init__(self, engine, proxy): - self.proxy = proxy - self.engine = engine - + def __init__(self, connection): + self.connection = connection + self.dialect = connection.dialect + def get_column_default(self, column): if column.default is not None: return column.default.accept_visitor(self) @@ -1188,8 +1206,8 @@ class DefaultRunner(schema.SchemaVisitor): return None def exec_default_sql(self, default): - c = sql.select([default.arg], engine=self.engine).compile() - return self.proxy(str(c), c.get_params()).fetchone()[0] + c = sql.select([default.arg]).compile(engine=self.connection) + return self.connection.execute_compiled(c).scalar() def visit_column_onupdate(self, onupdate): if isinstance(onupdate.arg, sql.ClauseElement): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 86563cd7cb..ceecee364f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -26,16 +26,17 @@ class PoolConnectionProvider(base.ConnectionProvider): class DefaultDialect(base.Dialect): """Default implementation of Dialect""" - def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', **kwargs): + def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs): self.convert_unicode = convert_unicode self.supports_autoclose_results = True self.encoding = encoding self.positional = False self._ischema = None - self._figure_paramstyle(default=default_paramstyle) + self.dbapi = dbapi + self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle) - def create_execution_context(self): - return DefaultExecutionContext(self) + def create_execution_context(self, **kwargs): + return DefaultExecutionContext(self, **kwargs) def type_descriptor(self, typeobj): """Provide a database-specific ``TypeEngine`` object, given @@ -56,6 +57,9 @@ class DefaultDialect(base.Dialect): # TODO: probably raise this and fill out # db modules better return 30 + + def supports_alter(self): + return True def oid_column_name(self, column): return None @@ -92,14 +96,8 @@ class DefaultDialect(base.Dialect): def do_execute(self, cursor, statement, parameters, **kwargs): cursor.execute(statement, parameters) - def defaultrunner(self, engine, proxy): - return base.DefaultRunner(engine, proxy) - - def create_cursor(self, connection): - return connection.cursor() - - def create_result_proxy_args(self, connection, cursor): - return dict(should_prefetch=False) + def defaultrunner(self, connection): + return base.DefaultRunner(connection) def _set_paramstyle(self, style): self._paramstyle = style @@ -126,11 +124,10 @@ class DefaultDialect(base.Dialect): return parameters def _figure_paramstyle(self, paramstyle=None, default='named'): - db = self.dbapi() if paramstyle is not None: self._paramstyle = paramstyle - elif db is not None: - self._paramstyle = db.paramstyle + elif self.dbapi is not None: + self._paramstyle = self.dbapi.paramstyle else: self._paramstyle = default @@ -146,10 +143,6 @@ class DefaultDialect(base.Dialect): raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle) def _get_ischema(self): - # We use a property for ischema so that the accessor - # creation only happens as needed, since otherwise we - # have a circularity problem with the generic - # ansisql.engine() if self._ischema is None: import sqlalchemy.databases.information_schema as ischema self._ischema = ischema.ISchema(self) @@ -157,20 +150,49 @@ class DefaultDialect(base.Dialect): ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""") class DefaultExecutionContext(base.ExecutionContext): - def __init__(self, dialect): + def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None): self.dialect = dialect + self.connection = connection + self.compiled = compiled + self.compiled_parameters = compiled_parameters + + if compiled is not None: + self.typemap = compiled.typemap + self.column_labels = compiled.column_labels + self.statement = unicode(compiled) + else: + self.typemap = self.column_labels = None + self.parameters = parameters + self.statement = statement - def pre_exec(self, engine, proxy, compiled, parameters): - self._process_defaults(engine, proxy, compiled, parameters) + if not dialect.supports_unicode_statements(): + self.statement = self.statement.encode('ascii') + + self.cursor = self.create_cursor() + + engine = property(lambda s:s.connection.engine) + + def is_select(self): + return re.match(r'SELECT', self.statement.lstrip(), re.I) + + def create_cursor(self): + return self.connection.connection.cursor() + + def pre_exec(self): + self._process_defaults() + self.parameters = self.dialect.convert_compiled_params(self.compiled_parameters) - def post_exec(self, engine, proxy, compiled, parameters): + def post_exec(self): pass - def get_rowcount(self, cursor): + def get_result_proxy(self): + return base.ResultProxy(self) + + def get_rowcount(self): if hasattr(self, '_rowcount'): return self._rowcount else: - return cursor.rowcount + return self.cursor.rowcount def supports_sane_rowcount(self): return self.dialect.supports_sane_rowcount() @@ -187,44 +209,44 @@ class DefaultExecutionContext(base.ExecutionContext): def lastrow_has_defaults(self): return self._lastrow_has_defaults - def set_input_sizes(self, cursor, parameters): + def set_input_sizes(self): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DBAPI types from the bind parameter's ``TypeEngine`` objects. """ - if isinstance(parameters, list): - plist = parameters + if isinstance(self.compiled_parameters, list): + plist = self.compiled_parameters else: - plist = [parameters] + plist = [self.compiled_parameters] if self.dialect.positional: inputsizes = [] for params in plist[0:1]: for key in params.positional: typeengine = params.binds[key].type - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module) + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if dbtype is not None: inputsizes.append(dbtype) - cursor.setinputsizes(*inputsizes) + self.cursor.setinputsizes(*inputsizes) else: inputsizes = {} for params in plist[0:1]: for key in params.keys(): typeengine = params.binds[key].type - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module) + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if dbtype is not None: inputsizes[key] = dbtype - cursor.setinputsizes(**inputsizes) + self.cursor.setinputsizes(**inputsizes) - def _process_defaults(self, engine, proxy, compiled, parameters): + def _process_defaults(self): """``INSERT`` and ``UPDATE`` statements, when compiled, may have additional columns added to their ``VALUES`` and ``SET`` lists corresponding to column defaults/onupdates that are present on the ``Table`` object (i.e. ``ColumnDefault``, ``Sequence``, ``PassiveDefault``). This method pre-execs those ``DefaultGenerator`` objects that require pre-execution - and sets their values within the parameter list, and flags the - thread-local state about ``PassiveDefault`` objects that may + and sets their values within the parameter list, and flags this + ExecutionContext about ``PassiveDefault`` objects that may require post-fetching the row after it is inserted/updated. This method relies upon logic within the ``ANSISQLCompiler`` @@ -234,30 +256,28 @@ class DefaultExecutionContext(base.ExecutionContext): statement. """ - if compiled is None: return - - if getattr(compiled, "isinsert", False): - if isinstance(parameters, list): - plist = parameters + if self.compiled.isinsert: + if isinstance(self.compiled_parameters, list): + plist = self.compiled_parameters else: - plist = [parameters] - drunner = self.dialect.defaultrunner(engine, proxy) + plist = [self.compiled_parameters] + drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection)) self._lastrow_has_defaults = False for param in plist: last_inserted_ids = [] need_lastrowid=False # check the "default" status of each column in the table - for c in compiled.statement.table.c: + for c in self.compiled.statement.table.c: # check if it will be populated by a SQL clause - we'll need that # after execution. - if c in compiled.inline_params: + if c in self.compiled.inline_params: self._lastrow_has_defaults = True if c.primary_key: need_lastrowid = True # check if its not present at all. see if theres a default # and fire it off, and add to bind parameters. if # its a pk, add the value to our last_inserted_ids list, - # or, if its a SQL-side default, dont do any of that, but we'll need + # or, if its a SQL-side default, let it fire off on the DB side, but we'll need # the SQL-generated value after execution. elif not c.key in param or param.get_original(c.key) is None: if isinstance(c.default, schema.PassiveDefault): @@ -278,19 +298,19 @@ class DefaultExecutionContext(base.ExecutionContext): else: self._last_inserted_ids = last_inserted_ids self._last_inserted_params = param - elif getattr(compiled, 'isupdate', False): - if isinstance(parameters, list): - plist = parameters + elif self.compiled.isupdate: + if isinstance(self.compiled_parameters, list): + plist = self.compiled_parameters else: - plist = [parameters] - drunner = self.dialect.defaultrunner(engine, proxy) + plist = [self.compiled_parameters] + drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection)) self._lastrow_has_defaults = False for param in plist: # check the "onupdate" status of each column in the table - for c in compiled.statement.table.c: + for c in self.compiled.statement.table.c: # it will be populated by a SQL clause - we'll need that # after execution. - if c in compiled.inline_params: + if c in self.compiled.inline_params: pass # its not in the bind parameters, and theres an "onupdate" defined for the column; # execute it and add to bind params diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index 8ac721b77c..1b760fca8b 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -50,6 +50,16 @@ class DefaultEngineStrategy(EngineStrategy): if k in kwargs: dialect_args[k] = kwargs.pop(k) + dbapi = kwargs.pop('module', None) + if dbapi is None: + dbapi_args = {} + for k in util.get_func_kwargs(module.dbapi): + if k in kwargs: + dbapi_args[k] = kwargs.pop(k) + dbapi = module.dbapi(**dbapi_args) + + dialect_args['dbapi'] = dbapi + # create dialect dialect = module.dialect(**dialect_args) @@ -60,10 +70,6 @@ class DefaultEngineStrategy(EngineStrategy): # look for existing pool or create pool = kwargs.pop('pool', None) if pool is None: - dbapi = kwargs.pop('module', dialect.dbapi()) - if dbapi is None: - raise exceptions.InvalidRequestError("Can't get DBAPI module for dialect '%s'" % dialect) - def connect(): try: return dbapi.connect(*cargs, **cparams) @@ -73,6 +79,7 @@ class DefaultEngineStrategy(EngineStrategy): poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool)) pool_args = {} + # consume pool arguments from kwargs, translating a few of the arguments for k in util.get_cls_kwargs(poolclass): tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k) @@ -139,3 +146,52 @@ class ThreadLocalEngineStrategy(DefaultEngineStrategy): return threadlocal.TLEngine ThreadLocalEngineStrategy() + + +class MockEngineStrategy(EngineStrategy): + """Produces a single Connection object which dispatches statement executions + to a passed-in function""" + def __init__(self): + EngineStrategy.__init__(self, 'mock') + + def create(self, name_or_url, executor, **kwargs): + # create url.URL object + u = url.make_url(name_or_url) + + # get module from sqlalchemy.databases + module = u.get_module() + + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(module.dialect): + if k in kwargs: + dialect_args[k] = kwargs.pop(k) + + # create dialect + dialect = module.dialect(**dialect_args) + + return MockEngineStrategy.MockConnection(dialect, executor) + + class MockConnection(base.Connectable): + def __init__(self, dialect, execute): + self._dialect = dialect + self.execute = execute + + engine = property(lambda s: s) + dialect = property(lambda s:s._dialect) + + def contextual_connect(self): + return self + + def create(self, entity, **kwargs): + kwargs['checkfirst'] = False + entity.accept_visitor(self.dialect.schemagenerator(self, **kwargs)) + + def drop(self, entity, **kwargs): + kwargs['checkfirst'] = False + entity.accept_visitor(self.dialect.schemadropper(self, **kwargs)) + + def execute(self, object, *multiparams, **params): + raise NotImplementedError() + +MockEngineStrategy() \ No newline at end of file diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index edb8cf32e8..faa0ffc11c 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -71,6 +71,10 @@ class URL(object): def get_module(self): """Return the SQLAlchemy database module corresponding to this URL's driver name.""" + if self.drivername == 'ansi': + import sqlalchemy.ansisql + return sqlalchemy.ansisql + try: return getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername) except ImportError: diff --git a/lib/sqlalchemy/logging.py b/lib/sqlalchemy/logging.py index 6f43687079..91326233a6 100644 --- a/lib/sqlalchemy/logging.py +++ b/lib/sqlalchemy/logging.py @@ -31,8 +31,8 @@ import sys # py2.5 absolute imports will fix.... logging = __import__('logging') -# turn off logging at the root sqlalchemy level -logging.getLogger('sqlalchemy').setLevel(logging.ERROR) + +logging.getLogger('sqlalchemy').setLevel(logging.WARN) default_enabled = False def default_logging(name): diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 787fd059f2..8d559aff52 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -237,7 +237,9 @@ class _ConnectionFairy(object): raise if self.__pool.echo: self.__pool.log("Connection %s checked out from pool" % repr(self.connection)) - + + _logger = property(lambda self: self.__pool.logger) + def invalidate(self): if self.connection is None: raise exceptions.InvalidRequestError("This connection is closed") @@ -248,7 +250,8 @@ class _ConnectionFairy(object): def cursor(self, *args, **kwargs): try: - return _CursorFairy(self, self.connection.cursor(*args, **kwargs)) + c = self.connection.cursor(*args, **kwargs) + return _CursorFairy(self, c) except Exception, e: self.invalidate() raise @@ -307,11 +310,14 @@ class _CursorFairy(object): def invalidate(self): self.__parent.invalidate() - + def close(self): if self in self.__parent._cursors: del self.__parent._cursors[self] - self.cursor.close() + try: + self.cursor.close() + except Exception, e: + self.__parent._logger.warn("Error closing cursor: " + str(e)) def __getattr__(self, key): return getattr(self.cursor, key) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 87cbdaf0c3..f6c2315ae9 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -508,7 +508,7 @@ class ClauseParameters(object): return d def __repr__(self): - return repr(self.get_original_dict()) + return self.__class__.__name__ + ":" + repr(self.get_original_dict()) class ClauseVisitor(object): """A class that knows how to traverse and visit diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 86e323c6ea..7d7dbeeedf 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -53,28 +53,12 @@ class TypeEngine(AbstractType): def __init__(self, *args, **params): pass - def engine_impl(self, engine): - """Deprecated; call dialect_impl with a dialect directly.""" - - return self.dialect_impl(engine.dialect) - def dialect_impl(self, dialect): try: return self.impl_dict[dialect] except KeyError: return self.impl_dict.setdefault(dialect, dialect.type_descriptor(self)) - def _get_impl(self): - if hasattr(self, '_impl'): - return self._impl - else: - return NULLTYPE - - def _set_impl(self, impl): - self._impl = impl - - impl = property(_get_impl, _set_impl) - def get_col_spec(self): raise NotImplementedError() @@ -86,26 +70,25 @@ class TypeEngine(AbstractType): def adapt(self, cls): return cls() - + + def get_search_list(self): + """return a list of classes to test for a match + when adapting this type to a dialect-specific type. + + """ + + return self.__class__.__mro__[0:-1] + class TypeDecorator(AbstractType): def __init__(self, *args, **kwargs): if not hasattr(self.__class__, 'impl'): raise exceptions.AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated") self.impl = self.__class__.impl(*args, **kwargs) - def engine_impl(self, engine): - return self.dialect_impl(engine.dialect) - def dialect_impl(self, dialect): try: return self.impl_dict[dialect] except: - # see if the dialect has an adaptation of the TypeDecorator itself - adapted_decorator = dialect.type_descriptor(self) - if adapted_decorator is not self: - result = adapted_decorator.dialect_impl(dialect) - self.impl_dict[dialect] = result - return result typedesc = dialect.type_descriptor(self.impl) tt = self.copy() if not isinstance(tt, self.__class__): @@ -168,8 +151,7 @@ def to_instance(typeobj): def adapt_type(typeobj, colspecs): if isinstance(typeobj, type): typeobj = typeobj() - - for t in typeobj.__class__.__mro__[0:-1]: + for t in typeobj.get_search_list(): try: impltype = colspecs[t] break @@ -198,26 +180,28 @@ class NullTypeEngine(TypeEngine): return value class String(TypeEngine): - def __new__(cls, *args, **kwargs): - if cls is not String or len(args) > 0 or kwargs.has_key('length'): - return super(String, cls).__new__(cls, *args, **kwargs) - else: - return super(String, TEXT).__new__(TEXT, *args, **kwargs) - - def __init__(self, length = None): + def __init__(self, length=None, convert_unicode=False): self.length = length + self.convert_unicode = convert_unicode def adapt(self, impltype): - return impltype(length=self.length) + return impltype(length=self.length, convert_unicode=self.convert_unicode) def convert_bind_param(self, value, dialect): - if not dialect.convert_unicode or value is None or not isinstance(value, unicode): + if not (self.convert_unicode or dialect.convert_unicode) or value is None or not isinstance(value, unicode): return value else: return value.encode(dialect.encoding) + def get_search_list(self): + l = super(String, self).get_search_list() + if self.length is None: + return (TEXT,) + l + else: + return l + def convert_result_value(self, value, dialect): - if not dialect.convert_unicode or value is None or isinstance(value, unicode): + if not (self.convert_unicode or dialect.convert_unicode) or value is None or isinstance(value, unicode): return value else: return value.decode(dialect.encoding) @@ -228,21 +212,11 @@ class String(TypeEngine): def compare_values(self, x, y): return x == y -class Unicode(TypeDecorator): - impl = String - - def convert_bind_param(self, value, dialect): - if value is not None and isinstance(value, unicode): - return value.encode(dialect.encoding) - else: - return value - - def convert_result_value(self, value, dialect): - if value is not None and not isinstance(value, unicode): - return value.decode(dialect.encoding) - else: - return value - +class Unicode(String): + def __init__(self, length=None, **kwargs): + kwargs['convert_unicode'] = True + super(Unicode, self).__init__(length=length, **kwargs) + class Integer(TypeEngine): """Integer datatype.""" @@ -310,7 +284,7 @@ class Binary(TypeEngine): def convert_bind_param(self, value, dialect): if value is not None: - return dialect.dbapi().Binary(value) + return dialect.dbapi.Binary(value) else: return None diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index dadcf0ddee..238f12493f 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -94,6 +94,10 @@ def get_cls_kwargs(cls): kw.append(vn) return kw +def get_func_kwargs(func): + """Return the full set of legal kwargs for the given `func`.""" + return [vn for vn in func.func_code.co_varnames] + class SimpleProperty(object): """A *default* property accessor.""" diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 62cd92b6e6..0c6323c10f 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -498,9 +498,10 @@ class SchemaTest(PersistTest): # insure this doesnt crash print [t for t in metadata.table_iterator()] buf = StringIO.StringIO() - def foo(s, p): + def foo(s, p=None): buf.write(s) - gen = testbase.db.dialect.schemagenerator(testbase.db.engine, foo, None) + gen = create_engine(testbase.db.name + "://", strategy="mock", executor=foo) + gen = gen.dialect.schemagenerator(gen) gen.traverse(table1) gen.traverse(table2) buf = buf.getvalue() diff --git a/test/orm/inheritance5.py b/test/orm/inheritance5.py index f92e70df3a..bdc9e02e12 100644 --- a/test/orm/inheritance5.py +++ b/test/orm/inheritance5.py @@ -42,7 +42,7 @@ class RelationTest1(testbase.ORMTest): try: compile_mappers() except exceptions.ArgumentError, ar: - assert str(ar) == "Cant determine relation direction for relationship 'Person.manager (Manager)' - foreign key columns are present in both the parent and the child's mapped tables. Specify 'foreign_keys' argument." + assert str(ar) == "Can't determine relation direction for relationship 'Person.manager (Manager)' - foreign key columns are present in both the parent and the child's mapped tables. Specify 'foreign_keys' argument.", str(ar) clear_mappers() diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 6f80df38f5..839a5172e6 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1332,7 +1332,7 @@ class InstancesTest(MapperSuperTest): 'addresses':relation(Address, lazy=True) }) mapper(Address, addresses) - query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True) + query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True,order_by=['ulist.user_id', addresses.c.address_id]) q = create_session().query(User) def go(): @@ -1348,7 +1348,7 @@ class InstancesTest(MapperSuperTest): }) mapper(Address, addresses) - selectquery = users.outerjoin(addresses).select(use_labels=True) + selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.user_id, addresses.c.address_id]) q = create_session().query(User) def go(): @@ -1363,7 +1363,7 @@ class InstancesTest(MapperSuperTest): mapper(Address, addresses) adalias = addresses.alias('adalias') - selectquery = users.outerjoin(adalias).select(use_labels=True) + selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id]) q = create_session().query(User) def go(): @@ -1378,7 +1378,7 @@ class InstancesTest(MapperSuperTest): mapper(Address, addresses) adalias = addresses.alias('adalias') - selectquery = users.outerjoin(adalias).select(use_labels=True) + selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id]) q = create_session().query(User) def go(): @@ -1393,7 +1393,7 @@ class InstancesTest(MapperSuperTest): mapper(Address, addresses) adalias = addresses.alias('adalias') - selectquery = users.outerjoin(adalias).select(use_labels=True) + selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.user_id, adalias.c.address_id]) def decorate(row): d = {} for c in addresses.columns: @@ -1418,7 +1418,7 @@ class InstancesTest(MapperSuperTest): (user7, user8, user9) = sess.query(User).select() (address1, address2, address3, address4) = sess.query(Address).select() - selectquery = users.outerjoin(addresses).select(use_labels=True) + selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.user_id, addresses.c.address_id]) q = sess.query(User) l = q.instances(selectquery.execute(), Address) # note the result is a cartesian product diff --git a/test/sql/constraints.py b/test/sql/constraints.py index 231a491b52..d695e824c7 100644 --- a/test/sql/constraints.py +++ b/test/sql/constraints.py @@ -172,11 +172,13 @@ class ConstraintTest(testbase.AssertMixin): capt = [] connection = testbase.db.connect() - def proxy(statement, parameters): - capt.append(statement) - capt.append(repr(parameters)) - connection.proxy(statement, parameters) - schemagen = testbase.db.dialect.schemagenerator(testbase.db, proxy, connection) + ex = connection._execute + def proxy(context): + capt.append(context.statement) + capt.append(repr(context.parameters)) + ex(context) + connection._execute = proxy + schemagen = testbase.db.dialect.schemagenerator(connection) schemagen.traverse(events) assert capt[0].strip().startswith('CREATE TABLE events') diff --git a/test/sql/query.py b/test/sql/query.py index 3c3e2334c0..08c766a0df 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -357,7 +357,7 @@ class QueryTest(PersistTest): Column('__parent', VARCHAR(20)), Column('__row', VARCHAR(20)), ) - shadowed.create() + shadowed.create(checkfirst=True) try: shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row') r = shadowed.select(shadowed.c.shadow_id==1).execute().fetchone() @@ -374,7 +374,7 @@ class QueryTest(PersistTest): pass # expected r.close() finally: - shadowed.drop() + shadowed.drop(checkfirst=True) class CompoundTest(PersistTest): """test compound statements like UNION, INTERSECT, particularly their ability to nest on diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 97e95d3892..d1256b31a5 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -6,7 +6,7 @@ import string,datetime, re, sys, os import sqlalchemy.engine.url as url import sqlalchemy.types - +from sqlalchemy.databases import mssql, oracle db = testbase.db @@ -22,18 +22,19 @@ class MyType(types.TypeEngine): class MyDecoratedType(types.TypeDecorator): impl = String - def convert_bind_param(self, value, engine): - return "BIND_IN"+ value - def convert_result_value(self, value, engine): - return value + "BIND_OUT" + def convert_bind_param(self, value, dialect): + return "BIND_IN"+ super(MyDecoratedType, self).convert_bind_param(value, dialect) + def convert_result_value(self, value, dialect): + return super(MyDecoratedType, self).convert_result_value(value, dialect) + "BIND_OUT" def copy(self): return MyDecoratedType() -class MyUnicodeType(types.Unicode): - def convert_bind_param(self, value, engine): - return "UNI_BIND_IN"+ value - def convert_result_value(self, value, engine): - return value + "UNI_BIND_OUT" +class MyUnicodeType(types.TypeDecorator): + impl = Unicode + def convert_bind_param(self, value, dialect): + return "UNI_BIND_IN"+ super(MyUnicodeType, self).convert_bind_param(value, dialect) + def convert_result_value(self, value, dialect): + return super(MyUnicodeType, self).convert_result_value(value, dialect) + "UNI_BIND_OUT" def copy(self): return MyUnicodeType(self.impl.length) @@ -52,31 +53,29 @@ class AdaptTest(PersistTest): assert t2 != t3 assert t3 != t1 - def testdecorator(self): - t1 = Unicode(20) - t2 = Unicode() - assert isinstance(t1.impl, String) - assert not isinstance(t1.impl, TEXT) - assert (t1.impl.length == 20) - assert isinstance(t2.impl, TEXT) - assert t2.impl.length is None - - - def testdialecttypedecorators(self): - """test that a a Dialect can provide a dialect-specific subclass of a TypeDecorator subclass.""" - import sqlalchemy.databases.mssql as mssql + def testmsnvarchar(self): dialect = mssql.MSSQLDialect() # run the test twice to insure the caching step works too for x in range(0, 1): col = Column('', Unicode(length=10)) dialect_type = col.type.dialect_impl(dialect) - assert isinstance(dialect_type, mssql.MSUnicode) + assert isinstance(dialect_type, mssql.MSNVarchar) assert dialect_type.get_col_spec() == 'NVARCHAR(10)' - assert isinstance(dialect_type.impl, mssql.MSString) - + + def testoracletext(self): + dialect = oracle.OracleDialect() + col = Column('', MyDecoratedType) + dialect_type = col.type.dialect_impl(dialect) + assert isinstance(dialect_type.impl, oracle.OracleText), repr(dialect_type.impl) + class OverrideTest(PersistTest): """tests user-defined types, including a full type as well as a TypeDecorator""" + def testbasic(self): + print users.c.goofy4.type + print users.c.goofy4.type.dialect_impl(testbase.db.dialect) + print users.c.goofy4.type.dialect_impl(testbase.db.dialect).get_col_spec() + def testprocessing(self): global users diff --git a/test/testbase.py b/test/testbase.py index 8a1d9ee59a..aae455673f 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -1,12 +1,9 @@ import sys sys.path.insert(0, './lib/') -import os -import unittest -import StringIO -import sqlalchemy.ext.proxy as proxy -import re +import os, unittest, StringIO, re import sqlalchemy from sqlalchemy import sql, engine, pool +import sqlalchemy.engine.base as base import optparse from sqlalchemy.schema import BoundMetaData from sqlalchemy.orm import clear_mappers @@ -49,6 +46,7 @@ def parse_argv(): parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to plain)") parser.add_option("--coverage", action="store_true", dest="coverage", help="Dump a full coverage report after running") parser.add_option("--reversetop", action="store_true", dest="topological", help="Reverse the collection ordering for topological sorts (helps reveal dependency issues)") + parser.add_option("--serverside", action="store_true", dest="serverside", help="Turn on server side cursors for PG") (options, args) = parser.parse_args() sys.argv[1:] = args @@ -73,7 +71,7 @@ def parse_argv(): db_uri = 'oracle://scott:tiger@127.0.0.1:1521' elif DBTYPE == 'oracle8': db_uri = 'oracle://scott:tiger@127.0.0.1:1521' - opts = {'use_ansi':False} + opts['use_ansi'] = False elif DBTYPE == 'mssql': db_uri = 'mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test' elif DBTYPE == 'firebird': @@ -94,6 +92,9 @@ def parse_argv(): global with_coverage with_coverage = options.coverage + + if options.serverside: + opts['server_side_cursors'] = True if options.enginestrategy is not None: opts['strategy'] = options.enginestrategy @@ -101,7 +102,16 @@ def parse_argv(): db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **opts) else: db = engine.create_engine(db_uri, **opts) - db = EngineAssert(db) + + # decorate the dialect's create_execution_context() method + # to produce a wrapper + create_context = db.dialect.create_execution_context + def create_exec_context(*args, **kwargs): + return ExecutionContextWrapper(create_context(*args, **kwargs)) + db.dialect.create_execution_context = create_exec_context + + global testdata + testdata = TestData(db) if options.topological: from sqlalchemy.orm import unitofwork @@ -172,8 +182,6 @@ class PersistTest(unittest.TestCase): """overridden to not return docstrings""" return None - - class AssertMixin(PersistTest): """given a list-based structure of keys/properties which represent information within an object structure, and a list of actual objects, asserts that the list of objects corresponds to the structure.""" @@ -197,20 +205,24 @@ class AssertMixin(PersistTest): else: self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value)) def assert_sql(self, db, callable_, list, with_sequences=None): + global testdata + testdata = TestData(db) if with_sequences is not None and (db.engine.name == 'postgres' or db.engine.name == 'oracle'): - db.set_assert_list(self, with_sequences) + testdata.set_assert_list(self, with_sequences) else: - db.set_assert_list(self, list) + testdata.set_assert_list(self, list) try: callable_() finally: - db.set_assert_list(None, None) + testdata.set_assert_list(None, None) + def assert_sql_count(self, db, callable_, count): - db.sql_count = 0 + global testdata + testdata = TestData(db) try: callable_() finally: - self.assert_(db.sql_count == count, "desired statement count %d does not match %d" % (count, db.sql_count)) + self.assert_(testdata.sql_count == count, "desired statement count %d does not match %d" % (count, testdata.sql_count)) class ORMTest(AssertMixin): keep_mappers = False @@ -233,83 +245,73 @@ class ORMTest(AssertMixin): for t in metadata.table_iterator(reverse=True): t.delete().execute().close() -class EngineAssert(proxy.BaseProxyEngine): - """decorates a SQLEngine object to match the incoming queries against a set of assertions.""" +class TestData(object): def __init__(self, engine): self._engine = engine - - self.real_execution_context = engine.dialect.create_execution_context - engine.dialect.create_execution_context = self.execution_context - self.logger = engine.logger self.set_assert_list(None, None) self.sql_count = 0 - def get_engine(self): - return self._engine - def set_engine(self, e): - self._engine = e + def set_assert_list(self, unittest, list): self.unittest = unittest self.assert_list = list if list is not None: self.assert_list.reverse() - def _set_echo(self, echo): - self.engine.echo = echo - echo = property(lambda s: s.engine.echo, _set_echo) - def execution_context(self): - def post_exec(engine, proxy, compiled, parameters, **kwargs): - ctx = e - self.engine.logger = self.logger - statement = unicode(compiled) - statement = re.sub(r'\n', '', statement) - - if self.assert_list is not None: - item = self.assert_list[-1] - if not isinstance(item, dict): - item = self.assert_list.pop() - else: - # asserting a dictionary of statements->parameters - # this is to specify query assertions where the queries can be in - # multiple orderings - if not item.has_key('_converted'): - for key in item.keys(): - ckey = self.convert_statement(key) - item[ckey] = item[key] - if ckey != key: - del item[key] - item['_converted'] = True - try: - entry = item.pop(statement) - if len(item) == 1: - self.assert_list.pop() - item = (statement, entry) - except KeyError: - self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)) - - (query, params) = item - if callable(params): - params = params(ctx) - if params is not None and isinstance(params, list) and len(params) == 1: - params = params[0] - - if isinstance(parameters, sql.ClauseParameters): - parameters = parameters.get_original_dict() - elif isinstance(parameters, list): - parameters = [p.get_original_dict() for p in parameters] - - query = self.convert_statement(query) - self.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) - self.sql_count += 1 - return realexec(ctx, proxy, compiled, parameters, **kwargs) - - e = self.real_execution_context() - realexec = e.post_exec - realexec.im_self.post_exec = post_exec - return e +class ExecutionContextWrapper(object): + def __init__(self, ctx): + self.__dict__['ctx'] = ctx + def __getattr__(self, key): + return getattr(self.ctx, key) + def __setattr__(self, key, value): + setattr(self.ctx, key, value) + + def post_exec(self): + ctx = self.ctx + statement = unicode(ctx.compiled) + statement = re.sub(r'\n', '', ctx.statement) + + if testdata.assert_list is not None: + item = testdata.assert_list[-1] + if not isinstance(item, dict): + item = testdata.assert_list.pop() + else: + # asserting a dictionary of statements->parameters + # this is to specify query assertions where the queries can be in + # multiple orderings + if not item.has_key('_converted'): + for key in item.keys(): + ckey = self.convert_statement(key) + item[ckey] = item[key] + if ckey != key: + del item[key] + item['_converted'] = True + try: + entry = item.pop(statement) + if len(item) == 1: + testdata.assert_list.pop() + item = (statement, entry) + except KeyError: + self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)) + + (query, params) = item + if callable(params): + params = params(ctx) + if params is not None and isinstance(params, list) and len(params) == 1: + params = params[0] + + if isinstance(ctx.compiled_parameters, sql.ClauseParameters): + parameters = ctx.compiled_parameters.get_original_dict() + elif isinstance(ctx.compiled_parameters, list): + parameters = [p.get_original_dict() for p in ctx.compiled_parameters] + + query = self.convert_statement(query) + testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) + testdata.sql_count += 1 + self.ctx.post_exec() def convert_statement(self, query): - paramstyle = self.engine.dialect.paramstyle + paramstyle = self.ctx.dialect.paramstyle if paramstyle == 'named': pass elif paramstyle =='pyformat': -- 2.47.2