Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pass parent to get_node_properties and add support for create aggregate with order by #72

Merged
merged 2 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion crates/codegen/src/get_node_properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ pub fn get_node_properties_mod(proto_file: &ProtoFile) -> proc_macro2::TokenStre
}
}

pub fn get_node_properties(node: &NodeEnum) -> Vec<TokenProperty> {
pub fn get_node_properties(node: &NodeEnum, parent: Option<&NodeEnum>) -> Vec<TokenProperty> {
let mut tokens: Vec<TokenProperty> = Vec::new();

match node {
Expand Down Expand Up @@ -481,6 +481,56 @@ fn custom_handlers(node: &Node) -> TokenStream {
tokens.push(TokenProperty::from(Token::As));
}
},
"List" => quote! {
if parent.is_some() {
// if parent is `DefineStmt`, we need to check whether an ORDER BY needs to be added
if let NodeEnum::DefineStmt(define_stmt) = parent.unwrap() {
// there *seems* to be an integer node in the last position of the DefineStmt args that
// defines whether the list contains an order by statement
let integer = define_stmt.args.last()
.and_then(|node| node.node.as_ref())
.and_then(|node| if let NodeEnum::Integer(n) = node { Some(n.ival) } else { None });
if integer.is_none() {
panic!("DefineStmt of type ObjectAggregate has no integer node in last position of args");
}
// if the integer is 1, then there is an order by statement
// we add it to the `List` node because that seems to make most sense based off the grammar definition
// ref: https://github.com/postgres/postgres/blob/REL_15_STABLE/src/backend/parser/gram.y#L8355
// ```
// aggr_args:
// | '(' aggr_args_list ORDER BY aggr_args_list ')'
// ```
if integer.unwrap() == 1 {
tokens.push(TokenProperty::from(Token::Order));
tokens.push(TokenProperty::from(Token::By));
}
}
}
},
"DefineStmt" => quote! {
tokens.push(TokenProperty::from(Token::Create));
if n.replace {
tokens.push(TokenProperty::from(Token::Or));
tokens.push(TokenProperty::from(Token::Replace));
}
match n.kind() {
protobuf::ObjectType::ObjectAggregate => {
tokens.push(TokenProperty::from(Token::Aggregate));

// n.args is always an array with two nodes
assert_eq!(n.args.len(), 2, "DefineStmt of type ObjectAggregate does not have exactly 2 args");
// the first is either a List or a Node { node: None }

if let Some(node) = &n.args.first() {
if node.node.is_none() {
// if first element is a Node { node: None }, then it's "*"
tokens.push(TokenProperty::from(Token::Ascii42));
} }
// if its a list, we handle it in the handler for `List`
},
_ => panic!("Unknown DefineStmt {:#?}", n.kind()),
}
},
"CreateSchemaStmt" => quote! {
tokens.push(TokenProperty::from(Token::Create));
tokens.push(TokenProperty::from(Token::Schema));
Expand Down
6 changes: 3 additions & 3 deletions crates/codegen/src/get_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn get_nodes_mod(proto_file: &ProtoFile) -> proc_macro2::TokenStream {
let root_node_idx = g.add_node(Node {
kind: SyntaxKind::from(node),
depth: at_depth,
properties: get_node_properties(node),
properties: get_node_properties(node, None),
location: get_location(node),
});

Expand All @@ -45,12 +45,12 @@ pub fn get_nodes_mod(proto_file: &ProtoFile) -> proc_macro2::TokenStream {
NodeEnum::BitString(n) => true,
_ => false
} {
g[parent_idx].properties.extend(get_node_properties(&c));
g[parent_idx].properties.extend(get_node_properties(&c, Some(&node)));
} else {
let node_idx = g.add_node(Node {
kind: SyntaxKind::from(&c),
depth: current_depth,
properties: get_node_properties(&c),
properties: get_node_properties(&c, Some(&node)),
location: get_location(&c),
});
g.add_edge(parent_idx, node_idx, ());
Expand Down
1 change: 1 addition & 0 deletions crates/parser/src/parse/libpg_query_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl<'p> LibpgQueryNodeParser<'p> {
) -> LibpgQueryNodeParser<'p> {
let current_depth = parser.depth.clone();
debug!("Parsing node {:#?}", node);
println!("Parsing node {:#?}", node);
Self {
parser,
token_range,
Expand Down
10 changes: 10 additions & 0 deletions crates/parser/tests/data/statements/valid/0039.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CREATE AGGREGATE aggregate1 (int4) (sfunc = sfunc1, stype = stype1);
CREATE AGGREGATE aggregate1 (int4, bool) (sfunc = sfunc1, stype = stype1);
CREATE AGGREGATE aggregate1 (*) (sfunc = sfunc1, stype = stype1);
CREATE AGGREGATE aggregate1 (int4) (sfunc = sfunc1, stype = stype1, finalfunc_extra, mfinalfuncextra);
CREATE AGGREGATE aggregate1 (int4) (sfunc = sfunc1, stype = stype1, finalfunc_modify = read_only, parallel = restricted);
CREATE AGGREGATE percentile_disc (float8 ORDER BY anyelement) (sfunc = ordered_set_transition, stype = internal, finalfunc = percentile_disc_final, finalfunc_extra);
CREATE AGGREGATE custom_aggregate (float8 ORDER BY column1, column2) (sfunc = sfunc1, stype = stype1);



Loading