Database Access

Cornucopia is a code generator that takes small snippets of SQL and turns them into Rust functions.

We'll turn our crates/db folder into a crate so we can keep all our database logic in one place.

Run the following

$ cargo init --lib crates/db
Created library package

Installation

Install cornucopia into your project cd into your crates/db folder.

cd crates/db
cargo add cornucopia_async

Creating a SQL definition

In a folder called db/queries create a file called users.sql and add the following content.

--: User()

--! get_users : User
SELECT 
    id, 
    email
FROM users;

Cornucopia will use the above definition to generate a Rust function called get_users to access the database. Note cornucopia checks the query at code generation time against Postgres.

Updating build.rs

Create a crates/db/build.rs file and add the following content. This file we compile our .sql files into rust code whenever they change.

use std::env;
use std::path::Path;

fn main() {
    // Compile our SQL
    cornucopia();
}

fn cornucopia() {
    // For the sake of simplicity, this example uses the defaults.
    let queries_path = "queries";

    let out_dir = env::var_os("OUT_DIR").unwrap();
    let file_path = Path::new(&out_dir).join("cornucopia.rs");

    let db_url = env::var_os("DATABASE_URL").unwrap();

    // Rerun this build script if the queries or migrations change.
    println!("cargo:rerun-if-changed={queries_path}");

    // Call cornucopia. Use whatever CLI command you need.
    let output = std::process::Command::new("cornucopia")
        .arg("-q")
        .arg(queries_path)
        .arg("--serialize")
        .arg("-d")
        .arg(&file_path)
        .arg("live")
        .arg(db_url)
        .output()
        .unwrap();

    // If Cornucopia couldn't run properly, try to display the error.
    if !output.status.success() {
        panic!("{}", &std::str::from_utf8(&output.stderr).unwrap());
    }
}

Add a function to do connection pooling

Add the following code to crates/db/src/lib.rs will we use this to convert our DATABASE_URL env var into something cornucopia can use for connection pooling.

use std::str::FromStr;
use std::sync::Arc;

pub use cornucopia_async::Params;
pub use deadpool_postgres::{Pool, PoolError, Transaction};
use rustls::client::danger::{ServerCertVerified, ServerCertVerifier, HandshakeSignatureValid};
use rustls_pki_types::{ServerName, CertificateDer, UnixTime};
pub use tokio_postgres::Error as TokioPostgresError;

pub use queries::users::User;

pub fn create_pool(database_url: &str) -> Pool {
    let config = tokio_postgres::Config::from_str(database_url).unwrap();

    let manager = if config.get_ssl_mode() != tokio_postgres::config::SslMode::Disable {
        let tls_config = rustls::ClientConfig::builder()
            .dangerous()
            .with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
            .with_no_client_auth();

        let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
        deadpool_postgres::Manager::new(config, tls)
    } else {
        deadpool_postgres::Manager::new(config, tokio_postgres::NoTls)
    };

    Pool::builder(manager).build().unwrap()
}

#[derive(Debug)]
struct DummyTlsVerifier;

impl ServerCertVerifier for DummyTlsVerifier {
    fn verify_server_cert(
        &self,
        _end_entity: &CertificateDer,
        _intermediates: &[CertificateDer],
        _server_name: &ServerName,
        _ocsp_response: &[u8],
        _now: UnixTime,
    ) -> Result<ServerCertVerified, rustls::Error> {
        Ok(ServerCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer,
        _dss: &rustls::DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, rustls::Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn verify_tls13_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer,
        _dss: &rustls::DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, rustls::Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
        Vec::new()
    }
}

include!(concat!(env!("OUT_DIR"), "/cornucopia.rs"));

Folder Structure

You should now have a folder structure something like this.

.
├── .devcontainer/
   └── ...
└── crates/
         axum-server/
         │  └── main.rs
         └── Cargo.toml
         db/
         ├── migrations
         │   └── 20220330110026_user_tables.sql
         ├── queries
         │   └── users.sql
         ├── src
         │   └── lib.rs
         └── build.rs
├── Cargo.toml
└── Cargo.lock

Testing our database crate

Make sure you're in the crates/db folder.

First add the client side dependencies to our project

cargo add tokio_postgres
cargo add deadpool-postgres
cargo add tokio_postgres_rustls
cargo add postgres_types
cargo add tokio --features macros,rt-multi-thread
cargo add rustls
cargo add rustls-pki-types
cargo add webpki_roots
cargo add futures
cargo add serde --features derive

Make sure everything builds.

cargo build

Add the following code to the bottom of your crates/db/src/lib.rs.

#[cfg(test)]
mod tests {
    use super::*;
    #[tokio::test]
    async fn load_users() {

        let db_url = std::env::var("DATABASE_URL").unwrap();
        let pool = create_pool(&db_url);

        let client = pool.get().await.unwrap();
    
        let users = crate::queries::users::get_users()
            .bind(&client)
            .all()
            .await
            .unwrap();
    
        dbg!(users);
    }
}

Run cargo test -- --nocapture and you should see

Running unittests src/lib.rs (/workspace/target/debug/deps/db-1a59f4c51c8578ce)

running 1 test
[crates/db/src/lib.rs:56] users = [
    User {
        id: 1,
        email: "[email protected]",
    },
    User {
        id: 2,
        email: "[email protected]",
    },
    User {
        id: 3,
        email: "[email protected]",
    },
]

test tests::load_users ... ok