From a8cdead32632045c29260b9bd7c2bcd5f2c8f221 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 4 Feb 2007 03:12:27 +0000 Subject: [PATCH] - more quoting fixes for [ticket:450]...quoting more aggressive (but still skips already-quoted literals) - got mysql to have "format" as default paramstyle even if mysql module not available, allows unit tests to pass in non-mysql system for [ticket:457]. all the dialects should be changed to pass in their usual paramstyle. --- lib/sqlalchemy/ansisql.py | 29 ++++++++++++++--------------- lib/sqlalchemy/databases/mysql.py | 6 ++++-- lib/sqlalchemy/engine/default.py | 9 ++++----- lib/sqlalchemy/schema.py | 2 +- lib/sqlalchemy/sql.py | 11 ++++++----- test/sql/defaults.py | 1 + test/sql/quote.py | 6 ++++++ 7 files changed, 36 insertions(+), 28 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 40ea5b00ac..794c484394 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -850,12 +850,13 @@ class ANSIIdentifierPreparer(object): self.initial_quote = initial_quote self.final_quote = final_quote or self.initial_quote self.omit_schema = omit_schema + self._is_quoted_regexp = re.compile(r"""^['"%s].+['"%s]$""" % (self.initial_quote, self.final_quote)) self.__strings = {} def _escape_identifier(self, value): """escape an identifier. subclasses should override this to provide database-dependent escaping behavior.""" - return value.replace('"', '""') + return value.replace("'", "''") def _quote_identifier(self, value): """quote an identifier. @@ -872,6 +873,10 @@ class ANSIIdentifierPreparer(object): # some tests would need to be rewritten if this is done. #return value.upper() + def _is_quoted(self, ident): + """return true if the given identifier is already quoted""" + return self._is_quoted_regexp.match(ident) + def _reserved_words(self): return RESERVED_WORDS @@ -884,10 +889,11 @@ class ANSIIdentifierPreparer(object): def _requires_quotes(self, value, case_sensitive): """return true if the given identifier requires quoting.""" return \ - value in self._reserved_words() \ + not self._is_quoted(value) and \ + (value in self._reserved_words() \ or (value[0] in self._illegal_initial_characters()) \ or bool(len([x for x in str(value) if x not in self._legal_characters()])) \ - or (case_sensitive and value.lower() != value) + or (case_sensitive and not value.islower())) def __generic_obj_format(self, obj, ident): if getattr(obj, 'quote', False): @@ -897,13 +903,13 @@ class ANSIIdentifierPreparer(object): try: return self.__strings[(ident, case_sens)] except KeyError: - if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())): + if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident.islower())): self.__strings[(ident, case_sens)] = self._quote_identifier(ident) else: self.__strings[(ident, case_sens)] = ident return self.__strings[(ident, case_sens)] else: - if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())): + if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident.islower())): return self._quote_identifier(ident) else: return ident @@ -934,17 +940,10 @@ class ANSIIdentifierPreparer(object): """Prepare a quoted column name """ # TODO: isinstance alert ! get ColumnClause and Column to better # differentiate themselves - if isinstance(column, schema.SchemaItem): - if use_table: - return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name) - else: - return self.__generic_obj_format(column, column.name) + if use_table: + return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name) else: - # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted - if use_table: - return self.format_table(column.table, use_schema=False) + "." + column.name - else: - return column.name + return self.__generic_obj_format(column, column.name) def format_column_with_table(self, column): """Prepare a quoted column name with table name""" diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index d30751fb4d..c09d02756c 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -17,6 +17,7 @@ try: import MySQLdb.constants.CLIENT as CLIENT_FLAGS except: mysql = None + CLIENT_FLAGS = None def kw_colspec(self, spec): if self.unsigned: @@ -256,7 +257,7 @@ class MySQLDialect(ansisql.ANSIDialect): self.module = mysql else: self.module = module - ansisql.ANSIDialect.__init__(self, **kwargs) + ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs) def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'db', 'user', 'passwd', 'port']) @@ -274,7 +275,8 @@ class MySQLDialect(ansisql.ANSIDialect): # TODO: what about options like "ssl", "cursorclass" and "conv" ? client_flag = opts.get('client_flag', 0) - client_flag |= CLIENT_FLAGS.FOUND_ROWS + if CLIENT_FLAGS is not None: + client_flag |= CLIENT_FLAGS.FOUND_ROWS opts['client_flag'] = client_flag return [[], opts] diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 06409377cd..2f6283b5d3 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -24,14 +24,13 @@ class PoolConnectionProvider(base.ConnectionProvider): class DefaultDialect(base.Dialect): """default implementation of Dialect""" - def __init__(self, convert_unicode=False, encoding='utf-8', **kwargs): + def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', **kwargs): self.convert_unicode = convert_unicode self.supports_autoclose_results = True self.encoding = encoding self.positional = False - self.paramstyle = 'named' self._ischema = None - self._figure_paramstyle() + self._figure_paramstyle(default=default_paramstyle) def create_execution_context(self): return DefaultExecutionContext(self) def type_descriptor(self, typeobj): @@ -90,14 +89,14 @@ class DefaultDialect(base.Dialect): parameters = parameters.get_raw_dict() return parameters - def _figure_paramstyle(self, paramstyle=None): + def _figure_paramstyle(self, paramstyle=None, default='named'): db = self.dbapi() if paramstyle is not None: self._paramstyle = paramstyle elif db is not None: self._paramstyle = db.paramstyle else: - self._paramstyle = 'named' + self._paramstyle = default if self._paramstyle == 'named': self.positional=False diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 2be808ba61..44ba85453a 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -482,7 +482,7 @@ class Column(SchemaItem, sql._ColumnClause): """redirect the 'case_sensitive' accessor to use the ultimate parent column which created this one.""" return self.__originating_column._get_case_sensitive() - case_sensitive = property(_case_sens) + case_sensitive = property(_case_sens, lambda s,v:None) def accept_schema_visitor(self, visitor, traverse=True): """traverses the given visitor to this Column's default and foreign key object, diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 53cb6b9771..823c8afc84 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -219,10 +219,10 @@ def label(name, obj): """returns a _Label object for the given selectable, used in the column list for a select statement.""" return _Label(name, obj) -def column(text, table=None, type=None): +def column(text, table=None, type=None, **kwargs): """returns a textual column clause, relative to a table. this is also the primitive version of a schema.Column which is a subclass. """ - return _ColumnClause(text, table, type) + return _ColumnClause(text, table, type, **kwargs) def table(name, *columns): """returns a table clause. this is a primitive version of the schema.Table object, which is a subclass @@ -1199,7 +1199,7 @@ class Alias(FromClause): alias = alias[0:15] alias = alias + "_" + hex(random.randint(0, 65535))[2:] self.name = alias - self.case_sensitive = getattr(baseselectable, "case_sensitive", alias.lower() != alias) + self.case_sensitive = getattr(baseselectable, "case_sensitive", True) def supports_execution(self): return self.original.supports_execution() def _locate_oid_column(self): @@ -1233,7 +1233,7 @@ class _Label(ColumnElement): while isinstance(obj, _Label): obj = obj.obj self.obj = obj - self.case_sensitive = getattr(obj, "case_sensitive", name.lower() != name) + self.case_sensitive = getattr(obj, "case_sensitive", True) self.type = sqltypes.to_instance(type) obj.parens=True key = property(lambda s: s.name) @@ -1251,12 +1251,13 @@ legal_characters = util.Set(string.ascii_letters + string.digits + '_') class _ColumnClause(ColumnElement): """represents a textual column clause in a SQL statement. May or may not be bound to an underlying Selectable.""" - def __init__(self, text, selectable=None, type=None, _is_oid=False): + def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True): self.key = self.name = text self.table = selectable self.type = sqltypes.to_instance(type) self._is_oid = _is_oid self.__label = None + self.case_sensitive = case_sensitive def _get_label(self): if self.__label is None: if self.table is not None and self.table.named_with_column(): diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 9d426ab0d2..d018c5efbf 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -176,6 +176,7 @@ class SequenceTest(PersistTest): metadata.create_all() + @testbase.supported('postgres', 'oracle') def testseqnonpk(self): """test sequences fire off as defaults on non-pk columns""" sometable.insert().execute(name="somename") diff --git a/test/sql/quote.py b/test/sql/quote.py index 607a595d3d..403ae2d426 100644 --- a/test/sql/quote.py +++ b/test/sql/quote.py @@ -97,10 +97,16 @@ class QuoteTest(PersistTest): x = select([table.c.col1.label("ImATable_col1")]).alias("SomeAlias") assert str(select([x.c.ImATable_col1])) == '''SELECT "SomeAlias"."ImATable_col1" \nFROM (SELECT "ImATable".col1 AS "ImATable_col1" \nFROM "ImATable") AS "SomeAlias"''' + # note that 'foo' and 'FooCol' are literals already quoted x = select([sql.column("'foo'").label("somelabel")], from_obj=[table]).alias("AnAlias") x = x.select() + #print x assert str(x) == '''SELECT "AnAlias".somelabel \nFROM (SELECT 'foo' AS somelabel \nFROM "ImATable") AS "AnAlias"''' + x = select([sql.column("'FooCol'").label("SomeLabel")], from_obj=[table]) + x = x.select() + assert str(x) == '''SELECT "SomeLabel" \nFROM (SELECT 'FooCol' AS "SomeLabel" \nFROM "ImATable")''' + def testlabelsnocase(self): metadata = MetaData() table1 = Table('SomeCase1', metadata, -- 2.47.2