{- Copyright (c) 2008, Scott E. Dillard. All rights reserved. -}

{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}

{-# OPTIONS_HADDOCK ignore-exports,prune #-}

module Data.Vec.LinAlg 
  ) where

import Prelude hiding (map,zipWith,foldl,foldr,reverse,take,drop,
import qualified Prelude as P
import Data.Vec.Base
import Data.Vec.Nat

import Control.Monad
import Data.Maybe

-- | dot \/ inner \/ scalar product
dot ::  (Num a, Num v, Fold v a) => v -> v -> a
dot u v = sum (u*v)
{-# INLINE dot #-}

-- | vector norm, squared
normSq ::  (Num a, Num v, Fold v a) => v -> a
normSq v = dot v v
{-# INLINE normSq #-}

-- | vector \/ L2 \/ Euclidean norm
norm ::  (Num v, Floating a, Fold v a) => v -> a
norm v = sqrt (dot v v)
{-# INLINE norm #-}

-- | @normalize v@ is a unit vector in the direction of @v@. @v@ is assumed
-- non-null.
normalize :: (Floating a, Num v, Fold v a, Map a a v v) => v -> v
normalize v = map (/(norm v)) v
{-# INLINE normalize #-}

-- | 3d cross product.
cross :: Num a => Vec3 a -> Vec3 a -> Vec3 a
cross (ux:.uy:.uz:._) (vx:.vy:.vz:._) =
{-# INLINE cross #-}

-- | lift a point into homogenous coordinates
homPoint ::  (Snoc v a v', Num a) => v -> v'
homPoint v = snoc v 1
{-# INLINE homPoint #-}

-- | point-at-infinity in homogenous coordinates
homVec ::  (Snoc v a v', Num a) => v -> v'
homVec   v = snoc v 0
{-# INLINE homVec   #-}

-- | project a vector from homogenous coordinates. Last vector element is
-- assumed non-zero.
project :: 
  ( Reverse' () t1 v'
  , Fractional t1
  , Vec a t t1
  , Reverse' () v (t :. t1)
  ) => v -> v'
project  v = case reverse v of (w:.u) -> reverse (u/vec w)
{-# INLINE project  #-}

-- | row vector * matrix
multvm :: 
  ( Transpose m mt
  , Map v a mt v'
  , Fold v a
  , Num a
  , Num v
  ) => v -> m -> v'
multvm v m = map (dot v) (transpose m)
{-# INLINE multvm #-}

-- | matrix * column vector
multmv :: 
  ( Map v a m v'
  , Num v
  , Fold v a
  , Num a
  ) => m -> v -> v'
multmv m v = map (dot v) m
{-# INLINE multmv #-}

-- | matrix * matrix 
multmm :: 
  (Map v v' m1 m3
  ,Map v a b v'
  ,Transpose m2 b
  ,Fold v a
  ,Num v
  ,Num a
  ) => m1 -> m2 -> m3
multmm a b = map (\v -> map (dot v) (transpose b)) a
{-# INLINE multmm #-}

-- | apply a translation to a projective transformation matrix
translate :: 
  (Transpose m mt
  ,Reverse' () mt (v' :. t)
  ,Reverse' (v' :. ()) t v'1
  ,Transpose v'1 m
  ,Num v'
  ,Num a
  ,Snoc v a v'
  ) => v -> m -> m
translate v m = 
  case reverse (transpose m) of
    (h:.t) -> transpose (reverse (((homVec v) + h) :. t))
{-# INLINE translate #-}

-- | get the @n@-th column as a vector. @n@ is a type-level natural.
column ::  (Transpose m mt, Access n v mt) => n -> m -> v
column n = get n . transpose 
{-# INLINE row #-}

-- | get the @n@-th row as a vector. @n@ is a type-level natural.
row ::  (Access n a v) => n -> v -> a
row n = get n
{-# INLINE column #-}

-- Matrix transpose wrapper class: infers type of one argument from the other,
-- because Transpose` can't do it, the fundeps there can't be bijective

-- | matrix transposition
class Transpose a b | a -> b, b -> a where 
  transpose :: a -> b

instance Transpose () () where
  transpose = id

    (Vec (Succ n) s (s:.ra)  --(s:ra) is an n-vector of s'es (row of a)
    ,Vec (Succ m) (s:.ra) ((s:.ra):.a)  --a is an m-vector of ra's
    ,Vec (Succ m) s (s:.rb)  --rb is an m-vector of s'es (row of b)
    ,Vec (Succ n) (s:.rb) ((s:.rb):.b)  --b is an n-vector of rb's
    ,Transpose' ((s:.ra):.a) ((s:.rb):.b)
    => Transpose ((s:.ra):.a) ((s:.rb):.b)
    transpose = transpose'
    {-# INLINE transpose #-}

class Transpose' a b | a->b
  where transpose' :: a -> b

instance Transpose' () () where 
  transpose' = id
  {-# INLINE transpose' #-}

    (Transpose' vs vs') => Transpose' ( () :. vs ) vs'
    transpose' (():.vs) = transpose' vs
    {-# INLINE transpose' #-}

instance Transpose' ((x:.()):.()) ((x:.()):.()) where
  transpose' = id

    (Head xss_h xss_hh
    ,Map xss_h xss_hh (xss_h:.xss_t) xs'
    ,Tail xss_h xss_ht
    ,Map xss_h xss_ht (xss_h:.xss_t) xss_
    ,Transpose' (xs :. xss_) xss'
    => Transpose' ((x:.xs):.(xss_h:.xss_t)) ((x:.xs'):.xss') 
    transpose' ((x:.xs):.xss) =
      (x :. (map head xss)) :. (transpose' (xs :. (map tail xss) :: (xs:.xss_)))
    {-# INLINE transpose' #-}

class SetDiagonal v m | m -> v, v -> m where
  -- |set the diagonal of an n-by-n matrix to a given n-vector
  setDiagonal :: v -> m -> m

instance (Vec n a v, Vec n r m, SetDiagonal' N0 v m) => SetDiagonal v m where
  setDiagonal v m = setDiagonal' (undefined::N0) v m
  {-# INLINE setDiagonal #-}

class SetDiagonal' n v m  where
  setDiagonal' :: n -> v -> m -> m

instance SetDiagonal' n () m where
  setDiagonal' _ _ m = m
  {-# INLINE setDiagonal' #-}

    ( SetDiagonal' (Succ n) v m
    , Access n a r
    ) => SetDiagonal' n (a:.v) (r:.m) 
    setDiagonal' _ (a:.v) (r:.m) = 
       (set (undefined::n) a r) :. (setDiagonal' (undefined::Succ n) v m)
    {-# INLINE setDiagonal' #-}

class GetDiagonal m v | m -> v, v -> m where
  -- |get the diagonal of an n-by-n matrix as a vector
  getDiagonal :: m -> v

instance (Vec n a v, Vec n v m, GetDiagonal' N0 () m v) => GetDiagonal m v where
  getDiagonal m = getDiagonal' (undefined::N0) () m
  {-# INLINE getDiagonal #-}

class GetDiagonal' n p m v where
  getDiagonal' :: n -> p -> m -> v

    (Access n a r
    ,Append p (a:.()) (a:.p)
    ) => GetDiagonal' n p (r:.()) (a:.p) 
    getDiagonal' _ p (r:.()) = append p ((get (undefined::n) r) :. ())
    {-# INLINE getDiagonal' #-}

    (Access n a r
    ,Append p (a:.()) p'
    ,GetDiagonal' (Succ n) p' (r:.m) v
    => GetDiagonal' n p (r:.r:.m) v
    getDiagonal' _ p (r:.m) = 
      getDiagonal' (undefined::Succ n) (append p ((get (undefined::n) r):.())) m
    {-# INLINE getDiagonal' #-}

-- | @scale v m@ multiplies the diagonal of matrix @m@ by the vector @s@, component-wise. So
-- @scale 5 m@ multiplies the diagonal by 5, whereas @scale 2:.1 m@
-- only scales the x component.
scale :: 
  ( GetDiagonal' N0 () m r
  , Num r
  , Vec n a r
  , Vec n r m
  , SetDiagonal' N0 r m
  ) => r -> m -> m
scale s m = setDiagonal (s * (getDiagonal m)) m
{-# INLINE scale #-}

-- | @diagonal v@ is a square matrix with the vector v as the diagonal, and 0
-- elsewhere.
diagonal :: (Vec n a v, Vec n v m, SetDiagonal v m, Num m) => v -> m
diagonal v = setDiagonal v 0
{-# INLINE diagonal #-}

-- | identity matrix (square)
identity :: (Vec n a v, Vec n v m, Num v, Num m, SetDiagonal v m) => m
identity = diagonal 1 
{-# INLINE identity #-}

-- Det' needs help inferring that all of the matrix elements are the same type.

-- | Determinant by minor expansion, i.e. Laplace's formula. Unfolds into a
-- closed form expression.  This should be the fastest way for 4x4 and smaller,
-- but @snd . gaussElim@ works too.

det :: forall n a r m. (Vec n a r, Vec n r m, Det' m a) => m -> a
det = det'
{-# INLINE det #-}

-- The Determinant of a square matrix, by minor expansion. 
class Det' m a | m -> a where
  det' :: m -> a

instance Det' ((a:.()):.()) a where
  det' ((a:._):._) = a

  ( (a:.a:.v) ~ r                  -- a row of the matrix, an n-vector
  , ((a:.a:.v):.(a:.a:.v):.vs) ~ m -- an n*n matrix, n >= 2
  , ((a:.v):.(a:.v):.vs_) ~ m_     -- an n*(n-1) matrix
  , (((a:.v):.vs_):.(x:.y)) ~ mm   -- an n-vector of (n-1)*(n-1) matrices to recurse upon
  , Map (a:.a:.v) (a:.v) m m_      -- drop the first column of m to get m_
  , DropConsec m_ mm               -- an n-vector of (n-1)*(n-1) matrices
  , Det' ((a:.v):.vs_) a           -- determinant of (n-1)*(n-1) matrix
  , Map ((a:.v):.vs_) a mm r       -- dets of all n of the (n-1)*(n-1) matrices, the result is same type as a row
  , Map r a m r                    -- grab the first column using "map head" the result is same type as a row
  , NegateOdds r                   -- flip sign of odd elements of first column
  , Fold r a                       -- add evertyhing up...
  , Num r
  , Num a
  ) => Det' ((a:.a:.v):.(a:.a:.v):.vs) a                    -- et voila
  det' m =
    sum $ (negateOdds $ map head m) * map det' (dropConsec $ map tail m)

-- DropConsec: Drop consecutive elements, collecting the results. Given an
-- n-vector v, drop each element from v, one at a time in sequence, and collect
-- the resulting (n-1)-vectors into an n-vector (ie an n-by-(n-1) matrix).
-- This is used for determinants.
-- dropConsec [1,2,3,4] = [[2,3,4],[1,3,4],[1,2,4],[1,2,3]]
class DropConsec v vv | v -> vv where
  dropConsec :: v -> vv

  (Vec n a v
  ,Pred n n_
  ,Vec n_ a v_
  ,Vec n v_ vv
  ,DropConsec' () v vv
  ) => DropConsec v vv
    dropConsec v = dropConsec' () v 
    {-# INLINE dropConsec #-}

class DropConsec' p v vv  where
  dropConsec' :: p -> v -> vv
instance DropConsec' p (a:.()) (p:.()) where
  dropConsec' p (a:.()) = (p:.())
  {-# INLINE dropConsec' #-}

    (Append p (a:.v) x
    ,Append p (a:.()) y
    ,DropConsec' y (a:.v) z
    => DropConsec' p (a:.a:.v) (x:.z)
    dropConsec' p (a:.v) = 
      (append p v) :. (dropConsec' (append p (a:.())) v)
    {-# INLINE dropConsec' #-}

-- Negate the odd or even elements of a vector.
-- Used for determinants.

class NegateOdds v where
  negateOdds :: v -> v 

class NegateEvens v where
  negateEvens :: v -> v 

instance NegateOdds  () where 
  negateOdds  () = () 
  {-# INLINE negateOdds #-}
instance NegateEvens () where 
  negateEvens () = () 
  {-# INLINE negateEvens #-}

instance (Num a, NegateEvens v) => NegateOdds (a:.v) where
  negateOdds (a:.v) = a :. negateEvens v
  {-# INLINE negateOdds #-}

instance (Num a, NegateOdds v) => NegateEvens (a:.v) where
  negateEvens (a:.v) = negate a :. negateOdds v
  {-# INLINE negateEvens #-}

--ReplConsec : this is a helper for implementing Cramer's rule.  Given an
--n-vector v and a value r, replace each consecutive element from v with r,
--and collect the resulting n-vectors into an n-vector (ie an n-by-n matrix)

class ReplConsec a v vv | v->a, v->vv, vv->v, vv->a where
  replConsec :: a -> v -> vv

  (Vec n a v
  ,Vec n v vv
  ,ReplConsec' a () v vv
  ) => ReplConsec a v vv
    replConsec a v = replConsec' a () v :: vv
    {-# INLINE replConsec #-}

class ReplConsec' a p v vv where
  replConsec' :: a -> p -> v -> vv

instance ReplConsec' a p () () where
  replConsec' _ _ () = ()
  {-# INLINE replConsec' #-}

    (Append p (a:.v) x
    ,Append p (a:.()) y
    ,ReplConsec' a y v z
    => ReplConsec' a p (a:.v) (x:.z)
    replConsec' r p (a:.v) = 
      (append p (r:.v)) :. (replConsec' r (append p (a :. ())) v)
    {-# INLINE replConsec' #-}

-- | @cramer'sRule m v@ computes the solution to @m\`multmv\`x=v@  using the
-- eponymous method. For larger than 3x3 you will want to use 'solve', which
-- uses 'gaussElim'. Cramer's rule, however, unfolds into a closed-form
-- expression, with no branches or allocations (other than the result). You may
-- need to increase the unfolding threshold to see this.

cramer'sRule :: 
  (Map a a1 b1 v
  ,Transpose w b1
  ,ZipWith a2 b vv v m w
  ,ReplConsec' a2 () b vv
  ,Vec n b vv
  ,Vec n a2 b
  ,Fractional a1
  ,Det' m a1
  ,Det' a a1
  ) => m -> v -> v
cramer'sRule m b =
  case map (\m' -> (det' m')/(det' m)) 
           (transpose (zipWith replConsec b m)) 
    of b' -> b' `asTypeOf` b 
{-# INLINE cramer'sRule #-}

mapFst f (a,b) = (f a,b)
{-# INLINE mapFst #-}

class Num a => NearZero a where
  -- | @nearZero x@ should be true when x is close enough to 0 to cause
  -- significant error in division. 
  nearZero :: a -> Bool
  nearZero 0 = True
  nearZero _ = False
  {-# INLINE nearZero #-}

instance NearZero Float where
  nearZero x = abs x < 1e-6
  {-# INLINE nearZero #-}

instance NearZero Double where
  nearZero x = abs x < 1e-14
  {-# INLINE nearZero #-}

instance NearZero Rational

-- Pivot1 : find a non-zero pivot column and put a 1 there. Second return
-- argument tracks value of determinant. Returns nothing if no pivot in the
-- first row. Does not try to find the 'best' pivot, only an acceptable one:
-- matrices are assumed small, roundoff error should be negligible. 

class Pivot1 a m where 
  pivot1 :: m -> Maybe (m,a)

--this instance prevents a fundep inferring type of a from m. 
instance Pivot1 a () where
  pivot1 _ = Nothing

    ( Fractional a, NearZero a
    ) => Pivot1 a ((a:.()):.()) 
    pivot1 ((p:._):._) 
      | nearZero p = Nothing
      | otherwise  = Just (1,p)
    {-# INLINE pivot1 #-}

    ( Fractional a, NearZero a 
    , Map a a (a:.r) (a:.r)
    ) => Pivot1 a ((a:.(a:.r)):.()) 
    pivot1 ((p:.r):._) 
      | nearZero p = Nothing
      | otherwise  = Just ((1 :. (map (/p) r)):.(), p)
    {-# INLINE pivot1 #-}

    ( Fractional a, NearZero a
    , Map a a (a:.r) (a:.r)
    , ZipWith a a a (a:.r) (a:.r) (a:.r) 
    , Map (a:.r) (a:.r) ((a:.r):.rs) ((a:.r):.rs)
    , Pivot1 a ((a:.r):.rs) 
    ) => Pivot1 a ((a:.r):.(a:.r):.rs) 
    pivot1 (row@(p:._):.rows) 
      | nearZero p = pivot1 rows >>= \(r:.rs,p)-> Just(r:.row:.rs,p)
      | otherwise  = Just ( first:.(map add rows) , p)
          where first        = map (/p) row
                add r@(x:._) = zipWith (-) r . map (*x) $ first 
    {-# INLINE pivot1 #-}

-- Pivot : find a pivot. Second return argument tracks determinant.
-- Returns Nothing if no pivot anywhere.

class Pivot a m | m -> a where
  pivot :: m -> Maybe (m,a)

instance Pivot a (():.v) where
  pivot _ = Nothing
  {-# INLINE pivot #-}

    ( Fractional a
    , NearZero a
    , Pivot1 a rs 
    , Tail (a:.r) r
    , Map (a:.r) r ((a:.r):.rs) (r:.rs') 
    , Map r (a:.r) (r:.rs') ((a:.r):.rs)
    , Pivot1 a ((a:.r):.rs)
    , Pivot a (r:.rs')
    ) => Pivot a ((a:.r):.rs) 
    pivot m = 
      mplus (pivot1 m) 
            (pivot (map tail m) >>= return . mapFst (map (0:.)) )
    {-# INLINE pivot #-}

-- | Gaussian elimination, adapted from Mirko Rahn:
-- <http://www.haskell.org/pipermail/glasgow-haskell-users/2007-May/012648.html>
-- This is more of a proof of concept. Using a foreign C function will run
-- slightly faster, and compile much faster. But where is the fun in that?
-- Set your unfolding threshold as high as possible.

class GaussElim a m | m -> a where
  -- | @gaussElim m@ returns a pair @(m',d)@ where @m'@ is @m@ in row echelon
  -- form and @d@ is the determinant of @m@. The determinant of @m'@ is 1 or 0,
  -- i.e., the leading coefficient of each non-zero row is 1.  
  gaussElim :: m -> (m,a)

instance (Num a, Pivot a (r:.())) => GaussElim a (r:.())
    gaussElim m = fromMaybe (m,1) (pivot m) 
    {-# INLINE gaussElim #-}

    ( Fractional a
    , Map (a:.r) r ((a:.r):.rs) rs_
    , Map r (a:.r) rs_ ((a:.r):.rs) 
    , Pivot a ((a:.r):.(a:.r):.rs)
    , GaussElim a rs_
    ) => GaussElim a ((a:.r):.(a:.r):.rs)
    gaussElim m =
      flip (maybe (m,1)) (pivot m) $ \(row:.rows,p) ->
        case gaussElim (map tail rows)
          of (rows',p') -> ( row:.(map (0:.) rows') , p*p')
    {-# INLINE gaussElim #-}

class BackSubstitute m where
  -- | backSubstitute takes a full rank matrix from row echelon form to reduced
  -- row echelon form. Returns @Nothing@ if the matrix is rank deficient. 
  backSubstitute :: m -> Maybe m 

instance NearZero a => BackSubstitute ((a:.r):.()) where
  backSubstitute r@((a:._):._) 
    | nearZero (1-a) = Just r
    | otherwise = Nothing
  {-# INLINE backSubstitute #-}

    ( Map (a:.r) r ((a:.r):.rs) rs_ --map tail
    , Map r (a:.r) rs_ ((a:.r):.rs) --map cons
    , Fold aas (a,a:.r) 
    , ZipWith a a a (a:.r) (a:.r) (a:.r)
    , Map a a (a:.r) (a:.r)
    , ZipWith a (a:.r) (a,a:.r) r ((a:.r):.rs) aas
    , Num a, NearZero a
    , BackSubstitute rs_
    ) => BackSubstitute ((a:.r):.(a:.r):.rs)
    backSubstitute m@(r@(rh:.rt):.rs) 
      | nearZero (1-rh) = 
          liftM (map (0:.)) (backSubstitute . map tail $ rs) >>= \rs' -> 
            return . (:.rs') . foldl (\v (a,w) -> sub v a w) r $ 
              zipWith (,) rt rs'
      | otherwise = Nothing -- rank deficient
          where sub v a = zipWith (-) v . map (*a)
    {-# INLINE backSubstitute #-}

class BackSubstitute' m where
  -- | backSubstitute' takes a full rank matrix from row echelon form to reduced
  -- row echelon form. Returns garbage is matrix is rank deficient.
  backSubstitute' :: m -> m 

instance BackSubstitute' ((a:.r):.()) where
  backSubstitute' = id
  {-# INLINE backSubstitute' #-}

    ( Map (a:.r) r ((a:.r):.rs) rs_ --map tail
    , Map r (a:.r) rs_ ((a:.r):.rs) --map cons
    , Fold aas (a,a:.r) 
    , ZipWith a a a (a:.r) (a:.r) (a:.r)
    , Map a a (a:.r) (a:.r)
    , ZipWith a (a:.r) (a,a:.r) r ((a:.r):.rs) aas
    , Num a
    , BackSubstitute' rs_
    ) => BackSubstitute' ((a:.r):.(a:.r):.rs)
    backSubstitute' (r@(_:.rt):.rs) = 
      case map (0:.) (backSubstitute' . map tail $ rs) 
        of rs' -> (:.rs') $ foldl (\ v (a,w) -> sub v a w) r 
                              (zipWith (,) rt rs')
      where sub v a = zipWith (-) v . map (*a)
    {-# INLINE backSubstitute' #-}

-- | @invert m@ returns @Just@ the inverse of @m@ or @Nothing@ if @m@ is singular.
invert :: forall n a r m r' m'. 
  ( Num r, Num m
  , Vec n a r     -- r is row type
  , Vec n r m     -- m is matrix type
  , Append r r r' -- r' is a row of augmented matrix
  , ZipWith r r r' m m m' -- m' is the augmented matrix
  , Drop n r' r -- get the right half of an augmented matrix row
  , Map r' r m' m -- get the right half of the augmented matrix
  , SetDiagonal r m -- needed to make identity matrix
  , GaussElim a m'
  , BackSubstitute m'
  ) => m -> Maybe m
invert m = 
  return i >>= backSubstitute . fst . gaussElim . zipWith append m 
           >>= return . map dropn
  where dropn = drop (undefined::n)
        i = identity :: m
{-# INLINE invert #-}

-- | inverse and determinant. If det = 0, inverted matrix is garbage.
invertAndDet :: forall n a r m r' m'. 
  ( Num a, Num r, Num m
  , Vec n a r     -- r is row type
  , Vec n r m     -- m is matrix type
  , Append r r r' -- r' is a row of augmented matrix
  , ZipWith r r r' m m m' -- m' is the augmented matrix
  , Drop n r' r -- get the right half of an augmented matrix row
  , Map r' r m' m -- get the right half of the augmented matrix
  , SetDiagonal r m -- needed to make identity matrix
  , GaussElim a m'
  , BackSubstitute m'
  ) => m -> (m,a)
invertAndDet m = 
  case backSubstitute rref of
    Nothing -> (m,0)
    Just m' -> ( map dropn m' , d )
    (rref,d) = gaussElim . zipWith append m $ i
    dropn = drop (undefined::n)
    i = identity :: m
{-# INLINE invertAndDet #-}

-- | Solution of linear system by Gaussian elimination. Returns @Nothing@
-- if no solution. 
solve :: forall n a v r m r' m'. 
  ( Num r, Num m
  , Vec n a r     -- r is row type
  , Vec n r m     -- m is matrix type
  , Snoc r a r'   -- a row of the extended matrix is one longer
  , ZipWith r a r' m r m' -- m' is the augmented matrix
  , Drop n r' (a:.()) -- get the right part of an augmented matrix row
  , Map r' a m' r -- get the right part of the augmented matrix
  , GaussElim a m'
  , BackSubstitute m'
  ) => m -> r -> Maybe r
solve m v = 
  return v >>= backSubstitute . fst . gaussElim . zipWith snoc m 
           >>= return . map (head . drop (undefined::n)) 
{-# INLINE solve #-}