package io.gitlab.jfronny.muscript.optimizer; import io.gitlab.jfronny.muscript.ast.*; import io.gitlab.jfronny.muscript.ast.bool.*; import io.gitlab.jfronny.muscript.ast.context.Script; import io.gitlab.jfronny.muscript.ast.dynamic.*; import io.gitlab.jfronny.muscript.ast.extensible.*; import io.gitlab.jfronny.muscript.ast.number.*; import io.gitlab.jfronny.muscript.ast.string.*; import io.gitlab.jfronny.muscript.core.CodeLocation; import io.gitlab.jfronny.muscript.data.additional.DFinal; import io.gitlab.jfronny.muscript.data.dynamic.Dynamic; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Stream; import static io.gitlab.jfronny.muscript.ast.Expr.literal; import static io.gitlab.jfronny.muscript.ast.context.ExprUtils.*; public class Optimizer { public static Script optimize(Script script) { return new Script(optimize(script.content())); } public static Expr optimize(Expr expr) { return switch (unpack(expr)) { case null -> throw new NullPointerException(); case ExtensibleExpr e -> e.optimize(); case StringExpr stringExpr -> optimize(stringExpr); case NumberExpr numberExpr -> optimize(numberExpr); case BoolExpr boolExpr -> optimize(boolExpr); case DynamicExpr dynamicExpr -> optimize(dynamicExpr); case NullLiteral nl -> nl; }; } public static StringExpr optimize(StringExpr expr) { return switch (optimizeInner(expr)) { case null -> throw new NullPointerException(); case StringConditional(var location, BoolLiteral(var location1, var value), var ifTrue, var ifFalse) -> value ? ifTrue : ifFalse; case StringConditional(var location, Not(var location1, var condition), var ifTrue, var ifFalse) -> new StringConditional(location, condition, ifFalse, ifTrue); case Concatenate(var location, StringLiteral(var location1, var left), StringLiteral(var location2, var right)) -> literal(location, left + right); case Concatenate(var location, Concatenate(var location1, var left1, StringLiteral(var location2, var right1)), StringLiteral(var location3, var right2)) -> new Concatenate(location1, left1, literal(location2, right1 + right2)); case Concatenate(var location, StringLiteral(var location1, var left1), Concatenate(var location2, StringLiteral(var location3, var left2), var right)) -> new Concatenate(location2, literal(location3, left1 + left2), right); case StringExpr fallback -> fallback; // Used instead of default to still keep applied optimizations }; } private static StringExpr optimizeInner(StringExpr expr) { return switch (expr) { case null -> null; case StringLiteral e -> e; case ExtensibleStringExpr e -> e.optimize(); case StringUnpack(var inner) -> asString(unpack(optimize(inner))); case StringCoerce(var inner) -> asString(unpack(optimize(inner))); case StringAssign(var location, var variable, var value) -> new StringAssign(location, variable, optimize(value)); case StringConditional(var location, var condition, var ifTrue, var ifFalse) -> new StringConditional(location, optimize(condition), optimize(ifTrue), optimize(ifFalse)); case Concatenate(var location, var left, var right) -> new Concatenate(location, optimize(left), optimize(right)); }; } public static NumberExpr optimize(NumberExpr expr) { return switch (optimizeInner(expr)) { case null -> throw new NullPointerException(); case NumberConditional(var location, BoolLiteral(var location1, var value), var ifTrue, var ifFalse) -> value ? ifTrue : ifFalse; case NumberConditional(var location, Not(var location1, var condition), var ifTrue, var ifFalse) -> new NumberConditional(location, condition, ifFalse, ifTrue); case Add(var location, NumberLiteral(var location1, var augend), NumberLiteral(var location2, var addend)) -> literal(location, augend + addend); case Add(var location, Negate(var location1, var augend), Negate(var location2, var addend)) -> new Negate(location, optimize(new Add(location, augend, addend))); case Add(var location, Negate(var location1, var augend), var addend) -> optimize(new Subtract(location, addend, augend)); case Add(var location, var augend, Negate(var location1, var addend)) -> optimize(new Subtract(location, augend, addend)); case Subtract(var location, NumberLiteral(var location1, var minuend), NumberLiteral(var location2, var subtrahend)) -> literal(location, minuend - subtrahend); case Subtract(var location, Subtract(var location1, var minuend1, var subtrahend1), var subtrahend) -> optimize(new Subtract(location, minuend1, new Add(location1, subtrahend1, subtrahend))); case Subtract(var location, var minuend, Subtract(var location1, var minuend1, var subtrahend1)) -> new Subtract(location, new Add(location1, minuend, subtrahend1), minuend1); case Subtract(var location, Negate(var location1, var minuend), Negate(var location2, var subtrahend)) -> new Subtract(location, subtrahend, minuend); case Subtract(var location, Negate(var location1, var minuend), var subtrahend) -> optimize(new Negate(location, new Add(location, minuend, subtrahend))); case Subtract(var location, var minuend, Negate(var location1, var subtrahend)) -> optimize(new Add(location, minuend, subtrahend)); case Negate(var location, NumberLiteral(var location1, var value)) -> literal(location, -value); case Negate(var location, Negate(var location1, var inner)) -> inner; case Negate(var location, Subtract(var location1, var minuend, var subtrahend)) -> optimize(new Subtract(location, subtrahend, minuend)); case Multiply(var location, NumberLiteral(var location1, var multiplier), NumberLiteral(var location2, var multiplicand)) -> literal(location, multiplier * multiplicand); //TODO optimize multiplication with or division by 1 case Divide(var location, NumberLiteral(var location1, var dividend), NumberLiteral(var location2, var divisor)) -> literal(location, dividend / divisor); case Divide(var location, Divide(var location1, var dividend1, var divisor1), var divisor) -> new Divide(location, dividend1, new Multiply(location1, divisor1, divisor)); case Divide(var location, var dividend, Divide(var location1, var dividend1, var divisor1)) -> new Divide(location, new Multiply(location1, dividend, divisor1), dividend1); case Divide(var location, Negate(var location1, var dividend), Negate(var location2, var divisor)) -> new Divide(location, divisor, dividend); case Divide(var location, Negate(var location1, var dividend), var divisor) -> new Negate(location, new Divide(location1, dividend, divisor)); case Divide(var location, var dividend, Negate(var location1, var divisor)) -> new Negate(location, new Divide(location1, dividend, divisor)); case Modulo(var location, NumberLiteral(var location1, var dividend), NumberLiteral(var location2, var divisor)) -> literal(location, dividend % divisor); case Power(var location, NumberLiteral(var location1, var base), NumberLiteral(var location2, var exponent)) -> literal(location, Math.pow(base, exponent)); case Power(var location, Multiply(var location1, NumberLiteral(var location2, var multiplier), var multiplicand), NumberLiteral(var location3, var exponent)) -> new Multiply(location1, literal(location, Math.pow(multiplier, exponent)), new Power(location, multiplicand, literal(location3, exponent))); case Power(var location, Multiply(var location1, var multiplier, NumberLiteral(var location2, var multiplicand)), NumberLiteral(var location3, var exponent)) -> new Multiply(location1, literal(location, Math.pow(multiplicand, exponent)), new Power(location, multiplier, literal(location3, exponent))); case NumberExpr fallback -> fallback; // Used instead of default to still keep applied optimizations }; } private static NumberExpr optimizeInner(NumberExpr expr) { return switch (expr) { case null -> null; case NumberLiteral e -> e; case ExtensibleNumberExpr e -> e.optimize(); case NumberUnpack(var inner) -> asNumber(unpack(optimize(inner))); case NumberAssign(var location, var variable, var value) -> new NumberAssign(location, variable, optimize(value)); case NumberConditional(var location, var condition, var ifTrue, var ifFalse) -> new NumberConditional(location, optimize(condition), optimize(ifTrue), optimize(ifFalse)); case Add(var location, var augend, var addend) -> new Add(location, optimize(augend), optimize(addend)); case Subtract(var location, var minuend, var subtrahend) -> new Subtract(location, optimize(minuend), optimize(subtrahend)); case Negate(var location, var inner) -> new Negate(location, optimize(inner)); case Multiply(var location, var multiplier, var multiplicand) -> new Multiply(location, optimize(multiplier), optimize(multiplicand)); case Divide(var location, var dividend, var divisor) -> new Divide(location, optimize(dividend), optimize(divisor)); case Modulo(var location, var dividend, var divisor) -> new Modulo(location, optimize(dividend), optimize(divisor)); case Power(var location, var base, var exponent) -> new Power(location, optimize(base), optimize(exponent)); }; } public static BoolExpr optimize(BoolExpr expr) { return switch (optimizeInner(expr)) { case null -> throw new NullPointerException(); case BoolConditional(var location, BoolLiteral(var location1, var value), var ifTrue, var ifFalse) -> value ? ifTrue : ifFalse; case BoolConditional(var location, Not(var location1, var condition), var ifTrue, var ifFalse) -> new BoolConditional(location, condition, ifFalse, ifTrue); case And(var location, BoolLiteral(var location1, var left), var right) -> left ? right : literal(location, false); case And(var location, var left, BoolLiteral(var location1, var right)) -> right ? left : literal(location, false); case Or(var location, BoolLiteral(var location1, var left), var right) -> left ? literal(location, true) : right; case Or(var location, var left, BoolLiteral(var location1, var right)) -> right ? literal(location, true) : left; case Not(var location, BoolLiteral(var location1, var value)) -> literal(location, !value); case Not(var location, Not(var location1, var inner)) -> inner; case GreaterThan(var location, NumberLiteral(var location1, var left), NumberLiteral(var location2, var right)) -> literal(location, left > right); case GreaterThan(var location, Divide(var location1, var dividend, var divisor), var right) -> optimize(new GreaterThan(location, dividend, new Multiply(location1, right, divisor))); case GreaterThan(var location, Negate(var location1, var inner), var right) -> optimize(new GreaterThan(location, new Negate(right.location(), right), inner)); case GreaterThan(var location, Subtract(var location1, var minuend, var subtrahend), var right) -> optimize(new GreaterThan(location, minuend, new Add(location1, subtrahend, right))); // Modulo is left out because it is too complicated for this naive impl // Multiply is left out since it would transform into a division and may be 0 case GreaterThan(var location, Add(var location1, var augend, var addend), var right) -> optimize(new GreaterThan(location, augend, new Subtract(location1, right, addend))); // Power is left out because it can't be transformed cleanly either case BoolExpr fallback -> fallback; // Used instead of default to still keep applied optimizations }; } private static BoolExpr optimizeInner(BoolExpr expr) { return switch (expr) { case null -> null; case BoolLiteral e -> e; case ExtensibleBoolExpr e -> e.optimize(); case BoolUnpack(var inner) -> asBool(unpack(optimize(inner))); case BoolAssign(var location, var variable, var value) -> new BoolAssign(location, variable, optimize(value)); case BoolConditional(var location, var condition, var ifTrue, var ifFalse) -> new BoolConditional(location, optimize(condition), optimize(ifTrue), optimize(ifFalse)); case And(var location, var left, var right) -> new And(location, optimize(left), optimize(right)); case Or(var location, var left, var right) -> new Or(location, optimize(left), optimize(right)); case Not(var location, var inner) -> new Not(location, optimize(inner)); case Equals(var location, var left, var right) -> new Equals(location, optimize(left), optimize(right)); case GreaterThan(var location, var left, var right) -> new GreaterThan(location, optimize(left), optimize(right)); }; } public static DynamicExpr optimize(DynamicExpr expr) { return switch (optimizeInner(expr)) { case null -> throw new NullPointerException(); case DynamicConditional(var location, BoolLiteral(var location1, var value), var ifTrue, var ifFalse) -> value ? ifTrue : ifFalse; case DynamicConditional(var location, Not(var location1, var inner), var ifTrue, var ifFalse) -> optimize(new DynamicConditional(location, inner, ifFalse, ifTrue)); case DynamicConditional(var location, var condition, var ifTrue, var ifFalse) -> ifTrue.equals(ifFalse) ? new ExprGroup(location, extractSideEffects(condition).map(Optimizer::optimize).toList(), ifTrue, null, false) : new DynamicConditional(location, condition, ifTrue, ifFalse); case GetOrAt(var location, ObjectLiteral l, var nameOrIndex) -> new Get(location, l, asString(nameOrIndex)); case GetOrAt(var location, ListLiteral l, var nameOrIndex) -> new At(location, l, asNumber(nameOrIndex)); case GetOrAt(var location, var left, StringExpr name) -> new Get(location, left, name); case GetOrAt(var location, var left, NumberExpr index) -> new At(location, left, index); case Call(var location, Bind(var location1, var callable, var parameter), var arguments) -> optimize(new Call(location, callable, concat(arguments, new Call.Argument(parameter, false)))); case Call(var location, Closure(var location1, var boundArgs, var variadic, var steps, var finish), var arguments) -> new ExprGroup(location1, concat(steps, finish), new ExprGroup.PackedArgs(arguments, boundArgs, variadic), true); case ObjectLiteral(var location, var originalContent) -> { var content = new LinkedHashMap(); var literalContent = new LinkedHashMap(); boolean literal = true; for (Map.Entry entry : originalContent.entrySet()) { DynamicExpr de = optimize(entry.getValue()); if (de instanceof DynamicLiteral(var location1, var cnt) && literal) { if (cnt instanceof Dynamic d) literalContent.put(entry.getKey(), d); else throw new IllegalArgumentException("Unsupported implementation of Dynamic"); } else literal = false; content.put(entry.getKey(), de); } if (literal) yield new DynamicLiteral(location, DFinal.of(literalContent)); yield new ObjectLiteral(location, content); } case DynamicExpr fallback -> fallback; // Used instead of default to still keep applied optimizations }; } private static DynamicExpr optimizeInner(DynamicExpr expr) { return switch (expr) { case null -> null; case DynamicLiteral e -> e; case ExtensibleDynamicExpr e -> e.optimize(); case DynamicCoerce(var inner) -> asDynamic(unpack(optimize(inner))); case DynamicAssign(var location, var variable, var value) -> new DynamicAssign(location, variable, optimize(value)); case DynamicConditional(var location, var condition, var ifTrue, var ifFalse) -> new DynamicConditional(location, optimize(condition), optimize(ifTrue), optimize(ifFalse)); case This e -> e; case Variable e -> e; case Get(var location, var left, var name) -> new Get(location, optimize(left), optimize(name)); case At(var location, var left, var index) -> new At(location, optimize(left), optimize(index)); case GetOrAt(var location, var left, var nameOrIndex) -> new GetOrAt(location, optimize(left), optimize(nameOrIndex)); case Bind(var location, var callable, var parameter) -> new Bind(location, optimize(callable), optimize(parameter)); case Call(var location, var callable, var arguments) -> new Call(location, optimize(callable), optimize(arguments)); case Closure(var location, var boundArgs, var variadic, var steps, var finish) -> new Closure(location, boundArgs, variadic, optimize(steps, finish), optimize(finish)); case ExprGroup(var location, var steps, var finish, ExprGroup.PackedArgs(var pFrom, var pTo, var variadic), var fork) -> new ExprGroup(location, optimize(steps, finish), new ExprGroup.PackedArgs(optimize(pFrom), pTo, variadic), fork); case ExprGroup(var location, var steps, var finish, var packedArgs, var fork) -> asDynamic(ExprGroup.of(location, optimize(steps, finish), fork)); case ListLiteral(var location, var elements) -> new ListLiteral(location, elements.stream().map(Optimizer::optimize).toList()); case ObjectLiteral e -> e; // Exclusively handled in optimize even though it contains expressions, since it cannot be decomposed }; } private static List concat(List list, T element) { return Stream.concat(list.stream(), Stream.of(element)).toList(); } private static List optimize(List steps, DynamicExpr finish) { return Stream.concat(steps.stream().map(Optimizer::optimize).flatMap(Optimizer::extractSideEffects).map(Optimizer::optimize), Stream.of(optimize(finish))).toList(); } private static List optimize(List arguments) { return arguments.stream().map(arg -> new Call.Argument(optimize(arg.value()), arg.variadic())).toList(); } public static Stream extractSideEffects(Expr expr) { return switch (expr) { case NullLiteral e -> Stream.empty(); case BoolLiteral e -> Stream.empty(); case NumberLiteral e -> Stream.empty(); case StringLiteral e -> Stream.empty(); case DynamicLiteral e -> Stream.empty(); case Closure e -> Stream.empty(); case Variable e -> Stream.empty(); case This e -> Stream.empty(); case NumberUnpack(var inner) -> extractSideEffects(inner); case StringUnpack(var inner) -> extractSideEffects(inner); case BoolUnpack(var inner) -> extractSideEffects(inner); case DynamicCoerce(var inner) -> extractSideEffects(inner); case StringCoerce(var inner) -> extractSideEffects(inner); case BoolAssign e -> Stream.of(e); case NumberAssign e -> Stream.of(e); case StringAssign e -> Stream.of(e); case DynamicAssign e -> Stream.of(e); case ExtensibleExpr e -> e.extractSideEffects(); case And(var location, var left, var right) -> Stream.concat(extractSideEffects(left), extractSideEffects(right)); case Or(var location, var left, var right) -> Stream.concat(extractSideEffects(left), extractSideEffects(right)); case Not(var location, var inner) -> extractSideEffects(inner); case Equals(var location, var left, var right) -> Stream.concat(extractSideEffects(left), extractSideEffects(right)); case GreaterThan(var location, var left, var right) -> Stream.concat(extractSideEffects(left), extractSideEffects(right)); case BoolConditional(var location, var condition, var ifTrue, var ifFalse) -> extractConditionalSideEffects(location, condition, ifTrue, ifFalse); case DynamicConditional(var location, var condition, var ifTrue, var ifFalse) -> extractConditionalSideEffects(location, condition, ifTrue, ifFalse); case NumberConditional(var location, var condition, var ifTrue, var ifFalse) -> extractConditionalSideEffects(location, condition, ifTrue, ifFalse); case StringConditional(var location, var condition, var ifTrue, var ifFalse) -> extractConditionalSideEffects(location, condition, ifTrue, ifFalse); case Bind(var location, var callable, var parameter) -> Stream.concat(extractSideEffects(callable), extractSideEffects(parameter)); case Call e -> Stream.of(e); case ExprGroup(var location, var steps, var finish, var packedArgs, var fork) -> fork ? Stream.of(expr) : Stream.of(ExprGroup.of(location, steps.stream().flatMap(Optimizer::extractSideEffects).toList())); case Get(var location, var left, var name) -> Stream.concat(extractSideEffects(left), extractSideEffects(name)); case At(var location, var left, var index) -> Stream.concat(extractSideEffects(left), extractSideEffects(index)); case GetOrAt(var location, var left, var nameOrIndex) -> Stream.concat(extractSideEffects(left), extractSideEffects(nameOrIndex)); case ListLiteral(var location, var elements) -> elements.stream().flatMap(Optimizer::extractSideEffects); case ObjectLiteral(var location, var content) -> content.values().stream().flatMap(Optimizer::extractSideEffects); case Divide(var location, var dividend, var divisor) -> Stream.concat(extractSideEffects(dividend), extractSideEffects(divisor)); case Negate(var location, var inner) -> extractSideEffects(inner); case Add(var location, var augend, var addend) -> Stream.concat(extractSideEffects(augend), extractSideEffects(addend)); case Subtract(var location, var minuend, var subtrahend) -> Stream.concat(extractSideEffects(minuend), extractSideEffects(subtrahend)); case Multiply(var location, var multiplier, var multiplicand) -> Stream.concat(extractSideEffects(multiplier), extractSideEffects(multiplicand)); case Modulo(var location, var dividend, var divisor) -> Stream.concat(extractSideEffects(dividend), extractSideEffects(divisor)); case Power(var location, var base, var exponent) -> Stream.concat(extractSideEffects(base), extractSideEffects(exponent)); case Concatenate(var location, var left, var right) -> Stream.concat(extractSideEffects(left), extractSideEffects(right)); }; } private static Stream extractConditionalSideEffects(CodeLocation location, BoolExpr condition, Expr ifTrue, Expr ifFalse) { List trueSE = extractSideEffects(ifTrue).toList(); List falseSE = extractSideEffects(ifFalse).toList(); if (trueSE.isEmpty() && falseSE.isEmpty()) return extractSideEffects(condition); return Stream.of(new DynamicConditional( location, condition, asDynamic(ExprGroup.of(ifTrue.location(), trueSE, false)), asDynamic(ExprGroup.of(ifFalse.location(), falseSE, false)) )); } }