{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
{-# LANGUAGE BangPatterns         #-}
{-# LANGUAGE ConstraintKinds      #-}
{-# LANGUAGE UndecidableInstances #-}

module Data.Tensor.Tensor where

import           Control.DeepSeq
import           Data.List        (intercalate)
import           Data.Proxy
import           Data.Tensor.Type
import           Data.Type.Bool   hiding (If)
import qualified Data.Vector      as V
import           GHC.Exts         (IsList (..))
import           GHC.TypeLits

-----------------------
-- Tensor
-----------------------

-- | Definition of <https://en.wikipedia.org/wiki/Tensor Tensor>.
-- `s` means shape of tensor.
--
-- > identity :: Tensor '[3,3] Int
newtype Tensor (s :: [Nat]) n = Tensor { Tensor s n -> Shape -> Shape -> n
getValue :: Shape -> Index -> n }

-- | <https://en.wikipedia.org/wiki/Scalarr_(mathematics) Scalar> is rank 0 of tensor
type Scalar n  = Tensor '[] n

-- | <https://en.wikipedia.org/wiki/Vector_(mathematics_and_physics) Vector> is rank 1 of tensor
type Vector s n = Tensor '[s] n

-- | <https://en.wikipedia.org/wiki/Matrix_(mathematics) Matrix> is rank 2 of tensor
type Matrix a b n = Tensor '[a,b] n

-- | Simple Tensor is rank `r` tensor, has `n^r` dimension in total.
--
-- > SimpleTensor 2 3 Int == Matrix 3 3 Int == Tensor '[3,3] Int
-- > SimpleTensor r 0 Int == Scalar Int
type SimpleTensor (r :: Nat) (dim :: Nat) n = Tensor (Replicate r dim) n

instance Functor (Tensor s) where
  fmap :: (a -> b) -> Tensor s a -> Tensor s b
fmap a -> b
f Tensor s a
t = (Shape -> Shape -> b) -> Tensor s b
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> b) -> Tensor s b)
-> (Shape -> Shape -> b) -> Tensor s b
forall a b. (a -> b) -> a -> b
$ \Shape
s Shape
i -> a -> b
f (Tensor s a -> Shape -> Shape -> a
forall (s :: [Nat]) n. Tensor s n -> Shape -> Shape -> n
getValue Tensor s a
t Shape
s Shape
i)

instance Applicative (Tensor s) where
  pure :: a -> Tensor s a
pure a
n = (Shape -> Shape -> a) -> Tensor s a
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> a) -> Tensor s a)
-> (Shape -> Shape -> a) -> Tensor s a
forall a b. (a -> b) -> a -> b
$ \Shape
_ Shape
_ -> a
n
  Tensor s (a -> b)
f <*> :: Tensor s (a -> b) -> Tensor s a -> Tensor s b
<*> Tensor s a
t = (Shape -> Shape -> b) -> Tensor s b
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> b) -> Tensor s b)
-> (Shape -> Shape -> b) -> Tensor s b
forall a b. (a -> b) -> a -> b
$ \Shape
s Shape
i -> Tensor s (a -> b) -> Shape -> Shape -> a -> b
forall (s :: [Nat]) n. Tensor s n -> Shape -> Shape -> n
getValue Tensor s (a -> b)
f Shape
s Shape
i (Tensor s a -> Shape -> Shape -> a
forall (s :: [Nat]) n. Tensor s n -> Shape -> Shape -> n
getValue Tensor s a
t Shape
s Shape
i)

instance (HasShape s, Eq n) => Eq (Tensor s n) where
  Tensor s n
f == :: Tensor s n -> Tensor s n -> Bool
== Tensor s n
t = Tensor s Bool -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and (Tensor s Bool -> Bool) -> Tensor s Bool -> Bool
forall a b. (a -> b) -> a -> b
$ n -> n -> Bool
forall a. Eq a => a -> a -> Bool
(==) (n -> n -> Bool) -> Tensor s n -> Tensor s (n -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor s n
f Tensor s (n -> Bool) -> Tensor s n -> Tensor s Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Tensor s n
t

instance HasShape s => Foldable (Tensor s) where
  foldr :: (a -> b -> b) -> b -> Tensor s a -> b
foldr a -> b -> b
f b
b Tensor s a
t =
    let s :: Shape
s = Tensor s a -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor s a
t
        r :: Int
r = Proxy s -> Int
forall (s :: [Nat]). HasShape s => Proxy s -> Int
toSize (Proxy s
forall k (t :: k). Proxy t
Proxy :: Proxy s)
    in (Int -> b -> b) -> b -> Shape -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (a -> b -> b
f (a -> b -> b) -> (Int -> a) -> Int -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor s a -> Shape -> Int -> a
forall (s :: [Nat]) n.
HasShape s =>
Tensor s n -> Shape -> Int -> n
gx Tensor s a
t Shape
s) b
b ([Int
0..Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] :: [Int])

instance (HasShape s, NFData a) => NFData (Tensor s a) where
  rnf :: Tensor s a -> ()
rnf = (a -> () -> ()) -> () -> Tensor s a -> ()
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\a
_ -> () -> ()
forall a. NFData a => a -> ()
rnf) () 

instance (HasShape s, Show n) => Show (Tensor s n) where
  show :: Tensor s n -> String
show (Tensor Shape -> Shape -> n
f) = let s :: Shape
s = SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s) in Int -> Shape -> Shape -> (Shape -> n) -> String
go Int
0 [] Shape
s (Shape -> Shape -> n
f Shape
s)
    where
      {-# INLINE go #-}
      go :: Int -> [Int] -> [Int] -> (Index -> n) -> String
      go :: Int -> Shape -> Shape -> (Shape -> n) -> String
go Int
_ Shape
i []     Shape -> n
fs = n -> String
forall a. Show a => a -> String
show (n -> String) -> n -> String
forall a b. (a -> b) -> a -> b
$ Shape -> n
fs (Shape -> Shape
forall a. [a] -> [a]
reverse Shape
i)
      go Int
z Shape
i [Int
n]    Shape -> n
fs = Int -> Int -> String -> [String] -> String
forall a p.
(Ord a, Num a) =>
a -> p -> String -> [String] -> String
g2 Int
n Int
z String
"," ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Int -> String) -> Shape -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
x -> n -> String
forall a. Show a => a -> String
show (Shape -> n
fs (Shape -> n) -> Shape -> n
forall a b. (a -> b) -> a -> b
$ Shape -> Shape
forall a. [a] -> [a]
reverse (Int
xInt -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
i))) [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      go Int
z Shape
i (Int
n:Shape
ns) Shape -> n
fs = Int -> Int -> String -> [String] -> String
forall a p.
(Ord a, Num a) =>
a -> p -> String -> [String] -> String
g2 Int
n Int
z String
",\n" ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Int -> String) -> Shape -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
x -> Int -> Shape -> Shape -> (Shape -> n) -> String
go (Int
zInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
xInt -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
i) Shape
ns Shape -> n
fs) [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      {-# INLINE g2 #-}
      g2 :: a -> p -> String -> [String] -> String
g2 a
n p
z String
sep [String]
xs = let x :: [String]
x = a -> p -> [String] -> [String]
forall a p. (Ord a, Num a) => a -> p -> [String] -> [String]
g3 a
n p
z [String]
xs in String
"[" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
sep [String]
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
      {-# INLINE g3 #-}
      g3 :: a -> p -> [String] -> [String]
g3 a
n p
_ [String]
xs
        | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
9 = Int -> [String] -> [String]
forall a. Int -> [a] -> [a]
take Int
8 [String]
xs [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [ String
"..", [String] -> String
forall a. [a] -> a
last [String]
xs]
        | Bool
otherwise = [String]
xs

-----------------------
-- Tensor as Num
-----------------------
instance (HasShape s, Num n) => Num (Tensor s n) where
  + :: Tensor s n -> Tensor s n -> Tensor s n
(+) = (n -> n -> n) -> Tensor s n -> Tensor s n -> Tensor s n
forall (s :: [Nat]) n.
HasShape s =>
(n -> n -> n) -> Tensor s n -> Tensor s n -> Tensor s n
zipWithTensor n -> n -> n
forall a. Num a => a -> a -> a
(+)
  * :: Tensor s n -> Tensor s n -> Tensor s n
(*) = (n -> n -> n) -> Tensor s n -> Tensor s n -> Tensor s n
forall (s :: [Nat]) n.
HasShape s =>
(n -> n -> n) -> Tensor s n -> Tensor s n -> Tensor s n
zipWithTensor n -> n -> n
forall a. Num a => a -> a -> a
(*)
  abs :: Tensor s n -> Tensor s n
abs = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Num a => a -> a
abs
  signum :: Tensor s n -> Tensor s n
signum = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Num a => a -> a
signum
  negate :: Tensor s n -> Tensor s n
negate = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Num a => a -> a
negate
  fromInteger :: Integer -> Tensor s n
fromInteger = n -> Tensor s n
forall (f :: * -> *) a. Applicative f => a -> f a
pure (n -> Tensor s n) -> (Integer -> n) -> Integer -> Tensor s n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> n
forall a. Num a => Integer -> a
fromInteger

instance (HasShape s, Fractional n) => Fractional (Tensor s n) where
  fromRational :: Rational -> Tensor s n
fromRational = n -> Tensor s n
forall (f :: * -> *) a. Applicative f => a -> f a
pure (n -> Tensor s n) -> (Rational -> n) -> Rational -> Tensor s n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> n
forall a. Fractional a => Rational -> a
fromRational
  / :: Tensor s n -> Tensor s n -> Tensor s n
(/) = (n -> n -> n) -> Tensor s n -> Tensor s n -> Tensor s n
forall (s :: [Nat]) n.
HasShape s =>
(n -> n -> n) -> Tensor s n -> Tensor s n -> Tensor s n
zipWithTensor n -> n -> n
forall a. Fractional a => a -> a -> a
(/)

instance (HasShape s, Floating n) => Floating (Tensor s n) where
  pi :: Tensor s n
pi      = n -> Tensor s n
forall (f :: * -> *) a. Applicative f => a -> f a
pure n
forall a. Floating a => a
pi
  exp :: Tensor s n -> Tensor s n
exp     = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
exp
  log :: Tensor s n -> Tensor s n
log     = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
log
  sqrt :: Tensor s n -> Tensor s n
sqrt    = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
sqrt
  logBase :: Tensor s n -> Tensor s n -> Tensor s n
logBase Tensor s n
a Tensor s n
b = n -> n -> n
forall a. Floating a => a -> a -> a
logBase (n -> n -> n) -> Tensor s n -> Tensor s (n -> n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor s n
a Tensor s (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Tensor s n
b
  sin :: Tensor s n -> Tensor s n
sin     = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
sin
  cos :: Tensor s n -> Tensor s n
cos     = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
cos
  tan :: Tensor s n -> Tensor s n
tan     = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
tan
  asin :: Tensor s n -> Tensor s n
asin    = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
asin
  acos :: Tensor s n -> Tensor s n
acos    = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
acos
  atan :: Tensor s n -> Tensor s n
atan    = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
atan
  sinh :: Tensor s n -> Tensor s n
sinh    = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
sinh
  cosh :: Tensor s n -> Tensor s n
cosh    = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
cosh
  tanh :: Tensor s n -> Tensor s n
tanh    = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
tanh
  asinh :: Tensor s n -> Tensor s n
asinh   = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
asinh
  acosh :: Tensor s n -> Tensor s n
acosh   = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
acosh
  atanh :: Tensor s n -> Tensor s n
atanh   = (n -> n) -> Tensor s n -> Tensor s n
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap n -> n
forall a. Floating a => a -> a
atanh


{-# INLINE generateTensor #-}
generateTensor :: forall s n. HasShape s => (Index -> n) -> Tensor s n
generateTensor :: (Shape -> n) -> Tensor s n
generateTensor Shape -> n
fn = case Proxy s -> Int
forall (s :: [Nat]). HasShape s => Proxy s -> Int
toSize (Proxy s
forall k (t :: k). Proxy t
Proxy :: Proxy s) of
  Int
0 -> n -> Tensor s n
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Shape -> n
fn [])
  Int
_ -> (Shape -> Shape -> n) -> Tensor s n
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> n) -> Shape -> Shape -> n
forall a b. a -> b -> a
const Shape -> n
fn)

{-# INLINE transformTensor #-}
transformTensor
  :: forall s s' n. HasShape s
  => (Shape -> (Shape, Index) -> Index)
  -> Tensor s  n
  -> Tensor s' n
transformTensor :: (Shape -> (Shape, Shape) -> Shape) -> Tensor s n -> Tensor s' n
transformTensor Shape -> (Shape, Shape) -> Shape
go (Tensor Shape -> Shape -> n
fo) =
  let s :: Shape
s = SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)
      {-# INLINE g #-}
      g :: Shape -> Shape -> n
g = ((Shape, Shape) -> n) -> Shape -> Shape -> n
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (((Shape, Shape) -> n) -> Shape -> Shape -> n)
-> ((Shape, Shape) -> n) -> Shape -> Shape -> n
forall a b. (a -> b) -> a -> b
$ Shape -> Shape -> n
fo Shape
s (Shape -> n) -> ((Shape, Shape) -> Shape) -> (Shape, Shape) -> n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> (Shape, Shape) -> Shape
go Shape
s
  in (Shape -> Shape -> n) -> Tensor s' n
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor Shape -> Shape -> n
g

-- | Clone tensor to a new `V.Vector` based tensor
clone :: HasShape s => Tensor s n -> Tensor s n
clone :: Tensor s n -> Tensor s n
clone Tensor s n
t =
  let s :: Shape
s = Tensor s n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor s n
t
      v :: Vector n
v = Int -> (Int -> n) -> Vector n
forall a. Int -> (Int -> a) -> Vector a
V.generate (Shape -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape
s) (Tensor s n -> Shape -> Int -> n
forall (s :: [Nat]) n.
HasShape s =>
Tensor s n -> Shape -> Int -> n
gx Tensor s n
t Shape
s)
  in (Shape -> Shape -> n) -> Tensor s n
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> n) -> Tensor s n)
-> (Shape -> Shape -> n) -> Tensor s n
forall a b. (a -> b) -> a -> b
$ \Shape
_ Shape
i -> Vector n
v Vector n -> Int -> n
forall a. Vector a -> Int -> a
V.! Shape -> Shape -> Int
tiTovi Shape
s Shape
i

{-# INLINE zipWithTensor #-}
zipWithTensor :: HasShape s => (n -> n -> n) -> Tensor s n -> Tensor s n -> Tensor s n
zipWithTensor :: (n -> n -> n) -> Tensor s n -> Tensor s n -> Tensor s n
zipWithTensor n -> n -> n
f Tensor s n
t1 Tensor s n
t2 =
  let s1 :: Shape
s1 = Tensor s n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor s n
t1
      s2 :: Shape
s2 = Tensor s n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor s n
t2
  in (Shape -> n) -> Tensor s n
forall (s :: [Nat]) n. HasShape s => (Shape -> n) -> Tensor s n
generateTensor (\Shape
i -> n -> n -> n
f (Tensor s n -> Shape -> Shape -> n
forall (s :: [Nat]) n. Tensor s n -> Shape -> Shape -> n
getValue Tensor s n
t1 Shape
s1 Shape
i) (Tensor s n -> Shape -> Shape -> n
forall (s :: [Nat]) n. Tensor s n -> Shape -> Shape -> n
getValue Tensor s n
t2 Shape
s2 Shape
i))

instance HasShape s => IsList (Tensor s n) where
  type Item (Tensor s n) = n
  fromList :: [Item (Tensor s n)] -> Tensor s n
fromList [Item (Tensor s n)]
v =
    let s :: Shape
s = SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)
        l :: Int
l = Shape -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape
s
    in if Int
l Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [n] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [n]
[Item (Tensor s n)]
v
      then String -> Tensor s n
forall a. HasCallStack => String -> a
error String
"length not match"
      else let vv :: Vector n
vv = [n] -> Vector n
forall a. [a] -> Vector a
V.fromList [n]
[Item (Tensor s n)]
v in (Shape -> Shape -> n) -> Tensor s n
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> n) -> Tensor s n)
-> (Shape -> Shape -> n) -> Tensor s n
forall a b. (a -> b) -> a -> b
$ \Shape
s' Shape
i -> Vector n
vv Vector n -> Int -> n
forall a. Vector a -> Int -> a
V.! Shape -> Shape -> Int
tiTovi Shape
s' Shape
i
  toList :: Tensor s n -> [Item (Tensor s n)]
toList  Tensor s n
t =
    let s :: Shape
s = SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)
        l :: Int
l = Shape -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape
s
    in (Int -> n) -> Shape -> [n]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Tensor s n -> Shape -> Int -> n
forall (s :: [Nat]) n.
HasShape s =>
Tensor s n -> Shape -> Int -> n
gx Tensor s n
t Shape
s) [Int
0..Int -> Int
forall a. Enum a => a -> a
pred Int
l]

-----------------------
-- Tensor Shape
-----------------------
-- | Shape of Tensor, is a list of integers, uniquely determine the shape of tensor.
shape :: forall s n. HasShape s => Tensor s n -> [Int]
shape :: Tensor s n -> Shape
shape Tensor s n
_ = SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)

-- | Rank of Tensor
rank :: forall s n. HasShape s => Tensor s n -> Int
rank :: Tensor s n -> Int
rank Tensor s n
_ = Proxy s -> Int
forall (s :: [Nat]). HasShape s => Proxy s -> Int
toRank (Proxy s
forall k (t :: k). Proxy t
Proxy :: Proxy s)

-----------------------
-- Tensor Operation
-----------------------
-- | Get value from tensor by index
(!) :: HasShape s => Tensor s n -> TensorIndex s -> n
(!) Tensor s n
t (TensorIndex Shape
i) = Tensor s n -> Shape -> Shape -> n
forall (s :: [Nat]) n. Tensor s n -> Shape -> Shape -> n
getValue Tensor s n
t (Tensor s n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor s n
t) Shape
i

gx :: HasShape s => Tensor s n -> Shape -> Int -> n
gx :: Tensor s n -> Shape -> Int -> n
gx (Tensor Shape -> Shape -> n
t) Shape
s Int
i = Shape -> Shape -> n
t Shape
s (Shape -> Int -> Shape
viToti Shape
s Int
i)

-- | Reshape a tensor to another tensor, with total dimensions are equal.
reshape :: (TensorSize s ~ TensorSize s', HasShape s) => Tensor s n -> Tensor s' n
reshape :: Tensor s n -> Tensor s' n
reshape = (Shape -> (Shape, Shape) -> Shape) -> Tensor s n -> Tensor s' n
forall (s :: [Nat]) (s' :: [Nat]) n.
HasShape s =>
(Shape -> (Shape, Shape) -> Shape) -> Tensor s n -> Tensor s' n
transformTensor Shape -> (Shape, Shape) -> Shape
go
  where
    {-# INLINE go #-}
    go :: Shape -> (Shape, Shape) -> Shape
go Shape
s (Shape
s',Shape
i') = Shape -> Int -> Shape
viToti Shape
s (Int -> Shape) -> Int -> Shape
forall a b. (a -> b) -> a -> b
$ Shape -> Shape -> Int
tiTovi Shape
s' Shape
i'

type Transpose (a :: [Nat]) = Reverse a '[]

-- | <https://en.wikipedia.org/wiki/Transpose Transpose> tensor completely
--
-- > λ> a = [1..9] :: Tensor '[3,3] Int
-- > λ> a
-- > [[1,2,3],
-- > [4,5,6],
-- > [7,8,9]]
-- > λ> transpose a
-- > [[1,4,7],
-- > [2,5,8],
-- > [3,6,9]]
transpose :: HasShape a => Tensor a n -> Tensor (Transpose a) n
transpose :: Tensor a n -> Tensor (Transpose a) n
transpose  = (Shape -> (Shape, Shape) -> Shape)
-> Tensor a n -> Tensor (Transpose a) n
forall (s :: [Nat]) (s' :: [Nat]) n.
HasShape s =>
(Shape -> (Shape, Shape) -> Shape) -> Tensor s n -> Tensor s' n
transformTensor Shape -> (Shape, Shape) -> Shape
forall p a a. p -> (a, [a]) -> [a]
go
  where
    {-# INLINE go #-}
    go :: p -> (a, [a]) -> [a]
go p
_ (a
_, [a]
i') = [a] -> [a]
forall a. [a] -> [a]
reverse [a]
i'

type Swapaxes i j s = Take i s ++ (Dimension s j : (Drop i (Take j s))) ++ (Dimension s j : (Tail (Drop j s)))

-- | Swapaxes any rank
--
-- > λ> a = [1..24] :: Tensor '[2,3,4] Int
-- > λ> a
-- > [[[1,2,3,4],
-- > [5,6,7,8],
-- > [9,10,11,12]],
-- > [[13,14,15,16],
-- > [17,18,19,20],
-- > [21,22,23,24]]]
-- > λ> swapaxes i0 i1 a
-- > [[[1,2,3,4],
-- > [13,14,15,16]],
-- > [[5,6,7,8],
-- > [17,18,19,20]],
-- > [[9,10,11,12],
-- > [21,22,23,24]]]
-- > λ> :t swapaxes i0 i1 a
-- > swapaxes i0 i1 a :: Tensor '[3, 2, 4] Int
-- > λ> :t swapaxes i1 i2 a
-- > swapaxes i1 i2 a :: Tensor '[2, 4, 3] Int
--
-- In rank 2 tensor, `swapaxes` is just `transpose`
--
-- > transpose == swapaxes i0 i1
swapaxes
  :: (CheckIndices i j s
    , HasShape s
    , KnownNat i
    , KnownNat j)
  => Proxy i
  -> Proxy j
  -> Tensor s n
  -> Tensor (Swapaxes i j s) n
swapaxes :: Proxy i -> Proxy j -> Tensor s n -> Tensor (Swapaxes i j s) n
swapaxes Proxy i
px Proxy j
pj =
  let i :: Int
i = Proxy i -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy i
px
      j :: Int
j = Proxy j -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy j
pj
      go :: Shape -> (Shape, Shape) -> Shape
go Shape
_ (Shape
_,Shape
s) = Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
take Int
i Shape
s Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ [Shape
s Shape -> Int -> Int
forall a. [a] -> Int -> a
!! Int
j] Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Shape -> Shape
forall a. [a] -> [a]
tail (Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
drop Int
i (Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
take Int
j Shape
s)) Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ [Shape
sShape -> Int -> Int
forall a. [a] -> Int -> a
!!Int
i] Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ Shape -> Shape
forall a. [a] -> [a]
tail (Int -> Shape -> Shape
forall a. Int -> [a] -> [a]
drop Int
j Shape
s)
  in (Shape -> (Shape, Shape) -> Shape)
-> Tensor s n -> Tensor (Swapaxes i j s) n
forall (s :: [Nat]) (s' :: [Nat]) n.
HasShape s =>
(Shape -> (Shape, Shape) -> Shape) -> Tensor s n -> Tensor s' n
transformTensor Shape -> (Shape, Shape) -> Shape
go

-- | Unit tensor of shape s, if all the indices are equal then return 1, otherwise return 0.
identity :: forall s n . (HasShape s, Num n) => Tensor s n
identity :: Tensor s n
identity = (Shape -> n) -> Tensor s n
forall (s :: [Nat]) n. HasShape s => (Shape -> n) -> Tensor s n
generateTensor Shape -> n
forall p a. (Num p, Eq a) => [a] -> p
go
  where
    go :: [a] -> p
go []  = p
0
    go [a
_] = p
1
    go (a
a:a
b:[a]
cs)
      | a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
b = p
0
      | Bool
otherwise = [a] -> p
go (a
ba -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
cs)

dyad'
  :: ( r ~ (s ++ t)
     , HasShape s
     , HasShape t
     , HasShape r)
  => (n -> m -> o)
  -> Tensor s n
  -> Tensor t m
  -> Tensor r o
dyad' :: (n -> m -> o) -> Tensor s n -> Tensor t m -> Tensor r o
dyad' n -> m -> o
f Tensor s n
t1 Tensor t m
t2 =
  let l :: Int
l = Tensor s n -> Int
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Int
rank Tensor s n
t1
      s1 :: Shape
s1 = Tensor s n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor s n
t1
      s2 :: Shape
s2 = Tensor t m -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor t m
t2
  in (Shape -> o) -> Tensor r o
forall (s :: [Nat]) n. HasShape s => (Shape -> n) -> Tensor s n
generateTensor (\Shape
i -> let (Shape
ti1,Shape
ti2) = Int -> Shape -> (Shape, Shape)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
l Shape
i in n -> m -> o
f (Tensor s n -> Shape -> Shape -> n
forall (s :: [Nat]) n. Tensor s n -> Shape -> Shape -> n
getValue Tensor s n
t1 Shape
s1 Shape
ti1) (Tensor t m -> Shape -> Shape -> m
forall (s :: [Nat]) n. Tensor s n -> Shape -> Shape -> n
getValue Tensor t m
t2 Shape
s2 Shape
ti2))

-- | <https://en.wikipedia.org/wiki/Dyadics Dyadic Tensor>
--
-- > λ> a = [1..4] :: Tensor '[2,2] Int
-- > λ> a
-- > [[1,2],
-- > [3,4]]
-- > λ> :t a `dyad` a
-- > a `dyad` a :: Tensor '[2, 2, 2, 2] Int
-- > λ> a `dyad` a
-- > [[[[1,2],
-- > [3,4]],
-- > [[2,4],
-- > [6,8]]],
-- > [[[3,6],
-- > [9,12]],
-- > [[4,8],
-- > [12,16]]]]
dyad
  :: ( r ~ (s ++ t)
     , HasShape s
     , HasShape t
     , HasShape r
     , Num n
     , Eq n)
  => Tensor s n -> Tensor t n -> Tensor r n
dyad :: Tensor s n -> Tensor t n -> Tensor r n
dyad = (n -> n -> n) -> Tensor s n -> Tensor t n -> Tensor r n
forall (r :: [Nat]) (s :: [Nat]) (t :: [Nat]) n m o.
(r ~ (s ++ t), HasShape s, HasShape t, HasShape r) =>
(n -> m -> o) -> Tensor s n -> Tensor t m -> Tensor r o
dyad' n -> n -> n
forall a. (Eq a, Num a) => a -> a -> a
mult


type DotTensor s1 s2 = Init s1 ++ Init s2

-- | Tensor Product
--
-- > λ> a = [1..4] :: Tensor '[2,2] Int
-- > λ> a
-- > [[1,2],
-- > [3,4]]
-- > λ> a `dot` a
-- > [[7,10],
-- > [15,22]]
--
-- > dot a b == contraction (dyad a b) (rank a - 1, rank a)
--
-- For rank 2 tensor, it is just matrix product.
dot
  :: ( Last s ~ Head s'
     , r ~ DotTensor s s'
     , HasShape s
     , HasShape s'
     , HasShape r
     , Num n
     , Eq n)
  => Tensor s n
  -> Tensor s' n
  -> Tensor r n
dot :: Tensor s n -> Tensor s' n -> Tensor r n
dot Tensor s n
t1 Tensor s' n
t2 =
  let s1 :: Shape
s1 = Tensor s n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor s n
t1
      s2 :: Shape
s2 = Tensor s' n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor s' n
t2
      n :: Int
n  = Shape -> Int
forall a. [a] -> a
last Shape
s1
      b :: Int
b  = Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
s1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
      f :: (Shape, Shape) -> n
f (!Shape
x,!Shape
y) = (Tensor s n -> Shape -> Shape -> n
forall (s :: [Nat]) n. Tensor s n -> Shape -> Shape -> n
getValue Tensor s n
t1 Shape
s1 Shape
x) n -> n -> n
forall a. (Eq a, Num a) => a -> a -> a
`mult` (Tensor s' n -> Shape -> Shape -> n
forall (s :: [Nat]) n. Tensor s n -> Shape -> Shape -> n
getValue Tensor s' n
t2 Shape
s2 Shape
y)
  in (Shape -> n) -> Tensor r n
forall (s :: [Nat]) n. HasShape s => (Shape -> n) -> Tensor s n
generateTensor (\Shape
i ->
        let (Shape
ti1,Shape
ti2) = Int -> Shape -> (Shape, Shape)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
b Shape
i
        in [n] -> n
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([n] -> n) -> [n] -> n
forall a b. (a -> b) -> a -> b
$ (Shape, Shape) -> n
f ((Shape, Shape) -> n) -> [(Shape, Shape)] -> [n]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Shape
ti1Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++[Int
x],Int
xInt -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ti2)| Int
x <- [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]])

type Contraction s x y = DropIndex (DropIndex s y) x
type DropIndex s i = Take i s ++ Drop (i+1) s

-- | Contraction Tensor
--
-- > λ> a = [1..16] :: Tensor '[4,4] Int
-- > λ> a
-- > [[1,2,3,4],
-- > [5,6,7,8],
-- > [9,10,11,12],
-- > [13,14,15,16]]
-- > λ> contraction (i0,i1) a
-- > 34
--
-- In rank 2 tensor, contraction of tensor is just the <https://en.wikipedia.org/wiki/Trace_(linear_algebra) trace>.
contraction
  :: forall x y s s' n.
     ( CheckIndices x y s
     , s' ~ Contraction s x y
     , Dimension s x ~ Dimension s y
     , KnownNat x
     , KnownNat y
     , HasShape s
     , HasShape s'
     , KnownNat (Dimension s x)
     , Num n)
  => (Proxy x, Proxy y)
  -> Tensor s  n
  -> Tensor s' n
contraction :: (Proxy x, Proxy y) -> Tensor s n -> Tensor s' n
contraction (Proxy x
px, Proxy y
py) t :: Tensor s n
t@(Tensor Shape -> Shape -> n
f) =
  let x :: Int
x  = Proxy x -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy x
px
      y :: Int
y  = Proxy y -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy y
py
      n :: Int
n  = Proxy (Dimension s y) -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat (Proxy (Dimension s x)
forall k (t :: k). Proxy t
Proxy :: Proxy (Dimension s x))
      s :: Shape
s  = Tensor s n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor s n
t
  in (Shape -> n) -> Tensor s' n
forall (s :: [Nat]) n. HasShape s => (Shape -> n) -> Tensor s n
generateTensor (Int -> Int -> Int -> (Shape -> n) -> Shape -> n
forall a a.
(Num a, Num a, Enum a) =>
Int -> Int -> a -> ([a] -> a) -> [a] -> a
go Int
x (Int
yInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
n (Shape -> Shape -> n
f Shape
s) )
  where
    {-# INLINE go #-}
    go :: Int -> Int -> a -> ([a] -> a) -> [a] -> a
go Int
a Int
b a
n [a] -> a
fs [a]
i =
      let ([a]
r1,[a]
rt) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
a [a]
i
          ([a]
r3,[a]
r4) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
b [a]
rt
      in [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([a] -> a) -> [a] -> a
forall a b. (a -> b) -> a -> b
$ ([a] -> a) -> [[a]] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [a] -> a
fs [[a]
r1 [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ (a
ja -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
r3) [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ (a
ja -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
r4) | a
j <- [a
0..a
na -> a -> a
forall a. Num a => a -> a -> a
-a
1]]

type CheckSelect dim i s = (CheckDimension dim s && IsIndex i (Dimension s dim)) ~ 'True

-- | Select `i` indexing of tensor
--
-- > λ> a = identity :: Tensor '[4,4] Int
-- > λ> select (i0,i0) a
-- > [1,0,0,0]
-- > λ> select (i0,i1) a
-- > [0,1,0,0]
select
  :: ( CheckSelect dim i s
     , HasShape s
     , KnownNat dim
     , KnownNat i)
  => (Proxy dim, Proxy i)
  -> Tensor s n
  -> Tensor (DropIndex s dim) n
select :: (Proxy dim, Proxy i) -> Tensor s n -> Tensor (DropIndex s dim) n
select (Proxy dim
pd, Proxy i
pid) Tensor s n
t=
  let dim :: Int
dim = Proxy dim -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy dim
pd
      ind :: Int
ind = Proxy i -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy i
pid
  in (Shape -> (Shape, Shape) -> Shape)
-> Tensor s n -> Tensor (DropIndex s dim) n
forall (s :: [Nat]) (s' :: [Nat]) n.
HasShape s =>
(Shape -> (Shape, Shape) -> Shape) -> Tensor s n -> Tensor s' n
transformTensor (Int -> Int -> Shape -> (Shape, Shape) -> Shape
forall a p a. Int -> a -> p -> (a, [a]) -> [a]
go Int
dim Int
ind) Tensor s n
t
  where
    {-# INLINE go #-}
    go :: Int -> a -> p -> (a, [a]) -> [a]
go Int
d a
i p
_ (a
_,[a]
i') = let ([a]
a,[a]
b) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
d [a]
i' in [a]
a [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ (a
ia -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
b)

type CheckSlice dim from to s = (CheckDimension dim s && IsIndices from to (Dimension s dim)) ~ 'True
type Slice dim from to s = Take dim s ++ ( to - from : Tail (Drop dim s))

-- | Slice tensor
--
-- > λ> a = identity :: Tensor '[4,4] Int
-- > λ> a
-- > [[1,0,0,0],
-- > [0,1,0,0],
-- > [0,0,1,0],
-- > [0,0,0,1]]
-- > λ> slice (i0,(i1,i3)) a
-- > [[0,1,0,0],
-- > [0,0,1,0]]
-- > λ> slice (i1,(i1,i3)) a
-- > [[0,0],
-- > [1,0],
-- > [0,1],
-- > [0,0]]
slice
  :: ( CheckSlice dim from to s
     , s' ~ Slice dim from to s
     , KnownNat dim
     , KnownNat from
     , KnownNat (to - from)
     , HasShape s)
  => (Proxy dim, (Proxy from, Proxy to))
  -> Tensor s n
  -> Tensor s' n
slice :: (Proxy dim, (Proxy from, Proxy to)) -> Tensor s n -> Tensor s' n
slice (Proxy dim
pd, (Proxy from
pa,Proxy to
_)) Tensor s n
t =
  let d :: Int
d = Proxy dim -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy dim
pd
      a :: Int
a = Proxy from -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy from
pa
  in (Shape -> (Shape, Shape) -> Shape) -> Tensor s n -> Tensor s' n
forall (s :: [Nat]) (s' :: [Nat]) n.
HasShape s =>
(Shape -> (Shape, Shape) -> Shape) -> Tensor s n -> Tensor s' n
transformTensor (\Shape
_ (Shape
_,Shape
i') -> let (Shape
x,Int
y:Shape
ys) = Int -> Shape -> (Shape, Shape)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
d Shape
i' in Shape
x Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ (Int
yInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
aInt -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ys)) Tensor s n
t

-- | Expand tensor
--
-- > λ> a = identity :: Tensor '[2,2] Int
-- > λ> a
-- > [[1,0],
-- > [0,1]]
-- > λ> expand a :: Tensor '[4,4] Int
-- > [[1,0,1,0],
-- > [0,1,0,1],
-- > [1,0,1,0],
-- > [0,1,0,1]]
expand
  :: (TensorRank s ~ TensorRank s'
     , HasShape s)
  => Tensor s n
  -> Tensor s' n
expand :: Tensor s n -> Tensor s' n
expand = (Shape -> (Shape, Shape) -> Shape) -> Tensor s n -> Tensor s' n
forall (s :: [Nat]) (s' :: [Nat]) n.
HasShape s =>
(Shape -> (Shape, Shape) -> Shape) -> Tensor s n -> Tensor s' n
transformTensor Shape -> (Shape, Shape) -> Shape
forall c a. Integral c => [c] -> (a, [c]) -> [c]
go
  where
    {-# INLINE go #-}
    go :: [c] -> (a, [c]) -> [c]
go [c]
s (a
_, [c]
i') = (c -> c -> c) -> [c] -> [c] -> [c]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith c -> c -> c
forall a. Integral a => a -> a -> a
mod [c]
i' [c]
s

type CheckConcatenate i a b = (IsIndex i (TensorRank a)) ~ 'True
type Concatenate i a b = Take i a ++ (Dimension a i + Dimension b i : Drop (i+1) a)

-- | Join a sequence of arrays along an existing axis.
--
-- > λ> a = [1..4] :: Tensor '[2,2] Int
-- > λ> a
-- > [[1,2],
-- > [3,4]]
-- > λ> b = [1,1,1,1] :: Tensor '[2,2] Int
-- > λ> b
-- > [[1,1],
-- > [1,1]]
-- > λ> concentrate i0 a b
-- > [[1,2],
-- > [3,4],
-- > [1,1],
-- > [1,1]]
-- > λ> concentrate i1 a b
-- > [[1,2,1,1],
-- > [3,4,1,1]]
concatenate
  :: ( TensorRank a ~ TensorRank b
    , DropIndex a i ~ DropIndex b i
    , CheckConcatenate i a b
    , Concatenate i a b ~ c
    , HasShape a
    , HasShape b
    , KnownNat i)
  => Proxy i
  -> Tensor a n
  -> Tensor b n
  -> Tensor c n
concatenate :: Proxy i -> Tensor a n -> Tensor b n -> Tensor c n
concatenate Proxy i
p ta :: Tensor a n
ta@(Tensor Shape -> Shape -> n
a) tb :: Tensor b n
tb@(Tensor Shape -> Shape -> n
b) =
  let i :: Int
i  = Proxy i -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy i
p
      sa :: Shape
sa = Tensor a n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor a n
ta
      sb :: Shape
sb = Tensor b n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor b n
tb
      n :: Int
n  = Shape
sa Shape -> Int -> Int
forall a. [a] -> Int -> a
!! Int
i
  in (Shape -> Shape -> n) -> Tensor c n
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> n) -> Tensor c n)
-> (Shape -> Shape -> n) -> Tensor c n
forall a b. (a -> b) -> a -> b
$ \Shape
_ Shape
ind -> let (Shape
ai,Int
x:Shape
bi) = Int -> Shape -> (Shape, Shape)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
i Shape
ind in if Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n then Shape -> Shape -> n
b Shape
sb (Shape
ai Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ (Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
n)Int -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
bi) else Shape -> Shape -> n
a Shape
sa Shape
ind

type CheckInsert dim i b = (CheckDimension dim b && IsIndex i (Dimension b dim))  ~ 'True
type Insert dim b = Take dim b ++ (Dimension b dim + 1 : Drop (dim + 1) b)

-- | Insert tensor to higher level tensor
--
-- > λ> a = [1,2] :: Vector 2 Float
-- > λ> b = a `dyad` a
-- > λ> b
-- > [[1.0,2.0],
-- > [2.0,4.0]]
-- > λ> :t b
-- > b :: Tensor '[2, 2] Float
-- > λ> c = [1..4] :: Tensor '[1,2,2] Float
-- > λ> c
-- > [[[1.0,2.0],
-- > [3.0,4.0]]]
-- > λ> d = insert i0 i0 b c
-- > λ> :t d
-- > d :: Tensor '[2, 2, 2] Float
-- > λ> d
-- > [[[1.0,2.0],
-- > [2.0,4.0]],
-- > [[1.0,2.0],
-- > [3.0,4.0]]]
-- > λ> insert i0 i1 b c
-- > [[[1.0,2.0],
-- > [3.0,4.0]],
-- > [[1.0,2.0],
-- > [2.0,4.0]]]
insert
  :: ( DropIndex b dim ~ a
    , CheckInsert dim i b
    , KnownNat i
    , KnownNat dim
    , HasShape a
    , HasShape b)
  => Proxy dim
  -> Proxy i
  -> Tensor a n
  -> Tensor b n
  -> Tensor (Insert dim b) n
insert :: Proxy dim
-> Proxy i -> Tensor a n -> Tensor b n -> Tensor (Insert dim b) n
insert Proxy dim
pd Proxy i
px a :: Tensor a n
a@(Tensor Shape -> Shape -> n
ta) b :: Tensor b n
b@(Tensor Shape -> Shape -> n
tb) =
  let d :: Int
d = Proxy dim -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy dim
pd
      i :: Int
i = Proxy i -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
toNat Proxy i
px
      sa :: Shape
sa = Tensor a n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor a n
a
      sb :: Shape
sb = Tensor b n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor b n
b
  in (Shape -> Shape -> n) -> Tensor (Insert dim b) n
forall (s :: [Nat]) n. (Shape -> Shape -> n) -> Tensor s n
Tensor ((Shape -> Shape -> n) -> Tensor (Insert dim b) n)
-> (Shape -> Shape -> n) -> Tensor (Insert dim b) n
forall a b. (a -> b) -> a -> b
$ \Shape
_ Shape
ci -> let (Shape
xs,Int
n:Shape
ys) = Int -> Shape -> (Shape, Shape)
forall a. Int -> [a] -> ([a], [a])
splitAt Int
d Shape
ci in if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i then Shape -> Shape -> n
ta Shape
sa (Shape
xsShape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++Shape
ys) else if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i then Shape -> Shape -> n
tb Shape
sb Shape
ci else Shape -> Shape -> n
tb Shape
sb (Shape
xs Shape -> Shape -> Shape
forall a. [a] -> [a] -> [a]
++ ((Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)Int -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
ys))

-- | Append tensor to the end of some dimension of other tensor
--
-- > λ> a = [1,2] :: Vector 2 Float
-- > λ> a
-- > [1.0,2.0]
-- > λ> b = 3 :: Tensor '[] Float
-- > λ> b
-- > 3.0
-- > λ> append i0 b a
-- > [1.0,2.0,3.0]
append
  :: forall dim a b n.
    ( DropIndex b dim ~ a
    , CheckInsert dim (Dimension b dim) b
    , KnownNat (Dimension b dim)
    , KnownNat dim
    , HasShape a
    , HasShape b)
  => Proxy dim
  -> Tensor a n
  -> Tensor b n
  -> Tensor (Insert dim b) n
append :: Proxy dim -> Tensor a n -> Tensor b n -> Tensor (Insert dim b) n
append Proxy dim
pd = Proxy dim
-> Proxy (Dimension b dim)
-> Tensor a n
-> Tensor b n
-> Tensor (Insert dim b) n
forall (b :: [Nat]) (dim :: Nat) (a :: [Nat]) (i :: Nat) n.
(DropIndex b dim ~ a, CheckInsert dim i b, KnownNat i,
 KnownNat dim, HasShape a, HasShape b) =>
Proxy dim
-> Proxy i -> Tensor a n -> Tensor b n -> Tensor (Insert dim b) n
insert Proxy dim
pd (Proxy (Dimension b dim)
forall k (t :: k). Proxy t
Proxy :: Proxy (Dimension b dim))

-- | Convert tensor to untyped function, for internal usage.
runTensor :: HasShape s => Tensor s n -> Index -> n
runTensor :: Tensor s n -> Shape -> n
runTensor t :: Tensor s n
t@(Tensor Shape -> Shape -> n
f) = Shape -> Shape -> n
f (Tensor s n -> Shape
forall (s :: [Nat]) n. HasShape s => Tensor s n -> Shape
shape Tensor s n
t)