From 87bbba32bc54fa0253e9c81663df669dc355f5da Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 24 Apr 2012 16:03:00 -0400 Subject: [PATCH] - [feature] The behavior of column targeting in result sets is now case sensitive by default. SQLAlchemy for many years would run a case-insensitive conversion on these values, probably to alleviate early case sensitivity issues with dialects like Oracle and Firebird. These issues have been more cleanly solved in more modern versions so the performance hit of calling lower() on identifiers is removed. The case insensitive comparisons can be re-enabled by setting "case_insensitive=False" on create_engine(). [ticket:2423] --- CHANGES | 13 ++++++++ lib/sqlalchemy/dialects/mssql/base.py | 5 ++- lib/sqlalchemy/engine/__init__.py | 6 ++++ lib/sqlalchemy/engine/base.py | 47 ++++++++++++++++++++------- lib/sqlalchemy/engine/default.py | 3 ++ lib/sqlalchemy/sql/compiler.py | 18 +++++++--- test/aaa_profiling/test_resultset.py | 8 ++--- test/aaa_profiling/test_zoomark.py | 4 +-- test/sql/test_query.py | 32 +++++++++++++++--- 9 files changed, 109 insertions(+), 27 deletions(-) diff --git a/CHANGES b/CHANGES index 38d51bc8ac..5e5790554f 100644 --- a/CHANGES +++ b/CHANGES @@ -160,6 +160,19 @@ CHANGES "inspector" object as the first argument. [ticket:2418] + - [feature] The behavior of column targeting + in result sets is now case sensitive by + default. SQLAlchemy for many years would + run a case-insensitive conversion on these values, + probably to alleviate early case sensitivity + issues with dialects like Oracle and + Firebird. These issues have been more cleanly + solved in more modern versions so the performance + hit of calling lower() on identifiers is removed. + The case insensitive comparisons can be re-enabled + by setting "case_insensitive=False" on + create_engine(). [ticket:2423] + - [bug] column.label(None) now produces an anonymous label, instead of returning the column object itself, consistent with the behavior diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index a63f10251e..e5eb447449 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -856,7 +856,9 @@ class MSSQLCompiler(compiler.SQLCompiler): t, column) if result_map is not None: - result_map[column.name.lower()] = \ + result_map[column.name + if self.dialect.case_sensitive + else column.name.lower()] = \ (column.name, (column, ), column.type) @@ -1300,6 +1302,7 @@ class MSDialect(default.DefaultDialect): whereclause = columns.c.table_name==tablename s = sql.select([columns], whereclause, order_by=[columns.c.ordinal_position]) + c = connection.execute(s) cols = [] while True: diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index 23b4b0b3b8..c3667dd335 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -143,6 +143,12 @@ def create_engine(*args, **kwargs): :class:`.String` type - see that type for further details. + :param case_sensitive=True: if False, result column names + will match in a case-insensitive fashion, that is, + ``row['SomeColumn']``. By default, result row names + match case-sensitively as of version 0.8. In version + 0.7 and prior, all matches were case-insensitive. + :param connect_args: a dictionary of options which will be passed directly to the DBAPI's ``connect()`` method as additional keyword arguments. See the example diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 1d25113337..93d2b19f10 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -2724,6 +2724,7 @@ class ResultMetaData(object): dialect = context.dialect typemap = dialect.dbapi_type_map translate_colname = dialect._translate_colname + self.case_sensitive = dialect.case_sensitive # high precedence key values. primary_keymap = {} @@ -2738,9 +2739,14 @@ class ResultMetaData(object): if translate_colname: colname, untranslated = translate_colname(colname) + if dialect.requires_name_normalize: + colname = dialect.normalize_name(colname) + if context.result_map: try: - name, obj, type_ = context.result_map[colname.lower()] + name, obj, type_ = context.result_map[colname + if self.case_sensitive + else colname.lower()] except KeyError: name, obj, type_ = \ colname, None, typemap.get(coltype, types.NULLTYPE) @@ -2758,17 +2764,20 @@ class ResultMetaData(object): primary_keymap[i] = rec # populate primary keymap, looking for conflicts. - if primary_keymap.setdefault(name.lower(), rec) is not rec: + if primary_keymap.setdefault( + name if self.case_sensitive + else name.lower(), + rec) is not rec: # place a record that doesn't have the "index" - this # is interpreted later as an AmbiguousColumnError, # but only when actually accessed. Columns # colliding by name is not a problem if those names # aren't used; integer and ColumnElement access is always # unambiguous. - primary_keymap[name.lower()] = (processor, obj, None) + primary_keymap[name + if self.case_sensitive + else name.lower()] = (processor, obj, None) - if dialect.requires_name_normalize: - colname = dialect.normalize_name(colname) self.keys.append(colname) if obj: @@ -2797,7 +2806,9 @@ class ResultMetaData(object): row. """ - rec = (processor, obj, i) = self._keymap[origname.lower()] + rec = (processor, obj, i) = self._keymap[origname if + self.case_sensitive + else origname.lower()] if self._keymap.setdefault(name, rec) is not rec: self._keymap[name] = (processor, obj, None) @@ -2805,17 +2816,27 @@ class ResultMetaData(object): map = self._keymap result = None if isinstance(key, basestring): - result = map.get(key.lower()) + result = map.get(key if self.case_sensitive else key.lower()) # fallback for targeting a ColumnElement to a textual expression # this is a rare use case which only occurs when matching text() # or colummn('name') constructs to ColumnElements, or after a # pickle/unpickle roundtrip elif isinstance(key, expression.ColumnElement): - if key._label and key._label.lower() in map: - result = map[key._label.lower()] - elif hasattr(key, 'name') and key.name.lower() in map: + if key._label and ( + key._label + if self.case_sensitive + else key._label.lower()) in map: + result = map[key._label + if self.case_sensitive + else key._label.lower()] + elif hasattr(key, 'name') and ( + key.name + if self.case_sensitive + else key.name.lower()) in map: # match is only on name. - result = map[key.name.lower()] + result = map[key.name + if self.case_sensitive + else key.name.lower()] # search extra hard to make sure this # isn't a column/label name overlap. # this check isn't currently available if the row @@ -2851,7 +2872,8 @@ class ResultMetaData(object): for key, (processor, obj, index) in self._keymap.iteritems() if isinstance(key, (basestring, int)) ), - 'keys': self.keys + 'keys': self.keys, + "case_sensitive":self.case_sensitive, } def __setstate__(self, state): @@ -2864,6 +2886,7 @@ class ResultMetaData(object): # proxy comparison fails with the unpickle keymap[key] = (None, None, index) self.keys = state['keys'] + self.case_sensitive = state['case_sensitive'] self._echo = False diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index d0cbe871ff..1f72d005de 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -105,6 +105,7 @@ class DefaultDialect(base.Dialect): def __init__(self, convert_unicode=False, assert_unicode=False, encoding='utf-8', paramstyle=None, dbapi=None, implicit_returning=None, + case_sensitive=True, label_length=None, **kwargs): if not getattr(self, 'ported_sqla_06', True): @@ -139,6 +140,8 @@ class DefaultDialect(base.Dialect): self.identifier_preparer = self.preparer(self) self.type_compiler = self.type_compiler(self) + self.case_sensitive = case_sensitive + if label_length and label_length > self.max_identifier_length: raise exc.ArgumentError( "Label length of %d is greater than this dialect's" diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index bf234fe5cc..218e48bcae 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -365,7 +365,9 @@ class SQLCompiler(engine.Compiled): labelname = label.name if result_map is not None: - result_map[labelname.lower()] = ( + result_map[labelname + if self.dialect.case_sensitive + else labelname.lower()] = ( label.name, (label, label.element, labelname, ) + label._alt_names, @@ -393,7 +395,9 @@ class SQLCompiler(engine.Compiled): name = self._truncated_identifier("colident", name) if result_map is not None: - result_map[name.lower()] = (orig_name, + result_map[name + if self.dialect.case_sensitive + else name.lower()] = (orig_name, (column, name, column.key), column.type) @@ -441,7 +445,10 @@ class SQLCompiler(engine.Compiled): def visit_textclause(self, textclause, **kwargs): if textclause.typemap is not None: for colname, type_ in textclause.typemap.iteritems(): - self.result_map[colname.lower()] = (colname, None, type_) + self.result_map[colname + if self.dialect.case_sensitive + else colname.lower()] = \ + (colname, None, type_) def do_bindparam(m): name = m.group(1) @@ -518,7 +525,10 @@ class SQLCompiler(engine.Compiled): def visit_function(self, func, result_map=None, **kwargs): if result_map is not None: - result_map[func.name.lower()] = (func.name, None, func.type) + result_map[func.name + if self.dialect.case_sensitive + else func.name.lower()] = \ + (func.name, None, func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) if disp: diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py index 632f67c6a2..0fc85ca035 100644 --- a/test/aaa_profiling/test_resultset.py +++ b/test/aaa_profiling/test_resultset.py @@ -37,8 +37,8 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): '2.4': 13214, '2.6':14416, '2.7':14416, - '2.6+cextension': 365, - '2.7+cextension':365}) + '2.6+cextension': 336, + '2.7+cextension':336}) def test_string(self): [tuple(row) for row in t.select().execute().fetchall()] @@ -47,8 +47,8 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): @profiling.function_call_count(versions={ '2.7':14396, '2.6':14396, - '2.6+cextension': 365, - '2.7+cextension':365}) + '2.6+cextension': 336, + '2.7+cextension':336}) def test_unicode(self): [tuple(row) for row in t2.select().execute().fetchall()] diff --git a/test/aaa_profiling/test_zoomark.py b/test/aaa_profiling/test_zoomark.py index a0336662d7..d4c66336c0 100644 --- a/test/aaa_profiling/test_zoomark.py +++ b/test/aaa_profiling/test_zoomark.py @@ -377,8 +377,8 @@ class ZooMarkTest(fixtures.TestBase): def test_profile_2_insert(self): self.test_baseline_2_insert() - @profiling.function_call_count(3340, {'2.4': 2158, '2.7':3541, - '2.7+cextension':3317, '2.6':3564}) + @profiling.function_call_count(3340, {'2.7':3333, + '2.7+cextension':3317, '2.6':3333}) def test_profile_3_properties(self): self.test_baseline_3_properties() diff --git a/test/sql/test_query.py b/test/sql/test_query.py index f315d6621c..29a6ed3552 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -852,9 +852,7 @@ class QueryTest(fixtures.TestBase): result.fetchone ) - def test_result_case_sensitivity(self): - """test name normalization for result sets.""" - + def test_row_case_sensitive(self): row = testing.db.execute( select([ literal_column("1").label("case_insensitive"), @@ -862,7 +860,33 @@ class QueryTest(fixtures.TestBase): ]) ).first() - assert row.keys() == ["case_insensitive", "CaseSensitive"] + eq_(row.keys(), ["case_insensitive", "CaseSensitive"]) + eq_(row["case_insensitive"], 1) + eq_(row["CaseSensitive"], 2) + + assert_raises( + KeyError, + lambda: row["Case_insensitive"] + ) + assert_raises( + KeyError, + lambda: row["casesensitive"] + ) + + def test_row_case_insensitive(self): + ins_db = engines.testing_engine(options={"case_sensitive":False}) + row = ins_db.execute( + select([ + literal_column("1").label("case_insensitive"), + literal_column("2").label("CaseSensitive") + ]) + ).first() + + eq_(row.keys(), ["case_insensitive", "CaseSensitive"]) + eq_(row["case_insensitive"], 1) + eq_(row["CaseSensitive"], 2) + eq_(row["Case_insensitive"],1) + eq_(row["casesensitive"],2) def test_row_as_args(self): -- 2.47.3