// Copyright (C) 2015 Deltares
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License version 2 as
// published by the Free Software Foundation.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA

/**
 * @file rtcToolsSNOPT.cpp
 * @brief SNOPT binding for RTC-Tools
 * @author Jorn Baayen
 * @version 0.1
 * @date August 2015
 */

#include "rtcToolsSNOPT.h"
#include "piDiagInterface.h"

#include <stdexcept>
#include <limits>
#include <cmath>

/**
 * @brief Constructor
 * 
 * @param tool RTC-Tools object
 * @param workDir Working directory (unused)
 */
rtcToolsSNOPT::rtcToolsSNOPT(rtcToolsRuntime *tool, string workDir) : tool(tool)
{
}

/**
 * @brief Constructor
 * 
 * @param tool RTC-Tools object
 * @param par Optimizer settings (unused)
 * @param workDir Working directory (unused)
 */
rtcToolsSNOPT::rtcToolsSNOPT(rtcToolsRuntime *tool, rtcRuntimeConfigSettings::SNOPT par, string workDir) : tool(tool)
{
}

/**
 * @brief Destructor
 */
rtcToolsSNOPT::~rtcToolsSNOPT()
{
}

/**
 * Static reference to RTC-Tools object.  Needed to be able to access internal API from within usrfg_,
 * the SNOPT callback.
 */
static rtcToolsRuntime *tool_;

/**
 * @brief Objective, constraint, and derivative eveluation function.
 * @details Called from SNOPT.
 */
static void usrfg_(int    *Status, int *n,    double x[],
		           int    *needF,  int *neF,  double F[],
		           int    *needG,  int *neG,  double G[],
		           char      *cu,  int *lencu,
		           int    iu[],    int *leniu,
		           double ru[],    int *lenru )
{
	try {
		double obj = tool_->simulate(tool_->getN(), x);
	
		if (*needF > 0) {
			F[0] = obj;

			tool_->eval_g(tool_->getN(), x, tool_->getM(), F + 1);
		}

		if (*needG > 0) {
			tool_->eval_grad_f(tool_->getN(), G);

			tool_->eval_jac_g(tool_->getM(), tool_->getNNZ(), NULL, NULL, G + tool_->getN());
		}

	} catch (...) {
		*Status = -1;
		
	}
}


/**
 * @brief Run optimization
 * @return 0 on success
 */
int rtcToolsSNOPT::optimize()
{
	// Allocate and initialize;
	int n     =  tool->getN();
	int neF   =  1 + tool->getM();

	double *x      = new double[n];
	double *xlow   = new double[n];
	double *xupp   = new double[n];
	double *xmul   = new double[n];
	int    *xstate = new    int[n];

	double *F      = new double[neF];
	double *Flow   = new double[neF];
	double *Fupp   = new double[neF];
	double *Fmul   = new double[neF];
	int    *Fstate = new int[neF];

	int    ObjRow  = 0;
	double ObjAdd  = 0;

	// No linear objective
	int lenA   = 1;
	int *iAfun = new int[lenA];
	int *jAvar = new int[lenA];
	double *A  = new double[lenA];

	int neA = 0;

	// Jacobian structure
	int lenG   = n + tool->getNNZ();
	int *iGfun = new int[lenG];
	int *jGvar = new int[lenG];

	int neG = lenG;

	// Jacobian structure:  Objective
	for (int i = 0; i < n; i++) {
		iGfun[i] = 0;
		jGvar[i] = i;
	}

  	// Jacobian structure:  Constraints
	tool->eval_jac_g(tool->getM(), tool->getNNZ(), iGfun + n, jGvar + n, NULL);

	// Correct i indices for the fact that the first row of the matrix G is the objective
	for (int i = n; i < n + tool->getNNZ(); i++) {
		iGfun[i] += 1;
	}

    // Set the upper and lower bounds.
	Flow[0] = std::numeric_limits<double>::min();
	Fupp[0] = std::numeric_limits<double>::max();
	tool->get_bounds_info(tool->getN(), xlow, xupp, tool->getM(), Flow + 1, Fupp + 1);

	// Initialize state.
	std::fill(xstate, xstate + n, 0);
	std::fill(xmul, xmul + n, 0.0);
	std::fill(F, F + neF, 0.0);
	std::fill(Fstate, Fstate + neF, 0);
	std::fill(Fmul, Fmul + neF, 0.0);

  	// Gets starting point and fill NaNs with zero
	tool->getInput(n, x);
	for (int i = 0; i < n; i++) {
		if (std::isnan(x[i]))
			x[i] = 0.0;
	}

	// Set up SNOPT problem object
	snopt.setProbName    ("RTC"); // max 8 characters

	snopt.setProblemSize (n, neF);
	snopt.setObjective   (ObjRow, ObjAdd);
	snopt.setX           (x, xlow, xupp, xmul, xstate);
	snopt.setF           (F, Flow, Fupp, Fmul, Fstate);

	snopt.setA           (lenA, neA, iAfun, jAvar, A);
	snopt.setG           (lenG, neG, iGfun, jGvar);
	snopt.setUserFun     (usrfg_);

	// TODO read options from XML
	snopt.setPrintFile   ("snopt.out");
	snopt.setIntParameter("Derivative option", 1);
	snopt.setIntParameter("Major iterations limit", 5000);
	snopt.setIntParameter("Minor iterations limit", 5000);
	snopt.setIntParameter("Major feasibility tolerance", 1e-3);
	snopt.setIntParameter("Minor feasibility tolerance", 1e-3);
	snopt.setIntParameter("Major optimality tolerance", 1e-1);
	snopt.setIntParameter("Hessian updates", 20);
	snopt.setIntParameter("Verify level", 3);

	tool_ = tool;
	snopt.solve          (0);

	delete []iAfun;  delete []jAvar;  delete []A;
	delete []iGfun;  delete []jGvar;

	delete []x;      delete []xlow;   delete []xupp;
	delete []xmul;   delete []xstate;

	delete []F;      delete []Flow;   delete []Fupp;
	delete []Fmul;   delete []Fstate;

	return 0;
}

/**
 * @brief Initialize optimizer
 */
void rtcToolsSNOPT::initialize()
{
}

/**
 * @brief Write optimization log to file
 * 
 * @param workDir Folder in which to create output file
 */
void rtcToolsSNOPT::write(string workDir)
{
}
