module Data.IDA.Internal 
where
import Data.ByteString( ByteString )
import qualified Data.ByteString as B
import Data.Vector(Vector,(!))
import qualified Data.Vector as V
import Data.Array( array )
import qualified Data.Array as A
import Control.Exception
import Data.Typeable
import Data.Binary( Binary )
import GHC.Generics
import qualified Data.Matrix as M
import Data.IDA.FiniteField
import qualified Data.FiniteField.PrimeField as PF
   
data Fragment = Fragment 
  { fragmentId  :: !Int          
  , trailLength :: !Int             
  , reconstructionThreshold :: !Int 
  , theContent :: ![FField]         
  , msgLength  :: !Int              
  }
  deriving(Typeable,Eq,Generic)
instance Show Fragment where
  show f = show (fragmentId f,theContent f)
instance Binary Fragment
encode :: Int 
       -> Int 
              
       -> ByteString 
       -> [Fragment] 
encode m numFragments msg 
  | numFragments >= 1021 && numFragments <1 = 
      throw $ AssertionFailed "encode: invalid number of fragments."
  | otherwise =
  let (intseq,trailLen) = toIntVec m msg 
      len = V.length intseq 
      blocks = V.fromList $ groupInto m intseq 
      vm = vmatrix numFragments m 
      c i k = dotProduct (M.getRow i vm) (blocks ! (k1)) 
      content i = [ c i j 
                  | j <- [ 1 .. ceiling $ fromIntegral len / fromIntegral m ]
                  ] in
  [ Fragment { fragmentId = i 
          , trailLength = trailLen 
          , reconstructionThreshold = m 
          , theContent = content i  
          , msgLength = len 
          } 
  | i <- [1 .. numFragments]
  ]
   
decode :: [Fragment] 
       -> ByteString
decode [] = throw $ AssertionFailed 
      "decode: need at least m fragments for reconstruction."
decode pss@(p:_) 
  | length pss < reconstructionThreshold p = throw $ AssertionFailed 
      "decode: need at least m fragments for reconstruction."
  | otherwise =
  let m = reconstructionThreshold p 
      idxs = map fragmentId (take m pss) 
      n = maximum idxs 
      fragments :: [Vector FField]
      fragments = map (V.fromList . theContent) (take m pss)
      idxVec = V.fromList idxs 
      vecA = M.matrix m m $ \(i,j) -> 
                 vmatrix n m M.! (idxVec!(i1),j) 
      matrixBInv = inverse vecA 
      colVecR :: Vector FMatrix 
      colVecR = V.fromList 
                  [  M.transpose $ M.fromLists [ map (! (k1)) fragments ] 
                  | k <- [1..V.length (head fragments)] 
                  ]  
      idxList = [(j,k) | k <- [1..V.length (head fragments)], j <- [1..m]] 
      matrixBInvTimesColVecr = array ((1,1),(V.length (head fragments),m)) 
        [ ((k,j),head $ 
          array (1,m) (zip [1..] $ M.toLists $ matrixBInv 
                                               *
                                               (colVecR ! (k1))) A.! j) 
        | k <- [1..V.length (head fragments)], j <- [1..m]] 
      mCont :: [FField]
      mCont = map (\(j,k) -> matrixBInvTimesColVecr A.! (k,j)) idxList in
  fromIntVec (msgLength p  trailLength p) $ V.fromList mCont
toIntVec :: Int -> ByteString -> (Vector FField,Int)
toIntVec m bStr = 
  let len = B.length bStr in
  let trailLen = if (len `mod` m) == 0 then 0
                                       else ((len `div` m)+1)*m  len in 
  let bStrApp = bStr `B.append` B.pack (replicate trailLen 0) in
  (V.fromList $ map fromIntegral $ B.unpack bStrApp,trailLen)
fromIntVec :: Int -> Vector FField -> ByteString
fromIntVec originalLength intVec = 
  B.pack $ map (fromInteger . PF.toInteger . number) $ V.toList $ V.slice 0 originalLength intVec
groupInto :: Int -> Vector a -> [Vector a]
groupInto size as =
  let (fs,ss) = V.splitAt size as in
  if V.null ss 
    then [fs]
    else  fs : groupInto size ss