Skip to content

Commit ee45ea6

Browse files
authoredMar 30, 2025··
feat: execute stmt under cursor (#257)
1 parent 384f3fc commit ee45ea6

33 files changed

+726
-186
lines changed
 

Diff for: ‎.cargo/config.toml

-3
This file was deleted.

Diff for: ‎.vscode/settings.json

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"postgrestools.bin": "./target/debug/postgrestools"
3+
}

Diff for: ‎Cargo.lock

+33-12
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: ‎Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ serde = "1.0.195"
3939
serde_json = "1.0.114"
4040
similar = "2.6.0"
4141
smallvec = { version = "1.13.2", features = ["union", "const_new", "serde"] }
42+
strum = { version = "0.27.1", features = ["derive"] }
4243
# this will use tokio if available, otherwise async-std
4344
sqlx = { version = "0.8.2", features = ["runtime-tokio", "runtime-async-std", "postgres", "json"] }
4445
syn = "1.0.109"
@@ -56,7 +57,6 @@ pgt_analyse = { path = "./crates/pgt_analyse", version = "0.0.0"
5657
pgt_analyser = { path = "./crates/pgt_analyser", version = "0.0.0" }
5758
pgt_base_db = { path = "./crates/pgt_base_db", version = "0.0.0" }
5859
pgt_cli = { path = "./crates/pgt_cli", version = "0.0.0" }
59-
pgt_commands = { path = "./crates/pgt_commands", version = "0.0.0" }
6060
pgt_completions = { path = "./crates/pgt_completions", version = "0.0.0" }
6161
pgt_configuration = { path = "./crates/pgt_configuration", version = "0.0.0" }
6262
pgt_console = { path = "./crates/pgt_console", version = "0.0.0" }

Diff for: ‎crates/pgt_commands/Cargo.toml

-21
This file was deleted.

Diff for: ‎crates/pgt_commands/src/command.rs

-32
This file was deleted.

Diff for: ‎crates/pgt_commands/src/execute_statement.rs

-44
This file was deleted.

Diff for: ‎crates/pgt_commands/src/lib.rs

-5
This file was deleted.

Diff for: ‎crates/pgt_configuration/src/database.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use biome_deserialize::StringSet;
12
use biome_deserialize_macros::{Merge, Partial};
23
use bpaf::Bpaf;
34
use serde::{Deserialize, Serialize};
@@ -28,6 +29,9 @@ pub struct DatabaseConfiguration {
2829
#[partial(bpaf(long("database")))]
2930
pub database: String,
3031

32+
#[partial(bpaf(long("allow_statement_executions_against")))]
33+
pub allow_statement_executions_against: StringSet,
34+
3135
/// The connection timeout in seconds.
3236
#[partial(bpaf(long("conn_timeout_secs"), fallback(Some(10)), debug_fallback))]
3337
pub conn_timeout_secs: u16,
@@ -41,6 +45,7 @@ impl Default for DatabaseConfiguration {
4145
username: "postgres".to_string(),
4246
password: "postgres".to_string(),
4347
database: "postgres".to_string(),
48+
allow_statement_executions_against: Default::default(),
4449
conn_timeout_secs: 10,
4550
}
4651
}

Diff for: ‎crates/pgt_configuration/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ impl PartialConfiguration {
111111
password: Some("postgres".to_string()),
112112
database: Some("postgres".to_string()),
113113
conn_timeout_secs: Some(10),
114+
allow_statement_executions_against: Default::default(),
114115
}),
115116
}
116117
}

Diff for: ‎crates/pgt_lsp/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pgt_workspace = { workspace = true }
2828
rustc-hash = { workspace = true }
2929
serde = { workspace = true, features = ["derive"] }
3030
serde_json = { workspace = true }
31+
strum = { workspace = true }
3132
tokio = { workspace = true, features = ["rt", "io-std"] }
3233
tower-lsp = { version = "0.20.0" }
3334
tracing = { workspace = true, features = ["attributes"] }

Diff for: ‎crates/pgt_lsp/src/capabilities.rs

+18-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
use pgt_lsp_converters::{PositionEncoding, WideEncoding, negotiated_encoding};
2+
use pgt_workspace::code_actions::{CommandActionCategory, CommandActionCategoryIter};
3+
use strum::{EnumIter, IntoEnumIterator};
24
use tower_lsp::lsp_types::{
3-
ClientCapabilities, CompletionOptions, PositionEncodingKind, SaveOptions, ServerCapabilities,
4-
TextDocumentSyncCapability, TextDocumentSyncKind, TextDocumentSyncOptions,
5-
TextDocumentSyncSaveOptions, WorkDoneProgressOptions,
5+
ClientCapabilities, CodeActionOptions, CompletionOptions, ExecuteCommandOptions,
6+
PositionEncodingKind, SaveOptions, ServerCapabilities, TextDocumentSyncCapability,
7+
TextDocumentSyncKind, TextDocumentSyncOptions, TextDocumentSyncSaveOptions,
8+
WorkDoneProgressOptions,
69
};
710

11+
use crate::handlers::code_actions::command_id;
12+
813
/// The capabilities to send from server as part of [`InitializeResult`]
914
///
1015
/// [`InitializeResult`]: lspower::lsp::InitializeResult
@@ -46,10 +51,19 @@ pub(crate) fn server_capabilities(capabilities: &ClientCapabilities) -> ServerCa
4651
work_done_progress: None,
4752
},
4853
}),
54+
execute_command_provider: Some(ExecuteCommandOptions {
55+
commands: CommandActionCategory::iter()
56+
.map(|c| command_id(&c))
57+
.collect::<Vec<String>>(),
58+
59+
..Default::default()
60+
}),
4961
document_formatting_provider: None,
5062
document_range_formatting_provider: None,
5163
document_on_type_formatting_provider: None,
52-
code_action_provider: None,
64+
code_action_provider: Some(tower_lsp::lsp_types::CodeActionProviderCapability::Simple(
65+
true,
66+
)),
5367
rename_provider: None,
5468
..Default::default()
5569
}

Diff for: ‎crates/pgt_lsp/src/handlers.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1+
pub(crate) mod code_actions;
12
pub(crate) mod completions;
3+
mod helper;
24
pub(crate) mod text_document;

Diff for: ‎crates/pgt_lsp/src/handlers/code_actions.rs

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
use crate::session::Session;
2+
use anyhow::{Result, anyhow};
3+
use tower_lsp::lsp_types::{
4+
self, CodeAction, CodeActionDisabled, CodeActionOrCommand, Command, ExecuteCommandParams,
5+
MessageType,
6+
};
7+
8+
use pgt_workspace::code_actions::{
9+
CodeActionKind, CodeActionsParams, CommandActionCategory, ExecuteStatementParams,
10+
};
11+
12+
use super::helper;
13+
14+
pub fn get_actions(
15+
session: &Session,
16+
params: lsp_types::CodeActionParams,
17+
) -> Result<lsp_types::CodeActionResponse> {
18+
let url = params.text_document.uri;
19+
let path = session.file_path(&url)?;
20+
21+
let cursor_position = helper::get_cursor_position(session, &url, params.range.start)?;
22+
23+
let workspace_actions = session.workspace.pull_code_actions(CodeActionsParams {
24+
path,
25+
cursor_position,
26+
only: vec![],
27+
skip: vec![],
28+
})?;
29+
30+
let actions: Vec<CodeAction> = workspace_actions
31+
.actions
32+
.into_iter()
33+
.filter_map(|action| match action.kind {
34+
CodeActionKind::Command(command) => {
35+
let command_id: String = command_id(&command.category);
36+
let title = action.title;
37+
38+
match command.category {
39+
CommandActionCategory::ExecuteStatement(stmt_id) => Some(CodeAction {
40+
title: title.clone(),
41+
kind: Some(lsp_types::CodeActionKind::EMPTY),
42+
command: Some({
43+
Command {
44+
title: title.clone(),
45+
command: command_id,
46+
arguments: Some(vec![
47+
serde_json::Value::Number(stmt_id.into()),
48+
serde_json::to_value(&url).unwrap(),
49+
]),
50+
}
51+
}),
52+
disabled: action
53+
.disabled_reason
54+
.map(|reason| CodeActionDisabled { reason }),
55+
..Default::default()
56+
}),
57+
}
58+
}
59+
60+
_ => todo!(),
61+
})
62+
.collect();
63+
64+
Ok(actions
65+
.into_iter()
66+
.map(|ac| CodeActionOrCommand::CodeAction(ac))
67+
.collect())
68+
}
69+
70+
pub fn command_id(command: &CommandActionCategory) -> String {
71+
match command {
72+
CommandActionCategory::ExecuteStatement(_) => "pgt.executeStatement".into(),
73+
}
74+
}
75+
76+
pub async fn execute_command(
77+
session: &Session,
78+
params: ExecuteCommandParams,
79+
) -> anyhow::Result<Option<serde_json::Value>> {
80+
let command = params.command;
81+
82+
match command.as_str() {
83+
"pgt.executeStatement" => {
84+
let id: usize = serde_json::from_value(params.arguments[0].clone())?;
85+
let doc_url: lsp_types::Url = serde_json::from_value(params.arguments[1].clone())?;
86+
87+
let path = session.file_path(&doc_url)?;
88+
89+
let result = session
90+
.workspace
91+
.execute_statement(ExecuteStatementParams {
92+
statement_id: id,
93+
path,
94+
})?;
95+
96+
/**
97+
* Updating all diagnostics: the changes caused by the statement execution
98+
* might affect many files.
99+
*
100+
* TODO: in test.sql, this seems to work after create table, but not after drop table.
101+
*/
102+
session.update_all_diagnostics().await;
103+
104+
session
105+
.client
106+
.show_message(MessageType::INFO, result.message)
107+
.await;
108+
109+
Ok(None)
110+
}
111+
112+
any => Err(anyhow!(format!("Unknown command: {}", any))),
113+
}
114+
}

Diff for: ‎crates/pgt_lsp/src/handlers/completions.rs

+7-16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use anyhow::Result;
33
use pgt_workspace::{WorkspaceError, workspace};
44
use tower_lsp::lsp_types::{self, CompletionItem, CompletionItemLabelDetails};
55

6+
use super::helper;
7+
68
#[tracing::instrument(level = "trace", skip_all)]
79
pub fn get_completions(
810
session: &Session,
@@ -11,27 +13,16 @@ pub fn get_completions(
1113
let url = params.text_document_position.text_document.uri;
1214
let path = session.file_path(&url)?;
1315

14-
let client_capabilities = session
15-
.client_capabilities()
16-
.expect("Client capabilities not established for current session.");
17-
18-
let line_index = session
19-
.document(&url)
20-
.map(|doc| doc.line_index)
21-
.map_err(|_| anyhow::anyhow!("Document not found."))?;
22-
23-
let offset = pgt_lsp_converters::from_proto::offset(
24-
&line_index,
25-
params.text_document_position.position,
26-
pgt_lsp_converters::negotiated_encoding(client_capabilities),
27-
)?;
28-
2916
let completion_result =
3017
match session
3118
.workspace
3219
.get_completions(workspace::GetCompletionsParams {
3320
path,
34-
position: offset,
21+
position: helper::get_cursor_position(
22+
session,
23+
&url,
24+
params.text_document_position.position,
25+
)?,
3526
}) {
3627
Ok(result) => result,
3728
Err(e) => match e {

Diff for: ‎crates/pgt_lsp/src/handlers/helper.rs

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use crate::session::Session;
2+
use pgt_text_size::TextSize;
3+
use tower_lsp::lsp_types;
4+
5+
pub fn get_cursor_position(
6+
session: &Session,
7+
url: &lsp_types::Url,
8+
position: lsp_types::Position,
9+
) -> anyhow::Result<TextSize> {
10+
let client_capabilities = session
11+
.client_capabilities()
12+
.expect("Client capabilities not established for current session.");
13+
14+
let line_index = session
15+
.document(url)
16+
.map(|doc| doc.line_index)
17+
.map_err(|_| anyhow::anyhow!("Document not found."))?;
18+
19+
let cursor_pos = pgt_lsp_converters::from_proto::offset(
20+
&line_index,
21+
position,
22+
pgt_lsp_converters::negotiated_encoding(client_capabilities),
23+
)?;
24+
25+
Ok(cursor_pos)
26+
}

Diff for: ‎crates/pgt_lsp/src/server.rs

+22
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,28 @@ impl LanguageServer for LSPServer {
252252
Err(e) => LspResult::Err(into_lsp_error(e)),
253253
}
254254
}
255+
256+
#[tracing::instrument(level = "trace", skip(self))]
257+
async fn code_action(&self, params: CodeActionParams) -> LspResult<Option<CodeActionResponse>> {
258+
match handlers::code_actions::get_actions(&self.session, params) {
259+
Ok(result) => {
260+
tracing::info!("Got Code Actions: {:?}", result);
261+
return LspResult::Ok(Some(result));
262+
}
263+
Err(e) => LspResult::Err(into_lsp_error(e)),
264+
}
265+
}
266+
267+
#[tracing::instrument(level = "trace", skip(self))]
268+
async fn execute_command(
269+
&self,
270+
params: ExecuteCommandParams,
271+
) -> LspResult<Option<serde_json::Value>> {
272+
match handlers::code_actions::execute_command(&self.session, params).await {
273+
Ok(result) => LspResult::Ok(None),
274+
Err(err) => LspResult::Err(into_lsp_error(err)),
275+
}
276+
}
255277
}
256278

257279
impl Drop for LSPServer {

Diff for: ‎crates/pgt_lsp/tests/server.rs

+146
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use pgt_test_utils::test_database::get_new_test_db;
1717
use pgt_workspace::DynRef;
1818
use serde::Serialize;
1919
use serde::de::DeserializeOwned;
20+
use serde_json::Value;
2021
use serde_json::{from_value, to_value};
2122
use sqlx::Executor;
2223
use std::any::type_name;
@@ -28,8 +29,13 @@ use tower_lsp::LspService;
2829
use tower_lsp::jsonrpc;
2930
use tower_lsp::jsonrpc::Response;
3031
use tower_lsp::lsp_types as lsp;
32+
use tower_lsp::lsp_types::CodeActionContext;
33+
use tower_lsp::lsp_types::CodeActionOrCommand;
34+
use tower_lsp::lsp_types::CodeActionParams;
35+
use tower_lsp::lsp_types::CodeActionResponse;
3136
use tower_lsp::lsp_types::CompletionParams;
3237
use tower_lsp::lsp_types::CompletionResponse;
38+
use tower_lsp::lsp_types::ExecuteCommandParams;
3339
use tower_lsp::lsp_types::PartialResultParams;
3440
use tower_lsp::lsp_types::Position;
3541
use tower_lsp::lsp_types::Range;
@@ -551,3 +557,143 @@ async fn test_completions() -> Result<()> {
551557

552558
Ok(())
553559
}
560+
561+
#[tokio::test]
562+
async fn test_execute_statement() -> Result<()> {
563+
let factory = ServerFactory::default();
564+
let mut fs = MemoryFileSystem::default();
565+
let test_db = get_new_test_db().await;
566+
567+
let database = test_db
568+
.connect_options()
569+
.get_database()
570+
.unwrap()
571+
.to_string();
572+
let host = test_db.connect_options().get_host().to_string();
573+
574+
let conf = PartialConfiguration {
575+
db: Some(PartialDatabaseConfiguration {
576+
database: Some(database),
577+
host: Some(host),
578+
..Default::default()
579+
}),
580+
..Default::default()
581+
};
582+
583+
fs.insert(
584+
url!("postgrestools.jsonc").to_file_path().unwrap(),
585+
serde_json::to_string_pretty(&conf).unwrap(),
586+
);
587+
588+
let (service, client) = factory
589+
.create_with_fs(None, DynRef::Owned(Box::new(fs)))
590+
.into_inner();
591+
592+
let (stream, sink) = client.split();
593+
let mut server = Server::new(service);
594+
595+
let (sender, _) = channel(CHANNEL_BUFFER_SIZE);
596+
let reader = tokio::spawn(client_handler(stream, sink, sender));
597+
598+
server.initialize().await?;
599+
server.initialized().await?;
600+
601+
server.load_configuration().await?;
602+
603+
let users_tbl_exists = async || {
604+
let result = sqlx::query!(
605+
r#"
606+
select exists (
607+
select 1 as exists
608+
from pg_catalog.pg_tables
609+
where tablename = 'users'
610+
);
611+
"#
612+
)
613+
.fetch_one(&test_db.clone())
614+
.await;
615+
616+
result.unwrap().exists.unwrap()
617+
};
618+
619+
assert_eq!(
620+
users_tbl_exists().await,
621+
false,
622+
"The user table shouldn't exist at this point."
623+
);
624+
625+
let doc_content = r#"
626+
create table users (
627+
id serial primary key,
628+
name text,
629+
email text
630+
);
631+
"#;
632+
633+
let doc_url = url!("test.sql");
634+
635+
server
636+
.open_named_document(doc_content.to_string(), doc_url.clone())
637+
.await?;
638+
639+
let code_actions_response = server
640+
.request::<CodeActionParams, CodeActionResponse>(
641+
"textDocument/codeAction",
642+
"_code_action",
643+
CodeActionParams {
644+
text_document: TextDocumentIdentifier {
645+
uri: doc_url.clone(),
646+
},
647+
range: Range {
648+
start: Position::new(3, 7),
649+
end: Position::new(3, 7),
650+
}, // just somewhere within the statement.
651+
context: CodeActionContext::default(),
652+
partial_result_params: PartialResultParams::default(),
653+
work_done_progress_params: WorkDoneProgressParams::default(),
654+
},
655+
)
656+
.await?
657+
.unwrap();
658+
659+
let exec_statement_command: (String, Vec<Value>) = code_actions_response
660+
.iter()
661+
.find_map(|action_or_cmd| match action_or_cmd {
662+
lsp::CodeActionOrCommand::CodeAction(code_action) => {
663+
let command = code_action.command.as_ref();
664+
if command.is_some_and(|cmd| &cmd.command == "pgt.executeStatement") {
665+
let command = command.unwrap();
666+
let arguments = command.arguments.as_ref().unwrap().clone();
667+
Some((command.command.clone(), arguments))
668+
} else {
669+
None
670+
}
671+
}
672+
673+
_ => None,
674+
})
675+
.expect("Did not find executeStatement command!");
676+
677+
server
678+
.request::<ExecuteCommandParams, Option<Value>>(
679+
"workspace/executeCommand",
680+
"_execStmt",
681+
ExecuteCommandParams {
682+
command: exec_statement_command.0,
683+
arguments: exec_statement_command.1,
684+
..Default::default()
685+
},
686+
)
687+
.await?;
688+
689+
assert_eq!(
690+
users_tbl_exists().await,
691+
true,
692+
"Users table did not exists even though it should've been created by the workspace/executeStatement command."
693+
);
694+
695+
server.shutdown().await?;
696+
reader.abort();
697+
698+
Ok(())
699+
}

Diff for: ‎crates/pgt_workspace/Cargo.toml

+6-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ version = "0.0.0"
1212

1313

1414
[dependencies]
15-
biome_deserialize = "0.6.0"
16-
dashmap = "5.5.3"
17-
futures = "0.3.31"
15+
biome_deserialize = "0.6.0"
16+
dashmap = "5.5.3"
17+
futures = "0.3.31"
18+
globset = "0.4.16"
19+
1820
ignore = { workspace = true }
1921
pgt_analyse = { workspace = true, features = ["serde"] }
2022
pgt_analyser = { workspace = true }
@@ -33,6 +35,7 @@ schemars = { workspace = true, optional = true }
3335
serde = { workspace = true, features = ["derive"] }
3436
serde_json = { workspace = true, features = ["raw_value"] }
3537
sqlx.workspace = true
38+
strum = { workspace = true }
3639
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
3740
tracing = { workspace = true, features = ["attributes", "log"] }
3841
tree-sitter.workspace = true

Diff for: ‎crates/pgt_workspace/src/code_actions.rs

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
use crate::workspace::StatementId;
2+
use pgt_configuration::RuleSelector;
3+
use pgt_fs::PgTPath;
4+
use pgt_text_size::TextSize;
5+
6+
#[derive(Debug, serde::Serialize, serde::Deserialize)]
7+
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
8+
pub struct CodeActionsParams {
9+
pub path: PgTPath,
10+
pub cursor_position: TextSize,
11+
pub only: Vec<RuleSelector>,
12+
pub skip: Vec<RuleSelector>,
13+
}
14+
15+
#[derive(Debug, serde::Serialize, serde::Deserialize)]
16+
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
17+
pub struct CodeActionsResult {
18+
pub actions: Vec<CodeAction>,
19+
}
20+
21+
#[derive(Debug, serde::Serialize, serde::Deserialize)]
22+
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
23+
pub struct CodeAction {
24+
pub title: String,
25+
pub kind: CodeActionKind,
26+
pub disabled_reason: Option<String>,
27+
}
28+
29+
#[derive(Debug, serde::Serialize, serde::Deserialize)]
30+
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
31+
pub enum CodeActionKind {
32+
Edit(EditAction),
33+
Command(CommandAction),
34+
EditAndCommand(EditAction, CommandAction),
35+
}
36+
37+
#[derive(Debug, serde::Serialize, serde::Deserialize)]
38+
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
39+
pub struct EditAction {}
40+
41+
#[derive(Debug, serde::Serialize, serde::Deserialize)]
42+
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
43+
pub struct CommandAction {
44+
pub category: CommandActionCategory,
45+
}
46+
47+
#[derive(Debug, serde::Serialize, serde::Deserialize, strum::EnumIter)]
48+
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
49+
50+
pub enum CommandActionCategory {
51+
ExecuteStatement(StatementId),
52+
}
53+
54+
#[derive(Debug, serde::Serialize, serde::Deserialize)]
55+
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
56+
pub struct ExecuteStatementParams {
57+
pub statement_id: StatementId,
58+
pub path: PgTPath,
59+
}
60+
61+
#[derive(Debug, serde::Serialize, serde::Deserialize)]
62+
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
63+
pub struct ExecuteStatementResult {
64+
pub message: String,
65+
}

Diff for: ‎crates/pgt_workspace/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ use std::ops::{Deref, DerefMut};
33
use pgt_console::Console;
44
use pgt_fs::{FileSystem, OsFileSystem};
55

6+
pub mod code_actions;
67
pub mod configuration;
78
pub mod diagnostics;
89
pub mod dome;
910
pub mod matcher;
1011
pub mod settings;
1112
pub mod workspace;
13+
1214
#[cfg(feature = "schema")]
1315
pub mod workspace_types;
1416

Diff for: ‎crates/pgt_workspace/src/settings.rs

+65-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use biome_deserialize::StringSet;
2+
use globset::Glob;
23
use pgt_diagnostics::Category;
34
use std::{
45
borrow::Cow,
@@ -273,6 +274,7 @@ pub struct DatabaseSettings {
273274
pub password: String,
274275
pub database: String,
275276
pub conn_timeout_secs: Duration,
277+
pub allow_statement_executions: bool,
276278
}
277279

278280
impl Default for DatabaseSettings {
@@ -284,23 +286,44 @@ impl Default for DatabaseSettings {
284286
password: "postgres".to_string(),
285287
database: "postgres".to_string(),
286288
conn_timeout_secs: Duration::from_secs(10),
289+
allow_statement_executions: true,
287290
}
288291
}
289292
}
290293

291294
impl From<PartialDatabaseConfiguration> for DatabaseSettings {
292295
fn from(value: PartialDatabaseConfiguration) -> Self {
293296
let d = DatabaseSettings::default();
297+
298+
let database = value.database.unwrap_or(d.database);
299+
let host = value.host.unwrap_or(d.host);
300+
301+
let allow_statement_executions = value
302+
.allow_statement_executions_against
303+
.map(|stringset| {
304+
stringset.iter().any(|pattern| {
305+
let glob = Glob::new(pattern)
306+
.expect(format!("Invalid pattern: {}", pattern).as_str())
307+
.compile_matcher();
308+
309+
glob.is_match(format!("{}/{}", host, database))
310+
})
311+
})
312+
.unwrap_or(false);
313+
294314
Self {
295-
host: value.host.unwrap_or(d.host),
296315
port: value.port.unwrap_or(d.port),
297316
username: value.username.unwrap_or(d.username),
298317
password: value.password.unwrap_or(d.password),
299-
database: value.database.unwrap_or(d.database),
318+
database,
319+
host,
320+
300321
conn_timeout_secs: value
301322
.conn_timeout_secs
302323
.map(|s| Duration::from_secs(s.into()))
303324
.unwrap_or(d.conn_timeout_secs),
325+
326+
allow_statement_executions,
304327
}
305328
}
306329
}
@@ -415,3 +438,43 @@ impl PartialConfigurationExt for PartialConfiguration {
415438
Ok((None, vec![]))
416439
}
417440
}
441+
442+
#[cfg(test)]
443+
mod tests {
444+
use biome_deserialize::StringSet;
445+
use pgt_configuration::database::PartialDatabaseConfiguration;
446+
447+
use super::DatabaseSettings;
448+
449+
#[test]
450+
fn should_identify_allowed_statement_executions() {
451+
let partial_config = PartialDatabaseConfiguration {
452+
allow_statement_executions_against: Some(StringSet::from_iter(
453+
vec![String::from("localhost/*")].into_iter(),
454+
)),
455+
host: Some("localhost".into()),
456+
database: Some("test-db".into()),
457+
..Default::default()
458+
};
459+
460+
let config = DatabaseSettings::from(partial_config);
461+
462+
assert_eq!(config.allow_statement_executions, true)
463+
}
464+
465+
#[test]
466+
fn should_identify_not_allowed_statement_executions() {
467+
let partial_config = PartialDatabaseConfiguration {
468+
allow_statement_executions_against: Some(StringSet::from_iter(
469+
vec![String::from("localhost/*")].into_iter(),
470+
)),
471+
host: Some("production".into()),
472+
database: Some("test-db".into()),
473+
..Default::default()
474+
};
475+
476+
let config = DatabaseSettings::from(partial_config);
477+
478+
assert_eq!(config.allow_statement_executions, false)
479+
}
480+
}

Diff for: ‎crates/pgt_workspace/src/workspace.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,18 @@ use pgt_fs::PgTPath;
77
use pgt_text_size::{TextRange, TextSize};
88
use serde::{Deserialize, Serialize};
99

10-
use crate::WorkspaceError;
10+
use crate::{
11+
WorkspaceError,
12+
code_actions::{
13+
CodeActionsParams, CodeActionsResult, ExecuteStatementParams, ExecuteStatementResult,
14+
},
15+
};
1116

1217
mod client;
1318
mod server;
1419

20+
pub(crate) use server::StatementId;
21+
1522
#[derive(Debug, serde::Serialize, serde::Deserialize)]
1623
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
1724
pub struct OpenFileParams {
@@ -115,6 +122,12 @@ pub trait Workspace: Send + Sync + RefUnwindSafe {
115122
params: PullDiagnosticsParams,
116123
) -> Result<PullDiagnosticsResult, WorkspaceError>;
117124

125+
/// Retrieves a list of available code_actions for a file/cursor_position
126+
fn pull_code_actions(
127+
&self,
128+
params: CodeActionsParams,
129+
) -> Result<CodeActionsResult, WorkspaceError>;
130+
118131
fn get_completions(
119132
&self,
120133
params: GetCompletionsParams,
@@ -145,6 +158,11 @@ pub trait Workspace: Send + Sync + RefUnwindSafe {
145158
///
146159
/// If the file path matches, then `true` is returned, and it should be considered ignored.
147160
fn is_path_ignored(&self, params: IsPathIgnoredParams) -> Result<bool, WorkspaceError>;
161+
162+
fn execute_statement(
163+
&self,
164+
params: ExecuteStatementParams,
165+
) -> Result<ExecuteStatementResult, WorkspaceError>;
148166
}
149167

150168
/// Convenience function for constructing a server instance of [Workspace]

Diff for: ‎crates/pgt_workspace/src/workspace/client.rs

+14
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,20 @@ impl<T> Workspace for WorkspaceClient<T>
8989
where
9090
T: WorkspaceTransport + RefUnwindSafe + Send + Sync,
9191
{
92+
fn pull_code_actions(
93+
&self,
94+
params: crate::code_actions::CodeActionsParams,
95+
) -> Result<crate::code_actions::CodeActionsResult, WorkspaceError> {
96+
self.request("pgt/code_actions", params)
97+
}
98+
99+
fn execute_statement(
100+
&self,
101+
params: crate::code_actions::ExecuteStatementParams,
102+
) -> Result<crate::code_actions::ExecuteStatementResult, WorkspaceError> {
103+
self.request("pgt/execute_statement", params)
104+
}
105+
92106
fn open_file(&self, params: OpenFileParams) -> Result<(), WorkspaceError> {
93107
self.request("pgt/open_file", params)
94108
}

Diff for: ‎crates/pgt_workspace/src/workspace/server.rs

+101
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use async_helper::run_async;
55
use change::StatementChange;
66
use dashmap::DashMap;
77
use db_connection::DbConnection;
8+
pub(crate) use document::StatementId;
89
use document::{Document, Statement};
910
use futures::{StreamExt, stream};
1011
use pg_query::PgQueryStore;
@@ -14,11 +15,16 @@ use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as
1415
use pgt_fs::{ConfigName, PgTPath};
1516
use pgt_typecheck::TypecheckParams;
1617
use schema_cache_manager::SchemaCacheManager;
18+
use sqlx::Executor;
1719
use tracing::info;
1820
use tree_sitter::TreeSitterStore;
1921

2022
use crate::{
2123
WorkspaceError,
24+
code_actions::{
25+
self, CodeAction, CodeActionKind, CodeActionsResult, CommandAction, CommandActionCategory,
26+
ExecuteStatementResult,
27+
},
2228
configuration::to_analyser_rules,
2329
settings::{Settings, SettingsHandle, SettingsHandleMut},
2430
workspace::PullDiagnosticsResult,
@@ -253,6 +259,101 @@ impl Workspace for WorkspaceServer {
253259
Ok(self.is_ignored(params.pgt_path.as_path()))
254260
}
255261

262+
fn pull_code_actions(
263+
&self,
264+
params: code_actions::CodeActionsParams,
265+
) -> Result<code_actions::CodeActionsResult, WorkspaceError> {
266+
let doc = self
267+
.documents
268+
.get(&params.path)
269+
.ok_or(WorkspaceError::not_found())?;
270+
271+
let eligible_statements = doc
272+
.iter_statements_with_text_and_range()
273+
.filter(|(_, range, _)| range.contains(params.cursor_position));
274+
275+
let mut actions: Vec<CodeAction> = vec![];
276+
277+
let settings = self
278+
.settings
279+
.read()
280+
.expect("Unable to read settings for Code Actions");
281+
282+
let disabled_reason: Option<String> = if settings.db.allow_statement_executions {
283+
None
284+
} else {
285+
Some("Statement execution not allowed against database.".into())
286+
};
287+
288+
for (stmt, range, txt) in eligible_statements {
289+
let title = format!(
290+
"Execute Statement: {}...",
291+
txt.chars().take(50).collect::<String>()
292+
);
293+
294+
actions.push(CodeAction {
295+
title,
296+
kind: CodeActionKind::Command(CommandAction {
297+
category: CommandActionCategory::ExecuteStatement(stmt.id),
298+
}),
299+
disabled_reason: disabled_reason.clone(),
300+
});
301+
}
302+
303+
Ok(CodeActionsResult { actions })
304+
}
305+
306+
fn execute_statement(
307+
&self,
308+
params: code_actions::ExecuteStatementParams,
309+
) -> Result<code_actions::ExecuteStatementResult, WorkspaceError> {
310+
let doc = self
311+
.documents
312+
.get(&params.path)
313+
.ok_or(WorkspaceError::not_found())?;
314+
315+
if self
316+
.pg_query
317+
.get_ast(&Statement {
318+
path: params.path,
319+
id: params.statement_id,
320+
})
321+
.is_none()
322+
{
323+
return Ok(ExecuteStatementResult {
324+
message: "Statement is invalid.".into(),
325+
});
326+
};
327+
328+
let sql: String = match doc.get_txt(params.statement_id) {
329+
Some(txt) => txt,
330+
None => {
331+
return Ok(ExecuteStatementResult {
332+
message: "Statement was not found in document.".into(),
333+
});
334+
}
335+
};
336+
337+
let conn = self.connection.write().unwrap();
338+
let pool = match conn.get_pool() {
339+
Some(p) => p,
340+
None => {
341+
return Ok(ExecuteStatementResult {
342+
message: "Not connected to database.".into(),
343+
});
344+
}
345+
};
346+
347+
let result = run_async(async move { pool.execute(sqlx::query(&sql)).await })??;
348+
349+
Ok(ExecuteStatementResult {
350+
message: format!(
351+
"Successfully executed statement. Rows affected: {}",
352+
result.rows_affected()
353+
),
354+
})
355+
}
356+
256357
fn pull_diagnostics(
257358
&self,
258359
params: super::PullDiagnosticsParams,

Diff for: ‎crates/pgt_workspace/src/workspace/server/document.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pub(crate) struct Statement {
1111
pub(crate) id: StatementId,
1212
}
1313

14-
pub type StatementId = usize;
14+
pub(crate) type StatementId = usize;
1515

1616
type StatementPos = (StatementId, TextRange);
1717

@@ -103,6 +103,16 @@ impl Document {
103103
)
104104
})
105105
}
106+
107+
pub fn get_txt(&self, stmt_id: StatementId) -> Option<String> {
108+
self.positions
109+
.iter()
110+
.find(|pos| pos.0 == stmt_id)
111+
.map(|(_, range)| {
112+
let stmt = &self.content[range.start().into()..range.end().into()];
113+
stmt.to_owned()
114+
})
115+
}
106116
}
107117

108118
pub(crate) struct IdGenerator {

Diff for: ‎crates/pgt_workspace/src/workspace/server/schema_cache_manager.rs

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ impl SchemaCacheManager {
4848
// return early if the connection string is the same
4949
let inner = self.inner.read().unwrap();
5050
if new_conn_str == inner.conn_str {
51+
tracing::info!("Same connection string, no updates.");
5152
return Ok(SchemaCacheHandle::wrap(inner));
5253
}
5354
}
@@ -63,6 +64,7 @@ impl SchemaCacheManager {
6364
if new_conn_str != inner.conn_str {
6465
inner.cache = refreshed;
6566
inner.conn_str = new_conn_str;
67+
tracing::info!("Refreshed connection.");
6668
}
6769
}
6870

Diff for: ‎docs/index.md

+9-8
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ A collection of language tools and a Language Server Protocol (LSP) implementati
1414

1515
This project provides a toolchain for Postgres development
1616

17-
##### Postgres Language Server
17+
##### Postgres Language Server
1818

1919
![LSP Demo](images/lsp-demo.gif)
2020

2121
##### CLI Demo
2222

2323
![CLI Demo](images/cli-demo.png)
2424

25-
The toolchain is built on Postgres' own parser `libpg_query` to ensure 100% syntax compatibility. It uses a Server-Client architecture and is a transport-agnostic. This means all features can be accessed through the [Language Server Protocol](https://microsoft.github.io/language-server-protocol/) as well as various interfaces like a CLI, HTTP APIs, or a WebAssembly module.
25+
The toolchain is built on Postgres' own parser `libpg_query` to ensure 100% syntax compatibility. It uses a Server-Client architecture and is a transport-agnostic. This means all features can be accessed through the [Language Server Protocol](https://microsoft.github.io/language-server-protocol/) as well as various interfaces like a CLI, HTTP APIs, or a WebAssembly module.
2626

2727
The following features are implemented:
2828

@@ -50,7 +50,7 @@ Now you can use Postgres Tools by simply running `./postgrestools`.
5050

5151
### NPM
5252

53-
If you are using Node, you can install the CLI via NPM. Run the following commands in a directory containing a `package.json` file.
53+
If you are using Node, you can install the CLI via NPM. Run the following commands in a directory containing a `package.json` file.
5454

5555
```sh
5656
npm add --save-dev --save-exact @postgrestools/postgrestools
@@ -78,7 +78,7 @@ postgrestools init
7878

7979
You’ll now have a `postgrestools.jsonc` file in your directory:
8080

81-
[//]: # (BEGIN DEFAULT_CONFIGURATION)
81+
[//]: # "BEGIN DEFAULT_CONFIGURATION"
8282

8383
```json
8484
{
@@ -103,12 +103,13 @@ You’ll now have a `postgrestools.jsonc` file in your directory:
103103
"username": "postgres",
104104
"password": "postgres",
105105
"database": "postgres",
106-
"connTimeoutSecs": 10
106+
"connTimeoutSecs": 10,
107+
"allowStatementExecutionsAgainst": ["127.0.0.1/*", "localhost/*"]
107108
}
108109
}
109110
```
110111

111-
[//]: # (END DEFAULT_CONFIGURATION)
112+
[//]: # "END DEFAULT_CONFIGURATION"
112113

113114
Make sure to edit the database connection settings to connect to your local development database. To see all options, run `postgrestools --help`.
114115

@@ -129,9 +130,10 @@ Make sure to check out the other options by running `postgrestools --help`. We w
129130
#### Using the LSP Proxy
130131

131132
Postgres Tools has a command called `lsp-proxy`. When executed, two processes will spawn:
133+
132134
- a daemon that does execute the requested operations;
133135
- a server that functions as a proxy between the requests of the client - the editor - and the server - the daemon;
134-
If your editor is able to interact with a server and send [JSON-RPC](https://www.jsonrpc.org) requests, you only need to configure the editor to run that command.
136+
If your editor is able to interact with a server and send [JSON-RPC](https://www.jsonrpc.org) requests, you only need to configure the editor to run that command.
135137

136138
#### Using the daemon with the binary
137139

@@ -159,4 +161,3 @@ The daemon saves logs in your file system. Logs are stored in a folder called `p
159161
For other operative systems, you can find the folder in the system’s temporary directory.
160162

161163
You can change the location of the `pgt-logs` folder via the `PGT_LOG_PATH` variable.
162-

Diff for: ‎docs/schemas/0.0.0/schema.json

+10
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@
7373
"description": "The configuration of the database connection.",
7474
"type": "object",
7575
"properties": {
76+
"allowStatementExecutionsAgainst": {
77+
"anyOf": [
78+
{
79+
"$ref": "#/definitions/StringSet"
80+
},
81+
{
82+
"type": "null"
83+
}
84+
]
85+
},
7686
"connTimeoutSecs": {
7787
"description": "The connection timeout in seconds.",
7888
"type": [

Diff for: ‎docs/schemas/latest/schema.json

+10
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@
7373
"description": "The configuration of the database connection.",
7474
"type": "object",
7575
"properties": {
76+
"allowStatementExecutionsAgainst": {
77+
"anyOf": [
78+
{
79+
"$ref": "#/definitions/StringSet"
80+
},
81+
{
82+
"type": "null"
83+
}
84+
]
85+
},
7686
"connTimeoutSecs": {
7787
"description": "The connection timeout in seconds.",
7888
"type": [

Diff for: ‎packages/@postgrestools/backend-jsonrpc/src/workspace.ts

+1
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ export interface PartialConfiguration {
230230
* The configuration of the database connection.
231231
*/
232232
export interface PartialDatabaseConfiguration {
233+
allowStatementExecutionsAgainst?: StringSet;
233234
/**
234235
* The connection timeout in seconds.
235236
*/

Diff for: ‎postgrestools.jsonc

+25-24
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
{
2-
"$schema": "./docs/schemas/latest/schema.json",
3-
"vcs": {
4-
"enabled": false,
5-
"clientKind": "git",
6-
"useIgnoreFile": false
7-
},
8-
"files": {
9-
"ignore": []
10-
},
11-
"linter": {
12-
"enabled": true,
13-
"rules": {
14-
"recommended": true
15-
}
16-
},
17-
// YOU CAN COMMENT ME OUT :)
18-
"db": {
19-
"host": "127.0.0.1",
20-
"port": 5432,
21-
"username": "postgres",
22-
"password": "postgres",
23-
"database": "postgres",
24-
"connTimeoutSecs": 10
25-
}
2+
"$schema": "./docs/schemas/latest/schema.json",
3+
"vcs": {
4+
"enabled": false,
5+
"clientKind": "git",
6+
"useIgnoreFile": false
7+
},
8+
"files": {
9+
"ignore": []
10+
},
11+
"linter": {
12+
"enabled": true,
13+
"rules": {
14+
"recommended": true
15+
}
16+
},
17+
// YOU CAN COMMENT ME OUT :)
18+
"db": {
19+
"host": "127.0.0.1",
20+
"port": 5432,
21+
"username": "postgres",
22+
"password": "postgres",
23+
"database": "postgres",
24+
"connTimeoutSecs": 10,
25+
"allowStatementExecutionsAgainst": ["127.0.0.1/*", "localhost/*"]
26+
}
2627
}

Diff for: ‎test.sql

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
select id, name, test1231234123, unknown from co;
1+
create table
2+
unknown_users (id serial primary key, address text, email text);
23

3-
select 14433313331333333333
4+
drop table unknown_users;
45

5-
select * from test;
6-
7-
alter tqjable test drop column id;
8-
9-
alter table test drop column id;
10-
11-
select lower();
6+
select
7+
*
8+
from
9+
unknown_users;

0 commit comments

Comments
 (0)
Please sign in to comment.