Skip to main content

nereids_fitting/
nelder_mead.rs

1//! Bounded Nelder-Mead simplex minimizer.
2//!
3//! Derivative-free polish optimizer used after a gradient-based stage to
4//! escape stall points.  Memo 35 §P2.1 and EG5 establish that, for
5//! backgrounded counts-path fits, a single L-BFGS start frequently stalls
6//! at the initial guess (1/20 self-flagged convergence on the EG2 S1 C_full
7//! regime), while a Nelder-Mead polish from that stall point resolves the
8//! failure cleanly (10/20 convergence, density bias from −5.94% to +0.013%,
9//! D/DOF from 905 to 1.001).
10//!
11//! ## Algorithm
12//!
13//! Standard Nelder-Mead simplex with reflection / expansion / contraction /
14//! shrink (Nelder & Mead 1965), using the classical coefficients
15//! (α=1, γ=2, ρ=0.5, σ=0.5).
16//!
17//! Box bounds are enforced via **reflection at the wall**: when a proposed
18//! vertex would leave the feasible box, each coordinate is reflected back
19//! inside (`x_i ← 2·bound − x_i` once, then clamped).  This preserves the
20//! simplex volume in bulk while keeping all vertices feasible.
21//!
22//! ## Convergence
23//!
24//! Terminates when both
25//! - the maximum coordinate distance from any simplex vertex to the current
26//!   best vertex (`simplex[0]`) is below `xatol`, AND
27//! - the range of objective values across the simplex is below `fatol`.
28//!
29//! This matches scipy's `optimize.minimize(method='Nelder-Mead')` simplex-
30//! spread check (`max(|sim[i] - sim[0]|)` over coordinates) behaviour.
31
32use crate::error::FittingError;
33
34/// Nelder-Mead configuration.
35#[derive(Debug, Clone)]
36pub struct NelderMeadConfig {
37    /// Absolute tolerance on vertex displacement.
38    pub xatol: f64,
39    /// Absolute tolerance on objective range across the simplex.
40    pub fatol: f64,
41    /// Maximum number of simplex iterations (each iteration = at most a
42    /// constant number of objective evaluations).
43    pub max_iter: usize,
44    /// Initial simplex edge length, used as a signed multiplier on each
45    /// coordinate: `step_i = initial_step_frac * x0_i` (so 0.05 gives a
46    /// 5 % perturbation in the direction of the coordinate's sign).
47    /// When `|x0_i| < 1e-8` the fallback `initial_step_abs` is used
48    /// instead.  Note: this is NOT `initial_step_frac * max(|x0|, 1)`
49    /// — for `|x0| < 1` the perturbation is therefore smaller than
50    /// `initial_step_frac` itself.
51    pub initial_step_frac: f64,
52    /// Small absolute initial step for parameters whose `|x_0| < 1e-8`.
53    pub initial_step_abs: f64,
54}
55
56impl Default for NelderMeadConfig {
57    fn default() -> Self {
58        // Defaults match scipy.optimize.minimize(method='Nelder-Mead'):
59        // xatol = 1e-4, fatol = 1e-4.  For the polish regime described in
60        // EG5 we use tighter tolerances (1e-9 / 1e-10) on the caller side.
61        Self {
62            xatol: 1e-4,
63            fatol: 1e-4,
64            max_iter: 5000,
65            initial_step_frac: 0.05,
66            initial_step_abs: 0.00025,
67        }
68    }
69}
70
71/// Nelder-Mead result.
72#[derive(Debug, Clone)]
73pub struct NelderMeadResult {
74    /// Best parameter vector found.
75    pub x: Vec<f64>,
76    /// Objective value at `x`.
77    pub fun: f64,
78    /// Number of simplex iterations performed.
79    pub iterations: usize,
80    /// Total objective evaluations (including initial simplex).
81    pub n_evals: usize,
82    /// `true` if both `xatol` and `fatol` were satisfied before hitting
83    /// `max_iter`.  Per memo 35 §P2.3, acceptance should be judged from
84    /// the deviance value, not this flag.
85    pub self_converged: bool,
86}
87
88/// Minimize a scalar objective with optional per-coordinate box bounds.
89///
90/// - `f` must be non-panicking; it may return `Err` to signal an infeasible
91///   point (the NM logic treats the vertex as +∞ and contracts away from it).
92/// - `x0` is the initial point.  An initial simplex of `n+1` vertices is
93///   built by perturbing each coordinate in turn.
94/// - `bounds`, if present, must have the same length as `x0`.  Each pair is
95///   `(lower, upper)`; use `f64::NEG_INFINITY` / `f64::INFINITY` to disable.
96///
97/// ## Panics
98///
99/// Does not panic on infeasible objective values.  Panics only if `x0` is
100/// empty or `bounds.len() != x0.len()`.
101pub fn nelder_mead_minimize<F>(
102    mut f: F,
103    x0: &[f64],
104    bounds: Option<&[(f64, f64)]>,
105    config: &NelderMeadConfig,
106) -> Result<NelderMeadResult, FittingError>
107where
108    F: FnMut(&[f64]) -> Result<f64, FittingError>,
109{
110    let n = x0.len();
111    assert!(n > 0, "nelder_mead_minimize: x0 must not be empty");
112    if let Some(b) = bounds {
113        assert_eq!(
114            b.len(),
115            n,
116            "nelder_mead_minimize: bounds length {} != x0 length {}",
117            b.len(),
118            n
119        );
120        for (i, &(lo, hi)) in b.iter().enumerate() {
121            assert!(
122                lo <= hi,
123                "nelder_mead_minimize: bound {i} has lo {lo} > hi {hi}"
124            );
125        }
126    }
127    // Classical Nelder-Mead coefficients.
128    const ALPHA: f64 = 1.0; // reflection
129    const GAMMA: f64 = 2.0; // expansion
130    const RHO: f64 = 0.5; // contraction
131    const SIGMA: f64 = 0.5; // shrink
132
133    // Project a point onto the bounding box.
134    let project = |x: &mut [f64]| {
135        if let Some(b) = bounds {
136            for (xi, &(lo, hi)) in x.iter_mut().zip(b.iter()) {
137                if *xi < lo {
138                    *xi = 2.0 * lo - *xi; // reflect
139                    if *xi > hi {
140                        *xi = hi;
141                    }
142                    if *xi < lo {
143                        *xi = lo;
144                    }
145                } else if *xi > hi {
146                    *xi = 2.0 * hi - *xi;
147                    if *xi < lo {
148                        *xi = lo;
149                    }
150                    if *xi > hi {
151                        *xi = hi;
152                    }
153                }
154            }
155        }
156    };
157
158    // Objective evaluator that turns Err into +∞ (infeasible → avoid).
159    let mut n_evals = 0usize;
160    let mut eval = |x: &[f64], f: &mut F| -> f64 {
161        n_evals += 1;
162        match f(x) {
163            Ok(v) if v.is_finite() => v,
164            _ => f64::INFINITY,
165        }
166    };
167
168    // Build initial simplex.  Vertex 0 is x0; vertex i>0 perturbs coord i-1.
169    let mut simplex: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
170    let mut fvals: Vec<f64> = Vec::with_capacity(n + 1);
171    let mut v0 = x0.to_vec();
172    project(&mut v0);
173    fvals.push(eval(&v0, &mut f));
174    simplex.push(v0.clone());
175    for i in 0..n {
176        let mut v = v0.clone();
177        let base = v[i];
178        let step = if base.abs() > 1e-8 {
179            config.initial_step_frac * base
180        } else {
181            config.initial_step_abs
182        };
183        v[i] = base + step;
184        project(&mut v);
185        // If projection collapsed the perturbation (e.g. vertex hit a wall
186        // and the reflection / clamp put it back on the original coord),
187        // try the opposite direction so the simplex remains non-degenerate.
188        if (v[i] - base).abs() < 1e-14 {
189            v[i] = base - step;
190            project(&mut v);
191            if (v[i] - base).abs() < 1e-14 {
192                // Give up and use the tiny default step — the simplex is
193                // near a corner but still has to start somewhere.
194                v[i] = base
195                    + config
196                        .initial_step_abs
197                        .copysign(if base >= 0.0 { 1.0 } else { -1.0 });
198                project(&mut v);
199            }
200        }
201        fvals.push(eval(&v, &mut f));
202        simplex.push(v);
203    }
204
205    // Sort simplex by ascending f-value.
206    let mut order: Vec<usize> = (0..=n).collect();
207    order.sort_by(|&a, &b| {
208        fvals[a]
209            .partial_cmp(&fvals[b])
210            .unwrap_or(std::cmp::Ordering::Equal)
211    });
212    simplex = order.iter().map(|&i| simplex[i].clone()).collect();
213    fvals = order.iter().map(|&i| fvals[i]).collect();
214
215    let mut centroid = vec![0.0; n];
216    let mut xr = vec![0.0; n];
217    let mut xe = vec![0.0; n];
218    let mut xc = vec![0.0; n];
219
220    let mut iter = 0usize;
221    let mut self_converged = false;
222    while iter < config.max_iter {
223        iter += 1;
224
225        // Convergence check.
226        let fmin = fvals[0];
227        let fmax = fvals[n];
228        let frange = fmax - fmin;
229        // Max coordinate distance from any vertex to the best vertex
230        // (`simplex[0]`).  Matches the scipy Nelder-Mead spread check.
231        let mut xrange = 0.0f64;
232        for v in simplex.iter() {
233            for (j, &xj) in v.iter().enumerate() {
234                let d = (xj - simplex[0][j]).abs();
235                if d > xrange {
236                    xrange = d;
237                }
238            }
239        }
240        if xrange <= config.xatol && frange <= config.fatol {
241            self_converged = true;
242            break;
243        }
244
245        // Centroid of all vertices except the worst.
246        for (j, c) in centroid.iter_mut().enumerate() {
247            let mut s = 0.0;
248            for v in simplex.iter().take(n) {
249                s += v[j];
250            }
251            *c = s / (n as f64);
252        }
253
254        // Reflection.
255        for j in 0..n {
256            xr[j] = centroid[j] + ALPHA * (centroid[j] - simplex[n][j]);
257        }
258        project(&mut xr);
259        let fxr = eval(&xr, &mut f);
260
261        if fvals[0] <= fxr && fxr < fvals[n - 1] {
262            simplex[n] = xr.clone();
263            fvals[n] = fxr;
264        } else if fxr < fvals[0] {
265            // Expansion.
266            for j in 0..n {
267                xe[j] = centroid[j] + GAMMA * (xr[j] - centroid[j]);
268            }
269            project(&mut xe);
270            let fxe = eval(&xe, &mut f);
271            if fxe < fxr {
272                simplex[n] = xe.clone();
273                fvals[n] = fxe;
274            } else {
275                simplex[n] = xr.clone();
276                fvals[n] = fxr;
277            }
278        } else {
279            // Contraction.  Outside contraction (fxr ≥ f[n-1]) chooses the
280            // reflected side; inside contraction chooses the worst side.
281            let (x_src, f_src) = if fxr < fvals[n] {
282                (&xr, fxr)
283            } else {
284                (&simplex[n], fvals[n])
285            };
286            for j in 0..n {
287                xc[j] = centroid[j] + RHO * (x_src[j] - centroid[j]);
288            }
289            project(&mut xc);
290            let fxc = eval(&xc, &mut f);
291            if fxc < f_src {
292                simplex[n] = xc.clone();
293                fvals[n] = fxc;
294            } else {
295                // Shrink toward the best vertex.  Snapshot the best vertex
296                // first to avoid aliasing borrows when mutating
297                // `simplex[i]`.
298                let best = simplex[0].clone();
299                for i in 1..=n {
300                    for (j, xj) in simplex[i].iter_mut().enumerate() {
301                        *xj = best[j] + SIGMA * (*xj - best[j]);
302                    }
303                    project(&mut simplex[i]);
304                    fvals[i] = eval(&simplex[i], &mut f);
305                }
306            }
307        }
308
309        // Re-sort simplex (O(n log n) — n is small for our use).
310        let mut order: Vec<usize> = (0..=n).collect();
311        order.sort_by(|&a, &b| {
312            fvals[a]
313                .partial_cmp(&fvals[b])
314                .unwrap_or(std::cmp::Ordering::Equal)
315        });
316        simplex = order.iter().map(|&i| simplex[i].clone()).collect();
317        fvals = order.iter().map(|&i| fvals[i]).collect();
318    }
319
320    Ok(NelderMeadResult {
321        x: simplex[0].clone(),
322        fun: fvals[0],
323        iterations: iter,
324        n_evals,
325        self_converged,
326    })
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_nm_quadratic_1d_converges() {
335        // f(x) = (x − 3)².
336        let f = |x: &[f64]| Ok((x[0] - 3.0).powi(2));
337        let cfg = NelderMeadConfig {
338            xatol: 1e-10,
339            fatol: 1e-12,
340            max_iter: 5000,
341            initial_step_frac: 0.1,
342            initial_step_abs: 0.01,
343        };
344        let r = nelder_mead_minimize(f, &[0.0], None, &cfg).unwrap();
345        assert!((r.x[0] - 3.0).abs() < 1e-6, "x = {:?}", r.x);
346        assert!(r.fun < 1e-12);
347        assert!(r.self_converged);
348    }
349
350    #[test]
351    fn test_nm_rosenbrock_2d() {
352        // Classic: f(x,y) = (1-x)² + 100(y-x²)², minimum at (1,1) with f=0.
353        let f = |x: &[f64]| Ok((1.0 - x[0]).powi(2) + 100.0 * (x[1] - x[0].powi(2)).powi(2));
354        let cfg = NelderMeadConfig {
355            xatol: 1e-6,
356            fatol: 1e-8,
357            max_iter: 10_000,
358            initial_step_frac: 0.1,
359            initial_step_abs: 0.01,
360        };
361        let r = nelder_mead_minimize(f, &[-1.2, 1.0], None, &cfg).unwrap();
362        assert!(
363            (r.x[0] - 1.0).abs() < 1e-3 && (r.x[1] - 1.0).abs() < 1e-3,
364            "Rosenbrock minimizer off: x = {:?} fun = {}",
365            r.x,
366            r.fun
367        );
368        assert!(r.fun < 1e-6);
369    }
370
371    #[test]
372    fn test_nm_respects_bounds_reflection() {
373        // f(x) = (x − 5)²; but bound x to [0, 2] — true minimum inside the
374        // box is at x = 2 (boundary).  Verify NM returns x ≈ 2 and never a
375        // value outside the box during search.
376        let lo = 0.0;
377        let hi = 2.0;
378        let f = {
379            move |x: &[f64]| -> Result<f64, FittingError> {
380                assert!(
381                    x[0] >= lo - 1e-12 && x[0] <= hi + 1e-12,
382                    "NM passed out-of-bounds x = {}",
383                    x[0]
384                );
385                Ok((x[0] - 5.0).powi(2))
386            }
387        };
388        let cfg = NelderMeadConfig::default();
389        let bounds = [(lo, hi)];
390        let r = nelder_mead_minimize(f, &[1.0], Some(&bounds), &cfg).unwrap();
391        assert!(
392            (r.x[0] - 2.0).abs() < 1e-2,
393            "expected x ≈ 2, got {}",
394            r.x[0]
395        );
396        assert!(r.x[0] >= lo - 1e-12 && r.x[0] <= hi + 1e-12);
397    }
398
399    #[test]
400    fn test_nm_handles_infeasible_objective() {
401        // f returns Err for x[0] < 0.1, otherwise (x-0.5)^2.  NM should
402        // find x ≈ 0.5 and never return the infeasible region.
403        let f = |x: &[f64]| -> Result<f64, FittingError> {
404            if x[0] < 0.1 {
405                Err(FittingError::EvaluationFailed("x too small".into()))
406            } else {
407                Ok((x[0] - 0.5).powi(2))
408            }
409        };
410        let cfg = NelderMeadConfig {
411            xatol: 1e-8,
412            fatol: 1e-10,
413            max_iter: 5000,
414            initial_step_frac: 0.2,
415            initial_step_abs: 0.05,
416        };
417        let r = nelder_mead_minimize(f, &[1.0], None, &cfg).unwrap();
418        assert!(
419            (r.x[0] - 0.5).abs() < 1e-3,
420            "expected x ≈ 0.5, got {} (fun = {})",
421            r.x[0],
422            r.fun
423        );
424    }
425}