{-# LANGUAGE ConstraintKinds      #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Tensor.Type where

import           Data.List      (foldl')
import           Data.Proxy
import           Data.Type.Bool hiding (If)
import           GHC.Exts
import           GHC.TypeLits
import qualified GHC.TypeLits   as L
import           Unsafe.Coerce

type Shape = [Int]
type Index = [Int]

newtype SShape (shape :: [Nat]) = SShape { SShape shape -> Shape
unShape :: Shape } deriving Int -> SShape shape -> ShowS
[SShape shape] -> ShowS
SShape shape -> String
(Int -> SShape shape -> ShowS)
-> (SShape shape -> String)
-> ([SShape shape] -> ShowS)
-> Show (SShape shape)
forall (shape :: [Nat]). Int -> SShape shape -> ShowS
forall (shape :: [Nat]). [SShape shape] -> ShowS
forall (shape :: [Nat]). SShape shape -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SShape shape] -> ShowS
$cshowList :: forall (shape :: [Nat]). [SShape shape] -> ShowS
show :: SShape shape -> String
$cshow :: forall (shape :: [Nat]). SShape shape -> String
showsPrec :: Int -> SShape shape -> ShowS
$cshowsPrec :: forall (shape :: [Nat]). Int -> SShape shape -> ShowS
Show

class HasShape s where
  toShape :: SShape s
  toRank  :: Proxy s -> Int
  toRank Proxy s
_ = Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape -> Int) -> Shape -> Int
forall a b. (a -> b) -> a -> b
$ SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)
  toSize  :: Proxy s -> Int
  toSize Proxy s
_ = Shape -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Shape -> Int) -> Shape -> Int
forall a b. (a -> b) -> a -> b
$ SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)

instance HasShape '[] where
  toShape :: SShape '[]
toShape = Shape -> SShape '[]
forall (shape :: [Nat]). Shape -> SShape shape
SShape []

instance (KnownNat n, HasShape s) => HasShape (n:s) where
  toShape :: SShape (n : s)
toShape = Shape -> SShape (n : s)
forall (shape :: [Nat]). Shape -> SShape shape
SShape (Shape -> SShape (n : s)) -> Shape -> SShape (n : s)
forall a b. (a -> b) -> a -> b
$ Integer -> Int
forall a. Num a => Integer -> a
fromInteger (Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)) Int -> Shape -> Shape
forall a. a -> [a] -> [a]
: SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)

toNat :: KnownNat n => Proxy n -> Int
toNat :: Proxy n -> Int
toNat = Integer -> Int
forall a b. a -> b
unsafeCoerce (Integer -> Int) -> (Proxy n -> Integer) -> Proxy n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal

viToti :: Shape -> Int -> Index
viToti :: Shape -> Int -> Shape
viToti Shape
s Int
i = (Int, Shape) -> Shape
forall a b. (a, b) -> b
snd ((Int, Shape) -> Shape) -> (Int, Shape) -> Shape
forall a b. (a -> b) -> a -> b
$ ((Int, Shape) -> Int -> (Int, Shape))
-> (Int, Shape) -> Shape -> (Int, Shape)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\(Int
r,Shape
xs) Int
si -> let (Int
r',Int
x) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
divMod Int
r Int
si in (Int
r', Int
xInt -> Shape -> Shape
forall a. a -> [a] -> [a]
:Shape
xs)) (Int
i,[]) (Shape -> Shape
forall a. [a] -> [a]
reverse Shape
s)

tiTovi :: Shape -> Index -> Int
tiTovi :: Shape -> Shape -> Int
tiTovi Shape
s Shape
i = (Int -> (Int, Int) -> Int) -> Int -> [(Int, Int)] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Int
b (Int
n,Int
ind) -> Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ind) Int
0 ([(Int, Int)] -> Int) -> [(Int, Int)] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> (Int, Int)) -> Shape -> Shape -> [(Int, Int)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (,) Shape
s Shape
i

-- | Tensor Index, used to locate each point of tensor
newtype TensorIndex (shape :: [Nat]) = TensorIndex Index deriving (TensorIndex shape -> TensorIndex shape -> Bool
(TensorIndex shape -> TensorIndex shape -> Bool)
-> (TensorIndex shape -> TensorIndex shape -> Bool)
-> Eq (TensorIndex shape)
forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TensorIndex shape -> TensorIndex shape -> Bool
$c/= :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
== :: TensorIndex shape -> TensorIndex shape -> Bool
$c== :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
Eq,Int -> TensorIndex shape -> ShowS
[TensorIndex shape] -> ShowS
TensorIndex shape -> String
(Int -> TensorIndex shape -> ShowS)
-> (TensorIndex shape -> String)
-> ([TensorIndex shape] -> ShowS)
-> Show (TensorIndex shape)
forall (shape :: [Nat]). Int -> TensorIndex shape -> ShowS
forall (shape :: [Nat]). [TensorIndex shape] -> ShowS
forall (shape :: [Nat]). TensorIndex shape -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TensorIndex shape] -> ShowS
$cshowList :: forall (shape :: [Nat]). [TensorIndex shape] -> ShowS
show :: TensorIndex shape -> String
$cshow :: forall (shape :: [Nat]). TensorIndex shape -> String
showsPrec :: Int -> TensorIndex shape -> ShowS
$cshowsPrec :: forall (shape :: [Nat]). Int -> TensorIndex shape -> ShowS
Show,Eq (TensorIndex shape)
Eq (TensorIndex shape)
-> (TensorIndex shape -> TensorIndex shape -> Ordering)
-> (TensorIndex shape -> TensorIndex shape -> Bool)
-> (TensorIndex shape -> TensorIndex shape -> Bool)
-> (TensorIndex shape -> TensorIndex shape -> Bool)
-> (TensorIndex shape -> TensorIndex shape -> Bool)
-> (TensorIndex shape -> TensorIndex shape -> TensorIndex shape)
-> (TensorIndex shape -> TensorIndex shape -> TensorIndex shape)
-> Ord (TensorIndex shape)
TensorIndex shape -> TensorIndex shape -> Bool
TensorIndex shape -> TensorIndex shape -> Ordering
TensorIndex shape -> TensorIndex shape -> TensorIndex shape
forall (shape :: [Nat]). Eq (TensorIndex shape)
forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Ordering
forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> TensorIndex shape
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TensorIndex shape -> TensorIndex shape -> TensorIndex shape
$cmin :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> TensorIndex shape
max :: TensorIndex shape -> TensorIndex shape -> TensorIndex shape
$cmax :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> TensorIndex shape
>= :: TensorIndex shape -> TensorIndex shape -> Bool
$c>= :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
> :: TensorIndex shape -> TensorIndex shape -> Bool
$c> :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
<= :: TensorIndex shape -> TensorIndex shape -> Bool
$c<= :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
< :: TensorIndex shape -> TensorIndex shape -> Bool
$c< :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Bool
compare :: TensorIndex shape -> TensorIndex shape -> Ordering
$ccompare :: forall (shape :: [Nat]).
TensorIndex shape -> TensorIndex shape -> Ordering
$cp1Ord :: forall (shape :: [Nat]). Eq (TensorIndex shape)
Ord)

instance HasShape s => IsList (TensorIndex s) where
  type Item (TensorIndex s) = Int
  fromList :: [Item (TensorIndex s)] -> TensorIndex s
fromList [Item (TensorIndex s)]
v =
    let s :: Shape
s = SShape s -> Shape
forall (shape :: [Nat]). SShape shape -> Shape
unShape (SShape s
forall (s :: [Nat]). HasShape s => SShape s
toShape :: SShape s)
    in if Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
[Item (TensorIndex s)]
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
s then String -> TensorIndex s
forall a. HasCallStack => String -> a
error String
"length not match"
        else if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ((Int -> Int -> Bool) -> Shape -> Shape -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i Int
n-> Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<Int
0 Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n) Shape
[Item (TensorIndex s)]
v Shape
s) then String -> TensorIndex s
forall a. HasCallStack => String -> a
error String
"index overflow"
          else Shape -> TensorIndex s
forall (shape :: [Nat]). Shape -> TensorIndex shape
TensorIndex Shape
[Item (TensorIndex s)]
v
  toList :: TensorIndex s -> [Item (TensorIndex s)]
toList (TensorIndex Shape
v) = Shape
[Item (TensorIndex s)]
v

-- | Tensor rank.
type family TensorRank (s :: [Nat]) :: Nat where
  TensorRank '[] = 0
  TensorRank (_:s) = TensorRank s + 1

-- | Tensor size.
type family TensorSize (s :: [Nat]) :: Nat where
  TensorSize '[] = 1
  TensorSize (n:s) = n L.* (TensorRank s)

type family Reverse (a :: [k]) (b :: [k]) :: [k] where
  Reverse '[]    b = b
  Reverse (a:as) b = Reverse as (a:b)

type family If (b :: Bool) c d where
  If 'True  c d = c
  If 'False c d = d

type family Replicate (a :: k) (dim :: Nat) :: [k] where
  Replicate a 0 = '[]
  Replicate a n = a : Replicate a n

type family Dimension (s :: [Nat]) (i :: Nat) :: Nat where
  Dimension (s:_) 0 = s
  Dimension (_:s) n = Dimension s (n-1)
  Dimension _ _     = TypeError ('Text "Index overflow")

type CheckDimension dim s = IsIndex dim (TensorRank s)
type CheckIndices i j s = IsIndices i j (TensorRank s) ~ 'True

type IsIndex i n = (0 <=? i) && (i + 1 <=? n)
type IsIndices i j n = (0 <=? i) && (i + 1 <=? j) && (j + 1 <=? n)

type family Take (n :: Nat) (a :: [k]) :: [k] where
  Take 0 _ = '[]
  Take n (x:xs) = x : Take (n-1) xs

type family Drop (n :: Nat) (a :: [k]) :: [k] where
  Drop 0 xs = xs
  Drop n (_:xs) = Take (n-1) xs

type family Tail (a :: [k]) :: [k] where
  Tail '[] = TypeError ('Text "No tail")
  Tail (_:xs) = xs

type family Init (a :: [k]) :: [k] where
  Init '[] = TypeError ('Text "No init")
  Init '[_] = '[]
  Init (x:xs) = x : Init xs

type family Head (a :: [k]) :: k where
  Head '[] = TypeError ('Text "No head")
  Head (x:_) = x

type family Last (a :: [k]) :: k where
  Last '[] = TypeError ('Text "No last")
  Last '[x] = x
  Last (_:xs) = Last xs

type family (a :: [k]) ++ (b :: [k]) :: [k] where
  '[] ++ b = b
  (a:as) ++ b = a : (as ++ b)

mult :: (Eq a, Num a) => a -> a -> a
mult :: a -> a -> a
mult a
a a
b
  | a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = a
0
  | a
b a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = a
0
  | Bool
otherwise = a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b
-----------------------
-- Tensor Type Index
-----------------------
i0 :: Proxy 0
i0  = Proxy 0
forall k (t :: k). Proxy t
Proxy :: Proxy 0
i1 :: Proxy 1
i1  = Proxy 1
forall k (t :: k). Proxy t
Proxy :: Proxy 1
i2 :: Proxy 2
i2  = Proxy 2
forall k (t :: k). Proxy t
Proxy :: Proxy 2
i3 :: Proxy 3
i3  = Proxy 3
forall k (t :: k). Proxy t
Proxy :: Proxy 3
i4 :: Proxy 4
i4  = Proxy 4
forall k (t :: k). Proxy t
Proxy :: Proxy 4
i5 :: Proxy 5
i5  = Proxy 5
forall k (t :: k). Proxy t
Proxy :: Proxy 5
i6 :: Proxy 6
i6  = Proxy 6
forall k (t :: k). Proxy t
Proxy :: Proxy 6
i7 :: Proxy 7
i7  = Proxy 7
forall k (t :: k). Proxy t
Proxy :: Proxy 7
i8 :: Proxy 8
i8  = Proxy 8
forall k (t :: k). Proxy t
Proxy :: Proxy 8
i9 :: Proxy 9
i9  = Proxy 9
forall k (t :: k). Proxy t
Proxy :: Proxy 9