use clap::Parser; use std::collections::HashMap; use std::error::Error; #[macro_use] extern crate rocket; use cached::proc_macro::cached; use reqwest::{Client, StatusCode}; use rocket::response::content::RawHtml; use rocket::serde::json::Json; mod cohost_account; mod cohost_posts; mod syndication; mod webfinger; use cohost_account::{CohostAccount, COHOST_ACCOUNT_API_URL}; use cohost_posts::{cohost_posts_api_url, CohostPost, CohostPostsPage}; use webfinger::CohostWebfingerResource; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// The base URL for the corobel instance #[clap(short, long, required = true)] domain: String, /// The base URL for the corobel instance #[clap(short, long, default_value_t = default_base_url() )] base_url: String, } fn default_base_url() -> String { "/".into() } fn user_agent() -> String { format!( "{}/{} (RSS feed converter) on {}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"), &ARGS.domain ) } static ARGS: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| Args::parse()); static CLIENT: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { reqwest::Client::builder() .user_agent(user_agent()) .build() .unwrap() }); #[get("/")] fn index() -> RawHtml<&'static str> { RawHtml(include_str!("../static/index.html")) } #[derive(Responder)] #[response(content_type = "text/markdown")] struct MdResponse { inner: String, } #[derive(Debug, Clone, Responder)] #[response(content_type = "application/rss+xml")] struct RssResponse { inner: String, } #[derive(Debug, Responder)] #[response(content_type = "text/plain")] enum ErrorResponse { #[response(status = 404)] NotFound(String), #[response(status = 500)] InternalError(String), } #[cached(time = 60, result)] async fn get_post_from_page(project_id: String, post_id: u64) -> Result { let mut page = 0; loop { let new_page = get_page_data(project_id.clone(), page).await?; if new_page.items.is_empty() { // Once there are no posts, we're done. return Err(ErrorResponse::NotFound( "End of posts reached, ID not found.".into(), )); } else { page += 1; if let Some(post) = new_page.items.into_iter().find(|post| post.id == post_id) { return Ok(post); } } } } #[cached(time = 120, result)] async fn get_full_post_data(project_id: String) -> Result { let mut page = 0; let mut merged_page = get_page_data(project_id.clone(), page).await?; loop { let mut new_page = get_page_data(project_id.clone(), page).await?; if new_page.items.is_empty() { // Once there are no posts, we're done. break; } else { page += 1; merged_page.number_items += new_page.number_items; merged_page.items.append(&mut new_page.items); } } Ok(merged_page) } // Not cached because it's never used individually. async fn get_page_data(project_id: String, page: u64) -> Result { let posts_url = cohost_posts_api_url(&project_id, page); eprintln!("[INT] making request to {}", posts_url); match CLIENT.get(posts_url).send().await { Ok(v) => match v.status() { StatusCode::OK => match v.json::().await { Ok(page_data) => Ok(page_data), Err(e) => { let err = format!( "Couldn't deserialize Cohost posts page for '{}': {:?}", project_id, e ); eprintln!("[ERR] {}", err); return Err(ErrorResponse::InternalError(err)); } }, // TODO NORA: Handle possible redirects s => { let err = format!("Didn't receive status code 200 for posts for Cohost project '{}'; got {:?} instead.", page, s); eprintln!("[ERR] {}", err); return Err(ErrorResponse::NotFound(err)); } }, Err(e) => { let err = format!( "Error making request to Cohost for posts for project '{}': {:?}", project_id, e ); eprintln!("[ERR] {}", err); return Err(ErrorResponse::InternalError(err)); } } } #[cached(time = 60, result)] async fn get_project_data(project_id: String) -> Result { let project_url = format!("{}{}", COHOST_ACCOUNT_API_URL, project_id); eprintln!("[INT] making request to {}", project_url); match CLIENT.get(project_url).send().await { Ok(v) => match v.status() { StatusCode::OK => match v.json::().await { Ok(a) => Ok(a), Err(e) => { let err = format!( "Couldn't deserialize Cohost project '{}': {:?}", project_id, e ); eprintln!("[ERR] {}", err); Err(ErrorResponse::InternalError(err)) } }, // TODO NORA: Handle possible redirects s => { let err = format!( "Didn't receive status code 200 for Cohost project '{}'; got {:?} instead.", project_id, s ); eprintln!("[ERR] {}", err); Err(ErrorResponse::NotFound(err)) } }, Err(e) => { let err = format!( "Error making request to Cohost for project '{}': {:?}", project_id, e ); eprintln!("[ERR] {}", err); Err(ErrorResponse::InternalError(err)) } } } #[get("//originals.rss")] async fn syndication_originals_rss_route(project: String) -> Result { eprintln!("[EXT] Request to /{}/originals.rss", project); let project_data = get_project_data(project.clone()).await?; let page_data = get_full_post_data(project.clone()).await?; Ok(RssResponse { inner: syndication::channel_for_posts_page(project.clone(), project_data, page_data, true) .to_string(), }) } #[get("//feed.rss")] async fn syndication_rss_route(project: String) -> Result { eprintln!("[EXT] Request to /{}/feed.rss", project); let project_data = get_project_data(project.clone()).await?; let page_data = get_full_post_data(project.clone()).await?; Ok(RssResponse { inner: syndication::channel_for_posts_page(project.clone(), project_data, page_data, false) .to_string(), }) } #[get("//")] async fn post_md_route(project: String, id: u64) -> Result { eprintln!("[EXT] Request to /{}/{}", project, id); let _project_data = get_project_data(project.clone()).await?; let post_data = get_post_from_page(project.clone(), id).await?; Ok(MdResponse { inner: post_data.plain_body, }) } #[get("/.well-known/webfinger?")] async fn webfinger_route( params: HashMap, ) -> Result, ErrorResponse> { let mut url_params_string = String::new(); for (k, v) in params.iter() { url_params_string.push_str(&format!("{}={}&", k, v)); } eprintln!( "[EXT] Request to /.well_known/webfinger?{}", url_params_string ); if params.len() != 1 { let err = format!( "Too may or too few parameters. Expected 1, got {}", params.len() ); eprintln!("[ERR] {}", err); return Err(ErrorResponse::InternalError(err)); } if let Some(param) = params.iter().next() { let _project_data = get_project_data(param.0.clone()).await?; Ok(Json(CohostWebfingerResource::new( param.0.as_str(), &ARGS.domain, &ARGS.base_url, ))) } else { Err(ErrorResponse::NotFound("No project ID provided.".into())) } } #[rocket::main] async fn main() -> Result<(), Box> { // Set up the global config once_cell::sync::Lazy::force(&ARGS); let _rocket = rocket::build() .mount( &ARGS.base_url, routes![ index, webfinger_route, syndication_rss_route, syndication_originals_rss_route, post_md_route ], ) .ignite() .await? .launch() .await?; Ok(()) }