{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE PartialTypeSignatures #-}
module Data.Array.Accelerate.TypeLits
              (
              -- * Types
              AccScalar,
              AccVector,
              AccMatrix,
              -- * Classes
              AccFunctor(..),
              -- * Constructors
              mkMatrix,
              mkVector,
              mkScalar,
              unsafeMkMatrix,
              unsafeMkVector,
              unMatrix,
              unVector,
              unScalar,
              identityMatrix,
              zeroV,
              zeroM,
              -- * Functions
              -- ** Scalar & X
              (.*^),
              (./^),
              (.*#),
              (./#),
              -- ** AccMatrix & Vector
              (#*^),
              (^*#),
              -- ** AccVector & Vector
              (^+^),
              (^-^),
              (^*^),
              -- ** AccMatrix & Matrix
              (#+#),
              (#-#),
              (#*#),
              (#**.),
              -- ** Utility functions
              transpose,
              zipWithV,
              zipWithM,
              )
              where

import qualified Data.Array.Accelerate as A

import           Data.Proxy (Proxy(..))
import           GHC.TypeLits (KnownNat, natVal)
import           Data.Array.Accelerate.TypeLits.Internal
import           Data.Array.Accelerate ( (:.)((:.))
                                       , Exp
                                       , DIM2, DIM3, Z(Z)
                                       , IsFloating, IsNum, Elt
                                       , All(All), Any(Any))

identityMatrix :: forall n a. (KnownNat n, IsNum a, Elt a) => AccMatrix n n a
-- | constructor for the nxn dimensional identity matrix, given by
--
-- > ⎛  1  0  …  0  0  ⎞
-- > ⎜  0  1  …  0  0  ⎟
-- > ⎜  .    .      .  ⎟
-- > ⎜  .     .     .  ⎟
-- > ⎜  .      .    .  ⎟
-- > ⎜  0  0  …  1  0  ⎟
-- > ⎝  0  0  …  0  1  ⎠

identityMatrix = AccMatrix $ A.use $ A.fromFunction (Z:.n':.n') aux
  where aux :: DIM2 -> a
        aux (Z:.i:.j) = if i == j then 1 else 0
        n' = fromIntegral $ natVal (Proxy :: Proxy n)

zeroV :: forall n a. (KnownNat n, IsNum a, Elt a) => AccVector n a
-- | constructor for the n dimensional zero vector, given by
--
-- > ⎛ 0 ⎞
-- > ⎜ . ⎟
-- > ⎜ . ⎟
-- > ⎜ . ⎟
-- > ⎜ . ⎟
-- > ⎜ . ⎟
-- > ⎝ 0 ⎠

zeroV = unsafeMkVector $ replicate n' 0
  where n' = fromIntegral $ natVal (Proxy :: Proxy n)

zeroM :: forall m n a. (KnownNat m, KnownNat n, IsNum a, Elt a) => AccMatrix m n a
-- | constructor for the mxn dimensional zero matrix, given by
--
-- > ⎛  0  0  …  0  0  ⎞
-- > ⎜  0  0  …  0  0  ⎟
-- > ⎜  .  .     .  .  ⎟
-- > ⎜  0  0  …  0  0  ⎟
-- > ⎝  0  0  …  0  0  ⎠

zeroM = unsafeMkMatrix $ replicate (m'*n') 0
  where n' = fromIntegral $ natVal (Proxy :: Proxy n)
        m' = fromIntegral $ natVal (Proxy :: Proxy m)


(#*^) :: forall m n a. (KnownNat m, KnownNat n, IsNum a, Elt a)
      => AccMatrix m n a -> AccVector n a -> AccVector n a
-- | the usual matrix-vector product
--
-- > ⎛ w₁₁ w₁₂ … w₁ₙ ⎞   ⎛x₁⎞   ⎛ w₁₁*x₁ + w₁₂*x₂ + … w₁ₙ*xₙ ⎞
-- > ⎜ w₂₁ w₂₂ … w₂ₙ ⎟   ⎜x₂⎟   ⎜ w₂₁*x₁ + w₂₂*x₂ + … w₂ₙ*xₙ ⎟
-- > ⎜  .   .     .  ⎟   ⎜. ⎟   ⎜  .          .          .   ⎟
-- > ⎜  .   .     .  ⎟ ✕ ⎜. ⎟ = ⎜  .          .          .   ⎟
-- > ⎜  .   .     .  ⎟   ⎜. ⎟   ⎜  .          .          .   ⎟
-- > ⎜  .   .     .  ⎟   ⎜. ⎟   ⎜  .          .          .   ⎟
-- > ⎝ wₘ₁ wₘ₂ … wₘₙ ⎠   ⎝xₙ⎠   ⎝ wₘ₁*x₁ + wₘ₂*x₂ + … wₘₙ*xₙ ⎠

ma #*^ va = let ma' = unMatrix ma
                va' = unVector va
            in AccVector $ A.fold1 (+)
                         $ A.zipWith (*)
                                    ma'
                                    (A.replicate (A.lift $ Z :. m' :. All) va')
  where m'  = fromIntegral $ natVal (Proxy :: Proxy m) :: Int

infixl 7 #*^

(^*#) :: forall m n a. (KnownNat m, KnownNat n, IsNum a, Elt a)
      => AccVector m a -> AccMatrix m n a -> AccVector n a
-- | the usual vector-matrix product
--
-- > ⎛x₁⎞T  ⎛w₁₁ w₁₂ … w₁ₙ ⎞   ⎛ x₁*w₁₁ + x₂*w₁₂ + … xₙ*w₁ₙ ⎞
-- > ⎜x₂⎟   ⎜w₂₁ w₂₂ … w₂ₙ ⎟   ⎜ x₁*w₂₁ + x₂*w₂₂ + … xₙ*w₂ₙ ⎟
-- > ⎜. ⎟   ⎜ .   .     .  ⎟   ⎜  .         .           .   ⎟
-- > ⎜. ⎟ ✕ ⎜ .   .     .  ⎟ = ⎜  .         .           .   ⎟
-- > ⎜. ⎟   ⎜ .   .     .  ⎟   ⎜  .         .           .   ⎟
-- > ⎜. ⎟   ⎜ .   .     .  ⎟   ⎜  .         .           .   ⎟
-- > ⎝xₘ⎠   ⎝wₘ₁ wₘ₂ … wₘₙ ⎠   ⎝ x₁*wₘ₁ + x₂*wₘ₂ + … xₙ*wₘₙ ⎠

va ^*# ma = let va' = unVector va
                ma' = unMatrix ma
            in AccVector $ A.fold1 (+)
                         $ A.zipWith (*)
                                    (A.replicate (A.lift $ Z :. n' :. All) va')
                                    ma'
  where n'  = fromIntegral $ natVal (Proxy :: Proxy n) :: Int

infixr 7 ^*#

(^+^) :: forall n a. (KnownNat n, IsNum a, Elt a)
      => AccVector n a -> AccVector n a -> AccVector n a
-- | the usual vector addition
--
-- > ⎛v₁⎞   ⎛w₁⎞   ⎛ v₁+w₁ ⎞
-- > ⎜v₂⎟   ⎜w₂⎟   ⎜ v₂+w₁ ⎟
-- > ⎜. ⎟   ⎜. ⎟   ⎜   .   ⎟
-- > ⎜. ⎟ + ⎜. ⎟ = ⎜   .   ⎟
-- > ⎜. ⎟   ⎜. ⎟   ⎜   .   ⎟
-- > ⎜. ⎟   ⎜. ⎟   ⎜   .   ⎟
-- > ⎝vₙ⎠   ⎝wₙ⎠   ⎝ vₙ+wₙ ⎠

v ^+^ w = AccVector $ A.zipWith (+) (unVector v) (unVector w)
-- | the usual vector subtraction
--
-- > ⎛v₁⎞   ⎛w₁⎞   ⎛ v₁-w₁ ⎞
-- > ⎜v₂⎟   ⎜w₂⎟   ⎜ v₂-w₁ ⎟
-- > ⎜. ⎟   ⎜. ⎟   ⎜   .   ⎟
-- > ⎜. ⎟ - ⎜. ⎟ = ⎜   .   ⎟
-- > ⎜. ⎟   ⎜. ⎟   ⎜   .   ⎟
-- > ⎜. ⎟   ⎜. ⎟   ⎜   .   ⎟
-- > ⎝vₙ⎠   ⎝wₙ⎠   ⎝ vₙ-wₙ ⎠

(^-^) :: forall n a. (KnownNat n, IsNum a, Elt a)
             => AccVector n a -> AccVector n a -> AccVector n a
v ^-^ w = AccVector $ A.zipWith (-) (unVector v) (unVector w)

infixl 6 ^+^
infixl 6 ^-^

(^*^) :: forall n a. (KnownNat n, IsNum a, Elt a)
      => AccVector n a -> AccVector n a -> AccScalar a
-- | the usual inner product of two vectors
--
-- > ⎛v₁⎞   ⎛w₁⎞
-- > ⎜v₂⎟   ⎜w₂⎟
-- > ⎜. ⎟   ⎜. ⎟
-- > ⎜. ⎟ * ⎜. ⎟ = v₁*w₁ + v₂*w₁ + … + vₙ*wₙ
-- > ⎜. ⎟   ⎜. ⎟
-- > ⎜. ⎟   ⎜. ⎟
-- > ⎝vₙ⎠   ⎝wₙ⎠

v ^*^ w = AccScalar $ A.sum $ A.zipWith (*) (unVector v) (unVector w)

infixl 7 ^*^

(#+#) :: forall m n a. (KnownNat m, KnownNat n, IsNum a, Elt a)
      => AccMatrix m n a -> AccMatrix m n a -> AccMatrix m n a
-- | the usual matrix addition/subtraction
--
-- > ⎛ v₁₁ v₁₂ … v₁ₙ ⎞     ⎛ w₁₁ w₁₂ … w₁ₙ ⎞     ⎛ v₁₁+w₁₁ v₁₂+w₁₂ … v₁ₙ+w₁ₙ ⎞
-- > ⎜ v₂₁ v₂₂ … v₂ₙ ⎟     ⎜ w₂₁ w₂₂ … w₂ₙ ⎟     ⎜ v₂₁+w₂₁ v₂₂+w₂₂ … v₂ₙ+w₂ₙ ⎟
-- > ⎜  .   .     .  ⎟     ⎜  .   .     .  ⎟     ⎜    .       .         .    ⎟
-- > ⎜  .   .     .  ⎟  +  ⎜  .   .     .  ⎟  =  ⎜    .       .         .    ⎟
-- > ⎜  .   .     .  ⎟     ⎜  .   .     .  ⎟     ⎜    .       .         .    ⎟
-- > ⎜  .   .     .  ⎟     ⎜  .   .     .  ⎟     ⎜    .       .         .    ⎟
-- > ⎝ vₘ₁ vₘ₂ … vₘₙ ⎠     ⎝ wₘ₁ wₘ₂ … wₘₙ ⎠     ⎝ vₘ₁+wₘ₁ wₘ₂+vₘ₂ … vₘₙ+wₘₙ ⎠

v #+# w = AccMatrix $ A.zipWith (+) (unMatrix v) (unMatrix w)

(#-#) :: forall m n a. (KnownNat m, KnownNat n, IsNum a, Elt a)
      => AccMatrix m n a -> AccMatrix m n a -> AccMatrix m n a
-- | the usual matrix addition/subtraction
--
-- > ⎛ v₁₁ v₁₂ … v₁ₙ ⎞     ⎛ w₁₁ w₁₂ … w₁ₙ ⎞     ⎛ v₁₁+w₁₁ v₁₂+w₁₂ … v₁ₙ+w₁ₙ ⎞
-- > ⎜ v₂₁ v₂₂ … v₂ₙ ⎟     ⎜ w₂₁ w₂₂ … w₂ₙ ⎟     ⎜ v₂₁+w₂₁ v₂₂+w₂₂ … v₂ₙ+w₂ₙ ⎟
-- > ⎜  .   .     .  ⎟     ⎜  .   .     .  ⎟     ⎜    .       .         .    ⎟
-- > ⎜  .   .     .  ⎟  +  ⎜  .   .     .  ⎟  =  ⎜    .       .         .    ⎟
-- > ⎜  .   .     .  ⎟     ⎜  .   .     .  ⎟     ⎜    .       .         .    ⎟
-- > ⎜  .   .     .  ⎟     ⎜  .   .     .  ⎟     ⎜    .       .         .    ⎟
-- > ⎝ vₘ₁ vₘ₂ … vₘₙ ⎠     ⎝ wₘ₁ wₘ₂ … wₘₙ ⎠     ⎝ vₘ₁+wₘ₁ wₘ₂+vₘ₂ … vₘₙ+wₘₙ ⎠

v #-# w = AccMatrix $ A.zipWith (-) (unMatrix v) (unMatrix w)

infixl 6 #+#
infixl 6 #-#

(#*#) :: forall k m n a. (KnownNat k, KnownNat m, KnownNat n, IsNum a, Elt a)
      => AccMatrix k m a -> AccMatrix m n a -> AccMatrix k n a
-- | the usual matrix multiplication
--
-- > ⎛ v₁₁ v₁₂ … v₁ₘ ⎞     ⎛ w₁₁ w₁₂ … w₁ₙ ⎞     ⎛ (v₁₁*w₁₁+v₁₂*w₂₁+…+v₁ₘ*wₘ₁) . . . (v₁₁*w₁ₙ+v₁₂*w₂ₙ+…+v₁ₘ*wₘₙ) ⎞
-- > ⎜ v₂₁ v₂₂ … v₂ₘ ⎟     ⎜ w₂₁ w₂₂ … w₂ₙ ⎟     ⎜            .                                  .               ⎟
-- > ⎜  .   .     .  ⎟     ⎜  .   .     .  ⎟     ⎜            .                                  .               ⎟
-- > ⎜  .   .     .  ⎟  *  ⎜  .   .     .  ⎟  =  ⎜            .                                  .               ⎟
-- > ⎜  .   .     .  ⎟     ⎜  .   .     .  ⎟     ⎜            .                                  .               ⎟
-- > ⎜  .   .     .  ⎟     ⎜  .   .     .  ⎟     ⎜            .                                  .               ⎟
-- > ⎝ vₖ₁ vₖ₂ … vₖₘ ⎠     ⎝ wₘ₁ wₘ₂ … wₘₙ ⎠     ⎝ (vₖ₁*w₁₁+vₖ₂*w₂₁+…+vₖₘ*wₘ₁) . . . (vₖ₁*w₁ₙ+vₖ₂*w₂ₙ+…+vₖₘ*wₘₙ) ⎠

v #*# w = AccMatrix $ A.fold1 (+)
                    $ A.backpermute (A.lift $ Z:.ek:.en:.em ) reindex
                    $ A.zipWith (*) v' w'
  where [k',m',n'] = map fromIntegral [ natVal (Proxy :: Proxy k)
                                      , natVal (Proxy :: Proxy m)
                                      , natVal (Proxy :: Proxy n)] :: [Int]
        [ek,em,en] = map fromIntegral [k',m',n'] :: [Exp Int]
        v' = A.replicate (A.lift $ Any:.All:.All:.k') (unMatrix v)
        w' = A.replicate (A.lift $ Any:.n':.All:.All) (unMatrix w)
        reindex :: Exp DIM3 -> Exp DIM3
        reindex ix = let (Z:.i:.t:.j) = A.unlift ix
                      in  A.lift (Z:.i:.j:.t :: Z :. Exp Int :. Exp Int :. Exp Int)

infixl 7 #*#

(.*^) :: forall n a. (KnownNat n, IsNum a, Elt a)
      => Exp a -> AccVector n a -> AccVector n a
-- | the usual multiplication of a scalar with a vector
--
-- >     ⎛x₁⎞   ⎛ a*x₁ ⎞
-- >     ⎜x₂⎟   ⎜ a*x₂ ⎟
-- >     ⎜. ⎟   ⎜  .   ⎟
-- > a • ⎜. ⎟ = ⎜  .   ⎟
-- >     ⎜. ⎟   ⎜  .   ⎟
-- >     ⎜. ⎟   ⎜  .   ⎟
-- >     ⎝xₙ⎠   ⎝ a*xₙ ⎠

a .*^ v = let v' = unVector v
          in AccVector $ A.map (* a) v'

(./^) :: forall n a. (KnownNat n, IsFloating a, Elt a)
      => Exp a -> AccVector n a -> AccVector n a
-- | a convenient helper deviding every element of a vector
--
-- >     ⎛x₁⎞   ⎛ x₁/a ⎞
-- >     ⎜x₂⎟   ⎜ x₂/a ⎟
-- >     ⎜. ⎟   ⎜  .   ⎟
-- > a / ⎜. ⎟ = ⎜  .   ⎟
-- >     ⎜. ⎟   ⎜  .   ⎟
-- >     ⎜. ⎟   ⎜  .   ⎟
-- >     ⎝xₙ⎠   ⎝ xₙ/a ⎠
a ./^ v = let v' = unVector v
          in AccVector $ A.map (/ a) v'

infixl 7 .*^
infixl 7 ./^

(.*#) :: forall m n a. (KnownNat m, KnownNat n, IsNum a, Elt a)
      => Exp a -> AccMatrix m n a -> AccMatrix m n a
-- | the usual multiplication of a scalar with a matrix
--
-- >     ⎛ w₁₁ w₁₂ … w₁ₙ ⎞    ⎛ a*w₁₁ a*w₁₂ … a*w₁ₙ ⎞
-- >     ⎜ w₂₁ w₂₂ … w₂ₙ ⎟    ⎜ a*w₂₁ a*w₂₂ … a*w₂ₙ ⎟
-- >     ⎜  .   .     .  ⎟    ⎜  .      .      .    ⎟
-- > a • ⎜  .   .     .  ⎟ =  ⎜  .      .      .    ⎟
-- >     ⎜  .   .     .  ⎟    ⎜  .      .      .    ⎟
-- >     ⎜  .   .     .  ⎟    ⎜  .      .      .    ⎟
-- >     ⎝ wₘ₁ wₘ₂ … wₘₙ ⎠    ⎝ a*wₘ₁ a*wₘ₂ … a*wₘₙ ⎠

a .*# v = let v' = unMatrix v
          in AccMatrix $ A.map (* a) v'

(./#) :: forall m n a. (KnownNat m ,KnownNat n, IsFloating a, Elt a)
      => Exp a -> AccMatrix m n a -> AccMatrix m n a
-- | a convenient helper deviding every element of a matrix
--
-- >     ⎛ w₁₁ w₁₂ … w₁ₙ ⎞    ⎛ w₁₁/a w₁₂/a … w₁ₙ/a ⎞
-- >     ⎜ w₂₁ w₂₂ … w₂ₙ ⎟    ⎜ w₂₁/a w₂₂/a … w₂ₙ/a ⎟
-- >     ⎜  .   .     .  ⎟    ⎜  .      .      .    ⎟
-- > a / ⎜  .   .     .  ⎟ =  ⎜  .      .      .    ⎟
-- >     ⎜  .   .     .  ⎟    ⎜  .      .      .    ⎟
-- >     ⎜  .   .     .  ⎟    ⎜  .      .      .    ⎟
-- >     ⎝ wₘ₁ wₘ₂ … wₘₙ ⎠    ⎝ wₘ₁/a wₘ₂/a … wₘₙ/a ⎠
a ./# v = let v' = unMatrix v
          in AccMatrix $ A.map (/ a) v'

infixl 7 .*#
infixl 7 ./#

(#**.) :: forall n a. (KnownNat n, IsNum a, Elt a)
       => AccMatrix n n a -> Int -> AccMatrix n n a
-- | the exponentiation of a square matrix with an `Int`. Negative exponents
-- raise an error - as inverse matrices are not yet implemented.
--
-- > ⎛ v₁₁ v₁₂ … v₁ₙ ⎞ k
-- > ⎜ v₂₁ v₂₂ … v₂ₙ ⎟
-- > ⎜  .   .     .  ⎟
-- > ⎜  .   .     .  ⎟
-- > ⎜  .   .     .  ⎟
-- > ⎜  .   .     .  ⎟
-- > ⎝ vₙ₁ vₙ₂ … vₙₙ ⎠

_ #**. 0 = identityMatrix
v #**. 1 = v
v #**. i | i < 0 = error $ "no negative exponents allowed in matrix exponetiation,"
                        ++ "inverse function not yet implemented"
         | otherwise = (v#**. (i-1)) #*# v

infixr 8 #**.

transpose :: forall m n a. (KnownNat m, KnownNat n, Elt a)
          => AccMatrix m n a -> AccMatrix n m a
-- | transpose for matrices - note the dimension of the matrix change.
transpose = AccMatrix . A.transpose . unMatrix


zipWithM :: forall m n a b c. (KnownNat m, KnownNat n, Elt a, Elt b, Elt c)
        => (Exp a -> Exp b -> Exp c) -> AccMatrix m n a -> AccMatrix m n b -> AccMatrix m n c
-- | the pendant of the usual zipWith function for matrices, but can only be
-- used with the same dimensions for both input
zipWithM f ma mb = AccMatrix $ A.zipWith f (unMatrix ma) (unMatrix mb)

zipWithV :: forall n a b c. (KnownNat n, Elt a, Elt b, Elt c)
        => (Exp a -> Exp b -> Exp c) -> AccVector n a -> AccVector n b -> AccVector n c
-- | the pendant of the usual zipWith function for vectors, but can only be
-- used with the same dimensions for both input
zipWithV f ma mb = AccVector $ A.zipWith f (unVector ma) (unVector mb)