{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Data.Array.Knead.Shape.Cubic (
   C(switch),
   switchInt,
   intersect,
   value,
   constant,
   paramWith,
   tunnel,
   flattenIndex,
   peek,
   poke,
   computeSize,

   Struct,
   T(..),
   Z(Z), z,
   (:.)((:.)),
   Shape, shape,
   Index, index,
   cons, (#:.),
   head,
   tail,
   switchR,
   loadMultiValue,
   storeMultiValue,
   ) where

import qualified Data.Array.Knead.Shape.Nested as Shape
import qualified Data.Array.Knead.Shape.Cubic.Int as Index

import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Iterator as IterMV
import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import LLVM.Extra.Multi.Value (Atom, )

import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM

import qualified Foreign.Storable as St
import Foreign.Storable.FixedArray (sizeOfArray, )
import Foreign.Marshal.Array (advancePtr, )
import Foreign.Ptr (Ptr, castPtr, )

import Control.Monad (liftM2, )

import Prelude hiding (min, head, tail, )


class C ix where
   switch ::
      f Z ->
      (forall ix0 i. (C ix0, Index.Single i) => f (ix0 :. i)) ->
      f ix

instance C Z where
   switch x _ = x

instance (C ix0, Index.Single i) => C (ix0 :. i) where
   switch _ x = x


newtype SwitchInt f ix i = SwitchInt {runSwitchInt :: f (ix :. i)}

switchInt ::
   (C ix) =>
   f Z ->
   (forall ix0. (C ix0) => f (ix0 :. Index.Int)) ->
   f ix
switchInt z0 cons0 =
   switch z0
      (runSwitchInt $ Index.switchSingle (SwitchInt cons0))


newtype Op2 tag sh = Op2 {runOp2 :: Exp (T tag sh) -> Exp (T tag sh) -> Exp (T tag sh)}

intersect :: C sh => Exp (Shape sh) -> Exp (Shape sh) -> Exp (Shape sh)
intersect =
   runOp2 $
   switchInt
      (Op2 $ \z0 _ -> z0)
      (Op2 $
       switchR $ \is i ->
       switchR $ \js j ->
          intersect is js #:. Expr.min i j)


_value :: (C sh, MultiValue.C sh) => sh -> Exp sh
_value = Expr.lift0 . MultiValue.cons


newtype MakeValue val tag sh = MakeValue {runMakeValue :: T tag sh -> val (T tag sh)}

value :: (C sh, Expr.Value val) => T tag sh -> val (T tag sh)
value =
   runMakeValue $
   switchInt
      (MakeValue $ \(Cons Z) -> z)
      (MakeValue $ \(Cons (t:.h)) ->
         value (Cons t) #:. Expr.lift0 (MultiValue.cons h))

paramWith ::
   (C sh, Expr.Value val) =>
   Param.T p (T tag sh) ->
   (forall parameters.
    (St.Storable parameters,
     MultiValueMemory.C parameters) =>
    (p -> parameters) ->
    (MultiValue.T parameters -> val (T tag sh)) ->
    a) ->
   a
paramWith p f =
   case tunnel p of
      Param.Tunnel get val -> f get (Expr.lift0 . val)

tunnel :: (C sh) => Param.T p (T tag sh) -> Param.Tunnel p (T tag sh)
tunnel p =
   case structFieldsPropF p of
      StructFieldsProp -> Param.tunnel value p


data StructFieldsProp sh = LLVM.StructFields (Struct sh) => StructFieldsProp

_structFieldsProp :: (C sh) => f sh -> StructFieldsProp sh
_structFieldsProp _p = structFieldsRec

structFieldsPropF :: (C sh) => f (g sh) -> StructFieldsProp sh
structFieldsPropF _p = structFieldsRec

withStructFieldsPropFF ::
   (C sh) => (StructFieldsProp sh -> f (g (h sh))) -> f (g (h sh))
withStructFieldsPropFF f = f structFieldsRec

structFieldsRec :: (C sh) => StructFieldsProp sh
structFieldsRec =
   switchInt
      StructFieldsProp
      (succStructFieldsProp structFieldsRec)

succStructFieldsProp ::
   StructFieldsProp sh -> StructFieldsProp (sh:.Index.Int)
succStructFieldsProp StructFieldsProp = StructFieldsProp


data Z = Z
   deriving (Eq, Ord, Read, Show)


infixl 3 :., #:.

data tail :. head = !tail :. !head
   deriving (Eq, Ord, Read, Show)


newtype T tag sh = Cons {decons :: sh}

data ShapeTag
data IndexTag

type Shape = T ShapeTag
type Index = T IndexTag

shape :: sh -> Shape sh
shape = Cons

index :: ix -> Index ix
index = Cons


(#:.) :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i))
(#:.) = cons

cons :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i))
cons =
   Expr.lift2 $
      \(MultiValue.Cons t) (MultiValue.Cons h) ->
         MultiValue.Cons (t,h)

z :: (Expr.Value val) => val (T tag Z)
z = Expr.lift0 $ MultiValue.Cons ()

head :: (Expr.Value val) => val (T tag (sh:.i)) -> val i
head = Expr.lift1 $ \(MultiValue.Cons (_t,h)) -> MultiValue.Cons h

tail :: (Expr.Value val) => val (T tag (sh:.i)) -> val (T tag sh)
tail = Expr.lift1 $ \(MultiValue.Cons (t,_h)) -> MultiValue.Cons t

switchR ::
   Expr.Value val =>
   (val (T tag sh) -> val i -> a) -> val (T tag (sh :. i)) -> a
switchR f ix = f (tail ix) (head ix)


instance (tag ~ ShapeTag, sh ~ Z) => Shape.Scalar (T tag sh) where
   scalar = Expr.lift0 $ MultiValue.Cons ()
   zeroIndex _ = Expr.lift0 $ MultiValue.Cons ()


type family PatternTuple pattern
type family Decomposed (f :: * -> *) tag pattern

type instance PatternTuple (sh:.s) =
   PatternTuple sh :. MultiValue.PatternTuple s

type instance Decomposed f tag (sh:.s) =
   Decomposed f tag sh :. MultiValue.Decomposed f s

type instance PatternTuple (Atom sh) = sh

type instance Decomposed f tag (Atom sh) = f (T tag sh)


class
   (Expr.Composed (Decomposed Exp tag pattern) ~ T tag (PatternTuple pattern)) =>
      Decompose tag pattern where
   decompose ::
      T tag pattern -> Exp (T tag (PatternTuple pattern)) ->
      Decomposed Exp tag pattern

instance Decompose tag (Atom sh) where
   decompose (Cons _atom) x = x

instance (Decompose tag sh, Expr.Decompose s) => Decompose tag (sh :. s) where
   decompose (Cons (psh:.ps)) x =
      decompose (Cons psh) (tail x) :. Expr.decompose ps (head x)


type instance MultiValue.PatternTuple (T tag sh) = T tag (PatternTuple sh)

type instance MultiValue.Decomposed f (T tag sh) = Decomposed f tag sh


type family Unwrap sh
type instance Unwrap (T tag sh) = sh

type family Tag sh
type instance Tag (T tag sh) = tag

instance
   (Expr.Compose sh,
    Expr.Composed sh ~ T (Tag (Expr.Composed sh)) (Unwrap (Expr.Composed sh)),
    Expr.Compose s) =>
      Expr.Compose (sh :. s) where
   type Composed (sh :. s) =
           T (Tag (Expr.Composed sh))
             (Unwrap (Expr.Composed sh) :. Expr.Composed s)
   compose (sh :. s) = cons (Expr.compose sh) (Expr.compose s)

instance (Decompose tag sh) => Expr.Decompose (T tag sh) where
   decompose = decompose



instance (C sh) => St.Storable (T tag sh) where
   sizeOf (Cons sh) = sizeOfArray (rank sh) (0::Shape.Size)
   alignment (Cons _sh) = St.alignment (0::Shape.Size)
   poke ptr = poke (castPtr ptr) . decons
   peek = fmap Cons . peek . castPtr


type family Repr (f :: * -> *) sh
type instance Repr f Z = ()
type instance Repr f (tail :. head) = (Repr f tail, MultiValue.Repr f head)

instance (C sh) => MultiValue.C (T tag sh) where
   type Repr f (T tag sh) = Repr f sh
   cons = value
   undef = constant $ MultiValue.undef
   zero = constant $ MultiValue.zero
   addPhis = addPhis
   phis = phis

instance (tag ~ ShapeTag, C sh) => Shape.C (T tag sh) where
   type Index (T tag sh) = Index sh
   size = fromIntegral . size . decons
   sizeCode = computeSize
   intersectCode = Expr.unliftM2 intersect
   flattenIndexRec sh ix =
      -- a joint implementation would not be more efficient
      liftM2 (,)
         (computeSize sh)
         (flattenIndex sh ix)
   iterator = iterator
   loop = loop


type family Struct sh
type instance Struct Z = ()
type instance Struct (sh :. Index.Int) = (Shape.Size, Struct sh)

instance
   (C sh, LLVM.StructFields (Struct sh)) =>
      MultiValueMemory.C (T tag sh) where
   type Struct (T tag sh) = LLVM.Struct (Struct sh)
   load = loadMultiValue
   store = storeMultiValue

loadMultiValue ::
   (C sh) =>
   LLVM.Value (Ptr (LLVM.Struct (Struct sh))) ->
   LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
loadMultiValue ptr =
   withStructFieldsPropFF $ \StructFieldsProp ->
      load =<< castPtrValue ptr

storeMultiValue ::
   (C sh) =>
   MultiValue.T (T tag sh) ->
   LLVM.Value (Ptr (LLVM.Struct (Struct sh))) -> LLVM.CodeGenFunction r ()
storeMultiValue x ptr =
   case structFieldsPropF x of
      StructFieldsProp -> store x =<< castPtrValue ptr


newtype FlattenIndex r sh =
   FlattenIndex {
      runFlattenIndex ::
         MultiValue.T (Shape sh) -> MultiValue.T (Index sh) ->
         LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
   }

flattenIndex ::
   (C sh) =>
   MultiValue.T (Shape sh) -> MultiValue.T (Index sh) ->
   LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
flattenIndex =
   runFlattenIndex $
   switchInt
      (FlattenIndex $ \_zerosh _zeroix -> return A.zero)
      (FlattenIndex $
         switchR $ \sh (MultiValue.Cons s) ->
         switchR $ \ix (MultiValue.Cons i) ->
            A.add i =<< A.mul s =<< flattenIndex sh ix)


newtype Rank sh = Rank {runRank :: sh -> Int}

rank :: (C sh) => sh -> Int
rank =
   runRank $
   switch
      (Rank $ const 0)
      (Rank $ succ . rank . (\(sh :. _s) -> sh))


newtype Peek sh = Peek {runPeek :: Ptr Shape.Size -> IO sh}

peek :: (C sh) => Ptr Shape.Size -> IO sh
peek =
   runPeek $
   switchInt
      (Peek $ const $ return Z)
      (Peek $ \ptr -> do
         h <- St.peek ptr
         t <- peek $ advancePtr ptr 1
         return (t :. Index.Int h))


newtype Poke sh = Poke {runPoke :: Ptr Shape.Size -> sh -> IO ()}

poke :: (C sh) => Ptr Shape.Size -> sh -> IO ()
poke =
   runPoke $
   switchInt
      (Poke $ const $ const $ return ())
      (Poke $ \ptr (sh :. Index.Int i) -> do
         St.poke ptr i
         poke (advancePtr ptr 1) sh)


castPtrValue ::
   (LLVM.StructFields sh) =>
   LLVM.Value (Ptr (LLVM.Struct sh)) ->
   LLVM.CodeGenFunction r (LLVM.Value (Ptr Shape.Size))
castPtrValue = LLVM.bitcast

newtype Load r tag sh =
   Load {
      runLoad ::
         LLVM.Value (Ptr Shape.Size) ->
         LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
   }

load ::
   (C sh) =>
   LLVM.Value (Ptr Shape.Size) ->
   LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
load =
   runLoad $
   switchInt
      (Load $ const $ return z)
      (Load $ \ptr -> do
         h <- LLVM.load ptr
         t <- load =<< A.advanceArrayElementPtr ptr
         return (t #:. MultiValue.Cons h))


newtype Store r tag sh =
   Store {
      runStore ::
         MultiValue.T (T tag sh) ->
         LLVM.Value (Ptr Shape.Size) ->
         LLVM.CodeGenFunction r ()
   }

store ::
   (C sh) =>
   MultiValue.T (T tag sh) ->
   LLVM.Value (Ptr Shape.Size) ->
   LLVM.CodeGenFunction r ()
store =
   runStore $
   switchInt
      (Store $ \_z _ptr -> return ())
      (Store $ switchR $ \sh (MultiValue.Cons k) ptr -> do
         LLVM.store k ptr
         store sh =<< A.advanceArrayElementPtr ptr)


newtype Size sh =
   Size {
      runSize :: sh -> Shape.Size
   }

size :: (C sh) => sh -> Shape.Size
size =
   runSize $
   switchInt
      (Size $ \_z -> 1)
      (Size $ \(sh :. Index.Int k) -> k * size sh)


newtype ComputeSize r sh =
   ComputeSize {
      runComputeSize ::
         MultiValue.T (Shape sh) ->
         LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
   }

computeSize ::
   (C sh) =>
   MultiValue.T (Shape sh) ->
   LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
computeSize =
   runComputeSize $
   switchInt
      (ComputeSize $ \_z -> return A.one)
      (ComputeSize $ switchR $ \sh (MultiValue.Cons k) ->
         A.mul k =<< computeSize sh)


newtype
   Constant val tag sh =
      Constant {getConstant :: val Index.Int -> val (T tag sh)}

constant :: (C sh, Expr.Value val) => val Index.Int -> val (T tag sh)
constant =
   getConstant $
   switchInt
      (Constant $ const z)
      (Constant $ \x -> constant x #:. x)


newtype AddPhis r tag sh =
   AddPhis {
      runAddPhis ::
         LLVM.BasicBlock ->
         MultiValue.T (T tag sh) ->
         MultiValue.T (T tag sh) ->
         LLVM.CodeGenFunction r ()
   }

addPhis ::
   (C sh) =>
   LLVM.BasicBlock ->
   MultiValue.T (T tag sh) ->
   MultiValue.T (T tag sh) ->
   LLVM.CodeGenFunction r ()
addPhis =
   runAddPhis $
   switchInt
      (AddPhis $ \_ _ _ -> return ())
      (AddPhis $ \bb ->
       switchR $ \hx tx ->
       switchR $ \hy ty ->
          MultiValue.addPhis bb tx ty >>
          addPhis bb hx hy)


newtype Phis r tag sh =
   Phis {
      runPhis ::
         LLVM.BasicBlock ->
         MultiValue.T (T tag sh) ->
         LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
   }

phis ::
   (C sh) =>
   LLVM.BasicBlock ->
   MultiValue.T (T tag sh) ->
   LLVM.CodeGenFunction r (MultiValue.T (T tag sh))
phis =
   runPhis $
   switchInt
      (Phis $ \_ -> return)
      (Phis $ \bb ->
       switchR $ \h t ->
          liftM2 (#:.)
             (phis bb h)
             (MultiValue.phis bb t))


newtype Iterator r sh =
   Iterator {
      runIterator ::
         MultiValue.T (Shape sh) -> Iter.T r (MultiValue.T (Index sh))
   }

iterator ::
   (C sh) =>
   MultiValue.T (Shape sh) -> Iter.T r (MultiValue.T (Index sh))
iterator =
   runIterator $
   switchInt
      (Iterator $ \ _z -> Iter.empty)
      (Iterator $ switchR $ \sh n ->
       fmap (\(ix,i) -> ix#:.i) $
       Iter.cartesian
         (iterator sh)
         (IterMV.takeWhile (MultiValue.cmp LLVM.CmpGT n) $
          Iter.iterate MultiValue.inc MultiValue.zero))


newtype Loop r state sh =
   Loop {
      runLoop ::
         (MultiValue.T (Index sh) ->
          state ->
          LLVM.CodeGenFunction r state) ->
         MultiValue.T (Shape sh) ->
         state ->
         LLVM.CodeGenFunction r state
   }

loop ::
   (C sh, Loop.Phi state) =>
   (MultiValue.T (Index sh) ->
    state ->
    LLVM.CodeGenFunction r state) ->
   MultiValue.T (Shape sh) ->
   state ->
   LLVM.CodeGenFunction r state
loop =
   runLoop $
   switchInt
      (Loop $ \code _z -> code z)
      (Loop $ \code -> switchR $ \sh (MultiValue.Cons n) ->
         loop
            (\ix ptrStart ->
               fmap fst $
               C.fixedLengthLoop n (ptrStart, A.zero) $ \(ptr, k) ->
                  liftM2 (,)
                     (code (ix #:. MultiValue.Cons k) ptr)
                     (A.inc k))
            sh)