/********************************************************************
 * This is a simple matrix-vector multiply, to demonstrate basic MPI
 * concepts.    
 * Slightly more complicated; decompose vector into blocks as well
 * 1/10/05 DCS
 ********************************************************************/

#include "mpi.h"
#include <stdio.h>

int main(int argc, char *argv[] ) {

	int numprocs, rank, chunk_size, i,j,k;
	int sendto,recvfrom;
	int max, mymax,rem;
	int matrix[800][800];
	int vector[800];
	int local_vector[800];
	int local_matrix[800][800];
	int result[800];
	int global_result[800];
	MPI_Status status;
	
	/* Initialize MPI */
	MPI_Init( &argc,&argv);
	MPI_Comm_rank( MPI_COMM_WORLD, &rank);
	MPI_Comm_size( MPI_COMM_WORLD, &numprocs);

	printf("Hello from process %d of %d \n",rank,numprocs);
   	chunk_size = 800/numprocs;
        
	
	if (rank == 0) { /* Only on the root task... */
	/* Initialize Matrix and Vector */
		for(i=0;i<800;i++) {
			vector[i] = i;
			result[i] = 0;
		   for(j=0;j<800;j++) {
				matrix[i][j] = 1;
		   }
		}
	}

	/* Distribute Matrix and Vector */
	/* Assume the matrix is too big to bradcast.  I'm going to send blocks of rows 
	   to each task, nrows/nprocs to each one */
   /* A "block" will be chunksize*800 elements... the number of rows to send times
      the size of a row */ 
	MPI_Scatter(matrix,800*chunk_size,MPI_INT,local_matrix,800*chunk_size,MPI_INT, 0,MPI_COMM_WORLD);
	MPI_Scatter(vector,chunk_size,MPI_INT,local_vector,800,MPI_INT,0,MPI_COMM_WORLD);

   /*Each processor has a chunk of rows, and a chunk of the vector. */
   /* Multiply, then exchange blocks of the vector  */
   for(k=0;k<numprocs;k++) {
		for(i=0;i<chunk_size;i++) {
			  for(j=0;j<chunk_size;j++) {
					result[i] += local_matrix[i][j]*local_vector[j];
			  }
		  }
		sendto = (rank + 1) % numprocs;
		recvfrom = (rank + numprocs - 1) % numprocs;
		printf("process %d sending to %d receiving from %d \n",rank,sendto,recvfrom);
		MPI_Sendrecv_replace(local_vector,chunk_size,MPI_INT,sendto,1,recvfrom,1,MPI_COMM_WORLD,&status);
	}
   /*Send result back to master */ 
	MPI_Gather(result,chunk_size,MPI_INT,global_result,chunk_size,MPI_INT, 0,MPI_COMM_WORLD);

	/*Display result */
	if(rank==0) {
		for(i=0;i<800;i++) {
			printf(" %d \t ",global_result[i]);
		}
   }

	MPI_Finalize();
	return 0;
}


