LCOV - code coverage report
Current view: top level - num/private - mod_simplex.f90 (source / functions) Coverage Total Hit
Test: coverage.info Lines: 88.1 % 159 140
Test Date: 2025-05-08 18:23:42 Functions: 88.9 % 9 8

            Line data    Source code
       1              : ! ***********************************************************************
       2              : !
       3              : !   Copyright (C) 2012-2019  Bill Paxton & The MESA Team
       4              : !
       5              : !   This program is free software: you can redistribute it and/or modify
       6              : !   it under the terms of the GNU Lesser General Public License
       7              : !   as published by the Free Software Foundation,
       8              : !   either version 3 of the License, or (at your option) any later version.
       9              : !
      10              : !   This program is distributed in the hope that it will be useful,
      11              : !   but WITHOUT ANY WARRANTY; without even the implied warranty of
      12              : !   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
      13              : !   See the GNU Lesser General Public License for more details.
      14              : !
      15              : !   You should have received a copy of the GNU Lesser General Public License
      16              : !   along with this program. If not, see <https://www.gnu.org/licenses/>.
      17              : !
      18              : ! ***********************************************************************
      19              : 
      20              :       module mod_simplex
      21              : 
      22              :       use const_def, only: dp
      23              :       use math_lib
      24              :       use num_def
      25              : 
      26              :       implicit none
      27              : 
      28              :       contains
      29              : 
      30            6 :       subroutine do_simplex( &
      31            6 :             n, x_lower, x_upper, x_first, x_final, f_final, &
      32            6 :             simplex, f, start_from_given_simplex_and_f, &
      33              :             fcn, x_atol, x_rtol, iter_max, fcn_calls_max, &
      34              :             centroid_weight_power, enforce_bounds, &
      35              :             adaptive_random_search, seed, &
      36              :             alpha, beta, gamma, delta, &
      37              :             lrpar, rpar, lipar, ipar, &
      38              :             num_iters, num_fcn_calls, &
      39              :             num_fcn_calls_for_ars, num_accepted_for_ars, ierr)
      40              :          integer, intent(in) :: n  ! number of dimensions
      41              :          real(dp), intent(in) :: x_lower(:), x_upper(:), x_first(:)  ! (n)
      42              :          real(dp), intent(inout) :: x_final(:)  ! (n)
      43              :          real(dp), intent(inout) :: simplex(:,:)  ! (n,n+1)
      44              :          real(dp), intent(inout) :: f(:)  ! (n+1)
      45              :          logical, intent(in) :: start_from_given_simplex_and_f
      46              :          interface
      47              : #include "num_simplex_fcn.dek"
      48              :          end interface
      49              :          real(dp), intent(in) :: x_atol, x_rtol, centroid_weight_power
      50              :          integer, intent(inout) :: seed
      51              :          real(dp), intent(in) :: alpha, beta, gamma, delta
      52              :          integer, intent(in) :: iter_max, fcn_calls_max
      53              :          logical, intent(in) :: enforce_bounds, adaptive_random_search
      54              :          integer, intent(in) :: lrpar, lipar
      55              :          integer, intent(inout), pointer :: ipar(:)  ! (lipar)
      56              :          real(dp), intent(inout), pointer :: rpar(:)  ! (lrpar)
      57              :          real(dp), intent(out) :: f_final
      58              :          integer, intent(out) :: num_iters, num_fcn_calls, &
      59              :             num_fcn_calls_for_ars, num_accepted_for_ars, ierr
      60              : 
      61          138 :          real(dp), dimension(n) :: c, x_reflect, x_expand, x_contract, x_ars
      62            6 :          real(dp) :: f_reflect, f_expand, f_contract, f_ars, &
      63            6 :             term1, weight, sum_weight, term_val_x
      64              :          integer :: h, s, l, i, j
      65              : 
      66              :          logical, parameter :: dbg = .false.
      67              : 
      68              :          include 'formats'
      69              : 
      70            6 :          ierr = 0
      71            6 :          num_fcn_calls = 0
      72            6 :          num_fcn_calls_for_ars = 0
      73            6 :          num_accepted_for_ars = 0
      74            6 :          num_iters = 0
      75              : 
      76            6 :          if (.not. start_from_given_simplex_and_f) then
      77            6 :             call set_initial_simplex(ierr)
      78            6 :             if (ierr /= 0) then
      79            0 :                write(*,*) 'ierr while evaluating initial simplex'
      80            0 :                return
      81              :             end if
      82              :          end if
      83              : 
      84         2634 :          do
      85              : 
      86         2640 :             num_iters = num_iters + 1
      87              : 
      88              :             if (dbg) write(*,*)
      89              :             if (dbg) write(*,2) 'iter', num_iters
      90              : 
      91              :             ! h = index of max f
      92              :             ! s = index of 2nd max f
      93              :             ! l = index of min f
      94         2640 :             h = 1; l = 1
      95        13200 :             do j=2,n+1
      96        10560 :                if (f(j) > f(h)) h = j
      97        13200 :                if (f(j) < f(l)) l = j
      98              :             end do
      99              :             s = 0
     100        15840 :             do j=1,n+1
     101        13200 :                if (j == h) cycle
     102        13200 :                if (s <= 0) then
     103              :                   s = j
     104         7920 :                else if (f(j) > f(s)) then
     105         5456 :                   s = j
     106              :                end if
     107              :             end do
     108              : 
     109              :             if (dbg) write(*,2) 'worst', h, f(h), simplex(1:n,h)
     110              :             if (dbg) write(*,2) '2nd worst', s, f(s), simplex(1:n,s)
     111              :             if (dbg) write(*,2) 'best', l, f(l), simplex(1:n,l)
     112              : 
     113              :             ! check for domain convergence
     114              :             term_val_x = 0
     115        15840 :             do j=1,n+1
     116        13200 :                if (j == l) cycle
     117        55440 :                do i=1,n
     118              :                   term1 = abs(simplex(i,j)-simplex(i,l)) / &
     119        42240 :                      (x_atol + x_rtol*max(abs(simplex(i,j)), abs(simplex(i,l))))
     120        55440 :                   if (term1 > term_val_x) term_val_x = term1
     121              :                end do
     122              :             end do
     123              :             if (dbg) write(*,1) 'term_val_x', term_val_x
     124         2640 :             if (term_val_x <= 1d0) exit
     125              : 
     126              :             ! check for failure to converge in allowed iterations or function calls
     127         2634 :             if (num_iters > iter_max .or. num_fcn_calls > fcn_calls_max) exit
     128              : 
     129              :             ! c = centroid excluding worst point
     130        13170 :             c(1:n) = 0
     131              :             sum_weight = 0d0
     132        15804 :             do j=1,n+1
     133        13170 :                if (j == h) cycle
     134        10536 :                if (centroid_weight_power == 0d0) then
     135            0 :                   weight = 1d0
     136              :                else
     137        10536 :                   weight = 1/f(j)
     138        10536 :                   if (centroid_weight_power /= 1d0) &
     139            0 :                      weight = pow(weight,centroid_weight_power)
     140              :                end if
     141        52680 :                do i=1,n
     142        52680 :                   c(i) = c(i) + simplex(i,j)*weight
     143              :                end do
     144        15804 :                sum_weight = sum_weight + weight
     145              :             end do
     146        13170 :             do i=1,n
     147        13170 :                c(i) = c(i)/sum_weight
     148              :             end do
     149              :             if (dbg) write(*,1) 'c', c(1:n)
     150              : 
     151              :             ! transform the simplex
     152              : 
     153         2634 :             call reflect(ierr)
     154         2634 :             if (ierr /= 0) return
     155              :             if (dbg) write(*,1) 'reflect', f_reflect, x_reflect(1:n)
     156              : 
     157         2640 :             if (f_reflect < f(s)) then  ! accept reflect
     158              :                if (dbg) write(*,2) 'accept reflect', num_iters
     159         6075 :                do i=1,n
     160         6075 :                   simplex(i,h) = x_reflect(i)
     161              :                end do
     162         1215 :                f(h) = f_reflect
     163         1215 :                if (f_reflect < f(l)) then  ! try to expand
     164          608 :                   call expand(ierr)
     165          608 :                   if (ierr /= 0) return
     166              :                   if (dbg) write(*,1) 'expand', f_expand, x_expand(1:n)
     167          608 :                   if (f_expand < f(l)) then  ! accept expand
     168              :                      ! note: to keep the simplex large in a good direction,
     169              :                      ! we take x_expand even if f_expand > f_reflect
     170         1890 :                      do i=1,n
     171         1890 :                         simplex(i,h) = x_expand(i)
     172              :                      end do
     173          378 :                      f(h) = f_expand
     174              :                      if (dbg) write(*,2) 'accept expand', num_iters
     175              :                   end if
     176              :                end if
     177              :             else  ! try to contract
     178         1419 :                call contract(ierr)
     179         1419 :                if (ierr /= 0) return
     180              :                if (dbg) write(*,1) 'contract', f_contract, x_contract(1:n)
     181         1419 :                if (f_contract < min(f(h),f_reflect)) then  ! accept contraction
     182              :                   if (dbg) write(*,2) 'accept contraction', num_iters
     183         6750 :                   do i=1,n
     184         6750 :                      simplex(i,h) = x_contract(i)
     185              :                   end do
     186         1350 :                   f(h) = f_contract
     187           69 :                else if (adaptive_random_search) then
     188           69 :                   call ARS(ierr)
     189           69 :                   if (ierr /= 0) return
     190              :                else
     191            0 :                   call shrink
     192              :                end if
     193              :             end if
     194              : 
     195              :          end do
     196              : 
     197            6 :          f_final = f(l)
     198           30 :          do i=1,n
     199           30 :             x_final(i) = simplex(i,l)
     200              :          end do
     201              : 
     202              :          if (dbg) write(*,*)
     203              :          if (dbg) write(*,1) 'final', f_final, x_final(1:n)
     204              :          if (dbg) write(*,*)
     205              : 
     206              : 
     207              :          contains
     208              : 
     209              : 
     210            6 :          subroutine set_initial_simplex(ierr)
     211              :             integer, intent(out) :: ierr
     212              :             integer :: j, i, k
     213              :             logical :: okay
     214              :             include 'formats'
     215              : 
     216            6 :             ierr = 0
     217              : 
     218           30 :             do i=1,n
     219           30 :                simplex(i,n+1) = x_first(i)
     220              :             end do
     221            6 :             f(n+1) = get_val(simplex(:,n+1), simplex_initial, ierr)
     222            6 :             if (ierr /= 0) then
     223              :                if (dbg) write(*,2) 'failed to get value for first simplex point'
     224              :                return
     225              :             end if
     226              : 
     227           30 :             do j=1,n
     228          120 :                do i=1,n
     229          120 :                   simplex(i,j) = x_first(i)
     230              :                end do
     231           24 :                if (x_upper(j) - x_first(j) < x_first(j) - x_lower(j)) then
     232              :                   ! closer to upper, so displace toward lower
     233           13 :                   simplex(j,j) = x_first(j) - 0.25d0*(x_upper(j) - x_lower(j))
     234              :                else
     235           11 :                   simplex(j,j) = x_first(j) + 0.25d0*(x_upper(j) - x_lower(j))
     236              :                end if
     237           24 :                okay = .false.
     238           24 :                do k=1,20
     239           24 :                   f(j) = get_val(simplex(:,j), simplex_initial, ierr)
     240           24 :                   if (ierr == 0) then
     241              :                      okay = .true.
     242              :                      exit
     243              :                   end if
     244              :                   ! move closer to x_first and retry
     245            0 :                   ierr = 0
     246            0 :                   simplex(j,j) = 0.5d0*(simplex(j,j) + x_first(j))
     247            0 :                   if (abs(simplex(j,j) - x_first(j)) < &
     248            0 :                         1d-12*(1d0 + abs(x_first(j)))) exit
     249              :                end do
     250           30 :                if (.not. okay) then
     251            0 :                   ierr = -1
     252              :                   if (dbg) write(*,2) 'failed to get value for initial simplex point'
     253            0 :                   return
     254              :                end if
     255              :             end do
     256              : 
     257              :          end subroutine set_initial_simplex
     258              : 
     259              : 
     260        11175 :          real(dp) function get_val(x, op_code, ierr) result(f)
     261              :             real(dp), intent(in) :: x(:)
     262              :             integer, intent(in) :: op_code  ! what nelder-mead is doing for this call
     263              :             integer, intent(out) :: ierr
     264              :             integer :: i
     265              :             include 'formats'
     266        11175 :             ierr = 0
     267        11175 :             if (enforce_bounds) then
     268        18661 :                do i=1,n
     269        18661 :                   if (x(i) > x_upper(i) .or. x(i) < x_lower(i)) then
     270              :                      if (dbg) write(*,2) 'out of bounds', &
     271              :                         num_iters, x(i), x_lower(i), x_upper(i)
     272        11175 :                      f = 1d99
     273              :                      return
     274              :                   end if
     275              :                end do
     276              :             end if
     277        11102 :             num_fcn_calls = num_fcn_calls + 1
     278        11102 :             f = fcn(n, x, lrpar, rpar, lipar, ipar, op_code, ierr)
     279        11102 :          end function get_val
     280              : 
     281              : 
     282         2634 :          subroutine reflect(ierr)
     283              :             integer, intent(out) :: ierr
     284              :             integer :: i
     285              :             include 'formats'
     286         2634 :             ierr = 0
     287        13170 :             do i=1,n
     288        13170 :                x_reflect(i) = c(i) + alpha*(c(i) - simplex(i,h))
     289              :             end do
     290         2634 :             f_reflect = get_val(x_reflect, simplex_reflect, ierr)
     291         2634 :          end subroutine reflect
     292              : 
     293              : 
     294          608 :          subroutine expand(ierr)
     295              :             integer, intent(out) :: ierr
     296              :             integer :: i
     297              :             include 'formats'
     298          608 :             ierr = 0
     299         3040 :             do i=1,n
     300         3040 :                x_expand(i) = c(i) + beta*(x_reflect(i) - c(i))
     301              :             end do
     302          608 :             f_expand = get_val(x_expand, simplex_expand, ierr)
     303          608 :          end subroutine expand
     304              : 
     305              : 
     306         1419 :          subroutine contract(ierr)
     307              :             integer, intent(out) :: ierr
     308              :             integer :: i, op_code
     309              :             include 'formats'
     310         1419 :             ierr = 0
     311         1419 :             if (f_reflect < f(h)) then  ! outside contraction
     312              :                if (dbg) write(*,1) 'outside contraction'
     313          164 :                op_code = simplex_outside
     314          820 :                do i=1,n
     315          820 :                   x_contract(i) = c(i) + gamma*(x_reflect(i) - c(i))
     316              :                end do
     317              :             else  ! inside contraction
     318              :                if (dbg) write(*,1) 'inside contraction'
     319         1255 :                op_code = simplex_inside
     320         6275 :                do i=1,n
     321         6275 :                   x_contract(i) = c(i) + gamma*(simplex(i,h) - c(i))
     322              :                end do
     323              :             end if
     324         1419 :             f_contract = get_val(x_contract, op_code, ierr)
     325         1419 :          end subroutine contract
     326              : 
     327              : 
     328           69 :          subroutine ARS(ierr)
     329              :             integer, intent(out) :: ierr
     330              :             integer :: i, k, k_max
     331              :             include 'formats'
     332           69 :             ierr = 0
     333           69 :             k_max = 100
     334         6544 :             do k=1,k_max  ! keep trying until find a better random point
     335         6484 :                if (num_fcn_calls > fcn_calls_max) exit
     336         6484 :                call get_point_for_ars(ierr)
     337         6484 :                if (ierr /= 0) return
     338              :                if (dbg) write(*,2) 'adaptive_random_search', num_iters, f_ars, x_ars(1:n)
     339         6484 :                if (f_ars <= f(h)) then  ! accept adaptive random search
     340              :                   if (dbg) write(*,2) 'accept adaptive random search', num_iters, f_ars, x_ars(1:n)
     341           45 :                   do i=1,n
     342           45 :                      simplex(i,h) = x_ars(i)
     343              :                   end do
     344            9 :                   f(h) = f_ars
     345            9 :                   num_accepted_for_ars = num_accepted_for_ars + 1
     346            9 :                   return
     347              :                end if
     348         6544 :                if (dbg) write(*,2) 'reject adaptive random search', num_iters, f_ars, x_ars(1:n)
     349              :             end do
     350              :          end subroutine ARS
     351              : 
     352              : 
     353         6484 :          subroutine get_point_for_ars(ierr)
     354              :             use mod_random, only: r8_uniform_01
     355              :             integer, intent(out) :: ierr
     356              :             integer :: i
     357         6484 :             real(dp) :: rand01
     358              :             include 'formats'
     359         6484 :             ierr = 0
     360        32420 :             do i=1,n
     361        25936 :                rand01 = r8_uniform_01(seed)
     362        32420 :                x_ars(i) = x_lower(i) + rand01*(x_upper(i) - x_lower(i))
     363              :             end do
     364              :             if (dbg) write(*,1) 'adaptive random search', x_ars(1:n)
     365         6484 :             num_fcn_calls_for_ars = num_fcn_calls_for_ars + 1
     366         6484 :             f_ars = get_val(x_ars, simplex_random, ierr)
     367         6484 :          end subroutine get_point_for_ars
     368              : 
     369              : 
     370            0 :          subroutine shrink  ! shrink the simplex towards the best point
     371              :             integer :: j, i
     372              :             include 'formats'
     373            0 :             do j=1,n+1
     374            0 :                if (j == l) cycle
     375            0 :                do i=1,n
     376            0 :                   simplex(i,j) = simplex(i,l) + delta*(simplex(i,j) - simplex(i,l))
     377              :                end do
     378            0 :                f(j) = get_val(simplex(:,j), simplex_shrink, ierr)
     379            0 :                if (ierr /= 0) return
     380            0 :                if (dbg) write(*,2) 'shrink', j, f(j), simplex(1:n,j)
     381              :             end do
     382              :          end subroutine shrink
     383              : 
     384              :       end subroutine do_simplex
     385              : 
     386              :       end module mod_simplex
        

Generated by: LCOV version 2.0-1