{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NoMonomorphismRestriction  #-}
{-# LANGUAGE UndecidableInstances       #-}
module Agda.TypeChecking.SizedTypes.Syntax where
import Prelude hiding ( null )
import Data.Foldable (Foldable)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Traversable (Traversable)
import Agda.TypeChecking.SizedTypes.Utils
import Agda.Utils.Functor
import Agda.Utils.Null
import Agda.Utils.Pretty
newtype Offset = O Int
  deriving (Eq, Ord, Num, Enum)
instance Show Offset where
  show (O n) = show n
instance Pretty Offset where
  pretty (O n) = pretty n
instance MeetSemiLattice Offset where
  meet = min
instance Plus Offset Offset Offset where
  plus (O x) (O y) = O (plus x y)
newtype Rigid  = RigidId { rigidId :: String }
  deriving (Eq, Ord)
instance Show Rigid where
  show (RigidId s) = "RigidId " ++ show s
instance Pretty Rigid where
  pretty = text . rigidId
newtype Flex   = FlexId { flexId :: String }
  deriving (Eq, Ord)
instance Show Flex where
  show (FlexId s) = "FlexId " ++ show s
instance Pretty Flex where
  pretty = text . flexId
data SizeExpr' rigid flex
  = Const { offset :: Offset }                   
  | Rigid { rigid  :: rigid, offset :: Offset }  
  | Infty                                        
  | Flex  { flex   :: flex, offset :: Offset }   
    deriving (Show, Eq, Ord, Functor, Foldable, Traversable)
type SizeExpr = SizeExpr' Rigid Flex
data Cmp
  = Lt  
  | Le  
  deriving (Show, Eq, Bounded, Enum)
instance Dioid Cmp where
  compose     = min
  unitCompose = Le
instance Ord Cmp where
  Lt <= x  = True
  Le <= Lt = False
  Le <= Le = True
instance MeetSemiLattice Cmp where
  meet = min
instance Top Cmp where
  top = Le
data Constraint' rigid flex = Constraint
  { leftExpr  :: SizeExpr' rigid flex
  , cmp       :: Cmp
  , rightExpr :: SizeExpr' rigid flex
  }
  deriving (Show, Functor, Foldable, Traversable)
type Constraint = Constraint' Rigid Flex
data Polarity = Least | Greatest
  deriving (Eq, Ord)
data PolarityAssignment flex = PolarityAssignment Polarity flex
type Polarities flex = Map flex Polarity
emptyPolarities :: Polarities flex
emptyPolarities = Map.empty
polaritiesFromAssignments :: Ord flex => [PolarityAssignment flex] -> Polarities flex
polaritiesFromAssignments = Map.fromList . map (\ (PolarityAssignment p x) -> (x,p))
getPolarity :: Ord flex => Polarities flex -> flex -> Polarity
getPolarity pols x = Map.findWithDefault Least x pols
newtype Solution rigid flex = Solution { theSolution :: Map flex (SizeExpr' rigid flex) }
  deriving (Show, Null)
instance (Pretty r, Pretty f) => Pretty (Solution r f) where
  pretty (Solution sol) = prettyList $ for (Map.toList sol) $ \ (x, e) ->
    pretty x <+> ":=" <+> pretty e
emptySolution :: Solution r f
emptySolution = Solution Map.empty
class Substitute r f a where
  subst :: Solution r f -> a -> a
instance Ord f => Substitute r f (SizeExpr' r f) where
  subst (Solution sol) e =
    case e of
      Flex x n -> maybe e (`plus` n) $ Map.lookup x sol
      _        -> e
instance Ord f => Substitute r f (Constraint' r f) where
  subst sol (Constraint e cmp e') = Constraint (subst sol e) cmp (subst sol e')
instance Substitute r f a => Substitute r f [a] where
  subst = map . subst
instance Substitute r f a => Substitute r f (Map k a) where
  subst = fmap . subst
instance Ord f => Substitute r f (Solution r f) where
  subst s = Solution . subst s . theSolution
instance Plus (SizeExpr' r f) Offset (SizeExpr' r f) where
  plus e m =
    case e of
      Const   n -> Const   $ n + m
      Rigid i n -> Rigid i $ n + m
      Flex x  n -> Flex x  $ n + m
      Infty     -> Infty
type CTrans r f = Constraint' r f -> Either String [Constraint' r f]
simplify1 :: (Pretty f, Pretty r, Eq r) => CTrans r f -> CTrans r f
simplify1 test c = do
  let err = Left $ "size constraint " ++ prettyShow c ++ " is inconsistent"
  case c of
    
    Constraint a           Le  Infty -> return []
    Constraint Const{}     Lt  Infty -> return []
    Constraint Infty       Lt  Infty -> err
    Constraint (Rigid i n) Lt  Infty -> test $ Constraint (Rigid i 0) Lt Infty
    Constraint a@Flex{}    Lt  Infty -> return [c { leftExpr = a { offset = 0 }}]
    
    Constraint (Const n)   cmp (Const m) -> if compareOffset n cmp m then return [] else err
    Constraint Infty       cmp  Const{}  -> err
    Constraint (Rigid i n) cmp (Const m) ->
      if compareOffset n cmp m then
        test (Constraint (Rigid i 0) Le (Const (m - n - ifLe cmp 0 1)))
       else err
    Constraint (Flex x n)  cmp (Const m) ->
      if compareOffset n cmp m
       then return [Constraint (Flex x 0) Le (Const (m - n - ifLe cmp 0 1))]
       else err
    
    Constraint Infty cmp Rigid{} -> err
    Constraint (Const m) cmp (Rigid i n) ->
      if compareOffset m cmp n then return []
      else test (Constraint (Const $ m - n) cmp (Rigid i 0))
    Constraint (Rigid j m) cmp (Rigid i n) | i == j ->
      if compareOffset m cmp n then return [] else err
    Constraint (Rigid j m) cmp (Rigid i n) -> test c
    Constraint (Flex x m)  cmp (Rigid i n) ->
      if compareOffset m cmp n
       then return [Constraint (Flex x 0) Le (Rigid i (n - m - ifLe cmp 0 1))]
       else return [Constraint (Flex x $ m - n + ifLe cmp 0 1) Le (Rigid i 0)]
    
    Constraint Infty Le (Flex x n) -> return [Constraint Infty Le (Flex x 0)]
    Constraint Infty Lt (Flex x n) -> err
    Constraint (Const m) cmp (Flex x n) ->
      if compareOffset m cmp n then return []
      else return [Constraint (Const $ m - n + ifLe cmp 0 1) Le (Flex x 0)]
    Constraint (Rigid i m) cmp (Flex x n) ->
      if compareOffset m cmp n
      then return [Constraint (Rigid i 0) cmp (Flex x $ n - m)]
      else return [Constraint (Rigid i $ m - n) cmp (Flex x 0)]
    Constraint (Flex y m) cmp (Flex x n) ->
      if compareOffset m cmp n
      then return [Constraint (Flex y 0) cmp (Flex x $ n - m)]
      else return [Constraint (Flex y $ m - n) cmp (Flex x 0)]
ifLe :: Cmp -> a -> a -> a
ifLe Le a b = a
ifLe Lt a b = b
compareOffset :: Offset -> Cmp -> Offset -> Bool
compareOffset n Le m = n <= m
compareOffset n Lt m = n <  m
instance (Pretty r, Pretty f) => Pretty (SizeExpr' r f) where
  pretty (Const n)   = pretty n
  pretty (Infty)     = "∞"
  pretty (Rigid i 0) = pretty i
  pretty (Rigid i n) = pretty i <> text ("+" ++ show n)
  pretty (Flex  x 0) = pretty x
  pretty (Flex  x n) = pretty x <> text ("+" ++ show n)
instance Pretty Polarity where
  pretty Least    = "-"
  pretty Greatest = "+"
instance Pretty flex => Pretty (PolarityAssignment flex) where
  pretty (PolarityAssignment pol flex) = pretty pol <> pretty flex
instance Pretty Cmp where
  pretty Le = "≤"
  pretty Lt = "<"
instance (Pretty r, Pretty f) => Pretty (Constraint' r f) where
  pretty (Constraint a cmp b) = pretty a <+> pretty cmp <+> pretty b
class ValidOffset a where
  validOffset :: a -> Bool
instance ValidOffset Offset where
  validOffset = (>= 0)
instance ValidOffset (SizeExpr' r f) where
  validOffset e =
    case e of
      Infty -> True
      _     -> validOffset (offset e)
class TruncateOffset a where
  truncateOffset :: a -> a
instance TruncateOffset Offset where
  truncateOffset n | n >= 0    = n
                   | otherwise = 0
instance TruncateOffset (SizeExpr' r f) where
  truncateOffset e =
    case e of
      Infty     -> e
      Const n   -> Const   $ truncateOffset n
      Rigid i n -> Rigid i $ truncateOffset n
      Flex  x n -> Flex  x $ truncateOffset n
class Rigids r a where
  rigids :: a -> Set r
instance (Ord r, Rigids r a) => Rigids r [a] where
  rigids as = Set.unions (map rigids as)
instance Rigids r (SizeExpr' r f) where
  rigids (Rigid x _) = Set.singleton x
  rigids _           = Set.empty
instance Ord r => Rigids r (Constraint' r f) where
  rigids (Constraint l _ r) = Set.union (rigids l) (rigids r)
class Flexs flex a | a -> flex where
  flexs :: a -> Set flex
instance (Ord flex, Flexs flex a) => Flexs flex [a] where
  flexs as = Set.unions (map flexs as)
instance Flexs flex (SizeExpr' rigid flex) where
  flexs (Flex x _) = Set.singleton x
  flexs _          = Set.empty
instance (Ord flex) => Flexs flex (Constraint' rigid flex) where
  flexs (Constraint l _ r) = Set.union (flexs l) (flexs r)