-- | Parsers for BAM and SAM.

module Bio.Bam.Reader (
    decodeBam,
    decodeBamFile,
    decodeBamFiles,
    IncompatibleRefs(..),

    decodePlainBam,
    decodePlainSam,
    getBamMeta,
    getBamRaw,
    getSamRec,

    concatInputs,
    mergeInputsOn,
    guardRefCompat,
    coordinates,
    qnames
                      ) where

import Bio.Bam.Header
import Bio.Bam.Rec
import Bio.Bam.Writer               ( packBam )
import Bio.Streaming
import Bio.Streaming.Bgzf           ( getBgzfHdr, bgunzip )
import Bio.Prelude
import Data.Attoparsec.ByteString   ( anyWord8 )

import qualified Data.Attoparsec.ByteString.Char8   as P
import qualified Data.ByteString                    as B
import qualified Data.ByteString.Char8              as C
import qualified Data.HashMap.Strict                as M
import qualified Data.Vector.Generic                as V
import qualified Data.Vector.Storable               as W
import qualified Data.Vector.Unboxed                as U
import qualified Bio.Streaming.Bytes                as S
import qualified Bio.Streaming.Parse                as S
import qualified Streaming.Prelude                  as Q

{- | Decodes either BAM or SAM.

The input can be plain, gzip'ed or bgzf'd and either BAM or SAM.  BAM
is reliably recognized, anything else is treated as SAM.  The offsets
stored in BAM records make sense only for uncompressed or bgzf'd BAM.
-}
decodeBam :: (MonadIO m, MonadLog m)
          => S.ByteStream m r
          -> m (BamMeta, Stream (Of BamRaw) m r)
decodeBam = getBgzfHdr >=> S.splitAt' 4 . pgunzip >=> unbam
  where
    unbam ("BAM\SOH" :> s) = decodePlainBam s
    unbam (magic     :> s) = decodePlainSam (S.consChunk magic s)

    pgunzip (Nothing, hdr, s) = S.gunzip (S.consChunk hdr s)
    pgunzip (Just _,  hdr, s) =  bgunzip (S.consChunk hdr s)
{-# INLINE decodeBam #-}

decodeBamFile :: (MonadIO m, MonadLog m, MonadMask m) => FilePath -> (BamMeta -> Stream (Of BamRaw) m () -> m r) -> m r
decodeBamFile f k = streamFile f $ decodeBam >=> uncurry k
{-# INLINE decodeBamFile #-}

{- | Reads multiple bam files.

A continuation is run on the list of headers and streams.  Since no
attempt is made to unify the headers, this will work for completely
unrelated bam files.  All files are opened at the same time, which might
run into the file descriptor limit given some ridiculous workflows.
-}
decodeBamFiles :: (MonadMask m, MonadLog m, MonadIO m) => [FilePath] -> ([(BamMeta, Stream (Of BamRaw) m ())] -> m r) -> m r
decodeBamFiles [      ] k = k []
decodeBamFiles ("-":fs) k = decodeBam (streamHandle stdin)   >>= \b -> decodeBamFiles fs $ \bs -> k (b:bs)
decodeBamFiles ( f :fs) k = streamFile f $ \s -> decodeBam s >>= \b -> decodeBamFiles fs $ \bs -> k (b:bs)
{-# INLINE decodeBamFiles #-}

decodePlainBam :: MonadLog m => S.ByteStream m r -> m (BamMeta, Stream (Of BamRaw) m r)
decodePlainBam =
    S.parse (const getBamMeta) >=> \case
        Left (exception, rest) -> logMsg Error exception >> S.effects rest >>= \r -> pure (mempty, pure r)
        Right (Left r)         -> logMsg Error S.EofException >> pure (mempty, pure r)
        Right (Right (h,s))    -> return (h, Q.unfoldr (S.parseLog Error getBamRaw) s)

getBamMeta :: Monad m => S.Parser r m BamMeta
getBamMeta = liftM2 mmerge get_bam_header get_ref_array
  where
    get_bam_header  = do hdr_len <- S.getWord32
                         S.isolate (fromIntegral hdr_len) (S.atto parseBamMeta)

    get_ref_array = do nref <- S.getWord32
                       V.fromList `liftM` replicateM (fromIntegral nref)
                            (do nm <- S.getWord32 >>= S.getString . fromIntegral
                                ln <- S.getWord32
                                return $! BamSQ (C.init nm) (fromIntegral ln) [])

    -- Need to merge information from header into actual reference list.
    -- The latter is the authoritative source for the *order* of the
    -- sequences, so leftovers from the header are discarded.  Merging
    -- is by name.  So we merge information from the header into the
    -- list, then replace the header information.
    mmerge meta refs =
        let tbl = M.fromList [ (sq_name sq, sq) | sq <- V.toList (unRefs (meta_refs meta)) ]
        in meta { meta_refs = Refs $ fmap (\s -> maybe s (mmerge' s) (M.lookup (sq_name s) tbl)) refs }

    mmerge' l r | sq_length l == sq_length r = l { sq_other_shit = sq_other_shit l ++ sq_other_shit r }
                | otherwise                  = l -- contradiction in header, but we'll just ignore it
{-# INLINABLE getBamMeta #-}

getBamRaw :: Monad m => Int64 -> S.Parser r m BamRaw
getBamRaw o = do
        bsize <- fromIntegral `liftM` S.getWord32
        s <- S.getString bsize
        unless (B.length s == bsize) S.abortParse
        bamRaw o s
{-# INLINABLE getBamRaw #-}

{- | Streaming parser for SAM files.

It parses plain uncompressed SAM and returns a result compatible with
'decodePlainBam'.  Since it is supposed to work the same way as the BAM
parser, it requires a symbol table for the reference names.  This is
extracted from the @SQ lines in the header.  Note that reading SAM tends
to be inefficient; if you care about performance at all, use BAM.  -}

decodePlainSam :: (MonadLog m, MonadIO m) => S.ByteStream m r -> m (BamMeta, Stream (Of BamRaw) m r)
decodePlainSam s = do
    (hdr,rest) <- either (\r -> (mempty, pure r)) id `liftM` S.parseIO (const $ S.atto parseBamMeta) s
    let !refs  = M.fromList $ zip [ nm | BamSQ { sq_name = nm } <- V.toList (unRefs (meta_refs hdr))] [toEnum 0..]
        ref  x = M.lookupDefault invalidRefseq x refs
        report = fmap (const Nothing) . logMsg Error
        use    = fmap Just . liftIO . packBam
        strm   = Q.concat . Q.mapM (either report use <=< getSamRec ref) $ S.lines' rest
    return (hdr, strm)


getSamRec :: MonadLog m => (Bytes -> Refseq) -> Bytes -> m (Either S.ParseError BamRec)
getSamRec ref s = case P.parseOnly record s of
    Left  e                                         -> pure . Left $ S.ParseError [unpack s] e
    Right b -> case b_qual b of
        Nothing                                     -> pure $ Right b
        Just qs | W.length qs == V.length (b_seq b) -> pure $ Right b
                | otherwise                         -> do logMsg Warning $ LengthMismatch (b_qname b)
                                                          pure . Right $ b { b_qual = Nothing }
  where
    record = do b_qname <- word
                b_flag  <- num
                b_rname <- ref <$> word
                b_pos   <- subtract 1 <$> num
                b_mapq  <- Q <$> num'
                b_cigar <- W.fromList <$> cigar
                b_mrnm  <- rnext <*> pure b_rname
                b_mpos  <- subtract 1 <$> num
                b_isize <- snum
                b_seq   <- sequ
                b_qual  <- quals
                b_exts  <- exts
                let b_virtual_offset = 0
                return BamRec{..}

    sep      = P.endOfInput <|> () <$ P.char '\t'
    word     = P.takeTill ('\t' ==) <* sep
    num      = P.decimal <* sep
    num'     = P.decimal <* sep
    snum     = P.signed P.decimal <* sep

    rnext    = id <$ P.char '=' <* sep <|> const . ref <$> word
    sequ     = (V.empty <$ P.char '*' <|>
               V.fromList . map toNucleotides . B.unpack <$> P.takeWhile is_nuc) <* sep

    quals    = Nothing <$ P.char '*' <* sep <|> bsToVec <$> word
        where
            bsToVec = Just . W.fromList . map (Q . subtract 33) . B.unpack

    cigar    = [] <$ P.char '*' <* sep <|>
               P.manyTill (flip (:*) <$> P.decimal <*> cigop) sep

    cigop    = P.choice $ zipWith (\c r -> r <$ P.char c) "MIDNSHP" [Mat,Ins,Del,Nop,SMa,HMa,Pad]
    exts     = ext `P.sepBy` sep
    ext      = (\a b v -> (fromString [a,b],v)) <$> P.anyChar <*> P.anyChar <*> (P.char ':' *> value)

    value    = P.char 'A' *> P.char ':' *> (Char <$>               anyWord8) <|>
               P.char 'i' *> P.char ':' *> (Int  <$>     P.signed P.decimal) <|>
               P.char 'Z' *> P.char ':' *> (Text <$>   P.takeTill ('\t' ==)) <|>
               P.char 'H' *> P.char ':' *> (Bin  <$>               hexarray) <|>
               P.char 'f' *> P.char ':' *> (Float . realToFrac <$> P.double) <|>
               P.char 'B' *> P.char ':' *> (
                    P.satisfy (P.inClass "cCsSiI") *> (intArr   <$> many (P.char ',' *> P.signed P.decimal)) <|>
                    P.char 'f'                     *> (floatArr <$> many (P.char ',' *> P.double)))

    intArr   is = IntArr   $ U.fromList is
    floatArr fs = FloatArr $ U.fromList $ map realToFrac fs
    hexarray    = B.pack . repack . C.unpack <$> P.takeWhile (P.inClass "0-9A-Fa-f")
    repack (a:b:cs) = fromIntegral (digitToInt a * 16 + digitToInt b) : repack cs ; repack _ = []
    is_nuc = P.inClass "acgtswkmrybdhvnACGTSWKMRYBDHVN"


data IncompatibleRefs = IncompatibleRefs FilePath FilePath deriving (Typeable, Show)

instance Exception IncompatibleRefs where
    displayException (IncompatibleRefs a b) = "references in " ++ a ++ " and " ++ b ++ " are incompatible"

guardRefCompat :: MonadThrow m => (FilePath,BamMeta) -> (FilePath,BamMeta) -> m ()
guardRefCompat (f0,hdr0) (f1,hdr1) =
    unless (p hdr1 `isPrefixOf` p hdr0) $ throwM $ IncompatibleRefs f0 f1
  where
    p = V.toList . unRefs . meta_refs


{- | Reads multiple bam inputs in sequence.

Only one file is opened at a time, so they must also be consumed in
sequence.  If you can afford to open all inputs simultaneously, you
probably want to use 'mergeInputsOn' instead.  The filename \"-\" refers
to stdin, if no filenames are given, stdin is read.  Since we can't look
ahead into further files, the header of the first input is used
for the result, and an exception is thrown if one of the subsequent
headers is incompatible with the first one.
-}
concatInputs :: (MonadIO m, MonadLog m, MonadMask m) => [FilePath] -> (BamMeta -> Stream (Of BamRaw) m () -> m r) -> m r
concatInputs fs0 k = streamInputs fs0 (go1 $ fs0 ++ repeat "-")
  where
    go1 fs = inspect >=> \case
        Left () -> k mempty (pure ())
        Right s -> do (hdr,bs) <- decodeBam s
                      k hdr (bs >>= go (head fs) hdr (tail fs))

    go f0 hdr0 fs = lift . inspect >=> \case
        Left () -> pure ()
        Right s -> do (hdr,bs) <- lift $ decodeBam s
                      lift $ guardRefCompat (f0,hdr0) (head fs,hdr)
                      bs >>= go f0 hdr0 (tail fs)
{-# INLINABLE concatInputs #-}

{- | Reads multiple bam files and merges them.

If the inputs are all sorted by the thing being merged on, the output
will be sorted, too.  The headers are all merged sensibly, even if their
reference lists differ.  However, for performance reasons, we don't want
to change the rname and mrnm fields in potentially all records.  So
instead of allowing arbitrary reference lists to be merged, we throw an
exception unless every input is compatible with the effective reference
list.
-}
mergeInputsOn :: (Ord x, MonadIO m, MonadLog m, MonadMask m)
              => (BamRaw -> x) -> [FilePath]
              -> (BamMeta -> Stream (Of BamRaw) m () -> m r) -> m r
mergeInputsOn _ [] k = decodeBam (streamHandle stdin) >>= uncurry k
mergeInputsOn p fs k = decodeBamFiles fs $ \bs -> do
    let hdr = foldMap fst bs
    sequence_ $ zipWith (\f (h,_) -> guardRefCompat ("*",hdr) (f,h)) fs bs
    k hdr (foldr (\a b -> void $ mergeStreamsOn p (snd a) b) (pure ()) bs)
{-# INLINABLE mergeInputsOn #-}

coordinates :: BamRaw -> (Refseq, Int)
coordinates = (b_rname &&& b_pos) . unpackBam
{-# INLINE coordinates #-}

qnames :: BamRaw -> Bytes
qnames = b_qname . unpackBam
{-# INLINE qnames #-}