Skip to content

Commit

Permalink
Merge pull request #11 from fizruk/th-params
Browse files Browse the repository at this point in the history
Extend TH support to parametrised types
  • Loading branch information
fizruk authored Jun 17, 2024
2 parents 821f7f7 + 0120489 commit 27a262f
Show file tree
Hide file tree
Showing 16 changed files with 567 additions and 382 deletions.
11 changes: 1 addition & 10 deletions .github/workflows/haskell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,12 @@ jobs:
haskell/lambda-pi/src/Language/LambdaPi/Syntax/Lex.hs
haskell/lambda-pi/src/Language/LambdaPi/Syntax/Par.hs
- name: Check Syntax files exist
if: steps.restore-syntax-files.outputs.cache-hit == 'true'
shell: bash
id: check-syntax-files
run: |
source scripts/lib.sh
check_syntax_files_exist
printf "SYNTAX_FILES_EXIST=$SYNTAX_FILES_EXIST\n" >> $GITHUB_OUTPUT
- name: 🧰 Setup Stack
uses: freckle/stack-action@v5
with:
stack-build-arguments: --pedantic
stack-build-arguments-build: --dry-run
stack-build-arguments-test: --ghc-options -O2 ${{ steps.check-syntax-files.outputs.SYNTAX_FILES_EXIST == 'true' && ' ' || '--reconfigure --force-dirty --ghc-options -fforce-recomp' }}
stack-build-arguments-test: --ghc-options -O2 ${{ steps.restore-syntax-files.outputs.cache-hit == 'true' && ' ' || '--reconfigure --force-dirty --ghc-options -fforce-recomp' }}

- name: Save Syntax files
uses: actions/cache/save@v4
Expand Down
1 change: 1 addition & 0 deletions haskell/free-foil/free-foil.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ library
Control.Monad.Foil.TH.MkFromFoil
Control.Monad.Foil.TH.MkInstancesFoil
Control.Monad.Foil.TH.MkToFoil
Control.Monad.Foil.TH.Util
Control.Monad.Free.Foil
Control.Monad.Free.Foil.Example
other-modules:
Expand Down
57 changes: 28 additions & 29 deletions haskell/free-foil/src/Control/Monad/Foil/TH/MkFoilData.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
Expand All @@ -9,6 +8,7 @@ module Control.Monad.Foil.TH.MkFoilData (mkFoilData) where
import Language.Haskell.TH

import qualified Control.Monad.Foil.Internal as Foil
import Control.Monad.Foil.TH.Util

-- | Generate scope-safe variants given names of types for the raw representation.
mkFoilData
Expand All @@ -20,20 +20,18 @@ mkFoilData
mkFoilData termT nameT scopeT patternT = do
n <- newName "n"
l <- newName "l"
TyConI (DataD _ctx _name _tvars _kind patternCons _deriv) <- reify patternT
TyConI (DataD _ctx _name _tvars _kind scopeCons _deriv) <- reify scopeT
TyConI (DataD _ctx _name _tvars _kind termCons _deriv) <- reify termT
TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT
TyConI (DataD _ctx _name scopeTVars _kind scopeCons _deriv) <- reify scopeT
TyConI (DataD _ctx _name termTVars _kind termCons _deriv) <- reify termT

foilPatternCons <- mapM (toPatternCon n) patternCons
let foilScopeCons = map (toScopeCon n) scopeCons
let foilTermCons = map (toTermCon n l) termCons
foilPatternCons <- mapM (toPatternCon patternTVars n) patternCons
let foilScopeCons = map (toScopeCon scopeTVars n) scopeCons
let foilTermCons = map (toTermCon termTVars n l) termCons

return
[ DataD [] foilTermT [PlainTV n ()] Nothing foilTermCons []
, StandaloneDerivD Nothing [] (AppT (ConT ''Show) (AppT (ConT foilTermT) (VarT n)))
, DataD [] foilScopeT [PlainTV n ()] Nothing foilScopeCons [DerivClause Nothing [ConT ''Show]]
, DataD [] foilPatternT [PlainTV n (), PlainTV l ()] Nothing foilPatternCons []
, StandaloneDerivD Nothing [] (AppT (ConT ''Show) (AppT (AppT (ConT foilPatternT) (VarT n)) (VarT l)))
[ DataD [] foilTermT (termTVars ++ [KindedTV n () (PromotedT ''Foil.S)]) Nothing foilTermCons []
, DataD [] foilScopeT (scopeTVars ++ [KindedTV n () (PromotedT ''Foil.S)]) Nothing foilScopeCons []
, DataD [] foilPatternT (patternTVars ++ [KindedTV n () (PromotedT ''Foil.S), KindedTV l () (PromotedT ''Foil.S)]) Nothing foilPatternCons []
]
where
foilTermT = mkName ("Foil" ++ nameBase termT)
Expand All @@ -43,13 +41,14 @@ mkFoilData termT nameT scopeT patternT = do
-- | Convert a constructor declaration for a raw pattern type
-- into a constructor for the scope-safe pattern type.
toPatternCon
:: Name -- ^ Name for the starting scope type variable.
:: [TyVarBndr ()]
-> Name -- ^ Name for the starting scope type variable.
-> Con -- ^ Raw pattern constructor.
-> Q Con
toPatternCon n (NormalC conName params) = do
toPatternCon tvars n (NormalC conName params) = do
(lastScopeName, foilParams) <- toPatternConParams 1 n params
let foilConName = mkName ("Foil" ++ nameBase conName)
return (GadtC [foilConName] foilParams (AppT (AppT (ConT foilPatternT) (VarT n)) (VarT lastScopeName)))
return (GadtC [foilConName] foilParams (PeelConT foilPatternT (map (VarT . tvarName) tvars ++ [VarT n, VarT lastScopeName])))
where
-- | Process type parameters of a pattern,
-- introducing (existential) type variables for the intermediate scopes,
Expand All @@ -64,16 +63,16 @@ mkFoilData termT nameT scopeT patternT = do
case type_ of
-- if the current component is a variable identifier
-- then treat it as a single name binder (see 'Foil.NameBinder')
ConT tyName | tyName == nameT -> do
PeelConT tyName _tyParams | tyName == nameT -> do
l <- newName ("n" <> show i)
let type' = AppT (AppT (ConT ''Foil.NameBinder) (VarT p)) (VarT l)
(l', conParams') <- toPatternConParams (i+1) l conParams
return (l', (bang_, type') : conParams')
-- if the current component is a raw pattern
-- then convert it into a scope-safe pattern
ConT tyName | tyName == patternT -> do
PeelConT tyName tyParams | tyName == patternT -> do
l <- newName ("n" <> show i)
let type' = AppT (AppT (ConT foilPatternT) (VarT p)) (VarT l)
let type' = PeelConT foilPatternT (tyParams ++ [VarT p, VarT l])
(l', conParams') <- toPatternConParams (i+1) l conParams
return (l', (bang_, type') : conParams')
-- otherwise, ignore the component as non-binding
Expand All @@ -83,26 +82,26 @@ mkFoilData termT nameT scopeT patternT = do

-- | Convert a constructor declaration for a raw scoped term
-- into a constructor for the scope-safe scoped term.
toScopeCon :: Name -> Con -> Con
toScopeCon n (NormalC conName params) =
toScopeCon :: [TyVarBndr ()] -> Name -> Con -> Con
toScopeCon _tvars n (NormalC conName params) =
NormalC foilConName (map toScopeParam params)
where
foilConName = mkName ("Foil" ++ nameBase conName)
toScopeParam (_bang, ConT tyName)
| tyName == termT = (_bang, AppT (ConT foilTermT) (VarT n))
toScopeParam (_bang, PeelConT tyName tyParams)
| tyName == termT = (_bang, PeelConT foilTermT (tyParams ++ [VarT n]))
toScopeParam _bangType = _bangType

-- | Convert a constructor declaration for a raw term
-- into a constructor for the scope-safe term.
toTermCon :: Name -> Name -> Con -> Con
toTermCon n l (NormalC conName params) =
GadtC [foilConName] (map toTermParam params) (AppT (ConT foilTermT) (VarT n))
toTermCon :: [TyVarBndr ()] -> Name -> Name -> Con -> Con
toTermCon tvars n l (NormalC conName params) =
GadtC [foilConName] (map toTermParam params) (PeelConT foilTermT (map (VarT . tvarName) tvars ++ [VarT n]))
where
foilNames = [n, l]
foilConName = mkName ("Foil" ++ nameBase conName)
toTermParam (_bang, ConT tyName)
| tyName == patternT = (_bang, foldl AppT (ConT foilPatternT) (map VarT foilNames))
toTermParam (_bang, PeelConT tyName tyParams)
| tyName == patternT = (_bang, PeelConT foilPatternT (tyParams ++ map VarT foilNames))
| tyName == nameT = (_bang, AppT (ConT ''Foil.Name) (VarT n))
| tyName == scopeT = (_bang, AppT (ConT foilScopeT) (VarT l))
| tyName == termT = (_bang, AppT (ConT foilTermT) (VarT n))
| tyName == scopeT = (_bang, PeelConT foilScopeT (tyParams ++ [VarT l]))
| tyName == termT = (_bang, PeelConT foilTermT (tyParams ++ [VarT n]))
toTermParam _bangType = _bangType
52 changes: 30 additions & 22 deletions haskell/free-foil/src/Control/Monad/Foil/TH/MkFromFoil.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Language.Haskell.TH
import Language.Haskell.TH.Syntax (addModFinalizer)

import qualified Control.Monad.Foil as Foil
import Control.Monad.Foil.TH.Util

-- | Generate conversion functions from raw to scope-safe representation.
mkFromFoil
Expand All @@ -19,34 +20,41 @@ mkFromFoil
-> Name -- ^ Type name for raw patterns.
-> Q [Dec]
mkFromFoil termT nameT scopeT patternT = do
TyConI (DataD _ctx _name _tvars _kind patternCons _deriv) <- reify patternT
TyConI (DataD _ctx _name _tvars _kind scopeCons _deriv) <- reify scopeT
TyConI (DataD _ctx _name _tvars _kind termCons _deriv) <- reify termT
n <- newName "n"
let ntype = return (VarT n)
l <- newName "l"
let ltype = return (VarT l)
r <- newName "r"
let rtype = return (VarT r)
TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT
TyConI (DataD _ctx _name scopeTVars _kind scopeCons _deriv) <- reify scopeT
TyConI (DataD _ctx _name termTVars _kind termCons _deriv) <- reify termT

let termParams = map (VarT . tvarName) termTVars
let scopeParams = map (VarT . tvarName) scopeTVars
let patternParams = map (VarT . tvarName) patternTVars

fromFoilTermSignature <-
SigD fromFoilTermT <$>
[t| forall n.
[$(return (ConT nameT))]
-> Foil.NameMap n $(return (ConT nameT))
-> $(return (ConT foilTermT)) n
-> $(return (ConT termT))
[t| [$(return (ConT nameT))]
-> Foil.NameMap $ntype $(return (ConT nameT))
-> $(return (PeelConT foilTermT termParams)) $ntype
-> $(return (PeelConT termT termParams))
|]
fromFoilScopedSignature <-
SigD fromFoilScopedT <$>
[t| forall n.
[$(return (ConT nameT))]
-> Foil.NameMap n $(return (ConT nameT))
-> $(return (ConT foilScopeT)) n
-> $(return (ConT scopeT))
[t| [$(return (ConT nameT))]
-> Foil.NameMap $ntype $(return (ConT nameT))
-> $(return (PeelConT foilScopeT scopeParams)) $ntype
-> $(return (PeelConT scopeT scopeParams))
|]
fromFoilPatternSignature <-
SigD fromFoilPatternT <$>
[t| forall n l r.
[$(return (ConT nameT))]
-> Foil.NameMap n $(return (ConT nameT))
-> $(return (ConT foilPatternT)) n l
-> ([$(return (ConT nameT))] -> Foil.NameMap l $(return (ConT nameT)) -> $(return (ConT patternT)) -> r)
-> r
[t| [$(return (ConT nameT))]
-> Foil.NameMap $ntype $(return (ConT nameT))
-> $(return (PeelConT foilPatternT patternParams)) $ntype $ltype
-> ([$(return (ConT nameT))] -> Foil.NameMap $ltype $(return (ConT nameT)) -> $(return (PeelConT patternT patternParams)) -> $rtype)
-> $rtype
|]

addModFinalizer $ putDoc (DeclDoc fromFoilTermT)
Expand Down Expand Up @@ -86,7 +94,7 @@ mkFromFoil termT nameT scopeT patternT = do
conMatchBody = go 1 (VarE freshVars) (VarE env) (ConE conName) params

go _i _freshVars' _env' p [] = p
go i freshVars' env' p ((_bang, ConT tyName) : conParams)
go i freshVars' env' p ((_bang, PeelConT tyName _tyParams) : conParams)
| tyName == nameT =
go (i+1) freshVars' env' (AppE p (AppE (AppE (VarE 'Foil.lookupName) (VarE xi)) env')) conParams
| tyName == termT =
Expand Down Expand Up @@ -135,7 +143,7 @@ mkFromFoil termT nameT scopeT patternT = do
conMatchBody = go 1 (VarE freshVars) (VarE env) (ConE conName) params

go _i freshVars' env' p [] = AppE (AppE (AppE (VarE cont) freshVars') env') p
go i freshVars' env' p ((_bang, ConT tyName) : conParams)
go i freshVars' env' p ((_bang, PeelConT tyName _tyParams) : conParams)
| tyName == nameT =
CaseE freshVars'
[ Match (ListP []) (NormalB (AppE (VarE 'error) (LitE (StringL "not enough fresh variables")))) []
Expand Down Expand Up @@ -188,7 +196,7 @@ mkFromFoil termT nameT scopeT patternT = do
conMatchBody = go 1 (VarE freshVars) (VarE env) (ConE conName) params

go _i _freshVars' _env' p [] = p
go i freshVars' env' p ((_bang, ConT tyName) : conParams)
go i freshVars' env' p ((_bang, PeelConT tyName _tyParams) : conParams)
| tyName == nameT =
go (i+1) freshVars' env' (AppE p (AppE (AppE (VarE 'Foil.lookupName) (VarE xi)) env')) conParams
| tyName == termT =
Expand Down
17 changes: 9 additions & 8 deletions haskell/free-foil/src/Control/Monad/Foil/TH/MkInstancesFoil.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module Control.Monad.Foil.TH.MkInstancesFoil (mkInstancesFoil) where
import Language.Haskell.TH

import qualified Control.Monad.Foil as Foil
import Control.Monad.Foil.TH.Util

-- | Generate 'Foil.Sinkable' and 'Foil.CoSinkable' instances.
mkInstancesFoil
Expand All @@ -17,18 +18,18 @@ mkInstancesFoil
-> Name -- ^ Type name for raw patterns.
-> Q [Dec]
mkInstancesFoil termT nameT scopeT patternT = do
TyConI (DataD _ctx _name _tvars _kind patternCons _deriv) <- reify patternT
TyConI (DataD _ctx _name _tvars _kind scopeCons _deriv) <- reify scopeT
TyConI (DataD _ctx _name _tvars _kind termCons _deriv) <- reify termT
TyConI (DataD _ctx _name patternTVars _kind patternCons _deriv) <- reify patternT
TyConI (DataD _ctx _name scopeTVars _kind scopeCons _deriv) <- reify scopeT
TyConI (DataD _ctx _name termTVars _kind termCons _deriv) <- reify termT

return
[ InstanceD Nothing [] (AppT (ConT ''Foil.Sinkable) (ConT foilScopeT))
[ InstanceD Nothing [] (AppT (ConT ''Foil.Sinkable) (PeelConT foilScopeT (map (VarT . tvarName) scopeTVars)))
[ FunD 'Foil.sinkabilityProof (map clauseScopedTerm scopeCons) ]

, InstanceD Nothing [] (AppT (ConT ''Foil.CoSinkable) (ConT foilPatternT))
, InstanceD Nothing [] (AppT (ConT ''Foil.CoSinkable) (PeelConT foilPatternT (map (VarT . tvarName) patternTVars)))
[ FunD 'Foil.coSinkabilityProof (map clausePattern patternCons) ]

, InstanceD Nothing [] (AppT (ConT ''Foil.Sinkable) (ConT foilTermT))
, InstanceD Nothing [] (AppT (ConT ''Foil.Sinkable) (PeelConT foilTermT (map (VarT . tvarName) termTVars)))
[ FunD 'Foil.sinkabilityProof (map clauseTerm termCons)]
]

Expand Down Expand Up @@ -58,7 +59,7 @@ mkInstancesFoil termT nameT scopeT patternT = do
mkConParamPattern _ i = VarP (mkName ("x" ++ show i))

go _i _rename' p [] = p
go i rename' p ((_bang, ConT tyName) : conParams)
go i rename' p ((_bang, PeelConT tyName _tyParams) : conParams)
| tyName == nameT =
go (i + 1) rename' (AppE p (AppE (VarE rename) (VarE xi))) conParams
| tyName == termT =
Expand Down Expand Up @@ -98,7 +99,7 @@ mkInstancesFoil termT nameT scopeT patternT = do
mkConParamPattern _ i = VarP (mkName ("x" ++ show i))

go _i rename' p [] = AppE (AppE (VarE cont) rename') p
go i rename' p ((_bang, ConT tyName) : conParams)
go i rename' p ((_bang, PeelConT tyName _tyParams) : conParams)
| tyName == nameT || tyName == patternT =
AppE
(AppE (AppE (VarE 'Foil.coSinkabilityProof) rename') (VarE xi))
Expand Down
Loading

0 comments on commit 27a262f

Please sign in to comment.