Line data Source code
1 : ! ***********************************************************************
2 : !
3 : ! Copyright (C) 2025 Niall Miller & The MESA Team
4 : !
5 : ! This program is free software: you can redistribute it and/or modify
6 : ! it under the terms of the GNU Lesser General Public License
7 : ! as published by the Free Software Foundation,
8 : ! either version 3 of the License, or (at your option) any later version.
9 : !
10 : ! This program is distributed in the hope that it will be useful,
11 : ! but WITHOUT ANY WARRANTY; without even the implied warranty of
12 : ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
13 : ! See the GNU Lesser General Public License for more details.
14 : !
15 : ! You should have received a copy of the GNU Lesser General Public License
16 : ! along with this program. If not, see <https://www.gnu.org/licenses/>.
17 : !
18 : ! ***********************************************************************
19 :
20 : ! knn interpolation for SEDs
21 : !
22 : ! data-loading strategy selected by rq%cube_loaded:
23 : ! .true. -> extract neighbor SEDs from the preloaded 4-D cube
24 : ! .false. -> load individual SED files via the lookup table (fallback)
25 :
26 : module knn_interp
27 : use const_def, only: dp
28 : use colors_def, only: Colors_General_Info
29 : use colors_utils, only: dilute_flux, load_sed_cached
30 : use utils_lib, only: mesa_error
31 : implicit none
32 :
33 : private
34 : public :: construct_sed_knn, interpolate_array
35 :
36 : contains
37 :
38 : ! main entry point -- construct a SED using KNN interpolation
39 : ! strategy controlled by rq%cube_loaded (set at init)
40 0 : subroutine construct_sed_knn(rq, teff, log_g, metallicity, R, d, &
41 : stellar_model_dir, wavelengths, fluxes)
42 : type(Colors_General_Info), intent(inout) :: rq
43 : real(dp), intent(in) :: teff, log_g, metallicity, R, d
44 : character(len=*), intent(in) :: stellar_model_dir
45 : real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes
46 :
47 : integer :: n_lambda
48 0 : real(dp), dimension(:), allocatable :: interp_flux, diluted_flux
49 :
50 0 : if (rq%cube_loaded) then
51 : ! fast path: extract neighbors from preloaded cube
52 : call construct_sed_from_cube(rq, teff, log_g, metallicity, &
53 0 : interp_flux, wavelengths)
54 0 : n_lambda = size(wavelengths)
55 : else
56 : ! fallback path: load individual SED files
57 : call construct_sed_from_files(rq, teff, log_g, metallicity, &
58 0 : stellar_model_dir, interp_flux, wavelengths)
59 0 : n_lambda = size(wavelengths)
60 : end if
61 :
62 0 : allocate (diluted_flux(n_lambda))
63 0 : call dilute_flux(interp_flux, R, d, diluted_flux)
64 0 : fluxes = diluted_flux
65 :
66 0 : end subroutine construct_sed_knn
67 :
68 : ! cube path: find 4 nearest grid points, extract their SEDs from cube_flux, blend by IDW
69 0 : subroutine construct_sed_from_cube(rq, teff, log_g, metallicity, &
70 : interp_flux, wavelengths)
71 : type(Colors_General_Info), intent(inout) :: rq
72 : real(dp), intent(in) :: teff, log_g, metallicity
73 : real(dp), dimension(:), allocatable, intent(out) :: interp_flux, wavelengths
74 :
75 : integer :: n_lambda, k
76 : integer, dimension(4) :: nbr_it, nbr_ig, nbr_im
77 : real(dp), dimension(4) :: distances, weights
78 : real(dp) :: sum_weights
79 :
80 0 : n_lambda = size(rq%cube_wavelengths)
81 0 : allocate (wavelengths(n_lambda))
82 0 : wavelengths = rq%cube_wavelengths
83 :
84 : ! find the 4 nearest grid points in the structured cube
85 : call get_closest_grid_points(teff, log_g, metallicity, &
86 : rq%cube_teff_grid, rq%cube_logg_grid, &
87 : rq%cube_meta_grid, &
88 0 : nbr_it, nbr_ig, nbr_im, distances)
89 :
90 : ! compute inverse-distance weights
91 0 : do k = 1, 4
92 0 : if (distances(k) == 0.0_dp) distances(k) = 1.0e-10_dp
93 0 : weights(k) = 1.0_dp/distances(k)
94 : end do
95 0 : sum_weights = sum(weights)
96 0 : weights = weights/sum_weights
97 :
98 : ! blend neighbor SEDs from cube
99 0 : allocate (interp_flux(n_lambda))
100 0 : interp_flux = 0.0_dp
101 0 : do k = 1, 4
102 : interp_flux = interp_flux + weights(k)* &
103 0 : rq%cube_flux(nbr_it(k), nbr_ig(k), nbr_im(k), :)
104 : end do
105 :
106 0 : end subroutine construct_sed_from_cube
107 :
108 : ! fallback path: find 4 nearest models in the lookup table, load SEDs, blend by IDW
109 0 : subroutine construct_sed_from_files(rq, teff, log_g, metallicity, &
110 : stellar_model_dir, interp_flux, wavelengths)
111 : use colors_utils, only: resolve_path
112 : type(Colors_General_Info), intent(inout) :: rq
113 : real(dp), intent(in) :: teff, log_g, metallicity
114 : character(len=*), intent(in) :: stellar_model_dir
115 : real(dp), dimension(:), allocatable, intent(out) :: interp_flux, wavelengths
116 :
117 : integer, dimension(4) :: closest_indices
118 0 : real(dp), dimension(:), allocatable :: temp_flux, common_wavelengths
119 0 : real(dp), dimension(:, :), allocatable :: model_fluxes
120 : real(dp), dimension(4) :: weights, distances
121 : integer :: i, n_points
122 : real(dp) :: sum_weights
123 : character(len=512) :: resolved_dir
124 :
125 0 : resolved_dir = trim(resolve_path(stellar_model_dir))
126 :
127 : ! get the four closest stellar models from the flat lookup table
128 : call get_closest_stellar_models(teff, log_g, metallicity, &
129 : rq%lu_teff, rq%lu_logg, rq%lu_meta, &
130 0 : closest_indices)
131 :
132 : ! load the first SED to define the wavelength grid (using cache)
133 0 : call load_sed_cached(rq, resolved_dir, closest_indices(1), temp_flux)
134 :
135 : ! get wavelengths from canonical copy on the handle
136 0 : if (rq%fallback_wavelengths_set) then
137 0 : n_points = size(rq%fallback_wavelengths)
138 0 : allocate (common_wavelengths(n_points))
139 0 : common_wavelengths = rq%fallback_wavelengths
140 : else
141 : ! should not happen -- load_sed_cached sets this on first call
142 0 : print *, 'KNN fallback: wavelengths not set after first SED load'
143 0 : call mesa_error(__FILE__, __LINE__)
144 : end if
145 :
146 0 : allocate (model_fluxes(4, n_points))
147 0 : model_fluxes(1, :) = temp_flux(1:n_points)
148 0 : if (allocated(temp_flux)) deallocate (temp_flux)
149 :
150 : ! load and store remaining SEDs
151 0 : do i = 2, 4
152 0 : call load_sed_cached(rq, resolved_dir, closest_indices(i), temp_flux)
153 0 : model_fluxes(i, :) = temp_flux(1:n_points)
154 0 : if (allocated(temp_flux)) deallocate (temp_flux)
155 : end do
156 :
157 : ! compute distances and weights for the four models
158 0 : do i = 1, 4
159 : distances(i) = sqrt((rq%lu_teff(closest_indices(i)) - teff)**2 + &
160 : (rq%lu_logg(closest_indices(i)) - log_g)**2 + &
161 0 : (rq%lu_meta(closest_indices(i)) - metallicity)**2)
162 0 : if (distances(i) == 0.0_dp) distances(i) = 1.0e-10_dp
163 0 : weights(i) = 1.0_dp/distances(i)
164 : end do
165 :
166 0 : sum_weights = sum(weights)
167 0 : weights = weights/sum_weights
168 :
169 0 : allocate (wavelengths(n_points))
170 0 : wavelengths = common_wavelengths
171 :
172 0 : allocate (interp_flux(n_points))
173 0 : interp_flux = 0.0_dp
174 :
175 : ! weighted combination of model fluxes
176 0 : do i = 1, 4
177 0 : interp_flux = interp_flux + weights(i)*model_fluxes(i, :)
178 : end do
179 :
180 0 : end subroutine construct_sed_from_files
181 :
182 : ! find the 4 closest grid points in the structured cube (normalised euclidean distance)
183 0 : subroutine get_closest_grid_points(teff, log_g, metallicity, &
184 0 : teff_grid, logg_grid, meta_grid, &
185 : nbr_it, nbr_ig, nbr_im, distances)
186 : real(dp), intent(in) :: teff, log_g, metallicity
187 : real(dp), intent(in) :: teff_grid(:), logg_grid(:), meta_grid(:)
188 : integer, dimension(4), intent(out) :: nbr_it, nbr_ig, nbr_im
189 : real(dp), dimension(4), intent(out) :: distances
190 :
191 : integer :: it, ig, im, j
192 : real(dp) :: dist, norm_teff, norm_logg, norm_meta
193 : real(dp) :: teff_min, teff_max, logg_min, logg_max, meta_min, meta_max
194 : real(dp) :: scaled_t, scaled_g, scaled_m, dt, dg, dm
195 : logical :: use_teff_dim, use_logg_dim, use_meta_dim
196 :
197 0 : distances = huge(1.0_dp)
198 0 : nbr_it = 1; nbr_ig = 1; nbr_im = 1
199 :
200 : ! normalisation ranges
201 0 : teff_min = minval(teff_grid); teff_max = maxval(teff_grid)
202 0 : logg_min = minval(logg_grid); logg_max = maxval(logg_grid)
203 0 : meta_min = minval(meta_grid); meta_max = maxval(meta_grid)
204 :
205 : ! detect dummy axes
206 : use_teff_dim = .not. (all(teff_grid == 0.0_dp) .or. &
207 0 : all(teff_grid == 999.0_dp) .or. all(teff_grid == -999.0_dp))
208 : use_logg_dim = .not. (all(logg_grid == 0.0_dp) .or. &
209 0 : all(logg_grid == 999.0_dp) .or. all(logg_grid == -999.0_dp))
210 : use_meta_dim = .not. (all(meta_grid == 0.0_dp) .or. &
211 0 : all(meta_grid == 999.0_dp) .or. all(meta_grid == -999.0_dp))
212 :
213 : ! normalised target values
214 0 : norm_teff = 0.0_dp; norm_logg = 0.0_dp; norm_meta = 0.0_dp
215 0 : if (use_teff_dim .and. teff_max - teff_min > 0.0_dp) &
216 0 : norm_teff = (teff - teff_min)/(teff_max - teff_min)
217 0 : if (use_logg_dim .and. logg_max - logg_min > 0.0_dp) &
218 0 : norm_logg = (log_g - logg_min)/(logg_max - logg_min)
219 0 : if (use_meta_dim .and. meta_max - meta_min > 0.0_dp) &
220 0 : norm_meta = (metallicity - meta_min)/(meta_max - meta_min)
221 :
222 0 : do it = 1, size(teff_grid)
223 0 : if (use_teff_dim .and. teff_max - teff_min > 0.0_dp) then
224 0 : scaled_t = (teff_grid(it) - teff_min)/(teff_max - teff_min)
225 : else
226 : scaled_t = 0.0_dp
227 : end if
228 0 : dt = 0.0_dp
229 0 : if (use_teff_dim) dt = (scaled_t - norm_teff)**2
230 :
231 0 : do ig = 1, size(logg_grid)
232 0 : if (use_logg_dim .and. logg_max - logg_min > 0.0_dp) then
233 0 : scaled_g = (logg_grid(ig) - logg_min)/(logg_max - logg_min)
234 : else
235 : scaled_g = 0.0_dp
236 : end if
237 0 : dg = 0.0_dp
238 0 : if (use_logg_dim) dg = (scaled_g - norm_logg)**2
239 :
240 0 : do im = 1, size(meta_grid)
241 0 : if (use_meta_dim .and. meta_max - meta_min > 0.0_dp) then
242 0 : scaled_m = (meta_grid(im) - meta_min)/(meta_max - meta_min)
243 : else
244 : scaled_m = 0.0_dp
245 : end if
246 0 : dm = 0.0_dp
247 0 : if (use_meta_dim) dm = (scaled_m - norm_meta)**2
248 :
249 0 : dist = dt + dg + dm
250 :
251 : ! insert into sorted top-4 if closer
252 0 : do j = 1, 4
253 0 : if (dist < distances(j)) then
254 0 : if (j < 4) then
255 0 : distances(j + 1:4) = distances(j:3)
256 0 : nbr_it(j + 1:4) = nbr_it(j:3)
257 0 : nbr_ig(j + 1:4) = nbr_ig(j:3)
258 0 : nbr_im(j + 1:4) = nbr_im(j:3)
259 : end if
260 0 : distances(j) = dist
261 0 : nbr_it(j) = it
262 0 : nbr_ig(j) = ig
263 0 : nbr_im(j) = im
264 0 : exit
265 : end if
266 : end do
267 : end do
268 : end do
269 : end do
270 :
271 : ! convert squared distances to actual distances for weighting
272 0 : do j = 1, 4
273 0 : distances(j) = sqrt(distances(j))
274 : end do
275 :
276 0 : end subroutine get_closest_grid_points
277 :
278 : ! find the four closest stellar models in the flat lookup table
279 0 : subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, &
280 0 : lu_logg, lu_meta, closest_indices)
281 : real(dp), intent(in) :: teff, log_g, metallicity
282 : real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:)
283 : integer, dimension(4), intent(out) :: closest_indices
284 : logical :: use_teff_dim, use_logg_dim, use_meta_dim
285 :
286 : integer :: i, n, j
287 : real(dp) :: distance, norm_teff, norm_logg, norm_meta
288 0 : real(dp), dimension(:), allocatable :: scaled_lu_teff, scaled_lu_logg, scaled_lu_meta
289 : real(dp), dimension(4) :: min_distances
290 : integer, dimension(4) :: indices
291 : real(dp) :: teff_min, teff_max, logg_min, logg_max, meta_min, meta_max
292 : real(dp) :: teff_dist, logg_dist, meta_dist
293 :
294 0 : n = size(lu_teff)
295 0 : min_distances = huge(1.0)
296 0 : indices = -1
297 :
298 : ! find min and max for normalisation
299 0 : teff_min = minval(lu_teff)
300 0 : teff_max = maxval(lu_teff)
301 0 : logg_min = minval(lu_logg)
302 0 : logg_max = maxval(lu_logg)
303 0 : meta_min = minval(lu_meta)
304 0 : meta_max = maxval(lu_meta)
305 :
306 0 : allocate (scaled_lu_teff(n), scaled_lu_logg(n), scaled_lu_meta(n))
307 :
308 0 : if (teff_max - teff_min > 0.0_dp) then
309 0 : scaled_lu_teff = (lu_teff - teff_min)/(teff_max - teff_min)
310 : end if
311 :
312 0 : if (logg_max - logg_min > 0.0_dp) then
313 0 : scaled_lu_logg = (lu_logg - logg_min)/(logg_max - logg_min)
314 : end if
315 :
316 0 : if (meta_max - meta_min > 0.0_dp) then
317 0 : scaled_lu_meta = (lu_meta - meta_min)/(meta_max - meta_min)
318 : end if
319 :
320 : ! normalise input parameters
321 0 : norm_teff = (teff - teff_min)/(teff_max - teff_min)
322 0 : norm_logg = (log_g - logg_min)/(logg_max - logg_min)
323 0 : norm_meta = (metallicity - meta_min)/(meta_max - meta_min)
324 :
325 : ! detect dummy axes -- skip degenerate dimensions in distance calc
326 0 : use_teff_dim = .not. (all(lu_teff == 0.0_dp) .or. all(lu_teff == 999.0_dp) .or. all(lu_teff == -999.0_dp))
327 0 : use_logg_dim = .not. (all(lu_logg == 0.0_dp) .or. all(lu_logg == 999.0_dp) .or. all(lu_logg == -999.0_dp))
328 0 : use_meta_dim = .not. (all(lu_meta == 0.0_dp) .or. all(lu_meta == 999.0_dp) .or. all(lu_meta == -999.0_dp))
329 :
330 0 : do i = 1, n
331 0 : teff_dist = 0.0_dp
332 0 : logg_dist = 0.0_dp
333 0 : meta_dist = 0.0_dp
334 :
335 0 : if (teff_max - teff_min > 0.0_dp) then
336 0 : teff_dist = scaled_lu_teff(i) - norm_teff
337 : end if
338 :
339 0 : if (logg_max - logg_min > 0.0_dp) then
340 0 : logg_dist = scaled_lu_logg(i) - norm_logg
341 : end if
342 :
343 0 : if (meta_max - meta_min > 0.0_dp) then
344 0 : meta_dist = scaled_lu_meta(i) - norm_meta
345 : end if
346 :
347 : ! compute distance using only valid dimensions
348 0 : distance = 0.0_dp
349 0 : if (use_teff_dim) distance = distance + teff_dist**2
350 0 : if (use_logg_dim) distance = distance + logg_dist**2
351 0 : if (use_meta_dim) distance = distance + meta_dist**2
352 :
353 0 : do j = 1, 4
354 0 : if (distance < min_distances(j)) then
355 : ! shift larger distances down
356 0 : if (j < 4) then
357 0 : min_distances(j + 1:4) = min_distances(j:3)
358 0 : indices(j + 1:4) = indices(j:3)
359 : end if
360 0 : min_distances(j) = distance
361 0 : indices(j) = i
362 0 : exit
363 : end if
364 : end do
365 : end do
366 :
367 0 : closest_indices = indices
368 0 : end subroutine get_closest_stellar_models
369 :
370 : ! linear interpolation -- binary search
371 182553 : subroutine linear_interpolate(x, y, x_val, y_val)
372 : real(dp), intent(in) :: x(:), y(:), x_val
373 : real(dp), intent(out) :: y_val
374 : integer :: low, high, mid
375 :
376 182553 : if (size(x) < 2) then
377 0 : print *, "Error: x array has fewer than 2 points."
378 0 : y_val = 0.0_dp
379 : return
380 : end if
381 :
382 182553 : if (size(x) /= size(y)) then
383 0 : print *, "Error: x and y arrays have different sizes."
384 0 : y_val = 0.0_dp
385 : return
386 : end if
387 :
388 : ! handle out-of-bounds cases
389 182553 : if (x_val <= x(1)) then
390 68018 : y_val = y(1)
391 68018 : return
392 114535 : else if (x_val >= x(size(x))) then
393 96849 : y_val = y(size(y))
394 : return
395 : end if
396 :
397 : ! binary search to find interval [x(low), x(low+1)]
398 53918 : low = 1
399 53918 : high = size(x)
400 94455 : do while (high - low > 1)
401 76769 : mid = (low + high)/2
402 94455 : if (x(mid) <= x_val) then
403 : low = mid
404 : else
405 53918 : high = mid
406 : end if
407 : end do
408 :
409 17686 : y_val = y(low) + (y(low + 1) - y(low))/(x(low + 1) - x(low))*(x_val - x(low))
410 : end subroutine linear_interpolate
411 :
412 : ! array interpolation for SED/filter alignment
413 112 : subroutine interpolate_array(x_in, y_in, x_out, y_out)
414 : real(dp), intent(in) :: x_in(:), y_in(:), x_out(:)
415 : real(dp), intent(out) :: y_out(:)
416 : integer :: i
417 :
418 112 : if (size(x_in) < 2 .or. size(y_in) < 2) then
419 0 : print *, "Error: x_in or y_in arrays have fewer than 2 points."
420 0 : call mesa_error(__FILE__, __LINE__)
421 : end if
422 :
423 112 : if (size(x_in) /= size(y_in)) then
424 0 : print *, "Error: x_in and y_in arrays have different sizes."
425 0 : call mesa_error(__FILE__, __LINE__)
426 : end if
427 :
428 112 : if (size(x_out) <= 0) then
429 0 : print *, "Error: x_out array is empty."
430 0 : call mesa_error(__FILE__, __LINE__)
431 : end if
432 :
433 182665 : do i = 1, size(x_out)
434 182665 : call linear_interpolate(x_in, y_in, x_out(i), y_out(i))
435 : end do
436 112 : end subroutine interpolate_array
437 :
438 : end module knn_interp
|