module Bio.Structure.Functions
  ( filterAtomsOfModel
  , chain, globalBond
  , residue
  , atom, localBond
  , renameChains
  ) where

import           Bio.Structure   (Atom (..), Bond (..), Chain (..),
                                  GlobalID (..), LocalID (..), Model (..),
                                  Residue (..), atoms, chains, globalBonds,
                                  localBonds, residues)
import           Control.Lens    (Traversal', each, (%~), (&))
import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as M (fromList, (!?))
import qualified Bio.Utils.Map   as M ((!?!))
import           Data.Set        (Set)
import qualified Data.Set        as S (fromList, notMember, unions)
import           Data.Text       (Text)
import           Data.Vector     (Vector)
import qualified Data.Vector     as V (filter, fromList, length, toList, unzip)

-- | Traversal for every 'Chain' of the 'Model'.
--
chain :: Traversal' Model Chain
chain :: (Chain -> f Chain) -> Model -> f Model
chain = (Vector Chain -> f (Vector Chain)) -> Model -> f Model
Lens' Model (Vector Chain)
chains ((Vector Chain -> f (Vector Chain)) -> Model -> f Model)
-> ((Chain -> f Chain) -> Vector Chain -> f (Vector Chain))
-> (Chain -> f Chain)
-> Model
-> f Model
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Chain -> f Chain) -> Vector Chain -> f (Vector Chain)
forall s t a b. Each s t a b => Traversal s t a b
each

-- | Traversal for every 'Bond' of the 'Model'.
--
globalBond :: Traversal' Model (Bond GlobalID)
globalBond :: (Bond GlobalID -> f (Bond GlobalID)) -> Model -> f Model
globalBond = (Vector (Bond GlobalID) -> f (Vector (Bond GlobalID)))
-> Model -> f Model
Lens' Model (Vector (Bond GlobalID))
globalBonds ((Vector (Bond GlobalID) -> f (Vector (Bond GlobalID)))
 -> Model -> f Model)
-> ((Bond GlobalID -> f (Bond GlobalID))
    -> Vector (Bond GlobalID) -> f (Vector (Bond GlobalID)))
-> (Bond GlobalID -> f (Bond GlobalID))
-> Model
-> f Model
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bond GlobalID -> f (Bond GlobalID))
-> Vector (Bond GlobalID) -> f (Vector (Bond GlobalID))
forall s t a b. Each s t a b => Traversal s t a b
each

-- | Traversal for every 'Residue' of the 'Chain'.
--
residue :: Traversal' Chain Residue
residue :: (Residue -> f Residue) -> Chain -> f Chain
residue = (Vector Residue -> f (Vector Residue)) -> Chain -> f Chain
Lens' Chain (Vector Residue)
residues ((Vector Residue -> f (Vector Residue)) -> Chain -> f Chain)
-> ((Residue -> f Residue) -> Vector Residue -> f (Vector Residue))
-> (Residue -> f Residue)
-> Chain
-> f Chain
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Residue -> f Residue) -> Vector Residue -> f (Vector Residue)
forall s t a b. Each s t a b => Traversal s t a b
each

-- | Traversal for every 'Atom' of the 'Residue'.
--
atom :: Traversal' Residue Atom
atom :: (Atom -> f Atom) -> Residue -> f Residue
atom = (Vector Atom -> f (Vector Atom)) -> Residue -> f Residue
Lens' Residue (Vector Atom)
atoms ((Vector Atom -> f (Vector Atom)) -> Residue -> f Residue)
-> ((Atom -> f Atom) -> Vector Atom -> f (Vector Atom))
-> (Atom -> f Atom)
-> Residue
-> f Residue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Atom -> f Atom) -> Vector Atom -> f (Vector Atom)
forall s t a b. Each s t a b => Traversal s t a b
each

-- | Traversal for every 'Bond' of the 'Residue'.
--
localBond :: Traversal' Residue (Bond LocalID)
localBond :: (Bond LocalID -> f (Bond LocalID)) -> Residue -> f Residue
localBond = (Vector (Bond LocalID) -> f (Vector (Bond LocalID)))
-> Residue -> f Residue
Lens' Residue (Vector (Bond LocalID))
localBonds ((Vector (Bond LocalID) -> f (Vector (Bond LocalID)))
 -> Residue -> f Residue)
-> ((Bond LocalID -> f (Bond LocalID))
    -> Vector (Bond LocalID) -> f (Vector (Bond LocalID)))
-> (Bond LocalID -> f (Bond LocalID))
-> Residue
-> f Residue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bond LocalID -> f (Bond LocalID))
-> Vector (Bond LocalID) -> f (Vector (Bond LocalID))
forall s t a b. Each s t a b => Traversal s t a b
each

-- | Rename chains of a given model according to the given mapping.
--   If chain is not present in the mapping then its name won't be changed.
--
renameChains :: Model -> Map Text Text -> Model
renameChains :: Model -> Map Text Text -> Model
renameChains Model
model Map Text Text
mapping = Model
model Model -> (Model -> Model) -> Model
forall a b. a -> (a -> b) -> b
& (Chain -> Identity Chain) -> Model -> Identity Model
Traversal' Model Chain
chain ((Chain -> Identity Chain) -> Model -> Identity Model)
-> (Chain -> Chain) -> Model -> Model
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ Chain -> Chain
renameChain
  where
    renameChain :: Chain -> Chain
    renameChain :: Chain -> Chain
renameChain ch :: Chain
ch@Chain{Text
Vector Residue
chainResidues :: Chain -> Vector Residue
chainName :: Chain -> Text
chainResidues :: Vector Residue
chainName :: Text
..} = Chain
ch { chainName :: Text
chainName = Text -> (Text -> Text) -> Maybe Text -> Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Text
chainName Text -> Text
forall a. a -> a
id (Maybe Text -> Text) -> Maybe Text -> Text
forall a b. (a -> b) -> a -> b
$ Map Text Text
mapping Map Text Text -> Text -> Maybe Text
forall k a. Ord k => Map k a -> k -> Maybe a
M.!? Text
chainName }

-- | Takes predicate on 'Atom's of 'Model' and returns new 'Model' containing only atoms
--   satisfying given predicate.
--
filterAtomsOfModel :: (Atom -> Bool) -> Model -> Model
filterAtomsOfModel :: (Atom -> Bool) -> Model -> Model
filterAtomsOfModel Atom -> Bool
p Model{Vector (Bond GlobalID)
Vector Chain
modelBonds :: Model -> Vector (Bond GlobalID)
modelChains :: Model -> Vector Chain
modelBonds :: Vector (Bond GlobalID)
modelChains :: Vector Chain
..} = Vector Chain -> Vector (Bond GlobalID) -> Model
Model Vector Chain
newChains Vector (Bond GlobalID)
newBonds
  where
    removePred :: Atom -> Bool
removePred         = Bool -> Bool
not (Bool -> Bool) -> (Atom -> Bool) -> Atom -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Atom -> Bool
p
    (Vector Chain
newChains, Vector (Set GlobalID)
indss) = Vector (Chain, Set GlobalID)
-> (Vector Chain, Vector (Set GlobalID))
forall a b. Vector (a, b) -> (Vector a, Vector b)
V.unzip (Vector (Chain, Set GlobalID)
 -> (Vector Chain, Vector (Set GlobalID)))
-> Vector (Chain, Set GlobalID)
-> (Vector Chain, Vector (Set GlobalID))
forall a b. (a -> b) -> a -> b
$ (Chain -> (Chain, Set GlobalID))
-> Vector Chain -> Vector (Chain, Set GlobalID)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Atom -> Bool) -> Chain -> (Chain, Set GlobalID)
removeAtomsFromChain Atom -> Bool
removePred) Vector Chain
modelChains

    inds :: Set GlobalID
inds     = [Set GlobalID] -> Set GlobalID
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions ([Set GlobalID] -> Set GlobalID) -> [Set GlobalID] -> Set GlobalID
forall a b. (a -> b) -> a -> b
$ Vector (Set GlobalID) -> [Set GlobalID]
forall a. Vector a -> [a]
V.toList Vector (Set GlobalID)
indss
    newBonds :: Vector (Bond GlobalID)
newBonds = (Bond GlobalID -> Bool)
-> Vector (Bond GlobalID) -> Vector (Bond GlobalID)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter (\(Bond GlobalID
l GlobalID
r Int
_) -> GlobalID
l GlobalID -> Set GlobalID -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set GlobalID
inds Bool -> Bool -> Bool
&& GlobalID
r GlobalID -> Set GlobalID -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set GlobalID
inds) Vector (Bond GlobalID)
modelBonds

removeAtomsFromChain :: (Atom -> Bool) -> Chain -> (Chain, Set GlobalID)
removeAtomsFromChain :: (Atom -> Bool) -> Chain -> (Chain, Set GlobalID)
removeAtomsFromChain Atom -> Bool
p Chain{Text
Vector Residue
chainResidues :: Vector Residue
chainName :: Text
chainResidues :: Chain -> Vector Residue
chainName :: Chain -> Text
..} = (Text -> Vector Residue -> Chain
Chain Text
chainName Vector Residue
newResidues, [Set GlobalID] -> Set GlobalID
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions ([Set GlobalID] -> Set GlobalID) -> [Set GlobalID] -> Set GlobalID
forall a b. (a -> b) -> a -> b
$ Vector (Set GlobalID) -> [Set GlobalID]
forall a. Vector a -> [a]
V.toList Vector (Set GlobalID)
indss)
  where
    (Vector Residue
newResidues, Vector (Set GlobalID)
indss) = Vector (Residue, Set GlobalID)
-> (Vector Residue, Vector (Set GlobalID))
forall a b. Vector (a, b) -> (Vector a, Vector b)
V.unzip (Vector (Residue, Set GlobalID)
 -> (Vector Residue, Vector (Set GlobalID)))
-> Vector (Residue, Set GlobalID)
-> (Vector Residue, Vector (Set GlobalID))
forall a b. (a -> b) -> a -> b
$ (Residue -> (Residue, Set GlobalID))
-> Vector Residue -> Vector (Residue, Set GlobalID)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Atom -> Bool) -> Residue -> (Residue, Set GlobalID)
removeAtomsFromResidue Atom -> Bool
p) Vector Residue
chainResidues

removeAtomsFromResidue :: (Atom -> Bool) -> Residue -> (Residue, Set GlobalID)
removeAtomsFromResidue :: (Atom -> Bool) -> Residue -> (Residue, Set GlobalID)
removeAtomsFromResidue Atom -> Bool
p r' :: Residue
r'@Residue{Char
Int
Text
Vector (Bond LocalID)
Vector Atom
SecondaryStructure
resChemCompType :: Residue -> Text
resSecondary :: Residue -> SecondaryStructure
resBonds :: Residue -> Vector (Bond LocalID)
resAtoms :: Residue -> Vector Atom
resInsertionCode :: Residue -> Char
resNumber :: Residue -> Int
resName :: Residue -> Text
resChemCompType :: Text
resSecondary :: SecondaryStructure
resBonds :: Vector (Bond LocalID)
resAtoms :: Vector Atom
resInsertionCode :: Char
resNumber :: Int
resName :: Text
..} = (Residue
res, [GlobalID] -> Set GlobalID
forall a. Ord a => [a] -> Set a
S.fromList ([GlobalID] -> Set GlobalID) -> [GlobalID] -> Set GlobalID
forall a b. (a -> b) -> a -> b
$ Vector GlobalID -> [GlobalID]
forall a. Vector a -> [a]
V.toList (Vector GlobalID -> [GlobalID]) -> Vector GlobalID -> [GlobalID]
forall a b. (a -> b) -> a -> b
$ (Atom -> GlobalID) -> Vector Atom -> Vector GlobalID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Atom -> GlobalID
atomId Vector Atom
withAtom)
  where
    (Vector Atom
withAtom, Vector Atom
withoutAtom, [Int]
indsToDelete) = Vector Atom -> (Vector Atom, Vector Atom, [Int])
partitionAndInds Vector Atom
resAtoms

    oldIndsToNew :: Map Int Int
oldIndsToNew = [(Int, Int)] -> Map Int Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Int, Int)] -> Map Int Int) -> [(Int, Int)] -> Map Int Int
forall a b. (a -> b) -> a -> b
$ (Int -> (Int, Int)) -> [Int] -> [(Int, Int)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
i -> (Int
i, Int -> Int
newInd Int
i)) [Int
0 .. Vector Atom -> Int
forall a. Vector a -> Int
V.length Vector Atom
resAtoms Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
    newBonds :: Vector (Bond LocalID)
newBonds     = (Bond LocalID -> Bond LocalID)
-> Vector (Bond LocalID) -> Vector (Bond LocalID)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bond LocalID -> Bond LocalID
modifyBond (Vector (Bond LocalID) -> Vector (Bond LocalID))
-> Vector (Bond LocalID) -> Vector (Bond LocalID)
forall a b. (a -> b) -> a -> b
$ (Bond LocalID -> Bool)
-> Vector (Bond LocalID) -> Vector (Bond LocalID)
forall a. (a -> Bool) -> Vector a -> Vector a
V.filter Bond LocalID -> Bool
leaveBond Vector (Bond LocalID)
resBonds

    res :: Residue
res = Residue
r' { resAtoms :: Vector Atom
resAtoms=Vector Atom
withoutAtom, resBonds :: Vector (Bond LocalID)
resBonds=Vector (Bond LocalID)
newBonds }

    leaveBond :: Bond LocalID -> Bool
    leaveBond :: Bond LocalID -> Bool
leaveBond (Bond (LocalID Int
l) (LocalID Int
r) Int
_) = Int
l Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int]
indsToDelete Bool -> Bool -> Bool
&& Int
r Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int]
indsToDelete

    modifyBond :: Bond LocalID -> Bond LocalID
    modifyBond :: Bond LocalID -> Bond LocalID
modifyBond (Bond (LocalID Int
l) (LocalID Int
r) Int
t) = LocalID -> LocalID -> Int -> Bond LocalID
forall m. m -> m -> Int -> Bond m
Bond (Int -> LocalID
LocalID (Int -> LocalID) -> Int -> LocalID
forall a b. (a -> b) -> a -> b
$ Map Int Int
oldIndsToNew Map Int Int -> Int -> Int
forall k a.
(HasCallStack, Ord k, Show k, Show a) =>
Map k a -> k -> a
M.!?! Int
l)
                                                       (Int -> LocalID
LocalID (Int -> LocalID) -> Int -> LocalID
forall a b. (a -> b) -> a -> b
$ Map Int Int
oldIndsToNew Map Int Int -> Int -> Int
forall k a.
(HasCallStack, Ord k, Show k, Show a) =>
Map k a -> k -> a
M.!?! Int
r)
                                                       Int
t

    newInd :: Int -> Int
    newInd :: Int -> Int
newInd Int
i = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i) [Int]
indsToDelete)

    partitionAndInds :: Vector Atom -> (Vector Atom, Vector Atom, [Int])
    partitionAndInds :: Vector Atom -> (Vector Atom, Vector Atom, [Int])
partitionAndInds = Int
-> ([Atom], [Atom], [Int])
-> [Atom]
-> (Vector Atom, Vector Atom, [Int])
go Int
0 ([], [], []) ([Atom] -> (Vector Atom, Vector Atom, [Int]))
-> (Vector Atom -> [Atom])
-> Vector Atom
-> (Vector Atom, Vector Atom, [Int])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Atom -> [Atom]
forall a. Vector a -> [a]
V.toList
      where
        go :: Int -> ([Atom], [Atom], [Int]) -> [Atom] -> (Vector Atom, Vector Atom, [Int])
        go :: Int
-> ([Atom], [Atom], [Int])
-> [Atom]
-> (Vector Atom, Vector Atom, [Int])
go Int
_ ([Atom]
sat, [Atom]
notSat, [Int]
inds) []       = ([Atom] -> Vector Atom
forall a. [a] -> Vector a
V.fromList ([Atom] -> Vector Atom) -> [Atom] -> Vector Atom
forall a b. (a -> b) -> a -> b
$ [Atom] -> [Atom]
forall a. [a] -> [a]
reverse [Atom]
sat, [Atom] -> Vector Atom
forall a. [a] -> Vector a
V.fromList ([Atom] -> Vector Atom) -> [Atom] -> Vector Atom
forall a b. (a -> b) -> a -> b
$ [Atom] -> [Atom]
forall a. [a] -> [a]
reverse [Atom]
notSat, [Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
inds)
        go Int
i ([Atom]
sat, [Atom]
notSat, [Int]
inds) (Atom
x : [Atom]
xs) = Int
-> ([Atom], [Atom], [Int])
-> [Atom]
-> (Vector Atom, Vector Atom, [Int])
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([Atom], [Atom], [Int])
newState [Atom]
xs
          where
            newState :: ([Atom], [Atom], [Int])
newState = if Atom -> Bool
p Atom
x then (Atom
x Atom -> [Atom] -> [Atom]
forall a. a -> [a] -> [a]
: [Atom]
sat, [Atom]
notSat, Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
inds) else ([Atom]
sat, Atom
x Atom -> [Atom] -> [Atom]
forall a. a -> [a] -> [a]
: [Atom]
notSat, [Int]
inds)