module Feldspar.Core.Constructs where



import Data.List
import Data.Typeable

import Feldspar.DSL.Expression
import Feldspar.DSL.Lambda
import Feldspar.DSL.Network
import Feldspar.Set
import Feldspar.Range
import Feldspar.Core.Types
import Feldspar.Core.Representation



value' :: Type a => Size a -> a -> Data a
value' sz a = nodeData (sz \/ sizeOf a) (Inject $ Node $ Literal a)

-- | A program that computes a constant value
value :: Type a => a -> Data a
value = value' empty

unit :: Data ()
unit = value ()

true :: Data Bool
true = value True

false :: Data Bool
false = value False

-- | Like 'value' but with an extra 'Size' argument that can be used to increase
-- the size beyond the given data.
--
-- Example 1:
--
-- > array (10 :> 20 :> universal) [] :: Data [[DefaultInt]]
--
-- gives an uninitialized 10x20 array of 'DefaultInt' elements.
--
-- Example 2:
--
-- > array (10 :> 20 :> universal) [[1,2,3]] :: Data [[DefaultInt]]
--
-- gives a 10x20 array whose first row is initialized to @[1,2,3]@.
array :: Type a => Size a -> a -> Data a
array = value'

cap :: Type a => Size a -> Data a -> Data a
cap sz a = resizeData (sz /\ dataSize a) a

function :: (Syntactic a, Type b)
    => Bool
    -> String
    -> (Info a     -> Size b)
    -> (Internal a -> b)
    -> (a          -> Data b)
function doConstProp fun sizeProp f a = case viewLiteral a of
    Just a' | doConstProp -> value (f a')
    _                     -> func
  where
    sz   = sizeProp (edgeInfo a)
    func = nodeData sz $ Inject (Node (Function fun f)) :$: toEdge a

function1 :: (Type a, Type b)
    => String
    -> (Size a -> Size b)
    -> (a      -> b)
    -> (Data a -> Data b)
function1 fun sizeProp = function True fun (sizeProp . edgeSize)

function2 :: (Type a, Type b, Type c)
    => String
    -> (Size a -> Size b -> Size c)
    -> (a      -> b      -> c)
    -> (Data a -> Data b -> Data c)
function2 fun sizeProp f = curry $ function True fun sizeProp' (uncurry f)
  where
    sizeProp' (i1,i2) = sizeProp (edgeSize i1) (edgeSize i2)

condition :: Syntactic a
    => Data Bool  -- ^ Condition
    -> a          -- ^ \"Then\" branch
    -> a          -- ^ \"Else\" branch
    -> a
condition cond t e
    | toEdge t == toEdge e           = t  -- TODO This check might be expensive
    | Just True  <- viewLiteral cond = t
    | Just False <- viewLiteral cond = e
    | otherwise
         =  fromOutEdge info
         $  Inject (Node Condition)
        :$: toEdge cond
        :$: toEdge t
        :$: toEdge e
  where
    info = edgeInfo t \/ edgeInfo e

(?) :: Syntactic a
    => Data Bool  -- ^ Condition
    -> (a,a)      -- ^ Alternatives
    -> a
cond ? (t,e) = condition cond t e

infix 1 ?

-- | Identical to 'condition'. Provided for backwards-compatibility, but will be
-- removed in the future.
ifThenElse :: Syntactic a
    => Data Bool  -- ^ Condition
    -> a          -- ^ \"Then\" branch
    -> a          -- ^ \"Else\" branch
    -> a
ifThenElse = condition
{-# DEPRECATED ifThenElse "Please use `condition` or `(?)` instead." #-}

viewGetIx :: Typeable a => Data Index -> Data a -> Maybe (Data [a])
viewGetIx (Data i) (Data a) = case undoEdge a of
    Inject (Node (Function "(!)" _)) :$: (Inject Group2 :$: as :$: i')
        | exprEq i i' -> Data `fmap` exprCast as
    _ -> Nothing

-- | Parallel array with continuation
parallel'' :: Type a =>
    Bool -> Data Length -> (Data Index -> Data a) -> Data [a] -> Data [a]
parallel'' optimize l ixf cont | l == value 0 = cont
parallel'' optimize l ixf cont = case viewGetIx ix body of
    Just arr | optimize, cont == value [] -> setLength l arr
    _
        ->  nodeData szPar
         $  Inject (Node Parallel)
        :$: toEdge l
        :$: lambda (EdgeSize szi) ixf
        :$: toEdge cont
  where
    szl1         = dataSize l
    szi          = rangeByRange 0 (szl1-1)
    ix           = variable (EdgeSize szi) "TODO"
    body         = ixf ix
    sza          = dataSize body
    szl2 :> sza' = dataSize cont
    szPar        = (szl1+szl2) :> (sza \/ sza')
  -- TODO The optimize argument is a hack to work around a problem with having
  --      literals (and other things) as continuations. This is only a problem
  --      if the parallel is a continuation of another parallel or sequential.
  --      If the parallel is the first segment, enabling optimization should be
  --      fine.

-- | Parallel array with continuation
parallel' :: Type a =>
    Data Length -> (Data Index -> Data a) -> Data [a] -> Data [a]
parallel' = parallel'' True

-- | Parallel array
--
-- Since there are no dependencies between the elements, the compiler is free to
-- compute the elements in any order, or even in parallel.
parallel :: Type a
    => Data Length  -- ^ Length of resulting array (outermost level)
    -> (Data Index -> Data a)
                    -- ^ Function that maps each index in the range @[0 .. l-1]@
                    -- to its element
    -> Data [a]
parallel l ixf = parallel' l ixf (value [])

-- | For loop
forLoop :: Syntactic st
    => Data Length  -- ^ Number of iterations
    -> st           -- ^ Initial state
    -> (Data Index -> st -> st)
                    -- ^ Loop body (current index and state to next state)
    -> st           -- ^ Final state
forLoop l init body | l == value 0 = init
forLoop l init body | l == value 1 = body (value 0) init
forLoop l init body
     =  fromOutEdge szst
     $  Inject (Node ForLoop)
    :$: toEdge l
    :$: toEdge init
    :$: Lambda (\i -> lambda szst $ body $ nodeData szi i)
  where
    szi      = rangeByRange 0 (dataSize l)
    szinit   = edgeInfo init
    fn _ sz  = edgeInfo $ body (variable (EdgeSize szi) "ix")
                               (variable sz  "st")
    (szst,_) = indexedFixedPoint (cutOffAt 3 fn) szinit

sequential :: (Type a, Syntactic st)
    => Data Length
    -> st                -- ^ Initial state
    -> (Data Index -> st -> (Data a,st))
                         -- ^ Current loop index and current state to current element
                         -- and next state
    -> (st -> Data [a])  -- ^ Continuation
    -> Data [a]
sequential l init step cont
     =  nodeData szSeq
     $  Inject (Node Sequential)
    :$: toEdge l
    :$: toEdge init
    :$: Lambda (\i -> lambda universal $ step $ nodeData szi i)
    :$: lambda universal cont
  where
    szl1      = dataSize l
    -- szl2 :> _ = dataSize cont
    -- TODO cont needs an argument
    szl2      = universal
    szi       = rangeByRange 0 (szl1-1)
    szSeq     = (szl1+szl2) :> universal  -- TODO Improve

-- | Prevent a function from being inlined
noinline :: (Syntactic a, Syntactic b) => String -> (a -> b) -> (a -> b)
noinline name body a
     =  fromOutEdge szb
     $  Inject (Node (NoInline name))
    :$: lambda sza body
    :$: toEdge a
  where
     sza = getInfo a
     szb = getInfo $ body a

noinline2 :: (Syntactic a, Syntactic b, Syntactic c) =>
    String -> (a -> b -> c) -> (a -> b -> c)
noinline2 name = curry . noinline name . uncurry

setLength :: Type a => Data Length -> Data [a] -> Data [a]
setLength l arr = case (undoEdge (unData l), undoEdge (unData arr)) of
    (Inject (Node (Function "length" _)) :$: a, _)
        | Just b <- exprCast a, b == unData arr -> Data b
    (Inject (Node (Literal n)), Inject (Node (Literal as))) ->
        nodeData (szLen :> szArrElem) $
            Inject $ Node $ Literal $ genericTake n as
    (_, Inject (Node Parallel) :$: _ :$: ixf :$: cont)
        | cont == unData (value []) -> nodeData (szLen :> szArrElem) $
            Inject (Node Parallel) :$: unData l :$: ixf :$: cont
    _ -> nodeData (szLen :> szArrElem) $
        Inject (Node SetLength) :$: toEdge l :$: toEdge arr
  where
    szl                   = dataSize l
    szArrLen :> szArrElem = dataSize arr
    szLen                 = rangeMin szl szArrLen
  -- The only purpose of this function is to enable optimization of 'parallel'.