!**********************************************************************
! MBAR   Copyright (C) 2013  Simone Marsili and Piero Procacci
! multiple bennett acceptance ratio (MBAR) algorithm with BAR 
! estimate for contiguous replica-
!**********************************************************************
!    This program is free software: you can redistribute it and/or modify
!    it under the terms of the GNU General Public License as published by
!    the Free Software Foundation, either version 3 of the License, or
!    (at your option) any later version.

!    This program is distributed in the hope that it will be useful,
!    but WITHOUT ANY WARRANTY; without even the implied warranty of
!    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!    GNU General Public License for more details.

!    This program comes with ABSOLUTELY NO WARRANTY.
!    This is free software, and you are welcome to redistribute it
!    under certain conditions.
!    For details on terms and conditions see <http://www.gnu.org/licenses/>.

program mbar

  use precision
  use strings

!**********************************************************************
  
  implicit none
  integer(ki) :: NPA 
  PARAMETER(NPA=1000000)
  integer(ki) :: i,j,c,s,iter,jj,istate,NF,NR
  integer(ki) :: err
  character(len=1280) :: line
  integer(ki) :: nargs
  character(len=128),allocatable :: args(:)
  integer(ki),parameter :: upar=11
  integer(ki),parameter :: udata=12
  character(len=128) :: fpar 
  character(len=128) :: fdata
  real(kr) :: x,x1,x2,tol
  integer(ki) :: np ! Number of Parameters per state
  integer(ki) :: ns ! Number of States
  integer(ki) :: nc ! Number of Configurations 
  real(kr),allocatable :: parms(:,:),dbeta(:,:)
  real(kr),allocatable :: pots(:,:)
  integer(ki),allocatable :: state(:)
  integer(ki),allocatable :: ncs(:)
  real(kr),allocatable :: z(:),zr(:),z0(:)
  real(kr),allocatable :: q(:,:)
  real(kr),allocatable :: wrk(:,:,:)
  real(kr),allocatable :: d(:), dum(:)
  real(kr),parameter :: eps=0.001
  real(kr) :: diff

  !!
  character(len=1280) :: string
  real(kr) :: w
  logical :: conv = .false.
  logical :: fileread = .false., usebar=.true.,debug=.false.
  real(kr) :: minw=10000000.0
  REAL(kr) WRK_F_Ab(NPA), WRK_R_Ba(NPA)

  COMMON /COMM1/ WRK_F_Ab,WRK_R_Ba,NF,NR
!  external :: FUNC

  ! print command line on standard error 
  call get_command(line)
  write(0,*) trim(line)
  
  !-------------- read arguments 

  call readargs(args,nargs,err)
  select case(nargs)
  case (0:1,5: )
     write(*,*) "Error reading arguments"
     write(*,*) "syntax:  mbar PARAMETER_FILE DATA_FILE [guess [debug]] "
     stop
  case (3)    
     if(args(3)=="F") usebar=.false. 
  case (4)    
     if(args(3)=="F") usebar=.false. 
     if(args(4)=="T") debug=.true. 
  end select
  fpar = args(1)
  fdata = args(2)

  write(0,*) "parameter file: ", trim(fpar)
  write(0,*) "data file: ", trim(fdata)
  write(0,*) "initial guess BAR: ", usebar
  write(0,*) "Debug printouts: ", debug

  !------------- open units
  open(unit=upar,file=trim(fpar),status='OLD',ERR=1001)
  open(unit=udata,file=trim(fdata),status='OLD',ERR=1002)
  fileread=.true. 

  !------------- get np,ns,nc
  ns = 0
  np = 0
  do 
     i = np
     call nargsline(upar,line,' ',np,err) 
     if(i /= 0 .and. np /= i) stop "check the parameters file"
     if(err < 0) exit
     ns = ns + 1
  end do
  rewind(upar)
  nc = 0
  do 
     read(udata,*,iostat=err) x
     if(err /= 0) then
        if(err < 0) then 
           exit
        else
           stop "error in data file" 
        end if
     end if
     nc = nc + 1
  end do
  rewind(udata)
  if(int(nc/ns)>NPA) then
     stop "Too many configurations per replica: increase NPA and recompile." 
  end if

  !------------- print headers
  write(0,*) "number of states: ", ns
  write(0,*) "number of parameters per state: ", np
  write(0,*) "number of config. : ", nc

  !------------- allocate memory
  allocate(parms(ns,np),stat=err)
  if(err /= 0) stop "error allocating parms"
  allocate(dbeta(ns,np),stat=err)
  if(err /= 0) stop "error allocating dbeta"
  allocate(state(nc),pots(nc,np),stat=err)
  if(err /= 0) stop "error allocating state,pots"
  allocate(ncs(ns),wrk(ns,nc,np),stat=err)
  if(err /= 0) stop "error allocating ncs,wrk"
  allocate(z(ns),zr(ns),z0(ns),stat=err)
  if(err /= 0) stop "error allocating z,zr,z0"
  allocate(q(nc,ns),d(nc),dum(np),stat=err)
  if(err /= 0) stop "error allocating q,d,dum"

  !------------- read 
  do i = 1,ns
     read(upar,*,iostat=err) parms(i,:)
     if(i > 1) dbeta(i-1,:)=parms(i,:) -parms((i-1),:)
     if(i > 1 .and. debug ) write(0,5) i-1,dbeta(i-1,:) 
5    format("DBETA",i3,3F15.3)
    if(err /= 0) stop "error in parameters file"
  end do
  do i = 1,nc
     read(udata,*,iostat=err) state(i),pots(i,:)
     if(err /= 0) stop "error in data file"
     if(state(i) > ns) stop "wrong state number"
  end do

  ! compute ncs
  ncs = 0
  do i = 1,nc
     ncs(state(i)) = ncs(state(i)) + 1
! compute and store all energies
     istate=state(i)
     dum(:) = pots(i,:)
     jj=ncs(state(i))
     wrk(istate,jj,:)=dum(:)
  end do


  ! compute exp once for all
  ! POTS nc x np, parms ns x np, q nc x ns
  q = matmul(pots,transpose(parms))

  ! remove max abosolute value
  
  do i = 1,nc
     if(maxval(abs(q(i,:))) > maxval(q(i,:)))then 
        q(i,:) = q(i,:) - minval(q(i,:))
     else
        q(i,:) = q(i,:) - maxval(q(i,:))
     end if
  end do

  ! take exp
  q = exp(-q)

!  z = 1.0
! compute initial guess
  z = 0.0
  X1 = -360.
  X2 = 360.
  TOL = 1.E-9
! BAR guess of the Z_k+1/Z_k ratios between contiguous replica
! Zbrent yields ln Z_k/Z_{k+1}=Delta F= F_{k+1}-F_{k}; 
! func is the BAR sum Eq. 15.  
  if(usebar) then
     ZR(1)=1.0
     z(1)=1.0
     do i=1,ns-1
        NF=NCS(i)
        NR=NCS(i+1)
        do J=1,NF 
           WRK_F_Ab(J)=sum(wrk(I,J,:)*dbeta(i,:))
           WRK_R_Ba(j)=-sum(wrk(i+1,J,:)*dbeta(i,:))
! print energies in state files if debug.
           if(debug) write(100+i,10) I,J,wrk(I,J,:)
10         format(2i8,3ES15.5)
        end do
        ZR(i+1)  = EXP(-ZBRENT(FUNC,X1,X2,TOL ))
        Z(I+1) = ZR(I+1)*Z(i)
        write(0,*) Z(I+1),I,ZBRENT(FUNC,X1,X2,TOL )
     end do
  else
! naive first estimate
     do c = 1,nc 
        z(state(c)) = z(state(c)) + sum(pots(c,:))
     end do
     z = z / ncs
     z = z/z(1)
  end if 
! start mbar iterations
  iter = 0
  inner_loop: do 
     iter = iter + 1
     do j = 1,ns
        if(abs(z(j)) < tiny(x)) then 
           write(0,*) "zero value of z in state ",j
           !           z(j) = z(j) + tiny(x) 
           stop
        end if
     end do

     z0 = z
     z = 0.0
     c_loop: do c = 1,nc
        d(c) = sum(ncs*q(c,:)/z0)
!        write(*,*) iter, c, d(c), z(1),z(ns)
        s_loop: do s = 1,ns
           ! compute weights for conf. c in state s
           w = q(c,s) / d(c)
           z(s) = z(s) + w
           if(w < minw) minw = w
           if(conv .and. s == 1) write(*,*) w
        end do s_loop
     end do c_loop
     z = z/z(1)
     if(iter > 1) then 
        if(conv) exit inner_loop
        ! check convergence
        diff = maxval(abs(log(z)-log(z0)))
        write(0,'(a,i5,3g15.5)') 'diff', iter, diff
        if(diff < eps) conv = .true. ! converged: print weights in next iteration and exit
        IF(debug) THEN  
           write(0,'(a,i5,a,g15.5)') 'Iteration = ', iter, ' f_tol = ', diff
           write(0,'("Rep",8x,"Z_k/Z_1",8xx,"z_k/z_k+1")') 
           do J=2,ns
              write(0,fmt='(I5,2ES15.5)') J-1,Z(J),Z(J)/Z(J-1)
           END DO
        end if
     end if
  end do inner_loop

  write(0,*) "goodbye!", z(ns), minw
  if(fileread) stop
1001 write(*,*) "Parameter file '", trim(fpar),"' not found"
1002 write(*,*) "Data File '", trim(fdata),"' not found"
contains 

!**********************************************************************
  
  subroutine readargs(args,nargs,err)
    use precision
    implicit none
    character(len=*),allocatable,intent(out) :: args(:)
    integer(ki),intent(out) :: nargs
    integer(ki),intent(out) :: err 
    ! local variables
    integer(ki) :: i
    
    err = 0
    
    ! get arguments number
    
    nargs = command_argument_count()

    allocate(args(nargs),stat=err)
    
    ! read arguments
    
    do i = 1,nargs
       call get_command_argument(i, args(i))
    end do
    
  end subroutine readargs

FUNCTION FUNC (DEF)
  use precision
 
  IMPLICIT NONE

  INTEGER(ki) :: NPA
  PARAMETER(NPA=1000000)
  REAL(kr) :: FUNC, DEF
  REAL(kr) :: F1, F2, DF_AB
  REAL(kr) WRK_F_Ab(NPA),  WRK_R_Ba(NPA)
  Integer(ki) ::  N,NF,NR
  
  COMMON /COMM1/ WRK_F_Ab,WRK_R_Ba,NF,NR

  F1 = 0.0
  F2 = 0.0

  DO N = 1, NF
     F1 = F1 + 1.0 / ( 1.0 + EXP(WRK_F_Ab(N)-DEF))
  ENDDO

  DO N = 1, NR
     F2 = F2 - 1.0 / ( 1.0 + EXP(WRK_R_Ba(N)+DEF)) 
  ENDDO

  FUNC = F1 + F2
!  write(0,10) NF,NR,F1,F2,DEF,FUNC
10 format("In func",2i8,2f15.5,5x,2f15.5)
  RETURN
END FUNCTION FUNC

!**********************************************************************

!*******************************************************************
!
FUNCTION ZBRENT(func,x1,x2,tol)
!
! This is a bisection routine. When ZBRENT is called, we provide a
! reference to a particular function and also two values which bound
! the arguments for the function of interest. ZBRENT finds a root of
! the function (i.e. the point where the function equals zero), that
! lies between the two bounds.  For a full description see Press et
! al. (1986).
!
!*******************************************************************
      use precision
      real(kr),intent(in)   :: tol,x1,x2
      real(kr) :: func !OLI: added to satisfy the implicit none compiler option.
      external          :: func
      real(kr)              :: zbrent

      ! internal variables...
      integer(ki)           :: iter
      integer(ki),parameter :: ITMAX=30
      real(kr)              :: a,b,c,d,e,fa,fb,fc,p,q,r,s,tol1,xm
      real(kr),parameter    :: EPS=3.e-8

      ! calculations...
      a=x1
      b=x2
      fa=func(a)
      fb=func(b)
      if ((fa.gt.0..and.fb.gt.0.) .or. (fa.lt.0..and.fb.lt.0.)) then
                fa=func(a)
                fb=func(b)
            write(*,*)'             fa            fb              x1                x2'
                write(*,*)fa,fb,x1,x2
                write(*,*)"ZBRENT.F90: root must be bracketed"
                fa=func(a)
                fb=func(b)
          endif
      c=b
      fc=fb
      do iter=1,ITMAX
        if ((fb.gt.0..and.fc.gt.0.) .or. (fb.lt.0..and.fc.lt.0.)) then
          c=a
          fc=fa
          d=b-a
          e=d
        endif
        if (abs(fc).lt.abs(fb)) then
          a=b
          b=c
          c=a
          fa=fb
          fb=fc
          fc=fa
        endif
        tol1=2.*EPS*abs(b)+0.5*tol
        xm=.5*(c-b)
        if (abs(xm).le.tol1 .or. fb.eq.0.) then
          zbrent=b
          return
        endif
        if (abs(e).ge.tol1 .and. abs(fa).gt.abs(fb)) then
          s=fb/fa
          if (a.eq.c) then
            p=2.*xm*s
            q=1.-s
          else
            q=fa/fc
            r=fb/fc
            p=s*(2.*xm*q*(q-r)-(b-a)*(r-1.))
            q=(q-1.)*(r-1.)*(s-1.)
          endif
          if (p.gt.0.) q=-q
          p=abs(p)
          if (2.*p .lt. min(3.*xm*q-abs(tol1*q),abs(e*q))) then
            e=d
            d=p/q
          else
            d=xm
            e=d
          endif
        else
          d=xm
          e=d
        endif
        a=b
        fa=fb
        if (abs(d) .gt. tol1) then
          b=b+d
        else
          b=b+sign(tol1,xm)
        endif
        fb=func(b)
      enddo
      write(*,*)"ZBRENT: exceeding maximum iterations"
      zbrent=b
END FUNCTION
  
end program mbar


