diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 470b3eb..90a3943 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -459,10 +459,13 @@ mod tests { } #[test] - fn arithmetic_scalar(){ + fn arithmetic_scalar() { let qs = "56"; let res = arithmetic(qs.as_bytes()); assert!(res.is_err()); - assert_eq!(nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), res.err().unwrap()); + assert_eq!( + nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), + res.err().unwrap() + ); } } diff --git a/src/common.rs b/src/common.rs index 1d6cd4d..a40a045 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,7 +1,7 @@ use nom::branch::alt; use nom::character::complete::{alphanumeric1, digit1, line_ending, multispace0, multispace1}; use nom::character::is_alphanumeric; -use nom::combinator::{map, not, peek}; +use nom::combinator::{into, map, not, peek}; use nom::{IResult, InputLength, Parser}; use std::fmt::{self, Display}; use std::str; @@ -16,7 +16,9 @@ use nom::combinator::opt; use nom::error::{ErrorKind, ParseError}; use nom::multi::{fold_many0, many0, many1, separated_list0}; use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple}; -use table::Table; +use select::join_clause; +use table::{Table, TableObject, TablePartition, TablePartitionList}; +use JoinClause; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub enum SqlType { @@ -388,7 +390,7 @@ where let (inp, _) = first.parse(inp)?; let (inp, o2) = second.parse(inp)?; third.parse(inp).map(|(i, _)| (i, o2)) - }, + } } } } @@ -641,7 +643,8 @@ pub fn function_argument_parser(i: &[u8]) -> IResult<&[u8], FunctionArgument> { // present. pub fn function_arguments(i: &[u8]) -> IResult<&[u8], (FunctionArgument, bool)> { let distinct_parser = opt(tuple((tag_no_case("distinct"), multispace1))); - let (remaining_input, (distinct, args)) = tuple((distinct_parser, function_argument_parser))(i)?; + let (remaining_input, (distinct, args)) = + tuple((distinct_parser, function_argument_parser))(i)?; Ok((remaining_input, (args, distinct.is_some()))) } @@ -695,12 +698,25 @@ pub fn column_function(i: &[u8]) -> IResult<&[u8], FunctionExpression> { FunctionExpression::GroupConcat(FunctionArgument::Column(col.clone()), sep) }, ), - map(tuple((sql_identifier, multispace0, tag("("), separated_list0(tag(","), delimited(multispace0, function_argument_parser, multispace0)), tag(")"))), |tuple| { - let (name, _, _, arguments, _) = tuple; - FunctionExpression::Generic( - str::from_utf8(name).unwrap().to_string(), - FunctionArguments::from(arguments)) - }) + map( + tuple(( + sql_identifier, + multispace0, + tag("("), + separated_list0( + tag(","), + delimited(multispace0, function_argument_parser, multispace0), + ), + tag(")"), + )), + |tuple| { + let (name, _, _, arguments, _) = tuple; + FunctionExpression::Generic( + str::from_utf8(name).unwrap().to_string(), + FunctionArguments::from(arguments), + ) + }, + ), ))(i) } @@ -893,6 +909,51 @@ pub fn table_list(i: &[u8]) -> IResult<&[u8], Vec> { many0(terminated(schema_table_reference, opt(ws_sep_comma)))(i) } +pub fn table_object_list(i: &[u8]) -> IResult<&[u8], Vec> { + many0(terminated(table_object, opt(ws_sep_comma)))(i) +} + +pub fn table_object(i: &[u8]) -> IResult<&[u8], TableObject> { + map( + tuple(( + schema_table_reference, + opt(preceded(multispace0, table_partition_list)), + )), + |(table, pl)| TableObject { + table, + partitions: pl.unwrap_or(vec![].into()), + }, + )(i) +} + +pub fn table_partition_list(i: &[u8]) -> IResult<&[u8], TablePartitionList> { + let (remaining, (_, _, p, _, _)) = tuple(( + tag_no_case("partition ("), + multispace0, + into(many1(terminated( + table_partition, + opt(tuple((multispace0, tag(","), multispace0))), + ))), + multispace0, + tag(")"), + ))(i)?; + + Ok((remaining, p)) +} + +pub fn table_partition(i: &[u8]) -> IResult<&[u8], TablePartition> { + into(sql_identifier)(i) +} + +pub fn from_clause(i: &[u8]) -> IResult<&[u8], Vec> { + let (i, (_, t)) = tuple(( + delimited(multispace0, tag_no_case("from"), multispace0), + many1(terminated(table_object, opt(ws_sep_comma))), + ))(i)?; + + Ok((i, t)) +} + // Integer literal value pub fn integer_literal(i: &[u8]) -> IResult<&[u8], Literal> { map(pair(opt(tag("-")), digit1), |tup| { @@ -1018,25 +1079,44 @@ pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec> { many0(delimited(multispace0, literal, opt(ws_sep_comma)))(i) } +pub fn relational_objects_clauses(i: &[u8]) -> IResult<&[u8], (Vec, Vec)> { + match from_clause(i) { + Ok((i, f)) => { + let (i, j) = many0(join_clause)(i).unwrap_or((i, vec![])); + + Ok((i, (f, j))) + } + Err(e) => { + if join_clause(i).is_ok() { + //TODO: needs a more helpful error once error handling is improved upon + Err(e) + } else { + Ok((i, (vec![], vec![]))) + } + } + } +} + // Parse a reference to a named schema.table, with an optional alias pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> { map( - tuple(( - opt(pair(sql_identifier, tag("."))), - sql_identifier, - opt(as_alias) - )), - |tup| Table { - name: String::from(str::from_utf8(tup.1).unwrap()), - alias: match tup.2 { - Some(a) => Some(String::from(a)), - None => None, - }, - schema: match tup.0 { - Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())), - None => None, + tuple(( + opt(pair(sql_identifier, tag("."))), + sql_identifier, + opt(as_alias), + )), + |tup| Table { + name: String::from(str::from_utf8(tup.1).unwrap()), + alias: match tup.2 { + Some(a) => Some(String::from(a)), + None => None, + }, + schema: match tup.0 { + Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())), + None => None, + }, }, - })(i) + )(i) } // Parse a reference to a named table, with an optional alias @@ -1047,7 +1127,7 @@ pub fn table_reference(i: &[u8]) -> IResult<&[u8], Table> { Some(a) => Some(String::from(a)), None => None, }, - schema: None, + schema: None, })(i) } @@ -1137,25 +1217,31 @@ mod tests { name: String::from("max(addr_id)"), alias: None, table: None, - function: Some(Box::new(FunctionExpression::Max( - FunctionArgument::Column(Column::from("addr_id")), - ))), + function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column( + Column::from("addr_id"), + )))), }; assert_eq!(res.unwrap().1, expected); } #[test] fn simple_generic_function() { - let qlist = ["coalesce(a,b,c)".as_bytes(), "coalesce (a,b,c)".as_bytes(), "coalesce(a ,b,c)".as_bytes(), "coalesce(a, b,c)".as_bytes()]; + let qlist = [ + "coalesce(a,b,c)".as_bytes(), + "coalesce (a,b,c)".as_bytes(), + "coalesce(a ,b,c)".as_bytes(), + "coalesce(a, b,c)".as_bytes(), + ]; for q in qlist.iter() { let res = column_function(q); - let expected = FunctionExpression::Generic("coalesce".to_string(), - FunctionArguments::from( - vec!( - FunctionArgument::Column(Column::from("a")), - FunctionArgument::Column(Column::from("b")), - FunctionArgument::Column(Column::from("c")) - ))); + let expected = FunctionExpression::Generic( + "coalesce".to_string(), + FunctionArguments::from(vec![ + FunctionArgument::Column(Column::from("a")), + FunctionArgument::Column(Column::from("b")), + FunctionArgument::Column(Column::from("c")), + ]), + ); assert_eq!(res, Ok((&b""[..], expected))); } } diff --git a/src/compound_select.rs b/src/compound_select.rs index d3ece89..764c40e 100644 --- a/src/compound_select.rs +++ b/src/compound_select.rs @@ -133,7 +133,7 @@ mod tests { use super::*; use column::Column; use common::{FieldDefinitionExpression, FieldValueExpression, Literal}; - use table::Table; + use table::TableObject; #[test] fn union() { @@ -143,7 +143,7 @@ mod tests { let res2 = compound_selection(qstr2.as_bytes()); let first_select = SelectStatement { - tables: vec![Table::from("Vote")], + tables: vec![TableObject::from("Vote")], fields: vec![ FieldDefinitionExpression::Col(Column::from("id")), FieldDefinitionExpression::Value(FieldValueExpression::Literal( @@ -153,7 +153,7 @@ mod tests { ..Default::default() }; let second_select = SelectStatement { - tables: vec![Table::from("Rating")], + tables: vec![TableObject::from("Rating")], fields: vec![ FieldDefinitionExpression::Col(Column::from("id")), FieldDefinitionExpression::Col(Column::from("stars")), @@ -185,12 +185,18 @@ mod tests { assert!(&res.is_err()); assert_eq!( res.unwrap_err(), - nom::Err::Error(nom::error::Error::new(");".as_bytes(), nom::error::ErrorKind::Tag)) + nom::Err::Error(nom::error::Error::new( + ");".as_bytes(), + nom::error::ErrorKind::Tag + )) ); assert!(&res2.is_err()); assert_eq!( res2.unwrap_err(), - nom::Err::Error(nom::error::Error::new(";".as_bytes(), nom::error::ErrorKind::Tag)) + nom::Err::Error(nom::error::Error::new( + ";".as_bytes(), + nom::error::ErrorKind::Tag + )) ); assert!(&res3.is_err()); assert_eq!( @@ -210,7 +216,7 @@ mod tests { let res = compound_selection(qstr.as_bytes()); let first_select = SelectStatement { - tables: vec![Table::from("Vote")], + tables: vec![TableObject::from("Vote")], fields: vec![ FieldDefinitionExpression::Col(Column::from("id")), FieldDefinitionExpression::Value(FieldValueExpression::Literal( @@ -220,7 +226,7 @@ mod tests { ..Default::default() }; let second_select = SelectStatement { - tables: vec![Table::from("Rating")], + tables: vec![TableObject::from("Rating")], fields: vec![ FieldDefinitionExpression::Col(Column::from("id")), FieldDefinitionExpression::Col(Column::from("stars")), @@ -228,7 +234,7 @@ mod tests { ..Default::default() }; let third_select = SelectStatement { - tables: vec![Table::from("Vote")], + tables: vec![TableObject::from("Vote")], fields: vec![ FieldDefinitionExpression::Value(FieldValueExpression::Literal( Literal::Integer(42).into(), @@ -259,7 +265,7 @@ mod tests { let res = compound_selection(qstr.as_bytes()); let first_select = SelectStatement { - tables: vec![Table::from("Vote")], + tables: vec![TableObject::from("Vote")], fields: vec![ FieldDefinitionExpression::Col(Column::from("id")), FieldDefinitionExpression::Value(FieldValueExpression::Literal( @@ -269,7 +275,7 @@ mod tests { ..Default::default() }; let second_select = SelectStatement { - tables: vec![Table::from("Rating")], + tables: vec![TableObject::from("Rating")], fields: vec![ FieldDefinitionExpression::Col(Column::from("id")), FieldDefinitionExpression::Col(Column::from("stars")), diff --git a/src/condition.rs b/src/condition.rs index a210b7a..c18497a 100644 --- a/src/condition.rs +++ b/src/condition.rs @@ -291,10 +291,7 @@ fn predicate(i: &[u8]) -> IResult<&[u8], ConditionExpression> { }, ); - alt(( - simple_expr, - nested_exists, - ))(i) + alt((simple_expr, nested_exists))(i) } fn simple_expr(i: &[u8]) -> IResult<&[u8], ConditionExpression> { @@ -332,6 +329,7 @@ mod tests { use arithmetic::{ArithmeticBase, ArithmeticOperator}; use column::Column; use common::{FieldDefinitionExpression, ItemPlaceholder, Literal, Operator}; + use table::TableObject; fn columns(cols: &[&str]) -> Vec { cols.iter() @@ -713,7 +711,6 @@ mod tests { fn nested_select() { use select::SelectStatement; use std::default::Default; - use table::Table; use ConditionBase::*; let cond = "bar in (select col from foo)"; @@ -721,7 +718,7 @@ mod tests { let res = condition_expr(cond.as_bytes()); let nested_select = Box::new(SelectStatement { - tables: vec![Table::from("foo")], + tables: vec![TableObject::from("foo")], fields: columns(&["col"]), ..Default::default() }); @@ -739,14 +736,13 @@ mod tests { fn exists_in_select() { use select::SelectStatement; use std::default::Default; - use table::Table; let cond = "exists ( select col from foo )"; let res = condition_expr(cond.as_bytes()); let nested_select = Box::new(SelectStatement { - tables: vec![Table::from("foo")], + tables: vec![TableObject::from("foo")], fields: columns(&["col"]), ..Default::default() }); @@ -760,14 +756,13 @@ mod tests { fn not_exists_in_select() { use select::SelectStatement; use std::default::Default; - use table::Table; let cond = "not exists (select col from foo)"; let res = condition_expr(cond.as_bytes()); let nested_select = Box::new(SelectStatement { - tables: vec![Table::from("foo")], + tables: vec![TableObject::from("foo")], fields: columns(&["col"]), ..Default::default() }); @@ -782,7 +777,6 @@ mod tests { fn and_with_nested_select() { use select::SelectStatement; use std::default::Default; - use table::Table; use ConditionBase::*; let cond = "paperId in (select paperId from PaperConflict) and size > 0"; @@ -790,7 +784,7 @@ mod tests { let res = condition_expr(cond.as_bytes()); let nested_select = Box::new(SelectStatement { - tables: vec![Table::from("PaperConflict")], + tables: vec![TableObject::from("PaperConflict")], fields: columns(&["paperId"]), ..Default::default() }); diff --git a/src/create.rs b/src/create.rs index a1b5afc..39adc4f 100644 --- a/src/create.rs +++ b/src/create.rs @@ -5,8 +5,8 @@ use std::str::FromStr; use column::{Column, ColumnConstraint, ColumnSpecification}; use common::{ - column_identifier_no_alias, parse_comment, sql_identifier, statement_terminator, - schema_table_reference, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey, + column_identifier_no_alias, parse_comment, schema_table_reference, sql_identifier, + statement_terminator, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey, }; use compound_select::{compound_selection, CompoundSelectStatement}; use create_table_options::table_options; @@ -534,7 +534,7 @@ mod tests { assert_eq!( res.unwrap().1, CreateTableStatement { - table: Table::from(("db1","t")), + table: Table::from(("db1", "t")), fields: vec![ColumnSpecification::new( Column::from("t.x"), SqlType::Int(32) @@ -795,7 +795,7 @@ mod tests { name: String::from("v"), fields: vec![], definition: Box::new(SelectSpecification::Simple(SelectStatement { - tables: vec![Table::from("users")], + tables: vec![Table::from("users").into()], fields: vec![FieldDefinitionExpression::All], where_clause: Some(ConditionExpression::ComparisonOp(ConditionTree { left: Box::new(ConditionExpression::Base(ConditionBase::Field( @@ -830,7 +830,7 @@ mod tests { ( None, SelectStatement { - tables: vec![Table::from("users")], + tables: vec![Table::from("users").into()], fields: vec![FieldDefinitionExpression::All], ..Default::default() }, @@ -838,7 +838,7 @@ mod tests { ( Some(CompoundSelectOperator::DistinctUnion), SelectStatement { - tables: vec![Table::from("old_users")], + tables: vec![Table::from("old_users").into()], fields: vec![FieldDefinitionExpression::All], ..Default::default() }, diff --git a/src/delete.rs b/src/delete.rs index 64ee1dc..0b1ee9b 100644 --- a/src/delete.rs +++ b/src/delete.rs @@ -1,26 +1,25 @@ use nom::character::complete::multispace1; use std::{fmt, str}; -use common::{statement_terminator, schema_table_reference}; +use common::{statement_terminator, table_object}; use condition::ConditionExpression; -use keywords::escape_if_keyword; use nom::bytes::complete::tag_no_case; use nom::combinator::opt; use nom::sequence::{delimited, tuple}; use nom::IResult; use select::where_clause; -use table::Table; +use table::TableObject; #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct DeleteStatement { - pub table: Table, + pub table: TableObject, pub where_clause: Option, } impl fmt::Display for DeleteStatement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "DELETE FROM ")?; - write!(f, "{}", escape_if_keyword(&self.table.name))?; + write!(f, "{}", self.table)?; if let Some(ref where_clause) = self.where_clause { write!(f, " WHERE ")?; write!(f, "{}", where_clause)?; @@ -33,7 +32,7 @@ pub fn deletion(i: &[u8]) -> IResult<&[u8], DeleteStatement> { let (remaining_input, (_, _, table, where_clause, _)) = tuple(( tag_no_case("delete"), delimited(multispace1, tag_no_case("from"), multispace1), - schema_table_reference, + table_object, opt(where_clause), statement_terminator, ))(i)?; @@ -55,7 +54,7 @@ mod tests { use condition::ConditionBase::*; use condition::ConditionExpression::*; use condition::ConditionTree; - use table::Table; + use table::{TableObject, TablePartitionList}; #[test] fn simple_delete() { @@ -64,7 +63,7 @@ mod tests { assert_eq!( res.unwrap().1, DeleteStatement { - table: Table::from("users"), + table: TableObject::from("users"), ..Default::default() } ); @@ -77,7 +76,7 @@ mod tests { assert_eq!( res.unwrap().1, DeleteStatement { - table: Table::from(("db1","users")), + table: TableObject::from(("db1", "users")), ..Default::default() } ); @@ -96,7 +95,30 @@ mod tests { assert_eq!( res.unwrap().1, DeleteStatement { - table: Table::from("users"), + table: TableObject::from("users"), + where_clause: expected_where_cond, + ..Default::default() + } + ); + } + + #[test] + fn delete_with_partition() { + let qstring = "DELETE FROM users PARTITION (u) WHERE id = 1;"; + let res = deletion(qstring.as_bytes()); + let expected_left = Base(Field(Column::from("id"))); + let expected_where_cond = Some(ComparisonOp(ConditionTree { + left: Box::new(expected_left), + right: Box::new(Base(Literal(Literal::Integer(1)))), + operator: Operator::Equal, + })); + assert_eq!( + res.unwrap().1, + DeleteStatement { + table: TableObject { + table: "users".into(), + partitions: TablePartitionList(vec!["u".into()]), + }, where_clause: expected_where_cond, ..Default::default() } diff --git a/src/insert.rs b/src/insert.rs index 9f55107..483686f 100644 --- a/src/insert.rs +++ b/src/insert.rs @@ -4,20 +4,19 @@ use std::str; use column::Column; use common::{ - assignment_expr_list, field_list, statement_terminator, schema_table_reference, value_list, - ws_sep_comma, FieldValueExpression, Literal, + assignment_expr_list, field_list, statement_terminator, table_object, value_list, ws_sep_comma, + FieldValueExpression, Literal, }; -use keywords::escape_if_keyword; use nom::bytes::complete::{tag, tag_no_case}; use nom::combinator::opt; use nom::multi::many1; use nom::sequence::{delimited, preceded, tuple}; use nom::IResult; -use table::Table; +use table::TableObject; #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct InsertStatement { - pub table: Table, + pub table: TableObject, pub fields: Option>, pub data: Vec>, pub ignore: bool, @@ -26,7 +25,7 @@ pub struct InsertStatement { impl fmt::Display for InsertStatement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "INSERT INTO {}", escape_if_keyword(&self.table.name))?; + write!(f, "INSERT INTO {}", self.table)?; if let Some(ref fields) = self.fields { write!( f, @@ -89,7 +88,7 @@ pub fn insertion(i: &[u8]) -> IResult<&[u8], InsertStatement> { multispace1, tag_no_case("into"), multispace1, - schema_table_reference, + table_object, multispace0, opt(fields), tag_no_case("values"), @@ -98,7 +97,7 @@ pub fn insertion(i: &[u8]) -> IResult<&[u8], InsertStatement> { opt(on_duplicate), statement_terminator, ))(i)?; - assert!(table.alias.is_none()); + assert!(table.table.alias.is_none()); let ignore = ignore_res.is_some(); Ok(( @@ -119,7 +118,7 @@ mod tests { use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator}; use column::Column; use common::ItemPlaceholder; - use table::Table; + use table::{TableObject, TablePartitionList}; #[test] fn simple_insert() { @@ -129,7 +128,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: None, data: vec![vec![42.into(), "test".into()]], ..Default::default() @@ -145,7 +144,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from(("db1","users")), + table: TableObject::from(("db1", "users")), fields: None, data: vec![vec![42.into(), "test".into()]], ..Default::default() @@ -161,7 +160,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: None, data: vec![vec![ 42.into(), @@ -182,7 +181,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), data: vec![vec![42.into(), "test".into()]], ..Default::default() @@ -199,7 +198,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), data: vec![vec![42.into(), "test".into()]], ..Default::default() @@ -215,7 +214,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), data: vec![ vec![42.into(), "test".into()], @@ -234,7 +233,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: Some(vec![Column::from("id"), Column::from("name")]), data: vec![vec![ Literal::Placeholder(ItemPlaceholder::QuestionMark), @@ -260,7 +259,7 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from("keystores"), + table: TableObject::from("keystores"), fields: Some(vec![Column::from("key"), Column::from("value")]), data: vec![vec![ Literal::Placeholder(ItemPlaceholder::DollarNumber(1)), @@ -283,7 +282,26 @@ mod tests { assert_eq!( res.unwrap().1, InsertStatement { - table: Table::from("users"), + table: TableObject::from("users"), + fields: Some(vec![Column::from("id"), Column::from("name")]), + data: vec![vec![42.into(), "test".into()]], + ..Default::default() + } + ); + } + + #[test] + fn insert_with_partitions() { + let qstring = "INSERT INTO users PARTITION (u) (id, name) VALUES (42, \"test\");"; + + let res = insertion(qstring.as_bytes()); + assert_eq!( + res.unwrap().1, + InsertStatement { + table: TableObject { + table: "users".into(), + partitions: TablePartitionList(vec!["u".into()]), + }, fields: Some(vec![Column::from("id"), Column::from("name")]), data: vec![vec![42.into(), "test".into()]], ..Default::default() diff --git a/src/join.rs b/src/join.rs index b91f5dc..83fe2d3 100644 --- a/src/join.rs +++ b/src/join.rs @@ -8,14 +8,14 @@ use nom::bytes::complete::tag_no_case; use nom::combinator::map; use nom::IResult; use select::{JoinClause, SelectStatement}; -use table::Table; +use table::TableObject; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub enum JoinRightSide { /// A single table. - Table(Table), + Table(TableObject), /// A comma-separated (and implicitly joined) sequence of tables. - Tables(Vec
), + Tables(Vec), /// A nested selection, represented as (query, alias). NestedSelect(Box, Option), /// A nested join clause. @@ -112,6 +112,7 @@ mod tests { use condition::ConditionExpression::{self, *}; use condition::ConditionTree; use select::{selection, JoinClause, SelectStatement}; + use table::TablePartitionList; #[test] fn inner_join() { @@ -127,13 +128,50 @@ mod tests { }; let join_cond = ConditionExpression::ComparisonOp(ct); let expected_stmt = SelectStatement { - tables: vec![Table::from("tags")], + tables: vec![TableObject::from("tags")], + join: vec![JoinClause { + operator: JoinOperator::InnerJoin, + right: JoinRightSide::Table(TableObject::from("taggings")), + constraint: JoinConstraint::On(join_cond), + }], fields: vec![FieldDefinitionExpression::AllInTable("tags".into())], + ..Default::default() + }; + + let q = res.unwrap().1; + assert_eq!(q, expected_stmt); + assert_eq!(qstring, format!("{}", q)); + } + + #[test] + fn partitioned_inner_join() { + let qstring = "SELECT tags.* FROM tags PARTITION (t, a, g) \ + INNER JOIN taggings PARTITION (t, a, g) ON tags.id = taggings.tag_id"; + + let res = selection(qstring.as_bytes()); + + let expected_pl = TablePartitionList(vec!["t".into(), "a".into(), "g".into()]); + + let ct = ConditionTree { + left: Box::new(Base(Field(Column::from("tags.id")))), + right: Box::new(Base(Field(Column::from("taggings.tag_id")))), + operator: Operator::Equal, + }; + let join_cond = ConditionExpression::ComparisonOp(ct); + let expected_stmt = SelectStatement { + tables: vec![TableObject { + table: "tags".into(), + partitions: expected_pl.clone(), + }], join: vec![JoinClause { operator: JoinOperator::InnerJoin, - right: JoinRightSide::Table(Table::from("taggings")), + right: JoinRightSide::Table(TableObject { + table: "taggings".into(), + partitions: expected_pl, + }), constraint: JoinConstraint::On(join_cond), }], + fields: vec![FieldDefinitionExpression::AllInTable("tags".into())], ..Default::default() }; diff --git a/src/keywords.rs b/src/keywords.rs index 88612de..582efb3 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -127,11 +127,12 @@ fn keyword_i_to_o(i: &[u8]) -> IResult<&[u8], &[u8]> { ))(i) } -fn keyword_o_to_s(i: &[u8]) -> IResult<&[u8], &[u8]> { +fn keyword_o_to_r(i: &[u8]) -> IResult<&[u8], &[u8]> { alt(( terminated(tag_no_case("ON"), keyword_follow_char), terminated(tag_no_case("OR"), keyword_follow_char), terminated(tag_no_case("OUTER"), keyword_follow_char), + terminated(tag_no_case("PARTITION"), keyword_follow_char), terminated(tag_no_case("PLAN"), keyword_follow_char), terminated(tag_no_case("PRAGMA"), keyword_follow_char), terminated(tag_no_case("PRIMARY"), keyword_follow_char), @@ -139,6 +140,11 @@ fn keyword_o_to_s(i: &[u8]) -> IResult<&[u8], &[u8]> { terminated(tag_no_case("RAISE"), keyword_follow_char), terminated(tag_no_case("RECURSIVE"), keyword_follow_char), terminated(tag_no_case("REFERENCES"), keyword_follow_char), + ))(i) +} + +fn keyword_r_to_s(i: &[u8]) -> IResult<&[u8], &[u8]> { + alt(( terminated(tag_no_case("REGEXP"), keyword_follow_char), terminated(tag_no_case("REINDEX"), keyword_follow_char), terminated(tag_no_case("RELEASE"), keyword_follow_char), @@ -185,7 +191,8 @@ pub fn sql_keyword(i: &[u8]) -> IResult<&[u8], &[u8]> { keyword_c_to_e, keyword_e_to_i, keyword_i_to_o, - keyword_o_to_s, + keyword_o_to_r, + keyword_r_to_s, keyword_s_to_z, ))(i) } diff --git a/src/lib.rs b/src/lib.rs index da4d5d1..7c8a538 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ extern crate serde_derive; #[cfg(test)] #[macro_use] extern crate pretty_assertions; +extern crate core; pub use self::arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator}; pub use self::case::{CaseWhenExpression, ColumnOrLiteral}; diff --git a/src/parser.rs b/src/parser.rs index ae68592..e4a43e4 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -79,7 +79,7 @@ mod tests { use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; - use table::Table; + use table::TableObject; #[test] fn hash_query() { @@ -88,7 +88,7 @@ mod tests { assert!(res.is_ok()); let expected = SqlQuery::Insert(InsertStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: None, data: vec![vec![42.into(), "test".into()]], ..Default::default() diff --git a/src/select.rs b/src/select.rs index 8735e95..8d3aa2e 100644 --- a/src/select.rs +++ b/src/select.rs @@ -3,21 +3,20 @@ use std::fmt; use std::str; use column::Column; -use common::FieldDefinitionExpression; use common::{ - as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference, - unsigned_number, + as_alias, field_definition_expr, field_list, relational_objects_clauses, statement_terminator, + table_object_list, unsigned_number, }; +use common::{table_object, FieldDefinitionExpression}; use condition::{condition_expr, ConditionExpression}; use join::{join_operator, JoinConstraint, JoinOperator, JoinRightSide}; use nom::branch::alt; use nom::bytes::complete::{tag, tag_no_case}; use nom::combinator::{map, opt}; -use nom::multi::many0; use nom::sequence::{delimited, preceded, terminated, tuple}; use nom::IResult; use order::{order_clause, OrderClause}; -use table::Table; +use table::TableObject; #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct GroupByClause { @@ -78,9 +77,9 @@ impl fmt::Display for LimitClause { #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct SelectStatement { - pub tables: Vec
, pub distinct: bool, pub fields: Vec, + pub tables: Vec, pub join: Vec, pub where_clause: Option, pub group_by: Option, @@ -116,9 +115,11 @@ impl fmt::Display for SelectStatement { .join(", ") )?; } + for jc in &self.join { write!(f, " {}", jc)?; } + if let Some(ref where_clause) = self.where_clause { write!(f, " WHERE ")?; write!(f, "{}", where_clause)?; @@ -217,7 +218,7 @@ fn join_constraint(i: &[u8]) -> IResult<&[u8], JoinConstraint> { } // Parse JOIN clause -fn join_clause(i: &[u8]) -> IResult<&[u8], JoinClause> { +pub fn join_clause(i: &[u8]) -> IResult<&[u8], JoinClause> { let (remaining_input, (_, _natural, operator, _, right, _, constraint)) = tuple(( multispace0, opt(terminated(tag_no_case("natural"), multispace1)), @@ -249,8 +250,8 @@ fn join_rhs(i: &[u8]) -> IResult<&[u8], JoinRightSide> { let nested_join = map(delimited(tag("("), join_clause, tag(")")), |nj| { JoinRightSide::NestedJoin(Box::new(nj)) }); - let table = map(table_reference, |t| JoinRightSide::Table(t)); - let tables = map(delimited(tag("("), table_list, tag(")")), |tables| { + let table = map(table_object, |t| JoinRightSide::Table(t)); + let tables = map(delimited(tag("("), table_object_list, tag(")")), |tables| { JoinRightSide::Tables(tables) }); alt((nested_select, nested_join, table, tables))(i) @@ -276,16 +277,14 @@ pub fn selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { let ( remaining_input, - (_, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit), + (_, _, distinct, _, fields, (tables, join), where_clause, group_by, order, limit), ) = tuple(( tag_no_case("select"), multispace1, opt(tag_no_case("distinct")), multispace0, field_definition_expr, - delimited(multispace0, tag_no_case("from"), multispace0), - table_list, - many0(join_clause), + relational_objects_clauses, opt(where_clause), opt(group_by_clause), opt(order_clause), @@ -295,9 +294,9 @@ pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> { remaining_input, SelectStatement { tables, + join, distinct: distinct.is_some(), fields, - join, where_clause, group_by, order, @@ -318,7 +317,8 @@ mod tests { use condition::ConditionExpression::*; use condition::ConditionTree; use order::OrderType; - use table::Table; + use table::{Table, TablePartitionList}; + use OrderType::OrderAscending; fn columns(cols: &[&str]) -> Vec { cols.iter() @@ -334,7 +334,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("users")], + tables: vec![TableObject::from("users")], fields: columns(&["id", "name"]), ..Default::default() } @@ -349,7 +349,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("users")], + tables: vec![TableObject::from("users")], fields: columns(&["users.id", "users.name"]), ..Default::default() } @@ -368,7 +368,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("users")], + tables: vec![TableObject::from("users")], fields: vec![ FieldDefinitionExpression::Value(FieldValueExpression::Literal( Literal::Null.into(), @@ -396,7 +396,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("users")], + tables: vec![TableObject::from("users")], fields: vec![FieldDefinitionExpression::All], ..Default::default() } @@ -411,7 +411,34 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("users"), Table::from("votes")], + tables: vec![TableObject::from("users"), TableObject::from("votes")], + fields: vec![FieldDefinitionExpression::AllInTable(String::from("users"))], + ..Default::default() + } + ); + } + + #[test] + fn select_all_in_partitioned_tables() { + let qstring = "SELECT users.* FROM users PARTITION (a, b), votes PARTITION (a, b);"; + + let expected_pl = TablePartitionList(vec!["a".into(), "b".into()]); + + let res = selection(qstring.as_bytes()); + assert_eq!( + res.unwrap().1, + SelectStatement { + distinct: false, + tables: vec![ + TableObject { + table: "users".into(), + partitions: expected_pl.clone(), + }, + TableObject { + table: "votes".into(), + partitions: expected_pl, + }, + ], fields: vec![FieldDefinitionExpression::AllInTable(String::from("users"))], ..Default::default() } @@ -426,7 +453,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("users")], + tables: vec![TableObject::from("users")], fields: columns(&["id", "name"]), ..Default::default() } @@ -493,7 +520,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("ContactInfo")], + tables: vec![TableObject::from("ContactInfo")], fields: vec![FieldDefinitionExpression::All], where_clause: expected_where_cond, ..Default::default() @@ -533,8 +560,9 @@ mod tests { tables: vec![Table { name: String::from("PaperTag"), alias: Some(String::from("t")), - schema: None, - },], + schema: None, + } + .into()], fields: vec![FieldDefinitionExpression::All], ..Default::default() } @@ -554,8 +582,9 @@ mod tests { tables: vec![Table { name: String::from("PaperTag"), alias: Some(String::from("t")), - schema: Some(String::from("db1")), - },], + schema: Some(String::from("db1")), + } + .into(),], fields: vec![FieldDefinitionExpression::All], ..Default::default() } @@ -573,7 +602,7 @@ mod tests { assert_eq!( res1.unwrap().1, SelectStatement { - tables: vec![Table::from("PaperTag")], + tables: vec![TableObject::from("PaperTag")], fields: vec![FieldDefinitionExpression::Col(Column { name: String::from("name"), alias: Some(String::from("TagName")), @@ -587,7 +616,7 @@ mod tests { assert_eq!( res2.unwrap().1, SelectStatement { - tables: vec![Table::from("PaperTag")], + tables: vec![TableObject::from("PaperTag")], fields: vec![FieldDefinitionExpression::Col(Column { name: String::from("name"), alias: Some(String::from("TagName")), @@ -608,7 +637,7 @@ mod tests { assert_eq!( res1.unwrap().1, SelectStatement { - tables: vec![Table::from("PaperTag")], + tables: vec![TableObject::from("PaperTag")], fields: vec![FieldDefinitionExpression::Col(Column { name: String::from("name"), alias: Some(String::from("TagName")), @@ -622,7 +651,7 @@ mod tests { assert_eq!( res2.unwrap().1, SelectStatement { - tables: vec![Table::from("PaperTag")], + tables: vec![TableObject::from("PaperTag")], fields: vec![FieldDefinitionExpression::Col(Column { name: String::from("name"), alias: Some(String::from("TagName")), @@ -650,7 +679,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("PaperTag")], + tables: vec![TableObject::from("PaperTag")], distinct: true, fields: columns(&["tag"]), where_clause: expected_where_cond, @@ -689,7 +718,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("PaperStorage")], + tables: vec![TableObject::from("PaperStorage")], fields: columns(&["infoJson"]), where_clause: expected_where_cond, ..Default::default() @@ -718,7 +747,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("users")], + tables: vec![TableObject::from("users")], fields: vec![FieldDefinitionExpression::All], where_clause: expected_where_cond, limit: expected_lim, @@ -736,7 +765,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("address")], + tables: vec![TableObject::from("address")], fields: vec![FieldDefinitionExpression::Col(Column { name: String::from("max(addr_id)"), alias: None, @@ -755,7 +784,7 @@ mod tests { let res = selection(qstring.as_bytes()); let agg_expr = FunctionExpression::Max(FunctionArgument::Column(Column::from("addr_id"))); let expected_stmt = SelectStatement { - tables: vec![Table::from("address")], + tables: vec![TableObject::from("address")], fields: vec![FieldDefinitionExpression::Col(Column { name: String::from("max_addr"), alias: Some(String::from("max_addr")), @@ -774,7 +803,7 @@ mod tests { let res = selection(qstring.as_bytes()); let agg_expr = FunctionExpression::CountStar; let expected_stmt = SelectStatement { - tables: vec![Table::from("votes")], + tables: vec![TableObject::from("votes")], fields: vec![FieldDefinitionExpression::Col(Column { name: String::from("count(*)"), alias: None, @@ -798,7 +827,7 @@ mod tests { let agg_expr = FunctionExpression::Count(FunctionArgument::Column(Column::from("vote_id")), true); let expected_stmt = SelectStatement { - tables: vec![Table::from("votes")], + tables: vec![TableObject::from("votes")], fields: vec![FieldDefinitionExpression::Col(Column { name: String::from("count(distinct vote_id)"), alias: None, @@ -834,7 +863,7 @@ mod tests { false, ); let expected_stmt = SelectStatement { - tables: vec![Table::from("votes")], + tables: vec![TableObject::from("votes")], fields: vec![FieldDefinitionExpression::Col(Column { name: format!("{}", agg_expr), alias: None, @@ -870,7 +899,7 @@ mod tests { false, ); let expected_stmt = SelectStatement { - tables: vec![Table::from("votes")], + tables: vec![TableObject::from("votes")], fields: vec![FieldDefinitionExpression::Col(Column { name: format!("{}", agg_expr), alias: None, @@ -907,7 +936,7 @@ mod tests { false, ); let expected_stmt = SelectStatement { - tables: vec![Table::from("votes")], + tables: vec![TableObject::from("votes")], fields: vec![FieldDefinitionExpression::Col(Column { name: format!("{}", agg_expr), alias: None, @@ -954,7 +983,7 @@ mod tests { false, ); let expected_stmt = SelectStatement { - tables: vec![Table::from("votes")], + tables: vec![TableObject::from("votes")], fields: vec![FieldDefinitionExpression::Col(Column { name: String::from("votes"), alias: Some(String::from("votes")), @@ -1001,7 +1030,7 @@ mod tests { }, ); let expected_stmt = SelectStatement { - tables: vec![Table::from("sometable")], + tables: vec![TableObject::from("sometable")], fields: vec![ FieldDefinitionExpression::Col(Column { name: String::from("x"), @@ -1045,7 +1074,7 @@ mod tests { assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("item"), Table::from("author")], + tables: vec![TableObject::from("item"), TableObject::from("author")], fields: vec![FieldDefinitionExpression::All], where_clause: expected_where_cond, order: Some(OrderClause { @@ -1066,13 +1095,13 @@ mod tests { let res = selection(qstring.as_bytes()); let expected_stmt = SelectStatement { - tables: vec![Table::from("PaperConflict")], - fields: columns(&["paperId"]), + tables: vec![TableObject::from("PaperConflict")], join: vec![JoinClause { operator: JoinOperator::Join, - right: JoinRightSide::Table(Table::from("PCMember")), + right: JoinRightSide::Table(TableObject::from("PCMember")), constraint: JoinConstraint::Using(vec![Column::from("contactId")]), }], + fields: columns(&["paperId"]), ..Default::default() }; assert_eq!(res.unwrap().1, expected_stmt); @@ -1094,13 +1123,13 @@ mod tests { }; let join_cond = ConditionExpression::ComparisonOp(ct); let expected = SelectStatement { - tables: vec![Table::from("PCMember")], - fields: columns(&["PCMember.contactId"]), + tables: vec![TableObject::from("PCMember")], join: vec![JoinClause { operator: JoinOperator::Join, - right: JoinRightSide::Table(Table::from("PaperReview")), + right: JoinRightSide::Table(TableObject::from("PaperReview")), constraint: JoinConstraint::On(join_cond), }], + fields: columns(&["PCMember.contactId"]), order: Some(OrderClause { columns: vec![("contactId".into(), OrderType::OrderAscending)], }), @@ -1145,19 +1174,14 @@ mod tests { let mkjoin = |tbl: &str, col: &str| -> JoinClause { JoinClause { operator: JoinOperator::LeftJoin, - right: JoinRightSide::Table(Table::from(tbl)), + right: JoinRightSide::Table(TableObject::from(tbl)), constraint: JoinConstraint::Using(vec![Column::from(col)]), } }; assert_eq!( res.unwrap().1, SelectStatement { - tables: vec![Table::from("ContactInfo")], - fields: columns(&[ - "PCMember.contactId", - "ChairAssistant.contactId", - "Chair.contactId" - ]), + tables: vec![TableObject::from("ContactInfo")], join: vec![ mkjoin("PaperReview", "contactId"), mkjoin("PaperConflict", "contactId"), @@ -1165,12 +1189,29 @@ mod tests { mkjoin("ChairAssistant", "contactId"), mkjoin("Chair", "contactId"), ], + fields: columns(&[ + "PCMember.contactId", + "ChairAssistant.contactId", + "Chair.contactId" + ]), where_clause: expected_where_cond, ..Default::default() } ); } + #[test] + fn out_of_order_joins_fail() { + let qstring0 = "select paperId join PCMember using (contactId);"; + let qstring1 = "select paperId join PCMember from PaperConflict using (contactId);"; + + let res0 = selection(qstring0.as_bytes()); + let res1 = selection(qstring1.as_bytes()); + + assert!(res0.is_err()); + assert!(res1.is_err()); + } + #[test] fn nested_select() { let qstr = "SELECT ol_i_id FROM orders, order_line \ @@ -1185,7 +1226,7 @@ mod tests { }); let inner_select = SelectStatement { - tables: vec![Table::from("orders"), Table::from("order_line")], + tables: vec![TableObject::from("orders"), TableObject::from("order_line")], fields: columns(&["o_c_id"]), where_clause: Some(inner_where_clause), ..Default::default() @@ -1198,7 +1239,7 @@ mod tests { }); let outer_select = SelectStatement { - tables: vec![Table::from("orders"), Table::from("order_line")], + tables: vec![TableObject::from("orders"), TableObject::from("order_line")], fields: columns(&["ol_i_id"]), where_clause: Some(outer_where_clause), ..Default::default() @@ -1218,7 +1259,7 @@ mod tests { let agg_expr = FunctionExpression::Max(FunctionArgument::Column(Column::from("o_id"))); let recursive_select = SelectStatement { - tables: vec![Table::from("orders")], + tables: vec![TableObject::from("orders")], fields: vec![FieldDefinitionExpression::Col(Column { name: String::from("max(o_id)"), alias: None, @@ -1247,7 +1288,7 @@ mod tests { }); let inner_select = SelectStatement { - tables: vec![Table::from("orders"), Table::from("order_line")], + tables: vec![TableObject::from("orders"), TableObject::from("order_line")], fields: columns(&["o_c_id"]), where_clause: Some(inner_where_clause), ..Default::default() @@ -1260,7 +1301,7 @@ mod tests { }); let outer_select = SelectStatement { - tables: vec![Table::from("orders"), Table::from("order_line")], + tables: vec![TableObject::from("orders"), TableObject::from("order_line")], fields: columns(&["ol_i_id"]), where_clause: Some(outer_where_clause), ..Default::default() @@ -1290,14 +1331,13 @@ mod tests { // N.B.: Don't alias the inner select to `inner`, which is, well, a SQL keyword! let inner_select = SelectStatement { - tables: vec![Table::from("order_line")], + tables: vec![TableObject::from("order_line")], fields: columns(&["ol_i_id"]), ..Default::default() }; let outer_select = SelectStatement { - tables: vec![Table::from("orders")], - fields: columns(&["o_id", "ol_i_id"]), + tables: vec![TableObject::from("orders")], join: vec![JoinClause { operator: JoinOperator::Join, right: JoinRightSide::NestedSelect(Box::new(inner_select), Some("ids".into())), @@ -1307,6 +1347,7 @@ mod tests { right: Box::new(Base(Field(Column::from("ids.ol_i_id")))), })), }], + fields: columns(&["o_id", "ol_i_id"]), ..Default::default() }; @@ -1321,7 +1362,7 @@ mod tests { let res = selection(qstr.as_bytes()); let expected = SelectStatement { - tables: vec![Table::from("orders")], + tables: vec![TableObject::from("orders")], fields: vec![FieldDefinitionExpression::Value( FieldValueExpression::Arithmetic(ArithmeticExpression::new( ArithmeticOperator::Subtract, @@ -1351,7 +1392,7 @@ mod tests { let res = selection(qstr.as_bytes()); let expected = SelectStatement { - tables: vec![Table::from("orders")], + tables: vec![TableObject::from("orders")], fields: vec![FieldDefinitionExpression::Value( FieldValueExpression::Arithmetic(ArithmeticExpression::new( ArithmeticOperator::Multiply, @@ -1389,24 +1430,92 @@ mod tests { })); let expected = SelectStatement { - tables: vec![Table::from("auth_permission")], - fields: vec![ - FieldDefinitionExpression::Col(Column::from("auth_permission.content_type_id")), - FieldDefinitionExpression::Col(Column::from("auth_permission.codename")), - ], + tables: vec![TableObject::from("auth_permission")], join: vec![JoinClause { operator: JoinOperator::Join, - right: JoinRightSide::Table(Table::from("django_content_type")), + right: JoinRightSide::Table(TableObject::from("django_content_type")), constraint: JoinConstraint::On(ComparisonOp(ConditionTree { operator: Operator::Equal, left: Box::new(Base(Field(Column::from("auth_permission.content_type_id")))), right: Box::new(Base(Field(Column::from("django_content_type.id")))), })), }], + fields: vec![ + FieldDefinitionExpression::Col(Column::from("auth_permission.content_type_id")), + FieldDefinitionExpression::Col(Column::from("auth_permission.codename")), + ], where_clause: expected_where_clause, ..Default::default() }; assert_eq!(res.unwrap().1, expected); } + + #[test] + fn literal_select() { + use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator}; + + let qstr0 = "SELECT 1 + 1"; + let qstr1 = "SELECT 1 + 1 AS adder GROUP BY adder ORDER BY adder"; + + let res0 = selection(qstr0.as_bytes()); + let res1 = selection(qstr1.as_bytes()); + + let expected0 = SelectStatement { + distinct: false, + fields: vec![FieldDefinitionExpression::Value( + FieldValueExpression::Arithmetic(ArithmeticExpression::new( + ArithmeticOperator::Add, + ArithmeticBase::Scalar(1.into()), + ArithmeticBase::Scalar(1.into()), + None, + )), + )], + tables: vec![], + where_clause: None, + group_by: None, + order: None, + limit: None, + ..Default::default() + }; + + let expected1 = SelectStatement { + distinct: false, + fields: vec![FieldDefinitionExpression::Value( + FieldValueExpression::Arithmetic(ArithmeticExpression::new( + ArithmeticOperator::Add, + ArithmeticBase::Scalar(1.into()), + ArithmeticBase::Scalar(1.into()), + Some("adder".to_string()), + )), + )], + tables: vec![], + where_clause: None, + group_by: Some(GroupByClause { + columns: vec![Column { + name: "adder".to_string(), + alias: None, + table: None, + function: None, + }], + having: None, + }), + order: Some(OrderClause { + columns: vec![( + Column { + name: "adder".to_string(), + alias: None, + table: None, + function: None, + }, + OrderAscending, + )], + }), + limit: None, + ..Default::default() + }; + + assert_eq!(res0.unwrap().1, expected0); + assert_eq!(res1.unwrap().1, expected1); + } } diff --git a/src/table.rs b/src/table.rs index cfc62b6..a946c51 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::fmt::{Display, Formatter}; use std::str; use keywords::escape_if_keyword; @@ -32,6 +33,7 @@ impl<'a> From<&'a str> for Table { } } } + impl<'a> From<(&'a str, &'a str)> for Table { fn from(t: (&str, &str)) -> Table { Table { @@ -41,3 +43,204 @@ impl<'a> From<(&'a str, &'a str)> for Table { } } } + +impl From
for TableObject { + fn from(table: Table) -> Self { + Self { + table, + partitions: Default::default(), + } + } +} + +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub struct TableObject { + pub table: Table, + pub partitions: TablePartitionList, +} + +impl Display for TableObject { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.table)?; + + if self.partitions.0.len() > 0 { + write!(f, " {}", self.partitions)?; + } + + Ok(()) + } +} + +impl<'a> From<&'a str> for TableObject { + fn from(t: &str) -> Self { + TableObject { + table: t.into(), + partitions: Default::default(), + } + } +} +impl<'a> From<(&'a str, &'a str)> for TableObject { + fn from(t: (&str, &str)) -> Self { + TableObject { + table: t.into(), + partitions: Default::default(), + } + } +} + +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub struct TablePartitionList(pub Vec); + +impl Display for TablePartitionList { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "PARTITION (")?; + + if self.0.len() > 0 { + write!( + f, + "{}", + self.0 + .iter() + .map(|p| p.to_string()) + .collect::>() + .join(", ") + )?; + } + + write!(f, ")") + } +} + +impl From> for TablePartitionList { + fn from(v: Vec) -> Self { + Self(v) + } +} + +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub struct TablePartition { + name: String, +} + +impl Display for TablePartition { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name) + } +} + +impl<'a> From<&'a str> for TablePartition { + fn from(name: &str) -> Self { + Self { + name: name.to_string(), + } + } +} + +impl<'a> From<&'a [u8]> for TablePartition { + fn from(name: &[u8]) -> Self { + Self { + name: String::from(str::from_utf8(name).unwrap()), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use common::table_object; + + #[test] + fn table_object_simple() { + let qstring0 = "table1"; + let qstring1 = "schema1.table1"; + + let res0 = table_object(qstring0.as_bytes()); + let res1 = table_object(qstring1.as_bytes()); + assert_eq!(res0.unwrap().1, TableObject::from("table1"),); + assert_eq!(res1.unwrap().1, TableObject::from(("schema1", "table1")),); + } + + #[test] + fn table_object_with_alias() { + let qstring0 = "table1 AS t1"; + let qstring1 = "schema1.table1 AS t1"; + + let res0 = table_object(qstring0.as_bytes()); + let res1 = table_object(qstring1.as_bytes()); + assert_eq!( + res0.unwrap().1, + Table { + name: "table1".to_string(), + alias: Some("t1".to_string()), + schema: None + } + .into() + ); + assert_eq!( + res1.unwrap().1, + Table { + name: "table1".to_string(), + alias: Some("t1".to_string()), + schema: Some("schema1".to_string()), + } + .into() + ); + } + + #[test] + fn table_object_with_partitioning() { + let qstring0 = "schema1.table1 partition (a)"; + let qstring1 = "schema1.table1 partition (a, b1_long, a2b)"; + let qstring2 = "schema1.table1 AS t1 partition (a)"; + let qstring3 = "schema1.table1 AS t1 partition (a, b1_long, a2b)"; + + let res0 = table_object(qstring0.as_bytes()); + let res1 = table_object(qstring1.as_bytes()); + let res2 = table_object(qstring2.as_bytes()); + let res3 = table_object(qstring3.as_bytes()); + assert_eq!( + res0.unwrap().1, + TableObject { + table: Table { + name: "table1".to_string(), + alias: None, + schema: Some("schema1".to_string()) + }, + partitions: TablePartitionList(vec!["a".into(),]) + } + ); + assert_eq!( + res1.unwrap().1, + TableObject { + table: Table { + name: "table1".to_string(), + alias: None, + schema: Some("schema1".to_string()) + }, + partitions: TablePartitionList(vec!["a".into(), "b1_long".into(), "a2b".into(),]) + } + ); + assert_eq!( + res2.unwrap().1, + TableObject { + table: Table { + name: "table1".to_string(), + alias: Some("t1".to_string()), + schema: Some("schema1".to_string()), + }, + partitions: TablePartitionList(vec!["a".into(),]) + } + ); + assert_eq!( + res3.unwrap().1, + TableObject { + table: Table { + name: "table1".to_string(), + alias: Some("t1".to_string()), + schema: Some("schema1".to_string()), + }, + partitions: TablePartitionList(vec!["a".into(), "b1_long".into(), "a2b".into(),]), + } + ); + } +} diff --git a/src/update.rs b/src/update.rs index 7366676..2296c8a 100644 --- a/src/update.rs +++ b/src/update.rs @@ -2,26 +2,25 @@ use nom::character::complete::{multispace0, multispace1}; use std::{fmt, str}; use column::Column; -use common::{assignment_expr_list, statement_terminator, table_reference, FieldValueExpression}; +use common::{assignment_expr_list, statement_terminator, table_object, FieldValueExpression}; use condition::ConditionExpression; -use keywords::escape_if_keyword; use nom::bytes::complete::tag_no_case; use nom::combinator::opt; use nom::sequence::tuple; use nom::IResult; use select::where_clause; -use table::Table; +use table::TableObject; #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct UpdateStatement { - pub table: Table, + pub table: TableObject, pub fields: Vec<(Column, FieldValueExpression)>, pub where_clause: Option, } impl fmt::Display for UpdateStatement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "UPDATE {} ", escape_if_keyword(&self.table.name))?; + write!(f, "UPDATE {} ", self.table)?; assert!(self.fields.len() > 0); write!( f, @@ -44,7 +43,7 @@ pub fn updating(i: &[u8]) -> IResult<&[u8], UpdateStatement> { let (remaining_input, (_, _, table, _, _, _, fields, _, where_clause, _)) = tuple(( tag_no_case("update"), multispace1, - table_reference, + table_object, multispace1, tag_no_case("set"), multispace1, @@ -72,7 +71,7 @@ mod tests { use condition::ConditionBase::*; use condition::ConditionExpression::*; use condition::ConditionTree; - use table::Table; + use table::{TableObject, TablePartitionList}; #[test] fn simple_update() { @@ -82,7 +81,7 @@ mod tests { assert_eq!( res.unwrap().1, UpdateStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: vec![ ( Column::from("id"), @@ -114,7 +113,7 @@ mod tests { assert_eq!( res.unwrap().1, UpdateStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: vec![ ( Column::from("id"), @@ -157,7 +156,7 @@ mod tests { assert_eq!( res.unwrap().1, UpdateStatement { - table: Table::from("stories"), + table: TableObject::from("stories"), fields: vec![( Column::from("hotness"), FieldValueExpression::Literal(LiteralExpression::from(Literal::FixedPoint( @@ -194,7 +193,7 @@ mod tests { assert_eq!( res.unwrap().1, UpdateStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: vec![( Column::from("karma"), FieldValueExpression::Arithmetic(expected_ae), @@ -219,7 +218,7 @@ mod tests { assert_eq!( res.unwrap().1, UpdateStatement { - table: Table::from("users"), + table: TableObject::from("users"), fields: vec![( Column::from("karma"), FieldValueExpression::Arithmetic(expected_ae), @@ -228,4 +227,33 @@ mod tests { } ); } + + #[test] + fn update_with_partitions() { + let qstring = "UPDATE users PARTITION (u) SET id = 42, name = 'test'"; + + let res = updating(qstring.as_bytes()); + assert_eq!( + res.unwrap().1, + UpdateStatement { + table: TableObject { + table: "users".into(), + partitions: TablePartitionList(vec!["u".into(),]), + }, + fields: vec![ + ( + Column::from("id"), + FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))), + ), + ( + Column::from("name"), + FieldValueExpression::Literal(LiteralExpression::from(Literal::from( + "test", + ))), + ), + ], + ..Default::default() + } + ); + } }