hsim/src/lib/std/solve.ml

228 lines
8.7 KiB
OCaml

open Hsim
open Types
type nonrec 'a value = 'a value
type nonrec 'a signal = 'a signal
type nonrec 'a signal_t = 'a signal_t
type time = float
type solver = RK45 | Sundials
(** Get a value's horizon [h] (reminder: a value is defined on [[0,h]]). *)
let horizon { h; _ } = h
(** Create a value from a horizon and function. *)
let make (h, u) = { h; u; c=Discontinuous }
(** Apply a value at a time t. *)
let apply ({ u; h; _ }, t) =
if t > h then raise (Invalid_argument (Format.sprintf
"Requested time t=%.10e is greater than the horizon h=%.10e" t h));
u t
let build_sim
(solver : solver)
(model : Ztypes.cstate -> (time * 'a, 'b) Ztypes.node)
: (unit *
((Ztypes.cvec, Ztypes.dvec) Solver.ivp *
(Ztypes.cvec, Ztypes.zoutvec) Solver.zc), 'a signal, 'b signal_t) dnode
= let model = Lift.lift model in
let solver = Hsim.Solver.solver
(match solver with
| RK45 -> d_of_dc @@ Solvers.StatefulRK45.InPlace.csolve ()
| Sundials -> Solvers.StatefulSundials.InPlace.csolve ())
(d_of_dc @@ Solvers.StatefulZ.InPlace.zsolve ()) in
let open Hsim.Sim.Sim(Hsim.State.InPlaceSimState) in
let DNode s = Hsim.Utils.(compose (run model solver) track) in
DNode { s with reset=fun p -> s.reset (p, ())}
let build_sim_2024
(solver : solver)
(model : Ztypes.cstate_new -> (time * 'a, 'b) Ztypes.node)
: (unit *
((Ztypes.cvec, Ztypes.dvec) Solver.ivp *
(Ztypes.cvec, Ztypes.zoutvec) Solver.zc), 'a signal, 'b signal_t) dnode
= let model = Lift.lift_2024 model in
let solver = Hsim.Solver.solver
(match solver with
| RK45 -> d_of_dc @@ Solvers.StatefulRK45.InPlace.csolve ()
| Sundials -> Solvers.StatefulSundials.InPlace.csolve ())
(d_of_dc @@ Solvers.StatefulZ.InPlace.zsolve ()) in
let open Hsim.Sim.Sim(Hsim.State.InPlaceSimState) in
let DNode s = Hsim.Utils.(compose (run model solver) track) in
DNode { s with reset=fun p -> s.reset (p, ())}
(** Lift a hybrid node into a discrete node on streams of functions. *)
let solve
(solver : solver)
(model : Ztypes.cstate -> (time * 'a, 'b) Ztypes.node)
: ('a signal, 'b signal_t) Ztypes.node
= let DNode sim = build_sim solver model in
let alloc () = ref sim.state in
let step s a = let b, s' = sim.step !s a in s := s'; b in
let reset _ = () in
Ztypes.Node { alloc; step; reset }
let solve_2024
(solver : solver)
(model : Ztypes.cstate_new -> (time * 'a, 'b) Ztypes.node)
: ('a signal, 'b signal_t) Ztypes.node
= let DNode sim = build_sim_2024 solver model in
let alloc () = ref sim.state in
let step s a = let b, s' = sim.step !s a in s := s'; b in
let reset _ = () in
Ztypes.Node { alloc; step; reset }
let solve_ode45 m = solve RK45 m
let solve_ode45_2024 m = solve_2024 RK45 m
let solve_sundials m = solve Sundials m
let solve_sundials_2024 m = solve_2024 Sundials m
(** Utility function for [synchr].
During synchronization, step the simulation that is lagging behind ([m]) and
join it with the stored value for the other ([n]).
Takes as arguments:
- The step method for [m];
- The input;
- The last stop times for [m] and [n];
- The state of [m];
- The stored value for [n].
Returns:
- The common output value up to the common reached date;
- The new reached date of [m];
- The stored value for [m];
- The stored value for [n]. *)
let synchr_neq
(m_step : 'ms -> 'a signal -> 'b signal_t)
(input : 'a signal)
(m_stop : time) (n_stop : time) (m_state : 'ms) (n_value : 'c value)
: ('b * 'c) signal_t * time * 'b signal * 'c signal
= match m_step m_state input with
| None -> None, m_stop, None, Some n_value
| Some (m_value, m_start) ->
let m_stop = m_start +. m_value.h in
let m_value, n_value, m_rest, n_rest =
(* Three possible scenarios: *)
if m_stop < n_stop then begin
(* [m] is still behind [n]: cut off [n_value] at [m_stop'] *)
let n_value, n_rest = Utils.cutoff n_value m_value.h in
m_value, n_value, None, Some n_rest
end else if n_stop < m_stop then begin
(* [m] overtakes [n]: cut off [m_value] at [n_stop] *)
let m_value, m_rest = Utils.cutoff m_value (n_stop -. m_start) in
m_value, n_value, Some m_rest, None
end else
(* [m] reaches [n] exactly: *)
m_value, n_value, None, None in
let mn_value = Utils.join m_value n_value in
Some (mn_value, m_start), m_stop, m_rest, n_rest
(** Utility function for [synchr].
During synchronization, step both simulations at the same time.
Takes as arguments:
- The step functions for both simulations;
- The input;
- The states of both simulations;
- The last stop times of both simulations.
Returns:
- The common output value up to the common reached date;
- The new stop times for both simulations;
- The new stored values for both simulations. *)
let synchr_eq
(m_step : 'ms -> 'a signal -> 'b signal_t)
(n_step : 'ns -> 'a signal -> 'c signal_t)
(input : 'a signal) (m_state : 'ms) (n_state : 'ns)
(m_stop : time) (n_stop : time)
: ('b * 'c) signal_t * time * time * 'b signal * 'c signal
= match m_step m_state input, n_step n_state input with
| Some (m_value, m_start), Some (n_value, n_start) ->
let m_stop, n_stop = m_start +. m_value.h, n_start +. n_value.h in
let m_value, n_value, m_rest, n_rest =
if m_stop < n_stop then
let n_value, n_rest = Utils.cutoff n_value m_value.h in
m_value, n_value, None, Some n_rest
else if m_stop > n_stop then
let m_value, m_rest = Utils.cutoff m_value n_value.h in
m_value, n_value, Some m_rest, None
else m_value, n_value, None, None in
let mn_value = Utils.join m_value n_value in
Some (mn_value, min m_stop n_stop), m_stop, n_stop, m_rest, n_rest
| None, None -> None, m_stop, n_stop, None, None
| _ -> assert false
(** Synchronize two simulations as much as possible. *)
let synchr
(m : ('a signal, 'b signal_t) Ztypes.node)
(n : ('a signal, 'c signal_t) Ztypes.node)
: ('a signal, ('b * 'c) signal_t) Ztypes.node
= let Ztypes.Node { alloc=m_alloc; step=m_step; reset=m_reset } = m in
let Ztypes.Node { alloc=n_alloc; step=n_step; reset=n_reset } = n in
let alloc () =
ref ((0.0, None, m_alloc ()), (0.0, None, n_alloc ())) in
let step state input =
let (m_stop, m_value, m_state), (n_stop, n_value, n_state) = !state in
let m_stop, m_rest, m_state, n_stop, n_rest, n_state, output =
if m_stop < n_stop then
let n_value = Option.get n_value in
let output, m_stop, m_rest, n_rest =
synchr_neq m_step input m_stop n_stop m_state n_value in
m_stop, m_rest, m_state, n_stop, n_rest, n_state, output
else if m_stop > n_stop then
let m_value = Option.get m_value in
let output, n_stop, n_rest, m_rest =
synchr_neq n_step input n_stop m_stop n_state m_value in
let output = Option.map (fun (u, t) -> Utils.swap u, t) output in
m_stop, m_rest, m_state, n_stop, n_rest, n_state, output
else
let output, m_stop, n_stop, m_rest, n_rest =
synchr_eq m_step n_step input m_state n_state m_stop n_stop in
m_stop, m_rest, m_state, n_stop, n_rest, n_state, output in
state := (m_stop, m_rest, m_state), (n_stop, n_rest, n_state);
output in
let reset ({ contents=((_, _, ms), (_, _, ns)) } as s) =
n_reset ns; m_reset ms; s := (0.0, None, ms), (0.0, None, ns) in
Ztypes.Node { alloc; step; reset }
(** Sample a value [n] times and iterate [f] on the samples. *)
let iter n f =
let Ztypes.Node { alloc; step; reset } = f in
let step s =
Option.iter @@ fun (v, _) ->
List.iter (fun (_, v) -> step s v) @@ Utils.sample v n in
Ztypes.Node { alloc; step; reset }
(** Sample a value [n] times and iterate [f] on the dated samples. *)
let iter_t n f =
let Ztypes.Node { alloc; step; reset } = f in
let step s =
Option.iter @@ fun (v, h) ->
List.iter (fun (t, v) -> step s (t +. h, v)) @@ Utils.sample v n in
Ztypes.Node { alloc; step; reset }
(** Sample a value [n] times and assert [f] on the samples. *)
let check
(n : int)
(Ztypes.Node { alloc; step; reset } : ('a, bool) Ztypes.node)
: ('a signal_t, unit) Ztypes.node
= let step s (now, v) =
try assert (step s v)
with Assert_failure _ ->
(Format.eprintf "Assertion failed at time %.10e\n" now; exit 1) in
iter_t n (Ztypes.Node { alloc; reset; step })
(** Sample a value [n] times and assert [f] on the dated samples. *)
let check_t
(n : int)
(Ztypes.Node { alloc; step; reset } : (time * 'a, bool) Ztypes.node)
: ('a signal_t, unit) Ztypes.node
= let step s (now, v) =
try assert (step s (now, v))
with Assert_failure _ ->
(Format.eprintf "Assertion failed at time %.10e\n" now; exit 1) in
iter_t n (Ztypes.Node { alloc; reset; step })