Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions internal/provider/allowlist_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,14 @@ func (r *allowListResource) Configure(
}
}

func (r *allowListResource) resourceType() string {
return providerResponseTypeName + "_allow_list"
}

func (r *allowListResource) Metadata(
_ context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse,
_ context.Context, _ resource.MetadataRequest, resp *resource.MetadataResponse,
) {
resp.TypeName = req.ProviderTypeName + "_allow_list"
resp.TypeName = r.resourceType()
}

func (r *allowListResource) Create(
Expand Down Expand Up @@ -149,6 +153,7 @@ func (r *allowListResource) Create(
}

traceAPICall("AddAllowlistEntry")
ctx = contextWithResourceMetadata(ctx, r, entry.ID.ValueString())
_, _, err := r.provider.service.AddAllowlistEntry(ctx, entry.ClusterId.ValueString(), &allowList)
if err != nil {
resp.Diagnostics.AddError(
Expand Down
9 changes: 7 additions & 2 deletions internal/provider/cluster_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,14 @@ func (r *clusterResource) Schema(
}
}

func (r *clusterResource) resourceType() string {
return providerResponseTypeName + "_cluster"
}

func (r *clusterResource) Metadata(
_ context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse,
_ context.Context, _ resource.MetadataRequest, resp *resource.MetadataResponse,
) {
resp.TypeName = req.ProviderTypeName + "_cluster"
resp.TypeName = r.resourceType()
}

func (r *clusterResource) Configure(
Expand Down Expand Up @@ -476,6 +480,7 @@ func (r *clusterResource) Read(
}

traceAPICall("GetCluster")
ctx = contextWithResourceMetadata(ctx, r, clusterID)
clusterObj, httpResp, err := r.provider.service.GetCluster(ctx, clusterID)
if err != nil {
if httpResp != nil && httpResp.StatusCode == http.StatusNotFound {
Expand Down
9 changes: 7 additions & 2 deletions internal/provider/cmek_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ type cmekResource struct {
provider *provider
}

func (r *cmekResource) resourceType() string {
return providerResponseTypeName + "_cmek"
}

func (r *cmekResource) Schema(
_ context.Context, _ resource.SchemaRequest, resp *resource.SchemaResponse,
) {
Expand All @@ -120,9 +124,9 @@ func (r *cmekResource) Schema(
}

func (r *cmekResource) Metadata(
_ context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse,
_ context.Context, _ resource.MetadataRequest, resp *resource.MetadataResponse,
) {
resp.TypeName = req.ProviderTypeName + "_cmek"
resp.TypeName = r.resourceType()
}

func (r *cmekResource) Configure(
Expand Down Expand Up @@ -169,6 +173,7 @@ func (r *cmekResource) Create(
cmekSpec.SetRegionSpecs(regionSpecs)

traceAPICall("EnableCMEKSpec")
ctx = contextWithResourceMetadata(ctx, r, plan.ID.ValueString())
cmekObj, _, err := r.provider.service.EnableCMEKSpec(ctx, plan.ID.ValueString(), cmekSpec)
if err != nil {
resp.Diagnostics.AddError(
Expand Down
10 changes: 9 additions & 1 deletion internal/provider/folder_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,14 @@ func (r *folderResource) Schema(
}
}

func (r *folderResource) resourceType() string {
return providerResponseTypeName + "_folder"
}

func (r *folderResource) Metadata(
_ context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse,
) {
resp.TypeName = req.ProviderTypeName + "_folder"
resp.TypeName = r.resourceType()
}

func (r *folderResource) Configure(
Expand Down Expand Up @@ -87,6 +91,7 @@ func (r *folderResource) Create(

parentID := plan.ParentId.ValueString()
traceAPICall("CreateFolder")
ctx = contextWithResourceMetadata(ctx, r, "")
folderObj, _, err := r.provider.service.CreateFolder(ctx, &client.CreateFolderRequest{
Name: plan.Name.ValueString(),
ParentId: &parentID,
Expand Down Expand Up @@ -139,6 +144,7 @@ func (r *folderResource) Read(
}

traceAPICall("GetFolder")
ctx = contextWithResourceMetadata(ctx, r, folderID)
folderObj, httpResp, err := r.provider.service.GetFolder(ctx, folderID)
if err != nil {
if httpResp != nil && httpResp.StatusCode == http.StatusNotFound {
Expand Down Expand Up @@ -187,6 +193,7 @@ func (r *folderResource) Update(
destParentID = plan.ParentId.ValueString()
)
traceAPICall("UpdateFolder")
ctx = contextWithResourceMetadata(ctx, r, plan.ID.ValueString())
folderObj, _, err := r.provider.service.UpdateFolder(
ctx,
plan.ID.ValueString(),
Expand Down Expand Up @@ -227,6 +234,7 @@ func (r *folderResource) Delete(
}

traceAPICall("DeleteFolder")
ctx = contextWithResourceMetadata(ctx, r, folderID.ValueString())
httpResp, err := r.provider.service.DeleteFolder(ctx, folderID.ValueString())
if err != nil {
if httpResp != nil && httpResp.StatusCode == http.StatusNotFound {
Expand Down
47 changes: 36 additions & 11 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package provider

import (
"context"
"net/http"
"os"

"github.com/cockroachdb/cockroach-cloud-sdk-go/pkg/client"
Expand All @@ -27,11 +28,14 @@ import (
"github.com/hashicorp/terraform-plugin-framework/provider/schema"
"github.com/hashicorp/terraform-plugin-framework/resource"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/logging"
)

// NewService overrides the client method for testing.
var NewService = client.NewService

var providerResponseTypeName = "cockroach"

// provider satisfies the tfsdk.Provider interface and usually is included
// with all Resource and DataSource implementations.
type provider struct {
Expand Down Expand Up @@ -89,24 +93,18 @@ func (p *provider) Configure(
}
cfg.UserAgent = UserAgent

logLevel := os.Getenv("TF_LOG")
if logLevel == "DEBUG" || logLevel == "TRACE" {
cfg.Debug = true
} else {
logLevel = os.Getenv("TF_LOG_PROVIDER")
if logLevel == "DEBUG" || logLevel == "TRACE" {
cfg.Debug = true
}
}

// retryablehttp gives us automatic retries with exponential backoff.
httpClient := retryablehttp.NewClient()

// The TF framework will pick up the default global logger.
// HTTP requests are logged at DEBUG level.
httpClient.Logger = &leveledTFLogger{baseCtx: ctx}
httpClient.ErrorHandler = retryablehttp.PassthroughErrorHandler
httpClient.CheckRetry = retryGetRequestsOnly
cfg.HTTPClient = httpClient.StandardClient()
cfg.HTTPClient.Transport = ApiWrapperRoundTripper{
next: logging.NewLoggingHTTPTransport(cfg.HTTPClient.Transport),
}

cl := client.NewClient(cfg)
p.service = NewService(cl)
Expand All @@ -119,7 +117,7 @@ func (p *provider) Configure(
func (p *provider) Metadata(
_ context.Context, _ tf_provider.MetadataRequest, resp *tf_provider.MetadataResponse,
) {
resp.TypeName = "cockroach"
resp.TypeName = providerResponseTypeName
resp.Version = p.version
}

Expand Down Expand Up @@ -182,3 +180,30 @@ func New(version string) func() tf_provider.Provider {
}
}
}

// This type implements the http.RoundTripper interface
type ApiWrapperRoundTripper struct {
// TODO(fitzner): What's a good name for this?
next http.RoundTripper
}

func (rt ApiWrapperRoundTripper) RoundTrip(req *http.Request) (res *http.Response, e error) {

ctx := req.Context()

resourceType := ctx.Value(contextValResourceType)
resourceIDHash := ctx.Value(contextValResourceIDHash)
if resourceType != nil || resourceIDHash != nil {
// make a copy
req = req.Clone(ctx)

if resourceType != nil && resourceType.(string) != "" {
req.Header.Set("Cc-Tf-Resource-Type", resourceType.(string))
}
if resourceIDHash != nil && resourceIDHash.(string) != "" {
req.Header.Set("Cc-Tf-Resource-Id-Hash", resourceIDHash.(string))
}
}

return rt.next.RoundTrip(req)
}
27 changes: 27 additions & 0 deletions internal/provider/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package provider

import (
"context"
"encoding/base64"
"errors"
"fmt"
"hash/fnv"
"net/http"
"os"
"regexp"
"runtime"
"strings"
"testing"

"github.com/cockroachdb/cockroach-cloud-sdk-go/pkg/client"
Expand All @@ -26,6 +29,9 @@ import (
// failed cleanup.
const tfTestPrefix = "tftest"

const contextValResourceType = "RESOURCE_TYPE"
const contextValResourceIDHash = "RESOURCE_ID_HASH"

func addConfigureProviderErr(diagnostics *diag.Diagnostics) {
diagnostics.AddError(
"Provider not configured",
Expand Down Expand Up @@ -185,3 +191,24 @@ func traceAPICall(endpoint string) {
fmt.Printf("CC API Call: %s (%s)\n", endpoint, runtime.FuncForPC(pc).Name())
}
}

type resourceTyper interface {
resourceType() string
}

func contextWithResourceMetadata(ctx context.Context, resource resourceTyper, resourceID string) context.Context {
ctx = context.WithValue(ctx, contextValResourceType, resource.resourceType())
ctx = context.WithValue(ctx, contextValResourceIDHash, hashString(resourceID))
return ctx
}

func hashString(s string) string {
hasher := fnv.New32()
hasher.Write([]byte(s))
encoded := base64.StdEncoding.EncodeToString(hasher.Sum(nil))
fillerIndex := strings.Index(encoded, "=")
if fillerIndex != -1 {
encoded = encoded[:fillerIndex]
}
return encoded
}