import re
import shutil
import warnings
+from typing import Iterator
from typing import List
from typing import Optional
def train(self):
- data = []
labels_tags = []
labels_correspondent = []
labels_document_type = []
labels_storage_path = []
+ docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
+
+ if docs_queryset.count() == 0:
+ raise ValueError("No training data available.")
+
# Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...")
m = hashlib.sha1()
- for doc in Document.objects.order_by("pk").exclude(
- tags__is_inbox_tag=True,
- ):
+ for doc in docs_queryset:
preprocessed_content = self.preprocess_content(doc.content)
m.update(preprocessed_content.encode("utf-8"))
- data.append(preprocessed_content)
y = -1
dt = doc.document_type
m.update(y.to_bytes(4, "little", signed=True))
labels_storage_path.append(y)
- if not data:
- raise ValueError("No training data available.")
-
new_data_hash = m.digest()
if self.data_hash and new_data_hash == self.data_hash:
logger.debug(
"{} documents, {} tag(s), {} correspondent(s), "
"{} document type(s). {} storage path(es)".format(
- len(data),
+ docs_queryset.count(),
num_tags,
num_correspondents,
num_document_types,
# Step 2: vectorize data
logger.debug("Vectorizing data...")
+
+ def content_generator() -> Iterator[str]:
+ for doc in docs_queryset:
+ yield self.preprocess_content(doc.content)
+
self.data_vectorizer = CountVectorizer(
analyzer="word",
ngram_range=(1, 2),
min_df=0.01,
)
- data_vectorized = self.data_vectorizer.fit_transform(data)
+
+ data_vectorized = self.data_vectorizer.fit_transform(content_generator())
# See the notes here:
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html # noqa: 501
return content
- def predict_correspondent(self, content):
+ def predict_correspondent(self, content: str):
if self.correspondent_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)])
correspondent_id = self.correspondent_classifier.predict(X)
else:
return None
- def predict_document_type(self, content):
+ def predict_document_type(self, content: str):
if self.document_type_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)])
document_type_id = self.document_type_classifier.predict(X)
else:
return None
- def predict_tags(self, content):
+ def predict_tags(self, content: str):
from sklearn.utils.multiclass import type_of_target
if self.tags_classifier:
else:
return []
- def predict_storage_path(self, content):
+ def predict_storage_path(self, content: str):
if self.storage_path_classifier:
X = self.data_vectorizer.transform([self.preprocess_content(content)])
storage_path_id = self.storage_path_classifier.predict(X)