{-# OPTIONS_GHC -Wall -fno-warn-orphans #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  ToySolver.BitVector.Base
-- Copyright   :  (c) Masahiro Sakai 2016
-- License     :  BSD-style
-- 
-- Maintainer  :  masahiro.sakai@gmail.com
-- Stability   :  experimental
--
-----------------------------------------------------------------------------
module ToySolver.BitVector.Base
  (
  -- * BitVector values
    BV
  , bv2nat
  , nat2bv
  , fromAscBits
  , fromDescBits
  , toAscBits
  , toDescBits
  , IsBV (..)

  -- * BitVector language
  , Var (..)
  , Expr (..)
  , Op1 (..)
  , Op2 (..)
  , repeat
  , zeroExtend
  , signExtend
  , Atom (..)
  , BVComparison (..)
  , module ToySolver.Data.OrdRel
  , Model
  , evalExpr
  , evalAtom
  ) where

import Prelude hiding (repeat)
import Data.Bits
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Monoid
import Data.Ord
import qualified Data.Semigroup as Semigroup
import qualified Data.Vector as V
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Unboxed as VU
import ToySolver.Data.Boolean
import ToySolver.Data.OrdRel

class Monoid a => IsBV a where
  width :: a -> Int
  extract :: Int -> Int -> a -> a
  fromBV :: BV -> a

  bvnot  :: a -> a
  bvand  :: a -> a -> a
  bvor   :: a -> a -> a
  bvxor  :: a -> a -> a
  bvnand :: a -> a -> a
  bvnor  :: a -> a -> a
  bvxnor :: a -> a -> a

  bvneg  :: a -> a
  bvadd  :: a -> a -> a
  bvsub  :: a -> a -> a
  bvmul  :: a -> a -> a
  bvudiv :: a -> a -> a
  bvurem :: a -> a -> a
  bvsdiv :: a -> a -> a
  bvsrem :: a -> a -> a
  bvsmod :: a -> a -> a
  bvshl  :: a -> a -> a
  bvlshr :: a -> a -> a
  bvashr :: a -> a -> a
  bvcomp :: a -> a -> a

  bvnand s t = bvnot (bvand s t)
  bvnor s t  = bvnot (bvor s t)
  bvxnor s t = bvnot (bvxor s t)

  bvsub s t = bvadd s (bvneg t)

repeat :: Monoid m => Int -> m -> m
repeat i x = mconcat (replicate i x)

zeroExtend :: IsBV a => Int -> a -> a
zeroExtend i s = fromAscBits (replicate i False) <> s

signExtend :: IsBV a => Int -> a -> a
signExtend i s
  | w == 0 = fromAscBits (replicate i False)
  | otherwise = repeat i (extract (w-1) (w-1) s) <> s
  where
    w = width s

class (IsBV a, IsEqRel a (ComparisonResult a), Complement (ComparisonResult a)) => BVComparison a where
  type ComparisonResult a

  bvule, bvult, bvuge, bvugt, bvsle, bvslt, bvsge, bvsgt :: a -> a -> ComparisonResult a

  bvule a b = notB (bvult b a)
  bvult a b = notB (bvule b a)
  bvuge a b = bvule b a
  bvugt a b = bvult b a

  bvsle a b = notB (bvslt b a)
  bvslt a b = notB (bvsle b a)
  bvsge a b = bvsle b a
  bvsgt a b = bvslt b a

  {-# MINIMAL (bvule | bvult), (bvsle | bvslt) #-}

-- ------------------------------------------------------------------------

newtype BV = BV (VU.Vector Bool)
  deriving (Eq)

instance Ord BV where
  compare (BV bs1) (BV bs2) =
    (comparing VG.length <> comparing VG.reverse) bs1 bs2

instance Semigroup.Semigroup BV where
  BV hi <> BV lo = BV (lo <> hi)

instance Monoid BV where
  mempty = BV VG.empty
#if !(MIN_VERSION_base(4,11,0))
  mappend = (Semigroup.<>)
#endif

instance Show BV where
  show bv = "0b" ++ [if b then '1' else '0' | b <- toDescBits bv]

instance Bits BV where
  (.&.) = bvand
  (.|.) = bvor
  xor = bvxor
  complement = bvnot

  shiftL x i
    | i < w = extract (w-1-i) 0 x <> nat2bv i 0
    | otherwise = nat2bv w 0
    where
      w = width x
  shiftR x i
    | i < w = nat2bv i 0 <> extract (w-1) i x
    | otherwise = nat2bv w 0
    where
      w = width x
  rotateL x i
    | w == 0 = x
    | otherwise = extract (w-1-j) 0 x <> extract (w-1) (w-j) x
    where
      w = width x
      j = i `mod` w
  rotateR x i
    | w == 0 = x
    | otherwise = extract (j-1) 0 x <> extract (w-1) j x
    where
      w = width x
      j = i `mod` w

  zeroBits = error "zeroBits is not implemented"

  bit = error "bit is not implemented"

  setBit x@(BV bs) i
    | 0 <= i && i < w = BV $ bs VG.// [(i,True)]
    | otherwise = x
    where
      w = width x
  clearBit x@(BV bs) i
    | 0 <= i && i < w = BV $ bs VG.// [(i,False)]
    | otherwise = x
    where
      w = width x
  complementBit x@(BV bs) i
    | 0 <= i && i < w = BV $ bs VG.// [(i, not (testBit x i))]
    | otherwise = x
    where
      w = width x
  testBit x@(BV bs) i
    | 0 <= i && i < w = bs VG.! i
    | otherwise = False
    where
      w = width x

  popCount x = sum [1 | b <- toAscBits x, b]

  bitSizeMaybe _ = Nothing
  bitSize _ = error "bitSize is not implemented"
  isSigned _ = False

instance IsBV BV where
  width (BV bs) = VG.length bs
  extract i j (BV bs) = BV $ VG.slice j (i - j + 1) bs
  fromBV = id

  bvnot (BV bs) = BV $ VG.map not bs

  bvand (BV bs1) (BV bs2)
    | VG.length bs1 /= VG.length bs2 = error "width mismatch"
    | otherwise = BV $ VG.zipWith (&&) bs1 bs2
  bvor (BV bs1) (BV bs2)
    | VG.length bs1 /= VG.length bs2 = error "width mismatch"
    | otherwise = BV $ VG.zipWith (||) bs1 bs2
  bvxor (BV bs1) (BV bs2)
    | VG.length bs1 /= VG.length bs2 = error "width mismatch"
    | otherwise = BV $ VG.zipWith (/=) bs1 bs2

  bvneg x = nat2bv (width x) $ 2 ^ width x - bv2nat x

  bvadd x y
    | width x /= width y = error "invalid width"
    | otherwise = nat2bv (width x) (bv2nat x + bv2nat y)

  bvmul x y
    | width x /= width y = error "invalid width"
    | otherwise = nat2bv (width x) (bv2nat x * bv2nat y)

  bvudiv x y
    | width x /= width y = error "invalid width"
    | y' == 0 = error "division by zero"
    | otherwise = nat2bv (width x) (bv2nat x `div` y')
    where
      y' :: Integer
      y' = bv2nat y

  bvurem x y
    | width x /= width y = error "invalid width"
    | y' == 0 = error "division by zero"
    | otherwise = nat2bv (width x) (bv2nat x `mod` y')
    where
      y' :: Integer
      y' = bv2nat y

  bvsdiv x y
    | width x < 1 || width y < 1 || width x /= width y = error "invalid width"
    | not msb_x && not msb_y = bvudiv x y
    | msb_x && not msb_y = bvneg $ bvudiv (bvneg x) y
    | not msb_x && msb_y = bvneg $ bvudiv x (bvneg y)
    | otherwise = bvudiv (bvneg x) (bvneg y)
    where
      msb_x = testBit x (width x - 1)
      msb_y = testBit y (width y - 1)

  bvsrem x y
    | width x < 1 || width y < 1 || width x /= width y = error "invalid width"
    | not msb_x && not msb_y = bvurem x y
    | msb_x && not msb_y = bvneg $ bvurem (bvneg x) y
    | not msb_x && msb_y = bvurem x (bvneg y)
    | otherwise = bvneg $ bvurem (bvneg x) (bvneg y)
    where
      msb_x = testBit x (width x - 1)
      msb_y = testBit y (width y - 1)

  bvsmod x y
    | width x < 1 || width y < 1 || width x /= width y = error "invalid width"
    | bv2nat u == (0::Integer) = u
    | not msb_x && not msb_y = u
    | msb_x && not msb_y = bvadd (bvneg u) y
    | not msb_x && msb_y = bvadd u y
    | otherwise = bvneg u
    where
      msb_x = testBit x (width x - 1)
      msb_y = testBit y (width y - 1)
      abs_x = if msb_x then bvneg x else x
      abs_y = if msb_y then bvneg y else y
      u = bvurem abs_x abs_y

  bvshl  x y
    | width x /= width y = error "invalid width"
    | otherwise = nat2bv (width x) (bv2nat x `shiftL` bv2nat y)

  bvlshr x y
    | width x /= width y = error "invalid width"
    | otherwise = nat2bv (width x) (bv2nat x `shiftR` bv2nat y)

  bvashr x y
    | width x /= width y = error "invalid width"
    | not msb_x = bvlshr x y
    | otherwise = bvneg $ bvlshr (bvneg x) y
    where
      msb_x = testBit x (width x - 1)

  bvcomp x y
    | width x /= width y = error "invalid width"
    | otherwise = nat2bv 1 (if x==y then 1 else 0)

instance IsEqRel BV Bool where
  (.==.) = (==)
  (./=.) = (/=)

instance BVComparison BV where
  type ComparisonResult BV = Bool

  bvule = (<=)

  bvsle bs1 bs2
    | width bs1 /= width bs2 = error ("length mismatch: " ++ show (width bs1) ++ " and " ++ show (width bs2))
    | w == 0 = true
    | otherwise = bs1_msb && not bs2_msb || (bs1_msb == bs2_msb) && bs1 <= bs2
    where
      w = width bs1
      bs1_msb = testBit bs1 (w-1)
      bs2_msb = testBit bs2 (w-1)

bv2nat :: Integral a => BV -> a
bv2nat (BV bv) = VG.ifoldl' (\r i x -> if x then r+2^i else r) 0 bv

nat2bv :: IsBV a => Int -> Integer -> a
nat2bv m x = fromBV $ BV $ VG.generate m (testBit x)

fromAscBits :: IsBV a => [Bool] -> a
fromAscBits = fromBV . BV . VG.fromList

fromDescBits :: IsBV a => [Bool] -> a
fromDescBits = fromBV . fromAscBits . reverse

toAscBits :: BV -> [Bool]
toAscBits (BV bs) = VG.toList bs

toDescBits :: BV -> [Bool]
toDescBits = reverse . toAscBits

-- ------------------------------------------------------------------------

data Var
  = Var
  { varWidth :: {-# UNPACK #-} !Int
  , varId :: {-# UNPACK #-} !Int
  }
  deriving (Eq, Ord, Show)

data Expr
  = EConst BV
  | EVar Var
  | EOp1 Op1 Expr
  | EOp2 Op2 Expr Expr
  deriving (Eq, Ord, Show)

data Op1
  = OpExtract !Int !Int
  | OpNot
  | OpNeg
  deriving (Eq, Ord, Show)

data Op2
  = OpConcat
  | OpAnd
  | OpOr
  | OpXOr
  | OpComp
  | OpAdd
  | OpMul
  | OpUDiv
  | OpURem
  | OpSDiv
  | OpSRem
  | OpSMod
  | OpShl
  | OpLShr
  | OpAShr
  deriving (Eq, Ord, Enum, Bounded, Show)

instance IsBV Expr where
  width (EConst x) = width x
  width (EVar v) = varWidth v
  width (EOp1 op arg) =
    case op of
      OpExtract i j -> i - j + 1
      _ -> width arg
  width (EOp2 op arg1 arg2) =
    case op of
      OpConcat -> width arg1 + width arg2
      OpComp -> 1
      _ -> width arg1

  extract i j = EOp1 (OpExtract i j)

  fromBV = EConst

  bvnot = EOp1 OpNot
  bvand = EOp2 OpAnd
  bvor  = EOp2 OpOr
  bvxor = EOp2 OpXOr

  bvneg  = EOp1 OpNeg
  bvadd  = EOp2 OpAdd
  bvmul  = EOp2 OpMul
  bvudiv = EOp2 OpUDiv
  bvurem = EOp2 OpURem
  bvsdiv = EOp2 OpSDiv
  bvsrem = EOp2 OpSRem
  bvsmod = EOp2 OpSMod
  bvshl  = EOp2 OpShl
  bvlshr = EOp2 OpLShr
  bvashr = EOp2 OpAShr
  bvcomp = EOp2 OpComp

instance Semigroup.Semigroup Expr where
  (<>) = EOp2 OpConcat

instance Monoid Expr where
  mempty = EConst mempty
#if !(MIN_VERSION_base(4,11,0))
  mappend = (Semigroup.<>)
#endif

instance Bits Expr where
  (.&.) = bvand
  (.|.) = bvor
  xor = bvxor
  complement = bvnot
  shiftL x i
    | i < w = extract (w-1-i) 0 x <> nat2bv i 0
    | otherwise = nat2bv w 0
    where
      w = width x
  shiftR x i
    | i < w = nat2bv i 0 <> extract (w-1) i x
    | otherwise = nat2bv w 0
    where
      w = width x
  rotateL x i
    | w == 0 = x
    | otherwise = extract (w-1-j) 0 x <> extract (w-1) (w-j) x
    where
      w = width x
      j = i `mod` w
  rotateR x i
    | w == 0 = x
    | otherwise = extract (j-1) 0 x <> extract (w-1) j x
    where
      w = width x
      j = i `mod` w

  zeroBits = error "zeroBits is not implemented"

  bit = error "bit is not implemented"

  setBit x i
    | 0 <= i && i < w = extract (w-1) (i+1) x <> fromDescBits [True] <> extract (i-1) 0 x
    | otherwise = x
    where
      w = width x

  clearBit x i
    | 0 <= i && i < w = extract (w-1) (i+1) x <> fromDescBits [False] <> extract (i-1) 0 x
    | otherwise = x
    where
      w = width x

  complementBit x i
    | 0 <= i && i < w = extract (w-1) (i+1) x <> bvnot (extract i i x) <> extract (i-1) 0 x
    | otherwise = x
    where
      w = width x

  testBit = error "testBit is not implemented"

  popCount = error "popCount is not implemented"

  bitSizeMaybe _ = Nothing
  bitSize _ = error "bitSize is not implemented"
  isSigned _ = False

data Atom = Rel (OrdRel Expr) Bool
  deriving (Eq, Ord, Show)

instance Complement Atom where
  notB (Rel rel signed) = Rel (notB rel) signed

instance IsEqRel Expr Atom where
  a .==. b = Rel (a .==. b) False
  a ./=. b = Rel (a ./=. b) False

instance BVComparison Expr where
  type ComparisonResult Expr = Atom

  bvule s t = Rel (s .<=. t) False
  bvult s t = Rel (s .<.  t) False
  bvuge s t = Rel (s .>=. t) False
  bvugt s t = Rel (s .>.  t) False
  bvsle s t = Rel (s .<=. t) True
  bvslt s t = Rel (s .<.  t) True
  bvsge s t = Rel (s .>=. t) True
  bvsgt s t = Rel (s .>.  t) True

-- ------------------------------------------------------------------------

type Model = (V.Vector BV, Map BV BV, Map BV BV)

evalExpr :: Model -> Expr -> BV
evalExpr (env, divTable, remTable) = f
  where
    f (EConst bv) = bv
    f (EVar v) = env VG.! varId v
    f (EOp1 op x) = evalOp1 op (f x)
    f (EOp2 op x y) = evalOp2 op (f x) (f y)

    evalOp1 (OpExtract i j) = extract i j
    evalOp1 OpNot = bvnot
    evalOp1 OpNeg = bvneg

    evalOp2 OpConcat a b = a <> b
    evalOp2 OpAnd x y = bvand x y
    evalOp2 OpOr x y = bvor x y
    evalOp2 OpXOr x y = bvxor x y
    evalOp2 OpAdd x y = bvadd x y
    evalOp2 OpMul x y = bvmul x y
    evalOp2 OpUDiv x y
      | y' /= 0 = bvudiv x y
      | otherwise =
          case Map.lookup x divTable of
            Just d -> d
            Nothing -> nat2bv (width x) 0
      where
        y' :: Integer
        y' = bv2nat y
    evalOp2 OpURem x y
      | y' /= 0 = bvurem x y
      | otherwise =
          case Map.lookup x remTable of
            Just r -> r
            Nothing -> nat2bv (width x) 0
      where
        y' :: Integer
        y' = bv2nat y
    evalOp2 OpSDiv x y
      | width x < 1 || width y < 1 || width x /= width y = error "invalid width"
      | not msb_x && not msb_y = evalOp2 OpUDiv x y
      | msb_x && not msb_y = bvneg $ evalOp2 OpUDiv (bvneg x) y
      | not msb_x && msb_y = bvneg $ evalOp2 OpUDiv x (bvneg y)
      | otherwise = evalOp2 OpUDiv (bvneg x) (bvneg y)
      where
        msb_x = testBit x (width x - 1)
        msb_y = testBit y (width y - 1)
    evalOp2 OpSRem x y
      | width x < 1 || width y < 1 || width x /= width y = error "invalid width"
      | not msb_x && not msb_y = evalOp2 OpURem x y
      | msb_x && not msb_y = bvneg $ evalOp2 OpURem (bvneg x) y
      | not msb_x && msb_y = evalOp2 OpURem x (bvneg y)
      | otherwise = bvneg $ evalOp2 OpURem (bvneg x) (bvneg y)
      where
        msb_x = testBit x (width x - 1)
        msb_y = testBit y (width y - 1)
    evalOp2 OpSMod x y
      | width x < 1 || width y < 1 || width x /= width y = error "invalid width"
      | bv2nat u == (0::Integer) = u
      | not msb_x && not msb_y = u
      | msb_x && not msb_y = bvadd (bvneg u) y
      | not msb_x && msb_y = bvadd u y
      | otherwise = bvneg u
      where
        msb_x = testBit x (width x - 1)
        msb_y = testBit y (width y - 1)
        abs_x = if msb_x then bvneg x else x
        abs_y = if msb_y then bvneg y else y
        u = evalOp2 OpURem abs_x abs_y
    evalOp2 OpShl x y = bvshl x y
    evalOp2 OpLShr x y = bvlshr x y
    evalOp2 OpAShr x y = bvashr x y
    evalOp2 OpComp x y = nat2bv 1 (if x==y then 1 else 0)

evalAtom :: Model -> Atom -> Bool
evalAtom m (Rel (OrdRel lhs op rhs) False) = evalOp op (evalExpr m lhs) (evalExpr m rhs)
evalAtom m (Rel (OrdRel lhs op rhs) True) =
  case op of
    Lt -> bvslt lhs' rhs'
    Gt -> bvslt rhs' lhs'
    Le -> bvsle lhs' rhs'
    Ge -> bvsle rhs' lhs'
    Eql -> lhs' == rhs'
    NEq -> lhs' /= rhs'
  where
    lhs' = evalExpr m lhs
    rhs' = evalExpr m rhs