-- SPDX-FileCopyrightText: 2021 Oxhead Alpha
-- SPDX-License-Identifier: LicenseRef-MIT-OA

{-# OPTIONS_GHC -Wno-unticked-promoted-constructors #-}

-- | Lorentz wrappers over instructions from Morley extension.
module Lorentz.Ext
  ( stackRef
  , printComment
  , justComment
  , comment
  , commentAroundFun
  , commentAroundStmt
  , testAssert
  , stackType
  ) where

import Data.Singletons (SingI)
import GHC.TypeNats (Nat)

import Lorentz.Base
import Morley.Michelson.Typed.Haskell
import Morley.Michelson.Typed.Instr
import Morley.Util.Peano

-- | Include a value at given position on stack into comment produced
-- by 'printComment'.
--
-- > stackRef @0
-- <includes the top of the stack>
stackRef
  :: forall (gn :: Nat) st n.
      (n ~ ToPeano gn, SingI n, RequireLongerThan st n)
  => PrintComment st
stackRef :: forall (gn :: Nat) (st :: [T]) (n :: Peano).
(n ~ ToPeano gn, SingI n, RequireLongerThan st n) =>
PrintComment st
stackRef = [Either Text (StackRef st)] -> PrintComment st
forall (st :: [T]). [Either Text (StackRef st)] -> PrintComment st
PrintComment ([Either Text (StackRef st)] -> PrintComment st)
-> (StackRef st -> [Either Text (StackRef st)])
-> StackRef st
-> PrintComment st
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either Text (StackRef st) -> [Either Text (StackRef st)]
forall x. One x => OneItem x -> x
one (Either Text (StackRef st) -> [Either Text (StackRef st)])
-> (StackRef st -> Either Text (StackRef st))
-> StackRef st
-> [Either Text (StackRef st)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StackRef st -> Either Text (StackRef st)
forall a b. b -> Either a b
Right (StackRef st -> PrintComment st) -> StackRef st -> PrintComment st
forall a b. (a -> b) -> a -> b
$ forall (gn :: Nat) (st :: [T]) (n :: Peano).
(n ~ ToPeano gn, SingIPeano gn, RequireLongerThan st n) =>
StackRef st
mkStackRef @gn

-- | Print a comment. It will be visible in tests.
--
-- > printComment "Hello world!"
-- > printComment $ "On top of the stack I see " <> stackRef @0
printComment :: PrintComment (ToTs s) -> s :-> s
printComment :: forall (s :: [*]). PrintComment (ToTs s) -> s :-> s
printComment = Instr (ToTs s) (ToTs s) -> s :-> s
forall (inp :: [*]) (out :: [*]).
Instr (ToTs inp) (ToTs out) -> inp :-> out
I (Instr (ToTs s) (ToTs s) -> s :-> s)
-> (PrintComment (ToTs s) -> Instr (ToTs s) (ToTs s))
-> PrintComment (ToTs s)
-> s :-> s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtInstr (ToTs s) -> Instr (ToTs s) (ToTs s)
forall (inp :: [T]). ExtInstr inp -> Instr inp inp
Ext (ExtInstr (ToTs s) -> Instr (ToTs s) (ToTs s))
-> (PrintComment (ToTs s) -> ExtInstr (ToTs s))
-> PrintComment (ToTs s)
-> Instr (ToTs s) (ToTs s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrintComment (ToTs s) -> ExtInstr (ToTs s)
forall (s :: [T]). PrintComment s -> ExtInstr s
PRINT

justComment :: Text -> s :-> s
justComment :: forall (s :: [*]). Text -> s :-> s
justComment = CommentType -> s :-> s
forall (s :: [*]). CommentType -> s :-> s
comment (CommentType -> s :-> s)
-> (Text -> CommentType) -> Text -> s :-> s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> CommentType
JustComment

comment :: CommentType -> s :-> s
comment :: forall (s :: [*]). CommentType -> s :-> s
comment = Instr (ToTs s) (ToTs s) -> s :-> s
forall (inp :: [*]) (out :: [*]).
Instr (ToTs inp) (ToTs out) -> inp :-> out
I (Instr (ToTs s) (ToTs s) -> s :-> s)
-> (CommentType -> Instr (ToTs s) (ToTs s))
-> CommentType
-> s :-> s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtInstr (ToTs s) -> Instr (ToTs s) (ToTs s)
forall (inp :: [T]). ExtInstr inp -> Instr inp inp
Ext (ExtInstr (ToTs s) -> Instr (ToTs s) (ToTs s))
-> (CommentType -> ExtInstr (ToTs s))
-> CommentType
-> Instr (ToTs s) (ToTs s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CommentType -> ExtInstr (ToTs s)
forall (s :: [T]). CommentType -> ExtInstr s
COMMENT_ITEM

commentAroundFun :: Text -> (i :-> o) -> (i :-> o)
commentAroundFun :: forall (i :: [*]) (o :: [*]). Text -> (i :-> o) -> i :-> o
commentAroundFun Text
funName i :-> o
funBody =
  CommentType -> i :-> i
forall (s :: [*]). CommentType -> s :-> s
comment (Text -> CommentType
FunctionStarts Text
funName) (i :-> i) -> (i :-> o) -> i :-> o
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
#
  i :-> o
funBody (i :-> o) -> (o :-> o) -> i :-> o
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
#
  CommentType -> o :-> o
forall (s :: [*]). CommentType -> s :-> s
comment (Text -> CommentType
FunctionEnds Text
funName)

commentAroundStmt :: Text -> (i :-> o) -> (i :-> o)
commentAroundStmt :: forall (i :: [*]) (o :: [*]). Text -> (i :-> o) -> i :-> o
commentAroundStmt Text
stmtName i :-> o
stmtBody =
  CommentType -> i :-> i
forall (s :: [*]). CommentType -> s :-> s
comment (Text -> CommentType
StatementStarts Text
stmtName) (i :-> i) -> (i :-> o) -> i :-> o
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
#
  i :-> o
stmtBody (i :-> o) -> (o :-> o) -> i :-> o
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
#
  CommentType -> o :-> o
forall (s :: [*]). CommentType -> s :-> s
comment (Text -> CommentType
StatementEnds Text
stmtName)

-- | Test an invariant, fail if it does not hold.
--
-- This won't be included into production contract and is executed only in tests.
testAssert
  :: (HasCallStack)
  => Text -> PrintComment (ToTs inp) -> inp :-> Bool : out -> inp :-> inp
testAssert :: forall (inp :: [*]) (out :: [*]).
HasCallStack =>
Text
-> PrintComment (ToTs inp) -> (inp :-> (Bool : out)) -> inp :-> inp
testAssert Text
msg PrintComment (ToTs inp)
comment' = \case
  I Instr (ToTs inp) (ToTs (Bool : out))
instr -> Instr (ToTs inp) (ToTs inp) -> inp :-> inp
forall (inp :: [*]) (out :: [*]).
Instr (ToTs inp) (ToTs out) -> inp :-> out
I (Instr (ToTs inp) (ToTs inp) -> inp :-> inp)
-> (TestAssert (ToTs inp) -> Instr (ToTs inp) (ToTs inp))
-> TestAssert (ToTs inp)
-> inp :-> inp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtInstr (ToTs inp) -> Instr (ToTs inp) (ToTs inp)
forall (inp :: [T]). ExtInstr inp -> Instr inp inp
Ext (ExtInstr (ToTs inp) -> Instr (ToTs inp) (ToTs inp))
-> (TestAssert (ToTs inp) -> ExtInstr (ToTs inp))
-> TestAssert (ToTs inp)
-> Instr (ToTs inp) (ToTs inp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TestAssert (ToTs inp) -> ExtInstr (ToTs inp)
forall (s :: [T]). TestAssert s -> ExtInstr s
TEST_ASSERT (TestAssert (ToTs inp) -> inp :-> inp)
-> TestAssert (ToTs inp) -> inp :-> inp
forall a b. (a -> b) -> a -> b
$ Text
-> PrintComment (ToTs inp)
-> Instr (ToTs inp) ('TBool : ToTs out)
-> TestAssert (ToTs inp)
forall (s :: [T]) (out :: [T]).
Text -> PrintComment s -> Instr s ('TBool : out) -> TestAssert s
TestAssert Text
msg PrintComment (ToTs inp)
comment' Instr (ToTs inp) ('TBool : ToTs out)
Instr (ToTs inp) (ToTs (Bool : out))
instr
  FI forall (out' :: [T]). Instr (ToTs inp) out'
_ -> Text -> inp :-> inp
forall a. HasCallStack => Text -> a
error Text
"test assert branch always fails"

-- | Fix the current type of the stack to be given one.
--
-- > stackType @'[Natural]
-- > stackType @(Integer : Natural : s)
-- > stackType @'["balance" :! Integer, "toSpend" :! Integer, BigMap Address Integer]
--
-- Note that you can omit arbitrary parts of the type.
--
-- > stackType @'["balance" :! Integer, "toSpend" :! _, BigMap _ _]
stackType :: forall s. s :-> s
stackType :: forall (s :: [*]). s :-> s
stackType = Instr (ToTs s) (ToTs s) -> s :-> s
forall (inp :: [*]) (out :: [*]).
Instr (ToTs inp) (ToTs out) -> inp :-> out
I Instr (ToTs s) (ToTs s)
forall (inp :: [T]). Instr inp inp
Nop

_sample1 :: (s ~ (a : s')) => s :-> s
_sample1 :: forall (s :: [*]) a (s' :: [*]). (s ~ (a : s')) => s :-> s
_sample1 = PrintComment (ToTs s) -> s :-> s
forall (s :: [*]). PrintComment (ToTs s) -> s :-> s
printComment (PrintComment (ToTs s) -> s :-> s)
-> PrintComment (ToTs s) -> s :-> s
forall a b. (a -> b) -> a -> b
$ PrintComment (ToT a : ToTs s')
"Head is " PrintComment (ToT a : ToTs s')
-> PrintComment (ToT a : ToTs s') -> PrintComment (ToT a : ToTs s')
forall a. Semigroup a => a -> a -> a
<> forall (gn :: Nat) (st :: [T]) (n :: Peano).
(n ~ ToPeano gn, SingI n, RequireLongerThan st n) =>
PrintComment st
stackRef @0

_sample2 :: Integer : Natural : s :-> Integer : Natural : s
_sample2 :: forall (s :: [*]).
(Integer : Natural : s) :-> (Integer : Natural : s)
_sample2 = forall (s :: [*]). s :-> s
stackType @(Integer : _)