shop.tonybtw.com

shop.tonybtw.com

https://git.tonybtw.com/shop.tonybtw.com.git git://git.tonybtw.com/shop.tonybtw.com.git
1,210 bytes raw
1
package lib
2
3
import (
4
	"crypto/rand"
5
	"encoding/base64"
6
	"net/http"
7
	"sync"
8
)
9
10
var (
11
	csrf_tokens = make(map[string]string)
12
	csrf_mutex  sync.RWMutex
13
)
14
15
func Generate_CSRF_Token() string {
16
	b := make([]byte, 32)
17
	rand.Read(b)
18
	return base64.URLEncoding.EncodeToString(b)
19
}
20
21
func Get_CSRF_Token(session_id string) string {
22
	csrf_mutex.RLock()
23
	token, exists := csrf_tokens[session_id]
24
	csrf_mutex.RUnlock()
25
26
	if !exists {
27
		token = Generate_CSRF_Token()
28
		csrf_mutex.Lock()
29
		csrf_tokens[session_id] = token
30
		csrf_mutex.Unlock()
31
	}
32
33
	return token
34
}
35
36
func Verify_CSRF_Token(session_id, token string) bool {
37
	csrf_mutex.RLock()
38
	expected := csrf_tokens[session_id]
39
	csrf_mutex.RUnlock()
40
41
	return token == expected && token != ""
42
}
43
44
func Require_CSRF(next http.HandlerFunc) http.HandlerFunc {
45
	return func(w http.ResponseWriter, r *http.Request) {
46
		if r.Method == "POST" {
47
			session_id := Get_Session_ID(r)
48
			if err := r.ParseForm(); err != nil {
49
				http.Error(w, "Invalid form data", http.StatusBadRequest)
50
				return
51
			}
52
53
			token := r.FormValue("csrf_token")
54
			if !Verify_CSRF_Token(session_id, token) {
55
				http.Error(w, "Invalid CSRF token", http.StatusForbidden)
56
				return
57
			}
58
		}
59
		next(w, r)
60
	}
61
}