feat: start of lift, debugging, cleanup

This commit is contained in:
Henri Saudubray 2025-06-23 10:06:01 +02:00
parent 883e5fff01
commit 589f89c768
Signed by: hms
GPG key ID: 7065F57ED8856128
31 changed files with 1297 additions and 51 deletions

View file

@ -2,3 +2,10 @@
let debug = ref false
let print s = if !debug then Format.printf "%s\n" s else ()
let print_entry y =
let n = Bigarray.Array1.dim y in
let rec loop i =
if i = n then ()
else (Format.printf "\t% .10e" y.{i}; loop (i+1)) in
if !debug then (loop 0; Format.printf "\n"; flush stdout)

151
src/lib/common/ztypes.ml Normal file
View file

@ -0,0 +1,151 @@
(**************************************************************************)
(* *)
(* Zelus *)
(* A synchronous language for hybrid systems *)
(* http://zelus.di.ens.fr *)
(* *)
(* Marc Pouzet and Timothy Bourke *)
(* *)
(* Copyright 2012 - 2019. All rights reserved. *)
(* *)
(* This file is distributed under the terms of the CeCILL-C licence *)
(* *)
(* Zelus is developed in the INRIA PARKAS team. *)
(* *)
(**************************************************************************)
(* Type declarations and values that must be linked with *)
(* the generated code *)
type 'a continuous = { mutable pos: 'a; mutable der: 'a }
type ('a, 'b) zerocrossing = { mutable zin: 'a; mutable zout: 'b }
type 'a signal = 'a * bool
type zero = bool
(* a synchronous stream function with type 'a -D-> 'b *)
(* is represented by an OCaml value of type ('a, 'b) node *)
type ('a, 'b) node =
Node:
{ alloc : unit -> 's; (* allocate the state *)
step : 's -> 'a -> 'b; (* compute a step *)
reset : 's -> unit; (* reset/inialize the state *)
} -> ('a, 'b) node
(* the same with a method copy *)
type ('a, 'b) cnode =
Cnode:
{ alloc : unit -> 's; (* allocate the state *)
copy : 's -> 's -> unit; (* copy the source into the destination *)
step : 's -> 'a -> 'b; (* compute a step *)
reset : 's -> unit; (* reset/inialize the state *)
} -> ('a, 'b) cnode
open Bigarray
type time = float
type cvec = (float, float64_elt, c_layout) Array1.t
type dvec = (float, float64_elt, c_layout) Array1.t
type zinvec = (int32, int32_elt, c_layout) Array1.t
type zoutvec = (float, float64_elt, c_layout) Array1.t
(* The interface with the ODE solver *)
type cstate =
{ mutable dvec : dvec; (* the vector of derivatives *)
mutable cvec : cvec; (* the vector of positions *)
mutable zinvec : zinvec; (* the vector of boolean; true when the
solver has detected a zero-crossing *)
mutable zoutvec : zoutvec; (* the corresponding vector that define
zero-crossings *)
mutable cindex : int; (* the position in the vector of positions *)
mutable zindex : int; (* the position in the vector of zero-crossings *)
mutable cend : int; (* the end of the vector of positions *)
mutable zend : int; (* the end of the zero-crossing vector *)
mutable cmax : int; (* the maximum size of the vector of positions *)
mutable zmax : int; (* the maximum number of zero-crossings *)
mutable horizon : float; (* the next horizon *)
mutable major : bool; (* integration iff [major = false] *)
}
(* A hybrid node is a node that is parameterised by a continuous state *)
(* all instances points to this global parameter and read/write on it *)
type ('a, 'b) hnode = cstate -> (time * 'a, 'b) node
type 'b hsimu =
Hsim:
{ alloc : unit -> 's;
(* allocate the initial state *)
maxsize : 's -> int * int;
(* returns the max length of the *)
(* cvector and zvector *)
csize : 's -> int;
(* returns the current length of the continuous state vector *)
zsize : 's -> int;
(* returns the current length of the zero-crossing vector *)
step : 's -> cvec -> dvec -> zinvec -> time -> 'b;
(* computes a step *)
derivative : 's -> cvec -> dvec -> zinvec -> zoutvec -> time -> unit;
(* computes the derivative *)
crossings : 's -> cvec -> zinvec -> zoutvec -> time -> unit;
(* computes the zero-crossings *)
reset : 's -> unit;
(* resets the state *)
horizon : 's -> time;
(* gives the next time horizon *)
} -> 'b hsimu
(* a function with type 'a -C-> 'b, when given to a solver, is *)
(* represented by an OCaml value of type ('a, 'b) hsnode *)
type ('a, 'b) hsnode =
Hnode:
{ state : 's;
(* the discrete state *)
zsize : int;
(* the maximum size of the zero-crossing vector *)
csize : int;
(* the maximum size of the continuous state vector (positions) *)
derivative : 's -> 'a -> time -> cvec -> dvec -> unit;
(* computes the derivative *)
crossing : 's -> 'a -> time -> cvec -> zoutvec -> unit;
(* computes the derivative *)
output : 's -> 'a -> cvec -> 'b;
(* computes the zero-crossings *)
setroots : 's -> 'a -> cvec -> zinvec -> unit;
(* returns the zero-crossings *)
majorstep : 's -> time -> cvec -> 'a -> 'b;
(* computes a step *)
reset : 's -> unit;
(* resets the state *)
horizon : 's -> time;
(* gives the next time horizon *)
} -> ('a, 'b) hsnode
(* An idea suggested by Adrien Guatto, 26/04/2021 *)
(* provide a means to the type for input/outputs of nodes *)
(* express them with GADT to ensure type safety *)
(* type ('a, 'b) node =
| Fun : { step : 'a -> 'b;
typ_arg: 'a typ;
typ_return: 'b typ
} -> ('a, 'b) node
| Node :
{ state : 's; step : 's -> 'a -> 'b * 's;
typ_arg: 'a typ;
typ_state : 's typ;
typ_return: 'b typ } -> ('a, 'b) node
and 'a typ =
| Tunit : unit typ
| Tarrow : 'a typ * 'b typ -> ('a * 'b) typ
| Tint : int -> int typ
| Ttuple : 'a typlist -> 'a typ
| Tnode : 'a typ * 'b typ -> ('a,'b) node typ
and 'a typlist =
| Tnil : unit typlist
| Tpair : 'a typ * 'b typlist -> ('a * 'b) typlist
Q1: do it for records? sum types ? How?
Q2: provide a "type_of" function for every introduced type?
*)

View file

@ -8,6 +8,7 @@ module Sim (S : SimState) =
include S
let step_discrete s step hor fder fzer cget csize zsize jump reset =
Common.Debug.print "SIMU :: DISCRETE :: start";
let ms, ss = get_mstate s, get_sstate s in
let i, now, stop = get_input s, get_now s, get_stop s in
let o, ms = step ms (i.u now) in
@ -26,9 +27,11 @@ module Sim (S : SimState) =
let mode, stop, now = Continuous, i.h, 0.0 in
update ms ss (set_running ~mode ~input:i ~stop ~now s)
end else set_running ~mode:Continuous s in
Common.Debug.print "SIMU :: DISCRETE :: end";
Utils.dot o, s
let step_continuous s step cset fout zset =
Common.Debug.print "SIMU :: CONTINUOUS :: start";
let ms, ss = get_mstate s, get_sstate s in
let i, now, stop = get_input s, get_now s, get_stop s in
let (h, f, z), ss = step ss stop in
@ -46,6 +49,7 @@ module Sim (S : SimState) =
let s = set_running ~mode:Discrete ~now:h s in
update (zset ms z) ss s, Discontinuous in
let h = h -. now in
Common.Debug.print "SIMU :: CONTINUOUS :: end";
{ h; u=fout; c }, s, { h; c; u=fms }
(** Simulation of a model with any solver. *)
@ -55,7 +59,7 @@ module Sim (S : SimState) =
: ('p * (('y, 'yder) ivp * ('y, 'zout) zc), 'a, 'b) sim
= let state = get_init m.state s.state in
let step_discrete st =
let o, s = step_discrete st m.step m.horizon m.fder m.fzer m.cget
let o, s = step_discrete st m.step m.horizon m.fder m.fzer m.cget
m.csize m.zsize m.jump s.reset in
Some o, s in
let step_continuous st =
@ -97,7 +101,7 @@ module Sim (S : SimState) =
let _, state = a.step a.state @@ Some (Utils.dot @@ get_mstate st) in
DNode { a with state }) al in
Some o, (st, al) in
let step_continuous (st, al) =
let ({ h; _ } as o), st, u =
step_continuous st s.step m.body.cset m.body.fout m.body.zset in

View file

@ -69,7 +69,7 @@ type ('p, 'a, 'b) sim = ('p, 'a signal, 'b signal) dnode
(** Consider a node with state copying as a node without state copying. *)
let d_of_dc (DNodeC { state; step; reset; _ }) = DNode { state; step; reset }
(** Consider a model without assertions as a model with an empty list of
(** Consider a model without assertions as a model with an empty list of
assertions. *)
let a_of_h (HNode body) = HNodeA { body; assertions=[] }

View file

@ -3,7 +3,9 @@
(* part of the Zelus standard library. *)
(* It is implemented with in-place modification of arrays. *)
let debug = ref false
let debug () =
(* false *)
!Common.Debug.debug
let printf x = Format.printf x
@ -121,7 +123,7 @@ type t = {
let reinitialize ({ g; f1 = f1; t1 = t1; _ } as s) t c =
s.t1 <- t;
g t1 c f1; (* fill f1, because it is immediately copied into f0 by next_mesh *)
if !debug then (printf "z|---------- init(%.24e, ... ----------@." t;
if debug () then (printf "z|---------- init(%.24e, ... ----------@." t;
log_limit s.f1);
s.bothf_valid <- false
@ -152,6 +154,7 @@ let num_roots { f0; _ } = Zls.length f0
(* f0/t0 take the previous values of f1/t1, f1/t1 are refreshed by g *)
let step ({ g; f0 = f0; f1 = f1; t1 = t1; _ } as s) t c =
Common.Debug.print "ZSOL :: Calling [step]";
(* swap f0 and f1; f0 takes the previous value of f1 *)
s.f0 <- f1;
s.t0 <- t1;
@ -162,7 +165,7 @@ let step ({ g; f0 = f0; f1 = f1; t1 = t1; _ } as s) t c =
g t c s.f1;
s.bothf_valid <- true;
if !debug then
if debug () then
(printf "z|---------- step(%.24e, %.24e)----------@." s.t0 s.t1;
log_limits s.f0 s.f1)
@ -212,7 +215,7 @@ let find ({ g = g; bothf_valid = bothf_valid;
dky t_right 0; (* c = dky_0(t_right); update state *)
ignore (update_roots calc_zc f_left (get_f_right f_right') roots);
if !debug then
if debug () then
(printf
"z|---------- stall(%.24e, %.24e) {interval < %.24e !}--@."
t_left t_right ttol;
@ -280,20 +283,20 @@ let find ({ g = g; bothf_valid = bothf_valid;
match check_interval calc_zc f_left f_mid with
| SearchLeft ->
if !debug then printf "z| (%.24e -- %.24e] %.24e@."
if debug () then printf "z| (%.24e -- %.24e] %.24e@."
t_left t_mid t_right;
let alpha = if i >= 1 then alpha *. 0.5 else alpha in
let n_mid = f_mid_from_f_right f_right' in
seek (t_left, f_left, n_mid, t_mid, Some f_mid, alpha, i + 1)
| SearchRight ->
if !debug then printf "z| %.24e (%.24e -- %.24e]@."
if debug () then printf "z| %.24e (%.24e -- %.24e]@."
t_left t_mid t_right;
let alpha = if i >= 1 then alpha *. 2.0 else alpha in
seek (t_mid, f_mid, f_left, t_right, f_right', alpha, i + 1)
| FoundMid ->
if !debug then printf "z| %.24e [%.24e] %.24e@."
if debug () then printf "z| %.24e [%.24e] %.24e@."
t_left t_mid t_right;
ignore (update_roots calc_zc f_left f_mid roots);
let f_tmp = f_mid_from_f_right f_right' in
@ -303,7 +306,7 @@ let find ({ g = g; bothf_valid = bothf_valid;
if not bothf_valid then (clear_roots roots; assert false)
else begin
if !debug then
if debug () then
printf "z|\nz|---------- find(%.24e, %.24e)----------@." t0 t1;
match check_interval calc_zc f0 f1 with
@ -314,7 +317,7 @@ let find ({ g = g; bothf_valid = bothf_valid;
end
| FoundMid -> begin
if !debug then printf "z| zero-crossing at limit (%.24e)@." t1;
if debug () then printf "z| zero-crossing at limit (%.24e)@." t1;
ignore (update_roots calc_zc f0 f1 roots);
s.bothf_valid <- false;
t1

View file

@ -51,7 +51,9 @@ module GenericODE (Butcher : BUTCHER_TABLEAU) : STATE_ODE_SOLVER =
struct (* {{{1 *)
open Bigarray
let debug = ref false (* !Debug.debug *)
let debug () =
false
(* !Common.Debug.debug *)
let pow = 1.0 /. float(Butcher.order)
@ -274,7 +276,7 @@ struct (* {{{1 *)
"odexx: step size < min step size (\n now=%.24e\n h=%.24e\n< min_step=%.24e)"
t h s.min_step);
if !debug then Printf.printf "s|\ns|----------step(%.24e)----------\n" max_t;
if debug () then Printf.printf "s|\ns|----------step(%.24e)----------\n" max_t;
let rec onestep (alreadyfailed: bool) h =
@ -288,11 +290,11 @@ struct (* {{{1 *)
let tnew = if finished then max_t else t +. h *. (mA maxK) in
mapinto ynew (make_newval y k maxK);
f tnew ynew k.(maxK);
if !debug then log_step t y k.(0) tnew ynew k.(maxK);
if debug () then log_step t y k.(0) tnew ynew k.(maxK);
let err = h *. calculate_error (abs_tol /. rel_tol) k y ynew in
if err > rel_tol then begin
if !debug then Printf.printf "s| error exceeds tolerance\n";
if debug () then Printf.printf "s| error exceeds tolerance\n";
if h <= hmin then failwith
(Printf.sprintf "Error (%e) > relative tolerance (%e) at t=%e"

View file

@ -22,6 +22,8 @@ module Functional =
{ state; vec = init } in
let step ({ state ; vec=v } as s) h =
Common.Debug.print "SOLVER STEP";
Common.Debug.print_entry v;
let y_nv = vec v in
let h = step state h y_nv in
let state = copy state in

View file

@ -15,7 +15,10 @@ module Functional =
vec = zmake 0 } in
let reset { fzer; init; size } { vec; _ } =
let fzer t cvec zout = let zout' = fzer t cvec in blit zout' zout in
let fzer t cvec zout =
let zout' = fzer t cvec in blit zout' zout in
Common.Debug.print "ZSolver Reset";
Common.Debug.print_entry init;
{ state = initialize size fzer init;
vec = if length vec = size then vec else zmake size } in