{-# LANGUAGE StrictData #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.Hash.MerkleTree (
MerkleTree(..),
MerkleRoot(..),
MerkleNode(..),
mkMerkleTree,
mkRootHash,
mkLeafRootHash,
emptyHash,
MerkleProof(..),
merkleProof,
validateMerkleProof,
mtRoot,
mtSize,
mtHash,
mtHeight,
testMerkleProofN,
) where
import Protolude hiding (hash)
import Crypto.Hash (Digest, SHA3_256(..), hash)
import qualified Data.List as List
import qualified Data.Serialize as S
import qualified Data.ByteArray as B
import qualified Data.ByteArray.Encoding as B
import qualified Data.ByteString as BS
import System.Random (randomRIO)
newtype MerkleRoot a = MerkleRoot
{ getMerkleRoot :: ByteString
} deriving (Show, Eq, Ord, Generic, S.Serialize)
instance B.ByteArrayAccess (MerkleRoot a) where
length (MerkleRoot bs) = B.length bs
withByteArray (MerkleRoot bs) f = B.withByteArray bs f
data MerkleTree a
= MerkleEmpty
| MerkleTree Word32 (MerkleNode a)
deriving (Show, Eq, Generic, S.Serialize)
data MerkleNode a
= MerkleBranch {
mRoot :: MerkleRoot a
, mLeft :: MerkleNode a
, mRight :: MerkleNode a
}
| MerkleLeaf {
mRoot :: MerkleRoot a
, mVal :: a
}
deriving (Eq, Show, Generic, S.Serialize)
instance Foldable MerkleTree where
foldMap _ MerkleEmpty = mempty
foldMap f (MerkleTree _ n) = foldMap f n
null MerkleEmpty = True
null _ = False
length MerkleEmpty = 0
length (MerkleTree s _) = fromIntegral s
instance Foldable MerkleNode where
foldMap f x = case x of
MerkleLeaf{mVal} -> f mVal
MerkleBranch{mLeft, mRight} ->
foldMap f mLeft `mappend` foldMap f mRight
mtRoot :: MerkleTree a -> MerkleRoot a
mtRoot MerkleEmpty = emptyHash
mtRoot (MerkleTree _ x) = mRoot x
mtHash :: MerkleTree a -> ByteString
mtHash MerkleEmpty = merkleHash ""
mtHash (MerkleTree _ x) = B.convert (mRoot x)
mtSize :: MerkleTree a -> Word32
mtSize MerkleEmpty = 0
mtSize (MerkleTree s _) = s
emptyHash :: MerkleRoot a
emptyHash = MerkleRoot (merkleHash mempty)
mtHeight :: Int -> Int
mtHeight ntx
| ntx < 2 = 0
| even ntx = 1 + mtHeight (ntx `div` 2)
| otherwise = mtHeight $ ntx + 1
mtWidth
:: Int
-> Int
-> Int
mtWidth ntx h = (ntx + (1 `shiftL` h) - 1) `shiftR` h
powerOfTwo :: (Bits a, Num a) => a -> a
powerOfTwo n
| n .&. (n - 1) == 0 = n `shiftR` 1
| otherwise = go n
where
go w = if w .&. (w - 1) == 0 then w else go (w .&. (w - 1))
mkLeaf :: ByteString -> MerkleNode ByteString
mkLeaf a =
MerkleLeaf
{ mVal = a
, mRoot = mkLeafRootHash a
}
mkLeafRootHash :: B.ByteArrayAccess a => a -> MerkleRoot a
mkLeafRootHash a = MerkleRoot $ merkleHash (BS.singleton 0 <> B.convert a)
mkBranch :: MerkleNode a -> MerkleNode a -> MerkleNode a
mkBranch a b =
MerkleBranch
{ mLeft = a
, mRight = b
, mRoot = mkRootHash (mRoot a) (mRoot b)
}
mkRootHash :: MerkleRoot a -> MerkleRoot a -> MerkleRoot a
mkRootHash (MerkleRoot l) (MerkleRoot r) = MerkleRoot $ merkleHash $ mconcat
[ BS.singleton 1, B.convert l, B.convert r ]
mkMerkleTree :: [ByteString] -> MerkleTree ByteString
mkMerkleTree [] = MerkleEmpty
mkMerkleTree ls = MerkleTree (fromIntegral lsLen) (go lsLen ls)
where
lsLen = length ls
go _ [x] = mkLeaf x
go len xs = mkBranch (go i l) (go (len - i) r)
where
i = powerOfTwo len
(l, r) = splitAt i xs
newtype MerkleProof a = MerkleProof { getMerkleProof :: [ProofElem a] }
deriving (Show, Eq, Ord, Generic, S.Serialize)
data ProofElem a = ProofElem
{ nodeRoot :: MerkleRoot a
, siblingRoot :: MerkleRoot a
, nodeSide :: Side
} deriving (Show, Eq, Ord, Generic, S.Serialize)
data Side = L | R
deriving (Show, Eq, Ord, Generic, S.Serialize)
merkleProof :: forall a. MerkleTree a -> MerkleRoot a -> MerkleProof a
merkleProof MerkleEmpty _ = MerkleProof []
merkleProof (MerkleTree _ rootNode) leafRoot = MerkleProof $ constructPath [] rootNode
where
constructPath :: [ProofElem a] -> MerkleNode a -> [ProofElem a]
constructPath pElems (MerkleLeaf leafRoot' _)
| leafRoot == leafRoot' = pElems
| otherwise = []
constructPath pElems (MerkleBranch bRoot ln rn) = lPath ++ rPath
where
lProofElem = ProofElem (mRoot ln) (mRoot rn) L
rProofElem = ProofElem (mRoot rn) (mRoot ln) R
lPath = constructPath (lProofElem:pElems) ln
rPath = constructPath (rProofElem:pElems) rn
validateMerkleProof :: forall a. MerkleProof a -> MerkleRoot a -> MerkleRoot a -> Bool
validateMerkleProof (MerkleProof proofElems) treeRoot leafRoot =
validate proofElems leafRoot
where
validate :: [ProofElem a] -> MerkleRoot a -> Bool
validate [] proofRoot = proofRoot == treeRoot
validate (pElem:pElems) proofRoot
| proofRoot /= nodeRoot pElem = False
| otherwise = validate pElems $ hashProofElem pElem
hashProofElem :: ProofElem a -> MerkleRoot a
hashProofElem (ProofElem pRoot sibRoot side) =
case side of
L -> mkRootHash pRoot sibRoot
R -> mkRootHash sibRoot pRoot
sha256 :: ByteString -> ByteString
sha256 x = B.convertToBase B.Base16 (hash x :: Digest SHA3_256)
merkleHash :: ByteString -> ByteString
merkleHash = sha256
testMerkleProofN :: Int -> IO Bool
testMerkleProofN n
| n < 2 = panic "Cannot construct a merkle tree with < 2 nodes"
| otherwise = do
randN <- randomRIO (1,n) :: IO Int
let mtree = mkMerkleTree $ map show [1..n]
randLeaf = mkLeafRootHash $ show randN
proof = merkleProof mtree randLeaf
return $ validateMerkleProof proof (mtRoot mtree) randLeaf