Compare commits

..

No commits in common. "main" and "main" have entirely different histories.
main ... main

18 changed files with 876 additions and 1131 deletions

View file

@ -1,23 +0,0 @@
matrix:
platform:
- linux/amd64
- linux/arm64
labels:
platform: ${platform}
when:
- event: push
branch: main
steps:
- name: "PKGBUILD"
image: git.hydrar.de/jmarya/pacco:latest
commands:
- pacco build --ci --push navos
environment:
PACCO_HOST: "https://pac.hydrar.de"
PACCO_TOKEN:
from_secret: pacco_token
SIGN_KEY:
from_secret: navos_key

1114
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -3,18 +3,10 @@ name = "sheepd"
version = "0.1.0" version = "0.1.0"
edition = "2024" edition = "2024"
[lib]
name = "sheepd"
path = "src/lib.rs"
[[bin]] [[bin]]
name = "sheepd" name = "sheepd"
path = "src/sheepd.rs" path = "src/sheepd.rs"
[[bin]]
name = "sheepctl"
path = "src/sheepctl.rs"
[[bin]] [[bin]]
name = "herd" name = "herd"
path = "src/herd.rs" path = "src/herd.rs"
@ -38,14 +30,9 @@ axum-client-ip = { version = "1.0.0", optional = true }
toml = "0.8.21" toml = "0.8.21"
hex = "0.4.3" hex = "0.4.3"
rand = "0.9.1" rand = "0.9.1"
based_auth = { git = "https://git.hydrar.de/jmarya/based_auth", features = ["axum"] } based = { git = "https://git.hydrar.de/jmarya/based", branch = "owl" }
http2 = "0.4.21" http2 = "0.4.21"
ureq = { version = "3.0.11", features = ["json"] } ureq = { version = "3.0.11", features = ["json"] }
rumqttc = { version = "0.24.0", features = ["url", "websocket"] } rumqttc = { version = "0.24.0", features = ["url", "websocket"] }
sage = { git = "https://git.hydrar.de/jmarya/sage" } sage = { git = "https://git.hydrar.de/jmarya/sage" }
dashmap = "6.1.0" dashmap = "6.1.0"
ulid = { version = "1.2.1", features = ["serde"] }
chrono = "0.4.41"
directories = "6.0.0"
inquire = "0.7.5"
axum-extra = { version = "0.10.1", features = ["typed-header"] }

View file

@ -1,18 +0,0 @@
FROM rust:buster AS builder
RUN rustup default nightly
COPY . /app
WORKDIR /app
RUN cargo build --release --bin herd --features herd
FROM git.hydrar.de/navos/navos:latest
RUN pacman-key --init && pacman-key --populate archlinux && pacman-key --populate navos && pacman -Syu --noconfirm && pacman -Syu --noconfirm openssl-1.1
COPY --from=builder /app/target/release/herd /herd
WORKDIR /
CMD ["/herd"]

View file

@ -1,41 +0,0 @@
pkgbase=sheep
pkgname=('sheepd' 'sheepctl')
pkgver=2025.05.05_b010027
pkgrel=1
arch=('x86_64' 'aarch64')
url="https://git.hydrar.de/navos/sheepd"
license=('MIT')
makedepends=('rustup')
source=("repo::git+https://git.hydrar.de/navos/sheepd.git")
sha256sums=("SKIP")
pkgver() {
cd "$srcdir/repo"
echo "$(date +%Y.%m.%d)_$(git rev-parse --short HEAD)"
}
prepare() {
cd "$srcdir/repo"
rustup default nightly
cargo fetch
}
build() {
cd "$srcdir/repo"
cargo build --release --bin sheepd
cargo build --release --bin sheepctl
}
package_sheepd() {
pkgdesc="sheep daemon"
depends=('osquery')
cd "$srcdir/repo"
install -Dm755 "sheepd.service" "$pkgir/etc/systemd/systemd/sheepd.service"
install -Dm755 "target/release/sheepd" "$pkgdir/usr/bin/sheepd"
}
package_sheepctl() {
pkgdesc="CLI for controling your herd"
cd "$srcdir/repo"
install -Dm755 "target/release/sheepctl" "$pkgdir/usr/bin/sheepctl"
}

View file

@ -10,11 +10,3 @@ services:
- ./mosquitto/data:/mosquitto/data - ./mosquitto/data:/mosquitto/data
- ./mosquitto/log:/mosquitto/log - ./mosquitto/log:/mosquitto/log
restart: unless-stopped restart: unless-stopped
herd:
build: .
ports:
- 8080:8000
volumes:
- ./herd:/herd
restart: unless-stopped

View file

@ -1,11 +0,0 @@
[Unit]
Description=Sheep Daemon
After=network.target
[Service]
Type=simple
ExecStart=/usr/bin/sheepd
Restart=on-failure
[Install]
WantedBy=multi-user.target

View file

@ -3,21 +3,7 @@ use rumqttc::{AsyncClient, Event, EventLoop, MqttOptions, Packet, Transport};
use std::time::Duration; use std::time::Duration;
use tokio::time::sleep; use tokio::time::sleep;
pub fn domain(host: &str) -> String {
if host.starts_with("http") {
return host.to_string();
} else {
format!("https://{host}")
}
}
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
pub struct LoginParam {
pub username: String,
pub password: String,
}
#[derive(Deserialize, Serialize, Clone)]
/// Join Request /// Join Request
pub struct JoinParams { pub struct JoinParams {
/// Optional join token /// Optional join token
@ -40,24 +26,6 @@ pub struct JoinResponse {
pub mqtt: String, pub mqtt: String,
} }
#[derive(Deserialize, Serialize)]
pub struct ShellParam {
pub cmd: String,
pub cwd: String,
}
#[derive(Deserialize, Serialize)]
pub struct QueryParam {
pub query: String,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct ShellResponse {
pub stdout: String,
pub stderr: String,
pub status: i32,
}
/// Setup a MQTT connection for `machine_id` on `mqtt`. /// Setup a MQTT connection for `machine_id` on `mqtt`.
/// ///
/// This will connect either over `ws://` or `wss://` depending on the scheme of `mqtt`. By default it will use `wss://`. /// This will connect either over `ws://` or `wss://` depending on the scheme of `mqtt`. By default it will use `wss://`.
@ -120,99 +88,20 @@ where
} }
} }
#[derive(Deserialize, Serialize)]
pub struct DeviceList {
pub devices: Vec<DeviceEntry>,
}
#[derive(Deserialize, Serialize)]
pub struct DeviceEntry {
pub id: String,
pub hostname: String,
pub online: bool,
}
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
/// Generic JSON API result /// Generic JSON API result
pub struct Result<T: Serialize> { pub struct Result {
pub ok: Option<T>, pub ok: u32,
pub err: Option<String>,
} }
impl<T: Serialize> Result<T> { impl Result {
#[allow(non_snake_case)]
pub fn OkVal(val: T) -> Self {
Self {
ok: Some(val),
err: None,
}
}
pub fn as_result(self) -> std::result::Result<T, String> {
if let Some(ok) = self.ok {
Ok(ok)
} else {
Err(self.err.unwrap())
}
}
#[allow(non_snake_case)]
pub fn Err(msg: &str) -> Self {
Self {
ok: None,
err: Some(msg.to_string()),
}
}
}
impl Result<i32> {
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn Ok() -> Self { pub fn Ok() -> Self {
Self { Self { ok: 1 }
ok: Some(1),
err: None,
} }
#[allow(non_snake_case)]
pub fn Err() -> Self {
Self { ok: 0 }
} }
} }
#[derive(Debug, Serialize, Deserialize)]
pub struct ClientAction {
pub id: ulid::Ulid,
pub action: ClientActions,
}
impl ClientAction {
pub fn new(action: ClientActions) -> Self {
Self {
id: ulid::Ulid::new(),
action,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ClientActions {
OSQuery(String),
Shell(String, String),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ServerResponse {
pub id: ulid::Ulid,
pub response: ServerResponses,
}
impl ServerResponse {
pub fn of(client: &ClientAction, resp: ServerResponses) -> Self {
Self {
id: client.id,
response: resp,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ServerResponses {
OSQuery(String),
Shell(ShellResponse),
}

View file

@ -1,33 +1,25 @@
use api::ServerResponse;
use axum::{ use axum::{
Router, Router,
routing::{get, post}, routing::{get, post},
}; };
use axum_client_ip::ClientIpSource; use axum_client_ip::ClientIpSource;
use dashmap::DashMap; use based::auth::User;
use owl::{prelude::*, set_global_db}; use owl::{prelude::*, set_global_db};
use rand::RngCore; use rand::RngCore;
use std::{net::SocketAddr, path::PathBuf}; use std::{net::SocketAddr, path::PathBuf};
mod api; mod api;
use based_auth::User;
mod herd_core; mod herd_core;
use crate::herd_core::mqtt::{handle_mqtt, listen_to_devices}; use crate::herd_core::mqtt::{handle_mqtt, listen_to_devices};
use herd_core::model::Machine;
use herd_core::{ use herd_core::{
config::Config, config::Config,
route::{device_get_api, device_osquery, device_shell_cmd, join_device, login_user}, route::{join_device, login_user},
}; };
use herd_core::{model::Machine, route::devices_list};
use rumqttc::AsyncClient;
use sage::Identity; use sage::Identity;
use tokio::sync::OnceCell; use tokio::sync::OnceCell;
pub static IDENTITY: OnceCell<Identity> = OnceCell::const_new(); pub static IDENTITY: OnceCell<Identity> = OnceCell::const_new();
pub static CONFIG: OnceCell<Config> = OnceCell::const_new(); pub static CONFIG: OnceCell<Config> = OnceCell::const_new();
pub static MQTT: OnceCell<AsyncClient> = OnceCell::const_new();
pub static ONLINE: OnceCell<DashMap<String, chrono::DateTime<chrono::Utc>>> = OnceCell::const_new();
pub static DISPATCH: OnceCell<DashMap<String, tokio::sync::oneshot::Sender<ServerResponse>>> =
OnceCell::const_new();
fn generate_token() -> String { fn generate_token() -> String {
let mut rng = rand::rng(); let mut rng = rand::rng();
@ -52,25 +44,17 @@ async fn main() {
let config = Config::default(); let config = Config::default();
let _ = crate::CONFIG.set(config); let _ = crate::CONFIG.set(config);
crate::ONLINE.set(DashMap::new()).unwrap();
crate::DISPATCH.set(DashMap::new()).unwrap();
let db = Database::filesystem("./herd/db"); let db = Database::filesystem("./herd/db");
set_global_db!(db); set_global_db!(db);
let _ = User::create("admin".to_string(), "admin", based_auth::UserRole::Admin).await; let _ = User::create("admin".to_string(), "admin", based::auth::UserRole::Admin).await;
let device = Router::new() let device = Router::new()
.route("/join", post(join_device)) .route("/join", post(join_device))
.layer(ClientIpSource::ConnectInfo.into_extension()); // Direct IP .layer(ClientIpSource::ConnectInfo.into_extension()); // Direct IP
// .layer(ClientIpSource::XRealIp.into_extension()) // Proxy // .layer(ClientIpSource::XRealIp.into_extension()) // Proxy
let user = Router::new() let user = Router::new().route("/login", post(login_user));
.route("/login", post(login_user))
.route("/device/{device_id}", get(device_get_api))
.route("/device/{device_id}/shell", post(device_shell_cmd))
.route("/device/{device_id}/osquery", post(device_osquery))
.route("/devices", get(devices_list));
let app = Router::new().merge(device).merge(user); let app = Router::new().merge(device).merge(user);
@ -79,7 +63,6 @@ async fn main() {
let (client, eventloop) = api::mqtt_connect("server", &crate::CONFIG.get().unwrap().mqtt); let (client, eventloop) = api::mqtt_connect("server", &crate::CONFIG.get().unwrap().mqtt);
listen_to_devices(&client).await; listen_to_devices(&client).await;
crate::MQTT.set(client);
tokio::spawn(async { tokio::spawn(async {
let listener = tokio::net::TcpListener::bind("0.0.0.0:8000").await.unwrap(); let listener = tokio::net::TcpListener::bind("0.0.0.0:8000").await.unwrap();

View file

@ -1,17 +1,9 @@
use std::sync::Arc;
use crate::Machine; use crate::Machine;
use crate::api::{ClientAction, ClientActions, ServerResponse};
use owl::prelude::*; use owl::prelude::*;
use owl::{Serialize, get, query}; use owl::{Serialize, get, query};
use rumqttc::AsyncClient; use rumqttc::AsyncClient;
use sage::PersonaIdentity; use sage::PersonaIdentity;
pub fn is_within_80_seconds(time: chrono::DateTime<chrono::Utc>) -> bool {
let now = chrono::Utc::now();
now.signed_duration_since(time).num_seconds() <= 80
}
/// Handle herd MQTT /// Handle herd MQTT
pub async fn handle_mqtt(topic: String, data: Vec<u8>) { pub async fn handle_mqtt(topic: String, data: Vec<u8>) {
log::info!("Received client request from {topic}"); log::info!("Received client request from {topic}");
@ -25,62 +17,18 @@ pub async fn handle_mqtt(topic: String, data: Vec<u8>) {
.unwrap(); .unwrap();
// TODO : check for recency // TODO : check for recency
println!("got raw: {}", String::from_utf8(dec.payload).unwrap());
match cat { match cat {
"online" => { "online" => {
if let Some(online) = crate::ONLINE.get().unwrap().get(client) { log::info!("Device {client} reported ONLINE");
if !is_within_80_seconds(*online) {
log::info!("Device {client} came back ONLINE");
}
} else {
log::info!("Device {client} went ONLINE");
}
crate::ONLINE
.get()
.unwrap()
.insert(client.to_string(), chrono::Utc::now());
}
"respond" => {
let resp: ServerResponse = serde_json::from_slice(&dec.payload).unwrap();
log::info!("Got response {:?}", resp);
let (id, entry) = crate::DISPATCH
.get()
.unwrap()
.remove(&resp.id.to_string())
.unwrap();
if entry.send(resp).is_err() {
log::error!(
"Could not send back response for action {id}. Probably due to timeout"
);
}
} }
_ => {} _ => {}
} }
} }
pub struct TaskWaiter {
pub id: ulid::Ulid,
pub recv: tokio::sync::oneshot::Receiver<ServerResponse>,
}
impl TaskWaiter {
pub async fn wait_for(self, timeout: std::time::Duration) -> Option<ServerResponse> {
if let Ok(in_time) = tokio::time::timeout(timeout, self.recv).await {
return in_time.ok();
}
None
}
}
/// Send a message to a registered `machine` /// Send a message to a registered `machine`
pub async fn send_msg( pub async fn send_msg<T: Serialize>(client: &AsyncClient, machine: &Model<Machine>, request: T) {
client: &AsyncClient,
machine: &Model<Machine>,
request: ClientAction,
) -> TaskWaiter {
let data = serde_json::to_string(&request).unwrap(); let data = serde_json::to_string(&request).unwrap();
let pk = &machine.read().identity; let pk = &machine.read().identity;
let rec = pk.enc_key().unwrap(); let rec = pk.enc_key().unwrap();
@ -96,17 +44,6 @@ pub async fn send_msg(
.publish(topic, rumqttc::QoS::AtMostOnce, true, payload) .publish(topic, rumqttc::QoS::AtMostOnce, true, payload)
.await .await
.unwrap(); .unwrap();
let (sender, recv) = tokio::sync::oneshot::channel();
crate::DISPATCH
.get()
.unwrap()
.insert(request.id.to_string(), sender);
TaskWaiter {
id: request.id,
recv,
}
} }
/// Subscribe to all `device->server` topics /// Subscribe to all `device->server` topics
@ -116,10 +53,6 @@ pub async fn listen_to_device(client: &AsyncClient, machine_id: &str) {
.subscribe(format!("{machine_id}/online"), rumqttc::QoS::AtMostOnce) .subscribe(format!("{machine_id}/online"), rumqttc::QoS::AtMostOnce)
.await .await
.unwrap(); .unwrap();
client
.subscribe(format!("{machine_id}/respond"), rumqttc::QoS::AtMostOnce)
.await
.unwrap();
} }
/// Subscibe to incoming messages from all registered machines /// Subscibe to incoming messages from all registered machines

View file

@ -1,165 +1,19 @@
use std::ops::Deref;
use crate::api; use crate::api;
use crate::api::ClientAction;
use crate::api::ClientActions;
use crate::api::JoinResponse; use crate::api::JoinResponse;
use crate::api::LoginParam;
use crate::api::ShellResponse;
use crate::herd_core::model::Machine; use crate::herd_core::model::Machine;
use crate::herd_core::mqtt::listen_to_device;
use axum::Json; use axum::Json;
use axum::extract::FromRequestParts;
use axum::extract::Path;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum_client_ip::ClientIp; use axum_client_ip::ClientIp;
use axum_extra::TypedHeader; use based::auth::Sessions;
use axum_extra::headers::Authorization; use based::auth::User;
use axum_extra::headers::authorization::Bearer;
use based_auth::APIUser;
use based_auth::Sessions;
use based_auth::User;
use owl::get;
use owl::prelude::Model;
use owl::query;
use owl::save; use owl::save;
use serde::Deserialize; use serde::Deserialize;
use serde_json::json; use serde_json::json;
use sheepd::DeviceEntry;
use sheepd::DeviceList;
use super::mqtt::is_within_80_seconds; #[derive(Deserialize)]
use super::mqtt::send_msg; pub struct LoginParam {
username: String,
macro_rules! check_admin { password: String,
($user:ident) => {
if !$user.read().is_admin() {
return (
StatusCode::UNAUTHORIZED,
Json(api::Result::Err("Invalid credentials")),
);
}
};
}
pub async fn device_osquery(
Path(device_id): Path<String>,
APIUser(user): APIUser,
Json(payload): Json<api::QueryParam>,
) -> (StatusCode, Json<api::Result<String>>) {
check_admin!(user);
let machine: Option<Model<Machine>> = get!(device_id);
if let Some(machine) = machine {
let resp = send_msg(
crate::MQTT.get().unwrap(),
&machine,
ClientAction::new(ClientActions::OSQuery(payload.query)),
)
.await;
if let Some(resp) = resp.wait_for(std::time::Duration::from_secs(60)).await {
let r = match resp.response {
api::ServerResponses::OSQuery(res) => res,
_ => unreachable!(),
};
(StatusCode::OK, Json(api::Result::OkVal(r)))
} else {
(
StatusCode::BAD_GATEWAY,
Json(api::Result::Err("Did not receive response from device")),
)
}
} else {
(StatusCode::NOT_FOUND, Json(api::Result::Err("Not Found")))
}
}
pub async fn device_shell_cmd(
Path(device_id): Path<String>,
APIUser(user): APIUser,
Json(payload): Json<api::ShellParam>,
) -> (StatusCode, Json<api::Result<ShellResponse>>) {
check_admin!(user);
let machine: Option<Model<Machine>> = get!(device_id);
if let Some(machine) = machine {
let resp = send_msg(
crate::MQTT.get().unwrap(),
&machine,
ClientAction::new(ClientActions::Shell(payload.cmd, payload.cwd)),
)
.await;
if let Some(resp) = resp.wait_for(std::time::Duration::from_secs(60)).await {
let r = match resp.response {
api::ServerResponses::Shell(shell_response) => shell_response,
_ => unreachable!(),
};
(StatusCode::OK, Json(api::Result::OkVal(r)))
} else {
(
StatusCode::BAD_GATEWAY,
Json(api::Result::Err("Did not receive response from device")),
)
}
} else {
(StatusCode::NOT_FOUND, Json(api::Result::Err("Not Found")))
}
}
pub async fn device_get_api(
Path(device_id): Path<String>,
APIUser(user): APIUser,
) -> (StatusCode, Json<api::Result<DeviceEntry>>) {
check_admin!(user);
let machine: Option<Model<Machine>> = get!(device_id.clone());
if let Some(machine) = machine {
let api = machine.read();
let api = DeviceEntry {
id: device_id.clone(),
hostname: api.hostname.clone(),
online: device_online(&device_id),
};
(StatusCode::OK, Json(api::Result::OkVal(api)))
} else {
let res = api::Result::<DeviceEntry>::Err("Not Found");
(StatusCode::NOT_FOUND, Json(res))
}
}
pub fn device_online(id: &String) -> bool {
crate::ONLINE
.get()
.unwrap()
.get(id)
.map(|x| is_within_80_seconds(*x.deref()))
.unwrap_or(false)
}
pub async fn devices_list(APIUser(user): APIUser) -> (StatusCode, Json<api::Result<DeviceList>>) {
check_admin!(user);
let machines: Vec<Model<Machine>> = query!(|_| true);
let mut ret = vec![];
for mac in machines {
let id = mac.read().id.to_string().replace("-", "");
let online_state = device_online(&id);
ret.push(DeviceEntry {
id: id,
hostname: mac.read().hostname.clone(),
online: online_state,
});
}
(
StatusCode::OK,
Json(api::Result::OkVal(DeviceList { devices: ret })),
)
} }
pub async fn login_user(Json(payload): Json<LoginParam>) -> (StatusCode, Json<serde_json::Value>) { pub async fn login_user(Json(payload): Json<LoginParam>) -> (StatusCode, Json<serde_json::Value>) {
@ -167,15 +21,9 @@ pub async fn login_user(Json(payload): Json<LoginParam>) -> (StatusCode, Json<se
let u = User::find(&payload.username).await.unwrap(); let u = User::find(&payload.username).await.unwrap();
if u.read().verify_pw(&payload.password) { if u.read().verify_pw(&payload.password) {
let ses = u.read().session().await; let ses = u.read().session().await;
( (StatusCode::OK, Json(json!({"token": ses.read().token})))
StatusCode::OK,
Json(json!(api::Result::OkVal(ses.read().token.as_str()))),
)
} else { } else {
( (StatusCode::FORBIDDEN, Json(json!({"error": "invalid"})))
StatusCode::FORBIDDEN,
Json(json!(api::Result::<api::ShellResponse>::Err("invalid"))),
)
} }
} }
@ -193,11 +41,9 @@ pub async fn join_device(
payload.machine_id payload.machine_id
); );
let machine = Machine::from_join_param(payload.clone()); let machine = Machine::from_join_param(payload);
let new_token = machine.token.clone(); let new_token = machine.token.clone();
listen_to_device(crate::MQTT.get().unwrap(), &payload.machine_id).await;
save!(machine); save!(machine);
let i = crate::IDENTITY.get().unwrap(); let i = crate::IDENTITY.get().unwrap();

View file

@ -1,2 +0,0 @@
pub mod api;
pub use api::*;

View file

@ -1,20 +0,0 @@
use sheepctl_core::{
args::{DeviceCommands, SheepctlArgs, SheepctlCommand},
cmd::{interactive_shell, list_devices, login, run_osquery},
};
mod api;
mod sheepctl_core;
fn main() {
let args: SheepctlArgs = argh::from_env();
match args.command {
SheepctlCommand::Device(device_command) => match device_command.command {
DeviceCommands::List(list_devices_command) => list_devices(list_devices_command),
},
SheepctlCommand::Login(login_command) => login(login_command),
SheepctlCommand::Shell(shell_command) => interactive_shell(shell_command),
SheepctlCommand::Query(query_command) => run_osquery(query_command),
}
}

View file

@ -1,74 +0,0 @@
use argh::FromArgs;
#[derive(FromArgs)]
/// Control your herd
pub struct SheepctlArgs {
#[argh(subcommand)]
pub command: SheepctlCommand,
}
#[derive(FromArgs, PartialEq, Debug)]
#[argh(subcommand)]
pub enum SheepctlCommand {
Login(LoginCommand),
Device(DeviceCommand),
Shell(ShellCommand),
Query(QueryCommand),
}
#[derive(FromArgs, PartialEq, Debug)]
/// Login to a homeserver
#[argh(subcommand, name = "login")]
pub struct LoginCommand {
#[argh(positional)]
/// homeserver
pub home: String,
#[argh(positional)]
/// username
pub username: String,
}
#[derive(FromArgs, PartialEq, Debug)]
#[argh(subcommand, name = "device")]
/// Commands for devices
pub struct DeviceCommand {
#[argh(subcommand)]
pub command: DeviceCommands,
}
#[derive(FromArgs, PartialEq, Debug)]
#[argh(subcommand)]
pub enum DeviceCommands {
List(ListDevicesCommand),
}
#[derive(FromArgs, PartialEq, Debug)]
/// List devices
#[argh(subcommand, name = "ls")]
pub struct ListDevicesCommand {
#[argh(switch)]
/// only show online devices
pub online: bool,
}
#[derive(FromArgs, PartialEq, Debug)]
#[argh(subcommand, name = "shell")]
/// Enter interactive shell
pub struct ShellCommand {
#[argh(positional)]
/// device ID
pub device: String,
}
#[derive(FromArgs, PartialEq, Debug)]
#[argh(subcommand, name = "query")]
/// Run an osquery
pub struct QueryCommand {
#[argh(positional)]
/// device ID
pub device: String,
#[argh(positional)]
/// query
pub query: String,
}

View file

@ -1,207 +0,0 @@
use std::{io::Write, path::PathBuf};
use owl::{Deserialize, Serialize};
use sheepd::{DeviceList, LoginParam, ShellResponse};
use super::args::{ListDevicesCommand, LoginCommand, QueryCommand, ShellCommand};
use crate::api::{DeviceEntry, QueryParam, ShellParam, domain};
/// Make an POST API call to `path` with `data` returning `Result<T>`
pub fn api_call<T: Serialize + for<'a> Deserialize<'a>, I: Serialize>(
server: &str,
path: &str,
data: I,
) -> crate::api::Result<T> {
let url = format!("{}/{path}", domain(server));
let mut res = ureq::post(url).send_json(data).unwrap();
let res: crate::api::Result<T> = res.body_mut().read_json().unwrap();
res
}
pub fn api_call_post_auth<T: Serialize + for<'a> Deserialize<'a>, I: Serialize>(
server: &str,
path: &str,
token: &str,
data: I,
) -> crate::api::Result<T> {
let url = format!("{}/{path}", domain(server));
let mut res = ureq::post(url)
.header("Authorization", format!("Bearer {token}"))
.send_json(data)
.unwrap();
let res: crate::api::Result<T> = res.body_mut().read_json().unwrap();
res
}
pub fn api_call_get<T: Serialize + for<'a> Deserialize<'a>>(
server: &str,
path: &str,
token: &str,
) -> crate::api::Result<T> {
let url = format!("{}/{path}", domain(server));
let mut res = ureq::get(url)
.header("Authorization", format!("Bearer {token}"))
.force_send_body()
.send_empty()
.unwrap();
let res: crate::api::Result<T> = res.body_mut().read_json().unwrap();
res
}
fn get_config_path() -> Option<PathBuf> {
directories::ProjectDirs::from("de", "Hydrar", "sheepd")
.map(|proj_dirs| proj_dirs.config_dir().join("config.toml"))
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CtlConfig {
pub home: String,
pub token: String,
}
impl CtlConfig {
pub fn load() -> Option<Self> {
let c = std::fs::read_to_string(get_config_path()?).ok()?;
toml::from_str(&c).ok()
}
pub fn save(&self) {
let s = toml::to_string(self).unwrap();
let config = get_config_path().unwrap();
let _ = std::fs::create_dir_all(config.parent().unwrap());
std::fs::write(get_config_path().unwrap(), s).unwrap();
}
}
pub fn get_machine_api(home: &str, token: &str, id: &str) -> Option<DeviceEntry> {
let res = api_call_get::<DeviceEntry>(home, &format!("device/{id}"), token);
res.as_result().ok()
}
pub fn list_devices(arg: ListDevicesCommand) {
let conf = CtlConfig::load().unwrap();
if let Ok(devices) = api_call_get::<DeviceList>(&conf.home, "devices", &conf.token).as_result()
{
println!("Hosts:");
for d in devices.devices {
println!(
"- {} [{}]{}",
d.hostname,
d.id,
if d.online { " [ONLINE]" } else { "" }
);
}
}
}
pub fn interactive_shell(arg: ShellCommand) {
let conf = CtlConfig::load().unwrap();
let machine = arg.device;
if let Some(machine) = get_machine_api(&conf.home, &conf.token, &machine) {
if !machine.online {
println!("Device not online.");
std::process::exit(1);
}
let mut cwd = "/".to_string();
loop {
print!("{} [{}]: {cwd} $ ", machine.hostname, machine.id);
std::io::stdout().flush().unwrap();
let mut read = String::new();
std::io::stdin().read_line(&mut read).unwrap();
if read == "exit" {
break;
}
if read.starts_with("cd") {
let dir = read.trim_start_matches("cd ").trim_end_matches(";");
cwd = dir.to_string();
continue;
}
let res = api_call_post_auth::<ShellResponse, _>(
&conf.home,
&format!("device/{}/shell", machine.id),
&conf.token,
ShellParam {
cmd: read.clone(),
cwd: cwd.clone(),
},
);
if let Ok(resp) = res.as_result() {
println!("{} #{}\n{}", read, resp.status, resp.stdout);
if !resp.stderr.is_empty() {
println!("Stderr: {}", resp.stderr);
}
} else {
println!("Command execution failed");
}
}
} else {
println!("No device with ID {machine}");
}
}
pub fn run_osquery(args: QueryCommand) {
// TODO : sanity checks
let conf = CtlConfig::load().unwrap();
let machine = args.device;
if let Some(machine) = get_machine_api(&conf.home, &conf.token, &machine) {
if !machine.online {
println!("Device not online.");
std::process::exit(1);
}
let res = api_call_post_auth::<String, _>(
&conf.home,
&format!("device/{}/osquery", machine.id),
&conf.token,
QueryParam { query: args.query },
);
if let Ok(res) = res.as_result() {
println!("{res}");
} else {
println!("Error doing query");
}
} else {
println!("No device with ID {machine}");
}
}
pub fn login(arg: LoginCommand) {
if let Some(conf) = CtlConfig::load() {
println!("You are already logged in to {}", conf.home);
std::process::exit(1);
}
let password = inquire::prompt_secret("Password: ").unwrap();
// login request
if let Result::Ok(token) = api_call::<String, _>(
&arg.home,
"login",
LoginParam {
username: arg.username,
password: password,
},
)
.as_result()
{
// save token to config
CtlConfig {
home: arg.home,
token,
}
.save();
} else {
println!("Login failed");
}
}

View file

@ -1,2 +0,0 @@
pub mod args;
pub mod cmd;

View file

@ -8,7 +8,14 @@ use crate::{
}; };
use super::args::JoinCommand; use super::args::JoinCommand;
use crate::api::domain;
fn domain(host: &str) -> String {
if host.starts_with("http") {
return host.to_string();
} else {
format!("https://{host}")
}
}
/// Join a herd as client /// Join a herd as client
pub fn join(conf: JoinCommand) { pub fn join(conf: JoinCommand) {

View file

@ -1,11 +1,7 @@
use std::process::Stdio;
use owl::Serialize; use owl::Serialize;
use rumqttc::AsyncClient; use rumqttc::AsyncClient;
use sage::PersonaIdentity; use sage::PersonaIdentity;
use crate::api::{ClientAction, ServerResponse, ServerResponses, ShellResponse};
// Client MQTT // Client MQTT
pub async fn handle_mqtt(topic: String, data: Vec<u8>) { pub async fn handle_mqtt(topic: String, data: Vec<u8>) {
//println!("got real raw: {}", String::from_utf8_lossy(&data)); //println!("got real raw: {}", String::from_utf8_lossy(&data));
@ -15,59 +11,15 @@ pub async fn handle_mqtt(topic: String, data: Vec<u8>) {
); );
let pk = pk.sign_key().unwrap(); let pk = pk.sign_key().unwrap();
let payload = crate::IDENTITY.get().unwrap().decrypt(&data, &pk).unwrap(); let payload = crate::IDENTITY.get().unwrap().decrypt(&data, &pk).unwrap();
println!(
let action: ClientAction = serde_json::from_slice(&payload.payload).unwrap(); "got payload {}",
log::info!("Got action {action:?}"); String::from_utf8(payload.payload).unwrap()
);
match &action.action {
crate::api::ClientActions::OSQuery(query) => {
log::info!("Doing osquery with {query}");
let res = osquery(&query);
send_back(
crate::MQTT.get().unwrap(),
"respond",
ServerResponse::of(&action, ServerResponses::OSQuery(res)),
)
.await;
}
crate::api::ClientActions::Shell(cmd, cwd) => {
log::info!("Received shell command: {cmd} in {cwd}");
let res = std::process::Command::new("sh")
.arg("-c")
.arg(cmd)
.current_dir(cwd)
.output()
.unwrap();
send_back(
crate::MQTT.get().unwrap(),
"respond",
ServerResponse::of(
&action,
ServerResponses::Shell(ShellResponse {
stdout: String::from_utf8_lossy(&res.stdout).to_string(),
stderr: String::from_utf8_lossy(&res.stderr).to_string(),
status: res.status.code().unwrap(),
}),
),
)
.await;
}
}
}
pub fn osquery(query: &str) -> String {
let cmd = std::process::Command::new("osqueryi")
.arg("--csv")
.arg(query)
.stdout(Stdio::piped())
.output()
.unwrap();
String::from_utf8(cmd.stdout).unwrap()
} }
/// Send something back to the server on `topic` /// Send something back to the server on `topic`
pub async fn send_back<T: Serialize>(client: &AsyncClient, topic: &str, data: T) { pub async fn send_back<T: Serialize>(client: &AsyncClient, topic: &str, request: T) {
let data = serde_json::to_string(&data).unwrap(); let data = serde_json::to_string(&request).unwrap();
let pk = crate::AGENT.get().unwrap(); let pk = crate::AGENT.get().unwrap();
let pk = (pk.server_age.clone(), String::new()); let pk = (pk.server_age.clone(), String::new());
@ -83,7 +35,7 @@ pub async fn send_back<T: Serialize>(client: &AsyncClient, topic: &str, data: T)
.encrypt(data.as_bytes(), &rec); .encrypt(data.as_bytes(), &rec);
let topic = format!("{machine_id}/{topic}"); let topic = format!("{machine_id}/{topic}");
log::info!("Publish to {topic}"); log::info!("Publish to {machine_id}{topic}");
client client
.publish(topic, rumqttc::QoS::AtMostOnce, true, payload) .publish(topic, rumqttc::QoS::AtMostOnce, true, payload)
.await .await