[orx-kdtree] Refactor KDTree interface to improve ease of use

This commit is contained in:
Edwin Jakobs
2022-01-01 13:46:52 +01:00
parent 7cba88143d
commit e7b143493c
4 changed files with 129 additions and 89 deletions

View File

@@ -2,13 +2,13 @@ import org.openrndr.application
import org.openrndr.color.ColorRGBa import org.openrndr.color.ColorRGBa
import org.openrndr.extra.kdtree.buildKDTree import org.openrndr.extra.kdtree.buildKDTree
import org.openrndr.extra.kdtree.findKNearest import org.openrndr.extra.kdtree.findKNearest
import org.openrndr.extra.kdtree.kdTree
import org.openrndr.extra.kdtree.vector2Mapper import org.openrndr.extra.kdtree.vector2Mapper
import org.openrndr.math.Vector2 import org.openrndr.math.Vector2
import org.openrndr.shape.LineSegment import org.openrndr.shape.LineSegment
fun main() { fun main() {
application { application {
configure { configure {
width = 1080 width = 1080
height = 720 height = 720
@@ -18,12 +18,12 @@ fun main() {
val points = MutableList(1000) { val points = MutableList(1000) {
Vector2(Math.random() * width, Math.random() * height) Vector2(Math.random() * width, Math.random() * height)
} }
val tree = buildKDTree(points, 2, ::vector2Mapper) val tree = points.kdTree()
extend { extend {
drawer.circles(points, 5.0) drawer.circles(points, 5.0)
val kNearest = findKNearest(tree, mouse.position, k=7, dimensions = 2, ::vector2Mapper) val kNearest = tree.findKNearest(mouse.position, k = 7)
drawer.fill = ColorRGBa.RED drawer.fill = ColorRGBa.RED
drawer.stroke = ColorRGBa.RED drawer.stroke = ColorRGBa.RED
drawer.strokeWeight = 2.0 drawer.strokeWeight = 2.0

View File

@@ -1,8 +1,5 @@
import org.openrndr.application import org.openrndr.application
import org.openrndr.extensions.SingleScreenshot import org.openrndr.extra.kdtree.kdTree
import org.openrndr.extra.kdtree.buildKDTree
import org.openrndr.extra.kdtree.findNearest
import org.openrndr.extra.kdtree.vector2Mapper
import org.openrndr.math.Vector2 import org.openrndr.math.Vector2
fun main() { fun main() {
@@ -15,10 +12,10 @@ fun main() {
val points = MutableList(1000) { val points = MutableList(1000) {
Vector2(Math.random() * width, Math.random() * height) Vector2(Math.random() * width, Math.random() * height)
} }
val tree = buildKDTree(points, 2, ::vector2Mapper) val tree = points.kdTree()
extend { extend {
drawer.circles(points, 5.0) drawer.circles(points, 5.0)
val nearest = findNearest(tree, mouse.position, 2, ::vector2Mapper) val nearest = tree.findNearest(mouse.position)
nearest?.let { nearest?.let {
drawer.circle(it.x, it.y, 20.0) drawer.circle(it.x, it.y, 20.0)
} }

View File

@@ -1,8 +1,6 @@
import org.openrndr.application import org.openrndr.application
import org.openrndr.color.ColorRGBa import org.openrndr.color.ColorRGBa
import org.openrndr.extra.kdtree.buildKDTree import org.openrndr.extra.kdtree.kdTree
import org.openrndr.extra.kdtree.findAllInRange
import org.openrndr.extra.kdtree.vector2Mapper
import org.openrndr.math.Vector2 import org.openrndr.math.Vector2
@@ -18,13 +16,13 @@ fun main() {
val points = MutableList(1000) { val points = MutableList(1000) {
Vector2(Math.random() * width, Math.random() * height) Vector2(Math.random() * width, Math.random() * height)
} }
val tree = buildKDTree(points, 2, ::vector2Mapper) val tree = points.kdTree()
val radius = 50.0 val radius = 50.0
extend { extend {
drawer.circles(points, 5.0) drawer.circles(points, 5.0)
val allInRange = findAllInRange(tree, mouse.position, maxDistance = radius, dimensions = 2, ::vector2Mapper) val allInRange = tree.findAllInRadius(mouse.position, radius = radius)
drawer.fill = ColorRGBa.PINK drawer.fill = ColorRGBa.PINK
drawer.stroke = ColorRGBa.PINK drawer.stroke = ColorRGBa.PINK
drawer.strokeWeight = 2.0 drawer.strokeWeight = 2.0

View File

@@ -4,10 +4,7 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import org.openrndr.math.IntVector2 import org.openrndr.math.*
import org.openrndr.math.Vector2
import org.openrndr.math.Vector3
import org.openrndr.math.Vector4
import java.util.* import java.util.*
import kotlin.IllegalStateException import kotlin.IllegalStateException
import kotlin.math.abs import kotlin.math.abs
@@ -47,7 +44,7 @@ fun vector4Mapper(v: Vector4, dimension: Int): Double {
} }
} }
class KDTreeNode<T> { class KDTreeNode<T>(val dimensions: Int, val mapper: (T, Int) -> Double) {
var parent: KDTreeNode<T>? = null var parent: KDTreeNode<T>? = null
var median: Double = 0.0 var median: Double = 0.0
var dimension: Int = 0 var dimension: Int = 0
@@ -57,6 +54,26 @@ class KDTreeNode<T> {
internal val isLeaf: Boolean internal val isLeaf: Boolean
get() = children[0] == null && children[1] == null get() = children[0] == null && children[1] == null
fun insert(item: T): KDTreeNode<T> {
return insert(this, item, dimensions, mapper)
}
fun remove(node: KDTreeNode<T>): KDTreeNode<T>? {
return org.openrndr.extra.kdtree.remove(node, mapper)
}
fun findNearest(query: T, includeQuery: Boolean = false): T? = findNearest(this, query, includeQuery)
fun findKNearest(query: T, k: Int, includeQuery: Boolean = false): List<T> {
return findKNearest(this, query, k, includeQuery)
}
fun findAllInRadius(query: T, radius: Double, includeQuery: Boolean = false): List<T> {
return findAllInRadius(this, query, radius, includeQuery)
}
override fun toString(): String { override fun toString(): String {
return "KDTreeNode{" + return "KDTreeNode{" +
"median=" + median + "median=" + median +
@@ -68,24 +85,32 @@ class KDTreeNode<T> {
} }
} }
fun <T> insertItem(root: KDTreeNode<T>, item: T, mapper: (T, Int) -> Double): KDTreeNode<T> { private fun <T> insertItem(root: KDTreeNode<T>, item: T): KDTreeNode<T> {
return if (root.isLeaf) { return if (root.isLeaf) {
root.item = item root.item = item
root root
} else { } else {
if (mapper(item, root.dimension) < root.median) { if (root.mapper(item, root.dimension) < root.median) {
insertItem(root.children[0] ?: throw IllegalStateException("left is null"), item, mapper) insertItem(root.children[0] ?: throw IllegalStateException("left is null"), item)
} else { } else {
insertItem(root.children[1] ?: throw IllegalStateException("right is null"), item, mapper) insertItem(root.children[1] ?: throw IllegalStateException("right is null"), item)
} }
} }
} }
fun <T> buildKDTree(items: MutableList<T>, dimensions: Int, mapper: (T, Int) -> Double): KDTreeNode<T> { fun <T> buildKDTree(items: MutableList<T>, dimensions: Int, mapper: (T, Int) -> Double): KDTreeNode<T> {
val root = KDTreeNode<T>() val root = KDTreeNode<T>(dimensions, mapper)
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
fun <T> buildTreeTask(scope: CoroutineScope, node: KDTreeNode<T>, items: MutableList<T>, dimensions: Int, levels: Int, mapper: (T, Int) -> Double): KDTreeNode<T> { fun <T> buildTreeTask(
scope: CoroutineScope,
node: KDTreeNode<T>,
items: MutableList<T>,
dimensions: Int,
levels: Int,
mapper: (T, Int) -> Double
): KDTreeNode<T> {
if (items.size > 0) { if (items.size > 0) {
val dimension = levels % dimensions val dimension = levels % dimensions
@@ -119,7 +144,7 @@ fun <T> buildKDTree(items: MutableList<T>, dimensions: Int, mapper: (T, Int) ->
} }
if (leftItems.size > 0) { if (leftItems.size > 0) {
node.children[0] = KDTreeNode() node.children[0] = KDTreeNode(dimensions, mapper)
node.children[0]?.let { node.children[0]?.let {
it.parent = node it.parent = node
@@ -129,7 +154,7 @@ fun <T> buildKDTree(items: MutableList<T>, dimensions: Int, mapper: (T, Int) ->
} }
} }
if (rightItems.size > 0) { if (rightItems.size > 0) {
node.children[1] = KDTreeNode() node.children[1] = KDTreeNode(dimensions, mapper)
node.children[1]?.let { node.children[1]?.let {
it.parent = node it.parent = node
scope.launch { scope.launch {
@@ -147,7 +172,7 @@ fun <T> buildKDTree(items: MutableList<T>, dimensions: Int, mapper: (T, Int) ->
runBlocking { runBlocking {
job.join() job.join()
} }
println("building took ${System.currentTimeMillis()-start}ms") println("building took ${System.currentTimeMillis() - start}ms")
return root return root
} }
@@ -185,45 +210,45 @@ fun <T> findAllNodes(root: KDTreeNode<T>): List<KDTreeNode<T>> {
fun <T> findKNearest( fun <T> findKNearest(
root: KDTreeNode<T>, root: KDTreeNode<T>,
item: T, query: T,
k: Int, k: Int,
dimensions: Int, includeQuery: Boolean = false
mapper: (T, Int) -> Double
): List<T> { ): List<T> {
// max-heap with size k // max-heap with size k
val queue = PriorityQueue<Pair<KDTreeNode<T>, Double>>(k + 1) { val queue = PriorityQueue<Pair<KDTreeNode<T>, Double>>(k + 1) { nodeA, nodeB ->
nodeA, nodeB -> compareValues(nodeB.second, nodeA.second) compareValues(nodeB.second, nodeA.second)
} }
fun nearest(node: KDTreeNode<T>?, item: T) { fun nearest(node: KDTreeNode<T>?) {
if (node != null) { if (node != null) {
val dimensionValue = mapper(item, node.dimension) val dimensionValue = node.mapper(query, node.dimension)
val route: Int = if (dimensionValue < node.median) { val route: Int = if (dimensionValue < node.median) {
nearest(node.children[0], item) nearest(node.children[0])
0 0
} else { } else {
nearest(node.children[1], item) nearest(node.children[1])
1 1
} }
val distance = sqrDistance(item, node.item val distance = sqrDistance(query, node.item ?: error("item is null"), node.dimensions, node.mapper)
?: throw IllegalStateException("item is null"), dimensions, mapper)
if (includeQuery || node.item !== query) {
if (queue.size < k || distance < queue.peek().second) { if (queue.size < k || distance < queue.peek().second) {
queue.add(Pair(node, distance)) queue.add(Pair(node, distance))
if (queue.size > k) { if (queue.size > k) {
queue.poll() queue.poll()
} }
} }
}
val d = abs(node.median - dimensionValue) val d = abs(node.median - dimensionValue)
if (d * d < queue.peek().second || queue.size < k) { if (d * d < queue.peek().second || queue.size < k) {
nearest(node.children[1 - route], item) nearest(node.children[1 - route])
} }
} }
} }
nearest(root, item) nearest(root)
return generateSequence { queue.poll() } return generateSequence { queue.poll() }
.map { it.first.item } .map { it.first.item }
@@ -231,77 +256,79 @@ fun <T> findKNearest(
.toList().reversed() .toList().reversed()
} }
fun <T> findNearest(root: KDTreeNode<T>, item: T, dimensions: Int, mapper: (T, Int) -> Double): T? { private fun <T> findNearest(root: KDTreeNode<T>, query: T, includeQuery: Boolean = false): T? {
var nearest = java.lang.Double.POSITIVE_INFINITY var nearest = java.lang.Double.POSITIVE_INFINITY
var nearestArg: KDTreeNode<T>? = null var nearestArg: KDTreeNode<T>? = null
fun nearest(node: KDTreeNode<T>?, item: T) { fun nearest(node: KDTreeNode<T>?) {
if (node != null) { if (node != null) {
val route: Int = if (root.mapper(query, node.dimension) < node.median) {
if (node.item == null) { nearest(node.children[0])
println(node)
}
val route: Int = if (mapper(item, node.dimension) < node.median) {
nearest(node.children[0], item)
0 0
} else { } else {
nearest(node.children[1], item) nearest(node.children[1])
1 1
} }
val distance = sqrDistance(item, node.item val distance = sqrDistance(
?: throw IllegalStateException("item is null"), dimensions, mapper) query, node.item
if (distance < nearest) { ?: error("item is null"), root.dimensions, root.mapper
)
if (distance < nearest && (includeQuery || node.item !== query)) {
nearest = distance nearest = distance
nearestArg = node nearestArg = node
} }
val d = abs(node.median - root.mapper(query, node.dimension))
val d = abs(node.median - mapper(item, node.dimension))
if (d * d < nearest) { if (d * d < nearest) {
nearest(node.children[1 - route], item) nearest(node.children[1 - route])
} }
} }
} }
nearest(root, item) nearest(root)
return nearestArg?.item return nearestArg?.item
} }
fun <T> findAllInRange( private fun <T> findAllInRadius(
root: KDTreeNode<T>, root: KDTreeNode<T>,
item: T, query: T,
maxDistance: Double, radius: Double,
dimensions: Int, includeQuery: Boolean = false
mapper: (T, Int) -> Double ): List<T> {
) : List<T> {
val sqrMaxDist = maxDistance * maxDistance val sqrMaxDist = radius * radius
val queue = kotlin.collections.ArrayDeque<KDTreeNode<T>?>() val queue = ArrayDeque<KDTreeNode<T>>()
queue.add(root) queue.add(root)
val results = mutableListOf<T?>() val results = mutableListOf<T?>()
while (queue.isNotEmpty()) { while (queue.isNotEmpty()) {
val node = queue.removeFirst() val node = queue.removeFirst()
if (node != null) { val dimensionValue = node.mapper(query, node.dimension)
val dimensionValue = mapper(item, node.dimension) val distance = sqrDistance(
val distance = sqrDistance(item, node.item query, node.item
?: throw IllegalStateException("item is null"), dimensions, mapper) ?: error("item is null"), node.dimensions, node.mapper
if (distance <= sqrMaxDist) { )
if (distance <= sqrMaxDist && (includeQuery || node.item != query)) {
results.add(node.item) results.add(node.item)
} }
val route: Int = if (dimensionValue < node.median) { val route: Int = if (dimensionValue < node.median && node.children[0] != null) {
queue.add(node.children[0]) queue.add(node.children[0])
0 0
} else { } else if (node.children[1] != null) {
queue.add(node.children[1]) queue.add(node.children[1])
1 1
} else {
-1
} }
if (route != -1) {
val d = abs(node.median - dimensionValue) val d = abs(node.median - dimensionValue)
if (d * d <= sqrMaxDist) { if (d * d <= sqrMaxDist) {
queue.add(node.children[1 - route]) val c = node.children[1 - route]
if (c != null) {
queue.add(c)
}
} }
} }
} }
@@ -309,7 +336,7 @@ fun <T> findAllInRange(
return results.filterNotNull() return results.filterNotNull()
} }
fun <T> insert(root: KDTreeNode<T>, item: T, dimensions: Int, mapper: (T, Int) -> Double): KDTreeNode<T> { private fun <T> insert(root: KDTreeNode<T>, item: T, dimensions: Int, mapper: (T, Int) -> Double): KDTreeNode<T> {
val stack = Stack<KDTreeNode<T>>() val stack = Stack<KDTreeNode<T>>()
stack.push(root) stack.push(root)
@@ -324,7 +351,7 @@ fun <T> insert(root: KDTreeNode<T>, item: T, dimensions: Int, mapper: (T, Int) -
stack.push(node.children[0]) stack.push(node.children[0])
} else { } else {
// sit here // sit here
node.children[0] = KDTreeNode() node.children[0] = KDTreeNode(dimensions, mapper)
node.children[0]?.item = item node.children[0]?.item = item
node.children[0]?.dimension = (node.dimension + 1) % dimensions node.children[0]?.dimension = (node.dimension + 1) % dimensions
node.children[0]?.median = mapper(item, (node.dimension + 1) % dimensions) node.children[0]?.median = mapper(item, (node.dimension + 1) % dimensions)
@@ -336,7 +363,7 @@ fun <T> insert(root: KDTreeNode<T>, item: T, dimensions: Int, mapper: (T, Int) -
stack.push(node.children[1]) stack.push(node.children[1])
} else { } else {
// sit here // sit here
node.children[1] = KDTreeNode() node.children[1] = KDTreeNode(dimensions, mapper)
node.children[1]?.item = item node.children[1]?.item = item
node.children[1]?.dimension = (node.dimension + 1) % dimensions node.children[1]?.dimension = (node.dimension + 1) % dimensions
node.children[1]?.median = mapper(item, (node.dimension + 1) % dimensions) node.children[1]?.median = mapper(item, (node.dimension + 1) % dimensions)
@@ -348,7 +375,7 @@ fun <T> insert(root: KDTreeNode<T>, item: T, dimensions: Int, mapper: (T, Int) -
} }
} }
fun <T> remove(toRemove: KDTreeNode<T>, mapper: (T, Int) -> Double): KDTreeNode<T>? { private fun <T> remove(toRemove: KDTreeNode<T>, mapper: (T, Int) -> Double): KDTreeNode<T>? {
// trivial case // trivial case
if (toRemove.isLeaf) { if (toRemove.isLeaf) {
val p = toRemove.parent val p = toRemove.parent
@@ -439,3 +466,21 @@ fun <T> remove(toRemove: KDTreeNode<T>, mapper: (T, Int) -> Double): KDTreeNode<
} }
return null return null
} }
@JvmName("kdTreeVector2")
fun Iterable<Vector2>.kdTree(): KDTreeNode<Vector2> {
val items = this.toMutableList()
return buildKDTree(items, 2, ::vector2Mapper)
}
@JvmName("kdTreeVector3")
fun Iterable<Vector3>.kdTree(): KDTreeNode<Vector3> {
val items = this.toMutableList()
return buildKDTree(items, 3, ::vector3Mapper)
}
@JvmName("kdTreeVector4")
fun Iterable<Vector4>.kdTree(): KDTreeNode<Vector4> {
val items = this.toMutableList()
return buildKDTree(items, 4, ::vector4Mapper)
}