]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
pybsddb 4.3.1, adds support for DB.set_bt_compare database btree comparison
authorGregory P. Smith <greg@mad-scientist.com>
Fri, 3 Jun 2005 07:03:07 +0000 (07:03 +0000)
committerGregory P. Smith <greg@mad-scientist.com>
Fri, 3 Jun 2005 07:03:07 +0000 (07:03 +0000)
functions written in python.

contributed by <frederic.gobry@epfl.ch>

Lib/bsddb/dbobj.py
Lib/bsddb/test/test_all.py
Lib/bsddb/test/test_compare.py [new file with mode: 0644]
Modules/_bsddb.c

index 3bafafa5d9df9dc8848417ebd555512ba8368a40..667ec314d3fa646234a349785a0bdd51c9fa73a1 100644 (file)
@@ -164,6 +164,8 @@ class DB(DictMixin):
         return apply(self._cobj.rename, args, kwargs)
     def set_bt_minkey(self, *args, **kwargs):
         return apply(self._cobj.set_bt_minkey, args, kwargs)
+    def set_bt_compare(self, *args, **kwargs):
+        return apply(self._cobj.set_bt_compare, args, kwargs)
     def set_cachesize(self, *args, **kwargs):
         return apply(self._cobj.set_cachesize, args, kwargs)
     def set_flags(self, *args, **kwargs):
index 7b1bf3d29428ac7c03e5f345af60bdf22d3a603b..701bdfe3f7442a7ee9dfb215096161ebbef95e0f 100644 (file)
@@ -50,6 +50,7 @@ def suite():
         'test_associate',
         'test_basics',
         'test_compat',
+        'test_compare',
         'test_dbobj',
         'test_dbshelve',
         'test_dbtables',
diff --git a/Lib/bsddb/test/test_compare.py b/Lib/bsddb/test/test_compare.py
new file mode 100644 (file)
index 0000000..f8b2312
--- /dev/null
@@ -0,0 +1,215 @@
+"""
+TestCases for python DB Btree key comparison function.
+"""
+
+import sys, os
+import test_all
+
+import unittest
+from bsddb3 import db
+
+def lexical_cmp (db, left, right):
+  return cmp (left, right)
+
+def lowercase_cmp(db, left, right):
+  return cmp (left.lower(), right.lower())
+
+def make_reverse_comparator (cmp):
+  def reverse (db, left, right, delegate=cmp):
+    return - delegate (db, left, right)
+  return reverse
+
+_expected_lexical_test_data = ['', 'CCCP', 'a', 'aaa', 'b', 'c', 'cccce', 'ccccf']
+_expected_lowercase_test_data = ['', 'a', 'aaa', 'b', 'c', 'CC', 'cccce', 'ccccf', 'CCCP']
+
+class ComparatorTests (unittest.TestCase):
+  def comparator_test_helper (self, comparator, expected_data):
+    data = expected_data[:]
+    data.sort (lambda l, r, cmp=comparator: cmp (None, l, r))
+    self.failUnless (data == expected_data,
+                     "comparator `%s' is not right: %s vs. %s"
+                     % (comparator, expected_data, data))
+  def test_lexical_comparator (self):
+    self.comparator_test_helper (lexical_cmp, _expected_lexical_test_data)
+  def test_reverse_lexical_comparator (self):
+    rev = _expected_lexical_test_data[:]
+    rev.reverse ()
+    self.comparator_test_helper (make_reverse_comparator (lexical_cmp),
+                                 rev)
+  def test_lowercase_comparator (self):
+    self.comparator_test_helper (lowercase_cmp,
+                                 _expected_lowercase_test_data)
+
+class AbstractBtreeKeyCompareTestCase (unittest.TestCase):
+  env = None
+  db = None
+  
+  def setUp (self):
+    self.filename = self.__class__.__name__ + '.db'
+    homeDir = os.path.join (os.path.dirname (sys.argv[0]), 'db_home')
+    self.homeDir = homeDir
+    try:
+      os.mkdir (homeDir)
+    except os.error:
+      pass
+
+    env = db.DBEnv ()
+    env.open (homeDir,
+              db.DB_CREATE | db.DB_INIT_MPOOL
+              | db.DB_INIT_LOCK | db.DB_THREAD)
+    self.env = env
+
+  def tearDown (self):
+    self.closeDB ()
+    if self.env is not None:
+      self.env.close ()
+      self.env = None
+    import glob
+    map (os.remove, glob.glob (os.path.join (self.homeDir, '*')))
+
+  def addDataToDB (self, data):
+    i = 0
+    for item in data:
+      self.db.put (item, str (i))
+      i = i + 1
+
+  def createDB (self, key_comparator):
+    self.db = db.DB (self.env)
+    self.setupDB (key_comparator)
+    self.db.open (self.filename, "test", db.DB_BTREE, db.DB_CREATE)
+
+  def setupDB (self, key_comparator):
+    self.db.set_bt_compare (key_comparator)
+
+  def closeDB (self):
+    if self.db is not None:
+      self.db.close ()
+      self.db = None
+
+  def startTest (self):
+    pass
+  
+  def finishTest (self, expected = None):
+    if expected is not None:
+      self.check_results (expected)
+    self.closeDB ()
+
+  def check_results (self, expected):
+    curs = self.db.cursor ()
+    try:
+      index = 0
+      rec = curs.first ()
+      while rec:
+        key, ignore = rec
+        self.failUnless (index < len (expected),
+                         "to many values returned from cursor")
+        self.failUnless (expected[index] == key,
+                         "expected value `%s' at %d but got `%s'"
+                         % (expected[index], index, key))
+        index = index + 1
+        rec = curs.next ()
+      self.failUnless (index == len (expected),
+                       "not enough values returned from cursor")
+    finally:
+      curs.close ()
+
+class BtreeKeyCompareTestCase (AbstractBtreeKeyCompareTestCase):
+  def runCompareTest (self, comparator, data):
+    self.startTest ()
+    self.createDB (comparator)
+    self.addDataToDB (data)
+    self.finishTest (data)
+
+  def test_lexical_ordering (self):
+    self.runCompareTest (lexical_cmp, _expected_lexical_test_data)
+
+  def test_reverse_lexical_ordering (self):
+    expected_rev_data = _expected_lexical_test_data[:]
+    expected_rev_data.reverse ()
+    self.runCompareTest (make_reverse_comparator (lexical_cmp),
+                         expected_rev_data)
+
+  def test_compare_function_useless (self):
+    self.startTest ()
+    def socialist_comparator (db, l, r):
+      return 0
+    self.createDB (socialist_comparator)
+    self.addDataToDB (['b', 'a', 'd'])
+    # all things being equal the first key will be the only key
+    # in the database...  (with the last key's value fwiw)
+    self.finishTest (['b'])
+
+    
+class BtreeExceptionsTestCase (AbstractBtreeKeyCompareTestCase):
+  def test_raises_non_callable (self):
+    self.startTest ()
+    self.assertRaises (TypeError, self.createDB, 'abc')
+    self.assertRaises (TypeError, self.createDB, None)
+    self.finishTest ()
+
+  def test_set_bt_compare_with_function (self):
+    self.startTest ()
+    self.createDB (lexical_cmp)
+    self.finishTest ()
+
+  def check_results (self, results):
+    pass
+
+  def test_compare_function_incorrect (self):
+    self.startTest ()
+    def bad_comparator (db, l, r):
+      return 1
+    # verify that set_bt_compare checks that comparator(db, '', '') == 0
+    self.assertRaises (TypeError, self.createDB, bad_comparator)
+    self.finishTest ()
+
+  def test_compare_function_exception (self):
+    self.startTest ()
+    def bad_comparator (db, l, r):
+      if l == r:
+       # pass the set_bt_compare test
+       return 0
+      raise RuntimeError, "i'm a naughty comparison function"
+    self.createDB (bad_comparator)
+    print "\n*** this test should print 2 uncatchable tracebacks ***"
+    self.addDataToDB (['a', 'b', 'c'])  # this should raise, but...
+    self.finishTest ()
+
+  def test_compare_function_bad_return (self):
+    self.startTest ()
+    def bad_comparator (db, l, r):
+      if l == r:
+       # pass the set_bt_compare test
+       return 0
+      return l
+    self.createDB (bad_comparator)
+    print "\n*** this test should print 2 errors about returning an int ***"
+    self.addDataToDB (['a', 'b', 'c'])  # this should raise, but...
+    self.finishTest ()
+
+
+  def test_cannot_assign_twice (self):
+
+    def my_compare (db, a, b):
+      return 0
+    
+    self.startTest ()
+    self.createDB (my_compare)
+    try:
+      self.db.set_bt_compare (my_compare)
+      assert False, "this set should fail"
+
+    except RuntimeError, msg:
+      pass
+    
+def test_suite ():
+  res = unittest.TestSuite ()
+
+  res.addTest (unittest.makeSuite (ComparatorTests))
+  if db.version () >= (3, 3, 11):
+    res.addTest (unittest.makeSuite (BtreeExceptionsTestCase))
+    res.addTest (unittest.makeSuite (BtreeKeyCompareTestCase))
+  return res
+
+if __name__ == '__main__':
+  unittest.main (defaultTest = 'suite')
index 2712e3d15a2b453af3cf648ff9d0915f191edaf6..6a39d859349087a0ab6b77143e860b9a369d94fd 100644 (file)
@@ -97,7 +97,7 @@
 #error "eek! DBVER can't handle minor versions > 9"
 #endif
 
-#define PY_BSDDB_VERSION "4.3.0"
+#define PY_BSDDB_VERSION "4.3.1"
 static char *rcs_id = "$Id$";
 
 
@@ -244,6 +244,7 @@ typedef struct {
     struct behaviourFlags moduleFlags;
 #if (DBVER >= 33)
     PyObject*       associateCallback;
+    PyObject*       btCompareCallback;
     int             primaryDBType;
 #endif
 #ifdef HAVE_WEAKREF
@@ -741,6 +742,7 @@ newDBObject(DBEnvObject* arg, int flags)
     self->myenvobj = NULL;
 #if (DBVER >= 33)
     self->associateCallback = NULL;
+    self->btCompareCallback = NULL;
     self->primaryDBType = 0;
 #endif
 #ifdef HAVE_WEAKREF
@@ -815,6 +817,10 @@ DB_dealloc(DBObject* self)
         Py_DECREF(self->associateCallback);
         self->associateCallback = NULL;
     }
+    if (self->btCompareCallback != NULL) {
+        Py_DECREF(self->btCompareCallback);
+        self->btCompareCallback = NULL;
+    }
 #endif
     PyObject_Del(self);
 }
@@ -1959,6 +1965,161 @@ DB_set_bt_minkey(DBObject* self, PyObject* args)
     RETURN_NONE();
 }
 
+static int 
+_default_cmp (const DBT *leftKey,
+             const DBT *rightKey)
+{
+  int res;
+  int lsize = leftKey->size, rsize = rightKey->size;
+
+  res = memcmp (leftKey->data, rightKey->data, 
+               lsize < rsize ? lsize : rsize);
+  
+  if (res == 0) {
+      if (lsize < rsize) {
+         res = -1;
+      }
+      else if (lsize > rsize) {
+         res = 1;
+      }
+  }
+  return res;
+}
+
+static int
+_db_compareCallback (DB* db, 
+                    const DBT *leftKey,
+                    const DBT *rightKey)
+{
+    int res = 0;
+    PyObject *args;
+    PyObject *result;
+    PyObject *leftObject;
+    PyObject *rightObject;
+    DBObject *self = (DBObject *) db->app_private;
+
+    if (self == NULL || self->btCompareCallback == NULL) {
+       MYDB_BEGIN_BLOCK_THREADS;
+       PyErr_SetString (PyExc_TypeError,
+                        (self == 0
+                         ? "DB_bt_compare db is NULL."
+                         : "DB_bt_compare callback is NULL."));
+       /* we're in a callback within the DB code, we can't raise */
+       PyErr_Print ();
+       res = _default_cmp (leftKey, rightKey);
+       MYDB_END_BLOCK_THREADS;
+    }
+    else {
+       MYDB_BEGIN_BLOCK_THREADS;
+
+       leftObject  = PyString_FromStringAndSize (leftKey->data, leftKey->size);
+       rightObject = PyString_FromStringAndSize (rightKey->data, rightKey->size);
+
+       args = PyTuple_New (3);
+       Py_INCREF (self);
+       PyTuple_SET_ITEM (args, 0, (PyObject *) self);
+       PyTuple_SET_ITEM (args, 1, leftObject);  /* steals reference */
+       PyTuple_SET_ITEM (args, 2, rightObject); /* steals reference */
+    
+       result = PyEval_CallObject (self->btCompareCallback, args);
+       if (result == 0) {
+           /* we're in a callback within the DB code, we can't raise */
+           PyErr_Print (); // XXX-gps or can we?  either way the DB is screwed
+           res = _default_cmp (leftKey, rightKey);
+       }
+       else if (PyInt_Check (result)) {
+           res = PyInt_AsLong (result);
+       }
+       else {
+           PyErr_SetString (PyExc_TypeError,
+                            "DB_bt_compare callback MUST return an int.");
+           /* we're in a callback within the DB code, we can't raise */
+           PyErr_Print ();
+           res = _default_cmp (leftKey, rightKey);
+       }
+    
+       Py_DECREF (args);
+       Py_XDECREF (result);
+
+       MYDB_END_BLOCK_THREADS;
+    }
+    return res;
+}
+
+static PyObject*
+DB_set_bt_compare (DBObject* self, PyObject* args)
+{
+    int err;
+    PyObject *comparator;
+    PyObject *tuple, *emptyStr, *result;
+
+    if (!PyArg_ParseTuple(args,"O:set_bt_compare", &comparator ))
+       return NULL;
+
+    CHECK_DB_NOT_CLOSED (self);
+
+    if (! PyCallable_Check (comparator)) {
+       makeTypeError ("Callable", comparator);
+       return NULL;
+    }
+
+    /* 
+     * Perform a test call of the comparator function with two empty
+     * string objects here.  verify that it returns an int (0).
+     * err if not.
+     */
+    tuple = PyTuple_New (3);
+    Py_INCREF (self);
+    PyTuple_SET_ITEM (tuple, 0, (PyObject *) self);
+
+    emptyStr = PyString_FromStringAndSize (NULL, 0);
+    Py_INCREF(emptyStr);
+    PyTuple_SET_ITEM (tuple, 1, emptyStr);
+    PyTuple_SET_ITEM (tuple, 2, emptyStr); /* steals reference */
+    result = PyEval_CallObject (comparator, tuple);
+    Py_DECREF (tuple);
+    if (result == 0 || !PyInt_Check(result)) {
+       PyErr_SetString (PyExc_TypeError,
+                        "callback MUST return an int");
+       return NULL;
+    }
+    else if (PyInt_AsLong(result) != 0) {
+       PyErr_SetString (PyExc_TypeError,
+                        "callback failed to return 0 on two empty strings");
+       return NULL;
+    }
+
+    /* We don't accept multiple set_bt_compare operations, in order to
+     * simplify the code. This would have no real use, as one cannot
+     * change the function once the db is opened anyway */
+    if (self->btCompareCallback != NULL) {
+       PyErr_SetString (PyExc_RuntimeError, "set_bt_compare () cannot be called more than once");
+       return NULL;
+    }
+
+    Py_INCREF (comparator);
+    self->btCompareCallback = comparator;
+
+    /* This is to workaround a problem with un-initialized threads (see
+       comment in DB_associate) */
+#ifdef WITH_THREAD
+    PyEval_InitThreads();
+#endif
+
+    err = self->db->set_bt_compare (self->db, 
+                                   (comparator != NULL ? 
+                                    _db_compareCallback : NULL));
+
+    if (err) {
+       /* restore the old state in case of error */
+       Py_DECREF (comparator);
+       self->btCompareCallback = NULL;
+    }
+
+    RETURN_IF_ERR ();
+    RETURN_NONE ();
+}
+
 
 static PyObject*
 DB_set_cachesize(DBObject* self, PyObject* args)
@@ -4400,6 +4561,7 @@ static PyMethodDef DB_methods[] = {
     {"remove",          (PyCFunction)DB_remove,         METH_VARARGS|METH_KEYWORDS},
     {"rename",          (PyCFunction)DB_rename,         METH_VARARGS},
     {"set_bt_minkey",   (PyCFunction)DB_set_bt_minkey,  METH_VARARGS},
+    {"set_bt_compare",  (PyCFunction)DB_set_bt_compare, METH_VARARGS},
     {"set_cachesize",   (PyCFunction)DB_set_cachesize,  METH_VARARGS},
 #if (DBVER >= 41)
     {"set_encrypt",     (PyCFunction)DB_set_encrypt,    METH_VARARGS|METH_KEYWORDS},