module Bio.Motif.Merge
( mergePWM
, mergePWMWeighted
, dilute
, trim
, mergeTree
, mergeTreeWeighted
, iterativeMerge
, buildTree
, cutTreeBy
)where
import AI.Clustering.Hierarchical hiding (size)
import Control.Arrow (first)
import Control.Monad (forM_, when)
import Control.Monad.ST (runST, ST)
import qualified Data.ByteString.Char8 as B
import Data.List (dropWhileEnd)
import qualified Data.Matrix.Symmetric.Mutable as MSU
import qualified Data.Matrix.Unboxed as MU
import Data.Maybe
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Unboxed as U
import Bio.Motif
import Bio.Motif.Alignment
import Bio.Utils.Functions (kld)
mergePWM :: (PWM, PWM, Int) -> PWM
mergePWM (m1, m2, shift) | shift >= 0 = merge shift (_mat m1) $ _mat m2
| otherwise = merge (shift) (_mat m2) $ _mat m1
where
merge s a b = PWM Nothing $ MU.fromRows $ loop 0
where
n1 = MU.rows a
n2 = MU.rows b
loop i | i' < 0 || (i < n1 && i' >= n2) = MU.takeRow a i : loop (i+1)
| i < n1 && i' < n2 = f (MU.takeRow a i) (MU.takeRow b i') : loop (i+1)
| i >= n1 && i' < n2 = MU.takeRow b i' : loop (i+1)
| otherwise = []
where
i' = i s
f = U.zipWith (\x y -> (x+y)/2)
mergePWMWeighted :: (PWM, [Int])
-> (PWM, [Int])
-> Int
-> (PWM, [Int])
mergePWMWeighted m1 m2 shift
| shift >= 0 = merge shift (first _mat m1) $ first _mat m2
| otherwise = merge (shift) (first _mat m2) $ first _mat m1
where
merge s (p1,w1) (p2,w2) = first (PWM Nothing . MU.fromRows) $ unzip $ loop 0
where
a = V.fromList $ zip (MU.toRows p1) w1
b = V.fromList $ zip (MU.toRows p2) w2
n1 = V.length a
n2 = V.length b
loop i | i' < 0 || (i < n1 && i' >= n2) = a V.! i : loop (i+1)
| i < n1 && i' < n2 = f (a V.! i) (b V.! i') : loop (i+1)
| i >= n1 && i' < n2 = b V.! i' : loop (i+1)
| otherwise = []
where
i' = i s
f (xs, wx) (ys, wy) = (U.zipWith (\x y ->
(fromIntegral wx * x + fromIntegral wy * y) /
fromIntegral (wx + wy)) xs ys, wx + wy)
dilute :: (PWM, [Int]) -> PWM
dilute (pwm, ws) = PWM Nothing $ MU.fromRows $ zipWith f ws $ MU.toRows $ _mat pwm
where
f w r | w < n = let d = fromIntegral $ n w
in U.map (\x -> (fromIntegral w * x + 0.25 * d) / fromIntegral n) r
| otherwise = r
n = maximum ws
trim :: Bkgd -> Double -> PWM -> PWM
trim (BG (a,c,g,t)) cutoff pwm = PWM Nothing $ MU.fromRows $ dropWhileEnd f $
dropWhile f rs
where
f x = kld x bg < cutoff
rs = MU.toRows $ _mat pwm
bg = U.fromList [a,c,g,t]
mergeTree :: AlignFn -> Dendrogram Motif -> PWM
mergeTree align t = case t of
Branch _ _ left right -> f (mergeTree align left) $ mergeTree align right
Leaf a -> _pwm a
where
f a b | isSame = mergePWM (a, b, i)
| otherwise = mergePWM (a, rcPWM b, i)
where (_, (isSame, i)) = align a b
mergeTreeWeighted :: AlignFn -> Dendrogram Motif -> (PWM, [Int])
mergeTreeWeighted align t = case t of
Branch _ _ left right -> f (mergeTreeWeighted align left) $
mergeTreeWeighted align right
Leaf a -> (_pwm a, replicate (size $ _pwm a) 1)
where
f (a,w1) (b,w2)
| isSame = mergePWMWeighted (a,w1) (b,w2) i
| otherwise = mergePWMWeighted (a,w1) (rcPWM b, reverse w2) i
where (_, (isSame, i)) = align a b
iterativeMerge :: AlignFn
-> Double
-> [Motif]
-> [([B.ByteString], PWM, [Int])]
iterativeMerge align th motifs = runST $ do
motifs' <- V.unsafeThaw $ V.fromList $ flip map motifs $ \x ->
Just ([_name x], _pwm x, replicate (size $ _pwm x) 1)
let n = VM.length motifs'
iter mat = do
((i, j), (d, (isSame, pos))) <- loop ((1,1), (1/0, undefined)) 0 1
if d < th
then do
Just (nm1, pwm1, w1) <- VM.unsafeRead motifs' i
Just (nm2, pwm2, w2) <- VM.unsafeRead motifs' j
let merged = (nm1 ++ nm2, pwm', w')
(pwm',w') | isSame = mergePWMWeighted (pwm1, w1) (pwm2, w2) pos
| otherwise = mergePWMWeighted (pwm1, w1)
(rcPWM $ pwm2, reverse w2) pos
forM_ [0..n1] $ \i' -> MSU.unsafeWrite mat (i',j) Nothing
VM.unsafeWrite motifs' i $ Just merged
VM.unsafeWrite motifs' j Nothing
forM_ [0..n1] $ \j' -> when (i /= j') $ do
x <- VM.unsafeRead motifs' j'
case x of
Just (_, pwm2',_) -> do
let ali | i < j' = Just $ align pwm' pwm2'
| otherwise = Just $ align pwm2' pwm'
MSU.unsafeWrite mat (i,j') ali
_ -> return ()
iter mat
else return ()
where
loop ((i_min, j_min), d_min) i j
| i >= n = return ((i_min, j_min), d_min)
| j >= n = loop ((i_min, j_min), d_min) (i+1) (i+2)
| otherwise = do
x <- MSU.unsafeRead mat (i,j)
case x of
Just d -> if fst d < fst d_min
then loop ((i,j), d) i (j+1)
else loop ((i_min, j_min), d_min) i (j+1)
_ -> loop ((i_min, j_min), d_min) i (j+1)
mat <- MSU.replicate (n,n) Nothing :: ST s (MSU.SymMMatrix VM.MVector s (Maybe (Double, (Bool, Int))))
forM_ [0..n1] $ \i -> forM_ [i+1 .. n1] $ \j -> do
Just (_, pwm1, _) <- VM.unsafeRead motifs' i
Just (_, pwm2, _) <- VM.unsafeRead motifs' j
MSU.unsafeWrite mat (i,j) $ Just $ align pwm1 pwm2
iter mat
results <- V.unsafeFreeze motifs'
return $ V.toList $ V.map fromJust $ V.filter isJust results
buildTree :: AlignFn -> [Motif] -> Dendrogram Motif
buildTree align motifs = hclust Average (V.fromList motifs) δ
where
δ (Motif _ x) (Motif _ y) = fst $ align x y
cutTreeBy :: Double
-> Double
-> ([Dendrogram a] -> Bool) -> Dendrogram a -> [Dendrogram a]
cutTreeBy start step fn tree = go start
where
go x | fn clusters = clusters
| x step > 0 = go $ x step
| otherwise = clusters
where clusters = cutAt tree x