hsim-live/lib/hsim/zls.ml
2026-03-30 13:28:49 +02:00

216 lines
7 KiB
OCaml

(* This code was originally written by Timothy Bourke and Marc Pouzet and is *)
(* part of the Zelus standard library. *)
(* open Ztypes *)
open Bigarray
(* Interfaces functions from within Zelus *)
type carray = (float, float64_elt, c_layout) Array1.t
type zarray = (int32, int32_elt, c_layout) Array1.t
let cmake (n: int) : carray =
let r = Array1.create float64 c_layout n in
Array1.fill r 0.0;
r
let zmake (n: int) : zarray =
let r = Array1.create int32 c_layout n in
Array1.fill r 0l;
r
let length = Array1.dim
let get = Array1.get
let set = Array1.set
let get_zin v i = Array1.get v i <> 0l
(* fill zinvec with zeros *)
let zzero zinvec length =
for i = 0 to length - 1 do
Array1.set zinvec i 0l
done
let czero c length =
for i = 0 to length - 1 do
Array1.set c i 0.0
done
(* copy functions *)
(* copy [c1] into [c2] *)
let blit c1 c2 = Array1.blit c1 c2
let copy c1 = let c2 = cmake (length c1) in blit c1 c2; c2
let blit_matrix m1 m2 = Array.iter2 blit m1 m2
let copy_matrix m = Array.map copy m
type 's f_alloc = unit -> 's
type 's f_maxsize = 's -> int * int
type 's f_csize = 's -> int
type 's f_zsize = 's -> int
type ('s, 'o) f_step = 's -> carray -> carray -> zarray -> float -> 'o
type 's f_ders = 's -> carray -> carray -> zarray -> carray -> float -> unit
type 's f_zero = 's -> carray -> zarray -> carray -> float -> unit
type 's f_reset = 's -> unit
type 's f_horizon = 's -> float
(* TODO: eliminate this ? *)
(* Compare two floats for equality, see:
* http://www.cygnus-software.com/papers/comparingfloats/comparingfloats.htm *)
let time_eq f1 f2 =
if abs_float (f1 -. f2) < min_float
then true (* absolute error check for numbers around to zero *)
else
let rel_error =
if abs_float f1 > abs_float f2
then abs_float ((f1 -. f2) /. f1)
else abs_float ((f1 -. f2) /. f2)
in
(rel_error <= 0.000001)
(* Compare times with 99.9999% accuracy. *)
let time_leq t1 t2 = t1 < t2 || time_eq t1 t2
let time_geq t1 t2 = t1 > t2 || time_eq t1 t2
(* TODO:
- adapt to the new sundials interface, rework, and simplify.
- take advantage of the final field.
*)
(* Interface of a stateful ODE solver *)
module type STATE_ODE_SOLVER =
sig
(* A session with the solver. *)
type t
(* The type of vectors used internally by the solver. *)
type nvec
(* Create a vector of the given size. *)
val cmake : int -> nvec
(* Unwrap a vector returning an array of continuous-state values. *)
val unvec : nvec -> carray
(* Wrap a vector of continuous-state values into an vector. *)
val vec : carray -> nvec
(* A right-hand-side function called by the solver to calculate the
instantaneous derivatives: [f t cvec dvec].
- t, the current simulation time (input)
- cvec, current values for continuous states (input)
- dvec, the vector of instantaneous derivatives (output) *)
type rhsfn = float -> carray -> carray -> unit
(* An interpolation function: [df cvec t k]
- cvec, a vector for storing the interpolated continuous states (output)
- t, the time to interpolate at,
- k, the derivative to interpolate *)
type dkyfn = nvec -> float -> int -> unit
(* [initialize f c] creates a solver session from a function [f] and
an initial state vector [c]. *)
val initialize : rhsfn -> nvec -> t
(* [reinitialize s t c] reinitializes the solver with the given time
[t] and vector of continuous states [c]. *)
(* warning. the size of [c] must be unchanged *)
val reinitialize : ?rhsfn:rhsfn -> t -> float -> nvec -> unit
(* [t' = step s tl c] given a state vector [c], takes a step to the next
'mesh-point', or the given time limit [tl] (whichever is sooner),
updating [c]. *)
val step : t -> float -> nvec -> float
(* Returns an interpolation function that can produce results for any
time [t] since the last mesh-point or the initial instant. *)
val get_dky : t -> dkyfn
(* generic solver parameters *)
val set_stop_time : t -> float -> unit
val set_min_step : t -> float -> unit
val set_max_step : t -> float -> unit
val set_tolerances : t -> float -> float -> unit
val copy : t -> t
val blit : t -> t -> unit
end
(* Interface of a stateful zero-crossing solver *)
module type STATE_ZEROC_SOLVER =
sig
(* A session with the solver. A zero-crossing solver has two internal
continuous-state vectors, called 'before' and 'now'. *)
type t
(* Zero-crossing function: [g t cvec zout]
- t, simulation time (input)
- cvec, vector of continuous states (input)
- zout, values of zero-crossing expressions (output) *)
type zcfn = float -> carray -> carray -> unit
(* Create a session with the zero-crossing solver:
[initialize nroots g cvec0]
- nroots, number of zero-crossing expressions
- g, function to calculate zero-crossing expressions
- cvec0, initial continuous state
Sets the 'now' vector to cvec0. *)
val initialize : int -> zcfn -> carray -> t
(* The same but does not run [g] at initialization time *)
val initialize_only : int -> zcfn -> t
(* Reinitialize the zero-crossing solver after a discrete step that
updates the continuous states directly: [reinitialize s t cvec].
- s, a session with the zero-crossing solver
- t, the current simultation time
- cvec, the current continuous state vector
Resets the 'now' vector to cvec. *)
val reinitialize : t -> float -> carray -> unit
(* Advance the zero-crossing solver after a continuous step:
[step s t cvec].
- s, a session with the zero-crossing solver
- t, the current simultation time
- cvec, the current continuous state vector
Moves the current 'now' vector to 'before', then sets 'now' to cvec. *)
val step : t -> float -> carray -> unit
val takeoff : t -> bool
(* Returns true if one zero-crossing signal moves from 0 to v > 0 *)
(* Compares the 'before' and 'now' vectors and returns true only if
there exists an i, such that before[i] < 0 and now[i] >= 0. *)
val has_roots : t -> bool
(* Locates the time of the zero-crossing closest to the 'before' vector.
Call after [has_roots] indicates the existence of a zero-crossing:
[t = find s (f, c) zin].
- The [get_dky] function [f] is provided by the state solver and is
expected to update the array [c] with the interpolated state.
- zin, is populated with the status of all zero-crossings
- the returned values is the simulation time of the earliest
zero-crossing that was found. *)
val find : t -> ((float -> int -> unit) * carray) -> zarray -> float
(* locate the fields for which there is a takeoff *)
val find_takeoff : t -> zarray -> float
end
(*
module type RUNTIME =
sig
val go : unit hsimu -> unit
val check : bool hsimu -> int -> unit
end
module type DISCRETE_RUNTIME =
sig
val go : float -> (unit -> unit) -> unit
end
*)