LCOV - code coverage report
Current view: top level - colors/private - knn_interp.f90 (source / functions) Coverage Total Hit
Test: coverage.info Lines: 11.8 % 187 22
Test Date: 2026-05-14 09:58:24 Functions: 28.6 % 7 2

            Line data    Source code
       1              : ! ***********************************************************************
       2              : !
       3              : !   Copyright (C) 2025  Niall Miller & The MESA Team
       4              : !
       5              : !   This program is free software: you can redistribute it and/or modify
       6              : !   it under the terms of the GNU Lesser General Public License
       7              : !   as published by the Free Software Foundation,
       8              : !   either version 3 of the License, or (at your option) any later version.
       9              : !
      10              : !   This program is distributed in the hope that it will be useful,
      11              : !   but WITHOUT ANY WARRANTY; without even the implied warranty of
      12              : !   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
      13              : !   See the GNU Lesser General Public License for more details.
      14              : !
      15              : !   You should have received a copy of the GNU Lesser General Public License
      16              : !   along with this program. If not, see <https://www.gnu.org/licenses/>.
      17              : !
      18              : ! ***********************************************************************
      19              : 
      20              : ! knn interpolation for SEDs
      21              : !
      22              : ! data-loading strategy selected by rq%cube_loaded:
      23              : !   .true.  -> extract neighbor SEDs from the preloaded 4-D cube
      24              : !   .false. -> load individual SED files via the lookup table (fallback)
      25              : 
      26              : module knn_interp
      27              :    use const_def, only: dp
      28              :    use colors_def, only: Colors_General_Info
      29              :    use colors_utils, only: dilute_flux, load_sed_cached
      30              :    use utils_lib, only: mesa_error
      31              :    implicit none
      32              : 
      33              :    private
      34              :    public :: construct_sed_knn, interpolate_array
      35              : 
      36              : contains
      37              : 
      38              :    ! main entry point -- construct a SED using KNN interpolation
      39              :    ! strategy controlled by rq%cube_loaded (set at init)
      40            0 :    subroutine construct_sed_knn(rq, teff, log_g, metallicity, R, d, &
      41              :                                 stellar_model_dir, wavelengths, fluxes)
      42              :       type(Colors_General_Info), intent(inout) :: rq
      43              :       real(dp), intent(in) :: teff, log_g, metallicity, R, d
      44              :       character(len=*), intent(in) :: stellar_model_dir
      45              :       real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes
      46              : 
      47              :       integer :: n_lambda
      48            0 :       real(dp), dimension(:), allocatable :: interp_flux, diluted_flux
      49              : 
      50            0 :       if (rq%cube_loaded) then
      51              :          ! fast path: extract neighbors from preloaded cube
      52              :          call construct_sed_from_cube(rq, teff, log_g, metallicity, &
      53            0 :                                       interp_flux, wavelengths)
      54            0 :          n_lambda = size(wavelengths)
      55              :       else
      56              :          ! fallback path: load individual SED files
      57              :          call construct_sed_from_files(rq, teff, log_g, metallicity, &
      58            0 :                                        stellar_model_dir, interp_flux, wavelengths)
      59            0 :          n_lambda = size(wavelengths)
      60              :       end if
      61              : 
      62            0 :       allocate (diluted_flux(n_lambda))
      63            0 :       call dilute_flux(interp_flux, R, d, diluted_flux)
      64            0 :       fluxes = diluted_flux
      65              : 
      66            0 :    end subroutine construct_sed_knn
      67              : 
      68              :    ! cube path: find 4 nearest grid points, extract their SEDs from cube_flux, blend by IDW
      69            0 :    subroutine construct_sed_from_cube(rq, teff, log_g, metallicity, &
      70              :                                       interp_flux, wavelengths)
      71              :       type(Colors_General_Info), intent(inout) :: rq
      72              :       real(dp), intent(in) :: teff, log_g, metallicity
      73              :       real(dp), dimension(:), allocatable, intent(out) :: interp_flux, wavelengths
      74              : 
      75              :       integer :: n_lambda, k
      76              :       integer, dimension(4) :: nbr_it, nbr_ig, nbr_im
      77              :       real(dp), dimension(4) :: distances, weights
      78              :       real(dp) :: sum_weights
      79              : 
      80            0 :       n_lambda = size(rq%cube_wavelengths)
      81            0 :       allocate (wavelengths(n_lambda))
      82            0 :       wavelengths = rq%cube_wavelengths
      83              : 
      84              :       ! find the 4 nearest grid points in the structured cube
      85              :       call get_closest_grid_points(teff, log_g, metallicity, &
      86              :                                    rq%cube_teff_grid, rq%cube_logg_grid, &
      87              :                                    rq%cube_meta_grid, &
      88            0 :                                    nbr_it, nbr_ig, nbr_im, distances)
      89              : 
      90              :       ! compute inverse-distance weights
      91            0 :       do k = 1, 4
      92            0 :          if (distances(k) == 0.0_dp) distances(k) = 1.0e-10_dp
      93            0 :          weights(k) = 1.0_dp/distances(k)
      94              :       end do
      95            0 :       sum_weights = sum(weights)
      96            0 :       weights = weights/sum_weights
      97              : 
      98              :       ! blend neighbor SEDs from cube
      99            0 :       allocate (interp_flux(n_lambda))
     100            0 :       interp_flux = 0.0_dp
     101            0 :       do k = 1, 4
     102              :          interp_flux = interp_flux + weights(k)* &
     103            0 :                        rq%cube_flux(nbr_it(k), nbr_ig(k), nbr_im(k), :)
     104              :       end do
     105              : 
     106            0 :    end subroutine construct_sed_from_cube
     107              : 
     108              :    ! fallback path: find 4 nearest models in the lookup table, load SEDs, blend by IDW
     109            0 :    subroutine construct_sed_from_files(rq, teff, log_g, metallicity, &
     110              :                                        stellar_model_dir, interp_flux, wavelengths)
     111              :       use colors_utils, only: resolve_path
     112              :       type(Colors_General_Info), intent(inout) :: rq
     113              :       real(dp), intent(in) :: teff, log_g, metallicity
     114              :       character(len=*), intent(in) :: stellar_model_dir
     115              :       real(dp), dimension(:), allocatable, intent(out) :: interp_flux, wavelengths
     116              : 
     117              :       integer, dimension(4) :: closest_indices
     118            0 :       real(dp), dimension(:), allocatable :: temp_flux, common_wavelengths
     119            0 :       real(dp), dimension(:, :), allocatable :: model_fluxes
     120              :       real(dp), dimension(4) :: weights, distances
     121              :       integer :: i, n_points
     122              :       real(dp) :: sum_weights
     123              :       character(len=512) :: resolved_dir
     124              : 
     125            0 :       resolved_dir = trim(resolve_path(stellar_model_dir))
     126              : 
     127              :       ! get the four closest stellar models from the flat lookup table
     128              :       call get_closest_stellar_models(teff, log_g, metallicity, &
     129              :                                       rq%lu_teff, rq%lu_logg, rq%lu_meta, &
     130            0 :                                       closest_indices)
     131              : 
     132              :       ! load the first SED to define the wavelength grid (using cache)
     133            0 :       call load_sed_cached(rq, resolved_dir, closest_indices(1), temp_flux)
     134              : 
     135              :       ! get wavelengths from canonical copy on the handle
     136            0 :       if (rq%fallback_wavelengths_set) then
     137            0 :          n_points = size(rq%fallback_wavelengths)
     138            0 :          allocate (common_wavelengths(n_points))
     139            0 :          common_wavelengths = rq%fallback_wavelengths
     140              :       else
     141              :          ! should not happen -- load_sed_cached sets this on first call
     142            0 :          print *, 'KNN fallback: wavelengths not set after first SED load'
     143            0 :          call mesa_error(__FILE__, __LINE__)
     144              :       end if
     145              : 
     146            0 :       allocate (model_fluxes(4, n_points))
     147            0 :       model_fluxes(1, :) = temp_flux(1:n_points)
     148            0 :       if (allocated(temp_flux)) deallocate (temp_flux)
     149              : 
     150              :       ! load and store remaining SEDs
     151            0 :       do i = 2, 4
     152            0 :          call load_sed_cached(rq, resolved_dir, closest_indices(i), temp_flux)
     153            0 :          model_fluxes(i, :) = temp_flux(1:n_points)
     154            0 :          if (allocated(temp_flux)) deallocate (temp_flux)
     155              :       end do
     156              : 
     157              :       ! compute distances and weights for the four models
     158            0 :       do i = 1, 4
     159              :          distances(i) = sqrt((rq%lu_teff(closest_indices(i)) - teff)**2 + &
     160              :                              (rq%lu_logg(closest_indices(i)) - log_g)**2 + &
     161            0 :                              (rq%lu_meta(closest_indices(i)) - metallicity)**2)
     162            0 :          if (distances(i) == 0.0_dp) distances(i) = 1.0e-10_dp
     163            0 :          weights(i) = 1.0_dp/distances(i)
     164              :       end do
     165              : 
     166            0 :       sum_weights = sum(weights)
     167            0 :       weights = weights/sum_weights
     168              : 
     169            0 :       allocate (wavelengths(n_points))
     170            0 :       wavelengths = common_wavelengths
     171              : 
     172            0 :       allocate (interp_flux(n_points))
     173            0 :       interp_flux = 0.0_dp
     174              : 
     175              :       ! weighted combination of model fluxes
     176            0 :       do i = 1, 4
     177            0 :          interp_flux = interp_flux + weights(i)*model_fluxes(i, :)
     178              :       end do
     179              : 
     180            0 :    end subroutine construct_sed_from_files
     181              : 
     182              :    ! find the 4 closest grid points in the structured cube (normalised euclidean distance)
     183            0 :    subroutine get_closest_grid_points(teff, log_g, metallicity, &
     184            0 :                                       teff_grid, logg_grid, meta_grid, &
     185              :                                       nbr_it, nbr_ig, nbr_im, distances)
     186              :       real(dp), intent(in) :: teff, log_g, metallicity
     187              :       real(dp), intent(in) :: teff_grid(:), logg_grid(:), meta_grid(:)
     188              :       integer, dimension(4), intent(out) :: nbr_it, nbr_ig, nbr_im
     189              :       real(dp), dimension(4), intent(out) :: distances
     190              : 
     191              :       integer :: it, ig, im, j
     192              :       real(dp) :: dist, norm_teff, norm_logg, norm_meta
     193              :       real(dp) :: teff_min, teff_max, logg_min, logg_max, meta_min, meta_max
     194              :       real(dp) :: scaled_t, scaled_g, scaled_m, dt, dg, dm
     195              :       logical :: use_teff_dim, use_logg_dim, use_meta_dim
     196              : 
     197            0 :       distances = huge(1.0_dp)
     198            0 :       nbr_it = 1; nbr_ig = 1; nbr_im = 1
     199              : 
     200              :       ! normalisation ranges
     201            0 :       teff_min = minval(teff_grid); teff_max = maxval(teff_grid)
     202            0 :       logg_min = minval(logg_grid); logg_max = maxval(logg_grid)
     203            0 :       meta_min = minval(meta_grid); meta_max = maxval(meta_grid)
     204              : 
     205              :       ! detect dummy axes
     206              :       use_teff_dim = .not. (all(teff_grid == 0.0_dp) .or. &
     207            0 :                             all(teff_grid == 999.0_dp) .or. all(teff_grid == -999.0_dp))
     208              :       use_logg_dim = .not. (all(logg_grid == 0.0_dp) .or. &
     209            0 :                             all(logg_grid == 999.0_dp) .or. all(logg_grid == -999.0_dp))
     210              :       use_meta_dim = .not. (all(meta_grid == 0.0_dp) .or. &
     211            0 :                             all(meta_grid == 999.0_dp) .or. all(meta_grid == -999.0_dp))
     212              : 
     213              :       ! normalised target values
     214            0 :       norm_teff = 0.0_dp; norm_logg = 0.0_dp; norm_meta = 0.0_dp
     215            0 :       if (use_teff_dim .and. teff_max - teff_min > 0.0_dp) &
     216            0 :          norm_teff = (teff - teff_min)/(teff_max - teff_min)
     217            0 :       if (use_logg_dim .and. logg_max - logg_min > 0.0_dp) &
     218            0 :          norm_logg = (log_g - logg_min)/(logg_max - logg_min)
     219            0 :       if (use_meta_dim .and. meta_max - meta_min > 0.0_dp) &
     220            0 :          norm_meta = (metallicity - meta_min)/(meta_max - meta_min)
     221              : 
     222            0 :       do it = 1, size(teff_grid)
     223            0 :          if (use_teff_dim .and. teff_max - teff_min > 0.0_dp) then
     224            0 :             scaled_t = (teff_grid(it) - teff_min)/(teff_max - teff_min)
     225              :          else
     226              :             scaled_t = 0.0_dp
     227              :          end if
     228            0 :          dt = 0.0_dp
     229            0 :          if (use_teff_dim) dt = (scaled_t - norm_teff)**2
     230              : 
     231            0 :          do ig = 1, size(logg_grid)
     232            0 :             if (use_logg_dim .and. logg_max - logg_min > 0.0_dp) then
     233            0 :                scaled_g = (logg_grid(ig) - logg_min)/(logg_max - logg_min)
     234              :             else
     235              :                scaled_g = 0.0_dp
     236              :             end if
     237            0 :             dg = 0.0_dp
     238            0 :             if (use_logg_dim) dg = (scaled_g - norm_logg)**2
     239              : 
     240            0 :             do im = 1, size(meta_grid)
     241            0 :                if (use_meta_dim .and. meta_max - meta_min > 0.0_dp) then
     242            0 :                   scaled_m = (meta_grid(im) - meta_min)/(meta_max - meta_min)
     243              :                else
     244              :                   scaled_m = 0.0_dp
     245              :                end if
     246            0 :                dm = 0.0_dp
     247            0 :                if (use_meta_dim) dm = (scaled_m - norm_meta)**2
     248              : 
     249            0 :                dist = dt + dg + dm
     250              : 
     251              :                ! insert into sorted top-4 if closer
     252            0 :                do j = 1, 4
     253            0 :                   if (dist < distances(j)) then
     254            0 :                      if (j < 4) then
     255            0 :                         distances(j + 1:4) = distances(j:3)
     256            0 :                         nbr_it(j + 1:4) = nbr_it(j:3)
     257            0 :                         nbr_ig(j + 1:4) = nbr_ig(j:3)
     258            0 :                         nbr_im(j + 1:4) = nbr_im(j:3)
     259              :                      end if
     260            0 :                      distances(j) = dist
     261            0 :                      nbr_it(j) = it
     262            0 :                      nbr_ig(j) = ig
     263            0 :                      nbr_im(j) = im
     264            0 :                      exit
     265              :                   end if
     266              :                end do
     267              :             end do
     268              :          end do
     269              :       end do
     270              : 
     271              :       ! convert squared distances to actual distances for weighting
     272            0 :       do j = 1, 4
     273            0 :          distances(j) = sqrt(distances(j))
     274              :       end do
     275              : 
     276            0 :    end subroutine get_closest_grid_points
     277              : 
     278              :    ! find the four closest stellar models in the flat lookup table
     279            0 :    subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, &
     280            0 :                                          lu_logg, lu_meta, closest_indices)
     281              :       real(dp), intent(in) :: teff, log_g, metallicity
     282              :       real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:)
     283              :       integer, dimension(4), intent(out) :: closest_indices
     284              :       logical :: use_teff_dim, use_logg_dim, use_meta_dim
     285              : 
     286              :       integer :: i, n, j
     287              :       real(dp) :: distance, norm_teff, norm_logg, norm_meta
     288            0 :       real(dp), dimension(:), allocatable :: scaled_lu_teff, scaled_lu_logg, scaled_lu_meta
     289              :       real(dp), dimension(4) :: min_distances
     290              :       integer, dimension(4) :: indices
     291              :       real(dp) :: teff_min, teff_max, logg_min, logg_max, meta_min, meta_max
     292              :       real(dp) :: teff_dist, logg_dist, meta_dist
     293              : 
     294            0 :       n = size(lu_teff)
     295            0 :       min_distances = huge(1.0)
     296            0 :       indices = -1
     297              : 
     298              :       ! find min and max for normalisation
     299            0 :       teff_min = minval(lu_teff)
     300            0 :       teff_max = maxval(lu_teff)
     301            0 :       logg_min = minval(lu_logg)
     302            0 :       logg_max = maxval(lu_logg)
     303            0 :       meta_min = minval(lu_meta)
     304            0 :       meta_max = maxval(lu_meta)
     305              : 
     306            0 :       allocate (scaled_lu_teff(n), scaled_lu_logg(n), scaled_lu_meta(n))
     307              : 
     308            0 :       if (teff_max - teff_min > 0.0_dp) then
     309            0 :          scaled_lu_teff = (lu_teff - teff_min)/(teff_max - teff_min)
     310              :       end if
     311              : 
     312            0 :       if (logg_max - logg_min > 0.0_dp) then
     313            0 :          scaled_lu_logg = (lu_logg - logg_min)/(logg_max - logg_min)
     314              :       end if
     315              : 
     316            0 :       if (meta_max - meta_min > 0.0_dp) then
     317            0 :          scaled_lu_meta = (lu_meta - meta_min)/(meta_max - meta_min)
     318              :       end if
     319              : 
     320              :       ! normalise input parameters
     321            0 :       norm_teff = (teff - teff_min)/(teff_max - teff_min)
     322            0 :       norm_logg = (log_g - logg_min)/(logg_max - logg_min)
     323            0 :       norm_meta = (metallicity - meta_min)/(meta_max - meta_min)
     324              : 
     325              :       ! detect dummy axes -- skip degenerate dimensions in distance calc
     326            0 :       use_teff_dim = .not. (all(lu_teff == 0.0_dp) .or. all(lu_teff == 999.0_dp) .or. all(lu_teff == -999.0_dp))
     327            0 :       use_logg_dim = .not. (all(lu_logg == 0.0_dp) .or. all(lu_logg == 999.0_dp) .or. all(lu_logg == -999.0_dp))
     328            0 :       use_meta_dim = .not. (all(lu_meta == 0.0_dp) .or. all(lu_meta == 999.0_dp) .or. all(lu_meta == -999.0_dp))
     329              : 
     330            0 :       do i = 1, n
     331            0 :          teff_dist = 0.0_dp
     332            0 :          logg_dist = 0.0_dp
     333            0 :          meta_dist = 0.0_dp
     334              : 
     335            0 :          if (teff_max - teff_min > 0.0_dp) then
     336            0 :             teff_dist = scaled_lu_teff(i) - norm_teff
     337              :          end if
     338              : 
     339            0 :          if (logg_max - logg_min > 0.0_dp) then
     340            0 :             logg_dist = scaled_lu_logg(i) - norm_logg
     341              :          end if
     342              : 
     343            0 :          if (meta_max - meta_min > 0.0_dp) then
     344            0 :             meta_dist = scaled_lu_meta(i) - norm_meta
     345              :          end if
     346              : 
     347              :          ! compute distance using only valid dimensions
     348            0 :          distance = 0.0_dp
     349            0 :          if (use_teff_dim) distance = distance + teff_dist**2
     350            0 :          if (use_logg_dim) distance = distance + logg_dist**2
     351            0 :          if (use_meta_dim) distance = distance + meta_dist**2
     352              : 
     353            0 :          do j = 1, 4
     354            0 :             if (distance < min_distances(j)) then
     355              :                ! shift larger distances down
     356            0 :                if (j < 4) then
     357            0 :                   min_distances(j + 1:4) = min_distances(j:3)
     358            0 :                   indices(j + 1:4) = indices(j:3)
     359              :                end if
     360            0 :                min_distances(j) = distance
     361            0 :                indices(j) = i
     362            0 :                exit
     363              :             end if
     364              :          end do
     365              :       end do
     366              : 
     367            0 :       closest_indices = indices
     368            0 :    end subroutine get_closest_stellar_models
     369              : 
     370              :    ! linear interpolation -- binary search
     371       182553 :    subroutine linear_interpolate(x, y, x_val, y_val)
     372              :       real(dp), intent(in) :: x(:), y(:), x_val
     373              :       real(dp), intent(out) :: y_val
     374              :       integer :: low, high, mid
     375              : 
     376       182553 :       if (size(x) < 2) then
     377            0 :          print *, "Error: x array has fewer than 2 points."
     378            0 :          y_val = 0.0_dp
     379              :          return
     380              :       end if
     381              : 
     382       182553 :       if (size(x) /= size(y)) then
     383            0 :          print *, "Error: x and y arrays have different sizes."
     384            0 :          y_val = 0.0_dp
     385              :          return
     386              :       end if
     387              : 
     388              :       ! handle out-of-bounds cases
     389       182553 :       if (x_val <= x(1)) then
     390        68018 :          y_val = y(1)
     391        68018 :          return
     392       114535 :       else if (x_val >= x(size(x))) then
     393        96849 :          y_val = y(size(y))
     394              :          return
     395              :       end if
     396              : 
     397              :       ! binary search to find interval [x(low), x(low+1)]
     398        53918 :       low = 1
     399        53918 :       high = size(x)
     400        94455 :       do while (high - low > 1)
     401        76769 :          mid = (low + high)/2
     402        94455 :          if (x(mid) <= x_val) then
     403              :             low = mid
     404              :          else
     405        53918 :             high = mid
     406              :          end if
     407              :       end do
     408              : 
     409        17686 :       y_val = y(low) + (y(low + 1) - y(low))/(x(low + 1) - x(low))*(x_val - x(low))
     410              :    end subroutine linear_interpolate
     411              : 
     412              :    ! array interpolation for SED/filter alignment
     413          112 :    subroutine interpolate_array(x_in, y_in, x_out, y_out)
     414              :       real(dp), intent(in) :: x_in(:), y_in(:), x_out(:)
     415              :       real(dp), intent(out) :: y_out(:)
     416              :       integer :: i
     417              : 
     418          112 :       if (size(x_in) < 2 .or. size(y_in) < 2) then
     419            0 :          print *, "Error: x_in or y_in arrays have fewer than 2 points."
     420            0 :          call mesa_error(__FILE__, __LINE__)
     421              :       end if
     422              : 
     423          112 :       if (size(x_in) /= size(y_in)) then
     424            0 :          print *, "Error: x_in and y_in arrays have different sizes."
     425            0 :          call mesa_error(__FILE__, __LINE__)
     426              :       end if
     427              : 
     428          112 :       if (size(x_out) <= 0) then
     429            0 :          print *, "Error: x_out array is empty."
     430            0 :          call mesa_error(__FILE__, __LINE__)
     431              :       end if
     432              : 
     433       182665 :       do i = 1, size(x_out)
     434       182665 :          call linear_interpolate(x_in, y_in, x_out(i), y_out(i))
     435              :       end do
     436          112 :    end subroutine interpolate_array
     437              : 
     438              : end module knn_interp
        

Generated by: LCOV version 2.0-1