{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module NumHask.Array.Constraints
( IsValidConcat
, Squeeze
, Concatenate
, IsValidTranspose
, Fold
, FoldAlong
, TailModule
, HeadModule
, Transpose
) where
import Data.Singletons.Prelude hiding (Max)
import Data.Singletons.Prelude.List hiding (Transpose)
import Data.Singletons.Prelude.Tuple (Fst, Snd)
import Data.Singletons.TypeLits (Nat)
import qualified Protolude as P
type family DropDim d a :: [b] where
DropDim 0 xs = Drop 1 xs
DropDim d xs = Take (d - 1) (Fst (SplitAt d xs)) ++ Snd (SplitAt d xs)
type family IsValidConcat i (a :: [Nat]) (b :: [Nat]) :: P.Bool where
IsValidConcat _ '[] _ = 'P.False
IsValidConcat _ _ '[] = 'P.False
IsValidConcat i a b =
And (ZipWith (==@#@$) (DropDim i a) (DropDim i b))
type family Squeeze (a :: [Nat]) where
Squeeze '[] = '[]
Squeeze a = Filter ((/=@#@$$) 1) a
type family IsValidTranspose (p :: [Nat]) (a :: [Nat]) :: P.Bool where
IsValidTranspose p a =
(Minimum p >= 0) && (Minimum a >= 0) && (Sum a == Sum p) && Length p == Length a
type family Transpose a where
Transpose a = Reverse a
type family AddDimension (d :: Nat) t :: [Nat] where
AddDimension d t = Insert d t
type family Concatenate i (a :: [Nat]) (b :: [Nat]) :: [Nat] where
Concatenate i a b =
Take i (Fst (SplitAt (i + 1) a)) ++
('[ Head (Drop i a) + Head (Drop i b)]) ++
Snd (SplitAt (i + 1) b)
type family FoldAlong i (s :: [Nat]) where
FoldAlong _ '[] = '[]
FoldAlong d xs = Take d (Fst (SplitAt (d + 1) xs)) ++ '[ 1] ++ Snd (SplitAt (d + 1) xs)
type family Fold i (s :: [Nat]) where
Fold _ '[] = '[]
Fold d xs = Take d (Fst (SplitAt (d + 1) xs)) ++ Snd (SplitAt (d + 1) xs)
type family TailModule i (s :: [Nat]) where
TailModule _ '[] = '[]
TailModule d xs = (Snd (SplitAt d xs))
type family HeadModule i (s :: [Nat]) where
HeadModule _ '[] = '[]
HeadModule d xs = (Fst (SplitAt d xs))