2023/08/31
はじめに
本コードの動作確認は2023/05時点までしか行っていません。
当時動作していたコードをそのまま公開いたします。
現在のAPIでの動作確認はしていないこと、ご了承ください。
以前、OpenAIのChat Completionsを叩くアプリを作っていた。 しかし、いつしか熱が冷めて開発が止まっていたので、眠らせておくくらいならと思い、コードを公開する。
また、一つの記事にまとめたかったが、ボリュームが大きくなりすぎてしまったので、記事をいくつかに分割した。
目次
- 0 概要
- 1 カスタムエラー
- 2 型定義
- 3 リクエストユーティリティ
- 4 モジュール本体
実装対象
モジュール本体の実装。
src/
└── chatter
├── error.rs
├── headers.rs
├── json.rs
├── message.rs
├── mod.rs # <-
├── request_builder.rs
├── role.rs
└── stream_data.rs
コード
mod.rs
mod error;
mod headers;
mod json;
mod message;
mod request_builder;
mod role;
mod stream_data;
use futures_util::StreamExt;
use self::{
error::{ChatterError, ChatterResult},
message::ToMessage,
role::Role,
stream_data::ToStreamData,
};
pub struct Chatter {
messages: Vec<self::message::Message>,
is_streaming: bool,
}
impl Chatter {
pub fn new() -> Self {
Self {
messages: vec![],
is_streaming: false,
}
}
pub async fn send<F>(&mut self, mut on_receive: F) -> ChatterResult<()>
where
F: FnMut(String) -> ChatterResult<()>,
{
if self.is_streaming {
return Err(ChatterError::DuplicatedSendingError);
}
let json = self::json::Json::new(&self.messages);
self.is_streaming = true;
let received_text = self::request_builder::create()?
.json(&json)
.send()
.await?
.bytes_stream()
.map(|chunk_result| match chunk_result {
Ok(chunk) => {
let mut received = String::new();
for line in String::from_utf8_lossy(&chunk).as_ref().lines() {
if let Ok(data) = line.to_stream_data() {
for choice in data.choices {
if let Some(content) = choice.delta.content {
received.push_str(content.as_str());
if let Err(_) = on_receive(content) {
self.is_streaming = false;
}
}
}
}
}
received
}
Err(err) => {
eprintln!("Error while receiving chunk: {:?}", err);
"".into()
}
})
.collect::<Vec<String>>()
.await
.join("");
self.messages
.push(received_text.to_message(Role::Assistant));
self.is_streaming = false;
Ok(())
}
pub fn push_user_message(&mut self, content: String) {
self.messages.push(content.to_message(Role::User));
}
}
解説
mod.rs
mod error;
mod headers;
mod json;
mod message;
mod request_builder;
mod role;
mod stream_data;
use futures_util::StreamExt;
use self::{
error::{ChatterError, ChatterResult},
message::ToMessage,
role::Role,
stream_data::ToStreamData,
};
Chatter
の定義。これそのものがMessage
一覧を保持する。
is_streaming
は、ストリーム受信中にメッセージを投げないためのフラグとしての意図で持たせた。
pub struct Chatter {
messages: Vec<self::message::Message>,
is_streaming: bool,
}
コンストラクタ。「Chatter
を生成する」ということは、「新しいチャットルームを作成する」と同義である設計をしたため、このような実装になった。
「チャットルームへのアタッチ」を追加したければ、new_with_messages(messages: Vec<Message>) -> Self
などのようにするとよい。
impl Chatter {
pub fn new() -> Self {
Self {
messages: vec![],
is_streaming: false,
}
}
リクエストを送る。ストリームが帰ってくるたびに、on_receive
が実行される。
pub async fn send<F>(&mut self, mut on_receive: F) -> ChatterResult<()>
where
F: FnMut(String) -> ChatterResult<()>,
{
ストリーム中の場合はエラーを返す。
if self.is_streaming {
return Err(ChatterError::DuplicatedSendingError);
}
現在のメッセージ一覧からリクエストjsonを作成。
let json = self::json::Json::new(&self.messages);
ストリームフラグをONにする。
self.is_streaming = true;
received_text
は、受け取ったDeltaをまとめた文字列が格納される。
.bytes_stream()
により、レスポンスをバイトストリームとして受け取る。
.map()
で処理しているのは、受け取ったDeltaを最後に接続する目的。
ストリームを受け取るたびに、バイト文字列変換→文字列StreamData変換→Delta文字列に対してon_receive実行、を行う。
let received_text = self::request_builder::create()?
.json(&json)
.send()
.await?
.bytes_stream()
.map(|chunk_result| match chunk_result {
Ok(chunk) => {
let mut received = String::new();
for line in String::from_utf8_lossy(&chunk).as_ref().lines() {
if let Ok(data) = line.to_stream_data() {
for choice in data.choices {
if let Some(content) = choice.delta.content {
received.push_str(content.as_str());
if let Err(_) = on_receive(content) {
self.is_streaming = false;
}
}
}
}
}
received
}
Err(err) => {
eprintln!("Error while receiving chunk: {:?}", err);
"".into()
}
})
.collect::<Vec<String>>()
.await
.join("");
受け取ったメッセージをChatter
に追加し、ストリーム中フラグをOFFにする。
self.messages
.push(received_text.to_message(Role::Assistant));
self.is_streaming = false;
Ok(())
}
Chatter
にユーザメッセージを追加。
pub fn push_user_message(&mut self, content: String) {
self.messages.push(content.to_message(Role::User));
}
}