#include "stdlib.h"
#include "stdio.h"
#include <math.h>
#include <time.h>

static void ReadInfo(int *pNumRow, int *pNumCol, int *pM, int *pN);
static void ReadTable (int NumRow, int NumCol, int **table, int *rowSum, 
		       int *colSum, int **s0, float **mle);
static double ChiSquareStatistic(int **t, int **s0, float **mle, int NumRow,
				 int NumCol);
static void  InitialSetting (double *pSampleMean, double *pSampleVar, 
			     double *pPValue);
static void GenerateATable (int **table, int **s0, int NumRow, int NumCol, 
			   int *rowSum, int *colSum, double *pCurrentWeight);
static void UpdatePValue (double currentWeight, int **table, int **s0,
			  float **mle, int NumRow, int NumCol, 
			  double chiSquare0, double *pSampleMean, 
		          double *pSampleVar, double *pPValue);
static void PrintOutput (FILE *outFile, int N, double sampleMean, 
			 double sampleVar, double pValue);
static double  NChooseM(int n, int m);
static double Mean(double x[], int n);
static int RandomInteger(int low, int high);

main()
{
        int i, NumRow, NumCol, M, N, numEstimate, numSample;
	int *rowSum, *colSum, **table, **s0;
	float  **mle;
	FILE *outFile;
	double sampleMean, sampleVar, chiSquare0, chiSquare, pValue, 
	  currentWeight;

      	ReadInfo(&NumRow, &NumCol, &M, &N);

	//allocate space for pointers
	rowSum=malloc(NumRow*sizeof(int));
	colSum=malloc(NumCol*sizeof(int));
	mle=malloc(NumRow*sizeof(float *));
	table=malloc(NumRow*sizeof(int *));
	s0=malloc(NumRow*sizeof(int *));
        for (i=0; i<NumRow; i++) {  
	  mle[i]=malloc(NumCol*sizeof(float));
	  table[i]=malloc(NumCol*sizeof(int));
	  s0[i]=malloc(NumCol*sizeof(int));
        }		
	//end of allocation

	srand((int) time(NULL));

	ReadTable(NumRow, NumCol, table, rowSum, colSum, s0, mle);

	chiSquare0=ChiSquareStatistic(table, s0, mle, NumRow, NumCol);

	outFile=fopen("result.txt", "w");
	if (outFile != NULL) {

	  for (numEstimate=0; numEstimate<M; numEstimate++) {
	    InitialSetting(&sampleMean, &sampleVar, &pValue);

	    for (numSample=0; numSample<N; numSample++) {
	      GenerateATable(table, s0, NumRow, NumCol, rowSum, colSum, 
			     &currentWeight);

	      UpdatePValue(currentWeight, table, s0, mle, NumRow, NumCol, 
		       chiSquare0, &sampleMean, &sampleVar, &pValue);
	    }
	    PrintOutput(outFile, N, sampleMean, sampleVar, pValue);
	  }  
	}
	else { printf("File creating failed!\n");}
	fclose(outFile);
}

static void ReadInfo(int *pNumRow, int *pNumCol, int *pM, int *pN)
{
        int NumRow, NumCol, M, N;

	printf("\nPlease enter the number of rows in the table?\n");
	scanf("%d", &NumRow);
	printf("\nPlease enter the number of columns in the table?\n");
	scanf("%d", &NumCol);
	printf("\nHow many estimates do you need?\n");
	scanf("%d", &M);
	printf("\nHow many samples should each estimate be based on?\n");
	scanf("%d", &N);	

	*pNumRow=NumRow;
	*pNumCol=NumCol;
	*pM=M;
	*pN=N;
}

static void ReadTable (int NumRow, int NumCol, int **table, int *rowSum, 
		       int *colSum, int **s0, float **mle)
{
        int i, j;
	FILE *inFile;
	char nameInFile[256];

	printf("\nPlease enter the data file name:\n");
	scanf("%s", nameInFile);
	while((inFile=fopen(nameInFile, "r"))==NULL)
	  {
	    printf("\nCan't read file. Please enter the data file name:\n");
	    scanf("%s", nameInFile);
	  }

	for (i=0; i<NumRow; i++) {   
	  for (j=0; j<NumCol; j++) {
	    fscanf(inFile, "%d", &table[i][j]);
          }
        }       

	for (i=0; i<NumRow; i++) {   
	  for (j=0; j<NumCol; j++) {
            fscanf(inFile, "%d", &s0[i][j]);
          }
        }       

	for (i=0; i<NumRow; i++) {            
          for (j=0; j<NumCol; j++) {
            fscanf(inFile, "%f", &mle[i][j]);
          }
        }       

        for (i=0; i<NumRow; i++) { 
          rowSum[i]=0;
          for (j=0; j<NumCol; j++) {
            rowSum[i]=rowSum[i]+table[i][j];
          }
        }

        for (j=0; j<NumCol; j++) { 
          colSum[j]=0;
          for (i=0; i<NumRow; i++) {
            colSum[j]=colSum[j]+table[i][j];
          }
        }
}

static double ChiSquareStatistic(int **t, int **s0, float **tmp, int NumRow,
				 int NumCol)
{
	double sum;
	int i,j;

	sum=0;
	for (i=0; i<NumRow; i++) {
		for (j=0; j<NumCol; j++) {
		  if (s0[i][j] == 0) 
		    sum=sum+(t[i][j]-tmp[i][j])*(t[i][j]-tmp[i][j])/tmp[i][j];
		}
	}
	return (sum);
}

static void  InitialSetting (double *pSampleMean, double *pSampleVar, 
			     double *pPValue)
{    
	    *pSampleMean=0;
	    *pSampleVar=0;
	    *pPValue=0;
}

static void GenerateATable (int **table, int **s0, int NumRow, int NumCol, 
			   int *rowSum, int *colSum, double *pCurrentWeight)
{
  int j, k, m, temp, leftColSum, max, min;
  int *currentRowSum, *upper, *lower, *leftLowerWeight, *leftUpperWeight; 
  double currentWeight;

        currentRowSum=malloc(NumRow*sizeof(int));
	lower=malloc((NumRow+1)*sizeof(int));
	upper=malloc((NumRow+1)*sizeof(int));
	leftUpperWeight=malloc((NumRow+1)*sizeof(int));
	leftLowerWeight=malloc((NumRow+1)*sizeof(int));

	      for (j=0;j<NumRow; j++) {
		currentRowSum[j]=rowSum[j];
	      }
	      currentWeight=1.0;

	      for (j=0; j<NumCol-1; j++) {
		if (colSum[j]>0) {
		  for (k=0; k<NumRow; k++) {		    
		    upper[k] = 0;
		    lower[k] = 0;
		    if (s0[k][j] == 0) {
		      temp=currentRowSum[k];
		      for (m=j+1; m<NumCol; m++) {
			if (s0[k][m]==0) temp=temp-colSum[m];
		      }
		      if (temp>0) lower[k]=temp;

		      upper[k]=(currentRowSum[k]<colSum[j])? currentRowSum[k] : colSum[j];
		    }
		  }

		  leftColSum=colSum[j];
		  leftLowerWeight[NumRow]=0;
		  leftUpperWeight[NumRow]=0;

		  for (k=NumRow-1; k>=0; k--) {
		    leftUpperWeight[k]=leftUpperWeight[k+1]+upper[k];
		    leftLowerWeight[k]=leftLowerWeight[k+1]+lower[k];
		  }

		  for (k=0; k<NumRow-1; k++) {
		    if (s0[k][j]==1) table[k][j]=0;
		    else {
		      min=(lower[k]>=leftColSum-leftUpperWeight[k+1])?
			lower[k]: (leftColSum-leftUpperWeight[k+1]);

		      max=(upper[k] <=
			   leftColSum-leftLowerWeight[k+1])? upper[k] : (leftColSum-leftLowerWeight[k+1]);

		      if (min>max) printf("Cannot sample!");

		      table[k][j]=RandomInteger(min, max);
		      leftColSum=leftColSum-table[k][j];
		      currentWeight=currentWeight*(max-min+1);
		    }
		  }
		  if (leftColSum != 0 && s0[NumRow-1][j]==1) printf("Inconsistent!!!");
		  else table[NumRow-1][j]=leftColSum;

		  for (k=0; k<NumRow; k++) {
		    currentRowSum[k]=currentRowSum[k]-table[k][j];
		  }
		}					  
		else {
		  for (k=0; k<NumRow; k++)
		    table[k][j]=0;
		}
	      }

	      for (k=0; k<NumRow; k++) {
		if (s0[k][NumCol-1]==1 && currentRowSum[k] !=0)
		  printf("Error!"); 
		else table[k][NumCol-1]=currentRowSum[k];
	      }
	      *pCurrentWeight=currentWeight;

	      free(currentRowSum);
	      free(lower);
	      free(upper);
	      free(leftUpperWeight);
	      free(leftLowerWeight);
}

static void UpdatePValue (double currentWeight, int **table, int **s0,
			  float **mle, int NumRow, int NumCol, 
			  double chiSquare0, double *pSampleMean, 
		          double *pSampleVar, double *pPValue) 
{	
  double chiSquare;

                (*pSampleMean) = (*pSampleMean)+currentWeight;
		(*pSampleVar)=(*pSampleVar)+currentWeight*currentWeight;

	        chiSquare=ChiSquareStatistic(table, s0, mle, NumRow, NumCol);
		if (chiSquare<=chiSquare0) (*pPValue)=(*pPValue)+currentWeight;
}	

static void PrintOutput (FILE *outFile, int N, double sampleMean, 
			 double sampleVar, double pValue)
{	    
            double cvsquare;

	    pValue=pValue/sampleMean;
	    cvsquare=N*sampleVar/sampleMean/sampleMean-1;
	    sampleMean=sampleMean/N;

            fprintf(outFile, "%15g %15g %15g\n", sampleMean, pValue, cvsquare);
	    printf("%15g %15g %15g\n", sampleMean, pValue, cvsquare);
}

static double  NChooseM(int n, int m)
{
	int i;
	double value;

	value=1.0;
	if ((m<0) || (n<0) || (n<m)) return(0);
	if (m==0) return(1.0);

	for (i=1; i<=m; i++) {
		value=value*(n-i+1)/i;
	}
	return (value);

}

static double Mean(double x[], int n)
{
	int i;
	double sum=0;

	for(i=0; i<n; i++) sum = sum+x[i];
	return (sum/n);
}

static int RandomInteger(int low, int high)
{
  int k;
  double d;

  d=(double) rand()/((double) RAND_MAX+1);
  k=(int) (d*(high-low+1));
  return (low+k);
}