648f7141f4a2c2aad8cf7b46ed374491f16c063f
[blender-addons-contrib.git] / add_mesh_space_tree / kdtree.py
1 # ##### BEGIN GPL LICENSE BLOCK #####
2 #
3 #  SCA Tree Generator, a Blender addon
4 #  (c) 2013 Michel J. Anders (varkenvarken)
5 #
6 #  This module is: kdtree.py
7 #  a pure python implementation of a kdtree
8 #
9 #  This program is free software; you can redistribute it and/or
10 #  modify it under the terms of the GNU General Public License
11 #  as published by the Free Software Foundation; either version 2
12 #  of the License, or (at your option) any later version.
13 #
14 #  This program is distributed in the hope that it will be useful,
15 #  but WITHOUT ANY WARRANTY; without even the implied warranty of
16 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 #  GNU General Public License for more details.
18 #
19 #  You should have received a copy of the GNU General Public License
20 #  along with this program; if not, write to the Free Software Foundation,
21 #  Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
22 #
23 # ##### END GPL LICENSE BLOCK #####
24
25 # <pep8 compliant>
26
27 from copy import copy, deepcopy
28
29
30 class Hyperrectangle:
31     '''an axis aligned bounding box of arbitrary dimension'''
32
33     def __init__(self, dim, min, max):
34         self.dim = dim
35         self.min = deepcopy(min)  # min and max should never point to the same instance
36         self.max = deepcopy(max)
37
38     def extend(self, pos):
39         '''adapt the hyperectangle if necessary so it will contain pos.'''
40         for i in range(self.dim):
41             if pos[i] < self.min[i]:
42                 self.min[i] = pos[i]
43             elif pos[i] > self.max[i]:
44                 self.max[i] = pos[i]
45
46     def distance_squared(self, pos):
47         '''return the distance squared to the nearest edge, or zero if pos lies within the hyperrectangle'''
48         result = 0.0
49         for i in range(self.dim):
50             if pos[i] < self.min[i]:
51                 result += (pos[i] - self.min[i]) ** 2
52             elif pos[i] > self.max[i]:
53                 result += (pos[i] - self.max[i]) ** 2
54         return result
55
56     def __str__(self):
57         return "[(%d) %s:%s]" % (int(self.dim), str(self.min), str(self.max))
58
59
60 class Node:
61     """implements a node in a kd-tree"""
62
63     def __init__(self, pos, data=None):
64         self.pos = deepcopy(pos)
65         self.data = data
66         self.left = None
67         self.right = None
68         self.dim = len(pos)
69         self.dir = 0
70         self.count = 0
71         self.level = 0
72         self.rect = Hyperrectangle(self.dim, pos, pos)
73
74     def addleft(self, node):
75         self.left = node
76         self.rect.extend(node.pos)
77         node.level = self.level + 1
78         node.dir = (self.dir + 1) % self.dim
79
80     def addright(self, node):
81         self.right = node
82         self.rect.extend(node.pos)
83         node.level = self.level + 1
84         node.dir = (self.dir + 1) % self.dim
85
86     def distance_squared(self, pos):
87         d = self.pos - pos
88         return d.dot(d)
89
90     def _str(self, level):
91         s = '  ' * level + str(self.dir) + ' ' + str(self.pos) + ' ' + str(self.rect) + '\n'
92         return s + ('' if self.left is None else 'L:' + self.left._str(level + 1)) + ('' if self.right is None else 'R:' + self.right._str(level + 1))
93
94     def __str__(self):
95         return self._str(0)
96
97
98 class Tree:
99     """implements a kd-tree"""
100
101     def __init__(self, dim):
102         self.root = None
103         self.nnearest = 0  # number of nearest neighbor queries
104         self.count = 0  # number of nodes visited
105         self.level = 0  # deepest node level
106
107     def resetcounters(self):
108         self.nnearest = 0  # number of nearest neighbor queries
109         self.count = 0  # number of nodes visited
110
111     def _insert(self, node, pos, data):
112         if pos[node.dir] < node.pos[node.dir]:
113             if node.left is None:
114                 node.addleft(Node(pos, data))
115                 return node.left
116             else:
117                 node.rect.extend(pos)
118                 return self._insert(node.left, pos, data)
119         else:
120             if node.right is None:
121                 node.addright(Node(pos, data))
122                 return node.right
123             else:
124                 node.rect.extend(pos)
125                 return self._insert(node.right, pos, data)
126
127     def insert(self, pos, data):
128         if self.root is None:
129             self.root = Node(pos, data)
130             self.level = self.root.level
131             return self.root
132         else:
133             node = self._insert(self.root, pos, data)
134             if node.level > self.level:
135                 self.level = node.level
136             return node
137
138     def _nearest(self, node, pos, checkempty, level=0):
139
140         self.count += 1
141
142         dir = node.dir
143         d = pos[dir] - node.pos[dir]
144
145         result = node
146         distsq = None
147         if checkempty and (node.data is None):
148             result = None
149         else:
150             distsq = node.distance_squared(pos)
151
152         if d <= 0:
153             neartree = node.left
154             fartree = node.right
155         else:
156             neartree = node.right
157             fartree = node.left
158
159         if neartree is not None:
160             nearnode, neardistsq = self._nearest(neartree, pos, checkempty, level + 1)
161             if (result is None) or (neardistsq is not None and neardistsq < distsq):
162                 result, distsq = nearnode, neardistsq
163
164         if fartree is not None:
165             if (result is None) or (fartree.rect.distance_squared(pos) < distsq):
166                 farnode, fardistsq = self._nearest(fartree, pos, checkempty, level + 1)
167                 if (result is None) or (fardistsq is not None and fardistsq < distsq):
168                     result, distsq = farnode, fardistsq
169
170         return result, distsq
171
172     def nearest(self, pos, checkempty=False):
173         self.nnearest += 1
174         if self.root is None:
175             return None, None
176         self.root.count = 0
177         node, distsq = self._nearest(self.root, pos, checkempty)
178         self.count += self.root.count
179         return node, distsq
180
181     def __str__(self):
182         return str(self.root)
183
184 if __name__ == "__main__":
185
186     class vector(list):
187
188         def __init__(self, *args):
189             super().__init__([float(a) for a in args])
190
191         def __str__(self):
192             return "<%.1f %.1f %.1f>" % tuple(self[0:3])
193
194         def __sub__(self, other):
195             return vector(self[0] - other[0], self[1] - other[1], self[2] - other[2])
196
197         def __add__(self, other):
198             return vector(self[0] + other[0], self[1] + other[1], self[2] + other[2])
199
200         def __mul__(self, other):
201             s = sum(self[i] * other[i] for i in (0, 1, 2))
202             #print("ds",s,self,other,[self[i]*other[i] for i in (0,1,2)])
203             return s
204
205         def dot(self, other):
206             return sum(self[k] * other[k] for k in (0, 1, 2))
207
208     from random import random, seed, shuffle
209     from time import time
210     import unittest
211
212     class TestVector(unittest.TestCase):
213         def test_ops(self):
214             v1 = vector(1, 0, 0)
215             v2 = vector(2, 1, 0)
216             self.assertAlmostEqual(v1 * v2, 2.0)
217             self.assertAlmostEqual(v1.dot(v2), 2.0)
218             v3 = vector(-1, -1, 0)
219             self.assertListEqual(v1 - v2, v3)
220             v4 = vector(3, 1, 0)
221             self.assertListEqual(v1 + v2, v4)
222
223     class TestHyperRectangle(unittest.TestCase):
224
225         def setUp(self):
226             self.left = vector(0, 0, 0)
227             self.right = vector(1, 1, 1)
228             self.left1 = vector(-1, 0, 0)
229             self.left2 = vector(0, -1, 0)
230             self.left3 = vector(0, 0, -1)
231             self.right1 = vector(2, 0, 0)
232             self.right2 = vector(0, 2, 0)
233             self.right3 = vector(0, 0, 2)
234
235         def test_constructor(self):
236             hr = Hyperrectangle(3, self.left, self.right)
237             self.assertListEqual(hr.min, self.left)
238             self.assertListEqual(hr.max, self.right)
239
240         def test_extend(self):
241             hr = Hyperrectangle(3, self.left, self.right)
242             hr.extend(self.left1)
243             self.assertListEqual(hr.min, [-1, 0, 0])
244             self.assertListEqual(hr.max, [1, 1, 1])
245             hr.extend(self.left2)
246             self.assertListEqual(hr.min, [-1, -1, 0])
247             self.assertListEqual(hr.max, [1, 1, 1])
248             hr.extend(self.left3)
249             self.assertListEqual(hr.min, [-1, -1, -1])
250             self.assertListEqual(hr.max, [1, 1, 1])
251             hr.extend(self.right1)
252             self.assertListEqual(hr.min, [-1, -1, -1])
253             self.assertListEqual(hr.max, [2, 1, 1])
254             hr.extend(self.right2)
255             self.assertListEqual(hr.min, [-1, -1, -1])
256             self.assertListEqual(hr.max, [2, 2, 1])
257             hr.extend(self.right3)
258             self.assertListEqual(hr.min, [-1, -1, -1])
259             self.assertListEqual(hr.max, [2, 2, 2])
260
261         def test_distance_squared(self):
262             hr = Hyperrectangle(3, self.left, self.right)
263             self.assertAlmostEqual(hr.distance_squared(vector(0.5, 0.5, 0.5)), 0.0)
264             self.assertAlmostEqual(hr.distance_squared(vector(0, 0, 0)), 0.0)
265             self.assertAlmostEqual(hr.distance_squared(vector(-1, 0, 0)), 1.0)
266             self.assertAlmostEqual(hr.distance_squared(vector(2, 0, 0)), 1.0)
267             self.assertAlmostEqual(hr.distance_squared(vector(2, 2, 2)), 3.0)
268             self.assertAlmostEqual(hr.distance_squared(vector(0.5, 2, 2)), 2.0)
269             self.assertAlmostEqual(hr.distance_squared(vector(0.5, -1, -1)), 2.0)
270             self.assertAlmostEqual(hr.distance_squared(vector(0.5, 0.5, 2)), 1.0)
271
272     class TestTree(unittest.TestCase):
273
274         def setUp(self):
275             seed(42)
276             r = (-1, 0, 1)
277             self.points = [vector(k, l, m) for k in r for l in r for m in r]
278             d = (-0.1, 0, 0.1)
279             self.d = [vector(k, l, m) for k in d for l in d for m in d if (k * l * m) != 0]
280             self.repeats = 4
281
282         def test_simple(self):
283             tree = Tree(3)
284             p1 = vector(0, 0, 0)
285             p2 = vector(-1, 0, 0)
286             p3 = vector(-1, 1, 0)
287             d = vector(0.1, 0.1, 0.1)
288
289             tree.insert(p1, p1)
290             node, distsq = tree.nearest(p1)
291             self.assertListEqual(node.pos, p1)
292             self.assertAlmostEqual(distsq, 0.0)
293             node, distsq = tree.nearest(p1 + d)
294             self.assertListEqual(node.pos, p1)
295             self.assertAlmostEqual(distsq, 0.03)
296
297             tree.insert(p2, p2)
298             node, distsq = tree.nearest(p1)
299             self.assertListEqual(node.pos, p1)
300             self.assertAlmostEqual(distsq, 0.0)
301             node, distsq = tree.nearest(p1 + d)
302             self.assertListEqual(node.pos, p1)
303             self.assertAlmostEqual(distsq, 0.03)
304
305             node, distsq = tree.nearest(p2)
306             self.assertListEqual(node.pos, p2)
307             self.assertAlmostEqual(distsq, 0.0)
308             node, distsq = tree.nearest(p2 + d)
309             self.assertListEqual(node.pos, p2)
310             self.assertAlmostEqual(distsq, 0.03)
311
312             tree.insert(p3, p3)
313             node, distsq = tree.nearest(p1)
314             self.assertListEqual(node.pos, p1)
315             self.assertAlmostEqual(distsq, 0.0)
316             node, distsq = tree.nearest(p1 + d)
317             self.assertListEqual(node.pos, p1)
318             self.assertAlmostEqual(distsq, 0.03)
319
320             node, distsq = tree.nearest(p2)
321             self.assertListEqual(node.pos, p2)
322             self.assertAlmostEqual(distsq, 0.0)
323             node, distsq = tree.nearest(p2 + d)
324             self.assertListEqual(node.pos, p2)
325             self.assertAlmostEqual(distsq, 0.03)
326
327             node, distsq = tree.nearest(p3)
328             self.assertListEqual(node.pos, p3)
329             self.assertAlmostEqual(distsq, 0.0)
330             node, distsq = tree.nearest(p3 + d)
331             self.assertListEqual(node.pos, p3)
332             self.assertAlmostEqual(distsq, 0.03)
333
334         def test_nearest(self):
335             for n in range(self.repeats):
336                 tree = Tree(3)
337                 shuffle(self.points)
338                 for p in self.points:
339                     tree.insert(p, p)  # data equal to position
340
341                 for p in self.points:
342                     for d in self.d:
343                         node, distsq = tree.nearest(p + d)
344                         s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root))
345                         self.assertListEqual(node.pos, p, msg=s)
346                         self.assertListEqual(node.data, p)
347                         self.assertAlmostEqual(distsq, d.dot(d), msg=s)
348
349                 for p in self.points:
350                     node, distsq = tree.nearest(p)
351                     self.assertListEqual(node.pos, p)
352                     self.assertListEqual(node.data, p)
353                     self.assertAlmostEqual(distsq, 0.0)
354
355         def test_nearest_empty(self):
356             for n in range(self.repeats):
357                 tree = Tree(3)
358                 shuffle(self.points)
359                 for p in self.points:
360                     tree.insert(p, p)  # data equal to position
361
362                 for p in self.points:
363                     for d in self.d:
364                         node, distsq = tree.nearest(p + d)
365                         s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root))
366                         self.assertListEqual(node.pos, p, msg=s)
367                         self.assertListEqual(node.data, p)
368                         self.assertAlmostEqual(distsq, d.dot(d), msg=s)
369
370                 for p in self.points:
371                     node, distsq = tree.nearest(p)
372                     self.assertListEqual(node.pos, p)
373                     self.assertListEqual(node.data, p)
374                     self.assertAlmostEqual(distsq, 0.0)
375
376                 # zeroing out a node should not affect retrieving any other node ...
377                 node, _ = tree.nearest(self.points[-1])  # last point
378                 node.data = None
379                 for p in self.points[:-1]:  # all but last
380                     for d in self.d:
381                         node, distsq = tree.nearest(p + d)
382                         s = "%s %s %s %s\n%s" % (str(p + d), str(p), str(d), str(node.pos), str(tree.root))
383                         self.assertListEqual(node.pos, p, msg=s)
384                         self.assertListEqual(node.data, p)
385                         self.assertAlmostEqual(distsq, d.dot(d), msg=s)
386
387                 for p in self.points[:-1]:  # all but last
388                     node, distsq = tree.nearest(p)
389                     self.assertListEqual(node.pos, p)
390                     self.assertListEqual(node.data, p)
391                     self.assertAlmostEqual(distsq, 0.0)
392
393                 # ... also, we should be able to retrieve the node anyway ...
394                 node, distsq = tree.nearest(self.points[-1])
395                 self.assertListEqual(node.pos, self.points[-1])
396                 self.assertIsNone(node.data)
397                 self.assertAlmostEqual(distsq, 0.0)
398
399                 # ... even for points nearby ...
400                 for d in self.d:
401                     node, distsq = tree.nearest(self.points[-1] + d)
402                     self.assertEqual(node.pos, self.points[-1])
403                     self.assertIsNone(node.data)
404                     self.assertAlmostEqual(distsq, d.dot(d))
405
406                 # ... unless we set checkempty
407                 node, distsq = tree.nearest(self.points[-1], checkempty=True)
408                 self.assertNotEqual(node.pos, self.points[-1])
409                 self.assertIsNotNone(node.data)
410                 self.assertAlmostEqual(distsq, 1.0)  # on a perpendicular grid nearest neighbor is at most 1 away
411
412         def test_performance(self):
413             tree = Tree(3)
414             tsize = 1000
415             qsize = 1000
416             emptyq = 3
417
418             print("<performance test, may take several seconds>")
419             qpos = [vector(random(), random(), random()) for p in range(qsize)]
420             for p in range(tsize):
421                 pos = vector(random(), random(), random())
422                 tree.insert(pos, pos)
423             s = time()
424             for p in qpos:
425                 node, distsq = tree.nearest(p)
426             e = time() - s
427             print("queries|tree size|tree height|empties|query load|query time")
428             print("{0:7d}|{2:9d}|{1.level:11d}|      0|{3:10.2f}|{4:10.1f}".format(qsize, tree, tsize, float(tree.count) / qsize, e))
429
430             tree.resetcounters()
431             empty = []
432             for p in range(tsize * 9):
433                 pos = vector(random(), random(), random())
434                 tree.insert(pos, pos)
435                 if (p % emptyq) == 1:
436                     empty.append(pos)
437             s = time()
438             for p in qpos:
439                 node, distsq = tree.nearest(p)
440             e2 = time() - s
441             print("{0:7d}|{2:9d}|{1.level:11d}|      0|{3:10.2f}|{4:10.1f}".format(qsize, tree, tsize * 10, float(tree.count) / qsize, e2))
442
443             self.assertLess(e2, 3 * e, msg="a 10x bigger tree shouldn't take more than 3x the time to query")
444
445             for p in empty:
446                 node, distsq = tree.nearest(p)
447                 node.data = None
448             tree.resetcounters()
449             s = time()
450             for p in qpos:
451                 node, distsq = tree.nearest(p, checkempty=True)
452             e3 = time() - s
453             print("{0:7d}|{2:9d}|{1.level:11d}|{5:7d}|{3:10.2f}|{4:10.1f}".format(qsize, tree, tsize * 10, float(tree.count) / qsize, e3, tsize * 10 // emptyq))
454
455     unittest.main()