{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Tensor.Type where
import Data.List (foldl')
import Data.Proxy
import Data.Type.Bool hiding (If)
import GHC.Exts
import GHC.TypeLits
import qualified GHC.TypeLits as L
import Unsafe.Coerce
type Shape = [Int]
type Index = [Int]
newtype SShape (shape :: [Nat]) = SShape { unShape :: Shape } deriving Show
class HasShape s where
toShape :: SShape s
toRank :: Proxy s -> Int
toRank _ = length $ unShape (toShape :: SShape s)
toSize :: Proxy s -> Int
toSize _ = product $ unShape (toShape :: SShape s)
instance HasShape '[] where
toShape = SShape []
instance (KnownNat n, HasShape s) => HasShape (n:s) where
toShape = SShape $ fromInteger (natVal (Proxy :: Proxy n)) : unShape (toShape :: SShape s)
toNat :: KnownNat n => Proxy n -> Int
toNat = unsafeCoerce . natVal
viToti :: Shape -> Int -> Index
viToti s i = snd $ foldl' (\(r,xs) si -> let (r',x) = divMod r si in (r', x:xs)) (i,[]) (reverse s)
tiTovi :: Shape -> Index -> Int
tiTovi s i = foldl' (\b (n,ind) -> b * n + ind) 0 $ zipWith (,) s i
newtype TensorIndex (shape :: [Nat]) = TensorIndex Index deriving (Eq,Show,Ord)
instance HasShape s => IsList (TensorIndex s) where
type Item (TensorIndex s) = Int
fromList v =
let s = unShape (toShape :: SShape s)
in if length v /= length s then error "length not match"
else if or (zipWith (\i n-> i <0 || i >= n) v s) then error "index overflow"
else TensorIndex v
toList (TensorIndex v) = v
type family TensorRank (s :: [Nat]) :: Nat where
TensorRank '[] = 0
TensorRank (_:s) = TensorRank s + 1
type family TensorSize (s :: [Nat]) :: Nat where
TensorSize '[] = 1
TensorSize (n:s) = n L.* (TensorRank s)
type family Reverse (a :: [k]) (b :: [k]) :: [k] where
Reverse '[] b = b
Reverse (a:as) b = Reverse as (a:b)
type family If (b :: Bool) c d where
If 'True c d = c
If 'False c d = d
type family Replicate (a :: k) (dim :: Nat) :: [k] where
Replicate a 0 = '[]
Replicate a n = a : Replicate a n
type family Dimension (s :: [Nat]) (i :: Nat) :: Nat where
Dimension (s:_) 0 = s
Dimension (_:s) n = Dimension s (n-1)
Dimension _ _ = TypeError ('Text "Index overflow")
type CheckDimension dim s = IsIndex dim (TensorRank s)
type CheckIndices i j s = IsIndices i j (TensorRank s) ~ 'True
type IsIndex i n = (0 <=? i) && (i + 1 <=? n)
type IsIndices i j n = (0 <=? i) && (i + 1 <=? j) && (j + 1 <=? n)
type family Take (n :: Nat) (a :: [k]) :: [k] where
Take 0 _ = '[]
Take n (x:xs) = x : Take (n-1) xs
type family Drop (n :: Nat) (a :: [k]) :: [k] where
Drop 0 xs = xs
Drop n (_:xs) = Take (n-1) xs
type family Tail (a :: [k]) :: [k] where
Tail '[] = TypeError ('Text "No tail")
Tail (_:xs) = xs
type family Init (a :: [k]) :: [k] where
Init '[] = TypeError ('Text "No init")
Init '[_] = '[]
Init (x:xs) = x : Init xs
type family Head (a :: [k]) :: k where
Head '[] = TypeError ('Text "No head")
Head (x:_) = x
type family Last (a :: [k]) :: k where
Last '[] = TypeError ('Text "No last")
Last '[x] = x
Last (_:xs) = Last xs
type family (a :: [k]) ++ (b :: [k]) :: [k] where
'[] ++ b = b
(a:as) ++ b = a : (as ++ b)
mult :: (Eq a, Num a) => a -> a -> a
mult a b
| a == 0 = 0
| b == 0 = 0
| otherwise = a * b
i0 = Proxy :: Proxy 0
i1 = Proxy :: Proxy 1
i2 = Proxy :: Proxy 2
i3 = Proxy :: Proxy 3
i4 = Proxy :: Proxy 4
i5 = Proxy :: Proxy 5
i6 = Proxy :: Proxy 6
i7 = Proxy :: Proxy 7
i8 = Proxy :: Proxy 8
i9 = Proxy :: Proxy 9