LCOV - code coverage report
Current view: top level - colors/private - knn_interp.f90 (source / functions) Coverage Total Hit
Test: coverage.info Lines: 19.1 % 115 22
Test Date: 2026-01-29 18:28:55 Functions: 50.0 % 4 2

            Line data    Source code
       1              : ! ***********************************************************************
       2              : !
       3              : !   Copyright (C) 2025  Niall Miller & The MESA Team
       4              : !
       5              : !   This program is free software: you can redistribute it and/or modify
       6              : !   it under the terms of the GNU Lesser General Public License
       7              : !   as published by the Free Software Foundation,
       8              : !   either version 3 of the License, or (at your option) any later version.
       9              : !
      10              : !   This program is distributed in the hope that it will be useful,
      11              : !   but WITHOUT ANY WARRANTY; without even the implied warranty of
      12              : !   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
      13              : !   See the GNU Lesser General Public License for more details.
      14              : !
      15              : !   You should have received a copy of the GNU Lesser General Public License
      16              : !   along with this program. If not, see <https://www.gnu.org/licenses/>.
      17              : !
      18              : ! ***********************************************************************
      19              : 
      20              : ! ***********************************************************************
      21              : ! K-Nearest Neighbors interpolation module for spectral energy distributions (SEDs)
      22              : ! ***********************************************************************
      23              : 
      24              : module knn_interp
      25              :    use const_def, only: dp
      26              :    use colors_utils, only: dilute_flux, load_sed
      27              :    use utils_lib, only: mesa_error
      28              :    implicit none
      29              : 
      30              :    private
      31              :    public :: construct_sed_knn, load_sed, interpolate_array, dilute_flux
      32              : 
      33              : contains
      34              : 
      35              :    !---------------------------------------------------------------------------
      36              :    ! Main entry point: Construct a SED using KNN interpolation
      37              :    !---------------------------------------------------------------------------
      38            0 :    subroutine construct_sed_knn(teff, log_g, metallicity, R, d, file_names, &
      39            0 :                                 lu_teff, lu_logg, lu_meta, stellar_model_dir, &
      40              :                                 wavelengths, fluxes)
      41              :       real(dp), intent(in) :: teff, log_g, metallicity, R, d
      42              :       real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:)
      43              :       character(len=*), intent(in) :: stellar_model_dir
      44              :       character(len=100), intent(in) :: file_names(:)
      45              :       real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes
      46              : 
      47              :       integer, dimension(4) :: closest_indices
      48            0 :       real(dp), dimension(:), allocatable :: temp_wavelengths, temp_flux, common_wavelengths
      49            0 :       real(dp), dimension(:, :), allocatable :: model_fluxes
      50              :       real(dp), dimension(4) :: weights, distances
      51              :       integer :: i, n_points
      52              :       real(dp) :: sum_weights
      53            0 :       real(dp), dimension(:), allocatable :: diluted_flux
      54              : 
      55              :       ! Get the four closest stellar models
      56              :       call get_closest_stellar_models(teff, log_g, metallicity, lu_teff, &
      57            0 :                                       lu_logg, lu_meta, closest_indices)
      58              : 
      59              :       ! Load the first SED to define the wavelength grid
      60            0 :       call load_sed(trim(stellar_model_dir)//trim(file_names(closest_indices(1))), &
      61            0 :                     closest_indices(1), temp_wavelengths, temp_flux)
      62              : 
      63            0 :       n_points = size(temp_wavelengths)
      64            0 :       allocate (common_wavelengths(n_points))
      65            0 :       common_wavelengths = temp_wavelengths
      66              : 
      67              :       ! Allocate flux array for the models (4 models, n_points each)
      68            0 :       allocate (model_fluxes(4, n_points))
      69            0 :       call interpolate_array(temp_wavelengths, temp_flux, common_wavelengths, model_fluxes(1, :))
      70              : 
      71              :       ! Load and interpolate remaining SEDs
      72            0 :       do i = 2, 4
      73            0 :          call load_sed(trim(stellar_model_dir)//trim(file_names(closest_indices(i))), &
      74            0 :                        closest_indices(i), temp_wavelengths, temp_flux)
      75              : 
      76            0 :          call interpolate_array(temp_wavelengths, temp_flux, common_wavelengths, model_fluxes(i, :))
      77              :       end do
      78              : 
      79              :       ! Compute distances and weights for the four models
      80            0 :       do i = 1, 4
      81            0 :          distances(i) = sqrt((lu_teff(closest_indices(i)) - teff)**2 + &
      82            0 :                              (lu_logg(closest_indices(i)) - log_g)**2 + &
      83            0 :                              (lu_meta(closest_indices(i)) - metallicity)**2)
      84            0 :          if (distances(i) == 0.0_dp) distances(i) = 1.0d-10  ! Prevent division by zero
      85            0 :          weights(i) = 1.0_dp/distances(i)
      86              :       end do
      87              : 
      88              :       ! Normalize weights
      89            0 :       sum_weights = sum(weights)
      90            0 :       weights = weights/sum_weights
      91              : 
      92              :       ! Allocate output arrays
      93            0 :       allocate (wavelengths(n_points), fluxes(n_points))
      94            0 :       wavelengths = common_wavelengths
      95            0 :       fluxes = 0.0_dp
      96              : 
      97              :       ! Perform weighted combination of the model fluxes (still at the stellar surface)
      98            0 :       do i = 1, 4
      99            0 :          fluxes = fluxes + weights(i)*model_fluxes(i, :)
     100              :       end do
     101              : 
     102              :       ! Now, apply the dilution factor (R/d)^2 to convert the surface flux density
     103              :       ! into the observed flux density at Earth.
     104            0 :       allocate (diluted_flux(n_points))
     105            0 :       call dilute_flux(fluxes, R, d, diluted_flux)
     106            0 :       fluxes = diluted_flux
     107              : 
     108            0 :    end subroutine construct_sed_knn
     109              : 
     110              :    !---------------------------------------------------------------------------
     111              :    ! Identify the four closest stellar models
     112              :    !---------------------------------------------------------------------------
     113            0 :    subroutine get_closest_stellar_models(teff, log_g, metallicity, lu_teff, &
     114            0 :                                          lu_logg, lu_meta, closest_indices)
     115              :       real(dp), intent(in) :: teff, log_g, metallicity
     116              :       real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:)
     117              :       integer, dimension(4), intent(out) :: closest_indices
     118              :       logical :: use_teff_dim, use_logg_dim, use_meta_dim
     119              : 
     120              :       integer :: i, n, j
     121              :       real(dp) :: distance, norm_teff, norm_logg, norm_meta
     122            0 :       real(dp), dimension(:), allocatable :: scaled_lu_teff, scaled_lu_logg, scaled_lu_meta
     123              :       real(dp), dimension(4) :: min_distances
     124              :       integer, dimension(4) :: indices
     125              :       real(dp) :: teff_min, teff_max, logg_min, logg_max, meta_min, meta_max
     126              :       real(dp) :: teff_dist, logg_dist, meta_dist
     127              : 
     128            0 :       n = size(lu_teff)
     129            0 :       min_distances = huge(1.0)
     130            0 :       indices = -1
     131              : 
     132              :       ! Find min and max for normalization
     133            0 :       teff_min = minval(lu_teff)
     134            0 :       teff_max = maxval(lu_teff)
     135            0 :       logg_min = minval(lu_logg)
     136            0 :       logg_max = maxval(lu_logg)
     137            0 :       meta_min = minval(lu_meta)
     138            0 :       meta_max = maxval(lu_meta)
     139              : 
     140              :       ! Allocate and scale lookup table values
     141            0 :       allocate (scaled_lu_teff(n), scaled_lu_logg(n), scaled_lu_meta(n))
     142              : 
     143            0 :       if (teff_max - teff_min > 0.0_dp) then
     144            0 :          scaled_lu_teff = (lu_teff - teff_min)/(teff_max - teff_min)
     145              :       end if
     146              : 
     147            0 :       if (logg_max - logg_min > 0.0_dp) then
     148            0 :          scaled_lu_logg = (lu_logg - logg_min)/(logg_max - logg_min)
     149              :       end if
     150              : 
     151            0 :       if (meta_max - meta_min > 0.0_dp) then
     152            0 :          scaled_lu_meta = (lu_meta - meta_min)/(meta_max - meta_min)
     153              :       end if
     154              : 
     155              :       ! Normalize input parameters
     156            0 :       norm_teff = (teff - teff_min)/(teff_max - teff_min)
     157            0 :       norm_logg = (log_g - logg_min)/(logg_max - logg_min)
     158            0 :       norm_meta = (metallicity - meta_min)/(meta_max - meta_min)
     159              : 
     160              :       ! Detect dummy axes once (outside the loop)
     161            0 :       use_teff_dim = .not. (all(lu_teff == 0.0_dp) .or. all(lu_teff == 999.0_dp) .or. all(lu_teff == -999.0_dp))
     162            0 :       use_logg_dim = .not. (all(lu_logg == 0.0_dp) .or. all(lu_logg == 999.0_dp) .or. all(lu_logg == -999.0_dp))
     163            0 :       use_meta_dim = .not. (all(lu_meta == 0.0_dp) .or. all(lu_meta == 999.0_dp) .or. all(lu_meta == -999.0_dp))
     164              : 
     165              :       ! Find closest models
     166            0 :       do i = 1, n
     167            0 :          teff_dist = 0.0_dp
     168            0 :          logg_dist = 0.0_dp
     169            0 :          meta_dist = 0.0_dp
     170              : 
     171            0 :          if (teff_max - teff_min > 0.0_dp) then
     172            0 :             teff_dist = scaled_lu_teff(i) - norm_teff
     173              :          end if
     174              : 
     175            0 :          if (logg_max - logg_min > 0.0_dp) then
     176            0 :             logg_dist = scaled_lu_logg(i) - norm_logg
     177              :          end if
     178              : 
     179            0 :          if (meta_max - meta_min > 0.0_dp) then
     180            0 :             meta_dist = scaled_lu_meta(i) - norm_meta
     181              :          end if
     182              : 
     183              :          ! Compute distance using only valid dimensions
     184            0 :          distance = 0.0_dp
     185            0 :          if (use_teff_dim) distance = distance + teff_dist**2
     186            0 :          if (use_logg_dim) distance = distance + logg_dist**2
     187            0 :          if (use_meta_dim) distance = distance + meta_dist**2
     188              : 
     189            0 :          do j = 1, 4
     190            0 :             if (distance < min_distances(j)) then
     191              :                ! Shift larger distances down
     192            0 :                if (j < 4) then
     193            0 :                   min_distances(j + 1:4) = min_distances(j:3)
     194            0 :                   indices(j + 1:4) = indices(j:3)
     195              :                end if
     196            0 :                min_distances(j) = distance
     197            0 :                indices(j) = i
     198            0 :                exit
     199              :             end if
     200              :          end do
     201              :       end do
     202              : 
     203            0 :       closest_indices = indices
     204            0 :    end subroutine get_closest_stellar_models
     205              : 
     206              :    !---------------------------------------------------------------------------
     207              :    ! Linear interpolation (binary search version for efficiency)
     208              :    !---------------------------------------------------------------------------
     209        56658 :    subroutine linear_interpolate(x, y, x_val, y_val)
     210              :       real(dp), intent(in) :: x(:), y(:), x_val
     211              :       real(dp), intent(out) :: y_val
     212              :       integer :: low, high, mid
     213              : 
     214              :       ! Validate input sizes
     215        56658 :       if (size(x) < 2) then
     216            0 :          print *, "Error: x array has fewer than 2 points."
     217            0 :          y_val = 0.0_dp
     218              :          return
     219              :       end if
     220              : 
     221        56658 :       if (size(x) /= size(y)) then
     222            0 :          print *, "Error: x and y arrays have different sizes."
     223            0 :          y_val = 0.0_dp
     224              :          return
     225              :       end if
     226              : 
     227              :       ! Handle out-of-bounds cases
     228        56658 :       if (x_val <= x(1)) then
     229        14183 :          y_val = y(1)
     230        14183 :          return
     231        42475 :       else if (x_val >= x(size(x))) then
     232        39159 :          y_val = y(size(y))
     233              :          return
     234              :       end if
     235              : 
     236              :       ! Binary search to find the proper interval [x(low), x(low+1)]
     237         3316 :       low = 1
     238         3316 :       high = size(x)
     239        17535 :       do while (high - low > 1)
     240        14219 :          mid = (low + high)/2
     241        17535 :          if (x(mid) <= x_val) then
     242              :             low = mid
     243              :          else
     244        10118 :             high = mid
     245              :          end if
     246              :       end do
     247              : 
     248              :       ! Linear interpolation between x(low) and x(low+1)
     249         3316 :       y_val = y(low) + (y(low + 1) - y(low))/(x(low + 1) - x(low))*(x_val - x(low))
     250              :    end subroutine linear_interpolate
     251              : 
     252              :    !---------------------------------------------------------------------------
     253              :    ! Array interpolation for SED construction
     254              :    !---------------------------------------------------------------------------
     255            7 :    subroutine interpolate_array(x_in, y_in, x_out, y_out)
     256              :       real(dp), intent(in) :: x_in(:), y_in(:), x_out(:)
     257              :       real(dp), intent(out) :: y_out(:)
     258              :       integer :: i
     259              : 
     260              :       ! Validate input sizes
     261            7 :       if (size(x_in) < 2 .or. size(y_in) < 2) then
     262            0 :          print *, "Error: x_in or y_in arrays have fewer than 2 points."
     263            0 :          call mesa_error(__FILE__, __LINE__)
     264              :       end if
     265              : 
     266            7 :       if (size(x_in) /= size(y_in)) then
     267            0 :          print *, "Error: x_in and y_in arrays have different sizes."
     268            0 :          call mesa_error(__FILE__, __LINE__)
     269              :       end if
     270              : 
     271            7 :       if (size(x_out) <= 0) then
     272            0 :          print *, "Error: x_out array is empty."
     273            0 :          call mesa_error(__FILE__, __LINE__)
     274              :       end if
     275              : 
     276        56665 :       do i = 1, size(x_out)
     277        56665 :          call linear_interpolate(x_in, y_in, x_out(i), y_out(i))
     278              :       end do
     279            7 :    end subroutine interpolate_array
     280              : 
     281              : end module knn_interp
        

Generated by: LCOV version 2.0-1