{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE FlexibleContexts #-} module Bio.Motif.Merge ( mergePWM , mergePWMWeighted , dilute , trim , mergeTree , mergeTreeWeighted , iterativeMerge , buildTree , cutTreeBy )where import AI.Clustering.Hierarchical hiding (size) import Control.Arrow (first) import Control.Monad (forM_, when) import Control.Monad.ST (runST, ST) import qualified Data.ByteString.Char8 as B import Data.List (dropWhileEnd) import qualified Data.Matrix.Symmetric.Mutable as MSU import qualified Data.Matrix.Unboxed as MU import Data.Maybe import qualified Data.Vector as V import qualified Data.Vector.Mutable as VM import qualified Data.Vector.Unboxed as U import Bio.Motif import Bio.Motif.Alignment import Bio.Utils.Functions (kld) mergePWM :: (PWM, PWM, Int) -> PWM mergePWM (m1, m2, shift) | shift >= 0 = merge shift (_mat m1) $ _mat m2 | otherwise = merge (-shift) (_mat m2) $ _mat m1 where merge s a b = PWM Nothing $ MU.fromRows $ loop 0 where n1 = MU.rows a n2 = MU.rows b loop i | i' < 0 || (i < n1 && i' >= n2) = MU.takeRow a i : loop (i+1) | i < n1 && i' < n2 = f (MU.takeRow a i) (MU.takeRow b i') : loop (i+1) | i >= n1 && i' < n2 = MU.takeRow b i' : loop (i+1) | otherwise = [] where i' = i - s f = U.zipWith (\x y -> (x+y)/2) mergePWMWeighted :: (PWM, [Int]) -- ^ pwm and weights at each position -> (PWM, [Int]) -> Int -- ^ shift -> (PWM, [Int]) mergePWMWeighted m1 m2 shift | shift >= 0 = merge shift (first _mat m1) $ first _mat m2 | otherwise = merge (-shift) (first _mat m2) $ first _mat m1 where merge s (p1,w1) (p2,w2) = first (PWM Nothing . MU.fromRows) $ unzip $ loop 0 where a = V.fromList $ zip (MU.toRows p1) w1 b = V.fromList $ zip (MU.toRows p2) w2 n1 = V.length a n2 = V.length b loop i | i' < 0 || (i < n1 && i' >= n2) = a V.! i : loop (i+1) | i < n1 && i' < n2 = f (a V.! i) (b V.! i') : loop (i+1) | i >= n1 && i' < n2 = b V.! i' : loop (i+1) | otherwise = [] where i' = i - s f (xs, wx) (ys, wy) = (U.zipWith (\x y -> (fromIntegral wx * x + fromIntegral wy * y) / fromIntegral (wx + wy)) xs ys, wx + wy) -- | dilute positions in a PWM that are associated with low weights dilute :: (PWM, [Int]) -> PWM dilute (pwm, ws) = PWM Nothing $ MU.fromRows $ zipWith f ws $ MU.toRows $ _mat pwm where f w r | w < n = let d = fromIntegral $ n - w in U.map (\x -> (fromIntegral w * x + 0.25 * d) / fromIntegral n) r | otherwise = r n = maximum ws {-# INLINE dilute #-} trim :: Bkgd -> Double -> PWM -> PWM trim (BG (a,c,g,t)) cutoff pwm = PWM Nothing $ MU.fromRows $ dropWhileEnd f $ dropWhile f rs where f x = kld x bg < cutoff rs = MU.toRows $ _mat pwm bg = U.fromList [a,c,g,t] {-# INLINE trim #-} mergeTree :: AlignFn -> Dendrogram Motif -> PWM mergeTree align t = case t of Branch _ _ left right -> f (mergeTree align left) $ mergeTree align right Leaf a -> _pwm a where f a b | isSame = mergePWM (a, b, i) | otherwise = mergePWM (a, rcPWM b, i) where (_, (isSame, i)) = align a b mergeTreeWeighted :: AlignFn -> Dendrogram Motif -> (PWM, [Int]) mergeTreeWeighted align t = case t of Branch _ _ left right -> f (mergeTreeWeighted align left) $ mergeTreeWeighted align right Leaf a -> (_pwm a, replicate (size $ _pwm a) 1) where f (a,w1) (b,w2) | isSame = mergePWMWeighted (a,w1) (b,w2) i | otherwise = mergePWMWeighted (a,w1) (rcPWM b, reverse w2) i where (_, (isSame, i)) = align a b {-# INLINE mergeTreeWeighted #-} iterativeMerge :: AlignFn -> Double -- cutoff -> [Motif] -- ^ Motifs to be merged. Motifs must have unique name. -> [([B.ByteString], PWM, [Int])] iterativeMerge align th motifs = runST $ do motifs' <- V.unsafeThaw $ V.fromList $ flip map motifs $ \x -> Just ([_name x], _pwm x, replicate (size $ _pwm x) 1) let n = VM.length motifs' iter mat = do -- retrieve the minimum value ((i, j), (d, (isSame, pos))) <- loop ((-1,-1), (1/0, undefined)) 0 1 if d < th then do Just (nm1, pwm1, w1) <- VM.unsafeRead motifs' i Just (nm2, pwm2, w2) <- VM.unsafeRead motifs' j let merged = (nm1 ++ nm2, pwm', w') (pwm',w') | isSame = mergePWMWeighted (pwm1, w1) (pwm2, w2) pos | otherwise = mergePWMWeighted (pwm1, w1) (rcPWM $ pwm2, reverse w2) pos -- update forM_ [0..n-1] $ \i' -> MSU.unsafeWrite mat (i',j) Nothing VM.unsafeWrite motifs' i $ Just merged VM.unsafeWrite motifs' j Nothing forM_ [0..n-1] $ \j' -> when (i /= j') $ do x <- VM.unsafeRead motifs' j' case x of Just (_, pwm2',_) -> do let ali | i < j' = Just $ align pwm' pwm2' | otherwise = Just $ align pwm2' pwm' MSU.unsafeWrite mat (i,j') ali _ -> return () iter mat else return () where loop ((i_min, j_min), d_min) i j | i >= n = return ((i_min, j_min), d_min) | j >= n = loop ((i_min, j_min), d_min) (i+1) (i+2) | otherwise = do x <- MSU.unsafeRead mat (i,j) case x of Just d -> if fst d < fst d_min then loop ((i,j), d) i (j+1) else loop ((i_min, j_min), d_min) i (j+1) _ -> loop ((i_min, j_min), d_min) i (j+1) -- initialization mat <- MSU.replicate (n,n) Nothing :: ST s (MSU.SymMMatrix VM.MVector s (Maybe (Double, (Bool, Int)))) forM_ [0..n-1] $ \i -> forM_ [i+1 .. n-1] $ \j -> do Just (_, pwm1, _) <- VM.unsafeRead motifs' i Just (_, pwm2, _) <- VM.unsafeRead motifs' j MSU.unsafeWrite mat (i,j) $ Just $ align pwm1 pwm2 iter mat results <- V.unsafeFreeze motifs' return $ V.toList $ V.map fromJust $ V.filter isJust results {-# INLINE iterativeMerge #-} -- | build a guide tree from a set of motifs buildTree :: AlignFn -> [Motif] -> Dendrogram Motif buildTree align motifs = hclust Average (V.fromList motifs) δ where δ (Motif _ x) (Motif _ y) = fst $ align x y cutTreeBy :: Double -- ^ start -> Double -- ^ step -> ([Dendrogram a] -> Bool) -> Dendrogram a -> [Dendrogram a] cutTreeBy start step fn tree = go start where go x | fn clusters = clusters | x - step > 0 = go $ x - step | otherwise = clusters where clusters = cutAt tree x