{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE PatternSynonyms        #-}
{-# LANGUAGE QuasiQuotes            #-}
{-# LANGUAGE ViewPatterns           #-}

{- |
Module      : Language.Egison.Tensor
Licence     : MIT

This module contains functions for tensors.
-}

module Language.Egison.Tensor
    ( TensorComponent (..)
    -- * Tensor
    , tref
    , enumTensorIndices
    , tTranspose
    , tTranspose'
    , tFlipIndices
    , appendDF
    , removeDF
    , tMap
    , tMap2
    , tProduct
    , tContract
    , tContract'
    , tConcat'
    ) where

import           Prelude                    hiding (foldr, mappend, mconcat)

import           Control.Monad.Except       (mzero, throwError, zipWithM)
import           Data.List                  (delete, intersect, partition, (\\))
import qualified Data.Vector                as V

import           Control.Egison
import qualified Control.Egison             as M

import           Language.Egison.Data
import           Language.Egison.Data.Utils
import           Language.Egison.IExpr      (Index (..), extractSupOrSubIndex)
import           Language.Egison.Math
import           Language.Egison.RState


data IndexM m = IndexM m
instance M.Matcher m a => M.Matcher (IndexM m) (Index a)

sub :: M.Matcher m a => M.Pattern (PP a) (IndexM m) (Index a) a
sub :: Pattern (PP a) (IndexM m) (Index a) a
sub PP a
_ IndexM m
_ (Sub a
a) = a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
sub PP a
_ IndexM m
_ Index a
_       = [a]
forall (m :: * -> *) a. MonadPlus m => m a
mzero
subM :: M.Matcher m a => IndexM m -> Index a -> m
subM :: IndexM m -> Index a -> m
subM (IndexM m
m) Index a
_ = m
m

sup :: M.Matcher m a => M.Pattern (PP a) (IndexM m) (Index a) a
sup :: Pattern (PP a) (IndexM m) (Index a) a
sup PP a
_ IndexM m
_ (Sup a
a) = a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
sup PP a
_ IndexM m
_ Index a
_       = [a]
forall (m :: * -> *) a. MonadPlus m => m a
mzero
supM :: M.Matcher m a => IndexM m -> Index a -> m
supM :: IndexM m -> Index a -> m
supM (IndexM m
m) Index a
_ = m
m

supsub :: M.Matcher m a => M.Pattern (PP a) (IndexM m) (Index a) a
supsub :: Pattern (PP a) (IndexM m) (Index a) a
supsub PP a
_ IndexM m
_ (SupSub a
a) = a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
supsub PP a
_ IndexM m
_ Index a
_          = [a]
forall (m :: * -> *) a. MonadPlus m => m a
mzero
supsubM :: M.Matcher m a => IndexM m -> Index a -> m
supsubM :: IndexM m -> Index a -> m
supsubM (IndexM m
m) Index a
_ = m
m

--
-- Tensors
--

class TensorComponent a b | a -> b where
  fromTensor :: Tensor b -> EvalM a
  toTensor :: a -> EvalM (Tensor b)

instance TensorComponent EgisonValue EgisonValue where
  fromTensor :: Tensor EgisonValue -> EvalM EgisonValue
fromTensor t :: Tensor EgisonValue
t@Tensor{} = EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> EvalM EgisonValue)
-> EgisonValue -> EvalM EgisonValue
forall a b. (a -> b) -> a -> b
$ Tensor EgisonValue -> EgisonValue
TensorData Tensor EgisonValue
t
  fromTensor (Scalar EgisonValue
x) = EgisonValue -> EvalM EgisonValue
forall (m :: * -> *) a. Monad m => a -> m a
return EgisonValue
x
  toTensor :: EgisonValue -> EvalM (Tensor EgisonValue)
toTensor (TensorData Tensor EgisonValue
t) = Tensor EgisonValue -> EvalM (Tensor EgisonValue)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor EgisonValue
t
  toTensor EgisonValue
x              = Tensor EgisonValue -> EvalM (Tensor EgisonValue)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor EgisonValue -> EvalM (Tensor EgisonValue))
-> Tensor EgisonValue -> EvalM (Tensor EgisonValue)
forall a b. (a -> b) -> a -> b
$ EgisonValue -> Tensor EgisonValue
forall a. a -> Tensor a
Scalar EgisonValue
x

instance TensorComponent WHNFData ObjectRef where
  fromTensor :: Tensor ObjectRef -> EvalM WHNFData
fromTensor t :: Tensor ObjectRef
t@Tensor{} = WHNFData -> EvalM WHNFData
forall (m :: * -> *) a. Monad m => a -> m a
return (WHNFData -> EvalM WHNFData) -> WHNFData -> EvalM WHNFData
forall a b. (a -> b) -> a -> b
$ Tensor ObjectRef -> WHNFData
ITensor Tensor ObjectRef
t
  fromTensor (Scalar ObjectRef
x) = ObjectRef -> EvalM WHNFData
evalRef ObjectRef
x
  toTensor :: WHNFData -> EvalM (Tensor ObjectRef)
toTensor (ITensor Tensor ObjectRef
t) = Tensor ObjectRef -> EvalM (Tensor ObjectRef)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor ObjectRef
t
  toTensor WHNFData
x           = ObjectRef -> Tensor ObjectRef
forall a. a -> Tensor a
Scalar (ObjectRef -> Tensor ObjectRef)
-> StateT EvalState (ExceptT EgisonError RuntimeM) ObjectRef
-> EvalM (Tensor ObjectRef)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> WHNFData
-> StateT EvalState (ExceptT EgisonError RuntimeM) ObjectRef
newEvaluatedObjectRef WHNFData
x

tShape :: Tensor a -> Shape
tShape :: Tensor a -> Shape
tShape (Tensor Shape
ns Vector a
_ [Index EgisonValue]
_) = Shape
ns
tShape (Scalar a
_)      = []

tToVector :: Tensor a -> V.Vector a
tToVector :: Tensor a -> Vector a
tToVector (Tensor Shape
_ Vector a
xs [Index EgisonValue]
_) = Vector a
xs
tToVector (Scalar a
x)      = [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a
x]

tIndex :: Tensor a -> [Index EgisonValue]
tIndex :: Tensor a -> [Index EgisonValue]
tIndex (Tensor Shape
_ Vector a
_ [Index EgisonValue]
js) = [Index EgisonValue]
js
tIndex (Scalar a
_)      = []

tIntRef' :: Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' :: Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' Integer
i (Tensor [Integer
n] Vector a
xs [Index EgisonValue]
_) =
  if Integer
0 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
i Bool -> Bool -> Bool
&& Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n
     then Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a))
-> (a -> Tensor a) -> a -> EvalM (Tensor a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Tensor a
forall a. a -> Tensor a
Scalar (a -> EvalM (Tensor a)) -> a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
     else (CallStack -> EgisonError) -> EvalM (Tensor a)
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (Integer -> Integer -> CallStack -> EgisonError
TensorIndexOutOfBounds Integer
i Integer
n)
tIntRef' Integer
i (Tensor (Integer
n:Shape
ns) Vector a
xs [Index EgisonValue]
js) =
  if Integer
0 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
i Bool -> Bool -> Bool
&& Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n
   then let w :: Int
w = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Shape -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape
ns)
            ys :: Vector a
ys = Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
V.take Int
w (Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
V.drop (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
* Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)) Vector a
xs)
         in Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector a
ys ([Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a]
cdr [Index EgisonValue]
js)
   else (CallStack -> EgisonError) -> EvalM (Tensor a)
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (Integer -> Integer -> CallStack -> EgisonError
TensorIndexOutOfBounds Integer
i Integer
n)
tIntRef' Integer
_ Tensor a
_ = EgisonError -> EvalM (Tensor a)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (EgisonError -> EvalM (Tensor a))
-> EgisonError -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ String -> EgisonError
Default String
"More indices than the order of the tensor"

tIntRef :: [Integer] -> Tensor a -> EvalM (Tensor a)
tIntRef :: Shape -> Tensor a -> EvalM (Tensor a)
tIntRef [] (Tensor [] Vector a
xs [Index EgisonValue]
_)
  | Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ a -> Tensor a
forall a. a -> Tensor a
Scalar (Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
0)
  | Bool
otherwise = (CallStack -> EgisonError) -> EvalM (Tensor a)
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> CallStack -> EgisonError
EgisonBug String
"sevaral elements in scalar tensor")
tIntRef [] Tensor a
t = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
tIntRef (Integer
m:Shape
ms) Tensor a
t = Integer -> Tensor a -> EvalM (Tensor a)
forall a. Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' Integer
m Tensor a
t EvalM (Tensor a)
-> (Tensor a -> EvalM (Tensor a)) -> EvalM (Tensor a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Shape -> Tensor a -> EvalM (Tensor a)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
tIntRef Shape
ms

tIntRef1 :: [Integer] -> Tensor a -> EvalM a
tIntRef1 :: Shape -> Tensor a -> EvalM a
tIntRef1 [] (Scalar a
x) = a -> EvalM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
tIntRef1 [] (Tensor [] Vector a
xs [Index EgisonValue]
_) | Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = a -> EvalM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
0)
tIntRef1 [] Tensor a
_ = (CallStack -> EgisonError) -> EvalM a
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> CallStack -> EgisonError
EgisonBug String
"sevaral elements in scalar tensor")
tIntRef1 (Integer
m:Shape
ms) Tensor a
t = Integer -> Tensor a -> EvalM (Tensor a)
forall a. Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' Integer
m Tensor a
t EvalM (Tensor a) -> (Tensor a -> EvalM a) -> EvalM a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Shape -> Tensor a -> EvalM a
forall a. Shape -> Tensor a -> EvalM a
tIntRef1 Shape
ms

pattern SupOrSubIndex :: a -> Index a
pattern $mSupOrSubIndex :: forall r a. Index a -> (a -> r) -> (Void# -> r) -> r
SupOrSubIndex i <- (extractSupOrSubIndex -> Just i)

tref :: [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tref :: [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tref [] (Tensor [] Vector a
xs [Index EgisonValue]
_)
  | Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ a -> Tensor a
forall a. a -> Tensor a
Scalar (Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
0)
  | Bool
otherwise = (CallStack -> EgisonError) -> EvalM (Tensor a)
forall a. (CallStack -> EgisonError) -> EvalM a
throwErrorWithTrace (String -> CallStack -> EgisonError
EgisonBug String
"sevaral elements in scalar tensor")
tref [] Tensor a
t = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
tref (s :: Index EgisonValue
s@(SupOrSubIndex (ScalarData (SingleSymbol SymbolExpr
_))):[Index EgisonValue]
ms) (Tensor (Integer
_:Shape
ns) Vector a
xs [Index EgisonValue]
js) = do
  let yss :: [Vector a]
yss = Integer -> Vector a -> [Vector a]
forall a. Integer -> Vector a -> [Vector a]
split (Shape -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape
ns) Vector a
xs
  [Tensor a]
ts <- (Vector a -> EvalM (Tensor a))
-> [Vector a]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Vector a
ys -> [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tref [Index EgisonValue]
ms (Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector a
ys ([Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a]
cdr [Index EgisonValue]
js))) [Vector a]
yss
  Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
forall a. Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat Index EgisonValue
s [Tensor a]
ts
tref (SupOrSubIndex (ScalarData (SingleTerm Integer
m [])):[Index EgisonValue]
ms) Tensor a
t = Integer -> Tensor a -> EvalM (Tensor a)
forall a. Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' Integer
m Tensor a
t EvalM (Tensor a)
-> (Tensor a -> EvalM (Tensor a)) -> EvalM (Tensor a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tref [Index EgisonValue]
ms
tref (SupOrSubIndex (ScalarData ScalarData
ZeroExpr):[Index EgisonValue]
_) Tensor a
_ = EgisonError -> EvalM (Tensor a)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (EgisonError -> EvalM (Tensor a))
-> EgisonError -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ String -> EgisonError
Default String
"tensor index out of bounds: 0"
tref (s :: Index EgisonValue
s@(SupOrSubIndex (Tuple [EgisonValue
mVal, EgisonValue
nVal])):[Index EgisonValue]
ms) t :: Tensor a
t@(Tensor Shape
is Vector a
_ [Index EgisonValue]
_) = do
  Integer
m <- EgisonValue -> EvalM Integer
forall a. EgisonData a => EgisonValue -> EvalM a
fromEgison EgisonValue
mVal
  Integer
n <- EgisonValue -> EvalM Integer
forall a. EgisonData a => EgisonValue -> EvalM a
fromEgison EgisonValue
nVal
  if Integer
m Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
n
    then
      Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Int -> Integer -> Shape
forall a. Int -> a -> [a]
replicate (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
is) Integer
0) Vector a
forall a. Vector a
V.empty [])
    else do
      [Tensor a]
ts <- (Integer -> EvalM (Tensor a))
-> Shape
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Integer
i -> Integer -> Tensor a -> EvalM (Tensor a)
forall a. Integer -> Tensor a -> EvalM (Tensor a)
tIntRef' Integer
i Tensor a
t EvalM (Tensor a)
-> (Tensor a -> EvalM (Tensor a)) -> EvalM (Tensor a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tref [Index EgisonValue]
ms) [Integer
m..Integer
n]
      String
symId <- StateT EvalState (ExceptT EgisonError RuntimeM) String
forall (m :: * -> *). MonadRuntime m => m String
fresh
      let index :: EgisonValue
index = String -> String -> EgisonValue
symbolScalarData String
"" (String
":::" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
symId)
      case Index EgisonValue
s of
        Sub{}    -> Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
forall a. Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat (EgisonValue -> Index EgisonValue
forall a. a -> Index a
Sub EgisonValue
index) [Tensor a]
ts
        Sup{}    -> Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
forall a. Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat (EgisonValue -> Index EgisonValue
forall a. a -> Index a
Sup EgisonValue
index) [Tensor a]
ts
        SupSub{} -> Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
forall a. Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat (EgisonValue -> Index EgisonValue
forall a. a -> Index a
SupSub EgisonValue
index) [Tensor a]
ts
tref (Index EgisonValue
_:[Index EgisonValue]
_) Tensor a
_ = EgisonError -> EvalM (Tensor a)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (EgisonError -> EvalM (Tensor a))
-> EgisonError -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ String -> EgisonError
Default String
"Tensor index must be an integer or a single symbol."

-- Enumarates all indices (1-indexed) from shape
-- ex.
-- >>> enumTensorIndices [2,2,2]
-- [[1,1,1],[1,1,2],[1,2,1],[1,2,2],[2,1,1],[2,1,2],[2,2,1],[2,2,2]]
enumTensorIndices :: Shape -> [[Integer]]
enumTensorIndices :: Shape -> [Shape]
enumTensorIndices []     = [[]]
enumTensorIndices (Integer
n:Shape
ns) = (Integer -> [Shape]) -> Shape -> [Shape]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\Integer
i -> (Shape -> Shape) -> [Shape] -> [Shape]
forall a b. (a -> b) -> [a] -> [b]
map (Integer
iInteger -> Shape -> Shape
forall a. a -> [a] -> [a]
:) (Shape -> [Shape]
enumTensorIndices Shape
ns)) [Integer
1..Integer
n]

-- transIndex [a, b, c] [c, a, b] [2, 3, 4] = [4, 2, 3]
transIndex :: [Index EgisonValue] -> [Index EgisonValue] -> Shape -> EvalM Shape
transIndex :: [Index EgisonValue] -> [Index EgisonValue] -> Shape -> EvalM Shape
transIndex [Index EgisonValue]
is [Index EgisonValue]
js Shape
ns = do
  (Index EgisonValue -> EvalM Integer)
-> [Index EgisonValue] -> EvalM Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Index EgisonValue
j -> case Index EgisonValue
-> [(Index EgisonValue, Integer)] -> Maybe Integer
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Index EgisonValue
j ([Index EgisonValue] -> Shape -> [(Index EgisonValue, Integer)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Index EgisonValue]
is Shape
ns) of
                Just Integer
n  -> Integer -> EvalM Integer
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
n
                Maybe Integer
Nothing -> EgisonError -> EvalM Integer
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (EgisonError -> EvalM Integer) -> EgisonError -> EvalM Integer
forall a b. (a -> b) -> a -> b
$ String -> EgisonError
Default String
"cannot transpose becuase of the inconsitent symbolic tensor indices")
       [Index EgisonValue]
js

tTranspose :: [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose :: [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose [Index EgisonValue]
is t :: Tensor a
t@(Tensor Shape
_ Vector a
_ [Index EgisonValue]
js) | [Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
is Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> [Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
js =
  Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
tTranspose [Index EgisonValue]
is t :: Tensor a
t@(Tensor Shape
ns Vector a
_ [Index EgisonValue]
js) = do
  let js' :: [Index EgisonValue]
js' = Int -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Int -> [a] -> [a]
take ([Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
is) [Index EgisonValue]
js
  let ds :: [Index EgisonValue]
ds = Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns [Index EgisonValue]
is
  Shape
ns' <- [Index EgisonValue] -> [Index EgisonValue] -> Shape -> EvalM Shape
transIndex ([Index EgisonValue]
js' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds) ([Index EgisonValue]
is [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds) Shape
ns
  Vector a
xs' <- (Shape -> EvalM Shape)
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Shape]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Index EgisonValue] -> [Index EgisonValue] -> Shape -> EvalM Shape
transIndex ([Index EgisonValue]
is [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds) ([Index EgisonValue]
js' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds)) (Shape -> [Shape]
enumTensorIndices Shape
ns') StateT EvalState (ExceptT EgisonError RuntimeM) [Shape]
-> ([Shape]
    -> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector a))
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Shape -> StateT EvalState (ExceptT EgisonError RuntimeM) a)
-> Vector Shape
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector a)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape
-> Tensor a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall a. Shape -> Tensor a -> EvalM a
`tIntRef1` Tensor a
t) (Vector Shape
 -> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector a))
-> ([Shape] -> Vector Shape)
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Shape] -> Vector Shape
forall a. [a] -> Vector a
V.fromList
  Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns' Vector a
xs' [Index EgisonValue]
is

tTranspose' :: [EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose' :: [EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose' [EgisonValue]
is t :: Tensor a
t@(Tensor Shape
_ Vector a
_ [Index EgisonValue]
js) =
  case (EgisonValue -> Maybe (Index EgisonValue))
-> [EgisonValue] -> Maybe [Index EgisonValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\EgisonValue
i -> EgisonValue -> [Index EgisonValue] -> Maybe (Index EgisonValue)
f EgisonValue
i [Index EgisonValue]
js) [EgisonValue]
is of
    Maybe [Index EgisonValue]
Nothing  -> Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
    Just [Index EgisonValue]
is' -> [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose [Index EgisonValue]
is' Tensor a
t
 where
  f :: EgisonValue -> [Index EgisonValue] -> Maybe (Index EgisonValue)
  f :: EgisonValue -> [Index EgisonValue] -> Maybe (Index EgisonValue)
f EgisonValue
i [Index EgisonValue]
js =
    ((List (IndexM Eql), [Index EgisonValue])
 -> DFS (List (IndexM Eql), [Index EgisonValue]))
-> [Index EgisonValue]
-> List (IndexM Eql)
-> [(List (IndexM Eql), [Index EgisonValue])
    -> DFS (Maybe (Index EgisonValue))]
-> Maybe (Index EgisonValue)
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (List (IndexM Eql), [Index EgisonValue])
-> DFS (List (IndexM Eql), [Index EgisonValue])
forall a. a -> DFS a
dfs [Index EgisonValue]
js (IndexM Eql -> List (IndexM Eql)
forall m. m -> List m
List (Eql -> IndexM Eql
forall m. m -> IndexM m
IndexM Eql
Eql))
      [ [mc| _ ++ ($j & (sub #i | sup #i | supsub #i)) : _ -> Just j |]
      , [mc| _ -> Nothing |]
      ]

tFlipIndices :: Tensor a -> EvalM (Tensor a)
tFlipIndices :: Tensor a -> EvalM (Tensor a)
tFlipIndices (Tensor Shape
ns Vector a
xs [Index EgisonValue]
js) = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector a
xs ((Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
reverseIndex [Index EgisonValue]
js)

appendDF :: Integer -> WHNFData -> WHNFData
appendDF :: Integer -> WHNFData -> WHNFData
appendDF Integer
id (ITensor (Tensor Shape
s Vector ObjectRef
xs [Index EgisonValue]
is)) =
  let k :: Integer
k = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
is)
   in Tensor ObjectRef -> WHNFData
ITensor (Shape
-> Vector ObjectRef -> [Index EgisonValue] -> Tensor ObjectRef
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector ObjectRef
xs ([Index EgisonValue]
is [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ (Integer -> Index EgisonValue) -> Shape -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map (Integer -> Integer -> Index EgisonValue
forall a. Integer -> Integer -> Index a
DF Integer
id) [Integer
1..Integer
k]))
appendDF Integer
id (Value (TensorData (Tensor Shape
s Vector EgisonValue
xs [Index EgisonValue]
is))) =
  let k :: Integer
k = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
is)
   in EgisonValue -> WHNFData
Value (Tensor EgisonValue -> EgisonValue
TensorData (Shape
-> Vector EgisonValue -> [Index EgisonValue] -> Tensor EgisonValue
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector EgisonValue
xs ([Index EgisonValue]
is [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ (Integer -> Index EgisonValue) -> Shape -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map (Integer -> Integer -> Index EgisonValue
forall a. Integer -> Integer -> Index a
DF Integer
id) [Integer
1..Integer
k])))
appendDF Integer
_ WHNFData
whnf = WHNFData
whnf

removeDF :: WHNFData -> EvalM WHNFData
removeDF :: WHNFData -> EvalM WHNFData
removeDF (ITensor (Tensor Shape
s Vector ObjectRef
xs [Index EgisonValue]
is)) = do
  let ([Index EgisonValue]
ds, [Index EgisonValue]
js) = (Index EgisonValue -> Bool)
-> [Index EgisonValue]
-> ([Index EgisonValue], [Index EgisonValue])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Index EgisonValue -> Bool
forall a. Index a -> Bool
isDF [Index EgisonValue]
is
  Tensor Shape
s Vector ObjectRef
ys [Index EgisonValue]
_ <- [Index EgisonValue] -> Tensor ObjectRef -> EvalM (Tensor ObjectRef)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
js [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds) (Shape
-> Vector ObjectRef -> [Index EgisonValue] -> Tensor ObjectRef
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector ObjectRef
xs [Index EgisonValue]
is)
  WHNFData -> EvalM WHNFData
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor ObjectRef -> WHNFData
ITensor (Shape
-> Vector ObjectRef -> [Index EgisonValue] -> Tensor ObjectRef
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector ObjectRef
ys [Index EgisonValue]
js))
 where
  isDF :: Index a -> Bool
isDF (DF Integer
_ Integer
_) = Bool
True
  isDF Index a
_        = Bool
False
removeDF (Value (TensorData (Tensor Shape
s Vector EgisonValue
xs [Index EgisonValue]
is))) = do
  let ([Index EgisonValue]
ds, [Index EgisonValue]
js) = (Index EgisonValue -> Bool)
-> [Index EgisonValue]
-> ([Index EgisonValue], [Index EgisonValue])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Index EgisonValue -> Bool
forall a. Index a -> Bool
isDF [Index EgisonValue]
is
  Tensor Shape
s Vector EgisonValue
ys [Index EgisonValue]
_ <- [Index EgisonValue]
-> Tensor EgisonValue -> EvalM (Tensor EgisonValue)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
js [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ds) (Shape
-> Vector EgisonValue -> [Index EgisonValue] -> Tensor EgisonValue
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector EgisonValue
xs [Index EgisonValue]
is)
  WHNFData -> EvalM WHNFData
forall (m :: * -> *) a. Monad m => a -> m a
return (EgisonValue -> WHNFData
Value (Tensor EgisonValue -> EgisonValue
TensorData (Shape
-> Vector EgisonValue -> [Index EgisonValue] -> Tensor EgisonValue
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
s Vector EgisonValue
ys [Index EgisonValue]
js)))
 where
  isDF :: Index a -> Bool
isDF (DF Integer
_ Integer
_) = Bool
True
  isDF Index a
_        = Bool
False
removeDF WHNFData
whnf = WHNFData -> EvalM WHNFData
forall (m :: * -> *) a. Monad m => a -> m a
return WHNFData
whnf

tMap :: (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap :: (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap a -> EvalM b
f (Tensor Shape
ns Vector a
xs [Index EgisonValue]
js') = do
  let js :: [Index EgisonValue]
js = [Index EgisonValue]
js' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns [Index EgisonValue]
js'
  Vector b
xs' <- (a -> EvalM b)
-> Vector a
-> StateT EvalState (ExceptT EgisonError RuntimeM) (Vector b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Vector a -> m (Vector b)
V.mapM a -> EvalM b
f Vector a
xs
  Tensor b -> EvalM (Tensor b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor b -> EvalM (Tensor b)) -> Tensor b -> EvalM (Tensor b)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector b -> [Index EgisonValue] -> Tensor b
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns Vector b
xs' [Index EgisonValue]
js
tMap a -> EvalM b
f (Scalar a
x) = b -> Tensor b
forall a. a -> Tensor a
Scalar (b -> Tensor b) -> EvalM b -> EvalM (Tensor b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> EvalM b
f a
x

tMap2 :: (a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tMap2 :: (a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tMap2 a -> b -> EvalM c
f (Tensor Shape
ns1 Vector a
xs1 [Index EgisonValue]
js1') (Tensor Shape
ns2 Vector b
xs2 [Index EgisonValue]
js2') = do
  let js1 :: [Index EgisonValue]
js1 = [Index EgisonValue]
js1' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns1 [Index EgisonValue]
js1'
  let js2 :: [Index EgisonValue]
js2 = [Index EgisonValue]
js2' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns2 [Index EgisonValue]
js2'
  let cjs :: [Index EgisonValue]
cjs = [Index EgisonValue]
js1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Index EgisonValue]
js2
  Tensor a
t1' <- [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
cjs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ ([Index EgisonValue]
js1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Index EgisonValue]
cjs)) (Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns1 Vector a
xs1 [Index EgisonValue]
js1)
  Tensor b
t2' <- [Index EgisonValue] -> Tensor b -> EvalM (Tensor b)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
cjs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ ([Index EgisonValue]
js2 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Index EgisonValue]
cjs)) (Shape -> Vector b -> [Index EgisonValue] -> Tensor b
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns2 Vector b
xs2 [Index EgisonValue]
js2)
  let cns :: Shape
cns = Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
take ([Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
cjs) (Tensor a -> Shape
forall a. Tensor a -> Shape
tShape Tensor a
t1')
  [Tensor a]
rts1 <- (Shape -> EvalM (Tensor a))
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape -> Tensor a -> EvalM (Tensor a)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
`tIntRef` Tensor a
t1') (Shape -> [Shape]
enumTensorIndices Shape
cns)
  [Tensor b]
rts2 <- (Shape -> EvalM (Tensor b))
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape -> Tensor b -> EvalM (Tensor b)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
`tIntRef` Tensor b
t2') (Shape -> [Shape]
enumTensorIndices Shape
cns)
  [Tensor c]
rts' <- (Tensor a -> Tensor b -> EvalM (Tensor c))
-> [Tensor a]
-> [Tensor b]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor c]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ((a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
forall a b c.
(a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tProduct a -> b -> EvalM c
f) [Tensor a]
rts1 [Tensor b]
rts2
  let ret :: Tensor c
ret = Shape -> Vector c -> [Index EgisonValue] -> Tensor c
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Shape
cns Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Tensor c -> Shape
forall a. Tensor a -> Shape
tShape ([Tensor c] -> Tensor c
forall a. [a] -> a
head [Tensor c]
rts')) ([Vector c] -> Vector c
forall a. [Vector a] -> Vector a
V.concat ((Tensor c -> Vector c) -> [Tensor c] -> [Vector c]
forall a b. (a -> b) -> [a] -> [b]
map Tensor c -> Vector c
forall a. Tensor a -> Vector a
tToVector [Tensor c]
rts')) ([Index EgisonValue]
cjs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Tensor c -> [Index EgisonValue]
forall a. Tensor a -> [Index EgisonValue]
tIndex ([Tensor c] -> Tensor c
forall a. [a] -> a
head [Tensor c]
rts'))
  [Index EgisonValue] -> Tensor c -> EvalM (Tensor c)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue] -> [Index EgisonValue]
uniq ([Index EgisonValue] -> [Index EgisonValue]
tDiagIndex ([Index EgisonValue]
js1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
js2))) Tensor c
ret
 where
  uniq :: [Index EgisonValue] -> [Index EgisonValue]
  uniq :: [Index EgisonValue] -> [Index EgisonValue]
uniq []     = []
  uniq (Index EgisonValue
x:[Index EgisonValue]
xs) = Index EgisonValue
xIndex EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. a -> [a] -> [a]
:[Index EgisonValue] -> [Index EgisonValue]
uniq (Index EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => a -> [a] -> [a]
delete Index EgisonValue
x [Index EgisonValue]
xs)
tMap2 a -> b -> EvalM c
f t :: Tensor a
t@Tensor{} (Scalar b
x) = (a -> EvalM c) -> Tensor a -> EvalM (Tensor c)
forall a b. (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap (a -> b -> EvalM c
`f` b
x) Tensor a
t
tMap2 a -> b -> EvalM c
f (Scalar a
x) t :: Tensor b
t@Tensor{} = (b -> EvalM c) -> Tensor b -> EvalM (Tensor c)
forall a b. (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap (a -> b -> EvalM c
f a
x) Tensor b
t
tMap2 a -> b -> EvalM c
f (Scalar a
x1) (Scalar b
x2) = c -> Tensor c
forall a. a -> Tensor a
Scalar (c -> Tensor c) -> EvalM c -> EvalM (Tensor c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> b -> EvalM c
f a
x1 b
x2

tDiag :: Tensor a -> EvalM (Tensor a)
tDiag :: Tensor a -> EvalM (Tensor a)
tDiag t :: Tensor a
t@(Tensor Shape
_ Vector a
_ [Index EgisonValue]
js) =
  case (Index EgisonValue -> Bool)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Index EgisonValue
j -> (Index EgisonValue -> Bool) -> [Index EgisonValue] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Index EgisonValue -> Index EgisonValue -> Bool
p Index EgisonValue
j) [Index EgisonValue]
js) [Index EgisonValue]
js of
    [] -> Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t
    [Index EgisonValue]
xs -> do
      let ys :: [Index EgisonValue]
ys = [Index EgisonValue]
js [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
\\ ([Index EgisonValue]
xs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ (Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
reverseIndex [Index EgisonValue]
xs)
      Tensor a
t2 <- [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
xs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ (Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
reverseIndex [Index EgisonValue]
xs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ys) Tensor a
t
      let (Shape
ns1, Shape
tmp) = Int -> Shape -> (Shape, Shape)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
xs) (Tensor a -> Shape
forall a. Tensor a -> Shape
tShape Tensor a
t2)
      let ns2 :: Shape
ns2 = Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
drop ([Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
xs) Shape
tmp
      [Tensor a]
ts <- (Shape -> EvalM (Tensor a))
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Shape
is -> Shape -> Tensor a -> EvalM (Tensor a)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
tIntRef (Shape
is Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Shape
is) Tensor a
t2) (Shape -> [Shape]
enumTensorIndices Shape
ns1)
      Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Shape
ns1 Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Shape
ns2) ([Vector a] -> Vector a
forall a. [Vector a] -> Vector a
V.concat ((Tensor a -> Vector a) -> [Tensor a] -> [Vector a]
forall a b. (a -> b) -> [a] -> [b]
map Tensor a -> Vector a
forall a. Tensor a -> Vector a
tToVector [Tensor a]
ts)) ((Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
toSupSub [Index EgisonValue]
xs [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
ys)
 where
  p :: Index EgisonValue -> Index EgisonValue -> Bool
  p :: Index EgisonValue -> Index EgisonValue -> Bool
p (Sup EgisonValue
i) (Sub EgisonValue
j) = EgisonValue
i EgisonValue -> EgisonValue -> Bool
forall a. Eq a => a -> a -> Bool
== EgisonValue
j
  p Index EgisonValue
_ Index EgisonValue
_             = Bool
False
tDiag Tensor a
t = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
t

tDiagIndex :: [Index EgisonValue] -> [Index EgisonValue]
tDiagIndex :: [Index EgisonValue] -> [Index EgisonValue]
tDiagIndex [Index EgisonValue]
js =
  ((List (IndexM Eql), [Index EgisonValue])
 -> DFS (List (IndexM Eql), [Index EgisonValue]))
-> [Index EgisonValue]
-> List (IndexM Eql)
-> [(List (IndexM Eql), [Index EgisonValue])
    -> DFS [Index EgisonValue]]
-> [Index EgisonValue]
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (List (IndexM Eql), [Index EgisonValue])
-> DFS (List (IndexM Eql), [Index EgisonValue])
forall a. a -> DFS a
dfs [Index EgisonValue]
js (IndexM Eql -> List (IndexM Eql)
forall m. m -> List m
List (Eql -> IndexM Eql
forall m. m -> IndexM m
IndexM Eql
Eql))
    [ [mc| $hjs ++ sup $i : $mjs ++ sub #i : $tjs ->
             tDiagIndex (SupSub i : hjs ++ mjs ++ tjs) |]
    , [mc| $hjs ++ sub $i : $mjs ++ sup #i : $tjs ->
             tDiagIndex (SupSub i : hjs ++ mjs ++ tjs) |]
    , [mc| _ -> js |]
    ]

tProduct :: (a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tProduct :: (a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tProduct a -> b -> EvalM c
f (Tensor Shape
ns1 Vector a
xs1 [Index EgisonValue]
js1') (Tensor Shape
ns2 Vector b
xs2 [Index EgisonValue]
js2') = do
  let js1 :: [Index EgisonValue]
js1 = [Index EgisonValue]
js1' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns1 [Index EgisonValue]
js1'
  let js2 :: [Index EgisonValue]
js2 = [Index EgisonValue]
js2' [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Shape -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns2 [Index EgisonValue]
js2'
  let ([Index EgisonValue]
cjs1, [Index EgisonValue]
cjs2, [Index EgisonValue]
tjs1, [Index EgisonValue]
tjs2) = [Index EgisonValue]
-> [Index EgisonValue]
-> ([Index EgisonValue], [Index EgisonValue], [Index EgisonValue],
    [Index EgisonValue])
h [Index EgisonValue]
js1 [Index EgisonValue]
js2
  let t1 :: Tensor a
t1 = Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns1 Vector a
xs1 [Index EgisonValue]
js1
  let t2 :: Tensor b
t2 = Shape -> Vector b -> [Index EgisonValue] -> Tensor b
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor Shape
ns2 Vector b
xs2 [Index EgisonValue]
js2
  case [Index EgisonValue]
cjs1 of
    [] -> do
      [c]
xs' <- (Shape -> EvalM c)
-> [Shape] -> StateT EvalState (ExceptT EgisonError RuntimeM) [c]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Shape
is -> do let is1 :: Shape
is1 = Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
take (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
ns1) Shape
is
                             let is2 :: Shape
is2 = Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
take (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
ns2) (Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
drop (Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
ns1) Shape
is)
                             a
x1 <- Shape -> Tensor a -> EvalM a
forall a. Shape -> Tensor a -> EvalM a
tIntRef1 Shape
is1 Tensor a
t1
                             b
x2 <- Shape -> Tensor b -> EvalM b
forall a. Shape -> Tensor a -> EvalM a
tIntRef1 Shape
is2 Tensor b
t2
                             a -> b -> EvalM c
f a
x1 b
x2)
                  (Shape -> [Shape]
enumTensorIndices (Shape
ns1 Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Shape
ns2))
      Tensor c -> EvalM (Tensor c)
forall a. Tensor a -> EvalM (Tensor a)
tContract' (Shape -> Vector c -> [Index EgisonValue] -> Tensor c
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Shape
ns1 Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Shape
ns2) ([c] -> Vector c
forall a. [a] -> Vector a
V.fromList [c]
xs') ([Index EgisonValue]
js1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
js2))
    [Index EgisonValue]
_ -> do
      Tensor a
t1' <- [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
cjs1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
tjs1) Tensor a
t1
      Tensor b
t2' <- [Index EgisonValue] -> Tensor b -> EvalM (Tensor b)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue]
cjs2 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
tjs2) Tensor b
t2
      let (Shape
cns1, Shape
_) = Int -> Shape -> (Shape, Shape)
forall a. Int -> [a] -> ([a], [a])
splitAt ([Index EgisonValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index EgisonValue]
cjs1) (Tensor a -> Shape
forall a. Tensor a -> Shape
tShape Tensor a
t1')
      [Tensor c]
rts' <- (Shape -> EvalM (Tensor c))
-> [Shape]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [Tensor c]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Shape
is -> do Tensor a
rt1 <- Shape -> Tensor a -> EvalM (Tensor a)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
tIntRef Shape
is Tensor a
t1'
                              Tensor b
rt2 <- Shape -> Tensor b -> EvalM (Tensor b)
forall a. Shape -> Tensor a -> EvalM (Tensor a)
tIntRef Shape
is Tensor b
t2'
                              (a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
forall a b c.
(a -> b -> EvalM c) -> Tensor a -> Tensor b -> EvalM (Tensor c)
tProduct a -> b -> EvalM c
f Tensor a
rt1 Tensor b
rt2)
                   (Shape -> [Shape]
enumTensorIndices Shape
cns1)
      let ret :: Tensor c
ret = Shape -> Vector c -> [Index EgisonValue] -> Tensor c
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Shape
cns1 Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Tensor c -> Shape
forall a. Tensor a -> Shape
tShape ([Tensor c] -> Tensor c
forall a. [a] -> a
head [Tensor c]
rts')) ([Vector c] -> Vector c
forall a. [Vector a] -> Vector a
V.concat ((Tensor c -> Vector c) -> [Tensor c] -> [Vector c]
forall a b. (a -> b) -> [a] -> [b]
map Tensor c -> Vector c
forall a. Tensor a -> Vector a
tToVector [Tensor c]
rts')) ((Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
toSupSub [Index EgisonValue]
cjs1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ Tensor c -> [Index EgisonValue]
forall a. Tensor a -> [Index EgisonValue]
tIndex ([Tensor c] -> Tensor c
forall a. [a] -> a
head [Tensor c]
rts'))
      [Index EgisonValue] -> Tensor c -> EvalM (Tensor c)
forall a. [Index EgisonValue] -> Tensor a -> EvalM (Tensor a)
tTranspose ([Index EgisonValue] -> [Index EgisonValue]
uniq ((Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
toSupSub [Index EgisonValue]
cjs1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
tjs1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. [a] -> [a] -> [a]
++ [Index EgisonValue]
tjs2)) Tensor c
ret
 where
  h :: [Index EgisonValue] -> [Index EgisonValue] -> ([Index EgisonValue], [Index EgisonValue], [Index EgisonValue], [Index EgisonValue])
  h :: [Index EgisonValue]
-> [Index EgisonValue]
-> ([Index EgisonValue], [Index EgisonValue], [Index EgisonValue],
    [Index EgisonValue])
h [Index EgisonValue]
js1 [Index EgisonValue]
js2 = let cjs :: [Index EgisonValue]
cjs = (Index EgisonValue -> Bool)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Index EgisonValue
j -> (Index EgisonValue -> Bool) -> [Index EgisonValue] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Index EgisonValue -> Index EgisonValue -> Bool
p Index EgisonValue
j) [Index EgisonValue]
js2) [Index EgisonValue]
js1 in
                ([Index EgisonValue]
cjs, (Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
reverseIndex [Index EgisonValue]
cjs, [Index EgisonValue]
js1 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Index EgisonValue]
cjs, [Index EgisonValue]
js2 [Index EgisonValue] -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => [a] -> [a] -> [a]
\\ (Index EgisonValue -> Index EgisonValue)
-> [Index EgisonValue] -> [Index EgisonValue]
forall a b. (a -> b) -> [a] -> [b]
map Index EgisonValue -> Index EgisonValue
forall a. Index a -> Index a
reverseIndex [Index EgisonValue]
cjs)
  p :: Index EgisonValue -> Index EgisonValue -> Bool
  p :: Index EgisonValue -> Index EgisonValue -> Bool
p (Sup EgisonValue
i) (Sub EgisonValue
j) = EgisonValue
i EgisonValue -> EgisonValue -> Bool
forall a. Eq a => a -> a -> Bool
== EgisonValue
j
  p (Sub EgisonValue
i) (Sup EgisonValue
j) = EgisonValue
i EgisonValue -> EgisonValue -> Bool
forall a. Eq a => a -> a -> Bool
== EgisonValue
j
  p Index EgisonValue
_ Index EgisonValue
_             = Bool
False
  uniq :: [Index EgisonValue] -> [Index EgisonValue]
  uniq :: [Index EgisonValue] -> [Index EgisonValue]
uniq []     = []
  uniq (Index EgisonValue
x:[Index EgisonValue]
xs) = Index EgisonValue
xIndex EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. a -> [a] -> [a]
:[Index EgisonValue] -> [Index EgisonValue]
uniq (Index EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. Eq a => a -> [a] -> [a]
delete Index EgisonValue
x [Index EgisonValue]
xs)
tProduct a -> b -> EvalM c
f (Scalar a
x) t :: Tensor b
t@Tensor{} = (b -> EvalM c) -> Tensor b -> EvalM (Tensor c)
forall a b. (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap (a -> b -> EvalM c
f a
x) Tensor b
t
tProduct a -> b -> EvalM c
f t :: Tensor a
t@Tensor{} (Scalar b
x) = (a -> EvalM c) -> Tensor a -> EvalM (Tensor c)
forall a b. (a -> EvalM b) -> Tensor a -> EvalM (Tensor b)
tMap (a -> b -> EvalM c
`f` b
x) Tensor a
t
tProduct a -> b -> EvalM c
f (Scalar a
x1) (Scalar b
x2) = c -> Tensor c
forall a. a -> Tensor a
Scalar (c -> Tensor c) -> EvalM c -> EvalM (Tensor c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> b -> EvalM c
f a
x1 b
x2

tContract :: Tensor a -> EvalM [Tensor a]
tContract :: Tensor a -> EvalM [Tensor a]
tContract Tensor a
t = do
  Tensor a
t' <- Tensor a -> EvalM (Tensor a)
forall a. Tensor a -> EvalM (Tensor a)
tDiag Tensor a
t
  case Tensor a
t' of
    Tensor (Integer
n:Shape
_) Vector a
_ (SupSub EgisonValue
_ : [Index EgisonValue]
_) -> do
      [Tensor a]
ts <- (Integer -> EvalM (Tensor a)) -> Shape -> EvalM [Tensor a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Integer -> Tensor a -> EvalM (Tensor a)
forall a. Integer -> Tensor a -> EvalM (Tensor a)
`tIntRef'` Tensor a
t') [Integer
1..Integer
n]
      [[Tensor a]]
tss <- (Tensor a -> EvalM [Tensor a])
-> [Tensor a]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [[Tensor a]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tensor a -> EvalM [Tensor a]
forall a. Tensor a -> EvalM [Tensor a]
tContract [Tensor a]
ts
      [Tensor a] -> EvalM [Tensor a]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Tensor a] -> EvalM [Tensor a]) -> [Tensor a] -> EvalM [Tensor a]
forall a b. (a -> b) -> a -> b
$ [[Tensor a]] -> [Tensor a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Tensor a]]
tss
    Tensor a
_ -> [Tensor a] -> EvalM [Tensor a]
forall (m :: * -> *) a. Monad m => a -> m a
return [Tensor a
t']

tContract' :: Tensor a -> EvalM (Tensor a)
tContract' :: Tensor a -> EvalM (Tensor a)
tContract' t :: Tensor a
t@(Tensor Shape
ns Vector a
_ [Index EgisonValue]
js) =
  ((List Something, [Index EgisonValue])
 -> DFS (List Something, [Index EgisonValue]))
-> [Index EgisonValue]
-> List Something
-> [(List Something, [Index EgisonValue])
    -> DFS (EvalM (Tensor a))]
-> EvalM (Tensor a)
forall m t (s :: * -> *) r.
(Matcher m t, MonadSearch s) =>
((m, t) -> s (m, t)) -> t -> m -> [(m, t) -> s r] -> r
match (List Something, [Index EgisonValue])
-> DFS (List Something, [Index EgisonValue])
forall a. a -> DFS a
dfs [Index EgisonValue]
js (Something -> List Something
forall m. m -> List m
List Something
M.Something)
    [ [mc| $hjs ++ $a : $mjs ++ ?(p a) : $tjs -> do
             let m = fromIntegral (length hjs)
             xs' <- mapM (\i -> tref (hjs ++ (Sub (ScalarData (SingleTerm i [])) : mjs)
                                          ++ (Sub (ScalarData (SingleTerm i [])) : tjs)) t)
                         [1..(ns !! m)]
             tConcat a xs' >>= tTranspose (hjs ++ a : mjs ++ tjs) >>= tContract' |]
    , [mc| _ -> return t |]
    ]
 where
  p :: Index EgisonValue -> Index EgisonValue -> Bool
  p :: Index EgisonValue -> Index EgisonValue -> Bool
p (Sup EgisonValue
i)    (Sup EgisonValue
j)    = EgisonValue
i EgisonValue -> EgisonValue -> Bool
forall a. Eq a => a -> a -> Bool
== EgisonValue
j
  p (Sub EgisonValue
i)    (Sub EgisonValue
j)    = EgisonValue
i EgisonValue -> EgisonValue -> Bool
forall a. Eq a => a -> a -> Bool
== EgisonValue
j
  p (DF Integer
i1 Integer
j1) (DF Integer
i2 Integer
j2) = (Integer
i1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
i2) Bool -> Bool -> Bool
&& (Integer
j1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
j2)
  p Index EgisonValue
_ Index EgisonValue
_                   = Bool
False
tContract' Tensor a
val = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor a
val

tConcat :: Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat :: Index EgisonValue -> [Tensor a] -> EvalM (Tensor a)
tConcat Index EgisonValue
s (Tensor ns :: Shape
ns@(Integer
0:Shape
_) Vector a
_ [Index EgisonValue]
js:[Tensor a]
_) = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Integer
0Integer -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ns) Vector a
forall a. Vector a
V.empty (Index EgisonValue
sIndex EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. a -> [a] -> [a]
:[Index EgisonValue]
js)
tConcat Index EgisonValue
s ts :: [Tensor a]
ts@(Tensor Shape
ns Vector a
_ [Index EgisonValue]
js:[Tensor a]
_) = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Tensor a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor a]
ts)Integer -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ns) ([Vector a] -> Vector a
forall a. [Vector a] -> Vector a
V.concat ((Tensor a -> Vector a) -> [Tensor a] -> [Vector a]
forall a b. (a -> b) -> [a] -> [b]
map Tensor a -> Vector a
forall a. Tensor a -> Vector a
tToVector [Tensor a]
ts)) (Index EgisonValue
sIndex EgisonValue -> [Index EgisonValue] -> [Index EgisonValue]
forall a. a -> [a] -> [a]
:[Index EgisonValue]
js)
tConcat Index EgisonValue
s [Tensor a]
ts = do
  [a]
ts' <- (Tensor a -> StateT EvalState (ExceptT EgisonError RuntimeM) a)
-> [Tensor a]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tensor a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall a. Tensor a -> EvalM a
getScalar [Tensor a]
ts
  Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor [Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Tensor a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor a]
ts)] ([a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
ts') [Index EgisonValue
s]

tConcat' :: [Tensor a] -> EvalM (Tensor a)
tConcat' :: [Tensor a] -> EvalM (Tensor a)
tConcat' (Tensor ns :: Shape
ns@(Integer
0:Shape
_) Vector a
_ [Index EgisonValue]
_ : [Tensor a]
_) = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Integer
0Integer -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ns) Vector a
forall a. Vector a
V.empty []
tConcat' ts :: [Tensor a]
ts@(Tensor Shape
ns Vector a
_ [Index EgisonValue]
_ : [Tensor a]
_) = Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Tensor a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor a]
ts)Integer -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ns) ([Vector a] -> Vector a
forall a. [Vector a] -> Vector a
V.concat ((Tensor a -> Vector a) -> [Tensor a] -> [Vector a]
forall a b. (a -> b) -> [a] -> [b]
map Tensor a -> Vector a
forall a. Tensor a -> Vector a
tToVector [Tensor a]
ts)) []
tConcat' [Tensor a]
ts = do
  [a]
ts' <- (Tensor a -> StateT EvalState (ExceptT EgisonError RuntimeM) a)
-> [Tensor a]
-> StateT EvalState (ExceptT EgisonError RuntimeM) [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Tensor a -> StateT EvalState (ExceptT EgisonError RuntimeM) a
forall a. Tensor a -> EvalM a
getScalar [Tensor a]
ts
  Tensor a -> EvalM (Tensor a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor a -> EvalM (Tensor a)) -> Tensor a -> EvalM (Tensor a)
forall a b. (a -> b) -> a -> b
$ Shape -> Vector a -> [Index EgisonValue] -> Tensor a
forall a. Shape -> Vector a -> [Index EgisonValue] -> Tensor a
Tensor [Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([Tensor a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor a]
ts)] ([a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
ts') []

-- utility functions for tensors

cdr :: [a] -> [a]
cdr :: [a] -> [a]
cdr []     = []
cdr (a
_:[a]
ts) = [a]
ts

split :: Integer -> V.Vector a -> [V.Vector a]
split :: Integer -> Vector a -> [Vector a]
split Integer
w Vector a
xs
  | Vector a -> Bool
forall a. Vector a -> Bool
V.null Vector a
xs = []
  | Bool
otherwise = let (Vector a
hs, Vector a
ts) = Int -> Vector a -> (Vector a, Vector a)
forall a. Int -> Vector a -> (Vector a, Vector a)
V.splitAt (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
w) Vector a
xs in
                    Vector a
hsVector a -> [Vector a] -> [Vector a]
forall a. a -> [a] -> [a]
:Integer -> Vector a -> [Vector a]
forall a. Integer -> Vector a -> [Vector a]
split Integer
w Vector a
ts

getScalar :: Tensor a -> EvalM a
getScalar :: Tensor a -> EvalM a
getScalar (Scalar a
x) = a -> EvalM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
getScalar Tensor a
_          = EgisonError -> EvalM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (EgisonError -> EvalM a) -> EgisonError -> EvalM a
forall a b. (a -> b) -> a -> b
$ String -> EgisonError
Default String
"Inconsitent Tensor order"

reverseIndex :: Index a -> Index a
reverseIndex :: Index a -> Index a
reverseIndex (Sup a
i) = a -> Index a
forall a. a -> Index a
Sub a
i
reverseIndex (Sub a
i) = a -> Index a
forall a. a -> Index a
Sup a
i
reverseIndex Index a
x       = Index a
x

toSupSub :: Index a -> Index a
toSupSub :: Index a -> Index a
toSupSub (Sup a
i) = a -> Index a
forall a. a -> Index a
SupSub a
i
toSupSub (Sub a
i) = a -> Index a
forall a. a -> Index a
SupSub a
i

complementWithDF :: Shape -> [Index a] -> [Index a]
complementWithDF :: Shape -> [Index a] -> [Index a]
complementWithDF Shape
ns [Index a]
js' = (Integer -> Index a) -> Shape -> [Index a]
forall a b. (a -> b) -> [a] -> [b]
map (Integer -> Integer -> Index a
forall a. Integer -> Integer -> Index a
DF Integer
0) [Integer
1..Integer
k]
  where k :: Integer
k = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
ns Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Index a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index a]
js'