From e60012140bff9908ddd6cfa4d4c645c648b89d06 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Fri, 6 Mar 2009 20:23:52 +0000 Subject: [PATCH] refactored mysql to separtate parsing from reflecting --- lib/sqlalchemy/dialects/mysql/base.py | 662 +++++++++++++------------- test/dialect/mysql.py | 4 +- 2 files changed, 343 insertions(+), 323 deletions(-) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index a2a9261d81..e39d3f7ae3 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -179,12 +179,14 @@ timely information affecting MySQL in SQLAlchemy. import datetime, decimal, inspect, re, sys -from sqlalchemy import exc, log, schema, sql, util +from sqlalchemy import schema as sa_schema +from sqlalchemy import exc, log, sql, util from sqlalchemy.sql import operators as sql_operators from sqlalchemy.sql import functions as sql_functions from sqlalchemy.sql import compiler from array import array as _array +from sqlalchemy.engine import reflection from sqlalchemy.engine import base as engine_base, default from sqlalchemy import types as sqltypes @@ -1832,33 +1834,201 @@ class MySQLDialect(default.DefaultDialect): else: self.preparer = MySQLIdentifierPreparer self.identifier_preparer = self.preparer(self) + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): - def reflecttable(self, connection, table, include_columns): - """Load column definitions from the server.""" + parser = kw.get('parser') + return parser.columns - charset = self._connection_charset + @reflection.cache + def get_primary_keys(self, connection, table_name, schema=None, **kw): + + parser = kw.get('parser') + for key in parser.keys: + if key['type'] == 'PRIMARY': + # There can be only one. + ##raise Exception, str(key) + return [s[0] for s in key['columns']] + return [] + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + + parser = kw.get('parser') + default_schema = None + + fkeys = [] + + for spec in parser.constraints: + # only FOREIGN KEYs + ref_name = spec['table'][-1] + ref_schema = len(spec['table']) > 1 and spec['table'][-2] or schema + + if not ref_schema: + if default_schema is None: + default_schema = \ + connection.dialect.get_default_schema_name(connection) + if schema == default_schema: + ref_schema = schema + + loc_names = spec['local'] + ref_names = spec['foreign'] + + con_kw = {} + for opt in ('name', 'onupdate', 'ondelete'): + if spec.get(opt, False): + con_kw[opt] = spec[opt] + fkey_d = { + 'name' : None, + 'constrained_columns' : loc_names, + 'referred_schema' : ref_schema, + 'referred_table' : ref_name, + 'referred_columns' : ref_names, + 'options' : con_kw + } + fkeys.append(fkey_d) + return fkeys + + @reflection.cache + def get_indexes(self, connection, table_name, schema, **kw): + + parser = kw.get('parser') + indexes = [] + for spec in parser.keys: + unique = False + flavor = spec['type'] + if flavor == 'PRIMARY': + continue + if flavor == 'UNIQUE': + unique = True + elif flavor in (None, 'FULLTEXT', 'SPATIAL'): + pass + else: + self.logger.info( + "Converting unknown KEY type %s to a plain KEY" % flavor) + pass + index_d = {} + index_d['name'] = spec['name'] + index_d['column_names'] = [s[0] for s in spec['columns']] + index_d['unique'] = unique + index_d['type'] = flavor + indexes.append(index_d) + return indexes + + def _setupParser(self, connection, table_name, schema=None): + + charset = self._connection_charset try: - reflector = self.reflector + parser = self.parser except AttributeError: preparer = self.identifier_preparer if (self.server_version_info < (4, 1) and self._server_use_ansiquotes): # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 preparer = MySQLIdentifierPreparer(self) - - self.reflector = reflector = MySQLSchemaReflector(self, preparer) - - sql = self._show_create_table(connection, table, charset) + self.parser = parser = MySQLTableDefinitionParser(self, preparer) + full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( + schema, table_name)) + sql = self._show_create_table(connection, None, charset, + full_name=full_name) if sql.startswith('CREATE ALGORITHM'): # Adapt views to something table-like. - columns = self._describe_table(connection, table, charset) - sql = reflector._describe_to_create(table, columns) + columns = self._describe_table(connection, None, charset, + full_name=full_name) + sql = parser._describe_to_create(table_name, columns) + parser.parse(sql, charset) + return parser + + def reflecttable(self, connection, table, include_columns): + """Load column definitions from the server.""" + charset = self._connection_charset self._adjust_casing(table) + parser = self._setupParser(connection, table.name, table.schema) + + # check the table name + if parser.table_name is not None: + table.name = parser.table_name + # apply table options + if parser.table_options: + table.kwargs.update(parser.table_options) + # columns + for col_d in self.get_columns(connection, table.name, table.schema, + parser=parser): + name = col_d['name'] + coltype = col_d['type'] + nullable = col_d.get('nullable', True) + default = col_d['default'] + colargs = col_d['colargs'] + if include_columns and name not in include_columns: + continue + if default is not None and default != 'NULL': + colargs.append(sa_schema.DefaultClause(default)) + # Can I not specify nullable=True? + col_kw = {} + if nullable is False: + col_kw['nullable'] = False + if 'autoincrement' in col_d: + col_kw['autoincrement'] = col_d['autoincrement'] + table.append_column(sa_schema.Column(name, coltype, + *colargs, **col_kw)) + + # primary keys + pkey_cols = self.get_primary_keys(connection, table.name, + table.schema, parser=parser) + pkey = sa_schema.PrimaryKeyConstraint() + for col in [table.c[name] for name in pkey_cols]: + pkey.append_column(col) + table.append_constraint(pkey) + + fkeys = self.get_foreign_keys(connection, table.name, + table.schema, parser=parser) + # foreign keys + for fkey_d in fkeys: + conname = fkey_d['name'] + loc_names = fkey_d['constrained_columns'] + ref_schema = fkey_d['referred_schema'] + ref_name = fkey_d['referred_table'] + ref_names = fkey_d['referred_columns'] + options = fkey_d['options'] + refspec = [] + + # load related table + ref_key = sa_schema._get_table_key(ref_name, ref_schema) + if ref_key in table.metadata.tables: + ref_table = table.metadata.tables[ref_key] + else: + ref_table = sa_schema.Table( + ref_name, table.metadata, schema=ref_schema, + autoload=True, autoload_with=connection) - return reflector.reflect(connection, table, sql, charset, - only=include_columns) + if ref_schema: + refspec = [".".join([ref_schema, ref_name, column]) for column in ref_names] + else: + refspec = [".".join([ref_name, column]) for column in ref_names] + key = sa_schema.ForeignKeyConstraint(loc_names, refspec, + link_to_name=True, **con_kw) + table.append_constraint(key) + + # Indexes + indexes = self.get_indexes(connection, table.name, table.schema, + parser=parser) + for index_d in indexes: + name = index_d['name'] + col_names = index_d['column_names'] + unique = index_d['unique'] + flavor = index_d['type'] + if include_columns and \ + not set(col_names).issubset(include_columns): + self.logger.info( + "Omitting %s KEY for (%s), key covers ommitted columns." % + (flavor, ', '.join(col_names))) + continue + key = sa_schema.Index(name, unique=unique) + for col in [table.c[name] for name in col_names]: + key.append_column(col) def _adjust_casing(self, table, charset=None): """Adjust Table name to the server case sensitivity, if needed.""" @@ -1985,8 +2155,7 @@ class MySQLDialect(default.DefaultDialect): return rows -class MySQLSchemaReflector(object): - """Parses SHOW CREATE TABLE output.""" +class MySQLTableDefinitionParser(object): def __init__(self, dialect, preparer=None): """Construct a MySQLSchemaReflector. @@ -1999,79 +2168,143 @@ class MySQLSchemaReflector(object): self.dialect = dialect self.preparer = preparer or dialect.identifier_preparer self._prep_regexes() - - def reflect(self, connection, table, show_create, charset, only=None): - """Parse MySQL SHOW CREATE TABLE and fill in a ''Table''. - - show_create - Unicode output of SHOW CREATE TABLE - - table - A ''Table'', to be loaded with Columns, Indexes, etc. - table.name will be set if not already - - charset - FIXME, some constructed values (like column defaults) - currently can't be Unicode. ''charset'' will convert them - into the connection character set. - - only - An optional sequence of column names. If provided, only - these columns will be reflected, and any keys or constraints - that include columns outside this set will also be omitted. - That means that if ``only`` includes only one column in a - 2 part primary key, the entire primary key will be omitted. - """ - - keys, constraints = [], [] - - if only: - only = set(only) - + # parsed results + self._set_defaults() + + def _set_defaults(self): + self.columns = [] + self.table_options = {} + self.table_name = None + self.keys = [] + self.constraints = [] + + def parse(self, show_create, charset): + self._set_defaults() + self.charset = charset for line in re.split(r'\r?\n', show_create): if line.startswith(' ' + self.preparer.initial_quote): - self._add_column(table, line, charset, only) + self._parse_column(line) # a regular table options line elif line.startswith(') '): - self._set_options(table, line) + self._parse_table_options(line) # an ANSI-mode table options line elif line == ')': pass elif line.startswith('CREATE '): - self._set_name(table, line) + self._parse_table_name(line) # Not present in real reflection, but may be if loading from a file. elif not line: pass else: - type_, spec = self.parse_constraints(line) + type_, spec = self._parse_constraints(line) if type_ is None: util.warn("Unknown schema content: %r" % line) elif type_ == 'key': - keys.append(spec) + self.keys.append(spec) elif type_ == 'constraint': - constraints.append(spec) + self.constraints.append(spec) else: pass - self._set_keys(table, keys, only) - self._set_constraints(table, constraints, connection, only) + def _parse_constraints(self, line): + """Parse a KEY or CONSTRAINT line. + + line + A line of SHOW CREATE TABLE output + """ + + # KEY + m = self._re_key.match(line) + if m: + spec = m.groupdict() + # convert columns into name, length pairs + spec['columns'] = self._parse_keyexprs(spec['columns']) + return 'key', spec + + # CONSTRAINT + m = self._re_constraint.match(line) + if m: + spec = m.groupdict() + spec['table'] = \ + self.preparer.unformat_identifiers(spec['table']) + spec['local'] = [c[0] + for c in self._parse_keyexprs(spec['local'])] + spec['foreign'] = [c[0] + for c in self._parse_keyexprs(spec['foreign'])] + return 'constraint', spec + + # PARTITION and SUBPARTITION + m = self._re_partition.match(line) + if m: + # Punt! + return 'partition', line - def _set_name(self, table, line): - """Override a Table name with the reflected name. + # No match. + return (None, line) - table - A ``Table`` + def _parse_table_name(self, line): + """Extract the table name. + + line + The first line of SHOW CREATE TABLE + """ + + regex, cleanup = self._pr_name + m = regex.match(line) + if m: + self.table_name = cleanup(m.group('name')) + + def _parse_table_options(self, line): + """Build a dictionary of all reflected table-level options. line - The first line of SHOW CREATE TABLE output. + The final line of SHOW CREATE TABLE output. """ - # Don't override by default. - if table.name is None: - table.name = self.parse_name(line) + options = {} + + if not line or line == ')': + pass + + else: + r_eq_trim = self._re_options_util['='] + + for regex, cleanup in self._pr_options: + m = regex.search(line) + if not m: + continue + directive, value = m.group('directive'), m.group('val') + directive = r_eq_trim.sub('', directive).lower() + if cleanup: + value = cleanup(value) + options[directive] = value + + for nope in ('auto_increment', 'data_directory', 'index_directory'): + options.pop(nope, None) + + for opt, val in options.items(): + self.table_options['mysql_%s' % opt] = val - def _add_column(self, table, line, charset, only=None): - spec = self.parse_column(line) + def _parse_column(self, line): + """Extract column details. + + Falls back to a 'minimal support' variant if full parse fails. + + line + Any column-bearing line from SHOW CREATE TABLE + """ + + charset = self.charset + spec = None + m = self._re_column.match(line) + if m: + spec = m.groupdict() + spec['full'] = True + else: + m = self._re_column_loose.match(line) + if m: + spec = m.groupdict() + spec['full'] = False if not spec: util.warn("Unknown column definition %r" % line) return @@ -2081,11 +2314,6 @@ class MySQLSchemaReflector(object): name, type_, args, notnull = \ spec['name'], spec['coltype'], spec['arg'], spec['notnull'] - if only and name not in only: - self.logger.info("Omitting reflected column %s.%s" % - (table.name, name)) - return - # Convention says that TINYINT(1) columns == BOOLEAN if type_ == 'tinyint' and args == '1': type_ = 'boolean' @@ -2143,122 +2371,60 @@ class MySQLSchemaReflector(object): default = sql.text(default) else: default = default[1:-1] - col_args.append(schema.DefaultClause(default)) - - table.append_column(schema.Column(name, type_instance, - *col_args, **col_kw)) - - def _set_keys(self, table, keys, only): - """Add ``Index`` and ``PrimaryKeyConstraint`` items to a ``Table``. + col_d = dict(name=name, type=type_instance, colargs=col_args, + default=default) + col_d.update(col_kw) + self.columns.append(col_d) - Most of the information gets dropped here- more is reflected than - the schema objects can currently represent. - - table - A ``Table`` + def _describe_to_create(self, table_name, columns): + """Re-format DESCRIBE output as a SHOW CREATE TABLE string. - keys - A sequence of key specifications produced by `constraints` + DESCRIBE is a much simpler reflection and is sufficient for + reflecting views for runtime use. This method formats DDL + for columns only- keys are omitted. - only - Optional `set` of column names. If provided, keys covering - columns not in this set will be omitted. + `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples. + SHOW FULL COLUMNS FROM rows must be rearranged for use with + this function. """ - for spec in keys: - flavor = spec['type'] - col_names = [s[0] for s in spec['columns']] - - if only and not set(col_names).issubset(only): - if flavor is None: - flavor = 'index' - self.logger.info( - "Omitting %s KEY for (%s), key covers ommitted columns." % - (flavor, ', '.join(col_names))) - continue - - constraint = False - if flavor == 'PRIMARY': - key = schema.PrimaryKeyConstraint() - constraint = True - elif flavor == 'UNIQUE': - key = schema.Index(spec['name'], unique=True) - elif flavor in (None, 'FULLTEXT', 'SPATIAL'): - key = schema.Index(spec['name']) - else: - self.logger.info( - "Converting unknown KEY type %s to a plain KEY" % flavor) - key = schema.Index(spec['name']) - - for col in [table.c[name] for name in col_names]: - key.append_column(col) - - if constraint: - table.append_constraint(key) - - def _set_constraints(self, table, constraints, connection, only): - """Apply constraints to a ``Table``.""" - - default_schema = None - - for spec in constraints: - # only FOREIGN KEYs - ref_name = spec['table'][-1] - ref_schema = len(spec['table']) > 1 and spec['table'][-2] or table.schema - - if not ref_schema: - if default_schema is None: - default_schema = connection.dialect.get_default_schema_name( - connection) - if table.schema == default_schema: - ref_schema = table.schema - - loc_names = spec['local'] - if only and not set(loc_names).issubset(only): - self.logger.info( - "Omitting FOREIGN KEY for (%s), key covers ommitted " - "columns." % (', '.join(loc_names))) - continue - - ref_key = schema._get_table_key(ref_name, ref_schema) - if ref_key in table.metadata.tables: - ref_table = table.metadata.tables[ref_key] - else: - ref_table = schema.Table( - ref_name, table.metadata, schema=ref_schema, - autoload=True, autoload_with=connection) - - ref_names = spec['foreign'] - - if ref_schema: - refspec = [".".join([ref_schema, ref_name, column]) for column in ref_names] - else: - refspec = [".".join([ref_name, column]) for column in ref_names] - - con_kw = {} - for opt in ('name', 'onupdate', 'ondelete'): - if spec.get(opt, False): - con_kw[opt] = spec[opt] - - key = schema.ForeignKeyConstraint(loc_names, refspec, link_to_name=True, **con_kw) - table.append_constraint(key) + buffer = [] + for row in columns: + (name, col_type, nullable, default, extra) = \ + [row[i] for i in (0, 1, 2, 4, 5)] - def _set_options(self, table, line): - """Apply safe reflected table options to a ``Table``. + line = [' '] + line.append(self.preparer.quote_identifier(name)) + line.append(col_type) + if not nullable: + line.append('NOT NULL') + if default: + if 'auto_increment' in default: + pass + elif (col_type.startswith('timestamp') and + default.startswith('C')): + line.append('DEFAULT') + line.append(default) + elif default == 'NULL': + line.append('DEFAULT') + line.append(default) + else: + line.append('DEFAULT') + line.append("'%s'" % default.replace("'", "''")) + if extra: + line.append(extra) - table - A ``Table`` + buffer.append(' '.join(line)) - line - The final line of SHOW CREATE TABLE output. - """ + return ''.join([('CREATE TABLE %s (\n' % + self.preparer.quote_identifier(table_name)), + ',\n'.join(buffer), + '\n) ']) - options = self.parse_table_options(line) - for nope in ('auto_increment', 'data_directory', 'index_directory'): - options.pop(nope, None) + def _parse_keyexprs(self, identifiers): + """Unpack '"col"(2),"col" ASC'-ish strings into components.""" - for opt, val in options.items(): - table.kwargs['mysql_%s' % opt] = val + return self._re_keyexprs.findall(identifiers) def _prep_regexes(self): """Pre-compile regular expressions.""" @@ -2416,154 +2582,8 @@ class MySQLSchemaReflector(object): r'(?P%s)' % (re.escape(directive), regex)) self._pr_options.append(_pr_compile(regex)) - - def parse_name(self, line): - """Extract the table name. - - line - The first line of SHOW CREATE TABLE - """ - - regex, cleanup = self._pr_name - m = regex.match(line) - if not m: - return None - return cleanup(m.group('name')) - - def parse_column(self, line): - """Extract column details. - - Falls back to a 'minimal support' variant if full parse fails. - - line - Any column-bearing line from SHOW CREATE TABLE - """ - - m = self._re_column.match(line) - if m: - spec = m.groupdict() - spec['full'] = True - return spec - m = self._re_column_loose.match(line) - if m: - spec = m.groupdict() - spec['full'] = False - return spec - return None - - def parse_constraints(self, line): - """Parse a KEY or CONSTRAINT line. - - line - A line of SHOW CREATE TABLE output - """ - - # KEY - m = self._re_key.match(line) - if m: - spec = m.groupdict() - # convert columns into name, length pairs - spec['columns'] = self._parse_keyexprs(spec['columns']) - return 'key', spec - - # CONSTRAINT - m = self._re_constraint.match(line) - if m: - spec = m.groupdict() - spec['table'] = \ - self.preparer.unformat_identifiers(spec['table']) - spec['local'] = [c[0] - for c in self._parse_keyexprs(spec['local'])] - spec['foreign'] = [c[0] - for c in self._parse_keyexprs(spec['foreign'])] - return 'constraint', spec - - # PARTITION and SUBPARTITION - m = self._re_partition.match(line) - if m: - # Punt! - return 'partition', line - - # No match. - return (None, line) - - def parse_table_options(self, line): - """Build a dictionary of all reflected table-level options. - - line - The final line of SHOW CREATE TABLE output. - """ - - options = {} - - if not line or line == ')': - return options - - r_eq_trim = self._re_options_util['='] - - for regex, cleanup in self._pr_options: - m = regex.search(line) - if not m: - continue - directive, value = m.group('directive'), m.group('val') - directive = r_eq_trim.sub('', directive).lower() - if cleanup: - value = cleanup(value) - options[directive] = value - - return options - - def _describe_to_create(self, table, columns): - """Re-format DESCRIBE output as a SHOW CREATE TABLE string. - - DESCRIBE is a much simpler reflection and is sufficient for - reflecting views for runtime use. This method formats DDL - for columns only- keys are omitted. - - `columns` is a sequence of DESCRIBE or SHOW COLUMNS 6-tuples. - SHOW FULL COLUMNS FROM rows must be rearranged for use with - this function. - """ - - buffer = [] - for row in columns: - (name, col_type, nullable, default, extra) = \ - [row[i] for i in (0, 1, 2, 4, 5)] - - line = [' '] - line.append(self.preparer.quote_identifier(name)) - line.append(col_type) - if not nullable: - line.append('NOT NULL') - if default: - if 'auto_increment' in default: - pass - elif (col_type.startswith('timestamp') and - default.startswith('C')): - line.append('DEFAULT') - line.append(default) - elif default == 'NULL': - line.append('DEFAULT') - line.append(default) - else: - line.append('DEFAULT') - line.append("'%s'" % default.replace("'", "''")) - if extra: - line.append(extra) - - buffer.append(' '.join(line)) - - return ''.join([('CREATE TABLE %s (\n' % - self.preparer.quote_identifier(table.name)), - ',\n'.join(buffer), - '\n) ']) - - def _parse_keyexprs(self, identifiers): - """Unpack '"col"(2),"col" ASC'-ish strings into components.""" - - return self._re_keyexprs.findall(identifiers) - -log.class_logger(MySQLSchemaReflector) +log.class_logger(MySQLTableDefinitionParser) +log.class_logger(MySQLDialect) class _DecodingRowProxy(object): diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index 3d6964606c..f93b5f31a2 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -986,10 +986,10 @@ class SQLTest(TestBase, AssertsCompiledSQL): class RawReflectionTest(TestBase): def setUp(self): self.dialect = mysql.dialect() - self.reflector = mysql.MySQLSchemaReflector(self.dialect) + self.parser = mysql.MySQLTableDefinitionParser(self.dialect) def test_key_reflection(self): - regex = self.reflector._re_key + regex = self.parser._re_key assert regex.match(' PRIMARY KEY (`id`),') assert regex.match(' PRIMARY KEY USING BTREE (`id`),') -- 2.47.3