[orx-expression-evaluator-typed] Refactor function dispatch handling, improve function resolution and add new tests
This commit is contained in:
@@ -0,0 +1,18 @@
|
|||||||
|
package org.openrndr.extra.expressions.typed
|
||||||
|
|
||||||
|
import org.openrndr.extra.noise.uniform
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dispatches a function without arguments based on its name.
|
||||||
|
*
|
||||||
|
* @param name The name of the function to dispatch.
|
||||||
|
* @param functions A map containing functions of type `TypedFunction0` associated with their names.
|
||||||
|
* @return A callable lambda that takes an array of `Any` as input and returns a result if the function is found,
|
||||||
|
* or null if there is no match.
|
||||||
|
*/
|
||||||
|
internal fun dispatchFunction0(name: String, functions: Map<String, TypedFunction0>): ((Array<Any>) -> Any)? {
|
||||||
|
return when (name) {
|
||||||
|
"random" -> { x -> Double.uniform(0.0, 1.0) }
|
||||||
|
else -> functions[name]?.let { { x: Array<Any> -> it.invoke() } }
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -780,6 +780,32 @@ abstract class TypedExpressionListenerBase(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fun <T> handleFunction(
|
||||||
|
name: String,
|
||||||
|
dispatchFunction: (String, Map<String, T>) -> ((Array<Any>) -> Any)?,
|
||||||
|
functionMap: Map<String, T>,
|
||||||
|
adapter: (T) -> (Array<Any>) -> Any,
|
||||||
|
errorMessage: String
|
||||||
|
) {
|
||||||
|
val function = dispatchFunction(name, functionMap)
|
||||||
|
|
||||||
|
if (function != null) {
|
||||||
|
s.functionStack.push(function)
|
||||||
|
} else {
|
||||||
|
val cfunction = constants(name) as? T
|
||||||
|
if (cfunction != null) {
|
||||||
|
s.functionStack.push(adapter(cfunction))
|
||||||
|
} else {
|
||||||
|
s.functionStack.push(errorValue("unresolved function: '$errorMessage'") { _ ->
|
||||||
|
error("this is the error function")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
val type = node.symbol.type
|
val type = node.symbol.type
|
||||||
if (type == KeyLangParser.Tokens.INTLIT) {
|
if (type == KeyLangParser.Tokens.INTLIT) {
|
||||||
s.valueStack.pushChecked(node.text.toDouble())
|
s.valueStack.pushChecked(node.text.toDouble())
|
||||||
@@ -794,24 +820,13 @@ abstract class TypedExpressionListenerBase(
|
|||||||
IDType.VARIABLE -> s.valueStack.pushChecked(
|
IDType.VARIABLE -> s.valueStack.pushChecked(
|
||||||
when (name) {
|
when (name) {
|
||||||
"PI" -> PI
|
"PI" -> PI
|
||||||
else -> constants(name) ?: errorValue("unresolved value: '${name}'. Available constant: ${constants}", Unit)
|
else -> constants(name)
|
||||||
|
?: errorValue("unresolved value: '${name}'. Available constant: ${constants}", Unit)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
IDType.PROPERTY -> s.propertyStack.push(name)
|
IDType.PROPERTY -> s.propertyStack.push(name)
|
||||||
|
|
||||||
IDType.FUNCTION0 -> {
|
|
||||||
val function: (Array<Any>) -> Any =
|
|
||||||
when (name) {
|
|
||||||
"random" -> { _ -> Double.uniform(0.0, 1.0) }
|
|
||||||
else -> functions.functions0[name]?.let { { _: Array<Any> -> it.invoke() } }
|
|
||||||
?: errorValue(
|
|
||||||
"unresolved function: '${name}()'"
|
|
||||||
) { _ -> error("this is the error function") }
|
|
||||||
}
|
|
||||||
s.functionStack.push(function)
|
|
||||||
}
|
|
||||||
|
|
||||||
IDType.MEMBER_FUNCTION0,
|
IDType.MEMBER_FUNCTION0,
|
||||||
IDType.MEMBER_FUNCTION1,
|
IDType.MEMBER_FUNCTION1,
|
||||||
IDType.MEMBER_FUNCTION2,
|
IDType.MEMBER_FUNCTION2,
|
||||||
@@ -835,7 +850,8 @@ abstract class TypedExpressionListenerBase(
|
|||||||
is ColorRGBa -> {
|
is ColorRGBa -> {
|
||||||
when (idType) {
|
when (idType) {
|
||||||
IDType.MEMBER_FUNCTION1 -> {
|
IDType.MEMBER_FUNCTION1 -> {
|
||||||
s.functionStack.push(when (name) {
|
s.functionStack.push(
|
||||||
|
when (name) {
|
||||||
"shade" -> { x -> receiver.shade(x[0] as Double) }
|
"shade" -> { x -> receiver.shade(x[0] as Double) }
|
||||||
"opacify" -> { x -> receiver.opacify(x[0] as Double) }
|
"opacify" -> { x -> receiver.opacify(x[0] as Double) }
|
||||||
else -> error("no member function '$receiver.$name()'")
|
else -> error("no member function '$receiver.$name()'")
|
||||||
@@ -846,7 +862,6 @@ abstract class TypedExpressionListenerBase(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
is Function<*> -> {
|
is Function<*> -> {
|
||||||
|
|
||||||
fun input(): String {
|
fun input(): String {
|
||||||
@@ -867,27 +882,28 @@ abstract class TypedExpressionListenerBase(
|
|||||||
|
|
||||||
IDType.MEMBER_FUNCTION1 -> {
|
IDType.MEMBER_FUNCTION1 -> {
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
(function as? (Any) -> Any) ?: error("Cannot cast function '$name' ($function) to (Any) -> Any ${input()}")
|
(function as? (Any) -> Any)
|
||||||
|
?: error("Cannot cast function '$name' ($function) to (Any) -> Any ${input()}")
|
||||||
s.functionStack.push({ x -> function(x[0]) })
|
s.functionStack.push({ x -> function(x[0]) })
|
||||||
}
|
}
|
||||||
|
|
||||||
IDType.MEMBER_FUNCTION2 -> {
|
IDType.MEMBER_FUNCTION2 -> {
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
function as? (Any, Any) -> Any ?: error("Cannot cast function '$name' ($function) to (Any, Any) -> Any ${input()}")
|
function as? (Any, Any) -> Any
|
||||||
|
?: error("Cannot cast function '$name' ($function) to (Any, Any) -> Any ${input()}")
|
||||||
s.functionStack.push({ x -> function(x[0], x[1]) })
|
s.functionStack.push({ x -> function(x[0], x[1]) })
|
||||||
}
|
}
|
||||||
|
|
||||||
IDType.MEMBER_FUNCTION3 -> {
|
IDType.MEMBER_FUNCTION3 -> {
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
function as? (Any, Any, Any) -> Any ?: error("Cannot cast function '$name' ($function) to (Any, Any, Any) -> Any ${input()}")
|
function as? (Any, Any, Any) -> Any
|
||||||
|
?: error("Cannot cast function '$name' ($function) to (Any, Any, Any) -> Any ${input()}")
|
||||||
s.functionStack.push({ x -> function(x[0], x[1], x[2]) })
|
s.functionStack.push({ x -> function(x[0], x[1], x[2]) })
|
||||||
}
|
}
|
||||||
|
|
||||||
else -> error("unreachable")
|
else -> error("unreachable")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
else -> error(
|
else -> error(
|
||||||
"receiver for '$name' '${
|
"receiver for '$name' '${
|
||||||
receiver.toString().take(30)
|
receiver.toString().take(30)
|
||||||
@@ -896,41 +912,55 @@ abstract class TypedExpressionListenerBase(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
IDType.FUNCTION1 -> {
|
IDType.FUNCTION0 -> handleFunction(
|
||||||
val localState = state
|
name,
|
||||||
val function: (Array<Any>) -> Any =
|
::dispatchFunction0,
|
||||||
dispatchFunction1(name, functions.functions1)
|
functions.functions0,
|
||||||
?: errorValue(
|
{ f -> { x -> f() } },
|
||||||
"unresolved function: '${name}(x0)'"
|
"${name}()"
|
||||||
) { _ -> error("this is the error function") }
|
)
|
||||||
localState.functionStack.push(function)
|
|
||||||
}
|
|
||||||
|
|
||||||
IDType.FUNCTION2 -> {
|
IDType.FUNCTION1 -> handleFunction(
|
||||||
val function: (Array<Any>) -> Any =
|
name,
|
||||||
dispatchFunction2(name, functions.functions2)
|
::dispatchFunction1,
|
||||||
?: errorValue(
|
functions.functions1,
|
||||||
"unresolved function: '${name}(x0, x1)'"
|
{ f -> { x -> f(x[0]) } },
|
||||||
) { _ -> error("this is the error function") }
|
"${name}(x0)"
|
||||||
s.functionStack.push(function)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
IDType.FUNCTION3 -> {
|
IDType.FUNCTION2 -> handleFunction(
|
||||||
val function: (Array<Any>) -> Any =
|
name,
|
||||||
dispatchFunction3(name, functions.functions3)
|
::dispatchFunction2,
|
||||||
?: errorValue(
|
functions.functions2,
|
||||||
"unresolved function: '${name}(x0)'"
|
{ f -> { x -> f(x[0], x[1]) } },
|
||||||
) { _ -> error("this is the error function") }
|
"${name}(x0, x1)"
|
||||||
s.functionStack.push(function)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
IDType.FUNCTION4 -> {
|
IDType.FUNCTION3 -> handleFunction(
|
||||||
val function: (Array<Any>) -> Any =
|
name,
|
||||||
dispatchFunction4(name, functions.functions4)
|
::dispatchFunction3,
|
||||||
?: errorValue(
|
functions.functions3,
|
||||||
"unresolved function: '${name}(x0)'"
|
{ f -> { x -> f(x[0], x[1], x[2]) } },
|
||||||
) { _ -> error("this is the error function") }
|
"${name}(x0, x1, x2)"
|
||||||
s.functionStack.push(function)
|
)
|
||||||
|
|
||||||
|
IDType.FUNCTION4 -> handleFunction(
|
||||||
|
name,
|
||||||
|
::dispatchFunction4,
|
||||||
|
functions.functions4,
|
||||||
|
{ f -> { x -> f(x[0], x[1], x[2], x[3]) } },
|
||||||
|
"${name}(x0, x1, x2, x3)"
|
||||||
|
)
|
||||||
|
|
||||||
|
IDType.FUNCTION5 -> {
|
||||||
|
val cfunction = constants(name) as? (Any, Any, Any, Any, Any) -> Any
|
||||||
|
if (cfunction != null) {
|
||||||
|
s.functionStack.push({ x -> cfunction(x[0], x[1], x[2], x[3], x[4]) })
|
||||||
|
} else {
|
||||||
|
s.functionStack.push(errorValue("unresolved function: '${name}(x0, x1, x2, x3, x4)'") { _ ->
|
||||||
|
error("this is the error function")
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
IDType.FUNCTION_ARGUMENT -> {
|
IDType.FUNCTION_ARGUMENT -> {
|
||||||
@@ -1000,6 +1030,15 @@ fun evaluateTypedExpression(
|
|||||||
return listener.state.lastExpressionResult
|
return listener.state.lastExpressionResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compiles a typed expression and returns a lambda that can execute the compiled expression.
|
||||||
|
*
|
||||||
|
* @param expression The string representation of the expression to compile.
|
||||||
|
* @param constants A lambda function to resolve constants by their names. Defaults to a resolver that returns null.
|
||||||
|
* @param functions An instance of `TypedFunctionExtensions` containing the supported custom functions for the expression. Defaults to an empty set of functions.
|
||||||
|
* @return A lambda function that evaluates the compiled expression and returns its result.
|
||||||
|
* @throws ExpressionException If there is a syntax error or a parsing issue in the provided expression.
|
||||||
|
*/
|
||||||
fun compileTypedExpression(
|
fun compileTypedExpression(
|
||||||
expression: String,
|
expression: String,
|
||||||
constants: (String) -> Any? = { null },
|
constants: (String) -> Any? = { null },
|
||||||
@@ -1023,7 +1062,6 @@ fun compileTypedExpression(
|
|||||||
val root = parser.keyLangFile()
|
val root = parser.keyLangFile()
|
||||||
val listener = TypedExpressionListener(functions, constants)
|
val listener = TypedExpressionListener(functions, constants)
|
||||||
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
try {
|
try {
|
||||||
ParseTreeWalker.DEFAULT.walk(listener, root)
|
ParseTreeWalker.DEFAULT.walk(listener, root)
|
||||||
|
|||||||
@@ -61,4 +61,57 @@ class TestTypedCompiledExpression {
|
|||||||
println("that took ${end - start}")
|
println("that took ${end - start}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testFunction2() {
|
||||||
|
run {
|
||||||
|
val c = compileFunction1OrNull<Map<String, Any>, Double>("x.x + 3.0", "x")!!
|
||||||
|
assertEquals(1.0 + 3.0, c(mapOf("x" to 1.0)))
|
||||||
|
//assertEquals(2.0 + 3.0, c(mapOf("x" to 2.0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDynamicConstants() {
|
||||||
|
|
||||||
|
val env = { n: String ->
|
||||||
|
when (n) {
|
||||||
|
"a" -> { nn: String ->
|
||||||
|
when (nn) {
|
||||||
|
"a" -> { nnn: String ->
|
||||||
|
when (nnn) {
|
||||||
|
"b" -> 7.0
|
||||||
|
"c" -> { x: Double -> x + 1.0 }
|
||||||
|
else -> null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
"b" -> 5.0
|
||||||
|
"c" -> { x: Double -> x * 2.0 }
|
||||||
|
else -> null
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
"c" -> { x: Double -> x * 3.0 }
|
||||||
|
else -> null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val c0 = compileFunction1OrNull<Map<String, Any>, Double>("a.a.c(2.0)", "x", constants = env)!!
|
||||||
|
val r0 = c0(emptyMap())
|
||||||
|
assertEquals(3.0, r0)
|
||||||
|
|
||||||
|
val c1 = compileFunction1OrNull<Map<String, Any>, Double>("a.c(2.0)", "x", constants = env)!!
|
||||||
|
val r1 = c1(emptyMap())
|
||||||
|
assertEquals(4.0, r1)
|
||||||
|
|
||||||
|
val c2 = compileFunction1OrNull<Map<String, Any>, Double>("c(2.0)", "x", constants = env)!!
|
||||||
|
val r2 = c2(emptyMap())
|
||||||
|
assertEquals(6.0, r2)
|
||||||
|
|
||||||
|
val c3 = compileFunction1OrNull<Map<String, Any>, Double>("cos(2.0)", "x", constants = env)!!
|
||||||
|
val r3 = c3(emptyMap())
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user