-- Copyright (c) David Amos, 2008. All rights reserved.

{-# OPTIONS_GHC -fglasgow-exts #-}

module Math.Algebra.LinearAlgebra where

import qualified Data.List as L
import Math.Algebra.Field.Base -- not actually used in this module


infixr 8 *>, *>>
infixr 7 <<*>
infixl 7 <.>, <*>, <<*>>, <*>>
infixl 6 <+>, <->, <<+>>, <<->>

-- The mnemonic for these operations is that the number of angle brackets on each side indicates the dimension of the argument on that side

u <+> v = zipWith (+) u v

u <-> v = zipWith (-) u v

-- scalar multiplication
k *> v = map (k*) v

k *>> m = (map . map) (k*) m

-- dot product of vectors (also called inner or scalar product)
u <.> v = sum (zipWith (*) u v)

-- tensor product of vectors (also called outer or matrix product)
u <*> v = [ [a*b | b <- v] | a <- u]


-- matrix operations

a <<+>> b = (zipWith . zipWith) (+) a b

a <<->> b = (zipWith . zipWith) (-) a b

a <<*>> b = [ [u <.> v | v <- L.transpose b] | u <- a]
 
-- action on the left
m <<*> v = map (<.> v) m

-- action on the right
v <*>> m = map (v <.>) (L.transpose m)


fMatrix n f = [[f i j | j <- [1..n]] | i <- [1..n]] 

-- version with indices from zero
fMatrix' n f = [[f i j | j <- [0..n-1]] | i <- [0..n-1]] 


-- idMx n = fMatrix n (\i j -> if i == j then 1 else 0)

idMx n = idMxs !! n where
    idMxs = map snd $ iterate next (0,[])
    next (j,m) = (j+1, (1 : replicate j 0) : map (0:) m)

jMx n = replicate n (replicate n 1)

zMx n = replicate n (replicate n 0)

{-
-- VECTORS

data Vector d k = V [k] deriving (Eq,Ord,Show) 

instance (IntegerAsType d, Num k) => Num (Vector d k) where
    V a + V b = V $ a <+> b
    V a - V b = V $ a <-> b
    negate (V a) = V $ map negate a
    fromInteger 0 = V $ replicate d' 0 where d' = fromInteger $ value (undefined :: d)

V v <>> M m = V $ v <*>> m

M m <<> V v = V $ m <<*> v

k |> V v = V $ k *> v
-}

-- MATRICES

{-
-- Square matrices of dimension d over field k
data Matrix d k = M [[k]] deriving (Eq,Ord,Show)

instance (IntegerAsType d, Num k) => Num (Matrix d k) where
    M a + M b = M $ a <<+>> b
    M a - M b = M $ a <<->> b
    negate (M a) = M $ (map . map) negate a
    M a * M b = M $ a <<*>> b
    fromInteger 0 = M $ zMx d' where d' = fromInteger $ value (undefined :: d)
    fromInteger 1 = M $ idMx d' where d' = fromInteger $ value (undefined :: d)

instance (IntegerAsType d, Fractional a) => Fractional (Matrix d a) where
	recip (M a) = case inverse a of
		Nothing -> error "Matrix.recip: matrix is singular"
		Just a' -> M a'
-}

inverse m =
    let d = length m -- the dimension
        i = idMx d
        m' = zipWith (++) m i
        i1 = inverse1 m'
        i2 = inverse2 i1
    in if length i1 == d
       then Just i2
       else Nothing

-- given (M|I), use row operations to get to (U|A), where U is upper triangular with 1s on diagonal
inverse1 [] = []
inverse1 ((x:xs):rs) =
    if x /= 0
    then let r' = (1/x) *> xs
         in (1:r') : inverse1 [ys <-> y *> r' | (y:ys) <- rs]
    else case filter (\r' -> head r' /= 0) rs of
         [] -> [] -- early termination, which will be detected in calling function
         r:_ -> inverse1 (((x:xs) <+> r) : rs)
-- This is basically row echelon form

-- given (U|A), use row operations to get to M^-1
inverse2 [] = []
inverse2 ((1:r):rs) = inverse2' r rs : inverse2 rs where
    inverse2' xs [] = xs
    inverse2' (x:xs) ((1:r):rs) = inverse2' (xs <-> x *> r) rs
-- This is basically reduced row echelon form

xs ! i = xs !! (i-1) -- ie, a 1-based list lookup instead of 0-based

rowEchelonForm [] = []
rowEchelonForm ((x:xs):rs) =
    if x /= 0
    then let r' = (1/x) *> xs
         in (1:r') : map (0:) (rowEchelonForm [ys <-> y *> r' | (y:ys) <- rs])
    else case filter (\r' -> head r' /= 0) rs of
         [] -> map (0:) (rowEchelonForm $ xs : map tail rs)
         r:_ -> rowEchelonForm (((x:xs) <+> r) : rs)
rowEchelonForm zs@([]:_) = zs

reducedRowEchelonForm m = reverse $ reduce $ reverse $ rowEchelonForm m where
    reduce (r:rs) = let r':rs' = reduceStep (r:rs) in r' : reduce rs' -- is this scanl or similar?
    reduce [] = []
    reduceStep ((1:xs):rs) = (1:xs) : [ 0: (ys <-> y *> xs) | y:ys <- rs]
    reduceStep rs@((0:_):_) = zipWith (:) (map head rs) (reduceStep $ map tail rs)
    reduceStep rs = rs

-- kernel of a matrix
-- returns basis for vectors v s.t m <<*> v == 0
kernel m = kernelRRE $ reducedRowEchelonForm m

kernelRRE m =
    let nc = length $ head m -- the number of columns
        is = findLeadingCols 1 (L.transpose m) -- these are the indices of the columns which have a leading 1
        js = [1..nc] L.\\ is
        freeCols = let m' = take (length is) m -- discard zero rows
                   in zip is $ L.transpose [map (negate . (!j)) m' | j <- js]
        boundCols = zip js (idMx $ length js)
    in L.transpose $ map snd $ L.sort $ freeCols ++ boundCols
    where
    findLeadingCols i (c@(1:_):cs) = i : findLeadingCols (i+1) (map tail cs)
    findLeadingCols i (c@(0:_):cs) = findLeadingCols (i+1) cs
    findLeadingCols _ _ = []

m ^- n = recip m ^ n

-- t (M m) = M (L.transpose m)

det [[x]] = x
det ((x:xs):rs) =
    if x /= 0
    then let r' = (1/x) *> xs
         in x * det [ys <-> y *> r' | (y:ys) <- rs]
    else case filter (\r' -> head r' /= 0) rs of
         [] -> 0
         r:_ -> det (((x:xs) <+> r) : rs)


{-
class IntegerAsType a where
    value :: a -> Integer

data Z
instance IntegerAsType Z where
    value _ = 0

data S a
instance IntegerAsType a => IntegerAsType (S a) where
    value _ = value (undefined :: a) + 1
-}