]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Changes classifier training to hold less data in memory at the same time
authorTrenton H <797416+stumpylog@users.noreply.github.com>
Wed, 22 Feb 2023 19:33:11 +0000 (11:33 -0800)
committerTrenton H <797416+stumpylog@users.noreply.github.com>
Tue, 28 Feb 2023 16:13:10 +0000 (08:13 -0800)
src/documents/classifier.py

index 2779fad7bca996f319098736c255d6969008de6e..66958087a1d4d467976d59e5bd653fdf57163e0b 100644 (file)
@@ -5,6 +5,7 @@ import pickle
 import re
 import shutil
 import warnings
+from typing import Iterator
 from typing import List
 from typing import Optional
 
@@ -136,21 +137,22 @@ class DocumentClassifier:
 
     def train(self):
 
-        data = []
         labels_tags = []
         labels_correspondent = []
         labels_document_type = []
         labels_storage_path = []
 
+        docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
+
+        if docs_queryset.count() == 0:
+            raise ValueError("No training data available.")
+
         # Step 1: Extract and preprocess training data from the database.
         logger.debug("Gathering data from database...")
         m = hashlib.sha1()
-        for doc in Document.objects.order_by("pk").exclude(
-            tags__is_inbox_tag=True,
-        ):
+        for doc in docs_queryset:
             preprocessed_content = self.preprocess_content(doc.content)
             m.update(preprocessed_content.encode("utf-8"))
-            data.append(preprocessed_content)
 
             y = -1
             dt = doc.document_type
@@ -183,9 +185,6 @@ class DocumentClassifier:
             m.update(y.to_bytes(4, "little", signed=True))
             labels_storage_path.append(y)
 
-        if not data:
-            raise ValueError("No training data available.")
-
         new_data_hash = m.digest()
 
         if self.data_hash and new_data_hash == self.data_hash:
@@ -207,7 +206,7 @@ class DocumentClassifier:
         logger.debug(
             "{} documents, {} tag(s), {} correspondent(s), "
             "{} document type(s). {} storage path(es)".format(
-                len(data),
+                docs_queryset.count(),
                 num_tags,
                 num_correspondents,
                 num_document_types,
@@ -221,12 +220,18 @@ class DocumentClassifier:
 
         # Step 2: vectorize data
         logger.debug("Vectorizing data...")
+
+        def content_generator() -> Iterator[str]:
+            for doc in docs_queryset:
+                yield self.preprocess_content(doc.content)
+
         self.data_vectorizer = CountVectorizer(
             analyzer="word",
             ngram_range=(1, 2),
             min_df=0.01,
         )
-        data_vectorized = self.data_vectorizer.fit_transform(data)
+
+        data_vectorized = self.data_vectorizer.fit_transform(content_generator())
 
         # See the notes here:
         # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html  # noqa: 501
@@ -341,7 +346,7 @@ class DocumentClassifier:
 
         return content
 
-    def predict_correspondent(self, content):
+    def predict_correspondent(self, content: str):
         if self.correspondent_classifier:
             X = self.data_vectorizer.transform([self.preprocess_content(content)])
             correspondent_id = self.correspondent_classifier.predict(X)
@@ -352,7 +357,7 @@ class DocumentClassifier:
         else:
             return None
 
-    def predict_document_type(self, content):
+    def predict_document_type(self, content: str):
         if self.document_type_classifier:
             X = self.data_vectorizer.transform([self.preprocess_content(content)])
             document_type_id = self.document_type_classifier.predict(X)
@@ -363,7 +368,7 @@ class DocumentClassifier:
         else:
             return None
 
-    def predict_tags(self, content):
+    def predict_tags(self, content: str):
         from sklearn.utils.multiclass import type_of_target
 
         if self.tags_classifier:
@@ -384,7 +389,7 @@ class DocumentClassifier:
         else:
             return []
 
-    def predict_storage_path(self, content):
+    def predict_storage_path(self, content: str):
         if self.storage_path_classifier:
             X = self.data_vectorizer.transform([self.preprocess_content(content)])
             storage_path_id = self.storage_path_classifier.predict(X)