From f5454c89ea82966075e58458b44fe2279d70a361 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 4 Sep 2006 01:56:31 +0000 Subject: [PATCH] simplification to quoting to just cache strings per-dialect, added quoting for alias and label names fixes [ticket:294] --- lib/sqlalchemy/ansisql.py | 126 ++++++++++++++++---------------------- lib/sqlalchemy/sql.py | 4 +- test/orm/selectresults.py | 38 ++++++++++++ test/sql/quote.py | 12 ++++ 4 files changed, 105 insertions(+), 75 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index d65e8ad338..d053f73898 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -32,10 +32,11 @@ def create_engine(): return engine.ComposedSQLEngine(None, ANSIDialect()) class ANSIDialect(default.DefaultDialect): - def __init__(self, **kwargs): + def __init__(self, cache_identifiers=True, **kwargs): super(ANSIDialect,self).__init__(**kwargs) self.identifier_preparer = self.preparer() - + self.cache_identifiers = cache_identifiers + def connect_args(self): return ([],{}) @@ -158,7 +159,7 @@ class ANSICompiler(sql.Compiled): def visit_label(self, label): if len(self.select_stack): self.typemap.setdefault(label.name.lower(), label.obj.type) - self.strings[label] = self.strings[label.obj] + " AS " + label.name + self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label) def visit_column(self, column): if len(self.select_stack): @@ -289,7 +290,7 @@ class ANSICompiler(sql.Compiled): return self.bindtemplate % name def visit_alias(self, alias): - self.froms[alias] = self.get_from_text(alias.original) + " AS " + alias.name + self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias) self.strings[alias] = self.get_str(alias.original) def visit_select(self, select): @@ -717,7 +718,7 @@ class ANSISchemaDropper(engine.SchemaIterator): class ANSIDefaultRunner(engine.DefaultRunner): pass -class ANSIIdentifierPreparer(schema.SchemaVisitor): +class ANSIIdentifierPreparer(object): """handles quoting and case-folding of identifiers based on options""" def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): """Constructs a new ANSIIdentifierPreparer object. @@ -731,8 +732,7 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): self.initial_quote = initial_quote self.final_quote = final_quote or self.initial_quote self.omit_schema = omit_schema - self.__strings = weakref.WeakKeyDictionary() - self.__visited = weakref.WeakKeyDictionary() + self.__strings = {} def _escape_identifier(self, value): """escape an identifier. @@ -771,68 +771,24 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): or bool(len([x for x in str(value) if x not in self._legal_characters()])) \ or (case_sensitive and value.lower() != value) - def visit_table(self, table): - if table in self.__visited: - return - - if table.quote or self._requires_quotes(table.name, table.case_sensitive): - tablestring = self._quote_identifier(table.name) - else: - tablestring = table.name - - if table.schema: - if table.quote_schema or self._requires_quotes(table.schema, table.case_sensitive_schema): - schemastring = self._quote_identifier(table.schema) - else: - schemastring = table.schema - else: - schemastring = None - - self.__strings[table] = (tablestring, schemastring) - - def visit_column(self, column): - if column in self.__visited: - return - if column.quote or self._requires_quotes(column.name, column.case_sensitive): - self.__strings[column] = self._quote_identifier(column.name) - else: - self.__strings[column] = column.name - - def visit_sequence(self, sequence): - if sequence in self.__visited: - return - if sequence.quote or self._requires_quotes(sequence.name, sequence.case_sensitive): - self.__strings[sequence] = self._quote_identifier(sequence.name) - else: - self.__strings[sequence] = sequence.name - - def __analyze_identifiers(self, obj): - """insure that each object we encounter is analyzed only once for its lifetime.""" - if obj in self.__visited: - return - if isinstance(obj, schema.SchemaItem): - obj.accept_schema_visitor(self) - self.__visited[obj] = True - - def __prepare_sequence(self, sequence): - self.__analyze_identifiers(sequence) - return self.__strings.get(sequence, sequence.name) - - def __prepare_table(self, table, use_schema=False): - self.__analyze_identifiers(table) - tablename = self.__strings.get(table, (table.name, None))[0] - if not self.omit_schema and use_schema and self.__strings.get(table, (None,None))[1] is not None: - return self.__strings[table][1] + "." + tablename - else: - return tablename - - def __prepare_column(self, column, use_table=True, **kwargs): - self.__analyze_identifiers(column) - if use_table: - return self.__prepare_table(column.table, **kwargs) + "." + self.__strings.get(column, column.name) + def __generic_obj_format(self, obj, ident): + if getattr(obj, 'quote', False): + return self._quote_identifier(ident) + if self.dialect.cache_identifiers: + try: + return self.__strings[ident] + except KeyError: + if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())): + self.__strings[ident] = self._quote_identifier(ident) + else: + self.__strings[ident] = ident + return self.__strings[ident] else: - return self.__strings.get(column, column.name) - + if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())): + return self._quote_identifier(ident) + else: + return ident + def should_quote(self, object): return object.quote or self._requires_quotes(object.name, object.case_sensitive) @@ -840,16 +796,38 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): return object.quote or self._requires_quotes(object.name, object.case_sensitive) def format_sequence(self, sequence): - return self.__prepare_sequence(sequence) + return self.__generic_obj_format(sequence, sequence.name) + + def format_label(self, label): + return self.__generic_obj_format(label, label.name) + + def format_alias(self, alias): + return self.__generic_obj_format(alias, alias.name) def format_table(self, table, use_schema=True): """Prepare a quoted table and schema name""" - return self.__prepare_table(table, use_schema=use_schema) + result = self.__generic_obj_format(table, table.name) + if use_schema and getattr(table, "schema", None): + result = self.__generic_obj_format(table, table.schema) + "." + result + return result - def format_column(self, column): + def format_column(self, column, use_table=False): """Prepare a quoted column name """ - return self.__prepare_column(column, use_table=False) - + # TODO: isinstance alert ! get ColumnClause and Column to better + # differentiate themselves + if isinstance(column, schema.SchemaItem): + if use_table: + return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name) + else: + return self.__generic_obj_format(column, column.name) + else: + # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted + if use_table: + return column.table.name + "." + column.name + else: + return column.name + def format_column_with_table(self, column): """Prepare a quoted column name with table name""" - return self.__prepare_column(column) + return self.format_column(column, use_table=True) + diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 596e0e8eef..d2e270c32c 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1148,6 +1148,7 @@ class Alias(FromClause): alias = alias[0:15] alias = alias + "_" + hex(random.randint(0, 65535))[2:] self.name = alias + self.case_sensitive = getattr(baseselectable, "case_sensitive", alias.lower() != alias) def _locate_oid_column(self): if self.selectable.oid_column is not None: @@ -1180,6 +1181,7 @@ class Label(ColumnElement): while isinstance(obj, Label): obj = obj.obj self.obj = obj + self.case_sensitive = getattr(obj, "case_sensitive", name.lower() != name) self.type = sqltypes.to_instance(type) obj.parens=True key = property(lambda s: s.name) @@ -1206,7 +1208,7 @@ class ColumnClause(ColumnElement): def _get_label(self): if self.__label is None: if self.table is not None and self.table.named_with_column(): - self.__label = self.table.name + "_" + self.name + self.__label = self.table.name + "_" + self.name if self.table.c.has_key(self.__label) or len(self.__label) >= 30: self.__label = self.__label[0:24] + "_" + hex(random.randint(0, 65535))[2:] else: diff --git a/test/orm/selectresults.py b/test/orm/selectresults.py index c4b1d6a56e..6997dfe6bb 100644 --- a/test/orm/selectresults.py +++ b/test/orm/selectresults.py @@ -32,6 +32,7 @@ class SelectResultsTest(PersistTest): global foo foo.drop() self.uninstall_threadlocal() + clear_mappers() def test_selectby(self): res = self.query.select_by(range=5) @@ -111,6 +112,7 @@ class SelectResultsTest2(PersistTest): def tearDownAll(self): metadata.drop_all() self.uninstall_threadlocal() + clear_mappers() def test_distinctcount(self): res = self.query.select() @@ -120,6 +122,42 @@ class SelectResultsTest2(PersistTest): res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1), distinct=True) self.assertEqual(res.count(), 1) +class SelectResultsTest3(PersistTest): + def setUpAll(self): + self.install_threadlocal() + global metadata, table1, table2 + metadata = BoundMetaData(testbase.db) + table1 = Table('Table1', metadata, + Column('ID', Integer, primary_key=True), + ) + table2 = Table('Table2', metadata, + Column('T1ID', Integer, ForeignKey("Table1.ID"), primary_key=True), + Column('NUM', Integer, primary_key=True), + ) + assign_mapper(Obj1, table1, extension=SelectResultsExt()) + assign_mapper(Obj2, table2, extension=SelectResultsExt()) + metadata.create_all() + table1.insert().execute({'ID':1},{'ID':2},{'ID':3},{'ID':4}) + table2.insert().execute({'NUM':1,'T1ID':1},{'NUM':2,'T1ID':1},{'NUM':3,'T1ID':1},\ +{'NUM':4,'T1ID':2},{'NUM':5,'T1ID':2},{'NUM':6,'T1ID':3}) + + def setUp(self): + self.query = Obj1.mapper.query() + #self.orig = self.query.select_whereclause() + #self.res = self.query.select() + + def tearDownAll(self): + metadata.drop_all() + self.uninstall_threadlocal() + clear_mappers() + + def test_distinctcount(self): + res = self.query.select() + assert res.count() == 4 + res = self.query.select(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)) + assert res.count() == 3 + res = self.query.select(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1), distinct=True) + self.assertEqual(res.count(), 1) if __name__ == "__main__": diff --git a/test/sql/quote.py b/test/sql/quote.py index 6b38accbd9..3e1e95a266 100644 --- a/test/sql/quote.py +++ b/test/sql/quote.py @@ -77,6 +77,18 @@ class QuoteTest(PersistTest): assert lcmetadata.case_sensitive is False assert t1.c.UcCol.case_sensitive is False assert t2.c.normalcol.case_sensitive is False + + def testlabels(self): + """test the quoting of labels. + + if labels arent quoted, a query in postgres in particular will fail since it produces: + + SELECT LaLa.lowercase, LaLa."UPPERCASE", LaLa."MixedCase", LaLa."ASC" + FROM (SELECT DISTINCT "WorstCase1".lowercase AS lowercase, "WorstCase1"."UPPERCASE" AS UPPERCASE, "WorstCase1"."MixedCase" AS MixedCase, "WorstCase1"."ASC" AS ASC \nFROM "WorstCase1") AS LaLa + + where the "UPPERCASE" column of "LaLa" doesnt exist. + """ + x = table1.select(distinct=True).alias("LaLa").select().scalar() if __name__ == "__main__": -- 2.47.2