Line data Source code
1 : ! ***********************************************************************
2 : !
3 : ! Copyright (C) 2025 Niall Miller & The MESA Team
4 : !
5 : ! This program is free software: you can redistribute it and/or modify
6 : ! it under the terms of the GNU Lesser General Public License
7 : ! as published by the Free Software Foundation,
8 : ! either version 3 of the License, or (at your option) any later version.
9 : !
10 : ! This program is distributed in the hope that it will be useful,
11 : ! but WITHOUT ANY WARRANTY; without even the implied warranty of
12 : ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
13 : ! See the GNU Lesser General Public License for more details.
14 : !
15 : ! You should have received a copy of the GNU Lesser General Public License
16 : ! along with this program. If not, see <https://www.gnu.org/licenses/>.
17 : !
18 : ! ***********************************************************************
19 :
20 : ! ***********************************************************************
21 : ! Hermite interpolation module for spectral energy distributions (SEDs)
22 : ! ***********************************************************************
23 :
24 : module hermite_interp
25 : use const_def, only: dp
26 : use colors_utils, only: dilute_flux
27 : implicit none
28 :
29 : private
30 : public :: construct_sed_hermite, hermite_tensor_interp3d
31 :
32 : contains
33 :
34 : !---------------------------------------------------------------------------
35 : ! Main entry point: Construct a SED using Hermite tensor interpolation
36 : !---------------------------------------------------------------------------
37 0 : subroutine construct_sed_hermite(teff, log_g, metallicity, R, d, file_names, &
38 : lu_teff, lu_logg, lu_meta, stellar_model_dir, &
39 : wavelengths, fluxes)
40 : real(dp), intent(in) :: teff, log_g, metallicity, R, d
41 : real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:)
42 : character(len=*), intent(in) :: stellar_model_dir
43 : character(len=100), intent(in) :: file_names(:)
44 : real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes
45 :
46 : integer :: i, n_lambda, status, n_teff, n_logg, n_meta
47 0 : real(dp), dimension(:), allocatable :: interp_flux, diluted_flux
48 0 : real(dp), dimension(:, :, :, :), allocatable :: precomputed_flux_cube
49 0 : real(dp), dimension(:, :, :), allocatable :: flux_cube_lambda
50 :
51 : ! Parameter grids
52 0 : real(dp), allocatable :: teff_grid(:), logg_grid(:), meta_grid(:)
53 : character(len=256) :: bin_filename
54 :
55 : ! Construct the binary filename
56 0 : bin_filename = trim(stellar_model_dir)//'/flux_cube.bin'
57 :
58 : ! Load the data from binary file
59 : call load_binary_data(bin_filename, teff_grid, logg_grid, meta_grid, &
60 0 : wavelengths, precomputed_flux_cube, status)
61 :
62 0 : n_teff = size(teff_grid)
63 0 : n_logg = size(logg_grid)
64 0 : n_meta = size(meta_grid)
65 0 : n_lambda = size(wavelengths)
66 :
67 : ! Allocate space for interpolated flux
68 0 : allocate (interp_flux(n_lambda))
69 :
70 : ! Process each wavelength point
71 0 : do i = 1, n_lambda
72 0 : allocate (flux_cube_lambda(n_teff, n_logg, n_meta))
73 0 : flux_cube_lambda = precomputed_flux_cube(:, :, :, i)
74 :
75 : interp_flux(i) = hermite_tensor_interp3d(teff, log_g, metallicity, &
76 0 : teff_grid, logg_grid, meta_grid, flux_cube_lambda)
77 :
78 0 : deallocate (flux_cube_lambda)
79 : end do
80 :
81 : ! Apply distance dilution to get observed flux
82 0 : allocate (diluted_flux(n_lambda))
83 0 : call dilute_flux(interp_flux, R, d, diluted_flux)
84 0 : fluxes = diluted_flux
85 :
86 0 : end subroutine construct_sed_hermite
87 :
88 : !---------------------------------------------------------------------------
89 : ! Load data from binary file
90 : !---------------------------------------------------------------------------
91 0 : subroutine load_binary_data(filename, teff_grid, logg_grid, meta_grid, &
92 : wavelengths, flux_cube, status)
93 : character(len=*), intent(in) :: filename
94 : real(dp), allocatable, intent(out) :: teff_grid(:), logg_grid(:), meta_grid(:)
95 : real(dp), allocatable, intent(out) :: wavelengths(:)
96 : real(dp), allocatable, intent(out) :: flux_cube(:, :, :, :)
97 : integer, intent(out) :: status
98 :
99 : integer :: unit, n_teff, n_logg, n_meta, n_lambda
100 :
101 0 : unit = 99
102 0 : status = 0
103 :
104 : ! Open the binary file
105 0 : open (unit=unit, file=filename, status='OLD', ACCESS='STREAM', FORM='UNFORMATTED', iostat=status)
106 0 : if (status /= 0) then
107 0 : print *, 'Error opening binary file:', trim(filename)
108 0 : return
109 : end if
110 :
111 : ! Read dimensions
112 0 : read (unit, iostat=status) n_teff, n_logg, n_meta, n_lambda
113 0 : if (status /= 0) then
114 0 : print *, 'Error reading dimensions from binary file'
115 0 : close (unit)
116 0 : return
117 : end if
118 :
119 : ! Allocate arrays based on dimensions
120 0 : allocate (teff_grid(n_teff), STAT=status)
121 0 : if (status /= 0) then
122 0 : print *, 'Error allocating teff_grid array'
123 0 : close (unit)
124 0 : return
125 : end if
126 :
127 0 : allocate (logg_grid(n_logg), STAT=status)
128 0 : if (status /= 0) then
129 0 : print *, 'Error allocating logg_grid array'
130 0 : close (unit)
131 0 : return
132 : end if
133 :
134 0 : allocate (meta_grid(n_meta), STAT=status)
135 0 : if (status /= 0) then
136 0 : print *, 'Error allocating meta_grid array'
137 0 : close (unit)
138 0 : return
139 : end if
140 :
141 0 : allocate (wavelengths(n_lambda), STAT=status)
142 0 : if (status /= 0) then
143 0 : print *, 'Error allocating wavelengths array'
144 0 : close (unit)
145 0 : return
146 : end if
147 :
148 0 : allocate (flux_cube(n_teff, n_logg, n_meta, n_lambda), STAT=status)
149 0 : if (status /= 0) then
150 0 : print *, 'Error allocating flux_cube array'
151 0 : close (unit)
152 0 : return
153 : end if
154 :
155 : ! Read grid arrays
156 0 : read (unit, iostat=status) teff_grid
157 0 : if (status /= 0) then
158 0 : print *, 'Error reading teff_grid'
159 0 : close (unit)
160 0 : return
161 : end if
162 :
163 0 : read (unit, iostat=status) logg_grid
164 0 : if (status /= 0) then
165 0 : print *, 'Error reading logg_grid'
166 0 : close (unit)
167 0 : return
168 : end if
169 :
170 0 : read (unit, iostat=status) meta_grid
171 0 : if (status /= 0) then
172 0 : print *, 'Error reading meta_grid'
173 0 : close (unit)
174 0 : return
175 : end if
176 :
177 0 : read (unit, iostat=status) wavelengths
178 0 : if (status /= 0) then
179 0 : print *, 'Error reading wavelengths'
180 0 : close (unit)
181 0 : return
182 : end if
183 :
184 : ! Read flux cube
185 0 : read (unit, iostat=status) flux_cube
186 0 : if (status /= 0) then
187 0 : print *, 'Error reading flux_cube'
188 0 : close (unit)
189 0 : return
190 : end if
191 :
192 : ! Close file and return success
193 0 : close (unit)
194 : end subroutine load_binary_data
195 :
196 :
197 :
198 0 : function hermite_tensor_interp3d(x_val, y_val, z_val, x_grid, y_grid, &
199 0 : z_grid, f_values) result(f_interp)
200 : real(dp), intent(in) :: x_val, y_val, z_val
201 : real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
202 : real(dp), intent(in) :: f_values(:, :, :)
203 : real(dp) :: f_interp
204 :
205 : integer :: i_x, i_y, i_z
206 0 : real(dp) :: t_x, t_y, t_z
207 0 : real(dp) :: dx, dy, dz
208 0 : real(dp) :: dx_values(2, 2, 2), dy_values(2, 2, 2), dz_values(2, 2, 2)
209 0 : real(dp) :: values(2, 2, 2)
210 0 : real(dp) :: sum
211 : integer :: ix, iy, iz
212 0 : real(dp) :: h_x(2), h_y(2), h_z(2)
213 0 : real(dp) :: hx_d(2), hy_d(2), hz_d(2)
214 :
215 : ! Find containing cell and parameter values
216 : call find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
217 0 : i_x, i_y, i_z, t_x, t_y, t_z)
218 :
219 : ! If outside grid, use nearest point
220 : if (i_x < 1 .or. i_x >= size(x_grid) .or. &
221 : i_y < 1 .or. i_y >= size(y_grid) .or. &
222 0 : i_z < 1 .or. i_z >= size(z_grid)) then
223 :
224 : call find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
225 0 : i_x, i_y, i_z)
226 0 : f_interp = f_values(i_x, i_y, i_z)
227 : return
228 : end if
229 :
230 : ! Grid cell spacing
231 0 : dx = x_grid(i_x + 1) - x_grid(i_x)
232 0 : dy = y_grid(i_y + 1) - y_grid(i_y)
233 0 : dz = z_grid(i_z + 1) - z_grid(i_z)
234 :
235 : ! Extract the local 2x2x2 grid cell and compute derivatives
236 0 : do iz = 0, 1
237 0 : do iy = 0, 1
238 0 : do ix = 0, 1
239 0 : values(ix + 1, iy + 1, iz + 1) = f_values(i_x + ix, i_y + iy, i_z + iz)
240 : call compute_derivatives_at_point(f_values, i_x + ix, i_y + iy, i_z + iz, &
241 : size(x_grid), size(y_grid), size(z_grid), &
242 : dx, dy, dz, &
243 : dx_values(ix + 1, iy + 1, iz + 1), &
244 : dy_values(ix + 1, iy + 1, iz + 1), &
245 0 : dz_values(ix + 1, iy + 1, iz + 1))
246 : end do
247 : end do
248 : end do
249 :
250 : ! Precompute Hermite basis functions and derivatives
251 0 : h_x = [h00(t_x), h01(t_x)]
252 0 : hx_d = [h10(t_x), h11(t_x)]
253 0 : h_y = [h00(t_y), h01(t_y)]
254 0 : hy_d = [h10(t_y), h11(t_y)]
255 0 : h_z = [h00(t_z), h01(t_z)]
256 0 : hz_d = [h10(t_z), h11(t_z)]
257 :
258 : ! Final interpolation sum
259 0 : sum = 0.0_dp
260 0 : do iz = 1, 2
261 0 : do iy = 1, 2
262 0 : do ix = 1, 2
263 0 : sum = sum + h_x(ix)*h_y(iy)*h_z(iz) * values(ix, iy, iz)
264 0 : sum = sum + hx_d(ix)*h_y(iy)*h_z(iz) * dx * dx_values(ix, iy, iz)
265 0 : sum = sum + h_x(ix)*hy_d(iy)*h_z(iz) * dy * dy_values(ix, iy, iz)
266 0 : sum = sum + h_x(ix)*h_y(iy)*hz_d(iz) * dz * dz_values(ix, iy, iz)
267 : end do
268 : end do
269 : end do
270 :
271 0 : f_interp = sum
272 : end function hermite_tensor_interp3d
273 :
274 :
275 :
276 : !---------------------------------------------------------------------------
277 : ! Find the cell containing the interpolation point
278 : !---------------------------------------------------------------------------
279 0 : subroutine find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
280 : i_x, i_y, i_z, t_x, t_y, t_z)
281 : real(dp), intent(in) :: x_val, y_val, z_val
282 : real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
283 : integer, intent(out) :: i_x, i_y, i_z
284 : real(dp), intent(out) :: t_x, t_y, t_z
285 :
286 : ! Find x interval
287 0 : call find_interval(x_grid, x_val, i_x, t_x)
288 :
289 : ! Find y interval
290 0 : call find_interval(y_grid, y_val, i_y, t_y)
291 :
292 : ! Find z interval
293 0 : call find_interval(z_grid, z_val, i_z, t_z)
294 0 : end subroutine find_containing_cell
295 :
296 : !---------------------------------------------------------------------------
297 : ! Find the interval in a sorted array containing a value
298 : !---------------------------------------------------------------------------
299 0 : subroutine find_interval(x, val, i, t)
300 : real(dp), intent(in) :: x(:), val
301 : integer, intent(out) :: i
302 : real(dp), intent(out) :: t
303 :
304 : integer :: n, lo, hi, mid
305 :
306 0 : n = size(x)
307 :
308 : ! Handle out-of-bounds cases
309 0 : if (val <= x(1)) then
310 0 : i = 1
311 0 : t = 0.0_dp
312 0 : return
313 0 : else if (val >= x(n)) then
314 0 : i = n - 1
315 0 : t = 1.0_dp
316 0 : return
317 : end if
318 :
319 : ! Binary search to find interval
320 : lo = 1
321 : hi = n
322 0 : do while (hi - lo > 1)
323 0 : mid = (lo + hi)/2
324 0 : if (val >= x(mid)) then
325 : lo = mid
326 : else
327 0 : hi = mid
328 : end if
329 : end do
330 :
331 0 : i = lo
332 0 : t = (val - x(i))/(x(i + 1) - x(i))
333 : end subroutine find_interval
334 :
335 : !---------------------------------------------------------------------------
336 : ! Find the nearest grid point
337 : !---------------------------------------------------------------------------
338 0 : subroutine find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
339 : i_x, i_y, i_z)
340 : real(dp), intent(in) :: x_val, y_val, z_val
341 : real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
342 : integer, intent(out) :: i_x, i_y, i_z
343 :
344 : ! Find nearest grid points using intrinsic minloc
345 0 : i_x = minloc(abs(x_val - x_grid), 1)
346 0 : i_y = minloc(abs(y_val - y_grid), 1)
347 0 : i_z = minloc(abs(z_val - z_grid), 1)
348 0 : end subroutine find_nearest_point
349 :
350 : !---------------------------------------------------------------------------
351 : ! Compute derivatives at a grid point
352 : !---------------------------------------------------------------------------
353 0 : subroutine compute_derivatives_at_point(f, i, j, k, nx, ny, nz, dx, dy, dz, &
354 : df_dx, df_dy, df_dz)
355 : real(dp), intent(in) :: f(:, :, :)
356 : integer, intent(in) :: i, j, k, nx, ny, nz
357 : real(dp), intent(in) :: dx, dy, dz
358 : real(dp), intent(out) :: df_dx, df_dy, df_dz
359 :
360 : ! Compute x derivative using centered differences where possible
361 0 : if (i > 1 .and. i < nx) then
362 0 : df_dx = (f(i + 1, j, k) - f(i - 1, j, k))/(2.0_dp*dx)
363 0 : else if (i == 1) then
364 0 : df_dx = (f(i + 1, j, k) - f(i, j, k))/dx
365 : else ! i == nx
366 0 : df_dx = (f(i, j, k) - f(i - 1, j, k))/dx
367 : end if
368 :
369 : ! Compute y derivative using centered differences where possible
370 0 : if (j > 1 .and. j < ny) then
371 0 : df_dy = (f(i, j + 1, k) - f(i, j - 1, k))/(2.0_dp*dy)
372 0 : else if (j == 1) then
373 0 : df_dy = (f(i, j + 1, k) - f(i, j, k))/dy
374 : else ! j == ny
375 0 : df_dy = (f(i, j, k) - f(i, j - 1, k))/dy
376 : end if
377 :
378 : ! Compute z derivative using centered differences where possible
379 0 : if (k > 1 .and. k < nz) then
380 0 : df_dz = (f(i, j, k + 1) - f(i, j, k - 1))/(2.0_dp*dz)
381 0 : else if (k == 1) then
382 0 : df_dz = (f(i, j, k + 1) - f(i, j, k))/dz
383 : else ! k == nz
384 0 : df_dz = (f(i, j, k) - f(i, j, k - 1))/dz
385 : end if
386 0 : end subroutine compute_derivatives_at_point
387 :
388 : !---------------------------------------------------------------------------
389 : ! Hermite basis functions
390 : !---------------------------------------------------------------------------
391 0 : function h00(t) result(h)
392 : real(dp), intent(in) :: t
393 : real(dp) :: h
394 0 : h = 2.0_dp*t**3 - 3.0_dp*t**2 + 1.0_dp
395 0 : end function h00
396 :
397 0 : function h10(t) result(h)
398 : real(dp), intent(in) :: t
399 : real(dp) :: h
400 0 : h = t**3 - 2.0_dp*t**2 + t
401 0 : end function h10
402 :
403 0 : function h01(t) result(h)
404 : real(dp), intent(in) :: t
405 : real(dp) :: h
406 0 : h = -2.0_dp*t**3 + 3.0_dp*t**2
407 0 : end function h01
408 :
409 0 : function h11(t) result(h)
410 : real(dp), intent(in) :: t
411 : real(dp) :: h
412 0 : h = t**3 - t**2
413 0 : end function h11
414 :
415 : end module hermite_interp
|