diff --git a/.env b/.env index c9e35841a..a4c685393 100644 --- a/.env +++ b/.env @@ -13,13 +13,14 @@ POSTGRES_URL_NON_POOLING="postgresql://postgres:postgres@localhost:5432/dsek?sch # AUTH # Auth.js uses the following environment variables. AUTH_SECRET="4e0b5eed97d12748be91415ac2716b9e91deb57198c7b3662afe7f1649089b54" # required, generate a 32-bit random string, for example with openssl rand -hex 32 -AUTH_TRUST_HOST=true # set to true for authentication in build environment +AUTH_TRUST_HOST=true # set to true for authentication in build environment AUTH_AUTHENTIK_CLIENT_ID="SvRybUTCGqhNiw2Y3gn1wqt0YxpjW2sv9fbPsUaP" AUTH_AUTHENTIK_CLIENT_SECRET="" # public dev client, only allows redirect to localhost PUBLIC_AUTH_AUTHENTIK_ISSUER="https://auth.dsek.se/application/o/dev/" +PUBLIC_AUTH_AUTHENTIK_TOKEN_ENDPOINT="https://auth.dsek.se/application/o/token/" # AUTHENTIK -# Used to connect to the Authentik server to keep the +# Used to connect to the Authentik server to keep the # roles and permissions in sync with the webpage. AUTHENTIK_API_TOKEN= AUTHENTIK_ENDPOINT=https://auth.dsek.se/api/v3 @@ -28,8 +29,8 @@ AUTHENTIK_ENABLED=false # set to false to avoid syncing with authentik # FILE STORAGE # Used to connect to the MinIO file server. # Different types of files are stored in different buckets. -MINIO_ROOT_USER= # <| -MINIO_ROOT_PASSWORD= # <| Do not forget to put the values here! Otherwise it wont work! +MINIO_ROOT_USER= # <| +MINIO_ROOT_PASSWORD= # <| Do not forget to put the values here! Otherwise it wont work! PUBLIC_MINIO_ENDPOINT=files-sandbox.dsek.se PUBLIC_MINIO_PORT=443 PUBLIC_MINIO_USE_SSL=true @@ -94,4 +95,8 @@ BOOKKEEPING_EMAIL_TO_ADDRESS=bookkeeping@example.com BOOKKEEPING_EMAIL_FROM_ADDRESS=automatic-expensing@dsek.se BOOKKEEPING_CC_TO_ADDRESS=skattm@dsek.se # comma separated list -SYNC_PASSWORD=password123 \ No newline at end of file +SYNC_PASSWORD=password123 + +# SCHEDULER +SCHEDULER_ENDPOINT=http://localhost:8080/schedule +SCHEDULER_PASSWORD=supersecretpassword diff --git a/scheduler-service/.env b/scheduler-service/.env new file mode 100644 index 000000000..98f0c3083 --- /dev/null +++ b/scheduler-service/.env @@ -0,0 +1,13 @@ +POSTGRES_HOST=localhost +POSTGRES_USER=postgres +POSTGRES_PASSWORD=postgres +POSTGRES_DB=postgres +POSTGRES_PORT=5431 + +PASSWORD=supersecretpassword + +SERVER_PORT=8080 + +JWT_ISSUER=https://auth.dsek.se/application/o/dev/ +JWT_AUDIENCE=SvRybUTCGqhNiw2Y3gn1wqt0YxpjW2sv9fbPsUaP +JWKS_ENDPOINT=https://auth.dsek.se/application/o/dev/jwks/ diff --git a/scheduler-service/authMiddleware.go b/scheduler-service/authMiddleware.go new file mode 100644 index 000000000..02191dccc --- /dev/null +++ b/scheduler-service/authMiddleware.go @@ -0,0 +1,96 @@ +package main + +import ( + "context" + "log" + "log/slog" + "net/http" + "os" + "strings" + "time" + + "github.com/lestrrat-go/httprc/v3" + "github.com/lestrrat-go/httprc/v3/tracesink" + "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/lestrrat-go/jwx/v3/jwt" +) + +var cachedJWKS jwk.Set + +func AuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if cachedJWKS == nil { + if err := createJWKCache(); err != nil { + log.Printf("Failed to create JWK cache: %s", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + + return + } + } + + parseOptions := []jwt.ParseOption{ + jwt.WithKeySet(cachedJWKS), + jwt.WithIssuer(JWTIssuer), + jwt.WithAudience(JWTAudience), + } + + if _, err := jwt.Parse([]byte(getTokenFromHeader(r)), parseOptions...); err != nil { + log.Printf("Failed to parse JWT: %s", err) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + + return + } + + next.ServeHTTP(w, r) + }) +} + +func getTokenFromHeader(r *http.Request) string { + stringToken := r.Header.Get("Authorization") + + const bearerPrefix = "Bearer " + if !strings.HasPrefix(stringToken, bearerPrefix) { + return "" + } + + stringToken = strings.TrimPrefix(stringToken, bearerPrefix) + + return stringToken +} + +func createJWKCache() error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + jwkCache, err := jwk.NewCache( + ctx, + httprc.NewClient( + httprc.WithTraceSink(tracesink.NewSlog(slog.New(slog.NewJSONHandler(os.Stderr, nil)))), + ), + ) + if err != nil { + log.Printf("Failed to create JWK cache: %s", err) + + return err + } + + if err = jwkCache.Register( + ctx, + JWKSEndpoint, + jwk.WithMaxInterval(24*time.Hour*7), + jwk.WithMinInterval(5*time.Minute), + ); err != nil { + log.Printf("Failed to register JWK endpoint: %s", err) + + return err + } + + cachedJWKS, err = jwkCache.CachedSet(JWKSEndpoint) + if err != nil { + log.Printf("Failed to get cached JWK set: %s", err) + + return err + } + + return nil +} diff --git a/scheduler-service/databaseHandler.go b/scheduler-service/databaseHandler.go new file mode 100644 index 000000000..7a366663a --- /dev/null +++ b/scheduler-service/databaseHandler.go @@ -0,0 +1,21 @@ +package main + +import ( + "fmt" + "os" + + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +func openDatabaseConnection(db **gorm.DB) error { + host, user, password, name, port := os.Getenv("POSTGRES_HOST"), os.Getenv("POSTGRES_USER"), os.Getenv("POSTGRES_PASSWORD"), os.Getenv("POSTGRES_DB"), os.Getenv("POSTGRES_PORT") + dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=disable TimeZone=UTC", host, user, password, name, port) + + var err error + if *db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}); err != nil { + return err + } + + return nil +} diff --git a/scheduler-service/go.mod b/scheduler-service/go.mod new file mode 100644 index 000000000..fb22bc657 --- /dev/null +++ b/scheduler-service/go.mod @@ -0,0 +1,35 @@ +module github.com/Dsek-LTH/scheduler + +go 1.25.4 + +require ( + github.com/joho/godotenv v1.5.1 + github.com/lestrrat-go/httprc/v3 v3.0.1 + github.com/lestrrat-go/jwx/v3 v3.0.12 + golang.org/x/time v0.14.0 + gorm.io/driver/postgres v1.6.0 + gorm.io/gorm v1.31.1 +) + +require ( + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect + github.com/goccy/go-json v0.10.3 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.6.0 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/lestrrat-go/blackmagic v1.0.4 // indirect + github.com/lestrrat-go/dsig v1.0.0 // indirect + github.com/lestrrat-go/dsig-secp256k1 v1.0.0 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/option v1.0.1 // indirect + github.com/lestrrat-go/option/v2 v2.0.0 // indirect + github.com/segmentio/asm v1.2.1 // indirect + github.com/valyala/fastjson v1.6.4 // indirect + golang.org/x/crypto v0.43.0 // indirect + golang.org/x/sync v0.17.0 // indirect + golang.org/x/sys v0.37.0 // indirect + golang.org/x/text v0.30.0 // indirect +) diff --git a/scheduler-service/go.sum b/scheduler-service/go.sum new file mode 100644 index 000000000..5fb660e26 --- /dev/null +++ b/scheduler-service/go.sum @@ -0,0 +1,68 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/lestrrat-go/blackmagic v1.0.4 h1:IwQibdnf8l2KoO+qC3uT4OaTWsW7tuRQXy9TRN9QanA= +github.com/lestrrat-go/blackmagic v1.0.4/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw= +github.com/lestrrat-go/dsig v1.0.0 h1:OE09s2r9Z81kxzJYRn07TFM9XA4akrUdoMwr0L8xj38= +github.com/lestrrat-go/dsig v1.0.0/go.mod h1:dEgoOYYEJvW6XGbLasr8TFcAxoWrKlbQvmJgCR0qkDo= +github.com/lestrrat-go/dsig-secp256k1 v1.0.0 h1:JpDe4Aybfl0soBvoVwjqDbp+9S1Y2OM7gcrVVMFPOzY= +github.com/lestrrat-go/dsig-secp256k1 v1.0.0/go.mod h1:CxUgAhssb8FToqbL8NjSPoGQlnO4w3LG1P0qPWQm/NU= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/httprc/v3 v3.0.1 h1:3n7Es68YYGZb2Jf+k//llA4FTZMl3yCwIjFIk4ubevI= +github.com/lestrrat-go/httprc/v3 v3.0.1/go.mod h1:2uAvmbXE4Xq8kAUjVrZOq1tZVYYYs5iP62Cmtru00xk= +github.com/lestrrat-go/jwx/v3 v3.0.12 h1:p25r68Y4KrbBdYjIsQweYxq794CtGCzcrc5dGzJIRjg= +github.com/lestrrat-go/jwx/v3 v3.0.12/go.mod h1:HiUSaNmMLXgZ08OmGBaPVvoZQgJVOQphSrGr5zMamS8= +github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= +github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= +github.com/lestrrat-go/option/v2 v2.0.0 h1:XxrcaJESE1fokHy3FpaQ/cXW8ZsIdWcdFzzLOcID3Ss= +github.com/lestrrat-go/option/v2 v2.0.0/go.mod h1:oSySsmzMoR0iRzCDCaUfsCzxQHUEuhOViQObyy7S6Vg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0= +github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/valyala/fastjson v1.6.4 h1:uAUNq9Z6ymTgGhcm0UynUAB6tlbakBrz6CQFax3BXVQ= +github.com/valyala/fastjson v1.6.4/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= +golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= +golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/scheduler-service/main.go b/scheduler-service/main.go new file mode 100644 index 000000000..215e0f0ce --- /dev/null +++ b/scheduler-service/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + + "github.com/joho/godotenv" + "gorm.io/gorm" +) + +var ( + db *gorm.DB + JWKSEndpoint string + JWTIssuer string + JWTAudience string +) + +func main() { + if err := godotenv.Load(); err != nil { + log.Println("Error loading .env file") + } + + JWKSEndpoint = os.Getenv("JWKS_ENDPOINT") + JWTIssuer = os.Getenv("JWT_ISSUER") + JWTAudience = os.Getenv("JWT_AUDIENCE") + + if err := openDatabaseConnection(&db); err != nil { + log.Fatal("Failed to connect to database:", err) + } + + if err := db.AutoMigrate(&ScheduledTask{}); err != nil { + log.Fatal("Failed to migrate database:", err) + } + + if scheduledTasks, err := gorm.G[ScheduledTask](db).Where("has_executed = ?", false).Find(context.Background()); err != nil { + log.Println("Error fetching scheduled tasks:", err) + } else { + for _, task := range scheduledTasks { + go scheduleTaskExecution(context.Background(), task) + } + } + + http.HandleFunc("/schedule", handleRequest) + + log.Printf("Server running on :%s", os.Getenv("SERVER_PORT")) + log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", os.Getenv("SERVER_PORT")), nil)) +} diff --git a/scheduler-service/rateLimitMiddleware.go b/scheduler-service/rateLimitMiddleware.go new file mode 100644 index 000000000..619e2f479 --- /dev/null +++ b/scheduler-service/rateLimitMiddleware.go @@ -0,0 +1,69 @@ +package main + +import ( + "log" + "net" + "net/http" + "sync" + "time" + + "golang.org/x/time/rate" +) + +var ( + limiters = make(map[string]*trackedLimiter) + mu sync.Mutex +) + +type trackedLimiter struct { + *rate.Limiter + lastSeen time.Time +} + +func cleanupLimiters(expiry time.Duration) { + now := time.Now() + for ip, limiter := range limiters { + if now.Sub(limiter.lastSeen) > expiry { + delete(limiters, ip) + } + } +} + +func getLimiter(ip string) *rate.Limiter { + mu.Lock() + defer mu.Unlock() + + cleanupLimiters(10 * time.Minute) + log.Println("Current limiters:", len(limiters)) + + lim, exists := limiters[ip] + if !exists { + lim = &trackedLimiter{ + Limiter: rate.NewLimiter(1, 5), + lastSeen: time.Now(), + } + limiters[ip] = lim + } else { + lim.lastSeen = time.Now() + } + + return lim.Limiter +} + +func rateLimitMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + + limiter := getLimiter(host) + if !limiter.Allow() { + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/scheduler-service/requestHandler.go b/scheduler-service/requestHandler.go new file mode 100644 index 000000000..98c044375 --- /dev/null +++ b/scheduler-service/requestHandler.go @@ -0,0 +1,130 @@ +package main + +import ( + "context" + "encoding/json" + "log" + "net/http" + "os" + + "github.com/lestrrat-go/jwx/v3/jwt" + "gorm.io/gorm" +) + +type ScheduleTaskRequestData struct { + RunTimestamp string `json:"runTimestamp"` + EndpointURL string `json:"endpointURL"` + Body string `json:"body"` + Password string `json:"password"` + Token string `json:"token,omitempty"` +} + +func handleRequest(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + rateLimitMiddleware(http.HandlerFunc(handlePost)).ServeHTTP(w, r) + + case http.MethodGet: + rateLimitMiddleware(AuthMiddleware(http.HandlerFunc(handleGet))).ServeHTTP(w, r) + + default: + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + } +} + +func handlePost(w http.ResponseWriter, r *http.Request) { + var data ScheduleTaskRequestData + + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + http.Error(w, "Invalid JSON", http.StatusBadRequest) + + return + } + + if data.Password != os.Getenv("PASSWORD") { + log.Printf("Unauthorised access attempt from %s with password: %s", r.RemoteAddr, data.Password) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + + return + } + + var subject string + if data.Token != "" { + parseOptions := []jwt.ParseOption{ + jwt.WithVerify(false), + jwt.WithIssuer(JWTIssuer), + jwt.WithAudience(JWTAudience), + } + + if token, err := jwt.Parse([]byte(data.Token), parseOptions...); err != nil { + log.Printf("Failed to parse JWT: %s", err) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + + return + } else { + subject, _ = token.Subject() + } + } + + newTask := ScheduledTask{ + RunTimestamp: data.RunTimestamp, + EndpointURL: data.EndpointURL, + Body: data.Body, + HasExecuted: false, + CreatedBy: &subject, + } + + if err := gorm.G[ScheduledTask](db).Create(r.Context(), &newTask); err != nil { + http.Error(w, "Failed to write to database", http.StatusInternalServerError) + + return + } else { + log.Printf("Scheduled task created: %+v", struct { + RunTimestamp string + EndpointURL string + Body string + }{ + RunTimestamp: newTask.RunTimestamp, + EndpointURL: newTask.EndpointURL, + Body: newTask.Body, + }) + } + + scheduleTaskExecution(context.Background(), newTask) + + w.WriteHeader(http.StatusCreated) +} + +func handleGet(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + password := query.Get("password") + + if password != os.Getenv("PASSWORD") { + log.Printf("Unauthorised access attempt from %s with password: %s", r.RemoteAddr, password) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + + return + } + + parsedToken, err := jwt.Parse([]byte(getTokenFromHeader(r)), jwt.WithVerify(false)) + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + + return + } + subject, _ := parsedToken.Subject() + + tasks, err := gorm.G[ScheduledTask](db). + Where("created_by = ? AND has_executed = ?", subject, false). + Find(r.Context()) + if err != nil { + http.Error(w, "Failed to read from database", http.StatusInternalServerError) + + return + } + + w.Header().Set("Content-Type", "application/json") + if err = json.NewEncoder(w).Encode(tasks); err != nil { + log.Printf("Error encoding response: %v", err) + } +} diff --git a/scheduler-service/taskHandler.go b/scheduler-service/taskHandler.go new file mode 100644 index 000000000..413614c9e --- /dev/null +++ b/scheduler-service/taskHandler.go @@ -0,0 +1,104 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "log" + "net/http" + "os" + "time" + + "gorm.io/gorm" +) + +type ScheduledTask struct { + gorm.Model + RunTimestamp string + EndpointURL string + Body string + HasExecuted bool + CreatedBy *string +} + +func scheduleTaskExecution(ctx context.Context, task ScheduledTask) { + runTime, err := time.Parse(time.RFC3339, task.RunTimestamp) + if err != nil { + log.Printf("Failed to parse RunTimestamp for task ID %d: %v", task.ID, err) + + return + } + + delay := time.Until(runTime) + + if delay <= 0 { + log.Printf("RunTimestamp for task ID %d is in the past. Executing immediately.", task.ID) + go executeTask(ctx, task) + + return + } + + log.Printf("Scheduling task ID %d to run at %s", task.ID, runTime.Format(time.RFC1123)) + + time.AfterFunc(delay, func() { + executeTask(ctx, task) + }) +} + +// TODO: Decide how to handle failures/retries +func executeTask(ctx context.Context, task ScheduledTask) { + log.Printf("Executing task ID: %d to %s", task.ID, task.EndpointURL) + + var bodyMap map[string]any + if err := json.Unmarshal([]byte(task.Body), &bodyMap); err != nil { + log.Printf("Error unmarshalling body for task ID %d: %v", task.ID, err) + + return + } + bodyMap["password"] = os.Getenv("PASSWORD") + bodyBytes, err := json.Marshal(bodyMap) + if err != nil { + log.Printf("Error marshalling body for task ID %d: %v", task.ID, err) + + return + } + req, err := http.NewRequest(http.MethodPost, task.EndpointURL, bytes.NewBuffer(bodyBytes)) + if err != nil { + log.Printf("Error creating request for task ID %d: %v", task.ID, err) + + return + } + + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{ + Timeout: 30 * time.Second, + } + + resp, err := client.Do(req) + if err != nil { + log.Printf("Error executing request for task ID %d: %v", task.ID, err) + + return + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + log.Printf("Task ID %d executed successfully. Status: %d", task.ID, resp.StatusCode) + } else { + log.Printf("Task ID %d executed with non-success status: %d", task.ID, resp.StatusCode) + + return + } + + setTaskExecuted(ctx, task.ID) +} + +// TODO: Decide how to handle failures/retries +func setTaskExecuted(ctx context.Context, taskID uint) { + if _, err := gorm.G[ScheduledTask](db).Where("id = ?", taskID).Update(ctx, "has_executed", true); err != nil { + log.Printf("Failed to update database for task ID %d: %v", taskID, err) + } +} diff --git a/src/hooks.server.ts b/src/hooks.server.ts index 2d66687c3..9eaae3cc8 100644 --- a/src/hooks.server.ts +++ b/src/hooks.server.ts @@ -52,10 +52,13 @@ const { handle: authHandle } = SvelteKitAuth({ group_list: profile.groups, }; }, + authorization: { + params: { scope: "openid profile email offline_access" }, + }, }), ], callbacks: { - jwt({ token, user }) { + async jwt({ token, user, account }) { if (user) { token.student_id = user.student_id; token.group_list = user.group_list ?? []; @@ -63,16 +66,67 @@ const { handle: authHandle } = SvelteKitAuth({ token.family_name = user.family_name; token.email = user.email; } + + if (account) { + token.refresh_token = account.refresh_token; + token.id_token = account.id_token; + token.expires_at = account.expires_at; + } else if (token.expires_at && Date.now() < token.expires_at * 1000) { + return token; + } else { + if (!token.refresh_token) throw new Error("Missing refresh_token"); + + try { + const response = await fetch( + envPublic.PUBLIC_AUTH_AUTHENTIK_TOKEN_ENDPOINT, + { + method: "POST", + body: new URLSearchParams({ + client_id: env.AUTH_AUTHENTIK_CLIENT_ID, + client_secret: env.AUTH_AUTHENTIK_CLIENT_SECRET, + grant_type: "refresh_token", + refresh_token: token.refresh_token, + }), + }, + ); + + const tokensOrError = await response.json(); + + if (!response.ok) throw tokensOrError; + + token.id_token = tokensOrError.id_token; + token.expires_at = + Math.floor(Date.now() / 1000) + tokensOrError.expires_in; + token.refresh_token = + tokensOrError.refresh_token ?? token.refresh_token; + + return token; + } catch (error) { + console.error("Error refreshing Authentik access_token:", error); + token.error = "RefreshTokenError"; + + return token; + } + } + return token; }, session({ session, token }) { - if (token && session?.user) { - session.user.student_id = token.student_id; - session.user.email = token.email ?? ""; - session.user.group_list = token.group_list; - session.user.given_name = token.given_name; - session.user.family_name = token.family_name; + if (token) { + if (session?.user) { + session.user.student_id = token.student_id; + session.user.email = token.email ?? ""; + session.user.group_list = token.group_list; + session.user.given_name = token.given_name; + session.user.family_name = token.family_name; + } + + session.error = token.error; + if (session.error) { + throw redirect(302, "/signout"); + } } + return session; }, /** @@ -96,7 +150,7 @@ const { handle: authHandle } = SvelteKitAuth({ }); const databaseHandle: Handle = async ({ event, resolve }) => { - const session = await event.locals.getSession(); + const session = await event.locals.auth(); const studentId = session?.user.student_id; const aClient = authorizedPrismaClient; @@ -120,7 +174,6 @@ const databaseHandle: Handle = async ({ event, resolve }) => { ) ?? sourceLanguageTag; event.locals.language = lang; setLanguageTag(lang); - const prisma = getExtendedPrismaClient(lang, session?.user.student_id); if (!session?.user) { diff --git a/src/lib/news/schema.ts b/src/lib/news/schema.ts index dbffa9274..f29dd5251 100644 --- a/src/lib/news/schema.ts +++ b/src/lib/news/schema.ts @@ -20,6 +20,7 @@ export const articleSchema = z.object({ imageUrls: z.string().array().optional(), imageUrl: z.string().optional().nullable(), youtubeUrl: z.string().optional().nullable(), + publishTime: z.date().optional().nullable(), // https://github.com/colinhacks/zod/pull/3118 images: z .instanceof(File, { message: "Please upload a file." }) diff --git a/src/lib/news/server/actions.ts b/src/lib/news/server/actions.ts index da9d400dd..5779b173e 100644 --- a/src/lib/news/server/actions.ts +++ b/src/lib/news/server/actions.ts @@ -2,8 +2,6 @@ import { PUBLIC_BUCKETS_FILES } from "$env/static/public"; import { uploadFile } from "$lib/files/uploadFiles"; import { createSchema, updateSchema } from "$lib/news/schema"; import authorizedPrismaClient from "$lib/server/authorizedPrisma"; -import sendNotification from "$lib/utils/notifications"; -import { NotificationType } from "$lib/utils/notifications/types"; import { redirect } from "$lib/utils/redirect"; import { slugWithCount, slugify } from "$lib/utils/slugify"; import * as m from "$paraglide/messages"; @@ -13,8 +11,8 @@ import type { AuthUser } from "@zenstackhq/runtime"; import { zod } from "sveltekit-superforms/adapters"; import { message, superValidate, fail } from "sveltekit-superforms"; import DOMPurify from "isomorphic-dompurify"; -import { markdownToTxt } from "markdown-to-txt"; -import type { ExtendedPrismaModel } from "$lib/server/extendedPrisma"; +import { sendNewArticleNotification } from "./notifications"; +import { scheduleExecution } from "$lib/server/scheduleExecution"; const uploadImage = async (user: AuthUser, image: File, slug: string) => { const randomName = (Math.random() + 1).toString(36).substring(2); @@ -33,42 +31,6 @@ const uploadImage = async (user: AuthUser, image: File, slug: string) => { return imageUrl; }; -const sendNewArticleNotification = async ( - article: ExtendedPrismaModel<"Article"> & { - tags: Array, "id">>; - author: ExtendedPrismaModel<"Author">; - }, - notificationText: string | null | undefined, -) => { - console.log("notifications: getting members"); - const subscribedMembers = await authorizedPrismaClient.member.findMany({ - where: { - subscribedTags: { - some: { - id: { - in: article.tags.map(({ id }) => id), - }, - }, - }, - }, - select: { - id: true, - }, - }); - - console.log("notifications: sending"); - await sendNotification({ - title: article.header, - message: notificationText - ? notificationText - : markdownToTxt(article.body).slice(0, 254), - type: NotificationType.NEW_ARTICLE, - link: `/news/${article.slug}`, - fromAuthor: article.author, - memberIds: subscribedMembers.map(({ id }) => id), - }); -}; - export const createArticle: Action = async (event) => { const { request, locals } = event; const { prisma, user } = locals; @@ -86,6 +48,7 @@ export const createArticle: Action = async (event) => { images, bodySv, bodyEn, + publishTime, ...rest } = form.data; const existingAuthor = await prisma.author.findFirst({ @@ -110,50 +73,65 @@ export const createArticle: Action = async (event) => { await Promise.resolve(); rest.imageUrls = await Promise.all(tasks); - const result = await prisma.article.create({ - data: { - slug, - headerSv: headerSv, - headerEn: headerEn, - bodySv: DOMPurify.sanitize(bodySv), - bodyEn: bodyEn ? DOMPurify.sanitize(bodyEn) : bodyEn, - ...rest, - author: { - connect: existingAuthor - ? { - id: existingAuthor.id, - } - : undefined, - create: !existingAuthor - ? { - member: { - connect: { studentId: user?.studentId }, - }, - mandate: author.mandateId - ? { - connect: { - member: { studentId: user?.studentId }, - id: author.mandateId, - }, - } - : undefined, - customAuthor: author.customId - ? { - connect: { id: author.customId }, - } - : undefined, - } - : undefined, - }, - tags: { - connect: tags - .filter((tag) => !!tag) - .map((tag) => ({ - id: tag.id, - })), - }, - publishedAt: new Date(), + const data = { + slug, + headerSv: headerSv, + headerEn: headerEn, + bodySv: DOMPurify.sanitize(bodySv), + bodyEn: bodyEn ? DOMPurify.sanitize(bodyEn) : bodyEn, + ...rest, + author: { + connect: existingAuthor + ? { + id: existingAuthor.id, + } + : undefined, + create: !existingAuthor + ? { + member: { + connect: { studentId: user?.studentId }, + }, + mandate: author.mandateId + ? { + connect: { + member: { studentId: user?.studentId }, + id: author.mandateId, + }, + } + : undefined, + customAuthor: author.customId + ? { + connect: { id: author.customId }, + } + : undefined, + } + : undefined, }, + tags: { + connect: tags + .filter((tag) => !!tag) + .map((tag) => ({ + id: tag.id, + })), + }, + publishedAt: publishTime ?? new Date(), + }; + + if (publishTime && publishTime > new Date()) { + return await scheduleExecution( + request, + data, + publishTime, + form, + m.news_errors_schedulingFailed(), + m.news_articleScheduled(), + "/news", + event, + ); + } + + const result = await prisma.article.create({ + data, include: { author: true, }, @@ -180,6 +158,7 @@ export const createArticle: Action = async (event) => { event, ); }; + export const updateArticle: Action<{ slug: string }> = async (event) => { const { request, locals } = event; const { prisma, user } = locals; diff --git a/src/lib/news/server/notifications.ts b/src/lib/news/server/notifications.ts new file mode 100644 index 000000000..454268d82 --- /dev/null +++ b/src/lib/news/server/notifications.ts @@ -0,0 +1,41 @@ +import authorizedPrismaClient from "$lib/server/authorizedPrisma"; +import type { ExtendedPrismaModel } from "$lib/server/extendedPrisma"; +import sendNotification from "$lib/utils/notifications"; +import { NotificationType } from "$lib/utils/notifications/types"; +import markdownToTxt from "markdown-to-txt"; + +export const sendNewArticleNotification = async ( + article: ExtendedPrismaModel<"Article"> & { + tags: Array, "id">>; + author: ExtendedPrismaModel<"Author">; + }, + notificationText: string | null | undefined, +) => { + console.log("notifications: getting members"); + const subscribedMembers = await authorizedPrismaClient.member.findMany({ + where: { + subscribedTags: { + some: { + id: { + in: article.tags.map(({ id }) => id), + }, + }, + }, + }, + select: { + id: true, + }, + }); + + console.log("notifications: sending"); + await sendNotification({ + title: article.header, + message: notificationText + ? notificationText + : markdownToTxt(article.body).slice(0, 254), + type: NotificationType.NEW_ARTICLE, + link: `/news/${article.slug}`, + fromAuthor: article.author, + memberIds: subscribedMembers.map(({ id }) => id), + }); +}; diff --git a/src/lib/server/getDecryptedJWT.ts b/src/lib/server/getDecryptedJWT.ts new file mode 100644 index 000000000..294485661 --- /dev/null +++ b/src/lib/server/getDecryptedJWT.ts @@ -0,0 +1,6 @@ +import { env } from "$env/dynamic/private"; +import { getToken, type JWT } from "@auth/core/jwt"; + +export const getDecryptedJWT = async (req: Request): Promise => { + return await getToken({ req, secret: env.AUTH_SECRET }); +}; diff --git a/src/lib/server/scheduleExecution.ts b/src/lib/server/scheduleExecution.ts new file mode 100644 index 000000000..b4253bf16 --- /dev/null +++ b/src/lib/server/scheduleExecution.ts @@ -0,0 +1,53 @@ +import { redirect } from "$lib/utils/redirect"; +import { getDecryptedJWT } from "$lib/server/getDecryptedJWT"; +import { env } from "$env/dynamic/private"; +import { type SuperValidated, fail } from "sveltekit-superforms"; +import type { RequestEvent } from "@sveltejs/kit"; + +export const scheduleExecution = async ( + request: Request, + data: Record, + publishTime: Date, + form: SuperValidated>, + errorMessage: string, + successMessage: string, + redirectEndpoint: string, + event: RequestEvent, +) => { + const jwt = await getDecryptedJWT(request); + let result; + try { + result = await fetch(env.SCHEDULER_ENDPOINT, { + method: "POST", + body: JSON.stringify({ + body: JSON.stringify(data), + endpointURL: request.url, + runTimestamp: publishTime, + password: env.SCHEDULER_PASSWORD, + token: jwt?.["id_token"], + }), + headers: { "Content-Type": "application/json" }, + }); + } catch (error) { + return fail(500, { + form, + message: `${errorMessage}: ${error}`, + }); + } + + if (!result.ok) { + return fail(500, { + form, + message: errorMessage, + }); + } + + throw redirect( + redirectEndpoint, + { + message: successMessage, + type: "success", + }, + event, + ); +}; diff --git a/src/routes/(app)/api/schedule/news/+server.ts b/src/routes/(app)/api/schedule/news/+server.ts new file mode 100644 index 000000000..e10b435a0 --- /dev/null +++ b/src/routes/(app)/api/schedule/news/+server.ts @@ -0,0 +1,34 @@ +import { env } from "$env/dynamic/private"; +import { sendNewArticleNotification } from "$lib/news/server/notifications"; +import type { RequestHandler } from "@sveltejs/kit"; + +export const POST: RequestHandler = async ({ request, locals }) => { + const body = await request.json(); + const { + password, + sendNotification: shouldSendNotification, + ...newsItem + } = body; + if (!password || password !== env.SCHEDULER_PASSWORD) { + return new Response("Unauthorized", { status: 401 }); + } + + const { prisma } = locals; + const result = await prisma.article.create({ + data: newsItem, + include: { author: true }, + }); + + if (shouldSendNotification) { + console.log("send notifications"); + await sendNewArticleNotification( + { + ...result, + tags: newsItem.tags, + }, + newsItem.notificationText, + ); + } + + return new Response("Article created", { status: 201 }); +}; diff --git a/src/routes/(app)/news/+page.server.ts b/src/routes/(app)/news/+page.server.ts index 444525e2f..9cc0bfc98 100644 --- a/src/routes/(app)/news/+page.server.ts +++ b/src/routes/(app)/news/+page.server.ts @@ -8,8 +8,10 @@ import { getPageOrThrowSvelteError, getPageSizeOrThrowSvelteError, } from "$lib/utils/url.server"; +import { getDecryptedJWT } from "$lib/server/getDecryptedJWT"; +import { env } from "$env/dynamic/private"; -export const load: PageServerLoad = async ({ locals, url }) => { +export const load: PageServerLoad = async ({ locals, url, request }) => { const { prisma } = locals; const articleCount = await prisma.article.count(); const pageSize = getPageSizeOrThrowSvelteError(url); @@ -27,11 +29,42 @@ export const load: PageServerLoad = async ({ locals, url }) => { }), getAllTags(prisma), ]); + const jwt = await getDecryptedJWT(request); + const scheduledTasks: ScheduledTaskParsed[] = []; + if (jwt) { + try { + const result = await fetch( + `${env.SCHEDULER_ENDPOINT}?password=${env.SCHEDULER_PASSWORD}`, + { + headers: { + Authorization: `Bearer ${jwt["id_token"]}`, + }, + }, + ); + + if (!result.ok) { + console.error( + `Failed to fetch scheduled tasks: ${result.status} ${result.statusText}`, + ); + } else { + for (const task of (await result.json()) as ScheduledTaskRaw[]) { + scheduledTasks.push({ + ID: task.ID, + RunTimestamp: task.RunTimestamp, + Body: JSON.parse(task.Body) as NewsArticleData, + }); + } + } + } catch (error) { + console.error("Error fetching or parsing scheduled tasks:", error); + } + } return { articles, pageCount, allTags, likeForm: await superValidate(zod(likeSchema)), + scheduledTasks: scheduledTasks, }; }; @@ -39,3 +72,65 @@ export const actions: Actions = { like: likesAction(true), dislike: likesAction(false), }; + +type ScheduledTaskRaw = { + ID: string; + RunTimestamp: string; + Body: string; +}; + +type ScheduledTaskParsed = { + ID: string; + RunTimestamp: string; + Body: NewsArticleData; +}; + +type NewsArticleData = { + author: { + connect: + | { + id: string; + } + | undefined; + create: + | { + member: { + connect: { + studentId: string | undefined; + }; + }; + mandate: + | { + connect: { + member: { + studentId: string | undefined; + }; + id: string; + }; + } + | undefined; + customAuthor: + | { + connect: { + id: string; + }; + } + | undefined; + } + | undefined; + }; + tags: { + connect: Array<{ + id: string; + }>; + }; + publishedAt: Date; + imageUrl?: string | null | undefined; + imageUrls?: string[] | undefined; + youtubeUrl?: string | null | undefined; + slug: string; + headerSv: string; + headerEn: string | null; + bodySv: string; + bodyEn: string | null; +}; diff --git a/src/routes/(app)/news/+page.svelte b/src/routes/(app)/news/+page.svelte index d54589d43..55affbad8 100644 --- a/src/routes/(app)/news/+page.svelte +++ b/src/routes/(app)/news/+page.svelte @@ -18,6 +18,8 @@ ); let form: HTMLFormElement; + const { scheduledTasks } = data; + let showScheduled = false; @@ -54,9 +56,33 @@ {#if isAuthorized(apiNames.NEWS.CREATE, data.user)} + {m.news_create()} {/if} + {#if scheduledTasks.length > 0} + + {/if} + {#if showScheduled} +
+

{m.news_scheduledNews()}

+
+ {#each scheduledTasks as task (task.ID)} +
+

{`${m.news_title()}:`} {task.Body.headerSv}

+

+ {`${m.news_publishDate()}:`} + {new Date(task.RunTimestamp).toLocaleString()} +

+
+ {/each} +
+
+ {/if} +
{#each data.articles as article (article.id)} diff --git a/src/routes/(app)/news/ArticleForm.svelte b/src/routes/(app)/news/ArticleForm.svelte index f9e1a1442..bdb1556c8 100644 --- a/src/routes/(app)/news/ArticleForm.svelte +++ b/src/routes/(app)/news/ArticleForm.svelte @@ -1,4 +1,5 @@