{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Tensor.Type where
import Data.List (foldl')
import Data.Proxy
import Data.Type.Bool hiding (If)
import GHC.Exts
import GHC.TypeLits
import qualified GHC.TypeLits as L
import Unsafe.Coerce
type Shape = [Int]
type Index = [Int]
newtype SShape (shape :: [Nat]) = SShape { SShape shape -> Shape
unShape :: Shape } deriving Int -> SShape shape -> ShowS
[SShape shape] -> ShowS
SShape shape -> String
(Int -> SShape shape -> ShowS)
-> (SShape shape -> String)
-> ([SShape shape] -> ShowS)
-> Show (SShape shape)
forall (shape :: [Nat]). Int -> SShape shape -> ShowS
forall (shape :: [Nat]). [SShape shape] -> ShowS
forall (shape :: [Nat]). SShape shape -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SShape shape] -> ShowS
$cshowList :: forall (shape :: [Nat]). [SShape shape] -> ShowS
show :: SShape shape -> String
$cshow :: forall (shape :: [Nat]). SShape shape -> String
showsPrec :: Int -> SShape shape -> ShowS
$cshowsPrec :: forall (shape :: [Nat]). Int -> SShape shape -> ShowS
Show
class HasShape s where
toShape :: SShape s
toRank :: Proxy s -> Int
toRank Proxy s
_ = Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape -> Int) -> Shape -> Int
forall a b. (a -> b) -> a -> b
$ SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)
toSize :: Proxy s -> Int
toSize Proxy s
_ = Shape -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Shape -> Int) -> Shape -> Int
forall a b. (a -> b) -> a -> b
$ SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)
instance HasShape '[] where
toShape :: SShape '[]
toShape = Shape -> SShape '[]
forall (shape :: [Nat]). Shape -> SShape shape
SShape []
instance (KnownNat n, HasShape s) => HasShape (n:s) where
toShape :: SShape (n : s)
toShape = Shape -> SShape (n : s)
forall (shape :: [Nat]). Shape -> SShape shape
SShape (Shape -> SShape (n : s)) -> Shape -> SShape (n : s)
forall a b. (a -> b) -> a -> b
$ Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)) Int -> Shape -> Shape
forall a. a -> [a] -> [a]
: SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)
toNat :: KnownNat n => Proxy n -> Int
toNat :: Proxy n -> Int
toNat = Integer -> Int
forall a b. a -> b
unsafeCoerce (Integer -> Int) -> (Proxy n -> Integer) -> Proxy n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal
viToti :: Shape -> Int -> Index
viToti :: Shape -> Int -> Shape
viToti Shape
s Int
i = (Int, Shape) -> Shape
forall a b. (a, b) -> b
snd ((Int, Shape) -> Shape) -> (Int, Shape) -> Shape
forall a b. (a -> b) -> a -> b
$ ((Int, Shape) -> Int -> (Int, Shape))
-> (Int, Shape) -> Shape -> (Int, Shape)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\(Int
r,Shape
xs) Int
si -> let (Int
r',Int
x) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
divMod Int
r Int
si in (Int
r', Int
xInt -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
xs)) (Int
i,[]) (Shape -> Shape
forall a. [a] -> [a]
reverse Shape
s)
tiTovi :: Shape -> Index -> Int
tiTovi :: Shape -> Shape -> Int
tiTovi Shape
s Shape
i = (Int -> (Int, Int) -> Int) -> Int -> [(Int, Int)] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Int
b (Int
n,Int
ind) -> Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ind) Int
0 ([(Int, Int)] -> Int) -> [(Int, Int)] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> (Int, Int)) -> Shape -> Shape -> [(Int, Int)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (,) Shape
s Shape
i
newtype TensorIndex (shape :: [Nat]) = TensorIndex Index deriving (TensorIndex shape -> TensorIndex shape -> Bool
(TensorIndex shape -> TensorIndex shape -> Bool)
-> (TensorIndex shape -> TensorIndex shape -> Bool)
-> Eq (TensorIndex shape)
forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TensorIndex shape -> TensorIndex shape -> Bool
$c/= :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
== :: TensorIndex shape -> TensorIndex shape -> Bool
$c== :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
Eq,Int -> TensorIndex shape -> ShowS
[TensorIndex shape] -> ShowS
TensorIndex shape -> String
(Int -> TensorIndex shape -> ShowS)
-> (TensorIndex shape -> String)
-> ([TensorIndex shape] -> ShowS)
-> Show (TensorIndex shape)
forall (shape :: [Nat]). Int -> TensorIndex shape -> ShowS
forall (shape :: [Nat]). [TensorIndex shape] -> ShowS
forall (shape :: [Nat]). TensorIndex shape -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TensorIndex shape] -> ShowS
$cshowList :: forall (shape :: [Nat]). [TensorIndex shape] -> ShowS
show :: TensorIndex shape -> String
$cshow :: forall (shape :: [Nat]). TensorIndex shape -> String
showsPrec :: Int -> TensorIndex shape -> ShowS
$cshowsPrec :: forall (shape :: [Nat]). Int -> TensorIndex shape -> ShowS
Show,Eq (TensorIndex shape)
Eq (TensorIndex shape)
-> (TensorIndex shape -> TensorIndex shape -> Ordering)
-> (TensorIndex shape -> TensorIndex shape -> Bool)
-> (TensorIndex shape -> TensorIndex shape -> Bool)
-> (TensorIndex shape -> TensorIndex shape -> Bool)
-> (TensorIndex shape -> TensorIndex shape -> Bool)
-> (TensorIndex shape -> TensorIndex shape -> TensorIndex shape)
-> (TensorIndex shape -> TensorIndex shape -> TensorIndex shape)
-> Ord (TensorIndex shape)
TensorIndex shape -> TensorIndex shape -> Bool
TensorIndex shape -> TensorIndex shape -> Ordering
TensorIndex shape -> TensorIndex shape -> TensorIndex shape
forall (shape :: [Nat]). Eq (TensorIndex shape)
forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Ordering
forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> TensorIndex shape
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TensorIndex shape -> TensorIndex shape -> TensorIndex shape
$cmin :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> TensorIndex shape
max :: TensorIndex shape -> TensorIndex shape -> TensorIndex shape
$cmax :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> TensorIndex shape
>= :: TensorIndex shape -> TensorIndex shape -> Bool
$c>= :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
> :: TensorIndex shape -> TensorIndex shape -> Bool
$c> :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
<= :: TensorIndex shape -> TensorIndex shape -> Bool
$c<= :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
< :: TensorIndex shape -> TensorIndex shape -> Bool
$c< :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
compare :: TensorIndex shape -> TensorIndex shape -> Ordering
$ccompare :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Ordering
$cp1Ord :: forall (shape :: [Nat]). Eq (TensorIndex shape)
Ord)
instance HasShape s => IsList (TensorIndex s) where
type Item (TensorIndex s) = Int
fromList :: [Item (TensorIndex s)] -> TensorIndex s
fromList [Item (TensorIndex s)]
v =
let s :: Shape
s = SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)
in if Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
[Item (TensorIndex s)]
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
s then String -> TensorIndex s
forall a. HasCallStack => String -> a
error String
"length not match"
else if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ((Int -> Int -> Bool) -> Shape -> Shape -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i Int
n-> Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
0 Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n) Shape
[Item (TensorIndex s)]
v Shape
s) then String -> TensorIndex s
forall a. HasCallStack => String -> a
error String
"index overflow"
else Shape -> TensorIndex s
forall (shape :: [Nat]). Shape -> TensorIndex shape
TensorIndex Shape
[Item (TensorIndex s)]
v
toList :: TensorIndex s -> [Item (TensorIndex s)]
toList (TensorIndex Shape
v) = Shape
[Item (TensorIndex s)]
v
type family TensorRank (s :: [Nat]) :: Nat where
TensorRank '[] = 0
TensorRank (_:s) = TensorRank s + 1
type family TensorSize (s :: [Nat]) :: Nat where
TensorSize '[] = 1
TensorSize (n:s) = n L.* (TensorRank s)
type family Reverse (a :: [k]) (b :: [k]) :: [k] where
Reverse '[] b = b
Reverse (a:as) b = Reverse as (a:b)
type family If (b :: Bool) c d where
If 'True c d = c
If 'False c d = d
type family Replicate (a :: k) (dim :: Nat) :: [k] where
Replicate a 0 = '[]
Replicate a n = a : Replicate a n
type family Dimension (s :: [Nat]) (i :: Nat) :: Nat where
Dimension (s:_) 0 = s
Dimension (_:s) n = Dimension s (n-1)
Dimension _ _ = TypeError ('Text "Index overflow")
type CheckDimension dim s = IsIndex dim (TensorRank s)
type CheckIndices i j s = IsIndices i j (TensorRank s) ~ 'True
type IsIndex i n = (0 <=? i) && (i + 1 <=? n)
type IsIndices i j n = (0 <=? i) && (i + 1 <=? j) && (j + 1 <=? n)
type family Take (n :: Nat) (a :: [k]) :: [k] where
Take 0 _ = '[]
Take n (x:xs) = x : Take (n-1) xs
type family Drop (n :: Nat) (a :: [k]) :: [k] where
Drop 0 xs = xs
Drop n (_:xs) = Take (n-1) xs
type family Tail (a :: [k]) :: [k] where
Tail '[] = TypeError ('Text "No tail")
Tail (_:xs) = xs
type family Init (a :: [k]) :: [k] where
Init '[] = TypeError ('Text "No init")
Init '[_] = '[]
Init (x:xs) = x : Init xs
type family Head (a :: [k]) :: k where
Head '[] = TypeError ('Text "No head")
Head (x:_) = x
type family Last (a :: [k]) :: k where
Last '[] = TypeError ('Text "No last")
Last '[x] = x
Last (_:xs) = Last xs
type family (a :: [k]) ++ (b :: [k]) :: [k] where
'[] ++ b = b
(a:as) ++ b = a : (as ++ b)
mult :: (Eq a, Num a) => a -> a -> a
mult :: a -> a -> a
mult a
a a
b
| a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = a
0
| a
b a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = a
0
| Bool
otherwise = a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b
i0 :: Proxy 0
i0 = Proxy 0
forall k (t :: k). Proxy t
Proxy :: Proxy 0
i1 :: Proxy 1
i1 = Proxy 1
forall k (t :: k). Proxy t
Proxy :: Proxy 1
i2 :: Proxy 2
i2 = Proxy 2
forall k (t :: k). Proxy t
Proxy :: Proxy 2
i3 :: Proxy 3
i3 = Proxy 3
forall k (t :: k). Proxy t
Proxy :: Proxy 3
i4 :: Proxy 4
i4 = Proxy 4
forall k (t :: k). Proxy t
Proxy :: Proxy 4
i5 :: Proxy 5
i5 = Proxy 5
forall k (t :: k). Proxy t
Proxy :: Proxy 5
i6 :: Proxy 6
i6 = Proxy 6
forall k (t :: k). Proxy t
Proxy :: Proxy 6
i7 :: Proxy 7
i7 = Proxy 7
forall k (t :: k). Proxy t
Proxy :: Proxy 7
i8 :: Proxy 8
i8 = Proxy 8
forall k (t :: k). Proxy t
Proxy :: Proxy 8
i9 :: Proxy 9
i9 = Proxy 9
forall k (t :: k). Proxy t
Proxy :: Proxy 9