cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub weilycoder/cp-library

:heavy_check_mark: test/matrix_product.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"

#include "../weilycoder/matrix.hpp"
#include "../weilycoder/number_theory/modint.hpp"
#include <cstdint>
#include <iostream>
using namespace std;
using namespace weilycoder;

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  cin.exceptions(cin.failbit | cin.badbit);
  size_t N, M, K;
  cin >> N >> M >> K;
  NMatrix<modint<998244353>> A(N, M), B(M, K);
  for (size_t i = 0; i < N; ++i)
    for (size_t j = 0; j < M; ++j)
      cin >> A(i, j);
  for (size_t i = 0; i < M; ++i)
    for (size_t j = 0; j < K; ++j)
      cin >> B(i, j);
  auto C = A * B;
  for (size_t i = 0; i < N; ++i)
    for (size_t j = 0; j < K; ++j)
      cout << C(i, j) << " \n"[j + 1 == K];
  return 0;
}
#line 1 "test/matrix_product.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"

#line 1 "weilycoder/matrix.hpp"



/**
 * @file matrix.hpp
 * @brief Matrix implementation using narray
 */

#line 1 "weilycoder/ds/bitset.hpp"



/**
 * @file bitset.hpp
 * @brief Run-time sized bitset
 */

#include <cstdint>
#include <stdexcept>
#include <vector>

namespace weilycoder {
/**
 * @class dbitset
 * @brief Run-time sized bitset implementation
 */
class dbitset {
private:
  const size_t bitsize;
  std::vector<uint64_t> bits;

  size_t extra_size() const { return bits.size() * 64 - bitsize; }

  uint64_t extra_mask() const {
    size_t extra = extra_size();
    return extra ? ((1ULL << (64 - extra)) - 1) : ~0ULL;
  }

  void trim_extra() { bits.back() &= extra_mask(); }

public:
  dbitset(size_t n) : bitsize(n), bits((n + 63) / 64, 0) {}
  dbitset(const dbitset &other) : bitsize(other.bitsize), bits(other.bits) {}

  dbitset &operator=(const dbitset &other) {
    if (this != &other) {
      if (bitsize != other.bitsize)
        throw std::invalid_argument("Bitset sizes do not match for assignment.");
      bits = other.bits;
    }
    return *this;
  }

  struct reference {
    dbitset &bs;
    size_t pos;

    reference(dbitset &bs, size_t pos) : bs(bs), pos(pos) {}

    operator bool() const {
      size_t idx = pos / 64;
      size_t bit = pos % 64;
      return (bs.bits[idx] >> bit) & 1ULL;
    }

    reference &operator=(bool val) {
      size_t idx = pos / 64;
      size_t bit = pos % 64;
      if (val)
        bs.bits[idx] |= (1ULL << bit);
      else
        bs.bits[idx] &= ~(1ULL << bit);
      return *this;
    }
  };

  struct const_reference {
    const dbitset &bs;
    size_t pos;

    const_reference(const dbitset &bs, size_t pos) : bs(bs), pos(pos) {}

    operator bool() const {
      size_t idx = pos / 64;
      size_t bit = pos % 64;
      return (bs.bits[idx] >> bit) & 1ULL;
    }
  };

  struct iterator {
    dbitset &bs;
    size_t pos;

    iterator(dbitset &bs, size_t pos) : bs(bs), pos(pos) {}
    bool operator!=(const iterator &other) const { return pos != other.pos; }
    reference operator*() { return reference(bs, pos); }
    iterator &operator++() { return ++pos, *this; }
    iterator operator++(int) {
      iterator temp = *this;
      ++pos;
      return temp;
    }
  };

  struct const_iterator {
    const dbitset &bs;
    size_t pos;

    const_iterator(const dbitset &bs, size_t pos) : bs(bs), pos(pos) {}
    bool operator!=(const const_iterator &other) const { return pos != other.pos; }
    const_reference operator*() const { return const_reference(bs, pos); }
    const_iterator &operator++() { return ++pos, *this; }
    const_iterator operator++(int) {
      const_iterator temp = *this;
      ++pos;
      return temp;
    }
  };

  iterator begin() { return iterator(*this, 0); }
  iterator end() { return iterator(*this, bitsize); }
  const_iterator begin() const { return const_iterator(*this, 0); }
  const_iterator end() const { return const_iterator(*this, bitsize); }

  reference operator[](size_t pos) { return reference(*this, pos); }
  const_reference operator[](size_t pos) const { return const_reference(*this, pos); }

  size_t size() const { return bitsize; }

  size_t d_size() const { return bits.size(); }
  uint64_t &d_get_word(size_t idx) { return bits[idx]; }
  const uint64_t &d_get_word(size_t idx) const { return bits[idx]; }

  dbitset &operator&=(const dbitset &other) {
    if (bitsize != other.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for AND operation.");
    for (size_t i = 0; i < bits.size(); ++i)
      bits[i] &= other.bits[i];
    return this->trim_extra(), *this;
  }

  dbitset &operator|=(const dbitset &other) {
    if (bitsize != other.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for OR operation.");
    for (size_t i = 0; i < bits.size(); ++i)
      bits[i] |= other.bits[i];
    return this->trim_extra(), *this;
  }

  dbitset &operator^=(const dbitset &other) {
    if (bitsize != other.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for XOR operation.");
    for (size_t i = 0; i < bits.size(); ++i)
      bits[i] ^= other.bits[i];
    return this->trim_extra(), *this;
  }

  dbitset &operator<<=(size_t shift) {
    if (shift >= bitsize) {
      std::fill(bits.begin(), bits.end(), 0);
      return *this;
    }
    size_t word_shift = shift / 64;
    size_t bit_shift = shift % 64;
    if (bit_shift == 0) {
      for (size_t i = bits.size() - 1; i >= word_shift; --i)
        bits[i] = bits[i - word_shift];
    } else {
      for (size_t i = bits.size() - 1; i > word_shift; --i)
        bits[i] = (bits[i - word_shift] << bit_shift) |
                  (bits[i - word_shift - 1] >> (64 - bit_shift));
      bits[word_shift] = bits[0] << bit_shift;
    }
    std::fill(bits.begin(), bits.begin() + word_shift, 0);
    return this->trim_extra(), *this;
  }

  dbitset &operator>>=(size_t shift) {
    if (shift >= bitsize) {
      std::fill(bits.begin(), bits.end(), 0);
      return *this;
    }
    size_t word_shift = shift / 64;
    size_t bit_shift = shift % 64;
    size_t n = bits.size();
    if (bit_shift == 0) {
      for (size_t i = 0; i < n - word_shift; ++i)
        bits[i] = bits[i + word_shift];
    } else {
      for (size_t i = 0; i < n - word_shift - 1; ++i)
        bits[i] = (bits[i + word_shift] >> bit_shift) |
                  (bits[i + word_shift + 1] << (64 - bit_shift));
      bits[n - word_shift - 1] = bits[n - 1] >> bit_shift;
    }
    std::fill(bits.end() - word_shift, bits.end(), 0);
    return *this;
  }

  friend dbitset operator~(const dbitset &bs) {
    dbitset result(bs.bitsize);
    for (size_t i = 0; i < bs.bits.size(); ++i)
      result.bits[i] = ~bs.bits[i];
    return result.trim_extra(), result;
  }

  friend dbitset operator&(const dbitset &a, const dbitset &b) {
    if (a.bitsize != b.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for AND operation.");
    dbitset result(a.bitsize);
    for (size_t i = 0; i < a.bits.size(); ++i)
      result.bits[i] = a.bits[i] & b.bits[i];
    return result.trim_extra(), result;
  }

  friend dbitset operator|(const dbitset &a, const dbitset &b) {
    if (a.bitsize != b.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for OR operation.");
    dbitset result(a.bitsize);
    for (size_t i = 0; i < a.bits.size(); ++i)
      result.bits[i] = a.bits[i] | b.bits[i];
    return result.trim_extra(), result;
  }

  friend dbitset operator^(const dbitset &a, const dbitset &b) {
    if (a.bitsize != b.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for XOR operation.");
    dbitset result(a.bitsize);
    for (size_t i = 0; i < a.bits.size(); ++i)
      result.bits[i] = a.bits[i] ^ b.bits[i];
    return result.trim_extra(), result;
  }

  friend dbitset operator<<(const dbitset &bs, size_t shift) {
    dbitset result = bs;
    return result <<= shift, result;
  }

  friend dbitset operator>>(const dbitset &bs, size_t shift) {
    dbitset result = bs;
    return result >>= shift, result;
  }
};
} // namespace weilycoder


#line 1 "weilycoder/ds/narray.hpp"



/**
 * @file narray.hpp
 * @brief N-dimensional array (narray)
 */

#include <array>
#include <cstddef>
#include <initializer_list>
#line 13 "weilycoder/ds/narray.hpp"

namespace weilycoder {
/**
 * @brief N-dimensional array (narray) implementation.
 * @tparam T Type of elements stored in the narray.
 * @tparam D Number of dimensions.
 */
template <typename T, size_t D> struct narray {
  const std::array<size_t, D> shape;
  std::vector<T> data;

  /**
   * @brief Constructor to initialize narray with given dimensions.
   * @tparam Sizes Parameter pack for sizes of each dimension.
   * @param sizes Sizes of each dimension.
   */
  template <typename... Sizes>
  narray(Sizes... sizes) : shape{static_cast<size_t>(sizes)...} {
    static_assert(sizeof...(Sizes) == D, "Number of sizes must match dimensions.");
    data.resize(size());
  }

  /**
   * @brief Get the total number of elements in the narray.
   * @return Total number of elements.
   */
  size_t size() const {
    size_t res = 1;
    for (size_t i = 0; i < D; ++i)
      res *= shape[i];
    return res;
  }

  /**
   * @brief Compute the linear index from multi-dimensional indices.
   * @tparam Indices Parameter pack for indices in each dimension.
   * @param indices Indices in each dimension.
   * @return Linear index corresponding to the multi-dimensional indices.
   */
  template <typename... Indices> size_t _index(Indices... indices) const {
    static_assert(sizeof...(Indices) == D, "Number of indices must match dimensions.");
    const size_t idxs[] = {static_cast<size_t>(indices)...};
    size_t res = 0;
    for (size_t i = 0; i < D; ++i)
      res = res * shape[i] + idxs[i];
    return res;
  }

  /**
   * @brief Access element at specified multi-dimensional indices.
   * @tparam Indices Parameter pack for indices in each dimension.
   * @param indices Indices in each dimension.
   * @return Reference to the element at the specified indices.
   */
  template <typename... Indices> const T &operator()(Indices... indices) const {
    return data[_index(indices...)];
  }

  /**
   * @brief Access element at specified multi-dimensional indices.
   * @tparam Indices Parameter pack for indices in each dimension.
   * @param indices Indices in each dimension.
   * @return Reference to the element at the specified indices.
   */
  template <typename... Indices> T &operator()(Indices... indices) {
    return data[_index(indices...)];
  }
};
} // namespace weilycoder


#line 14 "weilycoder/matrix.hpp"

namespace weilycoder {
template <typename T, size_t R, size_t C> struct Matrix {
  T data[R][C];

  Matrix() : data{} {}
  Matrix(std::initializer_list<T> init) {
    size_t i = 0, j = 0;
    for (const auto &val : init) {
      data[i][j] = val;
      if (++j == C)
        j = 0, ++i;
    }
  }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  constexpr const T &operator()(size_t row, size_t col) const { return data[row][col]; }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  T &operator()(size_t row, size_t col) { return data[row][col]; }

  Matrix<T, R, C> &operator+=(const Matrix<T, R, C> &other) {
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        data(i, j) += other(i, j);
    return *this;
  }

  Matrix<T, R, C> &operator-=(const Matrix<T, R, C> &other) {
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        data(i, j) -= other(i, j);
    return *this;
  }

  template <size_t M> Matrix<T, R, M> &operator*=(const Matrix<T, C, M> &other) {
    Matrix<T, R, M> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < M; ++j)
        for (size_t k = 0; k < C; ++k)
          result(i, j) += data(i, k) * other(k, j);
    return *this = result;
  }

  friend Matrix<T, R, C> operator+(const Matrix<T, R, C> &a, const Matrix<T, R, C> &b) {
    Matrix<T, R, C> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        result(i, j) = a(i, j) + b(i, j);
    return result;
  }

  friend Matrix<T, R, C> operator-(const Matrix<T, R, C> &a, const Matrix<T, R, C> &b) {
    Matrix<T, R, C> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        result(i, j) = a(i, j) - b(i, j);
    return result;
  }

  template <size_t M>
  friend Matrix<T, R, C> operator*(const Matrix<T, R, M> &a, const Matrix<T, M, C> &b) {
    Matrix<T, R, C> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        for (size_t k = 0; k < M; ++k)
          result(i, j) += a(i, k) * b(k, j);
    return result;
  }
};

template <typename T> struct NMatrix {
  narray<T, 2> data;

  NMatrix(size_t rows, size_t cols) : data(rows, cols) {}

  template <size_t R, size_t C> NMatrix(const Matrix<T, R, C> &matrix) : data(R, C) {
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        data(i, j) = matrix(i, j);
  }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  constexpr const T &operator()(size_t row, size_t col) const { return data(row, col); }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  T &operator()(size_t row, size_t col) { return data(row, col); }

  NMatrix<T> &operator+=(const NMatrix<T> &other) {
    if (data.shape != other.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for addition.");
    for (size_t i = 0; i < data.shape[0]; ++i)
      for (size_t j = 0; j < data.shape[1]; ++j)
        data(i, j) += other(i, j);
    return *this;
  }

  NMatrix<T> &operator-=(const NMatrix<T> &other) {
    if (data.shape != other.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for subtraction.");
    for (size_t i = 0; i < data.shape[0]; ++i)
      for (size_t j = 0; j < data.shape[1]; ++j)
        data(i, j) -= other(i, j);
    return *this;
  }

  NMatrix<T> &operator*=(const NMatrix<T> &other) {
    if (data.shape[1] != other.data.shape[0])
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    size_t R = data.shape[0];
    size_t C = other.data.shape[1];
    size_t K = data.shape[1];
    NMatrix<T> result(R, C);
    for (size_t i = 0; i < R; ++i)
      for (size_t k = 0; k < K; ++k)
        for (size_t j = 0; j < C; ++j)
          result(i, j) += data(i, k) * other(k, j);
    return *this = result;
  }

  friend NMatrix<T> operator+(const NMatrix<T> &a, const NMatrix<T> &b) {
    if (a.data.shape != b.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for addition.");
    NMatrix<T> result(a.data.shape[0], a.data.shape[1]);
    for (size_t i = 0; i < a.data.shape[0]; ++i)
      for (size_t j = 0; j < a.data.shape[1]; ++j)
        result(i, j) = a(i, j) + b(i, j);
    return result;
  }

  friend NMatrix<T> operator-(const NMatrix<T> &a, const NMatrix<T> &b) {
    if (a.data.shape != b.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for subtraction.");
    NMatrix<T> result(a.data.shape[0], a.data.shape[1]);
    for (size_t i = 0; i < a.data.shape[0]; ++i)
      for (size_t j = 0; j < a.data.shape[1]; ++j)
        result(i, j) = a(i, j) - b(i, j);
    return result;
  }

  friend NMatrix<T> operator*(const NMatrix<T> &a, const NMatrix<T> &b) {
    if (a.data.shape[1] != b.data.shape[0])
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    size_t R = a.data.shape[0];
    size_t C = b.data.shape[1];
    size_t K = a.data.shape[1];
    NMatrix<T> result(R, C);
    for (size_t i = 0; i < R; ++i)
      for (size_t k = 0; k < K; ++k)
        for (size_t j = 0; j < C; ++j)
          result(i, j) += a(i, k) * b(k, j);
    return result;
  }
};

struct BMatrix {
  size_t rows, cols;
  std::vector<dbitset> data;

  BMatrix(size_t rows, size_t cols)
      : rows(rows), cols(cols), data(rows, dbitset(cols)) {}

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  dbitset::reference operator()(size_t row, size_t col) { return data[row][col]; }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  dbitset::const_reference operator()(size_t row, size_t col) const {
    return data[row][col];
  }

  BMatrix &operator&=(const BMatrix &other) {
    for (size_t i = 0; i < rows; ++i)
      data[i] &= other.data[i];
    return *this;
  }

  BMatrix &operator|=(const BMatrix &other) {
    for (size_t i = 0; i < rows; ++i)
      data[i] |= other.data[i];
    return *this;
  }

  BMatrix &operator^=(const BMatrix &other) {
    for (size_t i = 0; i < rows; ++i)
      data[i] ^= other.data[i];
    return *this;
  }

  BMatrix operator*=(const BMatrix &other) {
    if (cols != other.rows)
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    BMatrix result(rows, other.cols);
    for (size_t i = 0; i < rows; ++i)
      for (size_t k = 0; k < cols; ++k)
        if (data[i][k])
          result.data[i] ^= other.data[k];
    return *this = result;
  }

  friend BMatrix operator&(const BMatrix &a, const BMatrix &b) {
    if (a.rows != b.rows || a.cols != b.cols)
      throw std::invalid_argument("Matrix dimensions do not match for AND operation.");
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = static_cast<bool>(a(i, j)) && static_cast<bool>(b(i, j));
    return result;
  }

  friend BMatrix operator|(const BMatrix &a, const BMatrix &b) {
    if (a.rows != b.rows || a.cols != b.cols)
      throw std::invalid_argument("Matrix dimensions do not match for OR operation.");
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = static_cast<bool>(a(i, j)) || static_cast<bool>(b(i, j));
    return result;
  }

  friend BMatrix operator^(const BMatrix &a, const BMatrix &b) {
    if (a.rows != b.rows || a.cols != b.cols)
      throw std::invalid_argument("Matrix dimensions do not match for XOR operation.");
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = static_cast<bool>(a(i, j)) ^ static_cast<bool>(b(i, j));
    return result;
  }

  friend BMatrix operator~(const BMatrix &a) {
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = !static_cast<bool>(a(i, j));
    return result;
  }

  friend BMatrix operator*(const BMatrix &a, const BMatrix &b) {
    if (a.cols != b.rows)
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    BMatrix result(a.rows, b.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t k = 0; k < a.cols; ++k)
        if (a(i, k))
          result.data[i] ^= b.data[k];
    return result;
  }
};
} // namespace weilycoder


#line 1 "weilycoder/number_theory/modint.hpp"



/**
 * @file modint.hpp
 * @brief Modular Integer Class
 */

#line 1 "weilycoder/number_theory/mod_utility.hpp"



/**
 * @file mod_utility.hpp
 * @brief Modular Arithmetic Utilities
 */

#line 10 "weilycoder/number_theory/mod_utility.hpp"

namespace weilycoder {
using u128 = unsigned __int128;

/**
 * @brief Perform modular addition for 64-bit integers.
 * @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
 *         all inputs are small enough to avoid overflow (i.e. bit-32).
 * @param a The first addend.
 * @param b The second addend.
 * @param modulus The modulus.
 * @return (a + b) % modulus
 */
template <bool bit32 = false>
constexpr uint64_t mod_add(uint64_t a, uint64_t b, uint64_t modulus) {
  if constexpr (bit32) {
    uint64_t res = a + b;
    if (res >= modulus)
      res -= modulus;
    return res;
  } else {
    u128 res = static_cast<u128>(a) + b;
    if (res >= modulus)
      res -= modulus;
    return res;
  }
}

/**
 * @brief Perform modular addition for 64-bit integers with a compile-time
 *        modulus.
 * @tparam Modulus The modulus.
 * @param a The first addend.
 * @param b The second addend.
 * @return (a + b) % Modulus
 */
template <uint64_t Modulus> constexpr uint64_t mod_add(uint64_t a, uint64_t b) {
  if constexpr (Modulus <= UINT32_MAX) {
    uint64_t res = a + b;
    if (res >= Modulus)
      res -= Modulus;
    return res;
  } else {
    u128 res = static_cast<u128>(a) + b;
    if (res >= Modulus)
      res -= Modulus;
    return res;
  }
}

/**
 * @brief Perform modular subtraction for 64-bit integers.
 * @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
 *         all inputs are small enough to avoid overflow (i.e. bit-32).
 * @param a The minuend.
 * @param b The subtrahend.
 * @param modulus The modulus.
 * @return (a - b) % modulus
 */
template <bool bit32 = false>
constexpr uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t modulus) {
  if constexpr (bit32) {
    uint64_t res = (a >= b) ? (a - b) : (modulus + a - b);
    return res;
  } else {
    u128 res = (a >= b) ? (a - b) : (static_cast<u128>(modulus) + a - b);
    return res;
  }
}

/**
 * @brief Perform modular subtraction for 64-bit integers with a compile-time
 *        modulus.
 * @tparam Modulus The modulus.
 * @param a The minuend.
 * @param b The subtrahend.
 * @return (a - b) % Modulus
 */
template <uint64_t Modulus> constexpr uint64_t mod_sub(uint64_t a, uint64_t b) {
  if constexpr (Modulus <= UINT32_MAX) {
    uint64_t res = (a >= b) ? (a - b) : (Modulus + a - b);
    return res;
  } else {
    u128 res = (a >= b) ? (a - b) : (static_cast<u128>(Modulus) + a - b);
    return res;
  }
}

/**
 * @brief Perform modular multiplication for 64-bit integers.
 * @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
 *         all inputs are small enough to avoid overflow (i.e. bit-32).
 * @param a The first multiplicand.
 * @param b The second multiplicand.
 * @param modulus The modulus.
 * @return (a * b) % modulus
 */
template <bool bit32 = false>
constexpr uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t modulus) {
  if constexpr (bit32)
    return a * b % modulus;
  else
    return static_cast<u128>(a) * b % modulus;
}

/**
 * @brief Perform modular multiplication for 64-bit integers with a compile-time
 *        modulus.
 * @tparam Modulus The modulus.
 * @param a The first multiplicand.
 * @param b The second multiplicand.
 * @return (a * b) % Modulus
 */
template <uint64_t Modulus> constexpr uint64_t mod_mul(uint64_t a, uint64_t b) {
  if constexpr (Modulus <= UINT32_MAX)
    return a * b % Modulus;
  else
    return static_cast<u128>(a) * b % Modulus;
}

/**
 * @brief Perform modular exponentiation for 64-bit integers.
 * @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
 *         all inputs are small enough to avoid overflow (i.e. bit-32).
 * @param base The base number.
 * @param exponent The exponent.
 * @param modulus The modulus.
 * @return (base^exponent) % modulus
 */
template <bool bit32 = false>
constexpr uint64_t mod_pow(uint64_t base, uint64_t exponent, uint64_t modulus) {
  uint64_t result = 1 % modulus;
  base %= modulus;
  while (exponent > 0) {
    if (exponent & 1)
      result = mod_mul<bit32>(result, base, modulus);
    base = mod_mul<bit32>(base, base, modulus);
    exponent >>= 1;
  }
  return result;
}

/**
 * @brief Perform modular exponentiation for 64-bit integers with a compile-time
 *        modulus.
 * @tparam Modulus The modulus.
 * @param base The base number.
 * @param exponent The exponent.
 * @return (base^exponent) % Modulus
 */
template <uint64_t Modulus>
constexpr uint64_t mod_pow(uint64_t base, uint64_t exponent) {
  uint64_t result = 1 % Modulus;
  base %= Modulus;
  while (exponent > 0) {
    if (exponent & 1)
      result = mod_mul<Modulus>(result, base);
    base = mod_mul<Modulus>(base, base);
    exponent >>= 1;
  }
  return result;
}

/**
 * @brief Compute the modular inverse of a 64-bit integer using Fermat's Little
 *        Theorem.
 * @tparam Modulus The modulus (must be prime).
 * @param a The number to find the modular inverse of.
 * @return The modular inverse of a modulo Modulus.
 */
template <uint64_t Modulus> constexpr uint64_t mod_inv(uint64_t a) {
  return mod_pow<Modulus>(a, Modulus - 2);
}

/**
 * @brief Compute the modular inverse of a compile-time 64-bit integer using
 *        Fermat's Little Theorem.
 * @tparam Modulus The modulus (must be prime).
 * @tparam a The number to find the modular inverse of.
 * @return The modular inverse of a modulo Modulus.
 */
template <uint64_t Modulus, uint64_t a> constexpr uint64_t mod_inv() {
  return mod_pow<Modulus>(a, Modulus - 2);
}
} // namespace weilycoder


#line 11 "weilycoder/number_theory/modint.hpp"
#include <istream>
#include <ostream>

namespace weilycoder {
/**
 * @brief Modular Integer with compile-time modulus.
 * @tparam Modulus The modulus.
 */
template <uint64_t Modulus> struct modint {
private:
  uint64_t value;

public:
  constexpr modint() : value(0) {}
  constexpr modint(uint32_t v) : value(v % Modulus) {}
  constexpr modint(int32_t v) { from_i64(v); }
  constexpr modint(uint64_t v) : value(v % Modulus) {}
  constexpr modint(int64_t v) { from_i64(v); }

  constexpr void from_i64(int64_t v) {
    int64_t x = v % static_cast<int64_t>(Modulus);
    if (x < 0)
      x += Modulus;
    value = static_cast<uint64_t>(x);
  }

  explicit operator uint64_t() const { return value; }

  friend constexpr modint<Modulus> operator+(const modint<Modulus> &lhs,
                                             const modint<Modulus> &rhs) {
    return modint<Modulus>(mod_add<Modulus>(lhs.value, rhs.value));
  }

  friend constexpr modint<Modulus> operator-(const modint<Modulus> &lhs,
                                             const modint<Modulus> &rhs) {
    return modint<Modulus>(mod_sub<Modulus>(lhs.value, rhs.value));
  }

  friend constexpr modint<Modulus> operator*(const modint<Modulus> &lhs,
                                             const modint<Modulus> &rhs) {
    return modint<Modulus>(mod_mul<Modulus>(lhs.value, rhs.value));
  }

  modint &operator+=(const modint &other) {
    value = mod_add<Modulus>(value, other.value);
    return *this;
  }

  modint &operator-=(const modint &other) {
    value = mod_sub<Modulus>(value, other.value);
    return *this;
  }

  modint &operator*=(const modint &other) {
    value = mod_mul<Modulus>(value, other.value);
    return *this;
  }

  friend std::ostream &operator<<(std::ostream &os, const modint &m) {
    return os << m.value;
  }

  friend std::istream &operator>>(std::istream &is, modint &m) {
    int64_t v;
    is >> v;
    m.from_i64(v);
    return is;
  }
};
} // namespace weilycoder


#line 6 "test/matrix_product.test.cpp"
#include <iostream>
using namespace std;
using namespace weilycoder;

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);
  cin.exceptions(cin.failbit | cin.badbit);
  size_t N, M, K;
  cin >> N >> M >> K;
  NMatrix<modint<998244353>> A(N, M), B(M, K);
  for (size_t i = 0; i < N; ++i)
    for (size_t j = 0; j < M; ++j)
      cin >> A(i, j);
  for (size_t i = 0; i < M; ++i)
    for (size_t j = 0; j < K; ++j)
      cin >> B(i, j);
  auto C = A * B;
  for (size_t i = 0; i < N; ++i)
    for (size_t j = 0; j < K; ++j)
      cout << C(i, j) << " \n"[j + 1 == K];
  return 0;
}
Back to top page