Route and forward for http(s)

This commit is contained in:
2024-01-14 02:01:56 +08:00
parent 5de1d13907
commit c991d0b54a
7 changed files with 732 additions and 22 deletions

View File

@@ -2,10 +2,12 @@ mod config;
mod proxies;
mod sideload;
use poem::{listener::TcpListener, Route, Server};
use poem::{listener::TcpListener, Endpoint, EndpointExt, Route, Server};
use poem_openapi::OpenApiService;
use tracing::{error, info, Level};
use crate::proxies::route;
#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
// Setting up logging
@@ -35,10 +37,19 @@ async fn main() -> Result<(), std::io::Error> {
};
// Proxies
let proxies_server = Server::new(TcpListener::bind(
config::C
.read()
.unwrap()
.get_string("listen.proxies")
.unwrap_or("0.0.0.0:80".to_string()),
))
.run(route::handle.data(instance));
// Sideload
let sideload = OpenApiService::new(sideload::SideloadApi, "Sideload API", "1.0")
.server("http://localhost:3000/cgi");
let sideload_ui = sideload.swagger_ui();
let sideload_server = Server::new(TcpListener::bind(
config::C
@@ -47,11 +58,13 @@ async fn main() -> Result<(), std::io::Error> {
.get_string("listen.sideload")
.unwrap_or("0.0.0.0:81".to_string()),
))
.run(Route::new().nest("/cgi", sideload));
.run(
Route::new()
.nest("/cgi", sideload)
.nest("/swagger", sideload_ui),
);
tokio::try_join!(sideload_server)?;
tokio::try_join!(proxies_server, sideload_server)?;
Ok(())
}

View File

@@ -4,23 +4,70 @@ use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Region {
id: String,
locations: Vec<Location>,
pub id: String,
pub locations: Vec<Location>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Location {
id: String,
hosts: Vec<String>,
paths: Vec<String>,
headers: Option<Vec<HashMap<String, String>>>,
query_strings: Option<Vec<HashMap<String, String>>>,
destinations: Vec<Destination>,
pub id: String,
pub hosts: Vec<String>,
pub paths: Vec<String>,
pub headers: Option<HashMap<String, String>>,
pub queries: Option<Vec<String>>,
pub destinations: Vec<Destination>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Destination {
id: String,
uri: String,
timeout: Option<u32>,
pub id: String,
pub uri: String,
pub timeout: Option<u32>,
pub weight: Option<u32>,
}
pub enum DestinationType {
Hypertext,
StaticFiles,
Unknown,
}
impl Destination {
pub fn get_type(&self) -> DestinationType {
match self.get_protocol() {
"http" | "https" => DestinationType::Hypertext,
"file" | "files" => DestinationType::StaticFiles,
_ => DestinationType::Unknown,
}
}
pub fn get_protocol(&self) -> &str {
self.uri.as_str().splitn(2, "://").collect::<Vec<_>>()[0]
}
pub fn get_queries(&self) -> &str {
self.uri.as_str().splitn(2, "?").collect::<Vec<_>>()[1]
}
pub fn get_host(&self) -> &str {
(self.uri.as_str().splitn(2, "://").collect::<Vec<_>>()[1])
.splitn(2, "?")
.collect::<Vec<_>>()[0]
}
pub fn get_hypertext_uri(&self) -> Result<String, ()> {
match self.get_protocol() {
"http" => Ok("http://".to_string() + self.get_host()),
"https" => Ok("https://".to_string() + self.get_host()),
_ => Err(()),
}
}
pub fn get_websocket_uri(&self) -> Result<String, ()> {
let url = self.uri.as_str().splitn(2, "://").collect::<Vec<_>>()[1];
match self.get_protocol() {
"http" | "https" => Ok(url.replace("http", "ws")),
_ => Err(()),
}
}
}

View File

@@ -1,9 +1,13 @@
use poem::http::{HeaderMap, Uri};
use regex::Regex;
use serde::{Deserialize, Serialize};
use wildmatch::WildMatch;
use self::config::Region;
use self::config::{Location, Region};
pub mod config;
pub mod loader;
pub mod route;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Instance {
@@ -14,4 +18,44 @@ impl Instance {
pub fn new() -> Instance {
Instance { regions: vec![] }
}
pub fn filter(&self, uri: &Uri, headers: &HeaderMap) -> Option<&Location> {
self.regions.iter().find_map(|region| {
region.locations.iter().find(|location| {
let mut hosts = location.hosts.iter();
if !hosts.any(|item| {
WildMatch::new(item.as_str()).matches(uri.host().unwrap_or("localhost"))
}) {
return false;
}
let mut paths = location.paths.iter();
if !paths.any(|item| {
uri.path().starts_with(item)
|| Regex::new(item.as_str()).unwrap().is_match(uri.path())
}) {
return false;
}
if let Some(val) = location.headers.clone() {
match !val.keys().all(|item| {
headers.get(item).unwrap()
== location.headers.clone().unwrap().get(item).unwrap()
}) {
true => return false,
false => (),
}
};
if let Some(val) = location.queries.clone() {
let queries: Vec<&str> = uri.query().unwrap_or("").split('&').collect();
if !val.iter().all(|item| queries.contains(&item.as_str())) {
return false;
}
}
true
})
})
}
}

155
src/proxies/route.rs Normal file
View File

@@ -0,0 +1,155 @@
use futures_util::{SinkExt, StreamExt};
use poem::{
handler,
http::{HeaderMap, StatusCode, Uri},
web::{websocket::WebSocket, Data},
Body, Error, FromRequest, IntoResponse, Request, Response, Result,
};
use rand::seq::SliceRandom;
use reqwest::Method;
use lazy_static::lazy_static;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_tungstenite::connect_async;
use crate::proxies::config::{Destination, DestinationType};
lazy_static! {
pub static ref CLIENT: reqwest::Client = reqwest::Client::new();
}
#[handler]
pub async fn handle(
app: Data<&super::Instance>,
req: &Request,
uri: &Uri,
headers: &HeaderMap,
method: Method,
body: Body,
) -> Result<impl IntoResponse, Error> {
let location = match app.filter(uri, headers) {
Some(val) => val,
None => {
return Err(Error::from_string(
"There are no region be able to respone this request.",
StatusCode::NOT_FOUND,
))
}
};
let destination = location
.destinations
.choose_weighted(&mut rand::thread_rng(), |item| item.weight.unwrap_or(1))
.unwrap();
async fn forward(
end: &Destination,
req: &Request,
ori: &Uri,
headers: &HeaderMap,
method: Method,
body: Body,
) -> Result<impl IntoResponse, Error> {
// Handle websocket
if let Ok(ws) = WebSocket::from_request_without_body(req).await {
// Get uri
let Ok(uri) = end.get_websocket_uri() else {
return Err(Error::from_string(
"Proxy endpoint not configured to support websockets!",
StatusCode::NOT_IMPLEMENTED,
));
};
// Build request
let mut ws_req = http::Request::builder().uri(&uri);
for (key, value) in headers.iter() {
ws_req = ws_req.header(key, value);
}
// Start the websocket connection
return Ok(ws
.on_upgrade(move |socket| async move {
let (mut clientsink, mut clientstream) = socket.split();
// Start connection to server
let (serversocket, _) = connect_async(ws_req.body(()).unwrap()).await.unwrap();
let (mut serversink, mut serverstream) = serversocket.split();
let client_live = Arc::new(RwLock::new(true));
let server_live = client_live.clone();
tokio::spawn(async move {
while let Some(Ok(msg)) = clientstream.next().await {
match serversink.send(msg.into()).await {
Err(_) => break,
_ => {}
};
if !*client_live.read().await {
break;
};
}
*client_live.write().await = false;
});
// Relay server messages to the client
tokio::spawn(async move {
while let Some(Ok(msg)) = serverstream.next().await {
match clientsink.send(msg.into()).await {
Err(_) => break,
_ => {}
};
if !*server_live.read().await {
break;
};
}
*server_live.write().await = false;
});
})
.into_response());
}
// Handle normal web request
match end.get_type() {
DestinationType::Hypertext => {
let Ok(uri) = end.get_hypertext_uri() else {
return Err(Error::from_string(
"Proxy endpoint not configured to support web requests!",
StatusCode::NOT_IMPLEMENTED,
));
};
let res = CLIENT
.request(method, uri + ori.path() + ori.query().unwrap_or(""))
.headers(headers.clone())
.body(body.into_bytes().await.unwrap())
.send()
.await;
match res {
Ok(result) => {
let mut res = Response::default();
res.extensions().clone_from(&result.extensions());
result.headers().iter().for_each(|(key, val)| {
res.headers_mut().insert(key, val.to_owned());
});
res.set_status(result.status());
res.set_version(result.version());
res.set_body(result.bytes().await.unwrap());
Ok(res)
}
Err(error) => Err(Error::from_string(
error.to_string(),
error.status().unwrap_or(StatusCode::BAD_GATEWAY),
)),
}
}
_ => Err(Error::from_status(StatusCode::NOT_IMPLEMENTED)),
}
}
forward(destination, req, uri, headers, method, body).await
}