nereids_fitting/nelder_mead.rs
1//! Bounded Nelder-Mead simplex minimizer.
2//!
3//! Derivative-free polish optimizer used after a gradient-based stage to
4//! escape stall points. Memo 35 §P2.1 and EG5 establish that, for
5//! backgrounded counts-path fits, a single L-BFGS start frequently stalls
6//! at the initial guess (1/20 self-flagged convergence on the EG2 S1 C_full
7//! regime), while a Nelder-Mead polish from that stall point resolves the
8//! failure cleanly (10/20 convergence, density bias from −5.94% to +0.013%,
9//! D/DOF from 905 to 1.001).
10//!
11//! ## Algorithm
12//!
13//! Standard Nelder-Mead simplex with reflection / expansion / contraction /
14//! shrink (Nelder & Mead 1965), using the classical coefficients
15//! (α=1, γ=2, ρ=0.5, σ=0.5).
16//!
17//! Box bounds are enforced via **reflection at the wall**: when a proposed
18//! vertex would leave the feasible box, each coordinate is reflected back
19//! inside (`x_i ← 2·bound − x_i` once, then clamped). This preserves the
20//! simplex volume in bulk while keeping all vertices feasible.
21//!
22//! ## Convergence
23//!
24//! Terminates when both
25//! - the maximum coordinate distance from any simplex vertex to the current
26//! best vertex (`simplex[0]`) is below `xatol`, AND
27//! - the range of objective values across the simplex is below `fatol`.
28//!
29//! This matches scipy's `optimize.minimize(method='Nelder-Mead')` simplex-
30//! spread check (`max(|sim[i] - sim[0]|)` over coordinates) behaviour.
31
32use crate::error::FittingError;
33
34/// Nelder-Mead configuration.
35#[derive(Debug, Clone)]
36pub struct NelderMeadConfig {
37 /// Absolute tolerance on vertex displacement.
38 pub xatol: f64,
39 /// Absolute tolerance on objective range across the simplex.
40 pub fatol: f64,
41 /// Maximum number of simplex iterations (each iteration = at most a
42 /// constant number of objective evaluations).
43 pub max_iter: usize,
44 /// Initial simplex edge length, used as a signed multiplier on each
45 /// coordinate: `step_i = initial_step_frac * x0_i` (so 0.05 gives a
46 /// 5 % perturbation in the direction of the coordinate's sign).
47 /// When `|x0_i| < 1e-8` the fallback `initial_step_abs` is used
48 /// instead. Note: this is NOT `initial_step_frac * max(|x0|, 1)`
49 /// — for `|x0| < 1` the perturbation is therefore smaller than
50 /// `initial_step_frac` itself.
51 pub initial_step_frac: f64,
52 /// Small absolute initial step for parameters whose `|x_0| < 1e-8`.
53 pub initial_step_abs: f64,
54}
55
56impl Default for NelderMeadConfig {
57 fn default() -> Self {
58 // Defaults match scipy.optimize.minimize(method='Nelder-Mead'):
59 // xatol = 1e-4, fatol = 1e-4. For the polish regime described in
60 // EG5 we use tighter tolerances (1e-9 / 1e-10) on the caller side.
61 Self {
62 xatol: 1e-4,
63 fatol: 1e-4,
64 max_iter: 5000,
65 initial_step_frac: 0.05,
66 initial_step_abs: 0.00025,
67 }
68 }
69}
70
71/// Nelder-Mead result.
72#[derive(Debug, Clone)]
73pub struct NelderMeadResult {
74 /// Best parameter vector found.
75 pub x: Vec<f64>,
76 /// Objective value at `x`.
77 pub fun: f64,
78 /// Number of simplex iterations performed.
79 pub iterations: usize,
80 /// Total objective evaluations (including initial simplex).
81 pub n_evals: usize,
82 /// `true` if both `xatol` and `fatol` were satisfied before hitting
83 /// `max_iter`. Per memo 35 §P2.3, acceptance should be judged from
84 /// the deviance value, not this flag.
85 pub self_converged: bool,
86}
87
88/// Minimize a scalar objective with optional per-coordinate box bounds.
89///
90/// - `f` must be non-panicking; it may return `Err` to signal an infeasible
91/// point (the NM logic treats the vertex as +∞ and contracts away from it).
92/// - `x0` is the initial point. An initial simplex of `n+1` vertices is
93/// built by perturbing each coordinate in turn.
94/// - `bounds`, if present, must have the same length as `x0`. Each pair is
95/// `(lower, upper)`; use `f64::NEG_INFINITY` / `f64::INFINITY` to disable.
96///
97/// ## Panics
98///
99/// Does not panic on infeasible objective values. Panics only if `x0` is
100/// empty or `bounds.len() != x0.len()`.
101pub fn nelder_mead_minimize<F>(
102 mut f: F,
103 x0: &[f64],
104 bounds: Option<&[(f64, f64)]>,
105 config: &NelderMeadConfig,
106) -> Result<NelderMeadResult, FittingError>
107where
108 F: FnMut(&[f64]) -> Result<f64, FittingError>,
109{
110 let n = x0.len();
111 assert!(n > 0, "nelder_mead_minimize: x0 must not be empty");
112 if let Some(b) = bounds {
113 assert_eq!(
114 b.len(),
115 n,
116 "nelder_mead_minimize: bounds length {} != x0 length {}",
117 b.len(),
118 n
119 );
120 for (i, &(lo, hi)) in b.iter().enumerate() {
121 assert!(
122 lo <= hi,
123 "nelder_mead_minimize: bound {i} has lo {lo} > hi {hi}"
124 );
125 }
126 }
127 // Classical Nelder-Mead coefficients.
128 const ALPHA: f64 = 1.0; // reflection
129 const GAMMA: f64 = 2.0; // expansion
130 const RHO: f64 = 0.5; // contraction
131 const SIGMA: f64 = 0.5; // shrink
132
133 // Project a point onto the bounding box.
134 let project = |x: &mut [f64]| {
135 if let Some(b) = bounds {
136 for (xi, &(lo, hi)) in x.iter_mut().zip(b.iter()) {
137 if *xi < lo {
138 *xi = 2.0 * lo - *xi; // reflect
139 if *xi > hi {
140 *xi = hi;
141 }
142 if *xi < lo {
143 *xi = lo;
144 }
145 } else if *xi > hi {
146 *xi = 2.0 * hi - *xi;
147 if *xi < lo {
148 *xi = lo;
149 }
150 if *xi > hi {
151 *xi = hi;
152 }
153 }
154 }
155 }
156 };
157
158 // Objective evaluator that turns Err into +∞ (infeasible → avoid).
159 let mut n_evals = 0usize;
160 let mut eval = |x: &[f64], f: &mut F| -> f64 {
161 n_evals += 1;
162 match f(x) {
163 Ok(v) if v.is_finite() => v,
164 _ => f64::INFINITY,
165 }
166 };
167
168 // Build initial simplex. Vertex 0 is x0; vertex i>0 perturbs coord i-1.
169 let mut simplex: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
170 let mut fvals: Vec<f64> = Vec::with_capacity(n + 1);
171 let mut v0 = x0.to_vec();
172 project(&mut v0);
173 fvals.push(eval(&v0, &mut f));
174 simplex.push(v0.clone());
175 for i in 0..n {
176 let mut v = v0.clone();
177 let base = v[i];
178 let step = if base.abs() > 1e-8 {
179 config.initial_step_frac * base
180 } else {
181 config.initial_step_abs
182 };
183 v[i] = base + step;
184 project(&mut v);
185 // If projection collapsed the perturbation (e.g. vertex hit a wall
186 // and the reflection / clamp put it back on the original coord),
187 // try the opposite direction so the simplex remains non-degenerate.
188 if (v[i] - base).abs() < 1e-14 {
189 v[i] = base - step;
190 project(&mut v);
191 if (v[i] - base).abs() < 1e-14 {
192 // Give up and use the tiny default step — the simplex is
193 // near a corner but still has to start somewhere.
194 v[i] = base
195 + config
196 .initial_step_abs
197 .copysign(if base >= 0.0 { 1.0 } else { -1.0 });
198 project(&mut v);
199 }
200 }
201 fvals.push(eval(&v, &mut f));
202 simplex.push(v);
203 }
204
205 // Sort simplex by ascending f-value.
206 let mut order: Vec<usize> = (0..=n).collect();
207 order.sort_by(|&a, &b| {
208 fvals[a]
209 .partial_cmp(&fvals[b])
210 .unwrap_or(std::cmp::Ordering::Equal)
211 });
212 simplex = order.iter().map(|&i| simplex[i].clone()).collect();
213 fvals = order.iter().map(|&i| fvals[i]).collect();
214
215 let mut centroid = vec![0.0; n];
216 let mut xr = vec![0.0; n];
217 let mut xe = vec![0.0; n];
218 let mut xc = vec![0.0; n];
219
220 let mut iter = 0usize;
221 let mut self_converged = false;
222 while iter < config.max_iter {
223 iter += 1;
224
225 // Convergence check.
226 let fmin = fvals[0];
227 let fmax = fvals[n];
228 let frange = fmax - fmin;
229 // Max coordinate distance from any vertex to the best vertex
230 // (`simplex[0]`). Matches the scipy Nelder-Mead spread check.
231 let mut xrange = 0.0f64;
232 for v in simplex.iter() {
233 for (j, &xj) in v.iter().enumerate() {
234 let d = (xj - simplex[0][j]).abs();
235 if d > xrange {
236 xrange = d;
237 }
238 }
239 }
240 if xrange <= config.xatol && frange <= config.fatol {
241 self_converged = true;
242 break;
243 }
244
245 // Centroid of all vertices except the worst.
246 for (j, c) in centroid.iter_mut().enumerate() {
247 let mut s = 0.0;
248 for v in simplex.iter().take(n) {
249 s += v[j];
250 }
251 *c = s / (n as f64);
252 }
253
254 // Reflection.
255 for j in 0..n {
256 xr[j] = centroid[j] + ALPHA * (centroid[j] - simplex[n][j]);
257 }
258 project(&mut xr);
259 let fxr = eval(&xr, &mut f);
260
261 if fvals[0] <= fxr && fxr < fvals[n - 1] {
262 simplex[n] = xr.clone();
263 fvals[n] = fxr;
264 } else if fxr < fvals[0] {
265 // Expansion.
266 for j in 0..n {
267 xe[j] = centroid[j] + GAMMA * (xr[j] - centroid[j]);
268 }
269 project(&mut xe);
270 let fxe = eval(&xe, &mut f);
271 if fxe < fxr {
272 simplex[n] = xe.clone();
273 fvals[n] = fxe;
274 } else {
275 simplex[n] = xr.clone();
276 fvals[n] = fxr;
277 }
278 } else {
279 // Contraction. Outside contraction (fxr ≥ f[n-1]) chooses the
280 // reflected side; inside contraction chooses the worst side.
281 let (x_src, f_src) = if fxr < fvals[n] {
282 (&xr, fxr)
283 } else {
284 (&simplex[n], fvals[n])
285 };
286 for j in 0..n {
287 xc[j] = centroid[j] + RHO * (x_src[j] - centroid[j]);
288 }
289 project(&mut xc);
290 let fxc = eval(&xc, &mut f);
291 if fxc < f_src {
292 simplex[n] = xc.clone();
293 fvals[n] = fxc;
294 } else {
295 // Shrink toward the best vertex. Snapshot the best vertex
296 // first to avoid aliasing borrows when mutating
297 // `simplex[i]`.
298 let best = simplex[0].clone();
299 for i in 1..=n {
300 for (j, xj) in simplex[i].iter_mut().enumerate() {
301 *xj = best[j] + SIGMA * (*xj - best[j]);
302 }
303 project(&mut simplex[i]);
304 fvals[i] = eval(&simplex[i], &mut f);
305 }
306 }
307 }
308
309 // Re-sort simplex (O(n log n) — n is small for our use).
310 let mut order: Vec<usize> = (0..=n).collect();
311 order.sort_by(|&a, &b| {
312 fvals[a]
313 .partial_cmp(&fvals[b])
314 .unwrap_or(std::cmp::Ordering::Equal)
315 });
316 simplex = order.iter().map(|&i| simplex[i].clone()).collect();
317 fvals = order.iter().map(|&i| fvals[i]).collect();
318 }
319
320 Ok(NelderMeadResult {
321 x: simplex[0].clone(),
322 fun: fvals[0],
323 iterations: iter,
324 n_evals,
325 self_converged,
326 })
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_nm_quadratic_1d_converges() {
335 // f(x) = (x − 3)².
336 let f = |x: &[f64]| Ok((x[0] - 3.0).powi(2));
337 let cfg = NelderMeadConfig {
338 xatol: 1e-10,
339 fatol: 1e-12,
340 max_iter: 5000,
341 initial_step_frac: 0.1,
342 initial_step_abs: 0.01,
343 };
344 let r = nelder_mead_minimize(f, &[0.0], None, &cfg).unwrap();
345 assert!((r.x[0] - 3.0).abs() < 1e-6, "x = {:?}", r.x);
346 assert!(r.fun < 1e-12);
347 assert!(r.self_converged);
348 }
349
350 #[test]
351 fn test_nm_rosenbrock_2d() {
352 // Classic: f(x,y) = (1-x)² + 100(y-x²)², minimum at (1,1) with f=0.
353 let f = |x: &[f64]| Ok((1.0 - x[0]).powi(2) + 100.0 * (x[1] - x[0].powi(2)).powi(2));
354 let cfg = NelderMeadConfig {
355 xatol: 1e-6,
356 fatol: 1e-8,
357 max_iter: 10_000,
358 initial_step_frac: 0.1,
359 initial_step_abs: 0.01,
360 };
361 let r = nelder_mead_minimize(f, &[-1.2, 1.0], None, &cfg).unwrap();
362 assert!(
363 (r.x[0] - 1.0).abs() < 1e-3 && (r.x[1] - 1.0).abs() < 1e-3,
364 "Rosenbrock minimizer off: x = {:?} fun = {}",
365 r.x,
366 r.fun
367 );
368 assert!(r.fun < 1e-6);
369 }
370
371 #[test]
372 fn test_nm_respects_bounds_reflection() {
373 // f(x) = (x − 5)²; but bound x to [0, 2] — true minimum inside the
374 // box is at x = 2 (boundary). Verify NM returns x ≈ 2 and never a
375 // value outside the box during search.
376 let lo = 0.0;
377 let hi = 2.0;
378 let f = {
379 move |x: &[f64]| -> Result<f64, FittingError> {
380 assert!(
381 x[0] >= lo - 1e-12 && x[0] <= hi + 1e-12,
382 "NM passed out-of-bounds x = {}",
383 x[0]
384 );
385 Ok((x[0] - 5.0).powi(2))
386 }
387 };
388 let cfg = NelderMeadConfig::default();
389 let bounds = [(lo, hi)];
390 let r = nelder_mead_minimize(f, &[1.0], Some(&bounds), &cfg).unwrap();
391 assert!(
392 (r.x[0] - 2.0).abs() < 1e-2,
393 "expected x ≈ 2, got {}",
394 r.x[0]
395 );
396 assert!(r.x[0] >= lo - 1e-12 && r.x[0] <= hi + 1e-12);
397 }
398
399 #[test]
400 fn test_nm_handles_infeasible_objective() {
401 // f returns Err for x[0] < 0.1, otherwise (x-0.5)^2. NM should
402 // find x ≈ 0.5 and never return the infeasible region.
403 let f = |x: &[f64]| -> Result<f64, FittingError> {
404 if x[0] < 0.1 {
405 Err(FittingError::EvaluationFailed("x too small".into()))
406 } else {
407 Ok((x[0] - 0.5).powi(2))
408 }
409 };
410 let cfg = NelderMeadConfig {
411 xatol: 1e-8,
412 fatol: 1e-10,
413 max_iter: 5000,
414 initial_step_frac: 0.2,
415 initial_step_abs: 0.05,
416 };
417 let r = nelder_mead_minimize(f, &[1.0], None, &cfg).unwrap();
418 assert!(
419 (r.x[0] - 0.5).abs() < 1e-3,
420 "expected x ≈ 0.5, got {} (fun = {})",
421 r.x[0],
422 r.fun
423 );
424 }
425}