]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
applied Michael's patch to fix issue with CREATE TABLE parser state
authorRandall Smith <randall@tnr.cc>
Sat, 7 Mar 2009 20:28:40 +0000 (20:28 +0000)
committerRandall Smith <randall@tnr.cc>
Sat, 7 Mar 2009 20:28:40 +0000 (20:28 +0000)
lib/sqlalchemy/dialects/mysql/base.py

index e39d3f7ae3ff9b1b3a737b89511e9a3ab9b38d1d..19cefd4b55c7afcdd5f806fd8d88459f08ea0ed0 100644 (file)
@@ -1837,15 +1837,15 @@ class MySQLDialect(default.DefaultDialect):
 
     @reflection.cache
     def get_columns(self, connection, table_name, schema=None, **kw):
-        
-        parser = kw.get('parser')
-        return parser.columns
+
+        parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+        return parsed_state.columns
 
     @reflection.cache
     def get_primary_keys(self, connection, table_name, schema=None, **kw):
 
-        parser = kw.get('parser')
-        for key in parser.keys:
+        parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+        for key in parsed_state.keys:
             if key['type'] == 'PRIMARY':
                 # There can be only one.
                 ##raise Exception, str(key)
@@ -1855,12 +1855,12 @@ class MySQLDialect(default.DefaultDialect):
     @reflection.cache
     def get_foreign_keys(self, connection, table_name, schema=None, **kw):
 
-        parser = kw.get('parser')
+        parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
         default_schema = None
 
         fkeys = []
 
-        for spec in parser.constraints:
+        for spec in parsed_state.constraints:
             # only FOREIGN KEYs
             ref_name = spec['table'][-1]
             ref_schema = len(spec['table']) > 1 and spec['table'][-2] or schema
@@ -1894,9 +1894,10 @@ class MySQLDialect(default.DefaultDialect):
     @reflection.cache
     def get_indexes(self, connection, table_name, schema, **kw):
 
-        parser = kw.get('parser')
+        parsed_state = self._parsed_state_or_create(connection, table_name, schema, **kw)
+        
         indexes = []
-        for spec in parser.keys:
+        for spec in parsed_state.keys:
             unique = False
             flavor = spec['type']
             if flavor == 'PRIMARY':
@@ -1917,7 +1918,13 @@ class MySQLDialect(default.DefaultDialect):
             indexes.append(index_d)
         return indexes
 
-    def _setupParser(self, connection, table_name, schema=None):
+    def _parsed_state_or_create(self, connection, table_name, schema=None, **kw):
+        if 'parsed_state' in kw:
+            return kw['parsed_state']
+        else:
+            return self._setup_parser(connection, table.name, schema)
+        
+    def _setup_parser(self, connection, table_name, schema=None):
 
         charset = self._connection_charset
         try:
@@ -1938,25 +1945,24 @@ class MySQLDialect(default.DefaultDialect):
             columns = self._describe_table(connection, None, charset,
                                            full_name=full_name)
             sql = parser._describe_to_create(table_name, columns)
-        parser.parse(sql, charset)
-        return parser
+        return parser.parse(sql, charset)
   
     def reflecttable(self, connection, table, include_columns):
         """Load column definitions from the server."""
 
         charset = self._connection_charset
         self._adjust_casing(table)
-        parser = self._setupParser(connection, table.name, table.schema)
+        parsed_state = self._setup_parser(connection, table.name, table.schema)
 
         # check the table name
-        if parser.table_name is not None:
-            table.name = parser.table_name
+        if parsed_state.table_name is not None:
+            table.name = parsed_state.table_name
         # apply table options
-        if parser.table_options:
-            table.kwargs.update(parser.table_options)
+        if parsed_state.table_options:
+            table.kwargs.update(parsed_state.table_options)
         # columns
         for col_d in self.get_columns(connection, table.name, table.schema,
-                                      parser=parser):
+                                      parsed_state=parsed_state):
             name = col_d['name']
             coltype = col_d['type']
             nullable = col_d.get('nullable', True)
@@ -1977,14 +1983,14 @@ class MySQLDialect(default.DefaultDialect):
 
         # primary keys
         pkey_cols = self.get_primary_keys(connection, table.name,
-                                          table.schema, parser=parser)
+                                          table.schema, parsed_state=parsed_state)
         pkey = sa_schema.PrimaryKeyConstraint()
         for col in [table.c[name] for name in pkey_cols]:
             pkey.append_column(col)
         table.append_constraint(pkey)
 
         fkeys = self.get_foreign_keys(connection, table.name,
-                                      table.schema, parser=parser)
+                                      table.schema, parsed_state=parsed_state)
         # foreign keys
         for fkey_d in fkeys:
             conname = fkey_d['name']
@@ -2014,7 +2020,7 @@ class MySQLDialect(default.DefaultDialect):
 
         # Indexes 
         indexes = self.get_indexes(connection, table.name, table.schema,
-                                   parser=parser)
+                                   parsed_state=parsed_state)
         for index_d in indexes:
             name = index_d['name']
             col_names = index_d['column_names']
@@ -2154,7 +2160,14 @@ class MySQLDialect(default.DefaultDialect):
                 rp.close()
         return rows
 
-
+class ReflectedState(object):
+    def __init__(self):
+        self.columns = []
+        self.table_options = {}
+        self.table_name = None
+        self.keys = []
+        self.constraints = []
+        
 class MySQLTableDefinitionParser(object):
 
     def __init__(self, dialect, preparer=None):
@@ -2168,30 +2181,21 @@ class MySQLTableDefinitionParser(object):
         self.dialect = dialect
         self.preparer = preparer or dialect.identifier_preparer
         self._prep_regexes()
-        # parsed results
-        self._set_defaults()
-
-    def _set_defaults(self):
-        self.columns = []
-        self.table_options = {}
-        self.table_name = None
-        self.keys = []
-        self.constraints = []
 
     def parse(self, show_create, charset):
-        self._set_defaults()
-        self.charset = charset
+        state = ReflectedState()
+        state.charset = charset
         for line in re.split(r'\r?\n', show_create):
             if line.startswith('  ' + self.preparer.initial_quote):
-                self._parse_column(line)
+                self._parse_column(line, state)
             # a regular table options line
             elif line.startswith(') '):
-                self._parse_table_options(line)
+                self._parse_table_options(line, state)
             # an ANSI-mode table options line
             elif line == ')':
                 pass
             elif line.startswith('CREATE '):
-                self._parse_table_name(line)
+                self._parse_table_name(line, state)
             # Not present in real reflection, but may be if loading from a file.
             elif not line:
                 pass
@@ -2200,12 +2204,14 @@ class MySQLTableDefinitionParser(object):
                 if type_ is None:
                     util.warn("Unknown schema content: %r" % line)
                 elif type_ == 'key':
-                    self.keys.append(spec)
+                    state.keys.append(spec)
                 elif type_ == 'constraint':
-                    self.constraints.append(spec)
+                    state.constraints.append(spec)
                 else:
                     pass
-
+                    
+        return state
+        
     def _parse_constraints(self, line):
         """Parse a KEY or CONSTRAINT line.
 
@@ -2242,7 +2248,7 @@ class MySQLTableDefinitionParser(object):
         # No match.
         return (None, line)
 
-    def _parse_table_name(self, line):
+    def _parse_table_name(self, line, state):
         """Extract the table name.
 
         line
@@ -2252,9 +2258,9 @@ class MySQLTableDefinitionParser(object):
         regex, cleanup = self._pr_name
         m = regex.match(line)
         if m:
-            self.table_name = cleanup(m.group('name'))
+            state.table_name = cleanup(m.group('name'))
 
-    def _parse_table_options(self, line):
+    def _parse_table_options(self, line, state):
         """Build a dictionary of all reflected table-level options.
 
         line
@@ -2283,9 +2289,9 @@ class MySQLTableDefinitionParser(object):
             options.pop(nope, None)
 
         for opt, val in options.items():
-            self.table_options['mysql_%s' % opt] = val
+            state.table_options['mysql_%s' % opt] = val
 
-    def _parse_column(self, line):
+    def _parse_column(self, line, state):
         """Extract column details.
 
         Falls back to a 'minimal support' variant if full parse fails.
@@ -2294,7 +2300,7 @@ class MySQLTableDefinitionParser(object):
           Any column-bearing line from SHOW CREATE TABLE
         """
 
-        charset = self.charset
+        charset = state.charset
         spec = None
         m = self._re_column.match(line)
         if m:
@@ -2374,7 +2380,7 @@ class MySQLTableDefinitionParser(object):
         col_d = dict(name=name, type=type_instance, colargs=col_args,
                      default=default)
         col_d.update(col_kw)
-        self.columns.append(col_d)
+        state.columns.append(col_d)
 
     def _describe_to_create(self, table_name, columns):
         """Re-format DESCRIBE output as a SHOW CREATE TABLE string.