diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 470b3eb..90a3943 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -459,10 +459,13 @@ mod tests { } #[test] - fn arithmetic_scalar(){ + fn arithmetic_scalar() { let qs = "56"; let res = arithmetic(qs.as_bytes()); assert!(res.is_err()); - assert_eq!(nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), res.err().unwrap()); + assert_eq!( + nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), + res.err().unwrap() + ); } } diff --git a/src/column.rs b/src/column.rs index 5c873e7..9537f1d 100644 --- a/src/column.rs +++ b/src/column.rs @@ -159,6 +159,21 @@ impl PartialOrd for Column { } } +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub enum SortingColumnIdentifier { + FunctionArguments(FunctionArgument), + Position(usize), +} + +impl fmt::Display for SortingColumnIdentifier { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + SortingColumnIdentifier::FunctionArguments(c) => write!(f, "{}", c), + SortingColumnIdentifier::Position(p) => write!(f, "{}", p), + } + } +} + #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub enum ColumnConstraint { NotNull, diff --git a/src/common.rs b/src/common.rs index 1d6cd4d..e163288 100644 --- a/src/common.rs +++ b/src/common.rs @@ -9,7 +9,9 @@ use std::str::FromStr; use arithmetic::{arithmetic_expression, ArithmeticExpression}; use case::case_when_column; -use column::{Column, FunctionArgument, FunctionArguments, FunctionExpression}; +use column::{ + Column, FunctionArgument, FunctionArguments, FunctionExpression, SortingColumnIdentifier, +}; use keywords::{escape_if_keyword, sql_keyword}; use nom::bytes::complete::{is_not, tag, tag_no_case, take, take_until, take_while1}; use nom::combinator::opt; @@ -388,7 +390,7 @@ where let (inp, _) = first.parse(inp)?; let (inp, o2) = second.parse(inp)?; third.parse(inp).map(|(i, _)| (i, o2)) - }, + } } } } @@ -641,7 +643,8 @@ pub fn function_argument_parser(i: &[u8]) -> IResult<&[u8], FunctionArgument> { // present. pub fn function_arguments(i: &[u8]) -> IResult<&[u8], (FunctionArgument, bool)> { let distinct_parser = opt(tuple((tag_no_case("distinct"), multispace1))); - let (remaining_input, (distinct, args)) = tuple((distinct_parser, function_argument_parser))(i)?; + let (remaining_input, (distinct, args)) = + tuple((distinct_parser, function_argument_parser))(i)?; Ok((remaining_input, (args, distinct.is_some()))) } @@ -695,12 +698,25 @@ pub fn column_function(i: &[u8]) -> IResult<&[u8], FunctionExpression> { FunctionExpression::GroupConcat(FunctionArgument::Column(col.clone()), sep) }, ), - map(tuple((sql_identifier, multispace0, tag("("), separated_list0(tag(","), delimited(multispace0, function_argument_parser, multispace0)), tag(")"))), |tuple| { - let (name, _, _, arguments, _) = tuple; - FunctionExpression::Generic( - str::from_utf8(name).unwrap().to_string(), - FunctionArguments::from(arguments)) - }) + map( + tuple(( + sql_identifier, + multispace0, + tag("("), + separated_list0( + tag(","), + delimited(multispace0, function_argument_parser, multispace0), + ), + tag(")"), + )), + |tuple| { + let (name, _, _, arguments, _) = tuple; + FunctionExpression::Generic( + str::from_utf8(name).unwrap().to_string(), + FunctionArguments::from(arguments), + ) + }, + ), ))(i) } @@ -740,28 +756,48 @@ pub fn column_identifier(i: &[u8]) -> IResult<&[u8], Column> { table: None, function: Some(Box::new(tup.0)), }); - let col_w_table = map( - tuple(( - opt(terminated(sql_identifier, tag("."))), - sql_identifier, - opt(as_alias), - )), - |tup| Column { - name: str::from_utf8(tup.1).unwrap().to_string(), - alias: match tup.2 { + let col_w_table = map(tuple((table_column_identifier, opt(as_alias))), |tup| { + Column { + name: tup.0 .1, + alias: match tup.1 { None => None, Some(a) => Some(String::from(a)), }, - table: match tup.0 { - None => None, - Some(t) => Some(str::from_utf8(t).unwrap().to_string()), - }, + table: tup.0 .0, function: None, - }, - ); + } + }); alt((col_func_no_table, col_w_table))(i) } +// Parses a SQL column name preceded in the table.column format +pub fn table_column_identifier(i: &[u8]) -> IResult<&[u8], (Option, String)> { + tuple(( + map(opt(terminated(sql_identifier, tag("."))), |si| { + si.and_then(|si| Some(str::from_utf8(si).unwrap().to_string())) + }), + map(sql_identifier, |si| str::from_utf8(si).unwrap().to_string()), + ))(i) +} + +pub fn sorting_column_identifier(i: &[u8]) -> IResult<&[u8], SortingColumnIdentifier> { + alt(( + map(digit1, |p| { + SortingColumnIdentifier::Position(usize::from_str(str::from_utf8(p).unwrap()).unwrap()) + }), + map(function_argument_parser, |c| { + SortingColumnIdentifier::FunctionArguments(c) + }), + ))(i) +} + +pub fn group_by_column_identifier(i: &[u8]) -> IResult<&[u8], SortingColumnIdentifier> { + map( + tuple((sorting_column_identifier, opt(ws_sep_comma))), + |(c, _)| c, + )(i) +} + // Parses a SQL identifier (alphanumeric1 and "_"). pub fn sql_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> { alt(( @@ -1021,22 +1057,23 @@ pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec> { // Parse a reference to a named schema.table, with an optional alias pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> { map( - tuple(( - opt(pair(sql_identifier, tag("."))), - sql_identifier, - opt(as_alias) - )), - |tup| Table { - name: String::from(str::from_utf8(tup.1).unwrap()), - alias: match tup.2 { - Some(a) => Some(String::from(a)), - None => None, - }, - schema: match tup.0 { - Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())), - None => None, + tuple(( + opt(pair(sql_identifier, tag("."))), + sql_identifier, + opt(as_alias), + )), + |tup| Table { + name: String::from(str::from_utf8(tup.1).unwrap()), + alias: match tup.2 { + Some(a) => Some(String::from(a)), + None => None, + }, + schema: match tup.0 { + Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())), + None => None, + }, }, - })(i) + )(i) } // Parse a reference to a named table, with an optional alias @@ -1047,7 +1084,7 @@ pub fn table_reference(i: &[u8]) -> IResult<&[u8], Table> { Some(a) => Some(String::from(a)), None => None, }, - schema: None, + schema: None, })(i) } @@ -1137,25 +1174,31 @@ mod tests { name: String::from("max(addr_id)"), alias: None, table: None, - function: Some(Box::new(FunctionExpression::Max( - FunctionArgument::Column(Column::from("addr_id")), - ))), + function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column( + Column::from("addr_id"), + )))), }; assert_eq!(res.unwrap().1, expected); } #[test] fn simple_generic_function() { - let qlist = ["coalesce(a,b,c)".as_bytes(), "coalesce (a,b,c)".as_bytes(), "coalesce(a ,b,c)".as_bytes(), "coalesce(a, b,c)".as_bytes()]; + let qlist = [ + "coalesce(a,b,c)".as_bytes(), + "coalesce (a,b,c)".as_bytes(), + "coalesce(a ,b,c)".as_bytes(), + "coalesce(a, b,c)".as_bytes(), + ]; for q in qlist.iter() { let res = column_function(q); - let expected = FunctionExpression::Generic("coalesce".to_string(), - FunctionArguments::from( - vec!( - FunctionArgument::Column(Column::from("a")), - FunctionArgument::Column(Column::from("b")), - FunctionArgument::Column(Column::from("c")) - ))); + let expected = FunctionExpression::Generic( + "coalesce".to_string(), + FunctionArguments::from(vec![ + FunctionArgument::Column(Column::from("a")), + FunctionArgument::Column(Column::from("b")), + FunctionArgument::Column(Column::from("c")), + ]), + ); assert_eq!(res, Ok((&b""[..], expected))); } } diff --git a/src/compound_select.rs b/src/compound_select.rs index d3ece89..53d3d08 100644 --- a/src/compound_select.rs +++ b/src/compound_select.rs @@ -185,12 +185,18 @@ mod tests { assert!(&res.is_err()); assert_eq!( res.unwrap_err(), - nom::Err::Error(nom::error::Error::new(");".as_bytes(), nom::error::ErrorKind::Tag)) + nom::Err::Error(nom::error::Error::new( + ");".as_bytes(), + nom::error::ErrorKind::Tag + )) ); assert!(&res2.is_err()); assert_eq!( res2.unwrap_err(), - nom::Err::Error(nom::error::Error::new(";".as_bytes(), nom::error::ErrorKind::Tag)) + nom::Err::Error(nom::error::Error::new( + ";".as_bytes(), + nom::error::ErrorKind::Tag + )) ); assert!(&res3.is_err()); assert_eq!( diff --git a/src/condition.rs b/src/condition.rs index a210b7a..ce9e1e8 100644 --- a/src/condition.rs +++ b/src/condition.rs @@ -291,10 +291,7 @@ fn predicate(i: &[u8]) -> IResult<&[u8], ConditionExpression> { }, ); - alt(( - simple_expr, - nested_exists, - ))(i) + alt((simple_expr, nested_exists))(i) } fn simple_expr(i: &[u8]) -> IResult<&[u8], ConditionExpression> { diff --git a/src/create.rs b/src/create.rs index a1b5afc..45b2cb3 100644 --- a/src/create.rs +++ b/src/create.rs @@ -5,8 +5,8 @@ use std::str::FromStr; use column::{Column, ColumnConstraint, ColumnSpecification}; use common::{ - column_identifier_no_alias, parse_comment, sql_identifier, statement_terminator, - schema_table_reference, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey, + column_identifier_no_alias, parse_comment, schema_table_reference, sql_identifier, + statement_terminator, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey, }; use compound_select::{compound_selection, CompoundSelectStatement}; use create_table_options::table_options; @@ -534,7 +534,7 @@ mod tests { assert_eq!( res.unwrap().1, CreateTableStatement { - table: Table::from(("db1","t")), + table: Table::from(("db1", "t")), fields: vec![ColumnSpecification::new( Column::from("t.x"), SqlType::Int(32) diff --git a/src/delete.rs b/src/delete.rs index 64ee1dc..40f3cd1 100644 --- a/src/delete.rs +++ b/src/delete.rs @@ -1,7 +1,7 @@ use nom::character::complete::multispace1; use std::{fmt, str}; -use common::{statement_terminator, schema_table_reference}; +use common::{schema_table_reference, statement_terminator}; use condition::ConditionExpression; use keywords::escape_if_keyword; use nom::bytes::complete::tag_no_case; @@ -77,7 +77,7 @@ mod tests { assert_eq!( res.unwrap().1, DeleteStatement { - table: Table::from(("db1","users")), + table: Table::from(("db1", "users")), ..Default::default() } ); diff --git a/src/insert.rs b/src/insert.rs index 9f55107..5c81427 100644 --- a/src/insert.rs +++ b/src/insert.rs @@ -4,7 +4,7 @@ use std::str; use column::Column; use common::{ - assignment_expr_list, field_list, statement_terminator, schema_table_reference, value_list, + assignment_expr_list, field_list, schema_table_reference, statement_terminator, value_list, ws_sep_comma, FieldValueExpression, Literal, }; use keywords::escape_if_keyword; @@ -145,7 +145,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from(("db1","users")), + table: Table::from(("db1", "users")), fields: None, data: vec![vec![42.into(), "test".into()]], ..Default::default() diff --git a/src/order.rs b/src/order.rs index 3db9ddc..db72bcc 100644 --- a/src/order.rs +++ b/src/order.rs @@ -2,15 +2,17 @@ use nom::character::complete::{multispace0, multispace1}; use std::fmt; use std::str; -use column::Column; -use common::{column_identifier_no_alias, ws_sep_comma}; +use column::SortingColumnIdentifier; +use common::{sorting_column_identifier, ws_sep_comma}; +use condition::condition_expr; use keywords::escape_if_keyword; use nom::branch::alt; use nom::bytes::complete::tag_no_case; use nom::combinator::{map, opt}; -use nom::multi::many0; +use nom::multi::many1; use nom::sequence::{preceded, tuple}; use nom::IResult; +use ConditionExpression; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub enum OrderType { @@ -27,23 +29,38 @@ impl fmt::Display for OrderType { } } +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub enum OrderingExpression { + Columns(Vec<(SortingColumnIdentifier, OrderType)>), + Condition(ConditionExpression), +} + #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct OrderClause { - pub columns: Vec<(Column, OrderType)>, // TODO(malte): can this be an arbitrary expr? + pub expression: OrderingExpression, } impl fmt::Display for OrderClause { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "ORDER BY ")?; - write!( - f, - "{}", - self.columns - .iter() - .map(|&(ref c, ref o)| format!("{} {}", escape_if_keyword(&c.name), o)) - .collect::>() - .join(", ") - ) + + match &self.expression { + OrderingExpression::Columns(c) => { + write!( + f, + "{}", + c.iter() + .map(|&(ref c, ref o)| { + format!("{} {}", escape_if_keyword(&c.to_string()), o) + }) + .collect::>() + .join(", ") + ) + } + OrderingExpression::Condition(c) => { + write!(f, "{}", c) + } + } } } @@ -54,9 +71,19 @@ pub fn order_type(i: &[u8]) -> IResult<&[u8], OrderType> { ))(i) } -fn order_expr(i: &[u8]) -> IResult<&[u8], (Column, OrderType)> { +fn order_expr(i: &[u8]) -> IResult<&[u8], OrderingExpression> { + alt(( + map(many1(order_sorting_column), |c| { + OrderingExpression::Columns(c) + }), + map(condition_expr, |c| OrderingExpression::Condition(c)), + ))(i) +} + +fn order_sorting_column(i: &[u8]) -> IResult<&[u8], (SortingColumnIdentifier, OrderType)> { let (remaining_input, (field_name, ordering, _)) = tuple(( - column_identifier_no_alias, + //column_identifier_no_alias, + sorting_column_identifier, opt(preceded(multispace0, order_type)), opt(ws_sep_comma), ))(i)?; @@ -69,38 +96,61 @@ fn order_expr(i: &[u8]) -> IResult<&[u8], (Column, OrderType)> { // Parse ORDER BY clause pub fn order_clause(i: &[u8]) -> IResult<&[u8], OrderClause> { - let (remaining_input, (_, _, _, columns)) = tuple(( + let (remaining_input, (_, _, _, oe)) = tuple(( multispace0, tag_no_case("order by"), multispace1, - many0(order_expr), + order_expr, ))(i)?; - Ok((remaining_input, OrderClause { columns })) + Ok((remaining_input, OrderClause { expression: oe })) } #[cfg(test)] mod tests { use super::*; + use common::Literal; + use condition::ConditionBase::*; + use condition::ConditionTree; use select::selection; + use Column; + use ConditionExpression::{Base, ComparisonOp}; + use {CaseWhenExpression, Operator}; + use {ColumnOrLiteral, FunctionArgument}; #[test] - fn order_clause() { + fn order_by_clause() { let qstring1 = "select * from users order by name desc\n"; let qstring2 = "select * from users order by name asc, age desc\n"; let qstring3 = "select * from users order by name\n"; let expected_ord1 = OrderClause { - columns: vec![("name".into(), OrderType::OrderDescending)], + expression: OrderingExpression::Columns(vec![( + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column("name".into())), + OrderType::OrderDescending, + )]), }; let expected_ord2 = OrderClause { - columns: vec![ - ("name".into(), OrderType::OrderAscending), - ("age".into(), OrderType::OrderDescending), - ], + expression: OrderingExpression::Columns(vec![ + ( + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column( + "name".into(), + )), + OrderType::OrderAscending, + ), + ( + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column( + "age".into(), + )), + OrderType::OrderDescending, + ), + ]), }; let expected_ord3 = OrderClause { - columns: vec![("name".into(), OrderType::OrderAscending)], + expression: OrderingExpression::Columns(vec![( + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column("name".into())), + OrderType::OrderAscending, + )]), }; let res1 = selection(qstring1.as_bytes()); @@ -110,4 +160,66 @@ mod tests { assert_eq!(res2.unwrap().1.order, Some(expected_ord2)); assert_eq!(res3.unwrap().1.order, Some(expected_ord3)); } + + #[test] + fn order_by_case() { + let qstring = "ORDER BY CASE WHEN vote_id > 10 THEN vote_id END DESC"; + + let res = order_clause(qstring.as_bytes()); + + let filter_cond = ComparisonOp(ConditionTree { + left: Box::new(Base(Field(Column::from("vote_id")))), + right: Box::new(Base(Literal(Literal::Integer(10.into())))), + operator: Operator::Greater, + }); + let expected = OrderClause { + expression: OrderingExpression::Columns(vec![( + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Conditional( + CaseWhenExpression { + then_expr: ColumnOrLiteral::Column(Column::from("vote_id")), + else_expr: None, + condition: filter_cond, + }, + )), + OrderType::OrderDescending, + )]), + }; + + assert_eq!(res.unwrap().1, expected); + } + + #[test] + fn order_by_positionals() { + let qstring0 = "ORDER BY 1"; + let qstring1 = "ORDER BY 1, 5, 3"; + + let res0 = order_clause(qstring0.as_bytes()); + let res1 = order_clause(qstring1.as_bytes()); + + let expected0 = OrderClause { + expression: OrderingExpression::Columns(vec![( + SortingColumnIdentifier::Position(1), + OrderType::OrderAscending, + )]), + }; + let expected1 = OrderClause { + expression: OrderingExpression::Columns(vec![ + ( + SortingColumnIdentifier::Position(1), + OrderType::OrderAscending, + ), + ( + SortingColumnIdentifier::Position(5), + OrderType::OrderAscending, + ), + ( + SortingColumnIdentifier::Position(3), + OrderType::OrderAscending, + ), + ]), + }; + + assert_eq!(res0.unwrap().1, expected0); + assert_eq!(res1.unwrap().1, expected1); + } } diff --git a/src/select.rs b/src/select.rs index 8735e95..2cdce8a 100644 --- a/src/select.rs +++ b/src/select.rs @@ -2,18 +2,19 @@ use nom::character::complete::{multispace0, multispace1}; use std::fmt; use std::str; -use column::Column; -use common::FieldDefinitionExpression; +use column::SortingColumnIdentifier; use common::{ as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference, unsigned_number, }; +use common::{group_by_column_identifier, FieldDefinitionExpression}; use condition::{condition_expr, ConditionExpression}; use join::{join_operator, JoinConstraint, JoinOperator, JoinRightSide}; +use keywords::escape_if_keyword; use nom::branch::alt; use nom::bytes::complete::{tag, tag_no_case}; use nom::combinator::{map, opt}; -use nom::multi::many0; +use nom::multi::{many0, many1}; use nom::sequence::{delimited, preceded, terminated, tuple}; use nom::IResult; use order::{order_clause, OrderClause}; @@ -21,22 +22,29 @@ use table::Table; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct GroupByClause { - pub columns: Vec, + pub expression: GroupByExpression, pub having: Option, } impl fmt::Display for GroupByClause { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "GROUP BY ")?; - write!( - f, - "{}", - self.columns - .iter() - .map(|c| format!("{}", c)) - .collect::>() - .join(", ") - )?; + + match &self.expression { + GroupByExpression::Columns(c) => { + write!( + f, + "{}", + c.iter() + .map(|c| { format!("{}", escape_if_keyword(&c.to_string())) }) + .collect::>() + .join(", ") + ) + } + GroupByExpression::Condition(c) => { + write!(f, "{}", c) + } + }?; if let Some(ref having) = self.having { write!(f, " HAVING {}", having)?; } @@ -44,6 +52,12 @@ impl fmt::Display for GroupByClause { } } +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub enum GroupByExpression { + Columns(Vec), + Condition(ConditionExpression), +} + #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct JoinClause { pub operator: JoinOperator, @@ -149,15 +163,17 @@ fn having_clause(i: &[u8]) -> IResult<&[u8], ConditionExpression> { // Parse GROUP BY clause pub fn group_by_clause(i: &[u8]) -> IResult<&[u8], GroupByClause> { - let (remaining_input, (_, _, _, columns, having)) = tuple(( + let (remaining_input, (_, _, _, expression, having)) = tuple(( multispace0, tag_no_case("group by"), multispace1, - field_list, + map(many1(group_by_column_identifier), |c| { + GroupByExpression::Columns(c) + }), opt(having_clause), ))(i)?; - Ok((remaining_input, GroupByClause { columns, having })) + Ok((remaining_input, GroupByClause { expression, having })) } fn offset(i: &[u8]) -> IResult<&[u8], u64> { @@ -310,14 +326,16 @@ pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { mod tests { use super::*; use case::{CaseWhenExpression, ColumnOrLiteral}; - use column::{Column, FunctionArgument, FunctionArguments, FunctionExpression}; + use column::{ + Column, FunctionArgument, FunctionArguments, FunctionExpression, SortingColumnIdentifier, + }; use common::{ FieldDefinitionExpression, FieldValueExpression, ItemPlaceholder, Literal, Operator, }; use condition::ConditionBase::*; use condition::ConditionExpression::*; use condition::ConditionTree; - use order::OrderType; + use order::{OrderType, OrderingExpression}; use table::Table; fn columns(cols: &[&str]) -> Vec { @@ -533,7 +551,7 @@ mod tests { tables: vec![Table { name: String::from("PaperTag"), alias: Some(String::from("t")), - schema: None, + schema: None, },], fields: vec![FieldDefinitionExpression::All], ..Default::default() @@ -554,7 +572,7 @@ mod tests { tables: vec![Table { name: String::from("PaperTag"), alias: Some(String::from("t")), - schema: Some(String::from("db1")), + schema: Some(String::from("db1")), },], fields: vec![FieldDefinitionExpression::All], ..Default::default() @@ -782,7 +800,11 @@ mod tests { function: Some(Box::new(agg_expr)), })], group_by: Some(GroupByClause { - columns: vec![Column::from("aid")], + expression: GroupByExpression::Columns(vec![ + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column( + "aid".into(), + )), + ]), having: None, }), ..Default::default() @@ -806,7 +828,11 @@ mod tests { function: Some(Box::new(agg_expr)), })], group_by: Some(GroupByClause { - columns: vec![Column::from("aid")], + expression: GroupByExpression::Columns(vec![ + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column( + "aid".into(), + )), + ]), having: None, }), ..Default::default() @@ -842,7 +868,11 @@ mod tests { function: Some(Box::new(agg_expr)), })], group_by: Some(GroupByClause { - columns: vec![Column::from("aid")], + expression: GroupByExpression::Columns(vec![ + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column( + "aid".into(), + )), + ]), having: None, }), ..Default::default() @@ -878,7 +908,11 @@ mod tests { function: Some(Box::new(agg_expr)), })], group_by: Some(GroupByClause { - columns: vec![Column::from("aid")], + expression: GroupByExpression::Columns(vec![ + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column( + "aid".into(), + )), + ]), having: None, }), ..Default::default() @@ -915,7 +949,11 @@ mod tests { function: Some(Box::new(agg_expr)), })], group_by: Some(GroupByClause { - columns: vec![Column::from("aid")], + expression: GroupByExpression::Columns(vec![ + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column( + "aid".into(), + )), + ]), having: None, }), ..Default::default() @@ -962,7 +1000,11 @@ mod tests { function: Some(Box::new(agg_expr)), })], group_by: Some(GroupByClause { - columns: vec![Column::from("votes.comment_id")], + expression: GroupByExpression::Columns(vec![ + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column( + "votes.comment_id".into(), + )), + ]), having: None, }), ..Default::default() @@ -1049,7 +1091,17 @@ mod tests { fields: vec![FieldDefinitionExpression::All], where_clause: expected_where_cond, order: Some(OrderClause { - columns: vec![("item.i_title".into(), OrderType::OrderAscending)], + expression: OrderingExpression::Columns(vec![( + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column( + Column { + name: "i_title".to_string(), + alias: None, + table: Some("item".to_string()), + function: None + } + )), + OrderType::OrderAscending + )]), }), limit: Some(LimitClause { limit: 50, @@ -1102,7 +1154,12 @@ mod tests { constraint: JoinConstraint::On(join_cond), }], order: Some(OrderClause { - columns: vec![("contactId".into(), OrderType::OrderAscending)], + expression: OrderingExpression::Columns(vec![( + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Column( + "contactId".into(), + )), + OrderType::OrderAscending, + )]), }), ..Default::default() }; @@ -1409,4 +1466,57 @@ mod tests { assert_eq!(res.unwrap().1, expected); } + + #[test] + fn group_by_case() { + let qstring = "GROUP BY CASE WHEN vote_id > 10 THEN vote_id END"; + + let res = group_by_clause(qstring.as_bytes()); + + let filter_cond = ComparisonOp(ConditionTree { + left: Box::new(Base(Field(Column::from("vote_id")))), + right: Box::new(Base(Literal(Literal::Integer(10.into())))), + operator: Operator::Greater, + }); + let expected = GroupByClause { + expression: GroupByExpression::Columns(vec![ + SortingColumnIdentifier::FunctionArguments(FunctionArgument::Conditional( + CaseWhenExpression { + then_expr: ColumnOrLiteral::Column(Column::from("vote_id")), + else_expr: None, + condition: filter_cond, + }, + )), + ]), + having: None, + }; + + assert_eq!(res.unwrap().1, expected); + } + + #[test] + fn group_by_positionals() { + let qstring0 = "GROUP BY 1"; + let qstring1 = "GROUP BY 1, 5, 3"; + + let res0 = group_by_clause(qstring0.as_bytes()); + let res1 = group_by_clause(qstring1.as_bytes()); + + let expected0 = GroupByClause { + expression: GroupByExpression::Columns(vec![SortingColumnIdentifier::Position(1)]), + having: None, + }; + + let expected1 = GroupByClause { + expression: GroupByExpression::Columns(vec![ + SortingColumnIdentifier::Position(1), + SortingColumnIdentifier::Position(5), + SortingColumnIdentifier::Position(3), + ]), + having: None, + }; + + assert_eq!(res0.unwrap().1, expected0); + assert_eq!(res1.unwrap().1, expected1); + } }