{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
module NumHask.Array.Shape where
import Data.List ((!!))
import Data.Type.Bool
import GHC.TypeLits as L
import NumHask.Prelude as P hiding (Last, minimum)
newtype Shape (s :: [Nat]) = Shape {shapeVal :: [Int]} deriving (Show)
class HasShape s where
toShape :: Shape s
instance HasShape '[] where
toShape = Shape []
instance (KnownNat n, HasShape s) => HasShape (n : s) where
toShape = Shape $ fromInteger (natVal (Proxy :: Proxy n)) : shapeVal (toShape :: Shape s)
rank :: [a] -> Int
rank = length
{-# INLINE rank #-}
type family Rank (s :: [a]) :: Nat where
Rank '[] = 0
Rank (_ : s) = Rank s + 1
ranks :: [[a]] -> [Int]
ranks = fmap rank
{-# INLINE ranks #-}
type family Ranks (s :: [[a]]) :: [Nat] where
Ranks '[] = '[]
Ranks (x : xs) = Rank x : Ranks xs
size :: [Int] -> Int
size [] = 1
size [x] = x
size xs = P.product xs
{-# INLINE size #-}
type family Size (s :: [Nat]) :: Nat where
Size '[] = 1
Size (n : s) = n L.* Size s
flatten :: [Int] -> [Int] -> Int
flatten [] _ = 0
flatten _ [x'] = x'
flatten ns xs = sum $ zipWith (*) xs (drop 1 $ scanr (*) one ns)
{-# INLINE flatten #-}
shapen :: [Int] -> Int -> [Int]
shapen [] _ = []
shapen [_] x' = [x']
shapen [_, y] x' = let (i, j) = divMod x' y in [i, j]
shapen ns x =
fst $
foldr
( \a (acc, r) ->
let (d, m) = divMod r a
in (m : acc, d)
)
([], x)
ns
{-# INLINE shapen #-}
checkIndex :: Int -> Int -> Bool
checkIndex i n = zero <= i && i + one <= n
type family CheckIndex (i :: Nat) (n :: Nat) :: Bool where
CheckIndex i n =
If ((0 <=? i) && (i + 1 <=? n)) 'True (L.TypeError ('Text "index outside range"))
checkIndexes :: [Int] -> Int -> Bool
checkIndexes is n = all (`checkIndex` n) is
type family CheckIndexes (i :: [Nat]) (n :: Nat) :: Bool where
CheckIndexes '[] n = 'True
CheckIndexes (i : is) n = CheckIndex i n && CheckIndexes is n
dimension :: [Int] -> Int -> Int
dimension (s : _) 0 = s
dimension (_ : s) n = dimension s (n - 1)
dimension _ _ = throw (NumHaskException "dimension overflow")
type family Dimension (s :: [Nat]) (i :: Nat) :: Nat where
Dimension (s : _) 0 = s
Dimension (_ : s) n = Dimension s (n - 1)
Dimension _ _ = L.TypeError ('Text "dimension overflow")
minimum :: [Int] -> Int
minimum [] = throw (NumHaskException "dimension underflow")
minimum [x] = x
minimum (x : xs) = P.min x (minimum xs)
type family Minimum (s :: [Nat]) :: Nat where
Minimum '[] = L.TypeError ('Text "zero dimension")
Minimum '[x] = x
Minimum (x : xs) = If (x <=? Minimum xs) x (Minimum xs)
type family Take (n :: Nat) (a :: [k]) :: [k] where
Take 0 _ = '[]
Take n (x : xs) = x : Take (n - 1) xs
type family Drop (n :: Nat) (a :: [k]) :: [k] where
Drop 0 xs = xs
Drop n (_ : xs) = Drop (n - 1) xs
type family Tail (a :: [k]) :: [k] where
Tail '[] = L.TypeError ('Text "No tail")
Tail (_ : xs) = xs
type family Init (a :: [k]) :: [k] where
Init '[] = L.TypeError ('Text "No init")
Init '[_] = '[]
Init (x : xs) = x : Init xs
type family Head (a :: [k]) :: k where
Head '[] = L.TypeError ('Text "No head")
Head (x : _) = x
type family Last (a :: [k]) :: k where
Last '[] = L.TypeError ('Text "No last")
Last '[x] = x
Last (_ : xs) = Last xs
type family (a :: [k]) ++ (b :: [k]) :: [k] where
'[] ++ b = b
(a : as) ++ b = a : (as ++ b)
dropIndex :: [Int] -> Int -> [Int]
dropIndex s i = take i s ++ drop (i + 1) s
type DropIndex s i = Take i s ++ Drop (i + 1) s
addIndex :: [Int] -> Int -> Int -> [Int]
addIndex s i d = take i s ++ (d : drop i s)
type AddIndex s i d = Take i s ++ (d : Drop i s)
type Reverse (a :: [k]) = ReverseGo a '[]
type family ReverseGo (a :: [k]) (b :: [k]) :: [k] where
ReverseGo '[] b = b
ReverseGo (a : as) b = ReverseGo as (a : b)
posRelative :: [Int] -> [Int]
posRelative as = reverse (go [] as)
where
go r [] = r
go r (x : xs) = go (x : r) ((\y -> bool (y - one) y (y < x)) <$> xs)
type family PosRelative (s :: [Nat]) where
PosRelative s = PosRelativeGo s '[]
type family PosRelativeGo (r :: [Nat]) (s :: [Nat]) where
PosRelativeGo '[] r = Reverse r
PosRelativeGo (x : xs) r = PosRelativeGo (DecMap x xs) (x : r)
type family DecMap (x :: Nat) (ys :: [Nat]) :: [Nat] where
DecMap _ '[] = '[]
DecMap x (y : ys) = If (y + 1 <=? x) y (y - 1) : DecMap x ys
dropIndexes :: [Int] -> [Int] -> [Int]
dropIndexes s i = foldl' dropIndex s (posRelative i)
type family DropIndexes (s :: [Nat]) (i :: [Nat]) where
DropIndexes s i = DropIndexesGo s (PosRelative i)
type family DropIndexesGo (s :: [Nat]) (i :: [Nat]) where
DropIndexesGo s '[] = s
DropIndexesGo s (i : is) = DropIndexesGo (DropIndex s i) is
addIndexes :: () => [Int] -> [Int] -> [Int] -> [Int]
addIndexes as xs = addIndexesGo as (reverse (posRelative (reverse xs)))
where
addIndexesGo as' [] _ = as'
addIndexesGo as' (x : xs') (y : ys') = addIndexesGo (addIndex as' x y) xs' ys'
addIndexesGo _ _ _ = throw (NumHaskException "mismatched ranks")
type family AddIndexes (as :: [Nat]) (xs :: [Nat]) (ys :: [Nat]) where
AddIndexes as xs ys = AddIndexesGo as (Reverse (PosRelative (Reverse xs))) ys
type family AddIndexesGo (as :: [Nat]) (xs :: [Nat]) (ys :: [Nat]) where
AddIndexesGo as' '[] _ = as'
AddIndexesGo as' (x : xs') (y : ys') = AddIndexesGo (AddIndex as' x y) xs' ys'
AddIndexesGo _ _ _ = L.TypeError ('Text "mismatched ranks")
takeIndexes :: [Int] -> [Int] -> [Int]
takeIndexes s i = (s !!) <$> i
type family TakeIndexes (s :: [Nat]) (i :: [Nat]) where
TakeIndexes '[] _ = '[]
TakeIndexes _ '[] = '[]
TakeIndexes s (i : is) =
(s !! i) ': TakeIndexes s is
type family (a :: [k]) !! (b :: Nat) :: k where
(!!) '[] i = L.TypeError ('Text "Index Underflow")
(!!) (x : _) 0 = x
(!!) (_ : xs) i = (!!) xs (i - 1)
type family Enumerate (n :: Nat) where
Enumerate n = Reverse (EnumerateGo n)
type family EnumerateGo (n :: Nat) where
EnumerateGo 0 = '[]
EnumerateGo n = (n - 1) : EnumerateGo (n - 1)
exclude :: Int -> [Int] -> [Int]
exclude r = dropIndexes [0 .. (r - 1)]
type family Exclude (r :: Nat) (i :: [Nat]) where
Exclude r i = DropIndexes (EnumerateGo r) i
concatenate' :: Int -> [Int] -> [Int] -> [Int]
concatenate' i s0 s1 = take i s0 ++ (dimension s0 i + dimension s1 i : drop (i + 1) s0)
type Concatenate i s0 s1 = Take i s0 ++ (Dimension s0 i + Dimension s1 i : Drop (i + 1) s0)
type CheckConcatenate i s0 s1 s =
( CheckIndex i (Rank s0)
&& DropIndex s0 i == DropIndex s1 i
&& Rank s0 == Rank s1
)
~ 'True
type CheckInsert d i s =
(CheckIndex d (Rank s) && CheckIndex i (Dimension s d)) ~ 'True
type Insert d s = Take d s ++ (Dimension s d + 1 : Drop (d + 1) s)
incAt :: Int -> [Int] -> [Int]
incAt d s = take d s ++ (dimension s d + 1 : drop (d + 1) s)
decAt :: Int -> [Int] -> [Int]
decAt d s = take d s ++ (dimension s d - 1 : drop (d + 1) s)
reorder' :: [Int] -> [Int] -> [Int]
reorder' [] _ = []
reorder' _ [] = []
reorder' s (d : ds) = dimension s d : reorder' s ds
type family Reorder (s :: [Nat]) (ds :: [Nat]) :: [Nat] where
Reorder '[] _ = '[]
Reorder _ '[] = '[]
Reorder s (d : ds) = Dimension s d : Reorder s ds
type family CheckReorder (ds :: [Nat]) (s :: [Nat]) where
CheckReorder ds s =
If
( Rank ds == Rank s
&& CheckIndexes ds (Rank s)
)
'True
(L.TypeError ('Text "bad dimensions"))
~ 'True
squeeze' :: (Eq a, Num a) => [a] -> [a]
squeeze' = filter (/=1)
type family Squeeze (a :: [Nat]) where
Squeeze '[] = '[]
Squeeze a = Filter '[] a 1
type family Filter (r :: [Nat]) (xs :: [Nat]) (i :: Nat) where
Filter r '[] _ = Reverse r
Filter r (x : xs) i = Filter (If (x == i) r (x : r)) xs i
type family Sort (xs :: [k]) :: [k] where
Sort '[] = '[]
Sort (x ': xs) = (Sort (SFilter 'FMin x xs) ++ '[x]) ++ Sort (SFilter 'FMax x xs)
data Flag = FMin | FMax
type family Cmp (a :: k) (b :: k) :: Ordering
type family SFilter (f :: Flag) (p :: k) (xs :: [k]) :: [k] where
SFilter f p '[] = '[]
SFilter 'FMin p (x ': xs) = If (Cmp x p == 'LT) (x ': SFilter 'FMin p xs) (SFilter 'FMin p xs)
SFilter 'FMax p (x ': xs) = If (Cmp x p == 'GT || Cmp x p == 'EQ) (x ': SFilter 'FMax p xs) (SFilter 'FMax p xs)
type family Zip lst lst' where
Zip lst lst' = ZipWith '(,) lst lst'
type family ZipWith f lst lst' where
ZipWith f '[] lst = '[]
ZipWith f lst '[] = '[]
ZipWith f (l ': ls) (n ': ns) = f l n ': ZipWith f ls ns
type family Fst a where
Fst '(a, _) = a
type family Snd a where
Snd '(_, a) = a
type family FMap f lst where
FMap f '[] = '[]
FMap f (l ': ls) = f l ': FMap f ls
class KnownNats (ns :: [Nat]) where
natVals :: Proxy ns -> [Int]
instance KnownNats '[] where
natVals _ = []
instance (KnownNat n, KnownNats ns) => KnownNats (n : ns) where
natVals _ = fromInteger (natVal (Proxy @n)) : natVals (Proxy @ns)
class KnownNatss (ns :: [[Nat]]) where
natValss :: Proxy ns -> [[Int]]
instance KnownNatss '[] where
natValss _ = []
instance (KnownNats n, KnownNatss ns) => KnownNatss (n : ns) where
natValss _ = natVals (Proxy @n) : natValss (Proxy @ns)