C wave=wavelet(s,freq,dt, z0, skip)
C s is the input vector
C freq is a vector of frequencies to analyse
C dt is the time step for the input vector
C z0 is the number of wavelengths in each wavelet
C     skip is how often to actually calculate- when plotting, only
c     plot every skip'th element in the time dimension of wave

      subroutine mexfunction(nlhs, plhs, nrhs, prhs)

      integer*4 nlhs, nrhs
      integer*4 plhs(*), prhs(*)

      integer*4 m_in, n_in
      integer*4 wavep, sp, freqp, dtp, z0p, skipp

      m_in = mxgetn(prhs(1))
      n_in = mxgetn(prhs(2))
      
      sp = mxgetpr(prhs(1))
      freqp = mxgetpr(prhs(2))
      dtp = mxgetpr(prhs(3))
      z0p = mxgetpr(prhs(4))
      skipp = mxgetpr(prhs(5))

      plhs(1) = mxcreatefull(m_in, n_in, 0)
      
      wavep = mxgetpr(plhs(1))

      call wavelet(%VAL(wavep), %VAL(sp), %VAL(freqp), %VAL(dtp)
     $     ,%VAL(z0p),%VAL(skipp),m_in,n_in) 
      
      return
      end


      subroutine wavelet(wave, s, freq,  dt, z0, skip, mm, nn)

c     - wave is the output, size (mm,nn)
c     - s is the input "signal" vector (max size=M
c     - freq is the vector of frequencies to do the calculation for
c     is the 1/(sampling rate) of the signal
c     - z0 is the width of the wavelet (ie number of wavelengths) -
c     larger values improve the frequency definition, but cause
c     problems because there's more chance that they'll straddle the
c     start/end of a note
c     - skip is how often to do the analysis (ie only calulate at
c     times n*skip*dt where n is an integer)

      implicit none
      integer*4 M, N, mm, nn
      parameter(M=200000, N=100)
      integer*4 i,j,ii,lmin, lmax, offset
      real*8 z0, dt, t, k, nan, skip
      real*8 period(N), tvals(M), kvals(N), freq(nn),  s(mm)
      double complex psi, ps, ctot
      real*8 wave(mm,nn)
      real mexGetNaN
      external mexGetNaN
      nan=mexGetNan()

C      print*, z0

      if ((mm.le.M).and.(nn.le.N)) then

         do j=1,nn
            period(j)=1.0/freq(j)
            kvals(j)=freq(j)
            
            do i=1,mm
               wave(i,j)=nan
            enddo

         enddo
         
         do i=1,mm
            tvals(i)=i*dt
         enddo

         print*,
     $        '       i       kvals       freq         period'


         do i = 1,nn
C     i=1
            print*, i, kvals(i),freq(i),period(i)
            offset=int(((z0+1)/2.0)*period(i)/dt)
            if (offset*2.lt.mm) then 
               do j = offset,mm-(1+offset)
                  if (MOD(j,skip).EQ.0) then
                  
                     k=kvals(i)
                     t=tvals(j)
                     
C     the following really speeds up the calculation, because there's
C     no point doing the integration in the tails of the wavelet...
                     
                     lmin=j-int(1.5*offset)
                     if (lmin.lt.1) then
                        lmin=1
                     endif
                     
                     lmax=j+int(1.5*offset)
                     if (lmax.gt.mm) then
                        lmax=mm
                     endif

C     Do the integration
                     ctot=0
                     do ii=lmin,lmax
                        ps=psi(k*(tvals(ii)-t),z0)
                        ctot=ctot+s(ii)*conjg(ps)
                     enddo
C     print*, ctot

C     swap indices for plotting
                     wave(j,i)=(k**0.5)*CDABS(ctot)
C     /(1+lmax-lmin) !don't know where this came from!

C     print*, j, i, ctot, CABS(ctot), k, wave(j,i)
                  elseif (j.ne.1) then
                     wave(j,i)=wave(j-1,i)
                  endif
               enddo
                                 
            endif

         enddo
C     print*, psi(z0/10.0,z0)
      else
         print*, 'Input signal vector too large.'
      endif

      return
      end
      
      double complex function psi(x, z0)
      implicit none
      
      complex i
      real*8 pi
      real*8 x, z0
      
      i=cmplx(0,1)
      pi=3.14159265
      psi=(COS(2.0*pi*x) + i * SIN(2.0* pi*x))*exp(-2.0
     $     *(x**2.0)*(pi**2.0)/(z0**2.0))-exp((-(z0**2) / 2) - 2
     $     *(x**2)*(pi**2)/(z0**2.0))
      return
      end



      







