]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- improved ability to get the "correct" and most minimal set of primary key
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Jul 2007 21:57:51 +0000 (21:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Jul 2007 21:57:51 +0000 (21:57 +0000)
  columns from a join, equating foreign keys and otherwise equated columns.
  this is also mostly to help inheritance scenarios formulate the best
  choice of primary key columns.  [ticket:185]
- added 'bind' argument to Sequence.create()/drop(), ColumnDefault.execute()

CHANGES
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/orm/inheritance.py
test/sql/defaults.py
test/sql/selectable.py
test/testbase.py

diff --git a/CHANGES b/CHANGES
index 0a1990ea1b1230831ef16631486f6ff7df0ab3f4..699fb9bcafc2cf84eac13f0166d1c8d09125f150 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -13,6 +13,7 @@
       InstrumentedList-like (e.g. over keys instead of values)
     - association proxies no longer bind tightly to source collections
       [ticket:597], and are constructed with a thunk instead
+    - added selectone_by() to assignmapper
 - orm
     - forwards-compatibility with 0.4: added one(), first(), and 
       all() to Query.  almost all Query functionality from 0.4 is
     - composite primary key is represented as a non-keyed set to allow for 
       composite keys consisting of cols with the same name; occurs within a
       Join.  helps inheritance scenarios formulate correct PK.
+    - improved ability to get the "correct" and most minimal set of primary key 
+      columns from a join, equating foreign keys and otherwise equated columns.
+      this is also mostly to help inheritance scenarios formulate the best 
+      choice of primary key columns.  [ticket:185]
+    - added 'bind' argument to Sequence.create()/drop(), ColumnDefault.execute()
     - some enhancements to "column targeting", the ability to match a column
       to a "corresponding" column in another selectable.  this affects mostly
       ORM ability to map to complex joins
     - fix port option handling for pyodbc [ticket:634]
     - now able to reflect start and increment values for identity columns
     - preliminary support for using scope_identity() with pyodbc
-
-- extensions
-    - added selectone_by() to assignmapper
     
 0.3.8
 - engines
index c62ed33734199bc879970ed4d427debdae6737a3..d9df0da90c159c79ec4ad0a9bb50951702d4dbfb 100644 (file)
@@ -59,13 +59,15 @@ class SchemaItem(object):
         """Return the engine or None if no engine."""
 
         if raiseerr:
-            e = self._derived_metadata().bind
+            m = self._derived_metadata()
+            e = m and m.bind or None
             if e is None:
                 raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.")
             else:
                 return e
         else:
-            return self._derived_metadata().bind
+            m = self._derived_metadata()
+            return m and m.bind or None
 
     def get_engine(self):
         """Return the engine or raise an error if no engine.
@@ -280,7 +282,9 @@ class Table(SchemaItem, sql.TableClause):
         self.schema = kwargs.pop('schema', None)
         self.indexes = util.Set()
         self.constraints = util.Set()
+        self._columns = sql.ColumnCollection()
         self.primary_key = PrimaryKeyConstraint()
+        self._foreign_keys = util.OrderedSet()
         self.quote = kwargs.pop('quote', False)
         self.quote_schema = kwargs.pop('quote_schema', False)
         if self.schema is not None:
@@ -298,6 +302,11 @@ class Table(SchemaItem, sql.TableClause):
         # store extra kwargs, which should only contain db-specific options
         self.kwargs = kwargs
 
+    def _export_columns(self, columns=None):
+        # override FromClause's collection initialization logic; TableClause and Table
+        # implement it differently
+        pass
+
     def _get_case_sensitive_schema(self):
         try:
             return getattr(self, '_case_sensitive_schema')
@@ -545,6 +554,14 @@ class Column(SchemaItem, sql._ColumnClause):
     def _get_engine(self):
         return self.table.bind
 
+    def references(self, column):
+        """return true if this column references the given column via foreign key"""
+        for fk in self.foreign_keys:
+            if fk.column is column:
+                return True
+        else:
+            return False
+            
     def append_foreign_key(self, fk):
         fk._set_parent(self)
 
@@ -763,7 +780,7 @@ class DefaultGenerator(SchemaItem):
 
     def __init__(self, for_update=False, metadata=None):
         self.for_update = for_update
-        self._metadata = metadata
+        self._metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata')
 
     def _derived_metadata(self):
         try:
@@ -782,8 +799,10 @@ class DefaultGenerator(SchemaItem):
         else:
             self.column.default = self
 
-    def execute(self, **kwargs):
-        return self._get_engine(raiseerr=True).execute_default(self, **kwargs)
+    def execute(self, bind=None, **kwargs):
+        if bind is None:
+            bind = self._get_engine(raiseerr=True)
+        return bind.execute_default(self, **kwargs)
 
     def __repr__(self):
         return "DefaultGenerator()"
@@ -845,12 +864,15 @@ class Sequence(DefaultGenerator):
         super(Sequence, self)._set_parent(column)
         column.sequence = self
 
-    def create(self):
-       self._get_engine(raiseerr=True).create(self)
-       return self
+    def create(self, bind=None):
+        if bind is None:
+            bind = self._get_engine(raiseerr=True)
+        bind.create(self)
 
-    def drop(self):
-       self._get_engine(raiseerr=True).drop(self)
+    def drop(self, bind=None):
+        if bind is None:
+            bind = self._get_engine(raiseerr=True)
+        bind.drop(self)
 
     def accept_visitor(self, visitor):
         """Call the visit_seauence method on the given visitor."""
index d1fc3fef17a8549a03fd258d3ce3d6e18ce28c18..0961cd4ee35c610cbc64cfe7c3a20a5b1d784ee8 100644 (file)
@@ -1764,7 +1764,7 @@ class FromClause(Selectable):
         """)
     oid_column = property(_get_oid_column)
 
-    def _export_columns(self):
+    def _export_columns(self, columns=None):
         """Initialize column collections.
 
         The collections include the primary key, foreign keys, list of
@@ -1777,14 +1777,16 @@ class FromClause(Selectable):
         its parent ``Selectable`` is this ``FromClause``.
         """
 
-        if hasattr(self, '_columns'):
+        if hasattr(self, '_columns') and columns is None:
             # TODO: put a mutex here ?  this is a key place for threading probs
             return
         self._columns = ColumnCollection()
         self._primary_key = ColumnSet()
         self._foreign_keys = util.Set()
         self._orig_cols = {}
-        for co in self._adjusted_exportable_columns():
+        if columns is None:
+            columns = self._adjusted_exportable_columns()
+        for co in columns:
             cp = self._proxy_column(co)
             for ci in cp.orig_set:
                 cx = self._orig_cols.get(ci)
@@ -2250,15 +2252,38 @@ class Join(FromClause):
     encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace'))
 
     def _init_primary_key(self):
-        pkcol = util.OrderedSet()
-        for col in self._adjusted_exportable_columns():
-            if col.primary_key:
-                pkcol.add(col)
-        for col in list(pkcol):
-            for f in col.foreign_keys:
-                if f.column in pkcol:
-                    pkcol.remove(col)
-        self.primary_key.extend(pkcol)
+        pkcol = util.Set([c for c in self._adjusted_exportable_columns() if c.primary_key])
+    
+        equivs = {}
+        def add_equiv(a, b):
+            for x, y in ((a, b), (b, a)):
+                if x in equivs:
+                    equivs[x].add(y)
+                else:
+                    equivs[x] = util.Set([y])
+                    
+        class BinaryVisitor(ClauseVisitor):
+            def visit_binary(self, binary):
+                if binary.operator == '=':
+                    add_equiv(binary.left, binary.right)
+        BinaryVisitor().traverse(self.onclause)
+        
+        for col in pkcol:
+            for fk in col.foreign_keys:
+                if fk.column in pkcol:
+                    add_equiv(col, fk.column)
+                    
+        omit = util.Set()
+        for col in pkcol:
+            p = col
+            for c in equivs.get(col, util.Set()):
+                if p.references(c) or (c.primary_key and not p.primary_key):
+                    omit.add(p)
+                    p = c
+            
+        self.__primary_key = ColumnSet([c for c in self._adjusted_exportable_columns() if c.primary_key and c not in omit])
+
+    primary_key = property(lambda s:s.__primary_key)
         
     def _locate_oid_column(self):
         return self.left.oid_column
@@ -2333,7 +2358,11 @@ class Join(FromClause):
                 collist.append(c)
         self.__folded_equivalents = collist
         return self.__folded_equivalents
-        
+
+    folded_equivalents = property(_get_folded_equivalents, doc="Returns the column list of this Join with all equivalently-named, "
+                                                            "equated columns folded into one column, where 'equated' means they are "
+                                                            "equated to each other in the ON clause of this join.")    
+    
     def select(self, whereclause = None, fold_equivalents=False, **kwargs):
         """Create a ``Select`` from this ``Join``.
         
@@ -2353,7 +2382,7 @@ class Join(FromClause):
           
         """
         if fold_equivalents:
-            collist = self._get_folded_equivalents()
+            collist = self.folded_equivalents
         else:
             collist = [self.left, self.right]
             
@@ -2632,12 +2661,8 @@ class TableClause(FromClause):
         super(TableClause, self).__init__(name)
         self.name = self.fullname = name
         self.encodedname = self.name.encode('ascii', 'backslashreplace')
-        self._columns = ColumnCollection()
-        self._foreign_keys = util.OrderedSet()
-        self._primary_key = ColumnCollection()
-        for c in columns:
-            self.append_column(c)
         self._oid_column = _ColumnClause('oid', self, _is_oid=True)
+        self._export_columns(columns)
 
     def named_with_column(self):
         return True
@@ -2649,6 +2674,10 @@ class TableClause(FromClause):
     def _locate_oid_column(self):
         return self._oid_column
 
+    def _proxy_column(self, c):
+        self.append_column(c)
+        return c
+
     def _orig_columns(self):
         try:
             return self._orig_cols
@@ -2666,6 +2695,7 @@ class TableClause(FromClause):
             return [c for c in self.c]
         else:
             return []
+
     def accept_visitor(self, visitor):
         visitor.visit_table(self)
 
index c827f1e7d663a0868c7c70125b90ef3c2c73c67e..b47822d613418fefe17f1091423b4c4e0015736b 100644 (file)
@@ -10,6 +10,7 @@ except ImportError:
     import dummy_thread as thread
     import dummy_threading as threading
 
+from sqlalchemy import exceptions
 import md5
 import sys
 import warnings
@@ -128,7 +129,16 @@ def duck_type_collection(col, default=None):
         return dict
     else:
         return default
-    
+
+def assert_arg_type(arg, argtype, name):
+    if isinstance(arg, argtype):
+        return arg
+    else:
+        if isinstance(argtype, tuple):
+            raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg))))
+        else:
+            raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg))))
+        
 def warn_exception(func):
     """executes the given function, catches all exceptions and converts to a warning."""
     try:
index 0458716e5aa1d17aba2b8161648246cd25e3f6ce..d608b2387a84812040445f29f07fd470d5e793ad 100644 (file)
@@ -481,6 +481,80 @@ class InheritTest7(testbase.ORMTest):
         a.password = 'sadmin'
         sess.flush()
         assert user_roles.count().scalar() == 1
+
+class InheritTest8(testbase.ORMTest):
+    """test the construction of mapper.primary_key when an inheriting relationship
+    joins on a column other than primary key column."""
+    keep_data = True
+    
+    def define_tables(self, metadata):
+        global person_table, employee_table, Person, Employee
+        
+        person_table = Table("persons", metadata,
+                Column("id", Integer, primary_key=True),
+                Column("name", String(80)),
+                )
+
+        employee_table = Table("employees", metadata,
+                Column("id", Integer, primary_key=True),
+                Column("salary", Integer),
+                Column("person_id", Integer, ForeignKey("persons.id")),
+                )
+
+        class Person(object):
+            def __init__(self, name):
+                self.name = name
+
+        class Employee(Person): pass
+    
+    def insert_data(self):
+        person_insert = person_table.insert()
+        person_insert.execute(id=1, name='alice')
+        person_insert.execute(id=2, name='bob')
+
+        employee_insert = employee_table.insert()
+        employee_insert.execute(id=2, salary=250, person_id=1) # alice
+        employee_insert.execute(id=3, salary=200, person_id=2) # bob
+        
+    def test_implicit(self):
+        person_mapper = mapper(Person, person_table)
+        mapper(Employee, employee_table, inherits=person_mapper)
+        print class_mapper(Employee).primary_key
+        assert list(class_mapper(Employee).primary_key) == [person_table.c.id, employee_table.c.id]
+        self._do_test(True)
+
+    def test_explicit_props(self):
+        person_mapper = mapper(Person, person_table)
+        mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id})
+        self._do_test(True)
+    
+    def test_explicit_composite_pk(self):
+        person_mapper = mapper(Person, person_table)
+        mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id])
+        self._do_test(True)
+
+    def test_explicit_pk(self):
+        person_mapper = mapper(Person, person_table)
+        mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id])
+        self._do_test(False)
+        
+    def _do_test(self, composite):
+        session = create_session()
+        query = session.query(Employee)
+
+        if composite:
+            alice1 = query.get([1,2])
+            bob = query.get([2,3])
+            alice2 = query.get([1,2])
+        else:
+            alice1 = query.get(1)
+            bob = query.get(2)
+            alice2 = query.get(1)
+            
+            assert alice1.name == alice2.name == 'alice'
+            assert bob.name == 'bob'
+        
+
         
 if __name__ == "__main__":    
     testbase.main()
index 09c58d2c2f7fcd445d08dd138606710e772f1ede..10a3610f99ad5da8c96ee83207b327debebf48d9 100644 (file)
@@ -231,6 +231,16 @@ class SequenceTest(PersistTest):
             self.assert_(x == 1)
         finally:
             s.drop()
+
+    @testbase.supported('postgres', 'oracle')
+    def teststandalone_explicit(self):
+        s = Sequence("my_sequence")
+        s.create(bind=testbase.db)
+        try:
+            x = s.execute(testbase.db)
+            self.assert_(x == 1)
+        finally:
+            s.drop(testbase.db)
     
     @testbase.supported('postgres', 'oracle')
     def teststandalone2(self):
index 853821f9af0245d9ba2457d3219de630fb60745e..ecd8253b8ffb60d7415f11956761ea74eb37bc0f 100755 (executable)
@@ -170,6 +170,61 @@ class SelectableTest(testbase.AssertMixin):
         print str(criterion)\r
         print str(j.onclause)\r
         self.assert_(criterion.compare(j.onclause))\r
+\r
+class PrimaryKeyTest(testbase.AssertMixin):\r
+    def test_join_pk_collapse_implicit(self):\r
+        """test that redundant columns in a join get 'collapsed' into a minimal primary key, \r
+        which is the root column along a chain of foreign key relationships."""\r
+        \r
+        meta = MetaData()\r
+        a = Table('a', meta, Column('id', Integer, primary_key=True))\r
+        b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True))\r
+        c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True))\r
+        d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True))\r
+\r
+        assert c.c.id.references(b.c.id)\r
+        assert not d.c.id.references(a.c.id)\r
+        \r
+        assert list(a.join(b).primary_key) == [a.c.id]\r
+        assert list(b.join(c).primary_key) == [b.c.id]\r
+        assert list(a.join(b).join(c).primary_key) == [a.c.id]\r
+        assert list(b.join(c).join(d).primary_key) == [b.c.id]\r
+        assert list(d.join(c).join(b).primary_key) == [b.c.id]\r
+        assert list(a.join(b).join(c).join(d).primary_key) == [a.c.id]\r
+\r
+    def test_join_pk_collapse_explicit(self):\r
+        """test that redundant columns in a join get 'collapsed' into a minimal primary key, \r
+        which is the root column along a chain of explicit join conditions."""\r
+\r
+        meta = MetaData()\r
+        a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer))\r
+        b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer))\r
+        c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True), Column('x', Integer))\r
+        d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True), Column('x', Integer))\r
+\r
+        print list(a.join(b, a.c.x==b.c.id).primary_key)\r
+        assert list(a.join(b, a.c.x==b.c.id).primary_key) == [b.c.id]\r
+        assert list(b.join(c, b.c.x==c.c.id).primary_key) == [b.c.id]\r
+        assert list(a.join(b).join(c, c.c.id==b.c.x).primary_key) == [a.c.id]\r
+        assert list(b.join(c, c.c.x==b.c.id).join(d).primary_key) == [c.c.id]\r
+        assert list(b.join(c, c.c.id==b.c.x).join(d).primary_key) == [b.c.id]\r
+        assert list(d.join(b, d.c.id==b.c.id).join(c, b.c.id==c.c.x).primary_key) == [c.c.id]\r
+        assert list(a.join(b).join(c, c.c.id==b.c.x).join(d).primary_key) == [a.c.id]\r
+        \r
+        assert list(a.join(b, and_(a.c.id==b.c.id, a.c.x==b.c.id)).primary_key) == [a.c.id]\r
+    \r
+    def test_init_doesnt_blowitaway(self):\r
+        meta = MetaData()\r
+        a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer))\r
+        b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer))\r
+\r
+        j = a.join(b)\r
+        assert list(j.primary_key) == [a.c.id]\r
+        \r
+        j.foreign_keys\r
+        assert list(j.primary_key) == [a.c.id]\r
+\r
+        \r
         \r
 if __name__ == "__main__":\r
     testbase.main()\r
index 54fcd9db2361268b8218e2a59c30eeb38b8da30d..7c5095d1a0abf432a18d54cc5c07613e3154c4de 100644 (file)
@@ -267,8 +267,11 @@ class ORMTest(AssertMixin):
         metadata = MetaData(db)
         self.define_tables(metadata)
         metadata.create_all()
+        self.insert_data()
     def define_tables(self, metadata):
         raise NotImplementedError()
+    def insert_data(self):
+        pass
     def get_metadata(self):
         return metadata
     def tearDownAll(self):