nereids_pipeline/
spatial.rs

1//! Spatial mapping: per-pixel fitting with rayon parallelization.
2//!
3//! Applies the single-spectrum fitting pipeline across all pixels in
4//! a hyperspectral neutron imaging dataset to produce 2D composition maps.
5
6use ndarray::{Array2, ArrayView3, s};
7use rayon::prelude::*;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10
11use nereids_physics::transmission::{
12    InstrumentParams, broadened_cross_sections, unbroadened_cross_sections,
13};
14
15use crate::error::PipelineError;
16use crate::pipeline::SpectrumFitResult;
17
18/// Result of spatial mapping over a 2D image.
19#[derive(Debug)]
20pub struct SpatialResult {
21    /// Fitted areal density maps, one per isotope.
22    /// Each Array2 has shape (height, width).
23    pub density_maps: Vec<Array2<f64>>,
24    /// Uncertainty maps, one per isotope.
25    pub uncertainty_maps: Vec<Array2<f64>>,
26    /// Reduced chi-squared map.
27    pub chi_squared_map: Array2<f64>,
28    /// Convergence map (true = converged).
29    pub converged_map: Array2<bool>,
30    /// Fitted temperature map (K). `Some` when `config.fit_temperature()` is true.
31    pub temperature_map: Option<Array2<f64>>,
32    /// Per-pixel temperature uncertainty map (K, 1-sigma).
33    /// `Some` when `config.fit_temperature()` is true.
34    /// Entries are NaN where uncertainty was unavailable for that pixel.
35    pub temperature_uncertainty_map: Option<Array2<f64>>,
36    /// Isotope labels captured at compute time, one per density map.
37    /// Ensures display labels stay in sync with density data even if the
38    /// user modifies the isotope list after fitting.
39    pub isotope_labels: Vec<String>,
40    /// Per-pixel normalization / signal-scale map (when background fitting is enabled).
41    pub anorm_map: Option<Array2<f64>>,
42    /// Per-pixel background parameter maps.
43    /// Transmission LM uses `[BackA, BackB, BackC]`.
44    /// Counts KL background uses `[b0, b1, alpha_2]`.
45    pub background_maps: Option<[Array2<f64>; 3]>,
46    /// Number of pixels that converged.
47    pub n_converged: usize,
48    /// Total number of pixels fitted.
49    pub n_total: usize,
50    /// Number of pixels where the fitter returned an error (not just
51    /// non-convergence — a hard failure like invalid parameters or NaN
52    /// model output). These pixels have NaN density and false convergence.
53    pub n_failed: usize,
54}
55
56// ── Phase 3: InputData3D + spatial_map_typed ─────────────────────────────
57
58use crate::pipeline::{InputData, SolverConfig, UnifiedFitConfig, fit_spectrum_typed};
59
60/// 3D input data for spatial mapping.
61///
62/// The outer dimension is energy (axis 0), inner dimensions are spatial (y, x).
63/// The two variants correspond to [`InputData`] but carry 3D arrays.
64#[derive(Debug)]
65pub enum InputData3D<'a> {
66    /// Pre-normalized transmission + uncertainty.
67    Transmission {
68        transmission: ArrayView3<'a, f64>,
69        uncertainty: ArrayView3<'a, f64>,
70    },
71    /// Raw detector counts + open beam reference.
72    Counts {
73        sample_counts: ArrayView3<'a, f64>,
74        open_beam_counts: ArrayView3<'a, f64>,
75    },
76    /// Raw detector counts with explicit nuisance spectra.
77    CountsWithNuisance {
78        sample_counts: ArrayView3<'a, f64>,
79        flux: ArrayView3<'a, f64>,
80        background: ArrayView3<'a, f64>,
81    },
82}
83
84impl InputData3D<'_> {
85    /// Shape of the data: (n_energies, height, width).
86    pub(crate) fn shape(&self) -> (usize, usize, usize) {
87        let s = match self {
88            Self::Transmission { transmission, .. } => transmission.shape(),
89            Self::Counts { sample_counts, .. } => sample_counts.shape(),
90            Self::CountsWithNuisance { sample_counts, .. } => sample_counts.shape(),
91        };
92        (s[0], s[1], s[2])
93    }
94}
95
96/// Spatial mapping using the typed input data API.
97///
98/// Dispatches per-pixel fitting based on the `InputData3D` variant:
99/// - **Transmission**: per-pixel LM or KL on transmission values
100/// - **Counts**: per-pixel KL on raw counts (preserves Poisson statistics)
101///
102/// Always returns [`SpatialResult`].
103pub fn spatial_map_typed(
104    input: &InputData3D<'_>,
105    config: &UnifiedFitConfig,
106    dead_pixels: Option<&Array2<bool>>,
107    cancel: Option<&AtomicBool>,
108    progress: Option<&AtomicUsize>,
109) -> Result<SpatialResult, PipelineError> {
110    let (n_energies, height, width) = input.shape();
111    // n_maps = number of density maps to return (one per group or per isotope).
112    let n_maps = config.n_density_params();
113
114    // Validate shapes
115    if n_energies != config.energies().len() {
116        return Err(PipelineError::ShapeMismatch(format!(
117            "input spectral axis ({n_energies}) != config.energies length ({})",
118            config.energies().len(),
119        )));
120    }
121    match input {
122        InputData3D::Transmission {
123            transmission,
124            uncertainty,
125        } => {
126            if uncertainty.shape() != transmission.shape() {
127                return Err(PipelineError::ShapeMismatch(format!(
128                    "uncertainty shape {:?} != transmission shape {:?}",
129                    uncertainty.shape(),
130                    transmission.shape(),
131                )));
132            }
133        }
134        InputData3D::Counts {
135            sample_counts,
136            open_beam_counts,
137        } => {
138            if open_beam_counts.shape() != sample_counts.shape() {
139                return Err(PipelineError::ShapeMismatch(format!(
140                    "open_beam shape {:?} != sample shape {:?}",
141                    open_beam_counts.shape(),
142                    sample_counts.shape(),
143                )));
144            }
145        }
146        InputData3D::CountsWithNuisance {
147            sample_counts,
148            flux,
149            background,
150        } => {
151            if flux.shape() != sample_counts.shape() {
152                return Err(PipelineError::ShapeMismatch(format!(
153                    "flux shape {:?} != sample shape {:?}",
154                    flux.shape(),
155                    sample_counts.shape(),
156                )));
157            }
158            if background.shape() != sample_counts.shape() {
159                return Err(PipelineError::ShapeMismatch(format!(
160                    "background shape {:?} != sample shape {:?}",
161                    background.shape(),
162                    sample_counts.shape(),
163                )));
164            }
165        }
166    }
167    if let Some(dp) = dead_pixels
168        && dp.shape() != [height, width]
169    {
170        return Err(PipelineError::ShapeMismatch(format!(
171            "dead_pixels shape {:?} != spatial dimensions ({height}, {width})",
172            dp.shape(),
173        )));
174    }
175
176    // Collect live pixel coordinates
177    let mut pixel_coords: Vec<(usize, usize)> = Vec::new();
178    for y in 0..height {
179        for x in 0..width {
180            let is_dead = dead_pixels.is_some_and(|m| m[[y, x]]);
181            if !is_dead {
182                pixel_coords.push((y, x));
183            }
184        }
185    }
186
187    let isotope_labels = config.isotope_names().to_vec();
188    let has_background_outputs =
189        config.transmission_background().is_some() || config.counts_background().is_some();
190
191    if cancel.is_some_and(|c| c.load(Ordering::Relaxed)) {
192        return Err(PipelineError::Cancelled);
193    }
194    if pixel_coords.is_empty() {
195        return Ok(SpatialResult {
196            density_maps: (0..n_maps)
197                .map(|_| Array2::zeros((height, width)))
198                .collect(),
199            uncertainty_maps: (0..n_maps)
200                .map(|_| Array2::from_elem((height, width), f64::NAN))
201                .collect(),
202            chi_squared_map: Array2::from_elem((height, width), f64::NAN),
203            converged_map: Array2::from_elem((height, width), false),
204            temperature_map: if config.fit_temperature() {
205                Some(Array2::from_elem((height, width), f64::NAN))
206            } else {
207                None
208            },
209            temperature_uncertainty_map: if config.fit_temperature() {
210                Some(Array2::from_elem((height, width), f64::NAN))
211            } else {
212                None
213            },
214            isotope_labels,
215            anorm_map: if has_background_outputs {
216                Some(Array2::from_elem((height, width), f64::NAN))
217            } else {
218                None
219            },
220            background_maps: if has_background_outputs {
221                Some([
222                    Array2::from_elem((height, width), f64::NAN),
223                    Array2::from_elem((height, width), f64::NAN),
224                    Array2::from_elem((height, width), f64::NAN),
225                ])
226            } else {
227                None
228            },
229            n_converged: 0,
230            n_total: 0,
231            n_failed: 0,
232        });
233    }
234
235    // Transpose data to (height, width, n_energies) for cache locality.
236    let (data_a, data_b, data_c) = match input {
237        InputData3D::Transmission {
238            transmission,
239            uncertainty,
240        } => {
241            let a = transmission
242                .permuted_axes([1, 2, 0])
243                .as_standard_layout()
244                .into_owned();
245            let b = uncertainty
246                .permuted_axes([1, 2, 0])
247                .as_standard_layout()
248                .into_owned();
249            (a, b, None)
250        }
251        InputData3D::Counts {
252            sample_counts,
253            open_beam_counts,
254        } => {
255            let a = sample_counts
256                .permuted_axes([1, 2, 0])
257                .as_standard_layout()
258                .into_owned();
259            let b = open_beam_counts
260                .permuted_axes([1, 2, 0])
261                .as_standard_layout()
262                .into_owned();
263            (a, b, None)
264        }
265        InputData3D::CountsWithNuisance {
266            sample_counts,
267            flux,
268            background,
269        } => {
270            let a = sample_counts
271                .permuted_axes([1, 2, 0])
272                .as_standard_layout()
273                .into_owned();
274            let b = flux
275                .permuted_axes([1, 2, 0])
276                .as_standard_layout()
277                .into_owned();
278            let c = background
279                .permuted_axes([1, 2, 0])
280                .as_standard_layout()
281                .into_owned();
282            (a, b, Some(c))
283        }
284    };
285
286    // Precompute cross-sections once (shared across all pixels)
287    let xs: Arc<Vec<Vec<f64>>> = match config.precomputed_cross_sections().cloned() {
288        Some(cached) => cached,
289        None => {
290            let instrument = config.resolution().map(|r| InstrumentParams {
291                resolution: r.clone(),
292            });
293            let xs_raw = broadened_cross_sections(
294                config.energies(),
295                config.resonance_data(),
296                config.temperature_k(),
297                instrument.as_ref(),
298                cancel,
299            )?;
300            Arc::new(xs_raw)
301        }
302    };
303
304    // When groups are active and temperature is NOT being fitted, collapse
305    // per-member broadened XS into per-group σ_eff once here.  This avoids
306    // redundant O(n_members × n_energies) collapsing inside
307    // build_transmission_model on every per-pixel call.
308    let xs = if !config.fit_temperature()
309        && let (Some(di), Some(dr)) = (&config.density_indices, &config.density_ratios)
310        && xs.len() == di.len()
311        && di.len() == dr.len()
312    {
313        let n_e = xs[0].len();
314        let mut eff = vec![vec![0.0f64; n_e]; n_maps];
315        for ((&idx, &ratio), member_xs) in di.iter().zip(dr.iter()).zip(xs.iter()) {
316            for (j, &sigma) in member_xs.iter().enumerate() {
317                eff[idx][j] += ratio * sigma;
318            }
319        }
320        Arc::new(eff)
321    } else {
322        xs
323    };
324
325    // Precompute unbroadened (base) cross-sections for temperature fitting.
326    // This avoids 74× overhead from redundant Reich-Moore evaluation per
327    // KL iteration (112ms Reich-Moore vs 1.5ms Doppler rebroadening).
328    let fast_config = if config.fit_temperature() {
329        let base_xs: Vec<Vec<f64>> =
330            unbroadened_cross_sections(config.energies(), config.resonance_data(), cancel)
331                .map_err(PipelineError::Transmission)?;
332        config
333            .clone()
334            .with_precomputed_cross_sections(xs)
335            .with_precomputed_base_xs(Arc::new(base_xs))
336            .with_compute_covariance(true)
337    } else {
338        // For non-temperature path: xs is already collapsed to σ_eff when
339        // groups are active, so clear group mapping to prevent double-collapse
340        // inside build_transmission_model.
341        let mut cfg = config.clone();
342        if cfg.density_indices.is_some() {
343            cfg.density_indices = None;
344            cfg.density_ratios = None;
345        }
346        cfg.with_precomputed_cross_sections(xs)
347            .with_compute_covariance(true)
348    };
349
350    // For counts data: spatially average the open beam to get a stable flux
351    // estimate, reducing per-pixel open-beam shot noise.
352    // Without this, per-pixel open-beam shot noise contaminates the flux
353    // estimate and makes KL fits materially noisier.
354    let averaged_flux: Option<Vec<f64>> = if matches!(input, InputData3D::Counts { .. }) {
355        let n_e = data_b.shape()[2]; // data_b is transposed: (h, w, n_e)
356        let mut flux = vec![0.0f64; n_e];
357        let n_live = pixel_coords.len() as f64;
358        if n_live > 0.0 {
359            for &(y, x) in &pixel_coords {
360                let ob_spectrum = data_b.slice(s![y, x, ..]);
361                for (e, &v) in ob_spectrum.iter().enumerate() {
362                    flux[e] += v;
363                }
364            }
365            for v in &mut flux {
366                *v /= n_live;
367            }
368        }
369        Some(flux)
370    } else {
371        None
372    };
373    let background_zeros: Vec<f64> = if matches!(input, InputData3D::Counts { .. }) {
374        vec![0.0f64; data_b.shape()[2]]
375    } else {
376        Vec::new()
377    };
378
379    // Fit all pixels in parallel
380    let failed_count = AtomicUsize::new(0);
381    let results: Vec<((usize, usize), SpectrumFitResult)> = pixel_coords
382        .par_iter()
383        .filter_map(|&(y, x)| {
384            if cancel.is_some_and(|c| c.load(Ordering::Relaxed)) {
385                return None;
386            }
387
388            let spectrum_a: Vec<f64> = data_a.slice(s![y, x, ..]).to_vec();
389
390            // Build per-pixel 1D InputData
391            let pixel_input = match input {
392                InputData3D::Counts { .. } => {
393                    let sample_clamped: Vec<f64> = spectrum_a.iter().map(|&v| v.max(0.0)).collect();
394                    let ob_spectrum: Vec<f64> = data_b.slice(s![y, x, ..]).to_vec();
395
396                    // Check effective solver: KL uses CountsWithNuisance
397                    // (averaged flux), LM uses raw Counts (auto-converts to
398                    // transmission inside fit_spectrum_typed).
399                    let effective = fast_config.effective_solver(&InputData::Counts {
400                        sample_counts: sample_clamped.clone(),
401                        open_beam_counts: ob_spectrum.clone(),
402                    });
403                    match effective {
404                        SolverConfig::PoissonKL(_) => InputData::CountsWithNuisance {
405                            sample_counts: sample_clamped,
406                            flux: averaged_flux.as_ref().unwrap().clone(),
407                            // Raw-count spatial path currently assumes zero
408                            // detector background unless the caller provides
409                            // explicit nuisance spectra.
410                            background: background_zeros.clone(),
411                        },
412                        _ => InputData::Counts {
413                            sample_counts: sample_clamped,
414                            open_beam_counts: ob_spectrum,
415                        },
416                    }
417                }
418                InputData3D::CountsWithNuisance { .. } => InputData::CountsWithNuisance {
419                    sample_counts: spectrum_a.iter().map(|&v| v.max(0.0)).collect(),
420                    flux: data_b.slice(s![y, x, ..]).to_vec(),
421                    background: data_c
422                        .as_ref()
423                        .expect("CountsWithNuisance requires background cube")
424                        .slice(s![y, x, ..])
425                        .to_vec(),
426                },
427                InputData3D::Transmission { .. } => {
428                    let spectrum_b: Vec<f64> = data_b
429                        .slice(s![y, x, ..])
430                        .iter()
431                        .map(|&v| v.max(1e-10))
432                        .collect();
433                    InputData::Transmission {
434                        transmission: spectrum_a,
435                        uncertainty: spectrum_b,
436                    }
437                }
438            };
439
440            let out = match fit_spectrum_typed(&pixel_input, &fast_config) {
441                Ok(result) => Some(((y, x), result)),
442                Err(_) => {
443                    failed_count.fetch_add(1, Ordering::Relaxed);
444                    None
445                }
446            };
447            if let Some(p) = progress {
448                p.fetch_add(1, Ordering::Relaxed);
449            }
450            out
451        })
452        .collect();
453
454    if cancel.is_some_and(|c| c.load(Ordering::Relaxed)) && results.is_empty() {
455        return Err(PipelineError::Cancelled);
456    }
457
458    // Assemble output maps
459    let mut density_maps: Vec<Array2<f64>> = (0..n_maps)
460        .map(|_| Array2::from_elem((height, width), f64::NAN))
461        .collect();
462    let mut uncertainty_maps: Vec<Array2<f64>> = (0..n_maps)
463        .map(|_| Array2::from_elem((height, width), f64::NAN))
464        .collect();
465    let mut chi_squared_map = Array2::from_elem((height, width), f64::NAN);
466    let mut converged_map = Array2::from_elem((height, width), false);
467    let mut anorm_map: Option<Array2<f64>> = if has_background_outputs {
468        Some(Array2::from_elem((height, width), f64::NAN))
469    } else {
470        None
471    };
472    let mut background_maps: Option<[Array2<f64>; 3]> = if has_background_outputs {
473        Some([
474            Array2::from_elem((height, width), f64::NAN),
475            Array2::from_elem((height, width), f64::NAN),
476            Array2::from_elem((height, width), f64::NAN),
477        ])
478    } else {
479        None
480    };
481    let mut n_converged = 0;
482    let mut temperature_map: Option<Array2<f64>> = if config.fit_temperature() {
483        Some(Array2::from_elem((height, width), f64::NAN))
484    } else {
485        None
486    };
487    let mut temperature_uncertainty_map: Option<Array2<f64>> = if config.fit_temperature() {
488        Some(Array2::from_elem((height, width), f64::NAN))
489    } else {
490        None
491    };
492
493    for ((y, x), result) in &results {
494        for i in 0..n_maps {
495            density_maps[i][[*y, *x]] = result.densities[i];
496            if let Some(ref unc) = result.uncertainties {
497                uncertainty_maps[i][[*y, *x]] = unc[i];
498            }
499        }
500        chi_squared_map[[*y, *x]] = result.reduced_chi_squared;
501        converged_map[[*y, *x]] = result.converged;
502        if let (Some(t_map), Some(t)) = (&mut temperature_map, result.temperature_k) {
503            t_map[[*y, *x]] = t;
504        }
505        if let (Some(tu_map), Some(tu)) =
506            (&mut temperature_uncertainty_map, result.temperature_k_unc)
507        {
508            tu_map[[*y, *x]] = tu;
509        }
510        if let Some(ref mut a_map) = anorm_map {
511            a_map[[*y, *x]] = result.anorm;
512        }
513        if let Some(ref mut bg_maps) = background_maps {
514            bg_maps[0][[*y, *x]] = result.background[0];
515            bg_maps[1][[*y, *x]] = result.background[1];
516            bg_maps[2][[*y, *x]] = result.background[2];
517        }
518        if result.converged {
519            n_converged += 1;
520        }
521    }
522
523    Ok(SpatialResult {
524        density_maps,
525        uncertainty_maps,
526        chi_squared_map,
527        converged_map,
528        temperature_map,
529        temperature_uncertainty_map,
530        isotope_labels,
531        anorm_map,
532        background_maps,
533        n_converged,
534        n_total: pixel_coords.len(),
535        n_failed: failed_count.load(Ordering::Relaxed),
536    })
537}
538
539// ── End Phase 3 ──────────────────────────────────────────────────────────
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use ndarray::{Array2, Array3};
545    use nereids_fitting::lm::{FitModel, LmConfig};
546    use nereids_fitting::poisson::PoissonConfig;
547    use nereids_fitting::transmission_model::PrecomputedTransmissionModel;
548
549    use crate::pipeline::{SolverConfig, UnifiedFitConfig};
550    use crate::test_helpers::{synthetic_single_resonance, u238_single_resonance};
551
552    /// Build a 4x4 synthetic transmission stack from known density.
553    fn synthetic_4x4_transmission(
554        res_data: &nereids_endf::resonance::ResonanceData,
555        true_density: f64,
556        energies: &[f64],
557    ) -> (Array3<f64>, Array3<f64>) {
558        let n_e = energies.len();
559        let xs = nereids_physics::transmission::broadened_cross_sections(
560            energies,
561            std::slice::from_ref(res_data),
562            0.0,
563            None,
564            None,
565        )
566        .unwrap();
567        let model = PrecomputedTransmissionModel {
568            cross_sections: Arc::new(xs),
569            density_indices: Arc::new(vec![0]),
570            energies: None,
571            instrument: None,
572        };
573        let t_1d = model.evaluate(&[true_density]).unwrap();
574        let sigma_1d: Vec<f64> = t_1d.iter().map(|&v| 0.01 * v.max(0.01)).collect();
575
576        // Fill a 4x4 grid with the same spectrum
577        let mut t_3d = Array3::zeros((n_e, 4, 4));
578        let mut u_3d = Array3::zeros((n_e, 4, 4));
579        for y in 0..4 {
580            for x in 0..4 {
581                for (i, (&t, &s)) in t_1d.iter().zip(sigma_1d.iter()).enumerate() {
582                    t_3d[[i, y, x]] = t;
583                    u_3d[[i, y, x]] = s;
584                }
585            }
586        }
587        (t_3d, u_3d)
588    }
589
590    /// Build a 4x4 synthetic counts stack from known density.
591    fn synthetic_4x4_counts(
592        res_data: &nereids_endf::resonance::ResonanceData,
593        true_density: f64,
594        energies: &[f64],
595        i0: f64,
596    ) -> (Array3<f64>, Array3<f64>) {
597        let (t_3d, _) = synthetic_4x4_transmission(res_data, true_density, energies);
598        let n_e = energies.len();
599        let mut sample = Array3::zeros((n_e, 4, 4));
600        let mut ob = Array3::zeros((n_e, 4, 4));
601        for y in 0..4 {
602            for x in 0..4 {
603                for i in 0..n_e {
604                    ob[[i, y, x]] = i0;
605                    sample[[i, y, x]] = (t_3d[[i, y, x]] * i0).round().max(0.0);
606                }
607            }
608        }
609        (sample, ob)
610    }
611
612    #[test]
613    fn test_spatial_map_typed_transmission_lm() {
614        let data = u238_single_resonance();
615        let true_density = 0.0005;
616        let energies: Vec<f64> = (0..101).map(|i| 1.0 + (i as f64) * 0.1).collect();
617        let (t_3d, u_3d) = synthetic_4x4_transmission(&data, true_density, &energies);
618
619        let config = UnifiedFitConfig::new(
620            energies,
621            vec![data],
622            vec!["U-238".into()],
623            0.0,
624            None,
625            vec![0.001],
626        )
627        .unwrap()
628        .with_solver(SolverConfig::LevenbergMarquardt(LmConfig::default()));
629
630        let input = InputData3D::Transmission {
631            transmission: t_3d.view(),
632            uncertainty: u_3d.view(),
633        };
634
635        let result = spatial_map_typed(&input, &config, None, None, None).unwrap();
636        assert_eq!(result.n_total, 16);
637        assert!(result.n_converged >= 14, "Most pixels should converge");
638
639        // Check mean density of converged pixels
640        let d = &result.density_maps[0];
641        let conv = &result.converged_map;
642        let mean: f64 = d
643            .iter()
644            .zip(conv.iter())
645            .filter(|(_, c)| **c)
646            .map(|(d, _)| *d)
647            .sum::<f64>()
648            / result.n_converged as f64;
649        assert!(
650            (mean - true_density).abs() / true_density < 0.05,
651            "mean density: {mean}, true: {true_density}"
652        );
653    }
654
655    #[test]
656    fn test_spatial_map_typed_counts_kl() {
657        let data = u238_single_resonance();
658        let true_density = 0.0005;
659        let energies: Vec<f64> = (0..101).map(|i| 1.0 + (i as f64) * 0.1).collect();
660        let (sample, ob) = synthetic_4x4_counts(&data, true_density, &energies, 1000.0);
661
662        let config = UnifiedFitConfig::new(
663            energies,
664            vec![data],
665            vec!["U-238".into()],
666            0.0,
667            None,
668            vec![0.001],
669        )
670        .unwrap()
671        .with_solver(SolverConfig::PoissonKL(PoissonConfig::default()));
672
673        let input = InputData3D::Counts {
674            sample_counts: sample.view(),
675            open_beam_counts: ob.view(),
676        };
677
678        let result = spatial_map_typed(&input, &config, None, None, None).unwrap();
679        assert_eq!(result.n_total, 16);
680        assert!(
681            result.n_converged >= 14,
682            "Most pixels should converge with KL"
683        );
684
685        let d = &result.density_maps[0];
686        let conv = &result.converged_map;
687        let mean: f64 = d
688            .iter()
689            .zip(conv.iter())
690            .filter(|(_, c)| **c)
691            .map(|(d, _)| *d)
692            .sum::<f64>()
693            / result.n_converged.max(1) as f64;
694        assert!(
695            (mean - true_density).abs() / true_density < 0.10,
696            "KL mean density: {mean}, true: {true_density}"
697        );
698    }
699
700    #[test]
701    fn test_spatial_map_typed_counts_kl_low_counts() {
702        // I0=10: the regime where KL excels
703        let data = u238_single_resonance();
704        let true_density = 0.0005;
705        let energies: Vec<f64> = (0..101).map(|i| 1.0 + (i as f64) * 0.1).collect();
706        let (sample, ob) = synthetic_4x4_counts(&data, true_density, &energies, 10.0);
707
708        let config = UnifiedFitConfig::new(
709            energies,
710            vec![data],
711            vec!["U-238".into()],
712            0.0,
713            None,
714            vec![0.001],
715        )
716        .unwrap(); // Auto solver → KL for counts
717
718        let input = InputData3D::Counts {
719            sample_counts: sample.view(),
720            open_beam_counts: ob.view(),
721        };
722
723        let result = spatial_map_typed(&input, &config, None, None, None).unwrap();
724        assert_eq!(result.n_total, 16);
725        // At I0=10, KL should still converge for most pixels
726        assert!(
727            result.n_converged >= 10,
728            "KL at I0=10: only {}/{} converged",
729            result.n_converged,
730            result.n_total
731        );
732    }
733
734    #[test]
735    fn test_spatial_map_typed_counts_with_nuisance_surfaces_background_maps() {
736        let data = u238_single_resonance();
737        let true_density = 0.002;
738        let true_alpha1 = 0.92;
739        let true_alpha2 = 1.35;
740        let energies: Vec<f64> = (0..101).map(|i| 1.0 + (i as f64) * 0.1).collect();
741        let (t_3d, _) = synthetic_4x4_transmission(&data, true_density, &energies);
742        let n_e = energies.len();
743
744        let mut sample = Array3::zeros((n_e, 4, 4));
745        let mut flux = Array3::zeros((n_e, 4, 4));
746        let mut background = Array3::zeros((n_e, 4, 4));
747        for y in 0..4 {
748            for x in 0..4 {
749                for (i, &e) in energies.iter().enumerate() {
750                    let bg = 30.0 + 8.0 / e.sqrt();
751                    flux[[i, y, x]] = 120.0;
752                    background[[i, y, x]] = bg;
753                    sample[[i, y, x]] =
754                        true_alpha1 * flux[[i, y, x]] * t_3d[[i, y, x]] + true_alpha2 * bg;
755                }
756            }
757        }
758
759        let config = UnifiedFitConfig::new(
760            energies,
761            vec![data],
762            vec!["U-238".into()],
763            0.0,
764            None,
765            vec![0.001],
766        )
767        .unwrap()
768        .with_solver(SolverConfig::PoissonKL(PoissonConfig::default()))
769        .with_counts_background(crate::pipeline::CountsBackgroundConfig {
770            alpha_1_init: 1.0,
771            alpha_2_init: 1.0,
772            fit_alpha_1: true,
773            fit_alpha_2: true,
774        });
775
776        let input = InputData3D::CountsWithNuisance {
777            sample_counts: sample.view(),
778            flux: flux.view(),
779            background: background.view(),
780        };
781
782        let result = spatial_map_typed(&input, &config, None, None, None).unwrap();
783        assert_eq!(result.n_total, 16);
784        assert_eq!(result.n_converged, 16);
785        assert!(
786            result.anorm_map.is_some(),
787            "counts background runs should surface anorm_map"
788        );
789        assert!(
790            result.background_maps.is_some(),
791            "counts background runs should surface background_maps"
792        );
793        let mean_alpha1 = result
794            .anorm_map
795            .as_ref()
796            .unwrap()
797            .iter()
798            .copied()
799            .sum::<f64>()
800            / 16.0;
801        let mean_alpha2 = result.background_maps.as_ref().unwrap()[2]
802            .iter()
803            .copied()
804            .sum::<f64>()
805            / 16.0;
806        assert!((mean_alpha1 - true_alpha1).abs() < 5e-3);
807        assert!((mean_alpha2 - true_alpha2).abs() < 5e-3);
808    }
809
810    #[test]
811    fn test_spatial_map_typed_dead_pixels() {
812        let data = u238_single_resonance();
813        let energies: Vec<f64> = (0..51).map(|i| 1.0 + (i as f64) * 0.2).collect();
814        let (t_3d, u_3d) = synthetic_4x4_transmission(&data, 0.0005, &energies);
815
816        let config = UnifiedFitConfig::new(
817            energies,
818            vec![data],
819            vec!["U-238".into()],
820            0.0,
821            None,
822            vec![0.001],
823        )
824        .unwrap();
825
826        // Mask half the pixels as dead
827        let mut dead = Array2::from_elem((4, 4), false);
828        for y in 0..2 {
829            for x in 0..4 {
830                dead[[y, x]] = true;
831            }
832        }
833
834        let input = InputData3D::Transmission {
835            transmission: t_3d.view(),
836            uncertainty: u_3d.view(),
837        };
838
839        let result = spatial_map_typed(&input, &config, Some(&dead), None, None).unwrap();
840        assert_eq!(result.n_total, 8, "Only 8 live pixels");
841    }
842
843    #[test]
844    fn test_spatial_map_failed_pixels_remain_nan() {
845        let data = u238_single_resonance();
846        let true_density = 0.0005;
847        let energies: Vec<f64> = (0..51).map(|i| 1.0 + (i as f64) * 0.2).collect();
848        let (sample, ob) = synthetic_4x4_counts(&data, true_density, &energies, 1000.0);
849
850        let config = UnifiedFitConfig::new(
851            energies,
852            vec![data],
853            vec!["U-238".into()],
854            0.0,
855            None,
856            vec![0.001],
857        )
858        .unwrap()
859        .with_solver(SolverConfig::PoissonKL(PoissonConfig::default()))
860        .with_counts_background(crate::pipeline::CountsBackgroundConfig {
861            alpha_1_init: 1.0,
862            alpha_2_init: 1.0,
863            fit_alpha_1: false,
864            fit_alpha_2: true,
865        });
866
867        let input = InputData3D::Counts {
868            sample_counts: sample.view(),
869            open_beam_counts: ob.view(),
870        };
871
872        let result = spatial_map_typed(&input, &config, None, None, None).unwrap();
873        assert_eq!(result.n_converged, 0);
874        assert!(
875            result.density_maps[0].iter().all(|v| v.is_nan()),
876            "failed pixels must remain NaN rather than looking like zero-density fits"
877        );
878        assert!(
879            result.chi_squared_map.iter().all(|v| v.is_nan()),
880            "failed pixels must retain NaN chi-squared"
881        );
882    }
883
884    /// Spatial map with isotope groups: 2 isotopes in 1 group on a 2×2 grid.
885    /// Verifies group-level density recovery and that only 1 density map is returned.
886    #[test]
887    fn test_spatial_map_grouped() {
888        let rd1 = synthetic_single_resonance(92, 235, 233.025, 5.0);
889        let rd2 = synthetic_single_resonance(92, 238, 236.006, 7.0);
890
891        let iso1 = nereids_core::types::Isotope::new(92, 235).unwrap();
892        let iso2 = nereids_core::types::Isotope::new(92, 238).unwrap();
893        let group = nereids_core::types::IsotopeGroup::custom(
894            "U (60/40)".into(),
895            vec![(iso1, 0.6), (iso2, 0.4)],
896        )
897        .unwrap();
898
899        let energies: Vec<f64> = (0..201).map(|i| 1.0 + (i as f64) * 0.05).collect();
900        let n_e = energies.len();
901        let true_density = 0.0005;
902
903        // Generate synthetic transmission for the group
904        let sample = nereids_physics::transmission::SampleParams::new(
905            0.0,
906            vec![
907                (rd1.clone(), true_density * 0.6),
908                (rd2.clone(), true_density * 0.4),
909            ],
910        )
911        .unwrap();
912        let t_1d = nereids_physics::transmission::forward_model(&energies, &sample, None).unwrap();
913        let s_1d: Vec<f64> = t_1d.iter().map(|&v| 0.01 * v.max(0.01)).collect();
914
915        // Fill 2×2 grid
916        let mut t_3d = Array3::zeros((n_e, 2, 2));
917        let mut u_3d = Array3::zeros((n_e, 2, 2));
918        for y in 0..2 {
919            for x in 0..2 {
920                for (i, (&t, &s)) in t_1d.iter().zip(s_1d.iter()).enumerate() {
921                    t_3d[[i, y, x]] = t;
922                    u_3d[[i, y, x]] = s;
923                }
924            }
925        }
926
927        let config = UnifiedFitConfig::new(
928            energies,
929            vec![rd1.clone()],
930            vec!["placeholder".into()],
931            0.0,
932            None,
933            vec![0.001],
934        )
935        .unwrap()
936        .with_groups(&[(&group, &[rd1, rd2])], vec![0.001])
937        .unwrap()
938        .with_solver(SolverConfig::LevenbergMarquardt(LmConfig::default()));
939
940        let input = InputData3D::Transmission {
941            transmission: t_3d.view(),
942            uncertainty: u_3d.view(),
943        };
944
945        let result = spatial_map_typed(&input, &config, None, None, None).unwrap();
946
947        // Should have 1 density map (1 group), not 2
948        assert_eq!(
949            result.density_maps.len(),
950            1,
951            "should have 1 group density map"
952        );
953        assert_eq!(result.isotope_labels, vec!["U (60/40)"]);
954        assert_eq!(result.n_total, 4);
955
956        // All pixels should recover true density within 5%
957        for y in 0..2 {
958            for x in 0..2 {
959                let fitted = result.density_maps[0][[y, x]];
960                let rel_error = (fitted - true_density).abs() / true_density;
961                assert!(
962                    rel_error < 0.05,
963                    "pixel ({y},{x}): fitted={fitted}, true={true_density}, rel_error={rel_error}"
964                );
965            }
966        }
967    }
968
969    // ── Phase 3: Spatial uncertainty propagation tests ──────────────────────
970
971    /// Spatial LM transmission fit populates density uncertainty maps.
972    #[test]
973    fn test_spatial_lm_populates_density_uncertainty() {
974        let rd = u238_single_resonance();
975        let energies: Vec<f64> = (0..101).map(|i| 1.0 + (i as f64) * 0.1).collect();
976        let (mut t_3d, u_3d) = synthetic_4x4_transmission(&rd, 0.001, &energies);
977        // Add deterministic pseudo-noise so reduced chi-squared > 0
978        // (a perfect fit gives chi2r=0, zeroing covariance).
979        for y in 0..4 {
980            for x in 0..4 {
981                for e in 0..energies.len() {
982                    let noise = 0.002 * ((e * 7 + y * 13 + x * 29) % 17) as f64 / 17.0 - 0.001;
983                    t_3d[[e, y, x]] = (t_3d[[e, y, x]] + noise).max(0.001);
984                }
985            }
986        }
987        let data = InputData3D::Transmission {
988            transmission: t_3d.view(),
989            uncertainty: u_3d.view(),
990        };
991        let config = UnifiedFitConfig::new(
992            energies,
993            vec![rd],
994            vec!["U-238".into()],
995            0.0,
996            None,
997            vec![0.0005],
998        )
999        .unwrap()
1000        .with_solver(SolverConfig::LevenbergMarquardt(LmConfig::default()));
1001
1002        let result = spatial_map_typed(&data, &config, None, None, None).unwrap();
1003        assert!(result.n_converged > 0, "some pixels should converge");
1004        // Uncertainty maps should have finite positive values for converged pixels.
1005        let unc_map = &result.uncertainty_maps[0];
1006        let conv_map = &result.converged_map;
1007        let mut n_finite = 0;
1008        for y in 0..4 {
1009            for x in 0..4 {
1010                if conv_map[[y, x]] {
1011                    let u = unc_map[[y, x]];
1012                    assert!(
1013                        u.is_finite() && u > 0.0,
1014                        "LM density unc at ({y},{x}) should be finite+positive, got {u}"
1015                    );
1016                    n_finite += 1;
1017                }
1018            }
1019        }
1020        assert!(
1021            n_finite > 0,
1022            "at least one converged pixel should have finite unc"
1023        );
1024    }
1025
1026    /// Spatial KL counts fit populates density uncertainty maps.
1027    #[test]
1028    fn test_spatial_kl_populates_density_uncertainty() {
1029        let rd = u238_single_resonance();
1030        let energies: Vec<f64> = (0..101).map(|i| 1.0 + (i as f64) * 0.1).collect();
1031        let (t_3d, _) = synthetic_4x4_transmission(&rd, 0.001, &energies);
1032        // Convert to counts: OB=1000, sample = OB * T
1033        let ob_3d = Array3::from_elem(t_3d.raw_dim(), 1000.0);
1034        let sample_3d = &t_3d * &ob_3d;
1035        let data = InputData3D::Counts {
1036            sample_counts: sample_3d.view(),
1037            open_beam_counts: ob_3d.view(),
1038        };
1039        let config = UnifiedFitConfig::new(
1040            energies,
1041            vec![rd],
1042            vec!["U-238".into()],
1043            0.0,
1044            None,
1045            vec![0.0005],
1046        )
1047        .unwrap()
1048        .with_solver(SolverConfig::PoissonKL(PoissonConfig::default()));
1049
1050        let result = spatial_map_typed(&data, &config, None, None, None).unwrap();
1051        assert!(result.n_converged > 0);
1052        let unc_map = &result.uncertainty_maps[0];
1053        let conv_map = &result.converged_map;
1054        let mut n_finite = 0;
1055        for y in 0..4 {
1056            for x in 0..4 {
1057                if conv_map[[y, x]] {
1058                    let u = unc_map[[y, x]];
1059                    assert!(
1060                        u.is_finite() && u > 0.0,
1061                        "KL density unc at ({y},{x}) should be finite+positive, got {u}"
1062                    );
1063                    n_finite += 1;
1064                }
1065            }
1066        }
1067        assert!(n_finite > 0);
1068    }
1069
1070    /// Spatial temperature-fitting populates temperature_uncertainty_map.
1071    #[test]
1072    fn test_spatial_temperature_uncertainty_map() {
1073        let rd = u238_single_resonance();
1074        let energies: Vec<f64> = (0..101).map(|i| 4.0 + (i as f64) * 0.05).collect();
1075        let (mut t_3d, u_3d) = synthetic_4x4_transmission(&rd, 0.001, &energies);
1076        // Add pseudo-noise for nonzero chi2r.
1077        for y in 0..4 {
1078            for x in 0..4 {
1079                for e in 0..energies.len() {
1080                    let noise = 0.002 * ((e * 7 + y * 13 + x * 29) % 17) as f64 / 17.0 - 0.001;
1081                    t_3d[[e, y, x]] = (t_3d[[e, y, x]] + noise).max(0.001);
1082                }
1083            }
1084        }
1085        let data = InputData3D::Transmission {
1086            transmission: t_3d.view(),
1087            uncertainty: u_3d.view(),
1088        };
1089        let config = UnifiedFitConfig::new(
1090            energies,
1091            vec![rd],
1092            vec!["U-238".into()],
1093            300.0,
1094            None,
1095            vec![0.0005],
1096        )
1097        .unwrap()
1098        .with_solver(SolverConfig::LevenbergMarquardt(LmConfig::default()))
1099        .with_fit_temperature(true);
1100
1101        let result = spatial_map_typed(&data, &config, None, None, None).unwrap();
1102        assert!(result.temperature_map.is_some());
1103        let tu_map = result
1104            .temperature_uncertainty_map
1105            .as_ref()
1106            .expect("temperature_uncertainty_map should be Some when fit_temperature=true");
1107        assert_eq!(tu_map.shape(), [4, 4]);
1108        // At least some converged pixels should have finite temperature uncertainty.
1109        let mut n_finite = 0;
1110        for y in 0..4 {
1111            for x in 0..4 {
1112                if result.converged_map[[y, x]] {
1113                    let tu = tu_map[[y, x]];
1114                    if tu.is_finite() && tu > 0.0 {
1115                        n_finite += 1;
1116                    }
1117                }
1118            }
1119        }
1120        assert!(
1121            n_finite > 0,
1122            "at least one converged pixel should have finite temperature uncertainty"
1123        );
1124    }
1125
1126    /// Unconverged pixels remain NaN, not zero-filled.
1127    #[test]
1128    fn test_spatial_unconverged_pixels_are_nan() {
1129        let rd = u238_single_resonance();
1130        let energies: Vec<f64> = (0..101).map(|i| 1.0 + (i as f64) * 0.1).collect();
1131        // Create data where pixel (0,0) is dead (all zeros)
1132        let (mut t_3d, mut u_3d) = synthetic_4x4_transmission(&rd, 0.001, &energies);
1133        for e in 0..energies.len() {
1134            t_3d[[e, 0, 0]] = 0.0;
1135            u_3d[[e, 0, 0]] = f64::INFINITY;
1136        }
1137        let data = InputData3D::Transmission {
1138            transmission: t_3d.view(),
1139            uncertainty: u_3d.view(),
1140        };
1141        let config = UnifiedFitConfig::new(
1142            energies,
1143            vec![rd],
1144            vec!["U-238".into()],
1145            0.0,
1146            None,
1147            vec![0.0005],
1148        )
1149        .unwrap()
1150        .with_solver(SolverConfig::LevenbergMarquardt(LmConfig::default()));
1151
1152        let result = spatial_map_typed(&data, &config, None, None, None).unwrap();
1153        // Pixel (0,0) should not converge and uncertainty should remain NaN.
1154        if !result.converged_map[[0, 0]] {
1155            let u = result.uncertainty_maps[0][[0, 0]];
1156            assert!(
1157                u.is_nan(),
1158                "unconverged pixel uncertainty should be NaN, got {u}"
1159            );
1160        }
1161    }
1162}