module Data.Tensor.Compiler.MultiStage where
data E1 a =
K1 a
| Dot a a
deriving (Eq, Show)
evalE1 :: E1 a -> E2 a
evalE1 expr = case expr of
K1 x -> K2 x
Dot a b -> Fold Add (ZipWith Mul (K2 a) (K2 b))
data E2 a =
K2 a
| ZipWith BinOp (E2 a) (E2 a)
| Fold BinOp (E2 a) deriving (Eq, Show)
evalE2 :: Fractional b => b -> E2 [b] -> [b]
evalE2 z expr = case expr of
K2 x -> x
ZipWith op v1 v2 -> zipWith (evalBinOp op) (evalE2 z v1) (evalE2 z v2)
Fold op v -> [foldr (evalBinOp op) z (evalE2 z v)]
k2 :: a -> E2 a
k2 = K2
addE2 :: a -> a -> E2 a
addE2 a b = ZipWith Add (k2 a) (k2 b)
mulE2 :: a -> a -> E2 a
mulE2 a b = ZipWith Mul (k2 a) (k2 b)
evalE :: Fractional b => b -> E1 [b] -> [b]
evalE z = evalE2 z . evalE1
data BinOp = Add | Sub | Mul | Div deriving (Eq, Show)
evalBinOp :: Fractional a => BinOp -> a -> a -> a
evalBinOp op = case op of
Add -> (+)
Sub -> ()
Mul -> (*)
Div -> (/)