[orx-expression-evaluator-typed] Refactor function dispatch handling, improve function resolution and add new tests
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
package org.openrndr.extra.expressions.typed
|
||||
|
||||
import org.openrndr.extra.noise.uniform
|
||||
|
||||
/**
|
||||
* Dispatches a function without arguments based on its name.
|
||||
*
|
||||
* @param name The name of the function to dispatch.
|
||||
* @param functions A map containing functions of type `TypedFunction0` associated with their names.
|
||||
* @return A callable lambda that takes an array of `Any` as input and returns a result if the function is found,
|
||||
* or null if there is no match.
|
||||
*/
|
||||
internal fun dispatchFunction0(name: String, functions: Map<String, TypedFunction0>): ((Array<Any>) -> Any)? {
|
||||
return when (name) {
|
||||
"random" -> { x -> Double.uniform(0.0, 1.0) }
|
||||
else -> functions[name]?.let { { x: Array<Any> -> it.invoke() } }
|
||||
}
|
||||
}
|
||||
@@ -780,6 +780,32 @@ abstract class TypedExpressionListenerBase(
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
fun <T> handleFunction(
|
||||
name: String,
|
||||
dispatchFunction: (String, Map<String, T>) -> ((Array<Any>) -> Any)?,
|
||||
functionMap: Map<String, T>,
|
||||
adapter: (T) -> (Array<Any>) -> Any,
|
||||
errorMessage: String
|
||||
) {
|
||||
val function = dispatchFunction(name, functionMap)
|
||||
|
||||
if (function != null) {
|
||||
s.functionStack.push(function)
|
||||
} else {
|
||||
val cfunction = constants(name) as? T
|
||||
if (cfunction != null) {
|
||||
s.functionStack.push(adapter(cfunction))
|
||||
} else {
|
||||
s.functionStack.push(errorValue("unresolved function: '$errorMessage'") { _ ->
|
||||
error("this is the error function")
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
val type = node.symbol.type
|
||||
if (type == KeyLangParser.Tokens.INTLIT) {
|
||||
s.valueStack.pushChecked(node.text.toDouble())
|
||||
@@ -794,24 +820,13 @@ abstract class TypedExpressionListenerBase(
|
||||
IDType.VARIABLE -> s.valueStack.pushChecked(
|
||||
when (name) {
|
||||
"PI" -> PI
|
||||
else -> constants(name) ?: errorValue("unresolved value: '${name}'. Available constant: ${constants}", Unit)
|
||||
else -> constants(name)
|
||||
?: errorValue("unresolved value: '${name}'. Available constant: ${constants}", Unit)
|
||||
}
|
||||
)
|
||||
|
||||
IDType.PROPERTY -> s.propertyStack.push(name)
|
||||
|
||||
IDType.FUNCTION0 -> {
|
||||
val function: (Array<Any>) -> Any =
|
||||
when (name) {
|
||||
"random" -> { _ -> Double.uniform(0.0, 1.0) }
|
||||
else -> functions.functions0[name]?.let { { _: Array<Any> -> it.invoke() } }
|
||||
?: errorValue(
|
||||
"unresolved function: '${name}()'"
|
||||
) { _ -> error("this is the error function") }
|
||||
}
|
||||
s.functionStack.push(function)
|
||||
}
|
||||
|
||||
IDType.MEMBER_FUNCTION0,
|
||||
IDType.MEMBER_FUNCTION1,
|
||||
IDType.MEMBER_FUNCTION2,
|
||||
@@ -835,18 +850,18 @@ abstract class TypedExpressionListenerBase(
|
||||
is ColorRGBa -> {
|
||||
when (idType) {
|
||||
IDType.MEMBER_FUNCTION1 -> {
|
||||
s.functionStack.push(when (name) {
|
||||
"shade" -> { x -> receiver.shade(x[0] as Double) }
|
||||
"opacify" -> { x -> receiver.opacify(x[0] as Double) }
|
||||
else -> error("no member function '$receiver.$name()'")
|
||||
})
|
||||
s.functionStack.push(
|
||||
when (name) {
|
||||
"shade" -> { x -> receiver.shade(x[0] as Double) }
|
||||
"opacify" -> { x -> receiver.opacify(x[0] as Double) }
|
||||
else -> error("no member function '$receiver.$name()'")
|
||||
})
|
||||
}
|
||||
|
||||
else -> error("no member function $idType '$receiver.$name()")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
is Function<*> -> {
|
||||
|
||||
fun input(): String {
|
||||
@@ -867,27 +882,28 @@ abstract class TypedExpressionListenerBase(
|
||||
|
||||
IDType.MEMBER_FUNCTION1 -> {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
(function as? (Any) -> Any) ?: error("Cannot cast function '$name' ($function) to (Any) -> Any ${input()}")
|
||||
(function as? (Any) -> Any)
|
||||
?: error("Cannot cast function '$name' ($function) to (Any) -> Any ${input()}")
|
||||
s.functionStack.push({ x -> function(x[0]) })
|
||||
}
|
||||
|
||||
IDType.MEMBER_FUNCTION2 -> {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
function as? (Any, Any) -> Any ?: error("Cannot cast function '$name' ($function) to (Any, Any) -> Any ${input()}")
|
||||
function as? (Any, Any) -> Any
|
||||
?: error("Cannot cast function '$name' ($function) to (Any, Any) -> Any ${input()}")
|
||||
s.functionStack.push({ x -> function(x[0], x[1]) })
|
||||
}
|
||||
|
||||
IDType.MEMBER_FUNCTION3 -> {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
function as? (Any, Any, Any) -> Any ?: error("Cannot cast function '$name' ($function) to (Any, Any, Any) -> Any ${input()}")
|
||||
function as? (Any, Any, Any) -> Any
|
||||
?: error("Cannot cast function '$name' ($function) to (Any, Any, Any) -> Any ${input()}")
|
||||
s.functionStack.push({ x -> function(x[0], x[1], x[2]) })
|
||||
}
|
||||
|
||||
else -> error("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
else -> error(
|
||||
"receiver for '$name' '${
|
||||
receiver.toString().take(30)
|
||||
@@ -896,41 +912,55 @@ abstract class TypedExpressionListenerBase(
|
||||
}
|
||||
}
|
||||
|
||||
IDType.FUNCTION1 -> {
|
||||
val localState = state
|
||||
val function: (Array<Any>) -> Any =
|
||||
dispatchFunction1(name, functions.functions1)
|
||||
?: errorValue(
|
||||
"unresolved function: '${name}(x0)'"
|
||||
) { _ -> error("this is the error function") }
|
||||
localState.functionStack.push(function)
|
||||
}
|
||||
IDType.FUNCTION0 -> handleFunction(
|
||||
name,
|
||||
::dispatchFunction0,
|
||||
functions.functions0,
|
||||
{ f -> { x -> f() } },
|
||||
"${name}()"
|
||||
)
|
||||
|
||||
IDType.FUNCTION2 -> {
|
||||
val function: (Array<Any>) -> Any =
|
||||
dispatchFunction2(name, functions.functions2)
|
||||
?: errorValue(
|
||||
"unresolved function: '${name}(x0, x1)'"
|
||||
) { _ -> error("this is the error function") }
|
||||
s.functionStack.push(function)
|
||||
}
|
||||
IDType.FUNCTION1 -> handleFunction(
|
||||
name,
|
||||
::dispatchFunction1,
|
||||
functions.functions1,
|
||||
{ f -> { x -> f(x[0]) } },
|
||||
"${name}(x0)"
|
||||
)
|
||||
|
||||
IDType.FUNCTION3 -> {
|
||||
val function: (Array<Any>) -> Any =
|
||||
dispatchFunction3(name, functions.functions3)
|
||||
?: errorValue(
|
||||
"unresolved function: '${name}(x0)'"
|
||||
) { _ -> error("this is the error function") }
|
||||
s.functionStack.push(function)
|
||||
}
|
||||
IDType.FUNCTION2 -> handleFunction(
|
||||
name,
|
||||
::dispatchFunction2,
|
||||
functions.functions2,
|
||||
{ f -> { x -> f(x[0], x[1]) } },
|
||||
"${name}(x0, x1)"
|
||||
)
|
||||
|
||||
IDType.FUNCTION4 -> {
|
||||
val function: (Array<Any>) -> Any =
|
||||
dispatchFunction4(name, functions.functions4)
|
||||
?: errorValue(
|
||||
"unresolved function: '${name}(x0)'"
|
||||
) { _ -> error("this is the error function") }
|
||||
s.functionStack.push(function)
|
||||
IDType.FUNCTION3 -> handleFunction(
|
||||
name,
|
||||
::dispatchFunction3,
|
||||
functions.functions3,
|
||||
{ f -> { x -> f(x[0], x[1], x[2]) } },
|
||||
"${name}(x0, x1, x2)"
|
||||
)
|
||||
|
||||
IDType.FUNCTION4 -> handleFunction(
|
||||
name,
|
||||
::dispatchFunction4,
|
||||
functions.functions4,
|
||||
{ f -> { x -> f(x[0], x[1], x[2], x[3]) } },
|
||||
"${name}(x0, x1, x2, x3)"
|
||||
)
|
||||
|
||||
IDType.FUNCTION5 -> {
|
||||
val cfunction = constants(name) as? (Any, Any, Any, Any, Any) -> Any
|
||||
if (cfunction != null) {
|
||||
s.functionStack.push({ x -> cfunction(x[0], x[1], x[2], x[3], x[4]) })
|
||||
} else {
|
||||
s.functionStack.push(errorValue("unresolved function: '${name}(x0, x1, x2, x3, x4)'") { _ ->
|
||||
error("this is the error function")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
IDType.FUNCTION_ARGUMENT -> {
|
||||
@@ -1000,6 +1030,15 @@ fun evaluateTypedExpression(
|
||||
return listener.state.lastExpressionResult
|
||||
}
|
||||
|
||||
/**
|
||||
* Compiles a typed expression and returns a lambda that can execute the compiled expression.
|
||||
*
|
||||
* @param expression The string representation of the expression to compile.
|
||||
* @param constants A lambda function to resolve constants by their names. Defaults to a resolver that returns null.
|
||||
* @param functions An instance of `TypedFunctionExtensions` containing the supported custom functions for the expression. Defaults to an empty set of functions.
|
||||
* @return A lambda function that evaluates the compiled expression and returns its result.
|
||||
* @throws ExpressionException If there is a syntax error or a parsing issue in the provided expression.
|
||||
*/
|
||||
fun compileTypedExpression(
|
||||
expression: String,
|
||||
constants: (String) -> Any? = { null },
|
||||
@@ -1023,7 +1062,6 @@ fun compileTypedExpression(
|
||||
val root = parser.keyLangFile()
|
||||
val listener = TypedExpressionListener(functions, constants)
|
||||
|
||||
|
||||
return {
|
||||
try {
|
||||
ParseTreeWalker.DEFAULT.walk(listener, root)
|
||||
|
||||
@@ -61,4 +61,57 @@ class TestTypedCompiledExpression {
|
||||
println("that took ${end - start}")
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFunction2() {
|
||||
run {
|
||||
val c = compileFunction1OrNull<Map<String, Any>, Double>("x.x + 3.0", "x")!!
|
||||
assertEquals(1.0 + 3.0, c(mapOf("x" to 1.0)))
|
||||
//assertEquals(2.0 + 3.0, c(mapOf("x" to 2.0))
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDynamicConstants() {
|
||||
|
||||
val env = { n: String ->
|
||||
when (n) {
|
||||
"a" -> { nn: String ->
|
||||
when (nn) {
|
||||
"a" -> { nnn: String ->
|
||||
when (nnn) {
|
||||
"b" -> 7.0
|
||||
"c" -> { x: Double -> x + 1.0 }
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
"b" -> 5.0
|
||||
"c" -> { x: Double -> x * 2.0 }
|
||||
else -> null
|
||||
}
|
||||
|
||||
}
|
||||
"c" -> { x: Double -> x * 3.0 }
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
val c0 = compileFunction1OrNull<Map<String, Any>, Double>("a.a.c(2.0)", "x", constants = env)!!
|
||||
val r0 = c0(emptyMap())
|
||||
assertEquals(3.0, r0)
|
||||
|
||||
val c1 = compileFunction1OrNull<Map<String, Any>, Double>("a.c(2.0)", "x", constants = env)!!
|
||||
val r1 = c1(emptyMap())
|
||||
assertEquals(4.0, r1)
|
||||
|
||||
val c2 = compileFunction1OrNull<Map<String, Any>, Double>("c(2.0)", "x", constants = env)!!
|
||||
val r2 = c2(emptyMap())
|
||||
assertEquals(6.0, r2)
|
||||
|
||||
val c3 = compileFunction1OrNull<Map<String, Any>, Double>("cos(2.0)", "x", constants = env)!!
|
||||
val r3 = c3(emptyMap())
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user