module Feldspar.Core.Expr where
import Control.Monad.State
import Control.Monad.Writer
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe
import Data.Unique
import Types.Data.Num
import Feldspar.Core.Ref (Ref)
import qualified Feldspar.Core.Ref as Ref
import Feldspar.Core.Types
import Feldspar.Core.Graph hiding (function, Function (..))
import qualified Feldspar.Core.Graph as Graph
import Feldspar.Core.Show
data Data a = Typeable a => Data (Ref (Expr a))
instance Eq (Data a)
where
Data a == Data b = a==b
instance Ord (Data a)
where
Data a `compare` Data b = a `compare` b
ref :: Typeable a => Expr a -> Data a
ref = Data . Ref.ref
refId :: Data a -> Unique
refId (Data r) = Ref.refId r
deref :: Data a -> Expr a
deref (Data r) = Ref.deref r
typeOfData :: forall a . Typeable a => Data a -> Tuple StorableType
typeOfData _ = typeOf (T::T a)
data Expr a
where
Input :: Expr a
Value :: Storable a => a -> Expr a
Tuple2 :: Data a -> Data b -> Expr (a,b)
Tuple3 :: Data a -> Data b -> Data c -> Expr (a,b,c)
Tuple4 :: Data a -> Data b -> Data c -> Data d -> Expr (a,b,c,d)
GetTuple :: GetTuple n a => T n -> Data a -> Expr (Part n a)
Function :: String -> (a -> b) -> (Data a -> Expr b)
NoInline
:: (Typeable a, Typeable b)
=> String -> Ref.Ref (Data a -> Data b) -> (Data a -> Expr b)
IfThenElse
:: (Typeable a, Typeable b)
=> Data Bool
-> (Data a -> Data b)
-> (Data a -> Data b)
-> (Data a -> Expr b)
While
:: Typeable a
=> (Data a -> Data Bool)
-> (Data a -> Data a)
-> Data a
-> Expr a
Parallel
:: (NaturalT n, Storable a)
=> Data Int
-> (Data Int -> Data a)
-> Expr (n :> a)
class Typeable (Internal a) => Computable a
where
type Internal a
internalize :: a -> Data (Internal a)
externalize :: Data (Internal a) -> a
instance Storable a => Computable (Data a)
where
type Internal (Data a) = a
internalize = id
externalize = id
instance (Computable a, Computable b) => Computable (a,b)
where
type Internal (a,b) = (Internal a, Internal b)
internalize (a,b) = ref $ Tuple2 (internalize a) (internalize b)
externalize ab =
( externalize $ ref $ GetTuple (T::T D0) ab
, externalize $ ref $ GetTuple (T::T D1) ab
)
instance (Computable a, Computable b, Computable c) => Computable (a,b,c)
where
type Internal (a,b,c) = (Internal a, Internal b, Internal c)
internalize (a,b,c) = ref $ Tuple3
(internalize a)
(internalize b)
(internalize c)
externalize abc =
( externalize $ ref $ GetTuple (T::T D0) abc
, externalize $ ref $ GetTuple (T::T D1) abc
, externalize $ ref $ GetTuple (T::T D2) abc
)
instance
( Computable a
, Computable b
, Computable c
, Computable d
) =>
Computable (a,b,c,d)
where
type Internal (a,b,c,d) = (Internal a, Internal b, Internal c, Internal d)
internalize (a,b,c,d) = ref $ Tuple4
(internalize a)
(internalize b)
(internalize c)
(internalize d)
externalize abcd =
( externalize $ ref $ GetTuple (T::T D0) abcd
, externalize $ ref $ GetTuple (T::T D1) abcd
, externalize $ ref $ GetTuple (T::T D2) abcd
, externalize $ ref $ GetTuple (T::T D3) abcd
)
wrap :: (Computable a, Computable b) =>
(a -> b) -> (Data (Internal a) -> Data (Internal b))
wrap f = internalize . f . externalize
unwrap :: (Computable a, Computable b) =>
(Data (Internal a) -> Data (Internal b)) -> (a -> b)
unwrap f = externalize . f . internalize
evalE :: Expr a -> a
evalE Input = error "evaluating Input"
evalE (Value a) = a
evalE (Tuple2 a b) = (evalD a, evalD b)
evalE (Tuple3 a b c) = (evalD a, evalD b, evalD c)
evalE (Tuple4 a b c d) = (evalD a, evalD b, evalD c, evalD d)
evalE (GetTuple n a) = getTup n (evalD a)
evalE (Function _ f a) = f (evalD a)
evalE (NoInline _ f a) = evalD (Ref.deref f a)
evalE (IfThenElse c t e a) = if evalD c then evalD (t a) else evalD (e a)
evalE (While continue body init) = loop init
where
loop s
| done = evalD s
| otherwise = loop (body s)
where
done = not $ evalD $ continue s
evalE (Parallel sz ixf) =
mapArray (evalD . ixf . value) $ fromList [(0::Int) .. n1]
where
n = evalD sz
evalD :: Data a -> a
evalD = evalE . deref
eval :: Computable a => a -> Internal a
eval = evalD . internalize
instance Primitive a => Show (Data a)
where
show a = "... :: Data a"
instance (Num n, Primitive n) => Num (Data n)
where
fromInteger = value . fromInteger
abs = functionFold "abs" abs
signum = functionFold "signum" signum
(+) = functionFold2 "(+)" (+)
() = functionFold2 "(-)" ()
(*) = functionFold2 "(*)" (*)
instance Fractional (Data Float)
where
fromRational = value . fromRational
(/) = functionFold2 "(/)" (/)
value_ :: Storable a => a -> Data a
value_ = ref . Value
value :: Primitive a => a -> Data a
value = value_
unit :: Data ()
unit = value ()
true :: Data Bool
true = value True
false :: Data Bool
false = value False
array :: (NaturalT n, Storable a) => ListBased (n :> a) -> Data (n :> a)
array = value_ . fromList
size :: (NaturalT n, Storable a) => Data (n :> a) -> [Int]
size arr = szs
where
One (StorableType szs _) = typeOfData arr
function :: (Storable a, Storable b) => String -> (a -> b) -> (Data a -> Data b)
function fun f = ref . Function fun f
function2
:: ( Storable a
, Storable b
, Storable c
)
=> String -> (a -> b -> c) -> (Data a -> Data b -> Data c)
function2 fun f a b = ref $ Function fun (\(a,b) -> f a b) (ref $ Tuple2 a b)
function3
:: ( Storable a
, Storable b
, Storable c
, Storable d
)
=> String -> (a -> b -> c -> d) -> (Data a -> Data b -> Data c -> Data d)
function3 fun f a b c =
ref $ Function fun (\(a,b,c) -> f a b c) (ref $ Tuple3 a b c)
function4
:: ( Storable a
, Storable b
, Storable c
, Storable d
, Storable e
)
=> String
-> (a -> b -> c -> d -> e)
-> (Data a -> Data b -> Data c -> Data d -> Data e)
function4 fun f a b c d =
ref $ Function fun (\(a,b,c,d) -> f a b c d) (ref $ Tuple4 a b c d)
functionFold
:: (Storable a, Storable b) => String -> (a -> b) -> (Data a -> Data b)
functionFold fun f a = case deref a of
Value a' -> value_ (f a')
_ -> function fun f a
functionFold2
:: ( Storable a
, Storable b
, Storable c
)
=> String -> (a -> b -> c) -> (Data a -> Data b -> Data c)
functionFold2 fun f a b = case (deref a, deref b) of
(Value a', Value b') -> value_ (f a' b')
_ -> function2 fun f a b
functionFold3
:: ( Storable a
, Storable b
, Storable c
, Storable d
)
=> String -> (a -> b -> c -> d) -> (Data a -> Data b -> Data c -> Data d)
functionFold3 fun f a b c = case (deref a, deref b, deref c) of
(Value a', Value b', Value c') -> value_ (f a' b' c')
_ -> function3 fun f a b c
functionFold4
:: ( Storable a
, Storable b
, Storable c
, Storable d
, Storable e
)
=> String -> (a -> b -> c -> d -> e)
-> (Data a -> Data b -> Data c -> Data d -> Data e)
functionFold4 fun f a b c d = case (deref a, deref b, deref c, deref d) of
(Value a', Value b', Value c', Value d') -> value_ (f a' b' c' d')
_ -> function4 fun f a b c d
getIx :: forall n a . (NaturalT n, Storable a) =>
Data (n :> a) -> Data Int -> Data a
getIx = functionFold2 "(!)" f
where
f (ArrayList as) i
| i >= n || i < 0 = error "getIx: index out of bounds"
| i >= l = error "getIx: reading garbage"
| otherwise = as !! i
where
n = fromIntegerT (undefined :: n)
l = length as
setIx :: forall n a . (NaturalT n, Storable a) =>
Data (n :> a) -> Data Int -> Data a -> Data (n :> a)
setIx = functionFold3 "setIx" f
where
f :: (n :> a) -> Int -> a -> (n :> a)
f (ArrayList as) i a
| i >= n || i < 0 = error "setIx: index out of bounds"
| i > l = error "setIx: writing past initialized area"
| otherwise = ArrayList $ take i as ++ [a] ++ drop (i+1) as
where
n = fromIntegerT (undefined :: n)
l = length as
class RandomAccess a
where
type Elem a
(!) :: a -> Data Int -> Elem a
instance (NaturalT n, Storable a) => RandomAccess (Data (n :> a))
where
type Elem (Data (n :> a)) = Data a
(!) = getIx
noInline :: (Computable a, Computable b) => String -> (a -> b) -> (a -> b)
noInline fun = unwrap . (ref .) . NoInline fun . Ref.ref . wrap
ifThenElse
:: (Computable a, Computable b)
=> Data Bool -> (a -> b) -> (a -> b) -> (a -> b)
ifThenElse cond t e = case deref cond of
Value True -> t
Value False -> e
_ -> unwrap $ (ref .) $ IfThenElse cond (wrap t) (wrap e)
while
:: Computable a
=> (a -> Data Bool)
-> (a -> a)
-> (a -> a)
while cont = unwrap . (ref .) . While (cont . externalize) . wrap
parallel :: (NaturalT n, Storable a) =>
Data Int -> (Data Int -> Data a) -> Data (n :> a)
parallel sz = ref . Parallel sz
data Info = Info
{
index :: NodeId
, visited :: Map Unique NodeId
}
type GraphBuilder a = WriterT [Node] (State Info) a
startInfo :: Info
startInfo = Info 0 Map.empty
runGraph :: GraphBuilder a -> Info -> (a, ([Node], Info))
runGraph graph info = (a, (nodes, info'))
where
((a,nodes),info') = runState (runWriterT graph) info
newIndex :: GraphBuilder NodeId
newIndex = do
info <- get
put (info {index = succ (index info)})
return (index info)
remember :: Data a -> NodeId -> GraphBuilder ()
remember dat i = modify $ \info ->
info {visited = Map.insert (refId dat) i (visited info)}
checkNode :: Data a -> GraphBuilder (Maybe NodeId)
checkNode dat = gets ((Map.lookup (refId dat)) . visited)
tupleBind :: Typeable a => NodeId -> T a -> Tuple Variable
tupleBind i = fmap (\path -> (i,path)) . tuplePath . typeOf
node
:: forall a . Typeable a
=> Data a
-> Graph.Function
-> Tuple Source
-> Tuple StorableType
-> GraphBuilder ()
node dat fun inTup inType = do
i <- newIndex
remember dat i
let outType = typeOf (T::T a)
tell [Node i fun inTup inType outType]
sourceNode :: Data a -> Graph.Function -> GraphBuilder ()
sourceNode dat@(Data _) fun = node dat fun (Tup []) (Tup [])
source :: forall a . [Int] -> Data a -> GraphBuilder Source
source path a = case deref a of
GetTuple n tup -> source (numberT n : path) tup
Value a | isPrimitive (T::T a) ->
let PrimitiveData a' = toData a
in return $ Constant a'
_ -> do
Just i <- checkNode a
return $ Variable (i,path)
traceTuple :: Data a -> GraphBuilder (Tuple Source)
traceTuple a = case deref a of
Tuple2 b c -> do
b' <- traceTuple b
c' <- traceTuple c
return (Tup [b',c'])
Tuple3 b c d -> do
b' <- traceTuple b
c' <- traceTuple c
d' <- traceTuple d
return (Tup [b',c',d'])
Tuple4 b c d e -> do
b' <- traceTuple b
c' <- traceTuple c
d' <- traceTuple d
e' <- traceTuple e
return (Tup [b',c',d',e'])
_ -> liftM One (source [] a)
buildGraph :: forall a . Data a -> GraphBuilder ()
buildGraph dat@(Data _) = do
idat <- checkNode dat
unless (isJust idat) $ list (deref dat)
where
funcNode fun inp@(Data _) = do
buildGraph inp
inTup <- traceTuple inp
node dat fun inTup (typeOfData inp)
list :: Expr a -> GraphBuilder ()
list Input = sourceNode dat Graph.Input
list (Value a)
| isPrimitive (T::T a) = return ()
| otherwise = sourceNode dat $ Graph.Array $ toData a
list (Tuple2 a b) = buildGraph a >> buildGraph b
list (Tuple3 a b c) = buildGraph a >> buildGraph b >> buildGraph c
list (Tuple4 a b c d) =
buildGraph a >> buildGraph b >> buildGraph c >> buildGraph d
list (GetTuple _ a) = buildGraph a
list (Function fun _ a) = funcNode (Graph.Function fun) a
list (NoInline fun f a) = do
iface <- buildSubFun (Ref.deref f)
funcNode (Graph.NoInline fun iface) a
list (IfThenElse cond t e a) = do
ifaceThen <- buildSubFun t
ifaceElse <- buildSubFun e
funcNode (Graph.IfThenElse ifaceThen ifaceElse) (ref $ Tuple2 cond a)
list (While cont body a) = do
ifaceCont <- buildSubFun cont
ifaceBody <- buildSubFun body
funcNode (Graph.While ifaceCont ifaceBody) a
list (Parallel sz ixf) = do
iface <- buildSubFun ixf
funcNode (Graph.Parallel n iface) sz
where
One (StorableType (n:_) _) = typeOfData dat
buildSubFun :: forall a b . (Typeable a, Typeable b) =>
(Data a -> Data b) -> GraphBuilder Interface
buildSubFun f = do
let inp = ref Input :: Data a
outp = f inp
buildGraph inp
buildGraph outp
outTup <- traceTuple outp
info <- get
let inId = visited info Map.! refId inp
inType = typeOf (T::T a)
outType = typeOf (T::T b)
return (Interface inId outTup inType outType)
toGraphD :: (Typeable a, Typeable b) => (Data a -> Data b) -> Graph
toGraphD f = Graph nodes iface
where
(iface,(nodes,_)) = runGraph (buildSubFun f) startInfo
class Program a
where
toGraph :: a -> Graph
hasArg :: T a -> Bool
instance Computable a => Program a
where
toGraph a = toGraphD (const (internalize a) :: Data () -> Data (Internal a))
hasArg = const False
instance (Computable a, Computable b) => Program (a -> b)
where
toGraph = toGraphD . wrap
hasArg = const True
instance (Computable a, Computable b, Computable c)
=> Program (a -> b -> c)
where
toGraph = toGraph . uncurry
hasArg = const True
instance (Computable a, Computable b, Computable c, Computable d)
=> Program (a -> b -> c -> d)
where
toGraph f = toGraph (\(a,b,c) -> f a b c)
hasArg = const True
instance
( Computable a
, Computable b
, Computable c
, Computable d
, Computable e
) =>
Program (a -> b -> c -> d -> e)
where
toGraph f = toGraph (\(a,b,c,d) -> f a b c d)
hasArg = const True
showCore :: forall a . Program a => a -> String
showCore = showGraph "program" (hasArg (T::T a)) . toGraph
printCore :: Program a => a -> IO ()
printCore = putStrLn . showCore