]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Oct 2005 18:42:33 +0000 (18:42 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Oct 2005 18:42:33 +0000 (18:42 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/tables.py

index d7cce0ce21887d323b3b3b6eaf5a6c84118868d0..c40383d145cb1c47d6fb5b4e5e85f27ce1cf9599 100644 (file)
@@ -107,7 +107,7 @@ class ANSICompiler(sql.Compiled):
             sep = " "
         else:
             sep = " " + compound.operator + " "
-            
+        
         if compound.parens:
             self.strings[compound] = "(" + string.join([self.get_str(c) for c in compound.clauses], sep) + ")"
         else:
index ad88bcc3c3f7ddc61653bf3cedda764342a86a1d..87216773dacedd100484323a3b8c63b4164822ff 100644 (file)
@@ -25,6 +25,10 @@ import sqlalchemy.sql as sql
 import StringIO
 import sqlalchemy.types as types
 
+def create_engine(name, *args ,**kwargs):
+    module = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+    return module.engine(*args, **kwargs)
+
 class SchemaIterator(schema.SchemaVisitor):
     """a visitor that can gather text into a buffer and execute the contents of the buffer."""
     def __init__(self, sqlproxy, **params):
@@ -58,6 +62,7 @@ class SQLEngine(schema.SchemaEngine):
         self.tables = {}
         self.notes = {}
 
+        
     def type_descriptor(self, typeobj):
         if type(typeobj) is type:
             typeobj = typeobj()
index bdbe84d6ee45b08aca1d4cf4f460acff0fbdb01f..e9ec1ee5786c7266308d7175dd1529f2e43c9dd8 100644 (file)
@@ -25,7 +25,9 @@ import random, copy, types
 __ALL__ = ['eagermapper', 'eagerloader', 'lazymapper', 'lazyloader', 'eagerload', 'lazyload', 'assignmapper', 'mapper', 'lazyloader', 'lazymapper', 'clear_mappers', 'objectstore', 'sql']
 
 def relation(*args, **params):
-    if isinstance(args[0], Mapper):
+    if isinstance(args[0], type) and len(args) == 1:
+        return relation_loader(*args, **params)
+    elif isinstance(args[0], Mapper):
         return relation_loader(*args, **params)
     else:
         return relation_mapper(*args, **params)
@@ -68,7 +70,9 @@ def mapper(class_, table = None, engine = None, autoload = False, *args, **param
         return _mappers[hashkey]
     except KeyError:
         m = Mapper(hashkey, class_, table, *args, **params)
-        return _mappers.setdefault(hashkey, m)
+        _mappers.setdefault(hashkey, m)
+        m._init_properties()
+        return _mappers[hashkey]
 
 def clear_mappers():
     _mappers.clear()
@@ -169,6 +173,8 @@ class Mapper(object):
         # load custom properties 
         if properties is not None:
             for key, prop in properties.iteritems():
+                if isinstance(prop, schema.Column):
+                    prop = ColumnProperty(prop)
                 self.props[key] = prop
                 if isinstance(prop, ColumnProperty):
                     for col in prop.columns:
@@ -202,13 +208,13 @@ class Mapper(object):
                 if not self.props.has_key(key):
                     self.props[key] = prop._copy()
                 
-                
         if not hasattr(self.class_, '_mapper') or self.is_primary or not _mappers.has_key(self.class_._mapper):
             self._init_class()
-        [prop.init(key, self) for key, prop in self.props.iteritems()]
         
     engines = property(lambda s: [t.engine for t in s.tables])
 
+    def _init_properties(self):
+        [prop.init(key, self) for key, prop in self.props.iteritems()]
     def __str__(self):
         return "Mapper|" + self.class_.__name__ + "|" + self.primarytable.name
     def hash_key(self):
@@ -217,7 +223,6 @@ class Mapper(object):
     def _init_class(self):
         self.class_._mapper = self.hashkey
         self.class_.c = self.c
-    
     def set_property(self, key, prop):
         self.props[key] = prop
         prop.init(key, self)
@@ -404,6 +409,7 @@ class Mapper(object):
         case, executes all the property loaders on the instance to also process extra information
         in the row."""
 
+            
         # look in main identity map.  if its there, we dont do anything to it,
         # including modifying any of its related items lists, as its already
         # been exposed to being modified by the application.
@@ -430,6 +436,7 @@ class Mapper(object):
             for col in self.primary_keys[self.table]:
                 if row[col.label] is None:
                     return None
+            # plugin point
             instance = self.class_()
             instance._mapper = self.hashkey
             instance._instance_key = identitykey
@@ -442,7 +449,9 @@ class Mapper(object):
 
         if result is not None:
             result.append_nohistory(instance)
-            
+
+        # plugin point
+        
         # call further mapper properties on the row, to pull further 
         # instances from the row and possibly populate this item.
         for prop in self.props.values():
@@ -519,17 +528,15 @@ class ColumnProperty(MapperProperty):
 class PropertyLoader(MapperProperty):
     """describes an object property that holds a single item or list of items that correspond to a related
     database table."""
-    def __init__(self, mapper, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False):
+    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False):
         self.uselist = uselist
-        self.mapper = mapper
-        self.target = self.mapper.table
+        self.argument = argument
         self.secondary = secondary
         self.primaryjoin = primaryjoin
         self.secondaryjoin = secondaryjoin
         self.foreignkey = foreignkey
         self.private = private
-        self._hash_key = "%s(%s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist))
-
+        self._hash_key = "%s(%s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist))
 
     def _copy(self):
         return self.__class__(self.mapper, self.secondary, self.primaryjoin, self.secondaryjoin, self.foreignkey, self.uselist, self.private)
@@ -538,8 +545,14 @@ class PropertyLoader(MapperProperty):
         return self._hash_key
 
     def init(self, key, parent):
-        if isinstance(self.mapper, str):
-            self.mapper = object_mapper(self.mapper)
+        if isinstance(self.argument, str):
+            self.mapper = object_mapper(self.argument)
+        elif isinstance(self.argument, type):
+            self.mapper = class_mapper(self.argument)
+        else:
+            self.mapper = self.argument
+            
+        self.target = self.mapper.table
             
         self.key = key
         self.parent = parent
@@ -610,7 +623,7 @@ class PropertyLoader(MapperProperty):
         elif len(crit) == 1:
             return (crit[0])
         else:
-            return sql.and_(crit)
+            return sql.and_(*crit)
 
     def register_deleted(self, obj, uow):
         if not self.private:
@@ -647,7 +660,7 @@ class PropertyLoader(MapperProperty):
             raise " no foreign key ?"
                 
     def process_dependencies(self, deplist, uowcommit, delete = False):
-        #print self.mapper.table.name + " process_dep isdelete " + repr(delete)
+        print self.mapper.table.name + " " + repr(deplist.map.values()) + " process_dep isdelete " + repr(delete)
         
         # function to retreive the child list off of an object.  "passive" means, if its
         # a lazy loaded list that is not loaded yet, dont load it.
@@ -665,6 +678,8 @@ class PropertyLoader(MapperProperty):
 
         associationrow = {}
         
+        # plugin point
+        
         if self.secondaryjoin is not None:
             secondary_delete = []
             secondary_insert = []
@@ -701,6 +716,7 @@ class PropertyLoader(MapperProperty):
                     statement = self.secondary.insert()
                     statement.execute(*secondary_insert)
         elif self.foreignkey.table == self.target:
+            print "HI"
             if delete and not self.private:
                 updates = []
                 clearkeys = True
@@ -720,9 +736,11 @@ class PropertyLoader(MapperProperty):
                     statement = self.target.update(self.lazywhere, values = values)
                     statement.execute(*updates)
             else:
+                print str(self.primaryjoin.compile())
                 for obj in deplist:
                     childlist = getlist(obj)
                     if childlist is None: return
+                    print "DEP: " +str(obj) + " LIST: " + repr([str(v) for v in childlist.added_items()])
                     uowcommit.register_saved_list(childlist)
                     clearkeys = False
                     for child in childlist.added_items():
@@ -746,23 +764,35 @@ class PropertyLoader(MapperProperty):
                         self.primaryjoin.accept_visitor(setter)
         else:
             raise " no foreign key ?"
+    
+        print self.mapper.table.name + " postdep " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete)
 
     def _sync_foreign_keys(self, binary, obj, child, associationrow, clearkeys):
         """given a binary clause with an = operator joining two table columns, synchronizes the values 
         of the corresponding attributes within a parent object and a child object, or the attributes within an 
         an "association row" that represents an association link between the 'parent' and 'child' object."""
         if binary.operator == '=':
-            colmap = {binary.left.table : binary.left, binary.right.table : binary.right}
-            if colmap.has_key(self.parent.primarytable) and colmap.has_key(self.target):
-                #print "set " + repr(child) + ":" + colmap[self.target].key + " to " + repr(obj) + ":" + colmap[self.parent.primarytable].key
-                if clearkeys:
-                    self.mapper._setattrbycolumn(child, colmap[self.target], None)
+            if binary.left.table == binary.right.table:
+                if binary.right is self.foreignkey:
+                    source = binary.left
+                elif binary.left is self.foreignkey:
+                    source = binary.right
                 else:
-                    self.mapper._setattrbycolumn(child, colmap[self.target], self.parent._getattrbycolumn(obj, colmap[self.parent.primarytable]))
-            elif colmap.has_key(self.parent.primarytable) and colmap.has_key(self.secondary):
-                associationrow[colmap[self.secondary].key] = self.parent._getattrbycolumn(obj, colmap[self.parent.primarytable])
-            elif colmap.has_key(self.target) and colmap.has_key(self.secondary):
-                associationrow[colmap[self.secondary].key] = self.mapper._getattrbycolumn(child, colmap[self.target])
+                    raise "Cant determine direction for relationship %s = %s" % (binary.left.fullname, binary.right.fullname)
+                print "set " + repr(child) + ":" + self.foreignkey.key + " to " + repr(obj) + ":" + source.key
+                self.mapper._setattrbycolumn(child, self.foreignkey, self.parent._getattrbycolumn(obj, source))
+            else:
+                colmap = {binary.left.table : binary.left, binary.right.table : binary.right}
+                if colmap.has_key(self.parent.primarytable) and colmap.has_key(self.target):
+                    print "set " + repr(child) + ":" + colmap[self.target].key + " to " + repr(obj) + ":" + colmap[self.parent.primarytable].key
+                    if clearkeys:
+                        self.mapper._setattrbycolumn(child, colmap[self.target], None)
+                    else:
+                        self.mapper._setattrbycolumn(child, colmap[self.target], self.parent._getattrbycolumn(obj, colmap[self.parent.primarytable]))
+                elif colmap.has_key(self.parent.primarytable) and colmap.has_key(self.secondary):
+                    associationrow[colmap[self.secondary].key] = self.parent._getattrbycolumn(obj, colmap[self.parent.primarytable])
+                elif colmap.has_key(self.target) and colmap.has_key(self.secondary):
+                    associationrow[colmap[self.secondary].key] = self.mapper._getattrbycolumn(child, colmap[self.target])
             
 
 # TODO: break out the lazywhere capability so that the main PropertyLoader can use it
@@ -937,9 +967,11 @@ class BinaryVisitor(sql.ClauseVisitor):
 def hash_key(obj):
     if obj is None:
         return 'None'
-    else:
+    elif hasattr(obj, 'hash_key'):
         return obj.hash_key()
-
+    else:
+        return repr(obj)
+        
 def mapper_hash_key(class_, table, primarytable = None, properties = None, scope = "thread", **kwargs):
     if properties is None:
         properties = {}
index 7d27e78d586cde8d88b53a0c4dd9febf559c583c..54386c56fb49d3a521ea68780b092169d0165bb2 100644 (file)
@@ -301,6 +301,8 @@ class UOWTransaction(object):
         for task in self.tasks.values():
             task.mapper.register_dependencies(self)
         
+        print repr(self.dependencies)
+        
         for task in self._sort_dependencies():
             obj_list = task.objects
             if not task.listonly and not task.isdelete:
@@ -387,6 +389,8 @@ class UOWTransaction(object):
             if task is not None:
                 res.append(task)
             for child in node.children:
+                if child is node:
+                    continue
                 sort(child, isdel, res)
             return res
             
index 3887c63b5565cb2c3ccaa115bb32cbf1307b95a3..dfa3cdbf47ca0d13422034c150ea484a141e4433 100644 (file)
@@ -218,13 +218,23 @@ class ClauseElement(object):
         return self
 
 
-    def compile(self, engine, bindparams = None):
+    def compile(self, engine = None, bindparams = None):
         """compiles this SQL expression using its underlying SQLEngine to produce
         a Compiled object.  The actual SQL statement is the Compiled object's string representation.   
         bindparams is an optional dictionary representing the bind parameters to be used with 
         the statement.  Currently, only the compilations of INSERT and UPDATE statements
         use the bind parameters, in order to determine which
         table columns should be used in the statement."""
+
+        if engine is None:
+            for f in self._get_from_objects():
+                engine = f.engine
+                if engine is not None: break
+            else:
+                import sqlalchemy.ansisql as ansisql
+                engine = ansisql.engine()
+                #raise "no engine supplied, and no engine could be located within the clauses!"
+
         return engine.compile(self, bindparams = bindparams)
 
     def execute(self, *multiparams, **params):
@@ -317,7 +327,7 @@ class TextClause(ClauseElement):
     def __init__(self, text = ""):
         self.text = text
         self.parens = False
-
+        
     def accept_visitor(self, visitor): visitor.visit_textclause(self)
 
     def hash_key(self):
index 0ee43be01c2e1739d4f439e3dc3215295f1f9cb0..bc13f5dbde6da1b75e50529bfc5230e189be92f4 100644 (file)
@@ -162,7 +162,10 @@ class HistoryArraySet(UserList.UserList):
                 self.records[item] = None
         else:
             self.data = []
-
+    def __getattr__(self, attr):
+        """proxies unknown HistoryArraySet methods and attributes to the underlying
+        data array.  this allows custom list classes to be used."""
+        return getattr(self.data, attr)
     def set_data(self, data):
         # first mark everything current as "deleted"
         for i in self.data:
index 0cd6ffb6d6c2b1ada62fbff2fab4c2d3366a6917..12eeb6b05b7eda41062e6efed33ea76c6c28fceb 100644 (file)
@@ -2,6 +2,7 @@
 from sqlalchemy.sql import *
 from sqlalchemy.schema import *
 from sqlalchemy.mapper import *
+import sqlalchemy
 import os
 import testbase
 
@@ -13,13 +14,12 @@ DATA = True
 DBTYPE = 'sqlite_memory'
 
 if DBTYPE == 'sqlite_memory':
-    import sqlalchemy.databases.sqlite as sqllite
-    db = sqllite.engine(':memory:', {}, echo = ECHO)
+    db = sqlalchemy.engine.create_engine('sqlite', ':memory:', {}, echo = testbase.echo)
 elif DBTYPE == 'sqlite_file':
     import sqlalchemy.databases.sqlite as sqllite
     if os.access('querytest.db', os.F_OK):
         os.remove('querytest.db')
-    db = sqllite.engine('querytest.db', opts = {}, echo = ECHO)
+    db = sqlalchemy.engine.create_engine('sqlite', 'querytest.db', {}, echo = testbase.echo)
 elif DBTYPE == 'postgres':
     pass