From 682769322734dd04ab0b28954b95c7d1fdfec604 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 27 May 2009 00:22:52 +0000 Subject: [PATCH] fixes --- lib/sqlalchemy/dialects/mysql/base.py | 12 ++++----- lib/sqlalchemy/engine/base.py | 8 +++--- lib/sqlalchemy/engine/reflection.py | 39 ++++++++++++++------------- lib/sqlalchemy/sql/compiler.py | 5 ++++ test/dialect/mysql.py | 22 +++++++-------- 5 files changed, 44 insertions(+), 42 deletions(-) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 75d1442db1..cec9ade928 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -292,7 +292,7 @@ class MSNumeric(sqltypes.Numeric, _NumericType): __visit_name__ = 'NUMERIC' - def __init__(self, precision=10, scale=2, asdecimal=True, **kw): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a NUMERIC. :param precision: Total digits in this number. If scale and precision @@ -331,7 +331,7 @@ class MSDecimal(MSNumeric): __visit_name__ = 'DECIMAL' - def __init__(self, precision=10, scale=2, asdecimal=True, **kw): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a DECIMAL. :param precision: Total digits in this number. If scale and precision @@ -1523,6 +1523,8 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def visit_NUMERIC(self, type_): if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") + elif type_.scale is None: + return self._extend_numeric(type_, "NUMERIC(%(precision)s)" % {'precision': type_.precision}) else: return self._extend_numeric(type_, "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}) @@ -2335,11 +2337,7 @@ class MySQLTableDefinitionParser(object): if default is not None and default != 'NULL': # Defaults should be in the native charset for the moment default = default.encode(charset) - if type_ == 'timestamp': - # can't be NULL for TIMESTAMPs - if (default[0], default[-1]) != ("'", "'"): - default = sql.text(default) - else: + if type_ != 'timestamp': default = default[1:-1] elif default == 'NULL': # eliminates the need to deal with this later. diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 3e81608171..0328433618 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -586,11 +586,9 @@ class Compiled(object): raise NotImplementedError() - params = property(construct_params, doc=""" - Return the bind params for this compiled object. - - """) - + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() def execute(self, *multiparams, **params): """Execute this compiled object.""" diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index cc0d511c95..66062e2e00 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -245,7 +245,8 @@ class Inspector(object): dialect._adjust_casing(table) # table attributes we might need. - reflection_options = dict((k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs) + reflection_options = dict( + (k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs) schema = table.schema table_name = table.name @@ -274,30 +275,30 @@ class Inspector(object): for col_d in self.get_columns(table_name, schema, **tblkw): found_table = True name = col_d['name'] - coltype = col_d['type'] - nullable = col_d['nullable'] - default = col_d['default'] - colargs = [] - col_kw = {} - if 'autoincrement' in col_d: - col_kw['autoincrement'] = col_d['autoincrement'] if include_columns and name not in include_columns: continue - if default is not None: - # fixme - # mysql does not use sql.text - if isinstance(dialect, MySQLDialect): - colargs.append(sa_schema.DefaultClause(default)) - else: - colargs.append(sa_schema.DefaultClause(sql.text(default))) - col = sa_schema.Column(name, coltype,nullable=nullable, *colargs, **col_kw) + + coltype = col_d['type'] + col_kw = { + 'nullable':col_d['nullable'], + 'autoincrement':col_d.get('autoincrement', False) + } + + colargs = [] + if col_d.get('default') is not None: + colargs.append(sa_schema.DefaultClause(col_d['default'])) + if 'sequence' in col_d: + # TODO: whos using this ? seq = col_d['sequence'] - col.sequence = sa_schema.Sequence(seq['name'], 1, 1) + sequence = sa_schema.Sequence(seq['name'], 1, 1) if 'start' in seq: - col.sequence.start = seq['start'] + sequence.start = seq['start'] if 'increment' in seq: - col.sequence.increment = seq['increment'] + sequence.increment = seq['increment'] + colargs.append(sequence) + + col = sa_schema.Column(name, coltype, *colargs, **col_kw) table.append_column(col) if not found_table: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ab6bf0d4db..db864daa00 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -244,6 +244,11 @@ class SQLCompiler(engine.Compiled): pd[self.bind_names[bindparam]] = bindparam.value return pd + params = property(construct_params, doc=""" + Return the bind params for this compiled object. + + """) + def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index de4480a4f7..fbb363093b 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -55,11 +55,11 @@ class TypesTest(TestBase, AssertsExecutionResults): # column type, args, kwargs, expected ddl # e.g. Column(Integer(10, unsigned=True)) == 'INTEGER(10) UNSIGNED' (mysql.MSNumeric, [], {}, - 'NUMERIC(10, 2)'), + 'NUMERIC'), (mysql.MSNumeric, [None], {}, 'NUMERIC'), (mysql.MSNumeric, [12], {}, - 'NUMERIC(12, 2)'), + 'NUMERIC(12)'), (mysql.MSNumeric, [12, 4], {'unsigned':True}, 'NUMERIC(12, 4) UNSIGNED'), (mysql.MSNumeric, [12, 4], {'zerofill':True}, @@ -68,11 +68,11 @@ class TypesTest(TestBase, AssertsExecutionResults): 'NUMERIC(12, 4) UNSIGNED ZEROFILL'), (mysql.MSDecimal, [], {}, - 'DECIMAL(10, 2)'), + 'DECIMAL'), (mysql.MSDecimal, [None], {}, 'DECIMAL'), (mysql.MSDecimal, [12], {}, - 'DECIMAL(12, 2)'), + 'DECIMAL(12)'), (mysql.MSDecimal, [12, None], {}, 'DECIMAL(12)'), (mysql.MSDecimal, [12, 4], {'unsigned':True}, @@ -909,11 +909,11 @@ class SQLTest(TestBase, AssertsCompiledSQL): (m.MSBit, "t.col"), # this is kind of sucky. thank you default arguments! - (NUMERIC, "CAST(t.col AS DECIMAL(10, 2))"), - (DECIMAL, "CAST(t.col AS DECIMAL(10, 2))"), - (Numeric, "CAST(t.col AS DECIMAL(10, 2))"), - (m.MSNumeric, "CAST(t.col AS DECIMAL(10, 2))"), - (m.MSDecimal, "CAST(t.col AS DECIMAL(10, 2))"), + (NUMERIC, "CAST(t.col AS DECIMAL)"), + (DECIMAL, "CAST(t.col AS DECIMAL)"), + (Numeric, "CAST(t.col AS DECIMAL)"), + (m.MSNumeric, "CAST(t.col AS DECIMAL)"), + (m.MSDecimal, "CAST(t.col AS DECIMAL)"), (FLOAT, "t.col"), (Float, "t.col"), @@ -998,8 +998,8 @@ class SQLTest(TestBase, AssertsCompiledSQL): class RawReflectionTest(TestBase): def setUp(self): - self.dialect = mysql.dialect() - self.parser = mysql.MySQLTableDefinitionParser(self.dialect) + dialect = mysql.dialect() + self.parser = mysql.MySQLTableDefinitionParser(dialect, dialect.identifier_preparer) def test_key_reflection(self): regex = self.parser._re_key -- 2.47.3