]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- simplify the OurSQL dialect regarding py3k, this version gives it a fairly
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Feb 2010 20:12:43 +0000 (20:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Feb 2010 20:12:43 +0000 (20:12 +0000)
fighting chance on python 3.  there's an oursql bug where it can't raise
an exception on executemany() correctly.
- needed to add "plain_query" wrappers for all the reflection methods.  not sure
why this was not needed earlier.

lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/oursql.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/test/assertsql.py

index 1a9e57ee77205b89c5c6771c8ea5679b33fce4cd..b9e3080b282a08b878703e83adc5203a4a4175e1 100644 (file)
@@ -1132,15 +1132,6 @@ ischema_names = {
 }
 
 class MySQLExecutionContext(default.DefaultExecutionContext):
-    def post_exec(self):
-        # TODO: i think this 'charset' in the info thing 
-        # is out
-        
-        if (not self.isupdate and not self.should_autocommit and
-              self.statement and SET_RE.match(self.statement)):
-            # This misses if a user forces autocommit on text('SET NAMES'),
-            # which is probably a programming error anyhow.
-            self.connection.info.pop(('mysql', 'charset'), None)
 
     def should_autocommit_text(self, statement):
         return AUTOCOMMIT_RE.match(statement)
@@ -1725,13 +1716,6 @@ class MySQLDialect(default.DefaultDialect):
     def _get_default_schema_name(self, connection):
         return connection.execute('SELECT DATABASE()').scalar()
 
-    def table_names(self, connection, schema):
-        """Return a Unicode SHOW TABLES from a given schema."""
-
-        charset = self._connection_charset
-        rp = connection.execute("SHOW TABLES FROM %s" %
-            self.identifier_preparer.quote_identifier(schema))
-        return [row[0] for row in self._compat_fetchall(rp, charset=charset)]
 
     def has_table(self, connection, table_name, schema=None):
         # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly
@@ -1782,17 +1766,28 @@ class MySQLDialect(default.DefaultDialect):
 
     @reflection.cache
     def get_table_names(self, connection, schema=None, **kw):
-        if schema is None:
-            schema = self.default_schema_name
-        if self.server_version_info < (5, 0, 2):
-            return self.table_names(connection, schema)
+        if schema is not None:
+            current_schema = schema
+        else:
+            current_schema = self.default_schema_name
+        table_names = self.table_names(connection, current_schema)
+        return table_names
+
+    def table_names(self, connection, schema):
+        """Return a Unicode SHOW TABLES from a given schema."""
+
         charset = self._connection_charset
-        rp = connection.execute("SHOW FULL TABLES FROM %s" %
+        if self.server_version_info < (5, 0, 2):
+            rp = connection.execute("SHOW TABLES FROM %s" %
                 self.identifier_preparer.quote_identifier(schema))
-        
-        return [row[0] for row in self._compat_fetchall(rp, charset=charset)\
-                                                    if row[1] == 'BASE TABLE']
+            return [row[0] for row in self._compat_fetchall(rp, charset=charset)]
+        else:
+            rp = connection.execute("SHOW FULL TABLES FROM %s" %
+                    self.identifier_preparer.quote_identifier(schema))
 
+            return [row[0] for row in self._compat_fetchall(rp, charset=charset)\
+                                                        if row[1] == 'BASE TABLE']
+            
     @reflection.cache
     def get_view_names(self, connection, schema=None, **kw):
         charset = self._connection_charset
index e75d1e0bd58bd65530809a2db7e9e9a53c41d3d9..83db1bc724074b85f39643b99aa71ae786424571 100644 (file)
@@ -51,14 +51,18 @@ class MySQL_oursqlExecutionContext(MySQLExecutionContext):
     @property
     def plain_query(self):
         return self.execution_options.get('_oursql_plain_query', False)
-
-
+    
 class MySQL_oursql(MySQLDialect):
     driver = 'oursql'
 # Py3K
 #    description_encoding = None
-    supports_unicode_statements = True
+#    supports_unicode_binds = False
+#    supports_unicode_statements = False
+# Py2K
     supports_unicode_binds = True
+    supports_unicode_statements = True
+# end Py2K
+
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
     execution_ctx_cls = MySQL_oursqlExecutionContext
@@ -77,24 +81,7 @@ class MySQL_oursql(MySQLDialect):
         return __import__('oursql')
 
     def do_execute(self, cursor, statement, parameters, context=None):
-        """Provide an implementation of *cursor.execute(statement, parameters)*.
-#       TODO: this isn't right.   the supports_unicode_binds
-#       and supports_unicode_statements flags should be used for this one.
-#       also, don't call _detect_charset - use self._connection_charset
-#       which is already configured (uses _detect_charset just once)."""
-
-# Py3K
-#        if context is not None:
-#            charset = self._detect_charset(context.connection)
-#            if charset is not None:
-#                statement = statement.encode(charset)
-#                encoded_parameters = []
-#                for p in parameters:
-#                    if isinstance(p, str):
-#                        encoded_parameters.append(p.encode(charset))
-#                    else:
-#                        encoded_parameters.append(p)
-#                parameters = encoded_parameters
+        """Provide an implementation of *cursor.execute(statement, parameters)*."""
 
         if context and context.plain_query:
             cursor.execute(statement, plain_query=True)
@@ -109,12 +96,14 @@ class MySQL_oursql(MySQLDialect):
         arg = connection.connection._escape_string(xid)
 # end Py2K
 # Py3K
-#        charset = connection.connection.charset
+#        charset = self._connection_charset
 #        arg = connection.connection._escape_string(xid.encode(charset)).decode(charset)
         connection.execution_options(_oursql_plain_query=True).execute(query % arg)
 
-    # Because mysql is bad, these methods have to be reimplemented to use _PlainQuery. Basically, some queries
-    # refuse to return any data if they're run through the parameterized query API, or refuse to be parameterized
+    # Because mysql is bad, these methods have to be 
+    # reimplemented to use _PlainQuery. Basically, some queries
+    # refuse to return any data if they're run through 
+    # the parameterized query API, or refuse to be parameterized
     # in the first place.
     def do_begin_twophase(self, connection, xid):
         self._xa_query(connection, 'XA BEGIN "%s"', xid)
@@ -134,26 +123,72 @@ class MySQL_oursql(MySQLDialect):
         if not is_prepared:
             self.do_prepare_twophase(connection, xid)
         self._xa_query(connection, 'XA COMMIT "%s"', xid)
-
+    
+    # Q: why didn't we need all these "plain_query" overrides earlier ?
+    # am i on a newer/older version of OurSQL ?
     def has_table(self, connection, table_name, schema=None):
-        return MySQLDialect.has_table(self, connection.execution_options(_oursql_plain_query=True), table_name, schema)
-
-    # TODO: don't do this.   just have base _show_create_table return
-    # unicode.  don't reuse _detect_charset(), use _connection_charset.
+        return MySQLDialect.has_table(self, 
+                                        connection.connect().\
+                                            execution_options(_oursql_plain_query=True),
+                                        table_name, schema)
+    
+    def get_table_options(self, connection, table_name, schema=None, **kw):
+        return MySQLDialect.get_table_options(self,
+                                            connection.connect().\
+                                                execution_options(_oursql_plain_query=True),
+                                            table_name,
+                                            schema = schema,
+                                            **kw
+        )
+
+
+    def get_columns(self, connection, table_name, schema=None, **kw):
+        return MySQLDialect.get_columns(self,
+                                        connection.connect().\
+                                                    execution_options(_oursql_plain_query=True),
+                                        table_name,
+                                        schema=schema,
+                                        **kw
+        )
+        
+    def get_view_names(self, connection, schema=None, **kw):
+        return MySQLDialect.get_view_names(self,
+                                            connection.connect().\
+                                                    execution_options(_oursql_plain_query=True),
+                                            schema=schema,
+                                            **kw
+        )
+        
+    def table_names(self, connection, schema):
+        return MySQLDialect.table_names(self,
+                            connection.connect().\
+                                        execution_options(_oursql_plain_query=True),
+                            schema
+        )
+        
+    def get_schema_names(self, connection, **kw):
+        return MySQLDialect.get_schema_names(self,
+                                    connection.connect().\
+                                                execution_options(_oursql_plain_query=True),
+                                    **kw
+        )
+        
+    def initialize(self, connection):
+        return MySQLDialect.initialize(
+                            self, 
+                            connection.execution_options(_oursql_plain_query=True)
+                            )
+        
     def _show_create_table(self, connection, table, charset=None,
                            full_name=None):
-        sql = MySQLDialect._show_create_table(self,
-            connection.contextual_connect(close_with_result=True).execution_options(_oursql_plain_query=True),
-            table, charset, full_name)
-# Py3K
-#        charset = self._detect_charset(connection)
-#        if charset is not None:
-#            sql = sql.decode(charset)
-        return sql
+        return MySQLDialect._show_create_table(self,
+                                connection.contextual_connect(close_with_result=True).
+                                execution_options(_oursql_plain_query=True),
+                                table, charset, full_name)
 
     def is_disconnect(self, e):
-        if isinstance(e, self.dbapi.ProgrammingError):  # if underlying connection is closed, this is the error you get
-            return e.errno is None and e.args[1].endswith('closed')
+        if isinstance(e, self.dbapi.ProgrammingError):  
+            return e.errno is None and 'cursor' not in e.args[1] and e.args[1].endswith('closed')
         else:
             return e.errno in (2006, 2013, 2014, 2045, 2055)
 
@@ -199,13 +234,8 @@ class MySQL_oursql(MySQLDialect):
 
     def _detect_charset(self, connection):
         """Sniff out the character set in use for connection results."""
-        if hasattr(connection, 'connection'):
-            if hasattr(connection.connection, 'use_unicode') and connection.connection.use_unicode:
-                return None
-            else:
-                return connection.connection.charset
-        else:
-            return None
+    
+        return connection.connection.charset
 
     def _compat_fetchall(self, rp, charset=None):
         """oursql isn't super-broken like MySQLdb, yaaay."""
@@ -215,5 +245,8 @@ class MySQL_oursql(MySQLDialect):
         """oursql isn't super-broken like MySQLdb, yaaay."""
         return rp.fetchone()
 
+    def _compat_first(self, rp, charset=None):
+        return rp.first()
+
 
 dialect = MySQL_oursql
index ac933bdf423e3cb14491f99cde088fc86ed0040c..4d4fd7c71969657f979bba3c6fdea480f7d896f2 100644 (file)
@@ -115,12 +115,6 @@ class DefaultDialect(base.Dialect):
 
         if not hasattr(self, 'description_encoding'):
             self.description_encoding = getattr(self, 'description_encoding', encoding)
-
-        # Py3K
-        ## work around dialects that might change these values
-        #self.supports_unicode_statements = True
-        #self.supports_unicode_binds = True
-        #self.returns_unicode_strings = True
     
     @property
     def dialect_description(self):
@@ -136,9 +130,7 @@ class DefaultDialect(base.Dialect):
         except NotImplementedError:
             self.default_schema_name = None
 
-        # Py2K
         self.returns_unicode_strings = self._check_unicode_returns(connection)
-        # end Py2K
     
     def _check_unicode_returns(self, connection):
         cursor = connection.connection.cursor()
@@ -268,9 +260,10 @@ class DefaultExecutionContext(base.ExecutionContext):
                                                     )
 
             if not dialect.supports_unicode_statements:
-                self.statement = unicode(compiled).encode(self.dialect.encoding)
+                self.unicode_statement = unicode(compiled)
+                self.statement = self.unicode_statement.encode(self.dialect.encoding)
             else:
-                self.statement = unicode(compiled)
+                self.statement = self.unicode_statement = unicode(compiled)
                 
             self.cursor = self.create_cursor()
             self.compiled_parameters = []
@@ -302,9 +295,10 @@ class DefaultExecutionContext(base.ExecutionContext):
             self.result_map = compiled.result_map
 
             if not dialect.supports_unicode_statements:
-                self.statement = unicode(compiled).encode(self.dialect.encoding)
+                self.unicode_statement = unicode(compiled)
+                self.statement = self.unicode_statement.encode(self.dialect.encoding)
             else:
-                self.statement = unicode(compiled)
+                self.statement = self.unicode_statement = unicode(compiled)
 
             self.isinsert = compiled.isinsert
             self.isupdate = compiled.isupdate
@@ -322,16 +316,20 @@ class DefaultExecutionContext(base.ExecutionContext):
             if self.isinsert or self.isupdate:
                 self.__process_defaults()
             self.parameters = self.__convert_compiled_params(self.compiled_parameters)
+            
         elif statement is not None:
             # plain text statement
             if connection._execution_options:
                 self.execution_options = self.execution_options.union(connection._execution_options)
             self.parameters = self.__encode_param_keys(parameters)
             self.executemany = len(parameters) > 1
+            
             if isinstance(statement, unicode) and not dialect.supports_unicode_statements:
+                self.unicode_statement = statement
                 self.statement = statement.encode(self.dialect.encoding)
             else:
-                self.statement = statement
+                self.statement = self.unicode_statement = statement
+                
             self.cursor = self.create_cursor()
         else:
             # no statement. used for standalone ColumnDefault execution.
@@ -349,7 +347,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                                                 or False)
                                                 
         if autocommit is expression.PARSE_AUTOCOMMIT:
-            return self.should_autocommit_text(self.statement)
+            return self.should_autocommit_text(self.unicode_statement)
         else:
             return autocommit
             
@@ -374,7 +372,8 @@ class DefaultExecutionContext(base.ExecutionContext):
     def _execute_scalar(self, stmt):
         """Execute a string statement on the current cursor, returning a scalar result.
         
-        Used to fire off sequences, default phrases, and "select lastrowid" types of statements individually
+        Used to fire off sequences, default phrases, and "select lastrowid" 
+        types of statements individually
         or in the context of a parent INSERT or UPDATE statement.
         
         """
index 6dbc95b784fafeb8f428e76a65cf06d23b38f8f5..1417c2e4355fe95a6504bb41a7d02231380a3723 100644 (file)
@@ -63,7 +63,7 @@ class ExactSQL(SQLMatchRule):
         if not context:
             return
             
-        _received_statement = _process_engine_statement(statement, context)
+        _received_statement = _process_engine_statement(context.unicode_statement, context)
         _received_parameters = context.compiled_parameters
         
         # TODO: remove this step once all unit tests
@@ -101,7 +101,7 @@ class RegexSQL(SQLMatchRule):
         if not context:
             return
 
-        _received_statement = _process_engine_statement(statement, context)
+        _received_statement = _process_engine_statement(context.unicode_statement, context)
         _received_parameters = context.compiled_parameters
 
         equivalent = bool(self.regex.match(_received_statement))