"""DNS Wire Data Helper"""
+import sys
import dns.exception
from ._compat import binary_type, string_types
# out what constant Python will use.
-class _SliceUnspecifiedBound(str):
+class _SliceUnspecifiedBound(binary_type):
- def __getslice__(self, i, j):
- return j
+ def __getitem__(self, key):
+ return key.stop
+
+ if sys.version_info < (3,):
+ def __getslice__(self, i, j): # pylint: disable=getslice-method
+ return self.__getitem__(slice(i, j))
-_unspecified_bound = _SliceUnspecifiedBound('')[1:]
+_unspecified_bound = _SliceUnspecifiedBound()[1:]
class WireData(binary_type):
def __getitem__(self, key):
try:
if isinstance(key, slice):
- return WireData(super(WireData, self).__getitem__(key))
+ # make sure we are not going outside of valid ranges,
+ # do stricter control of boundaries than python does
+ # by default
+ start = key.start
+ stop = key.stop
+
+ if sys.version_info < (3,):
+ if stop == _unspecified_bound:
+ # handle the case where the right bound is unspecified
+ stop = len(self)
+
+ if start < 0 or stop < 0:
+ raise dns.exception.FormError
+ # If it's not an empty slice, access left and right bounds
+ # to make sure they're valid
+ if start != stop:
+ super(WireData, self).__getitem__(start)
+ super(WireData, self).__getitem__(stop - 1)
+ else:
+ for index in (start, stop):
+ if index is None:
+ continue
+ elif abs(index) > len(self):
+ raise dns.exception.FormError
+
+ return WireData(super(WireData, self).__getitem__(
+ slice(start, stop)))
return bytearray(self.unwrap())[key]
except IndexError:
raise dns.exception.FormError
- def __getslice__(self, i, j):
- try:
- if j == _unspecified_bound:
- # handle the case where the right bound is unspecified
- j = len(self)
- if i < 0 or j < 0:
- raise dns.exception.FormError
- # If it's not an empty slice, access left and right bounds
- # to make sure they're valid
- if i != j:
- super(WireData, self).__getitem__(i)
- super(WireData, self).__getitem__(j - 1)
- return WireData(super(WireData, self).__getslice__(i, j))
- except IndexError:
- raise dns.exception.FormError
+ if sys.version_info < (3,):
+ def __getslice__(self, i, j): # pylint: disable=getslice-method
+ return self.__getitem__(slice(i, j))
def __iter__(self):
i = 0
--- /dev/null
+# Copyright (C) 2016
+# Author: Martin Basti <martin.basti@gmail.com>
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+from dns.exception import FormError
+from dns.wiredata import WireData
+
+
+class WireDataSlicingTestCase(unittest.TestCase):
+
+ def testSliceAll(self):
+ """Get all data"""
+ inst = WireData(b'0123456789')
+ self.assertEqual(inst[:], WireData(b'0123456789'))
+
+ def testSliceAllExplicitlyDefined(self):
+ """Get all data"""
+ inst = WireData(b'0123456789')
+ self.assertEqual(inst[0:10], WireData(b'0123456789'))
+
+ def testSliceLowerHalf(self):
+ """Get lower half of data"""
+ inst = WireData(b'0123456789')
+ self.assertEqual(inst[:5], WireData(b'01234'))
+
+ def testSliceLowerHalfWithNegativeIndex(self):
+ """Get lower half of data"""
+ inst = WireData(b'0123456789')
+ self.assertEqual(inst[:-5], WireData(b'01234'))
+
+ def testSliceUpperHalf(self):
+ """Get upper half of data"""
+ inst = WireData(b'0123456789')
+ self.assertEqual(inst[5:], WireData(b'56789'))
+
+ def testSliceMiddle(self):
+ """Get data from middle"""
+ inst = WireData(b'0123456789')
+ self.assertEqual(inst[3:6], WireData(b'345'))
+
+ def testSliceMiddleWithNegativeIndex(self):
+ """Get data from middle"""
+ inst = WireData(b'0123456789')
+ self.assertEqual(inst[-6:-3], WireData(b'456'))
+
+ def testSliceMiddleWithMixedIndex(self):
+ """Get data from middle"""
+ inst = WireData(b'0123456789')
+ self.assertEqual(inst[-8:3], WireData(b'2'))
+ self.assertEqual(inst[5:-3], WireData(b'56'))
+
+ def testGetOne(self):
+ """Get data one by one item"""
+ data = b'0123456789'
+ inst = WireData(data)
+ for i, byte in enumerate(bytearray(data)):
+ self.assertEqual(inst[i], byte)
+ for i in range(-1, len(data) * -1, -1):
+ self.assertEqual(inst[i], bytearray(data)[i])
+
+ def testEmptySlice(self):
+ """Test empty slice"""
+ data = b'0123456789'
+ inst = WireData(data)
+ for i, byte in enumerate(data):
+ self.assertEqual(inst[i:i], b'')
+ for i in range(-1, len(data) * -1, -1):
+ self.assertEqual(inst[i:i], b'')
+ self.assertEqual(inst[-3:-6], b'')
+
+ def testSliceStartOutOfLowerBorder(self):
+ """Get data from out of lower border"""
+ inst = WireData(b'0123456789')
+ with self.assertRaises(FormError):
+ inst[-11:] # pylint: disable=pointless-statement
+
+ def testSliceStopOutOfLowerBorder(self):
+ """Get data from out of lower border"""
+ inst = WireData(b'0123456789')
+ with self.assertRaises(FormError):
+ inst[:-11] # pylint: disable=pointless-statement
+
+ def testSliceBothOutOfLowerBorder(self):
+ """Get data from out of lower border"""
+ inst = WireData(b'0123456789')
+ with self.assertRaises(FormError):
+ inst[-12:-11] # pylint: disable=pointless-statement
+
+ def testSliceStartOutOfUpperBorder(self):
+ """Get data from out of upper border"""
+ inst = WireData(b'0123456789')
+ with self.assertRaises(FormError):
+ inst[11:] # pylint: disable=pointless-statement
+
+ def testSliceStopOutOfUpperBorder(self):
+ """Get data from out of upper border"""
+ inst = WireData(b'0123456789')
+ with self.assertRaises(FormError):
+ inst[:11] # pylint: disable=pointless-statement
+
+ def testSliceBothOutOfUpperBorder(self):
+ """Get data from out of lower border"""
+ inst = WireData(b'0123456789')
+ with self.assertRaises(FormError):
+ inst[10:20] # pylint: disable=pointless-statement
+
+ def testGetOneOutOfLowerBorder(self):
+ """Get item outside of range"""
+ inst = WireData(b'0123456789')
+ with self.assertRaises(FormError):
+ inst[-11] # pylint: disable=pointless-statement
+
+ def testGetOneOutOfUpperBorder(self):
+ """Get item outside of range"""
+ inst = WireData(b'0123456789')
+ with self.assertRaises(FormError):
+ inst[10] # pylint: disable=pointless-statement