From d0a3313edf7a12372cba12efefa53a7fb9b999af Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 21 Aug 2006 00:37:34 +0000 Subject: [PATCH] - postgres reflection moved to use pg_schema tables, can be overridden with use_information_schema=True argument to create_engine [ticket:60], [ticket:71] - added natural_case argument to Table, Column, semi-experimental flag for use with table reflection to help with quoting rules [ticket:155] --- CHANGES | 6 + lib/sqlalchemy/ansisql.py | 65 +++++++--- lib/sqlalchemy/databases/firebird.py | 6 +- lib/sqlalchemy/databases/mssql.py | 6 +- lib/sqlalchemy/databases/mysql.py | 6 +- lib/sqlalchemy/databases/postgres.py | 180 ++++++++++++++++++++++++++- lib/sqlalchemy/databases/sqlite.py | 6 +- lib/sqlalchemy/schema.py | 16 ++- test/engine/reflection.py | 1 + 9 files changed, 258 insertions(+), 34 deletions(-) diff --git a/CHANGES b/CHANGES index 5fb83098e5..5887b631dd 100644 --- a/CHANGES +++ b/CHANGES @@ -19,6 +19,12 @@ unit of work seeks to flush() them as part of a relationship.. - [ticket:280] statement execution supports using the same BindParam object more than once in an expression; simplified handling of positional parameters. nice job by Bill Noon figuring out the basic idea. +- postgres reflection moved to use pg_schema tables, can be overridden +with use_information_schema=True argument to create_engine +[ticket:60], [ticket:71] +- added natural_case argument to Table, Column, semi-experimental +flag for use with table reflection to help with quoting rules +[ticket:155] 0.2.7 - quoting facilities set up so that database-specific quoting can be diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 031c633284..f4b0852e6f 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -9,7 +9,7 @@ in the sql module.""" from sqlalchemy import schema, sql, engine, util import sqlalchemy.engine.default as default -import string, re, sets +import string, re, sets, weakref ANSI_FUNCS = sets.ImmutableSet([ 'CURRENT_TIME', @@ -27,6 +27,10 @@ def create_engine(): return engine.ComposedSQLEngine(None, ANSIDialect()) class ANSIDialect(default.DefaultDialect): + def __init__(self, **kwargs): + super(ANSIDialect,self).__init__(**kwargs) + self._identifier_cache = weakref.WeakKeyDictionary() + def connect_args(self): return ([],{}) @@ -46,7 +50,7 @@ class ANSIDialect(default.DefaultDialect): """return an IdenfifierPreparer. This object is used to format table and column names including proper quoting and case conventions.""" - return ANSIIdentifierPreparer() + return ANSIIdentifierPreparer(self) class ANSICompiler(sql.Compiled): """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings.""" @@ -77,6 +81,7 @@ class ANSICompiler(sql.Compiled): self.positiontup = [] self.preparer = dialect.preparer() + def after_compile(self): # this re will search for params like :param # it has a negative lookbehind for an extra ':' so that it doesnt match @@ -704,8 +709,8 @@ class ANSIDefaultRunner(engine.DefaultRunner): pass class ANSIIdentifierPreparer(schema.SchemaVisitor): - """Transforms identifiers of SchemaItems into ANSI-Compliant delimited identifiers where required""" - def __init__(self, initial_quote='"', final_quote=None, omit_schema=False): + """handles quoting and case-folding of identifiers based on options""" + def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False): """Constructs a new ANSIIdentifierPreparer object. initial_quote - Character that begins a delimited identifier @@ -713,12 +718,12 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): omit_schema - prevent prepending schema name. useful for databases that do not support schemae """ + self.dialect = dialect self.initial_quote = initial_quote self.final_quote = final_quote or self.initial_quote self.omit_schema = omit_schema self.strings = {} self.__visited = util.Set() - def _escape_identifier(self, value): """escape an identifier. @@ -740,31 +745,59 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): # some tests would need to be rewritten if this is done. #return value.upper() - def _requires_quotes(self, value): + def _requires_quotes(self, value, natural_case): """return true if the given identifier requires quoting.""" return False - + + def __requires_quotes_cached(self, value, natural_case): + try: + return self.dialect._identifier_cache[(value, natural_case)] + except KeyError: + result = self._requires_quotes(value, natural_case) + self.dialect._identifier_cache[(value, natural_case)] = result + return result + def visit_table(self, table): if table in self.__visited: return - if table.quote or self._requires_quotes(table.name): + + # cache the results within the dialect, weakly keyed to the table + try: + (self.strings[table], self.strings[(table, 'schema')]) = self.dialect._identifier_cache[table] + return + except KeyError: + pass + + if table.quote or self._requires_quotes(table.name, table.natural_case): self.strings[table] = self._quote_identifier(table.name) else: - self.strings[table] = table.name # TODO: case folding ? + self.strings[table] = table.name if table.schema: - if table.quote_schema or self._requires_quotes(table.quote_schema): + if table.quote_schema or self._requires_quotes(table.schema, table.natural_case_schema): self.strings[(table, 'schema')] = self._quote_identifier(table.schema) else: - self.strings[(table, 'schema')] = table.schema # TODO: case folding ? - + self.strings[(table, 'schema')] = table.schema + else: + self.strings[(table,'schema')] = None + self.dialect._identifier_cache[table] = (self.strings[table], self.strings[(table, 'schema')]) + def visit_column(self, column): if column in self.__visited: return - if column.quote or self._requires_quotes(column.name): + + # cache the results within the dialect, weakly keyed to the column + try: + self.strings[column] = self.dialect._identifier_cache[column] + return + except KeyError: + pass + + if column.quote or self._requires_quotes(column.name, column.natural_case): self.strings[column] = self._quote_identifier(column.name) else: - self.strings[column] = column.name # TODO: case folding ? - + self.strings[column] = column.name + self.dialect._identifier_cache[column] = self.strings[column] + def __start_visit(self, obj): if obj in self.__visited: return @@ -774,7 +807,7 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): def __prepare_table(self, table, use_schema=False): self.__start_visit(table) - if not self.omit_schema and use_schema and (table, 'schema') in self.strings: + if not self.omit_schema and use_schema and self.strings.get((table, 'schema'), None) is not None: return self.strings[(table, 'schema')] + "." + self.strings.get(table, table.name) else: return self.strings.get(table, table.name) diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 8c61a7c128..67214313f1 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -99,7 +99,7 @@ class FireBirdExecutionContext(default.DefaultExecutionContext): return FBDefaultRunner(self, proxy) def preparer(self): - return FBIdentifierPreparer() + return FBIdentifierPreparer(self) class FireBirdDialect(ansisql.ANSIDialect): def __init__(self, module = None, **params): @@ -381,7 +381,7 @@ class FBDefaultRunner(ansisql.ANSIDefaultRunner): return self.proxy("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").fetchone()[0] class FBIdentifierPreparer(ansisql.ANSIIdentifierPreparer): - def __init__(self): - super(FBIdentifierPreparer,self).__init__(omit_schema=True) + def __init__(self, dialect): + super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True) dialect = FireBirdDialect diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 3c52c522ab..08a555611c 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -273,7 +273,7 @@ class MSSQLDialect(ansisql.ANSIDialect): return MSSQLDefaultRunner(engine, proxy) def preparer(self): - return MSSQLIdentifierPreparer() + return MSSQLIdentifierPreparer(self) def get_default_schema_name(self): return "dbo" @@ -546,8 +546,8 @@ class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): pass class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): - def __init__(self): - super(MSSQLIdentifierPreparer, self).__init__(initial_quote='[', final_quote=']') + def __init__(self, dialect): + super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') def _escape_identifier(self, value): #TODO: determin MSSQL's escapeing rules return value diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index be704ef2a8..825d779e12 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -295,7 +295,7 @@ class MySQLDialect(ansisql.ANSIDialect): return MySQLSchemaDropper(*args, **kwargs) def preparer(self): - return MySQLIdentifierPreparer() + return MySQLIdentifierPreparer(self) def do_rollback(self, connection): # some versions of MySQL just dont support rollback() at all.... @@ -453,8 +453,8 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper): self.execute() class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): - def __init__(self): - super(MySQLIdentifierPreparer, self).__init__(initial_quote='`') + def __init__(self, dialect): + super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote='`') def _escape_identifier(self, value): #TODO: determin MySQL's escaping rules return value diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 9adc42d023..1ce10dd4c6 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -15,6 +15,8 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions import information_schema as ischema +from sqlalchemy import * +import re try: import mx.DateTime.DateTime as mxDateTime @@ -151,8 +153,11 @@ pg2_ischema_names = { 'float' : PGFloat, 'real' : PGFloat, 'double precision' : PGFloat, + 'timestamp' : PG2DateTime, 'timestamp with time zone' : PG2DateTime, 'timestamp without time zone' : PG2DateTime, + 'time with time zone' : PG2Time, + 'time without time zone' : PG2Time, 'date' : PG2Date, 'time': PG2Time, 'bytea' : PGBinary, @@ -166,6 +171,11 @@ pg1_ischema_names.update({ 'time' : PG1Time }) +reserved_words = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', 'as', 'asc', 'asymmetric', 'authorization', 'between', 'binary', 'both', 'case', 'cast', 'check', 'collate', 'column', 'constraint', 'create', 'cross', 'current_date', 'current_role', 'current_time', 'current_timestamp', 'current_user', 'default', 'deferrable', 'desc', 'distinct', 'do', 'else', 'end', 'except', 'false', 'for', 'foreign', 'freeze', 'from', 'full', 'grant', 'group', 'having', 'ilike', 'in', 'initially', 'inner', 'intersect', 'into', 'is', 'isnull', 'join', 'leading', 'left', 'like', 'limit', 'localtime', 'localtimestamp', 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', 'placing', 'primary', 'references', 'right', 'select', 'session_user', 'similar', 'some', 'symmetric', 'table', 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', 'using', 'verbose', 'when', 'where']) + +legal_characters = util.Set(string.ascii_lowercase + string.digits + '_$') +illegal_initial_characters = util.Set(string.digits + '$') + def engine(opts, **params): return PGSQLEngine(opts, **params) @@ -197,7 +207,7 @@ class PGExecutionContext(default.DefaultExecutionContext): self._last_inserted_ids = [v for v in row] class PGDialect(ansisql.ANSIDialect): - def __init__(self, module=None, use_oids=False, **params): + def __init__(self, module=None, use_oids=False, use_information_schema=False, **params): self.use_oids = use_oids if module is None: #if psycopg is None: @@ -214,6 +224,7 @@ class PGDialect(ansisql.ANSIDialect): except: self.version = 1 ansisql.ANSIDialect.__init__(self, **params) + self.use_information_schema = use_information_schema # produce consistent paramstyle even if psycopg2 module not present if self.module is None: self.paramstyle = 'pyformat' @@ -246,7 +257,7 @@ class PGDialect(ansisql.ANSIDialect): def defaultrunner(self, engine, proxy): return PGDefaultRunner(engine, proxy) def preparer(self): - return PGIdentifierPreparer() + return PGIdentifierPreparer(self) def get_default_schema_name(self, connection): if not hasattr(self, '_default_schema_name'): @@ -293,7 +304,155 @@ class PGDialect(ansisql.ANSIDialect): else: ischema_names = pg1_ischema_names - ischema.reflecttable(connection, table, ischema_names) + if self.use_information_schema: + ischema.reflecttable(connection, table, ischema_names) + else: + preparer = self.preparer() + if table.schema is not None: + current_schema = table.schema + else: + current_schema = connection.default_schema_name() + + ## information schema in pg suffers from too many permissions' restrictions + ## let us find out at the pg way what is needed... + + SQL_COLS = """ + SELECT a.attname, + pg_catalog.format_type(a.atttypid, a.atttypmod), + (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d + WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) + AS DEFAULT, + a.attnotnull, a.attnum + FROM pg_catalog.pg_attribute a + WHERE a.attrelid = ( + SELECT c.oid + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE (n.nspname = :schema OR pg_catalog.pg_table_is_visible(c.oid)) + AND c.relname = :table_name AND (c.relkind = 'r' OR c.relkind = 'v') + ) AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum + """ + + s = text(SQL_COLS ) + c = connection.execute(s, table_name=table.name, schema=current_schema) + found_table = False + while True: + row = c.fetchone() + if row is None: + break + found_table = True + name = row['attname'] + natural_case = preparer._is_natural_case(name) + ## strip (30) from character varying(30) + attype = re.search('([^\(]+)', row['format_type']).group(1) + + nullable = row['attnotnull'] == False + try: + charlen = re.search('\(([\d,]+)\)',row['format_type']).group(1) + except: + charlen = None + + numericprec = None + numericscale = None + default = row['default'] + if attype == 'numeric': + numericprec, numericscale = charlen.split(',') + charlen = None + if attype == 'double precision': + numericprec, numericscale = (53, None) + charlen = None + if attype == 'integer': + numericprec, numericscale = (32, 0) + charlen = None + + args = [] + for a in (charlen, numericprec, numericscale): + if a is not None: + args.append(int(a)) + + coltype = ischema_names[attype] + coltype = coltype(*args) + colargs= [] + if default is not None: + colargs.append(PassiveDefault(sql.text(default))) + table.append_item(schema.Column(name, coltype, nullable=nullable, natural_case=natural_case, *colargs)) + + + if not found_table: + raise exceptions.NoSuchTableError(table.name) + + # Primary keys + PK_SQL = """ + SELECT attname FROM pg_attribute + WHERE attrelid = ( + SELECT indexrelid FROM pg_index i, pg_class c, pg_namespace n + WHERE n.nspname = :schema AND c.relname = :table_name + AND c.oid = i.indrelid AND n.oid = c.relnamespace + AND i.indisprimary = 't' ) ; + """ + t = text(PK_SQL) + c = connection.execute(t, table_name=table.name, schema=current_schema) + while True: + row = c.fetchone() + if row is None: + break + pk = row[0] + table.c[pk]._set_primary_key() + + # Foreign keys + FK_SQL = """ + SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef + FROM pg_catalog.pg_constraint r + WHERE r.conrelid = ( + SELECT c.oid FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n + ON n.oid = c.relnamespace + WHERE c.relname = :table_name + AND pg_catalog.pg_table_is_visible(c.oid)) + AND r.contype = 'f' ORDER BY 1 + + """ + + t = text(FK_SQL) + c = connection.execute(t, table_name=table.name) + while True: + row = c.fetchone() + if row is None: + break + + identifier = '(?:[a-z_][a-z0-9_$]+|"(?:[^"]|"")+")' + identifier_group = '%s(?:, %s)*' % (identifier, identifier) + identifiers = '(%s)(?:, (%s))*' % (identifier, identifier) + f = re.compile(identifiers) + # FOREIGN KEY (mail_user_id,"Mail_User_ID2") REFERENCES "mYschema".euro_user(user_id,"User_ID2") + foreign_key_pattern = 'FOREIGN KEY \((%s)\) REFERENCES (?:(%s)\.)?(%s)\((%s)\)' % (identifier_group, identifier, identifier, identifier_group) + p = re.compile(foreign_key_pattern) + + m = p.search(row['condef']) + (constrained_columns, referred_schema, referred_table, referred_columns) = m.groups() + + constrained_columns = [preparer._unquote_identifier(x) for x in f.search(constrained_columns).groups() if x] + if referred_schema: + referred_schema = preparer._unquote_identifier(referred_schema) + referred_table = preparer._unquote_identifier(referred_table) + referred_columns = [preparer._unquote_identifier(x) for x in f.search(referred_columns).groups() if x] + + natural_case = preparer._is_natural_case(referred_table) + + refspec = [] + if referred_schema is not None: + natural_case_schema = preparer._is_natural_case(referred_schema) + schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, + autoload_with=connection, natural_case=natural_case, natural_case_schema = natural_case_schema) + for column in referred_columns: + refspec.append(".".join([referred_schema, referred_table, column])) + else: + schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, natural_case=natural_case) + for column in referred_columns: + refspec.append(".".join([referred_table, column])) + + table.append_item(ForeignKeyConstraint(constrained_columns, refspec, row['conname'])) class PGCompiler(ansisql.ANSICompiler): @@ -392,5 +551,18 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def _fold_identifier_case(self, value): return value.lower() - + def _requires_quotes(self, value, natural_case): + if natural_case: + value = self._fold_identifier_case(str(value)) + retval = bool(len([x for x in str(value) if x not in legal_characters])) + if not retval and (value[0] in illegal_initial_characters or value in reserved_words): + retval = True + return retval + def _unquote_identifier(self, value): + if value[0] == self.initial_quote: + value = value[1:-1].replace('""','"') + return value + def _is_natural_case(self, value): + return self._fold_identifier_case(value) == value + dialect = PGDialect diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index e3bc0b9c98..0dfd83eeba 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -142,7 +142,7 @@ class SQLiteDialect(ansisql.ANSIDialect): def schemagenerator(self, *args, **kwargs): return SQLiteSchemaGenerator(*args, **kwargs) def preparer(self): - return SQLiteIdentifierPreparer() + return SQLiteIdentifierPreparer(self) def create_connect_args(self, url): filename = url.database or ':memory:' return ([filename], url.query) @@ -300,8 +300,8 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): # super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint) class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer): - def __init__(self): - super(SQLiteIdentifierPreparer, self).__init__(omit_schema=True) + def __init__(self, dialect): + super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True) dialect = SQLiteDialect poolclass = pool.SingletonThreadPool diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index bba73ef88f..2bf1627dd0 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -144,10 +144,16 @@ class Table(SchemaItem, sql.TableClause): reflection. quote=False : indicates that the Table identifier must be properly escaped and quoted before being sent - to the database. + to the database. This flag overrides all other quoting behavior. quote_schema=False : indicates that the Namespace identifier must be properly escaped and quoted before being sent - to the database. + to the database. This flag overrides all other quoting behavior. + + natural_case=True : indicates that the identifier should be interpreted by the database in the natural case for identifiers. + Mixed case is not sufficient to cause this identifier to be quoted; it must contain an illegal character. + + natural_case_schema=True : indicates that the identifier should be interpreted by the database in the natural case for identifiers. + Mixed case is not sufficient to cause this identifier to be quoted; it must contain an illegal character. """ super(Table, self).__init__(name) self._metadata = metadata @@ -163,6 +169,8 @@ class Table(SchemaItem, sql.TableClause): self.owner = kwargs.pop('owner', None) self.quote = kwargs.pop('quote', False) self.quote_schema = kwargs.pop('quote_schema', False) + self.natural_case = kwargs.pop('natural_case', True) + self.natural_case_schema = kwargs.pop('natural_case_schema', True) self.kwargs = kwargs def _set_primary_key(self, pk): @@ -332,6 +340,9 @@ class Column(SchemaItem, sql.ColumnClause): quote=False : indicates that the Column identifier must be properly escaped and quoted before being sent to the database. + + natural_case=True : indicates that the identifier should be interpreted by the database in the natural case for identifiers. + Mixed case is not sufficient to cause this identifier to be quoted; it must contain an illegal character. """ name = str(name) # in case of incoming unicode super(Column, self).__init__(name, None, type) @@ -344,6 +355,7 @@ class Column(SchemaItem, sql.ColumnClause): self.index = kwargs.pop('index', None) self.unique = kwargs.pop('unique', None) self.quote = kwargs.pop('quote', False) + self.natural_case = kwargs.pop('natural_case', True) self.onupdate = kwargs.pop('onupdate', None) if self.index is not None and self.unique is not None: raise exceptions.ArgumentError("Column may not define both index and unique") diff --git a/test/engine/reflection.py b/test/engine/reflection.py index dd8a52a9a2..c0aec04a4d 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -354,6 +354,7 @@ class SchemaTest(PersistTest): table1.accept_schema_visitor(gen) table2.accept_schema_visitor(gen) buf = buf.getvalue() + print buf assert buf.index("CREATE TABLE someschema.table1") > -1 assert buf.index("CREATE TABLE someschema.table2") > -1 -- 2.47.2