#include <stdio.h>

/*** Calculate the energy of a 2-D polymer on a lattice ****/

double energy(int *p_grid, int dim, double e_11, double e_22, double e_12)
{
double ene, econtrib;

int irow, icol, ineigh;
int *p_point, *p_neigh;
int inds[4], n11, n12, n22;
int n11t, n12t, n22t;

ene = 0.0;
n11t=0; n12t=0; n22t=0;

for (irow=0; irow< dim; irow++) 
  {
    p_point = p_grid+irow*dim;
    for (icol=0; icol< dim; icol++) 
      {
        inds[0] = (irow-1)*dim;
        if (inds[0] < 0) inds[0] = (dim -1)*dim;
        inds[0] += icol;

        inds[1] = (irow+1)*dim;
        if (inds[1] >= dim*dim) inds[1] = 0;
        inds[1] += icol;

        inds[2] = icol-1;
        if (inds[2] < 0) inds[2] = dim -1;
        inds[2] += irow*dim;

        inds[3] = icol+1;
        if (inds[3] >= dim) inds[3] = 0;
        inds[3] += irow*dim;

        n11=0; n12=0; n22=0;
/*        printf("point %d (type %d) has neighbs: ", irow*dim+icol, *p_point); */
        for (ineigh=0; ineigh<4; ineigh++)
          {
/*             printf(" %d (type %d) cont. %3.2f",                         */
/*                        inds[ineigh], *(p_grid+inds[ineigh]), econtrib); */

/*** remember 0 is solvent, 1 is polymer ***/
             if ( *p_point == 0 && *(p_grid+inds[ineigh]) == 0 ) 
               {
                 n22++;
                 econtrib=e_22;
               }
             if ( *p_point == 1 && *(p_grid+inds[ineigh]) == 1 ) 
               {
                 n11++;
                 econtrib=e_11;
               }
             if ( *p_point == 0 && *(p_grid+inds[ineigh]) == 1 ) 
               {
                 n12++;
                 econtrib=e_12;
               }
             if ( *p_point == 1 && *(p_grid+inds[ineigh]) == 0 )
               {
                 n12++;
                 econtrib=e_12;
               }

             ene += econtrib;
          }
/*        printf("S-S : %d, S-P : %d, P-P : %d\n", n22, n12, n11); */
/*        n11t+=n11; n12t+=n12; n22t+=n22;                         */

        p_point++;
      }
  }   

/*printf("\nTotals : S-S : %d, S-P : %d, P-P : %d\n", n22t, n12t, n11t); */

/*** remove double counting ***/
return ene/2.0;
}
