{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedRecordDot #-}

{- | Core timestamp functionality for OpenTimestamps.

This module provides the main Timestamp data type and functions for
creating, manipulating, serializing, and verifying timestamps,
including operations, attestations, and Merkle tree calculations.
-}
module OpenTimestamps.Timestamp
  ( Timestamp (..)
  , serialize
  , deserialize
  , printHex
  , putTimestamp
  , merge
  , isTimestampComplete
  , getPendingAttestationsWithMsgs
  , getAttestations
  , getMerkleRoot
  , printTimestamp
  ) where

import Control.Monad (unless, when)
import Crypto.Hash (Digest, SHA256 (SHA256), hashWith)
import Data.Binary.Get (Get, getWord8, isEmpty)
import Data.Binary.Put (Put, putWord8, runPut)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.Map.Strict as Map
import Data.Serialize (decode)
import qualified Data.Set as Set
import Haskoin.Transaction (Tx)
import OpenTimestamps.Attestation
  ( Attestation (..)
  , getAttestation
  , putAttestation
  )
import OpenTimestamps.Config as Config
import OpenTimestamps.Op
  ( Op (..)
  , execute
  , getOp
  , putOp
  )

import Data.ByteString.Builder (byteStringHex, toLazyByteString)
import Data.Text.Lazy (unpack)
import Data.Text.Lazy.Builder (Builder, fromLazyText, fromString, toLazyText)
import Data.Text.Lazy.Builder.Int (decimal)
import Data.Text.Lazy.Encoding (decodeUtf8)
import OpenTimestamps.Types (OTsByteStream, OTsBytes)

{- | Main structure representing a timestamp.

Contains the message being timestamped, any attestations proving when
it existed, and operations that transform the message to reach attestations.
-}
data Timestamp where
  Timestamp ::
    { timestampMsg :: BS.ByteString
    , attestations :: Set.Set Attestation
    , ops :: Map.Map Op Timestamp
    } ->
    Timestamp
  deriving (Eq, Show, Ord)

-- | Extract transaction ID from raw transaction bytes.
getTxId :: BS.ByteString -> Builder
getTxId rawTx =
  let firstHash = hashWith SHA256 rawTx
      secondHash = hashWith SHA256 (firstHash :: Digest SHA256)
      reversedHash =
        BA.reverse
          (BA.convert (secondHash :: Digest SHA256) :: BS.ByteString)
   in fromLazyText
        . decodeUtf8
        . toLazyByteString
        . byteStringHex
        $ reversedHash

-- | Check if a timestamp is complete (i.e., has a Bitcoin attestation).
isTimestampComplete :: Timestamp -> Bool
isTimestampComplete ts =
  any isBitcoinAttestation (Set.toList ts.attestations)
    || any isTimestampComplete (Map.elems (ops ts))
  where
    isBitcoinAttestation (Bitcoin _) = True
    isBitcoinAttestation _ = False

-- | Get all pending attestations along with their corresponding message digests.
getPendingAttestationsWithMsgs :: Timestamp -> [(BS.ByteString, Attestation)]
getPendingAttestationsWithMsgs ts =
  let directPending =
        [ (timestampMsg ts, att)
        | att@(Pending _) <- Set.toList ts.attestations
        ]
      recursivePending =
        concatMap
          (getPendingAttestationsWithMsgs . snd)
          (Map.toList (ops ts))
   in directPending ++ recursivePending

-- | Get all attestations from a timestamp and its sub-timestamps.
getAttestations :: Timestamp -> [Attestation]
getAttestations ts =
  Set.toAscList ts.attestations
    ++ concatMap (getAttestations . snd) (Map.toAscList (ops ts))

-- | Add all operations and attestations from another timestamp to this one.
merge :: Timestamp -> Timestamp -> (Timestamp, Bool)
merge ts1 ts2 =
  let -- Attestations
      currentAtts = ts1.attestations
      otherAtts = ts2.attestations
      newlyAddedAtts = otherAtts `Set.difference` currentAtts
      mergedAtts = currentAtts `Set.union` otherAtts
      attsChanged = not (Set.null newlyAddedAtts)

      -- Operations
      currentOps = ops ts1
      otherOps = ops ts2

      -- Merge operations and track if any sub-timestamp changed
      (finalOps, opsChanged) =
        Map.foldlWithKey'
          mergeOp
          (currentOps, False)
          otherOps
        where
          mergeOp (accOps, accChanged) op otherSubTs =
            case Map.lookup op accOps of
              Just currentSubTs ->
                let (mergedSubTs, subChanged) = merge currentSubTs otherSubTs
                 in (Map.insert op mergedSubTs accOps, accChanged || subChanged)
              Nothing ->
                -- If op from otherOps is not in accOps, it's a new op, so it's a change
                (Map.insert op otherSubTs accOps, True)

      overallChanged = attsChanged || opsChanged
   in ( ts1
          { attestations = mergedAtts
          , ops = finalOps
          }
      , overallChanged
      )

-- | Deserialize a timestamp from binary format.
deserialize :: OTsBytes -> Get Timestamp
deserialize startMsg = do
  unless (BS.length startMsg < Config.maxTimestampMessageLength) $
    fail "Message too long"

  let mergeTimestamps ts1 ts2 = fst (merge ts1 ts2)

  let doTagOrAttestation tag =
        if tag == 0x00
          then do
            attestation <- getAttestation
            pure $ Timestamp startMsg (Set.singleton attestation) Map.empty
          else do
            op <- getOp tag
            let result = execute op startMsg
            stamp <- deserialize result
            pure $ Timestamp startMsg Set.empty (Map.singleton op stamp)

  let parseRemainingParts currentAcc = do
        isEnd <- isEmpty
        if isEnd
          then pure currentAcc
          else do
            tag <- getWord8
            if tag == 0xff
              then do
                innerTag <- getWord8
                part <- doTagOrAttestation innerTag
                parseRemainingParts (mergeTimestamps currentAcc part)
              else do
                -- This is the last part (already read its tag).
                part <- doTagOrAttestation tag
                pure (mergeTimestamps currentAcc part)

  -- A timestamp serialization is a sequence of forked items followed
  -- by a final item. Or it can be just a single item.
  tag <- getWord8
  if tag /= 0xff
    then do
      -- It's a single, non-forked timestamp
      doTagOrAttestation tag
    else do
      -- It's a forked timestamp. The first real tag follows the 0xff.
      firstTag <- getWord8
      firstPart <- doTagOrAttestation firstTag
      parseRemainingParts firstPart

-- | Serialize a timestamp to binary format.
serialize :: Timestamp -> OTsByteStream
serialize = runPut . putTimestamp

-- | Serialize a timestamp to binary format.
putTimestamp :: Timestamp -> Put
putTimestamp ts = do
  let sortedAtts = Set.toAscList ts.attestations
  let sortedOps = Map.toAscList (ops ts)

  when (null sortedAtts && null sortedOps) $ error "An empty timestamp can't be serialized"

  when (length sortedAtts > 1) $ do
    mapM_
      ( \att -> do
          putWord8 0xff
          putWord8 0x00
          putAttestation att
      )
      (take (length sortedAtts - 1) sortedAtts)

  case (null sortedOps, null sortedAtts) of
    (True, True) -> pure () -- Already handled by initial check
    (True, False) -> do
      -- Only attestations, no ops
      putWord8 0x00
      putAttestation (last sortedAtts)
    (False, True) -> do
      -- Only ops, no attestations
      let allButLastOp = take (length sortedOps - 1) sortedOps
      mapM_
        ( \(op, subTs) -> do
            putWord8 0xff
            putOp op
            putTimestamp subTs
        )
        allButLastOp
      let (lastOp, lastSubTs) = last sortedOps
      putOp lastOp
      putTimestamp lastSubTs
    (False, False) -> do
      -- Both ops and attestations
      putWord8 0xff
      putWord8 0x00
      putAttestation (last sortedAtts)

      let allButLastOp = take (length sortedOps - 1) sortedOps
      mapM_
        ( \(op, subTs) -> do
            putWord8 0xff
            putOp op
            putTimestamp subTs
        )
        allButLastOp
      let (lastOp, lastSubTs) = last sortedOps
      putOp lastOp
      putTimestamp lastSubTs

{- | Recursively applies all operations to the initial message to
compute final Merkle root.
-}
getMerkleRoot :: Timestamp -> BS.ByteString
getMerkleRoot ts =
  case Map.lookupMin (ops ts) of
    Nothing -> ts.timestampMsg
    Just (_, subTs) -> getMerkleRoot subTs

-- | Format a ByteString as hexadecimal.
formatHexDirect :: BS.ByteString -> Builder
formatHexDirect =
  fromLazyText
    . decodeUtf8
    . toLazyByteString
    . byteStringHex

-- | Format a digest as hexadecimal (reversed byte order).
formatDigest :: BS.ByteString -> Builder
formatDigest =
  fromLazyText
    . decodeUtf8
    . toLazyByteString
    . byteStringHex
    . BA.reverse

-- | Convert a ByteString to a hexadecimal string.
printHex :: BS.ByteString -> String
printHex bs = unpack . toLazyText $ formatHexDirect bs

-- | Convert an operation to a string representation.
printOp ::
  Op ->
  Builder
printOp op = case op of
  Append bs -> fromString "append " <> formatHexDirect bs
  Prepend bs -> fromString "prepend " <> formatHexDirect bs
  Sha1 -> fromString "sha1"
  Sha256 -> fromString "sha256"
  Ripemd160 -> fromString "ripemd160"
  Keccak256 -> fromString "keccak256"
  Hexlify -> fromString "hexlify"
  Reverse -> fromString "reverse"

-- | Print an operation with an arrow and indentation.
printOpWithArrow ::
  Int ->
  Timestamp ->
  (Op, Timestamp) ->
  Builder
printOpWithArrow indent parentTs (op, subTs) =
  let prefix = fromString (replicate indent ' ')
      opStr = printOp op
      txIdBuilder = case (decode (timestampMsg parentTs) :: Either String Tx) of
        Right _ ->
          let txId = getTxId (timestampMsg parentTs)
           in prefix <> fromString "* Transaction id " <> txId <> fromString "\n"
        Left _ -> mempty
   in txIdBuilder
        <> prefix
        <> fromString " -> "
        <> opStr
        <> fromString "\n"
        <> printTimestampBuilder (indent + 4) subTs

-- | Convert an attestation to a string representation.
printAttestation ::
  Int ->
  Timestamp ->
  Attestation ->
  Builder
printAttestation indent ts att =
  let prefix = fromString (replicate indent ' ')
   in case att of
        Bitcoin blockHeight ->
          prefix
            <> fromString "verify BitcoinBlockHeaderAttestation("
            <> decimal blockHeight
            <> fromString ")\n"
            <> prefix
            <> fromString "# Bitcoin block merkle root "
            <> formatDigest (getMerkleRoot ts)
            <> fromString "\n"
        Pending uri ->
          prefix
            <> fromString "verify PendingAttestation("
            <> fromString (show uri)
            <> fromString ")\n"
        Unknown w bs ->
          prefix
            <> fromString "unknown_attestation "
            <> formatHexDirect w
            <> fromString " "
            <> formatHexDirect bs
            <> fromString "\n"

-- | Build a string representation of a timestamp.
printTimestampBuilder :: Int -> Timestamp -> Builder
printTimestampBuilder indent ts =
  let attsBuilder =
        foldMap
          (printAttestation indent ts)
          (Set.toAscList (attestations ts))
      opsList = Map.toAscList (ops ts)
   in attsBuilder <> case opsList of
        [] -> mempty
        [(op, subTs)] ->
          let prefix = fromString (replicate indent ' ')
              opStr = printOp op
              txIdBuilder = case (decode (timestampMsg ts) :: Either String Tx) of
                Right _ ->
                  let txId = getTxId (timestampMsg ts)
                   in prefix <> fromString "# Transaction id " <> txId <> fromString "\n"
                Left _ -> mempty
           in txIdBuilder
                <> prefix
                <> opStr
                <> fromString "\n"
                <> printTimestampBuilder indent subTs
        _ -> foldMap (printOpWithArrow indent ts) opsList

-- | Convert a timestamp to a string representation with indentation.
printTimestamp :: Int -> Timestamp -> String
printTimestamp indent ts =
  unpack . toLazyText $
    printTimestampBuilder indent ts
