mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-11-03 19:17:13 -05:00 
			
		
		
		
	sorting for full text queries
This commit is contained in:
		
							parent
							
								
									814d90745b
								
							
						
					
					
						commit
						8ee2e8b23d
					
				@ -7,7 +7,7 @@ from dateutil.parser import isoparse
 | 
				
			|||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
from whoosh import highlight, classify, query
 | 
					from whoosh import highlight, classify, query
 | 
				
			||||||
from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN
 | 
					from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN
 | 
				
			||||||
from whoosh.highlight import Formatter, get_text, HtmlFormatter
 | 
					from whoosh.highlight import HtmlFormatter
 | 
				
			||||||
from whoosh.index import create_in, exists_in, open_dir
 | 
					from whoosh.index import create_in, exists_in, open_dir
 | 
				
			||||||
from whoosh.qparser import MultifieldParser
 | 
					from whoosh.qparser import MultifieldParser
 | 
				
			||||||
from whoosh.qparser.dateparse import DateParserPlugin
 | 
					from whoosh.qparser.dateparse import DateParserPlugin
 | 
				
			||||||
@ -147,12 +147,10 @@ def remove_document_from_index(document):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class DelayedQuery:
 | 
					class DelayedQuery:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    def _get_query(self):
 | 
				
			||||||
    def _query(self):
 | 
					 | 
				
			||||||
        raise NotImplementedError()
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    def _get_query_filter(self):
 | 
				
			||||||
    def _query_filter(self):
 | 
					 | 
				
			||||||
        criterias = []
 | 
					        criterias = []
 | 
				
			||||||
        for k, v in self.query_params.items():
 | 
					        for k, v in self.query_params.items():
 | 
				
			||||||
            if k == 'correspondent__id':
 | 
					            if k == 'correspondent__id':
 | 
				
			||||||
@ -185,16 +183,33 @@ class DelayedQuery:
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    def _get_query_sortedby(self):
 | 
				
			||||||
    def _query_sortedby(self):
 | 
					        if not 'ordering' in self.query_params:
 | 
				
			||||||
        # if not 'ordering' in self.query_params:
 | 
					            return None, False
 | 
				
			||||||
        return None, False
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # o: str = self.query_params['ordering']
 | 
					        field: str = self.query_params['ordering']
 | 
				
			||||||
        # if o.startswith('-'):
 | 
					
 | 
				
			||||||
        #     return o[1:], True
 | 
					        sort_fields_map = {
 | 
				
			||||||
        # else:
 | 
					            "created": "created",
 | 
				
			||||||
        #     return o, False
 | 
					            "modified": "modified",
 | 
				
			||||||
 | 
					            "added": "added",
 | 
				
			||||||
 | 
					            "title": "title",
 | 
				
			||||||
 | 
					            "correspondent__name": "correspondent",
 | 
				
			||||||
 | 
					            "document_type__name": "type",
 | 
				
			||||||
 | 
					            "archive_serial_number": "asn",
 | 
				
			||||||
 | 
					            "score": None,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if field.startswith('-'):
 | 
				
			||||||
 | 
					            field = field[1:]
 | 
				
			||||||
 | 
					            reverse = True
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            reverse = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if field not in sort_fields_map:
 | 
				
			||||||
 | 
					            return None, False
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return sort_fields_map[field], reverse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, searcher: Searcher, query_params, page_size):
 | 
					    def __init__(self, searcher: Searcher, query_params, page_size):
 | 
				
			||||||
        self.searcher = searcher
 | 
					        self.searcher = searcher
 | 
				
			||||||
@ -211,13 +226,13 @@ class DelayedQuery:
 | 
				
			|||||||
        if item.start in self.saved_results:
 | 
					        if item.start in self.saved_results:
 | 
				
			||||||
            return self.saved_results[item.start]
 | 
					            return self.saved_results[item.start]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        q, mask = self._query
 | 
					        q, mask = self._get_query()
 | 
				
			||||||
        sortedby, reverse = self._query_sortedby
 | 
					        sortedby, reverse = self._get_query_sortedby()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        page: ResultsPage = self.searcher.search_page(
 | 
					        page: ResultsPage = self.searcher.search_page(
 | 
				
			||||||
            q,
 | 
					            q,
 | 
				
			||||||
            mask=mask,
 | 
					            mask=mask,
 | 
				
			||||||
            filter=self._query_filter,
 | 
					            filter=self._get_query_filter(),
 | 
				
			||||||
            pagenum=math.floor(item.start / self.page_size) + 1,
 | 
					            pagenum=math.floor(item.start / self.page_size) + 1,
 | 
				
			||||||
            pagelen=self.page_size,
 | 
					            pagelen=self.page_size,
 | 
				
			||||||
            sortedby=sortedby,
 | 
					            sortedby=sortedby,
 | 
				
			||||||
@ -227,7 +242,9 @@ class DelayedQuery:
 | 
				
			|||||||
            surround=50)
 | 
					            surround=50)
 | 
				
			||||||
        page.results.formatter = HtmlFormatter(tagname="span", between=" ... ")
 | 
					        page.results.formatter = HtmlFormatter(tagname="span", between=" ... ")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not self.first_score and len(page.results) > 0:
 | 
					        if (not self.first_score and
 | 
				
			||||||
 | 
					                len(page.results) > 0 and
 | 
				
			||||||
 | 
					                sortedby is None):
 | 
				
			||||||
            self.first_score = page.results[0].score
 | 
					            self.first_score = page.results[0].score
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.first_score:
 | 
					        if self.first_score:
 | 
				
			||||||
@ -243,8 +260,7 @@ class DelayedQuery:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class DelayedFullTextQuery(DelayedQuery):
 | 
					class DelayedFullTextQuery(DelayedQuery):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    def _get_query(self):
 | 
				
			||||||
    def _query(self):
 | 
					 | 
				
			||||||
        q_str = self.query_params['query']
 | 
					        q_str = self.query_params['query']
 | 
				
			||||||
        qp = MultifieldParser(
 | 
					        qp = MultifieldParser(
 | 
				
			||||||
            ["content", "title", "correspondent", "tag", "type"],
 | 
					            ["content", "title", "correspondent", "tag", "type"],
 | 
				
			||||||
@ -261,8 +277,7 @@ class DelayedFullTextQuery(DelayedQuery):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class DelayedMoreLikeThisQuery(DelayedQuery):
 | 
					class DelayedMoreLikeThisQuery(DelayedQuery):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    def _get_query(self):
 | 
				
			||||||
    def _query(self):
 | 
					 | 
				
			||||||
        more_like_doc_id = int(self.query_params['more_like_id'])
 | 
					        more_like_doc_id = int(self.query_params['more_like_id'])
 | 
				
			||||||
        content = Document.objects.get(id=more_like_doc_id).content
 | 
					        content = Document.objects.get(id=more_like_doc_id).content
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -471,6 +471,31 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
 | 
				
			|||||||
        self.assertNotIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
 | 
					        self.assertNotIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
 | 
				
			||||||
        self.assertIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
 | 
					        self.assertIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_search_sorting(self):
 | 
				
			||||||
 | 
					        c1 = Correspondent.objects.create(name="corres Ax")
 | 
				
			||||||
 | 
					        c2 = Correspondent.objects.create(name="corres Cx")
 | 
				
			||||||
 | 
					        c3 = Correspondent.objects.create(name="corres Bx")
 | 
				
			||||||
 | 
					        d1 = Document.objects.create(checksum="1", correspondent=c1, content="test", archive_serial_number=2, title="3")
 | 
				
			||||||
 | 
					        d2 = Document.objects.create(checksum="2", correspondent=c2, content="test", archive_serial_number=3, title="2")
 | 
				
			||||||
 | 
					        d3 = Document.objects.create(checksum="3", correspondent=c3, content="test", archive_serial_number=1, title="1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with AsyncWriter(index.open_index()) as writer:
 | 
				
			||||||
 | 
					            for doc in Document.objects.all():
 | 
				
			||||||
 | 
					                index.update_document(writer, doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def search_query(q):
 | 
				
			||||||
 | 
					            r = self.client.get("/api/documents/?query=test" + q)
 | 
				
			||||||
 | 
					            self.assertEqual(r.status_code, 200)
 | 
				
			||||||
 | 
					            return [hit['id'] for hit in r.data['results']]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertListEqual(search_query("&ordering=archive_serial_number"), [d3.id, d1.id, d2.id])
 | 
				
			||||||
 | 
					        self.assertListEqual(search_query("&ordering=-archive_serial_number"), [d2.id, d1.id, d3.id])
 | 
				
			||||||
 | 
					        self.assertListEqual(search_query("&ordering=title"), [d3.id, d2.id, d1.id])
 | 
				
			||||||
 | 
					        self.assertListEqual(search_query("&ordering=-title"), [d1.id, d2.id, d3.id])
 | 
				
			||||||
 | 
					        self.assertListEqual(search_query("&ordering=correspondent__name"), [d1.id, d3.id, d2.id])
 | 
				
			||||||
 | 
					        self.assertListEqual(search_query("&ordering=-correspondent__name"), [d2.id, d3.id, d1.id])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_statistics(self):
 | 
					    def test_statistics(self):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        doc1 = Document.objects.create(title="none1", checksum="A")
 | 
					        doc1 = Document.objects.create(title="none1", checksum="A")
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user