sheepd/src/api.rs
2025-04-30 10:02:04 +02:00

107 lines
3.1 KiB
Rust

use owl::{Deserialize, Serialize};
use rumqttc::{AsyncClient, Event, EventLoop, MqttOptions, Packet, Transport};
use std::time::Duration;
use tokio::time::sleep;
#[derive(Deserialize, Serialize)]
/// Join Request
pub struct JoinParams {
/// Optional join token
pub join_token: Option<String>,
/// Machine ID
pub machine_id: String,
/// Hostname
pub hostname: String,
/// Public Key Identity
pub identity: (String, String),
}
#[derive(Deserialize, Serialize)]
pub struct JoinResponse {
/// Server Token
pub token: String,
/// Server Identity
pub identity: (String, String),
/// MQTT endpoint
pub mqtt: String,
}
/// Setup a MQTT connection for `machine_id` on `mqtt`.
///
/// This will connect either over `ws://` or `wss://` depending on the scheme of `mqtt`. By default it will use `wss://`.
pub fn mqtt_connect(machine_id: &str, mqtt: &str) -> (rumqttc::AsyncClient, rumqttc::EventLoop) {
let mqttoptions = if mqtt.starts_with("ws://") {
log::warn!("Using unencrypted WebSocket connection");
let mut mqttoptions = MqttOptions::new(
machine_id,
&format!("ws://{}", mqtt.trim_start_matches("ws://")),
8000,
);
mqttoptions.set_transport(Transport::Ws);
mqttoptions.set_keep_alive(Duration::from_secs(60));
mqttoptions
} else {
log::info!("Using encrypted WebSocket connection");
let mut mqttoptions = MqttOptions::new(
machine_id,
&format!("wss://{}", mqtt.trim_start_matches("wss://")),
8000,
);
mqttoptions.set_transport(Transport::wss_with_default_config());
mqttoptions.set_keep_alive(Duration::from_secs(60));
mqttoptions
};
AsyncClient::new(mqttoptions, 10)
}
/// Run the async MQTT event loop
pub async fn run_event_loop<F, Fut>(mut eventloop: EventLoop, handle_payload: F)
where
F: Fn(String, Vec<u8>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
log::info!("Handling MQTT events");
loop {
match eventloop.poll().await {
Ok(Event::Incoming(incoming)) => {
log::trace!("Incoming = {:?}", incoming);
match incoming {
Packet::Publish(publish) => {
log::info!("Got payload with size {}", publish.size());
let s = publish.payload;
tokio::spawn(handle_payload(publish.topic, s.to_vec()));
}
_ => {}
}
}
Ok(Event::Outgoing(outgoing)) => {
log::trace!("Outgoing = {:?}", outgoing);
}
Err(e) => {
log::error!("MQTT eventloop error = {:?}", e);
sleep(Duration::from_secs(1)).await;
}
}
}
}
#[derive(Deserialize, Serialize)]
/// Generic JSON API result
pub struct Result {
pub ok: u32,
}
impl Result {
#[allow(non_snake_case)]
pub fn Ok() -> Self {
Self { ok: 1 }
}
#[allow(non_snake_case)]
pub fn Err() -> Self {
Self { ok: 0 }
}
}