{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Tensor.Tensor where
import Data.List (intercalate)
import Data.Proxy
import Data.Singletons
import qualified Data.Singletons.Prelude as N
import qualified Data.Singletons.Prelude.List as N
import Data.Tensor.Index
import Data.Tensor.Type
import qualified Data.Vector as V
import GHC.Exts (IsList (..))
import GHC.TypeLits
newtype Tensor (s :: [Nat]) n = Tensor { getValue :: Shape -> Index -> n }
type Scalar n = Tensor '[] n
type Vector s n = Tensor '[s] n
type Matrix a b n = Tensor '[a,b] n
type SimpleTensor (r :: Nat) (dim :: Nat) n = N.If ((N.==) dim 0) (Scalar n) (Tensor (N.Replicate r dim) n)
instance (SingI s, Eq n) => Eq (Tensor s n) where
f == g = all (\i -> f ! i == g ! i ) ([minBound..maxBound] :: [TensorIndex s])
instance SingI s => Functor (Tensor s) where
fmap f (Tensor t) = Tensor (\s i -> f (t s i))
instance SingI s => Applicative (Tensor s) where
pure n = Tensor $ \_ _ -> n
Tensor f <*> Tensor t = Tensor $ \s i -> f s i (t s i)
instance SingI s => Foldable (Tensor s) where
foldMap f t = foldMap (f.(t !)) ([minBound..maxBound] :: [TensorIndex s])
instance (SingI s, Show n) => Show (Tensor s n) where
show (Tensor f) = let s = natsVal (Proxy :: Proxy s) in go 0 [] s (f s)
where
{-# INLINE go #-}
go :: Int -> [Int] -> [Int] -> (Index -> n) -> String
go _ i [] fs = show $ fs (reverse i)
go z i [n] fs = g2 n z "," $ fmap (\x -> show (fs $ reverse (x:i))) [0..n-1]
go z i (n:ns) fs = g2 n z ",\n" $ fmap (\x -> go (z+1) (x:i) ns fs) [0..n-1]
{-# INLINE g2 #-}
g2 n z sep xs = let x = g3 n z xs in "[" ++ intercalate sep x ++ "]"
{-# INLINE g3 #-}
g3 n _ xs
| n > 9 = take 8 xs ++ [ "..", last xs]
| otherwise = xs
instance (SingI s, Num n) => Num (Tensor s n) where
(+) = zipWithTensor (+)
(*) = zipWithTensor (*)
abs = fmap abs
signum = fmap signum
negate = fmap negate
fromInteger = pure . fromInteger
instance (SingI s, Fractional n) => Fractional (Tensor s n) where
fromRational = pure . fromRational
(/) = zipWithTensor (/)
instance (SingI s, Floating n) => Floating (Tensor s n) where
pi = pure pi
exp = fmap exp
log = fmap log
sqrt = fmap sqrt
logBase = error "undefined"
sin = fmap sin
cos = fmap cos
tan = fmap tan
asin = fmap asin
acos = fmap acos
atan = fmap atan
sinh = fmap sinh
cosh = fmap cosh
tanh = fmap tanh
asinh = fmap asinh
acosh = fmap acosh
atanh = fmap atanh
{-# INLINE generateTensor #-}
generateTensor :: SingI s => (Index -> n) -> Proxy s -> Tensor s n
generateTensor fn p =
let s = natsVal p
ps = product s
in if ps == 0 then pure (fn [0]) else Tensor $ const fn
{-# INLINE transformTensor #-}
transformTensor
:: forall s s' n. SingI s
=> (([Int], [Int]) -> [Int] -> [Int])
-> Tensor s n
-> Tensor s' n
transformTensor go (Tensor f) = let s = natsVal (Proxy :: Proxy s) in Tensor $ \s' i' -> f s (go (i',s') s)
clone :: SingI s => Tensor s n -> Tensor s n
clone t =
let s = shape t
v = V.generate (product s) (\i -> t ! toEnum i)
in Tensor $ \_ i -> v V.! tiTovi s i
{-# INLINE zipWithTensor #-}
zipWithTensor :: SingI s => (n -> n -> n) -> Tensor s n -> Tensor s n -> Tensor s n
zipWithTensor f t1 t2 = generateTensor (\i -> f (t1 ! TensorIndex i) (t2 ! TensorIndex i)) Proxy
instance SingI s => IsList (Tensor s n) where
type Item (Tensor s n) = n
fromList v =
let s = natsVal (Proxy :: Proxy s)
l = product s
in if l /= length v
then error "length not match"
else let vv = V.fromList v in Tensor $ \s' i -> vv V.! tiTovi s' i
toList t = let n = rank t - 1 in fmap (\i -> t ! toEnum i) [0..n]
shape :: forall s n. SingI s => Tensor s n -> [Int]
shape _ = natsVal (Proxy :: Proxy s)
rank :: SingI s => Tensor s n -> Int
rank = length . shape
(!) :: SingI s => Tensor s n -> TensorIndex s -> n
(!) t (TensorIndex i) = getValue t (shape t) i
reshape :: (N.Product s ~ N.Product s', SingI s) => Tensor s n -> Tensor s' n
reshape = transformTensor go
where
{-# INLINE go #-}
go (i',s') s = viToti s $ tiTovi s' i'
type Transpose (a :: [Nat]) = N.Reverse a
transpose :: SingI a => Tensor a n -> Tensor (Transpose a) n
transpose = transformTensor go
where
{-# INLINE go #-}
go (i',_) _ = reverse i'
type CheckSwapaxes i j s = N.And '[ (N.>=) i 0, (N.<) i j, (N.<) j (N.Length s)]
type Swapaxes i j s = N.Concat '[N.Take i s, '[(N.!!) s j], N.Tail (N.Drop i (N.Take j s)) , '[(N.!!) s i], N.Tail (N.Drop j s)]
swapaxes
:: (Swapaxes i j s ~ s'
, CheckSwapaxes i j s ~ 'True
, SingI s
, KnownNat i
, KnownNat j)
=> Proxy i
-> Proxy j
-> Tensor s n
-> Tensor s' n
swapaxes px pj =
let i = toNat px
j = toNat pj
go (s,_) _ = take i s ++ [s !! j] ++ tail (drop i (take j s)) ++ [s!!i] ++ tail (drop j s)
in transformTensor go
identity :: forall s n . (SingI s, Num n) => Tensor s n
identity = generateTensor go Proxy
where
go [] = 0
go [_] = 1
go (a:b:cs)
| a /= b = 0
| otherwise = go (b:cs)
dyad'
:: ( r ~ (N.++) s t
, SingI s
, SingI t
, SingI r)
=> (n -> m -> o)
-> Tensor s n
-> Tensor t m
-> Tensor r o
dyad' f t1 t2 =
let l = rank t1
in generateTensor (\i -> let (ti1,ti2) = splitAt l i in f (t1 ! TensorIndex ti1) (t2 ! TensorIndex ti2)) Proxy
dyad
:: ( r ~ (N.++) s t
, SingI s
, SingI t
, SingI r
, Num n
, Eq n)
=> Tensor s n -> Tensor t n -> Tensor r n
dyad = dyad' mult
type DotTensor s1 s2 = (N.++) (N.Init s1) (N.Tail s2)
dot
:: ( N.Last s ~ N.Head s'
, SingI (DotTensor s s')
, SingI s
, SingI s'
, Num n
, Eq n)
=> Tensor s n
-> Tensor s' n
-> Tensor (DotTensor s s') n
dot t1 t2 =
let s1 = shape t1
n = last s1
b = length s1 - 1
in generateTensor (\i ->
let (ti1,ti2) = splitAt b i
in sum $ fmap (\(x,y) -> (t1 ! TensorIndex x) `mult` (t2 ! TensorIndex y)) [(ti1++[x],x:ti2)| x <- [0..n-1]]) Proxy
type CheckContraction s x y = N.And '[(N.<) x y, (N.>=) x 0, (N.<) y (TensorRank s)]
type Contraction s x y = DropIndex (DropIndex s y) x
type TensorDim s i = (N.!!) s i
type DropIndex (s :: [Nat]) (i :: Nat) = (N.++) (N.Fst (N.SplitAt i s)) (N.Tail (N.Snd (N.SplitAt i s)))
contraction
:: forall x y s s' n.
( CheckContraction s x y ~ 'True
, s' ~ Contraction s x y
, TensorDim s x ~ TensorDim s y
, KnownNat x
, KnownNat y
, SingI s
, SingI s'
, KnownNat (TensorDim s x)
, Num n)
=> (Proxy x, Proxy y)
-> Tensor s n
-> Tensor s' n
contraction (px, py) t@(Tensor f) =
let x = toNat px
y = toNat py
n = toNat (Proxy :: Proxy (TensorDim s x))
s = shape t
in generateTensor (go x (y-x-1) n (f s) ) Proxy
where
{-# INLINE go #-}
go a b n fs i =
let (r1,rt) = splitAt a i
(r3,r4) = splitAt b rt
in sum $ fmap fs [r1 ++ (j:r3) ++ (j:r4) | j <- [0..n-1]]
type CheckDim dim s = N.And '[(N.>=) dim 0, (N.<) dim (N.Length s)]
type CheckSelect dim i s = N.And '[ CheckDim dim s , (N.>=) i 0, (N.<) i ((N.!!) s dim) ]
type Select i s = (N.++) (N.Take i s) (N.Tail (N.Drop i s))
select
:: ( CheckSelect dim i s ~ 'True
, s' ~ Select dim s
, SingI s
, KnownNat dim
, KnownNat i)
=> (Proxy dim, Proxy i)
-> Tensor s n
-> Tensor s' n
select (pd, pid) t=
let dim = toNat pd
ind = toNat pid
in transformTensor (go dim ind) t
where
{-# INLINE go #-}
go d i (i',_) _ = let (a,b) = splitAt d i' in a ++ (i:b)
type CheckSlice dim from to s = N.And '[ CheckDim dim s, CheckSelect dim from s, (N.<) from to , (N.<=) to ((N.!!) s dim)]
type Slice dim from to s = N.Concat '[N.Take dim s, '[to - from] , N.Tail (N.Drop dim s)]
slice
:: ( CheckSlice dim from to s ~ 'True
, s' ~ Slice dim from to s
, KnownNat dim
, KnownNat from
, KnownNat (to - from)
, SingI s)
=> (Proxy dim, (Proxy from, Proxy to))
-> Tensor s n
-> Tensor s' n
slice (pd, (pa,_)) t =
let d = toNat pd
a = toNat pa
in transformTensor (\(i',_) _ -> let (x,y:ys) = splitAt d i' in x ++ (y+a:ys)) t
expand
:: (TensorRank s ~ TensorRank s'
, SingI s)
=> Tensor s n
-> Tensor s' n
expand = transformTensor go
where
{-# INLINE go #-}
go (i',_) = zipWith mod i'
type CheckConcatenate i a b = N.And '[ (N.==) (N.Length a) (N.Length b), (N.>=) i 0, (N.<) i (N.Length a), (N.==) (Select i a) (Select i b) ]
type Concatenate i a b = N.Concat '[N.Take i a, '[(N.+) (TensorDim a i) (TensorDim b i)], N.Tail (N.Drop i a)]
concatenate
:: (CheckConcatenate i a b ~ 'True
, Concatenate i a b ~ c
, SingI a
, SingI b
, KnownNat i)
=> Proxy i
-> Tensor a n
-> Tensor b n
-> Tensor c n
concatenate p ta@(Tensor a) tb@(Tensor b) =
let i = toNat p
sa = shape ta
sb = shape tb
n = sa !! i
in Tensor $ \_ ind -> let (ai,x:bi) = splitAt i ind in if x >= n then b sb (ai ++ (x-n):bi) else a sa ind
type CheckInsert dim i a b = N.And '[ CheckDim dim b, (N.==) a (Select dim b), (N.>=) i 0, (N.<=) i (TensorDim b dim)]
type Insert dim a b = N.Concat '[N.Take dim b, '[ TensorDim b dim + 1 ], N.Tail (N.Drop dim b)]
insert
:: (CheckInsert dim i a b ~ 'True
, KnownNat i
, KnownNat dim
, SingI a
, SingI b)
=> Proxy dim
-> Proxy i
-> Tensor a n
-> Tensor b n
-> Tensor (Insert dim a b) n
insert pd px a@(Tensor ta) b@(Tensor tb) =
let d = toNat pd
i = toNat px
sa = shape a
sb = shape b
in Tensor $ \_ ci -> let (xs,n:ys) = splitAt d ci in if n == i then ta sa (xs++ys) else if n < i then tb sb ci else tb sb (xs ++ ((n-1):ys))
append
:: forall dim a b n.
(CheckInsert dim (TensorDim b dim) a b ~ 'True
, KnownNat (TensorDim b dim)
, KnownNat dim
, SingI a
, SingI b)
=> Proxy dim
-> Tensor a n
-> Tensor b n
-> Tensor (Insert dim a b) n
append pd = insert pd (Proxy :: Proxy (TensorDim b dim))
runTensor :: SingI s => Tensor s n -> Index -> n
runTensor t@(Tensor f) = f (shape t)