-import hashlib
import logging
import os
import pickle
import re
import shutil
import warnings
+from datetime import datetime
from typing import Iterator
from typing import List
from typing import Optional
# v7 - Updated scikit-learn package version
# v8 - Added storage path classifier
- FORMAT_VERSION = 8
+ # v9 - Changed from hash to time for training data check
+ FORMAT_VERSION = 9
def __init__(self):
- # hash of the training data. used to prevent re-training when the
+ # last time training data was calculated. used to prevent re-training when the
# training data has not changed.
- self.data_hash: Optional[bytes] = None
+ self.last_data_change: Optional[datetime] = None
self.data_vectorizer = None
self.tags_binarizer = None
)
else:
try:
- self.data_hash = pickle.load(f)
+ self.last_data_change = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
with open(target_file_temp, "wb") as f:
pickle.dump(self.FORMAT_VERSION, f)
- pickle.dump(self.data_hash, f)
+ pickle.dump(self.last_data_change, f)
pickle.dump(self.data_vectorizer, f)
pickle.dump(self.tags_binarizer, f)
def train(self):
+ # Get non-inbox documents
+ docs_queryset = Document.objects.exclude(tags__is_inbox_tag=True)
+
+ # No documents exit to train against
+ if docs_queryset.count() == 0:
+ raise ValueError("No training data available.")
+
+ # No documents have changed since classifier was trained
+ latest_doc_change = docs_queryset.latest("modified").modified
+ if (
+ self.last_data_change is not None
+ and self.last_data_change >= latest_doc_change
+ ):
+ return False
+
labels_tags = []
labels_correspondent = []
labels_document_type = []
labels_storage_path = []
- docs_queryset = Document.objects.order_by("pk").exclude(tags__is_inbox_tag=True)
-
- if docs_queryset.count() == 0:
- raise ValueError("No training data available.")
-
# Step 1: Extract and preprocess training data from the database.
logger.debug("Gathering data from database...")
- m = hashlib.sha1()
for doc in docs_queryset:
- preprocessed_content = self.preprocess_content(doc.content)
- m.update(preprocessed_content.encode("utf-8"))
y = -1
dt = doc.document_type
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
y = dt.pk
- m.update(y.to_bytes(4, "little", signed=True))
labels_document_type.append(y)
y = -1
cor = doc.correspondent
if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
y = cor.pk
- m.update(y.to_bytes(4, "little", signed=True))
labels_correspondent.append(y)
tags = sorted(
matching_algorithm=MatchingModel.MATCH_AUTO,
)
)
- for tag in tags:
- m.update(tag.to_bytes(4, "little", signed=True))
labels_tags.append(tags)
y = -1
sd = doc.storage_path
if sd and sd.matching_algorithm == MatchingModel.MATCH_AUTO:
y = sd.pk
- m.update(y.to_bytes(4, "little", signed=True))
labels_storage_path.append(y)
- new_data_hash = m.digest()
-
- if self.data_hash and new_data_hash == self.data_hash:
- return False
-
labels_tags_unique = {tag for tags in labels_tags for tag in tags}
num_tags = len(labels_tags_unique)
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier
- from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
+ from sklearn.preprocessing import LabelBinarizer
+ from sklearn.preprocessing import MultiLabelBinarizer
# Step 2: vectorize data
logger.debug("Vectorizing data...")
def content_generator() -> Iterator[str]:
+ """
+ Generates the content for documents, but once at a time
+ """
for doc in docs_queryset:
yield self.preprocess_content(doc.content)
"There are no storage paths. Not training storage path classifier.",
)
- self.data_hash = new_data_hash
+ self.last_data_change = latest_doc_change
return True
import os
import re
-import shutil
-import tempfile
from pathlib import Path
from unittest import mock
def dummy_preprocess(content: str):
+ """
+ Simpler, faster pre-processing for testing purposes
+ """
content = content.lower().strip()
content = re.sub(r"\s+", " ", content)
return content
class TestClassifier(DirectoriesMixin, TestCase):
-
- SAMPLE_MODEL_FILE = os.path.join(os.path.dirname(__file__), "data", "model.pickle")
-
def setUp(self):
super().setUp()
self.classifier = DocumentClassifier()
self.doc1.storage_path = self.sp1
- def testNoTrainingData(self):
- try:
+ def generate_train_and_save(self):
+ """
+ Generates the training data, trains and saves the updated pickle
+ file. This ensures the test is using the same scikit learn version
+ and eliminates a warning from the test suite
+ """
+ self.generate_test_data()
+ self.classifier.train()
+ self.classifier.save()
+
+ def test_no_training_data(self):
+ """
+ GIVEN:
+ - No documents exist to train
+ WHEN:
+ - Classifier training is requested
+ THEN:
+ - Exception is raised
+ """
+ with self.assertRaisesMessage(ValueError, "No training data available."):
+ self.classifier.train()
+
+ def test_no_non_inbox_tags(self):
+ """
+ GIVEN:
+ - No documents without an inbox tag exist
+ WHEN:
+ - Classifier training is requested
+ THEN:
+ - Exception is raised
+ """
+
+ t1 = Tag.objects.create(
+ name="t1",
+ matching_algorithm=Tag.MATCH_ANY,
+ pk=34,
+ is_inbox_tag=True,
+ )
+
+ doc1 = Document.objects.create(
+ title="doc1",
+ content="this is a document from c1",
+ checksum="A",
+ )
+ doc1.tags.add(t1)
+
+ with self.assertRaisesMessage(ValueError, "No training data available."):
self.classifier.train()
- except ValueError as e:
- self.assertEqual(str(e), "No training data available.")
- else:
- self.fail("Should raise exception")
def testEmpty(self):
+ """
+ GIVEN:
+ - A document exists
+ - No tags/not enough data to predict
+ WHEN:
+ - Classifier prediction is requested
+ THEN:
+ - Classifier returns no predictions
+ """
Document.objects.create(title="WOW", checksum="3457", content="ASD")
self.classifier.train()
+
self.assertIsNone(self.classifier.document_type_classifier)
self.assertIsNone(self.classifier.tags_classifier)
self.assertIsNone(self.classifier.correspondent_classifier)
self.assertIsNone(self.classifier.predict_correspondent(""))
def testTrain(self):
+ """
+ GIVEN:
+ - Test data
+ WHEN:
+ - Classifier is trained
+ THEN:
+ - Classifier uses correct values for correspondent learning
+ - Classifier uses correct values for tags learning
+ """
self.generate_test_data()
self.classifier.train()
+
self.assertListEqual(
list(self.classifier.correspondent_classifier.classes_),
[-1, self.c1.pk],
)
def testPredict(self):
+ """
+ GIVEN:
+ - Classifier trained against test data
+ WHEN:
+ - Prediction requested for correspondent, tags, type
+ THEN:
+ - Expected predictions based on training set
+ """
self.generate_test_data()
self.classifier.train()
+
self.assertEqual(
self.classifier.predict_correspondent(self.doc1.content),
self.c1.pk,
)
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
- def testDatasetHashing(self):
+ def test_no_retrain_if_no_change(self):
+ """
+ GIVEN:
+ - Classifier trained with current data
+ WHEN:
+ - Classifier training is requested again
+ THEN:
+ - Classifier does not redo training
+ """
self.generate_test_data()
self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train())
- def testVersionIncreased(self):
+ def test_retrain_if_change(self):
+ """
+ GIVEN:
+ - Classifier trained with current data
+ WHEN:
+ - Classifier training is requested again
+ - Documents have changed
+ THEN:
+ - Classifier does not redo training
+ """
self.generate_test_data()
+
self.assertTrue(self.classifier.train())
- self.assertFalse(self.classifier.train())
- self.classifier.save()
+ self.doc1.correspondent = self.c2
+ self.doc1.save()
+
+ self.assertTrue(self.classifier.train())
+
+ def testVersionIncreased(self):
+ """
+ GIVEN:
+ - Existing classifier model saved at a version
+ WHEN:
+ - Attempt to load classifier file from newer version
+ THEN:
+ - Exception is raised
+ """
+ self.generate_train_and_save()
classifier2 = DocumentClassifier()
# assure that we can load the classifier after saving it.
classifier2.load()
- @override_settings(DATA_DIR=tempfile.mkdtemp())
def testSaveClassifier(self):
- self.generate_test_data()
-
- self.classifier.train()
-
- self.classifier.save()
+ self.generate_train_and_save()
new_classifier = DocumentClassifier()
new_classifier.load()
self.assertFalse(new_classifier.train())
- # @override_settings(
- # MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
- # )
- # def test_create_test_load_and_classify(self):
- # self.generate_test_data()
- # self.classifier.train()
- # self.classifier.save()
-
def test_load_and_classify(self):
- # Generate test data, train and save to the model file
- # This ensures the model file sklearn version matches
- # and eliminates a warning
- shutil.copy(
- self.SAMPLE_MODEL_FILE,
- os.path.join(self.dirs.data_dir, "classification_model.pickle"),
- )
- self.generate_test_data()
- self.classifier.train()
- self.classifier.save()
+
+ self.generate_train_and_save()
new_classifier = DocumentClassifier()
new_classifier.load()
THEN:
- The ClassifierModelCorruptError is raised
"""
- shutil.copy(
- self.SAMPLE_MODEL_FILE,
- os.path.join(self.dirs.data_dir, "classification_model.pickle"),
- )
- # First load is the schema version
+ self.generate_train_and_save()
+
+ # First load is the schema version,allow it
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
with self.assertRaises(ClassifierModelCorruptError):