nautilus_infrastructure/sql/
pg.rs1use derive_builder::Builder;
17use regex::Regex;
18use sqlx::{ConnectOptions, PgPool, postgres::PgConnectOptions};
19
20fn validate_sql_identifier(value: &str, label: &str) -> anyhow::Result<()> {
21 if value.is_empty() {
22 anyhow::bail!("{label} must not be empty");
23 }
24
25 if !value.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
26 anyhow::bail!(
27 "{label} contains invalid characters (only alphanumeric and underscore allowed): {value}"
28 );
29 }
30 Ok(())
31}
32
33fn escape_sql_string(value: &str) -> String {
34 value.replace('\'', "''")
35}
36
37#[derive(Debug, Clone, Builder)]
38#[builder(default)]
39#[cfg_attr(
40 feature = "python",
41 pyo3::pyclass(
42 module = "nautilus_trader.core.nautilus_pyo3.infrastructure",
43 from_py_object
44 )
45)]
46#[cfg_attr(
47 feature = "python",
48 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.infrastructure")
49)]
50pub struct PostgresConnectOptions {
51 pub host: String,
52 pub port: u16,
53 pub username: String,
54 pub password: String,
55 pub database: String,
56}
57
58impl PostgresConnectOptions {
59 #[must_use]
61 pub const fn new(
62 host: String,
63 port: u16,
64 username: String,
65 password: String,
66 database: String,
67 ) -> Self {
68 Self {
69 host,
70 port,
71 username,
72 password,
73 database,
74 }
75 }
76
77 #[must_use]
78 pub fn connection_string(&self) -> String {
79 format!(
80 "postgres://{username}:{password}@{host}:{port}/{database}",
81 username = self.username,
82 password = self.password,
83 host = self.host,
84 port = self.port,
85 database = self.database
86 )
87 }
88
89 #[must_use]
91 pub fn connection_string_masked(&self) -> String {
92 format!(
93 "postgres://{username}:***@{host}:{port}/{database}",
94 username = self.username,
95 host = self.host,
96 port = self.port,
97 database = self.database
98 )
99 }
100
101 #[must_use]
102 pub fn default_administrator() -> Self {
103 Self::new(
104 String::from("localhost"),
105 5432,
106 String::from("nautilus"),
107 String::from("pass"),
108 String::from("nautilus"),
109 )
110 }
111}
112
113impl Default for PostgresConnectOptions {
114 fn default() -> Self {
115 Self::new(
116 String::from("localhost"),
117 5432,
118 String::from("nautilus"),
119 String::from("pass"),
120 String::from("nautilus"),
121 )
122 }
123}
124
125impl From<PostgresConnectOptions> for PgConnectOptions {
126 fn from(opt: PostgresConnectOptions) -> Self {
127 Self::new()
128 .host(opt.host.as_str())
129 .port(opt.port)
130 .username(opt.username.as_str())
131 .password(opt.password.as_str())
132 .database(opt.database.as_str())
133 .disable_statement_logging()
134 }
135}
136
137#[must_use]
143pub fn get_postgres_connect_options(
144 host: Option<String>,
145 port: Option<u16>,
146 username: Option<String>,
147 password: Option<String>,
148 database: Option<String>,
149) -> PostgresConnectOptions {
150 let defaults = PostgresConnectOptions::default_administrator();
151 let host = host
152 .or_else(|| std::env::var("POSTGRES_HOST").ok())
153 .unwrap_or(defaults.host);
154 let port = port
155 .or_else(|| {
156 std::env::var("POSTGRES_PORT")
157 .map(|port| port.parse::<u16>().unwrap())
158 .ok()
159 })
160 .unwrap_or(defaults.port);
161 let username = username
162 .or_else(|| std::env::var("POSTGRES_USERNAME").ok())
163 .unwrap_or(defaults.username);
164 let database = database
165 .or_else(|| std::env::var("POSTGRES_DATABASE").ok())
166 .unwrap_or(defaults.database);
167 let password = password
168 .or_else(|| std::env::var("POSTGRES_PASSWORD").ok())
169 .unwrap_or(defaults.password);
170 PostgresConnectOptions::new(host, port, username, password, database)
171}
172
173pub async fn connect_pg(options: PgConnectOptions) -> anyhow::Result<PgPool> {
179 Ok(PgPool::connect_with(options).await?)
180}
181
182fn get_schema_dir() -> anyhow::Result<String> {
194 std::env::var("SCHEMA_DIR").or_else(|_| {
195 let nautilus_git_repo_name = "nautilus_trader";
196 let binding = std::env::current_dir().unwrap();
197 let current_dir = binding.to_str().unwrap();
198 match current_dir.find(nautilus_git_repo_name){
199 Some(index) => {
200 let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema/sql";
201 Ok(schema_path)
202 }
203 None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable")
204 }
205 })
206}
207
208pub async fn init_postgres(
218 pg: &PgPool,
219 database: String,
220 password: String,
221 schema_dir: Option<String>,
222) -> anyhow::Result<()> {
223 log::info!("Initializing Postgres database with target permissions and schema");
224
225 validate_sql_identifier(&database, "database")?;
226
227 match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;")
229 .execute(pg)
230 .await
231 {
232 Ok(_) => log::info!("Schema public created successfully"),
233 Err(e) => log::error!("Error creating schema public: {e:?}"),
234 }
235
236 let escaped_password = escape_sql_string(&password);
238 match sqlx::query(
239 format!("CREATE ROLE {database} PASSWORD '{escaped_password}' LOGIN;").as_str(),
240 )
241 .execute(pg)
242 .await
243 {
244 Ok(_) => log::info!("Role {database} created successfully"),
245 Err(e) => {
246 if e.to_string().contains("already exists") {
247 log::info!("Role {database} already exists");
248 } else {
249 log::error!("Error creating role {database}: {e:?}");
250 }
251 }
252 }
253
254 let schema_dir = schema_dir.unwrap_or_else(|| get_schema_dir().unwrap());
256 let sql_files = vec!["types.sql", "functions.sql", "partitions.sql", "tables.sql"];
257 let plpgsql_regex =
258 Regex::new(r"\$\$ LANGUAGE plpgsql(?:[ \t\r\n]+SECURITY[ \t\r\n]+DEFINER)?;")?;
259
260 for file_name in &sql_files {
261 log::info!("Executing schema file: {file_name:?}");
262 let file_path = format!("{schema_dir}/{file_name}");
263 let sql_content = std::fs::read_to_string(&file_path)?;
264 let sql_statements: Vec<String> = match *file_name {
265 "functions.sql" | "partitions.sql" => {
266 let mut statements = Vec::new();
267 let mut last_end = 0;
268
269 for mat in plpgsql_regex.find_iter(&sql_content) {
270 let statement = sql_content[last_end..mat.end()].to_string();
271 if !statement.trim().is_empty() {
272 statements.push(statement);
273 }
274 last_end = mat.end();
275 }
276 statements
277 }
278 _ => sql_content
279 .split(';')
280 .filter(|s| !s.trim().is_empty())
281 .map(|s| format!("{s};"))
282 .collect(),
283 };
284
285 for sql_statement in sql_statements {
286 sqlx::query(&sql_statement)
287 .execute(pg)
288 .await
289 .map_err(|e| {
290 if e.to_string().contains("already exists") {
291 log::info!("Already exists error on statement, skipping");
292 } else {
293 panic!("Error executing statement {sql_statement} with error: {e:?}")
294 }
295 })
296 .unwrap();
297 }
298 }
299
300 match sqlx::query(format!("GRANT CONNECT ON DATABASE {database} TO {database};").as_str())
302 .execute(pg)
303 .await
304 {
305 Ok(_) => log::info!("Connect privileges granted to role {database}"),
306 Err(e) => log::error!("Error granting connect privileges to role {database}: {e:?}"),
307 }
308
309 match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {database};").as_str())
311 .execute(pg)
312 .await
313 {
314 Ok(_) => log::info!("All schema privileges granted to role {database}"),
315 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
316 }
317
318 match sqlx::query(
320 format!("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {database};").as_str(),
321 )
322 .execute(pg)
323 .await
324 {
325 Ok(_) => log::info!("All tables privileges granted to role {database}"),
326 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
327 }
328
329 match sqlx::query(
331 format!("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {database};").as_str(),
332 )
333 .execute(pg)
334 .await
335 {
336 Ok(_) => log::info!("All sequences privileges granted to role {database}"),
337 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
338 }
339
340 match sqlx::query(
342 format!("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {database};").as_str(),
343 )
344 .execute(pg)
345 .await
346 {
347 Ok(_) => log::info!("All functions privileges granted to role {database}"),
348 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
349 }
350
351 Ok(())
352}
353
354pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> {
360 validate_sql_identifier(&database, "database")?;
361
362 match sqlx::query(format!("DROP OWNED BY {database}").as_str())
364 .execute(pg)
365 .await
366 {
367 Ok(_) => log::info!("Dropped owned objects by role {database}"),
368 Err(e) => {
369 let err_msg = e.to_string();
370 if err_msg.contains("2BP01") || err_msg.contains("required by the database system") {
371 log::warn!("Skipping system-required objects for role {database}");
372 } else {
373 log::error!("Error dropping owned by role {database}: {e:?}");
374 }
375 }
376 }
377
378 match sqlx::query(format!("REVOKE CONNECT ON DATABASE {database} FROM {database};").as_str())
380 .execute(pg)
381 .await
382 {
383 Ok(_) => log::info!("Revoked connect privileges from role {database}"),
384 Err(e) => log::error!("Error revoking connect privileges from role {database}: {e:?}"),
385 }
386
387 match sqlx::query(
389 format!("REVOKE ALL PRIVILEGES ON DATABASE {database} FROM {database};").as_str(),
390 )
391 .execute(pg)
392 .await
393 {
394 Ok(_) => log::info!("Revoked all privileges from role {database}"),
395 Err(e) => log::error!("Error revoking all privileges from role {database}: {e:?}"),
396 }
397
398 match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE")
400 .execute(pg)
401 .await
402 {
403 Ok(_) => log::info!("Dropped schema public"),
404 Err(e) => log::error!("Error dropping schema public: {e:?}"),
405 }
406
407 match sqlx::query(format!("DROP ROLE IF EXISTS {database};").as_str())
409 .execute(pg)
410 .await
411 {
412 Ok(_) => log::info!("Dropped role {database}"),
413 Err(e) => {
414 let err_msg = e.to_string();
415 if err_msg.contains("55006") || err_msg.contains("current user cannot be dropped") {
416 log::warn!("Cannot drop currently connected role {database}");
417 } else {
418 log::error!("Error dropping role {database}: {e:?}");
419 }
420 }
421 }
422 Ok(())
423}