mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-25 15:52:35 -04:00 
			
		
		
		
	tests for the classifier and fixes for edge cases with minimal data.
This commit is contained in:
		
							parent
							
								
									2a4fe4dceb
								
							
						
					
					
						commit
						30acfdd3f1
					
				| @ -6,7 +6,8 @@ import re | ||||
| 
 | ||||
| from sklearn.feature_extraction.text import CountVectorizer | ||||
| from sklearn.neural_network import MLPClassifier | ||||
| from sklearn.preprocessing import MultiLabelBinarizer | ||||
| from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer | ||||
| from sklearn.utils.multiclass import type_of_target | ||||
| 
 | ||||
| from documents.models import Document, MatchingModel | ||||
| from paperless import settings | ||||
| @ -27,7 +28,7 @@ def preprocess_content(content): | ||||
| 
 | ||||
| class DocumentClassifier(object): | ||||
| 
 | ||||
|     FORMAT_VERSION = 5 | ||||
|     FORMAT_VERSION = 6 | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         # mtime of the model file on disk. used to prevent reloading when | ||||
| @ -54,6 +55,8 @@ class DocumentClassifier(object): | ||||
|                         "Cannor load classifier, incompatible versions.") | ||||
|                 else: | ||||
|                     if self.classifier_version > 0: | ||||
|                         # Don't be confused by this check. It's simply here | ||||
|                         # so that we wont log anything on initial reload. | ||||
|                         logger.info("Classifier updated on disk, " | ||||
|                                     "reloading classifier models") | ||||
|                     self.data_hash = pickle.load(f) | ||||
| @ -122,9 +125,14 @@ class DocumentClassifier(object): | ||||
|         labels_tags_unique = set([tag for tags in labels_tags for tag in tags]) | ||||
| 
 | ||||
|         num_tags = len(labels_tags_unique) | ||||
| 
 | ||||
|         # substract 1 since -1 (null) is also part of the classes. | ||||
|         num_correspondents = len(set(labels_correspondent)) - 1 | ||||
|         num_document_types = len(set(labels_document_type)) - 1 | ||||
| 
 | ||||
|         # union with {-1} accounts for cases where all documents have | ||||
|         # correspondents and types assigned, so -1 isnt part of labels_x, which | ||||
|         # it usually is. | ||||
|         num_correspondents = len(set(labels_correspondent) | {-1}) - 1 | ||||
|         num_document_types = len(set(labels_document_type) | {-1}) - 1 | ||||
| 
 | ||||
|         logging.getLogger(__name__).debug( | ||||
|             "{} documents, {} tag(s), {} correspondent(s), " | ||||
| @ -145,12 +153,23 @@ class DocumentClassifier(object): | ||||
|         ) | ||||
|         data_vectorized = self.data_vectorizer.fit_transform(data) | ||||
| 
 | ||||
|         self.tags_binarizer = MultiLabelBinarizer() | ||||
|         labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags) | ||||
| 
 | ||||
|         # Step 3: train the classifiers | ||||
|         if num_tags > 0: | ||||
|             logging.getLogger(__name__).debug("Training tags classifier...") | ||||
| 
 | ||||
|             if num_tags == 1: | ||||
|                 # Special case where only one tag has auto: | ||||
|                 # Fallback to binary classification. | ||||
|                 labels_tags = [label[0] if len(label) == 1 else -1 | ||||
|                                for label in labels_tags] | ||||
|                 self.tags_binarizer = LabelBinarizer() | ||||
|                 labels_tags_vectorized = self.tags_binarizer.fit_transform( | ||||
|                     labels_tags).ravel() | ||||
|             else: | ||||
|                 self.tags_binarizer = MultiLabelBinarizer() | ||||
|                 labels_tags_vectorized = self.tags_binarizer.fit_transform( | ||||
|                     labels_tags) | ||||
| 
 | ||||
|             self.tags_classifier = MLPClassifier(tol=0.01) | ||||
|             self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) | ||||
|         else: | ||||
| @ -222,6 +241,16 @@ class DocumentClassifier(object): | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             y = self.tags_classifier.predict(X) | ||||
|             tags_ids = self.tags_binarizer.inverse_transform(y)[0] | ||||
|             return tags_ids | ||||
|             if type_of_target(y).startswith('multilabel'): | ||||
|                 # the usual case when there are multiple tags. | ||||
|                 return list(tags_ids) | ||||
|             elif type_of_target(y) == 'binary' and tags_ids != -1: | ||||
|                 # This is for when we have binary classification with only one | ||||
|                 # tag and the result is to assign this tag. | ||||
|                 return [tags_ids] | ||||
|             else: | ||||
|                 # Usually binary as well with -1 as the result, but we're | ||||
|                 # going to catch everything else here as well. | ||||
|                 return [] | ||||
|         else: | ||||
|             return [] | ||||
|  | ||||
| @ -1,8 +1,10 @@ | ||||
| import tempfile | ||||
| from time import sleep | ||||
| from unittest import mock | ||||
| 
 | ||||
| from django.test import TestCase, override_settings | ||||
| 
 | ||||
| from documents.classifier import DocumentClassifier | ||||
| from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError | ||||
| from documents.models import Correspondent, Document, Tag, DocumentType | ||||
| 
 | ||||
| 
 | ||||
| @ -15,10 +17,12 @@ class TestClassifier(TestCase): | ||||
|     def generate_test_data(self): | ||||
|         self.c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO) | ||||
|         self.c2 = Correspondent.objects.create(name="c2") | ||||
|         self.c3 = Correspondent.objects.create(name="c3", matching_algorithm=Correspondent.MATCH_AUTO) | ||||
|         self.t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) | ||||
|         self.t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_ANY, pk=34, is_inbox_tag=True) | ||||
|         self.t3 = Tag.objects.create(name="t3", matching_algorithm=Tag.MATCH_AUTO, pk=45) | ||||
|         self.dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO) | ||||
|         self.dt2 = DocumentType.objects.create(name="dt2", matching_algorithm=DocumentType.MATCH_AUTO) | ||||
| 
 | ||||
|         self.doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=self.c1, checksum="A", document_type=self.dt) | ||||
|         self.doc2 = Document.objects.create(title="doc1", content="this is another document, but from c2", correspondent=self.c2, checksum="B") | ||||
| @ -59,8 +63,8 @@ class TestClassifier(TestCase): | ||||
|         self.classifier.train() | ||||
|         self.assertEqual(self.classifier.predict_correspondent(self.doc1.content), self.c1.pk) | ||||
|         self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None) | ||||
|         self.assertTupleEqual(self.classifier.predict_tags(self.doc1.content), (self.t1.pk,)) | ||||
|         self.assertTupleEqual(self.classifier.predict_tags(self.doc2.content), (self.t1.pk, self.t3.pk)) | ||||
|         self.assertListEqual(self.classifier.predict_tags(self.doc1.content), [self.t1.pk]) | ||||
|         self.assertListEqual(self.classifier.predict_tags(self.doc2.content), [self.t1.pk, self.t3.pk]) | ||||
|         self.assertEqual(self.classifier.predict_document_type(self.doc1.content), self.dt.pk) | ||||
|         self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None) | ||||
| 
 | ||||
| @ -71,6 +75,42 @@ class TestClassifier(TestCase): | ||||
|         self.assertTrue(self.classifier.train()) | ||||
|         self.assertFalse(self.classifier.train()) | ||||
| 
 | ||||
|     def testVersionIncreased(self): | ||||
| 
 | ||||
|         self.generate_test_data() | ||||
|         self.assertTrue(self.classifier.train()) | ||||
|         self.assertFalse(self.classifier.train()) | ||||
| 
 | ||||
|         classifier2 = DocumentClassifier() | ||||
| 
 | ||||
|         current_ver = DocumentClassifier.FORMAT_VERSION | ||||
|         with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1): | ||||
|             # assure that we won't load old classifiers. | ||||
|             self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload) | ||||
| 
 | ||||
|             self.classifier.save_classifier() | ||||
| 
 | ||||
|             # assure that we can load the classifier after saving it. | ||||
|             classifier2.reload() | ||||
| 
 | ||||
|     def testReload(self): | ||||
| 
 | ||||
|         self.generate_test_data() | ||||
|         self.assertTrue(self.classifier.train()) | ||||
|         self.classifier.save_classifier() | ||||
| 
 | ||||
|         classifier2 = DocumentClassifier() | ||||
|         classifier2.reload() | ||||
|         v1 = classifier2.classifier_version | ||||
| 
 | ||||
|         # change the classifier after some time. | ||||
|         sleep(1) | ||||
|         self.classifier.save_classifier() | ||||
| 
 | ||||
|         classifier2.reload() | ||||
|         v2 = classifier2.classifier_version | ||||
|         self.assertNotEqual(v1, v2) | ||||
| 
 | ||||
|     @override_settings(DATA_DIR=tempfile.mkdtemp()) | ||||
|     def testSaveClassifier(self): | ||||
| 
 | ||||
| @ -83,3 +123,112 @@ class TestClassifier(TestCase): | ||||
|         new_classifier = DocumentClassifier() | ||||
|         new_classifier.reload() | ||||
|         self.assertFalse(new_classifier.train()) | ||||
| 
 | ||||
|     def test_one_correspondent_predict(self): | ||||
|         c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO) | ||||
|         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A") | ||||
| 
 | ||||
|         self.classifier.train() | ||||
|         self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk) | ||||
| 
 | ||||
|     def test_one_correspondent_predict_manydocs(self): | ||||
|         c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO) | ||||
|         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A") | ||||
|         doc2 = Document.objects.create(title="doc2", content="this is a document from noone", checksum="B") | ||||
| 
 | ||||
|         self.classifier.train() | ||||
|         self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk) | ||||
|         self.assertIsNone(self.classifier.predict_correspondent(doc2.content)) | ||||
| 
 | ||||
|     def test_one_type_predict(self): | ||||
|         dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO) | ||||
| 
 | ||||
|         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", | ||||
|                                             checksum="A", document_type=dt) | ||||
| 
 | ||||
|         self.classifier.train() | ||||
|         self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk) | ||||
| 
 | ||||
|     def test_one_type_predict_manydocs(self): | ||||
|         dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO) | ||||
| 
 | ||||
|         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", | ||||
|                                             checksum="A", document_type=dt) | ||||
| 
 | ||||
|         doc2 = Document.objects.create(title="doc1", content="this is a document from c2", | ||||
|                                             checksum="B") | ||||
| 
 | ||||
|         self.classifier.train() | ||||
|         self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk) | ||||
|         self.assertIsNone(self.classifier.predict_document_type(doc2.content)) | ||||
| 
 | ||||
|     def test_one_tag_predict(self): | ||||
|         t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) | ||||
| 
 | ||||
|         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A") | ||||
| 
 | ||||
|         doc1.tags.add(t1) | ||||
|         self.classifier.train() | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) | ||||
| 
 | ||||
|     def test_one_tag_predict_unassigned(self): | ||||
|         t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) | ||||
| 
 | ||||
|         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A") | ||||
| 
 | ||||
|         self.classifier.train() | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc1.content), []) | ||||
| 
 | ||||
|     def test_two_tags_predict_singledoc(self): | ||||
|         t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) | ||||
|         t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121) | ||||
| 
 | ||||
|         doc4 = Document.objects.create(title="doc1", content="this is a document from c4", checksum="D") | ||||
| 
 | ||||
|         doc4.tags.add(t1) | ||||
|         doc4.tags.add(t2) | ||||
|         self.classifier.train() | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk]) | ||||
| 
 | ||||
|     def test_two_tags_predict(self): | ||||
|         t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) | ||||
|         t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121) | ||||
| 
 | ||||
|         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A") | ||||
|         doc2 = Document.objects.create(title="doc1", content="this is a document from c2", checksum="B") | ||||
|         doc3 = Document.objects.create(title="doc1", content="this is a document from c3", checksum="C") | ||||
|         doc4 = Document.objects.create(title="doc1", content="this is a document from c4", checksum="D") | ||||
| 
 | ||||
|         doc1.tags.add(t1) | ||||
|         doc2.tags.add(t2) | ||||
| 
 | ||||
|         doc4.tags.add(t1) | ||||
|         doc4.tags.add(t2) | ||||
|         self.classifier.train() | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc2.content), [t2.pk]) | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc3.content), []) | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk]) | ||||
| 
 | ||||
|     def test_one_tag_predict_multi(self): | ||||
|         t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) | ||||
| 
 | ||||
|         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A") | ||||
|         doc2 = Document.objects.create(title="doc2", content="this is a document from c2", checksum="B") | ||||
| 
 | ||||
|         doc1.tags.add(t1) | ||||
|         doc2.tags.add(t1) | ||||
|         self.classifier.train() | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc2.content), [t1.pk]) | ||||
| 
 | ||||
|     def test_one_tag_predict_multi_2(self): | ||||
|         t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) | ||||
| 
 | ||||
|         doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A") | ||||
|         doc2 = Document.objects.create(title="doc2", content="this is a document from c2", checksum="B") | ||||
| 
 | ||||
|         doc1.tags.add(t1) | ||||
|         self.classifier.train() | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk]) | ||||
|         self.assertListEqual(self.classifier.predict_tags(doc2.content), []) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user