From b4038adbd0bde05de749a3e18cd13fee919d0076 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 31 Oct 2005 02:11:16 +0000 Subject: [PATCH] --- lib/sqlalchemy/databases/postgres.py | 96 +++++++++++++++++----------- lib/sqlalchemy/engine.py | 6 +- lib/sqlalchemy/schema.py | 37 ++++++++--- lib/sqlalchemy/types.py | 2 + lib/sqlalchemy/util.py | 7 +- test/engines.py | 31 +++++---- 6 files changed, 114 insertions(+), 65 deletions(-) diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index ef6cac24ae..67ee1d4fd6 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -104,6 +104,13 @@ gen_column_constraints = schema.Table("constraint_column_usage", generic_engine, Column("constraint_name", String), schema="information_schema") +gen_key_constraints = schema.Table("key_column_usage", generic_engine, + Column("table_schema", String), + Column("table_name", String), + Column("column_name", String), + Column("constraint_name", String), + schema="information_schema") + def engine(opts, **params): return PGSQLEngine(opts, **params) @@ -138,6 +145,11 @@ class PGSQLEngine(ansisql.ANSISQLEngine): def reflecttable(self, table): raise "not implemented" + + def get_default_schema_name(self): + if not hasattr(self, '_default_schema_name'): + self._default_schema_name = text("select current_schema()", self).scalar() + return self._default_schema_name def last_inserted_ids(self): # if we used sequences or already had all values for the last inserted row, @@ -205,65 +217,73 @@ class PGSQLEngine(ansisql.ANSISQLEngine): return self.module def reflecttable(self, table): - columns = gen_columns.toengine(table.engine) - constraints = gen_constraints.toengine(table.engine) - column_constraints = gen_column_constraints.toengine(table.engine) - - s = select([columns, constraints.c.constraint_type], - columns.c.table_name==table.name, - order_by=[columns.c.ordinal_position]) - - s.append_from(sql.outerjoin(columns, column_constraints, - sql.and_( - columns.c.table_name==column_constraints.c.table_name, - columns.c.table_schema==column_constraints.c.table_schema, - columns.c.column_name==column_constraints.c.column_name, - )).outerjoin(constraints, - sql.and_( - column_constraints.c.table_schema==constraints.c.table_schema, - column_constraints.c.constraint_name==constraints.c.constraint_name, - constraints.c.constraint_type=='PRIMARY KEY' - ))) + columns = gen_columns.toengine(self) + constraints = gen_constraints.toengine(self) + column_constraints = gen_column_constraints.toengine(self) + key_constraints = gen_key_constraints.toengine(self) if table.schema is not None: - s.append_whereclause(columns.c.table_schema==table.schema) + current_schema = table.schema else: - current_schema = text("select current_schema()", table.engine).scalar() - s.append_whereclause(columns.c.table_schema==current_schema) - + current_schema = self.get_default_schema_name() + + s = select([columns], + sql.and_(columns.c.table_name==table.name, + columns.c.table_schema==current_schema), + order_by=[columns.c.ordinal_position]) + c = s.execute() while True: row = c.fetchone() if row is None: break #print "row! " + repr(row) - (name, type, nullable, primary_key, charlen, numericprec, numericscale) = ( + (name, type, nullable, charlen, numericprec, numericscale) = ( row[columns.c.column_name], row[columns.c.data_type], - not row[columns.c.is_nullable], - row[constraints.c.constraint_type] is not None, + row[columns.c.is_nullable] == 'YES', row[columns.c.character_maximum_length], row[columns.c.numeric_precision], row[columns.c.numeric_scale], ) - #match = re.match(r'(\w+)(\(.*?\))?', type) - #coltype = match.group(1) - #args = match.group(2) - - #print "coltype: " + repr(coltype) + " args: " + repr(args) + args = [] + for a in (charlen, numericprec, numericscale): + if a is not None: + args.append(a) coltype = ischema_names[type] - table.append_item(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable)) - return - c = self.execute("PRAGMA foreign_key_list(" + table.name + ")", {}) + #print "coltype " + repr(coltype) + " args " + repr(args) + coltype = coltype(*args) + table.append_item(schema.Column(name, coltype, nullable = nullable)) + + s = select([ + constraints.c.constraint_type, + column_constraints, + key_constraints + ], + sql.and_( + key_constraints.c.constraint_name==column_constraints.c.constraint_name, + column_constraints.c.constraint_name==constraints.c.constraint_name, + constraints.c.table_name==table.name, constraints.c.table_schema==current_schema) + , use_labels=True) + c = s.execute() while True: row = c.fetchone() if row is None: break - (tablename, localcol, remotecol) = (row[2], row[3], row[4]) - #print "row! " + repr(row) - remotetable = Table(tablename, self, autoload = True) - table.c[localcol].foreign_key = schema.ForeignKey(remotetable.c[remotecol]) + (type, constrained_column, referred_schema, referred_table, referred_column) = ( + row[constraints.c.constraint_type], + row[key_constraints.c.column_name], + row[column_constraints.c.table_schema], + row[column_constraints.c.table_name], + row[column_constraints.c.column_name] + ) + print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) + if type=='PRIMARY KEY': + table.c[constrained_column]._set_primary_key() + elif type=='FOREIGN KEY': + remotetable = Table(referred_table, self, autoload = True, schema=referred_schema) + table.c[constrained_column].foreign_key = schema.ForeignKey(remotetable.c[referred_column]) class PGCompiler(ansisql.ANSICompiler): def bindparam_string(self, name): diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 0175a1c7bb..0dc2e267f8 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -120,6 +120,9 @@ class SQLEngine(schema.SchemaEngine): """returns a new sql.ColumnImpl object to correspond to the given Column object.""" return sql.ColumnImpl(column) + def get_default_schema_name(self): + return None + def last_inserted_ids(self): """returns a thread-local list of the primary keys for the last insert statement executed. This does not apply to straight textual clauses; only to sql.Insert objects compiled against a schema.Table object, which are executed via statement.execute(). The order of items in the list is the same as that of the Table's 'primary_keys' attribute.""" @@ -297,7 +300,8 @@ class ResultProxy: rec = (typemap.get(item[0], types.NULLTYPE), i) else: rec = (types.NULLTYPE, i) - self.props[item[0].lower()] = rec + if self.props.setdefault(item[0].lower(), rec) is not rec: + raise "Duplicate column name '%s' in result set! use use_labels on select statement" % (item[0].lower()) self.props[i] = rec i+=1 diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 825fbe4a79..6d8ee22592 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -47,22 +47,31 @@ class SchemaItem(object): raise AttributeError(key) return getattr(self._impl, key) - +def _get_table_key(engine, name, schema): + if schema is not None and schema == engine.get_default_schema_name(): + schema = None + if schema is None: + return name + else: + return schema + "." + name + class TableSingleton(type): def __call__(self, name, engine, *args, **kwargs): try: - table = engine.tables[name] + schema = kwargs.get('schema', None) + key = _get_table_key(engine, name, schema) + table = engine.tables[key] if len(args): if kwargs.get('redefine', False): table.reload_values(*args) else: - raise "Table '%s' is already defined. specify 'redefine=True' to remap columns" % name + raise "Table '%s.%s' is already defined. specify 'redefine=True' to remap columns" % (schema, name) return table except KeyError: if kwargs.get('mustexist', False): - raise "Table '%s' not defined" % name + raise "Table '%s.%s' not defined" % (schema, name) table = type.__call__(self, name, engine, *args, **kwargs) - engine.tables[name] = table + engine.tables[key] = table # load column definitions from the database if 'autoload' is defined # we do it after the table is in the singleton dictionary to support # circular foreign keys @@ -86,7 +95,7 @@ class Table(SchemaItem): self._impl = self.engine.tableimpl(self) self._init_items(*args) self.schema = kwargs.get('schema', None) - if self.schema: + if self.schema is not None: self.fullname = "%s.%s" % (self.schema, self.name) else: self.fullname = self.name @@ -112,15 +121,18 @@ class Table(SchemaItem): c.accept_visitor(visitor) return visitor.visit_table(self) - def toengine(self, engine): + def toengine(self, engine, schema=None): """returns a singleton instance of this Table with a different engine""" try: - return engine.tables[self.name] + if schema is None: + schema = self.schema + key = _get_table_key(engine, self.name, schema) + return engine.tables[key] except: args = [] for c in self.columns: args.append(c.copy()) - return Table(self.name, engine, schema=self.schema, *args) + return Table(self.name, engine, schema=schema, *args) class Column(SchemaItem): """represents a column in a database table.""" @@ -138,7 +150,12 @@ class Column(SchemaItem): original = property(lambda s: s._orig or s) engine = property(lambda s: s.table.engine) - + + def _set_primary_key(self): + self.primary_key = True + self.nullable = False + self.table.primary_keys.append(self) + def _set_parent(self, table): if not self.hidden: table.columns[self.key] = self diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index c010ebbb0b..9a9e5423f2 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -46,6 +46,8 @@ def adapt_type(typeobj, colspecs): return typeobj.adapt(typeobj.__class__) class NullTypeEngine(TypeEngine): + def __init__(self, *args, **kwargs): + pass def get_col_spec(self): raise NotImplementedError() def convert_bind_param(self, value): diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index bae5d6bad5..b67eaa9d58 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -359,8 +359,8 @@ class DependencySorter(object): while n.parent is not None: n = n.parent return n - def get_highest_sibling(self, node): - """returns the highest ancestor node of this one which is either the root node, or the common parent of this node and the given node""" + def get_sibling_ancestor(self, node): + """returns the node which is an ancestor of this node and is a sibling of the given node, or else returns this node's root node.""" n = self while n.parent is not None and n.parent is not node.parent: n = n.parent @@ -376,6 +376,7 @@ class DependencySorter(object): def __init__(self, tuples, allitems): self.tuples = tuples self.allitems = allitems + def sort(self): (tuples, allitems) = (self.tuples, self.allitems) @@ -413,7 +414,7 @@ class DependencySorter(object): raise "Circular dependency detected" elif not childnode.is_descendant_of(parentnode): # if relationship doesnt exist, connect nodes together - root = childnode.get_highest_sibling(parentnode) + root = childnode.get_sibling_ancestor(parentnode) parentnode.append(root) # now we have a collection of subtrees which represent dependencies. diff --git a/test/engines.py b/test/engines.py index 291cafabd6..0ffa97f954 100644 --- a/test/engines.py +++ b/test/engines.py @@ -2,6 +2,7 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.databases.postgres as postgres import sqlalchemy.databases.oracle as oracle +import sqlalchemy.databases.sqlite as sqllite db = ansisql.engine() @@ -14,9 +15,11 @@ import unittest, re class EngineTest(PersistTest): - def testsqlitetableops(self): - import sqlalchemy.databases.sqlite as sqllite -# db = sqllite.engine(':memory:', {}, echo = testbase.echo) + def testsqlite(self): + db = sqllite.engine(':memory:', {}, echo = testbase.echo) + self.do_tableops(db) + + def testpostgres(self): db = postgres.engine({'database':'test', 'host':'127.0.0.1', 'user':'scott', 'password':'tiger'}, echo = testbase.echo) self.do_tableops(db) @@ -30,7 +33,7 @@ class EngineTest(PersistTest): Column('test3', TEXT), Column('test4', DECIMAL, nullable = False), Column('test5', TIMESTAMP), - Column('parent_user_id', INT, foreign_key = ForeignKey('users.user_id')), + Column('parent_user_id', INT, ForeignKey('users.user_id')), Column('test6', DATETIME, nullable = False), Column('test7', CLOB), Column('test8', BLOB), @@ -39,13 +42,10 @@ class EngineTest(PersistTest): addresses = Table('email_addresses', db, Column('address_id', Integer, primary_key = True), - Column('remote_user_id', Integer, foreign_key = ForeignKey(users.c.user_id)), + Column('remote_user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(20)), ) - users.drop() - addresses.drop() - # users.c.parent_user_id.set_foreign_key(ForeignKey(users.c.user_id)) users.create() @@ -54,14 +54,19 @@ class EngineTest(PersistTest): # clear out table registry db.tables.clear() - users = Table('users', db, autoload = True) - addresses = Table('email_addresses', db, autoload = True) + try: + users = Table('users', db, autoload = True) + addresses = Table('email_addresses', db, autoload = True) + finally: + addresses.drop() + users.drop() - users.drop() - addresses.drop() - users.create() addresses.create() + + addresses.drop() + users.drop() + if __name__ == "__main__": unittest.main() -- 2.47.2