diff --git a/src/arithmetic.rs b/src/arithmetic.rs
index 470b3eb..90a3943 100644
--- a/src/arithmetic.rs
+++ b/src/arithmetic.rs
@@ -459,10 +459,13 @@ mod tests {
}
#[test]
- fn arithmetic_scalar(){
+ fn arithmetic_scalar() {
let qs = "56";
let res = arithmetic(qs.as_bytes());
assert!(res.is_err());
- assert_eq!(nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), res.err().unwrap());
+ assert_eq!(
+ nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)),
+ res.err().unwrap()
+ );
}
}
diff --git a/src/common.rs b/src/common.rs
index 1d6cd4d..a40a045 100644
--- a/src/common.rs
+++ b/src/common.rs
@@ -1,7 +1,7 @@
use nom::branch::alt;
use nom::character::complete::{alphanumeric1, digit1, line_ending, multispace0, multispace1};
use nom::character::is_alphanumeric;
-use nom::combinator::{map, not, peek};
+use nom::combinator::{into, map, not, peek};
use nom::{IResult, InputLength, Parser};
use std::fmt::{self, Display};
use std::str;
@@ -16,7 +16,9 @@ use nom::combinator::opt;
use nom::error::{ErrorKind, ParseError};
use nom::multi::{fold_many0, many0, many1, separated_list0};
use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple};
-use table::Table;
+use select::join_clause;
+use table::{Table, TableObject, TablePartition, TablePartitionList};
+use JoinClause;
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub enum SqlType {
@@ -388,7 +390,7 @@ where
let (inp, _) = first.parse(inp)?;
let (inp, o2) = second.parse(inp)?;
third.parse(inp).map(|(i, _)| (i, o2))
- },
+ }
}
}
}
@@ -641,7 +643,8 @@ pub fn function_argument_parser(i: &[u8]) -> IResult<&[u8], FunctionArgument> {
// present.
pub fn function_arguments(i: &[u8]) -> IResult<&[u8], (FunctionArgument, bool)> {
let distinct_parser = opt(tuple((tag_no_case("distinct"), multispace1)));
- let (remaining_input, (distinct, args)) = tuple((distinct_parser, function_argument_parser))(i)?;
+ let (remaining_input, (distinct, args)) =
+ tuple((distinct_parser, function_argument_parser))(i)?;
Ok((remaining_input, (args, distinct.is_some())))
}
@@ -695,12 +698,25 @@ pub fn column_function(i: &[u8]) -> IResult<&[u8], FunctionExpression> {
FunctionExpression::GroupConcat(FunctionArgument::Column(col.clone()), sep)
},
),
- map(tuple((sql_identifier, multispace0, tag("("), separated_list0(tag(","), delimited(multispace0, function_argument_parser, multispace0)), tag(")"))), |tuple| {
- let (name, _, _, arguments, _) = tuple;
- FunctionExpression::Generic(
- str::from_utf8(name).unwrap().to_string(),
- FunctionArguments::from(arguments))
- })
+ map(
+ tuple((
+ sql_identifier,
+ multispace0,
+ tag("("),
+ separated_list0(
+ tag(","),
+ delimited(multispace0, function_argument_parser, multispace0),
+ ),
+ tag(")"),
+ )),
+ |tuple| {
+ let (name, _, _, arguments, _) = tuple;
+ FunctionExpression::Generic(
+ str::from_utf8(name).unwrap().to_string(),
+ FunctionArguments::from(arguments),
+ )
+ },
+ ),
))(i)
}
@@ -893,6 +909,51 @@ pub fn table_list(i: &[u8]) -> IResult<&[u8], Vec
> {
many0(terminated(schema_table_reference, opt(ws_sep_comma)))(i)
}
+pub fn table_object_list(i: &[u8]) -> IResult<&[u8], Vec> {
+ many0(terminated(table_object, opt(ws_sep_comma)))(i)
+}
+
+pub fn table_object(i: &[u8]) -> IResult<&[u8], TableObject> {
+ map(
+ tuple((
+ schema_table_reference,
+ opt(preceded(multispace0, table_partition_list)),
+ )),
+ |(table, pl)| TableObject {
+ table,
+ partitions: pl.unwrap_or(vec![].into()),
+ },
+ )(i)
+}
+
+pub fn table_partition_list(i: &[u8]) -> IResult<&[u8], TablePartitionList> {
+ let (remaining, (_, _, p, _, _)) = tuple((
+ tag_no_case("partition ("),
+ multispace0,
+ into(many1(terminated(
+ table_partition,
+ opt(tuple((multispace0, tag(","), multispace0))),
+ ))),
+ multispace0,
+ tag(")"),
+ ))(i)?;
+
+ Ok((remaining, p))
+}
+
+pub fn table_partition(i: &[u8]) -> IResult<&[u8], TablePartition> {
+ into(sql_identifier)(i)
+}
+
+pub fn from_clause(i: &[u8]) -> IResult<&[u8], Vec> {
+ let (i, (_, t)) = tuple((
+ delimited(multispace0, tag_no_case("from"), multispace0),
+ many1(terminated(table_object, opt(ws_sep_comma))),
+ ))(i)?;
+
+ Ok((i, t))
+}
+
// Integer literal value
pub fn integer_literal(i: &[u8]) -> IResult<&[u8], Literal> {
map(pair(opt(tag("-")), digit1), |tup| {
@@ -1018,25 +1079,44 @@ pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec> {
many0(delimited(multispace0, literal, opt(ws_sep_comma)))(i)
}
+pub fn relational_objects_clauses(i: &[u8]) -> IResult<&[u8], (Vec, Vec)> {
+ match from_clause(i) {
+ Ok((i, f)) => {
+ let (i, j) = many0(join_clause)(i).unwrap_or((i, vec![]));
+
+ Ok((i, (f, j)))
+ }
+ Err(e) => {
+ if join_clause(i).is_ok() {
+ //TODO: needs a more helpful error once error handling is improved upon
+ Err(e)
+ } else {
+ Ok((i, (vec![], vec![])))
+ }
+ }
+ }
+}
+
// Parse a reference to a named schema.table, with an optional alias
pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> {
map(
- tuple((
- opt(pair(sql_identifier, tag("."))),
- sql_identifier,
- opt(as_alias)
- )),
- |tup| Table {
- name: String::from(str::from_utf8(tup.1).unwrap()),
- alias: match tup.2 {
- Some(a) => Some(String::from(a)),
- None => None,
- },
- schema: match tup.0 {
- Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())),
- None => None,
+ tuple((
+ opt(pair(sql_identifier, tag("."))),
+ sql_identifier,
+ opt(as_alias),
+ )),
+ |tup| Table {
+ name: String::from(str::from_utf8(tup.1).unwrap()),
+ alias: match tup.2 {
+ Some(a) => Some(String::from(a)),
+ None => None,
+ },
+ schema: match tup.0 {
+ Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())),
+ None => None,
+ },
},
- })(i)
+ )(i)
}
// Parse a reference to a named table, with an optional alias
@@ -1047,7 +1127,7 @@ pub fn table_reference(i: &[u8]) -> IResult<&[u8], Table> {
Some(a) => Some(String::from(a)),
None => None,
},
- schema: None,
+ schema: None,
})(i)
}
@@ -1137,25 +1217,31 @@ mod tests {
name: String::from("max(addr_id)"),
alias: None,
table: None,
- function: Some(Box::new(FunctionExpression::Max(
- FunctionArgument::Column(Column::from("addr_id")),
- ))),
+ function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column(
+ Column::from("addr_id"),
+ )))),
};
assert_eq!(res.unwrap().1, expected);
}
#[test]
fn simple_generic_function() {
- let qlist = ["coalesce(a,b,c)".as_bytes(), "coalesce (a,b,c)".as_bytes(), "coalesce(a ,b,c)".as_bytes(), "coalesce(a, b,c)".as_bytes()];
+ let qlist = [
+ "coalesce(a,b,c)".as_bytes(),
+ "coalesce (a,b,c)".as_bytes(),
+ "coalesce(a ,b,c)".as_bytes(),
+ "coalesce(a, b,c)".as_bytes(),
+ ];
for q in qlist.iter() {
let res = column_function(q);
- let expected = FunctionExpression::Generic("coalesce".to_string(),
- FunctionArguments::from(
- vec!(
- FunctionArgument::Column(Column::from("a")),
- FunctionArgument::Column(Column::from("b")),
- FunctionArgument::Column(Column::from("c"))
- )));
+ let expected = FunctionExpression::Generic(
+ "coalesce".to_string(),
+ FunctionArguments::from(vec![
+ FunctionArgument::Column(Column::from("a")),
+ FunctionArgument::Column(Column::from("b")),
+ FunctionArgument::Column(Column::from("c")),
+ ]),
+ );
assert_eq!(res, Ok((&b""[..], expected)));
}
}
diff --git a/src/compound_select.rs b/src/compound_select.rs
index d3ece89..764c40e 100644
--- a/src/compound_select.rs
+++ b/src/compound_select.rs
@@ -133,7 +133,7 @@ mod tests {
use super::*;
use column::Column;
use common::{FieldDefinitionExpression, FieldValueExpression, Literal};
- use table::Table;
+ use table::TableObject;
#[test]
fn union() {
@@ -143,7 +143,7 @@ mod tests {
let res2 = compound_selection(qstr2.as_bytes());
let first_select = SelectStatement {
- tables: vec![Table::from("Vote")],
+ tables: vec![TableObject::from("Vote")],
fields: vec![
FieldDefinitionExpression::Col(Column::from("id")),
FieldDefinitionExpression::Value(FieldValueExpression::Literal(
@@ -153,7 +153,7 @@ mod tests {
..Default::default()
};
let second_select = SelectStatement {
- tables: vec![Table::from("Rating")],
+ tables: vec![TableObject::from("Rating")],
fields: vec![
FieldDefinitionExpression::Col(Column::from("id")),
FieldDefinitionExpression::Col(Column::from("stars")),
@@ -185,12 +185,18 @@ mod tests {
assert!(&res.is_err());
assert_eq!(
res.unwrap_err(),
- nom::Err::Error(nom::error::Error::new(");".as_bytes(), nom::error::ErrorKind::Tag))
+ nom::Err::Error(nom::error::Error::new(
+ ");".as_bytes(),
+ nom::error::ErrorKind::Tag
+ ))
);
assert!(&res2.is_err());
assert_eq!(
res2.unwrap_err(),
- nom::Err::Error(nom::error::Error::new(";".as_bytes(), nom::error::ErrorKind::Tag))
+ nom::Err::Error(nom::error::Error::new(
+ ";".as_bytes(),
+ nom::error::ErrorKind::Tag
+ ))
);
assert!(&res3.is_err());
assert_eq!(
@@ -210,7 +216,7 @@ mod tests {
let res = compound_selection(qstr.as_bytes());
let first_select = SelectStatement {
- tables: vec![Table::from("Vote")],
+ tables: vec![TableObject::from("Vote")],
fields: vec![
FieldDefinitionExpression::Col(Column::from("id")),
FieldDefinitionExpression::Value(FieldValueExpression::Literal(
@@ -220,7 +226,7 @@ mod tests {
..Default::default()
};
let second_select = SelectStatement {
- tables: vec![Table::from("Rating")],
+ tables: vec![TableObject::from("Rating")],
fields: vec![
FieldDefinitionExpression::Col(Column::from("id")),
FieldDefinitionExpression::Col(Column::from("stars")),
@@ -228,7 +234,7 @@ mod tests {
..Default::default()
};
let third_select = SelectStatement {
- tables: vec![Table::from("Vote")],
+ tables: vec![TableObject::from("Vote")],
fields: vec![
FieldDefinitionExpression::Value(FieldValueExpression::Literal(
Literal::Integer(42).into(),
@@ -259,7 +265,7 @@ mod tests {
let res = compound_selection(qstr.as_bytes());
let first_select = SelectStatement {
- tables: vec![Table::from("Vote")],
+ tables: vec![TableObject::from("Vote")],
fields: vec![
FieldDefinitionExpression::Col(Column::from("id")),
FieldDefinitionExpression::Value(FieldValueExpression::Literal(
@@ -269,7 +275,7 @@ mod tests {
..Default::default()
};
let second_select = SelectStatement {
- tables: vec![Table::from("Rating")],
+ tables: vec![TableObject::from("Rating")],
fields: vec![
FieldDefinitionExpression::Col(Column::from("id")),
FieldDefinitionExpression::Col(Column::from("stars")),
diff --git a/src/condition.rs b/src/condition.rs
index a210b7a..c18497a 100644
--- a/src/condition.rs
+++ b/src/condition.rs
@@ -291,10 +291,7 @@ fn predicate(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
},
);
- alt((
- simple_expr,
- nested_exists,
- ))(i)
+ alt((simple_expr, nested_exists))(i)
}
fn simple_expr(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
@@ -332,6 +329,7 @@ mod tests {
use arithmetic::{ArithmeticBase, ArithmeticOperator};
use column::Column;
use common::{FieldDefinitionExpression, ItemPlaceholder, Literal, Operator};
+ use table::TableObject;
fn columns(cols: &[&str]) -> Vec {
cols.iter()
@@ -713,7 +711,6 @@ mod tests {
fn nested_select() {
use select::SelectStatement;
use std::default::Default;
- use table::Table;
use ConditionBase::*;
let cond = "bar in (select col from foo)";
@@ -721,7 +718,7 @@ mod tests {
let res = condition_expr(cond.as_bytes());
let nested_select = Box::new(SelectStatement {
- tables: vec![Table::from("foo")],
+ tables: vec![TableObject::from("foo")],
fields: columns(&["col"]),
..Default::default()
});
@@ -739,14 +736,13 @@ mod tests {
fn exists_in_select() {
use select::SelectStatement;
use std::default::Default;
- use table::Table;
let cond = "exists ( select col from foo )";
let res = condition_expr(cond.as_bytes());
let nested_select = Box::new(SelectStatement {
- tables: vec![Table::from("foo")],
+ tables: vec![TableObject::from("foo")],
fields: columns(&["col"]),
..Default::default()
});
@@ -760,14 +756,13 @@ mod tests {
fn not_exists_in_select() {
use select::SelectStatement;
use std::default::Default;
- use table::Table;
let cond = "not exists (select col from foo)";
let res = condition_expr(cond.as_bytes());
let nested_select = Box::new(SelectStatement {
- tables: vec![Table::from("foo")],
+ tables: vec![TableObject::from("foo")],
fields: columns(&["col"]),
..Default::default()
});
@@ -782,7 +777,6 @@ mod tests {
fn and_with_nested_select() {
use select::SelectStatement;
use std::default::Default;
- use table::Table;
use ConditionBase::*;
let cond = "paperId in (select paperId from PaperConflict) and size > 0";
@@ -790,7 +784,7 @@ mod tests {
let res = condition_expr(cond.as_bytes());
let nested_select = Box::new(SelectStatement {
- tables: vec![Table::from("PaperConflict")],
+ tables: vec![TableObject::from("PaperConflict")],
fields: columns(&["paperId"]),
..Default::default()
});
diff --git a/src/create.rs b/src/create.rs
index a1b5afc..39adc4f 100644
--- a/src/create.rs
+++ b/src/create.rs
@@ -5,8 +5,8 @@ use std::str::FromStr;
use column::{Column, ColumnConstraint, ColumnSpecification};
use common::{
- column_identifier_no_alias, parse_comment, sql_identifier, statement_terminator,
- schema_table_reference, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey,
+ column_identifier_no_alias, parse_comment, schema_table_reference, sql_identifier,
+ statement_terminator, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey,
};
use compound_select::{compound_selection, CompoundSelectStatement};
use create_table_options::table_options;
@@ -534,7 +534,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
CreateTableStatement {
- table: Table::from(("db1","t")),
+ table: Table::from(("db1", "t")),
fields: vec![ColumnSpecification::new(
Column::from("t.x"),
SqlType::Int(32)
@@ -795,7 +795,7 @@ mod tests {
name: String::from("v"),
fields: vec![],
definition: Box::new(SelectSpecification::Simple(SelectStatement {
- tables: vec![Table::from("users")],
+ tables: vec![Table::from("users").into()],
fields: vec![FieldDefinitionExpression::All],
where_clause: Some(ConditionExpression::ComparisonOp(ConditionTree {
left: Box::new(ConditionExpression::Base(ConditionBase::Field(
@@ -830,7 +830,7 @@ mod tests {
(
None,
SelectStatement {
- tables: vec![Table::from("users")],
+ tables: vec![Table::from("users").into()],
fields: vec![FieldDefinitionExpression::All],
..Default::default()
},
@@ -838,7 +838,7 @@ mod tests {
(
Some(CompoundSelectOperator::DistinctUnion),
SelectStatement {
- tables: vec![Table::from("old_users")],
+ tables: vec![Table::from("old_users").into()],
fields: vec![FieldDefinitionExpression::All],
..Default::default()
},
diff --git a/src/delete.rs b/src/delete.rs
index 64ee1dc..0b1ee9b 100644
--- a/src/delete.rs
+++ b/src/delete.rs
@@ -1,26 +1,25 @@
use nom::character::complete::multispace1;
use std::{fmt, str};
-use common::{statement_terminator, schema_table_reference};
+use common::{statement_terminator, table_object};
use condition::ConditionExpression;
-use keywords::escape_if_keyword;
use nom::bytes::complete::tag_no_case;
use nom::combinator::opt;
use nom::sequence::{delimited, tuple};
use nom::IResult;
use select::where_clause;
-use table::Table;
+use table::TableObject;
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct DeleteStatement {
- pub table: Table,
+ pub table: TableObject,
pub where_clause: Option,
}
impl fmt::Display for DeleteStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "DELETE FROM ")?;
- write!(f, "{}", escape_if_keyword(&self.table.name))?;
+ write!(f, "{}", self.table)?;
if let Some(ref where_clause) = self.where_clause {
write!(f, " WHERE ")?;
write!(f, "{}", where_clause)?;
@@ -33,7 +32,7 @@ pub fn deletion(i: &[u8]) -> IResult<&[u8], DeleteStatement> {
let (remaining_input, (_, _, table, where_clause, _)) = tuple((
tag_no_case("delete"),
delimited(multispace1, tag_no_case("from"), multispace1),
- schema_table_reference,
+ table_object,
opt(where_clause),
statement_terminator,
))(i)?;
@@ -55,7 +54,7 @@ mod tests {
use condition::ConditionBase::*;
use condition::ConditionExpression::*;
use condition::ConditionTree;
- use table::Table;
+ use table::{TableObject, TablePartitionList};
#[test]
fn simple_delete() {
@@ -64,7 +63,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
DeleteStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
..Default::default()
}
);
@@ -77,7 +76,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
DeleteStatement {
- table: Table::from(("db1","users")),
+ table: TableObject::from(("db1", "users")),
..Default::default()
}
);
@@ -96,7 +95,30 @@ mod tests {
assert_eq!(
res.unwrap().1,
DeleteStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
+ where_clause: expected_where_cond,
+ ..Default::default()
+ }
+ );
+ }
+
+ #[test]
+ fn delete_with_partition() {
+ let qstring = "DELETE FROM users PARTITION (u) WHERE id = 1;";
+ let res = deletion(qstring.as_bytes());
+ let expected_left = Base(Field(Column::from("id")));
+ let expected_where_cond = Some(ComparisonOp(ConditionTree {
+ left: Box::new(expected_left),
+ right: Box::new(Base(Literal(Literal::Integer(1)))),
+ operator: Operator::Equal,
+ }));
+ assert_eq!(
+ res.unwrap().1,
+ DeleteStatement {
+ table: TableObject {
+ table: "users".into(),
+ partitions: TablePartitionList(vec!["u".into()]),
+ },
where_clause: expected_where_cond,
..Default::default()
}
diff --git a/src/insert.rs b/src/insert.rs
index 9f55107..483686f 100644
--- a/src/insert.rs
+++ b/src/insert.rs
@@ -4,20 +4,19 @@ use std::str;
use column::Column;
use common::{
- assignment_expr_list, field_list, statement_terminator, schema_table_reference, value_list,
- ws_sep_comma, FieldValueExpression, Literal,
+ assignment_expr_list, field_list, statement_terminator, table_object, value_list, ws_sep_comma,
+ FieldValueExpression, Literal,
};
-use keywords::escape_if_keyword;
use nom::bytes::complete::{tag, tag_no_case};
use nom::combinator::opt;
use nom::multi::many1;
use nom::sequence::{delimited, preceded, tuple};
use nom::IResult;
-use table::Table;
+use table::TableObject;
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct InsertStatement {
- pub table: Table,
+ pub table: TableObject,
pub fields: Option>,
pub data: Vec>,
pub ignore: bool,
@@ -26,7 +25,7 @@ pub struct InsertStatement {
impl fmt::Display for InsertStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(f, "INSERT INTO {}", escape_if_keyword(&self.table.name))?;
+ write!(f, "INSERT INTO {}", self.table)?;
if let Some(ref fields) = self.fields {
write!(
f,
@@ -89,7 +88,7 @@ pub fn insertion(i: &[u8]) -> IResult<&[u8], InsertStatement> {
multispace1,
tag_no_case("into"),
multispace1,
- schema_table_reference,
+ table_object,
multispace0,
opt(fields),
tag_no_case("values"),
@@ -98,7 +97,7 @@ pub fn insertion(i: &[u8]) -> IResult<&[u8], InsertStatement> {
opt(on_duplicate),
statement_terminator,
))(i)?;
- assert!(table.alias.is_none());
+ assert!(table.table.alias.is_none());
let ignore = ignore_res.is_some();
Ok((
@@ -119,7 +118,7 @@ mod tests {
use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
use column::Column;
use common::ItemPlaceholder;
- use table::Table;
+ use table::{TableObject, TablePartitionList};
#[test]
fn simple_insert() {
@@ -129,7 +128,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: None,
data: vec![vec![42.into(), "test".into()]],
..Default::default()
@@ -145,7 +144,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
- table: Table::from(("db1","users")),
+ table: TableObject::from(("db1", "users")),
fields: None,
data: vec![vec![42.into(), "test".into()]],
..Default::default()
@@ -161,7 +160,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: None,
data: vec![vec![
42.into(),
@@ -182,7 +181,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: Some(vec![Column::from("id"), Column::from("name")]),
data: vec![vec![42.into(), "test".into()]],
..Default::default()
@@ -199,7 +198,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: Some(vec![Column::from("id"), Column::from("name")]),
data: vec![vec![42.into(), "test".into()]],
..Default::default()
@@ -215,7 +214,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: Some(vec![Column::from("id"), Column::from("name")]),
data: vec![
vec![42.into(), "test".into()],
@@ -234,7 +233,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: Some(vec![Column::from("id"), Column::from("name")]),
data: vec![vec![
Literal::Placeholder(ItemPlaceholder::QuestionMark),
@@ -260,7 +259,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
- table: Table::from("keystores"),
+ table: TableObject::from("keystores"),
fields: Some(vec![Column::from("key"), Column::from("value")]),
data: vec![vec![
Literal::Placeholder(ItemPlaceholder::DollarNumber(1)),
@@ -283,7 +282,26 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
+ fields: Some(vec![Column::from("id"), Column::from("name")]),
+ data: vec![vec![42.into(), "test".into()]],
+ ..Default::default()
+ }
+ );
+ }
+
+ #[test]
+ fn insert_with_partitions() {
+ let qstring = "INSERT INTO users PARTITION (u) (id, name) VALUES (42, \"test\");";
+
+ let res = insertion(qstring.as_bytes());
+ assert_eq!(
+ res.unwrap().1,
+ InsertStatement {
+ table: TableObject {
+ table: "users".into(),
+ partitions: TablePartitionList(vec!["u".into()]),
+ },
fields: Some(vec![Column::from("id"), Column::from("name")]),
data: vec![vec![42.into(), "test".into()]],
..Default::default()
diff --git a/src/join.rs b/src/join.rs
index b91f5dc..83fe2d3 100644
--- a/src/join.rs
+++ b/src/join.rs
@@ -8,14 +8,14 @@ use nom::bytes::complete::tag_no_case;
use nom::combinator::map;
use nom::IResult;
use select::{JoinClause, SelectStatement};
-use table::Table;
+use table::TableObject;
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub enum JoinRightSide {
/// A single table.
- Table(Table),
+ Table(TableObject),
/// A comma-separated (and implicitly joined) sequence of tables.
- Tables(Vec),
+ Tables(Vec),
/// A nested selection, represented as (query, alias).
NestedSelect(Box, Option),
/// A nested join clause.
@@ -112,6 +112,7 @@ mod tests {
use condition::ConditionExpression::{self, *};
use condition::ConditionTree;
use select::{selection, JoinClause, SelectStatement};
+ use table::TablePartitionList;
#[test]
fn inner_join() {
@@ -127,13 +128,50 @@ mod tests {
};
let join_cond = ConditionExpression::ComparisonOp(ct);
let expected_stmt = SelectStatement {
- tables: vec![Table::from("tags")],
+ tables: vec![TableObject::from("tags")],
+ join: vec![JoinClause {
+ operator: JoinOperator::InnerJoin,
+ right: JoinRightSide::Table(TableObject::from("taggings")),
+ constraint: JoinConstraint::On(join_cond),
+ }],
fields: vec![FieldDefinitionExpression::AllInTable("tags".into())],
+ ..Default::default()
+ };
+
+ let q = res.unwrap().1;
+ assert_eq!(q, expected_stmt);
+ assert_eq!(qstring, format!("{}", q));
+ }
+
+ #[test]
+ fn partitioned_inner_join() {
+ let qstring = "SELECT tags.* FROM tags PARTITION (t, a, g) \
+ INNER JOIN taggings PARTITION (t, a, g) ON tags.id = taggings.tag_id";
+
+ let res = selection(qstring.as_bytes());
+
+ let expected_pl = TablePartitionList(vec!["t".into(), "a".into(), "g".into()]);
+
+ let ct = ConditionTree {
+ left: Box::new(Base(Field(Column::from("tags.id")))),
+ right: Box::new(Base(Field(Column::from("taggings.tag_id")))),
+ operator: Operator::Equal,
+ };
+ let join_cond = ConditionExpression::ComparisonOp(ct);
+ let expected_stmt = SelectStatement {
+ tables: vec![TableObject {
+ table: "tags".into(),
+ partitions: expected_pl.clone(),
+ }],
join: vec![JoinClause {
operator: JoinOperator::InnerJoin,
- right: JoinRightSide::Table(Table::from("taggings")),
+ right: JoinRightSide::Table(TableObject {
+ table: "taggings".into(),
+ partitions: expected_pl,
+ }),
constraint: JoinConstraint::On(join_cond),
}],
+ fields: vec![FieldDefinitionExpression::AllInTable("tags".into())],
..Default::default()
};
diff --git a/src/keywords.rs b/src/keywords.rs
index 88612de..582efb3 100644
--- a/src/keywords.rs
+++ b/src/keywords.rs
@@ -127,11 +127,12 @@ fn keyword_i_to_o(i: &[u8]) -> IResult<&[u8], &[u8]> {
))(i)
}
-fn keyword_o_to_s(i: &[u8]) -> IResult<&[u8], &[u8]> {
+fn keyword_o_to_r(i: &[u8]) -> IResult<&[u8], &[u8]> {
alt((
terminated(tag_no_case("ON"), keyword_follow_char),
terminated(tag_no_case("OR"), keyword_follow_char),
terminated(tag_no_case("OUTER"), keyword_follow_char),
+ terminated(tag_no_case("PARTITION"), keyword_follow_char),
terminated(tag_no_case("PLAN"), keyword_follow_char),
terminated(tag_no_case("PRAGMA"), keyword_follow_char),
terminated(tag_no_case("PRIMARY"), keyword_follow_char),
@@ -139,6 +140,11 @@ fn keyword_o_to_s(i: &[u8]) -> IResult<&[u8], &[u8]> {
terminated(tag_no_case("RAISE"), keyword_follow_char),
terminated(tag_no_case("RECURSIVE"), keyword_follow_char),
terminated(tag_no_case("REFERENCES"), keyword_follow_char),
+ ))(i)
+}
+
+fn keyword_r_to_s(i: &[u8]) -> IResult<&[u8], &[u8]> {
+ alt((
terminated(tag_no_case("REGEXP"), keyword_follow_char),
terminated(tag_no_case("REINDEX"), keyword_follow_char),
terminated(tag_no_case("RELEASE"), keyword_follow_char),
@@ -185,7 +191,8 @@ pub fn sql_keyword(i: &[u8]) -> IResult<&[u8], &[u8]> {
keyword_c_to_e,
keyword_e_to_i,
keyword_i_to_o,
- keyword_o_to_s,
+ keyword_o_to_r,
+ keyword_r_to_s,
keyword_s_to_z,
))(i)
}
diff --git a/src/lib.rs b/src/lib.rs
index da4d5d1..7c8a538 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -7,6 +7,7 @@ extern crate serde_derive;
#[cfg(test)]
#[macro_use]
extern crate pretty_assertions;
+extern crate core;
pub use self::arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
pub use self::case::{CaseWhenExpression, ColumnOrLiteral};
diff --git a/src/parser.rs b/src/parser.rs
index ae68592..e4a43e4 100644
--- a/src/parser.rs
+++ b/src/parser.rs
@@ -79,7 +79,7 @@ mod tests {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
- use table::Table;
+ use table::TableObject;
#[test]
fn hash_query() {
@@ -88,7 +88,7 @@ mod tests {
assert!(res.is_ok());
let expected = SqlQuery::Insert(InsertStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: None,
data: vec![vec![42.into(), "test".into()]],
..Default::default()
diff --git a/src/select.rs b/src/select.rs
index 8735e95..8d3aa2e 100644
--- a/src/select.rs
+++ b/src/select.rs
@@ -3,21 +3,20 @@ use std::fmt;
use std::str;
use column::Column;
-use common::FieldDefinitionExpression;
use common::{
- as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference,
- unsigned_number,
+ as_alias, field_definition_expr, field_list, relational_objects_clauses, statement_terminator,
+ table_object_list, unsigned_number,
};
+use common::{table_object, FieldDefinitionExpression};
use condition::{condition_expr, ConditionExpression};
use join::{join_operator, JoinConstraint, JoinOperator, JoinRightSide};
use nom::branch::alt;
use nom::bytes::complete::{tag, tag_no_case};
use nom::combinator::{map, opt};
-use nom::multi::many0;
use nom::sequence::{delimited, preceded, terminated, tuple};
use nom::IResult;
use order::{order_clause, OrderClause};
-use table::Table;
+use table::TableObject;
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct GroupByClause {
@@ -78,9 +77,9 @@ impl fmt::Display for LimitClause {
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct SelectStatement {
- pub tables: Vec,
pub distinct: bool,
pub fields: Vec,
+ pub tables: Vec,
pub join: Vec,
pub where_clause: Option,
pub group_by: Option,
@@ -116,9 +115,11 @@ impl fmt::Display for SelectStatement {
.join(", ")
)?;
}
+
for jc in &self.join {
write!(f, " {}", jc)?;
}
+
if let Some(ref where_clause) = self.where_clause {
write!(f, " WHERE ")?;
write!(f, "{}", where_clause)?;
@@ -217,7 +218,7 @@ fn join_constraint(i: &[u8]) -> IResult<&[u8], JoinConstraint> {
}
// Parse JOIN clause
-fn join_clause(i: &[u8]) -> IResult<&[u8], JoinClause> {
+pub fn join_clause(i: &[u8]) -> IResult<&[u8], JoinClause> {
let (remaining_input, (_, _natural, operator, _, right, _, constraint)) = tuple((
multispace0,
opt(terminated(tag_no_case("natural"), multispace1)),
@@ -249,8 +250,8 @@ fn join_rhs(i: &[u8]) -> IResult<&[u8], JoinRightSide> {
let nested_join = map(delimited(tag("("), join_clause, tag(")")), |nj| {
JoinRightSide::NestedJoin(Box::new(nj))
});
- let table = map(table_reference, |t| JoinRightSide::Table(t));
- let tables = map(delimited(tag("("), table_list, tag(")")), |tables| {
+ let table = map(table_object, |t| JoinRightSide::Table(t));
+ let tables = map(delimited(tag("("), table_object_list, tag(")")), |tables| {
JoinRightSide::Tables(tables)
});
alt((nested_select, nested_join, table, tables))(i)
@@ -276,16 +277,14 @@ pub fn selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
let (
remaining_input,
- (_, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit),
+ (_, _, distinct, _, fields, (tables, join), where_clause, group_by, order, limit),
) = tuple((
tag_no_case("select"),
multispace1,
opt(tag_no_case("distinct")),
multispace0,
field_definition_expr,
- delimited(multispace0, tag_no_case("from"), multispace0),
- table_list,
- many0(join_clause),
+ relational_objects_clauses,
opt(where_clause),
opt(group_by_clause),
opt(order_clause),
@@ -295,9 +294,9 @@ pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
remaining_input,
SelectStatement {
tables,
+ join,
distinct: distinct.is_some(),
fields,
- join,
where_clause,
group_by,
order,
@@ -318,7 +317,8 @@ mod tests {
use condition::ConditionExpression::*;
use condition::ConditionTree;
use order::OrderType;
- use table::Table;
+ use table::{Table, TablePartitionList};
+ use OrderType::OrderAscending;
fn columns(cols: &[&str]) -> Vec {
cols.iter()
@@ -334,7 +334,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("users")],
+ tables: vec![TableObject::from("users")],
fields: columns(&["id", "name"]),
..Default::default()
}
@@ -349,7 +349,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("users")],
+ tables: vec![TableObject::from("users")],
fields: columns(&["users.id", "users.name"]),
..Default::default()
}
@@ -368,7 +368,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("users")],
+ tables: vec![TableObject::from("users")],
fields: vec![
FieldDefinitionExpression::Value(FieldValueExpression::Literal(
Literal::Null.into(),
@@ -396,7 +396,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("users")],
+ tables: vec![TableObject::from("users")],
fields: vec![FieldDefinitionExpression::All],
..Default::default()
}
@@ -411,7 +411,34 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("users"), Table::from("votes")],
+ tables: vec![TableObject::from("users"), TableObject::from("votes")],
+ fields: vec![FieldDefinitionExpression::AllInTable(String::from("users"))],
+ ..Default::default()
+ }
+ );
+ }
+
+ #[test]
+ fn select_all_in_partitioned_tables() {
+ let qstring = "SELECT users.* FROM users PARTITION (a, b), votes PARTITION (a, b);";
+
+ let expected_pl = TablePartitionList(vec!["a".into(), "b".into()]);
+
+ let res = selection(qstring.as_bytes());
+ assert_eq!(
+ res.unwrap().1,
+ SelectStatement {
+ distinct: false,
+ tables: vec![
+ TableObject {
+ table: "users".into(),
+ partitions: expected_pl.clone(),
+ },
+ TableObject {
+ table: "votes".into(),
+ partitions: expected_pl,
+ },
+ ],
fields: vec![FieldDefinitionExpression::AllInTable(String::from("users"))],
..Default::default()
}
@@ -426,7 +453,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("users")],
+ tables: vec![TableObject::from("users")],
fields: columns(&["id", "name"]),
..Default::default()
}
@@ -493,7 +520,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("ContactInfo")],
+ tables: vec![TableObject::from("ContactInfo")],
fields: vec![FieldDefinitionExpression::All],
where_clause: expected_where_cond,
..Default::default()
@@ -533,8 +560,9 @@ mod tests {
tables: vec![Table {
name: String::from("PaperTag"),
alias: Some(String::from("t")),
- schema: None,
- },],
+ schema: None,
+ }
+ .into()],
fields: vec![FieldDefinitionExpression::All],
..Default::default()
}
@@ -554,8 +582,9 @@ mod tests {
tables: vec![Table {
name: String::from("PaperTag"),
alias: Some(String::from("t")),
- schema: Some(String::from("db1")),
- },],
+ schema: Some(String::from("db1")),
+ }
+ .into(),],
fields: vec![FieldDefinitionExpression::All],
..Default::default()
}
@@ -573,7 +602,7 @@ mod tests {
assert_eq!(
res1.unwrap().1,
SelectStatement {
- tables: vec![Table::from("PaperTag")],
+ tables: vec![TableObject::from("PaperTag")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: String::from("name"),
alias: Some(String::from("TagName")),
@@ -587,7 +616,7 @@ mod tests {
assert_eq!(
res2.unwrap().1,
SelectStatement {
- tables: vec![Table::from("PaperTag")],
+ tables: vec![TableObject::from("PaperTag")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: String::from("name"),
alias: Some(String::from("TagName")),
@@ -608,7 +637,7 @@ mod tests {
assert_eq!(
res1.unwrap().1,
SelectStatement {
- tables: vec![Table::from("PaperTag")],
+ tables: vec![TableObject::from("PaperTag")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: String::from("name"),
alias: Some(String::from("TagName")),
@@ -622,7 +651,7 @@ mod tests {
assert_eq!(
res2.unwrap().1,
SelectStatement {
- tables: vec![Table::from("PaperTag")],
+ tables: vec![TableObject::from("PaperTag")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: String::from("name"),
alias: Some(String::from("TagName")),
@@ -650,7 +679,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("PaperTag")],
+ tables: vec![TableObject::from("PaperTag")],
distinct: true,
fields: columns(&["tag"]),
where_clause: expected_where_cond,
@@ -689,7 +718,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("PaperStorage")],
+ tables: vec![TableObject::from("PaperStorage")],
fields: columns(&["infoJson"]),
where_clause: expected_where_cond,
..Default::default()
@@ -718,7 +747,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("users")],
+ tables: vec![TableObject::from("users")],
fields: vec![FieldDefinitionExpression::All],
where_clause: expected_where_cond,
limit: expected_lim,
@@ -736,7 +765,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("address")],
+ tables: vec![TableObject::from("address")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: String::from("max(addr_id)"),
alias: None,
@@ -755,7 +784,7 @@ mod tests {
let res = selection(qstring.as_bytes());
let agg_expr = FunctionExpression::Max(FunctionArgument::Column(Column::from("addr_id")));
let expected_stmt = SelectStatement {
- tables: vec![Table::from("address")],
+ tables: vec![TableObject::from("address")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: String::from("max_addr"),
alias: Some(String::from("max_addr")),
@@ -774,7 +803,7 @@ mod tests {
let res = selection(qstring.as_bytes());
let agg_expr = FunctionExpression::CountStar;
let expected_stmt = SelectStatement {
- tables: vec![Table::from("votes")],
+ tables: vec![TableObject::from("votes")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: String::from("count(*)"),
alias: None,
@@ -798,7 +827,7 @@ mod tests {
let agg_expr =
FunctionExpression::Count(FunctionArgument::Column(Column::from("vote_id")), true);
let expected_stmt = SelectStatement {
- tables: vec![Table::from("votes")],
+ tables: vec![TableObject::from("votes")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: String::from("count(distinct vote_id)"),
alias: None,
@@ -834,7 +863,7 @@ mod tests {
false,
);
let expected_stmt = SelectStatement {
- tables: vec![Table::from("votes")],
+ tables: vec![TableObject::from("votes")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: format!("{}", agg_expr),
alias: None,
@@ -870,7 +899,7 @@ mod tests {
false,
);
let expected_stmt = SelectStatement {
- tables: vec![Table::from("votes")],
+ tables: vec![TableObject::from("votes")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: format!("{}", agg_expr),
alias: None,
@@ -907,7 +936,7 @@ mod tests {
false,
);
let expected_stmt = SelectStatement {
- tables: vec![Table::from("votes")],
+ tables: vec![TableObject::from("votes")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: format!("{}", agg_expr),
alias: None,
@@ -954,7 +983,7 @@ mod tests {
false,
);
let expected_stmt = SelectStatement {
- tables: vec![Table::from("votes")],
+ tables: vec![TableObject::from("votes")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: String::from("votes"),
alias: Some(String::from("votes")),
@@ -1001,7 +1030,7 @@ mod tests {
},
);
let expected_stmt = SelectStatement {
- tables: vec![Table::from("sometable")],
+ tables: vec![TableObject::from("sometable")],
fields: vec![
FieldDefinitionExpression::Col(Column {
name: String::from("x"),
@@ -1045,7 +1074,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("item"), Table::from("author")],
+ tables: vec![TableObject::from("item"), TableObject::from("author")],
fields: vec![FieldDefinitionExpression::All],
where_clause: expected_where_cond,
order: Some(OrderClause {
@@ -1066,13 +1095,13 @@ mod tests {
let res = selection(qstring.as_bytes());
let expected_stmt = SelectStatement {
- tables: vec![Table::from("PaperConflict")],
- fields: columns(&["paperId"]),
+ tables: vec![TableObject::from("PaperConflict")],
join: vec![JoinClause {
operator: JoinOperator::Join,
- right: JoinRightSide::Table(Table::from("PCMember")),
+ right: JoinRightSide::Table(TableObject::from("PCMember")),
constraint: JoinConstraint::Using(vec![Column::from("contactId")]),
}],
+ fields: columns(&["paperId"]),
..Default::default()
};
assert_eq!(res.unwrap().1, expected_stmt);
@@ -1094,13 +1123,13 @@ mod tests {
};
let join_cond = ConditionExpression::ComparisonOp(ct);
let expected = SelectStatement {
- tables: vec![Table::from("PCMember")],
- fields: columns(&["PCMember.contactId"]),
+ tables: vec![TableObject::from("PCMember")],
join: vec![JoinClause {
operator: JoinOperator::Join,
- right: JoinRightSide::Table(Table::from("PaperReview")),
+ right: JoinRightSide::Table(TableObject::from("PaperReview")),
constraint: JoinConstraint::On(join_cond),
}],
+ fields: columns(&["PCMember.contactId"]),
order: Some(OrderClause {
columns: vec![("contactId".into(), OrderType::OrderAscending)],
}),
@@ -1145,19 +1174,14 @@ mod tests {
let mkjoin = |tbl: &str, col: &str| -> JoinClause {
JoinClause {
operator: JoinOperator::LeftJoin,
- right: JoinRightSide::Table(Table::from(tbl)),
+ right: JoinRightSide::Table(TableObject::from(tbl)),
constraint: JoinConstraint::Using(vec![Column::from(col)]),
}
};
assert_eq!(
res.unwrap().1,
SelectStatement {
- tables: vec![Table::from("ContactInfo")],
- fields: columns(&[
- "PCMember.contactId",
- "ChairAssistant.contactId",
- "Chair.contactId"
- ]),
+ tables: vec![TableObject::from("ContactInfo")],
join: vec![
mkjoin("PaperReview", "contactId"),
mkjoin("PaperConflict", "contactId"),
@@ -1165,12 +1189,29 @@ mod tests {
mkjoin("ChairAssistant", "contactId"),
mkjoin("Chair", "contactId"),
],
+ fields: columns(&[
+ "PCMember.contactId",
+ "ChairAssistant.contactId",
+ "Chair.contactId"
+ ]),
where_clause: expected_where_cond,
..Default::default()
}
);
}
+ #[test]
+ fn out_of_order_joins_fail() {
+ let qstring0 = "select paperId join PCMember using (contactId);";
+ let qstring1 = "select paperId join PCMember from PaperConflict using (contactId);";
+
+ let res0 = selection(qstring0.as_bytes());
+ let res1 = selection(qstring1.as_bytes());
+
+ assert!(res0.is_err());
+ assert!(res1.is_err());
+ }
+
#[test]
fn nested_select() {
let qstr = "SELECT ol_i_id FROM orders, order_line \
@@ -1185,7 +1226,7 @@ mod tests {
});
let inner_select = SelectStatement {
- tables: vec![Table::from("orders"), Table::from("order_line")],
+ tables: vec![TableObject::from("orders"), TableObject::from("order_line")],
fields: columns(&["o_c_id"]),
where_clause: Some(inner_where_clause),
..Default::default()
@@ -1198,7 +1239,7 @@ mod tests {
});
let outer_select = SelectStatement {
- tables: vec![Table::from("orders"), Table::from("order_line")],
+ tables: vec![TableObject::from("orders"), TableObject::from("order_line")],
fields: columns(&["ol_i_id"]),
where_clause: Some(outer_where_clause),
..Default::default()
@@ -1218,7 +1259,7 @@ mod tests {
let agg_expr = FunctionExpression::Max(FunctionArgument::Column(Column::from("o_id")));
let recursive_select = SelectStatement {
- tables: vec![Table::from("orders")],
+ tables: vec![TableObject::from("orders")],
fields: vec![FieldDefinitionExpression::Col(Column {
name: String::from("max(o_id)"),
alias: None,
@@ -1247,7 +1288,7 @@ mod tests {
});
let inner_select = SelectStatement {
- tables: vec![Table::from("orders"), Table::from("order_line")],
+ tables: vec![TableObject::from("orders"), TableObject::from("order_line")],
fields: columns(&["o_c_id"]),
where_clause: Some(inner_where_clause),
..Default::default()
@@ -1260,7 +1301,7 @@ mod tests {
});
let outer_select = SelectStatement {
- tables: vec![Table::from("orders"), Table::from("order_line")],
+ tables: vec![TableObject::from("orders"), TableObject::from("order_line")],
fields: columns(&["ol_i_id"]),
where_clause: Some(outer_where_clause),
..Default::default()
@@ -1290,14 +1331,13 @@ mod tests {
// N.B.: Don't alias the inner select to `inner`, which is, well, a SQL keyword!
let inner_select = SelectStatement {
- tables: vec![Table::from("order_line")],
+ tables: vec![TableObject::from("order_line")],
fields: columns(&["ol_i_id"]),
..Default::default()
};
let outer_select = SelectStatement {
- tables: vec![Table::from("orders")],
- fields: columns(&["o_id", "ol_i_id"]),
+ tables: vec![TableObject::from("orders")],
join: vec![JoinClause {
operator: JoinOperator::Join,
right: JoinRightSide::NestedSelect(Box::new(inner_select), Some("ids".into())),
@@ -1307,6 +1347,7 @@ mod tests {
right: Box::new(Base(Field(Column::from("ids.ol_i_id")))),
})),
}],
+ fields: columns(&["o_id", "ol_i_id"]),
..Default::default()
};
@@ -1321,7 +1362,7 @@ mod tests {
let res = selection(qstr.as_bytes());
let expected = SelectStatement {
- tables: vec![Table::from("orders")],
+ tables: vec![TableObject::from("orders")],
fields: vec![FieldDefinitionExpression::Value(
FieldValueExpression::Arithmetic(ArithmeticExpression::new(
ArithmeticOperator::Subtract,
@@ -1351,7 +1392,7 @@ mod tests {
let res = selection(qstr.as_bytes());
let expected = SelectStatement {
- tables: vec![Table::from("orders")],
+ tables: vec![TableObject::from("orders")],
fields: vec![FieldDefinitionExpression::Value(
FieldValueExpression::Arithmetic(ArithmeticExpression::new(
ArithmeticOperator::Multiply,
@@ -1389,24 +1430,92 @@ mod tests {
}));
let expected = SelectStatement {
- tables: vec![Table::from("auth_permission")],
- fields: vec![
- FieldDefinitionExpression::Col(Column::from("auth_permission.content_type_id")),
- FieldDefinitionExpression::Col(Column::from("auth_permission.codename")),
- ],
+ tables: vec![TableObject::from("auth_permission")],
join: vec![JoinClause {
operator: JoinOperator::Join,
- right: JoinRightSide::Table(Table::from("django_content_type")),
+ right: JoinRightSide::Table(TableObject::from("django_content_type")),
constraint: JoinConstraint::On(ComparisonOp(ConditionTree {
operator: Operator::Equal,
left: Box::new(Base(Field(Column::from("auth_permission.content_type_id")))),
right: Box::new(Base(Field(Column::from("django_content_type.id")))),
})),
}],
+ fields: vec![
+ FieldDefinitionExpression::Col(Column::from("auth_permission.content_type_id")),
+ FieldDefinitionExpression::Col(Column::from("auth_permission.codename")),
+ ],
where_clause: expected_where_clause,
..Default::default()
};
assert_eq!(res.unwrap().1, expected);
}
+
+ #[test]
+ fn literal_select() {
+ use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
+
+ let qstr0 = "SELECT 1 + 1";
+ let qstr1 = "SELECT 1 + 1 AS adder GROUP BY adder ORDER BY adder";
+
+ let res0 = selection(qstr0.as_bytes());
+ let res1 = selection(qstr1.as_bytes());
+
+ let expected0 = SelectStatement {
+ distinct: false,
+ fields: vec![FieldDefinitionExpression::Value(
+ FieldValueExpression::Arithmetic(ArithmeticExpression::new(
+ ArithmeticOperator::Add,
+ ArithmeticBase::Scalar(1.into()),
+ ArithmeticBase::Scalar(1.into()),
+ None,
+ )),
+ )],
+ tables: vec![],
+ where_clause: None,
+ group_by: None,
+ order: None,
+ limit: None,
+ ..Default::default()
+ };
+
+ let expected1 = SelectStatement {
+ distinct: false,
+ fields: vec![FieldDefinitionExpression::Value(
+ FieldValueExpression::Arithmetic(ArithmeticExpression::new(
+ ArithmeticOperator::Add,
+ ArithmeticBase::Scalar(1.into()),
+ ArithmeticBase::Scalar(1.into()),
+ Some("adder".to_string()),
+ )),
+ )],
+ tables: vec![],
+ where_clause: None,
+ group_by: Some(GroupByClause {
+ columns: vec![Column {
+ name: "adder".to_string(),
+ alias: None,
+ table: None,
+ function: None,
+ }],
+ having: None,
+ }),
+ order: Some(OrderClause {
+ columns: vec![(
+ Column {
+ name: "adder".to_string(),
+ alias: None,
+ table: None,
+ function: None,
+ },
+ OrderAscending,
+ )],
+ }),
+ limit: None,
+ ..Default::default()
+ };
+
+ assert_eq!(res0.unwrap().1, expected0);
+ assert_eq!(res1.unwrap().1, expected1);
+ }
}
diff --git a/src/table.rs b/src/table.rs
index cfc62b6..a946c51 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -1,4 +1,5 @@
use std::fmt;
+use std::fmt::{Display, Formatter};
use std::str;
use keywords::escape_if_keyword;
@@ -32,6 +33,7 @@ impl<'a> From<&'a str> for Table {
}
}
}
+
impl<'a> From<(&'a str, &'a str)> for Table {
fn from(t: (&str, &str)) -> Table {
Table {
@@ -41,3 +43,204 @@ impl<'a> From<(&'a str, &'a str)> for Table {
}
}
}
+
+impl From for TableObject {
+ fn from(table: Table) -> Self {
+ Self {
+ table,
+ partitions: Default::default(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
+pub struct TableObject {
+ pub table: Table,
+ pub partitions: TablePartitionList,
+}
+
+impl Display for TableObject {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.table)?;
+
+ if self.partitions.0.len() > 0 {
+ write!(f, " {}", self.partitions)?;
+ }
+
+ Ok(())
+ }
+}
+
+impl<'a> From<&'a str> for TableObject {
+ fn from(t: &str) -> Self {
+ TableObject {
+ table: t.into(),
+ partitions: Default::default(),
+ }
+ }
+}
+impl<'a> From<(&'a str, &'a str)> for TableObject {
+ fn from(t: (&str, &str)) -> Self {
+ TableObject {
+ table: t.into(),
+ partitions: Default::default(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
+pub struct TablePartitionList(pub Vec);
+
+impl Display for TablePartitionList {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ write!(f, "PARTITION (")?;
+
+ if self.0.len() > 0 {
+ write!(
+ f,
+ "{}",
+ self.0
+ .iter()
+ .map(|p| p.to_string())
+ .collect::>()
+ .join(", ")
+ )?;
+ }
+
+ write!(f, ")")
+ }
+}
+
+impl From> for TablePartitionList {
+ fn from(v: Vec) -> Self {
+ Self(v)
+ }
+}
+
+#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
+pub struct TablePartition {
+ name: String,
+}
+
+impl Display for TablePartition {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.name)
+ }
+}
+
+impl<'a> From<&'a str> for TablePartition {
+ fn from(name: &str) -> Self {
+ Self {
+ name: name.to_string(),
+ }
+ }
+}
+
+impl<'a> From<&'a [u8]> for TablePartition {
+ fn from(name: &[u8]) -> Self {
+ Self {
+ name: String::from(str::from_utf8(name).unwrap()),
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use common::table_object;
+
+ #[test]
+ fn table_object_simple() {
+ let qstring0 = "table1";
+ let qstring1 = "schema1.table1";
+
+ let res0 = table_object(qstring0.as_bytes());
+ let res1 = table_object(qstring1.as_bytes());
+ assert_eq!(res0.unwrap().1, TableObject::from("table1"),);
+ assert_eq!(res1.unwrap().1, TableObject::from(("schema1", "table1")),);
+ }
+
+ #[test]
+ fn table_object_with_alias() {
+ let qstring0 = "table1 AS t1";
+ let qstring1 = "schema1.table1 AS t1";
+
+ let res0 = table_object(qstring0.as_bytes());
+ let res1 = table_object(qstring1.as_bytes());
+ assert_eq!(
+ res0.unwrap().1,
+ Table {
+ name: "table1".to_string(),
+ alias: Some("t1".to_string()),
+ schema: None
+ }
+ .into()
+ );
+ assert_eq!(
+ res1.unwrap().1,
+ Table {
+ name: "table1".to_string(),
+ alias: Some("t1".to_string()),
+ schema: Some("schema1".to_string()),
+ }
+ .into()
+ );
+ }
+
+ #[test]
+ fn table_object_with_partitioning() {
+ let qstring0 = "schema1.table1 partition (a)";
+ let qstring1 = "schema1.table1 partition (a, b1_long, a2b)";
+ let qstring2 = "schema1.table1 AS t1 partition (a)";
+ let qstring3 = "schema1.table1 AS t1 partition (a, b1_long, a2b)";
+
+ let res0 = table_object(qstring0.as_bytes());
+ let res1 = table_object(qstring1.as_bytes());
+ let res2 = table_object(qstring2.as_bytes());
+ let res3 = table_object(qstring3.as_bytes());
+ assert_eq!(
+ res0.unwrap().1,
+ TableObject {
+ table: Table {
+ name: "table1".to_string(),
+ alias: None,
+ schema: Some("schema1".to_string())
+ },
+ partitions: TablePartitionList(vec!["a".into(),])
+ }
+ );
+ assert_eq!(
+ res1.unwrap().1,
+ TableObject {
+ table: Table {
+ name: "table1".to_string(),
+ alias: None,
+ schema: Some("schema1".to_string())
+ },
+ partitions: TablePartitionList(vec!["a".into(), "b1_long".into(), "a2b".into(),])
+ }
+ );
+ assert_eq!(
+ res2.unwrap().1,
+ TableObject {
+ table: Table {
+ name: "table1".to_string(),
+ alias: Some("t1".to_string()),
+ schema: Some("schema1".to_string()),
+ },
+ partitions: TablePartitionList(vec!["a".into(),])
+ }
+ );
+ assert_eq!(
+ res3.unwrap().1,
+ TableObject {
+ table: Table {
+ name: "table1".to_string(),
+ alias: Some("t1".to_string()),
+ schema: Some("schema1".to_string()),
+ },
+ partitions: TablePartitionList(vec!["a".into(), "b1_long".into(), "a2b".into(),]),
+ }
+ );
+ }
+}
diff --git a/src/update.rs b/src/update.rs
index 7366676..2296c8a 100644
--- a/src/update.rs
+++ b/src/update.rs
@@ -2,26 +2,25 @@ use nom::character::complete::{multispace0, multispace1};
use std::{fmt, str};
use column::Column;
-use common::{assignment_expr_list, statement_terminator, table_reference, FieldValueExpression};
+use common::{assignment_expr_list, statement_terminator, table_object, FieldValueExpression};
use condition::ConditionExpression;
-use keywords::escape_if_keyword;
use nom::bytes::complete::tag_no_case;
use nom::combinator::opt;
use nom::sequence::tuple;
use nom::IResult;
use select::where_clause;
-use table::Table;
+use table::TableObject;
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct UpdateStatement {
- pub table: Table,
+ pub table: TableObject,
pub fields: Vec<(Column, FieldValueExpression)>,
pub where_clause: Option,
}
impl fmt::Display for UpdateStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(f, "UPDATE {} ", escape_if_keyword(&self.table.name))?;
+ write!(f, "UPDATE {} ", self.table)?;
assert!(self.fields.len() > 0);
write!(
f,
@@ -44,7 +43,7 @@ pub fn updating(i: &[u8]) -> IResult<&[u8], UpdateStatement> {
let (remaining_input, (_, _, table, _, _, _, fields, _, where_clause, _)) = tuple((
tag_no_case("update"),
multispace1,
- table_reference,
+ table_object,
multispace1,
tag_no_case("set"),
multispace1,
@@ -72,7 +71,7 @@ mod tests {
use condition::ConditionBase::*;
use condition::ConditionExpression::*;
use condition::ConditionTree;
- use table::Table;
+ use table::{TableObject, TablePartitionList};
#[test]
fn simple_update() {
@@ -82,7 +81,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
UpdateStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: vec![
(
Column::from("id"),
@@ -114,7 +113,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
UpdateStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: vec![
(
Column::from("id"),
@@ -157,7 +156,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
UpdateStatement {
- table: Table::from("stories"),
+ table: TableObject::from("stories"),
fields: vec![(
Column::from("hotness"),
FieldValueExpression::Literal(LiteralExpression::from(Literal::FixedPoint(
@@ -194,7 +193,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
UpdateStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: vec![(
Column::from("karma"),
FieldValueExpression::Arithmetic(expected_ae),
@@ -219,7 +218,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
UpdateStatement {
- table: Table::from("users"),
+ table: TableObject::from("users"),
fields: vec![(
Column::from("karma"),
FieldValueExpression::Arithmetic(expected_ae),
@@ -228,4 +227,33 @@ mod tests {
}
);
}
+
+ #[test]
+ fn update_with_partitions() {
+ let qstring = "UPDATE users PARTITION (u) SET id = 42, name = 'test'";
+
+ let res = updating(qstring.as_bytes());
+ assert_eq!(
+ res.unwrap().1,
+ UpdateStatement {
+ table: TableObject {
+ table: "users".into(),
+ partitions: TablePartitionList(vec!["u".into(),]),
+ },
+ fields: vec![
+ (
+ Column::from("id"),
+ FieldValueExpression::Literal(LiteralExpression::from(Literal::from(42))),
+ ),
+ (
+ Column::from("name"),
+ FieldValueExpression::Literal(LiteralExpression::from(Literal::from(
+ "test",
+ ))),
+ ),
+ ],
+ ..Default::default()
+ }
+ );
+ }
}