{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Array.Knead.Shape.Cubic (
   constant,
   paramWith,
   tunnel,

   T(..),
   Z(Z), z,
   (:.)((:.)),
   Shape,
   Index,
   cons, (#:.),
   head,
   tail,
   switchR,
   ) where

import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Shape.Cubic.Int as Index

import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )

import qualified Data.Array.Comfort.Shape as ComfortShape
import Data.Array.Comfort.Shape (ZeroBased(ZeroBased))

import qualified LLVM.DSL.Parameter as Param

import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Iterator as IterMV
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Control as C
import LLVM.Extra.Multi.Value (Atom)

import qualified LLVM.Core as LLVM

import qualified Foreign.Storable as St
import Foreign.Storable.FixedArray (sizeOfArray, )
import Foreign.Ptr (castPtr, )

import qualified Type.Data.Num.Decimal as Dec
import qualified Type.Data.Num.Unary as Unary
import Type.Base.Proxy (Proxy(Proxy))

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified Data.FixedLength as FixedLength
import Data.FixedLength ((!:))

import Control.Monad (liftM2, )
import Control.Applicative (pure, (<$>), )

import Prelude hiding (min, head, tail, )


newtype T tag rank = Cons {decons :: FixedLength.T rank Index.Int}

data ShapeTag
data IndexTag

type Shape = T ShapeTag
type Index = T IndexTag


paramWith ::
   (Unary.Natural rank,
    Dec.Natural (Dec.FromUnary rank),
    Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) =>
   Param.T p (T tag rank) ->
   (forall parameters.
    (Marshal.MV parameters) =>
    (p -> parameters) ->
    (forall val. (Expr.Value val) =>
     MultiValue.T parameters -> val (T tag rank)) ->
    a) ->
   a
paramWith p f =
   case tunnel p of
      Param.Tunnel get val -> f get (Expr.lift0 . val)

tunnel ::
   (Unary.Natural rank,
    Dec.Natural (Dec.FromUnary rank),
    Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) =>
   Param.T p (T tag rank) -> Param.Tunnel p (T tag rank)
tunnel p = Param.tunnel MultiValue.cons p


data Z = Z
   deriving (Eq, Ord, Read, Show)


infixl 3 :., #:.

data tail :. head = !tail :. !head
   deriving (Eq, Ord, Read, Show)


(#:.) ::
   (Expr.Value val) =>
   val (T tag rank) -> val Index.Int -> val (T tag (Unary.Succ rank))
(#:.) = cons

cons ::
   (Expr.Value val) =>
   val (T tag rank) -> val Index.Int -> val (T tag (Unary.Succ rank))
cons =
   Expr.lift2 $
      \(MultiValue.Cons t) (MultiValue.Cons h) -> MultiValue.Cons (h!:t)

z :: (Expr.Value val) => val (T tag Unary.Zero)
z = Expr.lift0 $ MultiValue.Cons FixedLength.end

head ::
   (Expr.Value val, Unary.Natural rank) =>
   val (T tag (Unary.Succ rank)) -> val Index.Int
head =
   Expr.lift1 $ \(MultiValue.Cons sh) -> MultiValue.Cons $ FixedLength.head sh

tail ::
   (Expr.Value val, Unary.Natural rank) =>
   val (T tag (Unary.Succ rank)) -> val (T tag rank)
tail =
   Expr.lift1 $ \(MultiValue.Cons sh) -> MultiValue.Cons $ FixedLength.tail sh

switchR ::
   (Unary.Natural rank) =>
   Expr.Value val =>
   (val (T tag rank) -> val Index.Int -> a) ->
   val (T tag (Unary.Succ rank)) -> a
switchR f ix = f (tail ix) (head ix)


rank :: T tag rank -> Proxy rank
rank (Cons _) = Proxy


instance (tag ~ ShapeTag, rank ~ Unary.Zero) => Shape.Scalar (T tag rank) where
   scalar = Expr.lift0 $ MultiValue.Cons FixedLength.end
   zeroIndex _ = Expr.lift0 $ MultiValue.Cons FixedLength.end


type family AtomRank sh
type instance AtomRank (Atom (T tag rank)) = rank
type instance AtomRank (sh:.s) = Unary.Succ (AtomRank s)

type family AtomTag sh
type instance AtomTag (Atom (T tag rank)) = tag
type instance AtomTag (sh:.s) = AtomTag sh

type instance MultiValue.PatternTuple (sh:.s) =
   T (AtomTag sh) (Unary.Succ (AtomRank sh))

type instance MultiValue.Decomposed f (sh:.s) =
   MultiValue.Decomposed f sh :. f Index.Int

instance
   (Expr.Decompose sh, Expr.Decompose s,
    MultiValue.Decomposed Exp s ~ Exp Index.Int,
    MultiValue.PatternTuple s ~ Index.Int,
    MultiValue.PatternTuple sh ~ T (AtomTag sh) (AtomRank sh),
    Unary.Natural (AtomRank sh)) =>
      Expr.Decompose (sh :. s) where
   decompose (psh:.ps) x =
      Expr.decompose psh (tail x) :. Expr.decompose ps (head x)


type family Rank sh
type instance Rank (T tag rank) = rank

type family Tag sh
type instance Tag (T tag rank) = tag

instance
   (Expr.Compose sh,
    Expr.Composed sh ~ T (Tag (Expr.Composed sh)) (Rank (Expr.Composed sh)),
    Expr.Compose s,
    Expr.Composed s ~ Index.Int) =>
      Expr.Compose (sh :. s) where
   type Composed (sh :. s) =
            T (Tag (Expr.Composed sh)) (Unary.Succ (Rank (Expr.Composed sh)))
   compose (sh :. s) = cons (Expr.compose sh) (Expr.compose s)


instance (Unary.Natural rank) => St.Storable (T tag rank) where
   sizeOf sh = sizeOfArray (Unary.integralFromProxy $ rank sh) (0::Shape.Size)
   alignment (Cons _sh) = St.alignment (0::Shape.Size)
   poke ptr = St.poke (castPtr ptr) . fmap (\(Index.Int i) -> i) . decons
   peek = fmap (Cons . fmap Index.Int) . St.peek . castPtr

instance
   (Unary.Natural rank,
    Dec.Natural (Dec.FromUnary rank),
    Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) =>
      Marshal.C (T tag rank) where
   pack = LLVM.Array . map Marshal.pack . Fold.toList . decons
   unpack (LLVM.Array sh) = Cons $ toFixedList $ map Marshal.unpack sh

toFixedList :: (Unary.Natural n) => [a] -> FixedLength.T n a
toFixedList xs = snd $ Trav.mapAccumL (\(y:ys) () -> (ys,y)) xs (pure ())

instance
   (Unary.Natural rank,
    Dec.Natural (Dec.FromUnary rank),
    Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) =>
      Marshal.MV (T tag rank) where


instance (Unary.Natural rank) => Tuple.Value (T tag rank) where
   type ValueOf (T tag rank) = FixedLength.T rank (Tuple.ValueOf Index.Int)
   valueOf = fmap Tuple.valueOf . decons


instance (Unary.Natural rank) => MultiValue.C (T tag rank) where
   cons = MultiValue.Cons . fmap (\(Index.Int i) -> LLVM.valueOf i) . decons
   undef = constant $ MultiValue.undef
   zero = constant $ MultiValue.zero
   addPhi bb (MultiValue.Cons a) (MultiValue.Cons b) =
      Tuple.addPhiFoldable bb a b
   phi bb (MultiValue.Cons a) =
      fmap MultiValue.Cons . Tuple.phiTraversable bb $ a

constant ::
   (Unary.Natural rank) => MultiValue.T Index.Int -> MultiValue.T (T tag rank)
constant (MultiValue.Cons x) = MultiValue.Cons $ pure x

instance
   (tag ~ ShapeTag, Unary.Natural rank) =>
      ComfortShape.C (T tag rank) where
   size = Fold.product . fmap (ComfortShape.size . shapeFromInt) . decons

instance
   (tag ~ ShapeTag, Unary.Natural rank) =>
      ComfortShape.Indexed (T tag rank) where
   type Index (T tag rank) = Index rank
   indices (Cons ix) =
      map (Cons . fmap Index.Int) $
      Trav.mapM (ComfortShape.indices . shapeFromInt) ix
   inBounds (Cons sh) (Cons ix) =
      Fold.and $
      FixedLength.zipWith ComfortShape.inBounds
         (shapeFromInt <$> sh) (indexFromInt <$> ix)
   offset (Cons sh) (Cons ix) =
      Fold.foldl'
         (\off (s,i) -> off * ComfortShape.size s + fromIntegral i) 0 $
      FixedLength.zipWith (,) (shapeFromInt <$> sh) (indexFromInt <$> ix)

shapeFromInt :: Index.Int -> ZeroBased Shape.Size
shapeFromInt (Index.Int i) = ZeroBased i

indexFromInt :: Index.Int -> Shape.Size
indexFromInt (Index.Int i) = i


instance (tag ~ ShapeTag, Unary.Natural rank) => Shape.C (T tag rank) where
   size (MultiValue.Cons sh) = Fold.foldlM A.mul A.one sh
   intersectCode (MultiValue.Cons sh0) (MultiValue.Cons sh1) =
      fmap MultiValue.Cons $ Trav.sequence $ FixedLength.zipWith A.min sh0 sh1
   sizeOffset sh =
      -- would a joint implementation be more efficient?
      liftM2 (,) (Shape.size sh) (return $ offsetCode sh)
   iterator = iterator
   loop = loop


offsetCode ::
   (Unary.Natural rank) =>
   MultiValue.T (Shape rank) -> MultiValue.T (Index rank) ->
   LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
offsetCode (MultiValue.Cons sh) (MultiValue.Cons ix) =
   Fold.foldlM (\off (s,i) -> A.mul off s >>= A.add i) A.zero $
   FixedLength.zipWith (,) sh ix


newtype Iterator r rank =
   Iterator {
      runIterator ::
         MultiValue.T (Shape rank) -> Iter.T r (MultiValue.T (Index rank))
   }

iterator ::
   (Unary.Natural rank) =>
   MultiValue.T (Shape rank) -> Iter.T r (MultiValue.T (Index rank))
iterator =
   runIterator $
   Unary.switchNat
      (Iterator $ \ _z -> Iter.singleton z)
      (Iterator $ switchR $ \sh n ->
       fmap (\(ix,i) -> ix#:.i) $
       Iter.cartesian
         (iterator sh)
         (IterMV.takeWhile (MultiValue.cmp LLVM.CmpGT n) $
          Iter.iterate MultiValue.inc MultiValue.zero))


newtype Loop r state rank =
   Loop {
      runLoop ::
         (MultiValue.T (Index rank) ->
          state ->
          LLVM.CodeGenFunction r state) ->
         MultiValue.T (Shape rank) ->
         state ->
         LLVM.CodeGenFunction r state
   }

loop ::
   (Unary.Natural rank, Tuple.Phi state) =>
   (MultiValue.T (Index rank) ->
    state ->
    LLVM.CodeGenFunction r state) ->
   MultiValue.T (Shape rank) ->
   state ->
   LLVM.CodeGenFunction r state
loop =
   runLoop $
   Unary.switchNat
      (Loop $ \code _z -> code z)
      (Loop $ \code -> switchR $ \sh (MultiValue.Cons n) ->
         loop
            (\ix ptrStart ->
               fmap fst $
               C.fixedLengthLoop n (ptrStart, A.zero) $ \(ptr, k) ->
                  liftM2 (,)
                     (code (ix #:. MultiValue.Cons k) ptr)
                     (A.inc k))
            sh)