{-# LANGUAGE DeriveDataTypeable #-}
-- | Arithmetic taint tracking.
-- Extension of Word32, that allows "U"nknown values to be used. Operations on
-- known values are calculated as usual, while operations on unknowns are
-- recorded in unsimplified form.
-- >>> let e = 1+5 `xor` U `lshift` 2
-- >>> e
-- X Xor (N 6) (X LShift U (N 2))
-- >>> pprint e
-- "(6^(?<<2))"
-- >>> 1+5 `xor` 0xBEEF `lshift` 2
-- N 195514
-- Intended for analysis of XSHA-1, so only supports the necessary primitives.
module Numeric.Taint.Word32
    ( N(..), Op(..)
    , pprint, lshift
    ) where

import Data.Bits
import Data.Word
-- expression optimization:
import Data.Data
import Data.Generics.Uniplate.Data
import Data.Generics.Uniplate.Operations

-- | A numeric type extended to hold information about unknown values.
data N = N Word32  -- ^ Numbers are still present.
       | U         -- ^ Unknown values are marked with X.
       | X Op N N  -- ^ Operations can be piled on top of them.
       deriving (Eq,Ord,Show,Data,Typeable)

-- | Supported operators.
data Op = LShift | Xor | Or | And | Add | Sub | Rot
    deriving (Eq,Ord,Show,Data,Typeable)

pprint :: N -> String
-- ^ Infix notation display of stored operations.
pprint (N x) = show x
pprint U = "?"
pprint (X op a b) = concat ["(", pprint a, sh op, pprint b, ")"] where
    sh Rot = "<>"
    sh Add = "+"
    sh Sub = "-"
    sh Xor = "^"
    sh Or  = "|"
    sh And = "&"
    sh LShift = "<<"

lshift :: N -> N -> N
-- ^ The shift in the Bits class doesn't allow non-int shifts.
-- We'll be shifting by Unknown values, so have to use custom stuff.
-- Also we're following VC++ compiler behaviour, and shifting in
-- modulo 32 for 32 bit unsigned numbers. This behaviour is undefined
-- in the C standard.
lshift = lift LShift

-- | Partial instance, since I only care about XSHA1 operations.
instance Num N where
    (+) = lift Add
    (-) = lift Sub
    fromInteger = N . fromInteger
    (*) = undefined
    abs x = x
    signum _ = (N 1)

-- | Partial instance, since I only care about XSHA1 operations.
instance Bits N where
    (.&.) = lift And
    (.|.) = lift Or
    xor = lift Xor
    complement (N a) = N (complement a)
    complement n = X Xor (N 0xFFFFFFFF) n
    shift = undefined
    rotate (N a) i = N $ rotate a i
    rotate a i = X Rot a (N . fromIntegral $ i)
    bitSize _ = 32
    isSigned _ = False

lift :: Op -> N -> N -> N
lift op a b = optimize $ X op a b

optimize :: N -> N
-- ^ Simplify an algebraic expresion.
optimize = rewrite f where
    -- C standard doesn't define shifts greater or equal than number of bits
    -- Visual C++ seems to take modulo 32 of i
    f (X LShift a (N i)) | i>31 = Just $ X LShift a (N (i `mod` 32))
    -- trivial integer ops
    f (X op (N a) (N b)) = Just $ N $ eval op a b
    -- associativity of all operations with themselves
    f (X op (N a) (X op' (N b) c)) | op==op' = Just $ X op (N $ eval op a b) c
    f (X op (X op' (N b) c) (N a)) | op==op' = Just $ X op (N $ eval op a b) c
    f (X op (N a) (X op' c (N b))) | op==op' = Just $ X op (N $ eval op a b) c
    f (X op (X op' c (N b)) (N a)) | op==op' = Just $ X op (N $ eval op a b) c
    -- xor fixed point
    f (X Xor (N 0) a) = Just a
    f (X Xor a (N 0)) = Just a
    -- ANDing by 31 makes only the last 5 bits of the children matter, as
    -- long as we're only XORing or setting bits
    f (X LShift (N 1) (X And 31 n)) = 
        if filter (>31) (childrenBi n::[Word32]) == []
            then Nothing
            else Just (X LShift (N 1) (X And 31 n')) where
                n' = transformBi ((.&. 31)::Word32->Word32) n
    -- leave all else unchanged
    f x = Nothing
    -- evaluate an operation
    eval Add = (+)
    eval Sub = (-)
    eval And = (.&.)
    eval Or = (.|.)
    eval Xor = xor
    eval LShift = \a b->if b>31 then 0 else a `shiftL` (fromIntegral b)
    eval Rot = \a b->rotate a (fromIntegral b)