{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

{-|
Copyright        : (c) Galois, Inc 2021

@'Fin' n@ is a finite type with exactly @n@ elements. Essentially, they bundle a
'NatRepr' that has an existentially-quantified type parameter with a proof that
its parameter is less than some fixed natural.

They are useful in combination with types of a fixed size. For example 'Fin' is
used as the index in the 'Data.Functor.WithIndex.FunctorWithIndex' instance for
'Data.Parameterized.Vector'. As another example, a @Map ('Fin' n) a@ is a @Map@
that naturally has a fixed size bound of @n@.
-}
module Data.Parameterized.Fin
  ( Fin
  , mkFin
  , viewFin
  , finToNat
  , embed
  , tryEmbed
  , minFin
  , fin0Void
  , fin1Unit
  , fin2Bool
  ) where

import Control.Lens.Iso (Iso', iso)
import GHC.TypeNats (KnownNat)
import Numeric.Natural (Natural)
import Data.Void (Void, absurd)

import Data.Parameterized.NatRepr

-- | The type @'Fin' n@ has exactly @n@ inhabitants.
data Fin n =
  -- GHC 8.6 and 8.4 require parentheses around 'i + 1 <= n'
  forall i. (i + 1 <= n) => Fin { ()
_getFin :: NatRepr i }

instance Eq (Fin n) where
  Fin n
i == :: Fin n -> Fin n -> Bool
== Fin n
j = Fin n -> Natural
forall (n :: Nat). Fin n -> Natural
finToNat Fin n
i Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
== Fin n -> Natural
forall (n :: Nat). Fin n -> Natural
finToNat Fin n
j

instance Ord (Fin n) where
  compare :: Fin n -> Fin n -> Ordering
compare Fin n
i Fin n
j = Natural -> Natural -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Fin n -> Natural
forall (n :: Nat). Fin n -> Natural
finToNat Fin n
i) (Fin n -> Natural
forall (n :: Nat). Fin n -> Natural
finToNat Fin n
j)

instance (1 <= n, KnownNat n) => Bounded (Fin n) where
  minBound :: Fin n
minBound = NatRepr 0 -> Fin n
forall (n :: Nat) (i :: Nat). ((i + 1) <= n) => NatRepr i -> Fin n
Fin (KnownNat 0 => NatRepr 0
forall (n :: Nat). KnownNat n => NatRepr n
knownNat @0)
  maxBound :: Fin n
maxBound =
    case NatRepr n -> NatRepr 1 -> ((n - 1) + 1) :~: n
forall (f :: Nat -> *) (m :: Nat) (g :: Nat -> *) (n :: Nat).
(n <= m) =>
f m -> g n -> ((m - n) + n) :~: m
minusPlusCancel (KnownNat n => NatRepr n
forall (n :: Nat). KnownNat n => NatRepr n
knownNat @n) (KnownNat 1 => NatRepr 1
forall (n :: Nat). KnownNat n => NatRepr n
knownNat @1) of
      ((n - 1) + 1) :~: n
Refl -> NatRepr (n - 1) -> Fin n
forall (n :: Nat) (i :: Nat). ((i + 1) <= n) => NatRepr i -> Fin n
Fin (NatRepr n -> NatRepr (n - 1)
forall (n :: Nat). (1 <= n) => NatRepr n -> NatRepr (n - 1)
decNat (KnownNat n => NatRepr n
forall (n :: Nat). KnownNat n => NatRepr n
knownNat @n))

-- | Non-lawful instance, intended only for testing.
instance Show (Fin n) where
  show :: Fin n -> String
show Fin n
i = String
"Fin " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Natural -> String
forall a. Show a => a -> String
show (Fin n -> Natural
forall (n :: Nat). Fin n -> Natural
finToNat Fin n
i)

mkFin :: forall i n. (i + 1 <= n) => NatRepr i -> Fin n
mkFin :: NatRepr i -> Fin n
mkFin = NatRepr i -> Fin n
forall (n :: Nat) (i :: Nat). ((i + 1) <= n) => NatRepr i -> Fin n
Fin

viewFin ::  (forall i. (i + 1 <= n) => NatRepr i -> r) -> Fin n -> r
viewFin :: (forall (i :: Nat). ((i + 1) <= n) => NatRepr i -> r) -> Fin n -> r
viewFin forall (i :: Nat). ((i + 1) <= n) => NatRepr i -> r
f (Fin NatRepr i
i) = NatRepr i -> r
forall (i :: Nat). ((i + 1) <= n) => NatRepr i -> r
f NatRepr i
i

finToNat :: Fin n -> Natural
finToNat :: Fin n -> Natural
finToNat (Fin NatRepr i
i) = NatRepr i -> Natural
forall (n :: Nat). NatRepr n -> Natural
natValue NatRepr i
i

embed :: forall n m. (n <= m) => Fin n -> Fin m
embed :: Fin n -> Fin m
embed =
  (forall (i :: Nat). ((i + 1) <= n) => NatRepr i -> Fin m)
-> Fin n -> Fin m
forall (n :: Nat) r.
(forall (i :: Nat). ((i + 1) <= n) => NatRepr i -> r) -> Fin n -> r
viewFin
    (\(NatRepr i
x :: NatRepr o) ->
      case LeqProof (i + 1) n -> LeqProof n m -> LeqProof (i + 1) m
forall (m :: Nat) (n :: Nat) (p :: Nat).
LeqProof m n -> LeqProof n p -> LeqProof m p
leqTrans (LeqProof (i + 1) n
forall (m :: Nat) (n :: Nat). (m <= n) => LeqProof m n
LeqProof :: LeqProof (o + 1) n) (LeqProof n m
forall (m :: Nat) (n :: Nat). (m <= n) => LeqProof m n
LeqProof :: LeqProof n m) of
        LeqProof (i + 1) m
LeqProof -> NatRepr i -> Fin m
forall (n :: Nat) (i :: Nat). ((i + 1) <= n) => NatRepr i -> Fin n
Fin NatRepr i
x
    )

tryEmbed :: NatRepr n -> NatRepr m -> Fin n -> Maybe (Fin m)
tryEmbed :: NatRepr n -> NatRepr m -> Fin n -> Maybe (Fin m)
tryEmbed NatRepr n
n NatRepr m
m Fin n
i =
  case NatRepr n -> NatRepr m -> Maybe (LeqProof n m)
forall (m :: Nat) (n :: Nat).
NatRepr m -> NatRepr n -> Maybe (LeqProof m n)
testLeq NatRepr n
n NatRepr m
m of
    Just LeqProof n m
LeqProof -> Fin m -> Maybe (Fin m)
forall a. a -> Maybe a
Just (Fin n -> Fin m
forall (n :: Nat) (m :: Nat). (n <= m) => Fin n -> Fin m
embed Fin n
i)
    Maybe (LeqProof n m)
Nothing -> Maybe (Fin m)
forall a. Maybe a
Nothing

-- | The smallest element of @'Fin' n@
minFin :: (1 <= n) => Fin n
minFin :: Fin n
minFin = NatRepr 0 -> Fin n
forall (n :: Nat) (i :: Nat). ((i + 1) <= n) => NatRepr i -> Fin n
Fin (KnownNat 0 => NatRepr 0
forall (n :: Nat). KnownNat n => NatRepr n
knownNat @0)

fin0Void :: Iso' (Fin 0) Void
fin0Void :: p Void (f Void) -> p (Fin 0) (f (Fin 0))
fin0Void =
  (Fin 0 -> Void) -> (Void -> Fin 0) -> Iso (Fin 0) (Fin 0) Void Void
forall s a b t. (s -> a) -> (b -> t) -> Iso s t a b
iso
    ((forall (i :: Nat). ((i + 1) <= 0) => NatRepr i -> Void)
-> Fin 0 -> Void
forall (n :: Nat) r.
(forall (i :: Nat). ((i + 1) <= n) => NatRepr i -> r) -> Fin n -> r
viewFin
      (\(NatRepr i
x :: NatRepr o) ->
        case NatRepr i -> NatRepr 1 -> (i + 1) :~: (1 + i)
forall (f :: Nat -> *) (m :: Nat) (g :: Nat -> *) (n :: Nat).
f m -> g n -> (m + n) :~: (n + m)
plusComm NatRepr i
x (KnownNat 1 => NatRepr 1
forall (n :: Nat). KnownNat n => NatRepr n
knownNat @1) of
          (i + 1) :~: (1 + i)
Refl ->
            case LeqProof (1 + i) 0 -> LeqProof 1 0
forall (n :: Nat) (n' :: Nat) (m :: Nat).
LeqProof (n + n') m -> LeqProof n m
addIsLeqLeft1 @1 @o @0 LeqProof (1 + i) 0
forall (m :: Nat) (n :: Nat). (m <= n) => LeqProof m n
LeqProof of {}))
    Void -> Fin 0
forall a. Void -> a
absurd

fin1Unit :: Iso' (Fin 1) ()
fin1Unit :: p () (f ()) -> p (Fin 1) (f (Fin 1))
fin1Unit = (Fin 1 -> ()) -> (() -> Fin 1) -> Iso (Fin 1) (Fin 1) () ()
forall s a b t. (s -> a) -> (b -> t) -> Iso s t a b
iso (() -> Fin 1 -> ()
forall a b. a -> b -> a
const ()) (Fin 1 -> () -> Fin 1
forall a b. a -> b -> a
const Fin 1
forall (n :: Nat). (1 <= n) => Fin n
minFin)

fin2Bool :: Iso' (Fin 2) Bool
fin2Bool :: p Bool (f Bool) -> p (Fin 2) (f (Fin 2))
fin2Bool =
  (Fin 2 -> Bool) -> (Bool -> Fin 2) -> Iso (Fin 2) (Fin 2) Bool Bool
forall s a b t. (s -> a) -> (b -> t) -> Iso s t a b
iso
    ((forall (i :: Nat). ((i + 1) <= 2) => NatRepr i -> Bool)
-> Fin 2 -> Bool
forall (n :: Nat) r.
(forall (i :: Nat). ((i + 1) <= n) => NatRepr i -> r) -> Fin n -> r
viewFin
      (\NatRepr i
n ->
         case NatRepr i -> IsZeroNat i
forall (n :: Nat). NatRepr n -> IsZeroNat n
isZeroNat NatRepr i
n of
           IsZeroNat i
ZeroNat -> Bool
False
           IsZeroNat i
NonZeroNat -> Bool
True))
    (\Bool
b -> if Bool
b then Fin 2
forall a. Bounded a => a
maxBound else Fin 2
forall a. Bounded a => a
minBound)