sheepd/src/api.rs
2025-05-02 12:53:28 +02:00

198 lines
4.9 KiB
Rust

use owl::{Deserialize, Serialize};
use rumqttc::{AsyncClient, Event, EventLoop, MqttOptions, Packet, Transport};
use std::time::Duration;
use tokio::time::sleep;
pub fn domain(host: &str) -> String {
if host.starts_with("http") {
return host.to_string();
} else {
format!("https://{host}")
}
}
#[derive(Deserialize, Serialize)]
pub struct LoginParam {
pub username: String,
pub password: String,
}
#[derive(Deserialize, Serialize)]
/// Join Request
pub struct JoinParams {
/// Optional join token
pub join_token: Option<String>,
/// Machine ID
pub machine_id: String,
/// Hostname
pub hostname: String,
/// Public Key Identity
pub identity: (String, String),
}
#[derive(Deserialize, Serialize)]
pub struct JoinResponse {
/// Server Token
pub token: String,
/// Server Identity
pub identity: (String, String),
/// MQTT endpoint
pub mqtt: String,
}
/// Setup a MQTT connection for `machine_id` on `mqtt`.
///
/// This will connect either over `ws://` or `wss://` depending on the scheme of `mqtt`. By default it will use `wss://`.
pub fn mqtt_connect(machine_id: &str, mqtt: &str) -> (rumqttc::AsyncClient, rumqttc::EventLoop) {
let mqttoptions = if mqtt.starts_with("ws://") {
log::warn!("Using unencrypted WebSocket connection");
let mut mqttoptions = MqttOptions::new(
machine_id,
&format!("ws://{}", mqtt.trim_start_matches("ws://")),
8000,
);
mqttoptions.set_transport(Transport::Ws);
mqttoptions.set_keep_alive(Duration::from_secs(60));
mqttoptions
} else {
log::info!("Using encrypted WebSocket connection");
let mut mqttoptions = MqttOptions::new(
machine_id,
&format!("wss://{}", mqtt.trim_start_matches("wss://")),
8000,
);
mqttoptions.set_transport(Transport::wss_with_default_config());
mqttoptions.set_keep_alive(Duration::from_secs(60));
mqttoptions
};
AsyncClient::new(mqttoptions, 10)
}
/// Run the async MQTT event loop
pub async fn run_event_loop<F, Fut>(mut eventloop: EventLoop, handle_payload: F)
where
F: Fn(String, Vec<u8>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
log::info!("Handling MQTT events");
loop {
match eventloop.poll().await {
Ok(Event::Incoming(incoming)) => {
log::trace!("Incoming = {:?}", incoming);
match incoming {
Packet::Publish(publish) => {
log::info!("Got payload with size {}", publish.size());
let s = publish.payload;
tokio::spawn(handle_payload(publish.topic, s.to_vec()));
}
_ => {}
}
}
Ok(Event::Outgoing(outgoing)) => {
log::trace!("Outgoing = {:?}", outgoing);
}
Err(e) => {
log::error!("MQTT eventloop error = {:?}", e);
sleep(Duration::from_secs(1)).await;
}
}
}
}
#[derive(Deserialize, Serialize)]
pub struct DeviceList {
pub devices: Vec<DeviceEntry>,
}
#[derive(Deserialize, Serialize)]
pub struct DeviceEntry {
pub id: String,
pub hostname: String,
pub online: bool,
}
#[derive(Deserialize, Serialize)]
/// Generic JSON API result
pub struct Result<T: Serialize> {
pub ok: Option<T>,
pub err: Option<String>,
}
impl<T: Serialize> Result<T> {
#[allow(non_snake_case)]
pub fn OkVal(val: T) -> Self {
Self {
ok: Some(val),
err: None,
}
}
pub fn as_result(self) -> std::result::Result<T, String> {
if let Some(ok) = self.ok {
Ok(ok)
} else {
Err(self.err.unwrap())
}
}
}
impl Result<i32> {
#[allow(non_snake_case)]
pub fn Ok() -> Self {
Self {
ok: Some(1),
err: None,
}
}
#[allow(non_snake_case)]
pub fn Err(msg: &str) -> Self {
Self {
ok: None,
err: Some(msg.to_string()),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ClientAction {
pub id: ulid::Ulid,
pub action: ClientActions,
}
impl ClientAction {
pub fn new(action: ClientActions) -> Self {
Self {
id: ulid::Ulid::new(),
action,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ClientActions {
OSQuery(String),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ServerResponse {
pub id: ulid::Ulid,
pub response: ServerResponses,
}
impl ServerResponse {
pub fn of(client: &ClientAction, resp: ServerResponses) -> Self {
Self {
id: client.id,
response: resp,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ServerResponses {
OSQuery(String),
}