#define PROPAGATION_SOURCE_CODE_C_C
#include "propagation.h"

int output_flag;
#define OUT_PUT_ITERAION  -1
double g(double x, CONTROL_INFO *controlInfo)
{
  double tmp;
  switch (controlInfo->activation_choice) {
  case 1:
    tmp = tanh((double)(controlInfo->BETA*x));
    break;
  default:
    tmp = 1.0/(1.0+exp(0.-2*controlInfo->BETA*x));
  }
  return tmp;
}

double g_prime(double x, CONTROL_INFO *controlInfo)
{
  double tmp;
  switch (controlInfo->activation_choice) {
  case 1:
    tmp = controlInfo->BETA * (1-x*x);
    break;
  default:
    tmp = 2.*controlInfo->BETA * (1.0 - x)*x;
  }
  return tmp;
}

void init_mlp_archit(char *fname, MLP_NETWORK *anetwork)
{
  int i,j, k;
  FILE *IN;
  char option[256];
  if (anetwork->INPUT_UNIT > 0) return;
  anetwork->INPUT_UNIT = 3;
  anetwork->HIDDEN_UNIT = 4;
  anetwork->OUTPUT_UNIT = 1;
  IN = fopen(fname,"rb");
  if (IN != NULL) {
    if ( fscanf(IN,"%d%d%d", &i,&j,&k) == 3) {
      anetwork->INPUT_UNIT = i;
      anetwork->HIDDEN_UNIT = j;
      anetwork->OUTPUT_UNIT = k;
    }
    fclose(IN);
  }
  do {
    printf("Now the network has:\n\t%d units in input layer\n",
	   anetwork->INPUT_UNIT);
    printf("\t%d in hidden layer\n",
	   anetwork->HIDDEN_UNIT);
    printf("\t%d in output layer.\n",anetwork->OUTPUT_UNIT);
    printf("Do you want to change the architecture (Y/n): ");
    scanf("%s", option);
    if (option[0]=='N' || option[0] == 'n') break;
    if (option[0]=='Y' || option[0] == 'y') {
      printf("Please specify the units in input, hidden and output layers: ");
      scanf("%d%d%d",
	    &(anetwork->INPUT_UNIT),
	    &(anetwork->HIDDEN_UNIT),
	    &(anetwork->OUTPUT_UNIT));
    }
  } while(1);
  return;
}

void free_mlp_network(MLP_NETWORK *anetwork)
{
  Free_Array((void **)anetwork->w12, anetwork->HIDDEN_UNIT);
  Free_Array((void **)anetwork->w23, anetwork->OUTPUT_UNIT);
  Free_Array((void **)anetwork->deltaw12, anetwork->HIDDEN_UNIT);
  Free_Array((void **)anetwork->deltaw23, anetwork->OUTPUT_UNIT);
  free(anetwork->delta2);
  free(anetwork->delta3);
  free(anetwork->v1);
  free(anetwork->v2);
  free(anetwork->v3);
  free(anetwork->v_out);
  return;
}

void allocate_mlp_network(MLP_NETWORK *anetwork)
{
  anetwork->w12 = 
    (double **)Allocate_Array(anetwork->HIDDEN_UNIT,
			      (anetwork->INPUT_UNIT+1) * 
			      sizeof(double));
  anetwork->w23 = 
    (double **)Allocate_Array(anetwork->OUTPUT_UNIT,
			      (anetwork->HIDDEN_UNIT+1) *
			      sizeof(double));
  anetwork->deltaw12 = 
    (double **)Allocate_Array(anetwork->HIDDEN_UNIT,
			      (anetwork->INPUT_UNIT+1) * 
			      sizeof(double));
  anetwork->deltaw23 = 
    (double **)Allocate_Array(anetwork->OUTPUT_UNIT,
			      (anetwork->HIDDEN_UNIT+1) *
			      sizeof(double));
  anetwork->delta2 = 
    (double *)malloc(sizeof(double) * (anetwork->HIDDEN_UNIT+1));
  anetwork->delta3 = 
    (double *)malloc(sizeof(double) * anetwork->OUTPUT_UNIT);
  anetwork->v1 = 
    (double *)malloc(sizeof(double) * (anetwork->INPUT_UNIT+1));
  anetwork->v2 = 
    (double *)malloc(sizeof(double) * (anetwork->HIDDEN_UNIT+1));
  anetwork->v3 = 
    (double *)malloc(sizeof(double) * anetwork->OUTPUT_UNIT);
  anetwork->v_out = 
    (double *)malloc(sizeof(double) * anetwork->OUTPUT_UNIT);
  return;
}

void init_mlp_network(MLP_NETWORK *anetwork, CONTROL_INFO *controlInfo)
{
  int i,j;

  srand48(controlInfo->seed);
  for (i=0; i < anetwork->HIDDEN_UNIT; i++) {
    for (j=0; j < (anetwork->INPUT_UNIT+1); j++) {
      anetwork->w12[i][j] = (drand48()-0.5)*controlInfo->init_weight_scale;
      anetwork->deltaw12[i][j] = 0.0;
    }
  }
  for (i =0; i <anetwork-> OUTPUT_UNIT; i++) {
    for (j=0; j < (anetwork->HIDDEN_UNIT+1); j++) {
      anetwork->w23[i][j] = (drand48()-0.5)*controlInfo->init_weight_scale;
      anetwork->deltaw23[i][j] = 0.0;
    }
  }
  return;
}

int save_mlp_network(char *fname, MLP_NETWORK *anetwork)
{
  int i,j;
  FILE *OUT;
  
  OUT = fopen(fname,"wb");
  
  fprintf(OUT,"%d %d %d\n", 
	  anetwork->INPUT_UNIT, 
	  anetwork->HIDDEN_UNIT, 
	  anetwork->OUTPUT_UNIT);
  
  for (i=0; i < anetwork->HIDDEN_UNIT; i++) {
    for (j=0; j < (anetwork->INPUT_UNIT+1); j++) {
      fprintf(OUT,"%lf ", anetwork->w12[i][j]);
    }
    fprintf(OUT,"\n");
  }
  fprintf(OUT,"\n\n");
  for (i =0; i < anetwork->OUTPUT_UNIT; i++) {
    for (j=0; j < (anetwork->HIDDEN_UNIT+1); j++) {
      fprintf(OUT,"%lf ", anetwork->w23[i][j]);
    }
    fprintf(OUT,"\n");
  }
  
  fprintf(OUT,"\n\n");
  fclose(OUT);
  printf("Network weights are saved in \"%s\"\n",
	 fname);
  return 0;
}


int load_mlp_network(char *fname, MLP_NETWORK *anetwork)
{
  int i,j, k;
  FILE *IN;
  IN = fopen(fname,"rb");
  if (IN == NULL) {
    fprintf(stderr,"Cannot open network file \"%s\".\n", fname);
    return -1;
  }
  fscanf(IN,"%d%d%d", &i,&j,&k);
  if (i != anetwork->INPUT_UNIT || 
      j != anetwork->HIDDEN_UNIT || 
      k != anetwork->OUTPUT_UNIT) {
    fprintf(stderr,"The network in \"%s\" is not consistent ", fname);
    fprintf(stderr,"network architecture.\n");
    fprintf(stderr,"In the file, the network is (%d,%d,%d) ",
	    i,j,k);
    fprintf(stderr,"while current one is (%d,%d,%d).\n",
	    anetwork->INPUT_UNIT, 
	    anetwork->HIDDEN_UNIT, 
	    anetwork->OUTPUT_UNIT);
    fclose(IN);
    return -1;
  }
  
  for (i=0; i < anetwork->HIDDEN_UNIT; i++) {
    for (j=0; j < (anetwork->INPUT_UNIT+1); j++) {
      fscanf(IN,"%lf ", &(anetwork->w12[i][j]));
    }
  }
  
  for (i =0; i < anetwork->OUTPUT_UNIT; i++) {
    for (j=0; j < (anetwork->HIDDEN_UNIT+1); j++) {
      fscanf(IN,"%lf ", &(anetwork->w23[i][j]));
    }  
  }
  fclose(IN);
  printf("Network weights are loaded from \"%s\" successfully.\n",
	 fname);
  return 0;
}

double backpropagation(MLP_NETWORK *anetwork, CONTROL_INFO *controlInfo)
{
  double delta_error;
  double tmp;
  int i,j;
  
  
  output_flag++;
  /* forward propogation: input -> hidden layer */
  anetwork->v1[anetwork->INPUT_UNIT] = 1.0;
  for (i=0; i < anetwork->HIDDEN_UNIT; i++) {
    tmp =0.0;
    for (j=0; j < (anetwork->INPUT_UNIT+1); j++) {
      tmp += anetwork->w12[i][j] * anetwork->v1[j];
    }
    anetwork->v2[i] = g(tmp,controlInfo);
  }
  
  anetwork->v2[anetwork->HIDDEN_UNIT] = 1.0;
  
  /* forward propogation: hidden layer -> output layer */
  for (i=0; i < anetwork->OUTPUT_UNIT; i++) {
    tmp = 0.0;
    for (j=0; j < (anetwork->HIDDEN_UNIT+1); j++) {
      tmp += anetwork->w23[i][j] * anetwork->v2[j];
    }
    anetwork->v3[i] = g(tmp,controlInfo);
  }
  
  delta_error = 0.0;
  
  for (i=0; i < anetwork->OUTPUT_UNIT; i++) {
    delta_error += fabs(anetwork->v_out[i]-anetwork->v3[i]);
    anetwork-> delta3[i] = g_prime(anetwork->v3[i],controlInfo) * 
      (anetwork->v_out[i]-anetwork->v3[i]);
  }
  
  if (output_flag ==  OUT_PUT_ITERAION) {
    
    printf("Hidden layer: ");
  }

  /* back propagation: output layer -> hidden layer */
  
  for (j=0; j < anetwork->HIDDEN_UNIT; j++) {
    anetwork->delta2[j] = 0.0;
    for (i=0; i < anetwork->OUTPUT_UNIT; i++) {
      anetwork->delta2[j] += anetwork->delta3[i] * 
	anetwork->w23[i][j];
    }
    anetwork->delta2[j] *= g_prime(anetwork->v2[j],controlInfo);
    if (output_flag ==  OUT_PUT_ITERAION) {
      printf("%12.10f(%6.4f) ", anetwork->delta2[j],
	     anetwork->w23[0][j]);
    }
  }

  if (output_flag ==  OUT_PUT_ITERAION) {
    printf("\n");
    output_flag = 0;
  }
  /* update the weight: hidden layer -> output layer */
  for (i =0; i < anetwork->OUTPUT_UNIT; i++) {
    for (j=0; j < (anetwork->HIDDEN_UNIT+1); j++) {
      anetwork->deltaw23[i][j] = controlInfo->ETA * anetwork->delta3[i] * 
	anetwork->v2[j] + 
	controlInfo->ALPHA * anetwork->deltaw23[i][j];
      anetwork->w23[i][j] += anetwork->deltaw23[i][j];
    }
  }

  /* update the weight: input layer -> outlayer */
  for (i=0; i < anetwork->HIDDEN_UNIT; i++) {
    for (j=0; j < (anetwork->INPUT_UNIT+1); j++) {
      anetwork->deltaw12[i][j] = controlInfo->ETA * anetwork->delta2[i] * 
	anetwork->v1[j] + controlInfo->ALPHA * anetwork->deltaw12[i][j];
      anetwork->w12[i][j] += anetwork->deltaw12[i][j];
    }
  }
  
  return delta_error;
}


void forward_prop_only(MLP_NETWORK *anetwork, CONTROL_INFO *controlInfo)
{
  int i,j;
  double tmp;
  /* forward propogation: input -> hidden layer */
  anetwork->v1[anetwork->INPUT_UNIT] = 1.0;
  for (i=0; i < anetwork->HIDDEN_UNIT; i++) {
    tmp =0.0;
    for (j=0; j < (anetwork->INPUT_UNIT+1); j++) {
      tmp += anetwork->w12[i][j] * anetwork->v1[j];
    }
    anetwork->v2[i] = g(tmp, controlInfo);
  }
  
  anetwork->v2[anetwork->HIDDEN_UNIT] = 1.0;
  
  /* forward propogation: hidden layer -> output layer */
  for (i=0; i < anetwork->OUTPUT_UNIT; i++) {
    tmp = 0.0;
    for (j=0; j < (anetwork->HIDDEN_UNIT+1); j++) {
      tmp += anetwork->w23[i][j] * anetwork->v2[j];
    }
    anetwork->v3[i] = g(tmp, controlInfo);
  }
  
  return;
}