/*****************************************************************************
 This code is available for academic use
 under the LESSER GENERAL PUBLIC LICENSE 

 Weight Perturbation: enhancing local search strategies
 by perturbing the weights of training instances.
 Copyright (C) 2002  Gal Elidan, Matan Ninio, Nir Friedman and Dale Schuurmans

 This library is free software; you can redistribute it and/or
 modify it under the terms of the GNU Lesser General Public
 License as published by the Free Software Foundation; either
 version 2.1 of the License, or (at your option) any later version.
 
 This library is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 Lesser General Public License for more details.
 
 You should have received a copy of the GNU Lesser General Public
 License along with this library; if not, write to the Free Software
 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
 Please cite using this refrence:
 
@incollection{Elidan+al:2002,
   author = "Gal Elidan and Matan Ninio and Nir Friedman and Dale Schuurmans",
   booktitle = "Proc. National Conference on Artificial Intelligence (AAAI-02)",
   pages = "132-139",
   year = "2002",
   title = "Data Perturbation for Escaping Local Maxima in Learning",
 }
 
 You can contact the authors at annealing@cs.huji.ac.il
 
*****************************************************************************/

#include "WeightUpdate.h"
#include "RandomProb.h"
#include <math.h>

const double TINY = 10e-7;

vector<double> const& tWeightUpdate::GetWeights() 
{ 
  if ( _temp == 0.0 )
    _done = true;
  return _currWeights; 
}

tWeightUpdate::tWeightUpdate(const vector<double>& origWeights,
			     double startTemp,double stopTemp,double factor) : 
  _origWeights(origWeights), _currWeights(origWeights), 
  _startTemp(startTemp), _stopTemp(stopTemp), _factor(factor), _temp(-1.0), _done(false)
{
  assert(_startTemp>0.0 && _stopTemp>0.0 && _factor>0.0 && _factor<1.0);
}

void tWeightUpdate::CoolDown(vector<double> const* gradient = NULL)
{
  if ( IsDone() )
    return;
  if ( _temp < 0.0 ) 
    _temp = _startTemp;
  else {
    _temp = _temp * _factor;
    if ( _temp < _stopTemp ) {
      _temp = 0.0;
      // restore original weights
      for ( int i=0 ; i<_origWeights.size() ; i++ )
	_currWeights[i] = _origWeights[i];
      return;
    }
  }
  UpdateWeights(gradient);
}

tWeightUpdateRandom::tWeightUpdateRandom(const vector<double>& origWeights,
					 double startTemp,double stopTemp,double factor) :
  tWeightUpdate(origWeights,startTemp,stopTemp,factor)
{
}

void tWeightUpdateRandom::UpdateWeights(const vector<double>* gradient = NULL)
{
  vector<double>& currWeights = CurrentWeights();
  vector<double> const& origWeights = OrigWeights();
  double scale = sqrt(GetTemp());
  double var = 1.0/GetTemp();
  double wsum = 0.0;
  double origsum = 0.0;
  int i;
  // go over the instances and update new weights
  for ( i=0 ; i<currWeights.size() ; i++ ) {
    double noise = 0.0;
    double weight = 0.0;
    // sampling is according to original weight
    // This assumes original weights are whole numbers!!!
    for ( int w=1 ; w <= origWeights[i] ; w++ )
	weight += _RandomProbGenerator.DblRanGamma(var) * GetTemp();
    if ( weight < TINY )
      weight = TINY;
    currWeights[i] = weight;
    origsum += origWeights[i];
    wsum += weight;
  }
  // normalize
  double fact = origsum / wsum;
  for ( i=0 ; i<currWeights.size() ; i++ ) {
    currWeights[i] *= fact;
    //cerr << "Weight " << i << "=" << currWeights[i] << endl;
  }
}

tWeightUpdateGradient::tWeightUpdateGradient(const vector<double>& origWeights,
					     double startTemp,double stopTemp,double factor) :
  tWeightUpdate(origWeights,startTemp,stopTemp,factor),_lRate(2.0),_origDamp(1.0),_prevDamp(1.0)
{
}

void tWeightUpdateGradient::UpdateWeights(const vector<double>* gradient = NULL)
{
  // cerr << "Updating weight for gradient\n";
  // make sure we got all the gradients
  assert(gradient->size()==OrigWeights().size());

  vector<double>& currWeights = CurrentWeights();
  vector<double> const& origWeights = OrigWeights();
  double beta = _origDamp/GetTemp();
  double gamma = _prevDamp/GetTemp();
  double sfact = 1.0/(beta+gamma);
  double bfact = beta*sfact;
  double gfact = gamma*sfact;
  double origsum = 0.0;
  double maxw = -HUGE_VAL;
  // go over the instances and update new weights
  int i;
  for ( i=0 ; i<currWeights.size() ; i++ ) {
    origsum += origWeights[i];
    double weight = bfact*log(origWeights[i]) + gfact*log(currWeights[i]) - sfact*_lRate*(*gradient)[i];
    currWeights[i] = weight;
    if ( weight > maxw )
	  maxw = weight;
  }
  // normalize
  double wsum = 0.0;
  for ( i=0 ; i<currWeights.size() ; i++ ) {
    currWeights[i] -= maxw;
    currWeights[i] = exp(currWeights[i]);
    if ( currWeights[i] < TINY )
      currWeights[i] = TINY;
    wsum += currWeights[i];
  }
  double f = origsum / wsum;
  for ( i=0 ; i<currWeights.size() ; i++ ) {
    currWeights[i] *= f;
    //cerr << "Weight " << i << "=" << prevWeights[i] << endl;
  }
}












