corobel/src/main.rs

272 lines
8.7 KiB
Rust

use clap::Parser;
use std::collections::HashMap;
use std::error::Error;
#[macro_use]
extern crate rocket;
use cached::proc_macro::cached;
use reqwest::{Client, StatusCode};
use rocket::response::content::RawHtml;
use rocket::serde::json::Json;
mod cohost_account;
mod cohost_posts;
mod syndication;
mod webfinger;
use cohost_account::{CohostAccount, COHOST_ACCOUNT_API_URL};
use cohost_posts::{cohost_posts_api_url, CohostPost, CohostPostsPage};
use webfinger::CohostWebfingerResource;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The base URL for the corobel instance
#[clap(short, long, required = true)]
domain: String,
/// The base URL for the corobel instance
#[clap(short, long, default_value_t = default_base_url() )]
base_url: String,
}
fn default_base_url() -> String {
"/".into()
}
fn user_agent() -> String {
format!(
"{}/{} (RSS feed converter) on {}",
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_VERSION"),
&ARGS.domain
)
}
static ARGS: once_cell::sync::Lazy<Args> = once_cell::sync::Lazy::new(|| Args::parse());
static CLIENT: once_cell::sync::Lazy<Client> = once_cell::sync::Lazy::new(|| {
reqwest::Client::builder()
.user_agent(user_agent())
.build()
.unwrap()
});
#[get("/")]
fn index() -> RawHtml<&'static str> {
RawHtml(include_str!("../static/index.html"))
}
#[derive(Responder)]
#[response(content_type = "text/markdown")]
struct MdResponse {
inner: String,
}
#[derive(Debug, Clone, Responder)]
#[response(content_type = "application/rss+xml")]
struct RssResponse {
inner: String,
}
#[derive(Debug, Responder)]
#[response(content_type = "text/plain")]
enum ErrorResponse {
#[response(status = 404)]
NotFound(String),
#[response(status = 500)]
InternalError(String),
}
#[cached(time = 60, result)]
async fn get_post_from_page(project_id: String, post_id: u64) -> Result<CohostPost, ErrorResponse> {
let mut page = 0;
loop {
let new_page = get_page_data(project_id.clone(), page).await?;
if new_page.items.is_empty() {
// Once there are no posts, we're done.
return Err(ErrorResponse::NotFound(
"End of posts reached, ID not found.".into(),
));
} else {
page += 1;
if let Some(post) = new_page.items.into_iter().find(|post| post.id == post_id) {
return Ok(post);
}
}
}
}
#[cached(time = 120, result)]
async fn get_full_post_data(project_id: String) -> Result<CohostPostsPage, ErrorResponse> {
let mut page = 0;
let mut merged_page = get_page_data(project_id.clone(), page).await?;
loop {
let mut new_page = get_page_data(project_id.clone(), page).await?;
if new_page.items.is_empty() {
// Once there are no posts, we're done.
break;
} else {
page += 1;
merged_page.number_items += new_page.number_items;
merged_page.items.append(&mut new_page.items);
}
}
Ok(merged_page)
}
// Not cached because it's never used individually.
async fn get_page_data(project_id: String, page: u64) -> Result<CohostPostsPage, ErrorResponse> {
let posts_url = cohost_posts_api_url(&project_id, page);
eprintln!("[INT] making request to {}", posts_url);
match CLIENT.get(posts_url).send().await {
Ok(v) => match v.status() {
StatusCode::OK => match v.json::<CohostPostsPage>().await {
Ok(page_data) => Ok(page_data),
Err(e) => {
let err = format!(
"Couldn't deserialize Cohost posts page for '{}': {:?}",
project_id, e
);
eprintln!("[ERR] {}", err);
return Err(ErrorResponse::InternalError(err));
}
},
// TODO NORA: Handle possible redirects
s => {
let err = format!("Didn't receive status code 200 for posts for Cohost project '{}'; got {:?} instead.", page, s);
eprintln!("[ERR] {}", err);
return Err(ErrorResponse::NotFound(err));
}
},
Err(e) => {
let err = format!(
"Error making request to Cohost for posts for project '{}': {:?}",
project_id, e
);
eprintln!("[ERR] {}", err);
return Err(ErrorResponse::InternalError(err));
}
}
}
#[cached(time = 60, result)]
async fn get_project_data(project_id: String) -> Result<CohostAccount, ErrorResponse> {
let project_url = format!("{}{}", COHOST_ACCOUNT_API_URL, project_id);
eprintln!("[INT] making request to {}", project_url);
match CLIENT.get(project_url).send().await {
Ok(v) => match v.status() {
StatusCode::OK => match v.json::<CohostAccount>().await {
Ok(a) => Ok(a),
Err(e) => {
let err = format!(
"Couldn't deserialize Cohost project '{}': {:?}",
project_id, e
);
eprintln!("[ERR] {}", err);
Err(ErrorResponse::InternalError(err))
}
},
// TODO NORA: Handle possible redirects
s => {
let err = format!(
"Didn't receive status code 200 for Cohost project '{}'; got {:?} instead.",
project_id, s
);
eprintln!("[ERR] {}", err);
Err(ErrorResponse::NotFound(err))
}
},
Err(e) => {
let err = format!(
"Error making request to Cohost for project '{}': {:?}",
project_id, e
);
eprintln!("[ERR] {}", err);
Err(ErrorResponse::InternalError(err))
}
}
}
#[get("/<project>/originals.rss")]
async fn syndication_originals_rss_route(project: String) -> Result<RssResponse, ErrorResponse> {
eprintln!("[EXT] Request to /{}/originals.rss", project);
let project_data = get_project_data(project.clone()).await?;
let page_data = get_full_post_data(project.clone()).await?;
Ok(RssResponse {
inner: syndication::channel_for_posts_page(project.clone(), project_data, page_data, true)
.to_string(),
})
}
#[get("/<project>/feed.rss")]
async fn syndication_rss_route(project: String) -> Result<RssResponse, ErrorResponse> {
eprintln!("[EXT] Request to /{}/feed.rss", project);
let project_data = get_project_data(project.clone()).await?;
let page_data = get_full_post_data(project.clone()).await?;
Ok(RssResponse {
inner: syndication::channel_for_posts_page(project.clone(), project_data, page_data, false)
.to_string(),
})
}
#[get("/<project>/<id>")]
async fn post_md_route(project: String, id: u64) -> Result<MdResponse, ErrorResponse> {
eprintln!("[EXT] Request to /{}/{}", project, id);
let _project_data = get_project_data(project.clone()).await?;
let post_data = get_post_from_page(project.clone(), id).await?;
Ok(MdResponse {
inner: post_data.plain_body,
})
}
#[get("/.well-known/webfinger?<params..>")]
async fn webfinger_route(
params: HashMap<String, String>,
) -> Result<Json<CohostWebfingerResource>, ErrorResponse> {
let mut url_params_string = String::new();
for (k, v) in params.iter() {
url_params_string.push_str(&format!("{}={}&", k, v));
}
eprintln!(
"[EXT] Request to /.well_known/webfinger?{}",
url_params_string
);
if params.len() != 1 {
let err = format!(
"Too may or too few parameters. Expected 1, got {}",
params.len()
);
eprintln!("[ERR] {}", err);
return Err(ErrorResponse::InternalError(err));
}
if let Some(param) = params.iter().next() {
let _project_data = get_project_data(param.0.clone()).await?;
Ok(Json(CohostWebfingerResource::new(
param.0.as_str(),
&ARGS.domain,
&ARGS.base_url,
)))
} else {
Err(ErrorResponse::NotFound("No project ID provided.".into()))
}
}
#[rocket::main]
async fn main() -> Result<(), Box<dyn Error>> {
// Set up the global config
once_cell::sync::Lazy::force(&ARGS);
let _rocket = rocket::build()
.mount(
&ARGS.base_url,
routes![
index,
webfinger_route,
syndication_rss_route,
syndication_originals_rss_route,
post_md_route
],
)
.ignite()
.await?
.launch()
.await?;
Ok(())
}