{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE GADTs #-}
module Data.Array.Knead.Parameterized.Symbolic (
   Array,
   Exp,
   Sym.extendParameter,
   withExp,
   withExp2,
   withExp3,
   (Sym.!),
   Sym.fill,
   gather,
   backpermute,
   Sym.id,
   Sym.map,
   zipWith,
   Sym.fold1,
   Sym.fold1All,
   ) where

import qualified Data.Array.Knead.Parameterized.Private as Sym
import qualified Data.Array.Knead.Simple.Symbolic as Core
import Data.Array.Knead.Parameterized.Private (Array, gather, )

import qualified Data.Array.Knead.Index.Nested.Shape as Shape
import qualified Data.Array.Knead.Parameter as Param
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )

import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory
import qualified LLVM.Extra.Multi.Value as MultiValue

import Foreign.Storable (Storable, )

import Control.Applicative ((<*>), )

import Prelude (uncurry, ($), (.), )


{-
fromScalar ::
   (Storable a, MultiValueMemory.C a, MultiValue.C a) =>
   Param.T p a -> Array p Z a
fromScalar = Sym.fill (return Z)
-}


backpermute ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    Storable sh1, MultiValueMemory.C sh1,
    MultiValue.C a) =>
   Param.T p sh1 ->
   (Exp ix1 -> Exp ix0) ->
   Array p sh0 a ->
   Array p sh1 a
backpermute sh1 f = gather (Core.map f (Sym.id sh1))


zipWith ::
   (Shape.C sh, MultiValueMemory.C d, Storable d) =>
   (Exp d -> Exp a -> Exp b -> Exp c) ->
   Param.T p d -> Array p sh a -> Array p sh b -> Array p sh c
zipWith f d a b =
   Sym.map (\di ab -> uncurry (f di) $ Expr.unzip ab) d $ Core.zip a b


withExp ::
   (Storable x, MultiValueMemory.C x) =>
   (Exp x -> Core.Array shb b -> Core.Array sha a) ->
   Param.T p x -> Array p shb b -> Array p sha a
withExp f x =
   Sym.runHull . Sym.mapHullWithExp f (Sym.expParam x) . Sym.arrayHull

withExp2 ::
   (Storable x, MultiValueMemory.C x) =>
   (Exp x -> Core.Array sha a -> Core.Array shb b -> Core.Array shc c) ->
   Param.T p x -> Array p sha a -> Array p shb b -> Array p shc c
withExp2 f x a b =
   Sym.runHull $
   Sym.mapHullWithExp f (Sym.expParam x) (Sym.arrayHull a)
     <*> Sym.arrayHull b

withExp3 ::
   (Storable x, MultiValueMemory.C x) =>
   (Exp x -> Core.Array sha a ->
    Core.Array shb b -> Core.Array shc c -> Core.Array shd d) ->
   Param.T p x -> Array p sha a ->
   Array p shb b -> Array p shc c -> Array p shd d
withExp3 f x a b c =
   Sym.runHull $
   Sym.mapHullWithExp f (Sym.expParam x) (Sym.arrayHull a)
     <*> Sym.arrayHull b
     <*> Sym.arrayHull c