cp-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub weilycoder/cp-library

:heavy_check_mark: Matrix implementation using narray
(weilycoder/matrix.hpp)

Depends on

Verified with

Code

#ifndef WEILYCODER_MATRIX_HPP
#define WEILYCODER_MATRIX_HPP

/**
 * @file matrix.hpp
 * @brief Matrix implementation using narray
 */

#include "ds/bitset.hpp"
#include "ds/narray.hpp"
#include <cstddef>
#include <stdexcept>
#include <vector>

namespace weilycoder {
template <typename T, size_t R, size_t C> struct Matrix {
  T data[R][C];

  Matrix() : data{} {}
  Matrix(std::initializer_list<T> init) {
    size_t i = 0, j = 0;
    for (const auto &val : init) {
      data[i][j] = val;
      if (++j == C)
        j = 0, ++i;
    }
  }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  constexpr const T &operator()(size_t row, size_t col) const { return data[row][col]; }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  T &operator()(size_t row, size_t col) { return data[row][col]; }

  Matrix<T, R, C> &operator+=(const Matrix<T, R, C> &other) {
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        data(i, j) += other(i, j);
    return *this;
  }

  Matrix<T, R, C> &operator-=(const Matrix<T, R, C> &other) {
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        data(i, j) -= other(i, j);
    return *this;
  }

  template <size_t M> Matrix<T, R, M> &operator*=(const Matrix<T, C, M> &other) {
    Matrix<T, R, M> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < M; ++j)
        for (size_t k = 0; k < C; ++k)
          result(i, j) += data(i, k) * other(k, j);
    return *this = result;
  }

  friend Matrix<T, R, C> operator+(const Matrix<T, R, C> &a, const Matrix<T, R, C> &b) {
    Matrix<T, R, C> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        result(i, j) = a(i, j) + b(i, j);
    return result;
  }

  friend Matrix<T, R, C> operator-(const Matrix<T, R, C> &a, const Matrix<T, R, C> &b) {
    Matrix<T, R, C> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        result(i, j) = a(i, j) - b(i, j);
    return result;
  }

  template <size_t M>
  friend Matrix<T, R, C> operator*(const Matrix<T, R, M> &a, const Matrix<T, M, C> &b) {
    Matrix<T, R, C> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        for (size_t k = 0; k < M; ++k)
          result(i, j) += a(i, k) * b(k, j);
    return result;
  }
};

template <typename T> struct NMatrix {
  narray<T, 2> data;

  NMatrix(size_t rows, size_t cols) : data(rows, cols) {}

  template <size_t R, size_t C> NMatrix(const Matrix<T, R, C> &matrix) : data(R, C) {
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        data(i, j) = matrix(i, j);
  }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  constexpr const T &operator()(size_t row, size_t col) const { return data(row, col); }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  T &operator()(size_t row, size_t col) { return data(row, col); }

  NMatrix<T> &operator+=(const NMatrix<T> &other) {
    if (data.shape != other.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for addition.");
    for (size_t i = 0; i < data.shape[0]; ++i)
      for (size_t j = 0; j < data.shape[1]; ++j)
        data(i, j) += other(i, j);
    return *this;
  }

  NMatrix<T> &operator-=(const NMatrix<T> &other) {
    if (data.shape != other.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for subtraction.");
    for (size_t i = 0; i < data.shape[0]; ++i)
      for (size_t j = 0; j < data.shape[1]; ++j)
        data(i, j) -= other(i, j);
    return *this;
  }

  NMatrix<T> &operator*=(const NMatrix<T> &other) {
    if (data.shape[1] != other.data.shape[0])
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    size_t R = data.shape[0];
    size_t C = other.data.shape[1];
    size_t K = data.shape[1];
    NMatrix<T> result(R, C);
    for (size_t i = 0; i < R; ++i)
      for (size_t k = 0; k < K; ++k)
        for (size_t j = 0; j < C; ++j)
          result(i, j) += data(i, k) * other(k, j);
    return *this = result;
  }

  friend NMatrix<T> operator+(const NMatrix<T> &a, const NMatrix<T> &b) {
    if (a.data.shape != b.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for addition.");
    NMatrix<T> result(a.data.shape[0], a.data.shape[1]);
    for (size_t i = 0; i < a.data.shape[0]; ++i)
      for (size_t j = 0; j < a.data.shape[1]; ++j)
        result(i, j) = a(i, j) + b(i, j);
    return result;
  }

  friend NMatrix<T> operator-(const NMatrix<T> &a, const NMatrix<T> &b) {
    if (a.data.shape != b.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for subtraction.");
    NMatrix<T> result(a.data.shape[0], a.data.shape[1]);
    for (size_t i = 0; i < a.data.shape[0]; ++i)
      for (size_t j = 0; j < a.data.shape[1]; ++j)
        result(i, j) = a(i, j) - b(i, j);
    return result;
  }

  friend NMatrix<T> operator*(const NMatrix<T> &a, const NMatrix<T> &b) {
    if (a.data.shape[1] != b.data.shape[0])
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    size_t R = a.data.shape[0];
    size_t C = b.data.shape[1];
    size_t K = a.data.shape[1];
    NMatrix<T> result(R, C);
    for (size_t i = 0; i < R; ++i)
      for (size_t k = 0; k < K; ++k)
        for (size_t j = 0; j < C; ++j)
          result(i, j) += a(i, k) * b(k, j);
    return result;
  }
};

struct BMatrix {
  size_t rows, cols;
  std::vector<dbitset> data;

  BMatrix(size_t rows, size_t cols)
      : rows(rows), cols(cols), data(rows, dbitset(cols)) {}

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  dbitset::reference operator()(size_t row, size_t col) { return data[row][col]; }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  dbitset::const_reference operator()(size_t row, size_t col) const {
    return data[row][col];
  }

  BMatrix &operator&=(const BMatrix &other) {
    for (size_t i = 0; i < rows; ++i)
      data[i] &= other.data[i];
    return *this;
  }

  BMatrix &operator|=(const BMatrix &other) {
    for (size_t i = 0; i < rows; ++i)
      data[i] |= other.data[i];
    return *this;
  }

  BMatrix &operator^=(const BMatrix &other) {
    for (size_t i = 0; i < rows; ++i)
      data[i] ^= other.data[i];
    return *this;
  }

  BMatrix operator*=(const BMatrix &other) {
    if (cols != other.rows)
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    BMatrix result(rows, other.cols);
    for (size_t i = 0; i < rows; ++i)
      for (size_t k = 0; k < cols; ++k)
        if (data[i][k])
          result.data[i] ^= other.data[k];
    return *this = result;
  }

  friend BMatrix operator&(const BMatrix &a, const BMatrix &b) {
    if (a.rows != b.rows || a.cols != b.cols)
      throw std::invalid_argument("Matrix dimensions do not match for AND operation.");
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = static_cast<bool>(a(i, j)) && static_cast<bool>(b(i, j));
    return result;
  }

  friend BMatrix operator|(const BMatrix &a, const BMatrix &b) {
    if (a.rows != b.rows || a.cols != b.cols)
      throw std::invalid_argument("Matrix dimensions do not match for OR operation.");
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = static_cast<bool>(a(i, j)) || static_cast<bool>(b(i, j));
    return result;
  }

  friend BMatrix operator^(const BMatrix &a, const BMatrix &b) {
    if (a.rows != b.rows || a.cols != b.cols)
      throw std::invalid_argument("Matrix dimensions do not match for XOR operation.");
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = static_cast<bool>(a(i, j)) ^ static_cast<bool>(b(i, j));
    return result;
  }

  friend BMatrix operator~(const BMatrix &a) {
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = !static_cast<bool>(a(i, j));
    return result;
  }

  friend BMatrix operator*(const BMatrix &a, const BMatrix &b) {
    if (a.cols != b.rows)
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    BMatrix result(a.rows, b.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t k = 0; k < a.cols; ++k)
        if (a(i, k))
          result.data[i] ^= b.data[k];
    return result;
  }
};
} // namespace weilycoder

#endif
#line 1 "weilycoder/matrix.hpp"



/**
 * @file matrix.hpp
 * @brief Matrix implementation using narray
 */

#line 1 "weilycoder/ds/bitset.hpp"



/**
 * @file bitset.hpp
 * @brief Run-time sized bitset
 */

#include <cstdint>
#include <stdexcept>
#include <vector>

namespace weilycoder {
/**
 * @class dbitset
 * @brief Run-time sized bitset implementation
 */
class dbitset {
private:
  const size_t bitsize;
  std::vector<uint64_t> bits;

  size_t extra_size() const { return bits.size() * 64 - bitsize; }

  uint64_t extra_mask() const {
    size_t extra = extra_size();
    return extra ? ((1ULL << (64 - extra)) - 1) : ~0ULL;
  }

  void trim_extra() { bits.back() &= extra_mask(); }

public:
  dbitset(size_t n) : bitsize(n), bits((n + 63) / 64, 0) {}
  dbitset(const dbitset &other) : bitsize(other.bitsize), bits(other.bits) {}

  dbitset &operator=(const dbitset &other) {
    if (this != &other) {
      if (bitsize != other.bitsize)
        throw std::invalid_argument("Bitset sizes do not match for assignment.");
      bits = other.bits;
    }
    return *this;
  }

  struct reference {
    dbitset &bs;
    size_t pos;

    reference(dbitset &bs, size_t pos) : bs(bs), pos(pos) {}

    operator bool() const {
      size_t idx = pos / 64;
      size_t bit = pos % 64;
      return (bs.bits[idx] >> bit) & 1ULL;
    }

    reference &operator=(bool val) {
      size_t idx = pos / 64;
      size_t bit = pos % 64;
      if (val)
        bs.bits[idx] |= (1ULL << bit);
      else
        bs.bits[idx] &= ~(1ULL << bit);
      return *this;
    }
  };

  struct const_reference {
    const dbitset &bs;
    size_t pos;

    const_reference(const dbitset &bs, size_t pos) : bs(bs), pos(pos) {}

    operator bool() const {
      size_t idx = pos / 64;
      size_t bit = pos % 64;
      return (bs.bits[idx] >> bit) & 1ULL;
    }
  };

  struct iterator {
    dbitset &bs;
    size_t pos;

    iterator(dbitset &bs, size_t pos) : bs(bs), pos(pos) {}
    bool operator!=(const iterator &other) const { return pos != other.pos; }
    reference operator*() { return reference(bs, pos); }
    iterator &operator++() { return ++pos, *this; }
    iterator operator++(int) {
      iterator temp = *this;
      ++pos;
      return temp;
    }
  };

  struct const_iterator {
    const dbitset &bs;
    size_t pos;

    const_iterator(const dbitset &bs, size_t pos) : bs(bs), pos(pos) {}
    bool operator!=(const const_iterator &other) const { return pos != other.pos; }
    const_reference operator*() const { return const_reference(bs, pos); }
    const_iterator &operator++() { return ++pos, *this; }
    const_iterator operator++(int) {
      const_iterator temp = *this;
      ++pos;
      return temp;
    }
  };

  iterator begin() { return iterator(*this, 0); }
  iterator end() { return iterator(*this, bitsize); }
  const_iterator begin() const { return const_iterator(*this, 0); }
  const_iterator end() const { return const_iterator(*this, bitsize); }

  reference operator[](size_t pos) { return reference(*this, pos); }
  const_reference operator[](size_t pos) const { return const_reference(*this, pos); }

  size_t size() const { return bitsize; }

  size_t d_size() const { return bits.size(); }
  uint64_t &d_get_word(size_t idx) { return bits[idx]; }
  const uint64_t &d_get_word(size_t idx) const { return bits[idx]; }

  dbitset &operator&=(const dbitset &other) {
    if (bitsize != other.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for AND operation.");
    for (size_t i = 0; i < bits.size(); ++i)
      bits[i] &= other.bits[i];
    return this->trim_extra(), *this;
  }

  dbitset &operator|=(const dbitset &other) {
    if (bitsize != other.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for OR operation.");
    for (size_t i = 0; i < bits.size(); ++i)
      bits[i] |= other.bits[i];
    return this->trim_extra(), *this;
  }

  dbitset &operator^=(const dbitset &other) {
    if (bitsize != other.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for XOR operation.");
    for (size_t i = 0; i < bits.size(); ++i)
      bits[i] ^= other.bits[i];
    return this->trim_extra(), *this;
  }

  dbitset &operator<<=(size_t shift) {
    if (shift >= bitsize) {
      std::fill(bits.begin(), bits.end(), 0);
      return *this;
    }
    size_t word_shift = shift / 64;
    size_t bit_shift = shift % 64;
    if (bit_shift == 0) {
      for (size_t i = bits.size() - 1; i >= word_shift; --i)
        bits[i] = bits[i - word_shift];
    } else {
      for (size_t i = bits.size() - 1; i > word_shift; --i)
        bits[i] = (bits[i - word_shift] << bit_shift) |
                  (bits[i - word_shift - 1] >> (64 - bit_shift));
      bits[word_shift] = bits[0] << bit_shift;
    }
    std::fill(bits.begin(), bits.begin() + word_shift, 0);
    return this->trim_extra(), *this;
  }

  dbitset &operator>>=(size_t shift) {
    if (shift >= bitsize) {
      std::fill(bits.begin(), bits.end(), 0);
      return *this;
    }
    size_t word_shift = shift / 64;
    size_t bit_shift = shift % 64;
    size_t n = bits.size();
    if (bit_shift == 0) {
      for (size_t i = 0; i < n - word_shift; ++i)
        bits[i] = bits[i + word_shift];
    } else {
      for (size_t i = 0; i < n - word_shift - 1; ++i)
        bits[i] = (bits[i + word_shift] >> bit_shift) |
                  (bits[i + word_shift + 1] << (64 - bit_shift));
      bits[n - word_shift - 1] = bits[n - 1] >> bit_shift;
    }
    std::fill(bits.end() - word_shift, bits.end(), 0);
    return *this;
  }

  friend dbitset operator~(const dbitset &bs) {
    dbitset result(bs.bitsize);
    for (size_t i = 0; i < bs.bits.size(); ++i)
      result.bits[i] = ~bs.bits[i];
    return result.trim_extra(), result;
  }

  friend dbitset operator&(const dbitset &a, const dbitset &b) {
    if (a.bitsize != b.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for AND operation.");
    dbitset result(a.bitsize);
    for (size_t i = 0; i < a.bits.size(); ++i)
      result.bits[i] = a.bits[i] & b.bits[i];
    return result.trim_extra(), result;
  }

  friend dbitset operator|(const dbitset &a, const dbitset &b) {
    if (a.bitsize != b.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for OR operation.");
    dbitset result(a.bitsize);
    for (size_t i = 0; i < a.bits.size(); ++i)
      result.bits[i] = a.bits[i] | b.bits[i];
    return result.trim_extra(), result;
  }

  friend dbitset operator^(const dbitset &a, const dbitset &b) {
    if (a.bitsize != b.bitsize)
      throw std::invalid_argument("Bitset sizes do not match for XOR operation.");
    dbitset result(a.bitsize);
    for (size_t i = 0; i < a.bits.size(); ++i)
      result.bits[i] = a.bits[i] ^ b.bits[i];
    return result.trim_extra(), result;
  }

  friend dbitset operator<<(const dbitset &bs, size_t shift) {
    dbitset result = bs;
    return result <<= shift, result;
  }

  friend dbitset operator>>(const dbitset &bs, size_t shift) {
    dbitset result = bs;
    return result >>= shift, result;
  }
};
} // namespace weilycoder


#line 1 "weilycoder/ds/narray.hpp"



/**
 * @file narray.hpp
 * @brief N-dimensional array (narray)
 */

#include <array>
#include <cstddef>
#include <initializer_list>
#line 13 "weilycoder/ds/narray.hpp"

namespace weilycoder {
/**
 * @brief N-dimensional array (narray) implementation.
 * @tparam T Type of elements stored in the narray.
 * @tparam D Number of dimensions.
 */
template <typename T, size_t D> struct narray {
  const std::array<size_t, D> shape;
  std::vector<T> data;

  /**
   * @brief Constructor to initialize narray with given dimensions.
   * @tparam Sizes Parameter pack for sizes of each dimension.
   * @param sizes Sizes of each dimension.
   */
  template <typename... Sizes>
  narray(Sizes... sizes) : shape{static_cast<size_t>(sizes)...} {
    static_assert(sizeof...(Sizes) == D, "Number of sizes must match dimensions.");
    data.resize(size());
  }

  /**
   * @brief Get the total number of elements in the narray.
   * @return Total number of elements.
   */
  size_t size() const {
    size_t res = 1;
    for (size_t i = 0; i < D; ++i)
      res *= shape[i];
    return res;
  }

  /**
   * @brief Compute the linear index from multi-dimensional indices.
   * @tparam Indices Parameter pack for indices in each dimension.
   * @param indices Indices in each dimension.
   * @return Linear index corresponding to the multi-dimensional indices.
   */
  template <typename... Indices> size_t _index(Indices... indices) const {
    static_assert(sizeof...(Indices) == D, "Number of indices must match dimensions.");
    const size_t idxs[] = {static_cast<size_t>(indices)...};
    size_t res = 0;
    for (size_t i = 0; i < D; ++i)
      res = res * shape[i] + idxs[i];
    return res;
  }

  /**
   * @brief Access element at specified multi-dimensional indices.
   * @tparam Indices Parameter pack for indices in each dimension.
   * @param indices Indices in each dimension.
   * @return Reference to the element at the specified indices.
   */
  template <typename... Indices> const T &operator()(Indices... indices) const {
    return data[_index(indices...)];
  }

  /**
   * @brief Access element at specified multi-dimensional indices.
   * @tparam Indices Parameter pack for indices in each dimension.
   * @param indices Indices in each dimension.
   * @return Reference to the element at the specified indices.
   */
  template <typename... Indices> T &operator()(Indices... indices) {
    return data[_index(indices...)];
  }
};
} // namespace weilycoder


#line 14 "weilycoder/matrix.hpp"

namespace weilycoder {
template <typename T, size_t R, size_t C> struct Matrix {
  T data[R][C];

  Matrix() : data{} {}
  Matrix(std::initializer_list<T> init) {
    size_t i = 0, j = 0;
    for (const auto &val : init) {
      data[i][j] = val;
      if (++j == C)
        j = 0, ++i;
    }
  }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  constexpr const T &operator()(size_t row, size_t col) const { return data[row][col]; }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  T &operator()(size_t row, size_t col) { return data[row][col]; }

  Matrix<T, R, C> &operator+=(const Matrix<T, R, C> &other) {
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        data(i, j) += other(i, j);
    return *this;
  }

  Matrix<T, R, C> &operator-=(const Matrix<T, R, C> &other) {
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        data(i, j) -= other(i, j);
    return *this;
  }

  template <size_t M> Matrix<T, R, M> &operator*=(const Matrix<T, C, M> &other) {
    Matrix<T, R, M> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < M; ++j)
        for (size_t k = 0; k < C; ++k)
          result(i, j) += data(i, k) * other(k, j);
    return *this = result;
  }

  friend Matrix<T, R, C> operator+(const Matrix<T, R, C> &a, const Matrix<T, R, C> &b) {
    Matrix<T, R, C> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        result(i, j) = a(i, j) + b(i, j);
    return result;
  }

  friend Matrix<T, R, C> operator-(const Matrix<T, R, C> &a, const Matrix<T, R, C> &b) {
    Matrix<T, R, C> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        result(i, j) = a(i, j) - b(i, j);
    return result;
  }

  template <size_t M>
  friend Matrix<T, R, C> operator*(const Matrix<T, R, M> &a, const Matrix<T, M, C> &b) {
    Matrix<T, R, C> result;
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        for (size_t k = 0; k < M; ++k)
          result(i, j) += a(i, k) * b(k, j);
    return result;
  }
};

template <typename T> struct NMatrix {
  narray<T, 2> data;

  NMatrix(size_t rows, size_t cols) : data(rows, cols) {}

  template <size_t R, size_t C> NMatrix(const Matrix<T, R, C> &matrix) : data(R, C) {
    for (size_t i = 0; i < R; ++i)
      for (size_t j = 0; j < C; ++j)
        data(i, j) = matrix(i, j);
  }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  constexpr const T &operator()(size_t row, size_t col) const { return data(row, col); }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  T &operator()(size_t row, size_t col) { return data(row, col); }

  NMatrix<T> &operator+=(const NMatrix<T> &other) {
    if (data.shape != other.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for addition.");
    for (size_t i = 0; i < data.shape[0]; ++i)
      for (size_t j = 0; j < data.shape[1]; ++j)
        data(i, j) += other(i, j);
    return *this;
  }

  NMatrix<T> &operator-=(const NMatrix<T> &other) {
    if (data.shape != other.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for subtraction.");
    for (size_t i = 0; i < data.shape[0]; ++i)
      for (size_t j = 0; j < data.shape[1]; ++j)
        data(i, j) -= other(i, j);
    return *this;
  }

  NMatrix<T> &operator*=(const NMatrix<T> &other) {
    if (data.shape[1] != other.data.shape[0])
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    size_t R = data.shape[0];
    size_t C = other.data.shape[1];
    size_t K = data.shape[1];
    NMatrix<T> result(R, C);
    for (size_t i = 0; i < R; ++i)
      for (size_t k = 0; k < K; ++k)
        for (size_t j = 0; j < C; ++j)
          result(i, j) += data(i, k) * other(k, j);
    return *this = result;
  }

  friend NMatrix<T> operator+(const NMatrix<T> &a, const NMatrix<T> &b) {
    if (a.data.shape != b.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for addition.");
    NMatrix<T> result(a.data.shape[0], a.data.shape[1]);
    for (size_t i = 0; i < a.data.shape[0]; ++i)
      for (size_t j = 0; j < a.data.shape[1]; ++j)
        result(i, j) = a(i, j) + b(i, j);
    return result;
  }

  friend NMatrix<T> operator-(const NMatrix<T> &a, const NMatrix<T> &b) {
    if (a.data.shape != b.data.shape)
      throw std::invalid_argument("Matrix dimensions do not match for subtraction.");
    NMatrix<T> result(a.data.shape[0], a.data.shape[1]);
    for (size_t i = 0; i < a.data.shape[0]; ++i)
      for (size_t j = 0; j < a.data.shape[1]; ++j)
        result(i, j) = a(i, j) - b(i, j);
    return result;
  }

  friend NMatrix<T> operator*(const NMatrix<T> &a, const NMatrix<T> &b) {
    if (a.data.shape[1] != b.data.shape[0])
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    size_t R = a.data.shape[0];
    size_t C = b.data.shape[1];
    size_t K = a.data.shape[1];
    NMatrix<T> result(R, C);
    for (size_t i = 0; i < R; ++i)
      for (size_t k = 0; k < K; ++k)
        for (size_t j = 0; j < C; ++j)
          result(i, j) += a(i, k) * b(k, j);
    return result;
  }
};

struct BMatrix {
  size_t rows, cols;
  std::vector<dbitset> data;

  BMatrix(size_t rows, size_t cols)
      : rows(rows), cols(cols), data(rows, dbitset(cols)) {}

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  dbitset::reference operator()(size_t row, size_t col) { return data[row][col]; }

  /**
   * @brief Access element at specified row and column.
   * @param row Row index.
   * @param col Column index.
   * @return Reference to the element at the specified row and column.
   */
  dbitset::const_reference operator()(size_t row, size_t col) const {
    return data[row][col];
  }

  BMatrix &operator&=(const BMatrix &other) {
    for (size_t i = 0; i < rows; ++i)
      data[i] &= other.data[i];
    return *this;
  }

  BMatrix &operator|=(const BMatrix &other) {
    for (size_t i = 0; i < rows; ++i)
      data[i] |= other.data[i];
    return *this;
  }

  BMatrix &operator^=(const BMatrix &other) {
    for (size_t i = 0; i < rows; ++i)
      data[i] ^= other.data[i];
    return *this;
  }

  BMatrix operator*=(const BMatrix &other) {
    if (cols != other.rows)
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    BMatrix result(rows, other.cols);
    for (size_t i = 0; i < rows; ++i)
      for (size_t k = 0; k < cols; ++k)
        if (data[i][k])
          result.data[i] ^= other.data[k];
    return *this = result;
  }

  friend BMatrix operator&(const BMatrix &a, const BMatrix &b) {
    if (a.rows != b.rows || a.cols != b.cols)
      throw std::invalid_argument("Matrix dimensions do not match for AND operation.");
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = static_cast<bool>(a(i, j)) && static_cast<bool>(b(i, j));
    return result;
  }

  friend BMatrix operator|(const BMatrix &a, const BMatrix &b) {
    if (a.rows != b.rows || a.cols != b.cols)
      throw std::invalid_argument("Matrix dimensions do not match for OR operation.");
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = static_cast<bool>(a(i, j)) || static_cast<bool>(b(i, j));
    return result;
  }

  friend BMatrix operator^(const BMatrix &a, const BMatrix &b) {
    if (a.rows != b.rows || a.cols != b.cols)
      throw std::invalid_argument("Matrix dimensions do not match for XOR operation.");
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = static_cast<bool>(a(i, j)) ^ static_cast<bool>(b(i, j));
    return result;
  }

  friend BMatrix operator~(const BMatrix &a) {
    BMatrix result(a.rows, a.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t j = 0; j < a.cols; ++j)
        result(i, j) = !static_cast<bool>(a(i, j));
    return result;
  }

  friend BMatrix operator*(const BMatrix &a, const BMatrix &b) {
    if (a.cols != b.rows)
      throw std::invalid_argument("Matrix dimensions do not match for multiplication.");
    BMatrix result(a.rows, b.cols);
    for (size_t i = 0; i < a.rows; ++i)
      for (size_t k = 0; k < a.cols; ++k)
        if (a(i, k))
          result.data[i] ^= b.data[k];
    return result;
  }
};
} // namespace weilycoder
Back to top page