// -*- Mode: C++; tab-width: 2; -*-
// vi: set ts=2:
//

#include <BALL/FORMAT/PDBFile.h>
#include <BALL/FORMAT/molFileFactory.h>
#include <BALL/FORMAT/commandlineParser.h>
#include <BALL/DATATYPE/options.h>
#include <BALL/KERNEL/PTE.h>
#include <BALL/DOCKING/COMMON/flexibleMolecule.h>
#include "version.h"

using namespace BALL;
using namespace std;

int main(int argc, char* argv[])
{
	CommandlineParser parpars("RMSDCalculator", "calculate RMSD between poses", VERSION, String(__DATE__), "Analysis");
	parpars.registerParameter("i", "input molecule file", INFILE, true);
	parpars.registerParameter("org", "molecule file containing the original ('true') poses", INFILE, true);
	parpars.registerParameter("o", "output molecule file", OUTFILE);
	parpars.registerFlag("quiet", "by quiet, i.e. do not print progress information");
	String man = "This tool calculates the RMSD between different conformations of the same molecule.\n\nTherefore this tool can for example be used to evaluate the different between ligands taken from co-crystal structures and their poses generated by a docking.\nMolecules may be sorted differently in the two input files; a topology hashkey will be used to match molecules to each other.\n\nOutput of this tool is a molecule file which will for each molecule contain a property-tag 'RMSD' holding the calculated RMSD value.";
	parpars.setToolManual(man);
	parpars.setSupportedFormats("i",MolFileFactory::getSupportedFormats());
	parpars.setSupportedFormats("org",MolFileFactory::getSupportedFormats());
	parpars.setSupportedFormats("o","mol2,sdf,drf");
	parpars.setOutputFormatSource("o","i");
	parpars.parse(argc, argv);

	// Retrieve coordinates of original poses
	GenericMolFile* original = MolFileFactory::open(parpars.get("org"));
	HashMap<String, list<Vector3> > original_poses;
	for (Molecule* mol = original->read(); mol; delete mol, mol = original->read())
	{
		String topology_hash;
		FlexibleMolecule::generateTopologyHash(mol, topology_hash, true);
		if (original_poses.find(topology_hash) != original_poses.end())
		{
			Log<<"[Warning:] more than one 'original' conformation for a molecule detected. Will use only the first conformation and ignore all other."<<endl;
		}
		else
		{
			list<Vector3> l;
			HashMap<String, list<Vector3> >::iterator map_it = original_poses.insert(make_pair(topology_hash, l)).first;

			for (AtomConstIterator it = mol->beginAtom(); +it; it++)
			{
				if (it->getElement().getSymbol() != "H")
				{
					map_it->second.push_back(it->getPosition());
				}
			}
		}
	}
	delete original;

	// Retrieve coordinates of input poses and calculate RMSDs
	GenericMolFile* input = MolFileFactory::open(parpars.get("i"));
	GenericMolFile* output = 0;
	String filename = parpars.get("o");
	if (filename != CommandlineParser::NOT_FOUND)
	{
		output = MolFileFactory::open(filename, ios::out, input);
	}

	double average_RMSD = 0;
	int no_mols = 0;
	int no_valid_rmsds = 0;
	bool quiet = (parpars.get("quiet")!=CommandlineParser::NOT_FOUND);

	for (Molecule* mol = input->read(); mol; delete mol, mol = input->read())
	{
		no_mols++;
		String topology_hash;
		FlexibleMolecule::generateTopologyHash(mol, topology_hash, true);

		HashMap<String, list<Vector3> >::iterator map_it = original_poses.find(topology_hash);
		if (map_it == original_poses.end())
		{
			Log<<"[Warning:] no original pose for molecule '"<<mol->getName()<<"' found, its RMSD can thus not be computed."<<endl;
			mol->setProperty("RMSD", "N/A");
		}
		else
		{
			double RMSD = 0;
			list<Vector3>::iterator list_it = map_it->second.begin();
			int no_heavy_atoms = 0;
			AtomConstIterator it = mol->beginAtom();
			for (; +it ; it++)
			{
				if (it->getElement().getSymbol() != "H" && list_it != map_it->second.end())
				{
					RMSD += pow(it->getPosition().getDistance(*list_it), 2);
					no_heavy_atoms++;
					list_it++;
				}
			}
			if (it != mol->endAtom() || list_it != map_it->second.end())
			{
				Log.error()<<"[Error:] Number of heavy atoms of input pose do not match number of heavy atoms of original pose!!"<<endl;
				return 1;
			}
			RMSD = sqrt(RMSD/no_heavy_atoms);
			mol->setProperty("RMSD", RMSD);
			average_RMSD += RMSD;
			no_valid_rmsds++;

			if (!quiet) Log << "RMSD for molecule "<<no_mols<<", '"<<mol->getName()<<"' = "<<RMSD<<endl;
		}

		if (output) *output << *mol;
	}

	average_RMSD /= no_valid_rmsds;

	Log <<endl<<"average RMSD = "<<average_RMSD<<endl<<endl;

	delete input;
	delete output;
	return 0;
}
