{-# LANGUAGE MultiParamTypeClasses #-} 
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FunctionalDependencies #-}

module DCLabel.Core where

import Data.List (nub, sort, (\\))
import DCLabel.Lattice

-- | It represents principals in the labels. 
newtype Principal = MkPrincipal { name :: String } deriving (Eq, Ord)

-- | It generates a principal from an string. 
principal :: String -> Principal 
principal = MkPrincipal

instance Show Principal where
    show (MkPrincipal s) = show s

instance Read Principal where
    readsPrec d s = map (\(p, rest) -> (principal p, rest)) $ readsPrec d s

-- | Data type to represent disjunctions. 
newtype Disj a = MkDisj { disj :: [a] } deriving (Eq, Ord)

instance Show a => Show (Disj a) where
  show (MkDisj xs) = "[ " ++ showCat xs ++ " ]"
                     where showCat []     = ""
                           showCat [x]    = init $ tail $ show x 
                           showCat (x:xs) = init (tail (show x)) ++ " \\/ " ++ showCat xs

-- | Data type to represent conjunctions
newtype Conj a = MkConj { conj :: [a] } deriving (Eq, Ord)


instance Show a => Show (Conj a) where 
  show (MkConj [])     = "None" 
  show (MkConj [x])    = show x
  show (MkConj (x:xs))   = show x ++ " /\\ " ++ show (MkConj xs)  

-- | Data type to represent labels, i.e. conjunctions of disjunctions. 
data  Label a =  MkLabelAll
               | MkLabel { label :: (Conj (Disj a)) } deriving (Ord)

instance Show a => Show (Label a) where 
  show (MkLabelAll) = "All" 
  show (MkLabel l)  = show l 

instance (Ord a, Eq a) => Eq (Label a) where
  (==) MkLabelAll MkLabelAll = True
  (==) MkLabelAll _ = False
  (==) _ MkLabelAll = False
  (==) l1 l2 = (label . reduceLabel $ l1) == (label . reduceLabel $ l2)

-- | It takes two labels and makes the conjunction of them. 
and_label :: Label a -> Label a -> Label a 
and_label l1 l2   | isAllLabel l1 || isAllLabel l2 = MkLabelAll
                  | otherwise = MkLabel { label = MkConj $ conj (label l1)
                                                        ++ conj (label l2) }   


{-| 
It takes two labels and distributes the disjunction. For example, 

>>> singleton "Alice" `or_label` ( "Bob" ./\. "Carla" )
[ Alice \/ Bob ] /\ [ Alice \/ Carla ]

-}
or_label :: Ord a => Label a -> Label a -> Label a 
or_label l1 l2 | isEmptyLabel l2 || isEmptyLabel l1 = emptyLabel 
               | isAllLabel   l2 = l1    
               | isAllLabel   l1 = l2   
               | otherwise      = MkLabel $ MkConj $ 
                                 [ MkDisj (disj d1 ++ disj d2) | d1 <- conj (label l1), d2 <- conj (label l2), 
                                                                 (not . null) (disj d1), (not . null) (disj d2) ] 

{-
               where l1' = nubLabel l1   -- We need to remove the empty lists, otherwise it won't work: 
                     l2' = nubLabel l2   -- I = [A] /\ [] /\ []   I = [B] .. the or_label would be 
                                         -- I = [A \/ B] /\ [B] /\ [B], and then it reduces to [B]!
-}
               
-- | A label without any disjunctions or conjunctions.
emptyLabel :: Label a 
emptyLabel = MkLabel (MkConj [])

allLabel :: Label a
allLabel = MkLabelAll

-- ^ Predicate function that returns @True@ if the label corresponds to
-- the 'emptyLabel'.
isEmptyLabel :: Label a -> Bool
isEmptyLabel MkLabelAll = False
isEmptyLabel l = and [ null (disj d) | d <- conj (label l) ]

-- ^ Predicate function that retuns @True@ if the label corresponds to
-- the 'allLabel'.
isAllLabel :: Label a -> Bool
isAllLabel MkLabelAll = True
isAllLabel _ = False


{-| 
   It determines if a conjunction of disjunctions (i.e. a label) implies (in the logical sense) a disjunction. 
   In other words, it checks if d1 /\ ... /\ dn => d1.
-} 
impliesDisj :: Ord a => Label a -> Disj a -> Bool 
impliesDisj l1 d | isAllLabel l1 = True   -- Asserts 1
                 | otherwise = or [ and [ e `elem` (disj d) | e <- disj d1 ]
                                  | d1 <- conj (label l1) ]

{-| 
   It determines if a label implies (in the logical sense) another label. 
   In other words, d1 /\ ... /\ dn => d1' /\ ... /\ dn'.
-}
implies :: Ord a => Label a -> Label a -> Bool 
implies l1 l2   | isAllLabel l1 = True -- Asserts 1
                | isAllLabel l2 = False
                | otherwise = and [ impliesDisj l1 d | d <- conj (label (reduceLabel l2)) ] 


-- | It removes extraneous disjunctions from a given label.
reduceLabel :: Ord a => Label a -> Label a 
reduceLabel l@(MkLabelAll) = l 
reduceLabel l = (MkLabel . MkConj) $
                conj (label l') \\ extraneous 
                where l' = nubLabel l
                      extraneous = [ d2 | d1 <- conj (label l'), d2 <- conj (label l'), d1 /= d2, 
                                     impliesDisj ((MkLabel . MkConj) [d1]) d2 ] 

-- | It removes repeated principals from a given label.
nubLabel :: Ord a => Label a -> Label a  
nubLabel l@(MkLabelAll) = l
nubLabel l@(MkLabel _)  = MkLabel $ 
                             MkConj $ (sort.nub) [ MkDisj ( (sort.nub) (disj c) ) | c <- conj (label l), 
                                                                                    not. null $ disj c ]


-- | It represents privilege.
data Priv a = MkPriv { priv :: DCLabel a } deriving Eq

instance Show a => Show (Priv a) where 
  show (MkPriv l) = "Privilege: " ++ show l  

-- | It represents disjunction category labels
data DCLabel a = MkDCLabel { secrecy   :: Label a
                           , integrity :: Label a
                           } deriving (Eq, Ord)


instance Show a => Show (DCLabel a) where 
     show dclabel = "S=" ++ show (secrecy dclabel) ++ " I=" ++ show (integrity dclabel)

-- | It computes the canonical form of a 'DCLabel'.
reduceDC :: Ord a => DCLabel a -> DCLabel a
reduceDC l = let s = reduceLabel . secrecy $ l
                 i = reduceLabel . integrity $ l
             in MkDCLabel { secrecy = s, integrity = i }

----------------------------------------
---- DC Labels are elements of a lattice
----------------------------------------

-- | It defines that 'DCLabel' are elements of a lattice.
instance Ord a => Lattice (DCLabel a) where
  top = MkDCLabel { secrecy = MkLabelAll
                  , integrity = emptyLabel }
  bottom = MkDCLabel { secrecy = emptyLabel
                     , integrity = MkLabelAll }
  canflowto l1 l2 = let l1' = reduceDC l1
                        l2' = reduceDC l2
                    in ((secrecy l2') `implies` (secrecy l1')) &&
                       ((integrity l1') `implies` (integrity l2'))
  join l1 l2 = let s3 = (secrecy l1) `and_label` (secrecy l2)
                   i3 = (integrity l1) `or_label` (integrity l2)
               in reduceDC $ MkDCLabel s3 i3
  meet l1 l2 = let s3 = (secrecy l1) `or_label` (secrecy l2)
                   i3 = (integrity l1) `and_label` (integrity l2)
               in reduceDC $ MkDCLabel s3 i3
  

-- | It checks if a 'DCLabel' can flow to another 'DCLabel' given some privilege. 
canflowto_p :: Lattice (DCLabel a) => Priv a -> DCLabel a -> DCLabel a -> Bool
canflowto_p p dcl1 dcl2 = let dcl2' = newDC (and_label ( secrecy (priv p) ) (secrecy dcl2))
                                            (integrity dcl2)
                              dcl1' = newDC (secrecy dcl1) 
                                            (and_label ( integrity (priv p) ) (integrity dcl1)) 
                          in canflowto dcl1' dcl2' 

--
--Tests
--

a = MkDisj { disj = ["A"] } 

aorb = MkDisj { disj = ["A", "B"] } 
cord = MkDisj { disj = ["C", "D"] } 
eorf = MkDisj { disj = ["E", "F"] } 
gorh = MkDisj { disj = ["G", "H"] } 

aorborc = MkDisj { disj = ["A", "B", "C", "A"] } 

conjaorb = MkLabel $ MkConj { conj = [aorb] }


conjaorborc = MkLabel $ MkConj { conj = [aorborc, aorborc, aorb] }

t1 = impliesDisj conjaorborc aorb -- False
t2 = impliesDisj conjaorb aorborc -- True 

t3 = implies conjaorb conjaorborc -- True 
t4 = implies conjaorborc conjaorb -- False

aorbANDcord = MkLabel $ MkConj { conj = [aorb,cord] }
eorfANDgorh = MkLabel $ MkConj { conj = [eorf,gorh] }

t5 = aorbANDcord `or_label` eorfANDgorh 





{- (S=[ A ] /\ [ C \/ C ] /\ [  ] I=[ A ] /\ [ A \/ A ] /\ [  ],S=[ A ] /\ [  ] I=[ B ] /\ [  ] /\ [  ]) -}

-- 


---
---
--- Small DSL for working with labels. It is a simplified version of what David did
---
---

-- | Class that allows the creation of label which are just singletons.
class Singleton a b | a -> b where 
      -- | It creates a singleton label, i.e. a label with only one principal.
      singleton :: a -> Label b 

instance Singleton Principal Principal where 
      singleton p = MkLabel $ MkConj [ MkDisj [p] ]

instance Singleton String Principal where 
      singleton s = MkLabel $ MkConj [ MkDisj [principal s] ]


infixl 7 .\/.
infixl 6 ./\.

-- | Class used to create disjunctions
class DisjunctionOf a b c | a b -> c  where
  (.\/.) :: a -> b -> c

instance DisjunctionOf Principal Principal (Label Principal) where 
 p1 .\/. p2 = MkLabel $ MkConj [ MkDisj [p1,p2] ]

instance DisjunctionOf Principal (Label Principal) (Label Principal) where 
 p .\/. l = (singleton p) `or_label` l 

instance DisjunctionOf (Label Principal) Principal (Label Principal) where 
 l .\/. p = p .\/. l

instance DisjunctionOf (Label Principal) (Label Principal) (Label Principal) where 
 l1 .\/. l2 = l1 `or_label` l2

instance DisjunctionOf String String (Label Principal) where 
 s1 .\/. s2 = singleton s1 .\/. singleton s2

instance DisjunctionOf String (Label Principal) (Label Principal) where 
  s .\/. l = singleton s .\/. l 

instance DisjunctionOf (Label Principal) String (Label Principal) where 
  l .\/. p = p .\/. l  


-- | Class used to create conjunctions
class ConjunctionOf a b c | a b -> c where
   (./\.) :: a -> b -> c

instance ConjunctionOf Principal Principal (Label Principal) where
   p1 ./\. p2 = MkLabel $ MkConj [ MkDisj [p1], MkDisj [p2] ] 

instance ConjunctionOf Principal (Label Principal) (Label Principal) where
   p ./\. l = singleton p `and_label` l 

instance ConjunctionOf (Label Principal) Principal (Label Principal) where
   l ./\. p = p ./\. l 

instance ConjunctionOf (Label Principal) (Label Principal) (Label Principal) where
   l1 ./\. l2 = l1 `and_label` l2 

-- | Instances usng strings and not principals
instance ConjunctionOf String String (Label Principal) where
   s1 ./\. s2 = singleton s1 ./\. singleton s2 

instance ConjunctionOf String (Label Principal) (Label Principal) where
   s ./\. l = singleton s `and_label` l 

instance ConjunctionOf (Label Principal) String (Label Principal) where
   l ./\. s = s ./\. l 


(<>) = emptyLabel

---
---
--- Generic API
---
---

-- | It checks that a principal appears in a label. 
isPrincipal :: Principal -> Label Principal -> Bool
isPrincipal p l =  or [ name p `elem` ps' | ps <- labelToList l, let ps' = map name ps ]

-- | It converts a label into a list. 
labelToList :: Label Principal -> [ [Principal] ]
labelToList (MkLabel l) = [ [ d | d <- disj c ]  | c <- conj l ]

-- | It extracts the label corresponding to the secrecy component of a 'DCLabel'
secrecyDC :: DCLabel Principal -> Label Principal 
secrecyDC dcl = secrecy dcl

-- | It extracts the label corresponding to the integrity component of a 'DCLabel'
integrityDC :: DCLabel Principal -> Label Principal 
integrityDC dcl = integrity dcl

-- | It creates a 'DCLabel' based on two labels.
newDC l1 l2 = MkDCLabel l1 l2 

-- | It checks if a privilege is owned
own :: Ord a => Priv a -> DCLabel a -> Bool 
own p l = priv p `canflowto` l 

-- | It returns the privilege in the form of a label. 
privToLabel :: Priv a -> DCLabel a 
privToLabel = priv 

-- | It creates privileges. This functions must be used only by trustworthy code. 
createPriv :: DCLabel a -> Priv a 
createPriv = MkPriv