]> git.ipfire.org Git - oddments/cappie.git/commitdiff
Merge branch 'master' of ssh://git.ipfire.org/pub/git/oddments/cappie master
authorMichael Tremer <michael.tremer@ipfire.org>
Tue, 11 May 2010 21:58:54 +0000 (23:58 +0200)
committerMichael Tremer <michael.tremer@ipfire.org>
Tue, 11 May 2010 21:58:54 +0000 (23:58 +0200)
Conflicts:
cappie/__init__.py
cappie/queue.py

cappie/__init__.py
cappie/constants.py
cappie/database.py [new file with mode: 0644]
cappie/events.py
cappie/protocol.py
cappie/queue.py

index e760f4ad6a5259cf8d7643c8edaf41ff4c13bddf..fcabae4564094ed8567be56a5f55b4e0c822e083 100644 (file)
@@ -32,6 +32,7 @@ import queue
 import util
 
 from errors import *
+from events import *
 
 def getAllInterfaces():
        filters = ("lo", "any")
@@ -93,12 +94,12 @@ class Cappie(object):
                        iface.start()
 
                while True:
-                       if not self.queue.is_alive():
+                       if not self.queue.isAlive():
                                self.log.critical("Queue thread died unexpectedly.")
                                return
 
                        for iface in self.__interfaces:
-                               if not iface.is_alive():
+                               if not iface.isAlive():
                                        self.log.critical("Thread died unexpectedly. %s" % iface.dev)
                                        return
                                time.sleep(60)
@@ -126,6 +127,10 @@ class Cappie(object):
 
                self.queue.shutdown()
 
+       @property
+       def db(self):
+               return self.queue.db
+
 
 class Interface(Thread):
        heartbeat = 0.1
@@ -140,8 +145,6 @@ class Interface(Thread):
                self.promisc = promisc
                self.queue = self.cappie.queue
 
-               self.db = Database(self)
-
                self.log.debug("Created new interface %s" % self.dev)
                
                self.__running = True
@@ -161,16 +164,13 @@ class Interface(Thread):
                for key, val in p.items():
                        self.log.debug("  %s: %s" % (key, val))
 
-               if not self.db.has(p["source_address"]):
-                       self.db.put(p["source_address"], "SOURCE_IP_ADDRESS", p["source_ip_address"])
+               self._handlePacket(p)
 
        def run(self):
                self.log.info("Starting interface %s" % self.dev)
 
                util.setprocname("interface %s" % self.dev)
 
-               self.db.open()
-
                p = pcapy.open_live(self.dev, self.mtu, self.promisc, 0)
                p.setfilter(self.filter)
                #p.loop(0, self._callback)
@@ -178,7 +178,6 @@ class Interface(Thread):
                p.setnonblock(1)
                while True:
                        if not self.__running:
-                               self.db.close()
                                return
                        
                        if p.dispatch(1, self._callback):
@@ -197,33 +196,13 @@ class Interface(Thread):
        def filter(self):
                return "arp or rarp"
 
+       def addEvent(self, event):
+               return self.cappie.queue.add(event)
 
-class Database(object):
-       def __init__(self, interface):
-               self.interface = interface
-               self.dev = self.interface.dev
-               self.log = self.interface.log
-
-               self.__data = {}
-
-       def open(self):
-               self.log.debug("Opened database for %s" % self.dev)
-
-       def close(self):
-               self.log.debug("Closing database for %s" % self.dev)
-               print self.__data
-
-       def get(self, mac):
-               if self.has(mac):
-                       return self.__data[mac]
-
-       def has(self, mac):
-               return self.__data.has_key(mac)
-
-       def put(self, mac, key, val):
-               if not self.has(mac):
-                       self.__data[mac] = {}
-
-               # TODO Check key for sanity
+       def _handlePacket(self, packet):
+               if packet.operation == OPERATION_RESPONSE:
+                       self.addEvent(EventResponseTrigger(self, packet))
+                       #self.addEvent(EventCheckDuplicate(self, packet))
 
-               self.__data[mac][key] = val
+               elif packet.operation == OPERATION_REQUEST:
+                       self.addEvent(EventRequestTrigger(self, packet))
index 50fb0162f6c59e6af82f0351fc43072607473090..d68fcc7d2dd05a5b2289bbd46104d11c1e0560fa 100644 (file)
 #                                                                             #
 ###############################################################################
 
+ETHERTYPE_ARP = 0x0806
+
 TYPE_ARP = 0
 
 OPERATION_REQUEST = 0
 OPERATION_RESPONSE = 1
+
+DB_LASTSEEN_MAX = 5*60 # 5 minutes
+DB_GC_INTERVAL = 60
diff --git a/cappie/database.py b/cappie/database.py
new file mode 100644 (file)
index 0000000..88d57fb
--- /dev/null
@@ -0,0 +1,102 @@
+#!/usr/bin/python
+###############################################################################
+#                                                                             #
+# Cappie                                                                      #
+# Copyright (C) 2010 Michael Tremer                                           #
+#                                                                             #
+# This program is free software: you can redistribute it and/or modify        #
+# it under the terms of the GNU General Public License as published by        #
+# the Free Software Foundation, either version 3 of the License, or           #
+# (at your option) any later version.                                         #
+#                                                                             #
+# This program is distributed in the hope that it will be useful,             #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
+# GNU General Public License for more details.                                #
+#                                                                             #
+# You should have received a copy of the GNU General Public License           #
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
+#                                                                             #
+###############################################################################
+
+import itertools
+import sqlite3
+
+class Database(object):
+       KEYS = ("EVENTS", "ADDRESSES")
+       _CREATE = ["CREATE TABLE IF NOT EXISTS addresses(mac, address, lastseen);",
+               "CREATE TABLE IF NOT EXISTS changes(address, lastchange);"]
+
+       counter = 0
+
+       def __init__(self, log):
+               self.log = log
+
+               self.__connection = None
+
+       def __del__(self):
+               self.close()
+
+       def open(self):
+               self.log.debug("Opening database")
+               if self.__connection:
+                       self.close()
+               self.__connection = sqlite3.connect("test.db")
+               for statement in self._CREATE:
+                       self.execute(statement)
+
+       def close(self):
+               self.log.debug("Closing database")
+               self.commit()
+               self.__connection.close()
+               self.__connection = None
+
+       def commit(self):
+               self.log.debug("Committing data to database")
+               self.__connection.commit()
+
+       def query(self, query, *parameters):
+               """Returns a row list for the given query and parameters."""
+               cursor = self._cursor()
+               self._execute(cursor, query, parameters)
+               column_names = [d[0] for d in cursor.description]
+               return [Row(itertools.izip(column_names, row)) for row in cursor]
+
+       def get(self, query, *parameters):
+               """Returns the first row returned for the given query."""
+               rows = self.query(query, *parameters)
+               if not rows:
+                       return None
+               elif len(rows) > 1:
+                       raise Exception("Multiple rows returned for Database.get() query")
+               else:
+                       return rows[0]
+
+       def _cursor(self):
+               if not self.__connection:
+                       self.open()
+               return self.__connection.cursor()
+
+       def execute(self, query, *parameters):
+               """Executes the given query, returning the lastrowid from the query."""
+               cursor = self._cursor()
+               self._execute(cursor, query, parameters)
+               return cursor.lastrowid
+
+       def _execute(self, cursor, query, parameters):
+               self.log.debug("Executing query: %s" % query)
+               try:
+                       return cursor.execute(query, parameters)
+               except sqlite3.OperationalError:
+                       self.log.error("Error connecting to database")
+                       self.close()
+                       raise
+
+
+class Row(dict):
+       """A dict that allows for object-like property access syntax."""
+       def __getattr__(self, name):
+               try:
+                       return self[name]
+               except KeyError:
+                       raise AttributeError(name)
index d403b522ece585bc90322c98425866ada42f9995..236e19ffccb34a5d82851fd8fafc360263959395 100644 (file)
@@ -23,6 +23,7 @@ import os
 import subprocess
 import time
 
+from constants import *
 from errors import *
 
 class Event(object):
@@ -30,10 +31,14 @@ class Event(object):
                self.cappie = interface.cappie
                self.interface = interface
                self.log = interface.log
+               self.db = self.cappie.db
 
        def __str__(self):
                return self.__class__.__name__
 
+       def addEvent(self, event):
+               return self.cappie.queue.add(event)
+
        def run(self):
                raise NotImplementedError
 
@@ -77,3 +82,89 @@ class EventShell(Event):
                        p.returncode)
 
                return p.returncode
+
+
+class EventRequestTrigger(Event):
+       def __init__(self, interface, packet):
+               Event.__init__(self, interface)
+
+               self.db = interface.cappie.db
+               self.packet = packet
+
+       def _updateAddress(self, mac, address):
+               where = "WHERE mac = '%s' AND address = '%s'" % (mac, address)
+
+               if self.db.get("SELECT * FROM addresses %s" % where):
+                       self.db.execute("UPDATE addresses SET lastseen='%d' %s" % \
+                               (time.time(), where))
+               else:
+                       self.db.execute("INSERT INTO addresses VALUES('%s', '%s', '%d')" % \
+                               (mac, address, time.time()))
+
+       def _updateChanges(self, *args):
+               for arg in args:
+                       where = "WHERE address = '%s'" % arg
+                       if self.db.get("SELECT * FROM changes %s" % where):
+                               self.db.execute("UPDATE changes SET lastchange = '%d' %s" % \
+                                       (time.time(), where))
+                       else:
+                               self.db.execute("INSERT INTO changes VALUES('%s', '%d')" % \
+                                       (arg, time.time()))
+
+       def run(self):
+               mac = self.packet.source_address
+               address = self.packet.source_ip_address
+
+               self._updateAddress(mac, address)
+               self._updateChanges(mac, address)
+
+
+class EventResponseTrigger(EventRequestTrigger):
+       pass
+
+
+class EventGarbageCollector(Event):
+       def __init__(self, db, log):
+               self.db = db
+               self.log = log
+
+       def run(self):
+               # Remove old addresses
+               self.db.execute("DELETE FROM addresses WHERE lastseen >= '%d'" % \
+                       (time.time() - DB_LASTSEEN_MAX))
+
+               self.db.commit()
+
+
+class EventCheckDuplicate(Event):
+       def __init__(self, interface, packet):
+               Event.__init__(self, interface)
+               self.packet = packet
+
+       def run(self):
+               entries = self.db.query("SELECT * FROM addresses WHERE address = '%s'" % \
+                       self.packet.source_ip_address)
+
+               if not entries:
+                       return
+
+               for entry in entries:
+                       if self.packet.source_address == entry.mac:
+                               entries.remove(entry)
+
+               if len(entries) > 1:
+                       self.addEvent(EventHandleDuplicate(self.interface, self.packet))
+
+
+class EventHandleDuplicate(Event):
+       def __init__(self, interface, packet):
+               Event.__init__(self, interface)
+               self.packet = packet
+
+       def run(self):
+               self.log.warning("We probably have a mac spoofing for %s" % \
+                       self.packet.source_address)
+
+
+class EventCheckFlipFlop(Event):
+       pass
index ed5f4a3f98994bc3349724e0a144f2cf83f9a746..c4252de5075f27a9594ca1f053900ea10c97e42c 100644 (file)
@@ -21,6 +21,8 @@
 
 import struct
 
+import database
+
 from constants import *
 from errors import *
 
@@ -34,15 +36,17 @@ def val2mac(val):
        return ":".join(["%02x" % ord(i) for i in val])
 
 def decode_packet(data):
-       for func in (decode_arp_packet,):
-               try:
-                       p = func(data)
-               except PacketTypeError:
-                       continue
+       try:
+               protocol = val2int(struct.unpack("!2s", data[12:14])[0])
+       except:
+               raise DecodeError
 
-               return p
+       try:
+               d = protocol2function[protocol](data)
+       except KeyErrror:
+               raise PacketTypeError, "Could not determine type of packet"
 
-       raise PacketTypeError, "Could not determine type of packet"
+       return database.Row(d)
 
 def decode_arp_packet(data):
        operationmap = {
@@ -58,15 +62,10 @@ def decode_arp_packet(data):
        }
 
        #"hwtype" : data[:2],
-       protocol = val2int(struct.unpack("!2s", data[12:14])[0])
        hw_addr_size = val2int(struct.unpack("!1s", data[18:19])[0])
        hw_prot_size = val2int(struct.unpack("!1s", data[19:20])[0])
        operation = val2int(struct.unpack("!2s", data[20:22])[0])
 
-       # Sanity checks
-       if not protocol == 0x0806:
-               raise PacketTypeError, "Not an ARP packet"
-
        # TODO Must check hwtype here...
 
        try:
@@ -93,3 +92,7 @@ def decode_arp_packet(data):
 
 def decode_ndp_packet(data):
        raise PacketTypeError
+
+protocol2function = {
+       ETHERTYPE_ARP : decode_arp_packet,
+}
index 0db700635652633577e0ceeb28d9ea0712f415be..e690d136235dab0018114af2bf8b2b9438b1331f 100644 (file)
@@ -25,11 +25,13 @@ from threading import Thread
 
 import util
 
+from database import Database
 from errors import *
+from events import *
 
 class Queue(Thread):
        heartbeat = 1.0
-       maxitems = 100
+       maxitems = 10000
 
        def __init__(self, log):
                Thread.__init__(self)
@@ -39,6 +41,9 @@ class Queue(Thread):
                self.__running = True
                self.__queue = []
 
+               self.db = Database(log)
+               self.lastgc = None
+
        def __len__(self):
                return self.length
 
@@ -57,12 +62,16 @@ class Queue(Thread):
 
                util.setprocname("queue")
 
+               self.db.open()
+
                while self.__running or self.__queue:
                        if not self.__queue:
                                #self.log.debug("Queue sleeping for %s seconds" % self.heartbeat)
                                time.sleep(self.heartbeat)
                                continue
 
+                       self._checkGc()
+
                        event = self.__queue.pop(0)
                        self.log.debug("Processing queue event: %s" % event)
                        try:
@@ -70,6 +79,8 @@ class Queue(Thread):
                        except EventException, e:
                                self.log.error("Catched event exception: %s" % e)
 
+               self.db.close()
+
        def shutdown(self):
                self.__running = False
                self.log.debug("Shutting down queue")
@@ -77,3 +88,8 @@ class Queue(Thread):
 
                # Wait until queue handled all events
                self.join()
+
+       def _checkGc(self):
+               if not self.lastgc or self.lastgc <= (time.time() - DB_GC_INTERVAL):
+                       self.add(EventGarbageCollector(self.db, self.log))
+                       self.lastgc = time.time()