// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#include <cstdio>

#include <Kokkos_Macros.hpp>
#ifdef KOKKOS_ENABLE_EXPERIMENTAL_CXX20_MODULES
import kokkos.core;
#else
#include <Kokkos_Core.hpp>
#endif

namespace Test {

namespace {

template <class ExecSpace, class ScheduleType>
struct TestRange {
  using value_type = int;  ///< alias required for the parallel_reduce

  using view_type = Kokkos::View<value_type *, ExecSpace>;

  view_type m_flags;
  view_type result_view;

  struct VerifyInitTag {};
  struct ResetTag {};
  struct VerifyResetTag {};
  struct OffsetTag {};
  struct VerifyOffsetTag {};

  int N;
#ifndef KOKKOS_WORKAROUND_OPENMPTARGET_GCC
  static const int offset = 13;
#else
  int offset;
#endif
  TestRange(const size_t N_)
      : m_flags(Kokkos::view_alloc(Kokkos::WithoutInitializing, "flags"), N_),
        result_view(Kokkos::view_alloc(Kokkos::WithoutInitializing, "results"),
                    N_),
        N(N_) {
#ifdef KOKKOS_WORKAROUND_OPENMPTARGET_GCC
    offset = 13;
#endif
  }

  void test_for() {
    typename view_type::host_mirror_type host_flags =
        Kokkos::create_mirror_view(m_flags);

    Kokkos::parallel_for(Kokkos::RangePolicy<ExecSpace, ScheduleType>(0, N),
                         *this);

    Kokkos::parallel_for(
        Kokkos::RangePolicy<ExecSpace, ScheduleType, VerifyInitTag>(0, N),
        *this);

    Kokkos::deep_copy(host_flags, m_flags);

    int error_count = 0;
    for (int i = 0; i < N; ++i) {
      if (int(i) != host_flags(i)) ++error_count;
    }
    ASSERT_EQ(error_count, int(0));

    Kokkos::parallel_for(
        Kokkos::RangePolicy<ExecSpace, ScheduleType, ResetTag>(0, N), *this);
    Kokkos::parallel_for(
        std::string("TestKernelFor"),
        Kokkos::RangePolicy<ExecSpace, ScheduleType, VerifyResetTag>(0, N),
        *this);

    Kokkos::deep_copy(host_flags, m_flags);

    error_count = 0;
    for (int i = 0; i < N; ++i) {
      if (int(2 * i) != host_flags(i)) ++error_count;
    }
    ASSERT_EQ(error_count, int(0));

    Kokkos::parallel_for(
        Kokkos::RangePolicy<ExecSpace, ScheduleType, OffsetTag>(offset,
                                                                N + offset),
        *this);
    Kokkos::parallel_for(
        std::string("TestKernelFor"),
        Kokkos::RangePolicy<ExecSpace, ScheduleType, VerifyOffsetTag>(0, N),
        *this);

    Kokkos::deep_copy(host_flags, m_flags);

    error_count = 0;
    for (int i = 0; i < N; ++i) {
      if (i + offset != host_flags(i)) ++error_count;
    }
    ASSERT_EQ(error_count, int(0));
  }

  KOKKOS_INLINE_FUNCTION
  void operator()(const int i) const { m_flags(i) = i; }

  KOKKOS_INLINE_FUNCTION
  void operator()(const VerifyInitTag &, const int i) const {
    if (i != m_flags(i)) {
      Kokkos::printf("TestRange::test_for_error at %d != %d\n", i, m_flags(i));
    }
  }

  KOKKOS_INLINE_FUNCTION
  void operator()(const ResetTag &, const int i) const {
    m_flags(i) = 2 * m_flags(i);
  }

  KOKKOS_INLINE_FUNCTION
  void operator()(const VerifyResetTag &, const int i) const {
    if (2 * i != m_flags(i)) {
      Kokkos::printf("TestRange::test_for_error at %d != %d\n", i, m_flags(i));
    }
  }

  KOKKOS_INLINE_FUNCTION
  void operator()(const OffsetTag &, const int i) const {
    m_flags(i - offset) = i;
  }

  KOKKOS_INLINE_FUNCTION
  void operator()(const VerifyOffsetTag &, const int i) const {
    if (i + offset != m_flags(i)) {
      Kokkos::printf("TestRange::test_for_error at %d != %d\n", i + offset,
                     m_flags(i));
    }
  }

  //----------------------------------------

  void test_reduce() {
    value_type total = 0;

    Kokkos::parallel_for(Kokkos::RangePolicy<ExecSpace, ScheduleType>(0, N),
                         *this);

    Kokkos::parallel_reduce("TestKernelReduce",
                            Kokkos::RangePolicy<ExecSpace, ScheduleType>(0, N),
                            *this, total);
    // sum( 0 .. N-1 )
    ASSERT_EQ(size_t((N - 1) * (N) / 2), size_t(total));

    Kokkos::parallel_reduce(
        "TestKernelReduce_long",
        Kokkos::RangePolicy<ExecSpace, ScheduleType, long>(0, N), *this, total);
    // sum( 0 .. N-1 )
    ASSERT_EQ(size_t((N - 1) * (N) / 2), size_t(total));

    Kokkos::parallel_reduce(
        Kokkos::RangePolicy<ExecSpace, ScheduleType, OffsetTag>(offset,
                                                                N + offset),
        *this, total);
    // sum( 1 .. N )
    ASSERT_EQ(size_t((N) * (N + 1) / 2), size_t(total));
  }

  KOKKOS_INLINE_FUNCTION
  void operator()(const int i, value_type &update) const {
    update += m_flags(i);
  }

  KOKKOS_INLINE_FUNCTION
  void operator()(const OffsetTag &, const int i, value_type &update) const {
    update += 1 + m_flags(i - offset);
  }

  void test_dynamic_policy() {
    auto const N_no_implicit_capture = N;
    using policy_t =
        Kokkos::RangePolicy<ExecSpace, Kokkos::Schedule<Kokkos::Dynamic>>;
    int const concurrency = ExecSpace().concurrency();

    {
      Kokkos::View<size_t *, ExecSpace, Kokkos::MemoryTraits<Kokkos::Atomic>>
          count("Count", concurrency);
      Kokkos::View<int *, ExecSpace> a("A", N);

      Kokkos::parallel_for(
          policy_t(0, N), KOKKOS_LAMBDA(const int &i) {
            for (int k = 0; k < (i < N_no_implicit_capture / 2 ? 1 : 10000);
                 k++) {
              a(i)++;
            }
            count(ExecSpace::impl_hardware_thread_id())++;
          });

      int error = 0;
      Kokkos::parallel_reduce(
          Kokkos::RangePolicy<ExecSpace>(0, N),
          KOKKOS_LAMBDA(const int &i, value_type &lsum) {
            lsum += (a(i) != (i < N_no_implicit_capture / 2 ? 1 : 10000));
          },
          error);
      ASSERT_EQ(error, 0);

      if ((concurrency > 1) && (N > 4 * concurrency)) {
        size_t min = N;
        size_t max = 0;
        for (int t = 0; t < concurrency; t++) {
          if (count(t) < min) min = count(t);
          if (count(t) > max) max = count(t);
        }
        ASSERT_LT(min, max);

        // if ( concurrency > 2 ) {
        //  ASSERT_LT( 2 * min, max );
        //}
      }
    }

    {
      Kokkos::View<size_t *, ExecSpace, Kokkos::MemoryTraits<Kokkos::Atomic>>
          count("Count", concurrency);
      Kokkos::View<int *, ExecSpace> a("A", N);

      value_type sum = 0;
      Kokkos::parallel_reduce(
          policy_t(0, N),
          KOKKOS_LAMBDA(const int &i, value_type &lsum) {
            for (int k = 0; k < (i < N_no_implicit_capture / 2 ? 1 : 10000);
                 k++) {
              a(i)++;
            }
            count(ExecSpace::impl_hardware_thread_id())++;
            lsum++;
          },
          sum);
      ASSERT_EQ(sum, N);

      int error = 0;
      Kokkos::parallel_reduce(
          Kokkos::RangePolicy<ExecSpace>(0, N),
          KOKKOS_LAMBDA(const int &i, value_type &lsum) {
            lsum += (a(i) != (i < N_no_implicit_capture / 2 ? 1 : 10000));
          },
          error);
      ASSERT_EQ(error, 0);

      if ((concurrency > 1) && (N > 4 * concurrency)) {
        size_t min = N;
        size_t max = 0;
        for (int t = 0; t < concurrency; t++) {
          if (count(t) < min) min = count(t);
          if (count(t) > max) max = count(t);
        }
        ASSERT_LT(min, max);

        // if ( concurrency > 2 ) {
        //  ASSERT_LT( 2 * min, max );
        //}
      }
    }
  }
};

}  // namespace

TEST(TEST_CATEGORY, range_for) {
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Static>> f(0);
    f.test_for();
  }
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Dynamic>> f(0);
    f.test_for();
  }

  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Static>> f(2);
    f.test_for();
  }
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Dynamic>> f(3);
    f.test_for();
  }

  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Static>> f(1000);
    f.test_for();
  }
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Dynamic>> f(1001);
    f.test_for();
  }
}

TEST(TEST_CATEGORY, range_reduce) {
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Static>> f(0);
    f.test_reduce();
  }
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Dynamic>> f(0);
    f.test_reduce();
  }

  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Static>> f(2);
    f.test_reduce();
  }
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Dynamic>> f(3);
    f.test_reduce();
  }

  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Static>> f(1000);
    f.test_reduce();
  }
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Dynamic>> f(1001);
    f.test_reduce();
  }
}

template <typename ExecSpace, typename StaticBatchSize>
struct TestStaticBatchSize {
  using view_type = Kokkos::View<int *, ExecSpace>;

  view_type m_flags;
  view_type result_view;

  struct AtomicAddTag {};
  struct VerifyAtomicAddTag {};

  size_t N;

  TestStaticBatchSize(const size_t N_)
      : m_flags(Kokkos::view_alloc(Kokkos::WithoutInitializing, "flags"), N_),
        result_view(Kokkos::view_alloc(Kokkos::WithoutInitializing, "results"),
                    N_),
        N(N_) {}

  void test_batch_size() {
    Kokkos::deep_copy(m_flags, 0);

    Kokkos::parallel_for(
        Kokkos::RangePolicy<ExecSpace, AtomicAddTag, StaticBatchSize>(0, N),
        *this);

    bool success = true;
    Kokkos::parallel_reduce(
        Kokkos::RangePolicy<ExecSpace, VerifyAtomicAddTag>(0, N), *this,
        Kokkos::LAnd<bool>(success));

    ASSERT_TRUE(success);
  }

  KOKKOS_INLINE_FUNCTION
  void operator()(const AtomicAddTag, const int i) const {
    Kokkos::atomic_add(&m_flags(i), 1);
  }

  KOKKOS_INLINE_FUNCTION
  void operator()(const VerifyAtomicAddTag, const int i, bool &success) const {
    if (m_flags(i) != 1) {
      Kokkos::printf(
          "TestStaticBatchSize {::test_batch_size_error at %d != %d\n", i,
          m_flags(i));
    }
    success = success && (m_flags(i) == 1);
  }
};

#ifndef KOKKOS_ENABLE_OPENMPTARGET
TEST(TEST_CATEGORY, range_dynamic_policy) {
#if !defined(KOKKOS_ENABLE_CUDA) && !defined(KOKKOS_ENABLE_HIP) && \
    !defined(KOKKOS_ENABLE_SYCL)
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Dynamic>> f(0);
    f.test_dynamic_policy();
  }
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Dynamic>> f(3);
    f.test_dynamic_policy();
  }
  {
    TestRange<TEST_EXECSPACE, Kokkos::Schedule<Kokkos::Dynamic>> f(1001);
    f.test_dynamic_policy();
  }
#endif
}
#endif

// For 32-bit builds a View can't store enough elements
#ifndef KOKKOS_IMPL_32BIT
void test_large_parallel_for_reduce() {
  using ExecutionSpace              = typename TEST_EXECSPACE::execution_space;
  constexpr long long unsigned size = 1llu << 32;
  Kokkos::View<char *, TEST_EXECSPACE::memory_space> v(
      Kokkos::view_alloc(Kokkos::WithoutInitializing, "v"), size);

  // We want to explicitly test that using a parallel_for for filling the View
  // works to test if our internal block size calculations do not overflow.
  Kokkos::parallel_for(
      Kokkos::RangePolicy<ExecutionSpace,
                          Kokkos::IndexType<long long unsigned>>(0, size),
      KOKKOS_LAMBDA(long long unsigned i) { v(i) = 1; });

  long long unsigned sum;
  Kokkos::parallel_reduce(
      Kokkos::RangePolicy<ExecutionSpace,
                          Kokkos::IndexType<long long unsigned>>(0, size),
      KOKKOS_LAMBDA(long long unsigned, long long unsigned &partial_sum) {
        partial_sum += 1;
      },
      sum);
  ASSERT_EQ(sum, size);
}

TEST(TEST_CATEGORY, large_parallel_for_reduce) {
  if constexpr (std::is_same_v<typename TEST_EXECSPACE::memory_space,
                               Kokkos::HostSpace>) {
    GTEST_SKIP() << "Disabling for host backends";
  }
  test_large_parallel_for_reduce();
}
#endif

TEST(TEST_CATEGORY, check_batch_size) {
  ASSERT_TRUE(Kokkos::Experimental::StaticBatchSize<1>::batch_size == 1);
  ASSERT_TRUE(Kokkos::Experimental::StaticBatchSize<4>::batch_size == 4);
}

TEST(TEST_CATEGORY, range_static_batch_size) {
  {
    TestStaticBatchSize<TEST_EXECSPACE,
                        Kokkos::Experimental::StaticBatchSize<1>>
        f(1024);
    f.test_batch_size();
  }
  {
    TestStaticBatchSize<TEST_EXECSPACE,
                        Kokkos::Experimental::StaticBatchSize<2>>
        f(1024);
    f.test_batch_size();
  }
  {
    TestStaticBatchSize<TEST_EXECSPACE,
                        Kokkos::Experimental::StaticBatchSize<4>>
        f(1024);
    f.test_batch_size();
  }

  // Check for loop ranges where the range is not exactly divisible by the
  // static batch size.
  {
    TestStaticBatchSize<TEST_EXECSPACE,
                        Kokkos::Experimental::StaticBatchSize<4>>
        f(1025);
    f.test_batch_size();
  }

  // Check for loop ranges smaller than the static batch size.
  {
    TestStaticBatchSize<TEST_EXECSPACE,
                        Kokkos::Experimental::StaticBatchSize<4>>
        f(3);
    f.test_batch_size();
  }
}

}  // namespace Test
