diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 89486e8c8..c4656bcb3 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -136,6 +136,8 @@ include("rule.jl") include("matchers.jl") include("rewriters.jl") +include("rule2.jl") + # Convert to an efficient multi-variate polynomial representation import DynamicPolynomials export expand diff --git a/src/rule2.jl b/src/rule2.jl new file mode 100644 index 000000000..eb76c9015 --- /dev/null +++ b/src/rule2.jl @@ -0,0 +1,143 @@ +# empty Base.ImmutableDict of the correct type +const SymsType = BasicSymbolic{SymReal} +const MatchDict = ImmutableDict{Symbol, SymsType} +const NO_MATCHES = MatchDict() # or {Symbol, Union{Symbol, Real}} ? +const FAIL_DICT = MatchDict(:_fail,0) +const op_map = Dict(:+ => 0, :* => 1, :^ => 1) + +""" +data is a symbolic expression, we need to check if respects the rule +rule is a quoted expression, representing part of the rule +matches is the dictionary of the matches found so far + +return value is a ImmutableDict +1) if a mismatch is found, FAIL_DICT is returned. +2) if no mismatch is found but no new matches either (for example in mathcing ^2), the original matches is returned +3) otherwise the dictionary of old + new ones is returned that could look like: +Base.ImmutableDict{Symbol, SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymReal}}(:x => a, :y => b) + +TODO matches does assigment or mutation? which is faster? +""" +function check_expr_r(data::SymsType, rule::Expr, matches::MatchDict) + # println("Checking ",data," against ",rule,", with matches: ",[m for m in matches]...) + rule.head != :call && error("It happened, rule head is not a call") #it should never happen + # rule is a slot + if rule.head == :call && rule.args[1] == :(~) + if rule.args[2] in keys(matches) # if the slot has already been matched + # check if it mached the same symbolic expression + !isequal(matches[rule.args[2]],data) && return FAIL_DICT::MatchDict + return matches::MatchDict + else # if never been matched + # if there is a predicate rule.args[2] is a expression with :: + if isa(rule.args[2], Expr) + # check it + pred = rule.args[2].args[2] + !eval(pred)(SymbolicUtils.unwrap_const(data)) && return FAIL_DICT + return Base.ImmutableDict(matches, rule.args[2].args[1], data)::MatchDict + end + # if no predicate add match + return Base.ImmutableDict(matches, rule.args[2], data)::MatchDict + end + end + # if there is a deflsot in the arguments + p=findfirst(a->isa(a, Expr) && a.args[1] == :~ && isa(a.args[2], Expr) && a.args[2].args[1] == :!,rule.args[2:end]) + if p!==nothing + # build rule expr without defslot and check it + if p==1 + newr = Expr(:call, rule.args[1], :(~$(rule.args[2].args[2].args[2])), rule.args[3]) + elseif p==2 + newr = Expr(:call, rule.args[1], rule.args[2], :(~$(rule.args[3].args[2].args[2]))) + else + error("defslot error")# it should never happen + end + rv = check_expr_r(data, newr, matches) + rv!==FAIL_DICT && return rv::MatchDict + # if no normal match, check only the non-defslot part of the rule + rv = check_expr_r(data, rule.args[p==1 ? 3 : 2], matches) + # if yes match + rv!==FAIL_DICT && return Base.ImmutableDict(rv, rule.args[p+1].args[2].args[2], get(op_map, rule.args[1], -1))::MatchDict + return FAIL_DICT::MatchDict + else + # rule is a call, check operation and arguments + # - check operation + !iscall(data) && return FAIL_DICT::MatchDict + (Symbol(operation(data)) !== rule.args[1]) && return FAIL_DICT::MatchDict + # - check arguments + arg_data = arguments(data); arg_rule = rule.args[2:end]; + (length(arg_data) != length(arg_rule)) && return FAIL_DICT::MatchDict + if (rule.args[1]===:+) || (rule.args[1]===:*) + # commutative checks + for perm_arg_data in permutations(arg_data) # is the same if done on arg_rule right? + matches_this_perm = ceoaa(perm_arg_data, arg_rule, matches) + matches_this_perm!==FAIL_DICT && return matches_this_perm::MatchDict + # else try with next perm + end + # if all perm failed + return FAIL_DICT::MatchDict + else + # normal checks + return ceoaa(arg_data, arg_rule, matches)::MatchDict + end + end +end + +# check expression of all arguments +function ceoaa(arg_data, arg_rule, matches::MatchDict) + println(typeof(arg_data), typeof(arg_rule)) + for (a, b) in zip(arg_data, arg_rule) + matches = check_expr_r(a, b, matches) + matches===FAIL_DICT && return FAIL_DICT::MatchDict + # else the match has been added (or not added but confirmed) + end + return matches::MatchDict +end + +# for when the rule contains a constant, a literal number +function check_expr_r(data::SymsType, rule::Real, matches::MatchDict) + # println("Checking ",data," against the real ",rule,", with matches: ",[m for m in matches]...) + unw = unwrap_const(data) + if isa(unw, Real) + unw!==rule && return FAIL_DICT::MatchDict + return matches::MatchDict + end + # else always fail + return FAIL_DICT::MatchDict +end + +""" +matches is the dictionary +rhs is the expression to be rewritten into + +TODO investigate foo in rhs not working +""" +function rewrite(matches::MatchDict, rhs::Expr)::SymsType + if rhs.head != :call + error("It happened") #it should never happen + end + # rhs is a slot or defslot + if rhs.head == :call && rhs.args[1] == :(~) + var_name = rhs.args[2] + if haskey(matches, var_name) + return matches[var_name] + else + error("No match found for variable $(var_name)") #it should never happen + end + end + # rhs is a call, reconstruct it + op = eval(rhs.args[1]) + args = SymsType[] + for a in rhs.args[2:end] + push!(args, rewrite(matches, a)) + end + return op(args...) +end + +function rewrite(matches::MatchDict, rhs::Real)::SymsType + return rhs +end + +function rule2(rule::Pair{Expr, Expr}, exp::SymsType)::Union{SymsType, Nothing} + m = check_expr_r(exp, rule.first, NO_MATCHES) + m===FAIL_DICT && return nothing + return rewrite(m, rule.second) +end diff --git a/test/rule2.jl b/test/rule2.jl new file mode 100644 index 000000000..5ee006a19 --- /dev/null +++ b/test/rule2.jl @@ -0,0 +1,159 @@ + + +function int_and_subst(expr::SymsType, var::SymsType, old::SymsType, new::SymsType, tag::String)::SymsType + print("int_and_subst called with expr: "); show(expr); println() + print(" var: "); show(var); println() + print(" old: "); show(old); println() + print(" new: "); show(new); println() + print(" tag: "); show(tag); println() + return expr #dummy +end + + +function generate_random_expression_r(depth::Int, syms::Vector{SymsType}, operations::Vector{Symbol})::SymsType + if depth == 0 + return syms[rand(1:length(syms))] + end + op = operations[rand(1:length(operations))] + if op in [:+, :-, :*, :/, :^] + left = generate_random_expression_r(depth - 1, syms, operations) + right = generate_random_expression_r(depth - 1, syms, operations) + return eval(op)(left, right) + elseif op in [:log, :sin, :cos, :exp] + arg = generate_random_expression_r(depth - 1, syms, operations) + return eval(op)(arg) + elseif op == :∫ + var = syms[rand(1:length(syms))] + integrand = generate_random_expression_r(depth - 1, syms, operations) + return ∫(integrand, var) + else + error("Unknown operation: $op") + end +end + +function generate_random_expression() + operations = [:+, :-, :*, :/, :^, :log, :sin, :cos, :exp, :∫] + syms = [a, b, c, d, e, f, g, h] + return generate_random_expression_r(3, syms, operations) +end + +# function generate_random_rule_r2(depth::Int, syms::Vector{SymsType}, operations::Vector{Symbol})::Expr +# if depth == 0 +# choosen = syms[rand(1:length(syms))] +# return :(~($choosen)) +# end +# op = operations[rand(1:length(operations))] +# if op in [:+, :-, :*, :/, :^] +# left = generate_random_rule_r2(depth - 1, syms, operations) +# right = generate_random_rule_r2(depth - 1, syms, operations) +# return Expr(:call, op, left, right) +# elseif op in [:log, :sin, :cos, :exp] +# arg = generate_random_rule_r(depth - 1, syms, operations) +# return Expr(:call, op, arg) +# elseif op == :∫ +# var = syms[rand(1:length(syms))] +# integrand = generate_random_rule_r2(depth - 1, syms, operations) +# return Expr(:call, :∫, integrand, Expr(:call, :~, var)) +# else +# error("Unknown operation: $op") +# end +# end +# +# function generate_random_rule2() +# operations = [:+, :-, :*, :/, :^, :log, :sin, :cos, :exp, :∫] +# syms = [a, b, c, d, e, f, g, h] +# lhs = generate_random_rule_r2(3, syms, operations) +# rhs = generate_random_rule_r2(3, syms, operations) +# return (lhs, rhs) +# end +# +# function generate_random_rule1() +# operations = [:+, :-, :*, :/, :^, :log, :sin, :cos, :exp, :∫] +# syms = [a, b, c, d, e, f, g, h] +# lhs = generate_random_rule_r2(3, syms, operations) +# rhs = generate_random_rule_r2(3, syms, operations) +# r = @rule lhs => rhs +# println("Generated random rule: "); show(r); println(typeof(r)) +# return r +# end + +function testrule1(n::Int, verbose::Bool=false) + @syms x ∫(var1, var2) a b c d e f g h + rules = SymbolicUtils.Rule[] + for i in 1:n + r = @rule ∫(((~f) + (~!g)*(~x))^(~!q)*((~!a) + (~!b)*log((~!c)*((~d) + (~!e)*(~x))^(~!n)))^(~!p),(~x)) => + 1⨸(~e)*int_and_subst(((~f)*(~x)⨸(~d))^(~q)*((~a) + (~b)*log((~c)*(~x)^(~n)))^(~p), (~x), (~x), (~d) + (~e)*(~x), "3_3_2") + push!(rules, r) + end + + # set random seed for reproducibility + Random.seed!(1234) + for i in 1:n + rex = generate_random_expression() + verbose && print("$i) checking against expression: ", rex) + result = rules[i](rex) + if result === nothing + verbose && println(" NO MATCH") + else + verbose && println(" YES MATCH: ", result) + end + end +end + + +function testrule2(n::Int, verbose::Bool=false) + @syms x ∫(var1, var2) a b c d e f g h + rules = Pair{Expr, Expr}[] + for i in 1:n + r = (:(∫(((~f) + (~!g)*(~x))^(~!q)*((~!a) + (~!b)*log((~!c)*((~d) + (~!e)*(~x))^(~!n)))^(~!p),(~x))) => + :(1⨸(~e)*int_and_subst(((~f)*(~x)⨸(~d))^(~q)*((~a) + (~b)*log((~c)*(~x)^(~n)))^(~p), (~x), (~x), (~d) + (~e)*(~x), "3_3_2"))) + push!(rules, r) + end + + # set random seed for reproducibility + Random.seed!(1234) + for i in 1:n + rex = generate_random_expression() + verbose && print("$i) checking against expression: ", rex) + result = SymbolicUtils.rule2(rules[i], rex) + if result === nothing + verbose && println(" NO MATCH") + else + verbose && println(" YES MATCH: ", result) + end + end +end + + + +""" Results on macbook air m1: +julia> @benchmark testrule2(\$1000) +BenchmarkTools.Trial: 244 samples with 1 evaluation per sample. + Range (min … max): 18.481 ms … 29.089 ms ┊ GC (min … max): 0.00% … 30.02% + Time (median): 19.456 ms ┊ GC (median): 0.00% + Time (mean ± σ): 20.564 ms ± 2.652 ms ┊ GC (mean ± σ): 6.09% ± 10.54% + + ▄▆█▇▄▁ + ▇█▇██████▁▄▆▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▆▆▄▁▆▄▁▄▁▁▄▄▁▄▆▆▄▄▁▇▆▇▆▄▇▆▁▆▄ ▆ + 18.5 ms Histogram: log(frequency) by time 28.2 ms < + + Memory estimate: 13.07 MiB, allocs estimate: 356839. + +julia> @benchmark testrule1(\$1000) +BenchmarkTools.Trial: 11 samples with 1 evaluation per sample. + Range (min … max): 446.396 ms … 472.119 ms ┊ GC (min … max): 0.00% … 5.67% + Time (median): 461.125 ms ┊ GC (median): 3.27% + Time (mean ± σ): 460.506 ms ± 7.303 ms ┊ GC (mean ± σ): 3.12% ± 1.73% + + ▁ ▁ █ █ ▁ ▁ ▁ ▁ ▁ + █▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁█▁█▁▁█▁▁▁█▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁█ ▁ + 446 ms Histogram: frequency by time 472 ms < + + Memory estimate: 110.40 MiB, allocs estimate: 3393493. +""" + +function testpredicates() + @syms ∫ a + SymbolicUtils.rule2(:(∫(1 / (~x)^(~m::iseven), ~x)) => :(log(~x)*~m), ∫(1/a^3,a))===nothing + SymbolicUtils.rule2(:(∫(1 / (~x)^(~m::iseven), ~x)) => :(log(~x)*~m), ∫(1/a^2,a))!==nothing +end \ No newline at end of file