{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE GADTs #-}
module Data.Array.Knead.Simple.Symbolic (
   Core.Array,
   Core.C(..),
   Exp,
   fix,
   shape,
   (Core.!),
   Core.the,
   Core.fromScalar,
   Core.fill,
   gather,
   backpermute,
   Core.backpermute2,
   Core.id,
   Core.map,
   Core.mapWithIndex,
   zipWith,
   zipWith3,
   zipWith4,
   zip,
   zip3,
   zip4,
   Core.fold1,
   Core.fold1All,
   Core.findAll,
   ) where

import qualified Data.Array.Knead.Simple.ShapeDependent as ShapeDep
import qualified Data.Array.Knead.Simple.Private as Core
import Data.Array.Knead.Simple.Private (Array, shape, gather, )

import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )

import qualified LLVM.Extra.Multi.Value as MultiValue

import Data.Function.HT (Id)

import Prelude hiding (zipWith, zipWith3, zip, zip3, replicate, )


fix :: Id (Array sh a)
fix = id

backpermute ::
   (Shape.C sh0, Shape.Index sh0 ~ ix0,
    Shape.C sh1, Shape.Index sh1 ~ ix1,
    MultiValue.C a) =>
   Exp sh1 ->
   (Exp ix1 -> Exp ix0) ->
   Array sh0 a ->
   Array sh1 a
backpermute sh1 f = gather (Core.map f (Core.id sh1))

zipWith ::
   (Core.C array, Shape.C sh) =>
   (Exp a -> Exp b -> Exp c) ->
   array sh a -> array sh b -> array sh c
zipWith = ShapeDep.backpermute2 Shape.intersect id id

zipWith3 ::
   (Core.C array, Shape.C sh) =>
   (Exp a -> Exp b -> Exp c -> Exp d) ->
   array sh a -> array sh b -> array sh c -> array sh d
zipWith3 f a b c =
   zipWith (\ab -> uncurry f (Expr.unzip ab)) (zipWith Expr.zip a b) c

zipWith4 ::
   (Core.C array, Shape.C sh) =>
   (Exp a -> Exp b -> Exp c -> Exp d -> Exp e) ->
   array sh a -> array sh b -> array sh c -> array sh d -> array sh e
zipWith4 f a b c d =
   zipWith3 (\ab -> uncurry f (Expr.unzip ab)) (zipWith Expr.zip a b) c d


zip ::
   (Core.C array, Shape.C sh) =>
   array sh a -> array sh b -> array sh (a,b)
zip = zipWith (Expr.lift2 MultiValue.zip)

zip3 ::
   (Core.C array, Shape.C sh) =>
   array sh a -> array sh b -> array sh c -> array sh (a,b,c)
zip3 = zipWith3 (Expr.lift3 MultiValue.zip3)

zip4 ::
   (Core.C array, Shape.C sh) =>
   array sh a -> array sh b -> array sh c -> array sh d ->
   array sh (a,b,c,d)
zip4 = zipWith4 (Expr.lift4 MultiValue.zip4)