diff --git a/src/SciCompDSL.jl b/src/SciCompDSL.jl index 6d5d886224..fe0c581072 100644 --- a/src/SciCompDSL.jl +++ b/src/SciCompDSL.jl @@ -15,6 +15,8 @@ Base.promote_rule(::Type{T},::Type{T2}) where {T<:Number,T2<:Expression} = Expre Base.one(::Type{T}) where T<:Expression = Constant(1) Base.zero(::Type{T}) where T<:Expression = Constant(0) +function caclulate_jacobian end + include("operations.jl") include("operators.jl") include("systems/diffeqs/diffeqsystem.jl") diff --git a/src/systems/nonlinear/nonlinear_system.jl b/src/systems/nonlinear/nonlinear_system.jl index 25f8b0de14..94f2fb710b 100644 --- a/src/systems/nonlinear/nonlinear_system.jl +++ b/src/systems/nonlinear/nonlinear_system.jl @@ -37,14 +37,10 @@ function generate_nlsys_function(sys::NonlinearSystem) :((du,u,p)->$(block)) end -function generate_nlsys_jacobian(sys::NonlinearSystem,simplify=true) - var_exprs = [:($(sys.vs[i].name) = u[$i]) for i in 1:length(sys.vs)] - param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)] - +function calculate_jacobian(sys::NonlinearSystem,simplify=true) sys_idxs = map(eq->isequal(eq.args[1],Constant(0)),sys.eqs) sys_eqs = sys.eqs[sys_idxs] calc_eqs = sys.eqs[.!(sys_idxs)] - sys_exprs = [:($(Symbol("resid[$i]")) = $(sys_eqs[i].args[2])) for i in eachindex(sys_eqs)] rhs = [eq.args[2] for eq in sys_eqs] for i in 1:length(calc_eqs) @@ -59,5 +55,15 @@ function generate_nlsys_jacobian(sys::NonlinearSystem,simplify=true) sys_exprs end +function generate_nlsys_jacobian(sys::NonlinearSystem,simplify=true) + var_exprs = [:($(sys.vs[i].name) = u[$i]) for i in 1:length(sys.vs)] + param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)] + jac = calculate_jacobian(sys,simplify) + jac_exprs = [:(J[$i,$j] = $(Expr(jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)] + exprs = vcat(var_exprs,param_exprs,vec(jac_exprs)) + block = expr_arr_to_block(exprs) + :((J,u,p,t)->$(block)) +end + export NonlinearSystem export generate_nlsys_function diff --git a/test/derivatives.jl b/test/derivatives.jl index 23734b9866..dfb96dc02d 100644 --- a/test/derivatives.jl +++ b/test/derivatives.jl @@ -30,7 +30,7 @@ eqs = [0 ~ σ*(y-x), 0 ~ x*(ρ-z)-y, 0 ~ x*y - β*z] sys = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β]) -jac = SciCompDSL.generate_nlsys_jacobian(sys) +jac = SciCompDSL.calculate_jacobian(sys) @test jac[1,1] == σ*-1 @test jac[1,2] == σ @test jac[1,3] == 0 diff --git a/test/system_construction.jl b/test/system_construction.jl index 6a489e9b13..2a91ed166a 100644 --- a/test/system_construction.jl +++ b/test/system_construction.jl @@ -104,4 +104,5 @@ eqs = [a ~ y-x, 0 ~ x*y - β*z] ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β]) nlsys_func = SciCompDSL.generate_nlsys_function(ns) +jac = SciCompDSL.calculate_jacobian(ns) jac = SciCompDSL.generate_nlsys_jacobian(ns)