{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Ormolu.Printer.Operators
( OpTree (..),
opTreeLoc,
reassociateOpTree,
)
where
import Data.Function (on)
import qualified Data.List as L
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe, mapMaybe)
import GHC
import OccName (occNameString)
import Ormolu.Utils (unSrcSpan)
data OpTree ty op
= OpNode ty
| OpBranch
(OpTree ty op)
op
(OpTree ty op)
opTreeLoc :: OpTree (Located a) b -> SrcSpan
opTreeLoc (OpNode (L l _)) = l
opTreeLoc (OpBranch l _ r) = combineSrcSpans (opTreeLoc l) (opTreeLoc r)
reassociateOpTree ::
(op -> Maybe RdrName) ->
OpTree (Located ty) (Located op) ->
OpTree (Located ty) (Located op)
reassociateOpTree getOpName opTree =
reassociateOpTreeWith
(buildFixityMap getOpName normOpTree)
(getOpName . unLoc)
normOpTree
where
normOpTree = normalizeOpTree opTree
reassociateOpTreeWith ::
forall ty op.
Map String Fixity ->
(op -> Maybe RdrName) ->
OpTree ty op ->
OpTree ty op
reassociateOpTreeWith fixityMap getOpName = go
where
fixityOf :: op -> Fixity
fixityOf op = fromMaybe defaultFixity $ do
s <- occNameString . rdrNameOcc <$> getOpName op
M.lookup s fixityMap
go :: OpTree ty op -> OpTree ty op
go t@(OpNode _) = t
go t@(OpBranch (OpNode _) _ (OpNode _)) = t
go (OpBranch l@(OpNode _) op (OpBranch l' op' r')) =
go (OpBranch (OpBranch l op l') op' r')
go (OpBranch (OpBranch l op r) op' r'@(OpNode _)) =
if snd $ compareFixity (fixityOf op) (fixityOf op')
then OpBranch l op (go $ OpBranch r op' r')
else OpBranch (OpBranch l op r) op' r'
go (OpBranch (OpBranch l op r) op' (OpBranch l' op'' r')) =
if snd $ compareFixity (fixityOf op) (fixityOf op')
then go $ OpBranch (OpBranch l op (go $ OpBranch r op' l')) op'' r'
else go $ OpBranch (OpBranch (OpBranch l op r) op' l') op'' r'
data Score
=
AtBeginning Int
|
AtEnd
|
InBetween
deriving (Eq, Ord)
buildFixityMap ::
forall ty op.
(op -> Maybe RdrName) ->
OpTree (Located ty) (Located op) ->
Map String Fixity
buildFixityMap getOpName opTree =
addOverrides
. M.fromList
. concatMap (\(i, ns) -> map (\(n, _) -> (n, fixity i InfixL)) ns)
. zip [2 ..]
. L.groupBy ((==) `on` snd)
. selectScores
$ score opTree
where
addOverrides :: Map String Fixity -> Map String Fixity
addOverrides m =
M.fromList
[ ("$", fixity 0 InfixR),
(":", fixity 1 InfixR),
(".", fixity 100 InfixL)
]
`M.union` m
fixity = Fixity NoSourceText
score :: OpTree (Located ty) (Located op) -> [(String, Score)]
score (OpNode _) = []
score (OpBranch l o r) = fromMaybe (score r) $ do
le <- srcSpanEndLine <$> unSrcSpan (opTreeLoc l)
ob <- srcSpanStartLine <$> unSrcSpan (getLoc o)
oe <- srcSpanEndLine <$> unSrcSpan (getLoc o)
rb <- srcSpanStartLine <$> unSrcSpan (opTreeLoc r)
oc <- srcSpanStartCol <$> unSrcSpan (getLoc o)
opName <- occNameString . rdrNameOcc <$> getOpName (unLoc o)
let s
| le < ob = AtBeginning oc
| oe < rb = AtEnd
| otherwise = InBetween
return $ (opName, s) : score r
selectScores :: [(String, Score)] -> [(String, Score)]
selectScores =
L.sortOn snd
. mapMaybe
( \case
[] -> Nothing
xs@((n, _) : _) -> Just (n, selectScore $ map snd xs)
)
. L.groupBy ((==) `on` fst)
. L.sort
selectScore :: [Score] -> Score
selectScore xs =
case filter (/= InBetween) xs of
[] -> InBetween
xs' -> maximum xs'
normalizeOpTree :: OpTree ty op -> OpTree ty op
normalizeOpTree (OpNode n) =
OpNode n
normalizeOpTree (OpBranch (OpNode l) lop r) =
OpBranch (OpNode l) lop (normalizeOpTree r)
normalizeOpTree (OpBranch (OpBranch l' lop' r') lop r) =
normalizeOpTree (OpBranch l' lop' (OpBranch r' lop r))