{-|
Module      : AOC
Description : Utility functions commonly used while solving Advent of Code puzzles
Copyright   : (c) M1n3c4rt, 2025
License     : BSD-3-Clause
Maintainer  : vedicbits@gmail.com
Stability   : stable
-}

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# OPTIONS_GHC -Wno-type-defaults #-}
{-# LANGUAGE LambdaCase #-}

module Utility.AOC (
    -- * Pathfinding algorithms
    -- $cat1
    shortestDistance,
    shortestPaths,
    shortestDistanceWith,
    shortestPathsWith,
    -- * Neighbour functions
    neighbours4,
    neighbours8,
    neighbours6,
    neighbours26,
    -- * Taxicab (Manhattan) distance
    taxicab2,
    taxicab3,
    taxicab,
    -- * Input parsing
    -- $cat2
    enumerate',
    enumerateRead',
    enumerate,
    enumerateRead,
    enumerateHM,
    enumerateReadHM,
    enumerateFilter,
    enumerateFilterSet,
    numbers,
    numbers',
    binToDec,
    -- * Flood fill
    floodFill,
    floodFillWith,
    -- * List selection
    choose,
    permute,
    takeEvery,
    chunk,
    -- * Extrapolation
    extrapolate,
    -- * Debugging
    prettyPrintSet,
    prettyPrintSetWide,
    prettyPrintHM,
    prettyPrintHMWide,
    traceSleep,
    traceSleepSeconds,
    -- * Memoization
    -- $cat3
    memo2,
    memo3,
    memo4,
    -- * Range operations
    range,
    rangeIntersect,
    rangeSubtract,
    -- * ComplexL1 datatype
    ComplexL1 (..),
    realPart,
    imagPart,
    conjugate,
    magnitude,
    -- * Visualization
    displayPoints,
    displayPolygon,
    displayPolygons,
    displayGraphDirected,
    displayGraphUndirected,
    -- Export types/constraints used in top-level function type signatures
    Hashable,
    HM.HashMap,
    S.Set
) where

import qualified Data.HashMap.Strict as HM
import Data.Hashable (Hashable (hashWithSalt, hash))
import qualified Data.Set as S
import qualified Data.Heap as H
import qualified Data.Map as M
import Data.List (permutations, genericIndex, groupBy)
import Data.Maybe (fromMaybe, catMaybes, mapMaybe, listToMaybe)
import System.IO.Unsafe (unsafePerformIO)
import Control.Concurrent (threadDelay, newMVar, readMVar, modifyMVar_)
import Data.Function (on)
import SDL hiding (conjugate)
import Data.Text (pack)
-- all of these are for the ComplexL1 instances
import Data.Data (Data)
import GHC.Generics (Generic, Generic1)
import Foreign (Storable (sizeOf, peek, poke, peekElemOff, pokeElemOff, alignment), castPtr)
import qualified GHC.Ptr
import Control.Monad.Zip (MonadZip (mzipWith))
import Control.Monad.Fix (MonadFix (mfix))
import Data.Bifunctor (Bifunctor(bimap))
-- end

createMinPrioHeap :: Ord a1 => (a1,a) -> H.MinPrioHeap a1 a
createMinPrioHeap = H.singleton

-- $cat1
-- All of the following functions return distances as a @Maybe Int@, where @Nothing@ is returned if no path is found.
-- 
-- The graph is a @HashMap@ mapping each node to a sequence of (neighbour, edge weight) pairs.
-- 
-- Functions that take a graph assume that any neighbour that a node in the graph points to is also in the graph.

-- | Returns the shortest distance between two nodes in a graph.
shortestDistance :: (Foldable t, Hashable n, Ord a, Num a)
    => HM.HashMap n (t (n, a)) -- ^ Graph
    -> n -- ^ Start node
    -> n -- ^ End node
    -> Maybe a
shortestDistance graph = shortestDistanceWith (\n -> fromMaybe (error "Node not in graph") $ HM.lookup n graph)

-- | Returns the shortest distance between two nodes in a graph and a list of all possible paths from the ending node to the starting node.
-- 
-- The starting node is not included in each path.
shortestPaths :: (Foldable t, Hashable n, Ord a, Num a)
    => HM.HashMap n (t (n, a)) -- ^ Graph
    -> n -- ^ Start node
    -> n -- ^ End node
    -> (Maybe a, [[n]])
shortestPaths graph = shortestPathsWith (\n -> fromMaybe (error "Node not in graph") $ HM.lookup n graph)

-- | Given a function that takes a node and returns a sequence of (neighbour,edge weight) pairs, returns the shortest distance between two nodes.
shortestDistanceWith :: (Foldable t, Hashable n, Ord a, Num a)
    => (n -> t (n, a)) -- ^ Function to generate neighbours
    -> n -- ^ Start node
    -> n -- ^ End node
    -> Maybe a
shortestDistanceWith f start end = fst $ shortestPathsWith f start end

-- | Given a function that takes a node and returns a sequence of (neighbour,edge weight) pairs, returns the shortest distance between two nodes and a list of all possible paths from the ending node to the starting node.
-- 
-- The starting node is not included in each path.
shortestPathsWith :: (Foldable t, Hashable n, Ord a, Num a)
    => (n -> t (n, a)) -- ^ Function to generate neighbours
    -> n -- ^ Start node
    -> n -- ^ End node
    -> (Maybe a, [[n]])
shortestPathsWith f start end =
    let initQueue = createMinPrioHeap (0,start)
        initPaths = HM.singleton start (0,[[]])
        helper (paths,queue) = case H.view queue of
            Nothing -> (paths,queue)
            Just ((_,n),ns) ->
                let Just (currentDistance,currentPaths) = HM.lookup n paths
                    neighbours = f n
                    updateNeighbour (n',d') (p',q') = case HM.lookup n' p' of
                        Nothing -> (HM.insert n' (currentDistance+d',map (n':) currentPaths) p', H.insert (currentDistance+d',n') $ H.filter ((/=n') . snd) q')
                        Just (d'',ps'') ->
                            if d'' < currentDistance+d' then
                                (p',q')
                            else if d'' > currentDistance+d' then
                                (HM.insert n' (currentDistance+d',map (n':) currentPaths) p', H.insert (currentDistance+d',n') $ H.filter ((/=n') . snd) q')
                            else
                                (HM.insert n' (currentDistance+d',ps'' ++ map (n':) currentPaths) p', q')
                in helper $ foldr updateNeighbour (paths,ns) neighbours

    in case HM.lookup end $ fst (helper (initPaths,initQueue)) of
        Nothing -> (Nothing, [])
        Just (d,ps) -> (Just d, ps)

-- | Returns the 4 points orthogonally adjacent to the given point.
neighbours4 :: (Num a, Num b) => (a, b) -> [(a, b)]
neighbours4 (x,y) = [(x+1,y),(x,y+1),(x-1,y),(x,y-1)]

-- | Returns the 8 points orthogonally or diagonally adjacent to the given point.
neighbours8 :: (Eq a, Eq b, Num a, Num b) => (a, b) -> [(a, b)]
neighbours8 (x,y) = [(x+p,y+q) | p <- [-1,0,1], q <- [-1,0,1], p /= 0 || q /= 0]

-- | Returns the 6 points orthogonally adjacent to the given point in 3D space.
neighbours6 :: (Num a, Num b, Num c) => (a, b, c) -> [(a, b, c)]
neighbours6 (x,y,z) = [(x+1,y,z),(x,y+1,z),(x,y,z+1),(x-1,y,z),(x,y-1,z),(x,y,z-1)]

-- | Returns the 26 points orthogonally or diagonally adjacent to the given point in 3D space.
neighbours26 :: (Eq a, Eq b, Eq c, Num a, Num b, Num c) => (a, b, c) -> [(a, b, c)]
neighbours26 (x,y,z) = [(x+p,y+q,z+r) | p <- [-1,0,1], q <- [-1,0,1], r <- [-1,0,1], p /= 0 || q /= 0 || r /= 0]

-- | Returns the Taxicab/Manhattan distance between two points in 2D space.
taxicab2 :: Num a => (a, a) -> (a, a) -> a
taxicab2 (a,b) (c,d) = abs (a-c) + abs (b-d)

-- | Returns the Taxicab/Manhattan distance between two points in 3D space.
taxicab3 :: Num a => (a, a, a) -> (a, a, a) -> a
taxicab3 (a,b,c) (d,e,f) = abs (a-d) + abs (b-e) + abs (c-f)

-- | Returns the Taxicab/Manhattan distance between two points in n dimensions, where both points are lists of length n.
taxicab :: Num a => [a] -> [a] -> a
taxicab as bs = sum $ zipWith (\x y -> abs (x-y)) as bs

-- $cat2
-- The following functions (beginning with "enumerate") operate on a grid of characters as a string with a newline after each row (as seen in several Advent of Code puzzle inputs).

-- | Converts a grid to a list of pairs @((x,y),c)@ representing xy coordinates and the character at that location.
enumerate' :: (Num y, Num x) => String -> [((x, y), Char)]
enumerate' s =
    let ss = lines s
        ys = zipWith (\n l -> map (n,) l) (iterate (+1) 0) ss
        xs = map (zipWith (\x (y,c) -> ((x,y),c)) (iterate (+1) 0)) ys
    in concat xs

-- | Enumerates a grid along with reading the characters (usually as integers), and returns a list of pairs.
enumerateRead' :: (Read c, Num y, Num x) => String -> [((x, y), c)]
enumerateRead' = map (\((x,y),c) -> ((x,y),read [c])) . enumerate'

-- | Converts a grid to a list of triples @(x,y,c)@ representing xy coordinates and the character at that location.
enumerate :: (Num y, Num x) => String -> [(x, y, Char)]
enumerate = map (\((x,y),c) -> (x,y,c)) . enumerate'

-- | Enumerates a grid along with reading the characters (usually as integers), and returns a list of triples.
enumerateRead :: (Read c, Num y, Num x) => String -> [(x, y, c)]
enumerateRead = map (\((x,y),c) -> (x,y,read [c])) . enumerate'

-- | Enumerates a grid and stores it in a @HashMap@ where points are mapped to the character at that location.
enumerateHM :: (Num x, Num y, Enum x, Enum y, Hashable x, Hashable y) => String -> HM.HashMap (x, y) Char
enumerateHM = HM.fromList . enumerate'

-- | Enumerates a grid and stores it in a @HashMap@ along with reading the characters (usually as integers).
enumerateReadHM :: (Num x, Num y, Enum x, Enum y, Hashable x, Hashable y, Read c) => String -> HM.HashMap (x, y) c
enumerateReadHM = HM.fromList . map (\((x,y),c) -> ((x,y),read [c])) . enumerate'

-- | Returns a list of points on a grid for which a certain condition is met.
enumerateFilter :: (Num y, Num x) => (Char -> Bool) -> String -> [(x, y)]
enumerateFilter f = map fst . filter (f . snd) . enumerate'

-- | Returns a set of points on a grid for which a certain condition is met.
enumerateFilterSet :: (Ord x, Ord y, Num y, Num x) => (Char -> Bool) -> String -> S.Set (x, y)
enumerateFilterSet f = S.fromList . enumerateFilter f

-- | Returns all the integers in a string (including negative signs).
numbers :: (Num a, Read a) => String -> [a]
numbers = map read . filter (isDigit . head) . groupBy ((==) `on` isDigit)
    where isDigit = (`elem` "1234567890-")

-- | Returns all the integers in a string (excluding negative signs).
numbers' :: (Num a, Read a) => String -> [a]
numbers' = map read . filter (isDigit . head) . groupBy ((==) `on` isDigit)
    where isDigit = (`elem` "1234567890")

-- | Converts a list of booleans (parsed as a binary number) to an integer.
binToDec :: Num a => [Bool] -> a
binToDec = sum . zipWith (*) (map (2^) [0..]) . map (fromIntegral . fromEnum) . reverse

floodFill' :: Ord a => (a -> [a]) -> S.Set a -> [a] -> S.Set a -> S.Set a
floodFill' neighbours finished (f:frontier) blocks = floodFill' neighbours (S.insert f finished) (frontier++filtered) blocks
    where filtered = filter (\n -> n `S.notMember` finished && n `notElem` frontier && n `S.notMember` blocks) $ neighbours f
floodFill' _ finished [] _ = finished

floodFillWith' :: Ord a => (a -> a -> Bool) -> (a -> [a]) -> S.Set a -> [a] -> S.Set a
floodFillWith' cond neighbours finished (f:frontier) = floodFillWith' cond neighbours (S.insert f finished) (frontier++filtered)
    where filtered = filter (\n -> n `S.notMember` finished && n `notElem` frontier && cond f n) $ neighbours f
floodFillWith' _ _ finished [] = finished

-- | Applies a flood fill algorithm given a function to generate a point's neighbours, a starting set of points, and a set of points to avoid. Returns a set of all points covered.
floodFill :: Ord a
    => (a -> [a]) -- ^ Neighbour function
    -> S.Set a -- ^ Initial set of points
    -> S.Set a -- ^ Set of points to avoid
    -> S.Set a
floodFill neighbours frontier = floodFill' neighbours S.empty (S.toList frontier)

-- | Applies a flood fill algorithm given a function to generate a point's neighbours, a condition that filters out points generated by said function, and a starting set of points. Returns a set of all points covered.
-- 
-- The condition is of the form @a -> a -> Bool@, which returns @True@ if the second point is a valid neighbour of the first point and @False@ otherwise.
floodFillWith :: Ord a
    => (a -> a -> Bool) -- ^ Condition
    -> (a -> [a]) -- ^ Neighbour function
    -> S.Set a -- ^ Initial set of points
    -> S.Set a
floodFillWith cond neighbours frontier = floodFillWith' cond neighbours S.empty (S.toList frontier)

-- | Generates a list of all possible lists of length n by taking elements from the provided list of length l.
-- 
-- Relative order is maintained, and the length of the returned list is \(_{n}C_{l}\).
choose :: (Num n, Ord n) => n -> [a] -> [[a]]
choose 0 _ = [[]]
choose _ [] = []
choose n (x:xs)
    | n > fromIntegral (length (x:xs)) = []
    | otherwise = map (x:) (choose (n-1) xs) ++ choose n xs

-- | Generates a list of all possible lists of length n by taking elements from the provided list of length l.
-- 
-- The length of the returned list is \(_{n}P_{l}\).
permute :: (Num n, Ord n) => n -> [a] -> [[a]]
permute n = concatMap permutations . choose n

-- | Takes every nth element from a list xs, starting from @xs !! (n-1)@.
takeEvery :: Int -> [a] -> [a]
takeEvery _ [] = []
takeEvery n xs = let (a,b) = splitAt n xs in if length a < n then [] else last a:takeEvery n b

-- | Splits a list into sublists of size n. The length of the last sublist may be less than n.
chunk :: Int -> [a] -> [[a]]
chunk _ [] = []
chunk n xs = let (a,b) = splitAt n xs in a:chunk n b

-- | Gets the nth element of an infinite list, assuming that each element in the list can be generated using the previous element, for example, a list generated with @iterate@.
extrapolate :: (Integral b, Ord a) => b -> [a] -> a
extrapolate n ls = let (o,p) = helper 0 S.empty ls in ls `genericIndex` ((n-o) `mod` p + o)
    where
        helper k finished (l:ls')
            | S.null matches = helper (k+1) (S.insert (k,l) finished) ls'
            | otherwise = let o = fst $ S.elemAt 0 matches in (o,k-o)
            where matches = S.filter ((==l) . snd) finished

-- | Converts a set of points @(x,y)@ to a string composed of @'#'@ and @' '@. This function is useful when displaying puzzle answers formed by a grid of points.
-- 
-- Up to translation of points, @prettyPrintSet . enumerateFilterSet (==\'#\') = id@.
prettyPrintSet :: (Enum b, Enum a, Ord a, Ord b) => S.Set (a, b) -> String
prettyPrintSet points = unlines [[if (x,y) `S.member` points then '#' else ' ' | x <- [xmin..xmax]] | y <- reverse [ymin..ymax]]
    where
        xs = S.map fst points
        ys = S.map snd points
        (xmin,xmax,ymin,ymax) = (minimum xs, maximum xs, minimum ys, maximum ys)

-- | Same as @prettyPrintSet@, but displays points at double width to improve readability.
prettyPrintSetWide :: (Enum b, Enum a, Ord a, Ord b) => S.Set (a, b) -> String
prettyPrintSetWide = foldr (\c acc -> if c /= '\n' then c:c:acc else c:acc) [] . prettyPrintSet

-- | Converts a @HashMap@ of points @(x,y)@ and characters @c@ to a string with the corresponding character at each point. This function is useful when displaying puzzle answers formed by a grid of points.
-- 
-- Up to translation of points, @prettyPrintHM . enumerateHM = id@.
prettyPrintHM :: (Enum b, Enum a, Hashable a, Hashable b, Ord a, Ord b) => HM.HashMap (a, b) Char -> String
prettyPrintHM points = unlines [[HM.lookupDefault ' ' (x,y) points | x <- [xmin..xmax]] | y <- reverse [ymin..ymax]]
    where
        xs = map fst $ HM.keys points
        ys = map snd $ HM.keys points
        (xmin,xmax,ymin,ymax) = (minimum xs, maximum xs, minimum ys, maximum ys)

-- | Same as @prettyPrintHM@, but displays points at double width to improve readability.
prettyPrintHMWide :: (Enum b, Enum a, Hashable a, Hashable b, Ord a, Ord b) => HM.HashMap (a, b) Char -> String
prettyPrintHMWide = foldr (\c acc -> if c /= '\n' then c:c:acc else c:acc) [] . prettyPrintHM

{-# NOINLINE traceSleep #-}
-- | Pauses execution for n microseconds, before returning the second argument as its result. Useful for slowing down output that normally floods the terminal.
-- 
-- Like functions exported by Debug.Trace, this function should only be used for debugging.
-- 
-- The function is not referentially transparent: its type indicates that it is a pure function but it has the side effect of delaying execution.
traceSleep :: Int -> a -> a
traceSleep n x = unsafePerformIO $ do
    threadDelay n
    return x

{-# NOINLINE traceSleepSeconds #-}
-- | Pauses execution for n seconds. See @traceSleep@.
traceSleepSeconds :: Int -> a -> a
traceSleepSeconds n = traceSleep (n*1000000)

-- $cat3
-- Memoize a function with multiple arguments. Uses @memo@ from @Data.MemoUgly@ with slight modifications.

memo2 :: (Ord a, Ord b) => (a -> b -> c) -> (a -> b -> c)
memo2 f = unsafePerformIO $ do
    v <- newMVar M.empty
    let f' a b = unsafePerformIO $ do
            m <- readMVar v
            case M.lookup (a,b) m of
                Nothing -> do let { r = f a b }; modifyMVar_ v (return . M.insert (a,b) r); return r
                Just r  -> return r
    return f'
memo3 :: (Ord a, Ord b, Ord c) => (a -> b -> c -> d) -> (a -> b -> c -> d)
memo3 f = unsafePerformIO $ do
    v <- newMVar M.empty
    let f' a b c = unsafePerformIO $ do
            m <- readMVar v
            case M.lookup (a,b,c) m of
                Nothing -> do let { r = f a b c }; modifyMVar_ v (return . M.insert (a,b,c) r); return r
                Just r  -> return r
    return f'
memo4 :: (Ord a, Ord b, Ord c, Ord d) => (a -> b -> c -> d -> e) -> (a -> b -> c -> d -> e)
memo4 f = unsafePerformIO $ do
    v <- newMVar M.empty
    let f' a b c d = unsafePerformIO $ do
            m <- readMVar v
            case M.lookup (a,b,c,d) m of
                Nothing -> do let { r = f a b c d }; modifyMVar_ v (return . M.insert (a,b,c,d) r); return r
                Just r  -> return r
    return f'

-- | Generates a range with @[x..y]@, but reverses the list instead of returning an empty range if x > y.
range :: (Ord a, Enum a) => a -> a -> [a]
range x y = if y < x then [x,pred x..y] else [x..y]

-- | Takes @(a,b)@ and @(c,d)@ as arguments and returns the intersection of the ranges @[a..b]@ and @[c..d]@ as another pair if it is not empty.
rangeIntersect :: Ord b => (b, b) -> (b, b) -> Maybe (b, b)
rangeIntersect (a,b) (c,d)
    | b < c || a > d = Nothing
    | otherwise = Just (max a c, min b d)

-- | Takes @(a,b)@ and @(c,d)@ as arguments and returns every element in @[c..d]@ that is not in @[a..b]@ as a list of pairs of length at most 2.
rangeSubtract :: (Ord a, Num a) => (a, a) -> (a, a) -> [(a, a)]
rangeSubtract (a,b) (c,d) = case rangeIntersect (a,b) (c,d) of
    Nothing -> [(c,d)]
    Just (x,y) -> catMaybes [
            if x == a then Just (c,a-1) else Nothing,
            if y == b then Just (b+1,d) else Nothing
        ]

infix 6 :+
-- | A @Complex@ number type whose instance requires the constituent type to only be part of @Num@ instead of @RealFloat@.
-- 
-- As a consequence of this, the @abs@ and @magnitude@ functions use the L1 metric instead of the L2 metric (which is quite useful for Advent of Code puzzles).
-- 
-- Also, @signum@ is performed element-wise and can only return one of 8 different values.
-- 
-- This breaks the equality @abs z * signum z == z@ that is true in the Complex datatype.
-- 
-- The functions @mkPolar@, @cis@, @polar@ and @phase@ are unimplemented for obvious reasons.
-- 
-- Unlike @Complex@, Ord and Hashable are supported. The Ord instance treats the values as tuples.
-- 
-- The name of the datatype contains \"L1\" in order to reflect these changes.
data ComplexL1 a = a :+ a deriving (Eq, Ord, Read, Show, Data, Generic, Generic1, Functor, Foldable, Traversable)

-- | Real part of a complex number.
realPart :: ComplexL1 a -> a
realPart (a :+ _) = a
-- | Imaginary part of a complex number.
imagPart :: ComplexL1 a -> a
imagPart (_ :+ b) = b

-- | Conjugate of a complex number.
conjugate :: Num a => ComplexL1 a -> ComplexL1 a
conjugate (a :+ b) = a :+ negate b
-- | (Taxicab) magnitude of a complex number.
magnitude :: Num a => ComplexL1 a -> a
magnitude (a :+ b) = abs a + abs b

instance Num a => Num (ComplexL1 a) where
    (+) :: Num a => ComplexL1 a -> ComplexL1 a -> ComplexL1 a
    (a:+b) + (c:+d) = (a+c) :+ (b+d)
    (*) :: Num a => ComplexL1 a -> ComplexL1 a -> ComplexL1 a
    (a:+b) * (c:+d) = (a*c - b*d) :+ (b*c + a*d)
    abs :: Num a => ComplexL1 a -> ComplexL1 a
    abs (a :+ b) = (abs a + abs b) :+ 0
    signum :: Num a => ComplexL1 a -> ComplexL1 a
    signum (a :+ b) = signum a :+ signum b
    fromInteger :: Num a => Integer -> ComplexL1 a
    fromInteger n = fromInteger n :+ 0
    negate :: Num a => ComplexL1 a -> ComplexL1 a
    negate (a :+ b) = (a :+ b) * ((-1) :+ 0)

instance Hashable a => Hashable (ComplexL1 a) where
    hashWithSalt :: Hashable a => Int -> ComplexL1 a -> Int
    hashWithSalt salt (a :+ b) = hashWithSalt (hashWithSalt salt a) b

instance Storable a => Storable (ComplexL1 a) where
    sizeOf :: Storable a => ComplexL1 a -> Int
    sizeOf a       = 2 * sizeOf (realPart a)
    alignment :: Storable a => ComplexL1 a -> Int
    alignment a    = alignment (realPart a)
    peek :: Storable a => GHC.Ptr.Ptr (ComplexL1 a) -> IO (ComplexL1 a)
    peek p           = do
                        let q = castPtr p
                        r <- peek q
                        i <- peekElemOff q 1
                        return (r :+ i)
    poke :: Storable a => GHC.Ptr.Ptr (ComplexL1 a) -> ComplexL1 a -> IO ()
    poke p (r :+ i)  = do
                        let q = castPtr p
                        poke q r
                        pokeElemOff q 1 i

instance Applicative ComplexL1 where
  pure :: a -> ComplexL1 a
  pure a = a :+ a
  (<*>) :: ComplexL1 (a -> b) -> ComplexL1 a -> ComplexL1 b
  f :+ g <*> a :+ b = f a :+ g b
  liftA2 :: (a -> b -> c) -> ComplexL1 a -> ComplexL1 b -> ComplexL1 c
  liftA2 f (x :+ y) (a :+ b) = f x a :+ f y b

instance Monad ComplexL1 where
  a :+ b >>= f = realPart (f a) :+ imagPart (f b)

instance MonadZip ComplexL1 where
  mzipWith = liftA2

instance MonadFix ComplexL1 where
  mfix f = (let a :+ _ = f a in a) :+ (let _ :+ a = f a in a)

data ViewPort = V {focus :: (Float,Float), zoom :: Float} deriving Show

-- | Displays a list of points in an external window.
displayPoints :: (Integral a) => [(a,a)] -> IO ()
displayPoints ps = displayPolys' $ map (pure . bimap fromIntegral fromIntegral) ps

-- | Displays a polygon in an external window. The last point in the list of points is not joined to the first point.
displayPolygon :: (Integral a) => [(a,a)] -> IO ()
displayPolygon ps = displayPolys' $ pure $ map (bimap fromIntegral fromIntegral) ps

-- | Displays a list of polygons in an external window. The last point in each list of points is not joined to the first point.
displayPolygons :: (Integral a) => [[(a,a)]] -> IO ()
displayPolygons ps = displayPolys' $ map (map (bimap fromIntegral fromIntegral)) ps

displayPolys' :: [[(Float,Float)]] -> IO ()
displayPolys' ps = do
    initializeAll
    window <- createWindow (pack $ "Displaying " ++ show (length (concat ps)) ++ " Points") defaultWindow { windowInitialSize = V2 1000 1000, windowResizable = True}
    renderer <- createRenderer window (-1) defaultRenderer
    appLoop viewPortFromPoints renderer
    destroyWindow window
    where
        appLoop viewport@(V (x,y) s) r = do
            events <- pollEvents
            let payloads = map eventPayload events
            clear r
            rendererDrawColor r $= V4 0 0 0 255
            fillRect r Nothing
            rendererDrawColor r $= V4 255 255 255 255
            mapM_ (\p -> if length p == 1 then drawPoint' viewport r (head p) else drawPoly viewport r p) ps
            P (V2 mx' my') <- getAbsoluteMouseLocation
            let pcol = (if (`elem` concat ps) $ bimap (fromIntegral . round) (fromIntegral . round) $ pointToViewPort viewport (fromIntegral mx',fromIntegral my') then V4 255 255 255 255 else V4 128 128 128 255)
            drawCs pcol r (mx'+30,my'+30) 3 $ init $ tail $ show $ bimap round round $ pointToViewPort viewport (fromIntegral mx',fromIntegral my')
            present r

            let movement = listToMaybe $ mapMaybe (\case MouseMotionEvent d -> if ButtonLeft `elem` mouseMotionEventState d then Just $ mouseMotionEventRelMotion d else Nothing; _ -> Nothing) payloads
                viewport'@(V (x1,y1) s1) = case movement of
                    Just (V2 x' y') -> V (x-fromIntegral x'*s/1000,y-fromIntegral y'*s/1000) s
                    Nothing -> viewport

            let (mx,my) = pointToViewPort viewport' (mx',my')
                scale = listToMaybe $ mapMaybe (\case MouseWheelEvent d -> Just $ mouseWheelEventPos d; _ -> Nothing) payloads
                viewport'' = case scale of
                    Just (V2 _ dy) -> let sf = 2**(-(fromIntegral dy/5)) in V (x1*sf+mx*(1-sf),y1*sf+my*(1-sf)) (s1*sf)
                    Nothing -> viewport'

            if QuitEvent `elem` payloads then return () else appLoop viewport'' r

        drawPoly v r p = mapM_ (drawLine' v r) $ zip p (tail p)
        drawLine' v r ((x1,y1),(x2,y2)) = drawLine r (P $ V2 x1' y1') (P $ V2 x2' y2')
            where
                (x1',y1') = pointFromViewPort v (x1,y1)
                (x2',y2') = pointFromViewPort v (x2,y2)
        drawPoint' v r (x1,y1) = fillRect r $ Just $ Rectangle (P $ V2 (x1'-3) (y1'-3)) (V2 6 6)
            where (x1',y1') = pointFromViewPort v (x1,y1)
        pointFromViewPort (V (x,y) s) (x',y') = (round $ (x'-x)*1000/ s,round $ (y'-y)*1000/ s)
        pointToViewPort (V (x,y) s) (x',y') = (fromIntegral x'*s/1000+x,fromIntegral y'*s/1000+y)
        viewPortFromPoints = V (bx,by) (max (bx'-bx) (by'-by))
            where
                bx = minimum $ map fst $ concat ps
                bx' = maximum $ map fst $ concat ps
                by = minimum $ map snd $ concat ps
                by' = maximum $ map snd $ concat ps

        black = V4 0 0 0 255

        drawC g r (x,y) dp c = case c of
            '0' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 x y)) (V2 (3*dp) (5*dp))
                rendererDrawColor r $= black
                fillRect r $ Just $ Rectangle (P (V2 (x+1*dp) (y+1*dp))) (V2 (1*dp) (3*dp))
            '1' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 (x+2*dp) y)) (V2 (1*dp) (5*dp))
            '2' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 x y)) (V2 (3*dp) (5*dp))
                rendererDrawColor r $= black
                fillRect r $ Just $ Rectangle (P (V2 x (y+1*dp))) (V2 (2*dp) dp)
                fillRect r $ Just $ Rectangle (P (V2 (x+1*dp) (y+3*dp))) (V2 (2*dp) dp)
            '3' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 x y)) (V2 (3*dp) (5*dp))
                rendererDrawColor r $= black
                fillRect r $ Just $ Rectangle (P (V2 x (y+1*dp))) (V2 (2*dp) dp)
                fillRect r $ Just $ Rectangle (P (V2 x (y+3*dp))) (V2 (2*dp) dp)
            '4' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 x y)) (V2 (3*dp) (5*dp))
                rendererDrawColor r $= black
                fillRect r $ Just $ Rectangle (P (V2 (x+1*dp) y)) (V2 dp (2*dp))
                fillRect r $ Just $ Rectangle (P (V2 x (y+3*dp))) (V2 (2*dp) (2*dp))
            '5' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 x y)) (V2 (3*dp) (5*dp))
                rendererDrawColor r $= black
                fillRect r $ Just $ Rectangle (P (V2 (x+1*dp) (y+1*dp))) (V2 (2*dp) dp)
                fillRect r $ Just $ Rectangle (P (V2 x (y+3*dp))) (V2 (2*dp) dp)
            '6' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 x y)) (V2 (3*dp) (5*dp))
                rendererDrawColor r $= black
                fillRect r $ Just $ Rectangle (P (V2 (x+1*dp) (y+1*dp))) (V2 (2*dp) dp)
                fillRect r $ Just $ Rectangle (P (V2 (x+1*dp) (y+3*dp))) (V2 dp dp)
            '7' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 x y)) (V2 (3*dp) (5*dp))
                rendererDrawColor r $= black
                fillRect r $ Just $ Rectangle (P (V2 x (y+1*dp))) (V2 (2*dp) (4*dp))
            '8' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 x y)) (V2 (3*dp) (5*dp))
                rendererDrawColor r $= black
                fillRect r $ Just $ Rectangle (P (V2 (x+1*dp) (y+1*dp))) (V2 dp dp)
                fillRect r $ Just $ Rectangle (P (V2 (x+1*dp) (y+3*dp))) (V2 dp dp)
            '9' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 x y)) (V2 (3*dp) (5*dp))
                rendererDrawColor r $= black
                fillRect r $ Just $ Rectangle (P (V2 (x+1*dp) (y+1*dp))) (V2 dp dp)
                fillRect r $ Just $ Rectangle (P (V2 x (y+3*dp))) (V2 (2*dp) dp)
            '-' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 x (y+2*dp))) (V2 (3*dp) dp)
            ',' -> do
                rendererDrawColor r $= g
                fillRect r $ Just $ Rectangle (P (V2 (x+1*dp) (y+4*dp))) (V2 dp (2*dp))
            _ -> return ()

        drawCs g r (x,y) dp (c:cs) = do
            drawC g r (x,y) dp c
            drawCs g r (x+4*dp,y) dp cs
        drawCs _ _ _ _ [] = return ()

-- | Displays an undirected graph in an external window.
-- 
-- The graph is a @HashMap@ mapping each node to a sequence of neighbours.
--
-- It is assumed that any neighbour that a node in the graph points to is also in the graph.
--
-- Note that this function and @displayGraphDirected@ can be taxing on lower-end devices for well-connected graphs with more than a few hundred nodes.
displayGraphUndirected :: (Hashable n, Foldable t) => HM.HashMap n (t n) -> IO ()
displayGraphUndirected = displayGraph False
-- | Displays a directed graph in an external window.
displayGraphDirected :: (Hashable n, Foldable t) => HM.HashMap n (t n) -> IO ()
displayGraphDirected = displayGraph True

displayGraph :: (Hashable n, Foldable t) => Bool -> HM.HashMap n (t n) -> IO ()
displayGraph directed g = do
    initializeAll
    window <- createWindow (pack $ "Displaying " ++ show (length (HM.keys g)) ++ " Nodes") defaultWindow { windowInitialSize = V2 1000 1000, windowResizable = True}
    renderer <- createRenderer window (-1) defaultRenderer
    appLoop (stateFromGraph g) defaultViewPort renderer
    destroyWindow window
    where
        appLoop state viewport@(V (x,y) s) r = do
            events <- pollEvents
            let payloads = map eventPayload events
            clear r
            rendererDrawColor r $= V4 0 0 0 255
            fillRect r Nothing
            rendererDrawColor r $= V4 255 255 255 255
            P (V2 mx' my') <- getAbsoluteMouseLocation
            drawState viewport r state
            present r

            let state' = step state

            let (mx,my) = pointToViewPort viewport' (mx',my')
                movement = listToMaybe $ mapMaybe (\case MouseMotionEvent d -> if ButtonLeft `elem` mouseMotionEventState d then Just $ mouseMotionEventRelMotion d else Nothing; _ -> Nothing) payloads
                viewport'@(V (x1,y1) s1) = case movement of
                    Just (V2 x' y') -> if any (\(a,_,_,_) -> a) state' then viewport else V (x-fromIntegral x'*s/1000,y-fromIntegral y'*s/1000) s
                    Nothing -> viewport

            isHeld <- getMouseButtons
            let clicks = listToMaybe $ mapMaybe (\case MouseButtonEvent d -> if mouseButtonEventButton d == ButtonLeft then Just $ mouseButtonEventMotion d else Nothing; _ -> Nothing) payloads
                state'' = case clicks of
                    Just Pressed -> selNodes
                    Just Released -> relNodes
                    Nothing -> if isHeld ButtonLeft then selNodes else state'
                    where
                        inNode (x',y') = max (abs (x'-mx)) (abs (y'-my)) < 6
                        selNodes = HM.map (\(a,b,c,d) -> if inNode c || a then (True,b,(mx,my),(0,0)) else (False,b,c,d)) state'
                        relNodes = HM.map (\(_,b,c,d) -> (False,b,c,d)) state'

            let scale = listToMaybe $ mapMaybe (\case MouseWheelEvent d -> Just $ mouseWheelEventPos d; _ -> Nothing) payloads
                viewport'' = case scale of
                    Just (V2 _ dy) -> let sf = 2**(-(fromIntegral dy/5)) in V (x1*sf+mx*(1-sf),y1*sf+my*(1-sf)) (s1*sf)
                    Nothing -> viewport'

            if QuitEvent `elem` payloads then return () else appLoop state'' viewport'' r

        drawPoint' v r (x1,y1) = do
            rendererDrawColor r $= V4 255 255 255 255
            fillRect r $ Just $ Rectangle (P $ V2 (x1'-6000 `div` z) (y1'-6000 `div` z)) (V2 (12000 `div` z) (12000 `div` z))
            rendererDrawColor r $= V4 0 0 0 255
            fillRect r $ Just $ Rectangle (P $ V2 (x1'-3000 `div` z) (y1'-3000 `div` z)) (V2 (6000 `div` z) (6000 `div` z))
            where
                (x1',y1') = pointFromViewPort v (x1,y1)
                z = round $ zoom v
        drawArrow v r ((x1,y1),(x2,y2)) = mapM_ ((\((a,b),(c,d)) -> drawLine r (P $ V2 a b) (P $ V2 c d)) . (\((a,b),(c,d)) -> (pointFromViewPort v (a,b),pointFromViewPort v (c,d)))) ls
            where
                (hx,hy) = let scale = sqrt ((x2-x1)**2 + (y2-y1)**2) in ((x2-x1)*10/scale,(y2-y1)*10/scale)
                rotate45 (x,y) = ((x+y)/sqrt 2,(y-x)/sqrt 2)
                rotateMinus45 (x,y) = ((x-y)/sqrt 2,(y+x)/sqrt 2)
                (p,q) = rotate45 (hx,hy)
                (p',q') = rotateMinus45 (hx,hy)
                ls = [((x2-hx,y2-hy),(x2-hx-p/2,y2-hy-q/2)),((x2-hx,y2-hy),(x2-hx-p'/2,y2-hy-q'/2)),((x1+hx,y1+hy),(x2-hx,y2-hy))]
        drawLine' v r ((x1,y1),(x2,y2)) = drawLine r (P $ V2 x1' y1') (P $ V2 x2' y2')
            where
                (x1',y1') = pointFromViewPort v (x1,y1)
                (x2',y2') = pointFromViewPort v (x2,y2)
        pointFromViewPort (V (x,y) s) (x',y') = (round $ (x'-x)*1000/ s,round $ (y'-y)*1000/ s)
        pointToViewPort (V (x,y) s) (x',y') = (fromIntegral x'*s/1000+x,fromIntegral y'*s/1000+y)
        defaultViewPort = V (-100,100) 200

        drawState v r state = do mapM_ helper2 $ HM.toList state; rendererDrawColor r $= V4 255 255 255 255; mapM_ helper1 $ HM.toList state
            where
                helper1 (_,(_,_,(px,py),_)) = drawPoint' v r (px,py)
                helper2 (_,(_,ns,(px,py),_)) = mapM_ (\n' -> (if directed then drawArrow else drawLine') v r ((px,py),(\(_,_,p,_) -> p) $ fromMaybe' $ HM.lookup n' state)) ns

        step ns = updatedNodes
            where
                findForce (a,b) (c,d) =
                    let mag = max 5 $ sqrt $ (a-c)**2 + (b-d)**2
                        scale = min 1000 $ springConstant*(mag-ideal)/mag
                    in (scale*(c-a),scale*(d-b))
                findRepulsion (a,b) (c,d) =
                    let mag = max 5 $ sqrt $ (a-c)**2 + (b-d)**2
                        scale = negate $ min 1000 $ repulsionConstant/mag**1.5
                    in (scale*(c-a),scale*(d-b))
                findNetForce n =
                    let (selected,neighbours,pos,vel) = fromMaybe' $ HM.lookup n ns
                        adjacent = neighbours -- filter (\k -> k `elem` neighbours || n `elem` (\(a,b,c,d) -> b) (fromMaybe' $ HM.lookup k ns)) $ neighbours
                        notAdjacent = filter (\k -> k `notElem` neighbours && n `notElem` (\(_,b,_,_) -> b) (fromMaybe' $ HM.lookup k ns)) $ HM.keys ns
                        springs = foldl (\c n' -> let c' = (\(_,_,k,_) -> k) $ fromMaybe' $ HM.lookup n' ns in addV c $ findForce pos c') (0,0) adjacent
                        repulsion = foldl (\c n' -> let c' = (\(_,_,k,_) -> k) $ fromMaybe' $ HM.lookup n' ns in addV c $ findRepulsion pos c') (0,0) notAdjacent
                        induced = (-(friction*fst vel)-gravity*fst pos,-(friction*snd vel)-gravity*snd pos)
                    in if selected then (0,0) else addV induced $ addV springs repulsion
                updatedNodes = HM.mapWithKey (\k (b,n,p,v) -> (b,n,if b then p else addV p v,findNetForce k)) ns

        addV (a,b) (c,d) = (a+c,b+d)
        ideal = 0.000125
        springConstant = 0.005
        repulsionConstant = 0.25
        friction = 0.015
        gravity = 0.0005
        fromMaybe' = fromMaybe (error "Node not in graph")
        stateFromGraph g' = let n = HM.size g' in HM.mapWithKey (\k ns' -> (False,ns',(cos (pi * fromIntegral (hash k)/fromIntegral n),sin (pi * fromIntegral (hash k)/fromIntegral n)),(0,0))) g'