mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-11-03 19:17:13 -05:00 
			
		
		
		
	The classifier works with ids now, not names. Minor changes.
This commit is contained in:
		
							parent
							
								
									d2534a73e5
								
							
						
					
					
						commit
						d46ee11143
					
				@ -239,9 +239,9 @@ def run_document_classifier_on_selected(modeladmin, request, queryset):
 | 
				
			|||||||
    n = queryset.count()
 | 
					    n = queryset.count()
 | 
				
			||||||
    if n:
 | 
					    if n:
 | 
				
			||||||
        for obj in queryset:
 | 
					        for obj in queryset:
 | 
				
			||||||
            clf.classify_document(obj, classify_correspondent=True, classify_tags=True, classify_type=True, replace_tags=True)
 | 
					            clf.classify_document(obj, classify_correspondent=True, classify_tags=True, classify_document_type=True, replace_tags=True)
 | 
				
			||||||
            modeladmin.log_change(request, obj, str(obj))
 | 
					            modeladmin.log_change(request, obj, str(obj))
 | 
				
			||||||
        modeladmin.message_user(request, "Successfully applied tags, correspondent and type to %(count)d %(items)s." % {
 | 
					        modeladmin.message_user(request, "Successfully applied tags, correspondent and document type to %(count)d %(items)s." % {
 | 
				
			||||||
            "count": n, "items": model_ngettext(modeladmin.opts, n)
 | 
					            "count": n, "items": model_ngettext(modeladmin.opts, n)
 | 
				
			||||||
        }, messages.SUCCESS)
 | 
					        }, messages.SUCCESS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -12,6 +12,7 @@ class DocumentsConfig(AppConfig):
 | 
				
			|||||||
        from .signals import document_consumption_finished
 | 
					        from .signals import document_consumption_finished
 | 
				
			||||||
        from .signals.handlers import (
 | 
					        from .signals.handlers import (
 | 
				
			||||||
            classify_document,
 | 
					            classify_document,
 | 
				
			||||||
 | 
					            add_inbox_tags,
 | 
				
			||||||
            run_pre_consume_script,
 | 
					            run_pre_consume_script,
 | 
				
			||||||
            run_post_consume_script,
 | 
					            run_post_consume_script,
 | 
				
			||||||
            cleanup_document_deletion,
 | 
					            cleanup_document_deletion,
 | 
				
			||||||
@ -21,6 +22,7 @@ class DocumentsConfig(AppConfig):
 | 
				
			|||||||
        document_consumption_started.connect(run_pre_consume_script)
 | 
					        document_consumption_started.connect(run_pre_consume_script)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        document_consumption_finished.connect(classify_document)
 | 
					        document_consumption_finished.connect(classify_document)
 | 
				
			||||||
 | 
					        document_consumption_finished.connect(add_inbox_tags)
 | 
				
			||||||
        document_consumption_finished.connect(set_log_entry)
 | 
					        document_consumption_finished.connect(set_log_entry)
 | 
				
			||||||
        document_consumption_finished.connect(run_post_consume_script)
 | 
					        document_consumption_finished.connect(run_post_consume_script)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,6 @@ from documents.models import Correspondent, DocumentType, Tag, Document
 | 
				
			|||||||
from paperless import settings
 | 
					from paperless import settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from sklearn.feature_extraction.text import CountVectorizer
 | 
					from sklearn.feature_extraction.text import CountVectorizer
 | 
				
			||||||
from sklearn.multiclass import OneVsRestClassifier
 | 
					 | 
				
			||||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
 | 
					from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -30,11 +29,11 @@ class DocumentClassifier(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    tags_binarizer = None
 | 
					    tags_binarizer = None
 | 
				
			||||||
    correspondent_binarizer = None
 | 
					    correspondent_binarizer = None
 | 
				
			||||||
    type_binarizer = None
 | 
					    document_type_binarizer = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tags_classifier = None
 | 
					    tags_classifier = None
 | 
				
			||||||
    correspondent_classifier = None
 | 
					    correspondent_classifier = None
 | 
				
			||||||
    type_classifier = None
 | 
					    document_type_classifier = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def load_classifier():
 | 
					    def load_classifier():
 | 
				
			||||||
@ -49,11 +48,11 @@ class DocumentClassifier(object):
 | 
				
			|||||||
                self.data_vectorizer = pickle.load(f)
 | 
					                self.data_vectorizer = pickle.load(f)
 | 
				
			||||||
                self.tags_binarizer = pickle.load(f)
 | 
					                self.tags_binarizer = pickle.load(f)
 | 
				
			||||||
                self.correspondent_binarizer = pickle.load(f)
 | 
					                self.correspondent_binarizer = pickle.load(f)
 | 
				
			||||||
                self.type_binarizer = pickle.load(f)
 | 
					                self.document_type_binarizer = pickle.load(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                self.tags_classifier = pickle.load(f)
 | 
					                self.tags_classifier = pickle.load(f)
 | 
				
			||||||
                self.correspondent_classifier = pickle.load(f)
 | 
					                self.correspondent_classifier = pickle.load(f)
 | 
				
			||||||
                self.type_classifier = pickle.load(f)
 | 
					                self.document_type_classifier = pickle.load(f)
 | 
				
			||||||
            self.classifier_version = os.path.getmtime(settings.MODEL_FILE)
 | 
					            self.classifier_version = os.path.getmtime(settings.MODEL_FILE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def save_classifier(self):
 | 
					    def save_classifier(self):
 | 
				
			||||||
@ -62,29 +61,29 @@ class DocumentClassifier(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            pickle.dump(self.tags_binarizer, f)
 | 
					            pickle.dump(self.tags_binarizer, f)
 | 
				
			||||||
            pickle.dump(self.correspondent_binarizer, f)
 | 
					            pickle.dump(self.correspondent_binarizer, f)
 | 
				
			||||||
            pickle.dump(self.type_binarizer, f)
 | 
					            pickle.dump(self.document_type_binarizer, f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            pickle.dump(self.tags_classifier, f)
 | 
					            pickle.dump(self.tags_classifier, f)
 | 
				
			||||||
            pickle.dump(self.correspondent_classifier, f)
 | 
					            pickle.dump(self.correspondent_classifier, f)
 | 
				
			||||||
            pickle.dump(self.type_classifier, f)
 | 
					            pickle.dump(self.document_type_classifier, f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def train(self):
 | 
					    def train(self):
 | 
				
			||||||
        data = list()
 | 
					        data = list()
 | 
				
			||||||
        labels_tags = list()
 | 
					        labels_tags = list()
 | 
				
			||||||
        labels_correspondent = list()
 | 
					        labels_correspondent = list()
 | 
				
			||||||
        labels_type = list()
 | 
					        labels_document_type = list()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Step 1: Extract and preprocess training data from the database.
 | 
					        # Step 1: Extract and preprocess training data from the database.
 | 
				
			||||||
        logging.getLogger(__name__).info("Gathering data from database...")
 | 
					        logging.getLogger(__name__).info("Gathering data from database...")
 | 
				
			||||||
        for doc in Document.objects.exclude(tags__is_inbox_tag=True):
 | 
					        for doc in Document.objects.exclude(tags__is_inbox_tag=True):
 | 
				
			||||||
            data.append(preprocess_content(doc.content))
 | 
					            data.append(preprocess_content(doc.content))
 | 
				
			||||||
            labels_type.append(doc.document_type.name if doc.document_type is not None and doc.document_type.automatic_classification else "-")
 | 
					            labels_document_type.append(doc.document_type.id if doc.document_type is not None and doc.document_type.automatic_classification else -1)
 | 
				
			||||||
            labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None and doc.correspondent.automatic_classification else "-")
 | 
					            labels_correspondent.append(doc.correspondent.id if doc.correspondent is not None and doc.correspondent.automatic_classification else -1)
 | 
				
			||||||
            tags = [tag.name for tag in doc.tags.filter(automatic_classification=True)]
 | 
					            tags = [tag.id for tag in doc.tags.filter(automatic_classification=True)]
 | 
				
			||||||
            labels_tags.append(tags)
 | 
					            labels_tags.append(tags)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
 | 
					        labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
 | 
				
			||||||
        logging.getLogger(__name__).info("{} documents, {} tag(s) {}, {} correspondent(s) {}, {} type(s) {}.".format(len(data), len(labels_tags_unique), labels_tags_unique, len(set(labels_correspondent)), set(labels_correspondent), len(set(labels_type)), set(labels_type)))
 | 
					        logging.getLogger(__name__).info("{} documents, {} tag(s), {} correspondent(s), {} document type(s).".format(len(data), len(labels_tags_unique), len(set(labels_correspondent)), len(set(labels_document_type))))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Step 2: vectorize data
 | 
					        # Step 2: vectorize data
 | 
				
			||||||
        logging.getLogger(__name__).info("Vectorizing data...")
 | 
					        logging.getLogger(__name__).info("Vectorizing data...")
 | 
				
			||||||
@ -97,8 +96,8 @@ class DocumentClassifier(object):
 | 
				
			|||||||
        self.correspondent_binarizer = LabelBinarizer()
 | 
					        self.correspondent_binarizer = LabelBinarizer()
 | 
				
			||||||
        labels_correspondent_vectorized = self.correspondent_binarizer.fit_transform(labels_correspondent)
 | 
					        labels_correspondent_vectorized = self.correspondent_binarizer.fit_transform(labels_correspondent)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.type_binarizer = LabelBinarizer()
 | 
					        self.document_type_binarizer = LabelBinarizer()
 | 
				
			||||||
        labels_type_vectorized = self.type_binarizer.fit_transform(labels_type)
 | 
					        labels_document_type_vectorized = self.document_type_binarizer.fit_transform(labels_document_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Step 3: train the classifiers
 | 
					        # Step 3: train the classifiers
 | 
				
			||||||
        if len(self.tags_binarizer.classes_) > 0:
 | 
					        if len(self.tags_binarizer.classes_) > 0:
 | 
				
			||||||
@ -117,39 +116,52 @@ class DocumentClassifier(object):
 | 
				
			|||||||
            self.correspondent_classifier = None
 | 
					            self.correspondent_classifier = None
 | 
				
			||||||
            logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.")
 | 
					            logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if len(self.type_binarizer.classes_) > 0:
 | 
					        if len(self.document_type_binarizer.classes_) > 0:
 | 
				
			||||||
            logging.getLogger(__name__).info("Training document type classifier...")
 | 
					            logging.getLogger(__name__).info("Training document type classifier...")
 | 
				
			||||||
            self.type_classifier = MLPClassifier(verbose=True)
 | 
					            self.document_type_classifier = MLPClassifier(verbose=True)
 | 
				
			||||||
            self.type_classifier.fit(data_vectorized, labels_type_vectorized)
 | 
					            self.document_type_classifier.fit(data_vectorized, labels_document_type_vectorized)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.type_classifier = None
 | 
					            self.document_type_classifier = None
 | 
				
			||||||
            logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
 | 
					            logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def classify_document(self, document, classify_correspondent=False, classify_type=False, classify_tags=False, replace_tags=False):
 | 
					    def classify_document(self, document, classify_correspondent=False, classify_document_type=False, classify_tags=False, replace_tags=False):
 | 
				
			||||||
        X = self.data_vectorizer.transform([preprocess_content(document.content)])
 | 
					        X = self.data_vectorizer.transform([preprocess_content(document.content)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        update_fields=()
 | 
					        update_fields=()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if classify_correspondent and self.correspondent_classifier is not None:
 | 
					        if classify_correspondent and self.correspondent_classifier is not None:
 | 
				
			||||||
            y_correspondent = self.correspondent_classifier.predict(X)
 | 
					            y_correspondent = self.correspondent_classifier.predict(X)
 | 
				
			||||||
            correspondent = self.correspondent_binarizer.inverse_transform(y_correspondent)[0]
 | 
					            correspondent_id = self.correspondent_binarizer.inverse_transform(y_correspondent)[0]
 | 
				
			||||||
            print("Detected correspondent:", correspondent)
 | 
					            try:
 | 
				
			||||||
            document.correspondent = Correspondent.objects.filter(name=correspondent).first()
 | 
					                correspondent = Correspondent.objects.get(id=correspondent_id) if correspondent_id != -1 else None
 | 
				
			||||||
            update_fields = update_fields + ("correspondent",)
 | 
					                logging.getLogger(__name__).info("Detected correspondent: {}".format(correspondent.name if correspondent else "-"))
 | 
				
			||||||
 | 
					                document.correspondent = correspondent
 | 
				
			||||||
 | 
					                update_fields = update_fields + ("correspondent",)
 | 
				
			||||||
 | 
					            except Correspondent.DoesNotExist:
 | 
				
			||||||
 | 
					                logging.getLogger(__name__).warning("Detected correspondent with id {} does not exist anymore! Did you delete it?".format(correspondent_id))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if classify_type and self.type_classifier is not None:
 | 
					        if classify_document_type and self.document_type_classifier is not None:
 | 
				
			||||||
            y_type = self.type_classifier.predict(X)
 | 
					            y_type = self.document_type_classifier.predict(X)
 | 
				
			||||||
            type = self.type_binarizer.inverse_transform(y_type)[0]
 | 
					            type_id = self.document_type_binarizer.inverse_transform(y_type)[0]
 | 
				
			||||||
            print("Detected document type:", type)
 | 
					            try:
 | 
				
			||||||
            document.document_type = DocumentType.objects.filter(name=type).first()
 | 
					                document_type = DocumentType.objects.get(id=type_id) if type_id != -1 else None
 | 
				
			||||||
            update_fields = update_fields + ("document_type",)
 | 
					                logging.getLogger(__name__).info("Detected document type: {}".format(document_type.name if document_type else "-"))
 | 
				
			||||||
 | 
					                document.document_type = document_type
 | 
				
			||||||
 | 
					                update_fields = update_fields + ("document_type",)
 | 
				
			||||||
 | 
					            except DocumentType.DoesNotExist:
 | 
				
			||||||
 | 
					                logging.getLogger(__name__).warning("Detected document type with id {} does not exist anymore! Did you delete it?".format(type_id))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if classify_tags and self.tags_classifier is not None:
 | 
					        if classify_tags and self.tags_classifier is not None:
 | 
				
			||||||
            y_tags = self.tags_classifier.predict(X)
 | 
					            y_tags = self.tags_classifier.predict(X)
 | 
				
			||||||
            tags = self.tags_binarizer.inverse_transform(y_tags)[0]
 | 
					            tags_ids = self.tags_binarizer.inverse_transform(y_tags)[0]
 | 
				
			||||||
            print("Detected tags:", tags)
 | 
					 | 
				
			||||||
            if replace_tags:
 | 
					            if replace_tags:
 | 
				
			||||||
                document.tags.clear()
 | 
					                document.tags.clear()
 | 
				
			||||||
            document.tags.add(*[Tag.objects.filter(name=t).first() for t in tags])
 | 
					            for tag_id in tags_ids:
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    tag = Tag.objects.get(id=tag_id)
 | 
				
			||||||
 | 
					                    document.tags.add(tag)
 | 
				
			||||||
 | 
					                    logging.getLogger(__name__).info("Detected tag: {}".format(tag.name))
 | 
				
			||||||
 | 
					                except Tag.DoesNotExist:
 | 
				
			||||||
 | 
					                    logging.getLogger(__name__).warning("Detected tag with id {} does not exist anymore! Did you delete it?".format(tag_id))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        document.save(update_fields=update_fields)
 | 
					        document.save(update_fields=update_fields)
 | 
				
			||||||
 | 
				
			|||||||
@ -35,6 +35,10 @@ class Command(Renderable, BaseCommand):
 | 
				
			|||||||
            "-i", "--inbox-only",
 | 
					            "-i", "--inbox-only",
 | 
				
			||||||
            action="store_true"
 | 
					            action="store_true"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        parser.add_argument(
 | 
				
			||||||
 | 
					            "-r", "--replace-tags",
 | 
				
			||||||
 | 
					            action="store_true"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def handle(self, *args, **options):
 | 
					    def handle(self, *args, **options):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -52,7 +56,6 @@ class Command(Renderable, BaseCommand):
 | 
				
			|||||||
            logging.getLogger(__name__).fatal("Cannot classify documents, classifier model file was not found.")
 | 
					            logging.getLogger(__name__).fatal("Cannot classify documents, classifier model file was not found.")
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        for document in documents:
 | 
					        for document in documents:
 | 
				
			||||||
            logging.getLogger(__name__).info("Processing document {}".format(document.title))
 | 
					            logging.getLogger(__name__).info("Processing document {}".format(document.title))
 | 
				
			||||||
            clf.classify_document(document, classify_type=options['type'], classify_tags=options['tags'], classify_correspondent=options['correspondent'])
 | 
					            clf.classify_document(document, classify_document_type=options['type'], classify_tags=options['tags'], classify_correspondent=options['correspondent'], replace_tags=options['replace_tags'])
 | 
				
			||||||
 | 
				
			|||||||
@ -9,7 +9,7 @@ from django.contrib.contenttypes.models import ContentType
 | 
				
			|||||||
from django.utils import timezone
 | 
					from django.utils import timezone
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from documents.classifier import DocumentClassifier
 | 
					from documents.classifier import DocumentClassifier
 | 
				
			||||||
from ..models import Correspondent, Document, Tag, DocumentType
 | 
					from ..models import Document, Tag
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def logger(message, group):
 | 
					def logger(message, group):
 | 
				
			||||||
@ -23,11 +23,14 @@ def classify_document(sender, document=None, logging_group=None, **kwargs):
 | 
				
			|||||||
    global classifier
 | 
					    global classifier
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        classifier.reload()
 | 
					        classifier.reload()
 | 
				
			||||||
        classifier.classify_document(document, classify_correspondent=True, classify_tags=True, classify_type=True)
 | 
					        classifier.classify_document(document, classify_correspondent=True, classify_tags=True, classify_document_type=True)
 | 
				
			||||||
    except FileNotFoundError:
 | 
					    except FileNotFoundError:
 | 
				
			||||||
        logging.getLogger(__name__).fatal("Cannot classify document, classifier model file was not found.")
 | 
					        logging.getLogger(__name__).fatal("Cannot classify document, classifier model file was not found.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def add_inbox_tags(sender, document=None, logging_group=None, **kwargs):
 | 
				
			||||||
 | 
					    inbox_tags = Tag.objects.filter(is_inbox_tag=True)
 | 
				
			||||||
 | 
					    document.tags.add(*inbox_tags)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_pre_consume_script(sender, filename, **kwargs):
 | 
					def run_pre_consume_script(sender, filename, **kwargs):
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user