--------------------------------------------------------------------------------

{-# LANGUAGE RankNTypes, NamedFieldPuns, ScopedTypeVariables, GADTs,
             LambdaCase #-}
{-# LANGUAGE Safe #-}

module Copilot.Theorem.IL.Translate ( translate, translateWithBounds ) where

import Copilot.Theorem.IL.Spec

import qualified Copilot.Core as C

import qualified Data.Map.Strict as Map

import Control.Monad.State

import Data.Char
import Data.List (find)

import Text.Printf

import GHC.Float (float2Double)

import Data.Typeable (Typeable)

--------------------------------------------------------------------------------

-- 'nc' stands for naming convention.
ncSeq :: C.Id -> SeqId
ncSeq = printf "s%d"

-- We assume all local variables have distinct names whatever their scopes.
ncLocal :: C.Name -> SeqId
ncLocal s = "l" ++ dropWhile (not . isNumber) s

ncExternVar :: C.Name -> SeqId
ncExternVar n = "ext_" ++ n

ncExternFun :: C.Name -> SeqId
ncExternFun n = "_" ++ n

ncUnhandledOp :: String -> String
ncUnhandledOp = id

ncMux :: Integer -> SeqId
ncMux n = "mux" ++ show n

--------------------------------------------------------------------------------

-- | Translates a Copilot specification to an IL specification

translate :: C.Spec -> IL
translate = translate' False

translateWithBounds :: C.Spec -> IL
translateWithBounds = translate' True

translate' :: Bool -> C.Spec -> IL
translate' b (C.Spec {C.specStreams, C.specProperties}) = runTrans b $ do

  let modelInit = concatMap streamInit specStreams

  mainConstraints <- mapM streamRec specStreams

  localConstraints <- popLocalConstraints
  properties <- Map.fromList <$>
    forM specProperties
      (\(C.Property {C.propertyName, C.propertyExpr}) -> do
        e' <- expr propertyExpr
        propConds <- popLocalConstraints
        return (propertyName, (propConds, e')))

  return IL
    { modelInit
    , modelRec = mainConstraints ++ localConstraints
    , properties
    , inductive = not $ null specStreams
    }

bound :: Expr -> C.Type a -> Trans ()
bound s t = case t of
  C.Int8    -> bound' C.Int8
  C.Int16   -> bound' C.Int16
  C.Int32   -> bound' C.Int32
  C.Int64   -> bound' C.Int64
  C.Word8   -> bound' C.Word8
  C.Word16  -> bound' C.Word16
  C.Word32  -> bound' C.Word32
  C.Word64  -> bound' C.Word64
  _         -> return ()
  where bound' :: (Bounded a, Integral a) => C.Type a -> Trans ()
        bound' t = do
          b <- addBounds <$> get
          when b $ localConstraint (Op2 Bool And
            (Op2 Bool Le (trConst t minBound) s)
            (Op2 Bool Ge (trConst t maxBound) s))

streamInit :: C.Stream -> [Expr]
streamInit (C.Stream { C.streamId       = id
                     , C.streamBuffer   = b :: [val]
                     , C.streamExprType = t }) =
  zipWith initConstraint [0..] b
  where initConstraint :: Integer -> val -> Expr
        initConstraint p v = Op2 Bool Eq
          (SVal (trType t) (ncSeq id) (Fixed p))
          $ trConst t v

streamRec :: C.Stream -> Trans Expr
streamRec (C.Stream { C.streamId       = id
                    , C.streamExpr     = e
                    , C.streamBuffer   = b
                    , C.streamExprType = t })
  = do
  let s = SVal (trType t) (ncSeq id) (_n_plus $ length b)
  bound s t
  e' <- expr e
  return $ Op2 Bool Eq s e'

--------------------------------------------------------------------------------

expr :: Typeable a => C.Expr a -> Trans Expr

expr (C.Const t v) = return $ trConst t v

expr (C.Label _ _ e) = expr e

expr (C.Drop t k id) = return $ SVal (trType t) (ncSeq id) (_n_plus k)

expr (C.Local ta _ name ea eb) = do
  ea' <- expr ea
  localConstraint (Op2 Bool Eq (SVal (trType ta) (ncLocal name) _n_) ea')
  expr eb

expr (C.Var t name) = return $ SVal (trType t) (ncLocal name) _n_

expr (C.ExternVar t name _) = bound s t >> return s
  where s = SVal (trType t) (ncExternVar name) _n_

expr (C.ExternFun t name args _ _) = do
  args' <- mapM trArg args
  let s = FunApp (trType t) (ncExternFun name) args'
  bound s t
  return s
  where trArg (C.UExpr {C.uExprExpr}) = expr uExprExpr

expr (C.Op1 (C.Sign ta) e) = case ta of
  C.Int8   -> trSign ta e
  C.Int16  -> trSign ta e
  C.Int32  -> trSign ta e
  C.Int64  -> trSign ta e
  C.Float  -> trSign ta e
  C.Double -> trSign ta e
  _        -> expr $ C.Const ta 1
  where trSign :: (Typeable a, Ord a, Num a) => C.Type a -> C.Expr a -> Trans Expr
        trSign ta e =
          expr (C.Op3 (C.Mux ta)
            (C.Op2 (C.Lt ta) e (C.Const ta 0))
            (C.Const ta (-1))
            (C.Op3 (C.Mux ta)
              (C.Op2 (C.Gt ta) e (C.Const ta 0))
              (C.Const ta 1)
              (C.Const ta 0)))
expr (C.Op1 (C.Sqrt _) e) = do
  e' <- expr e
  return $ Op2 Real Pow e' (ConstR 0.5)
expr (C.Op1 (C.Cast _ _) e) = expr e
expr (C.Op1 op e) = do
  e' <- expr e
  return $ Op1 t' op' e'
  where (op', t') = trOp1 op

expr (C.Op2 (C.Ne t) e1 e2) = do
  e1' <- expr e1
  e2' <- expr e2
  return $ Op1 Bool Not (Op2 t' Eq e1' e2')
  where t' = trType t

expr (C.Op2 op e1 e2) = do
  e1' <- expr e1
  e2' <- expr e2
  return $ Op2 t' op' e1' e2'
  where (op', t') = trOp2 op

expr (C.Op3 (C.Mux t) cond e1 e2) = do
  cond' <- expr cond
  e1'   <- expr e1
  e2'   <- expr e2
  newMux cond' (trType t) e1' e2'

trConst :: C.Type a -> a -> Expr
trConst t v = case t of
  C.Bool   -> ConstB v
  C.Float  -> negifyR (float2Double v)
  C.Double -> negifyR v
  t@C.Int8   -> negifyI v (trType t)
  t@C.Int16  -> negifyI v (trType t)
  t@C.Int32  -> negifyI v (trType t)
  t@C.Int64  -> negifyI v (trType t)
  t@C.Word8  -> negifyI v (trType t)
  t@C.Word16 -> negifyI v (trType t)
  t@C.Word32 -> negifyI v (trType t)
  t@C.Word64 -> negifyI v (trType t)
  where negifyR :: Double -> Expr
        negifyR v
          | v >= 0    = ConstR v
          | otherwise = Op1 Real Neg $ ConstR $ negate $ v
        negifyI :: Integral a => a -> Type -> Expr
        negifyI v t
          | v >= 0    = ConstI t $ toInteger v
          | otherwise = Op1 t Neg $ ConstI t $ negate $ toInteger v

trOp1 :: C.Op1 a b -> (Op1, Type)
trOp1 = \case
  C.Not     -> (Not, Bool)
  C.Abs t   -> (Abs, trType t)
  -- C.Sign t  ->
  -- C.Recip t ->
  C.Exp t   -> (Exp, trType t)
  -- C.Sqrt t  ->
  C.Log t   -> (Log, trType t)
  C.Sin t   -> (Sin, trType t)
  C.Tan t   -> (Tan, trType t)
  C.Cos t   -> (Cos, trType t)
  C.Asin t  -> (Asin, trType t)
  C.Atan t  -> (Atan, trType t)
  C.Acos t  -> (Acos, trType t)
  C.Sinh t  -> (Sinh, trType t)
  C.Tanh t  -> (Tanh, trType t)
  C.Cosh t  -> (Cosh, trType t)
  C.Asinh t -> (Asinh, trType t)
  C.Atanh t -> (Atanh, trType t)
  C.Acosh t -> (Acosh, trType t)
  -- C.BwNot t ->
  -- C.Cast t  ->
  _ -> error "Unsupported unary operator in input." -- TODO(chathhorn)

trOp2 :: C.Op2 a b c -> (Op2, Type)
trOp2 = \case
  C.And          -> (And, Bool)
  C.Or           -> (Or, Bool)

  C.Add t        -> (Add, trType t)
  C.Sub t        -> (Sub, trType t)
  C.Mul t        -> (Mul, trType t)

  C.Mod t        -> (Mod, trType t)
  -- C.Div t        ->

  C.Fdiv t       -> (Fdiv, trType t)

  C.Pow t        -> (Pow, trType t)
  -- C.Logb t       ->

  C.Eq _         -> (Eq, Bool)
  -- C.Ne t         ->

  C.Le t         -> (Le, trType t)
  C.Ge t         -> (Ge, trType t)
  C.Lt t         -> (Lt, trType t)
  C.Gt t         -> (Gt, trType t)

  -- C.BwAnd t      ->
  -- C.BwOr t       ->
  -- C.BwXor t      ->
  -- C.BwShiftL t _ ->
  -- C.BwShiftR t _ ->

  _ -> error "Unsupported binary operator in input." -- TODO(chathhorn)

trType :: C.Type a -> Type
trType = \case
  C.Bool   -> Bool
  C.Int8   -> SBV8
  C.Int16  -> SBV16
  C.Int32  -> SBV32
  C.Int64  -> SBV64
  C.Word8  -> BV8
  C.Word16 -> BV16
  C.Word32 -> BV32
  C.Word64 -> BV64
  C.Float  -> Real
  C.Double -> Real

--------------------------------------------------------------------------------

data TransST = TransST
  { localConstraints :: [Expr]
  , muxes            :: [(Expr, (Expr, Type, Expr, Expr))]
  , nextFresh        :: Integer
  , addBounds        :: Bool
  }

newMux :: Expr -> Type -> Expr -> Expr -> Trans Expr
newMux c t e1 e2 = do
  ms <- muxes <$> get
  case find ((==mux) . snd) ms of
    Nothing -> do
      f <- fresh
      let v = SVal t (ncMux f) _n_
      modify $ \st -> st { muxes = (v, mux) : ms }
      return v
    Just (v, _) -> return v
  where mux = (c, t, e1, e2)

getMuxes :: Trans [Expr]
getMuxes = muxes <$> get >>= return . concat . (map toConstraints)
  where toConstraints (v, (c, _, e1, e2)) =
          [ Op2 Bool Or (Op1 Bool Not c) (Op2 Bool Eq v e1)
          , Op2 Bool Or c (Op2 Bool Eq v e2)
          ]

type Trans = State TransST

fresh :: Trans Integer
fresh = do
  modify $ \st -> st {nextFresh = nextFresh st + 1}
  nextFresh <$> get

localConstraint :: Expr -> Trans ()
localConstraint c =
  modify $ \st -> st {localConstraints = c : localConstraints st}

popLocalConstraints :: Trans [Expr]
popLocalConstraints = liftM2 (++) (localConstraints <$> get) getMuxes
  <* (modify $ \st -> st {localConstraints = [], muxes = []})

runTrans :: Bool -> Trans a -> a
runTrans b m = evalState m $ TransST [] [] 0 b

--------------------------------------------------------------------------------