/* Wavelet analysis routine (+ MATLAB gateway function) */

#include <math.h>
#include "mex.h"

/* Input Arguments */

#define	S_IN	prhs[0]
#define	FREQ_IN	prhs[1]
#define	DT_IN	prhs[2]
#define	Z0_IN	prhs[3]
#define	SKIP_IN	prhs[4]

/*max size of the signal and frequency arrays */
#define M 200000
#define N 100

/* Output Arguments */

#define	WAVE_OUT	plhs[0]

#define	max(A, B)	((A) > (B) ? (A) : (B))
#define	min(A, B)	((A) < (B) ? (A) : (B))

#define pi 3.14159265

static void wavelet(
		   double	wave[], /* although 2D array, we'll index it with one subscript*/
		   double	s[],
		   double	freq[], /* MATLAB only works with arrays, so freq, dt, zo*/
		   double	dt[],   /* and skip are arrays of size=1 */
		   double	z0[],
		   double	skip[],
		   unsigned int	mm,
		   unsigned int	nn
		   )
{
	double psi_i(double x, double z0);
	double psi_r(double x, double z0);
	
	long int i,j,index,ii,lmin,lmax,offset;
	double t,k,nan, t1, t2;
	double period[N], tvals[M], kvals[N];
	double ps, ctotr, ctoti;

	nan=mxGetNaN(); /* what is returned where no calculation has been */
                  /* done (could be set to 0) */

	/*	printf("%d\n",mm);
			printf("%d\n",nn);*/

	if ((mm < M) && (nn < N)) {

		for (j=0; j<nn; j++) {
			period[j]=1.0/freq[j];
			kvals[j]=freq[j];
					
			for (i=0; i<mm; i++) {
				index=i+j*mm;
				/*				printf("%d\n",index);*/
				wave[index]=nan;
			}
		}         

		for (i=0; i<mm; i++) {
			tvals[i]=i * dt[0];
		}

		printf("   i    kvals     freq     period\n");

		for (i=0; i<nn; i++) {
			offset=floor(((z0[0]+1.0)/2.0)*period[i]/dt[0]);
			printf("%4d  %8f  %8f  %8f %8d\n",i, kvals[i], freq[i], period[i],offset);
			if (offset*2 < mm) {
				for (j=offset; j<(mm-offset); j++) {
					index=i*mm+j;
					if (j % (int) skip[0] == 0.0) {
						
						k=kvals[i];
						t=tvals[j];

						lmin=max(j-floor(1.5*offset),0);
						lmax=min(j+ceil(1.5*offset),(mm-1));

						ctotr=0;
						ctoti=0;
						
						for (ii=lmin; ii <=lmax; ii++) {
							/* integrate s*conj(psi) - note s is always real. We can
							 speed things up a bit if we calculate the terms used in
							 the real and imaginary parts of ctot in one go... not
							 very neat though*/

							t1=s[ii]*exp(-2.0*k*(tvals[ii]-t)*k*(tvals[ii]-t)*pi*pi/(z0[0]*z0[0]));
							t2=exp((-z0[0]*z0[0] / 2.0) - 2*k*(tvals[ii]-t)*k*(tvals[ii]-t)*pi*pi/(z0[0]*z0[0]));
							ctotr=ctotr+sin(2.0* pi*k*(tvals[ii]-t))*t1 - t2;
							ctoti=ctoti+cos(2.0* pi*k*(tvals[ii]-t))*t1 - t2;
						}
						/*						printf("%f  %f\n",ctotr,ctoti);*/
						wave[index]=sqrt(k)*sqrt(ctotr*ctotr+ctoti*ctoti);
						/* /(1+lmax-lmin); hmmmm...*/
					}
					else if (j != 0)
						/* if skipping, copy previous calculation...*/
						wave[index]=wave[index-1];
					
				}
			}
		}
	}
	else {
		printf("input array(s) exceed maximum sizes");
	}
	/*	printf("test: %e %e\n",psi_r(z0[0]/10.0,z0[0]),psi_i(z0[0]/10.0,z0[0]));*/
  return;
}

void mexFunction(
                 int nlhs,       mxArray *plhs[],
                 int nrhs, const mxArray *prhs[]
		 )
{
  double	*wave;
  double	*s,*freq,*dt,*z0,*skip;
  unsigned int	mm,nn;
  
  /* Check for proper number of arguments */
  
  if (nrhs != 5) {
    mexErrMsgTxt("WAVELET requires five input arguments.");
  } else if (nlhs > 1) {
    mexErrMsgTxt("WAVELET requires one output argument.");
  }
  
  mm = mxGetN(S_IN);
  nn = mxGetN(FREQ_IN);

  skip = mxGetPr(SKIP_IN);
	
  /* Create a matrix for the return argument */
  
  WAVE_OUT = mxCreateDoubleMatrix(mm, nn, mxREAL);
  
  
  /* Assign pointers to the various parameters */
  
  wave = mxGetPr(WAVE_OUT);
  
  s = mxGetPr(S_IN);
  freq = mxGetPr(FREQ_IN);
  dt = mxGetPr(DT_IN);
  z0 = mxGetPr(Z0_IN);
  skip = mxGetPr(SKIP_IN);
  
  
  /* Do the actual computations in a subroutine */
  
  wavelet(wave,s,freq,dt,z0,skip,mm,nn);
  return;
}

/*The following aren't used now the main code has been mutilated...*/

double psi_i(double x, double z0)
{
	double p;

	p=sin(2.0* pi*x)*exp(-2.0*x*x*pi*pi/(z0*z0))-exp((-z0*z0 / 2.0) - 2*x*x*pi*pi/(z0*z0));

	return p;
}
	
double psi_r(double x, double z0)
{
	double p;

	p=cos(2.0* pi*x)*exp(-2.0*x*x*pi*pi/(z0*z0))-exp((-z0*z0 / 2.0) - 2*x*x*pi*pi/(z0*z0));

	return p;
}
	
