From b7c0c84be9f0ccc6d7e7aeea6d95b111d208cbc5 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Sat, 7 Mar 2009 20:28:40 +0000 Subject: [PATCH] applied Michael's patch to fix issue with CREATE TABLE parser state --- lib/sqlalchemy/dialects/mysql/base.py | 98 ++++++++++++++------------- 1 file changed, 52 insertions(+), 46 deletions(-) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index e39d3f7ae3..19cefd4b55 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1837,15 +1837,15 @@ class MySQLDialect(default.DefaultDialect): @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - - parser = kw.get('parser') - return parser.columns + + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + return parsed_state.columns @reflection.cache def get_primary_keys(self, connection, table_name, schema=None, **kw): - parser = kw.get('parser') - for key in parser.keys: + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + for key in parsed_state.keys: if key['type'] == 'PRIMARY': # There can be only one. ##raise Exception, str(key) @@ -1855,12 +1855,12 @@ class MySQLDialect(default.DefaultDialect): @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): - parser = kw.get('parser') + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) default_schema = None fkeys = [] - for spec in parser.constraints: + for spec in parsed_state.constraints: # only FOREIGN KEYs ref_name = spec['table'][-1] ref_schema = len(spec['table']) > 1 and spec['table'][-2] or schema @@ -1894,9 +1894,10 @@ class MySQLDialect(default.DefaultDialect): @reflection.cache def get_indexes(self, connection, table_name, schema, **kw): - parser = kw.get('parser') + parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw) + indexes = [] - for spec in parser.keys: + for spec in parsed_state.keys: unique = False flavor = spec['type'] if flavor == 'PRIMARY': @@ -1917,7 +1918,13 @@ class MySQLDialect(default.DefaultDialect): indexes.append(index_d) return indexes - def _setupParser(self, connection, table_name, schema=None): + def _parsed_state_or_create(self, connection, table_name, schema=None, **kw): + if 'parsed_state' in kw: + return kw['parsed_state'] + else: + return self._setup_parser(connection, table.name, schema) + + def _setup_parser(self, connection, table_name, schema=None): charset = self._connection_charset try: @@ -1938,25 +1945,24 @@ class MySQLDialect(default.DefaultDialect): columns = self._describe_table(connection, None, charset, full_name=full_name) sql = parser._describe_to_create(table_name, columns) - parser.parse(sql, charset) - return parser + return parser.parse(sql, charset) def reflecttable(self, connection, table, include_columns): """Load column definitions from the server.""" charset = self._connection_charset self._adjust_casing(table) - parser = self._setupParser(connection, table.name, table.schema) + parsed_state = self._setup_parser(connection, table.name, table.schema) # check the table name - if parser.table_name is not None: - table.name = parser.table_name + if parsed_state.table_name is not None: + table.name = parsed_state.table_name # apply table options - if parser.table_options: - table.kwargs.update(parser.table_options) + if parsed_state.table_options: + table.kwargs.update(parsed_state.table_options) # columns for col_d in self.get_columns(connection, table.name, table.schema, - parser=parser): + parsed_state=parsed_state): name = col_d['name'] coltype = col_d['type'] nullable = col_d.get('nullable', True) @@ -1977,14 +1983,14 @@ class MySQLDialect(default.DefaultDialect): # primary keys pkey_cols = self.get_primary_keys(connection, table.name, - table.schema, parser=parser) + table.schema, parsed_state=parsed_state) pkey = sa_schema.PrimaryKeyConstraint() for col in [table.c[name] for name in pkey_cols]: pkey.append_column(col) table.append_constraint(pkey) fkeys = self.get_foreign_keys(connection, table.name, - table.schema, parser=parser) + table.schema, parsed_state=parsed_state) # foreign keys for fkey_d in fkeys: conname = fkey_d['name'] @@ -2014,7 +2020,7 @@ class MySQLDialect(default.DefaultDialect): # Indexes indexes = self.get_indexes(connection, table.name, table.schema, - parser=parser) + parsed_state=parsed_state) for index_d in indexes: name = index_d['name'] col_names = index_d['column_names'] @@ -2154,7 +2160,14 @@ class MySQLDialect(default.DefaultDialect): rp.close() return rows - +class ReflectedState(object): + def __init__(self): + self.columns = [] + self.table_options = {} + self.table_name = None + self.keys = [] + self.constraints = [] + class MySQLTableDefinitionParser(object): def __init__(self, dialect, preparer=None): @@ -2168,30 +2181,21 @@ class MySQLTableDefinitionParser(object): self.dialect = dialect self.preparer = preparer or dialect.identifier_preparer self._prep_regexes() - # parsed results - self._set_defaults() - - def _set_defaults(self): - self.columns = [] - self.table_options = {} - self.table_name = None - self.keys = [] - self.constraints = [] def parse(self, show_create, charset): - self._set_defaults() - self.charset = charset + state = ReflectedState() + state.charset = charset for line in re.split(r'\r?\n', show_create): if line.startswith(' ' + self.preparer.initial_quote): - self._parse_column(line) + self._parse_column(line, state) # a regular table options line elif line.startswith(') '): - self._parse_table_options(line) + self._parse_table_options(line, state) # an ANSI-mode table options line elif line == ')': pass elif line.startswith('CREATE '): - self._parse_table_name(line) + self._parse_table_name(line, state) # Not present in real reflection, but may be if loading from a file. elif not line: pass @@ -2200,12 +2204,14 @@ class MySQLTableDefinitionParser(object): if type_ is None: util.warn("Unknown schema content: %r" % line) elif type_ == 'key': - self.keys.append(spec) + state.keys.append(spec) elif type_ == 'constraint': - self.constraints.append(spec) + state.constraints.append(spec) else: pass - + + return state + def _parse_constraints(self, line): """Parse a KEY or CONSTRAINT line. @@ -2242,7 +2248,7 @@ class MySQLTableDefinitionParser(object): # No match. return (None, line) - def _parse_table_name(self, line): + def _parse_table_name(self, line, state): """Extract the table name. line @@ -2252,9 +2258,9 @@ class MySQLTableDefinitionParser(object): regex, cleanup = self._pr_name m = regex.match(line) if m: - self.table_name = cleanup(m.group('name')) + state.table_name = cleanup(m.group('name')) - def _parse_table_options(self, line): + def _parse_table_options(self, line, state): """Build a dictionary of all reflected table-level options. line @@ -2283,9 +2289,9 @@ class MySQLTableDefinitionParser(object): options.pop(nope, None) for opt, val in options.items(): - self.table_options['mysql_%s' % opt] = val + state.table_options['mysql_%s' % opt] = val - def _parse_column(self, line): + def _parse_column(self, line, state): """Extract column details. Falls back to a 'minimal support' variant if full parse fails. @@ -2294,7 +2300,7 @@ class MySQLTableDefinitionParser(object): Any column-bearing line from SHOW CREATE TABLE """ - charset = self.charset + charset = state.charset spec = None m = self._re_column.match(line) if m: @@ -2374,7 +2380,7 @@ class MySQLTableDefinitionParser(object): col_d = dict(name=name, type=type_instance, colargs=col_args, default=default) col_d.update(col_kw) - self.columns.append(col_d) + state.columns.append(col_d) def _describe_to_create(self, table_name, columns): """Re-format DESCRIBE output as a SHOW CREATE TABLE string. -- 2.47.3