module Data.Matrix.SmithNormalForm.Internal
    ( smithNF
    , rectifyDiagonal
    , diagonalize
    , isDiagonalMatrix
    , divides
    ) where

import qualified Data.Matrix as M
import qualified Data.Vector as V
import Data.Maybe (fromJust)

-- | Main method that returns the Smith normal form of a given matrix.
smithNF :: Integral a => M.Matrix a -> M.Matrix a
smithNF :: Matrix a -> Matrix a
smithNF Matrix a
m = (\[a]
diags -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a
extraZeros (Int -> a -> [a] -> Matrix a
forall a. Int -> a -> [a] -> Matrix a
M.diagonalList ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
diags) a
0 [a]
diags)) ([a] -> Matrix a) -> [a] -> Matrix a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map a -> a
forall a. Num a => a -> a
abs ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ Matrix a -> [a]
forall a. Integral a => Matrix a -> [a]
rectifyDiagonal (Matrix a -> [a]) -> Matrix a -> [a]
forall a b. (a -> b) -> a -> b
$ Matrix a -> Matrix a
forall a. Integral a => Matrix a -> Matrix a
diagonalize Matrix a
m
  where extraZeros :: Matrix a -> Matrix a
extraZeros Matrix a
d = if Matrix a -> Int
forall a. Matrix a -> Int
M.nrows Matrix a
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Matrix a -> Int
forall a. Matrix a -> Int
M.ncols Matrix a
m
                       then Matrix a -> Matrix a -> Matrix a
forall a. Matrix a -> Matrix a -> Matrix a
(M.<->) Matrix a
d (Int -> Int -> Matrix a
forall a. Num a => Int -> Int -> Matrix a
M.zero ((Matrix a -> Int
forall a. Matrix a -> Int
M.nrows Matrix a
m) Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Matrix a -> Int
forall a. Matrix a -> Int
M.ncols Matrix a
m)) (Matrix a -> Int
forall a. Matrix a -> Int
M.ncols Matrix a
m))
                       else Matrix a -> Matrix a -> Matrix a
forall a. Matrix a -> Matrix a -> Matrix a
(M.<|>) Matrix a
d (Int -> Int -> Matrix a
forall a. Num a => Int -> Int -> Matrix a
M.zero (Matrix a -> Int
forall a. Matrix a -> Int
M.nrows Matrix a
m) ((Matrix a -> Int
forall a. Matrix a -> Int
M.ncols Matrix a
m) Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Matrix a -> Int
forall a. Matrix a -> Int
M.nrows Matrix a
m)))

-- | Given a diagonal matrix, outputs a list \([d_1,..,d_n]\) that satisfies 
--  \(d_k \mid d_{k+1}\) and represents the diagonal entries.
-- Assumes input is a diagonal matrix (not checked).
rectifyDiagonal :: Integral a => M.Matrix a -> [a]
rectifyDiagonal :: Matrix a -> [a]
rectifyDiagonal Matrix a
diagonalMatrix
  | [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
diags Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = [a]
diags
  | Bool
allDivisible = [a]
diags
  | Bool
otherwise = Matrix a -> [a]
forall a. Integral a => Matrix a -> [a]
rectifyDiagonal (Matrix a -> [a]) -> Matrix a -> [a]
forall a b. (a -> b) -> a -> b
$ Matrix a -> Matrix a
forall a. Integral a => Matrix a -> Matrix a
diagonalize (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (\Matrix a -> Matrix a
op -> Matrix a -> Matrix a
op (Int -> a -> [a] -> Matrix a
forall a. Int -> a -> [a] -> Matrix a
M.diagonalList ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
diags) a
0 [a]
diags)) ((Matrix a -> Matrix a) -> Matrix a)
-> (Matrix a -> Matrix a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ [Matrix a -> Matrix a] -> Matrix a -> Matrix a
forall a. [a] -> a
head ([Matrix a -> Matrix a] -> Matrix a -> Matrix a)
-> [Matrix a -> Matrix a] -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ ((Bool, (Int, Int)) -> Matrix a -> Matrix a)
-> [(Bool, (Int, Int))] -> [Matrix a -> Matrix a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, (Int, Int)) -> Matrix a -> Matrix a
modifier ([(Bool, (Int, Int))] -> [Matrix a -> Matrix a])
-> [(Bool, (Int, Int))] -> [Matrix a -> Matrix a]
forall a b. (a -> b) -> a -> b
$ ((Bool, (Int, Int)) -> Bool)
-> [(Bool, (Int, Int))] -> [(Bool, (Int, Int))]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(Bool
b, (Int, Int)
_) -> Bool
b Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False) ([(Bool, (Int, Int))] -> [(Bool, (Int, Int))])
-> [(Bool, (Int, Int))] -> [(Bool, (Int, Int))]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [(Int, Int)] -> [(Bool, (Int, Int))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
divisibles [(Int, Int)]
divIndices 
  where diags :: [a]
diags = Vector a -> [a]
forall a. Vector a -> [a]
V.toList (Matrix a -> Vector a
forall a. Matrix a -> Vector a
M.getDiag Matrix a
diagonalMatrix)
        divIndices :: [(Int, Int)]
divIndices = [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
diags) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2] [Int
1..([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
diags) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
        divPairs :: [(a, a)]
divPairs = ((Int, Int) -> (a, a)) -> [(Int, Int)] -> [(a, a)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Int
k, Int
k') -> ([a]
diags [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int
k, [a]
diags [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int
k')) ([(Int, Int)] -> [(a, a)]) -> [(Int, Int)] -> [(a, a)]
forall a b. (a -> b) -> a -> b
$ [(Int, Int)]
divIndices
        divisibles :: [Bool]
divisibles = ((a, a) -> Bool) -> [(a, a)] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\(a
d, a
d') -> a
d a -> a -> Bool
forall a. Integral a => a -> a -> Bool
`divides` a
d') [(a, a)]
divPairs
        allDivisible :: Bool
allDivisible = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool]
divisibles
        modifier :: (Bool, (Int, Int)) -> Matrix a -> Matrix a
modifier (Bool
b, (Int
i, Int
j)) = if Bool
b then Matrix a -> Matrix a
forall a. a -> a
id else a -> (Int, Int) -> Matrix a -> Matrix a
forall a. a -> (Int, Int) -> Matrix a -> Matrix a
M.setElem ([a]
diags [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int
j) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2, Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)


-- | Given a matrix, returns a diagonal matrix obtained by applying 
-- elementary row and column operations, but which does not necessarily satisfy the divisibility property
diagonalize :: Integral a => M.Matrix a -> M.Matrix a
diagonalize :: Matrix a -> Matrix a
diagonalize = Int -> Matrix a -> Matrix a
forall a. Integral a => Int -> Matrix a -> Matrix a
diagonalizer Int
1

diagonalizer :: Integral a => Int -> M.Matrix a -> M.Matrix a
diagonalizer :: Int -> Matrix a -> Matrix a
diagonalizer Int
rowIndex Matrix a
m 
  | Int
rowIndex Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Matrix a -> Int
forall a. Matrix a -> Int
M.nrows Matrix a
m = Matrix a
m
  | Matrix a -> Bool
forall a. (Num a, Eq a) => Matrix a -> Bool
isDiagonalMatrix Matrix a
m = Matrix a
m
  | (Int, Int)
pivotPosition (Int, Int) -> (Int, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== (-Int
1, -Int
1) = Matrix a
m -- means there's no more cols
  | [Bool] -> Bool
hasNonzeroAmongZeros [Bool]
areZeroCols = Int -> Matrix a -> Matrix a
forall a. Integral a => Int -> Matrix a -> Matrix a
diagonalizer Int
rowIndex (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (\(Int
zeroColIndex, Bool
_) -> Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
M.switchCols Int
zeroColIndex (Int
zeroColIndexInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Matrix a
m) ((Int, Bool) -> Matrix a) -> (Int, Bool) -> Matrix a
forall a b. (a -> b) -> a -> b
$ [(Int, Bool)] -> (Int, Bool)
forall a. [a] -> a
head ([(Int, Bool)] -> (Int, Bool)) -> [(Int, Bool)] -> (Int, Bool)
forall a b. (a -> b) -> a -> b
$ ((Int, Bool) -> Bool) -> [(Int, Bool)] -> [(Int, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (\(Int
_, Bool
b) -> Bool -> Bool
not Bool
b) ([(Int, Bool)] -> [(Int, Bool)]) -> [(Int, Bool)] -> [(Int, Bool)]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Bool] -> [(Int, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..] [Bool]
areZeroCols 
  | Matrix a -> Bool
forall a. (Num a, Eq a) => Matrix a -> Bool
isZero Matrix a
restOfRow Bool -> Bool -> Bool
&& Matrix a -> Bool
forall a. (Num a, Eq a) => Matrix a -> Bool
isZero Matrix a
restOfCol = Int -> Matrix a -> Matrix a
forall a. Integral a => Int -> Matrix a -> Matrix a
diagonalizer (Int
rowIndex Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Matrix a
m
  | Bool
otherwise = Int -> Matrix a -> Matrix a
forall a. Integral a => Int -> Matrix a -> Matrix a
diagonalizer Int
rowIndex ((Int, Int) -> Matrix a -> Matrix a
forall a. Integral a => (Int, Int) -> Matrix a -> Matrix a
clearPivotRow (Int, Int)
pivotPosition (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> Int -> Matrix a -> Matrix a
forall a. Integral a => (Int, Int) -> Int -> Matrix a -> Matrix a
improvePivot (Int, Int)
pivotPosition (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Matrix a
mat)
  where areZeroCols :: [Bool]
areZeroCols = (Matrix a -> Bool) -> [Matrix a] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map Matrix a -> Bool
forall a. (Num a, Eq a) => Matrix a -> Bool
isZero ([Matrix a] -> [Bool]) -> [Matrix a] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Matrix a -> [Matrix a]
forall a. Matrix a -> [Matrix a]
cols Matrix a
m
        ((Int, Int)
pivotPosition, Matrix a
mat) = Int -> Matrix a -> ((Int, Int), Matrix a)
forall a. Integral a => Int -> Matrix a -> ((Int, Int), Matrix a)
choosePivot Int
rowIndex Matrix a
m 
        (Int
i, Int
j) = (Int, Int)
pivotPosition
        restOfRow :: Matrix a
restOfRow = Vector a -> Matrix a
forall a. Vector a -> Matrix a
M.rowVector (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
V.drop Int
j (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ Int -> Matrix a -> Vector a
forall a. Int -> Matrix a -> Vector a
M.getRow Int
i Matrix a
m
        restOfCol :: Matrix a
restOfCol = Vector a -> Matrix a
forall a. Vector a -> Matrix a
M.colVector (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
V.drop Int
i (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ Int -> Matrix a -> Vector a
forall a. Int -> Matrix a -> Vector a
M.getCol Int
j Matrix a
m

hasNonzeroAmongZeros :: [Bool] -> Bool
hasNonzeroAmongZeros :: [Bool] -> Bool
hasNonzeroAmongZeros (Bool
a:Bool
b:[Bool]
xs) = (Bool
a Bool -> Bool -> Bool
&& (Bool -> Bool
not Bool
b)) Bool -> Bool -> Bool
|| ([Bool] -> Bool
hasNonzeroAmongZeros (Bool
bBool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
:[Bool]
xs))
hasNonzeroAmongZeros [Bool]
_ = Bool
False

-- | Returns whether a matrix (not necessarily square) is diagonal
isDiagonalMatrix :: (Num a, Eq a) => M.Matrix a -> Bool
isDiagonalMatrix :: Matrix a -> Bool
isDiagonalMatrix Matrix a
m = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (a -> Bool) -> [a] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0) [Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
M.getElem Int
i Int
j Matrix a
m | Int
i <- [Int
1..Matrix a -> Int
forall a. Matrix a -> Int
M.nrows Matrix a
m], Int
j <- [Int
1..Matrix a -> Int
forall a. Matrix a -> Int
M.ncols Matrix a
m], Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
j]

clearPivotRow :: Integral a => (Int, Int) -> M.Matrix a -> M.Matrix a
clearPivotRow :: (Int, Int) -> Matrix a -> Matrix a
clearPivotRow (Int
t, Int
jt) Matrix a
m = Matrix a -> Matrix a
forall a. Matrix a -> Matrix a
M.transpose (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> Int -> Matrix a -> Matrix a
forall a. Integral a => (Int, Int) -> Int -> Matrix a -> Matrix a
improvePivot (Int
jt, Int
t) (Int
jtInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Matrix a -> Matrix a
forall a. Matrix a -> Matrix a
M.transpose Matrix a
m)

-- INVARIANT: all operations do not change the absolute value of the determinant
-- i.e. (I) scale a row by a unit (which are just 1 and -1 over the integers)
--     (II) switch two rows 
--    (III) add a multiple of one row to another
-- This does most of the real work.
improvePivot :: Integral a => (Int, Int) -> Int -> M.Matrix a -> M.Matrix a
improvePivot :: (Int, Int) -> Int -> Matrix a -> Matrix a
improvePivot (Int
t, Int
jt) Int
rowIndex Matrix a
m
  | a
pivot a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = [Char] -> Matrix a
forall a. HasCallStack => [Char] -> a
error [Char]
"Zero pivot entry."  
  | Int
rowIndex Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Matrix a -> Int
forall a. Matrix a -> Int
M.nrows Matrix a
m = Matrix a
m
  | a
pivot a -> a -> Bool
forall a. Integral a => a -> a -> Bool
`divides` a
nextEntry = (Int, Int) -> Int -> Matrix a -> Matrix a
forall a. Integral a => (Int, Int) -> Int -> Matrix a -> Matrix a
improvePivot (Int
t, Int
jt) (Int
rowIndex Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ((Int -> a -> a) -> Int -> Matrix a -> Matrix a
forall a. (Int -> a -> a) -> Int -> Matrix a -> Matrix a
M.mapRow (\Int
j a
elt -> a
elt a -> a -> a
forall a. Num a => a -> a -> a
+ ((Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
M.getElem Int
1 Int
j Matrix a
pivotRow)a -> a -> a
forall a. Num a => a -> a -> a
*(-a
1)a -> a -> a
forall a. Num a => a -> a -> a
*(a
nextEntry a -> a -> a
forall a. Integral a => a -> a -> a
`div` a
pivot))) Int
rowIndex Matrix a
m)
  | Bool
otherwise = (Int, Int) -> Int -> Matrix a -> Matrix a
forall a. Integral a => (Int, Int) -> Int -> Matrix a -> Matrix a
improvePivot (Int
t, Int
jt) Int
rowIndex (Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
M.switchRows Int
t Int
rowIndex ((Int -> a -> a) -> Int -> Matrix a -> Matrix a
forall a. (Int -> a -> a) -> Int -> Matrix a -> Matrix a
M.mapRow (\Int
j a
elt -> a
elt a -> a -> a
forall a. Num a => a -> a -> a
- (a
nextEntry a -> a -> a
forall a. Integral a => a -> a -> a
`div` a
pivot)a -> a -> a
forall a. Num a => a -> a -> a
*(Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
M.getElem Int
1 Int
j Matrix a
pivotRow)) Int
rowIndex Matrix a
m))        
  -- idea: nextEntry = pivot * q + r
  where pivot :: a
pivot = Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
M.getElem Int
t Int
jt Matrix a
m
        nextEntry :: a
nextEntry = Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
M.getElem Int
rowIndex Int
jt Matrix a
m   
        pivotRow :: Matrix a
pivotRow = Int -> Matrix a -> Matrix a
forall a. Int -> Matrix a -> Matrix a
row Int
t Matrix a
m

-- | Returns whether a `divides` b := \(a \mid b\) 
-- and handles the special case of \(a=0\)
divides :: Integral a => a -> a -> Bool
divides :: a -> a -> Bool
divides a
0 a
0 = Bool
True
divides a
0 a
_ = Bool
False
divides a
a a
b = a
b a -> a -> a
forall a. Integral a => a -> a -> a
`mod` a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0

---------------------
-- PIVOT-SELECTION --
---------------------
choosePivot :: Integral a => Int -> M.Matrix a -> ((Int, Int), M.Matrix a)
choosePivot :: Int -> Matrix a -> ((Int, Int), Matrix a)
choosePivot Int
rowIndex Matrix a
m
  | Maybe (Int, Int)
pivotPosition Maybe (Int, Int) -> Maybe (Int, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe (Int, Int)
forall a. Maybe a
Nothing = ((-Int
1, -Int
1), Matrix a
m)    
  | (Int -> Int -> Matrix a -> a
forall a. Int -> Int -> Matrix a -> a
M.getElem Int
t Int
jt Matrix a
m) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
0 = ((Int
t, Int
jt), Matrix a
m)
  | Bool
otherwise = ((Int
t, Int
jt), (Int, Int) -> Matrix a -> Matrix a
forall a. Integral a => (Int, Int) -> Matrix a -> Matrix a
makePivotNonzero (Int
t, Int
jt) Matrix a
m)
  where pivotPosition :: Maybe (Int, Int)
pivotPosition = Int -> Matrix a -> Maybe (Int, Int)
forall a. Integral a => Int -> Matrix a -> Maybe (Int, Int)
nextPivotPosition Int
rowIndex Matrix a
m
        (Int
t, Int
jt) = Maybe (Int, Int) -> (Int, Int)
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (Int, Int)
pivotPosition

-- ASSUME: there is a nonzero entry in column jt, not at (t, jt), but at jt' > jt
makePivotNonzero :: Integral a => (Int, Int) -> M.Matrix a -> M.Matrix a
makePivotNonzero :: (Int, Int) -> Matrix a -> Matrix a
makePivotNonzero (Int
t, Int
jt) Matrix a
m = Int -> Int -> Matrix a -> Matrix a
forall a. Int -> Int -> Matrix a -> Matrix a
M.switchRows Int
t Int
nonzeroIndex Matrix a
m
  where nonzeroIndex :: Int
nonzeroIndex = [Int] -> Int
forall a. [a] -> a
head [Int
i | (Int
i, a
entry) <- [Int] -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..] (Matrix a -> [a]
forall a. Matrix a -> [a]
M.toList (Int -> Matrix a -> Matrix a
forall a. Int -> Matrix a -> Matrix a
col Int
jt Matrix a
m)), a
entry a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
0, Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
t]

-- ASSUMES: the rows with index < rowIndex only have a single nonzero entry, 
--   which occurs in columnIndex < rowIndex 
nextPivotPosition :: Integral a => Int -> M.Matrix a -> Maybe (Int, Int)
nextPivotPosition :: Int -> Matrix a -> Maybe (Int, Int)
nextPivotPosition Int
rowIndex Matrix a
m 
  | [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
nonzeroColIndices = Maybe (Int, Int)
forall a. Maybe a
Nothing
  | Bool
otherwise = (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
rowIndex, ([Int] -> Int
forall a. [a] -> a
head [Int]
nonzeroColIndices))
  where nonzeroColIndices :: [Int]
nonzeroColIndices = [Int
j | (Int
j, Matrix a
column) <- [Int] -> [Matrix a] -> [(Int, Matrix a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..] (Matrix a -> [Matrix a]
forall a. Matrix a -> [Matrix a]
cols Matrix a
m), Bool -> Bool
not (Matrix a -> Bool
forall a. (Num a, Eq a) => Matrix a -> Bool
isZero Matrix a
column), Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
rowIndex]

isZero :: (Num a, Eq a) => M.Matrix a -> Bool
isZero :: Matrix a -> Bool
isZero Matrix a
m = Matrix a
m Matrix a -> Matrix a -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Int -> Matrix a
forall a. Num a => Int -> Int -> Matrix a
M.zero (Matrix a -> Int
forall a. Matrix a -> Int
M.nrows Matrix a
m) (Matrix a -> Int
forall a. Matrix a -> Int
M.ncols Matrix a
m)

---------------------------
-- MATRIX HELPER METHODS --
---------------------------
-- return a list of rows, as 1 x n matrices
rows :: M.Matrix a -> [M.Matrix a]
rows :: Matrix a -> [Matrix a]
rows Matrix a
m = (Int -> Matrix a) -> [Int] -> [Matrix a]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
k -> Int -> Matrix a -> Matrix a
forall a. Int -> Matrix a -> Matrix a
row Int
k Matrix a
m) [Int
1..(Matrix a -> Int
forall a. Matrix a -> Int
M.nrows Matrix a
m)]

-- return a list of cols, as n x 1 matrices
cols :: M.Matrix a -> [M.Matrix a]
cols :: Matrix a -> [Matrix a]
cols Matrix a
m = (Int -> Matrix a) -> [Int] -> [Matrix a]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
k -> Int -> Matrix a -> Matrix a
forall a. Int -> Matrix a -> Matrix a
col Int
k Matrix a
m) [Int
1..(Matrix a -> Int
forall a. Matrix a -> Int
M.ncols Matrix a
m)]

-- get a row, represented as an 1 x n matrix
row :: Int -> M.Matrix a -> M.Matrix a
row :: Int -> Matrix a -> Matrix a
row Int
k = (Vector a -> Matrix a
forall a. Vector a -> Matrix a
M.rowVector (Vector a -> Matrix a)
-> (Matrix a -> Vector a) -> Matrix a -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Matrix a -> Vector a
forall a. Int -> Matrix a -> Vector a
M.getRow Int
k) 

-- get a column, represented as a n x 1 matrix
col :: Int -> M.Matrix a -> M.Matrix a
col :: Int -> Matrix a -> Matrix a
col Int
k = (Vector a -> Matrix a
forall a. Vector a -> Matrix a
M.colVector (Vector a -> Matrix a)
-> (Matrix a -> Vector a) -> Matrix a -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Matrix a -> Vector a
forall a. Int -> Matrix a -> Vector a
M.getCol Int
k)