]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
Add a simple script to learn fann from rspamd logs
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 22 Dec 2015 00:17:59 +0000 (00:17 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 22 Dec 2015 00:17:59 +0000 (00:17 +0000)
utils/fann_train.pl [new file with mode: 0755]

diff --git a/utils/fann_train.pl b/utils/fann_train.pl
new file mode 100755 (executable)
index 0000000..c6a4bf1
--- /dev/null
@@ -0,0 +1,245 @@
+#!/usr/bin/env perl
+
+# This script is a very simple prototype to learn fann from rspamd logs
+# For now, it is intended for internal use only
+
+use strict;
+use warnings FATAL => 'all';
+use AI::FANN qw(:all);
+use Getopt::Std;
+
+my %sym_idx; # Symbols by index
+my %sym_names; # Symbols by name
+my $num = 1; # Number of symbols
+my @spam;
+my @ham;
+my $max_samples = -1;
+my $split = 1;
+my $preprocessed = 0; # ouptut is in format <score>:<0|1>:<SYM1,...SYMN>
+my $score_spam = 12;
+my $score_ham = -6;
+
+sub process {
+    my ($input, $spam, $ham) = @_;
+    my $samples = 0;
+
+    while(<$input>) {
+        if (!$preprocessed) {
+            if (/^.*rspamd_task_write_log.*: \[(-?\d+\.?\d*)\/(\d+\.?\d*)\]\s*\[(.+)\].*$/) {
+                if ($1 > $score_spam) {
+                    $_ = "$1:1: $3";
+                }
+                elsif ($1 < $score_ham) {
+                    $_ = "$1:0: $3\n";
+                }
+                else {
+                    # Out of boundary
+                    next;
+                }
+            }
+            else {
+                # Not our log message
+                next;
+            }
+        }
+
+        $_ =~ /^(-?\d+\.?\d*):([01]):\s*(\S.*)$/;
+
+        my $is_spam = 0;
+
+        if ($2 == 1) {
+            $is_spam = 1;
+        }
+
+        my @ar = split /,/, $3;
+        my %sample;
+
+        foreach my $sym (@ar) {
+            chomp $sym;
+            if (!$sym_idx{$sym}) {
+                $sym_idx{$sym} = $num;
+                $sym_names{$num} = $sym;
+                $num++;
+            }
+
+            $sample{$sym_idx{$sym}} = 1;
+        }
+
+        if ($is_spam) {
+            push @{$spam}, \%sample;
+        }
+        else {
+            push @{$ham}, \%sample;
+        }
+
+        $samples++;
+        if ($max_samples > 0 && $samples > $max_samples) {
+            return;
+        }
+    }
+}
+
+# Shuffle array
+sub fisher_yates_shuffle
+{
+    my $array = shift;
+    my $i = @$array;
+
+    while ( --$i ) {
+        my $j = int rand( $i + 1 );
+        @$array[$i, $j] = @$array[$j, $i];
+    }
+}
+
+# Train network
+sub train {
+    my ($ann, $sample, $result) = @_;
+
+    my @row;
+
+    for (my $i = 1; $i < $num; $i++) {
+        if ($sample->{$i}) {
+            push @row, 1;
+        }
+        else {
+            push @row, 0;
+        }
+    }
+
+    #print "@row -> @{$result}\n";
+
+    $ann->train(\@row, \@{$result});
+}
+
+sub test {
+    my ($ann, $sample) = @_;
+
+    my @row;
+
+    for (my $i = 1; $i < $num; $i++) {
+        if ($sample->{$i}) {
+            push @row, 1;
+        }
+        else {
+            push @row, 0;
+        }
+    }
+
+    my $ret = $ann->run(\@row);
+
+    return $ret;
+}
+
+my %opts;
+getopts('o:i:s:n:t:hpS:H:', \%opts);
+
+if ($opts{'h'}) {
+    print "$0 [-i input] [-o output] [-s scores] [-n max_samples] [-S spam_score] [-H ham_score] [-ph]\n";
+    exit;
+}
+
+my $input = *STDIN;
+
+if ($opts{'i'}) {
+    open($input, '<', $opts{'i'}) or die "cannot open $opts{i}";
+}
+
+if ($opts{'n'}) {
+    $max_samples = $opts{'n'};
+}
+
+if ($opts{'t'}) {
+    # Test split
+    $split = $opts{'t'};
+}
+if ($opts{'p'}) {
+    $preprocessed = 1;
+}
+
+if ($opts{'H'}) {
+    $score_ham = $opts{'H'};
+}
+
+if ($opts{'S'}) {
+    $score_spam = $opts{'S'};
+}
+
+# ham_prob, spam_prob
+my @spam_out = (1);
+my @ham_out = (0);
+
+process($input, \@spam, \@ham);
+fisher_yates_shuffle(\@spam);
+fisher_yates_shuffle(\@ham);
+
+my $nspam = int(scalar(@spam) / $split);
+my $nham = int(scalar(@ham) / $split);
+
+my $ann = AI::FANN->new_standard($num - 1, ($num + 2) / 2, 1);
+
+my @train_data;
+# Train ANN
+for (my $i = 0; $i < $nham; $i++) {
+    push @train_data, [ $ham[$i], \@ham_out ];
+}
+
+for (my $i = 0; $i < $nspam; $i++) {
+    push @train_data, [ $spam[$i], \@spam_out ];
+}
+
+fisher_yates_shuffle(\@train_data);
+
+foreach my $train_row (@train_data) {
+    train($ann, @{$train_row}[0], @{$train_row}[1]);
+}
+
+print "Trained $nspam SPAM and $nham HAM samples\n";
+
+# Now run fann
+if ($split > 1) {
+    my $sample = 0.0;
+    my $correct = 0.0;
+    for (my $i = $nham; $i < $nham * $split; $i++) {
+        my $ret = test($ann, $ham[$i]);
+        #print "@{$ret}\n";
+        if (@{$ret}[0] < 0.5) {
+            $correct++;
+        }
+        $sample++;
+    }
+
+    print "Tested $sample HAM samples, correct matched: $correct, rate: ".($correct / $sample)."\n";
+
+    $sample = 0.0;
+    $correct = 0.0;
+
+    for (my $i = $nspam; $i < $nspam * $split; $i++) {
+        my $ret = test($ann, $spam[$i]);
+        #print "@{$ret}\n";
+        if (@{$ret}[0] > 0.5) {
+            $correct++;
+        }
+        $sample++;
+    }
+
+    print "Tested $sample SPAM samples, correct matched: $correct, rate: ".($correct / $sample)."\n";
+}
+
+if ($opts{'o'}) {
+    $ann->save($opts{'o'}) or die "cannot save ann into $opts{o}";
+}
+
+if ($opts{'s'}) {
+    open(my $scores, '>',
+        $opts{'s'}) or die "cannot open score file $opts{'s'}";
+    print $scores "{";
+    for (my $i = 1; $i < $num; $i++) {
+        my $n = $i - 1;
+        if ($i != $num - 1) {
+            print $scores "\"$sym_names{$i}\":$n,";
+        }
+        else {
+            print $scores "\"$sym_names{$i}\":$n}\n";
+        }
+    }
+}