feat: solvers and ball example

This commit is contained in:
Henri Saudubray 2025-04-25 13:57:53 +02:00
parent cc099c02e7
commit e07f165494
Signed by: hms
GPG key ID: 7065F57ED8856128
27 changed files with 1483 additions and 290 deletions

View file

@ -0,0 +1,2 @@
This module is part of the [Zélus](https://zelus.di.ens.fr) standard library,
and was originally written by Timothy Bourke and Marc Pouzet.

4
src/lib/solvers/dune Normal file
View file

@ -0,0 +1,4 @@
(env (dev (flags (:standard -w -9-27-32))))
(library
(name solvers))

361
src/lib/solvers/illinois.ml Normal file
View file

@ -0,0 +1,361 @@
(* This code was originally written by Timothy Bourke and Marc Pouzet and is *)
(* part of the Zelus standard library. *)
(* It is implemented with in-place modification of arrays. *)
let debug = ref false
let printf x = Format.printf x
type root_direction = Up | Down | Either | Ignore
let extra_precision = ref false
let set_precise_logging _ = (extra_precision := true)
let fold_zxzx f acc f0 f1 =
let n = Zls.length f0 in
let rec fold acc i =
if i = n then acc
else
let acc' = f i acc f0.{i} f1.{i} in
fold acc' (i + 1)
in fold acc 0
(* return a function that looks for zero-crossings *)
let get_check_root rdir =
let check_up x0 x1 = if x0 < 0.0 && x1 >= 0.0 then 1l else 0l in
let check_down x0 x1 = if x0 > 0.0 && x1 <= 0.0 then -1l else 0l in
let check_either x0 x1 = if x0 < 0.0 && x1 >= 0.0 then 1l else
if x0 > 0.0 && x1 <= 0.0 then -1l else 0l in
let no_check x0 x1 = 0l in
match rdir with
| Up -> check_up
| Down -> check_down
| Either -> check_either
| Ignore -> no_check
let up = Up
let down = Down
let either = Either
let ign = Ignore
(* returns true if a signal has moved from zero to a stritly positive value *)
let takeoff f0 f1 =
let n = Zls.length f0 in
let rec fold acc i =
if i = n then acc
else if acc then acc else fold ((f0.{i} = 0.0) && (f1.{i} > 0.0)) (i + 1)
in fold false 0
(* return a function that looks for zero-crossings between f0 and f1 *)
(** code inutile
let make_check_root rdir f0 f1 =
let check = get_check_root rdir in
(fun i -> check f0.{i} f1.{i})
**)
(* update roots and returns true if there was at least one root *)
(* between f0 and f1 for one component of index [i in [0..length f0 - 1]] *)
(* update [roots] *)
let update_roots calc_zc f0 f1 roots =
let update i found x0 x1 =
let zc = calc_zc x0 x1 in
roots.{i} <- zc;
found || (zc <> 0l)
in
fold_zxzx update false f0 f1
(* update [roots] *)
let clear_roots roots =
for i = 0 to Zls.length roots - 1 do
roots.{i} <- 0l
done
let log_limits f0 f1 =
let logf i _ = printf "z| g[% 2d]: % .24e --> % .24e@." i in
fold_zxzx logf () f0 f1
let log_limit f0 =
let logf i _ x _ = printf "z| g[% 2d]: % .24e@." i x in
fold_zxzx logf () f0 f0
(* the type signature of the zero-crossing function *)
type zcfn = float -> Zls.carray -> Zls.carray -> unit
(* type of a session with the solver *)
(* zx = g(t, c) yields the values of system zero-crossing expressions
f0/t0 are the zero-crossing expression values at the last mesh point
f1/t1 are the zero-crossing expression values at the next mesh point
bothf_valid is true when both f0/t0 and f1/t1 are valid and thus find
can check for zero-crossings between them.
roots is the array of booleans returned to callers to indicate on which
expressions zero-crossings have been detected.
calc_zc determines the kind of zero-crossings to seek and report.
fta and ftb are temporary arrays used when searching for zero-crossings.
They are kept in the session as an optimisation to avoid having to
continually create and destroy arrays.
*)
type t = {
g : zcfn;
mutable bothf_valid : bool;
mutable f0 : Zls.carray;
mutable t0 : float;
mutable f1 : Zls.carray;
mutable t1 : float;
mutable calc_zc : float -> float -> int32;
mutable fta : Zls.carray;
mutable ftb : Zls.carray;
}
(* Called from find when bothf_valid = false to initialise f1. *)
let reinitialize ({ g; f1 = f1; t1 = t1 } as s) t c =
s.t1 <- t;
g t1 c f1; (* fill f1, because it is immediately copied into f0 by next_mesh *)
if !debug then (printf "z|---------- init(%.24e, ... ----------@." t;
log_limit s.f1);
s.bothf_valid <- false
let initialize_only nroots g =
{
g = g;
bothf_valid = false;
f0 = Zls.cmake nroots;
t0 = 0.0;
f1 = Zls.cmake nroots;
t1 = 0.0;
fta = Zls.cmake nroots;
ftb = Zls.cmake nroots;
calc_zc = get_check_root Up;
}
let initialize nroots g c =
let s = initialize_only nroots g in
reinitialize s 0.0 c;
s
let num_roots { f0 } = Zls.length f0
(* f0/t0 take the previous values of f1/t1, f1/t1 are refreshed by g *)
let step ({ g; f0 = f0; f1 = f1; t1 = t1 } as s) t c =
(* swap f0 and f1; f0 takes the previous value of f1 *)
s.f0 <- f1;
s.t0 <- t1;
s.f1 <- f0;
s.t1 <- t;
(* calculate a new value for f1 *)
g t c s.f1;
s.bothf_valid <- true;
if !debug then
(printf "z|---------- step(%.24e, %.24e)----------@." s.t0 s.t1;
log_limits s.f0 s.f1)
type root_interval = SearchLeft | FoundMid | SearchRight
let resolve_intervals r1 r2 =
match r1, r2 with
| SearchLeft, _ | _, SearchLeft -> SearchLeft
| FoundMid, _ | _, FoundMid -> FoundMid
| SearchRight, _ -> SearchRight
(* Check for zero-crossings between f_left and f_mid, filling roots with the
intermediate results and returning:
SearchLeft zero-crossing in (f_left, f_mid)
FoundMid no zero-crossing in (f_left, f_mid)
zero-crossing in (f_left, f_mid]
SearchRight no zero-crossing in (f_left, f_mid]
(possible) zero-crossing in (f_mid, f_right]
*)
let check_interval calc_zc f_left f_mid =
let check i r x0 x1 =
let rv = calc_zc x0 x1 in
let r' = if rv = 0l then SearchRight
else if x1 = 0.0 then FoundMid
else SearchLeft in
resolve_intervals r r' in
fold_zxzx check SearchRight f_left f_mid
(* locates the zero-crossing *)
(* [find s (dky, c) roots = time] *)
(* stores the zero-crossing into the vector [roots] and returns the *)
(* time [time] right after the instant one zero-crossing has been found between *)
(* time [t0] and [t1] *)
let find ({ g = g; bothf_valid = bothf_valid;
f0 = f0; t0 = t0; f1 = f1; t1 = t1;
fta = fta; ftb = ftb; calc_zc = calc_zc } as s)
(dky, c) roots =
let ttol = 100.0 *. epsilon_float *. max (abs_float t0) (abs_float t1) in
(* A small optimisation to avoid copying or overwriting f1 *)
let get_f_right ofr = match ofr with None -> f1 | Some f -> f in
let f_mid_from_f_right ofr = match ofr with None -> ftb | Some f -> f in
(* update roots and c; return (t, f0_valid, f0, fta, ftb) *)
let interval_too_small t_left t_right f_left f_mid f_right' =
dky t_right 0; (* c = dky_0(t_right); update state *)
ignore (update_roots calc_zc f_left (get_f_right f_right') roots);
if !debug then
(printf
"z|---------- stall(%.24e, %.24e) {interval < %.24e !}--@."
t_left t_right ttol;
log_limits f_left (get_f_right f_right'));
match f_right' with
| None -> (t_right, false, f_left, f_mid, ftb)
| Some f_right -> (t_right, true, f_right, f_mid, f_left) in
(* Searches between (t_left, f_left) and (t_right, f_right) to find the
leftmost (t_mid, f_mid):
|
| f_right
|
| f_mid
+--[t_left---------t_mid---------------t_right]--
|
| f_left
|
t_left and t_right are the times that bound the interval
f_left and f_right are the values at the end points
f_mid is an array to be filled within the function (if necessary)
f_right' is used in the optimisation to avoid copying or overwriting f1
alpha is a parameter of the Illinois method, and
i is used in its calculation
seek() returns either:
(t, false, f0', fta', ftb') - root found at original f_right
(i.e., t = original t_right)
or
(t, true, f0', fta', ftb') - root found at f0' (i.e., t < t_right)
*)
let rec seek (t_left, f_left, f_mid, t_right, f_right', alpha, i) =
let dt = t_right -. t_left in
let f_right = get_f_right f_right' in
let leftmost_midpoint default =
let check _ t_min x_left x_right =
if x_left = 0.0 then t_min (* ignore expressions equal to zero at LHS *)
else
let sn = (x_right /. alpha) /. x_left in
let sn_d = 1.0 -. sn in
(* refer Dahlquist and Bjorck, sec. 6.2.2
stop if sn_d is not "large enough" *)
let t' =
if sn_d <= ttol then t_left +. (dt /. 2.0)
else t_right +. (sn /. sn_d) *. dt in
min t_min t' in
fold_zxzx check default f_left f_right in
if dt <= ttol
then interval_too_small t_left t_right f_left f_mid f_right'
else
let t_mid = leftmost_midpoint t_right in
if t_mid = t_right
then interval_too_small t_left t_right f_left f_mid f_right'
else begin
dky t_mid 0; (* c = dky_0(t_mid); interpolate state *)
g t_mid c f_mid; (* f_mid = g(t_mid, c); compute zc expressions *)
match check_interval calc_zc f_left f_mid with
| SearchLeft ->
if !debug then printf "z| (%.24e -- %.24e] %.24e@."
t_left t_mid t_right;
let alpha = if i >= 1 then alpha *. 0.5 else alpha in
let n_mid = f_mid_from_f_right f_right' in
seek (t_left, f_left, n_mid, t_mid, Some f_mid, alpha, i + 1)
| SearchRight ->
if !debug then printf "z| %.24e (%.24e -- %.24e]@."
t_left t_mid t_right;
let alpha = if i >= 1 then alpha *. 2.0 else alpha in
seek (t_mid, f_mid, f_left, t_right, f_right', alpha, i + 1)
| FoundMid ->
if !debug then printf "z| %.24e [%.24e] %.24e@."
t_left t_mid t_right;
ignore (update_roots calc_zc f_left f_mid roots);
let f_tmp = f_mid_from_f_right f_right' in
(t_mid, true, f_mid, f_left, f_tmp)
end
in
if not bothf_valid then (clear_roots roots; assert false)
else begin
if !debug then
printf "z|\nz|---------- find(%.24e, %.24e)----------@." t0 t1;
match check_interval calc_zc f0 f1 with
| SearchRight -> begin
clear_roots roots;
s.bothf_valid <- false;
assert false
end
| FoundMid -> begin
if !debug then printf "z| zero-crossing at limit (%.24e)@." t1;
ignore (update_roots calc_zc f0 f1 roots);
s.bothf_valid <- false;
t1
end
| SearchLeft -> begin
let (t, v, f0', fta', ftb') =
seek (t0, f0, fta, t1, None, 1.0, 0) in
s.t0 <- t;
s.f0 <- f0';
s.bothf_valid <- v;
s.fta <- fta';
s.ftb <- ftb';
t
end
end
(* the main function of this module *)
(* locate a root *)
let find s (dky, c) roots = find s (dky, c) roots
(* is there a root? [has_root s: bool] is true is there is a change in sign *)
(* for one component [i in [0..length f0 - 1]] beetwen [f0.(i)] and [f1.(i)] *)
let has_roots { bothf_valid = bothf_valid; t0; f0; t1; f1; calc_zc = calc_zc }
= bothf_valid && (check_interval calc_zc f0 f1 <> SearchRight)
let takeoff { bothf_valid = bothf_valid; f0; f1 } =
bothf_valid && (takeoff f0 f1)
(* returns true if a signal has moved from zero to a stritly positive value *)
(* Added by MP. Ask Tim if this code is necessary, that is, what happens *)
(* with function [find] when the signal is taking off from [0.0] to a *)
(* strictly positive value *)
let find_takeoff ({ f0; f1 } as s) roots =
let calc_zc x0 x1 =
if (x0 = 0.0) && (x1 > 0.0) then 1l else 0l in
let b = update_roots calc_zc f0 f1 roots in
if b then begin s.t1 <- s.t0; s.f1 <- s.f0; s.ftb <- s.fta end;
s.t0
let set_root_directions s rd = (s.calc_zc <- get_check_root rd)

View file

@ -0,0 +1,5 @@
(* This code was originally written by Timothy Bourke and Marc Pouzet and is *)
(* part of the Zelus standard library. *)
include Zls.STATE_ZEROC_SOLVER

413
src/lib/solvers/odexx.ml Normal file
View file

@ -0,0 +1,413 @@
(* This code was originally written by Timothy Bourke and Marc Pouzet and is *)
(* part of the Zelus standard library. *)
open Zls
module type BUTCHER_TABLEAU =
sig (* {{{ *)
val order : int (* solver order *)
val initial_reduction_limit_factor : float
(* factor limiting the reduction of h after a failed step *)
(* Butcher Tableau:
a(0) |
a(1) | b(1)
a(2) | b(2) b(3)
a(3) | b(4) b(5) b(6)
... | ...
-------+--------------
a(n) | b(~) b(~) b(~) ...
| b(+) b(+) b(+) ...
The b(~) values must be included in b.
The b(+) values are given indirectly via e.
e/h = y_n+1 - y*_n+1 = b(~)s - b(+)s
*)
val a : float array (* h coefficients; one per stage *)
val b : float array (* previous stage coefficients *)
val e : float array (* error estimation coefficients *)
val bi : float array (* interpolation coefficients *)
(* let ns be the number of stages, then:
size(a) = ns x 1
size(b) = ns x ns
(but only the lower strictly triangular entries)
size(e) = ns
size(bi) = ns x po
(where po is the order of the interpolating polynomial)
*)
end (* }}} *)
module GenericODE (Butcher : BUTCHER_TABLEAU) : STATE_ODE_SOLVER =
struct (* {{{1 *)
open Bigarray
let debug = ref false (* !Debug.debug *)
let pow = 1.0 /. float(Butcher.order)
let mA r = Butcher.a.(r)
let h_matB = Array.copy Butcher.b
let update_mhB h = for i = 0 to Array.length h_matB - 1 do
h_matB.(i) <- Butcher.b.(i) *. h
done
let mhB r c = if c >= r then 0.0 else h_matB.(((r-1)*r)/2 + c)
let mhB_row r = Array.sub h_matB (((r-1)*r)/2) r
let mE c = Butcher.e.(c)
let maxK = Array.length(Butcher.a) - 1
let rowsBI = Array.length(Butcher.a)
let colsBI = Array.length(Butcher.bi) / rowsBI
let maxBI = colsBI - 1
let h_matBI = Array.copy Butcher.bi
let update_mhBI h = for i = 0 to Array.length h_matBI - 1 do
h_matBI.(i) <- Butcher.bi.(i) *. h
done
let mhBI_row r = Array.sub h_matBI (r * colsBI) colsBI
let minmax minimum maximum x = min maximum (max minimum x)
let mapinto r f =
for i = 0 to Array1.dim r - 1 do
r.{i} <- f i
done
let fold2 f a v1 v2 =
let acc = ref a in
for i = 0 to min (length v1) (length v2) - 1 do
acc := f !acc (get v1 i) (get v2 i)
done;
!acc
let maxnorm2 f = fold2 (fun acc v1 v2 -> max acc (abs_float (f v1 v2))) 0.0
type rhsfn = float -> Zls.carray -> Zls.carray -> unit
type dkyfn = Zls.carray -> float -> int -> unit
(* dx = sysf(t, y) describes the system dynamics
y/time is the current mesh point
yold/last_time is the previous mesh point
(and also used for intermediate values during the
calculation of the next mesh point)
(y and yold are mutable because they are swapped after having calculated
the next mesh point yold)
h is the step size to be used for calculating the next mesh point.
k.(0) is the instantaneous derivative at the previous mesh point
k.(maxK) is the instantaneous derivative at the current mesh point
k.(1--maxK-1) track intermediate instantaneous derivatives during the
calculation of the next mesh point.
*)
type t = {
mutable sysf : float -> Zls.carray -> Zls.carray -> unit;
mutable y : Zls.carray;
mutable time : float;
mutable last_time : float;
mutable h : float;
mutable hmax : float;
k : Zls.carray array;
mutable yold : Zls.carray;
(* -- parameters -- *)
mutable stop_time : float;
(* bounds on small step sizes (mesh-points) *)
mutable min_step : float;
mutable max_step : float;
(* initial/fixed step size *)
initial_step_size : float option;
mutable rel_tol : float;
mutable abs_tol : float;
}
type nvec = Zls.carray
let cmake = Array1.create float64 c_layout
let unvec x = x
let vec x = x
let calculate_hmax tfinal min_step max_step =
(* [ensure hmax >= min_step] *)
let hmax =
if tfinal = infinity then max_step
else if max_step = infinity then 0.1 *. tfinal
else min max_step tfinal in
max min_step hmax
(* NB: y must be the initial state vector (y_0)
* k(0) must be the initial deriviatives vector (dy_0) *)
let initial_stepsize { initial_step_size; abs_tol; rel_tol; max_step;
time; y; hmax; k } =
let hmin = 16.0 *. epsilon_float *. abs_float time in
match initial_step_size with
| Some h -> minmax hmin max_step h
| None ->
let threshold = abs_tol /. rel_tol in
let rh =
maxnorm2 (fun y dy -> dy /. (max (abs_float y) threshold)) y k.(0)
/. (0.8 *. rel_tol ** pow)
in
max hmin (if hmax *. rh > 1.0 then 1.0 /. rh else hmax)
let reinitialize ?rhsfn ({ stop_time; min_step; max_step; sysf } as s) t ny =
Bigarray.Array1.blit ny s.y;
s.time <- t;
s.last_time <- t;
s.hmax <- calculate_hmax stop_time min_step max_step;
sysf t s.y s.k.(maxK); (* update initial derivatives;
to be FSAL swapped into k.(0) *)
s.h <- initial_stepsize s;
Option.iter (fun v -> s.sysf <- v) rhsfn
let initialize f ydata =
let y_len = Bigarray.Array1.dim ydata in
let s = {
sysf = f;
y = Zls.cmake y_len;
time = 0.0;
last_time = 0.0;
h = 0.0;
hmax = 0.0;
k = Array.init (maxK + 1) (fun _ -> Zls.cmake y_len);
yold = Zls.cmake y_len;
(* parameters *)
stop_time = infinity;
min_step = 16.0 *. epsilon_float;
max_step = infinity;
initial_step_size = None;
rel_tol = 1.0e-3;
abs_tol = 1.0e-6;
} in
Bigarray.Array1.blit ydata s.k.(0);
reinitialize s 0.0 ydata;
s
let set_stop_time t v =
if (v <= 0.0) then failwith "The stop time must be strictly positive.";
t.stop_time <- v
let set_min_step t v = t.min_step <- v
let set_max_step t v = t.max_step <- v
let set_tolerances t rel abs =
if (rel <= 0.0 || abs <= 0.0)
then failwith "Tolerance values must be strictly positive.";
(t.rel_tol <- rel; t.abs_tol <- abs)
let make_newval y k s =
let hB = mhB_row s in
let newval i =
let acc = ref y.{i} in
for si = 0 to s - 1 do
acc := !acc +. k.(si).{i} *. hB.(si)
done;
!acc in
newval
let calculate_error threshold k y ynew =
let maxerr = ref 0.0 in
for i = 0 to Bigarray.Array1.dim y - 1 do
let kE = ref 0.0 in
for s = 0 to maxK do
kE := !kE +. k.(s).{i} *. mE s
done;
let err = !kE /. (max threshold (max (abs_float y.{i})
(abs_float ynew.{i}))) in
maxerr := max !maxerr (abs_float err)
done;
!maxerr
let log_step t y dy t' y' dy' =
Printf.printf
"s| % .24e % .24e\n" t t';
for i = 0 to Array1.dim y - 1 do
Printf.printf "s| f[% 2d]: % .24e (% .24e) --> % .24e (% .24e)\n"
i (y.{i}) dy.{i} y'.{i} dy'.{i}
done
(* TODO: add stats: nfevals, nfailed, nsteps *)
let step s t_limit user_y =
let { stop_time; min_step; abs_tol; rel_tol;
sysf = f; time = t; h = h; hmax = hmax;
k = k; y = y; yold = ynew; } = s in
(* First Same As Last (FSAL) swap; doing it after the previous
step invalidates the interpolation routine. *)
let tmpK = k.(0) in
k.(0) <- k.(maxK);
k.(maxK) <- tmpK;
let hmin = 16.0 *. epsilon_float *. abs_float t in
let h = minmax hmin hmax h in
let max_t = min t_limit stop_time in
let h, finished =
if 1.1 *. h >= abs_float (max_t -. t)
then (max_t -. t, true)
else (h, false) in
if h < s.min_step then failwith
(Printf.sprintf
"odexx: step size < min step size (\n now=%.24e\n h=%.24e\n< min_step=%.24e)"
t h s.min_step);
if !debug then Printf.printf "s|\ns|----------step(%.24e)----------\n" max_t;
let rec onestep (alreadyfailed: bool) h =
(* approximate next state vector *)
update_mhB h;
for s = 1 to maxK - 1 do
mapinto ynew (make_newval y k s);
f (t +. h *. mA s) ynew k.(s)
done;
let tnew = if finished then max_t else t +. h *. (mA maxK) in
mapinto ynew (make_newval y k maxK);
f tnew ynew k.(maxK);
if !debug then log_step t y k.(0) tnew ynew k.(maxK);
let err = h *. calculate_error (abs_tol /. rel_tol) k y ynew in
if err > rel_tol then begin
if !debug then Printf.printf "s| error exceeds tolerance\n";
if h <= hmin then failwith
(Printf.sprintf "Error (%e) > relative tolerance (%e) at t=%e"
err rel_tol t);
let nexth =
if alreadyfailed then max hmin (0.5 *. h)
else max hmin (h *. max Butcher.initial_reduction_limit_factor
(0.8 *. (rel_tol /. err) ** pow)) in
onestep true nexth
end
else
let h = tnew -. t in
let nexth =
if alreadyfailed then h
else let f = 1.25 *. (err /. rel_tol) ** pow in
if f > 0.2 then h /. f else 5.0 *. h in
(tnew, nexth)
in
let nextt, nexth = onestep false h in
(* advance a step *)
s.y <- ynew;
s.yold <- y;
Bigarray.Array1.blit ynew user_y;
s.last_time <- t;
s.time <- nextt;
s.h <- nexth;
s.time
let get_dky { last_time = t; time = t'; h = h; yold = y; k = k } yi ti kd =
if kd > 0 then
failwith
(Printf.sprintf
"get_dky: requested derivative of order %d \
cannot be interpolated at time %.24e" kd ti);
if ti < t || ti > t' then
failwith
(Printf.sprintf
"get_dky: requested time %.24e is out of range\n\ [%.24e,...,%.24e]"
ti t t');
let h = t' -. t in
let th = (ti -. t) /. h in
update_mhBI h;
for i = 0 to Bigarray.Array1.dim y - 1 do
let ya = ref y.{i} in
for s = 0 to maxK do
let k = k.(s).{i} in
let hbi = mhBI_row s in
let acc = ref 0.0 in
for j = maxBI downto 0 do
acc := (!acc +. k *. hbi.(j)) *. th
done;
ya := !ya +. !acc
done;
yi.{i} <- !ya
done
(* copy functions *)
let copy ({ last_time; time; h; yold; k } as s) =
{ s with last_time; time; h; yold = Zls.copy yold; k = Zls.copy_matrix k }
let blit { last_time = l1; time = t1; h = h1; yold = yhold1; k = k1 }
({ last_time; time; h; yold; k } as s2) =
s2.last_time <- l1; s2.time <- t1;
Zls.blit yhold1 yold; Zls.blit_matrix k1 k
end (* }}} *)
module Ode23 = GenericODE (
struct
let order = 3
let initial_reduction_limit_factor = 0.5
let a = [| 0.0; 1.0/.2.0; 3.0/.4.0; 1.0 |]
let b = [| 1.0/.2.0;
0.0; 3.0/.4.0;
2.0/.9.0; 1.0/.3.0; 4.0/.9.0 |]
let e = [| -5.0/.72.0; 1.0/.12.0; 1.0/.9.0; -1.0/.8.0 |]
let bi = [| 1.0; -4.0/.3.0; 5.0/.9.0;
0.0; 1.0; -2.0/.3.0;
0.0; 4.0/.3.0; -8.0/.9.0;
0.0; -1.0; 1.0 |]
end)
module Ode45 = GenericODE (
struct
let order = 5
let initial_reduction_limit_factor = 0.1
let a = [| 0.0; 1.0/.5.0; 3.0/.10.0; 4.0/.5.0; 8.0/.9.0; 1.0; 1.0 |]
let b = [|
1.0/. 5.0;
3.0/.40.0; 9.0/.40.0;
44.0/.45.0; -56.0/.15.0; 32.0/.9.0;
19372.0/.6561.0; -25360.0/.2187.0; 64448.0/.6561.0; -212.0/.729.0;
9017.0/.3168.0; -355.0/.33.0; 46732.0/.5247.0; 49.0/.176.0; -5103.0/.18656.0;
35.0/.384.0; 0.0; 500.0/.1113.0; 125.0/.192.0; -2187.0/.6784.0; 11.0/.84.0;
|]
let e = [| 71.0/.57600.0; 0.0; -71.0/.16695.0; 71.0/.1920.0;
-17253.0/.339200.0; 22.0/.525.0; -1.0/.40.0 |]
let bi = [| 1.0; -183.0/.64.0; 37.0/.12.0; -145.0/.128.0;
0.0; 0.0; 0.0; 0.0;
0.0; 1500.0/.371.0; -1000.0/.159.0; 1000.0/.371.0;
0.0; -125.0/.32.0; 125.0/.12.0; -375.0/.64.0;
0.0; 9477.0/.3392.0; -729.0/.106.0; 25515.0/.6784.0;
0.0; -11.0/.7.0; 11.0/.3.0; -55.0/.28.0;
0.0; 3.0/.2.0; -4.0; 5.0/.2.0 |]
end)

215
src/lib/solvers/zls.ml Normal file
View file

@ -0,0 +1,215 @@
(* This code was originally written by Timothy Bourke and Marc Pouzet and is *)
(* part of the Zelus standard library. *)
(* open Ztypes *)
open Bigarray
(* Interfaces functions from within Zelus *)
type carray = (float, float64_elt, c_layout) Array1.t
type zarray = (int32, int32_elt, c_layout) Array1.t
let cmake (n: int) : carray =
let r = Array1.create float64 c_layout n in
Array1.fill r 0.0;
r
let zmake (n: int) : zarray =
let r = Array1.create int32 c_layout n in
Array1.fill r 0l;
r
let length = Array1.dim
let get = Array1.get
let set = Array1.set
let get_zin v i = Array1.get v i <> 0l
(* fill zinvec with zeros *)
let zzero zinvec length =
for i = 0 to length - 1 do
Array1.set zinvec i 0l
done
let czero c length =
for i = 0 to length - 1 do
Array1.set c i 0.0
done
(* copy functions *)
(* copy [c1] into [c2] *)
let blit c1 c2 = Array1.blit c1 c2
let copy c1 = let c2 = cmake (length c1) in blit c1 c2; c2
let blit_matrix m1 m2 = Array.iter2 blit m1 m2
let copy_matrix m = Array.map copy m
type 's f_alloc = unit -> 's
type 's f_maxsize = 's -> int * int
type 's f_csize = 's -> int
type 's f_zsize = 's -> int
type ('s, 'o) f_step = 's -> carray -> carray -> zarray -> float -> 'o
type 's f_ders = 's -> carray -> carray -> zarray -> carray -> float -> unit
type 's f_zero = 's -> carray -> zarray -> carray -> float -> unit
type 's f_reset = 's -> unit
type 's f_horizon = 's -> float
(* TODO: eliminate this ? *)
(* Compare two floats for equality, see:
* http://www.cygnus-software.com/papers/comparingfloats/comparingfloats.htm *)
let time_eq f1 f2 =
if abs_float (f1 -. f2) < min_float
then true (* absolute error check for numbers around to zero *)
else
let rel_error =
if abs_float f1 > abs_float f2
then abs_float ((f1 -. f2) /. f1)
else abs_float ((f1 -. f2) /. f2)
in
(rel_error <= 0.000001)
(* Compare times with 99.9999% accuracy. *)
let time_leq t1 t2 = t1 < t2 || time_eq t1 t2
let time_geq t1 t2 = t1 > t2 || time_eq t1 t2
(* TODO:
- adapt to the new sundials interface, rework, and simplify.
- take advantage of the final field.
*)
(* Interface of a stateful ODE solver *)
module type STATE_ODE_SOLVER =
sig
(* A session with the solver. *)
type t
(* The type of vectors used internally by the solver. *)
type nvec
(* Create a vector of the given size. *)
val cmake : int -> nvec
(* Unwrap a vector returning an array of continuous-state values. *)
val unvec : nvec -> carray
(* Wrap a vector of continuous-state values into an vector. *)
val vec : carray -> nvec
(* A right-hand-side function called by the solver to calculate the
instantaneous derivatives: [f t cvec dvec].
- t, the current simulation time (input)
- cvec, current values for continuous states (input)
- dvec, the vector of instantaneous derivatives (output) *)
type rhsfn = float -> carray -> carray -> unit
(* An interpolation function: [df cvec t k]
- cvec, a vector for storing the interpolated continuous states (output)
- t, the time to interpolate at,
- k, the derivative to interpolate *)
type dkyfn = nvec -> float -> int -> unit
(* [initialize f c] creates a solver session from a function [f] and
an initial state vector [c]. *)
val initialize : rhsfn -> nvec -> t
(* [reinitialize s t c] reinitializes the solver with the given time
[t] and vector of continuous states [c]. *)
(* warning. the size of [c] must be unchanged *)
val reinitialize : ?rhsfn:rhsfn -> t -> float -> nvec -> unit
(* [t' = step s tl c] given a state vector [c], takes a step to the next
'mesh-point', or the given time limit [tl] (whichever is sooner),
updating [c]. *)
val step : t -> float -> nvec -> float
(* Returns an interpolation function that can produce results for any
time [t] since the last mesh-point or the initial instant. *)
val get_dky : t -> dkyfn
(* generic solver parameters *)
val set_stop_time : t -> float -> unit
val set_min_step : t -> float -> unit
val set_max_step : t -> float -> unit
val set_tolerances : t -> float -> float -> unit
val copy : t -> t
val blit : t -> t -> unit
end
(* Interface of a stateful zero-crossing solver *)
module type STATE_ZEROC_SOLVER =
sig
(* A session with the solver. A zero-crossing solver has two internal
continuous-state vectors, called 'before' and 'now'. *)
type t
(* Zero-crossing function: [g t cvec zout]
- t, simulation time (input)
- cvec, vector of continuous states (input)
- zout, values of zero-crossing expressions (output) *)
type zcfn = float -> carray -> carray -> unit
(* Create a session with the zero-crossing solver:
[initialize nroots g cvec0]
- nroots, number of zero-crossing expressions
- g, function to calculate zero-crossing expressions
- cvec0, initial continuous state
Sets the 'now' vector to cvec0. *)
val initialize : int -> zcfn -> carray -> t
(* The same but does not run [g] at initialization time *)
val initialize_only : int -> zcfn -> t
(* Reinitialize the zero-crossing solver after a discrete step that
updates the continuous states directly: [reinitialize s t cvec].
- s, a session with the zero-crossing solver
- t, the current simultation time
- cvec, the current continuous state vector
Resets the 'now' vector to cvec. *)
val reinitialize : t -> float -> carray -> unit
(* Advance the zero-crossing solver after a continuous step:
[step s t cvec].
- s, a session with the zero-crossing solver
- t, the current simultation time
- cvec, the current continuous state vector
Moves the current 'now' vector to 'before', then sets 'now' to cvec. *)
val step : t -> float -> carray -> unit
val takeoff : t -> bool
(* Returns true if one zero-crossing signal moves from 0 to v > 0 *)
(* Compares the 'before' and 'now' vectors and returns true only if
there exists an i, such that before[i] < 0 and now[i] >= 0. *)
val has_roots : t -> bool
(* Locates the time of the zero-crossing closest to the 'before' vector.
Call after [has_roots] indicates the existence of a zero-crossing:
[t = find s (f, c) zin].
- The [get_dky] function [f] is provided by the state solver and is
expected to update the array [c] with the interpolated state.
- zin, is populated with the status of all zero-crossings
- the returned values is the simulation time of the earliest
zero-crossing that was found. *)
val find : t -> ((float -> int -> unit) * carray) -> zarray -> float
(* locate the fields for which there is a takeoff *)
val find_takeoff : t -> zarray -> float
end
(*
module type RUNTIME =
sig
val go : unit hsimu -> unit
val check : bool hsimu -> int -> unit
end
module type DISCRETE_RUNTIME =
sig
val go : float -> (unit -> unit) -> unit
end
*)