diff --git a/src/arithmetic.rs b/src/arithmetic.rs
index 470b3eb..90a3943 100644
--- a/src/arithmetic.rs
+++ b/src/arithmetic.rs
@@ -459,10 +459,13 @@ mod tests {
}
#[test]
- fn arithmetic_scalar(){
+ fn arithmetic_scalar() {
let qs = "56";
let res = arithmetic(qs.as_bytes());
assert!(res.is_err());
- assert_eq!(nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), res.err().unwrap());
+ assert_eq!(
+ nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)),
+ res.err().unwrap()
+ );
}
}
diff --git a/src/common.rs b/src/common.rs
index 1d6cd4d..9002b92 100644
--- a/src/common.rs
+++ b/src/common.rs
@@ -16,7 +16,9 @@ use nom::combinator::opt;
use nom::error::{ErrorKind, ParseError};
use nom::multi::{fold_many0, many0, many1, separated_list0};
use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple};
+use select::join_clause;
use table::Table;
+use JoinClause;
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub enum SqlType {
@@ -388,7 +390,7 @@ where
let (inp, _) = first.parse(inp)?;
let (inp, o2) = second.parse(inp)?;
third.parse(inp).map(|(i, _)| (i, o2))
- },
+ }
}
}
}
@@ -641,7 +643,8 @@ pub fn function_argument_parser(i: &[u8]) -> IResult<&[u8], FunctionArgument> {
// present.
pub fn function_arguments(i: &[u8]) -> IResult<&[u8], (FunctionArgument, bool)> {
let distinct_parser = opt(tuple((tag_no_case("distinct"), multispace1)));
- let (remaining_input, (distinct, args)) = tuple((distinct_parser, function_argument_parser))(i)?;
+ let (remaining_input, (distinct, args)) =
+ tuple((distinct_parser, function_argument_parser))(i)?;
Ok((remaining_input, (args, distinct.is_some())))
}
@@ -695,12 +698,25 @@ pub fn column_function(i: &[u8]) -> IResult<&[u8], FunctionExpression> {
FunctionExpression::GroupConcat(FunctionArgument::Column(col.clone()), sep)
},
),
- map(tuple((sql_identifier, multispace0, tag("("), separated_list0(tag(","), delimited(multispace0, function_argument_parser, multispace0)), tag(")"))), |tuple| {
- let (name, _, _, arguments, _) = tuple;
- FunctionExpression::Generic(
- str::from_utf8(name).unwrap().to_string(),
- FunctionArguments::from(arguments))
- })
+ map(
+ tuple((
+ sql_identifier,
+ multispace0,
+ tag("("),
+ separated_list0(
+ tag(","),
+ delimited(multispace0, function_argument_parser, multispace0),
+ ),
+ tag(")"),
+ )),
+ |tuple| {
+ let (name, _, _, arguments, _) = tuple;
+ FunctionExpression::Generic(
+ str::from_utf8(name).unwrap().to_string(),
+ FunctionArguments::from(arguments),
+ )
+ },
+ ),
))(i)
}
@@ -893,6 +909,15 @@ pub fn table_list(i: &[u8]) -> IResult<&[u8], Vec
> {
many0(terminated(schema_table_reference, opt(ws_sep_comma)))(i)
}
+pub fn from_clause(i: &[u8]) -> IResult<&[u8], Vec> {
+ let (i, (_, t)) = tuple((
+ delimited(multispace0, tag_no_case("from"), multispace0),
+ many1(terminated(schema_table_reference, opt(ws_sep_comma))),
+ ))(i)?;
+
+ Ok((i, t))
+}
+
// Integer literal value
pub fn integer_literal(i: &[u8]) -> IResult<&[u8], Literal> {
map(pair(opt(tag("-")), digit1), |tup| {
@@ -1018,25 +1043,44 @@ pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec> {
many0(delimited(multispace0, literal, opt(ws_sep_comma)))(i)
}
+pub fn relational_objects_clauses(i: &[u8]) -> IResult<&[u8], (Vec, Vec)> {
+ match from_clause(i) {
+ Ok((i, f)) => {
+ let (i, j) = many0(join_clause)(i).unwrap_or((i, vec![]));
+
+ Ok((i, (f, j)))
+ }
+ Err(e) => {
+ if join_clause(i).is_ok() {
+ //TODO: needs a more helpful error once error handling is improved upon
+ Err(e)
+ } else {
+ Ok((i, (vec![], vec![])))
+ }
+ }
+ }
+}
+
// Parse a reference to a named schema.table, with an optional alias
pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> {
map(
- tuple((
- opt(pair(sql_identifier, tag("."))),
- sql_identifier,
- opt(as_alias)
- )),
- |tup| Table {
- name: String::from(str::from_utf8(tup.1).unwrap()),
- alias: match tup.2 {
- Some(a) => Some(String::from(a)),
- None => None,
- },
- schema: match tup.0 {
- Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())),
- None => None,
+ tuple((
+ opt(pair(sql_identifier, tag("."))),
+ sql_identifier,
+ opt(as_alias),
+ )),
+ |tup| Table {
+ name: String::from(str::from_utf8(tup.1).unwrap()),
+ alias: match tup.2 {
+ Some(a) => Some(String::from(a)),
+ None => None,
+ },
+ schema: match tup.0 {
+ Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())),
+ None => None,
+ },
},
- })(i)
+ )(i)
}
// Parse a reference to a named table, with an optional alias
@@ -1047,7 +1091,7 @@ pub fn table_reference(i: &[u8]) -> IResult<&[u8], Table> {
Some(a) => Some(String::from(a)),
None => None,
},
- schema: None,
+ schema: None,
})(i)
}
@@ -1137,25 +1181,31 @@ mod tests {
name: String::from("max(addr_id)"),
alias: None,
table: None,
- function: Some(Box::new(FunctionExpression::Max(
- FunctionArgument::Column(Column::from("addr_id")),
- ))),
+ function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column(
+ Column::from("addr_id"),
+ )))),
};
assert_eq!(res.unwrap().1, expected);
}
#[test]
fn simple_generic_function() {
- let qlist = ["coalesce(a,b,c)".as_bytes(), "coalesce (a,b,c)".as_bytes(), "coalesce(a ,b,c)".as_bytes(), "coalesce(a, b,c)".as_bytes()];
+ let qlist = [
+ "coalesce(a,b,c)".as_bytes(),
+ "coalesce (a,b,c)".as_bytes(),
+ "coalesce(a ,b,c)".as_bytes(),
+ "coalesce(a, b,c)".as_bytes(),
+ ];
for q in qlist.iter() {
let res = column_function(q);
- let expected = FunctionExpression::Generic("coalesce".to_string(),
- FunctionArguments::from(
- vec!(
- FunctionArgument::Column(Column::from("a")),
- FunctionArgument::Column(Column::from("b")),
- FunctionArgument::Column(Column::from("c"))
- )));
+ let expected = FunctionExpression::Generic(
+ "coalesce".to_string(),
+ FunctionArguments::from(vec![
+ FunctionArgument::Column(Column::from("a")),
+ FunctionArgument::Column(Column::from("b")),
+ FunctionArgument::Column(Column::from("c")),
+ ]),
+ );
assert_eq!(res, Ok((&b""[..], expected)));
}
}
diff --git a/src/compound_select.rs b/src/compound_select.rs
index d3ece89..53d3d08 100644
--- a/src/compound_select.rs
+++ b/src/compound_select.rs
@@ -185,12 +185,18 @@ mod tests {
assert!(&res.is_err());
assert_eq!(
res.unwrap_err(),
- nom::Err::Error(nom::error::Error::new(");".as_bytes(), nom::error::ErrorKind::Tag))
+ nom::Err::Error(nom::error::Error::new(
+ ");".as_bytes(),
+ nom::error::ErrorKind::Tag
+ ))
);
assert!(&res2.is_err());
assert_eq!(
res2.unwrap_err(),
- nom::Err::Error(nom::error::Error::new(";".as_bytes(), nom::error::ErrorKind::Tag))
+ nom::Err::Error(nom::error::Error::new(
+ ";".as_bytes(),
+ nom::error::ErrorKind::Tag
+ ))
);
assert!(&res3.is_err());
assert_eq!(
diff --git a/src/condition.rs b/src/condition.rs
index a210b7a..ce9e1e8 100644
--- a/src/condition.rs
+++ b/src/condition.rs
@@ -291,10 +291,7 @@ fn predicate(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
},
);
- alt((
- simple_expr,
- nested_exists,
- ))(i)
+ alt((simple_expr, nested_exists))(i)
}
fn simple_expr(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
diff --git a/src/create.rs b/src/create.rs
index a1b5afc..45b2cb3 100644
--- a/src/create.rs
+++ b/src/create.rs
@@ -5,8 +5,8 @@ use std::str::FromStr;
use column::{Column, ColumnConstraint, ColumnSpecification};
use common::{
- column_identifier_no_alias, parse_comment, sql_identifier, statement_terminator,
- schema_table_reference, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey,
+ column_identifier_no_alias, parse_comment, schema_table_reference, sql_identifier,
+ statement_terminator, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey,
};
use compound_select::{compound_selection, CompoundSelectStatement};
use create_table_options::table_options;
@@ -534,7 +534,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
CreateTableStatement {
- table: Table::from(("db1","t")),
+ table: Table::from(("db1", "t")),
fields: vec![ColumnSpecification::new(
Column::from("t.x"),
SqlType::Int(32)
diff --git a/src/delete.rs b/src/delete.rs
index 64ee1dc..40f3cd1 100644
--- a/src/delete.rs
+++ b/src/delete.rs
@@ -1,7 +1,7 @@
use nom::character::complete::multispace1;
use std::{fmt, str};
-use common::{statement_terminator, schema_table_reference};
+use common::{schema_table_reference, statement_terminator};
use condition::ConditionExpression;
use keywords::escape_if_keyword;
use nom::bytes::complete::tag_no_case;
@@ -77,7 +77,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
DeleteStatement {
- table: Table::from(("db1","users")),
+ table: Table::from(("db1", "users")),
..Default::default()
}
);
diff --git a/src/insert.rs b/src/insert.rs
index 9f55107..5c81427 100644
--- a/src/insert.rs
+++ b/src/insert.rs
@@ -4,7 +4,7 @@ use std::str;
use column::Column;
use common::{
- assignment_expr_list, field_list, statement_terminator, schema_table_reference, value_list,
+ assignment_expr_list, field_list, schema_table_reference, statement_terminator, value_list,
ws_sep_comma, FieldValueExpression, Literal,
};
use keywords::escape_if_keyword;
@@ -145,7 +145,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
- table: Table::from(("db1","users")),
+ table: Table::from(("db1", "users")),
fields: None,
data: vec![vec![42.into(), "test".into()]],
..Default::default()
diff --git a/src/join.rs b/src/join.rs
index b91f5dc..2f1ca4c 100644
--- a/src/join.rs
+++ b/src/join.rs
@@ -128,12 +128,12 @@ mod tests {
let join_cond = ConditionExpression::ComparisonOp(ct);
let expected_stmt = SelectStatement {
tables: vec![Table::from("tags")],
- fields: vec![FieldDefinitionExpression::AllInTable("tags".into())],
join: vec![JoinClause {
operator: JoinOperator::InnerJoin,
right: JoinRightSide::Table(Table::from("taggings")),
constraint: JoinConstraint::On(join_cond),
}],
+ fields: vec![FieldDefinitionExpression::AllInTable("tags".into())],
..Default::default()
};
diff --git a/src/select.rs b/src/select.rs
index 8735e95..08fd893 100644
--- a/src/select.rs
+++ b/src/select.rs
@@ -3,17 +3,16 @@ use std::fmt;
use std::str;
use column::Column;
-use common::FieldDefinitionExpression;
use common::{
- as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference,
- unsigned_number,
+ as_alias, field_definition_expr, field_list, relational_objects_clauses, statement_terminator,
+ table_reference, unsigned_number,
};
+use common::{table_list, FieldDefinitionExpression};
use condition::{condition_expr, ConditionExpression};
use join::{join_operator, JoinConstraint, JoinOperator, JoinRightSide};
use nom::branch::alt;
use nom::bytes::complete::{tag, tag_no_case};
use nom::combinator::{map, opt};
-use nom::multi::many0;
use nom::sequence::{delimited, preceded, terminated, tuple};
use nom::IResult;
use order::{order_clause, OrderClause};
@@ -78,9 +77,9 @@ impl fmt::Display for LimitClause {
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct SelectStatement {
- pub tables: Vec,
pub distinct: bool,
pub fields: Vec,
+ pub tables: Vec,
pub join: Vec,
pub where_clause: Option,
pub group_by: Option,
@@ -116,9 +115,11 @@ impl fmt::Display for SelectStatement {
.join(", ")
)?;
}
+
for jc in &self.join {
write!(f, " {}", jc)?;
}
+
if let Some(ref where_clause) = self.where_clause {
write!(f, " WHERE ")?;
write!(f, "{}", where_clause)?;
@@ -217,7 +218,7 @@ fn join_constraint(i: &[u8]) -> IResult<&[u8], JoinConstraint> {
}
// Parse JOIN clause
-fn join_clause(i: &[u8]) -> IResult<&[u8], JoinClause> {
+pub fn join_clause(i: &[u8]) -> IResult<&[u8], JoinClause> {
let (remaining_input, (_, _natural, operator, _, right, _, constraint)) = tuple((
multispace0,
opt(terminated(tag_no_case("natural"), multispace1)),
@@ -276,16 +277,14 @@ pub fn selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
let (
remaining_input,
- (_, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit),
+ (_, _, distinct, _, fields, (tables, join), where_clause, group_by, order, limit),
) = tuple((
tag_no_case("select"),
multispace1,
opt(tag_no_case("distinct")),
multispace0,
field_definition_expr,
- delimited(multispace0, tag_no_case("from"), multispace0),
- table_list,
- many0(join_clause),
+ relational_objects_clauses,
opt(where_clause),
opt(group_by_clause),
opt(order_clause),
@@ -295,9 +294,9 @@ pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
remaining_input,
SelectStatement {
tables,
+ join,
distinct: distinct.is_some(),
fields,
- join,
where_clause,
group_by,
order,
@@ -319,6 +318,7 @@ mod tests {
use condition::ConditionTree;
use order::OrderType;
use table::Table;
+ use OrderType::OrderAscending;
fn columns(cols: &[&str]) -> Vec {
cols.iter()
@@ -533,8 +533,8 @@ mod tests {
tables: vec![Table {
name: String::from("PaperTag"),
alias: Some(String::from("t")),
- schema: None,
- },],
+ schema: None,
+ }],
fields: vec![FieldDefinitionExpression::All],
..Default::default()
}
@@ -554,7 +554,7 @@ mod tests {
tables: vec![Table {
name: String::from("PaperTag"),
alias: Some(String::from("t")),
- schema: Some(String::from("db1")),
+ schema: Some(String::from("db1")),
},],
fields: vec![FieldDefinitionExpression::All],
..Default::default()
@@ -1067,12 +1067,12 @@ mod tests {
let res = selection(qstring.as_bytes());
let expected_stmt = SelectStatement {
tables: vec![Table::from("PaperConflict")],
- fields: columns(&["paperId"]),
join: vec![JoinClause {
operator: JoinOperator::Join,
right: JoinRightSide::Table(Table::from("PCMember")),
constraint: JoinConstraint::Using(vec![Column::from("contactId")]),
}],
+ fields: columns(&["paperId"]),
..Default::default()
};
assert_eq!(res.unwrap().1, expected_stmt);
@@ -1095,12 +1095,12 @@ mod tests {
let join_cond = ConditionExpression::ComparisonOp(ct);
let expected = SelectStatement {
tables: vec![Table::from("PCMember")],
- fields: columns(&["PCMember.contactId"]),
join: vec![JoinClause {
operator: JoinOperator::Join,
right: JoinRightSide::Table(Table::from("PaperReview")),
constraint: JoinConstraint::On(join_cond),
}],
+ fields: columns(&["PCMember.contactId"]),
order: Some(OrderClause {
columns: vec![("contactId".into(), OrderType::OrderAscending)],
}),
@@ -1153,11 +1153,6 @@ mod tests {
res.unwrap().1,
SelectStatement {
tables: vec![Table::from("ContactInfo")],
- fields: columns(&[
- "PCMember.contactId",
- "ChairAssistant.contactId",
- "Chair.contactId"
- ]),
join: vec![
mkjoin("PaperReview", "contactId"),
mkjoin("PaperConflict", "contactId"),
@@ -1165,12 +1160,29 @@ mod tests {
mkjoin("ChairAssistant", "contactId"),
mkjoin("Chair", "contactId"),
],
+ fields: columns(&[
+ "PCMember.contactId",
+ "ChairAssistant.contactId",
+ "Chair.contactId"
+ ]),
where_clause: expected_where_cond,
..Default::default()
}
);
}
+ #[test]
+ fn out_of_order_joins_fail() {
+ let qstring0 = "select paperId join PCMember using (contactId);";
+ let qstring1 = "select paperId join PCMember from PaperConflict using (contactId);";
+
+ let res0 = selection(qstring0.as_bytes());
+ let res1 = selection(qstring1.as_bytes());
+
+ assert!(res0.is_err());
+ assert!(res1.is_err());
+ }
+
#[test]
fn nested_select() {
let qstr = "SELECT ol_i_id FROM orders, order_line \
@@ -1297,7 +1309,6 @@ mod tests {
let outer_select = SelectStatement {
tables: vec![Table::from("orders")],
- fields: columns(&["o_id", "ol_i_id"]),
join: vec![JoinClause {
operator: JoinOperator::Join,
right: JoinRightSide::NestedSelect(Box::new(inner_select), Some("ids".into())),
@@ -1307,6 +1318,7 @@ mod tests {
right: Box::new(Base(Field(Column::from("ids.ol_i_id")))),
})),
}],
+ fields: columns(&["o_id", "ol_i_id"]),
..Default::default()
};
@@ -1390,10 +1402,6 @@ mod tests {
let expected = SelectStatement {
tables: vec![Table::from("auth_permission")],
- fields: vec![
- FieldDefinitionExpression::Col(Column::from("auth_permission.content_type_id")),
- FieldDefinitionExpression::Col(Column::from("auth_permission.codename")),
- ],
join: vec![JoinClause {
operator: JoinOperator::Join,
right: JoinRightSide::Table(Table::from("django_content_type")),
@@ -1403,10 +1411,82 @@ mod tests {
right: Box::new(Base(Field(Column::from("django_content_type.id")))),
})),
}],
+ fields: vec![
+ FieldDefinitionExpression::Col(Column::from("auth_permission.content_type_id")),
+ FieldDefinitionExpression::Col(Column::from("auth_permission.codename")),
+ ],
where_clause: expected_where_clause,
..Default::default()
};
assert_eq!(res.unwrap().1, expected);
}
+
+ #[test]
+ fn literal_select() {
+ use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
+
+ let qstr0 = "SELECT 1 + 1";
+ let qstr1 = "SELECT 1 + 1 AS adder GROUP BY adder ORDER BY adder";
+
+ let res0 = selection(qstr0.as_bytes());
+ let res1 = selection(qstr1.as_bytes());
+
+ let expected0 = SelectStatement {
+ distinct: false,
+ fields: vec![FieldDefinitionExpression::Value(
+ FieldValueExpression::Arithmetic(ArithmeticExpression::new(
+ ArithmeticOperator::Add,
+ ArithmeticBase::Scalar(1.into()),
+ ArithmeticBase::Scalar(1.into()),
+ None,
+ )),
+ )],
+ tables: vec![],
+ where_clause: None,
+ group_by: None,
+ order: None,
+ limit: None,
+ ..Default::default()
+ };
+
+ let expected1 = SelectStatement {
+ distinct: false,
+ fields: vec![FieldDefinitionExpression::Value(
+ FieldValueExpression::Arithmetic(ArithmeticExpression::new(
+ ArithmeticOperator::Add,
+ ArithmeticBase::Scalar(1.into()),
+ ArithmeticBase::Scalar(1.into()),
+ Some("adder".to_string()),
+ )),
+ )],
+ tables: vec![],
+ where_clause: None,
+ group_by: Some(GroupByClause {
+ columns: vec![Column {
+ name: "adder".to_string(),
+ alias: None,
+ table: None,
+ function: None,
+ }],
+ having: None,
+ }),
+ order: Some(OrderClause {
+ columns: vec![(
+ Column {
+ name: "adder".to_string(),
+ alias: None,
+ table: None,
+ function: None,
+ },
+ OrderAscending,
+ )],
+ }),
+ limit: None,
+ ..Default::default()
+ };
+
+ assert_eq!(res0.unwrap().1, expected0);
+ assert_eq!(res1.unwrap().1, expected1);
+ }
}