/*=========================================================================
 *
 *  Copyright NumFOCUS
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         https://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/

#include "itkVersorRigid3DTransform.h"
#include "itkCenteredTransformInitializer.h"
#include "itkImageRegionIterator.h"


namespace
{
constexpr unsigned int Dimension = 3;

// This function assumes that the center of mass of both images is the
// geometrical center.
template <typename TFixedImage, typename TMovingImage>
bool
RunTest(itk::SmartPointer<TFixedImage> fixedImage, itk::SmartPointer<TMovingImage> movingImage)
{
  using FixedImageType = TFixedImage;
  using MovingImageType = TMovingImage;

  bool pass = true;

  // Transform Type
  using TransformType = itk::VersorRigid3DTransform<double>;

  // calculate image centers
  TransformType::InputPointType fixedCenter;
  TransformType::InputPointType movingCenter;

  using ContinuousIndexType = itk::ContinuousIndex<double, Dimension>;

  const typename FixedImageType::RegionType & fixedRegion = fixedImage->GetLargestPossibleRegion();
  const typename FixedImageType::SizeType &   fixedSize = fixedRegion.GetSize();
  const typename FixedImageType::IndexType &  fixedIndex = fixedRegion.GetIndex();
  ContinuousIndexType                         fixedCenterIndex;
  for (unsigned int i = 0; i < Dimension; ++i)
  {
    assert(0 < fixedSize[i]);
    fixedCenterIndex[i] = static_cast<double>(fixedIndex[i]) + static_cast<double>(fixedSize[i] - 1) / 2.0;
  }
  fixedImage->TransformContinuousIndexToPhysicalPoint(fixedCenterIndex, fixedCenter);

  const typename MovingImageType::RegionType & movingRegion = movingImage->GetLargestPossibleRegion();
  const typename MovingImageType::SizeType &   movingSize = movingRegion.GetSize();
  const typename MovingImageType::IndexType &  movingIndex = movingRegion.GetIndex();
  ContinuousIndexType                          movingCenterIndex;
  for (unsigned int i = 0; i < Dimension; ++i)
  {
    assert(0 < movingSize[i]);
    movingCenterIndex[i] = static_cast<double>(movingIndex[i]) + static_cast<double>(movingSize[i] - 1) / 2.0;
  }
  movingImage->TransformContinuousIndexToPhysicalPoint(movingCenterIndex, movingCenter);

  TransformType::InputVectorType relativeCenter = movingCenter - fixedCenter;


  auto transform = TransformType::New();

  using InitializerType = itk::CenteredTransformInitializer<TransformType, FixedImageType, MovingImageType>;

  auto initializer = InitializerType::New();

  initializer->SetFixedImage(fixedImage);
  initializer->SetMovingImage(movingImage);
  initializer->SetTransform(transform);

  transform->SetIdentity();
  initializer->GeometryOn();
  initializer->InitializeTransform();

  std::cout << std::endl << std::endl;
  std::cout << "Testing Geometric Mode " << std::endl;
  // transform->Print( std::cout );

  const TransformType::InputPointType &   center1 = transform->GetCenter();
  const TransformType::OutputVectorType & translation1 = transform->GetTranslation();
  const TransformType::OffsetType &       offset1 = transform->GetOffset();
  const double                            tolerance = 1e-3;

  // Verfications for the Geometry Mode
  for (unsigned int k = 0; k < Dimension; ++k)
  {
    if (itk::Math::abs(center1[k] - fixedCenter[k]) > tolerance)
    {
      std::cerr << "Center differs from expected value" << std::endl;
      std::cerr << "It should be " << fixedCenter << std::endl;
      std::cerr << "but it is    " << center1 << std::endl;
      pass = false;
      break;
    }
    if (itk::Math::abs(translation1[k] - relativeCenter[k]) > tolerance)
    {
      std::cerr << "Translation differs from expected value" << std::endl;
      std::cerr << "It should be " << relativeCenter << std::endl;
      std::cerr << "but it is    " << translation1 << std::endl;
      pass = false;
      break;
    }
    if (itk::Math::abs(offset1[k] - relativeCenter[k]) > tolerance)
    {
      std::cerr << "Offset differs from expected value" << std::endl;
      std::cerr << "It should be " << relativeCenter << std::endl;
      std::cerr << "but it is    " << offset1 << std::endl;
      pass = false;
      break;
    }
  }

  transform->SetIdentity();
  initializer->MomentsOn();
  initializer->InitializeTransform();

  std::cout << std::endl << std::endl;
  std::cout << "Testing Moments Mode " << std::endl;
  // transform->Print( std::cout );

  const TransformType::InputPointType &   center2 = transform->GetCenter();
  const TransformType::OutputVectorType & translation2 = transform->GetTranslation();
  const TransformType::OffsetType &       offset2 = transform->GetOffset();

  // Verfications for the Moments Mode
  for (unsigned int k = 0; k < Dimension; ++k)
  {
    if (itk::Math::abs(center2[k] - fixedCenter[k]) > tolerance)
    {
      std::cerr << "Center differs from expected value" << std::endl;
      std::cerr << "It should be " << fixedCenter << std::endl;
      std::cerr << "but it is    " << center2 << std::endl;
      pass = false;
      break;
    }
    if (itk::Math::abs(translation2[k] - relativeCenter[k]) > tolerance)
    {
      std::cerr << "Translation differs from expected value" << std::endl;
      std::cerr << "It should be " << relativeCenter << std::endl;
      std::cerr << "but it is    " << translation2 << std::endl;
      pass = false;
      break;
    }
    if (itk::Math::abs(offset2[k] - relativeCenter[k]) > tolerance)
    {
      std::cerr << "Offset differs from expected value" << std::endl;
      std::cerr << "It should be " << relativeCenter << std::endl;
      std::cerr << "but it is    " << offset2 << std::endl;
      pass = false;
      break;
    }
  }

  return pass;
}


template <typename TImage>
void
PopulateImage(itk::SmartPointer<TImage> image)
{
  image->Allocate();
  image->FillBuffer(0);

  using ImageType = TImage;
  using RegionType = typename ImageType::RegionType;
  using SizeType = typename ImageType::SizeType;
  using IndexType = typename ImageType::IndexType;

  const RegionType & region = image->GetLargestPossibleRegion();
  const SizeType &   size = region.GetSize();
  const IndexType &  index = region.GetIndex();

  RegionType internalRegion;
  SizeType   internalSize;
  IndexType  internalIndex;

  constexpr unsigned int border = 20;

  assert(2 * border < size[0]);
  assert(2 * border < size[1]);
  assert(2 * border < size[2]);

  internalIndex[0] = index[0] + border;
  internalIndex[1] = index[1] + border;
  internalIndex[2] = index[2] + border;

  internalSize[0] = size[0] - 2 * border;
  internalSize[1] = size[1] - 2 * border;
  internalSize[2] = size[2] - 2 * border;


  internalRegion.SetSize(internalSize);
  internalRegion.SetIndex(internalIndex);

  using Iterator = itk::ImageRegionIterator<ImageType>;
  Iterator it(image, internalRegion);

  it.GoToBegin();
  while (!it.IsAtEnd())
  {
    it.Set(200);
    ++it;
  }
}

} // namespace


/**
 *  This program tests the use of the CenteredTransformInitializer class
 *
 *
 */

int
itkCenteredTransformInitializerTest(int, char *[])
{

  bool pass = true;

  std::cout << std::endl << std::endl;
  std::cout << "Running tests with itk::Image" << std::endl;
  {
    // Create Images

    using FixedImageType = itk::Image<unsigned char, Dimension>;
    using MovingImageType = itk::Image<unsigned char, Dimension>;

    using SizeType = FixedImageType::SizeType;
    using SpacingType = FixedImageType::SpacingType;
    using PointType = FixedImageType::PointType;
    using IndexType = FixedImageType::IndexType;
    using RegionType = FixedImageType::RegionType;

    SizeType size;
    size[0] = 100;
    size[1] = 100;
    size[2] = 60;

    PointType fixedOrigin;
    fixedOrigin[0] = 0.0;
    fixedOrigin[1] = 0.0;
    fixedOrigin[2] = 0.0;

    PointType movingOrigin;
    movingOrigin[0] = 29.0;
    movingOrigin[1] = 17.0;
    movingOrigin[2] = 13.0;

    SpacingType spacing;
    spacing[0] = 1.5;
    spacing[1] = 1.5;
    spacing[2] = 2.5;

    IndexType index;
    index[0] = 0;
    index[1] = 0;
    index[2] = 0;

    RegionType region;
    region.SetSize(size);
    region.SetIndex(index);


    auto fixedImage = FixedImageType::New();
    auto movingImage = MovingImageType::New();

    fixedImage->SetRegions(region);
    fixedImage->SetSpacing(spacing);
    fixedImage->SetOrigin(fixedOrigin);

    movingImage->SetRegions(region);
    movingImage->SetSpacing(spacing);
    movingImage->SetOrigin(movingOrigin);

    PopulateImage(fixedImage);
    PopulateImage(movingImage);

    pass &= RunTest(fixedImage, movingImage);
  }

  std::cout << std::endl << std::endl;
  std::cout << "Running tests with itk::Image" << std::endl;
  {
    // Create Images

    using FixedImageType = itk::Image<unsigned char, Dimension>;
    using MovingImageType = itk::Image<unsigned char, Dimension>;

    using SizeType = FixedImageType::SizeType;
    using SpacingType = FixedImageType::SpacingType;
    using PointType = FixedImageType::PointType;
    using IndexType = FixedImageType::IndexType;
    using RegionType = FixedImageType::RegionType;
    using DirectionType = FixedImageType::DirectionType;

    SizeType size;
    size[0] = 100;
    size[1] = 100;
    size[2] = 60;

    PointType fixedOrigin;
    fixedOrigin[0] = 0.0;
    fixedOrigin[1] = 0.0;
    fixedOrigin[2] = 0.0;

    PointType movingOrigin;
    movingOrigin[0] = 29.0;
    movingOrigin[1] = 17.0;
    movingOrigin[2] = 13.0;

    SpacingType spacing;
    spacing[0] = 1.5;
    spacing[1] = 1.5;
    spacing[2] = 2.5;

    IndexType fixedIndex;
    fixedIndex[0] = 0;
    fixedIndex[1] = 0;
    fixedIndex[2] = 0;

    IndexType movingIndex;
    movingIndex[0] = 10;
    movingIndex[1] = 20;
    movingIndex[2] = 30;

    RegionType fixedRegion;
    fixedRegion.SetSize(size);
    fixedRegion.SetIndex(fixedIndex);

    RegionType movingRegion;
    movingRegion.SetSize(size);
    movingRegion.SetIndex(movingIndex);

    using VersorType = itk::Versor<itk::SpacePrecisionType>;
    VersorType x;
    x.SetRotationAroundX(0.5);
    VersorType y;
    y.SetRotationAroundY(1.0);
    VersorType z;
    z.SetRotationAroundZ(1.5);

    DirectionType fixedDirection = (x * y * z).GetMatrix();
    DirectionType movingDirection = (z * y * x).GetMatrix();


    auto fixedImage = FixedImageType::New();
    auto movingImage = MovingImageType::New();

    fixedImage->SetRegions(fixedRegion);
    fixedImage->SetSpacing(spacing);
    fixedImage->SetOrigin(fixedOrigin);
    fixedImage->SetDirection(fixedDirection);

    movingImage->SetRegions(movingRegion);
    movingImage->SetSpacing(spacing);
    movingImage->SetOrigin(movingOrigin);
    movingImage->SetDirection(movingDirection);

    PopulateImage(fixedImage);
    PopulateImage(movingImage);

    pass &= RunTest(fixedImage, movingImage);
  }

  if (!pass)
  {
    std::cout << "Test FAILED." << std::endl;
    return EXIT_FAILURE;
  }

  std::cout << "Test PASSED." << std::endl;
  return EXIT_SUCCESS;
}
