add digitalocean support

This commit is contained in:
Robin Appelman 2022-06-18 17:49:13 +02:00
commit 01e48520c6
8 changed files with 488 additions and 70 deletions

View file

@ -8,13 +8,20 @@ config_mode = "6v6" # 6v6 or 9v9, defaults to "6v6"
name = "Spire" # server name. optional, defaults to "Spire" name = "Spire" # server name. optional, defaults to "Spire"
tv_name = "SpireTV" # stv name. optional, defaults to "SpireTV" tv_name = "SpireTV" # stv name. optional, defaults to "SpireTV"
image = "spiretf/docker-spire-server" # docker image for the tf2 server. optional, defaults to "spiretf/docker-spire-server" image = "spiretf/docker-spire-server" # docker image for the tf2 server. optional, defaults to "spiretf/docker-spire-server"
ssh_key = "ssh-rsa AAAA..." # ssh key to add to the server. optional ssh_keys = ["ssh-rsa AAAA..."] # ssh key to add to the server. optional
manage_existing = false # whether to detect and manage server that are already running, optional, disabled by default manage_existing = false # whether to detect and manage server that are already running, optional, disabled by default
# Specify either the vultr settings or the digitalocean settings to pick the cloud provider
[vultr] [vultr]
api_key = "xxx" api_key = "xxx"
region = "ams" # see https://api.vultr.com/v2/regions for a list of regions region = "ams" # see https://api.vultr.com/v2/regions for a list of regions
plan = "vc2-1c-2gb" # optional, defaults to vc2-1c-2gb (2GB, $10/month) see https://api.vultr.com/v2/plans for a lis of plan plan = "vc2-1c-2gb" # optional, defaults to vc2-1c-2gb (2GB, $10/month) see https://api.vultr.com/v2/plans for a list of plans
[digitalocean]
api_key = "xxx"
region = "ams3" # see https://api.digitalocean.com/v2/apps/regions for a list of regions
plan = "s-1vcpu-2gb" # optional, defaults to s-1vcpu-2gb (2GB, $10/month) see https://api.digitalocean.com/v2/sizes for a list of plans
[dyndns] # optional dyndns2 details [dyndns] # optional dyndns2 details
update_url = "https://update.eurodyndns.org/update/" # Update url for dyndns2 update_url = "https://update.eurodyndns.org/update/" # Update url for dyndns2

355
src/cloud/digitalocean.rs Normal file
View file

@ -0,0 +1,355 @@
use crate::cloud::{Cloud, CloudError, Created, NetworkError, ResponseError, Result, Server};
use crate::CreatedAuth;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use futures_util::stream::FuturesUnordered;
use futures_util::TryStreamExt;
use petname::petname;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use std::time::Duration;
use thrussh_keys::key::KeyPair;
use thrussh_keys::PublicKeyBase64;
use tokio::time::sleep;
pub struct DigitalOcean {
region: String,
plan: String,
token: String,
client: Client,
}
impl DigitalOcean {
pub fn new(token: String, region: String, plan: String) -> Self {
DigitalOcean {
token,
region,
plan,
client: Client::default(),
}
}
}
#[async_trait]
impl Cloud for DigitalOcean {
async fn list(&self) -> Result<Vec<Server>> {
let response = self
.client
.get("https://api.digitalocean.com/v2/droplets")
.bearer_auth(&self.token)
.send()
.await
.map_err(NetworkError::from)?;
CloudError::from_status_code(response.status())?;
let response: DigitalOceanListResponse =
response.json().await.map_err(ResponseError::from)?;
Ok(response
.droplets
.into_iter()
.filter(|instance| instance.tags.iter().any(|tag| tag == "spire"))
.map(Server::from)
.collect())
}
async fn spawn(&self, ssh_keys: &[String]) -> Result<Created> {
let startup_key = Arc::new(KeyPair::generate_ed25519().unwrap());
let startup_key_id = self
.create_key(
"Dispenser Deploy Key",
&format!(
"{} {} {}",
startup_key.name(),
startup_key.public_key_base64(),
"dispenser-deploy"
),
)
.await?;
let mut key_ids = ssh_keys
.iter()
.map(|key| self.get_ssh_key_id(key))
.collect::<FuturesUnordered<_>>()
.try_collect::<Vec<_>>()
.await?;
key_ids.push(startup_key_id);
let response_res = self
.client
.post("https://api.digitalocean.com/v2/droplets")
.bearer_auth(&self.token)
.json(&DigitalOceanCreateParams {
region: self.region.as_str(),
size: self.plan.as_str(),
tags: &["spire"],
name: petname(2, "-"),
image: "docker-20-04",
ssh_keys: key_ids,
ipv6: true,
})
.send()
.await
.map_err(NetworkError::from);
self.remove_key(startup_key_id).await?;
// remove the deploy key, even if the spawn request failed
let response = response_res?;
CloudError::from_status_code(response.status())?;
if response.status().is_success() {
let response: DigitalOceanCreateResponse =
response.json().await.map_err(ResponseError::from)?;
Ok((response.droplet, startup_key).into())
} else {
Err(ResponseError::Other(response.text().await.map_err(NetworkError::from)?).into())
}
}
async fn kill(&self, id: &str) -> Result<()> {
let response = self
.client
.delete(format!("https://api.digitalocean.com/v2/droplets/{}", id))
.bearer_auth(&self.token)
.send()
.await
.map_err(NetworkError::from)?;
CloudError::from_status_code(response.status())
}
async fn wait_for_ip(&self, id: &str) -> Result<Server> {
let instance = loop {
let instance = self.get_instance(id).await?;
let ip = instance.networks.v4().next();
if ip.is_some() {
break instance;
} else {
sleep(Duration::from_millis(500)).await;
}
};
Ok(instance.into())
}
}
impl DigitalOcean {
async fn get_instance(&self, id: &str) -> Result<DigitalOceanInstanceResponse> {
let response = self
.client
.get(format!("https://api.digitalocean.com/v2/droplets/{}", id))
.bearer_auth(&self.token)
.send()
.await
.map_err(NetworkError::from)?;
CloudError::from_status_code(response.status())?;
let response: DigitalOceanGetResponse =
response.json().await.map_err(ResponseError::from)?;
Ok(response.droplet)
}
async fn get_ssh_key_id(&self, ssh_key: &str) -> Result<u32> {
let response = self
.client
.get("https://api.digitalocean.com/v2/account/keys/")
.bearer_auth(&self.token)
.send()
.await
.map_err(NetworkError::from)?;
CloudError::from_status_code(response.status())?;
if !response.status().is_success() {
return Err(
ResponseError::Other(response.text().await.map_err(NetworkError::from)?).into(),
);
}
let response: DigitalOceanSshListResponse =
response.json().await.map_err(ResponseError::from)?;
if let Some(key) = response
.ssh_keys
.into_iter()
.find(|key| key.public_key == ssh_key)
{
Ok(key.id)
} else {
self.create_key("Dispenser Key", ssh_key).await
}
}
async fn create_key(&self, name: &str, ssh_key: &str) -> Result<u32> {
let response = self
.client
.post("https://api.digitalocean.com/v2/account/keys/")
.bearer_auth(&self.token)
.json(&DigitalOceanCreateSshKeyParams {
name,
public_key: ssh_key,
})
.send()
.await
.map_err(NetworkError::from)?;
CloudError::from_status_code(response.status())?;
let response = response.error_for_status().map_err(NetworkError)?;
let response: DigitalOceanSshCreateResponse =
response.json().await.map_err(ResponseError::from)?;
Ok(response.ssh_key.id)
}
async fn remove_key(&self, key_id: u32) -> Result<()> {
let response = self
.client
.delete(format!(
"https://api.digitalocean.com/v2/account/keys/{}",
key_id
))
.bearer_auth(&self.token)
.send()
.await
.map_err(NetworkError::from)?;
CloudError::from_status_code(response.status())?;
Ok(())
}
}
#[derive(Serialize)]
struct DigitalOceanCreateParams<'a> {
name: String,
region: &'a str,
size: &'a str,
tags: &'a [&'a str],
image: &'a str,
ssh_keys: Vec<u32>,
ipv6: bool,
}
#[derive(Debug, Deserialize)]
struct DigitalOceanListResponse {
droplets: Vec<DigitalOceanInstanceResponse>,
}
#[derive(Debug, Deserialize)]
struct DigitalOceanGetResponse {
droplet: DigitalOceanInstanceResponse,
}
#[derive(Debug, Deserialize)]
struct DigitalOceanCreateResponse {
droplet: DigitalOceanCreatedInstanceResponse,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct DigitalOceanInstanceResponse {
id: u32,
memory: u64,
networks: DigitalOceanNetworks,
vcpus: u16,
created_at: DateTime<Utc>,
tags: Vec<String>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct DigitalOceanNetworks {
v4: Vec<DigitalOceanNetwork>,
v6: Vec<DigitalOceanNetwork>,
}
impl DigitalOceanNetworks {
fn v4(&self) -> impl Iterator<Item = IpAddr> + '_ {
self.v4
.iter()
.filter(|net| net.ty == DigitalOceanNetworkType::Public)
.map(|net| net.ip_address)
}
fn v6(&self) -> impl Iterator<Item = IpAddr> + '_ {
self.v6
.iter()
.filter(|net| net.ty == DigitalOceanNetworkType::Public)
.map(|net| net.ip_address)
}
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct DigitalOceanNetwork {
ip_address: IpAddr,
gateway: String,
#[serde(rename = "type")]
ty: DigitalOceanNetworkType,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
enum DigitalOceanNetworkType {
Private,
Public,
}
#[derive(Debug, Deserialize)]
struct DigitalOceanCreatedInstanceResponse {
id: u32,
}
impl From<DigitalOceanInstanceResponse> for Server {
fn from(instance: DigitalOceanInstanceResponse) -> Self {
Server {
id: instance.id.to_string(),
created: instance.created_at,
ip: instance
.networks
.v4()
.next()
.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)),
ip_v6: instance.networks.v6().next(),
}
}
}
impl From<(DigitalOceanCreatedInstanceResponse, Arc<KeyPair>)> for Created {
fn from((instance, key): (DigitalOceanCreatedInstanceResponse, Arc<KeyPair>)) -> Self {
Created {
id: instance.id.to_string(),
auth: CreatedAuth::Ssh(key),
}
}
}
#[allow(dead_code)]
#[derive(Serialize)]
struct DigitalOceanCreateSshKeyParams<'a> {
name: &'a str,
public_key: &'a str,
}
#[derive(Debug, Deserialize)]
struct DigitalOceanSshCreateResponse {
ssh_key: DigitalOceanSshCreateKey,
}
#[derive(Debug, Deserialize)]
struct DigitalOceanSshCreateKey {
id: u32,
}
#[derive(Debug, Deserialize)]
struct DigitalOceanSshListResponse {
ssh_keys: Vec<DigitalOceanSshKey>,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct DigitalOceanSshKey {
id: u32,
fingerprint: String,
public_key: String,
name: String,
}

View file

@ -1,10 +1,14 @@
use std::fmt::{Display, Formatter};
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use reqwest::StatusCode; use reqwest::StatusCode;
use thiserror::Error; use thiserror::Error;
use thrussh_keys::key::KeyPair;
pub mod digitalocean;
pub mod vultr; pub mod vultr;
#[derive(Debug, Error)] #[derive(Debug, Error)]
@ -63,13 +67,11 @@ pub trait Cloud: Send + Sync + 'static {
/// List all running servers on this cloud /// List all running servers on this cloud
async fn list(&self) -> Result<Vec<Server>>; async fn list(&self) -> Result<Vec<Server>>;
/// Create a new server with the given parameter /// Create a new server with the given parameter
async fn spawn(&self, ssh_key_id: Option<&str>) -> Result<Created>; async fn spawn(&self, ssh_keys: &[String]) -> Result<Created>;
/// Destroy a given server /// Destroy a given server
async fn kill(&self, id: &str) -> Result<()>; async fn kill(&self, id: &str) -> Result<()>;
/// Wait until the server has an ip /// Wait until the server has an ip
async fn wait_for_ip(&self, id: &str) -> Result<Server>; async fn wait_for_ip(&self, id: &str) -> Result<Server>;
/// Get the id for the given ssh key
async fn get_ssh_key_id(&self, key: &str) -> Result<String>;
} }
#[derive(Debug)] #[derive(Debug)]
@ -83,5 +85,20 @@ pub struct Server {
#[derive(Debug)] #[derive(Debug)]
pub struct Created { pub struct Created {
pub id: String, pub id: String,
pub password: String, pub auth: CreatedAuth,
}
#[derive(Debug)]
pub enum CreatedAuth {
Password(String),
Ssh(Arc<KeyPair>),
}
impl Display for CreatedAuth {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
CreatedAuth::Password(s) => s.fmt(f),
CreatedAuth::Ssh(_) => write!(f, "public key only"),
}
}
} }

View file

@ -1,6 +1,10 @@
use crate::cloud::{Cloud, CloudError, Created, NetworkError, ResponseError, Result, Server}; use crate::cloud::{
Cloud, CloudError, Created, CreatedAuth, NetworkError, ResponseError, Result, Server,
};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use futures_util::stream::FuturesUnordered;
use futures_util::TryStreamExt;
use petname::petname; use petname::petname;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
@ -48,12 +52,14 @@ impl Cloud for Vultr {
.collect()) .collect())
} }
async fn spawn(&self, ssh_key_id: Option<&str>) -> Result<Created> { async fn spawn(&self, ssh_keys: &[String]) -> Result<Created> {
let key_ids = if let Some(key) = ssh_key_id { let key_ids = ssh_keys
vec![key] .iter()
} else { .map(|key| self.get_ssh_key_id(key))
vec![] .collect::<FuturesUnordered<_>>()
}; .try_collect::<Vec<String>>()
.await?;
let response = self let response = self
.client .client
.post("https://api.vultr.com/v2/instances") .post("https://api.vultr.com/v2/instances")
@ -64,7 +70,7 @@ impl Cloud for Vultr {
tag: "spire", tag: "spire",
label: petname(2, "-"), label: petname(2, "-"),
app_id: self.get_app_id("docker").await?, app_id: self.get_app_id("docker").await?,
sshkey_id: &key_ids, sshkey_id: key_ids,
enable_ipv6: true, enable_ipv6: true,
}) })
.send() .send()
@ -103,6 +109,40 @@ impl Cloud for Vultr {
}; };
Ok(instance.into()) Ok(instance.into())
} }
}
impl Vultr {
async fn get_app_id(&self, short_name: &str) -> Result<u16> {
let response = self
.client
.get("https://api.vultr.com/v2/applications")
.send()
.await
.map_err(NetworkError::from)?;
let response: VultrApplicationsResponse =
response.json().await.map_err(ResponseError::from)?;
Ok(response
.applications
.into_iter()
.find_map(|application| (application.short_name == short_name).then(|| application.id))
.ok_or_else(|| {
ResponseError::Other(format!("Application \"{}\" not found", short_name))
})?)
}
async fn get_instance(&self, id: &str) -> Result<VultrInstanceResponse> {
let response = self
.client
.get(format!("https://api.vultr.com/v2/instances/{}", id))
.bearer_auth(&self.token)
.send()
.await
.map_err(NetworkError::from)?;
CloudError::from_status_code(response.status())?;
let response: VultrGetResponse = response.json().await.map_err(ResponseError::from)?;
Ok(response.instance)
}
async fn get_ssh_key_id(&self, ssh_key: &str) -> Result<String> { async fn get_ssh_key_id(&self, ssh_key: &str) -> Result<String> {
let response = self let response = self
@ -148,40 +188,6 @@ impl Cloud for Vultr {
} }
} }
impl Vultr {
async fn get_app_id(&self, short_name: &str) -> Result<u16> {
let response = self
.client
.get("https://api.vultr.com/v2/applications")
.send()
.await
.map_err(NetworkError::from)?;
let response: VultrApplicationsResponse =
response.json().await.map_err(ResponseError::from)?;
Ok(response
.applications
.into_iter()
.find_map(|application| (application.short_name == short_name).then(|| application.id))
.ok_or_else(|| {
ResponseError::Other(format!("Application \"{}\" not found", short_name))
})?)
}
async fn get_instance(&self, id: &str) -> Result<VultrInstanceResponse> {
let response = self
.client
.get(format!("https://api.vultr.com/v2/instances/{}", id))
.bearer_auth(&self.token)
.send()
.await
.map_err(NetworkError::from)?;
CloudError::from_status_code(response.status())?;
let response: VultrGetResponse = response.json().await.map_err(ResponseError::from)?;
Ok(response.instance)
}
}
#[derive(Serialize)] #[derive(Serialize)]
struct VultrCreateParams<'a> { struct VultrCreateParams<'a> {
region: &'a str, region: &'a str,
@ -189,7 +195,7 @@ struct VultrCreateParams<'a> {
tag: &'a str, tag: &'a str,
label: String, label: String,
app_id: u16, app_id: u16,
sshkey_id: &'a [&'a str], sshkey_id: Vec<String>,
enable_ipv6: bool, enable_ipv6: bool,
} }
@ -253,7 +259,7 @@ impl From<VultrCreatedInstanceResponse> for Created {
fn from(instance: VultrCreatedInstanceResponse) -> Self { fn from(instance: VultrCreatedInstanceResponse) -> Self {
Created { Created {
id: instance.id, id: instance.id,
password: instance.default_password, auth: CreatedAuth::Password(instance.default_password),
} }
} }
} }

View file

@ -1,3 +1,4 @@
use crate::cloud::digitalocean::DigitalOcean;
use crate::cloud::vultr::Vultr; use crate::cloud::vultr::Vultr;
use crate::cloud::Cloud; use crate::cloud::Cloud;
use camino::Utf8PathBuf; use camino::Utf8PathBuf;
@ -15,6 +16,8 @@ pub enum ConfigError {
Toml(#[from] TomlError), Toml(#[from] TomlError),
#[error("No cloud provider configured")] #[error("No cloud provider configured")]
NoProvider, NoProvider,
#[error("Multiple cloud providers configured")]
MultipleProviders,
} }
/// Intentionally opaque error /// Intentionally opaque error
@ -31,6 +34,7 @@ impl From<toml::de::Error> for TomlError {
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct Config { pub struct Config {
pub vultr: Option<VultrConfig>, pub vultr: Option<VultrConfig>,
pub digital_ocean: Option<DigitalOceanConfig>,
pub server: ServerConfig, pub server: ServerConfig,
pub dyndns: Option<DynDnsConfig>, pub dyndns: Option<DynDnsConfig>,
pub schedule: ScheduleConfig, pub schedule: ScheduleConfig,
@ -43,12 +47,20 @@ impl Config {
} }
pub fn cloud(&self) -> Result<Arc<dyn Cloud>, ConfigError> { pub fn cloud(&self) -> Result<Arc<dyn Cloud>, ConfigError> {
if let Some(vultr) = &self.vultr { if self.vultr.is_some() && self.digital_ocean.is_some() {
Err(ConfigError::NoProvider)
} else if let Some(vultr) = &self.vultr {
Ok(Arc::new(Vultr::new( Ok(Arc::new(Vultr::new(
vultr.api_key.clone(), vultr.api_key.clone(),
vultr.region.clone(), vultr.region.clone(),
vultr.plan.clone(), vultr.plan.clone(),
))) )))
} else if let Some(digital_ocean) = &self.digital_ocean {
Ok(Arc::new(DigitalOcean::new(
digital_ocean.api_key.clone(),
digital_ocean.region.clone(),
digital_ocean.plan.clone(),
)))
} else { } else {
Err(ConfigError::NoProvider) Err(ConfigError::NoProvider)
} }
@ -71,7 +83,8 @@ pub struct ServerConfig {
pub name: String, pub name: String,
#[serde(default = "server_default_tv_name")] #[serde(default = "server_default_tv_name")]
pub tv_name: String, pub tv_name: String,
pub ssh_key: Option<String>, #[serde(default)]
pub ssh_keys: Vec<String>,
#[serde(default)] #[serde(default)]
pub manage_existing: bool, pub manage_existing: bool,
} }
@ -110,6 +123,20 @@ fn vultr_default_plan() -> String {
String::from("vc2-1c-2gb") String::from("vc2-1c-2gb")
} }
#[derive(Deserialize, Debug)]
pub struct DigitalOceanConfig {
pub api_key: String,
/// See https://api.vultr.com/v2/regions for a list of plans
pub region: String,
/// See https://api.vultr.com/v2/plans for a list of plans
#[serde(default = "digital_ocean_default_plan")]
pub plan: String,
}
fn digital_ocean_default_plan() -> String {
String::from("s-2vcpu-2gb")
}
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct DynDnsConfig { pub struct DynDnsConfig {
pub update_url: String, pub update_url: String,

View file

@ -17,6 +17,8 @@ pub enum DynDnsError {
NotYourDomain, NotYourDomain,
#[error("Invalid hostname")] #[error("Invalid hostname")]
InvalidHostname, InvalidHostname,
#[error("Rate limited")]
Abuse,
} }
impl DynDnsError { impl DynDnsError {
@ -69,6 +71,7 @@ impl DynDnsClient {
"!yours" => Err(DynDnsError::NotYourDomain), "!yours" => Err(DynDnsError::NotYourDomain),
"nochg" | "good" => Ok(()), "nochg" | "good" => Ok(()),
"notfqdn" | "nohost" | "numhost" => Err(DynDnsError::InvalidHostname), "notfqdn" | "nohost" | "numhost" => Err(DynDnsError::InvalidHostname),
"abuse" => Err(DynDnsError::Abuse),
_ => Err(DynDnsError::InvalidResponse(text)), _ => Err(DynDnsError::InvalidResponse(text)),
} }
} }

View file

@ -1,4 +1,6 @@
use crate::cloud::{Cloud, CloudError, Server}; extern crate core;
use crate::cloud::{Cloud, CloudError, CreatedAuth, Server};
use crate::config::{Config, ConfigError, ServerConfig}; use crate::config::{Config, ConfigError, ServerConfig};
use crate::dns::{DynDnsClient, DynDnsError}; use crate::dns::{DynDnsClient, DynDnsError};
use crate::rcon::Rcon; use crate::rcon::Rcon;
@ -296,18 +298,12 @@ async fn start(cloud: &dyn Cloud, config: &Config) -> Result<Server, Error> {
return Err(Error::AlreadyRunning(first)); return Err(Error::AlreadyRunning(first));
} }
let ssh_key = if let Some(key) = config.server.ssh_key.as_ref() { let created = cloud.spawn(&config.server.ssh_keys).await?;
Some(cloud.get_ssh_key_id(key).await?)
} else {
None
};
let created = cloud.spawn(ssh_key.as_deref()).await?;
let server = cloud.wait_for_ip(&created.id).await?; let server = cloud.wait_for_ip(&created.id).await?;
println!("Server is booting"); println!("Server is booting");
println!(" IP: {}", server.ip); println!(" IP: {}", server.ip);
println!(" Root Password: {}", created.password); println!(" Root Password: {}", created.auth);
let connect_host = if let Some(dns_config) = config.dyndns.as_ref() { let connect_host = if let Some(dns_config) = config.dyndns.as_ref() {
let dns = DynDnsClient::new( let dns = DynDnsClient::new(
@ -325,7 +321,7 @@ async fn start(cloud: &dyn Cloud, config: &Config) -> Result<Server, Error> {
format!("{}", server.ip) format!("{}", server.ip)
}; };
let mut ssh = connect_ssh(server.ip, &created.password).await?; let mut ssh = connect_ssh(server.ip, &created.auth).await?;
setup(&mut ssh, &config.server).await?; setup(&mut ssh, &config.server).await?;
ssh.close().await?; ssh.close().await?;
@ -338,14 +334,14 @@ async fn start(cloud: &dyn Cloud, config: &Config) -> Result<Server, Error> {
Ok(server) Ok(server)
} }
async fn connect_ssh(ip: IpAddr, password: &str) -> Result<SshSession, Error> { async fn connect_ssh(ip: IpAddr, auth: &CreatedAuth) -> Result<SshSession, Error> {
let mut tries = 0; let mut tries = 0;
loop { loop {
tries += 1; tries += 1;
sleep(Duration::from_secs(2)).await; sleep(Duration::from_secs(2)).await;
match SshSession::open(ip, password).await { match SshSession::open(ip, &auth).await {
Ok(ssh) => { Ok(ssh) => {
return Ok(ssh); return Ok(ssh);
} }
@ -357,7 +353,7 @@ async fn connect_ssh(ip: IpAddr, password: &str) -> Result<SshSession, Error> {
return Err(e.into()); return Err(e.into());
} }
Err(_) => { Err(_) => {
error!(tries = tries, "Failed to connect to ssh, giving up"); error!(tries = tries, "Failed to connect to ssh");
} }
} }
} }

View file

@ -1,3 +1,4 @@
use crate::CreatedAuth;
use futures_util::future::{self}; use futures_util::future::{self};
use std::convert::identity; use std::convert::identity;
use std::fmt::{Debug, Formatter}; use std::fmt::{Debug, Formatter};
@ -71,12 +72,12 @@ impl Debug for SshSession {
} }
impl SshSession { impl SshSession {
#[instrument(skip(password))] #[instrument(skip(auth))]
pub async fn open(ip: IpAddr, password: &str) -> Result<Self, SshError> { pub async fn open(ip: IpAddr, auth: &CreatedAuth) -> Result<Self, SshError> {
timeout(Duration::from_secs(5 * 60), async move { timeout(Duration::from_secs(5 * 60), async move {
loop { loop {
sleep(Duration::from_secs(1)).await; sleep(Duration::from_secs(1)).await;
match SshSession::open_impl(ip, password).await { match SshSession::open_impl(ip, auth).await {
Ok(ssh) => return Ok(ssh), Ok(ssh) => return Ok(ssh),
Err(SshError::ConnectionTimeout) => {} Err(SshError::ConnectionTimeout) => {}
Err(e) => return Err(e), Err(e) => return Err(e),
@ -88,13 +89,19 @@ impl SshSession {
.and_then(identity) .and_then(identity)
} }
async fn open_impl(ip: IpAddr, password: &str) -> Result<Self, SshError> { async fn open_impl(ip: IpAddr, auth: &CreatedAuth) -> Result<Self, SshError> {
let config = client::Config::default(); let config = client::Config::default();
let config = Arc::new(config); let config = Arc::new(config);
let sh = Client {}; let sh = Client {};
let mut handle = client::connect(config, (ip, 22), sh).await?; let mut handle = client::connect(config, (ip, 22), sh).await?;
if handle.authenticate_password("root", password).await? { let result = match auth {
CreatedAuth::Password(password) => {
handle.authenticate_password("root", password).await?
}
CreatedAuth::Ssh(key) => handle.authenticate_publickey("root", key.clone()).await?,
};
if result {
Ok(SshSession { ip, handle }) Ok(SshSession { ip, handle })
} else { } else {
Err(SshError::Unauthorized) Err(SshError::Unauthorized)