mirror of
https://codeberg.org/icewind/tasmota-mqtt-client.git
synced 2026-06-03 10:14:10 +02:00
config export
This commit is contained in:
parent
04d8752b33
commit
549c533076
7 changed files with 721 additions and 9 deletions
120
src/download.rs
Normal file
120
src/download.rs
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
use crate::error::DownloadError;
|
||||
use crate::mqtt::MqttHelper;
|
||||
use crate::Result;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use md5::{Digest, Md5};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SendDownloadPayload<'a> {
|
||||
password: &'a str,
|
||||
#[serde(rename = "type")]
|
||||
ty: u8,
|
||||
binary: u8,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct DownloadState {
|
||||
name: String,
|
||||
size: u32,
|
||||
ty: u8,
|
||||
id: u32,
|
||||
data: BytesMut,
|
||||
md5: [u8; 16],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DownloadedFile {
|
||||
pub name: String,
|
||||
pub data: Bytes,
|
||||
}
|
||||
|
||||
pub async fn download_config(
|
||||
mqtt: &MqttHelper,
|
||||
client: &str,
|
||||
password: &str,
|
||||
) -> Result<DownloadedFile> {
|
||||
let mut rx = mqtt
|
||||
.subscribe(format!("stat/{client}/FILEDOWNLOAD"))
|
||||
.await?;
|
||||
let topic = format!("cmnd/{client}/FILEDOWNLOAD");
|
||||
|
||||
mqtt.send(
|
||||
&topic,
|
||||
&SendDownloadPayload {
|
||||
password,
|
||||
ty: 2,
|
||||
binary: 1,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut state = DownloadState::default();
|
||||
|
||||
loop {
|
||||
let msg = rx.recv().await.unwrap();
|
||||
|
||||
if let Ok(body) = serde_json::from_slice::<Value>(msg.payload.as_ref()) {
|
||||
if let Some(status) = body.get("FileDownload") {
|
||||
match status.as_str() {
|
||||
Some("Started") => {
|
||||
continue;
|
||||
}
|
||||
Some("Aborted") => {
|
||||
return Err(DownloadError::DownloadAborted.into());
|
||||
}
|
||||
Some("Error 1") => {
|
||||
return Err(DownloadError::InvalidPassword.into());
|
||||
}
|
||||
Some("Error 2") => {
|
||||
return Err(DownloadError::BadChunkSize.into());
|
||||
}
|
||||
Some("Error 3") => {
|
||||
return Err(DownloadError::InvalidFileType.into());
|
||||
}
|
||||
Some("Done") => {
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if let Some(name) = body.get("File").and_then(|v| v.as_str()) {
|
||||
state.name = name.to_string();
|
||||
}
|
||||
if let Some(size) = body.get("Size").and_then(|v| v.as_u64()) {
|
||||
state.size = size as u32;
|
||||
}
|
||||
if let Some(id) = body.get("Size").and_then(|v| v.as_u64()) {
|
||||
state.id = id as u32;
|
||||
}
|
||||
if let Some(ty) = body.get("Type").and_then(|v| v.as_u64()) {
|
||||
state.ty = ty as u8;
|
||||
}
|
||||
if let Some(md5) = body.get("Md5").and_then(|v| v.as_str()) {
|
||||
hex::decode_to_slice(md5, &mut state.md5[..]).map_err(DownloadError::from)?;
|
||||
}
|
||||
} else {
|
||||
state.data.extend(msg.payload);
|
||||
}
|
||||
|
||||
mqtt.send_str(&topic, "?").await?;
|
||||
}
|
||||
|
||||
if state.data.len() != state.size as usize {
|
||||
return Err(DownloadError::MismatchedLength(state.size, state.data.len() as u32).into());
|
||||
}
|
||||
|
||||
let mut hasher = Md5::new();
|
||||
hasher.update(state.data.as_ref());
|
||||
let hash = hasher.finalize();
|
||||
|
||||
if hash != state.md5.into() {
|
||||
return Err(DownloadError::MismatchedHash(state.md5, hash.into()).into());
|
||||
}
|
||||
|
||||
Ok(DownloadedFile {
|
||||
name: state.name,
|
||||
data: state.data.freeze(),
|
||||
})
|
||||
}
|
||||
75
src/error.rs
Normal file
75
src/error.rs
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
use hex::FromHexError;
|
||||
use rumqttc::{ClientError, ConnectionError};
|
||||
use thiserror::Error;
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
#[error("Error with mqtt transport: {0:#}")]
|
||||
Mqtt(MqttError),
|
||||
#[error("Topic {0} doesn't follow expected format")]
|
||||
MalformedTopic(String),
|
||||
#[error("Malformed json payload received: {0:#}")]
|
||||
JsonPayload(serde_json::Error),
|
||||
#[error(transparent)]
|
||||
Download(#[from] DownloadError),
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(value: serde_json::Error) -> Self {
|
||||
Error::JsonPayload(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MqttError {
|
||||
#[error("transparent")]
|
||||
Client(ClientError),
|
||||
#[error("transparent")]
|
||||
Connection(ConnectionError),
|
||||
}
|
||||
|
||||
impl From<MqttError> for Error {
|
||||
fn from(value: MqttError) -> Self {
|
||||
Error::Mqtt(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClientError> for Error {
|
||||
fn from(value: ClientError) -> Self {
|
||||
MqttError::Client(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ConnectionError> for Error {
|
||||
fn from(value: ConnectionError) -> Self {
|
||||
MqttError::Connection(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum DownloadError {
|
||||
#[error("Aborted")]
|
||||
DownloadAborted,
|
||||
#[error("Invalid password for device")]
|
||||
InvalidPassword,
|
||||
#[error("Bad chunk size")]
|
||||
BadChunkSize,
|
||||
#[error("Invalid file type")]
|
||||
InvalidFileType,
|
||||
#[error("Received error code: {0}")]
|
||||
Unknown(u32),
|
||||
#[error("Mismatched payload length, expected {0} got {1}")]
|
||||
MismatchedLength(u32, u32),
|
||||
#[error("Received an invalid md5 hash")]
|
||||
InvalidHash,
|
||||
#[error("Received data doesn't match the expected md5 hash, expected {0:x?} got {1:x?}")]
|
||||
MismatchedHash([u8; 16], [u8; 16]),
|
||||
}
|
||||
|
||||
impl From<FromHexError> for DownloadError {
|
||||
fn from(_: FromHexError) -> Self {
|
||||
DownloadError::InvalidHash
|
||||
}
|
||||
}
|
||||
33
src/lib.rs
33
src/lib.rs
|
|
@ -1,14 +1,29 @@
|
|||
pub fn add(left: usize, right: usize) -> usize {
|
||||
left + right
|
||||
mod download;
|
||||
mod error;
|
||||
mod mqtt;
|
||||
|
||||
use crate::download::download_config;
|
||||
pub use crate::download::DownloadedFile;
|
||||
use crate::mqtt::MqttHelper;
|
||||
pub use error::{Error, Result};
|
||||
use rumqttc::MqttOptions;
|
||||
|
||||
pub struct TasmotaClient {
|
||||
mqtt: MqttHelper,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
impl TasmotaClient {
|
||||
pub fn connect(host: &str, port: u16, credentials: Option<(&str, &str)>) -> Result<Self> {
|
||||
let mut mqtt_opts = MqttOptions::new("tasmota-client", host, port);
|
||||
if let Some((username, password)) = credentials {
|
||||
mqtt_opts.set_credentials(username, password);
|
||||
}
|
||||
Ok(TasmotaClient {
|
||||
mqtt: MqttHelper::connect(mqtt_opts)?,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
let result = add(2, 2);
|
||||
assert_eq!(result, 4);
|
||||
pub async fn download_config(&self, client: &str, password: &str) -> Result<DownloadedFile> {
|
||||
download_config(&self.mqtt, client, password).await
|
||||
}
|
||||
}
|
||||
|
|
|
|||
89
src/mqtt.rs
Normal file
89
src/mqtt.rs
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
use crate::Result;
|
||||
use async_stream::try_stream;
|
||||
use rumqttc::{matches, AsyncClient, Event, EventLoop, MqttOptions, Packet, Publish, QoS};
|
||||
use serde::Serialize;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::spawn;
|
||||
use tokio::sync::mpsc::{channel, Receiver, Sender};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tracing::{debug, error};
|
||||
|
||||
pub struct MqttHelper {
|
||||
client: AsyncClient,
|
||||
listeners: Arc<Mutex<Vec<(String, Sender<Publish>)>>>,
|
||||
}
|
||||
|
||||
impl MqttHelper {
|
||||
pub fn connect(opts: MqttOptions) -> Result<Self> {
|
||||
let (client, event_loop) = AsyncClient::new(opts, 10);
|
||||
|
||||
let listeners = Arc::<Mutex<Vec<(String, Sender<_>)>>>::default();
|
||||
let senders = listeners.clone();
|
||||
|
||||
spawn(async move {
|
||||
let stream = event_loop_to_stream(event_loop);
|
||||
let messages = stream
|
||||
.filter_map(|event| match event {
|
||||
Ok(event) => {
|
||||
debug!(event = ?event, "processing event");
|
||||
Some(event)
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "error while receiving mqtt message");
|
||||
None
|
||||
}
|
||||
})
|
||||
.filter_map(|event| match event {
|
||||
Event::Incoming(Packet::Publish(message)) => Some(message),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
let mut messages = pin!(messages);
|
||||
|
||||
while let Some(message) = messages.next().await {
|
||||
let message: Publish = message;
|
||||
let mut listeners_ref = senders.lock().await;
|
||||
listeners_ref.retain(|(_, sender)| !sender.is_closed());
|
||||
for (filter, sender) in listeners_ref.iter() {
|
||||
if matches(&message.topic, filter.as_str()) {
|
||||
let _ = sender.send(message.clone()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self { client, listeners })
|
||||
}
|
||||
|
||||
pub async fn send<B: Serialize>(&self, topic: &str, body: &B) -> Result<()> {
|
||||
self.client
|
||||
.publish(topic, QoS::AtLeastOnce, false, serde_json::to_vec(body)?)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_str(&self, topic: &str, body: &str) -> Result<()> {
|
||||
self.client
|
||||
.publish(topic, QoS::AtLeastOnce, false, body)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn subscribe(&self, topic: String) -> Result<Receiver<Publish>> {
|
||||
self.client.subscribe(&topic, QoS::AtLeastOnce).await?;
|
||||
let (tx, rx) = channel(10);
|
||||
self.listeners.lock().await.push((topic, tx));
|
||||
Ok(rx)
|
||||
}
|
||||
}
|
||||
|
||||
fn event_loop_to_stream(mut event_loop: EventLoop) -> impl Stream<Item = Result<Event>> {
|
||||
try_stream! {
|
||||
loop {
|
||||
let event = event_loop.poll().await?;
|
||||
yield event;
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue