initial commit
This commit is contained in:
17
internal/pkg/azure/azblob/azblob.go
Normal file
17
internal/pkg/azure/azblob/azblob.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package azblob
|
||||
|
||||
import (
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
)
|
||||
|
||||
func New(logger zerolog.Logger, cred *azidentity.DefaultAzureCredential) (*azblob.Client, error) {
|
||||
client, err := azblob.NewClientFromConnectionString("", nil)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("failed to create azure blob storage client")
|
||||
return nil, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
27
internal/pkg/azure/azbus/azbus.go
Normal file
27
internal/pkg/azure/azbus/azbus.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package azbus
|
||||
|
||||
import (
|
||||
"github.com/ThreeDotsLabs/watermill"
|
||||
"github.com/ThreeDotsLabs/watermill/message"
|
||||
"github.com/ThreeDotsLabs/watermill/pubsub/gochannel"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"base/config"
|
||||
"base/pkg/watermill/azsb"
|
||||
)
|
||||
|
||||
func New(cfg *config.AppConfig, logger zerolog.Logger) (message.Subscriber, message.Publisher, error) {
|
||||
if cfg.Environment == config.Local {
|
||||
gch := gochannel.NewGoChannel(gochannel.Config{}, watermill.NewStdLogger(true, true))
|
||||
return gch, gch, nil
|
||||
}
|
||||
|
||||
return azsb.NewAzBus(
|
||||
azsb.Config{
|
||||
ConnectionString: cfg.AzureServiceBus.ConnectionString,
|
||||
UseManagedIdentity: cfg.AzureServiceBus.UseManagedIdentity,
|
||||
Namespace: cfg.AzureServiceBus.Namespace,
|
||||
},
|
||||
logger,
|
||||
)
|
||||
}
|
||||
15
internal/pkg/azure/azureidentity/azidentity.go
Normal file
15
internal/pkg/azure/azureidentity/azidentity.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package azureidentity
|
||||
|
||||
import (
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func New(logger zerolog.Logger) (*azidentity.DefaultAzureCredential, error) {
|
||||
cred, err := azidentity.NewDefaultAzureCredential(nil)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("azure identity error")
|
||||
return nil, err
|
||||
}
|
||||
return cred, nil
|
||||
}
|
||||
143
internal/pkg/azure/communication/azcommunication.go
Normal file
143
internal/pkg/azure/communication/azcommunication.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package communication
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"base/config"
|
||||
"base/pkg/email"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
logger zerolog.Logger
|
||||
endpoint string
|
||||
accessKey string
|
||||
apiVersion string
|
||||
senderAddress string
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
func New(logger zerolog.Logger, config *config.AppConfig) email.Email {
|
||||
return &client{
|
||||
logger: logger,
|
||||
endpoint: config.AzureCommunicationConfig.Endpoint,
|
||||
accessKey: config.AzureCommunicationConfig.AccessKey,
|
||||
apiVersion: config.AzureCommunicationConfig.ApiVersion,
|
||||
senderAddress: config.AzureCommunicationConfig.SenderAddress,
|
||||
}
|
||||
}
|
||||
|
||||
func (c client) Send(ctx context.Context, params email.Request) (*email.Response, error) {
|
||||
var tpl bytes.Buffer
|
||||
if err := c.templates.ExecuteTemplate(&tpl, generateTemplateName(params.Template.EmailTemplateName), params.Template.Data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
html := tpl.String()
|
||||
|
||||
request := &ApiRequest{
|
||||
SenderAddress: c.senderAddress,
|
||||
Content: ApiContentDto{
|
||||
Subject: params.Subject,
|
||||
Html: html,
|
||||
},
|
||||
Recipients: ApiRecipientDto{
|
||||
To: []ApiRecipientDetailDto{
|
||||
{
|
||||
Address: params.RecipientAddress,
|
||||
DisplayName: params.UserFullName,
|
||||
},
|
||||
},
|
||||
CC: make([]ApiRecipientDetailDto, 0),
|
||||
BCC: make([]ApiRecipientDetailDto, 0),
|
||||
},
|
||||
}
|
||||
byteBody, err := json.Marshal(&request)
|
||||
if err != nil {
|
||||
return nil, errors.New("marshaling error")
|
||||
}
|
||||
|
||||
method := "POST"
|
||||
endpoint := c.endpoint
|
||||
u, _ := url.Parse(endpoint)
|
||||
snedPathAndQuery := fmt.Sprintf(
|
||||
"/emails:send?api-version=%s",
|
||||
c.apiVersion,
|
||||
)
|
||||
date := time.Now().In(time.FixedZone("GMT", 0)).Format("Mon, 02 Jan 2006 15:04:05 GMT")
|
||||
host := u.Host
|
||||
|
||||
contentHash := computeContentHash(byteBody)
|
||||
|
||||
stringToSign := fmt.Sprintf("%s\n%s\n%s;%s;%s", method, snedPathAndQuery, date, host, contentHash)
|
||||
signature, err := computeSignature(stringToSign, c.accessKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authHeader := fmt.Sprintf("HMAC-SHA256 SignedHeaders=x-ms-date;host;x-ms-content-sha256&Signature=%s", signature)
|
||||
fullURL := endpoint + snedPathAndQuery
|
||||
req, _ := http.NewRequest(method, fullURL, bytes.NewReader(byteBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-ms-date", date)
|
||||
req.Header.Set("x-ms-content-sha256", contentHash)
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
req.Header.Set("Host", host)
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
response := &ApiErrorResponse{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.logger.Info().Msgf("email sending failed. %v", response)
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
c.logger.Info().Msgf("email sending done. %v", resp.Body)
|
||||
response := &email.Response{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func computeContentHash(body []byte) string {
|
||||
sum := sha256.Sum256(body)
|
||||
return base64.StdEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func computeSignature(stringToSign, base64AccessKey string) (string, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(base64AccessKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
mac := hmac.New(sha256.New, key)
|
||||
_, err = mac.Write([]byte(stringToSign))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sig := mac.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(sig), nil
|
||||
}
|
||||
|
||||
func generateTemplateName(emailTemplateName email.Template) string {
|
||||
return fmt.Sprintf("%s.html", emailTemplateName.String())
|
||||
}
|
||||
41
internal/pkg/azure/communication/dto.go
Normal file
41
internal/pkg/azure/communication/dto.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package communication
|
||||
|
||||
type ApiResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type ApiContentDto struct {
|
||||
Subject string `json:"subject"`
|
||||
Html string `json:"html"`
|
||||
PlainText string `json:"plainText"`
|
||||
}
|
||||
type ApiRecipientDetailDto struct {
|
||||
Address string `json:"address"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
|
||||
type ApiRecipientDto struct {
|
||||
To []ApiRecipientDetailDto `json:"to"`
|
||||
CC []ApiRecipientDetailDto `json:"cc"`
|
||||
BCC []ApiRecipientDetailDto `json:"bcc"`
|
||||
}
|
||||
|
||||
type ApiRequest struct {
|
||||
SenderAddress string `json:"senderAddress"`
|
||||
Content ApiContentDto `json:"content"`
|
||||
Recipients ApiRecipientDto `json:"recipients"`
|
||||
}
|
||||
|
||||
type ApiErrorResponse struct {
|
||||
Error struct {
|
||||
AdditionalInfo []struct {
|
||||
Info any `json:"info"`
|
||||
Type string `json:"type"`
|
||||
} `json:"additionalInfo"`
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Target string `json:"target"`
|
||||
Details any `json:"details"`
|
||||
} `json:"error"`
|
||||
}
|
||||
99
internal/pkg/database/database.go
Normal file
99
internal/pkg/database/database.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"base/config"
|
||||
"base/pkg/metrics"
|
||||
)
|
||||
|
||||
// NewRWDatabaseConnection creates a new database connection
|
||||
func NewRWDatabaseConnection(cfg *config.AppConfig, logger zerolog.Logger, metric *metrics.Metrics) (*gorm.DB, error) {
|
||||
start := time.Now()
|
||||
|
||||
lg := logger.
|
||||
With().
|
||||
Str("module", "database").
|
||||
Int("maxOpenConnection", cfg.Database.MaxOpenConns).
|
||||
Int("maxIdleConnection", cfg.Database.MaxIdleConns).
|
||||
Logger()
|
||||
|
||||
gormConfig := &gorm.Config{Logger: NewGormLogger(logger, time.Second*5)}
|
||||
|
||||
wrDB, sqlDB, err := wr(cfg, gormConfig)
|
||||
if err != nil {
|
||||
fmt.Println("[DATABASE CONNECTION ERROR]Failed to connect to database", err.Error())
|
||||
metric.RecordDatabaseQuery("ConnectWR", "database", time.Since(start), err)
|
||||
|
||||
lg.Error().
|
||||
Err(err).
|
||||
Msg("failed to connect to database")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metric.RecordDatabaseQuery("ConnectWR", "database", time.Since(start), nil)
|
||||
|
||||
// Start monitoring connection pool metrics
|
||||
go monitorConnectionPool(sqlDB, metric, logger)
|
||||
|
||||
duration := time.Since(start)
|
||||
metric.RecordDatabaseQuery("Connect", "database", duration, nil)
|
||||
|
||||
lg.Info().Msg("Database connection established")
|
||||
|
||||
return wrDB, nil
|
||||
}
|
||||
|
||||
func wr(config *config.AppConfig, gormConfig *gorm.Config) (*gorm.DB, *sql.DB, error) {
|
||||
// PostgreSQL DSN format: postgres://user:password@host:port/dbname?sslmode=disable
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s user=%s password=%s dbname=%s port=%d sslmode=%s TimeZone=UTC",
|
||||
config.PgDatabaseConfig.Host,
|
||||
config.PgDatabaseConfig.User,
|
||||
config.PgDatabaseConfig.Password,
|
||||
config.PgDatabaseConfig.Name,
|
||||
config.PgDatabaseConfig.Port,
|
||||
config.PgDatabaseConfig.SSLMode,
|
||||
)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Get the underlying sql.DB
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sqlDB.SetMaxIdleConns(int(config.PgDatabaseConfig.PoolConfig.MinConn))
|
||||
sqlDB.SetMaxOpenConns(int(config.PgDatabaseConfig.PoolConfig.MaxConn))
|
||||
|
||||
// Parse and set connection timeouts from config
|
||||
// TODO: this is not type safe
|
||||
if config.PgDatabaseConfig.PoolConfig.MaxConnIdleTime.String() != "" {
|
||||
if idleTime, parseDurationErr := time.ParseDuration(config.PgDatabaseConfig.PoolConfig.MaxConnIdleTime.String()); parseDurationErr == nil {
|
||||
sqlDB.SetConnMaxIdleTime(idleTime)
|
||||
} else {
|
||||
sqlDB.SetConnMaxIdleTime(5 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
if config.PgDatabaseConfig.PoolConfig.MaxConnLifetime.String() != "" {
|
||||
if lifetime, parseDurationErr := time.ParseDuration(config.PgDatabaseConfig.PoolConfig.MaxConnLifetime.String()); parseDurationErr == nil {
|
||||
sqlDB.SetConnMaxLifetime(lifetime)
|
||||
} else {
|
||||
sqlDB.SetConnMaxLifetime(30 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
return db, sqlDB, nil
|
||||
}
|
||||
94
internal/pkg/database/logger.go
Normal file
94
internal/pkg/database/logger.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type GormLogger struct {
|
||||
logger zerolog.Logger
|
||||
slowThreshold time.Duration
|
||||
logLevel logger.LogLevel
|
||||
}
|
||||
|
||||
func (l *GormLogger) LogMode(level logger.LogLevel) logger.Interface {
|
||||
newLogger := *l
|
||||
newLogger.logLevel = level
|
||||
return &newLogger
|
||||
}
|
||||
|
||||
func (l *GormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.logLevel >= logger.Info {
|
||||
l.logger.Info().Msgf(msg, data...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *GormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.logLevel >= logger.Warn {
|
||||
l.logger.Warn().Msgf(msg, data...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *GormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.logLevel >= logger.Error {
|
||||
l.logger.Error().Msgf(msg, data...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
if l.logLevel <= logger.Silent {
|
||||
return
|
||||
}
|
||||
|
||||
elapsed := time.Since(begin)
|
||||
sql, rows := fc()
|
||||
|
||||
switch {
|
||||
// cache miss / record not found - expected, don't log as error
|
||||
case err != nil && errors.Is(err, gorm.ErrRecordNotFound):
|
||||
if l.logLevel >= logger.Info {
|
||||
l.logger.Debug().
|
||||
Str("sql", sql).
|
||||
Int64("rows", rows).
|
||||
Dur("elapsed", elapsed).
|
||||
Msg("QueryCacheMiss")
|
||||
}
|
||||
// error query
|
||||
case err != nil && l.logLevel >= logger.Error:
|
||||
l.logger.Error().
|
||||
Err(err).
|
||||
Str("sql", sql).
|
||||
Int64("rows", rows).
|
||||
Dur("elapsed", elapsed).
|
||||
Msg("QueryError")
|
||||
|
||||
// slow query
|
||||
case elapsed > l.slowThreshold && l.logLevel >= logger.Warn:
|
||||
l.logger.Warn().
|
||||
Str("sql", sql).
|
||||
Int64("rows", rows).
|
||||
Dur("elapsed", elapsed).
|
||||
Msg("SlowQuery")
|
||||
|
||||
// normal query
|
||||
case l.logLevel >= logger.Info:
|
||||
l.logger.Debug().
|
||||
Str("sql", sql).
|
||||
Int64("rows", rows).
|
||||
Dur("elapsed", elapsed).
|
||||
Msg("Query")
|
||||
}
|
||||
}
|
||||
|
||||
func NewGormLogger(serviceLogger zerolog.Logger, threshold time.Duration) *GormLogger {
|
||||
return &GormLogger{
|
||||
logger: serviceLogger.With().Str("module", "gorm").Logger(),
|
||||
slowThreshold: threshold,
|
||||
logLevel: logger.Warn, // default
|
||||
}
|
||||
}
|
||||
56
internal/pkg/database/utils.go
Normal file
56
internal/pkg/database/utils.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"base/pkg/metrics"
|
||||
)
|
||||
|
||||
// monitorConnectionPool periodically monitors and records connection pool metrics
|
||||
func monitorConnectionPool(sqlDB *sql.DB, metric *metrics.Metrics, logger zerolog.Logger) {
|
||||
ticker := time.NewTicker(30 * time.Second) // Monitor every 30 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
stats := sqlDB.Stats()
|
||||
|
||||
// Record connection pool metrics using available methods
|
||||
// Note: Connection pool size metrics are not available in current metrics package
|
||||
// Consider adding them if needed for monitoring
|
||||
|
||||
// Record wait time if there are any waits
|
||||
if stats.WaitCount > 0 {
|
||||
avgWaitTime := time.Duration(stats.WaitDuration.Nanoseconds() / stats.WaitCount)
|
||||
metric.RecordDatabaseQuery("WaitTime", "database", avgWaitTime, nil)
|
||||
}
|
||||
|
||||
// Log connection pool stats at info level for better visibility
|
||||
logger.Info().
|
||||
Int("open_connections", stats.OpenConnections).
|
||||
Int("in_use", stats.InUse).
|
||||
Int("idle", stats.Idle).
|
||||
Int("max_open", stats.MaxOpenConnections).
|
||||
Int64("wait_count", stats.WaitCount).
|
||||
Int64("wait_duration_ms", stats.WaitDuration.Milliseconds()).
|
||||
Msg("Database connection pool stats")
|
||||
|
||||
// Alert if we're approaching connection limits
|
||||
if stats.OpenConnections >= 7 { // 7 out of 8 max connections
|
||||
logger.Warn().
|
||||
Int("open_connections", stats.OpenConnections).
|
||||
Int("max_open", stats.MaxOpenConnections).
|
||||
Msg("Database connection pool approaching limit - consider reducing concurrent operations")
|
||||
}
|
||||
|
||||
// Alert if there are connection waits
|
||||
if stats.WaitCount > 0 {
|
||||
logger.Warn().
|
||||
Int64("wait_count", stats.WaitCount).
|
||||
Int64("wait_duration_ms", stats.WaitDuration.Milliseconds()).
|
||||
Msg("Database connections are being waited for - possible connection pool exhaustion")
|
||||
}
|
||||
}
|
||||
}
|
||||
128
internal/pkg/logger/logger.go
Normal file
128
internal/pkg/logger/logger.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/syslog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/pkgerrors"
|
||||
|
||||
"base/config"
|
||||
)
|
||||
|
||||
type Level string
|
||||
|
||||
type LogConfig struct {
|
||||
Environment string
|
||||
AppName string
|
||||
LogLevel Level
|
||||
Host string
|
||||
Port string
|
||||
Protocol string
|
||||
}
|
||||
|
||||
const (
|
||||
TRACE Level = "TRACE"
|
||||
DEBUG Level = "DEBUG"
|
||||
INFO Level = "INFO"
|
||||
WARN Level = "WARN"
|
||||
ERROR Level = "ERROR"
|
||||
PANIC Level = "PANIC"
|
||||
)
|
||||
|
||||
func New(appCfg *config.AppConfig) zerolog.Logger {
|
||||
zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMs
|
||||
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
|
||||
|
||||
time.Sleep(700 * time.Millisecond)
|
||||
|
||||
// Determine log level from configuration
|
||||
logLevel := INFO
|
||||
|
||||
// Override with syslog log level if configured
|
||||
if appCfg.Syslog.LogLevel != "" {
|
||||
configuredLevel := Level(appCfg.Syslog.LogLevel)
|
||||
// Validate the configured level
|
||||
switch configuredLevel {
|
||||
case TRACE, DEBUG, INFO, WARN, ERROR, PANIC:
|
||||
logLevel = configuredLevel
|
||||
default:
|
||||
// If invalid level is configured, keep the default
|
||||
fmt.Printf("Invalid log level configured: %s, using default: %s\n", appCfg.Syslog.LogLevel, logLevel)
|
||||
}
|
||||
}
|
||||
|
||||
cfg := LogConfig{
|
||||
Environment: appCfg.Environment,
|
||||
AppName: appCfg.Name, // You can customize this or extract from config
|
||||
LogLevel: logLevel,
|
||||
Host: appCfg.Syslog.Host, // You may want to add these to your config
|
||||
Port: appCfg.Syslog.Port, // Default syslog port
|
||||
Protocol: "udp", // Default syslog protocol
|
||||
}
|
||||
|
||||
switch cfg.Environment {
|
||||
case "development", "local":
|
||||
fmt.Printf("app %s using log level: %s Syslog.LogLevel: %s in %s \n", cfg.AppName, cfg.LogLevel, appCfg.Syslog.LogLevel, cfg.Environment)
|
||||
return zerolog.New(
|
||||
zerolog.NewConsoleWriter(
|
||||
func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = "03:04:05.000PM"
|
||||
})).
|
||||
Level(logLevelToZero(cfg.LogLevel)).
|
||||
With().
|
||||
Caller().
|
||||
Timestamp().
|
||||
Logger()
|
||||
default:
|
||||
fmt.Printf("app %s using log level: %s Syslog.LogLevel: %s in %s \n", cfg.AppName, cfg.LogLevel, appCfg.Syslog.LogLevel, cfg.Environment)
|
||||
syslogWriter, err := syslog.Dial(
|
||||
cfg.Protocol,
|
||||
fmt.Sprintf("%s:%s", cfg.Host, cfg.Port),
|
||||
syslog.LOG_INFO,
|
||||
cfg.AppName,
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to connect to syslog: %s\n", err)
|
||||
return zerolog.New(os.Stdout).
|
||||
Level(logLevelToZero(cfg.LogLevel)).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger()
|
||||
}
|
||||
|
||||
return zerolog.
|
||||
New(zerolog.SyslogLevelWriter(syslogWriter)).
|
||||
Level(logLevelToZero(cfg.LogLevel)).
|
||||
With().
|
||||
Caller().
|
||||
Timestamp().
|
||||
Logger()
|
||||
}
|
||||
}
|
||||
|
||||
func logLevelToZero(level Level) zerolog.Level {
|
||||
switch level {
|
||||
case PANIC:
|
||||
return zerolog.PanicLevel
|
||||
case ERROR:
|
||||
return zerolog.ErrorLevel
|
||||
case WARN:
|
||||
return zerolog.WarnLevel
|
||||
case INFO:
|
||||
return zerolog.InfoLevel
|
||||
case DEBUG:
|
||||
return zerolog.DebugLevel
|
||||
case TRACE:
|
||||
return zerolog.TraceLevel
|
||||
default:
|
||||
return zerolog.InfoLevel
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestLogger creates a no-op logger for tests
|
||||
func NewTestLogger() zerolog.Logger {
|
||||
return zerolog.New(nil).Level(zerolog.Disabled)
|
||||
}
|
||||
36
internal/pkg/module.go
Normal file
36
internal/pkg/module.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package pkg
|
||||
|
||||
import (
|
||||
"go.uber.org/fx"
|
||||
|
||||
"base/internal/dto"
|
||||
"base/internal/pkg/azure/azbus"
|
||||
"base/internal/pkg/azure/communication"
|
||||
"base/internal/pkg/database"
|
||||
"base/internal/pkg/logger"
|
||||
"base/internal/pkg/oauth"
|
||||
"base/pkg/cache"
|
||||
"base/pkg/metrics"
|
||||
"base/pkg/store"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func NewLandingCache(db *gorm.DB, lg zerolog.Logger, m *metrics.Metrics) cache.Cache[dto.Landing] {
|
||||
return cache.New(store.NewPostgresStore[dto.Landing](db, lg, m))
|
||||
}
|
||||
|
||||
var Module = fx.Module(
|
||||
"pkg",
|
||||
fx.Provide(
|
||||
logger.New,
|
||||
database.NewRWDatabaseConnection,
|
||||
communication.New,
|
||||
oauth.New,
|
||||
azbus.New,
|
||||
fx.Annotate(store.NewPostgresStore[string], fx.ResultTags(`name:"verification_store"`)),
|
||||
fx.Annotate(store.NewPostgresStore[string], fx.ResultTags(`name:"reset_password_store"`)),
|
||||
NewLandingCache,
|
||||
),
|
||||
)
|
||||
107
internal/pkg/oauth/github/client.go
Normal file
107
internal/pkg/oauth/github/client.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package github
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/github"
|
||||
|
||||
"base/internal/pkg/oauth/types"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
oauthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
func New(config oauth2.Config) types.Oauth {
|
||||
oauthConfig := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
Endpoint: github.Endpoint,
|
||||
RedirectURL: config.RedirectURL,
|
||||
Scopes: config.Scopes,
|
||||
}
|
||||
return &client{oauthConfig: oauthConfig}
|
||||
}
|
||||
|
||||
func (g client) GetConsentAuthUrl(ctx context.Context, state string) string {
|
||||
return g.oauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
}
|
||||
|
||||
func (g client) ExchangeCodeWithToken(ctx context.Context, code string) (*types.Token, error) {
|
||||
exchange, err := g.oauthConfig.Exchange(ctx, code, oauth2.AccessTypeOffline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token, err := g.oauthConfig.TokenSource(ctx, exchange).Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &types.Token{
|
||||
AccessToken: token.AccessToken,
|
||||
TokenType: token.TokenType,
|
||||
RefreshToken: token.RefreshToken,
|
||||
ExpiresIn: token.ExpiresIn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g client) GetUserInfo(ctx context.Context, token, _ string) (types.UserInfo, error) {
|
||||
oauthClient := g.oauthConfig.Client(ctx, &oauth2.Token{AccessToken: token})
|
||||
|
||||
resp, err := oauthClient.Get("https://api.github.com/user")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
|
||||
var user UserInfo
|
||||
if err = json.Unmarshal(data, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// GitHub /user often returns null for email; fetch from /user/emails (requires user:email scope)
|
||||
if user.GEmail == "" {
|
||||
user.GEmail = g.fetchPrimaryEmail(ctx, oauthClient)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// fetchPrimaryEmail gets the primary email from GitHub /user/emails (requires user:email scope).
|
||||
func (g client) fetchPrimaryEmail(_ context.Context, oauthClient *http.Client) string {
|
||||
resp, err := oauthClient.Get("https://api.github.com/user/emails")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var emails []struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Verified bool `json:"verified"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &emails); err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, e := range emails {
|
||||
if e.Primary && e.Verified {
|
||||
return e.Email
|
||||
}
|
||||
}
|
||||
if len(emails) > 0 {
|
||||
return emails[0].Email
|
||||
}
|
||||
return ""
|
||||
}
|
||||
59
internal/pkg/oauth/github/user.go
Normal file
59
internal/pkg/oauth/github/user.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package github
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
Login string `json:"login"`
|
||||
Id int `json:"id"`
|
||||
NodeId string `json:"node_id"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
GravatarId string `json:"gravatar_id"`
|
||||
Url string `json:"url"`
|
||||
HtmlUrl string `json:"html_url"`
|
||||
FollowersUrl string `json:"followers_url"`
|
||||
FollowingUrl string `json:"following_url"`
|
||||
GistsUrl string `json:"gists_url"`
|
||||
StarredUrl string `json:"starred_url"`
|
||||
SubscriptionsUrl string `json:"subscriptions_url"`
|
||||
OrganizationsUrl string `json:"organizations_url"`
|
||||
ReposUrl string `json:"repos_url"`
|
||||
EventsUrl string `json:"events_url"`
|
||||
ReceivedEventsUrl string `json:"received_events_url"`
|
||||
Type string `json:"type"`
|
||||
UserViewType string `json:"user_view_type"`
|
||||
SiteAdmin bool `json:"site_admin"`
|
||||
Name string `json:"name"`
|
||||
Company interface{} `json:"company"`
|
||||
Blog string `json:"blogusecase"`
|
||||
Location interface{} `json:"location"`
|
||||
GEmail string `json:"email"`
|
||||
Hireable interface{} `json:"hireable"`
|
||||
Bio string `json:"bio"`
|
||||
TwitterUsername string `json:"twitter_username"`
|
||||
NotificationEmail string `json:"notification_email"`
|
||||
PublicRepos int `json:"public_repos"`
|
||||
PublicGists int `json:"public_gists"`
|
||||
Followers int `json:"followers"`
|
||||
Following int `json:"following"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (u UserInfo) ID() string {
|
||||
return fmt.Sprintf("%d", u.Id)
|
||||
}
|
||||
|
||||
func (u UserInfo) Email() string {
|
||||
return u.GEmail
|
||||
}
|
||||
|
||||
func (u UserInfo) FirstName() string {
|
||||
return u.Name
|
||||
}
|
||||
|
||||
func (u UserInfo) LastName() string {
|
||||
return u.Name
|
||||
}
|
||||
77
internal/pkg/oauth/google/client.go
Normal file
77
internal/pkg/oauth/google/client.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
|
||||
"base/internal/pkg/oauth/types"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
oauthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
func New(config oauth2.Config) types.Oauth {
|
||||
oauthConfig := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
Endpoint: google.Endpoint,
|
||||
RedirectURL: config.RedirectURL,
|
||||
Scopes: config.Scopes,
|
||||
}
|
||||
return &client{oauthConfig: oauthConfig}
|
||||
}
|
||||
|
||||
func (g client) GetConsentAuthUrl(ctx context.Context, state string) string {
|
||||
return g.oauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
}
|
||||
|
||||
func (g client) ExchangeCodeWithToken(ctx context.Context, code string) (*types.Token, error) {
|
||||
exchange, err := g.oauthConfig.Exchange(ctx, code, oauth2.AccessTypeOffline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token, err := g.oauthConfig.TokenSource(ctx, exchange).Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &types.Token{
|
||||
AccessToken: token.AccessToken,
|
||||
TokenType: token.TokenType,
|
||||
RefreshToken: token.RefreshToken,
|
||||
ExpiresIn: token.ExpiresIn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g client) GetUserInfo(
|
||||
ctx context.Context,
|
||||
accessToken string,
|
||||
refreshToken string,
|
||||
) (types.UserInfo, error) {
|
||||
resp, err := g.oauthConfig.Client(
|
||||
ctx,
|
||||
&oauth2.Token{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}).Get("https://www.googleapis.com/oauth2/v2/userinfo")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user UserInfo
|
||||
|
||||
if err = json.Unmarshal(data, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, err
|
||||
}
|
||||
28
internal/pkg/oauth/google/user.go
Normal file
28
internal/pkg/oauth/google/user.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package google
|
||||
|
||||
type UserInfo struct {
|
||||
Id string `json:"id"`
|
||||
GEmail string `json:"email"`
|
||||
VerifiedEmail bool `json:"verified_email"`
|
||||
Name string `json:"name"`
|
||||
GivenName string `json:"given_name"`
|
||||
FamilyName string `json:"family_name"`
|
||||
Picture string `json:"picture"`
|
||||
Locale string `json:"locale"`
|
||||
}
|
||||
|
||||
func (u UserInfo) ID() string {
|
||||
return u.Id
|
||||
}
|
||||
|
||||
func (u UserInfo) Email() string {
|
||||
return u.GEmail
|
||||
}
|
||||
|
||||
func (u UserInfo) FirstName() string {
|
||||
return u.Name
|
||||
}
|
||||
|
||||
func (u UserInfo) LastName() string {
|
||||
return u.Name
|
||||
}
|
||||
74
internal/pkg/oauth/linkedin/linkedin.go
Normal file
74
internal/pkg/oauth/linkedin/linkedin.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package linkedin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/linkedin"
|
||||
"io"
|
||||
|
||||
"base/internal/pkg/oauth/types"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
oauthConfig *oauth2.Config
|
||||
}
|
||||
|
||||
func New(config oauth2.Config) types.Oauth {
|
||||
oauthConfig := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
Endpoint: linkedin.Endpoint,
|
||||
RedirectURL: config.RedirectURL,
|
||||
Scopes: config.Scopes,
|
||||
}
|
||||
return &client{oauthConfig: oauthConfig}
|
||||
}
|
||||
|
||||
func (l client) GetConsentAuthUrl(ctx context.Context, state string) string {
|
||||
return l.oauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
}
|
||||
|
||||
func (l client) ExchangeCodeWithToken(ctx context.Context, code string) (*types.Token, error) {
|
||||
exchange, err := l.oauthConfig.Exchange(ctx, code, oauth2.AccessTypeOffline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token, err := l.oauthConfig.TokenSource(ctx, exchange).Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &types.Token{
|
||||
AccessToken: token.AccessToken,
|
||||
TokenType: token.TokenType,
|
||||
RefreshToken: token.RefreshToken,
|
||||
ExpiresIn: token.ExpiresIn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l client) GetUserInfo(
|
||||
ctx context.Context,
|
||||
accessToken string,
|
||||
refreshToken string,
|
||||
) (types.UserInfo, error) {
|
||||
resp, err := l.oauthConfig.Client(ctx, &oauth2.Token{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}).Get("https://api.linkedin.com/v2/me")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user UserInfo
|
||||
|
||||
if err = json.Unmarshal(data, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
57
internal/pkg/oauth/linkedin/user.go
Normal file
57
internal/pkg/oauth/linkedin/user.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package linkedin
|
||||
|
||||
type UserInfo struct {
|
||||
Id string `json:"id"`
|
||||
LocalizedFirstName string `json:"localizedFirstName"`
|
||||
LocalizedHeadline string `json:"localizedHeadline"`
|
||||
VanityName string `json:"vanityName"`
|
||||
LocalizedLastName string `json:"localizedLastName"`
|
||||
Firstname UserInfoFirstName `json:"firstName"`
|
||||
Lastname UserInfoLastName `json:"lastName"`
|
||||
Headline UserInfoHeadline `json:"headline"`
|
||||
ProfilePicture UserInfoProfilePicture `json:"profilePicture"`
|
||||
}
|
||||
|
||||
type UserInfoFirstName struct {
|
||||
Localized Localized `json:"localized"`
|
||||
PreferredLocale PreferredLocale `json:"preferredLocale"`
|
||||
}
|
||||
|
||||
type UserInfoLastName struct {
|
||||
Localized Localized `json:"localized"`
|
||||
PreferredLocale PreferredLocale `json:"preferredLocale"`
|
||||
}
|
||||
|
||||
type Localized struct {
|
||||
EnUS string `json:"en_US"`
|
||||
}
|
||||
|
||||
type PreferredLocale struct {
|
||||
Country string `json:"country"`
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
type UserInfoHeadline struct {
|
||||
Localized Localized `json:"localized"`
|
||||
PreferredLocale PreferredLocale `json:"preferredLocale"`
|
||||
}
|
||||
|
||||
type UserInfoProfilePicture struct {
|
||||
DisplayImage string `json:"displayImage"`
|
||||
}
|
||||
|
||||
func (u UserInfo) ID() string {
|
||||
return u.Id
|
||||
}
|
||||
|
||||
func (u UserInfo) Email() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (u UserInfo) FirstName() string {
|
||||
return u.Firstname.Localized.EnUS
|
||||
}
|
||||
|
||||
func (u UserInfo) LastName() string {
|
||||
return u.Lastname.Localized.EnUS
|
||||
}
|
||||
81
internal/pkg/oauth/mock/client.go
Normal file
81
internal/pkg/oauth/mock/client.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"base/internal/pkg/oauth/types"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
oauthConfig *oauth2.Config
|
||||
userinfoURL string
|
||||
}
|
||||
|
||||
// New creates a mock OAuth client that uses a local mock OAuth server.
|
||||
// Use for local development when real Google/GitHub credentials are not available.
|
||||
func New(config oauth2.Config, baseURL string) types.Oauth {
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
oauthConfig := &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
RedirectURL: config.RedirectURL,
|
||||
Scopes: config.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: baseURL + "/authorize",
|
||||
TokenURL: baseURL + "/token",
|
||||
},
|
||||
}
|
||||
return &client{
|
||||
oauthConfig: oauthConfig,
|
||||
userinfoURL: baseURL + "/userinfo",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) GetConsentAuthUrl(ctx context.Context, state string) string {
|
||||
return c.oauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
}
|
||||
|
||||
func (c *client) ExchangeCodeWithToken(ctx context.Context, code string) (*types.Token, error) {
|
||||
exchange, err := c.oauthConfig.Exchange(ctx, code, oauth2.AccessTypeOffline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token, err := c.oauthConfig.TokenSource(ctx, exchange).Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &types.Token{
|
||||
AccessToken: token.AccessToken,
|
||||
TokenType: token.TokenType,
|
||||
RefreshToken: token.RefreshToken,
|
||||
ExpiresIn: token.ExpiresIn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *client) GetUserInfo(ctx context.Context, accessToken, _ string) (types.UserInfo, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.userinfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var user UserInfo
|
||||
if err := json.Unmarshal(data, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
25
internal/pkg/oauth/mock/user.go
Normal file
25
internal/pkg/oauth/mock/user.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package mock
|
||||
|
||||
// UserInfo matches the mock server's /userinfo response (Google-like format)
|
||||
type UserInfo struct {
|
||||
MID string `json:"id"`
|
||||
MEmail string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
GivenName string `json:"given_name"`
|
||||
FamilyName string `json:"family_name"`
|
||||
}
|
||||
|
||||
func (u UserInfo) ID() string { return u.MID }
|
||||
func (u UserInfo) Email() string { return u.MEmail }
|
||||
func (u UserInfo) FirstName() string {
|
||||
if u.GivenName != "" {
|
||||
return u.GivenName
|
||||
}
|
||||
return u.Name
|
||||
}
|
||||
func (u UserInfo) LastName() string {
|
||||
if u.FamilyName != "" {
|
||||
return u.FamilyName
|
||||
}
|
||||
return u.Name
|
||||
}
|
||||
119
internal/pkg/oauth/oauth.go
Normal file
119
internal/pkg/oauth/oauth.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"base/config"
|
||||
|
||||
"base/internal/pkg/oauth/github"
|
||||
"base/internal/pkg/oauth/google"
|
||||
"base/internal/pkg/oauth/linkedin"
|
||||
"base/internal/pkg/oauth/mock"
|
||||
"base/internal/pkg/oauth/types"
|
||||
)
|
||||
|
||||
// Token is an alias for types.Token for backward compatibility
|
||||
type Token = types.Token
|
||||
|
||||
type OAuth struct {
|
||||
google types.Oauth
|
||||
linkedin types.Oauth
|
||||
github types.Oauth
|
||||
mock types.Oauth
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
GoogleConfig oauth2.Config
|
||||
GitHubConfig oauth2.Config
|
||||
LinkedinConfig oauth2.Config
|
||||
}
|
||||
|
||||
func New(cfg *config.AppConfig) OAuth {
|
||||
oauthConfig := Config{
|
||||
GoogleConfig: oauth2.Config{
|
||||
ClientID: cfg.OAuth.Google.ClientID,
|
||||
ClientSecret: cfg.OAuth.Google.ClientSecret,
|
||||
RedirectURL: cfg.OAuth.Google.RedirectURL,
|
||||
Scopes: cfg.OAuth.Google.Scopes,
|
||||
},
|
||||
GitHubConfig: oauth2.Config{
|
||||
ClientID: cfg.OAuth.GitHub.ClientID,
|
||||
ClientSecret: cfg.OAuth.GitHub.ClientSecret,
|
||||
RedirectURL: cfg.OAuth.GitHub.RedirectURL,
|
||||
Scopes: cfg.OAuth.GitHub.Scopes,
|
||||
},
|
||||
LinkedinConfig: oauth2.Config{
|
||||
ClientID: cfg.OAuth.LinkedIn.ClientID,
|
||||
ClientSecret: cfg.OAuth.LinkedIn.ClientSecret,
|
||||
RedirectURL: cfg.OAuth.LinkedIn.RedirectURL,
|
||||
Scopes: cfg.OAuth.LinkedIn.Scopes,
|
||||
},
|
||||
}
|
||||
|
||||
o := OAuth{
|
||||
google: google.New(oauthConfig.GoogleConfig),
|
||||
linkedin: linkedin.New(oauthConfig.LinkedinConfig),
|
||||
github: github.New(oauthConfig.GitHubConfig),
|
||||
}
|
||||
|
||||
if cfg.OAuth.Mock.Enabled && strings.TrimSpace(cfg.OAuth.Mock.BaseURL) != "" {
|
||||
baseURL := strings.TrimSuffix(strings.TrimSpace(cfg.OAuth.Mock.BaseURL), "/")
|
||||
mockConfig := oauth2.Config{
|
||||
ClientID: cfg.OAuth.Mock.ClientID,
|
||||
ClientSecret: cfg.OAuth.Mock.ClientSecret,
|
||||
RedirectURL: cfg.OAuth.Mock.RedirectURL,
|
||||
Scopes: cfg.OAuth.Mock.Scopes,
|
||||
}
|
||||
if mockConfig.ClientID == "" {
|
||||
mockConfig.ClientID = "mock-client"
|
||||
}
|
||||
if mockConfig.ClientSecret == "" {
|
||||
mockConfig.ClientSecret = "mock-secret"
|
||||
}
|
||||
if mockConfig.RedirectURL == "" {
|
||||
mockConfig.RedirectURL = "http://localhost:3000/auth/callback"
|
||||
}
|
||||
o.mock = mock.New(mockConfig, baseURL)
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
func (a OAuth) Client(provider Provider) types.Oauth {
|
||||
switch provider {
|
||||
case Google:
|
||||
return a.google
|
||||
case Linkedin:
|
||||
return a.linkedin
|
||||
case GitHub:
|
||||
return a.github
|
||||
case Mock:
|
||||
if a.mock != nil {
|
||||
return a.mock
|
||||
}
|
||||
return disabledMockClient{}
|
||||
default:
|
||||
return a.google
|
||||
}
|
||||
}
|
||||
|
||||
// ErrMockNotEnabled is returned when mock provider is used but not configured
|
||||
var ErrMockNotEnabled = errors.New("oauth mock is not enabled - set oauth.mock.enabled=true and oauth.mock.base_url")
|
||||
|
||||
// disabledMockClient is used when mock is requested but not configured
|
||||
type disabledMockClient struct{}
|
||||
|
||||
func (disabledMockClient) GetConsentAuthUrl(_ context.Context, _ string) string {
|
||||
panic("oauth mock is not enabled - set oauth.mock.enabled=true and oauth.mock.base_url")
|
||||
}
|
||||
|
||||
func (disabledMockClient) ExchangeCodeWithToken(context.Context, string) (*types.Token, error) {
|
||||
return nil, ErrMockNotEnabled
|
||||
}
|
||||
|
||||
func (disabledMockClient) GetUserInfo(context.Context, string, string) (types.UserInfo, error) {
|
||||
return nil, ErrMockNotEnabled
|
||||
}
|
||||
51
internal/pkg/oauth/provider.go
Normal file
51
internal/pkg/oauth/provider.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:generate stringer -type=Provider
|
||||
type Provider int
|
||||
|
||||
const (
|
||||
Unknown Provider = iota
|
||||
Credentials
|
||||
Google
|
||||
GitHub
|
||||
Linkedin
|
||||
Mock
|
||||
)
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler so Provider accepts string in JSON (e.g. "mock", "google")
|
||||
func (p *Provider) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
parsed, err := ParseProvider(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*p = parsed
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseProvider parses a provider string and returns the corresponding Provider enum
|
||||
func ParseProvider(provider string) (Provider, error) {
|
||||
switch strings.ToLower(provider) {
|
||||
case "credentials":
|
||||
return Credentials, nil
|
||||
case "google":
|
||||
return Google, nil
|
||||
case "github":
|
||||
return GitHub, nil
|
||||
case "linkedin":
|
||||
return Linkedin, nil
|
||||
case "mock":
|
||||
return Mock, nil
|
||||
default:
|
||||
return Unknown, fmt.Errorf("unknown provider: %s", provider)
|
||||
}
|
||||
}
|
||||
28
internal/pkg/oauth/provider_string.go
Normal file
28
internal/pkg/oauth/provider_string.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Code generated by "stringer -type=Provider"; DO NOT EDIT.
|
||||
|
||||
package oauth
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[Unknown-0]
|
||||
_ = x[Credentials-1]
|
||||
_ = x[Google-2]
|
||||
_ = x[GitHub-3]
|
||||
_ = x[Linkedin-4]
|
||||
_ = x[Mock-5]
|
||||
}
|
||||
|
||||
const _Provider_name = "UnknownCredentialsGoogleGitHubLinkedinMock"
|
||||
|
||||
var _Provider_index = [...]uint8{0, 7, 18, 24, 30, 38, 42}
|
||||
|
||||
func (i Provider) String() string {
|
||||
if i < 0 || i >= Provider(len(_Provider_index)-1) {
|
||||
return "Provider(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _Provider_name[_Provider_index[i]:_Provider_index[i+1]]
|
||||
}
|
||||
25
internal/pkg/oauth/types/types.go
Normal file
25
internal/pkg/oauth/types/types.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
AccessToken string
|
||||
TokenType string
|
||||
RefreshToken string
|
||||
ExpiresIn int64
|
||||
}
|
||||
|
||||
type Oauth interface {
|
||||
GetConsentAuthUrl(ctx context.Context, state string) string
|
||||
ExchangeCodeWithToken(ctx context.Context, code string) (*Token, error)
|
||||
GetUserInfo(ctx context.Context, accessToken, refreshToken string) (UserInfo, error)
|
||||
}
|
||||
|
||||
type UserInfo interface {
|
||||
ID() string
|
||||
Email() string
|
||||
FirstName() string
|
||||
LastName() string
|
||||
}
|
||||
Reference in New Issue
Block a user