package org.openrndr.extra.kdtree import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import org.openrndr.math.* import java.util.* import kotlin.IllegalStateException import kotlin.math.abs /** built-in mapper for [Vector2] */ fun vector2Mapper(v: Vector2, dimension: Int): Double { return when (dimension) { 0 -> v.x else -> v.y } } fun intVector2Mapper(v: IntVector2, dimension: Int): Double { return when (dimension) { 0 -> v.x.toDouble() else -> v.y.toDouble() } } /** built-in mapper for [Vector3] */ fun vector3Mapper(v: Vector3, dimension: Int): Double { return when (dimension) { 0 -> v.x 1 -> v.y else -> v.z } } /** built-in mapper for [Vector4] */ fun vector4Mapper(v: Vector4, dimension: Int): Double { return when (dimension) { 0 -> v.x 1 -> v.y 2 -> v.z else -> v.w } } class KDTreeNode(val dimensions: Int, val mapper: (T, Int) -> Double) { var parent: KDTreeNode? = null var median: Double = 0.0 var dimension: Int = 0 var children: Array?> = arrayOfNulls(2) var item: T? = null internal val isLeaf: Boolean get() = children[0] == null && children[1] == null fun insert(item: T): KDTreeNode { return insert(this, item, dimensions, mapper) } fun remove(node: KDTreeNode): KDTreeNode? { return org.openrndr.extra.kdtree.remove(node, mapper) } fun findNearest(query: T, includeQuery: Boolean = false): T? = findNearest(this, query, includeQuery) fun findKNearest(query: T, k: Int, includeQuery: Boolean = false): List { return findKNearest(this, query, k, includeQuery) } fun findAllInRadius(query: T, radius: Double, includeQuery: Boolean = false): List { return findAllInRadius(this, query, radius, includeQuery) } override fun toString(): String { return "KDTreeNode{" + "median=" + median + ", item=" + item + ", dimension=" + dimension + ", children=" + Arrays.toString(children) + "} " + super.toString() } } private fun insertItem(root: KDTreeNode, item: T): KDTreeNode { return if (root.isLeaf) { root.item = item root } else { if (root.mapper(item, root.dimension) < root.median) { insertItem(root.children[0] ?: throw IllegalStateException("left is null"), item) } else { insertItem(root.children[1] ?: throw IllegalStateException("right is null"), item) } } } fun buildKDTree(items: MutableList, dimensions: Int, mapper: (T, Int) -> Double): KDTreeNode { val root = KDTreeNode(dimensions, mapper) val start = System.currentTimeMillis() fun buildTreeTask( scope: CoroutineScope, node: KDTreeNode, items: MutableList, dimensions: Int, levels: Int, mapper: (T, Int) -> Double ): KDTreeNode { if (items.size > 0) { val dimension = levels % dimensions val values = ArrayList() for (item in items) { values.add(item) } node.dimension = dimension val median = selectNth(items, items.size / 2) { mapper(it, dimension) } val leftItems = mutableListOf() val rightItems = mutableListOf() node.median = mapper(median, dimension) node.item = median for (item in items) { if (item === median) { continue } if (mapper(item, dimension) < node.median) { leftItems.add(item) } else { rightItems.add(item) } } // validate split if (leftItems.size + rightItems.size + 1 != items.size) { throw IllegalStateException("left: ${leftItems.size}, right: ${rightItems.size}, items: ${items.size}") } if (leftItems.size > 0) { node.children[0] = KDTreeNode(dimensions, mapper) node.children[0]?.let { it.parent = node scope.launch { buildTreeTask(scope, it, leftItems, dimensions, levels + 1, mapper) } } } if (rightItems.size > 0) { node.children[1] = KDTreeNode(dimensions, mapper) node.children[1]?.let { it.parent = node scope.launch { buildTreeTask(scope, it, rightItems, dimensions, levels + 1, mapper) } } } } return node } val job = GlobalScope.launch { buildTreeTask(this, root, items, dimensions, 0, mapper) } runBlocking { job.join() } println("building took ${System.currentTimeMillis() - start}ms") return root } private fun sqrDistance(left: T, right: T, dimensions: Int, mapper: (T, Int) -> Double): Double { var distance = 0.0 for (i in 0 until dimensions) { val d = mapper(left, i) - mapper(right, i) distance += d * d } return distance } fun findAllNodes(root: KDTreeNode): List> { val stack = Stack>() val all = ArrayList>() stack.push(root) while (!stack.isEmpty()) { val node = stack.pop() // if (node.item != null /*&& !visited.contains(node.children[1])*/) { all.add(node) // } if (node.children[0] != null /*&&!visited.contains(node.children[0])*/) { stack.push(node.children[0]) } if (node.children[1] != null) { stack.push(node.children[1]) } } return all } fun findKNearest( root: KDTreeNode, query: T, k: Int, includeQuery: Boolean = false ): List { // max-heap with size k val queue = PriorityQueue, Double>>(k + 1) { nodeA, nodeB -> compareValues(nodeB.second, nodeA.second) } fun nearest(node: KDTreeNode?) { if (node != null) { val dimensionValue = node.mapper(query, node.dimension) val route: Int = if (dimensionValue < node.median) { nearest(node.children[0]) 0 } else { nearest(node.children[1]) 1 } val distance = sqrDistance(query, node.item ?: error("item is null"), node.dimensions, node.mapper) if (includeQuery || node.item !== query) { if (queue.size < k || distance < queue.peek().second) { queue.add(Pair(node, distance)) if (queue.size > k) { queue.poll() } } } val d = abs(node.median - dimensionValue) if (queue.size < k || d * d < queue.peek().second) { nearest(node.children[1 - route]) } } } nearest(root) return generateSequence { queue.poll() } .map { it.first.item } .filterNotNull() .toList().reversed() } private fun findNearest(root: KDTreeNode, query: T, includeQuery: Boolean = false): T? { var nearest = java.lang.Double.POSITIVE_INFINITY var nearestArg: KDTreeNode? = null fun nearest(node: KDTreeNode?) { if (node != null) { val route: Int = if (root.mapper(query, node.dimension) < node.median) { nearest(node.children[0]) 0 } else { nearest(node.children[1]) 1 } val distance = sqrDistance( query, node.item ?: error("item is null"), root.dimensions, root.mapper ) if (distance < nearest && (includeQuery || node.item !== query)) { nearest = distance nearestArg = node } val d = abs(node.median - root.mapper(query, node.dimension)) if (d * d < nearest) { nearest(node.children[1 - route]) } } } nearest(root) return nearestArg?.item } private fun findAllInRadius( root: KDTreeNode, query: T, radius: Double, includeQuery: Boolean = false ): List { val sqrMaxDist = radius * radius val queue = ArrayDeque>() queue.add(root) val results = mutableListOf() while (queue.isNotEmpty()) { val node = queue.removeFirst() val dimensionValue = node.mapper(query, node.dimension) val distance = sqrDistance( query, node.item ?: error("item is null"), node.dimensions, node.mapper ) if (distance <= sqrMaxDist && (includeQuery || node.item != query)) { results.add(node.item) } val route: Int = if (dimensionValue < node.median && node.children[0] != null) { queue.add(node.children[0]) 0 } else if (node.children[1] != null) { queue.add(node.children[1]) 1 } else { -1 } if (route != -1) { val d = abs(node.median - dimensionValue) if (d * d <= sqrMaxDist) { val c = node.children[1 - route] if (c != null) { queue.add(c) } } } } return results.filterNotNull() } private fun insert(root: KDTreeNode, item: T, dimensions: Int, mapper: (T, Int) -> Double): KDTreeNode { val stack = Stack>() stack.push(root) dive@ while (true) { val node = stack.peek() val value = mapper(item, node.dimension) if (value < node.median) { if (node.children[0] != null) { stack.push(node.children[0]) } else { // sit here node.children[0] = KDTreeNode(dimensions, mapper) node.children[0]?.item = item node.children[0]?.dimension = (node.dimension + 1) % dimensions node.children[0]?.median = mapper(item, (node.dimension + 1) % dimensions) node.children[0]?.parent = node return node.children[0] ?: throw IllegalStateException("child is null") } } else { if (node.children[1] != null) { stack.push(node.children[1]) } else { // sit here node.children[1] = KDTreeNode(dimensions, mapper) node.children[1]?.item = item node.children[1]?.dimension = (node.dimension + 1) % dimensions node.children[1]?.median = mapper(item, (node.dimension + 1) % dimensions) node.children[1]?.parent = node return node.children[1] ?: throw IllegalStateException("child is null") } } } } private fun remove(toRemove: KDTreeNode, mapper: (T, Int) -> Double): KDTreeNode? { // trivial case if (toRemove.isLeaf) { val p = toRemove.parent if (p != null) { when { p.children[0] === toRemove -> p.children[0] = null p.children[1] === toRemove -> p.children[1] = null else -> { // broken! } } } else { toRemove.item = null } } else { val stack = Stack>() var branch = 0 if (toRemove.children[0] != null) { stack.push(toRemove.children[0]) branch = 0 } else { stack.push(toRemove.children[1]) branch = 1 } var minValue: Double = java.lang.Double.POSITIVE_INFINITY var maxValue: Double = java.lang.Double.NEGATIVE_INFINITY var minArg: KDTreeNode? = null var maxArg: KDTreeNode? = null while (!stack.isEmpty()) { val node = stack.pop() ?: throw RuntimeException("null on stack") val value = mapper(node.item ?: throw IllegalStateException("item is null"), toRemove.dimension) if (value < minValue) { minValue = value minArg = node } if (value > maxValue) { maxValue = value maxArg = node } if (node.dimension != toRemove.dimension) { if (node.children[0] != null) { stack.push(node.children[0]) } if (node.children[1] != null) { stack.push(node.children[1]) } } else { if (branch == 1) { if (node.children[0] != null) { stack.push(node.children[0]) } else { if (node.children[1] != null) { stack.push(node.children[1]) } } } if (branch == 0) { if (node.children[1] != null) { stack.push(node.children[1]) } else { if (node.children[0] != null) { stack.push(node.children[0]) } } } } } if (branch == 1) { toRemove.item = minArg?.item toRemove.median = mapper(minArg?.item ?: throw IllegalStateException("minArg is null"), toRemove.dimension) remove(minArg, mapper) } if (branch == 0) { toRemove.item = maxArg?.item toRemove.median = mapper(maxArg?.item ?: throw IllegalStateException("maxArg is null"), toRemove.dimension) remove(maxArg, mapper) } } return null } @JvmName("kdTreeVector2") fun Iterable.kdTree(): KDTreeNode { val items = this.toMutableList() return buildKDTree(items, 2, ::vector2Mapper) } @JvmName("kdTreeVector3") fun Iterable.kdTree(): KDTreeNode { val items = this.toMutableList() return buildKDTree(items, 3, ::vector3Mapper) } @JvmName("kdTreeVector4") fun Iterable.kdTree(): KDTreeNode { val items = this.toMutableList() return buildKDTree(items, 4, ::vector4Mapper) }