228 lines
8.7 KiB
OCaml
228 lines
8.7 KiB
OCaml
|
|
open Hsim
|
|
open Types
|
|
|
|
type nonrec 'a value = 'a value
|
|
type nonrec 'a signal = 'a signal
|
|
type nonrec 'a signal_t = 'a signal_t
|
|
|
|
type time = float
|
|
|
|
type solver = RK45 | Sundials
|
|
|
|
(** Get a value's horizon [h] (reminder: a value is defined on [[0,h]]). *)
|
|
let horizon { h; _ } = h
|
|
|
|
(** Create a value from a horizon and function. *)
|
|
let make (h, u) = { h; u; c=Discontinuous }
|
|
|
|
(** Apply a value at a time t. *)
|
|
let apply ({ u; h; _ }, t) =
|
|
if t > h then raise (Invalid_argument (Format.sprintf
|
|
"Requested time t=%.10e is greater than the horizon h=%.10e" t h));
|
|
u t
|
|
|
|
let build_sim
|
|
(solver : solver)
|
|
(model : Ztypes.cstate -> (time * 'a, 'b) Ztypes.node)
|
|
: (unit *
|
|
((Ztypes.cvec, Ztypes.dvec) Solver.ivp *
|
|
(Ztypes.cvec, Ztypes.zoutvec) Solver.zc), 'a signal, 'b signal_t) dnode
|
|
= let model = Lift.lift model in
|
|
let solver = Hsim.Solver.solver
|
|
(match solver with
|
|
| RK45 -> d_of_dc @@ Solvers.StatefulRK45.InPlace.csolve ()
|
|
| Sundials -> Solvers.StatefulSundials.InPlace.csolve ())
|
|
(d_of_dc @@ Solvers.StatefulZ.InPlace.zsolve ()) in
|
|
let open Hsim.Sim.Sim(Hsim.State.InPlaceSimState) in
|
|
let DNode s = Hsim.Utils.(compose (run model solver) track) in
|
|
DNode { s with reset=fun p -> s.reset (p, ())}
|
|
|
|
let build_sim_2024
|
|
(solver : solver)
|
|
(model : Ztypes.cstate_new -> (time * 'a, 'b) Ztypes.node)
|
|
: (unit *
|
|
((Ztypes.cvec, Ztypes.dvec) Solver.ivp *
|
|
(Ztypes.cvec, Ztypes.zoutvec) Solver.zc), 'a signal, 'b signal_t) dnode
|
|
= let model = Lift.lift_2024 model in
|
|
let solver = Hsim.Solver.solver
|
|
(match solver with
|
|
| RK45 -> d_of_dc @@ Solvers.StatefulRK45.InPlace.csolve ()
|
|
| Sundials -> Solvers.StatefulSundials.InPlace.csolve ())
|
|
(d_of_dc @@ Solvers.StatefulZ.InPlace.zsolve ()) in
|
|
let open Hsim.Sim.Sim(Hsim.State.InPlaceSimState) in
|
|
let DNode s = Hsim.Utils.(compose (run model solver) track) in
|
|
DNode { s with reset=fun p -> s.reset (p, ())}
|
|
|
|
(** Lift a hybrid node into a discrete node on streams of functions. *)
|
|
let solve
|
|
(solver : solver)
|
|
(model : Ztypes.cstate -> (time * 'a, 'b) Ztypes.node)
|
|
: ('a signal, 'b signal_t) Ztypes.node
|
|
= let DNode sim = build_sim solver model in
|
|
let alloc () = ref sim.state in
|
|
let step s a = let b, s' = sim.step !s a in s := s'; b in
|
|
let reset _ = () in
|
|
Ztypes.Node { alloc; step; reset }
|
|
|
|
let solve_2024
|
|
(solver : solver)
|
|
(model : Ztypes.cstate_new -> (time * 'a, 'b) Ztypes.node)
|
|
: ('a signal, 'b signal_t) Ztypes.node
|
|
= let DNode sim = build_sim_2024 solver model in
|
|
let alloc () = ref sim.state in
|
|
let step s a = let b, s' = sim.step !s a in s := s'; b in
|
|
let reset _ = () in
|
|
Ztypes.Node { alloc; step; reset }
|
|
|
|
let solve_ode45 m = solve RK45 m
|
|
let solve_ode45_2024 m = solve_2024 RK45 m
|
|
let solve_sundials m = solve Sundials m
|
|
let solve_sundials_2024 m = solve_2024 Sundials m
|
|
|
|
(** Utility function for [synchr].
|
|
|
|
During synchronization, step the simulation that is lagging behind ([m]) and
|
|
join it with the stored value for the other ([n]).
|
|
Takes as arguments:
|
|
- The step method for [m];
|
|
- The input;
|
|
- The last stop times for [m] and [n];
|
|
- The state of [m];
|
|
- The stored value for [n].
|
|
|
|
Returns:
|
|
- The common output value up to the common reached date;
|
|
- The new reached date of [m];
|
|
- The stored value for [m];
|
|
- The stored value for [n]. *)
|
|
let synchr_neq
|
|
(m_step : 'ms -> 'a signal -> 'b signal_t)
|
|
(input : 'a signal)
|
|
(m_stop : time) (n_stop : time) (m_state : 'ms) (n_value : 'c value)
|
|
: ('b * 'c) signal_t * time * 'b signal * 'c signal
|
|
= match m_step m_state input with
|
|
| None -> None, m_stop, None, Some n_value
|
|
| Some (m_value, m_start) ->
|
|
let m_stop = m_start +. m_value.h in
|
|
let m_value, n_value, m_rest, n_rest =
|
|
(* Three possible scenarios: *)
|
|
if m_stop < n_stop then begin
|
|
(* [m] is still behind [n]: cut off [n_value] at [m_stop'] *)
|
|
let n_value, n_rest = Utils.cutoff n_value m_value.h in
|
|
m_value, n_value, None, Some n_rest
|
|
end else if n_stop < m_stop then begin
|
|
(* [m] overtakes [n]: cut off [m_value] at [n_stop] *)
|
|
let m_value, m_rest = Utils.cutoff m_value (n_stop -. m_start) in
|
|
m_value, n_value, Some m_rest, None
|
|
end else
|
|
(* [m] reaches [n] exactly: *)
|
|
m_value, n_value, None, None in
|
|
let mn_value = Utils.join m_value n_value in
|
|
Some (mn_value, m_start), m_stop, m_rest, n_rest
|
|
|
|
(** Utility function for [synchr].
|
|
|
|
During synchronization, step both simulations at the same time.
|
|
Takes as arguments:
|
|
- The step functions for both simulations;
|
|
- The input;
|
|
- The states of both simulations;
|
|
- The last stop times of both simulations.
|
|
|
|
Returns:
|
|
- The common output value up to the common reached date;
|
|
- The new stop times for both simulations;
|
|
- The new stored values for both simulations. *)
|
|
let synchr_eq
|
|
(m_step : 'ms -> 'a signal -> 'b signal_t)
|
|
(n_step : 'ns -> 'a signal -> 'c signal_t)
|
|
(input : 'a signal) (m_state : 'ms) (n_state : 'ns)
|
|
(m_stop : time) (n_stop : time)
|
|
: ('b * 'c) signal_t * time * time * 'b signal * 'c signal
|
|
= match m_step m_state input, n_step n_state input with
|
|
| Some (m_value, m_start), Some (n_value, n_start) ->
|
|
let m_stop, n_stop = m_start +. m_value.h, n_start +. n_value.h in
|
|
let m_value, n_value, m_rest, n_rest =
|
|
if m_stop < n_stop then
|
|
let n_value, n_rest = Utils.cutoff n_value m_value.h in
|
|
m_value, n_value, None, Some n_rest
|
|
else if m_stop > n_stop then
|
|
let m_value, m_rest = Utils.cutoff m_value n_value.h in
|
|
m_value, n_value, Some m_rest, None
|
|
else m_value, n_value, None, None in
|
|
let mn_value = Utils.join m_value n_value in
|
|
Some (mn_value, min m_stop n_stop), m_stop, n_stop, m_rest, n_rest
|
|
| None, None -> None, m_stop, n_stop, None, None
|
|
| _ -> assert false
|
|
|
|
(** Synchronize two simulations as much as possible. *)
|
|
let synchr
|
|
(m : ('a signal, 'b signal_t) Ztypes.node)
|
|
(n : ('a signal, 'c signal_t) Ztypes.node)
|
|
: ('a signal, ('b * 'c) signal_t) Ztypes.node
|
|
= let Ztypes.Node { alloc=m_alloc; step=m_step; reset=m_reset } = m in
|
|
let Ztypes.Node { alloc=n_alloc; step=n_step; reset=n_reset } = n in
|
|
let alloc () =
|
|
ref ((0.0, None, m_alloc ()), (0.0, None, n_alloc ())) in
|
|
let step state input =
|
|
let (m_stop, m_value, m_state), (n_stop, n_value, n_state) = !state in
|
|
let m_stop, m_rest, m_state, n_stop, n_rest, n_state, output =
|
|
if m_stop < n_stop then
|
|
let n_value = Option.get n_value in
|
|
let output, m_stop, m_rest, n_rest =
|
|
synchr_neq m_step input m_stop n_stop m_state n_value in
|
|
m_stop, m_rest, m_state, n_stop, n_rest, n_state, output
|
|
else if m_stop > n_stop then
|
|
let m_value = Option.get m_value in
|
|
let output, n_stop, n_rest, m_rest =
|
|
synchr_neq n_step input n_stop m_stop n_state m_value in
|
|
let output = Option.map (fun (u, t) -> Utils.swap u, t) output in
|
|
m_stop, m_rest, m_state, n_stop, n_rest, n_state, output
|
|
else
|
|
let output, m_stop, n_stop, m_rest, n_rest =
|
|
synchr_eq m_step n_step input m_state n_state m_stop n_stop in
|
|
m_stop, m_rest, m_state, n_stop, n_rest, n_state, output in
|
|
state := (m_stop, m_rest, m_state), (n_stop, n_rest, n_state);
|
|
output in
|
|
let reset ({ contents=((_, _, ms), (_, _, ns)) } as s) =
|
|
n_reset ns; m_reset ms; s := (0.0, None, ms), (0.0, None, ns) in
|
|
Ztypes.Node { alloc; step; reset }
|
|
|
|
(** Sample a value [n] times and iterate [f] on the samples. *)
|
|
let iter n f =
|
|
let Ztypes.Node { alloc; step; reset } = f in
|
|
let step s =
|
|
Option.iter @@ fun (v, _) ->
|
|
List.iter (fun (_, v) -> step s v) @@ Utils.sample v n in
|
|
Ztypes.Node { alloc; step; reset }
|
|
|
|
(** Sample a value [n] times and iterate [f] on the dated samples. *)
|
|
let iter_t n f =
|
|
let Ztypes.Node { alloc; step; reset } = f in
|
|
let step s =
|
|
Option.iter @@ fun (v, h) ->
|
|
List.iter (fun (t, v) -> step s (t +. h, v)) @@ Utils.sample v n in
|
|
Ztypes.Node { alloc; step; reset }
|
|
|
|
(** Sample a value [n] times and assert [f] on the samples. *)
|
|
let check
|
|
(n : int)
|
|
(Ztypes.Node { alloc; step; reset } : ('a, bool) Ztypes.node)
|
|
: ('a signal_t, unit) Ztypes.node
|
|
= let step s (now, v) =
|
|
try assert (step s v)
|
|
with Assert_failure _ ->
|
|
(Format.eprintf "Assertion failed at time %.10e\n" now; exit 1) in
|
|
iter_t n (Ztypes.Node { alloc; reset; step })
|
|
|
|
(** Sample a value [n] times and assert [f] on the dated samples. *)
|
|
let check_t
|
|
(n : int)
|
|
(Ztypes.Node { alloc; step; reset } : (time * 'a, bool) Ztypes.node)
|
|
: ('a signal_t, unit) Ztypes.node
|
|
= let step s (now, v) =
|
|
try assert (step s (now, v))
|
|
with Assert_failure _ ->
|
|
(Format.eprintf "Assertion failed at time %.10e\n" now; exit 1) in
|
|
iter_t n (Ztypes.Node { alloc; reset; step })
|