38 public :: matrix_solver_tridiagonal
43 interface matrix_solver_tridiagonal
49 module procedure matrix_solver_tridiagonal_2d
51 module procedure matrix_solver_tridiagonal_3d
52 end interface matrix_solver_tridiagonal
69 type(cusparseHandle) :: handle
70 real(RP),
allocatable :: pbuffer(:)
90 log_info(
"MATRIX_setup",*)
'Setup'
104 status = cusparsecreate(handle)
105 if ( status /= cusparse_status_success )
then
106 log_error(
"MATRIX_setup",*)
"cusparseCreate failed: ", status
119 status = cusparsedestroy(handle)
120 if (
allocated(pbuffer) )
deallocate(pbuffer)
140 integer,
intent(in) :: ka, ks, ke
142 real(rp),
intent(in) :: ud(ka)
143 real(rp),
intent(in) :: md(ka)
144 real(rp),
intent(in) :: ld(ka)
145 real(rp),
intent(in) :: iv(ka)
147 real(rp),
intent(out) :: ov(ka)
150 real(rp),
intent(out) :: work(ks:ke,2)
151 #define c_ta(k) work(k,1)
152 #define d_ta(k) work(k,2)
154 real(rp) :: c_ta(ks:ke)
155 real(rp) :: d_ta(ks:ke)
163 c_ta(ks) = ud(ks) / md(ks)
164 d_ta(ks) = iv(ks) / md(ks)
166 rdenom = 1.0_rp / ( md(k) - ld(k) * c_ta(k-1) )
167 c_ta(k) = ud(k) * rdenom
168 d_ta(k) = ( iv(k) - ld(k) * d_ta(k-1) ) * rdenom
172 ov(ke) = ( iv(ke) - ld(ke) * d_ta(ke-1) ) / ( md(ke) - ld(ke) * c_ta(ke-1) )
174 ov(k) = d_ta(k) - c_ta(k) * ov(k+1)
193 integer,
intent(in) :: ka, ks, ke
195 real(rp),
intent(in) :: ud(ka)
196 real(rp),
intent(in) :: md(ka)
197 real(rp),
intent(in) :: ld(ka)
198 real(rp),
intent(in) :: iv(ka)
200 real(rp),
intent(out) :: ov(ka)
203 real(rp),
intent(out) :: work(ke-ks+1,4)
204 #define a1_cr(k) work(k,1)
205 #define b1_cr(k) work(k,2)
206 #define c1_cr(k) work(k,3)
207 #define x1_cr(k) work(k,4)
209 real(rp) :: a1_cr(ke-ks+1)
210 real(rp) :: b1_cr(ke-ks+1)
211 real(rp) :: c1_cr(ke-ks+1)
212 real(rp) :: x1_cr(ke-ks+1)
216 integer :: lmax, kmax
217 integer :: k, k1, k2, l
220 lmax = floor( log(real(kmax,rp)) / log(2.0_rp) ) - 1
231 do k = st*2, kmax, st*2
234 f1 = a1_cr(k) / b1_cr(k1)
235 if ( k2 > kmax )
then
239 f2 = c1_cr(k) / b1_cr(k2)
241 a1_cr(k) = - a1_cr(k1) * f1
242 c1_cr(k) = - c1_cr(k2) * f2
243 b1_cr(k) = b1_cr(k) - c1_cr(k1) * f1 - a1_cr(k2) * f2
244 x1_cr(k) = x1_cr(k) - x1_cr(k1) * f1 - x1_cr(k2) * f2
248 if ( kmax / st == 2 )
then
249 ov(ks+st*2-1) = ( a1_cr(st*2) * x1_cr(st) - b1_cr(st) * x1_cr(st*2) ) &
250 / ( a1_cr(st*2) * c1_cr(st) - b1_cr(st) * b1_cr(st*2) )
251 ov(ks+st-1) = ( x1_cr(st) - c1_cr(st) * ov(ks+st*2-1) ) / b1_cr(st)
252 else if ( kmax / st == 3 )
then
256 f2 = c1_cr(k1) / b1_cr(k)
257 c1_cr(k1) = - c1_cr(k) * f2
258 b1_cr(k1) = b1_cr(k1) - a1_cr(k) * f2
259 x1_cr(k1) = x1_cr(k1) - x1_cr(k) * f2
261 f1 = a1_cr(k2) / b1_cr(k)
262 a1_cr(k2) = - a1_cr(k) * f1
263 b1_cr(k2) = b1_cr(k2) - c1_cr(k) * f1
264 x1_cr(k2) = x1_cr(k2) - x1_cr(k) * f1
266 ov(ks+k2-1) = ( a1_cr(k2) * x1_cr(k1) - b1_cr(k1) * x1_cr(k2) ) &
267 / ( a1_cr(k2) * c1_cr(k1) - b1_cr(k1) * b1_cr(k2) )
268 ov(ks+k1-1) = ( x1_cr(k1) - c1_cr(k1) * ov(ks+k2-1) ) / b1_cr(k1)
269 ov(ks+k-1) = ( x1_cr(k) - a1_cr(k) * ov(ks+k1-1) - c1_cr(k) * ov(ks+k2-1) ) / b1_cr(k)
276 do k = st, kmax, st*2
278 ov(ks+k-1) = ( x1_cr(k) - c1_cr(k) * ov(ks+k+st-1) ) / b1_cr(k)
279 elseif ( k+st <= kmax )
then
280 ov(ks+k-1) = ( x1_cr(k) - a1_cr(k) * ov(ks+k-st-1) - c1_cr(k) * ov(ks+k+st-1) ) / b1_cr(k)
282 ov(ks+k-1) = ( x1_cr(k) - a1_cr(k) * ov(ks+k-st-1) ) / b1_cr(k)
303 integer,
intent(in) :: ka, ks, ke
305 real(rp),
intent(in) :: ud(ka)
306 real(rp),
intent(in) :: md(ka)
307 real(rp),
intent(in) :: ld(ka)
308 real(rp),
intent(in) :: iv(ka)
310 real(rp),
intent(out) :: ov(ka)
313 real(rp),
intent(out) :: work(ke-ks+1,2,4)
314 #define a1_pcr(k,n) work(k,n,1)
315 #define b1_pcr(k,n) work(k,n,2)
316 #define c1_pcr(k,n) work(k,n,3)
317 #define x1_pcr(k,n) work(k,n,4)
319 real(rp) :: a1_pcr(ke-ks+1,2)
320 real(rp) :: b1_pcr(ke-ks+1,2)
321 real(rp) :: c1_pcr(ke-ks+1,2)
322 real(rp) :: x1_pcr(ke-ks+1,2)
326 integer :: lmax, kmax
327 integer :: iw1, iw2, iws
328 integer :: k, k1, k2, l
331 lmax = ceiling( log(real(kmax,rp)) / log(2.0_rp) )
333 a1_pcr(:,1) = ld(ks:ke)
334 b1_pcr(:,1) = md(ks:ke)
335 c1_pcr(:,1) = ud(ks:ke)
336 x1_pcr(:,1) = iv(ks:ke)
351 f1 = a1_pcr(k,iw1) / b1_pcr(k1,iw1)
353 if ( k2 > kmax )
then
357 f2 = c1_pcr(k,iw1) / b1_pcr(k2,iw1)
359 a1_pcr(k,iw2) = - a1_pcr(k1,iw1) * f1
360 c1_pcr(k,iw2) = - c1_pcr(k2,iw1) * f2
361 b1_pcr(k,iw2) = b1_pcr(k,iw1) - c1_pcr(k1,iw1) * f1 - a1_pcr(k2,iw1) * f2
362 x1_pcr(k,iw2) = x1_pcr(k,iw1) - x1_pcr(k1,iw1) * f1 - x1_pcr(k2,iw1) * f2
372 ov(ks+k-1) = x1_pcr(k,iw1) / b1_pcr(k,iw1)
379 subroutine matrix_solver_tridiagonal_2d( &
388 integer,
intent(in) :: ka, ks, ke
389 integer,
intent(in) :: ia, is, ie
391 real(rp),
intent(in) :: ud(ka,ia)
392 real(rp),
intent(in) :: md(ka,ia)
393 real(rp),
intent(in) :: ld(ka,ia)
394 real(rp),
intent(in) :: iv(ka,ia)
396 real(rp),
intent(out) :: ov(ka,ia)
400 #elif defined(_OPENACC)
401 real(rp) :: work(ks:ke,4)
403 real(rp) :: c(lsize,ks:ke)
404 real(rp) :: d(lsize,ks:ke)
405 real(rp) :: w(lsize,ks:ke)
409 integer :: k, i, ii, l
417 status = cusparsesgtsv2stridedbatch_buffersizeext( &
419 status = cusparsedgtsv2stridedbatch_buffersizeext( &
423 ld(ks,is), md(ks,is), ud(ks,is), &
428 if ( status /= cusparse_status_success )
then
429 log_error(
"MATRIX_SOLVER_tridiagonal_2D",*)
"cusparseDgtsv2StridedBatch_bufferSizeExt failed: ", status
432 if ( bsize > bufsize )
then
433 if (
allocated(pbuffer) )
deallocate( pbuffer )
434 allocate( pbuffer(bsize/rp) )
439 ov(:,is:ie) = iv(:,is:ie)
444 status = cusparsesgtsv2stridedbatch( &
446 status = cusparsedgtsv2stridedbatch( &
450 ld(ks,is), md(ks,is), ud(ks,is), &
455 if ( status /= cusparse_status_success )
then
456 log_error(
"MATRIX_SOLVER_tridiagonal_2D",*)
"cusparseDgtsv2StridedBatch failed: ", status
460 #elif defined(_OPENACC)
467 ud(:,i), md(:,i), ld(:,i), &
475 do ii = is, ie, lsize
481 c(l,ks) = ud(ks,i) / md(ks,i)
482 d(l,ks) = iv(ks,i) / md(ks,i)
489 rdenom = 1.0_rp / ( md(k,i) - ld(k,i) * c(l,k-1) )
490 c(l,k) = ud(k,i) * rdenom
491 d(l,k) = ( iv(k,i) - ld(k,i) * d(l,k-1) ) * rdenom
500 w(l,ke) = ( iv(ke,i) - ld(ke,i) * d(l,ke-1) ) / ( md(ke,i) - ld(ke,i) * c(l,ke-1) )
507 w(l,k) = d(l,k) - c(l,k) * w(l,k+1)
528 end subroutine matrix_solver_tridiagonal_2d
538 integer,
intent(in) :: KA, KS, KE
540 real(RP),
intent(in) :: ud(KA,LSIZE)
541 real(RP),
intent(in) :: md(KA,LSIZE)
542 real(RP),
intent(in) :: ld(KA,LSIZE)
543 real(RP),
intent(in) :: iv(KA,LSIZE)
545 real(RP),
intent(out) :: ov(KA,LSIZE)
547 real(RP) :: c(LSIZE,KS:KE)
548 real(RP) :: d(LSIZE,KS:KE)
549 real(RP) :: work(LSIZE,KS:KE)
557 c(l,ks) = ud(ks,l) / md(ks,l)
558 d(l,ks) = iv(ks,l) / md(ks,l)
563 rdenom = 1.0_rp / ( md(k,l) - ld(k,l) * c(l,k-1) )
564 c(l,k) = ud(k,l) * rdenom
565 d(l,k) = ( iv(k,l) - ld(k,l) * d(l,k-1) ) * rdenom
571 work(l,ke) = ( iv(ke,l) - ld(ke,l) * d(l,ke-1) ) / ( md(ke,l) - ld(ke,l) * c(l,ke-1) )
576 work(l,k) = d(l,k) - c(l,k) * work(l,k+1)
590 subroutine matrix_solver_tridiagonal_2d_trans( &
597 integer,
intent(in) :: KA, KS, KE
599 real(RP),
intent(in) :: ud(LSIZE,KA)
600 real(RP),
intent(in) :: md(LSIZE,KA)
601 real(RP),
intent(in) :: ld(LSIZE,KA)
602 real(RP),
intent(in) :: iv(LSIZE,KA)
604 real(RP),
intent(out) :: ov(LSIZE,KA)
606 real(RP) :: c(LSIZE,KS:KE)
607 real(RP) :: d(LSIZE,KS:KE)
615 c(l,ks) = ud(l,ks) / md(l,ks)
616 d(l,ks) = iv(l,ks) / md(l,ks)
621 rdenom = 1.0_rp / ( md(l,k) - ld(l,k) * c(l,k-1) )
622 c(l,k) = ud(l,k) * rdenom
623 d(l,k) = ( iv(l,k) - ld(l,k) * d(l,k-1) ) * rdenom
629 ov(l,ke) = ( iv(l,ke) - ld(l,ke) * d(l,ke-1) ) / ( md(l,ke) - ld(l,ke) * c(l,ke-1) )
634 ov(l,k) = d(l,k) - c(l,k) * ov(l,k+1)
639 end subroutine matrix_solver_tridiagonal_2d_trans
642 subroutine matrix_solver_tridiagonal_3d( &
656 integer,
intent(in) :: KA, KS, KE
657 integer,
intent(in) :: IA, IS, IE
658 integer,
intent(in) :: JA, JS, JE
659 real(RP),
intent(in) :: ud(KA,IA,JA)
660 real(RP),
intent(in) :: md(KA,IA,JA)
661 real(RP),
intent(in) :: ld(KA,IA,JA)
662 real(RP),
intent(in) :: iv(KA,IA,JA)
663 real(RP),
intent(out),
target :: ov(KA,IA,JA)
665 logical,
intent(in),
optional :: mask(IA,JA)
668 real(RP),
pointer :: ovl(:,:,:)
669 real(RP),
target :: buf(KA,IA,JA)
671 #elif defined(_OPENACC)
672 real(RP) :: work(KS:KE,4)
674 real(RP) :: udl(LSIZE,KA)
675 real(RP) :: mdl(LSIZE,KA)
676 real(RP) :: ldl(LSIZE,KA)
677 real(RP) :: ivl(LSIZE,KA)
678 real(RP) :: ovl(LSIZE,KA)
679 integer :: idx(LSIZE)
683 integer :: i, j, k, l
693 status = cusparsesgtsv2stridedbatch_buffersizeext( &
695 status = cusparsedgtsv2stridedbatch_buffersizeext( &
699 ld(ks,1,1), md(ks,1,1), ud(ks,1,1), &
704 if ( status /= cusparse_status_success )
then
705 log_error(
"MATRIX_SOLVER_tridiagonal_3D",*)
"cusparseDgtsv2StridedBatch_bufferSizeExt failed: ", status
708 if ( bsize > bufsize )
then
709 if ( bufsize > 0 )
deallocate( pbuffer )
710 allocate( pbuffer(bsize/rp) )
714 if (
present(mask) )
then
721 ovl(:,:,:) = iv(:,:,:)
725 status = cusparsesgtsv2stridedbatch( &
727 status = cusparsedgtsv2stridedbatch( &
731 ld(ks,1,1), md(ks,1,1), ud(ks,1,1), &
736 if ( status /= cusparse_status_success )
then
737 log_error(
"MATRIX_SOLVER_tridiagonal_3D",*)
"cusparseDgtsv2StridedBatch failed: ", status
741 if (
present(mask) )
then
747 if ( mask(i,j) )
then
748 ov(k,i,j) = ovl(k,i,j)
757 #elif defined(_OPENACC)
759 if (
present(mask) )
then
765 if ( mask(i,j) )
then
768 ud(:,i,j), md(:,i,j), ld(:,i,j), &
783 ud(:,i,j), md(:,i,j), ld(:,i,j), &
793 if (
present(mask) )
then
799 if ( mask(i,j) )
then
802 if ( len == lsize )
then
806 udl(l,k) = ud(k,idx(l),j)
807 mdl(l,k) = md(k,idx(l),j)
808 ldl(l,k) = ld(k,idx(l),j)
809 ivl(l,k) = iv(k,idx(l),j)
812 call matrix_solver_tridiagonal_2d_trans( ka, ks, ke, &
813 udl(:,:), mdl(:,:), ldl(:,:), &
819 ov(k,idx(l),j) = ovl(l,k)
830 udl(l,k) = ud(k,idx(l),j)
831 mdl(l,k) = md(k,idx(l),j)
832 ldl(l,k) = ld(k,idx(l),j)
833 ivl(l,k) = iv(k,idx(l),j)
835 #if defined DEBUG || defined QUICKDEBUG
844 call matrix_solver_tridiagonal_2d_trans( ka, ks, ke, &
845 udl(:,:), mdl(:,:), ldl(:,:), &
851 ov(k,idx(l),j) = ovl(l,k)
861 call matrix_solver_tridiagonal_2d( ka, ks, ke, ia, is, ie, &
862 ud(:,:,j), md(:,:,j), ld(:,:,j), &
873 end subroutine matrix_solver_tridiagonal_3d
883 integer,
intent(in) :: n
884 real(rp),
intent(in) :: a (n,n)
885 real(rp),
intent(out) :: eival(n)
886 real(rp),
intent(out) :: eivec(n,n)
888 integer,
intent(in),
optional :: simdlen
890 real(rp) :: eival_inc
892 real(rp),
allocatable :: b (:,:)
893 real(rp),
allocatable :: w (:)
894 real(rp),
allocatable :: work (:)
895 integer,
allocatable :: iwork(:)
903 integer :: iblk, jblk
904 integer :: imax, jmax
905 integer :: ivec, jvec
907 integer :: i, j, ierr
911 if(
present(simdlen) )
then
919 lwork = 2*n*n + 6*n + 1
924 allocate( work(lwork) )
925 allocate( iwork(liwork) )
927 do jblk = 1, n, simdlen_
928 jmax = min( n-jblk+1, simdlen_ )
929 do iblk = 1, n, simdlen_
930 imax = min( n-iblk+1, simdlen_ )
937 b(i,j) = 0.5_rp * ( a(i,j) + a(j,i) )
945 call ssyevd(
"V",
"L",n,b,lda,w,work,lwork,iwork,liwork,ierr)
947 call dsyevd(
"V",
"L",n,b,lda,w,work,lwork,iwork,liwork,ierr)
951 log_info(
'MATRIX_SOLVER_eigenvalue_decomposition',*)
'LAPACK/SYEVD error code is ', ierr
952 log_info(
'MATRIX_SOLVER_eigenvalue_decomposition',*)
'input a'
954 log_info(
'MATRIX_SOLVER_eigenvalue_decomposition',*) j ,a(:,j)
956 log_info(
'MATRIX_SOLVER_eigenvalue_decomposition',*)
'output eival'
957 log_info(
'MATRIX_SOLVER_eigenvalue_decomposition',*) w(:)
958 log_info(
'MATRIX_SOLVER_eigenvalue_decomposition',*)
'output eivec'
960 log_info(
'MATRIX_SOLVER_eigenvalue_decomposition',*) j, b(:,j)
962 log_info(
'MATRIX_SOLVER_eigenvalue_decomposition',*)
'Try to use LAPACK/SYEV ...'
965 do jblk = 1, n, simdlen_
966 jmax = min( n-jblk+1, simdlen_ )
967 do iblk = 1, n, simdlen_
968 imax = min( n-iblk+1, simdlen_ )
975 b(i,j) = 0.5_rp * ( a(i,j) + a(j,i) )
983 call ssyev(
"V",
"L",n,b,lda,w,work,lwork,ierr)
985 call dsyev(
"V",
"L",n,b,lda,w,work,lwork,ierr)
988 if ( ierr /= 0 )
then
989 log_error(
'MATRIX_SOLVER_eigenvalue_decomposition',*)
'LAPACK/SYEV error code is ', ierr,
'! STOP.'
994 if( w(1) < 0.0_rp )
then
995 eival_inc = w(n) / ( 1.e+5_rp - 1.0_rp )
997 w(i) = w(i) + eival_inc
999 else if( w(n)/1.e+5_rp > w(1) )
then
1000 eival_inc = ( w(n) - w(1)*1.e+5_rp ) / ( 1.e+5_rp - 1.0_rp )
1002 w(i) = w(i) + eival_inc
1018 if( eival(n) > 0.0_rp )
then
1020 if( eival(i) < abs(eival(n))*sqrt(epsilon(eival)) )
then
1021 nrank_eff = nrank_eff - 1
1028 log_error(
'MATRIX_SOLVER_eigenvalue_decomposition',*)
'All eigenvalues are below 0! STOP.'
1037 log_error(
'MATRIX_SOLVER_eigenvalue_decomposition',*)
'Binary not compiled for DA! STOP.'