diff --git a/internal/daemon/controller/controller.go b/internal/daemon/controller/controller.go index a56ed92bc6..ce25afa11e 100644 --- a/internal/daemon/controller/controller.go +++ b/internal/daemon/controller/controller.go @@ -50,6 +50,7 @@ import ( external_plugins "github.com/hashicorp/boundary/sdk/plugins" "github.com/hashicorp/boundary/version" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-rate" "github.com/hashicorp/go-secure-stdlib/mlock" "github.com/hashicorp/go-secure-stdlib/pluginutil/v2" "github.com/hashicorp/nodeenrollment" @@ -119,6 +120,8 @@ type Controller struct { apiGrpcServerListener grpcServerListener apiGrpcGatewayTicket string + rateLimiter *rate.Limiter + // Repo factory methods AuthTokenRepoFn common.AuthTokenRepoFactory VaultCredentialRepoFn common.VaultCredentialRepoFactory @@ -245,6 +248,16 @@ func New(ctx context.Context, conf *Config) (*Controller, error) { } c.clusterListener = clusterListeners[0] + rateLimits, err := conf.RawConfig.Controller.ApiRateLimits.Limits(c.baseContext) + if err != nil { + return nil, fmt.Errorf("error parsing rate limit configuration: %w", err) + } + + c.rateLimiter, err = rate.NewLimiter(rateLimits, conf.RawConfig.Controller.ApiRateLimiterMaxEntries) + if err != nil { + return nil, fmt.Errorf("error initializing rate limiter: %w", err) + } + var pluginLogger hclog.Logger for _, enabledPlugin := range c.enabledPlugins { if pluginLogger == nil { diff --git a/internal/daemon/controller/handler.go b/internal/daemon/controller/handler.go index b461c3cb3d..9731befd51 100644 --- a/internal/daemon/controller/handler.go +++ b/internal/daemon/controller/handler.go @@ -46,6 +46,7 @@ import ( "github.com/hashicorp/boundary/internal/gen/controller/api/services" authpb "github.com/hashicorp/boundary/internal/gen/controller/auth" opsservices "github.com/hashicorp/boundary/internal/gen/ops/services" + "github.com/hashicorp/boundary/internal/ratelimit" "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-secure-stdlib/listenerutil" "github.com/hashicorp/go-secure-stdlib/strutil" @@ -73,7 +74,7 @@ func createMuxWithEndpoints(c *Controller, props HandlerProperties) (http.Handle } mux := http.NewServeMux() - mux.Handle("/v1/", grpcGwMux) + mux.Handle("/v1/", ratelimit.Handler(c.rateLimiter, grpcGwMux)) mux.Handle(uiPath, handleUi(c)) isUiRequest := func(req *http.Request) bool {