#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <sys/time.h>
#include <omp.h>
//#include <iostream>
//#include <iterator>

double max(double a, double b)
{
    if (a>b)
    {
        return a;
    }else {
        return b;
    }
}
double min(double c, double d)
{
    if (c<d)
    {
        return c;
    }else {
        return d;
    }
}

int main() {
    float H0 = 1;
    float g = 10;
    float T = 25;

    float dx = 0.4;
    float dy = 0.4;
    float dt;
    dt = min(dx, dy)/sqrt(2*g*H0)/3;

    int Nx = 800+1;
    int Ny = 800+1;
    int Nt = floor(T/dt)+1;
    int N = (Nx+1) * (Ny+1);
    int size = N*sizeof(float);


    float X = (Nx-1)*dx;
    float Y = (Ny-1)*dx;




    printf("\nNx: %i\nNy: %i\nNt: %i\n", Nx, Ny, Nt);
    float thrs = 0.002;
    float A0 = 0.5;

    float *x;
    x = (float *)malloc(Nx*sizeof(float));
    float *y;
    y = (float *)malloc(Ny*sizeof(float));
    for(int i = 0; i < Nx; i++) {
        x[i] = 0 + (i*dx);
//        printf("%f\n", x[i]);
    }
    for(int j = 0; j < Ny; j++) {
        y[j] = 0+j*dy;
    }

    //=======================//
    float *E0;
    float *E1;
    float *h;
    float *hh;
    float *U0;
    float *U1;
    float *qu;
    float *V0;
    float *V1;
    float *qv;

    

    E0 = (float *)malloc(size);
    E1 = (float *)malloc(size);
    h = (float *)malloc(size);
    hh = (float *)malloc(size);
    U0 = (float *)malloc(size);
    U1 = (float *)malloc(size);
    qu = (float *)malloc(size);
    V0 = (float *)malloc(size);
    V1 = (float *)malloc(size);
    qv = (float *)malloc(size);

    double wtime;
    wtime = omp_get_wtime();

    #pragma omp num threads(4)    
//bottom profile

    float r    = 3.6;
    float alpa = 1/19.85;
    float x0   = X/2;
    float y0   = Y/2;

    #pragma omp parallel for collapse(2)
    for(int i = 0; i < Nx; i++) {
        for(int j = 0; j < Ny; j++) {
     //       r[Ny * j + i] = sqrt(pow((x[i]-(X/2)), 2)+(pow((y[j]-(Y/2)), 2)));
   //         h[i][j] = -1 - (r*alpa - H0 - sqrt(pow(alpa,2)*(pow(x[i]-x0,2) + pow(y[j]-y0,2))));

            h[Ny * j + i] = - 1 - (r*alpa - H0 - sqrt(pow(alpa,2)*(pow(x[i]-x0,2) + pow(y[j]-y0,2))));
 //  printf("\nh[%i][%i] = %f", i, j, h[Ny * j + i]);
        }
    }

    //initial profile
    #pragma omp parallel for collapse(2)
    for(int i = 0; i < Nx; i++) {
        for(int j = 0; j < Ny; j++) {
            if ( x[i]<= 0.3*X && y[j]>= 0.7*Y ){
                E0[Ny * j + i] = max(A0,-h[Ny * j + i]+thrs);
            }else{
                E0[Ny * j + i] = max(0,-h[Ny * j + i]+thrs);
            }
            E1[Ny * j + i] = E0[Ny * j + i];
            hh[Ny * j + i] = E0[Ny * j + i]+h[Ny * j + i];
        }
    }


   /* for(i = 0; i < Nx; i++) {
        for(j = 0; j < Ny; j++) {
            printf("A[%i][%i] = %0.3f\n", i, j, E0[Ny * j + i]);
        }
    }
*/
    //calculation
    float *advx;
    float *advy;
    advx = (float *)malloc(size);
    advy = (float *)malloc(size);
    float ufd, ufu;
    //start timer
   // double total_time;
	//clock_t start, end;
	//start = clock();
	//time count starts
	//srand(time(NULL));

    for(int n = 0; n <= Nt; n++) {
        //calculate u
    #pragma omp parallel for collapse(2)

        for(int j = 0; j < Ny; j++) {
            for(int i = 1; i < Nx; i++) {
                if(hh[Ny * j + i] > thrs || hh[Ny * j + (i-1)] > thrs) {
                    advx[Ny * j + i] = 0;
                    advy[Ny * j + i] = 0;
                    ufd = (qu[Ny * j + (i-1)]+qu[Ny * j + i])/(max(thrs, hh[Ny * j + i]+hh[Ny * j + (i-1)]));
                    ufu = (qu[Ny * j + (i+1)]+qu[Ny * j + i])/(max(thrs, hh[Ny * j + i]+hh[Ny * j + (i-1)]));
                    advx[Ny * j + i] = advx[Ny * j + i]+(ufu<0)*ufu*(U0[Ny * j + (i+1)]-U0[Ny * j + i])+(ufd>0)*ufd*(U0[Ny * j + i]-U0[Ny * j + (i-1)]);
                  //  printf("%f\n", advx[Ny * j + i]);
                    if(j == 0) {
                        ufd = qv[Ny * j + i]/max(thrs,hh[Ny * j + i]);
                        ufu = (qv[Ny * j + i]+qv[Ny * (j+1) + i])/max(thrs,hh[Ny * j + i]+hh[Ny * j + (i-1)]);
                        advy[Ny * j + i] = advy[Ny * j + i]+(ufu<0)*ufu*(U0[Ny * (j+1) + i]-U0[Ny * j + i])+(ufd>0)*ufd*(U0[Ny * j + i]);
                    }
                    else if(j == Ny-1) {
                        ufd = (qv[Ny * j + i]+qv[Ny * (j-1) + i])/max(thrs,hh[Ny * j + i]+hh[Ny * j + (i-1)]);
                        ufu = qv[Ny * j + i]/max(thrs,hh[Ny * j + i]);
                        advy[Ny * j + i] = advy[Ny * j + i]+(ufu<0)*ufu*(-U0[Ny * j + i])+(ufd>0)*ufd*(U0[Ny * j + i]-U0[Ny * (j-1) + i]);
                    }
                    else {
                        ufd = (qv[Ny * j + i]+qv[Ny * (j-1) + i])/max(thrs,hh[Ny * j + i]+hh[Ny * j + (i-1)]);
                        ufu = (qv[Ny * j + i]+qv[Ny * (j+1) + i])/max(thrs,hh[Ny * j + i]+hh[Ny * j + (i-1)]);
                        advy[Ny * j + i] = advy[Ny * j + i]+(ufu<0)*ufu*(U0[Ny * (j+1) + i]-U0[Ny * j + i])+(ufd>0)*ufd*(U0[Ny * j + i]-U0[Ny * (j-1) + i]);
                    }
                    U1[Ny * j + i] = U0[Ny * j + i]-dt*(g/dx*(E0[Ny * j + i]-E0[Ny * j + (i-1)])+advx[Ny * j + i]/dx+advy[Ny * j + i]/dy);

                }
                else {
                    U1[Ny * j + i] = 0;
                }
               /* printf("U1[%i][%i] = %f\n", i, j, U1[Ny * j + i]);
                    getchar();
*/
            }
        }
    #pragma omp parallel for collapse(2)

        for(int j = 1; j < Ny; j++) {
            for(int i = 0; i < Nx; i++) {
                if(hh[Ny * j + i]>thrs || hh[Ny * (j-1) + i]>thrs) {
                    advx[Ny * j + i] = 0;
                    advy[Ny * j + i] = 0;
                    ufd = (qv[Ny * (j-1) + i]+qv[Ny * j + i])/max(thrs,hh[Ny * j + i]+hh[Ny * (j-1) + i]);
                    ufu = (qv[Ny * (j+1) + i]+qv[Ny * j + i])/max(thrs,hh[Ny * j + i]+hh[Ny * (j-1) + i]);
                    advy[Ny * j + i] = advy[Ny * j + i]+(ufu<0)*ufu*(V0[Ny * (j+1) + i]-V0[Ny * j + i])+(ufd>0)*ufd*(V0[Ny * j + i]-V0[Ny * (j-1) + i]);
                    if(i==Nx-1) {
                        ufu = qu[Ny * j + i]/max(thrs,hh[Ny * j + i]);
                        ufd = (qu[Ny * j + (i-1)]+qu[Ny * j + i])/max(thrs,hh[Ny * j + i]+hh[Ny * (j-1) + i]);
                        advx[Ny * j + i] = advx[Ny * j + i]+(ufu<0)*ufu*(-V0[Ny * j + i])+(ufd>0)*ufd*(V0[Ny * j + i]-V0[Ny * j + (i-1)]);
                    }
                    else if(i == 0) {
                        ufu = (qu[Ny * j + (i+1)]+qu[Ny * j + i])/max(thrs,hh[Ny * j + i]+hh[Ny * (j-1) + i]);
                        ufd = qu[Ny * j + i]/max(thrs,hh[Ny * j + i]);
                        advx[Ny * j + i] = advx[Ny * j + i]+(ufu<0)*ufu*(V0[Ny * j + (i+1)]-V0[Ny * j + i])+(ufd>0)*ufd*(V0[Ny * j + i]);
                    }
                    else {
                        ufd = (qu[Ny * j + (i-1)]+qu[Ny * j + i])/max(thrs,hh[Ny * j + i]+hh[Ny * (j-1) + i]);
                        ufu = (qu[Ny * j + (i+1)]+qu[Ny * j + i])/max(thrs,hh[Ny * j + i]+hh[Ny * (j-1) + i]);
                        advx[Ny * j + i] = advx[Ny * j + i]+(ufu<0)*ufu*(V0[Ny * j + (i+1)]-V0[Ny * j + i])+(ufd>0)*ufd*(V0[Ny * j + i]-V0[Ny * j + (i-1)]);
                    }
                    V1[Ny * j + i] = V0[Ny * j + i]-dt*(g/dy*(E0[Ny * j + i]-E0[Ny * (j-1) + i])+advy[Ny * j + i]/dy+advx[Ny * j + i]/dx);
                    /*if(V1[Ny * j + i] != 0) {
                        printf("\nt = %i\n V1[%i][%i] = %f\n", n, i, j, V1[Ny * j + i]);
                        getchar();
                    }*/
                }
                else {
                    V1[Ny * j + i] = 0;
                }
            }
        }
    #pragma omp parallel for

        //update for u
        for(int j = 0; j < Ny; j++) {
            qu[Ny * j + 0] = 0;
            qu[Ny * j + (Nx-1)] = 0;
        }
    #pragma omp parallel for collapse(2)

        for(int i = 1; i < Nx; i++) {
            for(int j = 0; j < Ny; j++) {
                if(U0[Ny * j + i] > 0) {
                    hh[Ny * j + i] = E0[Ny * j + (i-1)]+h[Ny * j + (i-1)];
                }
                else if(U0[Ny * j + i] < 0) {
                    hh[Ny * j + i] = E0[Ny * j + i]+h[Ny * j + i];
                }
                else {
                    hh[Ny * j + i] = max(E0[Ny * j + (i-1)], E0[Ny * j + i])+min(h[Ny * j + (i-1)], h[Ny * j + i]);
                }
                qu[Ny * j + i] = U1[Ny * j + i]*hh[Ny * j + i];
            }
        }

        // update for v
    #pragma omp parallel for

        for(int i = 0; i < Nx; i++) {
            qv[Ny * 0 + i] = 0;
            qv[Ny * (Ny-1) + i] = 0;
        }
    #pragma omp parallel for collapse(2)

        for(int j = 1; j < Ny; j++) {
            for(int i = 0; i < Nx; i++) {
                if(V0[Ny * j + i] > 0) {
                    hh[Ny * j + i] = E0[Ny * (j-1) + i]+h[Ny * (j-1) + i];
                }
                else if(V0[Ny * j + i] < 0) {
                    hh[Ny * j + i] = E0[Ny * j + i]+h[Ny * j + i];
                }
                else {
                    hh[Ny * j + i] = max(E0[Ny * (j-1) + i], E0[Ny * j + i])+min(h[Ny * (j-1) + i], h[Ny * j + i]);
                }
                qv[Ny * j + i] = V1[Ny * j + i]*hh[Ny * j + i];
            }
        }

    #pragma omp parallel for collapse(2)
        for(int j = 0; j < Ny; j++) {
            for(int i = 0; i < Nx; i++) {
                E1[Ny * j + i] = E0[Ny * j + i]-(dt/dx)*(qu[Ny * j + (i+1)]-qu[Ny * j + i])-(dt/dy)*(qv[Ny * (j+1) + i]-qv[Ny * j + i]);
                hh[Ny * j + i] = E1[Ny * j + i]+h[Ny * j + i];
            }
        }


        //back substitution
    #pragma omp parallel for collapse(2)
        for(int i = 0; i < Nx; i++) {
            for(int j = 0; j < Ny; j++) {
                E0[Ny * j + i] = E1[Ny * j + i];
  //              printf("\nn = %i\nE1[%i][%i] = %g", n, i, j, E1[Nx*i+j]);
            }
        }
//                getchar();
    #pragma omp parallel for collapse(2)
        for(int i = 0; i <= Nx; i++) {
            for(int j = 0; j < Ny; j++) {
                U0[Ny * j + i] = U1[Ny * j + i];
            }
        }
    #pragma omp parallel for collapse(2)
        for(int i = 0; i < Nx; i++) {
            for(int j = 0; j <= Ny; j++) {
                V0[Ny * j + i] = V1[Ny * j + i];
            }
        }
     //   printf("\nn = %i", n);
    }

    //finish timing
  //  end = clock();
	//time count stops
//	total_time = ((double) (end - start)) / CLK_TCK;
	//calulate total time
 //   for(int i = 0; i < Nx; i++) {
   //     for(int j = 0; j < Ny; j++) {
     //       printf("E0[%i][%i] = %f\n", i, j, E0[Ny * j + i]);
       // }
   // }
	printf("\nParallel execution time: %f s\n",omp_get_wtime() - wtime);


    return 0;
}

