diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 470b3eb..90a3943 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -459,10 +459,13 @@ mod tests { } #[test] - fn arithmetic_scalar(){ + fn arithmetic_scalar() { let qs = "56"; let res = arithmetic(qs.as_bytes()); assert!(res.is_err()); - assert_eq!(nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), res.err().unwrap()); + assert_eq!( + nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), + res.err().unwrap() + ); } } diff --git a/src/common.rs b/src/common.rs index 1d6cd4d..9002b92 100644 --- a/src/common.rs +++ b/src/common.rs @@ -16,7 +16,9 @@ use nom::combinator::opt; use nom::error::{ErrorKind, ParseError}; use nom::multi::{fold_many0, many0, many1, separated_list0}; use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple}; +use select::join_clause; use table::Table; +use JoinClause; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub enum SqlType { @@ -388,7 +390,7 @@ where let (inp, _) = first.parse(inp)?; let (inp, o2) = second.parse(inp)?; third.parse(inp).map(|(i, _)| (i, o2)) - }, + } } } } @@ -641,7 +643,8 @@ pub fn function_argument_parser(i: &[u8]) -> IResult<&[u8], FunctionArgument> { // present. pub fn function_arguments(i: &[u8]) -> IResult<&[u8], (FunctionArgument, bool)> { let distinct_parser = opt(tuple((tag_no_case("distinct"), multispace1))); - let (remaining_input, (distinct, args)) = tuple((distinct_parser, function_argument_parser))(i)?; + let (remaining_input, (distinct, args)) = + tuple((distinct_parser, function_argument_parser))(i)?; Ok((remaining_input, (args, distinct.is_some()))) } @@ -695,12 +698,25 @@ pub fn column_function(i: &[u8]) -> IResult<&[u8], FunctionExpression> { FunctionExpression::GroupConcat(FunctionArgument::Column(col.clone()), sep) }, ), - map(tuple((sql_identifier, multispace0, tag("("), separated_list0(tag(","), delimited(multispace0, function_argument_parser, multispace0)), tag(")"))), |tuple| { - let (name, _, _, arguments, _) = tuple; - FunctionExpression::Generic( - str::from_utf8(name).unwrap().to_string(), - FunctionArguments::from(arguments)) - }) + map( + tuple(( + sql_identifier, + multispace0, + tag("("), + separated_list0( + tag(","), + delimited(multispace0, function_argument_parser, multispace0), + ), + tag(")"), + )), + |tuple| { + let (name, _, _, arguments, _) = tuple; + FunctionExpression::Generic( + str::from_utf8(name).unwrap().to_string(), + FunctionArguments::from(arguments), + ) + }, + ), ))(i) } @@ -893,6 +909,15 @@ pub fn table_list(i: &[u8]) -> IResult<&[u8], Vec> { many0(terminated(schema_table_reference, opt(ws_sep_comma)))(i) } +pub fn from_clause(i: &[u8]) -> IResult<&[u8], Vec
> { + let (i, (_, t)) = tuple(( + delimited(multispace0, tag_no_case("from"), multispace0), + many1(terminated(schema_table_reference, opt(ws_sep_comma))), + ))(i)?; + + Ok((i, t)) +} + // Integer literal value pub fn integer_literal(i: &[u8]) -> IResult<&[u8], Literal> { map(pair(opt(tag("-")), digit1), |tup| { @@ -1018,25 +1043,44 @@ pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec> { many0(delimited(multispace0, literal, opt(ws_sep_comma)))(i) } +pub fn relational_objects_clauses(i: &[u8]) -> IResult<&[u8], (Vec
, Vec)> { + match from_clause(i) { + Ok((i, f)) => { + let (i, j) = many0(join_clause)(i).unwrap_or((i, vec![])); + + Ok((i, (f, j))) + } + Err(e) => { + if join_clause(i).is_ok() { + //TODO: needs a more helpful error once error handling is improved upon + Err(e) + } else { + Ok((i, (vec![], vec![]))) + } + } + } +} + // Parse a reference to a named schema.table, with an optional alias pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> { map( - tuple(( - opt(pair(sql_identifier, tag("."))), - sql_identifier, - opt(as_alias) - )), - |tup| Table { - name: String::from(str::from_utf8(tup.1).unwrap()), - alias: match tup.2 { - Some(a) => Some(String::from(a)), - None => None, - }, - schema: match tup.0 { - Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())), - None => None, + tuple(( + opt(pair(sql_identifier, tag("."))), + sql_identifier, + opt(as_alias), + )), + |tup| Table { + name: String::from(str::from_utf8(tup.1).unwrap()), + alias: match tup.2 { + Some(a) => Some(String::from(a)), + None => None, + }, + schema: match tup.0 { + Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())), + None => None, + }, }, - })(i) + )(i) } // Parse a reference to a named table, with an optional alias @@ -1047,7 +1091,7 @@ pub fn table_reference(i: &[u8]) -> IResult<&[u8], Table> { Some(a) => Some(String::from(a)), None => None, }, - schema: None, + schema: None, })(i) } @@ -1137,25 +1181,31 @@ mod tests { name: String::from("max(addr_id)"), alias: None, table: None, - function: Some(Box::new(FunctionExpression::Max( - FunctionArgument::Column(Column::from("addr_id")), - ))), + function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column( + Column::from("addr_id"), + )))), }; assert_eq!(res.unwrap().1, expected); } #[test] fn simple_generic_function() { - let qlist = ["coalesce(a,b,c)".as_bytes(), "coalesce (a,b,c)".as_bytes(), "coalesce(a ,b,c)".as_bytes(), "coalesce(a, b,c)".as_bytes()]; + let qlist = [ + "coalesce(a,b,c)".as_bytes(), + "coalesce (a,b,c)".as_bytes(), + "coalesce(a ,b,c)".as_bytes(), + "coalesce(a, b,c)".as_bytes(), + ]; for q in qlist.iter() { let res = column_function(q); - let expected = FunctionExpression::Generic("coalesce".to_string(), - FunctionArguments::from( - vec!( - FunctionArgument::Column(Column::from("a")), - FunctionArgument::Column(Column::from("b")), - FunctionArgument::Column(Column::from("c")) - ))); + let expected = FunctionExpression::Generic( + "coalesce".to_string(), + FunctionArguments::from(vec![ + FunctionArgument::Column(Column::from("a")), + FunctionArgument::Column(Column::from("b")), + FunctionArgument::Column(Column::from("c")), + ]), + ); assert_eq!(res, Ok((&b""[..], expected))); } } diff --git a/src/compound_select.rs b/src/compound_select.rs index d3ece89..53d3d08 100644 --- a/src/compound_select.rs +++ b/src/compound_select.rs @@ -185,12 +185,18 @@ mod tests { assert!(&res.is_err()); assert_eq!( res.unwrap_err(), - nom::Err::Error(nom::error::Error::new(");".as_bytes(), nom::error::ErrorKind::Tag)) + nom::Err::Error(nom::error::Error::new( + ");".as_bytes(), + nom::error::ErrorKind::Tag + )) ); assert!(&res2.is_err()); assert_eq!( res2.unwrap_err(), - nom::Err::Error(nom::error::Error::new(";".as_bytes(), nom::error::ErrorKind::Tag)) + nom::Err::Error(nom::error::Error::new( + ";".as_bytes(), + nom::error::ErrorKind::Tag + )) ); assert!(&res3.is_err()); assert_eq!( diff --git a/src/condition.rs b/src/condition.rs index a210b7a..ce9e1e8 100644 --- a/src/condition.rs +++ b/src/condition.rs @@ -291,10 +291,7 @@ fn predicate(i: &[u8]) -> IResult<&[u8], ConditionExpression> { }, ); - alt(( - simple_expr, - nested_exists, - ))(i) + alt((simple_expr, nested_exists))(i) } fn simple_expr(i: &[u8]) -> IResult<&[u8], ConditionExpression> { diff --git a/src/create.rs b/src/create.rs index a1b5afc..45b2cb3 100644 --- a/src/create.rs +++ b/src/create.rs @@ -5,8 +5,8 @@ use std::str::FromStr; use column::{Column, ColumnConstraint, ColumnSpecification}; use common::{ - column_identifier_no_alias, parse_comment, sql_identifier, statement_terminator, - schema_table_reference, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey, + column_identifier_no_alias, parse_comment, schema_table_reference, sql_identifier, + statement_terminator, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey, }; use compound_select::{compound_selection, CompoundSelectStatement}; use create_table_options::table_options; @@ -534,7 +534,7 @@ mod tests { assert_eq!( res.unwrap().1, CreateTableStatement { - table: Table::from(("db1","t")), + table: Table::from(("db1", "t")), fields: vec![ColumnSpecification::new( Column::from("t.x"), SqlType::Int(32) diff --git a/src/delete.rs b/src/delete.rs index 64ee1dc..40f3cd1 100644 --- a/src/delete.rs +++ b/src/delete.rs @@ -1,7 +1,7 @@ use nom::character::complete::multispace1; use std::{fmt, str}; -use common::{statement_terminator, schema_table_reference}; +use common::{schema_table_reference, statement_terminator}; use condition::ConditionExpression; use keywords::escape_if_keyword; use nom::bytes::complete::tag_no_case; @@ -77,7 +77,7 @@ mod tests { assert_eq!( res.unwrap().1, DeleteStatement { - table: Table::from(("db1","users")), + table: Table::from(("db1", "users")), ..Default::default() } ); diff --git a/src/insert.rs b/src/insert.rs index 9f55107..5c81427 100644 --- a/src/insert.rs +++ b/src/insert.rs @@ -4,7 +4,7 @@ use std::str; use column::Column; use common::{ - assignment_expr_list, field_list, statement_terminator, schema_table_reference, value_list, + assignment_expr_list, field_list, schema_table_reference, statement_terminator, value_list, ws_sep_comma, FieldValueExpression, Literal, }; use keywords::escape_if_keyword; @@ -145,7 +145,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from(("db1","users")), + table: Table::from(("db1", "users")), fields: None, data: vec![vec![42.into(), "test".into()]], ..Default::default() diff --git a/src/join.rs b/src/join.rs index b91f5dc..2f1ca4c 100644 --- a/src/join.rs +++ b/src/join.rs @@ -128,12 +128,12 @@ mod tests { let join_cond = ConditionExpression::ComparisonOp(ct); let expected_stmt = SelectStatement { tables: vec![Table::from("tags")], - fields: vec![FieldDefinitionExpression::AllInTable("tags".into())], join: vec![JoinClause { operator: JoinOperator::InnerJoin, right: JoinRightSide::Table(Table::from("taggings")), constraint: JoinConstraint::On(join_cond), }], + fields: vec![FieldDefinitionExpression::AllInTable("tags".into())], ..Default::default() }; diff --git a/src/select.rs b/src/select.rs index 8735e95..08fd893 100644 --- a/src/select.rs +++ b/src/select.rs @@ -3,17 +3,16 @@ use std::fmt; use std::str; use column::Column; -use common::FieldDefinitionExpression; use common::{ - as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference, - unsigned_number, + as_alias, field_definition_expr, field_list, relational_objects_clauses, statement_terminator, + table_reference, unsigned_number, }; +use common::{table_list, FieldDefinitionExpression}; use condition::{condition_expr, ConditionExpression}; use join::{join_operator, JoinConstraint, JoinOperator, JoinRightSide}; use nom::branch::alt; use nom::bytes::complete::{tag, tag_no_case}; use nom::combinator::{map, opt}; -use nom::multi::many0; use nom::sequence::{delimited, preceded, terminated, tuple}; use nom::IResult; use order::{order_clause, OrderClause}; @@ -78,9 +77,9 @@ impl fmt::Display for LimitClause { #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct SelectStatement { - pub tables: Vec
, pub distinct: bool, pub fields: Vec, + pub tables: Vec
, pub join: Vec, pub where_clause: Option, pub group_by: Option, @@ -116,9 +115,11 @@ impl fmt::Display for SelectStatement { .join(", ") )?; } + for jc in &self.join { write!(f, " {}", jc)?; } + if let Some(ref where_clause) = self.where_clause { write!(f, " WHERE ")?; write!(f, "{}", where_clause)?; @@ -217,7 +218,7 @@ fn join_constraint(i: &[u8]) -> IResult<&[u8], JoinConstraint> { } // Parse JOIN clause -fn join_clause(i: &[u8]) -> IResult<&[u8], JoinClause> { +pub fn join_clause(i: &[u8]) -> IResult<&[u8], JoinClause> { let (remaining_input, (_, _natural, operator, _, right, _, constraint)) = tuple(( multispace0, opt(terminated(tag_no_case("natural"), multispace1)), @@ -276,16 +277,14 @@ pub fn selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { let ( remaining_input, - (_, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit), + (_, _, distinct, _, fields, (tables, join), where_clause, group_by, order, limit), ) = tuple(( tag_no_case("select"), multispace1, opt(tag_no_case("distinct")), multispace0, field_definition_expr, - delimited(multispace0, tag_no_case("from"), multispace0), - table_list, - many0(join_clause), + relational_objects_clauses, opt(where_clause), opt(group_by_clause), opt(order_clause), @@ -295,9 +294,9 @@ pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { remaining_input, SelectStatement { tables, + join, distinct: distinct.is_some(), fields, - join, where_clause, group_by, order, @@ -319,6 +318,7 @@ mod tests { use condition::ConditionTree; use order::OrderType; use table::Table; + use OrderType::OrderAscending; fn columns(cols: &[&str]) -> Vec { cols.iter() @@ -533,8 +533,8 @@ mod tests { tables: vec![Table { name: String::from("PaperTag"), alias: Some(String::from("t")), - schema: None, - },], + schema: None, + }], fields: vec![FieldDefinitionExpression::All], ..Default::default() } @@ -554,7 +554,7 @@ mod tests { tables: vec![Table { name: String::from("PaperTag"), alias: Some(String::from("t")), - schema: Some(String::from("db1")), + schema: Some(String::from("db1")), },], fields: vec![FieldDefinitionExpression::All], ..Default::default() @@ -1067,12 +1067,12 @@ mod tests { let res = selection(qstring.as_bytes()); let expected_stmt = SelectStatement { tables: vec![Table::from("PaperConflict")], - fields: columns(&["paperId"]), join: vec![JoinClause { operator: JoinOperator::Join, right: JoinRightSide::Table(Table::from("PCMember")), constraint: JoinConstraint::Using(vec![Column::from("contactId")]), }], + fields: columns(&["paperId"]), ..Default::default() }; assert_eq!(res.unwrap().1, expected_stmt); @@ -1095,12 +1095,12 @@ mod tests { let join_cond = ConditionExpression::ComparisonOp(ct); let expected = SelectStatement { tables: vec![Table::from("PCMember")], - fields: columns(&["PCMember.contactId"]), join: vec![JoinClause { operator: JoinOperator::Join, right: JoinRightSide::Table(Table::from("PaperReview")), constraint: JoinConstraint::On(join_cond), }], + fields: columns(&["PCMember.contactId"]), order: Some(OrderClause { columns: vec![("contactId".into(), OrderType::OrderAscending)], }), @@ -1153,11 +1153,6 @@ mod tests { res.unwrap().1, SelectStatement { tables: vec![Table::from("ContactInfo")], - fields: columns(&[ - "PCMember.contactId", - "ChairAssistant.contactId", - "Chair.contactId" - ]), join: vec![ mkjoin("PaperReview", "contactId"), mkjoin("PaperConflict", "contactId"), @@ -1165,12 +1160,29 @@ mod tests { mkjoin("ChairAssistant", "contactId"), mkjoin("Chair", "contactId"), ], + fields: columns(&[ + "PCMember.contactId", + "ChairAssistant.contactId", + "Chair.contactId" + ]), where_clause: expected_where_cond, ..Default::default() } ); } + #[test] + fn out_of_order_joins_fail() { + let qstring0 = "select paperId join PCMember using (contactId);"; + let qstring1 = "select paperId join PCMember from PaperConflict using (contactId);"; + + let res0 = selection(qstring0.as_bytes()); + let res1 = selection(qstring1.as_bytes()); + + assert!(res0.is_err()); + assert!(res1.is_err()); + } + #[test] fn nested_select() { let qstr = "SELECT ol_i_id FROM orders, order_line \ @@ -1297,7 +1309,6 @@ mod tests { let outer_select = SelectStatement { tables: vec![Table::from("orders")], - fields: columns(&["o_id", "ol_i_id"]), join: vec![JoinClause { operator: JoinOperator::Join, right: JoinRightSide::NestedSelect(Box::new(inner_select), Some("ids".into())), @@ -1307,6 +1318,7 @@ mod tests { right: Box::new(Base(Field(Column::from("ids.ol_i_id")))), })), }], + fields: columns(&["o_id", "ol_i_id"]), ..Default::default() }; @@ -1390,10 +1402,6 @@ mod tests { let expected = SelectStatement { tables: vec![Table::from("auth_permission")], - fields: vec![ - FieldDefinitionExpression::Col(Column::from("auth_permission.content_type_id")), - FieldDefinitionExpression::Col(Column::from("auth_permission.codename")), - ], join: vec![JoinClause { operator: JoinOperator::Join, right: JoinRightSide::Table(Table::from("django_content_type")), @@ -1403,10 +1411,82 @@ mod tests { right: Box::new(Base(Field(Column::from("django_content_type.id")))), })), }], + fields: vec![ + FieldDefinitionExpression::Col(Column::from("auth_permission.content_type_id")), + FieldDefinitionExpression::Col(Column::from("auth_permission.codename")), + ], where_clause: expected_where_clause, ..Default::default() }; assert_eq!(res.unwrap().1, expected); } + + #[test] + fn literal_select() { + use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator}; + + let qstr0 = "SELECT 1 + 1"; + let qstr1 = "SELECT 1 + 1 AS adder GROUP BY adder ORDER BY adder"; + + let res0 = selection(qstr0.as_bytes()); + let res1 = selection(qstr1.as_bytes()); + + let expected0 = SelectStatement { + distinct: false, + fields: vec![FieldDefinitionExpression::Value( + FieldValueExpression::Arithmetic(ArithmeticExpression::new( + ArithmeticOperator::Add, + ArithmeticBase::Scalar(1.into()), + ArithmeticBase::Scalar(1.into()), + None, + )), + )], + tables: vec![], + where_clause: None, + group_by: None, + order: None, + limit: None, + ..Default::default() + }; + + let expected1 = SelectStatement { + distinct: false, + fields: vec![FieldDefinitionExpression::Value( + FieldValueExpression::Arithmetic(ArithmeticExpression::new( + ArithmeticOperator::Add, + ArithmeticBase::Scalar(1.into()), + ArithmeticBase::Scalar(1.into()), + Some("adder".to_string()), + )), + )], + tables: vec![], + where_clause: None, + group_by: Some(GroupByClause { + columns: vec![Column { + name: "adder".to_string(), + alias: None, + table: None, + function: None, + }], + having: None, + }), + order: Some(OrderClause { + columns: vec![( + Column { + name: "adder".to_string(), + alias: None, + table: None, + function: None, + }, + OrderAscending, + )], + }), + limit: None, + ..Default::default() + }; + + assert_eq!(res0.unwrap().1, expected0); + assert_eq!(res1.unwrap().1, expected1); + } }