{-# OPTIONS_GHC -Wall #-} {-# Language Rank2Types #-} {-# Language FlexibleContexts #-} module Dvda.Algorithm.Eval ( runAlgorithm , runAlgorithm' ) where import Control.Monad.ST ( ST, runST ) import Data.Vector.Generic ( (!) ) import qualified Data.Vector.Generic as G import qualified Data.Vector.Generic.Mutable as GM import Dvda.Expr import Dvda.Algorithm.Construct ( Algorithm(..), AlgOp(..), InputIdx(..), OutputIdx(..) ) import Dvda.Algorithm.FunGraph ( Node(..) ) newtype RtOp v a = RtOp (forall s. (G.Mutable v) s a -> v a -> (G.Mutable v) s a -> ST s ()) -- | purely run an algoritm runAlgorithm :: G.Vector v a => Algorithm a -> v a -> Either String (v a) runAlgorithm alg = runAlg'' (algInDims alg) (algOutDims alg) (algWorkSize alg) (map toRtOp (algOps alg)) where runAlg'' :: G.Vector v a => Int -> Int -> Int -> [RtOp v a] -> v a -> Either String (v a) runAlg'' inSize outSize workSize ops inputVec | G.length inputVec /= inSize = Left $ "runAlg: input dimension mismatch, given: " ++ show (G.length inputVec) ++ ", expected: " ++ show inSize | otherwise = Right $ runST $ do workVec <- GM.new workSize outputVec <- GM.new outSize mapM_ (\(RtOp op) -> op workVec inputVec outputVec) ops G.freeze outputVec -- | run an algoritm in the ST monad, mutating a user-provided output vector runAlgorithm' :: G.Vector v a => Algorithm a -> v a -> G.Mutable v s a -> ST s (Maybe String) runAlgorithm' alg = runAlg'' (algInDims alg) (algOutDims alg) (algWorkSize alg) (map toRtOp (algOps alg)) where runAlg'' :: G.Vector v a => Int -> Int -> Int -> [RtOp v a] -> v a -> G.Mutable v s a -> ST s (Maybe String) runAlg'' inSize outSize workSize ops inputVec outputVec | G.length inputVec /= inSize = return $ Just $ "runAlg': input dimension mismatch, given: " ++ show (G.length inputVec) ++ ", expected: " ++ show inSize | GM.length outputVec /= outSize = return $ Just $ "runAlg': output dimension mismatch, given: " ++ show (GM.length outputVec) ++ ", expected: " ++ show outSize | otherwise = do workVec <- GM.new workSize mapM_ (\(RtOp op) -> op workVec inputVec outputVec) ops return Nothing bin :: GM.MVector (G.Mutable v) a => Node -> Node -> Node -> (a -> a -> a) -> RtOp v a bin (Node k) (Node kx) (Node ky) f = RtOp $ \work _ _ -> do x <- GM.read work kx y <- GM.read work ky GM.write work k (f x y) un :: GM.MVector (G.Mutable v) a => Node -> Node -> (a -> a) -> RtOp v a un (Node k) (Node kx) f = RtOp $ \work _ _ -> GM.read work kx >>= GM.write work k . f toRtOp :: G.Vector v a => AlgOp a -> RtOp v a toRtOp (InputOp (Node k) (InputIdx i)) = RtOp $ \work input _ -> GM.write work k (input ! i) toRtOp (OutputOp (Node k) (OutputIdx i)) = RtOp $ \work _ output -> do GM.read work k >>= GM.write output i toRtOp (NormalOp (Node k) (GConst c)) = RtOp $ \work _ _ -> GM.write work k c toRtOp (NormalOp (Node k) (GNum (FromInteger x))) = RtOp $ \work _ _ -> GM.write work k (fromIntegral x) toRtOp (NormalOp (Node k) (GFractional (FromRational x))) = RtOp $ \work _ _ -> GM.write work k (fromRational x) toRtOp (NormalOp k (GNum (Mul x y))) = bin k x y (*) toRtOp (NormalOp k (GNum (Add x y))) = bin k x y (+) toRtOp (NormalOp k (GNum (Sub x y))) = bin k x y (-) toRtOp (NormalOp k (GNum (Negate x))) = un k x negate toRtOp (NormalOp k (GFractional (Div x y))) = bin k x y (/) toRtOp (NormalOp k (GNum (Abs x))) = un k x abs toRtOp (NormalOp k (GNum (Signum x))) = un k x signum toRtOp (NormalOp k (GFloating (Pow x y))) = bin k x y (**) toRtOp (NormalOp k (GFloating (LogBase x y))) = bin k x y logBase toRtOp (NormalOp k (GFloating (Exp x))) = un k x exp toRtOp (NormalOp k (GFloating (Log x))) = un k x log toRtOp (NormalOp k (GFloating (Sin x))) = un k x sin toRtOp (NormalOp k (GFloating (Cos x))) = un k x cos toRtOp (NormalOp k (GFloating (Tan x))) = un k x tan toRtOp (NormalOp k (GFloating (ASin x))) = un k x asin toRtOp (NormalOp k (GFloating (ATan x))) = un k x atan toRtOp (NormalOp k (GFloating (ACos x))) = un k x acos toRtOp (NormalOp k (GFloating (Sinh x))) = un k x sinh toRtOp (NormalOp k (GFloating (Cosh x))) = un k x cosh toRtOp (NormalOp k (GFloating (Tanh x))) = un k x tanh toRtOp (NormalOp k (GFloating (ASinh x))) = un k x asinh toRtOp (NormalOp k (GFloating (ATanh x))) = un k x atanh toRtOp (NormalOp k (GFloating (ACosh x))) = un k x acosh toRtOp (NormalOp _ (GSym _)) = error "runAlg: there's symbol in my algorithm"