diff --git a/app/__init__.py b/app/__init__.py
index aaf0713..969e8ed 100644
--- a/app/__init__.py
+++ b/app/__init__.py
@@ -3,6 +3,7 @@ from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_login import LoginManager
from flask_mail import Mail
+from flask_wtf.csrf import CSRFProtect
from config import Config
app = Flask(__name__)
@@ -15,11 +16,13 @@ if Config.SECRET_KEY == "a-random-secret-key":
db = SQLAlchemy(app)
migrate = Migrate(app, db)
mail = Mail(app)
+csrf = CSRFProtect(app)
login = LoginManager(app)
login.login_view = 'login'
login.login_message = "Veuillez vous authentifier avant de continuer."
+
# Register converters (needed for routing)
from app.utils.converters import *
app.url_map.converters['forum'] = ForumConverter
diff --git a/app/routes/account/login.py b/app/routes/account/login.py
index 058219f..b00293b 100644
--- a/app/routes/account/login.py
+++ b/app/routes/account/login.py
@@ -7,6 +7,7 @@ from app.models.user import Member
from app.models.priv import Group
from app.utils.render import render
from app.utils.send_mail import send_validation_mail
+from app.utils.check_csrf import check_csrf
import datetime
@@ -68,6 +69,7 @@ def login():
@app.route('/deconnexion')
@login_required
+@check_csrf
def logout():
logout_user()
flash('Déconnexion réussie', 'info')
diff --git a/app/templates/base/navbar/account.html b/app/templates/base/navbar/account.html
index d34cea4..000e0a7 100644
--- a/app/templates/base/navbar/account.html
+++ b/app/templates/base/navbar/account.html
@@ -36,7 +36,7 @@
Paramètres
-
+
Déconnexion
diff --git a/app/utils/check_csrf.py b/app/utils/check_csrf.py
new file mode 100644
index 0000000..971d572
--- /dev/null
+++ b/app/utils/check_csrf.py
@@ -0,0 +1,18 @@
+from functools import wraps
+from flask import request, abort
+from flask_wtf import csrf
+from wtforms.validators import ValidationError
+from app import app
+
+def check_csrf(func):
+ """
+ Check csrf_token GET parameter
+ """
+ @wraps(func)
+ def wrapped(*args, **kwargs):
+ try:
+ csrf.validate_csrf(request.args.get('csrf_token'))
+ except ValidationError:
+ abort(404)
+ return func(*args, **kwargs)
+ return wrapped