]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
foreign key relatinoships are defined primarily at the schema level
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Sep 2005 02:55:16 +0000 (02:55 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Sep 2005 02:55:16 +0000 (02:55 +0000)
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py

index 11ec39b745b866d984c5bc80cd4cd648e1942851..910823afd7f5ed5d741c278189a235be7cf7a986 100644 (file)
@@ -36,8 +36,8 @@ def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin
     else:
         return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **options)
     
-def relation_mapper(class_, selectable, secondary = None, primaryjoin = None, secondaryjoin = None, table = None, properties = None, lazy = True, uselist = True, foreignkey = None, **options):
-    return relation_loader(mapper(class_, selectable, table = table, properties = properties, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, uselist = uselist, foreignkey = foreignkey, **options)
+def relation_mapper(class_, selectable, secondary = None, primaryjoin = None, secondaryjoin = None, table = None, properties = None, lazy = True, uselist = True, foreignkey = None, primary_keys = None, **options):
+    return relation_loader(mapper(class_, selectable, table=table, properties=properties, primary_keys=primary_keys, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, uselist = uselist, foreignkey = foreignkey, **options)
 
 _mappers = {}
 def mapper(*args, **params):
@@ -65,7 +65,7 @@ def object_mapper(object):
             raise "Object " + object.__class__.__name__ + "/" + repr(id(object)) + " has no mapper specified"
         
 class Mapper(object):
-    def __init__(self, hashkey, class_, selectable, table = None, scope = "thread", properties = None, echo = None, **kwargs):
+    def __init__(self, hashkey, class_, selectable, table = None, scope = "thread", properties = None, echo = None, primary_keys = None, **kwargs):
         self.hashkey = hashkey
         self.class_ = class_
         self.scope = scope
@@ -73,6 +73,7 @@ class Mapper(object):
         tf = TableFinder()
         self.selectable.accept_visitor(tf)
         self.tables = tf.tables
+        self.primary_keys = {}
 
         if table is None:
             if len(self.tables) > 1:
@@ -81,6 +82,22 @@ class Mapper(object):
         else:
             self.table = table
 
+        if primary_keys is not None:
+            for k in primary_keys:
+                self.primary_keys.setdefault(k.table, []).append(k)
+        else:
+            for t in self.tables + [self.selectable]:
+                try:
+                    list = self.primary_keys[t]
+                except KeyError:
+                    list = self.primary_keys.setdefault(t, util.HashSet())
+                if not len(t.primary_keys):
+                    raise "Table " + t.name + " has no primary keys. Specify primary_keys argument to mapper."
+                for k in t.primary_keys:
+                    list.append(k)
+                
+                    
+
         self.echo = echo
 
         # object attribute names mapped to MapperProperty objects
@@ -152,13 +169,13 @@ class Mapper(object):
         """returns an instance of the object based on the given identifier, or None
         if not found.  The *ident argument is a 
         list of primary keys in the order of the table def's primary keys."""
-        key = objectstore.get_id_key(ident, self.class_, self.table, self.selectable)
+        key = objectstore.get_id_key(ident, self.class_, self.table)
         try:
             return objectstore.get(key)
         except KeyError:
             clause = sql.and_()
             i = 0
-            for primary_key in self.table.primary_keys:
+            for primary_key in self.primary_keys[table]:
                 # appending to the and_'s clause list directly to skip
                 # typechecks etc.
                 clause.clauses.append(primary_key == ident[i])
@@ -169,7 +186,7 @@ class Mapper(object):
                 return None
 
     def put(self, instance):
-        key = objectstore.get_instance_key(instance, self.class_, self.table, self.selectable)
+        key = objectstore.get_instance_key(instance, self.class_, self.table, self.primary_keys[self.selectable])
         objectstore.put(key, instance, self.scope)
         return key
 
@@ -232,7 +249,7 @@ class Mapper(object):
         for table, stuff in work.iteritems():
             if len(stuff['update']):
                 clause = sql.and_()
-                for col in table.primary_keys:
+                for col in self.primary_keys[table]:
                     clause.clauses.append(col == sql.bindparam(col.key))
                 statement = table.update(clause)
                 statement.echo = self.echo
@@ -246,7 +263,7 @@ class Mapper(object):
                     statement.execute(**params)
                     primary_key = table.engine.last_inserted_ids()[0]
                     found = False
-                    for col in table.primary_keys:
+                    for col in self.primary_keys[table]:
                         if self._getattrbycolumn(obj, col) is None:
                             if found:
                                 raise "Only one primary key per inserted row can be set via autoincrement/sequence"
@@ -282,7 +299,7 @@ class Mapper(object):
         return self.instances(statement.execute(**params))
 
     def _identity_key(self, row):
-        return objectstore.get_row_key(row, self.class_, self.table, self.selectable)
+        return objectstore.get_row_key(row, self.class_, self.table, self.primary_keys[self.selectable])
 
     def _instance(self, row, result = None):
         """pulls an object instance from the given row and appends it to the given result list.
@@ -297,7 +314,7 @@ class Mapper(object):
         if not exists:
             instance = self.class_()
             instance._mapper = self.hashkey
-            for column in self.selectable.primary_keys:
+            for column in self.primary_keys[self.selectable]:
                 if row[column.label] is None:
                     return None
             objectstore.put(identitykey, instance, self.scope)
@@ -408,23 +425,20 @@ class PropertyLoader(MapperProperty):
                 self.primaryjoin = self.match_primaries(parent.selectable, self.secondary)
         else:
             if self.primaryjoin is None:
-                if self.foreignkey is not None and self.foreignkey.table == parent.selectable:
-                    self.primaryjoin = self.match_primaries(self.target, parent.selectable)
-                else:
-                    self.primaryjoin = self.match_primaries(parent.selectable, self.target)
+                self.primaryjoin = self.match_primaries(parent.selectable, self.target)
         
         # if the foreign key wasnt specified and theres no assocaition table, try to figure
         # out who is dependent on who. we dont need all the foreign keys represented in the join,
         # just one of them.  
-        if self.foreignkey is None and self.secondaryjoin is None:
+#        if self.foreignkey is None and self.secondaryjoin is None:
             # else we usually will have a one-to-many where the secondary depends on the primary
             # but its possible that its reversed
-            w = PropertyLoader.FindDependent()
-            self.primaryjoin.accept_visitor(w)
-            if w.dependent is None:
-                raise "cant determine primary foreign key in the join relationship....specify foreignkey=<column>"
-            else:
-                self.foreignkey = w.dependent
+#            w = PropertyLoader.FindDependent()
+#            self.primaryjoin.accept_visitor(w)
+#            if w.dependent is None:
+#                raise "cant determine primary foreign key in the join relationship....specify foreignkey=<column>"
+#            else:
+#                self.foreignkey = w.dependent
                 
         if not hasattr(parent.class_, key):
             setattr(parent.class_, key, SmartProperty(key).property(usehistory = True, uselist = self.uselist))
@@ -445,14 +459,23 @@ class PropertyLoader(MapperProperty):
                 
 
     def match_primaries(self, primary, secondary):
-        pk = primary.primary_keys
-        try:
-            if len(pk) == 1:
-                return (pk[0] == secondary.c[pk[0].name])
-            else:
-                return sql.and_([pk == secondary.c[pk.name] for pk in primary.primary_keys])
-        except AttributeError, e:
-            raise e.args[0] + " table: " + secondary.name
+        crit = []
+
+        for fk in secondary.foreign_keys:
+            if fk.column.table is primary:
+                crit.append(fk.column == fk.parent)
+                self.foreignkey = fk.parent
+        for fk in primary.foreign_keys:
+            if fk.column.table is secondary:
+                crit.append(fk.column == fk.parent)
+                self.foreignkey = fk.parent
+
+        if len(crit) == 0:
+            raise "Cant find any foreign key relationships between " + primary.table.name + " and " + secondary.table.name
+        elif len(crit) == 1:
+            return (crit[0])
+        else:
+            return sql.and_(crit)
             
     def register_dependencies(self, objlist, uow):
         if self.secondaryjoin is not None:
index e9365176f3119c0e9a5ecdd9a45e9ff4e4c88d32..aefd9812309fb7b8d49a56b16d00d97874f7842d 100644 (file)
@@ -24,7 +24,7 @@ import thread
 import sqlalchemy.util as util
 import weakref
 
-def get_id_key(ident, class_, table, selectable):
+def get_id_key(ident, class_, table):
     """returns an identity-map key for use in storing/retrieving an item from the identity map, given
     a tuple of the object's primary key values.
     
@@ -37,7 +37,7 @@ def get_id_key(ident, class_, table, selectable):
     return value: a tuple object which is used as an identity key.
     """
     return (class_, table, tuple(ident))
-def get_instance_key(object, class_, table, selectable):
+def get_instance_key(object, class_, table, primary_keys):
     """returns an identity-map key for use in storing/retrieving an item from the identity map, given
     the object instance itself.
     
@@ -49,8 +49,8 @@ def get_instance_key(object, class_, table, selectable):
     may be synonymous with the table argument or can be a larger construct containing that table.
     return value: a tuple object which is used as an identity key.
     """
-    return (class_, table, tuple([getattr(object, column.key, None) for column in selectable.primary_keys]))
-def get_row_key(row, class_, table, selectable):
+    return (class_, table, tuple([getattr(object, column.key, None) for column in primary_keys]))
+def get_row_key(row, class_, table, primary_keys):
     """returns an identity-map key for use in storing/retrieving an item from the identity map, given
     a result set row.
     
@@ -62,7 +62,7 @@ def get_row_key(row, class_, table, selectable):
     may be synonymous with the table argument or can be a larger construct containing that table.
     return value: a tuple object which is used as an identity key.
     """
-    return (class_, table, tuple([row[column.label] for column in selectable.primary_keys]))
+    return (class_, table, tuple([row[column.label] for column in primary_keys]))
 
 identity_map = {}