#include #include #include "matrix.h" using namespace std; template valarray& matmul_v(valarray& a, const valarray& b, size_t n) { if (n == 1) return a = a * b; gslice s11(0, {n / 2, n / 2}, {n, 1}), s12(n / 2, {n / 2, n / 2}, {n, 1}), s21(n * n / 2, {n / 2, n / 2}, {n, 1}), s22(n * n / 2 + n / 2, {n / 2, n / 2}, {n, 1}); valarray h = b[s11]; h += b[s22]; valarray m1 = a[s11]; m1 += a[s22]; m1 = matmul_v(m1, h, n / 2); h = b[s11]; valarray m2 = a[s21]; m2 += a[s22]; m2 = matmul_v(m2, h, n / 2); h = b[s12]; h -= b[s22]; valarray m3 = a[s11]; m3 = matmul_v(m3, h, n / 2); h = b[s21]; h -= b[s11]; valarray m4 = a[s22]; m4 = matmul_v(m4, h, n / 2); h = b[s22]; valarray m5 = a[s11]; m5 += a[s12]; m5 = matmul_v(m5, h, n / 2); h = b[s11]; h += b[s12]; valarray m6 = a[s21]; m6 -= a[s11]; m6 = matmul_v(m6, h, n / 2); h = b[s21]; h += b[s22]; valarray m7 = a[s12]; m7 -= a[s22]; m7 = matmul_v(m7, h, n / 2); a[s11] = m1 + m4 - m5 + m7; a[s12] = m3 + m5; a[s21] = m2 + m4; a[s22] = m1 - m2 + m3 + m6; return a; } bool power_of_two(size_t i) { return i && !(i & (i - 1)); } template matrix matmul(matrix a, matrix b) { if (a.rows() != a.columns() || a.rows() != b.rows() || a.rows() != b.columns()) throw invalid_argument("matrices not square or not of same dimension"); if (!power_of_two(a.columns())) throw invalid_argument("matrix dimension not power of two"); valarray av = a; av = matmul_v(av, b, a.rows()); return matrix{valarray(av), a.rows(), a.columns()}; } template matrix matmul(matrix, matrix);