🔥 使用 Rust 重构 #5

Closed
LittleSheep wants to merge 9 commits from refactor/rust into master
11 changed files with 418 additions and 94 deletions
Showing only changes of commit 91ecf9d7bb - Show all commits

37
Cargo.lock generated
View File

@ -852,6 +852,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
dependencies = [
"mime",
"unicase",
]
[[package]] [[package]]
name = "minimal-lexical" name = "minimal-lexical"
version = "0.2.1" version = "0.2.1"
@ -1179,9 +1189,11 @@ dependencies = [
"headers", "headers",
"http", "http",
"http-body-util", "http-body-util",
"httpdate",
"hyper", "hyper",
"hyper-util", "hyper-util",
"mime", "mime",
"mime_guess",
"multer", "multer",
"nix", "nix",
"parking_lot", "parking_lot",
@ -1306,6 +1318,19 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "queryst"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1cbeb75ac695daf201ca2d66d9c684f873b135f28af4f2c79952478cab3b9d9"
dependencies = [
"lazy_static",
"percent-encoding",
"regex",
"serde",
"serde_json",
]
[[package]] [[package]]
name = "quick-xml" name = "quick-xml"
version = "0.30.0" version = "0.30.0"
@ -1450,8 +1475,11 @@ dependencies = [
"http", "http",
"hyper-util", "hyper-util",
"lazy_static", "lazy_static",
"mime",
"percent-encoding",
"poem", "poem",
"poem-openapi", "poem-openapi",
"queryst",
"rand", "rand",
"regex", "regex",
"reqwest", "reqwest",
@ -2126,6 +2154,15 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "unicase"
version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89"
dependencies = [
"version_check",
]
[[package]] [[package]]
name = "unicode-bidi" name = "unicode-bidi"
version = "0.3.14" version = "0.3.14"

View File

@ -11,8 +11,11 @@ futures-util = "0.3.30"
http = "1.0.0" http = "1.0.0"
hyper-util = { version = "0.1.2", features = ["full"] } hyper-util = { version = "0.1.2", features = ["full"] }
lazy_static = "1.4.0" lazy_static = "1.4.0"
poem = { version = "2.0.0", features = ["tokio-metrics", "websocket"] } mime = "0.3.17"
percent-encoding = "2.3.1"
poem = { version = "2.0.0", features = ["tokio-metrics", "websocket", "static-files"] }
poem-openapi = { version = "4.0.0", features = ["swagger-ui"] } poem-openapi = { version = "4.0.0", features = ["swagger-ui"] }
queryst = "3.0.0"
rand = "0.8.5" rand = "0.8.5"
regex = "1.10.2" regex = "1.10.2"
reqwest = { git = "https://github.com/seanmonstar/reqwest.git", branch = "hyper-v1", version = "0.11.23" } reqwest = { git = "https://github.com/seanmonstar/reqwest.git", branch = "hyper-v1", version = "0.11.23" }

12
regions/index.html Normal file
View File

@ -0,0 +1,12 @@
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Hello, World!</title>
</head>
<body>
<p>Hello, there!</p>
<p>Here's the roadsign benchmarking test data!</p>
</body>
</html>

View File

@ -5,5 +5,5 @@ id = "root"
hosts = ["localhost"] hosts = ["localhost"]
paths = ["/"] paths = ["/"]
[[locations.destinations]] [[locations.destinations]]
id = "echo" id = "static"
uri = "https://postman-echo.com/get" uri = "files://regions/index.html"

1
regions/kokodayo.txt Normal file
View File

@ -0,0 +1 @@
Ko Ko Da Yo~

View File

@ -2,7 +2,7 @@ mod config;
mod proxies; mod proxies;
mod sideload; mod sideload;
use poem::{listener::TcpListener, Endpoint, EndpointExt, Route, Server}; use poem::{listener::TcpListener, EndpointExt, Route, Server};
use poem_openapi::OpenApiService; use poem_openapi::OpenApiService;
use tracing::{error, info, Level}; use tracing::{error, info, Level};

52
src/proxies/browser.rs Normal file
View File

@ -0,0 +1,52 @@
use std::fmt::Write;
pub struct DirectoryTemplate<'a> {
pub path: &'a str,
pub files: Vec<FileRef>,
}
impl<'a> DirectoryTemplate<'a> {
pub fn render(&self) -> String {
let mut s = format!(
r#"
<html>
<head>
<title>Index of {}</title>
</head>
<body>
<h1>Index of /{}</h1>
<ul>"#,
self.path, self.path
);
for file in &self.files {
if file.is_dir {
let _ = write!(
s,
r#"<li><a href="{}">{}/</a></li>"#,
file.url, file.filename
);
} else {
let _ = write!(
s,
r#"<li><a href="{}">{}</a></li>"#,
file.url, file.filename
);
}
}
s.push_str(
r#"</ul>
</body>
</html>"#,
);
s
}
}
pub struct FileRef {
pub url: String,
pub filename: String,
pub is_dir: bool,
}

View File

@ -1,6 +1,10 @@
use std::collections::HashMap; use std::collections::HashMap;
use queryst::parse;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json;
use super::responder::StaticResponderConfig;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Region { pub struct Region {
@ -15,6 +19,7 @@ pub struct Location {
pub paths: Vec<String>, pub paths: Vec<String>,
pub headers: Option<HashMap<String, String>>, pub headers: Option<HashMap<String, String>>,
pub queries: Option<Vec<String>>, pub queries: Option<Vec<String>>,
pub methods: Option<Vec<String>>,
pub destinations: Vec<Destination>, pub destinations: Vec<Destination>,
} }
@ -46,15 +51,35 @@ impl Destination {
} }
pub fn get_queries(&self) -> &str { pub fn get_queries(&self) -> &str {
self.uri.as_str().splitn(2, "?").collect::<Vec<_>>()[1] self.uri
.as_str()
.splitn(2, '?')
.collect::<Vec<_>>()
.get(1)
.unwrap_or(&"")
} }
pub fn get_host(&self) -> &str { pub fn get_host(&self) -> &str {
(self.uri.as_str().splitn(2, "://").collect::<Vec<_>>()[1]) (self
.splitn(2, "?") .uri
.as_str()
.splitn(2, "://")
.collect::<Vec<_>>()
.get(1)
.unwrap_or(&""))
.splitn(2, '?')
.collect::<Vec<_>>()[0] .collect::<Vec<_>>()[0]
} }
pub fn get_websocket_uri(&self) -> Result<String, ()> {
let parts = self.uri.as_str().splitn(2, "://").collect::<Vec<_>>();
let url = parts.get(1).unwrap_or(&"");
match self.get_protocol() {
"http" | "https" => Ok(url.replace("http", "ws")),
_ => Err(()),
}
}
pub fn get_hypertext_uri(&self) -> Result<String, ()> { pub fn get_hypertext_uri(&self) -> Result<String, ()> {
match self.get_protocol() { match self.get_protocol() {
"http" => Ok("http://".to_string() + self.get_host()), "http" => Ok("http://".to_string() + self.get_host()),
@ -63,10 +88,32 @@ impl Destination {
} }
} }
pub fn get_websocket_uri(&self) -> Result<String, ()> { pub fn get_static_config(&self) -> Result<StaticResponderConfig, ()> {
let url = self.uri.as_str().splitn(2, "://").collect::<Vec<_>>()[1];
match self.get_protocol() { match self.get_protocol() {
"http" | "https" => Ok(url.replace("http", "ws")), "file" | "files" => {
let queries = parse(self.get_queries()).unwrap_or(json!({}));
Ok(StaticResponderConfig {
uri: self.get_host().to_string(),
utf8: queries
.get("utf8")
.and_then(|val| val.as_bool())
.unwrap_or(false),
with_slash: queries
.get("slash")
.and_then(|val| val.as_bool())
.unwrap_or(false),
browse: queries
.get("browse")
.and_then(|val| val.as_bool())
.unwrap_or(false),
index: queries
.get("index")
.and_then(|val| val.as_str().map(str::to_string)),
fallback: queries
.get("fallback")
.and_then(|val| val.as_str().map(str::to_string)),
})
}
_ => Err(()), _ => Err(()),
} }
} }

View File

@ -1,3 +1,4 @@
use http::Method;
use poem::http::{HeaderMap, Uri}; use poem::http::{HeaderMap, Uri};
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -5,8 +6,10 @@ use wildmatch::WildMatch;
use self::config::{Location, Region}; use self::config::{Location, Region};
pub mod browser;
pub mod config; pub mod config;
pub mod loader; pub mod loader;
pub mod responder;
pub mod route; pub mod route;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -19,7 +22,7 @@ impl Instance {
Instance { regions: vec![] } Instance { regions: vec![] }
} }
pub fn filter(&self, uri: &Uri, headers: &HeaderMap) -> Option<&Location> { pub fn filter(&self, uri: &Uri, method: Method, headers: &HeaderMap) -> Option<&Location> {
self.regions.iter().find_map(|region| { self.regions.iter().find_map(|region| {
region.locations.iter().find(|location| { region.locations.iter().find(|location| {
let mut hosts = location.hosts.iter(); let mut hosts = location.hosts.iter();
@ -37,6 +40,12 @@ impl Instance {
return false; return false;
} }
if let Some(val) = location.methods.clone() {
if !val.iter().any(|item| *item == method.to_string()) {
return false;
}
}
if let Some(val) = location.headers.clone() { if let Some(val) = location.headers.clone() {
match !val.keys().all(|item| { match !val.keys().all(|item| {
headers.get(item).unwrap() headers.get(item).unwrap()

221
src/proxies/responder.rs Normal file
View File

@ -0,0 +1,221 @@
use futures_util::{SinkExt, StreamExt};
use http::{header, request::Builder, HeaderMap, Method, StatusCode, Uri};
use lazy_static::lazy_static;
use poem::{
web::{websocket::WebSocket, StaticFileRequest},
Body, Error, FromRequest, IntoResponse, Request, Response,
};
use std::{
ffi::OsStr,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::sync::RwLock;
use tokio_tungstenite::connect_async;
use super::browser::{DirectoryTemplate, FileRef};
lazy_static! {
pub static ref CLIENT: reqwest::Client = reqwest::Client::new();
}
pub async fn repond_websocket(req: Builder, ws: WebSocket) -> Response {
ws.on_upgrade(move |socket| async move {
let (mut clientsink, mut clientstream) = socket.split();
// Start connection to server
let (serversocket, _) = connect_async(req.body(()).unwrap()).await.unwrap();
let (mut serversink, mut serverstream) = serversocket.split();
let client_live = Arc::new(RwLock::new(true));
let server_live = client_live.clone();
tokio::spawn(async move {
while let Some(Ok(msg)) = clientstream.next().await {
if (serversink.send(msg.into()).await).is_err() {
break;
};
if !*client_live.read().await {
break;
};
}
*client_live.write().await = false;
});
// Relay server messages to the client
tokio::spawn(async move {
while let Some(Ok(msg)) = serverstream.next().await {
if (clientsink.send(msg.into()).await).is_err() {
break;
};
if !*server_live.read().await {
break;
};
}
*server_live.write().await = false;
});
})
.into_response()
}
pub async fn respond_hypertext(
uri: String,
ori: &Uri,
method: Method,
body: Body,
headers: &HeaderMap,
) -> Result<Response, Error> {
let res = CLIENT
.request(method, uri + ori.path() + ori.query().unwrap_or(""))
.headers(headers.clone())
.body(body.into_bytes().await.unwrap())
.send()
.await;
match res {
Ok(result) => {
let mut res = Response::default();
res.extensions().clone_from(&result.extensions());
result.headers().iter().for_each(|(key, val)| {
res.headers_mut().insert(key, val.to_owned());
});
res.set_status(result.status());
res.set_version(result.version());
res.set_body(result.bytes().await.unwrap());
Ok(res)
}
Err(error) => Err(Error::from_string(
error.to_string(),
error.status().unwrap_or(StatusCode::BAD_GATEWAY),
)),
}
}
pub struct StaticResponderConfig {
pub uri: String,
pub utf8: bool,
pub with_slash: bool,
pub browse: bool,
pub index: Option<String>,
pub fallback: Option<String>,
}
pub async fn respond_static(
cfg: StaticResponderConfig,
method: Method,
req: &Request,
) -> Result<Response, Error> {
if method != Method::GET {
return Err(Error::from_string(
"This destination only support GET request.",
StatusCode::METHOD_NOT_ALLOWED,
));
}
let path = req
.uri()
.path()
.trim_start_matches('/')
.trim_end_matches('/');
let path = percent_encoding::percent_decode_str(path)
.decode_utf8()
.map_err(|_| Error::from_status(StatusCode::NOT_FOUND))?;
let base_path = cfg.uri.parse::<PathBuf>().unwrap();
let mut file_path = base_path.clone();
for p in Path::new(&*path) {
if p == OsStr::new(".") {
continue;
} else if p == OsStr::new("..") {
file_path.pop();
} else {
file_path.push(p);
}
}
if !file_path.starts_with(cfg.uri) {
return Err(Error::from_status(StatusCode::FORBIDDEN));
}
if !file_path.exists() {
if let Some(file) = cfg.fallback {
let fallback_path = base_path.join(file);
if fallback_path.is_file() {
return Ok(StaticFileRequest::from_request_without_body(req)
.await?
.create_response(&fallback_path, cfg.utf8)?
.into_response());
}
}
return Err(Error::from_status(StatusCode::NOT_FOUND));
}
if file_path.is_file() {
Ok(StaticFileRequest::from_request_without_body(req)
.await?
.create_response(&file_path, cfg.utf8)?
.into_response())
} else {
if cfg.with_slash
&& !req.original_uri().path().ends_with('/')
&& (cfg.index.is_some() || cfg.browse)
{
let redirect_to = format!("{}/", req.original_uri().path());
return Ok(Response::builder()
.status(StatusCode::FOUND)
.header(header::LOCATION, redirect_to)
.finish());
}
if let Some(index_file) = &cfg.index {
let index_path = file_path.join(index_file);
if index_path.is_file() {
return Ok(StaticFileRequest::from_request_without_body(req)
.await?
.create_response(&index_path, cfg.utf8)?
.into_response());
}
}
if cfg.browse {
let read_dir = file_path
.read_dir()
.map_err(|_| Error::from_status(StatusCode::FORBIDDEN))?;
let mut template = DirectoryTemplate {
path: &path,
files: Vec::new(),
};
for res in read_dir {
let entry = res.map_err(|_| Error::from_status(StatusCode::FORBIDDEN))?;
if let Some(filename) = entry.file_name().to_str() {
let mut base_url = req.original_uri().path().to_string();
if !base_url.ends_with('/') {
base_url.push('/');
}
let filename_url = percent_encoding::percent_encode(
filename.as_bytes(),
percent_encoding::NON_ALPHANUMERIC,
);
template.files.push(FileRef {
url: format!("{base_url}{filename_url}"),
filename: filename.to_string(),
is_dir: entry.path().is_dir(),
});
}
}
let html = template.render();
Ok(Response::builder()
.header(header::CONTENT_TYPE, mime::TEXT_HTML_UTF_8.as_ref())
.body(Body::from_string(html)))
} else {
Err(Error::from_status(StatusCode::NOT_FOUND))
}
}
}

View File

@ -1,4 +1,3 @@
use futures_util::{SinkExt, StreamExt};
use poem::{ use poem::{
handler, handler,
http::{HeaderMap, StatusCode, Uri}, http::{HeaderMap, StatusCode, Uri},
@ -8,16 +7,10 @@ use poem::{
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use reqwest::Method; use reqwest::Method;
use lazy_static::lazy_static; use crate::proxies::{
use std::sync::Arc; config::{Destination, DestinationType},
use tokio::sync::RwLock; responder,
use tokio_tungstenite::connect_async; };
use crate::proxies::config::{Destination, DestinationType};
lazy_static! {
pub static ref CLIENT: reqwest::Client = reqwest::Client::new();
}
#[handler] #[handler]
pub async fn handle( pub async fn handle(
@ -28,7 +21,7 @@ pub async fn handle(
method: Method, method: Method,
body: Body, body: Body,
) -> Result<impl IntoResponse, Error> { ) -> Result<impl IntoResponse, Error> {
let location = match app.filter(uri, headers) { let location = match app.filter(uri, method.clone(), headers) {
Some(val) => val, Some(val) => val,
None => { None => {
return Err(Error::from_string( return Err(Error::from_string(
@ -50,13 +43,13 @@ pub async fn handle(
headers: &HeaderMap, headers: &HeaderMap,
method: Method, method: Method,
body: Body, body: Body,
) -> Result<impl IntoResponse, Error> { ) -> Result<Response, Error> {
// Handle websocket // Handle websocket
if let Ok(ws) = WebSocket::from_request_without_body(req).await { if let Ok(ws) = WebSocket::from_request_without_body(req).await {
// Get uri // Get uri
let Ok(uri) = end.get_websocket_uri() else { let Ok(uri) = end.get_websocket_uri() else {
return Err(Error::from_string( return Err(Error::from_string(
"Proxy endpoint not configured to support websockets!", "This destination was not support websockets.",
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_IMPLEMENTED,
)); ));
}; };
@ -68,47 +61,7 @@ pub async fn handle(
} }
// Start the websocket connection // Start the websocket connection
return Ok(ws return Ok(responder::repond_websocket(ws_req, ws).await);
.on_upgrade(move |socket| async move {
let (mut clientsink, mut clientstream) = socket.split();
// Start connection to server
let (serversocket, _) = connect_async(ws_req.body(()).unwrap()).await.unwrap();
let (mut serversink, mut serverstream) = serversocket.split();
let client_live = Arc::new(RwLock::new(true));
let server_live = client_live.clone();
tokio::spawn(async move {
while let Some(Ok(msg)) = clientstream.next().await {
match serversink.send(msg.into()).await {
Err(_) => break,
_ => {}
};
if !*client_live.read().await {
break;
};
}
*client_live.write().await = false;
});
// Relay server messages to the client
tokio::spawn(async move {
while let Some(Ok(msg)) = serverstream.next().await {
match clientsink.send(msg.into()).await {
Err(_) => break,
_ => {}
};
if !*server_live.read().await {
break;
};
}
*server_live.write().await = false;
});
})
.into_response());
} }
// Handle normal web request // Handle normal web request
@ -116,40 +69,29 @@ pub async fn handle(
DestinationType::Hypertext => { DestinationType::Hypertext => {
let Ok(uri) = end.get_hypertext_uri() else { let Ok(uri) = end.get_hypertext_uri() else {
return Err(Error::from_string( return Err(Error::from_string(
"Proxy endpoint not configured to support web requests!", "This destination was not support web requests.",
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_IMPLEMENTED,
)); ));
}; };
let res = CLIENT responder::respond_hypertext(uri, ori, method, body, headers).await
.request(method, uri + ori.path() + ori.query().unwrap_or(""))
.headers(headers.clone())
.body(body.into_bytes().await.unwrap())
.send()
.await;
match res {
Ok(result) => {
let mut res = Response::default();
res.extensions().clone_from(&result.extensions());
result.headers().iter().for_each(|(key, val)| {
res.headers_mut().insert(key, val.to_owned());
});
res.set_status(result.status());
res.set_version(result.version());
res.set_body(result.bytes().await.unwrap());
Ok(res)
} }
DestinationType::StaticFiles => {
let Ok(cfg) = end.get_static_config() else {
return Err(Error::from_string(
"This destination was not support static files.",
StatusCode::NOT_IMPLEMENTED,
));
};
Err(error) => Err(Error::from_string( responder::respond_static(cfg, method, req).await
error.to_string(), }
error.status().unwrap_or(StatusCode::BAD_GATEWAY), _ => Err(Error::from_string(
"Unsupported destination protocol.",
StatusCode::NOT_IMPLEMENTED,
)), )),
} }
} }
_ => Err(Error::from_status(StatusCode::NOT_IMPLEMENTED)),
}
}
forward(destination, req, uri, headers, method, body).await forward(destination, req, uri, headers, method, body).await
} }