diff --git a/orx-kdtree/src/demo/kotlin/DemoKNearestNeighbour01.kt b/orx-kdtree/src/demo/kotlin/DemoKNearestNeighbour01.kt index 0204a12b..bd5c30ba 100644 --- a/orx-kdtree/src/demo/kotlin/DemoKNearestNeighbour01.kt +++ b/orx-kdtree/src/demo/kotlin/DemoKNearestNeighbour01.kt @@ -2,13 +2,13 @@ import org.openrndr.application import org.openrndr.color.ColorRGBa import org.openrndr.extra.kdtree.buildKDTree import org.openrndr.extra.kdtree.findKNearest +import org.openrndr.extra.kdtree.kdTree import org.openrndr.extra.kdtree.vector2Mapper import org.openrndr.math.Vector2 import org.openrndr.shape.LineSegment fun main() { application { - configure { width = 1080 height = 720 @@ -18,12 +18,12 @@ fun main() { val points = MutableList(1000) { Vector2(Math.random() * width, Math.random() * height) } - val tree = buildKDTree(points, 2, ::vector2Mapper) + val tree = points.kdTree() extend { drawer.circles(points, 5.0) - val kNearest = findKNearest(tree, mouse.position, k=7, dimensions = 2, ::vector2Mapper) + val kNearest = tree.findKNearest(mouse.position, k = 7) drawer.fill = ColorRGBa.RED drawer.stroke = ColorRGBa.RED drawer.strokeWeight = 2.0 diff --git a/orx-kdtree/src/demo/kotlin/DemoNearestNeighbour01.kt b/orx-kdtree/src/demo/kotlin/DemoNearestNeighbour01.kt index fd5200da..418d2a32 100644 --- a/orx-kdtree/src/demo/kotlin/DemoNearestNeighbour01.kt +++ b/orx-kdtree/src/demo/kotlin/DemoNearestNeighbour01.kt @@ -1,8 +1,5 @@ import org.openrndr.application -import org.openrndr.extensions.SingleScreenshot -import org.openrndr.extra.kdtree.buildKDTree -import org.openrndr.extra.kdtree.findNearest -import org.openrndr.extra.kdtree.vector2Mapper +import org.openrndr.extra.kdtree.kdTree import org.openrndr.math.Vector2 fun main() { @@ -15,10 +12,10 @@ fun main() { val points = MutableList(1000) { Vector2(Math.random() * width, Math.random() * height) } - val tree = buildKDTree(points, 2, ::vector2Mapper) + val tree = points.kdTree() extend { drawer.circles(points, 5.0) - val nearest = findNearest(tree, mouse.position, 2, ::vector2Mapper) + val nearest = tree.findNearest(mouse.position) nearest?.let { drawer.circle(it.x, it.y, 20.0) } diff --git a/orx-kdtree/src/demo/kotlin/DemoRangeQuery01.kt b/orx-kdtree/src/demo/kotlin/DemoRangeQuery01.kt index 3a86b965..9a8b5834 100644 --- a/orx-kdtree/src/demo/kotlin/DemoRangeQuery01.kt +++ b/orx-kdtree/src/demo/kotlin/DemoRangeQuery01.kt @@ -1,8 +1,6 @@ import org.openrndr.application import org.openrndr.color.ColorRGBa -import org.openrndr.extra.kdtree.buildKDTree -import org.openrndr.extra.kdtree.findAllInRange -import org.openrndr.extra.kdtree.vector2Mapper +import org.openrndr.extra.kdtree.kdTree import org.openrndr.math.Vector2 @@ -18,13 +16,13 @@ fun main() { val points = MutableList(1000) { Vector2(Math.random() * width, Math.random() * height) } - val tree = buildKDTree(points, 2, ::vector2Mapper) + val tree = points.kdTree() val radius = 50.0 extend { drawer.circles(points, 5.0) - val allInRange = findAllInRange(tree, mouse.position, maxDistance = radius, dimensions = 2, ::vector2Mapper) + val allInRange = tree.findAllInRadius(mouse.position, radius = radius) drawer.fill = ColorRGBa.PINK drawer.stroke = ColorRGBa.PINK drawer.strokeWeight = 2.0 diff --git a/orx-kdtree/src/main/kotlin/KDTree.kt b/orx-kdtree/src/main/kotlin/KDTree.kt index 5a7705ee..2a84a7f8 100644 --- a/orx-kdtree/src/main/kotlin/KDTree.kt +++ b/orx-kdtree/src/main/kotlin/KDTree.kt @@ -4,10 +4,7 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking -import org.openrndr.math.IntVector2 -import org.openrndr.math.Vector2 -import org.openrndr.math.Vector3 -import org.openrndr.math.Vector4 +import org.openrndr.math.* import java.util.* import kotlin.IllegalStateException import kotlin.math.abs @@ -47,7 +44,7 @@ fun vector4Mapper(v: Vector4, dimension: Int): Double { } } -class KDTreeNode { +class KDTreeNode(val dimensions: Int, val mapper: (T, Int) -> Double) { var parent: KDTreeNode? = null var median: Double = 0.0 var dimension: Int = 0 @@ -57,6 +54,26 @@ class KDTreeNode { internal val isLeaf: Boolean get() = children[0] == null && children[1] == null + + fun insert(item: T): KDTreeNode { + return insert(this, item, dimensions, mapper) + } + + fun remove(node: KDTreeNode): KDTreeNode? { + return org.openrndr.extra.kdtree.remove(node, mapper) + } + + + fun findNearest(query: T, includeQuery: Boolean = false): T? = findNearest(this, query, includeQuery) + + fun findKNearest(query: T, k: Int, includeQuery: Boolean = false): List { + return findKNearest(this, query, k, includeQuery) + } + + fun findAllInRadius(query: T, radius: Double, includeQuery: Boolean = false): List { + return findAllInRadius(this, query, radius, includeQuery) + } + override fun toString(): String { return "KDTreeNode{" + "median=" + median + @@ -68,24 +85,32 @@ class KDTreeNode { } } -fun insertItem(root: KDTreeNode, item: T, mapper: (T, Int) -> Double): KDTreeNode { +private fun insertItem(root: KDTreeNode, item: T): KDTreeNode { return if (root.isLeaf) { root.item = item root } else { - if (mapper(item, root.dimension) < root.median) { - insertItem(root.children[0] ?: throw IllegalStateException("left is null"), item, mapper) + if (root.mapper(item, root.dimension) < root.median) { + insertItem(root.children[0] ?: throw IllegalStateException("left is null"), item) } else { - insertItem(root.children[1] ?: throw IllegalStateException("right is null"), item, mapper) + insertItem(root.children[1] ?: throw IllegalStateException("right is null"), item) } } } + fun buildKDTree(items: MutableList, dimensions: Int, mapper: (T, Int) -> Double): KDTreeNode { - val root = KDTreeNode() + val root = KDTreeNode(dimensions, mapper) val start = System.currentTimeMillis() - fun buildTreeTask(scope: CoroutineScope, node: KDTreeNode, items: MutableList, dimensions: Int, levels: Int, mapper: (T, Int) -> Double): KDTreeNode { + fun buildTreeTask( + scope: CoroutineScope, + node: KDTreeNode, + items: MutableList, + dimensions: Int, + levels: Int, + mapper: (T, Int) -> Double + ): KDTreeNode { if (items.size > 0) { val dimension = levels % dimensions @@ -119,7 +144,7 @@ fun buildKDTree(items: MutableList, dimensions: Int, mapper: (T, Int) -> } if (leftItems.size > 0) { - node.children[0] = KDTreeNode() + node.children[0] = KDTreeNode(dimensions, mapper) node.children[0]?.let { it.parent = node @@ -129,7 +154,7 @@ fun buildKDTree(items: MutableList, dimensions: Int, mapper: (T, Int) -> } } if (rightItems.size > 0) { - node.children[1] = KDTreeNode() + node.children[1] = KDTreeNode(dimensions, mapper) node.children[1]?.let { it.parent = node scope.launch { @@ -147,7 +172,7 @@ fun buildKDTree(items: MutableList, dimensions: Int, mapper: (T, Int) -> runBlocking { job.join() } - println("building took ${System.currentTimeMillis()-start}ms") + println("building took ${System.currentTimeMillis() - start}ms") return root } @@ -185,45 +210,45 @@ fun findAllNodes(root: KDTreeNode): List> { fun findKNearest( root: KDTreeNode, - item: T, + query: T, k: Int, - dimensions: Int, - mapper: (T, Int) -> Double + includeQuery: Boolean = false ): List { // max-heap with size k - val queue = PriorityQueue, Double>>(k + 1) { - nodeA, nodeB -> compareValues(nodeB.second, nodeA.second) + val queue = PriorityQueue, Double>>(k + 1) { nodeA, nodeB -> + compareValues(nodeB.second, nodeA.second) } - fun nearest(node: KDTreeNode?, item: T) { + fun nearest(node: KDTreeNode?) { if (node != null) { - val dimensionValue = mapper(item, node.dimension) + val dimensionValue = node.mapper(query, node.dimension) val route: Int = if (dimensionValue < node.median) { - nearest(node.children[0], item) + nearest(node.children[0]) 0 } else { - nearest(node.children[1], item) + nearest(node.children[1]) 1 } - val distance = sqrDistance(item, node.item - ?: throw IllegalStateException("item is null"), dimensions, mapper) + val distance = sqrDistance(query, node.item ?: error("item is null"), node.dimensions, node.mapper) - if (queue.size < k || distance < queue.peek().second) { - queue.add(Pair(node, distance)) - if (queue.size > k) { - queue.poll() + if (includeQuery || node.item !== query) { + if (queue.size < k || distance < queue.peek().second) { + queue.add(Pair(node, distance)) + if (queue.size > k) { + queue.poll() + } } } val d = abs(node.median - dimensionValue) if (d * d < queue.peek().second || queue.size < k) { - nearest(node.children[1 - route], item) + nearest(node.children[1 - route]) } } } - nearest(root, item) + nearest(root) return generateSequence { queue.poll() } .map { it.first.item } @@ -231,77 +256,79 @@ fun findKNearest( .toList().reversed() } -fun findNearest(root: KDTreeNode, item: T, dimensions: Int, mapper: (T, Int) -> Double): T? { +private fun findNearest(root: KDTreeNode, query: T, includeQuery: Boolean = false): T? { var nearest = java.lang.Double.POSITIVE_INFINITY var nearestArg: KDTreeNode? = null - fun nearest(node: KDTreeNode?, item: T) { + fun nearest(node: KDTreeNode?) { if (node != null) { - - if (node.item == null) { - println(node) - } - - val route: Int = if (mapper(item, node.dimension) < node.median) { - nearest(node.children[0], item) + val route: Int = if (root.mapper(query, node.dimension) < node.median) { + nearest(node.children[0]) 0 } else { - nearest(node.children[1], item) + nearest(node.children[1]) 1 } - val distance = sqrDistance(item, node.item - ?: throw IllegalStateException("item is null"), dimensions, mapper) - if (distance < nearest) { + val distance = sqrDistance( + query, node.item + ?: error("item is null"), root.dimensions, root.mapper + ) + if (distance < nearest && (includeQuery || node.item !== query)) { nearest = distance nearestArg = node } - - val d = abs(node.median - mapper(item, node.dimension)) + val d = abs(node.median - root.mapper(query, node.dimension)) if (d * d < nearest) { - nearest(node.children[1 - route], item) + nearest(node.children[1 - route]) } } } - nearest(root, item) + nearest(root) return nearestArg?.item } -fun findAllInRange( +private fun findAllInRadius( root: KDTreeNode, - item: T, - maxDistance: Double, - dimensions: Int, - mapper: (T, Int) -> Double -) : List { + query: T, + radius: Double, + includeQuery: Boolean = false +): List { - val sqrMaxDist = maxDistance * maxDistance - val queue = kotlin.collections.ArrayDeque?>() + val sqrMaxDist = radius * radius + val queue = ArrayDeque>() queue.add(root) val results = mutableListOf() while (queue.isNotEmpty()) { val node = queue.removeFirst() - if (node != null) { - val dimensionValue = mapper(item, node.dimension) - val distance = sqrDistance(item, node.item - ?: throw IllegalStateException("item is null"), dimensions, mapper) - if (distance <= sqrMaxDist) { - results.add(node.item) - } + val dimensionValue = node.mapper(query, node.dimension) + val distance = sqrDistance( + query, node.item + ?: error("item is null"), node.dimensions, node.mapper + ) + if (distance <= sqrMaxDist && (includeQuery || node.item != query)) { + results.add(node.item) + } - val route: Int = if (dimensionValue < node.median) { - queue.add(node.children[0]) - 0 - } else { - queue.add(node.children[1]) - 1 - } + val route: Int = if (dimensionValue < node.median && node.children[0] != null) { + queue.add(node.children[0]) + 0 + } else if (node.children[1] != null) { + queue.add(node.children[1]) + 1 + } else { + -1 + } + if (route != -1) { val d = abs(node.median - dimensionValue) if (d * d <= sqrMaxDist) { - queue.add(node.children[1 - route]) + val c = node.children[1 - route] + if (c != null) { + queue.add(c) + } } } } @@ -309,7 +336,7 @@ fun findAllInRange( return results.filterNotNull() } -fun insert(root: KDTreeNode, item: T, dimensions: Int, mapper: (T, Int) -> Double): KDTreeNode { +private fun insert(root: KDTreeNode, item: T, dimensions: Int, mapper: (T, Int) -> Double): KDTreeNode { val stack = Stack>() stack.push(root) @@ -324,7 +351,7 @@ fun insert(root: KDTreeNode, item: T, dimensions: Int, mapper: (T, Int) - stack.push(node.children[0]) } else { // sit here - node.children[0] = KDTreeNode() + node.children[0] = KDTreeNode(dimensions, mapper) node.children[0]?.item = item node.children[0]?.dimension = (node.dimension + 1) % dimensions node.children[0]?.median = mapper(item, (node.dimension + 1) % dimensions) @@ -336,7 +363,7 @@ fun insert(root: KDTreeNode, item: T, dimensions: Int, mapper: (T, Int) - stack.push(node.children[1]) } else { // sit here - node.children[1] = KDTreeNode() + node.children[1] = KDTreeNode(dimensions, mapper) node.children[1]?.item = item node.children[1]?.dimension = (node.dimension + 1) % dimensions node.children[1]?.median = mapper(item, (node.dimension + 1) % dimensions) @@ -348,7 +375,7 @@ fun insert(root: KDTreeNode, item: T, dimensions: Int, mapper: (T, Int) - } } -fun remove(toRemove: KDTreeNode, mapper: (T, Int) -> Double): KDTreeNode? { +private fun remove(toRemove: KDTreeNode, mapper: (T, Int) -> Double): KDTreeNode? { // trivial case if (toRemove.isLeaf) { val p = toRemove.parent @@ -439,3 +466,21 @@ fun remove(toRemove: KDTreeNode, mapper: (T, Int) -> Double): KDTreeNode< } return null } + +@JvmName("kdTreeVector2") +fun Iterable.kdTree(): KDTreeNode { + val items = this.toMutableList() + return buildKDTree(items, 2, ::vector2Mapper) +} + +@JvmName("kdTreeVector3") +fun Iterable.kdTree(): KDTreeNode { + val items = this.toMutableList() + return buildKDTree(items, 3, ::vector3Mapper) +} + +@JvmName("kdTreeVector4") +fun Iterable.kdTree(): KDTreeNode { + val items = this.toMutableList() + return buildKDTree(items, 4, ::vector4Mapper) +} \ No newline at end of file