{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
module Data.Array.Knead.Index.Nested.Shape (
   C(..),
   value,
   paramWith,
   load,
   intersect,
   flattenIndex,

   Range(..),
   Shifted(..),

   Scalar(..),
   ) where

import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Parameter as Param
import Data.Array.Knead.Expression (Exp, )

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Control as C
import LLVM.Extra.Multi.Value (atom)
import LLVM.Extra.Monad (liftR2)

import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM

import Foreign.Storable (Storable, )
import Foreign.Ptr (Ptr, )

import Data.Word (Word32, Word64)
import Data.Int (Int32, Int64)

import qualified Control.Monad.HT as Monad
import Control.Applicative ((<$>))


value :: (C sh, Expr.Value val) => sh -> val sh
value = Expr.lift0 . MultiValue.cons

paramWith ::
   (Storable b, MultiValueMemory.C b, Expr.Value val) =>
   Param.T p b ->
   (forall parameters.
    (Storable parameters,
     MultiValueMemory.C parameters) =>
    (p -> parameters) ->
    (MultiValue.T parameters -> val b) ->
    a) ->
   a
paramWith p f =
   Param.withMulti p (\get val -> f get (Expr.lift0 . val))

load ::
   (MultiValueMemory.C sh) =>
   f sh -> LLVM.Value (Ptr (MultiValueMemory.Struct sh)) ->
   LLVM.CodeGenFunction r (MultiValue.T sh)
load _ = MultiValueMemory.load

intersect :: (C sh) => Exp sh -> Exp sh -> Exp sh
intersect = Expr.liftM2 intersectCode

flattenIndex ::
   (C sh) =>
   MultiValue.T sh -> MultiValue.T (Index sh) ->
   LLVM.CodeGenFunction r (LLVM.Value Word32)
flattenIndex sh ix =
   fmap snd $ flattenIndexRec sh ix

class (MultiValue.C sh) => C sh where
   type Index sh :: *
   {-
   It would be better to restrict zipWith to matching shapes
   and turn shape intersection into a bound check.
   -}
   intersectCode ::
      MultiValue.T sh -> MultiValue.T sh ->
      LLVM.CodeGenFunction r (MultiValue.T sh)
   sizeCode ::
      MultiValue.T sh ->
      LLVM.CodeGenFunction r (LLVM.Value Word32)
   size :: sh -> Int
   {- |
   Result is @(size, flattenedIndex)@.
   @size@ must equal the result of 'sizeCode'.
   We use this for sharing intermediate results.
   -}
   flattenIndexRec ::
      MultiValue.T sh -> MultiValue.T (Index sh) ->
      LLVM.CodeGenFunction r (LLVM.Value Word32, LLVM.Value Word32)
   loop ::
      (Index sh ~ ix, Loop.Phi state) =>
      (MultiValue.T ix -> state -> LLVM.CodeGenFunction r state) ->
      MultiValue.T sh -> state -> LLVM.CodeGenFunction r state


instance C () where
   type Index () = ()
   intersectCode _ _ = return $ MultiValue.cons ()
   sizeCode _ = return A.one
   size _ = 1
   flattenIndexRec _ _ = return (A.one, A.zero)
   loop = id


class C sh => Scalar sh where
   scalar :: (Expr.Value val) => val sh
   zeroIndex :: (Expr.Value val) => f sh -> val (Index sh)

instance Scalar () where
   scalar = Expr.lift0 $ MultiValue.Cons ()
   zeroIndex _ = Expr.lift0 $ MultiValue.Cons ()


loopPrimitive ::
   (MultiValue.Repr LLVM.Value j ~ LLVM.Value j,
    Num j, LLVM.IsConst j, LLVM.IsInteger j,
    LLVM.CmpRet j, LLVM.CmpResult j ~ Bool,
    MultiValue.Additive i, MultiValue.IntegerConstant i,
    Loop.Phi state) =>
   (MultiValue.T i -> state -> LLVM.CodeGenFunction r state) ->
   MultiValue.T j -> state -> LLVM.CodeGenFunction r state
loopPrimitive code (MultiValue.Cons n) ptrStart =
   loopStart code n MultiValue.zero ptrStart

loopStart ::
   (Num j, LLVM.IsConst j, LLVM.IsInteger j,
    LLVM.CmpRet j, LLVM.CmpResult j ~ Bool,
    MultiValue.Additive i, MultiValue.IntegerConstant i,
    Loop.Phi state) =>
   (MultiValue.T i -> state -> LLVM.CodeGenFunction r state) ->
   LLVM.Value j ->
   MultiValue.T i -> state -> LLVM.CodeGenFunction r state
loopStart code n start ptrStart =
   fmap fst $
   C.fixedLengthLoop n (ptrStart, start) $ \(ptr, k) ->
      Monad.lift2 (,)
         (code k ptr)
         (MultiValue.add k $ MultiValue.fromInteger' 1)

instance C Word32 where
   type Index Word32 = Word32
   intersectCode = MultiValue.min
   sizeCode (MultiValue.Cons n) = return n
   size = fromIntegral
   flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = return (n, i)
   loop = loopPrimitive

instance C Word64 where
   type Index Word64 = Word64
   intersectCode = MultiValue.min
   sizeCode (MultiValue.Cons n) = LLVM.trunc n
   size = fromIntegral
   flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
      Monad.lift2 (,) (LLVM.trunc n) (LLVM.trunc i)
   loop = loopPrimitive


{- |
Array dimensions and indexes cannot be negative,
but computations in indices may temporarily yield negative values
or we want to add negative values to indices.

Maybe we should better have type Index Word64 = Int64?
-}
instance C Int32 where
   type Index Int32 = Int32
   intersectCode = MultiValue.min
   sizeCode (MultiValue.Cons n) = LLVM.bitcast n
   size = fromIntegral
   flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
      Monad.lift2 (,) (LLVM.bitcast n) (LLVM.bitcast i)
   loop = loopPrimitive

instance C Int64 where
   type Index Int64 = Int64
   intersectCode = MultiValue.min
   sizeCode (MultiValue.Cons n) = LLVM.trunc n
   size = fromIntegral
   flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
      Monad.lift2 (,) (LLVM.trunc n) (LLVM.trunc i)
   loop = loopPrimitive


{- |
'Range' denotes an inclusive range like
those of the Haskell 98 standard @Array@ type from the @array@ package.
E.g. the shape type @(Range Int32, Range Int64)@
is equivalent to the ix type @(Int32, Int64)@ for @Array@s.
-}
data Range n = Range n n

singletonRange :: n -> Range n
singletonRange n = Range n n


class
   (MultiValue.Additive n, MultiValue.Real n, MultiValue.IntegerConstant n) =>
      ToSize n where
   toSize :: MultiValue.T n -> LLVM.CodeGenFunction r (LLVM.Value Word32)

instance ToSize Word32 where toSize (MultiValue.Cons n) = LLVM.adapt n
instance ToSize Word64 where toSize (MultiValue.Cons n) = LLVM.adapt n
instance ToSize Int32 where toSize (MultiValue.Cons n) = LLVM.bitcast n
instance ToSize Int64 where toSize (MultiValue.Cons n) = LLVM.trunc n

rangeSize ::
   (ToSize n) =>
   Range (MultiValue.T n) -> LLVM.CodeGenFunction r (LLVM.Value Word32)
rangeSize (Range from to) =
   toSize =<<
   MultiValue.add (MultiValue.fromInteger' 1) =<< MultiValue.sub to from

instance (MultiValue.C n) => MultiValue.C (Range n) where
   type Repr f (Range n) = Range (MultiValue.Repr f n)
   cons (Range from to) =
      MultiValue.compose $ Range (MultiValue.cons from) (MultiValue.cons to)
   undef = MultiValue.compose $ singletonRange MultiValue.undef
   zero = MultiValue.compose $ singletonRange MultiValue.zero
   phis bb a =
      case MultiValue.decompose (singletonRange atom) a of
         Range a0 a1 ->
            fmap MultiValue.compose $
            Monad.lift2 Range (MultiValue.phis bb a0) (MultiValue.phis bb a1)
   addPhis bb a b =
      case (MultiValue.decompose (singletonRange atom) a,
            MultiValue.decompose (singletonRange atom) b) of
         (Range a0 a1, Range b0 b1) ->
            MultiValue.addPhis bb a0 b0 >>
            MultiValue.addPhis bb a1 b1

type instance
   MultiValue.Decomposed f (Range pn) =
      Range (MultiValue.Decomposed f pn)
type instance
   MultiValue.PatternTuple (Range pn) =
      Range (MultiValue.PatternTuple pn)

instance (MultiValue.Compose n) => MultiValue.Compose (Range n) where
   type Composed (Range n) = Range (MultiValue.Composed n)
   compose (Range from to) =
      case (MultiValue.compose from, MultiValue.compose to) of
         (MultiValue.Cons f, MultiValue.Cons t) ->
            MultiValue.Cons (Range f t)

instance (MultiValue.Decompose pn) => MultiValue.Decompose (Range pn) where
   decompose (Range pfrom pto) (MultiValue.Cons (Range from to)) =
      Range
         (MultiValue.decompose pfrom (MultiValue.Cons from))
         (MultiValue.decompose pto (MultiValue.Cons to))

instance (Integral n, ToSize n) => C (Range n) where
   type Index (Range n) = n
   intersectCode =
      MultiValue.modifyF2 (singletonRange atom) (singletonRange atom) $
            \(Range fromN toN) (Range fromM toM) ->
         Monad.lift2 Range (MultiValue.max fromN fromM) (MultiValue.min toN toM)
   sizeCode = rangeSize . MultiValue.decompose (singletonRange atom)
   size (Range from to) = fromIntegral $ to-from+1
   flattenIndexRec rngValue i =
      case MultiValue.decompose (singletonRange atom) rngValue of
         rng@(Range from _to) ->
            Monad.lift2 (,) (rangeSize rng) (toSize =<< MultiValue.sub i from)
   loop code rngValue ptrStart =
      case MultiValue.decompose (singletonRange atom) rngValue of
         rng@(Range from _to) -> do
            {-
            FIXME: rangeSize converts to Word32 which is overly restrictive here.
            -}
            n <- rangeSize rng
            loopStart code n from ptrStart


{- |
'Shifted' denotes a range defined by the start index and the length.
-}
data Shifted n = Shifted {shiftedOffset, shiftedSize :: n}

singletonShifted :: n -> Shifted n
singletonShifted n = Shifted n n


instance (MultiValue.C n) => MultiValue.C (Shifted n) where
   type Repr f (Shifted n) = Shifted (MultiValue.Repr f n)
   cons (Shifted offset len) =
      MultiValue.compose $
      Shifted (MultiValue.cons offset) (MultiValue.cons len)
   undef = MultiValue.compose $ singletonShifted MultiValue.undef
   zero = MultiValue.compose $ singletonShifted MultiValue.zero
   phis bb a =
      case MultiValue.decompose (singletonShifted atom) a of
         Shifted a0 a1 ->
            fmap MultiValue.compose $
            Monad.lift2 Shifted (MultiValue.phis bb a0) (MultiValue.phis bb a1)
   addPhis bb a b =
      case (MultiValue.decompose (singletonShifted atom) a,
            MultiValue.decompose (singletonShifted atom) b) of
         (Shifted a0 a1, Shifted b0 b1) ->
            MultiValue.addPhis bb a0 b0 >>
            MultiValue.addPhis bb a1 b1

type instance
   MultiValue.Decomposed f (Shifted pn) =
      Shifted (MultiValue.Decomposed f pn)
type instance
   MultiValue.PatternTuple (Shifted pn) =
      Shifted (MultiValue.PatternTuple pn)

instance (MultiValue.Compose n) => MultiValue.Compose (Shifted n) where
   type Composed (Shifted n) = Shifted (MultiValue.Composed n)
   compose (Shifted offset len) =
      case (MultiValue.compose offset, MultiValue.compose len) of
         (MultiValue.Cons o, MultiValue.Cons l) ->
            MultiValue.Cons (Shifted o l)

instance (MultiValue.Decompose pn) => MultiValue.Decompose (Shifted pn) where
   decompose (Shifted poffset plen) (MultiValue.Cons (Shifted offset len)) =
      Shifted
         (MultiValue.decompose poffset (MultiValue.Cons offset))
         (MultiValue.decompose plen (MultiValue.Cons len))

instance (Integral n, ToSize n) => C (Shifted n) where
   type Index (Shifted n) = n
   intersectCode =
      MultiValue.modifyF2 (singletonShifted atom) (singletonShifted atom) $
            \(Shifted offsetN lenN) (Shifted offsetM lenM) -> do
         offset <- MultiValue.max offsetN offsetM
         endN <- MultiValue.add offsetN lenN
         endM <- MultiValue.add offsetM lenM
         end <- MultiValue.min endN endM
         Shifted offset <$> MultiValue.sub end offset
   sizeCode =
      toSize . shiftedSize . MultiValue.decompose (singletonShifted atom)
   size (Shifted _offset len) = fromIntegral len
   flattenIndexRec shapeValue i =
      case MultiValue.decompose (singletonShifted atom) shapeValue of
         Shifted offset len ->
            Monad.lift2 (,) (toSize len) (toSize =<< MultiValue.sub i offset)
   loop code rngValue ptrStart =
      case MultiValue.decompose (singletonShifted atom) rngValue of
         Shifted from len -> do
            n <- toSize len
            loopStart code n from ptrStart



instance (C n, C m) => C (n,m) where
   type Index (n,m) = (Index n, Index m)
   intersectCode a b =
      case (MultiValue.unzip a, MultiValue.unzip b) of
         ((an,am), (bn,bm)) ->
            Monad.lift2 MultiValue.zip
               (intersectCode an bn)
               (intersectCode am bm)
   sizeCode nm =
      case MultiValue.unzip nm of
         (n,m) -> liftR2 A.mul (sizeCode n) (sizeCode m)
   size (n,m) = size n * size m
   flattenIndexRec nm ij =
      case (MultiValue.unzip nm, MultiValue.unzip ij) of
         ((n,m), (i,j)) -> do
            (ns, il) <- flattenIndexRec n i
            (ms, jl) <- flattenIndexRec m j
            Monad.lift2 (,)
               (A.mul ns ms)
               (A.add jl =<< A.mul ms il)
   loop code nm =
      case MultiValue.unzip nm of
         (n,m) -> loop (\i -> loop (\j -> code (MultiValue.zip i j)) m) n

instance (C n, C m, C l) => C (n,m,l) where
   type Index (n,m,l) = (Index n, Index m, Index l)
   intersectCode a b =
      case (MultiValue.unzip3 a, MultiValue.unzip3 b) of
         ((ai,aj,ak), (bi,bj,bk)) ->
            Monad.lift3 MultiValue.zip3
               (intersectCode ai bi)
               (intersectCode aj bj)
               (intersectCode ak bk)
   sizeCode nml =
      case MultiValue.unzip3 nml of
         (n,m,l) ->
            liftR2 A.mul (sizeCode n) $
            liftR2 A.mul (sizeCode m) (sizeCode l)
   size (n,m,l) = size n * size m * size l
   flattenIndexRec nml ijk =
      case (MultiValue.unzip3 nml, MultiValue.unzip3 ijk) of
         ((n,m,l), (i,j,k)) -> do
            (ns, il) <- flattenIndexRec n i
            (ms, jl) <- flattenIndexRec m j
            x0 <- A.add jl =<< A.mul ms il
            (ls, kl) <- flattenIndexRec l k
            x1 <- A.add kl =<< A.mul ls x0
            sz <- A.mul ns =<< A.mul ms ls
            return (sz, x1)
   loop code nml =
      case MultiValue.unzip3 nml of
         (n,m,l) ->
             loop (\i -> loop (\j -> loop (\k ->
                code (MultiValue.zip3 i j k))
             l) m) n