From 3c9bdae971cc533bda620daa4561d30d7ed2dd9c Mon Sep 17 00:00:00 2001 From: Richard Steinmetz Date: Mon, 21 Oct 2024 12:15:29 +0200 Subject: [PATCH] Persist classifiers in memory cache only --- appinfo/info.xml | 2 +- lib/Command/RunMetaEstimator.php | 15 +- lib/Command/TrainAccount.php | 26 +- lib/Db/Classifier.php | 106 ---- lib/Db/ClassifierMapper.php | 57 --- .../Version4100Date20241021091352.php | 29 ++ lib/Model/Classifier.php | 134 +++++ lib/Model/ClassifierPipeline.php | 19 +- .../FeatureExtraction/CompositeExtractor.php | 40 +- .../NewCompositeExtractor.php | 24 - .../FeatureExtraction/SubjectExtractor.php | 2 +- .../VanillaCompositeExtractor.php | 24 - .../Classification/ImportanceClassifier.php | 74 ++- .../Classification/PersistenceService.php | 463 ++++-------------- .../Classification/RubixMemoryPersister.php | 41 ++ lib/Service/CleanupService.php | 6 - 16 files changed, 396 insertions(+), 666 deletions(-) delete mode 100644 lib/Db/Classifier.php delete mode 100644 lib/Db/ClassifierMapper.php create mode 100644 lib/Migration/Version4100Date20241021091352.php create mode 100644 lib/Model/Classifier.php delete mode 100644 lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php delete mode 100644 lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php create mode 100644 lib/Service/Classification/RubixMemoryPersister.php diff --git a/appinfo/info.xml b/appinfo/info.xml index b0e0c1e72a..7008982213 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -34,7 +34,7 @@ The rating depends on the installed text processing backend. See [the rating ove Learn more about the Nextcloud Ethical AI Rating [in our blog](https://nextcloud.com/blog/nextcloud-ethical-ai-rating/). ]]> - 4.1.0-alpha.2 + 4.1.0-alpha.3 agpl Christoph Wurst GretaD diff --git a/lib/Command/RunMetaEstimator.php b/lib/Command/RunMetaEstimator.php index 8a00c1bafe..32d1187d37 100644 --- a/lib/Command/RunMetaEstimator.php +++ b/lib/Command/RunMetaEstimator.php @@ -10,7 +10,7 @@ namespace OCA\Mail\Command; use OCA\Mail\Service\AccountService; -use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; +use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; @@ -86,8 +86,8 @@ protected function execute(InputInterface $input, OutputInterface $output): int return 1; } - /** @var NewCompositeExtractor $extractor */ - $extractor = $this->container->get(NewCompositeExtractor::class); + /** @var CompositeExtractor $extractor */ + $extractor = $this->container->get(CompositeExtractor::class); $consoleLogger = new ConsoleLoggerDecorator( $this->logger, $output @@ -124,8 +124,9 @@ protected function execute(InputInterface $input, OutputInterface $output): int return $estimator; }; + /** @var GridSearch $metaEstimator */ if ($dataSet) { - $this->classifier->trainWithCustomDataSet( + $metaEstimator = $this->classifier->trainWithCustomDataSet( $account, $consoleLogger, $dataSet, @@ -135,7 +136,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int false, ); } else { - $this->classifier->train( + $metaEstimator = $this->classifier->train( $account, $consoleLogger, $extractor, @@ -145,6 +146,10 @@ protected function execute(InputInterface $input, OutputInterface $output): int ); } + if ($metaEstimator) { + $output->writeln("Best estimator: {$metaEstimator->base()}"); + } + $mbs = (int)(memory_get_peak_usage() / 1024 / 1024); $output->writeln('' . $mbs . 'MB of memory used'); return 0; diff --git a/lib/Command/TrainAccount.php b/lib/Command/TrainAccount.php index 33ce44371d..12866de7f3 100644 --- a/lib/Command/TrainAccount.php +++ b/lib/Command/TrainAccount.php @@ -3,7 +3,7 @@ declare(strict_types=1); /** - * SPDX-FileCopyrightText: 2019 Nextcloud GmbH and Nextcloud contributors + * SPDX-FileCopyrightText: 2019-2024 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ @@ -11,15 +11,13 @@ use OCA\Mail\Service\AccountService; use OCA\Mail\Service\Classification\ClassificationSettingsService; +use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor; use OCA\Mail\Service\Classification\ImportanceClassifier; use OCA\Mail\Support\ConsoleLoggerDecorator; use OCP\AppFramework\Db\DoesNotExistException; use Psr\Container\ContainerInterface; use Psr\Log\LoggerInterface; -use Rubix\ML\Classifiers\GaussianNB; use Symfony\Component\Console\Command\Command; use Symfony\Component\Console\Input\InputArgument; use Symfony\Component\Console\Input\InputInterface; @@ -110,9 +108,6 @@ protected function execute(InputInterface $input, OutputInterface $output): int $shuffle = (bool)$input->getOption(self::ARGUMENT_SHUFFLE); $dryRun = (bool)$input->getOption(self::ARGUMENT_DRY_RUN); $force = (bool)$input->getOption(self::ARGUMENT_FORCE); - $old = (bool)$input->getOption(self::ARGUMENT_OLD); - $oldEstimator = $old || $input->getOption(self::ARGUMENT_OLD_ESTIMATOR); - $oldExtractor = $old || $input->getOption(self::ARGUMENT_OLD_EXTRACTOR); try { $account = $this->accountService->findById($accountId); @@ -127,18 +122,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int } /** @var IExtractor $extractor */ - if ($oldExtractor) { - $extractor = $this->container->get(VanillaCompositeExtractor::class); - } else { - $extractor = $this->container->get(NewCompositeExtractor::class); - } - - $estimator = null; - if ($oldEstimator) { - $estimator = static function () { - return new GaussianNB(); - }; - } + $extractor = $this->container->get(CompositeExtractor::class); $consoleLogger = new ConsoleLoggerDecorator( $this->logger, @@ -167,7 +151,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int $consoleLogger, $dataSet, $extractor, - $estimator, + null, null, !$dryRun ); @@ -176,7 +160,7 @@ protected function execute(InputInterface $input, OutputInterface $output): int $account, $consoleLogger, $extractor, - $estimator, + null, $shuffle, !$dryRun ); diff --git a/lib/Db/Classifier.php b/lib/Db/Classifier.php deleted file mode 100644 index 126f986730..0000000000 --- a/lib/Db/Classifier.php +++ /dev/null @@ -1,106 +0,0 @@ -addType('accountId', 'int'); - $this->addType('type', 'string'); - $this->addType('appVersion', 'string'); - $this->addType('trainingSetSize', 'int'); - $this->addType('validationSetSize', 'int'); - $this->addType('recallImportant', 'float'); - $this->addType('precisionImportant', 'float'); - $this->addType('f1ScoreImportant', 'float'); - $this->addType('duration', 'int'); - $this->addType('active', 'boolean'); - $this->addType('createdAt', 'int'); - } -} diff --git a/lib/Db/ClassifierMapper.php b/lib/Db/ClassifierMapper.php deleted file mode 100644 index 946b70f8b0..0000000000 --- a/lib/Db/ClassifierMapper.php +++ /dev/null @@ -1,57 +0,0 @@ - - */ -class ClassifierMapper extends QBMapper { - public function __construct(IDBConnection $db) { - parent::__construct($db, 'mail_classifiers'); - } - - /** - * @param int $id - * - * @return Classifier - * @throws DoesNotExistException - */ - public function findLatest(int $id): Classifier { - $qb = $this->db->getQueryBuilder(); - - $select = $qb->select('*') - ->from($this->getTableName()) - ->where( - $qb->expr()->eq('account_id', $qb->createNamedParameter($id, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT), - $qb->expr()->eq('active', $qb->createNamedParameter(true, IQueryBuilder::PARAM_BOOL), IQueryBuilder::PARAM_BOOL) - ) - ->orderBy('created_at', 'desc') - ->setMaxResults(1); - - return $this->findEntity($select); - } - - public function findHistoric(int $threshold, int $limit) { - $qb = $this->db->getQueryBuilder(); - $select = $qb->select('*') - ->from($this->getTableName()) - ->where( - $qb->expr()->lte('created_at', $qb->createNamedParameter($threshold, IQueryBuilder::PARAM_INT), IQueryBuilder::PARAM_INT), - ) - ->orderBy('created_at', 'asc') - ->setMaxResults($limit); - return $this->findEntities($select); - } -} diff --git a/lib/Migration/Version4100Date20241021091352.php b/lib/Migration/Version4100Date20241021091352.php new file mode 100644 index 0000000000..ccdf8be264 --- /dev/null +++ b/lib/Migration/Version4100Date20241021091352.php @@ -0,0 +1,29 @@ +dropTable('mail_classifiers'); + return $schema; + } +} diff --git a/lib/Model/Classifier.php b/lib/Model/Classifier.php new file mode 100644 index 0000000000..df4d21eeb1 --- /dev/null +++ b/lib/Model/Classifier.php @@ -0,0 +1,134 @@ +accountId; + } + + public function setAccountId(int $accountId): void { + $this->accountId = $accountId; + } + + public function getType(): string { + return $this->type; + } + + public function setType(string $type): void { + $this->type = $type; + } + + public function getEstimator(): string { + return $this->estimator; + } + + public function setEstimator(string $estimator): void { + $this->estimator = $estimator; + } + + public function getPersistenceVersion(): int { + return $this->persistenceVersion; + } + + public function setPersistenceVersion(int $persistenceVersion): void { + $this->persistenceVersion = $persistenceVersion; + } + + public function getTrainingSetSize(): int { + return $this->trainingSetSize; + } + + public function setTrainingSetSize(int $trainingSetSize): void { + $this->trainingSetSize = $trainingSetSize; + } + + public function getValidationSetSize(): int { + return $this->validationSetSize; + } + + public function setValidationSetSize(int $validationSetSize): void { + $this->validationSetSize = $validationSetSize; + } + + public function getRecallImportant(): float { + return $this->recallImportant; + } + + public function setRecallImportant(float $recallImportant): void { + $this->recallImportant = $recallImportant; + } + + public function getPrecisionImportant(): float { + return $this->precisionImportant; + } + + public function setPrecisionImportant(float $precisionImportant): void { + $this->precisionImportant = $precisionImportant; + } + + public function getF1ScoreImportant(): float { + return $this->f1ScoreImportant; + } + + public function setF1ScoreImportant(float $f1ScoreImportant): void { + $this->f1ScoreImportant = $f1ScoreImportant; + } + + public function getDuration(): int { + return $this->duration; + } + + public function setDuration(int $duration): void { + $this->duration = $duration; + } + + public function getCreatedAt(): int { + return $this->createdAt; + } + + public function setCreatedAt(int $createdAt): void { + $this->createdAt = $createdAt; + } + + #[ReturnTypeWillChange] + public function jsonSerialize() { + return [ + 'accountId' => $this->accountId, + 'type' => $this->type, + 'estimator' => $this->estimator, + 'persistenceVersion' => $this->persistenceVersion, + 'trainingSetSize' => $this->trainingSetSize, + 'validationSetSize' => $this->validationSetSize, + 'recallImportant' => $this->recallImportant, + 'precisionImportant' => $this->precisionImportant, + 'f1ScoreImportant' => $this->f1ScoreImportant, + 'duration' => $this->duration, + 'createdAt' => $this->createdAt, + ]; + } +} diff --git a/lib/Model/ClassifierPipeline.php b/lib/Model/ClassifierPipeline.php index eca0098195..bcc1be893e 100644 --- a/lib/Model/ClassifierPipeline.php +++ b/lib/Model/ClassifierPipeline.php @@ -9,28 +9,29 @@ namespace OCA\Mail\Model; +use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; use Rubix\ML\Estimator; use Rubix\ML\Transformers\Transformer; class ClassifierPipeline { - private Estimator $estimator; - - /** @var Transformer[] */ - private array $transformers; - /** - * @param Estimator $estimator * @param Transformer[] $transformers */ - public function __construct(Estimator $estimator, array $transformers) { - $this->estimator = $estimator; - $this->transformers = $transformers; + public function __construct( + private Estimator $estimator, + private IExtractor $extractor, + private array $transformers, + ) { } public function getEstimator(): Estimator { return $this->estimator; } + public function getExtractor(): IExtractor { + return $this->extractor; + } + /** * @return Transformer[] */ diff --git a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php index 197d0f2eb4..aaa30dcbfe 100644 --- a/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/CompositeExtractor.php @@ -3,7 +3,7 @@ declare(strict_types=1); /** - * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors + * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ @@ -11,20 +11,34 @@ use OCA\Mail\Account; use OCA\Mail\Db\Message; +use Rubix\ML\Transformers\TfIdfTransformer; +use Rubix\ML\Transformers\WordCountVectorizer; use function OCA\Mail\array_flat_map; /** * Combines a set of DI'ed extractors so they can be used as one class */ -abstract class CompositeExtractor implements IExtractor { - /** @var IExtractor[] */ - protected array $extractors; +class CompositeExtractor implements IExtractor { + private readonly SubjectExtractor $subjectExtractor; - /** - * @param IExtractor[] $extractors - */ - public function __construct(array $extractors) { - $this->extractors = $extractors; + /** @var IExtractor[] */ + private readonly array $extractors; + + public function __construct( + ImportantMessagesExtractor $ex1, + ReadMessagesExtractor $ex2, + RepliedMessagesExtractor $ex3, + SentMessagesExtractor $ex4, + SubjectExtractor $ex5, + ) { + $this->subjectExtractor = $ex5; + $this->extractors = [ + $ex1, + $ex2, + $ex3, + $ex4, + $ex5, + ]; } public function prepare(Account $account, @@ -36,12 +50,14 @@ public function prepare(Account $account, } } - /** - * @inheritDoc - */ public function extract(Message $message): array { return array_flat_map(static function (IExtractor $extractor) use ($message) { return $extractor->extract($message); }, $this->extractors); } + + public function getSubjectExtractor(): SubjectExtractor { + return $this->subjectExtractor; + } } + diff --git a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php deleted file mode 100644 index b7d86622ca..0000000000 --- a/lib/Service/Classification/FeatureExtraction/NewCompositeExtractor.php +++ /dev/null @@ -1,24 +0,0 @@ -subjectExtractor = $ex2; - } - - public function getSubjectExtractor(): SubjectExtractor { - return $this->subjectExtractor; - } -} diff --git a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php index b768938139..c5ecf6681f 100644 --- a/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php +++ b/lib/Service/Classification/FeatureExtraction/SubjectExtractor.php @@ -49,7 +49,7 @@ public function setWordCountVectorizer(WordCountVectorizer $wordCountVectorizer) $this->limitFeatureSize(); } - public function getTfidf(): Transformer { + public function getTfidf(): TfIdfTransformer { return $this->tfidf; } diff --git a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php b/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php deleted file mode 100644 index 09351f33f6..0000000000 --- a/lib/Service/Classification/FeatureExtraction/VanillaCompositeExtractor.php +++ /dev/null @@ -1,24 +0,0 @@ -mailboxMapper = $mailboxMapper; $this->messageMapper = $messageMapper; $this->persistenceService = $persistenceService; $this->performanceLogger = $performanceLogger; $this->rulesClassifier = $rulesClassifier; - $this->vanillaExtractor = $vanillaExtractor; $this->container = $container; } @@ -209,6 +204,8 @@ public function buildDataSet( * @param bool $shuffleDataSet Shuffle the data set before training * @param bool $persist Persist the trained classifier to use it for message classification * + * @return Estimator|null The validation estimator, persisted estimator (if `$persist` === true) or null in case none was trained + * * @throws ServiceException */ public function train( @@ -218,12 +215,12 @@ public function train( ?Closure $estimator = null, bool $shuffleDataSet = false, bool $persist = true, - ): void { + ): ?Estimator { $perf = $this->performanceLogger->start('importance classifier training'); if ($extractor === null) { try { - $extractor = $this->container->get(NewCompositeExtractor::class); + $extractor = $this->container->get(CompositeExtractor::class); } catch (ContainerExceptionInterface $e) { throw new ServiceException('Default extractor is not available', 0, $e); } @@ -231,10 +228,10 @@ public function train( $dataSet = $this->buildDataSet($account, $extractor, $logger, $perf, $shuffleDataSet); if ($dataSet === null) { - return; + return null; } - $this->trainWithCustomDataSet( + return $this->trainWithCustomDataSet( $account, $logger, $dataSet, @@ -256,24 +253,21 @@ public function train( * @param PerformanceLoggerTask|null $perf Optionally reuse a performance logger task * @param bool $persist Persist the trained classifier to use it for message classification * + * @return Estimator|null The validation estimator, persisted estimator (if `$persist` === true) or null in case none was trained + * * @throws ServiceException */ public function trainWithCustomDataSet( Account $account, LoggerInterface $logger, array $dataSet, - IExtractor $extractor, + CompositeExtractor $extractor, ?Closure $estimator, ?PerformanceLoggerTask $perf = null, bool $persist = true, - ): void { + ): ?Estimator { $perf ??= $this->performanceLogger->start('importance classifier training'); - - if ($estimator === null) { - $estimator = static function () { - return self::createDefaultEstimator(); - }; - } + $estimator ??= self::createDefaultEstimator(...); /** * How many of the most recent messages are excluded from training? @@ -303,7 +297,7 @@ public function trainWithCustomDataSet( if ($validationSet === [] || $trainingSet === []) { $logger->info('not enough messages to train a classifier'); $perf->end(); - return; + return null; } /** @var Learner&Estimator&Persistable $validationEstimator */ @@ -321,30 +315,28 @@ public function trainWithCustomDataSet( 'exception' => $e, ]); $perf->end(); - return; + return null; } $perf->step('train and validate classifier with training and validation sets'); - if ($persist) { - /** @var Learner&Estimator&Persistable $persistedEstimator */ - $persistedEstimator = $estimator(); - $this->trainClassifier($persistedEstimator, $dataSet); - $perf->step('train classifier with full data set'); - - // Extract persisted transformers of the subject extractor. - // Is a bit hacky but a full abstraction would be overkill. - /** @var (Transformer&Persistable)[] $transformers */ - $transformers = []; - if ($extractor instanceof NewCompositeExtractor) { - $transformers[] = $extractor->getSubjectExtractor()->getWordCountVectorizer(); - $transformers[] = $extractor->getSubjectExtractor()->getTfidf(); - } - - $classifier->setAccountId($account->getId()); - $classifier->setDuration($perf->end()); - $this->persistenceService->persist($classifier, $persistedEstimator, $transformers); - $logger->debug("classifier {$classifier->getId()} persisted"); + if (!$persist) { + return $validationEstimator; } + + /** @var Learner&Estimator&Persistable $persistedEstimator */ + $persistedEstimator = $estimator(); + $this->trainClassifier($persistedEstimator, $dataSet); + $perf->step('train classifier with full data set'); + $classifier->setDuration($perf->end()); + $classifier->setAccountId($account->getId()); + $classifier->setEstimator(get_class($persistedEstimator)); + $classifier->setPersistenceVersion(PersistenceService::VERSION); + + $this->persistenceService->persist($account, $persistedEstimator, $extractor); + $logger->debug("Classifier for account {$account->getId()} persisted", [ + 'classifier' => $classifier, + ]); + return $persistedEstimator; } diff --git a/lib/Service/Classification/PersistenceService.php b/lib/Service/Classification/PersistenceService.php index 4b4eea5345..120136b27e 100644 --- a/lib/Service/Classification/PersistenceService.php +++ b/lib/Service/Classification/PersistenceService.php @@ -3,162 +3,78 @@ declare(strict_types=1); /** - * SPDX-FileCopyrightText: 2020 Nextcloud GmbH and Nextcloud contributors + * SPDX-FileCopyrightText: 2020-2024 Nextcloud GmbH and Nextcloud contributors * SPDX-License-Identifier: AGPL-3.0-or-later */ namespace OCA\Mail\Service\Classification; -use OCA\DAV\Connector\Sabre\File; use OCA\Mail\Account; -use OCA\Mail\AppInfo\Application; -use OCA\Mail\Db\Classifier; -use OCA\Mail\Db\ClassifierMapper; -use OCA\Mail\Db\MailAccountMapper; use OCA\Mail\Exception\ServiceException; use OCA\Mail\Model\ClassifierPipeline; +use OCA\Mail\Service\Classification\FeatureExtraction\CompositeExtractor; use OCA\Mail\Service\Classification\FeatureExtraction\IExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\NewCompositeExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\SubjectExtractor; -use OCA\Mail\Service\Classification\FeatureExtraction\VanillaCompositeExtractor; -use OCP\App\IAppManager; -use OCP\AppFramework\Db\DoesNotExistException; -use OCP\AppFramework\Utility\ITimeFactory; -use OCP\Files; -use OCP\Files\IAppData; -use OCP\Files\NotFoundException; -use OCP\Files\NotPermittedException; +use OCP\ICache; use OCP\ICacheFactory; -use OCP\ITempManager; use Psr\Container\ContainerExceptionInterface; use Psr\Container\ContainerInterface; -use Psr\Log\LoggerInterface; use Rubix\ML\Learner; use Rubix\ML\Persistable; use Rubix\ML\PersistentModel; -use Rubix\ML\Persisters\Filesystem; use Rubix\ML\Serializers\RBX; use Rubix\ML\Transformers\TfIdfTransformer; use Rubix\ML\Transformers\Transformer; use Rubix\ML\Transformers\WordCountVectorizer; use RuntimeException; -use function file_get_contents; -use function file_put_contents; use function get_class; -use function strlen; class PersistenceService { - private const ADD_DATA_FOLDER = 'classifiers'; + // Increment the version when changing the classifier or transformer pipeline + public const VERSION = 1; - /** @var ClassifierMapper */ - private $mapper; - - /** @var IAppData */ - private $appData; - - /** @var ITempManager */ - private $tempManager; - - /** @var ITimeFactory */ - private $timeFactory; - - /** @var IAppManager */ - private $appManager; - - /** @var ICacheFactory */ - private $cacheFactory; - - /** @var LoggerInterface */ - private $logger; - - /** @var MailAccountMapper */ - private $accountMapper; - - private ContainerInterface $container; - - public function __construct(ClassifierMapper $mapper, - IAppData $appData, - ITempManager $tempManager, - ITimeFactory $timeFactory, - IAppManager $appManager, - ICacheFactory $cacheFactory, - LoggerInterface $logger, - MailAccountMapper $accountMapper, - ContainerInterface $container) { - $this->mapper = $mapper; - $this->appData = $appData; - $this->tempManager = $tempManager; - $this->timeFactory = $timeFactory; - $this->appManager = $appManager; - $this->cacheFactory = $cacheFactory; - $this->logger = $logger; - $this->accountMapper = $accountMapper; - $this->container = $container; + public function __construct( + private readonly ICacheFactory $cacheFactory, + private readonly ContainerInterface $container, + ) { } /** - * Persist the classifier data to the database, the estimator and its transformers to storage + * Persist classifier, estimator and its transformers to the memory cache. * - * @param Classifier $classifier * @param Learner&Persistable $estimator - * @param (Transformer&Persistable)[] $transformers * - * @throws ServiceException + * @throws ServiceException If any serialization fails */ - public function persist(Classifier $classifier, + public function persist( + Account $account, Learner $estimator, - array $transformers): void { - /* - * First we have to insert the row to get the unique ID, but disable - * it until the model is persisted as well. Otherwise another process - * might try to load the model in the meantime and run into an error - * due to the missing data in app data. - */ - $classifier->setAppVersion($this->appManager->getAppVersion(Application::APP_ID)); - $classifier->setEstimator(get_class($estimator)); - $classifier->setActive(false); - $classifier->setCreatedAt($this->timeFactory->getTime()); - $this->mapper->insert($classifier); + CompositeExtractor $extractor, + ): void { + $serializedData = []; /* - * Then we serialize the estimator into a temporary file + * First we serialize the estimator */ - $tmpPath = $this->tempManager->getTemporaryFile(); try { - $model = new PersistentModel($estimator, new Filesystem($tmpPath)); + $persister = new RubixMemoryPersister(); + $model = new PersistentModel($estimator, $persister); $model->save(); - $serializedClassifier = file_get_contents($tmpPath); - $this->logger->debug('Serialized classifier written to tmp file (' . strlen($serializedClassifier) . 'B'); + $serializedData[] = $persister->getData(); } catch (RuntimeException $e) { throw new ServiceException('Could not serialize classifier: ' . $e->getMessage(), 0, $e); } /* - * Then we store the serialized model to app data - */ - try { - try { - $folder = $this->appData->getFolder(self::ADD_DATA_FOLDER); - $this->logger->debug('Using existing folder for the serialized classifier'); - } catch (NotFoundException $e) { - $folder = $this->appData->newFolder(self::ADD_DATA_FOLDER); - $this->logger->debug('New folder created for serialized classifiers'); - } - $file = $folder->newFile((string)$classifier->getId()); - $file->putContent($serializedClassifier); - $this->logger->debug('Serialized classifier written to app data'); - } catch (NotPermittedException|NotFoundException $e) { - throw new ServiceException('Could not create classifiers directory: ' . $e->getMessage(), 0, $e); - } - - /* - * Then we serialize the transformer pipeline to temporary files + * Then we serialize the transformer pipeline */ - $transformerIndex = 0; + $transfomers = [ + $extractor->getSubjectExtractor()->getWordCountVectorizer(), + $extractor->getSubjectExtractor()->getTfIdf(), + ]; $serializer = new RBX(); - foreach ($transformers as $transformer) { - $tmpPath = $this->tempManager->getTemporaryFile(); + foreach ($transfomers as $transformer) { try { + $persister = new RubixMemoryPersister(); /** * This is how to serialize a transformer according to the official docs. * PersistentModel can only be used on Learners which transformers don't implement. @@ -167,316 +83,145 @@ public function persist(Classifier $classifier, * * @psalm-suppress InternalMethod */ - $serializer->serialize($transformer)->saveTo(new Filesystem($tmpPath)); - $serializedTransformer = file_get_contents($tmpPath); - $this->logger->debug('Serialized transformer written to tmp file (' . strlen($serializedTransformer) . 'B'); + $serializer->serialize($transformer)->saveTo($persister); + $serializedData[] = $persister->getData(); } catch (RuntimeException $e) { throw new ServiceException('Could not serialize transformer: ' . $e->getMessage(), 0, $e); } - - try { - $file = $folder->newFile("{$classifier->getId()}_t$transformerIndex"); - $file->putContent($serializedTransformer); - $this->logger->debug("Serialized transformer $transformerIndex written to app data"); - } catch (NotPermittedException|NotFoundException $e) { - throw new ServiceException( - "Failed to persist transformer $transformerIndex: " . $e->getMessage(), - 0, - $e - ); - } - - $transformerIndex++; } - /* - * Now we set the model active so it can be used by the next request - */ - $classifier->setActive(true); - $this->mapper->update($classifier); + $this->setCached((string)$account->getId(), $serializedData); } /** - * @param Account $account - * - * @return ?array [Estimator, IExtractor] + * Load the latest estimator and its transformers. * - * @throws ServiceException + * @throws ServiceException If any deserialization fails */ - public function loadLatest(Account $account): ?array { - try { - $latestModel = $this->mapper->findLatest($account->getId()); - } catch (DoesNotExistException $e) { + public function loadLatest(Account $account): ?ClassifierPipeline { + $cached = $this->getCached((string)$account->getId()); + if ($cached == null) { return null; } - $pipeline = $this->load($latestModel); + $serializedModel = $cached[0]; + $serializedTransformers = array_slice($cached, 1); try { - $extractor = $this->loadExtractor($latestModel, $pipeline); - } catch (ContainerExceptionInterface $e) { + $estimator = PersistentModel::load(new RubixMemoryPersister($serializedModel)); + } catch (RuntimeException $e) { throw new ServiceException( - "Failed to load extractor: {$e->getMessage()}", + 'Could not deserialize persisted classifier: ' . $e->getMessage(), 0, $e, ); } - return [$pipeline->getEstimator(), $extractor]; - } - - /** - * Load an estimator and its transformers of a classifier from storage - * - * @param Classifier $classifier - * @return ClassifierPipeline - * - * @throws ServiceException - */ - public function load(Classifier $classifier): ClassifierPipeline { - $transformerCount = 0; - $appVersion = $this->parseAppVersion($classifier->getAppVersion()); - if ($appVersion[0] >= 3 && $appVersion[1] >= 2) { - $transformerCount = 2; - } - - $id = $classifier->getId(); - $cached = $this->getCached($classifier->getId(), $transformerCount); - if ($cached !== null) { - $this->logger->debug("Using cached serialized classifier $id"); - $serialized = $cached[0]; - $serializedTransformers = array_slice($cached, 1); - } else { - $this->logger->debug('Loading serialized classifier from app data'); - try { - $modelsFolder = $this->appData->getFolder(self::ADD_DATA_FOLDER); - $modelFile = $modelsFolder->getFile((string)$id); - } catch (NotFoundException $e) { - $this->logger->debug("Could not load classifier $id: " . $e->getMessage()); - throw new ServiceException("Could not load classifier $id: " . $e->getMessage(), 0, $e); - } - - try { - $serialized = $modelFile->getContent(); - } catch (NotFoundException|NotPermittedException $e) { - $this->logger->debug("Could not load content for model file with classifier id $id: " . $e->getMessage()); - throw new ServiceException("Could not load content for model file with classifier id $id: " . $e->getMessage(), 0, $e); - } - $size = strlen($serialized); - $this->logger->debug("Serialized classifier loaded (size=$size)"); - - $serializedTransformers = []; - for ($i = 0; $i < $transformerCount; $i++) { - try { - $transformerFile = $modelsFolder->getFile("{$id}_t$i"); - } catch (NotFoundException $e) { - $this->logger->debug("Could not load transformer $i of classifier $id: " . $e->getMessage()); - throw new ServiceException("Could not load transformer $i of classifier $id: " . $e->getMessage(), 0, $e); - } - - try { - $serializedTransformer = $transformerFile->getContent(); - } catch (NotFoundException|NotPermittedException $e) { - $this->logger->debug("Could not load content for transformer file $i with classifier id $id: " . $e->getMessage()); - throw new ServiceException("Could not load content for transformer file $i with classifier id $id: " . $e->getMessage(), 0, $e); - } - $size = strlen($serializedTransformer); - $this->logger->debug("Serialized transformer $i loaded (size=$size)"); - $serializedTransformers[] = $serializedTransformer; - } - - $this->cache($id, $serialized, $serializedTransformers); - } - - $tmpPath = $this->tempManager->getTemporaryFile(); - file_put_contents($tmpPath, $serialized); - try { - $estimator = PersistentModel::load(new Filesystem($tmpPath)); - } catch (RuntimeException $e) { - throw new ServiceException("Could not deserialize persisted classifier $id: " . $e->getMessage(), 0, $e); - } - - $transformers = array_map(function (string $serializedTransformer) use ($id) { - $serializer = new RBX(); - $tmpPath = $this->tempManager->getTemporaryFile(); - file_put_contents($tmpPath, $serializedTransformer); + $serializer = new RBX(); + $transformers = array_map(function (string $serializedTransformer) use ($serializer) { try { - $persister = new Filesystem($tmpPath); + $persister = new RubixMemoryPersister($serializedTransformer); $transformer = $persister->load()->deserializeWith($serializer); } catch (RuntimeException $e) { - throw new ServiceException("Could not deserialize persisted transformer of classifier $id: " . $e->getMessage(), 0, $e); + throw new ServiceException( + 'Could not deserialize persisted transformer of classifier: ' . $e->getMessage(), + 0, + $e, + ); } if (!($transformer instanceof Transformer)) { - throw new ServiceException("Transformer of classifier $id is not a transformer: Got " . $transformer::class); + throw new ServiceException(sprintf( + 'Transformer is not an instance of %s: Got %s', + Transformer::class, + get_class($transformer), + )); } return $transformer; }, $serializedTransformers); - return new ClassifierPipeline($estimator, $transformers); - } - - public function cleanUp(): void { - $threshold = $this->timeFactory->getTime() - 30 * 24 * 60 * 60; - $totalAccounts = $this->accountMapper->getTotal(); - $classifiers = $this->mapper->findHistoric($threshold, $totalAccounts * 10); - foreach ($classifiers as $classifier) { - try { - $this->deleteModel($classifier->getId()); - $this->mapper->delete($classifier); - } catch (NotPermittedException $e) { - // Log and continue. This is not critical - $this->logger->warning('Could not clean-up old classifier', [ - 'id' => $classifier->getId(), - 'exception' => $e, - ]); - } - } - } - - /** - * @throws NotPermittedException - */ - private function deleteModel(int $id): void { - $this->logger->debug('Deleting serialized classifier from app data', [ - 'id' => $id, - ]); - try { - $modelsFolder = $this->appData->getFolder(self::ADD_DATA_FOLDER); - $modelFile = $modelsFolder->getFile((string)$id); - $modelFile->delete(); - } catch (NotFoundException $e) { - $this->logger->debug("Classifier model $id does not exist", [ - 'exception' => $e, - ]); - } - } - - /** - * Load and instantiate extractor based on a classifier's app version. - * - * @param Classifier $classifier - * @param ClassifierPipeline $pipeline - * @return IExtractor - * - * @throws ContainerExceptionInterface - * @throws ServiceException - */ - private function loadExtractor(Classifier $classifier, - ClassifierPipeline $pipeline): IExtractor { - $appVersion = $this->parseAppVersion($classifier->getAppVersion()); - if ($appVersion[0] >= 3 && $appVersion[1] >= 2) { - return $this->loadExtractorV2($pipeline->getTransformers()); - } - - return $this->loadExtractorV1($pipeline->getTransformers()); - } + $extractor = $this->loadExtractor($transformers); - /** - * @return VanillaCompositeExtractor - * - * @throws ContainerExceptionInterface - */ - private function loadExtractorV1(): VanillaCompositeExtractor { - return $this->container->get(VanillaCompositeExtractor::class); + return new ClassifierPipeline($estimator, $extractor, $transformers); } /** - * @param Transformer[] $transformers - * @return NewCompositeExtractor + * Load and instantiate extractor based on the given transformers. * - * @throws ContainerExceptionInterface - * @throws ServiceException + * @throws ServiceException If the transformers array contains unexpected instances or the composite extractor can't be instantiated */ - private function loadExtractorV2(array $transformers): NewCompositeExtractor { + private function loadExtractor(array $transformers): IExtractor { $wordCountVectorizer = $transformers[0]; if (!($wordCountVectorizer instanceof WordCountVectorizer)) { - throw new ServiceException('Failed to load persisted transformer: Expected ' . WordCountVectorizer::class . ', got' . $wordCountVectorizer::class); + throw new ServiceException(sprintf( + 'Failed to load persisted transformer: Expected %s, got %s', + WordCountVectorizer::class, + get_class($wordCountVectorizer), + )); } + $tfidfTransformer = $transformers[1]; if (!($tfidfTransformer instanceof TfIdfTransformer)) { - throw new ServiceException('Failed to load persisted transformer: Expected ' . TfIdfTransformer::class . ', got' . $tfidfTransformer::class); + throw new ServiceException(sprintf( + 'Failed to load persisted transformer: Expected %s, got %s', + TfIdfTransformer::class, + get_class($tfidfTransformer), + )); } - $subjectExtractor = new SubjectExtractor(); - $subjectExtractor->setWordCountVectorizer($wordCountVectorizer); - $subjectExtractor->setTfidf($tfidfTransformer); - return new NewCompositeExtractor( - $this->container->get(VanillaCompositeExtractor::class), - $subjectExtractor, - ); - } + try { + /** @var CompositeExtractor $extractor */ + $extractor = $this->container->get(CompositeExtractor::class); + } catch (ContainerExceptionInterface $e) { + throw new ServiceException('Failed to instantiate the composite extractor', 0, $e); + } - private function getCacheKey(int $id): string { - return "mail_classifier_$id"; + $extractor->getSubjectExtractor()->setWordCountVectorizer($wordCountVectorizer); + $extractor->getSubjectExtractor()->setTfidf($tfidfTransformer); + return $extractor; } - private function getTransformerCacheKey(int $id, int $index): string { - return $this->getCacheKey($id) . "_transformer_$index"; + private function getCacheInstance(): ?ICache { + if (!$this->cacheFactory->isAvailable()) { + return null; + } + + $version = self::VERSION; + return $this->cacheFactory->createDistributed("mail-classifier/v$version/"); } /** - * @param int $id - * @param int $transformerCount - * - * @return (?string)[]|null Array of serialized classifier and transformers + * @return string[]|null Array of serialized classifier and transformers */ - private function getCached(int $id, int $transformerCount): ?array { - // FIXME: Will always return null as the cached, serialized data is always an empty string. - // See my note in self::cache() for further elaboration. - - if (!$this->cacheFactory->isLocalCacheAvailable()) { + private function getCached(string $id): ?array { + $cache = $this->getCacheInstance(); + if ($cache === null) { return null; } - $cache = $this->cacheFactory->createLocal(); - $values = []; - $values[] = $cache->get($this->getCacheKey($id)); - for ($i = 0; $i < $transformerCount; $i++) { - $values[] = $cache->get($this->getTransformerCacheKey($id, $i)); - } - - // Only return cached values if estimator and all transformers are available - if (in_array(null, $values, true)) { + $json = $cache->get($id); + if (!is_string($json)) { return null; } - return $values; - } - - private function cache(int $id, string $serialized, array $serializedTransformers): void { - // FIXME: This is broken as some cache implementations will run the provided value through - // json_encode which drops non-utf8 strings. The serialized string contains binary - // data so an empty string will be saved instead (tested on Redis). - // Note: JSON requires strings to be valid utf8 (as per its spec). - - // IDEA: Implement a method ICache::setRaw() that forwards a raw/binary string as is to the - // underlying cache backend. - - if (!$this->cacheFactory->isLocalCacheAvailable()) { - return; - } - $cache = $this->cacheFactory->createLocal(); - $cache->set($this->getCacheKey($id), $serialized); - - $transformerIndex = 0; - foreach ($serializedTransformers as $transformer) { - $cache->set($this->getTransformerCacheKey($id, $transformerIndex), $transformer); - $transformerIndex++; - } + $serializedData = json_decode($json); + return array_map(base64_decode(...), $serializedData); } /** - * Parse minor and major part of the given semver string. - * - * @return int[] + * @param string[] $serializedData Array of serialized classifier and transformers */ - private function parseAppVersion(string $version): array { - $parts = explode('.', $version); - if (count($parts) < 2) { - return [0, 0]; + private function setCached(string $id, array $serializedData): void { + $cache = $this->getCacheInstance(); + if ($cache === null) { + return; } - return [(int)$parts[0], (int)$parts[1]]; + // Serialized data contains binary, non-utf8 data so we encode it as base64 first + $encodedData = array_map(base64_encode(...), $serializedData); + $json = json_encode($encodedData, JSON_THROW_ON_ERROR); + + // Set a ttl of a week because a new model will be generated daily + $cache->set($id, $json, 3600 * 24 * 7); } } diff --git a/lib/Service/Classification/RubixMemoryPersister.php b/lib/Service/Classification/RubixMemoryPersister.php new file mode 100644 index 0000000000..2b170b38b5 --- /dev/null +++ b/lib/Service/Classification/RubixMemoryPersister.php @@ -0,0 +1,41 @@ +data; + } + + public function save(Encoding $encoding): void { + $this->data = $encoding->data(); + } + + public function load(): Encoding { + if ($this->data === null) { + throw new ValueError('Trying to load encoding when no data is available'); + } + + return new Encoding($this->data); + } + + public function __toString() { + return (string)self::class; + } +} diff --git a/lib/Service/CleanupService.php b/lib/Service/CleanupService.php index 65fa421200..74f5c0070a 100644 --- a/lib/Service/CleanupService.php +++ b/lib/Service/CleanupService.php @@ -17,7 +17,6 @@ use OCA\Mail\Db\MessageRetentionMapper; use OCA\Mail\Db\MessageSnoozeMapper; use OCA\Mail\Db\TagMapper; -use OCA\Mail\Service\Classification\PersistenceService; use OCA\Mail\Support\PerformanceLogger; use OCP\AppFramework\Utility\ITimeFactory; use Psr\Log\LoggerInterface; @@ -44,7 +43,6 @@ class CleanupService { private MessageSnoozeMapper $messageSnoozeMapper; - private PersistenceService $classifierPersistenceService; private ITimeFactory $timeFactory; public function __construct(MailAccountMapper $mailAccountMapper, @@ -55,7 +53,6 @@ public function __construct(MailAccountMapper $mailAccountMapper, TagMapper $tagMapper, MessageRetentionMapper $messageRetentionMapper, MessageSnoozeMapper $messageSnoozeMapper, - PersistenceService $classifierPersistenceService, ITimeFactory $timeFactory) { $this->aliasMapper = $aliasMapper; $this->mailboxMapper = $mailboxMapper; @@ -64,7 +61,6 @@ public function __construct(MailAccountMapper $mailAccountMapper, $this->tagMapper = $tagMapper; $this->messageRetentionMapper = $messageRetentionMapper; $this->messageSnoozeMapper = $messageSnoozeMapper; - $this->classifierPersistenceService = $classifierPersistenceService; $this->mailAccountMapper = $mailAccountMapper; $this->timeFactory = $timeFactory; } @@ -92,8 +88,6 @@ public function cleanUp(LoggerInterface $logger): void { $task->step('delete expired messages'); $this->messageSnoozeMapper->deleteOrphans(); $task->step('delete orphan snoozes'); - $this->classifierPersistenceService->cleanUp(); - $task->step('delete orphan classifiers'); $task->end(); } }