Skip to content

Commit

Permalink
✨ Add function types
Browse files Browse the repository at this point in the history
semver: minor
  • Loading branch information
Somfic committed Dec 30, 2024
1 parent 77bbda6 commit 9b422e3
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/lexer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ impl<'de> Iterator for Lexer<'de> {
'~' => Ok((TokenKind::Tilde, TokenValue::None)),
'?' => Ok((TokenKind::Question, TokenValue::None)),
':' => Ok((TokenKind::Colon, TokenValue::None)),
'-' => Ok((TokenKind::Minus, TokenValue::None)),
'-' => self.parse_compound_operator(TokenKind::Minus, TokenKind::Arrow, '>'),
'+' => Ok((TokenKind::Plus, TokenValue::None)),
'*' => Ok((TokenKind::Star, TokenValue::None)),
'/' => Ok((TokenKind::Slash, TokenValue::None)),
Expand Down
3 changes: 3 additions & 0 deletions src/lexer/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ pub enum TokenKind {
Dollar,
/// A tilde sign; `~`.
Tilde,
/// An arrow; `->`.
Arrow,
/// A question mark; `?`.
Question,
/// A pipe; `|`.
Expand Down Expand Up @@ -226,6 +228,7 @@ impl Display for TokenKind {
TokenKind::Hash => write!(f, "`#`"),
TokenKind::Dollar => write!(f, "`$`"),
TokenKind::Tilde => write!(f, "`~`"),
TokenKind::Arrow => write!(f, "`->`"),
TokenKind::Question => write!(f, "`?`"),
TokenKind::Pipe => write!(f, "`|`"),
TokenKind::Caret => write!(f, "`^`"),
Expand Down
10 changes: 6 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ pub mod lexer;
pub mod parser;

const INPUT: &str = "
enum Test: a, b, c;
fn add(left ~ int, right ~ int) -> fn(int, int) -> int {
left + right
}
fn main() {
let x = true;
let is_event = \"even\" if x % 2 == 0 else \"odd\";
let a = 1;
let b = 2;
let c = add(a, b);
}
";

Expand Down
37 changes: 36 additions & 1 deletion src/parser/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ impl<'de> Type<'de> {
}
}

pub fn function(span: SourceSpan, parameters: Vec<Type<'de>>, return_type: Type<'de>) -> Self {
Self {
value: TypeValue::Function {
parameters,
return_type: Box::new(return_type),
},
span,
original_span: None,
}
}

pub fn span(mut self, span: SourceSpan) -> Self {
if self.original_span.is_none() {
self.original_span = Some(self.span);
Expand All @@ -120,6 +131,10 @@ pub enum TypeValue<'de> {
Symbol(Cow<'de, str>),
Collection(Box<Type<'de>>),
Set(Box<Type<'de>>),
Function {
parameters: Vec<Type<'de>>,
return_type: Box<Type<'de>>,
},
}

impl<'de> TypeValue<'de> {
Expand Down Expand Up @@ -208,9 +223,29 @@ impl Display for TypeValue<'_> {
TypeValue::Decimal => write!(f, "a decimal"),
TypeValue::Character => write!(f, "a character"),
TypeValue::String => write!(f, "a string"),
TypeValue::Symbol(name) => write!(f, "{}", name),
TypeValue::Symbol(name) => write!(f, "`{}`", name),
TypeValue::Collection(element) => write!(f, "[{}]", element),
TypeValue::Set(element) => write!(f, "{{{}}}", element),
TypeValue::Function {
parameters,
return_type,
} => {
write!(
f,
"fn ({})",
parameters
.iter()
.map(|p| p.to_string())
.collect::<Vec<_>>()
.join(", "),
)?;

if !return_type.value.is_unit() {
write!(f, " -> {}", return_type)?;
}

Ok(())
}
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/parser/ast/untyped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,28 @@ pub struct FunctionHeader<'de> {
pub name: Cow<'de, str>,
pub parameters: Vec<ParameterDeclaration<'de>>,
pub explicit_return_type: Option<Type<'de>>,
pub span: miette::SourceSpan,
}

#[derive(Debug, Clone)]
pub struct ParameterDeclaration<'de> {
pub name: Cow<'de, str>,
pub explicit_type: Type<'de>,
pub span: miette::SourceSpan,
}

#[derive(Debug, Clone)]
pub struct StructMemberDeclaration<'de> {
pub name: Cow<'de, str>,
pub explicit_type: Type<'de>,
pub span: miette::SourceSpan,
}

#[derive(Debug, Clone)]
pub struct EnumMemberDeclaration<'de> {
pub name: Cow<'de, str>,
pub value_type: Option<Type<'de>>,
pub span: miette::SourceSpan,
}

#[derive(Debug, Clone)]
Expand Down
1 change: 1 addition & 0 deletions src/parser/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ impl Default for Lookup<'_> {
.add_type_handler(TokenKind::StringType, typing::string)
.add_type_handler(TokenKind::SquareOpen, typing::collection)
.add_type_handler(TokenKind::CurlyOpen, typing::set)
.add_type_handler(TokenKind::Function, typing::function)
.add_statement_handler(TokenKind::Return, statement::return_)
.add_statement_handler(TokenKind::If, statement::if_)
}
Expand Down
28 changes: 19 additions & 9 deletions src/parser/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use super::{
statement, typing, Parser,
};
use crate::lexer::{Token, TokenKind, TokenValue};
use miette::{Context, Result};
use crate::parser::ast::CombineSpan;
use miette::{Context, Result, SourceSpan};

pub fn parse<'de>(parser: &mut Parser<'de>, optional_semicolon: bool) -> Result<Statement<'de>> {
let token = match parser.lexer.peek().as_ref() {
Expand Down Expand Up @@ -115,7 +116,7 @@ pub fn struct_<'de>(parser: &mut Parser<'de>) -> Result<Statement<'de>> {
.lexer
.expect(TokenKind::Identifier, "expected a field name")?;

let field = match field.value {
let field_name = match field.value {
TokenValue::Identifier(field) => field,
_ => unreachable!(),
};
Expand All @@ -125,7 +126,8 @@ pub fn struct_<'de>(parser: &mut Parser<'de>) -> Result<Statement<'de>> {
let explicit_type = typing::parse(parser, BindingPower::None)?;

fields.push(StructMemberDeclaration {
name: field,
span: SourceSpan::combine(vec![field.span, explicit_type.span]),
name: field_name,
explicit_type,
});
}
Expand Down Expand Up @@ -173,13 +175,14 @@ pub fn enum_<'de>(parser: &mut Parser<'de>) -> Result<Statement<'de>> {
.lexer
.expect(TokenKind::Identifier, "expected an enum member name")?;

let variant = match variant.value {
let variant_name = match variant.value {
TokenValue::Identifier(variant) => variant,
_ => unreachable!(),
};

variants.push(EnumMemberDeclaration {
name: variant,
span: variant.span,
name: variant_name,
value_type: None,
});
}
Expand Down Expand Up @@ -324,7 +327,7 @@ fn parse_function_header<'de>(parser: &mut Parser<'de>) -> Result<FunctionHeader
.lexer
.expect(TokenKind::Identifier, "expected a parameter name")?;

let parameter = match parameter.value {
let parameter_name = match parameter.value {
TokenValue::Identifier(parameter) => parameter,
_ => unreachable!(),
};
Expand All @@ -334,24 +337,31 @@ fn parse_function_header<'de>(parser: &mut Parser<'de>) -> Result<FunctionHeader
let explicit_type = typing::parse(parser, BindingPower::None)?;

parameters.push(ParameterDeclaration {
name: parameter,
span: SourceSpan::combine(vec![parameter.span, explicit_type.span]),
name: parameter_name,
explicit_type,
});
}

parser
let close = parser
.lexer
.expect(TokenKind::ParenClose, "expected a close parenthesis")?;

let explicit_return_type = match parser.lexer.peek_expect(TokenKind::Tilde) {
let explicit_return_type = match parser.lexer.peek_expect(TokenKind::Arrow) {
None => None,
Some(_) => {
parser.lexer.next();
Some(typing::parse(parser, BindingPower::None)?)
}
};

let mut spans = vec![token.span, close.span];
if let Some(explicit_return_type) = &explicit_return_type {
spans.push(explicit_return_type.span);
}

Ok(FunctionHeader {
span: SourceSpan::combine(spans),
name,
parameters,
explicit_return_type,
Expand Down
80 changes: 77 additions & 3 deletions src/parser/typechecker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,38 @@ impl<'ast> TypeChecker<'ast> {
}
}
untyped::StatementValue::Function { header, body } => {
self.check_expression(body, environment);
environment.set(
header.name.clone(),
Type::function(
header.span,
header
.parameters
.iter()
.map(|p| p.explicit_type.clone())
.collect(),
header
.explicit_return_type
.clone()
.unwrap_or(Type::unit(header.span)),
),
);

header.parameters.iter().for_each(|p| {
environment.set(p.name.clone(), p.explicit_type.clone());
});

let implicit_return_type = self
.check_expression(body, environment)
.unwrap_or(Type::unit(body.span));

self.expect_match(
&header
.explicit_return_type
.clone()
.unwrap_or(Type::unit(header.span)),
&implicit_return_type,
"explicit and implicit return types must match".into(),
);
}
untyped::StatementValue::Return(expr) => {
self.check_expression(expr, environment);
Expand All @@ -66,7 +97,6 @@ impl<'ast> TypeChecker<'ast> {
}
untyped::StatementValue::Assignment { name, value } => {
if let Some(expression_type) = self.check_expression(value, environment) {
println!("{}: {:?}", name, expression_type);
environment.set(name.clone(), expression_type);
}
}
Expand Down Expand Up @@ -177,6 +207,50 @@ impl<'ast> TypeChecker<'ast> {

Ok(truthy)
}
untyped::ExpressionValue::Call { callee, arguments } => {
let callee = self.type_of(callee, environment)?;

match callee.clone().value {
TypeValue::Function {
parameters,
return_type,
} => {
if parameters.len() != arguments.len() {
return Err(vec![MietteDiagnostic {
code: None,
severity: None,
url: None,
labels: Some(callee.label("function call")),
help: Some(format!(
"expected {} arguments, but found {}",
parameters.len(),
arguments.len()
)),
message: "incorrect number of arguments".to_owned(),
}]);
}

for (parameter, argument) in parameters.iter().zip(arguments) {
let argument = self.type_of(argument, environment)?;
self.expect_match(
&parameter,
&argument,
"argument and parameter must match".into(),
);
}

Ok((return_type.span(callee.span)).clone())
}
_ => Err(vec![MietteDiagnostic {
code: None,
severity: None,
url: None,
labels: Some(callee.label("function call")),
help: Some("only functions may be called".into()),
message: "not a function".to_owned(),
}]),
}
}
_ => todo!("type_of: {:?}", expression),
}
}
Expand Down Expand Up @@ -263,7 +337,7 @@ impl<'ast> TypeChecker<'ast> {
}

fn expect_types(&mut self, ty: &Type<'ast>, expected: &[TypeValue], message: String) {
if !expected.iter().any(|e| *e == ty.value) {
if !expected.iter().any(|ex| ty.value.matches(ex)) {
let mut labels = vec![];
labels.extend(ty.label(format!("{}", ty)));

Expand Down
Loading

0 comments on commit 9b422e3

Please sign in to comment.