flams_database/
db.rs

1use argon2::PasswordHasher;
2use axum_login::{AuthUser, AuthnBackend, tower_sessions, tracing::Instrument};
3//use axum_login::tracing::Instrument;
4use flams_git::gl::auth::GitlabUser;
5use flams_system::settings::Settings;
6use flams_utils::unwrap;
7use password_hash::{SaltString, rand_core::OsRng};
8use sqlx::{SqlitePool, prelude::FromRow};
9
10#[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)]
11pub struct DBUser {
12    pub id: i64,
13    pub username: String,
14    pub session_auth_hash: Vec<u8>,
15    pub avatar_url: Option<String>,
16    pub is_admin: bool,
17    pub secret: String,
18}
19
20impl DBUser {
21    fn admin(hash: Vec<u8>) -> Self {
22        Self {
23            id: 0,
24            username: "admin".to_string(),
25            session_auth_hash: hash,
26            secret: String::new(),
27            is_admin: true,
28            avatar_url: None,
29        }
30    }
31}
32
33#[derive(Clone, Debug)]
34pub struct DBBackend {
35    pub pool: SqlitePool,
36    pub admin: Option<(String, String)>,
37}
38
39impl DBBackend {
40    /// ### Panics
41    pub async fn new() -> Self {
42        let settings = Settings::get();
43        let db_path = &settings.database;
44        let admin = settings.admin_pwd.as_ref().map(|pwd| {
45            let argon = argon2::Argon2::default();
46            let salt = SaltString::generate(&mut OsRng);
47            let pass_hash = argon
48                .hash_password(pwd.as_bytes(), &salt)
49                .expect("Failed to hash password");
50            let pass_hash_str = pass_hash.to_string();
51            let salt_str = salt.as_str().to_string();
52            (pass_hash_str, salt_str)
53        });
54        if !db_path.exists() {
55            tokio::fs::create_dir_all(db_path.parent().expect("Invalid database path"))
56                .await
57                .expect("Failed to create database directory");
58            tokio::fs::File::create(db_path)
59                .await
60                .expect("Failed to create database file");
61        }
62        let db_path = db_path
63            .as_os_str()
64            .to_str()
65            .expect("Failed to connect to database");
66        let pool = SqlitePool::connect(db_path)
67            .in_current_span()
68            .await
69            .expect("Failed to connect to database");
70        sqlx::migrate!("../../resources/migrations")
71            .run(&pool)
72            .in_current_span()
73            .await
74            .expect("Failed to run migrations");
75        Self { pool, admin }
76    }
77
78    /// #### Errors
79    pub async fn all_users(&self) -> Result<Vec<SqlUser>, UserError> {
80        sqlx::query_as!(SqlUser, "SELECT * FROM users")
81            .fetch_all(&self.pool)
82            //.in_current_span()
83            .await
84            .map_err(Into::into)
85    }
86
87    /// #### Errors
88    pub async fn set_admin(&self, id: i64, is_admin: bool) -> Result<(), UserError> {
89        sqlx::query!("UPDATE users SET is_admin=$2 WHERE id=$1", id, is_admin)
90            .execute(&self.pool)
91            //.in_current_span()
92            .await
93            .map_err(Into::into)
94            .map(|_| ())
95    }
96
97    /// #### Errors
98    pub async fn add_user(
99        &self,
100        user: GitlabUser,
101        secret: String,
102    ) -> Result<Option<DBUser>, UserError> {
103        #[derive(Debug)]
104        struct InsertUser {
105            pub id: i64,
106            is_admin: bool,
107        }
108        let GitlabUser {
109            id: gitlab_id,
110            name,
111            username,
112            avatar_url,
113            email,
114            can_create_group,
115            can_create_project,
116        } = user;
117
118        if username.len() < 2 {
119            return Err(UserError::InvalidUserName);
120        }
121        if secret.len() < 2 {
122            return Err(UserError::InvalidPassword);
123        }
124        let argon2 = argon2::Argon2::default();
125        let salt = SaltString::generate(&mut OsRng);
126        let pass_hash = argon2.hash_password(secret.as_bytes(), &salt)?;
127        let hash_bytes = pass_hash
128            .hash
129            .unwrap_or_else(|| unreachable!())
130            .as_bytes()
131            .to_owned();
132        //let salt = salt.as_str();
133        let new_id:InsertUser = sqlx::query_as!(InsertUser,
134            "INSERT INTO users (gitlab_id,name,username,email,avatar_url,can_create_group,can_create_project,secret,secret_hash,is_admin)
135            VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10)
136            ON CONFLICT (gitlab_id) DO UPDATE
137            SET name = excluded.name, username = excluded.username, email = excluded.email, avatar_url = excluded.avatar_url, can_create_group = excluded.can_create_group, can_create_project = excluded.can_create_project, secret = excluded.secret, secret_hash = excluded.secret_hash
138            RETURNING id,is_admin",
139            gitlab_id,name,username,email,avatar_url,can_create_group,can_create_project,secret,hash_bytes,false//,salt
140        ).fetch_one(&self.pool).await?;
141        let u = DBUser {
142            id: new_id.id,
143            username,
144            session_auth_hash: hash_bytes,
145            secret,
146            avatar_url: Some(avatar_url),
147            is_admin: new_id.is_admin,
148        };
149        Ok(Some(u))
150    }
151
152    /// #### Errors
153    pub async fn login_as_admin(
154        pwd: &str,
155        mut session: axum_login::AuthSession<Self>,
156    ) -> Result<(), UserError> {
157        use argon2::PasswordVerifier;
158        let (pass_hash, salt) = unwrap!(session.backend.admin.as_ref());
159        let hash = password_hash::PasswordHash::parse(pass_hash, password_hash::Encoding::B64)?;
160        let hasher = argon2::Argon2::default();
161        hasher.verify_password(pwd.as_bytes(), &hash)?;
162        session
163            .login(&DBUser::admin(salt.as_bytes().to_owned()))
164            .await?;
165        Ok(())
166    }
167}
168
169//#[async_trait::async_trait]
170impl AuthnBackend for DBBackend {
171    type User = DBUser;
172    type Credentials = (i64, String);
173    type Error = UserError;
174
175    async fn authenticate(
176        &self,
177        (gitlab_id, secret): Self::Credentials,
178    ) -> Result<Option<Self::User>, Self::Error> {
179        let Some(user) =
180            sqlx::query_as!(SqlUser, "SELECT * FROM users WHERE gitlab_id=$1", gitlab_id)
181                .fetch_optional(&self.pool)
182                .await?
183        else {
184            return Ok(None);
185        };
186        if user.secret == secret {
187            Ok(Some(user.try_into()?))
188        } else {
189            Ok(None)
190        }
191    }
192
193    async fn get_user(&self, user_id: &i64) -> Result<Option<Self::User>, Self::Error> {
194        if *user_id == 0 {
195            return Ok(Some(DBUser::admin(
196                self.admin
197                    .as_ref()
198                    .unwrap_or_else(|| unreachable!())
199                    .1
200                    .as_bytes()
201                    .to_owned(),
202            )));
203        }
204        let Some(res) = sqlx::query_as!(SqlUser, "SELECT * FROM users WHERE id=$1", *user_id)
205            .fetch_optional(&self.pool)
206            //.in_current_span()
207            .await?
208        else {
209            return Ok(None);
210        };
211        Ok(Some(res.try_into()?))
212    }
213}
214
215impl AuthUser for DBUser {
216    type Id = i64;
217
218    #[inline]
219    fn id(&self) -> Self::Id {
220        self.id
221    }
222
223    #[inline]
224    fn session_auth_hash(&self) -> &[u8] {
225        &self.session_auth_hash
226    }
227}
228
229#[derive(Clone, PartialEq, Eq, Debug, FromRow)]
230pub struct SqlUser {
231    id: i64,
232    gitlab_id: i64,
233    name: String,
234    username: String,
235    email: String,
236    avatar_url: String,
237    can_create_group: bool,
238    can_create_project: bool,
239    secret: String,
240    secret_hash: Vec<u8>,
241    is_admin: bool, //salt: String,
242}
243
244impl TryFrom<SqlUser> for DBUser {
245    type Error = UserError;
246    fn try_from(value: SqlUser) -> Result<Self, Self::Error> {
247        Ok(Self {
248            id: value.id,
249            username: value.username,
250            session_auth_hash: value.secret_hash,
251            secret: value.secret,
252            is_admin: value.is_admin,
253            avatar_url: Some(value.avatar_url),
254        })
255    }
256}
257
258impl From<SqlUser> for super::UserData {
259    fn from(u: SqlUser) -> Self {
260        Self {
261            id: u.id,
262            name: u.name,
263            username: u.username,
264            email: u.email,
265            avatar_url: u.avatar_url,
266            is_admin: u.is_admin,
267        }
268    }
269}
270
271#[derive(Debug, strum::Display)]
272pub enum UserError {
273    #[strum(to_string = "Invalid password hash")]
274    PasswordHashNone,
275    #[strum(to_string = "{0}")]
276    PasswordHash(password_hash::errors::Error),
277    #[strum(to_string = "{0}")]
278    Sqlx(sqlx::Error),
279    #[strum(to_string = "{0}")]
280    Session(tower_sessions::session::Error),
281    #[strum(to_string = "Invalid username: needs to be at least two characters")]
282    InvalidUserName,
283    #[strum(to_string = "Invalid password: needs to be at least two characters")]
284    InvalidPassword,
285}
286impl std::error::Error for UserError {}
287impl From<password_hash::errors::Error> for UserError {
288    #[inline]
289    fn from(e: password_hash::errors::Error) -> Self {
290        Self::PasswordHash(e)
291    }
292}
293
294impl From<axum_login::Error<DBBackend>> for UserError {
295    #[inline]
296    fn from(e: axum_login::Error<DBBackend>) -> Self {
297        match e {
298            axum_login::Error::Session(e) => Self::Session(e),
299            axum_login::Error::Backend(e) => e,
300        }
301    }
302}
303
304impl From<sqlx::Error> for UserError {
305    #[inline]
306    fn from(e: sqlx::Error) -> Self {
307        Self::Sqlx(e)
308    }
309}
310
311impl From<UserError> for super::LoginError {
312    #[inline]
313    fn from(_: UserError) -> Self {
314        Self::WrongUsernameOrPassword
315    }
316}
317impl From<password_hash::Error> for super::LoginError {
318    #[inline]
319    fn from(_: password_hash::Error) -> Self {
320        Self::WrongUsernameOrPassword
321    }
322}
323impl From<axum_login::Error<DBBackend>> for super::LoginError {
324    fn from(_: axum_login::Error<DBBackend>) -> Self {
325        Self::InternalError
326    }
327}