{- | Read (and write?) the SFF file format used by
   Roche\/454 sequencing to store flowgram data.

   A flowgram is a series of values (intensities) representing homopolymer runs of
   A,G,C, and T in a fixed cycle, and usually displayed as a histogram.

   The Staden Package contains an io_lib, with a C routine for parsing this format.
   According to comments in the sources, the io_lib implementation is based on a file
   called getsff.c, which I've been unable to track down.

   It is believed that all values are stored big endian.
-}

module Bio.Sequence.SFF ( SFF(..), CommonHeader(..)
                        , ReadHeader(..), ReadBlock(..)
                        , readSFF, writeSFF, writeSFF', recoverSFF
                        , sffToSequence, trim, trimFromTo, trimKey
                        , baseToFlowPos, flowToBasePos
                        , test, convert, flowgram
                        , packFlows, unpackFlows
                        , Flow, Qual, Index, SeqData, QualData
                        , ReadName (..), decodeReadName, encodeReadName
                        ) where

import Bio.Sequence.SeqData
import Bio.Sequence.SFF_name

import Data.Int
import qualified Data.ByteString.Lazy as LB
import qualified Data.ByteString.Lazy.Char8 as LBC
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import Data.ByteString (ByteString)
import Control.Monad (when,replicateM,replicateM_,liftM)

import Data.List (intersperse)
import Data.Binary
import Data.Binary.Get (getByteString,getLazyByteString)
import qualified Data.Binary.Get as G
import Data.Binary.Put (putByteString,putLazyByteString)
import Data.Char (toUpper, toLower)
import Text.Printf (printf)
import System.IO

-- | The type of flowgram value
type Flow = Int16
type Index = Word8

-- Global variables holding static information
magic :: Int32
magic = 0x2e736666

versions :: [Int32]
versions = [1]

readSFF :: FilePath -> IO SFF
readSFF f = return . decode =<< LB.readFile f

trimKey :: CommonHeader -> Sequence Nuc -> Maybe (Sequence Nuc)
trimKey ch (Seq n s q) = let (k,s2) = LB.splitAt (fromIntegral $ key_length ch) s
                          in if LBC.map toLower k==LBC.map toLower (LB.fromChunks [key ch]) 
                             then Just $ Seq n s2 (liftM (LB.drop (fromIntegral $ key_length ch)) q)
                             else Nothing -- error ("Couldn't match key in sequence "++LBC.unpack n++" ("++LBC.unpack k++" vs. "++BC.unpack (key ch)++")!")

sffToSequence :: SFF -> [Sequence Nuc]
sffToSequence (SFF ch rs) = map r2s rs
    where r2s r = clip (read_header r, bases r, quality r)
          clip (h, s, q) = let (left,right) = (clip_qual_left h,clip_qual_right h)
                               split x = let (a,b) = LB.splitAt (fromIntegral right) x 
                                             (c,d) = LB.splitAt (fromIntegral left-1) a
                                         in [c,d,b]
                           in {- trim_key $ -} Seq (LB.fromChunks [read_name h ,BC.pack (" qclip: "++show left++".."++show right)])
                                  (let [a,b,c] = split s in LBC.concat [LBC.map toLower a,LBC.map toUpper b,LBC.map toLower c])
                                  (Just q)

-- trimming the flowgram is necessary, but how to deal with the shift in flow
-- sequence - i.e. what to do when trimming "splits" a flow into trimmed/untrimmed bases?
-- | Trim a read to specific sequence position.
-- The current implementation has the unintended side effect of always trimming the flowgram down to a basecalled position.
trimFromTo :: (Integral i) => i -> i -> ReadBlock -> ReadBlock
trimFromTo l r rd = let trim_seq = LB.drop (fromIntegral l) . LB.take (fromIntegral r)
                        trim_seq' = B.drop (fromIntegral l) . B.take (fromIntegral r)
                        trim_flw = B.drop ((2*) $ fromIntegral $ baseToFlowPos rd l) . B.take ((2*) $ fromIntegral $ baseToFlowPos rd r)
                    in rd { read_header = read_header rd -- FIXME: Update num_bases?
                          , flow_data = trim_flw (flow_data rd)
                          , flow_index = trim_seq' (flow_index rd)
                          , bases = trim_seq (bases rd)
                          , quality = trim_seq (quality rd)
                          }

-- | Trim a read according to clipping information
trim :: ReadBlock -> ReadBlock
trim rb = let rh = read_header rb in trimFromTo (clip_qual_left rh-1) (clip_qual_right rh) rb

-- | Convert a flow position to the corresponding sequence position
flowToBasePos :: Integral i => ReadBlock -> i -> Int
flowToBasePos rd fp = length $ takeWhile (<fp) $ scanl (+) 0 $ map fromIntegral $ B.unpack $ flow_index rd

-- | Convert a sequence position to the corresponding flow position
baseToFlowPos :: Integral i => ReadBlock -> i -> Int
baseToFlowPos rd sp = sum $ map fromIntegral $ B.unpack $ B.take (fromIntegral sp) $ flow_index rd

recoverSFF :: FilePath -> IO SFF
recoverSFF f = return . unRecovered . decode =<< LB.readFile f

-- | Write an 'SFF' to the specified file name
writeSFF :: FilePath -> SFF -> IO ()
writeSFF = encodeFile

-- | Write an 'SFF' to the specified file name, but go back and
--   update the read count.  Useful if you want to output a lazy
--   stream of 'ReadBlock's.  Returns the number of reads written.
writeSFF' :: FilePath -> SFF -> IO Int
writeSFF' f (SFF hs rs) = do
  h <- openFile f WriteMode
  LBC.hPut h $ encode hs
  c <- writeReads h (fromIntegral $ flow_length hs) rs
  hSeek h AbsoluteSeek 20
  LBC.hPut h $ encode c
  hClose h
  return $ fromIntegral c

writeReads :: Handle -> Int -> [ReadBlock] -> IO Int32
writeReads _ _ [] = return 0
writeReads h i (r:rs) = do
  LBC.hPut h $ encode (RBI i r)
  c <- writeReads h i rs
  return $! (c+1)

data RBI = RBI Int ReadBlock

-- | Wrapper for ReadBlocks since they need additional information
instance Binary RBI where 
    put (RBI c r) = do
      putRB c r
    get = undefined
      
-- --------------------------------------------------
-- | test serialization by output'ing the header and first two reads 
--   in an SFF, and the same after a decode + encode cycle.
test :: FilePath -> IO ()
test file = do 
  (SFF h rs) <- readSFF file 
  let sff = (SFF h (take 2 rs))
  putStrLn $ show $ sff
  putStrLn ""
  putStrLn $ show $ (decode $ encode sff :: SFF)

-- --------------------------------------------------
-- | Convert a file by decoding it and re-encoding it
--   This will lose the index (which isn't really necessary)
convert :: FilePath -> IO ()
convert file = writeSFF (file++".out") =<< readSFF file

-- | Generalized function for padding
pad :: Integral a => a -> Put
pad x = replicateM_ (fromIntegral x) (put zero) where zero = 0 :: Word8 

-- | Generalized function to skip padding
skip :: Integral a => a -> Get ()
skip = G.skip . fromIntegral

-- | The data structure storing the contents of an SFF file (modulo the index)
data SFF = SFF !CommonHeader [ReadBlock]

instance Show SFF where 
    show (SFF h rs) = (show h ++ "Read Blocks:\n\n" ++ concatMap show rs)

instance Binary SFF where
    get = do
      -- Parse CommonHeader
      chead <- get
      -- Get the ReadBlocks
      rds <- replicateM (fromIntegral (num_reads chead))
                                   (do 
                                      rh <- get :: Get ReadHeader
                                      getRB chead rh
                                   )
      return (SFF chead rds)

    put (SFF hd rds) = do
      put hd
      mapM_ (put . RBI (fromIntegral $ flow_length hd)) rds

{-# INLINE getRB #-}
getRB :: CommonHeader -> ReadHeader -> Get ReadBlock
getRB chead rh = do
  let nb = fromIntegral $ num_bases rh
      nb' = fromIntegral $ num_bases rh
      fl = fromIntegral $ flow_length chead
  fg <- getByteString (2*fl)
  fi <- getByteString nb
  bs <- getLazyByteString nb'
  qty <- getLazyByteString nb'
  let l = (fl*2+nb*3) `mod` 8
  when (l > 0) (skip (8-l))
  return (ReadBlock rh fg fi bs qty)

-- | A ReadBlock can't be an instance of Binary directly, since it depends on
--   information from the CommonHeader.
putRB :: Int -> ReadBlock -> Put
putRB fl rb = do
  put (read_header rb)
  putByteString (flow_data rb)
  -- ensure that flowgram has correct lenght
  replicateM (2*fl-B.length (flow_data rb)) (put (0::Word8))
  putByteString (flow_index rb)
  putLazyByteString (bases rb)
  putLazyByteString (quality rb)
  let nb = fromIntegral $ num_bases $ read_header rb
      l = (fl*2+nb*3) `mod` 8
  when (l > 0) (pad (8-l))

-- | Unpack the flow_data field into a list of flow values
unpackFlows :: ByteString -> [Flow]
unpackFlows = dec . map fromIntegral . B.unpack 
    where dec (d1:d2:rest) = d1*256+d2 : dec rest
          dec [] = []
          dec _  = error "odd flowgram length?!"

-- | Pack a list of flows into the corresponding binary structure (the flow_data field)
packFlows :: [Flow] -> ByteString
packFlows = B.pack . map fromIntegral . merge 
  where merge (x:xs) = let (a,b) = x `divMod` 256 in a:b:merge xs
        merge [] = []

-- ----------------------------------------------------------
-- | SFF has a 31-byte common header
--   Todo: remove items that are derivable (counters, magic, etc)
--   cheader_lenght points to the first read header.
--   Also, the format is open to having the index anywhere between reads,
--   we should really keep count and check for each read.  In practice, it
--   seems to be places after the reads.
--   
--   The following two fields are considered part of the header, but as
--   they are static, they are not part of the data structure
--     magic   :: Word32   -- ^ 0x2e736666, i.e. the string ".sff"
--     version :: Word32   -- ^ 0x00000001

data CommonHeader = CommonHeader {
          index_offset                            :: Int64    -- ^ Points to a text(?) section
        , index_length, num_reads                 :: Int32
        , key_length, flow_length                 :: Int16
        , flowgram_fmt                            :: Word8
        , flow, key                               :: ByteString 
        }

instance Show CommonHeader where
    show (CommonHeader io il nr kl fl fmt f k) =
        "Common Header:\n\n" ++ (unlines $ map ("    "++) 
                                 ["index_off:\t"++show io ++"\tindex_len:\t"++show il
                                 ,"num_reads:\t"++show nr
                                 ,"key_len:\t"  ++show kl ++ "\tflow_len:\t"++show fl
                                 ,"format\t:"   ++show fmt
                                 ,"flow\t:"     ++BC.unpack f
                                 ,"key\t:"      ++BC.unpack k
                                 , ""
                                 ])

instance Binary CommonHeader where
    get = do { m <- get ; when (m /= magic)   $ error (printf "Incorrect magic number - got %8x, expected %8x" m magic)
             ; v <- get ; when (not (v `elem` versions)) $ error (printf "Unexpected version - got %d, supported are: %s" v (unwords $ map show versions))
             ; io <- get ; ixl <- get ; nrd <- get
             ; chl <- get ; kl <- get ; fl <- get ; fmt <- get
             ; fw <- getByteString (fromIntegral fl)
             ; k  <- getByteString (fromIntegral kl)
             ; skip (chl-(31+fl+kl)) -- skip to boundary
             ; return (CommonHeader io ixl nrd kl fl fmt fw k)
             }

    put ch = let CommonHeader io il nr kl fl fmt f k = ch { index_offset = 0 } in
        do { let cl = 31+fl+kl
                 l = cl `mod` 8
                 padding = if l > 0 then 8-l else 0
           ; put magic; put (last versions); put io; put il; put nr; put (cl+padding); put kl; put fl; put fmt
           ; putByteString f; putByteString k
           ; pad padding -- skip to boundary
           }

-- ---------------------------------------------------------- 
-- | Each Read has a fixed read header
data ReadHeader = ReadHeader {
      name_length                           :: Int16
    , num_bases                             :: Int32
    , clip_qual_left, clip_qual_right
    , clip_adapter_left, clip_adapter_right :: Int16
    , read_name                             :: ByteString
}

instance Show ReadHeader where
    show (ReadHeader nl nb cql cqr cal car rn) =
        ("    Read Header:\n" ++) $ unlines $ map ("        "++) 
                    [ "name_len:\t"++show nl, "num_bases:\t"++show nb
                    , "clip_qual:\t"++show cql++"..."++show cqr
                    , "clip_adap:\t"++show cal++"..."++show car
                    , "read name:\t"++BC.unpack rn
                    , "" 
                    ]

instance Binary ReadHeader where
    get = do
      { rhl <- get; nl <- get; nb <- get
      ; cql <- get; cqr <- get ; cal <- get ; car <- get
      ; n <- getByteString (fromIntegral nl)
      ; skip (rhl - (16+ nl))
      ; return (ReadHeader nl nb cql cqr cal car n)
      }
    put (ReadHeader nl nb cql cqr cal car rn) = 
        do { let rl = 16+nl
                 l = rl `mod` 8
                 padding = if l > 0 then 8-l else 0
           ; put (rl+padding); put nl; put nb; put cql; put cqr; put cal; put car
           ; putByteString rn 
           ; pad padding
           }

-- ----------------------------------------------------------
-- | This contains the actual flowgram for a single read.
data ReadBlock = ReadBlock {
      read_header                :: ! ReadHeader
    -- The data block
    , flow_data                  :: ! ByteString -- nb! use unpackFlows for this
    , flow_index                 :: ! ByteString
    , bases                      :: ! SeqData
    , quality                    :: ! QualData
    }

flowgram :: ReadBlock -> [Flow]
flowgram = unpackFlows . flow_data

instance Show ReadBlock where
    show (ReadBlock h f i b q) =
        show h ++ unlines (map ("     "++) 
            ["flowgram:\t"++show (unpackFlows f)
            , "index:\t"++(concat . intersperse " " . map show . B.unpack) i
            , "bases:\t"++LBC.unpack b
            , "quality:\t"++(concat . intersperse " " . map show . LB.unpack) q
            , ""
            ])

-- ------------------------------------------------------------
-- | RSFF wraps an SFF to provide an instance of Binary with some more error checking.
data RSFF = RSFF { unRecovered :: SFF }

instance Binary RSFF where 
    get = do
      -- Parse CommonHeader
      chead <- get
      -- Get the first read block
      r1 <- do rh <- get 
               getRB chead rh
      -- Get subsequent read blocks
      rds <- replicateM (fromIntegral (num_reads chead))
                                   (do rh <- getSaneHeader (take 4 $ BC.unpack $ read_name $ read_header r1)
                                       getRB chead rh)
      return (RSFF $ SFF chead (r1:rds))
    put = error "You should not serialize an RSFF"

-- | This allows us to decode the constant parts of the read header for verifying its correcness.
data PartialReadHeader = PartialReadHeader {
      _pread_header_lenght                    :: Int16 
    , _pname_length                           :: Int16
    , _pnum_bases                             :: Int32
    , _pclip_qual_left, _pclip_qual_right
    , _clip_adapter_left, _pclip_adapter_right :: Int16
    , _pread_name                              :: ByteString -- length four
}

instance Binary PartialReadHeader where
    get = do { rhl <- get; nl <- get; nb <- get; ql <- get; qr <- get; al <- get; ar <- get; rn <- getByteString 4 
             ; return (PartialReadHeader rhl nl nb ql qr al ar rn) }
    put = error "You should not serialize a PartialReadHeader"

-- | Ensure that the header we're decoding matches our expectations.
getSaneHeader :: String -> Get ReadHeader
getSaneHeader prefix = do
  buf <- getLazyByteString 20
  decodeSaneH prefix buf  

decodeSaneH :: String -> LBC.ByteString -> Get ReadHeader
decodeSaneH prefix buf = do
  let PartialReadHeader rhl nl _nb _ql _qr _al _ar rn = decode buf
  if rhl >= 20 && nl > 0 && all id (zipWith (==) prefix (BC.unpack rn))
      then do buf2 <- getLazyByteString (fromIntegral rhl-20)
              return (decode $ LB.concat [buf,buf2])
      else do x <- getLazyByteString 1 -- error "skip one byte, try again"
              decodeSaneH prefix (LBC.concat [buf,x])