From 9c73e38ac29bc702611c79ea4492540cd7fa387e Mon Sep 17 00:00:00 2001 From: tezlm Date: Fri, 29 Sep 2023 15:55:26 -0700 Subject: [PATCH] random things --- .gitignore | 1 + Cargo.lock | 7 + Cargo.toml | 8 + print.wat | 32 ++++ src/error.rs | 16 ++ src/generator.rs | 160 +++++++++++++++++ src/lexer.rs | 325 ++++++++++++++++++++++++++++++++++ src/main.rs | 59 +++++++ src/parser.rs | 446 +++++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 1054 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 print.wat create mode 100644 src/error.rs create mode 100644 src/generator.rs create mode 100644 src/lexer.rs create mode 100644 src/main.rs create mode 100644 src/parser.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..afaae83 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "lang" +version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..7541543 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "lang" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/print.wat b/print.wat new file mode 100644 index 0000000..41c085d --- /dev/null +++ b/print.wat @@ -0,0 +1,32 @@ +(module + ;; import from wasi + ;; fn fd_write(fd, *iovs, iovs_len, nwritten) -> bytes_written + (import "wasi_unstable" "fd_write" (func $fd_write (param i32 i32 i32 i32) (result i32))) + + ;; create memory (size = 1 page = 64KiB) + (memory $foobar 1) + + ;; export memory - it's required, but we don't use it so the size is set to 0 + (export "memory" (memory 0)) + + ;; write string to memory (offset = 8 bytes) + (data (i32.const 8) "Hello, world!\n") + + (func $main (export "_start") + ;; iov.iov_base - pointer to string (offset = 0 bytes) + ;; the string's offset is 8 bytes in memory + (i32.store (i32.const 0) (i32.const 8)) + + ;; iov.iov_len - length of the hello world string (offset = 4 bytes) + ;; the string's length is 14 bytes + (i32.store (i32.const 4) (i32.const 14)) + + (call $fd_write + (i32.const 1) ;; fd: stdout = 1 + (i32.const 0) ;; data: pointer to memory - this is the first memory we create (index 0) + (i32.const 1) ;; data_len: there's 1 string + (i32.const 2468) ;; nwritten: i don't care about this, write it wherever + ) + drop ;; drop number of bytes written + ) +) diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..2cb8549 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,16 @@ +#[derive(Debug)] +pub enum Error { + SyntaxError(String), + TypeError(String), + ReferenceError(String), +} + +impl Error { + pub fn syn(what: &'static str) -> Error { + Error::SyntaxError(what.to_string()) + } + + pub fn ty(what: &'static str) -> Error { + Error::TypeError(what.to_string()) + } +} diff --git a/src/generator.rs b/src/generator.rs new file mode 100644 index 0000000..e0581bd --- /dev/null +++ b/src/generator.rs @@ -0,0 +1,160 @@ +/* +optimizations + +- use i32.eqz when comparing to zero +- write negative numberss directly instead of as positive + sign flip +*/ + +use crate::parser::{Expr, Literal, BinaryOp, PrefixOp, Statement, Context}; + +pub fn generate(expr: &Expr) { + println!(); + println!(); + + let mut ctx = Context::new(); + let mut exprs = Vec::new(); + get_locals(expr, &mut ctx, &mut exprs); + // println!(r#"(module (func (export "_start") (local $match i32)"#); + println!(r#"(module (func (export "_start") (result i32) (local $match i32)"#); + // println!(r#"(module (func (export "_start") (result f64) (local $match f64)"#); + for (name, _) in &exprs { + let ty = match ctx.locals.get(name).unwrap() { + crate::parser::Type::Integer => "i32", + crate::parser::Type::Float => "f64", + _ => todo!(), + }; + println!("(local ${name} {ty})"); + } + for (name, expr) in &exprs { + gen_expr(expr, &ctx); + println!("(local.set ${name})"); + } + + gen_expr(&expr, &ctx); + // println!("drop"); + println!("))"); +} + +fn gen_expr(expr: &Expr, ctx: &Context) { + match expr { + Expr::Literal(lit) => match lit { + Literal::Integer(int) => println!("i32.const {int}"), + Literal::Float(f) => println!("f64.const {f}"), + Literal::Boolean(b) => println!("i32.const {}", if *b { 1 } else { 0 }), + _ => todo!(), + } + Expr::Variable(name) => { + println!("local.get ${name}"); + } + Expr::Binary(op, a, b) => { + gen_expr(a, ctx); + gen_expr(b, ctx); + + let ty = match expr.infer(&ctx).unwrap() { + crate::parser::Type::Integer => "i32", + crate::parser::Type::Float => "f64", + crate::parser::Type::Boolean => "i32", + _ => todo!(), + }; + match op { + BinaryOp::Add => println!("{ty}.add"), + BinaryOp::Mul => println!("{ty}.mul"), + BinaryOp::Sub => println!("{ty}.sub"), + BinaryOp::Div => println!("{ty}.div_u"), // do i _u or _s? + BinaryOp::Mod => println!("{ty}.rem_u"), + BinaryOp::Eq => println!("{ty}.eq"), + BinaryOp::Neq => println!("{ty}.neq"), + BinaryOp::Less => println!("{ty}.lt_u"), + BinaryOp::Greater => println!("{ty}.gt_u"), + _ => todo!(), + } + } + Expr::Unary(op, e) => { + gen_expr(e, ctx); + match op { + PrefixOp::Minus => { + // this is so inefficent, but i don't care + println!("i32.const -1"); + println!("i32.mul"); + } + PrefixOp::LogicNot => { + println!("i32.eqz"); + } + PrefixOp::BitNot => { + // TODO: do i flip the sign bit? + println!("i32.const {}", i32::MAX); + println!("i32.xor"); + } + } + } + Expr::Match(cond, arms) => { + println!(";; --- set match variable"); + println!("(local.set $match ("); + gen_expr(cond, ctx); + println!("))"); + + println!(";; --- generate match"); + + for (idx, (pat, expr)) in arms.iter().enumerate() { + // FIXME: hardcoded until patern matching works better + match pat { + crate::parser::Pattern::Literal(lit) => match lit { + Literal::Integer(int) => println!("i32.const {}", int), + Literal::Boolean(b) => println!("i32.const {}", if *b { 1 } else { 0 }), + _ => todo!(), + } + }; + + println!("local.get $match"); + println!("i32.eq"); + println!("(if (result i32) (then"); + gen_expr(expr, ctx); + + if idx == arms.len() - 1 { + // TODO: verify its actually unreachable earlier on + println!(") (else unreachable"); + } else { + println!(") (else"); + } + } + println!("{}", ")".repeat(arms.len() * 2)); + println!(";; --- done"); + } + Expr::Block(b) => { + for (i, stmt) in b.0.iter().enumerate() { + match stmt { + Statement::Expr(expr) => { + gen_expr(expr, &ctx); + if i < b.0.len() - 1 { + println!("drop"); + } + } + _ => {}, + } + } + } + }; +} + +fn get_locals(expr: &Expr, ctx: &mut Context, exprs: &mut Vec<(String, Expr)>) { + match expr { + Expr::Block(b) => { + for stmt in &b.0 { + match stmt { + Statement::Let(name, expr) => { + let ty = expr.infer(ctx).unwrap(); + ctx.locals.insert(name.clone(), ty); + exprs.push((name.clone(), expr.clone())); + } + Statement::Expr(expr) => get_locals(&expr, ctx, exprs), + } + } + }, + Expr::Unary(_, expr) => get_locals(&expr, ctx, exprs), + Expr::Binary(_, a, b) => { + get_locals(&a, ctx, exprs); + get_locals(&b, ctx, exprs); + } + _ => (), + } +} diff --git a/src/lexer.rs b/src/lexer.rs new file mode 100644 index 0000000..41542fe --- /dev/null +++ b/src/lexer.rs @@ -0,0 +1,325 @@ +use crate::Error; + +pub struct Lexer { + input: Vec, + pos: usize, +} + +#[rustfmt::skip] +#[derive(Debug, PartialEq, Eq)] +pub enum Token { + Number { radix: u32, text: String }, + Ident(String), + String(String), + Char(char), + + OpenParan, CloseParan, + OpenBrace, CloseBrace, + OpenBracket, CloseBracket, + + Plus, Minus, Star, DoubleStar, Slash, Percent, + Pipe, DoublePipe, And, DoubleAnd, Carat, Shl, Shr, + + PlusSet, MinusSet, StarSet, DoubleStarSet, SlashSet, PercentSet, + PipeSet, DoublePipeSet, AndSet, DoubleAndSet, CaratSet, ShlSet, ShrSet, + + Set, Eq, Neq, Less, LessEq, Greater, GreaterEq, Not, + Dot, DoubleDot, TripleDot, Comma, Question, Colon, DoubleColon, Semicolon, + ThinArrow, FatArrow, + + Let, Const, Type, Fn, + True, False, + If, Else, Match, + While, Loop, For, Break, Continue, +} + +impl Lexer { + pub fn new(input: String) -> Lexer { + Lexer { + input: input.chars().collect(), + pos: 0, + } + } + + pub fn next(&mut self) -> Result, Error> { + let Some(&ch) = self.input.get(self.pos) else { + return Ok(None); + }; + let tok = match ch { + '0'..='9' => { + let token = self.lex_number()?; + if self + .input + .get(self.pos) + .is_some_and(|c| c.is_ascii_alphanumeric()) + { + panic!("unexpected char"); + } + token + } + '\'' => { + self.pos += 1; + let ch = self.lex_char()?; + if self.input.get(self.pos).is_some_and(|c| *c != '\'') { + panic!("expected '"); + } + self.pos += 1; + Token::Char(ch) + } + '"' => Token::String(self.lex_string()?), + ch if ch.is_alphabetic() || ch == '_' => match self.lex_ident().as_str() { + "let" => Token::Let, + "const" => Token::Const, + "type" => Token::Type, + "fn" => Token::Fn, + "true" => Token::True, + "false" => Token::False, + "if" => Token::If, + "else" => Token::Else, + "match" => Token::Match, + "while" => Token::While, + "loop" => Token::Loop, + "for" => Token::For, + "break" => Token::Break, + "continue" => Token::Continue, + ident => Token::Ident(ident.to_string()), + }, + ch if ch.is_whitespace() => { + self.pos += 1; + return self.next(); + } + _ => self.lex_op()?, + }; + Ok(Some(tok)) + } + + fn lex_number(&mut self) -> Result { + let mut buffer = String::new(); + let radix = match (self.input[self.pos], self.input.get(self.pos + 1)) { + ('0', Some(ch)) if ch.is_digit(10) => 10, + ('0', Some(ch)) => { + self.pos += 2; + match ch { + 'x' => 16, + 'o' => 8, + 'b' => 2, + _ if !ch.is_ascii_alphanumeric() => { + self.pos -= 2; + 10 + }, + _ => return Err(Error::SyntaxError(format!("unknown number radix {ch}"))), + } + } + _ => 10, + }; + let mut saw_decimal = false; + while let Some(&ch) = self.input.get(self.pos) { + if ch == '.' && !saw_decimal { + saw_decimal = true; + buffer.push(ch); + self.pos += 1; + } else if ch.is_digit(radix) { + buffer.push(ch); + self.pos += 1; + } else { + break; + } + } + Ok(Token::Number { + radix, + text: buffer, + }) + } + + fn lex_string(&mut self) -> Result { + self.pos += 1; // take " + let mut buffer = String::new(); + while let Some(&ch) = self.input.get(self.pos) { + if ch == '"' { + break; + } + buffer.push(self.lex_char()?); + } + if !self.input.get(self.pos).is_some_and(|c| *c == '"') { + panic!("expected \""); + } + self.pos += 1; + Ok(buffer) + } + + #[inline] + fn lex_char(&mut self) -> Result { + let ch = match self.input.get(self.pos) { + Some('\\') => { + self.pos += 1; + let ch = self + .input + .get(self.pos) + .ok_or_else(|| Error::syn("expected escape char"))?; + match ch { + 'n' => '\n', + 't' => '\t', + '\'' => '\'', + '"' => '\"', + // 'x' => '\x', + // 'u' => '\u', + _ => return Err(Error::syn("unknown escape char")), + } + } + Some(ch) => *ch, + None => return Err(Error::syn("expected char")), + }; + self.pos += 1; + Ok(ch) + } + + fn lex_ident(&mut self) -> String { + let rest: String = self.input[self.pos..] + .iter() + .copied() + .take_while(|c| c.is_alphanumeric() || *c == '_') + .collect(); + self.pos += rest.len(); + rest + } + + fn lex_op(&mut self) -> Result { + let ch = self.input[self.pos]; + + macro_rules! settable { + ($normal:expr, $set:expr) => { + match self.input.get(self.pos + 1) { + Some('=') => { + self.pos += 1; + $set + } + _ => $normal, + } + }; + } + + let token = match ch { + '(' => Token::OpenParan, + ')' => Token::CloseParan, + '[' => Token::OpenBracket, + ']' => Token::CloseBracket, + '{' => Token::OpenBrace, + '}' => Token::CloseBrace, + '+' => settable!(Token::Plus, Token::PlusSet), + '-' => match self.input.get(self.pos + 1) { + Some('>') => { + self.pos += 1; + Token::ThinArrow + }, + Some('=') => { + self.pos += 1; + Token::MinusSet + } + _ => Token::Minus, + }, + '*' => match self.input.get(self.pos + 1) { + Some('*') => { + self.pos += 1; + settable!(Token::DoubleStar, Token::DoubleStarSet) + }, + Some('=') => { + self.pos += 1; + Token::StarSet + } + _ => Token::Star, + }, + // TODO: comments + '/' => settable!(Token::Slash, Token::SlashSet), + '%' => settable!(Token::Percent, Token::PercentSet), + '|' => match self.input.get(self.pos + 1) { + Some('|') => { + self.pos += 1; + settable!(Token::DoublePipe, Token::DoublePipeSet) + }, + Some('=') => { + self.pos += 1; + Token::PipeSet + } + _ => Token::Pipe, + }, + '&' => match self.input.get(self.pos + 1) { + Some('&') => { + self.pos += 1; + settable!(Token::DoubleAnd, Token::DoubleAndSet) + } + Some('=') => { + self.pos += 1; + Token::AndSet + } + _ => Token::And, + }, + '^' => settable!(Token::Carat, Token::CaratSet), + '=' => match self.input.get(self.pos + 1) { + Some('=') => { + self.pos += 1; + Token::Eq + } + Some('>') => { + self.pos += 1; + Token::FatArrow + } + _ => Token::Set, + }, + '!' => match self.input.get(self.pos + 1) { + Some('=') => { + self.pos += 1; + Token::Neq + } + _ => Token::Not, + }, + '<' => match self.input.get(self.pos + 1) { + Some('=') => { + self.pos += 1; + Token::LessEq + } + Some('<') => { + self.pos += 1; + settable!(Token::Shl, Token::ShlSet) + } + _ => Token::Less, + }, + '>' => match self.input.get(self.pos + 1) { + Some('=') => { + self.pos += 1; + Token::GreaterEq + } + Some('>') => { + self.pos += 1; + settable!(Token::Shr, Token::ShrSet) + } + _ => Token::Greater, + }, + '.' => match self.input.get(self.pos + 1) { + Some('.') => match self.input.get(self.pos + 1) { + Some('.') => { + self.pos += 2; + Token::TripleDot + } + _ => { + self.pos += 1; + Token::DoubleDot + }, + }, + _ => Token::Dot, + }, + ':' => match self.input.get(self.pos + 1) { + Some(':') => { + self.pos += 1; + Token::DoubleColon + } + _ => Token::Colon, + }, + ',' => Token::Comma, + ';' => Token::Semicolon, + '?' => Token::Question, + _ => return Err(Error::SyntaxError(format!("unexpected character {}", ch))), + }; + self.pos += 1; + Ok(token) + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..5f9d059 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,59 @@ +/* +typechecking is somewhat embedded in the parser and gets essentially run +a second time when generating (so the types are known), there should be +a better way +*/ + +mod error; +mod generator; +mod lexer; +mod parser; + +pub use error::Error; +use parser::Context; + +fn main() { + let mut lexer = lexer::Lexer::new("!{ let foo = 8; let bar = foo * 3; foo + bar < 10 }".into()); + + let mut tokens = vec![]; + loop { + match lexer.next() { + Ok(None) => break, + Ok(Some(token)) => tokens.push(token), + Err(error) => { + eprintln!("error: {:?}", error); + return; + } + } + } + // dbg!(&tokens); + let mut parser = parser::Parser::new(tokens); + let mut statements = vec![]; + loop { + match parser.next() { + Ok(None) => break, + Ok(Some(tree)) => { + dbg!(&tree); + match &tree { + parser::Statement::Let(..) => todo!(), + parser::Statement::Expr(expr) => match expr.infer(&Context::new()) { + Ok(ty) => eprintln!("type: {:?}", ty), + Err(err) => eprintln!("err: {:?}", err), + }, + }; + statements.push(tree); + } + Err(error) => { + eprintln!("error: {:?}", error); + return; + } + } + } + + let expr = match &statements[0] { + crate::parser::Statement::Expr(expr) => expr, + _ => todo!(), + }; + + generator::generate(expr); +} diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..9dc2e26 --- /dev/null +++ b/src/parser.rs @@ -0,0 +1,446 @@ +use std::collections::HashMap; + +use crate::lexer::Token; +use crate::Error; + +pub struct Parser { + tokens: Vec, + pos: usize, +} + +#[derive(Debug, Clone)] +pub enum BinaryOp { + Pow, + Mul, + Div, + Mod, + Add, + Sub, + Shl, + Shr, + Less, + LessEq, + Greater, + GreaterEq, + Eq, + Neq, + BitAnd, + Xor, + BitOr, + LogicAnd, + LogicOr, + // TODO + // Set, +} + +#[derive(Debug, Clone)] +pub enum PrefixOp { + Minus, + LogicNot, + BitNot, +} + +#[derive(Debug, Clone)] +pub enum SuffixOp { + Unravel, +} + +#[derive(Debug, Clone)] +pub enum Statement { + Let(String, Expr), + // Type(String, Type), + Expr(Expr), + // Func(String, ...), + // Break, + // Continue, + // Type, +} + +#[derive(Debug, Clone)] +pub struct Block(pub Vec); + +#[derive(Debug, Clone)] +pub enum Expr { + Literal(Literal), + Variable(String), + Binary(BinaryOp, Box, Box), + Unary(PrefixOp, Box), + Match(Box, Vec<(Pattern, Expr)>), + // Call(String, Vec), + Block(Block), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Pattern { + Literal(Literal), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Literal { + Integer(i64), + Float(f64), + Boolean(bool), + String(String), + Char(char), +} + +#[derive(Debug, Clone)] +pub struct Context { + pub locals: HashMap, +} + +impl Parser { + pub fn new(tokens: Vec) -> Parser { + Parser { tokens, pos: 0 } + } + + fn peek_tok(&self) -> Option<&Token> { + self.tokens.get(self.pos) + } + + fn next_tok(&mut self) -> Option<&Token> { + let tok = self.tokens.get(self.pos); + self.pos += 1; + tok + } + + fn eat(&mut self, token: Token) -> Result<&Token, Error> { + match self.next_tok() { + Some(t) if t == &token => Ok(t), + Some(t) => Err(Error::SyntaxError(format!("expected {token:?}, got {t:?}"))), + None => Err(Error::SyntaxError(format!("expected {token:?}, got eof"))), + } + } + + pub fn next(&mut self) -> Result, Error> { + self.parse_statement() + } + + fn parse_statement(&mut self) -> Result, Error> { + let Some(tok) = self.peek_tok() else { + return Ok(None); + }; + let stmt = match tok { + Token::Let => { + self.eat(Token::Let)?; + let name = match self.next_tok() { + Some(Token::Ident(ident)) => ident.to_string(), + Some(tk) => return Err(Error::SyntaxError(format!("expected identifier, got {tk:?}"))), + None => return Err(Error::SyntaxError(format!("expected identifier, got eof"))), + }; + self.eat(Token::Set)?; + let expr = self.parse_expr(0)?; + Statement::Let(name, expr) + }, + _ => Statement::Expr(self.parse_expr(0)?), + }; + Ok(Some(stmt)) + } + + fn parse_block(&mut self) -> Result { + let mut statements = vec![]; + loop { + match self.peek_tok() { + Some(Token::CloseBrace) => break, + Some(_) => (), + None => return Err(Error::syn("missing closing brace")), + }; + statements.push(self.parse_statement()?.unwrap()); + match self.peek_tok() { + Some(Token::Semicolon) => self.next_tok(), + Some(Token::CloseBrace) => break, + Some(tok) => return Err(Error::SyntaxError(format!("unexpected token {tok:?}"))), + None => return Err(Error::syn("unexpected eof")), + }; + } + Ok(Block(statements)) + } + + fn parse_expr(&mut self, binding: u8) -> Result { + let tok = self.next_tok().ok_or(Error::syn("expected a token"))?; + let mut expr = match tok { + Token::Number { radix: _, text } => { + if text.contains('.') { + Expr::Literal(Literal::Float(text.parse().unwrap())) + } else { + Expr::Literal(Literal::Integer(text.parse().unwrap())) + } + } + Token::Ident(ident) => Expr::Variable(ident.to_string()), + Token::False => Expr::Literal(Literal::Boolean(false)), + Token::True => Expr::Literal(Literal::Boolean(true)), + Token::If => { + let cond = self.parse_expr(0)?; + self.eat(Token::OpenBrace)?; + let block = self.parse_block()?; + self.eat(Token::CloseBrace)?; + let otherwise = if self.peek_tok().is_some_and(|t| *t == Token::Else) { + self.next_tok(); + match self.peek_tok() { + Some(Token::OpenBrace) => { + self.eat(Token::OpenBrace)?; + let b = Some(self.parse_block()?); + self.eat(Token::CloseBrace)?; + b + } + Some(Token::If) => Some(Block(vec![Statement::Expr(self.parse_expr(0)?)])), + Some(_) => return Err(Error::syn("unexpected token")), + None => return Err(Error::syn("unexpected eof, wanted body for else")), + } + } else { + None + }; + let mut map = vec![(Pattern::Literal(Literal::Boolean(true)), Expr::Block(block))]; + if let Some(otherwise) = otherwise { + map.push((Pattern::Literal(Literal::Boolean(false)), Expr::Block(otherwise))); + } + Expr::Match(Box::new(cond), map) + } + Token::Minus => { + let expr = self.parse_expr(1)?; + Expr::Unary(PrefixOp::Minus, Box::new(expr)) + } + Token::Not => { + let expr = self.parse_expr(1)?; + Expr::Unary(PrefixOp::LogicNot, Box::new(expr)) + } + Token::Match => { + let expr = self.parse_expr(0)?; + let mut arms = vec![]; + self.eat(Token::OpenBrace)?; + loop { + let pat = self.parse_pattern()?; + self.eat(Token::FatArrow)?; + let expr = self.parse_expr(0)?; + arms.push((pat, expr)); + if self.peek_tok().is_some_and(|t| t == &Token::Comma) { + self.next_tok(); + } else { + break; + } + } + self.eat(Token::CloseBrace)?; + Expr::Match(Box::new(expr), arms) + } + Token::OpenBrace => { + let b = Expr::Block(self.parse_block()?); + self.eat(Token::CloseBrace)?; + b + } + Token::OpenParan => { + let expr = self.parse_expr(0)?; + self.eat(Token::CloseParan)?; + expr + } + _ => return Err(Error::syn("unexpected token")), + }; + while let Some(next) = self.peek_tok() { + let Some(op) = BinaryOp::from_token(next) else { + break; + }; + let (bind_left, bind_right) = op.precedence(); + if bind_left < binding { + break; + } + self.next_tok(); + let rhs = self.parse_expr(bind_right)?; + expr = Expr::Binary(op, Box::new(expr), Box::new(rhs)); + } + Ok(expr) + } + + fn parse_pattern(&mut self) -> Result { + let tok = self.next_tok().ok_or(Error::syn("expected a token"))?; + let pat = match tok { + Token::Number { radix: _, text } => { + if text.contains('.') { + Pattern::Literal(Literal::Float(text.parse().unwrap())) + } else { + Pattern::Literal(Literal::Integer(text.parse().unwrap())) + } + } + Token::False => Pattern::Literal(Literal::Boolean(false)), + Token::True => Pattern::Literal(Literal::Boolean(true)), + _ => todo!(), + }; + Ok(pat) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Type { + Integer, + Float, + Boolean, + String, + Char, + Function(Vec, Box), + Tuple(Vec), +} + +impl Expr { + pub fn infer(&self, ctx: &Context) -> Result { + match self { + Self::Literal(lit) => lit.infer(), + Self::Binary(op, lhs, rhs) => Ok(op.infer(lhs.infer(ctx)?, rhs.infer(ctx)?)?), + Self::Unary(op, expr) => Ok(op.infer(expr.infer(ctx)?)?), + Self::Variable(name) => match ctx.locals.get(name) { + Some(ty) => Ok(ty.clone()), + None => Err(Error::ReferenceError(format!("cannot find variable {name}"))), + }, + Self::Match(item, arms) => { + let mut match_ty = None; + let item_ty = item.infer(ctx)?; + for (pat, expr) in arms { + let ty = expr.infer(ctx)?; + let pat_ty = pat.infer()?; + if item_ty != pat_ty { + return Err(Error::ty("cannot compare different type")); + } + if match_ty.is_some_and(|mty| mty != ty) { + return Err(Error::ty("branch returns different type")); + } + match_ty = Some(ty); + } + // TODO: exhaustiveness checks + let Some(match_ty) = match_ty else { + // TODO: infallible types? `enum Nope {}` + return Err(Error::ty("match has no branches to infer")); + }; + Ok(match_ty) + } + Self::Block(block) => block.infer(ctx), + } + } +} + +impl BinaryOp { + pub fn infer(&self, a: Type, b: Type) -> Result { + use BinaryOp as B; + use Type as T; + + let ty = match (self, a, b) { + (B::Add | B::Sub | B::Mul | B::Div | B::Mod | B::Pow, T::Integer, T::Integer) => T::Integer, + (B::Eq | B::Neq | B::Less | B::LessEq | B::Greater | B::GreaterEq, T::Integer, T::Integer) => T::Boolean, + (B::Add | B::Sub | B::Mul | B::Div | B::Mod | B::Pow, T::Float, T::Float) => T::Float, + (B::Eq | B::Neq | B::Less | B::LessEq | B::Greater | B::GreaterEq, T::Float, T::Float) => T::Boolean, + // (B::Add | B::Sub | B::Mul | B::Div, T::Float, T::Float) => T::Float, + (op, a, b) => { + return Err(Error::TypeError(format!( + "operator {op:?} cannot be applied to {a:?} and {b:?}" + ))) + } + }; + + Ok(ty) + } + + fn precedence(&self) -> (u8, u8) { + match self { + Self::Pow => (22, 21), + Self::Mul | Self::Div | Self::Mod => (19, 20), + Self::Add | Self::Sub => (17, 18), + Self::Shl | Self::Shr => (15, 16), + Self::Less | Self::LessEq | Self::Greater | Self::GreaterEq => (13, 14), + Self::Eq | Self::Neq => (11, 12), + Self::BitAnd => (9, 10), + Self::Xor => (7, 8), + Self::BitOr => (5, 6), + Self::LogicAnd => (3, 4), + Self::LogicOr => (1, 2), + } + } + + fn from_token(token: &Token) -> Option { + let op = match token { + Token::DoubleStar => Self::Pow, + Token::Star => Self::Mul, + Token::Slash => Self::Div, + Token::Percent => Self::Mod, + Token::Plus => Self::Add, + Token::Minus => Self::Sub, + Token::Shl => Self::Shl, + Token::Shr => Self::Shr, + Token::Less => Self::Less, + Token::LessEq => Self::LessEq, + Token::Greater => Self::Greater, + Token::GreaterEq => Self::GreaterEq, + Token::Eq => Self::Eq, + Token::Neq => Self::Neq, + Token::And => Self::BitAnd, + Token::Carat => Self::Xor, + Token::Pipe => Self::BitOr, + Token::DoubleAnd => Self::LogicAnd, + Token::DoublePipe => Self::LogicOr, + _ => return None, + }; + Some(op) + } +} + +impl PrefixOp { + pub fn infer(&self, a: Type) -> Result { + use Type as T; + use PrefixOp as U; + + let ty = match (self, a) { + (U::Minus, T::Integer) => T::Integer, + // (U::Minus, T::Float) => T::Float, + (U::LogicNot, T::Boolean) => T::Boolean, + (op, ty) => { + return Err(Error::TypeError(format!( + "operator {op:?} cannot be applied to {ty:?}" + ))) + } + }; + + Ok(ty) + } +} + +impl Block { + #[allow(clippy::never_loop)] // for now + pub fn infer(&self, ctx: &Context) -> Result { + let mut ctx = ctx.clone(); + let mut ty = Type::Tuple(vec![]); + for statement in &self.0 { + match statement { + Statement::Expr(expr) => ty = expr.infer(&ctx)?, + Statement::Let(name, expr) => { + let var_ty = expr.infer(&ctx)?; + ctx.locals.insert(name.clone(), var_ty); + ty = Type::Tuple(vec![]); + } + } + } + Ok(ty) + } +} + +impl Literal { + fn infer(&self) -> Result { + match self { + Literal::Integer(_) => Ok(Type::Integer), + Literal::Float(_) => Ok(Type::Float), + Literal::Boolean(_) => Ok(Type::Boolean), + Literal::String(_) => Ok(Type::String), + Literal::Char(_) => Ok(Type::Char), + } + } +} + +impl Pattern { + fn infer(&self) -> Result { + match self { + Pattern::Literal(lit) => lit.infer(), + } + } +} + +impl Context { + pub fn new() -> Context { + Context { + locals: HashMap::new(), + } + } +}