-- SPDX-FileCopyrightText: 2021 Oxhead Alpha
-- SPDX-License-Identifier: LicenseRef-MIT-OA

{- | Reimplementation of some syntax sugar.

You need the following module pragmas to make it work smoothly:

{-# LANGUAGE NoApplicativeDo, RebindableSyntax #-}
{-# OPTIONS_GHC -Wno-unused-do-bind #-}

-}
module Lorentz.Rebinded
  ( (>>)
  , pure
  , return
  , IsCondition (ifThenElse)
  , Condition (..)
  , (<.)
  , (>.)
  , (<=.)
  , (>=.)
  , (==.)
  , (/=.)
  , keepIfArgs

    -- * Re-exports required for RebindableSyntax
  , fromInteger
  , fromString
  , fromLabel
  , negate
  ) where


import Prelude hiding (drop, not, swap, (>>), (>>=))

import Lorentz.Arith
import Lorentz.Base
import Lorentz.Coercions
import Lorentz.Constraints.Scopes
import Lorentz.Instr
import Lorentz.Macro
import Morley.Michelson.Typed.Arith
import Morley.Util.Label (Label)
import Morley.Util.Named

-- | Aliases for '(#)' used by do-blocks.
(>>) :: (a :-> b) -> (b :-> c) -> (a :-> c)
>> :: forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
(>>) = (a :-> b) -> (b :-> c) -> a :-> c
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
(#)

-- | The most basic predicate for @if ... then .. else ...@ construction,
-- defines a kind of operation applied to the top elements of the current stack.
--
-- Type arguments mean:
-- 1. Input of @if@
-- 2. Left branch input
-- 3. Right branch input
-- 4. Output of branches
-- 5. Output of @if@
data Condition arg argl argr outb out where
  Holds :: Condition (Bool ': s) s s o o
  IsSome :: Condition (Maybe a ': s) (a ': s) s o o
  IsNone :: Condition (Maybe a ': s) s (a ': s) o o
  IsLeft :: Condition (Either l r ': s) (l ': s) (r ': s) o o
  IsRight :: Condition (Either l r ': s) (r ': s) (l ': s) o o
  IsCons :: Condition ([a] ': s) (a ': [a] ': s) s o o
  IsNil :: Condition ([a] ': s) s (a ': [a] ': s) o o

  Not :: Condition s s1 s2 ob o -> Condition s s2 s1 ob o

  IsZero :: (UnaryArithOpHs Eq' a, UnaryArithResHs Eq' a ~ Bool)
         => Condition (a ': s) s s o o

  IsEq :: NiceComparable a => Condition (a ': a ': s) s s o o
  IsNeq :: NiceComparable a => Condition (a ': a ': s) s s o o
  IsLt :: NiceComparable a => Condition (a ': a ': s) s s o o
  IsGt :: NiceComparable a => Condition (a ': a ': s) s s o o
  IsLe :: NiceComparable a => Condition (a ': a ': s) s s o o
  IsGe :: NiceComparable a => Condition (a ': a ': s) s s o o

  -- | Explicitly named binary condition, to ensure proper order of
  -- stack arguments.
  NamedBinCondition ::
    Condition (a ': a ': s) s s o o ->
    Label n1 -> Label n2 ->
    Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o

  -- | Provide the compared arguments to @if@ branches.
  PreserveArgsBinCondition ::
    (Dupable a, Dupable b) =>
    (forall st o. Condition (a ': b ': st) st st o o) ->
    Condition (a ': b ': s) (a ': b ': s) (a ': b ': s) (a ': b ': s) s

-- | Everything that can be put after @if@ keyword.
--
-- The first type argument stands for the condition type, and all other type
-- arguments define stack types around/within the @if then else@ construction.
-- For semantics of each type argument see 'Condition'.
class IsCondition cond arg argl argr outb out where
  -- | Defines semantics of @if ... then ... else ...@ construction.
  ifThenElse :: cond -> (argl :-> outb) -> (argr :-> outb) -> (arg :-> out)

instance (arg ~ arg0, argl ~ argl0, argr ~ argr0, outb ~ outb0, out ~ out0) =>
         IsCondition (Condition arg argl argr outb out) arg0 argl0 argr0 outb0 out0 where
  ifThenElse :: Condition arg argl argr outb out
-> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
ifThenElse = \case
    Condition arg argl argr outb out
Holds -> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
forall (s :: [*]) (s' :: [*]).
(s :-> s') -> (s :-> s') -> (Bool : s) :-> s'
if_
    Condition arg argl argr outb out
IsSome -> ((argr0 :-> outb0)
 -> ((a : argr0) :-> outb0) -> (Maybe a : argr0) :-> outb0)
-> ((a : argr0) :-> outb0)
-> (argr0 :-> outb0)
-> (Maybe a : argr0) :-> outb0
forall a b c. (a -> b -> c) -> b -> a -> c
flip (argr0 :-> outb0)
-> ((a : argr0) :-> outb0) -> (Maybe a : argr0) :-> outb0
forall (s :: [*]) (s' :: [*]) a.
(s :-> s') -> ((a : s) :-> s') -> (Maybe a : s) :-> s'
ifNone
    Condition arg argl argr outb out
IsNone -> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
forall (s :: [*]) (s' :: [*]) a.
(s :-> s') -> ((a : s) :-> s') -> (Maybe a : s) :-> s'
ifNone
    Condition arg argl argr outb out
IsLeft -> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
forall a (s :: [*]) (s' :: [*]) b.
((a : s) :-> s') -> ((b : s) :-> s') -> (Either a b : s) :-> s'
ifLeft
    Condition arg argl argr outb out
IsRight -> (((l : s) :-> outb0)
 -> ((r : s) :-> outb0) -> (Either l r : s) :-> outb0)
-> ((r : s) :-> outb0)
-> ((l : s) :-> outb0)
-> (Either l r : s) :-> outb0
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((l : s) :-> outb0)
-> ((r : s) :-> outb0) -> (Either l r : s) :-> outb0
forall a (s :: [*]) (s' :: [*]) b.
((a : s) :-> s') -> ((b : s) :-> s') -> (Either a b : s) :-> s'
ifLeft
    Condition arg argl argr outb out
IsCons -> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
forall a (s :: [*]) (s' :: [*]).
((a : List a : s) :-> s') -> (s :-> s') -> (List a : s) :-> s'
ifCons
    Condition arg argl argr outb out
IsNil -> (((a : [a] : argl0) :-> outb0)
 -> (argl0 :-> outb0) -> ([a] : argl0) :-> outb0)
-> (argl0 :-> outb0)
-> ((a : [a] : argl0) :-> outb0)
-> ([a] : argl0) :-> outb0
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((a : [a] : argl0) :-> outb0)
-> (argl0 :-> outb0) -> ([a] : argl0) :-> outb0
forall a (s :: [*]) (s' :: [*]).
((a : List a : s) :-> s') -> (s :-> s') -> (List a : s) :-> s'
ifCons

    Not Condition arg argr argl outb out
cond -> \argl0 :-> outb0
l argr0 :-> outb0
r -> Condition arg argr argl outb out
-> (argr0 :-> outb0) -> (argl0 :-> outb0) -> arg0 :-> out0
forall cond (arg :: [*]) (argl :: [*]) (argr :: [*]) (outb :: [*])
       (out :: [*]).
IsCondition cond arg argl argr outb out =>
cond -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
ifThenElse Condition arg argr argl outb out
cond argr0 :-> outb0
r argl0 :-> outb0
l

    Condition arg argl argr outb out
IsZero -> \argl0 :-> outb0
l argr0 :-> outb0
r -> (a : argl0) :-> (Bool : argl0)
forall n (s :: [*]).
UnaryArithOpHs Eq' n =>
(n : s) :-> (UnaryArithResHs Eq' n : s)
eq0 ((a : argl0) :-> (Bool : argl0))
-> ((Bool : argl0) :-> outb0) -> (a : argl0) :-> outb0
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# (argl0 :-> outb0) -> (argl0 :-> outb0) -> (Bool : argl0) :-> outb0
forall (s :: [*]) (s' :: [*]).
(s :-> s') -> (s :-> s') -> (Bool : s) :-> s'
if_ argl0 :-> outb0
l argl0 :-> outb0
argr0 :-> outb0
r

    Condition arg argl argr outb out
IsEq -> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a : a : s) :-> s'
ifEq
    Condition arg argl argr outb out
IsNeq -> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a : a : s) :-> s'
ifNeq
    Condition arg argl argr outb out
IsLt -> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a : a : s) :-> s'
ifLt
    Condition arg argl argr outb out
IsGt -> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a : a : s) :-> s'
ifGt
    Condition arg argl argr outb out
IsLe -> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a : a : s) :-> s'
ifLe
    Condition arg argl argr outb out
IsGe -> (argl0 :-> outb0) -> (argr0 :-> outb0) -> arg0 :-> out0
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a : a : s) :-> s'
ifGe

    NamedBinCondition Condition (a : a : argl) argl argl outb outb
condition Label n1
l1 Label n2
l2 -> \argl0 :-> outb0
l argr0 :-> outb0
r ->
      Label n1
-> ((n1 :! a) : (n2 :! a) : argl0) :-> (a : (n2 :! a) : argl0)
forall (name :: Symbol) a (s :: [*]).
Label name -> ((name :! a) : s) :-> (a : s)
fromNamed Label n1
l1 (((n1 :! a) : (n2 :! a) : argl0) :-> (a : (n2 :! a) : argl0))
-> ((a : (n2 :! a) : argl0) :-> (a : a : argl0))
-> ((n1 :! a) : (n2 :! a) : argl0) :-> (a : a : argl0)
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# (((n2 :! a) : argl0) :-> (a : argl0))
-> (a : (n2 :! a) : argl0) :-> (a : a : argl0)
forall a (s :: [*]) (s' :: [*]).
HasCallStack =>
(s :-> s') -> (a : s) :-> (a : s')
dip (Label n2 -> ((n2 :! a) : argl0) :-> (a : argl0)
forall (name :: Symbol) a (s :: [*]).
Label name -> ((name :! a) : s) :-> (a : s)
fromNamed Label n2
l2) (((n1 :! a) : (n2 :! a) : argl0) :-> (a : a : argl0))
-> ((a : a : argl0) :-> out0)
-> ((n1 :! a) : (n2 :! a) : argl0) :-> out0
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# Condition (a : a : argl) argl argl outb outb
-> (argl0 :-> outb0)
-> (argr0 :-> outb0)
-> (a : a : argl0) :-> out0
forall cond (arg :: [*]) (argl :: [*]) (argr :: [*]) (outb :: [*])
       (out :: [*]).
IsCondition cond arg argl argr outb out =>
cond -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
ifThenElse Condition (a : a : argl) argl argl outb outb
condition argl0 :-> outb0
l argr0 :-> outb0
r

    PreserveArgsBinCondition forall (st :: [*]) (o :: [*]). Condition (a : b : st) st st o o
condition -> \argl0 :-> outb0
l argr0 :-> outb0
r ->
      forall (n :: Nat) a (inp :: [*]) (out :: [*]).
(ConstraintDUPNLorentz (ToPeano n) inp out a, Dupable a) =>
inp :-> out
dupN @2 (arg0 :-> (b : a : b : out0))
-> ((b : a : b : out0) :-> (a : b : a : b : out0))
-> arg0 :-> (a : b : a : b : out0)
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# forall (n :: Nat) a (inp :: [*]) (out :: [*]).
(ConstraintDUPNLorentz (ToPeano n) inp out a, Dupable a) =>
inp :-> out
dupN @2 (arg0 :-> (a : b : a : b : out0))
-> ((a : b : a : b : out0) :-> out0) -> arg0 :-> out0
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
#
      Condition
  (a : b : a : b : out0) (a : b : out0) (a : b : out0) out0 out0
-> (argl0 :-> out0)
-> (argr0 :-> out0)
-> (a : b : a : b : out0) :-> out0
forall cond (arg :: [*]) (argl :: [*]) (argr :: [*]) (outb :: [*])
       (out :: [*]).
IsCondition cond arg argl argr outb out =>
cond -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
ifThenElse Condition
  (a : b : a : b : out0) (a : b : out0) (a : b : out0) out0 out0
forall (st :: [*]) (o :: [*]). Condition (a : b : st) st st o o
condition
        -- since this pattern is commonly used when one of the branches fails,
        -- it's essential to @drop@ within branches, not after @if@ - @drop@s
        -- appearing to be dead code will be cut off
        (argl0 :-> outb0
l (argl0 :-> outb0) -> (outb0 :-> (b : out0)) -> argl0 :-> (b : out0)
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# outb0 :-> (b : out0)
forall a (s :: [*]). (a : s) :-> s
drop (argl0 :-> (b : out0)) -> ((b : out0) :-> out0) -> argl0 :-> out0
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# (b : out0) :-> out0
forall a (s :: [*]). (a : s) :-> s
drop)
        (argr0 :-> outb0
r (argr0 :-> outb0) -> (outb0 :-> (b : out0)) -> argr0 :-> (b : out0)
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# outb0 :-> (b : out0)
forall a (s :: [*]). (a : s) :-> s
drop (argr0 :-> (b : out0)) -> ((b : out0) :-> out0) -> argr0 :-> out0
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# (b : out0) :-> out0
forall a (s :: [*]). (a : s) :-> s
drop)

-- | Named version of 'IsLt'.
--
-- In this and similar operators you provide names of accepted stack operands as
-- a safety measure of that they go in the expected order.
infix 4 <.
(<.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
<. :: forall a (n1 :: Symbol) (n2 :: Symbol) (s :: [*]) (o :: [*]).
NiceComparable a =>
Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(<.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsLt

-- | Named version of 'IsGt'.
infix 4 >.
(>.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
>. :: forall a (n1 :: Symbol) (n2 :: Symbol) (s :: [*]) (o :: [*]).
NiceComparable a =>
Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(>.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsGt

-- | Named version of 'IsLe'.
infix 4 <=.
(<=.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
<=. :: forall a (n1 :: Symbol) (n2 :: Symbol) (s :: [*]) (o :: [*]).
NiceComparable a =>
Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(<=.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsLe

-- | Named version of 'IsGe'.
infix 4 >=.
(>=.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
>=. :: forall a (n1 :: Symbol) (n2 :: Symbol) (s :: [*]) (o :: [*]).
NiceComparable a =>
Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(>=.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsGe

-- | Named version of 'IsEq'.
infix 4 ==.
(==.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
==. :: forall a (n1 :: Symbol) (n2 :: Symbol) (s :: [*]) (o :: [*]).
NiceComparable a =>
Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(==.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsEq

-- | Named version of 'IsNeq'.
infix 4 /=.
(/=.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
/=. :: forall a (n1 :: Symbol) (n2 :: Symbol) (s :: [*]) (o :: [*]).
NiceComparable a =>
Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(/=.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsNeq

-- | Condition modifier, makes stack operands of binary comparison to be
-- available within @if@ branches.
keepIfArgs
  :: (Dupable a, Dupable b)
  => (forall st o. Condition (a ': b ': st) st st o o)
  -> Condition (a ': b ': s) (a ': b ': s) (a ': b ': s) (a ': b ': s) s
keepIfArgs :: forall a b (s :: [*]).
(Dupable a, Dupable b) =>
(forall (st :: [*]) (o :: [*]). Condition (a : b : st) st st o o)
-> Condition (a : b : s) (a : b : s) (a : b : s) (a : b : s) s
keepIfArgs = (forall (st :: [*]) (o :: [*]). Condition (a : b : st) st st o o)
-> Condition (a : b : s) (a : b : s) (a : b : s) (a : b : s) s
forall a b (s :: [*]).
(Dupable a, Dupable b) =>
(forall (st :: [*]) (o :: [*]). Condition (a : b : st) st st o o)
-> Condition (a : b : s) (a : b : s) (a : b : s) (a : b : s) s
PreserveArgsBinCondition