/* * MatrixUtil.java * * Created on March 7, 2005, 10:38 AM */ package org.das2.math.matrix; import java.io.PrintStream; import java.text.DecimalFormat; import java.util.Arrays; /** * * @author eew */ public final class MatrixUtil { private static final DecimalFormat format = new DecimalFormat(" 0.0##;-"); /** Creates a new instance of MatrixUtil */ private MatrixUtil() { } public static void print(Matrix m, PrintStream out) { int nRow = m.rowCount(); int nCol = m.columnCount(); for (int iRow = 0; iRow < nRow; iRow++) { out.print("["); for (int iCol = 0; iCol < nCol; iCol++) { out.print(format.format(m.get(iRow, iCol))); out.print('\t'); } out.println("]"); } } public static Matrix inverse(Matrix m) { Matrix orig; Matrix inv; Matrix both; int nRow; if (m.columnCount() != m.rowCount()) { throw new IllegalArgumentException("m must be a square matrix"); } nRow = m.rowCount(); orig = new ArrayMatrix(m); inv = identity(nRow); both = new CompositeMatrix(orig, inv); for (int iRow = 0; iRow < nRow; iRow++) { if (both.get(iRow, iRow) == 0.0) { pivot(both, iRow); } both.rowTimes(iRow, 1.0 / both.get(iRow, iRow)); for (int i = 0; i < nRow; i++) { if (i != iRow) { double scale = -both.get(i, iRow); both.rowTimesAddTo(iRow, scale, i); } } } return inv; } public static void pivot(Matrix m, final int row) { int nRow = m.rowCount(); for (int iRow = row + 1; iRow < nRow; iRow++) { if (m.get(iRow, row) != 0.0) { m.swapRows(row, iRow); return; } print(m, System.err); throw new IllegalArgumentException("Can't pivot"); } } public static Matrix identity(final int rows) { Matrix m = new ArrayMatrix(new double[rows * rows], rows, rows); for (int iRow = 0; iRow < rows; iRow++) { m.set(iRow, iRow, 1.0); } return m; } public static Matrix multiply(Matrix m1, Matrix m2) { Matrix res = new ArrayMatrix(m1.rowCount(), m2.columnCount()); multiply(m1, m2, res); return res; } public static void multiply(Matrix m1, Matrix m2, Matrix res) { int nRow, nCol, nInner; if (m1.columnCount() != m2.rowCount()) { throw new IllegalArgumentException(""); } nRow = m1.rowCount(); nCol = m2.columnCount(); nInner = m1.columnCount(); if (nRow != res.rowCount() || nCol != res.columnCount()); for (int iRow = 0; iRow < nRow; iRow++) { for (int iCol = 0; iCol < nCol; iCol++) { double d = 0; for (int iInner = 0; iInner < nInner; iInner++) { d += m1.get(iRow, iInner) * m2.get(iInner, iCol); } res.set(iRow, iCol, d); } } } }