From 7b27cdc614c673609b052eaeddad1ab5b3d3649c Mon Sep 17 00:00:00 2001 From: Julian Domke <68325451+juleswritescode@users.noreply.github.com> Date: Fri, 25 Apr 2025 20:18:31 +0200 Subject: [PATCH] fix(completions): improved completion in delete/update clauses (#371) * ok * fixie fixie * Update crates/pgt_completions/src/relevance/filtering.rs * set to workspace cargo.toml * OCD * seems valid * just some help * fixies * hope this works * ok * takin shape * ok * why * format too * this --- crates/pgt_cli/src/execute/mod.rs | 9 +- crates/pgt_completions/src/context.rs | 87 ++++++++++++++--- .../pgt_completions/src/providers/schemas.rs | 72 ++++++++++---- .../pgt_completions/src/providers/tables.rs | 97 ++++++++++++++++++- .../src/relevance/filtering.rs | 10 ++ .../pgt_completions/src/relevance/scoring.rs | 59 ++++++++--- crates/pgt_completions/src/test_helper.rs | 47 ++++++++- .../src/features/code_actions.rs | 1 - 8 files changed, 323 insertions(+), 59 deletions(-) diff --git a/crates/pgt_cli/src/execute/mod.rs b/crates/pgt_cli/src/execute/mod.rs index 90a5bb98..6cb01ca7 100644 --- a/crates/pgt_cli/src/execute/mod.rs +++ b/crates/pgt_cli/src/execute/mod.rs @@ -76,12 +76,11 @@ pub enum TraversalMode { Dummy, /// This mode is enabled when running the command `check` Check { - /// The type of fixes that should be applied when analyzing a file. - /// - /// It's [None] if the `check` command is called without `--apply` or `--apply-suggested` - /// arguments. + // The type of fixes that should be applied when analyzing a file. + // + // It's [None] if the `check` command is called without `--apply` or `--apply-suggested` + // arguments. // fix_file_mode: Option, - /// An optional tuple. /// 1. The virtual path to the file /// 2. The content of the file diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 6005e07b..b16fd21c 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -30,7 +30,7 @@ impl TryFrom<&str> for ClauseType { match value { "select" => Ok(Self::Select), "where" => Ok(Self::Where), - "from" | "keyword_from" => Ok(Self::From), + "from" => Ok(Self::From), "update" => Ok(Self::Update), "delete" => Ok(Self::Delete), _ => { @@ -49,8 +49,52 @@ impl TryFrom<&str> for ClauseType { impl TryFrom for ClauseType { type Error = String; - fn try_from(value: String) -> Result { - ClauseType::try_from(value.as_str()) + fn try_from(value: String) -> Result { + Self::try_from(value.as_str()) + } +} + +/// We can map a few nodes, such as the "update" node, to actual SQL clauses. +/// That gives us a lot of insight for completions. +/// Other nodes, such as the "relation" node, gives us less but still +/// relevant information. +/// `WrappingNode` maps to such nodes. +/// +/// Note: This is not the direct parent of the `node_under_cursor`, but the closest +/// *relevant* parent. +#[derive(Debug, PartialEq, Eq)] +pub enum WrappingNode { + Relation, + BinaryExpression, + Assignment, +} + +impl TryFrom<&str> for WrappingNode { + type Error = String; + + fn try_from(value: &str) -> Result { + match value { + "relation" => Ok(Self::Relation), + "assignment" => Ok(Self::Assignment), + "binary_expression" => Ok(Self::BinaryExpression), + _ => { + let message = format!("Unimplemented Relation: {}", value); + + // Err on tests, so we notice that we're lacking an implementation immediately. + if cfg!(test) { + panic!("{}", message); + } + + Err(message) + } + } + } +} + +impl TryFrom for WrappingNode { + type Error = String; + fn try_from(value: String) -> Result { + Self::try_from(value.as_str()) } } @@ -64,6 +108,9 @@ pub(crate) struct CompletionContext<'a> { pub schema_name: Option, pub wrapping_clause_type: Option, + + pub wrapping_node_kind: Option, + pub is_invocation: bool, pub wrapping_statement_range: Option, @@ -80,6 +127,7 @@ impl<'a> CompletionContext<'a> { node_under_cursor: None, schema_name: None, wrapping_clause_type: None, + wrapping_node_kind: None, wrapping_statement_range: None, is_invocation: false, mentioned_relations: HashMap::new(), @@ -133,6 +181,15 @@ impl<'a> CompletionContext<'a> { }) } + pub fn get_node_under_cursor_content(&self) -> Option { + self.node_under_cursor + .and_then(|n| self.get_ts_node_content(n)) + .and_then(|txt| match txt { + NodeText::Replaced => None, + NodeText::Original(c) => Some(c.to_string()), + }) + } + fn gather_tree_context(&mut self) { let mut cursor = self.tree.root_node().walk(); @@ -163,15 +220,18 @@ impl<'a> CompletionContext<'a> { ) { let current_node = cursor.node(); + let parent_node_kind = parent_node.kind(); + let current_node_kind = current_node.kind(); + // prevent infinite recursion – this can happen if we only have a PROGRAM node - if current_node.kind() == parent_node.kind() { + if current_node_kind == parent_node_kind { self.node_under_cursor = Some(current_node); return; } - match parent_node.kind() { + match parent_node_kind { "statement" | "subquery" => { - self.wrapping_clause_type = current_node.kind().try_into().ok(); + self.wrapping_clause_type = current_node_kind.try_into().ok(); self.wrapping_statement_range = Some(parent_node.range()); } "invocation" => self.is_invocation = true, @@ -179,7 +239,7 @@ impl<'a> CompletionContext<'a> { _ => {} } - match current_node.kind() { + match current_node_kind { "object_reference" => { let content = self.get_ts_node_content(current_node); if let Some(node_txt) = content { @@ -195,13 +255,12 @@ impl<'a> CompletionContext<'a> { } } - // in Treesitter, the Where clause is nested inside other clauses - "where" => { - self.wrapping_clause_type = "where".try_into().ok(); + "where" | "update" | "select" | "delete" | "from" => { + self.wrapping_clause_type = current_node_kind.try_into().ok(); } - "keyword_from" => { - self.wrapping_clause_type = "keyword_from".try_into().ok(); + "relation" | "binary_expression" | "assignment" => { + self.wrapping_node_kind = current_node_kind.try_into().ok(); } _ => {} @@ -406,10 +465,6 @@ mod tests { ctx.get_ts_node_content(node), Some(NodeText::Original("from")) ); - assert_eq!( - ctx.wrapping_clause_type, - Some(crate::context::ClauseType::From) - ); } #[test] diff --git a/crates/pgt_completions/src/providers/schemas.rs b/crates/pgt_completions/src/providers/schemas.rs index eb493d0c..c28f831e 100644 --- a/crates/pgt_completions/src/providers/schemas.rs +++ b/crates/pgt_completions/src/providers/schemas.rs @@ -27,8 +27,8 @@ pub fn complete_schemas<'a>(ctx: &'a CompletionContext, builder: &mut Completion mod tests { use crate::{ - CompletionItemKind, complete, - test_helper::{CURSOR_POS, get_test_deps, get_test_params}, + CompletionItemKind, + test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}, }; #[tokio::test] @@ -46,27 +46,59 @@ mod tests { ); "#; - let query = format!("select * from {}", CURSOR_POS); + assert_complete_results( + format!("select * from {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".to_string(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".to_string(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind( + "internal".to_string(), + CompletionItemKind::Schema, + ), + CompletionAssertion::LabelAndKind( + "private".to_string(), + CompletionItemKind::Schema, + ), + CompletionAssertion::LabelAndKind( + "information_schema".to_string(), + CompletionItemKind::Schema, + ), + CompletionAssertion::LabelAndKind( + "pg_catalog".to_string(), + CompletionItemKind::Schema, + ), + CompletionAssertion::LabelAndKind( + "pg_toast".to_string(), + CompletionItemKind::Schema, + ), + CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table), + ], + setup, + ) + .await; + } - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; - let params = get_test_params(&tree, &cache, query.as_str().into()); - let items = complete(params); + #[tokio::test] + async fn suggests_tables_and_schemas_with_matching_keys() { + let setup = r#" + create schema ultimate; - assert!(!items.is_empty()); + -- add a table to compete against schemas + create table users ( + id serial primary key, + name text, + password text + ); + "#; - assert_eq!( - items - .into_iter() - .take(5) - .map(|i| (i.label, i.kind)) - .collect::>(), + assert_complete_results( + format!("select * from u{}", CURSOR_POS).as_str(), vec![ - ("public".to_string(), CompletionItemKind::Schema), - ("auth".to_string(), CompletionItemKind::Schema), - ("internal".to_string(), CompletionItemKind::Schema), - ("private".to_string(), CompletionItemKind::Schema), - ("users".to_string(), CompletionItemKind::Table), - ] - ); + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("ultimate".into(), CompletionItemKind::Schema), + ], + setup, + ) + .await; } } diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 1da77e15..f9f922d1 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -31,7 +31,10 @@ mod tests { use crate::{ CompletionItem, CompletionItemKind, complete, - test_helper::{CURSOR_POS, get_test_deps, get_test_params}, + test_helper::{ + CURSOR_POS, CompletionAssertion, assert_complete_results, assert_no_complete_results, + get_test_deps, get_test_params, + }, }; #[tokio::test] @@ -178,4 +181,96 @@ mod tests { assert_eq!(label, "coos"); assert_eq!(kind, CompletionItemKind::Table); } + + #[tokio::test] + async fn suggests_tables_in_update() { + let setup = r#" + create table coos ( + id serial primary key, + name text + ); + "#; + + assert_complete_results( + format!("update {}", CURSOR_POS).as_str(), + vec![CompletionAssertion::LabelAndKind( + "public".into(), + CompletionItemKind::Schema, + )], + setup, + ) + .await; + + assert_complete_results( + format!("update public.{}", CURSOR_POS).as_str(), + vec![CompletionAssertion::LabelAndKind( + "coos".into(), + CompletionItemKind::Table, + )], + setup, + ) + .await; + + assert_no_complete_results(format!("update public.coos {}", CURSOR_POS).as_str(), setup) + .await; + + assert_complete_results( + format!("update coos set {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + ], + setup, + ) + .await; + + assert_complete_results( + format!("update coos set name = 'cool' where {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + ], + setup, + ) + .await; + } + + #[tokio::test] + async fn suggests_tables_in_delete() { + let setup = r#" + create table coos ( + id serial primary key, + name text + ); + "#; + + assert_no_complete_results(format!("delete {}", CURSOR_POS).as_str(), setup).await; + + assert_complete_results( + format!("delete from {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("coos".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("delete from public.{}", CURSOR_POS).as_str(), + vec![CompletionAssertion::Label("coos".into())], + setup, + ) + .await; + + assert_complete_results( + format!("delete from public.coos where {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 214fda56..69939e0b 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -35,6 +35,16 @@ impl CompletionFilter<'_> { return None; } + // No autocompletions if there are two identifiers without a separator. + if ctx.node_under_cursor.is_some_and(|n| { + n.prev_sibling().is_some_and(|p| { + (p.kind() == "identifier" || p.kind() == "object_reference") + && n.kind() == "identifier" + }) + }) { + return None; + } + Some(()) } diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 7c3f3a06..2ef8edb6 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -1,4 +1,4 @@ -use crate::context::{ClauseType, CompletionContext, NodeText}; +use crate::context::{ClauseType, CompletionContext, WrappingNode}; use super::CompletionRelevanceData; @@ -28,20 +28,13 @@ impl CompletionScore<'_> { self.check_matches_query_input(ctx); self.check_is_invocation(ctx); self.check_matching_clause_type(ctx); + self.check_matching_wrapping_node(ctx); self.check_relations_in_stmt(ctx); } fn check_matches_query_input(&mut self, ctx: &CompletionContext) { - let node = match ctx.node_under_cursor { - Some(node) => node, - None => return, - }; - - let content = match ctx.get_ts_node_content(node) { - Some(c) => match c { - NodeText::Original(s) => s, - NodeText::Replaced => return, - }, + let content = match ctx.get_node_under_cursor_content() { + Some(c) => c, None => return, }; @@ -52,7 +45,7 @@ impl CompletionScore<'_> { CompletionRelevanceData::Schema(s) => s.name.as_str(), }; - if name.starts_with(content) { + if name.starts_with(content.as_str()) { let len: i32 = content .len() .try_into() @@ -69,12 +62,13 @@ impl CompletionScore<'_> { }; let has_mentioned_tables = !ctx.mentioned_relations.is_empty(); + let has_mentioned_schema = ctx.schema_name.is_some(); self.score += match self.data { CompletionRelevanceData::Table(_) => match clause_type { ClauseType::From => 5, - ClauseType::Update => 15, - ClauseType::Delete => 15, + ClauseType::Update => 10, + ClauseType::Delete => 10, _ => -50, }, CompletionRelevanceData::Function(_) => match clause_type { @@ -90,7 +84,42 @@ impl CompletionScore<'_> { _ => -15, }, CompletionRelevanceData::Schema(_) => match clause_type { - ClauseType::From => 10, + ClauseType::From if !has_mentioned_schema => 15, + ClauseType::Update if !has_mentioned_schema => 15, + ClauseType::Delete if !has_mentioned_schema => 15, + _ => -50, + }, + } + } + + fn check_matching_wrapping_node(&mut self, ctx: &CompletionContext) { + let wrapping_node = match ctx.wrapping_node_kind.as_ref() { + None => return, + Some(wn) => wn, + }; + + let has_mentioned_schema = ctx.schema_name.is_some(); + let has_node_text = ctx.get_node_under_cursor_content().is_some(); + + self.score += match self.data { + CompletionRelevanceData::Table(_) => match wrapping_node { + WrappingNode::Relation if has_mentioned_schema => 15, + WrappingNode::Relation if !has_mentioned_schema => 10, + WrappingNode::BinaryExpression => 5, + _ => -50, + }, + CompletionRelevanceData::Function(_) => match wrapping_node { + WrappingNode::Relation => 10, + _ => -50, + }, + CompletionRelevanceData::Column(_) => match wrapping_node { + WrappingNode::BinaryExpression => 15, + WrappingNode::Assignment => 15, + _ => -15, + }, + CompletionRelevanceData::Schema(_) => match wrapping_node { + WrappingNode::Relation if !has_mentioned_schema && !has_node_text => 15, + WrappingNode::Relation if !has_mentioned_schema && has_node_text => 0, _ => -50, }, } diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index b1c5b399..5eb5f53f 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -4,7 +4,7 @@ use pgt_schema_cache::SchemaCache; use pgt_test_utils::test_database::get_new_test_db; use sqlx::Executor; -use crate::CompletionParams; +use crate::{CompletionItem, CompletionItemKind, CompletionParams, complete}; pub static CURSOR_POS: char = '€'; @@ -141,3 +141,48 @@ mod tests { } } } + +#[derive(Debug, PartialEq, Eq)] +pub(crate) enum CompletionAssertion { + Label(String), + LabelAndKind(String, CompletionItemKind), +} + +impl CompletionAssertion { + fn assert_eq(self, item: CompletionItem) { + match self { + CompletionAssertion::Label(label) => { + assert_eq!(item.label, label); + } + CompletionAssertion::LabelAndKind(label, kind) => { + assert_eq!(item.label, label); + assert_eq!(item.kind, kind); + } + } + } +} + +pub(crate) async fn assert_complete_results( + query: &str, + assertions: Vec, + setup: &str, +) { + let (tree, cache) = get_test_deps(setup, query.into()).await; + let params = get_test_params(&tree, &cache, query.into()); + let items = complete(params); + + assertions + .into_iter() + .zip(items.into_iter()) + .for_each(|(assertion, result)| { + assertion.assert_eq(result); + }); +} + +pub(crate) async fn assert_no_complete_results(query: &str, setup: &str) { + let (tree, cache) = get_test_deps(setup, query.into()).await; + let params = get_test_params(&tree, &cache, query.into()); + let items = complete(params); + + assert_eq!(items.len(), 0) +} diff --git a/crates/pgt_workspace/src/features/code_actions.rs b/crates/pgt_workspace/src/features/code_actions.rs index 5e3cd883..22223dd3 100644 --- a/crates/pgt_workspace/src/features/code_actions.rs +++ b/crates/pgt_workspace/src/features/code_actions.rs @@ -46,7 +46,6 @@ pub struct CommandAction { #[derive(Debug, serde::Serialize, serde::Deserialize, strum::EnumIter)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] - pub enum CommandActionCategory { ExecuteStatement(StatementId), }