Creating a REST API with Axum + Sqlx | Rust.

I started to use Axum a few weeks ago, honestly, I'm a fan of the framework, so I'm writing this article to document my learning. In this article, we are going to build a REST API using Axum as a web framework and Sqlx for SQL queries.

In this article, we will build a little REST API with Axum and Sqlx for the database.

If you don't know what Axum is, here is what its page says:

Axum is a web application framework that focuses on ergonomics and modularity.

High-level features:

  • Route requests to handlers with a macro-free API.

  • Declaratively parse requests using extractors.

  • Simple and predictable error handling model.

  • Generate responses with minimal boilerplate.

Take full advantage of the tower and tower-http ecosystem of middleware, services, and utilities. In particular the last point is what sets axum apart from other frameworks. axum doesn't have its own middleware system but instead uses tower::Service. This means axum gets timeouts, tracing, compression, authorization, and more, for free. It also enables you to share middleware with applications written using hyper or tonic.

Here is Axum's documentation.

About Sqlx:

SQLx is an async, pure Rust† SQL crate featuring compile-time checked queries without a DSL.

  • Truly Asynchronous. Built from the ground up using async/await for maximum concurrency.

  • Compile-time checked queries (if you want). See SQLx is not an ORM.

  • Database Agnostic. Support for PostgreSQL, MySQL, SQLite, and MSSQL.

  • Pure Rust. The Postgres and MySQL/MariaDB drivers are written in pure Rust using zero unsafe†† code.

  • Runtime Agnostic. Works on different runtimes (async-std / tokio / actix) and TLS backends (native-tls, rustls). † The SQLite driver uses the libsqlite3 C library as SQLite is an embedded database (the only way we could be pure Rust for SQLite is by porting all of SQLite to Rust).

†† SQLx uses #![forbid(unsafe_code)] unless the SQLite feature is enabled. As the SQLite driver interacts with C, those interactions are unsafe.

  • Cross-platform. Being native Rust, SQLx will compile anywhere Rust is supported.

  • Built-in connection pooling with sqlx::Pool.

  • Row streaming. Data is read asynchronously from the database and decoded on-demand.

  • Automatic statement preparation and caching. When using the high-level query API (sqlx::query), statements are prepared and cached per connection.

  • Simple (unprepared) query execution including fetching results into the same Row types used by the high-level API. Supports batch execution and returning results from all statements.

  • Transport Layer Security (TLS) where supported (MySQL and PostgreSQL).

  • Asynchronous notifications using LISTEN and NOTIFY for PostgreSQL.

  • Nested transactions with support for saving points.

  • Any database driver for changing the database driver at runtime. An AnyPool connects to the driver indicated by the URL scheme.

Here is Sqlx documentation

First, we generate our project folder.

cargo new axum_crud_api

Now we add the dependencies.

Cargo.toml

[package]
name = "axum_crud_api"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
axum = "0.5.9"
tokio = { version = "1.0", features = ["full"] }
serde = "1.0.137"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"]}

sqlx = { version = "0.5", features = ["runtime-tokio-native-tls", "json", "postgres"] }
anyhow = "1.0.58"
serde_json = "1.0.57"
tower-http = { version = "0.3.4", features = ["trace"] }

Let's write an example. This is like Hello World's example from its Github page, just with a few changes, here is the source code.

main.rs

use axum::{
    routing::{get},
    Router,
};

use std::net::SocketAddr;

#[tokio::main]
async fn main() {


    let app = Router::new()
        .route("/", get(root));

    let addr = SocketAddr::from(([127, 0 , 0, 1], 8000));
    println!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();

}

async fn root() -> &'static str {
    "Hello, World!"
}

Now we run the code

cargo run

It should print "listening on 127.0.0.1:8000" in our console, and if we copy the number and paste it into our browser we should see this page:

Hello, World.jpg

This is going to be our directory structure:

axum_crud_api/
  ---migrations/
  ---.env
 ---src/
     ---errors.rs
      ---main.rs
      ---controllers/
           ---task.rs
      ---models/
           ---task.rs

Models

Let's create a folder to store our model's app and create a file named task.rs in it.

task.rs

use serde::{Deserialize, Serialize};

#[derive(sqlx::FromRow, Deserialize, Serialize)]
pub struct Task {
    pub id: i32,
    pub task: String,
}

#[derive(sqlx::FromRow, Deserialize, Serialize)]
pub struct NewTask {
    pub task: String,
}

We create a file in our root directory to store our database URL:

DATABASE_URL = postgresql://user:password@locahost:host/database

To create our database, we need to have installed sqlx-cli, here are the instructions from the doc:

# supports all databases supported by SQLx
$ cargo install sqlx-cli

# only for Postgres
$ cargo install sqlx-cli --no-default-features --features native-tls,postgres

# use vendored OpenSSL (build from source)
$ cargo install sqlx-cli --features openssl-vendored

# use Rustls rather than OpenSSL (be sure to add the features for the databases you intend to use!)
$ cargo install sqlx-cli --no-default-features --features rustls

After we have sqlx-cli installed on our machine, we run the next code in our terminal to create our database:

sqlx database create

Then we run this code in our terminal, it creates a new in migrations/<timestamp>-<name>.sql and there is where we can add our schema:

sqlx migrate add task

In migrations/task.sql:

CREATE TABLE task (
    id  SERIAL PRIMARY KEY,
    task varchar(255) NOT NULL

);

Then we run the following code in our terminal to run migrations:

sqlx migrate run

If the migration was applied, it will show the next message in our terminal:

Applied <timestamp>task.sql

Now, we will change a few things in our main.rs file to connect our app to the database.

main.rs

We are using Postgres, so we need to import PgPoolOptions first to handle the connection.

use axum::{
    extract::{Extension},routing::{get, post}, Router,
};

use sqlx::postgres::PgPoolOptions;
use std::net::SocketAddr;
use std::fs;
use anyhow::Context;

#[tokio::main]
async fn main() -> anyhow::Result<()> {

    let env = fs::read_to_string(".env").unwrap();
    let (key, database_url) = env.split_once('=').unwrap();


    assert_eq!(key, "DATABASE_URL");

    tracing_subscriber::fmt::init();

    let pool = PgPoolOptions::new()
    .max_connections(50)
    .connect(&database_url)
    .await
    .context("could not connect to database_url")?;

    let app = Router::new()
        .route("/hello", get(root));

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    tracing::debug!("Listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await?;

        Ok(())

}

We create a pool instance and set the number of max connections to 50 connections. Then we pass our database URL to the connect method, which creates a new pool from PgPoolOptions and immediately opens at least one connection. For more details, here is the doc.

We use std::fs to read our DATABASE_URL in our .env file, and stored in database_url variable.

Controllers

We create a task.rs file.

task.rs

In this file, we are going to add the controllers to do CRUD operations.

To create our handlers we need to import response::IntoResponse, http::StatusCode, Extension to extract state and Json.

GET

use axum::response::IntoResponse;
use axum::http::StatusCode;

use axum::{Extension, Json};
use sqlx::PgPool;

use crate::{
    models::task
};


pub async fn all_tasks(Extension(pool): Extension<PgPool>) -> impl IntoResponse {
    let sql = "SELECT * FROM task ".to_string();

    let task = sqlx::query_as::<_, task::Task>(&sql).fetch_all(&pool).await.unwrap();

    (StatusCode::OK, Json(task))
}

The all_tasks controller retrieves all the tasks in our database. It receives PgPool parameters and returns all the tasks in a JSON format.

We use query_as to make a SQL query that is mapped to a concrete type using FromRow, in this case, task::Task, and use the fetch_all function, it executes the query and returns all the generated results, collected into a Vec. More details here

Errors.rs

In this file, we are going to implement the IntoResponse trait to create custom errors and use them as a response for our controllers.

use axum::{http::StatusCode, response::IntoResponse, Json};
use serde_json::json;


pub enum CustomError {
    BadRequest,
    TaskNotFound,
    InternalServerError,
}

impl IntoResponse for CustomError {
    fn into_response(self) -> axum::response::Response {
        let (status, error_message) = match self {
            Self::InternalServerError => (
                StatusCode::INTERNAL_SERVER_ERROR,
                "Internal Server Error",
            ),
            Self::BadRequest=> (StatusCode::BAD_REQUEST, "Bad Request"),
            Self::TaskNotFound => (StatusCode::NOT_FOUND, "Task Not Found"),
        };
        (status, Json(json!({"error": error_message}))).into_response()
    }
}

GET by Id

pub async fn task(Path(id):Path<i32>, 
Extension(pool): Extension<PgPool>) -> Result <Json<task::Task>, CustomError> {

    let sql = "SELECT * FROM task where id=$1".to_string();

    let task: task::Task = sqlx::query_as(&sql)
        .bind(id)
        .fetch_one(&pool)
        .await
        .map_err(|_| {
            CustomError::TaskNotFound
        })?;


    Ok(Json(task))  
}

In the task controller we pass and id and use the Path extractor to extract from the URL. We pass the id to the query, and if it is in the database the controller returns the task as JSON, if it is not, a Task Not Found as a message.

POST

pub async fn new_task(Json(task): Json<task::NewTask>, 
Extension(pool): Extension<PgPool>) -> Result <(StatusCode, 
Json<task::NewTask>), CustomError> {

    if task.task.is_empty() {
        return Err(CustomError::BadRequest)
    }
    let sql = "INSERT INTO task (task) values ($1)";

    let _ = sqlx::query(&sql)
        .bind(&task.task)
        .execute(&pool)
        .await
        .map_err(|_| {
            CustomError::InternalServerError
        })?;

    Ok((StatusCode::CREATED, Json(task)))
}

The new_task controller has a Json extractor as a parameter. According to its doc, JSON is an extractor that consumes the request body and deserializes it as JSON into some target type. In this code, the target type is NewTask.

We check if the JSON has the task field empty, and if it does, the function returns a Bad Request error message.

We use sql::query to make an SQL query and pass to it our query, store it in the sql variable, and use the bind function to bind the value to the query, in this case, the task field of NewTask. If there is a problem with the query it will return an Internal Server Error message.

PUT

pub async fn update_task(Path(id): Path<i32>, 
Json(task): Json<task::UpdateTask>, Extension(pool): Extension<PgPool>) 
-> Result <(StatusCode, Json<task::UpdateTask>), CustomError> {


    let sql = "SELECT * FROM task where id=$1".to_string();

    let _find: task::Task = sqlx::query_as(&sql)
        .bind(id)
        .fetch_one(&pool)
        .await
        .map_err(|_| {
            CustomError::TaskNotFound
        })?;

    sqlx::query("UPDATE task SET task=$1 WHERE id=$2")
        .bind(&task.task)
        .bind(id)
        .execute(&pool)
        .await;


    Ok((StatusCode::OK, Json(task)))
}

In update_task we passed it and id through the path and use the Path extractor, and the JSON with the fields we want to update, in this case only the task field. But first, we check that the id passed is in the database, if it's not, the controller returns a Task Not Found message.

Then we pass the SQL query to the query function and pass the task and id to the bind function. We are using Postgres, so the field that has $1, binds first than the $2. The controller returns the JSON updated.

DELETE

pub async fn delete_task(Path(id): Path<i32>, 
Extension(pool): Extension<PgPool>) 
-> Result <(StatusCode, Json<Value>), CustomError> {


    let _find: task::Task = sqlx::query_as("SELECT * FROM task where id=$1")
        .bind(id)
        .fetch_one(&pool)
        .await
        .map_err(|_| {
            CustomError::TaskNotFound
        })?;

    sqlx::query("DELETE FROM task WHERE id=$1")
        .bind(id)
        .execute(&pool)
        .await
        .map_err(|_| {
            CustomError::TaskNotFound
        })?;

        Ok((StatusCode::OK, Json(json!({"msg": "Task Deleted"}))))
}

In delete_task we pass the id of the task we want to delete, pass it to the bind function and the SQL statement to the query function, and return a message when the task is deleted after we check that the id passed is in the database.

Now, let's update our main.rs to add the controllers.

main.rs

use axum::{
    extract::{Extension},routing::{get, post, put, delete}, Router,
};

use sqlx::postgres::PgPoolOptions;
use std::net::SocketAddr;
use std::fs;
use anyhow::Context;
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

mod errors;
mod models;
mod controllers;

#[tokio::main]
async fn main() -> anyhow::Result<()> {

    let env = fs::read_to_string(".env").unwrap();
    let (key, database_url) = env.split_once('=').unwrap();


    assert_eq!(key, "DATABASE_URL");

    tracing_subscriber::registry()
        .with(tracing_subscriber::EnvFilter::new(
            std::env::var("tower_http=trace")
                .unwrap_or_else(|_| "example_tracing_aka_logging=debug,tower_http=debug".into()),
        ))
        .with(tracing_subscriber::fmt::layer())
        .init();

    let pool = PgPoolOptions::new()
    .max_connections(50)
    .connect(&database_url)
    .await
    .context("could not connect to database_url")?;

    let app = Router::new()
        .route("/hello", get(root))
        .route("/tasks", get(controllers::task::all_tasks))
        .route("/task", post(controllers::task::new_task))
        .route("/task/:id",get(controllers::task::task))
        .route("/task/:id", put(controllers::task::update_task))
        .route("/task/:id", delete(controllers::task::delete_task))
        .layer(Extension(pool))
        .layer(TraceLayer::new_for_http());

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    tracing::debug!("Listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await?;

        Ok(())

}

async fn root() -> &'static str {
    "Hello, World!"
}

We add tower-http::trace::TraceLayer to get logging, to do that we pass TraceLayer::new_for_http() as an argument to layer function in the route instance.

Here is the complete source code.

Thank you for taking the time to read this article.

If you have any recommendations about other packages, architectures, how to improve my code, my English, or anything; please leave a comment or contact me through Twitter, LinkedIn.

References