module Feldspar.Core.Constructs where import Data.List import Data.Typeable import Feldspar.DSL.Expression import Feldspar.DSL.Lambda import Feldspar.DSL.Network import Feldspar.Set import Feldspar.Range import Feldspar.Core.Types import Feldspar.Core.Representation value' :: Type a => Size a -> a -> Data a value' sz a = nodeData (sz \/ sizeOf a) (Inject $ Node $ Literal a) -- | A program that computes a constant value value :: Type a => a -> Data a value = value' empty unit :: Data () unit = value () true :: Data Bool true = value True false :: Data Bool false = value False -- | Like 'value' but with an extra 'Size' argument that can be used to increase -- the size beyond the given data. -- -- Example 1: -- -- > array (10 :> 20 :> universal) [] :: Data [[DefaultInt]] -- -- gives an uninitialized 10x20 array of 'DefaultInt' elements. -- -- Example 2: -- -- > array (10 :> 20 :> universal) [[1,2,3]] :: Data [[DefaultInt]] -- -- gives a 10x20 array whose first row is initialized to @[1,2,3]@. array :: Type a => Size a -> a -> Data a array = value' cap :: Type a => Size a -> Data a -> Data a cap sz a = resizeData (sz /\ dataSize a) a function :: (Syntactic a, Type b) => Bool -> String -> (Info a -> Size b) -> (Internal a -> b) -> (a -> Data b) function doConstProp fun sizeProp f a = case viewLiteral a of Just a' | doConstProp -> value (f a') _ -> func where sz = sizeProp (edgeInfo a) func = nodeData sz $ Inject (Node (Function fun f)) :$: toEdge a function1 :: (Type a, Type b) => String -> (Size a -> Size b) -> (a -> b) -> (Data a -> Data b) function1 fun sizeProp = function True fun (sizeProp . edgeSize) function2 :: (Type a, Type b, Type c) => String -> (Size a -> Size b -> Size c) -> (a -> b -> c) -> (Data a -> Data b -> Data c) function2 fun sizeProp f = curry $ function True fun sizeProp' (uncurry f) where sizeProp' (i1,i2) = sizeProp (edgeSize i1) (edgeSize i2) condition :: Syntactic a => Data Bool -- ^ Condition -> a -- ^ \"Then\" branch -> a -- ^ \"Else\" branch -> a condition cond t e | toEdge t == toEdge e = t -- TODO This check might be expensive | Just True <- viewLiteral cond = t | Just False <- viewLiteral cond = e | otherwise = fromOutEdge info $ Inject (Node Condition) :$: toEdge cond :$: toEdge t :$: toEdge e where info = edgeInfo t \/ edgeInfo e (?) :: Syntactic a => Data Bool -- ^ Condition -> (a,a) -- ^ Alternatives -> a cond ? (t,e) = condition cond t e infix 1 ? -- | Identical to 'condition'. Provided for backwards-compatibility, but will be -- removed in the future. ifThenElse :: Syntactic a => Data Bool -- ^ Condition -> a -- ^ \"Then\" branch -> a -- ^ \"Else\" branch -> a ifThenElse = condition {-# DEPRECATED ifThenElse "Please use `condition` or `(?)` instead." #-} viewGetIx :: Typeable a => Data Index -> Data a -> Maybe (Data [a]) viewGetIx (Data i) (Data a) = case undoEdge a of Inject (Node (Function "(!)" _)) :$: (Inject Group2 :$: as :$: i') | exprEq i i' -> Data `fmap` exprCast as _ -> Nothing -- | Parallel array with continuation parallel'' :: Type a => Bool -> Data Length -> (Data Index -> Data a) -> Data [a] -> Data [a] parallel'' optimize l ixf cont | l == value 0 = cont parallel'' optimize l ixf cont = case viewGetIx ix body of Just arr | optimize, cont == value [] -> setLength l arr _ -> nodeData szPar $ Inject (Node Parallel) :$: toEdge l :$: lambda (EdgeSize szi) ixf :$: toEdge cont where szl1 = dataSize l szi = rangeByRange 0 (szl1-1) ix = variable (EdgeSize szi) "TODO" body = ixf ix sza = dataSize body szl2 :> sza' = dataSize cont szPar = (szl1+szl2) :> (sza \/ sza') -- TODO The optimize argument is a hack to work around a problem with having -- literals (and other things) as continuations. This is only a problem -- if the parallel is a continuation of another parallel or sequential. -- If the parallel is the first segment, enabling optimization should be -- fine. -- | Parallel array with continuation parallel' :: Type a => Data Length -> (Data Index -> Data a) -> Data [a] -> Data [a] parallel' = parallel'' True -- | Parallel array -- -- Since there are no dependencies between the elements, the compiler is free to -- compute the elements in any order, or even in parallel. parallel :: Type a => Data Length -- ^ Length of resulting array (outermost level) -> (Data Index -> Data a) -- ^ Function that maps each index in the range @[0 .. l-1]@ -- to its element -> Data [a] parallel l ixf = parallel' l ixf (value []) -- | For loop forLoop :: Syntactic st => Data Length -- ^ Number of iterations -> st -- ^ Initial state -> (Data Index -> st -> st) -- ^ Loop body (current index and state to next state) -> st -- ^ Final state forLoop l init body | l == value 0 = init forLoop l init body | l == value 1 = body (value 0) init forLoop l init body = fromOutEdge szst $ Inject (Node ForLoop) :$: toEdge l :$: toEdge init :$: Lambda (\i -> lambda szst $ body $ nodeData szi i) where szi = rangeByRange 0 (dataSize l) szinit = edgeInfo init fn _ sz = edgeInfo $ body (variable (EdgeSize szi) "ix") (variable sz "st") (szst,_) = indexedFixedPoint (cutOffAt 3 fn) szinit sequential :: (Type a, Syntactic st) => Data Length -> st -- ^ Initial state -> (Data Index -> st -> (Data a,st)) -- ^ Current loop index and current state to current element -- and next state -> (st -> Data [a]) -- ^ Continuation -> Data [a] sequential l init step cont = nodeData szSeq $ Inject (Node Sequential) :$: toEdge l :$: toEdge init :$: Lambda (\i -> lambda universal $ step $ nodeData szi i) :$: lambda universal cont where szl1 = dataSize l -- szl2 :> _ = dataSize cont -- TODO cont needs an argument szl2 = universal szi = rangeByRange 0 (szl1-1) szSeq = (szl1+szl2) :> universal -- TODO Improve -- | Prevent a function from being inlined noinline :: (Syntactic a, Syntactic b) => String -> (a -> b) -> (a -> b) noinline name body a = fromOutEdge szb $ Inject (Node (NoInline name)) :$: lambda sza body :$: toEdge a where sza = getInfo a szb = getInfo $ body a noinline2 :: (Syntactic a, Syntactic b, Syntactic c) => String -> (a -> b -> c) -> (a -> b -> c) noinline2 name = curry . noinline name . uncurry setLength :: Type a => Data Length -> Data [a] -> Data [a] setLength l arr = case (undoEdge (unData l), undoEdge (unData arr)) of (Inject (Node (Function "length" _)) :$: a, _) | Just b <- exprCast a, b == unData arr -> Data b (Inject (Node (Literal n)), Inject (Node (Literal as))) -> nodeData (szLen :> szArrElem) $ Inject $ Node $ Literal $ genericTake n as (_, Inject (Node Parallel) :$: _ :$: ixf :$: cont) | cont == unData (value []) -> nodeData (szLen :> szArrElem) $ Inject (Node Parallel) :$: unData l :$: ixf :$: cont _ -> nodeData (szLen :> szArrElem) $ Inject (Node SetLength) :$: toEdge l :$: toEdge arr where szl = dataSize l szArrLen :> szArrElem = dataSize arr szLen = rangeMin szl szArrLen -- The only purpose of this function is to enable optimization of 'parallel'.