{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module NumHask.Array.Constraints
  ( IsValidConcat
  , Squeeze
  , Concatenate
  , IsValidTranspose
  , Fold
  , FoldAlong
  , TailModule
  , HeadModule
  , Transpose
  ) where

import Data.Singletons.Prelude hiding (Max)
import Data.Singletons.Prelude.List hiding (Transpose)
import Data.Singletons.Prelude.Tuple (Fst, Snd)
import Data.Singletons.TypeLits (Nat)
import qualified Protolude as P

type family DropDim d a :: [b] where
  DropDim 0 xs = Drop 1 xs
  DropDim d xs = Take (d - 1) (Fst (SplitAt d xs)) ++ Snd (SplitAt d xs)

type family IsValidConcat i (a :: [Nat]) (b :: [Nat]) :: P.Bool where
  IsValidConcat _ '[] _ = 'P.False
  IsValidConcat _ _ '[] = 'P.False
  IsValidConcat i a b =
    And (ZipWith (==@#@$) (DropDim i a) (DropDim i b))

type family Squeeze (a :: [Nat]) where
  Squeeze '[] = '[]
  Squeeze a = Filter ((/=@#@$$) 1) a

type family IsValidTranspose (p :: [Nat]) (a :: [Nat]) :: P.Bool where
  IsValidTranspose p a =
    (Minimum p >= 0) && (Minimum a >= 0) && (Sum a == Sum p) && Length p == Length a

type family Transpose a where
  Transpose a = Reverse a

type family AddDimension (d :: Nat) t :: [Nat] where
  AddDimension d t = Insert d t

type family Concatenate i (a :: [Nat]) (b :: [Nat]) :: [Nat] where
  Concatenate i a b =
    Take i (Fst (SplitAt (i + 1) a)) ++
    ('[ Head (Drop i a) + Head (Drop i b)]) ++
    Snd (SplitAt (i + 1) b)

-- | Reduces axis i in shape s.  Maintains singlton dimension
type family FoldAlong i (s :: [Nat]) where
  FoldAlong _ '[] = '[]
  FoldAlong d xs = Take d (Fst (SplitAt (d + 1) xs)) ++ '[ 1] ++ Snd (SplitAt (d + 1) xs)

-- | Reduces axis i in shape s. Does not maintain singlton dimension.
type family Fold i (s :: [Nat]) where
  Fold _ '[] = '[]
  Fold d xs = Take d (Fst (SplitAt (d + 1) xs)) ++ Snd (SplitAt (d + 1) xs)

type family TailModule i (s :: [Nat]) where
  TailModule _ '[] = '[]
  TailModule d xs = (Snd (SplitAt d xs))

type family HeadModule i (s :: [Nat]) where
  HeadModule _ '[] = '[]
  HeadModule d xs = (Fst (SplitAt d xs))