[orx-expression-evaluator-typed] Refactor function dispatch handling, improve function resolution and add new tests

This commit is contained in:
Edwin Jakobs
2025-05-13 23:21:27 +02:00
parent 781830ba96
commit b1e860cf39
3 changed files with 166 additions and 57 deletions

View File

@@ -0,0 +1,18 @@
package org.openrndr.extra.expressions.typed
import org.openrndr.extra.noise.uniform
/**
* Dispatches a function without arguments based on its name.
*
* @param name The name of the function to dispatch.
* @param functions A map containing functions of type `TypedFunction0` associated with their names.
* @return A callable lambda that takes an array of `Any` as input and returns a result if the function is found,
* or null if there is no match.
*/
internal fun dispatchFunction0(name: String, functions: Map<String, TypedFunction0>): ((Array<Any>) -> Any)? {
return when (name) {
"random" -> { x -> Double.uniform(0.0, 1.0) }
else -> functions[name]?.let { { x: Array<Any> -> it.invoke() } }
}
}

View File

@@ -780,6 +780,32 @@ abstract class TypedExpressionListenerBase(
return
}
fun <T> handleFunction(
name: String,
dispatchFunction: (String, Map<String, T>) -> ((Array<Any>) -> Any)?,
functionMap: Map<String, T>,
adapter: (T) -> (Array<Any>) -> Any,
errorMessage: String
) {
val function = dispatchFunction(name, functionMap)
if (function != null) {
s.functionStack.push(function)
} else {
val cfunction = constants(name) as? T
if (cfunction != null) {
s.functionStack.push(adapter(cfunction))
} else {
s.functionStack.push(errorValue("unresolved function: '$errorMessage'") { _ ->
error("this is the error function")
})
}
}
}
val type = node.symbol.type
if (type == KeyLangParser.Tokens.INTLIT) {
s.valueStack.pushChecked(node.text.toDouble())
@@ -794,24 +820,13 @@ abstract class TypedExpressionListenerBase(
IDType.VARIABLE -> s.valueStack.pushChecked(
when (name) {
"PI" -> PI
else -> constants(name) ?: errorValue("unresolved value: '${name}'. Available constant: ${constants}", Unit)
else -> constants(name)
?: errorValue("unresolved value: '${name}'. Available constant: ${constants}", Unit)
}
)
IDType.PROPERTY -> s.propertyStack.push(name)
IDType.FUNCTION0 -> {
val function: (Array<Any>) -> Any =
when (name) {
"random" -> { _ -> Double.uniform(0.0, 1.0) }
else -> functions.functions0[name]?.let { { _: Array<Any> -> it.invoke() } }
?: errorValue(
"unresolved function: '${name}()'"
) { _ -> error("this is the error function") }
}
s.functionStack.push(function)
}
IDType.MEMBER_FUNCTION0,
IDType.MEMBER_FUNCTION1,
IDType.MEMBER_FUNCTION2,
@@ -835,7 +850,8 @@ abstract class TypedExpressionListenerBase(
is ColorRGBa -> {
when (idType) {
IDType.MEMBER_FUNCTION1 -> {
s.functionStack.push(when (name) {
s.functionStack.push(
when (name) {
"shade" -> { x -> receiver.shade(x[0] as Double) }
"opacify" -> { x -> receiver.opacify(x[0] as Double) }
else -> error("no member function '$receiver.$name()'")
@@ -846,7 +862,6 @@ abstract class TypedExpressionListenerBase(
}
}
is Function<*> -> {
fun input(): String {
@@ -867,27 +882,28 @@ abstract class TypedExpressionListenerBase(
IDType.MEMBER_FUNCTION1 -> {
@Suppress("UNCHECKED_CAST")
(function as? (Any) -> Any) ?: error("Cannot cast function '$name' ($function) to (Any) -> Any ${input()}")
(function as? (Any) -> Any)
?: error("Cannot cast function '$name' ($function) to (Any) -> Any ${input()}")
s.functionStack.push({ x -> function(x[0]) })
}
IDType.MEMBER_FUNCTION2 -> {
@Suppress("UNCHECKED_CAST")
function as? (Any, Any) -> Any ?: error("Cannot cast function '$name' ($function) to (Any, Any) -> Any ${input()}")
function as? (Any, Any) -> Any
?: error("Cannot cast function '$name' ($function) to (Any, Any) -> Any ${input()}")
s.functionStack.push({ x -> function(x[0], x[1]) })
}
IDType.MEMBER_FUNCTION3 -> {
@Suppress("UNCHECKED_CAST")
function as? (Any, Any, Any) -> Any ?: error("Cannot cast function '$name' ($function) to (Any, Any, Any) -> Any ${input()}")
function as? (Any, Any, Any) -> Any
?: error("Cannot cast function '$name' ($function) to (Any, Any, Any) -> Any ${input()}")
s.functionStack.push({ x -> function(x[0], x[1], x[2]) })
}
else -> error("unreachable")
}
}
else -> error(
"receiver for '$name' '${
receiver.toString().take(30)
@@ -896,41 +912,55 @@ abstract class TypedExpressionListenerBase(
}
}
IDType.FUNCTION1 -> {
val localState = state
val function: (Array<Any>) -> Any =
dispatchFunction1(name, functions.functions1)
?: errorValue(
"unresolved function: '${name}(x0)'"
) { _ -> error("this is the error function") }
localState.functionStack.push(function)
}
IDType.FUNCTION0 -> handleFunction(
name,
::dispatchFunction0,
functions.functions0,
{ f -> { x -> f() } },
"${name}()"
)
IDType.FUNCTION2 -> {
val function: (Array<Any>) -> Any =
dispatchFunction2(name, functions.functions2)
?: errorValue(
"unresolved function: '${name}(x0, x1)'"
) { _ -> error("this is the error function") }
s.functionStack.push(function)
}
IDType.FUNCTION1 -> handleFunction(
name,
::dispatchFunction1,
functions.functions1,
{ f -> { x -> f(x[0]) } },
"${name}(x0)"
)
IDType.FUNCTION3 -> {
val function: (Array<Any>) -> Any =
dispatchFunction3(name, functions.functions3)
?: errorValue(
"unresolved function: '${name}(x0)'"
) { _ -> error("this is the error function") }
s.functionStack.push(function)
}
IDType.FUNCTION2 -> handleFunction(
name,
::dispatchFunction2,
functions.functions2,
{ f -> { x -> f(x[0], x[1]) } },
"${name}(x0, x1)"
)
IDType.FUNCTION4 -> {
val function: (Array<Any>) -> Any =
dispatchFunction4(name, functions.functions4)
?: errorValue(
"unresolved function: '${name}(x0)'"
) { _ -> error("this is the error function") }
s.functionStack.push(function)
IDType.FUNCTION3 -> handleFunction(
name,
::dispatchFunction3,
functions.functions3,
{ f -> { x -> f(x[0], x[1], x[2]) } },
"${name}(x0, x1, x2)"
)
IDType.FUNCTION4 -> handleFunction(
name,
::dispatchFunction4,
functions.functions4,
{ f -> { x -> f(x[0], x[1], x[2], x[3]) } },
"${name}(x0, x1, x2, x3)"
)
IDType.FUNCTION5 -> {
val cfunction = constants(name) as? (Any, Any, Any, Any, Any) -> Any
if (cfunction != null) {
s.functionStack.push({ x -> cfunction(x[0], x[1], x[2], x[3], x[4]) })
} else {
s.functionStack.push(errorValue("unresolved function: '${name}(x0, x1, x2, x3, x4)'") { _ ->
error("this is the error function")
})
}
}
IDType.FUNCTION_ARGUMENT -> {
@@ -1000,6 +1030,15 @@ fun evaluateTypedExpression(
return listener.state.lastExpressionResult
}
/**
* Compiles a typed expression and returns a lambda that can execute the compiled expression.
*
* @param expression The string representation of the expression to compile.
* @param constants A lambda function to resolve constants by their names. Defaults to a resolver that returns null.
* @param functions An instance of `TypedFunctionExtensions` containing the supported custom functions for the expression. Defaults to an empty set of functions.
* @return A lambda function that evaluates the compiled expression and returns its result.
* @throws ExpressionException If there is a syntax error or a parsing issue in the provided expression.
*/
fun compileTypedExpression(
expression: String,
constants: (String) -> Any? = { null },
@@ -1023,7 +1062,6 @@ fun compileTypedExpression(
val root = parser.keyLangFile()
val listener = TypedExpressionListener(functions, constants)
return {
try {
ParseTreeWalker.DEFAULT.walk(listener, root)

View File

@@ -61,4 +61,57 @@ class TestTypedCompiledExpression {
println("that took ${end - start}")
}
}
@Test
fun testFunction2() {
run {
val c = compileFunction1OrNull<Map<String, Any>, Double>("x.x + 3.0", "x")!!
assertEquals(1.0 + 3.0, c(mapOf("x" to 1.0)))
//assertEquals(2.0 + 3.0, c(mapOf("x" to 2.0))
}
}
@Test
fun testDynamicConstants() {
val env = { n: String ->
when (n) {
"a" -> { nn: String ->
when (nn) {
"a" -> { nnn: String ->
when (nnn) {
"b" -> 7.0
"c" -> { x: Double -> x + 1.0 }
else -> null
}
}
"b" -> 5.0
"c" -> { x: Double -> x * 2.0 }
else -> null
}
}
"c" -> { x: Double -> x * 3.0 }
else -> null
}
}
val c0 = compileFunction1OrNull<Map<String, Any>, Double>("a.a.c(2.0)", "x", constants = env)!!
val r0 = c0(emptyMap())
assertEquals(3.0, r0)
val c1 = compileFunction1OrNull<Map<String, Any>, Double>("a.c(2.0)", "x", constants = env)!!
val r1 = c1(emptyMap())
assertEquals(4.0, r1)
val c2 = compileFunction1OrNull<Map<String, Any>, Double>("c(2.0)", "x", constants = env)!!
val r2 = c2(emptyMap())
assertEquals(6.0, r2)
val c3 = compileFunction1OrNull<Map<String, Any>, Double>("cos(2.0)", "x", constants = env)!!
val r3 = c3(emptyMap())
}
}