/*
 * (C) Copyright 1996- ECMWF.
 *
 * This software is licensed under the terms of the Apache Licence Version 2.0
 * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
 * In applying this licence, ECMWF does not waive the privileges and immunities
 * granted to it by virtue of its status as an intergovernmental organisation
 * nor does it submit to any jurisdiction.
 */


// TODO:
// -----
// * Fix abort encountered with
//       mpirun -np 4 atlas-conservative-interpolation --source.grid=O20 --target.grid=H8 --order=2
//
// QUESTIONS:
// ----------
// * Why sqrt in ConservativeSphericalPolygon in line
//       remap_stat.errors[Statistics::Errors::REMAP_CONS] = std::sqrt(std::abs(err_remap_cons) / unit_sphere_area());
//   used to compute conservation_error


#include <cmath>
#include <fstream>
#include <map>
#include <unordered_map>

#include "eckit/geometry/Sphere.h"
#include "eckit/log/Bytes.h"
#include "eckit/log/JSON.h"
#include "eckit/types/FloatCompare.h"

#include "atlas/array.h"
#include "atlas/array/MakeView.h"
#include "atlas/field.h"
#include "atlas/grid.h"
#include "atlas/interpolation/Interpolation.h"
#include "atlas/interpolation/method/unstructured/ConservativeSphericalPolygonInterpolation.h"
#include "atlas/mesh.h"
#include "atlas/mesh/Mesh.h"
#include "atlas/mesh/actions/Build2DCellCentres.h"
#include "atlas/meshgenerator.h"
#include "atlas/option.h"
#include "atlas/output/Gmsh.h"
#include "atlas/runtime/AtlasTool.h"
#include "atlas/util/Config.h"
#include "atlas/util/function/SphericalHarmonic.h"
#include "atlas/util/function/VortexRollup.h"

#include "tests/AtlasTestEnvironment.h"


namespace atlas {


class AtlasParallelInterpolation : public AtlasTool {
    int execute(const AtlasTool::Args& args) override;
    std::string briefDescription() override { return "Demonstration of parallel interpolation"; }
    std::string usage() override {
        return name() +
               " [--source.grid=gridname] "
               "[--target.grid=gridname] [OPTION]... [--help]";
    }

    int numberOfPositionalArguments() override { return -1; }
    int minimumPositionalArguments() override { return 0; }

public:
    AtlasParallelInterpolation(int argc, char* argv[]): AtlasTool(argc, argv) {
        add_option(new eckit::option::Separator("Source/Target options"));

        add_option(new SimpleOption<std::string>("source.grid", "source gridname"));
        add_option(new SimpleOption<std::string>("target.grid", "target gridname"));
        add_option(new SimpleOption<std::string>("source.functionspace",
                                                 "source functionspace, to override source grid default"));
        add_option(new SimpleOption<std::string>("target.functionspace",
                                                 "target functionspace, to override target grid default"));
        add_option(new SimpleOption<long>("source.halo", "default=2"));
        add_option(new SimpleOption<long>("target.halo", "default=0"));

        add_option(new eckit::option::Separator("Interpolation options"));
        add_option(new SimpleOption<long>("order", "Interpolation order. Supported: 1, 2 (default=1)"));
        add_option(new SimpleOption<bool>("normalise_intersections",
                                          "Normalize polygon intersections so that interpolation weights sum to 1."));
        add_option(new SimpleOption<bool>("validate",
                                          "Enable extra validations at cost of performance. For debugging purpose."));
        add_option(new SimpleOption<bool>("matrix_free", "Do not store matrix for consecutive interpolations"));

        add_option(
            new SimpleOption<bool>("statistics.intersection", "Enable extra statistics on polygon intersections"));
        add_option(new SimpleOption<bool>("statistics.accuracy",
                                          "Enable extra statistics, comparing result with initial condition"));
        add_option(
            new SimpleOption<bool>("statistics.conservation", "Enable extra statistics computing mass conservation"));

        add_option(new eckit::option::Separator("Output options"));

        add_option(new SimpleOption<bool>(
            "output-gmsh", "Output gmsh files src_mesh.msh, tgt_mesh.msh, src_field.msh, tgt_field.msh"));
        add_option(new SimpleOption<std::string>("gmsh.coordinates", "Mesh coordinates [xy,lonlat,xyz]"));
        add_option(new SimpleOption<bool>("gmsh.ghost", "output of ghost"));

        add_option(new SimpleOption<bool>("output-json", "Output json file with run information"));
        add_option(new SimpleOption<std::string>("json.file", "File path for json output"));

        add_option(new eckit::option::Separator("Initial condition options"));

        add_option(new SimpleOption<std::string>(
            "init", "Setup initial source field [ constant, spherical_harmonic, vortex_rollup (default) ]"));
        add_option(new SimpleOption<double>("vortex_rollup.t", "Value that controls vortex rollup (default = 0.5)"));
        add_option(new SimpleOption<double>("constant.value", "Value that is assigned in case init==constant)"));
        add_option(new SimpleOption<long>("spherical_harmonic.n", "total wave number 'n' of a spherical harmonic"));
        add_option(new SimpleOption<long>("spherical_harmonic.m", "zonal wave number 'm' of a spherical harmonic"));
    }

    struct Timers {
        using StopWatch = atlas::runtime::trace::StopWatch;
        StopWatch target_setup;
        StopWatch source_setup;
        StopWatch initial_condition;
        StopWatch interpolation_setup;
        StopWatch interpolation_execute;
    } timers;
};

std::function<double(const PointLonLat&)> get_init(const AtlasTool::Args& args) {
    std::string init;
    args.get("init", init = "vortex_rollup");
    if (init == "vortex_rollup") {
        double t;
        args.get("vortex_rollup.t", t = 1.);
        return [t](const PointLonLat& p) { return util::function::vortex_rollup(p.lon(), p.lat(), t); };
    }
    else if (init == "spherical_harmonic") {
        int n = 2;
        int m = 2;
        args.get("spherical_harmonic.n", n);
        args.get("spherical_harmonic.m", m);

        bool caching = true;  // true -> warning not thread-safe
        util::function::SphericalHarmonic Y(n, m, caching);
        return [Y](const PointLonLat& p) { return Y(p.lon(), p.lat()); };
    }
    else if (init == "constant") {
        double value;
        args.get("constant.value", value = 1.);
        return [value](const PointLonLat&) { return value; };
    }
    else {
        if (args.has("init")) {
            Log::error() << "Bad value for \"init\": \"" << init << "\" not recognised." << std::endl;
            ATLAS_NOTIMPLEMENTED;
        }
    }
    ATLAS_THROW_EXCEPTION("Should not be here");
}

int AtlasParallelInterpolation::execute(const AtlasTool::Args& args) {
    auto src_grid = Grid{args.getString("source.grid", "H16")};
    auto tgt_grid = Grid{args.getString("target.grid", "H32")};

    auto create_functionspace = [&](Mesh& mesh, int halo, std::string type) -> FunctionSpace {
        if (type.empty()) {
            type = "NodeColumns";
            if (mesh.grid().type() == "healpix" || mesh.grid().type() == "cubedsphere") {
                type = "CellColumns";
            }
        }
        if (type == "CellColumns") {
            if (!mesh.cells().has_field("lonlat")) {
                mesh::actions::Build2DCellCentres{"lonlat"}(mesh);
            }
            return functionspace::CellColumns(mesh, option::halo(halo));
        }
        else if (type == "NodeColumns") {
            return functionspace::NodeColumns(mesh, option::halo(halo));
        }
        ATLAS_THROW_EXCEPTION("FunctionSpace " << type << " is not recognized.");
    };

    timers.target_setup.start();
    auto tgt_mesh = Mesh{tgt_grid};
    auto tgt_functionspace =
        create_functionspace(tgt_mesh, args.getLong("target.halo", 0), args.getString("target.functionspace", ""));
    auto tgt_field = tgt_functionspace.createField<double>();
    timers.target_setup.stop();

    timers.source_setup.start();
    auto src_meshgenerator = MeshGenerator{src_grid.meshgenerator() | option::halo(2)};
    auto src_partitioner   = grid::MatchingPartitioner{tgt_mesh};
    auto src_mesh          = src_meshgenerator.generate(src_grid, src_partitioner);
    auto src_functionspace =
        create_functionspace(src_mesh, args.getLong("source.halo", 2), args.getString("source.functionspace", ""));
    auto src_field = src_functionspace.createField<double>();
    timers.source_setup.stop();

    {
        ATLAS_TRACE("Initial condition");
        timers.initial_condition.start();
        const auto lonlat = array::make_view<double, 2>(src_functionspace.lonlat());
        auto src_view     = array::make_view<double, 1>(src_field);
        auto f            = get_init(args);
        for (idx_t n = 0; n < lonlat.shape(0); ++n) {
            src_view(n) = f(PointLonLat{lonlat(n, LON), lonlat(n, LAT)});
        }
        src_field.set_dirty(false);
        timers.initial_condition.start();
    }


    timers.interpolation_setup.start();
    auto interpolation =
        Interpolation(option::type("conservative-spherical-polygon") | args, src_functionspace, tgt_functionspace);
    timers.interpolation_setup.stop();


    timers.interpolation_execute.start();
    auto metadata = interpolation.execute(src_field, tgt_field);
    timers.interpolation_execute.stop();

    // API not yet acceptable
    Field src_conservation_field;
    {
        using Statistics = interpolation::method::ConservativeSphericalPolygonInterpolation::Statistics;
        Statistics stats(metadata);
        if (args.getBool("statistics.accuracy", false)) {
            stats.accuracy(interpolation, tgt_field, get_init(args));
        }
        if (args.getBool("statistics.conservation", false)) {
            // compute difference field
            src_conservation_field = stats.diff(interpolation, src_field, tgt_field);
        }
    }


    Log::info() << "interpolation metadata: \n";
    {
        eckit::JSON json(Log::info(), eckit::JSON::Formatting::indent(2));
        json << metadata;
    }
    Log::info() << std::endl;

    if (args.getBool("output-gmsh", false)) {
        if (args.getBool("gmsh.ghost", false)) {
            ATLAS_TRACE("halo exchange target");
            tgt_field.haloExchange();
        }
        util::Config config(args.getSubConfiguration("gmsh"));
        output::Gmsh{"src_mesh.msh", config}.write(src_mesh);
        output::Gmsh{"src_field.msh", config}.write(src_field);
        output::Gmsh{"tgt_mesh.msh", config}.write(tgt_mesh);
        output::Gmsh{"tgt_field.msh", config}.write(tgt_field);
        if (src_conservation_field) {
            output::Gmsh{"src_conservation_field.msh", config}.write(src_conservation_field);
        }
    }

    if (args.getBool("output-json", false)) {
        util::Config output;
        output.set("setup.source.grid", args.getString("source.grid"));
        output.set("setup.target.grid", args.getString("target.grid"));
        output.set("setup.source.functionspace", src_functionspace.type());
        output.set("setup.target.functionspace", tgt_functionspace.type());
        output.set("setup.source.halo", args.getLong("source.halo", 2));
        output.set("setup.target.halo", args.getLong("target.halo", 0));
        output.set("setup.interpolation.order", args.getInt("order", 1));
        output.set("setup.interpolation.normalise_intersections", args.getBool("normalise_intersections", false));
        output.set("setup.interpolation.validate", args.getBool("validate", false));
        output.set("setup.interpolation.matrix_free", args.getBool("matrix-free", false));
        output.set("setup.init", args.getString("init", "vortex_rollup"));

        output.set("runtime.mpi", mpi::size());
        output.set("runtime.omp", atlas_omp_get_max_threads());
        output.set("atlas.build_type", ATLAS_BUILD_TYPE);

        output.set("timings.target.setup", timers.target_setup.elapsed());
        output.set("timings.source.setup", timers.source_setup.elapsed());
        output.set("timings.initial_condition", timers.initial_condition.elapsed());
        output.set("timings.interpolation.setup", timers.interpolation_setup.elapsed());
        output.set("timings.interpolation.execute", timers.interpolation_execute.elapsed());

        output.set("interpolation", metadata);

        eckit::PathName json_filepath(args.getString("json.file", "out.json"));
        std::ostringstream ss;
        eckit::JSON json(ss, eckit::JSON::Formatting::indent(4));
        json << output;

        eckit::FileStream file(json_filepath, "w");
        std::string str = ss.str();
        file.write(str.data(), str.size());
        file.close();
    }


    return success();
}

}  // namespace atlas


int main(int argc, char* argv[]) {
    atlas::AtlasParallelInterpolation tool(argc, argv);
    return tool.start();
}
