{-# LANGUAGE CPP                  #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE GADTs                #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE StandaloneDeriving   #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}
{-| This module declares all Shape related functions and data structures, as well as all singleton
-- instances for the Shape data type. This module was highly influenciated by Grenade, a Haskell
-- library for deep learning with dependent types. See: https://github.com/HuwCampbell/grenade
-}
module TensorSafe.Shape where

import           Data.Singletons
import           GHC.TypeLits    as N

import           TensorSafe.Core

--
-- Shape definition as in Haskell's Grenade library
--

-- | The current shapes we accept.
--   at the moment this is just one, two, and three dimensional
--   Vectors/Matricies.
--
--   These are only used with DataKinds, as Kind `Shape`, with Types 'D1, 'D2, 'D3.
data Shape
    = D1 Nat
    -- ^ One dimensional vector
    | D2 Nat Nat
    -- ^ Two dimensional matrix. Row, Column.
    | D3 Nat Nat Nat
    -- ^ Three dimensional matrix. Row, Column, Channels.

-- | Concrete data structures for a Shape.
--
--   All shapes are held in contiguous memory.
--   3D is held in a matrix (usually row oriented) which has height depth * rows.
data S (n :: Shape) where
    S1D :: ( KnownNat len )
        => R len
        -> S ('D1 len)

    S2D :: ( KnownNat rows, KnownNat columns )
        => L rows columns
        -> S ('D2 rows columns)

    S3D :: ( KnownNat rows
            , KnownNat columns
            , KnownNat depth
            , KnownNat (rows N.* depth))
        => L (rows N.* depth) columns
        -> S ('D3 rows columns depth)

deriving instance Show (S n)

-- Singleton instances.
-- Check: http://hackage.haskell.org/package/singletons
--
-- These could probably be derived with template haskell, but this seems
-- clear and makes adding the KnownNat constraints simple.
-- We can also keep our code TH free, which is great.
data instance Sing (n :: Shape) where
    D1Sing :: KnownNat a => Sing a -> Sing ('D1 a)
    D2Sing :: (KnownNat a, KnownNat b) => Sing a -> Sing b -> Sing ('D2 a b)
    D3Sing :: (KnownNat a, KnownNat b, KnownNat c) => Sing a -> Sing b -> Sing c -> Sing ('D3 a b c)

instance KnownNat a => SingI ('D1 a) where
    sing = D1Sing sing

instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where
    sing = D2Sing sing sing

instance (KnownNat a, KnownNat b, KnownNat c) => SingI ('D3 a b c) where
    sing = D3Sing sing sing sing

-- | Compares two Shapes at kinds level and returns a Bool kind
type family ShapeEquals (sIn :: Shape) (sOut :: Shape) :: Bool where
    ShapeEquals s s = 'True
    ShapeEquals _ _ = 'False

-- | Same as ShapeEquals, which compares two Shapes at kinds level, but raises a TypeError exception
-- if the Shapes are not the equal.
type family ShapeEquals' (sIn :: Shape) (sOut :: Shape) :: Bool where
    ShapeEquals' s s = 'True
    ShapeEquals' s1 s2 =
        TypeError ( 'Text "Couldn't match the Shape "
              ':<>: 'ShowType s1
              ':<>: 'Text " with the Shape "
              ':<>: 'ShowType s2)