diff --git a/appinfo/info.xml b/appinfo/info.xml
index b0e0c1e72a..7008982213 100644
--- a/appinfo/info.xml
+++ b/appinfo/info.xml
@@ -34,7 +34,7 @@ The rating depends on the installed text processing backend. See [the rating ove
Learn more about the Nextcloud Ethical AI Rating [in our blog](https://nextcloud.com/blog/nextcloud-ethical-ai-rating/).
]]>
- 4.1.0-alpha.2
+ 4.1.0-alpha.3
agpl
Christoph Wurst
GretaD
diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php
index 8a00c1bafe..32d1187d37 100644
--- a/lib/Command/RunMetaEstimator.php
+++ b/lib/Command/RunMetaEstimator.php
@@ -10,7 +10,7 @@
namespace OCA\Mail\Command;
use OCA\Mail\Service\AccountService;
-use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor;
+use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor;
use OCA\Mail\Service\Classification\ImportanceClassifier;
use OCA\Mail\Support\ConsoleLoggerDecorator;
use OCP\AppFramework\Db\DoesNotExistException;
@@ -86,8 +86,8 @@ protected function execute(InputInterface $input, OutputInterface $output): int
return 1;
}
- /** @var NewCompositeExtractor $extractor */
- $extractor = $this->container->get(NewCompositeExtractor::class);
+ /** @var CompositeExtractor $extractor */
+ $extractor = $this->container->get(CompositeExtractor::class);
$consoleLogger = new ConsoleLoggerDecorator(
$this->logger,
$output
@@ -124,8 +124,9 @@ protected function execute(InputInterface $input, OutputInterface $output): int
return $estimator;
};
+ /** @var GridSearch $metaEstimator */
if ($dataSet) {
- $this->classifier->trainWithCustomDataSet(
+ $metaEstimator = $this->classifier->trainWithCustomDataSet(
$account,
$consoleLogger,
$dataSet,
@@ -135,7 +136,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int
false,
);
} else {
- $this->classifier->train(
+ $metaEstimator = $this->classifier->train(
$account,
$consoleLogger,
$extractor,
@@ -145,6 +146,10 @@ protected function execute(InputInterface $input, OutputInterface $output): int
);
}
+ if ($metaEstimator) {
+ $output->writeln("Best estimator: {$metaEstimator->base()}");
+ }
+
$mbs = (int)(memory_get_peak_usage() / 1024 / 1024);
$output->writeln('' . $mbs . 'MB of memory used');
return 0;
diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php
index 33ce44371d..12866de7f3 100644
--- a/lib/Command/TrainAccount.php
+++ b/lib/Command/TrainAccount.php
@@ -3,7 +3,7 @@
declare(strict_types=1);
/**
- * SPDX-FileCopyrightText: 2019 Nextcloud GmbH and Nextcloud contributors
+ * SPDX-FileCopyrightText: 2019-2024 Nextcloud GmbH and Nextcloud contributors
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
@@ -11,15 +11,13 @@
use OCA\Mail\Service\AccountService;
use OCA\Mail\Service\Classification\ClassificationSettingsService;
+use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor;
use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor;
-use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor;
-use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor;
use OCA\Mail\Service\Classification\ImportanceClassifier;
use OCA\Mail\Support\ConsoleLoggerDecorator;
use OCP\AppFramework\Db\DoesNotExistException;
use Psr\Container\ContainerInterface;
use Psr\Log\LoggerInterface;
-use Rubix\ML\Classifiers\GaussianNB;
use Symfony\Component\Console\Command\Command;
use Symfony\Component\Console\Input\InputArgument;
use Symfony\Component\Console\Input\InputInterface;
@@ -110,9 +108,6 @@ protected function execute(InputInterface $input, OutputInterface $output): int
$shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE);
$dryRun = (bool)$input->getOption(self::ARGUMENT_DRY_RUN);
$force = (bool)$input->getOption(self::ARGUMENT_FORCE);
- $old = (bool)$input->getOption(self::ARGUMENT_OLD);
- $oldEstimator = $old || $input->getOption(self::ARGUMENT_OLD_ESTIMATOR);
- $oldExtractor = $old || $input->getOption(self::ARGUMENT_OLD_EXTRACTOR);
try {
$account = $this->accountService->findById($accountId);
@@ -127,18 +122,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int
}
/** @var IExtractor $extractor */
- if ($oldExtractor) {
- $extractor = $this->container->get(VanillaCompositeExtractor::class);
- } else {
- $extractor = $this->container->get(NewCompositeExtractor::class);
- }
-
- $estimator = null;
- if ($oldEstimator) {
- $estimator = static function () {
- return new GaussianNB();
- };
- }
+ $extractor = $this->container->get(CompositeExtractor::class);
$consoleLogger = new ConsoleLoggerDecorator(
$this->logger,
@@ -167,7 +151,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int
$consoleLogger,
$dataSet,
$extractor,
- $estimator,
+ null,
null,
!$dryRun
);
@@ -176,7 +160,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int
$account,
$consoleLogger,
$extractor,
- $estimator,
+ null,
$shuffle,
!$dryRun
);
diff --git a/lib/Db/Classifier.php b/lib/Db/Classifier.php
deleted file mode 100644
index 126f986730..0000000000
--- a/lib/Db/Classifier.php
+++ /dev/null
@@ -1,106 +0,0 @@
-addType('accountId', 'int');
- $this->addType('type', 'string');
- $this->addType('appVersion', 'string');
- $this->addType('trainingSetSize', 'int');
- $this->addType('validationSetSize', 'int');
- $this->addType('recallImportant', 'float');
- $this->addType('precisionImportant', 'float');
- $this->addType('f1ScoreImportant', 'float');
- $this->addType('duration', 'int');
- $this->addType('active', 'boolean');
- $this->addType('createdAt', 'int');
- }
-}
diff --git a/lib/Db/ClassifierMapper.php b/lib/Db/ClassifierMapper.php
deleted file mode 100644
index 946b70f8b0..0000000000
--- a/lib/Db/ClassifierMapper.php
+++ /dev/null
@@ -1,57 +0,0 @@
-
- */
-class ClassifierMapper extends QBMapper {
- public function __construct(IDBConnection $db) {
- parent::__construct($db, 'mail_classifiers');
- }
-
- /**
- * @param int $id
- *
- * @return Classifier
- * @throws DoesNotExistException
- */
- public function findLatest(int $id): Classifier {
- $qb = $this->db->getQueryBuilder();
-
- $select = $qb->select('*')
- ->from($this->getTableName())
- ->where(
- $qb->expr()->eq('account_id', $qb->createNamedParameter($id, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT),
- $qb->expr()->eq('active', $qb->createNamedParameter(true, IQueryBuilder::PARAM_BOOL), IQueryBuilder::PARAM_BOOL)
- )
- ->orderBy('created_at', 'desc')
- ->setMaxResults(1);
-
- return $this->findEntity($select);
- }
-
- public function findHistoric(int $threshold, int $limit) {
- $qb = $this->db->getQueryBuilder();
- $select = $qb->select('*')
- ->from($this->getTableName())
- ->where(
- $qb->expr()->lte('created_at', $qb->createNamedParameter($threshold, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT),
- )
- ->orderBy('created_at', 'asc')
- ->setMaxResults($limit);
- return $this->findEntities($select);
- }
-}
diff --git a/lib/Migration/Version4100Date20241021091352.php b/lib/Migration/Version4100Date20241021091352.php
new file mode 100644
index 0000000000..ccdf8be264
--- /dev/null
+++ b/lib/Migration/Version4100Date20241021091352.php
@@ -0,0 +1,29 @@
+dropTable('mail_classifiers');
+ return $schema;
+ }
+}
diff --git a/lib/Model/Classifier.php b/lib/Model/Classifier.php
new file mode 100644
index 0000000000..df4d21eeb1
--- /dev/null
+++ b/lib/Model/Classifier.php
@@ -0,0 +1,134 @@
+accountId;
+ }
+
+ public function setAccountId(int $accountId): void {
+ $this->accountId = $accountId;
+ }
+
+ public function getType(): string {
+ return $this->type;
+ }
+
+ public function setType(string $type): void {
+ $this->type = $type;
+ }
+
+ public function getEstimator(): string {
+ return $this->estimator;
+ }
+
+ public function setEstimator(string $estimator): void {
+ $this->estimator = $estimator;
+ }
+
+ public function getPersistenceVersion(): int {
+ return $this->persistenceVersion;
+ }
+
+ public function setPersistenceVersion(int $persistenceVersion): void {
+ $this->persistenceVersion = $persistenceVersion;
+ }
+
+ public function getTrainingSetSize(): int {
+ return $this->trainingSetSize;
+ }
+
+ public function setTrainingSetSize(int $trainingSetSize): void {
+ $this->trainingSetSize = $trainingSetSize;
+ }
+
+ public function getValidationSetSize(): int {
+ return $this->validationSetSize;
+ }
+
+ public function setValidationSetSize(int $validationSetSize): void {
+ $this->validationSetSize = $validationSetSize;
+ }
+
+ public function getRecallImportant(): float {
+ return $this->recallImportant;
+ }
+
+ public function setRecallImportant(float $recallImportant): void {
+ $this->recallImportant = $recallImportant;
+ }
+
+ public function getPrecisionImportant(): float {
+ return $this->precisionImportant;
+ }
+
+ public function setPrecisionImportant(float $precisionImportant): void {
+ $this->precisionImportant = $precisionImportant;
+ }
+
+ public function getF1ScoreImportant(): float {
+ return $this->f1ScoreImportant;
+ }
+
+ public function setF1ScoreImportant(float $f1ScoreImportant): void {
+ $this->f1ScoreImportant = $f1ScoreImportant;
+ }
+
+ public function getDuration(): int {
+ return $this->duration;
+ }
+
+ public function setDuration(int $duration): void {
+ $this->duration = $duration;
+ }
+
+ public function getCreatedAt(): int {
+ return $this->createdAt;
+ }
+
+ public function setCreatedAt(int $createdAt): void {
+ $this->createdAt = $createdAt;
+ }
+
+ #[ReturnTypeWillChange]
+ public function jsonSerialize() {
+ return [
+ 'accountId' => $this->accountId,
+ 'type' => $this->type,
+ 'estimator' => $this->estimator,
+ 'persistenceVersion' => $this->persistenceVersion,
+ 'trainingSetSize' => $this->trainingSetSize,
+ 'validationSetSize' => $this->validationSetSize,
+ 'recallImportant' => $this->recallImportant,
+ 'precisionImportant' => $this->precisionImportant,
+ 'f1ScoreImportant' => $this->f1ScoreImportant,
+ 'duration' => $this->duration,
+ 'createdAt' => $this->createdAt,
+ ];
+ }
+}
diff --git a/lib/Model/ClassifierPipeline.php b/lib/Model/ClassifierPipeline.php
index eca0098195..bcc1be893e 100644
--- a/lib/Model/ClassifierPipeline.php
+++ b/lib/Model/ClassifierPipeline.php
@@ -9,28 +9,29 @@
namespace OCA\Mail\Model;
+use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor;
use Rubix\ML\Estimator;
use Rubix\ML\Transformers\Transformer;
class ClassifierPipeline {
- private Estimator $estimator;
-
- /** @var Transformer[] */
- private array $transformers;
-
/**
- * @param Estimator $estimator
* @param Transformer[] $transformers
*/
- public function __construct(Estimator $estimator, array $transformers) {
- $this->estimator = $estimator;
- $this->transformers = $transformers;
+ public function __construct(
+ private Estimator $estimator,
+ private IExtractor $extractor,
+ private array $transformers,
+ ) {
}
public function getEstimator(): Estimator {
return $this->estimator;
}
+ public function getExtractor(): IExtractor {
+ return $this->extractor;
+ }
+
/**
* @return Transformer[]
*/
diff --git a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php
index 197d0f2eb4..aaa30dcbfe 100644
--- a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php
+++ b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php
@@ -3,7 +3,7 @@
declare(strict_types=1);
/**
- * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors
+ * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
@@ -11,20 +11,34 @@
use OCA\Mail\Account;
use OCA\Mail\Db\Message;
+use Rubix\ML\Transformers\TfIdfTransformer;
+use Rubix\ML\Transformers\WordCountVectorizer;
use function OCA\Mail\array_flat_map;
/**
* Combines a set of DI'ed extractors so they can be used as one class
*/
-abstract class CompositeExtractor implements IExtractor {
- /** @var IExtractor[] */
- protected array $extractors;
+class CompositeExtractor implements IExtractor {
+ private readonly SubjectExtractor $subjectExtractor;
- /**
- * @param IExtractor[] $extractors
- */
- public function __construct(array $extractors) {
- $this->extractors = $extractors;
+ /** @var IExtractor[] */
+ private readonly array $extractors;
+
+ public function __construct(
+ ImportantMessagesExtractor $ex1,
+ ReadMessagesExtractor $ex2,
+ RepliedMessagesExtractor $ex3,
+ SentMessagesExtractor $ex4,
+ SubjectExtractor $ex5,
+ ) {
+ $this->subjectExtractor = $ex5;
+ $this->extractors = [
+ $ex1,
+ $ex2,
+ $ex3,
+ $ex4,
+ $ex5,
+ ];
}
public function prepare(Account $account,
@@ -36,12 +50,14 @@ public function prepare(Account $account,
}
}
- /**
- * @inheritDoc
- */
public function extract(Message $message): array {
return array_flat_map(static function (IExtractor $extractor) use ($message) {
return $extractor->extract($message);
}, $this->extractors);
}
+
+ public function getSubjectExtractor(): SubjectExtractor {
+ return $this->subjectExtractor;
+ }
}
+
diff --git a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php
deleted file mode 100644
index b7d86622ca..0000000000
--- a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php
+++ /dev/null
@@ -1,24 +0,0 @@
-subjectExtractor = $ex2;
- }
-
- public function getSubjectExtractor(): SubjectExtractor {
- return $this->subjectExtractor;
- }
-}
diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php
index b768938139..c5ecf6681f 100644
--- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php
+++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php
@@ -49,7 +49,7 @@ public function setWordCountVectorizer(WordCountVectorizer $wordCountVectorizer)
$this->limitFeatureSize();
}
- public function getTfidf(): Transformer {
+ public function getTfidf(): TfIdfTransformer {
return $this->tfidf;
}
diff --git a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php
deleted file mode 100644
index 09351f33f6..0000000000
--- a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php
+++ /dev/null
@@ -1,24 +0,0 @@
-mailboxMapper = $mailboxMapper;
$this->messageMapper = $messageMapper;
$this->persistenceService = $persistenceService;
$this->performanceLogger = $performanceLogger;
$this->rulesClassifier = $rulesClassifier;
- $this->vanillaExtractor = $vanillaExtractor;
$this->container = $container;
}
@@ -209,6 +204,8 @@ public function buildDataSet(
* @param bool $shuffleDataSet Shuffle the data set before training
* @param bool $persist Persist the trained classifier to use it for message classification
*
+ * @return Estimator|null The validation estimator, persisted estimator (if `$persist` === true) or null in case none was trained
+ *
* @throws ServiceException
*/
public function train(
@@ -218,12 +215,12 @@ public function train(
?Closure $estimator = null,
bool $shuffleDataSet = false,
bool $persist = true,
- ): void {
+ ): ?Estimator {
$perf = $this->performanceLogger->start('importance classifier training');
if ($extractor === null) {
try {
- $extractor = $this->container->get(NewCompositeExtractor::class);
+ $extractor = $this->container->get(CompositeExtractor::class);
} catch (ContainerExceptionInterface $e) {
throw new ServiceException('Default extractor is not available', 0, $e);
}
@@ -231,10 +228,10 @@ public function train(
$dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet);
if ($dataSet === null) {
- return;
+ return null;
}
- $this->trainWithCustomDataSet(
+ return $this->trainWithCustomDataSet(
$account,
$logger,
$dataSet,
@@ -256,24 +253,21 @@ public function train(
* @param PerformanceLoggerTask|null $perf Optionally reuse a performance logger task
* @param bool $persist Persist the trained classifier to use it for message classification
*
+ * @return Estimator|null The validation estimator, persisted estimator (if `$persist` === true) or null in case none was trained
+ *
* @throws ServiceException
*/
public function trainWithCustomDataSet(
Account $account,
LoggerInterface $logger,
array $dataSet,
- IExtractor $extractor,
+ CompositeExtractor $extractor,
?Closure $estimator,
?PerformanceLoggerTask $perf = null,
bool $persist = true,
- ): void {
+ ): ?Estimator {
$perf ??= $this->performanceLogger->start('importance classifier training');
-
- if ($estimator === null) {
- $estimator = static function () {
- return self::createDefaultEstimator();
- };
- }
+ $estimator ??= self::createDefaultEstimator(...);
/**
* How many of the most recent messages are excluded from training?
@@ -303,7 +297,7 @@ public function trainWithCustomDataSet(
if ($validationSet === [] || $trainingSet === []) {
$logger->info('not enough messages to train a classifier');
$perf->end();
- return;
+ return null;
}
/** @var Learner&Estimator&Persistable $validationEstimator */
@@ -321,30 +315,28 @@ public function trainWithCustomDataSet(
'exception' => $e,
]);
$perf->end();
- return;
+ return null;
}
$perf->step('train and validate classifier with training and validation sets');
- if ($persist) {
- /** @var Learner&Estimator&Persistable $persistedEstimator */
- $persistedEstimator = $estimator();
- $this->trainClassifier($persistedEstimator, $dataSet);
- $perf->step('train classifier with full data set');
-
- // Extract persisted transformers of the subject extractor.
- // Is a bit hacky but a full abstraction would be overkill.
- /** @var (Transformer&Persistable)[] $transformers */
- $transformers = [];
- if ($extractor instanceof NewCompositeExtractor) {
- $transformers[] = $extractor->getSubjectExtractor()->getWordCountVectorizer();
- $transformers[] = $extractor->getSubjectExtractor()->getTfidf();
- }
-
- $classifier->setAccountId($account->getId());
- $classifier->setDuration($perf->end());
- $this->persistenceService->persist($classifier, $persistedEstimator, $transformers);
- $logger->debug("classifier {$classifier->getId()} persisted");
+ if (!$persist) {
+ return $validationEstimator;
}
+
+ /** @var Learner&Estimator&Persistable $persistedEstimator */
+ $persistedEstimator = $estimator();
+ $this->trainClassifier($persistedEstimator, $dataSet);
+ $perf->step('train classifier with full data set');
+ $classifier->setDuration($perf->end());
+ $classifier->setAccountId($account->getId());
+ $classifier->setEstimator(get_class($persistedEstimator));
+ $classifier->setPersistenceVersion(PersistenceService::VERSION);
+
+ $this->persistenceService->persist($account, $persistedEstimator, $extractor);
+ $logger->debug("Classifier for account {$account->getId()} persisted", [
+ 'classifier' => $classifier,
+ ]);
+ return $persistedEstimator;
}
diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php
index 4b4eea5345..120136b27e 100644
--- a/lib/Service/Classification/PersistenceService.php
+++ b/lib/Service/Classification/PersistenceService.php
@@ -3,162 +3,78 @@
declare(strict_types=1);
/**
- * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors
+ * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
namespace OCA\Mail\Service\Classification;
-use OCA\DAV\Connector\Sabre\File;
use OCA\Mail\Account;
-use OCA\Mail\AppInfo\Application;
-use OCA\Mail\Db\Classifier;
-use OCA\Mail\Db\ClassifierMapper;
-use OCA\Mail\Db\MailAccountMapper;
use OCA\Mail\Exception\ServiceException;
use OCA\Mail\Model\ClassifierPipeline;
+use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor;
use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor;
-use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor;
-use OCA\Mail\Service\Classification\FeatureExtraction\SubjectExtractor;
-use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor;
-use OCP\App\IAppManager;
-use OCP\AppFramework\Db\DoesNotExistException;
-use OCP\AppFramework\Utility\ITimeFactory;
-use OCP\Files;
-use OCP\Files\IAppData;
-use OCP\Files\NotFoundException;
-use OCP\Files\NotPermittedException;
+use OCP\ICache;
use OCP\ICacheFactory;
-use OCP\ITempManager;
use Psr\Container\ContainerExceptionInterface;
use Psr\Container\ContainerInterface;
-use Psr\Log\LoggerInterface;
use Rubix\ML\Learner;
use Rubix\ML\Persistable;
use Rubix\ML\PersistentModel;
-use Rubix\ML\Persisters\Filesystem;
use Rubix\ML\Serializers\RBX;
use Rubix\ML\Transformers\TfIdfTransformer;
use Rubix\ML\Transformers\Transformer;
use Rubix\ML\Transformers\WordCountVectorizer;
use RuntimeException;
-use function file_get_contents;
-use function file_put_contents;
use function get_class;
-use function strlen;
class PersistenceService {
- private const ADD_DATA_FOLDER = 'classifiers';
+ // Increment the version when changing the classifier or transformer pipeline
+ public const VERSION = 1;
- /** @var ClassifierMapper */
- private $mapper;
-
- /** @var IAppData */
- private $appData;
-
- /** @var ITempManager */
- private $tempManager;
-
- /** @var ITimeFactory */
- private $timeFactory;
-
- /** @var IAppManager */
- private $appManager;
-
- /** @var ICacheFactory */
- private $cacheFactory;
-
- /** @var LoggerInterface */
- private $logger;
-
- /** @var MailAccountMapper */
- private $accountMapper;
-
- private ContainerInterface $container;
-
- public function __construct(ClassifierMapper $mapper,
- IAppData $appData,
- ITempManager $tempManager,
- ITimeFactory $timeFactory,
- IAppManager $appManager,
- ICacheFactory $cacheFactory,
- LoggerInterface $logger,
- MailAccountMapper $accountMapper,
- ContainerInterface $container) {
- $this->mapper = $mapper;
- $this->appData = $appData;
- $this->tempManager = $tempManager;
- $this->timeFactory = $timeFactory;
- $this->appManager = $appManager;
- $this->cacheFactory = $cacheFactory;
- $this->logger = $logger;
- $this->accountMapper = $accountMapper;
- $this->container = $container;
+ public function __construct(
+ private readonly ICacheFactory $cacheFactory,
+ private readonly ContainerInterface $container,
+ ) {
}
/**
- * Persist the classifier data to the database, the estimator and its transformers to storage
+ * Persist classifier, estimator and its transformers to the memory cache.
*
- * @param Classifier $classifier
* @param Learner&Persistable $estimator
- * @param (Transformer&Persistable)[] $transformers
*
- * @throws ServiceException
+ * @throws ServiceException If any serialization fails
*/
- public function persist(Classifier $classifier,
+ public function persist(
+ Account $account,
Learner $estimator,
- array $transformers): void {
- /*
- * First we have to insert the row to get the unique ID, but disable
- * it until the model is persisted as well. Otherwise another process
- * might try to load the model in the meantime and run into an error
- * due to the missing data in app data.
- */
- $classifier->setAppVersion($this->appManager->getAppVersion(Application::APP_ID));
- $classifier->setEstimator(get_class($estimator));
- $classifier->setActive(false);
- $classifier->setCreatedAt($this->timeFactory->getTime());
- $this->mapper->insert($classifier);
+ CompositeExtractor $extractor,
+ ): void {
+ $serializedData = [];
/*
- * Then we serialize the estimator into a temporary file
+ * First we serialize the estimator
*/
- $tmpPath = $this->tempManager->getTemporaryFile();
try {
- $model = new PersistentModel($estimator, new Filesystem($tmpPath));
+ $persister = new RubixMemoryPersister();
+ $model = new PersistentModel($estimator, $persister);
$model->save();
- $serializedClassifier = file_get_contents($tmpPath);
- $this->logger->debug('Serialized classifier written to tmp file (' . strlen($serializedClassifier) . 'B');
+ $serializedData[] = $persister->getData();
} catch (RuntimeException $e) {
throw new ServiceException('Could not serialize classifier: ' . $e->getMessage(), 0, $e);
}
/*
- * Then we store the serialized model to app data
- */
- try {
- try {
- $folder = $this->appData->getFolder(self::ADD_DATA_FOLDER);
- $this->logger->debug('Using existing folder for the serialized classifier');
- } catch (NotFoundException $e) {
- $folder = $this->appData->newFolder(self::ADD_DATA_FOLDER);
- $this->logger->debug('New folder created for serialized classifiers');
- }
- $file = $folder->newFile((string)$classifier->getId());
- $file->putContent($serializedClassifier);
- $this->logger->debug('Serialized classifier written to app data');
- } catch (NotPermittedException|NotFoundException $e) {
- throw new ServiceException('Could not create classifiers directory: ' . $e->getMessage(), 0, $e);
- }
-
- /*
- * Then we serialize the transformer pipeline to temporary files
+ * Then we serialize the transformer pipeline
*/
- $transformerIndex = 0;
+ $transfomers = [
+ $extractor->getSubjectExtractor()->getWordCountVectorizer(),
+ $extractor->getSubjectExtractor()->getTfIdf(),
+ ];
$serializer = new RBX();
- foreach ($transformers as $transformer) {
- $tmpPath = $this->tempManager->getTemporaryFile();
+ foreach ($transfomers as $transformer) {
try {
+ $persister = new RubixMemoryPersister();
/**
* This is how to serialize a transformer according to the official docs.
* PersistentModel can only be used on Learners which transformers don't implement.
@@ -167,316 +83,145 @@ public function persist(Classifier $classifier,
*
* @psalm-suppress InternalMethod
*/
- $serializer->serialize($transformer)->saveTo(new Filesystem($tmpPath));
- $serializedTransformer = file_get_contents($tmpPath);
- $this->logger->debug('Serialized transformer written to tmp file (' . strlen($serializedTransformer) . 'B');
+ $serializer->serialize($transformer)->saveTo($persister);
+ $serializedData[] = $persister->getData();
} catch (RuntimeException $e) {
throw new ServiceException('Could not serialize transformer: ' . $e->getMessage(), 0, $e);
}
-
- try {
- $file = $folder->newFile("{$classifier->getId()}_t$transformerIndex");
- $file->putContent($serializedTransformer);
- $this->logger->debug("Serialized transformer $transformerIndex written to app data");
- } catch (NotPermittedException|NotFoundException $e) {
- throw new ServiceException(
- "Failed to persist transformer $transformerIndex: " . $e->getMessage(),
- 0,
- $e
- );
- }
-
- $transformerIndex++;
}
- /*
- * Now we set the model active so it can be used by the next request
- */
- $classifier->setActive(true);
- $this->mapper->update($classifier);
+ $this->setCached((string)$account->getId(), $serializedData);
}
/**
- * @param Account $account
- *
- * @return ?array [Estimator, IExtractor]
+ * Load the latest estimator and its transformers.
*
- * @throws ServiceException
+ * @throws ServiceException If any deserialization fails
*/
- public function loadLatest(Account $account): ?array {
- try {
- $latestModel = $this->mapper->findLatest($account->getId());
- } catch (DoesNotExistException $e) {
+ public function loadLatest(Account $account): ?ClassifierPipeline {
+ $cached = $this->getCached((string)$account->getId());
+ if ($cached == null) {
return null;
}
- $pipeline = $this->load($latestModel);
+ $serializedModel = $cached[0];
+ $serializedTransformers = array_slice($cached, 1);
try {
- $extractor = $this->loadExtractor($latestModel, $pipeline);
- } catch (ContainerExceptionInterface $e) {
+ $estimator = PersistentModel::load(new RubixMemoryPersister($serializedModel));
+ } catch (RuntimeException $e) {
throw new ServiceException(
- "Failed to load extractor: {$e->getMessage()}",
+ 'Could not deserialize persisted classifier: ' . $e->getMessage(),
0,
$e,
);
}
- return [$pipeline->getEstimator(), $extractor];
- }
-
- /**
- * Load an estimator and its transformers of a classifier from storage
- *
- * @param Classifier $classifier
- * @return ClassifierPipeline
- *
- * @throws ServiceException
- */
- public function load(Classifier $classifier): ClassifierPipeline {
- $transformerCount = 0;
- $appVersion = $this->parseAppVersion($classifier->getAppVersion());
- if ($appVersion[0] >= 3 && $appVersion[1] >= 2) {
- $transformerCount = 2;
- }
-
- $id = $classifier->getId();
- $cached = $this->getCached($classifier->getId(), $transformerCount);
- if ($cached !== null) {
- $this->logger->debug("Using cached serialized classifier $id");
- $serialized = $cached[0];
- $serializedTransformers = array_slice($cached, 1);
- } else {
- $this->logger->debug('Loading serialized classifier from app data');
- try {
- $modelsFolder = $this->appData->getFolder(self::ADD_DATA_FOLDER);
- $modelFile = $modelsFolder->getFile((string)$id);
- } catch (NotFoundException $e) {
- $this->logger->debug("Could not load classifier $id: " . $e->getMessage());
- throw new ServiceException("Could not load classifier $id: " . $e->getMessage(), 0, $e);
- }
-
- try {
- $serialized = $modelFile->getContent();
- } catch (NotFoundException|NotPermittedException $e) {
- $this->logger->debug("Could not load content for model file with classifier id $id: " . $e->getMessage());
- throw new ServiceException("Could not load content for model file with classifier id $id: " . $e->getMessage(), 0, $e);
- }
- $size = strlen($serialized);
- $this->logger->debug("Serialized classifier loaded (size=$size)");
-
- $serializedTransformers = [];
- for ($i = 0; $i < $transformerCount; $i++) {
- try {
- $transformerFile = $modelsFolder->getFile("{$id}_t$i");
- } catch (NotFoundException $e) {
- $this->logger->debug("Could not load transformer $i of classifier $id: " . $e->getMessage());
- throw new ServiceException("Could not load transformer $i of classifier $id: " . $e->getMessage(), 0, $e);
- }
-
- try {
- $serializedTransformer = $transformerFile->getContent();
- } catch (NotFoundException|NotPermittedException $e) {
- $this->logger->debug("Could not load content for transformer file $i with classifier id $id: " . $e->getMessage());
- throw new ServiceException("Could not load content for transformer file $i with classifier id $id: " . $e->getMessage(), 0, $e);
- }
- $size = strlen($serializedTransformer);
- $this->logger->debug("Serialized transformer $i loaded (size=$size)");
- $serializedTransformers[] = $serializedTransformer;
- }
-
- $this->cache($id, $serialized, $serializedTransformers);
- }
-
- $tmpPath = $this->tempManager->getTemporaryFile();
- file_put_contents($tmpPath, $serialized);
- try {
- $estimator = PersistentModel::load(new Filesystem($tmpPath));
- } catch (RuntimeException $e) {
- throw new ServiceException("Could not deserialize persisted classifier $id: " . $e->getMessage(), 0, $e);
- }
-
- $transformers = array_map(function (string $serializedTransformer) use ($id) {
- $serializer = new RBX();
- $tmpPath = $this->tempManager->getTemporaryFile();
- file_put_contents($tmpPath, $serializedTransformer);
+ $serializer = new RBX();
+ $transformers = array_map(function (string $serializedTransformer) use ($serializer) {
try {
- $persister = new Filesystem($tmpPath);
+ $persister = new RubixMemoryPersister($serializedTransformer);
$transformer = $persister->load()->deserializeWith($serializer);
} catch (RuntimeException $e) {
- throw new ServiceException("Could not deserialize persisted transformer of classifier $id: " . $e->getMessage(), 0, $e);
+ throw new ServiceException(
+ 'Could not deserialize persisted transformer of classifier: ' . $e->getMessage(),
+ 0,
+ $e,
+ );
}
if (!($transformer instanceof Transformer)) {
- throw new ServiceException("Transformer of classifier $id is not a transformer: Got " . $transformer::class);
+ throw new ServiceException(sprintf(
+ 'Transformer is not an instance of %s: Got %s',
+ Transformer::class,
+ get_class($transformer),
+ ));
}
return $transformer;
}, $serializedTransformers);
- return new ClassifierPipeline($estimator, $transformers);
- }
-
- public function cleanUp(): void {
- $threshold = $this->timeFactory->getTime() - 30 * 24 * 60 * 60;
- $totalAccounts = $this->accountMapper->getTotal();
- $classifiers = $this->mapper->findHistoric($threshold, $totalAccounts * 10);
- foreach ($classifiers as $classifier) {
- try {
- $this->deleteModel($classifier->getId());
- $this->mapper->delete($classifier);
- } catch (NotPermittedException $e) {
- // Log and continue. This is not critical
- $this->logger->warning('Could not clean-up old classifier', [
- 'id' => $classifier->getId(),
- 'exception' => $e,
- ]);
- }
- }
- }
-
- /**
- * @throws NotPermittedException
- */
- private function deleteModel(int $id): void {
- $this->logger->debug('Deleting serialized classifier from app data', [
- 'id' => $id,
- ]);
- try {
- $modelsFolder = $this->appData->getFolder(self::ADD_DATA_FOLDER);
- $modelFile = $modelsFolder->getFile((string)$id);
- $modelFile->delete();
- } catch (NotFoundException $e) {
- $this->logger->debug("Classifier model $id does not exist", [
- 'exception' => $e,
- ]);
- }
- }
-
- /**
- * Load and instantiate extractor based on a classifier's app version.
- *
- * @param Classifier $classifier
- * @param ClassifierPipeline $pipeline
- * @return IExtractor
- *
- * @throws ContainerExceptionInterface
- * @throws ServiceException
- */
- private function loadExtractor(Classifier $classifier,
- ClassifierPipeline $pipeline): IExtractor {
- $appVersion = $this->parseAppVersion($classifier->getAppVersion());
- if ($appVersion[0] >= 3 && $appVersion[1] >= 2) {
- return $this->loadExtractorV2($pipeline->getTransformers());
- }
-
- return $this->loadExtractorV1($pipeline->getTransformers());
- }
+ $extractor = $this->loadExtractor($transformers);
- /**
- * @return VanillaCompositeExtractor
- *
- * @throws ContainerExceptionInterface
- */
- private function loadExtractorV1(): VanillaCompositeExtractor {
- return $this->container->get(VanillaCompositeExtractor::class);
+ return new ClassifierPipeline($estimator, $extractor, $transformers);
}
/**
- * @param Transformer[] $transformers
- * @return NewCompositeExtractor
+ * Load and instantiate extractor based on the given transformers.
*
- * @throws ContainerExceptionInterface
- * @throws ServiceException
+ * @throws ServiceException If the transformers array contains unexpected instances or the composite extractor can't be instantiated
*/
- private function loadExtractorV2(array $transformers): NewCompositeExtractor {
+ private function loadExtractor(array $transformers): IExtractor {
$wordCountVectorizer = $transformers[0];
if (!($wordCountVectorizer instanceof WordCountVectorizer)) {
- throw new ServiceException('Failed to load persisted transformer: Expected ' . WordCountVectorizer::class . ', got' . $wordCountVectorizer::class);
+ throw new ServiceException(sprintf(
+ 'Failed to load persisted transformer: Expected %s, got %s',
+ WordCountVectorizer::class,
+ get_class($wordCountVectorizer),
+ ));
}
+
$tfidfTransformer = $transformers[1];
if (!($tfidfTransformer instanceof TfIdfTransformer)) {
- throw new ServiceException('Failed to load persisted transformer: Expected ' . TfIdfTransformer::class . ', got' . $tfidfTransformer::class);
+ throw new ServiceException(sprintf(
+ 'Failed to load persisted transformer: Expected %s, got %s',
+ TfIdfTransformer::class,
+ get_class($tfidfTransformer),
+ ));
}
- $subjectExtractor = new SubjectExtractor();
- $subjectExtractor->setWordCountVectorizer($wordCountVectorizer);
- $subjectExtractor->setTfidf($tfidfTransformer);
- return new NewCompositeExtractor(
- $this->container->get(VanillaCompositeExtractor::class),
- $subjectExtractor,
- );
- }
+ try {
+ /** @var CompositeExtractor $extractor */
+ $extractor = $this->container->get(CompositeExtractor::class);
+ } catch (ContainerExceptionInterface $e) {
+ throw new ServiceException('Failed to instantiate the composite extractor', 0, $e);
+ }
- private function getCacheKey(int $id): string {
- return "mail_classifier_$id";
+ $extractor->getSubjectExtractor()->setWordCountVectorizer($wordCountVectorizer);
+ $extractor->getSubjectExtractor()->setTfidf($tfidfTransformer);
+ return $extractor;
}
- private function getTransformerCacheKey(int $id, int $index): string {
- return $this->getCacheKey($id) . "_transformer_$index";
+ private function getCacheInstance(): ?ICache {
+ if (!$this->cacheFactory->isAvailable()) {
+ return null;
+ }
+
+ $version = self::VERSION;
+ return $this->cacheFactory->createDistributed("mail-classifier/v$version/");
}
/**
- * @param int $id
- * @param int $transformerCount
- *
- * @return (?string)[]|null Array of serialized classifier and transformers
+ * @return string[]|null Array of serialized classifier and transformers
*/
- private function getCached(int $id, int $transformerCount): ?array {
- // FIXME: Will always return null as the cached, serialized data is always an empty string.
- // See my note in self::cache() for further elaboration.
-
- if (!$this->cacheFactory->isLocalCacheAvailable()) {
+ private function getCached(string $id): ?array {
+ $cache = $this->getCacheInstance();
+ if ($cache === null) {
return null;
}
- $cache = $this->cacheFactory->createLocal();
- $values = [];
- $values[] = $cache->get($this->getCacheKey($id));
- for ($i = 0; $i < $transformerCount; $i++) {
- $values[] = $cache->get($this->getTransformerCacheKey($id, $i));
- }
-
- // Only return cached values if estimator and all transformers are available
- if (in_array(null, $values, true)) {
+ $json = $cache->get($id);
+ if (!is_string($json)) {
return null;
}
- return $values;
- }
-
- private function cache(int $id, string $serialized, array $serializedTransformers): void {
- // FIXME: This is broken as some cache implementations will run the provided value through
- // json_encode which drops non-utf8 strings. The serialized string contains binary
- // data so an empty string will be saved instead (tested on Redis).
- // Note: JSON requires strings to be valid utf8 (as per its spec).
-
- // IDEA: Implement a method ICache::setRaw() that forwards a raw/binary string as is to the
- // underlying cache backend.
-
- if (!$this->cacheFactory->isLocalCacheAvailable()) {
- return;
- }
- $cache = $this->cacheFactory->createLocal();
- $cache->set($this->getCacheKey($id), $serialized);
-
- $transformerIndex = 0;
- foreach ($serializedTransformers as $transformer) {
- $cache->set($this->getTransformerCacheKey($id, $transformerIndex), $transformer);
- $transformerIndex++;
- }
+ $serializedData = json_decode($json);
+ return array_map(base64_decode(...), $serializedData);
}
/**
- * Parse minor and major part of the given semver string.
- *
- * @return int[]
+ * @param string[] $serializedData Array of serialized classifier and transformers
*/
- private function parseAppVersion(string $version): array {
- $parts = explode('.', $version);
- if (count($parts) < 2) {
- return [0, 0];
+ private function setCached(string $id, array $serializedData): void {
+ $cache = $this->getCacheInstance();
+ if ($cache === null) {
+ return;
}
- return [(int)$parts[0], (int)$parts[1]];
+ // Serialized data contains binary, non-utf8 data so we encode it as base64 first
+ $encodedData = array_map(base64_encode(...), $serializedData);
+ $json = json_encode($encodedData, JSON_THROW_ON_ERROR);
+
+ // Set a ttl of a week because a new model will be generated daily
+ $cache->set($id, $json, 3600 * 24 * 7);
}
}
diff --git a/lib/Service/Classification/RubixMemoryPersister.php b/lib/Service/Classification/RubixMemoryPersister.php
new file mode 100644
index 0000000000..2b170b38b5
--- /dev/null
+++ b/lib/Service/Classification/RubixMemoryPersister.php
@@ -0,0 +1,41 @@
+data;
+ }
+
+ public function save(Encoding $encoding): void {
+ $this->data = $encoding->data();
+ }
+
+ public function load(): Encoding {
+ if ($this->data === null) {
+ throw new ValueError('Trying to load encoding when no data is available');
+ }
+
+ return new Encoding($this->data);
+ }
+
+ public function __toString() {
+ return (string)self::class;
+ }
+}
diff --git a/lib/Service/CleanupService.php b/lib/Service/CleanupService.php
index 65fa421200..74f5c0070a 100644
--- a/lib/Service/CleanupService.php
+++ b/lib/Service/CleanupService.php
@@ -17,7 +17,6 @@
use OCA\Mail\Db\MessageRetentionMapper;
use OCA\Mail\Db\MessageSnoozeMapper;
use OCA\Mail\Db\TagMapper;
-use OCA\Mail\Service\Classification\PersistenceService;
use OCA\Mail\Support\PerformanceLogger;
use OCP\AppFramework\Utility\ITimeFactory;
use Psr\Log\LoggerInterface;
@@ -44,7 +43,6 @@ class CleanupService {
private MessageSnoozeMapper $messageSnoozeMapper;
- private PersistenceService $classifierPersistenceService;
private ITimeFactory $timeFactory;
public function __construct(MailAccountMapper $mailAccountMapper,
@@ -55,7 +53,6 @@ public function __construct(MailAccountMapper $mailAccountMapper,
TagMapper $tagMapper,
MessageRetentionMapper $messageRetentionMapper,
MessageSnoozeMapper $messageSnoozeMapper,
- PersistenceService $classifierPersistenceService,
ITimeFactory $timeFactory) {
$this->aliasMapper = $aliasMapper;
$this->mailboxMapper = $mailboxMapper;
@@ -64,7 +61,6 @@ public function __construct(MailAccountMapper $mailAccountMapper,
$this->tagMapper = $tagMapper;
$this->messageRetentionMapper = $messageRetentionMapper;
$this->messageSnoozeMapper = $messageSnoozeMapper;
- $this->classifierPersistenceService = $classifierPersistenceService;
$this->mailAccountMapper = $mailAccountMapper;
$this->timeFactory = $timeFactory;
}
@@ -92,8 +88,6 @@ public function cleanUp(LoggerInterface $logger): void {
$task->step('delete expired messages');
$this->messageSnoozeMapper->deleteOrphans();
$task->step('delete orphan snoozes');
- $this->classifierPersistenceService->cleanUp();
- $task->step('delete orphan classifiers');
$task->end();
}
}