{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Data.Array.Knead.Shape.Cubic (
   C(switch),
   switchInt,
   intersect,
   value,
   constant,
   paramWith,
   tunnel,
   offsetCode,
   peek,
   poke,
   computeSize,

   Struct,
   T(..),
   Z(Z), z,
   (:.)((:.)),
   Shape, shape,
   Index, index,
   cons, (#:.),
   head,
   tail,
   switchR,
   loadMultiValue,
   storeMultiValue,
   ) where

import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Shape.Cubic.Int as Index

import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )

import qualified Data.Array.Comfort.Shape as ComfortShape
import Data.Array.Comfort.Shape (ZeroBased(ZeroBased))

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Iterator as IterMV
import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import LLVM.Extra.Multi.Value (Atom, )

import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM

import qualified Foreign.Storable as St
import Foreign.Storable.FixedArray (sizeOfArray, )
import Foreign.Marshal.Array (advancePtr, )
import Foreign.Ptr (Ptr, castPtr, )

import Control.Monad (liftM2, )

import Prelude hiding (min, head, tail, )


class C ix where
   switch ::
      f Z ->
      (forall ix0 i. (C ix0, Index.Single i) => f (ix0 :. i)) ->
      f ix

instance C Z where
   switch x _ = x

instance (C ix0, Index.Single i) => C (ix0 :. i) where
   switch _ x = x


newtype SwitchInt f ix i = SwitchInt {runSwitchInt :: f (ix :. i)}

switchInt ::
   (C ix) =>
   f Z ->
   (forall ix0. (C ix0) => f (ix0 :. Index.Int)) ->
   f ix
switchInt z0 cons0 =
   switch z0
      (runSwitchInt $ Index.switchSingle (SwitchInt cons0))


newtype Op2 tag sh =
         Op2 {runOp2 :: Exp (T tag sh) -> Exp (T tag sh) -> Exp (T tag sh)}

intersect :: C sh => Exp (Shape sh) -> Exp (Shape sh) -> Exp (Shape sh)
intersect =
   runOp2 $
   switchInt
      (Op2 $ \z0 _ -> z0)
      (Op2 $
       switchR $ \is i ->
       switchR $ \js j ->
          intersect is js #:. Expr.min i j)


_value :: (C sh, MultiValue.C sh) => sh -> Exp sh
_value = Expr.lift0 . MultiValue.cons


newtype MakeValue val tag sh =
         MakeValue {runMakeValue :: T tag sh -> val (T tag sh)}

value :: (C sh, Expr.Value val) => T tag sh -> val (T tag sh)
value =
   runMakeValue $
   switchInt
      (MakeValue $ \(Cons Z) -> z)
      (MakeValue $ \(Cons (t:.h)) ->
         value (Cons t) #:. Expr.lift0 (MultiValue.cons h))

paramWith ::
   (C sh, Expr.Value val) =>
   Param.T p (T tag sh) ->
   (forall parameters.
    (St.Storable parameters,
     MultiValueMemory.C parameters) =>
    (p -> parameters) ->
    (MultiValue.T parameters -> val (T tag sh)) ->
    a) ->
   a
paramWith p f =
   case tunnel p of
      Param.Tunnel get val -> f get (Expr.lift0 . val)

tunnel :: (C sh) => Param.T p (T tag sh) -> Param.Tunnel p (T tag sh)
tunnel p =
   case structFieldsPropF p of
      StructFieldsProp -> Param.tunnel value p


data StructFieldsProp sh = LLVM.StructFields (Struct sh) => StructFieldsProp

_structFieldsProp :: (C sh) => f sh -> StructFieldsProp sh
_structFieldsProp _p = structFieldsRec

structFieldsPropF :: (C sh) => f (g sh) -> StructFieldsProp sh
structFieldsPropF _p = structFieldsRec

withStructFieldsPropFF ::
   (C sh) => (StructFieldsProp sh -> f (g (h sh))) -> f (g (h sh))
withStructFieldsPropFF f = f structFieldsRec

structFieldsRec :: (C sh) => StructFieldsProp sh
structFieldsRec =
   switchInt
      StructFieldsProp
      (succStructFieldsProp structFieldsRec)

succStructFieldsProp ::
   StructFieldsProp sh -> StructFieldsProp (sh:.Index.Int)
succStructFieldsProp StructFieldsProp = StructFieldsProp


data Z = Z
   deriving (Eq, Ord, Read, Show)


infixl 3 :., #:.

data tail :. head = !tail :. !head
   deriving (Eq, Ord, Read, Show)


newtype T tag sh = Cons {decons :: sh}

data ShapeTag
data IndexTag

type Shape = T ShapeTag
type Index = T IndexTag

shape :: sh -> Shape sh
shape = Cons

index :: ix -> Index ix
index = Cons


(#:.) :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i))
(#:.) = cons

cons :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i))
cons =
   Expr.lift2 $
      \(MultiValue.Cons t) (MultiValue.Cons h) ->
         MultiValue.Cons (t,h)

z :: (Expr.Value val) => val (T tag Z)
z = Expr.lift0 $ MultiValue.Cons ()

head :: (Expr.Value val) => val (T tag (sh:.i)) -> val i
head = Expr.lift1 $ \(MultiValue.Cons (_t,h)) -> MultiValue.Cons h

tail :: (Expr.Value val) => val (T tag (sh:.i)) -> val (T tag sh)
tail = Expr.lift1 $ \(MultiValue.Cons (t,_h)) -> MultiValue.Cons t

switchR ::
   Expr.Value val =>
   (val (T tag sh) -> val i -> a) -> val (T tag (sh :. i)) -> a
switchR f ix = f (tail ix) (head ix)


instance (tag ~ ShapeTag, sh ~ Z) => Shape.Scalar (T tag sh) where
   scalar = Expr.lift0 $ MultiValue.Cons ()
   zeroIndex _ = Expr.lift0 $ MultiValue.Cons ()


type family PatternTuple pattern
type family Decomposed (f :: * -> *) tag pattern

type instance PatternTuple (sh:.s) =
   PatternTuple sh :. MultiValue.PatternTuple s

type instance Decomposed f tag (sh:.s) =
   Decomposed f tag sh :. MultiValue.Decomposed f s

type instance PatternTuple (Atom sh) = sh

type instance Decomposed f tag (Atom sh) = f (T tag sh)


class
   (Expr.Composed (Decomposed Exp tag pattern) ~ T tag (PatternTuple pattern)) =>
      Decompose tag pattern where
   decompose ::
      T tag pattern -> Exp (T tag (PatternTuple pattern)) ->
      Decomposed Exp tag pattern

instance Decompose tag (Atom sh) where
   decompose (Cons _atom) x = x

instance (Decompose tag sh, Expr.Decompose s) => Decompose tag (sh :. s) where
   decompose (Cons (psh:.ps)) x =
      decompose (Cons psh) (tail x) :. Expr.decompose ps (head x)


type instance MultiValue.PatternTuple (T tag sh) = T tag (PatternTuple sh)

type instance MultiValue.Decomposed f (T tag sh) = Decomposed f tag sh


type family Unwrap sh
type instance Unwrap (T tag sh) = sh

type family Tag sh
type instance Tag (T tag sh) = tag

instance
   (Expr.Compose sh,
    Expr.Composed sh ~ T (Tag (Expr.Composed sh)) (Unwrap (Expr.Composed sh)),
    Expr.Compose s) =>
      Expr.Compose (sh :. s) where
   type Composed (sh :. s) =
           T (Tag (Expr.Composed sh))
             (Unwrap (Expr.Composed sh) :. Expr.Composed s)
   compose (sh :. s) = cons (Expr.compose sh) (Expr.compose s)

instance (Decompose tag sh) => Expr.Decompose (T tag sh) where
   decompose = decompose



instance (C sh) => St.Storable (T tag sh) where
   sizeOf (Cons sh) = sizeOfArray (rank sh) (0::Shape.Size)
   alignment (Cons _sh) = St.alignment (0::Shape.Size)
   poke ptr = poke (castPtr ptr) . decons
   peek = fmap Cons . peek . castPtr


type family Repr (f :: * -> *) sh
type instance Repr f Z = ()
type instance Repr f (tail :. head) = (Repr f tail, MultiValue.Repr f head)

instance (C sh) => MultiValue.C (T tag sh) where
   type Repr f (T tag sh) = Repr f sh
   cons = value
   undef = constant $ MultiValue.undef
   zero = constant $ MultiValue.zero
   addPhis = addPhis
   phis = phis

instance (tag ~ ShapeTag, C sh) => ComfortShape.C (T tag sh) where
   size = fromIntegral . size . decons

instance (tag ~ ShapeTag, C sh) => ComfortShape.Indexed (T tag sh) where
   type Index (T tag sh) = Index sh
   indices (Cons ix) = map index $ indices ix
   inBounds (Cons sh) (Cons ix) = inBounds sh ix
   offset (Cons sh) (Cons ix) = offset sh ix


newtype Indices sh = Indices {runIndices :: sh -> [sh]}

indices :: (C sh) => sh -> [sh]
indices =
   runIndices $
   switchInt
      (Indices $ \Z -> [Z])
      (Indices $ \(t :. Index.Int h) ->
         liftM2 (:.) (indices t)
            (map Index.Int $ ComfortShape.indices $ ZeroBased h))

newtype InBounds sh = InBounds {runInBounds :: sh -> sh -> Bool}

inBounds :: (C sh) => sh -> sh -> Bool
inBounds =
   runInBounds $
   switchInt
      (InBounds $ \Z Z -> True)
      (InBounds $ \(sh :. Index.Int s) (ix :. Index.Int i) ->
         inBounds sh ix && ComfortShape.inBounds (ZeroBased s) i)

newtype Offset sh = Offset {runOffset :: sh -> sh -> Int}

offset :: (C sh) => sh -> sh -> Int
offset =
   runOffset $
   switchInt
      (Offset $ \Z Z -> 0)
      (Offset $ \(sh :. Index.Int s) (ix :. Index.Int i) ->
         offset sh ix * fromIntegral s + fromIntegral i)



instance (tag ~ ShapeTag, C sh) => Shape.C (T tag sh) where
   size = computeSize
   intersectCode = Expr.unliftM2 intersect
   sizeOffset sh =
      -- would a joint implementation be more efficient?
      liftM2 (,)
         (computeSize sh)
         (return $ offsetCode sh)
   iterator = iterator
   loop = loop


type family Struct sh
type instance Struct Z = ()
type instance Struct (sh :. Index.Int) = (Shape.Size, Struct sh)

instance
   (C sh, LLVM.StructFields (Struct sh)) =>
      MultiValueMemory.C (T tag sh) where
   type Struct (T tag sh) = LLVM.Struct (Struct sh)
   load = loadMultiValue
   store = storeMultiValue

loadMultiValue ::
   (C sh) =>
   LLVM.Value (Ptr (LLVM.Struct (Struct sh))) ->
   LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
loadMultiValue ptr =
   withStructFieldsPropFF $ \StructFieldsProp ->
      load =<< castPtrValue ptr

storeMultiValue ::
   (C sh) =>
   MultiValue.T (T tag sh) ->
   LLVM.Value (Ptr (LLVM.Struct (Struct sh))) -> LLVM.CodeGenFunction r ()
storeMultiValue x ptr =
   case structFieldsPropF x of
      StructFieldsProp -> store x =<< castPtrValue ptr


newtype OffsetCode r sh =
   OffsetCode {
      runOffsetCode ::
         MultiValue.T (Shape sh) -> MultiValue.T (Index sh) ->
         LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
   }

offsetCode ::
   (C sh) =>
   MultiValue.T (Shape sh) -> MultiValue.T (Index sh) ->
   LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
offsetCode =
   runOffsetCode $
   switchInt
      (OffsetCode $ \_zerosh _zeroix -> return A.zero)
      (OffsetCode $
         switchR $ \sh (MultiValue.Cons s) ->
         switchR $ \ix (MultiValue.Cons i) ->
            A.add i =<< A.mul s =<< offsetCode sh ix)


newtype Rank sh = Rank {runRank :: sh -> Int}

rank :: (C sh) => sh -> Int
rank =
   runRank $
   switch
      (Rank $ const 0)
      (Rank $ succ . rank . (\(sh :. _s) -> sh))


newtype Peek sh = Peek {runPeek :: Ptr Shape.Size -> IO sh}

peek :: (C sh) => Ptr Shape.Size -> IO sh
peek =
   runPeek $
   switchInt
      (Peek $ const $ return Z)
      (Peek $ \ptr -> do
         h <- St.peek ptr
         t <- peek $ advancePtr ptr 1
         return (t :. Index.Int h))


newtype Poke sh = Poke {runPoke :: Ptr Shape.Size -> sh -> IO ()}

poke :: (C sh) => Ptr Shape.Size -> sh -> IO ()
poke =
   runPoke $
   switchInt
      (Poke $ const $ const $ return ())
      (Poke $ \ptr (sh :. Index.Int i) -> do
         St.poke ptr i
         poke (advancePtr ptr 1) sh)


castPtrValue ::
   (LLVM.StructFields sh) =>
   LLVM.Value (Ptr (LLVM.Struct sh)) ->
   LLVM.CodeGenFunction r (LLVM.Value (Ptr Shape.Size))
castPtrValue = LLVM.bitcast

newtype Load r tag sh =
   Load {
      runLoad ::
         LLVM.Value (Ptr Shape.Size) ->
         LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
   }

load ::
   (C sh) =>
   LLVM.Value (Ptr Shape.Size) ->
   LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
load =
   runLoad $
   switchInt
      (Load $ const $ return z)
      (Load $ \ptr -> do
         h <- LLVM.load ptr
         t <- load =<< A.advanceArrayElementPtr ptr
         return (t #:. MultiValue.Cons h))


newtype Store r tag sh =
   Store {
      runStore ::
         MultiValue.T (T tag sh) ->
         LLVM.Value (Ptr Shape.Size) ->
         LLVM.CodeGenFunction r ()
   }

store ::
   (C sh) =>
   MultiValue.T (T tag sh) ->
   LLVM.Value (Ptr Shape.Size) ->
   LLVM.CodeGenFunction r ()
store =
   runStore $
   switchInt
      (Store $ \_z _ptr -> return ())
      (Store $ switchR $ \sh (MultiValue.Cons k) ptr -> do
         LLVM.store k ptr
         store sh =<< A.advanceArrayElementPtr ptr)


newtype Size sh = Size {runSize :: sh -> Shape.Size}

size :: (C sh) => sh -> Shape.Size
size =
   runSize $
   switchInt
      (Size $ \_z -> 1)
      (Size $ \(sh :. Index.Int k) -> k * size sh)


newtype ComputeSize r sh =
   ComputeSize {
      runComputeSize ::
         MultiValue.T (Shape sh) ->
         LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
   }

computeSize ::
   (C sh) =>
   MultiValue.T (Shape sh) ->
   LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
computeSize =
   runComputeSize $
   switchInt
      (ComputeSize $ \_z -> return A.one)
      (ComputeSize $ switchR $ \sh (MultiValue.Cons k) ->
         A.mul k =<< computeSize sh)


newtype
   Constant val tag sh =
      Constant {getConstant :: val Index.Int -> val (T tag sh)}

constant :: (C sh, Expr.Value val) => val Index.Int -> val (T tag sh)
constant =
   getConstant $
   switchInt
      (Constant $ const z)
      (Constant $ \x -> constant x #:. x)


newtype AddPhis r tag sh =
   AddPhis {
      runAddPhis ::
         LLVM.BasicBlock ->
         MultiValue.T (T tag sh) ->
         MultiValue.T (T tag sh) ->
         LLVM.CodeGenFunction r ()
   }

addPhis ::
   (C sh) =>
   LLVM.BasicBlock ->
   MultiValue.T (T tag sh) ->
   MultiValue.T (T tag sh) ->
   LLVM.CodeGenFunction r ()
addPhis =
   runAddPhis $
   switchInt
      (AddPhis $ \_ _ _ -> return ())
      (AddPhis $ \bb ->
       switchR $ \hx tx ->
       switchR $ \hy ty ->
          MultiValue.addPhis bb tx ty >>
          addPhis bb hx hy)


newtype Phis r tag sh =
   Phis {
      runPhis ::
         LLVM.BasicBlock ->
         MultiValue.T (T tag sh) ->
         LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
   }

phis ::
   (C sh) =>
   LLVM.BasicBlock ->
   MultiValue.T (T tag sh) ->
   LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
phis =
   runPhis $
   switchInt
      (Phis $ \_ -> return)
      (Phis $ \bb ->
       switchR $ \h t ->
          liftM2 (#:.)
             (phis bb h)
             (MultiValue.phis bb t))


newtype Iterator r sh =
   Iterator {
      runIterator ::
         MultiValue.T (Shape sh) -> Iter.T r (MultiValue.T (Index sh))
   }

iterator ::
   (C sh) =>
   MultiValue.T (Shape sh) -> Iter.T r (MultiValue.T (Index sh))
iterator =
   runIterator $
   switchInt
      (Iterator $ \ _z -> Iter.empty)
      (Iterator $ switchR $ \sh n ->
       fmap (\(ix,i) -> ix#:.i) $
       Iter.cartesian
         (iterator sh)
         (IterMV.takeWhile (MultiValue.cmp LLVM.CmpGT n) $
          Iter.iterate MultiValue.inc MultiValue.zero))


newtype Loop r state sh =
   Loop {
      runLoop ::
         (MultiValue.T (Index sh) ->
          state ->
          LLVM.CodeGenFunction r state) ->
         MultiValue.T (Shape sh) ->
         state ->
         LLVM.CodeGenFunction r state
   }

loop ::
   (C sh, Loop.Phi state) =>
   (MultiValue.T (Index sh) ->
    state ->
    LLVM.CodeGenFunction r state) ->
   MultiValue.T (Shape sh) ->
   state ->
   LLVM.CodeGenFunction r state
loop =
   runLoop $
   switchInt
      (Loop $ \code _z -> code z)
      (Loop $ \code -> switchR $ \sh (MultiValue.Cons n) ->
         loop
            (\ix ptrStart ->
               fmap fst $
               C.fixedLengthLoop n (ptrStart, A.zero) $ \(ptr, k) ->
                  liftM2 (,)
                     (code (ix #:. MultiValue.Cons k) ptr)
                     (A.inc k))
            sh)