module Bio.Structure.Functions
  ( filterAtomsOfModel
  , chain, globalBond
  , residue
  , atom, localBond
  , renameChains
  ) where

import           Bio.Structure   (Atom (..), Bond (..), Chain (..),
                                  GlobalID (..), LocalID (..), Model (..),
                                  Residue (..), atoms, chains, globalBonds,
                                  localBonds, residues)
import           Control.Lens    (Traversal', each, (%~), (&))
import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as M (fromList, (!?))
import qualified Bio.Utils.Map   as M ((!?!))
import           Data.Set        (Set)
import qualified Data.Set        as S (fromList, notMember, unions)
import           Data.Text       (Text)
import           Data.Vector     (Vector)
import qualified Data.Vector     as V (filter, fromList, length, toList, unzip)

-- | Traversal for every 'Chain' of the 'Model'.
--
chain :: Traversal' Model Chain
chain :: Traversal' Model Chain
chain = Lens' Model (Vector Chain)
chains forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Each s t a b => Traversal s t a b
each

-- | Traversal for every 'Bond' of the 'Model'.
--
globalBond :: Traversal' Model (Bond GlobalID)
globalBond :: Traversal' Model (Bond GlobalID)
globalBond = Lens' Model (Vector (Bond GlobalID))
globalBonds forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Each s t a b => Traversal s t a b
each

-- | Traversal for every 'Residue' of the 'Chain'.
--
residue :: Traversal' Chain Residue
residue :: Traversal' Chain Residue
residue = Lens' Chain (Vector Residue)
residues forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Each s t a b => Traversal s t a b
each

-- | Traversal for every 'Atom' of the 'Residue'.
--
atom :: Traversal' Residue Atom
atom :: Traversal' Residue Atom
atom = Lens' Residue (Vector Atom)
atoms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Each s t a b => Traversal s t a b
each

-- | Traversal for every 'Bond' of the 'Residue'.
--
localBond :: Traversal' Residue (Bond LocalID)
localBond :: Traversal' Residue (Bond LocalID)
localBond = Lens' Residue (Vector (Bond LocalID))
localBonds forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Each s t a b => Traversal s t a b
each

-- | Rename chains of a given model according to the given mapping.
--   If chain is not present in the mapping then its name won't be changed.
--
renameChains :: Model -> Map Text Text -> Model
renameChains :: Model -> Map Text Text -> Model
renameChains Model
model Map Text Text
mapping = Model
model forall a b. a -> (a -> b) -> b
& Traversal' Model Chain
chain forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Chain -> Chain
renameChain
  where
    renameChain :: Chain -> Chain
    renameChain :: Chain -> Chain
renameChain ch :: Chain
ch@Chain{Vector Residue
Text
chainResidues :: Chain -> Vector Residue
chainName :: Chain -> Text
chainResidues :: Vector Residue
chainName :: Text
..} = Chain
ch { chainName :: Text
chainName = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Text
chainName forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ Map Text Text
mapping forall k a. Ord k => Map k a -> k -> Maybe a
M.!? Text
chainName }

-- | Takes predicate on 'Atom's of 'Model' and returns new 'Model' containing only atoms
--   satisfying given predicate.
--
filterAtomsOfModel :: (Atom -> Bool) -> Model -> Model
filterAtomsOfModel :: (Atom -> Bool) -> Model -> Model
filterAtomsOfModel Atom -> Bool
p Model{Vector (Bond GlobalID)
Vector Chain
modelBonds :: Model -> Vector (Bond GlobalID)
modelChains :: Model -> Vector Chain
modelBonds :: Vector (Bond GlobalID)
modelChains :: Vector Chain
..} = Vector Chain -> Vector (Bond GlobalID) -> Model
Model Vector Chain
newChains Vector (Bond GlobalID)
newBonds
  where
    removePred :: Atom -> Bool
removePred         = Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. Atom -> Bool
p
    (Vector Chain
newChains, Vector (Set GlobalID)
indss) = forall a b. Vector (a, b) -> (Vector a, Vector b)
V.unzip forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Atom -> Bool) -> Chain -> (Chain, Set GlobalID)
removeAtomsFromChain Atom -> Bool
removePred) Vector Chain
modelChains

    inds :: Set GlobalID
inds     = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions forall a b. (a -> b) -> a -> b
$ forall a. Vector a -> [a]
V.toList Vector (Set GlobalID)
indss
    newBonds :: Vector (Bond GlobalID)
newBonds = forall a. (a -> Bool) -> Vector a -> Vector a
V.filter (\(Bond GlobalID
l GlobalID
r Int
_) -> GlobalID
l forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set GlobalID
inds Bool -> Bool -> Bool
&& GlobalID
r forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set GlobalID
inds) Vector (Bond GlobalID)
modelBonds

removeAtomsFromChain :: (Atom -> Bool) -> Chain -> (Chain, Set GlobalID)
removeAtomsFromChain :: (Atom -> Bool) -> Chain -> (Chain, Set GlobalID)
removeAtomsFromChain Atom -> Bool
p Chain{Vector Residue
Text
chainResidues :: Vector Residue
chainName :: Text
chainResidues :: Chain -> Vector Residue
chainName :: Chain -> Text
..} = (Text -> Vector Residue -> Chain
Chain Text
chainName Vector Residue
newResidues, forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions forall a b. (a -> b) -> a -> b
$ forall a. Vector a -> [a]
V.toList Vector (Set GlobalID)
indss)
  where
    (Vector Residue
newResidues, Vector (Set GlobalID)
indss) = forall a b. Vector (a, b) -> (Vector a, Vector b)
V.unzip forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Atom -> Bool) -> Residue -> (Residue, Set GlobalID)
removeAtomsFromResidue Atom -> Bool
p) Vector Residue
chainResidues

removeAtomsFromResidue :: (Atom -> Bool) -> Residue -> (Residue, Set GlobalID)
removeAtomsFromResidue :: (Atom -> Bool) -> Residue -> (Residue, Set GlobalID)
removeAtomsFromResidue Atom -> Bool
p r' :: Residue
r'@Residue{Char
Int
Vector (Bond LocalID)
Vector Atom
Text
SecondaryStructure
resChemCompType :: Residue -> Text
resSecondary :: Residue -> SecondaryStructure
resBonds :: Residue -> Vector (Bond LocalID)
resAtoms :: Residue -> Vector Atom
resInsertionCode :: Residue -> Char
resNumber :: Residue -> Int
resName :: Residue -> Text
resChemCompType :: Text
resSecondary :: SecondaryStructure
resBonds :: Vector (Bond LocalID)
resAtoms :: Vector Atom
resInsertionCode :: Char
resNumber :: Int
resName :: Text
..} = (Residue
res, forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall a. Vector a -> [a]
V.toList forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Atom -> GlobalID
atomId Vector Atom
withAtom)
  where
    (Vector Atom
withAtom, Vector Atom
withoutAtom, [Int]
indsToDelete) = Vector Atom -> (Vector Atom, Vector Atom, [Int])
partitionAndInds Vector Atom
resAtoms

    oldIndsToNew :: Map Int Int
oldIndsToNew = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
i -> (Int
i, Int -> Int
newInd Int
i)) [Int
0 .. forall a. Vector a -> Int
V.length Vector Atom
resAtoms forall a. Num a => a -> a -> a
- Int
1]
    newBonds :: Vector (Bond LocalID)
newBonds     = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bond LocalID -> Bond LocalID
modifyBond forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> Vector a -> Vector a
V.filter Bond LocalID -> Bool
leaveBond Vector (Bond LocalID)
resBonds

    res :: Residue
res = Residue
r' { resAtoms :: Vector Atom
resAtoms=Vector Atom
withoutAtom, resBonds :: Vector (Bond LocalID)
resBonds=Vector (Bond LocalID)
newBonds }

    leaveBond :: Bond LocalID -> Bool
    leaveBond :: Bond LocalID -> Bool
leaveBond (Bond (LocalID Int
l) (LocalID Int
r) Int
_) = Int
l forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int]
indsToDelete Bool -> Bool -> Bool
&& Int
r forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int]
indsToDelete

    modifyBond :: Bond LocalID -> Bond LocalID
    modifyBond :: Bond LocalID -> Bond LocalID
modifyBond (Bond (LocalID Int
l) (LocalID Int
r) Int
t) = forall m. m -> m -> Int -> Bond m
Bond (Int -> LocalID
LocalID forall a b. (a -> b) -> a -> b
$ Map Int Int
oldIndsToNew forall k a.
(HasCallStack, Ord k, Show k, Show a) =>
Map k a -> k -> a
M.!?! Int
l)
                                                       (Int -> LocalID
LocalID forall a b. (a -> b) -> a -> b
$ Map Int Int
oldIndsToNew forall k a.
(HasCallStack, Ord k, Show k, Show a) =>
Map k a -> k -> a
M.!?! Int
r)
                                                       Int
t

    newInd :: Int -> Int
    newInd :: Int -> Int
newInd Int
i = Int
i forall a. Num a => a -> a -> a
- (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Ord a => a -> a -> Bool
< Int
i) [Int]
indsToDelete)

    partitionAndInds :: Vector Atom -> (Vector Atom, Vector Atom, [Int])
    partitionAndInds :: Vector Atom -> (Vector Atom, Vector Atom, [Int])
partitionAndInds = Int
-> ([Atom], [Atom], [Int])
-> [Atom]
-> (Vector Atom, Vector Atom, [Int])
go Int
0 ([], [], []) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Vector a -> [a]
V.toList
      where
        go :: Int -> ([Atom], [Atom], [Int]) -> [Atom] -> (Vector Atom, Vector Atom, [Int])
        go :: Int
-> ([Atom], [Atom], [Int])
-> [Atom]
-> (Vector Atom, Vector Atom, [Int])
go Int
_ ([Atom]
sat, [Atom]
notSat, [Int]
inds) []       = (forall a. [a] -> Vector a
V.fromList forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [Atom]
sat, forall a. [a] -> Vector a
V.fromList forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [Atom]
notSat, forall a. [a] -> [a]
reverse [Int]
inds)
        go Int
i ([Atom]
sat, [Atom]
notSat, [Int]
inds) (Atom
x : [Atom]
xs) = Int
-> ([Atom], [Atom], [Int])
-> [Atom]
-> (Vector Atom, Vector Atom, [Int])
go (Int
i forall a. Num a => a -> a -> a
+ Int
1) ([Atom], [Atom], [Int])
newState [Atom]
xs
          where
            newState :: ([Atom], [Atom], [Int])
newState = if Atom -> Bool
p Atom
x then (Atom
x forall a. a -> [a] -> [a]
: [Atom]
sat, [Atom]
notSat, Int
i forall a. a -> [a] -> [a]
: [Int]
inds) else ([Atom]
sat, Atom
x forall a. a -> [a] -> [a]
: [Atom]
notSat, [Int]
inds)