Line data Source code
1 : ! ***********************************************************************
2 : !
3 : ! Copyright (C) 2025 Niall Miller & The MESA Team
4 : !
5 : ! This program is free software: you can redistribute it and/or modify
6 : ! it under the terms of the GNU Lesser General Public License
7 : ! as published by the Free Software Foundation,
8 : ! either version 3 of the License, or (at your option) any later version.
9 : !
10 : ! This program is distributed in the hope that it will be useful,
11 : ! but WITHOUT ANY WARRANTY; without even the implied warranty of
12 : ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
13 : ! See the GNU Lesser General Public License for more details.
14 : !
15 : ! You should have received a copy of the GNU Lesser General Public License
16 : ! along with this program. If not, see <https://www.gnu.org/licenses/>.
17 : !
18 : ! ***********************************************************************
19 :
20 : ! ***********************************************************************
21 : ! Hermite interpolation module for spectral energy distributions (SEDs)
22 : ! ***********************************************************************
23 :
24 : module hermite_interp
25 : use const_def, only: dp
26 : use colors_utils, only: dilute_flux
27 : implicit none
28 :
29 : private
30 : public :: construct_sed_hermite, hermite_tensor_interp3d
31 :
32 : contains
33 :
34 : !---------------------------------------------------------------------------
35 : ! Main entry point: Construct a SED using Hermite tensor interpolation
36 : !---------------------------------------------------------------------------
37 0 : subroutine construct_sed_hermite(teff, log_g, metallicity, R, d, file_names, &
38 : lu_teff, lu_logg, lu_meta, stellar_model_dir, &
39 : wavelengths, fluxes)
40 : real(dp), intent(in) :: teff, log_g, metallicity, R, d
41 : real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:)
42 : character(len=*), intent(in) :: stellar_model_dir
43 : character(len=100), intent(in) :: file_names(:)
44 : real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes
45 :
46 : integer :: i, n_lambda, status, n_teff, n_logg, n_meta
47 0 : real(dp), dimension(:), allocatable :: interp_flux, diluted_flux
48 0 : real(dp), dimension(:, :, :, :), allocatable :: precomputed_flux_cube
49 0 : real(dp), dimension(:, :, :), allocatable :: flux_cube_lambda
50 :
51 : ! Parameter grids
52 0 : real(dp), allocatable :: teff_grid(:), logg_grid(:), meta_grid(:)
53 : character(len=256) :: bin_filename
54 :
55 : ! Construct the binary filename
56 0 : bin_filename = trim(stellar_model_dir)//'/flux_cube.bin'
57 :
58 : ! Load the data from binary file
59 : call load_binary_data(bin_filename, teff_grid, logg_grid, meta_grid, &
60 0 : wavelengths, precomputed_flux_cube, status)
61 :
62 0 : n_teff = size(teff_grid)
63 0 : n_logg = size(logg_grid)
64 0 : n_meta = size(meta_grid)
65 0 : n_lambda = size(wavelengths)
66 :
67 : ! Allocate space for interpolated flux
68 0 : allocate (interp_flux(n_lambda))
69 :
70 : ! Process each wavelength point
71 0 : do i = 1, n_lambda
72 0 : allocate (flux_cube_lambda(n_teff, n_logg, n_meta))
73 0 : flux_cube_lambda = precomputed_flux_cube(:, :, :, i)
74 :
75 0 : interp_flux(i) = hermite_tensor_interp3d(teff, log_g, metallicity, &
76 0 : teff_grid, logg_grid, meta_grid, flux_cube_lambda)
77 :
78 0 : deallocate (flux_cube_lambda)
79 : end do
80 :
81 : ! Apply distance dilution to get observed flux
82 0 : allocate (diluted_flux(n_lambda))
83 0 : call dilute_flux(interp_flux, R, d, diluted_flux)
84 0 : fluxes = diluted_flux
85 :
86 0 : end subroutine construct_sed_hermite
87 :
88 : !---------------------------------------------------------------------------
89 : ! Load data from binary file
90 : !---------------------------------------------------------------------------
91 0 : subroutine load_binary_data(filename, teff_grid, logg_grid, meta_grid, &
92 : wavelengths, flux_cube, status)
93 : character(len=*), intent(in) :: filename
94 : real(dp), allocatable, intent(out) :: teff_grid(:), logg_grid(:), meta_grid(:)
95 : real(dp), allocatable, intent(out) :: wavelengths(:)
96 : real(dp), allocatable, intent(out) :: flux_cube(:, :, :, :)
97 : integer, intent(out) :: status
98 :
99 : integer :: unit, n_teff, n_logg, n_meta, n_lambda
100 :
101 0 : unit = 99
102 0 : status = 0
103 :
104 : ! Open the binary file
105 0 : open (unit=unit, file=filename, status='OLD', ACCESS='STREAM', FORM='UNFORMATTED', iostat=status)
106 0 : if (status /= 0) then
107 0 : print *, 'Error opening binary file:', trim(filename)
108 0 : return
109 : end if
110 :
111 : ! Read dimensions
112 0 : read (unit, iostat=status) n_teff, n_logg, n_meta, n_lambda
113 0 : if (status /= 0) then
114 0 : print *, 'Error reading dimensions from binary file'
115 0 : close (unit)
116 0 : return
117 : end if
118 :
119 : ! Allocate arrays based on dimensions
120 0 : allocate (teff_grid(n_teff), STAT=status)
121 0 : if (status /= 0) then
122 0 : print *, 'Error allocating teff_grid array'
123 0 : close (unit)
124 0 : return
125 : end if
126 :
127 0 : allocate (logg_grid(n_logg), STAT=status)
128 0 : if (status /= 0) then
129 0 : print *, 'Error allocating logg_grid array'
130 0 : close (unit)
131 0 : return
132 : end if
133 :
134 0 : allocate (meta_grid(n_meta), STAT=status)
135 0 : if (status /= 0) then
136 0 : print *, 'Error allocating meta_grid array'
137 0 : close (unit)
138 0 : return
139 : end if
140 :
141 0 : allocate (wavelengths(n_lambda), STAT=status)
142 0 : if (status /= 0) then
143 0 : print *, 'Error allocating wavelengths array'
144 0 : close (unit)
145 0 : return
146 : end if
147 :
148 0 : allocate (flux_cube(n_teff, n_logg, n_meta, n_lambda), STAT=status)
149 0 : if (status /= 0) then
150 0 : print *, 'Error allocating flux_cube array'
151 0 : close (unit)
152 0 : return
153 : end if
154 :
155 : ! Read grid arrays
156 0 : read (unit, iostat=status) teff_grid
157 0 : if (status /= 0) then
158 0 : print *, 'Error reading teff_grid'
159 0 : close (unit)
160 0 : return
161 : end if
162 :
163 0 : read (unit, iostat=status) logg_grid
164 0 : if (status /= 0) then
165 0 : print *, 'Error reading logg_grid'
166 0 : close (unit)
167 0 : return
168 : end if
169 :
170 0 : read (unit, iostat=status) meta_grid
171 0 : if (status /= 0) then
172 0 : print *, 'Error reading meta_grid'
173 0 : close (unit)
174 0 : return
175 : end if
176 :
177 0 : read (unit, iostat=status) wavelengths
178 0 : if (status /= 0) then
179 0 : print *, 'Error reading wavelengths'
180 0 : close (unit)
181 0 : return
182 : end if
183 :
184 : ! Read flux cube
185 0 : read (unit, iostat=status) flux_cube
186 0 : if (status /= 0) then
187 0 : print *, 'Error reading flux_cube'
188 0 : close (unit)
189 0 : return
190 : end if
191 :
192 : ! Close file and return success
193 0 : close (unit)
194 : end subroutine load_binary_data
195 :
196 0 : function hermite_tensor_interp3d(x_val, y_val, z_val, x_grid, y_grid, &
197 0 : z_grid, f_values) result(f_interp)
198 : real(dp), intent(in) :: x_val, y_val, z_val
199 : real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
200 : real(dp), intent(in) :: f_values(:, :, :)
201 : real(dp) :: f_interp
202 :
203 : integer :: i_x, i_y, i_z
204 : real(dp) :: t_x, t_y, t_z
205 : real(dp) :: dx, dy, dz
206 : real(dp) :: dx_values(2, 2, 2), dy_values(2, 2, 2), dz_values(2, 2, 2)
207 : real(dp) :: values(2, 2, 2)
208 : real(dp) :: sum
209 : integer :: ix, iy, iz
210 : real(dp) :: h_x(2), h_y(2), h_z(2)
211 : real(dp) :: hx_d(2), hy_d(2), hz_d(2)
212 :
213 : ! Find containing cell and parameter values
214 : call find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
215 0 : i_x, i_y, i_z, t_x, t_y, t_z)
216 :
217 : ! If outside grid, use nearest point
218 : if (i_x < 1 .or. i_x >= size(x_grid) .or. &
219 : i_y < 1 .or. i_y >= size(y_grid) .or. &
220 0 : i_z < 1 .or. i_z >= size(z_grid)) then
221 :
222 : call find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
223 0 : i_x, i_y, i_z)
224 0 : f_interp = f_values(i_x, i_y, i_z)
225 : return
226 : end if
227 :
228 : ! Grid cell spacing
229 0 : dx = x_grid(i_x + 1) - x_grid(i_x)
230 0 : dy = y_grid(i_y + 1) - y_grid(i_y)
231 0 : dz = z_grid(i_z + 1) - z_grid(i_z)
232 :
233 : ! Extract the local 2x2x2 grid cell and compute derivatives
234 0 : do iz = 0, 1
235 0 : do iy = 0, 1
236 0 : do ix = 0, 1
237 0 : values(ix + 1, iy + 1, iz + 1) = f_values(i_x + ix, i_y + iy, i_z + iz)
238 : call compute_derivatives_at_point(f_values, i_x + ix, i_y + iy, i_z + iz, &
239 : size(x_grid), size(y_grid), size(z_grid), &
240 : dx, dy, dz, &
241 : dx_values(ix + 1, iy + 1, iz + 1), &
242 : dy_values(ix + 1, iy + 1, iz + 1), &
243 0 : dz_values(ix + 1, iy + 1, iz + 1))
244 : end do
245 : end do
246 : end do
247 :
248 : ! Precompute Hermite basis functions and derivatives
249 0 : h_x = [h00(t_x), h01(t_x)]
250 0 : hx_d = [h10(t_x), h11(t_x)]
251 0 : h_y = [h00(t_y), h01(t_y)]
252 0 : hy_d = [h10(t_y), h11(t_y)]
253 0 : h_z = [h00(t_z), h01(t_z)]
254 0 : hz_d = [h10(t_z), h11(t_z)]
255 :
256 : ! Final interpolation sum
257 0 : sum = 0.0_dp
258 0 : do iz = 1, 2
259 0 : do iy = 1, 2
260 0 : do ix = 1, 2
261 0 : sum = sum + h_x(ix)*h_y(iy)*h_z(iz)*values(ix, iy, iz)
262 0 : sum = sum + hx_d(ix)*h_y(iy)*h_z(iz)*dx*dx_values(ix, iy, iz)
263 0 : sum = sum + h_x(ix)*hy_d(iy)*h_z(iz)*dy*dy_values(ix, iy, iz)
264 0 : sum = sum + h_x(ix)*h_y(iy)*hz_d(iz)*dz*dz_values(ix, iy, iz)
265 : end do
266 : end do
267 : end do
268 :
269 0 : f_interp = sum
270 : end function hermite_tensor_interp3d
271 :
272 : !---------------------------------------------------------------------------
273 : ! Find the cell containing the interpolation point
274 : !---------------------------------------------------------------------------
275 0 : subroutine find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
276 : i_x, i_y, i_z, t_x, t_y, t_z)
277 : real(dp), intent(in) :: x_val, y_val, z_val
278 : real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
279 : integer, intent(out) :: i_x, i_y, i_z
280 : real(dp), intent(out) :: t_x, t_y, t_z
281 :
282 : ! Find x interval
283 0 : call find_interval(x_grid, x_val, i_x, t_x)
284 :
285 : ! Find y interval
286 0 : call find_interval(y_grid, y_val, i_y, t_y)
287 :
288 : ! Find z interval
289 0 : call find_interval(z_grid, z_val, i_z, t_z)
290 0 : end subroutine find_containing_cell
291 :
292 : !---------------------------------------------------------------------------
293 : ! Find the interval in a sorted array containing a value
294 : !---------------------------------------------------------------------------
295 :
296 0 : subroutine find_interval(x, val, i, t)
297 : real(dp), intent(in) :: x(:), val
298 : integer, intent(out) :: i
299 : real(dp), intent(out) :: t
300 :
301 : integer :: n, lo, hi, mid
302 : logical :: dummy_axis
303 :
304 0 : n = size(x)
305 :
306 : ! Detect dummy axis: all values == 0, 999, or -999
307 0 : dummy_axis = all(x == 0.0_dp) .or. all(x == 999.0_dp) .or. all(x == -999.0_dp)
308 :
309 : if (dummy_axis) then
310 : ! Collapse axis: always use first point, no interpolation
311 0 : i = 1
312 0 : t = 0.0_dp
313 0 : return
314 : end if
315 :
316 : ! ---------- ORIGINAL CODE BELOW ----------------
317 :
318 0 : if (val <= x(1)) then
319 0 : i = 1
320 0 : t = 0.0_dp
321 0 : return
322 0 : else if (val >= x(n)) then
323 0 : i = n - 1
324 0 : t = 1.0_dp
325 0 : return
326 : end if
327 :
328 : lo = 1
329 : hi = n
330 0 : do while (hi - lo > 1)
331 0 : mid = (lo + hi)/2
332 0 : if (val >= x(mid)) then
333 : lo = mid
334 : else
335 0 : hi = mid
336 : end if
337 : end do
338 :
339 0 : i = lo
340 0 : t = (val - x(i))/(x(i + 1) - x(i))
341 : end subroutine find_interval
342 :
343 : !---------------------------------------------------------------------------
344 : ! Find the nearest grid point
345 : !---------------------------------------------------------------------------
346 0 : subroutine find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
347 : i_x, i_y, i_z)
348 : real(dp), intent(in) :: x_val, y_val, z_val
349 : real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
350 : integer, intent(out) :: i_x, i_y, i_z
351 :
352 : ! Find nearest grid points using intrinsic minloc
353 0 : i_x = minloc(abs(x_val - x_grid), 1)
354 0 : i_y = minloc(abs(y_val - y_grid), 1)
355 0 : i_z = minloc(abs(z_val - z_grid), 1)
356 0 : end subroutine find_nearest_point
357 :
358 : !---------------------------------------------------------------------------
359 : ! Compute derivatives at a grid point
360 : !---------------------------------------------------------------------------
361 0 : subroutine compute_derivatives_at_point(f, i, j, k, nx, ny, nz, dx, dy, dz, &
362 : df_dx, df_dy, df_dz)
363 : real(dp), intent(in) :: f(:, :, :)
364 : integer, intent(in) :: i, j, k, nx, ny, nz
365 : real(dp), intent(in) :: dx, dy, dz
366 : real(dp), intent(out) :: df_dx, df_dy, df_dz
367 :
368 : ! Compute x derivative using centered differences where possible
369 0 : if (i > 1 .and. i < nx) then
370 0 : df_dx = (f(i + 1, j, k) - f(i - 1, j, k))/(2.0_dp*dx)
371 0 : else if (i == 1) then
372 0 : df_dx = (f(i + 1, j, k) - f(i, j, k))/dx
373 : else ! i == nx
374 0 : df_dx = (f(i, j, k) - f(i - 1, j, k))/dx
375 : end if
376 :
377 : ! Compute y derivative using centered differences where possible
378 0 : if (j > 1 .and. j < ny) then
379 0 : df_dy = (f(i, j + 1, k) - f(i, j - 1, k))/(2.0_dp*dy)
380 0 : else if (j == 1) then
381 0 : df_dy = (f(i, j + 1, k) - f(i, j, k))/dy
382 : else ! j == ny
383 0 : df_dy = (f(i, j, k) - f(i, j - 1, k))/dy
384 : end if
385 :
386 : ! Compute z derivative using centered differences where possible
387 0 : if (k > 1 .and. k < nz) then
388 0 : df_dz = (f(i, j, k + 1) - f(i, j, k - 1))/(2.0_dp*dz)
389 0 : else if (k == 1) then
390 0 : df_dz = (f(i, j, k + 1) - f(i, j, k))/dz
391 : else ! k == nz
392 0 : df_dz = (f(i, j, k) - f(i, j, k - 1))/dz
393 : end if
394 0 : end subroutine compute_derivatives_at_point
395 :
396 : !---------------------------------------------------------------------------
397 : ! Hermite basis functions
398 : !---------------------------------------------------------------------------
399 0 : function h00(t) result(h)
400 : real(dp), intent(in) :: t
401 : real(dp) :: h
402 0 : h = 2.0_dp*t**3 - 3.0_dp*t**2 + 1.0_dp
403 0 : end function h00
404 :
405 0 : function h10(t) result(h)
406 : real(dp), intent(in) :: t
407 : real(dp) :: h
408 0 : h = t**3 - 2.0_dp*t**2 + t
409 0 : end function h10
410 :
411 0 : function h01(t) result(h)
412 : real(dp), intent(in) :: t
413 : real(dp) :: h
414 0 : h = -2.0_dp*t**3 + 3.0_dp*t**2
415 0 : end function h01
416 :
417 0 : function h11(t) result(h)
418 : real(dp), intent(in) :: t
419 : real(dp) :: h
420 0 : h = t**3 - t**2
421 0 : end function h11
422 :
423 : end module hermite_interp
|