{-# LANGUAGE StrictData #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Crypto.Hash.MerkleTree (
  MerkleTree(..),
  MerkleRoot(..),
  MerkleNode(..),

  -- ** Constructors
  mkMerkleTree,
  mkRootHash,
  mkLeafRootHash,
  emptyHash,

  -- ** Merkle Proof
  MerkleProof(..),
  merkleProof,
  validateMerkleProof,

  -- ** Size
  mtRoot,
  mtSize,
  mtHash,
  mtHeight,

  -- ** Testing
  testMerkleProofN,
) where

import Protolude hiding (hash)

import Crypto.Hash (Digest, SHA3_256(..), hash)

import qualified Data.List as List
import qualified Data.Serialize as S
import qualified Data.ByteArray as B
import qualified Data.ByteArray.Encoding as B
import qualified Data.ByteString as BS

import System.Random (randomRIO)

-------------------------------------------------------------------------------
-- Types
-------------------------------------------------------------------------------

-- | A merkle tree root.
newtype MerkleRoot a = MerkleRoot
  { getMerkleRoot :: ByteString
  } deriving (Show, Eq, Ord,  Generic, S.Serialize)

instance B.ByteArrayAccess (MerkleRoot a) where
  length (MerkleRoot bs) = B.length bs
  withByteArray (MerkleRoot bs) f = B.withByteArray bs f

-- | A merkle tree.
data MerkleTree a
  = MerkleEmpty
  | MerkleTree Word32 (MerkleNode a)
  deriving (Show, Eq, Generic, S.Serialize)

data MerkleNode a
  = MerkleBranch {
      mRoot  :: MerkleRoot a
    , mLeft  :: MerkleNode a
    , mRight :: MerkleNode a
  }
  | MerkleLeaf {
      mRoot :: MerkleRoot a
    , mVal  :: a
  }
  deriving (Eq, Show, Generic, S.Serialize)

instance Foldable MerkleTree where
  foldMap _ MerkleEmpty      = mempty
  foldMap f (MerkleTree _ n) = foldMap f n

  null MerkleEmpty = True
  null _           = False

  length MerkleEmpty      = 0
  length (MerkleTree s _) = fromIntegral s

instance Foldable MerkleNode where
  foldMap f x = case x of
    MerkleLeaf{mVal}            -> f mVal
    MerkleBranch{mLeft, mRight} ->
      foldMap f mLeft `mappend` foldMap f mRight

-- | Returns root of merkle tree.
mtRoot :: MerkleTree a -> MerkleRoot a
mtRoot MerkleEmpty      = emptyHash
mtRoot (MerkleTree _ x) = mRoot x

-- | Returns root of merkle tree root hashed.
mtHash :: MerkleTree a -> ByteString
mtHash MerkleEmpty      = merkleHash ""
mtHash (MerkleTree _ x) = B.convert (mRoot x)

mtSize :: MerkleTree a -> Word32
mtSize MerkleEmpty      = 0
mtSize (MerkleTree s _) = s

emptyHash :: MerkleRoot a
emptyHash = MerkleRoot (merkleHash mempty)

-- | Merkle tree height
mtHeight :: Int -> Int
mtHeight ntx
  | ntx < 2 = 0
  | even ntx  = 1 + mtHeight (ntx `div` 2)
  | otherwise = mtHeight $ ntx + 1

-- | Merkle tree width
mtWidth
  :: Int -- ^ Number of transactions (leaf nodes).
  -> Int -- ^ Height at which we want to compute the width.
  -> Int -- ^ Width of the merkle tree.
mtWidth ntx h = (ntx + (1 `shiftL` h) - 1) `shiftR` h

-- | Return the largest power of two such that it's smaller than n.
powerOfTwo :: (Bits a, Num a) => a -> a
powerOfTwo n
   | n .&. (n - 1) == 0 = n `shiftR` 1
   | otherwise = go n
 where
    go w = if w .&. (w - 1) == 0 then w else go (w .&. (w - 1))

-------------------------------------------------------------------------------
-- Constructors
-------------------------------------------------------------------------------

mkLeaf :: ByteString -> MerkleNode ByteString
mkLeaf a =
  MerkleLeaf
  { mVal  = a
  , mRoot = mkLeafRootHash a
  }

mkLeafRootHash :: B.ByteArrayAccess a => a -> MerkleRoot a
mkLeafRootHash a = MerkleRoot $ merkleHash (BS.singleton 0 <> B.convert a)

mkBranch :: MerkleNode a -> MerkleNode a -> MerkleNode a
mkBranch a b =
  MerkleBranch
  { mLeft  = a
  , mRight = b
  , mRoot  = mkRootHash (mRoot a) (mRoot b)
  }

mkRootHash :: MerkleRoot a -> MerkleRoot a -> MerkleRoot a
mkRootHash (MerkleRoot l) (MerkleRoot r) = MerkleRoot $ merkleHash $ mconcat
  [ BS.singleton 1, B.convert l, B.convert r ]

-- | Smart constructor for 'MerkleTree'.
mkMerkleTree :: [ByteString] -> MerkleTree ByteString
mkMerkleTree [] = MerkleEmpty
mkMerkleTree ls = MerkleTree (fromIntegral lsLen) (go lsLen ls)
  where
    lsLen = length ls
    go _  [x] = mkLeaf x
    go len xs = mkBranch (go i l) (go (len - i) r)
      where
        i = powerOfTwo len
        (l, r) = splitAt i xs

-------------------------------------------------------------------------------
-- Merkle Proofs
-------------------------------------------------------------------------------

newtype MerkleProof a = MerkleProof { getMerkleProof :: [ProofElem a] }
  deriving (Show, Eq, Ord, Generic, S.Serialize)

data ProofElem a = ProofElem
  { nodeRoot    :: MerkleRoot a
  , siblingRoot :: MerkleRoot a
  , nodeSide    :: Side
  } deriving (Show, Eq, Ord, Generic, S.Serialize)

data Side = L | R
  deriving (Show, Eq, Ord, Generic, S.Serialize)

-- | Construct a merkle tree proof of inclusion
-- Walks the entire tree recursively, building a list of "proof elements"
-- that are comprised of the current node's root and it's sibling's root,
-- and whether it is the left or right sibling (this is necessary to determine
-- the order in which to hash each proof element root and it's sibling root).
-- The list is ordered such that the for each element, the next element in
-- the list is the proof element corresponding to the node's parent node.
merkleProof :: forall a. MerkleTree a -> MerkleRoot a -> MerkleProof a
merkleProof MerkleEmpty _ = MerkleProof []
merkleProof (MerkleTree _ rootNode) leafRoot = MerkleProof $ constructPath [] rootNode
  where
    constructPath :: [ProofElem a] -> MerkleNode a -> [ProofElem a]
    constructPath pElems (MerkleLeaf leafRoot' _)
      | leafRoot == leafRoot' = pElems
      | otherwise             = []
    constructPath pElems (MerkleBranch bRoot ln rn) = lPath ++ rPath
      where
        lProofElem = ProofElem (mRoot ln) (mRoot rn) L
        rProofElem = ProofElem (mRoot rn) (mRoot ln) R

        lPath = constructPath (lProofElem:pElems) ln
        rPath = constructPath (rProofElem:pElems) rn

-- | Validate a merkle tree proof of inclusion
validateMerkleProof :: forall a. MerkleProof a ->  MerkleRoot a -> MerkleRoot a -> Bool
validateMerkleProof (MerkleProof proofElems) treeRoot leafRoot =
    validate proofElems leafRoot
  where
    validate :: [ProofElem a] -> MerkleRoot a -> Bool
    validate [] proofRoot = proofRoot == treeRoot
    validate (pElem:pElems) proofRoot
      | proofRoot /= nodeRoot pElem = False
      | otherwise = validate pElems $ hashProofElem pElem

    hashProofElem :: ProofElem a -> MerkleRoot a
    hashProofElem (ProofElem pRoot sibRoot side) =
      case side of
        L -> mkRootHash pRoot sibRoot
        R -> mkRootHash sibRoot pRoot

-------------------------------------------------------------------------------
-- Hashing
-------------------------------------------------------------------------------

-- | Compute SHA-256 hash of a bytestring.
-- Maximum input size is (2^{64}-1)/8 bytes.
--
-- > Output size         : 256
-- > Internal state size : 1600
-- > Block size          : 1088
-- > Length size         : n/a
-- > Word size           : 64
-- > Rounds              : 24
sha256 :: ByteString -> ByteString
sha256 x = B.convertToBase B.Base16 (hash x :: Digest SHA3_256)

-- | Hash function to use for merkle tree
merkleHash :: ByteString -> ByteString
merkleHash = sha256

-------------------------------------------------------------------------------
-- Testing
-------------------------------------------------------------------------------

-- | Constructs a merkle tree and random leaf root to test inclusion of
testMerkleProofN :: Int -> IO Bool
testMerkleProofN n
  | n < 2 = panic "Cannot construct a merkle tree with < 2 nodes"
  | otherwise = do
      randN <- randomRIO (1,n) :: IO Int
      let mtree = mkMerkleTree $ map show [1..n]
          randLeaf = mkLeafRootHash $ show randN
          proof = merkleProof mtree randLeaf
      return $ validateMerkleProof proof (mtRoot mtree) randLeaf