Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion compiler/src/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ pub enum Token {
/// Pipe operator `|>`
Pipe,

/// Arrow operator `->` for lambda shorthand
Arrow,

/// Logging level tokens for different verbosity levels
LogDebug, // Debug log level
LogInfo, // Info log level
Expand Down Expand Up @@ -215,7 +218,13 @@ impl Lexer {
}
'-' => {
self.position += 1;
Some(Token::Minus)
// Check for ->
if self.position < self.input.len() && self.input[self.position] == '>' {
self.position += 1;
Some(Token::Arrow)
} else {
Some(Token::Minus)
}
}
'*' => {
self.position += 1;
Expand Down
20 changes: 19 additions & 1 deletion compiler/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,25 @@ impl Parser {
}

// Try binary operations
self.parse_binary_operation()
let expr = self.parse_binary_operation()?;

// Check for arrow lambda shorthand: x -> body
// Desugars: x -> body → Function[{x}, body]
if let Expression::Identifier(param_name) = &expr {
if matches!(&self.current_token, Some(Token::Arrow)) {
self.advance(); // consume ->
let body = Box::new(self.parse_base_expression()?);
return Some(Expression::Lambda {
parameters: vec![TypeAnnotation {
name: param_name.clone(),
type_: Type::Int32, // Placeholder - will be inferred
}],
body,
});
}
}

Some(expr)
}

/// Parse either a function definition or function call
Expand Down
281 changes: 281 additions & 0 deletions compiler/tests/lambda_shorthand_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
use w::lexer::{Lexer, Token};
use w::ast::Expression;
use w::parser::Parser;
use w::rust_codegen::RustCodeGenerator;

// ============================================
// Lexer Tests - Arrow Token
// ============================================

#[test]
fn test_arrow_token() {
let mut lexer = Lexer::new("x -> x + 1".to_string());

assert_eq!(lexer.next_token(), Some(Token::Identifier("x".to_string())));
assert_eq!(lexer.next_token(), Some(Token::Arrow));
assert_eq!(lexer.next_token(), Some(Token::Identifier("x".to_string())));
assert_eq!(lexer.next_token(), Some(Token::Plus));
assert_eq!(lexer.next_token(), Some(Token::Number(1)));
assert_eq!(lexer.next_token(), None);
}

#[test]
fn test_arrow_token_no_spaces() {
let mut lexer = Lexer::new("x->y".to_string());

assert_eq!(lexer.next_token(), Some(Token::Identifier("x".to_string())));
assert_eq!(lexer.next_token(), Some(Token::Arrow));
assert_eq!(lexer.next_token(), Some(Token::Identifier("y".to_string())));
assert_eq!(lexer.next_token(), None);
}

#[test]
fn test_minus_not_confused_with_arrow() {
let mut lexer = Lexer::new("x - y".to_string());

assert_eq!(lexer.next_token(), Some(Token::Identifier("x".to_string())));
assert_eq!(lexer.next_token(), Some(Token::Minus));
assert_eq!(lexer.next_token(), Some(Token::Identifier("y".to_string())));
assert_eq!(lexer.next_token(), None);
}

// ============================================
// Parser Tests - Lambda Shorthand
// ============================================

#[test]
fn test_parse_arrow_lambda_simple() {
// x -> x * 2 desugars to Function[{x}, x * 2]
let mut parser = Parser::new("x -> x * 2".to_string());
let expr = parser.parse_expression().unwrap();

match expr {
Expression::Lambda { parameters, body: _ } => {
assert_eq!(parameters.len(), 1);
assert_eq!(parameters[0].name, "x");
}
_ => panic!("Expected Lambda expression, got {:?}", expr),
}
}

#[test]
fn test_parse_arrow_lambda_comparison() {
// x -> x > 100 desugars to Function[{x}, x > 100]
let mut parser = Parser::new("x -> x > 100".to_string());
let expr = parser.parse_expression().unwrap();

match expr {
Expression::Lambda { parameters, body } => {
assert_eq!(parameters.len(), 1);
assert_eq!(parameters[0].name, "x");
// Body should be a BinaryOp with >
match *body {
Expression::BinaryOp { left, operator, right } => {
assert_eq!(*left, Expression::Identifier("x".to_string()));
assert_eq!(operator, w::ast::Operator::GreaterThan);
assert_eq!(*right, Expression::Number(100));
}
_ => panic!("Expected BinaryOp body, got {:?}", body),
}
}
_ => panic!("Expected Lambda expression, got {:?}", expr),
}
}

#[test]
fn test_parse_arrow_lambda_in_filter() {
// Filter[x -> x > 100] desugars to Filter[Function[{x}, x > 100]]
let mut parser = Parser::new("Filter[x -> x > 100]".to_string());
let expr = parser.parse_expression().unwrap();

match expr {
Expression::FunctionCall { function, arguments } => {
assert_eq!(*function, Expression::Identifier("Filter".to_string()));
assert_eq!(arguments.len(), 1);
match &arguments[0] {
Expression::Lambda { parameters, .. } => {
assert_eq!(parameters.len(), 1);
assert_eq!(parameters[0].name, "x");
}
_ => panic!("Expected Lambda as argument, got {:?}", arguments[0]),
}
}
_ => panic!("Expected FunctionCall, got {:?}", expr),
}
}

#[test]
fn test_parse_arrow_lambda_in_map() {
// Map[x -> x * 2] desugars to Map[Function[{x}, x * 2]]
let mut parser = Parser::new("Map[x -> x * 2]".to_string());
let expr = parser.parse_expression().unwrap();

match expr {
Expression::FunctionCall { function, arguments } => {
assert_eq!(*function, Expression::Identifier("Map".to_string()));
assert_eq!(arguments.len(), 1);
match &arguments[0] {
Expression::Lambda { parameters, .. } => {
assert_eq!(parameters.len(), 1);
assert_eq!(parameters[0].name, "x");
}
_ => panic!("Expected Lambda as argument, got {:?}", arguments[0]),
}
}
_ => panic!("Expected FunctionCall, got {:?}", expr),
}
}

#[test]
fn test_parse_arrow_lambda_with_pipe() {
// data |> Filter[x -> x > 100] |> Map[x -> x * 2]
let mut parser = Parser::new("data |> Filter[x -> x > 100] |> Map[x -> x * 2]".to_string());
let expr = parser.parse_expression().unwrap();

// Outermost: Map[x -> x * 2, Filter[x -> x > 100, data]]
match expr {
Expression::FunctionCall { function, arguments } => {
assert_eq!(*function, Expression::Identifier("Map".to_string()));
assert_eq!(arguments.len(), 2);

// First arg: lambda x -> x * 2
match &arguments[0] {
Expression::Lambda { parameters, .. } => {
assert_eq!(parameters[0].name, "x");
}
_ => panic!("Expected Lambda, got {:?}", arguments[0]),
}

// Second arg: Filter[x -> x > 100, data]
match &arguments[1] {
Expression::FunctionCall { function: inner_fn, arguments: inner_args } => {
assert_eq!(**inner_fn, Expression::Identifier("Filter".to_string()));
assert_eq!(inner_args.len(), 2);

// Filter's first arg: lambda x -> x > 100
match &inner_args[0] {
Expression::Lambda { parameters, .. } => {
assert_eq!(parameters[0].name, "x");
}
_ => panic!("Expected Lambda in Filter, got {:?}", inner_args[0]),
}

// Filter's second arg: data (piped in)
assert_eq!(inner_args[1], Expression::Identifier("data".to_string()));
}
_ => panic!("Expected inner FunctionCall, got {:?}", arguments[1]),
}
}
_ => panic!("Expected FunctionCall, got {:?}", expr),
}
}

// ============================================
// Equivalence Tests - Arrow vs Function syntax
// ============================================

#[test]
fn test_arrow_equivalent_to_function_in_map() {
// These two should produce the same AST
let mut parser1 = Parser::new("Map[x -> x * 2, [1, 2, 3]]".to_string());
let expr1 = parser1.parse_expression().unwrap();

let mut parser2 = Parser::new("Map[Function[{x}, x * 2], [1, 2, 3]]".to_string());
let expr2 = parser2.parse_expression().unwrap();

// Both should be FunctionCalls to Map with a Lambda as first arg
match (&expr1, &expr2) {
(
Expression::FunctionCall { arguments: args1, .. },
Expression::FunctionCall { arguments: args2, .. },
) => {
assert_eq!(args1.len(), args2.len());
// Both first args should be Lambda with param "x"
match (&args1[0], &args2[0]) {
(
Expression::Lambda { parameters: p1, .. },
Expression::Lambda { parameters: p2, .. },
) => {
assert_eq!(p1[0].name, p2[0].name);
}
_ => panic!("Expected both to have Lambda first args"),
}
}
_ => panic!("Expected both to be FunctionCalls"),
}
}

// ============================================
// Code Generation Tests - Lambda Shorthand
// ============================================

#[test]
fn test_codegen_arrow_lambda() {
let mut parser = Parser::new("x -> x * 2".to_string());
let expr = parser.parse_expression().unwrap();

let mut codegen = RustCodeGenerator::new();
let rust_code = codegen.generate(&expr).unwrap();

assert!(rust_code.contains("|x|"),
"Should generate Rust closure, got: {}", rust_code);
assert!(rust_code.contains("x * 2"),
"Should contain closure body, got: {}", rust_code);
}

#[test]
fn test_codegen_arrow_map() {
let mut parser = Parser::new("Map[x -> x * 2, [1, 2, 3]]".to_string());
let expr = parser.parse_expression().unwrap();

let mut codegen = RustCodeGenerator::new();
let rust_code = codegen.generate(&expr).unwrap();

assert!(rust_code.contains(".into_iter().map("),
"Should generate iterator map, got: {}", rust_code);
assert!(rust_code.contains("|x| (x * 2)"),
"Should inline arrow lambda in map, got: {}", rust_code);
assert!(rust_code.contains(".collect::<Vec<_>>()"),
"Should collect into Vec, got: {}", rust_code);
}

#[test]
fn test_codegen_arrow_filter() {
let mut parser = Parser::new("Filter[x -> x > 5, [1, 10, 3, 8]]".to_string());
let expr = parser.parse_expression().unwrap();

let mut codegen = RustCodeGenerator::new();
let rust_code = codegen.generate(&expr).unwrap();

assert!(rust_code.contains(".into_iter().filter("),
"Should generate iterator filter, got: {}", rust_code);
assert!(rust_code.contains("|&x| (x > 5)"),
"Should use pattern matching in filter, got: {}", rust_code);
}

#[test]
fn test_codegen_arrow_with_pipe() {
let mut parser = Parser::new("[1, 2, 3] |> Map[x -> x * 2]".to_string());
let expr = parser.parse_expression().unwrap();

let mut codegen = RustCodeGenerator::new();
let rust_code = codegen.generate(&expr).unwrap();

assert!(rust_code.contains(".into_iter().map(|x| (x * 2))"),
"Should generate piped map with arrow lambda, got: {}", rust_code);
}

#[test]
fn test_codegen_arrow_pipe_chain() {
let input = "[1, 50, 150, 200] |> Filter[x -> x > 100] |> Map[x -> x * 2]";
let mut parser = Parser::new(input.to_string());
let expr = parser.parse_expression().unwrap();

let mut codegen = RustCodeGenerator::new();
let rust_code = codegen.generate(&expr).unwrap();

assert!(rust_code.contains(".filter("),
"Should contain filter, got: {}", rust_code);
assert!(rust_code.contains(".map("),
"Should contain map, got: {}", rust_code);
}