ys memos

Blog

Rustのaxumのextensionをよりシンプルに使う


2024/08/26


axumでカスタム構造体をハンドラーで使う方法を紹介しましたが、ハンドラー側ではExtension(hoge): Extension<Hoge>のような長い形式で書く必要がありました。

ここでは、それをよりシンプルに記述できるようにしてみます。


Routerへの追加では、参照先と同じようにExtensionを使う。

だが、ハンドラー側でExtensionの省略は、FromRequestPartsを実装することで実現可能。


参照先の記事に、以下のimplを追加する。

Extension(hoge) = Extension<Hoge>部分の記述を、このトレイトに実装しておくことで、ハンドラ側の引数をシンプルにすることができる。

#[axum::async_trait]
impl<S> axum::extract::FromRequestParts<S> for AccessCounter
where
    S: Send + Sync,
{
    type Rejection = ();

    async fn from_request_parts(
        parts: &mut axum::http::request::Parts,
        state: &S,
    ) -> Result<Self, Self::Rejection> {
        let Extension(service) = Extension::<Self>::from_request_parts(parts, state)
            .await
            .map_err(|_| ())?;
        Ok(service)
    }
}

use std::sync::{
    atomic::{AtomicUsize, Ordering},
    Arc,
};

#[derive(Clone)]
struct AccessCounter {
    counter: Arc<AtomicUsize>,
}

impl AccessCounter {
    pub fn new() -> Self {
        Self {
            counter: Arc::new(AtomicUsize::new(0)),
        }
    }

    pub fn access(&self) -> usize {
        self.counter.fetch_add(1, Ordering::Relaxed)
    }
}

#[axum::async_trait]
impl<S> axum::extract::FromRequestParts<S> for AccessCounter
where
    S: Send + Sync,
{
    type Rejection = ();

    async fn from_request_parts(
        parts: &mut axum::http::request::Parts,
        state: &S,
    ) -> Result<Self, Self::Rejection> {
        let Extension(service) = Extension::<Self>::from_request_parts(parts, state)
            .await
            .map_err(|_| ())?;
        Ok(service)
    }
}

use axum::{routing, Extension, Router};

async fn access_count(access_counter: AccessCounter) -> String {
    format!("{}\n", access_counter.access())
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let access_counter = AccessCounter::new();
    let app = Router::new()
        .route("/", routing::get(access_count))
        .layer(Extension(access_counter));
    let litener = tokio::net::TcpListener::bind("localhost:8080").await?;
    axum::serve(litener, app).await?;
    Ok(())
}
Cargo.toml
axum = "0.7.5"
tokio = { version = "1.39.3", features = ["full"] }


関連タグを探す