PIMathMatrix add trace, assert for non square matrix, assert in operator *

This commit is contained in:
2020-10-23 15:10:57 +03:00
parent fa85f414db
commit 0189f28f43
3 changed files with 34 additions and 62 deletions

View File

@@ -393,6 +393,20 @@ public:
return ret; return ret;
} }
/**
* @brief Trace of the matrix is calculated. Works only with square matrix, nonzero matrices and invertible matrix
*
* @return matrix trace
*/
Type trace() const {
static_assert(Rows == Cols, "Works only with square matrix");
Type ret = Type(0);
for (uint i = 0; i < Cols; ++i) {
ret += m[i][i];
}
return ret;
}
/** /**
* @brief Transforming matrix to upper triangular. Works only with square matrix, nonzero matrices and invertible matrix * @brief Transforming matrix to upper triangular. Works only with square matrix, nonzero matrices and invertible matrix
* *
@@ -400,10 +414,7 @@ public:
* @return copy of transformed upper triangular matrix * @return copy of transformed upper triangular matrix
*/ */
_CMatrix &toUpperTriangular(bool *ok = 0) { _CMatrix &toUpperTriangular(bool *ok = 0) {
if (Cols != Rows) { static_assert(Rows == Cols, "Works only with square matrix");
if (ok != 0) *ok = false;
return *this;
}
_CMatrix smat(*this); _CMatrix smat(*this);
bool ndet; bool ndet;
uint crow; uint crow;
@@ -443,7 +454,7 @@ public:
* @return copy of inverted matrix * @return copy of inverted matrix
*/ */
_CMatrix &invert(bool *ok = 0) { _CMatrix &invert(bool *ok = 0) {
static_assert(Cols == Rows, "Only square matrix invertable"); static_assert(Rows == Cols, "Works only with square matrix");
_CMatrix mtmp = _CMatrix::identity(), smat(*this); _CMatrix mtmp = _CMatrix::identity(), smat(*this);
bool ndet; bool ndet;
uint crow; uint crow;
@@ -972,8 +983,8 @@ public:
Type determinant(bool *ok = 0) const { Type determinant(bool *ok = 0) const {
_CMatrix m(*this); _CMatrix m(*this);
bool k; bool k;
Type ret = Type(0);
m.toUpperTriangular(&k); m.toUpperTriangular(&k);
Type ret = Type(0);
if (ok) *ok = k; if (ok) *ok = k;
if (!k) return ret; if (!k) return ret;
ret = Type(1); ret = Type(1);
@@ -987,19 +998,14 @@ public:
/** /**
* @brief Trace of the matrix is calculated. Works only with square matrix, nonzero matrices and invertible matrix * @brief Trace of the matrix is calculated. Works only with square matrix, nonzero matrices and invertible matrix
* *
* @param ok is a parameter with which we can find out if the method worked correctly
* @return matrix trace * @return matrix trace
*/ */
Type trace(bool *ok = 0) const { Type trace() const {
assert(isSquare());
Type ret = Type(0); Type ret = Type(0);
if (!isSquare()) {
if (ok != 0) *ok = false;
return ret;
}
for (uint i = 0; i < _V2D::cols_; ++i) { for (uint i = 0; i < _V2D::cols_; ++i) {
ret += _V2D::element(i, i); ret += _V2D::element(i, i);
} }
if (ok != 0) *ok = true;
return ret; return ret;
} }
@@ -1010,10 +1016,7 @@ public:
* @return copy of transformed upper triangular matrix * @return copy of transformed upper triangular matrix
*/ */
_CMatrix &toUpperTriangular(bool *ok = 0) { _CMatrix &toUpperTriangular(bool *ok = 0) {
if (!isSquare()) { assert(isSquare());
if (ok != 0) *ok = false;
return *this;
}
_CMatrix smat(*this); _CMatrix smat(*this);
bool ndet; bool ndet;
uint crow; uint crow;
@@ -1054,10 +1057,7 @@ public:
* @return copy of inverted matrix * @return copy of inverted matrix
*/ */
_CMatrix &invert(bool *ok = 0, PIMathVector<Type> *sv = 0) { _CMatrix &invert(bool *ok = 0, PIMathVector<Type> *sv = 0) {
if (!isSquare()) { assert(isSquare());
if (ok != 0) *ok = false;
return *this;
}
_CMatrix mtmp = _CMatrix::identity(_V2D::cols_, _V2D::rows_), smat(*this); _CMatrix mtmp = _CMatrix::identity(_V2D::cols_, _V2D::rows_), smat(*this);
bool ndet; bool ndet;
uint crow; uint crow;
@@ -1196,14 +1196,13 @@ inline PIByteArray &operator>>(PIByteArray &s, PIMathMatrix<Type> &v) {
template<typename Type> template<typename Type>
inline PIMathMatrix<Type> operator*(const PIMathMatrix<Type> &fm, inline PIMathMatrix<Type> operator*(const PIMathMatrix<Type> &fm,
const PIMathMatrix<Type> &sm) { const PIMathMatrix<Type> &sm) {
uint cr = fm.cols(), rows0 = fm.rows(), cols1 = sm.cols(); assert(fm.cols() == sm.rows());
PIMathMatrix<Type> tm(cols1, rows0); PIMathMatrix<Type> tm(sm.cols(), fm.rows());
if (fm.cols() != sm.rows()) return tm;
Type t; Type t;
for (uint j = 0; j < rows0; ++j) { for (uint j = 0; j < fm.rows(); ++j) {
for (uint i = 0; i < cols1; ++i) { for (uint i = 0; i < sm.cols(); ++i) {
t = Type(0); t = Type(0);
for (uint k = 0; k < cr; ++k) for (uint k = 0; k < fm.cols(); ++k)
t += fm.element(j, k) * sm.element(k, i); t += fm.element(j, k) * sm.element(k, i);
tm.element(j, i) = t; tm.element(j, i) = t;
} }
@@ -1221,13 +1220,12 @@ inline PIMathMatrix<Type> operator*(const PIMathMatrix<Type> &fm,
template<typename Type> template<typename Type>
inline PIMathVector<Type> operator*(const PIMathMatrix<Type> &fm, inline PIMathVector<Type> operator*(const PIMathMatrix<Type> &fm,
const PIMathVector<Type> &sv) { const PIMathVector<Type> &sv) {
uint c = fm.cols(), r = fm.rows(); assert(fm.cols() == sv.size());
PIMathVector<Type> tv(r); PIMathVector<Type> tv(fm.rows());
if (c != sv.size()) return tv;
Type t; Type t;
for (uint j = 0; j < r; ++j) { for (uint j = 0; j < fm.rows(); ++j) {
t = Type(0); t = Type(0);
for (uint i = 0; i < c; ++i) for (uint i = 0; i < fm.cols(); ++i)
t += fm.element(j, i) * sv[i]; t += fm.element(j, i) * sv[i];
tv[j] = t; tv[j] = t;
} }
@@ -1244,12 +1242,12 @@ inline PIMathVector<Type> operator*(const PIMathMatrix<Type> &fm,
template<typename Type> template<typename Type>
inline PIMathVector<Type> operator*(const PIMathVector<Type> &sv, inline PIMathVector<Type> operator*(const PIMathVector<Type> &sv,
const PIMathMatrix<Type> &fm) { const PIMathMatrix<Type> &fm) {
uint c = fm.cols(), r = fm.rows(); PIMathVector<Type> tv(fm.cols());
PIMathVector<Type> tv(c); assert(fm.rows() == sv.size());
Type t; Type t;
for (uint j = 0; j < c; ++j) { for (uint j = 0; j < fm.cols(); ++j) {
t = Type(0); t = Type(0);
for (uint i = 0; i < r; ++i) for (uint i = 0; i < fm.rows(); ++i)
t += fm.element(i, j) * sv[i]; t += fm.element(i, j) * sv[i];
tv[j] = t; tv[j] = t;
} }

View File

@@ -355,12 +355,6 @@ TEST(PIMathMatrix_Test, determinantIfSquare) {
ASSERT_DOUBLE_EQ(d, i); ASSERT_DOUBLE_EQ(d, i);
} }
TEST(PIMathMatrix_Test, determinantIfNotSquare) {
PIMathMatrix<double> matrix(3, 5, 1.0);
matrix.element(1,1) = 5.0;
ASSERT_FALSE(matrix.determinant());
}
TEST(PIMathMatrix_Test, trace) { TEST(PIMathMatrix_Test, trace) {
PIMathMatrix<double> matrix(3, 3, 0.0); PIMathMatrix<double> matrix(3, 3, 0.0);
double t; double t;
@@ -383,12 +377,6 @@ TEST(PIMathMatrix_Test, trace) {
ASSERT_DOUBLE_EQ(t, i); ASSERT_DOUBLE_EQ(t, i);
} }
TEST(PIMathMatrix_Test, traceIfNotSquare) {
PIMathMatrix<double> matrix(3, 5, 1.0);
matrix.element(1,1) = 5.0;
ASSERT_FALSE(matrix.trace());
}
TEST(PIMathMatrix_Test, toUpperTriangular) { TEST(PIMathMatrix_Test, toUpperTriangular) {
PIMathMatrix<double> matrix(3, 3, 0.0); PIMathMatrix<double> matrix(3, 3, 0.0);
double d1, d2 = 1; double d1, d2 = 1;

View File

@@ -356,20 +356,6 @@ TEST(PIMathMatrixT_Test, determinantIfSquare) {
ASSERT_DOUBLE_EQ(i, d); ASSERT_DOUBLE_EQ(i, d);
} }
TEST(PIMathMatrixT_Test, determinantIfNotSquare) {
PIMathMatrixT<rows, 5u, double> matr;
matr.element(0,0) = 3;
matr.element(0,1) = 6;
matr.element(0,2) = 8;
matr.element(1,0) = 2;
matr.element(1,1) = 1;
matr.element(1,2) = 4;
matr.element(2,0) = 6;
matr.element(2,1) = 2;
matr.element(2,2) = 5;
ASSERT_FALSE(matr.determinant());
}
TEST(PIMathMatrixT_Test, invert) { TEST(PIMathMatrixT_Test, invert) {
PIMathMatrixT<rows, cols, double> matrix1; PIMathMatrixT<rows, cols, double> matrix1;
PIMathMatrixT<rows, cols, double> matrix2; PIMathMatrixT<rows, cols, double> matrix2;