open Hsim open Types type nonrec 'a value = 'a value type nonrec 'a signal = 'a signal type nonrec 'a signal_t = 'a signal_t type time = float type solver = RK45 | Sundials (** Get a value's horizon [h] (reminder: a value is defined on [[0,h]]). *) let horizon { h; _ } = h (** Create a value from a horizon and function. *) let make (h, u) = { h; u; c=Discontinuous } (** Apply a value at a time t. *) let apply ({ u; h; _ }, t) = if t > h then raise (Invalid_argument (Format.sprintf "Requested time t=%.10e is greater than the horizon h=%.10e" t h)); u t let build_sim (solver : solver) (model : Ztypes.cstate -> (time * 'a, 'b) Ztypes.node) : (unit * ((Ztypes.cvec, Ztypes.dvec) Solver.ivp * (Ztypes.cvec, Ztypes.zoutvec) Solver.zc), 'a signal, 'b signal_t) dnode = let model = Lift.lift model in let solver = Hsim.Solver.solver (match solver with | RK45 -> d_of_dc @@ Solvers.StatefulRK45.InPlace.csolve () | Sundials -> Solvers.StatefulSundials.InPlace.csolve ()) (d_of_dc @@ Solvers.StatefulZ.InPlace.zsolve ()) in let open Hsim.Sim.Sim(Hsim.State.InPlaceSimState) in let DNode s = Hsim.Utils.(compose (run model solver) track) in DNode { s with reset=fun p -> s.reset (p, ())} let build_sim_2024 (solver : solver) (model : Ztypes.cstate_new -> (time * 'a, 'b) Ztypes.node) : (unit * ((Ztypes.cvec, Ztypes.dvec) Solver.ivp * (Ztypes.cvec, Ztypes.zoutvec) Solver.zc), 'a signal, 'b signal_t) dnode = let model = Lift.lift_2024 model in let solver = Hsim.Solver.solver (match solver with | RK45 -> d_of_dc @@ Solvers.StatefulRK45.InPlace.csolve () | Sundials -> Solvers.StatefulSundials.InPlace.csolve ()) (d_of_dc @@ Solvers.StatefulZ.InPlace.zsolve ()) in let open Hsim.Sim.Sim(Hsim.State.InPlaceSimState) in let DNode s = Hsim.Utils.(compose (run model solver) track) in DNode { s with reset=fun p -> s.reset (p, ())} (** Lift a hybrid node into a discrete node on streams of functions. *) let solve (solver : solver) (model : Ztypes.cstate -> (time * 'a, 'b) Ztypes.node) : ('a signal, 'b signal_t) Ztypes.node = let DNode sim = build_sim solver model in let alloc () = ref sim.state in let step s a = let b, s' = sim.step !s a in s := s'; b in let reset _ = () in Ztypes.Node { alloc; step; reset } let solve_2024 (solver : solver) (model : Ztypes.cstate_new -> (time * 'a, 'b) Ztypes.node) : ('a signal, 'b signal_t) Ztypes.node = let DNode sim = build_sim_2024 solver model in let alloc () = ref sim.state in let step s a = let b, s' = sim.step !s a in s := s'; b in let reset _ = () in Ztypes.Node { alloc; step; reset } let solve_ode45 m = solve RK45 m let solve_ode45_2024 m = solve_2024 RK45 m let solve_sundials m = solve Sundials m let solve_sundials_2024 m = solve_2024 Sundials m (** Utility function for [synchr]. During synchronization, step the simulation that is lagging behind ([m]) and join it with the stored value for the other ([n]). Takes as arguments: - The step method for [m]; - The input; - The last stop times for [m] and [n]; - The state of [m]; - The stored value for [n]. Returns: - The common output value up to the common reached date; - The new reached date of [m]; - The stored value for [m]; - The stored value for [n]. *) let synchr_neq (m_step : 'ms -> 'a signal -> 'b signal_t) (input : 'a signal) (m_stop : time) (n_stop : time) (m_state : 'ms) (n_value : 'c value) : ('b * 'c) signal_t * time * 'b signal * 'c signal = match m_step m_state input with | None -> None, m_stop, None, Some n_value | Some (m_value, m_start) -> let m_stop = m_start +. m_value.h in let m_value, n_value, m_rest, n_rest = (* Three possible scenarios: *) if m_stop < n_stop then begin (* [m] is still behind [n]: cut off [n_value] at [m_stop'] *) let n_value, n_rest = Utils.cutoff n_value m_value.h in m_value, n_value, None, Some n_rest end else if n_stop < m_stop then begin (* [m] overtakes [n]: cut off [m_value] at [n_stop] *) let m_value, m_rest = Utils.cutoff m_value (n_stop -. m_start) in m_value, n_value, Some m_rest, None end else (* [m] reaches [n] exactly: *) m_value, n_value, None, None in let mn_value = Utils.join m_value n_value in Some (mn_value, m_start), m_stop, m_rest, n_rest (** Utility function for [synchr]. During synchronization, step both simulations at the same time. Takes as arguments: - The step functions for both simulations; - The input; - The states of both simulations; - The last stop times of both simulations. Returns: - The common output value up to the common reached date; - The new stop times for both simulations; - The new stored values for both simulations. *) let synchr_eq (m_step : 'ms -> 'a signal -> 'b signal_t) (n_step : 'ns -> 'a signal -> 'c signal_t) (input : 'a signal) (m_state : 'ms) (n_state : 'ns) (m_stop : time) (n_stop : time) : ('b * 'c) signal_t * time * time * 'b signal * 'c signal = match m_step m_state input, n_step n_state input with | Some (m_value, m_start), Some (n_value, n_start) -> let m_stop, n_stop = m_start +. m_value.h, n_start +. n_value.h in let m_value, n_value, m_rest, n_rest = if m_stop < n_stop then let n_value, n_rest = Utils.cutoff n_value m_value.h in m_value, n_value, None, Some n_rest else if m_stop > n_stop then let m_value, m_rest = Utils.cutoff m_value n_value.h in m_value, n_value, Some m_rest, None else m_value, n_value, None, None in let mn_value = Utils.join m_value n_value in Some (mn_value, min m_stop n_stop), m_stop, n_stop, m_rest, n_rest | None, None -> None, m_stop, n_stop, None, None | _ -> assert false (** Synchronize two simulations as much as possible. *) let synchr (m : ('a signal, 'b signal_t) Ztypes.node) (n : ('a signal, 'c signal_t) Ztypes.node) : ('a signal, ('b * 'c) signal_t) Ztypes.node = let Ztypes.Node { alloc=m_alloc; step=m_step; reset=m_reset } = m in let Ztypes.Node { alloc=n_alloc; step=n_step; reset=n_reset } = n in let alloc () = ref ((0.0, None, m_alloc ()), (0.0, None, n_alloc ())) in let step state input = let (m_stop, m_value, m_state), (n_stop, n_value, n_state) = !state in let m_stop, m_rest, m_state, n_stop, n_rest, n_state, output = if m_stop < n_stop then let n_value = Option.get n_value in let output, m_stop, m_rest, n_rest = synchr_neq m_step input m_stop n_stop m_state n_value in m_stop, m_rest, m_state, n_stop, n_rest, n_state, output else if m_stop > n_stop then let m_value = Option.get m_value in let output, n_stop, n_rest, m_rest = synchr_neq n_step input n_stop m_stop n_state m_value in let output = Option.map (fun (u, t) -> Utils.swap u, t) output in m_stop, m_rest, m_state, n_stop, n_rest, n_state, output else let output, m_stop, n_stop, m_rest, n_rest = synchr_eq m_step n_step input m_state n_state m_stop n_stop in m_stop, m_rest, m_state, n_stop, n_rest, n_state, output in state := (m_stop, m_rest, m_state), (n_stop, n_rest, n_state); output in let reset ({ contents=((_, _, ms), (_, _, ns)) } as s) = n_reset ns; m_reset ms; s := (0.0, None, ms), (0.0, None, ns) in Ztypes.Node { alloc; step; reset } (** Sample a value [n] times and iterate [f] on the samples. *) let iter n f = let Ztypes.Node { alloc; step; reset } = f in let step s = Option.iter @@ fun (v, _) -> List.iter (fun (_, v) -> step s v) @@ Utils.sample v n in Ztypes.Node { alloc; step; reset } (** Sample a value [n] times and iterate [f] on the dated samples. *) let iter_t n f = let Ztypes.Node { alloc; step; reset } = f in let step s = Option.iter @@ fun (v, h) -> List.iter (fun (t, v) -> step s (t +. h, v)) @@ Utils.sample v n in Ztypes.Node { alloc; step; reset } (** Sample a value [n] times and assert [f] on the samples. *) let check (n : int) (Ztypes.Node { alloc; step; reset } : ('a, bool) Ztypes.node) : ('a signal_t, unit) Ztypes.node = let step s (now, v) = try assert (step s v) with Assert_failure _ -> (Format.eprintf "Assertion failed at time %.10e\n" now; exit 1) in iter_t n (Ztypes.Node { alloc; reset; step }) (** Sample a value [n] times and assert [f] on the dated samples. *) let check_t (n : int) (Ztypes.Node { alloc; step; reset } : (time * 'a, bool) Ztypes.node) : ('a signal_t, unit) Ztypes.node = let step s (now, v) = try assert (step s (now, v)) with Assert_failure _ -> (Format.eprintf "Assertion failed at time %.10e\n" now; exit 1) in iter_t n (Ztypes.Node { alloc; reset; step })