{-# LANGUAGE CPP                        #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NoMonomorphismRestriction  #-}
{-# LANGUAGE UndecidableInstances       #-}

-- | Syntax of size expressions and constraints.

module Agda.TypeChecking.SizedTypes.Syntax where

#if MIN_VERSION_base(4,11,0)
import Prelude hiding ( (<>), null )
#else
import Prelude hiding ( null )
#endif

import Data.Maybe
import Data.Foldable (Foldable)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Traversable (Traversable)

import Agda.TypeChecking.SizedTypes.Utils

import Agda.Utils.Functor
import Agda.Utils.Null
import Agda.Utils.Pretty

-- * Syntax

-- | Constant finite sizes @n >= 0@.
newtype Offset = O Int
  deriving (Eq, Ord, Num, Enum)

-- This Show instance is ok because of the Enum constraint.
instance Show Offset where
  show (O n) = show n

instance Pretty Offset where
  pretty (O n) = pretty n

instance MeetSemiLattice Offset where
  meet = min

instance Plus Offset Offset Offset where
  plus (O x) (O y) = O (plus x y)

-- | Fixed size variables @i@.
newtype Rigid  = RigidId { rigidId :: String }
  deriving (Eq, Ord)

instance Show Rigid where
  show (RigidId s) = "RigidId " ++ show s

instance Pretty Rigid where
  pretty = text . rigidId

-- | Size meta variables @X@ to solve for.
newtype Flex   = FlexId { flexId :: String }
  deriving (Eq, Ord)

instance Show Flex where
  show (FlexId s) = "FlexId " ++ show s

instance Pretty Flex where
  pretty = text . flexId

-- | Size expressions appearing in constraints.
data SizeExpr' rigid flex
  = Const { offset :: Offset }                   -- ^ Constant number @n@.
  | Rigid { rigid  :: rigid, offset :: Offset }  -- ^ Variable plus offset @i + n@.
  | Infty                                        -- ^ Infinity @∞@.
  | Flex  { flex   :: flex, offset :: Offset }   -- ^ Meta variable @X + n@.
    deriving (Show, Eq, Ord, Functor, Foldable, Traversable)

type SizeExpr = SizeExpr' Rigid Flex

-- | Comparison operator, e.g. for size expression.
data Cmp
  = Lt  -- ^ @<@.
  | Le  -- ^ @≤@.
  deriving (Show, Eq, Bounded, Enum)

instance Dioid Cmp where
  compose     = min
  unitCompose = Le

-- | Comparison operator is ordered @'Lt' < 'Le'@.
instance Ord Cmp where
  Lt <= x  = True
  Le <= Lt = False
  Le <= Le = True

instance MeetSemiLattice Cmp where
  meet = min

instance Top Cmp where
  top = Le

-- | Constraint: an inequation between size expressions,
--   e.g. @X < ∞@ or @i + 3 ≤ j@.
data Constraint' rigid flex = Constraint
  { leftExpr  :: SizeExpr' rigid flex
  , cmp       :: Cmp
  , rightExpr :: SizeExpr' rigid flex
  }
  deriving (Show, Functor, Foldable, Traversable)

type Constraint = Constraint' Rigid Flex

-- * Polarities to specify solutions.
------------------------------------------------------------------------

-- | What type of solution are we looking for?
data Polarity = Least | Greatest
  deriving (Eq, Ord)

-- | Assigning a polarity to a flexible variable.
data PolarityAssignment flex = PolarityAssignment Polarity flex

-- | Type of solution wanted for each flexible.
type Polarities flex = Map flex Polarity

emptyPolarities :: Polarities flex
emptyPolarities = Map.empty

polaritiesFromAssignments :: Ord flex => [PolarityAssignment flex] -> Polarities flex
polaritiesFromAssignments = Map.fromList . map (\ (PolarityAssignment p x) -> (x,p))

-- | Default polarity is 'Least'.
getPolarity :: Ord flex => Polarities flex -> flex -> Polarity
getPolarity pols x = Map.findWithDefault Least x pols

-- * Solutions.
------------------------------------------------------------------------

-- | Partial substitution from flexible variables to size expression.
newtype Solution rigid flex = Solution { theSolution :: Map flex (SizeExpr' rigid flex) }
  deriving (Show, Null)

instance (Pretty r, Pretty f) => Pretty (Solution r f) where
  pretty (Solution sol) = prettyList $ for (Map.toList sol) $ \ (x, e) ->
    pretty x <+> ":=" <+> pretty e

emptySolution :: Solution r f
emptySolution = Solution Map.empty

-- | Executing a substitution.
class Substitute r f a where
  subst :: Solution r f -> a -> a

instance Ord f => Substitute r f (SizeExpr' r f) where
  subst (Solution sol) e =
    case e of
      Flex x n -> maybe e (`plus` n) $ Map.lookup x sol
      _        -> e

instance Ord f => Substitute r f (Constraint' r f) where
  subst sol (Constraint e cmp e') = Constraint (subst sol e) cmp (subst sol e')

instance Substitute r f a => Substitute r f [a] where
  subst = map . subst

instance Substitute r f a => Substitute r f (Map k a) where
  subst = fmap . subst

instance Ord f => Substitute r f (Solution r f) where
  subst s = Solution . subst s . theSolution

-- | Add offset to size expression.
instance Plus (SizeExpr' r f) Offset (SizeExpr' r f) where
  plus e m =
    case e of
      Const   n -> Const   $ n + m
      Rigid i n -> Rigid i $ n + m
      Flex x  n -> Flex x  $ n + m
      Infty     -> Infty

-- * Constraint simplification

type CTrans r f = Constraint' r f -> Either String [Constraint' r f]

-- | Returns an error message if we have a contradictory constraint.
simplify1 :: (Pretty f, Pretty r, Eq r) => CTrans r f -> CTrans r f
simplify1 test c = do
  let err = Left $ "size constraint " ++ prettyShow c ++ " is inconsistent"
  case c of
    -- rhs is Infty
    Constraint a           Le  Infty -> return []
    Constraint Const{}     Lt  Infty -> return []
    Constraint Infty       Lt  Infty -> err
    Constraint (Rigid i n) Lt  Infty -> test $ Constraint (Rigid i 0) Lt Infty
    Constraint a@Flex{}    Lt  Infty -> return [c { leftExpr = a { offset = 0 }}]

    -- rhs is Const
    Constraint (Const n)   cmp (Const m) -> if compareOffset n cmp m then return [] else err
    Constraint Infty       cmp  Const{}  -> err
    Constraint (Rigid i n) cmp (Const m) ->
      if compareOffset n cmp m then
        test (Constraint (Rigid i 0) Le (Const (m - n - ifLe cmp 0 1)))
       else err
    Constraint (Flex x n)  cmp (Const m) ->
      if compareOffset n cmp m
       then return [Constraint (Flex x 0) Le (Const (m - n - ifLe cmp 0 1))]
       else err

    -- rhs is Rigid
    Constraint Infty cmp Rigid{} -> err
    Constraint (Const m) cmp (Rigid i n) ->
      if compareOffset m cmp n then return []
      else test (Constraint (Const $ m - n) cmp (Rigid i 0))
    Constraint (Rigid j m) cmp (Rigid i n) | i == j ->
      if compareOffset m cmp n then return [] else err
    Constraint (Rigid j m) cmp (Rigid i n) -> test c
    Constraint (Flex x m)  cmp (Rigid i n) ->
      if compareOffset m cmp n
       then return [Constraint (Flex x 0) Le (Rigid i (n - m - ifLe cmp 0 1))]
       else return [Constraint (Flex x $ m - n + ifLe cmp 0 1) Le (Rigid i 0)]

    -- rhs is Flex
    Constraint Infty Le (Flex x n) -> return [Constraint Infty Le (Flex x 0)]
    Constraint Infty Lt (Flex x n) -> err
    Constraint (Const m) cmp (Flex x n) ->
      if compareOffset m cmp n then return []
      else return [Constraint (Const $ m - n + ifLe cmp 0 1) Le (Flex x 0)]
    Constraint (Rigid i m) cmp (Flex x n) ->
      if compareOffset m cmp n
      then return [Constraint (Rigid i 0) cmp (Flex x $ n - m)]
      else return [Constraint (Rigid i $ m - n) cmp (Flex x 0)]
    Constraint (Flex y m) cmp (Flex x n) ->
      if compareOffset m cmp n
      then return [Constraint (Flex y 0) cmp (Flex x $ n - m)]
      else return [Constraint (Flex y $ m - n) cmp (Flex x 0)]

-- | 'Le' acts as 'True', 'Lt' as 'False'.
ifLe :: Cmp -> a -> a -> a
ifLe Le a b = a
ifLe Lt a b = b

-- | Interpret 'Cmp' as relation on 'Offset'.
compareOffset :: Offset -> Cmp -> Offset -> Bool
compareOffset n Le m = n <= m
compareOffset n Lt m = n <  m

-- * Printing

instance (Pretty r, Pretty f) => Pretty (SizeExpr' r f) where
  pretty (Const n)   = pretty n
  pretty (Infty)     = "∞"
  pretty (Rigid i 0) = pretty i
  pretty (Rigid i n) = pretty i <> text ("+" ++ show n)
  pretty (Flex  x 0) = pretty x
  pretty (Flex  x n) = pretty x <> text ("+" ++ show n)

instance Pretty Polarity where
  pretty Least    = "-"
  pretty Greatest = "+"

instance Pretty flex => Pretty (PolarityAssignment flex) where
  pretty (PolarityAssignment pol flex) = pretty pol <> pretty flex

instance Pretty Cmp where
  pretty Le = "≤"
  pretty Lt = "<"

instance (Pretty r, Pretty f) => Pretty (Constraint' r f) where
  pretty (Constraint a cmp b) = pretty a <+> pretty cmp <+> pretty b

-- * Wellformedness

-- | Offsets @+ n@ must be non-negative
class ValidOffset a where
  validOffset :: a -> Bool

instance ValidOffset Offset where
  validOffset = (>= 0)

instance ValidOffset (SizeExpr' r f) where
  validOffset e =
    case e of
      Infty -> True
      _     -> validOffset (offset e)

-- | Make offsets non-negative by rounding up.
class TruncateOffset a where
  truncateOffset :: a -> a

instance TruncateOffset Offset where
  truncateOffset n | n >= 0    = n
                   | otherwise = 0

instance TruncateOffset (SizeExpr' r f) where
  truncateOffset e =
    case e of
      Infty     -> e
      Const n   -> Const   $ truncateOffset n
      Rigid i n -> Rigid i $ truncateOffset n
      Flex  x n -> Flex  x $ truncateOffset n

-- * Computing variable sets

-- | The rigid variables contained in a pice of syntax.
class Rigids r a where
  rigids :: a -> Set r

instance (Ord r, Rigids r a) => Rigids r [a] where
  rigids as = Set.unions (map rigids as)

instance Rigids r (SizeExpr' r f) where
  rigids (Rigid x _) = Set.singleton x
  rigids _           = Set.empty

instance Ord r => Rigids r (Constraint' r f) where
  rigids (Constraint l _ r) = Set.union (rigids l) (rigids r)

-- | The flexibe variables contained in a pice of syntax.
class Flexs flex a | a -> flex where
  flexs :: a -> Set flex

instance (Ord flex, Flexs flex a) => Flexs flex [a] where
  flexs as = Set.unions (map flexs as)

instance Flexs flex (SizeExpr' rigid flex) where
  flexs (Flex x _) = Set.singleton x
  flexs _          = Set.empty

instance (Ord flex) => Flexs flex (Constraint' rigid flex) where
  flexs (Constraint l _ r) = Set.union (flexs l) (flexs r)