From: Rick Morrison Date: Sun, 18 Mar 2007 17:14:10 +0000 (+0000) Subject: mssql: cleanup of module importing code; specifiable DB-API module; more explicit... X-Git-Tag: rel_0_3_6~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1c56142219188636ff57a1e4dbb8ee0d57e3dc88;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git mssql: cleanup of module importing code; specifiable DB-API module; more explicit ordering of module preferences. [ticket:480] --- diff --git a/CHANGES b/CHANGES index 3eac53cfcb..6b39299e19 100644 --- a/CHANGES +++ b/CHANGES @@ -163,6 +163,9 @@ for large unsized string fields. Use the new "text_as_varchar" to turn it on. [ticket:509] + - cleanup of module importing code; specifiable DB-API module; more + explicit ordering of module preferences. [ticket:480] + diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index f7d64761af..9ae6572507 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -30,7 +30,7 @@ Known issues / TODO: -* No support for more than one ``IDENTITY`` column per table no +* No support for more than one ``IDENTITY`` column per table * No support for table reflection of ``IDENTITY`` columns with (seed,increment) values other than (1,1) @@ -38,8 +38,7 @@ Known issues / TODO: * No support for ``GUID`` type columns (yet) * pymssql has problems with binary and unicode data that this module - does **not** work around adodbapi fails testtypes.py unit test on - unicode data too -- issue with the test? + does **not** work around """ import sys, StringIO, string, types, re, datetime @@ -52,96 +51,6 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions -dbmodule = None -dialect = None - -def use_adodbapi(): - global dbmodule, connect, make_connect_string, do_commit, sane_rowcount, dialect, colspecs, ischema_names - import adodbapi as dbmodule - # ADODBAPI has a non-standard Connection method - connect = dbmodule.Connection - def make_connect_string(keys): - connectors = ["Provider=SQLOLEDB"] - connectors.append ("Data Source=%s" % keys.get("host")) - connectors.append ("Initial Catalog=%s" % keys.get("database")) - user = keys.get("user") - if user: - connectors.append("User Id=%s" % user) - connectors.append("Password=%s" % keys.get("password", "")) - else: - connectors.append("Integrated Security=SSPI") - return [[";".join (connectors)], {}] - sane_rowcount = True - dialect = MSSQLDialect - colspecs[sqltypes.Unicode] = AdoMSUnicode - ischema_names['nvarchar'] = AdoMSUnicode - -def use_pymssql(): - global dbmodule, connect, make_connect_string, do_commit, sane_rowcount, dialect, colspecs, ischema_names - import pymssql as dbmodule - connect = dbmodule.connect - # pymmsql doesn't have a Binary method. we use string - dbmodule.Binary = lambda st: str(st) - def make_connect_string(keys): - if keys.get('port'): - # pymssql expects port as host:port, not a separate arg - keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])]) - del keys['port'] - return [[], keys] - do_commit = True - sane_rowcount = False - dialect = PyMSSQLDialect - colspecs[sqltypes.Unicode] = MSUnicode - ischema_names['nvarchar'] = MSUnicode - -def use_pyodbc(): - global dbmodule, connect, make_connect_string, do_commit, sane_rowcount, dialect, colspecs, ischema_names - import pyodbc as dbmodule - connect = dbmodule.connect - def make_connect_string(keys): - connectors = ["Driver={SQL Server}"] - connectors.append("Server=%s" % keys.get("host")) - connectors.append("Database=%s" % keys.get("database")) - user = keys.get("user") - if user: - connectors.append("UID=%s" % user) - connectors.append("PWD=%s" % keys.get("password", "")) - else: - connectors.append ("TrustedConnection=Yes") - return [[";".join (connectors)], {}] - do_commit = True - sane_rowcount = False - dialect = MSSQLDialect - import warnings - warnings.warn('pyodbc support in sqlalchemy.databases.mssql is experimental - use at your own risk.') - colspecs[sqltypes.Unicode] = AdoMSUnicode - ischema_names['nvarchar'] = AdoMSUnicode - -def use_default(): - import_errors = [] - def try_use(f): - try: - f() - except ImportError, e: - import_errors.append(e) - return False - else: - return True - for f in [ - # XXX - is this the best default ordering? For now, it retains the current (2007-Jan-11) - # default - that is, adodbapi first, pymssql second - and adds pyodbc as a third option. - # However, my tests suggest that the exact opposite order may be the best! - use_adodbapi, - use_pymssql, - use_pyodbc, - ]: - if try_use(f): - return dbmodule # informational return, so the user knows what he's using. - else: - return None - # cant raise this right now since all dialects need to be importable/loadable - #raise ImportError(import_errors) - class MSNumeric(sqltypes.Numeric): def convert_result_value(self, value, dialect): @@ -232,19 +141,16 @@ class MSString(sqltypes.String): return "VARCHAR(%(length)s)" % {'length' : self.length} class MSNVarchar(MSString): - """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True. - """ - + """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True. """ impl = sqltypes.Unicode def get_col_spec(self): if self.length: return "NVARCHAR(%(length)s)" % {'length' : self.length} + elif self.dialect.text_as_varchar: + return "NVARCHAR(max)" else: - if self.dialect.text_as_varchar: - return "NVARCHAR(max)" - else: - return "NTEXT" + return "NTEXT" class AdoMSNVarchar(MSNVarchar): def convert_bind_param(self, value, dialect): @@ -255,7 +161,6 @@ class AdoMSNVarchar(MSNVarchar): class MSUnicode(sqltypes.Unicode): """Unicode subclass, does Unicode conversion in all cases, uses NVARCHAR impl.""" - impl = MSNVarchar class AdoMSUnicode(MSUnicode): @@ -298,43 +203,6 @@ class MSBoolean(sqltypes.Boolean): else: return value and True or False -colspecs = { - sqltypes.Integer : MSInteger, - sqltypes.Smallinteger: MSSmallInteger, - sqltypes.Numeric : MSNumeric, - sqltypes.Float : MSFloat, - sqltypes.DateTime : MSDateTime, - sqltypes.Date : MSDate, - sqltypes.String : MSString, - sqltypes.Unicode : MSUnicode, - sqltypes.Binary : MSBinary, - sqltypes.Boolean : MSBoolean, - sqltypes.TEXT : MSText, - sqltypes.CHAR: MSChar, - sqltypes.NCHAR: MSNChar, -} - -ischema_names = { - 'int' : MSInteger, - 'smallint' : MSSmallInteger, - 'tinyint' : MSTinyInteger, - 'varchar' : MSString, - 'nvarchar' : MSUnicode, - 'char' : MSChar, - 'nchar' : MSNChar, - 'text' : MSText, - 'ntext' : MSText, - 'decimal' : MSNumeric, - 'numeric' : MSNumeric, - 'float' : MSFloat, - 'datetime' : MSDateTime, - 'smalldatetime' : MSDate, - 'binary' : MSBinary, - 'bit': MSBoolean, - 'real' : MSFloat, - 'image' : MSBinary -} - def descriptor(): return {'name':'mssql', 'description':'MSSQL', @@ -406,11 +274,63 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): class MSSQLDialect(ansisql.ANSIDialect): - def __init__(self, module=None, auto_identity_insert=True, **params): - self.module = module or dbmodule or use_default() + colspecs = { + sqltypes.Integer : MSInteger, + sqltypes.Smallinteger: MSSmallInteger, + sqltypes.Numeric : MSNumeric, + sqltypes.Float : MSFloat, + sqltypes.DateTime : MSDateTime, + sqltypes.Date : MSDate, + sqltypes.String : MSString, + sqltypes.Unicode : MSUnicode, + sqltypes.Binary : MSBinary, + sqltypes.Boolean : MSBoolean, + sqltypes.TEXT : MSText, + sqltypes.CHAR: MSChar, + sqltypes.NCHAR: MSNChar, + } + + ischema_names = { + 'int' : MSInteger, + 'smallint' : MSSmallInteger, + 'tinyint' : MSTinyInteger, + 'varchar' : MSString, + 'nvarchar' : MSUnicode, + 'char' : MSChar, + 'nchar' : MSNChar, + 'text' : MSText, + 'ntext' : MSText, + 'decimal' : MSNumeric, + 'numeric' : MSNumeric, + 'float' : MSFloat, + 'datetime' : MSDateTime, + 'smalldatetime' : MSDate, + 'binary' : MSBinary, + 'bit': MSBoolean, + 'real' : MSFloat, + 'image' : MSBinary + } + + def __new__(cls, module_name=None, *args, **kwargs): + if cls != MSSQLDialect: + return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs) + if module_name: + dialect = dialect_mapping.get(module_name) + if not dialect: + raise exceptions.InvalidRequestError('Unsupported MSSQL module requested (must be adodbpi, pymssql or pyodbc): ' + module_name) + if not hasattr(dialect, 'module'): + raise dialect.saved_import_error + return dialect(*args, **kwargs) + else: + for dialect in dialect_preference: + if hasattr(dialect, 'module'): + return dialect(*args, **kwargs) + raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc') + + def __init__(self, module_name=None, auto_identity_insert=True, **params): + super(MSSQLDialect, self).__init__(**params) self.auto_identity_insert = auto_identity_insert self.text_as_varchar = False - ansisql.ANSIDialect.__init__(self, **params) self.set_default_schema_name("dbo") def create_connect_args(self, url): @@ -422,13 +342,13 @@ class MSSQLDialect(ansisql.ANSIDialect): self.query_timeout = int(opts.pop('query_timeout')) if opts.has_key('text_as_varchar'): self.text_as_varchar = bool(opts.pop('text_as_varchar')) - return make_connect_string(opts) + return self.make_connect_string(opts) def create_execution_context(self): return MSSQLExecutionContext(self) def type_descriptor(self, typeobj): - newobj = sqltypes.adapt_type(typeobj, colspecs) + newobj = sqltypes.adapt_type(typeobj, self.colspecs) # Some types need to know about the dialect if isinstance(newobj, (MSText, MSNVarchar)): newobj.dialect = self @@ -437,8 +357,9 @@ class MSSQLDialect(ansisql.ANSIDialect): def last_inserted_ids(self): return self.context.last_inserted_ids + # this is only implemented in the dbapi-specific subclasses def supports_sane_rowcount(self): - return sane_rowcount + raise NotImplementedError() def compiler(self, statement, bindparams, **kwargs): return MSSQLCompiler(self, statement, bindparams, **kwargs) @@ -556,7 +477,7 @@ class MSSQLDialect(ansisql.ANSIDialect): for a in (charlen, numericprec, numericscale): if a is not None: args.append(a) - coltype = ischema_names[type] + coltype = self.ischema_names[type] if coltype == MSString and charlen == -1: coltype = MSText() else: @@ -628,7 +549,17 @@ class MSSQLDialect(ansisql.ANSIDialect): if fknm and scols: table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) -class PyMSSQLDialect(MSSQLDialect): +class MSSQLDialect_pymssql(MSSQLDialect): + try: + import pymssql as module + # pymmsql doesn't have a Binary method. we use string + module.Binary = lambda st: str(st) + except ImportError, e: + saved_import_error = e + + def supports_sane_rowcount(self): + return True + def do_rollback(self, connection): # pymssql throws an error on repeated rollbacks. Ignore it. try: @@ -637,11 +568,19 @@ class PyMSSQLDialect(MSSQLDialect): pass def create_connect_args(self, url): - r = super(PyMSSQLDialect, self).create_connect_args(url) + r = super(MSSQLDialect_pymssql, self).create_connect_args(url) if hasattr(self, 'query_timeout'): - dbmodule._mssql.set_query_timeout(self.query_timeout) + self.module._mssql.set_query_timeout(self.query_timeout) return r + def make_connect_string(self, keys): + if keys.get('port'): + # pymssql expects port as host:port, not a separate arg + keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])]) + del keys['port'] + return [[], keys] + + ## This code is leftover from the initial implementation, for reference ## def do_begin(self, connection): ## """implementations might want to put logic here for turning autocommit on/off, etc.""" @@ -673,6 +612,68 @@ class PyMSSQLDialect(MSSQLDialect): ## r.query("begin tran") ## r.fetch_array() +class MSSQLDialect_pyodbc(MSSQLDialect): + try: + import pyodbc as module + except ImportError, e: + saved_import_error = e + + colspecs = MSSQLDialect.colspecs.copy() + colspecs[sqltypes.Unicode] = AdoMSUnicode + ischema_names = MSSQLDialect.ischema_names.copy() + ischema_names['nvarchar'] = AdoMSUnicode + + def supports_sane_rowcount(self): + return False + + def make_connect_string(self, keys): + connectors = ["Driver={SQL Server}"] + connectors.append("Server=%s" % keys.get("host")) + connectors.append("Database=%s" % keys.get("database")) + user = keys.get("user") + if user: + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % keys.get("password", "")) + else: + connectors.append ("TrustedConnection=Yes") + return [[";".join (connectors)], {}] + + +class MSSQLDialect_adodbapi(MSSQLDialect): + try: + import adodbapi as module + except ImportError, e: + saved_import_error = e + + colspecs = MSSQLDialect.colspecs.copy() + colspecs[sqltypes.Unicode] = AdoMSUnicode + ischema_names = MSSQLDialect.ischema_names.copy() + ischema_names['nvarchar'] = AdoMSUnicode + + def supports_sane_rowcount(self): + return True + + def make_connect_string(self, keys): + connectors = ["Provider=SQLOLEDB"] + connectors.append ("Data Source=%s" % keys.get("host")) + connectors.append ("Initial Catalog=%s" % keys.get("database")) + user = keys.get("user") + if user: + connectors.append("User Id=%s" % user) + connectors.append("Password=%s" % keys.get("password", "")) + else: + connectors.append("Integrated Security=SSPI") + return [[";".join (connectors)], {}] + + +dialect_mapping = { + 'pymssql': MSSQLDialect_pymssql, + 'pyodbc': MSSQLDialect_pyodbc, + 'adodbapi': MSSQLDialect_adodbapi + } +dialect_preference = [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc] + + class MSSQLCompiler(ansisql.ANSICompiler): def __init__(self, dialect, statement, parameters, **kwargs): super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs) @@ -781,7 +782,5 @@ class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): #TODO: determin MSSQL's case folding rules return value -use_default() - - +dialect = MSSQLDialect