{- |

   This module implements the 2bit format for sequences.

   Based on: <http://genome.ucsc.edu/FAQ/FAQformat#format7>
   Note! the description is not accurate, it is missing a reserved word
         in each sequence record.

   There are also other, completely different ideas of the 2bit format, e.g.
      <http://jcomeau.freeshell.org/www/genome/2bitformat.html>
-}


module Bio.Sequence.TwoBit
   ( decode2Bit, 
     read2Bit, 
     hRead2Bit,
     encode2Bit,
     write2Bit,
     hWrite2Bit
   ) where


import Bio.Sequence.SeqData
import qualified Data.ByteString.Lazy as BB
import qualified Data.ByteString.Lazy.Char8 as B

import System.IO
import Control.Monad

import Data.Char
import Data.Binary
import Data.Int
import Data.List
import Data.Bits

-- import Test.QuickCheck hiding (check)    -- QC 1.0
import Test.QuickCheck hiding ((.&.)) -- QC 2.0

-- constants
default_magic, default_version :: Word32
default_magic   = 0x1A412743
default_version = 0


-- binary extras
check :: Monad m => (a -> Bool) ->  a -> m a
check p x = if (p x) then return x else fail "check failed"

bswap :: Integral a => Int -> a -> a
bswap n x = let s = bytes x in unbytes . reverse $ s ++ replicate (n-length s) 0

bytes :: Integral a => a -> [Word8]
bytes = Data.List.unfoldr (\w -> let (q,r) = quotRem w 256
                                 in if q == 0 && r == 0 then Nothing
                                    else Just (fromIntegral r,q))
unbytes :: Integral a => [Word8] -> a
unbytes = Data.List.foldr (\x y -> y*256+x) 0 . map fromIntegral



-- Conflicts with Bio.Util.TestBase
-- instance Arbitrary Word8 where
--    arbitrary = choose (0,255::Int) >>= return . fromIntegral

-- prop_bswap :: Word8 -> Word8 -> Word8 -> Word8 -> Bool
-- prop_bswap a1 a2 a3 a4 = (bswap 4 . decode . BB.pack) [a1,a2,a3,a4] == ((decode . BB.pack) [a4,a3,a2,a1] :: Word32)





-- "in-core" representation of 2Bit data types

data Header = Header { swap :: Bool,  version, count, reserved :: Word32 }
instance Show Header where show (Header _ v c r) = "H "++show (v,c,r)

instance Binary Header where
    get = do
       m <- get
       v <- get >>= check (==default_version)
       c <- get
       r <- get
       let s = if m == default_magic 
                   then id 
                   else if m == bswap 4 default_magic 
                            then bswap 4
                            else error "2bit decode: incorrect magic number"
       return (Header (m /= default_magic) (s v) (s c) (s r))
    put (Header m v c r) = do
       put default_magic
       put default_version
       put c
       put (0 :: Word32)



data Entry  = Entry  { name :: B.ByteString, offset :: Word32 }
              deriving Show

-- Byte swap an Entry's offset!
swapEntry :: Entry -> Entry
swapEntry entry = entry { offset = bswap 4 (offset entry) }

instance Binary Entry where
    get = do
       len <- getWord8
       name <- replicateM (fromIntegral len) getWord8
       offset <- get
       return (Entry (BB.pack name) offset)
    put (Entry byteString offset) = do
       let len = fromIntegral $ B.length byteString :: Word8
       put len  
       mapM_ put $ BB.unpack byteString
       put offset



data Entries = Entries Header [Entry]

instance Show Entries 
   where show (Entries h es) = unlines (show h : map show es)

instance Binary Entries where
    get = do
       h <- get
       es <- replicateM (fromIntegral $ count h) get
       return (Entries h $ if swap h then map swapEntry es else es)
    put (Entries h es) = do
       put h
       put es   



{-

   Sequence Record definition

 dnaSize         - number of bases of DNA in the sequence
 nBlockCount     - the number of blocks of Ns in the file (representing unknown sequence)
 nBlockStarts    - the starting position for each block of Ns
 nBlockSizes     - the size of each block of Ns
 maskBlockCount  - the number of masked (lower-case) blocks
 maskBlockStarts - the starting position for each masked block
 maskBlockSizes  - the size of each masked block
 packedDna       - the DNA packed to two bits per base

-}


data SR = SR { dnaSize         :: Word32,
               nBlockCount     :: Word32,
               nBlockStarts    :: [Word32],
               nBlockSizes     :: [Word32],
               maskBlockCount  :: Word32,
               maskBlockStarts :: [Word32],
               maskBlockSizes  :: [Word32],
               packedDna       :: [Word8],
               reserved2       :: Word32 }
        deriving Show



-- big- and little-endian variants (what a mess)
newtype SRBE = SRBE SR deriving Show
newtype SRLE = SRLE SR deriving Show


instance Binary SRBE where
    get = do
       dz <- get :: Get Word32
       nc <- get :: Get Word32
       let n = fromIntegral nc
       nbs <- replicateM n get
       nbsz <- replicateM n get
       mc <- get :: Get Word32
       let m = fromIntegral mc
       mbs <- replicateM m get
       mbsz <- replicateM m get
       _reserved <- get :: Get Word32 -- !!!! oops?
       let d = fromIntegral dz
       pdna <- replicateM ((d+3) `div` 4)  get
       return (SRBE $ SR dz nc nbs nbsz
              mc mbs mbsz pdna _reserved)
    -- should this happen?  Why not just write default format?
    put (SRBE sr) = do
       put $ dnaSize sr
       put $ nBlockCount sr
       mapM_ put (nBlockStarts sr)
       mapM_ put (nBlockSizes sr)
       put $ maskBlockCount sr
       mapM_ put (maskBlockStarts sr)
       mapM_ put (maskBlockSizes sr)
       put (0::Word32)
       mapM_ put (packedDna sr)


instance Binary SRLE where
    get = do
       dz <- get :: Get Word32
       nc <- get :: Get Word32
       let n = fromIntegral $ bswap 4 nc
       nbs <- replicateM n get
       nbsz <- replicateM n get
       mc <- get :: Get Word32
       let m = fromIntegral $ bswap 4 mc
       mbs <- replicateM m get
       mbsz <- replicateM m get
       _reserved <- get :: Get Word32 -- !!!! oops?
       let d = fromIntegral $ bswap 4 dz
       pdna <- replicateM ((d+3) `div` 4)  get
       return (SRLE $ SR (bswap 4 dz) (bswap 4 nc)
              (map (bswap 4) nbs) (map (bswap 4) nbsz)
              (bswap 4 mc) (map (bswap 4) mbs) (map (bswap 4) mbsz) pdna (bswap 4 _reserved))
    -- should this happen?  Why not just write default format?
    put (SRLE sr) = do
       put (bswap 4 $ dnaSize sr)
       put (bswap 4 $ nBlockCount sr)
       mapM_ (put . bswap 4) (nBlockStarts sr)
       mapM_ (put . bswap 4) (nBlockSizes sr)
       put (bswap 4 $ maskBlockCount sr)
       mapM_ (put . bswap 4) (maskBlockStarts sr)
       mapM_ (put . bswap 4) (maskBlockSizes sr)
       put (0::Word32)
       mapM_ put (packedDna sr)



-- Used to convert from sequence data in the Sequence data structure to ByteString
fromSR :: SR -> B.ByteString
fromSR sr = B.unfoldr go (0,low,ns,take (fromIntegral $ dnaSize sr) dna)
    where
       low = combine (maskBlockStarts sr) (maskBlockSizes sr)
       ns  = combine (nBlockStarts sr) (nBlockSizes sr)

       combine :: (Num t, Enum t) => [t] -> [t] -> [t]
       combine starts lengths = concatMap (\(p,l) -> [p..p+l-1]) $ zip starts lengths

       -- Unpack a 2Bit packed DNA sequence into a list of 2 bit values 0-3 
       dna = decodeDNA $ packedDna sr
       decodeDNA = concatMap (\x -> [shiftR (x .&. 0xC0) 6, shiftR (x .&. 0x30) 4, shiftR (x .&. 0x0C) 2, x .&. 0x03])


       -- Map 2Bit nucleotide encodings to their character equivalents
       dec1 :: (Num t) => t -> Char
       dec1 x = case x of 
                   0 -> 'T'; 
                   1 -> 'C'; 
                   2 -> 'A'; 
                   3 -> 'G'; 
                   _ -> error ("can't decode value "++show x)


       go :: (Num a, Num t) => (a, [a], [a], [t]) -> Maybe (Char, (a, [a], [a], [t]))
       go (_,_,_,[])              = Nothing
       go (pos,(l:ls),(n:ns),(d:ds))
           | pos == l && pos == n = Just ('n',(pos+1,ls,ns,ds))
           | pos == l             = Just (toLower (dec1 d),(pos+1,ls,n:ns,ds))
           |             pos == n = Just ('N',(pos+1,l:ls,ns,ds))
           | otherwise            = Just (dec1 d, (pos+1,l:ls,n:ns,ds))
       go (pos,[],n:ns,d:ds)
           |             pos == n = Just ('N',(pos+1,[],ns,ds))
           | otherwise            = Just (dec1 d, (pos+1,[],n:ns,ds))
       go (pos,l:ls,[],d:ds)
           | pos == l             = Just (toLower (dec1 d),(pos+1,ls,[],ds))
           | otherwise            = Just (dec1 d, (pos+1,l:ls,[],ds))
       go (pos,[],[],d:ds)        = Just (dec1 d, (pos+1,[],[],ds))
--    go x = error (show x)



toSR :: B.ByteString -> SR
toSR bs = undefined



splits :: [Int64] -> B.ByteString -> [B.ByteString] 
splits [] cs = [cs]
splits (e:es) cs = let (this,rest) = B.splitAt e cs
                   in this : splits es rest




-- | Parse a (lazy) ByteString as sequences in the 2bit format.
decode2Bit :: B.ByteString -> [Sequence Unknown]
decode2Bit cs = let 
                    -- decode to (Header, [Entry]) from ByteString
                    (Entries h es) = decode cs :: Entries

                    -- map to [64-bit offset] from [32-bit offset]!!! 
                    ms = map (fromIntegral . offset) es

                    -- break the ByteString up into chunks of raw 2Bit "sequences"
                    (c:chunks) = zipWith (-) ms (0:ms)

                    --  build raw undecoded [Sequence]!!!
                    --  chunks :: [Int64]
                    --          "drop c cs" .. chop off all cruft in the beginning!!!
                    ss = splits chunks $ B.drop c cs


                --                              ($ Nothing) :: (Maybe a -> b) -> b
                --                              zipWith :: (a -> b -> c) -> [a] -> [b] -> [c]
                --                              Seq :: SeqData -> SeqData -> Maybe QualData -> Sequence
                --                              map fromSR :: [SR] -> [B.ByteString] 
                --                              map (unSRLE.decode) :: [B.ByteString] -> [SR]
                --                              map (unSRBE.decode) :: [B.ByteString] -> [SR]
                in map ($ Nothing) $ zipWith Seq (map name es) $ map fromSR $ case swap h of
                                                                                 True -> map (unSRLE.decode) ss
                                                                                 False -> map (unSRBE.decode) ss
                   



type SequenceLabel = SeqData
type SequenceSize = Offset
type SequenceData = SeqData 
type TwoBitData = (SequenceLabel, SequenceSize, SequenceData)

-- | Marshall from neutral representation to the 2Bit ByteString rep
encode2Bit :: [Sequence a] -> B.ByteString
encode2Bit ss = let

                   buildHeader :: [Sequence a] -> Header
                   buildHeader ss = Header {swap = True, 
                                            version = default_version, 
                                            count = sequenceListLength ss, 
                                            reserved = 0}


                   -- Build the list of 2Bit Entries 
                   buildEntries :: [TwoBitData] -> Offset -> [Entry]
                   buildEntries [] _ = []
                   buildEntries ((label, length, _):xs) currentOffset = Entry {name=label, 
                                                                               offset=fromIntegral $ currentOffset} : buildEntries xs (currentOffset+length)   

                   -- Build a 2Bit Sequence Record!!
                   buildSR :: TwoBitData -> SR
                   buildSR (_, size, dnaData) = SR {dnaSize = (fromIntegral size), 
                                                    nBlockCount = 0,
                                                    nBlockStarts = [],
                                                    nBlockSizes = [],
                                                    maskBlockCount = 0,
                                                    maskBlockStarts = [], 
                                                    maskBlockSizes = [],
                                                    packedDna = encodeDNA $ splitWord8 $ explode dnaData, 
                                                    reserved2 = 0}


                   -- Total # of sequences present
                   sequenceListLength :: [Sequence a] -> Word32
                   sequenceListLength [] = 0
                   sequenceListLength (s:ss) = 1 + sequenceListLength ss
  

                   -- Build a list of vital data to do 2Bit marshalling
                   sequenceListExtract :: [Sequence a] -> [TwoBitData]
                   sequenceListExtract ss = map (\seq -> (seqlabel seq, seqlength seq, seqdata seq)) ss

  
                   -- Map a nucleotide to its respective 2Bit encoding
                   enc1 :: (Num t) => Char -> t 
                   enc1 c = case c of 
                               'T' -> 0; 
                               'C' -> 1; 
                               'A' -> 2; 
                               'G' -> 3;


                   -- Take a ByteString of nucleotides in character encoding and map to a list of their corresponding TwoBit encodings
                   explode :: SeqData -> [Word8]
                   explode seq = map enc1 $ B.unpack seq 


                   -- Split a list of Word8's into 4 character sublists (i.e. quads) 
                   splitWord8 ::  [Word8] -> [[Word8]]
                   splitWord8 [] = []
                   splitWord8 cs = let (this,rest) = splitAt 4 cs
                                      in this : splitWord8 rest


                   -- Build a 2Bit packed DNA sequence
                   encodeDNA :: [[Word8]] -> [Word8]
                   encodeDNA [] = []
                   encodeDNA (w:ws) = (pack2Bit $ getQuad) : encodeDNA ws

                                     where

                                        -- Length of a sequence
                                        len = length w


                                        -- Build a 2Bit encoded Word8 from "exploded" encoding 
                                        pack2Bit :: (Word8, Word8, Word8, Word8) -> Word8
                                        pack2Bit (d1, d2, d3, d4) = shiftL (d1) 6 .|. shiftL (d2) 4 .|. shiftL (d3) 2 .|. d4

                                        -- Build a quad, i.e. a 4 tuple of Word8
                                        getQuad :: (Word8, Word8, Word8, Word8)
                                        getQuad = case len of
                                                        1 -> (w !! 0, 0, 0, 0);
                                                        2 -> (w !! 0, w !! 1, 0, 0);
                                                        3 -> (w !! 0, w !! 1, w !! 2, 0);
                                                        4 -> (w !! 0, w !! 1, w !! 2, w !! 3);

                          

               in 

                   -- Serialize/marshall into ByteString representation
                   B.append (encode (buildHeader ss))
                      (B.append
                         (B.concat (map encode (buildEntries (sequenceListExtract ss) 56)))  -- TEMP WNH
                         (B.concat (map (encode . SRBE) (map buildSR (sequenceListExtract ss)))))      -- TEMP WNH
                         -- WNH (B.concat (map (encode . SRBE) (map buildSR (sequenceListExtract ss)))))



unSRBE :: SRBE -> SR
unSRBE (SRBE x) = x

unSRLE :: SRLE -> SR
unSRLE (SRLE x) = x




-- | Read sequences from a file in 2bit format and 
-- | unmarshall/deserialize into Sequence format.
read2Bit  :: FilePath -> IO [Sequence Unknown]
read2Bit f = B.readFile f >>= return . decode2Bit

-- | Read sequences from a file handle in the 2bit format and
-- | unmarshall/deserialze into Sequence format.
hRead2Bit :: Handle   -> IO [Sequence Unknown]
hRead2Bit h = B.hGetContents h >>= return . decode2Bit




-- | Marshall/serialize [Sequence] into 2Bit format and write to a file. 
write2Bit  :: FilePath -> [Sequence a] -> IO ()
write2Bit f seq = do
                   let byteString = encode2Bit seq    
                   B.writeFile f byteString
                   return ()

-- | Marshall/serialize [Sequence] into 2Bit format and write to a file using handle. 
hWrite2Bit :: Handle   -> [Sequence a] -> IO ()
hWrite2Bit h seq = do
                   let byteString = encode2Bit seq
                   B.hPut h byteString
                   return ()