From: Mike Bayer Date: Sat, 1 Oct 2005 18:42:33 +0000 (+0000) Subject: (no commit message) X-Git-Tag: rel_0_1_0~581 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=595e11015fe6809a56a803bda8f79b784c35e555;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index d7cce0ce21..c40383d145 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -107,7 +107,7 @@ class ANSICompiler(sql.Compiled): sep = " " else: sep = " " + compound.operator + " " - + if compound.parens: self.strings[compound] = "(" + string.join([self.get_str(c) for c in compound.clauses], sep) + ")" else: diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index ad88bcc3c3..87216773da 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -25,6 +25,10 @@ import sqlalchemy.sql as sql import StringIO import sqlalchemy.types as types +def create_engine(name, *args ,**kwargs): + module = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) + return module.engine(*args, **kwargs) + class SchemaIterator(schema.SchemaVisitor): """a visitor that can gather text into a buffer and execute the contents of the buffer.""" def __init__(self, sqlproxy, **params): @@ -58,6 +62,7 @@ class SQLEngine(schema.SchemaEngine): self.tables = {} self.notes = {} + def type_descriptor(self, typeobj): if type(typeobj) is type: typeobj = typeobj() diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index bdbe84d6ee..e9ec1ee578 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -25,7 +25,9 @@ import random, copy, types __ALL__ = ['eagermapper', 'eagerloader', 'lazymapper', 'lazyloader', 'eagerload', 'lazyload', 'assignmapper', 'mapper', 'lazyloader', 'lazymapper', 'clear_mappers', 'objectstore', 'sql'] def relation(*args, **params): - if isinstance(args[0], Mapper): + if isinstance(args[0], type) and len(args) == 1: + return relation_loader(*args, **params) + elif isinstance(args[0], Mapper): return relation_loader(*args, **params) else: return relation_mapper(*args, **params) @@ -68,7 +70,9 @@ def mapper(class_, table = None, engine = None, autoload = False, *args, **param return _mappers[hashkey] except KeyError: m = Mapper(hashkey, class_, table, *args, **params) - return _mappers.setdefault(hashkey, m) + _mappers.setdefault(hashkey, m) + m._init_properties() + return _mappers[hashkey] def clear_mappers(): _mappers.clear() @@ -169,6 +173,8 @@ class Mapper(object): # load custom properties if properties is not None: for key, prop in properties.iteritems(): + if isinstance(prop, schema.Column): + prop = ColumnProperty(prop) self.props[key] = prop if isinstance(prop, ColumnProperty): for col in prop.columns: @@ -202,13 +208,13 @@ class Mapper(object): if not self.props.has_key(key): self.props[key] = prop._copy() - if not hasattr(self.class_, '_mapper') or self.is_primary or not _mappers.has_key(self.class_._mapper): self._init_class() - [prop.init(key, self) for key, prop in self.props.iteritems()] engines = property(lambda s: [t.engine for t in s.tables]) + def _init_properties(self): + [prop.init(key, self) for key, prop in self.props.iteritems()] def __str__(self): return "Mapper|" + self.class_.__name__ + "|" + self.primarytable.name def hash_key(self): @@ -217,7 +223,6 @@ class Mapper(object): def _init_class(self): self.class_._mapper = self.hashkey self.class_.c = self.c - def set_property(self, key, prop): self.props[key] = prop prop.init(key, self) @@ -404,6 +409,7 @@ class Mapper(object): case, executes all the property loaders on the instance to also process extra information in the row.""" + # look in main identity map. if its there, we dont do anything to it, # including modifying any of its related items lists, as its already # been exposed to being modified by the application. @@ -430,6 +436,7 @@ class Mapper(object): for col in self.primary_keys[self.table]: if row[col.label] is None: return None + # plugin point instance = self.class_() instance._mapper = self.hashkey instance._instance_key = identitykey @@ -442,7 +449,9 @@ class Mapper(object): if result is not None: result.append_nohistory(instance) - + + # plugin point + # call further mapper properties on the row, to pull further # instances from the row and possibly populate this item. for prop in self.props.values(): @@ -519,17 +528,15 @@ class ColumnProperty(MapperProperty): class PropertyLoader(MapperProperty): """describes an object property that holds a single item or list of items that correspond to a related database table.""" - def __init__(self, mapper, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False): + def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False): self.uselist = uselist - self.mapper = mapper - self.target = self.mapper.table + self.argument = argument self.secondary = secondary self.primaryjoin = primaryjoin self.secondaryjoin = secondaryjoin self.foreignkey = foreignkey self.private = private - self._hash_key = "%s(%s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist)) - + self._hash_key = "%s(%s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist)) def _copy(self): return self.__class__(self.mapper, self.secondary, self.primaryjoin, self.secondaryjoin, self.foreignkey, self.uselist, self.private) @@ -538,8 +545,14 @@ class PropertyLoader(MapperProperty): return self._hash_key def init(self, key, parent): - if isinstance(self.mapper, str): - self.mapper = object_mapper(self.mapper) + if isinstance(self.argument, str): + self.mapper = object_mapper(self.argument) + elif isinstance(self.argument, type): + self.mapper = class_mapper(self.argument) + else: + self.mapper = self.argument + + self.target = self.mapper.table self.key = key self.parent = parent @@ -610,7 +623,7 @@ class PropertyLoader(MapperProperty): elif len(crit) == 1: return (crit[0]) else: - return sql.and_(crit) + return sql.and_(*crit) def register_deleted(self, obj, uow): if not self.private: @@ -647,7 +660,7 @@ class PropertyLoader(MapperProperty): raise " no foreign key ?" def process_dependencies(self, deplist, uowcommit, delete = False): - #print self.mapper.table.name + " process_dep isdelete " + repr(delete) + print self.mapper.table.name + " " + repr(deplist.map.values()) + " process_dep isdelete " + repr(delete) # function to retreive the child list off of an object. "passive" means, if its # a lazy loaded list that is not loaded yet, dont load it. @@ -665,6 +678,8 @@ class PropertyLoader(MapperProperty): associationrow = {} + # plugin point + if self.secondaryjoin is not None: secondary_delete = [] secondary_insert = [] @@ -701,6 +716,7 @@ class PropertyLoader(MapperProperty): statement = self.secondary.insert() statement.execute(*secondary_insert) elif self.foreignkey.table == self.target: + print "HI" if delete and not self.private: updates = [] clearkeys = True @@ -720,9 +736,11 @@ class PropertyLoader(MapperProperty): statement = self.target.update(self.lazywhere, values = values) statement.execute(*updates) else: + print str(self.primaryjoin.compile()) for obj in deplist: childlist = getlist(obj) if childlist is None: return + print "DEP: " +str(obj) + " LIST: " + repr([str(v) for v in childlist.added_items()]) uowcommit.register_saved_list(childlist) clearkeys = False for child in childlist.added_items(): @@ -746,23 +764,35 @@ class PropertyLoader(MapperProperty): self.primaryjoin.accept_visitor(setter) else: raise " no foreign key ?" + + print self.mapper.table.name + " postdep " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete) def _sync_foreign_keys(self, binary, obj, child, associationrow, clearkeys): """given a binary clause with an = operator joining two table columns, synchronizes the values of the corresponding attributes within a parent object and a child object, or the attributes within an an "association row" that represents an association link between the 'parent' and 'child' object.""" if binary.operator == '=': - colmap = {binary.left.table : binary.left, binary.right.table : binary.right} - if colmap.has_key(self.parent.primarytable) and colmap.has_key(self.target): - #print "set " + repr(child) + ":" + colmap[self.target].key + " to " + repr(obj) + ":" + colmap[self.parent.primarytable].key - if clearkeys: - self.mapper._setattrbycolumn(child, colmap[self.target], None) + if binary.left.table == binary.right.table: + if binary.right is self.foreignkey: + source = binary.left + elif binary.left is self.foreignkey: + source = binary.right else: - self.mapper._setattrbycolumn(child, colmap[self.target], self.parent._getattrbycolumn(obj, colmap[self.parent.primarytable])) - elif colmap.has_key(self.parent.primarytable) and colmap.has_key(self.secondary): - associationrow[colmap[self.secondary].key] = self.parent._getattrbycolumn(obj, colmap[self.parent.primarytable]) - elif colmap.has_key(self.target) and colmap.has_key(self.secondary): - associationrow[colmap[self.secondary].key] = self.mapper._getattrbycolumn(child, colmap[self.target]) + raise "Cant determine direction for relationship %s = %s" % (binary.left.fullname, binary.right.fullname) + print "set " + repr(child) + ":" + self.foreignkey.key + " to " + repr(obj) + ":" + source.key + self.mapper._setattrbycolumn(child, self.foreignkey, self.parent._getattrbycolumn(obj, source)) + else: + colmap = {binary.left.table : binary.left, binary.right.table : binary.right} + if colmap.has_key(self.parent.primarytable) and colmap.has_key(self.target): + print "set " + repr(child) + ":" + colmap[self.target].key + " to " + repr(obj) + ":" + colmap[self.parent.primarytable].key + if clearkeys: + self.mapper._setattrbycolumn(child, colmap[self.target], None) + else: + self.mapper._setattrbycolumn(child, colmap[self.target], self.parent._getattrbycolumn(obj, colmap[self.parent.primarytable])) + elif colmap.has_key(self.parent.primarytable) and colmap.has_key(self.secondary): + associationrow[colmap[self.secondary].key] = self.parent._getattrbycolumn(obj, colmap[self.parent.primarytable]) + elif colmap.has_key(self.target) and colmap.has_key(self.secondary): + associationrow[colmap[self.secondary].key] = self.mapper._getattrbycolumn(child, colmap[self.target]) # TODO: break out the lazywhere capability so that the main PropertyLoader can use it @@ -937,9 +967,11 @@ class BinaryVisitor(sql.ClauseVisitor): def hash_key(obj): if obj is None: return 'None' - else: + elif hasattr(obj, 'hash_key'): return obj.hash_key() - + else: + return repr(obj) + def mapper_hash_key(class_, table, primarytable = None, properties = None, scope = "thread", **kwargs): if properties is None: properties = {} diff --git a/lib/sqlalchemy/objectstore.py b/lib/sqlalchemy/objectstore.py index 7d27e78d58..54386c56fb 100644 --- a/lib/sqlalchemy/objectstore.py +++ b/lib/sqlalchemy/objectstore.py @@ -301,6 +301,8 @@ class UOWTransaction(object): for task in self.tasks.values(): task.mapper.register_dependencies(self) + print repr(self.dependencies) + for task in self._sort_dependencies(): obj_list = task.objects if not task.listonly and not task.isdelete: @@ -387,6 +389,8 @@ class UOWTransaction(object): if task is not None: res.append(task) for child in node.children: + if child is node: + continue sort(child, isdel, res) return res diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 3887c63b55..dfa3cdbf47 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -218,13 +218,23 @@ class ClauseElement(object): return self - def compile(self, engine, bindparams = None): + def compile(self, engine = None, bindparams = None): """compiles this SQL expression using its underlying SQLEngine to produce a Compiled object. The actual SQL statement is the Compiled object's string representation. bindparams is an optional dictionary representing the bind parameters to be used with the statement. Currently, only the compilations of INSERT and UPDATE statements use the bind parameters, in order to determine which table columns should be used in the statement.""" + + if engine is None: + for f in self._get_from_objects(): + engine = f.engine + if engine is not None: break + else: + import sqlalchemy.ansisql as ansisql + engine = ansisql.engine() + #raise "no engine supplied, and no engine could be located within the clauses!" + return engine.compile(self, bindparams = bindparams) def execute(self, *multiparams, **params): @@ -317,7 +327,7 @@ class TextClause(ClauseElement): def __init__(self, text = ""): self.text = text self.parens = False - + def accept_visitor(self, visitor): visitor.visit_textclause(self) def hash_key(self): diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 0ee43be01c..bc13f5dbde 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -162,7 +162,10 @@ class HistoryArraySet(UserList.UserList): self.records[item] = None else: self.data = [] - + def __getattr__(self, attr): + """proxies unknown HistoryArraySet methods and attributes to the underlying + data array. this allows custom list classes to be used.""" + return getattr(self.data, attr) def set_data(self, data): # first mark everything current as "deleted" for i in self.data: diff --git a/test/tables.py b/test/tables.py index 0cd6ffb6d6..12eeb6b05b 100644 --- a/test/tables.py +++ b/test/tables.py @@ -2,6 +2,7 @@ from sqlalchemy.sql import * from sqlalchemy.schema import * from sqlalchemy.mapper import * +import sqlalchemy import os import testbase @@ -13,13 +14,12 @@ DATA = True DBTYPE = 'sqlite_memory' if DBTYPE == 'sqlite_memory': - import sqlalchemy.databases.sqlite as sqllite - db = sqllite.engine(':memory:', {}, echo = ECHO) + db = sqlalchemy.engine.create_engine('sqlite', ':memory:', {}, echo = testbase.echo) elif DBTYPE == 'sqlite_file': import sqlalchemy.databases.sqlite as sqllite if os.access('querytest.db', os.F_OK): os.remove('querytest.db') - db = sqllite.engine('querytest.db', opts = {}, echo = ECHO) + db = sqlalchemy.engine.create_engine('sqlite', 'querytest.db', {}, echo = testbase.echo) elif DBTYPE == 'postgres': pass