Line data Source code
1 : ! ***********************************************************************
2 : ! Copyright (C) 2026 Niall Miller & The MESA Team
3 : ! ***********************************************************************
4 :
5 : ! ***********************************************************************
6 : ! Hermite interpolation module for spectral energy distributions (SEDs)
7 : ! ***********************************************************************
8 :
9 : module hermite_interp
10 : use const_def, only: dp
11 : use colors_def, only: Colors_General_Info
12 : use colors_utils, only: dilute_flux, find_containing_cell, &
13 : find_nearest_point, find_bracket_index, load_stencil
14 : implicit none
15 :
16 : private
17 : public :: construct_sed_hermite, hermite_tensor_interp3d
18 :
19 : contains
20 :
21 : !---------------------------------------------------------------------------
22 : ! Main entry point: Construct a SED using Hermite tensor interpolation.
23 : ! Data loading strategy is determined by rq%cube_loaded (set at init):
24 : ! cube_loaded = .true. -> use the preloaded 4-D cube on the handle
25 : ! cube_loaded = .false. -> load individual SED files via the lookup table
26 : !---------------------------------------------------------------------------
27 20 : subroutine construct_sed_hermite(rq, teff, log_g, metallicity, R, d, &
28 : stellar_model_dir, wavelengths, fluxes)
29 : type(Colors_General_Info), intent(inout) :: rq
30 : real(dp), intent(in) :: teff, log_g, metallicity, R, d
31 : character(len=*), intent(in) :: stellar_model_dir
32 : real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes
33 :
34 : integer :: n_lambda
35 20 : real(dp), dimension(:), allocatable :: interp_flux
36 :
37 20 : if (rq%cube_loaded) then
38 : ! ---- Fast path: use preloaded cube from handle ----
39 20 : n_lambda = size(rq%cube_wavelengths)
40 :
41 : ! Copy wavelengths to output
42 60 : allocate (wavelengths(n_lambda))
43 24020 : wavelengths = rq%cube_wavelengths
44 :
45 : ! Vectorised interpolation over all wavelengths in one pass —
46 : ! cell location is computed once and reused, no per-wavelength
47 : ! allocation or 3-D slice extraction needed.
48 40 : allocate (interp_flux(n_lambda))
49 : call hermite_interp_vector(teff, log_g, metallicity, &
50 : rq%cube_teff_grid, rq%cube_logg_grid, &
51 : rq%cube_meta_grid, &
52 20 : rq%cube_flux, n_lambda, interp_flux)
53 : else
54 : ! ---- Fallback path: load individual SED files from lookup table ----
55 : call construct_sed_from_files(rq, teff, log_g, metallicity, &
56 0 : stellar_model_dir, interp_flux, wavelengths)
57 0 : n_lambda = size(wavelengths)
58 : end if
59 :
60 : ! Apply distance dilution to get observed flux
61 60 : allocate(fluxes(n_lambda))
62 20 : call dilute_flux(interp_flux, R, d, fluxes)
63 :
64 20 : end subroutine construct_sed_hermite
65 :
66 : !---------------------------------------------------------------------------
67 : ! Fallback: Build a local sub-cube from individual SED files with enough
68 : ! context for Hermite derivative computation, then interpolate all
69 : ! wavelengths in a single pass.
70 : !---------------------------------------------------------------------------
71 0 : subroutine construct_sed_from_files(rq, teff, log_g, metallicity, &
72 : stellar_model_dir, interp_flux, wavelengths)
73 : use colors_utils, only: resolve_path, build_grid_to_lu_map
74 : type(Colors_General_Info), intent(inout) :: rq
75 : real(dp), intent(in) :: teff, log_g, metallicity
76 : character(len=*), intent(in) :: stellar_model_dir
77 : real(dp), dimension(:), allocatable, intent(out) :: interp_flux, wavelengths
78 :
79 : integer :: i_t, i_g, i_m ! bracketing indices in unique grids
80 : integer :: lo_t, hi_t, lo_g, hi_g, lo_m, hi_m ! stencil bounds
81 : integer :: nt, ng, nm, n_lambda
82 : character(len=512) :: resolved_dir
83 : logical :: need_reload
84 :
85 0 : resolved_dir = trim(resolve_path(stellar_model_dir))
86 :
87 : ! Ensure the grid-to-lu mapping exists (built once, then reused)
88 0 : if (.not. rq%grid_map_built) call build_grid_to_lu_map(rq)
89 :
90 : ! Find bracketing cell in the unique grids
91 0 : call find_bracket_index(rq%u_teff, teff, i_t)
92 0 : call find_bracket_index(rq%u_logg, log_g, i_g)
93 0 : call find_bracket_index(rq%u_meta, metallicity, i_m)
94 :
95 : ! Check if the stencil cache is still valid for this cell
96 0 : need_reload = .true.
97 : if (rq%stencil_valid .and. &
98 : i_t == rq%stencil_i_t .and. &
99 0 : i_g == rq%stencil_i_g .and. &
100 : i_m == rq%stencil_i_m) then
101 : need_reload = .false.
102 : end if
103 :
104 : if (need_reload) then
105 : ! Determine the extended stencil bounds:
106 : ! For each axis, include one point before and after the cell
107 : ! when available, so that centred differences match the cube.
108 0 : nt = size(rq%u_teff)
109 0 : ng = size(rq%u_logg)
110 0 : nm = size(rq%u_meta)
111 :
112 0 : if (nt < 2) then
113 0 : lo_t = 1; hi_t = 1
114 : else
115 0 : lo_t = max(1, i_t - 1)
116 0 : hi_t = min(nt, i_t + 2)
117 : end if
118 :
119 0 : if (ng < 2) then
120 0 : lo_g = 1; hi_g = 1
121 : else
122 0 : lo_g = max(1, i_g - 1)
123 0 : hi_g = min(ng, i_g + 2)
124 : end if
125 :
126 0 : if (nm < 2) then
127 0 : lo_m = 1; hi_m = 1
128 : else
129 0 : lo_m = max(1, i_m - 1)
130 0 : hi_m = min(nm, i_m + 2)
131 : end if
132 :
133 : ! Load SEDs for every stencil point (using memory cache)
134 0 : call load_stencil(rq, resolved_dir, lo_t, hi_t, lo_g, hi_g, lo_m, hi_m)
135 :
136 : ! Store subgrid arrays on the handle
137 0 : if (allocated(rq%stencil_teff)) deallocate(rq%stencil_teff)
138 0 : if (allocated(rq%stencil_logg)) deallocate(rq%stencil_logg)
139 0 : if (allocated(rq%stencil_meta)) deallocate(rq%stencil_meta)
140 :
141 0 : allocate (rq%stencil_teff(hi_t - lo_t + 1))
142 0 : allocate (rq%stencil_logg(hi_g - lo_g + 1))
143 0 : allocate (rq%stencil_meta(hi_m - lo_m + 1))
144 0 : rq%stencil_teff = rq%u_teff(lo_t:hi_t)
145 0 : rq%stencil_logg = rq%u_logg(lo_g:hi_g)
146 0 : rq%stencil_meta = rq%u_meta(lo_m:hi_m)
147 :
148 0 : rq%stencil_i_t = i_t
149 0 : rq%stencil_i_g = i_g
150 0 : rq%stencil_i_m = i_m
151 0 : rq%stencil_valid = .true.
152 : end if
153 :
154 : ! Copy wavelengths to output
155 0 : n_lambda = size(rq%stencil_wavelengths)
156 0 : allocate (wavelengths(n_lambda))
157 0 : wavelengths = rq%stencil_wavelengths
158 :
159 : ! Interpolate all wavelengths using precomputed stencil
160 0 : allocate (interp_flux(n_lambda))
161 : call hermite_interp_vector(teff, log_g, metallicity, &
162 : rq%stencil_teff, rq%stencil_logg, rq%stencil_meta, &
163 0 : rq%stencil_fluxes, n_lambda, interp_flux)
164 :
165 0 : end subroutine construct_sed_from_files
166 :
167 : !---------------------------------------------------------------------------
168 : ! Vectorised Hermite interpolation over all wavelengths.
169 : !
170 : ! The cell location (i_x, i_y, i_z, t_x, t_y, t_z) depends only on
171 : ! (teff, logg, meta) and the sub-grids — not on wavelength. Computing
172 : ! it once and reusing across all n_lambda samples eliminates redundant
173 : ! binary searches and basis-function evaluations.
174 : !---------------------------------------------------------------------------
175 20 : subroutine hermite_interp_vector(x_val, y_val, z_val, &
176 20 : x_grid, y_grid, z_grid, &
177 20 : f_values_4d, n_lambda, result_flux)
178 : real(dp), intent(in) :: x_val, y_val, z_val
179 : real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
180 : real(dp), intent(in) :: f_values_4d(:,:,:,:) ! (nx, ny, nz, n_lambda)
181 : integer, intent(in) :: n_lambda
182 : real(dp), intent(out) :: result_flux(n_lambda)
183 :
184 : integer :: i_x, i_y, i_z
185 : real(dp) :: t_x, t_y, t_z
186 : real(dp) :: dx, dy, dz
187 : real(dp) :: val, df_dx, df_dy, df_dz
188 : integer :: nx, ny, nz
189 : integer :: ix, iy, iz, lam
190 : real(dp) :: h_x(2), h_y(2), h_z(2)
191 : real(dp) :: hx_d(2), hy_d(2), hz_d(2)
192 : real(dp) :: wx, wy, wz, wxd, wyd, wzd
193 :
194 20 : nx = size(x_grid)
195 20 : ny = size(y_grid)
196 20 : nz = size(z_grid)
197 :
198 : ! Find containing cell (done once for all wavelengths)
199 : call find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
200 20 : i_x, i_y, i_z, t_x, t_y, t_z)
201 :
202 : ! If outside grid, use nearest point for all wavelengths
203 : if (i_x < 1 .or. i_x >= nx .or. &
204 : i_y < 1 .or. i_y >= ny .or. &
205 20 : i_z < 1 .or. i_z >= nz) then
206 :
207 : call find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
208 0 : i_x, i_y, i_z)
209 0 : do lam = 1, n_lambda
210 0 : result_flux(lam) = f_values_4d(i_x, i_y, i_z, lam)
211 : end do
212 0 : return
213 : end if
214 :
215 : ! Grid cell spacing
216 20 : dx = x_grid(i_x + 1) - x_grid(i_x)
217 20 : dy = y_grid(i_y + 1) - y_grid(i_y)
218 20 : dz = z_grid(i_z + 1) - z_grid(i_z)
219 :
220 : ! Precompute Hermite basis functions (same for all wavelengths)
221 60 : h_x = [h00(t_x), h01(t_x)]
222 60 : hx_d = [h10(t_x), h11(t_x)]
223 60 : h_y = [h00(t_y), h01(t_y)]
224 60 : hy_d = [h10(t_y), h11(t_y)]
225 60 : h_z = [h00(t_z), h01(t_z)]
226 60 : hz_d = [h10(t_z), h11(t_z)]
227 :
228 :
229 : ! stencil loop — weights are invariant over lambda, so lambda is innermost
230 24000 : result_flux = 0.0_dp
231 60 : do iz = 0, 1
232 40 : wz = h_z(iz + 1)
233 40 : wzd = hz_d(iz + 1)
234 140 : do iy = 0, 1
235 80 : wy = h_y(iy + 1)
236 80 : wyd = hy_d(iy + 1)
237 280 : do ix = 0, 1
238 160 : wx = h_x(ix + 1)
239 160 : wxd = hx_d(ix + 1)
240 192080 : do lam = 1, n_lambda
241 191840 : val = f_values_4d(i_x + ix, i_y + iy, i_z + iz, lam)
242 :
243 : call compute_derivatives_at_point_4d( &
244 : f_values_4d, i_x + ix, i_y + iy, i_z + iz, lam, &
245 191840 : nx, ny, nz, x_grid, y_grid, z_grid, df_dx, df_dy, df_dz)
246 :
247 : result_flux(lam) = result_flux(lam) &
248 : + wx*wy*wz * val &
249 : + wxd*wy*wz * dx * df_dx &
250 : + wx*wyd*wz * dy * df_dy &
251 192000 : + wx*wy*wzd * dz * df_dz
252 : end do
253 : end do
254 : end do
255 : end do
256 :
257 : end subroutine hermite_interp_vector
258 :
259 : !---------------------------------------------------------------------------
260 : ! Compute derivatives directly from the 4-D array at a given wavelength,
261 : ! avoiding the need to extract a 3-D slice first.
262 : !---------------------------------------------------------------------------
263 191840 : subroutine compute_derivatives_at_point_4d(f4d, i, j, k, lam, nx, ny, nz, &
264 191840 : x_grid, y_grid, z_grid, df_dx, df_dy, df_dz)
265 : real(dp), intent(in) :: f4d(:,:,:,:)
266 : integer, intent(in) :: i, j, k, lam, nx, ny, nz
267 : real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
268 : real(dp), intent(out) :: df_dx, df_dy, df_dz
269 :
270 : ! x derivative
271 191840 : if (nx == 1) then
272 0 : df_dx = 0.0_dp
273 191840 : else if (i > 1 .and. i < nx) then
274 182248 : df_dx = (f4d(i+1,j,k,lam) - f4d(i-1,j,k,lam)) / (x_grid(i+1) - x_grid(i-1))
275 9592 : else if (i == 1) then
276 9592 : df_dx = (f4d(i+1,j,k,lam) - f4d(i,j,k,lam)) / (x_grid(i+1) - x_grid(i))
277 : else
278 0 : df_dx = (f4d(i,j,k,lam) - f4d(i-1,j,k,lam)) / (x_grid(i) - x_grid(i-1))
279 : end if
280 :
281 : ! y derivative
282 191840 : if (ny == 1) then
283 0 : df_dy = 0.0_dp
284 191840 : else if (j > 1 .and. j < ny) then
285 187044 : df_dy = (f4d(i,j+1,k,lam) - f4d(i,j-1,k,lam)) / (y_grid(j+1) - y_grid(j-1))
286 4796 : else if (j == 1) then
287 0 : df_dy = (f4d(i,j+1,k,lam) - f4d(i,j,k,lam)) / (y_grid(j+1) - y_grid(j))
288 : else
289 4796 : df_dy = (f4d(i,j,k,lam) - f4d(i,j-1,k,lam)) / (y_grid(j) - y_grid(j-1))
290 : end if
291 :
292 : ! z derivative
293 191840 : if (nz == 1) then
294 0 : df_dz = 0.0_dp
295 191840 : else if (k > 1 .and. k < nz) then
296 187044 : df_dz = (f4d(i,j,k+1,lam) - f4d(i,j,k-1,lam)) / (z_grid(k+1) - z_grid(k-1))
297 4796 : else if (k == 1) then
298 0 : df_dz = (f4d(i,j,k+1,lam) - f4d(i,j,k,lam)) / (z_grid(k+1) - z_grid(k))
299 : else
300 4796 : df_dz = (f4d(i,j,k,lam) - f4d(i,j,k-1,lam)) / (z_grid(k) - z_grid(k-1))
301 : end if
302 :
303 191840 : end subroutine compute_derivatives_at_point_4d
304 :
305 :
306 0 : function hermite_tensor_interp3d(x_val, y_val, z_val, x_grid, y_grid, &
307 0 : z_grid, f_values) result(f_interp)
308 : real(dp), intent(in) :: x_val, y_val, z_val
309 : real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
310 : real(dp), intent(in) :: f_values(:, :, :)
311 : real(dp) :: f_interp
312 :
313 : integer :: i_x, i_y, i_z, ix, iy, iz
314 : integer :: nx, ny, nz
315 : real(dp) :: t_x, t_y, t_z, dx, dy, dz, f_sum
316 : real(dp) :: dx_values(2, 2, 2), dy_values(2, 2, 2), dz_values(2, 2, 2)
317 : real(dp) :: values(2, 2, 2)
318 : real(dp) :: h_x(2), h_y(2), h_z(2)
319 : real(dp) :: hx_d(2), hy_d(2), hz_d(2)
320 :
321 0 : nx = size(x_grid)
322 0 : ny = size(y_grid)
323 0 : nz = size(z_grid)
324 :
325 : call find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
326 0 : i_x, i_y, i_z, t_x, t_y, t_z)
327 :
328 : if (i_x < 1 .or. i_x >= nx .or. &
329 : i_y < 1 .or. i_y >= ny .or. &
330 0 : i_z < 1 .or. i_z >= nz) then
331 : call find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
332 0 : i_x, i_y, i_z)
333 0 : f_interp = f_values(i_x, i_y, i_z)
334 0 : return
335 : end if
336 :
337 0 : dx = x_grid(i_x + 1) - x_grid(i_x)
338 0 : dy = y_grid(i_y + 1) - y_grid(i_y)
339 0 : dz = z_grid(i_z + 1) - z_grid(i_z)
340 :
341 0 : do iz = 0, 1
342 0 : do iy = 0, 1
343 0 : do ix = 0, 1
344 0 : values(ix+1, iy+1, iz+1) = f_values(i_x+ix, i_y+iy, i_z+iz)
345 : call compute_derivatives_at_point(f_values, i_x+ix, i_y+iy, i_z+iz, &
346 : nx, ny, nz, x_grid, y_grid, z_grid, &
347 : dx_values(ix+1, iy+1, iz+1), &
348 : dy_values(ix+1, iy+1, iz+1), &
349 0 : dz_values(ix+1, iy+1, iz+1))
350 : end do
351 : end do
352 : end do
353 :
354 0 : h_x = [h00(t_x), h01(t_x)]
355 0 : hx_d = [h10(t_x), h11(t_x)]
356 0 : h_y = [h00(t_y), h01(t_y)]
357 0 : hy_d = [h10(t_y), h11(t_y)]
358 0 : h_z = [h00(t_z), h01(t_z)]
359 0 : hz_d = [h10(t_z), h11(t_z)]
360 :
361 0 : f_sum = 0.0_dp
362 0 : do iz = 1, 2
363 0 : do iy = 1, 2
364 0 : do ix = 1, 2
365 0 : f_sum = f_sum + h_x(ix)*h_y(iy)*h_z(iz) * values(ix, iy, iz)
366 0 : f_sum = f_sum + hx_d(ix)*h_y(iy)*h_z(iz) * dx * dx_values(ix, iy, iz)
367 0 : f_sum = f_sum + h_x(ix)*hy_d(iy)*h_z(iz) * dy * dy_values(ix, iy, iz)
368 0 : f_sum = f_sum + h_x(ix)*h_y(iy)*hz_d(iz) * dz * dz_values(ix, iy, iz)
369 : end do
370 : end do
371 : end do
372 :
373 0 : f_interp = f_sum
374 : end function hermite_tensor_interp3d
375 :
376 :
377 0 : subroutine compute_derivatives_at_point(f, i, j, k, nx, ny, nz, &
378 0 : x_grid, y_grid, z_grid, &
379 : df_dx, df_dy, df_dz)
380 : real(dp), intent(in) :: f(:, :, :)
381 : integer, intent(in) :: i, j, k, nx, ny, nz
382 : real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
383 : real(dp), intent(out) :: df_dx, df_dy, df_dz
384 :
385 0 : if (nx == 1) then
386 0 : df_dx = 0.0_dp
387 0 : else if (i > 1 .and. i < nx) then
388 0 : df_dx = (f(i+1,j,k) - f(i-1,j,k)) / (x_grid(i+1) - x_grid(i-1))
389 0 : else if (i == 1) then
390 0 : df_dx = (f(i+1,j,k) - f(i,j,k)) / (x_grid(i+1) - x_grid(i))
391 : else
392 0 : df_dx = (f(i,j,k) - f(i-1,j,k)) / (x_grid(i) - x_grid(i-1))
393 : end if
394 :
395 0 : if (ny == 1) then
396 0 : df_dy = 0.0_dp
397 0 : else if (j > 1 .and. j < ny) then
398 0 : df_dy = (f(i,j+1,k) - f(i,j-1,k)) / (y_grid(j+1) - y_grid(j-1))
399 0 : else if (j == 1) then
400 0 : df_dy = (f(i,j+1,k) - f(i,j,k)) / (y_grid(j+1) - y_grid(j))
401 : else
402 0 : df_dy = (f(i,j,k) - f(i,j-1,k)) / (y_grid(j) - y_grid(j-1))
403 : end if
404 :
405 0 : if (nz == 1) then
406 0 : df_dz = 0.0_dp
407 0 : else if (k > 1 .and. k < nz) then
408 0 : df_dz = (f(i,j,k+1) - f(i,j,k-1)) / (z_grid(k+1) - z_grid(k-1))
409 0 : else if (k == 1) then
410 0 : df_dz = (f(i,j,k+1) - f(i,j,k)) / (z_grid(k+1) - z_grid(k))
411 : else
412 0 : df_dz = (f(i,j,k) - f(i,j,k-1)) / (z_grid(k) - z_grid(k-1))
413 : end if
414 0 : end subroutine compute_derivatives_at_point
415 :
416 : !---------------------------------------------------------------------------
417 : ! Hermite basis functions
418 : !---------------------------------------------------------------------------
419 60 : function h00(t) result(h)
420 : real(dp), intent(in) :: t
421 : real(dp) :: h
422 60 : h = 2.0_dp*t**3 - 3.0_dp*t**2 + 1.0_dp
423 60 : end function h00
424 :
425 60 : function h10(t) result(h)
426 : real(dp), intent(in) :: t
427 : real(dp) :: h
428 60 : h = t**3 - 2.0_dp*t**2 + t
429 60 : end function h10
430 :
431 60 : function h01(t) result(h)
432 : real(dp), intent(in) :: t
433 : real(dp) :: h
434 60 : h = -2.0_dp*t**3 + 3.0_dp*t**2
435 60 : end function h01
436 :
437 60 : function h11(t) result(h)
438 : real(dp), intent(in) :: t
439 : real(dp) :: h
440 60 : h = t**3 - t**2
441 60 : end function h11
442 :
443 : end module hermite_interp
|