{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE PartialTypeSignatures #-}
module Data.Array.Accelerate.TypeLits.Internal where

import           GHC.TypeLits ( Nat, KnownNat, natVal)

import           Control.Monad (replicateM)

import qualified Data.Array.Accelerate as A
import qualified Data.Array.Accelerate.Interpreter as I
import           Data.Proxy (Proxy(..))
import           Data.Array.Accelerate ( (:.)((:.)), Array
                                       , Exp
                                       , DIM0, DIM1, DIM2, Z(Z)
                                       , Elt, Acc
                                       )

import           Test.SmallCheck.Series
import           Test.QuickCheck.Arbitrary

newtype AccScalar a = AccScalar { unScalar :: Acc (Array DIM0 a)}
                    deriving (Show)

instance forall a. (Eq a, Elt a) => Eq (AccScalar a) where
    s == t = let s' = I.run $ unScalar s
                 t' = I.run $ unScalar t
              in A.toList s' == A.toList t'

-- | A typesafe way to represent an AccVector and its dimension
newtype AccVector (dim :: Nat) a = AccVector { unVector :: Acc (Array DIM1 a)}
                                 deriving (Show)

instance forall n a. (KnownNat n, Eq a, Elt a) => Eq (AccVector n a) where
    v == w = let v' = I.run $ unVector v
                 w' = I.run $ unVector w
              in A.toList v' == A.toList w'

instance forall mm n a. (Serial mm a, KnownNat n, Eq a, Elt a)
  => Serial mm (AccVector n a) where
      series = AccVector . A.use . A.fromList (Z:.n') <$> cons1 (replicate n')
        where n' = fromIntegral $ natVal (Proxy :: Proxy n)

instance forall n a. (KnownNat n, Arbitrary a, Eq a, Elt a)
  => Arbitrary (AccVector n a) where
      arbitrary = AccVector . A.use . A.fromList (Z:.n') <$> replicateM n' arbitrary
        where n' = fromIntegral $ natVal (Proxy :: Proxy n)

-- | A typesafe way to represent an AccMatrix and its rows/colums
newtype AccMatrix (rows :: Nat) (cols :: Nat) a = AccMatrix {unMatrix :: Acc (Array DIM2 a)}
                                                deriving (Show)

instance forall m n a. (KnownNat m, KnownNat n, Eq a, Elt a) => Eq (AccMatrix m n a) where
    v == w = let v' = I.run $ unMatrix v
                 w' = I.run $ unMatrix w
              in A.toList v' == A.toList w'

instance forall mm m n a. (Serial mm a, KnownNat m, KnownNat n, Eq a, Elt a)
  => Serial mm (AccMatrix m n a) where
      series = AccMatrix . A.use . A.fromList (Z:.m':.n') <$> cons1 (replicate $ m'*n')
        where m' = fromIntegral $ natVal (Proxy :: Proxy m)
              n' = fromIntegral $ natVal (Proxy :: Proxy n)

instance forall m n a. (KnownNat m, KnownNat n, Arbitrary a, Eq a, Elt a)
  => Arbitrary (AccMatrix m n a) where
      arbitrary = AccMatrix . A.use . A.fromList (Z:.m':.n') <$> replicateM (m'*n') arbitrary
        where m' = fromIntegral $ natVal (Proxy :: Proxy m)
              n' = fromIntegral $ natVal (Proxy :: Proxy n)

-- | a functor like instance for a functor like instance for Accelerate computations
-- instead of working with simple functions `(a -> b)` this uses (Exp a -> Exp b)
class AccFunctor f where
  afmap :: forall a b. (Elt a, Elt b) => (Exp a -> Exp b) -> f a -> f b

instance AccFunctor AccScalar  where
    afmap f (AccScalar a) = AccScalar (A.map f a)

instance forall n. (KnownNat n) => AccFunctor (AccVector n) where
    afmap f (AccVector a) = AccVector (A.map f a)

instance forall m n. (KnownNat m, KnownNat n) => AccFunctor (AccMatrix m n) where
    afmap f (AccMatrix a) = AccMatrix (A.map f a)

mkVector :: forall n a. (KnownNat n, Elt a) => [a] -> Maybe (AccVector n a)
-- | a smart constructor to generate Vectors - returning Nothing
-- if the input list is not as long as the dimension of the Vector
mkVector as = if length as == n'
                 then Just $ unsafeMkVector as
                 else Nothing
  where n' = fromIntegral $ natVal (Proxy :: Proxy n)

unsafeMkVector :: forall n a. (KnownNat n, Elt a) => [a] -> AccVector n a
-- | unsafe smart constructor to generate Vectors
-- the length of the input list is not checked
unsafeMkVector as = AccVector (A.use $ A.fromList (Z:.n') as)
  where n' = fromIntegral $ natVal (Proxy :: Proxy n)

mkMatrix :: forall m n a. (KnownNat m, KnownNat n, Elt a)
         => [a] -> Maybe (AccMatrix m n a)
-- | a smart constructor to generate Matrices - returning Nothing
-- if the input list is not as long as the "length" of the Matrix, i.e. rows*colums
mkMatrix as = if length as == m'*n'
                 then Just $ unsafeMkMatrix as
                 else Nothing
  where m' = fromIntegral $ natVal (Proxy :: Proxy m)
        n' = fromIntegral $ natVal (Proxy :: Proxy n)

unsafeMkMatrix :: forall m n a. (KnownNat m, KnownNat n, Elt a)
         => [a] -> AccMatrix m n a
-- | unsafe smart constructor to generate Matrices
-- the length of the input list is not checked
unsafeMkMatrix as = AccMatrix (A.use $ A.fromList (Z:. m':.n') as)
  where m' = fromIntegral $ natVal (Proxy :: Proxy m)
        n' = fromIntegral $ natVal (Proxy :: Proxy n)

mkScalar :: forall a. Elt a => Exp a -> AccScalar a
-- | a smart constructor to generate scalars
mkScalar = AccScalar . A.unit