{-
Part of the code in this file comes from the parameterized-utils package:

Copyright (c) 2013-2022 Galois Inc.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:

  * Redistributions of source code must retain the above copyright
    notice, this list of conditions and the following disclaimer.

  * Redistributions in binary form must reproduce the above copyright
    notice, this list of conditions and the following disclaimer in
    the documentation and/or other materials provided with the
    distribution.

  * Neither the name of Galois, Inc. nor the names of its contributors
    may be used to endorse or promote products derived from this
    software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

-- |
-- Module      :   Grisette.Utils.Parameterized
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Utils.Parameterized
  ( -- * Unsafe axiom
    unsafeAxiom,

    -- * Runtime representation of type-level natural numbers
    NatRepr,
    natValue,
    unsafeMkNatRepr,
    natRepr,
    decNat,
    predNat,
    incNat,
    addNat,
    subNat,
    divNat,
    halfNat,

    -- * Proof of KnownNat
    KnownProof (..),
    hasRepr,
    withKnownProof,
    unsafeKnownProof,
    knownAdd,

    -- * Proof of (<=) for type-level natural numbers
    LeqProof (..),
    withLeqProof,
    unsafeLeqProof,
    testLeq,
    leqRefl,
    leqSucc,
    leqTrans,
    leqZero,
    leqAdd2,
    leqAdd,
    leqAddPos,
  )
where

import Data.Typeable (Proxy (Proxy), type (:~:) (Refl))
import GHC.Natural (Natural)
import GHC.TypeNats
  ( Div,
    KnownNat,
    Nat,
    SomeNat (SomeNat),
    natVal,
    someNatVal,
    type (+),
    type (-),
    type (<=),
  )
import Unsafe.Coerce (unsafeCoerce)

-- | Assert a proof of equality between two types.
-- This is unsafe if used improperly, so use this with caution!
unsafeAxiom :: forall a b. a :~: b
unsafeAxiom :: forall {k} (a :: k) (b :: k). a :~: b
unsafeAxiom = (a :~: a) -> a :~: b
forall a b. a -> b
unsafeCoerce (forall (a :: k). a :~: a
forall {k} (a :: k). a :~: a
Refl @a)

-- | A runtime representation of type-level natural numbers.
-- This can be used for performing dynamic checks on type-level natural numbers.
newtype NatRepr (n :: Nat) = NatRepr Natural

-- | The underlying runtime natural number value of a type-level natural number.
natValue :: NatRepr n -> Natural
natValue :: forall (n :: Nat). NatRepr n -> Nat
natValue (NatRepr Nat
n) = Nat
n

-- | Construct a runtime representation of a type-level natural number.
--
-- __Note:__ This function is unsafe, as it does not check that the runtime
-- representation is consistent with the type-level representation.
-- You should ensure the consistency yourself or the program can crash or
-- generate incorrect results.
unsafeMkNatRepr :: Natural -> NatRepr n
unsafeMkNatRepr :: forall (n :: Nat). Nat -> NatRepr n
unsafeMkNatRepr = Nat -> NatRepr n
forall (n :: Nat). Nat -> NatRepr n
NatRepr

-- | Construct a runtime representation of a type-level natural number when its
-- runtime value is known.
natRepr :: forall n. (KnownNat n) => NatRepr n
natRepr :: forall (n :: Nat). KnownNat n => NatRepr n
natRepr = Nat -> NatRepr n
forall (n :: Nat). Nat -> NatRepr n
NatRepr (Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n))

-- | Decrement a 'NatRepr' by 1.
decNat :: (1 <= n) => NatRepr n -> NatRepr (n - 1)
decNat :: forall (n :: Nat). (1 <= n) => NatRepr n -> NatRepr (n - 1)
decNat (NatRepr Nat
n) = Nat -> NatRepr (n - 1)
forall (n :: Nat). Nat -> NatRepr n
NatRepr (Nat
n Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
- Nat
1)

-- | Predecessor of a 'NatRepr'
predNat :: NatRepr (n + 1) -> NatRepr n
predNat :: forall (n :: Nat). NatRepr (n + 1) -> NatRepr n
predNat (NatRepr Nat
n) = Nat -> NatRepr n
forall (n :: Nat). Nat -> NatRepr n
NatRepr (Nat
n Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
- Nat
1)

-- | Increment a 'NatRepr' by 1.
incNat :: NatRepr n -> NatRepr (n + 1)
incNat :: forall (n :: Nat). NatRepr n -> NatRepr (n + 1)
incNat (NatRepr Nat
n) = Nat -> NatRepr (n + 1)
forall (n :: Nat). Nat -> NatRepr n
NatRepr (Nat
n Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
+ Nat
1)

-- | Addition of two 'NatRepr's.
addNat :: NatRepr m -> NatRepr n -> NatRepr (m + n)
addNat :: forall (m :: Nat) (n :: Nat).
NatRepr m -> NatRepr n -> NatRepr (m + n)
addNat (NatRepr Nat
m) (NatRepr Nat
n) = Nat -> NatRepr (m + n)
forall (n :: Nat). Nat -> NatRepr n
NatRepr (Nat
m Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
+ Nat
n)

-- | Subtraction of two 'NatRepr's.
subNat :: (n <= m) => NatRepr m -> NatRepr n -> NatRepr (m - n)
subNat :: forall (n :: Nat) (m :: Nat).
(n <= m) =>
NatRepr m -> NatRepr n -> NatRepr (m - n)
subNat (NatRepr Nat
m) (NatRepr Nat
n) = Nat -> NatRepr (m - n)
forall (n :: Nat). Nat -> NatRepr n
NatRepr (Nat
m Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
- Nat
n)

-- | Division of two 'NatRepr's.
divNat :: (1 <= n) => NatRepr m -> NatRepr n -> NatRepr (Div m n)
divNat :: forall (n :: Nat) (m :: Nat).
(1 <= n) =>
NatRepr m -> NatRepr n -> NatRepr (Div m n)
divNat (NatRepr Nat
m) (NatRepr Nat
n) = Nat -> NatRepr (Div m n)
forall (n :: Nat). Nat -> NatRepr n
NatRepr (Nat
m Nat -> Nat -> Nat
forall a. Integral a => a -> a -> a
`div` Nat
n)

-- | Half of a 'NatRepr'.
halfNat :: NatRepr (n + n) -> NatRepr n
halfNat :: forall (n :: Nat). NatRepr (n + n) -> NatRepr n
halfNat (NatRepr Nat
n) = Nat -> NatRepr n
forall (n :: Nat). Nat -> NatRepr n
NatRepr (Nat
n Nat -> Nat -> Nat
forall a. Integral a => a -> a -> a
`div` Nat
2)

-- | @'KnownProof n'@ is a type whose values are only inhabited when @n@ has
-- a known runtime value.
data KnownProof (n :: Nat) where
  KnownProof :: (KnownNat n) => KnownProof n

-- | Introduces the 'KnownNat' constraint when it's proven.
withKnownProof :: KnownProof n -> ((KnownNat n) => r) -> r
withKnownProof :: forall (n :: Nat) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof KnownProof n
p KnownNat n => r
r = case KnownProof n
p of KnownProof n
KnownProof -> r
KnownNat n => r
r

-- | Construct a 'KnownProof' given the runtime value.
--
-- __Note:__ This function is unsafe, as it does not check that the runtime
-- representation is consistent with the type-level representation.
-- You should ensure the consistency yourself or the program can crash or
-- generate incorrect results.
unsafeKnownProof :: Natural -> KnownProof n
unsafeKnownProof :: forall (n :: Nat). Nat -> KnownProof n
unsafeKnownProof Nat
nVal = NatRepr n -> KnownProof n
forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr (Nat -> NatRepr n
forall (n :: Nat). Nat -> NatRepr n
NatRepr Nat
nVal)

-- | Construct a 'KnownProof' given the runtime representation.
hasRepr :: forall n. NatRepr n -> KnownProof n
hasRepr :: forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr (NatRepr Nat
nVal) =
  case Nat -> SomeNat
someNatVal Nat
nVal of
    SomeNat (Proxy n
Proxy :: Proxy n') ->
      case n :~: n
forall {k} (a :: k) (b :: k). a :~: b
unsafeAxiom :: n :~: n' of
        n :~: n
Refl -> KnownProof n
forall (n :: Nat). KnownNat n => KnownProof n
KnownProof

-- | Adding two type-level natural numbers with known runtime values gives a
-- type-level natural number with a known runtime value.
knownAdd :: forall m n. KnownProof m -> KnownProof n -> KnownProof (m + n)
knownAdd :: forall (m :: Nat) (n :: Nat).
KnownProof m -> KnownProof n -> KnownProof (m + n)
knownAdd KnownProof m
KnownProof KnownProof n
KnownProof = forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr @(m + n) (Nat -> NatRepr (m + n)
forall (n :: Nat). Nat -> NatRepr n
NatRepr (Proxy m -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @m) Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
+ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)))

-- | @'LeqProof m n'@ is a type whose values are only inhabited when @m <= n@.
data LeqProof (m :: Nat) (n :: Nat) where
  LeqProof :: (m <= n) => LeqProof m n

-- | Introduces the @m <= n@ constraint when it's proven.
withLeqProof :: LeqProof m n -> ((m <= n) => r) -> r
withLeqProof :: forall (m :: Nat) (n :: Nat) r.
LeqProof m n -> ((m <= n) => r) -> r
withLeqProof LeqProof m n
p (m <= n) => r
r = case LeqProof m n
p of LeqProof m n
LeqProof -> r
(m <= n) => r
r

-- | Construct a 'LeqProof'.
--
-- __Note:__ This function is unsafe, as it does not check that the left-hand
-- side is less than or equal to the right-hand side.
-- You should ensure the consistency yourself or the program can crash or
-- generate incorrect results.
unsafeLeqProof :: forall m n. LeqProof m n
unsafeLeqProof :: forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof = LeqProof 0 0 -> LeqProof m n
forall a b. a -> b
unsafeCoerce (forall (m :: Nat) (n :: Nat). (m <= n) => LeqProof m n
LeqProof @0 @0)

-- | Checks if a 'NatRepr' is less than or equal to another 'NatRepr'.
testLeq :: NatRepr m -> NatRepr n -> Maybe (LeqProof m n)
testLeq :: forall (m :: Nat) (n :: Nat).
NatRepr m -> NatRepr n -> Maybe (LeqProof m n)
testLeq (NatRepr Nat
m) (NatRepr Nat
n) =
  case Nat -> Nat -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Nat
m Nat
n of
    Ordering
LT -> Maybe (LeqProof m n)
forall a. Maybe a
Nothing
    Ordering
EQ -> LeqProof m n -> Maybe (LeqProof m n)
forall a. a -> Maybe a
Just LeqProof m n
forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof
    Ordering
GT -> LeqProof m n -> Maybe (LeqProof m n)
forall a. a -> Maybe a
Just LeqProof m n
forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof

-- | Apply reflexivity to 'LeqProof'.
leqRefl :: f n -> LeqProof n n
leqRefl :: forall (f :: Nat -> *) (n :: Nat). f n -> LeqProof n n
leqRefl f n
_ = LeqProof n n
forall (m :: Nat) (n :: Nat). (m <= n) => LeqProof m n
LeqProof

-- | A natural number is less than or equal to its successor.
leqSucc :: f n -> LeqProof n (n + 1)
leqSucc :: forall (f :: Nat -> *) (n :: Nat). f n -> LeqProof n (n + 1)
leqSucc f n
_ = LeqProof n (n + 1)
forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof

-- | Apply transitivity to 'LeqProof'.
leqTrans :: LeqProof a b -> LeqProof b c -> LeqProof a c
leqTrans :: forall (a :: Nat) (b :: Nat) (c :: Nat).
LeqProof a b -> LeqProof b c -> LeqProof a c
leqTrans LeqProof a b
_ LeqProof b c
_ = LeqProof a c
forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof

-- | Zero is less than or equal to any natural number.
leqZero :: LeqProof 0 n
leqZero :: forall (n :: Nat). LeqProof 0 n
leqZero = LeqProof 0 n
forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof

-- | Add both sides of two inequalities.
leqAdd2 :: LeqProof xl xh -> LeqProof yl yh -> LeqProof (xl + yl) (xh + yh)
leqAdd2 :: forall (xl :: Nat) (xh :: Nat) (yl :: Nat) (yh :: Nat).
LeqProof xl xh -> LeqProof yl yh -> LeqProof (xl + yl) (xh + yh)
leqAdd2 LeqProof xl xh
_ LeqProof yl yh
_ = LeqProof (xl + yl) (xh + yh)
forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof

-- | Produce proof that adding a value to the larger element in an 'LeqProof'
-- is larger.
leqAdd :: LeqProof m n -> f o -> LeqProof m (n + o)
leqAdd :: forall (m :: Nat) (n :: Nat) (f :: Nat -> *) (o :: Nat).
LeqProof m n -> f o -> LeqProof m (n + o)
leqAdd LeqProof m n
_ f o
_ = LeqProof m (n + o)
forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof

-- | Adding two positive natural numbers is positive.
leqAddPos :: (1 <= m, 1 <= n) => p m -> q n -> LeqProof 1 (m + n)
leqAddPos :: forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(1 <= m, 1 <= n) =>
p m -> q n -> LeqProof 1 (m + n)
leqAddPos p m
_ q n
_ = LeqProof 1 (m + n)
forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof