module Dvda.Algorithm.Eval
( runAlgorithm
, runAlgorithm'
) where
import Control.Monad.ST ( ST, runST )
import Data.Vector.Generic ( (!) )
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as GM
import Dvda.Expr
import Dvda.Algorithm.Construct ( Algorithm(..), AlgOp(..), InputIdx(..), OutputIdx(..) )
import Dvda.Algorithm.FunGraph ( Node(..) )
newtype RtOp v a = RtOp (forall s. (G.Mutable v) s a -> v a -> (G.Mutable v) s a -> ST s ())
runAlgorithm :: G.Vector v a => Algorithm a -> v a -> Either String (v a)
runAlgorithm alg =
runAlg'' (algInDims alg) (algOutDims alg) (algWorkSize alg) (map toRtOp (algOps alg))
where
runAlg'' :: G.Vector v a => Int -> Int -> Int -> [RtOp v a] -> v a -> Either String (v a)
runAlg'' inSize outSize workSize ops inputVec
| G.length inputVec /= inSize =
Left $ "runAlg: input dimension mismatch, given: " ++ show (G.length inputVec) ++
", expected: " ++ show inSize
| otherwise = Right $ runST $ do
workVec <- GM.new workSize
outputVec <- GM.new outSize
mapM_ (\(RtOp op) -> op workVec inputVec outputVec) ops
G.freeze outputVec
runAlgorithm' :: G.Vector v a => Algorithm a -> v a -> G.Mutable v s a -> ST s (Maybe String)
runAlgorithm' alg =
runAlg'' (algInDims alg) (algOutDims alg) (algWorkSize alg) (map toRtOp (algOps alg))
where
runAlg'' :: G.Vector v a => Int -> Int -> Int -> [RtOp v a] -> v a -> G.Mutable v s a -> ST s (Maybe String)
runAlg'' inSize outSize workSize ops inputVec outputVec
| G.length inputVec /= inSize =
return $ Just $ "runAlg': input dimension mismatch, given: " ++ show (G.length inputVec) ++
", expected: " ++ show inSize
| GM.length outputVec /= outSize =
return $ Just $ "runAlg': output dimension mismatch, given: " ++ show (GM.length outputVec) ++
", expected: " ++ show outSize
| otherwise = do
workVec <- GM.new workSize
mapM_ (\(RtOp op) -> op workVec inputVec outputVec) ops
return Nothing
bin :: GM.MVector (G.Mutable v) a => Node -> Node -> Node -> (a -> a -> a) -> RtOp v a
bin (Node k) (Node kx) (Node ky) f = RtOp $ \work _ _ -> do
x <- GM.read work kx
y <- GM.read work ky
GM.write work k (f x y)
un :: GM.MVector (G.Mutable v) a => Node -> Node -> (a -> a) -> RtOp v a
un (Node k) (Node kx) f = RtOp $ \work _ _ -> GM.read work kx >>= GM.write work k . f
toRtOp :: G.Vector v a => AlgOp a -> RtOp v a
toRtOp (InputOp (Node k) (InputIdx i)) = RtOp $ \work input _ -> GM.write work k (input ! i)
toRtOp (OutputOp (Node k) (OutputIdx i)) = RtOp $ \work _ output -> do
GM.read work k >>= GM.write output i
toRtOp (NormalOp (Node k) (GConst c)) =
RtOp $ \work _ _ -> GM.write work k c
toRtOp (NormalOp (Node k) (GNum (FromInteger x))) =
RtOp $ \work _ _ -> GM.write work k (fromIntegral x)
toRtOp (NormalOp (Node k) (GFractional (FromRational x))) =
RtOp $ \work _ _ -> GM.write work k (fromRational x)
toRtOp (NormalOp k (GNum (Mul x y))) = bin k x y (*)
toRtOp (NormalOp k (GNum (Add x y))) = bin k x y (+)
toRtOp (NormalOp k (GNum (Sub x y))) = bin k x y ()
toRtOp (NormalOp k (GNum (Negate x))) = un k x negate
toRtOp (NormalOp k (GFractional (Div x y))) = bin k x y (/)
toRtOp (NormalOp k (GNum (Abs x))) = un k x abs
toRtOp (NormalOp k (GNum (Signum x))) = un k x signum
toRtOp (NormalOp k (GFloating (Pow x y))) = bin k x y (**)
toRtOp (NormalOp k (GFloating (LogBase x y))) = bin k x y logBase
toRtOp (NormalOp k (GFloating (Exp x))) = un k x exp
toRtOp (NormalOp k (GFloating (Log x))) = un k x log
toRtOp (NormalOp k (GFloating (Sin x))) = un k x sin
toRtOp (NormalOp k (GFloating (Cos x))) = un k x cos
toRtOp (NormalOp k (GFloating (Tan x))) = un k x tan
toRtOp (NormalOp k (GFloating (ASin x))) = un k x asin
toRtOp (NormalOp k (GFloating (ATan x))) = un k x atan
toRtOp (NormalOp k (GFloating (ACos x))) = un k x acos
toRtOp (NormalOp k (GFloating (Sinh x))) = un k x sinh
toRtOp (NormalOp k (GFloating (Cosh x))) = un k x cosh
toRtOp (NormalOp k (GFloating (Tanh x))) = un k x tanh
toRtOp (NormalOp k (GFloating (ASinh x))) = un k x asinh
toRtOp (NormalOp k (GFloating (ATanh x))) = un k x atanh
toRtOp (NormalOp k (GFloating (ACosh x))) = un k x acosh
toRtOp (NormalOp _ (GSym _)) = error "runAlg: there's symbol in my algorithm"