0.2.7
+- quoting facilities set up so that database-specific quoting can be
+turned on for individual table, schema, and column identifiers when
+used in all queries/creates/drops. Enabled via "quote=True" in
+Table or Column, as well as "quote_schema=True" in Table. Thanks to
+Aaron Spike for his excellent efforts.
- assignmapper was setting is_primary=True, causing all sorts of mayhem
by not raising an error when redundant mappers were set up, fixed
- added allow_null_pks option to Mapper, allows rows where some
def compiler(self, statement, parameters, **kwargs):
return ANSICompiler(self, statement, parameters, **kwargs)
+ def preparer(self):
+ """return an IdenfifierPreparer.
+
+ This object is used to format table and column names including proper quoting and case conventions."""
+ return ANSIIdentifierPreparer()
class ANSICompiler(sql.Compiled):
"""default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
self.bindtemplate = ":%s"
self.paramstyle = dialect.paramstyle
self.positional = dialect.positional
+ self.preparer = dialect.preparer()
def after_compile(self):
# this re will search for params like :param
# for this column which is used to translate result set values
self.typemap.setdefault(column.name.lower(), column.type)
if column.table is None or not column.table.named_with_column():
- self.strings[column] = column.name
+ self.strings[column] = self.preparer.format_column(column)
else:
if column.table.oid_column is column:
n = self.dialect.oid_column_name()
if n is not None:
- self.strings[column] = "%s.%s" % (column.table.name, n)
+ self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n)
elif len(column.table.primary_key) != 0:
- self.strings[column] = "%s.%s" % (column.table.name, column.table.primary_key[0].name)
+ self.strings[column] = self.preparer.format_column_with_table(column.table.primary_key[0])
else:
self.strings[column] = None
else:
- self.strings[column] = "%s.%s" % (column.table.name, column.name)
-
+ self.strings[column] = self.preparer.format_column_with_table(column)
def visit_fromclause(self, fromclause):
self.froms[fromclause] = fromclause.from_name
return " OFFSET " + str(select.offset)
def visit_table(self, table):
- self.froms[table] = table.fullname
+ self.froms[table] = self.preparer.format_table(table)
self.strings[table] = ""
def visit_join(self, join):
else:
return self.get_str(p)
- text = ("INSERT INTO " + insert_stmt.table.fullname + " (" + string.join([c[0].name for c in colparams], ', ') + ")" +
+ text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
" VALUES (" + string.join([create_param(c[1]) for c in colparams], ', ') + ")")
self.strings[insert_stmt] = text
else:
return self.get_str(p)
- text = "UPDATE " + update_stmt.table.fullname + " SET " + string.join(["%s=%s" % (c[0].name, create_param(c[1])) for c in colparams], ', ')
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(c[1])) for c in colparams], ', ')
if update_stmt.whereclause:
text += " WHERE " + self.get_str(update_stmt.whereclause)
return values
def visit_delete(self, delete_stmt):
- text = "DELETE FROM " + delete_stmt.table.fullname
+ text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
if delete_stmt.whereclause:
text += " WHERE " + self.get_str(delete_stmt.whereclause)
super(ANSISchemaGenerator, self).__init__(engine, proxy, **params)
self.checkfirst = checkfirst
self.connection = connection
+ self.preparer = self.engine.dialect.preparer()
+
def get_column_specification(self, column, first_pk=False):
raise NotImplementedError()
if self.checkfirst and self.engine.dialect.has_table(self.connection, table.name):
return
- self.append("\nCREATE TABLE " + table.fullname + " (")
+ self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
separator = "\n"
if len(constraint) == 0:
return
self.append(", \n")
- self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in constraint],', '))
+ self.append("\tPRIMARY KEY (%s)" % string.join([self.preparer.format_column(c) for c in constraint],', '))
def visit_foreign_key_constraint(self, constraint):
self.append(", \n\t ")
if constraint.name is not None:
self.append("CONSTRAINT %s " % constraint.name)
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- string.join([f.parent.name for f in constraint.elements], ', '),
- list(constraint.elements)[0].column.table.fullname,
- string.join([f.column.name for f in constraint.elements], ', ')
+ string.join([self.preparer.format_column(f.parent) for f in constraint.elements], ', '),
+ self.preparer.format_table(list(constraint.elements)[0].column.table),
+ string.join([self.preparer.format_column(f.column) for f in constraint.elements], ', ')
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
if index.unique:
self.append('UNIQUE ')
self.append('INDEX %s ON %s (%s)' \
- % (index.name, index.table.fullname,
- string.join([c.name for c in index.columns], ', ')))
+ % (index.name, self.preparer.format_table(index.table),
+ string.join([self.preparer.format_column(c) for c in index.columns], ', ')))
self.execute()
-
class ANSISchemaDropper(engine.SchemaIterator):
def __init__(self, engine, proxy, connection=None, checkfirst=False, **params):
super(ANSISchemaDropper, self).__init__(engine, proxy, **params)
self.checkfirst = checkfirst
self.connection = connection
+ self.preparer = self.engine.dialect.preparer()
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name)
# no need to drop them individually
if self.checkfirst and not self.engine.dialect.has_table(self.connection, table.name):
return
- self.append("\nDROP TABLE " + table.fullname)
+ self.append("\nDROP TABLE " + self.preparer.format_table(table))
self.execute()
-
class ANSIDefaultRunner(engine.DefaultRunner):
pass
+
+class ANSIIdentifierPreparer(object):
+ """Transforms identifiers into ANSI-Compliant delimited identifiers where required"""
+ def __init__(self, initial_quote='"', final_quote=None, omit_schema=False):
+ """Constructs a new ANSIIdentifierPreparer object.
+
+ initial_quote - Character that begins a delimited identifier
+ final_quote - Caracter that ends a delimited identifier. defaults to initial_quote.
+
+ omit_schema - prevent prepending schema name. useful for databases that do not support schemae
+ """
+ self.initial_quote = initial_quote
+ self.final_quote = final_quote or self.initial_quote
+ self.omit_schema = omit_schema
+
+ def _escape_identifier(self, value):
+ return value.replace('"', '""')
+
+ def _quote_identifier(self, value):
+ return self.initial_quote + self._escape_identifier(value) + self.final_quote
+
+ def _fold_identifier_case(self, value):
+ return value
+ # ANSI SQL calls for the case of all unquoted identifiers to be folded to UPPER.
+ # some tests would need to be rewritten if this is done.
+ #return value.upper()
+
+ def _prepare_table(self, table, use_schema=False):
+ names = []
+ if table.quote:
+ names.append(self._quote_identifier(table.name))
+ else:
+ names.append(self._fold_identifier_case(table.name))
+
+ if not self.omit_schema and use_schema and table.schema:
+ if table.quote_schema:
+ names.insert(0, self._quote_identifier(table.schema))
+ else:
+ names.insert(0, self._fold_identifier_case(table.schema))
+
+ return ".".join(names)
+
+ def _prepare_column(self, column, use_table=True, **kwargs):
+ names = []
+ if column.quote:
+ names.append(self._quote_identifier(column.name))
+ else:
+ names.append(self._fold_identifier_case(column.name))
+
+ if use_table:
+ names.insert(0, self._prepare_table(column.table, **kwargs))
+
+ return ".".join(names)
+
+ def format_table(self, table, use_schema=True):
+ """Prepare a quoted table and schema name"""
+ return self._prepare_table(table, use_schema=use_schema)
+
+ def format_column(self, column):
+ """Prepare a quoted column name"""
+ return self._prepare_column(column, use_table=False)
+
+ def format_column_with_table(self, column):
+ """Prepare a quoted column name with table name"""
+ return self._prepare_column(column)
def defaultrunner(self, proxy):
return FBDefaultRunner(self, proxy)
-
+
+ def preparer(self):
+ return FBIdentifierPreparer()
+
class FireBirdDialect(ansisql.ANSIDialect):
def __init__(self, module = None, **params):
global _initialized_kb
class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = column.name
+ colspec = self.preparer.format_column(column)
colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
def visit_sequence(self, seq):
return self.proxy("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").fetchone()[0]
+class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def __init__(self):
+ super(FBIdentifierPreparer,self).__init__(omit_schema=True)
dialect = FireBirdDialect
def defaultrunner(self, engine, proxy):
return MSSQLDefaultRunner(engine, proxy)
+ def preparer(self):
+ return MSSQLIdentifierPreparer()
+
def get_default_schema_name(self):
return "dbo"
class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if column.primary_key and isinstance(column.type, sqltypes.Integer):
class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
pass
+class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def __init__(self):
+ super(MSSQLIdentifierPreparer, self).__init__(initial_quote='[', final_quote=']')
+ def _escape_identifier(self, value):
+ #TODO: determin MSSQL's escapeing rules
+ return value
+ def _fold_identifier_case(self, value):
+ #TODO: determin MSSQL's case folding rules
+ return value
+
dialect = MSSQLDialect
def schemadropper(self, *args, **kwargs):
return MySQLSchemaDropper(*args, **kwargs)
+ def preparer(self):
+ return MySQLIdentifierPreparer()
+
def do_rollback(self, connection):
# some versions of MySQL just dont support rollback() at all....
try:
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
- colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
self.append("\nDROP INDEX " + index.name + " ON " + index.table.name)
self.execute()
+class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def __init__(self):
+ super(MySQLIdentifierPreparer, self).__init__(initial_quote='`')
+ def _escape_identifier(self, value):
+ #TODO: determin MySQL's escaping rules
+ return value
+ def _fold_identifier_case(self, value):
+ #TODO: determin MySQL's case folding rules
+ return value
+
dialect = MySQLDialect
return PGSchemaDropper(*args, **kwargs)
def defaultrunner(self, engine, proxy):
return PGDefaultRunner(engine, proxy)
+ def preparer(self):
+ return PGIdentifierPreparer()
def get_default_schema_name(self, connection):
if not hasattr(self, '_default_schema_name'):
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = column.name
+ colspec = self.preparer.format_column(column)
if column.primary_key and not column.foreign_key and isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
colspec += " SERIAL"
else:
else:
return None
+class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def _fold_identifier_case(self, value):
+ return value.lower()
+
dialect = PGDialect
return SQLiteCompiler(self, statement, bindparams, **kwargs)
def schemagenerator(self, *args, **kwargs):
return SQLiteSchemaGenerator(*args, **kwargs)
+ def preparer(self):
+ return SQLiteIdentifierPreparer()
def create_connect_args(self, url):
filename = url.database or ':memory:'
return ([filename], url.query)
return SQLiteExecutionContext(self)
def last_inserted_ids(self):
return self.context.last_inserted_ids
-
+
def oid_column_name(self):
return "oid"
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
- colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
+ colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
# self.append("\tUNIQUE (%s)" % string.join([c.name for c in constraint],', '))
# else:
# super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
-
+
+class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
+ def __init__(self):
+ super(SQLiteIdentifierPreparer, self).__init__(omit_schema=True)
+
dialect = SQLiteDialect
poolclass = pool.SingletonThreadPool
owner=None : optional owning user of this table. useful for databases such as Oracle to aid in table
reflection.
+
+ quote=False : indicates that the Table identifier must be properly escaped and quoted before being sent
+ to the database.
+
+ quote_schema=False : indicates that the Namespace identifier must be properly escaped and quoted before being sent
+ to the database.
"""
super(Table, self).__init__(name)
self._metadata = metadata
else:
self.fullname = self.name
self.owner = kwargs.pop('owner', None)
+ self.quote = kwargs.pop('quote', False)
+ self.quote_schema = kwargs.pop('quote_schema', False)
self.kwargs = kwargs
def _set_primary_key(self, pk):
specify the same index name will all be included in the index, in the
order of their creation.
+ quote=False : indicates that the Column identifier must be properly escaped and quoted before being sent
+ to the database.
"""
name = str(name) # in case of incoming unicode
super(Column, self).__init__(name, None, type)
self.default = kwargs.pop('default', None)
self.index = kwargs.pop('index', None)
self.unique = kwargs.pop('unique', None)
+ self.quote = kwargs.pop('quote', False)
self.onupdate = kwargs.pop('onupdate', None)
if self.index is not None and self.unique is not None:
raise exceptions.ArgumentError("Column may not define both index and unique")
return self.obj._get_from_objects()
def _make_proxy(self, selectable, name = None):
return self.obj._make_proxy(selectable, name=self.name)
-
+
+legal_characters = util.Set(string.ascii_letters + string.digits + '_')
class ColumnClause(ColumnElement):
"""represents a textual column clause in a SQL statement. May or may not
be bound to an underlying Selectable."""
self.__label = self.__label[0:24] + "_" + hex(random.randint(0, 65535))[2:]
else:
self.__label = self.name
+ self.__label = "".join([x for x in self.__label if x in legal_characters])
return self.__label
_label = property(_get_label)
def accept_visitor(self, visitor):
# assorted round-trip tests
'sql.query',
+ 'sql.quote',
# defaults, sequences (postgres/oracle)
'sql.defaults',
--- /dev/null
+from testbase import PersistTest
+import testbase
+from sqlalchemy import *
+
+class QuoteTest(PersistTest):
+ def setUpAll(self):
+ # TODO: figure out which databases/which identifiers allow special characters to be used,
+ # such as: spaces, quote characters, punctuation characters, set up tests for those as
+ # well.
+ global table1, table2, table3
+ metadata = BoundMetaData(testbase.db)
+ table1 = Table('WorstCase1', metadata,
+ Column('lowercase', Integer, primary_key=True),
+ Column('UPPERCASE', Integer),
+ Column('MixedCase', Integer, quote=True),
+ Column('ASC', Integer, quote=True),
+ quote=True)
+ table2 = Table('WorstCase2', metadata,
+ Column('desc', Integer, quote=True, primary_key=True),
+ Column('Union', Integer, quote=True),
+ Column('MixedCase', Integer, quote=True),
+ quote=True)
+ table1.create()
+ table2.create()
+
+ def tearDown(self):
+ table1.delete().execute()
+ table2.delete().execute()
+
+ def tearDownAll(self):
+ table1.drop()
+ table2.drop()
+
+ def testbasic(self):
+ table1.insert().execute({'lowercase':1,'UPPERCASE':2,'MixedCase':3,'ASC':4},
+ {'lowercase':2,'UPPERCASE':2,'MixedCase':3,'ASC':4},
+ {'lowercase':4,'UPPERCASE':3,'MixedCase':2,'ASC':1})
+ table2.insert().execute({'desc':1,'Union':2,'MixedCase':3},
+ {'desc':2,'Union':2,'MixedCase':3},
+ {'desc':4,'Union':3,'MixedCase':2})
+
+ res1 = select([table1.c.lowercase, table1.c.UPPERCASE, table1.c.MixedCase, table1.c.ASC]).execute().fetchall()
+ print res1
+ assert(res1==[(1,2,3,4),(2,2,3,4),(4,3,2,1)])
+
+ res2 = select([table2.c.desc, table2.c.Union, table2.c.MixedCase]).execute().fetchall()
+ print res2
+ assert(res2==[(1,2,3),(2,2,3),(4,3,2)])
+
+ def testreflect(self):
+ meta2 = BoundMetaData(testbase.db)
+ t2 = Table('WorstCase2', meta2, autoload=True, quote=True)
+ assert t2.c.has_key('MixedCase')
+
+ def testlabels(self):
+ table1.insert().execute({'lowercase':1,'UPPERCASE':2,'MixedCase':3,'ASC':4},
+ {'lowercase':2,'UPPERCASE':2,'MixedCase':3,'ASC':4},
+ {'lowercase':4,'UPPERCASE':3,'MixedCase':2,'ASC':1})
+ table2.insert().execute({'desc':1,'Union':2,'MixedCase':3},
+ {'desc':2,'Union':2,'MixedCase':3},
+ {'desc':4,'Union':3,'MixedCase':2})
+
+ res1 = select([table1.c.lowercase, table1.c.UPPERCASE, table1.c.MixedCase, table1.c.ASC], use_labels=True).execute().fetchall()
+ print res1
+ assert(res1==[(1,2,3,4),(2,2,3,4),(4,3,2,1)])
+
+ res2 = select([table2.c.desc, table2.c.Union, table2.c.MixedCase], use_labels=True).execute().fetchall()
+ print res2
+ assert(res2==[(1,2,3),(2,2,3),(4,3,2)])
+
+if __name__ == "__main__":
+ testbase.main()