Spam filter: Fix training sample size checks

This commit is contained in:
mdecimus
2025-12-26 10:22:48 +01:00
parent 7d29cfbccf
commit 5764a8580b

View File

@@ -262,8 +262,12 @@ impl SpamClassifier for Server {
Spam(SpamEvent::ModelNotReady),
Reason = "Not enough samples for training",
Details = vec![
trc::Value::from(ham_count + trainer.reservoir.ham.total_seen),
trc::Value::from(spam_count + trainer.reservoir.spam.total_seen)
trc::Value::from(trainer.reservoir.ham.total_seen),
trc::Value::from(trainer.reservoir.spam.total_seen)
],
Limit = vec![
trc::Value::from(config.min_ham_samples),
trc::Value::from(config.min_spam_samples)
],
Elapsed = started.elapsed()
);
@@ -273,7 +277,7 @@ impl SpamClassifier for Server {
// Balance classes if needed
if spam_count > ham_count {
// We have too much spam today. We need to replay old HAM.
// We have too much spam this time. We need to replay old HAM.
samples.extend(
trainer
.reservoir
@@ -286,7 +290,7 @@ impl SpamClassifier for Server {
}),
);
} else if ham_count > spam_count {
// We have too much ham today. We need to replay old SPAM.
// We have too much ham this time. We need to replay old SPAM.
samples.extend(
trainer
.reservoir
@@ -491,32 +495,20 @@ impl SpamClassifier for Server {
)
.await
.caused_by(trc::location!())?;
if ham_count >= config.min_ham_samples && spam_count >= config.min_spam_samples {
self.blob_store()
.put_blob(
SPAM_CLASSIFIER_KEY,
&classifier.serialize().caused_by(trc::location!())?,
)
.await
.caused_by(trc::location!())?;
self.blob_store()
.put_blob(
SPAM_CLASSIFIER_KEY,
&classifier.serialize().caused_by(trc::location!())?,
)
.await
.caused_by(trc::location!())?;
self.inner
.data
.spam_classifier
.store(Arc::new(classifier.inner));
self.cluster_broadcast(BroadcastEvent::ReloadSpamFilter)
.await;
} else {
self.blob_store()
.delete_blob(SPAM_CLASSIFIER_KEY)
.await
.caused_by(trc::location!())?;
trc::event!(
Spam(SpamEvent::ModelNotReady),
Details = vec![trc::Value::from(ham_count), trc::Value::from(spam_count)],
);
}
self.inner
.data
.spam_classifier
.store(Arc::new(classifier.inner));
self.cluster_broadcast(BroadcastEvent::ReloadSpamFilter)
.await;
trc::event!(
Spam(SpamEvent::TrainCompleted),