diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 470b3eb..90a3943 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -459,10 +459,13 @@ mod tests { } #[test] - fn arithmetic_scalar(){ + fn arithmetic_scalar() { let qs = "56"; let res = arithmetic(qs.as_bytes()); assert!(res.is_err()); - assert_eq!(nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), res.err().unwrap()); + assert_eq!( + nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), + res.err().unwrap() + ); } } diff --git a/src/common.rs b/src/common.rs index 1d6cd4d..cdc52f6 100644 --- a/src/common.rs +++ b/src/common.rs @@ -388,7 +388,7 @@ where let (inp, _) = first.parse(inp)?; let (inp, o2) = second.parse(inp)?; third.parse(inp).map(|(i, _)| (i, o2)) - }, + } } } } @@ -641,7 +641,8 @@ pub fn function_argument_parser(i: &[u8]) -> IResult<&[u8], FunctionArgument> { // present. pub fn function_arguments(i: &[u8]) -> IResult<&[u8], (FunctionArgument, bool)> { let distinct_parser = opt(tuple((tag_no_case("distinct"), multispace1))); - let (remaining_input, (distinct, args)) = tuple((distinct_parser, function_argument_parser))(i)?; + let (remaining_input, (distinct, args)) = + tuple((distinct_parser, function_argument_parser))(i)?; Ok((remaining_input, (args, distinct.is_some()))) } @@ -695,12 +696,25 @@ pub fn column_function(i: &[u8]) -> IResult<&[u8], FunctionExpression> { FunctionExpression::GroupConcat(FunctionArgument::Column(col.clone()), sep) }, ), - map(tuple((sql_identifier, multispace0, tag("("), separated_list0(tag(","), delimited(multispace0, function_argument_parser, multispace0)), tag(")"))), |tuple| { - let (name, _, _, arguments, _) = tuple; - FunctionExpression::Generic( - str::from_utf8(name).unwrap().to_string(), - FunctionArguments::from(arguments)) - }) + map( + tuple(( + sql_identifier, + multispace0, + tag("("), + separated_list0( + tag(","), + delimited(multispace0, function_argument_parser, multispace0), + ), + tag(")"), + )), + |tuple| { + let (name, _, _, arguments, _) = tuple; + FunctionExpression::Generic( + str::from_utf8(name).unwrap().to_string(), + FunctionArguments::from(arguments), + ) + }, + ), ))(i) } @@ -1021,22 +1035,23 @@ pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec> { // Parse a reference to a named schema.table, with an optional alias pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> { map( - tuple(( - opt(pair(sql_identifier, tag("."))), - sql_identifier, - opt(as_alias) - )), - |tup| Table { - name: String::from(str::from_utf8(tup.1).unwrap()), - alias: match tup.2 { - Some(a) => Some(String::from(a)), - None => None, - }, - schema: match tup.0 { - Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())), - None => None, + tuple(( + opt(pair(sql_identifier, tag("."))), + sql_identifier, + opt(as_alias), + )), + |tup| Table { + name: String::from(str::from_utf8(tup.1).unwrap()), + alias: match tup.2 { + Some(a) => Some(String::from(a)), + None => None, + }, + schema: match tup.0 { + Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())), + None => None, + }, }, - })(i) + )(i) } // Parse a reference to a named table, with an optional alias @@ -1047,7 +1062,7 @@ pub fn table_reference(i: &[u8]) -> IResult<&[u8], Table> { Some(a) => Some(String::from(a)), None => None, }, - schema: None, + schema: None, })(i) } @@ -1137,25 +1152,31 @@ mod tests { name: String::from("max(addr_id)"), alias: None, table: None, - function: Some(Box::new(FunctionExpression::Max( - FunctionArgument::Column(Column::from("addr_id")), - ))), + function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column( + Column::from("addr_id"), + )))), }; assert_eq!(res.unwrap().1, expected); } #[test] fn simple_generic_function() { - let qlist = ["coalesce(a,b,c)".as_bytes(), "coalesce (a,b,c)".as_bytes(), "coalesce(a ,b,c)".as_bytes(), "coalesce(a, b,c)".as_bytes()]; + let qlist = [ + "coalesce(a,b,c)".as_bytes(), + "coalesce (a,b,c)".as_bytes(), + "coalesce(a ,b,c)".as_bytes(), + "coalesce(a, b,c)".as_bytes(), + ]; for q in qlist.iter() { let res = column_function(q); - let expected = FunctionExpression::Generic("coalesce".to_string(), - FunctionArguments::from( - vec!( - FunctionArgument::Column(Column::from("a")), - FunctionArgument::Column(Column::from("b")), - FunctionArgument::Column(Column::from("c")) - ))); + let expected = FunctionExpression::Generic( + "coalesce".to_string(), + FunctionArguments::from(vec![ + FunctionArgument::Column(Column::from("a")), + FunctionArgument::Column(Column::from("b")), + FunctionArgument::Column(Column::from("c")), + ]), + ); assert_eq!(res, Ok((&b""[..], expected))); } } diff --git a/src/compound_select.rs b/src/compound_select.rs index d3ece89..70a5679 100644 --- a/src/compound_select.rs +++ b/src/compound_select.rs @@ -7,10 +7,10 @@ use nom::branch::alt; use nom::bytes::complete::{tag, tag_no_case}; use nom::combinator::{map, opt}; use nom::multi::many1; -use nom::sequence::{delimited, preceded, tuple}; +use nom::sequence::{preceded, tuple}; use nom::IResult; use order::{order_clause, OrderClause}; -use select::{limit_clause, nested_selection, LimitClause, SelectStatement}; +use select::{limit_clause, nested_simple_selection, LimitClause, Selection}; #[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)] pub enum CompoundSelectOperator { @@ -33,7 +33,7 @@ impl fmt::Display for CompoundSelectOperator { #[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)] pub struct CompoundSelectStatement { - pub selects: Vec<(Option, SelectStatement)>, + pub selects: Vec<(Option, Selection)>, pub order: Option, pub limit: Option, } @@ -89,43 +89,78 @@ fn compound_op(i: &[u8]) -> IResult<&[u8], CompoundSelectOperator> { ))(i) } -fn other_selects(i: &[u8]) -> IResult<&[u8], (Option, SelectStatement)> { - let (remaining_input, (_, op, _, select)) = tuple(( - multispace0, - compound_op, - multispace1, - opt_delimited( - tag("("), - delimited(multispace0, nested_selection, multispace0), - tag(")"), +// Parse terminated compound selection +pub fn compound_selection(i: &[u8]) -> IResult<&[u8], CompoundSelectStatement> { + let (remaining_input, (compound_selection, _, _)) = + tuple((nested_compound_selection, multispace0, statement_terminator))(i)?; + + Ok((remaining_input, compound_selection)) +} + +pub fn compound_selection_part(i: &[u8]) -> IResult<&[u8], Selection> { + alt(( + map(compound_selection_compound_part, |cs| cs.into()), + map( + opt_delimited(tag("("), nested_simple_selection, tag(")")), + |s| s.into(), ), + ))(i) +} + +pub fn compound_selection_compound_part(i: &[u8]) -> IResult<&[u8], CompoundSelectStatement> { + let (remaining_input, (_, lhs, op_rhs, _)) = tuple(( + tag("("), + opt_delimited(tag("("), nested_simple_selection, tag(")")), + many1(tuple((multispace1, compound_op_selection_part))), + tag(")"), ))(i)?; - Ok((remaining_input, (Some(op), select))) + let mut css = CompoundSelectStatement { + selects: vec![], + order: None, + limit: None, + }; + + css.selects.push((None, lhs.into())); + + for (_, (op, rhs)) in op_rhs { + css.selects.push((Some(op), rhs.into())) + } + + Ok((remaining_input, css)) } -// Parse compound selection -pub fn compound_selection(i: &[u8]) -> IResult<&[u8], CompoundSelectStatement> { - let (remaining_input, (first_select, other_selects, _, order, limit, _)) = tuple(( - opt_delimited(tag("("), nested_selection, tag(")")), - many1(other_selects), - multispace0, +pub fn compound_op_selection_part(i: &[u8]) -> IResult<&[u8], (CompoundSelectOperator, Selection)> { + let (remaining_input, (op, _, selection)) = + tuple((compound_op, multispace1, compound_selection_part))(i)?; + + Ok((remaining_input, (op, selection))) +} + +// Parse nested compound selection +pub fn nested_compound_selection(i: &[u8]) -> IResult<&[u8], CompoundSelectStatement> { + let (remaining_input, ((first, other_selects), order, limit)) = tuple(( + tuple(( + compound_selection_part, + many1(tuple((multispace1, compound_op_selection_part))), + )), opt(order_clause), opt(limit_clause), - statement_terminator, ))(i)?; - let mut selects = vec![(None, first_select)]; - selects.extend(other_selects); - - Ok(( - remaining_input, - CompoundSelectStatement { - selects, - order, - limit, - }, - )) + let mut css = CompoundSelectStatement { + selects: vec![], + order, + limit, + }; + + css.selects.push((None, first.into())); + + for os in other_selects { + css.selects.push((Some(os.1 .0), os.1 .1.into())); + } + + Ok((remaining_input, css)) } #[cfg(test)] @@ -133,14 +168,16 @@ mod tests { use super::*; use column::Column; use common::{FieldDefinitionExpression, FieldValueExpression, Literal}; + use select::selection; use table::Table; + use SelectStatement; #[test] fn union() { let qstr = "SELECT id, 1 FROM Vote UNION SELECT id, stars from Rating;"; let qstr2 = "(SELECT id, 1 FROM Vote) UNION (SELECT id, stars from Rating);"; - let res = compound_selection(qstr.as_bytes()); - let res2 = compound_selection(qstr2.as_bytes()); + let res = selection(qstr.as_bytes()); + let res2 = selection(qstr2.as_bytes()); let first_select = SelectStatement { tables: vec![Table::from("Vote")], @@ -162,15 +199,18 @@ mod tests { }; let expected = CompoundSelectStatement { selects: vec![ - (None, first_select), - (Some(CompoundSelectOperator::DistinctUnion), second_select), + (None, first_select.into()), + ( + Some(CompoundSelectOperator::DistinctUnion), + second_select.into(), + ), ], order: None, limit: None, }; - assert_eq!(res.unwrap().1, expected); - assert_eq!(res2.unwrap().1, expected); + assert_eq!(res.unwrap().1, expected.clone().into()); + assert_eq!(res2.unwrap().1, expected.into()); } #[test] @@ -185,29 +225,38 @@ mod tests { assert!(&res.is_err()); assert_eq!( res.unwrap_err(), - nom::Err::Error(nom::error::Error::new(");".as_bytes(), nom::error::ErrorKind::Tag)) + nom::Err::Error(nom::error::Error::new( + ");".as_bytes(), + nom::error::ErrorKind::MultiSpace + )) ); assert!(&res2.is_err()); assert_eq!( res2.unwrap_err(), - nom::Err::Error(nom::error::Error::new(";".as_bytes(), nom::error::ErrorKind::Tag)) + nom::Err::Error(nom::error::Error::new( + ";".as_bytes(), + nom::error::ErrorKind::Tag + )) ); assert!(&res3.is_err()); assert_eq!( res3.unwrap_err(), nom::Err::Error(nom::error::Error::new( ") UNION (SELECT id, stars from Rating;".as_bytes(), - nom::error::ErrorKind::Tag + nom::error::ErrorKind::MultiSpace )) ); } #[test] fn multi_union() { - let qstr = "SELECT id, 1 FROM Vote \ - UNION SELECT id, stars from Rating \ - UNION DISTINCT SELECT 42, 5 FROM Vote;"; - let res = compound_selection(qstr.as_bytes()); + let q = "SELECT id, 1 FROM Vote UNION SELECT id, stars from Rating UNION DISTINCT SELECT 42, 5 FROM Vote"; + let qstr0 = format!("{};", q); + let qstr1 = format!("({}) UNION ALL ({});", q, q); + let qstr2 = format!("{} UNION ALL {};", q, q); + let res0 = selection(qstr0.as_bytes()); + let res1 = selection(qstr1.as_bytes()); + let res2 = selection(qstr2.as_bytes()); let first_select = SelectStatement { tables: vec![Table::from("Vote")], @@ -240,23 +289,71 @@ mod tests { ..Default::default() }; - let expected = CompoundSelectStatement { + let expected0 = CompoundSelectStatement { + selects: vec![ + (None, first_select.clone().into()), + ( + Some(CompoundSelectOperator::DistinctUnion), + second_select.clone().into(), + ), + ( + Some(CompoundSelectOperator::DistinctUnion), + third_select.clone().into(), + ), + ], + order: None, + limit: None, + }; + + let expected1 = CompoundSelectStatement { selects: vec![ - (None, first_select), - (Some(CompoundSelectOperator::DistinctUnion), second_select), - (Some(CompoundSelectOperator::DistinctUnion), third_select), + (None, expected0.clone().into()), + ( + Some(CompoundSelectOperator::Union), + expected0.clone().into(), + ), ], order: None, limit: None, }; - assert_eq!(res.unwrap().1, expected); + let expected2 = CompoundSelectStatement { + selects: vec![ + (None, first_select.clone().into()), + ( + Some(CompoundSelectOperator::DistinctUnion), + second_select.clone().into(), + ), + ( + Some(CompoundSelectOperator::DistinctUnion), + third_select.clone().into(), + ), + ( + Some(CompoundSelectOperator::Union), + first_select.clone().into(), + ), + ( + Some(CompoundSelectOperator::DistinctUnion), + second_select.clone().into(), + ), + ( + Some(CompoundSelectOperator::DistinctUnion), + third_select.into(), + ), + ], + order: None, + limit: None, + }; + + assert_eq!(res0.unwrap().1, expected0.into()); + assert_eq!(res1.unwrap().1, expected1.into()); + assert_eq!(res2.unwrap().1, expected2.into()); } #[test] fn union_all() { let qstr = "SELECT id, 1 FROM Vote UNION ALL SELECT id, stars from Rating;"; - let res = compound_selection(qstr.as_bytes()); + let res = selection(qstr.as_bytes()); let first_select = SelectStatement { tables: vec![Table::from("Vote")], @@ -278,13 +375,13 @@ mod tests { }; let expected = CompoundSelectStatement { selects: vec![ - (None, first_select), - (Some(CompoundSelectOperator::Union), second_select), + (None, first_select.into()), + (Some(CompoundSelectOperator::Union), second_select.into()), ], order: None, limit: None, }; - assert_eq!(res.unwrap().1, expected); + assert_eq!(res.unwrap().1, expected.into()); } } diff --git a/src/condition.rs b/src/condition.rs index a210b7a..865c989 100644 --- a/src/condition.rs +++ b/src/condition.rs @@ -14,7 +14,7 @@ use nom::bytes::complete::{tag, tag_no_case}; use nom::combinator::{map, opt}; use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple}; use nom::IResult; -use select::{nested_selection, SelectStatement}; +use select::{nested_selection, nested_simple_selection, SelectStatement, Selection}; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub enum ConditionBase { @@ -89,7 +89,7 @@ pub enum ConditionExpression { ComparisonOp(ConditionTree), LogicalOp(ConditionTree), NegationOp(Box), - ExistsOp(Box), + ExistsOp(Box), Base(ConditionBase), Arithmetic(Box), Bracketed(Box), @@ -227,9 +227,10 @@ fn in_operation(i: &[u8]) -> IResult<&[u8], (Operator, ConditionExpression)> { opt(terminated(tag_no_case("not"), multispace1)), terminated(tag_no_case("in"), multispace0), alt(( - map(delimited(tag("("), nested_selection, tag(")")), |s| { - ConditionBase::NestedSelect(Box::new(s)) - }), + map( + delimited(tag("("), nested_simple_selection, tag(")")), + |s| ConditionBase::NestedSelect(Box::new(s)), + ), map(delimited(tag("("), value_list, tag(")")), |vs| { ConditionBase::LiteralList(vs) }), @@ -291,10 +292,7 @@ fn predicate(i: &[u8]) -> IResult<&[u8], ConditionExpression> { }, ); - alt(( - simple_expr, - nested_exists, - ))(i) + alt((simple_expr, nested_exists))(i) } fn simple_expr(i: &[u8]) -> IResult<&[u8], ConditionExpression> { @@ -320,9 +318,10 @@ fn simple_expr(i: &[u8]) -> IResult<&[u8], ConditionExpression> { map(column_identifier, |f| { ConditionExpression::Base(ConditionBase::Field(f)) }), - map(delimited(tag("("), nested_selection, tag(")")), |s| { - ConditionExpression::Base(ConditionBase::NestedSelect(Box::new(s))) - }), + map( + delimited(tag("("), nested_simple_selection, tag(")")), + |s| ConditionExpression::Base(ConditionBase::NestedSelect(Box::new(s))), + ), ))(i) } @@ -745,11 +744,14 @@ mod tests { let res = condition_expr(cond.as_bytes()); - let nested_select = Box::new(SelectStatement { - tables: vec![Table::from("foo")], - fields: columns(&["col"]), - ..Default::default() - }); + let nested_select = Box::new( + SelectStatement { + tables: vec![Table::from("foo")], + fields: columns(&["col"]), + ..Default::default() + } + .into(), + ); let expected = ConditionExpression::ExistsOp(nested_select); @@ -766,11 +768,14 @@ mod tests { let res = condition_expr(cond.as_bytes()); - let nested_select = Box::new(SelectStatement { - tables: vec![Table::from("foo")], - fields: columns(&["col"]), - ..Default::default() - }); + let nested_select = Box::new( + SelectStatement { + tables: vec![Table::from("foo")], + fields: columns(&["col"]), + ..Default::default() + } + .into(), + ); let expected = ConditionExpression::NegationOp(Box::new(ConditionExpression::ExistsOp(nested_select))); diff --git a/src/create.rs b/src/create.rs index a1b5afc..f3f11cf 100644 --- a/src/create.rs +++ b/src/create.rs @@ -5,8 +5,8 @@ use std::str::FromStr; use column::{Column, ColumnConstraint, ColumnSpecification}; use common::{ - column_identifier_no_alias, parse_comment, sql_identifier, statement_terminator, - schema_table_reference, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey, + column_identifier_no_alias, parse_comment, schema_table_reference, sql_identifier, + statement_terminator, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey, }; use compound_select::{compound_selection, CompoundSelectStatement}; use create_table_options::table_options; @@ -18,7 +18,7 @@ use nom::multi::{many0, many1}; use nom::sequence::{delimited, preceded, terminated, tuple}; use nom::IResult; use order::{order_type, OrderType}; -use select::{nested_selection, SelectStatement}; +use select::{nested_simple_selection, SelectStatement}; use table::Table; #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] @@ -431,7 +431,7 @@ pub fn view_creation(i: &[u8]) -> IResult<&[u8], CreateViewStatement> { multispace1, alt(( map(compound_selection, |s| SelectSpecification::Compound(s)), - map(nested_selection, |s| SelectSpecification::Simple(s)), + map(nested_simple_selection, |s| SelectSpecification::Simple(s)), )), statement_terminator, ))(i)?; @@ -534,7 +534,7 @@ mod tests { assert_eq!( res.unwrap().1, CreateTableStatement { - table: Table::from(("db1","t")), + table: Table::from(("db1", "t")), fields: vec![ColumnSpecification::new( Column::from("t.x"), SqlType::Int(32) @@ -833,7 +833,8 @@ mod tests { tables: vec![Table::from("users")], fields: vec![FieldDefinitionExpression::All], ..Default::default() - }, + } + .into(), ), ( Some(CompoundSelectOperator::DistinctUnion), @@ -841,7 +842,8 @@ mod tests { tables: vec![Table::from("old_users")], fields: vec![FieldDefinitionExpression::All], ..Default::default() - }, + } + .into(), ), ], order: None, diff --git a/src/delete.rs b/src/delete.rs index 64ee1dc..40f3cd1 100644 --- a/src/delete.rs +++ b/src/delete.rs @@ -1,7 +1,7 @@ use nom::character::complete::multispace1; use std::{fmt, str}; -use common::{statement_terminator, schema_table_reference}; +use common::{schema_table_reference, statement_terminator}; use condition::ConditionExpression; use keywords::escape_if_keyword; use nom::bytes::complete::tag_no_case; @@ -77,7 +77,7 @@ mod tests { assert_eq!( res.unwrap().1, DeleteStatement { - table: Table::from(("db1","users")), + table: Table::from(("db1", "users")), ..Default::default() } ); diff --git a/src/insert.rs b/src/insert.rs index 9f55107..5c81427 100644 --- a/src/insert.rs +++ b/src/insert.rs @@ -4,7 +4,7 @@ use std::str; use column::Column; use common::{ - assignment_expr_list, field_list, statement_terminator, schema_table_reference, value_list, + assignment_expr_list, field_list, schema_table_reference, statement_terminator, value_list, ws_sep_comma, FieldValueExpression, Literal, }; use keywords::escape_if_keyword; @@ -145,7 +145,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from(("db1","users")), + table: Table::from(("db1", "users")), fields: None, data: vec![vec![42.into(), "test".into()]], ..Default::default() diff --git a/src/join.rs b/src/join.rs index b91f5dc..35e3638 100644 --- a/src/join.rs +++ b/src/join.rs @@ -7,7 +7,7 @@ use nom::branch::alt; use nom::bytes::complete::tag_no_case; use nom::combinator::map; use nom::IResult; -use select::{JoinClause, SelectStatement}; +use select::{JoinClause, Selection}; use table::Table; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] @@ -17,7 +17,7 @@ pub enum JoinRightSide { /// A comma-separated (and implicitly joined) sequence of tables. Tables(Vec), /// A nested selection, represented as (query, alias). - NestedSelect(Box, Option), + NestedSelect(Box, Option), /// A nested join clause. NestedJoin(Box), } @@ -111,14 +111,14 @@ mod tests { use condition::ConditionBase::*; use condition::ConditionExpression::{self, *}; use condition::ConditionTree; - use select::{selection, JoinClause, SelectStatement}; + use select::{simple_selection, JoinClause, SelectStatement}; #[test] fn inner_join() { let qstring = "SELECT tags.* FROM tags \ INNER JOIN taggings ON tags.id = taggings.tag_id"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let ct = ConditionTree { left: Box::new(Base(Field(Column::from("tags.id")))), diff --git a/src/order.rs b/src/order.rs index 3db9ddc..273908a 100644 --- a/src/order.rs +++ b/src/order.rs @@ -82,7 +82,7 @@ pub fn order_clause(i: &[u8]) -> IResult<&[u8], OrderClause> { #[cfg(test)] mod tests { use super::*; - use select::selection; + use select::simple_selection; #[test] fn order_clause() { @@ -103,9 +103,9 @@ mod tests { columns: vec![("name".into(), OrderType::OrderAscending)], }; - let res1 = selection(qstring1.as_bytes()); - let res2 = selection(qstring2.as_bytes()); - let res3 = selection(qstring3.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); + let res2 = simple_selection(qstring2.as_bytes()); + let res3 = simple_selection(qstring3.as_bytes()); assert_eq!(res1.unwrap().1.order, Some(expected_ord1)); assert_eq!(res2.unwrap().1.order, Some(expected_ord2)); assert_eq!(res3.unwrap().1.order, Some(expected_ord3)); diff --git a/src/parser.rs b/src/parser.rs index ae68592..a0dfc42 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,7 +1,7 @@ use std::fmt; use std::str; -use compound_select::{compound_selection, CompoundSelectStatement}; +use compound_select::CompoundSelectStatement; use create::{creation, view_creation, CreateTableStatement, CreateViewStatement}; use delete::{deletion, DeleteStatement}; use drop::{drop_table, DropTableStatement}; @@ -9,7 +9,7 @@ use insert::{insertion, InsertStatement}; use nom::branch::alt; use nom::combinator::map; use nom::IResult; -use select::{selection, SelectStatement}; +use select::{selection, SelectStatement, Selection}; use set::{set, SetStatement}; use update::{updating, UpdateStatement}; @@ -42,12 +42,20 @@ impl fmt::Display for SqlQuery { } } +impl From for SqlQuery { + fn from(s: Selection) -> Self { + match s { + Selection::Statement(ss) => SqlQuery::Select(ss), + Selection::Compound(css) => SqlQuery::CompoundSelect(css), + } + } +} + pub fn sql_query(i: &[u8]) -> IResult<&[u8], SqlQuery> { alt(( map(creation, |c| SqlQuery::CreateTable(c)), map(insertion, |i| SqlQuery::Insert(i)), - map(compound_selection, |cs| SqlQuery::CompoundSelect(cs)), - map(selection, |s| SqlQuery::Select(s)), + map(selection, |s| s.into()), map(deletion, |d| SqlQuery::Delete(d)), map(drop_table, |dt| SqlQuery::DropTable(dt)), map(updating, |u| SqlQuery::Update(u)), diff --git a/src/select.rs b/src/select.rs index 8735e95..5428367 100644 --- a/src/select.rs +++ b/src/select.rs @@ -1,23 +1,26 @@ use nom::character::complete::{multispace0, multispace1}; use std::fmt; +use std::fmt::{Display, Formatter}; use std::str; use column::Column; -use common::FieldDefinitionExpression; use common::{ as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference, unsigned_number, }; +use common::{sql_identifier, FieldDefinitionExpression}; +use compound_select::nested_compound_selection; use condition::{condition_expr, ConditionExpression}; use join::{join_operator, JoinConstraint, JoinOperator, JoinRightSide}; use nom::branch::alt; use nom::bytes::complete::{tag, tag_no_case}; use nom::combinator::{map, opt}; -use nom::multi::many0; +use nom::multi::{many0, separated_list1}; use nom::sequence::{delimited, preceded, terminated, tuple}; use nom::IResult; use order::{order_clause, OrderClause}; use table::Table; +use CompoundSelectStatement; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct GroupByClause { @@ -76,8 +79,93 @@ impl fmt::Display for LimitClause { } } +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub enum Selection { + Statement(SelectStatement), + Compound(CompoundSelectStatement), +} + +impl Display for Selection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::Statement(s) => write!(f, "{}", s), + Self::Compound(cs) => write!(f, "{}", cs), + } + } +} + +impl From for Selection { + fn from(ss: SelectStatement) -> Self { + Self::Statement(ss) + } +} + +impl From for Selection { + fn from(css: CompoundSelectStatement) -> Self { + Self::Compound(css) + } +} + +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub struct WithClause { + pub recursive: bool, + pub subclauses: Vec, +} + +impl fmt::Display for WithClause { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "WITH ")?; + + if self.recursive { + write!(f, "RECURSIVE ")?; + } + + write!( + f, + "{}", + self.subclauses + .iter() + .map(|c| format!("{}", c)) + .collect::>() + .join(", ") + )?; + + Ok(()) + } +} + +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub struct WithSubclause { + pub name: String, + pub columns: Vec, + pub selection: Box, +} + +impl fmt::Display for WithSubclause { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} ", self.name)?; + + if self.columns.len() > 0 { + write!( + f, + "({}) ", + self.columns + .iter() + .map(|c| format!("{}", c)) + .collect::>() + .join(", ") + )?; + } + + write!(f, "AS ({})", self.selection)?; + + Ok(()) + } +} + #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct SelectStatement { + pub with: Option, pub tables: Vec
, pub distinct: bool, pub fields: Vec, @@ -90,6 +178,10 @@ pub struct SelectStatement { impl fmt::Display for SelectStatement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(ref with_clause) = self.with { + write!(f, "{}", with_clause)?; + } + write!(f, "SELECT ")?; if self.distinct { write!(f, "DISTINCT ")?; @@ -268,16 +360,30 @@ pub fn where_clause(i: &[u8]) -> IResult<&[u8], ConditionExpression> { Ok((remaining_input, where_condition)) } -// Parse rule for a SQL selection query. -pub fn selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { - terminated(nested_selection, statement_terminator)(i) +pub fn selection(i: &[u8]) -> IResult<&[u8], Selection> { + terminated(nested_selection, opt(statement_terminator))(i) } -pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { +pub fn nested_selection(i: &[u8]) -> IResult<&[u8], Selection> { + alt(( + map(nested_compound_selection, |cs| Selection::Compound(cs)), + map(nested_simple_selection, |s| Selection::Statement(s)), + ))(i) +} + +#[cfg(test)] +// Parse rule for a simple SQL selection query, currently only used to simplify tests +pub fn simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { + terminated(nested_simple_selection, statement_terminator)(i) +} + +pub fn nested_simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { let ( remaining_input, - (_, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit), + (with, _, _, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit), ) = tuple(( + opt(with_clause), + multispace0, tag_no_case("select"), multispace1, opt(tag_no_case("distinct")), @@ -294,6 +400,7 @@ pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { Ok(( remaining_input, SelectStatement { + with, tables, distinct: distinct.is_some(), fields, @@ -306,6 +413,60 @@ pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { )) } +pub fn with_clause(i: &[u8]) -> IResult<&[u8], WithClause> { + map( + tuple(( + tag_no_case("with"), + multispace1, + opt(tag_no_case("recursive")), + multispace0, + separated_list1(tuple((multispace0, tag(","), multispace0)), with_subclause), + )), + |(_, _, recursive, _, subclauses)| WithClause { + recursive: recursive.is_some(), + subclauses, + }, + )(i) +} + +pub fn with_subclause(i: &[u8]) -> IResult<&[u8], WithSubclause> { + map( + tuple(( + sql_identifier, + multispace1, + opt(with_clause_column_list), + multispace0, + tag_no_case("as"), + multispace1, + tag("("), + multispace0, + nested_selection, + multispace0, + tag(")"), + )), + |(name, _, columns, _, _, _, _, _, selection, _, _)| WithSubclause { + name: str::from_utf8(name).unwrap().to_string(), + columns: columns.unwrap_or(vec![]), + selection: Box::new(selection), + }, + )(i) +} + +pub fn with_clause_column_list(i: &[u8]) -> IResult<&[u8], Vec> { + let (i, (_, _, columns, _, _)) = tuple(( + tag("("), + multispace0, + separated_list1( + tuple((multispace0, tag(","), multispace0)), + map(sql_identifier, |si| str::from_utf8(si).unwrap().into()), + ), + multispace0, + tag(")"), + ))(i)?; + + Ok((i, columns)) +} + #[cfg(test)] mod tests { use super::*; @@ -330,7 +491,7 @@ mod tests { fn simple_select() { let qstring = "SELECT id, name FROM users;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -345,7 +506,7 @@ mod tests { fn more_involved_select() { let qstring = "SELECT users.id, users.name FROM users;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -364,7 +525,7 @@ mod tests { // TODO: doesn't support selecting literals without a FROM clause, which is still valid SQL // let qstring = "SELECT NULL, 1, \"foo\";"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -392,7 +553,7 @@ mod tests { fn select_all() { let qstring = "SELECT * FROM users;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -407,7 +568,7 @@ mod tests { fn select_all_in_table() { let qstring = "SELECT users.* FROM users, votes;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -422,7 +583,7 @@ mod tests { fn spaces_optional() { let qstring = "SELECT id,name FROM users;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!( res.unwrap().1, SelectStatement { @@ -439,8 +600,8 @@ mod tests { let qstring_uc = "SELECT id, name FROM users;"; assert_eq!( - selection(qstring_lc.as_bytes()).unwrap(), - selection(qstring_uc.as_bytes()).unwrap() + simple_selection(qstring_lc.as_bytes()).unwrap(), + simple_selection(qstring_uc.as_bytes()).unwrap() ); } @@ -450,9 +611,9 @@ mod tests { let qstring_nosem = "select id, name from users"; let qstring_linebreak = "select id, name from users\n"; - let r1 = selection(qstring_sem.as_bytes()).unwrap(); - let r2 = selection(qstring_nosem.as_bytes()).unwrap(); - let r3 = selection(qstring_linebreak.as_bytes()).unwrap(); + let r1 = simple_selection(qstring_sem.as_bytes()).unwrap(); + let r2 = simple_selection(qstring_nosem.as_bytes()).unwrap(); + let r3 = simple_selection(qstring_linebreak.as_bytes()).unwrap(); assert_eq!(r1, r2); assert_eq!(r2, r3); } @@ -482,7 +643,7 @@ mod tests { } fn where_clause_with_variable_placeholder(qstring: &str, literal: Literal) { - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let expected_left = Base(Field(Column::from("email"))); let expected_where_cond = Some(ComparisonOp(ConditionTree { @@ -515,8 +676,8 @@ mod tests { offset: 10, }; - let res1 = selection(qstring1.as_bytes()); - let res2 = selection(qstring2.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); + let res2 = simple_selection(qstring2.as_bytes()); assert_eq!(res1.unwrap().1.limit, Some(expected_lim1)); assert_eq!(res2.unwrap().1.limit, Some(expected_lim2)); } @@ -526,14 +687,14 @@ mod tests { let qstring1 = "select * from PaperTag as t;"; // let qstring2 = "select * from PaperTag t;"; - let res1 = selection(qstring1.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); assert_eq!( res1.unwrap().1, SelectStatement { tables: vec![Table { name: String::from("PaperTag"), alias: Some(String::from("t")), - schema: None, + schema: None, },], fields: vec![FieldDefinitionExpression::All], ..Default::default() @@ -547,14 +708,14 @@ mod tests { fn table_schema() { let qstring1 = "select * from db1.PaperTag as t;"; - let res1 = selection(qstring1.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); assert_eq!( res1.unwrap().1, SelectStatement { tables: vec![Table { name: String::from("PaperTag"), alias: Some(String::from("t")), - schema: Some(String::from("db1")), + schema: Some(String::from("db1")), },], fields: vec![FieldDefinitionExpression::All], ..Default::default() @@ -569,7 +730,7 @@ mod tests { let qstring1 = "select name as TagName from PaperTag;"; let qstring2 = "select PaperTag.name as TagName from PaperTag;"; - let res1 = selection(qstring1.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); assert_eq!( res1.unwrap().1, SelectStatement { @@ -583,7 +744,7 @@ mod tests { ..Default::default() } ); - let res2 = selection(qstring2.as_bytes()); + let res2 = simple_selection(qstring2.as_bytes()); assert_eq!( res2.unwrap().1, SelectStatement { @@ -604,7 +765,7 @@ mod tests { let qstring1 = "select name TagName from PaperTag;"; let qstring2 = "select PaperTag.name TagName from PaperTag;"; - let res1 = selection(qstring1.as_bytes()); + let res1 = simple_selection(qstring1.as_bytes()); assert_eq!( res1.unwrap().1, SelectStatement { @@ -618,7 +779,7 @@ mod tests { ..Default::default() } ); - let res2 = selection(qstring2.as_bytes()); + let res2 = simple_selection(qstring2.as_bytes()); assert_eq!( res2.unwrap().1, SelectStatement { @@ -638,7 +799,7 @@ mod tests { fn distinct() { let qstring = "select distinct tag from PaperTag where paperId=?;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let expected_left = Base(Field(Column::from("paperId"))); let expected_where_cond = Some(ComparisonOp(ConditionTree { left: Box::new(expected_left), @@ -663,7 +824,7 @@ mod tests { fn simple_condition_expr() { let qstring = "select infoJson from PaperStorage where paperId=? and paperStorageId=?;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let left_ct = ConditionTree { left: Box::new(Base(Field(Column::from("paperId")))), @@ -700,7 +861,7 @@ mod tests { #[test] fn where_and_limit_clauses() { let qstring = "select * from users where id = ? limit 10\n"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let expected_lim = Some(LimitClause { limit: 10, @@ -731,7 +892,7 @@ mod tests { fn aggregation_column() { let qstring = "SELECT max(addr_id) FROM address;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let agg_expr = FunctionExpression::Max(FunctionArgument::Column(Column::from("addr_id"))); assert_eq!( res.unwrap().1, @@ -752,7 +913,7 @@ mod tests { fn aggregation_column_with_alias() { let qstring = "SELECT max(addr_id) AS max_addr FROM address;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let agg_expr = FunctionExpression::Max(FunctionArgument::Column(Column::from("addr_id"))); let expected_stmt = SelectStatement { tables: vec![Table::from("address")], @@ -771,7 +932,7 @@ mod tests { fn count_all() { let qstring = "SELECT COUNT(*) FROM votes GROUP BY aid;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let agg_expr = FunctionExpression::CountStar; let expected_stmt = SelectStatement { tables: vec![Table::from("votes")], @@ -794,7 +955,7 @@ mod tests { fn count_distinct() { let qstring = "SELECT COUNT(DISTINCT vote_id) FROM votes GROUP BY aid;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let agg_expr = FunctionExpression::Count(FunctionArgument::Column(Column::from("vote_id")), true); let expected_stmt = SelectStatement { @@ -818,7 +979,7 @@ mod tests { fn count_filter() { let qstring = "SELECT COUNT(CASE WHEN vote_id > 10 THEN vote_id END) FROM votes GROUP BY aid;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let filter_cond = ComparisonOp(ConditionTree { left: Box::new(Base(Field(Column::from("vote_id")))), @@ -854,7 +1015,7 @@ mod tests { fn sum_filter() { let qstring = "SELECT SUM(CASE WHEN sign = 1 THEN vote_id END) FROM votes GROUP BY aid;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let filter_cond = ComparisonOp(ConditionTree { left: Box::new(Base(Field(Column::from("sign")))), @@ -891,7 +1052,7 @@ mod tests { let qstring = "SELECT SUM(CASE WHEN sign = 1 THEN vote_id ELSE 6 END) FROM votes GROUP BY aid;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let filter_cond = ComparisonOp(ConditionTree { left: Box::new(Base(Field(Column::from("sign")))), @@ -930,7 +1091,7 @@ mod tests { FROM votes GROUP BY votes.comment_id;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let filter_cond = LogicalOp(ConditionTree { left: Box::new(ComparisonOp(ConditionTree { @@ -974,7 +1135,7 @@ mod tests { fn generic_function_query() { let qstring = "SELECT coalesce(a, b,c) as x,d FROM sometable;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let agg_expr = FunctionExpression::Generic( String::from("coalesce"), FunctionArguments { @@ -1026,7 +1187,7 @@ mod tests { let qstring = "SELECT * FROM item, author WHERE item.i_a_id = author.a_id AND \ item.i_subject = ? ORDER BY item.i_title limit 50;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let expected_where_cond = Some(LogicalOp(ConditionTree { left: Box::new(ComparisonOp(ConditionTree { left: Box::new(Base(Field(Column::from("item.i_a_id")))), @@ -1064,7 +1225,7 @@ mod tests { fn simple_joins() { let qstring = "select paperId from PaperConflict join PCMember using (contactId);"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let expected_stmt = SelectStatement { tables: vec![Table::from("PaperConflict")], fields: columns(&["paperId"]), @@ -1086,7 +1247,7 @@ mod tests { join PaperReview on (PCMember.contactId=PaperReview.contactId) \ order by contactId;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let ct = ConditionTree { left: Box::new(Base(Field(Column::from("PCMember.contactId")))), right: Box::new(Base(Field(Column::from("PaperReview.contactId")))), @@ -1113,7 +1274,7 @@ mod tests { from PCMember \ join PaperReview on PCMember.contactId=PaperReview.contactId \ order by contactId;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); assert_eq!(res.unwrap().1, expected); } @@ -1133,7 +1294,7 @@ mod tests { (contactId) left join ChairAssistant using (contactId) left join Chair \ using (contactId) where ContactInfo.contactId=?;"; - let res = selection(qstring.as_bytes()); + let res = simple_selection(qstring.as_bytes()); let ct = ConditionTree { left: Box::new(Base(Field(Column::from("ContactInfo.contactId")))), right: Box::new(Base(Literal(Literal::Placeholder( @@ -1177,7 +1338,7 @@ mod tests { WHERE orders.o_c_id IN (SELECT o_c_id FROM orders, order_line \ WHERE orders.o_id = order_line.ol_o_id);"; - let res = selection(qstr.as_bytes()); + let res = simple_selection(qstr.as_bytes()); let inner_where_clause = ComparisonOp(ConditionTree { left: Box::new(Base(Field(Column::from("orders.o_id")))), right: Box::new(Base(Field(Column::from("order_line.ol_o_id")))), @@ -1214,7 +1375,7 @@ mod tests { WHERE orders.o_id = order_line.ol_o_id \ AND orders.o_id > (SELECT MAX(o_id) FROM orders));"; - let res = selection(qstr.as_bytes()); + let res = simple_selection(qstr.as_bytes()); let agg_expr = FunctionExpression::Max(FunctionArgument::Column(Column::from("o_id"))); let recursive_select = SelectStatement { @@ -1286,7 +1447,7 @@ mod tests { let qstr_with_alias = "SELECT o_id, ol_i_id FROM orders JOIN \ (SELECT ol_i_id FROM order_line) AS ids \ ON (orders.o_id = ids.ol_i_id);"; - let res = selection(qstr_with_alias.as_bytes()); + let res = simple_selection(qstr_with_alias.as_bytes()); // N.B.: Don't alias the inner select to `inner`, which is, well, a SQL keyword! let inner_select = SelectStatement { @@ -1300,7 +1461,10 @@ mod tests { fields: columns(&["o_id", "ol_i_id"]), join: vec![JoinClause { operator: JoinOperator::Join, - right: JoinRightSide::NestedSelect(Box::new(inner_select), Some("ids".into())), + right: JoinRightSide::NestedSelect( + Box::new(inner_select.into()), + Some("ids".into()), + ), constraint: JoinConstraint::On(ComparisonOp(ConditionTree { operator: Operator::Equal, left: Box::new(Base(Field(Column::from("orders.o_id")))), @@ -1340,7 +1504,7 @@ mod tests { ..Default::default() }; - assert_eq!(res.unwrap().1, expected); + assert_eq!(res.unwrap().1, expected.into()); } #[test] @@ -1370,7 +1534,7 @@ mod tests { ..Default::default() }; - assert_eq!(res.unwrap().1, expected); + assert_eq!(res.unwrap().1, expected.into()); } #[test] @@ -1407,6 +1571,115 @@ mod tests { ..Default::default() }; - assert_eq!(res.unwrap().1, expected); + assert_eq!(res.unwrap().1, expected.into()); + } + + #[test] + fn with() { + let qstr0 = "WITH cte1 AS (SELECT a, b FROM table1)"; + let qstr1 = "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, d FROM table2)"; + let qstr2 = + "WITH cte1 (e, f) AS (SELECT a, b FROM table1), cte2 AS (SELECT c, d FROM table2)"; + let qstr3 = "WITH RECURSIVE cte1 AS (SELECT a, b FROM table1)"; + let res0 = with_clause(qstr0.as_bytes()); + let res1 = with_clause(qstr1.as_bytes()); + let res2 = with_clause(qstr2.as_bytes()); + let res3 = with_clause(qstr3.as_bytes()); + + let expected_ss0 = Box::new(Selection::Statement(SelectStatement { + with: None, + tables: vec![Table { + name: "table1".to_string(), + alias: None, + schema: None, + }], + distinct: false, + fields: vec![ + FieldDefinitionExpression::Col(Column::from("a")), + FieldDefinitionExpression::Col(Column::from("b")), + ], + join: vec![], + where_clause: None, + group_by: None, + order: None, + limit: None, + })); + let expected_ss1 = Box::new(Selection::Statement(SelectStatement { + tables: vec![Table { + name: "table2".to_string(), + alias: None, + schema: None, + }], + fields: vec![ + FieldDefinitionExpression::Col(Column::from("c")), + FieldDefinitionExpression::Col(Column::from("d")), + ], + ..Default::default() + })); + + let expected0 = WithClause { + recursive: false, + subclauses: vec![WithSubclause { + name: "cte1".to_string(), + columns: vec![], + selection: expected_ss0.clone(), + }], + }; + let expected1 = WithClause { + recursive: false, + subclauses: vec![ + WithSubclause { + name: "cte1".to_string(), + columns: vec![], + selection: expected_ss0.clone(), + }, + WithSubclause { + name: "cte2".to_string(), + columns: vec![], + selection: expected_ss1.clone(), + }, + ], + }; + let expected2 = WithClause { + recursive: false, + subclauses: vec![ + WithSubclause { + name: "cte1".to_string(), + columns: vec![ + Column { + name: "e".to_string(), + alias: None, + table: None, + function: None, + }, + Column { + name: "f".to_string(), + alias: None, + table: None, + function: None, + }, + ], + selection: expected_ss0.clone(), + }, + WithSubclause { + name: "cte2".to_string(), + columns: vec![], + selection: expected_ss1.clone(), + }, + ], + }; + let expected3 = WithClause { + recursive: true, + subclauses: vec![WithSubclause { + name: "cte1".to_string(), + columns: vec![], + selection: expected_ss0.clone(), + }], + }; + + assert_eq!(res0.unwrap().1, expected0); + assert_eq!(res1.unwrap().1, expected1); + assert_eq!(res2.unwrap().1, expected2); + assert_eq!(res3.unwrap().1, expected3); } } diff --git a/tests/cte-queries.txt b/tests/cte-queries.txt new file mode 100644 index 0000000..112d8f3 --- /dev/null +++ b/tests/cte-queries.txt @@ -0,0 +1,17 @@ +-- simple CTE +WITH cte1 AS (SELECT a, b FROM table1) SELECT b, d FROM cte1; + +-- 2 CTEs +WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, d FROM table2) SELECT b, d FROM cte1 JOIN cte2 ON cte1.a = cte2.c; + +-- CTE in an exists +SELECT 'found' FROM DUAL WHERE EXISTS (WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, d FROM table2) SELECT b, d FROM cte1 JOIN cte2 ON cte1.a = cte2.c); + +-- recursive cte +WITH RECURSIVE cte1 AS (SELECT 1 AS a, 0 AS b FROM dual UNION ALL SELECT cte1.a + 1, cte1.b - 1 FROM table1), cte2 AS (SELECT c, d FROM table2) SELECT b, d FROM cte1 JOIN cte2 ON cte1.a = cte2.c; + +-- recursive cte with multiple initialization parts +WITH RECURSIVE cte1 AS (SELECT 1 AS a, 0 AS b FROM dual UNION SELECT MAX(a) as a, MAX(b) as b FROM dual UNION ALL SELECT cte1.a + 1 AS a, cte1.b - 1 AS b FROM cte1), cte2 AS (SELECT c, d FROM table2) SELECT b, d FROM cte1 JOIN cte2 ON cte1.a = cte2.c; + +-- recursive cte with multiple recursive parts +WITH RECURSIVE cte1 AS (SELECT 1 AS a, 0 AS b FROM dual UNION SELECT MAX(a) as a, MAX(b) as b FROM dual UNION ALL SELECT cte1.a + 1, cte1.b - 1 FROM cte1 UNION SELECT MAX(a) as a, MAX(b) as b FROM table2 ), cte2 AS (SELECT c, d FROM table2) SELECT b, d FROM cte1 JOIN cte2 ON cte1.a = cte2.c; diff --git a/tests/exists-queries.txt b/tests/exists-queries.txt index 2180af6..8513bd3 100644 --- a/tests/exists-queries.txt +++ b/tests/exists-queries.txt @@ -3,4 +3,5 @@ SELECT * FROM employees e WHERE exists(SELECT id FROM eotm_dyn d WHERE d.employe SELECT * FROM employees e WHERE not exists ( SELECT id FROM eotm_dyn d WHERE d.employeeID = e.id) SELECT * FROM employees e WHERE not (exists ( SELECT id FROM eotm_dyn d WHERE d.employeeID = e.id)) SELECT * FROM employees e WHERE x > 3 and not exists (SELECT id FROM eotm_dyn d WHERE d.employeeID = e.id ) and y < 3 +SELECT * FROM employees e WHERE x > 3 and not exists (SELECT id FROM eotm_dyn d WHERE d.employeeID = e.id UNION SELECT id FROM eotm_dyn d WHERE d.employeeID IS NULL ) and y < 3 diff --git a/tests/lib.rs b/tests/lib.rs index e0d7263..daac868 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -123,6 +123,14 @@ fn tpcw_test_tables() { assert_eq!(res.unwrap(), 10); } +#[test] +fn cte_queries() { + let res = test_queries_from_file(Path::new("tests/cte-queries.txt"), "CTE queries"); + assert!(res.is_ok()); + // There are 6 queries + assert_eq!(res.unwrap(), 6); +} + #[test] fn exists_test_queries() { let res = test_queries_from_file( @@ -131,7 +139,7 @@ fn exists_test_queries() { ); assert!(res.is_ok()); // There are 4 queries - assert_eq!(res.unwrap(), 4); + assert_eq!(res.unwrap(), 5); } #[test] @@ -201,6 +209,14 @@ fn parse_comments() { assert_eq!(fail, 0); } +#[test] +fn parse_nested_compound_selects() { + let (ok, fail) = parse_file("tests/nested-compound-selects.txt"); + + assert_eq!(ok, 4); + assert_eq!(fail, 0); +} + #[test] fn parse_autoincrement() { let (ok, fail) = parse_file("tests/autoincrement.txt"); diff --git a/tests/nested-compound-selects.txt b/tests/nested-compound-selects.txt new file mode 100644 index 0000000..888f2b6 --- /dev/null +++ b/tests/nested-compound-selects.txt @@ -0,0 +1,5 @@ +SELECT a, b FROM table1 JOIN (SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table 3 WHERE a = b); +SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b UNION SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b; +(SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b UNION SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b) UNION ALL (SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b UNION SELECT c, d FROM table2 WHERE c = d UNION SELECT a, b FROM table3 WHERE a = b); +SELECT a, b FROM table1 WHERE a IN (SELECT c FROM table2 WHERE c = d UNION SELECT b FROM table3 WHERE a > b); +