From abcebfe9714644d4e259e53b10e0e9417b5b864f Mon Sep 17 00:00:00 2001 From: Jerome Flesch Date: Sun, 21 Apr 2024 13:31:03 +0200 Subject: [PATCH] backend/guesswork/labels/sklearn: fix use of scipy.sparse.hstack() + numpy.zeros() Closes #1111 --- .../paperwork_backend/guesswork/label/sklearn/__init__.py | 5 +++-- paperwork-backend/src/paperwork_backend/model/fake.py | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py b/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py index b2af4350..8633211f 100644 --- a/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py +++ b/paperwork-backend/src/paperwork_backend/guesswork/label/sklearn/__init__.py @@ -191,7 +191,8 @@ class UpdatableVectorizer(object): ) if required_padding > 0: doc_vector = numpy.hstack([ - doc_vector, numpy.zeros((required_padding,)) + doc_vector, + numpy.zeros((required_padding,)) ]) if sum_features is None: sum_features = doc_vector @@ -339,7 +340,7 @@ class Corpus(object): if required_padding > 0: doc_vector = scipy.sparse.hstack([ scipy.sparse.csr_matrix(doc_vector), - numpy.zeros((required_padding,)) + numpy.zeros((1, required_padding)) ]) else: doc_vector = scipy.sparse.csr_matrix(doc_vector) diff --git a/paperwork-backend/src/paperwork_backend/model/fake.py b/paperwork-backend/src/paperwork_backend/model/fake.py index 29beae97..f06fe18e 100644 --- a/paperwork-backend/src/paperwork_backend/model/fake.py +++ b/paperwork-backend/src/paperwork_backend/model/fake.py @@ -125,6 +125,12 @@ class Plugin(openpaperwork_core.PluginBase): if doc['url'] == doc_url: out.update(doc['labels']) + def doc_has_labels_by_url(self, doc_url): + for doc in self.docs: + if doc['url'] == doc_url: + return True if len(doc["labels"]) > 0 else None + return None + def doc_add_label_by_url(self, doc_url, label, color=None): if color is None: all_labels = set() -- GitLab