Skip to main content

nautilus_infrastructure/sql/
pg.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use derive_builder::Builder;
17use regex::Regex;
18use sqlx::{ConnectOptions, PgPool, postgres::PgConnectOptions};
19
20fn validate_sql_identifier(value: &str, label: &str) -> anyhow::Result<()> {
21    if value.is_empty() {
22        anyhow::bail!("{label} must not be empty");
23    }
24
25    if !value.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
26        anyhow::bail!(
27            "{label} contains invalid characters (only alphanumeric and underscore allowed): {value}"
28        );
29    }
30    Ok(())
31}
32
33fn escape_sql_string(value: &str) -> String {
34    value.replace('\'', "''")
35}
36
37#[derive(Debug, Clone, Builder)]
38#[builder(default)]
39#[cfg_attr(
40    feature = "python",
41    pyo3::pyclass(
42        module = "nautilus_trader.core.nautilus_pyo3.infrastructure",
43        from_py_object
44    )
45)]
46#[cfg_attr(
47    feature = "python",
48    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.infrastructure")
49)]
50pub struct PostgresConnectOptions {
51    pub host: String,
52    pub port: u16,
53    pub username: String,
54    pub password: String,
55    pub database: String,
56}
57
58impl PostgresConnectOptions {
59    /// Creates a new [`PostgresConnectOptions`] instance.
60    #[must_use]
61    pub const fn new(
62        host: String,
63        port: u16,
64        username: String,
65        password: String,
66        database: String,
67    ) -> Self {
68        Self {
69            host,
70            port,
71            username,
72            password,
73            database,
74        }
75    }
76
77    #[must_use]
78    pub fn connection_string(&self) -> String {
79        format!(
80            "postgres://{username}:{password}@{host}:{port}/{database}",
81            username = self.username,
82            password = self.password,
83            host = self.host,
84            port = self.port,
85            database = self.database
86        )
87    }
88
89    /// Returns the connection string with the password masked for safe logging.
90    #[must_use]
91    pub fn connection_string_masked(&self) -> String {
92        format!(
93            "postgres://{username}:***@{host}:{port}/{database}",
94            username = self.username,
95            host = self.host,
96            port = self.port,
97            database = self.database
98        )
99    }
100
101    #[must_use]
102    pub fn default_administrator() -> Self {
103        Self::new(
104            String::from("localhost"),
105            5432,
106            String::from("nautilus"),
107            String::from("pass"),
108            String::from("nautilus"),
109        )
110    }
111}
112
113impl Default for PostgresConnectOptions {
114    fn default() -> Self {
115        Self::new(
116            String::from("localhost"),
117            5432,
118            String::from("nautilus"),
119            String::from("pass"),
120            String::from("nautilus"),
121        )
122    }
123}
124
125impl From<PostgresConnectOptions> for PgConnectOptions {
126    fn from(opt: PostgresConnectOptions) -> Self {
127        Self::new()
128            .host(opt.host.as_str())
129            .port(opt.port)
130            .username(opt.username.as_str())
131            .password(opt.password.as_str())
132            .database(opt.database.as_str())
133            .disable_statement_logging()
134    }
135}
136
137/// Constructs `PostgresConnectOptions` by merging provided arguments, environment variables, and defaults.
138///
139/// # Panics
140///
141/// Panics if an environment variable for port cannot be parsed into a `u16`.
142#[must_use]
143pub fn get_postgres_connect_options(
144    host: Option<String>,
145    port: Option<u16>,
146    username: Option<String>,
147    password: Option<String>,
148    database: Option<String>,
149) -> PostgresConnectOptions {
150    let defaults = PostgresConnectOptions::default_administrator();
151    let host = host
152        .or_else(|| std::env::var("POSTGRES_HOST").ok())
153        .unwrap_or(defaults.host);
154    let port = port
155        .or_else(|| {
156            std::env::var("POSTGRES_PORT")
157                .map(|port| port.parse::<u16>().unwrap())
158                .ok()
159        })
160        .unwrap_or(defaults.port);
161    let username = username
162        .or_else(|| std::env::var("POSTGRES_USERNAME").ok())
163        .unwrap_or(defaults.username);
164    let database = database
165        .or_else(|| std::env::var("POSTGRES_DATABASE").ok())
166        .unwrap_or(defaults.database);
167    let password = password
168        .or_else(|| std::env::var("POSTGRES_PASSWORD").ok())
169        .unwrap_or(defaults.password);
170    PostgresConnectOptions::new(host, port, username, password, database)
171}
172
173/// Connects to a Postgres database with the provided connection `options` returning a connection pool.
174///
175/// # Errors
176///
177/// Returns an error if establishing the database connection fails.
178pub async fn connect_pg(options: PgConnectOptions) -> anyhow::Result<PgPool> {
179    Ok(PgPool::connect_with(options).await?)
180}
181
182/// Scans the current working directory for the `nautilus_trader` repository
183/// and constructs the path to the SQL schema directory.
184///
185/// # Errors
186///
187/// Returns an error if the `SCHEMA_DIR` environment variable is not set and the repository
188/// cannot be located in the current directory path.
189///
190/// # Panics
191///
192/// Panics if the current working directory cannot be determined or contains invalid UTF-8.
193fn get_schema_dir() -> anyhow::Result<String> {
194    std::env::var("SCHEMA_DIR").or_else(|_| {
195        let nautilus_git_repo_name = "nautilus_trader";
196        let binding = std::env::current_dir().unwrap();
197        let current_dir = binding.to_str().unwrap();
198        match current_dir.find(nautilus_git_repo_name){
199            Some(index) => {
200                let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema/sql";
201                Ok(schema_path)
202            }
203            None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable")
204        }
205    })
206}
207
208/// Initializes the Postgres database by creating schema, roles, and executing SQL files from `schema_dir`.
209///
210/// # Errors
211///
212/// Returns an error if any SQL execution or file system operation fails.
213///
214/// # Panics
215///
216/// Panics if `schema_dir` is missing and cannot be determined or if other unwraps fail.
217pub async fn init_postgres(
218    pg: &PgPool,
219    database: String,
220    password: String,
221    schema_dir: Option<String>,
222) -> anyhow::Result<()> {
223    log::info!("Initializing Postgres database with target permissions and schema");
224
225    validate_sql_identifier(&database, "database")?;
226
227    // Create public schema
228    match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;")
229        .execute(pg)
230        .await
231    {
232        Ok(_) => log::info!("Schema public created successfully"),
233        Err(e) => log::error!("Error creating schema public: {e:?}"),
234    }
235
236    // Create role if not exists
237    let escaped_password = escape_sql_string(&password);
238    match sqlx::query(
239        format!("CREATE ROLE {database} PASSWORD '{escaped_password}' LOGIN;").as_str(),
240    )
241    .execute(pg)
242    .await
243    {
244        Ok(_) => log::info!("Role {database} created successfully"),
245        Err(e) => {
246            if e.to_string().contains("already exists") {
247                log::info!("Role {database} already exists");
248            } else {
249                log::error!("Error creating role {database}: {e:?}");
250            }
251        }
252    }
253
254    // Execute all the sql files in schema dir
255    let schema_dir = schema_dir.unwrap_or_else(|| get_schema_dir().unwrap());
256    let sql_files = vec!["types.sql", "functions.sql", "partitions.sql", "tables.sql"];
257    let plpgsql_regex =
258        Regex::new(r"\$\$ LANGUAGE plpgsql(?:[ \t\r\n]+SECURITY[ \t\r\n]+DEFINER)?;")?;
259
260    for file_name in &sql_files {
261        log::info!("Executing schema file: {file_name:?}");
262        let file_path = format!("{schema_dir}/{file_name}");
263        let sql_content = std::fs::read_to_string(&file_path)?;
264        let sql_statements: Vec<String> = match *file_name {
265            "functions.sql" | "partitions.sql" => {
266                let mut statements = Vec::new();
267                let mut last_end = 0;
268
269                for mat in plpgsql_regex.find_iter(&sql_content) {
270                    let statement = sql_content[last_end..mat.end()].to_string();
271                    if !statement.trim().is_empty() {
272                        statements.push(statement);
273                    }
274                    last_end = mat.end();
275                }
276                statements
277            }
278            _ => sql_content
279                .split(';')
280                .filter(|s| !s.trim().is_empty())
281                .map(|s| format!("{s};"))
282                .collect(),
283        };
284
285        for sql_statement in sql_statements {
286            sqlx::query(&sql_statement)
287                .execute(pg)
288                .await
289                .map_err(|e| {
290                    if e.to_string().contains("already exists") {
291                        log::info!("Already exists error on statement, skipping");
292                    } else {
293                        panic!("Error executing statement {sql_statement} with error: {e:?}")
294                    }
295                })
296                .unwrap();
297        }
298    }
299
300    // Grant connect
301    match sqlx::query(format!("GRANT CONNECT ON DATABASE {database} TO {database};").as_str())
302        .execute(pg)
303        .await
304    {
305        Ok(_) => log::info!("Connect privileges granted to role {database}"),
306        Err(e) => log::error!("Error granting connect privileges to role {database}: {e:?}"),
307    }
308
309    // Grant all schema privileges to the role
310    match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {database};").as_str())
311        .execute(pg)
312        .await
313    {
314        Ok(_) => log::info!("All schema privileges granted to role {database}"),
315        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
316    }
317
318    // Grant all table privileges to the role
319    match sqlx::query(
320        format!("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {database};").as_str(),
321    )
322    .execute(pg)
323    .await
324    {
325        Ok(_) => log::info!("All tables privileges granted to role {database}"),
326        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
327    }
328
329    // Grant all sequence privileges to the role
330    match sqlx::query(
331        format!("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {database};").as_str(),
332    )
333    .execute(pg)
334    .await
335    {
336        Ok(_) => log::info!("All sequences privileges granted to role {database}"),
337        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
338    }
339
340    // Grant all function privileges to the role
341    match sqlx::query(
342        format!("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {database};").as_str(),
343    )
344    .execute(pg)
345    .await
346    {
347        Ok(_) => log::info!("All functions privileges granted to role {database}"),
348        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
349    }
350
351    Ok(())
352}
353
354/// Drops the Postgres database with the given name using the provided connection pool.
355///
356/// # Errors
357///
358/// Returns an error if the DROP DATABASE command fails.
359pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> {
360    validate_sql_identifier(&database, "database")?;
361
362    // Execute drop owned
363    match sqlx::query(format!("DROP OWNED BY {database}").as_str())
364        .execute(pg)
365        .await
366    {
367        Ok(_) => log::info!("Dropped owned objects by role {database}"),
368        Err(e) => {
369            let err_msg = e.to_string();
370            if err_msg.contains("2BP01") || err_msg.contains("required by the database system") {
371                log::warn!("Skipping system-required objects for role {database}");
372            } else {
373                log::error!("Error dropping owned by role {database}: {e:?}");
374            }
375        }
376    }
377
378    // Revoke connect
379    match sqlx::query(format!("REVOKE CONNECT ON DATABASE {database} FROM {database};").as_str())
380        .execute(pg)
381        .await
382    {
383        Ok(_) => log::info!("Revoked connect privileges from role {database}"),
384        Err(e) => log::error!("Error revoking connect privileges from role {database}: {e:?}"),
385    }
386
387    // Revoke privileges
388    match sqlx::query(
389        format!("REVOKE ALL PRIVILEGES ON DATABASE {database} FROM {database};").as_str(),
390    )
391    .execute(pg)
392    .await
393    {
394        Ok(_) => log::info!("Revoked all privileges from role {database}"),
395        Err(e) => log::error!("Error revoking all privileges from role {database}: {e:?}"),
396    }
397
398    // Execute drop schema
399    match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE")
400        .execute(pg)
401        .await
402    {
403        Ok(_) => log::info!("Dropped schema public"),
404        Err(e) => log::error!("Error dropping schema public: {e:?}"),
405    }
406
407    // Drop role
408    match sqlx::query(format!("DROP ROLE IF EXISTS {database};").as_str())
409        .execute(pg)
410        .await
411    {
412        Ok(_) => log::info!("Dropped role {database}"),
413        Err(e) => {
414            let err_msg = e.to_string();
415            if err_msg.contains("55006") || err_msg.contains("current user cannot be dropped") {
416                log::warn!("Cannot drop currently connected role {database}");
417            } else {
418                log::error!("Error dropping role {database}: {e:?}");
419            }
420        }
421    }
422    Ok(())
423}