proper error when a client disconnects mid-download

This commit is contained in:
Robin Appelman 2024-06-25 22:48:20 +02:00
commit 3395e052bd
3 changed files with 20 additions and 3 deletions

View file

@ -1,9 +1,11 @@
use crate::error::DownloadError; use crate::error::DownloadError;
use crate::mqtt::MqttHelper; use crate::mqtt::MqttHelper;
use crate::Result; use crate::{DeviceUpdate, Result};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use md5::{Digest, Md5}; use md5::{Digest, Md5};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::select;
use tokio::sync::broadcast::Receiver;
use tracing::debug; use tracing::debug;
#[derive(Serialize)] #[derive(Serialize)]
@ -47,6 +49,7 @@ pub async fn download_config(
mqtt: &MqttHelper, mqtt: &MqttHelper,
client: &str, client: &str,
password: &str, password: &str,
mut device_update: Receiver<DeviceUpdate>,
) -> Result<DownloadedFile> { ) -> Result<DownloadedFile> {
let mut rx = mqtt let mut rx = mqtt
.subscribe(format!("stat/{client}/FILEDOWNLOAD")) .subscribe(format!("stat/{client}/FILEDOWNLOAD"))
@ -66,7 +69,19 @@ pub async fn download_config(
let mut state = DownloadState::default(); let mut state = DownloadState::default();
loop { loop {
let msg = rx.recv().await.unwrap(); let msg = select! {
msg = rx.recv() => {
msg.unwrap()
}
discovery = device_update.recv() => {
if let Ok(DeviceUpdate::Removed(device)) = discovery {
if device.as_str() == client {
return Err(DownloadError::Gone.into());
}
}
continue;
}
};
if let Ok(response) = serde_json::from_slice::<DownloadResponse>(msg.payload.as_ref()) { if let Ok(response) = serde_json::from_slice::<DownloadResponse>(msg.payload.as_ref()) {
debug!(message = ?response, "processing download status message"); debug!(message = ?response, "processing download status message");

View file

@ -72,6 +72,8 @@ pub enum DownloadError {
InvalidHash, InvalidHash,
#[error("Received data doesn't match the expected md5 hash, expected {0:x?} got {1:x?}")] #[error("Received data doesn't match the expected md5 hash, expected {0:x?} got {1:x?}")]
MismatchedHash([u8; 16], [u8; 16]), MismatchedHash([u8; 16], [u8; 16]),
#[error("Device has disconnected during the download")]
Gone,
} }
impl From<FromHexError> for DownloadError { impl From<FromHexError> for DownloadError {

View file

@ -101,7 +101,7 @@ impl TasmotaClient {
/// The password is the mqtt password used by the device, which might be different from the mqtt password used by this client /// The password is the mqtt password used by the device, which might be different from the mqtt password used by this client
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub async fn download_config(&self, client: &str, password: &str) -> Result<DownloadedFile> { pub async fn download_config(&self, client: &str, password: &str) -> Result<DownloadedFile> {
download_config(&self.mqtt, client, password).await download_config(&self.mqtt, client, password, self.device_update.subscribe()).await
} }
/// Get the list of known devices at this point in time /// Get the list of known devices at this point in time