module Feldspar.Core.Expr where
import Data.Function
import Data.Monoid
import Data.Unique
import Feldspar.Range
import Feldspar.Core.Types
import Feldspar.Core.Ref
data Expr a
where
Val :: a -> Expr a
Variable :: Expr a
Value :: Storable a => a -> Expr a
Function :: String -> (a -> b) -> Expr (a -> b)
Application :: Expr (a -> b) -> Data a -> Expr b
NoInline :: String -> Ref (a :-> b) -> (Data a -> Expr b)
IfThenElse
:: Data Bool
-> (a :-> b)
-> (a :-> b)
-> (Data a -> Expr b)
While
:: (a :-> Bool)
-> (a :-> a)
-> Data a
-> Expr a
Parallel
:: Storable a
=> Data Length
-> (Int :-> a)
-> Expr [a]
data Data a = Typeable a => Data
{ dataSize :: Size a
, dataRef :: Ref (Expr a)
}
instance Eq (Data a)
where
(==) = (==) `on` dataRef
instance Ord (Data a)
where
compare = compare `on` dataRef
data a :-> b = Typeable a =>
Lambda (Data a -> Data b) (Data a) (Data b)
dataType :: forall a . Data a -> Tuple StorableType
dataType a@(Data _ _) = typeOf (dataSize a) (T::T a)
dataId :: Data a -> Unique
dataId = refId . dataRef
dataToExpr :: Data a -> Expr a
dataToExpr = deref . dataRef
exprToData :: Typeable a => Size a -> Expr a -> Data a
exprToData sz a = Data sz (ref a)
freshVar :: Typeable a => Size a -> Data a
freshVar sz = exprToData sz Variable
lambda :: Typeable a => Size a -> (Data a -> Data b) -> (a :-> b)
lambda sz f = Lambda f var (f var)
where
var = freshVar sz
apply :: (a :-> b) -> Data a -> Data b
apply (Lambda f _ _) = f
resultSize :: (a :-> b) -> Size b
resultSize (Lambda _ _ outp) = dataSize outp
(|$|) :: Expr (a -> b) -> Data a -> Expr b
f |$| a = Application f a
_function
:: Typeable b
=> String -> (Size a -> Size b) -> (a -> b) -> (Data a -> Data b)
_function fun sizeProp f a = exprToData sz $ Function fun f |$| a
where
sz = sizeProp (dataSize a)
_function2
:: Typeable c
=> String
-> (Size a -> Size b -> Size c)
-> (a -> b -> c)
-> (Data a -> Data b -> Data c)
_function2 fun sizeProp f a b = exprToData sz $ Function fun f |$| a |$| b
where
sz = sizeProp (dataSize a) (dataSize b)
_function3
:: Typeable d
=> String -> (Size a -> Size b -> Size c -> Size d)
-> (a -> b -> c -> d)
-> (Data a -> Data b -> Data c -> Data d)
_function3 fun sizeProp f a b c =
exprToData sz $ Function fun f |$| a |$| b |$| c
where
sz = sizeProp (dataSize a) (dataSize b) (dataSize c)
_function4
:: Typeable e
=> String
-> (Size a -> Size b -> Size c -> Size d -> Size e)
-> (a -> b -> c -> d -> e)
-> (Data a -> Data b -> Data c -> Data d -> Data e)
_function4 fun sizeProp f a b c d =
exprToData sz $ Function fun f |$| a |$| b |$| c |$| d
where
sz = sizeProp (dataSize a) (dataSize b) (dataSize c) (dataSize d)
tup2 :: (Typeable a, Typeable b) => Data a -> Data b -> Data (a,b)
tup2 = _function2 "tup2" (,) (,)
tup3 :: (Typeable a, Typeable b, Typeable c) =>
Data a -> Data b -> Data c -> Data (a,b,c)
tup3 = _function3 "tup3" (,,) (,,)
tup4 :: (Typeable a, Typeable b, Typeable c, Typeable d) =>
Data a -> Data b -> Data c -> Data d -> Data (a,b,c,d)
tup4 = _function4 "tup4" (,,,) (,,,)
get21 :: Typeable a => Data (a,b) -> Data a
get21 = _function "getTup21" get get
where
get (a,b) = a
get22 :: Typeable b => Data (a,b) -> Data b
get22 = _function "getTup22" get get
where
get (a,b) = b
get31 :: Typeable a => Data (a,b,c) -> Data a
get31 = _function "getTup31" get get
where
get (a,b,c) = a
get32 :: Typeable b => Data (a,b,c) -> Data b
get32 = _function "getTup32" get get
where
get (a,b,c) = b
get33 :: Typeable c => Data (a,b,c) -> Data c
get33 = _function "getTup33" get get
where
get (a,b,c) = c
get41 :: Typeable a => Data (a,b,c,d) -> Data a
get41 = _function "getTup41" get get
where
get (a,b,c,d) = a
get42 :: Typeable b => Data (a,b,c,d) -> Data b
get42 = _function "getTup42" get get
where
get (a,b,c,d) = b
get43 :: Typeable c => Data (a,b,c,d) -> Data c
get43 = _function "getTup43" get get
where
get (a,b,c,d) = c
get44 :: Typeable d => Data (a,b,c,d) -> Data d
get44 = _function "getTup44" get get
where
get (a,b,c,d) = d
class Typeable (Internal a) => Computable a
where
type Internal a
internalize :: a -> Data (Internal a)
externalize :: Data (Internal a) -> a
instance Storable a => Computable (Data a)
where
type Internal (Data a) = a
internalize = id
externalize = id
instance (Computable a, Computable b) => Computable (a,b)
where
type Internal (a,b) = (Internal a, Internal b)
internalize (a,b) = tup2 (internalize a) (internalize b)
externalize ab =
( externalize (get21 ab)
, externalize (get22 ab)
)
instance (Computable a, Computable b, Computable c) => Computable (a,b,c)
where
type Internal (a,b,c) = (Internal a, Internal b, Internal c)
internalize (a,b,c) = tup3
(internalize a)
(internalize b)
(internalize c)
externalize abc =
( externalize (get31 abc)
, externalize (get32 abc)
, externalize (get33 abc)
)
instance
( Computable a
, Computable b
, Computable c
, Computable d
) =>
Computable (a,b,c,d)
where
type Internal (a,b,c,d) = (Internal a, Internal b, Internal c, Internal d)
internalize (a,b,c,d) = tup4
(internalize a)
(internalize b)
(internalize c)
(internalize d)
externalize abcd =
( externalize (get41 abcd)
, externalize (get42 abcd)
, externalize (get43 abcd)
, externalize (get44 abcd)
)
lowerFun :: (Computable a, Computable b) =>
(a -> b) -> (Data (Internal a) -> Data (Internal b))
lowerFun f = internalize . f . externalize
liftFun :: (Computable a, Computable b) =>
(Data (Internal a) -> Data (Internal b)) -> (a -> b)
liftFun f = externalize . f . internalize
evalE :: Expr a -> a
evalE (Val a) = a
evalE Variable = error "evaluating free variable"
evalE (Value a) = a
evalE (Function _ f) = f
evalE (Application f a) = evalE f (evalD a)
evalE (NoInline _ f a) = evalD (apply (deref f) a)
evalE (IfThenElse c t e a)
| evalD c = evalD (apply t a)
| otherwise = evalD (apply e a)
evalE (While cont body init) =
head $ dropWhile (evalF cont) $ iterate (evalF body) $ evalD init
evalE (Parallel l ixf) = map (evalF ixf) [0 .. evalD l1]
evalD :: Data a -> a
evalD = evalE . dataToExpr
evalF :: (a :-> b) -> (a -> b)
evalF (Lambda f i o) = evalD . f . exprToData (dataSize i) . Val
eval :: Computable a => a -> Internal a
eval = evalD . internalize
value :: Storable a => a -> Data a
value a = exprToData (storableSize a) (Value a)
array :: Storable a => Size a -> a -> Data a
array sz a = exprToData (sz `mappend` storableSize a) (Value a)
arrayLen :: Storable a => Data Length -> [a] -> Data [a]
arrayLen len = array sz
where
sz = mapMonotonic fromInteger (dataSize len) :> universal
unit :: Data ()
unit = value ()
true :: Data Bool
true = value True
false :: Data Bool
false = value False
size :: forall a . Storable a => Data [a] -> [Range Length]
size = listSize (T::T [a]) . dataSize
cap :: (Storable a, Size a ~ Range b, Ord b) => Range b -> Data a -> Data a
cap szb (Data sz a) = Data (sz /\ szb) a
function
:: (Storable a, Storable b)
=> String -> (Size a -> Size b) -> (a -> b) -> (Data a -> Data b)
function fun sizeProp f a = case dataToExpr a of
Value a' -> exprToData sz $ Value (f a')
_ -> _function fun sizeProp f a
where
sz = sizeProp (dataSize a)
function2
:: ( Storable a
, Storable b
, Storable c
)
=> String
-> (Size a -> Size b -> Size c)
-> (a -> b -> c)
-> (Data a -> Data b -> Data c)
function2 fun sizeProp f a b = case (dataToExpr a, dataToExpr b) of
(Value a', Value b') -> exprToData sz $ Value (f a' b')
_ -> _function fun (uncurry sizeProp) (uncurry f) (tup2 a b)
where
sz = sizeProp (dataSize a) (dataSize b)
function3
:: ( Storable a
, Storable b
, Storable c
, Storable d
)
=> String
-> (Size a -> Size b -> Size c -> Size d)
-> (a -> b -> c -> d)
-> (Data a -> Data b -> Data c -> Data d)
function3 fun sizeProp f a b c = case (d2e a, d2e b, d2e c) of
(Value a', Value b', Value c') -> exprToData sz $ Value (f a' b' c')
_ -> _function fun (uncurr sizeProp) (uncurr f) (tup3 a b c)
where
d2e = dataToExpr
sz = sizeProp (dataSize a) (dataSize b) (dataSize c)
uncurr g (a,b,c) = g a b c
function4
:: ( Storable a
, Storable b
, Storable c
, Storable d
, Storable e
)
=> String
-> (Size a -> Size b -> Size c -> Size d -> Size e)
-> (a -> b -> c -> d -> e)
-> (Data a -> Data b -> Data c -> Data d -> Data e)
function4 fun sizeProp f a b c d = case (d2e a, d2e b, d2e c, d2e d) of
(Value a', Value b', Value c', Value d') -> exprToData sz $ Value (f a' b' c' d')
_ -> _function fun (uncurr sizeProp) (uncurr f) (tup4 a b c d)
where
d2e = dataToExpr
sz = sizeProp (dataSize a) (dataSize b) (dataSize c) (dataSize d)
uncurr g (a,b,c,d) = g a b c d
getIx :: Storable a => Data [a] -> Data Int -> Data a
getIx arr = function2 "(!)" sizeProp f arr
where
sizeProp (_:>aSize) _ = aSize
f as i
| not (i `inRange` r) = error "getIx: index out of bounds"
| i >= la = error "getIx: reading garbage"
| otherwise = as !! i
where
l:>_ = dataSize arr
r = rangeByRange 0 (l1)
la = length as
setIx :: Storable a => Data [a] -> Data Int -> Data a -> Data [a]
setIx arr = function3 "setIx" sizeProp f arr
where
sizeProp (l:>aSize) _ aSize' = l :> (aSize `mappend` aSize')
f as i a
| not (i `inRange` r) = error "setIx: index out of bounds"
| i > la = error "setIx: writing past initialized area"
| otherwise = take i as ++ [a] ++ drop (i+1) as
where
l:>_ = dataSize arr
r = rangeByRange 0 (l1)
la = length as
infixl 9 !
class RandomAccess a
where
type Element a
(!) :: a -> Data Int -> Element a
instance Storable a => RandomAccess (Data [a])
where
type Element (Data [a]) = Data a
(!) = getIx
noInline :: (Computable a, Computable b) => String -> (a -> b) -> (a -> b)
noInline fun f a = liftFun (exprToData sz . NoInline fun (ref fLam)) a
where
fLam = lambda (dataSize $ internalize a) (lowerFun f)
sz = resultSize fLam
ifThenElse
:: (Computable a, Computable b)
=> Data Bool -> (a -> b) -> (a -> b) -> (a -> b)
ifThenElse cond t e a = case dataToExpr cond of
Value True -> t a
Value False -> e a
_ -> liftFun (exprToData szb . IfThenElse cond thenLam elseLam) a
where
sza = dataSize $ internalize a
thenLam = lambda sza (lowerFun t)
elseLam = lambda sza (lowerFun e)
szb = resultSize thenLam `mappend` resultSize elseLam
whileSized
:: Computable state
=> Size (Internal state)
-> Size (Internal state)
-> (state -> Data Bool)
-> (state -> state)
-> (state -> state)
whileSized szInitCont szInitBody cont body =
liftFun (exprToData szFinal . While contLam bodyLam)
where
contLam = lambda szInitCont (lowerFun cont)
bodyLam = lambda szInitBody (lowerFun body)
szFinal = universal
while
:: Computable state
=> (state -> Data Bool)
-> (state -> state)
-> (state -> state)
while = whileSized universal universal
parallel :: Storable a => Data Length -> (Data Int -> Data a) -> Data [a]
parallel l ixf = exprToData szPar $ Parallel l ixfLam
where
szl = dataSize l
ixfLam = lambda (rangeByRange 0 (szl1)) ixf
szPar = mapMonotonic fromIntegral szl :> resultSize ixfLam