A pit full of rusty nails
1//! Crate for defining and handling `nailpit` configuration. Defines the main
2//! [`NailConfig`] struct, as well as the utility method to derive the config object
3//! from various `toml` files.
4//!
5
6use core::num::NonZero;
7use std::{ops::Deref, sync::Arc};
8
9use color_eyre::{Result, eyre::Context};
10use nailbox::try_arc_within;
11
12#[derive(Debug, serde::Serialize, serde::Deserialize)]
13pub struct NailConfig {
14 pub server: ServerConfig,
15 pub generator: GeneratorConfig,
16 #[serde(default)]
17 pub rate_limiting: RateLimitingConfig,
18 pub open_telemetry: OpenTelemetryConfig,
19}
20
21#[derive(Default, serde::Serialize, serde::Deserialize)]
22pub struct PromptsList(Vec<String>);
23
24impl std::fmt::Debug for PromptsList {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_tuple("PromptsList").finish_non_exhaustive()
27 }
28}
29
30impl Deref for PromptsList {
31 type Target = [String];
32
33 fn deref(&self) -> &Self::Target {
34 &self.0
35 }
36}
37
38#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
39pub struct ServerConfig {
40 pub pit_routes: Vec<String>,
41 pub socket_addr: String,
42 pub worker_threads: NonZero<usize>,
43}
44
45#[derive(Debug, serde::Serialize, serde::Deserialize)]
46pub struct GeneratorConfig {
47 #[serde(default)]
48 pub prompts: PromptsList,
49 pub warning_template: String,
50 pub generated_template: String,
51 pub warning_message: String,
52 pub input_files: String,
53 pub min_paragraph_size: usize,
54 pub max_paragraph_size: Option<usize>,
55 pub payload_size: usize,
56 pub timeout: u64,
57 pub min_delay: u64,
58 pub max_delay: Option<u64>,
59 pub chunk_size: usize,
60 pub header_size: usize,
61 pub max_pit_links: usize,
62}
63
64#[derive(Debug, Default, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
65#[serde(tag = "mode")]
66pub enum DropBehavior {
67 #[default]
68 #[serde(rename = "normal")]
69 Normal,
70 #[serde(rename = "spicy")]
71 Spicy { payload: Vec<String> },
72}
73
74impl DropBehavior {
75 pub fn is_spicy(&self) -> bool {
76 matches!(self, Self::Spicy { .. })
77 }
78}
79
80#[derive(Debug, Default, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
81#[serde(tag = "type")]
82pub enum RateLimitingConfig {
83 #[default]
84 #[serde(rename = "no_limit")]
85 NoLimit,
86 #[serde(rename = "soft_limit")]
87 SoftLimit { soft_limit: u64, soft_delay: u64 },
88 #[serde(rename = "hard_limit")]
89 HardLimit {
90 hard_limit: u64,
91 drop_behavior: DropBehavior,
92 },
93 #[serde(rename = "soft_with_hard_limit")]
94 SoftWithHardLimit {
95 soft_limit: u64,
96 hard_limit: u64,
97 soft_delay: u64,
98 drop_behavior: DropBehavior,
99 },
100}
101
102#[derive(Debug, serde::Serialize, serde::Deserialize)]
103pub struct OpenTelemetryConfig {
104 pub endpoint: String,
105 pub service_name: String,
106 pub logs: bool,
107 pub traces: bool,
108}
109
110pub fn get_configuration() -> Result<Arc<NailConfig>> {
111 let socket_addr = std::env::var("NAILPIT_SOCKET").ok();
112
113 let working_dir = std::env::current_dir()?;
114
115 let config_dir = working_dir.join("configuration");
116 let default_dir = working_dir.join("defaults");
117
118 let config = config::Config::builder()
119 .add_source(
120 config::File::from(default_dir.join("pit.default.toml"))
121 .format(config::FileFormat::Toml),
122 )
123 .add_source(
124 config::File::from(config_dir.join("pit.toml"))
125 .required(false)
126 .format(config::FileFormat::Toml),
127 )
128 .set_override_option("server.socket_addr", socket_addr)?
129 .build()
130 .context("Unable to read configuration files")?;
131
132 let config = try_arc_within(|| config.try_deserialize())
133 .context("Failed to load configuration. Maybe the format is invalid")?;
134
135 Ok(config)
136}