module Bio.Bam.Index (
    BamIndex(..),
    withIndexedBam,
    readBamIndex,
    readBaiIndex,
    readTabix,
    IndexFormatError(..),

    Region(..),
    Subsequence(..),
    streamBamRefseq,
    streamBamRegions,
    streamBamSubseq,
    streamBamUnaligned
) where

import Bio.Bam.Header
import Bio.Bam.Reader
import Bio.Bam.Rec
import Bio.Bam.Regions              ( Region(..), Subsequence(..) )
import Bio.Prelude
import Bio.Streaming
import Bio.Streaming.Bgzf           ( bgunzip )
import System.Directory             ( doesFileExist )

import qualified Bio.Bam.Regions                as R
import qualified Bio.Streaming.Bytes            as S
import qualified Bio.Streaming.Parse            as P
import qualified Data.IntMap.Strict             as M
import qualified Data.ByteString                as B
import qualified Data.Vector                    as V
import qualified Data.Vector.Mutable            as W
import qualified Data.Vector.Unboxed            as U
import qualified Data.Vector.Unboxed.Mutable    as N
import qualified Data.Vector.Algorithms.Intro   as N
import qualified Streaming.Prelude              as Q

-- | Full index, unifying BAI and CSI style.  In both cases, we have the
-- binning scheme, parameters are fixed in BAI, but variable in CSI.
-- Checkpoints are created from the linear index in BAI or from the
-- `loffset' field in CSI.

data BamIndex a = BamIndex {
    -- | Minshift parameter from CSI
    minshift :: {-# UNPACK #-} !Int,

    -- | Depth parameter from CSI
    depth :: {-# UNPACK #-} !Int,

    -- | Best guess at where the unaligned records start
    unaln_off :: {-# UNPACK #-} !Int64,

    -- | Room for stuff (needed for tabix)
    extensions :: a,

    -- | Records for the binning index, where each bin has a list of
    -- segments belonging to it.
    refseq_bins :: {-# UNPACK #-} !(V.Vector Bins),

    -- | Known checkpoints of the form (pos,off) where off is the
    -- virtual offset of the first record crossing pos.
    refseq_ckpoints :: {-# UNPACK #-} !(V.Vector Ckpoints) }

  deriving Show

-- | Mapping from bin number to vector of clusters.
type Bins = IntMap Segments
type Segments = U.Vector (Int64,Int64)


-- | Checkpoints.  Each checkpoint is a position with the virtual offset
-- where the first alignment crossing the position is found.  In BAI, we
-- get this from the 'ioffset' vector, in CSI we get it from the
-- 'loffset' field:  "Given a region [beg,end), we only need to visit
-- chunks whose end file offset is larger than 'ioffset' of the 16kB
-- window containing 'beg'."  (Sounds like a marginal gain, though.)

type Ckpoints = IntMap Int64


-- | Decode only those reads that fall into one of several regions.
-- Strategy:  We will scan the file mostly linearly, but only those
-- regions that are actually needed.  We filter the decoded stuff so
-- that it actually overlaps our regions.
--
-- From the binning index, we get a list of segments per requested
-- region.  Using the checkpoints, we prune them:  if we have a
-- checkpoint to the left of the beginning of the interesting region, we
-- can move the start of each segment forward to the checkpoint.  If
-- that makes the segment empty, it can be droppped.
--
-- The resulting segment lists are merged, then traversed.  We seek to
-- the beginning of the earliest segment and start decoding.  Once the
-- virtual file position leaves the segment or the alignment position
-- moves past the end of the requested region, we move to the next.
-- Moving is a seek if it spans a sufficiently large gap or points
-- backwards, else we just keep going.

-- | A 'Segment' has a start and an end offset, and an "end coordinate"
-- from the originating region.
data Segment = Segment {-# UNPACK #-} !Int64 {-# UNPACK #-} !Int64 {-# UNPACK #-} !Int deriving Show

segmentLists :: BamIndex a -> Refseq -> R.Subsequence -> [[Segment]]
segmentLists bi@BamIndex{..} (Refseq ref) (R.Subsequence imap)
        | Just bins <- refseq_bins V.!? fromIntegral ref,
          Just cpts <- refseq_ckpoints V.!? fromIntegral ref
        = [ rgnToSegments bi beg end bins cpts | (beg,end) <- M.toList imap ]
segmentLists _ _ _ = []

-- from region to list of bins, then to list of segments
rgnToSegments :: BamIndex a -> Int -> Int -> Bins -> Ckpoints -> [Segment]
rgnToSegments bi@BamIndex{..} beg end bins cpts =
    [ Segment boff' eoff end
    | bin <- binList bi beg end
    , (boff,eoff) <- maybe [] U.toList $ M.lookup bin bins
    , let boff' = max boff cpt
    , boff' < eoff ]
  where
    !cpt = maybe 0 snd $ M.lookupLE beg cpts

-- list of bins for given range of coordinates, from Heng's horrible code
binList :: BamIndex a -> Int -> Int -> [Int]
binList BamIndex{..} beg end = binlist' 0 (minshift + 3*depth) 0
  where
    binlist' l s t = if l > depth then [] else [b..e] ++ go
      where
        b = t + beg `shiftR` s
        e = t + (end-1) `shiftR` s
        go = binlist' (l+1) (s-3) (t + 1 `shiftL` (3*l))


-- | Merges two lists of segments.  Lists must be sorted, the merge sort
-- merges overlapping segments into one.
infix 4 ~~
(~~) :: [Segment] -> [Segment] -> [Segment]
Segment a b e : xs ~~ Segment u v f : ys
    |          b < u = Segment a b e : (xs ~~ Segment u v f : ys)     -- no overlap
    | a < u && b < v = Segment a v (max e f) : (xs ~~ ys)             -- some overlap
    |          b < v = Segment u v (max e f) : (xs ~~ ys)             -- contained
    | v < a          = Segment u v f : (xs ~~ Segment a b e : ys)     -- no overlap
    | u < a          = Segment u b (max e f) : (xs ~~ ys)             -- some overlap
    | otherwise      = Segment a b (max e f) : (xs ~~ ys)             -- contained
[] ~~ ys = ys
xs ~~ [] = xs


data IndexFormatError = IndexFormatError Bytes deriving (Typeable, Show)

instance Exception IndexFormatError where
    displayException (IndexFormatError m) = "index signature " ++ show m ++ " not recognized"

{- | Reads any index we can find for a file.

If the file name has a .bai or .csi extension, optionally followed by
.gz, we read it.  Else we look for the index by adding such an extension
and by replacing the extension with these two, and finally try the file
itself.  The first file that exists is used.
-}
readBamIndex :: FilePath -> IO (BamIndex ())
readBamIndex fp1 | any (`isSuffixOf` fp1) exts = streamFile fp1 readBaiIndex
                 | otherwise                   = tryAll exts
  where
    exts = words ".bai .bai.gz .csi .csi.gz"

    fp2 = reverse $ case dropWhile (/='.') f of [] -> f++d ; _:b -> b++d
    (f,d) = break (=='/') $ reverse fp1

    tryAll [    ] = streamFile fp1 readBaiIndex
    tryAll (e:es) = do x1 <- liftIO $ doesFileExist (fp1 ++ e)
                       x2 <- liftIO $ doesFileExist (fp2 ++ e)
                       case () of
                            _ | x1 -> streamFile (fp1 ++ e) readBaiIndex
                              | x2 -> streamFile (fp2 ++ e) readBaiIndex
                            _      -> tryAll es

-- | Reads an index in BAI or CSI format, recognized automatically.  The
-- index can be compressed, even though this isn't standard.
readBaiIndex :: MonadIO m => ByteStream m r -> m (BamIndex ())
readBaiIndex = either (const . liftIO $ throwM P.EofException) (return . fst) <=<
               P.parseIO (const $ P.getString 4 >>= switch) . S.gunzip
  where
    switch "BAI\1" = do nref <- fromIntegral `liftM` P.getWord32
                        getIndexArrays nref 14 5 (const return) getIntervals

    switch "CSI\1" = do minshift <- fromIntegral `liftM` P.getWord32
                        depth <- fromIntegral `liftM` P.getWord32
                        P.getWord32 >>= P.drop . fromIntegral -- aux data
                        nref <- fromIntegral `liftM` P.getWord32
                        getIndexArrays nref minshift depth (addOneCheckpoint minshift depth) return

    switch magic   = throwM $ IndexFormatError magic


    -- Insert one checkpoint.  If we already have an entry (can happen
    -- if it comes from a different bin), we conservatively take the min
    addOneCheckpoint minshift depth bin cp = do
            loffset <- fromIntegral `liftM` P.getWord64
            let key = llim (fromIntegral bin) (3*depth) minshift
            return $! M.insertWith min key loffset cp

    -- compute left limit of bin
    llim bin dp sf | dp  ==  0 = 0
                   | bin >= ix = (bin - ix) `shiftL` sf
                   | otherwise = llim bin (dp-3) (sf+3)
            where ix = (1 `shiftL` dp - 1) `div` 7

withIndexedBam :: (MonadIO m, MonadLog m, MonadMask m) => FilePath -> (BamMeta -> BamIndex () -> Handle -> m r) -> m r
withIndexedBam f k = do
    idx <- liftIO $ readBamIndex f
    bracket (liftIO $ openBinaryFile f ReadMode) (liftIO . hClose) $ \hdl -> do
        (hdr,_) <- decodeBam $ streamHandle hdl
        k hdr idx hdl


type TabIndex = BamIndex TabMeta

data TabMeta = TabMeta { format :: TabFormat
                       , col_seq :: Int                           -- Column for the sequence name
                       , col_beg :: Int                           -- Column for the start of a region
                       , col_end :: Int                           -- Column for the end of a region
                       , comment_char :: Char
                       , skip_lines :: Int
                       , names :: V.Vector Bytes }
  deriving Show

data TabFormat = Generic | SamFormat | VcfFormat | ZeroBased   deriving Show

-- | Reads a Tabix index.  Note that tabix indices are compressed, this
-- is taken care of automatically.
readTabix :: MonadIO m => ByteStream m r -> m TabIndex
readTabix = either (const . liftIO $ throwM P.EofException) (return . fst) <=<
            P.parseIO (const $ P.getString 4 >>= switch) . S.gunzip
  where
    switch "TBI\1" = do nref <- fromIntegral `liftM` P.getWord32
                        format       <- liftM toFormat     P.getWord32
                        col_seq      <- liftM fromIntegral P.getWord32
                        col_beg      <- liftM fromIntegral P.getWord32
                        col_end      <- liftM fromIntegral P.getWord32
                        comment_char <- liftM (chr . fromIntegral) P.getWord32
                        skip_lines   <- liftM fromIntegral P.getWord32
                        names        <- liftM (V.fromList . B.split 0) . P.getString . fromIntegral =<< P.getWord32

                        ix <- getIndexArrays nref 14 5 (const return) getIntervals
                        fin <- P.isFinished
                        if fin then return $! ix { extensions = TabMeta{..} }
                               else do unaln <- fromIntegral `liftM` P.getWord64
                                       return $! ix { unaln_off = unaln, extensions = TabMeta{..} }

    switch magic   = throwM $ IndexFormatError magic

    toFormat 1 = SamFormat
    toFormat 2 = VcfFormat
    toFormat x = if testBit x 16 then ZeroBased else Generic

-- Read the intervals.  Each one becomes a checkpoint.
getIntervals :: Monad m => (IntMap Int64, Int64) -> P.Parser r m (IntMap Int64, Int64)
getIntervals (cp,mx0) = do
    nintv <- fromIntegral `liftM` P.getWord32
    reduceM 0 nintv (cp,mx0) $ \(!im,!mx) int -> do
        oo <- fromIntegral `liftM` P.getWord64
        return (if oo == 0 then im else M.insert (int * 0x4000) oo im, max mx oo)


getIndexArrays :: MonadIO m => Int -> Int -> Int
               -> (Word32 -> Ckpoints -> P.Parser r m Ckpoints)
               -> ((Ckpoints, Int64) -> P.Parser r m (Ckpoints, Int64))
               -> P.Parser r m (BamIndex ())
getIndexArrays nref minshift depth addOneCheckpoint addManyCheckpoints
    | nref  < 1 = return $ BamIndex minshift depth 0 () V.empty V.empty
    | otherwise = do
        rbins  <- liftIO $ W.new nref
        rckpts <- liftIO $ W.new nref
        mxR <- reduceM 0 nref 0 $ \mx0 r -> do
                nbins <- P.getWord32
                (!bins,!cpts,!mx1) <- reduceM 0 nbins (M.empty,M.empty,mx0) $ \(!im,!cp,!mx) _ -> do
                        bin <- P.getWord32 -- the "distinct bin"
                        cp' <- addOneCheckpoint bin cp
                        segsarr <- getSegmentArray
                        let !mx' = if U.null segsarr then mx else max mx (snd (U.last segsarr))
                        return (M.insert (fromIntegral bin) segsarr im, cp', mx')
                (!cpts',!mx2) <- addManyCheckpoints (cpts,mx1)
                liftIO $ W.write rbins r bins >> W.write rckpts r cpts'
                return mx2
        liftM2 (BamIndex minshift depth mxR ()) (liftIO $ V.unsafeFreeze rbins) (liftIO $ V.unsafeFreeze rckpts)

-- | Reads the list of segments from an index file and makes sure
-- it is sorted.
getSegmentArray :: MonadIO m => P.Parser r m Segments
getSegmentArray = do
    nsegs <- fromIntegral `liftM` P.getWord32
    segsarr <- liftIO $ N.new nsegs
    loopM 0 nsegs $ \i -> do beg <- fromIntegral `liftM` P.getWord64
                             end <- fromIntegral `liftM` P.getWord64
                             liftIO $ N.write segsarr i (beg,end)
    liftIO $ N.sort segsarr >> U.unsafeFreeze segsarr

{-# INLINE reduceM #-}
reduceM :: (Monad m, Enum ix, Eq ix) => ix -> ix -> a -> (a -> ix -> m a) -> m a
reduceM beg end acc cons = if beg /= end then cons acc beg >>= \n -> reduceM (succ beg) end n cons else return acc

{-# INLINE loopM #-}
loopM :: (Monad m, Enum ix, Eq ix) => ix -> ix -> (ix -> m ()) -> m ()
loopM beg end k = if beg /= end then k beg >> loopM (succ beg) end k else return ()

{-| Seeks to a virtual offset in a BGZF file and streams from there.

If the optional end offset is supplied, streaming stops when it is
reached.  Else, streaming goes on to the end of file.
-}
streamBgzf :: MonadIO m => Handle -> Int64 -> Maybe Int64 -> ByteStream m ()
streamBgzf hdl off eoff =
    S.drop (off .&. 0xffff) . bgunzip $ do
        when (off /= 0) (liftIO $ hSeek hdl AbsoluteSeek $ fromIntegral $ shiftR off 16)
        maybe id S.trim eoff $ streamHandle hdl
{-# INLINE streamBgzf #-}

{- | Streams one reference from a bam file.

Seeks to a given sequence in a Bam file and enumerates only those
records aligning to that reference.  We use the first checkpoint
available for the sequence, which an appropriate index.  Streams the
'BamRaw' records of the correct reference sequence only, and produces an
empty stream if the sequence isn't found.
-}
streamBamRefseq :: (MonadIO m, MonadLog m) => BamIndex b -> Handle -> Refseq -> Stream (Of BamRaw) m ()
streamBamRefseq BamIndex{..} hdl (Refseq r)
    | Just ckpts <- refseq_ckpoints V.!? fromIntegral r
    , Just (voff, _) <- M.minView ckpts
    , voff /= 0 = void $
                  Q.takeWhile ((Refseq r ==) . b_rname . unpackBam) $
                  Q.unfoldr (P.parseLog Error getBamRaw) $
                  streamBgzf hdl voff Nothing
    | otherwise = pure ()


{- | Reads from a Bam file the part with unaligned reads.

Sort of the dual to 'streamBamRefseq'.  Since the index does not
actually point to the unaligned part at the end, we use a best guess at
where the unaligned stuff might start, then skip over any aligned
records.  Our \"fallback guess\" is to decode from the current position;
this only works if something else already consumed the Bam header.
-}
streamBamUnaligned :: MonadIO m => BamIndex b -> Handle -> Stream (Of BamRaw) m ()
streamBamUnaligned BamIndex{..} hdl =
    Q.filter (not . isValidRefseq . b_rname . unpackBam) $
    Q.unfoldr (P.parseIO getBamRaw) $
    streamBgzf hdl unaln_off Nothing

{- | Streams one 'Segment'.

Takes a 'Handle', a 'Segment' and a 'Stream' coming from that handle.
If skipping ahead in the stream looks cheap enough, that is done.  Else
we seek the handle to the start offset and stream from it.  Either way,
the part of the stream before it crosses either the end offset or the
max position is returned, and the remaining stream after it is returned
in its functorial value so it can be passed to another invocation of
e.g. 'streamBamSegment'.  Note that the stream passed in becomes
unusable.
-}
streamBamSegment :: MonadIO m
                 => Handle -> Segment -> Stream (Of BamRaw) m ()
                 -> Stream (Of BamRaw) m (Stream (Of BamRaw) m ())
streamBamSegment hdl (Segment beg end mpos) =
    lift . Q.uncons >=> \case
        -- don't seek if it's a forwards seek of less than 512k
        Just (br,brs) | near (virt_offset br)
            -> Q.span in_seg $ Q.cons br brs
        _   -> Q.span in_seg $ Q.unfoldr (P.parseIO getBamRaw) $ streamBgzf hdl beg (Just end)
  where
    near    o = beg <= fromIntegral o && beg + 0x800000000 > fromIntegral o
    in_seg br = virt_offset br <= end && b_pos (unpackBam br) <= mpos

streamBamSubseq :: MonadIO m
                => BamIndex b -> Handle -> Refseq -> R.Subsequence -> Stream (Of BamRaw) m ()
                -> Stream (Of BamRaw) m (Stream (Of BamRaw) m ())
streamBamSubseq bi hdl ref subs str = Q.filter olap $ foldM (flip $ streamBamSegment hdl) str segs
  where
    segs = foldr (~~) [] $ segmentLists bi ref subs
    olap br = case unpackBam br of
        BamRec{..} -> b_rname == ref && R.overlaps b_pos (b_pos + alignedLength b_cigar) subs

streamBamRegions :: MonadIO m => BamIndex b -> Handle -> [R.Region] -> Stream (Of BamRaw) m ()
streamBamRegions bi hdl = void . foldM (\s (r,is) -> streamBamSubseq bi hdl r is s) (pure ()) . R.toList . R.fromList