flams_database/
db.rs

1use argon2::PasswordHasher;
2use axum_login::{AuthUser, AuthnBackend, tower_sessions, tracing::Instrument};
3use flams_git::gl::auth::GitlabUser;
4use flams_system::settings::Settings;
5use flams_utils::unwrap;
6use password_hash::{SaltString, rand_core::OsRng};
7use sqlx::{SqlitePool, prelude::FromRow};
8
9#[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)]
10pub struct DBUser {
11    pub id: i64,
12    pub username: String,
13    pub session_auth_hash: Vec<u8>,
14    pub avatar_url: Option<String>,
15    pub is_admin: bool,
16    pub secret: String,
17}
18
19impl DBUser {
20    fn admin(hash: Vec<u8>) -> Self {
21        Self {
22            id: 0,
23            username: "admin".to_string(),
24            session_auth_hash: hash,
25            secret: String::new(),
26            is_admin: true,
27            avatar_url: None,
28        }
29    }
30}
31
32#[derive(Clone, Debug)]
33pub struct DBBackend {
34    pub pool: SqlitePool,
35    pub admin: Option<(String, String)>,
36}
37
38impl DBBackend {
39    /// ### Panics
40    pub async fn new() -> Self {
41        let settings = Settings::get();
42        let db_path = &settings.database;
43        let admin = settings.admin_pwd.as_ref().map(|pwd| {
44            let argon = argon2::Argon2::default();
45            let salt = SaltString::generate(&mut OsRng);
46            let pass_hash = argon
47                .hash_password(pwd.as_bytes(), &salt)
48                .expect("Failed to hash password");
49            let pass_hash_str = pass_hash.to_string();
50            let salt_str = salt.as_str().to_string();
51            (pass_hash_str, salt_str)
52        });
53        if !db_path.exists() {
54            tokio::fs::create_dir_all(db_path.parent().expect("Invalid database path"))
55                .await
56                .expect("Failed to create database directory");
57            tokio::fs::File::create(db_path)
58                .await
59                .expect("Failed to create database file");
60        }
61        let db_path = db_path
62            .as_os_str()
63            .to_str()
64            .expect("Failed to connect to database");
65        let pool = SqlitePool::connect(db_path)
66            .in_current_span()
67            .await
68            .expect("Failed to connect to database");
69        sqlx::migrate!("../../resources/migrations")
70            .run(&pool)
71            .in_current_span()
72            .await
73            .expect("Failed to run migrations");
74        Self { pool, admin }
75    }
76
77    /// #### Errors
78    pub async fn all_users(&self) -> Result<Vec<SqlUser>, UserError> {
79        sqlx::query_as!(SqlUser, "SELECT * FROM users")
80            .fetch_all(&self.pool)
81            //.in_current_span()
82            .await
83            .map_err(Into::into)
84    }
85
86    /// #### Errors
87    pub async fn set_admin(&self, id: i64, is_admin: bool) -> Result<(), UserError> {
88        sqlx::query!("UPDATE users SET is_admin=$2 WHERE id=$1", id, is_admin)
89            .execute(&self.pool)
90            //.in_current_span()
91            .await
92            .map_err(Into::into)
93            .map(|_| ())
94    }
95
96    /// #### Errors
97    pub async fn add_user(
98        &self,
99        user: GitlabUser,
100        secret: String,
101    ) -> Result<Option<DBUser>, UserError> {
102        #[derive(Debug)]
103        struct InsertUser {
104            pub id: i64,
105            is_admin: bool,
106        }
107        let GitlabUser {
108            id: gitlab_id,
109            name,
110            username,
111            avatar_url,
112            email,
113            can_create_group,
114            can_create_project,
115        } = user;
116
117        if username.len() < 2 {
118            return Err(UserError::InvalidUserName);
119        }
120        if secret.len() < 2 {
121            return Err(UserError::InvalidPassword);
122        }
123        let argon2 = argon2::Argon2::default();
124        let salt = SaltString::generate(&mut OsRng);
125        let pass_hash = argon2.hash_password(secret.as_bytes(), &salt)?;
126        let hash_bytes = pass_hash
127            .hash
128            .unwrap_or_else(|| unreachable!())
129            .as_bytes()
130            .to_owned();
131        //let salt = salt.as_str();
132        let new_id:InsertUser = sqlx::query_as!(InsertUser,
133            "INSERT INTO users (gitlab_id,name,username,email,avatar_url,can_create_group,can_create_project,secret,secret_hash,is_admin)
134            VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10)
135            ON CONFLICT (gitlab_id) DO UPDATE
136            SET name = excluded.name, username = excluded.username, email = excluded.email, avatar_url = excluded.avatar_url, can_create_group = excluded.can_create_group, can_create_project = excluded.can_create_project, secret = excluded.secret, secret_hash = excluded.secret_hash
137            RETURNING id,is_admin",
138            gitlab_id,name,username,email,avatar_url,can_create_group,can_create_project,secret,hash_bytes,false//,salt
139        ).fetch_one(&self.pool).await?;
140        let u = DBUser {
141            id: new_id.id,
142            username,
143            session_auth_hash: hash_bytes,
144            secret,
145            avatar_url: Some(avatar_url),
146            is_admin: new_id.is_admin,
147        };
148        Ok(Some(u))
149    }
150
151    /// #### Errors
152    pub async fn login_as_admin(
153        pwd: &str,
154        mut session: axum_login::AuthSession<Self>,
155    ) -> Result<(), UserError> {
156        use argon2::PasswordVerifier;
157        let (pass_hash, salt) = unwrap!(session.backend.admin.as_ref());
158        let hash = password_hash::PasswordHash::parse(pass_hash, password_hash::Encoding::B64)?;
159        let hasher = argon2::Argon2::default();
160        hasher.verify_password(pwd.as_bytes(), &hash)?;
161        session
162            .login(&DBUser::admin(salt.as_bytes().to_owned()))
163            .await?;
164        Ok(())
165    }
166}
167
168#[async_trait::async_trait]
169impl AuthnBackend for DBBackend {
170    type User = DBUser;
171    type Credentials = (i64, String);
172    type Error = UserError;
173
174    async fn authenticate(
175        &self,
176        (gitlab_id, secret): Self::Credentials,
177    ) -> Result<Option<Self::User>, Self::Error> {
178        let Some(user) =
179            sqlx::query_as!(SqlUser, "SELECT * FROM users WHERE gitlab_id=$1", gitlab_id)
180                .fetch_optional(&self.pool)
181                .await?
182        else {
183            return Ok(None);
184        };
185        if user.secret == secret {
186            Ok(Some(user.try_into()?))
187        } else {
188            Ok(None)
189        }
190    }
191
192    async fn get_user(&self, user_id: &i64) -> Result<Option<Self::User>, Self::Error> {
193        if *user_id == 0 {
194            return Ok(Some(DBUser::admin(
195                self.admin
196                    .as_ref()
197                    .unwrap_or_else(|| unreachable!())
198                    .1
199                    .as_bytes()
200                    .to_owned(),
201            )));
202        }
203        let Some(res) = sqlx::query_as!(SqlUser, "SELECT * FROM users WHERE id=$1", *user_id)
204            .fetch_optional(&self.pool)
205            //.in_current_span()
206            .await?
207        else {
208            return Ok(None);
209        };
210        Ok(Some(res.try_into()?))
211    }
212}
213
214impl AuthUser for DBUser {
215    type Id = i64;
216
217    #[inline]
218    fn id(&self) -> Self::Id {
219        self.id
220    }
221
222    #[inline]
223    fn session_auth_hash(&self) -> &[u8] {
224        &self.session_auth_hash
225    }
226}
227
228#[derive(Clone, PartialEq, Eq, Debug, FromRow)]
229pub struct SqlUser {
230    id: i64,
231    gitlab_id: i64,
232    name: String,
233    username: String,
234    email: String,
235    avatar_url: String,
236    can_create_group: bool,
237    can_create_project: bool,
238    secret: String,
239    secret_hash: Vec<u8>,
240    is_admin: bool, //salt: String,
241}
242
243impl TryFrom<SqlUser> for DBUser {
244    type Error = UserError;
245    fn try_from(value: SqlUser) -> Result<Self, Self::Error> {
246        Ok(Self {
247            id: value.id,
248            username: value.username,
249            session_auth_hash: value.secret_hash,
250            secret: value.secret,
251            is_admin: value.is_admin,
252            avatar_url: Some(value.avatar_url),
253        })
254    }
255}
256
257impl From<SqlUser> for super::UserData {
258    fn from(u: SqlUser) -> Self {
259        Self {
260            id: u.id,
261            name: u.name,
262            username: u.username,
263            email: u.email,
264            avatar_url: u.avatar_url,
265            is_admin: u.is_admin,
266        }
267    }
268}
269
270#[derive(Debug, strum::Display)]
271pub enum UserError {
272    #[strum(to_string = "Invalid password hash")]
273    PasswordHashNone,
274    #[strum(to_string = "{0}")]
275    PasswordHash(password_hash::errors::Error),
276    #[strum(to_string = "{0}")]
277    Sqlx(sqlx::Error),
278    #[strum(to_string = "{0}")]
279    Session(tower_sessions::session::Error),
280    #[strum(to_string = "Invalid username: needs to be at least two characters")]
281    InvalidUserName,
282    #[strum(to_string = "Invalid password: needs to be at least two characters")]
283    InvalidPassword,
284}
285impl std::error::Error for UserError {}
286impl From<password_hash::errors::Error> for UserError {
287    #[inline]
288    fn from(e: password_hash::errors::Error) -> Self {
289        Self::PasswordHash(e)
290    }
291}
292
293impl From<axum_login::Error<DBBackend>> for UserError {
294    #[inline]
295    fn from(e: axum_login::Error<DBBackend>) -> Self {
296        match e {
297            axum_login::Error::Session(e) => Self::Session(e),
298            axum_login::Error::Backend(e) => e,
299        }
300    }
301}
302
303impl From<sqlx::Error> for UserError {
304    #[inline]
305    fn from(e: sqlx::Error) -> Self {
306        Self::Sqlx(e)
307    }
308}
309
310impl From<UserError> for super::LoginError {
311    #[inline]
312    fn from(_: UserError) -> Self {
313        Self::WrongUsernameOrPassword
314    }
315}
316impl From<password_hash::Error> for super::LoginError {
317    #[inline]
318    fn from(_: password_hash::Error) -> Self {
319        Self::WrongUsernameOrPassword
320    }
321}
322impl From<axum_login::Error<DBBackend>> for super::LoginError {
323    fn from(_: axum_login::Error<DBBackend>) -> Self {
324        Self::InternalError
325    }
326}