From: Mike Bayer Date: Tue, 13 Jan 2009 22:15:29 +0000 (+0000) Subject: the most epic dialect of all. the MYSQL DIALECT. didn't port the dialect test over... X-Git-Tag: rel_0_6_6~343 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5a9ff074cc00f3e1c02a1f5e82a2f6193ebc4996;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git the most epic dialect of all. the MYSQL DIALECT. didn't port the dialect test over yet. --- diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index 7f124d7dbd..a824cd87b2 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -11,7 +11,6 @@ __all__ = ( 'informix', 'maxdb', 'mssql', - 'mysql', 'oracle', 'sybase', ) diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py index 075e897fa8..33e481d25c 100644 --- a/lib/sqlalchemy/dialects/__init__.py +++ b/lib/sqlalchemy/dialects/__init__.py @@ -4,7 +4,7 @@ __all__ = ( # 'informix', # 'maxdb', # 'mssql', -# 'mysql', + 'mysql', # 'oracle', 'postgres', 'sqlite', diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py new file mode 100644 index 0000000000..e94acc64e3 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -0,0 +1,4 @@ +from sqlalchemy.dialects.mysql import base, mysqldb + +# default dialect +base.dialect = mysqldb.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/dialects/mysql/base.py similarity index 87% rename from lib/sqlalchemy/databases/mysql.py rename to lib/sqlalchemy/dialects/mysql/base.py index ac4e64b597..9c2cf0352a 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -19,7 +19,7 @@ But if you would like to use one of the MySQL-specific or enhanced column types when creating tables with your :class:`~sqlalchemy.Table` definitions, then you will need to import them from this module:: - from sqlalchemy.databases import mysql + from sqlalchemy.dialect.mysql import base as mysql Table('mytable', metadata, Column('id', Integer, primary_key=True), @@ -64,25 +64,6 @@ Nested Transactions 5.0.3 See the official MySQL documentation for detailed information about features supported in any given server release. -Character Sets --------------- - -Many MySQL server installations default to a ``latin1`` encoding for client -connections. All data sent through the connection will be converted into -``latin1``, even if you have ``utf8`` or another character set on your tables -and columns. With versions 4.1 and higher, you can change the connection -character set either through server configuration or by including the -``charset`` parameter in the URL used for ``create_engine``. The ``charset`` -option is passed through to MySQL-Python and has the side-effect of also -enabling ``use_unicode`` in the driver by default. For regular encoded -strings, also pass ``use_unicode=0`` in the connection arguments:: - - # set client encoding to utf8; all strings come back as unicode - create_engine('mysql:///mydb?charset=utf8') - - # set client encoding to utf8; all strings come back as utf8 str - create_engine('mysql:///mydb?charset=utf8&use_unicode=0') - Storage Engines --------------- @@ -197,7 +178,6 @@ timely information affecting MySQL in SQLAlchemy. """ import datetime, decimal, inspect, re, sys -from array import array as _array from sqlalchemy import exc, log, schema, sql, util from sqlalchemy.sql import operators as sql_operators @@ -275,15 +255,6 @@ class _NumericType(object): self.unsigned = kw.pop('unsigned', False) self.zerofill = kw.pop('zerofill', False) - def _extend(self, spec): - "Extend a numeric-type declaration with MySQL specific extensions." - - if self.unsigned: - spec += ' UNSIGNED' - if self.zerofill: - spec += ' ZEROFILL' - return spec - class _StringType(object): """Base for MySQL string types.""" @@ -299,34 +270,6 @@ class _StringType(object): self.binary = binary self.national = national - def _extend(self, spec): - """Extend a string-type declaration with standard SQL CHARACTER SET / - COLLATE annotations and MySQL specific extensions. - """ - - if self.charset: - charset = 'CHARACTER SET %s' % self.charset - elif self.ascii: - charset = 'ASCII' - elif self.unicode: - charset = 'UNICODE' - else: - charset = None - - if self.collation: - collation = 'COLLATE %s' % self.collation - elif self.binary: - collation = 'BINARY' - else: - collation = None - - if self.national: - # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. - return ' '.join([c for c in ('NATIONAL', spec, collation) - if c is not None]) - return ' '.join([c for c in (spec, charset, collation) - if c is not None]) - def __repr__(self): attributes = inspect.getargspec(self.__init__)[0][1:] attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) @@ -343,7 +286,9 @@ class _StringType(object): class MSNumeric(sqltypes.Numeric, _NumericType): """MySQL NUMERIC type.""" - + + __visit_name__ = 'NUMERIC' + def __init__(self, precision=10, scale=2, asdecimal=True, **kw): """Construct a NUMERIC. @@ -363,12 +308,6 @@ class MSNumeric(sqltypes.Numeric, _NumericType): _NumericType.__init__(self, kw) sqltypes.Numeric.__init__(self, precision, scale, asdecimal=asdecimal, **kw) - def get_col_spec(self): - if self.precision is None: - return self._extend("NUMERIC") - else: - return self._extend("NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}) - def bind_processor(self, dialect): return None @@ -386,7 +325,9 @@ class MSNumeric(sqltypes.Numeric, _NumericType): class MSDecimal(MSNumeric): """MySQL DECIMAL type.""" - + + __visit_name__ = 'DECIMAL' + def __init__(self, precision=10, scale=2, asdecimal=True, **kw): """Construct a DECIMAL. @@ -405,18 +346,12 @@ class MSDecimal(MSNumeric): """ super(MSDecimal, self).__init__(precision, scale, asdecimal=asdecimal, **kw) - def get_col_spec(self): - if self.precision is None: - return self._extend("DECIMAL") - elif self.scale is None: - return self._extend("DECIMAL(%(precision)s)" % {'precision': self.precision}) - else: - return self._extend("DECIMAL(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}) - class MSDouble(sqltypes.Float, _NumericType): """MySQL DOUBLE type.""" + __visit_name__ = 'DOUBLE' + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a DOUBLE. @@ -444,18 +379,12 @@ class MSDouble(sqltypes.Float, _NumericType): self.scale = scale self.precision = precision - def get_col_spec(self): - if self.precision is not None and self.scale is not None: - return self._extend("DOUBLE(%(precision)s, %(scale)s)" % - {'precision': self.precision, - 'scale' : self.scale}) - else: - return self._extend('DOUBLE') - class MSReal(MSDouble): """MySQL REAL type.""" + __visit_name__ = 'REAL' + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a REAL. @@ -474,18 +403,12 @@ class MSReal(MSDouble): """ MSDouble.__init__(self, precision, scale, asdecimal, **kw) - def get_col_spec(self): - if self.precision is not None and self.scale is not None: - return self._extend("REAL(%(precision)s, %(scale)s)" % - {'precision': self.precision, - 'scale' : self.scale}) - else: - return self._extend('REAL') - class MSFloat(sqltypes.Float, _NumericType): """MySQL FLOAT type.""" + __visit_name__ = 'FLOAT' + def __init__(self, precision=None, scale=None, asdecimal=False, **kw): """Construct a FLOAT. @@ -507,14 +430,6 @@ class MSFloat(sqltypes.Float, _NumericType): self.scale = scale self.precision = precision - def get_col_spec(self): - if self.scale is not None and self.precision is not None: - return self._extend("FLOAT(%s, %s)" % (self.precision, self.scale)) - elif self.precision is not None: - return self._extend("FLOAT(%s)" % (self.precision,)) - else: - return self._extend("FLOAT") - def bind_processor(self, dialect): return None @@ -522,6 +437,8 @@ class MSFloat(sqltypes.Float, _NumericType): class MSInteger(sqltypes.Integer, _NumericType): """MySQL INTEGER type.""" + __visit_name__ = 'INTEGER' + def __init__(self, display_width=None, **kw): """Construct an INTEGER. @@ -543,16 +460,12 @@ class MSInteger(sqltypes.Integer, _NumericType): _NumericType.__init__(self, kw) sqltypes.Integer.__init__(self, **kw) - def get_col_spec(self): - if self.display_width is not None: - return self._extend("INTEGER(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("INTEGER") - class MSBigInteger(MSInteger): """MySQL BIGINTEGER type.""" + __visit_name__ = 'BIGINT' + def __init__(self, display_width=None, **kw): """Construct a BIGINTEGER. @@ -568,16 +481,12 @@ class MSBigInteger(MSInteger): """ super(MSBigInteger, self).__init__(display_width, **kw) - def get_col_spec(self): - if self.display_width is not None: - return self._extend("BIGINT(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("BIGINT") - class MSMediumInteger(MSInteger): """MySQL MEDIUMINTEGER type.""" + __visit_name__ = 'MEDIUMINT' + def __init__(self, display_width=None, **kw): """Construct a MEDIUMINTEGER @@ -593,17 +502,12 @@ class MSMediumInteger(MSInteger): """ super(MSMediumInteger, self).__init__(display_width, **kw) - def get_col_spec(self): - if self.display_width is not None: - return self._extend("MEDIUMINT(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("MEDIUMINT") - - class MSTinyInteger(MSInteger): """MySQL TINYINT type.""" + __visit_name__ = 'TINYINT' + def __init__(self, display_width=None, **kw): """Construct a TINYINT. @@ -623,16 +527,12 @@ class MSTinyInteger(MSInteger): """ super(MSTinyInteger, self).__init__(display_width, **kw) - def get_col_spec(self): - if self.display_width is not None: - return self._extend("TINYINT(%s)" % self.display_width) - else: - return self._extend("TINYINT") - class MSSmallInteger(sqltypes.SmallInteger, MSInteger): """MySQL SMALLINTEGER type.""" + __visit_name__ = 'SMALLINT' + def __init__(self, display_width=None, **kw): """Construct a SMALLINTEGER. @@ -650,12 +550,6 @@ class MSSmallInteger(sqltypes.SmallInteger, MSInteger): _NumericType.__init__(self, kw) sqltypes.SmallInteger.__init__(self, **kw) - def get_col_spec(self): - if self.display_width is not None: - return self._extend("SMALLINT(%(display_width)s)" % {'display_width': self.display_width}) - else: - return self._extend("SMALLINT") - class MSBit(sqltypes.TypeEngine): """MySQL BIT type. @@ -666,6 +560,8 @@ class MSBit(sqltypes.TypeEngine): """ + __visit_name__ = 'BIT' + def __init__(self, length=None): """Construct a BIT. @@ -685,32 +581,24 @@ class MSBit(sqltypes.TypeEngine): return value return process - def get_col_spec(self): - if self.length is not None: - return "BIT(%s)" % self.length - else: - return "BIT" - +# TODO: probably don't need datetime/date types since no behavior changes class MSDateTime(sqltypes.DateTime): """MySQL DATETIME type.""" - - def get_col_spec(self): - return "DATETIME" + + __visit_name__ = 'DATETIME' class MSDate(sqltypes.Date): """MySQL DATE type.""" + __visit_name__ = 'DATE' - def get_col_spec(self): - return "DATE" class MSTime(sqltypes.Time): """MySQL TIME type.""" - def get_col_spec(self): - return "TIME" + __visit_name__ = 'TIME' def result_processor(self, dialect): def process(value): @@ -739,26 +627,23 @@ class MSTimeStamp(sqltypes.TIMESTAMP): server_default=sql.text('CURRENT TIMESTAMP ON UPDATE CURRENT_TIMESTAMP') """ - - def get_col_spec(self): - return "TIMESTAMP" + __visit_name__ = 'TIMESTAMP' class MSYear(sqltypes.TypeEngine): """MySQL YEAR type, for single byte storage of years 1901-2155.""" + __visit_name__ = 'YEAR' + def __init__(self, display_width=None): self.display_width = display_width - def get_col_spec(self): - if self.display_width is None: - return "YEAR" - else: - return "YEAR(%s)" % self.display_width class MSText(_StringType, sqltypes.Text): """MySQL TEXT type, for text up to 2^16 characters.""" + __visit_name__ = 'TEXT' + def __init__(self, length=None, **kwargs): """Construct a TEXT. @@ -791,16 +676,12 @@ class MSText(_StringType, sqltypes.Text): sqltypes.Text.__init__(self, length, kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None)) - def get_col_spec(self): - if self.length: - return self._extend("TEXT(%d)" % self.length) - else: - return self._extend("TEXT") - class MSTinyText(MSText): """MySQL TINYTEXT type, for text up to 2^8 characters.""" + __visit_name__ = 'TINYTEXT' + def __init__(self, **kwargs): """Construct a TINYTEXT. @@ -828,13 +709,12 @@ class MSTinyText(MSText): super(MSTinyText, self).__init__(**kwargs) - def get_col_spec(self): - return self._extend("TINYTEXT") - class MSMediumText(MSText): """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" + __visit_name__ = 'MEDIUMTEXT' + def __init__(self, **kwargs): """Construct a MEDIUMTEXT. @@ -861,13 +741,11 @@ class MSMediumText(MSText): """ super(MSMediumText, self).__init__(**kwargs) - def get_col_spec(self): - return self._extend("MEDIUMTEXT") - - class MSLongText(MSText): """MySQL LONGTEXT type, for text up to 2^32 characters.""" + __visit_name__ = 'LONGTEXT' + def __init__(self, **kwargs): """Construct a LONGTEXT. @@ -894,13 +772,13 @@ class MSLongText(MSText): """ super(MSLongText, self).__init__(**kwargs) - def get_col_spec(self): - return self._extend("LONGTEXT") class MSString(_StringType, sqltypes.String): """MySQL VARCHAR type, for variable-length character data.""" + __visit_name__ = 'VARCHAR' + def __init__(self, length=None, **kwargs): """Construct a VARCHAR. @@ -929,18 +807,14 @@ class MSString(_StringType, sqltypes.String): sqltypes.String.__init__(self, length, kwargs.get('convert_unicode', False), kwargs.get('assert_unicode', None)) - def get_col_spec(self): - if self.length: - return self._extend("VARCHAR(%d)" % self.length) - else: - return self._extend("VARCHAR") - class MSChar(_StringType, sqltypes.CHAR): """MySQL CHAR type, for fixed-length character data.""" + __visit_name__ = 'CHAR' + def __init__(self, length, **kwargs): - """Construct an NCHAR. + """Construct a CHAR. :param length: Maximum data length, in characters. @@ -956,8 +830,6 @@ class MSChar(_StringType, sqltypes.CHAR): sqltypes.CHAR.__init__(self, length, kwargs.get('convert_unicode', False)) - def get_col_spec(self): - return self._extend("CHAR(%(length)s)" % {'length' : self.length}) class MSNVarChar(_StringType, sqltypes.String): @@ -967,6 +839,8 @@ class MSNVarChar(_StringType, sqltypes.String): character set. """ + __visit_name__ = 'NVARCHAR' + def __init__(self, length=None, **kwargs): """Construct an NVARCHAR. @@ -985,10 +859,6 @@ class MSNVarChar(_StringType, sqltypes.String): sqltypes.String.__init__(self, length, kwargs.get('convert_unicode', False)) - def get_col_spec(self): - # We'll actually generate the equiv. "NATIONAL VARCHAR" instead - # of "NVARCHAR". - return self._extend("VARCHAR(%(length)s)" % {'length': self.length}) class MSNChar(_StringType, sqltypes.CHAR): @@ -998,6 +868,8 @@ class MSNChar(_StringType, sqltypes.CHAR): character set. """ + __visit_name__ = 'NCHAR' + def __init__(self, length=None, **kwargs): """Construct an NCHAR. Arguments are: @@ -1015,20 +887,11 @@ class MSNChar(_StringType, sqltypes.CHAR): _StringType.__init__(self, **kwargs) sqltypes.CHAR.__init__(self, length, kwargs.get('convert_unicode', False)) - def get_col_spec(self): - # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". - return self._extend("CHAR(%(length)s)" % {'length': self.length}) class _BinaryType(sqltypes.Binary): """Base for MySQL binary types.""" - def get_col_spec(self): - if self.length: - return "BLOB(%d)" % self.length - else: - return "BLOB" - def result_processor(self, dialect): def process(value): if value is None: @@ -1040,6 +903,8 @@ class _BinaryType(sqltypes.Binary): class MSVarBinary(_BinaryType): """MySQL VARBINARY type, for variable length binary data.""" + __visit_name__ = 'VARBINARY' + def __init__(self, length=None, **kw): """Construct a VARBINARY. Arguments are: @@ -1048,16 +913,12 @@ class MSVarBinary(_BinaryType): """ super(MSVarBinary, self).__init__(length, **kw) - def get_col_spec(self): - if self.length: - return "VARBINARY(%d)" % self.length - else: - return "BLOB" - class MSBinary(_BinaryType): """MySQL BINARY type, for fixed length binary data""" + __visit_name__ = 'BINARY' + def __init__(self, length=None, **kw): """Construct a BINARY. @@ -1070,12 +931,6 @@ class MSBinary(_BinaryType): """ super(MSBinary, self).__init__(length, **kw) - def get_col_spec(self): - if self.length: - return "BINARY(%d)" % self.length - else: - return "BLOB" - def result_processor(self, dialect): def process(value): if value is None: @@ -1087,6 +942,8 @@ class MSBinary(_BinaryType): class MSBlob(_BinaryType): """MySQL BLOB type, for binary data up to 2^16 bytes""" + __visit_name__ = 'BLOB' + def __init__(self, length=None, **kw): """Construct a BLOB. Arguments are: @@ -1097,12 +954,6 @@ class MSBlob(_BinaryType): """ super(MSBlob, self).__init__(length, **kw) - def get_col_spec(self): - if self.length: - return "BLOB(%d)" % self.length - else: - return "BLOB" - def result_processor(self, dialect): def process(value): if value is None: @@ -1117,28 +968,27 @@ class MSBlob(_BinaryType): class MSTinyBlob(MSBlob): """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" - - def get_col_spec(self): - return "TINYBLOB" + + __visit_name__ = 'TINYBLOB' class MSMediumBlob(MSBlob): """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" - def get_col_spec(self): - return "MEDIUMBLOB" + __visit_name__ = 'MEDIUMBLOB' class MSLongBlob(MSBlob): """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" - def get_col_spec(self): - return "LONGBLOB" + __visit_name__ = 'LONGBLOB' class MSEnum(MSString): """MySQL ENUM type.""" + __visit_name__ = 'ENUM' + def __init__(self, *enums, **kw): """Construct an ENUM. @@ -1239,15 +1089,11 @@ class MSEnum(MSString): return value return process - def get_col_spec(self): - quoted_enums = [] - for e in self.enums: - quoted_enums.append("'%s'" % e.replace("'", "''")) - return self._extend("ENUM(%s)" % ",".join(quoted_enums)) - class MSSet(MSString): """MySQL SET type.""" + __visit_name__ = 'SET' + def __init__(self, *values, **kw): """Construct a SET. @@ -1332,15 +1178,11 @@ class MSSet(MSString): return value return process - def get_col_spec(self): - return self._extend("SET(%s)" % ",".join(self.__ddl_values)) - class MSBoolean(sqltypes.Boolean): """MySQL BOOLEAN type.""" - def get_col_spec(self): - return "BOOL" + __visit_name__ = 'BOOLEAN' def result_processor(self, dialect): def process(value): @@ -1420,13 +1262,12 @@ ischema_names = { 'year': MSYear, } - class MySQLExecutionContext(default.DefaultExecutionContext): def post_exec(self): - if self.compiled.isinsert and not self.executemany: + if self.isinsert and not self.executemany: if (not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None): - self._last_inserted_ids = ([self.cursor.lastrowid] + + self._last_inserted_ids = ([self._lastrowid(self.cursor)] + self._last_inserted_ids[1:]) elif (not self.isupdate and not self.should_autocommit and self.statement and SET_RE.match(self.statement)): @@ -1434,114 +1275,473 @@ class MySQLExecutionContext(default.DefaultExecutionContext): # which is probably a programming error anyhow. self.connection.info.pop(('mysql', 'charset'), None) + def _lastrowid(self, cursor): + raise NotImplementedError() + def should_autocommit_text(self, statement): return AUTOCOMMIT_RE.match(statement) +class MySQLCompiler(compiler.SQLCompiler): + operators = compiler.SQLCompiler.operators.copy() + operators.update({ + sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y), + sql_operators.mod: '%%', + sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y) + }) + functions = compiler.SQLCompiler.functions.copy() + functions.update ({ + sql_functions.random: 'rand%(expr)s', + "utc_timestamp":"UTC_TIMESTAMP" + }) -class MySQLDialect(default.DefaultDialect): - """Details of the MySQL dialect. Not used directly in application code.""" - name = 'mysql' - supports_alter = True - supports_unicode_statements = False - # identifiers are 64, however aliases can be 255... - max_identifier_length = 255 - supports_sane_rowcount = True - default_paramstyle = 'format' + def visit_typeclause(self, typeclause): + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, MSInteger): + if getattr(type_, 'unsigned', False): + return 'UNSIGNED INTEGER' + else: + return 'SIGNED INTEGER' + elif isinstance(type_, (MSDecimal, MSDateTime, MSDate, MSTime)): + return self.dialect.type_compiler.process(type_) + elif isinstance(type_, MSText): + return 'CHAR' + elif (isinstance(type_, _StringType) and not + isinstance(type_, (MSEnum, MSSet))): + if getattr(type_, 'length'): + return 'CHAR(%s)' % type_.length + else: + return 'CHAR' + elif isinstance(type_, _BinaryType): + return 'BINARY' + elif isinstance(type_, MSNumeric): + return self.dialect.type_compiler.process(type_).replace('NUMERIC', 'DECIMAL') + elif isinstance(type_, MSTimeStamp): + return 'DATETIME' + elif isinstance(type_, (MSDateTime, MSDate, MSTime)): + return self.dialect.type_compiler.process(type_) + else: + return None - def __init__(self, use_ansiquotes=None, **kwargs): - self.use_ansiquotes = use_ansiquotes - default.DefaultDialect.__init__(self, **kwargs) + def visit_cast(self, cast, **kwargs): + # No cast until 4, no decimals until 5. + type_ = self.process(cast.typeclause) + if type_ is None: + return self.process(cast.clause) - def dbapi(cls): - import MySQLdb as mysql - return mysql - dbapi = classmethod(dbapi) - - def create_connect_args(self, url): - opts = url.translate_connect_args(database='db', username='user', - password='passwd') - opts.update(url.query) - - util.coerce_kw_type(opts, 'compress', bool) - util.coerce_kw_type(opts, 'connect_timeout', int) - util.coerce_kw_type(opts, 'client_flag', int) - util.coerce_kw_type(opts, 'local_infile', int) - # Note: using either of the below will cause all strings to be returned - # as Unicode, both in raw SQL operations and with column types like - # String and MSString. - util.coerce_kw_type(opts, 'use_unicode', bool) - util.coerce_kw_type(opts, 'charset', str) - - # Rich values 'cursorclass' and 'conv' are not supported via - # query string. - - ssl = {} - for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']: - if key in opts: - ssl[key[4:]] = opts[key] - util.coerce_kw_type(ssl, key[4:], str) - del opts[key] - if ssl: - opts['ssl'] = ssl - - # FOUND_ROWS must be set in CLIENT_FLAGS to enable - # supports_sane_rowcount. - client_flag = opts.get('client_flag', 0) - if self.dbapi is not None: - try: - import MySQLdb.constants.CLIENT as CLIENT_FLAGS - client_flag |= CLIENT_FLAGS.FOUND_ROWS - except: - pass - opts['client_flag'] = client_flag - return [[], opts] + return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) - def type_descriptor(self, typeobj): - return sqltypes.adapt_type(typeobj, colspecs) - def do_executemany(self, cursor, statement, parameters, context=None): - rowcount = cursor.executemany(statement, parameters) - if context is not None: - context._rowcount = rowcount + def post_process_text(self, text): + if '%%' in text: + util.warn("The SQLAlchemy MySQLDB dialect now automatically escapes '%' in text() expressions to '%%'.") + return text.replace('%', '%%') - def supports_unicode_statements(self): - return True + def get_select_precolumns(self, select): + if isinstance(select._distinct, basestring): + return select._distinct.upper() + " " + elif select._distinct: + return "DISTINCT " + else: + return "" - def do_commit(self, connection): - """Execute a COMMIT.""" + def visit_join(self, join, asfrom=False, **kwargs): + # 'JOIN ... ON ...' for inner joins isn't available until 4.0. + # Apparently < 3.23.17 requires theta joins for inner joins + # (but not outer). Not generating these currently, but + # support can be added, preferably after dialects are + # refactored to be version-sensitive. + return ''.join( + (self.process(join.left, asfrom=True), + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "), + self.process(join.right, asfrom=True), + " ON ", + self.process(join.onclause))) - # COMMIT/ROLLBACK were introduced in 3.23.15. - # Yes, we have at least one user who has to talk to these old versions! - # - # Ignore commit/rollback if support isn't present, otherwise even basic - # operations via autocommit fail. - try: - connection.commit() - except: - if self._server_version_info(connection) < (3, 23, 15): - args = sys.exc_info()[1].args - if args and args[0] == 1064: - return - raise + def for_update_clause(self, select): + if select.for_update == 'read': + return ' LOCK IN SHARE MODE' + else: + return super(MySQLCompiler, self).for_update_clause(select) - def do_rollback(self, connection): - """Execute a ROLLBACK.""" + def limit_clause(self, select): + # MySQL supports: + # LIMIT + # LIMIT , + # and in server versions > 3.3: + # LIMIT OFFSET + # The latter is more readable for offsets but we're stuck with the + # former until we can refine dialects by server revision. - try: - connection.rollback() - except: - if self._server_version_info(connection) < (3, 23, 15): - args = sys.exc_info()[1].args - if args and args[0] == 1064: - return - raise + limit, offset = select._limit, select._offset - def do_begin_twophase(self, connection, xid): - connection.execute("XA BEGIN %s", xid) + if (limit, offset) == (None, None): + return '' + elif offset is not None: + # As suggested by the MySQL docs, need to apply an + # artificial limit if one wasn't provided + if limit is None: + limit = 18446744073709551615 + return ' \n LIMIT %s, %s' % (offset, limit) + else: + # No offset provided, so just use the limit + return ' \n LIMIT %s' % (limit,) - def do_prepare_twophase(self, connection, xid): - connection.execute("XA END %s", xid) - connection.execute("XA PREPARE %s", xid) + def visit_update(self, update_stmt): + self.stack.append({'from': set([update_stmt.table])}) + + self.isupdate = True + colparams = self._get_colparams(update_stmt) + + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \ + " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) + + if update_stmt._whereclause: + text += " WHERE " + self.process(update_stmt._whereclause) + + limit = update_stmt.kwargs.get('mysql_limit', None) + if limit: + text += " LIMIT %s" % limit + + self.stack.pop(-1) + + return text + +# ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. +# Starting with MySQL 4.1.2, these indexes are created automatically. +# In older versions, the indexes must be created explicitly or the +# creation of foreign key constraints fails." + +class MySQLDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kw): + """Builds column DDL.""" + + colspec = [self.preparer.format_column(column), + #self.dialect.type_compiler.process(column.type.dialect_impl(self.dialect)) + self.dialect.type_compiler.process(column.type) + ] + + default = self.get_column_default_string(column) + if default is not None: + colspec.append('DEFAULT ' + default) + + if not column.nullable: + colspec.append('NOT NULL') + + if column.primary_key and column.autoincrement: + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + isinstance(c.type, sqltypes.Integer) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('AUTO_INCREMENT') + except IndexError: + pass + + return ' '.join(colspec) + + def post_create_table(self, table): + """Build table-level CREATE options like ENGINE and COLLATE.""" + + table_opts = [] + for k in table.kwargs: + if k.startswith('mysql_'): + opt = k[6:].upper() + joiner = '=' + if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', + 'CHARACTER SET', 'COLLATE'): + joiner = ' ' + + table_opts.append(joiner.join((opt, table.kwargs[k]))) + return ' '.join(table_opts) + + def visit_drop_index(self, drop): + index = drop.element + + return "\nDROP INDEX %s ON %s" % \ + (self.preparer.quote(self._validate_identifier(index.name, False), index.quote), + self.preparer.format_table(index.table)) + + def visit_drop_foreignkey(self, drop): + constraint = drop.element + return "ALTER TABLE %s DROP FOREIGN KEY %s" % \ + (self.preparer.format_table(constraint.table), + self.preparer.format_constraint(constraint)) + +class MySQLTypeCompiler(compiler.GenericTypeCompiler): + def _extend_numeric(self, type_, spec): + "Extend a numeric-type declaration with MySQL specific extensions." + + if not self._mysql_type(type_): + return spec + + if type_.unsigned: + spec += ' UNSIGNED' + if type_.zerofill: + spec += ' ZEROFILL' + return spec + + def _extend_string(self, type_, spec): + """Extend a string-type declaration with standard SQL CHARACTER SET / + COLLATE annotations and MySQL specific extensions. + + """ + if not self._mysql_type(type_): + return spec + + if type_.charset: + charset = 'CHARACTER SET %s' % type_.charset + elif type_.ascii: + charset = 'ASCII' + elif type_.unicode: + charset = 'UNICODE' + else: + charset = None + + if type_.collation: + collation = 'COLLATE %s' % type_.collation + elif type_.binary: + collation = 'BINARY' + else: + collation = None + + if type_.national: + # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. + return ' '.join([c for c in ('NATIONAL', spec, collation) + if c is not None]) + return ' '.join([c for c in (spec, charset, collation) + if c is not None]) + + def _mysql_type(self, type_): + return isinstance(type_, (_StringType, _NumericType, _BinaryType)) + + def visit_NUMERIC(self, type_): + if type_.precision is None: + return self._extend_numeric(type_, "NUMERIC") + else: + return self._extend_numeric(type_, "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}) + + def visit_DECIMAL(self, type_): + if type_.precision is None: + return self._extend_numeric(type_, "DECIMAL") + elif type_.scale is None: + return self._extend_numeric(type_, "DECIMAL(%(precision)s)" % {'precision': type_.precision}) + else: + return self._extend_numeric(type_, "DECIMAL(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}) + + def visit_DOUBLE(self, type_): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric(type_, "DOUBLE(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale' : type_.scale}) + else: + return self._extend_numeric(type_, 'DOUBLE') + + def visit_REAL(self, type_): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric(type_, "REAL(%(precision)s, %(scale)s)" % + {'precision': type_.precision, + 'scale' : type_.scale}) + else: + return self._extend_numeric(type_, 'REAL') + + + def visit_FLOAT(self, type_): + if self._mysql_type(type_) and type_.scale is not None and type_.precision is not None: + return self._extend_numeric(type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)) + elif type_.precision is not None: + return self._extend_numeric(type_, "FLOAT(%s)" % (type_.precision,)) + else: + return self._extend_numeric(type_, "FLOAT") + + def visit_INTEGER(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "INTEGER(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "INTEGER") + + def visit_BIGINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "BIGINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "BIGINT") + + def visit_MEDIUMINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "MEDIUMINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "MEDIUMINT") + + def visit_TINYINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "TINYINT(%s)" % type_.display_width) + else: + return self._extend_numeric(type_, "TINYINT") + + def visit_SMALLINT(self, type_): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric(type_, "SMALLINT(%(display_width)s)" % {'display_width': type_.display_width}) + else: + return self._extend_numeric(type_, "SMALLINT") + + def visit_BIT(self, type_): + if type_.length is not None: + return "BIT(%s)" % type_.length + else: + return "BIT" + + def visit_DATETIME(self, type_): + return "DATETIME" + + def visit_DATE(self, type_): + return "DATE" + + def visit_TIME(self, type_): + return "TIME" + + def visit_TIMESTAMP(self, type_): + return 'TIMESTAMP' + + def visit_YEAR(self, type_): + if type_.display_width is None: + return "YEAR" + else: + return "YEAR(%s)" % type_.display_width + + def visit_TEXT(self, type_): + if type_.length: + return self._extend_string(type_, "TEXT(%d)" % type_.length) + else: + return self._extend_string(type_, "TEXT") + + def visit_TINYTEXT(self, type_): + return self._extend_string(type_, "TINYTEXT") + + def visit_MEDIUMTEXT(self, type_): + return self._extend_string(type_, "MEDIUMTEXT") + + def visit_LONGTEXT(self, type_): + return self._extend_string(type_, "LONGTEXT") + + def visit_VARCHAR(self, type_): + if type_.length: + return self._extend_string(type_, "VARCHAR(%d)" % type_.length) + else: + return self._extend_string(type_, "VARCHAR") + + def visit_CHAR(self, type_): + return self._extend_string(type_, "CHAR(%(length)s)" % {'length' : type_.length}) + + def visit_NVARCHAR(self, type_): + # We'll actually generate the equiv. "NATIONAL VARCHAR" instead + # of "NVARCHAR". + return self._extend_string(type_, "VARCHAR(%(length)s)" % {'length': type_.length}) + + def visit_NCHAR(self, type_): + # We'll actually generate the equiv. "NATIONAL CHAR" instead of "NCHAR". + return self._extend_string(type_, "CHAR(%(length)s)" % {'length': type_.length}) + + def visit_VARBINARY(self, type_): + if type_.length: + return "VARBINARY(%d)" % type_.length + else: + return self.visit_BLOB(type_) + + def visit_BINARY(self, type_): + if type_.length: + return "BINARY(%d)" % type_.length + else: + return self.visit_BLOB(type_) + + def visit_BLOB(self, type_): + if type_.length: + return "BLOB(%d)" % type_.length + else: + return "BLOB" + + def visit_TINYBLOB(self, type_): + return "TINYBLOB" + + def visit_MEDIUMBLOB(self, type_): + return "MEDIUMBLOB" + + def visit_LONGBLOB(self, type_): + return "LONGBLOB" + + def visit_ENUM(self, type_): + quoted_enums = [] + for e in type_.enums: + quoted_enums.append("'%s'" % e.replace("'", "''")) + return self._extend_string(type_, "ENUM(%s)" % ",".join(quoted_enums)) + + def visit_SET(self, type_): + return self._extend_string("SET(%s)" % ",".join(type_._ddl_values)) + + def visit_BOOL(self, type): + return "BOOL" + + +class MySQLDialect(default.DefaultDialect): + """Details of the MySQL dialect. Not used directly in application code.""" + name = 'mysql' + supports_alter = True + # identifiers are 64, however aliases can be 255... + max_identifier_length = 255 + supports_sane_rowcount = True + default_paramstyle = 'format' + + statement_compiler = MySQLCompiler + ddl_compiler = MySQLDDLCompiler + type_compiler = MySQLTypeCompiler + ischema_names = ischema_names + + def __init__(self, use_ansiquotes=None, **kwargs): + self.use_ansiquotes = use_ansiquotes + default.DefaultDialect.__init__(self, **kwargs) + + def type_descriptor(self, typeobj): + return sqltypes.adapt_type(typeobj, colspecs) + + def do_executemany(self, cursor, statement, parameters, context=None): + rowcount = cursor.executemany(statement, parameters) + if context is not None: + context._rowcount = rowcount + + def do_commit(self, connection): + """Execute a COMMIT.""" + + # COMMIT/ROLLBACK were introduced in 3.23.15. + # Yes, we have at least one user who has to talk to these old versions! + # + # Ignore commit/rollback if support isn't present, otherwise even basic + # operations via autocommit fail. + try: + connection.commit() + except: + if self._server_version_info(connection) < (3, 23, 15): + args = sys.exc_info()[1].args + if args and args[0] == 1064: + return + raise + + def do_rollback(self, connection): + """Execute a ROLLBACK.""" + + try: + connection.rollback() + except: + if self._server_version_info(connection) < (3, 23, 15): + args = sys.exc_info()[1].args + if args and args[0] == 1064: + return + raise + + def do_begin_twophase(self, connection, xid): + connection.execute("XA BEGIN %s", xid) + + def do_prepare_twophase(self, connection, xid): + connection.execute("XA END %s", xid) + connection.execute("XA PREPARE %s", xid) def do_rollback_twophase(self, connection, xid, is_prepared=True, recover=False): @@ -1559,9 +1759,6 @@ class MySQLDialect(default.DefaultDialect): resultset = connection.execute("XA RECOVER") return [row['data'][0:row['gtrid_length']] for row in resultset] - def do_ping(self, connection): - connection.ping() - def is_disconnect(self, e): if isinstance(e, self.dbapi.OperationalError): return e.args[0] in (2006, 2013, 2014, 2045, 2055) @@ -1570,6 +1767,12 @@ class MySQLDialect(default.DefaultDialect): else: return False + def _compat_fetchall(self, rp, charset=None): + return rp.fetchall() + + def _compat_fetchone(self, rp, charset=None): + return rp.fetchone() + def get_default_schema_name(self, connection): return connection.execute('SELECT DATABASE()').scalar() get_default_schema_name = engine_base.connection_memoize( @@ -1582,7 +1785,7 @@ class MySQLDialect(default.DefaultDialect): self._autoset_identifier_style(connection) rp = connection.execute("SHOW TABLES FROM %s" % self.identifier_preparer.quote_identifier(schema)) - return [row[0] for row in _compat_fetchall(rp, charset=charset)] + return [row[0] for row in self._compat_fetchall(rp, charset=charset)] def has_table(self, connection, table_name, schema=None): # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly @@ -1632,18 +1835,6 @@ class MySQLDialect(default.DefaultDialect): server_version_info = engine_base.connection_memoize( ('mysql', 'server_version_info'))(server_version_info) - def _server_version_info(self, dbapi_con): - """Convert a MySQL-python server_info string into a tuple.""" - - version = [] - r = re.compile('[.\-]') - for n in r.split(dbapi_con.get_server_info()): - try: - version.append(int(n)) - except ValueError: - version.append(n) - return tuple(version) - def reflecttable(self, connection, table, include_columns): """Load column definitions from the server.""" @@ -1659,7 +1850,7 @@ class MySQLDialect(default.DefaultDialect): # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 preparer = MySQLIdentifierPreparer(self) - self.reflector = reflector = MySQLSchemaReflector(preparer) + self.reflector = reflector = MySQLSchemaReflector(self) sql = self._show_create_table(connection, table, charset) if sql.startswith('CREATE ALGORITHM'): @@ -1683,7 +1874,6 @@ class MySQLDialect(default.DefaultDialect): lc_alias = schema._get_table_key(table.name, table.schema) table.metadata.tables[lc_alias] = table - def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" @@ -1691,16 +1881,6 @@ class MySQLDialect(default.DefaultDialect): if ('mysql', 'force_charset') in connection.info: return connection.info[('mysql', 'force_charset')] - # Note: MySQL-python 1.2.1c7 seems to ignore changes made - # on a connection via set_character_set() - if self.server_version_info(connection) < (4, 1, 0): - try: - return connection.connection.character_set_name() - except AttributeError: - # < 1.2.1 final MySQL-python drivers have no charset support. - # a query is needed. - pass - # Prefer 'character_set_results' for the current connection over the # value in the driver. SET NAMES or individual variable SETs will # change the charset without updating the driver's view of the world. @@ -1708,22 +1888,17 @@ class MySQLDialect(default.DefaultDialect): # If it's decided that issuing that sort of SQL leaves you SOL, then # this can prefer the driver value. rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") - opts = dict([(row[0], row[1]) for row in _compat_fetchall(rs)]) + opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)]) if 'character_set_results' in opts: return opts['character_set_results'] - try: - return connection.connection.character_set_name() - except AttributeError: - # Still no charset on < 1.2.1 final... - if 'character_set' in opts: - return opts['character_set'] - else: - util.warn( - "Could not detect the connection character set with this " - "combination of MySQL server and MySQL-python. " - "MySQL-python >= 1.2.2 is recommended. Assuming latin1.") - return 'latin1' + # Still no charset on < 1.2.1 final... + if 'character_set' in opts: + return opts['character_set'] + else: + util.warn( + "Could not detect the connection character set. Assuming latin1.") + return 'latin1' _detect_charset = engine_base.connection_memoize( ('mysql', 'charset'))(_detect_charset) @@ -1738,7 +1913,7 @@ class MySQLDialect(default.DefaultDialect): # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html charset = self._detect_charset(connection) - row = _compat_fetchone(connection.execute( + row = self._compat_fetchone(connection.execute( "SHOW VARIABLES LIKE 'lower_case_table_names'"), charset=charset) if not row: @@ -1769,7 +1944,7 @@ class MySQLDialect(default.DefaultDialect): else: charset = self._detect_charset(connection) rs = connection.execute('SHOW COLLATION') - for row in _compat_fetchall(rs, charset): + for row in self._compat_fetchall(rs, charset): collations[row[0]] = row[1] return collations _detect_collations = engine_base.connection_memoize( @@ -1804,7 +1979,7 @@ class MySQLDialect(default.DefaultDialect): if self.use_ansiquotes is not None: return - row = _compat_fetchone( + row = self._compat_fetchone( connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), charset=charset) if not row: @@ -1835,7 +2010,7 @@ class MySQLDialect(default.DefaultDialect): raise exc.NoSuchTableError(full_name) else: raise - row = _compat_fetchone(rp, charset=charset) + row = self._compat_fetchone(rp, charset=charset) if not row: raise exc.NoSuchTableError(full_name) return row[1].strip() @@ -1862,237 +2037,17 @@ class MySQLDialect(default.DefaultDialect): raise exc.NoSuchTableError(full_name) else: raise - rows = _compat_fetchall(rp, charset=charset) + rows = self._compat_fetchall(rp, charset=charset) finally: if rp: rp.close() return rows -class _MySQLPythonRowProxy(object): - """Return consistent column values for all versions of MySQL-python. - - Smooth over data type issues (esp. with alpha driver versions) and - normalize strings as Unicode regardless of user-configured driver - encoding settings. - """ - - # Some MySQL-python versions can return some columns as - # sets.Set(['value']) (seriously) but thankfully that doesn't - # seem to come up in DDL queries. - - def __init__(self, rowproxy, charset): - self.rowproxy = rowproxy - self.charset = charset - def __getitem__(self, index): - item = self.rowproxy[index] - if isinstance(item, _array): - item = item.tostring() - if self.charset and isinstance(item, str): - return item.decode(self.charset) - else: - return item - def __getattr__(self, attr): - item = getattr(self.rowproxy, attr) - if isinstance(item, _array): - item = item.tostring() - if self.charset and isinstance(item, str): - return item.decode(self.charset) - else: - return item - - -class MySQLCompiler(compiler.SQLCompiler): - operators = compiler.SQLCompiler.operators.copy() - operators.update({ - sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y), - sql_operators.mod: '%%', - sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y) - }) - functions = compiler.SQLCompiler.functions.copy() - functions.update ({ - sql_functions.random: 'rand%(expr)s', - "utc_timestamp":"UTC_TIMESTAMP" - }) - - - def visit_typeclause(self, typeclause): - type_ = typeclause.type.dialect_impl(self.dialect) - if isinstance(type_, MSInteger): - if getattr(type_, 'unsigned', False): - return 'UNSIGNED INTEGER' - else: - return 'SIGNED INTEGER' - elif isinstance(type_, (MSDecimal, MSDateTime, MSDate, MSTime)): - return type_.get_col_spec() - elif isinstance(type_, MSText): - return 'CHAR' - elif (isinstance(type_, _StringType) and not - isinstance(type_, (MSEnum, MSSet))): - if getattr(type_, 'length'): - return 'CHAR(%s)' % type_.length - else: - return 'CHAR' - elif isinstance(type_, _BinaryType): - return 'BINARY' - elif isinstance(type_, MSNumeric): - return type_.get_col_spec().replace('NUMERIC', 'DECIMAL') - elif isinstance(type_, MSTimeStamp): - return 'DATETIME' - elif isinstance(type_, (MSDateTime, MSDate, MSTime)): - return type_.get_col_spec() - else: - return None - - def visit_cast(self, cast, **kwargs): - # No cast until 4, no decimals until 5. - type_ = self.process(cast.typeclause) - if type_ is None: - return self.process(cast.clause) - - return 'CAST(%s AS %s)' % (self.process(cast.clause), type_) - - - def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy MySQLDB dialect now automatically escapes '%' in text() expressions to '%%'.") - return text.replace('%', '%%') - - def get_select_precolumns(self, select): - if isinstance(select._distinct, basestring): - return select._distinct.upper() + " " - elif select._distinct: - return "DISTINCT " - else: - return "" - - def visit_join(self, join, asfrom=False, **kwargs): - # 'JOIN ... ON ...' for inner joins isn't available until 4.0. - # Apparently < 3.23.17 requires theta joins for inner joins - # (but not outer). Not generating these currently, but - # support can be added, preferably after dialects are - # refactored to be version-sensitive. - return ''.join( - (self.process(join.left, asfrom=True), - (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN "), - self.process(join.right, asfrom=True), - " ON ", - self.process(join.onclause))) - - def for_update_clause(self, select): - if select.for_update == 'read': - return ' LOCK IN SHARE MODE' - else: - return super(MySQLCompiler, self).for_update_clause(select) - - def limit_clause(self, select): - # MySQL supports: - # LIMIT - # LIMIT , - # and in server versions > 3.3: - # LIMIT OFFSET - # The latter is more readable for offsets but we're stuck with the - # former until we can refine dialects by server revision. - - limit, offset = select._limit, select._offset - - if (limit, offset) == (None, None): - return '' - elif offset is not None: - # As suggested by the MySQL docs, need to apply an - # artificial limit if one wasn't provided - if limit is None: - limit = 18446744073709551615 - return ' \n LIMIT %s, %s' % (offset, limit) - else: - # No offset provided, so just use the limit - return ' \n LIMIT %s' % (limit,) - - def visit_update(self, update_stmt): - self.stack.append({'from': set([update_stmt.table])}) - - self.isupdate = True - colparams = self._get_colparams(update_stmt) - - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \ - " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) - - if update_stmt._whereclause: - text += " WHERE " + self.process(update_stmt._whereclause) - - limit = update_stmt.kwargs.get('mysql_limit', None) - if limit: - text += " LIMIT %s" % limit - - self.stack.pop(-1) - - return text - -# ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. -# Starting with MySQL 4.1.2, these indexes are created automatically. -# In older versions, the indexes must be created explicitly or the -# creation of foreign key constraints fails." - -class MySQLSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, first_pk=False): - """Builds column DDL.""" - - colspec = [self.preparer.format_column(column), - column.type.dialect_impl(self.dialect).get_col_spec()] - - default = self.get_column_default_string(column) - if default is not None: - colspec.append('DEFAULT ' + default) - - if not column.nullable: - colspec.append('NOT NULL') - - if column.primary_key and column.autoincrement: - try: - first = [c for c in column.table.primary_key.columns - if (c.autoincrement and - isinstance(c.type, sqltypes.Integer) and - not c.foreign_keys)].pop(0) - if column is first: - colspec.append('AUTO_INCREMENT') - except IndexError: - pass - - return ' '.join(colspec) - - def post_create_table(self, table): - """Build table-level CREATE options like ENGINE and COLLATE.""" - - table_opts = [] - for k in table.kwargs: - if k.startswith('mysql_'): - opt = k[6:].upper() - joiner = '=' - if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', - 'CHARACTER SET', 'COLLATE'): - joiner = ' ' - - table_opts.append(joiner.join((opt, table.kwargs[k]))) - return ' '.join(table_opts) - - -class MySQLSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - self.append("\nDROP INDEX %s ON %s" % - (self.preparer.quote(self._validate_identifier(index.name, False), index.quote), - self.preparer.format_table(index.table))) - self.execute() - - def drop_foreignkey(self, constraint): - self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % - (self.preparer.format_table(constraint.table), - self.preparer.format_constraint(constraint))) - self.execute() - class MySQLSchemaReflector(object): """Parses SHOW CREATE TABLE output.""" - def __init__(self, identifier_preparer): + def __init__(self, dialect): """Construct a MySQLSchemaReflector. identifier_preparer @@ -2100,7 +2055,8 @@ class MySQLSchemaReflector(object): quoting style in effect. """ - self.preparer = identifier_preparer + self.dialect = dialect + self.preparer = dialect.identifier_preparer self._prep_regexes() def reflect(self, connection, table, show_create, charset, only=None): @@ -2195,7 +2151,7 @@ class MySQLSchemaReflector(object): args = None try: - col_type = ischema_names[type_] + col_type = self.dialect.ischema_names[type_] except KeyError: util.warn("Did not recognize type '%s' of column '%s'" % (type_, name)) @@ -2701,17 +2657,6 @@ class MySQLANSIIdentifierPreparer(_MySQLIdentifierPreparer): pass - -def _compat_fetchall(rp, charset=None): - """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" - - return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()] - -def _compat_fetchone(rp, charset=None): - """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" - - return _MySQLPythonRowProxy(rp.fetchone(), charset) - def _pr_compile(regex, cleanup=None): """Prepare a 2-tuple of compiled regex and callable.""" @@ -2722,8 +2667,3 @@ def _re_compile(regex): return re.compile(regex, re.I | re.UNICODE) -dialect = MySQLDialect -dialect.statement_compiler = MySQLCompiler -dialect.schemagenerator = MySQLSchemaGenerator -dialect.schemadropper = MySQLSchemaDropper -dialect.execution_ctx_cls = MySQLExecutionContext diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py new file mode 100644 index 0000000000..b644374a8a --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -0,0 +1,185 @@ +"""Support for the MySQL database via the MySQL-python adapter. + +Character Sets +-------------- + +Many MySQL server installations default to a ``latin1`` encoding for client +connections. All data sent through the connection will be converted into +``latin1``, even if you have ``utf8`` or another character set on your tables +and columns. With versions 4.1 and higher, you can change the connection +character set either through server configuration or by including the +``charset`` parameter in the URL used for ``create_engine``. The ``charset`` +option is passed through to MySQL-Python and has the side-effect of also +enabling ``use_unicode`` in the driver by default. For regular encoded +strings, also pass ``use_unicode=0`` in the connection arguments:: + + # set client encoding to utf8; all strings come back as unicode + create_engine('mysql:///mydb?charset=utf8') + + # set client encoding to utf8; all strings come back as utf8 str + create_engine('mysql:///mydb?charset=utf8&use_unicode=0') +""" + +from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy import exc, log, schema, sql, util +import re +from array import array as _array + +class MySQL_mysqldbExecutionContext(MySQLExecutionContext): + def _lastrowid(self, cursor): + return cursor.lastrowid + +class MySQL_mysqldb(MySQLDialect): + driver = 'mysqldb' + supports_unicode_statements = False + default_paramstyle = 'format' + execution_ctx_cls = MySQL_mysqldbExecutionContext + + @classmethod + def dbapi(cls): + import MySQLdb as mysql + return mysql + + def create_connect_args(self, url): + opts = url.translate_connect_args(database='db', username='user', + password='passwd') + opts.update(url.query) + + util.coerce_kw_type(opts, 'compress', bool) + util.coerce_kw_type(opts, 'connect_timeout', int) + util.coerce_kw_type(opts, 'client_flag', int) + util.coerce_kw_type(opts, 'local_infile', int) + # Note: using either of the below will cause all strings to be returned + # as Unicode, both in raw SQL operations and with column types like + # String and MSString. + util.coerce_kw_type(opts, 'use_unicode', bool) + util.coerce_kw_type(opts, 'charset', str) + + # Rich values 'cursorclass' and 'conv' are not supported via + # query string. + + ssl = {} + for key in ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher']: + if key in opts: + ssl[key[4:]] = opts[key] + util.coerce_kw_type(ssl, key[4:], str) + del opts[key] + if ssl: + opts['ssl'] = ssl + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get('client_flag', 0) + if self.dbapi is not None: + try: + import MySQLdb.constants.CLIENT as CLIENT_FLAGS + client_flag |= CLIENT_FLAGS.FOUND_ROWS + except: + pass + opts['client_flag'] = client_flag + return [[], opts] + + def do_ping(self, connection): + connection.ping() + + def _server_version_info(self, dbapi_con): + """Convert a MySQL-python server_info string into a tuple.""" + + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.get_server_info()): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Allow user override, won't sniff if force_charset is set. + if ('mysql', 'force_charset') in connection.info: + return connection.info[('mysql', 'force_charset')] + + # Note: MySQL-python 1.2.1c7 seems to ignore changes made + # on a connection via set_character_set() + if self.server_version_info(connection) < (4, 1, 0): + try: + return connection.connection.character_set_name() + except AttributeError: + # < 1.2.1 final MySQL-python drivers have no charset support. + # a query is needed. + pass + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") + opts = dict([(row[0], row[1]) for row in self._compat_fetchall(rs)]) + + if 'character_set_results' in opts: + return opts['character_set_results'] + try: + return connection.connection.character_set_name() + except AttributeError: + # Still no charset on < 1.2.1 final... + if 'character_set' in opts: + return opts['character_set'] + else: + util.warn( + "Could not detect the connection character set with this " + "combination of MySQL server and MySQL-python. " + "MySQL-python >= 1.2.2 is recommended. Assuming latin1.") + return 'latin1' + _detect_charset = engine_base.connection_memoize( + ('mysql', 'charset'))(_detect_charset) + + + def _compat_fetchall(self, rp, charset=None): + """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" + + return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()] + + def _compat_fetchone(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" + + return _MySQLPythonRowProxy(rp.fetchone(), charset) + +class _MySQLPythonRowProxy(object): + """Return consistent column values for all versions of MySQL-python. + + Smooth over data type issues (esp. with alpha driver versions) and + normalize strings as Unicode regardless of user-configured driver + encoding settings. + """ + + # Some MySQL-python versions can return some columns as + # sets.Set(['value']) (seriously) but thankfully that doesn't + # seem to come up in DDL queries. + + def __init__(self, rowproxy, charset): + self.rowproxy = rowproxy + self.charset = charset + def __getitem__(self, index): + item = self.rowproxy[index] + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, str): + return item.decode(self.charset) + else: + return item + def __getattr__(self, attr): + item = getattr(self.rowproxy, attr) + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, str): + return item.decode(self.charset) + else: + return item + + +dialect = MySQL_mysqldb \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py new file mode 100644 index 0000000000..b2698b16d3 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -0,0 +1,12 @@ +from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext + +class MySQL_pyodbcExecutionContext(MySQLExecutionContext): + def _lastrowid(self, cursor): + cursor.execute("SELECT LAST_INSERT_ID()") + return cursor.fetchone()[0] + +class MySQL_pyodbc(MySQLDialect): + pass + + +dialect = MySQL_pyodbc \ No newline at end of file diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 6d7a8f2693..02055b7092 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -28,7 +28,7 @@ expressions. """ import re, inspect -from sqlalchemy import types, exc, util, databases +from sqlalchemy import types, exc, util, dialects from sqlalchemy.sql import expression, visitors URL = None @@ -282,7 +282,7 @@ class Table(SchemaItem, expression.TableClause): def __extra_kwargs(self, **kwargs): # validate remaining kwargs that they all specify DB prefixes if len([k for k in kwargs - if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]): + if not re.match(r'^(?:%s)_' % '|'.join(dialects.__all__), k)]): raise TypeError( "Invalid argument(s) for Table: %s" % repr(kwargs.keys())) self.kwargs.update(kwargs) diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index a233c25f54..0ca9240110 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests() import sets from sqlalchemy import * from sqlalchemy import sql, exc -from sqlalchemy.databases import mysql +from sqlalchemy.dialects.mysql import base as mysql from testlib.testing import eq_ from testlib import * diff --git a/test/engine/reflection.py b/test/engine/reflection.py index ac245981e8..64ae468e67 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -707,9 +707,8 @@ class UnicodeReflectionTest(TestBase): r.drop_all() r.create_all() finally: - pass -# metadata.drop_all() -# bind.dispose() + metadata.drop_all() + bind.dispose() class SchemaTest(TestBase):