ys memos

Blog

RustでOpenAIのChatCompletionsをストリームで受けるmod 〜4.モジュール本体〜


rust

2023/08/31


本コードの動作確認は2023/05時点までしか行っていません。
当時動作していたコードをそのまま公開いたします。

現在のAPIでの動作確認はしていないこと、ご了承ください。

以前、OpenAIのChat Completionsを叩くアプリを作っていた。 しかし、いつしか熱が冷めて開発が止まっていたので、眠らせておくくらいならと思い、コードを公開する。

また、一つの記事にまとめたかったが、ボリュームが大きくなりすぎてしまったので、記事をいくつかに分割した。



モジュール本体の実装。

src/
└── chatter
    ├── error.rs
    ├── headers.rs
    ├── json.rs
    ├── message.rs
    ├── mod.rs             # <-
    ├── request_builder.rs
    ├── role.rs
    └── stream_data.rs


mod error;
mod headers;
mod json;
mod message;
mod request_builder;
mod role;
mod stream_data;

use futures_util::StreamExt;

use self::{
    error::{ChatterError, ChatterResult},
    message::ToMessage,
    role::Role,
    stream_data::ToStreamData,
};

pub struct Chatter {
    messages: Vec<self::message::Message>,
    is_streaming: bool,
}

impl Chatter {
    pub fn new() -> Self {
        Self {
            messages: vec![],
            is_streaming: false,
        }
    }

    pub async fn send<F>(&mut self, mut on_receive: F) -> ChatterResult<()>
    where
        F: FnMut(String) -> ChatterResult<()>,
    {
        if self.is_streaming {
            return Err(ChatterError::DuplicatedSendingError);
        }

        let json = self::json::Json::new(&self.messages);

        self.is_streaming = true;
        let received_text = self::request_builder::create()?
            .json(&json)
            .send()
            .await?
            .bytes_stream()
            .map(|chunk_result| match chunk_result {
                Ok(chunk) => {
                    let mut received = String::new();
                    for line in String::from_utf8_lossy(&chunk).as_ref().lines() {
                        if let Ok(data) = line.to_stream_data() {
                            for choice in data.choices {
                                if let Some(content) = choice.delta.content {
                                    received.push_str(content.as_str());
                                    if let Err(_) = on_receive(content) {
                                        self.is_streaming = false;
                                    }
                                }
                            }
                        }
                    }
                    received
                }
                Err(err) => {
                    eprintln!("Error while receiving chunk: {:?}", err);
                    "".into()
                }
            })
            .collect::<Vec<String>>()
            .await
            .join("");

        self.messages
            .push(received_text.to_message(Role::Assistant));

        self.is_streaming = false;

        Ok(())
    }

    pub fn push_user_message(&mut self, content: String) {
        self.messages.push(content.to_message(Role::User));
    }
}


mod error;
mod headers;
mod json;
mod message;
mod request_builder;
mod role;
mod stream_data;

use futures_util::StreamExt;

use self::{
    error::{ChatterError, ChatterResult},
    message::ToMessage,
    role::Role,
    stream_data::ToStreamData,
};

Chatterの定義。これそのものがMessage一覧を保持する。

is_streamingは、ストリーム受信中にメッセージを投げないためのフラグとしての意図で持たせた。

pub struct Chatter {
    messages: Vec<self::message::Message>,
    is_streaming: bool,
}

コンストラクタ。「Chatterを生成する」ということは、「新しいチャットルームを作成する」と同義である設計をしたため、このような実装になった。

「チャットルームへのアタッチ」を追加したければ、new_with_messages(messages: Vec<Message>) -> Selfなどのようにするとよい。

impl Chatter {
    pub fn new() -> Self {
        Self {
            messages: vec![],
            is_streaming: false,
        }
    }

リクエストを送る。ストリームが帰ってくるたびに、on_receiveが実行される。

    pub async fn send<F>(&mut self, mut on_receive: F) -> ChatterResult<()>
    where
        F: FnMut(String) -> ChatterResult<()>,
    {

ストリーム中の場合はエラーを返す。

        if self.is_streaming {
            return Err(ChatterError::DuplicatedSendingError);
        }

現在のメッセージ一覧からリクエストjsonを作成。

        let json = self::json::Json::new(&self.messages);

ストリームフラグをONにする。

        self.is_streaming = true;

received_textは、受け取ったDeltaをまとめた文字列が格納される。

.bytes_stream()により、レスポンスをバイトストリームとして受け取る。

.map()で処理しているのは、受け取ったDeltaを最後に接続する目的。

ストリームを受け取るたびに、バイト文字列変換→文字列StreamData変換→Delta文字列に対してon_receive実行、を行う。

        let received_text = self::request_builder::create()?
            .json(&json)
            .send()
            .await?
            .bytes_stream()
            .map(|chunk_result| match chunk_result {
                Ok(chunk) => {
                    let mut received = String::new();
                    for line in String::from_utf8_lossy(&chunk).as_ref().lines() {
                        if let Ok(data) = line.to_stream_data() {
                            for choice in data.choices {
                                if let Some(content) = choice.delta.content {
                                    received.push_str(content.as_str());
                                    if let Err(_) = on_receive(content) {
                                        self.is_streaming = false;
                                    }
                                }
                            }
                        }
                    }
                    received
                }
                Err(err) => {
                    eprintln!("Error while receiving chunk: {:?}", err);
                    "".into()
                }
            })
            .collect::<Vec<String>>()
            .await
            .join("");

受け取ったメッセージをChatterに追加し、ストリーム中フラグをOFFにする。

        self.messages
            .push(received_text.to_message(Role::Assistant));

        self.is_streaming = false;

        Ok(())
    }

Chatterにユーザメッセージを追加。

    pub fn push_user_message(&mut self, content: String) {
        self.messages.push(content.to_message(Role::User));
    }
}


関連タグを探す