Add data binding for orx-gradient-descent
Add 0 check in minimizer
This commit is contained in:
73
orx-gradient-descent/src/main/kotlin/DataBinding.kt
Normal file
73
orx-gradient-descent/src/main/kotlin/DataBinding.kt
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package org.openrndr.extra.gradientdescent
|
||||||
|
|
||||||
|
import org.openrndr.math.Vector2
|
||||||
|
import org.openrndr.math.Vector3
|
||||||
|
import org.openrndr.math.Vector4
|
||||||
|
|
||||||
|
/**
|
||||||
|
* converts a model to an array of doubles
|
||||||
|
*/
|
||||||
|
fun <T : Any> modelToArray(model: T): DoubleArray {
|
||||||
|
val doubles = mutableListOf<Double>()
|
||||||
|
model::class.java.declaredFields.forEach {
|
||||||
|
when {
|
||||||
|
it.type == Double::class.java -> {
|
||||||
|
it.trySetAccessible()
|
||||||
|
doubles.add(it.getDouble(model))
|
||||||
|
}
|
||||||
|
it.type == Vector2::class.java -> {
|
||||||
|
it.trySetAccessible()
|
||||||
|
val v2 = it.get(model) as Vector2
|
||||||
|
doubles.add(v2.x)
|
||||||
|
doubles.add(v2.y)
|
||||||
|
}
|
||||||
|
it.type == Vector3::class.java -> {
|
||||||
|
it.trySetAccessible()
|
||||||
|
val v3 = it.get(model) as Vector3
|
||||||
|
doubles.add(v3.x)
|
||||||
|
doubles.add(v3.y)
|
||||||
|
doubles.add(v3.z)
|
||||||
|
}
|
||||||
|
it.type == Vector4::class.java -> {
|
||||||
|
it.trySetAccessible()
|
||||||
|
val v4 = it.get(model) as Vector4
|
||||||
|
doubles.add(v4.x)
|
||||||
|
doubles.add(v4.y)
|
||||||
|
doubles.add(v4.z)
|
||||||
|
doubles.add(v4.w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return doubles.toDoubleArray()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* converts array of doubles to model values
|
||||||
|
*/
|
||||||
|
fun <T : Any> arrayToModel(data: DoubleArray, model: T) {
|
||||||
|
var index = 0
|
||||||
|
model::class.java.declaredFields.forEach {
|
||||||
|
when {
|
||||||
|
it.type == Double::class.java -> {
|
||||||
|
it.trySetAccessible()
|
||||||
|
it.setDouble(model, data[index])
|
||||||
|
index++
|
||||||
|
}
|
||||||
|
it.type == Vector2::class.java -> {
|
||||||
|
it.trySetAccessible()
|
||||||
|
it.set(model, Vector2(data[index], data[index+1]))
|
||||||
|
index+=2
|
||||||
|
}
|
||||||
|
it.type == Vector3::class.java -> {
|
||||||
|
it.trySetAccessible()
|
||||||
|
it.set(model, Vector3(data[index], data[index+1],data[index+2]))
|
||||||
|
index+=3
|
||||||
|
}
|
||||||
|
it.type == Vector4::class.java -> {
|
||||||
|
it.trySetAccessible()
|
||||||
|
it.set(model, Vector4(data[index], data[index+1],data[index+2],data[index+3]))
|
||||||
|
index+=3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -21,6 +21,8 @@
|
|||||||
// THE SOFTWARE.
|
// THE SOFTWARE.
|
||||||
//
|
//
|
||||||
|
|
||||||
|
package org.openrndr.extra.gradientdescent
|
||||||
|
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
@@ -66,6 +68,7 @@ fun gradient(x: DoubleArray, objective: (parameters: DoubleArray) -> Double): Do
|
|||||||
k++
|
k++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
//println("gradient at (${x.contentToString()}) -> (${grad.contentToString()}) ")
|
||||||
return grad
|
return grad
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,8 +108,10 @@ fun minimize(_x0: DoubleArray, endOnLineSearch: Boolean = true, tol: Double = 1e
|
|||||||
while (iteration < maxIterations) {
|
while (iteration < maxIterations) {
|
||||||
require(g0.all { it == it && it != Double.POSITIVE_INFINITY && it != Double.NEGATIVE_INFINITY })
|
require(g0.all { it == it && it != Double.POSITIVE_INFINITY && it != Double.NEGATIVE_INFINITY })
|
||||||
val pstep = dot(H1, g0)
|
val pstep = dot(H1, g0)
|
||||||
|
require(pstep.all { it == it }) { "pstep contains NaNs"}
|
||||||
|
require(pstep.all { it != Double.POSITIVE_INFINITY && it != Double.NEGATIVE_INFINITY }) { "pstep contains infs" }
|
||||||
val step = neg(pstep)
|
val step = neg(pstep)
|
||||||
require(step.all { it == it && it != Double.POSITIVE_INFINITY && it != Double.NEGATIVE_INFINITY })
|
|
||||||
val nstep = norm2(step)
|
val nstep = norm2(step)
|
||||||
require(nstep == nstep)
|
require(nstep == nstep)
|
||||||
if (nstep < tol) {
|
if (nstep < tol) {
|
||||||
@@ -121,6 +126,8 @@ fun minimize(_x0: DoubleArray, endOnLineSearch: Boolean = true, tol: Double = 1e
|
|||||||
s = mul(step, t)
|
s = mul(step, t)
|
||||||
x1 = add(x0, s)
|
x1 = add(x0, s)
|
||||||
f1 = f(x1)
|
f1 = f(x1)
|
||||||
|
|
||||||
|
require(f1 == f1) { "f1 is NaN"}
|
||||||
if (!(f1 - f0 >= 0.1 * t * df0)) {
|
if (!(f1 - f0 >= 0.1 * t * df0)) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -134,8 +141,12 @@ fun minimize(_x0: DoubleArray, endOnLineSearch: Boolean = true, tol: Double = 1e
|
|||||||
}
|
}
|
||||||
if (iteration >= maxIterations) break
|
if (iteration >= maxIterations) break
|
||||||
val g1 = grad(x1)
|
val g1 = grad(x1)
|
||||||
|
require(g1.all { it == it })
|
||||||
val y = sub(g1, g0)
|
val y = sub(g1, g0)
|
||||||
val ys = dot(y, s)
|
val ys = dot(y, s)
|
||||||
|
if (ys==0.0) {
|
||||||
|
break
|
||||||
|
}
|
||||||
val Hy = dot(H1, y)
|
val Hy = dot(H1, y)
|
||||||
H1 = sub(
|
H1 = sub(
|
||||||
add(
|
add(
|
||||||
@@ -158,5 +169,6 @@ fun minimize(_x0: DoubleArray, endOnLineSearch: Boolean = true, tol: Double = 1e
|
|||||||
g0 = g1
|
g0 = g1
|
||||||
iteration++
|
iteration++
|
||||||
}
|
}
|
||||||
|
|
||||||
return MinimizationResult(x0, f0, g0, H1, iteration)
|
return MinimizationResult(x0, f0, g0, H1, iteration)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user