Line data Source code
1 : ! ***********************************************************************
2 : !
3 : ! Copyright (C) 2012-2019 Bill Paxton & The MESA Team
4 : !
5 : ! This program is free software: you can redistribute it and/or modify
6 : ! it under the terms of the GNU Lesser General Public License
7 : ! as published by the Free Software Foundation,
8 : ! either version 3 of the License, or (at your option) any later version.
9 : !
10 : ! This program is distributed in the hope that it will be useful,
11 : ! but WITHOUT ANY WARRANTY; without even the implied warranty of
12 : ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
13 : ! See the GNU Lesser General Public License for more details.
14 : !
15 : ! You should have received a copy of the GNU Lesser General Public License
16 : ! along with this program. If not, see <https://www.gnu.org/licenses/>.
17 : !
18 : ! ***********************************************************************
19 :
20 : module mod_simplex
21 :
22 : use const_def, only: dp
23 : use math_lib
24 : use num_def
25 :
26 : implicit none
27 :
28 : contains
29 :
30 6 : subroutine do_simplex( &
31 6 : n, x_lower, x_upper, x_first, x_final, f_final, &
32 6 : simplex, f, start_from_given_simplex_and_f, &
33 : fcn, x_atol, x_rtol, iter_max, fcn_calls_max, &
34 : centroid_weight_power, enforce_bounds, &
35 : adaptive_random_search, seed, &
36 : alpha, beta, gamma, delta, &
37 : lrpar, rpar, lipar, ipar, &
38 : num_iters, num_fcn_calls, &
39 : num_fcn_calls_for_ars, num_accepted_for_ars, ierr)
40 : integer, intent(in) :: n ! number of dimensions
41 : real(dp), intent(in) :: x_lower(:), x_upper(:), x_first(:) ! (n)
42 : real(dp), intent(inout) :: x_final(:) ! (n)
43 : real(dp), intent(inout) :: simplex(:,:) ! (n,n+1)
44 : real(dp), intent(inout) :: f(:) ! (n+1)
45 : logical, intent(in) :: start_from_given_simplex_and_f
46 : interface
47 : #include "num_simplex_fcn.dek"
48 : end interface
49 : real(dp), intent(in) :: x_atol, x_rtol, centroid_weight_power
50 : integer, intent(inout) :: seed
51 : real(dp), intent(in) :: alpha, beta, gamma, delta
52 : integer, intent(in) :: iter_max, fcn_calls_max
53 : logical, intent(in) :: enforce_bounds, adaptive_random_search
54 : integer, intent(in) :: lrpar, lipar
55 : integer, intent(inout), pointer :: ipar(:) ! (lipar)
56 : real(dp), intent(inout), pointer :: rpar(:) ! (lrpar)
57 : real(dp), intent(out) :: f_final
58 : integer, intent(out) :: num_iters, num_fcn_calls, &
59 : num_fcn_calls_for_ars, num_accepted_for_ars, ierr
60 :
61 138 : real(dp), dimension(n) :: c, x_reflect, x_expand, x_contract, x_ars
62 6 : real(dp) :: f_reflect, f_expand, f_contract, f_ars, &
63 6 : term1, weight, sum_weight, term_val_x
64 : integer :: h, s, l, i, j
65 :
66 : logical, parameter :: dbg = .false.
67 :
68 : include 'formats'
69 :
70 6 : ierr = 0
71 6 : num_fcn_calls = 0
72 6 : num_fcn_calls_for_ars = 0
73 6 : num_accepted_for_ars = 0
74 6 : num_iters = 0
75 :
76 6 : if (.not. start_from_given_simplex_and_f) then
77 6 : call set_initial_simplex(ierr)
78 6 : if (ierr /= 0) then
79 0 : write(*,*) 'ierr while evaluating initial simplex'
80 0 : return
81 : end if
82 : end if
83 :
84 2634 : do
85 :
86 2640 : num_iters = num_iters + 1
87 :
88 : if (dbg) write(*,*)
89 : if (dbg) write(*,2) 'iter', num_iters
90 :
91 : ! h = index of max f
92 : ! s = index of 2nd max f
93 : ! l = index of min f
94 2640 : h = 1; l = 1
95 13200 : do j=2,n+1
96 10560 : if (f(j) > f(h)) h = j
97 13200 : if (f(j) < f(l)) l = j
98 : end do
99 : s = 0
100 15840 : do j=1,n+1
101 13200 : if (j == h) cycle
102 13200 : if (s <= 0) then
103 : s = j
104 7920 : else if (f(j) > f(s)) then
105 5456 : s = j
106 : end if
107 : end do
108 :
109 : if (dbg) write(*,2) 'worst', h, f(h), simplex(1:n,h)
110 : if (dbg) write(*,2) '2nd worst', s, f(s), simplex(1:n,s)
111 : if (dbg) write(*,2) 'best', l, f(l), simplex(1:n,l)
112 :
113 : ! check for domain convergence
114 : term_val_x = 0
115 15840 : do j=1,n+1
116 13200 : if (j == l) cycle
117 55440 : do i=1,n
118 : term1 = abs(simplex(i,j)-simplex(i,l)) / &
119 42240 : (x_atol + x_rtol*max(abs(simplex(i,j)), abs(simplex(i,l))))
120 55440 : if (term1 > term_val_x) term_val_x = term1
121 : end do
122 : end do
123 : if (dbg) write(*,1) 'term_val_x', term_val_x
124 2640 : if (term_val_x <= 1d0) exit
125 :
126 : ! check for failure to converge in allowed iterations or function calls
127 2634 : if (num_iters > iter_max .or. num_fcn_calls > fcn_calls_max) exit
128 :
129 : ! c = centroid excluding worst point
130 13170 : c(1:n) = 0
131 : sum_weight = 0d0
132 15804 : do j=1,n+1
133 13170 : if (j == h) cycle
134 10536 : if (centroid_weight_power == 0d0) then
135 0 : weight = 1d0
136 : else
137 10536 : weight = 1/f(j)
138 10536 : if (centroid_weight_power /= 1d0) &
139 0 : weight = pow(weight,centroid_weight_power)
140 : end if
141 52680 : do i=1,n
142 52680 : c(i) = c(i) + simplex(i,j)*weight
143 : end do
144 15804 : sum_weight = sum_weight + weight
145 : end do
146 13170 : do i=1,n
147 13170 : c(i) = c(i)/sum_weight
148 : end do
149 : if (dbg) write(*,1) 'c', c(1:n)
150 :
151 : ! transform the simplex
152 :
153 2634 : call reflect(ierr)
154 2634 : if (ierr /= 0) return
155 : if (dbg) write(*,1) 'reflect', f_reflect, x_reflect(1:n)
156 :
157 2640 : if (f_reflect < f(s)) then ! accept reflect
158 : if (dbg) write(*,2) 'accept reflect', num_iters
159 6075 : do i=1,n
160 6075 : simplex(i,h) = x_reflect(i)
161 : end do
162 1215 : f(h) = f_reflect
163 1215 : if (f_reflect < f(l)) then ! try to expand
164 608 : call expand(ierr)
165 608 : if (ierr /= 0) return
166 : if (dbg) write(*,1) 'expand', f_expand, x_expand(1:n)
167 608 : if (f_expand < f(l)) then ! accept expand
168 : ! note: to keep the simplex large in a good direction,
169 : ! we take x_expand even if f_expand > f_reflect
170 1890 : do i=1,n
171 1890 : simplex(i,h) = x_expand(i)
172 : end do
173 378 : f(h) = f_expand
174 : if (dbg) write(*,2) 'accept expand', num_iters
175 : end if
176 : end if
177 : else ! try to contract
178 1419 : call contract(ierr)
179 1419 : if (ierr /= 0) return
180 : if (dbg) write(*,1) 'contract', f_contract, x_contract(1:n)
181 1419 : if (f_contract < min(f(h),f_reflect)) then ! accept contraction
182 : if (dbg) write(*,2) 'accept contraction', num_iters
183 6750 : do i=1,n
184 6750 : simplex(i,h) = x_contract(i)
185 : end do
186 1350 : f(h) = f_contract
187 69 : else if (adaptive_random_search) then
188 69 : call ARS(ierr)
189 69 : if (ierr /= 0) return
190 : else
191 0 : call shrink
192 : end if
193 : end if
194 :
195 : end do
196 :
197 6 : f_final = f(l)
198 30 : do i=1,n
199 30 : x_final(i) = simplex(i,l)
200 : end do
201 :
202 : if (dbg) write(*,*)
203 : if (dbg) write(*,1) 'final', f_final, x_final(1:n)
204 : if (dbg) write(*,*)
205 :
206 :
207 : contains
208 :
209 :
210 6 : subroutine set_initial_simplex(ierr)
211 : integer, intent(out) :: ierr
212 : integer :: j, i, k
213 : logical :: okay
214 : include 'formats'
215 :
216 6 : ierr = 0
217 :
218 30 : do i=1,n
219 30 : simplex(i,n+1) = x_first(i)
220 : end do
221 6 : f(n+1) = get_val(simplex(:,n+1), simplex_initial, ierr)
222 6 : if (ierr /= 0) then
223 : if (dbg) write(*,2) 'failed to get value for first simplex point'
224 : return
225 : end if
226 :
227 30 : do j=1,n
228 120 : do i=1,n
229 120 : simplex(i,j) = x_first(i)
230 : end do
231 24 : if (x_upper(j) - x_first(j) < x_first(j) - x_lower(j)) then
232 : ! closer to upper, so displace toward lower
233 13 : simplex(j,j) = x_first(j) - 0.25d0*(x_upper(j) - x_lower(j))
234 : else
235 11 : simplex(j,j) = x_first(j) + 0.25d0*(x_upper(j) - x_lower(j))
236 : end if
237 24 : okay = .false.
238 24 : do k=1,20
239 24 : f(j) = get_val(simplex(:,j), simplex_initial, ierr)
240 24 : if (ierr == 0) then
241 : okay = .true.
242 : exit
243 : end if
244 : ! move closer to x_first and retry
245 0 : ierr = 0
246 0 : simplex(j,j) = 0.5d0*(simplex(j,j) + x_first(j))
247 0 : if (abs(simplex(j,j) - x_first(j)) < &
248 0 : 1d-12*(1d0 + abs(x_first(j)))) exit
249 : end do
250 30 : if (.not. okay) then
251 0 : ierr = -1
252 : if (dbg) write(*,2) 'failed to get value for initial simplex point'
253 0 : return
254 : end if
255 : end do
256 :
257 : end subroutine set_initial_simplex
258 :
259 :
260 11175 : real(dp) function get_val(x, op_code, ierr) result(f)
261 : real(dp), intent(in) :: x(:)
262 : integer, intent(in) :: op_code ! what nelder-mead is doing for this call
263 : integer, intent(out) :: ierr
264 : integer :: i
265 : include 'formats'
266 11175 : ierr = 0
267 11175 : if (enforce_bounds) then
268 18661 : do i=1,n
269 18661 : if (x(i) > x_upper(i) .or. x(i) < x_lower(i)) then
270 : if (dbg) write(*,2) 'out of bounds', &
271 : num_iters, x(i), x_lower(i), x_upper(i)
272 11175 : f = 1d99
273 : return
274 : end if
275 : end do
276 : end if
277 11102 : num_fcn_calls = num_fcn_calls + 1
278 11102 : f = fcn(n, x, lrpar, rpar, lipar, ipar, op_code, ierr)
279 11102 : end function get_val
280 :
281 :
282 2634 : subroutine reflect(ierr)
283 : integer, intent(out) :: ierr
284 : integer :: i
285 : include 'formats'
286 2634 : ierr = 0
287 13170 : do i=1,n
288 13170 : x_reflect(i) = c(i) + alpha*(c(i) - simplex(i,h))
289 : end do
290 2634 : f_reflect = get_val(x_reflect, simplex_reflect, ierr)
291 2634 : end subroutine reflect
292 :
293 :
294 608 : subroutine expand(ierr)
295 : integer, intent(out) :: ierr
296 : integer :: i
297 : include 'formats'
298 608 : ierr = 0
299 3040 : do i=1,n
300 3040 : x_expand(i) = c(i) + beta*(x_reflect(i) - c(i))
301 : end do
302 608 : f_expand = get_val(x_expand, simplex_expand, ierr)
303 608 : end subroutine expand
304 :
305 :
306 1419 : subroutine contract(ierr)
307 : integer, intent(out) :: ierr
308 : integer :: i, op_code
309 : include 'formats'
310 1419 : ierr = 0
311 1419 : if (f_reflect < f(h)) then ! outside contraction
312 : if (dbg) write(*,1) 'outside contraction'
313 164 : op_code = simplex_outside
314 820 : do i=1,n
315 820 : x_contract(i) = c(i) + gamma*(x_reflect(i) - c(i))
316 : end do
317 : else ! inside contraction
318 : if (dbg) write(*,1) 'inside contraction'
319 1255 : op_code = simplex_inside
320 6275 : do i=1,n
321 6275 : x_contract(i) = c(i) + gamma*(simplex(i,h) - c(i))
322 : end do
323 : end if
324 1419 : f_contract = get_val(x_contract, op_code, ierr)
325 1419 : end subroutine contract
326 :
327 :
328 69 : subroutine ARS(ierr)
329 : integer, intent(out) :: ierr
330 : integer :: i, k, k_max
331 : include 'formats'
332 69 : ierr = 0
333 69 : k_max = 100
334 6544 : do k=1,k_max ! keep trying until find a better random point
335 6484 : if (num_fcn_calls > fcn_calls_max) exit
336 6484 : call get_point_for_ars(ierr)
337 6484 : if (ierr /= 0) return
338 : if (dbg) write(*,2) 'adaptive_random_search', num_iters, f_ars, x_ars(1:n)
339 6484 : if (f_ars <= f(h)) then ! accept adaptive random search
340 : if (dbg) write(*,2) 'accept adaptive random search', num_iters, f_ars, x_ars(1:n)
341 45 : do i=1,n
342 45 : simplex(i,h) = x_ars(i)
343 : end do
344 9 : f(h) = f_ars
345 9 : num_accepted_for_ars = num_accepted_for_ars + 1
346 9 : return
347 : end if
348 6544 : if (dbg) write(*,2) 'reject adaptive random search', num_iters, f_ars, x_ars(1:n)
349 : end do
350 : end subroutine ARS
351 :
352 :
353 6484 : subroutine get_point_for_ars(ierr)
354 : use mod_random, only: r8_uniform_01
355 : integer, intent(out) :: ierr
356 : integer :: i
357 6484 : real(dp) :: rand01
358 : include 'formats'
359 6484 : ierr = 0
360 32420 : do i=1,n
361 25936 : rand01 = r8_uniform_01(seed)
362 32420 : x_ars(i) = x_lower(i) + rand01*(x_upper(i) - x_lower(i))
363 : end do
364 : if (dbg) write(*,1) 'adaptive random search', x_ars(1:n)
365 6484 : num_fcn_calls_for_ars = num_fcn_calls_for_ars + 1
366 6484 : f_ars = get_val(x_ars, simplex_random, ierr)
367 6484 : end subroutine get_point_for_ars
368 :
369 :
370 0 : subroutine shrink ! shrink the simplex towards the best point
371 : integer :: j, i
372 : include 'formats'
373 0 : do j=1,n+1
374 0 : if (j == l) cycle
375 0 : do i=1,n
376 0 : simplex(i,j) = simplex(i,l) + delta*(simplex(i,j) - simplex(i,l))
377 : end do
378 0 : f(j) = get_val(simplex(:,j), simplex_shrink, ierr)
379 0 : if (ierr /= 0) return
380 0 : if (dbg) write(*,2) 'shrink', j, f(j), simplex(1:n,j)
381 : end do
382 : end subroutine shrink
383 :
384 : end subroutine do_simplex
385 :
386 : end module mod_simplex
|