mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-24 23:39:05 -04:00 
			
		
		
		
	changed classifier
This commit is contained in:
		
							parent
							
								
									04bf5fc094
								
							
						
					
					
						commit
						d2534a73e5
					
				
							
								
								
									
										0
									
								
								models/.keep
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								models/.keep
									
									
									
									
									
										Normal file
									
								
							| @ -2,12 +2,13 @@ import logging | |||||||
| import os | import os | ||||||
| import pickle | import pickle | ||||||
| 
 | 
 | ||||||
|  | from sklearn.neural_network import MLPClassifier | ||||||
|  | 
 | ||||||
| from documents.models import Correspondent, DocumentType, Tag, Document | from documents.models import Correspondent, DocumentType, Tag, Document | ||||||
| from paperless import settings | from paperless import settings | ||||||
| 
 | 
 | ||||||
| from sklearn.feature_extraction.text import CountVectorizer | from sklearn.feature_extraction.text import CountVectorizer | ||||||
| from sklearn.multiclass import OneVsRestClassifier | from sklearn.multiclass import OneVsRestClassifier | ||||||
| from sklearn.naive_bayes import MultinomialNB |  | ||||||
| from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer | from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -87,7 +88,7 @@ class DocumentClassifier(object): | |||||||
| 
 | 
 | ||||||
|         # Step 2: vectorize data |         # Step 2: vectorize data | ||||||
|         logging.getLogger(__name__).info("Vectorizing data...") |         logging.getLogger(__name__).info("Vectorizing data...") | ||||||
|         self.data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(2, 6), min_df=0.1) |         self.data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(3, 5), min_df=0.1) | ||||||
|         data_vectorized = self.data_vectorizer.fit_transform(data) |         data_vectorized = self.data_vectorizer.fit_transform(data) | ||||||
| 
 | 
 | ||||||
|         self.tags_binarizer = MultiLabelBinarizer() |         self.tags_binarizer = MultiLabelBinarizer() | ||||||
| @ -102,7 +103,7 @@ class DocumentClassifier(object): | |||||||
|         # Step 3: train the classifiers |         # Step 3: train the classifiers | ||||||
|         if len(self.tags_binarizer.classes_) > 0: |         if len(self.tags_binarizer.classes_) > 0: | ||||||
|             logging.getLogger(__name__).info("Training tags classifier...") |             logging.getLogger(__name__).info("Training tags classifier...") | ||||||
|             self.tags_classifier = OneVsRestClassifier(MultinomialNB()) |             self.tags_classifier = MLPClassifier(verbose=True) | ||||||
|             self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) |             self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) | ||||||
|         else: |         else: | ||||||
|             self.tags_classifier = None |             self.tags_classifier = None | ||||||
| @ -110,7 +111,7 @@ class DocumentClassifier(object): | |||||||
| 
 | 
 | ||||||
|         if len(self.correspondent_binarizer.classes_) > 0: |         if len(self.correspondent_binarizer.classes_) > 0: | ||||||
|             logging.getLogger(__name__).info("Training correspondent classifier...") |             logging.getLogger(__name__).info("Training correspondent classifier...") | ||||||
|             self.correspondent_classifier = OneVsRestClassifier(MultinomialNB()) |             self.correspondent_classifier = MLPClassifier(verbose=True) | ||||||
|             self.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) |             self.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) | ||||||
|         else: |         else: | ||||||
|             self.correspondent_classifier = None |             self.correspondent_classifier = None | ||||||
| @ -118,7 +119,7 @@ class DocumentClassifier(object): | |||||||
| 
 | 
 | ||||||
|         if len(self.type_binarizer.classes_) > 0: |         if len(self.type_binarizer.classes_) > 0: | ||||||
|             logging.getLogger(__name__).info("Training document type classifier...") |             logging.getLogger(__name__).info("Training document type classifier...") | ||||||
|             self.type_classifier = OneVsRestClassifier(MultinomialNB()) |             self.type_classifier = MLPClassifier(verbose=True) | ||||||
|             self.type_classifier.fit(data_vectorized, labels_type_vectorized) |             self.type_classifier.fit(data_vectorized, labels_type_vectorized) | ||||||
|         else: |         else: | ||||||
|             self.type_classifier = None |             self.type_classifier = None | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ class Command(Renderable, BaseCommand): | |||||||
|         with open("dataset_tags.txt", "w") as f: |         with open("dataset_tags.txt", "w") as f: | ||||||
|             for doc in Document.objects.exclude(tags__is_inbox_tag=True): |             for doc in Document.objects.exclude(tags__is_inbox_tag=True): | ||||||
|                 labels = [] |                 labels = [] | ||||||
|                 for tag in doc.tags.all(): |                 for tag in doc.tags.filter(automatic_classification=True): | ||||||
|                     labels.append(tag.name) |                     labels.append(tag.name) | ||||||
|                 f.write(",".join(labels)) |                 f.write(",".join(labels)) | ||||||
|                 f.write(";") |                 f.write(";") | ||||||
| @ -27,14 +27,14 @@ class Command(Renderable, BaseCommand): | |||||||
| 
 | 
 | ||||||
|         with open("dataset_types.txt", "w") as f: |         with open("dataset_types.txt", "w") as f: | ||||||
|             for doc in Document.objects.exclude(tags__is_inbox_tag=True): |             for doc in Document.objects.exclude(tags__is_inbox_tag=True): | ||||||
|                 f.write(doc.document_type.name if doc.document_type is not None else "None") |                 f.write(doc.document_type.name if doc.document_type is not None and doc.document_type.automatic_classification else "-") | ||||||
|                 f.write(";") |                 f.write(";") | ||||||
|                 f.write(preprocess_content(doc.content)) |                 f.write(preprocess_content(doc.content)) | ||||||
|                 f.write("\n") |                 f.write("\n") | ||||||
| 
 | 
 | ||||||
|         with open("dataset_correspondents.txt", "w") as f: |         with open("dataset_correspondents.txt", "w") as f: | ||||||
|             for doc in Document.objects.exclude(tags__is_inbox_tag=True): |             for doc in Document.objects.exclude(tags__is_inbox_tag=True): | ||||||
|                 f.write(doc.correspondent.name if doc.correspondent is not None else "None") |                 f.write(doc.correspondent.name if doc.correspondent is not None and doc.correspondent.automatic_classification else "-") | ||||||
|                 f.write(";") |                 f.write(";") | ||||||
|                 f.write(preprocess_content(doc.content)) |                 f.write(preprocess_content(doc.content)) | ||||||
|                 f.write("\n") |                 f.write("\n") | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user