]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 31 Oct 2005 02:11:16 +0000 (02:11 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 31 Oct 2005 02:11:16 +0000 (02:11 +0000)
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util.py
test/engines.py

index ef6cac24ae34cd757c40fafc873f85d4c8dd73c8..67ee1d4fd606668b3e9557ba3cf65b54dd59b412 100644 (file)
@@ -104,6 +104,13 @@ gen_column_constraints = schema.Table("constraint_column_usage", generic_engine,
     Column("constraint_name", String),
     schema="information_schema")
 
+gen_key_constraints = schema.Table("key_column_usage", generic_engine,
+    Column("table_schema", String),
+    Column("table_name", String),
+    Column("column_name", String),
+    Column("constraint_name", String),
+    schema="information_schema")
+
 def engine(opts, **params):
     return PGSQLEngine(opts, **params)
 
@@ -138,6 +145,11 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
         
     def reflecttable(self, table):
         raise "not implemented"
+
+    def get_default_schema_name(self):
+        if not hasattr(self, '_default_schema_name'):
+            self._default_schema_name = text("select current_schema()", self).scalar()
+        return self._default_schema_name
         
     def last_inserted_ids(self):
         # if we used sequences or already had all values for the last inserted row,
@@ -205,65 +217,73 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
         return self.module
 
     def reflecttable(self, table):
-        columns = gen_columns.toengine(table.engine)
-        constraints = gen_constraints.toengine(table.engine)
-        column_constraints = gen_column_constraints.toengine(table.engine)
-        
-        s = select([columns, constraints.c.constraint_type], 
-            columns.c.table_name==table.name, 
-            order_by=[columns.c.ordinal_position])
-            
-        s.append_from(sql.outerjoin(columns, column_constraints, 
-                              sql.and_(
-                                      columns.c.table_name==column_constraints.c.table_name,
-                                      columns.c.table_schema==column_constraints.c.table_schema,
-                                      columns.c.column_name==column_constraints.c.column_name,
-                                  )).outerjoin(constraints, 
-                                  sql.and_(
-                                      column_constraints.c.table_schema==constraints.c.table_schema,
-                                      column_constraints.c.constraint_name==constraints.c.constraint_name,
-                                      constraints.c.constraint_type=='PRIMARY KEY'
-                                  )))
+        columns = gen_columns.toengine(self)
+        constraints = gen_constraints.toengine(self)
+        column_constraints = gen_column_constraints.toengine(self)
+        key_constraints = gen_key_constraints.toengine(self)
 
         if table.schema is not None:
-            s.append_whereclause(columns.c.table_schema==table.schema)
+            current_schema = table.schema
         else:
-            current_schema = text("select current_schema()", table.engine).scalar()
-            s.append_whereclause(columns.c.table_schema==current_schema)
-
+            current_schema = self.get_default_schema_name()
+        
+        s = select([columns], 
+            sql.and_(columns.c.table_name==table.name,
+            columns.c.table_schema==current_schema),
+            order_by=[columns.c.ordinal_position])
+            
         c = s.execute()
         while True:
             row = c.fetchone()
             if row is None:
                 break
             #print "row! " + repr(row)
-            (name, type, nullable, primary_key, charlen, numericprec, numericscale) = (
+            (name, type, nullable, charlen, numericprec, numericscale) = (
                 row[columns.c.column_name], 
                 row[columns.c.data_type], 
-                not row[columns.c.is_nullable], 
-                row[constraints.c.constraint_type] is not None,
+                row[columns.c.is_nullable] == 'YES', 
                 row[columns.c.character_maximum_length],
                 row[columns.c.numeric_precision],
                 row[columns.c.numeric_scale],
                 )
 
-            #match = re.match(r'(\w+)(\(.*?\))?', type)
-            #coltype = match.group(1)
-            #args = match.group(2)
-
-            #print "coltype: " + repr(coltype) + " args: " + repr(args)
+            args = []
+            for a in (charlen, numericprec, numericscale):
+                if a is not None:
+                    args.append(a)
             coltype = ischema_names[type]
-            table.append_item(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable))
-        return
-        c = self.execute("PRAGMA foreign_key_list(" + table.name + ")", {})
+            #print "coltype " + repr(coltype) + " args " +  repr(args)
+            coltype = coltype(*args)
+            table.append_item(schema.Column(name, coltype, nullable = nullable))
+
+        s = select([
+            constraints.c.constraint_type,
+            column_constraints,
+            key_constraints
+            ], 
+            sql.and_(
+                key_constraints.c.constraint_name==column_constraints.c.constraint_name,
+                column_constraints.c.constraint_name==constraints.c.constraint_name,
+                constraints.c.table_name==table.name, constraints.c.table_schema==current_schema)
+        , use_labels=True)
+        c = s.execute()
         while True:
             row = c.fetchone()
             if row is None:
                 break
-            (tablename, localcol, remotecol) = (row[2], row[3], row[4])
-            #print "row! " + repr(row)
-            remotetable = Table(tablename, self, autoload = True)
-            table.c[localcol].foreign_key = schema.ForeignKey(remotetable.c[remotecol])
+            (type, constrained_column, referred_schema, referred_table, referred_column) = (
+                row[constraints.c.constraint_type],
+                row[key_constraints.c.column_name],
+                row[column_constraints.c.table_schema],
+                row[column_constraints.c.table_name],
+                row[column_constraints.c.column_name]
+            )
+            print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) 
+            if type=='PRIMARY KEY':
+                table.c[constrained_column]._set_primary_key()
+            elif type=='FOREIGN KEY':
+                remotetable = Table(referred_table, self, autoload = True, schema=referred_schema)
+                table.c[constrained_column].foreign_key = schema.ForeignKey(remotetable.c[referred_column])
 
 class PGCompiler(ansisql.ANSICompiler):
     def bindparam_string(self, name):
index 0175a1c7bb864ddac349f8edd2fe27263eb5c9d7..0dc2e267f81b45204436cafa0be24f55256c7021 100644 (file)
@@ -120,6 +120,9 @@ class SQLEngine(schema.SchemaEngine):
         """returns a new sql.ColumnImpl object to correspond to the given Column object."""
         return sql.ColumnImpl(column)
 
+    def get_default_schema_name(self):
+        return None
+        
     def last_inserted_ids(self):
         """returns a thread-local list of the primary keys for the last insert statement executed.
         This does not apply to straight textual clauses; only to sql.Insert objects compiled against a schema.Table object, which are executed via statement.execute().  The order of items in the list is the same as that of the Table's 'primary_keys' attribute."""
@@ -297,7 +300,8 @@ class ResultProxy:
                     rec = (typemap.get(item[0], types.NULLTYPE), i)
                 else:
                     rec = (types.NULLTYPE, i)
-                self.props[item[0].lower()] = rec
+                if self.props.setdefault(item[0].lower(), rec) is not rec:
+                    raise "Duplicate column name '%s' in result set! use use_labels on select statement" % (item[0].lower())
                 self.props[i] = rec
                 i+=1
 
index 825fbe4a79508ed664fa5638c759a89ef0009948..6d8ee2259240dc7b4ecc03deed9c0e93eab84ce0 100644 (file)
@@ -47,22 +47,31 @@ class SchemaItem(object):
             raise AttributeError(key)
         return getattr(self._impl, key)
 
-
+def _get_table_key(engine, name, schema):
+    if schema is not None and schema == engine.get_default_schema_name():
+        schema = None
+    if schema is None:
+        return name
+    else:
+        return schema + "." + name
+        
 class TableSingleton(type):
     def __call__(self, name, engine, *args, **kwargs):
         try:
-            table = engine.tables[name]
+            schema = kwargs.get('schema', None)
+            key = _get_table_key(engine, name, schema)
+            table = engine.tables[key]
             if len(args):
                 if kwargs.get('redefine', False):
                     table.reload_values(*args)
                 else:
-                    raise "Table '%s' is already defined. specify 'redefine=True' to remap columns" % name
+                    raise "Table '%s.%s' is already defined. specify 'redefine=True' to remap columns" % (schema, name)
             return table
         except KeyError:
             if kwargs.get('mustexist', False):
-                raise "Table '%s' not defined" % name
+                raise "Table '%s.%s' not defined" % (schema, name)
             table = type.__call__(self, name, engine, *args, **kwargs)
-            engine.tables[name] = table
+            engine.tables[key] = table
             # load column definitions from the database if 'autoload' is defined
             # we do it after the table is in the singleton dictionary to support
             # circular foreign keys
@@ -86,7 +95,7 @@ class Table(SchemaItem):
         self._impl = self.engine.tableimpl(self)
         self._init_items(*args)
         self.schema = kwargs.get('schema', None)
-        if self.schema:
+        if self.schema is not None:
             self.fullname = "%s.%s" % (self.schema, self.name)
         else:
             self.fullname = self.name
@@ -112,15 +121,18 @@ class Table(SchemaItem):
             c.accept_visitor(visitor)
         return visitor.visit_table(self)
 
-    def toengine(self, engine):
+    def toengine(self, engine, schema=None):
         """returns a singleton instance of this Table with a different engine"""
         try:
-            return engine.tables[self.name]
+            if schema is None:
+                schema = self.schema
+            key = _get_table_key(engine, self.name, schema)
+            return engine.tables[key]
         except:
             args = []
             for c in self.columns:
                 args.append(c.copy())
-            return Table(self.name, engine, schema=self.schema, *args)
+            return Table(self.name, engine, schema=schema, *args)
 
 class Column(SchemaItem):
     """represents a column in a database table."""
@@ -138,7 +150,12 @@ class Column(SchemaItem):
         
     original = property(lambda s: s._orig or s)
     engine = property(lambda s: s.table.engine)
-        
+    
+    def _set_primary_key(self):
+        self.primary_key = True
+        self.nullable = False
+        self.table.primary_keys.append(self)
+            
     def _set_parent(self, table):
         if not self.hidden:
             table.columns[self.key] = self
index c010ebbb0b4b36f9a17122b2cc548b808e8aaa7c..9a9e5423f25c496bf2138e3063262f7acf7a77ef 100644 (file)
@@ -46,6 +46,8 @@ def adapt_type(typeobj, colspecs):
     return typeobj.adapt(typeobj.__class__)
     
 class NullTypeEngine(TypeEngine):
+    def __init__(self, *args, **kwargs):
+        pass
     def get_col_spec(self):
         raise NotImplementedError()
     def convert_bind_param(self, value):
index bae5d6bad52b7031b1d2b480ef9dcb3a2ae7e8c0..b67eaa9d582b8ee339e23e51835eb406ed8f7011 100644 (file)
@@ -359,8 +359,8 @@ class DependencySorter(object):
             while n.parent is not None:
                 n = n.parent
             return n
-        def get_highest_sibling(self, node):
-            """returns the highest ancestor node of this one which is either the root node, or the common parent of this node and the given node"""
+        def get_sibling_ancestor(self, node):
+            """returns the node which is an ancestor of this node and is a sibling of the given node, or else returns this node's root node."""
             n = self
             while n.parent is not None and n.parent is not node.parent:
                 n = n.parent
@@ -376,6 +376,7 @@ class DependencySorter(object):
     def __init__(self, tuples, allitems):
         self.tuples = tuples
         self.allitems = allitems
+        
     def sort(self):
         (tuples, allitems) = (self.tuples, self.allitems)
         
@@ -413,7 +414,7 @@ class DependencySorter(object):
                 raise "Circular dependency detected"
             elif not childnode.is_descendant_of(parentnode):
                 # if relationship doesnt exist, connect nodes together
-                root = childnode.get_highest_sibling(parentnode)
+                root = childnode.get_sibling_ancestor(parentnode)
                 parentnode.append(root)
 
         # now we have a collection of subtrees which represent dependencies.
index 291cafabd6625fd4ee9bb3632a824350061bf811..0ffa97f954a6c5a9de13d9ec6fc45157e8d3c8b7 100644 (file)
@@ -2,6 +2,7 @@
 import sqlalchemy.ansisql as ansisql
 import sqlalchemy.databases.postgres as postgres
 import sqlalchemy.databases.oracle as oracle
+import sqlalchemy.databases.sqlite as sqllite
 
 db = ansisql.engine()
 
@@ -14,9 +15,11 @@ import unittest, re
 
 
 class EngineTest(PersistTest):
-    def testsqlitetableops(self):
-        import sqlalchemy.databases.sqlite as sqllite
-#        db = sqllite.engine(':memory:', {}, echo = testbase.echo)
+    def testsqlite(self):
+        db = sqllite.engine(':memory:', {}, echo = testbase.echo)
+        self.do_tableops(db)
+
+    def testpostgres(self):
         db = postgres.engine({'database':'test', 'host':'127.0.0.1', 'user':'scott', 'password':'tiger'}, echo = testbase.echo)
         self.do_tableops(db)
         
@@ -30,7 +33,7 @@ class EngineTest(PersistTest):
             Column('test3', TEXT),
             Column('test4', DECIMAL, nullable = False),
             Column('test5', TIMESTAMP),
-            Column('parent_user_id', INT, foreign_key = ForeignKey('users.user_id')),
+            Column('parent_user_id', INT, ForeignKey('users.user_id')),
             Column('test6', DATETIME, nullable = False),
             Column('test7', CLOB),
             Column('test8', BLOB),
@@ -39,13 +42,10 @@ class EngineTest(PersistTest):
 
         addresses = Table('email_addresses', db,
             Column('address_id', Integer, primary_key = True),
-            Column('remote_user_id', Integer, foreign_key = ForeignKey(users.c.user_id)),
+            Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
             Column('email_address', String(20)),
         )
 
-        users.drop()
-        addresses.drop()
-        
 #        users.c.parent_user_id.set_foreign_key(ForeignKey(users.c.user_id))
 
         users.create()
@@ -54,14 +54,19 @@ class EngineTest(PersistTest):
         # clear out table registry
         db.tables.clear()
 
-        users = Table('users', db, autoload = True)
-        addresses = Table('email_addresses', db, autoload = True)
+        try:
+            users = Table('users', db, autoload = True)
+            addresses = Table('email_addresses', db, autoload = True)
+        finally:
+            addresses.drop()
+            users.drop()
 
-        users.drop()
-        addresses.drop()
-        
         users.create()
         addresses.create()
+
+        addresses.drop()
+        users.drop()
+        
         
 if __name__ == "__main__":
     unittest.main()