Line data Source code
1 : ! ***********************************************************************
2 : !
3 : ! Copyright (C) 2012 The MESA Team
4 : !
5 : ! This program is free software: you can redistribute it and/or modify
6 : ! it under the terms of the GNU Lesser General Public License
7 : ! as published by the Free Software Foundation,
8 : ! either version 3 of the License, or (at your option) any later version.
9 : !
10 : ! This program is distributed in the hope that it will be useful,
11 : ! but WITHOUT ANY WARRANTY; without even the implied warranty of
12 : ! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
13 : ! See the GNU Lesser General Public License for more details.
14 : !
15 : ! You should have received a copy of the GNU Lesser General Public License
16 : ! along with this program. If not, see <https://www.gnu.org/licenses/>.
17 : !
18 : ! ***********************************************************************
19 :
20 : ! derived from BCYCLIC written hirshman et. al.
21 : ! S.P.Hirshman, K.S.Perumalla, V.E.Lynch, & R.Sanchez,
22 : ! BCYCLIC: A parallel block tridiagonal matrix cyclic solver,
23 : ! J. Computational Physics, 229 (2010) 6392-6404.
24 :
25 :
26 : module star_bcyclic
27 :
28 : !use caliper_mod ! timing with caliper
29 : use star_private_def
30 : use const_def, only: dp, ln10
31 : use utils_lib, only: set_nan
32 :
33 : implicit none
34 :
35 : private
36 : public :: bcyclic_factor, bcyclic_solve, clear_storage
37 :
38 : logical, parameter :: dbg = .false.
39 : logical, parameter :: do_set_nan = .false.
40 :
41 : contains
42 :
43 33 : subroutine bcyclic_factor ( &
44 : s, nvar, nz, lblk1, dblk1, ublk1, lblkF1, dblkF1, ublkF1, ipivot1, &
45 : B1, row_scale_factors1, col_scale_factors1, &
46 33 : equed1, iter, ierr)
47 : type (star_info), pointer :: s
48 : integer, intent(in) :: nvar ! linear size of each block
49 : integer, intent(in) :: nz ! number of block rows
50 : real(dp), pointer, dimension(:) :: &
51 : lblk1, dblk1, ublk1, lblkF1, dblkF1, ublkF1, &
52 : B1, row_scale_factors1, col_scale_factors1
53 : integer, pointer :: ipivot1(:)
54 : character (len=nz) :: equed1
55 : integer, intent(in) :: iter ! solver iteration number for debugging output
56 : integer, intent(out) :: ierr
57 :
58 33 : integer, pointer :: nslevel(:), ipivot(:)
59 : integer :: neq, ncycle, nstemp, maxlevels, nlevel, i, k
60 : logical :: have_odd_storage
61 33 : real(dp), pointer, dimension(:,:) :: dmat, dmatF
62 33 : real(dp), pointer, dimension(:) :: row_scale_factors, col_scale_factors
63 : character (len=1) :: equed
64 33 : real(dp) :: min_rcond_from_DGESVX, rpgfac
65 : integer :: k_min_rcond_from_DGESVX
66 :
67 : integer, allocatable :: factored(:)
68 :
69 : include 'formats'
70 :
71 0 : if (s% use_DGESVX_in_bcyclic .and. s% report_min_rcond_from_DGESXV) &
72 0 : min_rcond_from_DGESVX = 1d99
73 :
74 33 : allocate(factored(nz))
75 39294 : do k=1,nz
76 39294 : factored(k) = 0
77 : end do
78 :
79 33 : ierr = 0
80 33 : neq = nvar*nz
81 33 : !$OMP PARALLEL DO SIMD
82 : do i = 1,nvar*neq
83 : lblkF1(i) = lblk1(i)
84 : dblkF1(i) = dblk1(i)
85 : ublkF1(i) = ublk1(i)
86 : end do
87 : !$OMP END PARALLEL DO SIMD
88 :
89 : if (dbg) write(*,*) 'start bcyclic_factor'
90 :
91 : ! compute number of cyclic reduction levels
92 33 : ncycle = 1
93 33 : maxlevels = 0
94 375 : do while (ncycle < nz)
95 342 : ncycle = 2*ncycle
96 342 : maxlevels = maxlevels+1
97 : end do
98 33 : maxlevels = max(1, maxlevels)
99 :
100 33 : have_odd_storage = associated(s% bcyclic_odd_storage)
101 33 : if (have_odd_storage) then
102 32 : if (size(s% bcyclic_odd_storage) < maxlevels) then
103 0 : call clear_storage(s)
104 : have_odd_storage = .false.
105 : end if
106 : end if
107 :
108 33 : if (.not. have_odd_storage) then
109 1 : allocate (s% bcyclic_odd_storage(maxlevels+3), stat=ierr)
110 1 : if (ierr /= 0) then
111 0 : write(*,*) 'alloc failed for odd_storage in bcyclic'
112 0 : return
113 : end if
114 16 : do nlevel = 1, size(s% bcyclic_odd_storage)
115 16 : s% bcyclic_odd_storage(nlevel)% ul_size = 0
116 : end do
117 : end if
118 :
119 33 : allocate (nslevel(maxlevels), stat=ierr)
120 33 : if (ierr /= 0) return
121 :
122 33 : ncycle = 1
123 33 : nstemp = nz
124 33 : nlevel = 1
125 :
126 : if (dbg) write(*,*) 'start factor_cycle'
127 :
128 : !call cali_begin_phase('factor_cycle')
129 342 : factor_cycle: do ! perform cyclic-reduction factorization
130 :
131 342 : nslevel(nlevel) = nstemp
132 :
133 : if (dbg) write(*,2) 'call cycle_onestep', nstemp
134 :
135 : call cycle_onestep( &
136 : s, nvar, nz, nstemp, ncycle, nlevel, iter, &
137 : lblk1, dblk1, ublk1, lblkF1, dblkF1, ublkF1, ipivot1, &
138 : B1, row_scale_factors1, col_scale_factors1, equed1, factored, &
139 : min_rcond_from_DGESVX, k_min_rcond_from_DGESVX, rpgfac, &
140 342 : ierr)
141 342 : if (ierr /= 0) then
142 0 : call dealloc
143 0 : return
144 : end if
145 :
146 342 : if (nstemp == 1) exit factor_cycle
147 :
148 342 : nstemp = (nstemp+1)/2
149 342 : nlevel = nlevel+1
150 342 : ncycle = 2*ncycle
151 :
152 342 : if (nlevel > maxlevels) exit factor_cycle
153 :
154 : end do factor_cycle
155 : !call cali_end_phase('factor_cycle')
156 :
157 : if (dbg) write(*,*) 'done factor_cycle'
158 :
159 : ! factor row 1
160 33 : dmat(1:nvar,1:nvar) => dblk1(1:nvar*nvar)
161 33 : dmatF(1:nvar,1:nvar) => dblkF1(1:nvar*nvar)
162 33 : ipivot(1:nvar) => ipivot1(1:nvar)
163 33 : row_scale_factors(1:nvar) => row_scale_factors1(1:nvar)
164 33 : col_scale_factors(1:nvar) => col_scale_factors1(1:nvar)
165 33 : factored(1) = factored(1) + 1
166 : call dense_factor(s, 1, nvar, dmat, dmatF, ipivot, &
167 : row_scale_factors, col_scale_factors, equed, &
168 : min_rcond_from_DGESVX, k_min_rcond_from_DGESVX, rpgfac, &
169 33 : ierr)
170 33 : equed1(1:1) = equed(1:1)
171 33 : if (ierr /= 0) then
172 0 : write(*,*) 'dense_factor failed'
173 0 : call dealloc
174 0 : return
175 : end if
176 :
177 39294 : do k=1,nz ! check that every cell factored exactly once
178 39294 : if (factored(k) /= 1) then
179 0 : write(*,3) 'factored /= 1', k, factored(k)
180 0 : call mesa_error(__FILE__,__LINE__,'bcyclic_factor')
181 : end if
182 : end do
183 :
184 33 : call dealloc
185 :
186 33 : if (s% use_DGESVX_in_bcyclic .and. s% report_min_rcond_from_DGESXV) then
187 0 : write(*,4) 'DGESVX: k_min, iter, model, min rcond, rpgfac', &
188 0 : k_min_rcond_from_DGESVX, iter, s% model_number, min_rcond_from_DGESVX, rpgfac
189 : end if
190 :
191 66 : if (dbg) write(*,*) 'done bcyclic_factor'
192 :
193 : contains
194 :
195 33 : subroutine dealloc
196 33 : deallocate (nslevel)
197 33 : end subroutine dealloc
198 :
199 :
200 : end subroutine bcyclic_factor
201 :
202 :
203 342 : subroutine cycle_onestep( &
204 : s, nvar, nz, nblk, ncycle, nlevel, iter, &
205 : lblk1, dblk1, ublk1, lblkF1, dblkF1, ublkF1, ipivot1, &
206 342 : B1, row_scale_factors1, col_scale_factors1, equed1, factored, &
207 : min_rcond_from_DGESVX, k_min_rcond_from_DGESVX, rpgfac, &
208 : ierr)
209 : type (star_info), pointer :: s
210 : integer, intent(in) :: nvar, nz, nblk, ncycle, nlevel, iter
211 : real(dp), pointer, dimension(:), intent(inout) :: &
212 : lblk1, dblk1, ublk1, lblkF1, dblkF1, ublkF1, &
213 : B1, row_scale_factors1, col_scale_factors1
214 : character (len=nz) :: equed1
215 : integer, pointer, intent(inout) :: ipivot1(:)
216 : integer :: factored(:)
217 : real(dp) :: min_rcond_from_DGESVX, rpgfac
218 : integer :: k_min_rcond_from_DGESVX
219 : integer, intent(out) :: ierr
220 :
221 342 : integer, pointer :: ipivot(:)
222 342 : real(dp), pointer, dimension(:,:) :: dmat, umat, lmat, umat0, lmat0, dmatF
223 342 : real(dp), pointer, dimension(:,:) :: lnext, unext, lprev, uprev
224 342 : real(dp), pointer, dimension(:) :: mat1
225 : integer :: i, j, shift, min_sz, new_sz, shift1, shift2, nvar2, &
226 : ns, op_err, nmin, kcount, k
227 342 : real(dp), pointer, dimension(:) :: row_scale_factors, col_scale_factors
228 : character (len=1) :: equed
229 :
230 : include 'formats'
231 :
232 342 : ierr = 0
233 342 : nvar2 = nvar*nvar
234 342 : nmin = 1
235 342 : kcount = 1+(nblk-nmin)/2
236 342 : min_sz = nvar2*kcount
237 342 : if (s% bcyclic_odd_storage(nlevel)% ul_size < min_sz) then
238 12 : if (s% bcyclic_odd_storage(nlevel)% ul_size > 0) &
239 0 : deallocate( &
240 : s% bcyclic_odd_storage(nlevel)% umat1, &
241 0 : s% bcyclic_odd_storage(nlevel)% lmat1)
242 12 : new_sz = min_sz*1.1d0 + 100
243 12 : s% bcyclic_odd_storage(nlevel)% ul_size = new_sz
244 : allocate (s% bcyclic_odd_storage(nlevel)% umat1(new_sz), &
245 12 : s% bcyclic_odd_storage(nlevel)% lmat1(new_sz), stat=ierr)
246 12 : if (ierr /= 0) then
247 0 : write(*,*) 'allocation error in cycle_onestep'
248 0 : return
249 : end if
250 : end if
251 :
252 : !call cali_begin_phase('co.loop1')
253 342 : !$OMP PARALLEL DO PRIVATE(ns, shift, shift2, i) COLLAPSE(2)
254 : do ns = nmin, nblk, 2 ! copy umat and lmat
255 : do i = 1, nvar2
256 : ! kcount = (ns-nmin)/2 + 1
257 : ! shift = nvar2*(kcount-1)
258 : shift = nvar2*(ns-nmin)/2
259 : shift2 = nvar2*ncycle*(ns-1)
260 : s% bcyclic_odd_storage(nlevel)% umat1(shift+i) = ublkF1(shift2+i)
261 : s% bcyclic_odd_storage(nlevel)% lmat1(shift+i) = lblkF1(shift2+i)
262 : end do
263 : end do
264 : !$OMP END PARALLEL DO
265 : !call cali_end_phase('co.loop1')
266 :
267 342 : if (nvar2*kcount > s% bcyclic_odd_storage(nlevel)% ul_size) then
268 0 : write(*,*) 'nvar2*kcount > ul_size in cycle_onestep'
269 0 : ierr = -1
270 0 : return
271 : end if
272 :
273 : if (dbg) write(*,*) 'start lu factorization'
274 : ! compute lu factorization of even diagonal blocks
275 342 : nmin = 2
276 : !call cali_begin_phase('co.loop2')
277 342 : !$OMP PARALLEL DO PRIVATE(ipivot,dmat,dmatF,ns,op_err,shift1,shift2,k,row_scale_factors,col_scale_factors,equed)
278 : do ns = nmin, nblk, 2
279 :
280 : k = ncycle*(ns-1) + 1
281 : shift1 = nvar*(k-1)
282 : shift2 = nvar*shift1
283 : dmat(1:nvar,1:nvar) => dblk1(shift2+1:shift2+nvar2)
284 : dmatF(1:nvar,1:nvar) => dblkF1(shift2+1:shift2+nvar2)
285 : op_err = 0
286 : ipivot(1:nvar) => ipivot1(shift1+1:shift1+nvar)
287 : row_scale_factors(1:nvar) => row_scale_factors1(shift1+1:shift1+nvar)
288 : col_scale_factors(1:nvar) => col_scale_factors1(shift1+1:shift1+nvar)
289 : factored(k) = factored(k) + 1
290 : call dense_factor(s, k, nvar, dmat, dmatF, ipivot, &
291 : row_scale_factors, col_scale_factors, equed, &
292 : min_rcond_from_DGESVX, k_min_rcond_from_DGESVX, rpgfac, &
293 : op_err)
294 : equed1(k:k) = equed(1:1)
295 : if (op_err /= 0) then
296 : ierr = op_err
297 : end if
298 :
299 : end do
300 : !$OMP END PARALLEL DO
301 : !call cali_end_phase('co.loop2')
302 342 : if (ierr /= 0) then
303 : !write(*,*) 'factorization failed in bcyclic'
304 : return
305 : end if
306 :
307 : if (dbg) write(*,*) 'done lu factorization; start solve'
308 :
309 : !call cali_begin_phase('co.loop3')
310 : !$OMP PARALLEL DO SCHEDULE(static,3) &
311 342 : !$OMP PRIVATE(ns,k,shift1,shift2,ipivot,dmat,dmatF,umat,lmat,mat1,i,j,row_scale_factors,col_scale_factors,equed,op_err)
312 : do ns = nmin, nblk, 2
313 : ! compute new l=-d[-1]l, u=-d[-1]u for even blocks
314 : k = ncycle*(ns-1) + 1
315 : shift1 = nvar*(k-1)
316 : shift2 = nvar*shift1
317 :
318 : lmat(1:nvar,1:nvar) => lblkF1(shift2+1:shift2+nvar2)
319 :
320 : dmat(1:nvar,1:nvar) => dblk1(shift2+1:shift2+nvar2)
321 : dmatF(1:nvar,1:nvar) => dblkF1(shift2+1:shift2+nvar2)
322 : ipivot(1:nvar) => ipivot1(shift1+1:shift1+nvar)
323 : row_scale_factors(1:nvar) => row_scale_factors1(shift1+1:shift1+nvar)
324 : col_scale_factors(1:nvar) => col_scale_factors1(shift1+1:shift1+nvar)
325 : equed(1:1) = equed1(k:k)
326 : call dense_solve(s, k, nvar, dmat, dmatF, ipivot, lmat, &
327 : row_scale_factors, col_scale_factors, equed, op_err)
328 : if (op_err /= 0) then
329 : ierr = op_err
330 : cycle
331 : end if
332 :
333 : do j=1,nvar
334 : !$OMP SIMD
335 : do i=1,nvar
336 : lmat(i,j) = -lmat(i,j)
337 : end do
338 : end do
339 :
340 : umat(1:nvar,1:nvar) => ublkF1(shift2+1:shift2+nvar2)
341 :
342 : call dense_solve(s, k, nvar, dmat, dmatF, ipivot, umat, &
343 : row_scale_factors, col_scale_factors, equed, op_err)
344 : if (op_err /= 0) then
345 : ierr = op_err
346 : cycle
347 : end if
348 :
349 : do j=1,nvar
350 : !$OMP SIMD
351 : do i=1,nvar
352 : umat(i,j) = -umat(i,j)
353 : end do
354 : end do
355 :
356 : end do
357 : !$OMP END PARALLEL DO
358 : !call cali_end_phase('co.loop3')
359 : if (dbg) write(*,*) 'done solve'
360 :
361 342 : if (ierr /= 0) return
362 :
363 : ! compute new odd blocks in terms of even block factors
364 : ! compute odd hatted matrix elements except at boundaries
365 342 : nmin = 1
366 : !call cali_begin_phase('co.loop4')
367 : !$OMP PARALLEL DO SCHEDULE(static,3) &
368 342 : !$OMP PRIVATE(i,ns,shift2,dmat,umat,lmat,lnext,unext,lprev,uprev,shift,umat0,lmat0,k)
369 : do i = 1, 3*(1+(nblk-nmin)/2)
370 :
371 : ns = 2*((i-1)/3) + nmin
372 : k = ncycle*(ns-1) + 1
373 : if (factored(k) > 0) then
374 : write(*,2) 'compute new dmat after already factored', k
375 : call mesa_error(__FILE__,__LINE__,'cycle_onestep')
376 : end if
377 : shift2 = nvar2*(k-1)
378 : dmat(1:nvar,1:nvar) => dblkF1(shift2+1:shift2+nvar2)
379 : umat(1:nvar,1:nvar) => ublkF1(shift2+1:shift2+nvar2)
380 : lmat(1:nvar,1:nvar) => lblkF1(shift2+1:shift2+nvar2)
381 :
382 : if (ns < nblk) then
383 : shift2 = nvar2*ncycle*ns
384 : lnext(1:nvar,1:nvar) => lblkF1(shift2+1:shift2+nvar2)
385 : unext(1:nvar,1:nvar) => ublkF1(shift2+1:shift2+nvar2)
386 : end if
387 :
388 : if (ns > 1) then
389 : shift2 = nvar2*ncycle*(ns-2)
390 : lprev(1:nvar,1:nvar) => lblkF1(shift2+1:shift2+nvar2)
391 : uprev(1:nvar,1:nvar) => ublkF1(shift2+1:shift2+nvar2)
392 : end if
393 :
394 : !kcount = 1+(ns-nmin)/2
395 : !shift = nvar2*(kcount-1)
396 : shift = nvar2*(ns-nmin)/2
397 : lmat0(1:nvar,1:nvar) => &
398 : s% bcyclic_odd_storage(nlevel)% lmat1(shift+1:shift+nvar2)
399 : umat0(1:nvar,1:nvar) => &
400 : s% bcyclic_odd_storage(nlevel)% umat1(shift+1:shift+nvar2)
401 :
402 : select case(mod(i-1,3))
403 : case (0)
404 : if (ns > 1) then
405 : ! lmat = matmul(lmat0, lprev)
406 : call my_gemm0_p1(nvar,nvar,nvar,lmat0,nvar,lprev,nvar,lmat,nvar)
407 : end if
408 : case (1)
409 : if (ns < nblk) then
410 : ! umat = matmul(umat0, unext)
411 : call my_gemm0_p1(nvar,nvar,nvar,umat0,nvar,unext,nvar,umat,nvar)
412 : end if
413 : case (2)
414 : if (ns < nblk) then
415 : if (ns > 1) then
416 : ! dmat = dmat + matmul(umat0, lnext) + matmul(lmat0,uprev)
417 : call my_gemm_plus_mm(nvar,nvar,nvar,umat0,lnext,lmat0,uprev,dmat)
418 : else
419 : ! dmat = dmat + matmul(umat0, lnext)
420 : call my_gemm_p1(nvar,nvar,nvar,umat0,nvar,lnext,nvar,dmat,nvar)
421 : end if
422 : else if (ns > 1) then
423 : ! dmat = dmat + matmul(lmat0,uprev)
424 : call my_gemm_p1(nvar,nvar,nvar,lmat0,nvar,uprev,nvar,dmat,nvar)
425 : end if
426 : end select
427 :
428 : end do
429 : !$OMP END PARALLEL DO
430 : !call cali_end_phase('co.loop4')
431 : if (dbg) write(*,*) 'done cycle_onestep'
432 :
433 342 : end subroutine cycle_onestep
434 :
435 :
436 342 : subroutine cycle_rhs( &
437 : s, nz, nblk, nvar, ncycle, nlevel, &
438 : dblk1, dblkF1, soln1, ipivot1, &
439 342 : row_scale_factors1, col_scale_factors1, equed1, ierr)
440 : !use chem_def, only: chem_isos
441 : type (star_info), pointer :: s
442 : integer, intent(in) :: nz, nblk, nvar, ncycle, nlevel
443 : real(dp), pointer, intent(in), dimension(:) :: &
444 : dblk1, dblkF1, row_scale_factors1, col_scale_factors1
445 : real(dp), pointer, intent(inout) :: soln1(:)
446 : integer, pointer, intent(in) :: ipivot1(:)
447 : character (len=nz) :: equed1
448 : integer, intent(out) :: ierr
449 :
450 : integer :: k, ns, op_err, nmin, kcount, shift, shift1, shift2, nvar2
451 342 : integer, pointer :: ipivot(:)
452 342 : real(dp), pointer, dimension(:,:) :: dmatF, dmat, umat, lmat
453 342 : real(dp), pointer, dimension(:) :: X, Xprev, Xnext
454 342 : real(dp), pointer, dimension(:) :: row_scale_factors, col_scale_factors
455 : character (len=1) :: equed
456 :
457 : include 'formats'
458 :
459 342 : ierr = 0
460 342 : nvar2 = nvar*nvar
461 : ! compute dblk[-1]*brhs for even indices and store in brhs(even)
462 342 : nmin = 2
463 342 : op_err = 0
464 : !call cali_begin_phase('cr.loop1')
465 : !$OMP PARALLEL DO &
466 342 : !$OMP PRIVATE(ns,shift1,ipivot,shift2,k,dmat,dmatF,X,row_scale_factors,col_scale_factors,equed,op_err)
467 : do ns = nmin, nblk, 2
468 : k = ncycle*(ns-1) + 1
469 : shift1 = nvar*(k-1)
470 : shift2 = nvar*shift1
471 : dmat(1:nvar,1:nvar) => dblk1(shift2+1:shift2+nvar2)
472 : dmatF(1:nvar,1:nvar) => dblkF1(shift2+1:shift2+nvar2)
473 : ipivot(1:nvar) => ipivot1(shift1+1:shift1+nvar)
474 : row_scale_factors(1:nvar) => row_scale_factors1(shift1+1:shift1+nvar)
475 : col_scale_factors(1:nvar) => col_scale_factors1(shift1+1:shift1+nvar)
476 : equed(1:1) = equed1(k:k)
477 : X(1:nvar) => soln1(shift1+1:shift1+nvar)
478 : call dense_solve1(s, k, nvar, X, dmat, dmatF, ipivot, .true., &
479 : row_scale_factors, col_scale_factors, equed, op_err)
480 : if (op_err /= 0) then
481 : ierr = op_err
482 : cycle
483 : end if
484 :
485 : end do
486 : !$OMP END PARALLEL DO
487 : !call cali_end_phase('cr.loop1')
488 :
489 342 : if (ierr /= 0) return
490 :
491 : ! compute odd (hatted) sources (b-hats) for interior rows
492 342 : nmin = 1
493 342 : kcount = 0
494 : !call cali_begin_phase('cr.loop2')
495 342 : !$OMP PARALLEL DO PRIVATE(ns,shift1,X,shift,umat,lmat,Xnext,Xprev)
496 : do ns = nmin, nblk, 2
497 : shift1 = nvar*ncycle*(ns-1)
498 : X(1:nvar) => soln1(shift1+1:shift1+nvar)
499 : !kcount = 1+(ns-nmin)/2
500 : !shift = nvar2*(kcount-1)
501 : shift = nvar2*(ns-nmin)/2
502 : umat(1:nvar,1:nvar) => &
503 : s% bcyclic_odd_storage(nlevel)% umat1(shift+1:shift+nvar2)
504 : lmat(1:nvar,1:nvar) => &
505 : s% bcyclic_odd_storage(nlevel)% lmat1(shift+1:shift+nvar2)
506 : if (ns > 1) then
507 : shift1 = nvar*ncycle*(ns-2)
508 : Xprev => soln1(shift1+1:shift1+nvar)
509 : end if
510 : if (ns < nblk) then
511 : shift1 = nvar*ncycle*ns
512 : Xnext => soln1(shift1+1:shift1+nvar)
513 : if (ns > 1) then
514 : ! bptr = bptr - matmul(umat,bnext) - matmul(lmat,bprev)
515 : call my_gemv_mv(nvar,nvar,umat,Xnext,lmat,Xprev,X)
516 : else
517 : ! bptr = bptr - matmul(umat,bnext)
518 : call my_gemv(nvar,nvar,umat,nvar,Xnext,X)
519 : end if
520 : else if (ns > 1) then
521 : ! bptr = bptr - matmul(lmat,bprev)
522 : call my_gemv(nvar,nvar,lmat,nvar,Xprev,X)
523 : end if
524 : end do
525 : !$OMP END PARALLEL DO
526 : !call cali_end_phase('cr.loop2')
527 :
528 342 : if (nvar2*kcount > s% bcyclic_odd_storage(nlevel)% ul_size) then
529 0 : write(*,*) 'nvar2*kcount > ul_size in cycle_rhs'
530 0 : ierr = -1
531 0 : return
532 : end if
533 :
534 342 : end subroutine cycle_rhs
535 :
536 :
537 : ! computes even index solution from the computed (at previous,higher level)
538 : ! odd index solutions at this level.
539 : ! note at this point, the odd brhs values have been replaced (at the highest cycle)
540 : ! with the solution values (x), at subsequent (lower) cycles, the
541 : ! odd values are replaced by the even solutions at the next highest cycle. the even
542 : ! brhs values were multiplied by d[-1] and stored in cycle_rhs
543 : ! solve for even index values in terms of (computed at this point) odd index values
544 342 : subroutine cycle_solve( &
545 : s, nvar, nz, ncycle, nblk, nlevel, lblk1, ublk1, lblkF1, ublkF1, soln1)
546 : type (star_info), pointer :: s
547 : integer, intent(in) :: nvar, nz, ncycle, nblk, nlevel
548 : real(dp), pointer, intent(in), dimension(:) :: lblk1, ublk1, lblkF1, ublkF1
549 : real(dp), pointer, intent(inout) :: soln1(:)
550 :
551 342 : real(dp), pointer :: umat(:,:), lmat(:,:), bprev(:), bnext(:), bptr(:)
552 : integer :: shift1, shift2, nvar2, ns, nmin
553 :
554 : include 'formats'
555 :
556 342 : nvar2 = nvar*nvar
557 342 : nmin = 2
558 : !call cali_begin_phase('cycle_solve')
559 : !$OMP PARALLEL DO SCHEDULE(static,3) &
560 342 : !$OMP PRIVATE(ns,shift1,bptr,shift2,lmat,bprev,umat,bnext)
561 : do ns = nmin, nblk, 2
562 : shift1 = ncycle*nvar*(ns-1)
563 : bptr(1:nvar) => soln1(shift1+1:shift1+nvar)
564 : shift2 = nvar*shift1
565 : if (ns > 1) then
566 : lmat(1:nvar,1:nvar) => lblkF1(shift2+1:shift2+nvar2)
567 : shift1 = ncycle*nvar*(ns-2)
568 : bprev(1:nvar) => soln1(shift1+1:shift1+nvar)
569 : end if
570 : if (ns < nblk) then
571 : umat(1:nvar,1:nvar) => ublkF1(shift2+1:shift2+nvar2)
572 : shift1 = ncycle*nvar*ns
573 : bnext(1:nvar) => soln1(shift1+1:shift1+nvar)
574 : if (ns > 1) then
575 : ! bptr = bptr + matmul(umat,bnext) + matmul(lmat,bprev)
576 : call my_gemv_p_mv(nvar,nvar,umat,bnext,lmat,bprev,bptr)
577 : else
578 : ! bptr = bptr + matmul(umat,bnext)
579 : call my_gemv_p1(nvar,nvar,umat,nvar,bnext,bptr)
580 : end if
581 : else if (ns > 1) then
582 : ! bptr = bptr + matmul(lmat,bprev)
583 : call my_gemv_p1(nvar,nvar,lmat,nvar,bprev,bptr)
584 : end if
585 : end do
586 : !$OMP END PARALLEL DO
587 : !call cali_end_phase('cycle_solve')
588 :
589 342 : end subroutine cycle_solve
590 :
591 :
592 39261 : subroutine dense_factor(s, k, nvar, mtx, mtxF, ipivot, &
593 : row_scale_factors, col_scale_factors, equed, &
594 : min_rcond_from_DGESVX, k_min_rcond_from_DGESVX, rpgfac, &
595 : ierr)
596 : type (star_info), pointer :: s
597 : integer, intent(in) :: k, nvar
598 : real(dp), pointer :: mtx(:,:), mtxF(:,:)
599 : integer, pointer :: ipivot(:)
600 : real(dp), pointer :: row_scale_factors(:), col_scale_factors(:)
601 : character (len=1) :: equed
602 : real(dp) :: min_rcond_from_DGESVX, rpgfac
603 : integer :: k_min_rcond_from_DGESVX
604 : integer, intent(out) :: ierr
605 : include 'formats'
606 39261 : ierr = 0
607 :
608 39261 : if (s% use_DGESVX_in_bcyclic) then
609 0 : call factor_with_DGESVX
610 0 : return
611 : end if
612 :
613 39261 : if (nvar == 4) then
614 0 : call my_getf2_n4(mtxF, ipivot, ierr)
615 39261 : else if (nvar == 5) then
616 0 : call my_getf2_n5(mtxF, ipivot, ierr)
617 : else
618 39261 : call my_getf2(nvar, mtxF, nvar, ipivot, ierr)
619 : end if
620 :
621 : contains
622 :
623 0 : subroutine factor_with_DGESVX
624 : character (len=1) :: fact, trans
625 : integer, parameter :: nrhs = 0
626 0 : real(dp) :: rcond
627 0 : real(dp) :: a(nvar,nvar), af(nvar,nvar), b(nvar,nrhs), x(nvar,nrhs), &
628 0 : r(nvar), c(nvar), ferr(nrhs), berr(nrhs), work(4*nvar)
629 0 : integer :: ipiv(nvar), iwork(nvar)
630 : integer :: i, j
631 : include 'formats'
632 :
633 0 : do i=1,nvar
634 0 : do j=1,nvar
635 0 : a(i,j) = mtxF(i,j)
636 : end do
637 : end do
638 :
639 0 : if (s% use_equilibration_in_DGESVX) then
640 0 : fact = 'E' ! matrix A will be equilibrated, then copied to AF and factored
641 : else
642 0 : fact = 'N' ! matrix A will be copied to AF and factored
643 : end if
644 0 : trans = 'N' ! no transpose
645 :
646 : ! SUBROUTINE DGESVX( FACT, TRANS, N, NRHS, A, LDA, AF, LDAF, IPIV,
647 : ! $ EQUED, R, C, B, LDB, X, LDX, RCOND, FERR, BERR,
648 : ! $ WORK, IWORK, INFO )
649 :
650 : call DGESVX(fact, trans, nvar, nrhs, a, nvar, af, nvar, ipiv, &
651 : equed, r, c, b, nvar, x, nvar, rcond, ferr, berr, &
652 0 : work, iwork, ierr)
653 :
654 0 : if (ierr > 0 .and. ierr <= nvar) then ! singular
655 0 : write(*,3) 'singular matrix for DGESVX', k, ierr
656 0 : call mesa_error(__FILE__,__LINE__,'factor_with_DGESVX')
657 : end if
658 0 : if (ierr == nvar+1) then ! means bad rcond, but may not be fatal
659 0 : write(*,2) 'DGESVX reports bad matrix conditioning: k, rcond', k, rcond
660 0 : ierr = 0
661 : end if
662 :
663 0 : do i=1,nvar
664 0 : do j=1,nvar
665 0 : mtx(i,j) = a(i,j)
666 0 : mtxF(i,j) = af(i,j)
667 : end do
668 0 : row_scale_factors(i) = r(i)
669 0 : col_scale_factors(i) = c(i)
670 0 : ipivot(i) = ipiv(i)
671 : end do
672 :
673 0 : if (s% report_min_rcond_from_DGESXV .and. rcond < min_rcond_from_DGESVX) then
674 0 : !$OMP CRITICAL (bcyclic_dense_factor_crit)
675 0 : min_rcond_from_DGESVX = rcond
676 0 : k_min_rcond_from_DGESVX = k
677 0 : rpgfac = work(1)
678 : !$OMP END CRITICAL (bcyclic_dense_factor_crit)
679 : end if
680 :
681 0 : end subroutine factor_with_DGESVX
682 :
683 : end subroutine dense_factor
684 :
685 :
686 33 : subroutine bcyclic_solve ( &
687 : s, nvar, nz, lblk1, dblk1, ublk1, lblkF1, dblkF1, ublkF1, ipivot1, &
688 33 : B1, soln1, row_scale_factors1, col_scale_factors1, equed1, &
689 : iter, ierr)
690 : type (star_info), pointer :: s
691 : integer, intent(in) :: nvar, nz, iter
692 : real(dp), pointer, dimension(:) :: &
693 : lblk1, dblk1, ublk1, lblkF1, dblkF1, ublkF1, &
694 : B1, soln1, row_scale_factors1, col_scale_factors1
695 : integer, pointer :: ipivot1(:)
696 : character (len=nz) :: equed1
697 : integer, intent(out) :: ierr
698 :
699 33 : integer, pointer :: nslevel(:), ipivot(:)
700 : integer :: ncycle, nstemp, maxlevels, nlevel, nvar2, i
701 33 : real(dp), pointer, dimension(:,:) :: dmat, dmatF
702 33 : real(dp), pointer, dimension(:) :: row_scale_factors, col_scale_factors
703 : character (len=1) :: equed
704 :
705 : include 'formats'
706 :
707 :
708 : if (dbg) write(*,*) 'start bcyclic_solve'
709 :
710 : ! copy B to soln
711 33 : !$OMP PARALLEL DO SIMD
712 : do i=1,nvar*nz
713 : soln1(i) = B1(i)
714 : end do
715 : !$OMP END PARALLEL DO SIMD
716 :
717 33 : ierr = 0
718 :
719 33 : nvar2 = nvar*nvar
720 33 : ncycle = 1
721 33 : maxlevels = 0
722 375 : do while (ncycle < nz)
723 342 : ncycle = 2*ncycle
724 342 : maxlevels = maxlevels+1
725 : end do
726 33 : maxlevels = max(1, maxlevels)
727 :
728 33 : allocate (nslevel(maxlevels), stat=ierr)
729 33 : if (ierr /= 0) return
730 :
731 33 : ncycle = 1
732 33 : nstemp = nz
733 33 : nlevel = 1
734 :
735 : if (dbg) write(*,*) 'start forward_cycle'
736 :
737 342 : forward_cycle: do
738 :
739 342 : nslevel(nlevel) = nstemp
740 : if (dbg) write(*,2) 'call cycle_rhs', nstemp
741 : call cycle_rhs( &
742 : s, nz, nstemp, nvar, ncycle, nlevel, &
743 : dblk1, dblkF1, soln1, ipivot1, &
744 342 : row_scale_factors1, col_scale_factors1, equed1, ierr)
745 342 : if (ierr /= 0) then
746 0 : call dealloc
747 0 : return
748 : end if
749 :
750 342 : if (nstemp == 1) exit forward_cycle
751 :
752 342 : nstemp = (nstemp+1)/2
753 342 : nlevel = nlevel+1
754 342 : ncycle = 2*ncycle
755 :
756 342 : if (nlevel > maxlevels) exit forward_cycle
757 :
758 : end do forward_cycle
759 :
760 : if (dbg) write(*,*) 'done forward_cycle'
761 :
762 33 : dmat(1:nvar,1:nvar) => dblk1(1:nvar2)
763 33 : dmatF(1:nvar,1:nvar) => dblkF1(1:nvar2)
764 33 : ipivot(1:nvar) => ipivot1(1:nvar)
765 33 : row_scale_factors(1:nvar) => row_scale_factors1(1:nvar)
766 33 : col_scale_factors(1:nvar) => col_scale_factors1(1:nvar)
767 33 : equed(1:1) = equed1(1:1)
768 : call dense_solve1(s, 1, nvar, soln1, dmat, dmatF, ipivot, .false., &
769 33 : row_scale_factors, col_scale_factors, equed, ierr)
770 33 : if (ierr /= 0) then
771 0 : write(*,*) 'failed in my_getrs1'
772 0 : call dealloc
773 0 : return
774 : end if
775 :
776 : ! back solve for even x's
777 375 : back_cycle: do while (ncycle > 1)
778 342 : ncycle = ncycle/2
779 342 : nlevel = nlevel-1
780 342 : if (nlevel < 1) then
781 0 : ierr = -1
782 0 : exit back_cycle
783 : end if
784 342 : nstemp = nslevel(nlevel)
785 : call cycle_solve( &
786 : s, nvar, nz, ncycle, nstemp, nlevel, &
787 342 : lblk1, ublk1, lblkF1, ublkF1, soln1)
788 : end do back_cycle
789 :
790 33 : call dealloc
791 :
792 33 : if (dbg) write(*,*) 'done bcyclic_solve'
793 :
794 :
795 : contains
796 :
797 :
798 33 : subroutine dealloc
799 33 : deallocate (nslevel)
800 33 : end subroutine dealloc
801 :
802 :
803 : end subroutine bcyclic_solve
804 :
805 :
806 1 : subroutine clear_storage(s)
807 : type (star_info), pointer :: s
808 : integer :: nlevel
809 1 : nlevel = size(s% bcyclic_odd_storage)
810 16 : do while (nlevel > 0)
811 15 : if (s% bcyclic_odd_storage(nlevel)% ul_size > 0) then
812 12 : deallocate(s% bcyclic_odd_storage(nlevel)% umat1)
813 12 : deallocate(s% bcyclic_odd_storage(nlevel)% lmat1)
814 : end if
815 15 : nlevel = nlevel-1
816 : end do
817 1 : deallocate(s% bcyclic_odd_storage)
818 1 : nullify(s% bcyclic_odd_storage)
819 1 : end subroutine clear_storage
820 :
821 :
822 78456 : subroutine dense_solve(s, k, nvar, mtx, mtxF, ipivot, X_mtx, &
823 : row_scale_factors, col_scale_factors, equed, ierr)
824 : type (star_info), pointer :: s
825 : integer, intent(in) :: k, nvar
826 : real(dp), pointer, dimension(:,:) :: mtx, mtxF, X_mtx
827 : integer, pointer :: ipivot(:)
828 : real(dp), pointer, dimension(:) :: row_scale_factors, col_scale_factors
829 : character (len=1) :: equed
830 : integer, intent(out) :: ierr
831 : integer :: i
832 78456 : real(dp), pointer :: X(:)
833 78456 : ierr = 0
834 :
835 78456 : if (s% use_DGESVX_in_bcyclic) then
836 0 : call solve_with_DGESVX
837 0 : return
838 : end if
839 :
840 1019928 : do i=1,nvar
841 941472 : X(1:nvar) => X_mtx(1:nvar,i)
842 : call dense_solve1(s, k, nvar, X, mtx, mtxF, ipivot, .false., &
843 941472 : row_scale_factors, col_scale_factors, equed, ierr)
844 1019928 : if (ierr /= 0) return
845 : end do
846 :
847 : contains
848 :
849 0 : subroutine solve_with_DGESVX
850 : character (len=1) :: fact, trans
851 0 : real(dp) :: rcond
852 0 : real(dp) :: a(nvar,nvar), af(nvar,nvar), b(nvar,nvar), x(nvar,nvar), &
853 0 : r(nvar), c(nvar), ferr(nvar), berr(nvar), work(4*nvar)
854 0 : integer :: ipiv(nvar), iwork(nvar)
855 : integer :: i, j, nrhs
856 : include 'formats'
857 :
858 0 : nrhs = nvar
859 :
860 0 : do i=1,nvar
861 0 : !$OMP SIMD
862 : do j=1,nvar
863 0 : a(i,j) = mtx(i,j)
864 0 : af(i,j) = mtxF(i,j)
865 0 : b(i,j) = X_mtx(i,j)
866 0 : x(i,j) = 0d0
867 : end do
868 0 : r(i) = row_scale_factors(i)
869 0 : c(i) = col_scale_factors(i)
870 0 : ipiv(i) = ipivot(i)
871 : end do
872 :
873 0 : fact = 'F' ! factored
874 0 : trans = 'N' ! no transpose
875 :
876 : ! SUBROUTINE DGESVX( FACT, TRANS, N, NRHS, A, LDA, AF, LDAF, IPIV,
877 : ! $ EQUED, R, C, B, LDB, X, LDX, RCOND, FERR, BERR,
878 : ! $ WORK, IWORK, INFO )
879 :
880 : call DGESVX(fact, trans, nvar, nrhs, a, nvar, af, nvar, ipiv, &
881 : equed, r, c, b, nvar, x, nvar, rcond, ferr, berr, &
882 0 : work, iwork, ierr)
883 0 : if (ierr /= 0) then
884 0 : write(*,2) 'solve_with_DGESVX failed', k
885 : end if
886 :
887 0 : do i=1,nvar
888 0 : !$OMP SIMD
889 : do j=1,nvar
890 0 : X_mtx(i,j) = x(i,j)
891 : end do
892 : end do
893 :
894 0 : end subroutine solve_with_DGESVX
895 :
896 : end subroutine dense_solve
897 :
898 :
899 980733 : subroutine dense_solve1(s, k, nvar, X_vec, mtx, mtxF, ipivot, dbg, &
900 : row_scale_factors, col_scale_factors, equed, ierr)
901 : type (star_info), pointer :: s
902 : integer, intent(in) :: k, nvar
903 : real(dp), pointer :: X_vec(:), mtx(:,:), mtxF(:,:)
904 : integer, pointer :: ipivot(:)
905 : real(dp), pointer, dimension(:) :: row_scale_factors, col_scale_factors
906 : character (len=1) :: equed
907 : logical, intent(in) :: dbg
908 : integer, intent(out) :: ierr
909 : include 'formats'
910 980733 : ierr = 0
911 :
912 980733 : if (s% use_DGESVX_in_bcyclic) then
913 0 : call solve1_with_DGESVX
914 0 : return
915 : end if
916 :
917 980733 : if (nvar == 4) then
918 0 : call my_getrs1_n4(mtxF, ipivot, X_vec, ierr)
919 980733 : else if (nvar == 5) then
920 0 : call my_getrs1_n5(mtxF, ipivot, X_vec, ierr)
921 : else
922 980733 : call my_getrs1(nvar, mtxF, nvar, ipivot, X_vec, nvar, ierr)
923 : end if
924 :
925 : contains
926 :
927 0 : subroutine solve1_with_DGESVX
928 : character (len=1) :: fact, trans
929 0 : real(dp) :: rcond
930 : integer, parameter :: nrhs = 1
931 0 : real(dp) :: a(nvar,nvar), af(nvar,nvar), b(nvar,nrhs), x(nvar,nrhs), &
932 0 : r(nvar), c(nvar), ferr(nrhs), berr(nrhs), work(4*nvar)
933 0 : integer :: ipiv(nvar), iwork(nvar)
934 : integer :: i, j
935 :
936 : include 'formats'
937 :
938 0 : do i=1,nvar
939 0 : !$OMP SIMD
940 : do j=1,nvar
941 0 : a(i,j) = mtx(i,j)
942 0 : af(i,j) = mtxF(i,j)
943 : end do
944 0 : b(i,1) = X_vec(i)
945 0 : x(i,1) = 0d0
946 0 : r(i) = row_scale_factors(i)
947 0 : c(i) = col_scale_factors(i)
948 0 : ipiv(i) = ipivot(i)
949 : end do
950 :
951 0 : fact = 'F' ! factored
952 0 : trans = 'N' ! no transpose
953 :
954 : ! SUBROUTINE DGESVX( FACT, TRANS, N, NRHS, A, LDA, AF, LDAF, IPIV,
955 : ! $ EQUED, R, C, B, LDB, X, LDX, RCOND, FERR, BERR,
956 : ! $ WORK, IWORK, INFO )
957 :
958 : call DGESVX(fact, trans, nvar, nrhs, a, nvar, af, nvar, ipiv, &
959 : equed, r, c, b, nvar, x, nvar, rcond, ferr, berr, &
960 0 : work, iwork, ierr)
961 :
962 0 : !$OMP SIMD
963 : do i=1,nvar
964 0 : X_vec(i) = x(i,1)
965 : end do
966 :
967 0 : end subroutine solve1_with_DGESVX
968 :
969 : end subroutine dense_solve1
970 :
971 :
972 : subroutine bcyclic_deallocate (s, ierr)
973 : type (star_info), pointer :: s
974 : integer, intent(out) :: ierr
975 : ierr = 0
976 : end subroutine bcyclic_deallocate
977 :
978 :
979 : include 'mtx_solve_routines.inc'
980 :
981 : end module star_bcyclic
|