Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,13 @@ mod tests {
}

#[test]
fn arithmetic_scalar(){
fn arithmetic_scalar() {
let qs = "56";
let res = arithmetic(qs.as_bytes());
assert!(res.is_err());
assert_eq!(nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), res.err().unwrap());
assert_eq!(
nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)),
res.err().unwrap()
);
}
}
15 changes: 15 additions & 0 deletions src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,21 @@ impl PartialOrd for Column {
}
}

#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub enum SortingColumnIdentifier {
FunctionArguments(FunctionArgument),
Position(usize),
}

impl fmt::Display for SortingColumnIdentifier {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SortingColumnIdentifier::FunctionArguments(c) => write!(f, "{}", c),
SortingColumnIdentifier::Position(p) => write!(f, "{}", p),
}
}
}

#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub enum ColumnConstraint {
NotNull,
Expand Down
145 changes: 94 additions & 51 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use std::str::FromStr;

use arithmetic::{arithmetic_expression, ArithmeticExpression};
use case::case_when_column;
use column::{Column, FunctionArgument, FunctionArguments, FunctionExpression};
use column::{
Column, FunctionArgument, FunctionArguments, FunctionExpression, SortingColumnIdentifier,
};
use keywords::{escape_if_keyword, sql_keyword};
use nom::bytes::complete::{is_not, tag, tag_no_case, take, take_until, take_while1};
use nom::combinator::opt;
Expand Down Expand Up @@ -388,7 +390,7 @@ where
let (inp, _) = first.parse(inp)?;
let (inp, o2) = second.parse(inp)?;
third.parse(inp).map(|(i, _)| (i, o2))
},
}
}
}
}
Expand Down Expand Up @@ -641,7 +643,8 @@ pub fn function_argument_parser(i: &[u8]) -> IResult<&[u8], FunctionArgument> {
// present.
pub fn function_arguments(i: &[u8]) -> IResult<&[u8], (FunctionArgument, bool)> {
let distinct_parser = opt(tuple((tag_no_case("distinct"), multispace1)));
let (remaining_input, (distinct, args)) = tuple((distinct_parser, function_argument_parser))(i)?;
let (remaining_input, (distinct, args)) =
tuple((distinct_parser, function_argument_parser))(i)?;
Ok((remaining_input, (args, distinct.is_some())))
}

Expand Down Expand Up @@ -695,12 +698,25 @@ pub fn column_function(i: &[u8]) -> IResult<&[u8], FunctionExpression> {
FunctionExpression::GroupConcat(FunctionArgument::Column(col.clone()), sep)
},
),
map(tuple((sql_identifier, multispace0, tag("("), separated_list0(tag(","), delimited(multispace0, function_argument_parser, multispace0)), tag(")"))), |tuple| {
let (name, _, _, arguments, _) = tuple;
FunctionExpression::Generic(
str::from_utf8(name).unwrap().to_string(),
FunctionArguments::from(arguments))
})
map(
tuple((
sql_identifier,
multispace0,
tag("("),
separated_list0(
tag(","),
delimited(multispace0, function_argument_parser, multispace0),
),
tag(")"),
)),
|tuple| {
let (name, _, _, arguments, _) = tuple;
FunctionExpression::Generic(
str::from_utf8(name).unwrap().to_string(),
FunctionArguments::from(arguments),
)
},
),
))(i)
}

Expand Down Expand Up @@ -740,28 +756,48 @@ pub fn column_identifier(i: &[u8]) -> IResult<&[u8], Column> {
table: None,
function: Some(Box::new(tup.0)),
});
let col_w_table = map(
tuple((
opt(terminated(sql_identifier, tag("."))),
sql_identifier,
opt(as_alias),
)),
|tup| Column {
name: str::from_utf8(tup.1).unwrap().to_string(),
alias: match tup.2 {
let col_w_table = map(tuple((table_column_identifier, opt(as_alias))), |tup| {
Column {
name: tup.0 .1,
alias: match tup.1 {
None => None,
Some(a) => Some(String::from(a)),
},
table: match tup.0 {
None => None,
Some(t) => Some(str::from_utf8(t).unwrap().to_string()),
},
table: tup.0 .0,
function: None,
},
);
}
});
alt((col_func_no_table, col_w_table))(i)
}

// Parses a SQL column name preceded in the table.column format
pub fn table_column_identifier(i: &[u8]) -> IResult<&[u8], (Option<String>, String)> {
tuple((
map(opt(terminated(sql_identifier, tag("."))), |si| {
si.and_then(|si| Some(str::from_utf8(si).unwrap().to_string()))
}),
map(sql_identifier, |si| str::from_utf8(si).unwrap().to_string()),
))(i)
}

pub fn sorting_column_identifier(i: &[u8]) -> IResult<&[u8], SortingColumnIdentifier> {
alt((
map(digit1, |p| {
SortingColumnIdentifier::Position(usize::from_str(str::from_utf8(p).unwrap()).unwrap())
}),
map(function_argument_parser, |c| {
SortingColumnIdentifier::FunctionArguments(c)
}),
))(i)
}

pub fn group_by_column_identifier(i: &[u8]) -> IResult<&[u8], SortingColumnIdentifier> {
map(
tuple((sorting_column_identifier, opt(ws_sep_comma))),
|(c, _)| c,
)(i)
}

// Parses a SQL identifier (alphanumeric1 and "_").
pub fn sql_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> {
alt((
Expand Down Expand Up @@ -1021,22 +1057,23 @@ pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec<Literal>> {
// Parse a reference to a named schema.table, with an optional alias
pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> {
map(
tuple((
opt(pair(sql_identifier, tag("."))),
sql_identifier,
opt(as_alias)
)),
|tup| Table {
name: String::from(str::from_utf8(tup.1).unwrap()),
alias: match tup.2 {
Some(a) => Some(String::from(a)),
None => None,
},
schema: match tup.0 {
Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())),
None => None,
tuple((
opt(pair(sql_identifier, tag("."))),
sql_identifier,
opt(as_alias),
)),
|tup| Table {
name: String::from(str::from_utf8(tup.1).unwrap()),
alias: match tup.2 {
Some(a) => Some(String::from(a)),
None => None,
},
schema: match tup.0 {
Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())),
None => None,
},
},
})(i)
)(i)
}

// Parse a reference to a named table, with an optional alias
Expand All @@ -1047,7 +1084,7 @@ pub fn table_reference(i: &[u8]) -> IResult<&[u8], Table> {
Some(a) => Some(String::from(a)),
None => None,
},
schema: None,
schema: None,
})(i)
}

Expand Down Expand Up @@ -1137,25 +1174,31 @@ mod tests {
name: String::from("max(addr_id)"),
alias: None,
table: None,
function: Some(Box::new(FunctionExpression::Max(
FunctionArgument::Column(Column::from("addr_id")),
))),
function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column(
Column::from("addr_id"),
)))),
};
assert_eq!(res.unwrap().1, expected);
}

#[test]
fn simple_generic_function() {
let qlist = ["coalesce(a,b,c)".as_bytes(), "coalesce (a,b,c)".as_bytes(), "coalesce(a ,b,c)".as_bytes(), "coalesce(a, b,c)".as_bytes()];
let qlist = [
"coalesce(a,b,c)".as_bytes(),
"coalesce (a,b,c)".as_bytes(),
"coalesce(a ,b,c)".as_bytes(),
"coalesce(a, b,c)".as_bytes(),
];
for q in qlist.iter() {
let res = column_function(q);
let expected = FunctionExpression::Generic("coalesce".to_string(),
FunctionArguments::from(
vec!(
FunctionArgument::Column(Column::from("a")),
FunctionArgument::Column(Column::from("b")),
FunctionArgument::Column(Column::from("c"))
)));
let expected = FunctionExpression::Generic(
"coalesce".to_string(),
FunctionArguments::from(vec![
FunctionArgument::Column(Column::from("a")),
FunctionArgument::Column(Column::from("b")),
FunctionArgument::Column(Column::from("c")),
]),
);
assert_eq!(res, Ok((&b""[..], expected)));
}
}
Expand Down
10 changes: 8 additions & 2 deletions src/compound_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,18 @@ mod tests {
assert!(&res.is_err());
assert_eq!(
res.unwrap_err(),
nom::Err::Error(nom::error::Error::new(");".as_bytes(), nom::error::ErrorKind::Tag))
nom::Err::Error(nom::error::Error::new(
");".as_bytes(),
nom::error::ErrorKind::Tag
))
);
assert!(&res2.is_err());
assert_eq!(
res2.unwrap_err(),
nom::Err::Error(nom::error::Error::new(";".as_bytes(), nom::error::ErrorKind::Tag))
nom::Err::Error(nom::error::Error::new(
";".as_bytes(),
nom::error::ErrorKind::Tag
))
);
assert!(&res3.is_err());
assert_eq!(
Expand Down
5 changes: 1 addition & 4 deletions src/condition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,7 @@ fn predicate(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
},
);

alt((
simple_expr,
nested_exists,
))(i)
alt((simple_expr, nested_exists))(i)
}

fn simple_expr(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
Expand Down
6 changes: 3 additions & 3 deletions src/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use std::str::FromStr;

use column::{Column, ColumnConstraint, ColumnSpecification};
use common::{
column_identifier_no_alias, parse_comment, sql_identifier, statement_terminator,
schema_table_reference, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey,
column_identifier_no_alias, parse_comment, schema_table_reference, sql_identifier,
statement_terminator, type_identifier, ws_sep_comma, Literal, Real, SqlType, TableKey,
};
use compound_select::{compound_selection, CompoundSelectStatement};
use create_table_options::table_options;
Expand Down Expand Up @@ -534,7 +534,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
CreateTableStatement {
table: Table::from(("db1","t")),
table: Table::from(("db1", "t")),
fields: vec![ColumnSpecification::new(
Column::from("t.x"),
SqlType::Int(32)
Expand Down
4 changes: 2 additions & 2 deletions src/delete.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use nom::character::complete::multispace1;
use std::{fmt, str};

use common::{statement_terminator, schema_table_reference};
use common::{schema_table_reference, statement_terminator};
use condition::ConditionExpression;
use keywords::escape_if_keyword;
use nom::bytes::complete::tag_no_case;
Expand Down Expand Up @@ -77,7 +77,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
DeleteStatement {
table: Table::from(("db1","users")),
table: Table::from(("db1", "users")),
..Default::default()
}
);
Expand Down
4 changes: 2 additions & 2 deletions src/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::str;

use column::Column;
use common::{
assignment_expr_list, field_list, statement_terminator, schema_table_reference, value_list,
assignment_expr_list, field_list, schema_table_reference, statement_terminator, value_list,
ws_sep_comma, FieldValueExpression, Literal,
};
use keywords::escape_if_keyword;
Expand Down Expand Up @@ -145,7 +145,7 @@ mod tests {
assert_eq!(
res.unwrap().1,
InsertStatement {
table: Table::from(("db1","users")),
table: Table::from(("db1", "users")),
fields: None,
data: vec![vec![42.into(), "test".into()]],
..Default::default()
Expand Down
Loading