module Data.Tensor.Compiler (
contract
, Tensor(..), Sh(..), Dd(..), Sd(..)
, Phoas, eval, var, let_, let2_
, CException (..)
)where
import Data.Typeable
import "exceptions" Control.Monad.Catch (MonadThrow(..), throwM, MonadCatch(..), catch)
import Control.Exception (Exception(..))
import Control.Applicative (liftA2, (<|>))
import Data.Tensor (Tensor(..), Sh(..), Dd(..), Sd(..), tshape, tdata, nnz, rank, dim)
import Data.Tensor.Compiler.PHOAS (Phoas(..), let_, let2_, var, lift1, lift2, eval)
mkVar :: MonadThrow m => [Int] -> Tensor i a -> m (Phoas (Tensor i a))
mkVar ixs0 t = do
ixs <- mkIxs ixs0 (rank t)
return $ var t
mkIxs :: MonadThrow m => [Int] -> Int -> m [Int]
mkIxs ixs mm = go ixs []
where
go [] acc = pure acc
go (i:is) acc | i < 0 =
throwM $ IncompatIx "Index must be non-negative"
| i > mm 1 =
throwM $ IncompatIx $ unwords ["Index must be smaller than", show mm]
| otherwise = go is (i : acc)
contract :: MonadThrow m =>
[Int]
-> Tensor i a
-> Tensor i b
-> ([Int] -> Tensor i a -> Tensor i b -> Phoas c)
-> m (Phoas c)
contract ixs0 t1 t2 f = do
_ <- mkIxs ixs0 (rank t1)
ixs <- mkIxs ixs0 (rank t2)
pure $ let_ (var ixs) $ \ixs' ->
let2_ (var t1) (var t2) (f ixs')
data CException = IncompatShape String | IncompatIx String deriving (Eq, Typeable)
instance Show CException where
show c = case c of
IncompatShape str -> unwords ["Incompatible shape:", str]
IncompatIx str -> unwords ["Incompatible index:", str]
instance Exception CException where