diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..d7dcbc9 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,11 @@ +.git +.gitignore +.dockerignore +.gitlab-ci.yml +.travis.yml +.idea +.github/ +host_key +host_key.pub +s3-sftp-proxy +s3-sftp-proxy.toml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..346056a --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +host_key +host_key.pub +s3-sftp-proxy +s3-sftp-proxy.toml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..1607528 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,21 @@ +FROM golang:alpine AS build + +ENV GO111MODULE=on + +SHELL ["/bin/sh", "-x", "-c"] +COPY . /go/src/s3-sftp-proxy/ +WORKDIR /go/src/s3-sftp-proxy/ +RUN go build -ldflags "-s -w" + + +FROM alpine:3.10 + +COPY --from=build /go/src/s3-sftp-proxy/s3-sftp-proxy /usr/local/bin + +RUN addgroup -g 1000 -S sftp && \ + adduser -u 1000 -S sftp -G sftp + +WORKDIR /home/sftp +USER sftp +ENTRYPOINT ["/usr/local/bin/s3-sftp-proxy"] +CMD ["--config", "/etc/s3-sftp-proxy.conf"] diff --git a/README.md b/README.md index be07198..dc2b709 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Usage of s3-sftp-proxy: Turn on debug logging. The output will be more verbose. - + ## Configuation The configuration file is in [TOML](https://github.com/toml-lang/toml) format. Refer to that page for the detailed explanation of the syntax. @@ -43,6 +43,14 @@ reader_lookback_buffer_size = 1048576 reader_min_chunk_size = 262144 lister_lookback_buffer_size = 100 +upload_memory_buffer_size = 5242880 +upload_memory_buffer_pool_size = 10 +upload_memory_buffer_pool_timeout = "5s" +upload_workers_count = 2 + +metrics_bind = ":2112" +metrics_endpoint = "/metrics" + # buckets and authantication settings follow... ``` @@ -60,6 +68,14 @@ lister_lookback_buffer_size = 100 Specifies the local address and port to listen on. +* `metrics_bind` (optional, defaults to `":2112"`) + + Specifies the local address and port metrics. + +* `metrics_endpoint` (optional, defaults to `"/metrics"`) + + Specifies the metrics endpoint. + * `banner` (optional, defaults to an empty string) A banner is a message text that will be sent to the client when the connection is esablished to the server prior to any authentication steps. @@ -72,10 +88,26 @@ lister_lookback_buffer_size = 100 Specifies the amount of data fetched from S3 at once. Increase the value when you experience quite a poor performance. -* `lister_lookback_buffer_size` (optional, defalts to `100`) +* `lister_lookback_buffer_size` (optional, defaults to `100`) Contrary to the people's expectation, SFTP also requires file listings to be retrieved in random-access as well. +* `upload_memory_buffer_size` (optional, defaults to `5242880`) + + Bytes used as internal memory buffer to upload files to S3, and to divide a file into several parts to upload to S3 (details on (Uploads section)[#uploads]). + +* `upload_memory_buffer_pool_size` (optional, defaults to `10`) + + Number of internal memory buffers of size `upload_memory_buffer_size` used for upload purposes. Details on (Uploads section)[#uploads]. + +* `upload_memory_buffer_pool_timeout` (optional, defaults to `"5s"`) + + Maximum amount of time to wait to wait for an available memory buffer from pool on uploads. This timeout is useful when the pool is full. Details on (Uploads section)[#uploads]. + +* `upload_workers_count` (optional, defaults to `2`) + + Number of workers used to upload parts to S3. Details on (Uploads section)[#uploads]. + * `buckets` (required) `buckets` contains records for bucket declarations. See [Bucket Settings](#bucket-settings) for detail. @@ -115,8 +147,8 @@ aws_secret_access_key = "bbb" Specifies s3 endpoint (server) different from AWS. * `s3_force_path_style` (optional) - This option should be set to `true` if ypu use endpount different from AWS. - + This option should be set to `true` if you use endpoint different from AWS. + Set this to `true` to force the request to use path-style addressing, i.e., `http://s3.amazonaws.com/BUCKET/KEY`. By default, the S3 client will use virtual hosted bucket addressing when possible (`http://BUCKET.s3.amazonaws.com/KEY`). * `disable_ssl` (optional) @@ -127,7 +159,7 @@ aws_secret_access_key = "bbb" Specifies the bucket name. * `key_prefix` (required when `bucket_url` is unspecified) - + Specifies the prefix prepended to the file path sent from the client. The key string is derived as follows: `key` = `key_prefix` + `path` @@ -147,11 +179,11 @@ aws_secret_access_key = "bbb" * `credentials` (optional) * `credentials.aws_access_key_id` (required) - + Specifies the AWS access key. * `credentials.aws_secret_access_key` (required) - + Specifies the AWS secret access key. * `max_object_size` (optional, defaults to unlimited) @@ -232,7 +264,7 @@ ssh-rsa AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA """ ``` -Or +Or ```toml [auth.test] @@ -255,3 +287,48 @@ user1 = { password="test", public_keys="..." } Specifies the public keys authorized to use in authentication. Multiple keys can be specified by delimiting them by newlines. +### Prometheus metrics + +* `sftp_operation_status` _(counter)_ + + Represents SFTP operation statuses count by method + +* `sftp_aws_session_error` _(counter)_ + + AWS S3 session errors count + +* `sftp_permissions_error` _(counter)_ + + Bucket permission errors count by method + +* `sftp_users_connected` _(gauge)_ + + Number of users connected to the server in certain moment. + +* `sftp_memory_buffer_pool_max` _(gauge)_ + + Number of memory buffers that can be requested in the pool. + +* `sftp_memory_buffer_pool_used` _(gauge)_ + + Number of memory buffers used current in the pool. + +* `sftp_memory_buffer_pool_timeouts` _(gauge)_ + + Number of timeouts produced in the pool when a memory buffer was requested. + +## Internals + +### Uploads + +`s3-sftp-proxy` uses S3 multipart upload (details on [](https://docs.aws.amazon.com/AmazonS3/latest/dev/mpuoverview.html)) for those +objects bigger than or equal to `upload_memory_buffer_size` parameter or S3 put object (details on [](https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html)) in other case. + +In order to optimize uploads to S3 and reduce the amount of memory needed, `s3-sftp-proxy` uses internal memory buffers (called memory pool internally). The size of each buffer is defined by `upload_memory_buffer_size`, meanwhile the total number is defined by `upload_memory_buffer_pool_size`. As the pool can be filled completely (by using all available buffers), a `upload_memory_buffer_pool_timeout` is defined to raise an error when the pool is full and an upload is waiting for a memory buffer this amount of time. + +Finally, in order to make uploads concurrently to S3, several upload workers are started. The amount of workers is defined by `upload_workers_count`. + +Given previous information, the maximum amount of memory used internally for buffers to upload to S3 can be calculted by: `upload_memory_buffer_size * upload_memory_buffer_pool_size`. This amount of memory is considerably lower than storing the entire file in memory. However, if pool +is full, an error will be raised and the file will not be uploaded. This kind of errors can be easily on metric `sftp_memory_buffer_pool_timeouts`. + +As an example, imagine you want to upload a 12MB size file (and we are using the default value for `upload_memory_buffer_size`, which is 5MB) using `sftp` tool. This tool uploads 32KB chunks in parallel, so chunks arrives to the server without order. When first chunk is received on the server, `s3-sftp-proxy` gets a buffer memory from the pool and inserts the data in their place. When the buffer is full (5MB are present on the server), a [CreateMultipartUpload](https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html) request is performed to S3 and an upload to S3 is enqueued to the workers. One upload worker will take this upload from the queue, upload its content to S3 using an [UploadPart](https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPart.html) request, and returned the buffer memory to the pool (releasing it). Meanwhile, more data from the client is received and stored on a different buffer. Finally, when the entire file is uploaded, pending data is uploaded to S3 via UploadPart. Finally, when all data is present on S3, a [CompleteMultipartUpload](https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html) request is sent to S3 to finish the upload. diff --git a/bucketio.go b/bucketio.go index 8b4a560..0f421e9 100644 --- a/bucketio.go +++ b/bucketio.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "context" "fmt" "io" @@ -14,6 +13,7 @@ import ( aws_session "github.com/aws/aws-sdk-go/aws/session" aws_s3 "github.com/aws/aws-sdk-go/service/s3" "github.com/pkg/sftp" + "github.com/prometheus/client_golang/prometheus" // s3crypto "github.com/aws/aws-sdk-go/service/s3/s3crypto" ) @@ -146,74 +146,11 @@ func (oor *S3GetObjectOutputReader) ReadAt(buf []byte, off int64) (int, error) { } copy(buf[i:], oor.spooled[s:be]) return be - s, nil - } else { - return 0, io.EOF } + return 0, io.EOF } } -type S3PutObjectWriter struct { - Ctx context.Context - Bucket string - Key Path - S3 *aws_s3.S3 - ServerSideEncryption *ServerSideEncryptionConfig - Log interface { - DebugLogger - ErrorLogger - } - MaxObjectSize int64 - Info *PhantomObjectInfo - PhantomObjectMap *PhantomObjectMap - mtx sync.Mutex - writer *BytesWriter -} - -func (oow *S3PutObjectWriter) Close() error { - F(oow.Log.Debug, "S3PutObjectWriter.Close") - oow.mtx.Lock() - defer oow.mtx.Unlock() - phInfo := oow.Info.GetOne() - oow.PhantomObjectMap.RemoveByInfoPtr(oow.Info) - key := phInfo.Key.String() - sse := oow.ServerSideEncryption - F(oow.Log.Debug, "PutObject(Bucket=%s, Key=%s, Sse=%v)", oow.Bucket, key, sse) - _, err := oow.S3.PutObject( - &aws_s3.PutObjectInput{ - ACL: &aclPrivate, - Body: bytes.NewReader(oow.writer.Bytes()), - Bucket: &oow.Bucket, - Key: &key, - ServerSideEncryption: sseTypes[sse.Type], - SSECustomerAlgorithm: nilIfEmpty(sse.CustomerAlgorithm()), - SSECustomerKey: nilIfEmpty(sse.CustomerKey), - SSECustomerKeyMD5: nilIfEmpty(sse.CustomerKeyMD5), - SSEKMSKeyId: nilIfEmpty(sse.KMSKeyId), - }, - ) - if err != nil { - oow.Log.Debug("=> ", err) - F(oow.Log.Error, "failed to put object: %s", err.Error()) - } else { - oow.Log.Debug("=> OK") - } - return nil -} - -func (oow *S3PutObjectWriter) WriteAt(buf []byte, off int64) (int, error) { - oow.mtx.Lock() - defer oow.mtx.Unlock() - if oow.MaxObjectSize >= 0 { - if int64(len(buf))+off > oow.MaxObjectSize { - return 0, fmt.Errorf("file too large: maximum allowed size is %d bytes", oow.MaxObjectSize) - } - } - F(oow.Log.Debug, "len(buf)=%d, off=%d", len(buf), off) - n, err := oow.writer.WriteAt(buf, off) - oow.Info.SetSize(oow.writer.Size()) - return n, err -} - type ObjectFileInfo struct { _Name string _LastModified time.Time @@ -300,12 +237,16 @@ func aclToMode(owner *aws_s3.Owner, grants []*aws_s3.Grant) os.FileMode { } func (sol *S3ObjectLister) ListAt(result []os.FileInfo, o int64) (int, error) { + lSuccess := prometheus.Labels{"method": "Ls", "status": "success"} + lFailure := prometheus.Labels{"method": "Ls", "status": "failure"} _o, err := castInt64ToInt(o) if err != nil { + mOperationStatus.With(lFailure).Inc() return 0, err } if _o < sol.spoolOffset { + mOperationStatus.With(lFailure).Inc() return 0, fmt.Errorf("supplied position is out of range") } @@ -327,6 +268,7 @@ func (sol *S3ObjectLister) ListAt(result []os.FileInfo, o int64) (int, error) { if sol.noMore { if i == 0 { + mOperationStatus.With(lSuccess).Inc() return 0, io.EOF } else { return i, nil @@ -382,6 +324,7 @@ func (sol *S3ObjectLister) ListAt(result []os.FileInfo, o int64) (int, error) { ) if err != nil { sol.Debug("=> ", err) + mOperationStatus.With(lFailure).Inc() return i, err } F(sol.Debug, "=> { CommonPrefixes=len(%d), Contents=len(%d) }", len(out.CommonPrefixes), len(out.Contents)) @@ -438,16 +381,21 @@ type S3ObjectStat struct { func (sos *S3ObjectStat) ListAt(result []os.FileInfo, o int64) (int, error) { F(sos.Debug, "S3ObjectStat.ListAt: len(result)=%d offset=%d", len(result), o) + lFailure := prometheus.Labels{"method": "Stat", "status": "failure"} + lNoObject := prometheus.Labels{"method": "Stat", "status": "noSuchObject"} _o, err := castInt64ToInt(o) if err != nil { + mOperationStatus.With(lFailure).Inc() return 0, err } if len(result) == 0 { + mOperationStatus.With(lFailure).Inc() return 0, nil } if _o > 0 { + mOperationStatus.With(lFailure).Inc() return 0, fmt.Errorf("supplied position is out of range") } @@ -514,6 +462,7 @@ func (sos *S3ObjectStat) ListAt(result []os.FileInfo, o int64) (int, error) { ) if err != nil || (!sos.Root && len(out.CommonPrefixes) == 0) { sos.Debug("=> ", err) + mOperationStatus.With(lNoObject).Inc() return 0, os.ErrNotExist } F(sos.Debug, "=> { CommonPrefixes=len(%d), Contents=len(%d) }", len(out.CommonPrefixes), len(out.Contents)) @@ -535,34 +484,35 @@ type S3BucketIO struct { ReaderLookbackBufferSize int ReaderMinChunkSize int ListerLookbackBufferSize int + UploadMemoryBufferPool *MemoryBufferPool PhantomObjectMap *PhantomObjectMap Perms Perms ServerSideEncryption *ServerSideEncryptionConfig Now func() time.Time Log interface { ErrorLogger + WarnLogger DebugLogger } + UserInfo *UserInfo + UploadChan chan<- *S3PartToUpload } func buildKey(s3b *S3Bucket, path string) Path { return s3b.KeyPrefix.Join(SplitIntoPath(path)) } -func buildPath(s3b *S3Bucket, key string) (string, bool) { - _key := SplitIntoPath(key) - if !_key.IsPrefixed(s3b.KeyPrefix) { - return "", false - } - return "/" + _key[len(s3b.KeyPrefix):].String(), true -} - func (s3io *S3BucketIO) Fileread(req *sftp.Request) (io.ReaderAt, error) { + lSuccess := prometheus.Labels{"method": req.Method, "status": "success"} + lFailure := prometheus.Labels{"method": req.Method, "status": "failure"} if !s3io.Perms.Readable { + mOperationStatus.With(lFailure).Inc() return nil, fmt.Errorf("read operation not allowed as per configuration") } sess, err := aws_session.NewSession() if err != nil { + mOperationStatus.With(lFailure).Inc() + mAWSSessionError.Inc() return nil, err } s3 := s3io.Bucket.S3(sess) @@ -570,11 +520,13 @@ func (s3io *S3BucketIO) Fileread(req *sftp.Request) (io.ReaderAt, error) { phInfo := s3io.PhantomObjectMap.Get(key) if phInfo != nil { - return bytes.NewReader(phInfo.Opaque.(*S3PutObjectWriter).writer.Bytes()), nil + mOperationStatus.With(lFailure).Inc() + return nil, fmt.Errorf("trying to download an uploading file") } keyStr := key.String() ctx := combineContext(s3io.Ctx, req.Context()) + F(s3io.Log.Warn, "Audit: User %s downloaded file \"%s\"", s3io.UserInfo.String(), keyStr) F(s3io.Log.Debug, "GetObject(Bucket=%s, Key=%s)", s3io.Bucket.Bucket, keyStr) sse := s3io.ServerSideEncryption goo, err := s3.GetObjectWithContext( @@ -588,23 +540,30 @@ func (s3io *S3BucketIO) Fileread(req *sftp.Request) (io.ReaderAt, error) { }, ) if err != nil { + mOperationStatus.With(lFailure).Inc() return nil, err } - return &S3GetObjectOutputReader{ + oor := &S3GetObjectOutputReader{ Ctx: ctx, Goo: goo, Log: s3io.Log, Lookback: s3io.ReaderLookbackBufferSize, MinChunkSize: s3io.ReaderMinChunkSize, - }, nil + } + mOperationStatus.With(lSuccess).Inc() + return oor, nil } func (s3io *S3BucketIO) Filewrite(req *sftp.Request) (io.WriterAt, error) { + lFailure := prometheus.Labels{"method": req.Method, "status": "failure"} if !s3io.Perms.Writable { + mOperationStatus.With(lFailure).Inc() return nil, fmt.Errorf("write operation not allowed as per configuration") } sess, err := aws_session.NewSession() if err != nil { + mOperationStatus.With(lFailure).Inc() + mAWSSessionError.Inc() return nil, err } maxObjectSize := s3io.Bucket.MaxObjectSize @@ -617,43 +576,53 @@ func (s3io *S3BucketIO) Filewrite(req *sftp.Request) (io.WriterAt, error) { Size: 0, LastModified: s3io.Now(), } - F(s3io.Log.Debug, "S3PutObjectWriter.New(key=%s)", key) - oow := &S3PutObjectWriter{ - Ctx: combineContext(s3io.Ctx, req.Context()), - Bucket: s3io.Bucket.Bucket, - Key: key, - S3: s3io.Bucket.S3(sess), - ServerSideEncryption: s3io.ServerSideEncryption, - Log: s3io.Log, - MaxObjectSize: maxObjectSize, - PhantomObjectMap: s3io.PhantomObjectMap, - Info: info, - writer: NewBytesWriter(), - } - info.Opaque = oow + F(s3io.Log.Warn, "Audit: User %s uploaded file \"%s\"", s3io.UserInfo.String(), key) + F(s3io.Log.Debug, "S3MultipartUploadWriter.New(key=%s)", key) + oow := &S3MultipartUploadWriter{ + Ctx: combineContext(s3io.Ctx, req.Context()), + Bucket: s3io.Bucket.Bucket, + Key: key, + S3: s3io.Bucket.S3(sess), + ServerSideEncryption: s3io.ServerSideEncryption, + Log: s3io.Log, + MaxObjectSize: maxObjectSize, + UploadMemoryBufferPool: s3io.UploadMemoryBufferPool, + PhantomObjectMap: s3io.PhantomObjectMap, + Info: info, + RequestMethod: req.Method, + UploadChan: s3io.UploadChan, + } s3io.PhantomObjectMap.Add(info) return oow, nil } func (s3io *S3BucketIO) Filecmd(req *sftp.Request) error { + lSuccess := prometheus.Labels{"method": req.Method, "status": "success"} + lFailure := prometheus.Labels{"method": req.Method, "status": "failure"} + lIgnored := prometheus.Labels{"method": req.Method, "status": "ignored"} switch req.Method { case "Rename": if !s3io.Perms.Writable { + mOperationStatus.With(lFailure).Inc() return fmt.Errorf("write operation not allowed as per configuration") } src := buildKey(s3io.Bucket, req.Filepath) dest := buildKey(s3io.Bucket, req.Target) if s3io.PhantomObjectMap.Rename(src, dest) { + mOperationStatus.With(lIgnored).Inc() return nil } sess, err := aws_session.NewSession() if err != nil { + mOperationStatus.With(lFailure).Inc() + mAWSSessionError.Inc() return err } srcStr := src.String() destStr := dest.String() copySource := s3io.Bucket.Bucket + "/" + srcStr sse := s3io.ServerSideEncryption + F(s3io.Log.Warn, "Audit: User %s renamed \"%s\" to \"%s\"", s3io.UserInfo.String(), srcStr, destStr) F(s3io.Log.Debug, "CopyObject(Bucket=%s, Key=%s, CopySource=%s, Sse=%v)", s3io.Bucket.Bucket, destStr, copySource, sse.Type) _, err = s3io.Bucket.S3(sess).CopyObjectWithContext( combineContext(s3io.Ctx, req.Context()), @@ -671,6 +640,7 @@ func (s3io *S3BucketIO) Filecmd(req *sftp.Request) error { ) if err != nil { s3io.Log.Debug("=> ", err) + mOperationStatus.With(lFailure).Inc() return err } F(s3io.Log.Debug, "DeleteObject(Bucket=%s, Key=%s)", s3io.Bucket.Bucket, srcStr) @@ -683,21 +653,28 @@ func (s3io *S3BucketIO) Filecmd(req *sftp.Request) error { ) if err != nil { s3io.Log.Debug("=> ", err) + mOperationStatus.With(lFailure).Inc() return err } + mOperationStatus.With(lSuccess).Inc() case "Remove": if !s3io.Perms.Writable { + mOperationStatus.With(lFailure).Inc() return fmt.Errorf("write operation not allowed as per configuration") } key := buildKey(s3io.Bucket, req.Filepath) if s3io.PhantomObjectMap.Remove(key) != nil { + mOperationStatus.With(lIgnored).Inc() return nil } sess, err := aws_session.NewSession() if err != nil { + mOperationStatus.With(lFailure).Inc() + mAWSSessionError.Inc() return err } keyStr := key.String() + F(s3io.Log.Warn, "Audit: User %s deleted file \"%s\"", s3io.UserInfo.String(), key) F(s3io.Log.Debug, "DeleteObject(Bucket=%s, Key=%s)", s3io.Bucket.Bucket, key) _, err = s3io.Bucket.S3(sess).DeleteObjectWithContext( combineContext(s3io.Ctx, req.Context()), @@ -708,22 +685,80 @@ func (s3io *S3BucketIO) Filecmd(req *sftp.Request) error { ) if err != nil { s3io.Log.Debug("=> ", err) + mOperationStatus.With(lFailure).Inc() + return err + } + mOperationStatus.With(lSuccess).Inc() + case "Mkdir": + if !s3io.Perms.Writable { + mOperationStatus.With(lFailure).Inc() + return fmt.Errorf("write operation not allowed as per configuration") + } + key := buildKey(s3io.Bucket, req.Filepath) + keyStr := fmt.Sprintf("%s/", key.String()) + sess, err := aws_session.NewSession() + if err != nil { + mOperationStatus.With(lFailure).Inc() + mAWSSessionError.Inc() + return err + } + F(s3io.Log.Debug, "Mkdir(Bucket=%s, Key=%s)", s3io.Bucket.Bucket, keyStr) + _, err = s3io.Bucket.S3(sess).PutObject( + &aws_s3.PutObjectInput{ + Bucket: &s3io.Bucket.Bucket, + Key: &keyStr, + }, + ) + if err != nil { + s3io.Log.Debug("=> ", err) + mOperationStatus.With(lFailure).Inc() + return err + } + mOperationStatus.With(lSuccess).Inc() + case "Rmdir": + if !s3io.Perms.Writable { + mOperationStatus.With(lFailure).Inc() + return fmt.Errorf("write operation not allowed as per configuration") + } + key := buildKey(s3io.Bucket, req.Filepath) + keyStr := fmt.Sprintf("%s/", key.String()) + sess, err := aws_session.NewSession() + if err != nil { + mOperationStatus.With(lFailure).Inc() + mAWSSessionError.Inc() + return err + } + F(s3io.Log.Debug, "Rmdir(Bucket=%s, Key=%s)", s3io.Bucket.Bucket, keyStr) + _, err = s3io.Bucket.S3(sess).DeleteObject( + &aws_s3.DeleteObjectInput{ + Bucket: &s3io.Bucket.Bucket, + Key: &keyStr, + }, + ) + if err != nil { + s3io.Log.Debug("=> ", err) + mOperationStatus.With(lFailure).Inc() return err } + mOperationStatus.With(lSuccess).Inc() } return nil } func (s3io *S3BucketIO) Filelist(req *sftp.Request) (sftp.ListerAt, error) { + lPermErr := prometheus.Labels{"method": req.Method} sess, err := aws_session.NewSession() if err != nil { + mAWSSessionError.Inc() return nil, err } switch req.Method { case "Stat", "ReadLink": if !s3io.Perms.Readable && !s3io.Perms.Listable { + mPermissionsError.With(lPermErr).Inc() return nil, fmt.Errorf("stat operation not allowed as per configuration") } + F(s3io.Log.Warn, "Audit: User %s read path stats \"%s\"", s3io.UserInfo.String(), req.Filepath) key := buildKey(s3io.Bucket, req.Filepath) return &S3ObjectStat{ DebugLogger: s3io.Log, @@ -736,8 +771,10 @@ func (s3io *S3BucketIO) Filelist(req *sftp.Request) (sftp.ListerAt, error) { }, nil case "List": if !s3io.Perms.Listable { + mPermissionsError.With(lPermErr).Inc() return nil, fmt.Errorf("listing operation not allowed as per configuration") } + F(s3io.Log.Warn, "Audit: User %s listed path \"%s\"", s3io.UserInfo.String(), req.Filepath) return &S3ObjectLister{ DebugLogger: s3io.Log, Ctx: combineContext(s3io.Ctx, req.Context()), @@ -748,6 +785,7 @@ func (s3io *S3BucketIO) Filelist(req *sftp.Request) (sftp.ListerAt, error) { PhantomObjectMap: s3io.PhantomObjectMap, }, nil default: + mPermissionsError.With(lPermErr).Inc() return nil, fmt.Errorf("unsupported method: %s", req.Method) } } diff --git a/config.go b/config.go index 3916120..749d863 100644 --- a/config.go +++ b/config.go @@ -2,17 +2,23 @@ package main import ( "fmt" - "github.com/BurntSushi/toml" - "github.com/pkg/errors" "io/ioutil" "net/url" + "time" + + "github.com/BurntSushi/toml" + "github.com/pkg/errors" ) var ( - minReaderLookbackBufferSize = 1048576 - minReaderMinChunkSize = 262144 - minListerLookbackBufferSize = 100 - vTrue = true + minReaderLookbackBufferSize = 1048576 + minReaderMinChunkSize = 262144 + minListerLookbackBufferSize = 100 + defaultUploadMemoryBufferSize = 5 * 1024 * 1024 // 5 MB + defaultUploadMemoryBufferPoolSize = 10 + defaultUploadMemoryBufferPoolTimeout = 5 * time.Second + defaultUploadWorkersCount = 2 + vTrue = true ) type URL struct { @@ -24,6 +30,15 @@ func (u *URL) UnmarshalText(text []byte) (err error) { return } +type duration struct { + time.Duration +} + +func (d *duration) UnmarshalText(text []byte) (err error) { + d.Duration, err = time.ParseDuration(string(text)) + return err +} + type AWSCredentialsConfig struct { AWSAccessKeyID string `toml:"aws_access_key_id"` AWSSecretAccessKey string `toml:"aws_secret_access_key"` @@ -41,7 +56,7 @@ type S3BucketConfig struct { BucketUrl *URL `toml:"bucket_url"` Auth string `toml:"auth"` MaxObjectSize *int64 `toml:"max_object_size"` - Readable *bool `toml:"readble"` + Readable *bool `toml:"readable"` Writable *bool `toml:"writable"` Listable *bool `toml:"listable"` ServerSideEncryption ServerSideEncryptionType `toml:"server_side_encryption"` @@ -63,14 +78,20 @@ type AuthConfig struct { } type S3SFTPProxyConfig struct { - Bind string `toml:"bind"` - HostKeyFile string `toml:"host_key_file"` - Banner string `toml:"banner"` - ReaderLookbackBufferSize *int `toml:"reader_lookback_buffer_size"` - ReaderMinChunkSize *int `toml:"reader_min_chunk_size"` - ListerLookbackBufferSize *int `toml:"lister_lookback_buffer_size"` - Buckets map[string]*S3BucketConfig `toml:"buckets"` - AuthConfigs map[string]*AuthConfig `toml:"auth"` + Bind string `toml:"bind"` + HostKeyFile string `toml:"host_key_file"` + Banner string `toml:"banner"` + ReaderLookbackBufferSize *int `toml:"reader_lookback_buffer_size"` + ReaderMinChunkSize *int `toml:"reader_min_chunk_size"` + ListerLookbackBufferSize *int `toml:"lister_lookback_buffer_size"` + UploadMemoryBufferSize *int `toml:"upload_memory_buffer_size"` + UploadMemoryBufferPoolSize *int `toml:"upload_memory_buffer_pool_size"` + UploadMemoryBufferPoolTimeout *duration `toml:"upload_memory_buffer_pool_timeout"` + UploadWorkersCount *int `toml:"upload_workers_count"` + Buckets map[string]*S3BucketConfig `toml:"buckets"` + AuthConfigs map[string]*AuthConfig `toml:"auth"` + MetricsBind string `toml:"metrics_bind"` + MetricsEndpoint string `toml:"metrics_endpoint"` } func validateAndFixupBucketConfig(bCfg *S3BucketConfig) error { @@ -179,6 +200,22 @@ func ReadConfig(tomlStr string) (*S3SFTPProxyConfig, error) { return nil, fmt.Errorf("lister_lookback_buffer_size must be equal to or greater than %d", minListerLookbackBufferSize) } + if cfg.UploadMemoryBufferSize == nil { + cfg.UploadMemoryBufferSize = &defaultUploadMemoryBufferSize + } + + if cfg.UploadMemoryBufferPoolSize == nil { + cfg.UploadMemoryBufferPoolSize = &defaultUploadMemoryBufferPoolSize + } + + if cfg.UploadMemoryBufferPoolTimeout == nil { + cfg.UploadMemoryBufferPoolTimeout = &duration{defaultUploadMemoryBufferPoolTimeout} + } + + if cfg.UploadWorkersCount == nil { + cfg.UploadWorkersCount = &defaultUploadWorkersCount + } + for name, bCfg := range cfg.Buckets { err := validateAndFixupBucketConfig(bCfg) if err != nil { diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c1f6d35 --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module github.com/moriyoshi/s3-sftp-proxy + +go 1.12 + +require ( + github.com/BurntSushi/toml v0.3.1 + github.com/aws/aws-sdk-go v1.25.3 + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.1 // indirect + github.com/pkg/errors v0.8.1 + github.com/pkg/sftp v1.10.1 + github.com/prometheus/client_golang v1.3.0 + github.com/sirupsen/logrus v1.4.2 + github.com/stretchr/testify v1.4.0 + golang.org/x/crypto v0.0.0-20191001170739-f9e2070545dc +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8c95faa --- /dev/null +++ b/go.sum @@ -0,0 +1,103 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/aws/aws-sdk-go v1.25.3 h1:uM16hIw9BotjZKMZlX05SN2EFtaWfi/NonPKIARiBLQ= +github.com/aws/aws-sdk-go v1.25.3/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go v1.26.8 h1:W+MPuCFLSO/itZkZ5GFOui0YC1j3lZ507/m5DFPtzE4= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.10.1 h1:VasscCm72135zRysgrJDKsntdmPN+OuU3+nnHYA9wyc= +github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.3.0 h1:miYCvYqFXtl/J9FIy8eNpBfYthAEFg+Ys0XyUVEcDsc= +github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.1.0 h1:ElTg5tNp4DqfV7UQjDqv2+RJlNzsDtvNAWccbItceIE= +github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.7.0 h1:L+1lyG48J1zAQXA3RBX/nG/B3gjlHq0zTt2tlbJLyCY= +github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.0.8 h1:+fpWZdT24pJBiqJdAwYBjPSk+5YmQzYNPYzQsdzLkt8= +github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191001170739-f9e2070545dc h1:KyTYo8xkh/2WdbFLUyQwBS0Jfn3qfZ9QmuPbok2oENE= +golang.org/x/crypto v0.0.0-20191001170739-f9e2070545dc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191220142924-d4481acd189f h1:68K/z8GLUxV76xGSqwTWw2gyk/jwn79LUL43rES2g8o= +golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/logging.go b/logging.go index 361c8b2..24f6856 100644 --- a/logging.go +++ b/logging.go @@ -8,6 +8,10 @@ type InfoLogger interface { Info(args ...interface{}) } +type WarnLogger interface { + Warn(args ...interface{}) +} + type ErrorLogger interface { Error(args ...interface{}) } diff --git a/main.go b/main.go index 721b8b2..f512d0b 100644 --- a/main.go +++ b/main.go @@ -7,11 +7,12 @@ import ( "fmt" "io/ioutil" "net" + "net/http" "os" "os/signal" - "time" "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) @@ -157,24 +158,52 @@ func main() { defer lsnr.Close() logger.Info("Listen on ", _bind) + metricsBind := cfg.MetricsBind + if metricsBind == "" { + metricsBind = ":2112" + } + + metricsEndpoint := cfg.MetricsEndpoint + if metricsEndpoint == "" { + metricsEndpoint = "/metrics" + } + + http.Handle(metricsEndpoint, promhttp.Handler()) + + go func() { + http.ListenAndServe(metricsBind, nil) + }() + + logger.Info("Metrics listen on ", metricsBind, metricsEndpoint) + ctx, cancel := context.WithCancel(context.Background()) - defer cancel() sigChan := make(chan os.Signal) signal.Notify(sigChan, os.Interrupt) + uploadWorkers := NewS3UploadWorkers(ctx, *cfg.UploadWorkersCount, logger) + uploadChan := uploadWorkers.Start() + + defer func() { + cancel() + uploadWorkers.WaitForCompletion() + }() + errChan := make(chan error) go func() { - errChan <- (&Server{ - S3Buckets: buckets, - ServerConfig: sCfg, - Log: logger, - ReaderLookbackBufferSize: *cfg.ReaderLookbackBufferSize, - ReaderMinChunkSize: *cfg.ReaderMinChunkSize, - ListerLookbackBufferSize: *cfg.ListerLookbackBufferSize, - PhantomObjectMap: NewPhantomObjectMap(), - Now: time.Now, - }).RunListenerEventLoop(ctx, lsnr.(*net.TCPListener)) + errChan <- NewServer( + ctx, + buckets, + sCfg, + logger, + *cfg.ReaderLookbackBufferSize, + *cfg.ReaderMinChunkSize, + *cfg.ListerLookbackBufferSize, + *cfg.UploadMemoryBufferSize, + *cfg.UploadMemoryBufferPoolSize, + (*cfg.UploadMemoryBufferPoolTimeout).Duration, + uploadChan, + ).RunListenerEventLoop(ctx, lsnr.(*net.TCPListener)) }() outer: diff --git a/memory_buffer_pool.go b/memory_buffer_pool.go new file mode 100644 index 0000000..325ce3f --- /dev/null +++ b/memory_buffer_pool.go @@ -0,0 +1,55 @@ +package main + +import ( + "context" + "fmt" + "sync/atomic" + "time" +) + +// MemoryBufferPool pool of memory buffers +// Used to reduce the GC generated when a memory buffer of the same size is needed +type MemoryBufferPool struct { + BufSize int + Used int32 + ch chan []byte + ctx context.Context + timeout time.Duration +} + +// NewMemoryBufferPool creates a new partition pool giving its size +func NewMemoryBufferPool(ctx context.Context, bufSize int, poolSize int, timeout time.Duration) *MemoryBufferPool { + mbp := &MemoryBufferPool{ + BufSize: bufSize, + ch: make(chan []byte, poolSize), + ctx: ctx, + timeout: timeout, + } + mMemoryBufferPoolMax.Add(float64(poolSize)) + for ; poolSize > 0; poolSize-- { + mbp.ch <- make([]byte, mbp.BufSize) + } + return mbp +} + +// Get gets a buffer from the pool +func (mbp *MemoryBufferPool) Get() ([]byte, error) { + select { + case <-mbp.ctx.Done(): + return nil, fmt.Errorf("partition pool get canceled") + case <-time.After(mbp.timeout): + mMemoryBufferPoolTimeouts.Inc() + return nil, fmt.Errorf("timeout getting partition from pool") + case res := <-mbp.ch: + mMemoryBufferPoolUsed.Inc() + atomic.AddInt32(&mbp.Used, 1) + return res, nil + } +} + +// Put returns a buffer into the pool +func (mbp *MemoryBufferPool) Put(buf []byte) { + mbp.ch <- buf + atomic.AddInt32(&mbp.Used, -1) + mMemoryBufferPoolUsed.Dec() +} diff --git a/memory_buffer_pool_test.go b/memory_buffer_pool_test.go new file mode 100644 index 0000000..60d3729 --- /dev/null +++ b/memory_buffer_pool_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestMemoryBufferPoolGetBasic(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + p := NewMemoryBufferPool(ctx, 10, 1, 5*time.Second) + buf, err := p.Get() + assert.NotNil(t, buf) + assert.NoError(t, err) +} + +func TestMemoryBufferPoolGetTimeout(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + p := NewMemoryBufferPool(ctx, 10, 0, 100*time.Millisecond) + buf, err := p.Get() + assert.Nil(t, buf) + assert.Error(t, err) + assert.Regexp(t, ".*timeout.*", err.Error()) +} + +func TestMemoryBufferPoolGetCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + p := NewMemoryBufferPool(ctx, 10, 0, 10*time.Second) + go cancel() + buf, err := p.Get() + assert.Nil(t, buf) + assert.Error(t, err) + assert.Regexp(t, ".*canceled.*", err.Error()) +} diff --git a/metrics.go b/metrics.go new file mode 100644 index 0000000..57f611f --- /dev/null +++ b/metrics.go @@ -0,0 +1,46 @@ +package main + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + mOperationStatus = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "sftp_operation_status", + Help: "Represents SFTP operation statuses", + }, + []string{"method", "status"}, + ) + mAWSSessionError = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftp_aws_session_error", + Help: "The total number of session errors", + }, + ) + mPermissionsError = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "sftp_permissions_error", + Help: "The total number of permission errors", + }, + []string{"method"}, + ) + mUsersConnected = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "sftp_users_connected", + Help: "The number of users connected now", + }, + ) + mMemoryBufferPoolMax = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "sftp_memory_buffer_pool_max", + Help: "The number of maximum memory buffers in the pool", + }, + ) + mMemoryBufferPoolUsed = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "sftp_memory_buffer_pool_used", + Help: "The number of memory buffers used in the pool", + }, + ) + mMemoryBufferPoolTimeouts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftp_memory_buffer_pool_timeouts", + Help: "The total number of timeouts produced in the pool", + }, + ) +) diff --git a/multipart_upload.go b/multipart_upload.go new file mode 100644 index 0000000..bb637b2 --- /dev/null +++ b/multipart_upload.go @@ -0,0 +1,518 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "sync" + + "github.com/moriyoshi/s3-sftp-proxy/util" + + aws_s3 "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/prometheus/client_golang/prometheus" +) + +// S3PartUploadState state in which a part upload is +type S3PartUploadState int + +const ( + // S3PartUploadStateAdding adding data to content + S3PartUploadStateAdding = iota + // S3PartUploadStateFull all data present, ready to be sent + S3PartUploadStateFull + // S3PartUploadStateSent already sent to S3 + S3PartUploadStateSent + // S3PartUploadErrorSending error sending part to S3 + S3PartUploadErrorSending + // S3PartUploadCancelled cancelled due to previous error + S3PartUploadCancelled +) + +// S3PartToUpload S3 part to be uploaded +type S3PartToUpload struct { + // Part content + content []byte + // Part number (starting from 1) + partNumber int64 + // Offset ranges already filled + o *util.OffsetRanges + // S3MultipartUploadWriter that contains this part + uw *S3MultipartUploadWriter + // Mutex to avoid problems accessing to the same part upload + mtx sync.Mutex + // State to know how to treat this part + state S3PartUploadState +} + +func (part *S3PartToUpload) getContent() ([]byte, error) { + end := part.o.GetMaxValidOffset() + if end == -1 { + return nil, fmt.Errorf("Trying to obtain content of incomplete part %d", part.partNumber) + } + return part.content[0:end], nil +} + +func (part *S3PartToUpload) copy(buf []byte, start int64, end int64) { + copy(part.content[start:end], buf) + part.o.Add(start, end) +} + +func (part *S3PartToUpload) isFull() bool { + return part.o.IsFull() +} + +// S3MultipartUploadWriter uploads multiple parts to S3 having a writer interface +type S3MultipartUploadWriter struct { + Ctx context.Context + Bucket string + Key Path + S3 s3iface.S3API + ServerSideEncryption *ServerSideEncryptionConfig + Log interface { + DebugLogger + WarnLogger + ErrorLogger + } + MaxObjectSize int64 + UploadMemoryBufferPool *MemoryBufferPool + Info *PhantomObjectInfo + PhantomObjectMap *PhantomObjectMap + RequestMethod string + mtx sync.Mutex + completedParts []*aws_s3.CompletedPart + parts []*S3PartToUpload + multiPartUploadID *string + err error + uploadGroup sync.WaitGroup + UploadChan chan<- *S3PartToUpload +} + +// Close closes multipart upload writer +func (u *S3MultipartUploadWriter) Close() error { + F(u.Log.Debug, "S3MultipartUploadWriter.Close") + + u.PhantomObjectMap.RemoveByInfoPtr(u.Info) + + u.mtx.Lock() + defer u.mtx.Unlock() + + err := u.err + if err == nil { + // Only one part -> use PutObject + if len(u.parts) == 1 && u.multiPartUploadID == nil { + part := u.parts[0] + + var content []byte + content, err = part.getContent() + if err == nil { + err = u.s3PutObject(content) + u.UploadMemoryBufferPool.Put(part.content) + + if err == nil { + part.state = S3PartUploadStateSent + } else { + part.state = S3PartUploadErrorSending + } + } else { + u.UploadMemoryBufferPool.Put(part.content) + part.state = S3PartUploadErrorSending + } + } else { + // More than 1 part -> MultiPartUpload used before, we have to send latest part, wait until all parts will be uploaded and then complete the job + u.mtx.Unlock() + + err = u.enqueueUpload(u.parts[len(u.parts)-1]) + u.uploadGroup.Wait() + + u.mtx.Lock() + if err == nil { + pending := u.closePartsInStateAdding() + if pending > 0 { + err = fmt.Errorf("Closing upload and having %d pending parts to fill", pending) + } else { + err = u.err + if err == nil { + err = u.s3CompleteMultipartUpload() + } + } + } + } + } + + if err != nil { + u.s3AbortMultipartUpload() + u.closePartsInStateAdding() + mOperationStatus.With(prometheus.Labels{"method": u.RequestMethod, "status": "failure"}).Inc() + } else { + mOperationStatus.With(prometheus.Labels{"method": u.RequestMethod, "status": "success"}).Inc() + } + return err +} + +// WriteAt stores on memory the data sent to be uploaded and uploads it when a part +// is completed +func (u *S3MultipartUploadWriter) WriteAt(buf []byte, off int64) (int, error) { + pending := int64(len(buf)) + offFinal := off + pending + partSize := int64(u.UploadMemoryBufferPool.BufSize) + partNumberInitial := int(off / partSize) + partOffsetInitial := off % partSize + bufOffset := int64(0) + + var err error + u.mtx.Lock() + err = u.err + if err == nil && u.MaxObjectSize >= 0 && offFinal > u.MaxObjectSize { + err = fmt.Errorf("file too large: maximum allowed size is %d bytes", u.MaxObjectSize) + } + + if err != nil { + F(u.Log.Debug, "Error on WriteAt: %s", err.Error()) + u.s3AbortMultipartUpload() + u.closePartsInStateAdding() + u.err = err + u.mtx.Unlock() + mOperationStatus.With(prometheus.Labels{"method": u.RequestMethod, "status": "failure"}).Inc() + return 0, err + } + + partNumberFinal := int((off + pending - 1) / partSize) + + F(u.Log.Debug, "len(buf)=%d, off=%d, partNumberInitial=%d, partOffsetInitial=%d", len(buf), off, partNumberInitial, partOffsetInitial) + u.Info.SetSizeIfGreater(offFinal) + if len(u.parts) <= partNumberFinal { + newParts := make([]*S3PartToUpload, partNumberFinal+1) + copy(newParts, u.parts) + u.parts = newParts + } + u.mtx.Unlock() + + partNumber := partNumberInitial + partOffset := partOffsetInitial + for pending > 0 { + u.mtx.Lock() + part := u.parts[partNumber] + if part == nil { + F(u.Log.Debug, "Getting space from partition pool for part number: %d", partNumber) + buf, err := u.UploadMemoryBufferPool.Get() + if err != nil { + F(u.Log.Debug, "Error getting a partition pool: %s", err.Error()) + u.s3AbortMultipartUpload() + u.closePartsInStateAdding() + u.err = err + u.mtx.Unlock() + mOperationStatus.With(prometheus.Labels{"method": u.RequestMethod, "status": "failure"}).Inc() + return 0, err + } + + part = &S3PartToUpload{ + content: buf, + o: util.NewOffsetRanges(partSize), + uw: u, + state: S3PartUploadStateAdding, + partNumber: int64(partNumber + 1), + } + u.parts[partNumber] = part + } + u.mtx.Unlock() + + partOffsetFinal := partOffset + pending + if partOffsetFinal > partSize { + partOffsetFinal = partSize + } + partCopied := partOffsetFinal - partOffset + + part.mtx.Lock() + if part.state < S3PartUploadStateFull { + part.copy(buf[bufOffset:bufOffset+partCopied], partOffset, partOffsetFinal) + if part.isFull() { + err = u.enqueueUpload(part) + if err != nil { + part.mtx.Unlock() + u.mtx.Lock() + F(u.Log.Debug, "Error enqueuing a part upload: %s", err.Error()) + u.s3AbortMultipartUpload() + u.closePartsInStateAdding() + u.err = err + u.mtx.Unlock() + mOperationStatus.With(prometheus.Labels{"method": u.RequestMethod, "status": "failure"}).Inc() + return 0, err + } + } + } else { + F(u.Log.Debug, "Trying to add more data to a part already full") + } + part.mtx.Unlock() + partNumber++ + pending -= partCopied + bufOffset += partCopied + partOffset = 0 + } + return len(buf), nil +} + +func (u *S3MultipartUploadWriter) enqueueUpload(part *S3PartToUpload) error { + if part.state < S3PartUploadStateFull { + u.mtx.Lock() + if u.multiPartUploadID == nil { + if err := u.s3CreateMultipartUpload(); err != nil { + u.mtx.Unlock() + return err + } + } + u.mtx.Unlock() + + F(u.Log.Debug, "Enqueuing part %d to be uploaded", part.partNumber) + part.state = S3PartUploadStateFull + u.uploadGroup.Add(1) + select { + case <-u.Ctx.Done(): + return fmt.Errorf("Enqueue upload cancelled") + case u.UploadChan <- part: + } + } + return nil +} + +func (u *S3MultipartUploadWriter) closePartsInStateAdding() int { + pending := 0 + if u.parts != nil { + for i := len(u.parts) - 1; i >= 0; i-- { + part := u.parts[i] + if part != nil { + part.mtx.Lock() + if part.state == S3PartUploadStateAdding { + u.UploadMemoryBufferPool.Put(part.content) + part.state = S3PartUploadCancelled + pending++ + } + part.mtx.Unlock() + } + } + } + return pending +} + +// S3 related actions +func (u *S3MultipartUploadWriter) s3CreateMultipartUpload() error { + key := u.Info.GetOne().Key.String() + sse := u.ServerSideEncryption + F(u.Log.Debug, "CreateMultipartUpload(Bucket=%s, Key=%s, Sse=%v)", u.Bucket, key, sse) + + params := &aws_s3.CreateMultipartUploadInput{ + ACL: &aclPrivate, + Bucket: &u.Bucket, + Key: &key, + ServerSideEncryption: sseTypes[sse.Type], + SSECustomerAlgorithm: nilIfEmpty(sse.CustomerAlgorithm()), + SSECustomerKey: nilIfEmpty(sse.CustomerKey), + SSECustomerKeyMD5: nilIfEmpty(sse.CustomerKeyMD5), + SSEKMSKeyId: nilIfEmpty(sse.KMSKeyId), + } + + resp, err := u.S3.CreateMultipartUploadWithContext(u.Ctx, params) + if err != nil { + u.Log.Debug("=> ", err) + F(u.Log.Error, "failed to create multipart upload: %s", err.Error()) + return err + } + F(u.Log.Debug, "=> OK, uploadId=%s", *resp.UploadId) + u.multiPartUploadID = resp.UploadId + return nil +} + +func (u *S3MultipartUploadWriter) s3PutObject(content []byte) error { + key := u.Info.GetOne().Key.String() + sse := u.ServerSideEncryption + F(u.Log.Debug, "PutObject(Bucket=%s, Key=%s, Sse=%v)", u.Bucket, key, sse) + + params := &aws_s3.PutObjectInput{ + ACL: &aclPrivate, + Body: bytes.NewReader(content), + Bucket: &u.Bucket, + Key: &key, + ServerSideEncryption: sseTypes[sse.Type], + SSECustomerAlgorithm: nilIfEmpty(sse.CustomerAlgorithm()), + SSECustomerKey: nilIfEmpty(sse.CustomerKey), + SSECustomerKeyMD5: nilIfEmpty(sse.CustomerKeyMD5), + SSEKMSKeyId: nilIfEmpty(sse.KMSKeyId), + } + if _, err := u.S3.PutObjectWithContext(u.Ctx, params); err != nil { + u.Log.Debug("=> ", err) + F(u.Log.Error, "failed to put object: %s", err.Error()) + return err + } + u.Log.Debug("=> OK") + + return nil +} + +func (u *S3MultipartUploadWriter) s3AbortMultipartUpload() error { + if u.multiPartUploadID != nil { + key := u.Info.GetOne().Key.String() + sse := u.ServerSideEncryption + F(u.Log.Debug, "AbortMultipartUpload(Bucket=%s, Key=%s, Sse=%v)", u.Bucket, key, sse) + + params := &aws_s3.AbortMultipartUploadInput{ + Bucket: &u.Bucket, + Key: &key, + UploadId: u.multiPartUploadID, + } + u.multiPartUploadID = nil + if _, err := u.S3.AbortMultipartUploadWithContext(u.Ctx, params); err != nil { + u.Log.Debug("=> ", err) + F(u.Log.Error, "failed to abort multipart upload: %s", err.Error()) + return err + } + u.Log.Debug("=> OK") + } + + return nil +} + +func (u *S3MultipartUploadWriter) s3CompleteMultipartUpload() error { + key := u.Info.GetOne().Key.String() + sse := u.ServerSideEncryption + F(u.Log.Debug, "CompleteMultipartUpload(Bucket=%s, Key=%s, Sse=%v)", u.Bucket, key, sse) + + params := &aws_s3.CompleteMultipartUploadInput{ + Bucket: &u.Bucket, + Key: &key, + UploadId: u.multiPartUploadID, + MultipartUpload: &aws_s3.CompletedMultipartUpload{Parts: u.completedParts}, + } + if _, err := u.S3.CompleteMultipartUploadWithContext(u.Ctx, params); err != nil { + u.Log.Debug("=> ", err) + F(u.Log.Error, "failed to complete multipart upload: %s", err.Error()) + return err + } + u.Log.Debug("=> OK") + return nil +} + +func (u *S3MultipartUploadWriter) s3UploadPart(part *S3PartToUpload) error { + key := u.Info.GetOne().Key.String() + sse := u.ServerSideEncryption + F(u.Log.Debug, "UploadPart(Bucket=%s, Key=%s, Sse=%v, uploadId=%s, part=%d)", u.Bucket, key, sse, *u.multiPartUploadID, part.partNumber) + + var content []byte + var err error + + content, err = part.getContent() + if err != nil { + return err + } + + params := &aws_s3.UploadPartInput{ + Bucket: &u.Bucket, + Key: &key, + Body: bytes.NewReader(content), + UploadId: u.multiPartUploadID, + SSECustomerAlgorithm: nilIfEmpty(sse.CustomerAlgorithm()), + SSECustomerKey: nilIfEmpty(sse.CustomerKey), + SSECustomerKeyMD5: nilIfEmpty(sse.CustomerKeyMD5), + PartNumber: &part.partNumber, + } + + resp, err := u.S3.UploadPartWithContext(u.Ctx, params) + + if err != nil { + u.Log.Debug("=> ", err) + F(u.Log.Error, "failed to upload part: %s", err.Error()) + return err + } + + if len(u.completedParts) < len(u.parts) { + newCompletedParts := make([]*aws_s3.CompletedPart, len(u.parts)) + copy(newCompletedParts, u.completedParts) + u.completedParts = newCompletedParts + } + completed := &aws_s3.CompletedPart{ETag: resp.ETag, PartNumber: &(part.partNumber)} + u.completedParts[part.partNumber-1] = completed + return nil +} + +// seterr is a thread-safe setter for the error object +func (u *S3MultipartUploadWriter) seterr(e error) { + u.mtx.Lock() + defer u.mtx.Unlock() + + u.err = e +} + +// S3UploadWorkers object to manage S3 upload workers +type S3UploadWorkers struct { + ctx context.Context + workers int + log DebugLogger + wg sync.WaitGroup +} + +// NewS3UploadWorkers creates new upload workers to take pending part uploads from a channel and +// upload them to S3 +func NewS3UploadWorkers(ctx context.Context, workers int, log DebugLogger) *S3UploadWorkers { + return &S3UploadWorkers{ + ctx: ctx, + workers: workers, + log: log, + } +} + +// Start starts workers +func (w *S3UploadWorkers) Start() chan<- *S3PartToUpload { + uploadChan := make(chan *S3PartToUpload) + + w.wg.Add(w.workers) + for c := 0; c < w.workers; c++ { + go func(wn int) { + defer w.wg.Done() + + F(w.log.Debug, "S3 upload worker %d waiting for upload jobs", wn) + for { + select { + case <-w.ctx.Done(): + F(w.log.Debug, "S3 upload worker %d ended", wn) + return + case part, ok := <-uploadChan: + if !ok { + F(w.log.Debug, "S3 upload worker %d => closed channel", wn) + return + } + F(w.log.Debug, "S3 upload worker %d => uploading part %d", wn, part.partNumber) + w.uploadPart(part) + } + } + }(c) + } + + return uploadChan +} + +// WaitForCompletion waits until all workers finish their job +func (w *S3UploadWorkers) WaitForCompletion() { + w.wg.Wait() +} + +func (w *S3UploadWorkers) uploadPart(part *S3PartToUpload) { + part.mtx.Lock() + defer part.mtx.Unlock() + + u := part.uw + defer u.uploadGroup.Done() + + if part.state != S3PartUploadStateFull { + F(w.log.Debug, "Part to upload state invalid.") + return + } + + err := u.s3UploadPart(part) + u.UploadMemoryBufferPool.Put(part.content) + + if err != nil { + part.state = S3PartUploadErrorSending + u.seterr(err) + } else { + part.state = S3PartUploadStateSent + } +} diff --git a/multipart_upload_test.go b/multipart_upload_test.go new file mode 100644 index 0000000..ff21cfb --- /dev/null +++ b/multipart_upload_test.go @@ -0,0 +1,767 @@ +package main + +import ( + "context" + "fmt" + "io/ioutil" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + aws_s3 "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/stretchr/testify/assert" +) + +type mockedS3 struct { + s3iface.S3API + + totalBytes int + + errorUploadPartCalls int + errorPutObjectCalls int + errorCreateMultipartUploadCalls int + errorCompleteMultipartUploadCalls int + errorAbortMultipartUploadCalls int + + partSize int + uploadPartCalls int + putObjectCalls int + createMultipartUploadCalls int + completeMultipartUploadCalls int + abortMultipartUploadCalls int +} + +var TestUploadID = "test" + +func (m *mockedS3) PutObjectWithContext(ctx aws.Context, input *aws_s3.PutObjectInput, _ ...request.Option) (*aws_s3.PutObjectOutput, error) { + m.putObjectCalls++ + if m.errorPutObjectCalls != 0 && m.errorPutObjectCalls >= m.putObjectCalls { + return nil, fmt.Errorf("Error on input test") + } + + b, err := ioutil.ReadAll(input.Body) + if err != nil { + return nil, err + } + + for i := len(b) - 1; i >= 0; i-- { + if b[i] != byte(i%10+48) { + return nil, fmt.Errorf("Found incorrect char at pos %d -> %c != %c", i, b[i], byte(i%10+49)) + } + } + m.totalBytes += len(b) + return nil, nil +} + +func (m *mockedS3) CreateMultipartUploadWithContext(_ aws.Context, _ *aws_s3.CreateMultipartUploadInput, _ ...request.Option) (*aws_s3.CreateMultipartUploadOutput, error) { + m.createMultipartUploadCalls++ + if m.errorCreateMultipartUploadCalls != 0 && m.errorCreateMultipartUploadCalls >= m.createMultipartUploadCalls { + return nil, fmt.Errorf("Error on input test") + } + + return &aws_s3.CreateMultipartUploadOutput{ + UploadId: &TestUploadID, + }, nil +} + +func (m *mockedS3) UploadPartWithContext(_ aws.Context, input *aws_s3.UploadPartInput, _ ...request.Option) (*aws_s3.UploadPartOutput, error) { + m.uploadPartCalls++ + if m.errorUploadPartCalls != 0 && m.errorUploadPartCalls >= m.uploadPartCalls { + return nil, fmt.Errorf("Error on input test") + } + + off := int((*input.PartNumber)-1) * m.partSize + b, err := ioutil.ReadAll(input.Body) + if err != nil { + return nil, err + } + + for i := len(b) - 1; i >= 0; i-- { + if b[i] != byte((off+i)%10+48) { + return nil, fmt.Errorf("Found incorrect char at part %d, pos %d -> %c != %c", *input.PartNumber, i, b[i], byte((off+i)%10+49)) + } + } + m.totalBytes += len(b) + etag := fmt.Sprintf("etagTest%d", *input.PartNumber-1) + return &aws_s3.UploadPartOutput{ + ETag: &etag, + }, nil +} + +func (m *mockedS3) CompleteMultipartUploadWithContext(_ aws.Context, input *aws_s3.CompleteMultipartUploadInput, _ ...request.Option) (*aws_s3.CompleteMultipartUploadOutput, error) { + m.completeMultipartUploadCalls++ + if m.errorCompleteMultipartUploadCalls != 0 && m.errorCompleteMultipartUploadCalls >= m.completeMultipartUploadCalls { + return nil, fmt.Errorf("Error on input test") + } + + if *input.UploadId != TestUploadID { + return nil, fmt.Errorf("UploadID does not match") + } + + for i, v := range input.MultipartUpload.Parts { + expectedEtag := fmt.Sprintf("etagTest%d", i) + if *v.PartNumber != int64(i+1) || *v.ETag != expectedEtag { + return nil, fmt.Errorf("etag or partnumber does not match: PartNumber(%d), ETag (%s)", *v.PartNumber, *v.ETag) + } + } + return nil, nil +} + +func (m *mockedS3) AbortMultipartUploadWithContext(_ aws.Context, _ *aws_s3.AbortMultipartUploadInput, _ ...request.Option) (*aws_s3.AbortMultipartUploadOutput, error) { + m.abortMultipartUploadCalls++ + if m.errorAbortMultipartUploadCalls != 0 && m.errorAbortMultipartUploadCalls >= m.abortMultipartUploadCalls { + return nil, fmt.Errorf("Error on input test") + } + return nil, nil +} + +type FakeLog struct{} + +func (f FakeLog) Debug(args ...interface{}) {} +func (f FakeLog) Info(args ...interface{}) {} +func (f FakeLog) Warn(args ...interface{}) {} +func (f FakeLog) Error(args ...interface{}) {} + +// Tests +func TestMultipartUploadSinglePart(t *testing.T) { + partSize := 50 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("01234567890"), 0) + assert.NoError(t, err) + assert.NoError(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 1, m.putObjectCalls) + assert.Equal(t, 0, m.uploadPartCalls) + assert.Equal(t, 0, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 11, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadPendingPartOnSinglePut(t *testing.T) { + partSize := 20 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("0123456789"), 7) + assert.NoError(t, err) + assert.Error(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 0, m.uploadPartCalls) + assert.Equal(t, 0, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 0, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadErrorPutObject(t *testing.T) { + partSize := 20 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + errorPutObjectCalls: 1, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("0123456789"), 0) + assert.NoError(t, err) + assert.Error(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 1, m.putObjectCalls) + assert.Equal(t, 0, m.uploadPartCalls) + assert.Equal(t, 0, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 0, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadSinglePartFullUsesMultipart(t *testing.T) { + partSize := 10 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("0123456789"), 0) + assert.NoError(t, err) + assert.NoError(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 1, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 1, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 10, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadFillingPart(t *testing.T) { + partSize := 10 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 10, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("01234"), 0) + assert.NoError(t, err) + _, err = u.WriteAt([]byte("56789"), 5) + assert.NoError(t, err) + _, err = u.WriteAt([]byte("0123456789"), 10) + assert.NoError(t, err) + assert.NoError(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 2, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 1, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 20, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadMultiPartNotLockings(t *testing.T) { + partSize := 10 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 10, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("0123456789"), 0) + assert.NoError(t, err) + _, err = u.WriteAt([]byte("0123456789"), 0) + assert.NoError(t, err) + assert.NoError(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 1, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 1, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 10, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadMultiplePartSingleWriteAt(t *testing.T) { + partSize := 15 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("01234567890123456789012345"), 0) + assert.NoError(t, err) + assert.NoError(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 2, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 1, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 26, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadMultiplePartMultipleWriteAt(t *testing.T) { + partSize := 15 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + for c := 0; c < 2; c++ { + _, err := u.WriteAt([]byte("0123456789"), 10*int64(c)) + assert.NoError(t, err) + } + assert.NoError(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 2, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 1, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadMultiplePartIgnoredOverlapping(t *testing.T) { + partSize := 15 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + for c := 0; c < 2; c++ { + _, err := u.WriteAt([]byte("0123456789"), 0) + assert.NoError(t, err) + } + assert.NoError(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 1, m.putObjectCalls) + assert.Equal(t, 0, m.uploadPartCalls) + assert.Equal(t, 0, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 10, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadMaxObjectSizeErrorSingleWrite(t *testing.T) { + partSize := 15 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: 3, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("0123456789"), 0) + assert.Error(t, err) + assert.Equal(t, u.Close(), err) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 0, m.uploadPartCalls) + assert.Equal(t, 0, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 0, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadMaxObjectSizeErrorSeveralWrites(t *testing.T) { + partSize := 7 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: 12, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("0123456789"), 0) + assert.NoError(t, err) + _, err = u.WriteAt([]byte("0123456789"), 10) + assert.Error(t, err) + assert.Equal(t, u.Close(), err) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 1, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 1, m.abortMultipartUploadCalls) + assert.Equal(t, 7, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadPendingParts(t *testing.T) { + partSize := 10 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 5, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("0123456789"), 7) + assert.NoError(t, err) + assert.Error(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 1, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 1, m.abortMultipartUploadCalls) + assert.Equal(t, 0, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadErrorCreatingMultipartUploadOnClose(t *testing.T) { + partSize := 10 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + errorCreateMultipartUploadCalls: 1, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 5, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("0123456789"), 7) + assert.NoError(t, err) + assert.Error(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 0, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 0, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadErrorCreatingMultipartUploadOnWrite(t *testing.T) { + partSize := 10 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + errorCreateMultipartUploadCalls: 1, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 5, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("0123456789012"), 0) + assert.Error(t, err) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 0, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 0, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadErrorUploadingPartDetectedOnNextWrite(t *testing.T) { + partSize := 10 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + errorUploadPartCalls: 1, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 5, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("012345678901234"), 0) + assert.NoError(t, err) + close(ch) + w.WaitForCompletion() + _, err = u.WriteAt([]byte("01"), 14) + assert.Error(t, err) + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 1, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 1, m.abortMultipartUploadCalls) + assert.Equal(t, 0, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadPartPending(t *testing.T) { + partSize := 10 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + errorUploadPartCalls: 1, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 5, 5*time.Second), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("012345678901"), 0) + assert.NoError(t, err) + _, err = u.WriteAt([]byte("01"), 16) + assert.NoError(t, err) + assert.Error(t, u.Close()) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 1, m.uploadPartCalls) + assert.Equal(t, 1, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 1, m.abortMultipartUploadCalls) + assert.Equal(t, 0, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadPoolFull(t *testing.T) { + partSize := 10 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + errorUploadPartCalls: 1, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 100*time.Millisecond), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("012345678901"), 7) + assert.Error(t, err) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 0, m.uploadPartCalls) + assert.Equal(t, 0, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 0, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +func TestMultipartUploadErrorasdf(t *testing.T) { + partSize := 10 + log := &FakeLog{} + w := NewS3UploadWorkers(context.Background(), 1, log) + ch := w.Start() + m := &mockedS3{ + partSize: partSize, + errorUploadPartCalls: 1, + } + u := &S3MultipartUploadWriter{ + Ctx: context.Background(), + S3: m, + UploadMemoryBufferPool: NewMemoryBufferPool(context.Background(), partSize, 1, 100*time.Millisecond), + RequestMethod: "read", + Log: log, + PhantomObjectMap: NewPhantomObjectMap(), + Info: &PhantomObjectInfo{Key: Path{"", "a", "b"}}, + UploadChan: ch, + MaxObjectSize: -1, + ServerSideEncryption: &ServerSideEncryptionConfig{}, + } + _, err := u.WriteAt([]byte("012345678901"), 7) + assert.Error(t, err) + close(ch) + w.WaitForCompletion() + assert.Equal(t, 0, m.putObjectCalls) + assert.Equal(t, 0, m.uploadPartCalls) + assert.Equal(t, 0, m.createMultipartUploadCalls) + assert.Equal(t, 0, m.completeMultipartUploadCalls) + assert.Equal(t, 0, m.abortMultipartUploadCalls) + assert.Equal(t, 0, m.totalBytes) + assertPartsWithState(t, u, 0, S3PartUploadStateAdding) +} + +// Helpers +func assertPartsWithState(t *testing.T, u *S3MultipartUploadWriter, expected int, state S3PartUploadState) { + res := 0 + for _, part := range u.parts { + if part != nil && part.state == state { + res++ + } + } + assert.Equal(t, expected, res, "Found %d parts with state %d, and expected %d", res, state, expected) +} diff --git a/phantom_object_map.go b/phantom_object_map.go index 68ca606..1b501ce 100644 --- a/phantom_object_map.go +++ b/phantom_object_map.go @@ -9,7 +9,6 @@ type PhantomObjectInfo struct { Key Path LastModified time.Time Size int64 - Opaque interface{} Mtx sync.Mutex } @@ -31,10 +30,12 @@ func (info *PhantomObjectInfo) SetLastModified(v time.Time) { info.LastModified = v } -func (info *PhantomObjectInfo) SetSize(v int64) { +func (info *PhantomObjectInfo) SetSizeIfGreater(v int64) { info.Mtx.Lock() defer info.Mtx.Unlock() - info.Size = v + if v > info.Size { + info.Size = v + } } type phantomObjectInfoMap map[string]*PhantomObjectInfo diff --git a/server.go b/server.go index 0052f01..68f5658 100644 --- a/server.go +++ b/server.go @@ -12,19 +12,40 @@ import ( "golang.org/x/crypto/ssh" ) +type ServerLogger interface { + DebugLogger + InfoLogger + WarnLogger + ErrorLogger +} + type Server struct { *ssh.ServerConfig *S3Buckets *PhantomObjectMap + UploadMemoryBufferPool *MemoryBufferPool ReaderLookbackBufferSize int ReaderMinChunkSize int ListerLookbackBufferSize int - Log interface { - DebugLogger - InfoLogger - ErrorLogger + Log ServerLogger + Now func() time.Time + UploadChan chan<- *S3PartToUpload +} + +// NewServer creates a new sftp server +func NewServer(ctx context.Context, buckets *S3Buckets, serverConfig *ssh.ServerConfig, logger ServerLogger, readerLookbackBufferSize int, readerMinChunkSize int, listerLookbackBufferSize int, partSize int, uploadMemoryBufferPoolSize int, uploadMemoryBufferPoolTimeout time.Duration, uploadChan chan<- *S3PartToUpload) *Server { + return &Server{ + S3Buckets: buckets, + ServerConfig: serverConfig, + Log: logger, + ReaderLookbackBufferSize: readerLookbackBufferSize, + ReaderMinChunkSize: readerMinChunkSize, + ListerLookbackBufferSize: listerLookbackBufferSize, + UploadMemoryBufferPool: NewMemoryBufferPool(ctx, partSize, uploadMemoryBufferPoolSize, uploadMemoryBufferPoolTimeout), + PhantomObjectMap: NewPhantomObjectMap(), + Now: time.Now, + UploadChan: uploadChan, } - Now func() time.Time } func asHandlers(handlers interface { @@ -36,7 +57,7 @@ func asHandlers(handlers interface { return sftp.Handlers{handlers, handlers, handlers, handlers} } -func (s *Server) HandleChannel(ctx context.Context, bucket *S3Bucket, sshCh ssh.Channel, reqs <-chan *ssh.Request) { +func (s *Server) HandleChannel(ctx context.Context, bucket *S3Bucket, sshCh ssh.Channel, reqs <-chan *ssh.Request, userInfo *UserInfo) { defer s.Log.Debug("HandleChannel ended") server := sftp.NewRequestServer( sshCh, @@ -47,11 +68,14 @@ func (s *Server) HandleChannel(ctx context.Context, bucket *S3Bucket, sshCh ssh. ReaderLookbackBufferSize: s.ReaderLookbackBufferSize, ReaderMinChunkSize: s.ReaderMinChunkSize, ListerLookbackBufferSize: s.ListerLookbackBufferSize, + UploadMemoryBufferPool: s.UploadMemoryBufferPool, Log: s.Log, PhantomObjectMap: s.PhantomObjectMap, Perms: bucket.Perms, ServerSideEncryption: &bucket.ServerSideEncryption, Now: s.Now, + UserInfo: userInfo, + UploadChan: s.UploadChan, }, ), ) @@ -103,9 +127,11 @@ func (s *Server) HandleChannel(ctx context.Context, bucket *S3Bucket, sshCh ssh. func (s *Server) HandleClient(ctx context.Context, conn *net.TCPConn) error { defer s.Log.Debug("HandleClient ended") defer func() { + mUsersConnected.Dec() F(s.Log.Info, "connection from client %s closed", conn.RemoteAddr().String()) conn.Close() }() + F(s.Log.Info, "connected from client %s", conn.RemoteAddr().String()) innerCtx, cancel := context.WithCancel(ctx) @@ -122,7 +148,13 @@ func (s *Server) HandleClient(ctx context.Context, conn *net.TCPConn) error { return err } + userInfo := &UserInfo{ + Addr: conn.RemoteAddr(), + User: sconn.User(), + } + F(s.Log.Info, "user %s logged in", sconn.User()) + mUsersConnected.Inc() bucket, ok := s.UserToBucketMap[sconn.User()] if !ok { return fmt.Errorf("unknown error: no bucket designated to user %s found", sconn.User()) @@ -160,7 +192,7 @@ func (s *Server) HandleClient(ctx context.Context, conn *net.TCPConn) error { wg.Add(1) go func() { defer wg.Done() - s.HandleChannel(innerCtx, bucket, sshCh, reqs) + s.HandleChannel(innerCtx, bucket, sshCh, reqs, userInfo) }() } }(chans) @@ -216,7 +248,7 @@ outer: } // drain - for _ = range connChan { + for range connChan { } wg.Wait() diff --git a/user.go b/user.go index b7db187..d2876ae 100644 --- a/user.go +++ b/user.go @@ -2,9 +2,11 @@ package main import ( "fmt" + "io/ioutil" + "net" + "github.com/pkg/errors" "golang.org/x/crypto/ssh" - "io/ioutil" ) type User struct { @@ -19,6 +21,15 @@ type UserStore struct { usersMap map[string]*User } +type UserInfo struct { + Addr net.Addr + User string +} + +func (ui *UserInfo) String() string { + return fmt.Sprintf("%s from %s", ui.User, ui.Addr.String()) +} + type UserStores map[string]UserStore func (us *UserStore) Add(u *User) { diff --git a/util/offset_ranges.go b/util/offset_ranges.go new file mode 100644 index 0000000..ca214a2 --- /dev/null +++ b/util/offset_ranges.go @@ -0,0 +1,105 @@ +package util + +import ( + "container/list" + "fmt" +) + +type offsetRange struct { + // Start offset + s int64 + + // End offset + e int64 +} + +// OffsetRanges contains used offset ranges. +// A sorted double linked list is used internally to add ranges. +type OffsetRanges struct { + maxOffset int64 + l *list.List +} + +// NewOffsetRanges creates a new offset ranges +func NewOffsetRanges(maxOffset int64) *OffsetRanges { + return &OffsetRanges{ + maxOffset: maxOffset, + l: list.New(), + } +} + +// Add adds a range +func (o *OffsetRanges) Add(start int64, end int64) error { + if end > o.maxOffset { + return fmt.Errorf("End range is higher than maximum offset") + } + + for e := o.l.Front(); e != nil; e = e.Next() { + r := e.Value.(*offsetRange) + if start >= r.s && start <= r.e { + if end >= r.s && end <= r.e { + return nil + } + r.e = o.cleanInnerRanges(e, end) + return nil + } else if end >= r.s && end <= r.e { + r.s = start + return nil + } else if start < r.s && end > r.e { + r.s = start + r.e = o.cleanInnerRanges(e, end) + return nil + } else if end < r.s { + o.l.InsertBefore(&offsetRange{s: start, e: end}, e) + return nil + } + } + o.l.PushBack(&offsetRange{s: start, e: end}) + return nil +} + +// MustAdd adds a range and panics if it returns an error. If not, it returns the initial object +func (o *OffsetRanges) MustAdd(start int64, end int64) *OffsetRanges { + if err := o.Add(start, end); err != nil { + panic("Add returned error") + } + return o +} + +// IsFull checks if offset ranges is already full +func (o *OffsetRanges) IsFull() bool { + if o.l.Len() == 1 { + r := o.l.Front().Value.(*offsetRange) + return r.s == 0 && r.e == o.maxOffset + } + return false +} + +// GetMaxValidOffset checks if there is only one offset range that +// starts by 0 and returns the end of this offset. It returns -1 in +// other cases +func (o *OffsetRanges) GetMaxValidOffset() int64 { + if o.l.Len() == 1 { + r := o.l.Front().Value.(*offsetRange) + if r.s == 0 { + return r.e + } + } + return -1 +} + +func (o *OffsetRanges) cleanInnerRanges(e *list.Element, end int64) int64 { + for e = e.Next(); e != nil; { + rn := e.Value.(*offsetRange) + if end < rn.s { + return end + } + del := e + e = e.Next() + o.l.Remove(del) + if end <= rn.e { + return rn.e + } + } + return end +} diff --git a/util/offset_ranges_test.go b/util/offset_ranges_test.go new file mode 100644 index 0000000..463c321 --- /dev/null +++ b/util/offset_ranges_test.go @@ -0,0 +1,64 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOffsetRangesIsFullDirectly(t *testing.T) { + o := NewOffsetRanges(10) + assert.Equal(t, false, o.IsFull()) + o.Add(0, 10) + assert.Equal(t, true, o.IsFull()) +} + +func TestOffsetRangesAddMergeSameRange(t *testing.T) { + assertOnlyContainsRange(t, NewOffsetRanges(10).MustAdd(0, 5).MustAdd(0, 3), 0, 5) + assertOnlyContainsRange(t, NewOffsetRanges(10).MustAdd(0, 5).MustAdd(2, 5), 0, 5) + assertOnlyContainsRange(t, NewOffsetRanges(10).MustAdd(2, 5).MustAdd(0, 3), 0, 5) + assertOnlyContainsRange(t, NewOffsetRanges(10).MustAdd(0, 5).MustAdd(5, 10), 0, 10) + assertOnlyContainsRange(t, NewOffsetRanges(10).MustAdd(2, 5).MustAdd(0, 10), 0, 10) +} + +func TestOffsetRangesAddMergeSeveralRanges(t *testing.T) { + assertOnlyContainsRange(t, NewOffsetRanges(10).MustAdd(7, 10).MustAdd(0, 3).MustAdd(2, 8), 0, 10) + assertOnlyContainsRange(t, NewOffsetRanges(100).MustAdd(7, 10).MustAdd(0, 3).MustAdd(2, 12), 0, 12) + assertOnlyContainsRange(t, NewOffsetRanges(100).MustAdd(7, 10).MustAdd(2, 5).MustAdd(0, 8), 0, 10) + assertOnlyContainsRange(t, NewOffsetRanges(100).MustAdd(7, 10).MustAdd(2, 5).MustAdd(0, 12), 0, 12) + assertOnlyContainsRange(t, NewOffsetRanges(10).MustAdd(7, 8).MustAdd(0, 3).MustAdd(3, 7), 0, 8) +} + +func TestOffsetRangesAddNewRange(t *testing.T) { + assertRanges(t, NewOffsetRanges(10).MustAdd(7, 8).MustAdd(0, 5), []*offsetRange{&offsetRange{0, 5}, &offsetRange{7, 8}}) + assertRanges(t, NewOffsetRanges(10).MustAdd(0, 5).MustAdd(7, 8), []*offsetRange{&offsetRange{0, 5}, &offsetRange{7, 8}}) + assertRanges(t, NewOffsetRanges(10).MustAdd(2, 3).MustAdd(5, 6).MustAdd(8, 9).MustAdd(0, 7), []*offsetRange{&offsetRange{0, 7}, &offsetRange{8, 9}}) + assertRanges(t, NewOffsetRanges(10).MustAdd(0, 3).MustAdd(5, 6).MustAdd(8, 9).MustAdd(1, 7), []*offsetRange{&offsetRange{0, 7}, &offsetRange{8, 9}}) +} + +func TestOffsetRangesGetMaxValidOffset(t *testing.T) { + assert.Equal(t, NewOffsetRanges(10).MustAdd(3, 8).GetMaxValidOffset(), int64(-1)) + assert.Equal(t, NewOffsetRanges(10).MustAdd(3, 8).MustAdd(0, 2).GetMaxValidOffset(), int64(-1)) + assert.Equal(t, NewOffsetRanges(10).MustAdd(0, 3).GetMaxValidOffset(), int64(3)) +} + +func assertOnlyContainsRange(t *testing.T, o *OffsetRanges, s int64, e int64) { + assert.Equal(t, 1, o.l.Len(), "Expected only 1 range") + assertRange(t, o.l.Front().Value, s, e) +} + +func assertRange(t *testing.T, elem interface{}, s int64, e int64) { + r := elem.(*offsetRange) + if r.s != s || r.e != e { + assert.Failf(t, "Range mismatches", "Expected [%d,%d] and found [%d,%d]", s, e, r.s, r.e) + } +} + +func assertRanges(t *testing.T, o *OffsetRanges, ranges []*offsetRange) { + assert.Equal(t, len(ranges), o.l.Len(), "Ranges length mismatch") + i := 0 + for e := o.l.Front(); e != nil && i < len(ranges); e = e.Next() { + assertRange(t, e.Value, ranges[i].s, ranges[i].e) + i++ + } +}