# -*- coding: utf-8 -*- from operator import attrgetter import sys from django.db import connections from django.db import router from django.db.models import signals from django.db.models.fields.related import add_lazy_relation, create_many_related_manager from django.db.models.fields.related import ManyToManyField, ReverseManyRelatedObjectsDescriptor from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT from django.conf import settings from django.utils.functional import curry from sortedm2m.forms import SortedMultipleChoiceField if sys.version_info[0] < 3: string_types = basestring else: string_types = str SORT_VALUE_FIELD_NAME = 'sort_value' def create_sorted_many_to_many_intermediate_model(field, klass): from django.db import models managed = True if isinstance(field.rel.to, string_types) and field.rel.to != RECURSIVE_RELATIONSHIP_CONSTANT: to_model = field.rel.to to = to_model.split('.')[-1] def set_managed(field, model, cls): field.rel.through._meta.managed = model._meta.managed or cls._meta.managed add_lazy_relation(klass, field, to_model, set_managed) elif isinstance(field.rel.to, string_types): to = klass._meta.object_name to_model = klass managed = klass._meta.managed else: to = field.rel.to._meta.object_name to_model = field.rel.to managed = klass._meta.managed or to_model._meta.managed name = '%s_%s' % (klass._meta.object_name, field.name) if field.rel.to == RECURSIVE_RELATIONSHIP_CONSTANT or to == klass._meta.object_name: from_ = 'from_%s' % to.lower() to = 'to_%s' % to.lower() else: from_ = klass._meta.object_name.lower() to = to.lower() meta = type(str('Meta'), (object,), { 'db_table': field._get_m2m_db_table(klass._meta), 'managed': managed, 'auto_created': klass, 'app_label': klass._meta.app_label, 'unique_together': (from_, to), 'ordering': (field.sort_value_field_name,), 'verbose_name': '%(from)s-%(to)s relationship' % {'from': from_, 'to': to}, 'verbose_name_plural': '%(from)s-%(to)s relationships' % {'from': from_, 'to': to}, }) # Construct and return the new class. def default_sort_value(name): model = models.get_model(klass._meta.app_label, name) return model._default_manager.count() default_sort_value = curry(default_sort_value, name) return type(str(name), (models.Model,), { 'Meta': meta, '__module__': klass.__module__, from_: models.ForeignKey(klass, related_name='%s+' % name), to: models.ForeignKey(to_model, related_name='%s+' % name), field.sort_value_field_name: models.IntegerField(default=default_sort_value), '_sort_field_name': field.sort_value_field_name, '_from_field_name': from_, '_to_field_name': to, }) def create_sorted_many_related_manager(superclass, rel): RelatedManager = create_many_related_manager(superclass, rel) class SortedRelatedManager(RelatedManager): def get_query_set(self): # We use ``extra`` method here because we have no other access to # the extra sorting field of the intermediary model. The fields # are hidden for joins because we set ``auto_created`` on the # intermediary's meta options. try: return self.instance._prefetched_objects_cache[self.prefetch_cache_name] except (AttributeError, KeyError): return super(SortedRelatedManager, self).\ get_query_set().\ extra(order_by=['%s.%s' % ( rel.through._meta.db_table, rel.through._sort_field_name, )]) if not hasattr(RelatedManager, '_get_fk_val'): @property def _fk_val(self): return self._pk_val def get_prefetch_query_set(self, instances): # mostly a copy of get_prefetch_query_set from ManyRelatedManager # but with addition of proper ordering db = self._db or router.db_for_read(instances[0].__class__, instance=instances[0]) query = {'%s__pk__in' % self.query_field_name: set(obj._get_pk_val() for obj in instances)} qs = super(RelatedManager, self).get_query_set().using(db)._next_is_sticky().filter(**query) # M2M: need to annotate the query in order to get the primary model # that the secondary model was actually related to. We know that # there will already be a join on the join table, so we can just add # the select. # For non-autocreated 'through' models, can't assume we are # dealing with PK values. fk = self.through._meta.get_field(self.source_field_name) source_col = fk.column join_table = self.through._meta.db_table connection = connections[db] qn = connection.ops.quote_name qs = qs.extra(select={'_prefetch_related_val': '%s.%s' % (qn(join_table), qn(source_col))}, order_by=['%s.%s' % ( rel.through._meta.db_table, rel.through._sort_field_name, )]) select_attname = fk.rel.get_related_field().get_attname() return (qs, attrgetter('_prefetch_related_val'), attrgetter(select_attname), False, self.prefetch_cache_name) def _add_items(self, source_field_name, target_field_name, *objs): # source_field_name: the PK fieldname in join_table for the source object # target_field_name: the PK fieldname in join_table for the target object # *objs - objects to add. Either object instances, or primary keys of object instances. # If there aren't any objects, there is nothing to do. from django.db.models import Model if objs: new_ids = [] for obj in objs: if isinstance(obj, self.model): if not router.allow_relation(obj, self.instance): raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' % (obj, self.instance._state.db, obj._state.db)) if hasattr(self, '_get_fk_val'): # Django>=1.5 fk_val = self._get_fk_val(obj, target_field_name) if fk_val is None: raise ValueError('Cannot add "%r": the value for field "%s" is None' % (obj, target_field_name)) new_ids.append(self._get_fk_val(obj, target_field_name)) else: # Django<1.5 new_ids.append(obj.pk) elif isinstance(obj, Model): raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) else: new_ids.append(obj) db = router.db_for_write(self.through, instance=self.instance) vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True) vals = vals.filter(**{ source_field_name: self._fk_val, '%s__in' % target_field_name: new_ids, }) for val in vals: if val in new_ids: new_ids.remove(val) _new_ids = [] for pk in new_ids: if pk not in _new_ids: _new_ids.append(pk) new_ids = _new_ids new_ids_set = set(new_ids) if self.reverse or source_field_name == self.source_field_name: # Don't send the signal when we are inserting the # duplicate data row for symmetrical reverse entries. signals.m2m_changed.send(sender=rel.through, action='pre_add', instance=self.instance, reverse=self.reverse, model=self.model, pk_set=new_ids_set, using=db) # Add the ones that aren't there already sort_field_name = self.through._sort_field_name sort_field = self.through._meta.get_field_by_name(sort_field_name)[0] for obj_id in new_ids: self.through._default_manager.using(db).create(**{ '%s_id' % source_field_name: self._fk_val, # Django 1.5 compatibility '%s_id' % target_field_name: obj_id, sort_field_name: sort_field.get_default(), }) if self.reverse or source_field_name == self.source_field_name: # Don't send the signal when we are inserting the # duplicate data row for symmetrical reverse entries. signals.m2m_changed.send(sender=rel.through, action='post_add', instance=self.instance, reverse=self.reverse, model=self.model, pk_set=new_ids_set, using=db) return SortedRelatedManager class ReverseSortedManyRelatedObjectsDescriptor(ReverseManyRelatedObjectsDescriptor): @property def related_manager_cls(self): return create_sorted_many_related_manager( self.field.rel.to._default_manager.__class__, self.field.rel ) class SortedManyToManyField(ManyToManyField): ''' Providing a many to many relation that remembers the order of related objects. Accept a boolean ``sorted`` attribute which specifies if relation is ordered or not. Default is set to ``True``. If ``sorted`` is set to ``False`` the field will behave exactly like django's ``ManyToManyField``. ''' def __init__(self, to, sorted=True, **kwargs): self.sorted = sorted self.sort_value_field_name = kwargs.pop( 'sort_value_field_name', SORT_VALUE_FIELD_NAME) super(SortedManyToManyField, self).__init__(to, **kwargs) if self.sorted: self.help_text = kwargs.get('help_text', None) def contribute_to_class(self, cls, name): if not self.sorted: return super(SortedManyToManyField, self).contribute_to_class(cls, name) # To support multiple relations to self, it's useful to have a non-None # related name on symmetrical relations for internal reasons. The # concept doesn't make a lot of sense externally ("you want me to # specify *what* on my non-reversible relation?!"), so we set it up # automatically. The funky name reduces the chance of an accidental # clash. if self.rel.symmetrical and (self.rel.to == "self" or self.rel.to == cls._meta.object_name): self.rel.related_name = "%s_rel_+" % name super(ManyToManyField, self).contribute_to_class(cls, name) # The intermediate m2m model is not auto created if: # 1) There is a manually specified intermediate, or # 2) The class owning the m2m field is abstract. if not self.rel.through and not cls._meta.abstract: self.rel.through = create_sorted_many_to_many_intermediate_model(self, cls) # Add the descriptor for the m2m relation setattr(cls, self.name, ReverseSortedManyRelatedObjectsDescriptor(self)) # Set up the accessor for the m2m table name for the relation self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta) # Populate some necessary rel arguments so that cross-app relations # work correctly. if isinstance(self.rel.through, string_types): def resolve_through_model(field, model, cls): field.rel.through = model add_lazy_relation(cls, self, self.rel.through, resolve_through_model) if hasattr(cls._meta, 'duplicate_targets'): # Django<1.5 if isinstance(self.rel.to, string_types): target = self.rel.to else: target = self.rel.to._meta.db_table cls._meta.duplicate_targets[self.column] = (target, "m2m") def formfield(self, **kwargs): defaults = {} if self.sorted: defaults['form_class'] = SortedMultipleChoiceField defaults.update(kwargs) return super(SortedManyToManyField, self).formfield(**defaults) # Add introspection rules for South database migrations # See http://south.aeracode.org/docs/customfields.html try: import south except ImportError: south = None if south is not None and 'south' in settings.INSTALLED_APPS: from south.modelsinspector import add_introspection_rules add_introspection_rules( [( (SortedManyToManyField,), [], {"sorted": ["sorted", {"default": True}]}, )], [r'^sortedm2m\.fields\.SortedManyToManyField'] ) # Monkeypatch South M2M actions to create the sorted through model. # FIXME: This doesn't detect if you changed the sorted argument to the field. import south.creator.actions from south.creator.freezer import model_key class AddM2M(south.creator.actions.AddM2M): SORTED_FORWARDS_TEMPLATE = ''' # Adding SortedM2M table for field %(field_name)s on '%(model_name)s' db.create_table(%(table_name)r, ( ('id', models.AutoField(verbose_name='ID', primary_key=True, auto_created=True)), (%(left_field)r, models.ForeignKey(orm[%(left_model_key)r], null=False)), (%(right_field)r, models.ForeignKey(orm[%(right_model_key)r], null=False)), (%(sort_field)r, models.IntegerField()) )) db.create_unique(%(table_name)r, [%(left_column)r, %(right_column)r])''' def console_line(self): if isinstance(self.field, SortedManyToManyField) and self.field.sorted: return " + Added SortedM2M table for %s on %s.%s" % ( self.field.name, self.model._meta.app_label, self.model._meta.object_name, ) else: return super(AddM2M, self).console_line() def forwards_code(self): if isinstance(self.field, SortedManyToManyField) and self.field.sorted: return self.SORTED_FORWARDS_TEMPLATE % { "model_name": self.model._meta.object_name, "field_name": self.field.name, "table_name": self.field.m2m_db_table(), "left_field": self.field.m2m_column_name()[:-3], # Remove the _id part "left_column": self.field.m2m_column_name(), "left_model_key": model_key(self.model), "right_field": self.field.m2m_reverse_name()[:-3], # Remove the _id part "right_column": self.field.m2m_reverse_name(), "right_model_key": model_key(self.field.rel.to), "sort_field": self.field.sort_value_field_name, } else: return super(AddM2M, self).forwards_code() class DeleteM2M(AddM2M): def console_line(self): return " - Deleted M2M table for %s on %s.%s" % ( self.field.name, self.model._meta.app_label, self.model._meta.object_name, ) def forwards_code(self): return AddM2M.backwards_code(self) def backwards_code(self): return AddM2M.forwards_code(self) south.creator.actions.AddM2M = AddM2M south.creator.actions.DeleteM2M = DeleteM2M