]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixes
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 27 May 2009 00:22:52 +0000 (00:22 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 27 May 2009 00:22:52 +0000 (00:22 +0000)
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/sql/compiler.py
test/dialect/mysql.py

index 75d1442db11f3917e114e55c807daee1e80f4b26..cec9ade928fe61feeae6891a499fd193f2a6df6a 100644 (file)
@@ -292,7 +292,7 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
     
     __visit_name__ = 'NUMERIC'
     
-    def __init__(self, precision=10, scale=2, asdecimal=True, **kw):
+    def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
         """Construct a NUMERIC.
 
         :param precision: Total digits in this number.  If scale and precision
@@ -331,7 +331,7 @@ class MSDecimal(MSNumeric):
     
     __visit_name__ = 'DECIMAL'
     
-    def __init__(self, precision=10, scale=2, asdecimal=True, **kw):
+    def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
         """Construct a DECIMAL.
 
         :param precision: Total digits in this number.  If scale and precision
@@ -1523,6 +1523,8 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
     def visit_NUMERIC(self, type_):
         if type_.precision is None:
             return self._extend_numeric(type_, "NUMERIC")
+        elif type_.scale is None:
+            return self._extend_numeric(type_, "NUMERIC(%(precision)s)" % {'precision': type_.precision})
         else:
             return self._extend_numeric(type_, "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale})
 
@@ -2335,11 +2337,7 @@ class MySQLTableDefinitionParser(object):
         if default is not None and default != 'NULL':
             # Defaults should be in the native charset for the moment
             default = default.encode(charset)
-            if type_ == 'timestamp':
-                # can't be NULL for TIMESTAMPs
-                if (default[0], default[-1]) != ("'", "'"):
-                    default = sql.text(default)
-            else:
+            if type_ != 'timestamp':
                 default = default[1:-1]
         elif default == 'NULL':
             # eliminates the need to deal with this later.
index 3e81608171ec538d456bf679e63efaffea928614..032843361819b5ddf65147badf80729cd69e681f 100644 (file)
@@ -586,11 +586,9 @@ class Compiled(object):
 
         raise NotImplementedError()
 
-    params = property(construct_params, doc="""
-        Return the bind params for this compiled object.
-    
-    """)
-
+    def params(self):
+        """Return the bind params for this compiled object."""
+        return self.construct_params()
 
     def execute(self, *multiparams, **params):
         """Execute this compiled object."""
index cc0d511c9588edbd3f915977d2e5cc36bc837df8..66062e2e007ec4267b6090d12b51cec3f3bbfa9a 100644 (file)
@@ -245,7 +245,8 @@ class Inspector(object):
             dialect._adjust_casing(table)
 
         # table attributes we might need.
-        reflection_options = dict((k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
+        reflection_options = dict(
+            (k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
 
         schema = table.schema
         table_name = table.name
@@ -274,30 +275,30 @@ class Inspector(object):
         for col_d in self.get_columns(table_name, schema, **tblkw):
             found_table = True
             name = col_d['name']
-            coltype = col_d['type']
-            nullable = col_d['nullable']
-            default = col_d['default']
-            colargs = []
-            col_kw = {}
-            if 'autoincrement' in col_d:
-                col_kw['autoincrement'] = col_d['autoincrement']
             if include_columns and name not in include_columns:
                 continue
-            if default is not None:
-                # fixme
-                # mysql does not use sql.text
-                if isinstance(dialect, MySQLDialect):
-                    colargs.append(sa_schema.DefaultClause(default))
-                else:
-                    colargs.append(sa_schema.DefaultClause(sql.text(default)))
-            col = sa_schema.Column(name, coltype,nullable=nullable, *colargs, **col_kw)
+
+            coltype = col_d['type']
+            col_kw = {
+                'nullable':col_d['nullable'],
+                'autoincrement':col_d.get('autoincrement', False)
+            }
+            
+            colargs = []
+            if col_d.get('default') is not None:
+                colargs.append(sa_schema.DefaultClause(col_d['default']))
+                
             if 'sequence' in col_d:
+                # TODO: whos using this ?
                 seq = col_d['sequence']
-                col.sequence = sa_schema.Sequence(seq['name'], 1, 1)
+                sequence = sa_schema.Sequence(seq['name'], 1, 1)
                 if 'start' in seq:
-                    col.sequence.start = seq['start']
+                    sequence.start = seq['start']
                 if 'increment' in seq:
-                    col.sequence.increment = seq['increment']
+                    sequence.increment = seq['increment']
+                colargs.append(sequence)
+                
+            col = sa_schema.Column(name, coltype, *colargs, **col_kw)
             table.append_column(col)
 
         if not found_table:
index ab6bf0d4db3f13b409aa04241bbdc2818708a1db..db864daa002a2e445a6ed7e3041ca088b55d3523 100644 (file)
@@ -244,6 +244,11 @@ class SQLCompiler(engine.Compiled):
                     pd[self.bind_names[bindparam]] = bindparam.value
             return pd
 
+    params = property(construct_params, doc="""
+        Return the bind params for this compiled object.
+
+    """)
+
     def default_from(self):
         """Called when a SELECT statement has no froms, and no FROM clause is to be appended.
 
index de4480a4f701b1d51eaf88d3376d03c75a989ccb..fbb363093b14e9bd04967d1fe7dd50131f7d3e3b 100644 (file)
@@ -55,11 +55,11 @@ class TypesTest(TestBase, AssertsExecutionResults):
             # column type, args, kwargs, expected ddl
             # e.g. Column(Integer(10, unsigned=True)) == 'INTEGER(10) UNSIGNED'
             (mysql.MSNumeric, [], {},
-             'NUMERIC(10, 2)'),
+             'NUMERIC'),
             (mysql.MSNumeric, [None], {},
              'NUMERIC'),
             (mysql.MSNumeric, [12], {},
-             'NUMERIC(12, 2)'),
+             'NUMERIC(12)'),
             (mysql.MSNumeric, [12, 4], {'unsigned':True},
              'NUMERIC(12, 4) UNSIGNED'),
             (mysql.MSNumeric, [12, 4], {'zerofill':True},
@@ -68,11 +68,11 @@ class TypesTest(TestBase, AssertsExecutionResults):
              'NUMERIC(12, 4) UNSIGNED ZEROFILL'),
 
             (mysql.MSDecimal, [], {},
-             'DECIMAL(10, 2)'),
+             'DECIMAL'),
             (mysql.MSDecimal, [None], {},
              'DECIMAL'),
             (mysql.MSDecimal, [12], {},
-             'DECIMAL(12, 2)'),
+             'DECIMAL(12)'),
             (mysql.MSDecimal, [12, None], {},
              'DECIMAL(12)'),
             (mysql.MSDecimal, [12, 4], {'unsigned':True},
@@ -909,11 +909,11 @@ class SQLTest(TestBase, AssertsCompiledSQL):
             (m.MSBit, "t.col"),
 
             # this is kind of sucky.  thank you default arguments!
-            (NUMERIC, "CAST(t.col AS DECIMAL(10, 2))"),
-            (DECIMAL, "CAST(t.col AS DECIMAL(10, 2))"),
-            (Numeric, "CAST(t.col AS DECIMAL(10, 2))"),
-            (m.MSNumeric, "CAST(t.col AS DECIMAL(10, 2))"),
-            (m.MSDecimal, "CAST(t.col AS DECIMAL(10, 2))"),
+            (NUMERIC, "CAST(t.col AS DECIMAL)"),
+            (DECIMAL, "CAST(t.col AS DECIMAL)"),
+            (Numeric, "CAST(t.col AS DECIMAL)"),
+            (m.MSNumeric, "CAST(t.col AS DECIMAL)"),
+            (m.MSDecimal, "CAST(t.col AS DECIMAL)"),
 
             (FLOAT, "t.col"),
             (Float, "t.col"),
@@ -998,8 +998,8 @@ class SQLTest(TestBase, AssertsCompiledSQL):
 
 class RawReflectionTest(TestBase):
     def setUp(self):
-        self.dialect = mysql.dialect()
-        self.parser = mysql.MySQLTableDefinitionParser(self.dialect)
+        dialect = mysql.dialect()
+        self.parser = mysql.MySQLTableDefinitionParser(dialect, dialect.identifier_preparer)
 
     def test_key_reflection(self):
         regex = self.parser._re_key