feat: lift runtime into language, start of zelus 2024 compatibility
This commit is contained in:
parent
dc8d941b84
commit
ffc583985a
37 changed files with 1154 additions and 143 deletions
228
src/lib/std/solve.ml
Normal file
228
src/lib/std/solve.ml
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
|
||||
open Hsim
|
||||
open Types
|
||||
|
||||
type nonrec 'a value = 'a value
|
||||
type nonrec 'a signal = 'a signal
|
||||
type nonrec 'a signal_t = 'a signal_t
|
||||
|
||||
type time = float
|
||||
|
||||
type solver = RK45 | Sundials
|
||||
|
||||
(** Get a value's horizon [h] (reminder: a value is defined on [[0,h]]). *)
|
||||
let horizon { h; _ } = h
|
||||
|
||||
(** Create a value from a horizon and function. *)
|
||||
let make (h, u) = { h; u; c=Discontinuous }
|
||||
|
||||
(** Apply a value at a time t. *)
|
||||
let apply ({ u; h; _ }, t) =
|
||||
if t > h then raise (Invalid_argument (Format.sprintf
|
||||
"Requested time t=%.10e is greater than the horizon h=%.10e" t h));
|
||||
u t
|
||||
|
||||
let build_sim
|
||||
(solver : solver)
|
||||
(model : Ztypes.cstate -> (time * 'a, 'b) Ztypes.node)
|
||||
: (unit *
|
||||
((Ztypes.cvec, Ztypes.dvec) Solver.ivp *
|
||||
(Ztypes.cvec, Ztypes.zoutvec) Solver.zc), 'a signal, 'b signal_t) dnode
|
||||
= let model = Lift.lift model in
|
||||
let solver = Hsim.Solver.solver
|
||||
(match solver with
|
||||
| RK45 -> d_of_dc @@ Solvers.StatefulRK45.InPlace.csolve ()
|
||||
| Sundials -> Solvers.StatefulSundials.InPlace.csolve ())
|
||||
(d_of_dc @@ Solvers.StatefulZ.InPlace.zsolve ()) in
|
||||
let open Hsim.Sim.Sim(Hsim.State.InPlaceSimState) in
|
||||
let DNode s = Hsim.Utils.(compose (run model solver) track) in
|
||||
DNode { s with reset=fun p -> s.reset (p, ())}
|
||||
|
||||
let build_sim_2024
|
||||
(solver : solver)
|
||||
(model : Ztypes.cstate_new -> (time * 'a, 'b) Ztypes.node)
|
||||
: (unit *
|
||||
((Ztypes.cvec, Ztypes.dvec) Solver.ivp *
|
||||
(Ztypes.cvec, Ztypes.zoutvec) Solver.zc), 'a signal, 'b signal_t) dnode
|
||||
= let model = Lift.lift_2024 model in
|
||||
let solver = Hsim.Solver.solver
|
||||
(match solver with
|
||||
| RK45 -> d_of_dc @@ Solvers.StatefulRK45.InPlace.csolve ()
|
||||
| Sundials -> Solvers.StatefulSundials.InPlace.csolve ())
|
||||
(d_of_dc @@ Solvers.StatefulZ.InPlace.zsolve ()) in
|
||||
let open Hsim.Sim.Sim(Hsim.State.InPlaceSimState) in
|
||||
let DNode s = Hsim.Utils.(compose (run model solver) track) in
|
||||
DNode { s with reset=fun p -> s.reset (p, ())}
|
||||
|
||||
(** Lift a hybrid node into a discrete node on streams of functions. *)
|
||||
let solve
|
||||
(solver : solver)
|
||||
(model : Ztypes.cstate -> (time * 'a, 'b) Ztypes.node)
|
||||
: ('a signal, 'b signal_t) Ztypes.node
|
||||
= let DNode sim = build_sim solver model in
|
||||
let alloc () = ref sim.state in
|
||||
let step s a = let b, s' = sim.step !s a in s := s'; b in
|
||||
let reset _ = () in
|
||||
Ztypes.Node { alloc; step; reset }
|
||||
|
||||
let solve_2024
|
||||
(solver : solver)
|
||||
(model : Ztypes.cstate_new -> (time * 'a, 'b) Ztypes.node)
|
||||
: ('a signal, 'b signal_t) Ztypes.node
|
||||
= let DNode sim = build_sim_2024 solver model in
|
||||
let alloc () = ref sim.state in
|
||||
let step s a = let b, s' = sim.step !s a in s := s'; b in
|
||||
let reset _ = () in
|
||||
Ztypes.Node { alloc; step; reset }
|
||||
|
||||
let solve_ode45 m = solve RK45 m
|
||||
let solve_ode45_2024 m = solve_2024 RK45 m
|
||||
let solve_sundials m = solve Sundials m
|
||||
let solve_sundials_2024 m = solve_2024 Sundials m
|
||||
|
||||
(** Utility function for [synchr].
|
||||
|
||||
During synchronization, step the simulation that is lagging behind ([m]) and
|
||||
join it with the stored value for the other ([n]).
|
||||
Takes as arguments:
|
||||
- The step method for [m];
|
||||
- The input;
|
||||
- The last stop times for [m] and [n];
|
||||
- The state of [m];
|
||||
- The stored value for [n].
|
||||
|
||||
Returns:
|
||||
- The common output value up to the common reached date;
|
||||
- The new reached date of [m];
|
||||
- The stored value for [m];
|
||||
- The stored value for [n]. *)
|
||||
let synchr_neq
|
||||
(m_step : 'ms -> 'a signal -> 'b signal_t)
|
||||
(input : 'a signal)
|
||||
(m_stop : time) (n_stop : time) (m_state : 'ms) (n_value : 'c value)
|
||||
: ('b * 'c) signal_t * time * 'b signal * 'c signal
|
||||
= match m_step m_state input with
|
||||
| None -> None, m_stop, None, Some n_value
|
||||
| Some (m_value, m_start) ->
|
||||
let m_stop = m_start +. m_value.h in
|
||||
let m_value, n_value, m_rest, n_rest =
|
||||
(* Three possible scenarios: *)
|
||||
if m_stop < n_stop then begin
|
||||
(* [m] is still behind [n]: cut off [n_value] at [m_stop'] *)
|
||||
let n_value, n_rest = Utils.cutoff n_value m_value.h in
|
||||
m_value, n_value, None, Some n_rest
|
||||
end else if n_stop < m_stop then begin
|
||||
(* [m] overtakes [n]: cut off [m_value] at [n_stop] *)
|
||||
let m_value, m_rest = Utils.cutoff m_value (n_stop -. m_start) in
|
||||
m_value, n_value, Some m_rest, None
|
||||
end else
|
||||
(* [m] reaches [n] exactly: *)
|
||||
m_value, n_value, None, None in
|
||||
let mn_value = Utils.join m_value n_value in
|
||||
Some (mn_value, m_start), m_stop, m_rest, n_rest
|
||||
|
||||
(** Utility function for [synchr].
|
||||
|
||||
During synchronization, step both simulations at the same time.
|
||||
Takes as arguments:
|
||||
- The step functions for both simulations;
|
||||
- The input;
|
||||
- The states of both simulations;
|
||||
- The last stop times of both simulations.
|
||||
|
||||
Returns:
|
||||
- The common output value up to the common reached date;
|
||||
- The new stop times for both simulations;
|
||||
- The new stored values for both simulations. *)
|
||||
let synchr_eq
|
||||
(m_step : 'ms -> 'a signal -> 'b signal_t)
|
||||
(n_step : 'ns -> 'a signal -> 'c signal_t)
|
||||
(input : 'a signal) (m_state : 'ms) (n_state : 'ns)
|
||||
(m_stop : time) (n_stop : time)
|
||||
: ('b * 'c) signal_t * time * time * 'b signal * 'c signal
|
||||
= match m_step m_state input, n_step n_state input with
|
||||
| Some (m_value, m_start), Some (n_value, n_start) ->
|
||||
let m_stop, n_stop = m_start +. m_value.h, n_start +. n_value.h in
|
||||
let m_value, n_value, m_rest, n_rest =
|
||||
if m_stop < n_stop then
|
||||
let n_value, n_rest = Utils.cutoff n_value m_value.h in
|
||||
m_value, n_value, None, Some n_rest
|
||||
else if m_stop > n_stop then
|
||||
let m_value, m_rest = Utils.cutoff m_value n_value.h in
|
||||
m_value, n_value, Some m_rest, None
|
||||
else m_value, n_value, None, None in
|
||||
let mn_value = Utils.join m_value n_value in
|
||||
Some (mn_value, min m_stop n_stop), m_stop, n_stop, m_rest, n_rest
|
||||
| None, None -> None, m_stop, n_stop, None, None
|
||||
| _ -> assert false
|
||||
|
||||
(** Synchronize two simulations as much as possible. *)
|
||||
let synchr
|
||||
(m : ('a signal, 'b signal_t) Ztypes.node)
|
||||
(n : ('a signal, 'c signal_t) Ztypes.node)
|
||||
: ('a signal, ('b * 'c) signal_t) Ztypes.node
|
||||
= let Ztypes.Node { alloc=m_alloc; step=m_step; reset=m_reset } = m in
|
||||
let Ztypes.Node { alloc=n_alloc; step=n_step; reset=n_reset } = n in
|
||||
let alloc () =
|
||||
ref ((0.0, None, m_alloc ()), (0.0, None, n_alloc ())) in
|
||||
let step state input =
|
||||
let (m_stop, m_value, m_state), (n_stop, n_value, n_state) = !state in
|
||||
let m_stop, m_rest, m_state, n_stop, n_rest, n_state, output =
|
||||
if m_stop < n_stop then
|
||||
let n_value = Option.get n_value in
|
||||
let output, m_stop, m_rest, n_rest =
|
||||
synchr_neq m_step input m_stop n_stop m_state n_value in
|
||||
m_stop, m_rest, m_state, n_stop, n_rest, n_state, output
|
||||
else if m_stop > n_stop then
|
||||
let m_value = Option.get m_value in
|
||||
let output, n_stop, n_rest, m_rest =
|
||||
synchr_neq n_step input n_stop m_stop n_state m_value in
|
||||
let output = Option.map (fun (u, t) -> Utils.swap u, t) output in
|
||||
m_stop, m_rest, m_state, n_stop, n_rest, n_state, output
|
||||
else
|
||||
let output, m_stop, n_stop, m_rest, n_rest =
|
||||
synchr_eq m_step n_step input m_state n_state m_stop n_stop in
|
||||
m_stop, m_rest, m_state, n_stop, n_rest, n_state, output in
|
||||
state := (m_stop, m_rest, m_state), (n_stop, n_rest, n_state);
|
||||
output in
|
||||
let reset ({ contents=((_, _, ms), (_, _, ns)) } as s) =
|
||||
n_reset ns; m_reset ms; s := (0.0, None, ms), (0.0, None, ns) in
|
||||
Ztypes.Node { alloc; step; reset }
|
||||
|
||||
(** Sample a value [n] times and iterate [f] on the samples. *)
|
||||
let iter n f =
|
||||
let Ztypes.Node { alloc; step; reset } = f in
|
||||
let step s =
|
||||
Option.iter @@ fun (v, _) ->
|
||||
List.iter (fun (_, v) -> step s v) @@ Utils.sample v n in
|
||||
Ztypes.Node { alloc; step; reset }
|
||||
|
||||
(** Sample a value [n] times and iterate [f] on the dated samples. *)
|
||||
let iter_t n f =
|
||||
let Ztypes.Node { alloc; step; reset } = f in
|
||||
let step s =
|
||||
Option.iter @@ fun (v, h) ->
|
||||
List.iter (fun (t, v) -> step s (t +. h, v)) @@ Utils.sample v n in
|
||||
Ztypes.Node { alloc; step; reset }
|
||||
|
||||
(** Sample a value [n] times and assert [f] on the samples. *)
|
||||
let check
|
||||
(n : int)
|
||||
(Ztypes.Node { alloc; step; reset } : ('a, bool) Ztypes.node)
|
||||
: ('a signal_t, unit) Ztypes.node
|
||||
= let step s (now, v) =
|
||||
try assert (step s v)
|
||||
with Assert_failure _ ->
|
||||
(Format.eprintf "Assertion failed at time %.10e\n" now; exit 1) in
|
||||
iter_t n (Ztypes.Node { alloc; reset; step })
|
||||
|
||||
(** Sample a value [n] times and assert [f] on the dated samples. *)
|
||||
let check_t
|
||||
(n : int)
|
||||
(Ztypes.Node { alloc; step; reset } : (time * 'a, bool) Ztypes.node)
|
||||
: ('a signal_t, unit) Ztypes.node
|
||||
= let step s (now, v) =
|
||||
try assert (step s (now, v))
|
||||
with Assert_failure _ ->
|
||||
(Format.eprintf "Assertion failed at time %.10e\n" now; exit 1) in
|
||||
iter_t n (Ztypes.Node { alloc; reset; step })
|
||||
Loading…
Add table
Add a link
Reference in a new issue