Skip to content

Commit

Permalink
Better class of OcMatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas Cornu committed Oct 29, 2024
1 parent cd7ac7d commit 29322f2
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 43 deletions.
61 changes: 43 additions & 18 deletions src/ivoc/ocmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,35 @@ OcMatrix* OcMatrix::instance(int nrow, int ncol, int type) {
}
}

void OcMatrix::unimp() {
hoc_execerror("Matrix method not implemented for this type matrix", 0);
void OcMatrix::unimp() const {
hoc_execerror("Matrix method not implemented for this type matrix", nullptr);
}

void OcMatrix::nonzeros(std::vector<int>& m, std::vector<int>& n) {
void OcMatrix::nonzeros(std::vector<int>& m, std::vector<int>& n) const {
m.clear();
n.clear();
for (int i = 0; i < nrow(); i++) {
for (int j = 0; j < ncol(); j++) {
if (getval(i, j) != 0) {
if (getval(i, j) != 0.) {
m.push_back(i);
n.push_back(j);
}
}
}
}

std::vector<std::pair<int, int>> OcMatrix::nonzeros() const {
std::vector<std::pair<int, int>> nzs;
for (int i = 0; i < nrow(); i++) {
for (int j = 0; j < ncol(); j++) {
if (getval(i, j) != 0.) {
nzs.emplace_back(i, j);
}
}
}
return nzs;
}

OcFullMatrix* OcMatrix::full() {
if (type_ != MFULL) { // could clone one maybe
hoc_execerror("Matrix is not a FULL matrix (type 1)", 0);
Expand All @@ -68,13 +80,13 @@ OcFullMatrix::OcFullMatrix(int nrow, int ncol)
double* OcFullMatrix::mep(int i, int j) {
return &m_(i, j);
}
double OcFullMatrix::getval(int i, int j) {
double OcFullMatrix::getval(int i, int j) const {
return m_(i, j);
}
int OcFullMatrix::nrow() {
int OcFullMatrix::nrow() const {
return m_.rows();
}
int OcFullMatrix::ncol() {
int OcFullMatrix::ncol() const {
return m_.cols();
}

Expand Down Expand Up @@ -138,17 +150,17 @@ void OcFullMatrix::svd1(Matrix* u, Matrix* v, Vect* d) {
}
}

void OcFullMatrix::getrow(int k, Vect* out) {
void OcFullMatrix::getrow(int k, Vect* out) const {
auto v1 = Vect2VEC(out);
v1 = m_.row(k);
}

void OcFullMatrix::getcol(int k, Vect* out) {
void OcFullMatrix::getcol(int k, Vect* out) const {
auto v1 = Vect2VEC(out);
v1 = m_.col(k);
}

void OcFullMatrix::getdiag(int k, Vect* out) {
void OcFullMatrix::getdiag(int k, Vect* out) const {
auto vout = m_.diagonal(k);
if (k >= 0) {
for (int i = 0, j = k; i < nrow() && j < ncol(); ++i, ++j) {
Expand Down Expand Up @@ -226,7 +238,7 @@ void OcFullMatrix::solv(Vect* in, Vect* out, bool use_lu) {
v2 = lu_->solve(v1);
}

double OcFullMatrix::det(int* e) {
double OcFullMatrix::det(int* e) const {
*e = 0;
double m = m_.determinant();
if (m) {
Expand Down Expand Up @@ -260,15 +272,15 @@ void OcSparseMatrix::zero() {
}
}

double OcSparseMatrix::getval(int i, int j) {
double OcSparseMatrix::getval(int i, int j) const {
return m_.coeff(i, j);
}

int OcSparseMatrix::nrow() {
int OcSparseMatrix::nrow() const {
return m_.rows();
}

int OcSparseMatrix::ncol() {
int OcSparseMatrix::ncol() const {
return m_.cols();
}

Expand Down Expand Up @@ -330,7 +342,7 @@ void OcSparseMatrix::setcol(int k, double in) {
}
}

void OcSparseMatrix::ident(void) {
void OcSparseMatrix::ident() {
m_.setIdentity();
}

Expand All @@ -348,15 +360,15 @@ void OcSparseMatrix::setdiag(int k, double in) {
}
}

int OcSparseMatrix::sprowlen(int i) {
int OcSparseMatrix::sprowlen(int i) const {
int acc = 0;
for (decltype(m_)::InnerIterator it(m_, i); it; ++it) {
acc += 1;
}
return acc;
}

double OcSparseMatrix::spgetrowval(int i, int jindx, int* j) {
double OcSparseMatrix::spgetrowval(int i, int jindx, int* j) const {
int acc = 0;
for (decltype(m_)::InnerIterator it(m_, i); it; ++it) {
if (acc == jindx) {
Expand All @@ -368,13 +380,26 @@ double OcSparseMatrix::spgetrowval(int i, int jindx, int* j) {
return 0;
}

void OcSparseMatrix::nonzeros(std::vector<int>& m, std::vector<int>& n) {
void OcSparseMatrix::nonzeros(std::vector<int>& m, std::vector<int>& n) const {
m.clear();
n.clear();
m.reserve(m_.nonZeros());
n.reserve(m_.nonZeros());
for (int k = 0; k < m_.outerSize(); ++k) {
for (decltype(m_)::InnerIterator it(m_, k); it; ++it) {
m.push_back(it.row());
n.push_back(it.col());
}
}
}

std::vector<std::pair<int, int>> OcSparseMatrix::nonzeros() const {
std::vector<std::pair<int, int>> nzs;
nzs.reserve(m_.nonZeros());
for (int k = 0; k < m_.outerSize(); ++k) {
for (decltype(m_)::InnerIterator it(m_, k); it; ++it) {
nzs.emplace_back(it.row(), it.col());
}
}
return nzs;
}
55 changes: 30 additions & 25 deletions src/ivoc/ocmatrix.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <memory>
#include <utility>
#include <vector>

#include <Eigen/Eigen>
Expand Down Expand Up @@ -28,23 +29,25 @@ class OcMatrix {
return *mep(i, j);
};

virtual double getval(int i, int j) {
virtual double getval(int i, int j) const {
unimp();
return 0.;
}
virtual int nrow() {
virtual int nrow() const {
unimp();
return 0;
}
virtual int ncol() {
virtual int ncol() const {
unimp();
return 0;
}
virtual void resize(int, int) {
unimp();
}

virtual void nonzeros(std::vector<int>& m, std::vector<int>& n);
virtual void nonzeros(std::vector<int>& m, std::vector<int>& n) const;

virtual std::vector<std::pair<int, int>> nonzeros() const;

OcFullMatrix* full();

Expand All @@ -63,13 +66,13 @@ class OcMatrix {
virtual void add(Matrix*, Matrix* out) {
unimp();
}
virtual void getrow(int, Vect* out) {
virtual void getrow(int, Vect* out) const {
unimp();
}
virtual void getcol(int, Vect* out) {
virtual void getcol(int, Vect* out) const {
unimp();
}
virtual void getdiag(int, Vect* out) {
virtual void getdiag(int, Vect* out) const {
unimp();
}
virtual void setrow(int, Vect* in) {
Expand Down Expand Up @@ -123,20 +126,20 @@ class OcMatrix {
virtual void svd1(Matrix* u, Matrix* v, Vect* d) {
unimp();
}
virtual double det(int* e) {
virtual double det(int* e) const {
unimp();
return 0.0;
}
virtual int sprowlen(int) {
virtual int sprowlen(int) const {
unimp();
return 0;
}
virtual double spgetrowval(int i, int jindx, int* j) {
virtual double spgetrowval(int i, int jindx, int* j) const {
unimp();
return 0.;
}

void unimp();
void unimp() const;

protected:
OcMatrix(int type);
Expand All @@ -156,18 +159,18 @@ class OcFullMatrix final: public OcMatrix { // type 1
~OcFullMatrix() override = default;

double* mep(int, int) override;
double getval(int i, int j) override;
int nrow() override;
int ncol() override;
double getval(int i, int j) const override;
int nrow() const override;
int ncol() const override;
void resize(int, int) override;

void mulv(Vect* in, Vect* out) override;
void mulm(Matrix* in, Matrix* out) override;
void muls(double, Matrix* out) override;
void add(Matrix*, Matrix* out) override;
void getrow(int, Vect* out) override;
void getcol(int, Vect* out) override;
void getdiag(int, Vect* out) override;
void getrow(int, Vect* out) const override;
void getcol(int, Vect* out) const override;
void getdiag(int, Vect* out) const override;
void setrow(int, Vect* in) override;
void setcol(int, Vect* in) override;
void setdiag(int, Vect* in) override;
Expand All @@ -185,7 +188,7 @@ class OcFullMatrix final: public OcMatrix { // type 1
void transpose(Matrix* out) override;
void symmeigen(Matrix* mout, Vect* vout) override;
void svd1(Matrix* u, Matrix* v, Vect* d) override;
double det(int* exponent) override;
double det(int* exponent) const override;

private:
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> m_{};
Expand All @@ -198,10 +201,10 @@ class OcSparseMatrix final: public OcMatrix { // type 2
~OcSparseMatrix() override = default;

double* mep(int, int) override;
int nrow() override;
int ncol() override;
double getval(int, int) override;
void ident(void) override;
int nrow() const override;
int ncol() const override;
double getval(int, int) const override;
void ident() override;
void mulv(Vect* in, Vect* out) override;
void solv(Vect* vin, Vect* vout, bool use_lu) override;

Expand All @@ -212,10 +215,12 @@ class OcSparseMatrix final: public OcMatrix { // type 2
void setcol(int, double in) override;
void setdiag(int, double in) override;

void nonzeros(std::vector<int>& m, std::vector<int>& n) override;
void nonzeros(std::vector<int>& m, std::vector<int>& n) const override;

std::vector<std::pair<int, int>> nonzeros() const override;

int sprowlen(int) override; // how many elements in row
double spgetrowval(int i, int jindx, int* j) override;
int sprowlen(int) const override; // how many elements in row
double spgetrowval(int i, int jindx, int* j) const override;

void zero() override;

Expand Down

0 comments on commit 29322f2

Please sign in to comment.