// tests

#include "catch2/catch.hpp"

#include <inputs.hpp>
#include <solver.hpp>

TEST_CASE("from_origin", "[]") {
  solver::InputInfo ii;

  // Geometry
  ii[PROB_LO] = std::vector<double>({-4, -5, -6});
  ii[PROB_HI] = solver::NumberArray({0.004, 0.001, 0.001});
  ii[PERIODIC] = solver::NumberArray({0, 0, 0});
  ii[CSG_FILENAME] = solver::StringArray({"geometry.csg",});

  // Mesh
  ii[N_CELL] = solver::NumberArray({11, 22, 33});
  ii[BLOCKING_FACTOR] = solver::NumberArray({222});
  ii[SMALL_VOLFRAC] = solver::NumberArray({0.000999});
  ii[FABARRAY_TILE_SZ] = solver::NumberArray({1234, 4567, 7890});
  ii[PARTICLE_TILE_SZ] = solver::NumberArray({333, 444, 555});
  ii[GRID_SIZE_X] = solver::NumberArray({24});
  ii[GRID_SIZE_Y] = solver::NumberArray({25});
  ii[GRID_SIZE_Z] = solver::NumberArray({26});
  ii[PARTICLE_GRID_SIZE_X] = solver::NumberArray({94});
  ii[PARTICLE_GRID_SIZE_Y] = solver::NumberArray({95});
  ii[PARTICLE_GRID_SIZE_Z] = solver::NumberArray({96});

  // Time
  ii[DT_MAX] = solver::NumberArray({99.99});
  ii[DT_MIN] = solver::NumberArray({0.001});
  ii[MAXSTEP] = solver::NumberArray({39});
  ii[TSTOP] = solver::NumberArray({3.14});
  ii[FIXED_DT] = solver::NumberArray({1});
  ii[CFL] = solver::NumberArray({0.77});
  ii[TCOLL_RATIO] = solver::NumberArray({49.2});

  auto [sv, messages] = solver::make_solver(ii);

  SECTION(" Geometry fields ") {
    auto [xx, yy, zz] = sv.geometry.axes;
    CHECK(xx.high == 0.004);
    CHECK(xx.low == -4);
    CHECK(yy.high == 0.001);
    CHECK(yy.low == -5);
    CHECK(zz.high == 0.001);
    CHECK(zz.low == -6);
    CHECK_FALSE(xx.periodic);
    CHECK_FALSE(yy.periodic);
    CHECK_FALSE(zz.periodic);
  }
  SECTION(" Mesh fields ") {
    CHECK(sv.mesh.blocking_factor == 222);
    auto [mx, my, mz] = sv.mesh.axes;
    CHECK(mx.n_cell == 11);
    CHECK(my.n_cell == 22);
    CHECK(mz.n_cell == 33);
    CHECK(mx.max_grid_size == 24);
    CHECK(my.max_grid_size == 25);
    CHECK(mz.max_grid_size == 26);
    CHECK(mx.particle_max_grid_size == 94);
    CHECK(my.particle_max_grid_size == 95);
    CHECK(mz.particle_max_grid_size == 96);
    CHECK(mx.particle_max_tile_size == 333);
    CHECK(my.particle_max_tile_size == 444);
    CHECK(mz.particle_max_tile_size == 555);
    CHECK(mx.fluid_max_tile_size == 1234);
    CHECK(my.fluid_max_tile_size == 4567);
    CHECK(mz.fluid_max_tile_size == 7890);
  }

  SECTION(" Time fields ") {
    CHECK(sv.time.dt_max == 99.99);
    CHECK(sv.time.dt_min == 0.001);
    CHECK(sv.time.max_step == 39);
    CHECK(sv.time.tstop == 3.14);
    CHECK(sv.time.fixed_dt);
    CHECK(sv.time.cfl == 0.77);
    CHECK(sv.time.tcoll_ratio == 49.2);
  }
}

TEST_CASE("serialize", "[]") {
  solver::SolverSettings ss;

  std::get<0>(ss.geometry.axes).periodic = 0;
  std::get<1>(ss.geometry.axes).periodic = 1;
  std::get<2>(ss.geometry.axes).periodic = 0;

  std::get<0>(ss.geometry.axes).low = 0;
  std::get<1>(ss.geometry.axes).low = 0;
  std::get<2>(ss.geometry.axes).low = 0;

  std::get<0>(ss.geometry.axes).high = 0.004;
  std::get<1>(ss.geometry.axes).high = 0.001;
  std::get<2>(ss.geometry.axes).high = 0.001;

  std::get<0>(ss.mesh.axes).n_cell = 11;
  std::get<1>(ss.mesh.axes).n_cell = 22;
  std::get<2>(ss.mesh.axes).n_cell = 33;

  auto inputs_str = serialize(ss);

  CHECK(inputs_str == R"(
  geometry.is_periodic = 0 1 0
  geometry.prob_lo     = 0 0 0
  geometry.prob_hi     = 0.004 0.001 0.001
  amr.n_cell      = 11 22 33)");
}
