{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -Wall #-}
module NumHask.Array.Shape
( Shape (..),
HasShape (..),
type (++),
type (!!),
Take,
Drop,
Reverse,
ReverseGo,
Filter,
rank,
Rank,
ranks,
Ranks,
size,
Size,
dimension,
Dimension,
flatten,
shapen,
minimum,
Minimum,
checkIndex,
CheckIndex,
checkIndexes,
CheckIndexes,
addIndex,
AddIndex,
dropIndex,
DropIndex,
posRelative,
PosRelative,
PosRelativeGo,
addIndexes,
AddIndexes,
AddIndexesGo,
dropIndexes,
DropIndexes,
takeIndexes,
TakeIndexes,
exclude,
Exclude,
concatenate',
Concatenate,
CheckConcatenate,
Insert,
CheckInsert,
reorder',
Reorder,
CheckReorder,
squeeze',
Squeeze,
incAt,
decAt,
KnownNats (..),
KnownNatss (..),
)
where
import Data.List ((!!))
import Data.Type.Bool
import GHC.TypeLits as L
import NumHask.Prelude as P hiding (Last, minimum)
newtype Shape (s :: [Nat]) = Shape {shapeVal :: [Int]} deriving (Show)
class HasShape s where
toShape :: Shape s
instance HasShape '[] where
toShape = Shape []
instance (KnownNat n, HasShape s) => HasShape (n : s) where
toShape = Shape $ fromInteger (natVal (Proxy :: Proxy n)) : shapeVal (toShape :: Shape s)
rank :: [a] -> Int
rank = length
{-# INLINE rank #-}
type family Rank (s :: [a]) :: Nat where
Rank '[] = 0
Rank (_ : s) = Rank s + 1
ranks :: [[a]] -> [Int]
ranks = fmap rank
{-# INLINE ranks #-}
type family Ranks (s :: [[a]]) :: [Nat] where
Ranks '[] = '[]
Ranks (x : xs) = Rank x : Ranks xs
size :: [Int] -> Int
size [] = 1
size [x] = x
size xs = P.product xs
{-# INLINE size #-}
type family Size (s :: [Nat]) :: Nat where
Size '[] = 1
Size (n : s) = n L.* Size s
flatten :: [Int] -> [Int] -> Int
flatten [] _ = 0
flatten _ [x'] = x'
flatten ns xs = sum $ zipWith (*) xs (drop 1 $ scanr (*) one ns)
{-# INLINE flatten #-}
shapen :: [Int] -> Int -> [Int]
shapen [] _ = []
shapen [_] x' = [x']
shapen [_, y] x' = let (i, j) = divMod x' y in [i, j]
shapen ns x =
fst $
foldr
( \a (acc, r) ->
let (d, m) = divMod r a
in (m : acc, d)
)
([], x)
ns
{-# INLINE shapen #-}
checkIndex :: Int -> Int -> Bool
checkIndex i n = zero <= i && i + one <= n
type family CheckIndex (i :: Nat) (n :: Nat) :: Bool where
CheckIndex i n =
If ((0 <=? i) && (i + 1 <=? n)) 'True (L.TypeError ('Text "index outside range"))
checkIndexes :: [Int] -> Int -> Bool
checkIndexes is n = all (`checkIndex` n) is
type family CheckIndexes (i :: [Nat]) (n :: Nat) :: Bool where
CheckIndexes '[] n = 'True
CheckIndexes (i : is) n = CheckIndex i n && CheckIndexes is n
dimension :: [Int] -> Int -> Int
dimension (s : _) 0 = s
dimension (_ : s) n = dimension s (n - 1)
dimension _ _ = throw (NumHaskException "dimension overflow")
type family Dimension (s :: [Nat]) (i :: Nat) :: Nat where
Dimension (s : _) 0 = s
Dimension (_ : s) n = Dimension s (n - 1)
Dimension _ _ = L.TypeError ('Text "dimension overflow")
minimum :: [Int] -> Int
minimum [] = throw (NumHaskException "dimension underflow")
minimum [x] = x
minimum (x : xs) = P.min x (minimum xs)
type family Minimum (s :: [Nat]) :: Nat where
Minimum '[] = L.TypeError ('Text "zero dimension")
Minimum '[x] = x
Minimum (x : xs) = If (x <=? Minimum xs) x (Minimum xs)
type family Take (n :: Nat) (a :: [k]) :: [k] where
Take 0 _ = '[]
Take n (x : xs) = x : Take (n - 1) xs
type family Drop (n :: Nat) (a :: [k]) :: [k] where
Drop 0 xs = xs
Drop n (_ : xs) = Drop (n - 1) xs
type family Tail (a :: [k]) :: [k] where
Tail '[] = L.TypeError ('Text "No tail")
Tail (_ : xs) = xs
type family Init (a :: [k]) :: [k] where
Init '[] = L.TypeError ('Text "No init")
Init '[_] = '[]
Init (x : xs) = x : Init xs
type family Head (a :: [k]) :: k where
Head '[] = L.TypeError ('Text "No head")
Head (x : _) = x
type family Last (a :: [k]) :: k where
Last '[] = L.TypeError ('Text "No last")
Last '[x] = x
Last (_ : xs) = Last xs
type family (a :: [k]) ++ (b :: [k]) :: [k] where
'[] ++ b = b
(a : as) ++ b = a : (as ++ b)
dropIndex :: [Int] -> Int -> [Int]
dropIndex s i = take i s ++ drop (i + 1) s
type DropIndex s i = Take i s ++ Drop (i + 1) s
addIndex :: [Int] -> Int -> Int -> [Int]
addIndex s i d = take i s ++ (d : drop i s)
type AddIndex s i d = Take i s ++ (d : Drop i s)
type Reverse (a :: [k]) = ReverseGo a '[]
type family ReverseGo (a :: [k]) (b :: [k]) :: [k] where
ReverseGo '[] b = b
ReverseGo (a : as) b = ReverseGo as (a : b)
posRelative :: [Int] -> [Int]
posRelative as = reverse (go [] as)
where
go r [] = r
go r (x : xs) = go (x : r) ((\y -> bool (y - one) y (y < x)) <$> xs)
type family PosRelative (s :: [Nat]) where
PosRelative s = PosRelativeGo s '[]
type family PosRelativeGo (r :: [Nat]) (s :: [Nat]) where
PosRelativeGo '[] r = Reverse r
PosRelativeGo (x : xs) r = PosRelativeGo (DecMap x xs) (x : r)
type family DecMap (x :: Nat) (ys :: [Nat]) :: [Nat] where
DecMap _ '[] = '[]
DecMap x (y : ys) = If (y + 1 <=? x) y (y - 1) : DecMap x ys
dropIndexes :: [Int] -> [Int] -> [Int]
dropIndexes s i = foldl' dropIndex s (posRelative i)
type family DropIndexes (s :: [Nat]) (i :: [Nat]) where
DropIndexes s i = DropIndexesGo s (PosRelative i)
type family DropIndexesGo (s :: [Nat]) (i :: [Nat]) where
DropIndexesGo s '[] = s
DropIndexesGo s (i : is) = DropIndexesGo (DropIndex s i) is
addIndexes :: () => [Int] -> [Int] -> [Int] -> [Int]
addIndexes as xs = addIndexesGo as (reverse (posRelative (reverse xs)))
where
addIndexesGo as' [] _ = as'
addIndexesGo as' (x : xs') (y : ys') = addIndexesGo (addIndex as' x y) xs' ys'
addIndexesGo _ _ _ = throw (NumHaskException "mismatched ranks")
type family AddIndexes (as :: [Nat]) (xs :: [Nat]) (ys :: [Nat]) where
AddIndexes as xs ys = AddIndexesGo as (Reverse (PosRelative (Reverse xs))) ys
type family AddIndexesGo (as :: [Nat]) (xs :: [Nat]) (ys :: [Nat]) where
AddIndexesGo as' '[] _ = as'
AddIndexesGo as' (x : xs') (y : ys') = AddIndexesGo (AddIndex as' x y) xs' ys'
AddIndexesGo _ _ _ = L.TypeError ('Text "mismatched ranks")
takeIndexes :: [Int] -> [Int] -> [Int]
takeIndexes s i = (s !!) <$> i
type family TakeIndexes (s :: [Nat]) (i :: [Nat]) where
TakeIndexes '[] _ = '[]
TakeIndexes _ '[] = '[]
TakeIndexes s (i : is) =
(s !! i) ': TakeIndexes s is
type family (a :: [k]) !! (b :: Nat) :: k where
(!!) '[] i = L.TypeError ('Text "Index Underflow")
(!!) (x : _) 0 = x
(!!) (_ : xs) i = (!!) xs (i - 1)
type family Enumerate (n :: Nat) where
Enumerate n = Reverse (EnumerateGo n)
type family EnumerateGo (n :: Nat) where
EnumerateGo 0 = '[]
EnumerateGo n = (n - 1) : EnumerateGo (n - 1)
exclude :: Int -> [Int] -> [Int]
exclude r = dropIndexes [0 .. (r - 1)]
type family Exclude (r :: Nat) (i :: [Nat]) where
Exclude r i = DropIndexes (EnumerateGo r) i
concatenate' :: Int -> [Int] -> [Int] -> [Int]
concatenate' i s0 s1 = take i s0 ++ (dimension s0 i + dimension s1 i : drop (i + 1) s0)
type Concatenate i s0 s1 = Take i s0 ++ (Dimension s0 i + Dimension s1 i : Drop (i + 1) s0)
type CheckConcatenate i s0 s1 s =
( CheckIndex i (Rank s0)
&& DropIndex s0 i == DropIndex s1 i
&& Rank s0 == Rank s1
)
~ 'True
type CheckInsert d i s =
(CheckIndex d (Rank s) && CheckIndex i (Dimension s d)) ~ 'True
type Insert d s = Take d s ++ (Dimension s d + 1 : Drop (d + 1) s)
incAt :: Int -> [Int] -> [Int]
incAt d s = take d s ++ (dimension s d + 1 : drop (d + 1) s)
decAt :: Int -> [Int] -> [Int]
decAt d s = take d s ++ (dimension s d - 1 : drop (d + 1) s)
reorder' :: [Int] -> [Int] -> [Int]
reorder' [] _ = []
reorder' _ [] = []
reorder' s (d : ds) = dimension s d : reorder' s ds
type family Reorder (s :: [Nat]) (ds :: [Nat]) :: [Nat] where
Reorder '[] _ = '[]
Reorder _ '[] = '[]
Reorder s (d : ds) = Dimension s d : Reorder s ds
type family CheckReorder (ds :: [Nat]) (s :: [Nat]) where
CheckReorder ds s =
If
( Rank ds == Rank s
&& CheckIndexes ds (Rank s)
)
'True
(L.TypeError ('Text "bad dimensions"))
~ 'True
squeeze' :: (Eq a, Num a) => [a] -> [a]
squeeze' = filter (/= 1)
type family Squeeze (a :: [Nat]) where
Squeeze '[] = '[]
Squeeze a = Filter '[] a 1
type family Filter (r :: [Nat]) (xs :: [Nat]) (i :: Nat) where
Filter r '[] _ = Reverse r
Filter r (x : xs) i = Filter (If (x == i) r (x : r)) xs i
type family Sort (xs :: [k]) :: [k] where
Sort '[] = '[]
Sort (x ': xs) = (Sort (SFilter 'FMin x xs) ++ '[x]) ++ Sort (SFilter 'FMax x xs)
data Flag = FMin | FMax
type family Cmp (a :: k) (b :: k) :: Ordering
type family SFilter (f :: Flag) (p :: k) (xs :: [k]) :: [k] where
SFilter f p '[] = '[]
SFilter 'FMin p (x ': xs) = If (Cmp x p == 'LT) (x ': SFilter 'FMin p xs) (SFilter 'FMin p xs)
SFilter 'FMax p (x ': xs) = If (Cmp x p == 'GT || Cmp x p == 'EQ) (x ': SFilter 'FMax p xs) (SFilter 'FMax p xs)
type family Zip lst lst' where
Zip lst lst' = ZipWith '(,) lst lst'
type family ZipWith f lst lst' where
ZipWith f '[] lst = '[]
ZipWith f lst '[] = '[]
ZipWith f (l ': ls) (n ': ns) = f l n ': ZipWith f ls ns
type family Fst a where
Fst '(a, _) = a
type family Snd a where
Snd '(_, a) = a
type family FMap f lst where
FMap f '[] = '[]
FMap f (l ': ls) = f l ': FMap f ls
class KnownNats (ns :: [Nat]) where
natVals :: Proxy ns -> [Int]
instance KnownNats '[] where
natVals _ = []
instance (KnownNat n, KnownNats ns) => KnownNats (n : ns) where
natVals _ = fromInteger (natVal (Proxy @n)) : natVals (Proxy @ns)
class KnownNatss (ns :: [[Nat]]) where
natValss :: Proxy ns -> [[Int]]
instance KnownNatss '[] where
natValss _ = []
instance (KnownNats n, KnownNatss ns) => KnownNatss (n : ns) where
natValss _ = natVals (Proxy @n) : natValss (Proxy @ns)