Line data Source code
1 : ! ***********************************************************************
2 : !
3 : ! Copyright (C) 2025 Niall Miller & The MESA Team
4 : !
5 : ! This program is free software: you can redistribute it and/or modify
6 : ! it under the terms of the GNU Lesser General Public License
7 : ! as published by the Free Software Foundation,
8 : ! either version 3 of the License, or (at your option) any later version.
9 : !
10 : ! This program is distributed in the hope that it will be useful,
11 : ! but WITHOUT ANY WARRANTY; without even the implied warranty of
12 : ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
13 : ! See the GNU Lesser General Public License for more details.
14 : !
15 : ! You should have received a copy of the GNU Lesser General Public License
16 : ! along with this program. If not, see <https://www.gnu.org/licenses/>.
17 : !
18 : ! ***********************************************************************
19 :
20 : ! ***********************************************************************
21 : ! K-Nearest Neighbors interpolation module for spectral energy distributions (SEDs)
22 : ! ***********************************************************************
23 :
24 : module knn_interp
25 : use const_def, only: dp
26 : use colors_utils, only: dilute_flux, load_sed
27 : use utils_lib, only: mesa_error
28 : implicit none
29 :
30 : private
31 : public :: construct_sed_knn, load_sed, interpolate_array, dilute_flux
32 :
33 : contains
34 :
35 : !---------------------------------------------------------------------------
36 : ! Main entry point: Construct a SED using KNN interpolation
37 : !---------------------------------------------------------------------------
38 0 : subroutine construct_sed_knn(teff, log_g, metallicity, R, d, file_names, &
39 0 : lu_teff, lu_logg, lu_meta, stellar_model_dir, &
40 : wavelengths, fluxes)
41 : real(dp), intent(in) :: teff, log_g, metallicity, R, d
42 : real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:)
43 : character(len=*), intent(in) :: stellar_model_dir
44 : character(len=100), intent(in) :: file_names(:)
45 : real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes
46 :
47 : integer, dimension(4) :: closest_indices
48 0 : real(dp), dimension(:), allocatable :: temp_wavelengths, temp_flux, common_wavelengths
49 0 : real(dp), dimension(:, :), allocatable :: model_fluxes
50 : real(dp), dimension(4) :: weights, distances
51 : integer :: i, n_points
52 : real(dp) :: sum_weights
53 0 : real(dp), dimension(:), allocatable :: diluted_flux
54 :
55 : ! Get the four closest stellar models
56 : call get_closest_stellar_models(teff, log_g, metallicity, lu_teff, &
57 0 : lu_logg, lu_meta, closest_indices)
58 :
59 : ! Load the first SED to define the wavelength grid
60 0 : call load_sed(trim(stellar_model_dir)//trim(file_names(closest_indices(1))), &
61 0 : closest_indices(1), temp_wavelengths, temp_flux)
62 :
63 0 : n_points = size(temp_wavelengths)
64 0 : allocate (common_wavelengths(n_points))
65 0 : common_wavelengths = temp_wavelengths
66 :
67 : ! Allocate flux array for the models (4 models, n_points each)
68 0 : allocate (model_fluxes(4, n_points))
69 0 : call interpolate_array(temp_wavelengths, temp_flux, common_wavelengths, model_fluxes(1, :))
70 :
71 : ! Load and interpolate remaining SEDs
72 0 : do i = 2, 4
73 0 : call load_sed(trim(stellar_model_dir)//trim(file_names(closest_indices(i))), &
74 0 : closest_indices(i), temp_wavelengths, temp_flux)
75 :
76 0 : call interpolate_array(temp_wavelengths, temp_flux, common_wavelengths, model_fluxes(i, :))
77 : end do
78 :
79 : ! Compute distances and weights for the four models
80 0 : do i = 1, 4
81 0 : distances(i) = sqrt((lu_teff(closest_indices(i)) - teff)**2 + &
82 0 : (lu_logg(closest_indices(i)) - log_g)**2 + &
83 0 : (lu_meta(closest_indices(i)) - metallicity)**2)
84 0 : if (distances(i) == 0.0_dp) distances(i) = 1.0d-10 ! Prevent division by zero
85 0 : weights(i) = 1.0_dp/distances(i)
86 : end do
87 :
88 : ! Normalize weights
89 0 : sum_weights = sum(weights)
90 0 : weights = weights/sum_weights
91 :
92 : ! Allocate output arrays
93 0 : allocate (wavelengths(n_points), fluxes(n_points))
94 0 : wavelengths = common_wavelengths
95 0 : fluxes = 0.0_dp
96 :
97 : ! Perform weighted combination of the model fluxes (still at the stellar surface)
98 0 : do i = 1, 4
99 0 : fluxes = fluxes + weights(i)*model_fluxes(i, :)
100 : end do
101 :
102 : ! Now, apply the dilution factor (R/d)^2 to convert the surface flux density
103 : ! into the observed flux density at Earth.
104 0 : allocate (diluted_flux(n_points))
105 0 : call dilute_flux(fluxes, R, d, diluted_flux)
106 0 : fluxes = diluted_flux
107 :
108 0 : end subroutine construct_sed_knn
109 :
110 : !---------------------------------------------------------------------------
111 : ! Identify the four closest stellar models
112 : !---------------------------------------------------------------------------
113 0 : subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, &
114 0 : lu_logg, lu_meta, closest_indices)
115 : real(dp), intent(in) :: teff, log_g, metallicity
116 : real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:)
117 : integer, dimension(4), intent(out) :: closest_indices
118 : logical :: use_teff_dim, use_logg_dim, use_meta_dim
119 :
120 : integer :: i, n, j
121 : real(dp) :: distance, norm_teff, norm_logg, norm_meta
122 0 : real(dp), dimension(:), allocatable :: scaled_lu_teff, scaled_lu_logg, scaled_lu_meta
123 : real(dp), dimension(4) :: min_distances
124 : integer, dimension(4) :: indices
125 : real(dp) :: teff_min, teff_max, logg_min, logg_max, meta_min, meta_max
126 : real(dp) :: teff_dist, logg_dist, meta_dist
127 :
128 0 : n = size(lu_teff)
129 0 : min_distances = huge(1.0)
130 0 : indices = -1
131 :
132 : ! Find min and max for normalization
133 0 : teff_min = minval(lu_teff)
134 0 : teff_max = maxval(lu_teff)
135 0 : logg_min = minval(lu_logg)
136 0 : logg_max = maxval(lu_logg)
137 0 : meta_min = minval(lu_meta)
138 0 : meta_max = maxval(lu_meta)
139 :
140 : ! Allocate and scale lookup table values
141 0 : allocate (scaled_lu_teff(n), scaled_lu_logg(n), scaled_lu_meta(n))
142 :
143 0 : if (teff_max - teff_min > 0.0_dp) then
144 0 : scaled_lu_teff = (lu_teff - teff_min)/(teff_max - teff_min)
145 : end if
146 :
147 0 : if (logg_max - logg_min > 0.0_dp) then
148 0 : scaled_lu_logg = (lu_logg - logg_min)/(logg_max - logg_min)
149 : end if
150 :
151 0 : if (meta_max - meta_min > 0.0_dp) then
152 0 : scaled_lu_meta = (lu_meta - meta_min)/(meta_max - meta_min)
153 : end if
154 :
155 : ! Normalize input parameters
156 0 : norm_teff = (teff - teff_min)/(teff_max - teff_min)
157 0 : norm_logg = (log_g - logg_min)/(logg_max - logg_min)
158 0 : norm_meta = (metallicity - meta_min)/(meta_max - meta_min)
159 :
160 : ! Detect dummy axes once (outside the loop)
161 0 : use_teff_dim = .not. (all(lu_teff == 0.0_dp) .or. all(lu_teff == 999.0_dp) .or. all(lu_teff == -999.0_dp))
162 0 : use_logg_dim = .not. (all(lu_logg == 0.0_dp) .or. all(lu_logg == 999.0_dp) .or. all(lu_logg == -999.0_dp))
163 0 : use_meta_dim = .not. (all(lu_meta == 0.0_dp) .or. all(lu_meta == 999.0_dp) .or. all(lu_meta == -999.0_dp))
164 :
165 : ! Find closest models
166 0 : do i = 1, n
167 0 : teff_dist = 0.0_dp
168 0 : logg_dist = 0.0_dp
169 0 : meta_dist = 0.0_dp
170 :
171 0 : if (teff_max - teff_min > 0.0_dp) then
172 0 : teff_dist = scaled_lu_teff(i) - norm_teff
173 : end if
174 :
175 0 : if (logg_max - logg_min > 0.0_dp) then
176 0 : logg_dist = scaled_lu_logg(i) - norm_logg
177 : end if
178 :
179 0 : if (meta_max - meta_min > 0.0_dp) then
180 0 : meta_dist = scaled_lu_meta(i) - norm_meta
181 : end if
182 :
183 : ! Compute distance using only valid dimensions
184 0 : distance = 0.0_dp
185 0 : if (use_teff_dim) distance = distance + teff_dist**2
186 0 : if (use_logg_dim) distance = distance + logg_dist**2
187 0 : if (use_meta_dim) distance = distance + meta_dist**2
188 :
189 0 : do j = 1, 4
190 0 : if (distance < min_distances(j)) then
191 : ! Shift larger distances down
192 0 : if (j < 4) then
193 0 : min_distances(j + 1:4) = min_distances(j:3)
194 0 : indices(j + 1:4) = indices(j:3)
195 : end if
196 0 : min_distances(j) = distance
197 0 : indices(j) = i
198 0 : exit
199 : end if
200 : end do
201 : end do
202 :
203 0 : closest_indices = indices
204 0 : end subroutine get_closest_stellar_models
205 :
206 : !---------------------------------------------------------------------------
207 : ! Linear interpolation (binary search version for efficiency)
208 : !---------------------------------------------------------------------------
209 56658 : subroutine linear_interpolate(x, y, x_val, y_val)
210 : real(dp), intent(in) :: x(:), y(:), x_val
211 : real(dp), intent(out) :: y_val
212 : integer :: low, high, mid
213 :
214 : ! Validate input sizes
215 56658 : if (size(x) < 2) then
216 0 : print *, "Error: x array has fewer than 2 points."
217 0 : y_val = 0.0_dp
218 : return
219 : end if
220 :
221 56658 : if (size(x) /= size(y)) then
222 0 : print *, "Error: x and y arrays have different sizes."
223 0 : y_val = 0.0_dp
224 : return
225 : end if
226 :
227 : ! Handle out-of-bounds cases
228 56658 : if (x_val <= x(1)) then
229 14183 : y_val = y(1)
230 14183 : return
231 42475 : else if (x_val >= x(size(x))) then
232 39159 : y_val = y(size(y))
233 : return
234 : end if
235 :
236 : ! Binary search to find the proper interval [x(low), x(low+1)]
237 3316 : low = 1
238 3316 : high = size(x)
239 17535 : do while (high - low > 1)
240 14219 : mid = (low + high)/2
241 17535 : if (x(mid) <= x_val) then
242 : low = mid
243 : else
244 10118 : high = mid
245 : end if
246 : end do
247 :
248 : ! Linear interpolation between x(low) and x(low+1)
249 3316 : y_val = y(low) + (y(low + 1) - y(low))/(x(low + 1) - x(low))*(x_val - x(low))
250 : end subroutine linear_interpolate
251 :
252 : !---------------------------------------------------------------------------
253 : ! Array interpolation for SED construction
254 : !---------------------------------------------------------------------------
255 7 : subroutine interpolate_array(x_in, y_in, x_out, y_out)
256 : real(dp), intent(in) :: x_in(:), y_in(:), x_out(:)
257 : real(dp), intent(out) :: y_out(:)
258 : integer :: i
259 :
260 : ! Validate input sizes
261 7 : if (size(x_in) < 2 .or. size(y_in) < 2) then
262 0 : print *, "Error: x_in or y_in arrays have fewer than 2 points."
263 0 : call mesa_error(__FILE__, __LINE__)
264 : end if
265 :
266 7 : if (size(x_in) /= size(y_in)) then
267 0 : print *, "Error: x_in and y_in arrays have different sizes."
268 0 : call mesa_error(__FILE__, __LINE__)
269 : end if
270 :
271 7 : if (size(x_out) <= 0) then
272 0 : print *, "Error: x_out array is empty."
273 0 : call mesa_error(__FILE__, __LINE__)
274 : end if
275 :
276 56665 : do i = 1, size(x_out)
277 56665 : call linear_interpolate(x_in, y_in, x_out(i), y_out(i))
278 : end do
279 7 : end subroutine interpolate_array
280 :
281 : end module knn_interp
|