{-# LANGUAGE LambdaCase #-}

-- | Definitions useful for all state transition modules.
module Stg.Machine.Evaluate.Common (

    -- * Primops
    PrimError(..),
    applyPrimOp,

    -- * Algebraic matching
    AltMatch(..),
    AltError(..),
    lookupAlgebraicAlt,

    -- * Primitive matching
    lookupPrimitiveAlt,

) where



import qualified Data.List as L

import Stg.Language
import Stg.Util



-- | Possible errors of primops
data PrimError = Div0 -- ^ Division by zero

-- | Apply a primop to two actual integers
applyPrimOp :: PrimOp -> Integer -> Integer -> Validate PrimError Integer
applyPrimOp Div _ 0 = Failure Div0
applyPrimOp Mod _ 0 = Failure Div0
applyPrimOp op x y = Success (opToFunc op x y)
  where
    boolToPrim p a b = if p a b then 1 else 0
    opToFunc = \case
        Add -> (+)
        Sub -> (-)
        Mul -> (*)
        Div -> div
        Mod -> mod
        Eq  -> boolToPrim (==)
        Lt  -> boolToPrim (<)
        Leq -> boolToPrim (<=)
        Gt  -> boolToPrim (>)
        Geq -> boolToPrim (>=)
        Neq -> boolToPrim (/=)




-- | Successful alternative match, used for finding the right branch in @case@
data AltMatch alt = AltMatches alt | DefaultMatches DefaultAlt

-- | Possible errors when looking up alternatives
data AltError = BadAlt -- ^ Algebraic\/primitive alternative in
                       -- primitive\/algebraic case

-- | Look up an algebraic constructor among the given alternatives, and return
-- the first match. If nothing matches, return the default alternative.
lookupAlgebraicAlt
    :: Alts
    -> Constr
    -> Validate AltError (AltMatch AlgebraicAlt)
lookupAlgebraicAlt (Alts (AlgebraicAlts alts) def) constr
  = let matchingAlt (AlgebraicAlt c _ _) = c == constr
    in Success (case L.find matchingAlt alts of
        Just alt   -> AltMatches alt
        _otherwise -> DefaultMatches def )
lookupAlgebraicAlt (Alts PrimitiveAlts{} _) _ = Failure BadAlt
lookupAlgebraicAlt (Alts NoNonDefaultAlts{} def) _ = Success (DefaultMatches def)



-- | 'lookupAlgebraicAlt' for primitive literals.
lookupPrimitiveAlt
    :: Alts
    -> Literal
    -> Validate AltError (AltMatch PrimitiveAlt)
lookupPrimitiveAlt (Alts (PrimitiveAlts alts) def) lit
  = let matchingAlt (PrimitiveAlt lit' _) = lit' == lit
    in Success (case L.find matchingAlt alts of
        Just alt   -> AltMatches alt
        _otherwise -> DefaultMatches def )
lookupPrimitiveAlt (Alts AlgebraicAlts{} _) _ = Failure BadAlt
lookupPrimitiveAlt (Alts NoNonDefaultAlts{} def) _ = Success (DefaultMatches def)