{-# LANGUAGE TypeFamilies #-}
module Data.Array.Knead.Simple.ShapeDependent where

import qualified Data.Array.Knead.Simple.Private as Core
import Data.Array.Knead.Simple.Private (Array(Array), )

import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )

import qualified Control.Monad.HT as Monad
import Control.Monad ((<=<), )


shape :: (Core.C array, Shape.C sh, Shape.Scalar z) => array sh a -> array z sh
shape = Core.lift1 $ Core.fromScalar . Core.shape

backpermute ::
   (Core.C array,
    Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1) =>
   (Exp sh0 -> Exp sh1) ->
   (Exp ix1 -> Exp ix0) ->
   array sh0 a ->
   array sh1 a
backpermute createShape projectIndex =
   Core.lift1 $ \(Array sh code) ->
      Array (createShape sh)
         (code <=< Expr.unliftM1 projectIndex)

{- |
This is between 'backpermute' and 'backpermute2'.
You can access the shapes of two arrays,
but only the content of one of them.
This is necessary if the second array contributes only a virtual dimension.
-}
backpermuteExtra ::
   (Core.C array,
    Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Shape.C sh,  Shape.Index sh  ~ ix) =>
   (Exp sh0 -> Exp sh1 -> Exp sh) ->
   (Exp ix -> Exp ix0) ->
   array sh0 a -> array sh1 b -> array sh a
backpermuteExtra newShape projectIndex =
   Core.lift2 $ \(Array sh0 code) (Array sh1 _code) ->
      Array (newShape sh0 sh1)
         (\ix -> code =<< Expr.unliftM1 projectIndex ix)

backpermute2 ::
   (Core.C array,
    Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Shape.C sh,  Shape.Index sh  ~ ix) =>
   (Exp sh0 -> Exp sh1 -> Exp sh) ->
   (Exp ix -> Exp ix0) ->
   (Exp ix -> Exp ix1) ->
   (Exp a -> Exp b -> Exp c) ->
   array sh0 a -> array sh1 b -> array sh c
backpermute2 combineShape projectIndex0 projectIndex1 f =
   Core.lift2 $ \(Array sha codeA) (Array shb codeB) ->
      Array (combineShape sha shb)
         (\ix ->
            Monad.liftJoin2 (Expr.unliftM2 f)
               (codeA =<< Expr.unliftM1 projectIndex0 ix)
               (codeB =<< Expr.unliftM1 projectIndex1 ix))

fill ::
   (Core.C array) =>
   (Exp sh0 -> Exp sh1) -> Exp b ->
   array sh0 a -> array sh1 b
fill fsh a =
   Core.lift1 $ \arr ->
      Core.fill (fsh $ Core.shape arr) a