module Bio.Align (
    Mode(..),
    myersAlign,
    showAligned
                 ) where

import Bio.Prelude       hiding ( lefts, rights )
import Foreign.C.String         ( CString )
import Foreign.C.Types          ( CInt(..) )
import Foreign.Marshal.Alloc    ( allocaBytes )

import qualified Data.ByteString.Char8      as S
import qualified Data.ByteString.Unsafe     as S
import qualified Data.ByteString.Lazy.Char8 as L

foreign import ccall unsafe "myers_align.h myers_diff" myers_diff ::
        CString -> CInt ->              -- sequence A and length A
        CInt ->                         -- mode (an enum)
        CString -> CInt ->              -- sequence B and length B
        CInt ->                         -- max distance
        CString ->                      -- backtracing space A
        CString ->                      -- backtracing space B
        IO CInt                         -- returns distance

-- | Mode argument for 'myersAlign', determines where free gaps are
-- allowed.
data Mode = Globally  -- ^ align globally, without gaps at either end
          | HasPrefix -- ^ align so that the second sequence is a prefix of the first
          | IsPrefix  -- ^ align so that the first sequence is a prefix of the second
    deriving Int -> Mode
Mode -> Int
Mode -> [Mode]
Mode -> Mode
Mode -> Mode -> [Mode]
Mode -> Mode -> Mode -> [Mode]
(Mode -> Mode)
-> (Mode -> Mode)
-> (Int -> Mode)
-> (Mode -> Int)
-> (Mode -> [Mode])
-> (Mode -> Mode -> [Mode])
-> (Mode -> Mode -> [Mode])
-> (Mode -> Mode -> Mode -> [Mode])
-> Enum Mode
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Mode -> Mode -> Mode -> [Mode]
$cenumFromThenTo :: Mode -> Mode -> Mode -> [Mode]
enumFromTo :: Mode -> Mode -> [Mode]
$cenumFromTo :: Mode -> Mode -> [Mode]
enumFromThen :: Mode -> Mode -> [Mode]
$cenumFromThen :: Mode -> Mode -> [Mode]
enumFrom :: Mode -> [Mode]
$cenumFrom :: Mode -> [Mode]
fromEnum :: Mode -> Int
$cfromEnum :: Mode -> Int
toEnum :: Int -> Mode
$ctoEnum :: Int -> Mode
pred :: Mode -> Mode
$cpred :: Mode -> Mode
succ :: Mode -> Mode
$csucc :: Mode -> Mode
Enum

-- | Align two strings.  @myersAlign maxd seqA mode seqB@ tries to align
-- @seqA@ to @seqB@, which will work as long as no more than @maxd@ gaps
-- or mismatches are incurred.  The @mode@ argument determines if either
-- of the sequences is allowed to have an overhanging tail.
--
-- The result is the triple of the actual distance (gaps + mismatches)
-- and the two padded sequences.  These sequences are the original
-- sequences with dashes inserted for gaps.
--
-- The algorithm is the O(nd) algorithm by Myers, implemented in C.  A
-- gap and a mismatch score the same.  The strings are supposed to code
-- for DNA, the code understands IUPAC-IUB ambiguity codes.  Two
-- characters match iff there is at least one nucleotide both can code
-- for.  Note that N is a wildcard, while X matches nothing.

myersAlign :: Int -> Bytes -> Mode -> Bytes -> (Int, Bytes, Bytes)
myersAlign :: Int -> Bytes -> Mode -> Bytes -> (Int, Bytes, Bytes)
myersAlign maxd :: Int
maxd seqA :: Bytes
seqA mode :: Mode
mode seqB :: Bytes
seqB =
    IO (Int, Bytes, Bytes) -> (Int, Bytes, Bytes)
forall a. IO a -> a
unsafePerformIO                                 (IO (Int, Bytes, Bytes) -> (Int, Bytes, Bytes))
-> IO (Int, Bytes, Bytes) -> (Int, Bytes, Bytes)
forall a b. (a -> b) -> a -> b
$
    Bytes
-> (CStringLen -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes)
forall a. Bytes -> (CStringLen -> IO a) -> IO a
S.unsafeUseAsCStringLen Bytes
seqA                    ((CStringLen -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes))
-> (CStringLen -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes)
forall a b. (a -> b) -> a -> b
$ \(seq_a :: Ptr CChar
seq_a, len_a :: Int
len_a) ->
    Bytes
-> (CStringLen -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes)
forall a. Bytes -> (CStringLen -> IO a) -> IO a
S.unsafeUseAsCStringLen Bytes
seqB                    ((CStringLen -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes))
-> (CStringLen -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes)
forall a b. (a -> b) -> a -> b
$ \(seq_b :: Ptr CChar
seq_b, len_b :: Int
len_b) ->

    -- size of output buffers derives from this:
    -- char *out_a = bt_a + len_a + maxd +2 ;
    -- char *out_b = bt_b + len_b + maxd +2 ;
    Int
-> (Ptr CChar -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes)
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes (Int
len_a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
maxd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 2)                  ((Ptr CChar -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes))
-> (Ptr CChar -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes)
forall a b. (a -> b) -> a -> b
$ \bt_a :: Ptr CChar
bt_a ->
    Int
-> (Ptr CChar -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes)
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes (Int
len_b Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
maxd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 2)                  ((Ptr CChar -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes))
-> (Ptr CChar -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes)
forall a b. (a -> b) -> a -> b
$ \bt_b :: Ptr CChar
bt_b ->

    Ptr CChar
-> CInt
-> CInt
-> Ptr CChar
-> CInt
-> CInt
-> Ptr CChar
-> Ptr CChar
-> IO CInt
myers_diff Ptr CChar
seq_a (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len_a)
               (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CInt) -> Int -> CInt
forall a b. (a -> b) -> a -> b
$ Mode -> Int
forall a. Enum a => a -> Int
fromEnum Mode
mode)
               Ptr CChar
seq_b (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len_b)
               (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
maxd) Ptr CChar
bt_a Ptr CChar
bt_b      IO CInt
-> (CInt -> IO (Int, Bytes, Bytes)) -> IO (Int, Bytes, Bytes)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \dist :: CInt
dist ->
    if CInt
dist CInt -> CInt -> Bool
forall a. Ord a => a -> a -> Bool
< 0
      then (Int, Bytes, Bytes) -> IO (Int, Bytes, Bytes)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
forall a. Bounded a => a
maxBound, Bytes
S.empty, Bytes
S.empty)
      else (,,) (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
dist) (Bytes -> Bytes -> (Int, Bytes, Bytes))
-> IO Bytes -> IO (Bytes -> (Int, Bytes, Bytes))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
           Ptr CChar -> IO Bytes
S.packCString Ptr CChar
bt_a IO (Bytes -> (Int, Bytes, Bytes))
-> IO Bytes -> IO (Int, Bytes, Bytes)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
           Ptr CChar -> IO Bytes
S.packCString Ptr CChar
bt_b


-- | Nicely print an alignment.  An alignment is simply a list of
-- strings with inserted gaps to make them align.  We split them into
-- manageable chunks, stack them vertically and add a line showing
-- asterisks in every column where all aligned strings agree.  The
-- result is /almost/ the Clustal format.
showAligned :: Int -> [Bytes] -> [L.ByteString]
showAligned :: Int -> [Bytes] -> [ByteString]
showAligned w :: Int
w ss :: [Bytes]
ss | (Bytes -> Bool) -> [Bytes] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Bytes -> Bool
S.null [Bytes]
ss = []
                 | Bool
otherwise = (Bytes -> ByteString) -> [Bytes] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ([Bytes] -> ByteString
L.fromChunks ([Bytes] -> ByteString)
-> (Bytes -> [Bytes]) -> Bytes -> ByteString
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Bytes -> [Bytes] -> [Bytes]
forall a. a -> [a] -> [a]
:[])) [Bytes]
lefts [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++
                               [Char] -> ByteString
L.pack [Char]
agreement ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:
                               ByteString
L.empty ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:
                               Int -> [Bytes] -> [ByteString]
showAligned Int
w [Bytes]
rights
  where
    (lefts :: [Bytes]
lefts, rights :: [Bytes]
rights) = [(Bytes, Bytes)] -> ([Bytes], [Bytes])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Bytes, Bytes)] -> ([Bytes], [Bytes]))
-> [(Bytes, Bytes)] -> ([Bytes], [Bytes])
forall a b. (a -> b) -> a -> b
$ (Bytes -> (Bytes, Bytes)) -> [Bytes] -> [(Bytes, Bytes)]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Bytes -> (Bytes, Bytes)
S.splitAt Int
w) [Bytes]
ss
    agreement :: [Char]
agreement = (Bytes -> Char) -> [Bytes] -> [Char]
forall a b. (a -> b) -> [a] -> [b]
map Bytes -> Char
star ([Bytes] -> [Char]) -> [Bytes] -> [Char]
forall a b. (a -> b) -> a -> b
$ [Bytes] -> [Bytes]
S.transpose [Bytes]
lefts
    star :: Bytes -> Char
star str :: Bytes
str = if Bytes -> Bool
S.null Bytes
str Bool -> Bool -> Bool
|| (Char -> Bool) -> Bytes -> Bool
S.all (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Bytes -> Char
S.head Bytes
str) Bytes
str then '*' else ' '