[orx-expression-evaluator-typed] Add list and lambda support

This commit is contained in:
Edwin Jakobs
2024-06-11 10:23:53 +02:00
parent 382e155813
commit 8c2e8a2d73
8 changed files with 311 additions and 10 deletions

View File

@@ -11,9 +11,9 @@ fun <T0, R> compileFunction1(
constants: (String) -> Any? = { null },
functions: TypedFunctionExtensions = TypedFunctionExtensions.EMPTY
): ((T0) -> R) {
require(constants(parameter0) == null) {
"${parameter0} is in constants with value '${constants(parameter0)}"
}
// require(constants(parameter0) == null) {
// "${parameter0} is in constants with value '${constants(parameter0)}"
// }
val root = expressionRoot(expression)
var varP0: T0? = null

View File

@@ -1,6 +1,8 @@
package org.openrndr.extra.expressions.typed
fun String.memberFunctions(n: String): ((Array<Any>) -> Any)? {
import kotlin.math.roundToInt
internal fun String.memberFunctions(n: String): ((Array<Any>) -> Any)? {
return when (n) {
"take" -> { n -> this.take((n[0] as Number).toInt()) }
"drop" -> { n -> this.drop((n[0] as Number).toInt()) }
@@ -8,4 +10,16 @@ fun String.memberFunctions(n: String): ((Array<Any>) -> Any)? {
"dropLast" -> { n -> this.takeLast((n[0] as Number).toInt()) }
else -> null
}
}
internal fun List<*>.memberFunctions(n: String): ((Array<Any>) -> Any)? {
return when (n) {
"take" -> { n -> this.take((n[0] as Number).toInt()) }
"drop" -> { n -> this.drop((n[0] as Number).toInt()) }
"takeLast" -> { n -> this.takeLast((n[0] as Number).toInt()) }
"dropLast" -> { n -> this.takeLast((n[0] as Number).toInt()) }
"map" -> { n -> val lambda = (n[0] as (Any)->Any); this.map { lambda(it!!) } }
"filter" -> { n -> val lambda = (n[0] as (Any)->Any); this.filter { (lambda(it!!) as Double).roundToInt() != 0 } }
else -> null
}
}

View File

@@ -56,7 +56,8 @@ enum class IDType {
FUNCTION2,
FUNCTION3,
FUNCTION4,
FUNCTION5
FUNCTION5,
FUNCTION_ARGUMENT
}
abstract class TypedExpressionListenerBase(
@@ -74,22 +75,71 @@ abstract class TypedExpressionListenerBase(
var lastExpressionResult: Any? = null
val exceptionStack = ArrayDeque<ExpressionException>()
var inFunctionLiteral = 0
fun reset() {
valueStack.clear()
functionStack.clear()
propertyStack.clear()
idTypeStack.clear()
lastExpressionResult = null
exceptionStack.clear()
inFunctionLiteral = 0
}
}
abstract val state: State
override fun enterLine(ctx: KeyLangParser.LineContext) {
val s = state
s.reset()
}
override fun exitListLiteral(ctx: KeyLangParser.ListLiteralContext) {
val s = state
val list = (0 until ctx.getExpression().size).map { s.valueStack.pop() }
s.valueStack.push(list.reversed())
}
override fun enterFunctionLiteral(ctx: KeyLangParser.FunctionLiteralContext) {
val s = state
s.inFunctionLiteral++
}
override fun exitFunctionLiteral(ctx: KeyLangParser.FunctionLiteralContext) {
val s = state
s.inFunctionLiteral--
val functionExpr = ctx.getExpression().text
val ids = ctx.ID()
val f = when (ids.size) {
0 -> compileFunction1<Any, Any>(functionExpr, "it", constants, functions)
1 -> compileFunction1<Any, Any>(functionExpr, ids[0].text, constants, functions)
2 -> compileFunction2<Any, Any, Any>(functionExpr, ids[0].text, ids[1].text, constants, functions)
else -> error("functions with ${ids.size} parameters are not supported")
}
s.valueStack.push(f)
}
override fun exitExpressionStatement(ctx: KeyLangParser.ExpressionStatementContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
throw ExpressionException("error in evaluation of '${ctx.text}': ${it.message ?: ""}")
}
val result = state.valueStack.pop()
state.lastExpressionResult = result
s.lastExpressionResult = result
}
override fun exitMinusExpression(ctx: KeyLangParser.MinusExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
val op = s.valueStack.pop()
s.valueStack.pushChecked(
when (op) {
@@ -105,6 +155,9 @@ abstract class TypedExpressionListenerBase(
override fun exitBinaryOperation1(ctx: KeyLangParser.BinaryOperation1Context) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
@@ -128,6 +181,7 @@ abstract class TypedExpressionListenerBase(
left is Matrix44 && right is Double -> left * right
left is ColorRGBa && right is Double -> left * right
left is String && right is Double -> left.repeat(right.roundToInt())
left is List<*> && right is Double -> (0 until right.roundToInt()).flatMap { left }
else -> error("unsupported operands for * operator left:${left::class} right:${right::class}")
}
@@ -159,6 +213,11 @@ abstract class TypedExpressionListenerBase(
@Suppress("IMPLICIT_CAST_TO_ANY")
override fun exitBinaryOperation2(ctx: KeyLangParser.BinaryOperation2Context) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -176,6 +235,7 @@ abstract class TypedExpressionListenerBase(
left is Matrix44 && right is Matrix44 -> left + right
left is ColorRGBa && right is ColorRGBa -> left + right
left is String && right is String -> left + right
left is List<*> && right is List<*> -> left + right
else -> error("unsupported operands for + operator left:${left::class} right:${right::class}")
}
@@ -196,6 +256,10 @@ abstract class TypedExpressionListenerBase(
override fun exitJoinOperation(ctx: KeyLangParser.JoinOperationContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
val right = (s.valueStack.pop() as Double).roundToInt()
val left = (s.valueStack.pop() as Double).roundToInt()
@@ -209,6 +273,9 @@ abstract class TypedExpressionListenerBase(
override fun exitComparisonOperation(ctx: KeyLangParser.ComparisonOperationContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
val right = s.valueStack.pop()
val left = s.valueStack.pop()
@@ -251,12 +318,20 @@ abstract class TypedExpressionListenerBase(
override fun exitNegateExpression(ctx: KeyLangParser.NegateExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
val operand = (s.valueStack.pop() as Double).roundToInt()
s.valueStack.pushChecked(if (operand == 0) 1.0 else 0.0)
}
override fun exitTernaryExpression(ctx: KeyLangParser.TernaryExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
val right = s.valueStack.pop()
val left = s.valueStack.pop()
val comp = s.valueStack.pop()
@@ -270,16 +345,28 @@ abstract class TypedExpressionListenerBase(
override fun enterValueReference(ctx: KeyLangParser.ValueReferenceContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.VARIABLE)
}
override fun enterMemberFunctionCall0Expression(ctx: KeyLangParser.MemberFunctionCall0ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.MEMBER_FUNCTION1)
}
override fun exitMemberFunctionCall0Expression(ctx: KeyLangParser.MemberFunctionCall0ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -289,11 +376,19 @@ abstract class TypedExpressionListenerBase(
override fun enterMemberFunctionCall1Expression(ctx: KeyLangParser.MemberFunctionCall1ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.MEMBER_FUNCTION1)
}
override fun exitMemberFunctionCall1Expression(ctx: KeyLangParser.MemberFunctionCall1ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -303,11 +398,19 @@ abstract class TypedExpressionListenerBase(
override fun enterMemberFunctionCall2Expression(ctx: KeyLangParser.MemberFunctionCall2ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.MEMBER_FUNCTION2)
}
override fun exitMemberFunctionCall2Expression(ctx: KeyLangParser.MemberFunctionCall2ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -320,11 +423,19 @@ abstract class TypedExpressionListenerBase(
override fun enterMemberFunctionCall3Expression(ctx: KeyLangParser.MemberFunctionCall3ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.MEMBER_FUNCTION3)
}
override fun exitMemberFunctionCall3Expression(ctx: KeyLangParser.MemberFunctionCall3ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -336,14 +447,44 @@ abstract class TypedExpressionListenerBase(
s.valueStack.pushChecked(s.functionStack.pop().invoke(arrayOf(argument0, argument1, argument2)))
}
override fun enterMemberFunctionCall0LambdaExpression(ctx: KeyLangParser.MemberFunctionCall0LambdaExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.MEMBER_FUNCTION1)
}
override fun exitMemberFunctionCall0LambdaExpression(ctx: KeyLangParser.MemberFunctionCall0LambdaExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
}
s.valueStack.pushChecked(s.functionStack.pop().invoke(arrayOf(s.valueStack.pop())))
}
override fun enterFunctionCall0Expression(ctx: KeyLangParser.FunctionCall0ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.FUNCTION0)
}
override fun exitFunctionCall0Expression(ctx: KeyLangParser.FunctionCall0ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -356,11 +497,19 @@ abstract class TypedExpressionListenerBase(
override fun enterFunctionCall1Expression(ctx: KeyLangParser.FunctionCall1ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.FUNCTION1)
}
override fun exitFunctionCall1Expression(ctx: KeyLangParser.FunctionCall1ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -375,11 +524,19 @@ abstract class TypedExpressionListenerBase(
override fun enterFunctionCall2Expression(ctx: KeyLangParser.FunctionCall2ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.FUNCTION2)
}
override fun exitFunctionCall2Expression(ctx: KeyLangParser.FunctionCall2ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -395,11 +552,19 @@ abstract class TypedExpressionListenerBase(
override fun enterFunctionCall3Expression(ctx: KeyLangParser.FunctionCall3ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.FUNCTION3)
}
override fun exitFunctionCall3Expression(ctx: KeyLangParser.FunctionCall3ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -416,11 +581,19 @@ abstract class TypedExpressionListenerBase(
override fun enterFunctionCall4Expression(ctx: KeyLangParser.FunctionCall4ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.FUNCTION4)
}
override fun exitFunctionCall4Expression(ctx: KeyLangParser.FunctionCall4ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -439,11 +612,19 @@ abstract class TypedExpressionListenerBase(
override fun enterFunctionCall5Expression(ctx: KeyLangParser.FunctionCall5ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.FUNCTION5)
}
override fun exitFunctionCall5Expression(ctx: KeyLangParser.FunctionCall5ExpressionContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
ifError {
pushError(it.message ?: "")
return
@@ -467,6 +648,7 @@ abstract class TypedExpressionListenerBase(
private fun pushError(message: String) {
val s = state
s.exceptionStack.push(ExpressionException(message))
}
@@ -480,11 +662,19 @@ abstract class TypedExpressionListenerBase(
override fun enterPropReference(ctx: KeyLangParser.PropReferenceContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
s.idTypeStack.push(IDType.PROPERTY)
}
override fun exitPropReference(ctx: KeyLangParser.PropReferenceContext) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
val root = s.valueStack.pop()
var current = root
val property = s.propertyStack.pop()
@@ -507,6 +697,10 @@ abstract class TypedExpressionListenerBase(
override fun visitTerminal(node: TerminalNode) {
val s = state
if (s.inFunctionLiteral > 0) {
return
}
val type = node.symbol.type
if (type == KeyLangParser.Tokens.INTLIT) {
s.valueStack.pushChecked(node.text.toDouble())
@@ -552,6 +746,13 @@ abstract class TypedExpressionListenerBase(
)
}
is List<*> -> {
s.functionStack.push(
receiver.memberFunctions(name)
?: error("no member function '$receiver.$name()'")
)
}
is ColorRGBa -> {
when (idType) {
IDType.MEMBER_FUNCTION1 -> {
@@ -598,6 +799,7 @@ abstract class TypedExpressionListenerBase(
}
}
else -> error("receiver '${receiver}' not supported")
}
}
@@ -639,6 +841,10 @@ abstract class TypedExpressionListenerBase(
s.functionStack.push(function)
}
IDType.FUNCTION_ARGUMENT -> {
}
else -> error("unsupported id-type $idType")
}
}

View File

@@ -0,0 +1,8 @@
package org.openrndr.extra.expressions.typed
actual class TypedExpressionListener actual constructor(
functions: TypedFunctionExtensions,
constants: (String) -> Any?
) : TypedExpressionListenerBase(functions, constants) {
actual override val state: State = State()
}

View File

@@ -0,0 +1,15 @@
package org.openrndr.extra.expressions.typed
import kotlin.concurrent.getOrSet
/*
Thread safe TypeExpressionListener
*/
actual class TypedExpressionListener actual constructor(
functions: TypedFunctionExtensions,
constants: (String) -> Any?
) : TypedExpressionListenerBase(functions, constants) {
private val threadLocalState = ThreadLocal<State>()
actual override val state: State
get() = threadLocalState.getOrSet { State() }
}

View File

@@ -7,6 +7,53 @@ import kotlin.test.Test
class TestTypedExpression {
@Test
fun funTestFunction() {
run {
val r = evaluateTypedExpression("{ x -> 2.0 + x }")
val f = r as (Double) -> Double
println(f(3.0))
}
run {
val r = evaluateTypedExpression("{ { 2.0 + it } }")
val f0 = r as (Any) -> ((Any) -> Any)
val f1 = f0(0.0)
println(f1(3.0))
}
}
@Test
fun funTestLambdaArg() {
run {
val r = evaluateTypedExpression("[0.0, 1.0].map { x -> 2.0 + x }")
assertEquals(listOf(2.0, 3.0), r)
}
run {
val r = evaluateTypedExpression("[0.0, 1.0].map { x -> vec2(2.0 + x, 2.0 + x) }")
assertEquals(listOf(Vector2(2.0, 2.0), Vector2(3.0, 3.0)), r)
}
run {
val r = evaluateTypedExpression("[0.0, 1.0, 2.0].filter { x -> x >= 1.0 }")
assertEquals(listOf(1.0, 2.0), r)
}
}
@Test
fun testList() {
println("result is: ${evaluateTypedExpression("[]")}")
println("result is: ${evaluateTypedExpression("[1.0, 2.0]")}")
println("result is: ${evaluateTypedExpression("[1.0, 2.0].take(1)")}")
println("result is: ${evaluateTypedExpression("[1.0 + 2.0, 2.0 * 3.0].take(1 + 1)")}")
println("result is: ${evaluateTypedExpression("[] + []")}")
println("result is: ${evaluateTypedExpression("([1] * 2 + [2] * 1)*5")}" )
}
@Test
fun testTernary() {
println("result is: ${evaluateTypedExpression("2.0 > 0.5 ? 1.3 : 0.7")}")

View File

@@ -25,18 +25,24 @@ DIVISION : '/' ;
ASSIGN : '=' ;
LPAREN : '(' ;
RPAREN : ')' ;
LBRACKET : '[' ;
RBRACKET : ']' ;
LCURLY : '{' ;
RCURLY : '}' ;
QUESTION_MARK : '?' ;
COLON : ':' ;
ARROW : '->' ;
COMMA : ',' ;
DOT : '.' ;
EQ : '==' ;
LT : '<' ;
LTEQ : '<=' ;
GT : '>=' ;
GTEQ : '>' ;
GT : '>' ;
GTEQ : '>=' ;
AND : '&&' ;
OR : '||' ;

View File

@@ -10,8 +10,13 @@ line : statement (NEWLINE | EOF) ;
statement :
expression # expressionStatement ;
lambda: LCURLY (ID ( COMMA ID )* ARROW )? expression RCURLY # functionLiteral;
expression : INTLIT # intLiteral
| DECLIT # decimalLiteral
| LBRACKET (expression ( COMMA expression )*)? RBRACKET # listLiteral
| expression DOT ID lambda # memberFunctionCall0LambdaExpression
| lambda # lambdaExpression
| expression DOT ID LPAREN RPAREN # memberFunctionCall0Expression
| expression DOT ID LPAREN expression RPAREN # memberFunctionCall1Expression
| expression DOT ID LPAREN expression COMMA expression RPAREN # memberFunctionCall2Expression