diff --git a/Cargo.lock b/Cargo.lock index 2b20f05..ad3dd63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -397,6 +397,7 @@ checksum = "03fc05c17098f21b89bc7d98fe1dd3cce2c11c2ad8e145f2a44fe08ed28eb559" dependencies = [ "diesel_derives", "libsqlite3-sys", + "r2d2", "time", ] @@ -1530,6 +1531,17 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r2d2" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93" +dependencies = [ + "log", + "parking_lot", + "scheduled-thread-pool", +] + [[package]] name = "rand" version = "0.8.5" @@ -1782,6 +1794,31 @@ dependencies = [ "uncased", ] +[[package]] +name = "rocket_sync_db_pools" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d83f32721ed79509adac4328e97f817a8f55a47c4b64799f6fd6cc3adb6e42ff" +dependencies = [ + "diesel", + "r2d2", + "rocket", + "rocket_sync_db_pools_codegen", + "serde", + "tokio", + "version_check", +] + +[[package]] +name = "rocket_sync_db_pools_codegen" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc890925dc79370c28eb15c9957677093fdb7e8c44966d189f38cedb995ee68" +dependencies = [ + "devise", + "quote", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -1869,6 +1906,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scheduled-thread-pool" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" +dependencies = [ + "parking_lot", +] + [[package]] name = "scoped-tls" version = "1.0.1" @@ -1960,6 +2006,7 @@ dependencies = [ "reqwest", "rocket", "rocket-accept-language", + "rocket_sync_db_pools", "serde", "serde_json", "tera", diff --git a/Cargo.toml b/Cargo.toml index 06d2c27..fb53a8f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,11 @@ version = "0.12" default-features = false features = ["rustls-tls", "socks"] +[dependencies.rocket_sync_db_pools] +version = "0.1" +default-features = false +features = ["diesel_sqlite_pool"] + [dependencies.tera] version = "1.19" default-features = false diff --git a/src/lib.rs b/src/lib.rs index 30a8a36..9a11f77 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,8 +4,6 @@ pub mod schema; use core::fmt::{Debug, Display, Formatter, Result}; use std::fs::read_to_string; -use diesel::prelude::*; -use diesel::sqlite::SqliteConnection; use kdl::KdlDocument; pub struct Config { @@ -40,11 +38,6 @@ pub fn get_config() -> Config { } } -pub fn establish_connection() -> SqliteConnection { - SqliteConnection::establish(&get_config().database) - .unwrap_or_else(|_| panic!("Error connecting to database")) -} - #[derive(rocket::FromFormField, Debug, PartialEq)] pub enum Software { PeerTube, diff --git a/src/main.rs b/src/main.rs index 62dcc87..9280c5d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ extern crate rocket; mod check; use crate::check::check; -use services::{establish_connection, get_config, models::*, Software}; +use services::{get_config, models::*, Config, Software}; use std::fs::File; @@ -12,6 +12,7 @@ use diesel::prelude::*; use fluent_templates::FluentLoader; use rocket::{ fairing::{Fairing, Info, Kind}, + figment::{util::map, value::Value}, form::{self, Error, Form, Strict}, http::Header, response::content::RawHtml, @@ -19,10 +20,11 @@ use rocket::{ Request, Response, State, }; use rocket_accept_language::{language, AcceptLanguage, LanguageIdentifier}; +use rocket_sync_db_pools::{database, diesel}; use serde::Serialize; use serde_json::json; use std::io::BufReader; -use tera::{Context, Tera, Value}; +use tera::{Context, Tera}; use unic_langid::{langid, subtags::Language}; const LANGUAGE_DEFAULT: Language = language!("en"); @@ -37,7 +39,7 @@ fluent_templates::static_loader! { }; } -fn gen_context(accept_language: &AcceptLanguage, values: Value) -> Context { +fn gen_context(accept_language: &AcceptLanguage, values: tera::Value) -> Context { let mut cont: Context = Context::from_value(json!({ "ln": &accept_language .get_appropriate_language_region(&SL) @@ -66,37 +68,49 @@ impl Fairing for HttpHeaders { } } +#[database("main_db")] +struct DbConn(diesel::SqliteConnection); + #[launch] fn rocket() -> _ { + let config = get_config(); + let mut tera = Tera::new("templates/*.html.tera").unwrap(); tera.register_function("fluent", FluentLoader::new(&*LOCALES)); - let db: Database = Database::from_reader(BufReader::new( - File::open(get_config().ip_to_asn).expect("unable to open ip2asn TSV file"), - )) - .unwrap(); - - rocket::build() - .manage(tera) - .manage(db) - .attach(Shield::new().enable(Referrer::NoReferrer)) - .attach(HttpHeaders) - .mount( - "/", - routes![ - list_services, - list_scans, - add_service_get, - add_service_post, - dl, - about, - ], - ) + rocket::custom(rocket::Config::figment().merge(( + "databases", + map!["main_db" => map!{ + "url" => Into::::into(config.database.clone()), + }], + ))) + .manage( + Database::from_reader(BufReader::new( + File::open(&config.ip_to_asn).expect("unable to open ip2asn TSV file"), + )) + .unwrap(), + ) + .manage(config) + .manage(tera) + .attach(DbConn::fairing()) + .attach(Shield::new().enable(Referrer::NoReferrer)) + .attach(HttpHeaders) + .mount( + "/", + routes![ + list_services, + list_scans, + add_service_get, + add_service_post, + dl, + about, + ], + ) } #[get("/services.db")] -fn dl() -> File { - File::open(get_config().database).unwrap() +fn dl(config: &State) -> File { + File::open(&config.database).unwrap() } #[derive(Serialize, Debug)] @@ -119,22 +133,27 @@ struct TemplateServices { ip_info: Vec, } #[get("/?")] -fn list_services( +async fn list_services( + conn: DbConn, tera: &State, - db: &State, + ipdb: &State, al: &AcceptLanguage, software: Option>, ) -> RawHtml { - let mut request = services::schema::services::dsl::services.into_boxed(); - if let Some(s) = software { - request = - request.filter(services::schema::services::software.eq(s.to_string().to_lowercase())); - } - let services = request - .limit(300) - .select(Services::as_select()) - .load(&mut establish_connection()) - .unwrap(); + let services = conn + .run(|c| { + let mut request = services::schema::services::dsl::services.into_boxed(); + if let Some(s) = software { + request = request + .filter(services::schema::services::software.eq(s.to_string().to_lowercase())); + } + request + .limit(300) + .select(Services::as_select()) + .load(c) + .unwrap() + }) + .await; let mut templates: Vec = vec![]; for service in &services { @@ -145,7 +164,7 @@ fn list_services( .iter() .filter(|ip| !ip.is_empty()) .for_each(|ip| { - match db.lookup(ip.parse().unwrap()).unwrap() { + match ipdb.lookup(ip.parse().unwrap()).unwrap() { asn_db2::IpEntry::V6(info) => ip_info.push(IpInfo { ip: ip.to_string(), subnet: info.subnet.to_string(), @@ -190,18 +209,18 @@ fn list_services( } #[get("/list-scans")] -fn list_scans(tera: &State, al: &AcceptLanguage) -> RawHtml { +async fn list_scans(conn: DbConn, tera: &State, al: &AcceptLanguage) -> RawHtml { RawHtml( tera.render( "list-scans.html.tera", &gen_context( al, json!({ - "scans": services::schema::scans::dsl::scans - .limit(1000) - .select(Scans::as_select()) - .load(&mut establish_connection()) - .unwrap() + "scans": conn.run(|c| services::schema::scans::dsl::scans + .limit(1000) + .select(Scans::as_select()) + .load(c) + .unwrap()).await }), ), ) @@ -242,6 +261,7 @@ struct Submission<'r> { #[post("/add-service", data = "")] async fn add_service_post( + conn: DbConn, submission: Form>>, tera: &State, al: &AcceptLanguage, @@ -262,20 +282,23 @@ async fn add_service_post( } }; - diesel::insert_into(services::table()) - .values(Services { - url: service.url, - software: service.software, - server: service.server, - ipv6: "".to_string(), - ipv4: "".to_string(), - availability_ipv6: "".to_string(), - availability_ipv4: "".to_string(), - address_ipv6: "".to_string(), - address_ipv4: "".to_string(), - }) - .execute(&mut establish_connection()) - .unwrap(); + conn.run(|c| { + diesel::insert_into(services::table()) + .values(Services { + url: service.url, + software: service.software, + server: service.server, + ipv6: "".to_string(), + ipv4: "".to_string(), + availability_ipv6: "".to_string(), + availability_ipv4: "".to_string(), + address_ipv6: "".to_string(), + address_ipv4: "".to_string(), + }) + .execute(c) + .unwrap() + }) + .await; RawHtml( tera.render("add-service.html.tera", &gen_context(al, json!({}))) @@ -284,14 +307,14 @@ async fn add_service_post( } #[get("/about")] -fn about(tera: &State, al: &AcceptLanguage) -> RawHtml { +fn about(config: &State, tera: &State, al: &AcceptLanguage) -> RawHtml { RawHtml( tera.render( "about.html.tera", &gen_context( al, json!({ - "source_code": get_config().source_code, + "source_code": config.source_code, }), ), ) diff --git a/src/updater.rs b/src/updater.rs index d62f990..d50c055 100644 --- a/src/updater.rs +++ b/src/updater.rs @@ -1,28 +1,32 @@ mod check; use crate::check::check; - -use ::services::establish_connection; -use ::services::models::Scans; -use ::services::models::Services; -use ::services::schema::scans::dsl::scans; -use ::services::schema::scans::installation; -use ::services::schema::services::dsl::*; -use diesel::associations::HasTable; -use diesel::ExpressionMethods; -use diesel::QueryDsl; -use diesel::RunQueryDsl; -use diesel::SelectableHelper; -use dns_lookup::lookup_host; +use ::services::{ + get_config, + models::{Scans, Services}, + schema::{ + scans::{dsl::scans, installation}, + services::dsl::*, + }, +}; use std::net::IpAddr; +use diesel::{ + associations::HasTable, Connection, ExpressionMethods, QueryDsl, RunQueryDsl, SelectableHelper, + SqliteConnection, +}; +use dns_lookup::lookup_host; + #[tokio::main] async fn main() { - let connection = &mut establish_connection(); + let config = get_config(); + + let mut conn = SqliteConnection::establish(&config.database) + .unwrap_or_else(|_| panic!("Error connecting to database")); for service in services .limit(300) .select(Services::as_select()) - .load(connection) + .load(&mut conn) .expect("Error loading services") { let result_ipv6 = match check(&service.url, Some(6)).await { @@ -46,14 +50,14 @@ async fn main() { result_ipv6: result_ipv6.clone(), result_ipv4: result_ipv4.clone(), }) - .execute(connection) + .execute(&mut conn) .unwrap(); let installation_scans = scans .filter(installation.eq(service.url.to_string())) .limit(100) .select(Scans::as_select()) - .load(connection) + .load(&mut conn) .unwrap(); let ipv6_successes = installation_scans @@ -99,7 +103,7 @@ async fn main() { address_ipv6.eq(&addr_ipv6.join(",")), address_ipv4.eq(&addr_ipv4.join(",")), )) - .execute(connection) + .execute(&mut conn) .unwrap(); } }