config export

This commit is contained in:
Robin Appelman 2024-01-26 21:01:50 +01:00
commit 549c533076
7 changed files with 721 additions and 9 deletions

120
src/download.rs Normal file
View file

@ -0,0 +1,120 @@
use crate::error::DownloadError;
use crate::mqtt::MqttHelper;
use crate::Result;
use bytes::{Bytes, BytesMut};
use md5::{Digest, Md5};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Serialize)]
struct SendDownloadPayload<'a> {
password: &'a str,
#[serde(rename = "type")]
ty: u8,
binary: u8,
}
#[derive(Default, Debug)]
struct DownloadState {
name: String,
size: u32,
ty: u8,
id: u32,
data: BytesMut,
md5: [u8; 16],
}
#[derive(Debug)]
pub struct DownloadedFile {
pub name: String,
pub data: Bytes,
}
pub async fn download_config(
mqtt: &MqttHelper,
client: &str,
password: &str,
) -> Result<DownloadedFile> {
let mut rx = mqtt
.subscribe(format!("stat/{client}/FILEDOWNLOAD"))
.await?;
let topic = format!("cmnd/{client}/FILEDOWNLOAD");
mqtt.send(
&topic,
&SendDownloadPayload {
password,
ty: 2,
binary: 1,
},
)
.await?;
let mut state = DownloadState::default();
loop {
let msg = rx.recv().await.unwrap();
if let Ok(body) = serde_json::from_slice::<Value>(msg.payload.as_ref()) {
if let Some(status) = body.get("FileDownload") {
match status.as_str() {
Some("Started") => {
continue;
}
Some("Aborted") => {
return Err(DownloadError::DownloadAborted.into());
}
Some("Error 1") => {
return Err(DownloadError::InvalidPassword.into());
}
Some("Error 2") => {
return Err(DownloadError::BadChunkSize.into());
}
Some("Error 3") => {
return Err(DownloadError::InvalidFileType.into());
}
Some("Done") => {
break;
}
_ => {}
}
}
if let Some(name) = body.get("File").and_then(|v| v.as_str()) {
state.name = name.to_string();
}
if let Some(size) = body.get("Size").and_then(|v| v.as_u64()) {
state.size = size as u32;
}
if let Some(id) = body.get("Size").and_then(|v| v.as_u64()) {
state.id = id as u32;
}
if let Some(ty) = body.get("Type").and_then(|v| v.as_u64()) {
state.ty = ty as u8;
}
if let Some(md5) = body.get("Md5").and_then(|v| v.as_str()) {
hex::decode_to_slice(md5, &mut state.md5[..]).map_err(DownloadError::from)?;
}
} else {
state.data.extend(msg.payload);
}
mqtt.send_str(&topic, "?").await?;
}
if state.data.len() != state.size as usize {
return Err(DownloadError::MismatchedLength(state.size, state.data.len() as u32).into());
}
let mut hasher = Md5::new();
hasher.update(state.data.as_ref());
let hash = hasher.finalize();
if hash != state.md5.into() {
return Err(DownloadError::MismatchedHash(state.md5, hash.into()).into());
}
Ok(DownloadedFile {
name: state.name,
data: state.data.freeze(),
})
}

75
src/error.rs Normal file
View file

@ -0,0 +1,75 @@
use hex::FromHexError;
use rumqttc::{ClientError, ConnectionError};
use thiserror::Error;
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, Error)]
pub enum Error {
#[error("Error with mqtt transport: {0:#}")]
Mqtt(MqttError),
#[error("Topic {0} doesn't follow expected format")]
MalformedTopic(String),
#[error("Malformed json payload received: {0:#}")]
JsonPayload(serde_json::Error),
#[error(transparent)]
Download(#[from] DownloadError),
}
impl From<serde_json::Error> for Error {
fn from(value: serde_json::Error) -> Self {
Error::JsonPayload(value)
}
}
#[derive(Debug, Error)]
pub enum MqttError {
#[error("transparent")]
Client(ClientError),
#[error("transparent")]
Connection(ConnectionError),
}
impl From<MqttError> for Error {
fn from(value: MqttError) -> Self {
Error::Mqtt(value)
}
}
impl From<ClientError> for Error {
fn from(value: ClientError) -> Self {
MqttError::Client(value).into()
}
}
impl From<ConnectionError> for Error {
fn from(value: ConnectionError) -> Self {
MqttError::Connection(value).into()
}
}
#[derive(Debug, Error)]
pub enum DownloadError {
#[error("Aborted")]
DownloadAborted,
#[error("Invalid password for device")]
InvalidPassword,
#[error("Bad chunk size")]
BadChunkSize,
#[error("Invalid file type")]
InvalidFileType,
#[error("Received error code: {0}")]
Unknown(u32),
#[error("Mismatched payload length, expected {0} got {1}")]
MismatchedLength(u32, u32),
#[error("Received an invalid md5 hash")]
InvalidHash,
#[error("Received data doesn't match the expected md5 hash, expected {0:x?} got {1:x?}")]
MismatchedHash([u8; 16], [u8; 16]),
}
impl From<FromHexError> for DownloadError {
fn from(_: FromHexError) -> Self {
DownloadError::InvalidHash
}
}

View file

@ -1,14 +1,29 @@
pub fn add(left: usize, right: usize) -> usize {
left + right
mod download;
mod error;
mod mqtt;
use crate::download::download_config;
pub use crate::download::DownloadedFile;
use crate::mqtt::MqttHelper;
pub use error::{Error, Result};
use rumqttc::MqttOptions;
pub struct TasmotaClient {
mqtt: MqttHelper,
}
#[cfg(test)]
mod tests {
use super::*;
impl TasmotaClient {
pub fn connect(host: &str, port: u16, credentials: Option<(&str, &str)>) -> Result<Self> {
let mut mqtt_opts = MqttOptions::new("tasmota-client", host, port);
if let Some((username, password)) = credentials {
mqtt_opts.set_credentials(username, password);
}
Ok(TasmotaClient {
mqtt: MqttHelper::connect(mqtt_opts)?,
})
}
#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
pub async fn download_config(&self, client: &str, password: &str) -> Result<DownloadedFile> {
download_config(&self.mqtt, client, password).await
}
}

89
src/mqtt.rs Normal file
View file

@ -0,0 +1,89 @@
use crate::Result;
use async_stream::try_stream;
use rumqttc::{matches, AsyncClient, Event, EventLoop, MqttOptions, Packet, Publish, QoS};
use serde::Serialize;
use std::pin::pin;
use std::sync::Arc;
use tokio::spawn;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex;
use tokio_stream::{Stream, StreamExt};
use tracing::{debug, error};
pub struct MqttHelper {
client: AsyncClient,
listeners: Arc<Mutex<Vec<(String, Sender<Publish>)>>>,
}
impl MqttHelper {
pub fn connect(opts: MqttOptions) -> Result<Self> {
let (client, event_loop) = AsyncClient::new(opts, 10);
let listeners = Arc::<Mutex<Vec<(String, Sender<_>)>>>::default();
let senders = listeners.clone();
spawn(async move {
let stream = event_loop_to_stream(event_loop);
let messages = stream
.filter_map(|event| match event {
Ok(event) => {
debug!(event = ?event, "processing event");
Some(event)
}
Err(e) => {
error!(error = ?e, "error while receiving mqtt message");
None
}
})
.filter_map(|event| match event {
Event::Incoming(Packet::Publish(message)) => Some(message),
_ => None,
});
let mut messages = pin!(messages);
while let Some(message) = messages.next().await {
let message: Publish = message;
let mut listeners_ref = senders.lock().await;
listeners_ref.retain(|(_, sender)| !sender.is_closed());
for (filter, sender) in listeners_ref.iter() {
if matches(&message.topic, filter.as_str()) {
let _ = sender.send(message.clone()).await;
}
}
}
});
Ok(Self { client, listeners })
}
pub async fn send<B: Serialize>(&self, topic: &str, body: &B) -> Result<()> {
self.client
.publish(topic, QoS::AtLeastOnce, false, serde_json::to_vec(body)?)
.await?;
Ok(())
}
pub async fn send_str(&self, topic: &str, body: &str) -> Result<()> {
self.client
.publish(topic, QoS::AtLeastOnce, false, body)
.await?;
Ok(())
}
pub async fn subscribe(&self, topic: String) -> Result<Receiver<Publish>> {
self.client.subscribe(&topic, QoS::AtLeastOnce).await?;
let (tx, rx) = channel(10);
self.listeners.lock().await.push((topic, tx));
Ok(rx)
}
}
fn event_loop_to_stream(mut event_loop: EventLoop) -> impl Stream<Item = Result<Event>> {
try_stream! {
loop {
let event = event_loop.poll().await?;
yield event;
}
}
}