/*****************************************************************************
 This code is available for academic use
 under the LESSER GENERAL PUBLIC LICENSE 

 Weight Perturbation: enhancing local search strategies
 by perturbing the weights of training instances.
 Copyright (C) 2002  Gal Elidan, Matan Ninio, Nir Friedman and Dale Schuurmans

 This library is free software; you can redistribute it and/or
 modify it under the terms of the GNU Lesser General Public
 License as published by the Free Software Foundation; either
 version 2.1 of the License, or (at your option) any later version.
 
 This library is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 Lesser General Public License for more details.
 
 You should have received a copy of the GNU Lesser General Public
 License along with this library; if not, write to the Free Software
 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
 Please cite using this refrence:
 
@incollection{Elidan+al:2002,
   author = "Gal Elidan and Matan Ninio and Nir Friedman and Dale Schuurmans",
   booktitle = "Proc. National Conference on Artificial Intelligence (AAAI-02)",
   pages = "132-139",
   year = "2002",
   title = "Data Perturbation for Escaping Local Maxima in Learning",
 }
 
 You can contact the authors at annealing@cs.huji.ac.il
 
*****************************************************************************/

#ifndef __WeightUpdate_h
#define __WeightUpdate_h

#include <vector>

/** This class is the base interface for weight annealing.
    It is constructed with the initial and final temperature
    and the cooling factor and handles the annealing process.
    The class should be used as follows:

    // initilize weight update object (e.g. WU)
    //
    // annealing loop
    while ( !WU.IsDone() ) {
       WU.CoolDown();
       NewWeights = WU.GetWeights();
       // do learning 
       //
    }

    CoolDown() should be used BEFORE GetWeights() to work properly.
    After the temperature goes below the stopping tempearature,
    another "frozen" (temp=0.0) iteration is performed.
    Done() returns true after temperature has been set to 0.0
    and weights were retrieved.

    Written by: Gal Elidan
*/
class tWeightUpdate {
 public:
  /// Default constructor
  tWeightUpdate(const vector<double>& origWeights,
		double startTemp,double stopTemp,double factor);

  /// Default destructor
  ~tWeightUpdate() {};

  /// Cools temperature by _factor
  void CoolDown(vector<double> const* gradient = NULL);

  /// return the last weight vector
  vector<double> const& GetWeights();

  /// return original weight vector
  vector<double> const& OrigWeights() { return _origWeights; };

  /** Returns true when annealing is done. This happens after temperature has
      reached 0 and GetWeights was called
  */
  bool IsDone() { return _done; };

  /// Get the current temperature
  double GetTemp() { return _temp; };
  /// Get the cooling factor
  double GetFactor() { return _factor; };
  /// Get the stopping temperatur
  double GetStopTemperature() { return _stopTemp; };
 protected:
  /// Get non const reference to prev weights
  vector<double>& CurrentWeights() { return _currWeights; };
  /// Actual updating of weights
  virtual void UpdateWeights(const vector<double>* gradient = NULL) = 0;  

 private:
  vector<double> _origWeights;
  vector<double> _currWeights;
  double _startTemp;
  double _stopTemp;
  double _factor;
  double _temp;
  bool _done;
};

/** This class implement the interface of tWeightUpdate by
    random updating of weights centered around the original weights
    in proportion to the temperature.

    Written by: Gal Elidan
 */
class tWeightUpdateRandom : public tWeightUpdate {
 public:
  /// default constructor
  tWeightUpdateRandom(const vector<double>& origWeights,
		      double startTemp,double stopTemp,double factor);
  /// default destructor
  ~tWeightUpdateRandom() {};
 protected:
  /// update the weights ignoring gradients
  virtual void UpdateWeights(const vector<double>* gradient);
 private:
};

/** This class implement the interface of tWeightUpdate by
    doing a gradient update that is governed by the temperature
    and damping factors for the original and previous weights. 
    The learning rate determines the magnitude of update.

    Written by: Gal Elidan
 */
class tWeightUpdateGradient : public tWeightUpdate {
 public:
  /// default constructor
  tWeightUpdateGradient(const vector<double>& origWeights,
			double startTemp,double stopTemp,double factor);
  /// default destructor
  ~tWeightUpdateGradient() {};

  /// Get the learning rate
  double GetLearningRate() { return _lRate; };
  /// Set the learning rate for multiplicative update
  void SetLearningRate(double lRate) { _lRate = lRate; };
  /// Get the damping factor to original weights
  double GetOrigDamp() { return _origDamp; };
  /// Set damping factor with respect to original weights
  void SetOrigDamp(double origDamp) { _origDamp = origDamp; };
  /// Get the damping factor to previous weights
  double GetPrevDamp() { return _prevDamp; };
  /// Set damping factor with respect to previous weights
  void SetPrevDamp(double prevDamp) { _prevDamp = prevDamp; };

 protected:
  /// Update the weights using gradients
  virtual void UpdateWeights(const vector<double>* gradient);

 private:
  double _lRate;
  double _origDamp;
  double _prevDamp;
};

#endif












