random things

This commit is contained in:
tezlm 2023-09-29 15:55:26 -07:00
commit 9c73e38ac2
Signed by: tezlm
GPG key ID: 649733FCD94AFBBA
9 changed files with 1054 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/target

7
Cargo.lock generated Normal file
View file

@ -0,0 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "lang"
version = "0.1.0"

8
Cargo.toml Normal file
View file

@ -0,0 +1,8 @@
[package]
name = "lang"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

32
print.wat Normal file
View file

@ -0,0 +1,32 @@
(module
;; import from wasi
;; fn fd_write(fd, *iovs, iovs_len, nwritten) -> bytes_written
(import "wasi_unstable" "fd_write" (func $fd_write (param i32 i32 i32 i32) (result i32)))
;; create memory (size = 1 page = 64KiB)
(memory $foobar 1)
;; export memory - it's required, but we don't use it so the size is set to 0
(export "memory" (memory 0))
;; write string to memory (offset = 8 bytes)
(data (i32.const 8) "Hello, world!\n")
(func $main (export "_start")
;; iov.iov_base - pointer to string (offset = 0 bytes)
;; the string's offset is 8 bytes in memory
(i32.store (i32.const 0) (i32.const 8))
;; iov.iov_len - length of the hello world string (offset = 4 bytes)
;; the string's length is 14 bytes
(i32.store (i32.const 4) (i32.const 14))
(call $fd_write
(i32.const 1) ;; fd: stdout = 1
(i32.const 0) ;; data: pointer to memory - this is the first memory we create (index 0)
(i32.const 1) ;; data_len: there's 1 string
(i32.const 2468) ;; nwritten: i don't care about this, write it wherever
)
drop ;; drop number of bytes written
)
)

16
src/error.rs Normal file
View file

@ -0,0 +1,16 @@
#[derive(Debug)]
pub enum Error {
SyntaxError(String),
TypeError(String),
ReferenceError(String),
}
impl Error {
pub fn syn(what: &'static str) -> Error {
Error::SyntaxError(what.to_string())
}
pub fn ty(what: &'static str) -> Error {
Error::TypeError(what.to_string())
}
}

160
src/generator.rs Normal file
View file

@ -0,0 +1,160 @@
/*
optimizations
- use i32.eqz when comparing to zero
- write negative numberss directly instead of as positive + sign flip
*/
use crate::parser::{Expr, Literal, BinaryOp, PrefixOp, Statement, Context};
pub fn generate(expr: &Expr) {
println!();
println!();
let mut ctx = Context::new();
let mut exprs = Vec::new();
get_locals(expr, &mut ctx, &mut exprs);
// println!(r#"(module (func (export "_start") (local $match i32)"#);
println!(r#"(module (func (export "_start") (result i32) (local $match i32)"#);
// println!(r#"(module (func (export "_start") (result f64) (local $match f64)"#);
for (name, _) in &exprs {
let ty = match ctx.locals.get(name).unwrap() {
crate::parser::Type::Integer => "i32",
crate::parser::Type::Float => "f64",
_ => todo!(),
};
println!("(local ${name} {ty})");
}
for (name, expr) in &exprs {
gen_expr(expr, &ctx);
println!("(local.set ${name})");
}
gen_expr(&expr, &ctx);
// println!("drop");
println!("))");
}
fn gen_expr(expr: &Expr, ctx: &Context) {
match expr {
Expr::Literal(lit) => match lit {
Literal::Integer(int) => println!("i32.const {int}"),
Literal::Float(f) => println!("f64.const {f}"),
Literal::Boolean(b) => println!("i32.const {}", if *b { 1 } else { 0 }),
_ => todo!(),
}
Expr::Variable(name) => {
println!("local.get ${name}");
}
Expr::Binary(op, a, b) => {
gen_expr(a, ctx);
gen_expr(b, ctx);
let ty = match expr.infer(&ctx).unwrap() {
crate::parser::Type::Integer => "i32",
crate::parser::Type::Float => "f64",
crate::parser::Type::Boolean => "i32",
_ => todo!(),
};
match op {
BinaryOp::Add => println!("{ty}.add"),
BinaryOp::Mul => println!("{ty}.mul"),
BinaryOp::Sub => println!("{ty}.sub"),
BinaryOp::Div => println!("{ty}.div_u"), // do i _u or _s?
BinaryOp::Mod => println!("{ty}.rem_u"),
BinaryOp::Eq => println!("{ty}.eq"),
BinaryOp::Neq => println!("{ty}.neq"),
BinaryOp::Less => println!("{ty}.lt_u"),
BinaryOp::Greater => println!("{ty}.gt_u"),
_ => todo!(),
}
}
Expr::Unary(op, e) => {
gen_expr(e, ctx);
match op {
PrefixOp::Minus => {
// this is so inefficent, but i don't care
println!("i32.const -1");
println!("i32.mul");
}
PrefixOp::LogicNot => {
println!("i32.eqz");
}
PrefixOp::BitNot => {
// TODO: do i flip the sign bit?
println!("i32.const {}", i32::MAX);
println!("i32.xor");
}
}
}
Expr::Match(cond, arms) => {
println!(";; --- set match variable");
println!("(local.set $match (");
gen_expr(cond, ctx);
println!("))");
println!(";; --- generate match");
for (idx, (pat, expr)) in arms.iter().enumerate() {
// FIXME: hardcoded until patern matching works better
match pat {
crate::parser::Pattern::Literal(lit) => match lit {
Literal::Integer(int) => println!("i32.const {}", int),
Literal::Boolean(b) => println!("i32.const {}", if *b { 1 } else { 0 }),
_ => todo!(),
}
};
println!("local.get $match");
println!("i32.eq");
println!("(if (result i32) (then");
gen_expr(expr, ctx);
if idx == arms.len() - 1 {
// TODO: verify its actually unreachable earlier on
println!(") (else unreachable");
} else {
println!(") (else");
}
}
println!("{}", ")".repeat(arms.len() * 2));
println!(";; --- done");
}
Expr::Block(b) => {
for (i, stmt) in b.0.iter().enumerate() {
match stmt {
Statement::Expr(expr) => {
gen_expr(expr, &ctx);
if i < b.0.len() - 1 {
println!("drop");
}
}
_ => {},
}
}
}
};
}
fn get_locals(expr: &Expr, ctx: &mut Context, exprs: &mut Vec<(String, Expr)>) {
match expr {
Expr::Block(b) => {
for stmt in &b.0 {
match stmt {
Statement::Let(name, expr) => {
let ty = expr.infer(ctx).unwrap();
ctx.locals.insert(name.clone(), ty);
exprs.push((name.clone(), expr.clone()));
}
Statement::Expr(expr) => get_locals(&expr, ctx, exprs),
}
}
},
Expr::Unary(_, expr) => get_locals(&expr, ctx, exprs),
Expr::Binary(_, a, b) => {
get_locals(&a, ctx, exprs);
get_locals(&b, ctx, exprs);
}
_ => (),
}
}

325
src/lexer.rs Normal file
View file

@ -0,0 +1,325 @@
use crate::Error;
pub struct Lexer {
input: Vec<char>,
pos: usize,
}
#[rustfmt::skip]
#[derive(Debug, PartialEq, Eq)]
pub enum Token {
Number { radix: u32, text: String },
Ident(String),
String(String),
Char(char),
OpenParan, CloseParan,
OpenBrace, CloseBrace,
OpenBracket, CloseBracket,
Plus, Minus, Star, DoubleStar, Slash, Percent,
Pipe, DoublePipe, And, DoubleAnd, Carat, Shl, Shr,
PlusSet, MinusSet, StarSet, DoubleStarSet, SlashSet, PercentSet,
PipeSet, DoublePipeSet, AndSet, DoubleAndSet, CaratSet, ShlSet, ShrSet,
Set, Eq, Neq, Less, LessEq, Greater, GreaterEq, Not,
Dot, DoubleDot, TripleDot, Comma, Question, Colon, DoubleColon, Semicolon,
ThinArrow, FatArrow,
Let, Const, Type, Fn,
True, False,
If, Else, Match,
While, Loop, For, Break, Continue,
}
impl Lexer {
pub fn new(input: String) -> Lexer {
Lexer {
input: input.chars().collect(),
pos: 0,
}
}
pub fn next(&mut self) -> Result<Option<Token>, Error> {
let Some(&ch) = self.input.get(self.pos) else {
return Ok(None);
};
let tok = match ch {
'0'..='9' => {
let token = self.lex_number()?;
if self
.input
.get(self.pos)
.is_some_and(|c| c.is_ascii_alphanumeric())
{
panic!("unexpected char");
}
token
}
'\'' => {
self.pos += 1;
let ch = self.lex_char()?;
if self.input.get(self.pos).is_some_and(|c| *c != '\'') {
panic!("expected '");
}
self.pos += 1;
Token::Char(ch)
}
'"' => Token::String(self.lex_string()?),
ch if ch.is_alphabetic() || ch == '_' => match self.lex_ident().as_str() {
"let" => Token::Let,
"const" => Token::Const,
"type" => Token::Type,
"fn" => Token::Fn,
"true" => Token::True,
"false" => Token::False,
"if" => Token::If,
"else" => Token::Else,
"match" => Token::Match,
"while" => Token::While,
"loop" => Token::Loop,
"for" => Token::For,
"break" => Token::Break,
"continue" => Token::Continue,
ident => Token::Ident(ident.to_string()),
},
ch if ch.is_whitespace() => {
self.pos += 1;
return self.next();
}
_ => self.lex_op()?,
};
Ok(Some(tok))
}
fn lex_number(&mut self) -> Result<Token, Error> {
let mut buffer = String::new();
let radix = match (self.input[self.pos], self.input.get(self.pos + 1)) {
('0', Some(ch)) if ch.is_digit(10) => 10,
('0', Some(ch)) => {
self.pos += 2;
match ch {
'x' => 16,
'o' => 8,
'b' => 2,
_ if !ch.is_ascii_alphanumeric() => {
self.pos -= 2;
10
},
_ => return Err(Error::SyntaxError(format!("unknown number radix {ch}"))),
}
}
_ => 10,
};
let mut saw_decimal = false;
while let Some(&ch) = self.input.get(self.pos) {
if ch == '.' && !saw_decimal {
saw_decimal = true;
buffer.push(ch);
self.pos += 1;
} else if ch.is_digit(radix) {
buffer.push(ch);
self.pos += 1;
} else {
break;
}
}
Ok(Token::Number {
radix,
text: buffer,
})
}
fn lex_string(&mut self) -> Result<String, Error> {
self.pos += 1; // take "
let mut buffer = String::new();
while let Some(&ch) = self.input.get(self.pos) {
if ch == '"' {
break;
}
buffer.push(self.lex_char()?);
}
if !self.input.get(self.pos).is_some_and(|c| *c == '"') {
panic!("expected \"");
}
self.pos += 1;
Ok(buffer)
}
#[inline]
fn lex_char(&mut self) -> Result<char, Error> {
let ch = match self.input.get(self.pos) {
Some('\\') => {
self.pos += 1;
let ch = self
.input
.get(self.pos)
.ok_or_else(|| Error::syn("expected escape char"))?;
match ch {
'n' => '\n',
't' => '\t',
'\'' => '\'',
'"' => '\"',
// 'x' => '\x',
// 'u' => '\u',
_ => return Err(Error::syn("unknown escape char")),
}
}
Some(ch) => *ch,
None => return Err(Error::syn("expected char")),
};
self.pos += 1;
Ok(ch)
}
fn lex_ident(&mut self) -> String {
let rest: String = self.input[self.pos..]
.iter()
.copied()
.take_while(|c| c.is_alphanumeric() || *c == '_')
.collect();
self.pos += rest.len();
rest
}
fn lex_op(&mut self) -> Result<Token, Error> {
let ch = self.input[self.pos];
macro_rules! settable {
($normal:expr, $set:expr) => {
match self.input.get(self.pos + 1) {
Some('=') => {
self.pos += 1;
$set
}
_ => $normal,
}
};
}
let token = match ch {
'(' => Token::OpenParan,
')' => Token::CloseParan,
'[' => Token::OpenBracket,
']' => Token::CloseBracket,
'{' => Token::OpenBrace,
'}' => Token::CloseBrace,
'+' => settable!(Token::Plus, Token::PlusSet),
'-' => match self.input.get(self.pos + 1) {
Some('>') => {
self.pos += 1;
Token::ThinArrow
},
Some('=') => {
self.pos += 1;
Token::MinusSet
}
_ => Token::Minus,
},
'*' => match self.input.get(self.pos + 1) {
Some('*') => {
self.pos += 1;
settable!(Token::DoubleStar, Token::DoubleStarSet)
},
Some('=') => {
self.pos += 1;
Token::StarSet
}
_ => Token::Star,
},
// TODO: comments
'/' => settable!(Token::Slash, Token::SlashSet),
'%' => settable!(Token::Percent, Token::PercentSet),
'|' => match self.input.get(self.pos + 1) {
Some('|') => {
self.pos += 1;
settable!(Token::DoublePipe, Token::DoublePipeSet)
},
Some('=') => {
self.pos += 1;
Token::PipeSet
}
_ => Token::Pipe,
},
'&' => match self.input.get(self.pos + 1) {
Some('&') => {
self.pos += 1;
settable!(Token::DoubleAnd, Token::DoubleAndSet)
}
Some('=') => {
self.pos += 1;
Token::AndSet
}
_ => Token::And,
},
'^' => settable!(Token::Carat, Token::CaratSet),
'=' => match self.input.get(self.pos + 1) {
Some('=') => {
self.pos += 1;
Token::Eq
}
Some('>') => {
self.pos += 1;
Token::FatArrow
}
_ => Token::Set,
},
'!' => match self.input.get(self.pos + 1) {
Some('=') => {
self.pos += 1;
Token::Neq
}
_ => Token::Not,
},
'<' => match self.input.get(self.pos + 1) {
Some('=') => {
self.pos += 1;
Token::LessEq
}
Some('<') => {
self.pos += 1;
settable!(Token::Shl, Token::ShlSet)
}
_ => Token::Less,
},
'>' => match self.input.get(self.pos + 1) {
Some('=') => {
self.pos += 1;
Token::GreaterEq
}
Some('>') => {
self.pos += 1;
settable!(Token::Shr, Token::ShrSet)
}
_ => Token::Greater,
},
'.' => match self.input.get(self.pos + 1) {
Some('.') => match self.input.get(self.pos + 1) {
Some('.') => {
self.pos += 2;
Token::TripleDot
}
_ => {
self.pos += 1;
Token::DoubleDot
},
},
_ => Token::Dot,
},
':' => match self.input.get(self.pos + 1) {
Some(':') => {
self.pos += 1;
Token::DoubleColon
}
_ => Token::Colon,
},
',' => Token::Comma,
';' => Token::Semicolon,
'?' => Token::Question,
_ => return Err(Error::SyntaxError(format!("unexpected character {}", ch))),
};
self.pos += 1;
Ok(token)
}
}

59
src/main.rs Normal file
View file

@ -0,0 +1,59 @@
/*
typechecking is somewhat embedded in the parser and gets essentially run
a second time when generating (so the types are known), there should be
a better way
*/
mod error;
mod generator;
mod lexer;
mod parser;
pub use error::Error;
use parser::Context;
fn main() {
let mut lexer = lexer::Lexer::new("!{ let foo = 8; let bar = foo * 3; foo + bar < 10 }".into());
let mut tokens = vec![];
loop {
match lexer.next() {
Ok(None) => break,
Ok(Some(token)) => tokens.push(token),
Err(error) => {
eprintln!("error: {:?}", error);
return;
}
}
}
// dbg!(&tokens);
let mut parser = parser::Parser::new(tokens);
let mut statements = vec![];
loop {
match parser.next() {
Ok(None) => break,
Ok(Some(tree)) => {
dbg!(&tree);
match &tree {
parser::Statement::Let(..) => todo!(),
parser::Statement::Expr(expr) => match expr.infer(&Context::new()) {
Ok(ty) => eprintln!("type: {:?}", ty),
Err(err) => eprintln!("err: {:?}", err),
},
};
statements.push(tree);
}
Err(error) => {
eprintln!("error: {:?}", error);
return;
}
}
}
let expr = match &statements[0] {
crate::parser::Statement::Expr(expr) => expr,
_ => todo!(),
};
generator::generate(expr);
}

446
src/parser.rs Normal file
View file

@ -0,0 +1,446 @@
use std::collections::HashMap;
use crate::lexer::Token;
use crate::Error;
pub struct Parser {
tokens: Vec<Token>,
pos: usize,
}
#[derive(Debug, Clone)]
pub enum BinaryOp {
Pow,
Mul,
Div,
Mod,
Add,
Sub,
Shl,
Shr,
Less,
LessEq,
Greater,
GreaterEq,
Eq,
Neq,
BitAnd,
Xor,
BitOr,
LogicAnd,
LogicOr,
// TODO
// Set,
}
#[derive(Debug, Clone)]
pub enum PrefixOp {
Minus,
LogicNot,
BitNot,
}
#[derive(Debug, Clone)]
pub enum SuffixOp {
Unravel,
}
#[derive(Debug, Clone)]
pub enum Statement {
Let(String, Expr),
// Type(String, Type),
Expr(Expr),
// Func(String, ...),
// Break,
// Continue,
// Type,
}
#[derive(Debug, Clone)]
pub struct Block(pub Vec<Statement>);
#[derive(Debug, Clone)]
pub enum Expr {
Literal(Literal),
Variable(String),
Binary(BinaryOp, Box<Expr>, Box<Expr>),
Unary(PrefixOp, Box<Expr>),
Match(Box<Expr>, Vec<(Pattern, Expr)>),
// Call(String, Vec<Expr>),
Block(Block),
}
#[derive(Debug, Clone, PartialEq)]
pub enum Pattern {
Literal(Literal),
}
#[derive(Debug, Clone, PartialEq)]
pub enum Literal {
Integer(i64),
Float(f64),
Boolean(bool),
String(String),
Char(char),
}
#[derive(Debug, Clone)]
pub struct Context {
pub locals: HashMap<String, Type>,
}
impl Parser {
pub fn new(tokens: Vec<Token>) -> Parser {
Parser { tokens, pos: 0 }
}
fn peek_tok(&self) -> Option<&Token> {
self.tokens.get(self.pos)
}
fn next_tok(&mut self) -> Option<&Token> {
let tok = self.tokens.get(self.pos);
self.pos += 1;
tok
}
fn eat(&mut self, token: Token) -> Result<&Token, Error> {
match self.next_tok() {
Some(t) if t == &token => Ok(t),
Some(t) => Err(Error::SyntaxError(format!("expected {token:?}, got {t:?}"))),
None => Err(Error::SyntaxError(format!("expected {token:?}, got eof"))),
}
}
pub fn next(&mut self) -> Result<Option<Statement>, Error> {
self.parse_statement()
}
fn parse_statement(&mut self) -> Result<Option<Statement>, Error> {
let Some(tok) = self.peek_tok() else {
return Ok(None);
};
let stmt = match tok {
Token::Let => {
self.eat(Token::Let)?;
let name = match self.next_tok() {
Some(Token::Ident(ident)) => ident.to_string(),
Some(tk) => return Err(Error::SyntaxError(format!("expected identifier, got {tk:?}"))),
None => return Err(Error::SyntaxError(format!("expected identifier, got eof"))),
};
self.eat(Token::Set)?;
let expr = self.parse_expr(0)?;
Statement::Let(name, expr)
},
_ => Statement::Expr(self.parse_expr(0)?),
};
Ok(Some(stmt))
}
fn parse_block(&mut self) -> Result<Block, Error> {
let mut statements = vec![];
loop {
match self.peek_tok() {
Some(Token::CloseBrace) => break,
Some(_) => (),
None => return Err(Error::syn("missing closing brace")),
};
statements.push(self.parse_statement()?.unwrap());
match self.peek_tok() {
Some(Token::Semicolon) => self.next_tok(),
Some(Token::CloseBrace) => break,
Some(tok) => return Err(Error::SyntaxError(format!("unexpected token {tok:?}"))),
None => return Err(Error::syn("unexpected eof")),
};
}
Ok(Block(statements))
}
fn parse_expr(&mut self, binding: u8) -> Result<Expr, Error> {
let tok = self.next_tok().ok_or(Error::syn("expected a token"))?;
let mut expr = match tok {
Token::Number { radix: _, text } => {
if text.contains('.') {
Expr::Literal(Literal::Float(text.parse().unwrap()))
} else {
Expr::Literal(Literal::Integer(text.parse().unwrap()))
}
}
Token::Ident(ident) => Expr::Variable(ident.to_string()),
Token::False => Expr::Literal(Literal::Boolean(false)),
Token::True => Expr::Literal(Literal::Boolean(true)),
Token::If => {
let cond = self.parse_expr(0)?;
self.eat(Token::OpenBrace)?;
let block = self.parse_block()?;
self.eat(Token::CloseBrace)?;
let otherwise = if self.peek_tok().is_some_and(|t| *t == Token::Else) {
self.next_tok();
match self.peek_tok() {
Some(Token::OpenBrace) => {
self.eat(Token::OpenBrace)?;
let b = Some(self.parse_block()?);
self.eat(Token::CloseBrace)?;
b
}
Some(Token::If) => Some(Block(vec![Statement::Expr(self.parse_expr(0)?)])),
Some(_) => return Err(Error::syn("unexpected token")),
None => return Err(Error::syn("unexpected eof, wanted body for else")),
}
} else {
None
};
let mut map = vec![(Pattern::Literal(Literal::Boolean(true)), Expr::Block(block))];
if let Some(otherwise) = otherwise {
map.push((Pattern::Literal(Literal::Boolean(false)), Expr::Block(otherwise)));
}
Expr::Match(Box::new(cond), map)
}
Token::Minus => {
let expr = self.parse_expr(1)?;
Expr::Unary(PrefixOp::Minus, Box::new(expr))
}
Token::Not => {
let expr = self.parse_expr(1)?;
Expr::Unary(PrefixOp::LogicNot, Box::new(expr))
}
Token::Match => {
let expr = self.parse_expr(0)?;
let mut arms = vec![];
self.eat(Token::OpenBrace)?;
loop {
let pat = self.parse_pattern()?;
self.eat(Token::FatArrow)?;
let expr = self.parse_expr(0)?;
arms.push((pat, expr));
if self.peek_tok().is_some_and(|t| t == &Token::Comma) {
self.next_tok();
} else {
break;
}
}
self.eat(Token::CloseBrace)?;
Expr::Match(Box::new(expr), arms)
}
Token::OpenBrace => {
let b = Expr::Block(self.parse_block()?);
self.eat(Token::CloseBrace)?;
b
}
Token::OpenParan => {
let expr = self.parse_expr(0)?;
self.eat(Token::CloseParan)?;
expr
}
_ => return Err(Error::syn("unexpected token")),
};
while let Some(next) = self.peek_tok() {
let Some(op) = BinaryOp::from_token(next) else {
break;
};
let (bind_left, bind_right) = op.precedence();
if bind_left < binding {
break;
}
self.next_tok();
let rhs = self.parse_expr(bind_right)?;
expr = Expr::Binary(op, Box::new(expr), Box::new(rhs));
}
Ok(expr)
}
fn parse_pattern(&mut self) -> Result<Pattern, Error> {
let tok = self.next_tok().ok_or(Error::syn("expected a token"))?;
let pat = match tok {
Token::Number { radix: _, text } => {
if text.contains('.') {
Pattern::Literal(Literal::Float(text.parse().unwrap()))
} else {
Pattern::Literal(Literal::Integer(text.parse().unwrap()))
}
}
Token::False => Pattern::Literal(Literal::Boolean(false)),
Token::True => Pattern::Literal(Literal::Boolean(true)),
_ => todo!(),
};
Ok(pat)
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Type {
Integer,
Float,
Boolean,
String,
Char,
Function(Vec<Type>, Box<Type>),
Tuple(Vec<Type>),
}
impl Expr {
pub fn infer(&self, ctx: &Context) -> Result<Type, Error> {
match self {
Self::Literal(lit) => lit.infer(),
Self::Binary(op, lhs, rhs) => Ok(op.infer(lhs.infer(ctx)?, rhs.infer(ctx)?)?),
Self::Unary(op, expr) => Ok(op.infer(expr.infer(ctx)?)?),
Self::Variable(name) => match ctx.locals.get(name) {
Some(ty) => Ok(ty.clone()),
None => Err(Error::ReferenceError(format!("cannot find variable {name}"))),
},
Self::Match(item, arms) => {
let mut match_ty = None;
let item_ty = item.infer(ctx)?;
for (pat, expr) in arms {
let ty = expr.infer(ctx)?;
let pat_ty = pat.infer()?;
if item_ty != pat_ty {
return Err(Error::ty("cannot compare different type"));
}
if match_ty.is_some_and(|mty| mty != ty) {
return Err(Error::ty("branch returns different type"));
}
match_ty = Some(ty);
}
// TODO: exhaustiveness checks
let Some(match_ty) = match_ty else {
// TODO: infallible types? `enum Nope {}`
return Err(Error::ty("match has no branches to infer"));
};
Ok(match_ty)
}
Self::Block(block) => block.infer(ctx),
}
}
}
impl BinaryOp {
pub fn infer(&self, a: Type, b: Type) -> Result<Type, Error> {
use BinaryOp as B;
use Type as T;
let ty = match (self, a, b) {
(B::Add | B::Sub | B::Mul | B::Div | B::Mod | B::Pow, T::Integer, T::Integer) => T::Integer,
(B::Eq | B::Neq | B::Less | B::LessEq | B::Greater | B::GreaterEq, T::Integer, T::Integer) => T::Boolean,
(B::Add | B::Sub | B::Mul | B::Div | B::Mod | B::Pow, T::Float, T::Float) => T::Float,
(B::Eq | B::Neq | B::Less | B::LessEq | B::Greater | B::GreaterEq, T::Float, T::Float) => T::Boolean,
// (B::Add | B::Sub | B::Mul | B::Div, T::Float, T::Float) => T::Float,
(op, a, b) => {
return Err(Error::TypeError(format!(
"operator {op:?} cannot be applied to {a:?} and {b:?}"
)))
}
};
Ok(ty)
}
fn precedence(&self) -> (u8, u8) {
match self {
Self::Pow => (22, 21),
Self::Mul | Self::Div | Self::Mod => (19, 20),
Self::Add | Self::Sub => (17, 18),
Self::Shl | Self::Shr => (15, 16),
Self::Less | Self::LessEq | Self::Greater | Self::GreaterEq => (13, 14),
Self::Eq | Self::Neq => (11, 12),
Self::BitAnd => (9, 10),
Self::Xor => (7, 8),
Self::BitOr => (5, 6),
Self::LogicAnd => (3, 4),
Self::LogicOr => (1, 2),
}
}
fn from_token(token: &Token) -> Option<Self> {
let op = match token {
Token::DoubleStar => Self::Pow,
Token::Star => Self::Mul,
Token::Slash => Self::Div,
Token::Percent => Self::Mod,
Token::Plus => Self::Add,
Token::Minus => Self::Sub,
Token::Shl => Self::Shl,
Token::Shr => Self::Shr,
Token::Less => Self::Less,
Token::LessEq => Self::LessEq,
Token::Greater => Self::Greater,
Token::GreaterEq => Self::GreaterEq,
Token::Eq => Self::Eq,
Token::Neq => Self::Neq,
Token::And => Self::BitAnd,
Token::Carat => Self::Xor,
Token::Pipe => Self::BitOr,
Token::DoubleAnd => Self::LogicAnd,
Token::DoublePipe => Self::LogicOr,
_ => return None,
};
Some(op)
}
}
impl PrefixOp {
pub fn infer(&self, a: Type) -> Result<Type, Error> {
use Type as T;
use PrefixOp as U;
let ty = match (self, a) {
(U::Minus, T::Integer) => T::Integer,
// (U::Minus, T::Float) => T::Float,
(U::LogicNot, T::Boolean) => T::Boolean,
(op, ty) => {
return Err(Error::TypeError(format!(
"operator {op:?} cannot be applied to {ty:?}"
)))
}
};
Ok(ty)
}
}
impl Block {
#[allow(clippy::never_loop)] // for now
pub fn infer(&self, ctx: &Context) -> Result<Type, Error> {
let mut ctx = ctx.clone();
let mut ty = Type::Tuple(vec![]);
for statement in &self.0 {
match statement {
Statement::Expr(expr) => ty = expr.infer(&ctx)?,
Statement::Let(name, expr) => {
let var_ty = expr.infer(&ctx)?;
ctx.locals.insert(name.clone(), var_ty);
ty = Type::Tuple(vec![]);
}
}
}
Ok(ty)
}
}
impl Literal {
fn infer(&self) -> Result<Type, Error> {
match self {
Literal::Integer(_) => Ok(Type::Integer),
Literal::Float(_) => Ok(Type::Float),
Literal::Boolean(_) => Ok(Type::Boolean),
Literal::String(_) => Ok(Type::String),
Literal::Char(_) => Ok(Type::Char),
}
}
}
impl Pattern {
fn infer(&self) -> Result<Type, Error> {
match self {
Pattern::Literal(lit) => lit.infer(),
}
}
}
impl Context {
pub fn new() -> Context {
Context {
locals: HashMap::new(),
}
}
}