From: Mike Bayer Date: Thu, 8 Jun 2006 16:58:14 +0000 (+0000) Subject: late compilation of mappers. now you can create mappers in any order, and they will... X-Git-Tag: rel_0_2_3~19 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8005c151593f1b9ffcc69b3b32ac57ef1c052fa0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git late compilation of mappers. now you can create mappers in any order, and they will compile their internal state when first used in a query or flush operation (or their props or 'c'/'columns' attributes are used). includes various cleanups and fixes in support of the change, including some unit test changes, additional unit tests. --- diff --git a/examples/backref/backref_tree.py b/examples/backref/backref_tree.py index 3f81b1145f..d217b22f1c 100644 --- a/examples/backref/backref_tree.py +++ b/examples/backref/backref_tree.py @@ -19,7 +19,7 @@ table.create() mapper(Tree, table, properties={ - 'childs':relation(Tree, foreignkey=table.c.father_id, primaryjoin=table.c.father_id==table.c.id, backref=backref('father', uselist=False, foreignkey=table.c.id))}, + 'childs':relation(Tree, foreignkey=table.c.father_id, primaryjoin=table.c.father_id==table.c.id, backref=backref('father', foreignkey=table.c.id))}, ) root = Tree('root') @@ -27,6 +27,9 @@ child1 = Tree('child1', root) child2 = Tree('child2', root) child3 = Tree('child3', child1) +child4 = Tree('child4') +child1.childs.append(child4) + session = create_session() session.save(root) session.flush() diff --git a/examples/vertical/vertical.py b/examples/vertical/vertical.py index fbd9021ffa..66224fb5bd 100644 --- a/examples/vertical/vertical.py +++ b/examples/vertical/vertical.py @@ -1,5 +1,6 @@ from sqlalchemy import * import datetime +import sys """this example illustrates a "vertical table". an object is stored with individual attributes represented in distinct database rows. This allows objects to be created with dynamically changing diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 21f19da02e..0fbcdb8e17 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -10,7 +10,7 @@ import util as mapperutil import sync import query as querylib import session as sessionlib -import sys, weakref, sets +import weakref __all__ = ['Mapper', 'MapperExtension', 'class_mapper', 'object_mapper', 'EXT_PASS'] @@ -28,6 +28,12 @@ NO_ATTRIBUTE = object() # returned by a MapperExtension method to indicate a "do nothing" response EXT_PASS = object() +# as mappers are constructed, they place records in this dictionary +# to set up "compile triggers" between mappers related by backref setups, so that when one +# mapper compiles it can trigger the compilation of a second mapper which needs to place +# a backref on the first. +_compile_triggers = {} + class Mapper(object): """Persists object instances to and from schema.Table objects via the sql package. Instances of this class should be constructed through this package's mapper() or @@ -53,6 +59,161 @@ class Mapper(object): concrete=False, select_table=None): + if not issubclass(class_, object): + raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) + + for table in (local_table, select_table): + if table is not None and isinstance(table, sql.SelectBaseMixin): + # some db's, noteably postgres, dont want to select from a select + # without an alias. also if we make our own alias internally, then + # the configured properties on the mapper are not matched against the alias + # we make, theres workarounds but it starts to get really crazy (its crazy enough + # the SQL that gets generated) so just require an alias + raise exceptions.ArgumentError("Mapping against a Select object requires that it has a name. Use an alias to give it a name, i.e. s = select(...).alias('myselect')") + + self.class_ = class_ + self.entity_name = entity_name + self.class_key = ClassKey(class_, entity_name) + self.is_primary = is_primary + self.primary_key = primary_key + self.non_primary = non_primary + self.order_by = order_by + self.always_refresh = always_refresh + self.version_id_col = version_id_col + self.concrete = concrete + self.inherits = inherits + self.select_table = select_table + self.local_table = local_table + self.inherit_condition = inherit_condition + self.extension = extension + self.properties = properties or {} + self.allow_column_override = allow_column_override + + # a Column which is used during a select operation to retrieve the + # "polymorphic identity" of the row, which indicates which Mapper should be used + # to construct a new object instance from that row. + self.polymorphic_on = polymorphic_on + + # our 'polymorphic identity', a string name that when located in a result set row + # indicates this Mapper should be used to construct the object instance for that row. + self.polymorphic_identity = polymorphic_identity + + # a dictionary of 'polymorphic identity' names, associating those names with + # Mappers that will be used to construct object instances upon a select operation. + if polymorphic_map is None: + self.polymorphic_map = {} + else: + self.polymorphic_map = polymorphic_map + + class LOrderedProp(util.OrderedProperties): + """this extends OrderedProperties to trigger a compile() before the + members of the object are accessed.""" + def __getattr__(s, key): + self.compile() + return util.OrderedProperties.__getattr__(s, key) + + self.columns = LOrderedProp() + self.c = self.columns + + # each time the options() method is called, the resulting Mapper is + # stored in this dictionary based on the given options for fast re-access + self._options = {} + + # a set of all mappers which inherit from this one. + self._inheriting_mappers = util.Set() + + # a second mapper that is used for selecting, if the "select_table" argument + # was sent to this mapper. + self.__surrogate_mapper = None + + # whether or not our compile() method has been called already. + self.__is_compiled = False + + # if this mapper is to be a primary mapper (i.e. the non_primary flag is not set), + # associate this Mapper with the given class_ and entity name. subsequent + # calls to class_mapper() for the class_/entity name combination will return this + # mapper. + self._compile_class() + + # for all MapperProperties sent in the properties dictionary (typically this means + # (relation() instances), call the "attach()" method which may be used to set up + # compile triggers for this Mapper. + for prop in self.properties.values(): + if isinstance(prop, MapperProperty): + prop.attach(self) + + # uncomment to compile at construction time (the old way) + # this will break mapper setups that arent declared in the order + # of dependency + #self.compile() + + def _get_props(self): + self.compile() + return self.__props + props = property(_get_props, doc="compiles this mapper if needed, and returns the \ + dictionary of MapperProperty objects associated with this mapper.") + + def compile(self): + """compile. this step assembles the Mapper's constructor arguments into their final internal + format, which includes establishing its relationships with other Mappers either via inheritance + relationships or via attached properties. This step is deferred to when the Mapper is first used + (i.e. queried, used in a flush(), or its columns or properties are accessed) so that Mappers can be + constructed in an arbitrary order, completing their relationships when they have all been established.""" + if self.__is_compiled: + return self + #print "COMPILING!", self.class_key + self.__is_compiled = True + self._compile_extensions() + self._compile_inheritance() + self._compile_tables() + self._compile_selectable() + self._compile_properties() + self._initialize_properties() + + # compile some other mappers which have backrefs to this mapper + triggerset = _compile_triggers.pop(self.class_key, None) + if triggerset is not None: + for rec in triggerset: + (mapper, set) = rec + set.remove(self.class_key) + if len(set) == 0: + mapper.compile() + del _compile_triggers[mapper] + + return self + + def _add_compile_trigger(self, argument): + """given an argument which is either a Class or a Mapper, sets a + "compile trigger" indicating this mapper should be compiled directly + after the given mapper (or class's mapper) is compiled.""" + + if isinstance(argument, Mapper): + classkey = argument.class_key + else: + classkey = ClassKey(argument, None) + + try: + rec = _compile_triggers[self] + except KeyError: + # a tuple of: (mapper to be compiled, Set of classkeys of mappers to be compiled first) + rec = (self, util.Set()) + _compile_triggers[self] = rec + if classkey in rec[1]: + return + + rec[1].add(classkey) + try: + triggers = _compile_triggers[classkey] + except KeyError: + # list of the above tuples corresponding to a particular class key + triggers = [] + _compile_triggers[classkey] = triggers + triggers.append(rec) + _compile_triggers[classkey] = triggers + + def _compile_extensions(self): + """goes through the global_extensions list as well as the list of MapperExtensions + specified for this Mapper and creates a linked list of those extensions.""" # uber-pendantic style of making mapper chain, as various testbase/ # threadlocal/assignmapper combinations keep putting dupes etc. in the list # TODO: do something that isnt 21 lines.... @@ -63,10 +224,11 @@ class Mapper(object): else: extlist.add(ext_class()) + extension = self.extension if extension is not None: for ext_obj in util.to_list(extension): extlist.add(ext_obj) - + self.extension = None previous = None for ext in extlist: @@ -77,117 +239,103 @@ class Mapper(object): previous = ext if self.extension is None: self.extension = MapperExtension() - - self.class_ = class_ - self.entity_name = entity_name - self.class_key = ClassKey(class_, entity_name) - self.is_primary = is_primary - self.non_primary = non_primary - self.order_by = order_by - self._options = {} - self.always_refresh = always_refresh - self.version_id_col = version_id_col - self._inheriting_mappers = util.Set() - self.polymorphic_on = polymorphic_on - if polymorphic_map is None: - self.polymorphic_map = {} - else: - self.polymorphic_map = polymorphic_map - self.__surrogate_mapper = None - self._surrogate_parent = None - - if not issubclass(class_, object): - raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) - # set up various Selectable units: - - # mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table) - # local_table - the Selectable that was passed to this Mapper's constructor, if any - # select_table - the Selectable that will be used during queries. if this is specified - # as a constructor keyword argument, it takes precendence over mapped_table, otherwise its mapped_table - # unjoined_table - our Selectable, minus any joins constructed against the inherits table. - # this is either select_table if it was given explicitly, or in the case of a mapper that inherits - # its local_table - # tables - a collection of underlying Table objects pulled from mapped_table - - for table in (local_table, select_table): - if table is not None and isinstance(table, sql.SelectBaseMixin): - # some db's, noteably postgres, dont want to select from a select - # without an alias. also if we make our own alias internally, then - # the configured properties on the mapper are not matched against the alias - # we make, theres workarounds but it starts to get really crazy (its crazy enough - # the SQL that gets generated) so just require an alias - raise exceptions.ArgumentError("Mapping against a Select object requires that it has a name. Use an alias to give it a name, i.e. s = select(...).alias('myselect')") - - self.local_table = local_table - - if inherits is not None: - if isinstance(inherits, type): - inherits = class_mapper(inherits) - if self.class_.__mro__[1] != inherits.class_: - raise exceptions.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, inherits.class_.__name__)) + def _compile_inheritance(self): + """determines if this Mapper inherits from another mapper, and if so calculates the mapped_table + for this Mapper taking the inherited mapper into account. for joined table inheritance, creates + a SyncRule that will synchronize column values between the joined tables. also initializes polymorphic variables + used in polymorphic loads.""" + if self.inherits is not None: + if isinstance(self.inherits, type): + self.inherits = class_mapper(self.inherits) + else: + self.inherits = self.inherits.compile() + if self.class_.__mro__[1] != self.inherits.class_: + raise exceptions.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__)) # inherit_condition is optional. - if local_table is None: - self.local_table = local_table = inherits.local_table - if not local_table is inherits.local_table: - if concrete: + if self.local_table is None: + self.local_table = self.inherits.local_table + if not self.local_table is self.inherits.local_table: + if self.concrete: self._synchronizer= None self.mapped_table = self.local_table else: - if inherit_condition is None: + if self.inherit_condition is None: # figure out inherit condition from our table to the immediate table # of the inherited mapper, not its full table which could pull in other # stuff we dont want (allows test/inheritance.InheritTest4 to pass) - inherit_condition = sql.join(inherits.local_table, self.local_table).onclause - self.mapped_table = sql.join(inherits.mapped_table, self.local_table, inherit_condition) - #print "inherit condition", str(self.table.onclause) - + self.inherit_condition = sql.join(self.inherits.local_table, self.local_table).onclause + self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition) # generate sync rules. similarly to creating the on clause, specify a # stricter set of tables to create "sync rules" by,based on the immediate # inherited table, rather than all inherited tables self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) - self._synchronizer.compile(self.mapped_table.onclause, util.Set([inherits.local_table]), sqlutil.TableFinder(self.local_table)) + self._synchronizer.compile(self.mapped_table.onclause, util.Set([self.inherits.local_table]), sqlutil.TableFinder(self.local_table)) else: self._synchronizer = None self.mapped_table = self.local_table - self.inherits = inherits - if polymorphic_identity is not None: - inherits.add_polymorphic_mapping(polymorphic_identity, self) - self.polymorphic_identity = polymorphic_identity - if self.polymorphic_on is None and inherits.polymorphic_on is not None: - self.polymorphic_on = self.mapped_table.corresponding_column(inherits.polymorphic_on, keys_ok=True, raiseerr=False) + if self.polymorphic_identity is not None: + self.inherits._add_polymorphic_mapping(self.polymorphic_identity, self) + if self.polymorphic_on is None and self.inherits.polymorphic_on is not None: + self.polymorphic_on = self.mapped_table.corresponding_column(self.inherits.polymorphic_on, keys_ok=True, raiseerr=False) if self.order_by is False: - self.order_by = inherits.order_by - self.polymorphic_map = inherits.polymorphic_map + self.order_by = self.inherits.order_by + self.polymorphic_map = self.inherits.polymorphic_map else: self._synchronizer = None - self.inherits = None self.mapped_table = self.local_table - if polymorphic_identity is not None: - self.add_polymorphic_mapping(polymorphic_identity, self) - self.polymorphic_identity = polymorphic_identity + if self.polymorphic_identity is not None: + self._add_polymorphic_mapping(self.polymorphic_identity, self) + + # convert polymorphic class associations to mappers + for key in self.polymorphic_map.keys(): + if isinstance(self.polymorphic_map[key], type): + self.polymorphic_map[key] = class_mapper(self.polymorphic_map[key]) - if select_table is not None: - self.select_table = select_table - else: + def _add_polymorphic_mapping(self, key, class_or_mapper, entity_name=None): + """adds a Mapper to our 'polymorphic map' """ + if isinstance(class_or_mapper, type): + class_or_mapper = class_mapper(class_or_mapper, entity_name=entity_name) + self.polymorphic_map[key] = class_or_mapper + + def _compile_tables(self): + """after the inheritance relationships have been reconciled, sets up some more table-based instance + variables and determines the "primary key" columns for all tables represented by this Mapper.""" + + # summary of the various Selectable units: + # mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table) + # local_table - the Selectable that was passed to this Mapper's constructor, if any + # select_table - the Selectable that will be used during queries. if this is specified + # as a constructor keyword argument, it takes precendence over mapped_table, otherwise its mapped_table + # unjoined_table - our Selectable, minus any joins constructed against the inherits table. + # this is either select_table if it was given explicitly, or in the case of a mapper that inherits + # its local_table + # tables - a collection of underlying Table objects pulled from mapped_table + + if self.select_table is None: self.select_table = self.mapped_table self.unjoined_table = self.local_table - # locate all tables contained within the "table" passed in, which # may be a join or other construct self.tables = sqlutil.TableFinder(self.mapped_table) # determine primary key columns, either passed in, or get them from our set of tables self.pks_by_table = {} - if primary_key is not None: - for k in primary_key: + if self.primary_key is not None: + # determine primary keys using user-given list of primary key columns as a guide + # + # TODO: this might not work very well for joined-table and/or polymorphic + # inheritance mappers since local_table isnt taken into account nor is select_table + # need to test custom primary key columns used with inheriting mappers + for k in self.primary_key: self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k) if k.table != self.mapped_table: # associate pk cols from subtables to the "main" table self.pks_by_table.setdefault(self.mapped_table, util.OrderedSet()).add(k) - # TODO: need local_table properly accounted for when custom primary key is sent else: + # no user-defined primary key columns - go through all of our represented tables + # and assemble primary key columns for t in self.tables + [self.mapped_table]: try: l = self.pks_by_table[t] @@ -195,38 +343,39 @@ class Mapper(object): l = self.pks_by_table.setdefault(t, util.OrderedSet()) for k in t.primary_key: l.add(k) - + if len(self.pks_by_table[self.mapped_table]) == 0: raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) - - # make table columns addressable via the mapper - self.columns = util.OrderedProperties() - self.c = self.columns - + + + def _compile_properties(self): + """inspects the properties dictionary sent to the Mapper's constructor as well as the mapped_table, and creates + MapperProperty objects corresponding to each mapped column and relation. also grabs MapperProperties from the + inherited mapper, if any, and creates copies of them to attach to this Mapper.""" # object attribute names mapped to MapperProperty objects - self.props = {} - + self.__props = {} + # table columns mapped to lists of MapperProperty objects # using a list allows a single column to be defined as # populating multiple object attributes self.columntoproperty = TranslatingDict(self.mapped_table) - + # load custom properties - if properties is not None: - for key, prop in properties.iteritems(): - self.add_property(key, prop, False) + if self.properties is not None: + for key, prop in self.properties.iteritems(): + self._compile_property(key, prop, False) - if inherits is not None: + if self.inherits is not None: # transfer properties from the inherited mapper to here. # this includes column properties as well as relations. # the column properties will attempt to be translated from the selectable unit # of the parent mapper to this mapper's selectable unit. - inherits._inheriting_mappers.add(self) - for key, prop in inherits.props.iteritems(): - if not self.props.has_key(key): + self.inherits._inheriting_mappers.add(self) + for key, prop in self.inherits.props.iteritems(): + if not self.__props.has_key(key): p = prop.copy() if p.adapt(self): - self.add_property(key, p, init=False) + self._compile_property(key, p, init=False) # load properties from the main table object, # not overriding those set up in the 'properties' argument @@ -236,58 +385,100 @@ class Mapper(object): if not self.columns.has_key(column.key): self.columns[column.key] = self.select_table.corresponding_column(column, keys_ok=True, raiseerr=True) - prop = self.props.get(column.key, None) + prop = self.__props.get(column.key, None) if prop is None: prop = ColumnProperty(column) - self.props[column.key] = prop + self.__props[column.key] = prop elif isinstance(prop, ColumnProperty): - # the order which columns are appended to a ColumnProperty is significant, as the - # column at index 0 determines which result column is used to populate the object - # attribute, in the case of mapping against a join with column names repeated - # (and particularly in an inheritance relationship) - # TODO: clarify this comment - prop.columns.insert(0, column) - #prop.columns.append(column) + prop.columns.append(column) else: - if not allow_column_override: + if not self.allow_column_override: raise exceptions.ArgumentError("WARNING: column '%s' not being added due to property '%s'. Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop))) else: continue - + # its a ColumnProperty - match the ultimate table columns # back to the property proplist = self.columntoproperty.setdefault(column, []) proplist.append(prop) - - if not non_primary and (not mapper_registry.has_key(self.class_key) or self.is_primary or (inherits is not None and inherits._is_primary_mapper())): - sessionlib.global_attributes.reset_class_managed(self.class_) - self._init_class() - elif not non_primary: - raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined. Use is_primary=True to assign a new primary mapper to the class, or use non_primary=True to create a non primary Mapper" % self.class_) - for key in self.polymorphic_map.keys(): - if isinstance(self.polymorphic_map[key], type): - self.polymorphic_map[key] = class_mapper(self.polymorphic_map[key]) + def _initialize_properties(self): + """calls the init() method on all MapperProperties attached to this mapper. this will incur the + compilation of related mappers.""" + l = [(key, prop) for key, prop in self.__props.iteritems()] + for key, prop in l: + if getattr(prop, 'key', None) is None: + prop.init(key, self) - # select_table specified...set up a surrogate mapper that will be used for selects - # select_table has to encompass all the columns of the mapped_table either directly - # or through proxying relationships + def _compile_selectable(self): + """if the 'select_table' keyword argument was specified, + set up a second "surrogate mapper" that will be used for select operations. + the columns of select_table should encompass all the columns of the mapped_table either directly + or through proxying relationships.""" if self.select_table is not self.mapped_table: props = {} - if properties is not None: - for key, prop in properties.iteritems(): + if self.properties is not None: + for key, prop in self.properties.iteritems(): if sql.is_column(prop): props[key] = self.select_table.corresponding_column(prop) elif (isinstance(prop, list) and sql.is_column(prop[0])): props[key] = [self.select_table.corresponding_column(c) for c in prop] - self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on)) - - l = [(key, prop) for key, prop in self.props.iteritems()] - for key, prop in l: - if getattr(prop, 'key', None) is None: - prop.init(key, self) + self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on)).compile() + + def _compile_class(self): + """if this mapper is to be a primary mapper (i.e. the non_primary flag is not set), + associate this Mapper with the given class_ and entity name. subsequent + calls to class_mapper() for the class_/entity name combination will return this + mapper. also decorates the __init__ method on the mapped class to include auto-session attachment logic.""" + if self.non_primary: + return + + if not self.non_primary and (mapper_registry.has_key(self.class_key) and not self.is_primary): + raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined. Use is_primary=True to assign a new primary mapper to the class, or use non_primary=True to create a non primary Mapper" % self.class_) + sessionlib.global_attributes.reset_class_managed(self.class_) + oldinit = self.class_.__init__ + def init(self, *args, **kwargs): + entity_name = kwargs.pop('_sa_entity_name', None) + mapper = mapper_registry.get(ClassKey(self.__class__, entity_name)) + if mapper is not None: + mapper = mapper.compile() + + # this gets the AttributeManager to do some pre-initialization, + # in order to save on KeyErrors later on + sessionlib.global_attributes.init_attr(self) + + if kwargs.has_key('_sa_session'): + session = kwargs.pop('_sa_session') + else: + # works for whatever mapper the class is associated with + if mapper is not None: + session = mapper.extension.get_session() + if session is EXT_PASS: + session = None + else: + session = None + # if a session was found, either via _sa_session or via mapper extension, + # and we have found a mapper, save() this instance to the session, and give it an associated entity_name. + # otherwise, this instance will not have a session or mapper association until it is + # save()d to some session. + if session is not None and mapper is not None: + self._entity_name = entity_name + session._register_new(self) + + if oldinit is not None: + oldinit(self, *args, **kwargs) + # override oldinit, insuring that its not already a Mapper-decorated init method + if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'): + init._sa_mapper_init = True + init.__name__ = oldinit.__name__ + init.__doc__ = oldinit.__doc__ + self.class_.__init__ = init + mapper_registry[self.class_key] = self + if self.entity_name is None: + self.class_.c = self.c + def base_mapper(self): """returns the ultimate base mapper in an inheritance chain""" if self.inherits is not None: @@ -295,20 +486,19 @@ class Mapper(object): else: return self - def _inherits(self, mapper): - """returns True if the given mapper and this mapper are in the same inheritance hierarchy""" - return self.base_mapper() is mapper.base_mapper() - - def add_polymorphic_mapping(self, key, class_or_mapper, entity_name=None): - if isinstance(class_or_mapper, type): - class_or_mapper = class_mapper(class_or_mapper, entity_name=entity_name) - self.polymorphic_map[key] = class_or_mapper - def add_properties(self, dict_of_properties): """adds the given dictionary of properties to this mapper, using add_property.""" for key, value in dict_of_properties.iteritems(): - self.add_property(key, value, True) - + self.add_property(key, value) + + def add_property(self, key, prop): + """adds an indiviual MapperProperty to this mapper. If the mapper has not been compiled yet, + just adds the property to the initial properties dictionary sent to the constructor. if this Mapper + has already been compiled, then the given MapperProperty is compiled immediately.""" + self.properties[key] = prop + if self.__is_compiled: + self._compile_property(key, prop, init=True) + def _create_prop_from_column(self, column, skipmissing=False): if sql.is_column(column): try: @@ -329,8 +519,8 @@ class Mapper(object): return ColumnProperty(*column) else: return None - - def add_property(self, key, prop, init=True, skipmissing=False): + + def _compile_property(self, key, prop, init=True, skipmissing=False): """adds an additional property to this mapper. this is the same as if it were specified within the 'properties' argument to the constructor. if the named property already exists, this will replace it. Useful for @@ -342,7 +532,7 @@ class Mapper(object): if prop is None: raise exceptions.ArgumentError("'%s' is not an instance of MapperProperty or Column" % repr(prop)) - self.props[key] = prop + self.__props[key] = prop if isinstance(prop, ColumnProperty): col = self.select_table.corresponding_column(prop.columns[0], keys_ok=True, raiseerr=False) @@ -359,7 +549,7 @@ class Mapper(object): for mapper in self._inheriting_mappers: p = prop.copy() if p.adapt(mapper): - mapper.add_property(key, p, init=False) + mapper._compile_property(key, p, init=False) def __str__(self): return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + self.mapped_table.name @@ -382,48 +572,6 @@ class Mapper(object): instance will return the primary mapper corresponding to this Mapper's class and entity name.""" instance._entity_name = self.entity_name - def _init_class(self): - """decorates the __init__ method on the mapped class to include auto-session attachment logic, - and assocites this Mapper with its class via the mapper_registry.""" - oldinit = self.class_.__init__ - def init(self, *args, **kwargs): - - # this gets the AttributeManager to do some pre-initialization, - # in order to save on KeyErrors later on - sessionlib.global_attributes.init_attr(self) - - entity_name = kwargs.pop('_sa_entity_name', None) - if kwargs.has_key('_sa_session'): - session = kwargs.pop('_sa_session') - else: - # works for whatever mapper the class is associated with - mapper = mapper_registry.get(ClassKey(self.__class__, entity_name)) - if mapper is not None: - session = mapper.extension.get_session() - if session is EXT_PASS: - session = None - else: - session = None - # if a session was found, either via _sa_session or via mapper extension, - # save() this instance to the session, and give it an associated entity_name. - # otherwise, this instance will not have a session or mapper association until it is - # save()d to some session. - if session is not None: - self._entity_name = entity_name - session._register_new(self) - - if oldinit is not None: - oldinit(self, *args, **kwargs) - # override oldinit, insuring that its not already a Mapper-decorated init method - if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'): - init._sa_mapper_init = True - init.__name__ = oldinit.__name__ - init.__doc__ = oldinit.__doc__ - self.class_.__init__ = init - mapper_registry[self.class_key] = self - if self.entity_name is None: - self.class_.c = self.c - def get_session(self): """returns the contextual session provided by the mapper extension chain @@ -439,9 +587,6 @@ class Mapper(object): """returns True if one of the properties attached to this Mapper is eager loading""" return getattr(self, '_has_eager', False) - def set_property(self, key, prop): - self.props[key] = prop - prop.init(key, self) def instances(self, cursor, session, *mappers, **kwargs): """given a cursor (ResultProxy) from an SQLEngine, returns a list of object instances @@ -492,13 +637,16 @@ class Mapper(object): mapper = Mapper.__new__(Mapper) mapper.__dict__.update(self.__dict__) mapper.__dict__.update(kwargs) - mapper.props = self.props.copy() + mapper.__props = self.__props.copy() + mapper._inheriting_mappers = [] + for m in self._inheriting_mappers: + mapper._inheriting_mappers.append(m.copy()) return mapper def options(self, *options, **kwargs): """uses this mapper as a prototype for a new mapper with different behavior. *options is a list of options directives, which include eagerload(), lazyload(), and noload()""" - + self.compile() optkey = repr([hash_key(o) for o in options]) try: return self._options[optkey] @@ -515,7 +663,7 @@ class Mapper(object): prop = self.columntoproperty[column] except KeyError: try: - prop = self.props[column.key] + prop = self.__props[column.key] if not raiseerror: return None raise exceptions.InvalidRequestError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) @@ -753,20 +901,20 @@ class Mapper(object): """called by an instance of unitofwork.UOWTransaction to register which mappers are dependent on which, as well as DependencyProcessor objects which will process lists of objects in between saves and deletes.""" - for prop in self.props.values(): + for prop in self.__props.values(): prop.register_dependencies(uowcommit, *args, **kwargs) def cascade_iterator(self, type, object, callable_=None, recursive=None): if recursive is None: recursive=util.Set() - for prop in self.props.values(): + for prop in self.__props.values(): for c in prop.cascade_iterator(type, object, recursive): yield c def cascade_callable(self, type, object, callable_, recursive=None): if recursive is None: recursive=util.Set() - for prop in self.props.values(): + for prop in self.__props.values(): prop.cascade_callable(type, object, callable_, recursive) def _row_identity_key(self, row): @@ -800,7 +948,7 @@ class Mapper(object): if populate_existing or session.is_expired(instance, unexpire=True): if not imap.has_key(identitykey): imap[identitykey] = instance - for prop in self.props.values(): + for prop in self.__props.values(): prop.execute(session, instance, row, identitykey, imap, True) if self.extension.append_result(self, session, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS: if result is not None: @@ -859,7 +1007,7 @@ class Mapper(object): def populate_instance(self, session, instance, row, identitykey, imap, isnew, frommapper=None): if frommapper is not None: row = frommapper.translate_row(self, row) - for prop in self.props.values(): + for prop in self.__props.values(): prop.execute(session, instance, row, identitykey, imap, isnew) # deprecated query methods. Query is constructed from Session, and the rest @@ -879,15 +1027,6 @@ class Mapper(object): def using(self, session): """deprecated. use Query instead.""" return querylib.Query(self, session=session) - def __getattr__(self, key): - """deprecated. use Query instead.""" - if (key.startswith('select_by_') or key.startswith('get_by_')): - return getattr(self.query(), key) - else: - raise AttributeError(key) - def compile(self, whereclause = None, **options): - """deprecated. use Query instead.""" - return self.query()._compile(whereclause, **options) def get(self, ident, **kwargs): """deprecated. use Query instead.""" return self.query().get(ident, **kwargs) @@ -934,6 +1073,10 @@ class Mapper(object): class MapperProperty(object): """an element attached to a Mapper that describes and assists in the loading and saving of an attribute on an object instance.""" + def attach(self, mapper): + """called during mapper construction for each property present in the "properties" dictionary. + this is before the Mapper has compiled its internal state.""" + pass def execute(self, session, instance, row, identitykey, imap, isnew): """called when the mapper receives a row. instance is the parent instance corresponding to the row. """ @@ -952,16 +1095,11 @@ class MapperProperty(object): this is called by a mappers select_by method to formulate a set of key/value pairs into a WHERE criterion that spans multiple tables if needed.""" return None - def hash_key(self): - """describes this property and its instantiated arguments in such a way - as to uniquely identify the concept this MapperProperty represents,within - a process.""" - raise NotImplementedError() def setup(self, key, statement, **options): """called when a statement is being constructed. """ return self def init(self, key, parent): - """called when the MapperProperty is first attached to a new parent Mapper.""" + """called during Mapper compilation to compile each MapperProperty.""" self.key = key self.parent = parent self.localparent = parent @@ -969,7 +1107,7 @@ class MapperProperty(object): def adapt(self, newparent): """adapts this MapperProperty to a new parent, assuming the new parent is an inheriting descendant of the old parent. Should return True if the adaptation was successful, or - False if this MapperProperty cannot be adapted to the new parent (the case for this is, + False if this MapperProperty cannot be adapted to the new parent (the case for "False" is, the parent mapper has a polymorphic select, and this property represents a column that is not represented in the new mapper's mapped table)""" #self.parent = newparent @@ -982,6 +1120,9 @@ class MapperProperty(object): """called when the instance is being deleted""" pass def register_dependencies(self, *args, **kwargs): + """called by the Mapper in response to the UnitOfWork calling the Mapper's + register_dependencies operation. Should register with the UnitOfWork all + inter-mapper dependencies as well as dependency processors (see UOW docs for more details)""" pass def is_primary(self): """a return value of True indicates we are the primary MapperProperty for this loader's @@ -1006,11 +1147,12 @@ class ExtensionOption(MapperOption): mapper.extension = self.ext class MapperExtension(object): + """base implementation for an object that provides overriding behavior to various + Mapper functions. For each method in MapperExtension, a result of EXT_PASS indicates + the functionality is not overridden.""" def __init__(self): self.next = None def chain(self, ext): - if ext is self: - raise "nu uh " + repr(self) + " " + repr(ext) self.next = ext return self def get_session(self): @@ -1152,15 +1294,18 @@ class TranslatingDict(dict): return super(TranslatingDict, self).setdefault(self.__translate_col(col), value) class ClassKey(object): - """keys a class and an entity name to a mapper, via the mapper_registry""" + """keys a class and an entity name to a mapper, via the mapper_registry.""" + __metaclass__ = util.ArgSingleton def __init__(self, class_, entity_name): self.class_ = class_ self.entity_name = entity_name def __hash__(self): return hash((self.class_, self.entity_name)) def __eq__(self, other): - return self.class_ is other.class_ and self.entity_name == other.entity_name - + return self is other + def __repr__(self): + return "ClassKey(%s, %s)" % (repr(self.class_), repr(self.entity_name)) + def hash_key(obj): if obj is None: return 'None' @@ -1178,16 +1323,20 @@ def has_mapper(object): def object_mapper(object, raiseerror=True): """given an object, returns the primary Mapper associated with the object instance""" try: - return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name'))] + mapper = mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name'))] except (KeyError, AttributeError): if raiseerror: raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', None))) else: return None - + return mapper.compile() + def class_mapper(class_, entity_name=None): """given a ClassKey, returns the primary Mapper associated with the key.""" try: - return mapper_registry[ClassKey(class_, entity_name)] + mapper = mapper_registry[ClassKey(class_, entity_name)] except (KeyError, AttributeError): raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name)) + return mapper.compile() + + \ No newline at end of file diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 0b609e3008..23cdc78f11 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -67,8 +67,10 @@ class DeferredColumnProperty(ColumnProperty): return mapper.object_mapper(instance).props[self.key].setup_loader(instance) def lazyload(): session = sessionlib.object_session(instance) - connection = session.connection(self.parent) + if session is None: + return None clause = sql.and_() + connection = session.connection(self.parent) try: pk = self.parent.pks_by_table[self.columns[0].table] except KeyError: @@ -139,9 +141,12 @@ class PropertyLoader(mapper.MapperProperty): else: self.backref = backref self.is_backref = is_backref - + private = property(lambda s:s.cascade.delete_orphan) - + + def attach(self, mapper): + mapper._add_compile_trigger(self.argument) + def cascade_iterator(self, type, object, recursive): if not type in self.cascade: return @@ -180,10 +185,10 @@ class PropertyLoader(mapper.MapperProperty): if isinstance(self.argument, type): self.mapper = mapper.class_mapper(self.argument) else: - self.mapper = self.argument + self.mapper = self.argument.compile() self.mapper = self.mapper.get_select_mapper() - + if self.association is not None: if isinstance(self.association, type): self.association = mapper.class_mapper(self.association) @@ -349,24 +354,25 @@ class LazyLoader(PropertyLoader): break else: allparams = False - if allparams: - # if we have a simple straight-primary key load, use mapper.get() - # to possibly save a DB round trip - if self.use_get: - ident = [] - for primary_key in self.mapper.pks_by_table[self.mapper.mapped_table]: - bind = self.lazyreverse[primary_key] - ident.append(params[bind.key]) - return self.mapper.using(session).get(ident) - elif self.order_by is not False: - order_by = self.order_by - elif self.secondary is not None and self.secondary.default_order_by() is not None: - order_by = self.secondary.default_order_by() - else: - order_by = False - result = self.mapper.using(session).select_whereclause(self.lazywhere, order_by=order_by, params=params) + if not allparams: + return None + + # if we have a simple straight-primary key load, use mapper.get() + # to possibly save a DB round trip + if self.use_get: + ident = [] + for primary_key in self.mapper.pks_by_table[self.mapper.mapped_table]: + bind = self.lazyreverse[primary_key] + ident.append(params[bind.key]) + return self.mapper.using(session).get(ident) + elif self.order_by is not False: + order_by = self.order_by + elif self.secondary is not None and self.secondary.default_order_by() is not None: + order_by = self.secondary.default_order_by() else: - result = [] + order_by = False + result = self.mapper.using(session).select_whereclause(self.lazywhere, order_by=order_by, params=params) + if self.uselist: return result else: @@ -428,7 +434,6 @@ class EagerLoader(LazyLoader): parent._has_eager = True self.eagertarget = self.target.alias() -# print "ALIAS", str(self.eagertarget.select()) #selectable.__class__.__name__ if self.secondary: self.eagersecondary = self.secondary.alias() self.aliasizer = Aliasizer(self.target, self.secondary, aliases={ @@ -451,11 +456,13 @@ class EagerLoader(LazyLoader): else: self.eager_order_by = None + def _create_eager_chain(self, recursion_stack=None): + try: + if self.__eager_chain_init == id(self): + return + except AttributeError: + pass - def _create_eager_chain(self, in_chain=False, recursion_stack=None): - if not in_chain and getattr(self, '_eager_chained', False): - return - if recursion_stack is None: recursion_stack = {} @@ -467,6 +474,7 @@ class EagerLoader(LazyLoader): for key, prop in self.mapper.props.iteritems(): if isinstance(prop, EagerLoader): eagerprops.append(prop) + if len(eagerprops): recursion_stack[self.localparent.mapped_table] = True self.mapper = self.mapper.copy() @@ -481,7 +489,7 @@ class EagerLoader(LazyLoader): # print "we are:", id(self), self.target.name, (self.secondary and self.secondary.name or "None"), self.parent.mapped_table.name # print "prop is",id(prop), prop.target.name, (prop.secondary and prop.secondary.name or "None"), prop.parent.mapped_table.name p.do_init_subclass(prop.key, prop.parent, recursion_stack) - p._create_eager_chain(in_chain=True, recursion_stack=recursion_stack) + p._create_eager_chain(recursion_stack=recursion_stack) p.eagerprimary = p.eagerprimary.copy_container() # aliasizer = Aliasizer(p.parent.mapped_table, aliases={p.parent.mapped_table:self.eagertarget}) p.eagerprimary.accept_visitor(self.aliasizer) @@ -490,8 +498,9 @@ class EagerLoader(LazyLoader): del recursion_stack[self.localparent.mapped_table] self._row_decorator = self._create_decorator_row() + self.__eager_chain_init = id(self) - self._eager_chained = True +# print "ROW DECORATOR", self._row_decorator def _aliasize_orderby(self, orderby, copy=True): if copy: @@ -508,9 +517,10 @@ class EagerLoader(LazyLoader): def setup(self, key, statement, eagertable=None, **options): """add a left outer join to the statement thats being constructed""" - # initialize the eager chains late in the game + # initialize the "eager" chain of EagerLoader objects + # this can't quite be done in the do_init_mapper() step self._create_eager_chain() - + if hasattr(statement, '_outerjoin'): towrap = statement._outerjoin else: @@ -600,6 +610,7 @@ class EagerLoader(LazyLoader): try: return self._row_decorator(row) except AttributeError: + # insure the "eager chain" step occurred self._create_eager_chain() return self._row_decorator(row) @@ -617,14 +628,14 @@ class GenericOption(mapper.MapperOption): oldprop = mapper.props[tokens[0]] newprop = oldprop.copy() newprop.argument = self.process_by_key(oldprop.mapper.copy(), tokens[1]) - mapper.set_property(tokens[0], newprop) + mapper._compile_property(tokens[0], newprop) else: self.create_prop(mapper, tokens[0]) return mapper def create_prop(self, mapper, key): kwargs = util.constructor_args(oldprop) - mapper.set_property(key, class_(**kwargs )) + mapper._compile_property(key, class_(**kwargs )) class BackRef(object): """stores the name of a backreference property as well as options to @@ -661,7 +672,7 @@ class BackRef(object): # the backref property is set on the primary mapper parent = prop.parent.primary_mapper() relation = cls(parent, prop.secondary, pj, sj, backref=prop.key, is_backref=True, **self.kwargs) - mapper.add_property(self.key, relation); + mapper._compile_property(self.key, relation); else: # else set one of us as the "backreference" if not mapper.props[self.key].is_backref: @@ -693,7 +704,7 @@ class EagerLazyOption(GenericOption): newprop = class_.__new__(class_) newprop.__dict__.update(oldprop.__dict__) newprop.do_init_subclass(key, mapper) - mapper.set_property(key, newprop) + mapper._compile_property(key, newprop) class DeferredOption(GenericOption): def __init__(self, key, defer=False, **kwargs): @@ -708,7 +719,7 @@ class DeferredOption(GenericOption): prop = DeferredColumnProperty(*oldprop.columns, **self.kwargs) else: prop = ColumnProperty(*oldprop.columns, **self.kwargs) - mapper.set_property(key, prop) + mapper._compile_property(key, prop) class Aliasizer(sql.ClauseVisitor): """converts a table instance within an expression to be an alias of that table.""" @@ -733,8 +744,6 @@ class Aliasizer(sql.ClauseVisitor): if isinstance(clist.clauses[i], schema.Column) and self.tables.has_key(clist.clauses[i].table): orig = clist.clauses[i] clist.clauses[i] = self.get_alias(clist.clauses[i].table).corresponding_column(clist.clauses[i]) - if clist.clauses[i] is None: - raise "cant get orig for " + str(orig) + " against table " + orig.table.name + " " + self.get_alias(orig.table).name def visit_binary(self, binary): if isinstance(binary.left, schema.Column) and self.tables.has_key(binary.left.table): binary.left = self.get_alias(binary.left.table).corresponding_column(binary.left) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 3af55b7390..fbdb1bd17b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -14,9 +14,9 @@ class Query(object): if isinstance(class_or_mapper, type): self.mapper = class_mapper(class_or_mapper, entity_name=entity_name) else: - self.mapper = class_or_mapper + self.mapper = class_or_mapper.compile() self.mapper = self.mapper.get_select_mapper() - + self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh) self.order_by = kwargs.pop('order_by', self.mapper.order_by) self.extension = kwargs.pop('extension', self.mapper.extension) @@ -219,7 +219,7 @@ class Query(object): return self.select_whereclause(whereclause=arg, **kwargs) def select_whereclause(self, whereclause=None, params=None, **kwargs): - statement = self._compile(whereclause, **kwargs) + statement = self.compile(whereclause, **kwargs) return self._select_statement(statement, params=params) def count(self, whereclause=None, params=None, **kwargs): @@ -281,7 +281,7 @@ class Query(object): if len(ident) > i + 1: i += 1 try: - statement = self._compile(self._get_clause) + statement = self.compile(self._get_clause) return self._select_statement(statement, params=params, populate_existing=reload)[0] except IndexError: return None @@ -301,7 +301,7 @@ class Query(object): and (kwargs.has_key('limit') or kwargs.has_key('offset') or kwargs.get('distinct', False)) ) - def _compile(self, whereclause = None, **kwargs): + def compile(self, whereclause = None, **kwargs): order_by = kwargs.pop('order_by', False) from_obj = kwargs.pop('from_obj', []) if order_by is False: diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index c3017e40c2..5f6d1796ca 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -41,6 +41,17 @@ def reversed(seq): raise StopIteration() return rev() +class ArgSingleton(type): + instances = {} + def __call__(self, *args): + hashkey = (self, args) + try: + return ArgSingleton.instances[hashkey] + except KeyError: + instance = type.__call__(self, *args) + ArgSingleton.instances[hashkey] = instance + return instance + class SimpleProperty(object): """a "default" property accessor.""" def __init__(self, key): @@ -92,34 +103,30 @@ class OrderedProperties(object): no append or extend.) """ def __init__(self): - self.__dict__['_list'] = [] + self.__dict__['_OrderedProperties__data'] = OrderedDict() def __len__(self): - return len(self._list) - def keys(self): - return list(self._list) - def get(self, key, default): - return getattr(self, key, default) - def has_key(self, key): - return hasattr(self, key) + return len(self.__data) def __iter__(self): - return iter([self[x] for x in self._list]) + return self.__data.itervalues() def __setitem__(self, key, object): - setattr(self, key, object) + self.__data[key] = object def __getitem__(self, key): - try: - return getattr(self, key) - except AttributeError: - raise KeyError(key) + return self.__data[key] def __delitem__(self, key): - delattr(self, key) - del self._list[self._list.index(key)] + del self.__data[key] def __setattr__(self, key, object): - if not hasattr(self, key): - self._list.append(key) - self.__dict__[key] = object + self.__data[key] = object + def __getattr__(self, key): + try: + return self.__data[key] + except KeyError: + raise AttributeError(key) + def keys(self): + return self.__data.keys() + def has_key(self, key): + return self.__data.has_key(key) def clear(self): - self.__dict__.clear() - self.__dict__['_list'] = [] + self.__data.clear() class OrderedDict(dict): """A Dictionary that keeps its own internal ordering""" diff --git a/test/orm/inheritance2.py b/test/orm/inheritance2.py index c6b9f01985..179beeae8e 100644 --- a/test/orm/inheritance2.py +++ b/test/orm/inheritance2.py @@ -177,7 +177,7 @@ class InheritTest(testbase.AssertMixin): master=relation(Assembly, lazy=False, uselist=False, foreignkey=specification_table.c.master_id, primaryjoin=specification_table.c.master_id==products_table.c.product_id, - backref=backref('specification', primaryjoin=specification_table.c.master_id==products_table.c.product_id), + backref=backref('specification', primaryjoin=specification_table.c.master_id==products_table.c.product_id, cascade="all, delete-orphan"), ), slave=relation(Product, lazy=False, uselist=False, foreignkey=specification_table.c.slave_id, @@ -193,24 +193,12 @@ class InheritTest(testbase.AssertMixin): properties=dict( name=documents_table.c.name, data=deferred(documents_table.c.data), - product=relation(Product, lazy=True, backref='documents'), + product=relation(Product, lazy=True, backref=backref('documents', cascade="all, delete-orphan")), ), ) raster_document_mapper = mapper(RasterDocument, inherits=document_mapper, polymorphic_identity='raster_document') - assembly_mapper.add_property('specification', - relation(SpecLine, lazy=True, - primaryjoin=specification_table.c.master_id==products_table.c.product_id, - backref='master', cascade='all, delete-orphan', - ) - ) - - product_mapper.add_property('documents', - relation(Document, lazy=True, - backref='product', cascade='all, delete-orphan'), - ) - session = create_session() a1 = Assembly(name='a1') @@ -245,17 +233,12 @@ class InheritTest(testbase.AssertMixin): properties=dict( name=documents_table.c.name, data=deferred(documents_table.c.data), - product=relation(Product, lazy=True, backref='documents'), + product=relation(Product, lazy=True, backref=backref('documents', cascade="all, delete-orphan")), ), ) raster_document_mapper = mapper(RasterDocument, inherits=document_mapper, polymorphic_identity='raster_document') - product_mapper.add_property('documents', - relation(Document, lazy=True, - backref='product', cascade='all, delete-orphan'), - ) - session = create_session(echo_uow=False) a1 = Assembly(name='a1') @@ -278,6 +261,67 @@ class InheritTest(testbase.AssertMixin): a1 = session.query(Product).get_by(name='a1') assert len(session.query(Document).select()) == 0 - + + def testfive(self): + """tests the late compilation of mappers""" + + specification_mapper = mapper(SpecLine, specification_table, + properties=dict( + master=relation(Assembly, lazy=False, uselist=False, + foreignkey=specification_table.c.master_id, + primaryjoin=specification_table.c.master_id==products_table.c.product_id, + backref=backref('specification', primaryjoin=specification_table.c.master_id==products_table.c.product_id), + ), + slave=relation(Product, lazy=False, uselist=False, + foreignkey=specification_table.c.slave_id, + primaryjoin=specification_table.c.slave_id==products_table.c.product_id, + ), + quantity=specification_table.c.quantity, + ) + ) + + detail_mapper = mapper(Detail, inherits=Product, + polymorphic_identity='detail') + + raster_document_mapper = mapper(RasterDocument, inherits=Document, + polymorphic_identity='raster_document') + + product_mapper = mapper(Product, products_table, + polymorphic_on=products_table.c.product_type, + polymorphic_identity='product', properties={ + 'documents' : relation(Document, lazy=True, + backref='product', cascade='all, delete-orphan'), + }) + + assembly_mapper = mapper(Assembly, inherits=Product, + polymorphic_identity='assembly') + + document_mapper = mapper(Document, documents_table, + polymorphic_on=documents_table.c.document_type, + polymorphic_identity='document', + properties=dict( + name=documents_table.c.name, + data=deferred(documents_table.c.data), + product=relation(Product, lazy=True, backref='documents'), + ), + ) + + session = create_session() + + a1 = Assembly(name='a1') + a1.specification.append(SpecLine(slave=Detail(name='d1'))) + a1.documents.append(Document('doc1')) + a1.documents.append(RasterDocument('doc2')) + session.save(a1) + orig = repr(a1) + session.flush() + session.clear() + + a1 = session.query(Product).get_by(name='a1') + new = repr(a1) + print orig + print new + assert orig == new == ' specification=[>] documents=[, ]' + if __name__ == "__main__": testbase.main() diff --git a/test/orm/manytomany.py b/test/orm/manytomany.py index 92b7efb260..e5055ef583 100644 --- a/test/orm/manytomany.py +++ b/test/orm/manytomany.py @@ -183,6 +183,11 @@ class M2MTest(testbase.AssertMixin): ) ) + Place.mapper.options() + print Place.mapper.props['inputs'] + print Transition.mapper.props['inputs'] + return + Place.eagermapper = Place.mapper.options( eagerload('inputs', selectalias='ip_alias'), eagerload('outputs', selectalias='op_alias') diff --git a/test/orm/mapper.py b/test/orm/mapper.py index d4225d412e..de38629ece 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -218,7 +218,7 @@ class MapperTest(MapperSuperTest): """tests the various attributes of the properties attached to classes""" m = mapper(User, users, properties = { 'addresses' : relation(mapper(Address, addresses)) - }) + }).compile() self.assert_(User.addresses.property is m.props['addresses']) def testload(self): @@ -295,7 +295,7 @@ class MapperTest(MapperSuperTest): try: m = mapper(User, users, properties = { 'user_name' : relation(mapper(Address, addresses)), - }) + }).compile() self.assert_(False, "should have raised ArgumentError") except exceptions.ArgumentError, e: self.assert_(True) @@ -330,7 +330,7 @@ class MapperTest(MapperSuperTest): sess = create_session() usermapper = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False) - )) + )).compile() # first test straight eager load, 1 statement def go(): @@ -357,6 +357,20 @@ class MapperTest(MapperSuperTest): self.assert_result(l, User, *user_address_result) self.assert_sql_count(db, go, 3) + def testlatecompile(self): + """tests mappers compiling late in the game""" + + mapper(User, users, properties = {'orders': relation(Order)}) + mapper(Item, orderitems, properties={'keywords':relation(Keyword, secondary=itemkeywords)}) + mapper(Keyword, keywords) + mapper(Order, orders, properties={'items':relation(Item)}) + + sess = create_session() + u = sess.query(User).select() + def go(): + print u[0].orders[1].items[0].keywords[1] + self.assert_sql_count(db, go, 3) + def testdeepoptions(self): mapper(User, users, properties = { @@ -441,7 +455,7 @@ class DeferredTest(MapperSuperTest): o = Order() self.assert_(o.description is None) - + q = create_session().query(m) def go(): l = q.select() @@ -777,7 +791,7 @@ class EagerTest(MapperSuperTest): def testbackwardsonetoone(self): m = mapper(Address, addresses, properties = dict( user = relation(mapper(User, users), lazy = False) - )) + )).compile() self.echo(repr(m.props['user'].uselist)) q = create_session().query(m) l = q.select(addresses.c.address_id == 1) @@ -807,7 +821,7 @@ class EagerTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False) )) - s = m.compile(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.user_id)) + s = session.query(m).compile(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.user_id)) c = s.compile() self.echo("\n" + str(c) + repr(c.get_params())) diff --git a/test/orm/objectstore.py b/test/orm/objectstore.py index 5bd855fa78..83d279a284 100644 --- a/test/orm/objectstore.py +++ b/test/orm/objectstore.py @@ -86,7 +86,34 @@ class HistoryTest(SessionTest): u = s.query(m).select()[0] print u.addresses[0].user - +class CustomAttrTest(SessionTest): + def setUpAll(self): + SessionTest.setUpAll(self) + global sometable, metadata, someothertable + metadata = BoundMetaData(testbase.db) + sometable = Table('sometable', metadata, + Column('col1',Integer, primary_key=True)) + someothertable = Table('someothertable', metadata, + Column('col1', Integer, primary_key=True), + Column('scol1', Integer, ForeignKey(sometable.c.col1)), + Column('data', String(20)) + ) + def testbasic(self): + class MyList(list): + pass + class Foo(object): + bars = MyList + class Bar(object): + pass + mapper(Foo, sometable, properties={ + 'bars':relation(Bar) + }) + mapper(Bar, someothertable) + f = Foo() + assert isinstance(f.bars.data, MyList) + def tearDownAll(self): + SessionTest.tearDownAll(self) + class VersioningTest(SessionTest): def setUpAll(self): SessionTest.setUpAll(self)