This documentation is automatically generated by online-judge-tools/verification-helper
#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"
#include "../weilycoder/matrix.hpp"
#include "../weilycoder/number_theory/modint.hpp"
#include <cstdint>
#include <iostream>
using namespace std;
using namespace weilycoder;
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
cin.exceptions(cin.failbit | cin.badbit);
size_t N, M, K;
cin >> N >> M >> K;
NMatrix<modint<998244353>> A(N, M), B(M, K);
for (size_t i = 0; i < N; ++i)
for (size_t j = 0; j < M; ++j)
cin >> A(i, j);
for (size_t i = 0; i < M; ++i)
for (size_t j = 0; j < K; ++j)
cin >> B(i, j);
auto C = A * B;
for (size_t i = 0; i < N; ++i)
for (size_t j = 0; j < K; ++j)
cout << C(i, j) << " \n"[j + 1 == K];
return 0;
}#line 1 "test/matrix_product.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"
#line 1 "weilycoder/matrix.hpp"
/**
* @file matrix.hpp
* @brief Matrix implementation using narray
*/
#line 1 "weilycoder/ds/bitset.hpp"
/**
* @file bitset.hpp
* @brief Run-time sized bitset
*/
#include <cstdint>
#include <stdexcept>
#include <vector>
namespace weilycoder {
/**
* @class dbitset
* @brief Run-time sized bitset implementation
*/
class dbitset {
private:
const size_t bitsize;
std::vector<uint64_t> bits;
size_t extra_size() const { return bits.size() * 64 - bitsize; }
uint64_t extra_mask() const {
size_t extra = extra_size();
return extra ? ((1ULL << (64 - extra)) - 1) : ~0ULL;
}
void trim_extra() { bits.back() &= extra_mask(); }
public:
dbitset(size_t n) : bitsize(n), bits((n + 63) / 64, 0) {}
dbitset(const dbitset &other) : bitsize(other.bitsize), bits(other.bits) {}
dbitset &operator=(const dbitset &other) {
if (this != &other) {
if (bitsize != other.bitsize)
throw std::invalid_argument("Bitset sizes do not match for assignment.");
bits = other.bits;
}
return *this;
}
struct reference {
dbitset &bs;
size_t pos;
reference(dbitset &bs, size_t pos) : bs(bs), pos(pos) {}
operator bool() const {
size_t idx = pos / 64;
size_t bit = pos % 64;
return (bs.bits[idx] >> bit) & 1ULL;
}
reference &operator=(bool val) {
size_t idx = pos / 64;
size_t bit = pos % 64;
if (val)
bs.bits[idx] |= (1ULL << bit);
else
bs.bits[idx] &= ~(1ULL << bit);
return *this;
}
};
struct const_reference {
const dbitset &bs;
size_t pos;
const_reference(const dbitset &bs, size_t pos) : bs(bs), pos(pos) {}
operator bool() const {
size_t idx = pos / 64;
size_t bit = pos % 64;
return (bs.bits[idx] >> bit) & 1ULL;
}
};
struct iterator {
dbitset &bs;
size_t pos;
iterator(dbitset &bs, size_t pos) : bs(bs), pos(pos) {}
bool operator!=(const iterator &other) const { return pos != other.pos; }
reference operator*() { return reference(bs, pos); }
iterator &operator++() { return ++pos, *this; }
iterator operator++(int) {
iterator temp = *this;
++pos;
return temp;
}
};
struct const_iterator {
const dbitset &bs;
size_t pos;
const_iterator(const dbitset &bs, size_t pos) : bs(bs), pos(pos) {}
bool operator!=(const const_iterator &other) const { return pos != other.pos; }
const_reference operator*() const { return const_reference(bs, pos); }
const_iterator &operator++() { return ++pos, *this; }
const_iterator operator++(int) {
const_iterator temp = *this;
++pos;
return temp;
}
};
iterator begin() { return iterator(*this, 0); }
iterator end() { return iterator(*this, bitsize); }
const_iterator begin() const { return const_iterator(*this, 0); }
const_iterator end() const { return const_iterator(*this, bitsize); }
reference operator[](size_t pos) { return reference(*this, pos); }
const_reference operator[](size_t pos) const { return const_reference(*this, pos); }
size_t size() const { return bitsize; }
size_t d_size() const { return bits.size(); }
uint64_t &d_get_word(size_t idx) { return bits[idx]; }
const uint64_t &d_get_word(size_t idx) const { return bits[idx]; }
dbitset &operator&=(const dbitset &other) {
if (bitsize != other.bitsize)
throw std::invalid_argument("Bitset sizes do not match for AND operation.");
for (size_t i = 0; i < bits.size(); ++i)
bits[i] &= other.bits[i];
return this->trim_extra(), *this;
}
dbitset &operator|=(const dbitset &other) {
if (bitsize != other.bitsize)
throw std::invalid_argument("Bitset sizes do not match for OR operation.");
for (size_t i = 0; i < bits.size(); ++i)
bits[i] |= other.bits[i];
return this->trim_extra(), *this;
}
dbitset &operator^=(const dbitset &other) {
if (bitsize != other.bitsize)
throw std::invalid_argument("Bitset sizes do not match for XOR operation.");
for (size_t i = 0; i < bits.size(); ++i)
bits[i] ^= other.bits[i];
return this->trim_extra(), *this;
}
dbitset &operator<<=(size_t shift) {
if (shift >= bitsize) {
std::fill(bits.begin(), bits.end(), 0);
return *this;
}
size_t word_shift = shift / 64;
size_t bit_shift = shift % 64;
if (bit_shift == 0) {
for (size_t i = bits.size() - 1; i >= word_shift; --i)
bits[i] = bits[i - word_shift];
} else {
for (size_t i = bits.size() - 1; i > word_shift; --i)
bits[i] = (bits[i - word_shift] << bit_shift) |
(bits[i - word_shift - 1] >> (64 - bit_shift));
bits[word_shift] = bits[0] << bit_shift;
}
std::fill(bits.begin(), bits.begin() + word_shift, 0);
return this->trim_extra(), *this;
}
dbitset &operator>>=(size_t shift) {
if (shift >= bitsize) {
std::fill(bits.begin(), bits.end(), 0);
return *this;
}
size_t word_shift = shift / 64;
size_t bit_shift = shift % 64;
size_t n = bits.size();
if (bit_shift == 0) {
for (size_t i = 0; i < n - word_shift; ++i)
bits[i] = bits[i + word_shift];
} else {
for (size_t i = 0; i < n - word_shift - 1; ++i)
bits[i] = (bits[i + word_shift] >> bit_shift) |
(bits[i + word_shift + 1] << (64 - bit_shift));
bits[n - word_shift - 1] = bits[n - 1] >> bit_shift;
}
std::fill(bits.end() - word_shift, bits.end(), 0);
return *this;
}
friend dbitset operator~(const dbitset &bs) {
dbitset result(bs.bitsize);
for (size_t i = 0; i < bs.bits.size(); ++i)
result.bits[i] = ~bs.bits[i];
return result.trim_extra(), result;
}
friend dbitset operator&(const dbitset &a, const dbitset &b) {
if (a.bitsize != b.bitsize)
throw std::invalid_argument("Bitset sizes do not match for AND operation.");
dbitset result(a.bitsize);
for (size_t i = 0; i < a.bits.size(); ++i)
result.bits[i] = a.bits[i] & b.bits[i];
return result.trim_extra(), result;
}
friend dbitset operator|(const dbitset &a, const dbitset &b) {
if (a.bitsize != b.bitsize)
throw std::invalid_argument("Bitset sizes do not match for OR operation.");
dbitset result(a.bitsize);
for (size_t i = 0; i < a.bits.size(); ++i)
result.bits[i] = a.bits[i] | b.bits[i];
return result.trim_extra(), result;
}
friend dbitset operator^(const dbitset &a, const dbitset &b) {
if (a.bitsize != b.bitsize)
throw std::invalid_argument("Bitset sizes do not match for XOR operation.");
dbitset result(a.bitsize);
for (size_t i = 0; i < a.bits.size(); ++i)
result.bits[i] = a.bits[i] ^ b.bits[i];
return result.trim_extra(), result;
}
friend dbitset operator<<(const dbitset &bs, size_t shift) {
dbitset result = bs;
return result <<= shift, result;
}
friend dbitset operator>>(const dbitset &bs, size_t shift) {
dbitset result = bs;
return result >>= shift, result;
}
};
} // namespace weilycoder
#line 1 "weilycoder/ds/narray.hpp"
/**
* @file narray.hpp
* @brief N-dimensional array (narray)
*/
#include <array>
#include <cstddef>
#include <initializer_list>
#line 13 "weilycoder/ds/narray.hpp"
namespace weilycoder {
/**
* @brief N-dimensional array (narray) implementation.
* @tparam T Type of elements stored in the narray.
* @tparam D Number of dimensions.
*/
template <typename T, size_t D> struct narray {
const std::array<size_t, D> shape;
std::vector<T> data;
/**
* @brief Constructor to initialize narray with given dimensions.
* @tparam Sizes Parameter pack for sizes of each dimension.
* @param sizes Sizes of each dimension.
*/
template <typename... Sizes>
narray(Sizes... sizes) : shape{static_cast<size_t>(sizes)...} {
static_assert(sizeof...(Sizes) == D, "Number of sizes must match dimensions.");
data.resize(size());
}
/**
* @brief Get the total number of elements in the narray.
* @return Total number of elements.
*/
size_t size() const {
size_t res = 1;
for (size_t i = 0; i < D; ++i)
res *= shape[i];
return res;
}
/**
* @brief Compute the linear index from multi-dimensional indices.
* @tparam Indices Parameter pack for indices in each dimension.
* @param indices Indices in each dimension.
* @return Linear index corresponding to the multi-dimensional indices.
*/
template <typename... Indices> size_t _index(Indices... indices) const {
static_assert(sizeof...(Indices) == D, "Number of indices must match dimensions.");
const size_t idxs[] = {static_cast<size_t>(indices)...};
size_t res = 0;
for (size_t i = 0; i < D; ++i)
res = res * shape[i] + idxs[i];
return res;
}
/**
* @brief Access element at specified multi-dimensional indices.
* @tparam Indices Parameter pack for indices in each dimension.
* @param indices Indices in each dimension.
* @return Reference to the element at the specified indices.
*/
template <typename... Indices> const T &operator()(Indices... indices) const {
return data[_index(indices...)];
}
/**
* @brief Access element at specified multi-dimensional indices.
* @tparam Indices Parameter pack for indices in each dimension.
* @param indices Indices in each dimension.
* @return Reference to the element at the specified indices.
*/
template <typename... Indices> T &operator()(Indices... indices) {
return data[_index(indices...)];
}
};
} // namespace weilycoder
#line 14 "weilycoder/matrix.hpp"
namespace weilycoder {
template <typename T, size_t R, size_t C> struct Matrix {
T data[R][C];
Matrix() : data{} {}
Matrix(std::initializer_list<T> init) {
size_t i = 0, j = 0;
for (const auto &val : init) {
data[i][j] = val;
if (++j == C)
j = 0, ++i;
}
}
/**
* @brief Access element at specified row and column.
* @param row Row index.
* @param col Column index.
* @return Reference to the element at the specified row and column.
*/
constexpr const T &operator()(size_t row, size_t col) const { return data[row][col]; }
/**
* @brief Access element at specified row and column.
* @param row Row index.
* @param col Column index.
* @return Reference to the element at the specified row and column.
*/
T &operator()(size_t row, size_t col) { return data[row][col]; }
Matrix<T, R, C> &operator+=(const Matrix<T, R, C> &other) {
for (size_t i = 0; i < R; ++i)
for (size_t j = 0; j < C; ++j)
data(i, j) += other(i, j);
return *this;
}
Matrix<T, R, C> &operator-=(const Matrix<T, R, C> &other) {
for (size_t i = 0; i < R; ++i)
for (size_t j = 0; j < C; ++j)
data(i, j) -= other(i, j);
return *this;
}
template <size_t M> Matrix<T, R, M> &operator*=(const Matrix<T, C, M> &other) {
Matrix<T, R, M> result;
for (size_t i = 0; i < R; ++i)
for (size_t j = 0; j < M; ++j)
for (size_t k = 0; k < C; ++k)
result(i, j) += data(i, k) * other(k, j);
return *this = result;
}
friend Matrix<T, R, C> operator+(const Matrix<T, R, C> &a, const Matrix<T, R, C> &b) {
Matrix<T, R, C> result;
for (size_t i = 0; i < R; ++i)
for (size_t j = 0; j < C; ++j)
result(i, j) = a(i, j) + b(i, j);
return result;
}
friend Matrix<T, R, C> operator-(const Matrix<T, R, C> &a, const Matrix<T, R, C> &b) {
Matrix<T, R, C> result;
for (size_t i = 0; i < R; ++i)
for (size_t j = 0; j < C; ++j)
result(i, j) = a(i, j) - b(i, j);
return result;
}
template <size_t M>
friend Matrix<T, R, C> operator*(const Matrix<T, R, M> &a, const Matrix<T, M, C> &b) {
Matrix<T, R, C> result;
for (size_t i = 0; i < R; ++i)
for (size_t j = 0; j < C; ++j)
for (size_t k = 0; k < M; ++k)
result(i, j) += a(i, k) * b(k, j);
return result;
}
};
template <typename T> struct NMatrix {
narray<T, 2> data;
NMatrix(size_t rows, size_t cols) : data(rows, cols) {}
template <size_t R, size_t C> NMatrix(const Matrix<T, R, C> &matrix) : data(R, C) {
for (size_t i = 0; i < R; ++i)
for (size_t j = 0; j < C; ++j)
data(i, j) = matrix(i, j);
}
/**
* @brief Access element at specified row and column.
* @param row Row index.
* @param col Column index.
* @return Reference to the element at the specified row and column.
*/
constexpr const T &operator()(size_t row, size_t col) const { return data(row, col); }
/**
* @brief Access element at specified row and column.
* @param row Row index.
* @param col Column index.
* @return Reference to the element at the specified row and column.
*/
T &operator()(size_t row, size_t col) { return data(row, col); }
NMatrix<T> &operator+=(const NMatrix<T> &other) {
if (data.shape != other.data.shape)
throw std::invalid_argument("Matrix dimensions do not match for addition.");
for (size_t i = 0; i < data.shape[0]; ++i)
for (size_t j = 0; j < data.shape[1]; ++j)
data(i, j) += other(i, j);
return *this;
}
NMatrix<T> &operator-=(const NMatrix<T> &other) {
if (data.shape != other.data.shape)
throw std::invalid_argument("Matrix dimensions do not match for subtraction.");
for (size_t i = 0; i < data.shape[0]; ++i)
for (size_t j = 0; j < data.shape[1]; ++j)
data(i, j) -= other(i, j);
return *this;
}
NMatrix<T> &operator*=(const NMatrix<T> &other) {
if (data.shape[1] != other.data.shape[0])
throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
size_t R = data.shape[0];
size_t C = other.data.shape[1];
size_t K = data.shape[1];
NMatrix<T> result(R, C);
for (size_t i = 0; i < R; ++i)
for (size_t k = 0; k < K; ++k)
for (size_t j = 0; j < C; ++j)
result(i, j) += data(i, k) * other(k, j);
return *this = result;
}
friend NMatrix<T> operator+(const NMatrix<T> &a, const NMatrix<T> &b) {
if (a.data.shape != b.data.shape)
throw std::invalid_argument("Matrix dimensions do not match for addition.");
NMatrix<T> result(a.data.shape[0], a.data.shape[1]);
for (size_t i = 0; i < a.data.shape[0]; ++i)
for (size_t j = 0; j < a.data.shape[1]; ++j)
result(i, j) = a(i, j) + b(i, j);
return result;
}
friend NMatrix<T> operator-(const NMatrix<T> &a, const NMatrix<T> &b) {
if (a.data.shape != b.data.shape)
throw std::invalid_argument("Matrix dimensions do not match for subtraction.");
NMatrix<T> result(a.data.shape[0], a.data.shape[1]);
for (size_t i = 0; i < a.data.shape[0]; ++i)
for (size_t j = 0; j < a.data.shape[1]; ++j)
result(i, j) = a(i, j) - b(i, j);
return result;
}
friend NMatrix<T> operator*(const NMatrix<T> &a, const NMatrix<T> &b) {
if (a.data.shape[1] != b.data.shape[0])
throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
size_t R = a.data.shape[0];
size_t C = b.data.shape[1];
size_t K = a.data.shape[1];
NMatrix<T> result(R, C);
for (size_t i = 0; i < R; ++i)
for (size_t k = 0; k < K; ++k)
for (size_t j = 0; j < C; ++j)
result(i, j) += a(i, k) * b(k, j);
return result;
}
};
struct BMatrix {
size_t rows, cols;
std::vector<dbitset> data;
BMatrix(size_t rows, size_t cols)
: rows(rows), cols(cols), data(rows, dbitset(cols)) {}
/**
* @brief Access element at specified row and column.
* @param row Row index.
* @param col Column index.
* @return Reference to the element at the specified row and column.
*/
dbitset::reference operator()(size_t row, size_t col) { return data[row][col]; }
/**
* @brief Access element at specified row and column.
* @param row Row index.
* @param col Column index.
* @return Reference to the element at the specified row and column.
*/
dbitset::const_reference operator()(size_t row, size_t col) const {
return data[row][col];
}
BMatrix &operator&=(const BMatrix &other) {
for (size_t i = 0; i < rows; ++i)
data[i] &= other.data[i];
return *this;
}
BMatrix &operator|=(const BMatrix &other) {
for (size_t i = 0; i < rows; ++i)
data[i] |= other.data[i];
return *this;
}
BMatrix &operator^=(const BMatrix &other) {
for (size_t i = 0; i < rows; ++i)
data[i] ^= other.data[i];
return *this;
}
BMatrix operator*=(const BMatrix &other) {
if (cols != other.rows)
throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
BMatrix result(rows, other.cols);
for (size_t i = 0; i < rows; ++i)
for (size_t k = 0; k < cols; ++k)
if (data[i][k])
result.data[i] ^= other.data[k];
return *this = result;
}
friend BMatrix operator&(const BMatrix &a, const BMatrix &b) {
if (a.rows != b.rows || a.cols != b.cols)
throw std::invalid_argument("Matrix dimensions do not match for AND operation.");
BMatrix result(a.rows, a.cols);
for (size_t i = 0; i < a.rows; ++i)
for (size_t j = 0; j < a.cols; ++j)
result(i, j) = static_cast<bool>(a(i, j)) && static_cast<bool>(b(i, j));
return result;
}
friend BMatrix operator|(const BMatrix &a, const BMatrix &b) {
if (a.rows != b.rows || a.cols != b.cols)
throw std::invalid_argument("Matrix dimensions do not match for OR operation.");
BMatrix result(a.rows, a.cols);
for (size_t i = 0; i < a.rows; ++i)
for (size_t j = 0; j < a.cols; ++j)
result(i, j) = static_cast<bool>(a(i, j)) || static_cast<bool>(b(i, j));
return result;
}
friend BMatrix operator^(const BMatrix &a, const BMatrix &b) {
if (a.rows != b.rows || a.cols != b.cols)
throw std::invalid_argument("Matrix dimensions do not match for XOR operation.");
BMatrix result(a.rows, a.cols);
for (size_t i = 0; i < a.rows; ++i)
for (size_t j = 0; j < a.cols; ++j)
result(i, j) = static_cast<bool>(a(i, j)) ^ static_cast<bool>(b(i, j));
return result;
}
friend BMatrix operator~(const BMatrix &a) {
BMatrix result(a.rows, a.cols);
for (size_t i = 0; i < a.rows; ++i)
for (size_t j = 0; j < a.cols; ++j)
result(i, j) = !static_cast<bool>(a(i, j));
return result;
}
friend BMatrix operator*(const BMatrix &a, const BMatrix &b) {
if (a.cols != b.rows)
throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
BMatrix result(a.rows, b.cols);
for (size_t i = 0; i < a.rows; ++i)
for (size_t k = 0; k < a.cols; ++k)
if (a(i, k))
result.data[i] ^= b.data[k];
return result;
}
};
} // namespace weilycoder
#line 1 "weilycoder/number_theory/modint.hpp"
/**
* @file modint.hpp
* @brief Modular Integer Class
*/
#line 1 "weilycoder/number_theory/mod_utility.hpp"
/**
* @file mod_utility.hpp
* @brief Modular Arithmetic Utilities
*/
#line 10 "weilycoder/number_theory/mod_utility.hpp"
namespace weilycoder {
using u128 = unsigned __int128;
/**
* @brief Perform modular addition for 64-bit integers.
* @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
* all inputs are small enough to avoid overflow (i.e. bit-32).
* @param a The first addend.
* @param b The second addend.
* @param modulus The modulus.
* @return (a + b) % modulus
*/
template <bool bit32 = false>
constexpr uint64_t mod_add(uint64_t a, uint64_t b, uint64_t modulus) {
if constexpr (bit32) {
uint64_t res = a + b;
if (res >= modulus)
res -= modulus;
return res;
} else {
u128 res = static_cast<u128>(a) + b;
if (res >= modulus)
res -= modulus;
return res;
}
}
/**
* @brief Perform modular addition for 64-bit integers with a compile-time
* modulus.
* @tparam Modulus The modulus.
* @param a The first addend.
* @param b The second addend.
* @return (a + b) % Modulus
*/
template <uint64_t Modulus> constexpr uint64_t mod_add(uint64_t a, uint64_t b) {
if constexpr (Modulus <= UINT32_MAX) {
uint64_t res = a + b;
if (res >= Modulus)
res -= Modulus;
return res;
} else {
u128 res = static_cast<u128>(a) + b;
if (res >= Modulus)
res -= Modulus;
return res;
}
}
/**
* @brief Perform modular subtraction for 64-bit integers.
* @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
* all inputs are small enough to avoid overflow (i.e. bit-32).
* @param a The minuend.
* @param b The subtrahend.
* @param modulus The modulus.
* @return (a - b) % modulus
*/
template <bool bit32 = false>
constexpr uint64_t mod_sub(uint64_t a, uint64_t b, uint64_t modulus) {
if constexpr (bit32) {
uint64_t res = (a >= b) ? (a - b) : (modulus + a - b);
return res;
} else {
u128 res = (a >= b) ? (a - b) : (static_cast<u128>(modulus) + a - b);
return res;
}
}
/**
* @brief Perform modular subtraction for 64-bit integers with a compile-time
* modulus.
* @tparam Modulus The modulus.
* @param a The minuend.
* @param b The subtrahend.
* @return (a - b) % Modulus
*/
template <uint64_t Modulus> constexpr uint64_t mod_sub(uint64_t a, uint64_t b) {
if constexpr (Modulus <= UINT32_MAX) {
uint64_t res = (a >= b) ? (a - b) : (Modulus + a - b);
return res;
} else {
u128 res = (a >= b) ? (a - b) : (static_cast<u128>(Modulus) + a - b);
return res;
}
}
/**
* @brief Perform modular multiplication for 64-bit integers.
* @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
* all inputs are small enough to avoid overflow (i.e. bit-32).
* @param a The first multiplicand.
* @param b The second multiplicand.
* @param modulus The modulus.
* @return (a * b) % modulus
*/
template <bool bit32 = false>
constexpr uint64_t mod_mul(uint64_t a, uint64_t b, uint64_t modulus) {
if constexpr (bit32)
return a * b % modulus;
else
return static_cast<u128>(a) * b % modulus;
}
/**
* @brief Perform modular multiplication for 64-bit integers with a compile-time
* modulus.
* @tparam Modulus The modulus.
* @param a The first multiplicand.
* @param b The second multiplicand.
* @return (a * b) % Modulus
*/
template <uint64_t Modulus> constexpr uint64_t mod_mul(uint64_t a, uint64_t b) {
if constexpr (Modulus <= UINT32_MAX)
return a * b % Modulus;
else
return static_cast<u128>(a) * b % Modulus;
}
/**
* @brief Perform modular exponentiation for 64-bit integers.
* @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
* all inputs are small enough to avoid overflow (i.e. bit-32).
* @param base The base number.
* @param exponent The exponent.
* @param modulus The modulus.
* @return (base^exponent) % modulus
*/
template <bool bit32 = false>
constexpr uint64_t mod_pow(uint64_t base, uint64_t exponent, uint64_t modulus) {
uint64_t result = 1 % modulus;
base %= modulus;
while (exponent > 0) {
if (exponent & 1)
result = mod_mul<bit32>(result, base, modulus);
base = mod_mul<bit32>(base, base, modulus);
exponent >>= 1;
}
return result;
}
/**
* @brief Perform modular exponentiation for 64-bit integers with a compile-time
* modulus.
* @tparam Modulus The modulus.
* @param base The base number.
* @param exponent The exponent.
* @return (base^exponent) % Modulus
*/
template <uint64_t Modulus>
constexpr uint64_t mod_pow(uint64_t base, uint64_t exponent) {
uint64_t result = 1 % Modulus;
base %= Modulus;
while (exponent > 0) {
if (exponent & 1)
result = mod_mul<Modulus>(result, base);
base = mod_mul<Modulus>(base, base);
exponent >>= 1;
}
return result;
}
/**
* @brief Compute the modular inverse of a 64-bit integer using Fermat's Little
* Theorem.
* @tparam Modulus The modulus (must be prime).
* @param a The number to find the modular inverse of.
* @return The modular inverse of a modulo Modulus.
*/
template <uint64_t Modulus> constexpr uint64_t mod_inv(uint64_t a) {
return mod_pow<Modulus>(a, Modulus - 2);
}
/**
* @brief Compute the modular inverse of a compile-time 64-bit integer using
* Fermat's Little Theorem.
* @tparam Modulus The modulus (must be prime).
* @tparam a The number to find the modular inverse of.
* @return The modular inverse of a modulo Modulus.
*/
template <uint64_t Modulus, uint64_t a> constexpr uint64_t mod_inv() {
return mod_pow<Modulus>(a, Modulus - 2);
}
} // namespace weilycoder
#line 11 "weilycoder/number_theory/modint.hpp"
#include <istream>
#include <ostream>
namespace weilycoder {
/**
* @brief Modular Integer with compile-time modulus.
* @tparam Modulus The modulus.
*/
template <uint64_t Modulus> struct modint {
private:
uint64_t value;
public:
constexpr modint() : value(0) {}
constexpr modint(uint32_t v) : value(v % Modulus) {}
constexpr modint(int32_t v) { from_i64(v); }
constexpr modint(uint64_t v) : value(v % Modulus) {}
constexpr modint(int64_t v) { from_i64(v); }
constexpr void from_i64(int64_t v) {
int64_t x = v % static_cast<int64_t>(Modulus);
if (x < 0)
x += Modulus;
value = static_cast<uint64_t>(x);
}
explicit operator uint64_t() const { return value; }
friend constexpr modint<Modulus> operator+(const modint<Modulus> &lhs,
const modint<Modulus> &rhs) {
return modint<Modulus>(mod_add<Modulus>(lhs.value, rhs.value));
}
friend constexpr modint<Modulus> operator-(const modint<Modulus> &lhs,
const modint<Modulus> &rhs) {
return modint<Modulus>(mod_sub<Modulus>(lhs.value, rhs.value));
}
friend constexpr modint<Modulus> operator*(const modint<Modulus> &lhs,
const modint<Modulus> &rhs) {
return modint<Modulus>(mod_mul<Modulus>(lhs.value, rhs.value));
}
modint &operator+=(const modint &other) {
value = mod_add<Modulus>(value, other.value);
return *this;
}
modint &operator-=(const modint &other) {
value = mod_sub<Modulus>(value, other.value);
return *this;
}
modint &operator*=(const modint &other) {
value = mod_mul<Modulus>(value, other.value);
return *this;
}
friend std::ostream &operator<<(std::ostream &os, const modint &m) {
return os << m.value;
}
friend std::istream &operator>>(std::istream &is, modint &m) {
int64_t v;
is >> v;
m.from_i64(v);
return is;
}
};
} // namespace weilycoder
#line 6 "test/matrix_product.test.cpp"
#include <iostream>
using namespace std;
using namespace weilycoder;
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
cin.exceptions(cin.failbit | cin.badbit);
size_t N, M, K;
cin >> N >> M >> K;
NMatrix<modint<998244353>> A(N, M), B(M, K);
for (size_t i = 0; i < N; ++i)
for (size_t j = 0; j < M; ++j)
cin >> A(i, j);
for (size_t i = 0; i < M; ++i)
for (size_t j = 0; j < K; ++j)
cin >> B(i, j);
auto C = A * B;
for (size_t i = 0; i < N; ++i)
for (size_t j = 0; j < K; ++j)
cout << C(i, j) << " \n"[j + 1 == K];
return 0;
}