LCOV - code coverage report
Current view: top level - colors/private - hermite_interp.f90 (source / functions) Coverage Total Hit
Test: coverage.info Lines: 0.0 % 173 0
Test Date: 2026-01-06 18:03:11 Functions: 0.0 % 11 0

            Line data    Source code
       1              : ! ***********************************************************************
       2              : !
       3              : !   Copyright (C) 2025  Niall Miller & The MESA Team
       4              : !
       5              : !   This program is free software: you can redistribute it and/or modify
       6              : !   it under the terms of the GNU Lesser General Public License
       7              : !   as published by the Free Software Foundation,
       8              : !   either version 3 of the License, or (at your option) any later version.
       9              : !
      10              : !   This program is distributed in the hope that it will be useful,
      11              : !   but WITHOUT ANY WARRANTY; without even the implied warranty of
      12              : !   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
      13              : !   See the GNU Lesser General Public License for more details.
      14              : !
      15              : !   You should have received a copy of the GNU Lesser General Public License
      16              : !   along with this program. If not, see <https://www.gnu.org/licenses/>.
      17              : !
      18              : ! ***********************************************************************
      19              : 
      20              : ! ***********************************************************************
      21              : ! Hermite interpolation module for spectral energy distributions (SEDs)
      22              : ! ***********************************************************************
      23              : 
      24              : module hermite_interp
      25              :    use const_def, only: dp
      26              :    use colors_utils, only: dilute_flux
      27              :    implicit none
      28              : 
      29              :    private
      30              :    public :: construct_sed_hermite, hermite_tensor_interp3d
      31              : 
      32              : contains
      33              : 
      34              :    !---------------------------------------------------------------------------
      35              :    ! Main entry point: Construct a SED using Hermite tensor interpolation
      36              :    !---------------------------------------------------------------------------
      37            0 :    subroutine construct_sed_hermite(teff, log_g, metallicity, R, d, file_names, &
      38              :                                     lu_teff, lu_logg, lu_meta, stellar_model_dir, &
      39              :                                     wavelengths, fluxes)
      40              :       real(dp), intent(in) :: teff, log_g, metallicity, R, d
      41              :       real(dp), intent(in) :: lu_teff(:), lu_logg(:), lu_meta(:)
      42              :       character(len=*), intent(in) :: stellar_model_dir
      43              :       character(len=100), intent(in) :: file_names(:)
      44              :       real(dp), dimension(:), allocatable, intent(out) :: wavelengths, fluxes
      45              : 
      46              :       integer :: i, n_lambda, status, n_teff, n_logg, n_meta
      47            0 :       real(dp), dimension(:), allocatable :: interp_flux, diluted_flux
      48            0 :       real(dp), dimension(:, :, :, :), allocatable :: precomputed_flux_cube
      49            0 :       real(dp), dimension(:, :, :), allocatable :: flux_cube_lambda
      50              : 
      51              :       ! Parameter grids
      52            0 :       real(dp), allocatable :: teff_grid(:), logg_grid(:), meta_grid(:)
      53              :       character(len=256) :: bin_filename
      54              : 
      55              :       ! Construct the binary filename
      56            0 :       bin_filename = trim(stellar_model_dir)//'/flux_cube.bin'
      57              : 
      58              :       ! Load the data from binary file
      59              :       call load_binary_data(bin_filename, teff_grid, logg_grid, meta_grid, &
      60            0 :                             wavelengths, precomputed_flux_cube, status)
      61              : 
      62            0 :       n_teff = size(teff_grid)
      63            0 :       n_logg = size(logg_grid)
      64            0 :       n_meta = size(meta_grid)
      65            0 :       n_lambda = size(wavelengths)
      66              : 
      67              :       ! Allocate space for interpolated flux
      68            0 :       allocate (interp_flux(n_lambda))
      69              : 
      70              :       ! Process each wavelength point
      71            0 :       do i = 1, n_lambda
      72            0 :          allocate (flux_cube_lambda(n_teff, n_logg, n_meta))
      73            0 :          flux_cube_lambda = precomputed_flux_cube(:, :, :, i)
      74              : 
      75            0 :          interp_flux(i) = hermite_tensor_interp3d(teff, log_g, metallicity, &
      76            0 :                                                   teff_grid, logg_grid, meta_grid, flux_cube_lambda)
      77              : 
      78            0 :          deallocate (flux_cube_lambda)
      79              :       end do
      80              : 
      81              :       ! Apply distance dilution to get observed flux
      82            0 :       allocate (diluted_flux(n_lambda))
      83            0 :       call dilute_flux(interp_flux, R, d, diluted_flux)
      84            0 :       fluxes = diluted_flux
      85              : 
      86            0 :    end subroutine construct_sed_hermite
      87              : 
      88              : !---------------------------------------------------------------------------
      89              : ! Load data from binary file
      90              : !---------------------------------------------------------------------------
      91            0 :    subroutine load_binary_data(filename, teff_grid, logg_grid, meta_grid, &
      92              :                                wavelengths, flux_cube, status)
      93              :       character(len=*), intent(in) :: filename
      94              :       real(dp), allocatable, intent(out) :: teff_grid(:), logg_grid(:), meta_grid(:)
      95              :       real(dp), allocatable, intent(out) :: wavelengths(:)
      96              :       real(dp), allocatable, intent(out) :: flux_cube(:, :, :, :)
      97              :       integer, intent(out) :: status
      98              : 
      99              :       integer :: unit, n_teff, n_logg, n_meta, n_lambda
     100              : 
     101            0 :       unit = 99
     102            0 :       status = 0
     103              : 
     104              :       ! Open the binary file
     105            0 :       open (unit=unit, file=filename, status='OLD', ACCESS='STREAM', FORM='UNFORMATTED', iostat=status)
     106            0 :       if (status /= 0) then
     107            0 :          print *, 'Error opening binary file:', trim(filename)
     108            0 :          return
     109              :       end if
     110              : 
     111              :       ! Read dimensions
     112            0 :       read (unit, iostat=status) n_teff, n_logg, n_meta, n_lambda
     113            0 :       if (status /= 0) then
     114            0 :          print *, 'Error reading dimensions from binary file'
     115            0 :          close (unit)
     116            0 :          return
     117              :       end if
     118              : 
     119              :       ! Allocate arrays based on dimensions
     120            0 :       allocate (teff_grid(n_teff), STAT=status)
     121            0 :       if (status /= 0) then
     122            0 :          print *, 'Error allocating teff_grid array'
     123            0 :          close (unit)
     124            0 :          return
     125              :       end if
     126              : 
     127            0 :       allocate (logg_grid(n_logg), STAT=status)
     128            0 :       if (status /= 0) then
     129            0 :          print *, 'Error allocating logg_grid array'
     130            0 :          close (unit)
     131            0 :          return
     132              :       end if
     133              : 
     134            0 :       allocate (meta_grid(n_meta), STAT=status)
     135            0 :       if (status /= 0) then
     136            0 :          print *, 'Error allocating meta_grid array'
     137            0 :          close (unit)
     138            0 :          return
     139              :       end if
     140              : 
     141            0 :       allocate (wavelengths(n_lambda), STAT=status)
     142            0 :       if (status /= 0) then
     143            0 :          print *, 'Error allocating wavelengths array'
     144            0 :          close (unit)
     145            0 :          return
     146              :       end if
     147              : 
     148            0 :       allocate (flux_cube(n_teff, n_logg, n_meta, n_lambda), STAT=status)
     149            0 :       if (status /= 0) then
     150            0 :          print *, 'Error allocating flux_cube array'
     151            0 :          close (unit)
     152            0 :          return
     153              :       end if
     154              : 
     155              :       ! Read grid arrays
     156            0 :       read (unit, iostat=status) teff_grid
     157            0 :       if (status /= 0) then
     158            0 :          print *, 'Error reading teff_grid'
     159            0 :          close (unit)
     160            0 :          return
     161              :       end if
     162              : 
     163            0 :       read (unit, iostat=status) logg_grid
     164            0 :       if (status /= 0) then
     165            0 :          print *, 'Error reading logg_grid'
     166            0 :          close (unit)
     167            0 :          return
     168              :       end if
     169              : 
     170            0 :       read (unit, iostat=status) meta_grid
     171            0 :       if (status /= 0) then
     172            0 :          print *, 'Error reading meta_grid'
     173            0 :          close (unit)
     174            0 :          return
     175              :       end if
     176              : 
     177            0 :       read (unit, iostat=status) wavelengths
     178            0 :       if (status /= 0) then
     179            0 :          print *, 'Error reading wavelengths'
     180            0 :          close (unit)
     181            0 :          return
     182              :       end if
     183              : 
     184              :       ! Read flux cube
     185            0 :       read (unit, iostat=status) flux_cube
     186            0 :       if (status /= 0) then
     187            0 :          print *, 'Error reading flux_cube'
     188            0 :          close (unit)
     189            0 :          return
     190              :       end if
     191              : 
     192              :       ! Close file and return success
     193            0 :       close (unit)
     194              :    end subroutine load_binary_data
     195              : 
     196              : 
     197              : 
     198            0 :    function hermite_tensor_interp3d(x_val, y_val, z_val, x_grid, y_grid, &
     199            0 :                                     z_grid, f_values) result(f_interp)
     200              :       real(dp), intent(in) :: x_val, y_val, z_val
     201              :       real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
     202              :       real(dp), intent(in) :: f_values(:, :, :)
     203              :       real(dp) :: f_interp
     204              : 
     205              :       integer :: i_x, i_y, i_z
     206              :       real(dp) :: t_x, t_y, t_z
     207              :       real(dp) :: dx, dy, dz
     208              :       real(dp) :: dx_values(2, 2, 2), dy_values(2, 2, 2), dz_values(2, 2, 2)
     209              :       real(dp) :: values(2, 2, 2)
     210              :       real(dp) :: sum
     211              :       integer :: ix, iy, iz
     212              :       real(dp) :: h_x(2), h_y(2), h_z(2)
     213              :       real(dp) :: hx_d(2), hy_d(2), hz_d(2)
     214              : 
     215              :       ! Find containing cell and parameter values
     216              :       call find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
     217            0 :                                 i_x, i_y, i_z, t_x, t_y, t_z)
     218              : 
     219              :       ! If outside grid, use nearest point
     220              :       if (i_x < 1 .or. i_x >= size(x_grid) .or. &
     221              :           i_y < 1 .or. i_y >= size(y_grid) .or. &
     222            0 :           i_z < 1 .or. i_z >= size(z_grid)) then
     223              : 
     224              :          call find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
     225            0 :                                  i_x, i_y, i_z)
     226            0 :          f_interp = f_values(i_x, i_y, i_z)
     227              :          return
     228              :       end if
     229              : 
     230              :       ! Grid cell spacing
     231            0 :       dx = x_grid(i_x + 1) - x_grid(i_x)
     232            0 :       dy = y_grid(i_y + 1) - y_grid(i_y)
     233            0 :       dz = z_grid(i_z + 1) - z_grid(i_z)
     234              : 
     235              :       ! Extract the local 2x2x2 grid cell and compute derivatives
     236            0 :       do iz = 0, 1
     237            0 :          do iy = 0, 1
     238            0 :             do ix = 0, 1
     239            0 :                values(ix + 1, iy + 1, iz + 1) = f_values(i_x + ix, i_y + iy, i_z + iz)
     240              :                call compute_derivatives_at_point(f_values, i_x + ix, i_y + iy, i_z + iz, &
     241              :                                                  size(x_grid), size(y_grid), size(z_grid), &
     242              :                                                  dx, dy, dz, &
     243              :                                                  dx_values(ix + 1, iy + 1, iz + 1), &
     244              :                                                  dy_values(ix + 1, iy + 1, iz + 1), &
     245            0 :                                                  dz_values(ix + 1, iy + 1, iz + 1))
     246              :             end do
     247              :          end do
     248              :       end do
     249              : 
     250              :       ! Precompute Hermite basis functions and derivatives
     251            0 :       h_x  = [h00(t_x), h01(t_x)]
     252            0 :       hx_d = [h10(t_x), h11(t_x)]
     253            0 :       h_y  = [h00(t_y), h01(t_y)]
     254            0 :       hy_d = [h10(t_y), h11(t_y)]
     255            0 :       h_z  = [h00(t_z), h01(t_z)]
     256            0 :       hz_d = [h10(t_z), h11(t_z)]
     257              : 
     258              :       ! Final interpolation sum
     259            0 :       sum = 0.0_dp
     260            0 :       do iz = 1, 2
     261            0 :          do iy = 1, 2
     262            0 :             do ix = 1, 2
     263            0 :                sum = sum + h_x(ix)*h_y(iy)*h_z(iz)     * values(ix, iy, iz)
     264            0 :                sum = sum + hx_d(ix)*h_y(iy)*h_z(iz)    * dx * dx_values(ix, iy, iz)
     265            0 :                sum = sum + h_x(ix)*hy_d(iy)*h_z(iz)    * dy * dy_values(ix, iy, iz)
     266            0 :                sum = sum + h_x(ix)*h_y(iy)*hz_d(iz)    * dz * dz_values(ix, iy, iz)
     267              :             end do
     268              :          end do
     269              :       end do
     270              : 
     271            0 :       f_interp = sum
     272              :    end function hermite_tensor_interp3d
     273              : 
     274              : 
     275              : 
     276              :    !---------------------------------------------------------------------------
     277              :    ! Find the cell containing the interpolation point
     278              :    !---------------------------------------------------------------------------
     279            0 :    subroutine find_containing_cell(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
     280              :                                    i_x, i_y, i_z, t_x, t_y, t_z)
     281              :       real(dp), intent(in) :: x_val, y_val, z_val
     282              :       real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
     283              :       integer, intent(out) :: i_x, i_y, i_z
     284              :       real(dp), intent(out) :: t_x, t_y, t_z
     285              : 
     286              :       ! Find x interval
     287            0 :       call find_interval(x_grid, x_val, i_x, t_x)
     288              : 
     289              :       ! Find y interval
     290            0 :       call find_interval(y_grid, y_val, i_y, t_y)
     291              : 
     292              :       ! Find z interval
     293            0 :       call find_interval(z_grid, z_val, i_z, t_z)
     294            0 :    end subroutine find_containing_cell
     295              : 
     296              :    !---------------------------------------------------------------------------
     297              :    ! Find the interval in a sorted array containing a value
     298              :    !---------------------------------------------------------------------------
     299              : 
     300            0 :    subroutine find_interval(x, val, i, t)
     301              :       real(dp), intent(in) :: x(:), val
     302              :       integer, intent(out) :: i
     303              :       real(dp), intent(out) :: t
     304              : 
     305              :       integer :: n, lo, hi, mid
     306              :       logical :: dummy_axis
     307              : 
     308            0 :       n = size(x)
     309              : 
     310              :       ! Detect dummy axis: all values == 0, 999, or -999
     311            0 :       dummy_axis = all(x == 0.0_dp) .or. all(x == 999.0_dp) .or. all(x == -999.0_dp)
     312              : 
     313              :       if (dummy_axis) then
     314              :          ! Collapse axis: always use first point, no interpolation
     315            0 :          i = 1
     316            0 :          t = 0.0_dp
     317            0 :          return
     318              :       end if
     319              : 
     320              :       ! ---------- ORIGINAL CODE BELOW ----------------
     321              : 
     322            0 :       if (val <= x(1)) then
     323            0 :          i = 1
     324            0 :          t = 0.0_dp
     325            0 :          return
     326            0 :       else if (val >= x(n)) then
     327            0 :          i = n - 1
     328            0 :          t = 1.0_dp
     329            0 :          return
     330              :       end if
     331              : 
     332              :       lo = 1
     333              :       hi = n
     334            0 :       do while (hi - lo > 1)
     335            0 :          mid = (lo + hi)/2
     336            0 :          if (val >= x(mid)) then
     337              :             lo = mid
     338              :          else
     339            0 :             hi = mid
     340              :          end if
     341              :       end do
     342              : 
     343            0 :       i = lo
     344            0 :       t = (val - x(i))/(x(i + 1) - x(i))
     345              :    end subroutine find_interval
     346              : 
     347              : 
     348              :    !---------------------------------------------------------------------------
     349              :    ! Find the nearest grid point
     350              :    !---------------------------------------------------------------------------
     351            0 :    subroutine find_nearest_point(x_val, y_val, z_val, x_grid, y_grid, z_grid, &
     352              :                                  i_x, i_y, i_z)
     353              :       real(dp), intent(in) :: x_val, y_val, z_val
     354              :       real(dp), intent(in) :: x_grid(:), y_grid(:), z_grid(:)
     355              :       integer, intent(out) :: i_x, i_y, i_z
     356              : 
     357              :       ! Find nearest grid points using intrinsic minloc
     358            0 :       i_x = minloc(abs(x_val - x_grid), 1)
     359            0 :       i_y = minloc(abs(y_val - y_grid), 1)
     360            0 :       i_z = minloc(abs(z_val - z_grid), 1)
     361            0 :    end subroutine find_nearest_point
     362              : 
     363              :    !---------------------------------------------------------------------------
     364              :    ! Compute derivatives at a grid point
     365              :    !---------------------------------------------------------------------------
     366            0 :    subroutine compute_derivatives_at_point(f, i, j, k, nx, ny, nz, dx, dy, dz, &
     367              :                                            df_dx, df_dy, df_dz)
     368              :       real(dp), intent(in) :: f(:, :, :)
     369              :       integer, intent(in) :: i, j, k, nx, ny, nz
     370              :       real(dp), intent(in) :: dx, dy, dz
     371              :       real(dp), intent(out) :: df_dx, df_dy, df_dz
     372              : 
     373              :       ! Compute x derivative using centered differences where possible
     374            0 :       if (i > 1 .and. i < nx) then
     375            0 :          df_dx = (f(i + 1, j, k) - f(i - 1, j, k))/(2.0_dp*dx)
     376            0 :       else if (i == 1) then
     377            0 :          df_dx = (f(i + 1, j, k) - f(i, j, k))/dx
     378              :       else ! i == nx
     379            0 :          df_dx = (f(i, j, k) - f(i - 1, j, k))/dx
     380              :       end if
     381              : 
     382              :       ! Compute y derivative using centered differences where possible
     383            0 :       if (j > 1 .and. j < ny) then
     384            0 :          df_dy = (f(i, j + 1, k) - f(i, j - 1, k))/(2.0_dp*dy)
     385            0 :       else if (j == 1) then
     386            0 :          df_dy = (f(i, j + 1, k) - f(i, j, k))/dy
     387              :       else ! j == ny
     388            0 :          df_dy = (f(i, j, k) - f(i, j - 1, k))/dy
     389              :       end if
     390              : 
     391              :       ! Compute z derivative using centered differences where possible
     392            0 :       if (k > 1 .and. k < nz) then
     393            0 :          df_dz = (f(i, j, k + 1) - f(i, j, k - 1))/(2.0_dp*dz)
     394            0 :       else if (k == 1) then
     395            0 :          df_dz = (f(i, j, k + 1) - f(i, j, k))/dz
     396              :       else ! k == nz
     397            0 :          df_dz = (f(i, j, k) - f(i, j, k - 1))/dz
     398              :       end if
     399            0 :    end subroutine compute_derivatives_at_point
     400              : 
     401              :    !---------------------------------------------------------------------------
     402              :    ! Hermite basis functions
     403              :    !---------------------------------------------------------------------------
     404            0 :    function h00(t) result(h)
     405              :       real(dp), intent(in) :: t
     406              :       real(dp) :: h
     407            0 :       h = 2.0_dp*t**3 - 3.0_dp*t**2 + 1.0_dp
     408            0 :    end function h00
     409              : 
     410            0 :    function h10(t) result(h)
     411              :       real(dp), intent(in) :: t
     412              :       real(dp) :: h
     413            0 :       h = t**3 - 2.0_dp*t**2 + t
     414            0 :    end function h10
     415              : 
     416            0 :    function h01(t) result(h)
     417              :       real(dp), intent(in) :: t
     418              :       real(dp) :: h
     419            0 :       h = -2.0_dp*t**3 + 3.0_dp*t**2
     420            0 :    end function h01
     421              : 
     422            0 :    function h11(t) result(h)
     423              :       real(dp), intent(in) :: t
     424              :       real(dp) :: h
     425            0 :       h = t**3 - t**2
     426            0 :    end function h11
     427              : 
     428              : end module hermite_interp
        

Generated by: LCOV version 2.0-1