{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS -Wno-orphans #-}
module Codec.Candid.TypTable where

import qualified Data.Map as M
import Control.Monad
import Control.Monad.State.Lazy
import Data.Void
import Prettyprinter
import Data.DList (singleton, DList)
import Data.Graph
import Data.Foldable

import Codec.Candid.Types

data SeqDesc where
    SeqDesc :: forall k. (Pretty k, Ord k) => M.Map k (Type k) -> [Type k] -> SeqDesc

instance Pretty SeqDesc where
    pretty :: SeqDesc -> Doc ann
pretty (SeqDesc Map k (Type k)
m [Type k]
ts) = ([(k, Type k)], [Type k]) -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (Map k (Type k) -> [(k, Type k)]
forall k a. Map k a -> [(k, a)]
M.toList Map k (Type k)
m, [Type k]
ts)

data Ref k f  = Ref k (f (Ref k f))

instance Pretty k => Pretty (Ref k f) where
    pretty :: Ref k f -> Doc ann
pretty (Ref k
k f (Ref k f)
_) = k -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty k
k
instance Eq k => Eq (Ref k f) where
    == :: Ref k f -> Ref k f -> Bool
(==) (Ref k
k1 f (Ref k f)
_) (Ref k
k2 f (Ref k f)
_) = k -> k -> Bool
forall a. Eq a => a -> a -> Bool
(==) k
k1 k
k2
instance Ord k => Ord (Ref k f) where
    compare :: Ref k f -> Ref k f -> Ordering
compare (Ref k
k1 f (Ref k f)
_) (Ref k
k2 f (Ref k f)
_) = k -> k -> Ordering
forall a. Ord a => a -> a -> Ordering
compare k
k1 k
k2

unrollTypeTable :: SeqDesc -> (forall k. (Pretty k, Ord k) => [Type (Ref k Type)] -> r) -> r
unrollTypeTable :: SeqDesc
-> (forall k. (Pretty k, Ord k) => [Type (Ref k Type)] -> r) -> r
unrollTypeTable (SeqDesc Map k (Type k)
m [Type k]
t) forall k. (Pretty k, Ord k) => [Type (Ref k Type)] -> r
k = [Type (Ref k Type)] -> r
forall k. (Pretty k, Ord k) => [Type (Ref k Type)] -> r
k (Map k (Type k) -> [Type k] -> [Type (Ref k Type)]
forall k.
Ord k =>
Map k (Type k) -> [Type k] -> [Type (Ref k Type)]
unrollTypeTable' Map k (Type k)
m [Type k]
t)

unrollTypeTable' :: forall k. Ord k => M.Map k (Type k) -> [Type k] -> [Type (Ref k Type)]
unrollTypeTable' :: Map k (Type k) -> [Type k] -> [Type (Ref k Type)]
unrollTypeTable' Map k (Type k)
m [Type k]
ts = [Type (Ref k Type)]
ts'
  where
    f :: k -> Type (Ref k Type)
    f :: k -> Type (Ref k Type)
f k
k = Ref k Type -> Type (Ref k Type)
forall a. a -> Type a
RefT (k -> Type (Ref k Type) -> Ref k Type
forall k (f :: * -> *). k -> f (Ref k f) -> Ref k f
Ref k
k (Map k (Type (Ref k Type))
m' Map k (Type (Ref k Type)) -> k -> Type (Ref k Type)
forall k a. Ord k => Map k a -> k -> a
M.! k
k))
    m' :: M.Map k (Type (Ref k Type))
    m' :: Map k (Type (Ref k Type))
m' = (Type k -> (k -> Type (Ref k Type)) -> Type (Ref k Type)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= k -> Type (Ref k Type)
f) (Type k -> Type (Ref k Type))
-> Map k (Type k) -> Map k (Type (Ref k Type))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map k (Type k)
m
    ts' :: [Type (Ref k Type)]
    ts' :: [Type (Ref k Type)]
ts' = (Type k -> (k -> Type (Ref k Type)) -> Type (Ref k Type)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= k -> Type (Ref k Type)
f) (Type k -> Type (Ref k Type)) -> [Type k] -> [Type (Ref k Type)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type k]
ts

buildSeqDesc :: forall k. (Pretty k, Ord k) => [Type (Ref k Type)] -> SeqDesc
buildSeqDesc :: [Type (Ref k Type)] -> SeqDesc
buildSeqDesc [Type (Ref k Type)]
ts = Map k (Type k) -> [Type k] -> SeqDesc
forall k.
(Pretty k, Ord k) =>
Map k (Type k) -> [Type k] -> SeqDesc
SeqDesc Map k (Type k)
m [Type k]
ts'
  where
    ([Type k]
ts', Map k (Type k)
m) = State (Map k (Type k)) [Type k]
-> Map k (Type k) -> ([Type k], Map k (Type k))
forall s a. State s a -> s -> (a, s)
runState ((Type (Ref k Type) -> StateT (Map k (Type k)) Identity (Type k))
-> [Type (Ref k Type)] -> State (Map k (Type k)) [Type k]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Ref k Type -> StateT (Map k (Type k)) Identity k)
-> Type (Ref k Type) -> StateT (Map k (Type k)) Identity (Type k)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ref k Type -> StateT (Map k (Type k)) Identity k
go) [Type (Ref k Type)]
ts) Map k (Type k)
forall a. Monoid a => a
mempty

    go :: Ref k Type -> State (M.Map k (Type k)) k
    go :: Ref k Type -> StateT (Map k (Type k)) Identity k
go (Ref k
k Type (Ref k Type)
t) = do
        Bool
seen <- (Map k (Type k) -> Bool) -> StateT (Map k (Type k)) Identity Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (k -> Map k (Type k) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member k
k)
        Bool
-> StateT (Map k (Type k)) Identity ()
-> StateT (Map k (Type k)) Identity ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
seen (StateT (Map k (Type k)) Identity ()
 -> StateT (Map k (Type k)) Identity ())
-> StateT (Map k (Type k)) Identity ()
-> StateT (Map k (Type k)) Identity ()
forall a b. (a -> b) -> a -> b
$ mdo
            (Map k (Type k) -> Map k (Type k))
-> StateT (Map k (Type k)) Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (k -> Type k -> Map k (Type k) -> Map k (Type k)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k Type k
t')
            Type k
t' <- (Ref k Type -> StateT (Map k (Type k)) Identity k)
-> Type (Ref k Type) -> StateT (Map k (Type k)) Identity (Type k)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ref k Type -> StateT (Map k (Type k)) Identity k
go Type (Ref k Type)
t
            () -> StateT (Map k (Type k)) Identity ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        k -> StateT (Map k (Type k)) Identity k
forall (m :: * -> *) a. Monad m => a -> m a
return k
k

voidEmptyTypes :: SeqDesc -> SeqDesc
voidEmptyTypes :: SeqDesc -> SeqDesc
voidEmptyTypes (SeqDesc Map k (Type k)
m [Type k]
ts) = Map k (Type k) -> [Type k] -> SeqDesc
forall k.
(Pretty k, Ord k) =>
Map k (Type k) -> [Type k] -> SeqDesc
SeqDesc Map k (Type k)
m' [Type k]
ts
  where
    edges :: [(k, k, [k])]
edges = [ (k
k,k
k, DList k -> [k]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Type k -> DList k
forall k. Type k -> DList k
underRec Type k
t)) | (k
k,Type k
t) <- Map k (Type k) -> [(k, Type k)]
forall k a. Map k a -> [(k, a)]
M.toList Map k (Type k)
m ]
    sccs :: [SCC k]
sccs = [(k, k, [k])] -> [SCC k]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
stronglyConnComp [(k, k, [k])]
edges
    bad :: [k]
bad = [[k]] -> [k]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ [k]
xs | CyclicSCC [k]
xs <- [SCC k]
sccs ]
    m' :: Map k (Type k)
m' = (Map k (Type k) -> k -> Map k (Type k))
-> Map k (Type k) -> [k] -> Map k (Type k)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Map k (Type k)
m k
k -> k -> Type k -> Map k (Type k) -> Map k (Type k)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k Type k
forall a. Type a
EmptyT Map k (Type k)
m) Map k (Type k)
m [k]
bad


underRec :: Type k -> DList k
underRec :: Type k -> DList k
underRec (RefT k
x) = k -> DList k
forall a. a -> DList a
singleton k
x
underRec (RecT Fields k
fs) = ((FieldName, Type k) -> DList k) -> Fields k -> DList k
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Type k -> DList k
forall k. Type k -> DList k
underRec (Type k -> DList k)
-> ((FieldName, Type k) -> Type k)
-> (FieldName, Type k)
-> DList k
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FieldName, Type k) -> Type k
forall a b. (a, b) -> b
snd) Fields k
fs
underRec Type k
_ = DList k
forall a. Monoid a => a
mempty

-- | This takes a type description and replaces all named types with their definition.
--
-- This can produce an infinite type! Only use this in sufficiently lazy contexts, or when the
-- type is known to be not recursive.
tieKnot :: SeqDesc -> [Type Void]
tieKnot :: SeqDesc -> [Type Void]
tieKnot (SeqDesc Map k (Type k)
m ([Type k]
ts :: [Type k])) = [Type Void]
ts'
  where
    f :: k -> Type Void
    f :: k -> Type Void
f k
k = Map k (Type Void)
m' Map k (Type Void) -> k -> Type Void
forall k a. Ord k => Map k a -> k -> a
M.! k
k
    m' :: M.Map k (Type Void)
    m' :: Map k (Type Void)
m' = (Type k -> (k -> Type Void) -> Type Void
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= k -> Type Void
f) (Type k -> Type Void) -> Map k (Type k) -> Map k (Type Void)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map k (Type k)
m
    ts' :: [Type Void]
    ts' :: [Type Void]
ts' = (Type k -> (k -> Type Void) -> Type Void
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= k -> Type Void
f) (Type k -> Type Void) -> [Type k] -> [Type Void]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type k]
ts