-
Notifications
You must be signed in to change notification settings - Fork 98
/
Copy pathsql_function.rs
122 lines (104 loc) · 3.79 KB
/
sql_function.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use std::sync::Arc;
use dashmap::DashMap;
use pgt_text_size::TextRange;
use super::statement_identifier::StatementId;
#[derive(Debug, Clone)]
pub struct SQLFunctionBody {
pub range: TextRange,
pub body: String,
}
pub struct SQLFunctionBodyStore {
db: DashMap<StatementId, Option<Arc<SQLFunctionBody>>>,
}
impl SQLFunctionBodyStore {
pub fn new() -> SQLFunctionBodyStore {
SQLFunctionBodyStore { db: DashMap::new() }
}
pub fn get_function_body(
&self,
statement: &StatementId,
ast: &pgt_query_ext::NodeEnum,
content: &str,
) -> Option<Arc<SQLFunctionBody>> {
// First check if we already have this statement cached
if let Some(existing) = self.db.get(statement).map(|x| x.clone()) {
return existing;
}
// If not cached, try to extract it from the AST
let fn_body = get_sql_fn(ast, content).map(Arc::new);
// Cache the result and return it
self.db.insert(statement.clone(), fn_body.clone());
fn_body
}
pub fn clear_statement(&self, id: &StatementId) {
self.db.remove(id);
if let Some(child_id) = id.get_child_id() {
self.db.remove(&child_id);
}
}
}
/// Extracts SQL function body and its text range from a CreateFunctionStmt node.
/// Returns None if the function is not an SQL function or if the body can't be found.
fn get_sql_fn(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option<SQLFunctionBody> {
let create_fn = match ast {
pgt_query_ext::NodeEnum::CreateFunctionStmt(cf) => cf,
_ => return None,
};
// Extract language from function options
let language = find_option_value(create_fn, "language")?;
// Only process SQL functions
if language != "sql" {
return None;
}
// Extract SQL body from function options
let sql_body = find_option_value(create_fn, "as")?;
// Find the range of the SQL body in the content
let start = content.find(&sql_body)?;
let end = start + sql_body.len();
let range = TextRange::new(start.try_into().unwrap(), end.try_into().unwrap());
Some(SQLFunctionBody {
range,
body: sql_body.clone(),
})
}
/// Helper function to find a specific option value from function options
fn find_option_value(
create_fn: &pgt_query_ext::protobuf::CreateFunctionStmt,
option_name: &str,
) -> Option<String> {
create_fn
.options
.iter()
.filter_map(|opt_wrapper| opt_wrapper.node.as_ref())
.find_map(|opt| {
if let pgt_query_ext::NodeEnum::DefElem(def_elem) = opt {
if def_elem.defname == option_name {
def_elem
.arg
.iter()
.filter_map(|arg_wrapper| arg_wrapper.node.as_ref())
.find_map(|arg| {
if let pgt_query_ext::NodeEnum::String(s) = arg {
Some(s.sval.clone())
} else if let pgt_query_ext::NodeEnum::List(l) = arg {
l.items.iter().find_map(|item_wrapper| {
if let Some(pgt_query_ext::NodeEnum::String(s)) =
item_wrapper.node.as_ref()
{
Some(s.sval.clone())
} else {
None
}
})
} else {
None
}
})
} else {
None
}
} else {
None
}
})
}