[[!tag obnam btree]]
I think I need a B-tree implementation for Obnam, in Python. I could not find anything suitable so I wrote my own. However, since it about two decades since my data structures class at university, I probably messed it up. Please tell me how?
I include the code below, and it can also be found via bzr:
bzr get http://code.liw.fi/btree/bzr/trunk/
The code in bzr may get updated; I will keep the code below static. The bzr branch also contains some automatic test cases.
One of the requirements I have for the B-tree code is that it needs to update things via copy-on-write. In Obnam, I will not overwrite data on disk, I will instead write a new file, and then do garbage collection at a later time to reclaim the files that are no longer needed. This will be necessary for implementing backup generations, for example. That's why some of the code might be a bit weird.
Once I have some confidence that my code works, I will extend the tree code to use some external, user-provided mechanism for storing the nodes, and to use the size of the nodes in bytes as the limiting factor, not the number of keys.
In addition to bugs, I welcome any other feedback.
class Node(dict):
'''Abstract base class for index and leaf nodes.
A node may be initialized with a list of (key, value) pairs. For
leaf nodes, the values are the actual values. For index nodes, they
are references to other nodes.
'''
def keys(self):
'''Return keys in the node, sorted.'''
return sorted(dict.keys(self))
def first_key(self):
'''Return smallest key in the node.'''
return self.keys()[0]
def pairs(self, exclude=None):
'''Return (key, value) pairs in the node.
``exclude`` can be set to a list of keys that should be excluded
from the list.
'''
if exclude is None:
exclude = []
return sorted((key, self[key]) for key in self if key not in exclude)
class LeafNode(Node):
'''Leaf node in the tree.
A leaf node contains key/value pairs, and has no children.
'''
pass
class IndexNode(Node):
'''Index node in the tree.
An index node contains pairs of keys and references to other nodes.
The other nodes may be either index nodes or leaf nodes.
'''
def __init__(self, pairs):
for key, child in pairs:
assert type(key) == str
assert isinstance(child, IndexNode) or isinstance(child, LeafNode)
dict.__init__(self, pairs)
def find_key_for_child_containing(self, key):
'''Return key for the child that contains ``key``.'''
for k in reversed(self.keys()):
if key >= k:
return k
return None
class BTree(object):
'''B-tree.
The tree is balanced, and has a fan-out factor given to the initializer
as its only argument. The fan-out factor determines how aggressively
the tree expands at each level.
Three basic operations are available to the tree: lookup, insert, and
remove.
'''
def __init__(self, fanout):
self.root = IndexNode([])
self.fanout = fanout
self.min_index_length = self.fanout
self.max_index_length = 2 * self.fanout + 1
def lookup(self, key):
'''Return value corresponding to ``key``.
If the key is not in the tree, raise ``KeyError``.
'''
return self._lookup(self.root, key)
def _lookup(self, node, key):
if isinstance(node, LeafNode):
return node[key]
else:
k = node.find_key_for_child_containing(key)
if k is None:
raise KeyError(key)
else:
return self._lookup(node[k], key)
def insert(self, key, value):
'''Insert a new key/value pair into the tree.
If the key already existed in the tree, the old value is silently
forgotten.
'''
a, b = self._insert(self.root, key, value)
if b is None:
self.root = a
else:
self.root = IndexNode([(a.first_key(), a),
(b.first_key(), b)])
def _insert(self, node, key, value):
if isinstance(node, LeafNode):
return self._insert_into_leaf(node, key, value)
elif len(node) == 0:
return self._insert_into_empty_root(key, value)
elif len(node) == self.max_index_length:
return self._insert_into_full_index(node, key, value)
else:
return self._insert_into_nonfull_index(node, key, value)
def _insert_into_leaf(self, leaf, key, value):
pairs = sorted(leaf.pairs(exclude=[key]) + [(key, value)])
if len(pairs) <= self.fanout:
return LeafNode(pairs), None
else:
n = len(pairs) / 2
leaf1 = LeafNode(pairs[:n])
leaf2 = LeafNode(pairs[n:])
return leaf1, leaf2
def _insert_into_empty_root(self, key, value):
leaf = LeafNode([(key, value)])
return IndexNode([(leaf.first_key(), leaf)]), None
def _insert_into_full_index(self, node, key, value):
# A full index node needs to be split, then key/value inserted into
# one of the halves.
pairs = node.pairs()
n = len(pairs) / 2
node1 = IndexNode(pairs[:n])
node2 = IndexNode(pairs[n:])
if key < node2.first_key():
a, b = self._insert(node1, key, value)
assert b is None
return a, node2
else:
a, b = self._insert(node2, key, value)
assert b is None
return node1, a
def _insert_into_nonfull_index(self, node, key, value):
# Insert into correct child, get up to two replacements for
# that child.
k = node.find_key_for_child_containing(key)
if k is None:
k = node.first_key()
a, b = self._insert(node[k], key, value)
assert a is not None
pairs = node.pairs(exclude=[k]) + [(a.first_key(), a)]
if b is not None:
pairs += [(b.first_key(), b)]
pairs.sort()
assert len(pairs) <= self.max_index_length
return IndexNode(pairs), None
def remove(self, key):
'''Remove ``key`` and its associated value from tree.
If key is not in the tree, ``KeyValue`` is raised.
'''
self.root = self._remove(self.root, key)
if self.root is None:
self.root = IndexNode([])
def _remove(self, node, key):
if isinstance(node, LeafNode):
return self._remove_from_leaf(node, key)
else:
k = node.find_key_for_child_containing(key)
if k is None:
raise KeyError(key)
elif len(node[k]) <= self.min_index_length:
return self._remove_from_minimal_index(node, key, k)
else:
return self._remove_from_nonminimal_index(node, key, k)
def _remove_from_leaf(self, node, key):
if key in node:
pairs = node.pairs(exclude=[key])
if pairs:
return LeafNode(pairs)
else:
return None
else:
raise KeyError(key)
def _merge(self, n1, n2):
if isinstance(n1, IndexNode):
assert isinstance(n2, IndexNode)
return IndexNode(n1.pairs() + n2.pairs())
else:
assert isinstance(n1, LeafNode)
assert isinstance(n2, LeafNode)
return LeafNode(n1.pairs() + n2.pairs())
def _remove_from_minimal_index(self, node, key, child_key):
exclude = [child_key]
new_ones = []
child = self._remove(node[child_key], key)
if child is not None:
keys = node.keys()
i = keys.index(child_key)
# If possible, merge with left or right sibling.
if i > 0 and len(node[keys[i-1]]) < self.max_index_length:
new_ones.append(self._merge(node[keys[i-1]], child))
exclude.append(keys[i-1])
elif i+1 < len(keys) and len(node[keys[i+1]]) < self.max_index_length:
new_ones.append(self._merge(node[keys[i+1]], child))
exclude.append(keys[i+1])
else:
new_ones.append(child)
others = node.pairs(exclude=exclude)
if others + new_ones:
return IndexNode(others + [(n.first_key(), n) for n in new_ones])
else:
return None
def _remove_from_nonminimal_index(self, node, key, child_key):
child = self._remove(node[child_key], key)
pairs = node.pairs(exclude=[child_key])
if child is not None:
pairs += [(child.first_key(), child)]
pairs.sort()
assert pairs
return IndexNode(pairs)