Skip to content

Commit

Permalink
fix: more tests and minor compatibility fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
psteinroe committed Aug 10, 2024
1 parent 2f17cc6 commit 866ee74
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 107 deletions.
2 changes: 1 addition & 1 deletion crates/pg_statement_splitter/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ pub static STATEMENT_DEFINITIONS: LazyLock<HashMap<SyntaxKind, Vec<StatementDefi
stmt: SyntaxKind::CreateSchemaStmt,
tokens: vec![
SyntaxDefinition::RequiredToken(SyntaxKind::Create),
SyntaxDefinition::OptionalToken(SyntaxKind::Schema),
SyntaxDefinition::RequiredToken(SyntaxKind::Schema),
],
});

Expand Down
32 changes: 32 additions & 0 deletions crates/pg_statement_splitter/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,38 @@ impl Parser {
self.pos == self.tokens.len()
}

/// lookbehind method.
///
/// if `ignore_whitespace` is true, it will skip all whitespace tokens
pub fn lookbehind(&self, lookbehind: usize, ignore_whitespace: bool) -> Option<&Token> {
if ignore_whitespace {
let mut idx = 0;
let mut non_whitespace_token_ctr = 0;
loop {
match self.tokens.get(self.pos - idx) {
Some(token) => {
if !WHITESPACE_TOKENS.contains(&token.kind) {
non_whitespace_token_ctr += 1;
if non_whitespace_token_ctr == lookbehind {
return Some(token);
}
}
idx += 1;
}
None => {
if (self.pos - idx) > 0 {
idx += 1;
} else {
return None;
}
}
}
}
} else {
self.tokens.get(self.pos - lookbehind)
}
}

/// lookahead method.
///
/// if `ignore_whitespace` is true, it will skip all whitespace tokens
Expand Down
209 changes: 198 additions & 11 deletions crates/pg_statement_splitter/src/statement_splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ impl<'a> StatementSplitter<'a> {

while !self.parser.eof() {
let at_token = self.parser.nth(0, false);
// println!("{:?}", at_token.kind);
// println!(
// "tracked stmts before {:?}",
// self.tracked_statements
// .iter()
// .map(|s| s.def.stmt)
// .collect::<Vec<_>>()
// );
println!("{:?}", at_token.kind);
println!(
"tracked stmts before {:?}",
self.tracked_statements
.iter()
.map(|s| s.def.stmt)
.collect::<Vec<_>>()
);
// TODO rename vars and add helpers to make distinciton between pos and text pos clear

if at_token.kind == SyntaxKind::BeginP {
Expand All @@ -59,8 +59,44 @@ impl<'a> StatementSplitter<'a> {
self.sub_stmt_depth -= 1;
}

self.tracked_statements
.retain_mut(|stmt| stmt.advance_with(&at_token.kind));
let mut removed_items = Vec::new();

self.tracked_statements.retain_mut(|stmt| {
let keep = stmt.advance_with(&at_token.kind);
if !keep {
removed_items.push(stmt.started_at);
}
keep
});

if self.tracked_statements.len() == 0 && removed_items.len() > 0 {
let any_stmt_after = removed_items.iter().min().unwrap();
println!("adding any statement: {:?}", any_stmt_after,);
ranges.push(StatementPosition {
kind: SyntaxKind::Any,
range: TextRange::new(
TextSize::try_from(
self.parser
.tokens
.get(*any_stmt_after)
.unwrap()
.span
.start(),
)
.unwrap(),
TextSize::try_from(self.parser.lookbehind(2, true).unwrap().span.end())
.unwrap(),
),
});
}

println!(
"tracked stmts after advance {:?}",
self.tracked_statements
.iter()
.map(|s| s.def.stmt)
.collect::<Vec<_>>()
);

if self.sub_trx_depth == 0
&& self.sub_stmt_depth == 0
Expand Down Expand Up @@ -103,6 +139,71 @@ impl<'a> StatementSplitter<'a> {
);
}

println!(
"tracked stmts after {:?}",
self.tracked_statements
.iter()
.map(|s| s.def.stmt)
.collect::<Vec<_>>()
);

if at_token.kind == SyntaxKind::Ascii59 {
// ;
// get earliest statement
if let Some(earliest_complete_stmt_started_at) = self
.tracked_statements
.iter()
.filter(|s| s.could_be_complete())
.min_by_key(|stmt| stmt.started_at)
.map(|stmt| stmt.started_at)
{
let earliest_complete_stmt = self
.tracked_statements
.iter()
.filter(|s| {
s.started_at == earliest_complete_stmt_started_at
&& s.could_be_complete()
})
.max_by_key(|stmt| stmt.current_pos)
.unwrap();

assert_eq!(
1,
self.tracked_statements
.iter()
.filter(|s| {
s.started_at == earliest_complete_stmt_started_at
&& s.could_be_complete()
&& s.current_pos == earliest_complete_stmt.current_pos
})
.count(),
"multiple complete statements at the same position"
);

let end_pos = at_token.span.end();
let start_pos = TextSize::try_from(
self.parser
.tokens
.get(earliest_complete_stmt.started_at)
.unwrap()
.span
.start(),
)
.unwrap();
println!(
"adding stmt from ';': {:?}",
earliest_complete_stmt.def.stmt
);
ranges.push(StatementPosition {
kind: earliest_complete_stmt.def.stmt,
range: TextRange::new(start_pos, end_pos),
});
}

self.tracked_statements.clear();
self.active_bridges.clear();
}

// if a statement is complete, check if there are any complete statements that start
// before the just completed one

Expand Down Expand Up @@ -167,6 +268,7 @@ impl<'a> StatementSplitter<'a> {
.parser
.tokens
.iter()
// .skip(latest_completed_stmt_started_at)
.filter_map(|t| {
if t.span.start() < latest_text_pos
&& !WHITESPACE_TOKENS.contains(&t.kind)
Expand All @@ -179,7 +281,7 @@ impl<'a> StatementSplitter<'a> {
.max()
.unwrap();

// println!("adding stmt: {:?}", latest_complete_before.def.stmt);
println!("adding stmt: {:?}", latest_complete_before.def.stmt);

ranges.push(StatementPosition {
kind: latest_complete_before.def.stmt,
Expand Down Expand Up @@ -247,6 +349,7 @@ impl<'a> StatementSplitter<'a> {
.parser
.tokens
.iter()
.skip(earliest_complete_stmt.started_at)
.filter_map(|t| {
if t.span.start() > earliest_text_pos && !WHITESPACE_TOKENS.contains(&t.kind) {
Some(t.span.end())
Expand All @@ -265,10 +368,38 @@ impl<'a> StatementSplitter<'a> {
.start(),
)
.unwrap();
println!("adding stmt at end: {:?}", earliest_complete_stmt.def.stmt);
println!("start: {:?}, end: {:?}", start_pos, end_pos);
ranges.push(StatementPosition {
kind: earliest_complete_stmt.def.stmt,
range: TextRange::new(start_pos, end_pos),
});

self.tracked_statements
.retain(|s| s.started_at > earliest_complete_stmt_started_at);
}

if let Some(earliest_stmt_started_at) = self
.tracked_statements
.iter()
.min_by_key(|stmt| stmt.started_at)
.map(|stmt| stmt.started_at)
{
let start_pos = TextSize::try_from(
self.parser
.tokens
.get(earliest_stmt_started_at)
.unwrap()
.span
.start(),
);
// end position is last non-whitespace token before or at the current position
let end_pos = TextSize::try_from(self.parser.lookbehind(1, true).unwrap().span.end());
println!("adding any stmt at end");
ranges.push(StatementPosition {
kind: SyntaxKind::Any,
range: TextRange::new(start_pos.unwrap(), end_pos.unwrap()),
});
}

ranges
Expand Down Expand Up @@ -494,5 +625,61 @@ DROP ROLE IF EXISTS regress_alter_generic_user1;";
let result = StatementSplitter::new(input).run();

assert_eq!(result.len(), 3);
assert_eq!(
"CREATE FUNCTION test_opclass_options_func(internal)\n RETURNS void\n AS :'regresslib', 'test_opclass_options_func'\n LANGUAGE C;",
input[result[0].range].to_string()
);
assert_eq!(SyntaxKind::CreateFunctionStmt, result[0].kind);
assert_eq!(
"SET client_min_messages TO 'warning';",
input[result[1].range].to_string()
);
assert_eq!(SyntaxKind::VariableSetStmt, result[1].kind);
assert_eq!(
"DROP ROLE IF EXISTS regress_alter_generic_user1;",
input[result[2].range].to_string()
);
assert_eq!(SyntaxKind::DropRoleStmt, result[2].kind);
}

#[test]
fn test_incomplete_statement() {
let input = "create\nselect 1;";

let result = StatementSplitter::new(input).run();

for r in &result {
println!("{:?} {:?}", r.kind, input[r.range].to_string());
}

assert_eq!(result.len(), 2);
assert_eq!("create", input[result[0].range].to_string());
assert_eq!(SyntaxKind::Any, result[0].kind);
assert_eq!("select 1;", input[result[1].range].to_string());
assert_eq!(SyntaxKind::SelectStmt, result[1].kind);
}

#[test]
fn test_incomplete_statement_at_end() {
let input = "select 1;\ncreate";

let result = StatementSplitter::new(input).run();

assert_eq!(result.len(), 2);
assert_eq!("select 1;", input[result[0].range].to_string());
assert_eq!(SyntaxKind::SelectStmt, result[0].kind);
assert_eq!("create", input[result[1].range].to_string());
assert_eq!(SyntaxKind::Any, result[1].kind);
}

#[test]
fn test_only_incomplete_statement() {
let input = " create ";

let result = StatementSplitter::new(input).run();

assert_eq!(result.len(), 1);
assert_eq!("create", input[result[0].range].to_string());
assert_eq!(SyntaxKind::Any, result[0].kind);
}
}
Loading

0 comments on commit 866ee74

Please sign in to comment.