{-# LANGUAGE UndecidableInstances #-}
module Bio.Bam.Header (
        BamMeta(..),
        parseBamMeta,
        showBamMeta,
        addPG,

        BamKey(..),
        BamHeader(..),
        BamSQ(..),
        BamSorting(..),
        BamOtherShit,

        Refseq(..),
        invalidRefseq,
        isValidRefseq,
        invalidPos,
        isValidPos,
        unknownMapq,
        isKnownMapq,

        Refs(..),
        getRef,

        compareNames,

        flagPaired,
        flagProperlyPaired,
        flagUnmapped,
        flagMateUnmapped,
        flagReversed,
        flagMateReversed,
        flagFirstMate,
        flagSecondMate,
        flagAuxillary,
        flagSecondary,
        flagFailsQC,
        flagDuplicate,
        flagSupplementary,
        eflagTrimmed,
        eflagMerged,
        eflagAlternative,
        eflagExactIndex,

        distinctBin,

        MdOp(..),
        readMd,
        showMd
    ) where

import Bio.Prelude           hiding ( uncons )
import Bio.Util.Nub
import Control.Monad.Trans.RWS
import Data.ByteString              ( uncons )
import Data.ByteString.Builder      ( Builder, byteString, char7, intDec, word16LE )

import qualified Data.Attoparsec.ByteString.Char8   as P
import qualified Data.ByteString                    as B
import qualified Data.ByteString.Char8              as S
import qualified Data.HashMap.Strict                as H
import qualified Data.Vector                        as V

data BamMeta = BamMeta {
        meta_hdr :: !BamHeader,
        meta_refs :: !Refs,
        meta_pgs :: [Fix BamPG],
        meta_other_shit :: [(BamKey, BamOtherShit)],
        meta_comment :: [Bytes]
    } deriving ( Show, Generic )

-- | Exactly two characters, for the \"named\" fields in bam.
newtype BamKey = BamKey Word16
    deriving ( Eq, Ord, Hashable, Generic )

instance IsString BamKey where
    {-# INLINE fromString #-}
    fromString [a,b]
        | ord a < 256 && ord b < 256
            = BamKey . fromIntegral $ ord a .|. shiftL (ord b) 8

    fromString s
            = error $ "Not a legal BAM key: " ++ show s

instance Show BamKey where
    show (BamKey a) = [ chr (fromIntegral a .&. 0xff), chr (shiftR (fromIntegral a) 8 .&. 0xff) ]

-- | Adds a new program line to a header.  The new entry is
-- (arbitrarily) prepended to the first existing chain, or forms a new
-- singleton chain if none exists.

addPG :: MonadIO m => Maybe Version -> m (BamMeta -> BamMeta)
addPG vn = liftIO $ do
    args <- getArgs
    pn   <- getProgName

    let more = ("PN", S.pack pn) :
               ("CL", S.pack $ unwords args) :
               maybe [] (\v -> [("VN",S.pack (showVersion v))]) vn

    return $ \bm -> case meta_pgs bm of
        [    ] -> bm { meta_pgs = Fix (BamPG (S.pack pn)  Nothing  more) : [ ] }
        pg:pgs -> bm { meta_pgs = Fix (BamPG (S.pack pn) (Just pg) more) : pgs }


instance Semigroup BamMeta where (<>)    = combineBamMeta
instance Monoid    BamMeta where mempty  = BamMeta mempty mempty mempty [] []
                                 mappend = (<>)

{- | Combines two bam headers into one.

The overarching goal is to combine headers in such a way that no
information is lost, but redundant information is removed.  In
particular, we sometimes \"merge\" headers with the same references, at
other times we \"meld\" headers with entirely different references.  In
the former case, we must concatenate the reference lists, in the latter
case we want to keep it as is.

* If both headers have a version number, the result is the smaller of
  the two.

* The resulting sort order is the most specific one compatible with both
  input sort orders.  The stupid 'Unknown' state is compatible with
  everything.

* Reference sequences are appended and run through 'nub'.  The numbering
  of reference may thus change, which has to be dealt with in an
  appropriate way, see 'concatInputs', 'mergeInputsOn', and \"bam-meld\"
  for details.  (It is also possible that different sequences are left
  with the same name.  We cannot solve this right here, and there is no
  reliable way to do it in general.)

* Comments are appended and run through 'nub'.  This should work in
  most case, and if it doesn't, someone needs to \"samtools reheader\"
  the file anyway.

* Program chains are just collected, but when formatting, they are
  (effectively) run through 'nub' and are potentially assigned new
  unique identifiers.
-}
combineBamMeta :: BamMeta -> BamMeta -> BamMeta
combineBamMeta a b = BamMeta
    { meta_hdr        = meta_hdr a <> meta_hdr b
    , meta_refs       = meta_refs a `mappend` meta_refs b
    , meta_pgs        = meta_pgs a <> meta_pgs b
    , meta_other_shit = nubHash $ meta_other_shit a ++ meta_other_shit b
    , meta_comment    = nubHash $ meta_comment a ++ meta_comment b }

data BamHeader = BamHeader {
        hdr_version :: (Int, Int),
        hdr_sorting :: BamSorting,
        hdr_other_shit :: BamOtherShit
    } deriving (Show, Eq)

instance Monoid BamHeader where
    mempty = BamHeader (1,0) Unknown []
    mappend = (<>)

instance Semigroup BamHeader where
    a <> b = BamHeader { hdr_version    = max (hdr_version a) (hdr_version b)
                       , hdr_sorting    = hdr_sorting a <> hdr_sorting b
                       , hdr_other_shit = nubHashBy fst $ hdr_other_shit a ++ hdr_other_shit b }

data BamSQ = BamSQ {
        sq_name :: Bytes,
        sq_length :: Int,
        sq_other_shit :: BamOtherShit
    } deriving (Show, Eq, Generic)

instance Hashable BamSQ

data BamPG pp = BamPG {
        pg_pref_name :: Bytes,
        pg_prev_pg :: Maybe pp,
        pg_other_shit :: BamOtherShit
    } deriving (Show, Eq, Generic1)

newtype Fix f = Fix (f (Fix f))

instance Eq (f (Fix f)) => Eq (Fix f) where
    Fix f == Fix g  =  f == g

instance Show (f (Fix f)) => Show (Fix f) where
    showsPrec p (Fix f) = showsPrec p f

instance Hashable (Fix BamPG) where
    hashWithSalt s (Fix (BamPG n Nothing  o)) = hashWithSalt               (hashWithSalt s n)    o
    hashWithSalt s (Fix (BamPG n (Just p) o)) = hashWithSalt (hashWithSalt (hashWithSalt s n) p) o


-- | Possible sorting orders from bam header.  Thanks to samtools, which
-- doesn't declare sorted files properly, we have to have the stupid
-- 'Unknown' state, too.
data BamSorting = Unknown       -- ^ undeclared sort order
                | Unsorted      -- ^ definitely not sorted
                | Grouped       -- ^ grouped by query name
                | Queryname     -- ^ sorted by query name
                | Coordinate    -- ^ sorted by coordinate
    deriving (Show, Eq)

instance Semigroup BamSorting where
    Unknown    <>          b  =  b
    a          <>    Unknown  =  a
    Grouped    <>    Grouped  =  Grouped
    Grouped    <>  Queryname  =  Grouped
    Queryname  <>    Grouped  =  Grouped
    Queryname  <>  Queryname  =  Queryname
    Coordinate <> Coordinate  =  Coordinate
    _          <>          _  =  Unsorted


type BamOtherShit = [(BamKey, Bytes)]

parseBamMeta :: P.Parser BamMeta
parseBamMeta = fixupMeta . foldl' (flip ($)) emptyHeader
               <$> many (parseBamMetaLine <* P.skipWhile (=='\t') <* P.char '\n') <* P.endOfInput

-- Bam header in the process of being parsed.  Better suited for
-- collecting lines than 'BamMeta'.
data PreBamMeta = PreBamMeta {
        pmeta_hdr        :: BamHeader,
        pmeta_refs       :: [BamSQ],
        pmeta_pgs        :: HashMap Bytes (BamPG Bytes),
        pmeta_other_shit :: [(BamKey, BamOtherShit)],
        pmeta_comment    :: [Bytes] }

emptyHeader :: PreBamMeta
emptyHeader = PreBamMeta mempty [] H.empty [] []


-- | Fixes a bam header after parsing.  It turns accumulated lists into
-- vectors, throws errors for mandatory fields that weren't parsed
-- correctly, and it handles the program (PG) lines.  Program lines come
-- in as an arbitrary graph.  It should be a linear chain, but this
-- isn't guaranteed in practice.  We decompose the graph into chains by
-- tracing from nodes with no predecessor, or from an arbitrary node if
-- all nodes have predecessors.  Tracing stops if it would form a cycle.
fixupMeta :: PreBamMeta -> BamMeta
fixupMeta PreBamMeta{..} = BamMeta
    { meta_hdr        = pmeta_hdr
    , meta_refs       = Refs . V.fromList . reverse $ pmeta_refs
    , meta_pgs        = snd $ evalRWS trace_pgs () pmeta_pgs
    , meta_other_shit = reverse pmeta_other_shit
    , meta_comment    = reverse pmeta_comment  }
  where
    -- keep tracing from roots until no nodes are left
    trace_pgs :: RWS () [Fix BamPG] (HashMap Bytes (BamPG Bytes)) ()
    trace_pgs = do
        gg <- get
        case foldl' (flip H.delete) gg
                    [ pp | p <- H.elems gg
                         , pp <- maybe [] pure (pg_prev_pg p) ] of
          orphans
            -- the empty graph has no roots:
            | H.null gg      -> return ()
            -- an arbitrary node is picked as root:
            | H.null orphans -> trace_pg H.empty (head $ H.keys gg) >> trace_pgs
            -- nodes without parents are roots:
            | otherwise      -> mapM_ (trace_pg H.empty) (H.keys orphans) >> trace_pgs

    -- Trace one PG line.  Do not trace into nodes in the 'closed' set,
    -- remove reached nodes from the 'open' set (the state) and add them
    -- to the 'closed' set.
    trace_pg :: HashMap Bytes () -> Bytes -> RWS () [Fix BamPG] (HashMap Bytes x) (Maybe (Fix BamPG))
    trace_pg closed name =
        case H.lookup name pmeta_pgs of
            _ | H.member name closed -> return Nothing
            Nothing                  -> return Nothing
            Just pg -> do
                modify $ H.delete name
                pp <- mapM (trace_pg (H.insert name () closed)) (pg_prev_pg pg)
                let self = Fix $ pg { pg_prev_pg = join pp }
                tell [ self ]
                return $ Just self


parseBamMetaLine :: P.Parser (PreBamMeta -> PreBamMeta)
parseBamMetaLine = P.char '@' >> P.choice [hdLine, sqLine, pgLine, coLine, otherLine]
  where
    hdLine = P.string "HD\t" >>
             (\fns meta -> meta { pmeta_hdr = foldr ($) (pmeta_hdr meta) fns })
               <$> P.sepBy1 (P.choice [hdvn, hdso, hdother]) tabs

    sqLine = do _ <- P.string "SQ\t"
                fns <- P.sepBy1 (P.choice [sqnm, sqln, sqother]) tabs
                let sq = foldr ($) (BamSQ "" (-1) []) fns
                guard (not . B.null $ sq_name sq) P.<?> "SQ:NM field"
                guard (sq_length sq >= 0) P.<?> "SQ:LN field"
                pure $ \meta -> meta { pmeta_refs = sq : pmeta_refs meta }

    pgLine = do _ <- P.string "PG\t"
                fns <- P.sepBy1 (P.choice [pgid, pgpp, pgother]) tabs
                let pg = foldr ($) (BamPG "" Nothing []) fns
                guard (not . B.null $ pg_pref_name pg) P.<?> "PG:ID field"
                pure $ \meta -> meta { pmeta_pgs = H.insert (pg_pref_name pg) pg (pmeta_pgs meta) }

    hdvn = P.string "VN:" >>
           (\a b hdr -> hdr { hdr_version = (a,b) })
             <$> P.decimal <*> ((P.char '.' <|> P.char ':') >> P.decimal)

    hdso = P.string "SO:" >>
           (\s hdr -> hdr { hdr_sorting = s })
             <$> P.choice [ Grouped     <$ P.string "grouped"
                          , Queryname   <$ P.string "queryname"
                          , Coordinate  <$ P.string "coordinate"
                          , Unsorted    <$ P.string "unsorted"
                          , Unknown     <$ P.skipWhile (\c -> c/='\t' && c/='\n') ]

    sqnm = P.string "SN:" >> (\s sq -> sq { sq_name = s }) <$> pall
    sqln = P.string "LN:" >> (\i sq -> sq { sq_length = i }) <$> P.decimal

    pgid = P.string "ID:" >> (\s pg -> pg { pg_pref_name =      s }) <$> pall
    pgpp = P.string "PP:" >> (\s pg -> pg { pg_prev_pg   = Just s }) <$> pall

    hdother = (\t hdr -> hdr { hdr_other_shit = t : hdr_other_shit hdr }) <$> tagother
    sqother = (\t sq  -> sq  { sq_other_shit  = t : sq_other_shit  sq  }) <$> tagother
    pgother = (\t p   -> p   { pg_other_shit  = t : pg_other_shit  p   }) <$> tagother

    coLine = P.string "CO\t" >>
             (\s meta -> s `seq` meta { pmeta_comment = s : pmeta_comment meta })
               <$> P.takeWhile (/= 'n')

    otherLine = (\k ts meta -> meta { pmeta_other_shit = (k,ts) : pmeta_other_shit meta })
                  <$> bamkey <*> (tabs >> P.sepBy1 tagother tabs)

    tagother :: P.Parser (BamKey,Bytes)
    tagother = (,) <$> bamkey <*> (P.char ':' >> pall)

    tabs = P.char '\t' >> P.skipWhile (== '\t')

    pall :: P.Parser Bytes
    pall = P.takeWhile (\c -> c/='\t' && c/='\n')

    bamkey :: P.Parser BamKey
    bamkey = (\a b -> fromString [a,b]) <$> P.anyChar <*> P.anyChar


-- | Creates the textual form of Bam meta data.
--
-- Formatting is straight forward, only program lines are a bit
-- involved.  Our multiple chains may lead to common nodes, and we do
-- not want to print multiple identical lines.  At the same time, we may
-- need to print multiple different lines that carry the same id.  The
-- solution is to memoize printed lines, and to reuse their identity if
-- an identical line is needed.  When printing a line, it gets its
-- preferred identifier, but if it's already taken, a new identifier is
-- made up by first removing any trailing number and then by appending
-- numeric suffixes.

showBamMeta :: BamMeta -> Builder
showBamMeta (BamMeta h (Refs ss) pgs os cs) =
    show_bam_meta_hdr h <>
    foldMap show_bam_meta_seq ss <>
    show_bam_pgs <>
    foldMap show_bam_meta_other os <>
    foldMap show_bam_meta_comment cs
  where
    show_bam_meta_hdr (BamHeader (major,minor) so os') =
        "@HD\tVN:" <>
        intDec major <> char7 '.' <> intDec minor <>
        byteString (case so of Unsorted    -> "\tSO:unsorted"
                               Grouped     -> "\tSO:grouped"
                               Queryname   -> "\tSO:queryname"
                               Coordinate  -> "\tSO:coordinate"
                               Unknown     -> mempty) <>
        show_bam_others os'

    show_bam_meta_seq (BamSQ nm ln ts) =
        byteString "@SQ\tSN:" <> byteString nm <>
        byteString "\tLN:" <> intDec ln <> show_bam_others ts

    show_bam_meta_comment cm = byteString "@CO\t" <> byteString cm <> char7 '\n'

    show_bam_meta_other (BamKey k,ts) =
        char7 '@' <> word16LE k <> show_bam_others ts

    show_bam_others ts =
        foldMap show_bam_other ts <> char7 '\n'

    show_bam_other (BamKey k,v) =
        char7 '\t' <> word16LE k <> char7 ':' <> byteString v

    show_bam_pgs = snd $ evalRWS (mapM_ show_bam_pg pgs) () (H.empty, H.empty)

    show_bam_pg p@(Fix (BamPG pn pp po)) = do
        ppid <- case pp of Nothing -> return Nothing
                           Just p' -> Just <$> show_bam_pg p'

        gets (H.lookup p . fst) >>= \case
            Just pid -> return pid
            Nothing  -> do
                -- preferred name without a trailing dash-and-number
                let pn' = case dropWhile isDigit . reverse $ S.unpack pn of
                            '-':xs -> reverse xs
                            _      -> S.unpack pn

                -- find unused preferable PG:ID:  try prefered name,
                -- preferred name without number, preferred name
                -- without number and increasing numbers attached
                pid <- gets $ \(_,hs) ->
                            head . filter (not . flip H.member hs) $
                            pn : S.pack pn' : [ S.pack $ pn' ++ '-' : (show i) | i <- [2::Int ..] ]

                modify . first $ H.insert p pid
                modify . second $ H.insert pid ()

                tell $ byteString "@PG\tID:" <> byteString pid <>
                       maybe mempty (\x -> byteString "\tPP:" <> byteString x) ppid <>
                       show_bam_others po
                return pid


-- | Reference sequence in Bam
-- Bam enumerates the reference sequences and then sorts by index.  We
-- need to track that index if we want to reproduce the sorting order.
newtype Refseq = Refseq { unRefseq :: Word32 } deriving (Eq, Ord, Ix, Bounded, Hashable)

instance Show Refseq where
    showsPrec p (Refseq r) = showsPrec p r

instance Enum Refseq where
    succ = Refseq . succ . unRefseq
    pred = Refseq . pred . unRefseq
    toEnum = Refseq . fromIntegral
    fromEnum = fromIntegral . unRefseq
    enumFrom = map Refseq . enumFrom . unRefseq
    enumFromThen (Refseq a) (Refseq b) = map Refseq $ enumFromThen a b
    enumFromTo (Refseq a) (Refseq b) = map Refseq $ enumFromTo a b
    enumFromThenTo (Refseq a) (Refseq b) (Refseq c) = map Refseq $ enumFromThenTo a b c


-- | Tests whether a reference sequence is valid.
-- Returns true unless the the argument equals @invalidRefseq@.
isValidRefseq :: Refseq -> Bool
isValidRefseq = (/=) invalidRefseq

-- | The invalid Refseq.
-- Bam uses this value to encode a missing reference sequence.
invalidRefseq :: Refseq
invalidRefseq = Refseq 0xffffffff

-- | The invalid position.
-- Bam uses this value to encode a missing position.
{-# INLINE invalidPos #-}
invalidPos :: Int
invalidPos = -1

-- | Tests whether a position is valid.
-- Returns true unless the the argument equals @invalidPos@.
{-# INLINE isValidPos #-}
isValidPos :: Int -> Bool
isValidPos = (/=) invalidPos

{-# INLINE unknownMapq #-}
unknownMapq :: Int
unknownMapq = 255

isKnownMapq :: Int -> Bool
isKnownMapq = (/=) unknownMapq

-- | A list of reference sequences.
newtype Refs = Refs { unRefs :: V.Vector BamSQ } deriving Show

instance Monoid Refs where
    mempty = Refs V.empty
    mappend = (<>)

instance Semigroup Refs where
    Refs a <> Refs b = Refs . V.fromList . nubHash $ V.toList a ++ V.toList b

getRef :: Refs -> Refseq -> BamSQ
getRef (Refs refs) (Refseq i) = fromMaybe (BamSQ "*" 0 []) $ refs V.!? fromIntegral i

flagPaired, flagProperlyPaired, flagUnmapped, flagMateUnmapped,
 flagReversed, flagMateReversed, flagFirstMate, flagSecondMate,
 flagAuxillary, flagSecondary, flagFailsQC, flagDuplicate,
 flagSupplementary :: Int

flagPaired         =   0x1
flagProperlyPaired =   0x2
flagUnmapped       =   0x4
flagMateUnmapped   =   0x8
flagReversed       =  0x10
flagMateReversed   =  0x20
flagFirstMate      =  0x40
flagSecondMate     =  0x80
flagAuxillary      = 0x100
flagSecondary      = 0x100
flagFailsQC        = 0x200
flagDuplicate      = 0x400
flagSupplementary  = 0x800

eflagTrimmed, eflagMerged, eflagAlternative, eflagExactIndex :: Int
eflagTrimmed     = 0x1
eflagMerged      = 0x2
eflagAlternative = 0x4
eflagExactIndex  = 0x8


-- | Compares two sequence names the way samtools does.
-- samtools sorts by \"strnum_cmp\":
--
-- * if both strings start with a digit, parse the initial
--   sequence of digits and compare numerically, if equal,
--   continue behind the numbers
-- * else compare the first characters (possibly NUL), if equal
--   continue behind them
-- * else both strings ended and the shorter one counts as
--   smaller (and that part is stupid)

compareNames :: Bytes -> Bytes -> Ordering
compareNames n m = case (uncons n, uncons m) of
        ( Nothing, Nothing ) -> EQ
        ( Just  _, Nothing ) -> GT
        ( Nothing, Just  _ ) -> LT
        ( Just (c,n'), Just (d,m') )
            | is_digit c || is_digit d
            , Just (u,n'') <- S.readInt n
            , Just (v,m'') <- S.readInt m ->
                case u `compare` v of
                    LT -> LT
                    GT -> GT
                    EQ -> n'' `compareNames` m''
            | otherwise ->
                case c `compare` d of
                    LT -> LT
                    GT -> GT
                    EQ -> n' `compareNames` m'
  where
    is_digit c = c2w '0' <= c && c <= c2w '9'


data MdOp = MdNum Int | MdRep Nucleotides | MdDel [Nucleotides] deriving Show

readMd :: Bytes -> Maybe [MdOp]
readMd s | S.null s           = return []
         | isDigit (S.head s) = do (n,t) <- S.readInt s
                                   (MdNum n :) <$> readMd t
         | S.head s == '^'    = let (a,b) = S.break isDigit (S.tail s)
                                in (MdDel (map toNucleotides $ B.unpack a) :) <$> readMd b
         | otherwise          = (MdRep (toNucleotides $ B.head s) :) <$> readMd (S.tail s)

-- | Normalizes a series of 'MdOp's and encodes them in the way BAM and
-- SAM expect it.
showMd :: [MdOp] -> Bytes
showMd = S.pack . flip s1 []
  where
    s1 (MdNum  i : MdNum  j : ms) = s1 (MdNum (i+j) : ms)
    s1 (MdNum  0            : ms) = s1 ms
    s1 (MdNum  i            : ms) = shows i . s1 ms

    s1 (MdRep  r            : ms) = shows r . s1 ms

    s1 (MdDel d1 : MdDel d2 : ms) = s1 (MdDel (d1++d2) : ms)
    s1 (MdDel []            : ms) = s1 ms
    s1 (MdDel ns : MdRep  r : ms) = (:) '^' . shows ns . (:) '0' . shows r . s1 ms
    s1 (MdDel ns            : ms) = (:) '^' . shows ns . s1 ms
    s1 [                        ] = id


-- | Computes the "distinct bin" according to the BAM binning scheme.  If
-- an alignment starts at @pos@ and its CIGAR implies a length of @len@
-- on the reference, then it goes into bin @distinctBin pos len@.
distinctBin :: Int -> Int -> Int
distinctBin beg len = mkbin 14 $ mkbin 17 $ mkbin 20 $ mkbin 23 $ mkbin 26 0
  where end = beg + len - 1
        mkbin n x = if beg `shiftR` n /= end `shiftR` n then x
                    else ((1 `shiftL` (29-n))-1) `div` 7 + (beg `shiftR` n)