From b1e860cf395277cbc949c3fe14752f241c58543a Mon Sep 17 00:00:00 2001 From: Edwin Jakobs Date: Tue, 13 May 2025 23:21:27 +0200 Subject: [PATCH] [orx-expression-evaluator-typed] Refactor function dispatch handling, improve function resolution and add new tests --- .../src/commonMain/kotlin/typed/Function0.kt | 18 +++ .../kotlin/typed/TypedExpressions.kt | 152 +++++++++++------- .../typed/TestTypedCompiledExpression.kt | 53 ++++++ 3 files changed, 166 insertions(+), 57 deletions(-) create mode 100644 orx-expression-evaluator-typed/src/commonMain/kotlin/typed/Function0.kt diff --git a/orx-expression-evaluator-typed/src/commonMain/kotlin/typed/Function0.kt b/orx-expression-evaluator-typed/src/commonMain/kotlin/typed/Function0.kt new file mode 100644 index 00000000..283379a3 --- /dev/null +++ b/orx-expression-evaluator-typed/src/commonMain/kotlin/typed/Function0.kt @@ -0,0 +1,18 @@ +package org.openrndr.extra.expressions.typed + +import org.openrndr.extra.noise.uniform + +/** + * Dispatches a function without arguments based on its name. + * + * @param name The name of the function to dispatch. + * @param functions A map containing functions of type `TypedFunction0` associated with their names. + * @return A callable lambda that takes an array of `Any` as input and returns a result if the function is found, + * or null if there is no match. + */ +internal fun dispatchFunction0(name: String, functions: Map): ((Array) -> Any)? { + return when (name) { + "random" -> { x -> Double.uniform(0.0, 1.0) } + else -> functions[name]?.let { { x: Array -> it.invoke() } } + } +} \ No newline at end of file diff --git a/orx-expression-evaluator-typed/src/commonMain/kotlin/typed/TypedExpressions.kt b/orx-expression-evaluator-typed/src/commonMain/kotlin/typed/TypedExpressions.kt index 3b0c0fd7..7a561dbe 100644 --- a/orx-expression-evaluator-typed/src/commonMain/kotlin/typed/TypedExpressions.kt +++ b/orx-expression-evaluator-typed/src/commonMain/kotlin/typed/TypedExpressions.kt @@ -780,6 +780,32 @@ abstract class TypedExpressionListenerBase( return } + + fun handleFunction( + name: String, + dispatchFunction: (String, Map) -> ((Array) -> Any)?, + functionMap: Map, + adapter: (T) -> (Array) -> Any, + errorMessage: String + ) { + val function = dispatchFunction(name, functionMap) + + if (function != null) { + s.functionStack.push(function) + } else { + val cfunction = constants(name) as? T + if (cfunction != null) { + s.functionStack.push(adapter(cfunction)) + } else { + s.functionStack.push(errorValue("unresolved function: '$errorMessage'") { _ -> + error("this is the error function") + }) + } + } + } + + + val type = node.symbol.type if (type == KeyLangParser.Tokens.INTLIT) { s.valueStack.pushChecked(node.text.toDouble()) @@ -794,24 +820,13 @@ abstract class TypedExpressionListenerBase( IDType.VARIABLE -> s.valueStack.pushChecked( when (name) { "PI" -> PI - else -> constants(name) ?: errorValue("unresolved value: '${name}'. Available constant: ${constants}", Unit) + else -> constants(name) + ?: errorValue("unresolved value: '${name}'. Available constant: ${constants}", Unit) } ) IDType.PROPERTY -> s.propertyStack.push(name) - IDType.FUNCTION0 -> { - val function: (Array) -> Any = - when (name) { - "random" -> { _ -> Double.uniform(0.0, 1.0) } - else -> functions.functions0[name]?.let { { _: Array -> it.invoke() } } - ?: errorValue( - "unresolved function: '${name}()'" - ) { _ -> error("this is the error function") } - } - s.functionStack.push(function) - } - IDType.MEMBER_FUNCTION0, IDType.MEMBER_FUNCTION1, IDType.MEMBER_FUNCTION2, @@ -835,18 +850,18 @@ abstract class TypedExpressionListenerBase( is ColorRGBa -> { when (idType) { IDType.MEMBER_FUNCTION1 -> { - s.functionStack.push(when (name) { - "shade" -> { x -> receiver.shade(x[0] as Double) } - "opacify" -> { x -> receiver.opacify(x[0] as Double) } - else -> error("no member function '$receiver.$name()'") - }) + s.functionStack.push( + when (name) { + "shade" -> { x -> receiver.shade(x[0] as Double) } + "opacify" -> { x -> receiver.opacify(x[0] as Double) } + else -> error("no member function '$receiver.$name()'") + }) } else -> error("no member function $idType '$receiver.$name()") } } - is Function<*> -> { fun input(): String { @@ -867,27 +882,28 @@ abstract class TypedExpressionListenerBase( IDType.MEMBER_FUNCTION1 -> { @Suppress("UNCHECKED_CAST") - (function as? (Any) -> Any) ?: error("Cannot cast function '$name' ($function) to (Any) -> Any ${input()}") + (function as? (Any) -> Any) + ?: error("Cannot cast function '$name' ($function) to (Any) -> Any ${input()}") s.functionStack.push({ x -> function(x[0]) }) } IDType.MEMBER_FUNCTION2 -> { @Suppress("UNCHECKED_CAST") - function as? (Any, Any) -> Any ?: error("Cannot cast function '$name' ($function) to (Any, Any) -> Any ${input()}") + function as? (Any, Any) -> Any + ?: error("Cannot cast function '$name' ($function) to (Any, Any) -> Any ${input()}") s.functionStack.push({ x -> function(x[0], x[1]) }) } IDType.MEMBER_FUNCTION3 -> { @Suppress("UNCHECKED_CAST") - function as? (Any, Any, Any) -> Any ?: error("Cannot cast function '$name' ($function) to (Any, Any, Any) -> Any ${input()}") + function as? (Any, Any, Any) -> Any + ?: error("Cannot cast function '$name' ($function) to (Any, Any, Any) -> Any ${input()}") s.functionStack.push({ x -> function(x[0], x[1], x[2]) }) } else -> error("unreachable") } } - - else -> error( "receiver for '$name' '${ receiver.toString().take(30) @@ -896,41 +912,55 @@ abstract class TypedExpressionListenerBase( } } - IDType.FUNCTION1 -> { - val localState = state - val function: (Array) -> Any = - dispatchFunction1(name, functions.functions1) - ?: errorValue( - "unresolved function: '${name}(x0)'" - ) { _ -> error("this is the error function") } - localState.functionStack.push(function) - } + IDType.FUNCTION0 -> handleFunction( + name, + ::dispatchFunction0, + functions.functions0, + { f -> { x -> f() } }, + "${name}()" + ) - IDType.FUNCTION2 -> { - val function: (Array) -> Any = - dispatchFunction2(name, functions.functions2) - ?: errorValue( - "unresolved function: '${name}(x0, x1)'" - ) { _ -> error("this is the error function") } - s.functionStack.push(function) - } + IDType.FUNCTION1 -> handleFunction( + name, + ::dispatchFunction1, + functions.functions1, + { f -> { x -> f(x[0]) } }, + "${name}(x0)" + ) - IDType.FUNCTION3 -> { - val function: (Array) -> Any = - dispatchFunction3(name, functions.functions3) - ?: errorValue( - "unresolved function: '${name}(x0)'" - ) { _ -> error("this is the error function") } - s.functionStack.push(function) - } + IDType.FUNCTION2 -> handleFunction( + name, + ::dispatchFunction2, + functions.functions2, + { f -> { x -> f(x[0], x[1]) } }, + "${name}(x0, x1)" + ) - IDType.FUNCTION4 -> { - val function: (Array) -> Any = - dispatchFunction4(name, functions.functions4) - ?: errorValue( - "unresolved function: '${name}(x0)'" - ) { _ -> error("this is the error function") } - s.functionStack.push(function) + IDType.FUNCTION3 -> handleFunction( + name, + ::dispatchFunction3, + functions.functions3, + { f -> { x -> f(x[0], x[1], x[2]) } }, + "${name}(x0, x1, x2)" + ) + + IDType.FUNCTION4 -> handleFunction( + name, + ::dispatchFunction4, + functions.functions4, + { f -> { x -> f(x[0], x[1], x[2], x[3]) } }, + "${name}(x0, x1, x2, x3)" + ) + + IDType.FUNCTION5 -> { + val cfunction = constants(name) as? (Any, Any, Any, Any, Any) -> Any + if (cfunction != null) { + s.functionStack.push({ x -> cfunction(x[0], x[1], x[2], x[3], x[4]) }) + } else { + s.functionStack.push(errorValue("unresolved function: '${name}(x0, x1, x2, x3, x4)'") { _ -> + error("this is the error function") + }) + } } IDType.FUNCTION_ARGUMENT -> { @@ -1000,6 +1030,15 @@ fun evaluateTypedExpression( return listener.state.lastExpressionResult } +/** + * Compiles a typed expression and returns a lambda that can execute the compiled expression. + * + * @param expression The string representation of the expression to compile. + * @param constants A lambda function to resolve constants by their names. Defaults to a resolver that returns null. + * @param functions An instance of `TypedFunctionExtensions` containing the supported custom functions for the expression. Defaults to an empty set of functions. + * @return A lambda function that evaluates the compiled expression and returns its result. + * @throws ExpressionException If there is a syntax error or a parsing issue in the provided expression. + */ fun compileTypedExpression( expression: String, constants: (String) -> Any? = { null }, @@ -1023,7 +1062,6 @@ fun compileTypedExpression( val root = parser.keyLangFile() val listener = TypedExpressionListener(functions, constants) - return { try { ParseTreeWalker.DEFAULT.walk(listener, root) diff --git a/orx-expression-evaluator-typed/src/jvmTest/kotlin/typed/TestTypedCompiledExpression.kt b/orx-expression-evaluator-typed/src/jvmTest/kotlin/typed/TestTypedCompiledExpression.kt index 91904842..ab0f7a61 100644 --- a/orx-expression-evaluator-typed/src/jvmTest/kotlin/typed/TestTypedCompiledExpression.kt +++ b/orx-expression-evaluator-typed/src/jvmTest/kotlin/typed/TestTypedCompiledExpression.kt @@ -61,4 +61,57 @@ class TestTypedCompiledExpression { println("that took ${end - start}") } } + + @Test + fun testFunction2() { + run { + val c = compileFunction1OrNull, Double>("x.x + 3.0", "x")!! + assertEquals(1.0 + 3.0, c(mapOf("x" to 1.0))) + //assertEquals(2.0 + 3.0, c(mapOf("x" to 2.0)) + } + } + + @Test + fun testDynamicConstants() { + + val env = { n: String -> + when (n) { + "a" -> { nn: String -> + when (nn) { + "a" -> { nnn: String -> + when (nnn) { + "b" -> 7.0 + "c" -> { x: Double -> x + 1.0 } + else -> null + } + } + + "b" -> 5.0 + "c" -> { x: Double -> x * 2.0 } + else -> null + } + + } + "c" -> { x: Double -> x * 3.0 } + else -> null + } + } + + val c0 = compileFunction1OrNull, Double>("a.a.c(2.0)", "x", constants = env)!! + val r0 = c0(emptyMap()) + assertEquals(3.0, r0) + + val c1 = compileFunction1OrNull, Double>("a.c(2.0)", "x", constants = env)!! + val r1 = c1(emptyMap()) + assertEquals(4.0, r1) + + val c2 = compileFunction1OrNull, Double>("c(2.0)", "x", constants = env)!! + val r2 = c2(emptyMap()) + assertEquals(6.0, r2) + + val c3 = compileFunction1OrNull, Double>("cos(2.0)", "x", constants = env)!! + val r3 = c3(emptyMap()) + + + } } \ No newline at end of file