//                                               -*- C++ -*-
/**
 * @file  NumericalMathEvaluationImplementation.cxx
 * @brief Abstract top-level class for all distributions
 *
 *  (C) Copyright 2005-2012 EDF-EADS-Phimeca
 *
 *  This library is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU Lesser General Public
 *  License as published by the Free Software Foundation; either
 *  version 2.1 of the License.
 *
 *  This library is distributed in the hope that it will be useful
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public
 *  License along with this library; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
 *
 * \author $LastChangedBy: schueller $
 * \date   $LastChangedDate: 2012-03-09 19:08:31 +0100 (Fri, 09 Mar 2012) $
 */

#include "NumericalMathEvaluationImplementation.hxx"
#include "ComposedNumericalMathEvaluationImplementation.hxx"
#include "AnalyticalNumericalMathEvaluationImplementation.hxx"
#include "Exception.hxx"
#include "PersistentObjectFactory.hxx"
#include "Full.hxx"

BEGIN_NAMESPACE_OPENTURNS

typedef NumericalMathEvaluationImplementation::CacheType NumericalMathEvaluationImplementationCache;
TEMPLATE_CLASSNAMEINIT(NumericalMathEvaluationImplementationCache);

static Factory<NumericalMathEvaluationImplementationCache> RegisteredFactory_Cache("NumericalMathEvaluationImplementationCache");

/* These methods are implemented here for the needs of Cache */
/* We should be careful because they may interfere with other definitions placed elsewhere */
TEMPLATE_CLASSNAMEINIT(PersistentCollection<UnsignedLong>);
static Factory<PersistentCollection<UnsignedLong> > RegisteredFactory_alt1("PersistentCollection<UnsignedLong>");
TEMPLATE_CLASSNAMEINIT(PersistentCollection<PersistentCollection<NumericalScalar> >);
static Factory<PersistentCollection<PersistentCollection<NumericalScalar> > > RegisteredFactory_alt2("PersistentCollection<PersistentCollection<NumericalScalar> >");


CLASSNAMEINIT(NumericalMathEvaluationImplementation);

static Factory<NumericalMathEvaluationImplementation> RegisteredFactory("NumericalMathEvaluationImplementation");


/* Default constructor */
NumericalMathEvaluationImplementation::NumericalMathEvaluationImplementation()
  : PersistentObject(),
    callsNumber_(0),
    p_cache_(new CacheType),
    inputStrategy_(Full()),
    outputStrategy_(Full()),
    isHistoryEnabled_(false),
    description_(0),
    parameters_(0)
{
  // We disable the cache by default
  p_cache_->disable();
}

/* Virtual constructor */
NumericalMathEvaluationImplementation * NumericalMathEvaluationImplementation::clone() const
{
  return new NumericalMathEvaluationImplementation(*this);
}


/* Comparison operator */
Bool NumericalMathEvaluationImplementation::operator ==(const NumericalMathEvaluationImplementation & other) const
{
  return true;
}

/* String converter */
String NumericalMathEvaluationImplementation::__repr__() const {
  OSS oss;
  oss << "class=" << NumericalMathEvaluationImplementation::GetClassName()
      << " name=" << getName()
      << " description=" << description_
      << " parameters=" << parameters_;
  return oss;
}

/* String converter */
String NumericalMathEvaluationImplementation::__str__(const String & offset) const {
  return OSS(false) << offset << "NumericalMathEvaluationImplementation";
}

/* Description Accessor */
void NumericalMathEvaluationImplementation::setDescription(const Description & description)
{
  if (description.getSize() != getInputDimension() + getOutputDimension()) throw InvalidArgumentException(HERE) << "Error: the description must have a size of input dimension + output dimension, here size=" << description.getSize() << ", input dimension=" << getInputDimension() << ", output dimension=" << getOutputDimension();
  description_ = description;
}


/* Description Accessor */
Description NumericalMathEvaluationImplementation::getDescription() const
{
  return description_;
}

/* Input description Accessor */
Description NumericalMathEvaluationImplementation::getInputDescription() const
{
  Description inputDescription(0);
  // Check if the description has been set
  if (description_.getSize() > 0) for (UnsignedLong i = 0; i < getInputDimension(); i++) inputDescription.add(description_[i]);
  else for (UnsignedLong i = 0; i < getInputDimension(); i++) inputDescription.add((OSS() << "x" << i));
  return inputDescription;
}

/* Output description Accessor */
Description NumericalMathEvaluationImplementation::getOutputDescription() const
{
  Description outputDescription(0);
  // Check if the description has been set
  if (description_.getSize() > 0) for (UnsignedLong i = getInputDimension(); i < description_.getSize(); i++) outputDescription.add(description_[i]);
  else for (UnsignedLong i = 0; i < getOutputDimension(); i++) outputDescription.add((OSS() << "y" << i));
  return outputDescription;
}

/* Test for actual implementation */
Bool NumericalMathEvaluationImplementation::isActualImplementation() const
{
  return true;
}

/* Here is the interface that all derived class must implement */

/* Operator () */
NumericalSample NumericalMathEvaluationImplementation::operator() (const NumericalSample & inSample) const
{
  const UnsignedLong inputDimension(getInputDimension());
  if (inSample.getDimension() != inputDimension) throw InvalidArgumentException(HERE) << "Error: the given sample has an invalid dimension. Expect a dimension " << inputDimension << ", got " << inSample.getDimension();

  const UnsignedLong size(inSample.getSize());
  NumericalSample outSample(size, NumericalPoint(getOutputDimension()));
  // Simple loop over the evaluation operator based on point
  // The calls number is updated by these calls
  for (UnsignedLong i = 0; i < size; ++i)
    outSample[i] = operator()(inSample[i]);
  return outSample;
}


/* Operator () */
TimeSeries NumericalMathEvaluationImplementation::operator() (const TimeSeries & inTimeSeries) const
{
  const UnsignedLong inputDimension(getInputDimension());
  if (inTimeSeries.getDimension() != inputDimension) throw InvalidArgumentException(HERE) << "Error: the given time series has an invalid dimension. Expect a dimension " << inputDimension << ", got " << inTimeSeries.getDimension();

  const UnsignedLong size(inTimeSeries.getSize());
  TimeSeries outTimeSeries(size, getOutputDimension());
  // Simple loop over the evaluation operator based on point
  // The calls number is updated by these calls
  for (UnsignedLong i = 0; i < size; ++i) {
    outTimeSeries[i][0] = inTimeSeries[i][0];
    outTimeSeries.getValueAtIndex(i) = operator()( inTimeSeries.getValueAtIndex(i) );
  }
  return outTimeSeries;
}


/* Enable or disable the internal cache */
void NumericalMathEvaluationImplementation::enableCache() const {
  p_cache_->enable();
}

void NumericalMathEvaluationImplementation::disableCache() const {
  p_cache_->disable();
}

Bool NumericalMathEvaluationImplementation::isCacheEnabled() const {
  return p_cache_->isEnabled();
}

UnsignedLong NumericalMathEvaluationImplementation::getCacheHits() const
{
  return p_cache_->getHits();
}

void NumericalMathEvaluationImplementation::addCacheContent(const NumericalSample& inSample, const NumericalSample& outSample)
{
  const UnsignedLong size = inSample.getSize();
  for ( UnsignedLong i = 0; i < size; ++ i ) {
    p_cache_->add( inSample[i], outSample[i] );
  }
}

NumericalSample NumericalMathEvaluationImplementation::getCacheInput() const
{
  PersistentCollection<CacheKeyType> keyColl( p_cache_->getKeys() );
  NumericalSample inSample(0, getInputDimension());
  for ( UnsignedLong i = 0; i < keyColl.getSize(); ++ i ) {
    inSample.add(keyColl[i]);
  }
  return inSample;
}

/* Enable or disable the input/output history */
void NumericalMathEvaluationImplementation::enableHistory() const
{
  isHistoryEnabled_ = true;
}

void NumericalMathEvaluationImplementation::disableHistory() const
{
  isHistoryEnabled_ = false;
}

Bool NumericalMathEvaluationImplementation::isHistoryEnabled() const
{
  return isHistoryEnabled_;
}

void NumericalMathEvaluationImplementation::resetHistory() const
{
  inputStrategy_ = Full();
  outputStrategy_ = Full();
}

HistoryStrategy NumericalMathEvaluationImplementation::getInputHistory() const
{
  return inputStrategy_;
}

HistoryStrategy NumericalMathEvaluationImplementation::getOutputHistory() const
{
  return outputStrategy_;
}

/* Gradient according to the marginal parameters */
Matrix NumericalMathEvaluationImplementation::parametersGradient(const NumericalPoint & inP) const
{
  return Matrix(parameters_.getDimension(), getOutputDimension());
}

/* Parameters value and description accessor */
NumericalPointWithDescription NumericalMathEvaluationImplementation::getParameters() const
{
  return parameters_;
}

void NumericalMathEvaluationImplementation::setParameters(const NumericalPointWithDescription & parameters)
{
  parameters_ = parameters;
}

/* Operator () */
NumericalPoint NumericalMathEvaluationImplementation::operator() (const NumericalPoint & inP) const
{
  throw NotYetImplementedException(HERE);
}

/* Accessor for input point dimension */
UnsignedLong NumericalMathEvaluationImplementation::getInputDimension() const
{
  throw NotYetImplementedException(HERE);
}

/* Accessor for output point dimension */
UnsignedLong NumericalMathEvaluationImplementation::getOutputDimension() const
{
  throw NotYetImplementedException(HERE);
}

/* Get the i-th marginal function */
NumericalMathEvaluationImplementation * NumericalMathEvaluationImplementation::getMarginal(const UnsignedLong i) const
{
  if (i >= getOutputDimension()) throw InvalidArgumentException(HERE) << "Error: the index of a marginal function must be in the range [0, outputDimension-1]";
  return getMarginal(Indices(1, i));
}

/* Get the function corresponding to indices components */
NumericalMathEvaluationImplementation * NumericalMathEvaluationImplementation::getMarginal(const Indices & indices) const
{
  if (!indices.check(getOutputDimension() - 1)) throw InvalidArgumentException(HERE) << "Error: the indices of a marginal function must be in the range [0, outputDimension-1] and  must be different";
  // We build an analytical function that extract the needed component
  // If X1,...,XN are the descriptions of the input of this function, it is a function from R^n to R^p
  // with formula Yk = Xindices[k] for k=1,...,p
  // Build non-ambigous names for the inputs. We cannot simply use the output description, as it must be valid muParser identifiers
  const UnsignedLong inputDimension(getOutputDimension());
  Description input(inputDimension);
  for (UnsignedLong index = 0; index < inputDimension; ++index)
    input[index] = OSS() << "x" << index;
  // Extract the components
  const UnsignedLong outputDimension(indices.getSize());
  Description output(outputDimension);
  Description formulas(outputDimension);
  Description currentOutputDescription(getOutputDescription());
  for (UnsignedLong index = 0; index < outputDimension; ++index)
    {
      output[index] = currentOutputDescription[indices[index]];
      formulas[index] = input[indices[index]];
    }
  const AnalyticalNumericalMathEvaluationImplementation left(input, output, formulas);
  return new ComposedNumericalMathEvaluationImplementation(left.clone(), clone());
}

/* Get the number of calls to operator() */
UnsignedLong NumericalMathEvaluationImplementation::getCallsNumber() const
{
  return callsNumber_;
}

/* Method save() stores the object through the StorageManager */
void NumericalMathEvaluationImplementation::save(Advocate & adv) const
{
  PersistentObject::save(adv);
  adv.saveAttribute( "callsNumber_", callsNumber_ );
  adv.saveAttribute( "cache_", *p_cache_ );
  adv.saveAttribute( "description_", description_ );
  adv.saveAttribute( "parameters_", parameters_ );
}

/* Method load() reloads the object from the StorageManager */
void NumericalMathEvaluationImplementation::load(Advocate & adv)
{
  TypedInterfaceObject<CacheType> cache;
  PersistentObject::load(adv);
  adv.loadAttribute( "callsNumber_", callsNumber_ );
  adv.loadAttribute( "cache_", cache );
  p_cache_ = cache.getImplementation();
  adv.loadAttribute( "description_", description_ );
  adv.loadAttribute( "parameters_", parameters_ );
}

END_NAMESPACE_OPENTURNS
