{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Data.Array.Knead.Shape (
   C(..), Index,
   Size,
   value,
   paramWith,
   load,
   intersect,
   offset,

   ZeroBased(ZeroBased), zeroBased, zeroBasedSize,

   Range(Range), range, rangeFrom, rangeTo,
   Shifted(Shifted), shifted, shiftedOffset, shiftedSize,

   Scalar(..),
   Sequence(..),
   ) where

import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Parameter as Param
import Data.Array.Knead.Expression (Exp, )

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape
         (Index, ZeroBased(ZeroBased), Range(Range), Shifted(Shifted))
import Data.Ix (Ix)

import qualified LLVM.Extra.Multi.Value.Memory as MultiMem
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Iterator as IterMV
import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.Arithmetic as A
import LLVM.Extra.Multi.Value (atom)

import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum

import Foreign.Storable (Storable)
import Foreign.Ptr (Ptr)

import Data.Word (Word8, Word16, Word32, Word64)
import Data.Int (Int8, Int16, Int32, Int64)

import qualified Control.Monad.HT as Monad
import Control.Applicative ((<$>))


type Size = Word64

value :: (C sh, Expr.Value val) => sh -> val sh
value = Expr.lift0 . MultiValue.cons

paramWith ::
   (Storable b, MultiMem.C b, Expr.Value val) =>
   Param.T p b ->
   (forall parameters.
    (Storable parameters, MultiMem.C parameters) =>
    (p -> parameters) ->
    (MultiValue.T parameters -> val b) ->
    a) ->
   a
paramWith p f =
   Param.withMulti p (\get val -> f get (Expr.lift0 . val))

load ::
   (MultiMem.C sh) =>
   f sh -> LLVM.Value (Ptr (MultiMem.Struct sh)) ->
   LLVM.CodeGenFunction r (MultiValue.T sh)
load _ = MultiMem.load

intersect :: (C sh) => Exp sh -> Exp sh -> Exp sh
intersect = Expr.liftM2 intersectCode

offset ::
   (C sh) =>
   MultiValue.T sh -> MultiValue.T (Index sh) ->
   LLVM.CodeGenFunction r (LLVM.Value Size)
offset sh ix = ($ix) . snd =<< sizeOffset sh

class (MultiValue.C sh, MultiValue.C (Index sh), Shape.Indexed sh) => C sh where
   {-
   It would be better to restrict zipWith to matching shapes
   and turn shape intersection into a bound check.
   -}
   intersectCode ::
      MultiValue.T sh -> MultiValue.T sh ->
      LLVM.CodeGenFunction r (MultiValue.T sh)
   size :: MultiValue.T sh -> LLVM.CodeGenFunction r (LLVM.Value Size)
   {- |
   Result is @(size, offset)@.
   @size@ must equal the result of 'size'.
   We use this for sharing intermediate results.
   -}
   sizeOffset ::
      MultiValue.T sh ->
      LLVM.CodeGenFunction r
         (LLVM.Value Size,
          MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Size))
   iterator :: (Index sh ~ ix) => MultiValue.T sh -> Iter.T r (MultiValue.T ix)
   loop ::
      (Index sh ~ ix, MultiValue.C ix, Loop.Phi state) =>
      (MultiValue.T ix -> state -> LLVM.CodeGenFunction r state) ->
      MultiValue.T sh -> state -> LLVM.CodeGenFunction r state
   loop f sh = Iter.mapState_ f (iterator sh)


instance C () where
   intersectCode _ _ = return $ MultiValue.cons ()
   size _ = return A.one
   sizeOffset _ = return (A.one, \_ -> return A.zero)
   iterator = Iter.singleton
   loop = id


class C sh => Scalar sh where
   scalar :: (Expr.Value val) => val sh
   zeroIndex :: (Expr.Value val) => f sh -> val (Index sh)

instance Scalar () where
   scalar = Expr.lift0 $ MultiValue.Cons ()
   zeroIndex _ = Expr.lift0 $ MultiValue.Cons ()


class
   (C sh,
    MultiValue.IntegerConstant (Index sh),
    MultiValue.Additive (Index sh)) =>
      Sequence sh where
   sequenceShapeFromIndex ::
      MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (MultiValue.T sh)


class
   (MultiValue.Additive n, MultiValue.Real n, MultiValue.IntegerConstant n) =>
      ToSize n where
   toSize :: MultiValue.T n -> LLVM.CodeGenFunction r (LLVM.Value Size)

instance ToSize Word8  where toSize (MultiValue.Cons n) = LLVM.ext n
instance ToSize Word16 where toSize (MultiValue.Cons n) = LLVM.ext n
instance ToSize Word32 where toSize (MultiValue.Cons n) = LLVM.ext n
instance ToSize Word64 where toSize (MultiValue.Cons n) = return n
instance ToSize Int8  where toSize (MultiValue.Cons n) = LLVM.zext n
instance ToSize Int16 where toSize (MultiValue.Cons n) = LLVM.zext n
instance ToSize Int32 where toSize (MultiValue.Cons n) = LLVM.zext n
instance ToSize Int64 where toSize (MultiValue.Cons n) = LLVM.bitcast n


unzipZeroBased :: MultiValue.T (ZeroBased n) -> ZeroBased (MultiValue.T n)
unzipZeroBased (MultiValue.Cons (ZeroBased n)) = ZeroBased (MultiValue.Cons n)

zeroBasedSize :: (Expr.Value val) => val (ZeroBased n) -> val n
zeroBasedSize = Expr.lift1 $ Shape.zeroBasedSize . unzipZeroBased

zeroBased :: (Expr.Value val) => val n -> val (ZeroBased n)
zeroBased =
   Expr.lift1 $ \(MultiValue.Cons n) -> MultiValue.Cons (ZeroBased n)

instance (MultiValue.C n) => MultiValue.C (ZeroBased n) where
   type Repr f (ZeroBased n) = ZeroBased (MultiValue.Repr f n)
   cons (ZeroBased n) = zeroBased (MultiValue.cons n)
   undef = zeroBased MultiValue.undef
   zero = zeroBased MultiValue.zero
   phis bb = Monad.lift zeroBased . MultiValue.phis bb . zeroBasedSize
   addPhis bb a b = MultiValue.addPhis bb (zeroBasedSize a) (zeroBasedSize b)

type instance
   MultiValue.Decomposed f (ZeroBased pn) =
      ZeroBased (MultiValue.Decomposed f pn)
type instance
   MultiValue.PatternTuple (ZeroBased pn) =
      ZeroBased (MultiValue.PatternTuple pn)

instance (MultiValue.Compose n) => MultiValue.Compose (ZeroBased n) where
   type Composed (ZeroBased n) = ZeroBased (MultiValue.Composed n)
   compose (ZeroBased n) = zeroBased (MultiValue.compose n)

instance (MultiValue.Decompose pn) => MultiValue.Decompose (ZeroBased pn) where
   decompose (ZeroBased p) sh =
      MultiValue.decompose p <$> unzipZeroBased sh

instance (Expr.Compose n) => Expr.Compose (ZeroBased n) where
   type Composed (ZeroBased n) = ZeroBased (Expr.Composed n)
   compose (ZeroBased n) = Expr.lift1 zeroBased (Expr.compose n)

instance (Expr.Decompose pn) => Expr.Decompose (ZeroBased pn) where
   decompose (ZeroBased p) = ZeroBased . Expr.decompose p . zeroBasedSize

instance (MultiMem.C n) => MultiMem.C (ZeroBased n) where
   type Struct (ZeroBased n) = MultiMem.Struct n
   decompose = fmap zeroBased . MultiMem.decompose
   compose = MultiMem.compose . zeroBasedSize

{- |
Array dimensions and indexes cannot be negative,
but computations in indices may temporarily yield negative values
or we want to add negative values to indices.

So maybe, we would better have type Index (ZeroBased Word64) = Int64.
This is not possible.
Maybe we need an additional ZeroBased type for unsigned array sizes.
-}
instance
      (Integral n, ToSize n, MultiValue.Comparison n) => C (ZeroBased n) where
   intersectCode sha shb =
      zeroBased <$> MultiValue.min (zeroBasedSize sha) (zeroBasedSize shb)
   size = toSize . zeroBasedSize
   sizeOffset sh = Monad.lift2 (,) (toSize $ zeroBasedSize sh) (return toSize)
   iterator sh =
      IterMV.take (zeroBasedSize sh) $
      Iter.iterate MultiValue.inc MultiValue.zero

instance
   (Integral n, ToSize n, MultiValue.Comparison n) =>
      Sequence (ZeroBased n) where
   sequenceShapeFromIndex = return . zeroBased


singletonRange :: n -> Range n
singletonRange n = Range n n

rangeSize ::
   (ToSize n) =>
   Range (MultiValue.T n) -> LLVM.CodeGenFunction r (LLVM.Value Size)
rangeSize (Range from to) =
   toSize =<< MultiValue.inc =<< MultiValue.sub to from

unzipRange :: MultiValue.T (Range n) -> Range (MultiValue.T n)
unzipRange (MultiValue.Cons (Range from to)) =
   Range (MultiValue.Cons from) (MultiValue.Cons to)

zipRange :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Range n)
zipRange (MultiValue.Cons from) (MultiValue.Cons to) =
   MultiValue.Cons (Range from to)


rangeFrom :: (Expr.Value val) => val (Range n) -> val n
rangeFrom = Expr.lift1 $ Shape.rangeFrom . unzipRange

rangeTo :: (Expr.Value val) => val (Range n) -> val n
rangeTo = Expr.lift1 $ Shape.rangeTo . unzipRange

range :: (Expr.Value val) => val n -> val n -> val (Range n)
range =
   Expr.lift2 $
      \(MultiValue.Cons from) (MultiValue.Cons to) ->
         MultiValue.Cons (Range from to)


instance (MultiValue.C n) => MultiValue.C (Range n) where
   type Repr f (Range n) = Range (MultiValue.Repr f n)
   cons (Range from to) = zipRange (MultiValue.cons from) (MultiValue.cons to)
   undef = MultiValue.compose $ singletonRange MultiValue.undef
   zero = MultiValue.compose $ singletonRange MultiValue.zero
   phis bb a =
      case unzipRange a of
         Range a0 a1 ->
            Monad.lift2 zipRange (MultiValue.phis bb a0) (MultiValue.phis bb a1)
   addPhis bb a b =
      case (unzipRange a, unzipRange b) of
         (Range a0 a1, Range b0 b1) ->
            MultiValue.addPhis bb a0 b0 >>
            MultiValue.addPhis bb a1 b1

type instance
   MultiValue.Decomposed f (Range pn) =
      Range (MultiValue.Decomposed f pn)
type instance
   MultiValue.PatternTuple (Range pn) =
      Range (MultiValue.PatternTuple pn)

instance (MultiValue.Compose n) => MultiValue.Compose (Range n) where
   type Composed (Range n) = Range (MultiValue.Composed n)
   compose (Range from to) =
      zipRange (MultiValue.compose from) (MultiValue.compose to)

instance (MultiValue.Decompose pn) => MultiValue.Decompose (Range pn) where
   decompose (Range pfrom pto) rng =
      case unzipRange rng of
         Range from to ->
            Range
               (MultiValue.decompose pfrom from)
               (MultiValue.decompose pto to)

instance (MultiMem.C n) => MultiMem.C (Range n) where
   type Struct (Range n) = PairStruct n
   decompose = fmap (uncurry zipRange) . decomposeGen
   compose x = case unzipRange x of Range n m -> composeGen n m

instance (Ix n, ToSize n, MultiValue.Comparison n) => C (Range n) where
   intersectCode =
      MultiValue.modifyF2 (singletonRange atom) (singletonRange atom) $
            \(Range fromN toN) (Range fromM toM) ->
         Monad.lift2 Range (MultiValue.max fromN fromM) (MultiValue.min toN toM)
   size = rangeSize . unzipRange
   sizeOffset rngValue =
      case unzipRange rngValue of
         rng@(Range from _to) ->
            Monad.lift2 (,) (rangeSize rng)
               (return $ \i -> toSize =<< MultiValue.sub i from)
   iterator rngValue =
      case MultiValue.decompose (singletonRange atom) rngValue of
         Range from to ->
            IterMV.takeWhile (MultiValue.cmp LLVM.CmpGE to) $
            Iter.iterate MultiValue.inc from


singletonShifted :: n -> Shifted n
singletonShifted n = Shifted n n


unzipShifted :: MultiValue.T (Shifted n) -> Shifted (MultiValue.T n)
unzipShifted (MultiValue.Cons (Shifted from to)) =
   Shifted (MultiValue.Cons from) (MultiValue.Cons to)

zipShifted :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Shifted n)
zipShifted (MultiValue.Cons from) (MultiValue.Cons to) =
   MultiValue.Cons (Shifted from to)


shiftedOffset :: (Expr.Value val) => val (Shifted n) -> val n
shiftedOffset = Expr.lift1 $ Shape.shiftedOffset . unzipShifted

shiftedSize :: (Expr.Value val) => val (Shifted n) -> val n
shiftedSize = Expr.lift1 $ Shape.shiftedSize . unzipShifted

shifted :: (Expr.Value val) => val n -> val n -> val (Shifted n)
shifted =
   Expr.lift2 $
      \(MultiValue.Cons from) (MultiValue.Cons to) ->
         MultiValue.Cons (Shifted from to)


instance (MultiValue.C n) => MultiValue.C (Shifted n) where
   type Repr f (Shifted n) = Shifted (MultiValue.Repr f n)
   cons (Shifted start len) =
      zipShifted (MultiValue.cons start) (MultiValue.cons len)
   undef = MultiValue.compose $ singletonShifted MultiValue.undef
   zero = MultiValue.compose $ singletonShifted MultiValue.zero
   phis bb a =
      case unzipShifted a of
         Shifted a0 a1 ->
            Monad.lift2 zipShifted
               (MultiValue.phis bb a0) (MultiValue.phis bb a1)
   addPhis bb a b =
      case (unzipShifted a, unzipShifted b) of
         (Shifted a0 a1, Shifted b0 b1) ->
            MultiValue.addPhis bb a0 b0 >>
            MultiValue.addPhis bb a1 b1

type instance
   MultiValue.Decomposed f (Shifted pn) =
      Shifted (MultiValue.Decomposed f pn)
type instance
   MultiValue.PatternTuple (Shifted pn) =
      Shifted (MultiValue.PatternTuple pn)

instance (MultiValue.Compose n) => MultiValue.Compose (Shifted n) where
   type Composed (Shifted n) = Shifted (MultiValue.Composed n)
   compose (Shifted start len) =
      zipShifted (MultiValue.compose start) (MultiValue.compose len)

instance (MultiValue.Decompose pn) => MultiValue.Decompose (Shifted pn) where
   decompose (Shifted pstart plen) rng =
      case unzipShifted rng of
         Shifted start len ->
            Shifted
               (MultiValue.decompose pstart start)
               (MultiValue.decompose plen len)

instance (MultiMem.C n) => MultiMem.C (Shifted n) where
   type Struct (Shifted n) = PairStruct n
   decompose = fmap (uncurry zipShifted) . decomposeGen
   compose x = case unzipShifted x of Shifted n m -> composeGen n m

type PairStruct n = LLVM.Struct (MultiMem.Struct n, (MultiMem.Struct n, ()))

decomposeGen ::
   (MultiMem.C n) =>
   LLVM.Value (PairStruct n) ->
   LLVM.CodeGenFunction r (MultiValue.T n, MultiValue.T n)
decomposeGen nm =
   Monad.lift2 (,)
      (MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d0)
      (MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d1)

composeGen ::
   (MultiMem.C n) =>
   MultiValue.T n -> MultiValue.T n ->
   LLVM.CodeGenFunction r (LLVM.Value (PairStruct n))
composeGen n m = do
   sn <- MultiMem.compose n
   sm <- MultiMem.compose m
   rn <- LLVM.insertvalue (LLVM.value LLVM.undef) sn TypeNum.d0
   LLVM.insertvalue rn sm TypeNum.d1


instance (Integral n, ToSize n, MultiValue.Comparison n) => C (Shifted n) where
   intersectCode =
      MultiValue.modifyF2 (singletonShifted atom) (singletonShifted atom) $
            \(Shifted startN lenN) (Shifted startM lenM) -> do
         start <- MultiValue.max startN startM
         endN <- MultiValue.add startN lenN
         endM <- MultiValue.add startM lenM
         end <- MultiValue.min endN endM
         Shifted start <$> MultiValue.sub end start
   size = toSize . shiftedSize
   sizeOffset shapeValue =
      case unzipShifted shapeValue of
         Shifted start len ->
            Monad.lift2 (,) (toSize len)
               (return $ \i -> toSize =<< MultiValue.sub i start)
   iterator rngValue =
      case MultiValue.decompose (singletonShifted atom) rngValue of
         Shifted from len ->
            IterMV.take len $ Iter.iterate MultiValue.inc from



instance (C n, C m) => C (n,m) where
   intersectCode a b =
      case (MultiValue.unzip a, MultiValue.unzip b) of
         ((an,am), (bn,bm)) ->
            Monad.lift2 MultiValue.zip
               (intersectCode an bn)
               (intersectCode am bm)
   size nm =
      case MultiValue.unzip nm of
         (n,m) -> Monad.liftJoin2 A.mul (size n) (size m)
   sizeOffset nm =
      case MultiValue.unzip nm of
         (n,m) -> do
            (ns, iOffset) <- sizeOffset n
            (ms, jOffset) <- sizeOffset m
            sz <- A.mul ns ms
            return
               (sz,
                \ij ->
                  case MultiValue.unzip ij of
                     (i,j) -> do
                        il <- iOffset i
                        jl <- jOffset j
                        A.add jl =<< A.mul ms il)
   iterator nm =
      case MultiValue.unzip nm of
         (n,m) ->
            uncurry MultiValue.zip <$>
            Iter.cartesian (iterator n) (iterator m)
   loop code nm =
      case MultiValue.unzip nm of
         (n,m) -> loop (\i -> loop (\j -> code (MultiValue.zip i j)) m) n

instance (C n, C m, C l) => C (n,m,l) where
   intersectCode a b =
      case (MultiValue.unzip3 a, MultiValue.unzip3 b) of
         ((ai,aj,ak), (bi,bj,bk)) ->
            Monad.lift3 MultiValue.zip3
               (intersectCode ai bi)
               (intersectCode aj bj)
               (intersectCode ak bk)
   size nml =
      case MultiValue.unzip3 nml of
         (n,m,l) ->
            Monad.liftJoin2 A.mul (size n) $
            Monad.liftJoin2 A.mul (size m) (size l)
   sizeOffset nml =
      case MultiValue.unzip3 nml of
         (n,m,l) -> do
            (ns, iOffset) <- sizeOffset n
            (ms, jOffset) <- sizeOffset m
            (ls, kOffset) <- sizeOffset l
            sz <- A.mul ns =<< A.mul ms ls
            return
               (sz,
                \ijk ->
                  case MultiValue.unzip3 ijk of
                     (i,j,k) -> do
                        il <- iOffset i
                        jl <- jOffset j
                        kl <- kOffset k
                        A.add kl =<< A.mul ls =<< A.add jl =<< A.mul ms il)
   iterator nml =
      case MultiValue.unzip3 nml of
         (n,m,l) ->
            fmap (\(a,(b,c)) -> MultiValue.zip3 a b c) $
            Iter.cartesian (iterator n) $
            Iter.cartesian (iterator m) (iterator l)
   loop code nml =
      case MultiValue.unzip3 nml of
         (n,m,l) ->
            loop (\i -> loop (\j -> loop (\k ->
               code (MultiValue.zip3 i j k))
            l) m) n