diff --git a/src/documents/filters.py b/src/documents/filters.py deleted file mode 100644 index a1c9917a8..000000000 --- a/src/documents/filters.py +++ /dev/null @@ -1,950 +0,0 @@ -from __future__ import annotations - -import functools -import inspect -import json -import operator -from contextlib import contextmanager -from typing import TYPE_CHECKING - -from django.contrib.contenttypes.models import ContentType -from django.db.models import Case -from django.db.models import CharField -from django.db.models import Count -from django.db.models import Exists -from django.db.models import IntegerField -from django.db.models import OuterRef -from django.db.models import Q -from django.db.models import Subquery -from django.db.models import Sum -from django.db.models import Value -from django.db.models import When -from django.db.models.functions import Cast -from django.utils.translation import gettext_lazy as _ -from django_filters.rest_framework import BooleanFilter -from django_filters.rest_framework import Filter -from django_filters.rest_framework import FilterSet -from drf_spectacular.utils import extend_schema_field -from guardian.utils import get_group_obj_perms_model -from guardian.utils import get_user_obj_perms_model -from rest_framework import serializers -from rest_framework.filters import OrderingFilter -from rest_framework_guardian.filters import ObjectPermissionsFilter - -from paperless.models import Correspondent -from paperless.models import CustomField -from paperless.models import CustomFieldInstance -from paperless.models import Document -from paperless.models import DocumentType -from paperless.models import PaperlessTask -from paperless.models import ShareLink -from paperless.models import StoragePath -from paperless.models import Tag - -if TYPE_CHECKING: - from collections.abc import Callable - -CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"] -ID_KWARGS = ["in", "exact"] -INT_KWARGS = ["exact", "gt", "gte", "lt", "lte", "isnull"] -DATE_KWARGS = [ - "year", - "month", - "day", - "date__gt", - "date__gte", - "gt", - "gte", - "date__lt", - "date__lte", - "lt", - "lte", -] - -CUSTOM_FIELD_QUERY_MAX_DEPTH = 10 -CUSTOM_FIELD_QUERY_MAX_ATOMS = 20 - - -class CorrespondentFilterSet(FilterSet): - class Meta: - model = Correspondent - fields = { - "id": ID_KWARGS, - "name": CHAR_KWARGS, - } - - -class TagFilterSet(FilterSet): - class Meta: - model = Tag - fields = { - "id": ID_KWARGS, - "name": CHAR_KWARGS, - } - - -class DocumentTypeFilterSet(FilterSet): - class Meta: - model = DocumentType - fields = { - "id": ID_KWARGS, - "name": CHAR_KWARGS, - } - - -class StoragePathFilterSet(FilterSet): - class Meta: - model = StoragePath - fields = { - "id": ID_KWARGS, - "name": CHAR_KWARGS, - "path": CHAR_KWARGS, - } - - -class ObjectFilter(Filter): - def __init__(self, *, exclude=False, in_list=False, field_name=""): - super().__init__() - self.exclude = exclude - self.in_list = in_list - self.field_name = field_name - - def filter(self, qs, value): - if not value: - return qs - - try: - object_ids = [int(x) for x in value.split(",")] - except ValueError: - return qs - - if self.in_list: - qs = qs.filter(**{f"{self.field_name}__id__in": object_ids}).distinct() - else: - for obj_id in object_ids: - if self.exclude: - qs = qs.exclude(**{f"{self.field_name}__id": obj_id}) - else: - qs = qs.filter(**{f"{self.field_name}__id": obj_id}) - - return qs - - -@extend_schema_field(serializers.BooleanField) -class InboxFilter(Filter): - def filter(self, qs, value): - if value == "true": - return qs.filter(tags__is_inbox_tag=True) - elif value == "false": - return qs.exclude(tags__is_inbox_tag=True) - else: - return qs - - -@extend_schema_field(serializers.CharField) -class TitleContentFilter(Filter): - def filter(self, qs, value): - if value: - return qs.filter(Q(title__icontains=value) | Q(content__icontains=value)) - else: - return qs - - -@extend_schema_field(serializers.BooleanField) -class SharedByUser(Filter): - def filter(self, qs, value): - ctype = ContentType.objects.get_for_model(self.model) - UserObjectPermission = get_user_obj_perms_model() - GroupObjectPermission = get_group_obj_perms_model() - # see https://github.com/paperless-ngx/paperless-ngx/issues/5392, we limit subqueries - # to 1 because Postgres doesn't like returning > 1 row, but all we care about is > 0 - return ( - qs.filter( - owner_id=value, - ) - .annotate( - num_shared_users=Count( - UserObjectPermission.objects.filter( - content_type=ctype, - object_pk=Cast(OuterRef("pk"), CharField()), - ).values("user_id")[:1], - ), - ) - .annotate( - num_shared_groups=Count( - GroupObjectPermission.objects.filter( - content_type=ctype, - object_pk=Cast(OuterRef("pk"), CharField()), - ).values("group_id")[:1], - ), - ) - .filter( - Q(num_shared_users__gt=0) | Q(num_shared_groups__gt=0), - ) - if value is not None - else qs - ) - - -class CustomFieldFilterSet(FilterSet): - class Meta: - model = CustomField - fields = { - "id": ID_KWARGS, - "name": CHAR_KWARGS, - } - - -@extend_schema_field(serializers.CharField) -class CustomFieldsFilter(Filter): - def filter(self, qs, value): - if value: - fields_with_matching_selects = CustomField.objects.filter( - extra_data__icontains=value, - ) - option_ids = [] - if fields_with_matching_selects.count() > 0: - for field in fields_with_matching_selects: - options = field.extra_data.get("select_options", []) - for _, option in enumerate(options): - if option.get("label").lower().find(value.lower()) != -1: - option_ids.extend([option.get("id")]) - return ( - qs.filter(custom_fields__field__name__icontains=value) - | qs.filter(custom_fields__value_text__icontains=value) - | qs.filter(custom_fields__value_bool__icontains=value) - | qs.filter(custom_fields__value_int__icontains=value) - | qs.filter(custom_fields__value_float__icontains=value) - | qs.filter(custom_fields__value_date__icontains=value) - | qs.filter(custom_fields__value_url__icontains=value) - | qs.filter(custom_fields__value_monetary__icontains=value) - | qs.filter(custom_fields__value_document_ids__icontains=value) - | qs.filter(custom_fields__value_select__in=option_ids) - ) - else: - return qs - - -class MimeTypeFilter(Filter): - def filter(self, qs, value): - if value: - return qs.filter(mime_type__icontains=value) - else: - return qs - - -class SelectField(serializers.CharField): - def __init__(self, custom_field: CustomField): - self._options = custom_field.extra_data["select_options"] - super().__init__(max_length=16) - - def to_internal_value(self, data): - # If the supplied value is the option label instead of the ID - try: - data = next( - option.get("id") - for option in self._options - if option.get("label") == data - ) - except StopIteration: - pass - return super().to_internal_value(data) - - -def handle_validation_prefix(func: Callable): - """ - Catch ValidationErrors raised by the wrapped function - and add a prefix to the exception detail to track what causes the exception, - similar to nested serializers. - """ - - def wrapper(*args, validation_prefix=None, **kwargs): - try: - return func(*args, **kwargs) - except serializers.ValidationError as e: - raise serializers.ValidationError({validation_prefix: e.detail}) - - # Update the signature to include the validation_prefix argument - old_sig = inspect.signature(func) - new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY) - new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param]) - - # Apply functools.wraps and manually set the new signature - functools.update_wrapper(wrapper, func) - wrapper.__signature__ = new_sig - - return wrapper - - -class CustomFieldQueryParser: - EXPR_BY_CATEGORY = { - "basic": ["exact", "in", "isnull", "exists"], - "string": [ - "icontains", - "istartswith", - "iendswith", - ], - "arithmetic": [ - "gt", - "gte", - "lt", - "lte", - "range", - ], - "containment": ["contains"], - } - - SUPPORTED_EXPR_CATEGORIES = { - CustomField.FieldDataType.STRING: ("basic", "string"), - CustomField.FieldDataType.URL: ("basic", "string"), - CustomField.FieldDataType.DATE: ("basic", "arithmetic"), - CustomField.FieldDataType.BOOL: ("basic",), - CustomField.FieldDataType.INT: ("basic", "arithmetic"), - CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"), - CustomField.FieldDataType.MONETARY: ("basic", "string", "arithmetic"), - CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"), - CustomField.FieldDataType.SELECT: ("basic",), - } - - DATE_COMPONENTS = [ - "year", - "iso_year", - "month", - "day", - "week", - "week_day", - "iso_week_day", - "quarter", - ] - - def __init__( - self, - validation_prefix, - max_query_depth=10, - max_atom_count=20, - ) -> None: - """ - A helper class that parses the query string into a `django.db.models.Q` for filtering - documents based on custom field values. - - The syntax of the query expression is illustrated with the below pseudo code rules: - 1. parse([`custom_field`, "exists", true]): - matches documents with Q(custom_fields__field=`custom_field`) - 2. parse([`custom_field`, "exists", false]): - matches documents with ~Q(custom_fields__field=`custom_field`) - 3. parse([`custom_field`, `op`, `value`]): - matches documents with - Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`) - 4. parse(["AND", [`q0`, `q1`, ..., `qn`]]) - -> parse(`q0`) & parse(`q1`) & ... & parse(`qn`) - 5. parse(["OR", [`q0`, `q1`, ..., `qn`]]) - -> parse(`q0`) | parse(`q1`) | ... | parse(`qn`) - 6. parse(["NOT", `q`]) - -> ~parse(`q`) - - Args: - validation_prefix: Used to generate the ValidationError message. - max_query_depth: Limits the maximum nesting depth of queries. - max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query. - - `max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily - complex SQL queries. - """ - self._custom_fields: dict[int | str, CustomField] = {} - self._validation_prefix = validation_prefix - # Dummy ModelSerializer used to convert a Django models.Field to serializers.Field. - self._model_serializer = serializers.ModelSerializer() - # Used for sanity check - self._max_query_depth = max_query_depth - self._max_atom_count = max_atom_count - self._current_depth = 0 - self._atom_count = 0 - # The set of annotations that we need to apply to the queryset - self._annotations = {} - - def parse(self, query: str) -> tuple[Q, dict[str, Count]]: - """ - Parses the query string into a `django.db.models.Q` - and a set of annotations to be applied to the queryset. - """ - try: - expr = json.loads(query) - except json.JSONDecodeError: - raise serializers.ValidationError( - {self._validation_prefix: [_("Value must be valid JSON.")]}, - ) - return ( - self._parse_expr(expr, validation_prefix=self._validation_prefix), - self._annotations, - ) - - @handle_validation_prefix - def _parse_expr(self, expr) -> Q: - """ - Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr. - """ - with self._track_query_depth(): - if isinstance(expr, list | tuple): - if len(expr) == 2: - return self._parse_logical_expr(*expr) - elif len(expr) == 3: - return self._parse_atom(*expr) - raise serializers.ValidationError( - [_("Invalid custom field query expression")], - ) - - @handle_validation_prefix - def _parse_expr_list(self, exprs) -> list[Q]: - """ - Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5. - """ - if not isinstance(exprs, list | tuple) or not exprs: - raise serializers.ValidationError( - [_("Invalid expression list. Must be nonempty.")], - ) - return [ - self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs) - ] - - def _parse_logical_expr(self, op, args) -> Q: - """ - Handles rule 4, 5, 6. - """ - op_lower = op.lower() - - if op_lower == "not": - return ~self._parse_expr(args, validation_prefix=1) - - if op_lower == "and": - op_func = operator.and_ - elif op_lower == "or": - op_func = operator.or_ - else: - raise serializers.ValidationError( - {"0": [_("Invalid logical operator {op!r}").format(op=op)]}, - ) - - qs = self._parse_expr_list(args, validation_prefix="1") - return functools.reduce(op_func, qs) - - def _parse_atom(self, id_or_name, op, value) -> Q: - """ - Handles rule 1, 2, 3. - """ - # Guard against queries with too many conditions. - self._atom_count += 1 - if self._atom_count > self._max_atom_count: - raise serializers.ValidationError( - [_("Maximum number of query conditions exceeded.")], - ) - - custom_field = self._get_custom_field(id_or_name, validation_prefix="0") - op = self._validate_atom_op(custom_field, op, validation_prefix="1") - value = self._validate_atom_value( - custom_field, - op, - value, - validation_prefix="2", - ) - - # Needed because not all DB backends support Array __contains - if ( - custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK - and op == "contains" - ): - return self._parse_atom_doc_link_contains(custom_field, value) - - value_field_name = CustomFieldInstance.get_value_field_name( - custom_field.data_type, - ) - if ( - custom_field.data_type == CustomField.FieldDataType.MONETARY - and op in self.EXPR_BY_CATEGORY["arithmetic"] - ): - value_field_name = "value_monetary_amount" - has_field = Q(custom_fields__field=custom_field) - - # We need to use an annotation here because different atoms - # might be referring to different instances of custom fields. - annotation_name = f"_custom_field_filter_{len(self._annotations)}" - - # Our special exists operator. - if op == "exists": - annotation = Count("custom_fields", filter=has_field) - # A Document should have > 0 match if it has this field, or 0 if doesn't. - query_op = "gt" if value else "exact" - query = Q(**{f"{annotation_name}__{query_op}": 0}) - else: - # Check if 1) custom field name matches, and 2) value satisfies condition - field_filter = has_field & Q( - **{f"custom_fields__{value_field_name}__{op}": value}, - ) - # Annotate how many matching custom fields each document has - annotation = Count("custom_fields", filter=field_filter) - # Filter document by count - query = Q(**{f"{annotation_name}__gt": 0}) - - self._annotations[annotation_name] = annotation - return query - - @handle_validation_prefix - def _get_custom_field(self, id_or_name): - """Get the CustomField instance by id or name.""" - if id_or_name in self._custom_fields: - return self._custom_fields[id_or_name] - - kwargs = ( - {"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name} - ) - try: - custom_field = CustomField.objects.get(**kwargs) - except CustomField.DoesNotExist: - raise serializers.ValidationError( - [_("{name!r} is not a valid custom field.").format(name=id_or_name)], - ) - self._custom_fields[custom_field.id] = custom_field - self._custom_fields[custom_field.name] = custom_field - return custom_field - - @staticmethod - def _split_op(full_op): - *prefix, op = str(full_op).rsplit("__", maxsplit=1) - prefix = prefix[0] if prefix else None - return prefix, op - - @handle_validation_prefix - def _validate_atom_op(self, custom_field, raw_op): - """Check if the `op` is compatible with the type of the custom field.""" - prefix, op = self._split_op(raw_op) - - # Check if the operator is supported for the current data_type. - supported = False - for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]: - if op in self.EXPR_BY_CATEGORY[category]: - supported = True - break - - # Check prefix - if prefix is not None: - if ( - prefix in self.DATE_COMPONENTS - and custom_field.data_type == CustomField.FieldDataType.DATE - ): - pass # ok - e.g., "year__exact" for date field - else: - supported = False # anything else is invalid - - if not supported: - raise serializers.ValidationError( - [ - _("{data_type} does not support query expr {expr!r}.").format( - data_type=custom_field.data_type, - expr=raw_op, - ), - ], - ) - - return raw_op - - def _get_serializer_field(self, custom_field, full_op): - """Return a serializers.Field for value validation.""" - prefix, op = self._split_op(full_op) - field = None - - if op in ("isnull", "exists"): - # `isnull` takes either True or False regardless of the data_type. - field = serializers.BooleanField() - elif ( - custom_field.data_type == CustomField.FieldDataType.DATE - and prefix in self.DATE_COMPONENTS - ): - # DateField admits queries in the form of `year__exact`, etc. These take integers. - field = serializers.IntegerField() - elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK: - # We can be more specific here and make sure the value is a list. - field = serializers.ListField(child=serializers.IntegerField()) - elif custom_field.data_type == CustomField.FieldDataType.SELECT: - # We use this custom field to permit SELECT option names. - field = SelectField(custom_field) - elif custom_field.data_type == CustomField.FieldDataType.URL: - # For URL fields we don't need to be strict about validation (e.g., for istartswith). - field = serializers.CharField() - else: - # The general case: inferred from the corresponding field in CustomFieldInstance. - value_field_name = CustomFieldInstance.get_value_field_name( - custom_field.data_type, - ) - model_field = CustomFieldInstance._meta.get_field(value_field_name) - field_name = model_field.deconstruct()[0] - field_class, field_kwargs = self._model_serializer.build_standard_field( - field_name, - model_field, - ) - field = field_class(**field_kwargs) - field.allow_null = False - - # Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation. - # See https://github.com/paperless-ngx/paperless-ngx/issues/7361. - if isinstance(field, serializers.CharField): - field.allow_blank = True - - if op == "in": - # `in` takes a list of values. - field = serializers.ListField(child=field, allow_empty=False) - elif op == "range": - # `range` takes a list of values, i.e., [start, end]. - field = serializers.ListField( - child=field, - min_length=2, - max_length=2, - ) - - return field - - @handle_validation_prefix - def _validate_atom_value(self, custom_field, op, value): - """Check if `value` is valid for the custom field and `op`. Returns the validated value.""" - serializer_field = self._get_serializer_field(custom_field, op) - return serializer_field.run_validation(value) - - def _parse_atom_doc_link_contains(self, custom_field, value) -> Q: - """ - Handles document link `contains` in a way that is supported by all DB backends. - """ - - # If the value is an empty set, - # this is trivially true for any document with not null document links. - if not value: - return Q( - custom_fields__field=custom_field, - custom_fields__value_document_ids__isnull=False, - ) - - # First we look up reverse links from the requested documents. - links = CustomFieldInstance.objects.filter( - document_id__in=value, - field__data_type=CustomField.FieldDataType.DOCUMENTLINK, - ) - - # Check if any of the requested IDs are missing. - missing_ids = set(value) - set(link.document_id for link in links) - if missing_ids: - # The result should be an empty set in this case. - return Q(id__in=[]) - - # Take the intersection of the reverse links - this should be what we are looking for. - document_ids_we_want = functools.reduce( - operator.and_, - (set(link.value_document_ids) for link in links), - ) - - return Q(id__in=document_ids_we_want) - - @contextmanager - def _track_query_depth(self): - # guard against queries that are too deeply nested - self._current_depth += 1 - if self._current_depth > self._max_query_depth: - raise serializers.ValidationError([_("Maximum nesting depth exceeded.")]) - try: - yield - finally: - self._current_depth -= 1 - - -@extend_schema_field(serializers.CharField) -class CustomFieldQueryFilter(Filter): - def __init__(self, validation_prefix): - """ - A filter that filters documents based on custom field name and value. - - Args: - validation_prefix: Used to generate the ValidationError message. - """ - super().__init__() - self._validation_prefix = validation_prefix - - def filter(self, qs, value): - if not value: - return qs - - parser = CustomFieldQueryParser( - self._validation_prefix, - max_query_depth=CUSTOM_FIELD_QUERY_MAX_DEPTH, - max_atom_count=CUSTOM_FIELD_QUERY_MAX_ATOMS, - ) - q, annotations = parser.parse(value) - - return qs.annotate(**annotations).filter(q) - - -class DocumentFilterSet(FilterSet): - is_tagged = BooleanFilter( - label="Is tagged", - field_name="tags", - lookup_expr="isnull", - exclude=True, - ) - - tags__id__all = ObjectFilter(field_name="tags") - - tags__id__none = ObjectFilter(field_name="tags", exclude=True) - - tags__id__in = ObjectFilter(field_name="tags", in_list=True) - - correspondent__id__none = ObjectFilter(field_name="correspondent", exclude=True) - - document_type__id__none = ObjectFilter(field_name="document_type", exclude=True) - - storage_path__id__none = ObjectFilter(field_name="storage_path", exclude=True) - - is_in_inbox = InboxFilter() - - title_content = TitleContentFilter() - - owner__id__none = ObjectFilter(field_name="owner", exclude=True) - - custom_fields__icontains = CustomFieldsFilter() - - custom_fields__id__all = ObjectFilter(field_name="custom_fields__field") - - custom_fields__id__none = ObjectFilter( - field_name="custom_fields__field", - exclude=True, - ) - - custom_fields__id__in = ObjectFilter( - field_name="custom_fields__field", - in_list=True, - ) - - has_custom_fields = BooleanFilter( - label="Has custom field", - field_name="custom_fields", - lookup_expr="isnull", - exclude=True, - ) - - custom_field_query = CustomFieldQueryFilter("custom_field_query") - - shared_by__id = SharedByUser() - - mime_type = MimeTypeFilter() - - class Meta: - model = Document - fields = { - "id": ID_KWARGS, - "title": CHAR_KWARGS, - "content": CHAR_KWARGS, - "archive_serial_number": INT_KWARGS, - "created": DATE_KWARGS, - "added": DATE_KWARGS, - "modified": DATE_KWARGS, - "original_filename": CHAR_KWARGS, - "checksum": CHAR_KWARGS, - "correspondent": ["isnull"], - "correspondent__id": ID_KWARGS, - "correspondent__name": CHAR_KWARGS, - "tags__id": ID_KWARGS, - "tags__name": CHAR_KWARGS, - "document_type": ["isnull"], - "document_type__id": ID_KWARGS, - "document_type__name": CHAR_KWARGS, - "storage_path": ["isnull"], - "storage_path__id": ID_KWARGS, - "storage_path__name": CHAR_KWARGS, - "owner": ["isnull"], - "owner__id": ID_KWARGS, - "custom_fields": ["icontains"], - } - - -class ShareLinkFilterSet(FilterSet): - class Meta: - model = ShareLink - fields = { - "created": DATE_KWARGS, - "expiration": DATE_KWARGS, - } - - -class PaperlessTaskFilterSet(FilterSet): - acknowledged = BooleanFilter( - label="Acknowledged", - field_name="acknowledged", - ) - - class Meta: - model = PaperlessTask - fields = { - "type": ["exact"], - "task_name": ["exact"], - "status": ["exact"], - } - - -class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter): - """ - A filter backend that limits results to those where the requesting user - has read object level permissions, owns the objects, or objects without - an owner (for backwards compat) - """ - - def filter_queryset(self, request, queryset, view): - objects_with_perms = super().filter_queryset(request, queryset, view) - objects_owned = queryset.filter(owner=request.user) - objects_unowned = queryset.filter(owner__isnull=True) - return objects_with_perms | objects_owned | objects_unowned - - -class ObjectOwnedPermissionsFilter(ObjectPermissionsFilter): - """ - A filter backend that limits results to those where the requesting user - owns the objects or objects without an owner (for backwards compat) - """ - - def filter_queryset(self, request, queryset, view): - if request.user.is_superuser: - return queryset - objects_owned = queryset.filter(owner=request.user) - objects_unowned = queryset.filter(owner__isnull=True) - return objects_owned | objects_unowned - - -class DocumentsOrderingFilter(OrderingFilter): - field_name = "ordering" - prefix = "custom_field_" - - def filter_queryset(self, request, queryset, view): - param = request.query_params.get("ordering") - if param and self.prefix in param: - custom_field_id = int(param.split(self.prefix)[1]) - try: - field = CustomField.objects.get(pk=custom_field_id) - except CustomField.DoesNotExist: - raise serializers.ValidationError( - {self.prefix + str(custom_field_id): [_("Custom field not found")]}, - ) - - annotation = None - match field.data_type: - case CustomField.FieldDataType.STRING: - annotation = Subquery( - CustomFieldInstance.objects.filter( - document_id=OuterRef("id"), - field_id=custom_field_id, - ).values("value_text")[:1], - ) - case CustomField.FieldDataType.INT: - annotation = Subquery( - CustomFieldInstance.objects.filter( - document_id=OuterRef("id"), - field_id=custom_field_id, - ).values("value_int")[:1], - ) - case CustomField.FieldDataType.FLOAT: - annotation = Subquery( - CustomFieldInstance.objects.filter( - document_id=OuterRef("id"), - field_id=custom_field_id, - ).values("value_float")[:1], - ) - case CustomField.FieldDataType.DATE: - annotation = Subquery( - CustomFieldInstance.objects.filter( - document_id=OuterRef("id"), - field_id=custom_field_id, - ).values("value_date")[:1], - ) - case CustomField.FieldDataType.MONETARY: - annotation = Subquery( - CustomFieldInstance.objects.filter( - document_id=OuterRef("id"), - field_id=custom_field_id, - ).values("value_monetary_amount")[:1], - ) - case CustomField.FieldDataType.SELECT: - # Select options are a little more complicated since the value is the id of the option, not - # the label. Additionally, to support sqlite we can't use StringAgg, so we need to create a - # case statement for each option, setting the value to the index of the option in a list - # sorted by label, and then summing the results to give a single value for the annotation - - select_options = sorted( - field.extra_data.get("select_options", []), - key=lambda x: x.get("label"), - ) - whens = [ - When( - custom_fields__field_id=custom_field_id, - custom_fields__value_select=option.get("id"), - then=Value(idx, output_field=IntegerField()), - ) - for idx, option in enumerate(select_options) - ] - whens.append( - When( - custom_fields__field_id=custom_field_id, - custom_fields__value_select__isnull=True, - then=Value( - len(select_options), - output_field=IntegerField(), - ), - ), - ) - annotation = Sum( - Case( - *whens, - default=Value(0), - output_field=IntegerField(), - ), - ) - case CustomField.FieldDataType.DOCUMENTLINK: - annotation = Subquery( - CustomFieldInstance.objects.filter( - document_id=OuterRef("id"), - field_id=custom_field_id, - ).values("value_document_ids")[:1], - ) - case CustomField.FieldDataType.URL: - annotation = Subquery( - CustomFieldInstance.objects.filter( - document_id=OuterRef("id"), - field_id=custom_field_id, - ).values("value_url")[:1], - ) - case CustomField.FieldDataType.BOOL: - annotation = Subquery( - CustomFieldInstance.objects.filter( - document_id=OuterRef("id"), - field_id=custom_field_id, - ).values("value_bool")[:1], - ) - - if not annotation: - # Only happens if a new data type is added and not handled here - raise ValueError("Invalid custom field data type") - - queryset = ( - queryset.annotate( - # We need to annotate the queryset with the custom field value - custom_field_value=annotation, - # We also need to annotate the queryset with a boolean for sorting whether the field exists - has_field=Exists( - CustomFieldInstance.objects.filter( - document_id=OuterRef("id"), - field_id=custom_field_id, - ), - ), - ) - .order_by( - "-has_field", - param.replace( - self.prefix + str(custom_field_id), - "custom_field_value", - ), - ) - .distinct() - ) - - return super().filter_queryset(request, queryset, view) diff --git a/src/paperless/filters.py b/src/paperless/filters.py index a3c09d50f..d5f25d642 100644 --- a/src/paperless/filters.py +++ b/src/paperless/filters.py @@ -1,8 +1,70 @@ +from __future__ import annotations + +import functools +import inspect +import json +import operator +from contextlib import contextmanager +from typing import TYPE_CHECKING + from django.contrib.auth.models import Group from django.contrib.auth.models import User +from django.contrib.contenttypes.models import ContentType +from django.db.models import Case +from django.db.models import CharField +from django.db.models import Count +from django.db.models import Exists +from django.db.models import IntegerField +from django.db.models import OuterRef +from django.db.models import Q +from django.db.models import Subquery +from django.db.models import Sum +from django.db.models import Value +from django.db.models import When +from django.db.models.functions import Cast +from django.utils.translation import gettext_lazy as _ +from django_filters.rest_framework import BooleanFilter +from django_filters.rest_framework import Filter from django_filters.rest_framework import FilterSet +from drf_spectacular.utils import extend_schema_field +from guardian.utils import get_group_obj_perms_model +from guardian.utils import get_user_obj_perms_model +from rest_framework import serializers +from rest_framework.filters import OrderingFilter +from rest_framework_guardian.filters import ObjectPermissionsFilter -from documents.filters import CHAR_KWARGS +from paperless.models import Correspondent +from paperless.models import CustomField +from paperless.models import CustomFieldInstance +from paperless.models import Document +from paperless.models import DocumentType +from paperless.models import PaperlessTask +from paperless.models import ShareLink +from paperless.models import StoragePath +from paperless.models import Tag + +if TYPE_CHECKING: + from collections.abc import Callable + +CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"] +ID_KWARGS = ["in", "exact"] +INT_KWARGS = ["exact", "gt", "gte", "lt", "lte", "isnull"] +DATE_KWARGS = [ + "year", + "month", + "day", + "date__gt", + "date__gte", + "gt", + "gte", + "date__lt", + "date__lte", + "lt", + "lte", +] + +CUSTOM_FIELD_QUERY_MAX_DEPTH = 10 +CUSTOM_FIELD_QUERY_MAX_ATOMS = 20 class UserFilterSet(FilterSet): @@ -15,3 +77,888 @@ class GroupFilterSet(FilterSet): class Meta: model = Group fields = {"name": CHAR_KWARGS} + + +class CorrespondentFilterSet(FilterSet): + class Meta: + model = Correspondent + fields = { + "id": ID_KWARGS, + "name": CHAR_KWARGS, + } + + +class TagFilterSet(FilterSet): + class Meta: + model = Tag + fields = { + "id": ID_KWARGS, + "name": CHAR_KWARGS, + } + + +class DocumentTypeFilterSet(FilterSet): + class Meta: + model = DocumentType + fields = { + "id": ID_KWARGS, + "name": CHAR_KWARGS, + } + + +class StoragePathFilterSet(FilterSet): + class Meta: + model = StoragePath + fields = { + "id": ID_KWARGS, + "name": CHAR_KWARGS, + "path": CHAR_KWARGS, + } + + +class ObjectFilter(Filter): + def __init__(self, *, exclude=False, in_list=False, field_name=""): + super().__init__() + self.exclude = exclude + self.in_list = in_list + self.field_name = field_name + + def filter(self, qs, value): + if not value: + return qs + + try: + object_ids = [int(x) for x in value.split(",")] + except ValueError: + return qs + + if self.in_list: + qs = qs.filter(**{f"{self.field_name}__id__in": object_ids}).distinct() + else: + for obj_id in object_ids: + if self.exclude: + qs = qs.exclude(**{f"{self.field_name}__id": obj_id}) + else: + qs = qs.filter(**{f"{self.field_name}__id": obj_id}) + + return qs + + +@extend_schema_field(serializers.BooleanField) +class InboxFilter(Filter): + def filter(self, qs, value): + if value == "true": + return qs.filter(tags__is_inbox_tag=True) + elif value == "false": + return qs.exclude(tags__is_inbox_tag=True) + else: + return qs + + +@extend_schema_field(serializers.CharField) +class TitleContentFilter(Filter): + def filter(self, qs, value): + if value: + return qs.filter(Q(title__icontains=value) | Q(content__icontains=value)) + else: + return qs + + +@extend_schema_field(serializers.BooleanField) +class SharedByUser(Filter): + def filter(self, qs, value): + ctype = ContentType.objects.get_for_model(self.model) + UserObjectPermission = get_user_obj_perms_model() + GroupObjectPermission = get_group_obj_perms_model() + # see https://github.com/paperless-ngx/paperless-ngx/issues/5392, we limit subqueries + # to 1 because Postgres doesn't like returning > 1 row, but all we care about is > 0 + return ( + qs.filter( + owner_id=value, + ) + .annotate( + num_shared_users=Count( + UserObjectPermission.objects.filter( + content_type=ctype, + object_pk=Cast(OuterRef("pk"), CharField()), + ).values("user_id")[:1], + ), + ) + .annotate( + num_shared_groups=Count( + GroupObjectPermission.objects.filter( + content_type=ctype, + object_pk=Cast(OuterRef("pk"), CharField()), + ).values("group_id")[:1], + ), + ) + .filter( + Q(num_shared_users__gt=0) | Q(num_shared_groups__gt=0), + ) + if value is not None + else qs + ) + + +class CustomFieldFilterSet(FilterSet): + class Meta: + model = CustomField + fields = { + "id": ID_KWARGS, + "name": CHAR_KWARGS, + } + + +@extend_schema_field(serializers.CharField) +class CustomFieldsFilter(Filter): + def filter(self, qs, value): + if value: + fields_with_matching_selects = CustomField.objects.filter( + extra_data__icontains=value, + ) + option_ids = [] + if fields_with_matching_selects.count() > 0: + for field in fields_with_matching_selects: + options = field.extra_data.get("select_options", []) + for _, option in enumerate(options): + if option.get("label").lower().find(value.lower()) != -1: + option_ids.extend([option.get("id")]) + return ( + qs.filter(custom_fields__field__name__icontains=value) + | qs.filter(custom_fields__value_text__icontains=value) + | qs.filter(custom_fields__value_bool__icontains=value) + | qs.filter(custom_fields__value_int__icontains=value) + | qs.filter(custom_fields__value_float__icontains=value) + | qs.filter(custom_fields__value_date__icontains=value) + | qs.filter(custom_fields__value_url__icontains=value) + | qs.filter(custom_fields__value_monetary__icontains=value) + | qs.filter(custom_fields__value_document_ids__icontains=value) + | qs.filter(custom_fields__value_select__in=option_ids) + ) + else: + return qs + + +class MimeTypeFilter(Filter): + def filter(self, qs, value): + if value: + return qs.filter(mime_type__icontains=value) + else: + return qs + + +class SelectField(serializers.CharField): + def __init__(self, custom_field: CustomField): + self._options = custom_field.extra_data["select_options"] + super().__init__(max_length=16) + + def to_internal_value(self, data): + # If the supplied value is the option label instead of the ID + try: + data = next( + option.get("id") + for option in self._options + if option.get("label") == data + ) + except StopIteration: + pass + return super().to_internal_value(data) + + +def handle_validation_prefix(func: Callable): + """ + Catch ValidationErrors raised by the wrapped function + and add a prefix to the exception detail to track what causes the exception, + similar to nested serializers. + """ + + def wrapper(*args, validation_prefix=None, **kwargs): + try: + return func(*args, **kwargs) + except serializers.ValidationError as e: + raise serializers.ValidationError({validation_prefix: e.detail}) + + # Update the signature to include the validation_prefix argument + old_sig = inspect.signature(func) + new_param = inspect.Parameter("validation_prefix", inspect.Parameter.KEYWORD_ONLY) + new_sig = old_sig.replace(parameters=[*old_sig.parameters.values(), new_param]) + + # Apply functools.wraps and manually set the new signature + functools.update_wrapper(wrapper, func) + wrapper.__signature__ = new_sig + + return wrapper + + +class CustomFieldQueryParser: + EXPR_BY_CATEGORY = { + "basic": ["exact", "in", "isnull", "exists"], + "string": [ + "icontains", + "istartswith", + "iendswith", + ], + "arithmetic": [ + "gt", + "gte", + "lt", + "lte", + "range", + ], + "containment": ["contains"], + } + + SUPPORTED_EXPR_CATEGORIES = { + CustomField.FieldDataType.STRING: ("basic", "string"), + CustomField.FieldDataType.URL: ("basic", "string"), + CustomField.FieldDataType.DATE: ("basic", "arithmetic"), + CustomField.FieldDataType.BOOL: ("basic",), + CustomField.FieldDataType.INT: ("basic", "arithmetic"), + CustomField.FieldDataType.FLOAT: ("basic", "arithmetic"), + CustomField.FieldDataType.MONETARY: ("basic", "string", "arithmetic"), + CustomField.FieldDataType.DOCUMENTLINK: ("basic", "containment"), + CustomField.FieldDataType.SELECT: ("basic",), + } + + DATE_COMPONENTS = [ + "year", + "iso_year", + "month", + "day", + "week", + "week_day", + "iso_week_day", + "quarter", + ] + + def __init__( + self, + validation_prefix, + max_query_depth=10, + max_atom_count=20, + ) -> None: + """ + A helper class that parses the query string into a `django.db.models.Q` for filtering + documents based on custom field values. + + The syntax of the query expression is illustrated with the below pseudo code rules: + 1. parse([`custom_field`, "exists", true]): + matches documents with Q(custom_fields__field=`custom_field`) + 2. parse([`custom_field`, "exists", false]): + matches documents with ~Q(custom_fields__field=`custom_field`) + 3. parse([`custom_field`, `op`, `value`]): + matches documents with + Q(custom_fields__field=`custom_field`, custom_fields__value_`type`__`op`= `value`) + 4. parse(["AND", [`q0`, `q1`, ..., `qn`]]) + -> parse(`q0`) & parse(`q1`) & ... & parse(`qn`) + 5. parse(["OR", [`q0`, `q1`, ..., `qn`]]) + -> parse(`q0`) | parse(`q1`) | ... | parse(`qn`) + 6. parse(["NOT", `q`]) + -> ~parse(`q`) + + Args: + validation_prefix: Used to generate the ValidationError message. + max_query_depth: Limits the maximum nesting depth of queries. + max_atom_count: Limits the maximum number of atoms (i.e., rule 1, 2, 3) in the query. + + `max_query_depth` and `max_atom_count` can be set to guard against generating arbitrarily + complex SQL queries. + """ + self._custom_fields: dict[int | str, CustomField] = {} + self._validation_prefix = validation_prefix + # Dummy ModelSerializer used to convert a Django models.Field to serializers.Field. + self._model_serializer = serializers.ModelSerializer() + # Used for sanity check + self._max_query_depth = max_query_depth + self._max_atom_count = max_atom_count + self._current_depth = 0 + self._atom_count = 0 + # The set of annotations that we need to apply to the queryset + self._annotations = {} + + def parse(self, query: str) -> tuple[Q, dict[str, Count]]: + """ + Parses the query string into a `django.db.models.Q` + and a set of annotations to be applied to the queryset. + """ + try: + expr = json.loads(query) + except json.JSONDecodeError: + raise serializers.ValidationError( + {self._validation_prefix: [_("Value must be valid JSON.")]}, + ) + return ( + self._parse_expr(expr, validation_prefix=self._validation_prefix), + self._annotations, + ) + + @handle_validation_prefix + def _parse_expr(self, expr) -> Q: + """ + Applies rule (1, 2, 3) or (4, 5, 6) based on the length of the expr. + """ + with self._track_query_depth(): + if isinstance(expr, list | tuple): + if len(expr) == 2: + return self._parse_logical_expr(*expr) + elif len(expr) == 3: + return self._parse_atom(*expr) + raise serializers.ValidationError( + [_("Invalid custom field query expression")], + ) + + @handle_validation_prefix + def _parse_expr_list(self, exprs) -> list[Q]: + """ + Handles [`q0`, `q1`, ..., `qn`] in rule 4 & 5. + """ + if not isinstance(exprs, list | tuple) or not exprs: + raise serializers.ValidationError( + [_("Invalid expression list. Must be nonempty.")], + ) + return [ + self._parse_expr(expr, validation_prefix=i) for i, expr in enumerate(exprs) + ] + + def _parse_logical_expr(self, op, args) -> Q: + """ + Handles rule 4, 5, 6. + """ + op_lower = op.lower() + + if op_lower == "not": + return ~self._parse_expr(args, validation_prefix=1) + + if op_lower == "and": + op_func = operator.and_ + elif op_lower == "or": + op_func = operator.or_ + else: + raise serializers.ValidationError( + {"0": [_("Invalid logical operator {op!r}").format(op=op)]}, + ) + + qs = self._parse_expr_list(args, validation_prefix="1") + return functools.reduce(op_func, qs) + + def _parse_atom(self, id_or_name, op, value) -> Q: + """ + Handles rule 1, 2, 3. + """ + # Guard against queries with too many conditions. + self._atom_count += 1 + if self._atom_count > self._max_atom_count: + raise serializers.ValidationError( + [_("Maximum number of query conditions exceeded.")], + ) + + custom_field = self._get_custom_field(id_or_name, validation_prefix="0") + op = self._validate_atom_op(custom_field, op, validation_prefix="1") + value = self._validate_atom_value( + custom_field, + op, + value, + validation_prefix="2", + ) + + # Needed because not all DB backends support Array __contains + if ( + custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK + and op == "contains" + ): + return self._parse_atom_doc_link_contains(custom_field, value) + + value_field_name = CustomFieldInstance.get_value_field_name( + custom_field.data_type, + ) + if ( + custom_field.data_type == CustomField.FieldDataType.MONETARY + and op in self.EXPR_BY_CATEGORY["arithmetic"] + ): + value_field_name = "value_monetary_amount" + has_field = Q(custom_fields__field=custom_field) + + # We need to use an annotation here because different atoms + # might be referring to different instances of custom fields. + annotation_name = f"_custom_field_filter_{len(self._annotations)}" + + # Our special exists operator. + if op == "exists": + annotation = Count("custom_fields", filter=has_field) + # A Document should have > 0 match if it has this field, or 0 if doesn't. + query_op = "gt" if value else "exact" + query = Q(**{f"{annotation_name}__{query_op}": 0}) + else: + # Check if 1) custom field name matches, and 2) value satisfies condition + field_filter = has_field & Q( + **{f"custom_fields__{value_field_name}__{op}": value}, + ) + # Annotate how many matching custom fields each document has + annotation = Count("custom_fields", filter=field_filter) + # Filter document by count + query = Q(**{f"{annotation_name}__gt": 0}) + + self._annotations[annotation_name] = annotation + return query + + @handle_validation_prefix + def _get_custom_field(self, id_or_name): + """Get the CustomField instance by id or name.""" + if id_or_name in self._custom_fields: + return self._custom_fields[id_or_name] + + kwargs = ( + {"id": id_or_name} if isinstance(id_or_name, int) else {"name": id_or_name} + ) + try: + custom_field = CustomField.objects.get(**kwargs) + except CustomField.DoesNotExist: + raise serializers.ValidationError( + [_("{name!r} is not a valid custom field.").format(name=id_or_name)], + ) + self._custom_fields[custom_field.id] = custom_field + self._custom_fields[custom_field.name] = custom_field + return custom_field + + @staticmethod + def _split_op(full_op): + *prefix, op = str(full_op).rsplit("__", maxsplit=1) + prefix = prefix[0] if prefix else None + return prefix, op + + @handle_validation_prefix + def _validate_atom_op(self, custom_field, raw_op): + """Check if the `op` is compatible with the type of the custom field.""" + prefix, op = self._split_op(raw_op) + + # Check if the operator is supported for the current data_type. + supported = False + for category in self.SUPPORTED_EXPR_CATEGORIES[custom_field.data_type]: + if op in self.EXPR_BY_CATEGORY[category]: + supported = True + break + + # Check prefix + if prefix is not None: + if ( + prefix in self.DATE_COMPONENTS + and custom_field.data_type == CustomField.FieldDataType.DATE + ): + pass # ok - e.g., "year__exact" for date field + else: + supported = False # anything else is invalid + + if not supported: + raise serializers.ValidationError( + [ + _("{data_type} does not support query expr {expr!r}.").format( + data_type=custom_field.data_type, + expr=raw_op, + ), + ], + ) + + return raw_op + + def _get_serializer_field(self, custom_field, full_op): + """Return a serializers.Field for value validation.""" + prefix, op = self._split_op(full_op) + field = None + + if op in ("isnull", "exists"): + # `isnull` takes either True or False regardless of the data_type. + field = serializers.BooleanField() + elif ( + custom_field.data_type == CustomField.FieldDataType.DATE + and prefix in self.DATE_COMPONENTS + ): + # DateField admits queries in the form of `year__exact`, etc. These take integers. + field = serializers.IntegerField() + elif custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK: + # We can be more specific here and make sure the value is a list. + field = serializers.ListField(child=serializers.IntegerField()) + elif custom_field.data_type == CustomField.FieldDataType.SELECT: + # We use this custom field to permit SELECT option names. + field = SelectField(custom_field) + elif custom_field.data_type == CustomField.FieldDataType.URL: + # For URL fields we don't need to be strict about validation (e.g., for istartswith). + field = serializers.CharField() + else: + # The general case: inferred from the corresponding field in CustomFieldInstance. + value_field_name = CustomFieldInstance.get_value_field_name( + custom_field.data_type, + ) + model_field = CustomFieldInstance._meta.get_field(value_field_name) + field_name = model_field.deconstruct()[0] + field_class, field_kwargs = self._model_serializer.build_standard_field( + field_name, + model_field, + ) + field = field_class(**field_kwargs) + field.allow_null = False + + # Need to set allow_blank manually because of the inconsistency in CustomFieldInstance validation. + # See https://github.com/paperless-ngx/paperless-ngx/issues/7361. + if isinstance(field, serializers.CharField): + field.allow_blank = True + + if op == "in": + # `in` takes a list of values. + field = serializers.ListField(child=field, allow_empty=False) + elif op == "range": + # `range` takes a list of values, i.e., [start, end]. + field = serializers.ListField( + child=field, + min_length=2, + max_length=2, + ) + + return field + + @handle_validation_prefix + def _validate_atom_value(self, custom_field, op, value): + """Check if `value` is valid for the custom field and `op`. Returns the validated value.""" + serializer_field = self._get_serializer_field(custom_field, op) + return serializer_field.run_validation(value) + + def _parse_atom_doc_link_contains(self, custom_field, value) -> Q: + """ + Handles document link `contains` in a way that is supported by all DB backends. + """ + + # If the value is an empty set, + # this is trivially true for any document with not null document links. + if not value: + return Q( + custom_fields__field=custom_field, + custom_fields__value_document_ids__isnull=False, + ) + + # First we look up reverse links from the requested documents. + links = CustomFieldInstance.objects.filter( + document_id__in=value, + field__data_type=CustomField.FieldDataType.DOCUMENTLINK, + ) + + # Check if any of the requested IDs are missing. + missing_ids = set(value) - set(link.document_id for link in links) + if missing_ids: + # The result should be an empty set in this case. + return Q(id__in=[]) + + # Take the intersection of the reverse links - this should be what we are looking for. + document_ids_we_want = functools.reduce( + operator.and_, + (set(link.value_document_ids) for link in links), + ) + + return Q(id__in=document_ids_we_want) + + @contextmanager + def _track_query_depth(self): + # guard against queries that are too deeply nested + self._current_depth += 1 + if self._current_depth > self._max_query_depth: + raise serializers.ValidationError([_("Maximum nesting depth exceeded.")]) + try: + yield + finally: + self._current_depth -= 1 + + +@extend_schema_field(serializers.CharField) +class CustomFieldQueryFilter(Filter): + def __init__(self, validation_prefix): + """ + A filter that filters documents based on custom field name and value. + + Args: + validation_prefix: Used to generate the ValidationError message. + """ + super().__init__() + self._validation_prefix = validation_prefix + + def filter(self, qs, value): + if not value: + return qs + + parser = CustomFieldQueryParser( + self._validation_prefix, + max_query_depth=CUSTOM_FIELD_QUERY_MAX_DEPTH, + max_atom_count=CUSTOM_FIELD_QUERY_MAX_ATOMS, + ) + q, annotations = parser.parse(value) + + return qs.annotate(**annotations).filter(q) + + +class DocumentFilterSet(FilterSet): + is_tagged = BooleanFilter( + label="Is tagged", + field_name="tags", + lookup_expr="isnull", + exclude=True, + ) + + tags__id__all = ObjectFilter(field_name="tags") + + tags__id__none = ObjectFilter(field_name="tags", exclude=True) + + tags__id__in = ObjectFilter(field_name="tags", in_list=True) + + correspondent__id__none = ObjectFilter(field_name="correspondent", exclude=True) + + document_type__id__none = ObjectFilter(field_name="document_type", exclude=True) + + storage_path__id__none = ObjectFilter(field_name="storage_path", exclude=True) + + is_in_inbox = InboxFilter() + + title_content = TitleContentFilter() + + owner__id__none = ObjectFilter(field_name="owner", exclude=True) + + custom_fields__icontains = CustomFieldsFilter() + + custom_fields__id__all = ObjectFilter(field_name="custom_fields__field") + + custom_fields__id__none = ObjectFilter( + field_name="custom_fields__field", + exclude=True, + ) + + custom_fields__id__in = ObjectFilter( + field_name="custom_fields__field", + in_list=True, + ) + + has_custom_fields = BooleanFilter( + label="Has custom field", + field_name="custom_fields", + lookup_expr="isnull", + exclude=True, + ) + + custom_field_query = CustomFieldQueryFilter("custom_field_query") + + shared_by__id = SharedByUser() + + mime_type = MimeTypeFilter() + + class Meta: + model = Document + fields = { + "id": ID_KWARGS, + "title": CHAR_KWARGS, + "content": CHAR_KWARGS, + "archive_serial_number": INT_KWARGS, + "created": DATE_KWARGS, + "added": DATE_KWARGS, + "modified": DATE_KWARGS, + "original_filename": CHAR_KWARGS, + "checksum": CHAR_KWARGS, + "correspondent": ["isnull"], + "correspondent__id": ID_KWARGS, + "correspondent__name": CHAR_KWARGS, + "tags__id": ID_KWARGS, + "tags__name": CHAR_KWARGS, + "document_type": ["isnull"], + "document_type__id": ID_KWARGS, + "document_type__name": CHAR_KWARGS, + "storage_path": ["isnull"], + "storage_path__id": ID_KWARGS, + "storage_path__name": CHAR_KWARGS, + "owner": ["isnull"], + "owner__id": ID_KWARGS, + "custom_fields": ["icontains"], + } + + +class ShareLinkFilterSet(FilterSet): + class Meta: + model = ShareLink + fields = { + "created": DATE_KWARGS, + "expiration": DATE_KWARGS, + } + + +class PaperlessTaskFilterSet(FilterSet): + acknowledged = BooleanFilter( + label="Acknowledged", + field_name="acknowledged", + ) + + class Meta: + model = PaperlessTask + fields = { + "type": ["exact"], + "task_name": ["exact"], + "status": ["exact"], + } + + +class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter): + """ + A filter backend that limits results to those where the requesting user + has read object level permissions, owns the objects, or objects without + an owner (for backwards compat) + """ + + def filter_queryset(self, request, queryset, view): + objects_with_perms = super().filter_queryset(request, queryset, view) + objects_owned = queryset.filter(owner=request.user) + objects_unowned = queryset.filter(owner__isnull=True) + return objects_with_perms | objects_owned | objects_unowned + + +class ObjectOwnedPermissionsFilter(ObjectPermissionsFilter): + """ + A filter backend that limits results to those where the requesting user + owns the objects or objects without an owner (for backwards compat) + """ + + def filter_queryset(self, request, queryset, view): + if request.user.is_superuser: + return queryset + objects_owned = queryset.filter(owner=request.user) + objects_unowned = queryset.filter(owner__isnull=True) + return objects_owned | objects_unowned + + +class DocumentsOrderingFilter(OrderingFilter): + field_name = "ordering" + prefix = "custom_field_" + + def filter_queryset(self, request, queryset, view): + param = request.query_params.get("ordering") + if param and self.prefix in param: + custom_field_id = int(param.split(self.prefix)[1]) + try: + field = CustomField.objects.get(pk=custom_field_id) + except CustomField.DoesNotExist: + raise serializers.ValidationError( + {self.prefix + str(custom_field_id): [_("Custom field not found")]}, + ) + + annotation = None + match field.data_type: + case CustomField.FieldDataType.STRING: + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_text")[:1], + ) + case CustomField.FieldDataType.INT: + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_int")[:1], + ) + case CustomField.FieldDataType.FLOAT: + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_float")[:1], + ) + case CustomField.FieldDataType.DATE: + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_date")[:1], + ) + case CustomField.FieldDataType.MONETARY: + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_monetary_amount")[:1], + ) + case CustomField.FieldDataType.SELECT: + # Select options are a little more complicated since the value is the id of the option, not + # the label. Additionally, to support sqlite we can't use StringAgg, so we need to create a + # case statement for each option, setting the value to the index of the option in a list + # sorted by label, and then summing the results to give a single value for the annotation + + select_options = sorted( + field.extra_data.get("select_options", []), + key=lambda x: x.get("label"), + ) + whens = [ + When( + custom_fields__field_id=custom_field_id, + custom_fields__value_select=option.get("id"), + then=Value(idx, output_field=IntegerField()), + ) + for idx, option in enumerate(select_options) + ] + whens.append( + When( + custom_fields__field_id=custom_field_id, + custom_fields__value_select__isnull=True, + then=Value( + len(select_options), + output_field=IntegerField(), + ), + ), + ) + annotation = Sum( + Case( + *whens, + default=Value(0), + output_field=IntegerField(), + ), + ) + case CustomField.FieldDataType.DOCUMENTLINK: + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_document_ids")[:1], + ) + case CustomField.FieldDataType.URL: + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_url")[:1], + ) + case CustomField.FieldDataType.BOOL: + annotation = Subquery( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ).values("value_bool")[:1], + ) + + if not annotation: + # Only happens if a new data type is added and not handled here + raise ValueError("Invalid custom field data type") + + queryset = ( + queryset.annotate( + # We need to annotate the queryset with the custom field value + custom_field_value=annotation, + # We also need to annotate the queryset with a boolean for sorting whether the field exists + has_field=Exists( + CustomFieldInstance.objects.filter( + document_id=OuterRef("id"), + field_id=custom_field_id, + ), + ), + ) + .order_by( + "-has_field", + param.replace( + self.prefix + str(custom_field_id), + "custom_field_value", + ), + ) + .distinct() + ) + + return super().filter_queryset(request, queryset, view) diff --git a/src/paperless/views.py b/src/paperless/views.py index 93c26aced..d960225f2 100644 --- a/src/paperless/views.py +++ b/src/paperless/views.py @@ -89,17 +89,6 @@ from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ReadOnlyModelViewSet from rest_framework.viewsets import ViewSet -from documents.filters import CorrespondentFilterSet -from documents.filters import CustomFieldFilterSet -from documents.filters import DocumentFilterSet -from documents.filters import DocumentsOrderingFilter -from documents.filters import DocumentTypeFilterSet -from documents.filters import ObjectOwnedOrGrantedPermissionsFilter -from documents.filters import ObjectOwnedPermissionsFilter -from documents.filters import PaperlessTaskFilterSet -from documents.filters import ShareLinkFilterSet -from documents.filters import StoragePathFilterSet -from documents.filters import TagFilterSet from documents.schema import generate_object_with_permissions_schema from documents.signals import document_updated from documents.templating.filepath import validate_filepath_template_and_render @@ -129,7 +118,18 @@ from paperless.data_models import ConsumableDocument from paperless.data_models import DocumentMetadataOverrides from paperless.data_models import DocumentSource from paperless.db import GnuPG +from paperless.filters import CorrespondentFilterSet +from paperless.filters import CustomFieldFilterSet +from paperless.filters import DocumentFilterSet +from paperless.filters import DocumentsOrderingFilter +from paperless.filters import DocumentTypeFilterSet from paperless.filters import GroupFilterSet +from paperless.filters import ObjectOwnedOrGrantedPermissionsFilter +from paperless.filters import ObjectOwnedPermissionsFilter +from paperless.filters import PaperlessTaskFilterSet +from paperless.filters import ShareLinkFilterSet +from paperless.filters import StoragePathFilterSet +from paperless.filters import TagFilterSet from paperless.filters import UserFilterSet from paperless.index import DelayedQuery from paperless.mail import send_email diff --git a/src/paperless_mail/views.py b/src/paperless_mail/views.py index 62a25c60c..202e3b347 100644 --- a/src/paperless_mail/views.py +++ b/src/paperless_mail/views.py @@ -17,7 +17,7 @@ from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet -from documents.filters import ObjectOwnedOrGrantedPermissionsFilter +from paperless.filters import ObjectOwnedOrGrantedPermissionsFilter from paperless.permissions import PaperlessObjectPermissions from paperless.views import PassUserMixin from paperless.views import StandardPagination