module Feldspar.Core.Expr where
import Data.Monoid
import Data.Unique
import Feldspar.Range
import Feldspar.Core.Types
import Feldspar.Core.Ref
data Expr a
where
Input :: Size a -> Expr a
Value :: Storable a => Size a -> a -> Expr a
Tuple2 :: Data a -> Data b -> Expr (a,b)
Tuple3 :: Data a -> Data b -> Data c -> Expr (a,b,c)
Tuple4 :: Data a -> Data b -> Data c -> Data d -> Expr (a,b,c,d)
Get21 :: Data (a,b) -> Expr a
Get22 :: Data (a,b) -> Expr b
Get31 :: Data (a,b,c) -> Expr a
Get32 :: Data (a,b,c) -> Expr b
Get33 :: Data (a,b,c) -> Expr c
Get41 :: Data (a,b,c,d) -> Expr a
Get42 :: Data (a,b,c,d) -> Expr b
Get43 :: Data (a,b,c,d) -> Expr c
Get44 :: Data (a,b,c,d) -> Expr d
Function :: String -> Size b -> (a -> b) -> (Data a -> Expr b)
NoInline :: String -> Ref (a :-> b) -> (Data a -> Expr b)
IfThenElse
:: Data Bool
-> (a :-> b)
-> (a :-> b)
-> (Data a -> Expr b)
While
:: (a :-> Bool)
-> (a :-> a)
-> Data a
-> Expr a
Parallel
:: Storable a
=> Data Length
-> (Int :-> a)
-> Expr [a]
data a :-> b = SubFunction (Data a -> Data b) (Data a) (Data b)
data Data a = Typeable a => Data (Size a) (Ref (Expr a))
instance Eq (Data a)
where
Data _ a == Data _ b = a==b
instance Ord (Data a)
where
Data _ a `compare` Data _ b = a `compare` b
dataSize :: Data a -> Size a
dataSize (Data sz _) = sz
dataType :: forall a . Data a -> Tuple StorableType
dataType a@(Data _ _) = typeOf (dataSize a) (T::T a)
dataId :: Data a -> Unique
dataId (Data _ r) = refId r
dataToExpr :: Data a -> Expr a
dataToExpr (Data _ r) = deref r
subFunSize :: (a :-> b) -> Size b
subFunSize (SubFunction _ _ outp) = dataSize outp
subAp :: (a :-> b) -> (Data a -> Data b)
subAp (SubFunction f _ _) = f
exprToData :: Typeable a => Expr a -> Data a
exprToData a = Data (exprSize a) (ref a)
exprSize :: forall a . Typeable a => Expr a -> Size a
exprSize (Input sz) = sz
exprSize (Value sz _) = sz
exprSize (Tuple2 a b) = (dataSize a, dataSize b)
exprSize (Tuple3 a b c) = (dataSize a, dataSize b, dataSize c)
exprSize (Tuple4 a b c d) = (dataSize a, dataSize b, dataSize c, dataSize d)
exprSize (Get21 ab) = da
where
(da,db) = dataSize ab
exprSize (Get22 ab) = db
where
(da,db) = dataSize ab
exprSize (Get31 abc) = da
where
(da,db,dc) = dataSize abc
exprSize (Get32 abc) = db
where
(da,db,dc) = dataSize abc
exprSize (Get33 abc) = dc
where
(da,db,dc) = dataSize abc
exprSize (Get41 abcd) = da
where
(da,db,dc,dd) = dataSize abcd
exprSize (Get42 abcd) = db
where
(da,db,dc,dd) = dataSize abcd
exprSize (Get43 abcd) = dc
where
(da,db,dc,dd) = dataSize abcd
exprSize (Get44 abcd) = dd
where
(da,db,dc,dd) = dataSize abcd
exprSize (Function _ sz _ _) = sz
exprSize (NoInline _ f a) = subFunSize (deref f)
exprSize (IfThenElse _ t e a) = subFunSize t `mappend` subFunSize e
exprSize (While _ b i) = dataSize i `mappend` subFunSize b
exprSize (Parallel l ixf) = mapMonotonic fromIntegral (dataSize l)
:> subFunSize ixf
class Typeable (Internal a) => Computable a
where
type Internal a
internalize :: a -> Data (Internal a)
externalize :: Data (Internal a) -> a
instance Storable a => Computable (Data a)
where
type Internal (Data a) = a
internalize = id
externalize = id
instance (Computable a, Computable b) => Computable (a,b)
where
type Internal (a,b) = (Internal a, Internal b)
internalize (a,b) = exprToData $ Tuple2 (internalize a) (internalize b)
externalize ab =
( externalizeE $ Get21 ab
, externalizeE $ Get22 ab
)
instance (Computable a, Computable b, Computable c) => Computable (a,b,c)
where
type Internal (a,b,c) = (Internal a, Internal b, Internal c)
internalize (a,b,c) = exprToData $ Tuple3
(internalize a)
(internalize b)
(internalize c)
externalize abc =
( externalizeE $ Get31 abc
, externalizeE $ Get32 abc
, externalizeE $ Get33 abc
)
instance
( Computable a
, Computable b
, Computable c
, Computable d
) =>
Computable (a,b,c,d)
where
type Internal (a,b,c,d) = (Internal a, Internal b, Internal c, Internal d)
internalize (a,b,c,d) = exprToData $ Tuple4
(internalize a)
(internalize b)
(internalize c)
(internalize d)
externalize abcd =
( externalizeE $ Get41 abcd
, externalizeE $ Get42 abcd
, externalizeE $ Get43 abcd
, externalizeE $ Get44 abcd
)
externalizeE :: Computable a => Expr (Internal a) -> a
externalizeE = externalize . exprToData
lowerFun :: (Computable a, Computable b) =>
(a -> b) -> (Data (Internal a) -> Data (Internal b))
lowerFun f = internalize . f . externalize
liftFun :: (Computable a, Computable b) =>
(Data (Internal a) -> Data (Internal b)) -> (a -> b)
liftFun f = externalize . f . internalize
evalE :: Expr a -> a
evalE (Input _) = error "evaluating Input"
evalE (Value _ a) = a
evalE (Tuple2 a b) = (evalD a, evalD b)
evalE (Tuple3 a b c) = (evalD a, evalD b, evalD c)
evalE (Tuple4 a b c d) = (evalD a, evalD b, evalD c, evalD d)
evalE (Get21 ab) = a
where
(a,b) = evalD ab
evalE (Get22 ab) = b
where
(a,b) = evalD ab
evalE (Get31 abc) = a
where
(a,b,c) = evalD abc
evalE (Get32 abc) = b
where
(a,b,c) = evalD abc
evalE (Get33 abc) = c
where
(a,b,c) = evalD abc
evalE (Get41 abcd) = a
where
(a,b,c,d) = evalD abcd
evalE (Get42 abcd) = b
where
(a,b,c,d) = evalD abcd
evalE (Get43 abcd) = c
where
(a,b,c,d) = evalD abcd
evalE (Get44 abcd) = d
where
(a,b,c,d) = evalD abcd
evalE (Function _ _ f a) = f (evalD a)
evalE (NoInline _ f a) = evalD $ subAp (deref f) a
evalE (IfThenElse c t e a) = if evalD c
then evalD (subAp t a)
else evalD (subAp e a)
evalE (While continue body init) = loop init
where
loop s = if done
then evalD s
else loop (subAp body s)
where
done = not $ evalD $ subAp continue s
evalE (Parallel l ixf) = map (evalD . subAp ixf . value) [0 .. n1]
where
n = evalD l
evalD :: Data a -> a
evalD = evalE . dataToExpr
eval :: Computable a => a -> Internal a
eval = evalD . internalize
value :: Storable a => a -> Data a
value a = exprToData (Value (storableSize a) a)
array :: Storable a => Size a -> a -> Data a
array sz a = exprToData $ Value (sz `mappend` storableSize a) a
arrayLen :: Storable a => Data Length -> [a] -> Data [a]
arrayLen len = array sz
where
sz = mapMonotonic fromInteger (dataSize len) :> universal
unit :: Data ()
unit = value ()
true :: Data Bool
true = value True
false :: Data Bool
false = value False
size :: forall a . Storable a => Data [a] -> [Range Length]
size = listSize (T::T [a]) . dataSize
cap :: (Storable a, Size a ~ Range b, Ord b) => Range b -> Data a -> Data a
cap szb (Data sz a) = Data (sz /\ szb) a
function
:: (Storable a, Storable b)
=> String -> (Size a -> Size b) -> (a -> b) -> (Data a -> Data b)
function fun sizeProp f a = case dataToExpr a of
Value _ a' -> Data s (ref $ Value s $ f a')
_ -> exprToData $ Function fun s f a
where
s = sizeProp (dataSize a)
function2
:: ( Storable a
, Storable b
, Storable c
)
=> String
-> (Size a -> Size b -> Size c)
-> (a -> b -> c)
-> (Data a -> Data b -> Data c)
function2 fun sizeProp f a b = case (dataToExpr a, dataToExpr b) of
(Value _ a', Value _ b') -> Data s (ref $ Value s $ f a' b')
_ -> exprToData $ Function fun s f' $ exprToData $ Tuple2 a b
where
s = sizeProp (dataSize a) (dataSize b)
f' (a,b) = f a b
function3
:: ( Storable a
, Storable b
, Storable c
, Storable d
)
=> String
-> (Size a -> Size b -> Size c -> Size d)
-> (a -> b -> c -> d)
-> (Data a -> Data b -> Data c -> Data d)
function3 fun sizeProp f a b c = case (d2e a, d2e b, d2e c) of
(Value _ a', Value _ b', Value _ c') -> Data s (ref $ Value s $ f a' b' c')
_ -> exprToData $ Function fun s f' $ exprToData $ Tuple3 a b c
where
d2e = dataToExpr
s = sizeProp (dataSize a) (dataSize b) (dataSize c)
f' (a,b,c) = f a b c
function4
:: ( Storable a
, Storable b
, Storable c
, Storable d
, Storable e
)
=> String
-> (Size a -> Size b -> Size c -> Size d -> Size e)
-> (a -> b -> c -> d -> e)
-> (Data a -> Data b -> Data c -> Data d -> Data e)
function4 fun sizeProp f a b c d = case (d2e a, d2e b, d2e c, d2e d) of
(Value _ a', Value _ b', Value _ c', Value _ d') -> Data s (ref $ Value s $ f a' b' c' d')
_ -> exprToData $ Function fun s f' $ exprToData $ Tuple4 a b c d
where
d2e = dataToExpr
s = sizeProp (dataSize a) (dataSize b) (dataSize c) (dataSize d)
f' (a,b,c,d) = f a b c d
instance Show (Data a)
where
show _ = "... :: Data a"
instance Numeric a => Num (Data a)
where
fromInteger = value . fromInteger
abs = function "abs" abs abs
signum = function "signum" signum signum
(+) = function2 "(+)" (+) (+)
() = function2 "(-)" () ()
(*) = function2 "(*)" (*) (*)
instance Fractional (Data Float)
where
fromRational = value . fromRational
(/) = function2 "(/)" (\_ _ -> fullRange) (/)
getIx :: Storable a => Data [a] -> Data Int -> Data a
getIx arr = function2 "(!)" sizeProp f arr
where
sizeProp (_:>aSize) _ = aSize
f as i
| not (i `inRange` r) = error "getIx: index out of bounds"
| i >= la = error "getIx: reading garbage"
| otherwise = as !! i
where
l :> _ = dataSize arr
r = rangeByRange 0 (l1)
la = length as
setIx :: Storable a => Data [a] -> Data Int -> Data a -> Data [a]
setIx arr = function3 "setIx" sizeProp f arr
where
sizeProp (l:>aSize) _ aSize' = l :> (aSize `mappend` aSize')
f as i a
| not (i `inRange` r) = error "setIx: index out of bounds"
| i > la = error "setIx: writing past initialized area"
| otherwise = take i as ++ [a] ++ drop (i+1) as
where
l:>_ = dataSize arr
r = rangeByRange 0 (l1)
la = length as
infixl 9 !
class RandomAccess a
where
type Element a
(!) :: a -> Data Int -> Element a
instance Storable a => RandomAccess (Data [a])
where
type Element (Data [a]) = Data a
(!) = getIx
mkSubFun :: Typeable a => Size a -> (Data a -> Data b) -> (a :-> b)
mkSubFun sz f = SubFunction f inp (f inp)
where
inp = exprToData $ Input sz
noInline :: (Computable a, Computable b) => String -> (a -> b) -> (a -> b)
noInline fun f a = liftFun (exprToData . NoInline fun (ref subFun)) a
where
subFun = mkSubFun (dataSize $ internalize a) (lowerFun f)
ifThenElse
:: (Computable a, Computable b)
=> Data Bool -> (a -> b) -> (a -> b) -> (a -> b)
ifThenElse cond t e a = case dataToExpr cond of
Value _ True -> t a
Value _ False -> e a
_ -> liftFun (exprToData . IfThenElse cond thenSub elseSub) a
where
sz = dataSize $ internalize a
thenSub = mkSubFun sz $ lowerFun t
elseSub = mkSubFun sz $ lowerFun e
whileSized
:: Computable state
=> Size (Internal state)
-> (state -> Data Bool)
-> (state -> state)
-> (state -> state)
whileSized sz cont body init = liftFun (exprToData . While contSub bodySub) init
where
contSub = mkSubFun sz $ lowerFun cont
bodySub = mkSubFun sz $ lowerFun body
while
:: Computable state
=> (state -> Data Bool)
-> (state -> state)
-> (state -> state)
while = whileSized universal
parallel :: Storable a => Data Length -> (Data Int -> Data a) -> Data [a]
parallel l ixf = exprToData $ Parallel l ixfSub
where
szl = dataSize l
ixfSub = mkSubFun (rangeByRange 0 (szl1)) ixf