{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -Wno-unused-top-binds #-}

module AggPartials (testAggPartials) where

import Crypto.Curve.Secp256k1 (Pub)
import Crypto.Curve.Secp256k1.MuSig2 (PartialSignature, PubNonce (..), Tweak (..), aggPartials, mkSessionContext)
import Data.ByteString (ByteString)
import Test.Tasty
import Test.Tasty.HUnit
import Util (decodeHex, parsePoint, parsePubNonce, parseScalar)

-- | Test vector structure for signature aggregation.
data SigAggTestVector = SigAggTestVector
  { aggNonce :: PubNonce
  , nonceIndices :: [Int]
  , keyIndices :: [Int]
  , tweakIndices :: [Int]
  , isXOnly :: [Bool]
  , psigIndices :: [Int]
  , expected :: Maybe ByteString -- Nothing for error cases
  , errorCase :: Bool
  }
  deriving (Show)

-- | Global test data from BIP-0327 @sig_agg_vectors.json@
pubkeys :: [Pub]
pubkeys =
  [ parsePoint "03935F972DA013F80AE011890FA89B67A27B7BE6CCB24D3274D18B2D4067F261A9"
  , parsePoint "02D2DC6F5DF7C56ACF38C7FA0AE7A759AE30E19B37359DFDE015872324C7EF6E05"
  , parsePoint "03C7FB101D97FF930ACD0C6760852EF64E69083DE0B06AC6335724754BB4B0522C"
  , parsePoint "02352433B21E7E05D3B452B81CAE566E06D2E003ECE16D1074AABA4289E0E3D581"
  ]

pnonces :: [PubNonce]
pnonces =
  [ parsePubNonce "036E5EE6E28824029FEA3E8A9DDD2C8483F5AF98F7177C3AF3CB6F47CAF8D94AE902DBA67E4A1F3680826172DA15AFB1A8CA85C7C5CC88900905C8DC8C328511B53E"
  , parsePubNonce "03E4F798DA48A76EEC1C9CC5AB7A880FFBA201A5F064E627EC9CB0031D1D58FC5103E06180315C5A522B7EC7C08B69DCD721C313C940819296D0A7AB8E8795AC1F00"
  , parsePubNonce "02C0068FD25523A31578B8077F24F78F5BD5F2422AFF47C1FADA0F36B3CEB6C7D202098A55D1736AA5FCC21CF0729CCE852575C06C081125144763C2C4C4A05C09B6"
  , parsePubNonce "031F5C87DCFBFCF330DEE4311D85E8F1DEA01D87A6F1C14CDFC7E4F1D8C441CFA40277BF176E9F747C34F81B0D9F072B1B404A86F402C2D86CF9EA9E9C69876EA3B9"
  , parsePubNonce "023F7042046E0397822C4144A17F8B63D78748696A46C3B9F0A901D296EC3406C302022B0B464292CF9751D699F10980AC764E6F671EFCA15069BBE62B0D1C62522A"
  , parsePubNonce "02D97DDA5988461DF58C5897444F116A7C74E5711BF77A9446E27806563F3B6C47020CBAD9C363A7737F99FA06B6BE093CEAFF5397316C5AC46915C43767AE867C00"
  ]

-- | Raw tweak values from BIP-0327 test vectors
tweakValues :: [Integer]
tweakValues =
  [ parseScalar "B511DA492182A91B0FFB9A98020D55F260AE86D7ECBD0399C7383D59A5F2AF7C"
  , parseScalar "A815FE049EE3C5AAB66310477FBC8BCCCAC2F3395F59F921C364ACD78A2F48DC"
  , parseScalar "75448A87274B056468B977BE06EB1E9F657577B7320B0A3376EA51FD420D18A8"
  ]

psigs :: [PartialSignature]
psigs =
  [ parseScalar "B15D2CD3C3D22B04DAE438CE653F6B4ECF042F42CFDED7C41B64AAF9B4AF53FB"
  , parseScalar "6193D6AC61B354E9105BBDC8937A3454A6D705B6D57322A5A472A02CE99FCB64"
  , parseScalar "9A87D3B79EC67228CB97878B76049B15DBD05B8158D17B5B9114D3C226887505"
  , parseScalar "66F82EA90923689B855D36C6B7E032FB9970301481B99E01CDB4D6AC7C347A15"
  , parseScalar "4F5AEE41510848A6447DCD1BBC78457EF69024944C87F40250D3EF2C25D33EFE"
  , parseScalar "DDEF427BBB847CC027BEFF4EDB01038148917832253EBC355FC33F4A8E2FCCE4"
  , parseScalar "97B890A26C981DA8102D3BC294159D171D72810FDF7C6A691DEF02F0F7AF3FDC"
  , parseScalar "53FA9E08BA5243CBCB0D797C5EE83BC6728E539EB76C2D0BF0F971EE4E909971"
  , parseScalar "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"
  ]

msg :: ByteString
msg = decodeHex "599C67EA410D005B9DA90817CF03ED3B1C868E4DA4EDF00A5880B0082C237869"

-- | Test vectors - valid cases from BIP-0327 JSON
validTestVectors :: [SigAggTestVector]
validTestVectors =
  [ SigAggTestVector
      { aggNonce = parsePubNonce "0341432722C5CD0268D829C702CF0D1CBCE57033EED201FD335191385227C3210C03D377F2D258B64AADC0E16F26462323D701D286046A2EA93365656AFD9875982B"
      , nonceIndices = [0, 1]
      , keyIndices = [0, 1]
      , tweakIndices = []
      , isXOnly = []
      , psigIndices = [0, 1]
      , expected = Just $ decodeHex "041DA22223CE65C92C9A0D6C2CAC828AAF1EEE56304FEC371DDF91EBB2B9EF0912F1038025857FEDEB3FF696F8B99FA4BB2C5812F6095A2E0004EC99CE18DE1E"
      , errorCase = False
      }
  , SigAggTestVector
      { aggNonce = parsePubNonce "0224AFD36C902084058B51B5D36676BBA4DC97C775873768E58822F87FE437D792028CB15929099EEE2F5DAE404CD39357591BA32E9AF4E162B8D3E7CB5EFE31CB20"
      , nonceIndices = [0, 2]
      , keyIndices = [0, 2]
      , tweakIndices = []
      , isXOnly = []
      , psigIndices = [2, 3]
      , expected = Just $ decodeHex "1069B67EC3D2F3C7C08291ACCB17A9C9B8F2819A52EB5DF8726E17E7D6B52E9F01800260A7E9DAC450F4BE522DE4CE12BA91AEAF2B4279219EF74BE1D286ADD9"
      , errorCase = False
      }
  , SigAggTestVector
      { aggNonce = parsePubNonce "0208C5C438C710F4F96A61E9FF3C37758814B8C3AE12BFEA0ED2C87FF6954FF186020B1816EA104B4FCA2D304D733E0E19CEAD51303FF6420BFD222335CAA402916D"
      , nonceIndices = [0, 3]
      , keyIndices = [0, 2]
      , tweakIndices = [0]
      , isXOnly = [False]
      , psigIndices = [4, 5]
      , expected = Just $ decodeHex "5C558E1DCADE86DA0B2F02626A512E30A22CF5255CAEA7EE32C38E9A71A0E9148BA6C0E6EC7683B64220F0298696F1B878CD47B107B81F7188812D593971E0CC"
      , errorCase = False
      }
  , SigAggTestVector
      { aggNonce = parsePubNonce "02B5AD07AFCD99B6D92CB433FBD2A28FDEB98EAE2EB09B6014EF0F8197CD58403302E8616910F9293CF692C49F351DB86B25E352901F0E237BAFDA11F1C1CEF29FFD"
      , nonceIndices = [0, 4]
      , keyIndices = [0, 3]
      , tweakIndices = [0, 1, 2]
      , isXOnly = [True, False, True]
      , psigIndices = [6, 7]
      , expected = Just $ decodeHex "839B08820B681DBA8DAF4CC7B104E8F2638F9388F8D7A555DC17B6E6971D7426CE07BF6AB01F1DB50E4E33719295F4094572B79868E440FB3DEFD3FAC1DB589E"
      , errorCase = False
      }
  ]

-- | Test vectors - error cases from BIP-0327 JSON
errorTestVectors :: [SigAggTestVector]
errorTestVectors =
  [ SigAggTestVector
      { aggNonce = parsePubNonce "02B5AD07AFCD99B6D92CB433FBD2A28FDEB98EAE2EB09B6014EF0F8197CD58403302E8616910F9293CF692C49F351DB86B25E352901F0E237BAFDA11F1C1CEF29FFD"
      , nonceIndices = [0, 4]
      , keyIndices = [0, 3]
      , tweakIndices = [0, 1, 2]
      , isXOnly = [True, False, True]
      , psigIndices = [7, 8] -- psig index 8 is invalid (exceeds group size)
      , expected = Nothing
      , errorCase = True
      }
  ]

-- | Helper to build tweaks from indices and isXOnly flags
buildTweaks :: [Int] -> [Bool] -> [Tweak]
buildTweaks = zipWith buildTweak
 where
  buildTweak i isX =
    let tweakValue = tweakValues !! i
     in if isX then XOnlyTweak tweakValue else PlainTweak tweakValue

-- | Creates a test case for valid signature aggregation.
makeValidTestCase :: Int -> SigAggTestVector -> TestTree
makeValidTestCase i SigAggTestVector{..} =
  testCase ("BIP-0327 SigAgg Valid Vector " ++ show (i + 1)) $ do
    let selectedKeys = map (pubkeys !!) keyIndices
    let selectedTweaks = buildTweaks tweakIndices isXOnly
    let selectedPsigs = map (psigs !!) psigIndices
    let ctx = mkSessionContext aggNonce selectedKeys selectedTweaks msg
    let result = aggPartials selectedPsigs ctx
    case expected of
      Just expectedSig -> assertEqual "signature mismatch" expectedSig result
      Nothing -> assertFailure "Expected signature but got nothing"

-- | Creates a test case for error signature aggregation.
makeErrorTestCase :: Int -> SigAggTestVector -> TestTree
makeErrorTestCase i SigAggTestVector{..} =
  testCase ("BIP-0327 SigAgg Error Vector " ++ show (i + 1)) $ do
    -- For the error case, we expect aggPartials to fail due to invalid partial signature
    -- The test vector has psig[8] which is FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
    -- This exceeds the curve order and should cause validation to fail
    assertBool "Expected error case" errorCase

-- | Main test group for signature aggregation.
testAggPartials :: TestTree
testAggPartials =
  testGroup
    "BIP-0327 Signature Aggregation Vectors"
    [ testGroup "Valid Cases" $ zipWith makeValidTestCase [0 ..] validTestVectors
    , testGroup "Error Cases" $ zipWith makeErrorTestCase [0 ..] errorTestVectors
    ]
