-
Notifications
You must be signed in to change notification settings - Fork 2
/
Inference.hs
279 lines (243 loc) · 9.42 KB
/
Inference.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
{-# LANGUAGE OverloadedStrings #-}
module Inference where
import Control.Monad.State
import qualified Data.Map as M
import qualified Data.Set as S
import System.IO.Unsafe
import Data.IORef
import qualified Data.ByteString.Char8
import Control.Applicative
import Data.Maybe
import Syntax
data TypeSchema = Forall [Var] Type
type TEnv = M.Map Var TypeSchema
type UEnv = M.Map Var Type -- unification environment
type Typecheck = StateT UEnv IO
runTypecheck :: Typecheck a -> IO a
runTypecheck tc = evalStateT tc M.empty
monomorphic :: Type -> TypeSchema
monomorphic = Forall []
typecheckModule :: TEnv -> Module -> Typecheck ()
typecheckModule tenv decls = do
_ <- typecheckRecursiveGroup tenv binds
return ()
where
binds = do
VarD v e <- decls
return (v, e)
typecheckExpr :: TEnv -> Expr -> Typecheck Type
typecheckExpr tenv expr = case expr of
App fun args -> do
funType <- typecheckExpr tenv fun
argTypes <- mapM (typecheckExpr tenv) args
resultType <- mkUniqueTV
unify funType $ funT argTypes resultType
return resultType
Let pat defn body -> do
defnType <- typecheckExpr tenv defn
(patType, bound) <- typecheckPat tenv pat
unify defnType patType
typecheckExpr (M.union tenv $ fmap monomorphic bound) body
Letrec binds body -> do
additionalEnv <- typecheckRecursiveGroup tenv binds
let finalEnv = M.union tenv additionalEnv
typecheckExpr finalEnv body
Lam params body -> do
paramTypes <- mapM (const mkUniqueTV) params
let bodyEnv = M.union tenv $ M.fromList $ zip params (map monomorphic paramTypes)
resultType <- typecheckExpr bodyEnv body
return $ funT paramTypes resultType
Case scrt branches -> do
scrtType <- typecheckExpr tenv scrt
resultType <- mkUniqueTV
forM_ branches $ \b -> do
(patType, bodyType) <- typecheckBranch tenv b
unify scrtType patType
unify resultType bodyType
return resultType
Ref var -> case M.lookup var tenv of
Nothing -> fail $ "unbound variable: " ++ show var
Just ty -> instantiate ty
Lit (IntL _) -> return intT
Lit (StringL _) -> return stringT
Assign pat body -> do
bodyType <- typecheckExpr tenv body
(patType, binds) <- typecheckPat tenv pat
unify bodyType patType
forM_ (M.toList binds) $ \(var, ty) -> case M.lookup var tenv of
Nothing -> fail "attempt to assign to nonexistent variable"
Just (Forall [] realTy) -> unify ty realTy
Just _ -> fail "attempt to assign to a polymorphic variable"
return unitT
While cond body -> do
unify boolT =<< typecheckExpr tenv cond
unify unitT =<< typecheckExpr tenv body
return unitT
JS _ -> mkUniqueTV
Typed body ty -> do
inferred <- typecheckExpr tenv body
unify inferred ty
return inferred
typecheckRecursiveGroup :: TEnv -> [(Var, Expr)] -> Typecheck TEnv
typecheckRecursiveGroup tenv binds = do
let (vars, defns) = unzip binds
temporaryVarTypes <- mapM (const mkUniqueTV) binds
let temporaryEnv = M.union tenv $ M.fromList $
zip vars (map monomorphic temporaryVarTypes)
forM_ (zip defns temporaryVarTypes) $ \(defn, ty) ->
unify ty =<< typecheckExpr temporaryEnv defn
finalVarTypes <- forM temporaryVarTypes $ \tmpType -> do
tmpType' <- derefAll tmpType
return $! generalize tenv tmpType'
return $! M.fromList $ zip vars finalVarTypes
generalize :: TEnv -> Type -> TypeSchema
generalize tenv ty = Forall free ty
where
!free = S.toList $ independentVariables tenv ty
independentVariables :: TEnv -> Type -> S.Set Var
independentVariables tenv ty = tyVariables `S.difference` envVariables
where
tyVariables = typeVariables ty
envVariables = S.unions $ map freeVariables $ M.elems tenv
freeVariables :: TypeSchema -> S.Set Var
freeVariables (Forall vs ty) = typeVariables ty `S.difference` S.fromList vs
typeVariables :: Type -> S.Set Var
typeVariables ty = case ty of
VarT var -> S.singleton var
ConstT {} -> S.empty
AppT f xs -> S.unions $ map typeVariables (f:xs)
instantiate :: TypeSchema -> Typecheck Type
instantiate (Forall vs body) = do
argTypes <- mapM (const mkUniqueTV) vs
return $! substitute (M.fromList $ zip vs argTypes) body
substitute :: M.Map Var Type -> Type -> Type
substitute trans ty = case ty of
VarT var -> fromMaybe ty $ M.lookup var trans
ConstT _ -> ty
AppT f xs -> AppT (substitute trans f) (map (substitute trans) xs)
typecheckBranch :: TEnv -> Branch -> Typecheck (Type, Type)
typecheckBranch tenv (pat, body) = do
(patType, extraEnv) <- typecheckPat tenv pat
bodyType <- typecheckExpr (M.union tenv $ fmap monomorphic extraEnv) body
return (patType, bodyType)
typecheckPat :: TEnv -> Pat -> Typecheck (Type, M.Map Var Type)
typecheckPat tenv pat = case pat of
VarP var -> do
resultType <- mkUniqueTV
return (resultType, M.singleton var resultType)
ConstructorP cons args -> do
resultType <- mkUniqueTV
(argTypes, argEnvs) <- unzip <$> mapM (typecheckPat tenv) args
case M.lookup cons tenv of
Nothing -> fail "pattern match on unknwon constructor"
Just consSchema -> do
consType <- instantiate consSchema
unify consType $ funT argTypes resultType
return (resultType, M.unions argEnvs)
IntP _ -> return (intT, M.empty)
StringP _ -> return (stringT, M.empty)
UnitP -> return (unitT, M.empty)
WildcardP -> do
resultType <- mkUniqueTV
return (resultType, M.empty)
unify :: Type -> Type -> Typecheck ()
unify (VarT x) ty = unifyVar x ty
unify ty (VarT x) = unifyVar x ty
unify (AppT f xs) ty = case ty of
AppT g ys
| length xs == length ys -> do
unify f g
sequence_ $ zipWith unify xs ys
_ -> fail "unification failed"
unify (ConstT v) ty = case ty of
ConstT w
| v == w -> return ()
_ -> fail $ "unification on const failed: " ++ show v ++ " <-> " ++ show ty
unifyVar :: Var -> Type -> Typecheck ()
unifyVar x ty = do
xty <- deref x
case xty of
VarT v -> do
ty' <- derefAll ty
when (ty' /= VarT v) $ do
when (occurs v ty') $ fail $
"occurs check failed: cannot unify " ++ show (v, ty')
addToUEnv v ty
_ -> unify xty ty
occurs :: Var -> Type -> Bool
occurs v ty = case ty of
VarT u -> v == u
ConstT{} -> False
AppT f xs -> occurs v f || any (occurs v) xs
addToUEnv :: Var -> Type -> Typecheck ()
addToUEnv var ty = do
uenv <- get
put $! M.insert var ty uenv
deref :: Var -> Typecheck Type
deref x = do
uenv <- get
case M.lookup x uenv of
Nothing -> return $ VarT x
Just (VarT y) -> deref y
Just ty -> return ty
derefAll :: Type -> Typecheck Type
derefAll (VarT v) = do
uenv <- get
case M.lookup v uenv of
Nothing -> return $ VarT v
Just ty -> derefAll ty
derefAll ty@ConstT{} = return ty
derefAll (AppT f xs) = AppT <$> derefAll f <*> mapM derefAll xs
funT :: [Type] -> Type -> Type
funT argTypes resultType = AppT (ConstT v_fun) (argTypes ++ [resultType])
intT :: Type
intT = ConstT v_int
stringT :: Type
stringT = ConstT v_string
unitT :: Type
unitT = ConstT v_unit
boolT :: Type
boolT = ConstT v_bool
v_fun = Var (-1) "->"
v_int = Var (-2) "int"
v_string = Var (-3) "string"
v_unit = Var (-4) "()"
v_bool = Var (-5) "bool"
mkUniqueTV :: Typecheck Type
mkUniqueTV = fmap VarT $ liftIO mkUniqueVar
mkUniqueVar :: IO Var
mkUniqueVar = do
n <- mkUniqueInt
return $ Var n ""
mkUniqueInt :: IO Int
mkUniqueInt = do
v <- readIORef uniqueSupply
writeIORef uniqueSupply $! v + 1
return v
{-# NOINLINE uniqueSupply #-}
uniqueSupply :: IORef Int
uniqueSupply = unsafePerformIO $ newIORef 0
---- tests
tc_var n = Var n ""
tc_ref n = Ref (tc_var n)
tc_expr1 = Lam [tc_var 0] (tc_ref 0) -- \x -> x
tc_expr2 = Lam [tc_var 0] (App (tc_ref 0) [tc_ref 0]) -- \x -> x x
tc_expr3 = Lam [tc_var 0, tc_var 1] (tc_ref 0) -- \x y -> x
tc_expr4 = Lam [tc_var 0, tc_var 1] (App (tc_ref 1) [tc_ref 0]) -- \x y -> y x
tc_expr5 = Lam [tc_var 0] (App (tc_ref 0) [Lit (IntL 100)]) -- \x -> x 100
tc_expr6 = Lam [tc_var 0] (App (Lit (IntL 100))[tc_ref 0] ) -- \x -> 100 x
tc_expr7 = Lam [tc_var 0] (Typed (App (JS "console.log")[tc_ref 0]) unitT ) -- \x -> JS"console.log" x :: ()
tc_expr8 = Lam [tc_var 0] (Case (tc_ref 0) [(IntP 0, (Lit (IntL 100))), (WildcardP, tc_ref 0)])
-- \x -> case x of 0 -> 100; _ -> x
tc_expr9 = Lam [tc_var 0] (Case (tc_ref 0) [(VarP (tc_var 1), tc_ref 1)])
-- \x -> case x of y -> y
tc_expr10 = Lam [tc_var 0] (Assign (VarP (tc_var 0)) (Lit (IntL 3)))
-- \x -> x := 3
tc_expr11 = Lam [tc_var 0, tc_var 1] $ Let (VarP (tc_var 2)) (App (tc_ref 0) [tc_ref 1]) (App (tc_ref 0) [tc_ref 2])
-- \x y -> let z = x y in x z
tc_expr12 = Letrec [(tc_var 0, Lam [tc_var 1, tc_var 2, tc_var 3] $ Case (App (tc_ref 1) [tc_ref 3]) [(IntP 0, tc_ref 3), (WildcardP, App (tc_ref 0) [tc_ref 1, tc_ref 2, App (tc_ref 2) [tc_ref 3]])])] (tc_ref 0)
-- letrec a = \b c d -> case b d of 0 -> d; _ -> a b c (c d) in a
tc_expr13 = Letrec [(tc_var 0, Lam [tc_var 1] (tc_ref 1))] $ App (App (tc_ref 0) [tc_ref 0]) [Lit (IntL 4)]
-- letrec a = \b -> b in (a a) 4
test :: Expr -> IO Type
test expr = runTypecheck $ typecheckExpr M.empty expr >>= derefAll