Source code for common.models.trackedmodel

from __future__ import annotations

from typing import Any
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Sequence
from typing import Set
from typing import TypeVar

from django.db import models
from django.db.models import F
from django.db.models import Field
from django.db.models import Q
from django.db.models.expressions import Expression
from django.db.models.options import Options
from django.db.models.query import QuerySet
from django.db.transaction import atomic
from django.urls import NoReverseMatch
from django.urls import reverse
from polymorphic.models import PolymorphicModel

from common import validators
from common.exceptions import IllegalSaveError
from common.fields import NumericSID
from common.fields import SignedIntSID
from common.models import TimestampedMixin
from common.models.managers import CurrentTrackedModelManager
from common.models.managers import TrackedModelManager
from common.models.mixins import TimestampedMixin
from common.models.tracked_qs import TrackedModelQuerySet
from common.models.tracked_utils import get_deferred_set_fields
from common.models.tracked_utils import get_models_linked_to
from common.models.tracked_utils import get_relations
from common.models.tracked_utils import get_subrecord_relations
from common.util import classproperty
from common.util import get_accessor
from common.util import get_field_tuple
from common.validators import UpdateType
from workbaskets.validators import WorkflowStatus


[docs]class VersionGroup(TimestampedMixin): """A group that contains all versions of the same TrackedModel.""" current_version = models.OneToOneField( "common.TrackedModel", on_delete=models.SET_NULL, null=True, related_query_name="is_current", ) versions: QuerySet[TrackedModel]
Cls = TypeVar("Cls", bound="TrackedModel")
[docs]class TrackedModel(PolymorphicModel, TimestampedMixin): transaction = models.ForeignKey( "common.Transaction", on_delete=models.PROTECT, related_name="tracked_models", editable=False, ) update_type: validators.UpdateType = models.PositiveSmallIntegerField( choices=validators.UpdateType.choices, db_index=True, ) """ The change that was made to the model when this version of the model was authored. The first version should always have :data:`~validators.UpdateType.CREATE`, subsequent versions will have :data:`~validators.UpdateType.UPDATE` and the final version will have :data:`~validators.UpdateType.DELETE`. Deleted models that reappear for the same :attr:`identifying_fields` will have a new :attr:`version_group` created. """ version_group: VersionGroup = models.ForeignKey( VersionGroup, on_delete=models.PROTECT, related_name="versions", ) """ Each version group contains all of the versions of the same logical model. When a new version of a model is authored (e.g. to :data:`~validators.UpdateType.DELETE` it) a new model row is created and added to the same version group as the existing model being changed. Models are identified logically by their :attr:`identifying_fields`, so within one version group all of the models should have the same values for these fields. """ objects: TrackedModelQuerySet = TrackedModelManager.from_queryset( TrackedModelQuerySet, )() current_objects: TrackedModelQuerySet = CurrentTrackedModelManager.from_queryset( TrackedModelQuerySet, )() """The `current_objects` model manager provides a default queryset that, by default, filters to the 'current' transaction.""" business_rules: Iterable = () indirect_business_rules: Iterable = () record_code: int """ The type id of this model's type family in the TARIC specification. This number groups together a number of different models into 'records'. Where two models share a record code, they are conceptually expressing different properties of the same logical model. In theory each :class:`~common.transactions.Transaction` should only contain models with a single :attr:`record_code` (but differing :attr:`subrecord_code`.) """ subrecord_code: int """ The type id of this model in the TARIC specification. The :attr:`subrecord_code` when combined with the :attr:`record_code` uniquely identifies the type within the specification. The subrecord code gives the intended order for models in a transaction, with comparatively smaller subrecord codes needing to come before larger ones. """ identifying_fields: Sequence[str] = ("pk",) """ The fields which together form a composite unique key for each model. The system ID (or SID) field, 'sid' is normally the unique identifier of a TARIC model, but in places where this does not exist models can declare their own. (Note that because multiple versions of each model will exist this does not actually equate to a ``UNIQUE`` constraint in the database.) TrackedModel itself defaults to ("pk",) as it does not have an SID. """ url_suffix = "" """ This is to add a link within a page for get_url() e.g. for linking to a Measure's conditions tab. If url_suffix is set to '#conditions' the output detail url will be /measures/12345678/#conditions """
[docs] def new_version( self: Cls, workbasket, transaction=None, update_type: UpdateType = UpdateType.UPDATE, **overrides, ) -> Cls: """ Create and return a new version of the object. Callers can override existing data by passing in keyword args. The new version is added to a transaction which is created and added to the passed in workbasket (or may be supplied as a keyword arg). `update_type` must be UPDATE or DELETE, with UPDATE as the default. """ if update_type not in ( validators.UpdateType.UPDATE, validators.UpdateType.DELETE, ): raise ValueError("update_type must be UPDATE or DELETE") cls = self.__class__ new_object_kwargs = { field.name: getattr(self, field.name) for field in self._meta.fields if field is self._meta.get_field("version_group") or field.name not in self.system_set_field_names } new_object_overrides = { name: value for name, value in overrides.items() if name not in [f.name for f in get_deferred_set_fields(self)] } new_object_kwargs["update_type"] = update_type new_object_kwargs.update(new_object_overrides) if transaction is None: transaction = workbasket.new_transaction() new_object_kwargs["transaction"] = transaction new_object = cls(**new_object_kwargs) new_object.save() deferred_kwargs = { field.name: field.value_from_object(self) for field in get_deferred_set_fields(self) } deferred_overrides = { name: value for name, value in overrides.items() if name in [f.name for f in get_deferred_set_fields(self)] } deferred_kwargs.update(deferred_overrides) for field in deferred_kwargs: getattr(new_object, field).set(deferred_kwargs[field]) return new_object
[docs] def get_versions(self): """Find all versions of this model.""" if hasattr(self, "version_group"): query = Q(version_group_id=self.version_group_id) else: query = Q(**self.get_identifying_fields()) return type(self).objects.filter(query)
def _get_version_group(self) -> VersionGroup: if self.update_type == validators.UpdateType.CREATE: return VersionGroup.objects.create() latest_version = self.get_versions().latest_approved().last() if not latest_version: # An object may be created and deleted/updated in the same workbasket. # If the workbasket status is not WorkflowStatus.PUBLISHED, # then latest_approved() in the above line of code will return None. # Trying to get the version group off that will throw an exception. # The extra bit of logic below deals with such cases # It will attempt to find the corresponding CREATE record # in the current workbasket and return that as the latest_version. try: latest_version = [ record for transaction in self.transaction.workbasket.transactions.all() for record in transaction.tracked_models.all() if type(record) == type(self) if record.update_type == UpdateType.CREATE if record.get_identifying_fields() == self.get_identifying_fields() ][0] except IndexError: return return latest_version.version_group def _can_write(self): return not ( self.pk and self.transaction.workbasket.status in WorkflowStatus.approved_statuses() )
[docs] def get_identifying_fields( self, identifying_fields: Optional[Iterable[str]] = None, ) -> Dict[str, Any]: """ Get a name/value mapping of the fields that identify this model. :param identifying_fields Optional[Iterable[str]]: Optionally override the fields to retrieve :rtype dict[str, Any]: A dict of field names to values """ identifying_fields = identifying_fields or self.identifying_fields fields = {} for field in identifying_fields: _, fields[field] = get_field_tuple(self, field) return fields
[docs] def identifying_fields_to_string( self, identifying_fields: Optional[Iterable[str]] = None, ) -> str: """ Constructs a comma separated string of the identifying fields of the model with field name and value pairs delimited by "=", eg: "field1=1, field2=2". :param identifying_fields: Optionally override the fields to use in the string :rtype str: The constructed string """ field_list = [ f"{field}={str(value)}" for field, value in self.get_identifying_fields(identifying_fields).items() ] return ", ".join(field_list)
@property def structure_code(self): """ A string used to describe the model instance. Used as the displayed value in an AutocompleteWidget dropdown, and in the "Your tariff changes" list. """ return str(self) @property def structure_description(self) -> Optional[str]: """ The current description of the model, if it has related description models or a description field. :rtype Optional[str]: The current description """ description = None if hasattr(self, "descriptions"): description = self.get_descriptions().last() if description: # Get the actual description, not just the object description = description.description if hasattr(self, "description"): description = self.description return description or None @property def record_identifier(self) -> str: """Returns the record identifier as defined in TARIC3 records specification.""" return f"{self.record_code}{self.subrecord_code}" @property def update_type_str(self) -> str: return dict(UpdateType.choices)[self.update_type] @property def current_version(self: Cls) -> Cls: """The current version of this model.""" current_version = self.version_group.current_version if current_version is None: raise self.__class__.DoesNotExist("Object has no current version") return current_version
[docs] def version_at(self: Cls, transaction) -> Cls: """ The latest version of this model that was approved as of the given transaction. :param transaction Transaction: Limit versions to this transaction :rtype TrackedModel: """ return self.get_versions().approved_up_to_transaction(transaction).get()
@classproperty def copyable_fields(cls): """ Return the set of fields that can have their values copied from one model to another. This is anything that is: - a native value - a foreign key to some other model """ return { field for field in cls._meta.get_fields() if not any((field.many_to_many, field.one_to_many)) and field.name not in cls.system_set_field_names } _meta: Options @classproperty def auto_value_fields(cls) -> Set[Field]: """Returns the set of fields on this model that should have their value set automatically on save, excluding any primary keys.""" return { f for f in cls._meta.get_fields() if isinstance(f, (SignedIntSID, NumericSID)) } # Fields that we don't want to copy from one object to a new one, either # because they will be set by the system automatically or because we will # always want to override them. system_set_field_names = { "is_current", "version_group", "polymorphic_ctype", "id", "update_type", "trackedmodel_ptr", "transaction", "created_at", "updated_at", }
[docs] def copy( self: Cls, transaction, **overrides: Any, ) -> Cls: """ Create a copy of the model as a new logical domain object – i.e. with a new version group, new SID (if present) and update type of CREATE. Any dependent models that are TARIC subrecords of this model will be copied as well. Any many-to-many relationships will also be duplicated if they do not have an explicit through model. Any overrides passed in as keyword arguments will be applied to the new model. If the model uses SIDs, they will be automatically set to the next highest available SID. Models with other identifying fields should have thier new IDs passed in through overrides. """ # Remove any fields from the basic data that are overriden, because # otherwise when we convert foreign keys to IDs (below) Django will # ignore the object from the overrides and just take the ID from the # basic data. basic_fields = self.copyable_fields subrecord_fields = {} for field_name in overrides: field = None # Check for fields on related model if not "__" in field_name: field = self._meta.get_field(field_name) # Check for non-basic fields e.g. related models if field and field in basic_fields: basic_fields.remove(field) # Add non-basic fields from overrides to subrecord_fields dict else: subrecord_fields.update({field_name: overrides[field_name]}) # Remove related models and related model fields from overrides before generating object_data below overrides = { k: v for (k, v) in overrides.items() if k not in subrecord_fields.keys() } # Remove any SIDs from the copied data. This allows them to either # automatically pick the next highest value or to be passed in. for field in self.auto_value_fields: basic_fields.remove(field) # Build the dictionary of data that the new model will have. Convert any # foreign keys into ids because ``value_from_object`` returns PKs. model_data = { f.name + ( "_id" if any((f.many_to_one, f.one_to_one)) else "" ): f.value_from_object(self) for f in basic_fields } new_object_data = { **model_data, "transaction": transaction, "update_type": validators.UpdateType.CREATE, **overrides, } new_object = type(self).objects.create(**new_object_data) # Now copy any many-to-many fields with an auto-created through model. # These must be handled after creation of the new model. We only need to # do this for auto-created models because others will be handled below. for field in get_deferred_set_fields(self): getattr(new_object, field.name).set(field.value_from_object(self)) # Now go and create copies of all of the models that reference this one # with a foreign key that are part of the same record family. Find all # of the related models and then recursively call copy on them, but with # the new model substituted in place of this one. It's done this way to # give these related models a chance to increment SIDs, etc. for field in get_subrecord_relations(self.__class__): ignore = False # Check if user passed related model into overrides argument if field.name in subrecord_fields.keys(): # If user passed a new unsaved model, set the remote field value equal to new_object for each model passed # e.g. if a Measure is copied and a MeasureCondition is passed, update `dependent_measure` field to `new_object` if subrecord_fields[field.name]: for subrecord in subrecord_fields[field.name]: remote_field = [ f for f in self._meta.get_fields() if f.name == field.name ][0].remote_field.name if not subrecord.pk: setattr(subrecord, remote_field, new_object) subrecord.save() else: # If user passed a saved object, create a copy of that object with remote_field pointing at the new copied object # set ignore to True, so that duplicate copies are not made below subrecord.copy(transaction, **{remote_field: new_object}) ignore = True # Else, if an empty or None value is passed, set ignore to True, so that related models are not copied # e.g. if an existing Measure with two conditions is copied with conditions=[], the copy will have no conditions else: ignore = True queryset = getattr(self, field.get_accessor_name()) reverse_field_name = field.field.name kwargs = {reverse_field_name: new_object} nested_fields = { k.split("__", 1)[1]: v for (k, v) in subrecord_fields.items() if field.name in k and field.name != k } kwargs.update(nested_fields) if not ignore: try: for model in queryset.approved_up_to_transaction(transaction): model.copy(transaction, **kwargs) except ValueError: # Calling a related manager for an unsaved instance would nevertheless result in an empty queryset pass return new_object
[docs] def in_use_by(self, via_relation: str, transaction=None) -> QuerySet[TrackedModel]: """ Returns all of the models that are referencing this one via the specified relation and exist as of the passed transaction. ``via_relation`` should be the name of a relation, and a ``KeyError`` will be raised if the relation name is not valid for this model. Relations are accessible via get_relations helper method. """ relation = {r.name: r for r in get_relations(self.__class__).keys()}[ via_relation ] remote_model = relation.remote_field.model remote_field_name = get_accessor(relation.remote_field) return remote_model.objects.filter( **{f"{remote_field_name}__version_group": self.version_group}, ).approved_up_to_transaction(transaction)
[docs] def in_use(self, transaction=None, *relations: str) -> bool: """ Returns True if there are any models that are using this one as of the specified transaction. This can be any model this model is related to, but ignoring any subrecords (because e.g. a footnote is not considered "in use by" its own description) and then filtering for only things that link _to_ this model. The list of relations can be filtered by passing in the name of a relation. If a name is passed in that does not refer to a relation on this model, ``ValueError`` will be raised. """ # Get the list of models that use models of this type. class_ = self.__class__ using_models = set( relation.name for relation in ( get_relations(class_).keys() - get_subrecord_relations(class_) - get_models_linked_to(class_).keys() ) ) # If the user has specified names, check that they are sane # and then filter the relations to them, if relations: bad_names = set(relations) - set(using_models) if any(bad_names): raise ValueError( f"{bad_names} are unknown relations; use one of {using_models}", ) using_models = { relation for relation in using_models if relation in relations } # If this model doesn't have any using relations, it cannot be in use. if not any(using_models): return False # If we find any objects for any relation, then the model is in use. for relation_name in using_models: relation_queryset = self.in_use_by(relation_name, transaction) if relation_queryset.exists(): return True return False
[docs] @atomic def save(self, *args, force_write=False, **kwargs): """ Save the model to the database. :param force_write bool: Ignore append-only restrictions and write to the database even if the model already exists """ if not force_write and not self._can_write(): raise IllegalSaveError( "TrackedModels cannot be updated once written and approved. " "If writing a new row, use `.new_draft` instead", ) if not hasattr(self, "version_group"): self.version_group = self._get_version_group() return_value = super().save(*args, **kwargs) if self.transaction.workbasket.status in WorkflowStatus.approved_statuses(): self.version_group.current_version = self self.version_group.save() auto_fields = { field for field in self.auto_value_fields if field.attname in self.__dict__ and isinstance(self.__dict__.get(field.attname), (Expression, F)) } # If the model contains any fields that are built in the database, the # fields will still contain the expression objects. So remove them now # and Django will lazy fetch the real values if they are accessed. for field in auto_fields: delattr(self, field.name) return return_value
def __str__(self): return self.identifying_fields_to_string() def __hash__(self): return hash(f"{__name__}.{self.__class__.__name__}")
[docs] def get_url(self, action: str = "detail") -> Optional[str]: """ Generate a URL to a representation of the model in the webapp. Callers should handle the case where no URL is returned. :param action str: The view type to generate a URL for (default "detail"), eg: "list" or "edit" :rtype Optional[str]: The generated URL """ kwargs = {} if action not in ["list", "create"]: kwargs = self.get_identifying_fields() try: if ( action == "edit" and self.transaction.workbasket.status == WorkflowStatus.EDITING ): # Edits in WorkBaskets that are in EDITING state get real # changes via DB updates, not newly created UPDATE instances. if self.update_type == UpdateType.CREATE: action += "-create" elif self.update_type == UpdateType.UPDATE: action += "-update" url = reverse( f"{self.get_url_pattern_name_prefix()}-ui-{action}", kwargs=kwargs, ) return f"{url}{self.url_suffix}" except NoReverseMatch: return None
[docs] @classmethod def get_url_pattern_name_prefix(cls): """ Get the prefix string for a view name for this model. By default, this is the verbose name of the model with spaces replaced by underscores, but this method allows this to be overridden. :rtype str: The prefix """ prefix = getattr(cls, "url_pattern_name_prefix", None) if not prefix: prefix = cls._meta.verbose_name.replace(" ", "_") return prefix